Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Transfer Learning API #2

Merged
merged 8 commits into from
Mar 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ ResNeSt50, variables = pretrained_resnest(50)
model = ResNeSt50()
out = model.apply(variables,
jnp.ones((32, 224, 224, 3)), # ImageNet sized inputs.
mutable=False, # Ensure `batch_stats` aren't updated.
train=False) # Use running mean/var for batchnorm.
mutable=False) # Ensure `batch_stats` aren't updated.
```

You must install PyTorch yourself
Expand All @@ -52,6 +51,16 @@ match exactly. Feel free to use it via `pretrained_resnetd` (should be fine for
transfer learning). You must install fast.ai yourself
([instructions](https://docs.fast.ai/)) to use this function.

### Transfer Learning

To extract a subset of the model, you can use
`Sequential(model.layers[start:end])`.

The `slice_variables` function (found in in
[`common.py`](https://github.com/n2cholas/jax-resnet/blob/main/jax_resnet/common.py))
allows you to extract the corresponding subset of the variables dict. Check out
that docstring for more information.

## References

- [Deep Residual Learning for Image Recognition. Kaiming He, Xiangyu Zhang,
Expand Down
71 changes: 67 additions & 4 deletions jax_resnet/common.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from functools import partial
from typing import Callable, Iterable, Optional, Tuple, Union
from typing import (Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Tuple,
Union)

from flax import linen as nn
import flax
import flax.linen as nn
import jax.numpy as jnp

ModuleDef = Callable[..., Callable]

Expand All @@ -21,7 +24,7 @@ class ConvBlock(nn.Module):
force_conv_bias: bool = False

@nn.compact
def __call__(self, x, train: bool = True):
def __call__(self, x):
x = self.conv_cls(
self.n_filters,
self.kernel_size,
Expand All @@ -33,8 +36,68 @@ def __call__(self, x, train: bool = True):
if self.norm_cls:
scale_init = (nn.initializers.zeros
if self.is_last else nn.initializers.ones)
x = self.norm_cls(use_running_average=not train, scale_init=scale_init)(x)
mutable = self.is_mutable_collection('batch_stats')
x = self.norm_cls(use_running_average=not mutable, scale_init=scale_init)(x)

if not self.is_last:
x = self.activation(x)
return x


class Sequential(nn.Module):
layers: Sequence[Union[nn.Module, Callable[[jnp.ndarray], jnp.ndarray]]]

@nn.compact
def __call__(self, x):
for layer in self.layers:
x = layer(x)
return x


def slice_variables(variables: Mapping[str, Any],
start: int = 0,
end: Optional[int] = None) -> flax.core.FrozenDict:
"""Returns variables dict correspond to a sliced model.

You can retrieve the model corresponding to the slices variables via
`Sequential(model.layers[start:end])`.

The variables mapping should have the same structure as a Sequential
model's variable dict (based on Flax):

```python
variables = {
'group1': ['layers_a', 'layer_b', ...]
'group2': ['layers_a', 'layer_b', ...]
...,
}

Typically, `'group1'` and `'group2'` would be `'params'` and
`'batch_stats'`, but they don't have to be. `a, b, ...` correspond to the
integer indices of the layers.

Args:
variables: A mapping (typically a flax.core.FrozenDict) containing the
model parameters and state.
start: integer indicating the first layer to keep.
end: integer indicating the first layer to exclude (can be negative,
has the same semantics as negative list indexing).

