From 2b10aaa05d3752186899bd5b5364d92164edc7ef Mon Sep 17 00:00:00 2001 From: shua Date: Wed, 12 Jun 2024 08:15:32 +0200 Subject: [PATCH 1/7] implement Slice op (#2260) --- candle-onnx/src/eval.rs | 80 +++++++++++++++++++++++ candle-onnx/tests/ops.rs | 135 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 215 insertions(+) diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs index f52e6c5cca..10a3b9377b 100644 --- a/candle-onnx/src/eval.rs +++ b/candle-onnx/src/eval.rs @@ -14,6 +14,7 @@ pub fn dtype(dt: DataType) -> Option { DataType::Float16 => Some(DType::F16), DataType::Float => Some(DType::F32), DataType::Double => Some(DType::F64), + DataType::Bool => Some(DType::U8), _ => None, } } @@ -1053,6 +1054,85 @@ fn simple_eval_( ), } } + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#slice + "Slice" => { + let data = get(&node.input[0])?; + let starts = get(&node.input[1])?; + let ends = get(&node.input[2])?; + let default_axes; + let default_steps; + let axes: &Tensor; + let steps: &Tensor; + // If axes are omitted, they are set to [0, ..., r-1]. If steps are omitted, + // they are set to [1, ..., 1] of length len(starts) + match node.input.len() { + 3 => { + let len = starts.dims()[0]; + default_axes = Some(Tensor::arange(0, len as i64, starts.device())?); + axes = default_axes.as_ref().unwrap(); + default_steps = Some(Tensor::ones((len,), DType::I64, starts.device())?); + steps = default_steps.as_ref().unwrap(); + } + 4 => { + let len = starts.dims()[0]; + axes = get(&node.input[3])?; + default_steps = Some(Tensor::ones((len,), DType::I64, starts.device())?); + steps = default_steps.as_ref().unwrap(); + } + 5 => { + steps = get(&node.input[4])?; + axes = get(&node.input[3])?; + } + _ => bail!( + "Slice node is invalid, expected 3-5 inputs, got {}: {:?}", + node.input.len(), + node + ), + } + + let mut out = data.clone(); + for (i, axis) in axes.to_vec1::()?.into_iter().enumerate() { + // All negative elements of axes are made non-negative by + // adding r to them, where r = rank(input). + let axis = if axis < 0 { + axis + data.rank() as i64 + } else { + axis + } as usize; + + let data_dim = data.dims()[axis] as i64; + let mut s = starts.get(i)?.to_scalar::()?; + let mut e = ends.get(i)?.to_scalar::()?; + // All negative values in starts[i] and ends[i] have + // dims[axes[i]] added to them, where dims are the + // dimensions of input. + if s < 0 { + s += data_dim; + } + if e < 0 { + e += data_dim; + } + + let p = steps.get(i)?.to_scalar::()?; + // starts[i] is clamped into the range [0, dims[axes[i]]] + // for positive stepping and [0, dims[axes[i]]-1] for + // negative stepping. + // for positive stepping ends[axes[i]] is clamped to + // [0, dims[axes[i]]], while for negative stepping it is + // clamped to [-1, dims[axes[i]]-1]. + if p >= 0 { + s = s.clamp(0, data_dim); + e = e.clamp(0, data_dim); + } else { + s = s.clamp(0, data_dim - 1); + e = e.clamp(-1, data_dim - 1); + } + + let indexes = Tensor::arange_step(s, e, p, data.device())?; + out = out.index_select(&indexes, axis)? + } + values.insert(node.output[0].clone(), out); + } // https://onnx.ai/onnx/operators/onnx__ReduceMean.html#reducemean-13 // TODO: This version is only compatible with ReduceMean V13 and below. "ReduceMean" => { diff --git a/candle-onnx/tests/ops.rs b/candle-onnx/tests/ops.rs index b4299af1bc..82d38aa490 100644 --- a/candle-onnx/tests/ops.rs +++ b/candle-onnx/tests/ops.rs @@ -3272,3 +3272,138 @@ fn test_pad() -> Result<()> { assert_eq!(actual.to_vec2::()?, expected.to_vec2::()?); Ok(()) } + +#[test] +fn test_slice() -> Result<()> { + let model = create_model_proto_with_graph(Some(GraphProto { + node: vec![NodeProto { + op_type: "Slice".to_string(), + input: vec![ + "data".to_string(), + "starts".to_string(), + "ends".to_string(), + "axes".to_string(), + "steps".to_string(), + ], + output: vec!["result".to_string()], + ..NodeProto::default() + }], + input: ["data", "starts", "ends", "axes", "steps"] + .into_iter() + .map(|name| ValueInfoProto { + name: name.to_string(), + r#type: None, + doc_string: "".to_string(), + }) + .collect(), + output: ["result"] + .into_iter() + .map(|name| ValueInfoProto { + name: name.to_string(), + r#type: None, + doc_string: "".to_string(), + }) + .collect(), + ..GraphProto::default() + })); + + /* + data = [ + [1, 2, 3, 4], + [5, 6, 7, 8], + ] + axes = [0, 1] + starts = [1, 0] + ends = [2, 3] + steps = [1, 2] + result = [ + [5, 7], + ] + */ + + let outputs = candle_onnx::simple_eval( + &model, + HashMap::from_iter([ + ( + "data".to_string(), + Tensor::from_vec(vec![1i64, 2, 3, 4, 5, 6, 7, 8], (2, 4), &Device::Cpu)?, + ), + ( + "starts".to_string(), + Tensor::from_vec(vec![1i64, 0], (2,), &Device::Cpu)?, + ), + ( + "ends".to_string(), + Tensor::from_vec(vec![2i64, 3], (2,), &Device::Cpu)?, + ), + ( + "axes".to_string(), + Tensor::from_vec(vec![0i64, 1], (2,), &Device::Cpu)?, + ), + ( + "steps".to_string(), + Tensor::from_vec(vec![1i64, 2], (2,), &Device::Cpu)?, + ), + ]), + )?; + let actual = outputs.get("result").unwrap().to_vec2::()?; + assert_eq!(actual, vec![vec![5i64, 7]]); + + /* + data = [ + [1, 2, 3, 4], + [5, 6, 7, 8], + ] + starts = [0, 1] + ends = [-1, 1000] + result = [ + [2, 3, 4], + ] + */ + let model = create_model_proto_with_graph(Some(GraphProto { + node: vec![NodeProto { + op_type: "Slice".to_string(), + input: vec!["data".to_string(), "starts".to_string(), "ends".to_string()], + output: vec!["result".to_string()], + ..NodeProto::default() + }], + input: ["data", "starts", "ends"] + .into_iter() + .map(|name| ValueInfoProto { + name: name.to_string(), + r#type: None, + doc_string: "".to_string(), + }) + .collect(), + output: ["result"] + .into_iter() + .map(|name| ValueInfoProto { + name: name.to_string(), + r#type: None, + doc_string: "".to_string(), + }) + .collect(), + ..GraphProto::default() + })); + let outputs = candle_onnx::simple_eval( + &model, + HashMap::from_iter([ + ( + "data".to_string(), + Tensor::from_vec(vec![1i64, 2, 3, 4, 5, 6, 7, 8], (2, 4), &Device::Cpu)?, + ), + ( + "starts".to_string(), + Tensor::from_vec(vec![0i64, 1], (2,), &Device::Cpu)?, + ), + ( + "ends".to_string(), + Tensor::from_vec(vec![-1i64, 1000], (2,), &Device::Cpu)?, + ), + ]), + )?; + let actual = outputs.get("result").unwrap().to_vec2::()?; + assert_eq!(actual, vec![vec![2i64, 3, 4]]); + + Ok(()) +} From 36cf54525d93660f62c3601ba0988653f3567e0e Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Tue, 18 Jun 2024 23:46:58 +0200 Subject: [PATCH 2/7] Fix the fast bf16 gemm cublas kernels. (#2274) * Use flash-attn in gemma. * Fix for the fast bf16 cublas gemm. * Fix some clippy lints. * Fix another lint. * Proper clippy fix. --- candle-core/examples/cuda_basics.rs | 5 ++++- candle-core/src/cpu_backend/mod.rs | 3 ++- candle-core/src/cpu_backend/utils.rs | 20 +++++++++++++++----- candle-core/src/cuda_backend/mod.rs | 8 +++----- candle-transformers/src/models/vgg.rs | 3 +-- 5 files changed, 25 insertions(+), 14 deletions(-) diff --git a/candle-core/examples/cuda_basics.rs b/candle-core/examples/cuda_basics.rs index 00e937cb88..9af1b006e3 100644 --- a/candle-core/examples/cuda_basics.rs +++ b/candle-core/examples/cuda_basics.rs @@ -9,8 +9,10 @@ use candle_core::{Device, Tensor}; fn main() -> Result<()> { let device = Device::new_cuda(0)?; - let x = Tensor::randn(0f32, 1.0, (8 * 4096, 8 * 4096), &device)?; + let x = Tensor::randn(0f32, 1.0, (8 * 4096, 8 * 4096), &device)? + .to_dtype(candle_core::DType::BF16)?; candle_core::cuda::set_gemm_reduced_precision_f32(false); + candle_core::cuda::set_gemm_reduced_precision_bf16(false); let _x1 = x.matmul(&x)?; drop(_x1); let start_time = std::time::Instant::now(); @@ -19,6 +21,7 @@ fn main() -> Result<()> { println!("fp32: {:?}", start_time.elapsed()); drop(_x1); candle_core::cuda::set_gemm_reduced_precision_f32(true); + candle_core::cuda::set_gemm_reduced_precision_bf16(true); let _x1 = x.matmul(&x)?; drop(_x1); let start_time = std::time::Instant::now(); diff --git a/candle-core/src/cpu_backend/mod.rs b/candle-core/src/cpu_backend/mod.rs index 18b73e9b60..58773c8020 100644 --- a/candle-core/src/cpu_backend/mod.rs +++ b/candle-core/src/cpu_backend/mod.rs @@ -121,7 +121,8 @@ impl ReduceIndex { let dst_len = src_l.shape().elem_count() / reduce_dim_size; let mut dst: Vec = Vec::with_capacity(dst_len); let dst_to_set = dst.spare_capacity_mut(); - let dst_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(dst_to_set) }; + let dst_to_set = + unsafe { std::mem::transmute::<&mut [std::mem::MaybeUninit], &mut [U]>(dst_to_set) }; match src_l.contiguous_offsets() { Some((o1, o2)) => { let src = &src[o1..o2]; diff --git a/candle-core/src/cpu_backend/utils.rs b/candle-core/src/cpu_backend/utils.rs index af25a2aff9..3e0c69b4f7 100644 --- a/candle-core/src/cpu_backend/utils.rs +++ b/candle-core/src/cpu_backend/utils.rs @@ -174,7 +174,9 @@ pub fn binary_map_vec T, FV: FnMut(&[T], &[T], &mut [ (Some((o_l1, o_l2)), Some((o_r1, o_r2))) => { let mut ys: Vec = Vec::with_capacity(el_count); let ys_to_set = ys.spare_capacity_mut(); - let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) }; + let ys_to_set = unsafe { + std::mem::transmute::<&mut [std::mem::MaybeUninit], &mut [T]>(ys_to_set) + }; f_vec(&lhs[o_l1..o_l2], &rhs[o_r1..o_r2], ys_to_set); // SAFETY: values are all set by f_vec. unsafe { ys.set_len(el_count) }; @@ -185,7 +187,9 @@ pub fn binary_map_vec T, FV: FnMut(&[T], &[T], &mut [ let rhs = &rhs[ob.start..ob.start + ob.len]; let mut ys: Vec = Vec::with_capacity(el_count); let ys_to_set = ys.spare_capacity_mut(); - let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) }; + let ys_to_set = unsafe { + std::mem::transmute::<&mut [std::mem::MaybeUninit], &mut [T]>(ys_to_set) + }; let mut dst_i = 0; for src_i in (o_l1..o_l2).step_by(ob.len) { f_vec( @@ -224,7 +228,9 @@ pub fn binary_map_vec T, FV: FnMut(&[T], &[T], &mut [ let lhs = &lhs[ob.start..ob.start + ob.len]; let mut ys: Vec = Vec::with_capacity(el_count); let ys_to_set = ys.spare_capacity_mut(); - let ys_to_set = unsafe { std::mem::transmute::<_, &mut [T]>(ys_to_set) }; + let ys_to_set = unsafe { + std::mem::transmute::<&mut [std::mem::MaybeUninit], &mut [T]>(ys_to_set) + }; let mut dst_i = 0; for src_i in (o_r1..o_r2).step_by(ob.len) { f_vec( @@ -311,7 +317,9 @@ pub fn unary_map_vec U, FV: FnMut(&[T], &mut [U crate::StridedBlocks::SingleBlock { start_offset, len } => { let mut ys: Vec = Vec::with_capacity(len); let ys_to_set = ys.spare_capacity_mut(); - let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) }; + let ys_to_set = unsafe { + std::mem::transmute::<&mut [std::mem::MaybeUninit], &mut [U]>(ys_to_set) + }; f_vec(&vs[start_offset..start_offset + len], ys_to_set); // SAFETY: values are all set by f_vec. unsafe { ys.set_len(len) }; @@ -333,7 +341,9 @@ pub fn unary_map_vec U, FV: FnMut(&[T], &mut [U } else { let mut ys: Vec = Vec::with_capacity(el_count); let ys_to_set = ys.spare_capacity_mut(); - let ys_to_set = unsafe { std::mem::transmute::<_, &mut [U]>(ys_to_set) }; + let ys_to_set = unsafe { + std::mem::transmute::<&mut [std::mem::MaybeUninit], &mut [U]>(ys_to_set) + }; let mut dst_index = 0; for src_index in block_start_index { let vs = &vs[src_index..src_index + block_len]; diff --git a/candle-core/src/cuda_backend/mod.rs b/candle-core/src/cuda_backend/mod.rs index 9e72dcc810..7edad3d409 100644 --- a/candle-core/src/cuda_backend/mod.rs +++ b/candle-core/src/cuda_backend/mod.rs @@ -2035,15 +2035,13 @@ unsafe fn gemm_strided_batched_bf16( let alpha_f32: f32 = cfg.gemm.alpha.to_f32(); let beta_f32: f32 = cfg.gemm.beta.to_f32(); - let alpha = f16::from_f32(alpha_f32); - let beta = f16::from_f32(beta_f32); // The type for alpha and beta depends on the computeType. // https://docs.nvidia.com/cuda/cublas/index.html#cublasgemmstridedbatchedex let (compute_type, alpha, beta) = if gemm_reduced_precision_bf16() { ( - sys::cublasComputeType_t::CUBLAS_COMPUTE_16F, - (&alpha) as *const f16 as *const _, - (&beta) as *const f16 as *const _, + sys::cublasComputeType_t::CUBLAS_COMPUTE_32F_FAST_16BF, + (&alpha_f32) as *const f32 as *const _, + (&beta_f32) as *const f32 as *const _, ) } else { ( diff --git a/candle-transformers/src/models/vgg.rs b/candle-transformers/src/models/vgg.rs index a20b5e3725..010643c8d2 100644 --- a/candle-transformers/src/models/vgg.rs +++ b/candle-transformers/src/models/vgg.rs @@ -54,8 +54,7 @@ impl ModuleT for Vgg<'_> { fn conv2d_block(convs: &[(usize, usize, &str)], vb: &VarBuilder) -> Result> { let layers = convs .iter() - .enumerate() - .map(|(_, &(in_c, out_c, name))| { + .map(|&(in_c, out_c, name)| { candle_nn::conv2d( in_c, out_c, From 6baa1d486bfd58da94dbd8630679bd1ed519970f Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sat, 22 Jun 2024 23:21:20 +0200 Subject: [PATCH 3/7] Fix a bug in the metal implemtation of col2im1d. (#2284) --- candle-core/src/metal_backend/mod.rs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index 06f6cd3715..fa83692df7 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -848,7 +848,6 @@ impl BackendStorage for MetalStorage { .device .new_buffer(dst_el, self.dtype, "conv_transpose1d")?; - let command_buffer = self.device.command_buffer()?; let name = match self.dtype { DType::F32 => "col2im1d_f32", DType::U32 => "col2im1d_u32", @@ -869,6 +868,12 @@ impl BackendStorage for MetalStorage { &kernel_l_mm, )? }; + // It is important for the command buffer to be obtained *after* the matmul + // kernel has run, otherwise we might use a command-buffer that has been commited + // already resulting in the following error. + // _status < MTLCommandBufferStatusCommitted > + // -[IOGPUMetalCommandBuffer setCurrentCommandEncoder:] + let command_buffer = self.device.command_buffer()?; candle_metal_kernels::call_col2im1d( &self.device.device, &command_buffer, From 242e006bbb26ff12581b3c04bfd069996fe1f6bb Mon Sep 17 00:00:00 2001 From: Jeroen Vlek Date: Mon, 24 Jun 2024 19:12:52 +0200 Subject: [PATCH 4/7] Depth Anything v2 (#2279) * define structs * construct ResidualConvUnit * forward() for ResidualConvUnit * implement FeatureFusionBlock * implement Scratch * implement DPTHead * add identity module * implement forward for DTPHead * add get_intermediate_layers to DinoVisionTransformer * implement DepthAnythingV2 * some minor tweaks * fix compile errors * fix var builder prefixes * setup initial example * use fixed patch size of 37 (518 / 14) * debugged until output * print min and max values * add some dynamism to the output location * scale input image * extract prep function * extract output path function * normalize image with magic mean and std * add spectral coloring * squeeze in the right place * make enterpolation optional * use bail instead of panic * omit unnecessary Shape call * remove empty curly braces * use bail instead of assert * use vb and pp * remove closures * extract config object * Apply rustfmt. * Fix some clippy lints. * More lints. * Use the array methods. --------- Co-authored-by: laurent --- candle-examples/Cargo.toml | 7 + .../examples/depth_anything_v2/README.md | 13 + .../examples/depth_anything_v2/color_map.rs | 50 ++ .../examples/depth_anything_v2/main.rs | 187 ++++++ candle-nn/src/ops.rs | 23 +- .../src/models/depth_anything_v2.rs | 553 ++++++++++++++++++ candle-transformers/src/models/dinov2.rs | 78 +++ candle-transformers/src/models/mod.rs | 1 + 8 files changed, 911 insertions(+), 1 deletion(-) create mode 100644 candle-examples/examples/depth_anything_v2/README.md create mode 100644 candle-examples/examples/depth_anything_v2/color_map.rs create mode 100644 candle-examples/examples/depth_anything_v2/main.rs create mode 100644 candle-transformers/src/models/depth_anything_v2.rs diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index 5b90f140c2..fa5c620a48 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -25,6 +25,8 @@ hf-hub = { workspace = true, features = ["tokio"] } image = { workspace = true } intel-mkl-src = { workspace = true, optional = true } num-traits = { workspace = true } +palette = { version = "0.7.6", optional = true } +enterpolation = { version = "0.2.1", optional = true} pyo3 = { version = "0.21.0", features = ["auto-initialize"], optional = true } rayon = { workspace = true } rubato = { version = "0.15.0", optional = true } @@ -65,6 +67,7 @@ onnx = ["candle-onnx"] metal = ["candle/metal", "candle-nn/metal"] microphone = ["cpal"] encodec = ["cpal", "symphonia", "rubato"] +depth_anything_v2 = ["palette", "enterpolation"] [[example]] name = "llama_multiprocess" @@ -101,3 +104,7 @@ required-features = ["candle-datasets"] [[example]] name = "encodec" required-features = ["encodec"] + +[[example]] +name = "depth_anything_v2" +required-features = ["depth_anything_v2"] diff --git a/candle-examples/examples/depth_anything_v2/README.md b/candle-examples/examples/depth_anything_v2/README.md new file mode 100644 index 0000000000..163b398b89 --- /dev/null +++ b/candle-examples/examples/depth_anything_v2/README.md @@ -0,0 +1,13 @@ +# candle-dinov2 + +[Depth Anything V2] is a model for Monocular Depth Estimation (MDE, i.e. just using a single image) which +builds on the [DINOv2](https://github.com/facebookresearch/dinov2) vision transformer. + +This example first instantiates the DINOv2 model and then proceeds to create DepthAnythingV2 and run it. + +## Running an example with color map and CUDA + +```bash +cargo run --features cuda,depth_anything_v2 --package candle-examples --example depth_anything_v2 -- --color-map --image candle-examples/examples/yolo-v8/assets/bike.jpg +``` + diff --git a/candle-examples/examples/depth_anything_v2/color_map.rs b/candle-examples/examples/depth_anything_v2/color_map.rs new file mode 100644 index 0000000000..94be326fc5 --- /dev/null +++ b/candle-examples/examples/depth_anything_v2/color_map.rs @@ -0,0 +1,50 @@ +use enterpolation::linear::ConstEquidistantLinear; +use enterpolation::Generator; +use palette::LinSrgb; + +use candle::Tensor; + +pub struct SpectralRColormap { + gradient: ConstEquidistantLinear, +} + +impl SpectralRColormap { + pub(crate) fn new() -> Self { + // Define a colormap similar to 'Spectral_r' by specifying key colors. + // got the colors from ChatGPT-4o + let gradient = ConstEquidistantLinear::::equidistant_unchecked([ + LinSrgb::new(0.3686, 0.3098, 0.6353), // Dark blue + LinSrgb::new(0.1961, 0.5333, 0.7412), // Blue + LinSrgb::new(0.4000, 0.7608, 0.6471), // Cyan + LinSrgb::new(0.6706, 0.8667, 0.6431), // Green + LinSrgb::new(0.9020, 0.9608, 0.5961), // Yellow + LinSrgb::new(0.9961, 0.8784, 0.5451), // Orange + LinSrgb::new(0.9922, 0.6824, 0.3804), // Red + LinSrgb::new(0.9569, 0.4275, 0.2627), // Dark red + LinSrgb::new(0.8353, 0.2431, 0.3098), // Dark purple + ]); + Self { gradient } + } + + fn get_color(&self, value: f32) -> LinSrgb { + self.gradient.gen(value) + } + + pub fn gray2color(&self, gray: &Tensor) -> candle::Result { + println!("Gray: {:?}", gray.dims()); + let gray_values: Vec = gray.flatten_all()?.to_vec1()?; + let rgb_values: Vec = gray_values + .iter() + .map(|g| self.get_color(*g)) + .flat_map(|rgb| [rgb.red, rgb.green, rgb.blue]) + .collect(); + + let [.., height, width] = gray.dims() else { + candle::bail!("Not enough dims!") + }; + + let color = Tensor::from_vec(rgb_values, (*height, *width, 3), gray.device())?; + + color.permute((2, 0, 1)) + } +} diff --git a/candle-examples/examples/depth_anything_v2/main.rs b/candle-examples/examples/depth_anything_v2/main.rs new file mode 100644 index 0000000000..ef337ebab4 --- /dev/null +++ b/candle-examples/examples/depth_anything_v2/main.rs @@ -0,0 +1,187 @@ +//! Depth Anything V2 +//! https://huggingface.co/spaces/depth-anything/Depth-Anything-V2 + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +use std::ffi::OsString; +use std::path::PathBuf; + +use clap::Parser; + +use candle::DType::{F32, U8}; +use candle::{DType, Device, Module, Result, Tensor}; +use candle_examples::{load_image, load_image_and_resize, save_image}; +use candle_nn::VarBuilder; +use candle_transformers::models::depth_anything_v2::{DepthAnythingV2, DepthAnythingV2Config}; +use candle_transformers::models::dinov2; + +use crate::color_map::SpectralRColormap; + +mod color_map; + +// taken these from: https://huggingface.co/spaces/depth-anything/Depth-Anything-V2/blob/main/depth_anything_v2/dpt.py#L207 +const MAGIC_MEAN: [f32; 3] = [0.485, 0.456, 0.406]; +const MAGIC_STD: [f32; 3] = [0.229, 0.224, 0.225]; + +const DINO_IMG_SIZE: usize = 518; + +#[derive(Parser)] +struct Args { + #[arg(long)] + dinov2_model: Option, + + #[arg(long)] + depth_anything_v2_model: Option, + + #[arg(long)] + image: PathBuf, + + #[arg(long)] + output_dir: Option, + + #[arg(long)] + cpu: bool, + + #[arg(long)] + color_map: bool, +} + +pub fn main() -> anyhow::Result<()> { + let args = Args::parse(); + let device = candle_examples::device(args.cpu)?; + + let dinov2_model_file = match args.dinov2_model { + None => { + let api = hf_hub::api::sync::Api::new()?; + let api = api.model("lmz/candle-dino-v2".into()); + api.get("dinov2_vits14.safetensors")? + } + Some(dinov2_model) => dinov2_model, + }; + println!("Using file {:?}", dinov2_model_file); + + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[dinov2_model_file], F32, &device)? }; + let dinov2 = dinov2::vit_small(vb)?; + println!("DinoV2 model built"); + + let depth_anything_model_file = match args.depth_anything_v2_model { + None => { + let api = hf_hub::api::sync::Api::new()?; + let api = api.model("jeroenvlek/depth-anything-v2-safetensors".into()); + api.get("depth_anything_v2_vits.safetensors")? + } + Some(depth_anything_model) => depth_anything_model, + }; + println!("Using file {:?}", depth_anything_model_file); + + let vb = unsafe { + VarBuilder::from_mmaped_safetensors(&[depth_anything_model_file], DType::F32, &device)? + }; + + let config = DepthAnythingV2Config::vit_small(); + let depth_anything = DepthAnythingV2::new(&dinov2, &config, vb)?; + + let (original_height, original_width, image) = load_and_prep_image(&args.image, &device)?; + + println!("Loaded image {image:?}"); + + let depth = depth_anything.forward(&image)?; + + println!("Got predictions {:?}", depth.shape()); + + let output_image = post_process_image(&depth, original_height, original_width, args.color_map)?; + + let output_path = full_output_path(&args.image, &args.output_dir); + println!("Saving image to {}", output_path.to_string_lossy()); + save_image(&output_image, output_path)?; + + Ok(()) +} + +fn full_output_path(image_path: &PathBuf, output_dir: &Option) -> PathBuf { + let input_file_name = image_path.file_name().unwrap(); + let mut output_file_name = OsString::from("depth_"); + output_file_name.push(input_file_name); + let mut output_path = match output_dir { + None => image_path.parent().unwrap().to_path_buf(), + Some(output_path) => output_path.clone(), + }; + output_path.push(output_file_name); + + output_path +} + +fn load_and_prep_image( + image_path: &PathBuf, + device: &Device, +) -> anyhow::Result<(usize, usize, Tensor)> { + let (_original_image, original_height, original_width) = load_image(&image_path, None)?; + + let image = load_image_and_resize(&image_path, DINO_IMG_SIZE, DINO_IMG_SIZE)? + .unsqueeze(0)? + .to_dtype(F32)? + .to_device(&device)?; + + let max_pixel_val = Tensor::try_from(255.0f32)? + .to_device(&device)? + .broadcast_as(image.shape())?; + let image = (image / max_pixel_val)?; + let image = normalize_image(&image, &MAGIC_MEAN, &MAGIC_STD)?; + + Ok((original_height, original_width, image)) +} + +fn normalize_image(image: &Tensor, mean: &[f32; 3], std: &[f32; 3]) -> Result { + let mean_tensor = + Tensor::from_vec(mean.to_vec(), (3, 1, 1), &image.device())?.broadcast_as(image.shape())?; + let std_tensor = + Tensor::from_vec(std.to_vec(), (3, 1, 1), &image.device())?.broadcast_as(image.shape())?; + image.sub(&mean_tensor)?.div(&std_tensor) +} + +fn post_process_image( + image: &Tensor, + original_height: usize, + original_width: usize, + color_map: bool, +) -> Result { + let out = image.interpolate2d(original_height, original_width)?; + let out = scale_image(&out)?; + + let out = if color_map { + let spectral_r = SpectralRColormap::new(); + spectral_r.gray2color(&out)? + } else { + let rgb_slice = [&out, &out, &out]; + Tensor::cat(&rgb_slice, 0)?.squeeze(1)? + }; + + let max_pixel_val = Tensor::try_from(255.0f32)? + .to_device(out.device())? + .broadcast_as(out.shape())?; + let out = (out * max_pixel_val)?; + + out.to_dtype(U8) +} + +fn scale_image(depth: &Tensor) -> Result { + let flat_values: Vec = depth.flatten_all()?.to_vec1()?; + + let min_val = flat_values.iter().min_by(|a, b| a.total_cmp(b)).unwrap(); + let max_val = flat_values.iter().max_by(|a, b| a.total_cmp(b)).unwrap(); + + let min_val_tensor = Tensor::try_from(*min_val)? + .to_device(depth.device())? + .broadcast_as(depth.shape())?; + let depth = (depth - min_val_tensor)?; + + let range = max_val - min_val; + let range_tensor = Tensor::try_from(range)? + .to_device(depth.device())? + .broadcast_as(depth.shape())?; + + depth / range_tensor +} diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index 2a76ee5eed..9a360c472c 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -1,4 +1,4 @@ -use candle::{CpuStorage, DType, Layout, Result, Shape, Tensor, D}; +use candle::{CpuStorage, DType, Layout, Module, Result, Shape, Tensor, D}; use rayon::prelude::*; /// Applies the softmax function to the input tensor, rescaling the element so that elements on @@ -926,3 +926,24 @@ pub fn replication_pad2d(xs: &Tensor, pad: usize) -> Result { n => candle::bail!("replication-pad with a size of {n} is not supported"), } } + +#[derive(Clone, Debug)] +pub struct Identity; + +impl Identity { + pub fn new() -> Identity { + Self + } +} + +impl Default for Identity { + fn default() -> Self { + Self + } +} + +impl Module for Identity { + fn forward(&self, xs: &Tensor) -> Result { + Ok(xs.clone()) + } +} diff --git a/candle-transformers/src/models/depth_anything_v2.rs b/candle-transformers/src/models/depth_anything_v2.rs new file mode 100644 index 0000000000..9eee6d1130 --- /dev/null +++ b/candle-transformers/src/models/depth_anything_v2.rs @@ -0,0 +1,553 @@ +use candle::D::Minus1; +use candle::{Module, Result, Tensor}; +use candle_nn::ops::Identity; +use candle_nn::{ + batch_norm, conv2d, conv2d_no_bias, conv_transpose2d, linear, seq, Activation, BatchNorm, + BatchNormConfig, Conv2d, Conv2dConfig, ConvTranspose2dConfig, Sequential, VarBuilder, +}; + +use crate::models::dinov2::DinoVisionTransformer; + +pub struct DepthAnythingV2Config { + out_channel_sizes: [usize; 4], + in_channel_size: usize, // embed_dim in the Dino model + num_features: usize, + use_batch_norm: bool, + use_class_token: bool, + layer_ids_vits: Vec, + input_image_size: usize, + target_patch_size: usize, +} + +impl DepthAnythingV2Config { + #[allow(clippy::too_many_arguments)] + pub fn new( + out_channel_sizes: [usize; 4], + in_channel_size: usize, + num_features: usize, + use_batch_norm: bool, + use_class_token: bool, + layer_ids_vits: Vec, + input_image_size: usize, + target_patch_size: usize, + ) -> Self { + Self { + out_channel_sizes, + in_channel_size, + num_features, + use_batch_norm, + use_class_token, + layer_ids_vits, + input_image_size, + target_patch_size, + } + } + + pub fn vit_small() -> Self { + Self { + out_channel_sizes: [48, 96, 192, 384], + in_channel_size: 384, + num_features: 64, + use_batch_norm: false, + use_class_token: false, + layer_ids_vits: vec![2, 5, 8, 11], + input_image_size: 518, + target_patch_size: 518 / 14, + } + } + + pub fn vit_base() -> Self { + Self { + out_channel_sizes: [96, 192, 384, 768], + in_channel_size: 768, + num_features: 128, + use_batch_norm: false, + use_class_token: false, + layer_ids_vits: vec![2, 5, 8, 11], + input_image_size: 518, + target_patch_size: 518 / 14, + } + } + + pub fn vit_large() -> Self { + Self { + out_channel_sizes: [256, 512, 1024, 1024], + in_channel_size: 1024, + num_features: 256, + use_batch_norm: false, + use_class_token: false, + layer_ids_vits: vec![4, 11, 17, 23], + input_image_size: 518, + target_patch_size: 518 / 14, + } + } + + pub fn vit_giant() -> Self { + Self { + out_channel_sizes: [1536, 1536, 1536, 1536], + in_channel_size: 1536, + num_features: 384, + use_batch_norm: false, + use_class_token: false, + layer_ids_vits: vec![9, 19, 29, 39], + input_image_size: 518, + target_patch_size: 518 / 14, + } + } +} + +pub struct ResidualConvUnit { + activation: Activation, + conv1: Conv2d, + conv2: Conv2d, + batch_norm1: Option, + batch_norm2: Option, +} + +impl ResidualConvUnit { + pub fn new( + conf: &DepthAnythingV2Config, + activation: Activation, + vb: VarBuilder, + ) -> Result { + const KERNEL_SIZE: usize = 3; + let conv_cfg = Conv2dConfig { + padding: 1, + stride: 1, + dilation: 1, + groups: 1, + }; + let conv1 = conv2d( + conf.num_features, + conf.num_features, + KERNEL_SIZE, + conv_cfg, + vb.pp("conv1"), + )?; + let conv2 = conv2d( + conf.num_features, + conf.num_features, + KERNEL_SIZE, + conv_cfg, + vb.pp("conv2"), + )?; + + let (batch_norm1, batch_norm2) = match conf.use_batch_norm { + true => { + let batch_norm_cfg = BatchNormConfig { + eps: 1e-05, + remove_mean: false, + affine: true, + momentum: 0.1, + }; + ( + Some(batch_norm(conf.num_features, batch_norm_cfg, vb.pp("bn1"))?), + Some(batch_norm(conf.num_features, batch_norm_cfg, vb.pp("bn2"))?), + ) + } + false => (None, None), + }; + + Ok(Self { + activation, + conv1, + conv2, + batch_norm1, + batch_norm2, + }) + } +} + +impl Module for ResidualConvUnit { + fn forward(&self, xs: &Tensor) -> Result { + let out = self.activation.forward(xs)?; + let out = self.conv1.forward(&out)?; + let out = if let Some(batch_norm1) = &self.batch_norm1 { + batch_norm1.forward_train(&out)? + } else { + out + }; + + let out = self.activation.forward(&out)?; + let out = self.conv2.forward(&out)?; + let out = if let Some(batch_norm2) = &self.batch_norm2 { + batch_norm2.forward_train(&out)? + } else { + out + }; + + out + xs + } +} + +pub struct FeatureFusionBlock { + res_conv_unit1: ResidualConvUnit, + res_conv_unit2: ResidualConvUnit, + output_conv: Conv2d, + target_patch_size: usize, +} + +impl FeatureFusionBlock { + pub fn new( + conf: &DepthAnythingV2Config, + target_patch_size: usize, + activation: Activation, + vb: VarBuilder, + ) -> Result { + const KERNEL_SIZE: usize = 1; + let conv_cfg = Conv2dConfig { + padding: 0, + stride: 1, + dilation: 1, + groups: 1, + }; + let output_conv = conv2d( + conf.num_features, + conf.num_features, + KERNEL_SIZE, + conv_cfg, + vb.pp("out_conv"), + )?; + let res_conv_unit1 = ResidualConvUnit::new(conf, activation, vb.pp("resConfUnit1"))?; + let res_conv_unit2 = ResidualConvUnit::new(conf, activation, vb.pp("resConfUnit2"))?; + + Ok(Self { + res_conv_unit1, + res_conv_unit2, + output_conv, + target_patch_size, + }) + } +} + +impl Module for FeatureFusionBlock { + fn forward(&self, xs: &Tensor) -> Result { + let out = self.res_conv_unit2.forward(xs)?; + let out = out.interpolate2d(self.target_patch_size, self.target_patch_size)?; + + self.output_conv.forward(&out) + } +} + +pub struct Scratch { + layer1_rn: Conv2d, + layer2_rn: Conv2d, + layer3_rn: Conv2d, + layer4_rn: Conv2d, + refine_net1: FeatureFusionBlock, + refine_net2: FeatureFusionBlock, + refine_net3: FeatureFusionBlock, + refine_net4: FeatureFusionBlock, + output_conv1: Conv2d, + output_conv2: Sequential, +} + +impl Scratch { + pub fn new(conf: &DepthAnythingV2Config, vb: VarBuilder) -> Result { + const KERNEL_SIZE: usize = 3; + let conv_cfg = Conv2dConfig { + padding: 1, + stride: 1, + dilation: 1, + groups: 1, + }; + + let layer1_rn = conv2d_no_bias( + conf.out_channel_sizes[0], + conf.num_features, + KERNEL_SIZE, + conv_cfg, + vb.pp("layer1_rn"), + )?; + let layer2_rn = conv2d_no_bias( + conf.out_channel_sizes[1], + conf.num_features, + KERNEL_SIZE, + conv_cfg, + vb.pp("layer2_rn"), + )?; + let layer3_rn = conv2d_no_bias( + conf.out_channel_sizes[2], + conf.num_features, + KERNEL_SIZE, + conv_cfg, + vb.pp("layer3_rn"), + )?; + let layer4_rn = conv2d_no_bias( + conf.out_channel_sizes[3], + conf.num_features, + KERNEL_SIZE, + conv_cfg, + vb.pp("layer4_rn"), + )?; + + let refine_net1 = FeatureFusionBlock::new( + conf, + conf.target_patch_size * 8, + Activation::Relu, + vb.pp("refinenet1"), + )?; + let refine_net2 = FeatureFusionBlock::new( + conf, + conf.target_patch_size * 4, + Activation::Relu, + vb.pp("refinenet2"), + )?; + let refine_net3 = FeatureFusionBlock::new( + conf, + conf.target_patch_size * 2, + Activation::Relu, + vb.pp("refinenet3"), + )?; + let refine_net4 = FeatureFusionBlock::new( + conf, + conf.target_patch_size, + Activation::Relu, + vb.pp("refinenet4"), + )?; + + let conv_cfg = Conv2dConfig { + padding: 1, + stride: 1, + dilation: 1, + groups: 1, + }; + let output_conv1 = conv2d( + conf.num_features, + conf.num_features / 2, + KERNEL_SIZE, + conv_cfg, + vb.pp("output_conv1"), + )?; + + let output_conv2 = seq(); + const HEAD_FEATURES_2: usize = 32; + const OUT_CHANNELS_2: usize = 1; + const KERNEL_SIZE_2: usize = 1; + let output_conv2 = output_conv2.add(conv2d( + conf.num_features / 2, + HEAD_FEATURES_2, + KERNEL_SIZE, + conv_cfg, + vb.pp("output_conv2").pp("0"), + )?); + let output_conv2 = output_conv2 + .add(Activation::Relu) + .add(conv2d( + HEAD_FEATURES_2, + OUT_CHANNELS_2, + KERNEL_SIZE_2, + conv_cfg, + vb.pp("output_conv2").pp("2"), + )?) + .add(Activation::Relu); + + Ok(Self { + layer1_rn, + layer2_rn, + layer3_rn, + layer4_rn, + refine_net1, + refine_net2, + refine_net3, + refine_net4, + output_conv1, + output_conv2, + }) + } +} + +const NUM_CHANNELS: usize = 4; + +pub struct DPTHead<'a> { + conf: &'a DepthAnythingV2Config, + projections: Vec, + resize_layers: Vec>, + readout_projections: Vec, + scratch: Scratch, +} + +impl<'a> DPTHead<'a> { + pub fn new(conf: &'a DepthAnythingV2Config, vb: VarBuilder) -> Result { + let mut projections: Vec = Vec::with_capacity(conf.out_channel_sizes.len()); + for (conv_index, out_channel_size) in conf.out_channel_sizes.iter().enumerate() { + projections.push(conv2d( + conf.in_channel_size, + *out_channel_size, + 1, + Default::default(), + vb.pp("projects").pp(conv_index.to_string()), + )?); + } + + let resize_layers: Vec> = vec![ + Box::new(conv_transpose2d( + conf.out_channel_sizes[0], + conf.out_channel_sizes[0], + 4, + ConvTranspose2dConfig { + padding: 0, + stride: 4, + dilation: 1, + output_padding: 0, + }, + vb.pp("resize_layers").pp("0"), + )?), + Box::new(conv_transpose2d( + conf.out_channel_sizes[1], + conf.out_channel_sizes[1], + 2, + ConvTranspose2dConfig { + padding: 0, + stride: 2, + dilation: 1, + output_padding: 0, + }, + vb.pp("resize_layers").pp("1"), + )?), + Box::new(Identity::new()), + Box::new(conv2d( + conf.out_channel_sizes[3], + conf.out_channel_sizes[3], + 3, + Conv2dConfig { + padding: 1, + stride: 2, + dilation: 1, + groups: 1, + }, + vb.pp("resize_layers").pp("3"), + )?), + ]; + + let readout_projections = if conf.use_class_token { + let rop = Vec::with_capacity(NUM_CHANNELS); + for rop_index in 0..NUM_CHANNELS { + seq() + .add(linear( + 2 * conf.in_channel_size, + conf.in_channel_size, + vb.pp("readout_projects").pp(rop_index.to_string()), + )?) + .add(Activation::Gelu); + } + rop + } else { + vec![] + }; + + let scratch = Scratch::new(conf, vb.pp("scratch"))?; + + Ok(Self { + conf, + projections, + resize_layers, + readout_projections, + scratch, + }) + } +} + +impl Module for DPTHead<'_> { + fn forward(&self, xs: &Tensor) -> Result { + let mut out: Vec = Vec::with_capacity(NUM_CHANNELS); + for i in 0..NUM_CHANNELS { + let x = if self.conf.use_class_token { + let x = xs.get(i)?.get(0)?; + let class_token = xs.get(i)?.get(1)?; + let readout = class_token.unsqueeze(1)?.expand(x.shape())?; + let to_cat = [x, readout]; + let cat = Tensor::cat(&to_cat, Minus1)?; + self.readout_projections[i].forward(&cat)? + } else { + xs.get(i)? + }; + let x_dims = x.dims(); + + let x = x.permute((0, 2, 1))?.reshape(( + x_dims[0], + x_dims[x_dims.len() - 1], + self.conf.target_patch_size, + self.conf.target_patch_size, + ))?; + let x = self.projections[i].forward(&x)?; + + let x = self.resize_layers[i].forward(&x)?; + out.push(x); + } + + let layer_1_rn = self.scratch.layer1_rn.forward(&out[0])?; + let layer_2_rn = self.scratch.layer2_rn.forward(&out[1])?; + let layer_3_rn = self.scratch.layer3_rn.forward(&out[2])?; + let layer_4_rn = self.scratch.layer4_rn.forward(&out[3])?; + + let path4 = self.scratch.refine_net4.forward(&layer_4_rn)?; + + let res3_out = self + .scratch + .refine_net3 + .res_conv_unit1 + .forward(&layer_3_rn)?; + let res3_out = path4.add(&res3_out)?; + let path3 = self.scratch.refine_net3.forward(&res3_out)?; + + let res2_out = self + .scratch + .refine_net2 + .res_conv_unit1 + .forward(&layer_2_rn)?; + let res2_out = path3.add(&res2_out)?; + let path2 = self.scratch.refine_net2.forward(&res2_out)?; + + let res1_out = self + .scratch + .refine_net1 + .res_conv_unit1 + .forward(&layer_1_rn)?; + let res1_out = path2.add(&res1_out)?; + let path1 = self.scratch.refine_net1.forward(&res1_out)?; + + let out = self.scratch.output_conv1.forward(&path1)?; + + let out = out.interpolate2d(self.conf.input_image_size, self.conf.input_image_size)?; + + self.scratch.output_conv2.forward(&out) + } +} + +pub struct DepthAnythingV2<'a> { + pretrained: &'a DinoVisionTransformer, + depth_head: DPTHead<'a>, + conf: &'a DepthAnythingV2Config, +} + +impl<'a> DepthAnythingV2<'a> { + pub fn new( + pretrained: &'a DinoVisionTransformer, + conf: &'a DepthAnythingV2Config, + vb: VarBuilder, + ) -> Result { + let depth_head = DPTHead::new(conf, vb.pp("depth_head"))?; + + Ok(Self { + pretrained, + depth_head, + conf, + }) + } +} + +impl<'a> Module for DepthAnythingV2<'a> { + fn forward(&self, xs: &Tensor) -> Result { + let features = self.pretrained.get_intermediate_layers( + xs, + &self.conf.layer_ids_vits, + false, + false, + true, + )?; + let depth = self.depth_head.forward(&features)?; + + depth.relu() + } +} diff --git a/candle-transformers/src/models/dinov2.rs b/candle-transformers/src/models/dinov2.rs index 757aa88ac4..00e501ce0d 100644 --- a/candle-transformers/src/models/dinov2.rs +++ b/candle-transformers/src/models/dinov2.rs @@ -258,6 +258,84 @@ impl DinoVisionTransformer { let xs = Tensor::cat(&[&self.cls_token, &xs], 1)?; &xs + &self.interpolate_pos_encoding(&xs, w, h)? } + + fn get_intermediate_layers_not_chunked( + &self, + xs: &Tensor, + blocks_to_take: &[usize], + ) -> Result> { + let mut xs = self.prepare_tokens_with_mask(xs)?; + let mut output = Vec::new(); + for (i, blk) in self.blocks.iter().enumerate() { + xs = blk.forward(&xs)?; + if blocks_to_take.contains(&i) { + output.push(xs.clone()); + } + } + if output.len() != blocks_to_take.len() { + candle::bail!( + "only {} / {} blocks found", + output.len(), + blocks_to_take.len() + ); + } + Ok(output) + } + + pub fn get_intermediate_layers( + &self, + xs: &Tensor, + blocks_to_take: &[usize], + reshape: bool, + return_class_token: bool, + norm: bool, + ) -> Result { + let outputs = self.get_intermediate_layers_not_chunked(xs, blocks_to_take)?; + let outputs = if norm { + outputs + .iter() + .map(|out| self.norm.forward(out)) + .collect::>>()? + } else { + outputs + }; + let class_tokens = outputs + .iter() + .map(|out| out.i((.., 0))) + .collect::>>()?; + let outputs = outputs + .iter() + .map(|out| out.i((.., 1..))) + .collect::>>()?; + + let outputs = if reshape { + let (b, _c, w, h) = xs.dims4()?; + let patch_size = self.patch_embed.patch_size.0; + let num_channels = outputs[0].elem_count() / (b * (w / patch_size) * (h / patch_size)); + outputs + .iter() + .map(|out| { + out.reshape((b, w / patch_size, h / patch_size, num_channels))? + .transpose(2, 3)? + .transpose(1, 2) + }) + .collect::>>()? + } else { + outputs + }; + + let outputs = if return_class_token { + outputs + .iter() + .zip(class_tokens.iter()) + .map(|(out, class_token)| Tensor::cat(&[out, class_token], D::Minus1)) + .collect::>>()? + } else { + outputs + }; + + Tensor::stack(&outputs[..], 0) + } } impl Module for DinoVisionTransformer { diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 4628a3de43..89ae0f8a39 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -6,6 +6,7 @@ pub mod chatglm; pub mod clip; pub mod convmixer; pub mod convnext; +pub mod depth_anything_v2; pub mod dinov2; pub mod distilbert; pub mod efficientnet; From a3dd87f15e3656ee2bec4820ae72a2a4e5662b40 Mon Sep 17 00:00:00 2001 From: "drCathieSo.eth" Date: Sat, 29 Jun 2024 03:40:31 +0800 Subject: [PATCH 5/7] Adding Gemm and ArgMax operators to candle-onnx (#2231) * feat(gemm): implement Gemm operator in candle-onnx * feat(onnx): Add support for ArgMax operator in candle-onnx * Apply rustfmt. * Remove argmax as it was already present. --------- Co-authored-by: Laurent --- candle-onnx/src/eval.rs | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs index 10a3b9377b..f7203b36f7 100644 --- a/candle-onnx/src/eval.rs +++ b/candle-onnx/src/eval.rs @@ -1274,6 +1274,30 @@ fn simple_eval_( let output = candle_nn::ops::leaky_relu(input, alpha.into())?; values.insert(node.output[0].clone(), output); } + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#Gemm + "Gemm" => { + let a = get(&node.input[0])?; + let b = get(&node.input[1])?; + let c = get(&node.input[2])?; + + let alpha = get_attr_opt::(node, "alpha")?.copied().unwrap_or(1.0); + let beta = get_attr_opt::(node, "beta")?.copied().unwrap_or(1.0); + + let alpha = Tensor::full(alpha, a.shape(), &Device::Cpu)?; + let beta = Tensor::full(beta, c.shape(), &Device::Cpu)?; + + let trans_a = get_attr_opt::(node, "transA")?.copied().unwrap_or(0); + let trans_b = get_attr_opt::(node, "transB")?.copied().unwrap_or(0); + + let a = if trans_a == 0 { a.clone() } else { a.t()? }; + let b = if trans_b == 0 { b.clone() } else { b.t()? }; + + let output = a + .broadcast_mul(&alpha)? + .broadcast_matmul(&b)? + .broadcast_add(&c.broadcast_mul(&beta)?)?; + values.insert(node.output[0].clone(), output); + } op_type => bail!("unsupported op_type {op_type} for op {node:?}"), } } From e27aac0a062a6de125e2984eacdb7841664e86fd Mon Sep 17 00:00:00 2001 From: v-espitalier <125037408+v-espitalier@users.noreply.github.com> Date: Sat, 29 Jun 2024 11:49:15 +0200 Subject: [PATCH 6/7] Add DINOv2Reg4 + PlantCLEF2024 (#2293) * Add: DINOv2Reg4 with PlantCLEF2024 weights and example ( See https://arxiv.org/abs/2309.16588 and https://zenodo.org/records/10848263 ) * Remove extra files + update README to download them + remove extra lines * minor fix (README remove extra spaces) * minor fix (README: Fix image url) * Modif: Add back interpolate_pos_encoding() + fix when no interpolation + remove extra comments + Update README ( source image changed and so the predictions ) * Fix: Improve code lisibility with '$ cargo clippy' and '$ cargo fmt' * Another clippy fix. --------- Co-authored-by: x-VEspit Co-authored-by: laurent --- candle-examples/examples/dinov2reg4/README.md | 25 ++ candle-examples/examples/dinov2reg4/main.rs | 70 +++++ candle-examples/src/imagenet.rs | 18 ++ candle-transformers/src/models/dinov2reg4.rs | 281 ++++++++++++++++++ candle-transformers/src/models/mod.rs | 1 + 5 files changed, 395 insertions(+) create mode 100644 candle-examples/examples/dinov2reg4/README.md create mode 100644 candle-examples/examples/dinov2reg4/main.rs create mode 100644 candle-transformers/src/models/dinov2reg4.rs diff --git a/candle-examples/examples/dinov2reg4/README.md b/candle-examples/examples/dinov2reg4/README.md new file mode 100644 index 0000000000..ac86ca6911 --- /dev/null +++ b/candle-examples/examples/dinov2reg4/README.md @@ -0,0 +1,25 @@ +# candle-dinov2-reg4 + +[DINOv2-reg4](https://arxiv.org/abs/2309.16588) is the lastest version of DINOv2 with registers. +In this example, it is used as an plant species classifier: the model returns the +probability for the image to belong to each of the 7806 PlantCLEF2024 categories. + +## Running some example + +```bash +# Download classes names and a plant picture to identify +curl https://huggingface.co/vincent-espitalier/dino-v2-reg4-with-plantclef2024-weights/raw/main/species_id_mapping.txt --output candle-examples/examples/dinov2reg4/species_id_mapping.txt +curl https://bs.plantnet.org/image/o/bd2d3830ac3270218ba82fd24e2290becd01317c --output candle-examples/examples/dinov2reg4/bd2d3830ac3270218ba82fd24e2290becd01317c.jpg + +# Perform inference +cargo run --example dinov2reg4 --release -- --image candle-examples/examples/dinov2reg4/bd2d3830ac3270218ba82fd24e2290becd01317c.jpg + +> Orchis simia Lam. : 45.55% +> Orchis × bergonii Nanteuil: 9.80% +> Orchis italica Poir. : 9.66% +> Orchis × angusticruris Franch.: 2.76% +> Orchis × bivonae Tod. : 2.54% + +``` + +![Orchis Simia](https://bs.plantnet.org/image/o/bd2d3830ac3270218ba82fd24e2290becd01317c) diff --git a/candle-examples/examples/dinov2reg4/main.rs b/candle-examples/examples/dinov2reg4/main.rs new file mode 100644 index 0000000000..15270517c5 --- /dev/null +++ b/candle-examples/examples/dinov2reg4/main.rs @@ -0,0 +1,70 @@ +//! DINOv2 reg4 finetuned on PlantCLEF 2024 +//! https://arxiv.org/abs/2309.16588 +//! https://huggingface.co/spaces/BVRA/PlantCLEF2024 +//! https://zenodo.org/records/10848263 + +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use clap::Parser; + +use candle::{DType, IndexOp, D}; +use candle_nn::{Module, VarBuilder}; +use candle_transformers::models::dinov2reg4; + +#[derive(Parser)] +struct Args { + #[arg(long)] + model: Option, + + #[arg(long)] + image: String, + + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, +} + +pub fn main() -> anyhow::Result<()> { + let args = Args::parse(); + + let device = candle_examples::device(args.cpu)?; + + let image = candle_examples::imagenet::load_image518(args.image)?.to_device(&device)?; + println!("loaded image {image:?}"); + + let f_species_id_mapping = "candle-examples/examples/dinov2reg4/species_id_mapping.txt"; + let classes: Vec = std::fs::read_to_string(f_species_id_mapping) + .expect("missing classes file") + .split('\n') + .map(|s| s.to_string()) + .collect(); + + let model_file = match args.model { + None => { + let api = hf_hub::api::sync::Api::new()?; + let api = + api.model("vincent-espitalier/dino-v2-reg4-with-plantclef2024-weights".into()); + api.get( + "vit_base_patch14_reg4_dinov2_lvd142m_pc24_onlyclassifier_then_all.safetensors", + )? + } + Some(model) => model.into(), + }; + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? }; + let model = dinov2reg4::vit_base(vb)?; + println!("model built"); + let logits = model.forward(&image.unsqueeze(0)?)?; + let prs = candle_nn::ops::softmax(&logits, D::Minus1)? + .i(0)? + .to_vec1::()?; + let mut prs = prs.iter().enumerate().collect::>(); + prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1)); + for &(category_idx, pr) in prs.iter().take(5) { + println!("{:24}: {:.2}%", classes[category_idx], 100. * pr); + } + Ok(()) +} diff --git a/candle-examples/src/imagenet.rs b/candle-examples/src/imagenet.rs index cefbd71bbe..781dcd4fc3 100644 --- a/candle-examples/src/imagenet.rs +++ b/candle-examples/src/imagenet.rs @@ -17,6 +17,24 @@ pub fn load_image224>(p: P) -> Result { .broadcast_div(&std) } +/// Loads an image from disk using the image crate, this returns a tensor with shape +/// (3, 518, 518). imagenet normalization is applied. +/// The model dinov2 reg4 analyzes images with dimensions 3x518x518 (resulting in 37x37 transformer tokens). +pub fn load_image518>(p: P) -> Result { + let img = image::io::Reader::open(p)? + .decode() + .map_err(candle::Error::wrap)? + .resize_to_fill(518, 518, image::imageops::FilterType::Triangle); + let img = img.to_rgb8(); + let data = img.into_raw(); + let data = Tensor::from_vec(data, (518, 518, 3), &Device::Cpu)?.permute((2, 0, 1))?; + let mean = Tensor::new(&[0.485f32, 0.456, 0.406], &Device::Cpu)?.reshape((3, 1, 1))?; + let std = Tensor::new(&[0.229f32, 0.224, 0.225], &Device::Cpu)?.reshape((3, 1, 1))?; + (data.to_dtype(candle::DType::F32)? / 255.)? + .broadcast_sub(&mean)? + .broadcast_div(&std) +} + pub const CLASS_COUNT: i64 = 1000; pub const CLASSES: [&str; 1000] = [ diff --git a/candle-transformers/src/models/dinov2reg4.rs b/candle-transformers/src/models/dinov2reg4.rs new file mode 100644 index 0000000000..6bbe2e2410 --- /dev/null +++ b/candle-transformers/src/models/dinov2reg4.rs @@ -0,0 +1,281 @@ +use candle::{IndexOp, Result, Tensor, D}; +use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder}; + +const IMG_SIZE: usize = 518; +const PATCH_SIZE: usize = 14; +const NUM_CLASSES: usize = 7806; // PlantCLEF2024 DINOv2 (https://zenodo.org/records/10848263) + +fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result { + if bias { + candle_nn::linear(in_dim, out_dim, vb) + } else { + candle_nn::linear_no_bias(in_dim, out_dim, vb) + } +} + +#[derive(Debug)] +struct Attention { + qkv: Linear, + proj: Linear, + num_heads: usize, + scale: f64, +} + +impl Attention { + fn new( + vb: VarBuilder, + dim: usize, + num_heads: usize, + qkv_bias: bool, + proj_bias: bool, + ) -> Result { + let qkv = linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?; + let proj = linear(vb.pp("proj"), dim, dim, proj_bias)?; + let scale = 1. / ((dim / num_heads) as f64).sqrt(); + Ok(Self { + qkv, + proj, + num_heads, + scale, + }) + } +} + +impl Module for Attention { + fn forward(&self, xs: &Tensor) -> Result { + let (b, n, c) = xs.dims3()?; + let qkv = self + .qkv + .forward(xs)? + .reshape((b, n, 3, self.num_heads, c / self.num_heads))? + .transpose(1, 2)? // 02134 + .transpose(0, 1)? // 20134 + .transpose(2, 3)?; // 20314 + let q = (qkv.i(0)? * self.scale)?; + let k = qkv.i(1)?.contiguous()?; + let v = qkv.i(2)?.contiguous()?; + let attn = candle_nn::ops::softmax(&q.matmul(&k.t()?)?, D::Minus1)?; + let attn = attn.matmul(&v)?.transpose(1, 2)?.reshape((b, n, c))?; + self.proj.forward(&attn) + } +} + +#[derive(Debug)] +struct LayerScale { + gamma: Tensor, +} + +impl LayerScale { + fn new(vb: VarBuilder, dim: usize) -> Result { + let gamma = vb.get(dim, "gamma")?; + Ok(Self { gamma }) + } +} + +impl Module for LayerScale { + fn forward(&self, xs: &Tensor) -> Result { + xs.broadcast_mul(&self.gamma) + } +} + +#[derive(Debug)] +struct Mlp { + fc1: Linear, + fc2: Linear, +} + +impl Mlp { + fn new(vb: VarBuilder, in_features: usize, hidden_features: usize, bias: bool) -> Result { + let out_features = in_features; + let fc1 = linear(vb.pp("fc1"), in_features, hidden_features, bias)?; + let fc2 = linear(vb.pp("fc2"), hidden_features, out_features, bias)?; + Ok(Self { fc1, fc2 }) + } +} + +impl Module for Mlp { + fn forward(&self, xs: &Tensor) -> Result { + let xs = self.fc1.forward(xs)?.gelu()?; + self.fc2.forward(&xs) + } +} + +#[derive(Debug)] +struct Block { + norm1: LayerNorm, + attn: Attention, + ls1: LayerScale, + norm2: LayerNorm, + mlp: Mlp, + ls2: LayerScale, +} + +impl Block { + fn new(vb: VarBuilder, dim: usize, num_heads: usize) -> Result { + let norm1 = layer_norm(dim, 1e-6, vb.pp("norm1"))?; + let attn = Attention::new(vb.pp("attn"), dim, num_heads, true, true)?; + let ls1 = LayerScale::new(vb.pp("ls1"), dim)?; + let norm2 = layer_norm(dim, 1e-6, vb.pp("norm2"))?; + let mlp = Mlp::new(vb.pp("mlp"), dim, dim * 4, true)?; + let ls2 = LayerScale::new(vb.pp("ls2"), dim)?; + Ok(Self { + norm1, + attn, + ls1, + norm2, + mlp, + ls2, + }) + } +} + +impl Module for Block { + fn forward(&self, xs: &Tensor) -> Result { + let residual = xs; + let xs = self + .ls1 + .forward(&self.attn.forward(&self.norm1.forward(xs)?)?)?; + let xs = (xs + residual)?; + let residual = &xs; + let xs = self + .ls2 + .forward(&self.mlp.forward(&self.norm2.forward(&xs)?)?)?; + xs + residual + } +} + +#[derive(Debug)] +struct PatchEmbed { + proj: candle_nn::Conv2d, + patch_size: (usize, usize), + num_patches: usize, +} + +impl PatchEmbed { + fn new( + vb: VarBuilder, + img_size: usize, + patch_size: usize, + in_chans: usize, + embed_dim: usize, + ) -> Result { + let config = candle_nn::Conv2dConfig { + stride: patch_size, + ..Default::default() + }; + let proj = candle_nn::conv2d(in_chans, embed_dim, patch_size, config, vb.pp("proj"))?; + let num_patches = (img_size / patch_size) * (img_size / patch_size); + Ok(Self { + proj, + patch_size: (patch_size, patch_size), + num_patches, + }) + } +} + +impl Module for PatchEmbed { + fn forward(&self, xs: &Tensor) -> Result { + let (_b, _c, h, w) = xs.dims4()?; + let (patch_h, patch_w) = self.patch_size; + if (h % patch_h) != 0 { + candle::bail!("image height {h} is not a multiple of patch height {patch_h}") + } + if (w % patch_w) != 0 { + candle::bail!("image width {w} is not a multiple of patch width {patch_w}") + } + let xs = self.proj.forward(xs)?; + let (b, c, h, w) = xs.dims4()?; + // flatten embeddings. + xs.reshape((b, c, h * w))?.transpose(1, 2) + } +} + +#[derive(Debug)] +pub struct DinoVisionTransformer { + patch_embed: PatchEmbed, + cls_token: Tensor, + reg_token: Tensor, + pos_embed: Tensor, + blocks: Vec, + norm: LayerNorm, + head: Linear, +} + +impl DinoVisionTransformer { + pub fn new(vb: VarBuilder, depth: usize, embed_dim: usize, num_heads: usize) -> Result { + let patch_embed = + PatchEmbed::new(vb.pp("patch_embed"), IMG_SIZE, PATCH_SIZE, 3, embed_dim)?; + let cls_token = vb.get((1, 1, embed_dim), "cls_token")?; + let reg_token = vb.get((1, 4, embed_dim), "reg_token")?; + let pos_embed = vb.get((1, patch_embed.num_patches, embed_dim), "pos_embed")?; + let head = linear(vb.pp("head"), embed_dim, NUM_CLASSES, true)?; + let norm = layer_norm(embed_dim, 1e-6, vb.pp("norm"))?; + let vb_b = vb.pp("blocks"); + let blocks = (0..depth) + .map(|i| Block::new(vb_b.pp(&i.to_string()), embed_dim, num_heads)) + .collect::>>()?; + Ok(Self { + patch_embed, + cls_token, + reg_token, + pos_embed, + blocks, + norm, + head, + }) + } + + fn interpolate_pos_encoding(&self, xs: &Tensor, w: usize, h: usize) -> Result { + let npatch = xs.dim(1)? - 1; + let n = self.pos_embed.dim(1)? - 1; + let sqrt_n = (n as f64).sqrt(); + if npatch == n && w == h { + return Ok(self.pos_embed.clone()); + } + let patch_pos_embed = &self.pos_embed; + let dim = xs.dim(D::Minus1)?; + let (w0, h0) = ((w / PATCH_SIZE) as f64 + 0.1, (h / PATCH_SIZE) as f64 + 0.1); + let patch_pos_embed = patch_pos_embed + .reshape((1, sqrt_n as usize, sqrt_n as usize, dim))? + .transpose(2, 3)? + .transpose(1, 2)?; + // This uses bicubic interpolation in the original implementation. + let patch_pos_embed = patch_pos_embed.upsample_nearest2d(h0 as usize, w0 as usize)?; + let el_count = patch_pos_embed.shape().elem_count(); + patch_pos_embed + .transpose(1, 2)? + .transpose(2, 3)? + .reshape((1, el_count / dim, dim)) + } + + fn prepare_tokens_with_mask(&self, xs: &Tensor) -> Result { + let (_b, _nc, w, h) = xs.dims4()?; + if (w != IMG_SIZE) || (h != IMG_SIZE) { + panic!("Error: The input tensor should have the shape: Bx3x518x518."); + } + let xs = self.patch_embed.forward(xs)?; + let xs = (&xs + &self.interpolate_pos_encoding(&xs, w, h)?)?; + let xs = Tensor::cat(&[&self.cls_token, &self.reg_token, &xs], 1)?; + Ok(xs) + } +} + +impl Module for DinoVisionTransformer { + fn forward(&self, xs: &Tensor) -> Result { + let mut xs = self.prepare_tokens_with_mask(xs)?; + for blk in self.blocks.iter() { + xs = blk.forward(&xs)? + } + let xs = self.norm.forward(&xs)?; + let xs_norm_clstoken = xs.i((.., 0))?; + self.head.forward(&xs_norm_clstoken) + } +} + +pub fn vit_small(vb: VarBuilder) -> Result { + DinoVisionTransformer::new(vb, 12, 384, 6) +} + +pub fn vit_base(vb: VarBuilder) -> Result { + DinoVisionTransformer::new(vb, 12, 768, 12) +} diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 89ae0f8a39..2908d3457a 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -8,6 +8,7 @@ pub mod convmixer; pub mod convnext; pub mod depth_anything_v2; pub mod dinov2; +pub mod dinov2reg4; pub mod distilbert; pub mod efficientnet; pub mod efficientvit; From 74e9e4191167c162f61a9e8334cfe2445dd41d83 Mon Sep 17 00:00:00 2001 From: Czxck001 <10724409+Czxck001@users.noreply.github.com> Date: Sat, 29 Jun 2024 12:34:42 -0700 Subject: [PATCH 7/7] make up for the missing last token output of phi2 example (#2299) --- candle-examples/examples/phi/main.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/candle-examples/examples/phi/main.rs b/candle-examples/examples/phi/main.rs index 1cfeb443a2..1a0d9aca53 100644 --- a/candle-examples/examples/phi/main.rs +++ b/candle-examples/examples/phi/main.rs @@ -114,6 +114,10 @@ impl TextGeneration { tokens.push(next_token); generated_tokens += 1; if next_token == eos_token { + if let Some(t) = self.tokenizer.decode_rest()? { + print!("{t}"); + std::io::stdout().flush()?; + } break; } if let Some(t) = self.tokenizer.next_token(next_token)? {