Skip to content

Commit

Permalink
[java] Adding ability to load a model from a memory mapped byte buffer (
Browse files Browse the repository at this point in the history
#20062)

### Description
Adds support for constructing an `OrtSession` from a
`java.nio.ByteBuffer`. These buffers can be memory mapped from files
which means there doesn't need to be copies of the model protobuf held
in Java, reducing peak memory usage during session construction.

### Motivation and Context
Reduces memory usage on model construction by not requiring as many
copies on the Java side. Should help with #19599.
  • Loading branch information
Craigacp authored Sep 15, 2024
1 parent c63dd02 commit 02e00dc
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 2 deletions.
49 changes: 48 additions & 1 deletion java/src/main/java/ai/onnxruntime/OrtEnvironment.java
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
/*
* Copyright (c) 2019, 2023 Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2019, 2024 Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime;

import ai.onnxruntime.OrtSession.SessionOptions;
import ai.onnxruntime.OrtTrainingSession.OrtCheckpointState;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.EnumSet;
import java.util.Objects;
import java.util.logging.Logger;
Expand Down Expand Up @@ -236,6 +237,52 @@ OrtSession createSession(String modelPath, OrtAllocator allocator, SessionOption
return new OrtSession(this, modelPath, allocator, options);
}

/**
* Create a session using the specified {@link SessionOptions}, model and the default memory
* allocator.
*
* @param modelBuffer Byte buffer representing an ONNX model. Must be a direct byte buffer.
* @param options The session options.
* @return An {@link OrtSession} with the specified model.
* @throws OrtException If the model failed to parse, wasn't compatible or caused an error.
*/
public OrtSession createSession(ByteBuffer modelBuffer, SessionOptions options)
throws OrtException {
return createSession(modelBuffer, defaultAllocator, options);
}

/**
* Create a session using the default {@link SessionOptions}, model and the default memory
* allocator.
*
* @param modelBuffer Byte buffer representing an ONNX model. Must be a direct byte buffer.
* @return An {@link OrtSession} with the specified model.
* @throws OrtException If the model failed to parse, wasn't compatible or caused an error.
*/
public OrtSession createSession(ByteBuffer modelBuffer) throws OrtException {
return createSession(modelBuffer, new OrtSession.SessionOptions());
}

/**
* Create a session using the specified {@link SessionOptions} and model buffer.
*
* @param modelBuffer Byte buffer representing an ONNX model. Must be a direct byte buffer.
* @param allocator The memory allocator to use.
* @param options The session options.
* @return An {@link OrtSession} with the specified model.
* @throws OrtException If the model failed to parse, wasn't compatible or caused an error.
*/
OrtSession createSession(ByteBuffer modelBuffer, OrtAllocator allocator, SessionOptions options)
throws OrtException {
Objects.requireNonNull(modelBuffer, "model array must not be null");
if (modelBuffer.remaining() == 0) {
throw new OrtException("Invalid model buffer, no elements remaining.");
} else if (!modelBuffer.isDirect()) {
throw new OrtException("ByteBuffer is not direct.");
}
return new OrtSession(this, modelBuffer, allocator, options);
}

/**
* Create a session using the specified {@link SessionOptions}, model and the default memory
* allocator.
Expand Down
35 changes: 35 additions & 0 deletions java/src/main/java/ai/onnxruntime/OrtSession.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import ai.onnxruntime.providers.OrtFlags;
import ai.onnxruntime.providers.OrtTensorRTProviderOptions;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
Expand Down Expand Up @@ -94,6 +95,31 @@ public class OrtSession implements AutoCloseable {
allocator);
}

/**
* Creates a session reading the model from the supplied byte buffer.
*
* <p>Must be a direct byte buffer.
*
* @param env The environment.
* @param modelBuffer The model protobuf as a byte buffer.
* @param allocator The allocator to use.
* @param options Session configuration options.
* @throws OrtException If the model was corrupted or some other error occurred in native code.
*/
OrtSession(
OrtEnvironment env, ByteBuffer modelBuffer, OrtAllocator allocator, SessionOptions options)
throws OrtException {
this(
createSession(
OnnxRuntime.ortApiHandle,
env.getNativeHandle(),
modelBuffer,
modelBuffer.position(),
modelBuffer.remaining(),
options.getNativeHandle()),
allocator);
}

/**
* Private constructor to build the Java object wrapped around a native session.
*
Expand Down Expand Up @@ -514,6 +540,15 @@ private static native long createSession(
private static native long createSession(
long apiHandle, long envHandle, byte[] modelArray, long optsHandle) throws OrtException;

private static native long createSession(
long apiHandle,
long envHandle,
ByteBuffer modelBuffer,
int bufferPos,
int bufferSize,
long optsHandle)
throws OrtException;

private native long getNumInputs(long apiHandle, long nativeHandle) throws OrtException;

private native String[] getInputNames(long apiHandle, long nativeHandle, long allocatorHandle)
Expand Down
25 changes: 24 additions & 1 deletion java/src/main/native/ai_onnxruntime_OrtSession.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019, 2020, 2022 Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2019, 2024 Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
#include <jni.h>
Expand Down Expand Up @@ -48,6 +48,29 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_createSession__JJLjava_la
return (jlong)session;
}

/*
* Class: ai_onnxruntime_OrtSession
* Method: createSession
* Signature: (JJLjava/nio/ByteBuffer;IIJ)J
*/
JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_createSession__JJLjava_nio_ByteBuffer_2IIJ(JNIEnv* jniEnv, jclass jclazz, jlong apiHandle, jlong envHandle, jobject buffer, jint bufferPos, jint bufferSize, jlong optsHandle) {
(void)jclazz; // Required JNI parameter not needed by functions which don't need to access their host object.
const OrtApi* api = (const OrtApi*)apiHandle;
OrtEnv* env = (OrtEnv*)envHandle;
OrtSessionOptions* opts = (OrtSessionOptions*)optsHandle;
OrtSession* session = NULL;

// Extract the buffer
char* bufferArr = (char*)(*jniEnv)->GetDirectBufferAddress(jniEnv, buffer);
// Increment by bufferPos bytes
bufferArr = bufferArr + bufferPos;

// Create the session
checkOrtStatus(jniEnv, api, api->CreateSessionFromArray(env, bufferArr, bufferSize, opts, &session));

return (jlong)session;
}

/*
* Class: ai_onnxruntime_OrtSession
* Method: createSession
Expand Down
31 changes: 31 additions & 0 deletions java/src/test/java/ai/onnxruntime/InferenceTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,14 @@
import ai.onnxruntime.OrtSession.SessionOptions.OptLevel;
import java.io.File;
import java.io.IOException;
import java.io.RandomAccessFile;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.FloatBuffer;
import java.nio.LongBuffer;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.nio.channels.FileChannel.MapMode;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
Expand Down Expand Up @@ -338,6 +342,33 @@ public void partialInputsTest() throws OrtException {
}
}

@Test
public void createSessionFromByteBuffer() throws IOException, OrtException {
Path modelPath = TestHelpers.getResourcePath("/squeezenet.onnx");
try (RandomAccessFile file = new RandomAccessFile(modelPath.toFile(), "r");
FileChannel channel = file.getChannel()) {
MappedByteBuffer modelBuffer = channel.map(MapMode.READ_ONLY, 0, channel.size());
try (OrtSession.SessionOptions options = new SessionOptions();
OrtSession session = env.createSession(modelBuffer, options)) {
assertNotNull(session);
assertEquals(1, session.getNumInputs()); // 1 input node
Map<String, NodeInfo> inputInfoList = session.getInputInfo();
assertNotNull(inputInfoList);
assertEquals(1, inputInfoList.size());
NodeInfo input = inputInfoList.get("data_0");
assertEquals("data_0", input.getName()); // input node name
assertTrue(input.getInfo() instanceof TensorInfo);
TensorInfo inputInfo = (TensorInfo) input.getInfo();
assertEquals(OnnxJavaType.FLOAT, inputInfo.type);
int[] expectedInputDimensions = new int[] {1, 3, 224, 224};
assertEquals(expectedInputDimensions.length, inputInfo.shape.length);
for (int i = 0; i < expectedInputDimensions.length; i++) {
assertEquals(expectedInputDimensions[i], inputInfo.shape[i]);
}
}
}
}

@Test
public void createSessionFromByteArray() throws IOException, OrtException {
Path modelPath = TestHelpers.getResourcePath("/squeezenet.onnx");
Expand Down

0 comments on commit 02e00dc

Please sign in to comment.