Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Input Wrapper #883

Closed
wants to merge 7 commits into from
Closed

Input Wrapper #883

wants to merge 7 commits into from

Conversation

swfsql
Copy link
Contributor

@swfsql swfsql commented Nov 6, 2023

This is a draft, closes #878.
Note: If this design ends up being useful, this could be implemented as a separated library (there are only code additions and they don't conflict with anything), but for perhaps feedback, it's better and more straight-forward to currently have this as a draft PR.

  • Add #[input_wrapper].
    • Add the heck dep to convert from CamelCase into snake_case.
  • Add layers.
    • Id, which just forwards the input.
    • On, applies some Module into an input wrapper field.
      • Contains a test demonstrating it's usage.
    • Add, which calls try_add for the inputs.

This is how it gets used:

#[input_wrapper]
pub struct Split1<Forward, Skip> {
pub forward: Forward,
pub skip: Skip,
}
#[derive(Default, Clone, Sequential)]
pub struct ResidualAdd1<T: Clone + std::fmt::Debug> {
// input is Input
pub split: SplitInto<(Id, Id)>,
// input is (Input, Input)
pub input_to_wrapper: split1::FromTuple,
// input is Split1<Input, Input>
pub t: On<split1::forward, T>,
// input is Split1<T::Output, Input>
pub input_to_tuple: split1::IntoTuple,
// input is (T::Output, Input)
pub add: ops::Add,
// input is T::Output = Input
}

This is what gets generated from the above:

rust code

pub struct Split1<Forward, Skip> {
    pub forward: Forward,
    pub skip: Skip,
}
/// Automatically generated by `input_wrapper`. The containing items are visible on your project's documentation.
pub mod split1 {
    use super::Split1;
    /// Indicates the [`Split1::forward`] field.  \nThis field is the `0` value (`0`-based index).
    #[allow(non_camel_case_types)]
    #[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
    pub struct forward;

    /// Indicates the [`Split1::skip`] field.  \nThis field is the `1` value (`0`-based index).
    #[allow(non_camel_case_types)]
    #[derive(Debug, Default, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
    pub struct skip;

    /// Indicates a conversion from a (Forward, Skip) tuple into a `Split1<Forward, Skip>`.
    #[derive(
        Debug, Default, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, crate::prelude::CustomModule,
    )]
    pub struct FromTuple;

    /// Indicates a conversion from a `Split1<Forward, Skip>` into a (Forward, Skip) tuple.
    #[derive(
        Debug, Default, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, crate::prelude::CustomModule,
    )]
    pub struct IntoTuple;

    /// Conversion of a tuple into a [`Split1`].
    impl<Forward, Skip> From<(Forward, Skip)> for Split1<Forward, Skip> {
        fn from(x: (Forward, Skip)) -> Self {
            Split1 {
                forward: x.0,
                skip: x.1,
            }
        }
    }
    /// Conversion of a [`Split1`] into a tuple.
    impl<Forward, Skip> From<Split1<Forward, Skip>> for (Forward, Skip) {
        fn from(x: Split1<Forward, Skip>) -> Self {
            (x.forward, x.skip)
        }
    }
    /// Module to convert a tuple into a [`Split1`].
    impl<Forward, Skip> crate::prelude::Module<(Forward, Skip)> for FromTuple {
        type Output = Split1<Forward, Skip>;
        fn try_forward(&self, x: (Forward, Skip)) -> Result<Self::Output, crate::prelude::Error> {
            Ok(x.into())
        }
    }
    /// Module to convert a [`Split1`] into a tuple.
    impl<Forward, Skip> crate::prelude::Module<Split1<Forward, Skip>> for IntoTuple {
        type Output = (Forward, Skip);
        fn try_forward(
            &self,
            x: Split1<Forward, Skip>,
        ) -> Result<Self::Output, crate::prelude::Error> {
            Ok(x.into())
        }
    }
    /// Module that access [`Split1::forward`] and then applies Module `M` on it.
    impl<M: crate::prelude::Module<Forward>, Forward, Skip>
        crate::prelude::Module<Split1<Forward, Skip>> for crate::prelude::On<forward, M>
    {
        type Output = Split1<<M as crate::prelude::Module<Forward>>::Output, Skip>;
        fn try_forward(
            &self,
            x: Split1<Forward, Skip>,
        ) -> Result<Self::Output, crate::prelude::Error> {
            let x0 = x.forward;
            let x1 = x.skip;
            let x0 = self.t.try_forward(x0)?;
            let x = Split1 {
                forward: x0,
                skip: x1,
            };
            Ok(x)
        }
        fn try_forward_mut(
            &mut self,
            x: Split1<Forward, Skip>,
        ) -> Result<Self::Output, crate::prelude::Error> {
            let x0 = x.forward;
            let x1 = x.skip;
            let x0 = self.t.try_forward_mut(x0)?;
            let x = Split1 {
                forward: x0,
                skip: x1,
            };
            Ok(x)
        }
    }
    /// Module that access [`Split1::skip`] and then applies Module `M` on it.
    impl<M: crate::prelude::Module<Skip>, Forward, Skip>
        crate::prelude::Module<Split1<Forward, Skip>> for crate::prelude::On<skip, M>
    {
        type Output = Split1<Forward, <M as crate::prelude::Module<Skip>>::Output>;
        fn try_forward(
            &self,
            x: Split1<Forward, Skip>,
        ) -> Result<Self::Output, crate::prelude::Error> {
            let x0 = x.forward;
            let x1 = x.skip;
            let x1 = self.t.try_forward(x1)?;
            let x = Split1 {
                forward: x0,
                skip: x1,
            };
            Ok(x)
        }
        fn try_forward_mut(
            &mut self,
            x: Split1<Forward, Skip>,
        ) -> Result<Self::Output, crate::prelude::Error> {
            let x0 = x.forward;
            let x1 = x.skip;
            let x1 = self.t.try_forward_mut(x1)?;
            let x = Split1 {
                forward: x0,
                skip: x1,
            };
            Ok(x)
        }
    }
}

