diff --git a/e2e/benchmarks/benchmark_util.js b/e2e/benchmarks/benchmark_util.js index 2f2db9b2c45..414489d780b 100644 --- a/e2e/benchmarks/benchmark_util.js +++ b/e2e/benchmarks/benchmark_util.js @@ -280,6 +280,63 @@ async function timeInference(predict, numRuns = 1) { return timeInfo; } +/** + * Time one model inference with parallel compilation feature, based on the + * current backend. + * + * The inference time contains the time spent by both `predict()` and `data()` + * called by tensors in the prediction. + * + * ```js + * // Benchmark the first infernece time with parallel compilation. + * const modelUrl = + * 'https://tfhub.dev/google/imagenet/mobilenet_v2_140_224/classification/2'; + * const model = await tf.loadGraphModel(modelUrl, {fromTFHub: true}); + * const zeros = tf.zeros([1, 224, 224, 3]); + * const firstInferenceTime = + * await timeFirstInference(() => model.predict(zeros), true); + * ``` + * + * @param predict The predict function to execute and time for. + * @param parallelCompile The boolean value to indicate whether to use parallel + * compilation. This currently only has effect for WebGL backend and WebGPU + * backend. + */ +async function timeFirstInference(predict, parallelCompile = false) { + const start = performance.now(); + + // Parallel Compile + if (parallelCompile && tf.getBackend() === 'webgl') { + tf.env().set('ENGINE_COMPILE_ONLY', true); + const compileRes = predict(); + tf.env().set('ENGINE_COMPILE_ONLY', false); + await tf.backend().checkCompileCompletionAsync(); + tf.backend().getUniformLocations(); + tf.dispose(compileRes); + } else if (parallelCompile && tf.getBackend() === 'webgpu') { + tf.env().set('WEBGPU_ENGINE_COMPILE_ONLY', true); + const compileRes = predict(); + tf.env().set('WEBGPU_ENGINE_COMPILE_ONLY', false); + await tf.backend().checkCompileCompletionAsync(); + tf.dispose(compileRes); + } else if (parallelCompile && isTflite()) { + throw new Error('Parallel Compilation for TFlite is not supported.'); + } + + // First inference + let res = predict(); + if (parallelCompile && res instanceof Promise) { + throw new Error( + 'Parallel Compilation for async function is not supported.'); + } + res = await res; + await downloadValuesFromTensorContainer(res); + const elapsedTime = performance.now() - start; + + tf.dispose(res); + return elapsedTime; +} + /** * Downloads the values from the `tensorContainer` from any `tf.Tensor`s found * within the `tensorContainer`. Returns a promise of `TypedArray` or diff --git a/e2e/benchmarks/local-benchmark/index.html b/e2e/benchmarks/local-benchmark/index.html index 9b9e3bdeb63..b5f3be09d42 100644 --- a/e2e/benchmarks/local-benchmark/index.html +++ b/e2e/benchmarks/local-benchmark/index.html @@ -94,6 +94,7 @@

TensorFlow.js Model Benchmark

let runTimes = 50; let profileTimes = 1; let warmupTimes = 1; + let paralllelCompile = false; // Default do not run any task. let task = ''; function getURLState(url) { @@ -113,6 +114,9 @@

TensorFlow.js Model Benchmark

if (params.has('task')) { task = params.get('task'); } + if (params.has('parallelCompile')) { + paralllelCompile = Boolean(params.get('parallelCompile')); + } return params; } // Controllers can be updated by URI or clicked by user. @@ -120,6 +124,7 @@

TensorFlow.js Model Benchmark

let numWarmupsController = null; let numRunsController = null; let numProfilesController = null; + let parallelCompileController = null; let modelParameterFolder = null; let runButton = null; let correctnessButton = null; @@ -189,6 +194,7 @@

TensorFlow.js Model Benchmark

numWarmups: warmupTimes, numRuns: runTimes, numProfiles: profileTimes, + parallelCompile: paralllelCompile, benchmark: 'MobileNetV3', run: (v) => { runBenchmark().catch(e => { @@ -406,7 +412,27 @@

TensorFlow.js Model Benchmark

tbody.appendChild(tr); } - async function warmUpAndRecordTime() { + async function measureFirstPredictTime() { + let predictFn; + if (state.benchmark === 'custom') { + const input = generateInputFromDef( + state.inputs, model instanceof tf.GraphModel); + predictFn = isTflite() ? + () => predict(undefined, input) : + getPredictFnForModel(model, input); + } else { + predictFn = () => predict(model); + } + + const firstInferenceTime = await timeFirstInference( + predictFn, state.parallelCompile); + const tag = '1st inference time' + + (state.parallelCompile ? ' (parallel compile)' : ''); + appendRow(timeTable, 'backend', state.backend); + appendRow(timeTable, tag, printTime(firstInferenceTime)); + } + + async function warmUp() { const numWarmups = state.numWarmups; if (numWarmups == 0) { await showMsg('Skip warming up'); @@ -416,23 +442,19 @@

TensorFlow.js Model Benchmark

} await showMsg(`Warming up for ${numWarmups} time(s)`); - let timeInfo; if (state.benchmark === 'custom') { const input = generateInputFromDef(state.inputs, model instanceof tf.GraphModel); try { - timeInfo = isTflite() ? + isTflite() ? await timeInference(() => predict(undefined, input), numWarmups) : await timeModelInference(model, input, numWarmups); } finally { tf.dispose(input); } } else { - timeInfo = await timeInference(() => predict(model), numWarmups); + await timeInference(() => predict(model), numWarmups); } - await showMsg(null); - appendRow(timeTable, 'backend', state.backend); - appendRow(timeTable, 'Warmup time', printTime(timeInfo.times[0])); } async function showInputs() { @@ -771,7 +793,8 @@

TensorFlow.js Model Benchmark

} await showBenchmarkingParameters(); - await warmUpAndRecordTime(); + await measureFirstPredictTime(); + await warmUp(); await measureAveragePredictTime(); await profileMemoryAndKernelTime(); urlState = null; @@ -919,6 +942,7 @@

TensorFlow.js Model Benchmark

numRunsController = parameterFolder.add(state, 'numRuns'); numProfilesController = parameterFolder.add(state, 'numProfiles'); parameterFolder.add(state, 'kernelTiming', ['aggregate', 'individual']); + parallelCompileController = parameterFolder.add(state, 'parallelCompile'); parameterFolder.open(); // Show model parameter UI when loading the page. @@ -960,6 +984,8 @@

TensorFlow.js Model Benchmark

// Remove the "Test correctness" button which doesn't apply to tfjs-tflite. correctnessButton.destroy(); correctnessButton = null; + parallelCompileController.destroy(); + parallelCompileController = null; // Update models drop down to only show models that support tflite. const tfliteBenchmarks = Object.keys(benchmarks) @@ -973,6 +999,7 @@

TensorFlow.js Model Benchmark

// tflite to non-tflite backend. if (correctnessButton == null) { correctnessButton = gui.add(state, 'testCorrectness'); + parallelCompileController = parameterFolder.add(state, 'parallelCompile'); const allBenchmarks = Object.keys(benchmarks); updateModelsDropdown(allBenchmarks); modelController.setValue(allBenchmarks[0]);