Skip to content

Commit

Permalink
Add pytorch backend, tests, examples
Browse files Browse the repository at this point in the history
  • Loading branch information
rahulchaphalkar committed Oct 8, 2024
1 parent 9bc918f commit 32d61b2
Show file tree
Hide file tree
Showing 16 changed files with 1,512 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
use anyhow::{Context, Result};
use std::fs;
use test_programs::nn::{sort_results, wit};

pub fn main() -> Result<()> {
let model = fs::read("fixture/model.pt")
.context("the model file to be mapped to the fixture directory")?;
let graph = wit::load(
&[model],
wit::GraphEncoding::Pytorch,
wit::ExecutionTarget::Cpu,
)?;
let tensor = fs::read("fixture/kitten.tensor")
.context("the tensor file to be mapped to the fixture directory")?;
let output_buffer = wit::classify(graph, ("input", tensor), "output")?;
let result = softmax(output_buffer);
let top_five = &sort_results(&result)[..5];
assert_eq!(top_five[0].class_id(), 281);
println!("found results, sorted top 5: {top_five:?}");
Ok(())
}

fn softmax(output_tensor: Vec<f32>) -> Vec<f32> {
let max_val = output_tensor
.iter()
.cloned()
.fold(f32::NEG_INFINITY, f32::max);

// Compute the exponential of each element subtracted by max_val for numerical stability.
let exps: Vec<f32> = output_tensor.iter().map(|&x| (x - max_val).exp()).collect();

// Compute the sum of the exponentials.
let sum_exps: f32 = exps.iter().sum();

// Normalize each element to get the probabilities.
exps.iter().map(|&exp| exp / sum_exps).collect()
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
use anyhow::{Context, Result};
use std::fs;
use test_programs::nn::{sort_results, witx};

pub fn main() -> Result<()> {
let model = fs::read("fixture/model.pt")
.context("the model file to be mapped to the fixture directory")?;
let graph = witx::load(
&[&model],
witx::GraphEncoding::Pytorch,
witx::ExecutionTarget::CPU,
)?;
let tensor = fs::read("fixture/kitten.tensor")
.context("the tensor file to be mapped to the fixture directory")?;
let output_buffer = witx::classify(graph, tensor)?;
let result = softmax(output_buffer);
let top_five = &sort_results(&result)[..5];
assert_eq!(top_five[0].class_id(), 281);
println!("found results, sorted top 5: {top_five:?}");
Ok(())
}

fn softmax(output_tensor: Vec<f32>) -> Vec<f32> {
let max_val = output_tensor
.iter()
.cloned()
.fold(f32::NEG_INFINITY, f32::max);

// Compute the exponential of each element subtracted by max_val for numerical stability.
let exps: Vec<f32> = output_tensor.iter().map(|&x| (x - max_val).exp()).collect();

// Compute the sum of the exponentials.
let sum_exps: f32 = exps.iter().sum();

// Normalize each element to get the probabilities.
exps.iter().map(|&exp| exp / sum_exps).collect()
}
5 changes: 4 additions & 1 deletion crates/wasi-nn/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ ort = { version = "2.0.0-rc.2", default-features = false, features = [
"copy-dylibs",
"download-binaries",
], optional = true }
tch = { version = "0.17.0", default-features = false, optional = true}

[target.'cfg(windows)'.dependencies.windows]
version = "0.52"
Expand Down Expand Up @@ -69,7 +70,9 @@ openvino = ["dep:openvino"]
onnx = ["dep:ort"]
# WinML is only available on Windows 10 1809 and later.
winml = ["dep:windows"]
# PyTorch is available on all platforms; requires Libtorch to be installed
pytorch = ["dep:tch"]

[[test]]
name = "test-programs"
harness = false
harness = false
196 changes: 196 additions & 0 deletions crates/wasi-nn/examples/classification-example-pytorch/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 17 additions & 0 deletions crates/wasi-nn/examples/classification-example-pytorch/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
[package]
name = "wasi-nn-example-pytorch"
version = "0.0.0"
edition = "2021"
publish = false

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
wasi-nn = "0.6.0"
anyhow = "1.0.86"
image = { version = "0.25.2", default-features = false, features = ["png"] }

# This crate is built with the wasm32-wasip1 target, so it's separate
# from the main Wasmtime build, so use this directive to exclude it
# from the parent directory's workspace.
[workspace]
15 changes: 15 additions & 0 deletions crates/wasi-nn/examples/classification-example-pytorch/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
This example project demonstrates using the `wasi-nn` API to perform PyTorch based inference. It consists of Rust code that is built using the `wasm32-wasip1` target.

To run this example:
1. Ensure you set appropriate Libtorch enviornment variables according to [tch-rs instructions]( https://github.com/LaurentMazare/tch-rs?tab=readme-ov-file#libtorch-manual-install).
- Requires the C++ PyTorch library (libtorch) in version *v2.4.0* to be available on
your system.
- `export LIBTORCH=/path/to/libtorch`
2. Build Wasmtime with `wasmtime-wasi-nn/pytorch` feature.
3. Navigate to this example directory `crates/wasi-nn/examples/classification-example-pytorch`.
4. Build this example `cargo build --target=wasm32-wasip1`.
5. Run the generated wasm file with wasmtime after mapping the directory containing Resnet18 `model.pt` and sample image `kitten.png`
```
${Wasmtime_root_dir}/target/debug/wasmtime -S nn --dir ${Wasmtime_root_dir}/crates/wasi-nn/examples/classification-example-pytorch::. ${Wasmtime_root_dir}/crates/wasi-nn/examples/classification-example-pytorch/target/wasm32-wasip1/debug/wasi-nn-example-pytorch.wasm
```
6. Check that result `281` has highest probability, which corresponds to `tabby cat`.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Loading

0 comments on commit 32d61b2

Please sign in to comment.