Skip to content

Commit

Permalink
Merge branch 'main' of github.com:microsoft/onnxruntime into sajandhy…
Browse files Browse the repository at this point in the history
…/webgpu_group_query_attention_update
  • Loading branch information
satyajandhyala committed Sep 17, 2024
2 parents 2a57e70 + 291a535 commit fe8d5ff
Show file tree
Hide file tree
Showing 93 changed files with 1,337 additions and 2,846 deletions.
4 changes: 2 additions & 2 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -580,10 +580,10 @@ message(STATUS "CMAKE_CXX_COMPILER_VERSION: ${CMAKE_CXX_COMPILER_VERSION}")

if(NOT "${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" OR CMAKE_CXX_COMPILER_VERSION VERSION_GREATER "11")
message(STATUS "Using -mavx2 -mfma -mavxvnni flags")
set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma -mavxvnni")
set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma -mf16c -mavxvnni")
else()
message(STATUS "Using -mavx2 -mfma flags")
set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma")
set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma -mf16c")
endif()
set(mlas_platform_srcs_avx512f
${MLAS_SRC_DIR}/x86_64/DgemmKernelAvx512F.S
Expand Down
6 changes: 4 additions & 2 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -2521,6 +2521,8 @@ This version of the operator has been available since version 1 of the 'com.micr
Only supports causal and local attention.
Supports rotary position embedding for CPU and CUDA.
Supports packed input for CPU and CUDA.
Supports continuous decoding for batch_size == 1 for CPU and CUDA.


#### Version

Expand Down Expand Up @@ -2561,9 +2563,9 @@ This version of the operator has been available since version 1 of the 'com.micr
<dt><tt>past_value</tt> (optional) : T</dt>
<dd>past state value with support for format BNSH. When past_value uses same tensor as present_value(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.</dd>
<dt><tt>seqlens_k</tt> : M</dt>
<dd>1d Tensor of shape (batch_size). Indicates past sequence lengths for token generation case.</dd>
<dd>1D Tensor of shape (batch_size). Equivalent to (total_sequence_lengths - 1).</dd>
<dt><tt>total_sequence_length</tt> : M</dt>
<dd>Scalar tensor of total sequence length (past + new).</dd>
<dd>Scalar tensor equivalent to the maximum total sequence length (past + new) of the batch. Used for checking inputs and determining prompt vs token generation case.</dd>
<dt><tt>cos_cache</tt> (optional) : T</dt>
<dd>2D tensor with shape (max_sequence_length, head_size / 2).</dd>
<dt><tt>sin_cache</tt> (optional) : T</dt>
Expand Down
2 changes: 1 addition & 1 deletion docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ Do not modify directly.*
|MatMulFpQ4|*in* A:**T1**<br> *in* B:**T2**<br> *in* B_shape:**T3**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)<br/> **T3** = tensor(int64)|
|MatMulInteger16|*in* A:**T1**<br> *in* B:**T2**<br> *out* Y:**T3**|1+|**T1** = tensor(int16)<br/> **T2** = tensor(int16)<br/> **T3** = tensor(int32)|
|MatMulIntegerToFloat|*in* A:**T1**<br> *in* B:**T2**<br> *in* a_scale:**T3**<br> *in* b_scale:**T3**<br> *in* a_zero_point:**T1**<br> *in* b_zero_point:**T2**<br> *in* bias:**T3**<br> *out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(float)|
|MatMulNBits|*in* A:**T1**<br> *in* B:**T2**<br> *in* scales:**T1**<br> *in* zero_points:**T3**<br> *in* g_idx:**T4**<br> *in* bias:**T1**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)<br/> **T3** = tensor(float), tensor(uint8)<br/> **T4** = tensor(int32)|
|MatMulNBits|*in* A:**T1**<br> *in* B:**T2**<br> *in* scales:**T1**<br> *in* zero_points:**T3**<br> *in* g_idx:**T4**<br> *in* bias:**T1**<br> *out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(uint8)<br/> **T3** = tensor(float), tensor(float16), tensor(uint8)<br/> **T4** = tensor(int32)|
|MaxpoolWithMask|*in* X:**T**<br> *in* M:**tensor(int32)**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|MultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**M**<br> *in* attention_bias:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**T** = tensor(float)|
|MurmurHash3|*in* X:**T1**<br> *out* Y:**T2**|1+|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(string), tensor(uint32), tensor(uint64)<br/> **T2** = tensor(int32), tensor(uint32)|
Expand Down
49 changes: 48 additions & 1 deletion java/src/main/java/ai/onnxruntime/OrtEnvironment.java
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
/*
* Copyright (c) 2019, 2023 Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2019, 2024 Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime;

import ai.onnxruntime.OrtSession.SessionOptions;
import ai.onnxruntime.OrtTrainingSession.OrtCheckpointState;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.EnumSet;
import java.util.Objects;
import java.util.logging.Logger;
Expand Down Expand Up @@ -236,6 +237,52 @@ OrtSession createSession(String modelPath, OrtAllocator allocator, SessionOption
return new OrtSession(this, modelPath, allocator, options);
}

/**
* Create a session using the specified {@link SessionOptions}, model and the default memory
* allocator.
*
* @param modelBuffer Byte buffer representing an ONNX model. Must be a direct byte buffer.
* @param options The session options.
* @return An {@link OrtSession} with the specified model.
* @throws OrtException If the model failed to parse, wasn't compatible or caused an error.
*/
public OrtSession createSession(ByteBuffer modelBuffer, SessionOptions options)
throws OrtException {
return createSession(modelBuffer, defaultAllocator, options);
}

