Skip to content

Commit

Permalink
Move Sequential and slice_model to common.py
Browse files Browse the repository at this point in the history
  • Loading branch information
n2cholas committed Mar 23, 2021
1 parent 9f5806a commit c682ab0
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 54 deletions.
55 changes: 53 additions & 2 deletions jax_resnet/common.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from functools import partial
from typing import Callable, Iterable, Optional, Tuple, Union
from typing import Any, Callable, Dict, Iterable, Optional, Sequence, Tuple, Union

from flax import linen as nn
import flax
import flax.linen as nn
import jax.numpy as jnp

ModuleDef = Callable[..., Callable]

Expand Down Expand Up @@ -39,3 +41,52 @@ def __call__(self, x):
if not self.is_last:
x = self.activation(x)
return x


class Sequential(nn.Module):
layers: Sequence[Union[nn.Module, Callable[[jnp.ndarray], jnp.ndarray]]]

@nn.compact
def __call__(self, x):
for layer in self.layers:
x = layer(x)
return x


def slice_model(
resnet: Sequential,
start: int = 0,
end: Optional[int] = None,
*,
variables: Optional[flax.core.FrozenDict] = None
) -> Union[Sequence, Tuple[Sequential, flax.core.FrozenDict]]:
"""Returns ResNet with a subset of the layers from indices [start, end).
Args:
resnet: A Sequential model (i.e. a flax.linen.Module with a `layers`
attribute holding all the layers).
start: integer indicating the first layer to keep.
end: integer indicating the first layer to exclude (can be negative,
has the same semantics as negative list indexing).
variables: The flax.FrozenDict extract a subset of the layer state
from.
Returns:
If variables is provided, a tuple with the sliced model and variables,
otherwise just the sliced model.
"""
if variables is None:
return Sequential(resnet.layers[start:end])
else:
end_ind = end if end is not None else 0
if end_ind < 0:
end_ind = max(int(s.split('_')[-1]) for s in variables['params']) + end_ind

sliced_variables: Dict[str, Any] = {}
for k, var_dict in variables.items(): # usually params and batch_stats
sliced_variables[k] = {}
for i in range(start, end_ind):
if f'layers_{i}' in var_dict:
sliced_variables[k][f'layers_{i}'] = var_dict[f'layers_{i}']

return Sequential(resnet.layers[start:end]), flax.core.freeze(sliced_variables)
54 changes: 2 additions & 52 deletions jax_resnet/resnet.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from functools import partial
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union
from typing import Callable, Optional, Sequence, Tuple

import flax
import jax.numpy as jnp
from flax import linen as nn

from .common import ConvBlock, ModuleDef
from .common import ConvBlock, ModuleDef, Sequential
from .splat import SplAtConv2d

STAGE_SIZES = {
Expand Down Expand Up @@ -170,16 +169,6 @@ def __call__(self, x):
return self.activation(y + skip_cls(self.strides)(x, y.shape))


class Sequential(nn.Module):
layers: Sequence[Union[nn.Module, Callable[[jnp.ndarray], jnp.ndarray]]]

@nn.compact
def __call__(self, x):
for layer in self.layers:
x = layer(x)
return x


def ResNet(
block_cls: ModuleDef,
stage_sizes: Sequence[int],
Expand Down Expand Up @@ -209,45 +198,6 @@ def ResNet(
return Sequential(layers)


def slice_model(
resnet: Sequential,
start: int = 0,
end: Optional[int] = None,
*,
variables: Optional[flax.core.FrozenDict] = None
) -> Union[Sequence, Tuple[Sequential, flax.core.FrozenDict]]:
"""Returns ResNet with a subset of the layers from indices [start, end).
Args:
resnet: A Sequential model (i.e. a flax.linen.Module with a `layers`
attribute holding all the layers).
start: integer indicating the first layer to keep.
end: integer indicating the first layer to exclude (can be negative,
has the same semantics as negative list indexing).
variables: The flax.FrozenDict extract a subset of the layer state
from.
Returns:
If variables is provided, a tuple with the sliced model and variables,
otherwise just the sliced model.
"""
if variables is None:
return Sequential(resnet.layers[start:end])
else:
end_ind = end if end is not None else 0
if end_ind < 0:
end_ind = max(int(s.split('_')[-1]) for s in variables['params']) + end_ind

sliced_variables: Dict[str, Any] = {}
for k, var_dict in variables.items(): # usually params and batch_stats
sliced_variables[k] = {}
for i in range(start, end_ind):
if f'layers_{i}' in var_dict:
sliced_variables[k][f'layers_{i}'] = var_dict[f'layers_{i}']

return Sequential(resnet.layers[start:end]), flax.core.freeze(sliced_variables)


# yapf: disable
ResNet18 = partial(ResNet, stage_sizes=STAGE_SIZES[18],
stem_cls=ResNetStem, block_cls=ResNetBlock)
Expand Down

0 comments on commit c682ab0

Please sign in to comment.