Python - Expose encode method on Model

This commit is contained in:
Anthony MOI
2020-03-18 23:04:32 -04:00
parent 8de6ef5a37
commit a397a1da63
6 changed files with 139 additions and 34 deletions

View File

@@ -4,7 +4,7 @@ use crate::error::PyError;
use pyo3::prelude::*;
use pyo3::types::*;
use pyo3::{PyObjectProtocol, PySequenceProtocol};
use tk::tokenizer::PaddingDirection;
use tk::tokenizer::{Offsets, PaddingDirection};
#[pyclass(dict)]
#[repr(transparent)]

View File

@@ -1,5 +1,6 @@
extern crate tokenizers as tk;
use super::encoding::Encoding;
use super::error::ToPyResult;
use super::utils::Container;
use pyo3::exceptions;
@@ -35,6 +36,67 @@ impl Model {
.map(|path| path.to_string_lossy().into_owned())
.collect())
}
#[args(type_id = 0)]
fn encode(&self, sequence: &PyList, type_id: u32) -> PyResult<Encoding> {
if sequence.is_empty() {
return Ok(Encoding::new(tk::tokenizer::Encoding::default()));
}
enum Mode {
NoOffsets,
Offsets,
};
let mode = sequence
.iter()
.next()
.map(|item| {
if item.extract::<String>().is_ok() {
Ok(Mode::NoOffsets)
} else if item.extract::<(String, (usize, usize))>().is_ok() {
Ok(Mode::Offsets)
} else {
Err(exceptions::ValueError::py_err(
"Input must be a list[str] or list[(str, (int, int))]",
))
}
})
.unwrap()?;
let mut total_len = 0;
let sequence = sequence
.iter()
.enumerate()
.map(|(i, item)| match mode {
Mode::NoOffsets => item
.extract::<String>()
.map_err(|_| {
exceptions::ValueError::py_err(format!(
"Value at index {} should be a `str`",
i
))
})
.map(|s| {
let len = s.chars().count();
total_len += len;
(s, (total_len - len, total_len))
}),
Mode::Offsets => item.extract::<(String, (usize, usize))>().map_err(|_| {
exceptions::ValueError::py_err(format!(
"Value at index {} should be a `(str, (int, int))`",
i
))
}),
})
.collect::<Result<Vec<_>, PyErr>>()?;
ToPyResult(self.model.execute(|model| {
model
.tokenize(sequence)
.map(|tokens| Encoding::new(tk::tokenizer::Encoding::from_tokens(tokens, type_id)))
}))
.into()
}
}
/// BPE Model

View File

@@ -2,6 +2,8 @@ from .. import Tokenizer, Encoding
from typing import List, Union, Tuple, Optional
Offsets = Tuple[int, int]
class BaseTokenizer:
def __init__(self, tokenizer: Tokenizer, parameters=None):
@@ -136,6 +138,25 @@ class BaseTokenizer:
"""
return self._tokenizer.normalize(sequence)
def encode_tokenized(
self, sequence: Union[List[str], List[Tuple[str, Offsets]]], type_id: int = 0
) -> Encoding:
""" Encode the given tokenized sequence. Let us skip the Normalizer and PreTokenizer
by providing already tokenized substrings.
Args:
sequence: Union[List[str], List[Tuple[str, Offsets]]]:
Either a list of strings, or a list of tuples (string, offsets) where offset
is a tuple (int, int)
type_id: int:
The type id of the given sequence
Returns:
An Encoding
"""
return self._tokenizer.model.encode(sequence)
def encode(
self, sequence: str, pair: Optional[str] = None, add_special_tokens: bool = True
) -> Encoding:

View File

@@ -1,4 +1,7 @@
from typing import List, Optional
from .. import Encoding
from typing import List, Optional, Union, Tuple
Offsets = Tuple[int, int]
class Model:
""" Base class for all models
@@ -15,6 +18,23 @@ class Model:
Any file with the same name that already exist in this folder will be overwritten.
"""
pass
def encode(
self, sequence: Union[List[str], List[Tuple[str, Offsets]]], type_id: int = 0
) -> Encoding:
""" Encode the given list of string or tuples (string, offsets)
Args:
sequence: Union[List[str], List[Tuple[str, Tuple[int, int]]]]:
Either a list of strings, or a list of tuples (string, offsets) where offset
is a tuple (int, int)
type_id: int:
The type id of the given sequence
Returns:
An Encoding
"""
pass
class BPE(Model):
""" BytePairEncoding model class """