mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-22 16:25:30 +00:00
fix wordpiece tokenizer when no vocabfile provided
This commit is contained in:
@ -0,0 +1,42 @@
|
||||
import { BertWordPieceOptions, BertWordPieceTokenizer } from "./bert-wordpiece.tokenizer";
|
||||
import { mocked } from "ts-jest/utils";
|
||||
import { Tokenizer } from "../bindings/tokenizer";
|
||||
|
||||
jest.mock("../bindings/models");
|
||||
jest.mock("../bindings/tokenizer");
|
||||
|
||||
describe("BertWordPieceTokenizer", () => {
|
||||
describe("fromOptions", () => {
|
||||
|
||||
it("should not throw any error if no vocabFile is provided", async () => {
|
||||
await BertWordPieceTokenizer.fromOptions();
|
||||
});
|
||||
|
||||
describe("when a vocabFile is provided and `addSpecialTokens === true`", () => {
|
||||
it("should throw a `sepToken error` if no `sepToken` is provided", () => {
|
||||
const options: BertWordPieceOptions = {
|
||||
vocabFile: "./fake.txt",
|
||||
sepToken: undefined
|
||||
};
|
||||
|
||||
expect.assertions(1);
|
||||
BertWordPieceTokenizer.fromOptions(options)
|
||||
.catch(e => expect(e).toBeDefined());
|
||||
});
|
||||
|
||||
it("should throw a `clsToken error` if no `clsToken` is provided", () => {
|
||||
const options: BertWordPieceOptions = {
|
||||
vocabFile: "./fake.txt",
|
||||
clsToken: undefined
|
||||
};
|
||||
|
||||
mocked(Tokenizer.prototype.tokenToId).mockImplementationOnce(() => 10);
|
||||
|
||||
expect.assertions(1);
|
||||
BertWordPieceTokenizer.fromOptions(options)
|
||||
.catch(e => expect(e).toBeDefined());
|
||||
});
|
||||
});
|
||||
|
||||
});
|
||||
});
|
@ -1,7 +1,7 @@
|
||||
import { promisify } from "util";
|
||||
import { BaseTokenizer } from "./base.tokenizer";
|
||||
import { Tokenizer } from "../bindings/tokenizer";
|
||||
import { Model, wordPiece } from "../bindings/models";
|
||||
import { Model, WordPiece } from "../bindings/models";
|
||||
import { bertNormalizer } from "../bindings/normalizers";
|
||||
import { bertPreTokenizer } from "../bindings/pre-tokenizers";
|
||||
import { bertProcessing } from "../bindings/post-processors";
|
||||
@ -115,39 +115,41 @@ export class BertWordPieceTokenizer extends BaseTokenizer {
|
||||
* @param [options] Optional tokenizer options
|
||||
*/
|
||||
static async fromOptions(options?: BertWordPieceOptions): Promise<BertWordPieceTokenizer> {
|
||||
const mergedOptions = { ...this.defaultBertOptions, ...options };
|
||||
const opts = { ...this.defaultBertOptions, ...options };
|
||||
|
||||
let model: Model;
|
||||
if (mergedOptions.vocabFile) {
|
||||
if (opts.vocabFile) {
|
||||
// const fromFiles = promisify(WordPiece.fromFiles);
|
||||
model = wordPiece.fromFiles(mergedOptions.vocabFile, { unkToken: mergedOptions.unkToken });
|
||||
model = WordPiece.fromFiles(opts.vocabFile, { unkToken: opts.unkToken });
|
||||
// model = await fromFiles(mergedOptions.vocabFile, mergedOptions.unkToken, null);
|
||||
} else {
|
||||
model = wordPiece.empty();
|
||||
model = WordPiece.empty();
|
||||
}
|
||||
|
||||
const tokenizer = new Tokenizer(model);
|
||||
|
||||
const normalizer = bertNormalizer(mergedOptions);
|
||||
const normalizer = bertNormalizer(opts);
|
||||
tokenizer.setNormalizer(normalizer);
|
||||
tokenizer.setPreTokenizer(bertPreTokenizer());
|
||||
|
||||
const sepTokenId = tokenizer.tokenToId(mergedOptions.sepToken);
|
||||
if (sepTokenId === undefined) {
|
||||
throw new Error("sepToken not found in the vocabulary");
|
||||
if (opts.vocabFile && opts.addSpecialTokens) {
|
||||
const sepTokenId = tokenizer.tokenToId(opts.sepToken);
|
||||
if (sepTokenId === undefined) {
|
||||
throw new Error("sepToken not found in the vocabulary");
|
||||
}
|
||||
|
||||
const clsTokenId = tokenizer.tokenToId(opts.clsToken);
|
||||
if (clsTokenId === undefined) {
|
||||
throw new Error("clsToken not found in the vocabulary");
|
||||
}
|
||||
|
||||
if (opts.addSpecialTokens) {
|
||||
const processor = bertProcessing([opts.sepToken, sepTokenId], [opts.clsToken, clsTokenId]);
|
||||
tokenizer.setPostProcessor(processor);
|
||||
}
|
||||
}
|
||||
|
||||
const clsTokenId = tokenizer.tokenToId(mergedOptions.clsToken);
|
||||
if (clsTokenId === undefined) {
|
||||
throw new Error("clsToken not found in the vocabulary");
|
||||
}
|
||||
|
||||
if (mergedOptions.addSpecialTokens) {
|
||||
const processor = bertProcessing([mergedOptions.sepToken, sepTokenId], [mergedOptions.clsToken, clsTokenId]);
|
||||
tokenizer.setPostProcessor(processor);
|
||||
}
|
||||
|
||||
const decoder = wordPieceDecoder(mergedOptions.wordpiecesPrefix);
|
||||
const decoder = wordPieceDecoder(opts.wordpiecesPrefix);
|
||||
tokenizer.setDecoder(decoder);
|
||||
|
||||
return new BertWordPieceTokenizer(tokenizer);
|
||||
|
Reference in New Issue
Block a user