Skip to content

Commit

Permalink
PR feedback
Browse files Browse the repository at this point in the history
* Added shouldTransferToMLBuffer to avoid creating MLBuffers for initializers
* Switches from getMLBuffer to ensureBuffer to fix issues when graph is partitioned
* Switches from custom DataType enum to the common one.
  • Loading branch information
egalli committed Jul 18, 2024
1 parent 1e07f85 commit cd1b01a
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 29 deletions.
42 changes: 17 additions & 25 deletions js/web/lib/wasm/jsep/backend-webnn.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,37 +8,25 @@

import {Tensor} from 'onnxruntime-common';

import {DataType} from '../wasm-common';
import {getInstance} from '../wasm-factory';

import {createView} from './tensor-view';
import {BufferId, BufferManager, createBufferManager} from './webnn/buffer-manager';

/*
* TensorProto::data_type from the ONNX specification.
*/
enum TensorProtoDataType {
float = 1,
uint8 = 2,
int8 = 3,
int32 = 6,
int64 = 7,
bool = 9,
float16 = 10,
uint32 = 12,
uint64 = 13,
}

/*
* TensorProto::data_type to WebNN OperandType mapping.
*/
const onnxDataTypeToWebnnDataType = new Map<TensorProtoDataType, MLOperandDataType>([
[TensorProtoDataType.float, 'float32'],
[TensorProtoDataType.float16, 'float16'],
[TensorProtoDataType.int32, 'int32'],
[TensorProtoDataType.uint32, 'uint32'],
[TensorProtoDataType.int64, 'int64'],
[TensorProtoDataType.uint64, 'uint64'],
[TensorProtoDataType.int8, 'int8'],
[TensorProtoDataType.uint8, 'uint8'],
[TensorProtoDataType.bool, 'uint8'],
const onnxDataTypeToWebnnDataType = new Map<DataType, MLOperandDataType>([
[DataType.float, 'float32'],
[DataType.float16, 'float16'],
[DataType.int32, 'int32'],
[DataType.uint32, 'uint32'],
[DataType.int64, 'int64'],
[DataType.uint64, 'uint64'],
[DataType.int8, 'int8'],
[DataType.uint8, 'uint8'],
[DataType.bool, 'uint8'],
]);

/**
Expand Down Expand Up @@ -130,6 +118,10 @@ export class WebNNBackend {
}

public uploadBuffer(bufferId: BufferId, data: Uint8Array): void {
const wasm = getInstance();
if (!wasm.shouldTransferToMLBuffer) {
throw new Error('Trying to upload to a MLBuffer while shouldTransferToMLBuffer is false');
}
this.bufferManager.upload(bufferId, data);
}

Expand Down
11 changes: 8 additions & 3 deletions js/web/lib/wasm/wasm-core-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ export const createSession = async(
for (const provider of options?.executionProviders ?? []) {
const providerName = typeof provider === 'string' ? provider : provider.name;
if (providerName === 'webnn') {
wasm.shouldTransferToMLBuffer = false;
if (wasm.currentContext) {
throw new Error('WebNN execution provider is already set.');
}
Expand Down Expand Up @@ -294,6 +295,7 @@ export const createSession = async(
if (wasm.currentContext) {
wasm.jsepRegisterMLContext!(sessionHandle, wasm.currentContext);
wasm.currentContext = undefined;
wasm.shouldTransferToMLBuffer = true;
}

const [inputCount, outputCount] = getSessionInputOutputCount(sessionHandle);
Expand Down Expand Up @@ -691,16 +693,19 @@ export const run = async(
'gpu-buffer'
]);
} else if (preferredLocation === 'ml-buffer' && size > 0) {
const getMLBuffer = wasm.jsepGetMLBuffer;
if (!getMLBuffer) {
const ensureBuffer = wasm.jsepEnsureBuffer;
if (!ensureBuffer) {
throw new Error('preferredLocation "ml-buffer" is not supported without using WebNN.');
}
const mlBuffer = getMLBuffer(dataOffset);
const elementSize = getTensorElementSize(dataType);
if (elementSize === undefined || !isMLBufferSupportedType(type)) {
throw new Error(`Unsupported data type: ${type}`);
}

// If the graph has been partitioned, the output tensor may have not been created. For this reason, we use
// ensureBuffer to get/create the MLBuffer.
const mlBuffer = ensureBuffer(dataOffset, dataType, dims);

// do not release the tensor right now. it will be released when user calls tensor.dispose().
keepOutputTensor = true;

Expand Down
5 changes: 5 additions & 0 deletions js/web/lib/wasm/wasm-types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,11 @@ export declare namespace JSEP {
*/
currentContext: MLContext;

/**
* Disables creating MLBuffers. This is used to avoid creating MLBuffers for graph initializers.
*/
shouldTransferToMLBuffer: boolean;

/**
* [exported from pre-jsep.js] Register MLContext for a session.
* @param sessionId - specify the session ID.
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/core/providers/webnn/allocator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ void* WebNNBufferAllocator::Alloc(size_t size) {
if (size == 0) {
return nullptr;
}
if (!emscripten::val::module_property("shouldTransferToMLBuffer").as<bool>()) {
// We don't need to transfer the buffer to an MLBuffer, so we don't need to allocate buffer id.
return nullptr;
}
void* p = EM_ASM_PTR({ return Module.jsepReserveBufferId(); });
allocations_[p] = size;
stats_.num_allocs++;
Expand Down
6 changes: 5 additions & 1 deletion onnxruntime/core/providers/webnn/data_transfer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
#include <emscripten.h>
#include "core/framework/tensor.h"


namespace onnxruntime {
namespace webnn {

Expand All @@ -24,6 +23,11 @@ common::Status DataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const {

const auto& dst_device = dst.Location().device;

if (!emscripten::val::module_property("shouldTransferToMLBuffer").as<bool>()) {
// We don't need to transfer the buffer to an MLBuffer, so we don't need to copy the buffer.
return Status::OK();
}

if (dst_device.Type() == OrtDevice::GPU) {
EM_ASM({
Module.jsepUploadBuffer($0, HEAPU8.subarray($1, $1 + $2));
Expand Down

0 comments on commit cd1b01a

Please sign in to comment.