[BREAKING CHANGE] Ignore added_tokens (both special and not) in the decoder (#1513)

* [BREAKING CHANGE] Ignore added_tokens (both special and not) in the
decoder

Causes issues with `ByteLevel` messing up some `AddedTokens` with some
utf-8 range used in the bytelevel mapping.

This commit tests the extend of the damage of ignoring the decoder for
those tokens.

* Format.

* Installing cargo audit.

* Minor fix.

* Fixing "bug" in node/python.

* Autoformat.

* Clippy.

* Only prefix space when there's no decoder.
This commit is contained in:
Nicolas Patry
2024-05-06 11:49:38 +02:00
committed by GitHub
parent f2ec3b239b
commit 25aee8b88c
4 changed files with 51 additions and 16 deletions

View File

@ -63,6 +63,12 @@ jobs:
toolchain: stable
components: rustfmt, clippy
- name: Install audit
uses: actions-rs/cargo@v1
with:
command: install
args: cargo-audit
- name: Install Python
uses: actions/setup-python@v4
with:

View File

@ -36,6 +36,12 @@ jobs:
command: install
args: cargo-readme
- name: Install audit
uses: actions-rs/cargo@v1
with:
command: install
args: cargo-audit
- name: Build
uses: actions-rs/cargo@v1
with:

View File

@ -216,6 +216,10 @@ impl AddedVocabulary {
}
/// Get the token matching the given id if it exists
#[deprecated(
since = "0.19.0",
note = "please use `added_vocabulary.simple_id_to_token(id).or_else(|| model.id_to_token(id)` instead"
)]
pub fn id_to_token(&self, id: u32, model: &impl Model) -> Option<String> {
self.added_tokens_map_r
.get(&id)
@ -223,6 +227,10 @@ impl AddedVocabulary {
.or_else(|| model.id_to_token(id))
}
pub fn simple_id_to_token(&self, id: u32) -> Option<String> {
self.added_tokens_map_r.get(&id).map(|t| t.content.clone())
}
//
pub fn set_encode_special_tokens(&mut self, value: bool) {
self.encode_special_tokens = value;

View File

@ -699,7 +699,9 @@ where
/// Converts an id to the corresponding token.
pub fn id_to_token(&self, id: u32) -> Option<String> {
self.added_vocabulary.id_to_token(id, &self.model)
self.added_vocabulary
.simple_id_to_token(id)
.or_else(|| self.model.id_to_token(id))
}
/// set the added bocab's splitting scheme
@ -845,22 +847,35 @@ where
/// Decode the given ids, back to a String
pub fn decode(&self, ids: &[u32], skip_special_tokens: bool) -> Result<String> {
let tokens = ids
.iter()
.filter_map(|id| {
self.added_vocabulary
.id_to_token(*id, &self.model)
.filter(|token| {
!skip_special_tokens || !self.added_vocabulary.is_special_token(token)
})
})
.collect::<Vec<_>>();
if let Some(decoder) = &self.decoder {
decoder.decode(tokens)
} else {
Ok(tokens.join(" "))
let mut result = String::with_capacity(ids.len());
let mut chunks = Vec::with_capacity(ids.len());
for id in ids {
if let Some(added_token) = self.added_vocabulary.simple_id_to_token(*id) {
if skip_special_tokens && self.added_vocabulary.is_special_token(&added_token) {
continue;
}
let text_chunk = if let Some(decoder) = &self.decoder {
decoder.decode(chunks.clone())?
} else {
chunks.join(" ")
};
result.push_str(&text_chunk);
if !result.is_empty() && self.decoder.is_none() {
result.push(' ');
}
result.push_str(&added_token);
chunks.clear();
} else if let Some(token) = self.model.id_to_token(*id) {
chunks.push(token);
}
}
let text_chunk = if let Some(decoder) = &self.decoder {
decoder.decode(chunks.clone())?
} else {
chunks.join(" ")
};
result.push_str(&text_chunk);
Ok(result)
}
}