diff --git a/e2e/benchmarks/benchmark_util.js b/e2e/benchmarks/benchmark_util.js
index 2f2db9b2c45..414489d780b 100644
--- a/e2e/benchmarks/benchmark_util.js
+++ b/e2e/benchmarks/benchmark_util.js
@@ -280,6 +280,63 @@ async function timeInference(predict, numRuns = 1) {
return timeInfo;
}
+/**
+ * Time one model inference with parallel compilation feature, based on the
+ * current backend.
+ *
+ * The inference time contains the time spent by both `predict()` and `data()`
+ * called by tensors in the prediction.
+ *
+ * ```js
+ * // Benchmark the first infernece time with parallel compilation.
+ * const modelUrl =
+ * 'https://tfhub.dev/google/imagenet/mobilenet_v2_140_224/classification/2';
+ * const model = await tf.loadGraphModel(modelUrl, {fromTFHub: true});
+ * const zeros = tf.zeros([1, 224, 224, 3]);
+ * const firstInferenceTime =
+ * await timeFirstInference(() => model.predict(zeros), true);
+ * ```
+ *
+ * @param predict The predict function to execute and time for.
+ * @param parallelCompile The boolean value to indicate whether to use parallel
+ * compilation. This currently only has effect for WebGL backend and WebGPU
+ * backend.
+ */
+async function timeFirstInference(predict, parallelCompile = false) {
+ const start = performance.now();
+
+ // Parallel Compile
+ if (parallelCompile && tf.getBackend() === 'webgl') {
+ tf.env().set('ENGINE_COMPILE_ONLY', true);
+ const compileRes = predict();
+ tf.env().set('ENGINE_COMPILE_ONLY', false);
+ await tf.backend().checkCompileCompletionAsync();
+ tf.backend().getUniformLocations();
+ tf.dispose(compileRes);
+ } else if (parallelCompile && tf.getBackend() === 'webgpu') {
+ tf.env().set('WEBGPU_ENGINE_COMPILE_ONLY', true);
+ const compileRes = predict();
+ tf.env().set('WEBGPU_ENGINE_COMPILE_ONLY', false);
+ await tf.backend().checkCompileCompletionAsync();
+ tf.dispose(compileRes);
+ } else if (parallelCompile && isTflite()) {
+ throw new Error('Parallel Compilation for TFlite is not supported.');
+ }
+
+ // First inference
+ let res = predict();
+ if (parallelCompile && res instanceof Promise) {
+ throw new Error(
+ 'Parallel Compilation for async function is not supported.');
+ }
+ res = await res;
+ await downloadValuesFromTensorContainer(res);
+ const elapsedTime = performance.now() - start;
+
+ tf.dispose(res);
+ return elapsedTime;
+}
+
/**
* Downloads the values from the `tensorContainer` from any `tf.Tensor`s found
* within the `tensorContainer`. Returns a promise of `TypedArray` or
diff --git a/e2e/benchmarks/local-benchmark/index.html b/e2e/benchmarks/local-benchmark/index.html
index 9b9e3bdeb63..b5f3be09d42 100644
--- a/e2e/benchmarks/local-benchmark/index.html
+++ b/e2e/benchmarks/local-benchmark/index.html
@@ -94,6 +94,7 @@
TensorFlow.js Model Benchmark
let runTimes = 50;
let profileTimes = 1;
let warmupTimes = 1;
+ let paralllelCompile = false;
// Default do not run any task.
let task = '';
function getURLState(url) {
@@ -113,6 +114,9 @@ TensorFlow.js Model Benchmark
if (params.has('task')) {
task = params.get('task');
}
+ if (params.has('parallelCompile')) {
+ paralllelCompile = Boolean(params.get('parallelCompile'));
+ }
return params;
}
// Controllers can be updated by URI or clicked by user.
@@ -120,6 +124,7 @@ TensorFlow.js Model Benchmark
let numWarmupsController = null;
let numRunsController = null;
let numProfilesController = null;
+ let parallelCompileController = null;
let modelParameterFolder = null;
let runButton = null;
let correctnessButton = null;
@@ -189,6 +194,7 @@ TensorFlow.js Model Benchmark
numWarmups: warmupTimes,
numRuns: runTimes,
numProfiles: profileTimes,
+ parallelCompile: paralllelCompile,
benchmark: 'MobileNetV3',
run: (v) => {
runBenchmark().catch(e => {
@@ -406,7 +412,27 @@ TensorFlow.js Model Benchmark
tbody.appendChild(tr);
}
- async function warmUpAndRecordTime() {
+ async function measureFirstPredictTime() {
+ let predictFn;
+ if (state.benchmark === 'custom') {
+ const input = generateInputFromDef(
+ state.inputs, model instanceof tf.GraphModel);
+ predictFn = isTflite() ?
+ () => predict(undefined, input) :
+ getPredictFnForModel(model, input);
+ } else {
+ predictFn = () => predict(model);
+ }
+
+ const firstInferenceTime = await timeFirstInference(
+ predictFn, state.parallelCompile);
+ const tag = '1st inference time' +
+ (state.parallelCompile ? ' (parallel compile)' : '');
+ appendRow(timeTable, 'backend', state.backend);
+ appendRow(timeTable, tag, printTime(firstInferenceTime));
+ }
+
+ async function warmUp() {
const numWarmups = state.numWarmups;
if (numWarmups == 0) {
await showMsg('Skip warming up');
@@ -416,23 +442,19 @@ TensorFlow.js Model Benchmark
}
await showMsg(`Warming up for ${numWarmups} time(s)`);
- let timeInfo;
if (state.benchmark === 'custom') {
const input = generateInputFromDef(state.inputs, model instanceof tf.GraphModel);
try {
- timeInfo = isTflite() ?
+ isTflite() ?
await timeInference(() => predict(undefined, input), numWarmups) :
await timeModelInference(model, input, numWarmups);
} finally {
tf.dispose(input);
}
} else {
- timeInfo = await timeInference(() => predict(model), numWarmups);
+ await timeInference(() => predict(model), numWarmups);
}
-
await showMsg(null);
- appendRow(timeTable, 'backend', state.backend);
- appendRow(timeTable, 'Warmup time', printTime(timeInfo.times[0]));
}
async function showInputs() {
@@ -771,7 +793,8 @@ TensorFlow.js Model Benchmark
}
await showBenchmarkingParameters();
- await warmUpAndRecordTime();
+ await measureFirstPredictTime();
+ await warmUp();
await measureAveragePredictTime();
await profileMemoryAndKernelTime();
urlState = null;
@@ -919,6 +942,7 @@ TensorFlow.js Model Benchmark
numRunsController = parameterFolder.add(state, 'numRuns');
numProfilesController = parameterFolder.add(state, 'numProfiles');
parameterFolder.add(state, 'kernelTiming', ['aggregate', 'individual']);
+ parallelCompileController = parameterFolder.add(state, 'parallelCompile');
parameterFolder.open();
// Show model parameter UI when loading the page.
@@ -960,6 +984,8 @@ TensorFlow.js Model Benchmark
// Remove the "Test correctness" button which doesn't apply to tfjs-tflite.
correctnessButton.destroy();
correctnessButton = null;
+ parallelCompileController.destroy();
+ parallelCompileController = null;
// Update models drop down to only show models that support tflite.
const tfliteBenchmarks = Object.keys(benchmarks)
@@ -973,6 +999,7 @@ TensorFlow.js Model Benchmark
// tflite to non-tflite backend.
if (correctnessButton == null) {
correctnessButton = gui.add(state, 'testCorrectness');
+ parallelCompileController = parameterFolder.add(state, 'parallelCompile');
const allBenchmarks = Object.keys(benchmarks);
updateModelsDropdown(allBenchmarks);
modelController.setValue(allBenchmarks[0]);