Skip to content

Commit

Permalink
fix bugs & update results
Browse files Browse the repository at this point in the history
  • Loading branch information
GengDavid committed Feb 25, 2022
1 parent 03a98a1 commit 218c149
Show file tree
Hide file tree
Showing 10 changed files with 202 additions and 84 deletions.
24 changes: 21 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# CyCTR-PyTorch
This is a PyTorch re-implementation of NeurIPS 2021 paper "[Few-Shot Segmentation via Cycle-Consistent Transformer](https://proceedings.neurips.cc/paper/2021/file/b8b12f949378552c21f28deff8ba8eb6-Paper.pdf)".

# News
(Feb. 2022) Fix some bugs and update some results.

# Usage

### Requirements
Expand Down Expand Up @@ -68,7 +71,6 @@ For example,
```

### Test Only
+ Download checkpoints from [here](https://drive.google.com/drive/folders/1P3Qo7Zz_257z9gnVb7wroV7acaFYinkw?usp=sharing)
+ Modify `config` file (specify checkpoint path)
+ Run the following command:
```
Expand All @@ -80,10 +82,26 @@ For example,
sh test.sh pascal split0_resnet50
```

Results on 1-shot Pascal-5^i
Results on 1-shot Pascal-5^i with ResNet50 backbone ([checkpoints](https://drive.google.com/drive/folders/1fqIYfWz6vjxRsOrRGV5v9GpWAy1ZZ6Bo?usp=sharing))
| Model | Split-0 | Split-1 | Split-2 | Split-3 | Mean |
|--------------------|---------|---------|---------|---------|-------|
| CyCTR_resnet50 | 65.7 | 71.0 | 59.5 | 59.7 | 64.0 |

Results on 5-shot Pascal-5^i with ResNet50 backbone ([checkpoints](https://drive.google.com/drive/folders/1xD3PJKrnm2FnUlJjOBn8x0mv8GWvhFKW?usp=sharing))
| Model | Split-0 | Split-1 | Split-2 | Split-3 | Mean |
|--------------------|---------|---------|---------|---------|-------|
| CyCTR_resnet50 | 69.3 | 73.5 | 63.8 | 63.5 | 67.5 |

Results on 1-shot Pascal-5^i with ResNet101 backbone ([checkpoints](https://drive.google.com/drive/folders/1DRUz8NNukK5Aflt_uotihhom4XB7u4Bf?usp=sharing))
| Model | Split-0 | Split-1 | Split-2 | Split-3 | Mean |
|--------------------|---------|---------|---------|---------|-------|
| CyCTR_resnet50 | 67.2 | 71.1 | 57.6 | 59.0 | 63.7 |
Results on 5-shot Pascal-5^i with ResNet101 backbone ([checkpoints](https://drive.google.com/drive/folders/1lU2KDOPeOibNXWEbMQ7O-euFBI134fx_?usp=sharing))
| Model | Split-0 | Split-1 | Split-2 | Split-3 | Mean |
|--------------------|---------|---------|---------|---------|-------|
| CyCTR_resnet50 | 67.8 | 72.7 | 58.0 | 57.9 | 64.1 |
| CyCTR_resnet50 | 71.0 | 75.0 | 58.5 | 65.0 | 67.4|


# Acknowledgement

Expand Down
2 changes: 1 addition & 1 deletion config/coco/coco_split0_resnet50.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ TRAIN:
train_h: 473
train_w: 473
val_size: 473
hidden_dims: 384
hidden_dims: 256 # DFattn friendly
scale_min: 0.8 # minimum random scale
scale_max: 1.25 # maximum random scale
rotate_min: -10 # minimum random rotate
Expand Down
2 changes: 1 addition & 1 deletion config/pascal/pascal_split0_resnet50.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ TRAIN:
train_h: 473
train_w: 473
val_size: 473
hidden_dims: 384
hidden_dims: 256 # DFattn friendly
scale_min: 0.9 # minimum random scale
scale_max: 1.1 # maximum random scale
rotate_min: -10 # minimum random rotate
Expand Down
14 changes: 7 additions & 7 deletions model/CyCTR.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(self, layers=50, classes=2, shot=1, reduce_dim=384, \
nn.Dropout2d(p=drop_out),
)

self.high_avg_pool = nn.AdaptiveAvgPool1d(reduce_dim)
self.high_avg_pool = nn.Identity()

prior_channel = 1
self.qry_merge_feat = nn.Sequential(
Expand All @@ -64,7 +64,7 @@ def __init__(self, layers=50, classes=2, shot=1, reduce_dim=384, \
nn.ReLU(inplace=True),
nn.Conv2d(reduce_dim, reduce_dim, kernel_size=1, bias=False)
)
self.transformer = CyCTransformer(embed_dims=reduce_dim, num_points=9)
self.transformer = CyCTransformer(embed_dims=reduce_dim, shot=self.shot, num_points=9)
self.merge_multi_lvl_reduce = nn.Sequential(
nn.Conv2d(reduce_dim*self.trans_multi_lvl, reduce_dim, kernel_size=1, padding=0, bias=False),
nn.ReLU(inplace=True),
Expand Down Expand Up @@ -138,7 +138,7 @@ def print_params(self):
return repr_str


def forward(self, x, s_x=torch.FloatTensor(1,1,3,473,473).cuda(), s_y=torch.FloatTensor(1,1,473,473).cuda(), y=None):
def forward(self, x, s_x=torch.FloatTensor(1,1,3,473,473).cuda(), s_y=torch.FloatTensor(1,1,473,473).cuda(), y=None, padding_mask=None, s_padding_mask=None):
batch_size, _, h, w = x.size()
assert (h-1) % 8 == 0 and (w-1) % 8 == 0
img_size = x.size()[-2:]
Expand Down Expand Up @@ -187,7 +187,7 @@ def forward(self, x, s_x=torch.FloatTensor(1,1,3,473,473).cuda(), s_y=torch.Floa
aug_supp_feat = torch.cat(to_merge_fts, dim=1)
aug_supp_feat = self.supp_merge_feat(aug_supp_feat)

query_feat_list = self.transformer(query_feat, y.float(), aug_supp_feat, s_y.clone().float())
query_feat_list = self.transformer(query_feat, padding_mask.float(), aug_supp_feat, s_y.clone().float(), s_padding_mask.float())
fused_query_feat = []
for lvl, qry_feat in enumerate(query_feat_list):
if lvl == 0:
Expand Down Expand Up @@ -251,11 +251,11 @@ def generate_prior(self, query_feat_high, supp_feat_high, s_y, fts_size):
tmp_mask = F.interpolate(tmp_mask, size=(fts_size[0], fts_size[1]), mode='bilinear', align_corners=True)

tmp_supp_feat = supp_feat_high[:,st,...] * tmp_mask
q = self.high_avg_pool(query_feat_high.flatten(2).transpose(-2, -1)) # [bs, h*w, 256]
s = self.high_avg_pool(tmp_supp_feat.flatten(2).transpose(-2, -1)) # [bs, h*w, 256]
q = self.high_avg_pool(query_feat_high.flatten(2).transpose(-2, -1)) # [bs, h*w, c]
s = self.high_avg_pool(tmp_supp_feat.flatten(2).transpose(-2, -1)) # [bs, h*w, c]

tmp_query = q
tmp_query = tmp_query.contiguous().permute(0, 2, 1) # [bs, 256, h*w]
tmp_query = tmp_query.contiguous().permute(0, 2, 1) # [bs, c, h*w]
tmp_query_norm = torch.norm(tmp_query, 2, 1, True)

tmp_supp = s
Expand Down
88 changes: 58 additions & 30 deletions model/cyc_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch.nn.functional as F
import cv2
import math
import random
from model.ops.modules import MSDeformAttn
from model.positional_encoding import SinePositionalEncoding

Expand Down Expand Up @@ -57,7 +58,6 @@ def forward(self, x, residual=None):
class MyCrossAttention(nn.Module):
def __init__(self, dim, num_heads=1, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
assert num_heads==1, "currently only implement num_heads==1"
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
Expand All @@ -69,42 +69,56 @@ def __init__(self, dim, num_heads=1, qkv_bias=False, qk_scale=None, attn_drop=0.
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim, bias=False)
self.proj_drop = nn.Dropout(proj_drop)
self.ass_drop = nn.Dropout(0.1)

self.drop_prob = 0.1


def forward(self, q, k, v, supp_valid_mask=None, supp_mask=None, cyc=True):
B, N, C = q.shape
N_s = k.size(1)

q = self.q_fc(q)
q = self.q_fc(q).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)

k = self.k_fc(k)
v = self.v_fc(v)
k = self.k_fc(k).reshape(B, N_s, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
v = self.v_fc(v).reshape(B, N_s, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)

attn = (q @ k.transpose(-2, -1)) * self.scale # [bs, n, n]
if supp_valid_mask is not None:
supp_valid_mask = supp_valid_mask.unsqueeze(1).repeat(1, self.num_heads, 1) # [bs, nH, n]

attn = (q @ k.transpose(-2, -1)) * self.scale # [bs, nH, n, n]

if supp_mask is not None and cyc==True:
k2q_sim_idx = attn.max(1)[1] # [bs, n]
association = []
for hd_id in range(self.num_heads):
attn_single_hd = attn[:, hd_id, ...]
k2q_sim_idx = attn_single_hd.max(1)[1] # [bs, n]

q2k_sim_idx = attn.max(2)[1] # [bs, n]
q2k_sim_idx = attn_single_hd.max(2)[1] # [bs, n]

re_map_idx = torch.gather(q2k_sim_idx, 1, k2q_sim_idx)
re_map_mask = torch.gather(supp_mask, 1, re_map_idx)

association = (supp_mask == re_map_mask).to(attn.device) # [bs, n], True means matched position in supp

re_map_idx = torch.gather(q2k_sim_idx, 1, k2q_sim_idx)
re_map_mask = torch.gather(supp_mask, 1, re_map_idx)

asso_single_head = (supp_mask == re_map_mask).to(attn.device) # [bs, n], True means matched position in supp
association.append(asso_single_head.unsqueeze(1))
association = torch.cat(association, dim=1) # [bs, nH, ns]

if cyc:
supp_valid_mask[association==False] = 1.
inconsistent = ~association
inconsistent = inconsistent.float()
inconsistent = self.ass_drop(inconsistent)
supp_valid_mask[inconsistent>0] = 1.


if supp_valid_mask is not None:
supp_valid_mask = supp_valid_mask.unsqueeze(1).float()
supp_valid_mask = supp_valid_mask.unsqueeze(-2).float() # [bs, nH, 1, ns]
supp_valid_mask = supp_valid_mask * -10000.0
attn = attn + supp_valid_mask
attn = attn + supp_valid_mask

attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)

x = (attn @ v)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)

Expand All @@ -114,7 +128,7 @@ def forward(self, q, k, v, supp_valid_mask=None, supp_mask=None, cyc=True):
class CyCTransformer(nn.Module):
def __init__(self,
embed_dims=384,
num_heads=1,
num_heads=8,
num_layers=2,
num_levels=1,
num_points=9,
Expand All @@ -136,6 +150,7 @@ def __init__(self,
self.shot = shot
self.use_cross = True
self.use_self = True
self.use_cyc = True

self.rand_fg_num = rand_fg_num * shot
self.rand_bg_num = rand_bg_num * shot
Expand All @@ -148,7 +163,7 @@ def __init__(self,
for l_id in range(self.num_layers):
if self.use_cross:
self.cross_layers.append(
MyCrossAttention(embed_dims, attn_drop=self.dropout, proj_drop=self.dropout),
MyCrossAttention(embed_dims, num_heads=12 if embed_dims%12==0 else self.num_heads, attn_drop=self.dropout, proj_drop=self.dropout),
)
self.layer_norms.append(nn.LayerNorm(embed_dims))
if self.use_ffn:
Expand All @@ -157,7 +172,7 @@ def __init__(self,

if self.use_self:
self.qry_self_layers.append(
MSDeformAttn(embed_dims, num_levels, num_heads, num_points)
MSDeformAttn(embed_dims, num_levels, 12 if embed_dims%12==0 else self.num_heads, num_points)
)
self.layer_norms.append(nn.LayerNorm(embed_dims))

Expand Down Expand Up @@ -243,19 +258,22 @@ def get_qry_flatten_input(self, x, qry_masks):

return src_flatten, qry_valid_masks_flatten, pos_embed_flatten, spatial_shapes, level_start_index

def get_supp_flatten_input(self, s_x, supp_mask):
def get_supp_flatten_input(self, s_x, supp_mask, s_padding_mask):
s_x_flatten = []
supp_valid_mask = []
supp_obj_mask = []
supp_mask = F.interpolate(supp_mask, size=s_x.shape[-2:], mode='nearest').squeeze(1) # [bs*shot, h, w]
supp_mask = supp_mask.view(-1, self.shot, s_x.size(2), s_x.size(3))

s_padding_mask = F.interpolate(s_padding_mask, size=s_x.shape[-2:], mode='nearest').squeeze(1) # [bs*shot, h, w]
s_padding_mask = s_padding_mask.view(-1, self.shot, s_x.size(2), s_x.size(3))
s_x = s_x.view(-1, self.shot, s_x.size(1), s_x.size(2), s_x.size(3))

for st_id in range(s_x.size(1)):
supp_valid_mask_s = []
supp_obj_mask_s = []
for img_id in range(s_x.size(0)):
supp_valid_mask_s.append(supp_mask[img_id, st_id, ...]==255)
supp_valid_mask_s.append(s_padding_mask[img_id, st_id, ...]==255)
obj_mask = supp_mask[img_id, st_id, ...]==1
if obj_mask.sum() == 0: # To avoid NaN
obj_mask[obj_mask.size(0)//2-1:obj_mask.size(0)//2+1, obj_mask.size(1)//2-1:obj_mask.size(1)//2+1] = True
Expand Down Expand Up @@ -284,6 +302,15 @@ def get_supp_flatten_input(self, s_x, supp_mask):
return s_x_flatten, supp_valid_mask, supp_mask_flatten

def sparse_sampling(self, s_x, supp_mask, supp_valid_mask):
if self.training:
scale_min = 0.6
scale_max = 4.0 if self.shot==1 else 1.4
sampling_scale = random.uniform(scale_min, scale_max)
rand_fg_num = int(self.rand_fg_num*sampling_scale)
rand_bg_num = int(self.rand_bg_num*sampling_scale)
else:
rand_fg_num = self.rand_fg_num
rand_bg_num = self.rand_bg_num
assert supp_mask is not None
re_arrange_k = []
re_arrange_mask = []
Expand All @@ -296,23 +323,23 @@ def sparse_sampling(self, s_x, supp_mask, supp_valid_mask):
fg_k = k_b[supp_mask_b] # [num_fg, c]
bg_k = k_b[supp_mask_b==False] # [num_bg, c]

if num_fg<self.rand_fg_num:
rest_num = self.rand_fg_num+self.rand_bg_num-num_fg
if num_fg<rand_fg_num:
rest_num = rand_fg_num+rand_bg_num-num_fg
bg_select_idx = torch.randperm(num_bg)[:rest_num]
re_k = torch.cat([fg_k, bg_k[bg_select_idx]], dim=0)
re_mask = torch.cat([supp_mask_b[supp_mask_b==True], supp_mask_b[bg_select_idx]], dim=0)
re_valid_mask = torch.cat([supp_valid_mask[b_id][supp_mask_b==True], supp_valid_mask[b_id][bg_select_idx]], dim=0)

elif num_bg<self.rand_bg_num:
rest_num = self.rand_fg_num+self.rand_bg_num-num_bg
elif num_bg<rand_bg_num:
rest_num = rand_fg_num+rand_bg_num-num_bg
fg_select_idx = torch.randperm(num_fg)[:rest_num]
re_k = torch.cat([fg_k[fg_select_idx], bg_k], dim=0)
re_mask = torch.cat([supp_mask_b[fg_select_idx], supp_mask_b[supp_mask_b==False]], dim=0)
re_valid_mask = torch.cat([supp_valid_mask[b_id][fg_select_idx], supp_valid_mask[b_id][supp_mask_b==False]], dim=0)

else:
fg_select_idx = torch.randperm(num_fg)[:self.rand_fg_num]
bg_select_idx = torch.randperm(num_bg)[:self.rand_bg_num]
fg_select_idx = torch.randperm(num_fg)[:rand_fg_num]
bg_select_idx = torch.randperm(num_bg)[:rand_bg_num]
re_k = torch.cat([fg_k[fg_select_idx], bg_k[bg_select_idx]], dim=0)
re_mask = torch.cat([supp_mask_b[fg_select_idx], supp_mask_b[bg_select_idx]], dim=0)
re_valid_mask = torch.cat([supp_valid_mask[b_id][fg_select_idx], supp_valid_mask[b_id][bg_select_idx]], dim=0)
Expand All @@ -327,7 +354,7 @@ def sparse_sampling(self, s_x, supp_mask, supp_valid_mask):

return k, supp_mask, supp_valid_mask

def forward(self, x, qry_masks, s_x, supp_mask):
def forward(self, x, qry_masks, s_x, supp_mask, s_padding_mask):
if not isinstance(x, list):
x = [x]
if not isinstance(qry_masks, list):
Expand All @@ -338,7 +365,7 @@ def forward(self, x, qry_masks, s_x, supp_mask):

x_flatten, qry_valid_masks_flatten, pos_embed_flatten, spatial_shapes, level_start_index = self.get_qry_flatten_input(x, qry_masks)

s_x, supp_valid_mask, supp_mask_flatten = self.get_supp_flatten_input(s_x, supp_mask.clone())
s_x, supp_valid_mask, supp_mask_flatten = self.get_supp_flatten_input(s_x, supp_mask.clone(), s_padding_mask.clone())

reference_points = self.get_reference_points(spatial_shapes, device=x_flatten.device)

Expand All @@ -362,7 +389,8 @@ def forward(self, x, qry_masks, s_x, supp_mask):
if self.use_cross:
k, sampled_mask, sampled_valid_mask = self.sparse_sampling(s_x, supp_mask_flatten, supp_valid_mask) if self.training or l_id==0 else (k, sampled_mask, sampled_valid_mask)
v = k.clone()
cross_out = self.cross_layers[l_id](q, k, v, sampled_valid_mask, sampled_mask)
cross_out = self.cross_layers[l_id](q, k, v, sampled_valid_mask, sampled_mask, cyc=self.use_cyc)

q = cross_out + q
q = self.layer_norms[ln_id](q)
ln_id += 1
Expand Down
6 changes: 6 additions & 0 deletions model/ops/modules/ms_deform_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,17 @@ def forward(self, query, reference_points, input_flatten, input_spatial_shapes,

value = self.value_proj(input_flatten)
if input_padding_mask is not None:
drp_mask = torch.ones_like(input_padding_mask).float()
drp_mask = F.dropout(drp_mask, p=0.1, training=self.training)
drp_mask = drp_mask==0
value = value.masked_fill(drp_mask[..., None], float(0))

value = value.masked_fill(input_padding_mask[..., None], float(0))
value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads)
sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2)
attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points)
attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points)
attention_weights = F.dropout(attention_weights, p=0.1, training=self.training)
# N, Len_q, n_heads, n_levels, n_points, 2
if reference_points.shape[-1] == 2:
offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1)
Expand Down
10 changes: 6 additions & 4 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,25 +169,27 @@ def validate(val_loader, model, criterion):
end = time.time()
if args.split != 999:
if args.use_coco:
test_num = 20000
test_num = 5000
else:
test_num = 5000
test_num = 1000
else:
test_num = len(val_loader)
assert test_num % args.batch_size_val == 0
iter_num = 0
total_time = 0
for e in range(20):
for i, (input, target, s_input, s_mask, subcls, ori_label) in enumerate(val_loader):
for i, (input, target, s_input, s_mask, padding_mask, s_padding_mask, subcls, ori_label) in enumerate(val_loader):
if (iter_num-1) * args.batch_size_val >= test_num:
break
iter_num += 1
data_time.update(time.time() - end)
input = input.cuda(non_blocking=True)
target = target.cuda(non_blocking=True)
padding_mask = padding_mask.cuda(non_blocking=True)
s_padding_mask = s_padding_mask.cuda(non_blocking=True)
ori_label = ori_label.cuda(non_blocking=True)
start_time = time.time()
output = model(s_x=s_input, s_y=s_mask, x=input, y=target)
output = model(s_x=s_input, s_y=s_mask, x=input, y=target, padding_mask=padding_mask, s_padding_mask=s_padding_mask)
total_time = total_time + 1
model_time.update(time.time() - start_time)

Expand Down
Loading

0 comments on commit 218c149

Please sign in to comment.