mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Fixing missing direction
in TruncationParams. (#868)
This commit is contained in:
@ -1,12 +0,0 @@
|
|||||||
from tokenizers import ByteLevelBPETokenizer
|
|
||||||
from tokenizers import pre_tokenizers, models, Tokenizer, trainers
|
|
||||||
|
|
||||||
tokenizer = Tokenizer(models.Unigram())
|
|
||||||
tokenizer.pre_tokenizer = pre_tokenizers.WhitespaceSplit()
|
|
||||||
trainer = trainers.UnigramTrainer(
|
|
||||||
vocab_size=400000000,
|
|
||||||
show_progress=True,
|
|
||||||
special_tokens=["<s>", "<pad>", "</s>", "<unk>", "mask"]
|
|
||||||
)
|
|
||||||
tokenizer.train(["data/big.txt"], trainer)
|
|
||||||
|
|
@ -8,6 +8,11 @@ pub enum TruncationDirection {
|
|||||||
Left,
|
Left,
|
||||||
Right,
|
Right,
|
||||||
}
|
}
|
||||||
|
impl Default for TruncationDirection {
|
||||||
|
fn default() -> Self {
|
||||||
|
TruncationDirection::Right
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl std::convert::AsRef<str> for TruncationDirection {
|
impl std::convert::AsRef<str> for TruncationDirection {
|
||||||
fn as_ref(&self) -> &str {
|
fn as_ref(&self) -> &str {
|
||||||
@ -20,6 +25,7 @@ impl std::convert::AsRef<str> for TruncationDirection {
|
|||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct TruncationParams {
|
pub struct TruncationParams {
|
||||||
|
#[serde(default)]
|
||||||
pub direction: TruncationDirection,
|
pub direction: TruncationDirection,
|
||||||
pub max_length: usize,
|
pub max_length: usize,
|
||||||
pub strategy: TruncationStrategy,
|
pub strategy: TruncationStrategy,
|
||||||
@ -30,9 +36,9 @@ impl Default for TruncationParams {
|
|||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {
|
Self {
|
||||||
max_length: 512,
|
max_length: 512,
|
||||||
strategy: TruncationStrategy::LongestFirst,
|
strategy: TruncationStrategy::default(),
|
||||||
stride: 0,
|
stride: 0,
|
||||||
direction: TruncationDirection::Right,
|
direction: TruncationDirection::default(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -68,6 +74,12 @@ pub enum TruncationStrategy {
|
|||||||
OnlySecond,
|
OnlySecond,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Default for TruncationStrategy {
|
||||||
|
fn default() -> Self {
|
||||||
|
TruncationStrategy::LongestFirst
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl std::convert::AsRef<str> for TruncationStrategy {
|
impl std::convert::AsRef<str> for TruncationStrategy {
|
||||||
fn as_ref(&self) -> &str {
|
fn as_ref(&self) -> &str {
|
||||||
match self {
|
match self {
|
||||||
@ -325,4 +337,13 @@ mod tests {
|
|||||||
truncate_and_assert(get_medium(), get_medium(), ¶ms, 0, 0);
|
truncate_and_assert(get_medium(), get_medium(), ¶ms, 0, 0);
|
||||||
truncate_and_assert(get_long(), get_long(), ¶ms, 0, 0);
|
truncate_and_assert(get_long(), get_long(), ¶ms, 0, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_deserialize_defaults() {
|
||||||
|
let old_truncation_params = r#"{"max_length":256,"strategy":"LongestFirst","stride":0}"#;
|
||||||
|
|
||||||
|
let params: TruncationParams = serde_json::from_str(old_truncation_params).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(params.direction, TruncationDirection::Right);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user