From 048a78a5134b1a19ab5ed8f5364ea58b0b63f97b Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Tue, 19 Nov 2024 15:11:51 +0800 Subject: [PATCH] [js/webgpu] Donot record with computePassEncoder when capturing --- js/web/lib/wasm/jsep/backend-webgpu.ts | 26 ++++++++--- .../lib/wasm/jsep/webgpu/gpu-data-manager.ts | 43 ++++++++++++++----- .../lib/wasm/jsep/webgpu/program-manager.ts | 10 ++--- 3 files changed, 58 insertions(+), 21 deletions(-) diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index a0010df4643a4..81c977d66dbde 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -23,13 +23,21 @@ import { TimestampQuery, } from './webgpu/types'; -interface CommandInfo { +interface ComputeCommand { readonly kernelId: number; readonly computePipeline: GPUComputePipeline; readonly bindGroup: GPUBindGroup; readonly dispatchGroup: [number, number, number]; } +interface MemcpyCommand { + readonly source: GPUBuffer; + readonly dest: GPUBuffer; + readonly size: number; +} + +type Command = ComputeCommand | MemcpyCommand; + interface KernelInfo { readonly kernelType: string; readonly kernelName: string; @@ -234,9 +242,9 @@ export class WebGpuBackend { env: Env; sessionStatus: SessionState = 'default'; /** - * a SessionID -> CommandInfo[] mapping. It's used to record all GPU commands for corresponding session. + * a SessionID -> Command[] mapping. It's used to record all GPU commands for corresponding session. */ - capturedCommandList: Map = new Map(); + capturedCommandList: Map = new Map(); /** * a SessionID -> PendingKernelInfo[] mapping for profiling. @@ -910,9 +918,15 @@ export class WebGpuBackend { const computePassEncoder = this.getComputePassEncoder(); const command = sessionCommandList![i]; this.writeTimestamp(this.pendingDispatchNumber * 2); - computePassEncoder.setPipeline(command.computePipeline); - computePassEncoder.setBindGroup(0, command.bindGroup); - computePassEncoder.dispatchWorkgroups(...command.dispatchGroup); + if ('bindGroup' in command) { + computePassEncoder.setPipeline(command.computePipeline); + computePassEncoder.setBindGroup(0, command.bindGroup); + computePassEncoder.dispatchWorkgroups(...command.dispatchGroup); + } else { + const commandEncoder = this.getCommandEncoder(); + this.endComputePass(); + commandEncoder.copyBufferToBuffer(command.source, 0, command.dest, 0, command.size); + } this.writeTimestamp(this.pendingDispatchNumber * 2 + 1); this.pendingDispatchNumber++; if (this.queryType !== 'none') { diff --git a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts index 1c6016500e7d3..0dbf9a1e5eb96 100644 --- a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts @@ -274,16 +274,39 @@ class GpuDataManagerImpl implements GpuDataManager { const size = calcNormalizedBufferSize(sourceGpuDataCache.originalSize); - // GPU copy - const commandEncoder = this.backend.getCommandEncoder(); - this.backend.endComputePass(); - commandEncoder.copyBufferToBuffer( - sourceGpuDataCache.gpuData.buffer, - 0, - destinationGpuDataCache.gpuData.buffer, - 0, - size, - ); + if (this.backend.sessionStatus === 'capturing') { + const command = { + source: sourceGpuDataCache.gpuData.buffer, + dest: destinationGpuDataCache.gpuData.buffer, + size: size, + }; + const sessionCommandList = this.backend.capturedCommandList.get(this.backend.currentSessionId!); + sessionCommandList!.push(command); + } else { + // GPU copy + const commandEncoder = this.backend.getCommandEncoder(); + this.backend.endComputePass(); + commandEncoder.copyBufferToBuffer( + sourceGpuDataCache.gpuData.buffer, + 0, + destinationGpuDataCache.gpuData.buffer, + 0, + size, + ); + } + + this.backend.writeTimestamp(this.backend.pendingDispatchNumber * 2 + 1); + this.backend.pendingDispatchNumber++; + + if ( + this.backend.pendingDispatchNumber >= this.backend.maxDispatchNumber || + this.backend.queryType === 'at-passes' + ) { + this.backend.endComputePass(); + } + if (this.backend.pendingDispatchNumber >= this.backend.maxDispatchNumber) { + this.backend.flush(); + } } registerExternalBuffer(buffer: GPUBuffer, originalSize: number, previous?: [GpuDataId, GPUBuffer]): number { diff --git a/js/web/lib/wasm/jsep/webgpu/program-manager.ts b/js/web/lib/wasm/jsep/webgpu/program-manager.ts index 2c5180c5db3ee..d36255b27cff6 100644 --- a/js/web/lib/wasm/jsep/webgpu/program-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/program-manager.ts @@ -41,7 +41,6 @@ export class ProgramManager { ): void { TRACE_FUNC_BEGIN(buildArtifact.programInfo.name); const device = this.backend.device; - const computePassEncoder = this.backend.getComputePassEncoder(); this.backend.writeTimestamp(this.backend.pendingDispatchNumber * 2); const entries = []; for (const input of inputs) { @@ -68,11 +67,12 @@ export class ProgramManager { }; const sessionCommandList = this.backend.capturedCommandList.get(this.backend.currentSessionId!); sessionCommandList!.push(commandInfo); + } else { + const computePassEncoder = this.backend.getComputePassEncoder(); + computePassEncoder.setPipeline(buildArtifact.computePipeline); + computePassEncoder.setBindGroup(0, bindGroup); + computePassEncoder.dispatchWorkgroups(...dispatchGroup); } - - computePassEncoder.setPipeline(buildArtifact.computePipeline); - computePassEncoder.setBindGroup(0, bindGroup); - computePassEncoder.dispatchWorkgroups(...dispatchGroup); this.backend.writeTimestamp(this.backend.pendingDispatchNumber * 2 + 1); this.backend.pendingDispatchNumber++;