Skip to content

Commit

Permalink
Merge pull request #256 from yuzawa-san/arenas
Browse files Browse the repository at this point in the history
use Arenas
  • Loading branch information
yuzawa-san authored Nov 24, 2024
2 parents 2d4d186 + 3047375 commit d7909d8
Show file tree
Hide file tree
Showing 25 changed files with 169 additions and 204 deletions.
13 changes: 6 additions & 7 deletions src/main/java/com/jyuzawa/onnxruntime/ApiImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import java.lang.System.Logger.Level;
import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.SegmentAllocator;
import java.util.Collections;
import java.util.EnumSet;
import java.util.Set;
Expand Down Expand Up @@ -359,20 +358,20 @@ void checkStatus(MemorySegment status) {
}
}

MemorySegment create(SegmentAllocator allocator, Function<MemorySegment, MemorySegment> constructor) {
MemorySegment pointer = allocator.allocate(C_POINTER);
MemorySegment create(Arena arena, Function<MemorySegment, MemorySegment> constructor) {
MemorySegment pointer = arena.allocate(C_POINTER);
checkStatus(constructor.apply(pointer));
return pointer.getAtIndex(C_POINTER, 0);
}

int extractInt(SegmentAllocator allocator, Function<MemorySegment, MemorySegment> method) {
MemorySegment pointer = allocator.allocate(C_INT);
int extractInt(Arena arena, Function<MemorySegment, MemorySegment> method) {
MemorySegment pointer = arena.allocate(C_INT);
checkStatus(method.apply(pointer));
return pointer.getAtIndex(C_INT, 0);
}

long extractLong(SegmentAllocator allocator, Function<MemorySegment, MemorySegment> method) {
MemorySegment pointer = allocator.allocate(C_LONG);
long extractLong(Arena arena, Function<MemorySegment, MemorySegment> method) {
MemorySegment pointer = arena.allocate(C_LONG);
checkStatus(method.apply(pointer));
return pointer.getAtIndex(C_LONG, 0);
}
Expand Down
12 changes: 6 additions & 6 deletions src/main/java/com/jyuzawa/onnxruntime/EnvironmentImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ final class EnvironmentImpl extends ManagedImpl implements Environment {
MemorySegment threadingOptionsAddress = builder.newThreadingOptions(temporarySession);
try {
this.address = api.create(
memorySession,
arena,
out -> api.CreateEnvWithCustomLoggerAndGlobalThreadPools.apply(
OnnxRuntimeLoggingLevel.LOG_CALLBACK,
MemorySegment.NULL,
Expand All @@ -41,7 +41,7 @@ final class EnvironmentImpl extends ManagedImpl implements Environment {
}
} else {
this.address = api.create(
memorySession,
arena,
out -> api.CreateEnvWithCustomLogger.apply(
OnnxRuntimeLoggingLevel.LOG_CALLBACK,
MemorySegment.NULL,
Expand All @@ -51,8 +51,8 @@ final class EnvironmentImpl extends ManagedImpl implements Environment {
}
api.checkStatus(api.SetLanguageProjection.apply(address, ORT_PROJECTION_JAVA()));
this.memoryInfo = api.create(
memorySession, out -> api.CreateCpuMemoryInfo.apply(OrtArenaAllocator(), OrtMemTypeDefault(), out));
this.ortAllocator = api.create(memorySession, out -> api.GetAllocatorWithDefaultOptions.apply(out));
arena, out -> api.CreateCpuMemoryInfo.apply(OrtArenaAllocator(), OrtMemTypeDefault(), out));
this.ortAllocator = api.create(arena, out -> api.GetAllocatorWithDefaultOptions.apply(out));
Map<String, Long> arenaConfig = builder.arenaConfig;
if (arenaConfig == null) {
api.RegisterAllocator.apply(address, ortAllocator);
Expand Down Expand Up @@ -163,8 +163,8 @@ public Builder setArenaConfig(Map<String, Long> config) {
return this;
}

private MemorySegment newThreadingOptions(Arena memorySession) {
MemorySegment threadingOptions = api.create(memorySession, out -> api.CreateThreadingOptions.apply(out));
private MemorySegment newThreadingOptions(Arena arena) {
MemorySegment threadingOptions = api.create(arena, out -> api.CreateThreadingOptions.apply(out));
if (globalDenormalAsZero != null && globalSpinControl) {
api.checkStatus(api.SetGlobalDenormalAsZero.apply(threadingOptions));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ protected ExecutionProviderCPUConfig(Map<String, String> properties) {
}

@Override
void appendToSessionOptions(Arena memorySession, ApiImpl api, MemorySegment sessionOptions) {
void appendToSessionOptions(Arena arena, ApiImpl api, MemorySegment sessionOptions) {
// default is true
// https://github.com/microsoft/onnxruntime/blob/fb85b31facb9fb3fc99c76f99c93ea8f06ada39b/onnxruntime/core/providers/cpu/cpu_execution_provider.h#L14
String useArena = properties.getOrDefault(USE_ARENA, "1");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.SegmentAllocator;
import java.util.Map;
import java.util.Optional;
import java.util.function.BiConsumer;
Expand Down Expand Up @@ -47,12 +46,9 @@ protected void copyLong(String key, MemorySegment config, BiConsumer<MemorySegme
}

protected void copyString(
String key,
MemorySegment config,
SegmentAllocator allocator,
BiConsumer<MemorySegment, MemorySegment> consumer) {
get(key).ifPresent(val -> consumer.accept(config, allocator.allocateFrom(val)));
String key, MemorySegment config, Arena arena, BiConsumer<MemorySegment, MemorySegment> consumer) {
get(key).ifPresent(val -> consumer.accept(config, arena.allocateFrom(val)));
}

abstract void appendToSessionOptions(Arena memorySession, ApiImpl api, MemorySegment sessionOptions);
abstract void appendToSessionOptions(Arena arena, ApiImpl api, MemorySegment sessionOptions);
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ protected ExecutionProviderCoreMLConfig(Map<String, String> properties) {
}

@Override
final void appendToSessionOptions(Arena memorySession, ApiImpl api, MemorySegment sessionOptions) {
final void appendToSessionOptions(Arena arena, ApiImpl api, MemorySegment sessionOptions) {
int flags = 0;
if (TRUE_VALUE.equals(properties.get("use_cpu_only"))) {
flags |= onnxruntime_all_h.COREML_FLAG_USE_CPU_ONLY();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ protected ExecutionProviderMIGraphXConfig(Map<String, String> properties) {
}

@Override
final void appendToSessionOptions(Arena memorySession, ApiImpl api, MemorySegment sessionOptions) {
MemorySegment config = OrtMIGraphXProviderOptions.allocate(memorySession);
final void appendToSessionOptions(Arena arena, ApiImpl api, MemorySegment sessionOptions) {
MemorySegment config = OrtMIGraphXProviderOptions.allocate(arena);
copyInteger("device_id", config, OrtMIGraphXProviderOptions::device_id);
copyInteger("migraphx_fp16_enable", config, OrtMIGraphXProviderOptions::migraphx_fp16_enable);
copyInteger("migraphx_int8_enable", config, OrtMIGraphXProviderOptions::migraphx_int8_enable);
Expand All @@ -28,21 +28,13 @@ final void appendToSessionOptions(Arena memorySession, ApiImpl api, MemorySegmen
copyString(
"migraphx_int8_calibration_table_name",
config,
memorySession,
arena,
OrtMIGraphXProviderOptions::migraphx_int8_calibration_table_name);
copyInteger("migraphx_save_compiled_model", config, OrtMIGraphXProviderOptions::migraphx_save_compiled_model);
copyString(
"migraphx_save_model_path",
config,
memorySession,
OrtMIGraphXProviderOptions::migraphx_save_model_path);
copyString("migraphx_save_model_path", config, arena, OrtMIGraphXProviderOptions::migraphx_save_model_path);

copyInteger("migraphx_load_compiled_model", config, OrtMIGraphXProviderOptions::migraphx_load_compiled_model);
copyString(
"migraphx_load_model_path",
config,
memorySession,
OrtMIGraphXProviderOptions::migraphx_load_model_path);
copyString("migraphx_load_model_path", config, arena, OrtMIGraphXProviderOptions::migraphx_load_model_path);
copyBoolean("migraphx_exhaustive_tune", config, OrtMIGraphXProviderOptions::migraphx_exhaustive_tune);
api.checkStatus(api.SessionOptionsAppendExecutionProvider_MIGraphX.apply(sessionOptions, config));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,24 @@ protected ExecutionProviderMapConfig(Map<String, String> properties) {
}

protected abstract void appendToSessionOptions(
Arena memorySession,
Arena arena,
ApiImpl api,
MemorySegment sessionOptions,
MemorySegment keys,
MemorySegment values,
int numProperties);

@Override
final void appendToSessionOptions(Arena memorySession, ApiImpl api, MemorySegment sessionOptions) {
final void appendToSessionOptions(Arena arena, ApiImpl api, MemorySegment sessionOptions) {
int numProps = properties.size();
MemorySegment keys = memorySession.allocate(C_POINTER, numProps);
MemorySegment values = memorySession.allocate(C_POINTER, numProps);
MemorySegment keys = arena.allocate(C_POINTER, numProps);
MemorySegment values = arena.allocate(C_POINTER, numProps);
int i = 0;
for (Map.Entry<String, String> entry : properties.entrySet()) {
keys.setAtIndex(C_POINTER, i, memorySession.allocateFrom(entry.getKey()));
values.setAtIndex(C_POINTER, i, memorySession.allocateFrom(entry.getValue()));
keys.setAtIndex(C_POINTER, i, arena.allocateFrom(entry.getKey()));
values.setAtIndex(C_POINTER, i, arena.allocateFrom(entry.getValue()));
i++;
}
appendToSessionOptions(memorySession, api, sessionOptions, keys, values, numProps);
appendToSessionOptions(arena, api, sessionOptions, keys, values, numProps);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ protected ExecutionProviderObjectConfig(Map<String, String> properties) {

@Override
protected final void appendToSessionOptions(
Arena memorySession,
Arena arena,
ApiImpl api,
MemorySegment sessionOptions,
MemorySegment keys,
MemorySegment values,
int numProperties) {
MemorySegment config = api.create(memorySession, out -> create(api, out));
MemorySegment config = api.create(arena, out -> create(api, out));
try {
api.checkStatus(update(api, config, keys, values, numProperties));
api.checkStatus(append(api, sessionOptions, config));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ protected ExecutionProviderOpenVINOConfig(Map<String, String> properties) {

@Override
protected void appendToSessionOptions(
Arena memorySession,
Arena arena,
ApiImpl api,
MemorySegment sessionOptions,
MemorySegment keys,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ protected ExecutionProviderROCMConfig(Map<String, String> properties) {
}

@Override
final void appendToSessionOptions(Arena memorySession, ApiImpl api, MemorySegment sessionOptions) {
MemorySegment config = OrtROCMProviderOptions.allocate(memorySession);
final void appendToSessionOptions(Arena arena, ApiImpl api, MemorySegment sessionOptions) {
MemorySegment config = OrtROCMProviderOptions.allocate(arena);
copyInteger("device_id", config, OrtROCMProviderOptions::device_id);
copyInteger("miopen_conv_exhaustive_search", config, OrtROCMProviderOptions::miopen_conv_exhaustive_search);
copyLong("gpu_mem_limit", config, OrtROCMProviderOptions::gpu_mem_limit);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ static final ExecutionProviderConfigFactory of(String name) {

@Override
protected void appendToSessionOptions(
Arena memorySession,
Arena arena,
ApiImpl api,
MemorySegment sessionOptions,
MemorySegment keys,
MemorySegment values,
int numProperties) {
api.checkStatus(api.SessionOptionsAppendExecutionProvider.apply(
sessionOptions, memorySession.allocateFrom(name), keys, values, numProperties));
sessionOptions, arena.allocateFrom(name), keys, values, numProperties));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ protected ExecutionProviderVitisAIConfig(Map<String, String> properties) {

@Override
protected void appendToSessionOptions(
Arena memorySession,
Arena arena,
ApiImpl api,
MemorySegment sessionOptions,
MemorySegment keys,
Expand Down
26 changes: 11 additions & 15 deletions src/main/java/com/jyuzawa/onnxruntime/IoBindingImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

final class IoBindingImpl implements IoBinding {
private final ApiImpl api;
private final Arena memorySession;
private final Arena arena;
private final MemorySegment ioBinding;
private final MemorySegment runOptions;
private final NamedCollectionImpl<OnnxValue> inputs;
Expand All @@ -23,38 +23,34 @@ final class IoBindingImpl implements IoBinding {

IoBindingImpl(Builder builder) {
// NOTE: this is shared since we want to allow closing from another thread.
this.memorySession = Arena.ofShared();
this.arena = Arena.ofShared();
this.api = builder.api;
this.session = builder.session.address();
this.ioBinding = builder.api.create(memorySession, out -> builder.api.CreateIoBinding.apply(session, out));
this.runOptions = api.create(memorySession, out -> api.CreateRunOptions.apply(out));
this.ioBinding = builder.api.create(arena, out -> builder.api.CreateIoBinding.apply(session, out));
this.runOptions = api.create(arena, out -> api.CreateRunOptions.apply(out));
Map<String, String> config = builder.config;
if (config != null && !config.isEmpty()) {
for (Map.Entry<String, String> entry : config.entrySet()) {
api.checkStatus(api.AddRunConfigEntry.apply(
runOptions,
memorySession.allocateFrom(entry.getKey()),
memorySession.allocateFrom(entry.getValue())));
runOptions, arena.allocateFrom(entry.getKey()), arena.allocateFrom(entry.getValue())));
}
}
List<NodeInfoImpl> rawInputs = builder.inputs;
List<NodeInfoImpl> rawOutputs = builder.outputs;
this.closeables = new ArrayList<>(rawInputs.size() + rawOutputs.size());
ValueContext valueContext = new ValueContext(
builder.api,
memorySession,
memorySession,
arena,
builder.session.environment.ortAllocator,
builder.session.environment.memoryInfo,
closeables);
this.inputs = add(rawInputs, valueContext, memorySession, api, ioBinding, true);
this.outputs = add(rawOutputs, valueContext, memorySession, api, ioBinding, false);
this.inputs = add(rawInputs, valueContext, api, ioBinding, true);
this.outputs = add(rawOutputs, valueContext, api, ioBinding, false);
}

private static final NamedCollectionImpl<OnnxValue> add(
List<NodeInfoImpl> nodes,
ValueContext valueContext,
Arena memorySession,
ApiImpl api,
MemorySegment ioBinding,
boolean isInput) {
Expand Down Expand Up @@ -82,7 +78,7 @@ public void close() {
for (Runnable closeable : closeables) {
closeable.run();
}
memorySession.close();
arena.close();
}

@Override
Expand Down Expand Up @@ -168,8 +164,8 @@ public IoBinding setLogVerbosityLevel(int level) {

@Override
public IoBinding setRunTag(String runTag) {
try (Arena allocator = Arena.ofConfined()) {
MemorySegment segment = allocator.allocateFrom(runTag);
try (Arena arena = Arena.ofConfined()) {
MemorySegment segment = arena.allocateFrom(runTag);
api.checkStatus(api.RunOptionsSetRunTag.apply(runOptions, segment));
}
return this;
Expand Down
8 changes: 4 additions & 4 deletions src/main/java/com/jyuzawa/onnxruntime/ManagedImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@
abstract class ManagedImpl implements AutoCloseable {

protected final ApiImpl api;
protected final Arena memorySession;
protected final Arena arena;

protected ManagedImpl(ApiImpl api, Arena memorySession) {
protected ManagedImpl(ApiImpl api, Arena arena) {
this.api = api;
this.memorySession = memorySession;
this.arena = arena;
}

@Override
public void close() {
memorySession.close();
arena.close();
}

abstract MemorySegment address();
Expand Down
Loading

0 comments on commit d7909d8

Please sign in to comment.