Fixing missing direction in TruncationParams. (#868)

This commit is contained in:
Nicolas Patry
2022-01-04 14:21:46 +01:00
committed by GitHub
parent 7069988ffe
commit 4122a33f09
2 changed files with 23 additions and 14 deletions

View File

@ -1,12 +0,0 @@
from tokenizers import ByteLevelBPETokenizer
from tokenizers import pre_tokenizers, models, Tokenizer, trainers
tokenizer = Tokenizer(models.Unigram())
tokenizer.pre_tokenizer = pre_tokenizers.WhitespaceSplit()
trainer = trainers.UnigramTrainer(
vocab_size=400000000,
show_progress=True,
special_tokens=["<s>", "<pad>", "</s>", "<unk>", "mask"]
)
tokenizer.train(["data/big.txt"], trainer)

View File

@ -8,6 +8,11 @@ pub enum TruncationDirection {
Left,
Right,
}
impl Default for TruncationDirection {
fn default() -> Self {
TruncationDirection::Right
}
}
impl std::convert::AsRef<str> for TruncationDirection {
fn as_ref(&self) -> &str {
@ -20,6 +25,7 @@ impl std::convert::AsRef<str> for TruncationDirection {
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TruncationParams {
#[serde(default)]
pub direction: TruncationDirection,
pub max_length: usize,
pub strategy: TruncationStrategy,
@ -30,9 +36,9 @@ impl Default for TruncationParams {
fn default() -> Self {
Self {
max_length: 512,
strategy: TruncationStrategy::LongestFirst,
strategy: TruncationStrategy::default(),
stride: 0,
direction: TruncationDirection::Right,
direction: TruncationDirection::default(),
}
}
}
@ -68,6 +74,12 @@ pub enum TruncationStrategy {
OnlySecond,
}
impl Default for TruncationStrategy {
fn default() -> Self {
TruncationStrategy::LongestFirst
}
}
impl std::convert::AsRef<str> for TruncationStrategy {
fn as_ref(&self) -> &str {
match self {
@ -325,4 +337,13 @@ mod tests {
truncate_and_assert(get_medium(), get_medium(), &params, 0, 0);
truncate_and_assert(get_long(), get_long(), &params, 0, 0);
}
#[test]
fn test_deserialize_defaults() {
let old_truncation_params = r#"{"max_length":256,"strategy":"LongestFirst","stride":0}"#;
let params: TruncationParams = serde_json::from_str(old_truncation_params).unwrap();
assert_eq!(params.direction, TruncationDirection::Right);
}
}