Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Breaking] Combining separate device errors into single dfdx::tensor::Error enum #875

Merged
merged 4 commits into from
Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 14 additions & 42 deletions dfdx-core/src/nn_traits/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,16 @@ mod vecs;

use std::vec::Vec;

use crate::prelude::{Device, Dtype, Gradients, Shape, Tensor, UniqueId};
use crate::prelude::{Device, Dtype, Error, Gradients, Shape, Tensor, UniqueId};

/// Mutable & Immutable forward of `Input` that produces [Module::Output].
pub trait Module<X> {
/// The type that this unit produces given `Input`.
type Output;
type Error: std::fmt::Debug;

fn try_forward(&self, x: X) -> Result<Self::Output, Self::Error>;
fn try_forward(&self, x: X) -> Result<Self::Output, Error>;

fn try_forward_mut(&mut self, x: X) -> Result<Self::Output, Self::Error> {
fn try_forward_mut(&mut self, x: X) -> Result<Self::Output, Error> {
self.try_forward(x)
}

Expand All @@ -26,52 +25,25 @@ pub trait Module<X> {
}
}

/// An error indicating that a parameter was not used in gradient
/// computation, and was therefore not present in [Gradients]
/// during an update.
#[derive(Debug)]
pub enum OptimizerUpdateError<Err> {
UnusedTensors(Vec<UniqueId>),
DeviceError(Err),
}

impl<Err: std::fmt::Display> std::fmt::Display for OptimizerUpdateError<Err> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::UnusedTensors(unused) => write!(f, "Unused tensors: {unused:?}"),
Self::DeviceError(err) => write!(f, "{err}"),
}
}
}

#[cfg(feature = "std")]
impl<Err: std::fmt::Debug + std::fmt::Display> std::error::Error for OptimizerUpdateError<Err> {}

/// Something that can update both tensors and a [UpdateParams]. At minimum [Optimizer::update_tensor()] must be implemented.
pub trait Optimizer<M, E: Dtype, D: Device<E>>: Sized {
fn update_tensor<S: Shape>(
&mut self,
t: &mut Tensor<S, E, D>,
gradients: &Gradients<E, D>,
missing_tensors: &mut Vec<UniqueId>,
) -> Result<(), D::Err>;
) -> Result<(), Error>;

fn update(
&mut self,
module: &mut M,
gradients: &Gradients<E, D>,
) -> Result<(), OptimizerUpdateError<D::Err>>
fn update(&mut self, module: &mut M, gradients: &Gradients<E, D>) -> Result<(), Error>
where
M: UpdateParams<E, D>,
{
let mut missing_tensors = Vec::new();
module
.try_update_params(self, gradients, &mut missing_tensors)
.map_err(OptimizerUpdateError::DeviceError)?;
module.try_update_params(self, gradients, &mut missing_tensors)?;
if missing_tensors.is_empty() {
Ok(())
} else {
Err(OptimizerUpdateError::UnusedTensors(missing_tensors))
Err(Error::UnusedTensors(missing_tensors))
}
}
}
Expand All @@ -82,15 +54,15 @@ pub trait BuildOnDevice<E: Dtype, D: Device<E>>: Clone {
fn build_on_device(&self, device: &D) -> Self::Built {
self.try_build_on_device(device).unwrap()
}
fn try_build_on_device(&self, device: &D) -> Result<Self::Built, D::Err>;
fn try_build_on_device(&self, device: &D) -> Result<Self::Built, crate::tensor::Error>;
}

/// Something that can have all of its parameters reset to a specific state (may be random or not random).
pub trait ResetParams<E: Dtype, D: Device<E>> {
fn reset_params(&mut self) {
self.try_reset_params().unwrap()
}
fn try_reset_params(&mut self) -> Result<(), D::Err>;
fn try_reset_params(&mut self) -> Result<(), crate::tensor::Error>;
}

/// Something that can have it's params updated with an [Optimizer] and a set of [Gradients].
Expand All @@ -109,7 +81,7 @@ pub trait UpdateParams<E: Dtype, D: Device<E>> {
optimizer: &mut Optim,
gradients: &Gradients<E, D>,
missing_tensors: &mut Vec<UniqueId>,
) -> Result<(), D::Err>;
) -> Result<(), crate::tensor::Error>;
}