/**
* Create a session using the default {@link SessionOptions}, model and the default memory
* allocator.
*
* @param modelBuffer Byte buffer representing an ONNX model. Must be a direct byte buffer.
* @return An {@link OrtSession} with the specified model.
* @throws OrtException If the model failed to parse, wasn't compatible or caused an error.
*/
public OrtSession createSession(ByteBuffer modelBuffer) throws OrtException {
return createSession(modelBuffer, new OrtSession.SessionOptions());
}

/**
* Create a session using the specified {@link SessionOptions} and model buffer.
*
* @param modelBuffer Byte buffer representing an ONNX model. Must be a direct byte buffer.
* @param allocator The memory allocator to use.
* @param options The session options.
* @return An {@link OrtSession} with the specified model.
* @throws OrtException If the model failed to parse, wasn't compatible or caused an error.
*/
OrtSession createSession(ByteBuffer modelBuffer, OrtAllocator allocator, SessionOptions options)
throws OrtException {
Objects.requireNonNull(modelBuffer, "model array must not be null");
if (modelBuffer.remaining() == 0) {
throw new OrtException("Invalid model buffer, no elements remaining.");
} else if (!modelBuffer.isDirect()) {
throw new OrtException("ByteBuffer is not direct.");
}
return new OrtSession(this, modelBuffer, allocator, options);
}

/**
* Create a session using the specified {@link SessionOptions}, model and the default memory
* allocator.
Expand Down
52 changes: 52 additions & 0 deletions java/src/main/java/ai/onnxruntime/OrtSession.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import ai.onnxruntime.providers.OrtFlags;
import ai.onnxruntime.providers.OrtTensorRTProviderOptions;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
Expand Down Expand Up @@ -94,6 +95,31 @@ public class OrtSession implements AutoCloseable {
allocator);
}

/**
* Creates a session reading the model from the supplied byte buffer.
*
* <p>Must be a direct byte buffer.
*
* @param env The environment.
* @param modelBuffer The model protobuf as a byte buffer.
* @param allocator The allocator to use.
* @param options Session configuration options.
* @throws OrtException If the model was corrupted or some other error occurred in native code.
*/
OrtSession(
OrtEnvironment env, ByteBuffer modelBuffer, OrtAllocator allocator, SessionOptions options)
throws OrtException {
this(
createSession(
OnnxRuntime.ortApiHandle,
env.getNativeHandle(),
modelBuffer,
modelBuffer.position(),
modelBuffer.remaining(),
options.getNativeHandle()),
allocator);
}

