Skip to content

Commit

Permalink
add test for try_forward_mut (fails)
Browse files Browse the repository at this point in the history
  • Loading branch information
swfsql committed Nov 6, 2023
1 parent 370334f commit 20a958d
Showing 1 changed file with 18 additions and 0 deletions.
18 changes: 18 additions & 0 deletions dfdx/src/nn/layers/batch_norm2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -311,4 +311,22 @@ mod tests {
let mut opt = crate::nn::optim::Sgd::new(&bn, Default::default());
opt.update(&mut bn, &g).expect("");
}

#[derive(Default, Clone, Sequential)]
struct Arch {
pub batch: BatchNorm2DConstConfig<3>,
}

#[test]
fn test_batchnorm2d_update_with_derive() {
let dev: TestDevice = Default::default();

let x1: Tensor<Rank3<3, 4, 5>, TestDtype, _> = dev.sample_normal();
let mut bn = dev.build_module::<TestDtype>(Arch::default());
let y = bn.forward_mut(x1.leaky_trace());
let g = y.square().mean().backward();

let mut opt = crate::nn::optim::Sgd::new(&bn, Default::default());
opt.update(&mut bn, &g).expect("");
}
}

0 comments on commit 20a958d

Please sign in to comment.