mirror of
https://github.com/mii443/tokenizers.git
synced 2025-12-06 20:58:22 +00:00
Introduce WordLevel model for TransformerXL (#125)
* Added lookup table model mapping string to id present in a vocab map. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * RustFmt Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Formatting. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Fix invalid void return on Rust side. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Python binding for LookupTable model Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Enable loading from Python's side. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Renamed LookupTable to WordLevel Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * RustFmt happy now. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * clippy happy now. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Addressing mismatching names. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Addressing mismatching names (one missing). Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
This commit is contained in:
@@ -28,6 +28,7 @@ fn models(_py: Python, m: &PyModule) -> PyResult<()> {
|
|||||||
m.add_class::<models::Model>()?;
|
m.add_class::<models::Model>()?;
|
||||||
m.add_class::<models::BPE>()?;
|
m.add_class::<models::BPE>()?;
|
||||||
m.add_class::<models::WordPiece>()?;
|
m.add_class::<models::WordPiece>()?;
|
||||||
|
m.add_class::<models::WordLevel>()?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -153,3 +153,40 @@ impl WordPiece {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[pyclass]
|
||||||
|
pub struct WordLevel {}
|
||||||
|
|
||||||
|
#[pymethods]
|
||||||
|
impl WordLevel {
|
||||||
|
#[staticmethod]
|
||||||
|
#[args(kwargs = "**")]
|
||||||
|
fn from_files(vocab: &str, kwargs: Option<&PyDict>) -> PyResult<Model> {
|
||||||
|
let mut unk_token = String::from("<unk>");
|
||||||
|
|
||||||
|
if let Some(kwargs) = kwargs {
|
||||||
|
for (key, val) in kwargs {
|
||||||
|
let key: &str = key.extract()?;
|
||||||
|
match key {
|
||||||
|
"unk_token" => unk_token = val.extract()?,
|
||||||
|
_ => println!("Ignored unknown kwargs option {}", key),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
match tk::models::wordlevel::WordLevel::from_files(
|
||||||
|
vocab,
|
||||||
|
unk_token,
|
||||||
|
) {
|
||||||
|
Err(e) => {
|
||||||
|
println!("Errors: {:?}", e);
|
||||||
|
Err(exceptions::Exception::py_err(
|
||||||
|
"Error while initializing WordLevel",
|
||||||
|
))
|
||||||
|
}
|
||||||
|
Ok(model) => Ok(Model {
|
||||||
|
model: Container::Owned(Box::new(model)),
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -3,3 +3,4 @@ from .. import models
|
|||||||
Model = models.Model
|
Model = models.Model
|
||||||
BPE = models.BPE
|
BPE = models.BPE
|
||||||
WordPiece = models.WordPiece
|
WordPiece = models.WordPiece
|
||||||
|
WordLevel = models.WordLevel
|
||||||
@@ -17,7 +17,7 @@ class Model:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class BPE:
|
class BPE(Model):
|
||||||
""" BytePairEncoding model class """
|
""" BytePairEncoding model class """
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -62,7 +62,7 @@ class BPE:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class WordPiece:
|
class WordPiece(Model):
|
||||||
""" WordPiece model class """
|
""" WordPiece model class """
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -87,3 +87,22 @@ class WordPiece:
|
|||||||
def empty() -> Model:
|
def empty() -> Model:
|
||||||
""" Instantiate an empty WordPiece Model. """
|
""" Instantiate an empty WordPiece Model. """
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class WordLevel(Model):
|
||||||
|
"""
|
||||||
|
Most simple tokenizer model based on mapping token from a vocab file to their corresponding id.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_files(vocab: str, unk_token: str) -> Model:
|
||||||
|
""" Instantiate a WordLevel Model from the given vocab file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vocab: string:
|
||||||
|
Path to a vocabulary file.
|
||||||
|
|
||||||
|
unk_token: str:
|
||||||
|
The unknown token to be used by the model.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
//! Popular tokenizer models.
|
//! Popular tokenizer models.
|
||||||
|
|
||||||
pub mod bpe;
|
pub mod bpe;
|
||||||
|
pub mod wordlevel;
|
||||||
pub mod wordpiece;
|
pub mod wordpiece;
|
||||||
|
|||||||
183
tokenizers/src/models/wordlevel/mod.rs
Normal file
183
tokenizers/src/models/wordlevel/mod.rs
Normal file
@@ -0,0 +1,183 @@
|
|||||||
|
use crate::tokenizer::{Model, Result, Token};
|
||||||
|
use serde_json::Value;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::fmt;
|
||||||
|
use std::fs::File;
|
||||||
|
use std::io::{BufReader, Read, Write};
|
||||||
|
use std::path::{Path, PathBuf};
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum Error {
|
||||||
|
MissingUnkToken,
|
||||||
|
BadVocabulary,
|
||||||
|
}
|
||||||
|
impl std::error::Error for Error {}
|
||||||
|
|
||||||
|
impl fmt::Display for Error {
|
||||||
|
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
|
||||||
|
match self {
|
||||||
|
Error::MissingUnkToken => write!(
|
||||||
|
fmt,
|
||||||
|
"WordLevel error: Missing [UNK] token from the vocabulary"
|
||||||
|
),
|
||||||
|
Error::BadVocabulary => write!(fmt, "Bad vocabulary json file"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Config {
|
||||||
|
vocab: HashMap<String, u32>,
|
||||||
|
unk_token: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A `WordLevelBuilder` can be used to create a `WordLevel`
|
||||||
|
/// model with a custom configuration.
|
||||||
|
pub struct WordLevelBuilder {
|
||||||
|
config: Config,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for WordLevelBuilder {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
config: Config {
|
||||||
|
vocab: HashMap::new(),
|
||||||
|
unk_token: String::from("<unk>"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WordLevelBuilder {
|
||||||
|
/// Construct a new `WordLevelBuilder`.
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self::default()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set the vocab (token -> ID) mapping.
|
||||||
|
pub fn vocab(mut self, vocab: HashMap<String, u32>) -> Self {
|
||||||
|
self.config.vocab = vocab;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The the `UNK` token for the vocab.
|
||||||
|
pub fn unk_token(mut self, unk_token: String) -> Self {
|
||||||
|
self.config.unk_token = unk_token;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Contructs a `WordLevel` model that uses the `WordLevelBuilder`'s configuration.
|
||||||
|
pub fn build(self) -> WordLevel {
|
||||||
|
let vocab_r = self
|
||||||
|
.config
|
||||||
|
.vocab
|
||||||
|
.iter()
|
||||||
|
.map(|(key, val)| (*val, key.to_owned()))
|
||||||
|
.collect();
|
||||||
|
WordLevel {
|
||||||
|
vocab: self.config.vocab,
|
||||||
|
vocab_r,
|
||||||
|
unk_token: self.config.unk_token,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct WordLevel {
|
||||||
|
vocab: HashMap<String, u32>,
|
||||||
|
vocab_r: HashMap<u32, String>,
|
||||||
|
unk_token: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WordLevel {
|
||||||
|
fn builder() -> WordLevelBuilder {
|
||||||
|
WordLevelBuilder::new()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Initialize a WordLevel model from vocab and merges file.
|
||||||
|
pub fn from_files(vocab_path: &str, unk_token: String) -> Result<WordLevel> {
|
||||||
|
// Read vocab.json
|
||||||
|
let vocab_file = File::open(vocab_path)?;
|
||||||
|
let mut vocab_file = BufReader::new(vocab_file);
|
||||||
|
let mut buffer = String::new();
|
||||||
|
let mut vocab = HashMap::new();
|
||||||
|
|
||||||
|
vocab_file.read_to_string(&mut buffer)?;
|
||||||
|
let json: Value = serde_json::from_str(&buffer)?;
|
||||||
|
|
||||||
|
match json {
|
||||||
|
Value::Object(m) => {
|
||||||
|
for (token, id) in m {
|
||||||
|
if let Value::Number(id) = id {
|
||||||
|
let id = id.as_u64().ok_or(Error::BadVocabulary)? as u32;
|
||||||
|
vocab.insert(token, id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => return Err(Box::new(Error::BadVocabulary)),
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(Self::builder().vocab(vocab).unk_token(unk_token).build())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for WordLevel {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
vocab: HashMap::new(),
|
||||||
|
vocab_r: HashMap::new(),
|
||||||
|
unk_token: String::from("<unk>"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Model for WordLevel {
|
||||||
|
fn tokenize(&self, tokens: Vec<(String, (usize, usize))>) -> Result<Vec<Token>> {
|
||||||
|
let mut output_tokens = vec![];
|
||||||
|
|
||||||
|
for (token, initial_offsets) in tokens {
|
||||||
|
let t = Token {
|
||||||
|
id: *self
|
||||||
|
.vocab
|
||||||
|
.get(&*token)
|
||||||
|
.or_else(|| self.vocab.get(&*self.unk_token))
|
||||||
|
.ok_or(Error::MissingUnkToken)?,
|
||||||
|
value: token,
|
||||||
|
offsets: initial_offsets,
|
||||||
|
};
|
||||||
|
|
||||||
|
output_tokens.push(t);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(output_tokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn token_to_id(&self, token: &str) -> Option<u32> {
|
||||||
|
self.vocab.get(token).copied()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn id_to_token(&self, id: u32) -> Option<String> {
|
||||||
|
self.vocab_r.get(&id).cloned()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_vocab_size(&self) -> usize {
|
||||||
|
self.vocab.keys().len()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn save(&self, folder: &Path, name: &str) -> Result<Vec<PathBuf>> {
|
||||||
|
// Write vocab.txt
|
||||||
|
let vocab_path: PathBuf = [folder, Path::new(&format!("{}-vocab.txt", name))]
|
||||||
|
.iter()
|
||||||
|
.collect();
|
||||||
|
let mut vocab_file = File::create(&vocab_path)?;
|
||||||
|
let mut vocab: Vec<(&String, &u32)> = self.vocab.iter().collect();
|
||||||
|
vocab.sort_unstable_by_key(|k| *k.1);
|
||||||
|
vocab_file.write_all(
|
||||||
|
&vocab
|
||||||
|
.into_iter()
|
||||||
|
.map(|(token, _)| format!("{}\n", token).as_bytes().to_owned())
|
||||||
|
.flatten()
|
||||||
|
.collect::<Vec<_>>()[..],
|
||||||
|
)?;
|
||||||
|
|
||||||
|
Ok(vec![vocab_path])
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user