mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Fixing missing direction
in TruncationParams. (#868)
This commit is contained in:
@ -1,12 +0,0 @@
|
||||
from tokenizers import ByteLevelBPETokenizer
|
||||
from tokenizers import pre_tokenizers, models, Tokenizer, trainers
|
||||
|
||||
tokenizer = Tokenizer(models.Unigram())
|
||||
tokenizer.pre_tokenizer = pre_tokenizers.WhitespaceSplit()
|
||||
trainer = trainers.UnigramTrainer(
|
||||
vocab_size=400000000,
|
||||
show_progress=True,
|
||||
special_tokens=["<s>", "<pad>", "</s>", "<unk>", "mask"]
|
||||
)
|
||||
tokenizer.train(["data/big.txt"], trainer)
|
||||
|
@ -8,6 +8,11 @@ pub enum TruncationDirection {
|
||||
Left,
|
||||
Right,
|
||||
}
|
||||
impl Default for TruncationDirection {
|
||||
fn default() -> Self {
|
||||
TruncationDirection::Right
|
||||
}
|
||||
}
|
||||
|
||||
impl std::convert::AsRef<str> for TruncationDirection {
|
||||
fn as_ref(&self) -> &str {
|
||||
@ -20,6 +25,7 @@ impl std::convert::AsRef<str> for TruncationDirection {
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TruncationParams {
|
||||
#[serde(default)]
|
||||
pub direction: TruncationDirection,
|
||||
pub max_length: usize,
|
||||
pub strategy: TruncationStrategy,
|
||||
@ -30,9 +36,9 @@ impl Default for TruncationParams {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_length: 512,
|
||||
strategy: TruncationStrategy::LongestFirst,
|
||||
strategy: TruncationStrategy::default(),
|
||||
stride: 0,
|
||||
direction: TruncationDirection::Right,
|
||||
direction: TruncationDirection::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -68,6 +74,12 @@ pub enum TruncationStrategy {
|
||||
OnlySecond,
|
||||
}
|
||||
|
||||
impl Default for TruncationStrategy {
|
||||
fn default() -> Self {
|
||||
TruncationStrategy::LongestFirst
|
||||
}
|
||||
}
|
||||
|
||||
impl std::convert::AsRef<str> for TruncationStrategy {
|
||||
fn as_ref(&self) -> &str {
|
||||
match self {
|
||||
@ -325,4 +337,13 @@ mod tests {
|
||||
truncate_and_assert(get_medium(), get_medium(), ¶ms, 0, 0);
|
||||
truncate_and_assert(get_long(), get_long(), ¶ms, 0, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deserialize_defaults() {
|
||||
let old_truncation_params = r#"{"max_length":256,"strategy":"LongestFirst","stride":0}"#;
|
||||
|
||||
let params: TruncationParams = serde_json::from_str(old_truncation_params).unwrap();
|
||||
|
||||
assert_eq!(params.direction, TruncationDirection::Right);
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user