Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Java Bindings for Adapters API #1030

Merged
merged 32 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
dd81078
Init adapters java bindings
skyline75489 Nov 1, 2024
27c3876
public APIs
skyline75489 Nov 1, 2024
39207ea
Update
skyline75489 Nov 7, 2024
f33874f
Format
skyline75489 Nov 7, 2024
8be1a9d
Revert "Format"
skyline75489 Nov 7, 2024
5f5e484
[skip ci] spotless Apply
skyline75489 Nov 7, 2024
b6d6534
Run Java tests
skyline75489 Nov 7, 2024
cf7265a
Java tests
skyline75489 Nov 7, 2024
71df122
Update
skyline75489 Nov 7, 2024
e2f2e9a
[skip ci] Up
skyline75489 Nov 7, 2024
b633d19
Fix
skyline75489 Nov 7, 2024
084ec4d
More fix
skyline75489 Nov 7, 2024
19d59fa
[skip ci] shutdown properly
skyline75489 Nov 8, 2024
9956d1c
Fix on windows
skyline75489 Nov 8, 2024
5bffd43
More java fix
skyline75489 Nov 8, 2024
0b7c5bd
Update
skyline75489 Nov 8, 2024
e98439b
Fix
skyline75489 Nov 8, 2024
25a4caa
[skip ci] Fix adapters tests
skyline75489 Nov 11, 2024
c958763
Merge branch 'main' into jialli/adapters-jni
skyline75489 Nov 11, 2024
3061f6d
Format
skyline75489 Nov 11, 2024
14d4b30
Clean up
skyline75489 Nov 11, 2024
724249a
[skip ci] binary dir
skyline75489 Nov 11, 2024
c41d8a9
Fix build on win32
skyline75489 Nov 11, 2024
f25cc88
More fix
skyline75489 Nov 11, 2024
8f8bc3f
Windows ARM64
skyline75489 Nov 11, 2024
06557bd
Win ARM64 needs java21
skyline75489 Nov 11, 2024
e64a550
Fix build on macOS
skyline75489 Nov 11, 2024
13f15fb
[skip ci] resolve comments
skyline75489 Nov 18, 2024
75dc96a
Merge branch 'main' into jialli/adapters-jni
skyline75489 Nov 25, 2024
515dc29
[skip ci] comments
skyline75489 Nov 26, 2024
b8666b5
comments
skyline75489 Nov 28, 2024
8059f9e
Merge branch 'main' into jialli/adapters-jni
skyline75489 Nov 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions src/java/src/main/java/ai/onnxruntime/genai/Adapters.java
skyline75489 marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime.genai;


public final class Adapters implements AutoCloseable {
private long nativeHandle = 0;

/**
* Constructs an Adapters object with the given model.
*
* @param model The model.
* @throws GenAIException If the call to the GenAI native API fails.
*/
public Adapters(Model model) throws GenAIException {
if (model.nativeHandle() == 0) {
throw new IllegalArgumentException("model has been freed and is invalid");
}

nativeHandle = createAdapters(model.nativeHandle());
}

/**
* Load an adapter from the specified path.
*
* @param adapterFilePath The path of the adapter.
* @param adapterName A unique user supplied adapter identifier.
* @throws GenAIException If the call to the GenAI native API fails.
*/
public void loadAdapters(String adapterFilePath, String adapterName) throws GenAIException {
skyline75489 marked this conversation as resolved.
Show resolved Hide resolved
if (nativeHandle == 0) {
throw new IllegalStateException("Instance has been freed and is invalid");
}

loadAdapter(nativeHandle, adapterFilePath, adapterName);
}

/**
* Unload an adapter.
*
* @param adapterName A unique user supplied adapter identifier.
* @throws GenAIException If the call to the GenAI native API fails.
*/
public void unloadAdapters(String adapterName) throws GenAIException {
if (nativeHandle == 0) {
throw new IllegalStateException("Instance has been freed and is invalid");
}

unloadAdapter(nativeHandle, adapterName);
}

@Override
public void close() {
if (nativeHandle != 0) {
destroyAdapters(nativeHandle);
nativeHandle = 0;
}
}

long nativeHandle() {
return nativeHandle;
}

private native long createAdapters(long modelHandle)
throws GenAIException;

private native void destroyAdapters(long nativeHandle);

private native void loadAdapter(long nativeHandle, String adapterFilePath, String adapterName)
throws GenAIException;
skottmckay marked this conversation as resolved.
Show resolved Hide resolved

private native void unloadAdapter(long nativeHandle, String adapterName)
throws GenAIException;
}
12 changes: 12 additions & 0 deletions src/java/src/main/java/ai/onnxruntime/genai/Generator.java
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,15 @@ public int getLastTokenInSequence(long sequenceIndex) throws GenAIException {
return getSequenceLastToken(nativeHandle, sequenceIndex);
}

