Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup csharp API SessionOptions and RunOptions to be consistent with other APIs #1570

Merged
merged 11 commits into from
Aug 14, 2019
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ static void UseApi()

// Optional : Create session options and set the graph optimization level for the session
SessionOptions options = new SessionOptions();
options.SetSessionGraphOptimizationLevel(2);
options.GraphOptimizationLevel = 2;

using (var session = new InferenceSession(modelPath, options))
{
Expand Down
154 changes: 83 additions & 71 deletions csharp/src/Microsoft.ML.OnnxRuntime/InferenceSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ public class InferenceSession : IDisposable
{
protected IntPtr _nativeHandle;
protected Dictionary<string, NodeMetadata> _inputMetadata, _outputMetadata;
private SessionOptions _builtInSessionOptions = null;
private RunOptions _builtInRunOptions = null;


#region Public API
Expand All @@ -28,63 +30,26 @@ public class InferenceSession : IDisposable
/// </summary>
/// <param name="modelPath"></param>
public InferenceSession(string modelPath)
: this(modelPath, SessionOptions.Default)
{
_builtInSessionOptions = new SessionOptions(); // need to be disposed
Init(modelPath, _builtInSessionOptions);
}


/// <summary>
/// Constructs an InferenceSession from a model file, with some additional session options
/// </summary>
/// <param name="modelPath"></param>
/// <param name="options"></param>
public InferenceSession(string modelPath, SessionOptions options)
{
var envHandle = OnnxRuntime.Handle;

_nativeHandle = IntPtr.Zero;
try
{
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateSession(envHandle, System.Text.Encoding.Unicode.GetBytes(modelPath), options._nativePtr, out _nativeHandle));
else
NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateSession(envHandle, System.Text.Encoding.UTF8.GetBytes(modelPath), options._nativePtr, out _nativeHandle));

// Initialize input/output metadata
_inputMetadata = new Dictionary<string, NodeMetadata>();
_outputMetadata = new Dictionary<string, NodeMetadata>();

// get input count
UIntPtr inputCount = UIntPtr.Zero;
NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetInputCount(_nativeHandle, out inputCount));

// get all the output names
for (ulong i = 0; i < (ulong)inputCount; i++)
{
var iname = GetInputName(i);
_inputMetadata[iname] = GetInputMetadata(i);
}
// get output count
UIntPtr outputCount = UIntPtr.Zero;
NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetOutputCount(_nativeHandle, out outputCount));

// get all the output names
for (ulong i = 0; i < (ulong)outputCount; i++)
{
_outputMetadata[GetOutputName(i)] = GetOutputMetadata(i);
}

}
catch (OnnxRuntimeException e)
{
if (_nativeHandle != IntPtr.Zero)
{
NativeMethods.OrtReleaseSession(_nativeHandle);
_nativeHandle = IntPtr.Zero;
}
throw e;
}
Init(modelPath, options);
}


/// <summary>
/// Meta data regarding the input nodes, keyed by input names
/// </summary>
public IReadOnlyDictionary<string, NodeMetadata> InputMetadata
{
get
Expand All @@ -93,6 +58,9 @@ public IReadOnlyDictionary<string, NodeMetadata> InputMetadata
}
}

/// <summary>
/// Metadata regarding the output nodes, keyed by output names
/// </summary>
public IReadOnlyDictionary<string, NodeMetadata> OutputMetadata
{
get
Expand All @@ -101,11 +69,12 @@ public IReadOnlyDictionary<string, NodeMetadata> OutputMetadata
}
}


