diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml new file mode 100644 index 0000000..489c766 --- /dev/null +++ b/.github/workflows/rust.yml @@ -0,0 +1,25 @@ +name: build + +on: [push, pull_request] + +jobs: + build: + + runs-on: ${{matrix.os}} + strategy: + fail-fast: true + matrix: + os: [macos-latest, ubuntu-latest, windows-latest] + + steps: + - uses: actions/checkout@v2 + - name: Check formatting + run: | + rustup component add rustfmt + rustup component add clippy + cargo fmt -- --check + cargo clippy --all-features --verbose -- -Dwarnings + - name: Build + run: cargo build --verbose + - name: Run tests + run: cargo test --all-features diff --git a/Cargo.lock b/Cargo.lock index c58c710..a222611 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8,7 +8,7 @@ checksum = "7bbb73db36c1246e9034e307d0fba23f9a2e251faa47ade70c1bd252220c8311" [[package]] name = "esaxx-rs" -version = "0.1.1" +version = "0.1.2" dependencies = [ "cc", ] diff --git a/Cargo.toml b/Cargo.toml index 54de2a3..52713e2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "esaxx-rs" -version = "0.1.1" +version = "0.1.2" authors = ["Nicolas Patry "] edition = "2018" description = "Wrapping around sentencepiece's esaxxx library." diff --git a/README.md b/README.md index 6bacfa2..c3aa382 100644 --- a/README.md +++ b/README.md @@ -1,33 +1,25 @@ +![](https://github.com/Narsil/esaxx-rs/workflows/build/badge.svg) + # esaxx-rs +This code implements a fast suffix tree / suffix array. + +This code is taken from ![sentencepiece](https://github.com/google/sentencepiece) +and to be used by ![hugging face](https://github.com/huggingface/tokenizers/). + + Small wrapper around sentencepiece's esaxx suffix array C++ library. Usage ```rust -let string = "abracadabra".to_string(); +let string = "abracadabra"; +let suffix = esaxx_rs::suffix(string).unwrap(); let chars: Vec<_> = string.chars().collect(); -let n = chars.len(); -let mut sa = vec![0; n]; -let mut l = vec![0; n]; -let mut r = vec![0; n]; -let mut d = vec![0; n]; -let mut node_num = 0; - -let alphabet_size = 0x110000; // All UCS4 range. -unsafe { - esaxx_int32( - chars.as_ptr() as *mut u32, - sa.as_mut_ptr(), - l.as_mut_ptr(), - r.as_mut_ptr(), - d.as_mut_ptr(), - n.try_into().unwrap(), - alphabet_size, - &mut node_num, - ); -} +let mut iter = suffix.iter(); +assert_eq!(iter.next().unwrap(), (&chars[..4], 2)); // abra +assert_eq!(iter.next(), Some((&chars[..1], 5))); // a +assert_eq!(iter.next(), Some((&chars[1..4], 2))); // bra +assert_eq!(iter.next(), Some((&chars[2..4], 2))); // ra +assert_eq!(iter.next(), Some((&chars[..0], 11))); // '' +assert_eq!(iter.next(), None); ``` - -Current version: 0.1.0 - -License: Apache diff --git a/README.tpl b/README.tpl index ee3febb..a972eef 100644 --- a/README.tpl +++ b/README.tpl @@ -1,7 +1,12 @@ +![](https://github.com/Narsil/esaxx-rs/workflows/build/badge.svg) + # {{crate}} +This code implements a fast suffix tree / suffix array. + +This code is taken from ![sentencepiece](https://github.com/google/sentencepiece) +and to be used by ![hugging face](https://github.com/huggingface/tokenizers/). + + {{readme}} -Current version: {{version}} - -License: {{license}} diff --git a/src/esa.rs b/src/esa.rs index 5cf557f..2ebf5dc 100644 --- a/src/esa.rs +++ b/src/esa.rs @@ -75,7 +75,7 @@ fn suffixtree( node_num } -pub fn esaxx_rs( +pub(crate) fn esaxx_rs( t: &StringT, sa: &mut SArray, l: &mut SArray, diff --git a/src/lib.rs b/src/lib.rs index 3115755..fd69d71 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,6 +13,25 @@ //! assert_eq!(iter.next(), Some((&chars[..0], 11))); // '' //! assert_eq!(iter.next(), None); //! ``` +//! +//! The previous version uses unsafe optimized c++ code. +//! There exists another implementation a bit slower (~2x slower) that uses +//! safe rust. It's a bit slower because it uses usize (mostly 64bit) instead of i32 (32bit). +//! But it does seems to fix a few OOB issues in the cpp version +//! (which never seemed to cause real problems in tests but still.) +//! +//! ```rust +//! let string = "abracadabra"; +//! let suffix = esaxx_rs::suffix_rs(string).unwrap(); +//! let chars: Vec<_> = string.chars().collect(); +//! let mut iter = suffix.iter(); +//! assert_eq!(iter.next().unwrap(), (&chars[..4], 2)); // abra +//! assert_eq!(iter.next(), Some((&chars[..1], 5))); // a +//! assert_eq!(iter.next(), Some((&chars[1..4], 2))); // bra +//! assert_eq!(iter.next(), Some((&chars[2..4], 2))); // ra +//! assert_eq!(iter.next(), Some((&chars[..0], 11))); // '' +//! assert_eq!(iter.next(), None); +//! ``` #![feature(test)] extern crate test; @@ -21,11 +40,11 @@ mod esa; mod sais; mod types; -pub use esa::esaxx_rs; +use esa::esaxx_rs; use types::SuffixError; extern "C" { - pub fn esaxx_int32( + fn esaxx_int32( // This is char32 T: *const u32, SA: *mut i32, @@ -38,7 +57,7 @@ extern "C" { ) -> i32; } -pub fn esaxx( +fn esaxx( chars: &[char], sa: &mut [i32], l: &mut [i32], @@ -83,6 +102,8 @@ pub struct Suffix { node_num: usize, } +/// Creates the suffix array and provides an iterator over its items (Rust version) +/// See [suffix](fn.suffix.html) pub fn suffix_rs(string: &str) -> Result, SuffixError> { let chars: Vec<_> = string.chars().collect(); let n = chars.len(); @@ -102,6 +123,22 @@ pub fn suffix_rs(string: &str) -> Result, SuffixError> { }) } +/// Creates the suffix array and provides an iterator over its items (c++ unsafe version) +/// +/// Gives you an iterator over the suffixes of the input array and their count within +/// the input srtring. +/// ```rust +/// let string = "abracadabra"; +/// let suffix = esaxx_rs::suffix(string).unwrap(); +/// let chars: Vec<_> = string.chars().collect(); +/// let mut iter = suffix.iter(); +/// assert_eq!(iter.next().unwrap(), (&chars[..4], 2)); // abra +/// assert_eq!(iter.next(), Some((&chars[..1], 5))); // a +/// assert_eq!(iter.next(), Some((&chars[1..4], 2))); // bra +/// assert_eq!(iter.next(), Some((&chars[2..4], 2))); // ra +/// assert_eq!(iter.next(), Some((&chars[..0], 11))); // '' +/// assert_eq!(iter.next(), None); +/// ``` pub fn suffix(string: &str) -> Result, SuffixError> { let chars: Vec<_> = string.chars().collect(); let n = chars.len();