mirror of
https://github.com/mii443/tokenizers.git
synced 2025-08-23 00:35:35 +00:00
fix wordpiece tokenizer when no vocabfile provided
This commit is contained in:
@ -0,0 +1,42 @@
|
|||||||
|
import { BertWordPieceOptions, BertWordPieceTokenizer } from "./bert-wordpiece.tokenizer";
|
||||||
|
import { mocked } from "ts-jest/utils";
|
||||||
|
import { Tokenizer } from "../bindings/tokenizer";
|
||||||
|
|
||||||
|
jest.mock("../bindings/models");
|
||||||
|
jest.mock("../bindings/tokenizer");
|
||||||
|
|
||||||
|
describe("BertWordPieceTokenizer", () => {
|
||||||
|
describe("fromOptions", () => {
|
||||||
|
|
||||||
|
it("should not throw any error if no vocabFile is provided", async () => {
|
||||||
|
await BertWordPieceTokenizer.fromOptions();
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("when a vocabFile is provided and `addSpecialTokens === true`", () => {
|
||||||
|
it("should throw a `sepToken error` if no `sepToken` is provided", () => {
|
||||||
|
const options: BertWordPieceOptions = {
|
||||||
|
vocabFile: "./fake.txt",
|
||||||
|
sepToken: undefined
|
||||||
|
};
|
||||||
|
|
||||||
|
expect.assertions(1);
|
||||||
|
BertWordPieceTokenizer.fromOptions(options)
|
||||||
|
.catch(e => expect(e).toBeDefined());
|
||||||
|
});
|
||||||
|
|
||||||
|
it("should throw a `clsToken error` if no `clsToken` is provided", () => {
|
||||||
|
const options: BertWordPieceOptions = {
|
||||||
|
vocabFile: "./fake.txt",
|
||||||
|
clsToken: undefined
|
||||||
|
};
|
||||||
|
|
||||||
|
mocked(Tokenizer.prototype.tokenToId).mockImplementationOnce(() => 10);
|
||||||
|
|
||||||
|
expect.assertions(1);
|
||||||
|
BertWordPieceTokenizer.fromOptions(options)
|
||||||
|
.catch(e => expect(e).toBeDefined());
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
});
|
||||||
|
});
|
@ -1,7 +1,7 @@
|
|||||||
import { promisify } from "util";
|
import { promisify } from "util";
|
||||||
import { BaseTokenizer } from "./base.tokenizer";
|
import { BaseTokenizer } from "./base.tokenizer";
|
||||||
import { Tokenizer } from "../bindings/tokenizer";
|
import { Tokenizer } from "../bindings/tokenizer";
|
||||||
import { Model, wordPiece } from "../bindings/models";
|
import { Model, WordPiece } from "../bindings/models";
|
||||||
import { bertNormalizer } from "../bindings/normalizers";
|
import { bertNormalizer } from "../bindings/normalizers";
|
||||||
import { bertPreTokenizer } from "../bindings/pre-tokenizers";
|
import { bertPreTokenizer } from "../bindings/pre-tokenizers";
|
||||||
import { bertProcessing } from "../bindings/post-processors";
|
import { bertProcessing } from "../bindings/post-processors";
|
||||||
@ -115,39 +115,41 @@ export class BertWordPieceTokenizer extends BaseTokenizer {
|
|||||||
* @param [options] Optional tokenizer options
|
* @param [options] Optional tokenizer options
|
||||||
*/
|
*/
|
||||||
static async fromOptions(options?: BertWordPieceOptions): Promise<BertWordPieceTokenizer> {
|
static async fromOptions(options?: BertWordPieceOptions): Promise<BertWordPieceTokenizer> {
|
||||||
const mergedOptions = { ...this.defaultBertOptions, ...options };
|
const opts = { ...this.defaultBertOptions, ...options };
|
||||||
|
|
||||||
let model: Model;
|
let model: Model;
|
||||||
if (mergedOptions.vocabFile) {
|
if (opts.vocabFile) {
|
||||||
// const fromFiles = promisify(WordPiece.fromFiles);
|
// const fromFiles = promisify(WordPiece.fromFiles);
|
||||||
model = wordPiece.fromFiles(mergedOptions.vocabFile, { unkToken: mergedOptions.unkToken });
|
model = WordPiece.fromFiles(opts.vocabFile, { unkToken: opts.unkToken });
|
||||||
// model = await fromFiles(mergedOptions.vocabFile, mergedOptions.unkToken, null);
|
// model = await fromFiles(mergedOptions.vocabFile, mergedOptions.unkToken, null);
|
||||||
} else {
|
} else {
|
||||||
model = wordPiece.empty();
|
model = WordPiece.empty();
|
||||||
}
|
}
|
||||||
|
|
||||||
const tokenizer = new Tokenizer(model);
|
const tokenizer = new Tokenizer(model);
|
||||||
|
|
||||||
const normalizer = bertNormalizer(mergedOptions);
|
const normalizer = bertNormalizer(opts);
|
||||||
tokenizer.setNormalizer(normalizer);
|
tokenizer.setNormalizer(normalizer);
|
||||||
tokenizer.setPreTokenizer(bertPreTokenizer());
|
tokenizer.setPreTokenizer(bertPreTokenizer());
|
||||||
|
|
||||||
const sepTokenId = tokenizer.tokenToId(mergedOptions.sepToken);
|
if (opts.vocabFile && opts.addSpecialTokens) {
|
||||||
if (sepTokenId === undefined) {
|
const sepTokenId = tokenizer.tokenToId(opts.sepToken);
|
||||||
throw new Error("sepToken not found in the vocabulary");
|
if (sepTokenId === undefined) {
|
||||||
|
throw new Error("sepToken not found in the vocabulary");
|
||||||
|
}
|
||||||
|
|
||||||
|
const clsTokenId = tokenizer.tokenToId(opts.clsToken);
|
||||||
|
if (clsTokenId === undefined) {
|
||||||
|
throw new Error("clsToken not found in the vocabulary");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (opts.addSpecialTokens) {
|
||||||
|
const processor = bertProcessing([opts.sepToken, sepTokenId], [opts.clsToken, clsTokenId]);
|
||||||
|
tokenizer.setPostProcessor(processor);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const clsTokenId = tokenizer.tokenToId(mergedOptions.clsToken);
|
const decoder = wordPieceDecoder(opts.wordpiecesPrefix);
|
||||||
if (clsTokenId === undefined) {
|
|
||||||
throw new Error("clsToken not found in the vocabulary");
|
|
||||||
}
|
|
||||||
|
|
||||||
if (mergedOptions.addSpecialTokens) {
|
|
||||||
const processor = bertProcessing([mergedOptions.sepToken, sepTokenId], [mergedOptions.clsToken, clsTokenId]);
|
|
||||||
tokenizer.setPostProcessor(processor);
|
|
||||||
}
|
|
||||||
|
|
||||||
const decoder = wordPieceDecoder(mergedOptions.wordpiecesPrefix);
|
|
||||||
tokenizer.setDecoder(decoder);
|
tokenizer.setDecoder(decoder);
|
||||||
|
|
||||||
return new BertWordPieceTokenizer(tokenizer);
|
return new BertWordPieceTokenizer(tokenizer);
|
||||||
|
Reference in New Issue
Block a user