From e8b80b9832a580092fdebd16d0e7a57675078814 Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Tue, 19 Nov 2024 15:11:51 +0800 Subject: [PATCH] [js/webgpu] Donot record with computePassEncoder when capturing --- js/web/lib/wasm/jsep/backend-webgpu.ts | 25 +++++++---- .../lib/wasm/jsep/webgpu/gpu-data-manager.ts | 43 ++++++++++++++----- .../lib/wasm/jsep/webgpu/program-manager.ts | 10 ++--- 3 files changed, 55 insertions(+), 23 deletions(-) diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index a0010df4643a4..824b7ab44d10b 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -24,10 +24,13 @@ import { } from './webgpu/types'; interface CommandInfo { - readonly kernelId: number; - readonly computePipeline: GPUComputePipeline; - readonly bindGroup: GPUBindGroup; - readonly dispatchGroup: [number, number, number]; + readonly kernelId?: number; + readonly computePipeline?: GPUComputePipeline; + readonly bindGroup?: GPUBindGroup; + readonly dispatchGroup?: [number, number, number]; + readonly source?: GPUBuffer; + readonly dest?: GPUBuffer; + readonly size?: number; } interface KernelInfo { @@ -909,10 +912,16 @@ export class WebGpuBackend { for (let i = 0; i < length; i++) { const computePassEncoder = this.getComputePassEncoder(); const command = sessionCommandList![i]; - this.writeTimestamp(this.pendingDispatchNumber * 2); - computePassEncoder.setPipeline(command.computePipeline); - computePassEncoder.setBindGroup(0, command.bindGroup); - computePassEncoder.dispatchWorkgroups(...command.dispatchGroup); + if (command.bindGroup) { + this.writeTimestamp(this.pendingDispatchNumber * 2); + computePassEncoder.setPipeline(command.computePipeline!); + computePassEncoder.setBindGroup(0, command.bindGroup); + computePassEncoder.dispatchWorkgroups(...command.dispatchGroup!); + } else { + this.writeTimestamp(this.pendingDispatchNumber * 2); + const commandEncoder = this.getCommandEncoder(); + commandEncoder.copyBufferToBuffer(command.source!, 0, command.dest!, 0, command.size!); + } this.writeTimestamp(this.pendingDispatchNumber * 2 + 1); this.pendingDispatchNumber++; if (this.queryType !== 'none') { diff --git a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts index 1c6016500e7d3..3ba6d82f03677 100644 --- a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts @@ -274,16 +274,39 @@ class GpuDataManagerImpl implements GpuDataManager { const size = calcNormalizedBufferSize(sourceGpuDataCache.originalSize); - // GPU copy - const commandEncoder = this.backend.getCommandEncoder(); - this.backend.endComputePass(); - commandEncoder.copyBufferToBuffer( - sourceGpuDataCache.gpuData.buffer, - 0, - destinationGpuDataCache.gpuData.buffer, - 0, - size, - ); + if (this.backend.sessionStatus === 'capturing') { + const commandInfo = { + source: sourceGpuDataCache.gpuData.buffer, + dest: destinationGpuDataCache.gpuData.buffer, + size: size, + }; + const sessionCommandList = this.backend.capturedCommandList.get(this.backend.currentSessionId!); + sessionCommandList!.push(commandInfo); + } else { + // GPU copy + const commandEncoder = this.backend.getCommandEncoder(); + this.backend.endComputePass(); + commandEncoder.copyBufferToBuffer( + sourceGpuDataCache.gpuData.buffer, + 0, + destinationGpuDataCache.gpuData.buffer, + 0, + size, + ); + } + + this.backend.writeTimestamp(this.backend.pendingDispatchNumber * 2 + 1); + this.backend.pendingDispatchNumber++; + + if ( + this.backend.pendingDispatchNumber >= this.backend.maxDispatchNumber || + this.backend.queryType === 'at-passes' + ) { + this.backend.endComputePass(); + } + if (this.backend.pendingDispatchNumber >= this.backend.maxDispatchNumber) { + this.backend.flush(); + } } registerExternalBuffer(buffer: GPUBuffer, originalSize: number, previous?: [GpuDataId, GPUBuffer]): number { diff --git a/js/web/lib/wasm/jsep/webgpu/program-manager.ts b/js/web/lib/wasm/jsep/webgpu/program-manager.ts index 2c5180c5db3ee..d36255b27cff6 100644 --- a/js/web/lib/wasm/jsep/webgpu/program-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/program-manager.ts @@ -41,7 +41,6 @@ export class ProgramManager { ): void { TRACE_FUNC_BEGIN(buildArtifact.programInfo.name); const device = this.backend.device; - const computePassEncoder = this.backend.getComputePassEncoder(); this.backend.writeTimestamp(this.backend.pendingDispatchNumber * 2); const entries = []; for (const input of inputs) { @@ -68,11 +67,12 @@ export class ProgramManager { }; const sessionCommandList = this.backend.capturedCommandList.get(this.backend.currentSessionId!); sessionCommandList!.push(commandInfo); + } else { + const computePassEncoder = this.backend.getComputePassEncoder(); + computePassEncoder.setPipeline(buildArtifact.computePipeline); + computePassEncoder.setBindGroup(0, bindGroup); + computePassEncoder.dispatchWorkgroups(...dispatchGroup); } - - computePassEncoder.setPipeline(buildArtifact.computePipeline); - computePassEncoder.setBindGroup(0, bindGroup); - computePassEncoder.dispatchWorkgroups(...dispatchGroup); this.backend.writeTimestamp(this.backend.pendingDispatchNumber * 2 + 1); this.backend.pendingDispatchNumber++;