Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Passing on State to BatchNorm when it is integrated with nn.Squential #448

Closed
paganpasta opened this issue Aug 15, 2023 · 3 comments
Closed
Labels
feature New feature

Comments

@paganpasta
Copy link
Contributor

Hi,

This in continuation to the effort of bringing eqxvision up to date with changes in equinox (paganpasta/eqxvision#72).

One of the changes required now are to pass the state explicitly to the Normlayer's call.
Quite often the models are recursive collections of bunch of layers with BatchNorm integrated in them somewhere. For example,
ResNet utilises mainly 4 blocks where each block is implemented as a Sequential and within each sequential block there is a BatchNorm.

One solution I see is to either modify the equinox.nn.Sequential to accomodate an optional State input or simply implement a modified Sequential in the Eqxvision repository.

@patrick-kidger, I wanted to seek your advice on the best way to go about and supporting such scenarios before making any drastic changes.

@patrick-kidger
Copy link
Owner

Heyhey! I think we could probably do something like:

class StatefulLayer(eqx.Module):
    @abc.abstractmethod
    def __call__(self, x, state):
        ...

class BatchNorm(StatefulLayer):
    ...

class Sequential(eqx.Module):
    ...

    def __call__(self, x, state=sentinel):
        for layer in self.layers:
            if isinstance(layer, StatefulLayer) and state is not sentinel:
                x, state = layer(x, state)
            else:
                x = layer(x)
        if state is sentinel:
            return x
        else:
            return x, state

In other words: (a) use the same trick that LayerNorm uses to optionally pass through stateful input, and (b) add an explicit way of checking whether a layer should take an extra state input.

I'd be happy to take a PR on this to Equinox if you want. (I'd also be happy to see this appear in Eqxvision; I don't have strong feelings either way.)

@hlzl
Copy link

hlzl commented Sep 25, 2023

I think this can be closed due to #505

@patrick-kidger patrick-kidger added feature New feature and removed question User queries labels Sep 25, 2023
@patrick-kidger
Copy link
Owner

Yup! Problem solved. :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature New feature
Projects
None yet
Development

No branches or pull requests

3 participants