Fixing the progressbar. (#1353)

* Fixing the progressbar.

* Upgrade deps.

* Update cargo audit

* Ssh this action.

* Fixing esaxx by using slower rust version.

* Trying the new esaxx version.

* Publish.

* Get cache again.
This commit is contained in:
Nicolas Patry
2023-10-05 15:33:58 +02:00
committed by GitHub
parent 7e8e69a22c
commit aed491df8c
9 changed files with 50 additions and 57 deletions

View File

@ -54,7 +54,7 @@ jobs:
os: [ubuntu-latest, macos-latest] os: [ubuntu-latest, macos-latest]
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v1 uses: actions/checkout@v2
- name: Install Rust - name: Install Rust
@ -99,9 +99,7 @@ jobs:
uses: actions-rs/cargo@v1 uses: actions-rs/cargo@v1
with: with:
command: audit command: audit
# ignoring specific CVE which probably isn't affecting this crate args: -D warnings -f ./bindings/python/Cargo.lock
# https://github.com/chronotope/chrono/issues/602
args: -D warnings -f ./bindings/python/Cargo.lock --ignore RUSTSEC-2020-0071 --ignore RUSTSEC-2021-0145
- name: Install - name: Install
working-directory: ./bindings/python working-directory: ./bindings/python

View File

@ -85,9 +85,7 @@ jobs:
uses: actions-rs/cargo@v1 uses: actions-rs/cargo@v1
with: with:
command: audit command: audit
# ignoring specific CVE which probably isn't affecting this crate args: -D warnings -f ./tokenizers/Cargo.lock
# https://github.com/chronotope/chrono/issues/602
args: -D warnings -f ./tokenizers/Cargo.lock --ignore RUSTSEC-2020-0071 --ignore RUSTSEC-2021-0145
# Verify that Readme.md is up to date. # Verify that Readme.md is up to date.
- name: Make sure, Readme generated from lib.rs matches actual Readme - name: Make sure, Readme generated from lib.rs matches actual Readme

View File

@ -9,25 +9,23 @@ name = "tokenizers"
crate-type = ["cdylib"] crate-type = ["cdylib"]
[dependencies] [dependencies]
rayon = "1.3" rayon = "1.8"
serde = { version = "1.0", features = [ "rc", "derive" ]} serde = { version = "1.0", features = [ "rc", "derive" ]}
serde_json = "1.0" serde_json = "1.0"
libc = "0.2" libc = "0.2"
env_logger = "0.7.1" env_logger = "0.10.0"
pyo3 = { version = "0.19" } pyo3 = { version = "0.19" }
numpy = "0.19.0" numpy = "0.19.0"
ndarray = "0.13" ndarray = "0.15"
onig = { version = "6.0", default-features = false } onig = { version = "6.4", default-features = false }
itertools = "0.9" itertools = "0.11"
[dependencies.tokenizers] [dependencies.tokenizers]
version = "0.14.1-dev.0" version = "0.14.1-dev.0"
path = "../../tokenizers" path = "../../tokenizers"
default-features = false
features = ["onig"]
[dev-dependencies] [dev-dependencies]
tempfile = "3.1" tempfile = "3.8"
pyo3 = { version = "0.19", features = ["auto-initialize"] } pyo3 = { version = "0.19", features = ["auto-initialize"] }
[features] [features]

View File

@ -45,31 +45,31 @@ harness = false
[dependencies] [dependencies]
lazy_static = "1.4" lazy_static = "1.4"
rand = "0.8" rand = "0.8"
onig = { version = "6.0", default-features = false, optional = true } onig = { version = "6.4", default-features = false, optional = true }
regex = "1.8" regex = "1.9"
regex-syntax = "0.7" regex-syntax = "0.7"
rayon = "1.7" rayon = "1.8"
rayon-cond = "0.1" rayon-cond = "0.3"
serde = { version = "1.0", features = [ "derive" ] } serde = { version = "1.0", features = [ "derive" ] }
serde_json = "1.0" serde_json = "1.0"
clap = { version = "4.0", features=["derive"], optional = true } clap = { version = "4.4", features=["derive"], optional = true }
unicode-normalization-alignments = "0.1" unicode-normalization-alignments = "0.1"
unicode_categories = "0.1" unicode_categories = "0.1"
unicode-segmentation = "1.10" unicode-segmentation = "1.10"
indicatif = {version = "0.15", optional = true} indicatif = {version = "0.17", optional = true}
itertools = "0.9" itertools = "0.11"
log = "0.4" log = "0.4"
derive_builder = "0.12" derive_builder = "0.12"
spm_precompiled = "0.1" spm_precompiled = "0.1"
hf-hub = { version = "0.2.0", optional = true } hf-hub = { version = "0.3.2", optional = true }
aho-corasick = "0.7" aho-corasick = "1.1"
paste = "1.0.6" paste = "1.0.14"
macro_rules_attribute = "0.1.2" macro_rules_attribute = "0.2.0"
thiserror = "1.0.30" thiserror = "1.0.49"
fancy-regex = { version = "0.10", optional = true} fancy-regex = { version = "0.11", optional = true}
getrandom = { version = "0.2.6" } getrandom = { version = "0.2.10" }
esaxx-rs = { version = "0.1", default-features = false, features=[]} esaxx-rs = { version = "0.1.10", default-features = false, features=[]}
monostate = "0.1.5" monostate = "0.1.9"
[features] [features]
default = ["progressbar", "cli", "onig", "esaxx_fast"] default = ["progressbar", "cli", "onig", "esaxx_fast"]
@ -80,8 +80,8 @@ cli = ["clap"]
unstable_wasm = ["fancy-regex", "getrandom/js"] unstable_wasm = ["fancy-regex", "getrandom/js"]
[dev-dependencies] [dev-dependencies]
criterion = "0.4" criterion = "0.5"
tempfile = "3.1" tempfile = "3.8"
assert_approx_eq = "1.1" assert_approx_eq = "1.1"
[profile.release] [profile.release]

View File

@ -223,7 +223,8 @@ impl BpeTrainer {
let p = ProgressBar::new(0); let p = ProgressBar::new(0);
p.set_style( p.set_style(
ProgressStyle::default_bar() ProgressStyle::default_bar()
.template("[{elapsed_precise}] {msg:<40!} {wide_bar} {pos:<9!}/{len:>9!}"), .template("[{elapsed_precise}] {msg:<30!} {wide_bar} {pos:<9!}/{len:>9!}")
.expect("Invalid progress template"),
); );
Some(p) Some(p)
} else { } else {
@ -241,11 +242,10 @@ impl BpeTrainer {
} }
/// Update the progress bar with the new provided length and message /// Update the progress bar with the new provided length and message
fn update_progress(&self, p: &Option<ProgressBar>, len: usize, message: &str) { fn update_progress(&self, p: &Option<ProgressBar>, len: usize, message: &'static str) {
if let Some(p) = p { if let Some(p) = p {
p.set_message(message); p.set_message(message);
p.set_length(len as u64); p.set_length(len as u64);
p.set_draw_delta(len as u64 / 100);
p.reset(); p.reset();
} }
} }

View File

@ -88,7 +88,8 @@ impl UnigramTrainer {
let p = ProgressBar::new(0); let p = ProgressBar::new(0);
p.set_style( p.set_style(
ProgressStyle::default_bar() ProgressStyle::default_bar()
.template("[{elapsed_precise}] {msg:<40!} {wide_bar} {pos:<9!}/{len:>9!}"), .template("[{elapsed_precise}] {msg:<30!} {wide_bar} {pos:<9!}/{len:>9!}")
.expect("Invalid progress template"),
); );
Some(p) Some(p)
} else { } else {
@ -431,11 +432,10 @@ impl UnigramTrainer {
} }
/// Update the progress bar with the new provided length and message /// Update the progress bar with the new provided length and message
fn update_progress(&self, p: &Option<ProgressBar>, len: usize, message: &str) { fn update_progress(&self, p: &Option<ProgressBar>, len: usize, message: &'static str) {
if let Some(p) = p { if let Some(p) = p {
p.set_message(message); p.set_message(message);
p.set_length(len as u64); p.set_length(len as u64);
p.set_draw_delta(len as u64 / 100);
p.reset(); p.reset();
} }
} }

View File

@ -166,10 +166,12 @@ impl AddedVocabulary {
pub fn new() -> Self { pub fn new() -> Self {
let trie = AhoCorasickBuilder::new() let trie = AhoCorasickBuilder::new()
.match_kind(MatchKind::LeftmostLongest) .match_kind(MatchKind::LeftmostLongest)
.build::<_, &&[u8]>([]); .build::<_, &&[u8]>([])
.expect("The trie should build correctly");
let normalized_trie = AhoCorasickBuilder::new() let normalized_trie = AhoCorasickBuilder::new()
.match_kind(MatchKind::LeftmostLongest) .match_kind(MatchKind::LeftmostLongest)
.build::<_, &&[u8]>([]); .build::<_, &&[u8]>([])
.expect("The normalized trie should build correctly");
Self { Self {
added_tokens_map: HashMap::new(), added_tokens_map: HashMap::new(),
added_tokens_map_r: HashMap::new(), added_tokens_map_r: HashMap::new(),
@ -314,7 +316,8 @@ impl AddedVocabulary {
let (tokens, ids): (Vec<&AddedToken>, Vec<u32>) = non_normalized.into_iter().unzip(); let (tokens, ids): (Vec<&AddedToken>, Vec<u32>) = non_normalized.into_iter().unzip();
let trie = AhoCorasickBuilder::new() let trie = AhoCorasickBuilder::new()
.match_kind(MatchKind::LeftmostLongest) .match_kind(MatchKind::LeftmostLongest)
.build(tokens.iter().map(|token| &token.content)); .build(tokens.iter().map(|token| &token.content))
.expect("Failed to build tried when refreshing tokens");
self.split_trie = (trie, ids); self.split_trie = (trie, ids);
let (ntokens, nids): (Vec<&AddedToken>, Vec<u32>) = normalized.into_iter().unzip(); let (ntokens, nids): (Vec<&AddedToken>, Vec<u32>) = normalized.into_iter().unzip();
@ -330,7 +333,8 @@ impl AddedVocabulary {
.collect(); .collect();
let normalized_trie = AhoCorasickBuilder::new() let normalized_trie = AhoCorasickBuilder::new()
.match_kind(MatchKind::LeftmostLongest) .match_kind(MatchKind::LeftmostLongest)
.build(patterns.iter().map(|content| content.get())); .build(patterns.iter().map(|content| content.get()))
.expect("Failed to build tried when refreshing tokens (normalized)");
self.split_normalized_trie = (normalized_trie, nids); self.split_normalized_trie = (normalized_trie, nids);
} }

View File

@ -1078,11 +1078,11 @@ where
let progress = ProgressBar::new(len); let progress = ProgressBar::new(len);
progress.set_style( progress.set_style(
ProgressStyle::default_bar() ProgressStyle::default_bar()
.template("[{elapsed_precise}] {msg:<40!} {wide_bar} {percent:>18!}%"), .template("[{elapsed_precise}] {msg:<30!} {wide_bar} {percent:>18!}%")
.expect("Invalid progress template"),
); );
progress progress
.set_message(&format!("Pre-processing files ({:.2} Mo)", len / 1_000_000)); .set_message(format!("Pre-processing files ({:.2} Mo)", len / 1_000_000));
progress.set_draw_delta(len / 100); // Redraw only every 2%
Some(progress) Some(progress)
} else { } else {
None None
@ -1131,15 +1131,10 @@ where
let progress = ProgressBar::new(len); let progress = ProgressBar::new(len);
progress.set_style( progress.set_style(
ProgressStyle::default_bar() ProgressStyle::default_bar()
.template("[{elapsed_precise}] {msg:<40!} {wide_bar} {pos:<9!}/{len:>9!}"), .template("[{elapsed_precise}] {msg:<30!} {wide_bar} {pos:<9!}/{len:>9!}")
.expect("Invalid progress template"),
); );
progress.set_message("Pre-processing sequences"); progress.set_message("Pre-processing sequences");
if len > 0 {
progress.set_draw_delta(len / 100); // Redraw only every 2%
} else {
// Trying to have a good default to avoid progress tracking being the bottleneck
progress.set_draw_delta(1000);
}
Some(progress) Some(progress)
} else { } else {
None None

View File

@ -3,6 +3,7 @@ pub(crate) use indicatif::{ProgressBar, ProgressStyle};
#[cfg(not(feature = "progressbar"))] #[cfg(not(feature = "progressbar"))]
mod progressbar { mod progressbar {
use std::borrow::Cow;
pub struct ProgressBar; pub struct ProgressBar;
impl ProgressBar { impl ProgressBar {
pub fn new(_length: u64) -> Self { pub fn new(_length: u64) -> Self {
@ -10,8 +11,7 @@ mod progressbar {
} }
pub fn set_length(&self, _length: u64) {} pub fn set_length(&self, _length: u64) {}
pub fn set_draw_delta(&self, _draw_delta: u64) {} pub fn set_message(&self, _message: impl Into<Cow<'static, str>>) {}
pub fn set_message(&self, _message: &str) {}
pub fn finish(&self) {} pub fn finish(&self) {}
pub fn reset(&self) {} pub fn reset(&self) {}
pub fn inc(&self, _inc: u64) {} pub fn inc(&self, _inc: u64) {}
@ -23,8 +23,8 @@ mod progressbar {
pub fn default_bar() -> Self { pub fn default_bar() -> Self {
Self {} Self {}
} }
pub fn template(self, _template: &str) -> Self { pub fn template(self, _template: &str) -> Result<Self, String> {
self Ok(self)
} }
} }
} }