Skip to content

Commit

Permalink
Fix no default features flags + update cubecl (#2725)
Browse files Browse the repository at this point in the history
  • Loading branch information
laggui authored Jan 21, 2025
1 parent 140ea75 commit b33bd24
Show file tree
Hide file tree
Showing 14 changed files with 132 additions and 53 deletions.
157 changes: 114 additions & 43 deletions Cargo.lock

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ ratatui = "0.29.0"

# WGPU stuff
text_placeholder = "0.5.1"
wgpu = "23.0.0"
wgpu = "24.0.0"

# Benchmarks and Burnbench
arboard = "3.4.1"
Expand Down Expand Up @@ -153,8 +153,8 @@ ahash = { version = "0.8.11", default-features = false }
portable-atomic-util = { version = "0.2.4", features = ["alloc"] }

### For the main burn branch. ###
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "2cc42af02671d90255ab823e29a4a3ad2e564333" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "2cc42af02671d90255ab823e29a4a3ad2e564333" }
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "dd0c2fa86bc58133cde596e13ddb1d7ad33ac44c" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "dd0c2fa86bc58133cde596e13ddb1d7ad33ac44c" }
### For local development. ###
# cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-autodiff/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ std = []
async = [] # Require std

[dependencies]
burn-common = { path = "../burn-common", version = "0.17.0" }
burn-common = { path = "../burn-common", version = "0.17.0", default-features = false }
burn-tensor = { path = "../burn-tensor", version = "0.17.0", default-features = false }
burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "0.17.0", optional = true }

Expand Down
1 change: 1 addition & 0 deletions crates/burn-core/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#![cfg_attr(not(feature = "std"), no_std)]
#![warn(missing_docs)]
#![cfg_attr(docsrs, feature(doc_auto_cfg))]
#![recursion_limit = "135"]

//! The core crate of Burn.
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-jit/src/kernel/conv/conv2d/im2col.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ fn im2col_kernel<F: Float>(

#[cfg(not(test))]
pub(crate) fn batches_per_run(batch_size: usize, out_h: usize, out_w: usize) -> Option<usize> {
let cube_count_per_batch = (out_h * out_w).div_ceil(cubecl::PLANE_DIM_APPROX);
let cube_count_per_batch = (out_h * out_w).div_ceil(burn_common::PLANE_DIM_APPROX);
let max_cube_count = u16::MAX as usize;
let max_simultaneous = (max_cube_count / cube_count_per_batch).min(batch_size);
if max_simultaneous == 0 {
Expand Down
3 changes: 2 additions & 1 deletion crates/burn-jit/src/kernel/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ pub use mask::*;
pub(crate) use unary_float::*;
pub(crate) use unary_numeric::*;

pub use cubecl::{Kernel, PLANE_DIM_APPROX};
pub use burn_common::PLANE_DIM_APPROX;
pub use cubecl::Kernel;

/// Convolution kernels
pub mod conv;
Expand Down
3 changes: 2 additions & 1 deletion crates/burn-jit/src/template/base.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::{element::JitElement, tensor::JitTensor, JitRuntime};
use cubecl::{prelude::*, Compiler, ExecutionMode, KernelId};
use burn_common::ExecutionMode;
use cubecl::{prelude::*, Compiler, KernelId};

use super::SourceTemplate;

Expand Down
2 changes: 1 addition & 1 deletion crates/burn-ndarray/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ blas-openblas-system = [

# ** Please make sure all dependencies support no_std when std is disabled **

burn-autodiff = { path = "../burn-autodiff", version = "0.17.0", optional = true }
burn-autodiff = { path = "../burn-autodiff", version = "0.17.0", default-features = false, optional = true }
burn-common = { path = "../burn-common", version = "0.17.0", default-features = false }
burn-tensor = { path = "../burn-tensor", version = "0.17.0", default-features = false, features = ["repr"] }

Expand Down
1 change: 1 addition & 0 deletions crates/burn-router/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#![cfg_attr(not(feature = "std"), no_std)]
#![warn(missing_docs)]
#![cfg_attr(docsrs, feature(doc_auto_cfg))]
#![recursion_limit = "138"]

//! Burn multi-backend router.
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-tensor/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ std = [
[dependencies]
burn-common = { path = "../burn-common", version = "0.17.0", default-features = false }
burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "0.17.0", optional = true }
cubecl = { workspace = true, optional = true, default-features = true }
cubecl = { workspace = true, optional = true, default-features = false }

bytemuck = { workspace = true, features = ["extern_crate_alloc"] }
colored = { workspace = true, optional = true }
Expand Down
2 changes: 1 addition & 1 deletion crates/burn-wgpu/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ pub use burn_jit::{
pub use burn_jit::{tensor::JitTensor, JitBackend};
pub use burn_jit::{BoolElement, FloatElement, IntElement};
pub use cubecl::flex32;
pub use cubecl::ir::CubeDim;
pub use cubecl::wgpu::*;
pub use cubecl::CubeDim;

pub type Wgsl = cubecl::wgpu::WgslCompiler;
#[cfg(feature = "spirv")]
Expand Down
1 change: 1 addition & 0 deletions examples/guide/src/bin/infer.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#![recursion_limit = "131"]
use burn::{backend::Wgpu, data::dataset::Dataset};
use guide::inference;

Expand Down
1 change: 1 addition & 0 deletions examples/image-classification-web/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#![cfg_attr(not(test), no_std)]
#![recursion_limit = "135"]

pub mod model;
pub mod web;
Expand Down
2 changes: 2 additions & 0 deletions examples/server/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#![recursion_limit = "141"]

pub fn start() {
let port = std::env::var("REMOTE_BACKEND_PORT")
.map(|port| match port.parse::<u16>() {
Expand Down

0 comments on commit b33bd24

Please sign in to comment.