Skip to content

Commit

Permalink
span no_norm to spanplus_st
Browse files Browse the repository at this point in the history
  • Loading branch information
umzi2 committed Jun 21, 2024
1 parent b4a2214 commit 49122d8
Show file tree
Hide file tree
Showing 5 changed files with 237 additions and 5 deletions.
5 changes: 5 additions & 0 deletions neosr/archs/spanplus_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,3 +245,8 @@ def spanplus_xl(**kwargs):
@ARCH_REGISTRY.register()
def spanplus_s(**kwargs):
return spanplus(blocks=[2], feature_channels=32, **kwargs)


@ARCH_REGISTRY.register()
def spanplus_st(**kwargs):
return spanplus(upsampler="ps", **kwargs)
11 changes: 6 additions & 5 deletions readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@ This repository is not an official modernization of [span](https://github.com/ho

Training code from [NeoSR](https://github.com/muslll/neosr)

| Name | Upscaler | blocks | feature_channels |
|-------------|--------------|-----------|------------------|
| spanplus | Dysample | [4] | 48 |
| spanplus-s | DySample | [2] | 32 |
| spanplus-xl | DySample | [4, 4, 4] | 96 |
| Name | Upscaler | blocks | feature_channels |
|-------------|---------------|-----------|------------------|
| spanplus | Dysample | [4] | 48 |
| spanplus-s | DySample | [2] | 32 |
| spanplus-xl | DySample | [4, 4, 4] | 96 |
| spanplus-st | PixelShuffle | [4] | 48 |

### Detect:
```py
Expand Down
Empty file added span_to_spanplus/__init__.py
Empty file.
206 changes: 206 additions & 0 deletions span_to_spanplus/convert_keys.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
key_to_key = {
"conv_1.sk.weight": "feats.0.sk.weight",
"conv_1.sk.bias": "feats.0.sk.bias",
"conv_1.conv.0.weight": "feats.0.conv.0.weight",
"conv_1.conv.0.bias": "feats.0.conv.0.bias",
"conv_1.conv.1.weight": "feats.0.conv.1.weight",
"conv_1.conv.1.bias": "feats.0.conv.1.bias",
"conv_1.conv.2.weight": "feats.0.conv.2.weight",
"conv_1.conv.2.bias": "feats.0.conv.2.bias",
"conv_1.eval_conv.weight": "feats.0.eval_conv.weight",
"conv_1.eval_conv.bias": "feats.0.eval_conv.bias",
"block_1.c1_r.sk.weight": "feats.1.block_1.c1_r.sk.weight",
"block_1.c1_r.sk.bias": "feats.1.block_1.c1_r.sk.bias",
"block_1.c1_r.conv.0.weight": "feats.1.block_1.c1_r.conv.0.weight",
"block_1.c1_r.conv.0.bias": "feats.1.block_1.c1_r.conv.0.bias",
"block_1.c1_r.conv.1.weight": "feats.1.block_1.c1_r.conv.1.weight",
"block_1.c1_r.conv.1.bias": "feats.1.block_1.c1_r.conv.1.bias",
"block_1.c1_r.conv.2.weight": "feats.1.block_1.c1_r.conv.2.weight",
"block_1.c1_r.conv.2.bias": "feats.1.block_1.c1_r.conv.2.bias",
"block_1.c1_r.eval_conv.weight": "feats.1.block_1.c1_r.eval_conv.weight",
"block_1.c1_r.eval_conv.bias": "feats.1.block_1.c1_r.eval_conv.bias",
"block_1.c2_r.sk.weight": "feats.1.block_1.c2_r.sk.weight",
"block_1.c2_r.sk.bias": "feats.1.block_1.c2_r.sk.bias",
"block_1.c2_r.conv.0.weight": "feats.1.block_1.c2_r.conv.0.weight",
"block_1.c2_r.conv.0.bias": "feats.1.block_1.c2_r.conv.0.bias",
"block_1.c2_r.conv.1.weight": "feats.1.block_1.c2_r.conv.1.weight",
"block_1.c2_r.conv.1.bias": "feats.1.block_1.c2_r.conv.1.bias",
"block_1.c2_r.conv.2.weight": "feats.1.block_1.c2_r.conv.2.weight",
"block_1.c2_r.conv.2.bias": "feats.1.block_1.c2_r.conv.2.bias",
"block_1.c2_r.eval_conv.weight": "feats.1.block_1.c2_r.eval_conv.weight",
"block_1.c2_r.eval_conv.bias": "feats.1.block_1.c2_r.eval_conv.bias",
"block_1.c3_r.sk.weight": "feats.1.block_1.c3_r.sk.weight",
"block_1.c3_r.sk.bias": "feats.1.block_1.c3_r.sk.bias",
"block_1.c3_r.conv.0.weight": "feats.1.block_1.c3_r.conv.0.weight",
"block_1.c3_r.conv.0.bias": "feats.1.block_1.c3_r.conv.0.bias",
"block_1.c3_r.conv.1.weight": "feats.1.block_1.c3_r.conv.1.weight",
"block_1.c3_r.conv.1.bias": "feats.1.block_1.c3_r.conv.1.bias",
"block_1.c3_r.conv.2.weight": "feats.1.block_1.c3_r.conv.2.weight",
"block_1.c3_r.conv.2.bias": "feats.1.block_1.c3_r.conv.2.bias",
"block_1.c3_r.eval_conv.weight": "feats.1.block_1.c3_r.eval_conv.weight",
"block_1.c3_r.eval_conv.bias": "feats.1.block_1.c3_r.eval_conv.bias",
"block_2.c1_r.sk.weight": "feats.1.block_n.0.c1_r.sk.weight",
"block_2.c1_r.sk.bias": "feats.1.block_n.0.c1_r.sk.bias",
"block_2.c1_r.conv.0.weight": "feats.1.block_n.0.c1_r.conv.0.weight",
"block_2.c1_r.conv.0.bias": "feats.1.block_n.0.c1_r.conv.0.bias",
"block_2.c1_r.conv.1.weight": "feats.1.block_n.0.c1_r.conv.1.weight",
"block_2.c1_r.conv.1.bias": "feats.1.block_n.0.c1_r.conv.1.bias",
"block_2.c1_r.conv.2.weight": "feats.1.block_n.0.c1_r.conv.2.weight",
"block_2.c1_r.conv.2.bias": "feats.1.block_n.0.c1_r.conv.2.bias",
"block_2.c1_r.eval_conv.weight": "feats.1.block_n.0.c1_r.eval_conv.weight",
"block_2.c1_r.eval_conv.bias": "feats.1.block_n.0.c1_r.eval_conv.bias",
"block_2.c2_r.sk.weight": "feats.1.block_n.0.c2_r.sk.weight",
"block_2.c2_r.sk.bias": "feats.1.block_n.0.c2_r.sk.bias",
"block_2.c2_r.conv.0.weight": "feats.1.block_n.0.c2_r.conv.0.weight",
"block_2.c2_r.conv.0.bias": "feats.1.block_n.0.c2_r.conv.0.bias",
"block_2.c2_r.conv.1.weight": "feats.1.block_n.0.c2_r.conv.1.weight",
"block_2.c2_r.conv.1.bias": "feats.1.block_n.0.c2_r.conv.1.bias",
"block_2.c2_r.conv.2.weight": "feats.1.block_n.0.c2_r.conv.2.weight",
"block_2.c2_r.conv.2.bias": "feats.1.block_n.0.c2_r.conv.2.bias",
"block_2.c2_r.eval_conv.weight": "feats.1.block_n.0.c2_r.eval_conv.weight",
"block_2.c2_r.eval_conv.bias": "feats.1.block_n.0.c2_r.eval_conv.bias",
"block_2.c3_r.sk.weight": "feats.1.block_n.0.c3_r.sk.weight",
"block_2.c3_r.sk.bias": "feats.1.block_n.0.c3_r.sk.bias",
"block_2.c3_r.conv.0.weight": "feats.1.block_n.0.c3_r.conv.0.weight",
"block_2.c3_r.conv.0.bias": "feats.1.block_n.0.c3_r.conv.0.bias",
"block_2.c3_r.conv.1.weight": "feats.1.block_n.0.c3_r.conv.1.weight",
"block_2.c3_r.conv.1.bias": "feats.1.block_n.0.c3_r.conv.1.bias",
"block_2.c3_r.conv.2.weight": "feats.1.block_n.0.c3_r.conv.2.weight",
"block_2.c3_r.conv.2.bias": "feats.1.block_n.0.c3_r.conv.2.bias",
"block_2.c3_r.eval_conv.weight": "feats.1.block_n.0.c3_r.eval_conv.weight",
"block_2.c3_r.eval_conv.bias": "feats.1.block_n.0.c3_r.eval_conv.bias",
"block_3.c1_r.sk.weight": "feats.1.block_n.1.c1_r.sk.weight",
"block_3.c1_r.sk.bias": "feats.1.block_n.1.c1_r.sk.bias",
"block_3.c1_r.conv.0.weight": "feats.1.block_n.1.c1_r.conv.0.weight",
"block_3.c1_r.conv.0.bias": "feats.1.block_n.1.c1_r.conv.0.bias",
"block_3.c1_r.conv.1.weight": "feats.1.block_n.1.c1_r.conv.1.weight",
"block_3.c1_r.conv.1.bias": "feats.1.block_n.1.c1_r.conv.1.bias",
"block_3.c1_r.conv.2.weight": "feats.1.block_n.1.c1_r.conv.2.weight",
"block_3.c1_r.conv.2.bias": "feats.1.block_n.1.c1_r.conv.2.bias",
"block_3.c1_r.eval_conv.weight": "feats.1.block_n.1.c1_r.eval_conv.weight",
"block_3.c1_r.eval_conv.bias": "feats.1.block_n.1.c1_r.eval_conv.bias",
"block_3.c2_r.sk.weight": "feats.1.block_n.1.c2_r.sk.weight",
"block_3.c2_r.sk.bias": "feats.1.block_n.1.c2_r.sk.bias",
"block_3.c2_r.conv.0.weight": "feats.1.block_n.1.c2_r.conv.0.weight",
"block_3.c2_r.conv.0.bias": "feats.1.block_n.1.c2_r.conv.0.bias",
"block_3.c2_r.conv.1.weight": "feats.1.block_n.1.c2_r.conv.1.weight",
"block_3.c2_r.conv.1.bias": "feats.1.block_n.1.c2_r.conv.1.bias",
"block_3.c2_r.conv.2.weight": "feats.1.block_n.1.c2_r.conv.2.weight",
"block_3.c2_r.conv.2.bias": "feats.1.block_n.1.c2_r.conv.2.bias",
"block_3.c2_r.eval_conv.weight": "feats.1.block_n.1.c2_r.eval_conv.weight",
"block_3.c2_r.eval_conv.bias": "feats.1.block_n.1.c2_r.eval_conv.bias",
"block_3.c3_r.sk.weight": "feats.1.block_n.1.c3_r.sk.weight",
"block_3.c3_r.sk.bias": "feats.1.block_n.1.c3_r.sk.bias",
"block_3.c3_r.conv.0.weight": "feats.1.block_n.1.c3_r.conv.0.weight",
"block_3.c3_r.conv.0.bias": "feats.1.block_n.1.c3_r.conv.0.bias",
"block_3.c3_r.conv.1.weight": "feats.1.block_n.1.c3_r.conv.1.weight",
"block_3.c3_r.conv.1.bias": "feats.1.block_n.1.c3_r.conv.1.bias",
"block_3.c3_r.conv.2.weight": "feats.1.block_n.1.c3_r.conv.2.weight",
"block_3.c3_r.conv.2.bias": "feats.1.block_n.1.c3_r.conv.2.bias",
"block_3.c3_r.eval_conv.weight": "feats.1.block_n.1.c3_r.eval_conv.weight",
"block_3.c3_r.eval_conv.bias": "feats.1.block_n.1.c3_r.eval_conv.bias",
"block_4.c1_r.sk.weight": "feats.1.block_n.2.c1_r.sk.weight",
"block_4.c1_r.sk.bias": "feats.1.block_n.2.c1_r.sk.bias",
"block_4.c1_r.conv.0.weight": "feats.1.block_n.2.c1_r.conv.0.weight",
"block_4.c1_r.conv.0.bias": "feats.1.block_n.2.c1_r.conv.0.bias",
"block_4.c1_r.conv.1.weight": "feats.1.block_n.2.c1_r.conv.1.weight",
"block_4.c1_r.conv.1.bias": "feats.1.block_n.2.c1_r.conv.1.bias",
"block_4.c1_r.conv.2.weight": "feats.1.block_n.2.c1_r.conv.2.weight",
"block_4.c1_r.conv.2.bias": "feats.1.block_n.2.c1_r.conv.2.bias",
"block_4.c1_r.eval_conv.weight": "feats.1.block_n.2.c1_r.eval_conv.weight",
"block_4.c1_r.eval_conv.bias": "feats.1.block_n.2.c1_r.eval_conv.bias",
"block_4.c2_r.sk.weight": "feats.1.block_n.2.c2_r.sk.weight",
"block_4.c2_r.sk.bias": "feats.1.block_n.2.c2_r.sk.bias",
"block_4.c2_r.conv.0.weight": "feats.1.block_n.2.c2_r.conv.0.weight",
"block_4.c2_r.conv.0.bias": "feats.1.block_n.2.c2_r.conv.0.bias",
"block_4.c2_r.conv.1.weight": "feats.1.block_n.2.c2_r.conv.1.weight",
"block_4.c2_r.conv.1.bias": "feats.1.block_n.2.c2_r.conv.1.bias",
"block_4.c2_r.conv.2.weight": "feats.1.block_n.2.c2_r.conv.2.weight",
"block_4.c2_r.conv.2.bias": "feats.1.block_n.2.c2_r.conv.2.bias",
"block_4.c2_r.eval_conv.weight": "feats.1.block_n.2.c2_r.eval_conv.weight",
"block_4.c2_r.eval_conv.bias": "feats.1.block_n.2.c2_r.eval_conv.bias",
"block_4.c3_r.sk.weight": "feats.1.block_n.2.c3_r.sk.weight",
"block_4.c3_r.sk.bias": "feats.1.block_n.2.c3_r.sk.bias",
"block_4.c3_r.conv.0.weight": "feats.1.block_n.2.c3_r.conv.0.weight",
"block_4.c3_r.conv.0.bias": "feats.1.block_n.2.c3_r.conv.0.bias",
"block_4.c3_r.conv.1.weight": "feats.1.block_n.2.c3_r.conv.1.weight",
"block_4.c3_r.conv.1.bias": "feats.1.block_n.2.c3_r.conv.1.bias",
"block_4.c3_r.conv.2.weight": "feats.1.block_n.2.c3_r.conv.2.weight",
"block_4.c3_r.conv.2.bias": "feats.1.block_n.2.c3_r.conv.2.bias",
"block_4.c3_r.eval_conv.weight": "feats.1.block_n.2.c3_r.eval_conv.weight",
"block_4.c3_r.eval_conv.bias": "feats.1.block_n.2.c3_r.eval_conv.bias",
"block_5.c1_r.sk.weight": "feats.1.block_n.3.c1_r.sk.weight",
"block_5.c1_r.sk.bias": "feats.1.block_n.3.c1_r.sk.bias",
"block_5.c1_r.conv.0.weight": "feats.1.block_n.3.c1_r.conv.0.weight",
"block_5.c1_r.conv.0.bias": "feats.1.block_n.3.c1_r.conv.0.bias",
"block_5.c1_r.conv.1.weight": "feats.1.block_n.3.c1_r.conv.1.weight",
"block_5.c1_r.conv.1.bias": "feats.1.block_n.3.c1_r.conv.1.bias",
"block_5.c1_r.conv.2.weight": "feats.1.block_n.3.c1_r.conv.2.weight",
"block_5.c1_r.conv.2.bias": "feats.1.block_n.3.c1_r.conv.2.bias",
"block_5.c1_r.eval_conv.weight": "feats.1.block_n.3.c1_r.eval_conv.weight",
"block_5.c1_r.eval_conv.bias": "feats.1.block_n.3.c1_r.eval_conv.bias",
"block_5.c2_r.sk.weight": "feats.1.block_n.3.c2_r.sk.weight",
"block_5.c2_r.sk.bias": "feats.1.block_n.3.c2_r.sk.bias",
"block_5.c2_r.conv.0.weight": "feats.1.block_n.3.c2_r.conv.0.weight",
"block_5.c2_r.conv.0.bias": "feats.1.block_n.3.c2_r.conv.0.bias",
"block_5.c2_r.conv.1.weight": "feats.1.block_n.3.c2_r.conv.1.weight",
"block_5.c2_r.conv.1.bias": "feats.1.block_n.3.c2_r.conv.1.bias",
"block_5.c2_r.conv.2.weight": "feats.1.block_n.3.c2_r.conv.2.weight",
"block_5.c2_r.conv.2.bias": "feats.1.block_n.3.c2_r.conv.2.bias",
"block_5.c2_r.eval_conv.weight": "feats.1.block_n.3.c2_r.eval_conv.weight",
"block_5.c2_r.eval_conv.bias": "feats.1.block_n.3.c2_r.eval_conv.bias",
"block_5.c3_r.sk.weight": "feats.1.block_n.3.c3_r.sk.weight",
"block_5.c3_r.sk.bias": "feats.1.block_n.3.c3_r.sk.bias",
"block_5.c3_r.conv.0.weight": "feats.1.block_n.3.c3_r.conv.0.weight",
"block_5.c3_r.conv.0.bias": "feats.1.block_n.3.c3_r.conv.0.bias",
"block_5.c3_r.conv.1.weight": "feats.1.block_n.3.c3_r.conv.1.weight",
"block_5.c3_r.conv.1.bias": "feats.1.block_n.3.c3_r.conv.1.bias",
"block_5.c3_r.conv.2.weight": "feats.1.block_n.3.c3_r.conv.2.weight",
"block_5.c3_r.conv.2.bias": "feats.1.block_n.3.c3_r.conv.2.bias",
"block_5.c3_r.eval_conv.weight": "feats.1.block_n.3.c3_r.eval_conv.weight",
"block_5.c3_r.eval_conv.bias": "feats.1.block_n.3.c3_r.eval_conv.bias",
"block_6.c1_r.sk.weight": "feats.1.block_end.c1_r.sk.weight",
"block_6.c1_r.sk.bias": "feats.1.block_end.c1_r.sk.bias",
"block_6.c1_r.conv.0.weight": "feats.1.block_end.c1_r.conv.0.weight",
"block_6.c1_r.conv.0.bias": "feats.1.block_end.c1_r.conv.0.bias",
"block_6.c1_r.conv.1.weight": "feats.1.block_end.c1_r.conv.1.weight",
"block_6.c1_r.conv.1.bias": "feats.1.block_end.c1_r.conv.1.bias",
"block_6.c1_r.conv.2.weight": "feats.1.block_end.c1_r.conv.2.weight",
"block_6.c1_r.conv.2.bias": "feats.1.block_end.c1_r.conv.2.bias",
"block_6.c1_r.eval_conv.weight": "feats.1.block_end.c1_r.eval_conv.weight",
"block_6.c1_r.eval_conv.bias": "feats.1.block_end.c1_r.eval_conv.bias",
"block_6.c2_r.sk.weight": "feats.1.block_end.c2_r.sk.weight",
"block_6.c2_r.sk.bias": "feats.1.block_end.c2_r.sk.bias",
"block_6.c2_r.conv.0.weight": "feats.1.block_end.c2_r.conv.0.weight",
"block_6.c2_r.conv.0.bias": "feats.1.block_end.c2_r.conv.0.bias",
"block_6.c2_r.conv.1.weight": "feats.1.block_end.c2_r.conv.1.weight",
"block_6.c2_r.conv.1.bias": "feats.1.block_end.c2_r.conv.1.bias",
"block_6.c2_r.conv.2.weight": "feats.1.block_end.c2_r.conv.2.weight",
"block_6.c2_r.conv.2.bias": "feats.1.block_end.c2_r.conv.2.bias",
"block_6.c2_r.eval_conv.weight": "feats.1.block_end.c2_r.eval_conv.weight",
"block_6.c2_r.eval_conv.bias": "feats.1.block_end.c2_r.eval_conv.bias",
"block_6.c3_r.sk.weight": "feats.1.block_end.c3_r.sk.weight",
"block_6.c3_r.sk.bias": "feats.1.block_end.c3_r.sk.bias",
"block_6.c3_r.conv.0.weight": "feats.1.block_end.c3_r.conv.0.weight",
"block_6.c3_r.conv.0.bias": "feats.1.block_end.c3_r.conv.0.bias",
"block_6.c3_r.conv.1.weight": "feats.1.block_end.c3_r.conv.1.weight",
"block_6.c3_r.conv.1.bias": "feats.1.block_end.c3_r.conv.1.bias",
"block_6.c3_r.conv.2.weight": "feats.1.block_end.c3_r.conv.2.weight",
"block_6.c3_r.conv.2.bias": "feats.1.block_end.c3_r.conv.2.bias",
"block_6.c3_r.eval_conv.weight": "feats.1.block_end.c3_r.eval_conv.weight",
"block_6.c3_r.eval_conv.bias": "feats.1.block_end.c3_r.eval_conv.bias",
"conv_cat.weight": "feats.1.conv_cat.weight",
"conv_cat.bias": "feats.1.conv_cat.bias",
"conv_2.sk.weight": "feats.1.conv_2.sk.weight",
"conv_2.sk.bias": "feats.1.conv_2.sk.bias",
"conv_2.conv.0.weight": "feats.1.conv_2.conv.0.weight",
"conv_2.conv.0.bias": "feats.1.conv_2.conv.0.bias",
"conv_2.conv.1.weight": "feats.1.conv_2.conv.1.weight",
"conv_2.conv.1.bias": "feats.1.conv_2.conv.1.bias",
"conv_2.conv.2.weight": "feats.1.conv_2.conv.2.weight",
"conv_2.conv.2.bias": "feats.1.conv_2.conv.2.bias",
"conv_2.eval_conv.weight": "feats.1.conv_2.eval_conv.weight",
"conv_2.eval_conv.bias": "feats.1.conv_2.eval_conv.bias",
"upsampler.0.weight": "upsampler.0.weight",
"upsampler.0.bias": "upsampler.0.bias",
}
20 changes: 20 additions & 0 deletions span_to_spanplus/span_to_spanplus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import torch
from convert_keys import key_to_key

input_folder = "no_norm_span.pth"
out_folder = "span_plus.pth"


def load_model(state_dict):
unwrap_keys = ["state_dict", "params_ema", "params-ema", "params", "model", "net"]
for key in unwrap_keys:
if key in state_dict and isinstance(state_dict[key], dict):
return state_dict[key]


model = torch.load(input_folder)
span_model = load_model(model)
span_model.pop("no_norm")
for i in list(span_model.keys()):
span_model[key_to_key[i]] = span_model.pop(i)
torch.save(span_model, out_folder)

0 comments on commit 49122d8

Please sign in to comment.