@swfsql swfsql force-pushed the derive-input-wrapper branch from b4ad71e to 52b6264 Compare November 8, 2023 00:28
@swfsql
Copy link
Contributor Author

swfsql commented Nov 25, 2023

To add info on this, this is how I was able to define a unet:

(Note: in this case I was using a version of dfdx that had some other local changes, specially experimental ones.)

rust code

#[input_wrapper]
#[derive(Clone, Debug)]
pub struct Split<Forward, Skip> {
    pub forward: Forward,
    pub skip: Skip,
}

impl<Forward, Skip, const AXIS: isize> TryConcatTensorAlong<Axis<AXIS>> for Split<Forward, Skip>
where
    (Forward, Skip): TryConcatTensorAlong<Axis<AXIS>>,
{
    type Output = <(Forward, Skip) as TryConcatTensorAlong<Axis<AXIS>>>::Output;
    fn try_concat_tensor_along(self, ax: Axis<AXIS>) -> Result<Self::Output, Error> {
        let (forward, skip) = self.into();
        (forward, skip).try_concat_tensor_along(ax)
    }
}

/// From:
/// ```ignore
/// batch * CH_IN * height * width
/// ```
///
/// To:
/// ```ignore
/// batch * CH_OUT * height * width
/// ```
#[derive(Debug, Clone, Default, dfdx::Sequential)]
#[built(ConvBlock)]
pub struct ConvBlockConfig<const CH_IN: usize, const CH_OUT: usize> {
    pub conv_1: Conv2DConstConfig<CH_IN, CH_OUT, 3, 1, 1>,
    pub norm_1: BatchNorm2DConstConfig<CH_OUT>,
    pub a_1: ReLU,
    //
    pub conv_2: Conv2DConstConfig<CH_OUT, CH_OUT, 3, 1, 1>,
    pub norm_2: BatchNorm2DConstConfig<CH_OUT>,
    pub a_2: ReLU,
}

/// From:
/// ```ignore
/// batch * CH_IN * height * width
/// ```
///
/// To:
/// ```ignore
/// Split {
///     forward: batch * CH_OUT * height/2 * width/2,
///     skip: batch * CH_OUT * height * width,
/// }
/// ```
#[derive(Debug, Clone, Default, dfdx::Sequential)]
#[built(DownBlock)]
pub struct DownBlockConfig<const CH_IN: usize, const CH_OUT: usize> {
    pub conv: ConvBlockConfig<CH_IN, CH_OUT>,
    //
    pub split: SplitInto<(Id, Id)>,
    pub wrapper: split::FromTuple,
    pub pool: On<split::forward, MaxPool2DConst<2, 2, 0>>,
}