/// <summary>
/// Runs the loaded model for the given inputs, and fetches all the outputs.
/// </summary>
/// <param name="inputs"></param>
/// <returns>Output Tensors in a Collection of NamedOnnxValue</returns>
/// <returns>Output Tensors in a Collection of NamedOnnxValue. User must dispose the output.</returns>
public IDisposableReadOnlyCollection<DisposableNamedOnnxValue> Run(IReadOnlyCollection<NamedOnnxValue> inputs)
{
string[] outputNames = new string[_outputMetadata.Count];
Expand All @@ -118,21 +87,22 @@ public IDisposableReadOnlyCollection<DisposableNamedOnnxValue> Run(IReadOnlyColl
/// </summary>
/// <param name="inputs"></param>
/// <param name="outputNames"></param>
/// <returns>Output Tensors in a Collection of NamedOnnxValue</returns>
/// <returns>Output Tensors in a Collection of NamedOnnxValue. User must dispose the output.</returns>
public IDisposableReadOnlyCollection<DisposableNamedOnnxValue> Run(IReadOnlyCollection<NamedOnnxValue> inputs, IReadOnlyCollection<string> outputNames)
{
return Run(inputs, outputNames, RunOptions.Default);
IDisposableReadOnlyCollection<DisposableNamedOnnxValue> result = null;
result = Run(inputs, outputNames, _builtInRunOptions);
return result;
}

/// <summary>
/// Runs the loaded model for the given inputs, and fetches the specified outputs in <paramref name="outputNames"/>.
/// Runs the loaded model for the given inputs, and fetches the specified outputs in <paramref name="outputNames". Uses the given RunOptions for this run./>.
/// </summary>
/// <param name="inputs"></param>
/// <param name="outputNames"></param>
/// <param name="options"></param>
/// <returns>Output Tensors in a Collection of NamedOnnxValue</returns>
//TODO: kept internal until RunOptions is made public
internal IDisposableReadOnlyCollection<DisposableNamedOnnxValue> Run(IReadOnlyCollection<NamedOnnxValue> inputs, IReadOnlyCollection<string> outputNames, RunOptions options)
/// <returns>Output Tensors in a Collection of NamedOnnxValue. User must dispose the output.</returns>
public IDisposableReadOnlyCollection<DisposableNamedOnnxValue> Run(IReadOnlyCollection<NamedOnnxValue> inputs, IReadOnlyCollection<string> outputNames, RunOptions options)
{
var inputNames = new string[inputs.Count];
var inputTensors = new IntPtr[inputs.Count];
Expand Down Expand Up @@ -211,6 +181,58 @@ internal ModelMetadata ModelMetadata
#endregion

#region private methods

protected void Init(string modelPath, SessionOptions options)
{
var envHandle = OnnxRuntime.Handle;

_nativeHandle = IntPtr.Zero;
try
{
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateSession(envHandle, System.Text.Encoding.Unicode.GetBytes(modelPath), options.Handle, out _nativeHandle));
else
NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateSession(envHandle, System.Text.Encoding.UTF8.GetBytes(modelPath), options.Handle, out _nativeHandle));

// Initialize input/output metadata
_inputMetadata = new Dictionary<string, NodeMetadata>();
_outputMetadata = new Dictionary<string, NodeMetadata>();

// get input count
UIntPtr inputCount = UIntPtr.Zero;
NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetInputCount(_nativeHandle, out inputCount));

// get all the output names
for (ulong i = 0; i < (ulong)inputCount; i++)
{
var iname = GetInputName(i);
_inputMetadata[iname] = GetInputMetadata(i);
}
// get output count
UIntPtr outputCount = UIntPtr.Zero;
NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetOutputCount(_nativeHandle, out outputCount));

// get all the output names
for (ulong i = 0; i < (ulong)outputCount; i++)
{
_outputMetadata[GetOutputName(i)] = GetOutputMetadata(i);
}

}
catch (OnnxRuntimeException e)
{
if (_nativeHandle != IntPtr.Zero)
{
NativeMethods.OrtReleaseSession(_nativeHandle);
_nativeHandle = IntPtr.Zero;
}
throw e;
}

_builtInRunOptions = new RunOptions(); // create a default built-in run option, and avoid creating a new one every run() call
}


private string GetOutputName(ulong index)
{
IntPtr nameHandle = IntPtr.Zero;
Expand Down Expand Up @@ -358,6 +380,15 @@ protected virtual void Dispose(bool disposing)
if (disposing)
{
// cleanup managed resources
if (_builtInSessionOptions != null)
{
_builtInSessionOptions.Dispose();
}

if (_builtInRunOptions != null)
{
_builtInRunOptions.Dispose();
}
}

// cleanup unmanaged resources
Expand Down Expand Up @@ -426,24 +457,5 @@ internal class ModelMetadata
//TODO: placeholder for Model metadata. Currently C-API does not expose this.
}

