diff --git a/dfdx-core/src/tensor/cuda/device.rs b/dfdx-core/src/tensor/cuda/device.rs index 156bdddd..de6f7196 100644 --- a/dfdx-core/src/tensor/cuda/device.rs +++ b/dfdx-core/src/tensor/cuda/device.rs @@ -76,7 +76,7 @@ impl Cuda { #[cfg(feature = "cudnn")] let cudnn = cudarc::cudnn::Cudnn::new(dev.clone())?; let par_stream = Arc::new(dev.fork_default_stream()?); - let workspace = Arc::new(Mutex::new(dev.alloc_zeros::(0)?)); + let workspace = Arc::new(Mutex::new(dev.alloc_zeros::(1)?)); Ok(Self { cpu, dev,