diff --git a/test/test_models.py b/test/test_models.py index 209f27209bf..d657475bafb 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -406,7 +406,7 @@ def test_mobilenet_norm_layer(model_fn): assert any(isinstance(x, nn.BatchNorm2d) for x in model.modules()) def get_gn(num_channels): - return nn.GroupNorm(32, num_channels) + return nn.GroupNorm(1, num_channels) model = model_fn(norm_layer=get_gn) assert not (any(isinstance(x, nn.BatchNorm2d) for x in model.modules()))