Skip to content

Commit

Permalink
Add varbuilder get_unchecked (huggingface#52)
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler authored Dec 10, 2024
1 parent a3814f5 commit 2c7408b
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 45 deletions.
158 changes: 115 additions & 43 deletions candle-nn/src/var_builder.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
//! A `VarBuilder` for variable retrieval from models
//!
//! A `VarBuilder` is used to retrieve variables used by a model. These variables can either come
//! from a pre-trained checkpoint, e.g. using `VarBuilder::from_mmaped_safetensors`, or initialized
//! for training, e.g. using `VarBuilder::from_varmap`.
Expand Down Expand Up @@ -60,6 +58,9 @@ pub trait Backend: Send + Sync {
dev: &Device,
) -> Result<Tensor>;

/// Retrieve a tensor based on the name.
fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result<Tensor>;

fn contains_tensor(&self, name: &str) -> bool;
}

Expand All @@ -74,6 +75,9 @@ pub trait SimpleBackend: Send + Sync {
dev: &Device,
) -> Result<Tensor>;

/// Retrieve a tensor based on the name.
fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result<Tensor>;

fn contains_tensor(&self, name: &str) -> bool;
}

Expand All @@ -90,6 +94,10 @@ impl Backend for Box<dyn SimpleBackend + '_> {
self.as_ref().get(s, name, h, dtype, dev)
}

fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result<Tensor> {
self.as_ref().get_unchecked(name, dtype, dev)
}

fn contains_tensor(&self, name: &str) -> bool {
self.as_ref().contains_tensor(name)
}
Expand Down Expand Up @@ -196,14 +204,27 @@ impl<B: Backend> VarBuilderArgs<'_, B> {
name: &str,
hints: B::Hints,
) -> Result<Tensor> {
self.get_with_hints_dtype(s, name, hints, self.dtype)
self.get_with_hints_dtype(s, name, hints, self.data.dtype)
}

/// Retrieve the tensor associated with the given name at the current path.
pub fn get<S: Into<Shape>>(&self, s: S, name: &str) -> Result<Tensor> {
self.get_with_hints(s, name, Default::default())
}

/// Retrieve the tensor associated with the given name at the current path.
pub fn get_unchecked(&self, name: &str) -> Result<Tensor> {
self.get_unchecked_dtype(name, self.data.dtype)
}

/// Retrieve the tensor associated with the given name & dtype at the current path.
pub fn get_unchecked_dtype(&self, name: &str, dtype: DType) -> Result<Tensor> {
let name = self.path(name);
self.data
.backend
.get_unchecked(&name, dtype, &self.data.device)
}

