Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support K-Net #1289

Merged
merged 24 commits into from
Mar 10, 2022
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
e164374
knet first commit
MengzhangLI Feb 8, 2022
6b3fbea
fix import error in knet
MengzhangLI Feb 9, 2022
ff6daef
remove kernel update head from decoder head
MengzhangLI Feb 15, 2022
3192ecd
[Feature] Add kenerl updation for some decoder heads.
MengzhangLI Feb 17, 2022
a8cf3d5
[Feature] Add kenerl updation for some decoder heads.
MengzhangLI Feb 17, 2022
3315730
directly use forward_feature && modify other 3 decoder heads
MengzhangLI Feb 18, 2022
606e090
remover kernel_update attr
MengzhangLI Feb 18, 2022
9f8fb98
delete unnecessary variables in forward function
MengzhangLI Feb 18, 2022
a2629c3
delete kernel update function
MengzhangLI Feb 22, 2022
725f10e
delete kernel update function
MengzhangLI Feb 22, 2022
fb1fb20
refactor kernel_generate_head
MengzhangLI Feb 22, 2022
9729dca
delete kernel_generate_head
MengzhangLI Feb 22, 2022
5dc6d9b
add unit test & comments in knet.py
MengzhangLI Feb 25, 2022
c4fedd6
add copyright to fix lint error
MengzhangLI Feb 25, 2022
3d3aedb
Merge branch 'master' of https://github.com/open-mmlab/mmsegmentation…
MengzhangLI Feb 25, 2022
54c9c75
modify config names of knet
MengzhangLI Feb 25, 2022
a60e60c
rename swin-l 640
MengzhangLI Feb 25, 2022
59fd4d5
update upstream master branch
MengzhangLI Mar 1, 2022
498d9fb
Merge branch 'master' of https://github.com/open-mmlab/mmsegmentation…
MengzhangLI Mar 4, 2022
2af6585
Merge branch 'master' of https://github.com/open-mmlab/mmsegmentation…
MengzhangLI Mar 7, 2022
9c0e72d
upload models&logs and refactor knet_head.py
MengzhangLI Mar 7, 2022
1e26707
modify docstrings and add some ut
MengzhangLI Mar 7, 2022
8b815e6
add url, modify docstring and add loss ut
MengzhangLI Mar 8, 2022
d0a7c08
modify docstrings
MengzhangLI Mar 10, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
_base_ = [
'../_base_/datasets/ade20k.py', '../_base_/default_runtime.py',
'../_base_/schedules/schedule_80k.py'
]

# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
num_stages = 3
conv_kernel_size = 1
model = dict(
type='EncoderDecoder',
pretrained='open-mmlab://resnet50_v1c',
backbone=dict(
type='ResNetV1c',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
dilations=(1, 1, 2, 4),
strides=(1, 2, 1, 1),
norm_cfg=norm_cfg,
norm_eval=False,
style='pytorch',
contract_dilation=True),
decode_head=dict(
type='IterativeDecodeHead',
num_stages=num_stages,
kernel_update_head=[
dict(
type='KernelUpdateHead',
num_classes=150,
num_ffn_fcs=2,
num_heads=8,
num_mask_fcs=1,
feedforward_channels=2048,
in_channels=512,
out_channels=512,
dropout=0.0,
conv_kernel_size=conv_kernel_size,
mask_upsample_stride=2,
ffn_act_cfg=dict(type='ReLU', inplace=True),
with_ffn=True,
feat_transform_cfg=dict(
conv_cfg=dict(type='Conv2d'), act_cfg=None),
kernel_updator_cfg=dict(
type='KernelUpdator',
in_channels=256,
feat_channels=256,
out_channels=256,
input_feat_shape=3,
act_cfg=dict(type='ReLU', inplace=True),
norm_cfg=dict(type='LN'))) for _ in range(num_stages)
],
kernel_generate_head=dict(
type='ASPPHead',
in_channels=2048,
in_index=3,
channels=512,
dilations=(1, 12, 24, 36),
dropout_ratio=0.1,
num_classes=150,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))),
auxiliary_head=dict(
type='FCNHead',
in_channels=1024,
in_index=2,
channels=256,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=150,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))

