Skip to content

Commit

Permalink
generalize Id; test renaming
Browse files Browse the repository at this point in the history
  • Loading branch information
swfsql committed Nov 8, 2023
1 parent 0a394c7 commit 52b6264
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
6 changes: 3 additions & 3 deletions dfdx/src/nn/layers/id.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ use crate::prelude::*;
#[derive(Default, Debug, Clone, Copy, CustomModule)]
pub struct Id;

impl<S: Shape, E: Dtype, D: Device<E>, T: Tape<E, D>> Module<Tensor<S, E, D, T>> for Id {
type Output = Tensor<S, E, D, T>;
fn try_forward(&self, x: Tensor<S, E, D, T>) -> Result<Self::Output, Error> {
impl<Input> Module<Input> for Id {
type Output = Input;
fn try_forward(&self, x: Input) -> Result<Self::Output, Error> {
Ok(x)
}
}
Expand Down
14 changes: 7 additions & 7 deletions dfdx/src/nn/layers/on.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::marker::PhantomData;

// TODO: try making a Call module, whih allows calling an arbitrary method on the input.

/// Access the input that is stored in a wrapper structure.
/// Applies module `T` into an input field from a wrapper.
#[derive(
Debug, Default, Clone, ResetParams, ZeroGrads, UpdateParams, LoadSafeTensors, SaveSafeTensors,
)]
Expand Down Expand Up @@ -52,10 +52,10 @@ mod tests {
// input is (Input, Input)
pub input_to_wrapper: split1::FromTuple,

// input is Split1 { Input, Input }
// input is Split1<Input, Input>
pub t: On<split1::forward, T>,

// input is Split1 { T::Output, Input }
// input is Split1<T::Output, Input>
pub input_to_tuple: split1::IntoTuple,

// input is (T::Output, Input)
Expand All @@ -64,7 +64,7 @@ mod tests {
}

#[test]
fn test_residual_add_backward() {
fn test_input_wrapper_struct() {
let dev: TestDevice = Default::default();

let model = dev.build_module::<f32>(<ResidualAdd1<LinearConstConfig<2, 2>>>::default());
Expand Down Expand Up @@ -106,10 +106,10 @@ mod tests {
// input is (Input, Input)
pub input_to_wrapper: split2::FromTuple,

// input is Split2 ( Input, Input )
// input is Split2<Input, Input>
pub t: On<split2::_0, T>,

// input is Split2 ( T::Output, Input )
// input is Split2<T::Output, Input>
pub input_to_tuple: split2::IntoTuple,

// input is (T::Output, Input)
Expand All @@ -118,7 +118,7 @@ mod tests {
}

#[test]
fn test_residual_add_backward2() {
fn test_input_wrapper_tuple_struct() {
let dev: TestDevice = Default::default();

let model = dev.build_module::<f32>(<ResidualAdd2<LinearConstConfig<2, 2>>>::default());
Expand Down

0 comments on commit 52b6264

Please sign in to comment.