Skip to content

Commit

Permalink
[gpt2pre 4] GPT2Preprocessor Layer (tensorflow#7814)
Browse files Browse the repository at this point in the history
* Draft gpt2 Preprocessor class

* Add serialization test

* Add memory test

* lint

---------

Co-authored-by: Linchenn <[email protected]>
  • Loading branch information
pforderique and Linchenn authored Jul 21, 2023
1 parent 983dc16 commit 81abd7b
Show file tree
Hide file tree
Showing 2 changed files with 360 additions and 0 deletions.
220 changes: 220 additions & 0 deletions tfjs-layers/src/layers/nlp/models/gpt2/gpt2_preprocessor.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
/**
* @license
* Copyright 2023 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

/**
* GPT-2 preprocessor layer.
*/

/* Original source: keras-nlp/models/gpt2/gpt2_preprocessor.py */
import { Tensor, Tensor2D, serialization, tidy } from '@tensorflow/tfjs-core';

import { LayerArgs } from '../../../../engine/topology';
import { Preprocessor } from '../preprocessor';
import { GPT2Tokenizer } from './gpt2_tokenizer';
import { StartEndPacker } from '../../preprocessing/start_end_packer';
import { ValueError } from '../../../../errors';

export declare interface GPT2PreprocessorArgs extends LayerArgs {
/**
* A GPT2Tokenizer instance.
*/
tokenizer: GPT2Tokenizer;

/**
* The length of the packed inputs.
* Defaults to 1024.
*/
sequenceLength?: number;

/**
* If `true`, the preprocessor will prepend the tokenizer start token to each
* input sequence.
* Defaults to `true`.
*/
addStartToken?: boolean;

/**
* If `true`, the preprocessor will prepend the tokenizer end token to each
* input sequence.
* Defaults to `true`.
*/
addEndToken?: boolean;
}

export declare interface GPT2PreprocessorOptions {
/**
* Any label data. Will be passed through unaltered.
*/
y?: Tensor;

/**
* Any label weight data. Will be passed through unaltered.
*/
sampleWeight?: Tensor;

/**
* Pass to override the configured `sequenceLength` of the layer.
*/
sequenceLength?: number;
}

export declare interface PreprocessorOutputs {
tokenIds: Tensor2D;
paddingMask: Tensor2D;
}

function packXYSampleWeight(
x: PreprocessorOutputs, y?: Tensor, sampleWeight?: Tensor):
PreprocessorOutputs
| [PreprocessorOutputs, Tensor]
| [PreprocessorOutputs, Tensor, Tensor] {

if (y === undefined) {
return x;
} else if (sampleWeight === undefined) {
return [x, y];
} else {
return [x, y, sampleWeight];
}
}

/**
* GPT2 preprocessing layer which tokenizes and packs inputs.
*
* This preprocessing layer will do 2 things:
*
* - Tokenize the inputs using the `tokenizer`.
* - Construct a dictionary with keys `"tokenIds"`, `"paddingMask"`, that can
* be passed directly to a `GPT2Backbone`.
*
* The call method of this layer accepts three arguments, `x`, `y`, and
* `sampleWeight`. `x` can be a string or tensor representing a single
* segment, a list of strings representing a batch of single segments,
* or a list of tensors representing multiple segments to be packed together.
* `y` and `sampleWeight` are both optional, can have any format, and will be
* passed through unaltered.
*
* `GPT2Preprocessor` forces the input to have only one segment, as GPT2 is
* mainly used for generation tasks. For tasks having multi-segment inputs
* like "glue/mnli", please use a model designed for classification purposes
* such as BERT or RoBERTa.
*
* Examples:
*
* Directly calling the layer on data.
* ```js
* const features = ['a quick fox.', 'a fox quick.'];
* const vocabulary =
* new Map([['<|endoftext|>', 0], ['a', 4], ['Ġquick', 5], ['Ġfox', 6]]);
* const merges =
* ['Ġ q', 'u i', 'c k', 'ui ck', 'Ġq uick', 'Ġ f', 'o x', 'Ġf ox'];
* const tokenizer = GPT2Tokenizer({vocabulary, merges});
*
* const preprocessor = GPT2Preprocessor({tokenizer});
* preprocessor.call(tensor(['the quick brown fox jumped.']))[0].print();
* ```
*/
export class GPT2Preprocessor extends Preprocessor {
private readonly sequenceLength: number;
private readonly addStartToken: boolean;
private readonly addEndToken: boolean;
private readonly packer: StartEndPacker;

constructor(args: GPT2PreprocessorArgs) {
super(args);
this.tokenizer = args.tokenizer;
this.sequenceLength = args.sequenceLength ?? 1024;
this.addStartToken = args.addStartToken ?? true;
this.addEndToken = args.addEndToken ?? true;

const gpt2Tokenizer = this.tokenizer as GPT2Tokenizer;
this.packer = new StartEndPacker({
startValue: gpt2Tokenizer.startTokenId,
endValue: gpt2Tokenizer.endTokenId,
padValue: gpt2Tokenizer.padTokenId,
sequenceLength: this.sequenceLength,
});
}

override getConfig(): serialization.ConfigDict {
const config = {
sequenceLength: this.sequenceLength,
addStartToken: this.addStartToken,
addEndToken: this.addEndToken,
};
const baseConfig = super.getConfig();
Object.assign(config, baseConfig);
return config;
}

override call(
inputs: Tensor|Tensor[], kwargs: GPT2PreprocessorOptions): Tensor|Tensor[] {
return this.callAndReturnPaddingMask(inputs, kwargs).tokenIds;
}

private callAndReturnPaddingMask(
inputs: Tensor|Tensor[],
kwargs: GPT2PreprocessorOptions
): PreprocessorOutputs {
return tidy(() => {
if (inputs instanceof Array) {
if (inputs.length !== 1) {
throw new ValueError(
'GPT2 requires each input feature to contain only ' +
`one segment, but received ${inputs.length}. If you are using ` +
'GPT2 for a multi-segment classification task, please refer to ' +
'classification models like BERT or RoBERTa.'
);
}
inputs = inputs[0];
}

const sequenceLength = kwargs.sequenceLength ?? this.sequenceLength;
const [tokenIds, paddingMask] = this.packer.callAndReturnPaddingMask(
this.tokenizer.call(inputs),
{
sequenceLength,
addStartValue: this.addStartToken,
addEndValue: this.addEndToken
}
);

return {
tokenIds: tokenIds as Tensor2D,
paddingMask: paddingMask as Tensor2D
};
});
}

/**
* Calls the layer and returns extra information like the paddingMask used to
* pack the sequence, the label data, and the sample weights used.
*/
callAndPackArgs(inputs: Tensor|Tensor[], kwargs: GPT2PreprocessorOptions):
PreprocessorOutputs
| [PreprocessorOutputs, Tensor]
| [PreprocessorOutputs, Tensor, Tensor] {
const x = this.callAndReturnPaddingMask(inputs, kwargs);
return packXYSampleWeight(x, kwargs.y, kwargs.sampleWeight);
}

static override tokenizerCls<T extends serialization.Serializable>(
cls: serialization.SerializableConstructor<T>) {
return GPT2Tokenizer;
}
}
serialization.registerClass(GPT2Preprocessor);
140 changes: 140 additions & 0 deletions tfjs-layers/src/layers/nlp/models/gpt2/gpt2_preprocessor_test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
/**
* @license
* Copyright 2023 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

/**
* Unit Tests for GPT2Preprocessor.
*/

import { Tensor, memory, serialization, tensor, tensor2d } from '@tensorflow/tfjs-core';

import { GPT2Preprocessor, PreprocessorOutputs } from './gpt2_preprocessor';
import { GPT2Tokenizer } from './gpt2_tokenizer';
import { expectTensorsClose } from '../../../../utils/test_utils';

describe('GPT2Preprocessor', () => {
let vocabulary: Map<string, number>;
let merges: string[];
let preprocessor: GPT2Preprocessor;

beforeEach(() => {
vocabulary = new Map([
['!', 0],
['air', 1],
['Ġair', 2],
['plane', 3],
['Ġat', 4],
['port', 5],
['<|endoftext|>', 6],
]);

merges = ['Ġ a', 'Ġ t', 'Ġ i', 'Ġ b', 'a i', 'p l', 'n e'].concat(
['Ġa t', 'p o', 'r t', 'Ġt h', 'ai r', 'pl a', 'po rt'],
['Ġai r', 'Ġa i', 'pla ne']
);
preprocessor = new GPT2Preprocessor({
tokenizer: new GPT2Tokenizer({vocabulary, merges}),
sequenceLength: 8
});
});

it('tokenize', () => {
const inputData = tensor(['airplane at airport']);

const output =
preprocessor.callAndPackArgs(inputData, {}) as PreprocessorOutputs;

expectTensorsClose(output.tokenIds, tensor2d([[6, 1, 3, 4, 2, 5, 6, 0]]));
expectTensorsClose(
output.paddingMask, tensor2d([[1, 1, 1, 1, 1, 1, 1, 0]], [1, 8], 'bool'));
});

it('no start end token', () => {
const inputData = tensor(Array<string>(4).fill('airplane at airport'));
preprocessor = new GPT2Preprocessor({
tokenizer: new GPT2Tokenizer({vocabulary, merges}),
sequenceLength: 8,
addStartToken: false,
addEndToken: false,
});
const expectedOutput = {
tokenIds: tensor2d(Array<number[]>(4).fill([1, 3, 4, 2, 5, 0, 0, 0])),
paddingMask: tensor2d(
Array<number[]>(4).fill([1, 1, 1, 1, 1, 0, 0, 0]), [4, 8], 'bool'),
};

const output =
preprocessor.callAndPackArgs(inputData, {}) as PreprocessorOutputs;

expectTensorsClose(output.tokenIds, expectedOutput.tokenIds);
expectTensorsClose(output.paddingMask, expectedOutput.paddingMask);
});

it('tokenize labeled batch', () => {
const inputData = tensor(Array<string>(4).fill('airplane at airport'));
const yIn = tensor([1, 1, 1, 1]);
const swIn = tensor([1., 1., 1., 1.]);
const expectedX = {
tokenIds: tensor2d(Array<number[]>(4).fill([6, 1, 3, 4, 2, 5, 6, 0])),
paddingMask: tensor2d(
Array<number[]>(4).fill([1, 1, 1, 1, 1, 1, 1, 0]), [4, 8], 'bool'),
};

const output = preprocessor.callAndPackArgs(
inputData, {y: yIn, sampleWeight: swIn}
) as [PreprocessorOutputs, Tensor, Tensor];

expectTensorsClose(output[0].tokenIds, expectedX.tokenIds);
expectTensorsClose(output[0].paddingMask, expectedX.paddingMask);
expectTensorsClose(output[1], yIn);
expectTensorsClose(output[2], swIn);
});

it('sequence length override', () => {
const inputData = tensor(['airplane at airport']);

const output = preprocessor.callAndPackArgs(
inputData, {sequenceLength: 4}
) as PreprocessorOutputs;

expectTensorsClose(output.tokenIds, tensor2d([[6, 1, 3, 6]]));
});

it('does not leak memory', () => {
const inputData = tensor(['airplane at airport']);

const numTensorsBefore = memory().numTensors;
preprocessor.callAndPackArgs(inputData, {sequenceLength: 4});
const numTensorsAfter = memory().numTensors;
expect(numTensorsAfter).toEqual(numTensorsBefore + 2);
});

it('serialization round-trip', () => {
const reserialized = GPT2Preprocessor.fromConfig(
GPT2Preprocessor, preprocessor.getConfig());

const originalConfig = preprocessor.getConfig();
const reserializedConfig = reserialized.getConfig();

// TODO(pforderique): Verify any tokenizer name consistency issues.
delete ((originalConfig['tokenizer'] as serialization.ConfigDict
)['config'] as serialization.ConfigDict) ['name'];
delete ((reserializedConfig['tokenizer'] as serialization.ConfigDict
)['config'] as serialization.ConfigDict) ['name'];

expect(reserializedConfig).toEqual(originalConfig);
});
});

0 comments on commit 81abd7b

Please sign in to comment.