Returns:
A flax.core.FrozenDict with the subset of parameters/state requested.
"""
last_ind = max(int(s.split('_')[-1]) for s in variables['params'])
if end is None:
end = last_ind + 1
elif end < 0:
end += last_ind + 1

sliced_variables: Dict[str, Any] = {}
for k, var_dict in variables.items(): # usually params and batch_stats
sliced_variables[k] = {
f'layers_{i-start}': var_dict[f'layers_{i}']
for i in range(start, end)
if f'layers_{i}' in var_dict
}

return flax.core.freeze(sliced_variables)
51 changes: 25 additions & 26 deletions jax_resnet/pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,12 @@ def pretrained_resnet(size: int) -> Tuple[ModuleDef, Mapping]:
add_bn = _get_add_bn(pt2jax)

def bname(num):
return f'ResNetBottleneckBlock_{num}'
return f'layers_{num}'

pt2jax['conv1.weight'] = ('params', 'ResNetStem_0', 'ConvBlock_0', 'Conv_0',
'kernel')
add_bn('bn1', ('ResNetStem_0', 'ConvBlock_0', 'BatchNorm_0'))
pt2jax['conv1.weight'] = ('params', 'layers_0', 'ConvBlock_0', 'Conv_0', 'kernel')
add_bn('bn1', ('layers_0', 'ConvBlock_0', 'BatchNorm_0'))

b_ind = 0 # block_ind
b_ind = 2 # block_ind
for b, n_blocks in enumerate(resnet.STAGE_SIZES[size], 1):
for i in range(n_blocks):
for j in range(3):
Expand All @@ -64,8 +63,9 @@ def bname(num):

b_ind += 1

pt2jax['fc.weight'] = ('params', 'Dense_0', 'kernel')
pt2jax['fc.bias'] = ('params', 'Dense_0', 'bias')
b_ind += 1
pt2jax['fc.weight'] = ('params', bname(b_ind), 'kernel')
pt2jax['fc.bias'] = ('params', bname(b_ind), 'bias')

variables = _pytorch_to_jax_params(pt2jax, state_dict, ('fc.weight',))
model_cls = partial(getattr(resnet, f'ResNet{size}'), n_classes=1000)
Expand Down Expand Up @@ -106,12 +106,12 @@ def add_convblock(pt_layer, jax_layer):
add_bn(f'{pt_layer}.1', (*jax_layer, 'BatchNorm_0'))

def bname(num):
return f'ResNetDBottleneckBlock_{num}'
return f'layers_{num}'

for i in range(3):
add_convblock(i, ('ResNetDStem_0', f'ConvBlock_{i}'))
add_convblock(i, ('layers_0', f'ConvBlock_{i}'))

b_ind = 0 # block_ind
b_ind = 2 # block_ind
for b, n_blocks in enumerate(resnet.STAGE_SIZES[size], 4):
for i in range(n_blocks):
for j in range(3):
Expand All @@ -126,8 +126,9 @@ def bname(num):

b_ind += 1

pt2jax['11.weight'] = ('params', 'Dense_0', 'kernel')
pt2jax['11.bias'] = ('params', 'Dense_0', 'bias')
b_ind += 1
pt2jax['11.weight'] = ('params', bname(b_ind), 'kernel')
pt2jax['11.bias'] = ('params', bname(b_ind), 'bias')

variables = _pytorch_to_jax_params(pt2jax, state_dict, ('11.weight',))
model_cls = partial(getattr(resnet, f'ResNetD{size}'), n_classes=1000)
Expand Down Expand Up @@ -159,20 +160,17 @@ def pretrained_resnest(size: int) -> Tuple[ModuleDef, Mapping]:
add_bn = _get_add_bn(pt2jax)

def bname(num):
return f'ResNeStBottleneckBlock_{num}'
return f'layers_{num}'

# Stem
pt2jax['conv1.0.weight'] = ('params', 'ResNeStStem_0', 'ConvBlock_0', 'Conv_0',
'kernel')
add_bn('conv1.1', ('ResNeStStem_0', 'ConvBlock_0', 'BatchNorm_0'))
pt2jax['conv1.3.weight'] = ('params', 'ResNeStStem_0', 'ConvBlock_1', 'Conv_0',
'kernel')
add_bn('conv1.4', ('ResNeStStem_0', 'ConvBlock_1', 'BatchNorm_0'))
pt2jax['conv1.6.weight'] = ('params', 'ResNeStStem_0', 'ConvBlock_2', 'Conv_0',
'kernel')
add_bn('bn1', ('ResNeStStem_0', 'ConvBlock_2', 'BatchNorm_0'))

b_ind = 0 # block_ind
pt2jax['conv1.0.weight'] = ('params', 'layers_0', 'ConvBlock_0', 'Conv_0', 'kernel')
add_bn('conv1.1', ('layers_0', 'ConvBlock_0', 'BatchNorm_0'))
pt2jax['conv1.3.weight'] = ('params', 'layers_0', 'ConvBlock_1', 'Conv_0', 'kernel')
add_bn('conv1.4', ('layers_0', 'ConvBlock_1', 'BatchNorm_0'))
pt2jax['conv1.6.weight'] = ('params', 'layers_0', 'ConvBlock_2', 'Conv_0', 'kernel')
add_bn('bn1', ('layers_0', 'ConvBlock_2', 'BatchNorm_0'))

b_ind = 2 # block_ind
for b, n_blocks in enumerate(resnet.STAGE_SIZES[size], 1):
for i in range(n_blocks):
pt2jax[f'layer{b}.{i}.conv1.weight'] = ('params', bname(b_ind),
Expand Down Expand Up @@ -215,8 +213,9 @@ def bname(num):

b_ind += 1

pt2jax['fc.weight'] = ('params', 'Dense_0', 'kernel')
pt2jax['fc.bias'] = ('params', 'Dense_0', 'bias')
b_ind += 1
pt2jax['fc.weight'] = ('params', bname(b_ind), 'kernel')
pt2jax['fc.bias'] = ('params', bname(b_ind), 'bias')

variables = _pytorch_to_jax_params(pt2jax, state_dict, ('fc.weight',))

Expand Down
Loading