Skip to content

Commit

Permalink
PR feedback
Browse files Browse the repository at this point in the history
* Fixed issues when building under debug
* Disabled MLBuffer on CPU device types
* Renamed MlBuffer and MlContext to match specification
  • Loading branch information
egalli committed Jul 17, 2024
1 parent d823108 commit 85ca43b
Show file tree
Hide file tree
Showing 24 changed files with 154 additions and 120 deletions.
6 changes: 3 additions & 3 deletions js/common/lib/tensor-factory-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -275,10 +275,10 @@ export const tensorFromGpuBuffer = <T extends TensorInterface.GpuBufferDataTypes
};

/**
* implementation of Tensor.fromMlBuffer().
* implementation of Tensor.fromMLBuffer().
*/
export const tensorFromMlBuffer = <T extends TensorInterface.GpuBufferDataTypes>(
mlBuffer: TensorInterface.MlBufferType, options: TensorFromGpuBufferOptions<T>): Tensor => {
export const tensorFromMLBuffer = <T extends TensorInterface.GpuBufferDataTypes>(
mlBuffer: TensorInterface.MLBufferType, options: TensorFromGpuBufferOptions<T>): Tensor => {
const {dataType, dims, download, dispose} = options;
return new Tensor({location: 'ml-buffer', type: dataType ?? 'float32', mlBuffer, dims, download, dispose});
};
Expand Down
8 changes: 4 additions & 4 deletions js/common/lib/tensor-factory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ export interface GpuBufferConstructorParameters<T extends Tensor.GpuBufferDataTy
readonly gpuBuffer: Tensor.GpuBufferType;
}

export interface MlBufferConstructorParameters<T extends Tensor.MlBufferDataTypes = Tensor.MlBufferDataTypes> extends
export interface MLBufferConstructorParameters<T extends Tensor.MLBufferDataTypes = Tensor.MLBufferDataTypes> extends
CommonConstructorParameters<T>, GpuResourceConstructorParameters<T> {
/**
* Specify the location of the data to be 'ml-buffer'.
Expand All @@ -94,7 +94,7 @@ export interface MlBufferConstructorParameters<T extends Tensor.MlBufferDataType
/**
* Specify the WebNN buffer that holds the tensor data.
*/
readonly mlBuffer: Tensor.MlBufferType;
readonly mlBuffer: Tensor.MLBufferType;
}

// #endregion
Expand Down Expand Up @@ -212,7 +212,7 @@ export interface TensorFromGpuBufferOptions<T extends Tensor.GpuBufferDataTypes>
dataType?: T;
}

