Skip to content

Commit

Permalink
Merge pull request #255 from yuzawa-san/ort/1.20.0
Browse files Browse the repository at this point in the history
ORT v1.20.0 bump
  • Loading branch information
yuzawa-san authored Nov 21, 2024
2 parents 8109360 + 7127ace commit 2d4d186
Show file tree
Hide file tree
Showing 15 changed files with 873 additions and 172 deletions.
5 changes: 5 additions & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,13 @@ jobs:
strategy:
fail-fast: false
matrix:
# baseline version
java: [22]
os: ['ubuntu-latest', 'macos-latest', 'windows-latest']
include:
# latest version
- java: 23
os: ubuntu-latest
runs-on: ${{ matrix.os }}
name: Build on ${{ matrix.os }} on Java ${{ matrix.java }}
steps:
Expand Down
4 changes: 1 addition & 3 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@ allprojects {
apply plugin: "java"
apply plugin: "com.diffplug.spotless"
java {
toolchain {
languageVersion = JavaLanguageVersion.of(22)
}
sourceCompatibility = targetCompatibility = JavaLanguageVersion.of(22)
}
repositories {
mavenCentral()
Expand Down
2 changes: 1 addition & 1 deletion gradle.properties
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
version=2.0.0-SNAPSHOT
com.jyuzawa.onnxruntime.library_version=1.19.2
com.jyuzawa.onnxruntime.library_version=1.20.0
com.jyuzawa.onnxruntime.library_baseline=2.0.0
org.gradle.parallel=true
org.gradle.caching=true
Expand Down
2 changes: 1 addition & 1 deletion gradle/wrapper/gradle-wrapper.properties
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
distributionBase=GRADLE_USER_HOME
distributionPath=wrapper/dists
distributionUrl=https\://services.gradle.org/distributions/gradle-8.10.1-bin.zip
distributionUrl=https\://services.gradle.org/distributions/gradle-8.11.1-bin.zip
networkTimeout=10000
validateDistributionUrl=true
zipStoreBase=GRADLE_USER_HOME
Expand Down
2 changes: 2 additions & 0 deletions src/main/java/com/jyuzawa/onnxruntime/ApiImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ final class ApiImpl implements Api {
final SetLanguageProjection.Function SetLanguageProjection;
final SetOptimizedModelFilePath.Function SetOptimizedModelFilePath;
final SetDeterministicCompute.Function SetDeterministicCompute;
final SetEpDynamicOptions.Function SetEpDynamicOptions;
final SetSessionExecutionMode.Function SetSessionExecutionMode;
final SetSessionGraphOptimizationLevel.Function SetSessionGraphOptimizationLevel;
final SetSessionLogId.Function SetSessionLogId;
Expand Down Expand Up @@ -255,6 +256,7 @@ final class ApiImpl implements Api {
this.SetLanguageProjection = OrtApi.SetLanguageProjectionFunction(memorySegment);
this.SetOptimizedModelFilePath = OrtApi.SetOptimizedModelFilePathFunction(memorySegment);
this.SetDeterministicCompute = OrtApi.SetDeterministicComputeFunction(memorySegment);
this.SetEpDynamicOptions = OrtApi.SetEpDynamicOptionsFunction(memorySegment);
this.SetSessionExecutionMode = OrtApi.SetSessionExecutionModeFunction(memorySegment);
this.SetSessionGraphOptimizationLevel = OrtApi.SetSessionGraphOptimizationLevelFunction(memorySegment);
this.SetSessionLogId = OrtApi.SetSessionLogIdFunction(memorySegment);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

abstract class ExecutionProviderConfig {

protected static final String TRUE_VALUE = "1";

protected final Map<String, String> properties;

protected ExecutionProviderConfig(Map<String, String> properties) {
Expand All @@ -36,6 +38,10 @@ protected void copyInteger(String key, MemorySegment config, BiConsumer<MemorySe
get(key).ifPresent(val -> consumer.accept(config, Integer.parseInt(val)));
}

protected void copyBoolean(String key, MemorySegment config, BiConsumer<MemorySegment, Boolean> consumer) {
get(key).ifPresent(val -> consumer.accept(config, TRUE_VALUE.equals(val) || Boolean.valueOf(val)));
}

protected void copyLong(String key, MemorySegment config, BiConsumer<MemorySegment, Long> consumer) {
get(key).ifPresent(val -> consumer.accept(config, Long.parseLong(val)));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@

final class ExecutionProviderCoreMLConfig extends ExecutionProviderConfig {

private static final String TRUE_VALUE = "1";

protected ExecutionProviderCoreMLConfig(Map<String, String> properties) {
super(properties);
}
Expand All @@ -32,6 +30,9 @@ final void appendToSessionOptions(Arena memorySession, ApiImpl api, MemorySegmen
if (TRUE_VALUE.equals(properties.get("allow_static_input_shapes"))) {
flags |= onnxruntime_all_h.COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES();
}
if (TRUE_VALUE.equals(properties.get("use_cpu_and_cpu"))) {
flags |= onnxruntime_all_h.COREML_FLAG_USE_CPU_AND_GPU();
}
try {
api.checkStatus(onnxruntime_all_h.OrtSessionOptionsAppendExecutionProvider_CoreML(sessionOptions, flags));
} catch (UnsatisfiedLinkError e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ final void appendToSessionOptions(Arena memorySession, ApiImpl api, MemorySegmen
config,
memorySession,
OrtMIGraphXProviderOptions::migraphx_load_model_path);
copyBoolean("migraphx_exhaustive_tune", config, OrtMIGraphXProviderOptions::migraphx_exhaustive_tune);
api.checkStatus(api.SessionOptionsAppendExecutionProvider_MIGraphX.apply(sessionOptions, config));
}
}
12 changes: 12 additions & 0 deletions src/main/java/com/jyuzawa/onnxruntime/Session.java
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,20 @@ public interface Session extends AutoCloseable {
*/
Transaction.Builder newTransaction();

/**
* Create a new I/O Binding.
* @return a builder
* @since v1.4.0
*/
IoBinding.Builder newIoBinding();

/**
* Set DynamicOptions for EPs (Execution Providers)
* @param epDynamicOptions
* @since 2.0.0
*/
void setEpDynamicOptions(Map<String, String> epDynamicOptions);

/**
* A builder of a {@link Session}. Must provide either bytes or a path.
*
Expand Down
17 changes: 17 additions & 0 deletions src/main/java/com/jyuzawa/onnxruntime/SessionImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import static com.jyuzawa.onnxruntime_extern.onnxruntime_all_h.C_CHAR;
import static com.jyuzawa.onnxruntime_extern.onnxruntime_all_h.C_LONG;
import static com.jyuzawa.onnxruntime_extern.onnxruntime_all_h.C_POINTER;
import static java.lang.foreign.ValueLayout.JAVA_BYTE;

import java.io.IOException;
Expand Down Expand Up @@ -225,6 +226,22 @@ public IoBinding.Builder newIoBinding() {
return new IoBindingImpl.Builder(this);
}

@Override
public void setEpDynamicOptions(Map<String, String> epDynamicOptions) {
try (Arena tmpArena = Arena.ofConfined()) {
int size = epDynamicOptions.size();
MemorySegment keyArray = tmpArena.allocate(C_POINTER, size);
MemorySegment valueArray = tmpArena.allocate(C_POINTER, size);
int i = 0;
for (Map.Entry<String, String> entry : epDynamicOptions.entrySet()) {
keyArray.setAtIndex(C_POINTER, i, tmpArena.allocateFrom(entry.getKey()));
valueArray.setAtIndex(C_POINTER, i, tmpArena.allocateFrom(entry.getValue()));
i++;
}
api.checkStatus(api.SetEpDynamicOptions.apply(address, keyArray, valueArray, size));
}
}

static final class Builder implements Session.Builder {
private final ApiImpl api;
private final EnvironmentImpl environment;
Expand Down
Loading

0 comments on commit 2d4d186

Please sign in to comment.