Skip to content

Commit 0d31b72

Browse files
Support loading models from ModelArtifacts and ModelJSON in loadGraphModelSync (#6847)
Let users pass inputs of type ModelArtifacts or [ModelJSON, ArrayBuffer] to loadGraphModelSync instead of requiring them to construct a synchronous IOHandler. In the case of ModelArtifacts, the model weights are contained in the object. In the case of [ModelJSON, ArrayBuffer], the ArrayBuffer is a list of concatenated weights for the model. This change turns this code: const modelJson = JSON.parse(modelString); const weights = ... modelJson.weightData = weights; modelJson.weightSpecs = modelJson.weightsManifest[0].weights; const ioHandler = tf.io.fromMemorySync(modelJson); const model = tf.loadGraphModelSync(ioHandler); ... into this simpler code const modelJson = JSON.parse(modelString); const weights = ... const model = tf.loadGraphModelSync([modelJson, weights]);
1 parent feb0eeb commit 0d31b72

File tree

5 files changed

+173
-29
lines changed

5 files changed

+173
-29
lines changed

tfjs-converter/src/executor/graph_model.ts

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -499,23 +499,57 @@ export async function loadGraphModel(
499499
/**
500500
* Load a graph model given a synchronous IO handler with a 'load' method.
501501
*
502-
* @param modelSource The `io.IOHandlerSync` that loads the model.
502+
* @param modelSource The `io.IOHandlerSync` that loads the model, or the
503+
* `io.ModelArtifacts` that encode the model, or a tuple of
504+
* `[io.ModelJSON, ArrayBuffer]` of which the first element encodes the
505+
* model and the second contains the weights.
503506
*
504507
* @doc {heading: 'Models', subheading: 'Loading'}
505508
*/
509+
export function loadGraphModelSync(modelSource: io.IOHandlerSync
510+
| io.ModelArtifacts | [io.ModelJSON, /* Weights */ ArrayBuffer]):
511+
GraphModel<io.IOHandlerSync> {
506512

507-
export function loadGraphModelSync(modelSource: io.IOHandlerSync):
508-
GraphModel<io.IOHandlerSync> {
509513
if (modelSource == null) {
510514
throw new Error(
511-
'modelUrl in loadGraphModelSync() cannot be null. Please provide a ' +
512-
'url or an IOHandler that loads the model');
515+
'modelUrl in loadGraphModelSync() cannot be null. Please provide ' +
516+
'model artifacts or an IOHandler that loads the model');
513517
}
514-
if (!modelSource.load) {
515-
throw new Error(`modelUrl IO Handler ${modelSource} has no load function`);
518+
519+
let ioHandler: io.IOHandlerSync;
520+
if (modelSource instanceof Array) {
521+
const [modelJSON, weights] = modelSource;
522+
if (!modelJSON) {
523+
throw new Error('modelJSON must be the first element of the array');
524+
}
525+
if (!weights || !(weights instanceof ArrayBuffer)) {
526+
throw new Error('An ArrayBuffer of weights must be the second element of'
527+
+ ' the array');
528+
}
529+
if (!('modelTopology' in modelJSON)) {
530+
throw new Error('Model JSON is missing \'modelTopology\'');
531+
}
532+
if (!('weightsManifest' in modelJSON)) {
533+
throw new Error('Model JSON is missing \'weightsManifest\'');
534+
}
535+
536+
const weightSpecs = io.getWeightSpecs(modelJSON.weightsManifest);
537+
const modelArtifacts = io.getModelArtifactsForJSONSync(modelJSON,
538+
weightSpecs,
539+
weights);
540+
ioHandler = io.fromMemorySync(modelArtifacts);
541+
} else if ('load' in modelSource) {
542+
// Then modelSource is already an IOHandlerSync.
543+
ioHandler = modelSource;
544+
} else if ('modelTopology' in modelSource && 'weightSpecs' in modelSource
545+
&& 'weightData' in modelSource) {
546+
// modelSource is of type ModelArtifacts.
547+
ioHandler = io.fromMemorySync(modelSource);
548+
} else {
549+
throw new Error('Unknown model format');
516550
}
517-
const model = new GraphModel(modelSource);
518551

552+
const model = new GraphModel(ioHandler);
519553
model.load();
520554
return model;
521555
}

tfjs-converter/src/executor/graph_model_test.ts

Lines changed: 70 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,13 @@ describe('loadGraphModel', () => {
425425
});
426426

427427
describe('loadGraphModelSync', () => {
428+
function checkModel(model: GraphModel<unknown>) {
429+
expect(model).toBeDefined();
430+
const bias = model.weights['Const'][0];
431+
expect(bias.dtype).toBe('int32');
432+
expect(bias.dataSync()).toEqual(new Int32Array([5]));
433+
}
434+
428435
it('Pass a custom io handler', () => {
429436
const customLoader: tfc.io.IOHandlerSync = {
430437
load: () => {
@@ -436,10 +443,69 @@ describe('loadGraphModelSync', () => {
436443
}
437444
};
438445
const model = loadGraphModelSync(customLoader);
439-
expect(model).toBeDefined();
440-
const bias = model.weights['Const'][0];
441-
expect(bias.dtype).toBe('int32');
442-
expect(bias.dataSync()).toEqual(new Int32Array([5]));
446+
checkModel(model);
447+
});
448+
449+
it('Pass the model artifacts directly', () => {
450+
const modelArtifacts: tfc.io.ModelArtifacts = {
451+
modelTopology: SIMPLE_MODEL,
452+
weightSpecs: weightsManifest,
453+
weightData: new Int32Array([5]).buffer,
454+
};
455+
const model = loadGraphModelSync(modelArtifacts);
456+
checkModel(model);
457+
});
458+
459+
it('Pass the model JSON and weights', () => {
460+
const modelJson: tfc.io.ModelJSON = {
461+
modelTopology: SIMPLE_MODEL,
462+
weightsManifest: [{paths: [], weights: weightsManifest}],
463+
};
464+
const weights = new Int32Array([5]).buffer;
465+
const model = loadGraphModelSync([modelJson, weights]);
466+
checkModel(model);
467+
});
468+
469+
it('Throws an error if ModelJSON is passed without weights', () => {
470+
const modelJson: tfc.io.ModelJSON = {
471+
modelTopology: SIMPLE_MODEL,
472+
weightsManifest: [{paths: [], weights: weightsManifest}],
473+
};
474+
expect(() => {
475+
return loadGraphModelSync([modelJson] as unknown as [io.ModelJSON,
476+
ArrayBuffer]);
477+
}).toThrowMatching(err =>
478+
err.message.includes('weights must be the second element'));
479+
});
480+
481+
it('Throws an error if modelJSON is missing \'modelTopology\'', () => {
482+
const badInput = {
483+
weightsManifest: [{paths: [] as string[], weights: weightsManifest}],
484+
};
485+
const weights = new Int32Array([5]).buffer;
486+
expect(() => {
487+
return loadGraphModelSync([badInput as io.ModelJSON, weights]);
488+
}).toThrowMatching(err =>
489+
err.message.includes('missing \'modelTopology\''));
490+
});
491+
492+
it('Throws an error if modelJSON is missing \'weightsManifest\'', () => {
493+
const badInput = {
494+
modelTopology: SIMPLE_MODEL,
495+
};
496+
const weights = new Int32Array([5]).buffer;
497+
expect(() => {
498+
return loadGraphModelSync([badInput as io.ModelJSON, weights]);
499+
}).toThrowMatching(err =>
500+
err.message.includes('missing \'weightsManifest\''));
501+
});
502+
503+
it('Throws an error if modelSource is an unknown format', () => {
504+
const badInput = {foo: 'bar'};
505+
expect(() => {
506+
return loadGraphModelSync(badInput as io.ModelArtifacts);
507+
}).toThrowMatching(err =>
508+
err.message.includes('Unknown model format'));
443509
});
444510

445511
it('Expect an error when moderUrl is null', () => {

tfjs-core/src/io/http.ts

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import {env} from '../environment';
2525

2626
import {assert} from '../util';
27-
import {concatenateArrayBuffers, getModelArtifactsForJSON, getModelArtifactsInfoForJSON, getModelJSONForModelArtifacts} from './io_utils';
27+
import {concatenateArrayBuffers, getModelArtifactsForJSON, getModelArtifactsInfoForJSON, getModelJSONForModelArtifacts, getWeightSpecs} from './io_utils';
2828
import {IORouter, IORouterRegistry} from './router_registry';
2929
import {IOHandler, LoadOptions, ModelArtifacts, ModelJSON, OnProgressCallback, SaveResult, WeightsManifestConfig, WeightsManifestEntry} from './types';
3030
import {loadWeightsAsArrayBuffer} from './weights_loader';
@@ -187,10 +187,7 @@ export class HTTPRequest implements IOHandler {
187187
const [prefix, suffix] = parseUrl(weightPath);
188188
const pathPrefix = this.weightPathPrefix || prefix;
189189

190-
const weightSpecs = [];
191-
for (const entry of weightsManifest) {
192-
weightSpecs.push(...entry.weights);
193-
}
190+
const weightSpecs = getWeightSpecs(weightsManifest);
194191

195192
const fetchURLs: string[] = [];
196193
const urlPromises: Array<Promise<string>> = [];

tfjs-core/src/io/io.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import './local_storage';
2222

2323
import {browserFiles} from './browser_files';
2424
import {browserHTTPRequest, http, isHTTPScheme} from './http';
25-
import {concatenateArrayBuffers, decodeWeights, encodeWeights, getModelArtifactsForJSON, getModelArtifactsInfoForJSON} from './io_utils';
25+
import {concatenateArrayBuffers, decodeWeights, encodeWeights, getModelArtifactsForJSON, getModelArtifactsForJSONSync, getModelArtifactsInfoForJSON, getWeightSpecs} from './io_utils';
2626
import {fromMemory, fromMemorySync, withSaveHandler, withSaveHandlerSync} from './passthrough';
2727
import {getLoadHandlers, getSaveHandlers, registerLoadRouter, registerSaveRouter} from './router_registry';
2828
import {IOHandler, IOHandlerSync, LoadHandler, LoadOptions, ModelArtifacts, ModelArtifactsInfo, ModelJSON, ModelStoreManager, OnProgressCallback, RequestDetails, SaveConfig, SaveHandler, SaveResult, TrainingConfig, WeightGroup, WeightsManifestConfig, WeightsManifestEntry} from './types';
@@ -39,8 +39,10 @@ export {
3939
fromMemorySync,
4040
getLoadHandlers,
4141
getModelArtifactsForJSON,
42+
getModelArtifactsForJSONSync,
4243
getModelArtifactsInfoForJSON,
4344
getSaveHandlers,
45+
getWeightSpecs,
4446
http,
4547
IOHandler,
4648
IOHandlerSync,

tfjs-core/src/io/io_utils.ts

Lines changed: 56 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -403,19 +403,20 @@ export function getModelJSONForModelArtifacts(
403403
}
404404

405405
/**
406-
* Create `ModelArtifacts` from a JSON file.
406+
* Create `ModelArtifacts` from a JSON file and weights.
407407
*
408408
* @param modelJSON Object containing the parsed JSON of `model.json`
409-
* @param loadWeights Function that takes the JSON file's weights manifest,
410-
* reads weights from the listed path(s), and returns a Promise of the
411-
* weight manifest entries along with the weights data.
409+
* @param weightSpecs The list of WeightsManifestEntry for the model. Must be
410+
* passed if the modelJSON has a weightsManifest.
411+
* @param weightData An ArrayBuffer of weight data for the model corresponding
412+
* to the weights in weightSpecs. Must be passed if the modelJSON has a
413+
* weightsManifest.
412414
* @returns A Promise of the `ModelArtifacts`, as described by the JSON file.
413415
*/
414-
export async function getModelArtifactsForJSON(
415-
modelJSON: ModelJSON,
416-
loadWeights: (weightsManifest: WeightsManifestConfig) => Promise<[
417-
/* weightSpecs */ WeightsManifestEntry[], /* weightData */ ArrayBuffer
418-
]>): Promise<ModelArtifacts> {
416+
export function getModelArtifactsForJSONSync(
417+
modelJSON: ModelJSON, weightSpecs?: WeightsManifestEntry[],
418+
weightData?: ArrayBuffer): ModelArtifacts {
419+
419420
const modelArtifacts: ModelArtifacts = {
420421
modelTopology: modelJSON.modelTopology,
421422
format: modelJSON.format,
@@ -427,8 +428,12 @@ export async function getModelArtifactsForJSON(
427428
modelArtifacts.trainingConfig = modelJSON.trainingConfig;
428429
}
429430
if (modelJSON.weightsManifest != null) {
430-
const [weightSpecs, weightData] =
431-
await loadWeights(modelJSON.weightsManifest);
431+
if (!weightSpecs) {
432+
throw new Error('modelJSON has weightsManifest but weightSpecs is null');
433+
}
434+
if (!weightData) {
435+
throw new Error('modelJSON has weightsManifest but weightData is null');
436+
}
432437
modelArtifacts.weightSpecs = weightSpecs;
433438
modelArtifacts.weightData = weightData;
434439
}
@@ -445,6 +450,30 @@ export async function getModelArtifactsForJSON(
445450
return modelArtifacts;
446451
}
447452

453+
/**
454+
* Create `ModelArtifacts` from a JSON file.
455+
*
456+
* @param modelJSON Object containing the parsed JSON of `model.json`
457+
* @param loadWeights Function that takes the JSON file's weights manifest,
458+
* reads weights from the listed path(s), and returns a Promise of the
459+
* weight manifest entries along with the weights data.
460+
* @returns A Promise of the `ModelArtifacts`, as described by the JSON file.
461+
*/
462+
export async function getModelArtifactsForJSON(
463+
modelJSON: ModelJSON,
464+
loadWeights: (weightsManifest: WeightsManifestConfig) => Promise<[
465+
/* weightSpecs */ WeightsManifestEntry[], /* weightData */ ArrayBuffer
466+
]>): Promise<ModelArtifacts> {
467+
let weightSpecs: WeightsManifestEntry[] | undefined;
468+
let weightData: ArrayBuffer | undefined;
469+
470+
if (modelJSON.weightsManifest != null) {
471+
[weightSpecs, weightData] = await loadWeights(modelJSON.weightsManifest);
472+
}
473+
474+
return getModelArtifactsForJSONSync(modelJSON, weightSpecs, weightData);
475+
}
476+
448477
/**
449478
* Populate ModelArtifactsInfo fields for a model with JSON topology.
450479
* @param modelArtifacts
@@ -471,6 +500,22 @@ export function getModelArtifactsInfoForJSON(modelArtifacts: ModelArtifacts):
471500
};
472501
}
473502

503+
/**
504+
* Concatenate the weights stored in a WeightsManifestConfig into a list of
505+
* WeightsManifestEntry
506+
*
507+
* @param weightsManifest The WeightsManifestConfig to extract weights from.
508+
* @returns A list of WeightsManifestEntry of the weights in the weightsManifest
509+
*/
510+
export function getWeightSpecs(weightsManifest: WeightsManifestConfig):
511+
WeightsManifestEntry[] {
512+
const weightSpecs: WeightsManifestEntry[] = [];
513+
for (const entry of weightsManifest) {
514+
weightSpecs.push(...entry.weights);
515+
}
516+
return weightSpecs;
517+
}
518+
474519
/**
475520
* Computes mantisa table for casting Float16 to Float32
476521
* See http://www.fox-toolkit.org/ftp/fasthalffloatconversion.pdf

0 commit comments

Comments
 (0)