public void setActiveAdapter(Adapters adapters, String adapterName) throws GenAIException {
if (nativeHandle == 0) {
throw new IllegalStateException("Instance has been freed and is invalid");
}

setActiveAdapter(nativeHandle, adapters.nativeHandle(), adapterName);
}


/** Closes the Generator and releases any associated resources. */
@Override
public void close() {
Expand Down Expand Up @@ -169,4 +178,7 @@ private native int[] getSequenceNative(long nativeHandle, long sequenceIndex)

private native int getSequenceLastToken(long nativeHandle, long sequenceIndex)
throws GenAIException;

private native void setActiveAdapter(long nativeHandle, long adaptersNativeHandle, String adapterName)
throws GenAIException;
}
40 changes: 40 additions & 0 deletions src/java/src/main/native/ai_onnxruntime_genai_Adapters.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License.
*/
#include "ai_onnxruntime_genai_Adapters.h"

#include "ort_genai_c.h"
#include "utils.h"

using namespace Helpers;

JNIEXPORT jlong JNICALL
Java_ai_onnxruntime_genai_Adapters_createAdapters(JNIEnv* env, jobject thiz, jlong model_handle) {
const OgaModel* model = reinterpret_cast<const OgaModel*>(model_handle);
OgaAdapters* adapters = nullptr;
if (ThrowIfError(env, OgaCreateAdapters(model, &adapters))) {
return 0;
}

return reinterpret_cast<jlong>(adapters);
}

JNIEXPORT void JNICALL
Java_ai_onnxruntime_genai_Adapters_destroyAdapters(JNIEnv* env, jobject thiz, jlong native_handle) {
OgaDestroyAdapters(reinterpret_cast<OgaAdapters*>(native_handle));
}

JNIEXPORT void JNICALL
Java_ai_onnxruntime_genai_Adapters_loadAdapter(JNIEnv* env, jobject thiz, jlong native_handle,
jstring adapter_file_path, jstring adapter_name) {
CString file_path{env, adapter_file_path};
CString name{env, adapter_name};
ThrowIfError(env, OgaLoadAdapter(reinterpret_cast<OgaAdapters*>(native_handle), file_path, name));
}

JNIEXPORT void JNICALL
Java_ai_onnxruntime_genai_Adapters_unloadAdapter(JNIEnv* env, jobject thiz, jlong native_handle, jstring adapter_name) {
CString name{env, adapter_name};
ThrowIfError(env, OgaUnloadAdapter(reinterpret_cast<OgaAdapters*>(native_handle), name));
}
9 changes: 9 additions & 0 deletions src/java/src/main/native/ai_onnxruntime_genai_Generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,12 @@ Java_ai_onnxruntime_genai_Generator_getSequenceLastToken(JNIEnv* env, jobject th

return jint(tokens[num_tokens - 1]);
}

JNIEXPORT void JNICALL
Java_ai_onnxruntime_genai_Generator_setActiveAdapter(JNIEnv* env, jobject thiz, jlong native_handle,
jlong adapters_native_handle, jstring adapter_name) {
CString name{env, adapter_name};
ThrowIfError(env, OgaSetActiveAdapter(reinterpret_cast<OgaGenerator*>(native_handle),
reinterpret_cast<OgaAdapters*>(adapters_native_handle),
name));
}
4 changes: 4 additions & 0 deletions src/models/onnxruntime_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ p_session_->Run(nullptr, input_names, inputs, std::size(inputs), output_names, o
#define PATH_MAX (4096)
#endif

#if !defined(__ANDROID__)
skyline75489 marked this conversation as resolved.
Show resolved Hide resolved

#define LOG_WHEN_ENABLED(LOG_FUNC) \
if (Generators::g_log.enabled && Generators::g_log.ort_lib) LOG_FUNC

Expand All @@ -110,6 +112,8 @@ p_session_->Run(nullptr, input_names, inputs, std::size(inputs), output_names, o
#define LOG_ERROR(...) LOG_WHEN_ENABLED(Generators::Log("error", __VA_ARGS__))
#define LOG_FATAL(...) LOG_WHEN_ENABLED(Generators::Log("fatal", __VA_ARGS__))

#endif

/** \brief Free functions and a few helpers are defined inside this namespace. Otherwise all types are the C API types
*
*/
Expand Down
Loading