New version. Staticmethods need to return a IntoPy<PyObject>

which is non trivial for PyClassInitializer. Instead I added a lower
staticmethod that returns raw objects, and the `from_file(s)` methods
are implemented directly in Python.
This commit is contained in:
Nicolas Patry
2020-09-22 18:10:56 +02:00
parent 98a30eead1
commit 60c1e25910
8 changed files with 72 additions and 31 deletions

View File

@@ -80,6 +80,10 @@ class BertWordPieceTokenizer(BaseTokenizer):
super().__init__(tokenizer, parameters) super().__init__(tokenizer, parameters)
def from_file(vocab: str, **kwargs):
vocab = WordPiece.read_file(vocab)
return BertWordPieceTokenizer(vocab, **kwargs)
def train( def train(
self, self,
files: Union[str, List[str]], files: Union[str, List[str]],

View File

@@ -77,6 +77,11 @@ class ByteLevelBPETokenizer(BaseTokenizer):
super().__init__(tokenizer, parameters) super().__init__(tokenizer, parameters)
@staticmethod
def from_files(vocab_filename: str, merges_filename: str, **kwargs):
vocab, merges = BPE.read_files(vocab_filename, merges_filename)
return ByteLevelBPETokenizer(vocab, merges, **kwargs)
def train( def train(
self, self,
files: Union[str, List[str]], files: Union[str, List[str]],

View File

@@ -94,6 +94,11 @@ class CharBPETokenizer(BaseTokenizer):
super().__init__(tokenizer, parameters) super().__init__(tokenizer, parameters)
@staticmethod
def from_files(vocab_filename: str, merges_filename: str, **kwargs):
vocab, merges = BPE.read_files(vocab_filename, merges_filename)
return CharBPETokenizer(vocab, merges, **kwargs)
def train( def train(
self, self,
files: Union[str, List[str]], files: Union[str, List[str]],

View File

@@ -47,6 +47,11 @@ class SentencePieceBPETokenizer(BaseTokenizer):
super().__init__(tokenizer, parameters) super().__init__(tokenizer, parameters)
@staticmethod
def from_files(vocab_filename: str, merges_filename: str, **kwargs):
vocab, merges = BPE.read_files(vocab_filename, merges_filename)
return SentencePieceBPETokenizer(vocab, merges, **kwargs)
def train( def train(
self, self,
files: Union[str, List[str]], files: Union[str, List[str]],

View File

@@ -62,8 +62,6 @@ class BPE(Model):
fuse_unk: (`optional`) bool: fuse_unk: (`optional`) bool:
Multiple unk tokens get fused into only 1 Multiple unk tokens get fused into only 1
""" """
@staticmethod
def __init__( def __init__(
self, self,
vocab: Optional[Union[str, Dict[str, int]]], vocab: Optional[Union[str, Dict[str, int]]],
@@ -77,6 +75,15 @@ class BPE(Model):
): ):
pass pass
@staticmethod
def read_files(vocab_filename: str, merges_filename: str) -> Tuple[Vocab, Merges]:
pass
@staticmethod
def from_files(vocab_filename: str, merges_filename: str, **kwargs) -> BPE:
vocab, merges = BPE.read_files(vocab_filename, merges_filename)
return BPE(vocab, merges, **kwargs)
class WordPiece(Model): class WordPiece(Model):
""" WordPiece model class """ WordPiece model class
@@ -101,6 +108,15 @@ class WordPiece(Model):
): ):
pass pass
@staticmethod
def read_file(vocab_filename: str) -> Tuple[Vocab]:
pass
@staticmethod
def from_files(vocab_filename: str, **kwargs) -> WordPiece:
vocab = WordPiece.read_files(vocab_filename)
return WordPiece(vocab, **kwargs)
class WordLevel(Model): class WordLevel(Model):
""" """
Most simple tokenizer model based on mapping token from a vocab file to their corresponding id. Most simple tokenizer model based on mapping token from a vocab file to their corresponding id.
@@ -118,6 +134,15 @@ class WordLevel(Model):
def __init__(self, vocab: Optional[Union[str, Dict[str, int]]], unk_token: Optional[str]): def __init__(self, vocab: Optional[Union[str, Dict[str, int]]], unk_token: Optional[str]):
pass pass
@staticmethod
def read_file(vocab_filename: str) -> Tuple[Vocab]:
pass
@staticmethod
def from_files(vocab_filename: str, **kwargs) -> WordLevel:
vocab = WordLevel.read_files(vocab_filename)
return WordLevel(vocab, **kwargs)
class Unigram(Model): class Unigram(Model):
"""UnigramEncoding model class """UnigramEncoding model class

View File

@@ -7,7 +7,7 @@ use pyo3::exceptions;
use pyo3::prelude::*; use pyo3::prelude::*;
use pyo3::types::*; use pyo3::types::*;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tk::models::bpe::{BpeBuilder, BPE}; use tk::models::bpe::{BpeBuilder, Merges, Vocab, BPE};
use tk::models::unigram::Unigram; use tk::models::unigram::Unigram;
use tk::models::wordlevel::WordLevel; use tk::models::wordlevel::WordLevel;
use tk::models::wordpiece::{WordPiece, WordPieceBuilder}; use tk::models::wordpiece::{WordPiece, WordPieceBuilder};
@@ -217,16 +217,13 @@ impl PyBPE {
} }
#[staticmethod] #[staticmethod]
#[args(kwargs = "**")] fn read_files(vocab_filename: &str, merges_filename: &str) -> PyResult<(Vocab, Merges)> {
fn from_files( BPE::read_files(vocab_filename, merges_filename).map_err(|e| {
vocab_filename: String, exceptions::PyValueError::new_err(format!(
merges_filename: String, "Error while reading vocab&merges files: {}",
kwargs: Option<&PyDict>, e
) -> PyResult<(Self, PyModel)> { ))
let mut builder = BPE::builder(); })
builder = builder.files(vocab_filename, merges_filename);
PyBPE::with_builder(builder, kwargs)
} }
} }
@@ -292,10 +289,10 @@ impl PyWordPiece {
} }
#[staticmethod] #[staticmethod]
fn from_file(vocab: String, kwargs: Option<&PyDict>) -> PyResult<(Self, PyModel)> { fn read_file(vocab_filename: &str) -> PyResult<Vocab> {
let mut builder = WordPiece::builder(); WordPiece::read_files(vocab_filename).map_err(|e| {
builder = builder.files(vocab); exceptions::PyValueError::new_err(format!("Error while reading WordPiece file: {}", e))
PyWordPiece::with_builder(builder, kwargs) })
} }
} }
@@ -356,15 +353,10 @@ impl PyWordLevel {
} }
#[staticmethod] #[staticmethod]
fn from_file(vocab_filename: &str, kwargs: Option<&PyDict>) -> PyResult<(Self, PyModel)> { fn read_file(vocab_filename: &str) -> PyResult<Vocab> {
let unk_token = PyWordLevel::get_unk(kwargs)?; WordLevel::read_files(vocab_filename).map_err(|e| {
let model = WordLevel::from_files(vocab_filename, unk_token).map_err(|e| { exceptions::PyValueError::new_err(format!("Error while reading WordLevel file: {}", e))
exceptions::PyException::new_err(format!( })
"Error while loading WordLevel from file: {}",
e
))
})?;
Ok((PyWordLevel {}, PyModel::new(Arc::new(model.into()))))
} }
} }

