Skip to content

Commit

Permalink
Update test.py
Browse files Browse the repository at this point in the history
  • Loading branch information
shenqq377 authored Oct 16, 2022
1 parent 6a16c72 commit ac08d41
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def main(_run, _config, _log):
torch.set_num_threads(1)

_log.info(f'Create model...')
model = FewShotSeg()
model = FewShotSeg(alpha=_config['alpha'])
model.cuda()
model.load_state_dict(torch.load(_config['reload_model_path'], map_location='cpu'))

Expand Down Expand Up @@ -127,7 +127,7 @@ def main(_run, _config, _log):
query_pred_s = []
for i in range(query_image_s.shape[0]):
_pred_s, _ = model([support_image_s], [support_fg_mask_s], [query_image_s[[i]]],
train=False) # C x 2 x H x W
train=False, n_iters=_config['n_iters']) # C x 2 x H x W
query_pred_s.append(_pred_s)
query_pred_s = torch.cat(query_pred_s, dim=0)
query_pred_s = query_pred_s.argmax(dim=1).cpu() # C x H x W
Expand Down

0 comments on commit ac08d41

Please sign in to comment.