fix wordpiece tokenizer when no vocabfile provided

This commit is contained in:
Pierric Cistac
2020-01-15 12:23:52 -05:00
parent 0db8467fba
commit 26a52dd660
2 changed files with 64 additions and 20 deletions

View File

@ -0,0 +1,42 @@
import { BertWordPieceOptions, BertWordPieceTokenizer } from "./bert-wordpiece.tokenizer";
import { mocked } from "ts-jest/utils";
import { Tokenizer } from "../bindings/tokenizer";
jest.mock("../bindings/models");
jest.mock("../bindings/tokenizer");
describe("BertWordPieceTokenizer", () => {
describe("fromOptions", () => {
it("should not throw any error if no vocabFile is provided", async () => {
await BertWordPieceTokenizer.fromOptions();
});
describe("when a vocabFile is provided and `addSpecialTokens === true`", () => {
it("should throw a `sepToken error` if no `sepToken` is provided", () => {
const options: BertWordPieceOptions = {
vocabFile: "./fake.txt",
sepToken: undefined
};
expect.assertions(1);
BertWordPieceTokenizer.fromOptions(options)
.catch(e => expect(e).toBeDefined());
});
it("should throw a `clsToken error` if no `clsToken` is provided", () => {
const options: BertWordPieceOptions = {
vocabFile: "./fake.txt",
clsToken: undefined
};
mocked(Tokenizer.prototype.tokenToId).mockImplementationOnce(() => 10);
expect.assertions(1);
BertWordPieceTokenizer.fromOptions(options)
.catch(e => expect(e).toBeDefined());
});
});
});
});

View File

@ -1,7 +1,7 @@
import { promisify } from "util";
import { BaseTokenizer } from "./base.tokenizer";
import { Tokenizer } from "../bindings/tokenizer";
import { Model, wordPiece } from "../bindings/models";
import { Model, WordPiece } from "../bindings/models";
import { bertNormalizer } from "../bindings/normalizers";
import { bertPreTokenizer } from "../bindings/pre-tokenizers";
import { bertProcessing } from "../bindings/post-processors";
@ -115,39 +115,41 @@ export class BertWordPieceTokenizer extends BaseTokenizer {
* @param [options] Optional tokenizer options
*/
static async fromOptions(options?: BertWordPieceOptions): Promise<BertWordPieceTokenizer> {
const mergedOptions = { ...this.defaultBertOptions, ...options };
const opts = { ...this.defaultBertOptions, ...options };
let model: Model;
if (mergedOptions.vocabFile) {
if (opts.vocabFile) {
// const fromFiles = promisify(WordPiece.fromFiles);
model = wordPiece.fromFiles(mergedOptions.vocabFile, { unkToken: mergedOptions.unkToken });
model = WordPiece.fromFiles(opts.vocabFile, { unkToken: opts.unkToken });
// model = await fromFiles(mergedOptions.vocabFile, mergedOptions.unkToken, null);
} else {
model = wordPiece.empty();
model = WordPiece.empty();
}
const tokenizer = new Tokenizer(model);
const normalizer = bertNormalizer(mergedOptions);
const normalizer = bertNormalizer(opts);
tokenizer.setNormalizer(normalizer);
tokenizer.setPreTokenizer(bertPreTokenizer());
const sepTokenId = tokenizer.tokenToId(mergedOptions.sepToken);
if (sepTokenId === undefined) {
throw new Error("sepToken not found in the vocabulary");
if (opts.vocabFile && opts.addSpecialTokens) {
const sepTokenId = tokenizer.tokenToId(opts.sepToken);
if (sepTokenId === undefined) {
throw new Error("sepToken not found in the vocabulary");
}
const clsTokenId = tokenizer.tokenToId(opts.clsToken);
if (clsTokenId === undefined) {
throw new Error("clsToken not found in the vocabulary");
}
if (opts.addSpecialTokens) {
const processor = bertProcessing([opts.sepToken, sepTokenId], [opts.clsToken, clsTokenId]);
tokenizer.setPostProcessor(processor);
}
}
const clsTokenId = tokenizer.tokenToId(mergedOptions.clsToken);
if (clsTokenId === undefined) {
throw new Error("clsToken not found in the vocabulary");
}
if (mergedOptions.addSpecialTokens) {
const processor = bertProcessing([mergedOptions.sepToken, sepTokenId], [mergedOptions.clsToken, clsTokenId]);
tokenizer.setPostProcessor(processor);
}
const decoder = wordPieceDecoder(mergedOptions.wordpiecesPrefix);
const decoder = wordPieceDecoder(opts.wordpiecesPrefix);
tokenizer.setDecoder(decoder);
return new BertWordPieceTokenizer(tokenizer);