From 5f66ddd0f97e12634f0655dd339490198e2d0078 Mon Sep 17 00:00:00 2001 From: Amit Moryossef Date: Fri, 7 Apr 2023 21:10:36 +0200 Subject: [PATCH 1/3] feat(pix2pix): add inference batching --- src/app/modules/pix2pix/pix2pix.model.ts | 64 +++++++++++-------- .../human-pose-viewer.component.ts | 14 ++-- 2 files changed, 49 insertions(+), 29 deletions(-) diff --git a/src/app/modules/pix2pix/pix2pix.model.ts b/src/app/modules/pix2pix/pix2pix.model.ts index 21ada41b..e19e3f0f 100644 --- a/src/app/modules/pix2pix/pix2pix.model.ts +++ b/src/app/modules/pix2pix/pix2pix.model.ts @@ -1,7 +1,9 @@ -import type {Tensor, Tensor3D} from '@tensorflow/tfjs'; +import type {Tensor, Tensor4D} from '@tensorflow/tfjs'; import type {LayersModel} from '@tensorflow/tfjs-layers'; import {loadTFDS} from '../../core/services/tfjs/tfjs.loader'; +type Image = ImageBitmap | ImageData; + class ModelNotLoadedError extends Error { constructor() { super('Model not loaded'); @@ -62,7 +64,7 @@ async function loadGeneratorModel(layers: Map) { lstmSequence[2].config.recurrent_initializer.class_name = 'Zeros'; // Make model stateful lstmSequence[0].config.batch_input_shape[0] = 1; - lstmSequence[0].config.batch_input_shape[1] = 1; + // lstmSequence[0].config.batch_input_shape[1] = 1; lstmSequence[2].config.stateful = true; // Convert JSON to blob @@ -102,50 +104,63 @@ function removeGreenScreen(data: Uint8ClampedArray): Uint8ClampedArray { return data; } -let queuePromise: Promise = Promise.resolve(); +let queuePromise: Promise = Promise.resolve([null]); let globalQueueId = 0; +let imageQueue: Image[] = []; +// const BATCH_TIMEOUT = 50; // Adjust this based on the desired maximum waiting time for new images (in ms) + export async function translateQueue(queueId: number, image: ImageBitmap | ImageData): Promise { globalQueueId = queueId; + imageQueue.push(image); - const tensor = await translate(image); // Lazy tensor evaluation const tf = await tfPromise; + // Adjust this based on your hardware and performance requirements + const MAX_BATCH_SIZE = tf.getBackend() === 'webgpu' ? 8 : 1; + // Chain the model evaluation per frame - queuePromise = queuePromise.then(() => { + queuePromise = queuePromise.then(async (lastImages: any[]) => { if (globalQueueId !== queueId) { return null; } - return tensor.buffer(); // 60-70ms + // Remove the oldest image from the results + lastImages.shift(); + if (lastImages.length > 0) { + return lastImages; + } + + const now = performance.now(); + + // Dequeue images up to max batch size + const randomBatchSize = Math.floor(Math.random() * MAX_BATCH_SIZE) + 1; + const images = imageQueue.splice(0, randomBatchSize); + + const lazyTensor = await translate(images); + const buffer = await lazyTensor.buffer(); // 60-70ms + const tensor = buffer.toTensor(); + console.log('Batching', images.length, (performance.now() - now).toFixed(1), 'ms'); + return tensor.unstack(); }); - const imageBuffer = await queuePromise; - let outputImage = await tf.browser.toPixels(imageBuffer.toTensor()); // ~1-3ms + const outputImages = await queuePromise; // Get first image from batch + if (outputImages === null) { + return null; + } + let outputImage = await tf.browser.toPixels(outputImages[0]); // ~1-3ms outputImage = removeGreenScreen(outputImage); // ~0.1-0.2ms return outputImage; } -const frameTimes = []; -let lastFrameTime = null; - function upscale(tensor: Tensor) { return (upscaler.predict(tensor) as Tensor) .depthToSpace(3, 'NHWC') // Could not convert the depthToSpace operation to tfjs, must use this instead .clipByValue(0, 1); // Clipping to [0, 1] as upscale model may output values greater than 1 } -async function translate(image: ImageBitmap | ImageData): Promise { - if (lastFrameTime) { - frameTimes.push(Date.now() - lastFrameTime); - if (frameTimes.length > 20) { - const totalTime = frameTimes.slice(frameTimes.length - 20).reduce((a, b) => a + b, 0); - console.log('average', (totalTime / 20).toFixed(1), 'ms'); - } - } - lastFrameTime = Date.now(); - +async function translate(images: Image[]): Promise { if (!model) { throw new ModelNotLoadedError(); } @@ -153,10 +168,10 @@ async function translate(image: ImageBitmap | ImageData): Promise { const tf = await tfPromise; return tf.tidy(() => { - const pixels = tf.browser.fromPixels(image, 3); // 0.1-0.3ms + const pixels = tf.stack(images.map(image => tf.browser.fromPixels(image, 3))); // 0.1-0.3ms const pixelsTensor = pixels.toFloat(); const input = tf.sub(tf.div(pixelsTensor, tf.scalar(255 / 2)), tf.scalar(1)); // # Normalizing the images to [-1, 1] - const tensor = tf.reshape(input, [1, 1, ...input.shape]); // Add batch and time dimensions + const tensor = tf.stack([input]); // Add batch dimension, model expects Tensor5D // Must apply model in training=True mode to avoid using aggregated norm statistics let pred = model.apply(tensor, {training: true}) as Tensor; //6-8ms, but works @@ -167,7 +182,6 @@ async function translate(image: ImageBitmap | ImageData): Promise { pred = tf.squeeze(pred, [0]); // Remove time dimension pred = upscale(pred); - pred = tf.squeeze(pred); // Remove batch dimension - return pred as Tensor3D; + return pred as Tensor4D; }); } diff --git a/src/app/pages/translate/pose-viewers/human-pose-viewer/human-pose-viewer.component.ts b/src/app/pages/translate/pose-viewers/human-pose-viewer/human-pose-viewer.component.ts index ec21c953..3643b66c 100644 --- a/src/app/pages/translate/pose-viewers/human-pose-viewer/human-pose-viewer.component.ts +++ b/src/app/pages/translate/pose-viewers/human-pose-viewer/human-pose-viewer.component.ts @@ -51,6 +51,7 @@ export class HumanPoseViewerComponent extends BasePoseViewerComponent implements // To avoid communication time losses, we create a queue sent to be translated let queued = 0; + let nextFramePromise: Promise = Promise.resolve(); const iterFrame = async () => { // Verify element is not destroyed @@ -68,8 +69,14 @@ export class HumanPoseViewerComponent extends BasePoseViewerComponent implements } queued++; - await new Promise(requestAnimationFrame); // Await animation frame due to canvas change - const image = await transferableImage(poseCanvas, poseCtx); + // Await animation frame due to canvas change, prevent multiple sends for the same frame + nextFramePromise = nextFramePromise.then(async () => { + await new Promise(requestAnimationFrame); + const image = await transferableImage(poseCanvas, poseCtx); + await pose.nextFrame(); + return image; + }); + const image = await nextFramePromise; await pose.nextFrame(); this.translateFrame(image, canvas, ctx).then(() => { queued--; @@ -77,8 +84,7 @@ export class HumanPoseViewerComponent extends BasePoseViewerComponent implements }); }; - for (let i = 0; i < 3; i++) { - // Leaving at 1, need to fix VideoEncoder timestamps + for (let i = 0; i < 8; i++) { await iterFrame(); } }), From d71f47b00119f356f15a75aff779f1ba3a255560 Mon Sep 17 00:00:00 2001 From: Amit Moryossef Date: Fri, 7 Apr 2023 21:11:43 +0200 Subject: [PATCH 2/3] feat(pix2pix): add inference batching --- src/app/modules/pix2pix/pix2pix.model.ts | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/app/modules/pix2pix/pix2pix.model.ts b/src/app/modules/pix2pix/pix2pix.model.ts index e19e3f0f..75f314f7 100644 --- a/src/app/modules/pix2pix/pix2pix.model.ts +++ b/src/app/modules/pix2pix/pix2pix.model.ts @@ -134,8 +134,7 @@ export async function translateQueue(queueId: number, image: ImageBitmap | Image const now = performance.now(); // Dequeue images up to max batch size - const randomBatchSize = Math.floor(Math.random() * MAX_BATCH_SIZE) + 1; - const images = imageQueue.splice(0, randomBatchSize); + const images = imageQueue.splice(0, MAX_BATCH_SIZE); const lazyTensor = await translate(images); const buffer = await lazyTensor.buffer(); // 60-70ms From 5d9fe2e5f19490533951e7474a583d7fc232c145 Mon Sep 17 00:00:00 2001 From: Amit Moryossef Date: Fri, 7 Apr 2023 22:04:25 +0200 Subject: [PATCH 3/3] fix(pix2pix): don't skip frames, optimize green screen function --- src/app/modules/pix2pix/pix2pix.model.ts | 51 +++++++++---------- .../human-pose-viewer.component.ts | 1 - 2 files changed, 24 insertions(+), 28 deletions(-) diff --git a/src/app/modules/pix2pix/pix2pix.model.ts b/src/app/modules/pix2pix/pix2pix.model.ts index 75f314f7..34b83685 100644 --- a/src/app/modules/pix2pix/pix2pix.model.ts +++ b/src/app/modules/pix2pix/pix2pix.model.ts @@ -85,29 +85,20 @@ export async function loadModel( resetDropout(model.layers); // Extremely important, as we are performing inference in training mode } -function isGreen(r: number, g: number, b: number) { - return g > 255 / 2 && g > r * 1.5 && g > b * 1.5; -} - -function removeGreenScreen(data: Uint8ClampedArray): Uint8ClampedArray { - // TODO consider - // https://github.com/bhj/gl-chromakey - // https://github.com/Sean-Bradley/Three.js-TypeScript-Boilerplate/blob/webcam/src/client/client.ts - // (easiest) https://developer.vonage.com/blog/2020/06/24/use-a-green-screen-in-javascript-with-vonage-video - - // This takes 0.15ms for 256x256 images, would perhaps be good to do this in wasm. - for (let i = 0; i < data.length; i += 4) { - if (isGreen(data[i], data[i + 1], data[i + 2])) { - data[i + 3] = 0; - } - } - return data; +function removeGreenScreenTensorflow(image: Tensor, tf): Tensor { + const [red, green, blue] = tf.split(image, 3, -1); + const greenMask = green.greater(0.5); + const redMask = green.greater(red.mul(1.5)); + const blueMask = green.greater(blue.mul(1.5)); + const alpha = tf.logicalNot(greenMask.logicalAnd(redMask).logicalAnd(blueMask)).cast('float32'); + return tf.concat([red, green, blue, alpha], -1); } let queuePromise: Promise = Promise.resolve([null]); let globalQueueId = 0; let imageQueue: Image[] = []; + // const BATCH_TIMEOUT = 50; // Adjust this based on the desired maximum waiting time for new images (in ms) export async function translateQueue(queueId: number, image: ImageBitmap | ImageData): Promise { @@ -143,14 +134,12 @@ export async function translateQueue(queueId: number, image: ImageBitmap | Image return tensor.unstack(); }); - const outputImages = await queuePromise; // Get first image from batch + const outputImages = await queuePromise; // Get output images from batch if (outputImages === null) { return null; } - let outputImage = await tf.browser.toPixels(outputImages[0]); // ~1-3ms - outputImage = removeGreenScreen(outputImage); // ~0.1-0.2ms - return outputImage; + return tf.browser.toPixels(outputImages[0]); // ~1-3ms } function upscale(tensor: Tensor) { @@ -167,20 +156,28 @@ async function translate(images: Image[]): Promise { const tf = await tfPromise; return tf.tidy(() => { - const pixels = tf.stack(images.map(image => tf.browser.fromPixels(image, 3))); // 0.1-0.3ms - const pixelsTensor = pixels.toFloat(); - const input = tf.sub(tf.div(pixelsTensor, tf.scalar(255 / 2)), tf.scalar(1)); // # Normalizing the images to [-1, 1] - const tensor = tf.stack([input]); // Add batch dimension, model expects Tensor5D + const one = tf.scalar(1); + const half = tf.scalar(0.5); + + const pixelsBatch = images.map(image => tf.browser.fromPixels(image, 3)); + const pixelsTensor = tf + .stack(pixelsBatch) // 0.1-0.3ms + .toFloat() + // Normalizing the images to [-1, 1] + .div(tf.scalar(255 / 2)) + .sub(one); + const tensor = tf.stack([pixelsTensor]); // Add batch dimension, model expects Tensor5D // Must apply model in training=True mode to avoid using aggregated norm statistics let pred = model.apply(tensor, {training: true}) as Tensor; //6-8ms, but works - // let pred = model.predict(tensor) as Tensor; // 3-4ms, but returns black screen - pred = pred.mul(tf.scalar(0.5)).add(tf.scalar(0.5)); // Normalization to range [0, 1] + pred = pred.mul(half).add(half); // Normalization to range [0, 1] pred = tf.squeeze(pred, [0]); // Remove time dimension pred = upscale(pred); + pred = removeGreenScreenTensorflow(pred, tf); + return pred as Tensor4D; }); } diff --git a/src/app/pages/translate/pose-viewers/human-pose-viewer/human-pose-viewer.component.ts b/src/app/pages/translate/pose-viewers/human-pose-viewer/human-pose-viewer.component.ts index 3643b66c..651c441c 100644 --- a/src/app/pages/translate/pose-viewers/human-pose-viewer/human-pose-viewer.component.ts +++ b/src/app/pages/translate/pose-viewers/human-pose-viewer/human-pose-viewer.component.ts @@ -77,7 +77,6 @@ export class HumanPoseViewerComponent extends BasePoseViewerComponent implements return image; }); const image = await nextFramePromise; - await pose.nextFrame(); this.translateFrame(image, canvas, ctx).then(() => { queued--; iterFrame();