[ignore_merges] Fix offsets (#1640)

* Fix the default offset create

* update the tests

* clippy
This commit is contained in:
Arthur
2024-10-01 09:22:20 +02:00
committed by GitHub
parent b4a38c4f63
commit 3fb1371c1c
2 changed files with 12 additions and 7 deletions

View File

@ -462,7 +462,11 @@ impl BPE {
fn tokenize_with_cache(&self, sequence: &str) -> Result<Vec<Token>> {
if self.ignore_merges {
if let Some(id) = self.vocab.get(sequence) {
return Ok(vec![Token::new(*id, sequence.to_string().clone(), (0, 0))]);
return Ok(vec![Token::new(
*id,
sequence.to_string().clone(),
(0, sequence.len()),
)]);
}
}
if let Some(ref hit) = self.cache.as_ref().and_then(|c| c.get(sequence)) {
@ -941,10 +945,13 @@ mod tests {
.build()
.unwrap();
let tokens = bpe.tokenize(".:.:").unwrap();
assert_eq!(tokens, vec![Token::new(0u32, ".:.:".into(), (0, 0))]);
assert_eq!(tokens, vec![Token::new(0u32, ".:.:".into(), (0, 4))]);
let tokens = bpe.tokenize("Ġbelirtilen").unwrap();
assert_eq!(tokens, vec![Token::new(1u32, "Ġbelirtilen".into(), (0, 0))]);
assert_eq!(
tokens,
vec![Token::new(1u32, "Ġbelirtilen".into(), (0, 12))]
);
bpe.ignore_merges = false;

View File

@ -1181,11 +1181,10 @@ where
};
trainer.feed(
sequences.map(|s| {
sequences.inspect(|s| {
if let Some(progress) = &progress {
progress.inc(s.len() as u64)
}
s
}),
|seq| {
let normalized = self.do_normalize(seq.as_ref())?;
@ -1233,11 +1232,10 @@ where
};
trainer.feed(
sequences.map(|s| {
sequences.inspect(|_s| {
if let Some(progress) = &progress {
progress.inc(1)
}
s
}),
|seq| {
let normalized = self.do_normalize(seq.as_ref())?;