impl<S: Shape, E: Dtype, D: Device<E>> UpdateParams<E, D> for Tensor<S, E, D> {
Expand All @@ -118,7 +90,7 @@ impl<S: Shape, E: Dtype, D: Device<E>> UpdateParams<E, D> for Tensor<S, E, D> {
optimizer: &mut Optim,
gradients: &Gradients<E, D>,
missing_tensors: &mut Vec<UniqueId>,
) -> Result<(), <D>::Err> {
) -> Result<(), crate::tensor::Error> {
optimizer.update_tensor(self, gradients, missing_tensors)
}
}
Expand All @@ -128,12 +100,12 @@ pub trait ZeroGrads<E: Dtype, D: Device<E>> {
fn zero_grads(&self, grads: &mut Gradients<E, D>) {
self.try_zero_grads(grads).unwrap()
}
fn try_zero_grads(&self, grads: &mut Gradients<E, D>) -> Result<(), D::Err>;
fn try_zero_grads(&self, grads: &mut Gradients<E, D>) -> Result<(), crate::tensor::Error>;

fn alloc_grads(&self) -> Gradients<E, D> {
self.try_alloc_grads().unwrap()
}
fn try_alloc_grads(&self) -> Result<Gradients<E, D>, D::Err> {
fn try_alloc_grads(&self) -> Result<Gradients<E, D>, crate::tensor::Error> {
let mut grads = Gradients::leaky();
self.try_zero_grads(&mut grads)?;
grads.retain_current_grads_as_leafs();
Expand Down Expand Up @@ -275,7 +247,7 @@ pub trait BuildModuleExt<M>: Sized {
self.try_build_module(m).unwrap()
}

fn try_build_module<E: Dtype>(&self, m: M) -> Result<M::Built, Self::Err>
fn try_build_module<E: Dtype>(&self, m: M) -> Result<M::Built, Error>
where
M: BuildOnDevice<E, Self>,
M::Built: ResetParams<E, Self>,
Expand Down
21 changes: 12 additions & 9 deletions dfdx-core/src/nn_traits/tuples.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
use crate::{dtypes::Dtype, tensor::UniqueId, tensor_ops::Device};
use crate::{
dtypes::Dtype,
tensor::{Error, UniqueId},
tensor_ops::Device,
};

use std::vec::Vec;

Expand All @@ -7,7 +11,7 @@ macro_rules! tuple_impls {

impl<Dev: Device<Elem>, Elem: Dtype, $($name: crate::nn_traits::BuildOnDevice<Elem, Dev>),+> crate::nn_traits::BuildOnDevice<Elem, Dev> for ($($name,)+) {
type Built = ($($name::Built, )+);
fn try_build_on_device(&self, device: &Dev) -> Result<Self::Built, Dev::Err> {
fn try_build_on_device(&self, device: &Dev) -> Result<Self::Built, Error> {
Ok(($(
self.$idx.try_build_on_device(device)?,
)+))
Expand Down Expand Up @@ -38,7 +42,7 @@ macro_rules! tuple_impls {
}

impl<Dev: Device<Elem>, Elem: Dtype, $($name: crate::nn_traits::ResetParams<Elem, Dev>),+> crate::nn_traits::ResetParams<Elem, Dev> for ($($name,)+) {
fn try_reset_params(&mut self) -> Result<(), Dev::Err> {
fn try_reset_params(&mut self) -> Result<(), Error> {
$(self.$idx.try_reset_params()?;)+
Ok(())
}
Expand All @@ -50,14 +54,14 @@ macro_rules! tuple_impls {
optimizer: &mut Optim,
gradients: &crate::prelude::Gradients<Elem, Dev>,
missing_tensors: &mut Vec<UniqueId>,
) -> Result<(), Dev::Err> {
) -> Result<(), Error> {
$(self.$idx.try_update_params(optimizer, gradients, missing_tensors)?;)+
Ok(())
}
}

impl<Dev: Device<Elem>, Elem: Dtype, $($name: crate::nn_traits::ZeroGrads<Elem, Dev>),+> crate::nn_traits::ZeroGrads<Elem, Dev> for ($($name,)+) {
fn try_zero_grads(&self, grads: &mut crate::prelude::Gradients<Elem, Dev>) -> Result<(), Dev::Err> {
fn try_zero_grads(&self, grads: &mut crate::prelude::Gradients<Elem, Dev>) -> Result<(), Error> {
$(self.$idx.try_zero_grads(grads)?;)+
Ok(())
}
Expand Down Expand Up @@ -91,20 +95,19 @@ macro_rules! tuple_impls {
impl<
Input,
$last:
$(crate::nn_traits::Module::<$rev_tail ::Output, Error=$rev_tail::Error>, $rev_tail: )*
$(crate::nn_traits::Module::<$rev_tail ::Output>, $rev_tail: )*
crate::nn_traits::Module<Input>
> crate::nn_traits::Module<Input> for ($($name,)+) {
type Output = $last ::Output;
type Error = $last ::Error;

/// Calls forward sequentially on each module in the tuple.
fn try_forward(&self, x: Input) -> Result<Self::Output, Self::Error> {
fn try_forward(&self, x: Input) -> Result<Self::Output, Error> {
$(let x = self.$idx.try_forward(x)?;)+
Ok(x)
}

/// Calls forward sequentially on each module in the tuple.
fn try_forward_mut(&mut self, x: Input) -> Result<Self::Output, Self::Error> {
fn try_forward_mut(&mut self, x: Input) -> Result<Self::Output, Error> {
$(let x = self.$idx.try_forward_mut(x)?;)+
Ok(x)
}
Expand Down
22 changes: 14 additions & 8 deletions dfdx-core/src/nn_traits/vecs.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
use crate::{dtypes::Dtype, tensor::UniqueId, tensor_ops::Device};
use crate::{
dtypes::Dtype,
tensor::{Error, UniqueId},
tensor_ops::Device,
};

use std::vec::Vec;

impl<E: Dtype, D: Device<E>, T: crate::nn_traits::BuildOnDevice<E, D>>
crate::nn_traits::BuildOnDevice<E, D> for Vec<T>
{
type Built = Vec<T::Built>;
fn try_build_on_device(&self, device: &D) -> Result<Self::Built, <D>::Err> {
fn try_build_on_device(&self, device: &D) -> Result<Self::Built, crate::tensor::Error> {
self.iter()
.map(|m_i| m_i.try_build_on_device(device))
.collect()
Expand All @@ -16,7 +20,7 @@ impl<E: Dtype, D: Device<E>, T: crate::nn_traits::BuildOnDevice<E, D>>
impl<E: Dtype, D: Device<E>, T: crate::nn_traits::ResetParams<E, D>>
crate::nn_traits::ResetParams<E, D> for Vec<T>
{
fn try_reset_params(&mut self) -> Result<(), <D>::Err> {
fn try_reset_params(&mut self) -> Result<(), crate::tensor::Error> {
for m_i in self.iter_mut() {
m_i.try_reset_params()?;
}
Expand All @@ -32,7 +36,7 @@ impl<E: Dtype, D: Device<E>, T: crate::nn_traits::UpdateParams<E, D>>
optimizer: &mut Optim,
gradients: &crate::tensor::Gradients<E, D>,
missing_tensors: &mut Vec<UniqueId>,
) -> Result<(), D::Err> {
) -> Result<(), crate::tensor::Error> {
for m_i in self.iter_mut() {
m_i.try_update_params(optimizer, gradients, missing_tensors)?;
}
Expand All @@ -43,7 +47,10 @@ impl<E: Dtype, D: Device<E>, T: crate::nn_traits::UpdateParams<E, D>>
impl<E: Dtype, D: Device<E>, T: crate::nn_traits::ZeroGrads<E, D>> crate::nn_traits::ZeroGrads<E, D>
for Vec<T>
{
fn try_zero_grads(&self, grads: &mut crate::tensor::Gradients<E, D>) -> Result<(), <D>::Err> {
fn try_zero_grads(
&self,
grads: &mut crate::tensor::Gradients<E, D>,
) -> Result<(), crate::tensor::Error> {
for m_i in self.iter() {
m_i.try_zero_grads(grads)?;
}
Expand Down Expand Up @@ -82,15 +89,14 @@ impl<Input, T: crate::nn_traits::Module<Input, Output = Input>> crate::nn_traits
for Vec<T>
{
type Output = T::Output;
type Error = T::Error;

fn try_forward(&self, mut x: Input) -> Result<Self::Output, T::Error> {
fn try_forward(&self, mut x: Input) -> Result<Self::Output, Error> {
for m_i in self.iter() {
x = m_i.try_forward(x)?;
}
Ok(x)
}
fn try_forward_mut(&mut self, mut x: Input) -> Result<Self::Output, Self::Error> {
fn try_forward_mut(&mut self, mut x: Input) -> Result<Self::Output, Error> {
for m_i in self.iter_mut() {
x = m_i.try_forward_mut(x)?;
}
Expand Down
Loading
Loading