mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Fix ByteLevel Decoder
The join was done after replacing bytes and building subwords, which was preventing bytes across these subwords to be merged correctly. We need to join first.
This commit is contained in:
@ -58,17 +58,26 @@ print(f"Tokenizing {len(text)} lines")
|
||||
start = time.time()
|
||||
encoded_r = tokenize_r()
|
||||
end = time.time()
|
||||
print(f"Rust tokenizer took: {end - start} sec")
|
||||
time_r = end - start
|
||||
print(f"Rust tokenizer took: {time_r} sec")
|
||||
|
||||
# Python version
|
||||
start = time.time()
|
||||
encoded_p = tokenize_p()
|
||||
end = time.time()
|
||||
print(f"Transformer tokenizer took: {end - start} sec")
|
||||
time_p = end - start
|
||||
print(f"Transformer tokenizer took: {time_p} sec")
|
||||
|
||||
print(f"SpeedUp Ratio: {time_p / time_r}")
|
||||
|
||||
ids_r = [ [ token.id for token in sentence ] for sentence in encoded_r ]
|
||||
assert(ids_r == encoded_p)
|
||||
|
||||
decoded_r = tok_r.decode_batch(ids_r)
|
||||
print(f"Decoded sentences: {decoded_r}")
|
||||
for i in range(0, len(text)):
|
||||
if decoded_r[i] != text[i]:
|
||||
print(decoded_r[i])
|
||||
print(text[i])
|
||||
print("")
|
||||
|
||||
assert(decoded_r == text)
|
||||
|
@ -48,17 +48,14 @@ impl PreTokenizer for ByteLevel {
|
||||
|
||||
impl Decoder for ByteLevel {
|
||||
fn decode(&self, tokens: Vec<String>) -> String {
|
||||
tokens
|
||||
.into_iter()
|
||||
.map(|token| {
|
||||
let bytes = token
|
||||
.chars()
|
||||
.map(|c| CHAR_BYTES[&(c as u32)])
|
||||
.collect::<Vec<u8>>();
|
||||
String::from_utf8_lossy(&bytes).into_owned()
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("")
|
||||
String::from_utf8_lossy(
|
||||
&tokens
|
||||
.join("")
|
||||
.chars()
|
||||
.map(|c| CHAR_BYTES[&(c as u32)])
|
||||
.collect::<Vec<_>>(),
|
||||
)
|
||||
.into_owned()
|
||||
}
|
||||
}
|
||||
|
||||
@ -93,4 +90,25 @@ mod tests {
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decode_works_on_separated_tokens() {
|
||||
let samples = vec![
|
||||
String::from(
|
||||
"A Nuskhuri abbreviation of იესუ ქრისტე ( iesu kriste ) \" Jesus Christ \"",
|
||||
),
|
||||
String::from("An equal number have descenders , like p or q in English : გ , დ , ე , ვ , კ , ლ , ჟ , ტ , უ , ფ , ღ , ყ , ც"),
|
||||
];
|
||||
|
||||
let bl = ByteLevel;
|
||||
for sample in samples {
|
||||
let pre_tokenized = bl.pre_tokenize(&sample);
|
||||
let separated_tokens = pre_tokenized
|
||||
.into_iter()
|
||||
.map(|token| token.split("").map(|t| t.into()).collect::<Vec<_>>())
|
||||
.flatten()
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(sample, bl.decode(separated_tokens));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user