diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index d989f6f5f9..fbfb8200a8 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -1,5 +1,3 @@ -//! A `VarBuilder` for variable retrieval from models -//! //! A `VarBuilder` is used to retrieve variables used by a model. These variables can either come //! from a pre-trained checkpoint, e.g. using `VarBuilder::from_mmaped_safetensors`, or initialized //! for training, e.g. using `VarBuilder::from_varmap`. @@ -60,6 +58,9 @@ pub trait Backend: Send + Sync { dev: &Device, ) -> Result; + /// Retrieve a tensor based on the name. + fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result; + fn contains_tensor(&self, name: &str) -> bool; } @@ -74,6 +75,9 @@ pub trait SimpleBackend: Send + Sync { dev: &Device, ) -> Result; + /// Retrieve a tensor based on the name. + fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result; + fn contains_tensor(&self, name: &str) -> bool; } @@ -90,6 +94,10 @@ impl Backend for Box { self.as_ref().get(s, name, h, dtype, dev) } + fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result { + self.as_ref().get_unchecked(name, dtype, dev) + } + fn contains_tensor(&self, name: &str) -> bool { self.as_ref().contains_tensor(name) } @@ -196,7 +204,7 @@ impl VarBuilderArgs<'_, B> { name: &str, hints: B::Hints, ) -> Result { - self.get_with_hints_dtype(s, name, hints, self.dtype) + self.get_with_hints_dtype(s, name, hints, self.data.dtype) } /// Retrieve the tensor associated with the given name at the current path. @@ -204,6 +212,19 @@ impl VarBuilderArgs<'_, B> { self.get_with_hints(s, name, Default::default()) } + /// Retrieve the tensor associated with the given name at the current path. + pub fn get_unchecked(&self, name: &str) -> Result { + self.get_unchecked_dtype(name, self.data.dtype) + } + + /// Retrieve the tensor associated with the given name & dtype at the current path. + pub fn get_unchecked_dtype(&self, name: &str, dtype: DType) -> Result { + let name = self.path(name); + self.data + .backend + .get_unchecked(&name, dtype, &self.data.device) + } + /// Retrieve the tensor associated with the given name & dtype at the current path. pub fn get_with_hints_dtype>( &self, @@ -251,6 +272,12 @@ impl SimpleBackend for Zeros { Tensor::zeros(s, dtype, dev) } + fn get_unchecked(&self, _name: &str, _dtype: DType, _dev: &Device) -> Result { + candle::bail!( + "`Zeros` requires a shape for tensor retrieval, use `get` instead of `get_unchecked`" + ) + } + fn contains_tensor(&self, _name: &str) -> bool { true } @@ -265,6 +292,19 @@ impl SimpleBackend for HashMap { dtype: DType, dev: &Device, ) -> Result { + let tensor = self.get_unchecked(name, dtype, dev)?; + if tensor.shape() != &s { + Err(candle::Error::UnexpectedShape { + msg: format!("shape mismatch for {name}"), + expected: s, + got: tensor.shape().clone(), + } + .bt())? + } + Ok(tensor) + } + + fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result { let tensor = self .get(name) .ok_or_else(|| { @@ -274,14 +314,6 @@ impl SimpleBackend for HashMap { .bt() })? .clone(); - if tensor.shape() != &s { - Err(candle::Error::UnexpectedShape { - msg: format!("shape mismatch for {name}"), - expected: s, - got: tensor.shape().clone(), - } - .bt())? - } tensor.to_device(dev)?.to_dtype(dtype) } @@ -302,6 +334,10 @@ impl SimpleBackend for VarMap { VarMap::get(self, s, name, h, dtype, dev) } + fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result { + VarMap::get_unchecked(self, name, dtype, dev) + } + fn contains_tensor(&self, name: &str) -> bool { self.data().lock().unwrap().contains_key(name) } @@ -317,11 +353,24 @@ impl SimpleBackend for SafeTensorWithRouting<'_> { fn get( &self, s: Shape, - path: &str, + name: &str, _: crate::Init, dtype: DType, dev: &Device, ) -> Result { + let tensor = self.get_unchecked(name, dtype, dev)?; + if tensor.shape() != &s { + Err(candle::Error::UnexpectedShape { + msg: format!("shape mismatch for {name}"), + expected: s, + got: tensor.shape().clone(), + } + .bt())? + } + Ok(tensor) + } + + fn get_unchecked(&self, path: &str, dtype: DType, dev: &Device) -> Result { let index = self.routing.get(path).ok_or_else(|| { Error::CannotFindTensor { path: path.to_string(), @@ -332,14 +381,6 @@ impl SimpleBackend for SafeTensorWithRouting<'_> { .tensor(path)? .load(dev)? .to_dtype(dtype)?; - if tensor.shape() != &s { - Err(candle::Error::UnexpectedShape { - msg: format!("shape mismatch for {path}"), - expected: s, - got: tensor.shape().clone(), - } - .bt())? - } Ok(tensor) } @@ -352,22 +393,15 @@ impl SimpleBackend for candle::npy::NpzTensors { fn get( &self, s: Shape, - path: &str, + name: &str, _: crate::Init, dtype: DType, dev: &Device, ) -> Result { - let tensor = match self.get(path)? { - None => Err(Error::CannotFindTensor { - path: path.to_string(), - } - .bt())?, - Some(tensor) => tensor, - }; - let tensor = tensor.to_device(dev)?.to_dtype(dtype)?; + let tensor = self.get_unchecked(name, dtype, dev)?; if tensor.shape() != &s { Err(candle::Error::UnexpectedShape { - msg: format!("shape mismatch for {path}"), + msg: format!("shape mismatch for {name}"), expected: s, got: tensor.shape().clone(), } @@ -376,6 +410,18 @@ impl SimpleBackend for candle::npy::NpzTensors { Ok(tensor) } + fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result { + let tensor = match self.get(name)? { + None => Err(Error::CannotFindTensor { + path: name.to_string(), + } + .bt())?, + Some(tensor) => tensor, + }; + let tensor = tensor.to_device(dev)?.to_dtype(dtype)?; + Ok(tensor) + } + fn contains_tensor(&self, name: &str) -> bool { self.get(name).map_or(false, |v| v.is_some()) } @@ -385,22 +431,15 @@ impl SimpleBackend for candle::pickle::PthTensors { fn get( &self, s: Shape, - path: &str, + name: &str, _: crate::Init, dtype: DType, dev: &Device, ) -> Result { - let tensor = match self.get(path)? { - None => Err(Error::CannotFindTensor { - path: path.to_string(), - } - .bt())?, - Some(tensor) => tensor, - }; - let tensor = tensor.to_device(dev)?.to_dtype(dtype)?; + let tensor = self.get_unchecked(name, dtype, dev)?; if tensor.shape() != &s { Err(candle::Error::UnexpectedShape { - msg: format!("shape mismatch for {path}"), + msg: format!("shape mismatch for {name}"), expected: s, got: tensor.shape().clone(), } @@ -409,6 +448,18 @@ impl SimpleBackend for candle::pickle::PthTensors { Ok(tensor) } + fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result { + let tensor = match self.get(name)? { + None => Err(Error::CannotFindTensor { + path: name.to_string(), + } + .bt())?, + Some(tensor) => tensor, + }; + let tensor = tensor.to_device(dev)?.to_dtype(dtype)?; + Ok(tensor) + } + fn contains_tensor(&self, name: &str) -> bool { self.get(name).map_or(false, |v| v.is_some()) } @@ -423,7 +474,7 @@ impl SimpleBackend for candle::safetensors::MmapedSafetensors { dtype: DType, dev: &Device, ) -> Result { - let tensor = self.load(name, dev)?.to_dtype(dtype)?; + let tensor = self.get_unchecked(name, dtype, dev)?; if tensor.shape() != &s { Err(candle::Error::UnexpectedShape { msg: format!("shape mismatch for {name}"), @@ -435,6 +486,10 @@ impl SimpleBackend for candle::safetensors::MmapedSafetensors { Ok(tensor) } + fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result { + self.load(name, dev)?.to_dtype(dtype) + } + fn contains_tensor(&self, name: &str) -> bool { self.get(name).is_ok() } @@ -449,7 +504,7 @@ impl SimpleBackend for candle::safetensors::BufferedSafetensors { dtype: DType, dev: &Device, ) -> Result { - let tensor = self.load(name, dev)?.to_dtype(dtype)?; + let tensor = self.get_unchecked(name, dtype, dev)?; if tensor.shape() != &s { Err(candle::Error::UnexpectedShape { msg: format!("shape mismatch for {name}"), @@ -461,6 +516,10 @@ impl SimpleBackend for candle::safetensors::BufferedSafetensors { Ok(tensor) } + fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result { + self.load(name, dev)?.to_dtype(dtype) + } + fn contains_tensor(&self, name: &str) -> bool { self.get(name).is_ok() } @@ -475,7 +534,7 @@ impl SimpleBackend for candle::safetensors::SliceSafetensors<'_> { dtype: DType, dev: &Device, ) -> Result { - let tensor = self.load(name, dev)?.to_dtype(dtype)?; + let tensor = self.get_unchecked(name, dtype, dev)?; if tensor.shape() != &s { Err(candle::Error::UnexpectedShape { msg: format!("shape mismatch for {name}"), @@ -487,6 +546,10 @@ impl SimpleBackend for candle::safetensors::SliceSafetensors<'_> { Ok(tensor) } + fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result { + self.load(name, dev)?.to_dtype(dtype) + } + fn contains_tensor(&self, name: &str) -> bool { self.get(name).is_ok() } @@ -739,6 +802,10 @@ impl Backend for ShardedSafeTensors { Tensor::from_raw_buffer(&raw, view_dtype, &shape, dev)?.to_dtype(dtype) } + fn get_unchecked(&self, _name: &str, _dtype: DType, _dev: &Device) -> Result { + candle::bail!("`get_unchecked` does not make sense for `ShardedSafeTensors`, use `get`."); + } + fn contains_tensor(&self, name: &str) -> bool { self.0.get(name).is_ok() } @@ -772,6 +839,11 @@ impl SimpleBackend for Rename<'_, R> { .to_device(dev) } + fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result { + let name = self.renamer.rename(name); + self.inner.get_unchecked_dtype(&name, dtype)?.to_device(dev) + } + fn contains_tensor(&self, name: &str) -> bool { let name = self.renamer.rename(name); self.inner.contains_tensor(&name) diff --git a/candle-nn/src/var_map.rs b/candle-nn/src/var_map.rs index ba020746b5..76111e2ac5 100644 --- a/candle-nn/src/var_map.rs +++ b/candle-nn/src/var_map.rs @@ -1,5 +1,3 @@ -//! A `VarMap` is a store that holds named variables. -//! use candle::{DType, Device, Result, Shape, Tensor, Var}; use std::collections::HashMap; use std::sync::{Arc, Mutex}; @@ -115,6 +113,11 @@ impl VarMap { Ok(tensor) } + /// Retrieve or add a new variable. + pub fn get_unchecked(&self, _path: &str, _dtype: DType, _device: &Device) -> Result { + candle::bail!("`get_unchecked` does not make sense for `VarMap`, use `get`."); + } + pub fn data(&self) -> &Mutex> { &self.data }