mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
Use thiserror
crate for Errors (#951)
* Use `thiserror` crate for Errors * cargo fmt * `#[source]` redundant when `#[from]` is present
This commit is contained in:
@ -59,6 +59,7 @@ cached-path = { version = "0.5", optional = true }
|
||||
aho-corasick = "0.7"
|
||||
paste = "1.0.6"
|
||||
macro_rules_attribute = "0.0.2"
|
||||
thiserror = "1.0.30"
|
||||
|
||||
[features]
|
||||
default = ["progressbar", "http"]
|
||||
|
@ -1,5 +1,5 @@
|
||||
//! [Byte Pair Encoding](https://www.aclweb.org/anthology/P16-1162/) model.
|
||||
use std::{convert::From, io, iter, mem};
|
||||
use std::{iter, mem};
|
||||
|
||||
mod model;
|
||||
mod serialization;
|
||||
@ -9,65 +9,32 @@ mod word;
|
||||
type Pair = (u32, u32);
|
||||
|
||||
/// Errors that can be encountered while using or constructing a `BPE` model.
|
||||
#[derive(Debug)]
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum Error {
|
||||
/// An error encountered while reading files mainly.
|
||||
Io(std::io::Error),
|
||||
#[error("IoError: {0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
/// An error forwarded from Serde, while parsing JSON
|
||||
JsonError(serde_json::Error),
|
||||
#[error("JsonError: {0}")]
|
||||
JsonError(#[from] serde_json::Error),
|
||||
/// When the vocab.json file is in the wrong format
|
||||
#[error("Bad vocabulary json file")]
|
||||
BadVocabulary,
|
||||
/// When the merges.txt file is in the wrong format. This error holds the line
|
||||
/// number of the line that caused the error.
|
||||
#[error("Merges text file invalid at line {0}")]
|
||||
BadMerges(usize),
|
||||
/// If a token found in merges, is not in the vocab
|
||||
#[error("Token `{0}` out of vocabulary")]
|
||||
MergeTokenOutOfVocabulary(String),
|
||||
/// If the provided unk token is out of vocabulary
|
||||
#[error("Unk token `{0}` not found in the vocabulary")]
|
||||
UnkTokenOutOfVocabulary(String),
|
||||
/// Dropout not between 0 and 1.
|
||||
#[error("Dropout should be between 0 and 1")]
|
||||
InvalidDropout,
|
||||
}
|
||||
|
||||
impl From<io::Error> for Error {
|
||||
fn from(error: io::Error) -> Self {
|
||||
Self::Io(error)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<serde_json::Error> for Error {
|
||||
fn from(error: serde_json::Error) -> Self {
|
||||
Self::JsonError(error)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Error {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Io(e) => write!(f, "IoError: {}", e),
|
||||
Self::JsonError(e) => write!(f, "JsonError: {}", e),
|
||||
Self::BadVocabulary => write!(f, "Bad vocabulary json file"),
|
||||
Self::BadMerges(line) => write!(f, "Merges text file invalid at line {}", line),
|
||||
Self::MergeTokenOutOfVocabulary(token) => {
|
||||
write!(f, "Token `{}` out of vocabulary", token)
|
||||
}
|
||||
Self::UnkTokenOutOfVocabulary(token) => {
|
||||
write!(f, "Unk token `{}` not found in the vocabulary", token)
|
||||
}
|
||||
Self::InvalidDropout => write!(f, "Dropout should be between 0 and 1"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for Error {
|
||||
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
||||
match self {
|
||||
Self::Io(e) => Some(e),
|
||||
Self::JsonError(e) => Some(e),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Provides access to the `FirstLastIterator` to any Iterator
|
||||
pub(crate) trait WithFirstLastIterator: Iterator + Sized {
|
||||
fn with_first_and_last(self) -> FirstLastIterator<Self>;
|
||||
|
@ -65,31 +65,16 @@ impl std::fmt::Debug for Unigram {
|
||||
|
||||
static K_UNK_PENALTY: f64 = 10.0;
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum UnigramError {
|
||||
#[error("The vocabulary is empty but at least <unk> is needed")]
|
||||
EmptyVocabulary,
|
||||
#[error("The `unk_id` is larger than vocabulary size")]
|
||||
UnkIdNotInVocabulary,
|
||||
#[error("Encountered an unknown token but `unk_id` is missing")]
|
||||
MissingUnkId,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for UnigramError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::EmptyVocabulary => {
|
||||
write!(f, "The vocabulary is empty but at least <unk> is needed")
|
||||
}
|
||||
Self::UnkIdNotInVocabulary => {
|
||||
write!(f, "The `unk_id` is larger than vocabulary size")
|
||||
}
|
||||
Self::MissingUnkId => {
|
||||
write!(f, "Encountered an unknown token but `unk_id` is missing")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for UnigramError {}
|
||||
|
||||
impl Default for Unigram {
|
||||
fn default() -> Self {
|
||||
let vocab = vec![("<unk>".to_string(), 0.0)];
|
||||
|
@ -2,7 +2,6 @@ use super::OrderedVocabIter;
|
||||
use crate::tokenizer::{Model, Result, Token};
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use std::fmt;
|
||||
use std::fs::File;
|
||||
use std::io::{BufReader, Read, Write};
|
||||
use std::path::{Path, PathBuf};
|
||||
@ -15,24 +14,13 @@ pub use trainer::*;
|
||||
|
||||
type Vocab = HashMap<String, u32>;
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum Error {
|
||||
#[error("WordLevel error: Missing [UNK] token from the vocabulary")]
|
||||
MissingUnkToken,
|
||||
#[error("Bad vocabulary json file")]
|
||||
BadVocabulary,
|
||||
}
|
||||
impl std::error::Error for Error {}
|
||||
|
||||
impl fmt::Display for Error {
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
|
||||
match self {
|
||||
Self::MissingUnkToken => write!(
|
||||
fmt,
|
||||
"WordLevel error: Missing [UNK] token from the vocabulary"
|
||||
),
|
||||
Self::BadVocabulary => write!(fmt, "Bad vocabulary json file"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct Config {
|
||||
files: Option<String>,
|
||||
|
@ -6,7 +6,6 @@ use crate::tokenizer::{Model, Result, Token};
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
collections::HashMap,
|
||||
fmt,
|
||||
fs::File,
|
||||
io::prelude::*,
|
||||
io::{BufRead, BufReader},
|
||||
@ -17,22 +16,11 @@ mod serialization;
|
||||
mod trainer;
|
||||
pub use trainer::*;
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum Error {
|
||||
#[error("WordPiece error: Missing [UNK] token from the vocabulary")]
|
||||
MissingUnkToken,
|
||||
}
|
||||
impl std::error::Error for Error {}
|
||||
|
||||
impl fmt::Display for Error {
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
|
||||
match self {
|
||||
Self::MissingUnkToken => write!(
|
||||
fmt,
|
||||
"WordPiece error: Missing [UNK] token from the vocabulary"
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type Vocab = HashMap<String, u32>;
|
||||
type VocabR = HashMap<u32, String>;
|
||||
|
@ -11,7 +11,6 @@
|
||||
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
fmt,
|
||||
fs::{read_to_string, File},
|
||||
io::prelude::*,
|
||||
io::BufReader,
|
||||
@ -240,17 +239,10 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
#[error("{0}")]
|
||||
pub struct BuilderError(String);
|
||||
|
||||
impl std::error::Error for BuilderError {}
|
||||
|
||||
impl fmt::Display for BuilderError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
/// Builder for Tokenizer structs.
|
||||
///
|
||||
/// `build()` fails if the `model` is missing.
|
||||
|
@ -43,30 +43,16 @@ impl Default for TruncationParams {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum TruncationError {
|
||||
/// We are supposed to truncate the pair sequence, but it has not been provided.
|
||||
#[error("Truncation error: Second sequence not provided")]
|
||||
SecondSequenceNotProvided,
|
||||
/// We cannot truncate the target sequence enough to respect the provided max length.
|
||||
#[error("Truncation error: Sequence to truncate too short to respect the provided max_length")]
|
||||
SequenceTooShort,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for TruncationError {
|
||||
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
use TruncationError::*;
|
||||
match self {
|
||||
SecondSequenceNotProvided => {
|
||||
write!(fmt, "Truncation error: Second sequence not provided")
|
||||
}
|
||||
SequenceTooShort => write!(
|
||||
fmt,
|
||||
"Truncation error: Sequence to truncate too short to respect the provided max_length"
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
impl std::error::Error for TruncationError {}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
|
||||
pub enum TruncationStrategy {
|
||||
LongestFirst,
|
||||
|
Reference in New Issue
Block a user