From b4b31d73cd9b2701b04b1a19b750dcf8a4da9e73 Mon Sep 17 00:00:00 2001 From: Anthony MOI Date: Tue, 10 Dec 2019 16:20:31 -0500 Subject: [PATCH] Expose vocabulary size --- bindings/python/src/tokenizer.rs | 5 +++++ tokenizers/src/models/bpe/model.rs | 4 ++++ tokenizers/src/models/wordpiece/mod.rs | 4 ++++ tokenizers/src/tokenizer.rs | 6 ++++++ 4 files changed, 19 insertions(+) diff --git a/bindings/python/src/tokenizer.rs b/bindings/python/src/tokenizer.rs index 51b69e03..289aa691 100644 --- a/bindings/python/src/tokenizer.rs +++ b/bindings/python/src/tokenizer.rs @@ -30,6 +30,11 @@ impl Tokenizer { } } + #[getter] + fn get_vocab_size(&self) -> usize { + self.tokenizer.get_vocab_size() + } + fn with_model(&mut self, model: &mut Model) -> PyResult<()> { if let Some(model) = model.model.to_pointer() { self.tokenizer.with_model(model); diff --git a/tokenizers/src/models/bpe/model.rs b/tokenizers/src/models/bpe/model.rs index 2862bab4..d46118b0 100644 --- a/tokenizers/src/models/bpe/model.rs +++ b/tokenizers/src/models/bpe/model.rs @@ -96,6 +96,10 @@ impl BPE { } impl Model for BPE { + fn get_vocab_size(&self) -> usize { + self.vocab.len() + } + fn tokenize(&self, sentence: Vec) -> Vec { if sentence.len() == 0 { return vec![]; diff --git a/tokenizers/src/models/wordpiece/mod.rs b/tokenizers/src/models/wordpiece/mod.rs index 315842b5..4f0a13cd 100644 --- a/tokenizers/src/models/wordpiece/mod.rs +++ b/tokenizers/src/models/wordpiece/mod.rs @@ -48,6 +48,10 @@ impl WordPiece { } impl Model for WordPiece { + fn get_vocab_size(&self) -> usize { + self.vocab.len() + } + fn tokenize(&self, sentence: Vec) -> Vec { let mut output_tokens = vec![]; diff --git a/tokenizers/src/tokenizer.rs b/tokenizers/src/tokenizer.rs index 6f720cf0..81b300aa 100644 --- a/tokenizers/src/tokenizer.rs +++ b/tokenizers/src/tokenizer.rs @@ -37,6 +37,7 @@ pub trait Model { fn decode(&self, ids: Vec) -> Vec; fn token_to_id(&self, token: &str) -> Option; fn id_to_token(&self, id: u32) -> Option; + fn get_vocab_size(&self) -> usize; } /// A PostProcessor has the responsibility to post process an encoded output of the Tokenizer. @@ -166,6 +167,11 @@ impl Tokenizer { self } + /// Get the size of the vocabulary + pub fn get_vocab_size(&self) -> usize { + self.model.get_vocab_size() + } + /// Converts a token in the corresponding id. pub fn token_to_id(&self, token: &str) -> Option { self.model.token_to_id(token)