From bac23eaeb3e0a5408758e34ecb1ca39ddc6bf06b Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Tue, 19 Nov 2024 15:11:51 +0800 Subject: [PATCH] [js/webgpu] Donot record with computePassEncoder when capturing --- js/web/lib/wasm/jsep/backend-webgpu.ts | 42 +++++++++++++------ .../lib/wasm/jsep/webgpu/gpu-data-manager.ts | 32 +++++++++----- .../lib/wasm/jsep/webgpu/program-manager.ts | 10 ++--- 3 files changed, 56 insertions(+), 28 deletions(-) diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index a0010df4643a4..e3ec1682629d4 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -23,13 +23,21 @@ import { TimestampQuery, } from './webgpu/types'; -interface CommandInfo { +interface ComputeCommand { readonly kernelId: number; readonly computePipeline: GPUComputePipeline; readonly bindGroup: GPUBindGroup; readonly dispatchGroup: [number, number, number]; } +interface MemcpyCommand { + readonly source: GPUBuffer; + readonly dest: GPUBuffer; + readonly size: number; +} + +type Command = ComputeCommand | MemcpyCommand; + interface KernelInfo { readonly kernelType: string; readonly kernelName: string; @@ -234,9 +242,9 @@ export class WebGpuBackend { env: Env; sessionStatus: SessionState = 'default'; /** - * a SessionID -> CommandInfo[] mapping. It's used to record all GPU commands for corresponding session. + * a SessionID -> Command[] mapping. It's used to record all GPU commands for corresponding session. */ - capturedCommandList: Map = new Map(); + capturedCommandList: Map = new Map(); /** * a SessionID -> PendingKernelInfo[] mapping for profiling. @@ -909,18 +917,26 @@ export class WebGpuBackend { for (let i = 0; i < length; i++) { const computePassEncoder = this.getComputePassEncoder(); const command = sessionCommandList![i]; - this.writeTimestamp(this.pendingDispatchNumber * 2); - computePassEncoder.setPipeline(command.computePipeline); - computePassEncoder.setBindGroup(0, command.bindGroup); - computePassEncoder.dispatchWorkgroups(...command.dispatchGroup); - this.writeTimestamp(this.pendingDispatchNumber * 2 + 1); - this.pendingDispatchNumber++; - if (this.queryType !== 'none') { - this.pendingKernels.push(sessionPendingKernels![i]); - } - if (this.pendingDispatchNumber >= this.maxDispatchNumber || this.queryType === 'at-passes') { + if ('bindGroup' in command) { + this.writeTimestamp(this.pendingDispatchNumber * 2); + computePassEncoder.setPipeline(command.computePipeline); + computePassEncoder.setBindGroup(0, command.bindGroup); + computePassEncoder.dispatchWorkgroups(...command.dispatchGroup); + this.writeTimestamp(this.pendingDispatchNumber * 2 + 1); + + this.pendingDispatchNumber++; + if (this.queryType !== 'none') { + this.pendingKernels.push(sessionPendingKernels![i]); + } + if (this.pendingDispatchNumber >= this.maxDispatchNumber || this.queryType === 'at-passes') { + this.endComputePass(); + } + } else { + const commandEncoder = this.getCommandEncoder(); this.endComputePass(); + commandEncoder.copyBufferToBuffer(command.source, 0, command.dest, 0, command.size); } + if (this.pendingDispatchNumber >= this.maxDispatchNumber) { this.flush(); } diff --git a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts index 1c6016500e7d3..b468e665bd8cb 100644 --- a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts @@ -274,16 +274,28 @@ class GpuDataManagerImpl implements GpuDataManager { const size = calcNormalizedBufferSize(sourceGpuDataCache.originalSize); - // GPU copy - const commandEncoder = this.backend.getCommandEncoder(); - this.backend.endComputePass(); - commandEncoder.copyBufferToBuffer( - sourceGpuDataCache.gpuData.buffer, - 0, - destinationGpuDataCache.gpuData.buffer, - 0, - size, - ); + if (this.backend.sessionStatus === 'capturing') { + const command = { + source: sourceGpuDataCache.gpuData.buffer, + dest: destinationGpuDataCache.gpuData.buffer, + size, + }; + const sessionCommandList = this.backend.capturedCommandList.get(this.backend.currentSessionId!); + sessionCommandList!.push(command); + + this.backend.pendingDispatchNumber++; + } else { + // GPU copy + const commandEncoder = this.backend.getCommandEncoder(); + this.backend.endComputePass(); + commandEncoder.copyBufferToBuffer( + sourceGpuDataCache.gpuData.buffer, + 0, + destinationGpuDataCache.gpuData.buffer, + 0, + size, + ); + } } registerExternalBuffer(buffer: GPUBuffer, originalSize: number, previous?: [GpuDataId, GPUBuffer]): number { diff --git a/js/web/lib/wasm/jsep/webgpu/program-manager.ts b/js/web/lib/wasm/jsep/webgpu/program-manager.ts index 2c5180c5db3ee..d36255b27cff6 100644 --- a/js/web/lib/wasm/jsep/webgpu/program-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/program-manager.ts @@ -41,7 +41,6 @@ export class ProgramManager { ): void { TRACE_FUNC_BEGIN(buildArtifact.programInfo.name); const device = this.backend.device; - const computePassEncoder = this.backend.getComputePassEncoder(); this.backend.writeTimestamp(this.backend.pendingDispatchNumber * 2); const entries = []; for (const input of inputs) { @@ -68,11 +67,12 @@ export class ProgramManager { }; const sessionCommandList = this.backend.capturedCommandList.get(this.backend.currentSessionId!); sessionCommandList!.push(commandInfo); + } else { + const computePassEncoder = this.backend.getComputePassEncoder(); + computePassEncoder.setPipeline(buildArtifact.computePipeline); + computePassEncoder.setBindGroup(0, bindGroup); + computePassEncoder.dispatchWorkgroups(...dispatchGroup); } - - computePassEncoder.setPipeline(buildArtifact.computePipeline); - computePassEncoder.setBindGroup(0, bindGroup); - computePassEncoder.dispatchWorkgroups(...dispatchGroup); this.backend.writeTimestamp(this.backend.pendingDispatchNumber * 2 + 1); this.backend.pendingDispatchNumber++;