Introduce WordLevel model for TransformerXL (#125)

* Added lookup table model mapping string to id present in a vocab map.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* RustFmt

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Formatting.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Fix invalid void return on Rust side.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Python binding for LookupTable model

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Enable loading from Python's side.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Renamed LookupTable to WordLevel

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* RustFmt happy now.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* clippy happy now.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Addressing mismatching names.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Addressing mismatching names (one missing).

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
This commit is contained in:
Funtowicz Morgan
2020-02-05 16:51:35 +00:00
committed by GitHub
parent 9770be5661
commit 8200112e9b
6 changed files with 244 additions and 2 deletions

View File

@@ -28,6 +28,7 @@ fn models(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<models::Model>()?; m.add_class::<models::Model>()?;
m.add_class::<models::BPE>()?; m.add_class::<models::BPE>()?;
m.add_class::<models::WordPiece>()?; m.add_class::<models::WordPiece>()?;
m.add_class::<models::WordLevel>()?;
Ok(()) Ok(())
} }

View File

@@ -153,3 +153,40 @@ impl WordPiece {
} }
} }
} }
#[pyclass]
pub struct WordLevel {}
#[pymethods]
impl WordLevel {
#[staticmethod]
#[args(kwargs = "**")]
fn from_files(vocab: &str, kwargs: Option<&PyDict>) -> PyResult<Model> {
let mut unk_token = String::from("<unk>");
if let Some(kwargs) = kwargs {
for (key, val) in kwargs {
let key: &str = key.extract()?;
match key {
"unk_token" => unk_token = val.extract()?,
_ => println!("Ignored unknown kwargs option {}", key),
}
}
}
match tk::models::wordlevel::WordLevel::from_files(
vocab,
unk_token,
) {
Err(e) => {
println!("Errors: {:?}", e);
Err(exceptions::Exception::py_err(
"Error while initializing WordLevel",
))
}
Ok(model) => Ok(Model {
model: Container::Owned(Box::new(model)),
}),
}
}
}

View File

@@ -3,3 +3,4 @@ from .. import models
Model = models.Model Model = models.Model
BPE = models.BPE BPE = models.BPE
WordPiece = models.WordPiece WordPiece = models.WordPiece
WordLevel = models.WordLevel

View File

@@ -17,7 +17,7 @@ class Model:
pass pass
class BPE: class BPE(Model):
""" BytePairEncoding model class """ """ BytePairEncoding model class """
@staticmethod @staticmethod
@@ -62,7 +62,7 @@ class BPE:
pass pass
class WordPiece: class WordPiece(Model):
""" WordPiece model class """ """ WordPiece model class """
@staticmethod @staticmethod
@@ -87,3 +87,22 @@ class WordPiece:
def empty() -> Model: def empty() -> Model:
""" Instantiate an empty WordPiece Model. """ """ Instantiate an empty WordPiece Model. """
pass pass
class WordLevel(Model):
"""
Most simple tokenizer model based on mapping token from a vocab file to their corresponding id.
"""
@staticmethod
def from_files(vocab: str, unk_token: str) -> Model:
""" Instantiate a WordLevel Model from the given vocab file.
Args:
vocab: string:
Path to a vocabulary file.
unk_token: str:
The unknown token to be used by the model.
"""
pass

View File

@@ -1,4 +1,5 @@
//! Popular tokenizer models. //! Popular tokenizer models.
pub mod bpe; pub mod bpe;
pub mod wordlevel;
pub mod wordpiece; pub mod wordpiece;

View File