/// From:
/// ```ignore
/// Split {
///     forward: batch * CH_INF * height/2 * width/2,
///     skip: batch * CH_INS * height * width,
/// }
/// ```
///
/// To:
/// ```ignore
/// batch * CH_OUT * height * width
/// ```
///
/// Notes:
/// - `CH_INF` refers to the #channels from [`Split::forward`].
/// - `CH_INS` refers to the #channels from [`Split::skip`], but this parameter is not directly passed to this structure.
/// - `CH_CONCAT` is supposed to be `CH_OUT + CH_INS`.
#[derive(Debug, Clone, Default, dfdx::Sequential)]
#[built(UpBlock)]
pub struct UpBlockConfig<const CH_INF: usize, const CH_OUT: usize, const CH_CONCAT: usize> {
    // for keras' padding='same', the PADDING value must be set to:
    // ((kernel-1) * dilation + 1) // 2 = ((3-1) * 1 + 1 // 2) = 1
    pub conv_trans:
        On<split::forward, ConvTrans2DConstConfig<CH_INF, CH_OUT, 3, 2, 1, 1, 1, 1>>,
    pub bias: On<split::forward, Bias2DConstConfig<CH_OUT>>,
    // concat "skip" and "forward" along channels
    pub tuple: split::IntoTuple,
    pub concat: ops::ConcatTensorAlong<Axis<1>>,
    pub conv: ConvBlockConfig<CH_CONCAT, CH_OUT>,
}

/// Just applies `M`.
type Onc0<M> = M;
/// Access `F` and then applies `M`.
type Onc1<F, M> = On<F, M>;
/// Access `F` consecutively 2 times and then applies `M`.
type Onc2<F, M> = On<F, On<F, M>>;
/// Access `F` consecutively 3 times and then applies `M`.
type Onc3<F, M> = On<F, Onc2<F, M>>;
/// Access `F` consecutively 4 times and then applies `M`.
type Onc4<F, M> = On<F, Onc3<F, M>>;
/// Access `F` consecutively 5 times and then applies `M`.
type Onc5<F, M> = On<F, Onc4<F, M>>;

#[derive(Debug, Clone, Default, dfdx::Sequential)]
#[built(Model)]
pub struct ModelConfig {
    // encoder
    pub down_block_0: Onc0<DownBlockConfig<3, 32>>,
    pub down_block_1: Onc1<split::forward, DownBlockConfig<32, 64>>,
    pub down_block_2: Onc2<split::forward, DownBlockConfig<64, 128>>,
    pub down_block_3: Onc3<split::forward, DownBlockConfig<128, 256>>,

    // bottleneck
    // note: this increases channels but does not reduces the height nor the width
    pub conv_bottle: Onc4<split::forward, ConvBlockConfig<256, 512>>,

    // decoder
    pub up_block_4: Onc3<split::forward, UpBlockConfig<512, 256, 512>>,
    pub up_block_3: Onc2<split::forward, UpBlockConfig<256, 128, 256>>,
    pub up_block_2: Onc1<split::forward, UpBlockConfig<128, 64, 128>>,
    pub up_block_1: Onc0<UpBlockConfig<64, 32, 64>>,

    // yclass channel conversion
    pub conv_2: Conv2DConstConfig<32, 13, 1, 1, 0>,
    pub bias_2: Bias2DConstConfig<13>,
}

And this is how I defined and trained a simple RNN (based on this exercise).
The trained model was able to generate dino-like names.

rust code

pub mod model {
    use super::*;

    #[input_wrapper]
    #[derive(Clone, Debug)]
    pub struct Input<A, X> {
        pub a_prev: A,
        pub x: X,
    }

    impl<A, X, const AXIS: isize> TryConcatTensorAlong<Axis<AXIS>> for Input<A, X>
    where
        (A, X): TryConcatTensorAlong<Axis<AXIS>>,
    {
        type Output = <(A, X) as TryConcatTensorAlong<Axis<AXIS>>>::Output;
        fn try_concat_tensor_along(self, ax: Axis<AXIS>) -> Result<Self::Output, Error> {
            (self.a_prev, self.x).try_concat_tensor_along(ax)
        }
    }

    #[input_wrapper]
    #[derive(Clone, Debug)]
    pub struct Output<A, Y> {
        pub a: A,
        pub y: Y,
    }

    impl<AS: Shape, YS: Shape, E: Dtype, D: Device<E>>
        Output<Tensor<AS, E, D, OwnedTape<E, D>>, Tensor<YS, E, D, OwnedTape<E, D>>>
    {
        pub fn merge_tapes_on_y(self) -> Self {
            let (a, at) = self.a.split_tape();
            let (y, yt) = self.y.split_tape();
            Self {
                a: a.leaky_traced(),
                y: y.put_tape(at.merge(yt)),
            }
        }
    }