/// Retrieve the tensor associated with the given name & dtype at the current path.
pub fn get_with_hints_dtype<S: Into<Shape>>(
&self,
Expand Down Expand Up @@ -251,6 +272,12 @@ impl SimpleBackend for Zeros {
Tensor::zeros(s, dtype, dev)
}

fn get_unchecked(&self, _name: &str, _dtype: DType, _dev: &Device) -> Result<Tensor> {
candle::bail!(
"`Zeros` requires a shape for tensor retrieval, use `get` instead of `get_unchecked`"
)
}

fn contains_tensor(&self, _name: &str) -> bool {
true
}
Expand All @@ -265,6 +292,19 @@ impl SimpleBackend for HashMap<String, Tensor> {
dtype: DType,
dev: &Device,
) -> Result<Tensor> {
let tensor = self.get_unchecked(name, dtype, dev)?;
if tensor.shape() != &s {
Err(candle::Error::UnexpectedShape {
msg: format!("shape mismatch for {name}"),
expected: s,
got: tensor.shape().clone(),
}
.bt())?
}
Ok(tensor)
}

fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result<Tensor> {
let tensor = self
.get(name)
.ok_or_else(|| {
Expand All @@ -274,14 +314,6 @@ impl SimpleBackend for HashMap<String, Tensor> {
.bt()
})?
.clone();
if tensor.shape() != &s {
Err(candle::Error::UnexpectedShape {
msg: format!("shape mismatch for {name}"),
expected: s,
got: tensor.shape().clone(),
}
.bt())?
}
tensor.to_device(dev)?.to_dtype(dtype)
}

Expand All @@ -302,6 +334,10 @@ impl SimpleBackend for VarMap {
VarMap::get(self, s, name, h, dtype, dev)
}

fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result<Tensor> {
VarMap::get_unchecked(self, name, dtype, dev)
}

fn contains_tensor(&self, name: &str) -> bool {
self.data().lock().unwrap().contains_key(name)
}
Expand All @@ -317,11 +353,24 @@ impl SimpleBackend for SafeTensorWithRouting<'_> {
fn get(
&self,
s: Shape,
path: &str,
name: &str,
_: crate::Init,
dtype: DType,
dev: &Device,
) -> Result<Tensor> {
let tensor = self.get_unchecked(name, dtype, dev)?;
if tensor.shape() != &s {
Err(candle::Error::UnexpectedShape {
msg: format!("shape mismatch for {name}"),
expected: s,
got: tensor.shape().clone(),
}
.bt())?
}
Ok(tensor)
}

fn get_unchecked(&self, path: &str, dtype: DType, dev: &Device) -> Result<Tensor> {
let index = self.routing.get(path).ok_or_else(|| {
Error::CannotFindTensor {
path: path.to_string(),
Expand All @@ -332,14 +381,6 @@ impl SimpleBackend for SafeTensorWithRouting<'_> {
.tensor(path)?
.load(dev)?
.to_dtype(dtype)?;
if tensor.shape() != &s {
Err(candle::Error::UnexpectedShape {
msg: format!("shape mismatch for {path}"),
expected: s,
got: tensor.shape().clone(),
}
.bt())?
}
Ok(tensor)
}

Expand All @@ -352,22 +393,15 @@ impl SimpleBackend for candle::npy::NpzTensors {
fn get(
&self,
s: Shape,
path: &str,
name: &str,
_: crate::Init,
dtype: DType,
dev: &Device,
) -> Result<Tensor> {
let tensor = match self.get(path)? {
None => Err(Error::CannotFindTensor {
path: path.to_string(),
}
.bt())?,
Some(tensor) => tensor,
};
let tensor = tensor.to_device(dev)?.to_dtype(dtype)?;
let tensor = self.get_unchecked(name, dtype, dev)?;
if tensor.shape() != &s {
Err(candle::Error::UnexpectedShape {
msg: format!("shape mismatch for {path}"),
msg: format!("shape mismatch for {name}"),
expected: s,
got: tensor.shape().clone(),
}
Expand All @@ -376,6 +410,18 @@ impl SimpleBackend for candle::npy::NpzTensors {
Ok(tensor)
}

fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result<Tensor> {
let tensor = match self.get(name)? {
None => Err(Error::CannotFindTensor {
path: name.to_string(),
}
.bt())?,
Some(tensor) => tensor,
};
let tensor = tensor.to_device(dev)?.to_dtype(dtype)?;
Ok(tensor)
}

fn contains_tensor(&self, name: &str) -> bool {
self.get(name).map_or(false, |v| v.is_some())
}
Expand All @@ -385,22 +431,15 @@ impl SimpleBackend for candle::pickle::PthTensors {
fn get(
&self,
s: Shape,
path: &str,
name: &str,
_: crate::Init,
dtype: DType,
dev: &Device,
) -> Result<Tensor> {
let tensor = match self.get(path)? {
None => Err(Error::CannotFindTensor {
path: path.to_string(),
}
.bt())?,
Some(tensor) => tensor,
};
let tensor = tensor.to_device(dev)?.to_dtype(dtype)?;
let tensor = self.get_unchecked(name, dtype, dev)?;
if tensor.shape() != &s {
Err(candle::Error::UnexpectedShape {
msg: format!("shape mismatch for {path}"),
msg: format!("shape mismatch for {name}"),
expected: s,
got: tensor.shape().clone(),
}
Expand All @@ -409,6 +448,18 @@ impl SimpleBackend for candle::pickle::PthTensors {
Ok(tensor)
}

fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result<Tensor> {
let tensor = match self.get(name)? {
None => Err(Error::CannotFindTensor {
path: name.to_string(),
}
.bt())?,
Some(tensor) => tensor,
};
let tensor = tensor.to_device(dev)?.to_dtype(dtype)?;
Ok(tensor)
}

fn contains_tensor(&self, name: &str) -> bool {
self.get(name).map_or(false, |v| v.is_some())
}
Expand All @@ -423,7 +474,7 @@ impl SimpleBackend for candle::safetensors::MmapedSafetensors {
dtype: DType,
dev: &Device,
) -> Result<Tensor> {
let tensor = self.load(name, dev)?.to_dtype(dtype)?;
let tensor = self.get_unchecked(name, dtype, dev)?;
if tensor.shape() != &s {
Err(candle::Error::UnexpectedShape {
msg: format!("shape mismatch for {name}"),
Expand All @@ -435,6 +486,10 @@ impl SimpleBackend for candle::safetensors::MmapedSafetensors {
Ok(tensor)
}

fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result<Tensor> {
self.load(name, dev)?.to_dtype(dtype)
}

fn contains_tensor(&self, name: &str) -> bool {
self.get(name).is_ok()
}
Expand All @@ -449,7 +504,7 @@ impl SimpleBackend for candle::safetensors::BufferedSafetensors {
dtype: DType,
dev: &Device,
) -> Result<Tensor> {
let tensor = self.load(name, dev)?.to_dtype(dtype)?;
let tensor = self.get_unchecked(name, dtype, dev)?;
if tensor.shape() != &s {
Err(candle::Error::UnexpectedShape {
msg: format!("shape mismatch for {name}"),
Expand All @@ -461,6 +516,10 @@ impl SimpleBackend for candle::safetensors::BufferedSafetensors {
Ok(tensor)
}

fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result<Tensor> {
self.load(name, dev)?.to_dtype(dtype)
}

fn contains_tensor(&self, name: &str) -> bool {
self.get(name).is_ok()
}
Expand All @@ -475,7 +534,7 @@ impl SimpleBackend for candle::safetensors::SliceSafetensors<'_> {
dtype: DType,
dev: &Device,
) -> Result<Tensor> {
let tensor = self.load(name, dev)?.to_dtype(dtype)?;
let tensor = self.get_unchecked(name, dtype, dev)?;
if tensor.shape() != &s {
Err(candle::Error::UnexpectedShape {
msg: format!("shape mismatch for {name}"),
Expand All @@ -487,6 +546,10 @@ impl SimpleBackend for candle::safetensors::SliceSafetensors<'_> {
Ok(tensor)
}

fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result<Tensor> {
self.load(name, dev)?.to_dtype(dtype)
}

fn contains_tensor(&self, name: &str) -> bool {
self.get(name).is_ok()
}
Expand Down Expand Up @@ -739,6 +802,10 @@ impl Backend for ShardedSafeTensors {
Tensor::from_raw_buffer(&raw, view_dtype, &shape, dev)?.to_dtype(dtype)
}

fn get_unchecked(&self, _name: &str, _dtype: DType, _dev: &Device) -> Result<Tensor> {
candle::bail!("`get_unchecked` does not make sense for `ShardedSafeTensors`, use `get`.");
}

fn contains_tensor(&self, name: &str) -> bool {
self.0.get(name).is_ok()
}
Expand Down Expand Up @@ -772,6 +839,11 @@ impl<R: Renamer + Sync + Send> SimpleBackend for Rename<'_, R> {
.to_device(dev)
}

fn get_unchecked(&self, name: &str, dtype: DType, dev: &Device) -> Result<Tensor> {
let name = self.renamer.rename(name);
self.inner.get_unchecked_dtype(&name, dtype)?.to_device(dev)
}

fn contains_tensor(&self, name: &str) -> bool {
let name = self.renamer.rename(name);
self.inner.contains_tensor(&name)
Expand Down
7 changes: 5 additions & 2 deletions candle-nn/src/var_map.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
//! A `VarMap` is a store that holds named variables.
//!
use candle::{DType, Device, Result, Shape, Tensor, Var};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
Expand Down Expand Up @@ -115,6 +113,11 @@ impl VarMap {
Ok(tensor)
}

/// Retrieve or add a new variable.
pub fn get_unchecked(&self, _path: &str, _dtype: DType, _device: &Device) -> Result<Tensor> {
candle::bail!("`get_unchecked` does not make sense for `VarMap`, use `get`.");
}

pub fn data(&self) -> &Mutex<HashMap<String, Var>> {
&self.data
}
Expand Down

0 comments on commit 2c7408b

Please sign in to comment.