Skip to content

feat(pix2pix): add inference batching #83

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 57 additions & 47 deletions src/app/modules/pix2pix/pix2pix.model.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import type {Tensor, Tensor3D} from '@tensorflow/tfjs';
import type {Tensor, Tensor4D} from '@tensorflow/tfjs';
import type {LayersModel} from '@tensorflow/tfjs-layers';
import {loadTFDS} from '../../core/services/tfjs/tfjs.loader';

type Image = ImageBitmap | ImageData;

class ModelNotLoadedError extends Error {
constructor() {
super('Model not loaded');
Expand Down Expand Up @@ -62,7 +64,7 @@ async function loadGeneratorModel(layers: Map<string, string>) {
lstmSequence[2].config.recurrent_initializer.class_name = 'Zeros';
// Make model stateful
lstmSequence[0].config.batch_input_shape[0] = 1;
lstmSequence[0].config.batch_input_shape[1] = 1;
// lstmSequence[0].config.batch_input_shape[1] = 1;
lstmSequence[2].config.stateful = true;

// Convert JSON to blob
Expand All @@ -83,91 +85,99 @@ export async function loadModel(
resetDropout(model.layers); // Extremely important, as we are performing inference in training mode
}

function isGreen(r: number, g: number, b: number) {
return g > 255 / 2 && g > r * 1.5 && g > b * 1.5;
function removeGreenScreenTensorflow(image: Tensor, tf): Tensor {
const [red, green, blue] = tf.split(image, 3, -1);
const greenMask = green.greater(0.5);
const redMask = green.greater(red.mul(1.5));
const blueMask = green.greater(blue.mul(1.5));
const alpha = tf.logicalNot(greenMask.logicalAnd(redMask).logicalAnd(blueMask)).cast('float32');
return tf.concat([red, green, blue, alpha], -1);
}

function removeGreenScreen(data: Uint8ClampedArray): Uint8ClampedArray {
// TODO consider
// https://github.com/bhj/gl-chromakey
// https://github.com/Sean-Bradley/Three.js-TypeScript-Boilerplate/blob/webcam/src/client/client.ts
// (easiest) https://developer.vonage.com/blog/2020/06/24/use-a-green-screen-in-javascript-with-vonage-video
let queuePromise: Promise<any> = Promise.resolve([null]);
let globalQueueId = 0;

// This takes 0.15ms for 256x256 images, would perhaps be good to do this in wasm.
for (let i = 0; i < data.length; i += 4) {
if (isGreen(data[i], data[i + 1], data[i + 2])) {
data[i + 3] = 0;
}
}
return data;
}
let imageQueue: Image[] = [];

let queuePromise: Promise<any> = Promise.resolve();
let globalQueueId = 0;
// const BATCH_TIMEOUT = 50; // Adjust this based on the desired maximum waiting time for new images (in ms)

export async function translateQueue(queueId: number, image: ImageBitmap | ImageData): Promise<Uint8ClampedArray> {
globalQueueId = queueId;
imageQueue.push(image);

const tensor = await translate(image); // Lazy tensor evaluation
const tf = await tfPromise;

// Adjust this based on your hardware and performance requirements
const MAX_BATCH_SIZE = tf.getBackend() === 'webgpu' ? 8 : 1;

// Chain the model evaluation per frame
queuePromise = queuePromise.then(() => {
queuePromise = queuePromise.then(async (lastImages: any[]) => {
if (globalQueueId !== queueId) {
return null;
}

return tensor.buffer(); // 60-70ms
// Remove the oldest image from the results
lastImages.shift();
if (lastImages.length > 0) {
return lastImages;
}

const now = performance.now();

// Dequeue images up to max batch size
const images = imageQueue.splice(0, MAX_BATCH_SIZE);

const lazyTensor = await translate(images);
const buffer = await lazyTensor.buffer(); // 60-70ms
const tensor = buffer.toTensor();
console.log('Batching', images.length, (performance.now() - now).toFixed(1), 'ms');
return tensor.unstack();
});

const imageBuffer = await queuePromise;
let outputImage = await tf.browser.toPixels(imageBuffer.toTensor()); // ~1-3ms
outputImage = removeGreenScreen(outputImage); // ~0.1-0.2ms
const outputImages = await queuePromise; // Get output images from batch
if (outputImages === null) {
return null;
}

return outputImage;
return tf.browser.toPixels(outputImages[0]); // ~1-3ms
}

const frameTimes = [];
let lastFrameTime = null;

function upscale(tensor: Tensor) {
return (upscaler.predict(tensor) as Tensor)
.depthToSpace(3, 'NHWC') // Could not convert the depthToSpace operation to tfjs, must use this instead
.clipByValue(0, 1); // Clipping to [0, 1] as upscale model may output values greater than 1
}

async function translate(image: ImageBitmap | ImageData): Promise<Tensor3D> {
if (lastFrameTime) {
frameTimes.push(Date.now() - lastFrameTime);
if (frameTimes.length > 20) {
const totalTime = frameTimes.slice(frameTimes.length - 20).reduce((a, b) => a + b, 0);
console.log('average', (totalTime / 20).toFixed(1), 'ms');
}
}
lastFrameTime = Date.now();

async function translate(images: Image[]): Promise<Tensor4D> {
if (!model) {
throw new ModelNotLoadedError();
}

const tf = await tfPromise;

return tf.tidy(() => {
const pixels = tf.browser.fromPixels(image, 3); // 0.1-0.3ms
const pixelsTensor = pixels.toFloat();
const input = tf.sub(tf.div(pixelsTensor, tf.scalar(255 / 2)), tf.scalar(1)); // # Normalizing the images to [-1, 1]
const tensor = tf.reshape(input, [1, 1, ...input.shape]); // Add batch and time dimensions
const one = tf.scalar(1);
const half = tf.scalar(0.5);

const pixelsBatch = images.map(image => tf.browser.fromPixels(image, 3));
const pixelsTensor = tf
.stack(pixelsBatch) // 0.1-0.3ms
.toFloat()
// Normalizing the images to [-1, 1]
.div(tf.scalar(255 / 2))
.sub(one);
const tensor = tf.stack([pixelsTensor]); // Add batch dimension, model expects Tensor5D

// Must apply model in training=True mode to avoid using aggregated norm statistics
let pred = model.apply(tensor, {training: true}) as Tensor; //6-8ms, but works

// let pred = model.predict(tensor) as Tensor; // 3-4ms, but returns black screen
pred = pred.mul(tf.scalar(0.5)).add(tf.scalar(0.5)); // Normalization to range [0, 1]
pred = pred.mul(half).add(half); // Normalization to range [0, 1]

pred = tf.squeeze(pred, [0]); // Remove time dimension
pred = upscale(pred);

pred = tf.squeeze(pred); // Remove batch dimension
return pred as Tensor3D;
pred = removeGreenScreenTensorflow(pred, tf);

return pred as Tensor4D;
});
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ export class HumanPoseViewerComponent extends BasePoseViewerComponent implements

// To avoid communication time losses, we create a queue sent to be translated
let queued = 0;
let nextFramePromise: Promise<any> = Promise.resolve();

const iterFrame = async () => {
// Verify element is not destroyed
Expand All @@ -68,17 +69,21 @@ export class HumanPoseViewerComponent extends BasePoseViewerComponent implements
}

queued++;
await new Promise(requestAnimationFrame); // Await animation frame due to canvas change
const image = await transferableImage(poseCanvas, poseCtx);
await pose.nextFrame();
// Await animation frame due to canvas change, prevent multiple sends for the same frame
nextFramePromise = nextFramePromise.then(async () => {
await new Promise(requestAnimationFrame);
const image = await transferableImage(poseCanvas, poseCtx);
await pose.nextFrame();
return image;
});
const image = await nextFramePromise;
this.translateFrame(image, canvas, ctx).then(() => {
queued--;
iterFrame();
});
};

for (let i = 0; i < 3; i++) {
// Leaving at 1, need to fix VideoEncoder timestamps
for (let i = 0; i < 8; i++) {
await iterFrame();
}
}),
Expand Down