Use thiserror crate for Errors (#951)

* Use `thiserror` crate for Errors

* cargo fmt

* `#[source]` redundant when `#[from]` is present
This commit is contained in:
Mishig Davaadorj
2022-03-17 09:38:21 +01:00
committed by GitHub
parent 4b6055d4fb
commit 1f1f86dd32
7 changed files with 26 additions and 119 deletions

View File

@ -59,6 +59,7 @@ cached-path = { version = "0.5", optional = true }
aho-corasick = "0.7"
paste = "1.0.6"
macro_rules_attribute = "0.0.2"
thiserror = "1.0.30"
[features]
default = ["progressbar", "http"]

View File

@ -1,5 +1,5 @@
//! [Byte Pair Encoding](https://www.aclweb.org/anthology/P16-1162/) model.
use std::{convert::From, io, iter, mem};
use std::{iter, mem};
mod model;
mod serialization;
@ -9,65 +9,32 @@ mod word;
type Pair = (u32, u32);
/// Errors that can be encountered while using or constructing a `BPE` model.
#[derive(Debug)]
#[derive(thiserror::Error, Debug)]
pub enum Error {
/// An error encountered while reading files mainly.
Io(std::io::Error),
#[error("IoError: {0}")]
Io(#[from] std::io::Error),
/// An error forwarded from Serde, while parsing JSON
JsonError(serde_json::Error),
#[error("JsonError: {0}")]
JsonError(#[from] serde_json::Error),
/// When the vocab.json file is in the wrong format
#[error("Bad vocabulary json file")]
BadVocabulary,
/// When the merges.txt file is in the wrong format. This error holds the line
/// number of the line that caused the error.
#[error("Merges text file invalid at line {0}")]
BadMerges(usize),
/// If a token found in merges, is not in the vocab
#[error("Token `{0}` out of vocabulary")]
MergeTokenOutOfVocabulary(String),
/// If the provided unk token is out of vocabulary
#[error("Unk token `{0}` not found in the vocabulary")]
UnkTokenOutOfVocabulary(String),
/// Dropout not between 0 and 1.
#[error("Dropout should be between 0 and 1")]
InvalidDropout,
}
impl From<io::Error> for Error {
fn from(error: io::Error) -> Self {
Self::Io(error)
}
}
impl From<serde_json::Error> for Error {
fn from(error: serde_json::Error) -> Self {
Self::JsonError(error)
}
}
impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
Self::Io(e) => write!(f, "IoError: {}", e),
Self::JsonError(e) => write!(f, "JsonError: {}", e),
Self::BadVocabulary => write!(f, "Bad vocabulary json file"),
Self::BadMerges(line) => write!(f, "Merges text file invalid at line {}", line),
Self::MergeTokenOutOfVocabulary(token) => {
write!(f, "Token `{}` out of vocabulary", token)
}
Self::UnkTokenOutOfVocabulary(token) => {
write!(f, "Unk token `{}` not found in the vocabulary", token)
}
Self::InvalidDropout => write!(f, "Dropout should be between 0 and 1"),
}
}
}
impl std::error::Error for Error {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::Io(e) => Some(e),
Self::JsonError(e) => Some(e),
_ => None,
}
}
}
/// Provides access to the `FirstLastIterator` to any Iterator
pub(crate) trait WithFirstLastIterator: Iterator + Sized {
fn with_first_and_last(self) -> FirstLastIterator<Self>;

View File

@ -65,31 +65,16 @@ impl std::fmt::Debug for Unigram {
static K_UNK_PENALTY: f64 = 10.0;
#[derive(Debug)]
#[derive(thiserror::Error, Debug)]
pub enum UnigramError {
#[error("The vocabulary is empty but at least <unk> is needed")]
EmptyVocabulary,
#[error("The `unk_id` is larger than vocabulary size")]
UnkIdNotInVocabulary,
#[error("Encountered an unknown token but `unk_id` is missing")]
MissingUnkId,
}
impl std::fmt::Display for UnigramError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
Self::EmptyVocabulary => {
write!(f, "The vocabulary is empty but at least <unk> is needed")
}
Self::UnkIdNotInVocabulary => {
write!(f, "The `unk_id` is larger than vocabulary size")
}
Self::MissingUnkId => {
write!(f, "Encountered an unknown token but `unk_id` is missing")
}
}
}
}
impl std::error::Error for UnigramError {}
impl Default for Unigram {
fn default() -> Self {
let vocab = vec![("<unk>".to_string(), 0.0)];

View File

@ -2,7 +2,6 @@ use super::OrderedVocabIter;
use crate::tokenizer::{Model, Result, Token};
use serde_json::Value;
use std::collections::HashMap;
use std::fmt;
use std::fs::File;
use std::io::{BufReader, Read, Write};
use std::path::{Path, PathBuf};
@ -15,24 +14,13 @@ pub use trainer::*;
type Vocab = HashMap<String, u32>;
#[derive(Debug)]
#[derive(thiserror::Error, Debug)]
pub enum Error {
#[error("WordLevel error: Missing [UNK] token from the vocabulary")]
MissingUnkToken,
#[error("Bad vocabulary json file")]
BadVocabulary,
}
impl std::error::Error for Error {}
impl fmt::Display for Error {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::MissingUnkToken => write!(
fmt,
"WordLevel error: Missing [UNK] token from the vocabulary"
),
Self::BadVocabulary => write!(fmt, "Bad vocabulary json file"),
}
}
}
struct Config {
files: Option<String>,

View File

@ -6,7 +6,6 @@ use crate::tokenizer::{Model, Result, Token};
use std::{
borrow::Cow,
collections::HashMap,
fmt,
fs::File,
io::prelude::*,
io::{BufRead, BufReader},
@ -17,22 +16,11 @@ mod serialization;
mod trainer;
pub use trainer::*;
#[derive(Debug)]
#[derive(thiserror::Error, Debug)]
pub enum Error {
#[error("WordPiece error: Missing [UNK] token from the vocabulary")]
MissingUnkToken,
}
impl std::error::Error for Error {}
impl fmt::Display for Error {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::MissingUnkToken => write!(
fmt,
"WordPiece error: Missing [UNK] token from the vocabulary"
),
}
}
}
type Vocab = HashMap<String, u32>;
type VocabR = HashMap<u32, String>;