    //
    /// Input:
    /// ```ignore
    /// Input {
    ///     a_prev: A,
    ///     x: X,
    /// }
    /// ```
    ///
    /// Output:
    /// ```ignore
    /// Output {
    ///     a: A,
    ///     y: Y,
    /// }
    /// ```
    #[derive(Debug, Clone, Default, dfdx::Sequential)]
    #[built(Cell)]
    pub struct CellConstConfig<
        const NA: usize,
        const NX: usize,
        const NY: usize,
        const CONCAT_AXIS: isize,
        const NAPNX: usize = { NA + NX },
    > {
        // doing concat(a_prev, x) dot concat(wa^t, wb^t) + b is the same as
        // doing a_prev dot wa^t + x dot wb^t + b
        //
        // pub amul: On<input::a_prev, MatMulConstConfig<NA, NA>>,
        // pub xmul: On<input::x, MatMulConstConfig<NX, NA>>,
        // pub ax_tuple: input::IntoTuple,
        // pub ax_add: ops::Add,
        // pub bias: Bias1DConstConfig<NA>,
        //
        pub concat_input: ops::ConcatTensorAlong<Axis<CONCAT_AXIS>>,
        pub ax_linear: LinearConstConfig<NAPNX, NA>,
        //
        pub g1: Tanh,
        pub ay_tuple: SplitInto<(Id, Id)>,
        pub ay: output::FromTuple,
        pub ylinear: On<output::y, LinearConstConfig<NA, NY>>,
    }
}

#[test]
fn test_rnn() -> anyhow::Result<()> {
    let dev = Cuda::try_build(0, 0)?;

    const T_: usize = 2;
    const BATCH: usize = 3;
    // for unbatched (1D) tensors, the concat axis is 0
    // for batched (2D) tensors, the concat axis is 1
    const CONCAT_AXIS: isize = 1;
    const NA: usize = 2;
    const NX: usize = 3;
    const NY: usize = 3;
    type XT<T = NoneTape> = Tensor<Rank2<BATCH, NX>, f32, Device_, T>;
    type AT<T = NoneTape> = Tensor<Rank2<BATCH, NA>, f32, Device_, T>;
    type YT<T = NoneTape> = Tensor<Rank2<BATCH, NY>, f32, Device_, T>;

    let mut model =
        dev.build_module::<f32>(model::CellConstConfig::<NA, NX, NY, CONCAT_AXIS>::default());
    let mut grads = model.alloc_grads();

    let mut opt = dfdx::prelude::optim::Adam::new(
        &model,
        AdamConfig {
            lr: 1e-4,
            ..Default::default() // weight_decay: Some(dfdx::nn::optim::WeightDecay::L2(0.001)),
        },
    );

    const EPOCHS: usize = 2;
    for e in 0..EPOCHS {
        let a_prev: AT = dev.zeros();
        let mut x: XT = dev.zeros();

        let mut a_prev_t: AT<_> = a_prev.leaky_traced();

        let mut batch_loss = 0f32;

        for _t in 0..T_ {
            let y_t: YT = dev.sample_uniform();

            let x_t: XT<OwnedTape<f32, Device_>> = x.traced(grads);

            let input = model::Input {
                a_prev: a_prev_t,
                x: x_t,
            };

            let prediction = model.forward_mut(input);
            let prediction = prediction.merge_tapes_on_y();
            let loss_t =
                dfdx::losses::cross_entropy_with_logits_loss(prediction.y, y_t.clone());
            batch_loss += loss_t.array();

            // Note:
            // Running backprop and model update for each timestep t.
            // A different approach would be to run backprop at the last timestep and update once.
            // Or yet do something in between.
            grads = loss_t.backward();
            opt.update(&mut model, &grads).unwrap();

            x = y_t;
            a_prev_t = prediction.a;
        }
        println!("epoch: {}; loss: {}", e, batch_loss);
        // grads.drop_non_leafs();
        model.zero_grads(&mut grads);
    }

    Ok(())
}

  - Add the heck dep to convert from CamelCase into snake_case.
- Add layers.
  - `Id`, which just forwards the input.
  - `On`, applies some Module into an input wrapper field.
    - Contains a test demonstrating it's usage.
  - `Add`, which calls `try_add` for the inputs.
@swfsql swfsql force-pushed the derive-input-wrapper branch from 52b6264 to 82c314b Compare December 4, 2023 22:57
@swfsql
Copy link
Contributor Author

swfsql commented Mar 1, 2024

I'll prioritize moving this experiment to a separate crate, but feel free to ping in case anyone have some question or suggestion.

@swfsql swfsql closed this Mar 1, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Consider helpers for accessing tensors from tuples and input wrappers
1 participant