/**
* Private constructor to build the Java object wrapped around a native session.
*
Expand Down Expand Up @@ -514,6 +540,15 @@ private static native long createSession(
private static native long createSession(
long apiHandle, long envHandle, byte[] modelArray, long optsHandle) throws OrtException;

private static native long createSession(
long apiHandle,
long envHandle,
ByteBuffer modelBuffer,
int bufferPos,
int bufferSize,
long optsHandle)
throws OrtException;

private native long getNumInputs(long apiHandle, long nativeHandle) throws OrtException;

private native String[] getInputNames(long apiHandle, long nativeHandle, long allocatorHandle)
Expand Down Expand Up @@ -907,6 +942,20 @@ public void setSymbolicDimensionValue(String dimensionName, long dimensionValue)
OnnxRuntime.ortApiHandle, nativeHandle, dimensionName, dimensionValue);
}

/**
* Set whether to use deterministic compute.
*
* <p>Default is false. If set to true, this will enable deterministic compute for GPU kernels
* where possible. Note that this most likely will have a performance cost.
*
* @param value Should the compute be deterministic?
* @throws OrtException If there was an error in native code.
*/
public void setDeterministicCompute(boolean value) throws OrtException {
checkClosed();
setDeterministicCompute(OnnxRuntime.ortApiHandle, nativeHandle, value);
}

/**
* Disables the per session thread pools. Must be used in conjunction with an environment
* containing global thread pools.
Expand Down Expand Up @@ -1292,6 +1341,9 @@ private native void registerCustomOpsUsingFunction(

private native void closeOptions(long apiHandle, long nativeHandle);

private native void setDeterministicCompute(
long apiHandle, long nativeHandle, boolean isDeterministic) throws OrtException;

private native void addFreeDimensionOverrideByName(
long apiHandle, long nativeHandle, String dimensionName, long dimensionValue)
throws OrtException;
Expand Down
25 changes: 24 additions & 1 deletion java/src/main/native/ai_onnxruntime_OrtSession.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019, 2020, 2022 Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2019, 2024 Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
#include <jni.h>
Expand Down Expand Up @@ -48,6 +48,29 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_createSession__JJLjava_la
return (jlong)session;
}

/*
* Class: ai_onnxruntime_OrtSession
* Method: createSession
* Signature: (JJLjava/nio/ByteBuffer;IIJ)J
*/
JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_createSession__JJLjava_nio_ByteBuffer_2IIJ(JNIEnv* jniEnv, jclass jclazz, jlong apiHandle, jlong envHandle, jobject buffer, jint bufferPos, jint bufferSize, jlong optsHandle) {
(void)jclazz; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*)apiHandle;
OrtEnv* env = (OrtEnv*)envHandle;
OrtSessionOptions* opts = (OrtSessionOptions*)optsHandle;
OrtSession* session = NULL;

// Extract the buffer
char* bufferArr = (char*)(*jniEnv)->GetDirectBufferAddress(jniEnv, buffer);
// Increment by bufferPos bytes
bufferArr = bufferArr + bufferPos;

// Create the session
checkOrtStatus(jniEnv, api, api->CreateSessionFromArray(env, bufferArr, bufferSize, opts, &session));

return (jlong)session;
}

/*
* Class: ai_onnxruntime_OrtSession
* Method: createSession
Expand Down
13 changes: 13 additions & 0 deletions java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,19 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_setSes
checkOrtStatus(jniEnv,api,api->SetSessionLogVerbosityLevel(options,logLevel));
}

/*
* Class: ai_onnxruntime_OrtSession_SessionOptions
* Method: setDeterministicCompute
* Signature: (JJZ)V
*/
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_setDeterministicCompute
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong optionsHandle, jboolean isDeterministic) {
(void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*)apiHandle;
OrtSessionOptions* options = (OrtSessionOptions*) optionsHandle;
checkOrtStatus(jniEnv,api,api->SetDeterministicCompute(options, isDeterministic));
}

