Skip to content

Commit

Permalink
Fix db-pedia-infer backend
Browse files Browse the repository at this point in the history
  • Loading branch information
laggui committed Jan 22, 2025
1 parent b33bd24 commit 79d6e19
Showing 1 changed file with 10 additions and 22 deletions.
32 changes: 10 additions & 22 deletions examples/text-classification/examples/db-pedia-infer.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
use text_classification::DbPediaDataset;

use burn::tensor::backend::AutodiffBackend;
use burn::tensor::backend::Backend;

#[cfg(not(feature = "f16"))]
#[allow(dead_code)]
type ElemType = f32;
#[cfg(feature = "f16")]
type ElemType = burn::tensor::f16;

pub fn launch<B: AutodiffBackend>(device: B::Device) {
pub fn launch<B: Backend>(device: B::Device) {
text_classification::inference::infer::<B, DbPediaDataset>(
device,
"/tmp/text-classification-db-pedia",
Expand All @@ -34,24 +34,18 @@ pub fn launch<B: AutodiffBackend>(device: B::Device) {
feature = "ndarray-blas-accelerate",
))]
mod ndarray {
use burn::backend::{
ndarray::{NdArray, NdArrayDevice},
Autodiff,
};
use burn::backend::ndarray::{NdArray, NdArrayDevice};

use crate::{launch, ElemType};

pub fn run() {
launch::<Autodiff<NdArray<ElemType>>>(NdArrayDevice::Cpu);
launch::<NdArray<ElemType>>(NdArrayDevice::Cpu);
}
}

#[cfg(feature = "tch-gpu")]
mod tch_gpu {
use burn::backend::{
libtorch::{LibTorch, LibTorchDevice},
Autodiff,
};
use burn::backend::libtorch::{LibTorch, LibTorchDevice};

use crate::{launch, ElemType};

Expand All @@ -61,35 +55,29 @@ mod tch_gpu {
#[cfg(target_os = "macos")]
let device = LibTorchDevice::Mps;

launch::<Autodiff<LibTorch<ElemType>>>(device);
launch::<LibTorch<ElemType>>(device);
}
}

#[cfg(feature = "tch-cpu")]
mod tch_cpu {
use burn::backend::{
tch::{LibTorch, LibTorchDevice},
Autodiff,
};
use burn::backend::tch::{LibTorch, LibTorchDevice};

use crate::{launch, ElemType};

pub fn run() {
launch::<Autodiff<LibTorch<ElemType>>>(LibTorchDevice::Cpu);
launch::<LibTorch<ElemType>>(LibTorchDevice::Cpu);
}
}

#[cfg(feature = "wgpu")]
mod wgpu {
use burn::backend::{
wgpu::{Wgpu, WgpuDevice},
Autodiff,
};
use burn::backend::wgpu::{Wgpu, WgpuDevice};

use crate::{launch, ElemType};

pub fn run() {
launch::<Autodiff<Wgpu<ElemType, i32>>>(WgpuDevice::default());
launch::<Wgpu<ElemType, i32>>(WgpuDevice::default());
}
}

Expand Down

0 comments on commit 79d6e19

Please sign in to comment.