export interface TensorFromMlBufferOptions<T extends Tensor.MlBufferDataTypes> extends
export interface TensorFromMLBufferOptions<T extends Tensor.MLBufferDataTypes> extends
Pick<Tensor, 'dims'>, GpuResourceConstructorParameters<T> {
/**
* Describes the data type of the tensor.
Expand Down Expand Up @@ -345,7 +345,7 @@ export interface TensorFactory {
*
* @returns a tensor object
*/
fromMlBuffer<T extends Tensor.MlBufferDataTypes>(buffer: Tensor.MlBufferType, options: TensorFromMlBufferOptions<T>):
fromMLBuffer<T extends Tensor.MLBufferDataTypes>(buffer: Tensor.MLBufferType, options: TensorFromMLBufferOptions<T>):
TypedTensor<T>;

/**
Expand Down
20 changes: 10 additions & 10 deletions js/common/lib/tensor-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

import {tensorToDataURL, tensorToImageData} from './tensor-conversion-impl.js';
import {TensorToDataUrlOptions, TensorToImageDataOptions} from './tensor-conversion.js';
import {tensorFromGpuBuffer, tensorFromImage, tensorFromMlBuffer, tensorFromPinnedBuffer, tensorFromTexture} from './tensor-factory-impl.js';
import {CpuPinnedConstructorParameters, GpuBufferConstructorParameters, MlBufferConstructorParameters, TensorFromGpuBufferOptions, TensorFromImageBitmapOptions, TensorFromImageDataOptions, TensorFromImageElementOptions, TensorFromTextureOptions, TensorFromUrlOptions, TextureConstructorParameters} from './tensor-factory.js';
import {tensorFromGpuBuffer, tensorFromImage, tensorFromMLBuffer, tensorFromPinnedBuffer, tensorFromTexture} from './tensor-factory-impl.js';
import {CpuPinnedConstructorParameters, GpuBufferConstructorParameters, MLBufferConstructorParameters, TensorFromGpuBufferOptions, TensorFromImageBitmapOptions, TensorFromImageDataOptions, TensorFromImageElementOptions, TensorFromTextureOptions, TensorFromUrlOptions, TextureConstructorParameters} from './tensor-factory.js';
import {checkTypedArray, NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP, NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP, SupportedTypedArray, SupportedTypedArrayConstructors} from './tensor-impl-type-mapping.js';
import {calculateSize, tensorReshape} from './tensor-utils-impl.js';
import {Tensor as TensorInterface} from './tensor.js';
Expand All @@ -16,7 +16,7 @@ type TensorDataType = TensorInterface.DataType;
type TensorDataLocation = TensorInterface.DataLocation;
type TensorTextureType = TensorInterface.TextureType;
type TensorGpuBufferType = TensorInterface.GpuBufferType;
type TensorMlBufferType = TensorInterface.MlBufferType;
type TensorMLBufferType = TensorInterface.MLBufferType;

/**
* the implementation of Tensor interface.
Expand Down Expand Up @@ -68,14 +68,14 @@ export class Tensor implements TensorInterface {
*
* @param params - Specify the parameters to construct the tensor.
*/
constructor(params: MlBufferConstructorParameters);
constructor(params: MLBufferConstructorParameters);

/**
* implementation.
*/
constructor(
arg0: TensorType|TensorDataType|readonly string[]|readonly boolean[]|CpuPinnedConstructorParameters|
TextureConstructorParameters|GpuBufferConstructorParameters|MlBufferConstructorParameters,
TextureConstructorParameters|GpuBufferConstructorParameters|MLBufferConstructorParameters,
arg1?: TensorDataType|readonly number[]|readonly string[]|readonly boolean[], arg2?: readonly number[]) {
// perform one-time check for BigInt/Float16Array support
checkTypedArray();
Expand Down Expand Up @@ -273,9 +273,9 @@ export class Tensor implements TensorInterface {
return tensorFromGpuBuffer(gpuBuffer, options);
}

static fromMlBuffer<T extends TensorInterface.GpuBufferDataTypes>(
mlBuffer: TensorMlBufferType, options: TensorFromGpuBufferOptions<T>): TensorInterface {
return tensorFromMlBuffer(mlBuffer, options);
static fromMLBuffer<T extends TensorInterface.GpuBufferDataTypes>(
mlBuffer: TensorMLBufferType, options: TensorFromGpuBufferOptions<T>): TensorInterface {
return tensorFromMLBuffer(mlBuffer, options);
}

static fromPinnedBuffer<T extends TensorInterface.CpuPinnedDataTypes>(
Expand Down Expand Up @@ -326,7 +326,7 @@ export class Tensor implements TensorInterface {
/**
* stores the underlying WebNN MLBuffer when location is 'ml-buffer'. otherwise empty.
*/
private mlBufferData?: TensorMlBufferType;
private mlBufferData?: TensorMLBufferType;


/**
Expand Down Expand Up @@ -376,7 +376,7 @@ export class Tensor implements TensorInterface {
return this.gpuBufferData;
}

get mlBuffer(): TensorMlBufferType {
get mlBuffer(): TensorMLBufferType {
this.ensureValid();
if (!this.mlBufferData) {
throw new Error('The data is not stored as a WebNN buffer.');
Expand Down
4 changes: 2 additions & 2 deletions js/common/lib/tensor-utils-impl.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import {CpuPinnedConstructorParameters, GpuBufferConstructorParameters, MlBufferConstructorParameters, TextureConstructorParameters} from './tensor-factory.js';
import {CpuPinnedConstructorParameters, GpuBufferConstructorParameters, MLBufferConstructorParameters, TextureConstructorParameters} from './tensor-factory.js';
import {Tensor} from './tensor-impl.js';

/**
Expand Down Expand Up @@ -56,7 +56,7 @@ export const tensorReshape = (tensor: Tensor, dims: readonly number[]): Tensor =
return new Tensor({
location: 'ml-buffer',
mlBuffer: tensor.mlBuffer,
type: tensor.type as MlBufferConstructorParameters['type'],
type: tensor.type as MLBufferConstructorParameters['type'],
dims,
});
default:
Expand Down
6 changes: 3 additions & 3 deletions js/common/lib/tensor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ interface TypedTensorBase<T extends Tensor.Type> {
*
* If the data is not in a WebNN MLBuffer, throw error.
*/
readonly mlBuffer: Tensor.MlBufferType;
readonly mlBuffer: Tensor.MLBufferType;

/**
* Get the buffer data of the tensor.
Expand Down Expand Up @@ -144,7 +144,7 @@ export declare namespace Tensor {
*
* The specification for WebNN's ML Buffer is currently in flux.
*/
export type MlBufferType = unknown;
export type MLBufferType = unknown;

/**
* supported data types for constructing a tensor from a WebGPU buffer
Expand All @@ -154,7 +154,7 @@ export declare namespace Tensor {
/**
* supported data types for constructing a tensor from a WebNN MLBuffer
*/
export type MlBufferDataTypes = 'float32'|'float16'|'int8'|'uint8'|'int32'|'uint32'|'int64'|'uint64'|'bool';
export type MLBufferDataTypes = 'float32'|'float16'|'int8'|'uint8'|'int32'|'uint32'|'int64'|'uint64'|'bool';

/**
* represent where the tensor data is stored
Expand Down
26 changes: 13 additions & 13 deletions js/web/lib/wasm/jsep/backend-webnn.ts
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ export class WebNNBackend {
/**
* Maps from MLContext to session ids.
*/
private sessionIdsByMlContext = new Map<MLContext, Set<number>>();
private sessionIdsByMLContext = new Map<MLContext, Set<number>>();
/**
* Current session id.
*/
Expand All @@ -68,38 +68,38 @@ export class WebNNBackend {
if (this.currentSessionId === undefined) {
throw new Error('No active session');
}
return this.getMlContext(this.currentSessionId);
return this.getMLContext(this.currentSessionId);
}

public registerMlContext(sessionId: number, mlContext: MLContext): void {
public registerMLContext(sessionId: number, mlContext: MLContext): void {
this.mlContextBySessionId.set(sessionId, mlContext);
let sessionIds = this.sessionIdsByMlContext.get(mlContext);
let sessionIds = this.sessionIdsByMLContext.get(mlContext);
if (!sessionIds) {
sessionIds = new Set();
this.sessionIdsByMlContext.set(mlContext, sessionIds);
this.sessionIdsByMLContext.set(mlContext, sessionIds);
}
sessionIds.add(sessionId);
}

public unregisterMlContext(sessionId: number): void {
public unregisterMLContext(sessionId: number): void {
const mlContext = this.mlContextBySessionId.get(sessionId)!;
if (!mlContext) {
throw new Error(`No MLContext found for session ${sessionId}`);
}
this.mlContextBySessionId.delete(sessionId);
const sessionIds = this.sessionIdsByMlContext.get(mlContext)!;
const sessionIds = this.sessionIdsByMLContext.get(mlContext)!;
sessionIds.delete(sessionId);
if (sessionIds.size === 0) {
this.sessionIdsByMlContext.delete(mlContext);
this.sessionIdsByMLContext.delete(mlContext);
}
}

public onReleaseSession(sessionId: number): void {
this.unregisterMlContext(sessionId);
this.bufferManager.releaseBuffersForContext(this.getMlContext(sessionId));
this.unregisterMLContext(sessionId);
this.bufferManager.releaseBuffersForContext(this.getMLContext(sessionId));
}

public getMlContext(sessionId: number): MLContext {
public getMLContext(sessionId: number): MLContext {
return this.mlContextBySessionId.get(sessionId)!;
}

Expand Down Expand Up @@ -137,14 +137,14 @@ export class WebNNBackend {
return this.bufferManager.download(bufferId);
}

public createMlBufferDownloader(bufferId: BufferId, type: Tensor.GpuBufferDataTypes): () => Promise<Tensor.DataType> {
public createMLBufferDownloader(bufferId: BufferId, type: Tensor.GpuBufferDataTypes): () => Promise<Tensor.DataType> {
return async () => {
const data = await this.bufferManager.download(bufferId);
return createView(data, type);
};
}

public registerMlBuffer(buffer: MLBuffer): BufferId {
public registerMLBuffer(buffer: MLBuffer): BufferId {
return this.bufferManager.registerBuffer(this.currentContext, buffer);
}

Expand Down
17 changes: 16 additions & 1 deletion js/web/lib/wasm/jsep/init.ts
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,21 @@ export const init =
]);
} else {
const backend = new WebNNBackend();
jsepInit('webnn', [backend]);
jsepInit('webnn', [
backend,
// jsepReserveBufferId
() => backend.reserveBufferId(),
// jsepReleaseBufferId,
(bufferId: number) => backend.releaseBufferId(bufferId),
// jsepEnsureBuffer
(bufferId: number, onnxDataType: number, dimensions: number[]) =>
backend.ensureBuffer(bufferId, onnxDataType, dimensions),
// jsepUploadBuffer
(bufferId: number, data: Uint8Array) => {
backend.uploadBuffer(bufferId, data);
},
// jsepDownloadBuffer
async (bufferId: number) => backend.downloadBuffer(bufferId),
]);
}
};
10 changes: 5 additions & 5 deletions js/web/lib/wasm/proxy-messages.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ export type GpuBufferMetadata = {
dispose?: () => void;
};

export type MlBufferMetadata = {
mlBuffer: Tensor.MlBufferType;
download?: () => Promise<Tensor.DataTypeMap[Tensor.MlBufferDataTypes]>;
export type MLBufferMetadata = {
mlBuffer: Tensor.MLBufferType;
download?: () => Promise<Tensor.DataTypeMap[Tensor.MLBufferDataTypes]>;
dispose?: () => void;
};

Expand All @@ -26,7 +26,7 @@ export type MlBufferMetadata = {
*/
export type UnserializableTensorMetadata =
[dataType: Tensor.Type, dims: readonly number[], data: GpuBufferMetadata, location: 'gpu-buffer']|
[dataType: Tensor.Type, dims: readonly number[], data: MlBufferMetadata, location: 'ml-buffer']|
[dataType: Tensor.Type, dims: readonly number[], data: MLBufferMetadata, location: 'ml-buffer']|
[dataType: Tensor.Type, dims: readonly number[], data: Tensor.DataType, location: 'cpu-pinned'];

/**
Expand All @@ -37,7 +37,7 @@ export type UnserializableTensorMetadata =
* - cpu: Uint8Array
* - cpu-pinned: Uint8Array
* - gpu-buffer: GpuBufferMetadata
* - ml-buffer: MlBufferMetadata
* - ml-buffer: MLBufferMetadata
* - location: tensor data location
*/
export type TensorMetadata = SerializableTensorMetadata|UnserializableTensorMetadata;
Expand Down
6 changes: 3 additions & 3 deletions js/web/lib/wasm/session-handler-inference.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import {InferenceSession, InferenceSessionHandler, SessionHandler, Tensor, TRACE

import {SerializableInternalBuffer, TensorMetadata} from './proxy-messages';
import {copyFromExternalBuffer, createSession, endProfiling, releaseSession, run} from './proxy-wrapper';
import {isGpuBufferSupportedType, isMlBufferSupportedType} from './wasm-common';
import {isGpuBufferSupportedType, isMLBufferSupportedType} from './wasm-common';
import {isNode} from './wasm-utils-env';
import {loadFile} from './wasm-utils-load-file';

Expand Down Expand Up @@ -36,11 +36,11 @@ export const decodeTensorMetadata = (tensor: TensorMetadata): Tensor => {
}
case 'ml-buffer': {
const dataType = tensor[0];
if (!isMlBufferSupportedType(dataType)) {
if (!isMLBufferSupportedType(dataType)) {
throw new Error(`not supported data type: ${dataType} for deserializing GPU tensor`);
}
const {mlBuffer, download, dispose} = tensor[2];
return Tensor.fromMlBuffer(mlBuffer, {dataType, dims: tensor[1], download, dispose});
return Tensor.fromMLBuffer(mlBuffer, {dataType, dims: tensor[1], download, dispose});
}
default:
throw new Error(`invalid data location: ${tensor[3]}`);
Expand Down
2 changes: 1 addition & 1 deletion js/web/lib/wasm/wasm-common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ export const isGpuBufferSupportedType = (type: Tensor.Type): type is Tensor.GpuB
/**
* Check whether the given tensor type is supported by WebNN MLBuffer
*/
export const isMlBufferSupportedType = (type: Tensor.Type): type is Tensor.MlBufferDataTypes => type === 'float32' ||
export const isMLBufferSupportedType = (type: Tensor.Type): type is Tensor.MLBufferDataTypes => type === 'float32' ||
type === 'float16' || type === 'int32' || type === 'int64' || type === 'uint32' || type === 'uint64' ||
type === 'int8' || type === 'uint8' || type === 'bool';

Expand Down
20 changes: 10 additions & 10 deletions js/web/lib/wasm/wasm-core-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import {Env, InferenceSession, Tensor} from 'onnxruntime-common';
import {SerializableInternalBuffer, SerializableSessionMetadata, SerializableTensorMetadata, TensorMetadata} from './proxy-messages';
import {setRunOptions} from './run-options';
import {setSessionOptions} from './session-options';
import {dataLocationStringToEnum, getTensorElementSize, isGpuBufferSupportedType, isMlBufferSupportedType, logLevelStringToEnum, tensorDataTypeEnumToString, tensorDataTypeStringToEnum, tensorTypeToTypedArrayConstructor} from './wasm-common';
import {dataLocationStringToEnum, getTensorElementSize, isGpuBufferSupportedType, isMLBufferSupportedType, logLevelStringToEnum, tensorDataTypeEnumToString, tensorDataTypeStringToEnum, tensorTypeToTypedArrayConstructor} from './wasm-common';
import {getInstance} from './wasm-factory';
import {allocWasmString, checkLastError} from './wasm-utils';
import {loadFile} from './wasm-utils-load-file';
Expand Down Expand Up @@ -292,7 +292,7 @@ export const createSession = async(

// clear current MLContext after session creation
if (wasm.currentContext) {
wasm.jsepRegisterMlContext!(sessionHandle, wasm.currentContext);
wasm.jsepRegisterMLContext!(sessionHandle, wasm.currentContext);
wasm.currentContext = undefined;
}

Expand Down Expand Up @@ -446,11 +446,11 @@ export const prepareInputOutputTensor =
const elementSizeInBytes = getTensorElementSize(tensorDataTypeStringToEnum(dataType))!;
dataByteLength = dims.reduce((a, b) => a * b, 1) * elementSizeInBytes;

const registerMlBuffer = wasm.jsepRegisterMlBuffer;
if (!registerMlBuffer) {
const registerMLBuffer = wasm.jsepRegisterMLBuffer;
if (!registerMLBuffer) {
throw new Error('Tensor location "ml-buffer" is not supported without using WebNN.');
}
rawData = registerMlBuffer(mlBuffer);
rawData = registerMLBuffer(mlBuffer);
} else {
const data = tensor[2];

Expand Down Expand Up @@ -691,13 +691,13 @@ export const run = async(
'gpu-buffer'
]);
} else if (preferredLocation === 'ml-buffer' && size > 0) {
const getMlBuffer = wasm.jsepGetMlBuffer;
if (!getMlBuffer) {
const getMLBuffer = wasm.jsepGetMLBuffer;
if (!getMLBuffer) {
throw new Error('preferredLocation "ml-buffer" is not supported without using WebNN.');
}
const mlBuffer = getMlBuffer(dataOffset);
const mlBuffer = getMLBuffer(dataOffset);
const elementSize = getTensorElementSize(dataType);
if (elementSize === undefined || !isMlBufferSupportedType(type)) {
if (elementSize === undefined || !isMLBufferSupportedType(type)) {
throw new Error(`Unsupported data type: ${type}`);
}

Expand All @@ -707,7 +707,7 @@ export const run = async(
output.push([
type, dims, {
mlBuffer,
download: wasm.jsepCreateMlBufferDownloader!(dataOffset, type),
download: wasm.jsepCreateMLBufferDownloader!(dataOffset, type),
dispose: () => {
wasm.jsepReleaseBufferId!(dataOffset);
wasm._OrtReleaseTensor(tensor);
Expand Down
Loading

0 comments on commit 85ca43b

Please sign in to comment.