Fix ByteLevel Decoder

The join was done after replacing bytes and building subwords, which was preventing bytes across these subwords to be merged correctly. We need to join first.
This commit is contained in:
Anthony MOI
2019-11-21 16:50:25 -05:00
parent 634415c098
commit 663644e041
2 changed files with 41 additions and 14 deletions

View File

@ -58,17 +58,26 @@ print(f"Tokenizing {len(text)} lines")
start = time.time()
encoded_r = tokenize_r()
end = time.time()
print(f"Rust tokenizer took: {end - start} sec")
time_r = end - start
print(f"Rust tokenizer took: {time_r} sec")
# Python version
start = time.time()
encoded_p = tokenize_p()
end = time.time()
print(f"Transformer tokenizer took: {end - start} sec")
time_p = end - start
print(f"Transformer tokenizer took: {time_p} sec")
print(f"SpeedUp Ratio: {time_p / time_r}")
ids_r = [ [ token.id for token in sentence ] for sentence in encoded_r ]
assert(ids_r == encoded_p)
decoded_r = tok_r.decode_batch(ids_r)
print(f"Decoded sentences: {decoded_r}")
for i in range(0, len(text)):
if decoded_r[i] != text[i]:
print(decoded_r[i])
print(text[i])
print("")
assert(decoded_r == text)

View File

@ -48,17 +48,14 @@ impl PreTokenizer for ByteLevel {
impl Decoder for ByteLevel {
fn decode(&self, tokens: Vec<String>) -> String {
tokens
.into_iter()
.map(|token| {
let bytes = token
.chars()
.map(|c| CHAR_BYTES[&(c as u32)])
.collect::<Vec<u8>>();
String::from_utf8_lossy(&bytes).into_owned()
})
.collect::<Vec<_>>()
.join("")
String::from_utf8_lossy(
&tokens
.join("")
.chars()
.map(|c| CHAR_BYTES[&(c as u32)])
.collect::<Vec<_>>(),
)
.into_owned()
}
}
@ -93,4 +90,25 @@ mod tests {
)
);
}
#[test]
fn decode_works_on_separated_tokens() {
let samples = vec![
String::from(
"A Nuskhuri abbreviation of იესუ ქრისტე ( iesu kriste ) \" Jesus Christ \"",
),
String::from("An equal number have descenders , like p or q in English : გ , დ , ე , ვ , კ , ლ , ჟ , ტ , უ , ფ , ღ , , ც"),
];
let bl = ByteLevel;
for sample in samples {
let pre_tokenized = bl.pre_tokenize(&sample);
let separated_tokens = pre_tokenized
.into_iter()
.map(|token| token.split("").map(|t| t.into()).collect::<Vec<_>>())
.flatten()
.collect::<Vec<_>>();
assert_eq!(sample, bl.decode(separated_tokens));
}
}
}