View File

@ -11,7 +11,6 @@
use std::{
collections::HashMap,
fmt,
fs::{read_to_string, File},
io::prelude::*,
io::BufReader,
@ -240,17 +239,10 @@ where
}
}
#[derive(Debug)]
#[derive(thiserror::Error, Debug)]
#[error("{0}")]
pub struct BuilderError(String);
impl std::error::Error for BuilderError {}
impl fmt::Display for BuilderError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
/// Builder for Tokenizer structs.
///
/// `build()` fails if the `model` is missing.

View File

@ -43,30 +43,16 @@ impl Default for TruncationParams {
}
}
#[derive(Debug)]
#[derive(thiserror::Error, Debug)]
pub enum TruncationError {
/// We are supposed to truncate the pair sequence, but it has not been provided.
#[error("Truncation error: Second sequence not provided")]
SecondSequenceNotProvided,
/// We cannot truncate the target sequence enough to respect the provided max length.
#[error("Truncation error: Sequence to truncate too short to respect the provided max_length")]
SequenceTooShort,
}
impl std::fmt::Display for TruncationError {
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
use TruncationError::*;
match self {
SecondSequenceNotProvided => {
write!(fmt, "Truncation error: Second sequence not provided")
}
SequenceTooShort => write!(
fmt,
"Truncation error: Sequence to truncate too short to respect the provided max_length"
),
}
}
}
impl std::error::Error for TruncationError {}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub enum TruncationStrategy {
LongestFirst,