mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Add safety comments (#1651)
* Unsafe comment for from_u32_unchecked * Add safety comments and type assertion for HashSet parallel iteration * Add safety comment for String splice * fixes * fmt * pos
This commit is contained in:
committed by
GitHub
parent
6ea758872d
commit
5512a424bf
@ -454,7 +454,7 @@ impl BpeTrainer {
|
|||||||
// 3. Tokenize words
|
// 3. Tokenize words
|
||||||
//
|
//
|
||||||
self.update_progress(&progress, word_counts.len(), "Tokenize words");
|
self.update_progress(&progress, word_counts.len(), "Tokenize words");
|
||||||
let (words, counts) =
|
let (mut words, counts) =
|
||||||
self.tokenize_words(word_counts, &mut word_to_id, &mut id_to_word, &progress);
|
self.tokenize_words(word_counts, &mut word_to_id, &mut id_to_word, &progress);
|
||||||
self.finalize_progress(&progress, words.len());
|
self.finalize_progress(&progress, words.len());
|
||||||
|
|
||||||
@ -530,14 +530,29 @@ impl BpeTrainer {
|
|||||||
merges.push((top.pair, new_token_id));
|
merges.push((top.pair, new_token_id));
|
||||||
|
|
||||||
// Merge the new pair in every words
|
// Merge the new pair in every words
|
||||||
let changes = top
|
// Safety: This is just a type assertion, the code below may no longer be safe
|
||||||
.pos
|
// if the type of `pos` changes
|
||||||
|
let pos: &HashSet<usize> = &top.pos;
|
||||||
|
|
||||||
|
let words_len = words.len();
|
||||||
|
struct WordPtr(*mut Word);
|
||||||
|
// Safety: We do not actually use this for concurrent access to the same memory,
|
||||||
|
// only to different chunks within the same allocation.
|
||||||
|
unsafe impl Sync for WordPtr {}
|
||||||
|
let word_start = WordPtr(words.as_mut_ptr());
|
||||||
|
|
||||||
|
let changes = pos
|
||||||
.maybe_par_iter()
|
.maybe_par_iter()
|
||||||
.flat_map(|&i| {
|
.flat_map(|&i| {
|
||||||
let word = &words[i] as *const _ as *mut Word;
|
// Safety:
|
||||||
// We can merge each of these words in parallel here because each position
|
// We are producing a valid pointer since we are indexing in bounds
|
||||||
// can be there only once (HashSet). So this is safe.
|
//
|
||||||
|
// We can access each `word` here in parallel because each position
|
||||||
|
// can be there only once (pos is a HashSet).
|
||||||
unsafe {
|
unsafe {
|
||||||
|
assert!(i < words_len);
|
||||||
|
// This is words[i], but avoids needing to go through &T (which triggers UB)
|
||||||
|
let word = word_start.0.add(i);
|
||||||
// let word: &mut Word = &mut (*word);
|
// let word: &mut Word = &mut (*word);
|
||||||
(*word)
|
(*word)
|
||||||
.merge(top.pair.0, top.pair.1, new_token_id, max_token_length)
|
.merge(top.pair.0, top.pair.1, new_token_id, max_token_length)
|
||||||
|
@ -28,6 +28,9 @@ pub(crate) fn bytes_char() -> HashMap<u8, char> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Safety: cs contains all values from bs (between 0 and 255),
|
||||||
|
// and some values of value 2⁸ + n, where n is between 0 and 255. This is between 255 and 512.
|
||||||
|
// Both ranges are valid UTF-32 values (which is fully saturated until 0xD000)
|
||||||
bs.into_iter()
|
bs.into_iter()
|
||||||
.zip(cs)
|
.zip(cs)
|
||||||
.map(|(f, t)| (f, unsafe { std::char::from_u32_unchecked(t) }))
|
.map(|(f, t)| (f, unsafe { std::char::from_u32_unchecked(t) }))
|
||||||
|
@ -411,8 +411,16 @@ impl NormalizedString {
|
|||||||
.collect::<String>();
|
.collect::<String>();
|
||||||
|
|
||||||
self.alignments.splice(n_range.clone(), alignments);
|
self.alignments.splice(n_range.clone(), alignments);
|
||||||
|
|
||||||
|
// This bounds check already happens above (`self.normalized[n_range.clone()]`), but future
|
||||||
|
// code could change to mutate `self` or `self.normalized` in the interim.
|
||||||
|
// Perform it again and hope the optimizer collapses it.
|
||||||
|
assert!(self.normalized.get(n_range.clone()).is_some());
|
||||||
unsafe {
|
unsafe {
|
||||||
self.normalized
|
self.normalized
|
||||||
|
// Safety: This is safe as long as we do not splice across a
|
||||||
|
// UTF-8 character, and we only add UTF-8 text. `normalized` is a String
|
||||||
|
// so the latter is trivially true, and we assert for the former above.
|
||||||
.as_mut_vec()
|
.as_mut_vec()
|
||||||
.splice(n_range, normalized.bytes());
|
.splice(n_range, normalized.bytes());
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user