/*
* Class: ai_onnxruntime_OrtSession_SessionOptions
* Method: registerCustomOpLibrary
Expand Down
32 changes: 32 additions & 0 deletions java/src/test/java/ai/onnxruntime/InferenceTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,14 @@
import ai.onnxruntime.OrtSession.SessionOptions.OptLevel;
import java.io.File;
import java.io.IOException;
import java.io.RandomAccessFile;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.FloatBuffer;
import java.nio.LongBuffer;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.nio.channels.FileChannel.MapMode;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
Expand Down Expand Up @@ -338,6 +342,33 @@ public void partialInputsTest() throws OrtException {
}
}

@Test
public void createSessionFromByteBuffer() throws IOException, OrtException {
Path modelPath = TestHelpers.getResourcePath("/squeezenet.onnx");
try (RandomAccessFile file = new RandomAccessFile(modelPath.toFile(), "r");
FileChannel channel = file.getChannel()) {
MappedByteBuffer modelBuffer = channel.map(MapMode.READ_ONLY, 0, channel.size());
try (OrtSession.SessionOptions options = new SessionOptions();
OrtSession session = env.createSession(modelBuffer, options)) {
assertNotNull(session);
assertEquals(1, session.getNumInputs()); // 1 input node
Map<String, NodeInfo> inputInfoList = session.getInputInfo();
assertNotNull(inputInfoList);
assertEquals(1, inputInfoList.size());
NodeInfo input = inputInfoList.get("data_0");
assertEquals("data_0", input.getName()); // input node name
assertTrue(input.getInfo() instanceof TensorInfo);
TensorInfo inputInfo = (TensorInfo) input.getInfo();
assertEquals(OnnxJavaType.FLOAT, inputInfo.type);
int[] expectedInputDimensions = new int[] {1, 3, 224, 224};
assertEquals(expectedInputDimensions.length, inputInfo.shape.length);
for (int i = 0; i < expectedInputDimensions.length; i++) {
assertEquals(expectedInputDimensions[i], inputInfo.shape[i]);
}
}
}
}

@Test
public void createSessionFromByteArray() throws IOException, OrtException {
Path modelPath = TestHelpers.getResourcePath("/squeezenet.onnx");
Expand Down Expand Up @@ -1232,6 +1263,7 @@ public void testExtraSessionOptions() throws OrtException, IOException {
options.setLoggerId("monkeys");
options.setSessionLogLevel(OrtLoggingLevel.ORT_LOGGING_LEVEL_FATAL);
options.setSessionLogVerbosityLevel(5);
options.setDeterministicCompute(true);
Map<String, String> configEntries = options.getConfigEntries();
assertTrue(configEntries.isEmpty());
options.addConfigEntry("key", "value");
Expand Down
5 changes: 0 additions & 5 deletions js/web/lib/backend-wasm-inference.ts

This file was deleted.

29 changes: 0 additions & 29 deletions js/web/lib/backend-wasm-training.ts

This file was deleted.

2 changes: 2 additions & 0 deletions js/web/lib/backend-wasm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,5 @@ export class OnnxruntimeWebAssemblyBackend implements Backend {
return Promise.resolve(handler);
}
}

export const wasmBackend = new OnnxruntimeWebAssemblyBackend();
4 changes: 1 addition & 3 deletions js/web/lib/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@ if (!BUILD_DEFS.DISABLE_WEBGL) {
}

if (!BUILD_DEFS.DISABLE_WASM) {
const wasmBackend = BUILD_DEFS.DISABLE_TRAINING
? require('./backend-wasm-inference').wasmBackend
: require('./backend-wasm-training').wasmBackend;
const wasmBackend = require('./backend-wasm').wasmBackend;
if (!BUILD_DEFS.DISABLE_JSEP) {
registerBackend('webgpu', wasmBackend, 5);
registerBackend('webnn', wasmBackend, 5);
Expand Down
Loading

0 comments on commit fe8d5ff

Please sign in to comment.