diff --git a/README.md b/README.md index 4bb3883..24b72d7 100644 --- a/README.md +++ b/README.md @@ -38,8 +38,7 @@ ResNeSt50, variables = pretrained_resnest(50) model = ResNeSt50() out = model.apply(variables, jnp.ones((32, 224, 224, 3)), # ImageNet sized inputs. - mutable=False, # Ensure `batch_stats` aren't updated. - train=False) # Use running mean/var for batchnorm. + mutable=False) # Ensure `batch_stats` aren't updated. ``` You must install PyTorch yourself @@ -52,6 +51,16 @@ match exactly. Feel free to use it via `pretrained_resnetd` (should be fine for transfer learning). You must install fast.ai yourself ([instructions](https://docs.fast.ai/)) to use this function. +### Transfer Learning + +To extract a subset of the model, you can use +`Sequential(model.layers[start:end])`. + +The `slice_variables` function (found in in +[`common.py`](https://github.com/n2cholas/jax-resnet/blob/main/jax_resnet/common.py)) +allows you to extract the corresponding subset of the variables dict. Check out +that docstring for more information. + ## References - [Deep Residual Learning for Image Recognition. Kaiming He, Xiangyu Zhang, diff --git a/jax_resnet/common.py b/jax_resnet/common.py index 82458b4..0066b24 100644 --- a/jax_resnet/common.py +++ b/jax_resnet/common.py @@ -1,7 +1,10 @@ from functools import partial -from typing import Callable, Iterable, Optional, Tuple, Union +from typing import (Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Tuple, + Union) -from flax import linen as nn +import flax +import flax.linen as nn +import jax.numpy as jnp ModuleDef = Callable[..., Callable] @@ -21,7 +24,7 @@ class ConvBlock(nn.Module): force_conv_bias: bool = False @nn.compact - def __call__(self, x, train: bool = True): + def __call__(self, x): x = self.conv_cls( self.n_filters, self.kernel_size, @@ -33,8 +36,68 @@ def __call__(self, x, train: bool = True): if self.norm_cls: scale_init = (nn.initializers.zeros if self.is_last else nn.initializers.ones) - x = self.norm_cls(use_running_average=not train, scale_init=scale_init)(x) + mutable = self.is_mutable_collection('batch_stats') + x = self.norm_cls(use_running_average=not mutable, scale_init=scale_init)(x) if not self.is_last: x = self.activation(x) return x + + +class Sequential(nn.Module): + layers: Sequence[Union[nn.Module, Callable[[jnp.ndarray], jnp.ndarray]]] + + @nn.compact + def __call__(self, x): + for layer in self.layers: + x = layer(x) + return x + + +def slice_variables(variables: Mapping[str, Any], + start: int = 0, + end: Optional[int] = None) -> flax.core.FrozenDict: + """Returns variables dict correspond to a sliced model. + + You can retrieve the model corresponding to the slices variables via + `Sequential(model.layers[start:end])`. + + The variables mapping should have the same structure as a Sequential + model's variable dict (based on Flax): + + ```python + variables = { + 'group1': ['layers_a', 'layer_b', ...] + 'group2': ['layers_a', 'layer_b', ...] + ..., + } + + Typically, `'group1'` and `'group2'` would be `'params'` and + `'batch_stats'`, but they don't have to be. `a, b, ...` correspond to the + integer indices of the layers. + + Args: + variables: A mapping (typically a flax.core.FrozenDict) containing the + model parameters and state. + start: integer indicating the first layer to keep. + end: integer indicating the first layer to exclude (can be negative, + has the same semantics as negative list indexing). + + Returns: + A flax.core.FrozenDict with the subset of parameters/state requested. + """ + last_ind = max(int(s.split('_')[-1]) for s in variables['params']) + if end is None: + end = last_ind + 1 + elif end < 0: + end += last_ind + 1 + + sliced_variables: Dict[str, Any] = {} + for k, var_dict in variables.items(): # usually params and batch_stats + sliced_variables[k] = { + f'layers_{i-start}': var_dict[f'layers_{i}'] + for i in range(start, end) + if f'layers_{i}' in var_dict + } + + return flax.core.freeze(sliced_variables) diff --git a/jax_resnet/pretrained.py b/jax_resnet/pretrained.py index 8729644..25e7110 100644 --- a/jax_resnet/pretrained.py +++ b/jax_resnet/pretrained.py @@ -37,13 +37,12 @@ def pretrained_resnet(size: int) -> Tuple[ModuleDef, Mapping]: add_bn = _get_add_bn(pt2jax) def bname(num): - return f'ResNetBottleneckBlock_{num}' + return f'layers_{num}' - pt2jax['conv1.weight'] = ('params', 'ResNetStem_0', 'ConvBlock_0', 'Conv_0', - 'kernel') - add_bn('bn1', ('ResNetStem_0', 'ConvBlock_0', 'BatchNorm_0')) + pt2jax['conv1.weight'] = ('params', 'layers_0', 'ConvBlock_0', 'Conv_0', 'kernel') + add_bn('bn1', ('layers_0', 'ConvBlock_0', 'BatchNorm_0')) - b_ind = 0 # block_ind + b_ind = 2 # block_ind for b, n_blocks in enumerate(resnet.STAGE_SIZES[size], 1): for i in range(n_blocks): for j in range(3): @@ -64,8 +63,9 @@ def bname(num): b_ind += 1 - pt2jax['fc.weight'] = ('params', 'Dense_0', 'kernel') - pt2jax['fc.bias'] = ('params', 'Dense_0', 'bias') + b_ind += 1 + pt2jax['fc.weight'] = ('params', bname(b_ind), 'kernel') + pt2jax['fc.bias'] = ('params', bname(b_ind), 'bias') variables = _pytorch_to_jax_params(pt2jax, state_dict, ('fc.weight',)) model_cls = partial(getattr(resnet, f'ResNet{size}'), n_classes=1000) @@ -106,12 +106,12 @@ def add_convblock(pt_layer, jax_layer): add_bn(f'{pt_layer}.1', (*jax_layer, 'BatchNorm_0')) def bname(num): - return f'ResNetDBottleneckBlock_{num}' + return f'layers_{num}' for i in range(3): - add_convblock(i, ('ResNetDStem_0', f'ConvBlock_{i}')) + add_convblock(i, ('layers_0', f'ConvBlock_{i}')) - b_ind = 0 # block_ind + b_ind = 2 # block_ind for b, n_blocks in enumerate(resnet.STAGE_SIZES[size], 4): for i in range(n_blocks): for j in range(3): @@ -126,8 +126,9 @@ def bname(num): b_ind += 1 - pt2jax['11.weight'] = ('params', 'Dense_0', 'kernel') - pt2jax['11.bias'] = ('params', 'Dense_0', 'bias') + b_ind += 1 + pt2jax['11.weight'] = ('params', bname(b_ind), 'kernel') + pt2jax['11.bias'] = ('params', bname(b_ind), 'bias') variables = _pytorch_to_jax_params(pt2jax, state_dict, ('11.weight',)) model_cls = partial(getattr(resnet, f'ResNetD{size}'), n_classes=1000) @@ -159,20 +160,17 @@ def pretrained_resnest(size: int) -> Tuple[ModuleDef, Mapping]: add_bn = _get_add_bn(pt2jax) def bname(num): - return f'ResNeStBottleneckBlock_{num}' + return f'layers_{num}' # Stem - pt2jax['conv1.0.weight'] = ('params', 'ResNeStStem_0', 'ConvBlock_0', 'Conv_0', - 'kernel') - add_bn('conv1.1', ('ResNeStStem_0', 'ConvBlock_0', 'BatchNorm_0')) - pt2jax['conv1.3.weight'] = ('params', 'ResNeStStem_0', 'ConvBlock_1', 'Conv_0', - 'kernel') - add_bn('conv1.4', ('ResNeStStem_0', 'ConvBlock_1', 'BatchNorm_0')) - pt2jax['conv1.6.weight'] = ('params', 'ResNeStStem_0', 'ConvBlock_2', 'Conv_0', - 'kernel') - add_bn('bn1', ('ResNeStStem_0', 'ConvBlock_2', 'BatchNorm_0')) - - b_ind = 0 # block_ind + pt2jax['conv1.0.weight'] = ('params', 'layers_0', 'ConvBlock_0', 'Conv_0', 'kernel') + add_bn('conv1.1', ('layers_0', 'ConvBlock_0', 'BatchNorm_0')) + pt2jax['conv1.3.weight'] = ('params', 'layers_0', 'ConvBlock_1', 'Conv_0', 'kernel') + add_bn('conv1.4', ('layers_0', 'ConvBlock_1', 'BatchNorm_0')) + pt2jax['conv1.6.weight'] = ('params', 'layers_0', 'ConvBlock_2', 'Conv_0', 'kernel') + add_bn('bn1', ('layers_0', 'ConvBlock_2', 'BatchNorm_0')) + + b_ind = 2 # block_ind for b, n_blocks in enumerate(resnet.STAGE_SIZES[size], 1): for i in range(n_blocks): pt2jax[f'layer{b}.{i}.conv1.weight'] = ('params', bname(b_ind), @@ -215,8 +213,9 @@ def bname(num): b_ind += 1 - pt2jax['fc.weight'] = ('params', 'Dense_0', 'kernel') - pt2jax['fc.bias'] = ('params', 'Dense_0', 'bias') + b_ind += 1 + pt2jax['fc.weight'] = ('params', bname(b_ind), 'kernel') + pt2jax['fc.bias'] = ('params', bname(b_ind), 'bias') variables = _pytorch_to_jax_params(pt2jax, state_dict, ('fc.weight',)) diff --git a/jax_resnet/resnet.py b/jax_resnet/resnet.py index 2fe0ba1..77f6b4e 100644 --- a/jax_resnet/resnet.py +++ b/jax_resnet/resnet.py @@ -1,9 +1,10 @@ from functools import partial from typing import Callable, Optional, Sequence, Tuple +import jax.numpy as jnp from flax import linen as nn -from .common import ConvBlock, ModuleDef +from .common import ConvBlock, ModuleDef, Sequential from .splat import SplAtConv2d STAGE_SIZES = { @@ -21,11 +22,11 @@ class ResNetStem(nn.Module): conv_block_cls: ModuleDef = ConvBlock @nn.compact - def __call__(self, x, train: bool = True): + def __call__(self, x): return self.conv_block_cls(64, kernel_size=(7, 7), strides=(2, 2), - padding=[(3, 3), (3, 3)])(x, train=train) + padding=[(3, 3), (3, 3)])(x) class ResNetDStem(nn.Module): @@ -36,13 +37,13 @@ class ResNetDStem(nn.Module): adaptive_first_width: bool = True # TODO: Find better name. @nn.compact - def __call__(self, x, train: bool = True): + def __call__(self, x): cls = partial(self.conv_block_cls, kernel_size=(3, 3), padding=((1, 1), (1, 1))) first_width = (8 * (x.shape[-1] + 1) if self.adaptive_first_width else self.stem_width) - x = cls(first_width, strides=(2, 2))(x, train=train) - x = cls(self.stem_width, strides=(1, 1))(x, train=train) - x = cls(self.stem_width * 2, strides=(1, 1))(x, train=train) + x = cls(first_width, strides=(2, 2))(x) + x = cls(self.stem_width, strides=(1, 1))(x) + x = cls(self.stem_width * 2, strides=(1, 1))(x) return x @@ -56,12 +57,12 @@ class ResNetSkipConnection(nn.Module): conv_block_cls: ModuleDef = ConvBlock @nn.compact - def __call__(self, x, out_shape, train: bool = True): + def __call__(self, x, out_shape): if x.shape != out_shape: x = self.conv_block_cls(out_shape[-1], kernel_size=(1, 1), strides=self.strides, - activation=lambda y: y)(x, train=train) + activation=lambda y: y)(x) return x @@ -70,12 +71,11 @@ class ResNetDSkipConnection(nn.Module): conv_block_cls: ModuleDef = ConvBlock @nn.compact - def __call__(self, x, out_shape, train: bool = True): + def __call__(self, x, out_shape): if self.strides != (1, 1): x = nn.avg_pool(x, (2, 2), strides=(2, 2), padding=((0, 0), (0, 0))) if x.shape[-1] != out_shape[-1]: - x = self.conv_block_cls(out_shape[-1], (1, 1), - activation=lambda y: y)(x, train=train) + x = self.conv_block_cls(out_shape[-1], (1, 1), activation=lambda y: y)(x) return x @@ -93,13 +93,14 @@ class ResNetBlock(nn.Module): skip_cls: ModuleDef = ResNetSkipConnection @nn.compact - def __call__(self, x, train: bool = True): + def __call__(self, x): + skip_cls = partial(self.skip_cls, conv_block_cls=self.conv_block_cls) y = self.conv_block_cls(self.n_hidden, padding=[(1, 1), (1, 1)], - strides=self.strides)(x, train=train) + strides=self.strides)(x) y = self.conv_block_cls(self.n_hidden, padding=[(1, 1), (1, 1)], - is_last=True)(y, train=train) - return self.activation(y + self.skip_cls(self.strides)(x, y.shape, train=train)) + is_last=True)(y) + return self.activation(y + skip_cls(self.strides)(x, y.shape)) class ResNetBottleneckBlock(nn.Module): @@ -111,14 +112,14 @@ class ResNetBottleneckBlock(nn.Module): skip_cls: ModuleDef = ResNetSkipConnection @nn.compact - def __call__(self, x, train: bool = True): - y = self.conv_block_cls(self.n_hidden, kernel_size=(1, 1))(x, train=train) + def __call__(self, x): + skip_cls = partial(self.skip_cls, conv_block_cls=self.conv_block_cls) + y = self.conv_block_cls(self.n_hidden, kernel_size=(1, 1))(x) y = self.conv_block_cls(self.n_hidden, strides=self.strides, - padding=((1, 1), (1, 1)))(y, train=train) - y = self.conv_block_cls(self.n_hidden * 4, kernel_size=(1, 1), - is_last=True)(y, train=train) - return self.activation(y + self.skip_cls(self.strides)(x, y.shape, train=train)) + padding=((1, 1), (1, 1)))(y) + y = self.conv_block_cls(self.n_hidden * 4, kernel_size=(1, 1), is_last=True)(y) + return self.activation(y + skip_cls(self.strides)(x, y.shape)) class ResNetDBlock(ResNetBlock): @@ -139,15 +140,16 @@ class ResNeStBottleneckBlock(ResNetBottleneckBlock): splat_cls: ModuleDef = SplAtConv2d @nn.compact - def __call__(self, x, train: bool = True): + def __call__(self, x): # TODO: implement groups != 1 and radix != 2 assert self.groups == 1 assert self.radix == 2 + skip_cls = partial(self.skip_cls, conv_block_cls=self.conv_block_cls) n_filters = self.n_hidden * 4 group_width = int(self.n_hidden * (self.bottleneck_width / 64.)) * self.groups - y = self.conv_block_cls(group_width, kernel_size=(1, 1))(x, train=train) + y = self.conv_block_cls(group_width, kernel_size=(1, 1))(x) if self.strides != (1, 1) and self.avg_pool_first: y = nn.avg_pool(y, (3, 3), strides=self.strides, padding=[(1, 1), (1, 1)]) @@ -157,61 +159,43 @@ def __call__(self, x, train: bool = True): strides=(1, 1), padding=[(1, 1), (1, 1)], groups=self.groups, - radix=self.radix)(y, train=train) + radix=self.radix)(y) if self.strides != (1, 1) and not self.avg_pool_first: y = nn.avg_pool(y, (3, 3), strides=self.strides, padding=[(1, 1), (1, 1)]) - y = self.conv_block_cls(n_filters, kernel_size=(1, 1), - is_last=True)(y, train=train) + y = self.conv_block_cls(n_filters, kernel_size=(1, 1), is_last=True)(y) - return self.activation(y + self.skip_cls(self.strides)(x, y.shape, train=train)) + return self.activation(y + skip_cls(self.strides)(x, y.shape)) -class ResNet(nn.Module): - block_cls: ModuleDef - stage_sizes: Sequence[int] - n_classes: int - - conv_cls: ModuleDef = nn.Conv - norm_cls: Optional[ModuleDef] = partial(nn.BatchNorm, momentum=0.9) - - conv_block_cls: ModuleDef = ConvBlock - stem_cls: ModuleDef = ResNetStem +def ResNet( + block_cls: ModuleDef, + stage_sizes: Sequence[int], + n_classes: int, + conv_cls: ModuleDef = nn.Conv, + norm_cls: Optional[ModuleDef] = partial(nn.BatchNorm, momentum=0.9), + conv_block_cls: ModuleDef = ConvBlock, + stem_cls: ModuleDef = ResNetStem, pool_fn: Callable = partial(nn.max_pool, window_shape=(3, 3), strides=(2, 2), - padding=((1, 1), (1, 1))) - - # When True, the model will propogate the top-level conv_cls and norm_cls - # through the conv_block_cls to all the submodules (stem, bottleneck, etc). - consistent_conv_block: bool = False - backbone_only: bool = False # When True, no GlobalAveragePool or Dense - - @nn.compact - def __call__(self, x, train: bool = True): - conv_block_cls = partial(self.conv_block_cls, - conv_cls=self.conv_cls, - norm_cls=self.norm_cls) - stem_cls, block_cls = self.stem_cls, self.block_cls - if self.consistent_conv_block: - stem_cls = partial(stem_cls, conv_block_cls=conv_block_cls) - # TODO: set conv_block_cls for skip_cls - block_cls = partial(block_cls, conv_block_cls=conv_block_cls) - - x = stem_cls()(x, train=train) - x = self.pool_fn(x) - - for i, n_blocks in enumerate(self.stage_sizes): - for b in range(n_blocks): - strides = (1, 1) if i == 0 or b != 0 else (2, 2) - x = block_cls(n_hidden=2**(i + 6), strides=strides)(x, train=train) - - if self.backbone_only: - return x - - x = x.mean((-2, -3)) # global average pool - return nn.Dense(self.n_classes)(x) + padding=((1, 1), (1, 1))), +): + conv_block_cls = partial(conv_block_cls, conv_cls=conv_cls, norm_cls=norm_cls) + stem_cls = partial(stem_cls, conv_block_cls=conv_block_cls) + block_cls = partial(block_cls, conv_block_cls=conv_block_cls) + + layers = [stem_cls(), pool_fn] + + for i, n_blocks in enumerate(stage_sizes): + for b in range(n_blocks): + strides = (1, 1) if i == 0 or b != 0 else (2, 2) + layers.append(block_cls(n_hidden=2**(i + 6), strides=strides)) + + layers.append(partial(jnp.mean, axis=(1, 2))) # global average pool + layers.append(nn.Dense(n_classes)) + return Sequential(layers) # yapf: disable diff --git a/jax_resnet/splat.py b/jax_resnet/splat.py index dc9fa08..0865167 100644 --- a/jax_resnet/splat.py +++ b/jax_resnet/splat.py @@ -33,7 +33,7 @@ class SplAtConv2d(nn.Module): match_reference: bool = False @nn.compact - def __call__(self, x, train: bool = True): + def __call__(self, x): inter_channels = max(x.shape[-1] * self.radix // self.reduction_factor, 32) conv_block = self.conv_block_cls(self.channels * self.radix, @@ -42,7 +42,7 @@ def __call__(self, x, train: bool = True): groups=self.groups * self.radix, padding=self.padding) conv_cls = conv_block.conv_cls # type: ignore - x = conv_block(x, train=train) + x = conv_block(x) if self.radix > 1: # torch split takes split_size: int(rchannel//self.radix) @@ -59,8 +59,7 @@ def __call__(self, x, train: bool = True): gap = self.conv_block_cls(inter_channels, kernel_size=(1, 1), groups=self.cardinality, - force_conv_bias=self.match_reference)(gap, - train=train) + force_conv_bias=self.match_reference)(gap) attn = conv_cls(self.channels * self.radix, kernel_size=(1, 1), diff --git a/setup.cfg b/setup.cfg index b4eea04..e9be3e1 100644 --- a/setup.cfg +++ b/setup.cfg @@ -15,5 +15,8 @@ ignore = F405 max-line-length = 88 max-complexity = 15 +[isort] +line_length = 88 + [mypy] ignore_missing_imports = True diff --git a/tests/test_pretrained.py b/tests/test_pretrained.py index a73a692..15f53c1 100644 --- a/tests/test_pretrained.py +++ b/tests/test_pretrained.py @@ -52,7 +52,7 @@ def _test_pretrained(size, pretrained_fn): pretrained_vars) assert jax.tree_util.tree_all(eq_tree) - out = model.apply(pretrained_vars, arr, mutable=False, train=False) + out = model.apply(pretrained_vars, arr, mutable=False) assert out.shape == (1, 1000) @@ -101,7 +101,7 @@ def _test_pretrained_resnet_activations(size): bottleneck.register_forward_hook(ptracker) # Bottleneck pnet.relu.register_forward_hook(ptracker) # Stem ReLU - jout = jnet.apply(variables, jnp.ones((1, 224, 224, 3)), mutable=False, train=False) + jout = jnet.apply(variables, jnp.ones((1, 224, 224, 3)), mutable=False) with torch.no_grad(): pout = pnet(torch.ones((1, 3, 224, 224))).numpy() @@ -149,7 +149,7 @@ def test_pretrained_resnetd_activation_shapes(size): pnet[2].register_forward_hook(ptracker) # Stem - jout = jnet.apply(variables, jnp.ones((1, 224, 224, 3)), mutable=False, train=False) + jout = jnet.apply(variables, jnp.ones((1, 224, 224, 3)), mutable=False) with torch.no_grad(): pout = pnet(torch.ones((1, 3, 224, 224))).numpy() @@ -199,7 +199,7 @@ def _test_pretrained_resnest_activations(size): pnet.conv1[6].register_forward_hook(ptracker) # Stem Conv pnet.relu.register_forward_hook(ptracker) # Stem Output - jout = jnet.apply(variables, jnp.ones((1, 224, 224, 3)), mutable=False, train=False) + jout = jnet.apply(variables, jnp.ones((1, 224, 224, 3)), mutable=False) with torch.no_grad(): pout = pnet(torch.ones((1, 3, 224, 224))).numpy() diff --git a/tests/test_resnet.py b/tests/test_resnet.py index 4428e28..db2d227 100644 --- a/tests/test_resnet.py +++ b/tests/test_resnet.py @@ -103,22 +103,26 @@ def test_resnest_stem_param_count(): assert n_params(ResNeStStem()) == 112832 -@pytest.mark.parametrize( - 'cls', [ResNet18, ResNet50, ResNetD18, ResNetD50, ResNeSt50, ResNeSt50Fast]) -def test_consistent_convblock(cls): - # TODO: improve this test, right now only checks if it runs - model = cls(n_classes=10, consistent_conv_block=True) - init_array = jnp.ones((2, 224, 224, 3), dtype=jnp.float32) - variables = model.init(jax.random.PRNGKey(0), init_array) - model.apply(variables, init_array, mutable=False, train=False) - - -@pytest.mark.parametrize( - 'cls', [ResNet18, ResNet50, ResNetD18, ResNetD50, ResNeSt50, ResNeSt50Fast]) -def test_backbone_only(cls): - # TODO: improve this test, right now only checks if it runs - model = cls(n_classes=10, backbone_only=True) - init_array = jnp.ones((2, 224, 224, 3), dtype=jnp.float32) - variables = model.init(jax.random.PRNGKey(0), init_array) - out = model.apply(variables, init_array, mutable=False, train=False) - assert out.ndim == 4 +@pytest.mark.parametrize('start,end', [(0, 5), (0, None), (0, -3), (4, -2), (3, -1), + (2, None)]) +def test_slice_variables(start, end): + model = ResNet18(n_classes=10) + key = jax.random.PRNGKey(0) + + variables = model.init(key, jnp.ones((1, 224, 224, 3))) + sliced_vars = slice_variables(variables, start, end) + sliced_model = Sequential(model.layers[start:end]) + + # Need the correct number of input channels for slice: + first = variables['params'][f'layers_{start}']['ConvBlock_0']['Conv_0']['kernel'] + slice_inp = jnp.ones((1, 224, 224, first.shape[2])) + exp_sliced_vars = sliced_model.init(key, slice_inp) + + assert set(sliced_vars['params'].keys()) == set(exp_sliced_vars['params'].keys()) + assert set(sliced_vars['batch_stats'].keys()) == set( + exp_sliced_vars['batch_stats'].keys()) + + assert jax.tree_map(jnp.shape, + sliced_vars) == jax.tree_map(jnp.shape, exp_sliced_vars) + + sliced_model.apply(sliced_vars, slice_inp, mutable=False)