/// Sets various runtime options.
/// TODO: currently uses Default options only. kept internal until fully implemented
internal class RunOptions
{
protected static readonly Lazy<RunOptions> _default = new Lazy<RunOptions>(() => new RunOptions());

public static RunOptions Default
{
get
{
return _default.Value;
}
}

private void RuntOptions()
{

}
}

}
31 changes: 31 additions & 0 deletions csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.cs
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,37 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca
//[DllImport(nativeLib, CharSet = charSet)]
//public static extern void OrtAddCustomOp(IntPtr /*(OrtSessionOptions*)*/ options, string custom_op_path);

#endregion

#region RunOptions API
[DllImport(nativeLib, CharSet = charSet)]
public static extern IntPtr /*(OrtStatus*)*/ OrtCreateRunOptions( out IntPtr /* OrtRunOptions** */ runOptions);

[DllImport(nativeLib, CharSet = charSet)]
public static extern void OrtReleaseRunOptions(IntPtr /*(OrtRunOptions*)*/options);

[DllImport(nativeLib, CharSet = charSet)]
public static extern IntPtr /*(OrtStatus*)*/ OrtRunOptionsSetRunLogVerbosityLevel(IntPtr /* OrtRunOptions* */ options, LogLevel value);

[DllImport(nativeLib, CharSet = charSet)]
public static extern IntPtr /*(OrtStatus*)*/ OrtRunOptionsSetRunTag(IntPtr /* OrtRunOptions* */ options, string /* const char* */ runTag);

[DllImport(nativeLib, CharSet = charSet)]
public static extern IntPtr /*(OrtStatus*)*/ OrtRunOptionsGetRunLogVerbosityLevel(IntPtr /* OrtRunOptions* */ options, out LogLevel verbosityLevel);

[DllImport(nativeLib, CharSet = charSet)]
public static extern IntPtr /*(OrtStatus*)*/ OrtRunOptionsGetRunTag(IntPtr /* const OrtRunOptions* */options, out IntPtr /* const char** */ runtag);

// Set a flag so that any running OrtRun* calls that are using this instance of OrtRunOptions
// will exit as soon as possible if the flag is true.
[DllImport(nativeLib, CharSet = charSet)]
public static extern IntPtr /*(OrtStatus*)*/ OrtRunOptionsEnableTerminate(IntPtr /* OrtRunOptions* */ options);

[DllImport(nativeLib, CharSet = charSet)]
public static extern IntPtr /*(OrtStatus*)*/ OrtRunOptionsDisableTerminate(IntPtr /* OrtRunOptions* */ options);



#endregion

#region Allocator/AllocatorInfo API
Expand Down
32 changes: 31 additions & 1 deletion csharp/src/Microsoft.ML.OnnxRuntime/OnnxRuntime.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ internal struct GlobalOptions //Options are currently not accessible to user
public LogLevel LogLevel { get; set; }
}

internal enum LogLevel
public enum LogLevel
{
Verbose = 0,
Info = 1,
Expand Down Expand Up @@ -51,6 +51,9 @@ public override bool IsInvalid
private OnnxRuntime() //Problem: it is not possible to pass any option for a Singleton
:base(IntPtr.Zero, true)
{
// Check LibC version on Linux, before doing any onnxruntime initialization
CheckLibcVersionGreaterThanMinimum();

handle = IntPtr.Zero;
try
{
Expand Down Expand Up @@ -78,5 +81,32 @@ protected override bool ReleaseHandle()
Delete(handle);
return true;
}

[DllImport("libc", ExactSpelling = true, CallingConvention = CallingConvention.Cdecl)]
private static extern IntPtr gnu_get_libc_version();

private static void CheckLibcVersionGreaterThanMinimum()
{
// require libc version 2.23 or higher
var minVersion = new Version(2, 23);
var curVersion = new Version(0, 0);
if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux))
{
try
{
curVersion = Version.Parse(Marshal.PtrToStringAnsi(gnu_get_libc_version()));
if (curVersion >= minVersion)
return;
}
catch (Exception)
{
// trap any obscure exception
}
throw new OnnxRuntimeException(ErrorCode.RuntimeException,
$"libc.so version={curVersion} does not meet the minimun of 2.23 required by OnnxRuntime. " +
"Linux distribution should be similar to Ubuntu 16.04 or higher");
}
}

}
}
Loading