View File

@@ -12,9 +12,9 @@ use std::{
path::{Path, PathBuf}, path::{Path, PathBuf},
}; };
type Vocab = HashMap<String, u32>; pub type Vocab = HashMap<String, u32>;
type VocabR = HashMap<u32, String>; type VocabR = HashMap<u32, String>;
type Merges = HashMap<Pair, (u32, u32)>; pub type Merges = HashMap<Pair, (u32, u32)>;
struct Config { struct Config {
files: Option<(String, String)>, files: Option<(String, String)>,

View File

@@ -9,6 +9,8 @@ use std::path::{Path, PathBuf};
mod serialization; mod serialization;
type Vocab = HashMap<String, u32>;
#[derive(Debug)] #[derive(Debug)]
pub enum Error { pub enum Error {
MissingUnkToken, MissingUnkToken,
@@ -105,9 +107,7 @@ impl WordLevel {
WordLevelBuilder::new() WordLevelBuilder::new()
} }
/// Initialize a WordLevel model from vocab and merges file. pub fn read_files(vocab_path: &str) -> Result<Vocab> {
pub fn from_files(vocab_path: &str, unk_token: String) -> Result<WordLevel> {
// Read vocab.json
let vocab_file = File::open(vocab_path)?; let vocab_file = File::open(vocab_path)?;
let mut vocab_file = BufReader::new(vocab_file); let mut vocab_file = BufReader::new(vocab_file);
let mut buffer = String::new(); let mut buffer = String::new();
@@ -127,7 +127,12 @@ impl WordLevel {
} }
_ => return Err(Box::new(Error::BadVocabulary)), _ => return Err(Box::new(Error::BadVocabulary)),
}; };
Ok(vocab)
}
/// Initialize a WordLevel model from vocab and merges file.
pub fn from_files(vocab_path: &str, unk_token: String) -> Result<WordLevel> {
let vocab = WordLevel::read_files(vocab_path)?;
Ok(Self::builder().vocab(vocab).unk_token(unk_token).build()) Ok(Self::builder().vocab(vocab).unk_token(unk_token).build())
} }
} }