mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 08:15:49 +00:00
Add safety comments (#1651)
* Unsafe comment for from_u32_unchecked * Add safety comments and type assertion for HashSet parallel iteration * Add safety comment for String splice * fixes * fmt * pos
This commit is contained in:
committed by
GitHub
parent
6ea758872d
commit
5512a424bf
@ -454,7 +454,7 @@ impl BpeTrainer {
|
||||
// 3. Tokenize words
|
||||
//
|
||||
self.update_progress(&progress, word_counts.len(), "Tokenize words");
|
||||
let (words, counts) =
|
||||
let (mut words, counts) =
|
||||
self.tokenize_words(word_counts, &mut word_to_id, &mut id_to_word, &progress);
|
||||
self.finalize_progress(&progress, words.len());
|
||||
|
||||
@ -530,14 +530,29 @@ impl BpeTrainer {
|
||||
merges.push((top.pair, new_token_id));
|
||||
|
||||
// Merge the new pair in every words
|
||||
let changes = top
|
||||
.pos
|
||||
// Safety: This is just a type assertion, the code below may no longer be safe
|
||||
// if the type of `pos` changes
|
||||
let pos: &HashSet<usize> = &top.pos;
|
||||
|
||||
let words_len = words.len();
|
||||
struct WordPtr(*mut Word);
|
||||
// Safety: We do not actually use this for concurrent access to the same memory,
|
||||
// only to different chunks within the same allocation.
|
||||
unsafe impl Sync for WordPtr {}
|
||||
let word_start = WordPtr(words.as_mut_ptr());
|
||||
|
||||
let changes = pos
|
||||
.maybe_par_iter()
|
||||
.flat_map(|&i| {
|
||||
let word = &words[i] as *const _ as *mut Word;
|
||||
// We can merge each of these words in parallel here because each position
|
||||
// can be there only once (HashSet). So this is safe.
|
||||
// Safety:
|
||||
// We are producing a valid pointer since we are indexing in bounds
|
||||
//
|
||||
// We can access each `word` here in parallel because each position
|
||||
// can be there only once (pos is a HashSet).
|
||||
unsafe {
|
||||
assert!(i < words_len);
|
||||
// This is words[i], but avoids needing to go through &T (which triggers UB)
|
||||
let word = word_start.0.add(i);
|
||||
// let word: &mut Word = &mut (*word);
|
||||
(*word)
|
||||
.merge(top.pair.0, top.pair.1, new_token_id, max_token_length)
|
||||
|
@ -28,6 +28,9 @@ pub(crate) fn bytes_char() -> HashMap<u8, char> {
|
||||
}
|
||||
}
|
||||
|
||||
// Safety: cs contains all values from bs (between 0 and 255),
|
||||
// and some values of value 2⁸ + n, where n is between 0 and 255. This is between 255 and 512.
|
||||
// Both ranges are valid UTF-32 values (which is fully saturated until 0xD000)
|
||||
bs.into_iter()
|
||||
.zip(cs)
|
||||
.map(|(f, t)| (f, unsafe { std::char::from_u32_unchecked(t) }))
|
||||
|
@ -411,8 +411,16 @@ impl NormalizedString {
|
||||
.collect::<String>();
|
||||
|
||||
self.alignments.splice(n_range.clone(), alignments);
|
||||
|
||||
// This bounds check already happens above (`self.normalized[n_range.clone()]`), but future
|
||||
// code could change to mutate `self` or `self.normalized` in the interim.
|
||||
// Perform it again and hope the optimizer collapses it.
|
||||
assert!(self.normalized.get(n_range.clone()).is_some());
|
||||
unsafe {
|
||||
self.normalized
|
||||
// Safety: This is safe as long as we do not splice across a
|
||||
// UTF-8 character, and we only add UTF-8 text. `normalized` is a String
|
||||
// so the latter is trivially true, and we assert for the former above.
|
||||
.as_mut_vec()
|
||||
.splice(n_range, normalized.bytes());
|
||||
}
|
||||
|
Reference in New Issue
Block a user