feat: Zenzaiで学習機能を有効化する (#108)

* feat: add metadata for DicdataElement

* feat: ignore learned word in zenz evaluation

* feat: improve session command to support temporal memory
This commit is contained in:
Miwa
2024-06-27 00:32:44 +09:00
committed by GitHub
parent fbf09a76eb
commit f58f1603d4
5 changed files with 133 additions and 19 deletions

View File

@ -127,13 +127,29 @@ class ZenzContext {
case wholeResult(String)
}
func evaluate_candidate(input: String, candidate: String) -> CandidateEvaluationResult {
func getNextDicdataElement(for prefix: String, of candidate: Candidate) -> DicdataElement? {
var curPrefix = ""
for datum in candidate.data {
if curPrefix == prefix {
// prefix
return datum
} else if curPrefix.hasPrefix(prefix) {
// FIXME:
return nil
}
curPrefix.append(datum.word)
}
return nil
}
func evaluate_candidate(input: String, candidate: Candidate) -> CandidateEvaluationResult {
print("Evaluate", candidate)
// For zenz-v1 model, \u{EE00} is a token used for 'start query', and \u{EE01} is a token used for 'start answer'
// We assume \u{EE01}\(candidate) is always splitted into \u{EE01}_\(candidate) by zenz-v1 tokenizer
let prompt = "\u{EE00}\(input)\u{EE01}"
// Therefore, tokens = prompt_tokens + candidate_tokens is an appropriate operation.
let prompt_tokens = self.tokenize(text: prompt, add_bos: true, add_eos: false)
let candidate_tokens = self.tokenize(text: candidate, add_bos: false, add_eos: false)
let candidate_tokens = self.tokenize(text: candidate.text, add_bos: false, add_eos: false)
let tokens = prompt_tokens + candidate_tokens
let startOffset = prompt_tokens.count - 1
let pos_max = llama_kv_cache_seq_pos_max(self.context, 0)
@ -178,12 +194,19 @@ class ZenzContext {
var cchars = tokens[..<i].reduce(into: []) {
$0.append(contentsOf: token_to_piece(token: $1))
}
let acceptedPrefix = String(cString: cchars + [0]).dropFirst(prompt.count)
// adding "\0"
cchars += token_to_piece(token: max_token) + [0]
let string = String(cString: cchars)
//
let prefixConstraint = String(string.dropFirst(prompt.count))
return .fixRequired(prefixConstraint: prefixConstraint)
if let nextDicdataElement = getNextDicdataElement(for: String(acceptedPrefix), of: candidate), nextDicdataElement.metadata.contains(.isLearned) {
//
// pass
} else {
//
let prefixConstraint = String(string.dropFirst(prompt.count))
return .fixRequired(prefixConstraint: prefixConstraint)
}
}
}
score += log(max_exp) - log(exp_sum)