forked from tensorflow/tfjs
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[gpt2pre 4] GPT2Preprocessor Layer (tensorflow#7814)
* Draft gpt2 Preprocessor class * Add serialization test * Add memory test * lint --------- Co-authored-by: Linchenn <[email protected]>
- Loading branch information
1 parent
983dc16
commit 81abd7b
Showing
2 changed files
with
360 additions
and
0 deletions.
There are no files selected for viewing
220 changes: 220 additions & 0 deletions
220
tfjs-layers/src/layers/nlp/models/gpt2/gpt2_preprocessor.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,220 @@ | ||
/** | ||
* @license | ||
* Copyright 2023 Google LLC. | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
* ============================================================================= | ||
*/ | ||
|
||
/** | ||
* GPT-2 preprocessor layer. | ||
*/ | ||
|
||
/* Original source: keras-nlp/models/gpt2/gpt2_preprocessor.py */ | ||
import { Tensor, Tensor2D, serialization, tidy } from '@tensorflow/tfjs-core'; | ||
|
||
import { LayerArgs } from '../../../../engine/topology'; | ||
import { Preprocessor } from '../preprocessor'; | ||
import { GPT2Tokenizer } from './gpt2_tokenizer'; | ||
import { StartEndPacker } from '../../preprocessing/start_end_packer'; | ||
import { ValueError } from '../../../../errors'; | ||
|
||
export declare interface GPT2PreprocessorArgs extends LayerArgs { | ||
/** | ||
* A GPT2Tokenizer instance. | ||
*/ | ||
tokenizer: GPT2Tokenizer; | ||
|
||
/** | ||
* The length of the packed inputs. | ||
* Defaults to 1024. | ||
*/ | ||
sequenceLength?: number; | ||
|
||
/** | ||
* If `true`, the preprocessor will prepend the tokenizer start token to each | ||
* input sequence. | ||
* Defaults to `true`. | ||
*/ | ||
addStartToken?: boolean; | ||
|
||
/** | ||
* If `true`, the preprocessor will prepend the tokenizer end token to each | ||
* input sequence. | ||
* Defaults to `true`. | ||
*/ | ||
addEndToken?: boolean; | ||
} | ||
|
||
export declare interface GPT2PreprocessorOptions { | ||
/** | ||
* Any label data. Will be passed through unaltered. | ||
*/ | ||
y?: Tensor; | ||
|
||
/** | ||
* Any label weight data. Will be passed through unaltered. | ||
*/ | ||
sampleWeight?: Tensor; | ||
|
||
/** | ||
* Pass to override the configured `sequenceLength` of the layer. | ||
*/ | ||
sequenceLength?: number; | ||
} | ||
|
||
export declare interface PreprocessorOutputs { | ||
tokenIds: Tensor2D; | ||
paddingMask: Tensor2D; | ||
} | ||
|
||
function packXYSampleWeight( | ||
x: PreprocessorOutputs, y?: Tensor, sampleWeight?: Tensor): | ||
PreprocessorOutputs | ||
| [PreprocessorOutputs, Tensor] | ||
| [PreprocessorOutputs, Tensor, Tensor] { | ||
|
||
if (y === undefined) { | ||
return x; | ||
} else if (sampleWeight === undefined) { | ||
return [x, y]; | ||
} else { | ||
return [x, y, sampleWeight]; | ||
} | ||
} | ||
|
||
/** | ||
* GPT2 preprocessing layer which tokenizes and packs inputs. | ||
* | ||
* This preprocessing layer will do 2 things: | ||
* | ||
* - Tokenize the inputs using the `tokenizer`. | ||
* - Construct a dictionary with keys `"tokenIds"`, `"paddingMask"`, that can | ||
* be passed directly to a `GPT2Backbone`. | ||
* | ||
* The call method of this layer accepts three arguments, `x`, `y`, and | ||
* `sampleWeight`. `x` can be a string or tensor representing a single | ||
* segment, a list of strings representing a batch of single segments, | ||
* or a list of tensors representing multiple segments to be packed together. | ||
* `y` and `sampleWeight` are both optional, can have any format, and will be | ||
* passed through unaltered. | ||
* | ||
* `GPT2Preprocessor` forces the input to have only one segment, as GPT2 is | ||
* mainly used for generation tasks. For tasks having multi-segment inputs | ||
* like "glue/mnli", please use a model designed for classification purposes | ||
* such as BERT or RoBERTa. | ||
* | ||
* Examples: | ||
* | ||
* Directly calling the layer on data. | ||
* ```js | ||
* const features = ['a quick fox.', 'a fox quick.']; | ||
* const vocabulary = | ||
* new Map([['<|endoftext|>', 0], ['a', 4], ['Ġquick', 5], ['Ġfox', 6]]); | ||
* const merges = | ||
* ['Ġ q', 'u i', 'c k', 'ui ck', 'Ġq uick', 'Ġ f', 'o x', 'Ġf ox']; | ||
* const tokenizer = GPT2Tokenizer({vocabulary, merges}); | ||
* | ||
* const preprocessor = GPT2Preprocessor({tokenizer}); | ||
* preprocessor.call(tensor(['the quick brown fox jumped.']))[0].print(); | ||
* ``` | ||
*/ | ||
export class GPT2Preprocessor extends Preprocessor { | ||
private readonly sequenceLength: number; | ||
private readonly addStartToken: boolean; | ||
private readonly addEndToken: boolean; | ||
private readonly packer: StartEndPacker; | ||
|
||
constructor(args: GPT2PreprocessorArgs) { | ||
super(args); | ||
this.tokenizer = args.tokenizer; | ||
this.sequenceLength = args.sequenceLength ?? 1024; | ||
this.addStartToken = args.addStartToken ?? true; | ||
this.addEndToken = args.addEndToken ?? true; | ||
|
||
const gpt2Tokenizer = this.tokenizer as GPT2Tokenizer; | ||
this.packer = new StartEndPacker({ | ||
startValue: gpt2Tokenizer.startTokenId, | ||
endValue: gpt2Tokenizer.endTokenId, | ||
padValue: gpt2Tokenizer.padTokenId, | ||
sequenceLength: this.sequenceLength, | ||
}); | ||
} | ||
|
||
override getConfig(): serialization.ConfigDict { | ||
const config = { | ||
sequenceLength: this.sequenceLength, | ||
addStartToken: this.addStartToken, | ||
addEndToken: this.addEndToken, | ||
}; | ||
const baseConfig = super.getConfig(); | ||
Object.assign(config, baseConfig); | ||
return config; | ||
} | ||
|
||
override call( | ||
inputs: Tensor|Tensor[], kwargs: GPT2PreprocessorOptions): Tensor|Tensor[] { | ||
return this.callAndReturnPaddingMask(inputs, kwargs).tokenIds; | ||
} | ||
|
||
private callAndReturnPaddingMask( | ||
inputs: Tensor|Tensor[], | ||
kwargs: GPT2PreprocessorOptions | ||
): PreprocessorOutputs { | ||
return tidy(() => { | ||
if (inputs instanceof Array) { | ||
if (inputs.length !== 1) { | ||
throw new ValueError( | ||
'GPT2 requires each input feature to contain only ' + | ||
`one segment, but received ${inputs.length}. If you are using ` + | ||
'GPT2 for a multi-segment classification task, please refer to ' + | ||
'classification models like BERT or RoBERTa.' | ||
); | ||
} | ||
inputs = inputs[0]; | ||
} | ||
|
||
const sequenceLength = kwargs.sequenceLength ?? this.sequenceLength; | ||
const [tokenIds, paddingMask] = this.packer.callAndReturnPaddingMask( | ||
this.tokenizer.call(inputs), | ||
{ | ||
sequenceLength, | ||
addStartValue: this.addStartToken, | ||
addEndValue: this.addEndToken | ||
} | ||
); | ||
|
||
return { | ||
tokenIds: tokenIds as Tensor2D, | ||
paddingMask: paddingMask as Tensor2D | ||
}; | ||
}); | ||
} | ||
|
||
/** | ||
* Calls the layer and returns extra information like the paddingMask used to | ||
* pack the sequence, the label data, and the sample weights used. | ||
*/ | ||
callAndPackArgs(inputs: Tensor|Tensor[], kwargs: GPT2PreprocessorOptions): | ||
PreprocessorOutputs | ||
| [PreprocessorOutputs, Tensor] | ||
| [PreprocessorOutputs, Tensor, Tensor] { | ||
const x = this.callAndReturnPaddingMask(inputs, kwargs); | ||
return packXYSampleWeight(x, kwargs.y, kwargs.sampleWeight); | ||
} | ||
|
||
static override tokenizerCls<T extends serialization.Serializable>( | ||
cls: serialization.SerializableConstructor<T>) { | ||
return GPT2Tokenizer; | ||
} | ||
} | ||
serialization.registerClass(GPT2Preprocessor); |
140 changes: 140 additions & 0 deletions
140
tfjs-layers/src/layers/nlp/models/gpt2/gpt2_preprocessor_test.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
/** | ||
* @license | ||
* Copyright 2023 Google LLC. | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
* ============================================================================= | ||
*/ | ||
|
||
/** | ||
* Unit Tests for GPT2Preprocessor. | ||
*/ | ||
|
||
import { Tensor, memory, serialization, tensor, tensor2d } from '@tensorflow/tfjs-core'; | ||
|
||
import { GPT2Preprocessor, PreprocessorOutputs } from './gpt2_preprocessor'; | ||
import { GPT2Tokenizer } from './gpt2_tokenizer'; | ||
import { expectTensorsClose } from '../../../../utils/test_utils'; | ||
|
||
describe('GPT2Preprocessor', () => { | ||
let vocabulary: Map<string, number>; | ||
let merges: string[]; | ||
let preprocessor: GPT2Preprocessor; | ||
|
||
beforeEach(() => { | ||
vocabulary = new Map([ | ||
['!', 0], | ||
['air', 1], | ||
['Ġair', 2], | ||
['plane', 3], | ||
['Ġat', 4], | ||
['port', 5], | ||
['<|endoftext|>', 6], | ||
]); | ||
|
||
merges = ['Ġ a', 'Ġ t', 'Ġ i', 'Ġ b', 'a i', 'p l', 'n e'].concat( | ||
['Ġa t', 'p o', 'r t', 'Ġt h', 'ai r', 'pl a', 'po rt'], | ||
['Ġai r', 'Ġa i', 'pla ne'] | ||
); | ||
preprocessor = new GPT2Preprocessor({ | ||
tokenizer: new GPT2Tokenizer({vocabulary, merges}), | ||
sequenceLength: 8 | ||
}); | ||
}); | ||
|
||
it('tokenize', () => { | ||
const inputData = tensor(['airplane at airport']); | ||
|
||
const output = | ||
preprocessor.callAndPackArgs(inputData, {}) as PreprocessorOutputs; | ||
|
||
expectTensorsClose(output.tokenIds, tensor2d([[6, 1, 3, 4, 2, 5, 6, 0]])); | ||
expectTensorsClose( | ||
output.paddingMask, tensor2d([[1, 1, 1, 1, 1, 1, 1, 0]], [1, 8], 'bool')); | ||
}); | ||
|
||
it('no start end token', () => { | ||
const inputData = tensor(Array<string>(4).fill('airplane at airport')); | ||
preprocessor = new GPT2Preprocessor({ | ||
tokenizer: new GPT2Tokenizer({vocabulary, merges}), | ||
sequenceLength: 8, | ||
addStartToken: false, | ||
addEndToken: false, | ||
}); | ||
const expectedOutput = { | ||
tokenIds: tensor2d(Array<number[]>(4).fill([1, 3, 4, 2, 5, 0, 0, 0])), | ||
paddingMask: tensor2d( | ||
Array<number[]>(4).fill([1, 1, 1, 1, 1, 0, 0, 0]), [4, 8], 'bool'), | ||
}; | ||
|
||
const output = | ||
preprocessor.callAndPackArgs(inputData, {}) as PreprocessorOutputs; | ||
|
||
expectTensorsClose(output.tokenIds, expectedOutput.tokenIds); | ||
expectTensorsClose(output.paddingMask, expectedOutput.paddingMask); | ||
}); | ||
|
||
it('tokenize labeled batch', () => { | ||
const inputData = tensor(Array<string>(4).fill('airplane at airport')); | ||
const yIn = tensor([1, 1, 1, 1]); | ||
const swIn = tensor([1., 1., 1., 1.]); | ||
const expectedX = { | ||
tokenIds: tensor2d(Array<number[]>(4).fill([6, 1, 3, 4, 2, 5, 6, 0])), | ||
paddingMask: tensor2d( | ||
Array<number[]>(4).fill([1, 1, 1, 1, 1, 1, 1, 0]), [4, 8], 'bool'), | ||
}; | ||
|
||
const output = preprocessor.callAndPackArgs( | ||
inputData, {y: yIn, sampleWeight: swIn} | ||
) as [PreprocessorOutputs, Tensor, Tensor]; | ||
|
||
expectTensorsClose(output[0].tokenIds, expectedX.tokenIds); | ||
expectTensorsClose(output[0].paddingMask, expectedX.paddingMask); | ||
expectTensorsClose(output[1], yIn); | ||
expectTensorsClose(output[2], swIn); | ||
}); | ||
|
||
it('sequence length override', () => { | ||
const inputData = tensor(['airplane at airport']); | ||
|
||
const output = preprocessor.callAndPackArgs( | ||
inputData, {sequenceLength: 4} | ||
) as PreprocessorOutputs; | ||
|
||
expectTensorsClose(output.tokenIds, tensor2d([[6, 1, 3, 6]])); | ||
}); | ||
|
||
it('does not leak memory', () => { | ||
const inputData = tensor(['airplane at airport']); | ||
|
||
const numTensorsBefore = memory().numTensors; | ||
preprocessor.callAndPackArgs(inputData, {sequenceLength: 4}); | ||
const numTensorsAfter = memory().numTensors; | ||
expect(numTensorsAfter).toEqual(numTensorsBefore + 2); | ||
}); | ||
|
||
it('serialization round-trip', () => { | ||
const reserialized = GPT2Preprocessor.fromConfig( | ||
GPT2Preprocessor, preprocessor.getConfig()); | ||
|
||
const originalConfig = preprocessor.getConfig(); | ||
const reserializedConfig = reserialized.getConfig(); | ||
|
||
// TODO(pforderique): Verify any tokenizer name consistency issues. | ||
delete ((originalConfig['tokenizer'] as serialization.ConfigDict | ||
)['config'] as serialization.ConfigDict) ['name']; | ||
delete ((reserializedConfig['tokenizer'] as serialization.ConfigDict | ||
)['config'] as serialization.ConfigDict) ['name']; | ||
|
||
expect(reserializedConfig).toEqual(originalConfig); | ||
}); | ||
}); |