diff --git a/js/web/lib/wasm/jsep/backend-webnn.ts b/js/web/lib/wasm/jsep/backend-webnn.ts index d510972c236f0..6cad192479625 100644 --- a/js/web/lib/wasm/jsep/backend-webnn.ts +++ b/js/web/lib/wasm/jsep/backend-webnn.ts @@ -8,37 +8,25 @@ import {Tensor} from 'onnxruntime-common'; +import {DataType} from '../wasm-common'; +import {getInstance} from '../wasm-factory'; + import {createView} from './tensor-view'; import {BufferId, BufferManager, createBufferManager} from './webnn/buffer-manager'; -/* - * TensorProto::data_type from the ONNX specification. - */ -enum TensorProtoDataType { - float = 1, - uint8 = 2, - int8 = 3, - int32 = 6, - int64 = 7, - bool = 9, - float16 = 10, - uint32 = 12, - uint64 = 13, -} - /* * TensorProto::data_type to WebNN OperandType mapping. */ -const onnxDataTypeToWebnnDataType = new Map([ - [TensorProtoDataType.float, 'float32'], - [TensorProtoDataType.float16, 'float16'], - [TensorProtoDataType.int32, 'int32'], - [TensorProtoDataType.uint32, 'uint32'], - [TensorProtoDataType.int64, 'int64'], - [TensorProtoDataType.uint64, 'uint64'], - [TensorProtoDataType.int8, 'int8'], - [TensorProtoDataType.uint8, 'uint8'], - [TensorProtoDataType.bool, 'uint8'], +const onnxDataTypeToWebnnDataType = new Map([ + [DataType.float, 'float32'], + [DataType.float16, 'float16'], + [DataType.int32, 'int32'], + [DataType.uint32, 'uint32'], + [DataType.int64, 'int64'], + [DataType.uint64, 'uint64'], + [DataType.int8, 'int8'], + [DataType.uint8, 'uint8'], + [DataType.bool, 'uint8'], ]); /** @@ -130,6 +118,10 @@ export class WebNNBackend { } public uploadBuffer(bufferId: BufferId, data: Uint8Array): void { + const wasm = getInstance(); + if (!wasm.shouldTransferToMLBuffer) { + throw new Error('Trying to upload to a MLBuffer while shouldTransferToMLBuffer is false'); + } this.bufferManager.upload(bufferId, data); } diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index b23518dd20e73..43d2980de1c9f 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -261,6 +261,7 @@ export const createSession = async( for (const provider of options?.executionProviders ?? []) { const providerName = typeof provider === 'string' ? provider : provider.name; if (providerName === 'webnn') { + wasm.shouldTransferToMLBuffer = false; if (wasm.currentContext) { throw new Error('WebNN execution provider is already set.'); } @@ -294,6 +295,7 @@ export const createSession = async( if (wasm.currentContext) { wasm.jsepRegisterMLContext!(sessionHandle, wasm.currentContext); wasm.currentContext = undefined; + wasm.shouldTransferToMLBuffer = true; } const [inputCount, outputCount] = getSessionInputOutputCount(sessionHandle); @@ -691,16 +693,19 @@ export const run = async( 'gpu-buffer' ]); } else if (preferredLocation === 'ml-buffer' && size > 0) { - const getMLBuffer = wasm.jsepGetMLBuffer; - if (!getMLBuffer) { + const ensureBuffer = wasm.jsepEnsureBuffer; + if (!ensureBuffer) { throw new Error('preferredLocation "ml-buffer" is not supported without using WebNN.'); } - const mlBuffer = getMLBuffer(dataOffset); const elementSize = getTensorElementSize(dataType); if (elementSize === undefined || !isMLBufferSupportedType(type)) { throw new Error(`Unsupported data type: ${type}`); } + // If the graph has been partitioned, the output tensor may have not been created. For this reason, we use + // ensureBuffer to get/create the MLBuffer. + const mlBuffer = ensureBuffer(dataOffset, dataType, dims); + // do not release the tensor right now. it will be released when user calls tensor.dispose(). keepOutputTensor = true; diff --git a/js/web/lib/wasm/wasm-types.ts b/js/web/lib/wasm/wasm-types.ts index afca278422edd..783620b146364 100644 --- a/js/web/lib/wasm/wasm-types.ts +++ b/js/web/lib/wasm/wasm-types.ts @@ -126,6 +126,11 @@ export declare namespace JSEP { */ currentContext: MLContext; + /** + * Disables creating MLBuffers. This is used to avoid creating MLBuffers for graph initializers. + */ + shouldTransferToMLBuffer: boolean; + /** * [exported from pre-jsep.js] Register MLContext for a session. * @param sessionId - specify the session ID. diff --git a/onnxruntime/core/providers/webnn/allocator.cc b/onnxruntime/core/providers/webnn/allocator.cc index 355ee7e48b9f4..4b8188a6f8344 100644 --- a/onnxruntime/core/providers/webnn/allocator.cc +++ b/onnxruntime/core/providers/webnn/allocator.cc @@ -12,6 +12,10 @@ void* WebNNBufferAllocator::Alloc(size_t size) { if (size == 0) { return nullptr; } + if (!emscripten::val::module_property("shouldTransferToMLBuffer").as()) { + // We don't need to transfer the buffer to an MLBuffer, so we don't need to allocate buffer id. + return nullptr; + } void* p = EM_ASM_PTR({ return Module.jsepReserveBufferId(); }); allocations_[p] = size; stats_.num_allocs++; diff --git a/onnxruntime/core/providers/webnn/data_transfer.cc b/onnxruntime/core/providers/webnn/data_transfer.cc index 66096c74a7950..3ba9d1171191a 100644 --- a/onnxruntime/core/providers/webnn/data_transfer.cc +++ b/onnxruntime/core/providers/webnn/data_transfer.cc @@ -6,7 +6,6 @@ #include #include "core/framework/tensor.h" - namespace onnxruntime { namespace webnn { @@ -24,6 +23,11 @@ common::Status DataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const { const auto& dst_device = dst.Location().device; + if (!emscripten::val::module_property("shouldTransferToMLBuffer").as()) { + // We don't need to transfer the buffer to an MLBuffer, so we don't need to copy the buffer. + return Status::OK(); + } + if (dst_device.Type() == OrtDevice::GPU) { EM_ASM({ Module.jsepUploadBuffer($0, HEAPU8.subarray($1, $1 + $2));