@@ -0,0 +1,183 @@
use crate::tokenizer::{Model, Result, Token};
use serde_json::Value;
use std::collections::HashMap;
use std::fmt;
use std::fs::File;
use std::io::{BufReader, Read, Write};
use std::path::{Path, PathBuf};
#[derive(Debug)]
pub enum Error {
MissingUnkToken,
BadVocabulary,
}
impl std::error::Error for Error {}
impl fmt::Display for Error {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
match self {
Error::MissingUnkToken => write!(
fmt,
"WordLevel error: Missing [UNK] token from the vocabulary"
),
Error::BadVocabulary => write!(fmt, "Bad vocabulary json file"),
}
}
}
struct Config {
vocab: HashMap<String, u32>,
unk_token: String,
}
/// A `WordLevelBuilder` can be used to create a `WordLevel`
/// model with a custom configuration.
pub struct WordLevelBuilder {
config: Config,
}
impl Default for WordLevelBuilder {
fn default() -> Self {
Self {
config: Config {
vocab: HashMap::new(),
unk_token: String::from("<unk>"),
},
}
}
}
impl WordLevelBuilder {
/// Construct a new `WordLevelBuilder`.
pub fn new() -> Self {
Self::default()
}
/// Set the vocab (token -> ID) mapping.
pub fn vocab(mut self, vocab: HashMap<String, u32>) -> Self {
self.config.vocab = vocab;
self
}
/// The the `UNK` token for the vocab.
pub fn unk_token(mut self, unk_token: String) -> Self {
self.config.unk_token = unk_token;
self
}
/// Contructs a `WordLevel` model that uses the `WordLevelBuilder`'s configuration.
pub fn build(self) -> WordLevel {
let vocab_r = self
.config
.vocab
.iter()
.map(|(key, val)| (*val, key.to_owned()))
.collect();
WordLevel {
vocab: self.config.vocab,
vocab_r,
unk_token: self.config.unk_token,
}
}
}
pub struct WordLevel {
vocab: HashMap<String, u32>,
vocab_r: HashMap<u32, String>,
unk_token: String,
}
impl WordLevel {
fn builder() -> WordLevelBuilder {
WordLevelBuilder::new()
}
/// Initialize a WordLevel model from vocab and merges file.
pub fn from_files(vocab_path: &str, unk_token: String) -> Result<WordLevel> {
// Read vocab.json
let vocab_file = File::open(vocab_path)?;
let mut vocab_file = BufReader::new(vocab_file);
let mut buffer = String::new();
let mut vocab = HashMap::new();
vocab_file.read_to_string(&mut buffer)?;
let json: Value = serde_json::from_str(&buffer)?;
match json {
Value::Object(m) => {
for (token, id) in m {
if let Value::Number(id) = id {
let id = id.as_u64().ok_or(Error::BadVocabulary)? as u32;
vocab.insert(token, id);
}
}
}
_ => return Err(Box::new(Error::BadVocabulary)),
};
Ok(Self::builder().vocab(vocab).unk_token(unk_token).build())
}
}
impl Default for WordLevel {
fn default() -> Self {
Self {
vocab: HashMap::new(),
vocab_r: HashMap::new(),
unk_token: String::from("<unk>"),
}
}
}
impl Model for WordLevel {
fn tokenize(&self, tokens: Vec<(String, (usize, usize))>) -> Result<Vec<Token>> {
let mut output_tokens = vec![];
for (token, initial_offsets) in tokens {
let t = Token {
id: *self
.vocab
.get(&*token)
.or_else(|| self.vocab.get(&*self.unk_token))
.ok_or(Error::MissingUnkToken)?,
value: token,
offsets: initial_offsets,
};
output_tokens.push(t);
}
Ok(output_tokens)
}
fn token_to_id(&self, token: &str) -> Option<u32> {
self.vocab.get(token).copied()
}
fn id_to_token(&self, id: u32) -> Option<String> {
self.vocab_r.get(&id).cloned()
}
fn get_vocab_size(&self) -> usize {
self.vocab.keys().len()
}
fn save(&self, folder: &Path, name: &str) -> Result<Vec<PathBuf>> {
// Write vocab.txt
let vocab_path: PathBuf = [folder, Path::new(&format!("{}-vocab.txt", name))]
.iter()
.collect();
let mut vocab_file = File::create(&vocab_path)?;
let mut vocab: Vec<(&String, &u32)> = self.vocab.iter().collect();
vocab.sort_unstable_by_key(|k| *k.1);
vocab_file.write_all(
&vocab
.into_iter()
.map(|(token, _)| format!("{}\n", token).as_bytes().to_owned())
.flatten()
.collect::<Vec<_>>()[..],
)?;
Ok(vec![vocab_path])
}
}