[BREAKING CHANGE] Ignore added_tokens (both special and not) in the decoder (#1513)

* [BREAKING CHANGE] Ignore added_tokens (both special and not) in the
decoder

Causes issues with `ByteLevel` messing up some `AddedTokens` with some
utf-8 range used in the bytelevel mapping.

This commit tests the extend of the damage of ignoring the decoder for
those tokens.

* Format.

* Installing cargo audit.

* Minor fix.

* Fixing "bug" in node/python.

* Autoformat.

* Clippy.

* Only prefix space when there's no decoder.
This commit is contained in:
Nicolas Patry
2024-05-06 11:49:38 +02:00
committed by GitHub
parent f2ec3b239b
commit 25aee8b88c
4 changed files with 51 additions and 16 deletions

View File

@ -63,6 +63,12 @@ jobs:
toolchain: stable toolchain: stable
components: rustfmt, clippy components: rustfmt, clippy
- name: Install audit
uses: actions-rs/cargo@v1
with:
command: install
args: cargo-audit
- name: Install Python - name: Install Python
uses: actions/setup-python@v4 uses: actions/setup-python@v4
with: with:

View File

@ -36,6 +36,12 @@ jobs:
command: install command: install
args: cargo-readme args: cargo-readme
- name: Install audit
uses: actions-rs/cargo@v1
with:
command: install
args: cargo-audit
- name: Build - name: Build
uses: actions-rs/cargo@v1 uses: actions-rs/cargo@v1
with: with:

View File

@ -216,6 +216,10 @@ impl AddedVocabulary {
} }
/// Get the token matching the given id if it exists /// Get the token matching the given id if it exists
#[deprecated(
since = "0.19.0",
note = "please use `added_vocabulary.simple_id_to_token(id).or_else(|| model.id_to_token(id)` instead"
)]
pub fn id_to_token(&self, id: u32, model: &impl Model) -> Option<String> { pub fn id_to_token(&self, id: u32, model: &impl Model) -> Option<String> {
self.added_tokens_map_r self.added_tokens_map_r
.get(&id) .get(&id)
@ -223,6 +227,10 @@ impl AddedVocabulary {
.or_else(|| model.id_to_token(id)) .or_else(|| model.id_to_token(id))
} }
pub fn simple_id_to_token(&self, id: u32) -> Option<String> {
self.added_tokens_map_r.get(&id).map(|t| t.content.clone())
}
// //
pub fn set_encode_special_tokens(&mut self, value: bool) { pub fn set_encode_special_tokens(&mut self, value: bool) {
self.encode_special_tokens = value; self.encode_special_tokens = value;

View File

@ -699,7 +699,9 @@ where
/// Converts an id to the corresponding token. /// Converts an id to the corresponding token.
pub fn id_to_token(&self, id: u32) -> Option<String> { pub fn id_to_token(&self, id: u32) -> Option<String> {
self.added_vocabulary.id_to_token(id, &self.model) self.added_vocabulary
.simple_id_to_token(id)
.or_else(|| self.model.id_to_token(id))
} }
/// set the added bocab's splitting scheme /// set the added bocab's splitting scheme
@ -845,22 +847,35 @@ where
/// Decode the given ids, back to a String /// Decode the given ids, back to a String
pub fn decode(&self, ids: &[u32], skip_special_tokens: bool) -> Result<String> { pub fn decode(&self, ids: &[u32], skip_special_tokens: bool) -> Result<String> {
let tokens = ids let mut result = String::with_capacity(ids.len());
.iter() let mut chunks = Vec::with_capacity(ids.len());
.filter_map(|id| { for id in ids {
self.added_vocabulary if let Some(added_token) = self.added_vocabulary.simple_id_to_token(*id) {
.id_to_token(*id, &self.model) if skip_special_tokens && self.added_vocabulary.is_special_token(&added_token) {
.filter(|token| { continue;
!skip_special_tokens || !self.added_vocabulary.is_special_token(token) }
}) let text_chunk = if let Some(decoder) = &self.decoder {
}) decoder.decode(chunks.clone())?
.collect::<Vec<_>>(); } else {
chunks.join(" ")
if let Some(decoder) = &self.decoder { };
decoder.decode(tokens) result.push_str(&text_chunk);
} else { if !result.is_empty() && self.decoder.is_none() {
Ok(tokens.join(" ")) result.push(' ');
}
result.push_str(&added_token);
chunks.clear();
} else if let Some(token) = self.model.id_to_token(*id) {
chunks.push(token);
}
} }
let text_chunk = if let Some(decoder) = &self.decoder {
decoder.decode(chunks.clone())?
} else {
chunks.join(" ")
};
result.push_str(&text_chunk);
Ok(result)
} }
} }