From bcfb8d527be0ae5aa41028d58a3735039223bc69 Mon Sep 17 00:00:00 2001 From: shahasad <43590019+shahasad@users.noreply.github.com> Date: Thu, 24 Oct 2019 20:39:16 -0700 Subject: [PATCH] Fix model path marshalling in csharp, and re-enable the pretrained model tests (#2236) --- .../src/Microsoft.ML.OnnxRuntime/InferenceSession.cs | 6 +----- csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs | 12 ++++++++++-- .../src/Microsoft.ML.OnnxRuntime/SessionOptions.cs | 4 ++-- .../Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs | 4 +++- 4 files changed, 16 insertions(+), 10 deletions(-) diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.cs b/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.cs index 0310eb7db0cf1..068fae6f75392 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.cs @@ -215,11 +215,7 @@ private void Init(string modelPath, SessionOptions options) { var envHandle = OnnxRuntime.Handle; var session = IntPtr.Zero; - - if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) - NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateSession(envHandle, System.Text.Encoding.Unicode.GetBytes(modelPath), options.Handle, out session)); - else - NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateSession(envHandle, System.Text.Encoding.UTF8.GetBytes(modelPath), options.Handle, out session)); + NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateSession(envHandle, NativeMethods.GetPlatformSerializedString(modelPath), options.Handle, out session)); InitWithSessionHandle(session, options); } diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs index a7cb482d7dc43..18bcb1b50d39e 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs @@ -364,10 +364,10 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca ExecutionMode execution_mode); public static DOrtSetSessionExecutionMode OrtSetSessionExecutionMode; - public delegate IntPtr /*(OrtStatus*)*/ DOrtSetOptimizedModelFilePath(IntPtr /* OrtSessionOptions* */ options, [MarshalAs(UnmanagedType.LPWStr)]string optimizedModelFilepath); + public delegate IntPtr /*(OrtStatus*)*/ DOrtSetOptimizedModelFilePath(IntPtr /* OrtSessionOptions* */ options, byte[] optimizedModelFilepath); public static DOrtSetOptimizedModelFilePath OrtSetOptimizedModelFilePath; - public delegate IntPtr /*(OrtStatus*)*/ DOrtEnableProfiling(IntPtr /* OrtSessionOptions* */ options, string profilePathPrefix); + public delegate IntPtr /*(OrtStatus*)*/ DOrtEnableProfiling(IntPtr /* OrtSessionOptions* */ options, byte[] profilePathPrefix); public static DOrtEnableProfiling OrtEnableProfiling; public delegate IntPtr /*(OrtStatus*)*/ DOrtDisableProfiling(IntPtr /* OrtSessionOptions* */ options); @@ -659,5 +659,13 @@ public enum MemoryType public static DOrtReleaseValue OrtReleaseValue; #endregion + + public static byte[] GetPlatformSerializedString(string str) + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + return System.Text.Encoding.Unicode.GetBytes(str + Char.MinValue); + else + return System.Text.Encoding.UTF8.GetBytes(str + Char.MinValue); + } } //class NativeMethods } //namespace diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.cs b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.cs index 1eb0984270ce0..5a34d2dca52f4 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.cs @@ -201,7 +201,7 @@ public bool EnableProfiling { if (!_enableProfiling && value) { - NativeApiStatus.VerifySuccess(NativeMethods.OrtEnableProfiling(_nativePtr, ProfileOutputPathPrefix)); + NativeApiStatus.VerifySuccess(NativeMethods.OrtEnableProfiling(_nativePtr, NativeMethods.GetPlatformSerializedString(ProfileOutputPathPrefix))); _enableProfiling = true; } else if (_enableProfiling && !value) @@ -226,7 +226,7 @@ public string OptimizedModelFilePath { if (value != _optimizedModelFilePath) { - NativeApiStatus.VerifySuccess(NativeMethods.OrtSetOptimizedModelFilePath(_nativePtr, value)); + NativeApiStatus.VerifySuccess(NativeMethods.OrtSetOptimizedModelFilePath(_nativePtr, NativeMethods.GetPlatformSerializedString(value))); _optimizedModelFilePath = value; } } diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs index 4072fd3f26f1e..c566e8853a924 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs @@ -344,6 +344,8 @@ private static Dictionary GetSkippedModels() if (System.Environment.Is64BitProcess == false) { skipModels["test_vgg19"] = "Get preallocated buffer for initializer conv4_4_b_0 failed"; + skipModels["tf_pnasnet_large"] = "Get preallocated buffer for initializer ConvBnFusion_BN_B_cell_5/comb_iter_1/left/bn_sep_7x7_1/beta:0_203 failed"; + skipModels["tf_nasnet_large"] = "Get preallocated buffer for initializer ConvBnFusion_BN_B_cell_11/beginning_bn/beta:0_331 failed"; } return skipModels; @@ -390,7 +392,7 @@ public static IEnumerable GetSkippedModelForTest() } - [Theory(Skip = "TestPreTrainedModels is flaky and is blocking CI build progress. Enable it once this is fixed.")] + [Theory] [MemberData(nameof(GetModelsForTest))] [MemberData(nameof(GetSkippedModelForTest), Skip = "Skipped due to Error, please fix the error and enable the test")] private void TestPreTrainedModels(string opset, string modelName)