{
/** The buffer holding the indices. */
final T indices;
+
/** The buffer holding the values. */
final Buffer values;
diff --git a/java/src/main/java/ai/onnxruntime/OnnxTensor.java b/java/src/main/java/ai/onnxruntime/OnnxTensor.java
index 0078adb6402f8..e1ee2c14fd9d1 100644
--- a/java/src/main/java/ai/onnxruntime/OnnxTensor.java
+++ b/java/src/main/java/ai/onnxruntime/OnnxTensor.java
@@ -14,12 +14,14 @@
import java.nio.LongBuffer;
import java.nio.ShortBuffer;
import java.util.Optional;
+import java.util.logging.Logger;
/**
* A Java object wrapping an OnnxTensor. Tensors are the main input to the library, and can also be
* returned as outputs.
*/
public class OnnxTensor extends OnnxTensorLike {
+ private static final Logger logger = Logger.getLogger(OnnxTensor.class.getName());
/**
* This reference is held for OnnxTensors backed by a java.nio.Buffer to ensure the buffer does
@@ -97,6 +99,7 @@ public OnnxValueType getType() {
*/
@Override
public Object getValue() throws OrtException {
+ checkClosed();
if (info.isScalar()) {
switch (info.type) {
case FLOAT:
@@ -144,16 +147,21 @@ public Object getValue() throws OrtException {
@Override
public String toString() {
- return "OnnxTensor(info=" + info.toString() + ")";
+ return "OnnxTensor(info=" + info.toString() + ",closed=" + closed + ")";
}
/**
- * Closes the tensor, releasing it's underlying memory (if it's not backed by an NIO buffer). If
- * it is backed by a buffer then the memory is released when the buffer is GC'd.
+ * Closes the tensor, releasing its underlying memory (if it's not backed by an NIO buffer). If it
+ * is backed by a buffer then the memory is released when the buffer is GC'd.
*/
@Override
- public void close() {
- close(OnnxRuntime.ortApiHandle, nativeHandle);
+ public synchronized void close() {
+ if (!closed) {
+ close(OnnxRuntime.ortApiHandle, nativeHandle);
+ closed = true;
+ } else {
+ logger.warning("Closing an already closed tensor.");
+ }
}
/**
@@ -165,6 +173,7 @@ public void close() {
* @return A ByteBuffer copy of the OnnxTensor.
*/
public ByteBuffer getByteBuffer() {
+ checkClosed();
if (info.type != OnnxJavaType.STRING) {
ByteBuffer buffer = getBuffer(OnnxRuntime.ortApiHandle, nativeHandle);
ByteBuffer output = ByteBuffer.allocate(buffer.capacity());
@@ -183,6 +192,7 @@ public ByteBuffer getByteBuffer() {
* @return A FloatBuffer copy of the OnnxTensor.
*/
public FloatBuffer getFloatBuffer() {
+ checkClosed();
if (info.type == OnnxJavaType.FLOAT) {
// if it's fp32 use the efficient copy.
FloatBuffer buffer = getBuffer().asFloatBuffer();
@@ -212,6 +222,7 @@ public FloatBuffer getFloatBuffer() {
* @return A DoubleBuffer copy of the OnnxTensor.
*/
public DoubleBuffer getDoubleBuffer() {
+ checkClosed();
if (info.type == OnnxJavaType.DOUBLE) {
DoubleBuffer buffer = getBuffer().asDoubleBuffer();
DoubleBuffer output = DoubleBuffer.allocate(buffer.capacity());
@@ -230,6 +241,7 @@ public DoubleBuffer getDoubleBuffer() {
* @return A ShortBuffer copy of the OnnxTensor.
*/
public ShortBuffer getShortBuffer() {
+ checkClosed();
if ((info.type == OnnxJavaType.INT16)
|| (info.type == OnnxJavaType.FLOAT16)
|| (info.type == OnnxJavaType.BFLOAT16)) {
@@ -250,6 +262,7 @@ public ShortBuffer getShortBuffer() {
* @return An IntBuffer copy of the OnnxTensor.
*/
public IntBuffer getIntBuffer() {
+ checkClosed();
if (info.type == OnnxJavaType.INT32) {
IntBuffer buffer = getBuffer().asIntBuffer();
IntBuffer output = IntBuffer.allocate(buffer.capacity());
@@ -268,6 +281,7 @@ public IntBuffer getIntBuffer() {
* @return A LongBuffer copy of the OnnxTensor.
*/
public LongBuffer getLongBuffer() {
+ checkClosed();
if (info.type == OnnxJavaType.INT64) {
LongBuffer buffer = getBuffer().asLongBuffer();
LongBuffer output = LongBuffer.allocate(buffer.capacity());
diff --git a/java/src/main/java/ai/onnxruntime/OnnxTensorLike.java b/java/src/main/java/ai/onnxruntime/OnnxTensorLike.java
index c2989fe296dc2..bbfd4e981ece2 100644
--- a/java/src/main/java/ai/onnxruntime/OnnxTensorLike.java
+++ b/java/src/main/java/ai/onnxruntime/OnnxTensorLike.java
@@ -28,6 +28,9 @@ public abstract class OnnxTensorLike implements OnnxValue {
/** The size and shape information for this tensor. */
protected final TensorInfo info;
+ /** Is this value closed? */
+ protected boolean closed;
+
/**
* Constructs a tensor-like (the base class of OnnxTensor and OnnxSparseTensor).
*
@@ -39,6 +42,7 @@ public abstract class OnnxTensorLike implements OnnxValue {
this.nativeHandle = nativeHandle;
this.allocatorHandle = allocatorHandle;
this.info = info;
+ this.closed = false;
}
/**
@@ -59,4 +63,16 @@ long getNativeHandle() {
public TensorInfo getInfo() {
return info;
}
+
+ @Override
+ public synchronized boolean isClosed() {
+ return closed;
+ }
+
+ /** Checks if the OnnxValue is closed, if so throws {@link IllegalStateException}. */
+ protected void checkClosed() {
+ if (closed) {
+ throw new IllegalStateException("Trying to use a closed OnnxValue");
+ }
+ }
}
diff --git a/java/src/main/java/ai/onnxruntime/OnnxValue.java b/java/src/main/java/ai/onnxruntime/OnnxValue.java
index 752a0e74267d3..e829bc80f09f6 100644
--- a/java/src/main/java/ai/onnxruntime/OnnxValue.java
+++ b/java/src/main/java/ai/onnxruntime/OnnxValue.java
@@ -64,7 +64,14 @@ public enum OnnxValueType {
*/
public ValueInfo getInfo();
- /** Closes the OnnxValue, freeing it's native memory. */
+ /**
+ * Checks if this value is closed (i.e., the native object has been released).
+ *
+ * @return True if the value is closed and the native object has been released.
+ */
+ public boolean isClosed();
+
+ /** Closes the OnnxValue, freeing its native memory. */
@Override
public void close();
diff --git a/java/src/main/java/ai/onnxruntime/OrtProviderOptions.java b/java/src/main/java/ai/onnxruntime/OrtProviderOptions.java
index 39a5121fad7a2..70af10ff8cd79 100644
--- a/java/src/main/java/ai/onnxruntime/OrtProviderOptions.java
+++ b/java/src/main/java/ai/onnxruntime/OrtProviderOptions.java
@@ -5,11 +5,14 @@
package ai.onnxruntime;
import java.io.IOException;
+import java.util.logging.Logger;
/** An abstract base class for execution provider options classes. */
// Note this lives in ai.onnxruntime to allow subclasses to access the OnnxRuntime.ortApiHandle
// package private field.
public abstract class OrtProviderOptions implements AutoCloseable {
+ private static final Logger logger = Logger.getLogger(OrtProviderOptions.class.getName());
+
static {
try {
OnnxRuntime.init();
@@ -21,6 +24,9 @@ public abstract class OrtProviderOptions implements AutoCloseable {
/** The native pointer. */
protected final long nativeHandle;
+ /** Is the native object closed? */
+ protected boolean closed;
+
/**
* Constructs a OrtProviderOptions wrapped around a native pointer.
*
@@ -28,6 +34,7 @@ public abstract class OrtProviderOptions implements AutoCloseable {
*/
protected OrtProviderOptions(long nativeHandle) {
this.nativeHandle = nativeHandle;
+ this.closed = false;
}
/**
@@ -46,9 +53,30 @@ protected static long getApiHandle() {
*/
public abstract OrtProvider getProvider();
+ /**
+ * Is the native object closed?
+ *
+ * @return True if the native object has been released.
+ */
+ public synchronized boolean isClosed() {
+ return closed;
+ }
+
@Override
public void close() {
- close(OnnxRuntime.ortApiHandle, nativeHandle);
+ if (!closed) {
+ close(OnnxRuntime.ortApiHandle, nativeHandle);
+ closed = true;
+ } else {
+ logger.warning("Closing an already closed tensor.");
+ }
+ }
+
+ /** Checks if the OrtProviderOptions is closed, if so throws {@link IllegalStateException}. */
+ protected void checkClosed() {
+ if (closed) {
+ throw new IllegalStateException("Trying to use a closed OrtProviderOptions");
+ }
}
/**
diff --git a/java/src/main/java/ai/onnxruntime/OrtTrainingSession.java b/java/src/main/java/ai/onnxruntime/OrtTrainingSession.java
index 49ddf29c22335..eeede3a1bed0b 100644
--- a/java/src/main/java/ai/onnxruntime/OrtTrainingSession.java
+++ b/java/src/main/java/ai/onnxruntime/OrtTrainingSession.java
@@ -12,6 +12,7 @@
import java.util.Map;
import java.util.Objects;
import java.util.Set;
+import java.util.logging.Logger;
/**
* Wraps an ONNX training model and allows training and inference calls.
@@ -1049,8 +1050,12 @@ private native void exportModelForInference(
/** Wrapper class for the checkpoint state. */
static final class OrtCheckpointState implements AutoCloseable {
+ private static final Logger logger = Logger.getLogger(OrtCheckpointState.class.getName());
+
final long nativeHandle;
+ private boolean closed;
+
/**
* Wraps an object around the checkpoint native handle.
*
@@ -1058,6 +1063,7 @@ static final class OrtCheckpointState implements AutoCloseable {
*/
OrtCheckpointState(long nativeHandle) {
this.nativeHandle = nativeHandle;
+ this.closed = false;
}
/**
@@ -1097,6 +1103,7 @@ static OrtCheckpointState loadCheckpoint(String checkpoint) throws OrtException
* @throws OrtException If the checkpoint failed to save.
*/
public void saveCheckpoint(Path outputPath, boolean saveOptimizer) throws OrtException {
+ checkClosed();
Objects.requireNonNull(outputPath, "checkpoint path must not be null");
String outputStr = outputPath.toString();
saveCheckpoint(
@@ -1115,6 +1122,7 @@ public void saveCheckpoint(Path outputPath, boolean saveOptimizer) throws OrtExc
* @throws OrtException If the call failed.
*/
public void addProperty(String name, float value) throws OrtException {
+ checkClosed();
addProperty(
OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, nativeHandle, name, value);
}
@@ -1127,6 +1135,7 @@ public void addProperty(String name, float value) throws OrtException {
* @throws OrtException If the call failed.
*/
public void addProperty(String name, int value) throws OrtException {
+ checkClosed();
addProperty(
OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, nativeHandle, name, value);
}
@@ -1139,6 +1148,7 @@ public void addProperty(String name, int value) throws OrtException {
* @throws OrtException If the call failed.
*/
public void addProperty(String name, String value) throws OrtException {
+ checkClosed();
addProperty(
OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, nativeHandle, name, value);
}
@@ -1152,6 +1162,7 @@ public void addProperty(String name, String value) throws OrtException {
* @throws OrtException If the property does not exist, or is of the wrong type.
*/
public float getFloatProperty(OrtAllocator allocator, String name) throws OrtException {
+ checkClosed();
return getFloatProperty(
OnnxRuntime.ortApiHandle,
OnnxRuntime.ortTrainingApiHandle,
@@ -1169,6 +1180,7 @@ public float getFloatProperty(OrtAllocator allocator, String name) throws OrtExc
* @throws OrtException If the property does not exist, or is of the wrong type.
*/
public int getIntProperty(OrtAllocator allocator, String name) throws OrtException {
+ checkClosed();
return getIntProperty(
OnnxRuntime.ortApiHandle,
OnnxRuntime.ortTrainingApiHandle,
@@ -1186,6 +1198,7 @@ public int getIntProperty(OrtAllocator allocator, String name) throws OrtExcepti
* @throws OrtException If the property does not exist, or is of the wrong type.
*/
public String getStringProperty(OrtAllocator allocator, String name) throws OrtException {
+ checkClosed();
return getStringProperty(
OnnxRuntime.ortApiHandle,
OnnxRuntime.ortTrainingApiHandle,
@@ -1194,9 +1207,25 @@ public String getStringProperty(OrtAllocator allocator, String name) throws OrtE
name);
}
+ /** Checks if the OrtCheckpointState is closed, if so throws {@link IllegalStateException}. */
+ private void checkClosed() {
+ if (closed) {
+ throw new IllegalStateException("Trying to use a closed OrtCheckpointState");
+ }
+ }
+
+ public synchronized boolean isClosed() {
+ return closed;
+ }
+
@Override
- public void close() {
- close(OnnxRuntime.ortTrainingApiHandle, nativeHandle);
+ public synchronized void close() {
+ if (!closed) {
+ close(OnnxRuntime.ortTrainingApiHandle, nativeHandle);
+ closed = true;
+ } else {
+ logger.warning("Closing a checkpoint twice");
+ }
}
/*
diff --git a/java/src/main/java/ai/onnxruntime/TensorInfo.java b/java/src/main/java/ai/onnxruntime/TensorInfo.java
index 69ccb954e8afe..1c21387b50455 100644
--- a/java/src/main/java/ai/onnxruntime/TensorInfo.java
+++ b/java/src/main/java/ai/onnxruntime/TensorInfo.java
@@ -7,6 +7,7 @@
import java.lang.reflect.Array;
import java.nio.Buffer;
import java.util.Arrays;
+import java.util.stream.Collectors;
/** Describes an {@link OnnxTensor}, including it's size, shape and element type. */
public class TensorInfo implements ValueInfo {
@@ -159,6 +160,12 @@ public static OnnxTensorType mapFromJavaType(OnnxJavaType type) {
/** The shape of the tensor. */
final long[] shape;
+ /** The names of the unbound dimensions. */
+ final String[] dimensionNames;
+
+ /** If there are non-empty dimension names */
+ private final boolean hasNames;
+
/** The Java type of this tensor. */
public final OnnxJavaType type;
@@ -177,6 +184,9 @@ public static OnnxTensorType mapFromJavaType(OnnxJavaType type) {
*/
TensorInfo(long[] shape, OnnxJavaType type, OnnxTensorType onnxType) {
this.shape = shape;
+ this.dimensionNames = new String[shape.length];
+ Arrays.fill(dimensionNames, "");
+ this.hasNames = false;
this.type = type;
this.onnxType = onnxType;
this.numElements = elementCount(shape);
@@ -188,10 +198,20 @@ public static OnnxTensorType mapFromJavaType(OnnxJavaType type) {
* Called from JNI.
*
* @param shape The tensor shape.
+ * @param names The dimension names.
* @param typeInt The native type int.
*/
- TensorInfo(long[] shape, int typeInt) {
+ TensorInfo(long[] shape, String[] names, int typeInt) {
this.shape = shape;
+ this.dimensionNames = names;
+ boolean hasNames = false;
+ for (String s : names) {
+ if (!s.isEmpty()) {
+ hasNames = true;
+ break;
+ }
+ }
+ this.hasNames = hasNames;
this.onnxType = OnnxTensorType.mapFromInt(typeInt);
this.type = OnnxJavaType.mapFromOnnxTensorType(this.onnxType);
this.numElements = elementCount(shape);
@@ -206,15 +226,42 @@ public long[] getShape() {
return Arrays.copyOf(shape, shape.length);
}
+ /**
+ * Get a copy of the tensor's named dimensions.
+ *
+ * @return A copof the tensor's named dimensions.
+ */
+ public String[] getDimensionNames() {
+ return Arrays.copyOf(dimensionNames, dimensionNames.length);
+ }
+
@Override
public String toString() {
- return "TensorInfo(javaType="
- + type.toString()
- + ",onnxType="
- + onnxType.toString()
- + ",shape="
- + Arrays.toString(shape)
- + ")";
+ String output =
+ "TensorInfo(javaType="
+ + type.toString()
+ + ",onnxType="
+ + onnxType.toString()
+ + ",shape="
+ + Arrays.toString(shape);
+ if (hasNames) {
+ output =
+ output
+ + ",dimNames=["
+ + Arrays.stream(dimensionNames)
+ .map(
+ a -> {
+ if (a.isEmpty()) {
+ return "\"\"";
+ } else {
+ return a;
+ }
+ })
+ .collect(Collectors.joining(","))
+ + "]";
+ }
+ output = output + ")";
+ return output;
}
/**
diff --git a/java/src/main/java/ai/onnxruntime/providers/CoreMLFlags.java b/java/src/main/java/ai/onnxruntime/providers/CoreMLFlags.java
index eb124decf75f3..cec3fadf446ca 100644
--- a/java/src/main/java/ai/onnxruntime/providers/CoreMLFlags.java
+++ b/java/src/main/java/ai/onnxruntime/providers/CoreMLFlags.java
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021, 2023, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 2021, 2024, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime.providers;
@@ -14,7 +14,18 @@ public enum CoreMLFlags implements OrtFlags {
/** Enables CoreML on subgraphs. */
ENABLE_ON_SUBGRAPH(2), // COREML_FLAG_ENABLE_ON_SUBGRAPH(0x002)
/** Only enable usage of CoreML if the device has an Apple Neural Engine. */
- ONLY_ENABLE_DEVICE_WITH_ANE(4); // COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE(0x004),
+ ONLY_ENABLE_DEVICE_WITH_ANE(4), // COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE(0x004)
+ /**
+ * Only allow CoreML EP to take nodes with inputs with static shapes. By default it will also
+ * allow inputs with dynamic shapes. However, the performance may be negatively impacted if inputs
+ * have dynamic shapes.
+ */
+ ONLY_ALLOW_STATIC_INPUT_SHAPES(8), // COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES(0x008)
+ /**
+ * Create an MLProgram. By default it will create a NeuralNetwork model. Requires Core ML 5 or
+ * later.
+ */
+ CREATE_MLPROGRAM(16); // COREML_FLAG_CREATE_MLPROGRAM(0x010)
/** The native value of the enum. */
public final int value;
diff --git a/java/src/main/java/ai/onnxruntime/providers/StringConfigProviderOptions.java b/java/src/main/java/ai/onnxruntime/providers/StringConfigProviderOptions.java
index 02207b2949e54..961163035c9a6 100644
--- a/java/src/main/java/ai/onnxruntime/providers/StringConfigProviderOptions.java
+++ b/java/src/main/java/ai/onnxruntime/providers/StringConfigProviderOptions.java
@@ -32,6 +32,7 @@ protected StringConfigProviderOptions(long nativeHandle) {
* @throws OrtException If the addition failed.
*/
public void add(String key, String value) throws OrtException {
+ checkClosed();
Objects.requireNonNull(key, "Key must not be null");
Objects.requireNonNull(value, "Value must not be null");
options.put(key, value);
diff --git a/java/src/main/native/OrtJniUtil.c b/java/src/main/native/OrtJniUtil.c
index 879ba8a310618..7b26291581395 100644
--- a/java/src/main/native/OrtJniUtil.c
+++ b/java/src/main/native/OrtJniUtil.c
@@ -342,7 +342,6 @@ jobject convertToTensorInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtTensorT
if (code != ORT_OK) {
return NULL;
}
- //printf("numDim %d\n",numDim);
int64_t* dimensions = (int64_t*) malloc(sizeof(int64_t)*numDim);
code = checkOrtStatus(jniEnv, api, api->GetDimensions(info, dimensions, numDim));
if (code != ORT_OK) {
@@ -358,12 +357,31 @@ jobject convertToTensorInfo(JNIEnv *jniEnv, const OrtApi * api, const OrtTensorT
free(dimensions);
dimensions = NULL;
+ // Create the string array for the names.
+ const char** dimensionNames = (const char**) malloc(sizeof(char*)*numDim);
+ if (dimensionNames == NULL) {
+ throwOrtException(jniEnv, 1, "Not enough memory");
+ return NULL;
+ }
+ code = checkOrtStatus(jniEnv, api, api->GetSymbolicDimensions(info, dimensionNames, numDim));
+ if (code != ORT_OK) {
+ // extraction failed, exception has been thrown, return to Java.
+ free(dimensionNames);
+ return NULL;
+ }
+ jclass stringClazz = (*jniEnv)->FindClass(jniEnv, "java/lang/String");
+ jobjectArray names = (*jniEnv)->NewObjectArray(jniEnv, safecast_size_t_to_jsize(numDim), stringClazz, NULL);
+ for (size_t i = 0; i < numDim; i++) {
+ jobject javaName = (*jniEnv)->NewStringUTF(jniEnv, dimensionNames[i]);
+ (*jniEnv)->SetObjectArrayElement(jniEnv, names, safecast_size_t_to_jsize(i), javaName);
+ }
+ free(dimensionNames);
+
// Create the TensorInfo object
static const char *tensorInfoClassName = "ai/onnxruntime/TensorInfo";
jclass clazz = (*jniEnv)->FindClass(jniEnv, tensorInfoClassName);
- jmethodID tensorInfoConstructor = (*jniEnv)->GetMethodID(jniEnv,clazz, "", "([JI)V");
- //printf("TensorInfo class %p, methodID %p\n",clazz,tensorInfoConstructor);
- jobject tensorInfo = (*jniEnv)->NewObject(jniEnv, clazz, tensorInfoConstructor, shape, onnxTypeInt);
+ jmethodID tensorInfoConstructor = (*jniEnv)->GetMethodID(jniEnv,clazz, "", "([J[Ljava/lang/String;I)V");
+ jobject tensorInfo = (*jniEnv)->NewObject(jniEnv, clazz, tensorInfoConstructor, shape, names, onnxTypeInt);
return tensorInfo;
}
diff --git a/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c b/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c
index 3a1c0d1bb8fa1..337f4c1921c6e 100644
--- a/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c
+++ b/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c
@@ -8,7 +8,7 @@
#include "onnxruntime/core/session/onnxruntime_c_api.h"
#include "OrtJniUtil.h"
#include "ai_onnxruntime_OrtSession_SessionOptions.h"
-#ifdef WIN32
+#ifdef _WIN32
#include
#else
#include
@@ -318,7 +318,7 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_closeC
// Iterate the handles, calling the appropriate close function
for (jint i = 0; i < numHandles; i++) {
-#ifdef WIN32
+#ifdef _WIN32
FreeLibrary((void*)handles[i]);
#else
dlclose((void*)handles[i]);
@@ -630,7 +630,7 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_addMIG
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_addDirectML
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jint deviceID) {
(void)jobj;
- #ifdef USE_DIRECTML
+ #ifdef USE_DML
checkOrtStatus(jniEnv,(const OrtApi*)apiHandle,OrtSessionOptionsAppendExecutionProvider_DML((OrtSessionOptions*) handle, deviceID));
#else
(void)apiHandle;(void)handle;(void)deviceID; // Parameters used when DirectML is defined.
diff --git a/java/src/main/native/ai_onnxruntime_OrtTrainingSession.c b/java/src/main/native/ai_onnxruntime_OrtTrainingSession.c
index 9f7b8d3a3dcfc..464234c34798a 100644
--- a/java/src/main/native/ai_onnxruntime_OrtTrainingSession.c
+++ b/java/src/main/native/ai_onnxruntime_OrtTrainingSession.c
@@ -66,7 +66,7 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtTrainingSession_createTrainingSes
}
}
wchar_t* optimizerStr = NULL;
- if (optimizerPath == NULL) {
+ if (optimizerPath != NULL) {
optimizerStr = copyAndPad(jniEnv, optimizerPath);
if (optimizerStr == NULL) {
// exception has been thrown in Java, go to cleanup and return null.
diff --git a/java/src/test/java/ai/onnxruntime/InferenceTest.java b/java/src/test/java/ai/onnxruntime/InferenceTest.java
index e975117fb75bd..ac65cbab146bf 100644
--- a/java/src/test/java/ai/onnxruntime/InferenceTest.java
+++ b/java/src/test/java/ai/onnxruntime/InferenceTest.java
@@ -69,7 +69,9 @@ public void environmentTest() {
// Checks that the environment instance is the same.
OrtEnvironment otherEnv = OrtEnvironment.getEnvironment();
assertSame(env, otherEnv);
+ TestHelpers.quietLogger(OrtEnvironment.class);
otherEnv = OrtEnvironment.getEnvironment("test-name");
+ TestHelpers.loudLogger(OrtEnvironment.class);
assertSame(env, otherEnv);
}
@@ -588,6 +590,12 @@ public void testSymbolicDimensionAssignment() throws OrtException {
Map infoMap = session.getInputInfo();
TensorInfo aInfo = (TensorInfo) infoMap.get("A").getInfo();
assertArrayEquals(new long[] {-1, 2}, aInfo.shape);
+ assertEquals(2, aInfo.dimensionNames.length);
+ assertEquals("n", aInfo.dimensionNames[0]);
+ assertEquals("", aInfo.dimensionNames[1]);
+ TensorInfo bInfo = (TensorInfo) infoMap.get("B").getInfo();
+ assertEquals(1, bInfo.dimensionNames.length);
+ assertEquals("m", bInfo.dimensionNames[0]);
}
}
// Check that when the options are assigned it overrides the symbolic dimension
@@ -643,6 +651,12 @@ public void testCoreML() throws OrtException {
runProvider(OrtProvider.CORE_ML);
}
+ @Test
+ @EnabledIfSystemProperty(named = "USE_DML", matches = "1")
+ public void testDirectML() throws OrtException {
+ runProvider(OrtProvider.DIRECT_ML);
+ }
+
private void runProvider(OrtProvider provider) throws OrtException {
EnumSet providers = OrtEnvironment.getAvailableProviders();
assertTrue(providers.size() > 1);
@@ -665,7 +679,7 @@ private void runProvider(OrtProvider provider) throws OrtException {
// CoreML gives slightly different answers on a 2020 13" M1 MBP
assertArrayEquals(expectedOutput, resultArray, 1e-2f);
} else {
- assertArrayEquals(expectedOutput, resultArray, 1e-6f);
+ assertArrayEquals(expectedOutput, resultArray, 1e-5f);
}
} catch (OrtException e) {
throw new IllegalStateException("Failed to execute a scoring operation", e);
@@ -1918,6 +1932,8 @@ private static SqueezeNetTuple openSessionSqueezeNet(EnumSet provid
options.addNnapi();
break;
case DIRECT_ML:
+ options.setMemoryPatternOptimization(false);
+ options.setExecutionMode(ExecutionMode.SEQUENTIAL);
options.addDirectML(0);
break;
case ACL:
diff --git a/java/src/test/java/ai/onnxruntime/OnnxTensorTest.java b/java/src/test/java/ai/onnxruntime/OnnxTensorTest.java
index a5f285ba86a14..c060cf73ecf14 100644
--- a/java/src/test/java/ai/onnxruntime/OnnxTensorTest.java
+++ b/java/src/test/java/ai/onnxruntime/OnnxTensorTest.java
@@ -4,6 +4,10 @@
*/
package ai.onnxruntime;
+import static org.junit.jupiter.api.Assertions.assertArrayEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+
import ai.onnxruntime.platform.Fp16Conversions;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
@@ -97,8 +101,8 @@ public void testBufferCreation() throws OrtException {
float[] arrValues = new float[] {0, 1, 2, 3, 4};
try (OnnxTensor t = OnnxTensor.createTensor(env, arrValues)) {
// array creation isn't backed by buffers
- Assertions.assertFalse(t.ownsBuffer());
- Assertions.assertFalse(t.getBufferRef().isPresent());
+ assertFalse(t.ownsBuffer());
+ assertFalse(t.getBufferRef().isPresent());
FloatBuffer buf = t.getFloatBuffer();
float[] output = new float[arrValues.length];
buf.get(output);
@@ -146,7 +150,7 @@ public void testBufferCreation() throws OrtException {
directBuffer.rewind();
try (OnnxTensor t = OnnxTensor.createTensor(env, directBuffer, new long[] {1, 5})) {
// direct buffers don't trigger a copy
- Assertions.assertFalse(t.ownsBuffer());
+ assertFalse(t.ownsBuffer());
// tensors backed by buffers can get the buffer ref back out
Assertions.assertTrue(t.getBufferRef().isPresent());
FloatBuffer buf = t.getFloatBuffer();
@@ -428,4 +432,21 @@ public void testBf16RoundTrip() {
}
}
}
+
+ @Test
+ public void testClose() throws OrtException {
+ OrtEnvironment env = OrtEnvironment.getEnvironment();
+ long[] input = new long[] {1, 2, 3, 4, 5};
+ OnnxTensor value = OnnxTensor.createTensor(env, input);
+ assertFalse(value.isClosed());
+ long[] output = (long[]) value.getValue();
+ assertArrayEquals(input, output);
+ value.close();
+ // check use after close throws
+ assertThrows(IllegalStateException.class, value::getValue);
+ // check double close doesn't crash (emits warning)
+ TestHelpers.quietLogger(OnnxTensor.class);
+ value.close();
+ TestHelpers.loudLogger(OnnxTensor.class);
+ }
}
diff --git a/java/src/test/java/ai/onnxruntime/TestHelpers.java b/java/src/test/java/ai/onnxruntime/TestHelpers.java
index 55d8169434d48..c13cdf222b15b 100644
--- a/java/src/test/java/ai/onnxruntime/TestHelpers.java
+++ b/java/src/test/java/ai/onnxruntime/TestHelpers.java
@@ -22,6 +22,8 @@
import java.util.Comparator;
import java.util.List;
import java.util.Map;
+import java.util.logging.Level;
+import java.util.logging.Logger;
import java.util.regex.Pattern;
import org.junit.jupiter.api.Assertions;
@@ -258,6 +260,16 @@ static void flattenStringBase(String[] input, List output) {
output.addAll(Arrays.asList(input));
}
+ static void loudLogger(Class> loggerClass) {
+ Logger l = Logger.getLogger(loggerClass.getName());
+ l.setLevel(Level.INFO);
+ }
+
+ static void quietLogger(Class> loggerClass) {
+ Logger l = Logger.getLogger(loggerClass.getName());
+ l.setLevel(Level.OFF);
+ }
+
public static Path getResourcePath(String path) {
return new File(TestHelpers.class.getResource(path).getFile()).toPath();
}
diff --git a/java/src/test/java/ai/onnxruntime/providers/ProviderOptionsTest.java b/java/src/test/java/ai/onnxruntime/providers/ProviderOptionsTest.java
index 1ed883ace36e5..0e3bc15ba9c70 100644
--- a/java/src/test/java/ai/onnxruntime/providers/ProviderOptionsTest.java
+++ b/java/src/test/java/ai/onnxruntime/providers/ProviderOptionsTest.java
@@ -96,7 +96,7 @@ private static void runProvider(OrtProvider provider, OrtSession.SessionOptions
OnnxValue resultTensor = result.get(0);
float[] resultArray = TestHelpers.flattenFloat(resultTensor.getValue());
assertEquals(expectedOutput.length, resultArray.length);
- assertArrayEquals(expectedOutput, resultArray, 1e-6f);
+ assertArrayEquals(expectedOutput, resultArray, 1e-5f);
} catch (OrtException e) {
throw new IllegalStateException("Failed to execute a scoring operation", e);
}
diff --git a/java/src/test/java/sample/ScoreMNIST.java b/java/src/test/java/sample/ScoreMNIST.java
index 5587b58e17f52..6ecbc5cd56d10 100644
--- a/java/src/test/java/sample/ScoreMNIST.java
+++ b/java/src/test/java/sample/ScoreMNIST.java
@@ -30,6 +30,7 @@
public class ScoreMNIST {
private static final Logger logger = Logger.getLogger(ScoreMNIST.class.getName());
+
/** Pattern for splitting libsvm format files. */
private static final Pattern splitPattern = Pattern.compile("\\s+");
diff --git a/js/common/lib/backend-impl.ts b/js/common/lib/backend-impl.ts
index 3e1e833addb91..e90efd7b97c29 100644
--- a/js/common/lib/backend-impl.ts
+++ b/js/common/lib/backend-impl.ts
@@ -2,6 +2,7 @@
// Licensed under the MIT License.
import {Backend} from './backend.js';
+import {InferenceSession} from './inference-session.js';
interface BackendInfo {
backend: Backend;
@@ -10,6 +11,7 @@ interface BackendInfo {
initPromise?: Promise;
initialized?: boolean;
aborted?: boolean;
+ error?: string;
}
const backends: Map = new Map();
@@ -60,43 +62,100 @@ export const registerBackend = (name: string, backend: Backend, priority: number
};
/**
- * Resolve backend by specified hints.
+ * Try to resolve and initialize a backend.
*
- * @param backendHints - a list of execution provider names to lookup. If omitted use registered backends as list.
- * @returns a promise that resolves to the backend.
+ * @param backendName - the name of the backend.
+ * @returns the backend instance if resolved and initialized successfully, or an error message if failed.
+ */
+const tryResolveAndInitializeBackend = async(backendName: string): Promise => {
+ const backendInfo = backends.get(backendName);
+ if (!backendInfo) {
+ return 'backend not found.';
+ }
+
+ if (backendInfo.initialized) {
+ return backendInfo.backend;
+ } else if (backendInfo.aborted) {
+ return backendInfo.error!;
+ } else {
+ const isInitializing = !!backendInfo.initPromise;
+ try {
+ if (!isInitializing) {
+ backendInfo.initPromise = backendInfo.backend.init(backendName);
+ }
+ await backendInfo.initPromise;
+ backendInfo.initialized = true;
+ return backendInfo.backend;
+ } catch (e) {
+ if (!isInitializing) {
+ backendInfo.error = `${e}`;
+ backendInfo.aborted = true;
+ }
+ return backendInfo.error!;
+ } finally {
+ delete backendInfo.initPromise;
+ }
+ }
+};
+
+/**
+ * Resolve execution providers from the specific session options.
+ *
+ * @param options - the session options object.
+ * @returns a promise that resolves to a tuple of an initialized backend instance and a session options object with
+ * filtered EP list.
*
* @ignore
*/
-export const resolveBackend = async(backendHints: readonly string[]): Promise => {
- const backendNames = backendHints.length === 0 ? backendsSortedByPriority : backendHints;
- const errors = [];
- for (const backendName of backendNames) {
- const backendInfo = backends.get(backendName);
- if (backendInfo) {
- if (backendInfo.initialized) {
- return backendInfo.backend;
- } else if (backendInfo.aborted) {
- continue; // current backend is unavailable; try next
- }
+export const resolveBackendAndExecutionProviders = async(options: InferenceSession.SessionOptions):
+ Promise<[backend: Backend, options: InferenceSession.SessionOptions]> => {
+ // extract backend hints from session options
+ const eps = options.executionProviders || [];
+ const backendHints = eps.map(i => typeof i === 'string' ? i : i.name);
+ const backendNames = backendHints.length === 0 ? backendsSortedByPriority : backendHints;
- const isInitializing = !!backendInfo.initPromise;
- try {
- if (!isInitializing) {
- backendInfo.initPromise = backendInfo.backend.init(backendName);
+ // try to resolve and initialize all requested backends
+ let backend: Backend|undefined;
+ const errors = [];
+ const availableBackendNames = new Set();
+ for (const backendName of backendNames) {
+ const resolveResult = await tryResolveAndInitializeBackend(backendName);
+ if (typeof resolveResult === 'string') {
+ errors.push({name: backendName, err: resolveResult});
+ } else {
+ if (!backend) {
+ backend = resolveResult;
+ }
+ if (backend === resolveResult) {
+ availableBackendNames.add(backendName);
+ }
}
- await backendInfo.initPromise;
- backendInfo.initialized = true;
- return backendInfo.backend;
- } catch (e) {
- if (!isInitializing) {
- errors.push({name: backendName, err: e});
+ }
+
+ // if no backend is available, throw error.
+ if (!backend) {
+ throw new Error(`no available backend found. ERR: ${errors.map(e => `[${e.name}] ${e.err}`).join(', ')}`);
+ }
+
+ // for each explicitly requested backend, if it's not available, output warning message.
+ for (const {name, err} of errors) {
+ if (backendHints.includes(name)) {
+ // eslint-disable-next-line no-console
+ console.warn(`removing requested execution provider "${
+ name}" from session options because it is not available: ${err}`);
}
- backendInfo.aborted = true;
- } finally {
- delete backendInfo.initPromise;
}
- }
- }
- throw new Error(`no available backend found. ERR: ${errors.map(e => `[${e.name}] ${e.err}`).join(', ')}`);
-};
+ const filteredEps = eps.filter(i => availableBackendNames.has(typeof i === 'string' ? i : i.name));
+
+ return [
+ backend, new Proxy(options, {
+ get: (target, prop) => {
+ if (prop === 'executionProviders') {
+ return filteredEps;
+ }
+ return Reflect.get(target, prop);
+ }
+ })
+ ];
+ };
diff --git a/js/common/lib/backend.ts b/js/common/lib/backend.ts
index 9bfcb12206057..8c07bdd5c5c4a 100644
--- a/js/common/lib/backend.ts
+++ b/js/common/lib/backend.ts
@@ -58,7 +58,7 @@ export interface TrainingSessionHandler extends SessionHandler {
options: InferenceSession.RunOptions): Promise;
getParametersSize(trainableOnly: boolean): Promise;
- loadParametersBuffer(array: Uint8Array, trainableOnly: boolean): Promise;
+ loadParametersBuffer(buffer: Uint8Array, trainableOnly: boolean): Promise;
getContiguousParameters(trainableOnly: boolean): Promise;
}
@@ -77,8 +77,8 @@ export interface Backend {
Promise;
createTrainingSessionHandler?
- (checkpointStateUriOrBuffer: TrainingSession.URIorBuffer, trainModelUriOrBuffer: TrainingSession.URIorBuffer,
- evalModelUriOrBuffer: TrainingSession.URIorBuffer, optimizerModelUriOrBuffer: TrainingSession.URIorBuffer,
+ (checkpointStateUriOrBuffer: TrainingSession.UriOrBuffer, trainModelUriOrBuffer: TrainingSession.UriOrBuffer,
+ evalModelUriOrBuffer: TrainingSession.UriOrBuffer, optimizerModelUriOrBuffer: TrainingSession.UriOrBuffer,
options: InferenceSession.SessionOptions): Promise;
}
diff --git a/js/common/lib/env.ts b/js/common/lib/env.ts
index 0cded7e5edbcb..c8df1613b3268 100644
--- a/js/common/lib/env.ts
+++ b/js/common/lib/env.ts
@@ -33,6 +33,14 @@ export declare namespace Env {
*/
simd?: boolean;
+ /**
+ * set or get a boolean value indicating whether to enable trace.
+ *
+ * @deprecated Use `env.trace` instead. If `env.trace` is set, this property will be ignored.
+ * @defaultValue `false`
+ */
+ trace?: boolean;
+
/**
* Set or get a number specifying the timeout for initialization of WebAssembly backend, in milliseconds. A zero
* value indicates no timeout is set.
@@ -103,6 +111,7 @@ export declare namespace Env {
kernelId: number;
kernelType: string;
kernelName: string;
+ programName: string;
startTime: number;
endTime: number;
}
@@ -134,13 +143,52 @@ export declare namespace Env {
*/
ondata?: (data: WebGpuProfilingData) => void;
};
+ /**
+ * Set or get the power preference.
+ *
+ * Setting this property only has effect before the first WebGPU inference session is created. The value will be
+ * used as options for `navigator.gpu.requestAdapter()`.
+ *
+ * See {@link https://gpuweb.github.io/gpuweb/#dictdef-gpurequestadapteroptions} for more details.
+ *
+ * @defaultValue `undefined`
+ */
+ powerPreference?: 'low-power'|'high-performance';
+ /**
+ * Set or get the force fallback adapter flag.
+ *
+ * Setting this property only has effect before the first WebGPU inference session is created. The value will be
+ * used as options for `navigator.gpu.requestAdapter()`.
+ *
+ * See {@link https://gpuweb.github.io/gpuweb/#dictdef-gpurequestadapteroptions} for more details.
+ *
+ * @defaultValue `undefined`
+ */
+ forceFallbackAdapter?: boolean;
+ /**
+ * Set or get the adapter for WebGPU.
+ *
+ * Setting this property only has effect before the first WebGPU inference session is created. The value will be
+ * used as the GPU adapter for the underlying WebGPU backend to create GPU device.
+ *
+ * If this property is not set, it will be available to get after the first WebGPU inference session is created. The
+ * value will be the GPU adapter that created by the underlying WebGPU backend.
+ *
+ * When use with TypeScript, the type of this property is `GPUAdapter` defined in "@webgpu/types".
+ * Use `const adapter = env.webgpu.adapter as GPUAdapter;` in TypeScript to access this property with correct type.
+ *
+ * see comments on {@link Tensor.GpuBufferType}
+ */
+ adapter: unknown;
/**
* Get the device for WebGPU.
*
+ * This property is only available after the first WebGPU inference session is created.
+ *
* When use with TypeScript, the type of this property is `GPUDevice` defined in "@webgpu/types".
* Use `const device = env.webgpu.device as GPUDevice;` in TypeScript to access this property with correct type.
*
- * see comments on {@link GpuBufferType} for more details about why not use types defined in "@webgpu/types".
+ * see comments on {@link Tensor.GpuBufferType} for more details about why not use types defined in "@webgpu/types".
*/
readonly device: unknown;
/**
@@ -159,6 +207,7 @@ export interface Env {
* @defaultValue `'warning'`
*/
logLevel?: 'verbose'|'info'|'warning'|'error'|'fatal';
+
/**
* Indicate whether run in debug mode.
*
@@ -166,6 +215,13 @@ export interface Env {
*/
debug?: boolean;
+ /**
+ * set or get a boolean value indicating whether to enable trace.
+ *
+ * @defaultValue `false`
+ */
+ trace?: boolean;
+
/**
* Get version of the current package.
*/
diff --git a/js/common/lib/index.ts b/js/common/lib/index.ts
index 9cbfcc4e8bcdc..3ed56b3c2e812 100644
--- a/js/common/lib/index.ts
+++ b/js/common/lib/index.ts
@@ -11,7 +11,7 @@
* - [onnxruntime-react-native](https://www.npmjs.com/package/onnxruntime-react-native)
*
* See also:
- * - [Get Started](https://onnxruntime.ai/docs/get-started/with-javascript.html)
+ * - [Get Started](https://onnxruntime.ai/docs/get-started/with-javascript/)
* - [Inference examples](https://github.com/microsoft/onnxruntime-inference-examples/tree/main/js)
*
* @packageDocumentation
@@ -21,5 +21,9 @@ export * from './backend.js';
export * from './env.js';
export * from './inference-session.js';
export * from './tensor.js';
+export * from './tensor-conversion.js';
+export * from './tensor-factory.js';
+export * from './trace.js';
+export * from './onnx-model.js';
export * from './onnx-value.js';
export * from './training-session.js';
diff --git a/js/common/lib/inference-session-impl.ts b/js/common/lib/inference-session-impl.ts
index 9bc2088f2088a..ab4c6a3e0c46b 100644
--- a/js/common/lib/inference-session-impl.ts
+++ b/js/common/lib/inference-session-impl.ts
@@ -1,11 +1,12 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
-import {resolveBackend} from './backend-impl.js';
+import {resolveBackendAndExecutionProviders} from './backend-impl.js';
import {InferenceSessionHandler} from './backend.js';
import {InferenceSession as InferenceSessionInterface} from './inference-session.js';
import {OnnxValue} from './onnx-value.js';
import {Tensor} from './tensor.js';
+import {TRACE_FUNC_BEGIN, TRACE_FUNC_END} from './trace.js';
type SessionOptions = InferenceSessionInterface.SessionOptions;
type RunOptions = InferenceSessionInterface.RunOptions;
@@ -20,6 +21,7 @@ export class InferenceSession implements InferenceSessionInterface {
run(feeds: FeedsType, options?: RunOptions): Promise;
run(feeds: FeedsType, fetches: FetchesType, options?: RunOptions): Promise;
async run(feeds: FeedsType, arg1?: FetchesType|RunOptions, arg2?: RunOptions): Promise {
+ TRACE_FUNC_BEGIN();
const fetches: {[name: string]: OnnxValue|null} = {};
let options: RunOptions = {};
// check inputs
@@ -117,6 +119,7 @@ export class InferenceSession implements InferenceSessionInterface {
}
}
}
+ TRACE_FUNC_END();
return returnValue;
}
@@ -132,6 +135,7 @@ export class InferenceSession implements InferenceSessionInterface {
static async create(
arg0: string|ArrayBufferLike|Uint8Array, arg1?: SessionOptions|number, arg2?: number,
arg3?: SessionOptions): Promise {
+ TRACE_FUNC_BEGIN();
// either load from a file or buffer
let filePathOrUint8Array: string|Uint8Array;
let options: SessionOptions = {};
@@ -191,11 +195,10 @@ export class InferenceSession implements InferenceSessionInterface {
throw new TypeError('Unexpected argument[0]: must be \'path\' or \'buffer\'.');
}
- // get backend hints
- const eps = options.executionProviders || [];
- const backendHints = eps.map(i => typeof i === 'string' ? i : i.name);
- const backend = await resolveBackend(backendHints);
- const handler = await backend.createInferenceSessionHandler(filePathOrUint8Array, options);
+ // resolve backend, update session options with validated EPs, and create session handler
+ const [backend, optionsWithValidatedEPs] = await resolveBackendAndExecutionProviders(options);
+ const handler = await backend.createInferenceSessionHandler(filePathOrUint8Array, optionsWithValidatedEPs);
+ TRACE_FUNC_END();
return new InferenceSession(handler);
}
diff --git a/js/common/lib/inference-session.ts b/js/common/lib/inference-session.ts
index c7760692eed00..14db5c59d972a 100644
--- a/js/common/lib/inference-session.ts
+++ b/js/common/lib/inference-session.ts
@@ -2,6 +2,7 @@
// Licensed under the MIT License.
import {InferenceSession as InferenceSessionImpl} from './inference-session-impl.js';
+import {OnnxModelOptions} from './onnx-model.js';
import {OnnxValue, OnnxValueDataLocation} from './onnx-value.js';
/* eslint-disable @typescript-eslint/no-redeclare */
@@ -43,7 +44,7 @@ export declare namespace InferenceSession {
/**
* A set of configurations for session behavior.
*/
- export interface SessionOptions {
+ export interface SessionOptions extends OnnxModelOptions {
/**
* An array of execution provider options.
*
@@ -110,7 +111,7 @@ export declare namespace InferenceSession {
optimizedModelFilePath?: string;
/**
- * Wether enable profiling.
+ * Whether enable profiling.
*
* This setting is a placeholder for a future use.
*/
@@ -153,6 +154,12 @@ export declare namespace InferenceSession {
*/
preferredOutputLocation?: OnnxValueDataLocation|{readonly [outputName: string]: OnnxValueDataLocation};
+ /**
+ * Whether enable graph capture.
+ * This setting is available only in ONNXRuntime Web for WebGPU EP.
+ */
+ enableGraphCapture?: boolean;
+
/**
* Store configurations for a session. See
* https://github.com/microsoft/onnxruntime/blob/main/include/onnxruntime/core/session/
@@ -179,22 +186,22 @@ export declare namespace InferenceSession {
// #region execution providers
// Currently, we have the following backends to support execution providers:
- // Backend Node.js binding: supports 'cpu' and 'cuda'.
- // Backend WebAssembly: supports 'cpu', 'wasm', 'xnnpack' and 'webnn'.
+ // Backend Node.js binding: supports 'cpu', 'dml' (win32), 'coreml' (macOS) and 'cuda' (linux).
+ // Backend WebAssembly: supports 'cpu', 'wasm', 'webgpu' and 'webnn'.
// Backend ONNX.js: supports 'webgl'.
// Backend React Native: supports 'cpu', 'xnnpack', 'coreml' (iOS), 'nnapi' (Android).
interface ExecutionProviderOptionMap {
+ coreml: CoreMLExecutionProviderOption;
cpu: CpuExecutionProviderOption;
- coreml: CoreMlExecutionProviderOption;
cuda: CudaExecutionProviderOption;
dml: DmlExecutionProviderOption;
+ nnapi: NnapiExecutionProviderOption;
tensorrt: TensorRtExecutionProviderOption;
wasm: WebAssemblyExecutionProviderOption;
webgl: WebGLExecutionProviderOption;
- xnnpack: XnnpackExecutionProviderOption;
webgpu: WebGpuExecutionProviderOption;
webnn: WebNNExecutionProviderOption;
- nnapi: NnapiExecutionProviderOption;
+ xnnpack: XnnpackExecutionProviderOption;
}
type ExecutionProviderName = keyof ExecutionProviderOptionMap;
@@ -212,10 +219,6 @@ export declare namespace InferenceSession {
readonly name: 'cuda';
deviceId?: number;
}
- export interface CoreMlExecutionProviderOption extends ExecutionProviderOption {
- readonly name: 'coreml';
- coreMlFlags?: number;
- }
export interface DmlExecutionProviderOption extends ExecutionProviderOption {
readonly name: 'dml';
deviceId?: number;
@@ -240,14 +243,45 @@ export declare namespace InferenceSession {
}
export interface WebNNExecutionProviderOption extends ExecutionProviderOption {
readonly name: 'webnn';
- deviceType?: 'cpu'|'gpu';
+ deviceType?: 'cpu'|'gpu'|'npu';
numThreads?: number;
powerPreference?: 'default'|'low-power'|'high-performance';
}
export interface CoreMLExecutionProviderOption extends ExecutionProviderOption {
readonly name: 'coreml';
+ /**
+ * The bit flags for CoreML execution provider.
+ *
+ * ```
+ * COREML_FLAG_USE_CPU_ONLY = 0x001
+ * COREML_FLAG_ENABLE_ON_SUBGRAPH = 0x002
+ * COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE = 0x004
+ * COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES = 0x008
+ * COREML_FLAG_CREATE_MLPROGRAM = 0x010
+ * ```
+ *
+ * See include/onnxruntime/core/providers/coreml/coreml_provider_factory.h for more details.
+ *
+ * This flag is available only in ONNXRuntime (Node.js binding).
+ */
+ coreMlFlags?: number;
+ /**
+ * Specify whether to use CPU only in CoreML EP.
+ *
+ * This setting is available only in ONNXRuntime (react-native).
+ */
useCPUOnly?: boolean;
+ /**
+ * Specify whether to enable CoreML EP on subgraph.
+ *
+ * This setting is available only in ONNXRuntime (react-native).
+ */
enableOnSubgraph?: boolean;
+ /**
+ * Specify whether to only enable CoreML EP for Apple devices with ANE (Apple Neural Engine).
+ *
+ * This setting is available only in ONNXRuntime (react-native).
+ */
onlyEnableDeviceWithANE?: boolean;
}
export interface NnapiExecutionProviderOption extends ExecutionProviderOption {
diff --git a/js/common/lib/onnx-model.ts b/js/common/lib/onnx-model.ts
new file mode 100644
index 0000000000000..1cd3eedb6fcca
--- /dev/null
+++ b/js/common/lib/onnx-model.ts
@@ -0,0 +1,57 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+/**
+ * A string that represents a file's URL or path.
+ *
+ * Path is vailable only in onnxruntime-node or onnxruntime-web running in Node.js.
+ */
+export type FileUrlOrPath = string;
+
+/**
+ * A Blob object that represents a file.
+ */
+export type FileBlob = Blob;
+
+/**
+ * A Uint8Array, ArrayBuffer or SharedArrayBuffer object that represents a file content.
+ *
+ * When it is an ArrayBuffer or SharedArrayBuffer, the whole buffer is assumed to be the file content.
+ */
+export type FileData = Uint8Array|ArrayBufferLike;
+
+/**
+ * Represents a file that can be loaded by the ONNX Runtime JavaScript API.
+ */
+export type FileType = FileUrlOrPath|FileBlob|FileData;
+
+/**
+ * Represents an external data file.
+ */
+export interface ExternalDataFileDescription {
+ /**
+ * Specify the external data file.
+ */
+ data: FileType;
+ /**
+ * Specify the file path.
+ */
+ path: string;
+}
+
+/**
+ * Represents an external data file.
+ *
+ * When using a string, it should be a file URL or path that in the same directory as the model file.
+ */
+export type ExternalDataFileType = ExternalDataFileDescription|FileUrlOrPath;
+
+/**
+ * Options for model loading.
+ */
+export interface OnnxModelOptions {
+ /**
+ * Specifying a list of files that represents the external data.
+ */
+ externalData?: readonly ExternalDataFileType[];
+}
diff --git a/js/common/lib/onnx-value.ts b/js/common/lib/onnx-value.ts
index a16a30d25d839..72369ce8b4209 100644
--- a/js/common/lib/onnx-value.ts
+++ b/js/common/lib/onnx-value.ts
@@ -3,7 +3,7 @@
import {Tensor} from './tensor.js';
-type NonTensorType = never;
+export type NonTensorType = never;
/**
* Type OnnxValue Represents both tensors and non-tensors value for model's inputs/outputs.
diff --git a/js/common/lib/tensor-conversion-impl.ts b/js/common/lib/tensor-conversion-impl.ts
index 22397321e8c6b..b1de48a10c0e1 100644
--- a/js/common/lib/tensor-conversion-impl.ts
+++ b/js/common/lib/tensor-conversion-impl.ts
@@ -8,10 +8,11 @@ import {Tensor} from './tensor.js';
* implementation of Tensor.toDataURL()
*/
export const tensorToDataURL = (tensor: Tensor, options?: TensorToDataUrlOptions): string => {
- const canvas = document.createElement('canvas');
+ const canvas = typeof document !== 'undefined' ? document.createElement('canvas') : (new OffscreenCanvas(1, 1));
canvas.width = tensor.dims[3];
canvas.height = tensor.dims[2];
- const pixels2DContext = canvas.getContext('2d');
+ const pixels2DContext =
+ canvas.getContext('2d') as (CanvasRenderingContext2D | OffscreenCanvasRenderingContext2D | null);
if (pixels2DContext != null) {
// Default values for height and width & format
@@ -88,7 +89,11 @@ export const tensorToDataURL = (tensor: Tensor, options?: TensorToDataUrlOptions
pixels2DContext.fillRect(j, i, 1, 1);
}
}
- return canvas.toDataURL();
+ if ('toDataURL' in canvas) {
+ return canvas.toDataURL();
+ } else {
+ throw new Error('toDataURL is not supported');
+ }
} else {
throw new Error('Can not access image data');
}
@@ -98,7 +103,9 @@ export const tensorToDataURL = (tensor: Tensor, options?: TensorToDataUrlOptions
* implementation of Tensor.toImageData()
*/
export const tensorToImageData = (tensor: Tensor, options?: TensorToImageDataOptions): ImageData => {
- const pixels2DContext = document.createElement('canvas').getContext('2d');
+ const pixels2DContext = typeof document !== 'undefined' ?
+ document.createElement('canvas').getContext('2d') :
+ new OffscreenCanvas(1, 1).getContext('2d') as OffscreenCanvasRenderingContext2D;
let image: ImageData;
if (pixels2DContext != null) {
// Default values for height and width & format
diff --git a/js/common/lib/tensor-factory-impl.ts b/js/common/lib/tensor-factory-impl.ts
index 7228c4a97055b..19c62cb54bfed 100644
--- a/js/common/lib/tensor-factory-impl.ts
+++ b/js/common/lib/tensor-factory-impl.ts
@@ -110,13 +110,31 @@ export const tensorFromImage = async(
let data: Uint8ClampedArray|undefined;
let bufferToTensorOptions: BufferToTensorOptions = options ?? {};
+ const createCanvas = () => {
+ if (typeof document !== 'undefined') {
+ return document.createElement('canvas');
+ } else if (typeof OffscreenCanvas !== 'undefined') {
+ return new OffscreenCanvas(1, 1);
+ } else {
+ throw new Error('Canvas is not supported');
+ }
+ };
+ const createCanvasContext = (canvas: HTMLCanvasElement|OffscreenCanvas) => {
+ if (canvas instanceof HTMLCanvasElement) {
+ return canvas.getContext('2d');
+ } else if (canvas instanceof OffscreenCanvas) {
+ return canvas.getContext('2d') as OffscreenCanvasRenderingContext2D;
+ } else {
+ return null;
+ }
+ };
// filling and checking image configuration options
if (isHTMLImageEle) {
// HTMLImageElement - image object - format is RGBA by default
- const canvas = document.createElement('canvas');
+ const canvas = createCanvas();
canvas.width = image.width;
canvas.height = image.height;
- const pixels2DContext = canvas.getContext('2d');
+ const pixels2DContext = createCanvasContext(canvas);
if (pixels2DContext != null) {
let height = image.height;
@@ -166,12 +184,12 @@ export const tensorFromImage = async(
bufferToTensorOptions.width = width;
if (options !== undefined) {
- const tempCanvas = document.createElement('canvas');
+ const tempCanvas = createCanvas();
tempCanvas.width = width;
tempCanvas.height = height;
- const pixels2DContext = tempCanvas.getContext('2d');
+ const pixels2DContext = createCanvasContext(tempCanvas);
if (pixels2DContext != null) {
pixels2DContext.putImageData(image, 0, 0);
@@ -188,10 +206,10 @@ export const tensorFromImage = async(
throw new Error('Please provide image config with format for Imagebitmap');
}
- const canvas = document.createElement('canvas');
+ const canvas = createCanvas();
canvas.width = image.width;
canvas.height = image.height;
- const pixels2DContext = canvas.getContext('2d');
+ const pixels2DContext = createCanvasContext(canvas);
if (pixels2DContext != null) {
const height = image.height;
@@ -206,8 +224,8 @@ export const tensorFromImage = async(
}
} else if (isString) {
return new Promise((resolve, reject) => {
- const canvas = document.createElement('canvas');
- const context = canvas.getContext('2d');
+ const canvas = createCanvas();
+ const context = createCanvasContext(canvas);
if (!image || !context) {
return reject();
}
diff --git a/js/common/lib/tensor-factory.ts b/js/common/lib/tensor-factory.ts
index 6e19d7fb898a3..431de4c3635c2 100644
--- a/js/common/lib/tensor-factory.ts
+++ b/js/common/lib/tensor-factory.ts
@@ -253,7 +253,7 @@ export interface TensorFactory {
/**
* create a tensor from an ImageBitmap object
*
- * @param bitMap - the ImageBitmap object to create tensor from
+ * @param bitmap - the ImageBitmap object to create tensor from
* @param options - An optional object representing options for creating tensor from URL.
*
* The following default settings will be applied:
diff --git a/js/common/lib/tensor-impl-type-mapping.ts b/js/common/lib/tensor-impl-type-mapping.ts
index c4a43ea27fea1..b29cb8cbd6d35 100644
--- a/js/common/lib/tensor-impl-type-mapping.ts
+++ b/js/common/lib/tensor-impl-type-mapping.ts
@@ -14,7 +14,6 @@ export const NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP = new Map {
- if (!isBigIntChecked) {
- isBigIntChecked = true;
- const isBigInt64ArrayAvailable = typeof BigInt64Array !== 'undefined' && typeof BigInt64Array.from === 'function';
- const isBigUint64ArrayAvailable =
- typeof BigUint64Array !== 'undefined' && typeof BigUint64Array.from === 'function';
+// a dummy type declaration for Float16Array in case any polyfill is available.
+declare global {
+ // eslint-disable-next-line @typescript-eslint/naming-convention, @typescript-eslint/no-explicit-any
+ const Float16Array: any;
+}
+
+// the following code allows delaying execution of BigInt/Float16Array checking. This allows lazy initialization for
+// NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP and NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP, which allows BigInt/Float16Array
+// polyfill if available.
+let isTypedArrayChecked = false;
+export const checkTypedArray = () => {
+ if (!isTypedArrayChecked) {
+ isTypedArrayChecked = true;
+ const isBigInt64ArrayAvailable = typeof BigInt64Array !== 'undefined' && BigInt64Array.from;
+ const isBigUint64ArrayAvailable = typeof BigUint64Array !== 'undefined' && BigUint64Array.from;
+ const isFloat16ArrayAvailable = typeof Float16Array !== 'undefined' && Float16Array.from;
if (isBigInt64ArrayAvailable) {
NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP.set('int64', BigInt64Array);
@@ -53,5 +58,12 @@ export const checkBigInt = () => {
NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP.set('uint64', BigUint64Array);
NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP.set(BigUint64Array, 'uint64');
}
+ if (isFloat16ArrayAvailable) {
+ NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP.set('float16', Float16Array);
+ NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP.set(Float16Array, 'float16');
+ } else {
+ // if Float16Array is not available, use 'Uint16Array' to store the data.
+ NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP.set('float16', Uint16Array);
+ }
}
};
diff --git a/js/common/lib/tensor-impl.ts b/js/common/lib/tensor-impl.ts
index e3e2b9c728556..56682ef98e117 100644
--- a/js/common/lib/tensor-impl.ts
+++ b/js/common/lib/tensor-impl.ts
@@ -5,7 +5,7 @@ import {tensorToDataURL, tensorToImageData} from './tensor-conversion-impl.js';
import {TensorToDataUrlOptions, TensorToImageDataOptions} from './tensor-conversion.js';
import {tensorFromGpuBuffer, tensorFromImage, tensorFromPinnedBuffer, tensorFromTexture} from './tensor-factory-impl.js';
import {CpuPinnedConstructorParameters, GpuBufferConstructorParameters, TensorFromGpuBufferOptions, TensorFromImageBitmapOptions, TensorFromImageDataOptions, TensorFromImageElementOptions, TensorFromTextureOptions, TensorFromUrlOptions, TextureConstructorParameters} from './tensor-factory.js';
-import {checkBigInt, NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP, NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP, SupportedTypedArray, SupportedTypedArrayConstructors} from './tensor-impl-type-mapping.js';
+import {checkTypedArray, NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP, NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP, SupportedTypedArray, SupportedTypedArrayConstructors} from './tensor-impl-type-mapping.js';
import {calculateSize, tensorReshape} from './tensor-utils-impl.js';
import {Tensor as TensorInterface} from './tensor.js';
@@ -67,8 +67,8 @@ export class Tensor implements TensorInterface {
arg0: TensorType|TensorDataType|readonly string[]|readonly boolean[]|CpuPinnedConstructorParameters|
TextureConstructorParameters|GpuBufferConstructorParameters,
arg1?: TensorDataType|readonly number[]|readonly string[]|readonly boolean[], arg2?: readonly number[]) {
- // perform one-time check for BigInt support
- checkBigInt();
+ // perform one-time check for BigInt/Float16Array support
+ checkTypedArray();
let type: TensorType;
let dims: readonly number[];
@@ -103,7 +103,7 @@ export class Tensor implements TensorInterface {
}
case 'gpu-buffer': {
if ((type !== 'float32' && type !== 'float16' && type !== 'int32' && type !== 'int64' && type !== 'uint32' &&
- type !== 'bool')) {
+ type !== 'uint8' && type !== 'bool')) {
throw new TypeError(`unsupported type "${type}" to create tensor from gpu buffer`);
}
this.gpuBufferData = arg0.gpuBuffer;
@@ -142,7 +142,9 @@ export class Tensor implements TensorInterface {
throw new TypeError(`Unsupported tensor type: ${arg0}.`);
}
if (Array.isArray(arg1)) {
- if (arg0 === 'float16') {
+ if (arg0 === 'float16' && typedArrayConstructor === Uint16Array) {
+ // When no Float16Array polyfill is used, we cannot create 'float16' tensor from number array.
+ //
// Throw error here because when user try to use number array as data,
// e.g. new Tensor('float16', [1, 2, 3, 4], dims)), it will actually call
// Uint16Array.from(arg1) which generates wrong data.
diff --git a/js/common/lib/tensor.ts b/js/common/lib/tensor.ts
index 6c08d1fe8e057..20319ebb800c2 100644
--- a/js/common/lib/tensor.ts
+++ b/js/common/lib/tensor.ts
@@ -135,7 +135,7 @@ export declare namespace Tensor {
/**
* supported data types for constructing a tensor from a WebGPU buffer
*/
- export type GpuBufferDataTypes = 'float32'|'float16'|'int32'|'int64'|'uint32'|'bool';
+ export type GpuBufferDataTypes = 'float32'|'float16'|'int32'|'int64'|'uint32'|'uint8'|'bool';
/**
* represent where the tensor data is stored
@@ -160,7 +160,7 @@ export interface Tensor extends TypedTensorBase, TypedTensorUtils {
+ if (typeof env.trace === 'undefined' ? !env.wasm.trace : !env.trace) {
+ return;
+ }
+ // eslint-disable-next-line no-console
+ console.timeStamp(`${deviceType}::ORT::${label}`);
+};
+
+const TRACE_FUNC = (msg: string, extraMsg?: string) => {
+ const stack = new Error().stack?.split(/\r\n|\r|\n/g) || [];
+ let hasTraceFunc = false;
+ for (let i = 0; i < stack.length; i++) {
+ if (hasTraceFunc && !stack[i].includes('TRACE_FUNC')) {
+ let label = `FUNC_${msg}::${stack[i].trim().split(' ')[1]}`;
+ if (extraMsg) {
+ label += `::${extraMsg}`;
+ }
+ TRACE('CPU', label);
+ return;
+ }
+ if (stack[i].includes('TRACE_FUNC')) {
+ hasTraceFunc = true;
+ }
+ }
+};
+
+/**
+ * @ignore
+ */
+export const TRACE_FUNC_BEGIN = (extraMsg?: string) => {
+ if (typeof env.trace === 'undefined' ? !env.wasm.trace : !env.trace) {
+ return;
+ }
+ TRACE_FUNC('BEGIN', extraMsg);
+};
+
+/**
+ * @ignore
+ */
+export const TRACE_FUNC_END = (extraMsg?: string) => {
+ if (typeof env.trace === 'undefined' ? !env.wasm.trace : !env.trace) {
+ return;
+ }
+ TRACE_FUNC('END', extraMsg);
+};
diff --git a/js/common/lib/training-session-impl.ts b/js/common/lib/training-session-impl.ts
index 23bd4421ae672..bae38b0dfda5a 100644
--- a/js/common/lib/training-session-impl.ts
+++ b/js/common/lib/training-session-impl.ts
@@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
-import {resolveBackend} from './backend-impl.js';
+import {resolveBackendAndExecutionProviders} from './backend-impl.js';
import {SessionHandler, TrainingSessionHandler} from './backend.js';
import {InferenceSession as InferenceSession} from './inference-session.js';
import {OnnxValue} from './onnx-value.js';
@@ -55,13 +55,12 @@ export class TrainingSession implements TrainingSessionInterface {
const optimizerModel: string|Uint8Array = trainingOptions.optimizerModel || '';
const options: SessionOptions = sessionOptions || {};
- // get backend hints
- const eps = options.executionProviders || [];
- const backendHints = eps.map(i => typeof i === 'string' ? i : i.name);
- const backend = await resolveBackend(backendHints);
+ // resolve backend, update session options with validated EPs, and create session handler
+ const [backend, optionsWithValidatedEPs] = await resolveBackendAndExecutionProviders(options);
if (backend.createTrainingSessionHandler) {
const handler = await backend.createTrainingSessionHandler(
- trainingOptions.checkpointState, trainingOptions.trainModel, evalModel, optimizerModel, options);
+ trainingOptions.checkpointState, trainingOptions.trainModel, evalModel, optimizerModel,
+ optionsWithValidatedEPs);
return new TrainingSession(handler, !!trainingOptions.optimizerModel, !!trainingOptions.evalModel);
} else {
throw new Error(noBackendErrMsg);
diff --git a/js/common/lib/training-session.ts b/js/common/lib/training-session.ts
index e54aed90e702c..f9de77e3ac7d0 100644
--- a/js/common/lib/training-session.ts
+++ b/js/common/lib/training-session.ts
@@ -11,7 +11,7 @@ export declare namespace TrainingSession {
/**
* Either URI file path (string) or Uint8Array containing model or checkpoint information.
*/
- type URIorBuffer = string|Uint8Array;
+ type UriOrBuffer = string|Uint8Array;
}
/**
@@ -98,13 +98,13 @@ export interface TrainingSession {
getParametersSize(trainableOnly: boolean): Promise;
/**
- * Copies parameter values from the given array to the training state. Currently, only supporting models with
+ * Copies parameter values from the given buffer to the training state. Currently, only supporting models with
* parameters of type Float32.
*
- * @param buffer - Float32 buffer containing parameters converted to a Uint8Array.
+ * @param buffer - A Uint8Array representation of Float32 parameters.
* @param trainableOnly - True if trainable parameters only to be modified, false otherwise. Default value is true.
*/
- loadParametersBuffer(array: Uint8Array, trainableOnly: boolean): Promise;
+ loadParametersBuffer(buffer: Uint8Array, trainableOnly: boolean): Promise;
/**
* Copies the model parameters to a contiguous buffer. Usually used in the context of Federated Learning.
@@ -157,19 +157,19 @@ export interface TrainingSessionCreateOptions {
/**
* URI or buffer for a .ckpt file that contains the checkpoint for the training model.
*/
- checkpointState: TrainingSession.URIorBuffer;
+ checkpointState: TrainingSession.UriOrBuffer;
/**
* URI or buffer for the .onnx training file.
*/
- trainModel: TrainingSession.URIorBuffer;
+ trainModel: TrainingSession.UriOrBuffer;
/**
* Optional. URI or buffer for the .onnx optimizer model file.
*/
- optimizerModel?: TrainingSession.URIorBuffer;
+ optimizerModel?: TrainingSession.UriOrBuffer;
/**
* Optional. URI or buffer for the .onnx eval model file.
*/
- evalModel?: TrainingSession.URIorBuffer;
+ evalModel?: TrainingSession.UriOrBuffer;
}
/**
diff --git a/js/common/lib/version.ts b/js/common/lib/version.ts
index 96c2361cceabe..40f970ddf02ae 100644
--- a/js/common/lib/version.ts
+++ b/js/common/lib/version.ts
@@ -4,4 +4,4 @@
// This file is generated by /js/scripts/update-version.ts
// Do not modify file content manually.
-export const version = '1.17.0';
+export const version = '1.18.0';
diff --git a/js/common/package-lock.json b/js/common/package-lock.json
index 84f6dba83fa59..3988ac80707e0 100644
--- a/js/common/package-lock.json
+++ b/js/common/package-lock.json
@@ -1,21 +1,21 @@
{
"name": "onnxruntime-common",
- "version": "1.17.0",
+ "version": "1.18.0",
"lockfileVersion": 2,
"requires": true,
"packages": {
"": {
"name": "onnxruntime-common",
- "version": "1.17.0",
+ "version": "1.18.0",
"license": "MIT",
"devDependencies": {
- "typedoc": "^0.23.22"
+ "typedoc": "^0.25.7"
}
},
"node_modules/ansi-sequence-parser": {
- "version": "1.1.0",
- "resolved": "https://registry.npmjs.org/ansi-sequence-parser/-/ansi-sequence-parser-1.1.0.tgz",
- "integrity": "sha512-lEm8mt52to2fT8GhciPCGeCXACSz2UwIN4X2e2LJSnZ5uAbn2/dsYdOmUXq0AtWS5cpAupysIneExOgH0Vd2TQ==",
+ "version": "1.1.1",
+ "resolved": "https://registry.npmjs.org/ansi-sequence-parser/-/ansi-sequence-parser-1.1.1.tgz",
+ "integrity": "sha512-vJXt3yiaUL4UU546s3rPXlsry/RnM730G1+HkpKE012AN0sx1eOrxSu95oKDIonskeLTijMgqWZ3uDEe3NFvyg==",
"dev": true
},
"node_modules/balanced-match": {
@@ -34,9 +34,9 @@
}
},
"node_modules/jsonc-parser": {
- "version": "3.2.0",
- "resolved": "https://registry.npmjs.org/jsonc-parser/-/jsonc-parser-3.2.0.tgz",
- "integrity": "sha512-gfFQZrcTc8CnKXp6Y4/CBT3fTc0OVuDofpre4aEeEpSBPV5X5v4+Vmx+8snU7RLPrNHPKSgLxGo9YuQzz20o+w==",
+ "version": "3.2.1",
+ "resolved": "https://registry.npmjs.org/jsonc-parser/-/jsonc-parser-3.2.1.tgz",
+ "integrity": "sha512-AilxAyFOAcK5wA1+LeaySVBrHsGQvUFCDWXKpZjzaL0PqW+xfBOttn8GNtWKFWqneyMZj41MWF9Kl6iPWLwgOA==",
"dev": true
},
"node_modules/lunr": {
@@ -46,9 +46,9 @@
"dev": true
},
"node_modules/marked": {
- "version": "4.2.12",
- "resolved": "https://registry.npmjs.org/marked/-/marked-4.2.12.tgz",
- "integrity": "sha512-yr8hSKa3Fv4D3jdZmtMMPghgVt6TWbk86WQaWhDloQjRSQhMMYCAro7jP7VDJrjjdV8pxVxMssXS8B8Y5DZ5aw==",
+ "version": "4.3.0",
+ "resolved": "https://registry.npmjs.org/marked/-/marked-4.3.0.tgz",
+ "integrity": "sha512-PRsaiG84bK+AMvxziE/lCFss8juXjNaWzVbN5tXAm4XjeaS9NAHhop+PjQxz2A9h8Q4M/xGmzP8vqNwy6JeK0A==",
"dev": true,
"bin": {
"marked": "bin/marked.js"
@@ -58,24 +58,24 @@
}
},
"node_modules/minimatch": {
- "version": "7.4.2",
- "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-7.4.2.tgz",
- "integrity": "sha512-xy4q7wou3vUoC9k1xGTXc+awNdGaGVHtFUaey8tiX4H1QRc04DZ/rmDFwNm2EBsuYEhAZ6SgMmYf3InGY6OauA==",
+ "version": "9.0.3",
+ "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.3.tgz",
+ "integrity": "sha512-RHiac9mvaRw0x3AYRgDC1CxAP7HTcNrrECeA8YYJeWnpo+2Q5CegtZjaotWTWxDG3UeGA1coE05iH1mPjT/2mg==",
"dev": true,
"dependencies": {
"brace-expansion": "^2.0.1"
},
"engines": {
- "node": ">=10"
+ "node": ">=16 || 14 >=14.17"
},
"funding": {
"url": "https://github.com/sponsors/isaacs"
}
},
"node_modules/shiki": {
- "version": "0.14.1",
- "resolved": "https://registry.npmjs.org/shiki/-/shiki-0.14.1.tgz",
- "integrity": "sha512-+Jz4nBkCBe0mEDqo1eKRcCdjRtrCjozmcbTUjbPTX7OOJfEbTZzlUWlZtGe3Gb5oV1/jnojhG//YZc3rs9zSEw==",
+ "version": "0.14.7",
+ "resolved": "https://registry.npmjs.org/shiki/-/shiki-0.14.7.tgz",
+ "integrity": "sha512-dNPAPrxSc87ua2sKJ3H5dQ/6ZaY8RNnaAqK+t0eG7p0Soi2ydiqbGOTaZCqaYvA/uZYfS1LJnemt3Q+mSfcPCg==",
"dev": true,
"dependencies": {
"ansi-sequence-parser": "^1.1.0",
@@ -85,30 +85,30 @@
}
},
"node_modules/typedoc": {
- "version": "0.23.26",
- "resolved": "https://registry.npmjs.org/typedoc/-/typedoc-0.23.26.tgz",
- "integrity": "sha512-5m4KwR5tOLnk0OtMaRn9IdbeRM32uPemN9kur7YK9wFqx8U0CYrvO9aVq6ysdZSV1c824BTm+BuQl2Ze/k1HtA==",
+ "version": "0.25.7",
+ "resolved": "https://registry.npmjs.org/typedoc/-/typedoc-0.25.7.tgz",
+ "integrity": "sha512-m6A6JjQRg39p2ZVRIN3NKXgrN8vzlHhOS+r9ymUYtcUP/TIQPvWSq7YgE5ZjASfv5Vd5BW5xrir6Gm2XNNcOow==",
"dev": true,
"dependencies": {
"lunr": "^2.3.9",
- "marked": "^4.2.12",
- "minimatch": "^7.1.3",
- "shiki": "^0.14.1"
+ "marked": "^4.3.0",
+ "minimatch": "^9.0.3",
+ "shiki": "^0.14.7"
},
"bin": {
"typedoc": "bin/typedoc"
},
"engines": {
- "node": ">= 14.14"
+ "node": ">= 16"
},
"peerDependencies": {
- "typescript": "4.6.x || 4.7.x || 4.8.x || 4.9.x"
+ "typescript": "4.6.x || 4.7.x || 4.8.x || 4.9.x || 5.0.x || 5.1.x || 5.2.x || 5.3.x"
}
},
"node_modules/typescript": {
- "version": "4.9.5",
- "resolved": "https://registry.npmjs.org/typescript/-/typescript-4.9.5.tgz",
- "integrity": "sha512-1FXk9E2Hm+QzZQ7z+McJiHL4NW1F2EzMu9Nq9i3zAaGqibafqYwCVU6WyWAuyQRRzOlxou8xZSyXLEN8oKj24g==",
+ "version": "5.2.2",
+ "resolved": "https://registry.npmjs.org/typescript/-/typescript-5.2.2.tgz",
+ "integrity": "sha512-mI4WrpHsbCIcwT9cF4FZvr80QUeKvsUsUvKDoR+X/7XHQH98xYD8YHZg7ANtz2GtZt/CBq2QJ0thkGJMHfqc1w==",
"dev": true,
"peer": true,
"bin": {
@@ -116,7 +116,7 @@
"tsserver": "bin/tsserver"
},
"engines": {
- "node": ">=4.2.0"
+ "node": ">=14.17"
}
},
"node_modules/vscode-oniguruma": {
@@ -134,9 +134,9 @@
},
"dependencies": {
"ansi-sequence-parser": {
- "version": "1.1.0",
- "resolved": "https://registry.npmjs.org/ansi-sequence-parser/-/ansi-sequence-parser-1.1.0.tgz",
- "integrity": "sha512-lEm8mt52to2fT8GhciPCGeCXACSz2UwIN4X2e2LJSnZ5uAbn2/dsYdOmUXq0AtWS5cpAupysIneExOgH0Vd2TQ==",
+ "version": "1.1.1",
+ "resolved": "https://registry.npmjs.org/ansi-sequence-parser/-/ansi-sequence-parser-1.1.1.tgz",
+ "integrity": "sha512-vJXt3yiaUL4UU546s3rPXlsry/RnM730G1+HkpKE012AN0sx1eOrxSu95oKDIonskeLTijMgqWZ3uDEe3NFvyg==",
"dev": true
},
"balanced-match": {
@@ -155,9 +155,9 @@
}
},
"jsonc-parser": {
- "version": "3.2.0",
- "resolved": "https://registry.npmjs.org/jsonc-parser/-/jsonc-parser-3.2.0.tgz",
- "integrity": "sha512-gfFQZrcTc8CnKXp6Y4/CBT3fTc0OVuDofpre4aEeEpSBPV5X5v4+Vmx+8snU7RLPrNHPKSgLxGo9YuQzz20o+w==",
+ "version": "3.2.1",
+ "resolved": "https://registry.npmjs.org/jsonc-parser/-/jsonc-parser-3.2.1.tgz",
+ "integrity": "sha512-AilxAyFOAcK5wA1+LeaySVBrHsGQvUFCDWXKpZjzaL0PqW+xfBOttn8GNtWKFWqneyMZj41MWF9Kl6iPWLwgOA==",
"dev": true
},
"lunr": {
@@ -167,24 +167,24 @@
"dev": true
},
"marked": {
- "version": "4.2.12",
- "resolved": "https://registry.npmjs.org/marked/-/marked-4.2.12.tgz",
- "integrity": "sha512-yr8hSKa3Fv4D3jdZmtMMPghgVt6TWbk86WQaWhDloQjRSQhMMYCAro7jP7VDJrjjdV8pxVxMssXS8B8Y5DZ5aw==",
+ "version": "4.3.0",
+ "resolved": "https://registry.npmjs.org/marked/-/marked-4.3.0.tgz",
+ "integrity": "sha512-PRsaiG84bK+AMvxziE/lCFss8juXjNaWzVbN5tXAm4XjeaS9NAHhop+PjQxz2A9h8Q4M/xGmzP8vqNwy6JeK0A==",
"dev": true
},
"minimatch": {
- "version": "7.4.2",
- "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-7.4.2.tgz",
- "integrity": "sha512-xy4q7wou3vUoC9k1xGTXc+awNdGaGVHtFUaey8tiX4H1QRc04DZ/rmDFwNm2EBsuYEhAZ6SgMmYf3InGY6OauA==",
+ "version": "9.0.3",
+ "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.3.tgz",
+ "integrity": "sha512-RHiac9mvaRw0x3AYRgDC1CxAP7HTcNrrECeA8YYJeWnpo+2Q5CegtZjaotWTWxDG3UeGA1coE05iH1mPjT/2mg==",
"dev": true,
"requires": {
"brace-expansion": "^2.0.1"
}
},
"shiki": {
- "version": "0.14.1",
- "resolved": "https://registry.npmjs.org/shiki/-/shiki-0.14.1.tgz",
- "integrity": "sha512-+Jz4nBkCBe0mEDqo1eKRcCdjRtrCjozmcbTUjbPTX7OOJfEbTZzlUWlZtGe3Gb5oV1/jnojhG//YZc3rs9zSEw==",
+ "version": "0.14.7",
+ "resolved": "https://registry.npmjs.org/shiki/-/shiki-0.14.7.tgz",
+ "integrity": "sha512-dNPAPrxSc87ua2sKJ3H5dQ/6ZaY8RNnaAqK+t0eG7p0Soi2ydiqbGOTaZCqaYvA/uZYfS1LJnemt3Q+mSfcPCg==",
"dev": true,
"requires": {
"ansi-sequence-parser": "^1.1.0",
@@ -194,21 +194,21 @@
}
},
"typedoc": {
- "version": "0.23.26",
- "resolved": "https://registry.npmjs.org/typedoc/-/typedoc-0.23.26.tgz",
- "integrity": "sha512-5m4KwR5tOLnk0OtMaRn9IdbeRM32uPemN9kur7YK9wFqx8U0CYrvO9aVq6ysdZSV1c824BTm+BuQl2Ze/k1HtA==",
+ "version": "0.25.7",
+ "resolved": "https://registry.npmjs.org/typedoc/-/typedoc-0.25.7.tgz",
+ "integrity": "sha512-m6A6JjQRg39p2ZVRIN3NKXgrN8vzlHhOS+r9ymUYtcUP/TIQPvWSq7YgE5ZjASfv5Vd5BW5xrir6Gm2XNNcOow==",
"dev": true,
"requires": {
"lunr": "^2.3.9",
- "marked": "^4.2.12",
- "minimatch": "^7.1.3",
- "shiki": "^0.14.1"
+ "marked": "^4.3.0",
+ "minimatch": "^9.0.3",
+ "shiki": "^0.14.7"
}
},
"typescript": {
- "version": "4.9.5",
- "resolved": "https://registry.npmjs.org/typescript/-/typescript-4.9.5.tgz",
- "integrity": "sha512-1FXk9E2Hm+QzZQ7z+McJiHL4NW1F2EzMu9Nq9i3zAaGqibafqYwCVU6WyWAuyQRRzOlxou8xZSyXLEN8oKj24g==",
+ "version": "5.2.2",
+ "resolved": "https://registry.npmjs.org/typescript/-/typescript-5.2.2.tgz",
+ "integrity": "sha512-mI4WrpHsbCIcwT9cF4FZvr80QUeKvsUsUvKDoR+X/7XHQH98xYD8YHZg7ANtz2GtZt/CBq2QJ0thkGJMHfqc1w==",
"dev": true,
"peer": true
},
diff --git a/js/common/package.json b/js/common/package.json
index beab7d29be263..cd2612aab4984 100644
--- a/js/common/package.json
+++ b/js/common/package.json
@@ -2,14 +2,14 @@
"license": "MIT",
"type": "module",
"name": "onnxruntime-common",
- "version": "1.17.0",
+ "version": "1.18.0",
"repository": {
"url": "https://github.com/Microsoft/onnxruntime.git",
"type": "git"
},
"author": "fs-eire",
"scripts": {
- "build:cjs": "tsc --module commonjs --outDir ./dist/cjs",
+ "build:cjs": "tsc --module commonjs --moduleResolution node10 --outDir ./dist/cjs",
"build:esm": "tsc",
"build:bundles": "webpack",
"build": "node ./build.js",
@@ -18,7 +18,7 @@
"test": "mocha ./test/**/*.js --timeout 30000"
},
"devDependencies": {
- "typedoc": "^0.23.22"
+ "typedoc": "^0.25.7"
},
"main": "dist/cjs/index.js",
"exports": {
diff --git a/js/common/test/tsconfig.json b/js/common/test/tsconfig.json
index 2e4927ac3b325..e9068ad837a81 100644
--- a/js/common/test/tsconfig.json
+++ b/js/common/test/tsconfig.json
@@ -2,7 +2,7 @@
"extends": "../../tsconfig.tools.json",
"exclude": ["type-tests/**/*.ts"],
"compilerOptions": {
- "module": "ES2022",
+ "module": "Node16",
"sourceMap": true
}
}
diff --git a/js/node/CMakeLists.txt b/js/node/CMakeLists.txt
index c3898fbad7401..8157df288eeb9 100644
--- a/js/node/CMakeLists.txt
+++ b/js/node/CMakeLists.txt
@@ -66,9 +66,17 @@ if(MSVC AND CMAKE_JS_NODELIB_DEF AND CMAKE_JS_NODELIB_TARGET)
execute_process(COMMAND ${CMAKE_AR} /def:${CMAKE_JS_NODELIB_DEF} /out:${CMAKE_JS_NODELIB_TARGET} ${CMAKE_STATIC_LINKER_FLAGS})
endif()
+if (WIN32)
+ if (${ONNXRUNTIME_GENERATOR} MATCHES "Ninja")
+ set(ONNXRUNTIME_WIN_BIN_DIR ${ONNXRUNTIME_BUILD_DIR})
+ else()
+ set(ONNXRUNTIME_WIN_BIN_DIR ${ONNXRUNTIME_BUILD_DIR}/${CMAKE_BUILD_TYPE})
+ endif()
+ message(STATUS "onnxruntime dist dir: ${ONNXRUNTIME_WIN_BIN_DIR}")
+endif()
# add libraries
if (WIN32)
- target_link_directories(onnxruntime_binding PRIVATE ${ONNXRUNTIME_BUILD_DIR}/${CMAKE_BUILD_TYPE})
+ target_link_directories(onnxruntime_binding PRIVATE ${ONNXRUNTIME_WIN_BIN_DIR})
else()
target_link_directories(onnxruntime_binding PRIVATE ${ONNXRUNTIME_BUILD_DIR})
endif()
@@ -95,14 +103,14 @@ if (WIN32)
add_custom_command(
TARGET onnxruntime_binding POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy
- ${ONNXRUNTIME_BUILD_DIR}/${CMAKE_BUILD_TYPE}/onnxruntime.dll
+ ${ONNXRUNTIME_WIN_BIN_DIR}/onnxruntime.dll
${dist_folder}
)
if (USE_DML)
add_custom_command(
TARGET onnxruntime_binding POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy
- ${ONNXRUNTIME_BUILD_DIR}/${CMAKE_BUILD_TYPE}/DirectML.dll
+ ${ONNXRUNTIME_WIN_BIN_DIR}/DirectML.dll
${dist_folder}
)
endif ()
@@ -110,7 +118,7 @@ if (WIN32)
add_custom_command(
TARGET onnxruntime_binding POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy
- ${ONNXRUNTIME_BUILD_DIR}/${CMAKE_BUILD_TYPE}/onnxruntime.pdb
+ ${ONNXRUNTIME_WIN_BIN_DIR}/onnxruntime.pdb
${dist_folder}
COMMAND ${CMAKE_COMMAND} -E copy $/onnxruntime_binding.pdb ${dist_folder}
)
diff --git a/js/node/README.md b/js/node/README.md
index 98b2ea66de2a8..234eaa111a220 100644
--- a/js/node/README.md
+++ b/js/node/README.md
@@ -22,7 +22,7 @@ Following platforms are supported with pre-built binaries:
- Linux x64 CPU NAPI_v3
- MacOS x64 CPU NAPI_v3
-To use on platforms without pre-built binaries, you can build Node.js binding from source and consume it by `npm install /js/node/`. See also [instructions](https://www.onnxruntime.ai/docs/how-to/build.html#apis-and-language-bindings) for building ONNX Runtime Node.js binding locally.
+To use on platforms without pre-built binaries, you can build Node.js binding from source and consume it by `npm install /js/node/`. See also [instructions](https://onnxruntime.ai/docs/build/inferencing.html#apis-and-language-bindings) for building ONNX Runtime Node.js binding locally.
# GPU Support
diff --git a/js/node/lib/backend.ts b/js/node/lib/backend.ts
index e8eb0e9babf5a..927953b4f1dd6 100644
--- a/js/node/lib/backend.ts
+++ b/js/node/lib/backend.ts
@@ -36,7 +36,7 @@ class OnnxruntimeSessionHandler implements InferenceSessionHandler {
async run(feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType, options: InferenceSession.RunOptions):
Promise {
return new Promise((resolve, reject) => {
- process.nextTick(() => {
+ setImmediate(() => {
try {
resolve(this.#inferenceSession.run(feeds, fetches, options));
} catch (e) {
@@ -56,7 +56,7 @@ class OnnxruntimeBackend implements Backend {
async createInferenceSessionHandler(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions):
Promise {
return new Promise((resolve, reject) => {
- process.nextTick(() => {
+ setImmediate(() => {
try {
resolve(new OnnxruntimeSessionHandler(pathOrBuffer, options || {}));
} catch (e) {
diff --git a/js/node/lib/version.ts b/js/node/lib/version.ts
index 96c2361cceabe..40f970ddf02ae 100644
--- a/js/node/lib/version.ts
+++ b/js/node/lib/version.ts
@@ -4,4 +4,4 @@
// This file is generated by /js/scripts/update-version.ts
// Do not modify file content manually.
-export const version = '1.17.0';
+export const version = '1.18.0';
diff --git a/js/node/package-lock.json b/js/node/package-lock.json
index c1cf8af4bb80e..62b47698a1438 100644
--- a/js/node/package-lock.json
+++ b/js/node/package-lock.json
@@ -1,12 +1,12 @@
{
"name": "onnxruntime-node",
- "version": "1.17.0",
+ "version": "1.18.0",
"lockfileVersion": 2,
"requires": true,
"packages": {
"": {
"name": "onnxruntime-node",
- "version": "1.17.0",
+ "version": "1.18.0",
"license": "MIT",
"os": [
"win32",
@@ -27,10 +27,10 @@
},
"../common": {
"name": "onnxruntime-common",
- "version": "1.17.0",
+ "version": "1.18.0",
"license": "MIT",
"devDependencies": {
- "typedoc": "^0.23.22"
+ "typedoc": "^0.25.7"
}
},
"node_modules/@protobufjs/aspromise": {
@@ -336,9 +336,9 @@
"dev": true
},
"node_modules/follow-redirects": {
- "version": "1.15.2",
- "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.2.tgz",
- "integrity": "sha512-VQLG33o04KaQ8uYi2tVNbdrWp1QWxNNea+nmIB4EVM28v0hmP17z7aG1+wAkNzVq4KeXTq3221ye5qTJP91JwA==",
+ "version": "1.15.6",
+ "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.6.tgz",
+ "integrity": "sha512-wWN62YITEaOpSK584EZXJafH1AGpO8RVgElfkuXbTOrPX4fIfOyEpW/CsiNd8JdYrAoOvafRTOEnvsO++qCqFA==",
"dev": true,
"funding": [
{
@@ -1242,9 +1242,9 @@
"dev": true
},
"follow-redirects": {
- "version": "1.15.2",
- "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.2.tgz",
- "integrity": "sha512-VQLG33o04KaQ8uYi2tVNbdrWp1QWxNNea+nmIB4EVM28v0hmP17z7aG1+wAkNzVq4KeXTq3221ye5qTJP91JwA==",
+ "version": "1.15.6",
+ "resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.6.tgz",
+ "integrity": "sha512-wWN62YITEaOpSK584EZXJafH1AGpO8RVgElfkuXbTOrPX4fIfOyEpW/CsiNd8JdYrAoOvafRTOEnvsO++qCqFA==",
"dev": true
},
"form-data": {
@@ -1503,7 +1503,7 @@
"onnxruntime-common": {
"version": "file:../common",
"requires": {
- "typedoc": "^0.23.22"
+ "typedoc": "^0.25.7"
}
},
"parse-json": {
diff --git a/js/node/package.json b/js/node/package.json
index 8e591d8f46b9d..026840742e29e 100644
--- a/js/node/package.json
+++ b/js/node/package.json
@@ -13,7 +13,7 @@
3
]
},
- "version": "1.17.0",
+ "version": "1.18.0",
"dependencies": {
"onnxruntime-common": "file:../common"
},
diff --git a/js/node/script/build.ts b/js/node/script/build.ts
index dfa88821a8d09..cc59507179085 100644
--- a/js/node/script/build.ts
+++ b/js/node/script/build.ts
@@ -23,6 +23,8 @@ if (ARCH !== 'x64' && ARCH !== 'ia32' && ARCH !== 'arm64' && ARCH !== 'arm') {
}
// --onnxruntime-build-dir=
const ONNXRUNTIME_BUILD_DIR = buildArgs['onnxruntime-build-dir'];
+// --onnxruntime-generator=
+const ONNXRUNTIME_GENERATOR = buildArgs['onnxruntime-generator'];
// --rebuild
const REBUILD = !!buildArgs.rebuild;
// --use_dml
@@ -55,6 +57,9 @@ const args = [
if (ONNXRUNTIME_BUILD_DIR && typeof ONNXRUNTIME_BUILD_DIR === 'string') {
args.push(`--CDONNXRUNTIME_BUILD_DIR=${ONNXRUNTIME_BUILD_DIR}`);
}
+if (ONNXRUNTIME_GENERATOR && typeof ONNXRUNTIME_GENERATOR === 'string') {
+ args.push(`--CDONNXRUNTIME_GENERATOR=${ONNXRUNTIME_GENERATOR}`);
+}
if (USE_DML) {
args.push('--CDUSE_DML=ON');
}
diff --git a/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeModule.java b/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeModule.java
index fd085f9533801..707a356b949ab 100644
--- a/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeModule.java
+++ b/js/react_native/android/src/main/java/ai/onnxruntime/reactnative/OnnxruntimeModule.java
@@ -199,6 +199,12 @@ private WritableMap loadModelImpl(String uri, byte[] modelData, ReadableMap opti
if (modelData != null && modelData.length > 0) {
// load model via model data array
ortSession = ortEnvironment.createSession(modelData, sessionOptions);
+ } else if (uri.startsWith("file://") || uri.startsWith("/")) {
+ // load model from local
+ if (uri.startsWith("file://")) {
+ uri = uri.substring(7);
+ }
+ ortSession = ortEnvironment.createSession(uri, sessionOptions);
} else {
// load model via model path string uri
InputStream modelStream =
diff --git a/js/react_native/e2e/yarn.lock b/js/react_native/e2e/yarn.lock
index 9e20a286c4e27..6f05faf046098 100644
--- a/js/react_native/e2e/yarn.lock
+++ b/js/react_native/e2e/yarn.lock
@@ -3351,9 +3351,9 @@ invariant@^2.2.4:
loose-envify "^1.0.0"
ip@^1.1.5:
- version "1.1.8"
- resolved "https://registry.yarnpkg.com/ip/-/ip-1.1.8.tgz#ae05948f6b075435ed3307acce04629da8cdbf48"
- integrity sha512-PuExPYUiu6qMBQb4l06ecm6T6ujzhmh+MeJcW9wa89PoAz5pvd4zPgN5WJV104mb6S2T1AwNIAaB70JNrLQWhg==
+ version "1.1.9"
+ resolved "https://registry.yarnpkg.com/ip/-/ip-1.1.9.tgz#8dfbcc99a754d07f425310b86a99546b1151e396"
+ integrity sha512-cyRxvOEpNHNtchU3Ln9KC/auJgup87llfQpQ+t5ghoC/UhL16SWzbueiCsdTnWmqAWl7LadfuwhlqmtOaqMHdQ==
is-accessor-descriptor@^0.1.6:
version "0.1.6"
diff --git a/js/react_native/lib/version.ts b/js/react_native/lib/version.ts
index 96c2361cceabe..40f970ddf02ae 100644
--- a/js/react_native/lib/version.ts
+++ b/js/react_native/lib/version.ts
@@ -4,4 +4,4 @@
// This file is generated by /js/scripts/update-version.ts
// Do not modify file content manually.
-export const version = '1.17.0';
+export const version = '1.18.0';
diff --git a/js/react_native/package.json b/js/react_native/package.json
index 39e6cb08bb06a..47324a76fe55f 100644
--- a/js/react_native/package.json
+++ b/js/react_native/package.json
@@ -36,7 +36,7 @@
"registry": "https://registry.npmjs.org/"
},
"source": "lib/index",
- "version": "1.17.0",
+ "version": "1.18.0",
"main": "dist/commonjs/index",
"homepage": "https://github.com/microsoft/onnxruntime/blob/main/js/react_native/README.md",
"files": [
diff --git a/js/react_native/yarn.lock b/js/react_native/yarn.lock
index ff9be7fbe3a5b..bbb0c4f3d1e22 100644
--- a/js/react_native/yarn.lock
+++ b/js/react_native/yarn.lock
@@ -3701,9 +3701,9 @@ invariant@^2.2.4:
loose-envify "^1.0.0"
ip@^1.1.5:
- version "1.1.8"
- resolved "https://registry.yarnpkg.com/ip/-/ip-1.1.8.tgz#ae05948f6b075435ed3307acce04629da8cdbf48"
- integrity sha512-PuExPYUiu6qMBQb4l06ecm6T6ujzhmh+MeJcW9wa89PoAz5pvd4zPgN5WJV104mb6S2T1AwNIAaB70JNrLQWhg==
+ version "1.1.9"
+ resolved "https://registry.yarnpkg.com/ip/-/ip-1.1.9.tgz#8dfbcc99a754d07f425310b86a99546b1151e396"
+ integrity sha512-cyRxvOEpNHNtchU3Ln9KC/auJgup87llfQpQ+t5ghoC/UhL16SWzbueiCsdTnWmqAWl7LadfuwhlqmtOaqMHdQ==
is-absolute@^1.0.0:
version "1.0.0"
@@ -5254,7 +5254,7 @@ onetime@^5.1.0, onetime@^5.1.2:
mimic-fn "^2.1.0"
"onnxruntime-common@file:../common":
- version "1.17.0"
+ version "1.18.0"
open@^6.2.0:
version "6.4.0"
diff --git a/js/web/README.md b/js/web/README.md
index c75a40ad6da28..906c78a1b7ec4 100644
--- a/js/web/README.md
+++ b/js/web/README.md
@@ -12,7 +12,7 @@ The [Open Neural Network Exchange](http://onnx.ai/) (ONNX) is an open standard f
With ONNX Runtime Web, web developers can score models directly on browsers with various benefits including reducing server-client communication and protecting user privacy, as well as offering install-free and cross-platform in-browser ML experience.
-ONNX Runtime Web can run on both CPU and GPU. On CPU side, [WebAssembly](https://developer.mozilla.org/en-US/docs/WebAssembly) is adopted to execute the model at near-native speed. ONNX Runtime Web complies the native ONNX Runtime CPU engine into WebAssembly backend by using Emscripten, so it supports most functionalities native ONNX Runtime offers, including full ONNX operator coverage, multi-threading, [ONNX Runtime Quantization](https://www.onnxruntime.ai/docs/how-to/quantization.html) as well as [ONNX Runtime Mobile](https://onnxruntime.ai/docs/tutorials/mobile/). For performance acceleration with GPUs, ONNX Runtime Web leverages WebGL, a popular standard for accessing GPU capabilities. We are keeping improving op coverage and optimizing performance in WebGL backend.
+ONNX Runtime Web can run on both CPU and GPU. On CPU side, [WebAssembly](https://developer.mozilla.org/en-US/docs/WebAssembly) is adopted to execute the model at near-native speed. ONNX Runtime Web compiles the native ONNX Runtime CPU engine into WebAssembly backend by using Emscripten, so it supports most functionalities native ONNX Runtime offers, including full ONNX operator coverage, multi-threading, [ONNX Runtime Quantization](https://www.onnxruntime.ai/docs/how-to/quantization.html) as well as [ONNX Runtime Mobile](https://onnxruntime.ai/docs/tutorials/mobile/). For performance acceleration with GPUs, ONNX Runtime Web leverages WebGL, a popular standard for accessing GPU capabilities. We are keeping improving op coverage and optimizing performance in WebGL backend.
See [Compatibility](#Compatibility) and [Operators Supported](#Operators) for a list of platforms and operators ONNX Runtime Web currently supports.
@@ -22,7 +22,7 @@ Refer to [ONNX Runtime JavaScript examples](https://github.com/microsoft/onnxrun
## Documents
-### Developement
+### Development
Refer to the following links for development information:
diff --git a/js/web/docs/webgl-operators.md b/js/web/docs/webgl-operators.md
index 7c129b66bfa3d..cd25819a2069e 100644
--- a/js/web/docs/webgl-operators.md
+++ b/js/web/docs/webgl-operators.md
@@ -29,7 +29,7 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat
| [BitwiseOr](https://github.com/onnx/onnx/blob/main/docs/Operators.md#BitwiseOr) | |
| [BitwiseXor](https://github.com/onnx/onnx/blob/main/docs/Operators.md#BitwiseXor) | |
| [BlackmanWindow](https://github.com/onnx/onnx/blob/main/docs/Operators.md#BlackmanWindow) | |
-| [Cast](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Cast) | [6-8](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Cast-6), [9-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Cast-9), [13-18](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Cast-13), [19+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Cast-19) |
+| [Cast](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Cast) | [6-8](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Cast-6), [9-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Cast-9), [13-18](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Cast-13), [19-20](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Cast-19), [21+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Cast-21) |
| [CastLike](https://github.com/onnx/onnx/blob/main/docs/Operators.md#CastLike) | |
| [Ceil](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Ceil) | [6-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Ceil-6), [13+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Ceil-13) |
| [Celu](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Celu) | |
@@ -62,7 +62,7 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat
| [Exp](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Exp) | [6-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Exp-6), [13+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Exp-13) |
| [Expand](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Expand) | |
| [EyeLike](https://github.com/onnx/onnx/blob/main/docs/Operators.md#EyeLike) | |
-| [Flatten](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Flatten) | [1-8](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Flatten-1), [9-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Flatten-9), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Flatten-11), [13+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Flatten-13) |
+| [Flatten](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Flatten) | [1-8](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Flatten-1), [9-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Flatten-9), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Flatten-11), [13-20](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Flatten-13), [21+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Flatten-21) |
| [Floor](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Floor) | [6-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Floor-6), [13+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Floor-13) |
| [GRU](https://github.com/onnx/onnx/blob/main/docs/Operators.md#GRU) | |
| [Gather](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Gather) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Gather-1), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Gather-11), [13+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Gather-13) |
@@ -82,7 +82,7 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat
| [HardSigmoid](https://github.com/onnx/onnx/blob/main/docs/Operators.md#HardSigmoid) | |
| [HardSwish](https://github.com/onnx/onnx/blob/main/docs/Operators.md#HardSwish) | |
| [Hardmax](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Hardmax) | |
-| [Identity](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Identity) | [1-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-1), [13](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-13), [14-15](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-14), [16-18](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-16), [19+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-19) |
+| [Identity](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Identity) | [1-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-1), [13](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-13), [14-15](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-14), [16-18](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-16), [19-20](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-19), [21+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-21) |
| [If](https://github.com/onnx/onnx/blob/main/docs/Operators.md#If) | |
| [ImageDecoder](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ImageDecoder) | |
| [InstanceNormalization](https://github.com/onnx/onnx/blob/main/docs/Operators.md#InstanceNormalization) | [6+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#InstanceNormalization-6) |
@@ -124,7 +124,7 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat
| [OptionalHasElement](https://github.com/onnx/onnx/blob/main/docs/Operators.md#OptionalHasElement) | |
| [Or](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Or) | [7+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Or-7) |
| [PRelu](https://github.com/onnx/onnx/blob/main/docs/Operators.md#PRelu) | [7-8](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#PRelu-7), [9-15](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#PRelu-9), [16+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#PRelu-16) |
-| [Pad](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Pad) | [2-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pad-2), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pad-11), [13-17](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pad-13), [18](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pad-18), [19+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pad-19) |
+| [Pad](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Pad) | [2-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pad-2), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pad-11), [13-17](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pad-13), [18](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pad-18), [19-20](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pad-19), [21+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pad-21) |
| [Pow](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Pow) | [7-11](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pow-7), [12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pow-12), [13-14](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pow-13), [15+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Pow-15) |
| [QLinearConv](https://github.com/onnx/onnx/blob/main/docs/Operators.md#QLinearConv) | |
| [QLinearMatMul](https://github.com/onnx/onnx/blob/main/docs/Operators.md#QLinearMatMul) | |
@@ -148,7 +148,7 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat
| [ReduceSumSquare](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceSumSquare) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceSumSquare-1), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceSumSquare-11), [13-17](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceSumSquare-13), [18+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceSumSquare-18) |
| [RegexFullMatch](https://github.com/onnx/onnx/blob/main/docs/Operators.md#RegexFullMatch) | |
| [Relu](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Relu) | [6-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Relu-6), [13](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Relu-13), [14+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Relu-14) |
-| [Reshape](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Reshape) | [5-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Reshape-5), [13](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Reshape-13), [14-18](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Reshape-14), [19+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Reshape-19) |
+| [Reshape](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Reshape) | [5-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Reshape-5), [13](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Reshape-13), [14-18](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Reshape-14), [19-20](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Reshape-19), [21+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Reshape-21) |
| [Resize](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Resize) | [10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Resize-10), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Resize-11), [13-17](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Resize-13), [18](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Resize-18), [19+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Resize-19) |
| [ReverseSequence](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReverseSequence) | |
| [RoiAlign](https://github.com/onnx/onnx/blob/main/docs/Operators.md#RoiAlign) | |
@@ -166,7 +166,7 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat
| [SequenceInsert](https://github.com/onnx/onnx/blob/main/docs/Operators.md#SequenceInsert) | |
| [SequenceLength](https://github.com/onnx/onnx/blob/main/docs/Operators.md#SequenceLength) | |
| [SequenceMap](https://github.com/onnx/onnx/blob/main/docs/Operators.md#SequenceMap) | |
-| [Shape](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Shape) | [1-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Shape-1), [13-14](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Shape-13), [15-18](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Shape-15), [19+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Shape-19) |
+| [Shape](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Shape) | [1-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Shape-1), [13-14](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Shape-13), [15-18](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Shape-15), [19-20](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Shape-19), [21+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Shape-21) |
| [Shrink](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Shrink) | |
| [Sigmoid](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Sigmoid) | [6-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Sigmoid-6), [13+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Sigmoid-13) |
| [Sign](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Sign) | |
@@ -182,7 +182,7 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat
| [Split](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Split) | [2-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Split-2), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Split-11) |
| [SplitToSequence](https://github.com/onnx/onnx/blob/main/docs/Operators.md#SplitToSequence) | |
| [Sqrt](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Sqrt) | [6-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Sqrt-6), [13+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Sqrt-13) |
-| [Squeeze](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Squeeze) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Squeeze-1), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Squeeze-11), [13+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Squeeze-13) |
+| [Squeeze](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Squeeze) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Squeeze-1), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Squeeze-11), [13-20](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Squeeze-13), [21+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Squeeze-21) |
| [StringConcat](https://github.com/onnx/onnx/blob/main/docs/Operators.md#StringConcat) | |
| [StringNormalizer](https://github.com/onnx/onnx/blob/main/docs/Operators.md#StringNormalizer) | |
| [StringSplit](https://github.com/onnx/onnx/blob/main/docs/Operators.md#StringSplit) | |
@@ -194,10 +194,10 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat
| [ThresholdedRelu](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ThresholdedRelu) | |
| [Tile](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Tile) | [6-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Tile-6), [13+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Tile-13) |
| [TopK](https://github.com/onnx/onnx/blob/main/docs/Operators.md#TopK) | |
-| [Transpose](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Transpose) | [1-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Transpose-1), [13+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Transpose-13) |
+| [Transpose](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Transpose) | [1-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Transpose-1), [13-20](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Transpose-13), [21+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Transpose-21) |
| [Trilu](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Trilu) | |
| [Unique](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Unique) | |
-| [Unsqueeze](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Unsqueeze) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Unsqueeze-1), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Unsqueeze-11), [13+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Unsqueeze-13) |
+| [Unsqueeze](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Unsqueeze) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Unsqueeze-1), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Unsqueeze-11), [13-20](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Unsqueeze-13), [21+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Unsqueeze-21) |
| [Upsample](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Upsample) | [7-8](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Upsample-7), [9](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Upsample-9) |
| [Where](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Where) | |
| [Xor](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Xor) | [7+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Xor-7) |
diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md
index 2f510308d9306..c93f4f3cce68f 100644
--- a/js/web/docs/webgpu-operators.md
+++ b/js/web/docs/webgpu-operators.md
@@ -34,6 +34,7 @@ Do not modify directly.*
| Cos | ai.onnx(7+) | |
| Cosh | ai.onnx(9+) | |
| CumSum | ai.onnx(11-13,14+) | |
+| DepthToSpace | ai.onnx(11-12,13+); com.ms.internal.nhwc(11-12,13+) | |
| Div | ai.onnx(7-12,13,14+) | |
| Einsum | ai.onnx(12+) | |
| Elu | ai.onnx(6+) | |
@@ -41,6 +42,7 @@ Do not modify directly.*
| Erf | ai.onnx(9-12,13+) | |
| Exp | ai.onnx(6-12,13+) | |
| Expand | ai.onnx(8-12,13+) | |
+| FastGelu | com.microsoft(1+) | |
| Flatten | ai.onnx(1-8,9-10,11-12,13+) | |
| Floor | ai.onnx(6-12,13+) | |
| FusedConv | com.microsoft(1+) | |
@@ -52,6 +54,7 @@ Do not modify directly.*
| GlobalMaxPool | ai.onnx(1+); com.ms.internal.nhwc(1+) | |
| Greater | ai.onnx(7-8,9-12,13+) | |
| GreaterOrEqual | ai.onnx(12-15,16+) | |
+| HardSigmoid | ai.onnx(6+) | |
| If | ai.onnx(1-10,11-12,13-18,19+) | |
| InstanceNormalization | ai.onnx(6+); com.ms.internal.nhwc(6+) | |
| LayerNormalization | ai.onnx(17+) | |
@@ -60,6 +63,7 @@ Do not modify directly.*
| LessOrEqual | ai.onnx(12-15,16+) | |
| Log | ai.onnx(6-12,13+) | |
| MatMul | ai.onnx(1-12,13+) | |
+| MatMulNBits | com.microsoft(1+) | |
| MaxPool | ai.onnx(1-7,8-9,10,11,12+); com.ms.internal.nhwc(1-7,8-9,10,11,12+) | need perf optimization; need implementing activation |
| MemcpyFromHost | ai.onnx(1+) | |
| MemcpyToHost | ai.onnx(1+) | |
@@ -84,11 +88,14 @@ Do not modify directly.*
| Relu | ai.onnx(6-12,13,14+) | |
| Reshape | ai.onnx(5-12,13,14+) | no GPU kernel |
| Resize | ai.onnx(10,11-12,13-17,18,19+); com.ms.internal.nhwc(10,11-12,13-17,18,19+) | CoordinateTransformMode align_corners is not supported with downsampling |
+| RotaryEmbedding | com.microsoft(1+) | |
| Shape | ai.onnx(1-12,13-14,15+) | no GPU kernel; an ORT warning is generated - need to fix |
| Sigmoid | ai.onnx(6-12,13+) | |
+| SimplifiedLayerNormalization | ai.onnx(1+) | |
| Sin | ai.onnx(7+) | |
| Sinh | ai.onnx(9+) | |
| SkipLayerNormalization | com.microsoft(1+) | |
+| SkipSimplifiedLayerNormalization | com.microsoft(1+) | |
| Slice | ai.onnx(1-9,10,11-12,13+) | |
| Softmax | ai.onnx(1-10,11-12,13+) | |
| Split | ai.onnx(1,2-10,11-12,13-17,18+) | |
diff --git a/js/web/karma.conf.js b/js/web/karma.conf.js
index 8fce79843f617..507da0de2b4ad 100644
--- a/js/web/karma.conf.js
+++ b/js/web/karma.conf.js
@@ -9,6 +9,8 @@ const karmaPlugins = args['karma-plugins'] || undefined;
const timeoutMocha = args['timeout-mocha'] || 60000;
const forceLocalHost = !!args['force-localhost'];
+// user data directory; will be passed to the Edge/Chrome/ChromeCanary/Firefox launchers
+const userDataDir = args['user-data-dir'];
// parse chromium flags
let chromiumFlags = args['chromium-flags'];
if (!chromiumFlags) {
@@ -86,11 +88,12 @@ module.exports = function(config) {
hostname,
listenAddress,
customLaunchers: {
- // the following flags are used to make sure Edge on CI agents to initialize WebGPU correctly.
- EdgeTest: {base: 'Edge', flags: chromiumFlags},
- ChromeTest: {base: 'Chrome', flags: chromiumFlags},
- ChromeTestHeadless: {base: 'ChromeHeadless', flags: chromiumFlags},
- ChromeCanaryTest: {base: 'ChromeCanary', flags: chromiumFlags},
+ // Chromium-based browsers
+ EdgeTest: {base: 'Edge', flags: chromiumFlags, edgeDataDir: userDataDir},
+ ChromeTest: {base: 'Chrome', flags: chromiumFlags, chromeDataDir: userDataDir},
+ ChromeCanaryTest: {base: 'ChromeCanary', flags: chromiumFlags, chromeDataDir: userDataDir},
+ FirefoxTest: {base: 'Firefox', profile: userDataDir},
+
//
// ==== BrowserStack browsers ====
//
diff --git a/js/web/lib/backend-wasm.ts b/js/web/lib/backend-wasm.ts
index 2d123cdb71290..31ecffb07e40c 100644
--- a/js/web/lib/backend-wasm.ts
+++ b/js/web/lib/backend-wasm.ts
@@ -26,7 +26,17 @@ export const initializeFlags = (): void => {
env.wasm.proxy = false;
}
+ if (typeof env.wasm.trace !== 'boolean') {
+ env.wasm.trace = false;
+ }
+
if (typeof env.wasm.numThreads !== 'number' || !Number.isInteger(env.wasm.numThreads) || env.wasm.numThreads <= 0) {
+ // Web: when crossOriginIsolated is false, SharedArrayBuffer is not available so WebAssembly threads will not work.
+ // Node.js: onnxruntime-web does not support multi-threads in Node.js.
+ if ((typeof self !== 'undefined' && !self.crossOriginIsolated) ||
+ (typeof process !== 'undefined' && process.versions && process.versions.node)) {
+ env.wasm.numThreads = 1;
+ }
const numCpuLogicalCores = typeof navigator === 'undefined' ? cpus().length : navigator.hardwareConcurrency;
env.wasm.numThreads = Math.min(4, Math.ceil((numCpuLogicalCores || 1) / 2));
}
diff --git a/js/web/lib/build-def.d.ts b/js/web/lib/build-def.d.ts
index fb714bf5996f1..2c9cd88a375bd 100644
--- a/js/web/lib/build-def.d.ts
+++ b/js/web/lib/build-def.d.ts
@@ -19,7 +19,7 @@ interface BuildDefinitions {
*/
readonly DISABLE_WEBGPU: boolean;
/**
- * defines whether to disable the whole WebAssembly backend in the build.
+ * defines whether to disable the whole WebNN backend in the build.
*/
readonly DISABLE_WASM: boolean;
/**
diff --git a/js/web/lib/index.ts b/js/web/lib/index.ts
index 499327741c82b..b212c0f49df3b 100644
--- a/js/web/lib/index.ts
+++ b/js/web/lib/index.ts
@@ -23,13 +23,10 @@ if (!BUILD_DEFS.DISABLE_WASM) {
require('./backend-wasm-training').wasmBackend;
if (!BUILD_DEFS.DISABLE_WEBGPU) {
registerBackend('webgpu', wasmBackend, 5);
+ registerBackend('webnn', wasmBackend, 5);
}
registerBackend('cpu', wasmBackend, 10);
registerBackend('wasm', wasmBackend, 10);
- if (BUILD_DEFS.DISABLE_TRAINING) {
- registerBackend('xnnpack', wasmBackend, 9);
- registerBackend('webnn', wasmBackend, 9);
- }
}
Object.defineProperty(env.versions, 'web', {value: version, enumerable: true});
diff --git a/js/web/lib/onnxjs/model.ts b/js/web/lib/onnxjs/model.ts
index f9a1b6e76089d..8e689626011be 100644
--- a/js/web/lib/onnxjs/model.ts
+++ b/js/web/lib/onnxjs/model.ts
@@ -16,6 +16,7 @@ export class Model {
constructor() {}
load(buf: Uint8Array, graphInitializer?: Graph.Initializer, isOrtFormat?: boolean): void {
+ let onnxError: Error|undefined;
if (!isOrtFormat) {
// isOrtFormat === false || isOrtFormat === undefined
try {
@@ -25,10 +26,19 @@ export class Model {
if (isOrtFormat !== undefined) {
throw e;
}
+ onnxError = e;
}
}
- this.loadFromOrtFormat(buf, graphInitializer);
+ try {
+ this.loadFromOrtFormat(buf, graphInitializer);
+ } catch (e) {
+ if (isOrtFormat !== undefined) {
+ throw e;
+ }
+ // Tried both formats and failed (when isOrtFormat === undefined)
+ throw new Error(`Failed to load model as ONNX format: ${onnxError}\nas ORT format: ${e}`);
+ }
}
private loadFromOnnxFormat(buf: Uint8Array, graphInitializer?: Graph.Initializer): void {
diff --git a/js/web/lib/version.ts b/js/web/lib/version.ts
index 96c2361cceabe..40f970ddf02ae 100644
--- a/js/web/lib/version.ts
+++ b/js/web/lib/version.ts
@@ -4,4 +4,4 @@
// This file is generated by /js/scripts/update-version.ts
// Do not modify file content manually.
-export const version = '1.17.0';
+export const version = '1.18.0';
diff --git a/js/web/lib/wasm/binding/ort-wasm.d.ts b/js/web/lib/wasm/binding/ort-wasm.d.ts
index 00431a4e86d5b..56925b728e9a3 100644
--- a/js/web/lib/wasm/binding/ort-wasm.d.ts
+++ b/js/web/lib/wasm/binding/ort-wasm.d.ts
@@ -13,25 +13,105 @@ export declare namespace JSEP {
type ReleaseKernelFunction = (kernel: number) => void;
type RunFunction =
(kernel: number, contextDataOffset: number, sessionHandle: number, errors: Array>) => number;
+ type CaptureBeginFunction = () => void;
+ type CaptureEndFunction = () => void;
+ type ReplayFunction = () => void;
+
+ export interface Module extends WebGpuModule {
+ /**
+ * Mount the external data file to an internal map, which will be used during session initialization.
+ *
+ * @param externalDataFilePath - specify the relative path of the external data file.
+ * @param externalDataFileData - specify the content data.
+ */
+ mountExternalData(externalDataFilePath: string, externalDataFileData: Uint8Array): void;
+ /**
+ * Unmount all external data files from the internal map.
+ */
+ unmountExternalData(): void;
+
+ /**
+ * This is the entry of JSEP initialization. This function is called once when initializing ONNX Runtime per
+ * backend. This function initializes Asyncify support. If name is 'webgpu', also initializes WebGPU backend and
+ * registers a few callbacks that will be called in C++ code.
+ */
+ jsepInit(name: 'webgpu', initParams: [
+ backend: BackendType, alloc: AllocFunction, free: FreeFunction, upload: UploadFunction,
+ download: DownloadFunction, createKernel: CreateKernelFunction, releaseKernel: ReleaseKernelFunction,
+ run: RunFunction, captureBegin: CaptureBeginFunction, captureEnd: CaptureEndFunction, replay: ReplayFunction
+ ]): void;
+ jsepInit(name: 'webnn', initParams?: never): void;
+ }
+
+ export interface WebGpuModule {
+ /**
+ * [exported from wasm] Specify a kernel's output when running OpKernel::Compute().
+ *
+ * @param context - specify the kernel context pointer.
+ * @param index - specify the index of the output.
+ * @param data - specify the pointer to encoded data of type and dims.
+ */
+ _JsepOutput(context: number, index: number, data: number): number;
+ /**
+ * [exported from wasm] Get name of an operator node.
+ *
+ * @param kernel - specify the kernel pointer.
+ * @returns the pointer to a C-style UTF8 encoded string representing the node name.
+ */
+ _JsepGetNodeName(kernel: number): number;
+
+ /**
+ * [exported from js_internal_api.js] Register a user GPU buffer for usage of a session's input or output.
+ *
+ * @param sessionId - specify the session ID.
+ * @param index - specify an integer to represent which input/output it is registering for. For input, it is the
+ * input_index corresponding to the session's inputNames. For output, it is the inputCount + output_index
+ * corresponding to the session's ouputNames.
+ * @param buffer - specify the GPU buffer to register.
+ * @param size - specify the original data size in byte.
+ * @returns the GPU data ID for the registered GPU buffer.
+ */
+ jsepRegisterBuffer: (sessionId: number, index: number, buffer: GPUBuffer, size: number) => number;
+ /**
+ * [exported from js_internal_api.js] Get the GPU buffer by GPU data ID.
+ *
+ * @param dataId - specify the GPU data ID
+ * @returns the GPU buffer.
+ */
+ jsepGetBuffer: (dataId: number) => GPUBuffer;
+ /**
+ * [exported from js_internal_api.js] Create a function to be used to create a GPU Tensor.
+ *
+ * @param gpuBuffer - specify the GPU buffer
+ * @param size - specify the original data size in byte.
+ * @param type - specify the tensor type.
+ * @returns the generated downloader function.
+ */
+ jsepCreateDownloader:
+ (gpuBuffer: GPUBuffer, size: number,
+ type: Tensor.GpuBufferDataTypes) => () => Promise;
+ /**
+ * [exported from js_internal_api.js] Called when InferenceSession.run started. This function will be called before
+ * _OrtRun[WithBinding]() is called.
+ * @param sessionId - specify the session ID.
+ */
+ jsepOnRunStart: (sessionId: number) => void;
+ /**
+ * [exported from js_internal_api.js] Release a session. This function will be called before _OrtReleaseSession() is
+ * called.
+ * @param sessionId - specify the session ID.
+ * @returns
+ */
+ jsepOnReleaseSession: (sessionId: number) => void;
+ }
}
-export interface OrtWasmModule extends EmscriptenModule {
- // #region emscripten functions
- stackSave(): number;
- stackRestore(stack: number): void;
- stackAlloc(size: number): number;
-
- UTF8ToString(offset: number, maxBytesToRead?: number): string;
- lengthBytesUTF8(str: string): number;
- stringToUTF8(str: string, offset: number, maxBytes: number): void;
- // #endregion
-
- // #region ORT APIs
+export interface OrtInferenceAPIs {
_OrtInit(numThreads: number, loggingLevel: number): number;
_OrtGetLastError(errorCodeOffset: number, errorMessageOffset: number): void;
- _OrtCreateSession(dataOffset: number, dataLength: number, sessionOptionsHandle: number): number;
+ _OrtCreateSession(dataOffset: number, dataLength: number, sessionOptionsHandle: number): Promise;
_OrtReleaseSession(sessionHandle: number): void;
_OrtGetInputOutputCount(sessionHandle: number, inputCountOffset: number, outputCountOffset: number): number;
_OrtGetInputName(sessionHandle: number, index: number): number;
@@ -71,112 +151,61 @@ export interface OrtWasmModule extends EmscriptenModule {
_OrtReleaseRunOptions(runOptionsHandle: number): void;
_OrtEndProfiling(sessionHandle: number): number;
- // #endregion
+}
- // #region ORT Training APIs
- _OrtTrainingLoadCheckpoint?(dataOffset: number, dataLength: number): number;
+export interface OrtTrainingAPIs {
+ _OrtTrainingLoadCheckpoint(dataOffset: number, dataLength: number): number;
- _OrtTrainingReleaseCheckpoint?(checkpointHandle: number): void;
+ _OrtTrainingReleaseCheckpoint(checkpointHandle: number): void;
- _OrtTrainingCreateSession?
- (sessionOptionsHandle: number, checkpointHandle: number, trainOffset: number, trainLength: number,
- evalOffset: number, evalLength: number, optimizerOffset: number, optimizerLength: number): number;
+ _OrtTrainingCreateSession(
+ sessionOptionsHandle: number, checkpointHandle: number, trainOffset: number, trainLength: number,
+ evalOffset: number, evalLength: number, optimizerOffset: number, optimizerLength: number): number;
- _OrtTrainingLazyResetGrad?(trainingHandle: number): number;
+ _OrtTrainingLazyResetGrad(trainingHandle: number): number;
- _OrtTrainingRunTrainStep?
- (trainingHandle: number, inputsOffset: number, inputCount: number, outputsOffset: number, outputCount: number,
- runOptionsHandle: number): number;
+ _OrtTrainingRunTrainStep(
+ trainingHandle: number, inputsOffset: number, inputCount: number, outputsOffset: number, outputCount: number,
+ runOptionsHandle: number): number;
- _OrtTrainingOptimizerStep?(trainingHandle: number, runOptionsHandle: number): number;
+ _OrtTrainingOptimizerStep(trainingHandle: number, runOptionsHandle: number): number;
- _OrtTrainingEvalStep?
- (trainingHandle: number, inputsOffset: number, inputCount: number, outputsOffset: number, outputCount: number,
- runOptionsHandle: number): number;
+ _OrtTrainingEvalStep(
+ trainingHandle: number, inputsOffset: number, inputCount: number, outputsOffset: number, outputCount: number,
+ runOptionsHandle: number): number;
- _OrtTrainingGetParametersSize?(trainingHandle: number, paramSizeT: number, trainableOnly: boolean): number;
+ _OrtTrainingGetParametersSize(trainingHandle: number, paramSizeT: number, trainableOnly: boolean): number;
- _OrtTrainingCopyParametersToBuffer?
- (trainingHandle: number, parametersBuffer: number, parameterCount: number, trainableOnly: boolean): number;
+ _OrtTrainingCopyParametersToBuffer(
+ trainingHandle: number, parametersBuffer: number, parameterCount: number, trainableOnly: boolean): number;
- _OrtTrainingCopyParametersFromBuffer?
- (trainingHandle: number, parametersBuffer: number, parameterCount: number, trainableOnly: boolean): number;
+ _OrtTrainingCopyParametersFromBuffer(
+ trainingHandle: number, parametersBuffer: number, parameterCount: number, trainableOnly: boolean): number;
- _OrtTrainingGetModelInputOutputCount?
- (trainingHandle: number, inputCount: number, outputCount: number, isEvalModel: boolean): number;
- _OrtTrainingGetModelInputOutputName?
- (trainingHandle: number, index: number, isInput: boolean, isEvalModel: boolean): number;
+ _OrtTrainingGetModelInputOutputCount(
+ trainingHandle: number, inputCount: number, outputCount: number, isEvalModel: boolean): number;
+ _OrtTrainingGetModelInputOutputName(trainingHandle: number, index: number, isInput: boolean, isEvalModel: boolean):
+ number;
- _OrtTrainingReleaseSession?(trainingHandle: number): void;
+ _OrtTrainingReleaseSession(trainingHandle: number): void;
+}
+
+export interface OrtWasmModule extends EmscriptenModule, OrtInferenceAPIs, Partial,
+ Partial {
+ // #region emscripten functions
+ stackSave(): number;
+ stackRestore(stack: number): void;
+ stackAlloc(size: number): number;
+
+ UTF8ToString(offset: number, maxBytesToRead?: number): string;
+ lengthBytesUTF8(str: string): number;
+ stringToUTF8(str: string, offset: number, maxBytes: number): void;
// #endregion
// #region config
+ numThreads?: number;
mainScriptUrlOrBlob?: string|Blob;
// #endregion
-
- // #region JSEP
- /**
- * This is the entry of JSEP initialization. This function is called once when initializing ONNX Runtime.
- * This function initializes WebGPU backend and registers a few callbacks that will be called in C++ code.
- */
- jsepInit?
- (backend: JSEP.BackendType, alloc: JSEP.AllocFunction, free: JSEP.FreeFunction, upload: JSEP.UploadFunction,
- download: JSEP.DownloadFunction, createKernel: JSEP.CreateKernelFunction,
- releaseKernel: JSEP.ReleaseKernelFunction, run: JSEP.RunFunction): void;
-
- /**
- * [exported from wasm] Specify a kernel's output when running OpKernel::Compute().
- *
- * @param context - specify the kernel context pointer.
- * @param index - specify the index of the output.
- * @param data - specify the pointer to encoded data of type and dims.
- */
- _JsepOutput(context: number, index: number, data: number): number;
- /**
- * [exported from wasm] Get name of an operator node.
- *
- * @param kernel - specify the kernel pointer.
- * @returns the pointer to a C-style UTF8 encoded string representing the node name.
- */
- _JsepGetNodeName(kernel: number): number;
-
- /**
- * [exported from js_internal_api.js] Register a user GPU buffer for usage of a session's input or output.
- *
- * @param sessionId - specify the session ID.
- * @param index - specify an integer to represent which input/output it is registering for. For input, it is the
- * input_index corresponding to the session's inputNames. For output, it is the inputCount + output_index
- * corresponding to the session's ouputNames.
- * @param buffer - specify the GPU buffer to register.
- * @param size - specify the original data size in byte.
- * @returns the GPU data ID for the registered GPU buffer.
- */
- jsepRegisterBuffer: (sessionId: number, index: number, buffer: GPUBuffer, size: number) => number;
- /**
- * [exported from js_internal_api.js] Unregister all user GPU buffers for a session.
- *
- * @param sessionId - specify the session ID.
- */
- jsepUnregisterBuffers?: (sessionId: number) => void;
- /**
- * [exported from js_internal_api.js] Get the GPU buffer by GPU data ID.
- *
- * @param dataId - specify the GPU data ID
- * @returns the GPU buffer.
- */
- jsepGetBuffer: (dataId: number) => GPUBuffer;
- /**
- * [exported from js_internal_api.js] Create a function to be used to create a GPU Tensor.
- *
- * @param gpuBuffer - specify the GPU buffer
- * @param size - specify the original data size in byte.
- * @param type - specify the tensor type.
- * @returns the generated downloader function.
- */
- jsepCreateDownloader:
- (gpuBuffer: GPUBuffer, size: number,
- type: Tensor.GpuBufferDataTypes) => () => Promise;
- // #endregion
}
declare const moduleFactory: EmscriptenModuleFactory;
diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts
index 6c3d22352772e..1b421029cc7ae 100644
--- a/js/web/lib/wasm/jsep/backend-webgpu.ts
+++ b/js/web/lib/wasm/jsep/backend-webgpu.ts
@@ -1,14 +1,37 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
-import {Env, Tensor} from 'onnxruntime-common';
+import {Env, Tensor, TRACE, TRACE_FUNC_BEGIN, TRACE_FUNC_END} from 'onnxruntime-common';
+
+import {DataType, tensorDataTypeEnumToString} from '../wasm-common';
import {configureLogger, LOG_DEBUG} from './log';
import {createView, TensorView} from './tensor-view';
import {createGpuDataManager, downloadGpuData, GpuDataManager} from './webgpu/gpu-data-manager';
import {RunFunction, WEBGPU_OP_RESOLVE_RULES} from './webgpu/op-resolve-rules';
import {ProgramManager} from './webgpu/program-manager';
-import {ComputeContext, GpuData, ProgramInfo, ProgramInputTensorInfoDependency} from './webgpu/types';
+import {AdapterInfo, ComputeContext, GpuArchitecture, GpuData, GpuVendor, ProgramInfo, ProgramInputTensorInfoDependency, SessionState, TimestampQuery} from './webgpu/types';
+
+interface CommandInfo {
+ readonly kernelId: number;
+ readonly computePipeline: GPUComputePipeline;
+ readonly bindGroup: GPUBindGroup;
+ readonly dispatchGroup: [number, number, number];
+}
+
+interface KernelInfo {
+ readonly kernelType: string;
+ readonly kernelName: string;
+ readonly kernelEntry: RunFunction;
+ readonly attributes: [((attribute: unknown) => unknown)|undefined, unknown];
+}
+
+interface PendingKernelInfo {
+ readonly kernelId: number;
+ readonly programName: string;
+ readonly inputTensorViews: readonly TensorView[];
+ readonly outputTensorViews: readonly TensorView[];
+}
const getProgramInputTensorInfoDependencyKey =
(inputTensors: readonly TensorView[], inputDependencies: readonly ProgramInputTensorInfoDependency[]): string => {
@@ -71,11 +94,32 @@ const getProgramInfoUniqueKey =
return key;
};
+class AdapterInfoImpl implements AdapterInfo {
+ readonly architecture?: string;
+ readonly vendor?: string;
+
+ constructor(adapterInfo: GPUAdapterInfo) {
+ if (adapterInfo) {
+ this.architecture = adapterInfo.architecture;
+ this.vendor = adapterInfo.vendor;
+ }
+ }
+
+ isArchitecture(architecture: GpuArchitecture): boolean {
+ return this.architecture === architecture;
+ }
+
+ isVendor(vendor: GpuVendor): boolean {
+ return this.vendor === vendor;
+ }
+}
+
/**
* this class is designed to store status and being used as a singleton for JSEP. It will be passed to jsepInit() as
* the first parameter so that it is stored for future use.
*/
export class WebGpuBackend {
+ adapterInfo: AdapterInfoImpl;
device: GPUDevice;
/**
* an instance of GpuDataManager to manage a GpuDataId -> GpuBuffer mapping
@@ -87,6 +131,13 @@ export class WebGpuBackend {
*/
programManager: ProgramManager;
+ /**
+ * representing the session ID of which is currently being run.
+ * `null` means no session is being run.
+ * only valid when session.run is executed.
+ */
+ currentSessionId: number|null = null;
+
/**
* representing the kernel ID of which is currently being computed (CPU code perspective).
* `null` means no kernel is being computed.
@@ -122,22 +173,33 @@ export class WebGpuBackend {
return data;
}
- /**
- * a KernelID -> kernel info mapping. value is
- * [ op_type, name, run function, [optional] preprocess_attribute_once function ]
- */
- kernels: Map unknown) | undefined, unknown]]>;
-
+ // KernelID -> kernelInfo mapping
+ kernels: Map;
private commandEncoder: GPUCommandEncoder|null = null;
private computePassEncoder: GPUComputePassEncoder|null = null;
+ maxDispatchNumber = 16;
pendingDispatchNumber = 0;
- queryData?: GpuData;
- querySet?: GPUQuerySet;
- querySetCount = 2;
- queryTimeBase?: bigint;
+ // info of kernels pending submission for a single batch
+ private pendingKernels: PendingKernelInfo[] = [];
+ // queryReadBuffer -> pendingKernels mapping for all the batches
+ private pendingQueries: Map = new Map();
+ private queryResolveBuffer?: GPUBuffer;
+ private querySet?: GPUQuerySet;
+ private queryTimeBase?: bigint;
+ queryType: TimestampQuery;
env: Env;
+ sessionStatus: SessionState = 'default';
+ /**
+ * a SessionID -> CommandInfo[] mapping. It's used to record all GPU commands for corresponding session.
+ */
+ capturedCommandList: Map = new Map();
+
+ /**
+ * a SessionID -> PendingKernelInfo[] mapping for profiling.
+ */
+ private capturedPendingKernels: Map = new Map();
/**
* a SessionID -> a Map of (InputOutputIndex -> [ID, GPUBuffer]) mapping.
@@ -161,7 +223,9 @@ export class WebGpuBackend {
requiredFeatures,
};
- if (adapter.features.has('timestamp-query')) {
+ if (adapter.features.has('chromium-experimental-timestamp-query-inside-passes')) {
+ requiredFeatures.push('chromium-experimental-timestamp-query-inside-passes' as GPUFeatureName);
+ } else if (adapter.features.has('timestamp-query')) {
requiredFeatures.push('timestamp-query');
}
if (adapter.features.has('shader-f16')) {
@@ -169,6 +233,7 @@ export class WebGpuBackend {
}
this.device = await adapter.requestDevice(deviceDescriptor);
+ this.adapterInfo = new AdapterInfoImpl(await adapter.requestAdapterInfo());
this.gpuDataManager = createGpuDataManager(this);
this.programManager = new ProgramManager(this);
this.kernels = new Map();
@@ -187,7 +252,13 @@ export class WebGpuBackend {
}
};
- Object.defineProperty(this.env.webgpu, 'device', {value: this.device});
+ Object.defineProperty(
+ this.env.webgpu, 'device', {value: this.device, writable: false, enumerable: true, configurable: false});
+ Object.defineProperty(
+ this.env.webgpu, 'adapter', {value: adapter, writable: false, enumerable: true, configurable: false});
+
+ // init queryType, which is necessary for InferenceSession.create
+ this.setQueryType();
}
dispose(): void {
@@ -206,22 +277,18 @@ export class WebGpuBackend {
getComputePassEncoder(): GPUComputePassEncoder {
if (!this.computePassEncoder) {
+ const commandEncoder = this.getCommandEncoder();
const computePassDescriptor: GPUComputePassDescriptor = {};
- if (this.isQueryEnabled()) {
- if (typeof this.querySet === 'undefined') {
- this.querySet = this.device.createQuerySet({
- type: 'timestamp',
- count: this.querySetCount,
- });
- }
+
+ if (this.queryType === 'at-passes') {
computePassDescriptor.timestampWrites = {
- querySet: this.querySet,
- beginningOfPassWriteIndex: 0,
- endOfPassWriteIndex: 1,
+ querySet: this.querySet!,
+ beginningOfPassWriteIndex: this.pendingDispatchNumber * 2,
+ endOfPassWriteIndex: this.pendingDispatchNumber * 2 + 1,
};
}
- this.computePassEncoder = this.getCommandEncoder().beginComputePass(computePassDescriptor);
+ this.computePassEncoder = commandEncoder.beginComputePass(computePassDescriptor);
}
return this.computePassEncoder;
}
@@ -234,19 +301,95 @@ export class WebGpuBackend {
}
flush(): void {
- if (this.commandEncoder) {
- this.endComputePass();
- this.device.queue.submit([this.getCommandEncoder().finish()]);
- this.gpuDataManager.refreshPendingBuffers();
- this.commandEncoder = null;
- this.pendingDispatchNumber = 0;
+ if (!this.commandEncoder) {
+ return;
}
- }
- isQueryEnabled(): boolean {
- return this.device.features.has('timestamp-query') &&
- (this.env.webgpu.profiling?.mode === 'default' ||
- (!this.env.webgpu.profiling?.mode && this.env.webgpu.profilingMode === 'default'));
+ TRACE_FUNC_BEGIN();
+
+ this.endComputePass();
+ let queryReadBuffer: GPUBuffer;
+ if (this.queryType !== 'none') {
+ this.commandEncoder.resolveQuerySet(
+ this.querySet!, 0, this.pendingDispatchNumber * 2, this.queryResolveBuffer!, 0);
+
+ queryReadBuffer = this.device.createBuffer(
+ // eslint-disable-next-line no-bitwise
+ {size: this.pendingDispatchNumber * 2 * 8, usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST});
+
+ this.pendingQueries.set(queryReadBuffer, this.pendingKernels);
+ this.pendingKernels = [];
+ this.commandEncoder.copyBufferToBuffer(
+ this.queryResolveBuffer!, 0, queryReadBuffer, 0, this.pendingDispatchNumber * 2 * 8);
+ }
+
+ this.device.queue.submit([this.commandEncoder.finish()]);
+ this.gpuDataManager.refreshPendingBuffers();
+ this.commandEncoder = null;
+ this.pendingDispatchNumber = 0;
+
+ if (this.queryType !== 'none') {
+ void queryReadBuffer!.mapAsync(GPUMapMode.READ).then(() => {
+ const mappedData = new BigUint64Array(queryReadBuffer.getMappedRange());
+ const pendingKernels = this.pendingQueries.get(queryReadBuffer)!;
+ for (let i = 0; i < mappedData.length / 2; i++) {
+ const pendingKernelInfo = pendingKernels[i];
+ const kernelId = pendingKernelInfo.kernelId;
+ const kernelInfo = this.kernels.get(kernelId)!;
+ const kernelType = kernelInfo.kernelType;
+ const kernelName = kernelInfo.kernelName;
+ const programName = pendingKernelInfo.programName;
+ const inputTensorViews = pendingKernelInfo.inputTensorViews;
+ const outputTensorViews = pendingKernelInfo.outputTensorViews;
+ const startTimeU64 = mappedData[i * 2];
+ const endTimeU64 = mappedData[i * 2 + 1];
+
+ if (typeof this.queryTimeBase === 'undefined') {
+ this.queryTimeBase = startTimeU64;
+ }
+
+ const startTime = Number(startTimeU64 - this.queryTimeBase);
+ const endTime = Number(endTimeU64 - this.queryTimeBase);
+
+ if (!Number.isSafeInteger(startTime) || !Number.isSafeInteger(endTime)) {
+ throw new RangeError('incorrect timestamp range');
+ }
+
+ if (this.env.webgpu.profiling?.ondata) {
+ this.env.webgpu.profiling.ondata({
+ version: 1,
+ inputsMetadata: inputTensorViews.map(
+ value => ({dims: value.dims, dataType: tensorDataTypeEnumToString(value.dataType)})),
+ outputsMetadata: outputTensorViews.map(
+ value => ({dims: value.dims, dataType: tensorDataTypeEnumToString(value.dataType)})),
+ kernelId,
+ kernelType,
+ kernelName,
+ programName,
+ startTime,
+ endTime,
+ });
+ } else {
+ // if no callback is provided, print the profiling message to console
+ let inputShapes = '';
+ inputTensorViews.forEach((value, i) => {
+ inputShapes += `input[${i}]: [${value.dims}] | ${tensorDataTypeEnumToString(value.dataType)}, `;
+ });
+ let outputShapes = '';
+ outputTensorViews.forEach((value, i) => {
+ outputShapes += `output[${i}]: [${value.dims}] | ${tensorDataTypeEnumToString(value.dataType)}, `;
+ });
+ // eslint-disable-next-line no-console
+ console.log(`[profiling] kernel "${kernelId}|${kernelType}|${kernelName}|${programName}" ${inputShapes}${
+ outputShapes}execution time: ${endTime - startTime} ns`);
+ }
+ TRACE('GPU', `${programName}::${startTimeU64}::${endTimeU64}`);
+ }
+ queryReadBuffer.unmap();
+ this.pendingQueries.delete(queryReadBuffer);
+ });
+ }
+ TRACE_FUNC_END();
}
/**
@@ -263,14 +406,20 @@ export class WebGpuBackend {
run(program: ProgramInfo, inputTensorViews: readonly TensorView[], outputIndices: readonly number[],
createKernelOutput: (index: number, dataType: number, dims: readonly number[]) => TensorView,
createIntermediateOutput: (dataType: number, dims: readonly number[]) => TensorView): TensorView[] {
+ TRACE_FUNC_BEGIN(program.name);
// create info for inputs
const inputDatas: GpuData[] = [];
for (let i = 0; i < inputTensorViews.length; ++i) {
- const gpuData = this.gpuDataManager.get(inputTensorViews[i].data);
+ const data = inputTensorViews[i].data;
+ // if tensor view data is 0, it means the output is zero-sized tensor, and there is no GPU data for it.
+ if (data === 0) {
+ continue;
+ }
+ const gpuData = this.gpuDataManager.get(data);
if (!gpuData) {
- throw new Error(`no GPU data for input: ${inputTensorViews[i].data}`);
+ throw new Error(`no GPU data for input: ${data}`);
}
- inputDatas[i] = gpuData;
+ inputDatas.push(gpuData);
}
const {outputs, dispatchGroup, programUniforms} = program.getRunData(inputTensorViews);
@@ -300,6 +449,11 @@ export class WebGpuBackend {
const tensorView = (isTemporary || isPersistent) ?
createIntermediateOutput(outputs[i].dataType, outputs[i].dims) :
createKernelOutput(validatedOutputIndices[i], outputs[i].dataType, outputs[i].dims);
+ outputTensorViews.push(tensorView);
+ // if tensor view data is 0, it means the output is zero-sized tensor, and there is no GPU data for it.
+ if (tensorView.data === 0) {
+ continue;
+ }
const gpuData = this.gpuDataManager.get(tensorView.data);
if (!gpuData) {
throw new Error(`no GPU data for output: ${tensorView.data}`);
@@ -315,10 +469,24 @@ export class WebGpuBackend {
}
persistentData.push(gpuData);
}
- outputTensorViews.push(tensorView);
outputDatas.push(gpuData);
}
+ // when there are any zero-sized tensor in the inputs or outputs, we should report error unless all outputs are
+ // zero-sized tensors.
+ if (inputDatas.length !== inputTensorViews.length || outputDatas.length !== outputTensorViews.length) {
+ // if all outputs are zero-sized tensors, there is no need to run the program.
+ if (outputDatas.length === 0) {
+ TRACE_FUNC_END(program.name);
+ return outputTensorViews;
+ }
+ // if some outputs are zero-sized tensors, report an error.
+ //
+ // TODO: so far we don't see any use case that outputs include both zero-sized tensors and non-zero-sized tensors.
+ // If we see such use case, we need to make a change here to support it.
+ throw new Error(
+ `Program ${program.name} has zero-sized tensor(s) in inputs or outputs. This is not supported now.`);
+ }
// load uniforms
// TODO: add cache for uniform (is it necessary?)
@@ -334,13 +502,26 @@ export class WebGpuBackend {
return;
}
// https://www.w3.org/TR/WGSL/#alignof
- const baseAlignment = data.length <= 2 ? data.length * 4 : 16;
+ const sizeOfElement = v.type === DataType.float16 ? 2 : 4;
+ let sizeOfVecOrMat;
+ let baseAlignment;
+ if (v.type === DataType.float16) {
+ baseAlignment = data.length > 4 ? 16 : (data.length > 2 ? 8 : data.length * sizeOfElement);
+ sizeOfVecOrMat = data.length > 4 ? 16 : sizeOfElement * data.length;
+ } else {
+ baseAlignment = data.length <= 2 ? data.length * sizeOfElement : 16;
+ sizeOfVecOrMat = 16;
+ }
currentOffset = Math.ceil(currentOffset / baseAlignment) * baseAlignment;
offsets.push(currentOffset);
- // When data.length > 4, the uniform variable is of type array,N>, where N =
- // Math.ceil(data.length / 4) and SizeOf(vec4) = 16. The total byte length is N *
- // SizeOf(vec4).
- currentOffset += data.length > 4 ? Math.ceil(data.length / 4) * 16 : data.length * 4;
+ // For non-float16 type, when data.length > 4, the uniform variable is of type array,N>, where
+ // N = Math.ceil(data.length / 4) and SizeOf(vec4) = 16. The total byte length is N *
+ // SizeOf(vec4). For float16 type, when data.length > 4, the uniform variable is of type
+ // array,N>, where N = Math.ceil(data.length / 8) and SizeOf(mat2x4) = 16. The total byte
+ // length is N * SizeOf(mat2x4).
+ const elementPerVecOrMat = v.type === DataType.float16 ? 8 : 4;
+ currentOffset += data.length > 4 ? Math.ceil(data.length / elementPerVecOrMat) * sizeOfVecOrMat :
+ data.length * sizeOfElement;
});
// Meet alignment of struct here: https://www.w3.org/TR/WGSL/#alignment-and-size. For simplicity, set
@@ -351,12 +532,17 @@ export class WebGpuBackend {
programUniforms.forEach((v, i) => {
const offset = offsets[i];
const data = typeof v.data === 'number' ? [v.data] : v.data;
- if (v.type === 'int32') {
+ if (v.type === DataType.int32) {
new Int32Array(arrayBuffer, offset, data.length).set(data);
- } else if (v.type === 'uint32') {
+ } else if (v.type === DataType.uint32) {
new Uint32Array(arrayBuffer, offset, data.length).set(data);
- } else {
+ } else if (v.type === DataType.float16) {
+ // TODO: use Float16Array.
+ new Uint16Array(arrayBuffer, offset, data.length).set(data);
+ } else if (v.type === DataType.float) {
new Float32Array(arrayBuffer, offset, data.length).set(data);
+ } else {
+ throw new Error(`Unsupported uniform type: ${tensorDataTypeEnumToString(v.type)}`);
}
});
@@ -379,14 +565,47 @@ export class WebGpuBackend {
LOG_DEBUG('info', () => `[artifact] key: ${key}, programName: ${program.name}`);
}
+ // validate uniform variables
+ if (programUniforms && artifact.uniformVariablesInfo) {
+ if (programUniforms.length !== artifact.uniformVariablesInfo.length) {
+ throw new Error(`Uniform variables count mismatch: expect ${artifact.uniformVariablesInfo.length}, got ${
+ programUniforms.length} in program "${artifact.programInfo.name}".`);
+ }
+ for (let i = 0; i < programUniforms.length; i++) {
+ const uniform = programUniforms[i];
+ const actualType = uniform.type;
+ const actualLength = typeof uniform.data === 'number' ? 1 : uniform.data.length;
+ const [type, length] = artifact.uniformVariablesInfo[i];
+ if (actualType !== type || actualLength !== length) {
+ throw new Error(`Uniform variable ${i} mismatch: expect type ${type} with size ${length}, got type ${
+ actualType} with size ${actualLength} in program "${artifact.programInfo.name}".`);
+ }
+ }
+ }
+
LOG_DEBUG(
'info',
() => `[ProgramManager] run "${program.name}" (key=${key}) with ${normalizedDispatchGroup[0]}x${
normalizedDispatchGroup[1]}x${normalizedDispatchGroup[2]}`);
- this.programManager.run(
- artifact, inputTensorViews, outputTensorViews, inputDatas, outputDatas, normalizedDispatchGroup,
- uniformBufferBinding);
+ if (this.queryType !== 'none' || this.sessionStatus === 'capturing') {
+ const pendingKernelInfo: PendingKernelInfo = {
+ kernelId: this.currentKernelId!,
+ programName: artifact.programInfo.name,
+ inputTensorViews,
+ outputTensorViews,
+ };
+ this.pendingKernels.push(pendingKernelInfo);
+
+ if (this.sessionStatus === 'capturing') {
+ const sessionPendingKernels = this.capturedPendingKernels.get(this.currentSessionId!);
+ sessionPendingKernels!.push(pendingKernelInfo);
+ }
+ }
+
+ this.programManager.run(artifact, inputDatas, outputDatas, normalizedDispatchGroup, uniformBufferBinding);
+
+ TRACE_FUNC_END(program.name);
return outputTensorViews;
}
@@ -412,13 +631,19 @@ export class WebGpuBackend {
return this.gpuDataManager.release(ptr);
}
- createKernel(opType: string, kernelId: number, attribute: unknown, nodeName: string): void {
- const op = WEBGPU_OP_RESOLVE_RULES.get(opType);
+ createKernel(kernelType: string, kernelId: number, attribute: unknown, kernelName: string): void {
+ const op = WEBGPU_OP_RESOLVE_RULES.get(kernelType);
if (!op) {
- throw new Error(`kernel not implemented: ${opType}`);
+ throw new Error(`kernel not implemented: ${kernelType}`);
}
- this.kernels.set(kernelId, [opType, nodeName, op[0], [op[1], attribute]]);
+ const kernelInfo: KernelInfo = {
+ kernelType,
+ kernelName,
+ kernelEntry: op[0],
+ attributes: [op[1], attribute],
+ };
+ this.kernels.set(kernelId, kernelInfo);
}
releaseKernel(kernelId: number): void {
@@ -439,9 +664,12 @@ export class WebGpuBackend {
if (!kernel) {
throw new Error(`kernel not created: ${kernelId}`);
}
- const [opType, nodeName, kernelEntry, attributes] = kernel;
+ const kernelType = kernel.kernelType;
+ const kernelName = kernel.kernelName;
+ const kernelEntry = kernel.kernelEntry;
+ const attributes = kernel.attributes;
if (this.currentKernelId !== null) {
- throw new Error(`kernel "[${opType}] ${nodeName}" is not allowed to be called recursively`);
+ throw new Error(`kernel "[${kernelType}] ${kernelName}" is not allowed to be called recursively`);
}
this.currentKernelId = kernelId;
@@ -451,7 +679,7 @@ export class WebGpuBackend {
attributes[0] = undefined;
}
- LOG_DEBUG('info', () => `[WebGPU] Start to run kernel "[${opType}] ${nodeName}"...`);
+ LOG_DEBUG('info', () => `[WebGPU] Start to run kernel "[${kernelType}] ${kernelName}"...`);
const useErrorScope = this.env.debug;
@@ -464,12 +692,12 @@ export class WebGpuBackend {
kernelEntry(context, attributes[1]);
return 0; // ORT_OK
} catch (e) {
- errors.push(Promise.resolve(`[WebGPU] Kernel "[${opType}] ${nodeName}" failed. ${e}`));
+ errors.push(Promise.resolve(`[WebGPU] Kernel "[${kernelType}] ${kernelName}" failed. ${e}`));
return 1; // ORT_FAIL
} finally {
if (useErrorScope) {
errors.push(this.device.popErrorScope().then(
- err => err ? `GPU validation error for kernel "[${opType}] ${nodeName}": ${err.message}` : null));
+ err => err ? `GPU validation error for kernel "[${kernelType}] ${kernelName}": ${err.message}` : null));
}
for (const data of this.temporaryData) {
@@ -515,4 +743,98 @@ export class WebGpuBackend {
};
}
// #endregion
+ writeTimestamp(index: number): void {
+ if (this.queryType !== 'inside-passes') {
+ return;
+ }
+
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
+ (this.computePassEncoder as any).writeTimestamp(this.querySet, index);
+ }
+ setQueryType(): void {
+ this.queryType = 'none';
+ if (this.env.webgpu.profiling?.mode === 'default' ||
+ (typeof this.env.trace === 'undefined' ? this.env.wasm.trace : this.env.trace)) {
+ if (this.device.features.has('chromium-experimental-timestamp-query-inside-passes')) {
+ this.queryType = 'inside-passes';
+ } else if (this.device.features.has('timestamp-query')) {
+ this.queryType = 'at-passes';
+ }
+
+ if (this.queryType !== 'none' && typeof this.querySet === 'undefined') {
+ this.querySet = this.device.createQuerySet({
+ type: 'timestamp',
+ count: this.maxDispatchNumber * 2,
+ });
+ this.queryResolveBuffer = this.device.createBuffer(
+ // eslint-disable-next-line no-bitwise
+ {size: this.maxDispatchNumber * 2 * 8, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.QUERY_RESOLVE});
+ }
+ }
+ }
+
+ captureBegin(): void {
+ LOG_DEBUG('info', 'captureBegin');
+ if (!this.capturedCommandList.get(this.currentSessionId!)) {
+ this.capturedCommandList.set(this.currentSessionId!, []);
+ }
+ if (!this.capturedPendingKernels.get(this.currentSessionId!)) {
+ this.capturedPendingKernels.set(this.currentSessionId!, []);
+ }
+ // flush the left commands before we change the status.
+ this.flush();
+ this.sessionStatus = 'capturing';
+ }
+ captureEnd(): void {
+ LOG_DEBUG('info', 'captureEnd');
+ // flush the left commands before we change the status.
+ this.flush();
+ this.sessionStatus = 'default';
+ }
+ replay(): void {
+ LOG_DEBUG('info', 'replay');
+ this.sessionStatus = 'replaying';
+ const sessionCommandList = this.capturedCommandList.get(this.currentSessionId!);
+ const sessionPendingKernels = this.capturedPendingKernels.get(this.currentSessionId!);
+ const length = sessionCommandList!.length;
+ this.pendingKernels = [];
+ for (let i = 0; i < length; i++) {
+ const computePassEncoder = this.getComputePassEncoder();
+ const command = sessionCommandList![i];
+ this.writeTimestamp(this.pendingDispatchNumber * 2);
+ computePassEncoder.setPipeline(command.computePipeline);
+ computePassEncoder.setBindGroup(0, command.bindGroup);
+ computePassEncoder.dispatchWorkgroups(...command.dispatchGroup);
+ this.writeTimestamp(this.pendingDispatchNumber * 2 + 1);
+ this.pendingDispatchNumber++;
+ if (this.queryType !== 'none') {
+ this.pendingKernels.push(sessionPendingKernels![i]);
+ }
+ if (this.pendingDispatchNumber >= this.maxDispatchNumber || this.queryType === 'at-passes') {
+ this.endComputePass();
+ }
+ if (this.pendingDispatchNumber >= this.maxDispatchNumber) {
+ this.flush();
+ }
+ }
+ // flush the left commands before we change the status.
+ this.flush();
+ this.sessionStatus = 'default';
+ }
+
+ onReleaseSession(sessionId: number): void {
+ this.unregisterBuffers(sessionId);
+ if (this.capturedCommandList.has(sessionId)) {
+ this.capturedCommandList.delete(sessionId);
+ }
+ if (this.capturedPendingKernels.has(sessionId)) {
+ this.capturedPendingKernels.delete(sessionId);
+ }
+ this.gpuDataManager.onReleaseSession(sessionId);
+ }
+
+ onRunStart(sessionId: number): void {
+ this.currentSessionId = sessionId;
+ this.setQueryType();
+ }
}
diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts
index 3c6edf3ebb35d..1ceae2394f462 100644
--- a/js/web/lib/wasm/jsep/init.ts
+++ b/js/web/lib/wasm/jsep/init.ts
@@ -10,7 +10,7 @@ import {WebGpuBackend} from './backend-webgpu';
import {LOG_DEBUG} from './log';
import {TensorView} from './tensor-view';
import {ShapeUtil} from './util';
-import {ComputeContext, ComputeContextInputsOutputsMapping, ProgramInfo} from './webgpu/types';
+import {AdapterInfo, ComputeContext, ComputeContextInputsOutputsMapping, ProgramInfo} from './webgpu/types';
/* eslint-disable no-bitwise */
@@ -54,6 +54,7 @@ class TensorViewImpl implements TensorView {
}
class ComputeContextImpl implements ComputeContext {
+ readonly adapterInfo: AdapterInfo;
readonly opKernelContext: number;
readonly inputs: readonly TensorView[];
readonly outputCount: number;
@@ -66,6 +67,7 @@ class ComputeContextImpl implements ComputeContext {
private customDataOffset = 0;
private customDataSize = 0;
constructor(private module: OrtWasmModule, private backend: WebGpuBackend, contextDataOffset: number) {
+ this.adapterInfo = backend.adapterInfo;
const heapU32 = module.HEAPU32;
// extract context data
@@ -90,6 +92,17 @@ class ComputeContextImpl implements ComputeContext {
this.inputs = inputs;
}
+ getMaxComputeWorkgroupSizes(): [number, number, number] {
+ return [
+ this.backend.device.limits.maxComputeWorkgroupSizeX, this.backend.device.limits.maxComputeWorkgroupSizeY,
+ this.backend.device.limits.maxComputeWorkgroupSizeZ
+ ];
+ }
+
+ getMaxComputeWorkgroupStoragesize(): number {
+ return this.backend.device.limits.maxComputeWorkgroupStorageSize;
+ }
+
compute(program: ProgramInfo, inputsOutputsMapping?: ComputeContextInputsOutputsMapping): TensorView[] {
// prepare inputs. inputs should always be valid data.
const mappedInputs =
@@ -104,7 +117,8 @@ class ComputeContextImpl implements ComputeContext {
throw new Error(`Unsupported data type: ${dataType}`);
}
const bufferSize = elementSize * ShapeUtil.size(dims);
- return new TensorViewImpl(this.module, dataType, this.backend.gpuDataManager.create(bufferSize).id, dims);
+ const gpuDataId = bufferSize > 0 ? this.backend.gpuDataManager.create(bufferSize).id : 0;
+ return new TensorViewImpl(this.module, dataType, gpuDataId, dims);
};
return this.backend.run(program, mappedInputs, outputIndices, createKernelOutput, createTemporaryOutput);
}
@@ -118,7 +132,7 @@ class ComputeContextImpl implements ComputeContext {
for (let i = 0; i < dims.length; i++) {
this.module.HEAPU32[offset++] = dims[i];
}
- return this.module._JsepOutput(this.opKernelContext, index, data);
+ return this.module._JsepOutput!(this.opKernelContext, index, data);
} catch (e) {
throw new Error(
`Failed to generate kernel's output[${index}] with dims [${dims}]. ` +
@@ -133,27 +147,39 @@ class ComputeContextImpl implements ComputeContext {
/**
* Initialize JSEP with WebGPU backend.
*
- * This function will be called only once after the WebAssembly module is loaded and initialized ("_OrtInit" is called).
- * This function expects:
+ * This function will be called after the WebAssembly module is loaded and initialized ("_OrtInit" is called), once for
+ * each of the following EPs if they are specified:
+ * - "webgpu"
+ * - "webnn"
+ *
+ * For WebGPU, this function expects:
* - WebGPU is enabled in build (BUILD_DEFS.DISABLE_WEBGPU === false).
* - WebGPU is available in current environment. (a valid GPUAdapter is passed in)
+ *
+ * For WebNN, this function expects:
+ * - WebNN is enabled in build (BUILD_DEFS.DISABLE_WEBGPU === false).
+ * - WebNN is available in current environment. (navigator.ml is not undefined)
+ *
* If the WebAssembly module is not built with JSEP support, this function will throw an error. This will invalidate
- * 'webgpu' backend.
+ * 'webgpu'/'webnn' backend.
*
+ * @param name - the name of the EP, either "webgpu" or "webnn"
* @param module - the ORT WebAssembly module
* @param env - the ORT environment variable (ort.env)
* @param gpuAdapter - the pre-created GPU adapter
*/
-export const init = async(module: OrtWasmModule, env: Env, gpuAdapter: GPUAdapter): Promise => {
+export const init =
+ async(name: 'webgpu'|'webnn', module: OrtWasmModule, env: Env, gpuAdapter?: GPUAdapter): Promise => {
const jsepInit = module.jsepInit;
if (!jsepInit) {
throw new Error('Failed to initialize JSEP. The WebAssembly module is not built with JSEP support.');
}
- const backend = new WebGpuBackend();
- await backend.initialize(env, gpuAdapter);
+ if (name === 'webgpu') {
+ const backend = new WebGpuBackend();
+ await backend.initialize(env, gpuAdapter!);
- jsepInit(
+ jsepInit('webgpu', [
// backend
backend,
@@ -170,7 +196,7 @@ export const init = async(module: OrtWasmModule, env: Env, gpuAdapter: GPUAdapte
backend.memcpy(src, dst);
} else {
LOG_DEBUG('verbose', () => `[WebGPU] jsepCopyCpuToGpu: dataOffset=${src}, gpuDataId=${dst}, size=${size}`);
- const data = module.HEAPU8.subarray(src, src + size);
+ const data = module.HEAPU8.subarray(src >>> 0, (src >>> 0) + size);
backend.upload(dst, data);
}
},
@@ -182,13 +208,13 @@ export const init = async(module: OrtWasmModule, env: Env, gpuAdapter: GPUAdapte
'verbose',
() => `[WebGPU] jsepCopyGpuToCpu: gpuDataId=${gpuDataId}, dataOffset=${dataOffset}, size=${size}`);
- await backend.download(gpuDataId, () => module.HEAPU8.subarray(dataOffset, dataOffset + size));
+ await backend.download(
+ gpuDataId, () => module.HEAPU8.subarray(dataOffset >>> 0, (dataOffset >>> 0) + size));
},
// jsepCreateKernel
- (name: string, kernel: number, attribute: unknown) => backend.createKernel(
- name, kernel, attribute,
- env.debug || backend.isQueryEnabled() ? module.UTF8ToString(module._JsepGetNodeName(kernel)) : `${kernel}`),
+ (kernelType: string, kernelId: number, attribute: unknown) => backend.createKernel(
+ kernelType, kernelId, attribute, module.UTF8ToString(module._JsepGetNodeName!(kernelId))),
// jsepReleaseKernel
(kernel: number) => backend.releaseKernel(kernel),
@@ -201,5 +227,15 @@ export const init = async(module: OrtWasmModule, env: Env, gpuAdapter: GPUAdapte
contextDataOffset}`);
const context = new ComputeContextImpl(module, backend, contextDataOffset);
return backend.computeKernel(kernel, context, errors);
- });
+ },
+ // jsepCaptureBegin
+ () => backend.captureBegin(),
+ // jsepCaptureEnd
+ () => backend.captureEnd(),
+ // jsepReplay
+ () => backend.replay()
+ ]);
+ } else {
+ jsepInit('webnn');
+ }
};
diff --git a/js/web/lib/wasm/jsep/util.ts b/js/web/lib/wasm/jsep/util.ts
index 6922d7ff5df6e..9a1d5463f7843 100644
--- a/js/web/lib/wasm/jsep/util.ts
+++ b/js/web/lib/wasm/jsep/util.ts
@@ -56,7 +56,16 @@ export class BroadcastUtil {
if (aLen !== bLen && aLen > 1 && bLen > 1) {
return undefined;
}
- cdims[crank - i] = Math.max(aLen, bLen);
+ const max = Math.max(aLen, bLen);
+ if (aLen && bLen) {
+ cdims[crank - i] = Math.max(aLen, bLen);
+ } else {
+ // when either aLen or bLen is 0, the other should be either 0 or 1, otherwise it is not broadcastable.
+ if (max > 1) {
+ return undefined;
+ }
+ cdims[crank - i] = 0;
+ }
}
return cdims;
@@ -92,6 +101,34 @@ export class ShapeUtil {
return ShapeUtil.getSizeFromDimensionRange(dims, 0, dims.length);
}
+ /**
+ * convert dims corresponding to type change to pack. ex. uint8 data to uint32
+ */
+ static convertShape(dims: readonly number[], size = 4): readonly number[] {
+ const rank = dims.length;
+ if (rank === 0) {
+ return [];
+ }
+ const newDims = new Array(rank);
+ let i = rank - 1;
+ while (i >= 0) {
+ if (dims[i] % size === 0) {
+ newDims[i] = dims[i] / size;
+ break;
+ }
+ if (size % dims[i] !== 0) {
+ throw new Error('cannot convert shape');
+ }
+ newDims[i] = 1;
+ size /= dims[i];
+ i--;
+ }
+ for (i--; i >= 0; i--) {
+ newDims[i] = dims[i];
+ }
+ return newDims;
+ }
+
/**
* calculate the size (number of elements) from the given axis (inclusive)
*/
diff --git a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts
index 6f3d9a52d9f5d..c17bd1e1477ec 100644
--- a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts
+++ b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts
@@ -60,9 +60,15 @@ export interface GpuDataManager {
unregisterExternalBuffer(buffer: GPUBuffer): void;
/**
- * destroy all gpu buffers. Call this when the session.release is called.
+ * destroy all gpu buffers.
*/
dispose(): void;
+
+ /**
+ * release session related data.
+ * @param sessionId - specify the session ID.
+ */
+ onReleaseSession(sessionId: number): void;
}
interface StorageCacheValue {
@@ -139,6 +145,10 @@ class GpuDataManagerImpl implements GpuDataManager {
// The external buffers registered users for IO Binding.
private externalBuffers: Map;
+ // The pendingBuffers for capture graph.
+ // a SessionID -> GPUBuffer[] mapping.
+ private capturedPendingBuffers: Map;
+
constructor(private backend: WebGpuBackend) {
this.storageCache = new Map();
this.freeBuffers = new Map();
@@ -146,6 +156,7 @@ class GpuDataManagerImpl implements GpuDataManager {
this.buffersForUploadingPending = [];
this.buffersPending = [];
this.externalBuffers = new Map();
+ this.capturedPendingBuffers = new Map();
}
upload(id: GpuDataId, data: Uint8Array): void {
@@ -220,6 +231,9 @@ class GpuDataManagerImpl implements GpuDataManager {
() => `[WebGPU] GpuDataManager.registerExternalBuffer(size=${originalSize}) => id=${
id}, buffer is the same, skip.`);
return id;
+ } else if (this.backend.capturedCommandList.has(this.backend.currentSessionId!)) {
+ throw new Error(`Registering a different external buffer under graph capture mode is not supported yet.
+ Please use the previous external buffer!`);
}
this.externalBuffers.delete(previousBuffer);
} else {
@@ -312,20 +326,39 @@ class GpuDataManagerImpl implements GpuDataManager {
buffer.destroy();
}
this.buffersForUploadingPending = [];
- for (const buffer of this.buffersPending) {
- // eslint-disable-next-line no-bitwise
- if ((buffer.usage & GPUBufferUsage.STORAGE) === GPUBufferUsage.STORAGE) {
- // Put the pending buffer to freeBuffers list instead of really destroying it for buffer reusing.
- this.freeBuffers.get(buffer.size)!.push(buffer);
+
+ if (this.buffersPending.length === 0) {
+ return;
+ }
+
+ if (this.backend.sessionStatus === 'default') {
+ for (const buffer of this.buffersPending) {
// eslint-disable-next-line no-bitwise
- } else if ((buffer.usage & GPUBufferUsage.UNIFORM) === GPUBufferUsage.UNIFORM) {
- // Put the pending buffer to freeUniformBuffers list instead of really destroying it for buffer reusing.
- this.freeUniformBuffers.get(buffer.size)!.push(buffer);
- } else {
- buffer.destroy();
+ if ((buffer.usage & GPUBufferUsage.STORAGE) === GPUBufferUsage.STORAGE) {
+ // Put the pending buffer to freeBuffers list instead of really destroying it for buffer reusing.
+ this.freeBuffers.get(buffer.size)!.push(buffer);
+ // eslint-disable-next-line no-bitwise
+ } else if ((buffer.usage & GPUBufferUsage.UNIFORM) === GPUBufferUsage.UNIFORM) {
+ // Put the pending buffer to freeUniformBuffers list instead of really destroying it for buffer reusing.
+ this.freeUniformBuffers.get(buffer.size)!.push(buffer);
+ } else {
+ buffer.destroy();
+ }
+ }
+ this.buffersPending = [];
+ } else {
+ // Don't release intermediate tensors in non-default mode.
+ // TODO: reuse the storage buffers in non-default mode.
+ let capturedBuffers = this.capturedPendingBuffers.get(this.backend.currentSessionId!);
+ if (!capturedBuffers) {
+ capturedBuffers = [];
+ this.capturedPendingBuffers.set(this.backend.currentSessionId!, capturedBuffers);
}
+ for (const buffer of this.buffersPending) {
+ capturedBuffers.push(buffer);
+ }
+ this.buffersPending = [];
}
- this.buffersPending = [];
}
dispose() {
@@ -344,9 +377,26 @@ class GpuDataManagerImpl implements GpuDataManager {
storage.gpuData.buffer.destroy();
});
+ this.capturedPendingBuffers.forEach((buffers) => {
+ buffers.forEach(buffer => {
+ buffer.destroy();
+ });
+ });
this.storageCache = new Map();
this.freeBuffers = new Map();
this.freeUniformBuffers = new Map();
+ this.capturedPendingBuffers = new Map();
+ }
+
+ onReleaseSession(sessionId: number) {
+ // release the captured pending buffers.
+ const pendingBuffers = this.capturedPendingBuffers.get(sessionId);
+ if (pendingBuffers) {
+ pendingBuffers.forEach(buffer => {
+ buffer.destroy();
+ });
+ this.capturedPendingBuffers.delete(sessionId);
+ }
}
}
diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts
index 8e1ec782079be..5627365100d9b 100644
--- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts
+++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts
@@ -2,7 +2,7 @@
// Licensed under the MIT License.
import {argMax, argMin, parseArgMinMaxAttributes} from './ops/argminmax';
-import {attention, parseAttentionAttributes} from './ops/attention';
+import {attention} from './ops/attention';
import {batchNorm} from './ops/batch-norm';
import {biasAdd} from './ops/bias-add';
import {biasSplitGelu} from './ops/bias-split-gelu';
@@ -11,21 +11,25 @@ import {concat, parseConcatAttributes} from './ops/concat';
import {conv, parseConvAttributes} from './ops/conv';
import {convTranspose, parseConvTransposeAttributes} from './ops/conv-transpose';
import {cumsum, parseCumSumAttributes} from './ops/cumsum';
+import {depthToSpace, parseDepthToSpaceAttributes} from './ops/depth-to-space';
import {einsum, parseEinsumAttributes} from './ops/einsum';
import {expand} from './ops/expand';
+import {fastGelu} from './ops/fast-gelu';
import {gather, parseGatherAttributes} from './ops/gather';
import {gatherElements, parseGatherElementsAttributes} from './ops/gather-elements';
import {gemm, parseGemmAttributes} from './ops/gemm';
-import {instanceNorm, parseInstanceNormAttributes} from './ops/instance-norm';
-import {layerNorm, parseLayerNormAttributes} from './ops/layer-norm';
+import {instanceNorm} from './ops/instance-norm';
+import {layerNorm} from './ops/layer-norm';
import {matMul} from './ops/matmul';
+import {matMulNBits, parseMatMulNBitsAttributes} from './ops/matmulnbits';
import {multiHeadAttention, parseMultiHeadAttentionAttributes} from './ops/multi-head-attentiion';
-import {pad, parsePadAttributes} from './ops/pad';
+import {pad} from './ops/pad';
import * as pool from './ops/pool';
import {range} from './ops/range';
import {reduceL1, reduceL2, reduceLogSum, reduceLogSumExp, reduceMax, reduceMean, reduceMin, reduceProd, reduceSum, reduceSumSquare} from './ops/reduce';
import {parseResizeAttributes, resize} from './ops/resize';
-import {parseSkipLayerNormAttributes, skipLayerNorm} from './ops/skip-layer-norm';
+import {rotaryEmbedding} from './ops/rotary-embedding';
+import {skipLayerNorm} from './ops/skip-layer-norm';
import {parseSliceAttributes, slice} from './ops/slice';
import {parseSoftmaxAttributes, softmax} from './ops/softmax';
import {parseSplitAttributes, split} from './ops/split';
@@ -50,7 +54,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new
['Asinh', [unaryOps.asinh]],
['Atan', [unaryOps.atan]],
['Atanh', [unaryOps.atanh]],
- ['Attention', [attention, parseAttentionAttributes]],
+ ['Attention', [attention]],
// TODO: support new attributes for AveragePool-10
['AveragePool', [pool.averagePool, pool.parseAveragePoolAttributes]],
['BatchNormalization', [batchNorm]],
@@ -65,6 +69,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new
['Cos', [unaryOps.cos]],
['Cosh', [unaryOps.cosh]],
['CumSum', [cumsum, parseCumSumAttributes]],
+ ['DepthToSpace', [depthToSpace, parseDepthToSpaceAttributes]],
['Div', [binaryOps.div]],
['Einsum', [einsum, parseEinsumAttributes]],
['Elu', [unaryOps.elu, unaryOps.parseAlphaAttributes]],
@@ -72,6 +77,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new
['Erf', [unaryOps.erf]],
['Exp', [unaryOps.exp]],
['Expand', [expand]],
+ ['FastGelu', [fastGelu]],
['Floor', [unaryOps.floor]],
['FusedConv', [conv, parseConvAttributes]],
['Gather', [gather, parseGatherAttributes]],
@@ -82,20 +88,22 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new
['GlobalMaxPool', [pool.globalMaxPool, pool.parseGlobalMaxPoolAttributes]],
['Greater', [binaryOps.greater]],
['GreaterOrEqual', [binaryOps.greaterOrEqual]],
- ['InstanceNormalization', [instanceNorm, parseInstanceNormAttributes]],
- ['LayerNormalization', [layerNorm, parseLayerNormAttributes]],
+ ['HardSigmoid', [unaryOps.hardSigmoid, unaryOps.parseHardSigmoidAttributes]],
+ ['InstanceNormalization', [instanceNorm]],
+ ['LayerNormalization', [layerNorm]],
['LeakyRelu', [unaryOps.leakyRelu, unaryOps.parseAlphaAttributes]],
['Less', [binaryOps.less]],
['LessOrEqual', [binaryOps.lessOrEqual]],
['Log', [unaryOps.log]],
['MatMul', [matMul]],
+ ['MatMulNBits', [matMulNBits, parseMatMulNBitsAttributes]],
// TODO: support new attributes for MaxPool-8 and MaxPool-10
['MaxPool', [pool.maxPool, pool.parseMaxPoolAttributes]],
['Mul', [binaryOps.mul]],
['MultiHeadAttention', [multiHeadAttention, parseMultiHeadAttentionAttributes]],
['Neg', [unaryOps.neg]],
['Not', [unaryOps.not]],
- ['Pad', [pad, parsePadAttributes]],
+ ['Pad', [pad]],
['Pow', [binaryOps.pow]],
['Range', [range]],
['Reciprocal', [unaryOps.reciprocal]],
@@ -111,11 +119,12 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new
['ReduceSumSquare', [reduceSumSquare]],
['Relu', [unaryOps.relu]],
['Resize', [resize, parseResizeAttributes]],
+ ['RotaryEmbedding', [rotaryEmbedding]],
['Sigmoid', [unaryOps.sigmoid]],
['Sin', [unaryOps.sin]],
['Sinh', [unaryOps.sinh]],
['Slice', [slice, parseSliceAttributes]],
- ['SkipLayerNormalization', [skipLayerNorm, parseSkipLayerNormAttributes]],
+ ['SkipLayerNormalization', [skipLayerNorm]],
['Split', [split, parseSplitAttributes]],
['Sqrt', [unaryOps.sqrt]],
['Softmax', [softmax, parseSoftmaxAttributes]],
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts
index 3638938df7dbe..24006d393592a 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts
@@ -19,12 +19,13 @@
//
// modified to fit the needs of the project
+import {DataType} from '../../../../wasm-common';
import {LOG_DEBUG} from '../../../log';
import {TensorView} from '../../../tensor-view';
-import {ProgramInfo, ProgramUniform} from '../../types';
-import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common';
+import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types';
+import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from '../common';
import {ConvAttributes} from '../conv';
-import {getActivationSnippet} from '../fuse-utils';
+import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet} from '../fuse-utils';
import {biasSnippet, typeSnippet} from './activation_util';
import {utilFunctions} from './conv_util';
@@ -88,10 +89,10 @@ const conv2dCommonSnippet =
let outRow = ${row} / outWidth;
let outCol = ${row} % outWidth;
- let WRow = ${col} / (filterDims[1] * inChannels);
- let WCol = ${col} / inChannels % filterDims[1];
- let xRow = outRow * stride[0] + dilation[0] * WRow - pad[0];
- let xCol = outCol * stride[1] + dilation[1] * WCol - pad[1];
+ let WRow = ${col} / (i32(uniforms.w_shape[1]) * inChannels);
+ let WCol = ${col} / inChannels % i32(uniforms.w_shape[1]);
+ let xRow = outRow * uniforms.stride[0] + uniforms.dilation[0] * WRow - uniforms.pad[0];
+ let xCol = outCol * uniforms.stride[1] + uniforms.dilation[1] * WCol - uniforms.pad[1];
let xCh = ${col} % inChannels;
var resData = ${typeSnippet(innerElementSizeX, dataType)}(0.0);
// The bounds checking is always needed since we use it to pad zero for
@@ -108,7 +109,7 @@ const conv2dCommonSnippet =
${readXSnippet}` :
`
let col = colIn * ${innerElementSizeX};
- if (row < uniforms.dimAOuter && col < uniforms.dimInner) {
+ if (row < uniforms.dim_a_outer && col < uniforms.dim_inner) {
${readXSnippet}
}
return ${typeSnippet(innerElementSizeX, dataType)}(0.0);`) :
@@ -117,7 +118,7 @@ const conv2dCommonSnippet =
${readXSnippet}` :
`
let col = colIn * ${innerElementSizeX};
- if (row < uniforms.dimInner && col < uniforms.dimBOuter) {
+ if (row < uniforms.dim_inner && col < uniforms.dim_b_outer) {
${readXSnippet}
}
return ${typeSnippet(innerElementSizeX, dataType)}(0.0);`);
@@ -129,9 +130,8 @@ const conv2dCommonSnippet =
isChannelsLast ? typeSnippet(innerElementSizeX, dataType) : typeSnippet(innerElementSizeW, dataType);
const bType =
isChannelsLast ? typeSnippet(innerElementSizeW, dataType) : typeSnippet(innerElementSizeX, dataType);
- const {activationFunction, applyActivation} = getActivationSnippet(attributes, resType);
+ const applyActivation = getActivationSnippet(attributes, resType, dataType);
const userCode = `
- ${activationFunction}
fn mm_readA(batch: i32, row : i32, colIn : i32) -> ${aType} {
${isChannelsLast ? sampleX : sampleW}
}
@@ -142,7 +142,7 @@ const conv2dCommonSnippet =
fn mm_write(batch: i32, row : i32, colIn : i32, valueIn : ${resType}) {
let col = colIn * ${innerElementSize};
- if (row < uniforms.dimAOuter && col < uniforms.dimBOuter)
+ if (row < uniforms.dim_a_outer && col < uniforms.dim_b_outer)
{
var value = valueIn;
let outWidth = ${isChannelsLast ? 'i32(uniforms.result_shape[2])' : 'i32(uniforms.result_shape[3])'};
@@ -181,31 +181,40 @@ export const createConv2DMatMulProgramInfo =
LOG_DEBUG('verbose', () => `[conv2d_mm_webgpu] dispatch = ${dispatch}`);
const innerElementSize = isVec4 ? (isChannelsLast && inChannels % 4 !== 0 ? 3 : 4) : 1;
-
const tileAOuter = workGroupSize[1] * elementsPerThread[1];
const tileBOuter = workGroupSize[0] * elementsPerThread[0];
const tileInner = Math.max(workGroupSize[0] * innerElementSize, workGroupSize[1]);
-
const fitAOuter = dimAOuter % tileAOuter === 0;
const fitBOuter = dimBOuter % tileBOuter === 0;
const fitInner = dimInner % tileInner === 0;
-
const elementsSize = isVec4 ? [innerElementSize, 4, 4] : [1, 1, 1];
- const t = tensorTypeToWsglStorageType(inputs[0].dataType);
- // TODO: support component 2, 3.
- const components = isVec4 ? 4 : 1;
- const programUniforms: ProgramUniform[] =
- [{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}];
- const x =
- inputVariable('x', inputs[0].dataType, inputs[0].dims.length, innerElementSize === 3 ? 1 : innerElementSize);
- const w = inputVariable('w', inputs[1].dataType, inputs[1].dims.length, components);
- const inputVariables = [x, w];
+ const programUniforms: ProgramUniform[] = [
+ {type: DataType.int32, data: dimAOuter}, {type: DataType.int32, data: dimBOuter},
+ {type: DataType.int32, data: dimInner}, {type: DataType.int32, data: [attributes.pads[0], attributes.pads[1]]},
+ {type: DataType.int32, data: attributes.strides}, {type: DataType.int32, data: attributes.dilations}
+ ];
+ appendActivationUniformsData(attributes, programUniforms);
+ programUniforms.push(...createTensorShapeVariables(inputs[0].dims, inputs[1].dims));
+ const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank'];
+ if (hasBias) {
+ programUniforms.push(...createTensorShapeVariables(inputs[2].dims));
+ inputDependencies.push('rank');
+ }
+ programUniforms.push(...createTensorShapeVariables(outputShape));
- programUniforms.push(...createTensorShapeVariables(inputs[0].dims));
- programUniforms.push(...createTensorShapeVariables(inputs[1].dims));
+ const getShaderSource = (shaderHelper: ShaderHelper) => {
+ const uniforms: UniformsArrayType = [
+ {name: 'dim_a_outer', type: 'i32'}, {name: 'dim_b_outer', type: 'i32'}, {name: 'dim_inner', type: 'i32'},
+ {name: 'pad', type: 'i32', length: 2}, {name: 'stride', type: 'i32', length: 2},
+ {name: 'dilation', type: 'i32', length: 2}
+ ];
+ appendActivationUniforms(attributes, uniforms);
- let declareFunctions = `
+ // TODO: support component 2, 3.
+ const components = isVec4 ? 4 : 1;
+ const t = tensorTypeToWsglStorageType(inputs[0].dataType);
+ let declareFunctions = `
fn setOutputAtIndex(flatIndex : i32, value : ${isVec4 ? `vec4<${t}>` : t}) {
result[flatIndex] = ${isVec4 ? `vec4<${t}>` : t}(value);
}
@@ -213,51 +222,50 @@ export const createConv2DMatMulProgramInfo =
let flatIndex = getOutputIndexFromCoords(vec4(d0, d1, d2, d3));
setOutputAtIndex(flatIndex ${isVec4 ? '/ 4' : ''}, value);
}`;
- if (hasBias) {
- const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components);
- inputVariables.push(bias);
-
- programUniforms.push(...createTensorShapeVariables(inputs[2].dims));
-
- declareFunctions += `
+ const x = inputVariable(
+ 'x', inputs[0].dataType, inputs[0].dims.length, innerElementSize === 3 ? 1 : innerElementSize);
+ const w = inputVariable('w', inputs[1].dataType, inputs[1].dims.length, components);
+ const inputVariables = [x, w];
+ const output = outputVariable('result', inputs[0].dataType, outputShape.length, components);
+ if (hasBias) {
+ const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components);
+ inputVariables.push(bias);
+ declareFunctions += `
fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? `vec4<${t}>` : t} {
return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}];
}`;
- }
- const output = outputVariable('result', inputs[0].dataType, outputShape.length, components);
- programUniforms.push(...createTensorShapeVariables(outputShape));
- return {
- name: 'Conv2DMatMul',
- shaderCache: {hint: attributes.cacheKey},
- getRunData: () => ({
- outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
- dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]},
- programUniforms,
- }),
- getShaderSource: (shaderHelper: ShaderHelper) => `
+ }
+
+ return `
${utilFunctions('uniforms.result_strides')}
//struct Uniforms { xShape : vec4, wShape : vec4, outShape : vec4,
// outShapeStrides: vec3, filterDims : vec2, pad : vec2, stride : vec2,
// dilation : vec2, dimAOuter : i32, dimBOuter : i32, dimInner : i32 };
- ${
- shaderHelper.registerUniform('dimAOuter', 'i32')
- .registerUniform('dimBOuter', 'i32')
- .registerUniform('dimInner', 'i32')
- .declareVariables(...inputVariables, output)}
- const filterDims : vec2 = vec2(${attributes.kernelShape[0]}, ${attributes.kernelShape[1]});
- const pad : vec2 = vec2(${attributes.pads[0]}, ${attributes.pads[1]});
- const stride : vec2 = vec2(${attributes.strides[0]}, ${attributes.strides[1]});
- const dilation : vec2 = vec2(${attributes.dilations[0]}, ${attributes.dilations[1]});
+ ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)}
${declareFunctions}
${
conv2dCommonSnippet(
isChannelsLast, fitAOuter, fitBOuter, fitInner, hasBias, attributes, elementsSize[0], elementsSize[1],
elementsSize[2], t)}
- ${
+ ${
isVec4 ?
makeMatMulPackedVec4Source(elementsPerThread, workGroupSize, t, undefined, !isChannelsLast, tileInner) :
makeMatMulPackedSource(
elementsPerThread, workGroupSize, t, undefined, !isChannelsLast, tileInner, false, undefined,
- sequentialAccessByThreads)}`
+ sequentialAccessByThreads)}`;
+ };
+ return {
+ name: 'Conv2DMatMul',
+ shaderCache: {
+ hint: `${attributes.cacheKey};${innerElementSize};${isVec4};${fitAOuter};${fitBOuter};${fitInner};${
+ tileAOuter};${tileBOuter};${tileInner}`,
+ inputDependencies
+ },
+ getRunData: () => ({
+ outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
+ dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]},
+ programUniforms,
+ }),
+ getShaderSource
};
};
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts
index d425155857e14..080b24a2432aa 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts
@@ -19,20 +19,21 @@
//
// modified to fit the needs of the project
+import {DataType} from '../../../../wasm-common';
import {LOG_DEBUG} from '../../../log';
import {TensorView} from '../../../tensor-view';
-import {ProgramInfo, ProgramUniform} from '../../types';
-import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from '../common';
+import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types';
+import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from '../common';
import {ConvTransposeAttributes} from '../conv-transpose';
-import {getActivationSnippet} from '../fuse-utils';
+import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet} from '../fuse-utils';
-import {biasSnippet, typeSnippet} from './activation_util';
+import {biasSnippet} from './activation_util';
import {utilFunctions} from './conv_util';
import {makeMatMulPackedSource, makeMatMulPackedVec4Source} from './matmul_packed_webgpu';
const conv2dTransposeCommonSnippet =
- (isChannelsLast: boolean, addBias = false, attributes: ConvTransposeAttributes, innerElementSize = 4): string => {
- const type = typeSnippet(innerElementSize, 'f32');
+ (isChannelsLast: boolean, addBias = false, attributes: ConvTransposeAttributes, type: string,
+ innerElementSize = 4): string => {
const getWSnippet = (innerElementSize: number) => {
switch (innerElementSize) {
case 1:
@@ -46,7 +47,7 @@ const conv2dTransposeCommonSnippet =
let v1 = w[getIndexFromCoords4D(coord1, vec4