# optimizer
optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, weight_decay=0.0005)
optimizer_config = dict(grad_clip=dict(max_norm=1, norm_type=2))
# learning policy
lr_config = dict(
_delete_=True,
policy='step',
warmup='linear',
warmup_iters=1000,
warmup_ratio=0.001,
step=[60000, 72000],
by_epoch=False)
# In K-Net implementation we use batch size 2 per GPU as default
data = dict(samples_per_gpu=2, workers_per_gpu=2)
95 changes: 95 additions & 0 deletions configs/knet/knet_s3_fcn_r50-d8_4x2_512x512_80k_adamw_ade20k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
_base_ = [
'../_base_/datasets/ade20k.py', '../_base_/default_runtime.py',
'../_base_/schedules/schedule_80k.py'
]

# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
num_stages = 3
conv_kernel_size = 1
model = dict(
type='EncoderDecoder',
pretrained='open-mmlab://resnet50_v1c',
backbone=dict(
type='ResNetV1c',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
dilations=(1, 1, 2, 4),
strides=(1, 2, 1, 1),
norm_cfg=norm_cfg,
norm_eval=False,
style='pytorch',
contract_dilation=True),
decode_head=dict(
type='IterativeDecodeHead',
num_stages=num_stages,
kernel_update_head=[
dict(
type='KernelUpdateHead',
num_classes=150,
num_ffn_fcs=2,
num_heads=8,
num_mask_fcs=1,
feedforward_channels=2048,
in_channels=512,
out_channels=512,
dropout=0.0,
conv_kernel_size=conv_kernel_size,
mask_upsample_stride=2,
ffn_act_cfg=dict(type='ReLU', inplace=True),
with_ffn=True,
feat_transform_cfg=dict(
conv_cfg=dict(type='Conv2d'), act_cfg=None),
kernel_updator_cfg=dict(
type='KernelUpdator',
in_channels=256,
feat_channels=256,
out_channels=256,
input_feat_shape=3,
act_cfg=dict(type='ReLU', inplace=True),
norm_cfg=dict(type='LN'))) for _ in range(num_stages)
],
kernel_generate_head=dict(
type='FCNHead',
in_channels=2048,
in_index=3,
channels=512,
num_convs=2,
concat_input=True,
dropout_ratio=0.1,
num_classes=150,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))),
auxiliary_head=dict(
type='FCNHead',
in_channels=1024,
in_index=2,
channels=256,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=150,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))
# optimizer
optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, weight_decay=0.0005)
optimizer_config = dict(grad_clip=dict(max_norm=1, norm_type=2))
# learning policy
lr_config = dict(
_delete_=True,
policy='step',
warmup='linear',
warmup_iters=1000,
warmup_ratio=0.001,
step=[60000, 72000],
by_epoch=False)
# In K-Net implementation we use batch size 2 per GPU as default
data = dict(samples_per_gpu=2, workers_per_gpu=2)
94 changes: 94 additions & 0 deletions configs/knet/knet_s3_pspnet_r50-d8_4x2_512x512_80k_adamw_ade20k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
_base_ = [
'../_base_/datasets/ade20k.py', '../_base_/default_runtime.py',
'../_base_/schedules/schedule_80k.py'
]

# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
num_stages = 3
conv_kernel_size = 1
model = dict(
type='EncoderDecoder',
pretrained='open-mmlab://resnet50_v1c',
backbone=dict(
type='ResNetV1c',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
dilations=(1, 1, 2, 4),
strides=(1, 2, 1, 1),
norm_cfg=norm_cfg,
norm_eval=False,
style='pytorch',
contract_dilation=True),
decode_head=dict(
type='IterativeDecodeHead',
num_stages=num_stages,
kernel_update_head=[
dict(
type='KernelUpdateHead',
num_classes=150,
num_ffn_fcs=2,
num_heads=8,
num_mask_fcs=1,
feedforward_channels=2048,
in_channels=512,
out_channels=512,
dropout=0.0,
conv_kernel_size=conv_kernel_size,
mask_upsample_stride=2,
ffn_act_cfg=dict(type='ReLU', inplace=True),
with_ffn=True,
feat_transform_cfg=dict(
conv_cfg=dict(type='Conv2d'), act_cfg=None),
kernel_updator_cfg=dict(
type='KernelUpdator',
in_channels=256,
feat_channels=256,
out_channels=256,
input_feat_shape=3,
act_cfg=dict(type='ReLU', inplace=True),
norm_cfg=dict(type='LN'))) for _ in range(num_stages)
],
kernel_generate_head=dict(
type='PSPHead',
in_channels=2048,
in_index=3,
channels=512,
pool_scales=(1, 2, 3, 6),
dropout_ratio=0.1,
num_classes=150,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))),
auxiliary_head=dict(
type='FCNHead',
in_channels=1024,
in_index=2,
channels=256,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=150,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))
# optimizer
optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, weight_decay=0.0005)
optimizer_config = dict(grad_clip=dict(max_norm=1, norm_type=2))
# learning policy
lr_config = dict(
_delete_=True,
policy='step',
warmup='linear',
warmup_iters=1000,
warmup_ratio=0.001,
step=[60000, 72000],
by_epoch=False)
# In K-Net implementation we use batch size 2 per GPU as default
data = dict(samples_per_gpu=2, workers_per_gpu=2)
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
_base_ = [
'../_base_/datasets/ade20k.py', '../_base_/default_runtime.py',
'../_base_/schedules/schedule_80k.py'
]

# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
num_stages = 3
conv_kernel_size = 1

model = dict(
type='EncoderDecoder',
pretrained='open-mmlab://resnet50_v1c',
backbone=dict(
type='ResNetV1c',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
dilations=(1, 1, 1, 1),
strides=(1, 2, 2, 2),
norm_cfg=norm_cfg,
norm_eval=False,
style='pytorch',
contract_dilation=True),
decode_head=dict(
type='IterativeDecodeHead',
num_stages=num_stages,
kernel_update_head=[
dict(
type='KernelUpdateHead',
num_classes=150,
num_ffn_fcs=2,
num_heads=8,
num_mask_fcs=1,
feedforward_channels=2048,
in_channels=256,
out_channels=256,
dropout=0.0,
conv_kernel_size=conv_kernel_size,
mask_upsample_stride=2,
ffn_act_cfg=dict(type='ReLU', inplace=True),
with_ffn=True,
feat_transform_cfg=dict(
conv_cfg=dict(type='Conv2d'), act_cfg=None),
kernel_updator_cfg=dict(
type='KernelUpdator',
in_channels=256,
feat_channels=256,
out_channels=256,
input_feat_shape=3,
act_cfg=dict(type='ReLU', inplace=True),
norm_cfg=dict(type='LN'))) for _ in range(num_stages)
],
kernel_generate_head=dict(
type='UPerHead',
in_channels=[256, 512, 1024, 2048],
in_index=[0, 1, 2, 3],
pool_scales=(1, 2, 3, 6),
channels=256,
dropout_ratio=0.1,
num_classes=150,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0))),
auxiliary_head=dict(
type='FCNHead',
in_channels=1024,
in_index=2,
channels=256,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=150,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))
# optimizer
optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, weight_decay=0.0005)
optimizer_config = dict(grad_clip=dict(max_norm=1, norm_type=2))
# learning policy
lr_config = dict(
_delete_=True,
policy='step',
warmup='linear',
warmup_iters=1000,
warmup_ratio=0.001,
step=[60000, 72000],
by_epoch=False)
# In K-Net implementation we use batch size 2 per GPU as default
data = dict(samples_per_gpu=2, workers_per_gpu=2)
Loading