Skip to content

Commit

Permalink
Fix model path marshalling in csharp, and re-enable the pretrained mo…
Browse files Browse the repository at this point in the history
…del tests (#2236)
  • Loading branch information
shahasad authored and snnn committed Oct 25, 2019
1 parent 90dea14 commit bcfb8d5
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 10 deletions.
6 changes: 1 addition & 5 deletions csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -215,11 +215,7 @@ private void Init(string modelPath, SessionOptions options)
{
var envHandle = OnnxRuntime.Handle;
var session = IntPtr.Zero;

if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateSession(envHandle, System.Text.Encoding.Unicode.GetBytes(modelPath), options.Handle, out session));
else
NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateSession(envHandle, System.Text.Encoding.UTF8.GetBytes(modelPath), options.Handle, out session));
NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateSession(envHandle, NativeMethods.GetPlatformSerializedString(modelPath), options.Handle, out session));

InitWithSessionHandle(session, options);
}
Expand Down
12 changes: 10 additions & 2 deletions csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs
Original file line number Diff line number Diff line change
Expand Up @@ -364,10 +364,10 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca
ExecutionMode execution_mode);
public static DOrtSetSessionExecutionMode OrtSetSessionExecutionMode;

public delegate IntPtr /*(OrtStatus*)*/ DOrtSetOptimizedModelFilePath(IntPtr /* OrtSessionOptions* */ options, [MarshalAs(UnmanagedType.LPWStr)]string optimizedModelFilepath);
public delegate IntPtr /*(OrtStatus*)*/ DOrtSetOptimizedModelFilePath(IntPtr /* OrtSessionOptions* */ options, byte[] optimizedModelFilepath);
public static DOrtSetOptimizedModelFilePath OrtSetOptimizedModelFilePath;

public delegate IntPtr /*(OrtStatus*)*/ DOrtEnableProfiling(IntPtr /* OrtSessionOptions* */ options, string profilePathPrefix);
public delegate IntPtr /*(OrtStatus*)*/ DOrtEnableProfiling(IntPtr /* OrtSessionOptions* */ options, byte[] profilePathPrefix);
public static DOrtEnableProfiling OrtEnableProfiling;

public delegate IntPtr /*(OrtStatus*)*/ DOrtDisableProfiling(IntPtr /* OrtSessionOptions* */ options);
Expand Down Expand Up @@ -659,5 +659,13 @@ public enum MemoryType
public static DOrtReleaseValue OrtReleaseValue;

#endregion

public static byte[] GetPlatformSerializedString(string str)
{
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
return System.Text.Encoding.Unicode.GetBytes(str + Char.MinValue);
else
return System.Text.Encoding.UTF8.GetBytes(str + Char.MinValue);
}
} //class NativeMethods
} //namespace
4 changes: 2 additions & 2 deletions csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ public bool EnableProfiling
{
if (!_enableProfiling && value)
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtEnableProfiling(_nativePtr, ProfileOutputPathPrefix));
NativeApiStatus.VerifySuccess(NativeMethods.OrtEnableProfiling(_nativePtr, NativeMethods.GetPlatformSerializedString(ProfileOutputPathPrefix)));
_enableProfiling = true;
}
else if (_enableProfiling && !value)
Expand All @@ -226,7 +226,7 @@ public string OptimizedModelFilePath
{
if (value != _optimizedModelFilePath)
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtSetOptimizedModelFilePath(_nativePtr, value));
NativeApiStatus.VerifySuccess(NativeMethods.OrtSetOptimizedModelFilePath(_nativePtr, NativeMethods.GetPlatformSerializedString(value)));
_optimizedModelFilePath = value;
}
}
Expand Down
4 changes: 3 additions & 1 deletion csharp/test/Microsoft.ML.OnnxRuntime.Tests/InferenceTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,8 @@ private static Dictionary<string, string> GetSkippedModels()
if (System.Environment.Is64BitProcess == false)
{
skipModels["test_vgg19"] = "Get preallocated buffer for initializer conv4_4_b_0 failed";
skipModels["tf_pnasnet_large"] = "Get preallocated buffer for initializer ConvBnFusion_BN_B_cell_5/comb_iter_1/left/bn_sep_7x7_1/beta:0_203 failed";
skipModels["tf_nasnet_large"] = "Get preallocated buffer for initializer ConvBnFusion_BN_B_cell_11/beginning_bn/beta:0_331 failed";
}

return skipModels;
Expand Down Expand Up @@ -390,7 +392,7 @@ public static IEnumerable<object[]> GetSkippedModelForTest()
}


[Theory(Skip = "TestPreTrainedModels is flaky and is blocking CI build progress. Enable it once this is fixed.")]
[Theory]
[MemberData(nameof(GetModelsForTest))]
[MemberData(nameof(GetSkippedModelForTest), Skip = "Skipped due to Error, please fix the error and enable the test")]
private void TestPreTrainedModels(string opset, string modelName)
Expand Down

0 comments on commit bcfb8d5

Please sign in to comment.