Skip to content

Commit

Permalink
[Fix] Fix typo of NAFNet (#1557)
Browse files Browse the repository at this point in the history
* Update nafnet_c64eb11128mb1db1111_8xb8-lr1e-3-400k_gopro.py

* Update nafnet_c64eb2248mb12db2222_8xb8-lr1e-3-400k_sidd.py

* Update nafnet_net.py

* Update nafnet_net.py

* Update test_nafnet.py

* Update test_nafnet.py

---------

Co-authored-by: Z-Fran <[email protected]>
  • Loading branch information
ydengbi and Z-Fran authored Feb 23, 2023
1 parent 0acc032 commit a6d75e9
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
type='BaseEditModel',
generator=dict(
type='NAFNetLocal',
img_channel=3,
img_channels=3,
mid_channels=64,
enc_blk_nums=[1, 1, 1, 28],
middle_blk_num=1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
type='BaseEditModel',
generator=dict(
type='NAFNet',
img_channel=3,
img_channels=3,
mid_channels=64,
enc_blk_nums=[2, 2, 4, 8],
middle_blk_num=12,
Expand Down
6 changes: 3 additions & 3 deletions mmedit/models/editors/nafnet/nafnet_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@ class NAFNet(BaseModule):
"""

def __init__(self,
img_channel=3,
img_channels=3,
mid_channels=16,
middle_blk_num=1,
enc_blk_nums=[],
dec_blk_nums=[]):
super().__init__()

self.intro = nn.Conv2d(
in_channels=img_channel,
in_channels=img_channels,
out_channels=mid_channels,
kernel_size=3,
padding=1,
Expand All @@ -42,7 +42,7 @@ def __init__(self,
bias=True)
self.ending = nn.Conv2d(
in_channels=mid_channels,
out_channels=img_channel,
out_channels=img_channels,
kernel_size=3,
padding=1,
stride=1,
Expand Down
4 changes: 2 additions & 2 deletions tests/test_models/test_editors/test_nafnet/test_nafnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
def test_nafnet():

model = NAFNet(
img_channel=3,
img_channels=3,
mid_channels=64,
enc_blk_nums=[2, 2, 4, 8],
middle_blk_num=12,
Expand Down Expand Up @@ -39,7 +39,7 @@ def test_nafnet():
def test_nafnet_local():

model = NAFNetLocal(
img_channel=3,
img_channels=3,
mid_channels=64,
enc_blk_nums=[1, 1, 1, 28],
middle_blk_num=1,
Expand Down

0 comments on commit a6d75e9

Please sign in to comment.