Skip to content

Commit

Permalink
[Local benchmark tool] Add parallel compile option (#7755)
Browse files Browse the repository at this point in the history
FEATURE

* add parallel compile benchmark

* Update index.html
  • Loading branch information
Linchenn committed Jun 27, 2023
1 parent f746187 commit 22196bb
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 8 deletions.
57 changes: 57 additions & 0 deletions e2e/benchmarks/benchmark_util.js
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,63 @@ async function timeInference(predict, numRuns = 1) {
return timeInfo;
}

/**
* Time one model inference with parallel compilation feature, based on the
* current backend.
*
* The inference time contains the time spent by both `predict()` and `data()`
* called by tensors in the prediction.
*
* ```js
* // Benchmark the first infernece time with parallel compilation.
* const modelUrl =
* 'https://tfhub.dev/google/imagenet/mobilenet_v2_140_224/classification/2';
* const model = await tf.loadGraphModel(modelUrl, {fromTFHub: true});
* const zeros = tf.zeros([1, 224, 224, 3]);
* const firstInferenceTime =
* await timeFirstInference(() => model.predict(zeros), true);
* ```
*
* @param predict The predict function to execute and time for.
* @param parallelCompile The boolean value to indicate whether to use parallel
* compilation. This currently only has effect for WebGL backend and WebGPU
* backend.
*/
async function timeFirstInference(predict, parallelCompile = false) {
const start = performance.now();

// Parallel Compile
if (parallelCompile && tf.getBackend() === 'webgl') {
tf.env().set('ENGINE_COMPILE_ONLY', true);
const compileRes = predict();
tf.env().set('ENGINE_COMPILE_ONLY', false);
await tf.backend().checkCompileCompletionAsync();
tf.backend().getUniformLocations();
tf.dispose(compileRes);
} else if (parallelCompile && tf.getBackend() === 'webgpu') {
tf.env().set('WEBGPU_ENGINE_COMPILE_ONLY', true);
const compileRes = predict();
tf.env().set('WEBGPU_ENGINE_COMPILE_ONLY', false);
await tf.backend().checkCompileCompletionAsync();
tf.dispose(compileRes);
} else if (parallelCompile && isTflite()) {
throw new Error('Parallel Compilation for TFlite is not supported.');
}

// First inference
let res = predict();
if (parallelCompile && res instanceof Promise) {
throw new Error(
'Parallel Compilation for async function is not supported.');
}
res = await res;
await downloadValuesFromTensorContainer(res);
const elapsedTime = performance.now() - start;

tf.dispose(res);
return elapsedTime;
}

/**
* Downloads the values from the `tensorContainer` from any `tf.Tensor`s found
* within the `tensorContainer`. Returns a promise of `TypedArray` or
Expand Down
43 changes: 35 additions & 8 deletions e2e/benchmarks/local-benchmark/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ <h2>TensorFlow.js Model Benchmark</h2>
let runTimes = 50;
let profileTimes = 1;
let warmupTimes = 1;
let paralllelCompile = false;
// Default do not run any task.
let task = '';
function getURLState(url) {
Expand All @@ -113,13 +114,17 @@ <h2>TensorFlow.js Model Benchmark</h2>
if (params.has('task')) {
task = params.get('task');
}
if (params.has('parallelCompile')) {
paralllelCompile = Boolean(params.get('parallelCompile'));
}
return params;
}
// Controllers can be updated by URI or clicked by user.
let modelController = null;
let numWarmupsController = null;
let numRunsController = null;
let numProfilesController = null;
let parallelCompileController = null;
let modelParameterFolder = null;
let runButton = null;
let correctnessButton = null;
Expand Down Expand Up @@ -189,6 +194,7 @@ <h2>TensorFlow.js Model Benchmark</h2>
numWarmups: warmupTimes,
numRuns: runTimes,
numProfiles: profileTimes,
parallelCompile: paralllelCompile,
benchmark: 'MobileNetV3',
run: (v) => {
runBenchmark().catch(e => {
Expand Down Expand Up @@ -406,7 +412,27 @@ <h2>TensorFlow.js Model Benchmark</h2>
tbody.appendChild(tr);
}

async function warmUpAndRecordTime() {
async function measureFirstPredictTime() {
let predictFn;
if (state.benchmark === 'custom') {
const input = generateInputFromDef(
state.inputs, model instanceof tf.GraphModel);
predictFn = isTflite() ?
() => predict(undefined, input) :
getPredictFnForModel(model, input);
} else {
predictFn = () => predict(model);
}

const firstInferenceTime = await timeFirstInference(
predictFn, state.parallelCompile);
const tag = '1st inference time' +
(state.parallelCompile ? ' (parallel compile)' : '');
appendRow(timeTable, 'backend', state.backend);
appendRow(timeTable, tag, printTime(firstInferenceTime));
}

async function warmUp() {
const numWarmups = state.numWarmups;
if (numWarmups == 0) {
await showMsg('Skip warming up');
Expand All @@ -416,23 +442,19 @@ <h2>TensorFlow.js Model Benchmark</h2>
}
await showMsg(`Warming up for ${numWarmups} time(s)`);

let timeInfo;
if (state.benchmark === 'custom') {
const input = generateInputFromDef(state.inputs, model instanceof tf.GraphModel);
try {
timeInfo = isTflite() ?
isTflite() ?
await timeInference(() => predict(undefined, input), numWarmups) :
await timeModelInference(model, input, numWarmups);
} finally {
tf.dispose(input);
}
} else {
timeInfo = await timeInference(() => predict(model), numWarmups);
await timeInference(() => predict(model), numWarmups);
}

await showMsg(null);
appendRow(timeTable, 'backend', state.backend);
appendRow(timeTable, 'Warmup time', printTime(timeInfo.times[0]));
}

async function showInputs() {
Expand Down Expand Up @@ -771,7 +793,8 @@ <h2>TensorFlow.js Model Benchmark</h2>
}
await showBenchmarkingParameters();

await warmUpAndRecordTime();
await measureFirstPredictTime();
await warmUp();
await measureAveragePredictTime();
await profileMemoryAndKernelTime();
urlState = null;
Expand Down Expand Up @@ -919,6 +942,7 @@ <h2>TensorFlow.js Model Benchmark</h2>
numRunsController = parameterFolder.add(state, 'numRuns');
numProfilesController = parameterFolder.add(state, 'numProfiles');
parameterFolder.add(state, 'kernelTiming', ['aggregate', 'individual']);
parallelCompileController = parameterFolder.add(state, 'parallelCompile');
parameterFolder.open();

// Show model parameter UI when loading the page.
Expand Down Expand Up @@ -960,6 +984,8 @@ <h2>TensorFlow.js Model Benchmark</h2>
// Remove the "Test correctness" button which doesn't apply to tfjs-tflite.
correctnessButton.destroy();
correctnessButton = null;
parallelCompileController.destroy();
parallelCompileController = null;

// Update models drop down to only show models that support tflite.
const tfliteBenchmarks = Object.keys(benchmarks)
Expand All @@ -973,6 +999,7 @@ <h2>TensorFlow.js Model Benchmark</h2>
// tflite to non-tflite backend.
if (correctnessButton == null) {
correctnessButton = gui.add(state, 'testCorrectness');
parallelCompileController = parameterFolder.add(state, 'parallelCompile');
const allBenchmarks = Object.keys(benchmarks);
updateModelsDropdown(allBenchmarks);
modelController.setValue(allBenchmarks[0]);
Expand Down

0 comments on commit 22196bb

Please sign in to comment.