mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Merges cannot handle tokens containing spaces. (#909)
* Merges cannot handle tokens containing spaces. This fixes this while keeping backward support. We don't want to merge that blindly. * Update the tests. * Fixing clippy. * Add a test with spaces in the token/merge.
This commit is contained in:
@ -30,14 +30,14 @@ impl Serialize for BPE {
|
|||||||
.map(|(pair, (rank, _))| (pair, rank))
|
.map(|(pair, (rank, _))| (pair, rank))
|
||||||
.collect();
|
.collect();
|
||||||
merges.sort_unstable_by_key(|k| *k.1);
|
merges.sort_unstable_by_key(|k| *k.1);
|
||||||
let merges_str = merges
|
let merges = merges
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|(pair, _)| format!("{} {}", self.vocab_r[&pair.0], self.vocab_r[&pair.1]))
|
.map(|(pair, _)| (self.vocab_r[&pair.0].clone(), self.vocab_r[&pair.1].clone()))
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
let ordered_vocab = OrderedVocabIter::new(&self.vocab_r);
|
let ordered_vocab = OrderedVocabIter::new(&self.vocab_r);
|
||||||
|
|
||||||
model.serialize_field("vocab", &ordered_vocab)?;
|
model.serialize_field("vocab", &ordered_vocab)?;
|
||||||
model.serialize_field("merges", &merges_str)?;
|
model.serialize_field("merges", &merges)?;
|
||||||
|
|
||||||
model.end()
|
model.end()
|
||||||
}
|
}
|
||||||
@ -81,7 +81,14 @@ impl<'de> Visitor<'de> for BPEVisitor {
|
|||||||
{
|
{
|
||||||
let mut builder = BpeBuilder::new();
|
let mut builder = BpeBuilder::new();
|
||||||
let mut vocab: Option<HashMap<String, u32>> = None;
|
let mut vocab: Option<HashMap<String, u32>> = None;
|
||||||
let mut merges: Option<Vec<String>> = None;
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
enum MergeType {
|
||||||
|
Tuple(Vec<(String, String)>),
|
||||||
|
Legacy(Vec<String>),
|
||||||
|
}
|
||||||
|
let mut merges: Option<MergeType> = None;
|
||||||
while let Some(key) = map.next_key::<String>()? {
|
while let Some(key) = map.next_key::<String>()? {
|
||||||
match key.as_ref() {
|
match key.as_ref() {
|
||||||
"dropout" => {
|
"dropout" => {
|
||||||
@ -134,8 +141,12 @@ impl<'de> Visitor<'de> for BPEVisitor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if let (Some(vocab), Some(merges)) = (vocab, merges) {
|
if let (Some(vocab), Some(merges)) = (vocab, merges) {
|
||||||
let merges =
|
let merges = match merges {
|
||||||
convert_merges_to_hashmap(merges.into_iter(), &vocab).map_err(Error::custom)?;
|
MergeType::Tuple(merges) => merges,
|
||||||
|
MergeType::Legacy(merges) => {
|
||||||
|
convert_merges_to_hashmap(merges.into_iter(), &vocab).map_err(Error::custom)?
|
||||||
|
}
|
||||||
|
};
|
||||||
builder = builder.vocab_and_merges(vocab, merges);
|
builder = builder.vocab_and_merges(vocab, merges);
|
||||||
Ok(builder.build().map_err(Error::custom)?)
|
Ok(builder.build().map_err(Error::custom)?)
|
||||||
} else {
|
} else {
|
||||||
@ -167,13 +178,40 @@ mod test {
|
|||||||
.build()
|
.build()
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
let legacy = r#"{"type":"BPE","dropout":null,"unk_token":"<unk>","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"<unk>":0,"a":1,"b":2,"ab":3},"merges":["a b"]}"#;
|
||||||
|
let legacy = serde_json::from_str(legacy).unwrap();
|
||||||
|
assert_eq!(bpe, legacy);
|
||||||
|
|
||||||
let data = serde_json::to_string(&bpe).unwrap();
|
let data = serde_json::to_string(&bpe).unwrap();
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
data,
|
data,
|
||||||
r#"{"type":"BPE","dropout":null,"unk_token":"<unk>","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"<unk>":0,"a":1,"b":2,"ab":3},"merges":["a b"]}"#
|
r#"{"type":"BPE","dropout":null,"unk_token":"<unk>","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"<unk>":0,"a":1,"b":2,"ab":3},"merges":[["a","b"]]}"#
|
||||||
);
|
);
|
||||||
let reconstructed = serde_json::from_str(&data).unwrap();
|
let reconstructed = serde_json::from_str(&data).unwrap();
|
||||||
|
assert_eq!(bpe, reconstructed);
|
||||||
|
|
||||||
|
// With a space in the token
|
||||||
|
let vocab: Vocab = [
|
||||||
|
("<unk>".into(), 0),
|
||||||
|
("a".into(), 1),
|
||||||
|
("b c d".into(), 2),
|
||||||
|
("ab c d".into(), 3),
|
||||||
|
]
|
||||||
|
.iter()
|
||||||
|
.cloned()
|
||||||
|
.collect();
|
||||||
|
let bpe = BpeBuilder::default()
|
||||||
|
.vocab_and_merges(vocab, vec![("a".to_string(), "b c d".to_string())])
|
||||||
|
.unk_token("<unk>".to_string())
|
||||||
|
.ignore_merges(true)
|
||||||
|
.build()
|
||||||
|
.unwrap();
|
||||||
|
let data = serde_json::to_string(&bpe).unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
data,
|
||||||
|
r#"{"type":"BPE","dropout":null,"unk_token":"<unk>","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"<unk>":0,"a":1,"b c d":2,"ab c d":3},"merges":[["a","b c d"]]}"#
|
||||||
|
);
|
||||||
|
let reconstructed = serde_json::from_str(&data).unwrap();
|
||||||
assert_eq!(bpe, reconstructed);
|
assert_eq!(bpe, reconstructed);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -312,11 +312,14 @@ mod tests {
|
|||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let model = ModelWrapper::BPE(bpe);
|
let model = ModelWrapper::BPE(bpe);
|
||||||
|
let legacy = r#"{"type":"BPE","dropout":null,"unk_token":"<unk>","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"<unk>":0,"a":1,"b":2,"ab":3},"merges":["a b"]}"#;
|
||||||
|
let legacy = serde_json::from_str(legacy).unwrap();
|
||||||
|
assert_eq!(model, legacy);
|
||||||
|
|
||||||
let data = serde_json::to_string(&model).unwrap();
|
let data = serde_json::to_string(&model).unwrap();
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
data,
|
data,
|
||||||
r#"{"type":"BPE","dropout":null,"unk_token":"<unk>","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"<unk>":0,"a":1,"b":2,"ab":3},"merges":["a b"]}"#
|
r#"{"type":"BPE","dropout":null,"unk_token":"<unk>","continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":true,"vocab":{"<unk>":0,"a":1,"b":2,"ab":3},"merges":[["a","b"]]}"#
|
||||||
);
|
);
|
||||||
let reconstructed = serde_json::from_str(&data).unwrap();
|
let reconstructed = serde_json::from_str(&data).unwrap();
|
||||||
assert_eq!(model, reconstructed);
|
assert_eq!(model, reconstructed);
|
||||||
|
Reference in New Issue
Block a user