Changed rust api for merges, that is now Vec<(String, String)>

This commit is contained in:
Nicolas Patry
2020-09-23 19:09:46 +02:00
parent 36832bfa12
commit 95cc8c47ad
13 changed files with 94 additions and 84 deletions

View File

@@ -49,7 +49,7 @@ export namespace BPE {
* @param options BPE model options
* @param __callback Callback called when model is loaded
*/
export function fromFiles(
export function fromFile(
vocab: string,
merges: string,
options: BPEOptions,
@@ -62,7 +62,7 @@ export namespace BPE {
* @param merges Path to a merge file
* @param __callback Callback called when model is loaded
*/
export function fromFiles(
export function fromFile(
vocab: string,
merges: string,
__callback: (err: Error, encoding: Model) => void
@@ -100,7 +100,7 @@ export namespace WordPiece {
* @param options WordPiece model options
* @param __callback Callback called when model is loaded
*/
export function fromFiles(
export function fromFile(
vocab: string,
options: WordPieceOptions,
__callback: (err: Error, encoding: Model) => void
@@ -111,7 +111,7 @@ export namespace WordPiece {
* @param vocab Path to a vocabulary file
* @param __callback Callback called when model is loaded
*/
export function fromFiles(
export function fromFile(
vocab: string,
__callback: (err: Error, encoding: Model) => void
): void;

View File

@@ -3,12 +3,12 @@ const native = require("./native");
module.exports = {
BPE: {
init: native.models_BPE_init,
fromFiles: native.models_BPE_from_files,
fromFile: native.models_BPE_from_file,
empty: native.models_BPE_empty,
},
WordPiece: {
init: native.models_WordPiece_init,
fromFiles: native.models_WordPiece_from_files,
fromFile: native.models_WordPiece_from_file,
empty: native.models_WordPiece_empty,
},
};

View File

@@ -6,25 +6,25 @@ import { BPE, WordPiece } from "./models";
const MOCKS_DIR = __dirname + "/__mocks__";
describe("WordPiece", () => {
describe("fromFiles", () => {
describe("fromFile", () => {
it("throws if called with only one argument", () => {
expect(() => (WordPiece as any).fromFiles("test")).toThrow("not enough arguments");
expect(() => (WordPiece as any).fromFile("test")).toThrow("not enough arguments");
});
it("throws if called with 2 arguments without a callback as third argument", () => {
expect(() => (WordPiece as any).fromFiles("test", {})).toThrow(
expect(() => (WordPiece as any).fromFile("test", {})).toThrow(
"not enough arguments"
);
});
describe("when called with 2 correct arguments", () => {
it("returns `undefined` ", () => {
expect(WordPiece.fromFiles(`${MOCKS_DIR}/vocab.txt`, () => {})).toBeUndefined();
expect(WordPiece.fromFile(`${MOCKS_DIR}/vocab.txt`, () => {})).toBeUndefined();
});
it("has its callback called with the loaded model", () => {
return new Promise((done) => {
WordPiece.fromFiles(`${MOCKS_DIR}/vocab.txt`, (err, model) => {
WordPiece.fromFile(`${MOCKS_DIR}/vocab.txt`, (err, model) => {
expect(model).toBeDefined();
done();
});
@@ -35,13 +35,13 @@ describe("WordPiece", () => {
describe("when called with 3 correct arguments", () => {
it("returns `undefined`", () => {
expect(
WordPiece.fromFiles(`${MOCKS_DIR}/vocab.txt`, {}, () => {})
WordPiece.fromFile(`${MOCKS_DIR}/vocab.txt`, {}, () => {})
).toBeUndefined();
});
it("has its callback called with the loaded model", () => {
return new Promise((done) => {
WordPiece.fromFiles(`${MOCKS_DIR}/vocab.txt`, {}, (err, model) => {
WordPiece.fromFile(`${MOCKS_DIR}/vocab.txt`, {}, (err, model) => {
expect(model).toBeDefined();
done();
});
@@ -52,13 +52,13 @@ describe("WordPiece", () => {
});
describe("BPE", () => {
describe("fromFiles", () => {
describe("fromFile", () => {
it("throws if called with only two arguments", () => {
expect(() => (BPE as any).fromFiles("test", "bis")).toThrow("not enough arguments");
expect(() => (BPE as any).fromFile("test", "bis")).toThrow("not enough arguments");
});
it("throws if called with 3 arguments without a callback as last argument", () => {
expect(() => (BPE as any).fromFiles("test", "bis", {})).toThrow(
expect(() => (BPE as any).fromFile("test", "bis", {})).toThrow(
"not enough arguments"
);
});
@@ -67,13 +67,13 @@ describe("BPE", () => {
describe("when called with 3 correct arguments", () => {
it("returns `undefined`", () => {
expect(
BPE.fromFiles(`${MOCKS_DIR}/vocab.json`, `${MOCKS_DIR}/merges.txt`, () => {})
BPE.fromFile(`${MOCKS_DIR}/vocab.json`, `${MOCKS_DIR}/merges.txt`, () => {})
).toBeUndefined();
});
it("has its callback called with the loaded model", () => {
return new Promise((done) => {
BPE.fromFiles(
BPE.fromFile(
`${MOCKS_DIR}/vocab.json`,
`${MOCKS_DIR}/merges.txt`,
(err, model) => {
@@ -88,13 +88,13 @@ describe("BPE", () => {
describe("when called with 4 correct arguments", () => {
it("returns `undefined`", () => {
expect(
BPE.fromFiles(`${MOCKS_DIR}/vocab.json`, `${MOCKS_DIR}/merges.txt`, {}, () => {})
BPE.fromFile(`${MOCKS_DIR}/vocab.json`, `${MOCKS_DIR}/merges.txt`, {}, () => {})
).toBeUndefined();
});
it("has its callback called with the loaded model", () => {
return new Promise((done) => {
BPE.fromFiles(
BPE.fromFile(
`${MOCKS_DIR}/vocab.json`,
`${MOCKS_DIR}/merges.txt`,
{},
@@ -108,16 +108,13 @@ describe("BPE", () => {
});
describe("When initialized from memory", () => {
it("returns `undefined`", () => {
const merges = new Map();
merges.set([0, 1], [0, 2]);
expect((BPE as any).init({ a: 0, b: 1, ab: 2 }, merges, () => {})).toBeUndefined();
expect(
(BPE as any).init({ a: 0, b: 1, ab: 2 }, [["a", "b"]], () => {})
).toBeUndefined();
});
it("has its callback called with the loaded model", () => {
return new Promise((done) => {
const merges = new Map();
merges.set([0, 1], [0, 2]);
(BPE as any).init({ a: 0, b: 1, ab: 2 }, merges, (err: any, model: any) => {
(BPE as any).init({ a: 0, b: 1, ab: 2 }, [["a", "b"]], (err: any, model: any) => {
expect(model).toBeDefined();
done();
});

View File

@@ -22,7 +22,7 @@ describe("Can modify pretokenizers on the fly", () => {
let tokenizer: Tokenizer;
beforeAll(async () => {
const model = await promisify<string, WordPieceOptions, Model>(WordPiece.fromFiles)(
const model = await promisify<string, WordPieceOptions, Model>(WordPiece.fromFile)(
`${MOCKS_DIR}/vocab.txt`,
{
continuingSubwordPrefix: "##",
@@ -60,7 +60,7 @@ describe("RawEncoding", () => {
) => Promise<RawEncoding>;
beforeAll(async () => {
const model = await promisify<string, WordPieceOptions, Model>(WordPiece.fromFiles)(
const model = await promisify<string, WordPieceOptions, Model>(WordPiece.fromFile)(
`${MOCKS_DIR}/vocab.txt`,
{
continuingSubwordPrefix: "##",

View File

@@ -132,8 +132,8 @@ export class BertWordPieceTokenizer extends BaseTokenizer<BertTokenizerConfig> {
let model: Model;
if (opts.vocabFile) {
const fromFiles = promisify<string, WordPieceOptions, Model>(WordPiece.fromFiles);
model = await fromFiles(opts.vocabFile, {
const fromFile = promisify<string, WordPieceOptions, Model>(WordPiece.fromFile);
model = await fromFile(opts.vocabFile, {
unkToken: getTokenContent(opts.unkToken),
continuingSubwordPrefix: opts.wordpiecesPrefix,
});

View File

@@ -108,8 +108,8 @@ export class BPETokenizer extends BaseTokenizer<BPETokenizerConfig> {
unkToken: unkToken,
};
const fromFiles = promisify<string, string, BPEOptions, Model>(BPE.fromFiles);
model = await fromFiles(opts.vocabFile, opts.mergesFile, modelOptions);
const fromFile = promisify<string, string, BPEOptions, Model>(BPE.fromFile);
model = await fromFile(opts.vocabFile, opts.mergesFile, modelOptions);
} else {
model = BPE.empty();
}

View File

@@ -93,8 +93,8 @@ export class ByteLevelBPETokenizer extends BaseTokenizer<ByteLevelBPETokenizerCo
let model: Model;
if (opts.vocabFile && opts.mergesFile) {
const fromFiles = promisify<string, string, BPEOptions, Model>(BPE.fromFiles);
model = await fromFiles(opts.vocabFile, opts.mergesFile, opts);
const fromFile = promisify<string, string, BPEOptions, Model>(BPE.fromFile);
model = await fromFile(opts.vocabFile, opts.mergesFile, opts);
} else {
model = BPE.empty();
}

View File

@@ -100,8 +100,8 @@ export class SentencePieceBPETokenizer extends BaseTokenizer<
unkToken: unkToken,
};
const fromFiles = promisify<string, string, BPEOptions, Model>(BPE.fromFiles);
model = await fromFiles(opts.vocabFile, opts.mergesFile, modelOptions);
const fromFile = promisify<string, string, BPEOptions, Model>(BPE.fromFile);
model = await fromFile(opts.vocabFile, opts.mergesFile, modelOptions);
} else {
model = BPE.empty();
}

View File

@@ -8,7 +8,11 @@ use std::path::Path;
use std::path::PathBuf;
use std::sync::Arc;
use tk::models::{bpe::BpeBuilder, wordpiece::WordPieceBuilder, ModelWrapper};
use tk::models::{
bpe::{BpeBuilder, Merges, Vocab},
wordpiece::WordPieceBuilder,
ModelWrapper,
};
use tk::Model as ModelTrait;
use tk::Token;
@@ -144,8 +148,8 @@ pub fn bpe_init(mut cx: FunctionContext) -> JsResult<JsUndefined> {
// Options not specified, callback instead
Err(_) => (BpeOptions::default(), cx.argument::<JsFunction>(2)?),
};
let vocab = cx.extract::<HashMap<String, u32>>(0)?;
let merges = cx.extract::<HashMap<(u32, u32), (u32, u32)>>(1)?;
let vocab = cx.extract::<Vocab>(0)?;
let merges = cx.extract::<Merges>(1)?;
let mut builder = tk::models::bpe::BPE::builder().vocab_and_merges(vocab, merges);

View File

@@ -36,11 +36,7 @@ class BPE(Model):
A dictionnary of string keys and their ids {"am": 0,...}
merges: (`optional`) string:
A dictionnary of pairs of ids as keys and their merge correspondace:
{(id_left, id_right): (importance, id_merged), .... }
with vocab : {"a": 0, "b": 1", ... "ab": 4} the merge
{(0, 1): (0, 4) ,...}
corresponds to the "ab" merge, that is the most likely merge (0)
A list of pairs of tokens [("a", "b"),...]
cache_capacity: (`optional`) int:
The number of words that the BPE cache can contain. The cache allows
@@ -66,7 +62,7 @@ class BPE(Model):
def __init__(
self,
vocab: Optional[Union[str, Dict[str, int]]],
merges: Optional[Union[str, Dict[Tuple[int, int], Tuple[int, int]]]],
merges: Optional[Union[str, List[Tuple[str, str]]]],
cache_capacity: Optional[int],
dropout: Optional[float],
unk_token: Optional[str],

View File

@@ -12,7 +12,7 @@ class TestBPE:
assert isinstance(BPE(), BPE)
vocab = {"a": 0, "b": 1, "ab": 2}
merges = {(0, 1): (0, 2)}
merges = [("a", "b")]
assert isinstance(BPE(vocab, merges), Model)
assert isinstance(BPE.from_file(roberta_files["vocab"], roberta_files["merges"]), BPE)
with pytest.raises(ValueError, match="`vocab` and `merges` must be both specified"):

View File

@@ -14,7 +14,8 @@ use std::{
pub type Vocab = HashMap<String, u32>;
type VocabR = HashMap<u32, String>;
pub type Merges = HashMap<Pair, (u32, u32)>;
pub type MergeMap = HashMap<Pair, (u32, u32)>;
pub type Merges = Vec<(String, String)>;
struct Config {
files: Option<(String, String)>,
@@ -39,7 +40,7 @@ impl Default for BpeBuilder {
config: Config {
files: None,
vocab: HashMap::new(),
merges: HashMap::new(),
merges: vec![],
cache_capacity: DEFAULT_CACHE_CAPACITY,
dropout: None,
unk_token: None,
@@ -133,10 +134,33 @@ impl BpeBuilder {
capacity => Some(Cache::new(capacity)),
};
let vocab = self.config.vocab;
let merge_map: MergeMap = self
.config
.merges
.into_iter()
.enumerate()
.map(|(i, (a, b))| -> Result<(Pair, (u32, u32))> {
let a_id = vocab
.get(&a)
.ok_or_else(|| Error::MergeTokenOutOfVocabulary(a.to_owned()))?;
let b_id = vocab
.get(&b)
.ok_or_else(|| Error::MergeTokenOutOfVocabulary(b.to_owned()))?;
let new_token = format!("{}{}", a, b);
let new_id = vocab
.get(&new_token)
.ok_or(Error::MergeTokenOutOfVocabulary(new_token))?;
Ok(((*a_id, *b_id), (i as u32, *new_id)))
})
.collect::<Result<MergeMap>>()?;
// merges.insert(pair, (rank as u32, *new_id));
Ok(BPE {
vocab: self.config.vocab,
vocab,
vocab_r,
merges: self.config.merges,
merges: merge_map,
cache,
dropout: self.config.dropout,
unk_token: self.config.unk_token,
@@ -155,7 +179,7 @@ pub struct BPE {
/// Reversed vocabulary, to rebuild sentences.
pub(crate) vocab_r: VocabR,
/// Contains the mapping between Pairs and their (rank, new_id).
pub(crate) merges: Merges,
pub(crate) merges: MergeMap,
/// Contains the cache for optimizing the encoding step.
cache: Option<Cache<String, Word>>,
/// Dropout probability for merges. 0 = no dropout is the default. At 1.0, tokenization will
@@ -214,9 +238,9 @@ impl Clone for BPE {
/// "{pair_a} {pair_b}" into the format expected by the BPE struct
pub(crate) fn convert_merges_to_hashmap<I: Iterator<Item = String>>(
iter: I,
vocab: &Vocab,
_vocab: &Vocab,
) -> Result<Merges> {
let mut merges = HashMap::new();
let mut merges = vec![];
let lines = iter.filter(|l| !l.starts_with("#version"));
for (rank, line) in lines.enumerate() {
@@ -225,19 +249,7 @@ pub(crate) fn convert_merges_to_hashmap<I: Iterator<Item = String>>(
return Err(Error::BadMerges(rank + 1).into());
}
let a = vocab
.get(parts[0])
.ok_or_else(|| Error::MergeTokenOutOfVocabulary(parts[0].to_owned()))?;
let b = vocab
.get(parts[1])
.ok_or_else(|| Error::MergeTokenOutOfVocabulary(parts[1].to_owned()))?;
let pair = (*a, *b);
let new_token = format!("{}{}", parts[0], parts[1]);
let new_id = vocab
.get(&new_token)
.ok_or(Error::MergeTokenOutOfVocabulary(new_token))?;
merges.insert(pair, (rank as u32, *new_id));
merges.push((parts[0].to_string(), parts[1].to_string()));
}
Ok(merges)
@@ -500,7 +512,7 @@ mod tests {
.cloned()
.collect();
let bpe = BpeBuilder::default()
.vocab_and_merges(vocab, HashMap::new())
.vocab_and_merges(vocab, vec![])
.unk_token("<unk>".to_string())
.build()
.unwrap();
@@ -534,7 +546,7 @@ mod tests {
.cloned()
.collect();
let bpe = BpeBuilder::default()
.vocab_and_merges(vocab, HashMap::new())
.vocab_and_merges(vocab, vec![])
.unk_token("<unk>".to_string())
.fuse_unk(true)
.build()
@@ -583,19 +595,16 @@ mod tests {
.iter()
.cloned()
.collect();
let merges: Merges = [
((vocab["r"], vocab["e"]), (1u32, vocab["re"])), // 'r-e' -> 're'
((vocab["a"], vocab["t"]), (2u32, vocab["at"])), // 'a-t' -> 'at'
((vocab["e"], vocab["d"]), (3u32, vocab["ed"])), // 'e-d' -> 'ed'
((vocab["u"], vocab["n"]), (4u32, vocab["un"])), // 'u-n' -> 'un'
((vocab["at"], vocab["ed"]), (5u32, vocab["ated"])), // 'at-ed' -> 'ated'
((vocab["re"], vocab["l"]), (6u32, vocab["rel"])), // 're-l' -> 'rel'
((vocab["rel"], vocab["ated"]), (7u32, vocab["related"])), // 'rel-ated' -> 'related'
((vocab["un"], vocab["related"]), (8u32, vocab["unrelated"])), // 'un-related' -> 'unrelated'
]
.iter()
.cloned()
.collect();
let merges: Merges = vec![
("r".to_string(), "e".to_string()),
("a".to_string(), "t".to_string()),
("e".to_string(), "d".to_string()),
("u".to_string(), "n".to_string()),
("at".to_string(), "ed".to_string()),
("re".to_string(), "l".to_string()),
("rel".to_string(), "ated".to_string()),
("un".to_string(), "related".to_string()),
];
let mut bpe = BPE::new(vocab, merges);
// With no dropout:

View File

@@ -555,8 +555,12 @@ impl BpeTrainer {
word_to_id,
merges
.into_iter()
.enumerate()
.map(|(index, (pair, new_id))| (pair, (index as u32, new_id)))
.map(|((a_id, b_id), _)| {
(
id_to_word[a_id as usize].clone(),
id_to_word[b_id as usize].clone(),
)
})
.collect(),
);
if let Some(prefix) = &self.continuing_subword_prefix {