-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[C API]: Compile ONNX bindings with GPU support #9
Comments
Using this branch of the ONNX Rust bindings - https://github.com/radu-matei/onnxruntime-rs/tree/cuda, the following patch works with CUDA 10.2:
Patch: From 4315f65e7fd816f0568f4f12ee58640a61e6610b Mon Sep 17 00:00:00 2001
From: Radu M <[email protected]>
Date: Sun, 27 Jun 2021 11:12:39 +0000
Subject: [PATCH] Trying to enable CUDA
---
Cargo.lock | 6 ++++--
crates/wasi-nn-onnx-wasmtime/Cargo.toml | 2 +-
crates/wasi-nn-onnx-wasmtime/src/onnx_runtime.rs | 1 +
3 files changed, 6 insertions(+), 3 deletions(-)
diff --git a/Cargo.lock b/Cargo.lock
index 5fc7a4d..cf77155 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -1,5 +1,7 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
+version = 3
+
[[package]]
name = "addr2line"
version = "0.15.2"
@@ -1489,7 +1491,7 @@ checksum = "692fcb63b64b1758029e0a96ee63e049ce8c5948587f2f7208df04625e5f6b56"
[[package]]
name = "onnxruntime"
version = "0.0.12"
-source = "git+https://github.com/radu-matei/onnxruntime-rs?branch=owned-session#5f47f47b24793c0d0fbb314e854cc04395b9108f"
+source = "git+https://github.com/radu-matei/onnxruntime-rs?branch=cuda#2e5a8649def1d6cdcdd02018f9ae7c415d5f6c25"
dependencies = [
"lazy_static",
"ndarray",
@@ -1501,7 +1503,7 @@ dependencies = [
[[package]]
name = "onnxruntime-sys"
version = "0.0.12"
-source = "git+https://github.com/radu-matei/onnxruntime-rs?branch=owned-session#5f47f47b24793c0d0fbb314e854cc04395b9108f"
+source = "git+https://github.com/radu-matei/onnxruntime-rs?branch=cuda#2e5a8649def1d6cdcdd02018f9ae7c415d5f6c25"
dependencies = [
"flate2",
"tar",
diff --git a/crates/wasi-nn-onnx-wasmtime/Cargo.toml b/crates/wasi-nn-onnx-wasmtime/Cargo.toml
index a307e12..8979c8d 100644
--- a/crates/wasi-nn-onnx-wasmtime/Cargo.toml
+++ b/crates/wasi-nn-onnx-wasmtime/Cargo.toml
@@ -9,7 +9,7 @@ anyhow = "1.0"
byteorder = "1.4"
log = { version = "0.4", default-features = false }
ndarray = "0.15"
-onnxruntime = { git = "https://github.com/radu-matei/onnxruntime-rs", branch = "owned-session", optional = true }
+onnxruntime = { git = "https://github.com/radu-matei/onnxruntime-rs", branch = "cuda", optional = true }
thiserror = "1.0"
tract-data = "0.14"
tract-linalg = "0.14"
diff --git a/crates/wasi-nn-onnx-wasmtime/src/onnx_runtime.rs b/crates/wasi-nn-onnx-wasmtime/src/onnx_runtime.rs
index 1fd2d7e..71d4a16 100644
--- a/crates/wasi-nn-onnx-wasmtime/src/onnx_runtime.rs
+++ b/crates/wasi-nn-onnx-wasmtime/src/onnx_runtime.rs
@@ -141,6 +141,7 @@ impl WasiEphemeralNn for WasiNnOnnxCtx {
.build()?;
let session = environment
.new_owned_session_builder()?
+ .use_cuda()?
.with_optimization_level(GraphOptimizationLevel::All)?
.with_model_from_memory(model_bytes)?;
let session = OnnxSession::with_session(session)?;
--
2.17.1 Environment:
|
CUDA 10.2 might be hitting this issue - microsoft/onnxruntime#5957 In any case, the performance is significantly worse than it was expected with a Tesla P100, and I suspect it has to do with the CUDA version. |
For Windows, we should also try compiling with DirectML support - https://www.onnxruntime.ai/docs/reference/execution-providers/DirectML-ExecutionProvider.html |
I've created a PR nbigaouette/onnxruntime-rs#87 with CUDA 11 for ONNX 1.7 based on nbigaouette/onnxruntime-rs#78 I think it is what you're looking for testing. Feel free to review the branch and point out any key issues :) |
This would add the CUDA and DirectML headers and pull the appropriate shared object.
See nbigaouette/onnxruntime-rs#57
As I am not currently on a CUDA-enabled machine, labeling this as
help wanted
.The text was updated successfully, but these errors were encountered: