Python - Trainers can get/set their attributes

This commit is contained in:
Anthony MOI
2020-11-24 17:46:58 -05:00
committed by Anthony MOI
parent 3eb7ef6d0a
commit a351d1c604
7 changed files with 679 additions and 51 deletions

View File

@ -1,6 +1,8 @@
use std::collections::HashMap;
use std::collections::{hash_map::DefaultHasher, HashMap};
use std::hash::{Hash, Hasher};
use numpy::PyArray1;
use pyo3::class::basic::CompareOp;
use pyo3::exceptions;
use pyo3::prelude::*;
use pyo3::types::*;
@ -106,6 +108,19 @@ impl PyAddedToken {
}
}
impl From<tk::AddedToken> for PyAddedToken {
fn from(token: tk::AddedToken) -> Self {
Self {
content: token.content,
single_word: Some(token.single_word),
lstrip: Some(token.lstrip),
rstrip: Some(token.rstrip),
normalized: Some(token.normalized),
is_special_token: !token.normalized,
}
}
}
#[pymethods]
impl PyAddedToken {
#[new]
@ -205,6 +220,21 @@ impl PyObjectProtocol for PyAddedToken {
bool_to_python(token.normalized)
))
}
fn __richcmp__(&self, other: Py<PyAddedToken>, op: CompareOp) -> bool {
use CompareOp::*;
Python::with_gil(|py| match op {
Lt | Le | Gt | Ge => false,
Eq => self.get_token() == other.borrow(py).get_token(),
Ne => self.get_token() != other.borrow(py).get_token(),
})
}
fn __hash__(&self) -> u64 {
let mut hasher = DefaultHasher::new();
self.get_token().hash(&mut hasher);
hasher.finish()
}
}
struct TextInputSequence<'s>(tk::InputSequence<'s>);