Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issues about running "Ours_training.py" and "Ours_Asys_testing.py" #3

Open
SeunghanYu opened this issue Mar 6, 2024 · 3 comments
Open

Comments

@SeunghanYu
Copy link

SeunghanYu commented Mar 6, 2024

Hi, @MCC-WH.

I've been working with Ours_training.py and encountered a couple of issues that prevented the code from executing properly, necessitating some modifications on my end.

First Issue:

In the WarmupCos_Scheduler class, specifically within the def step(self): method, I encountered an issue where the lr_schedule list's index exceeded 1. To address this, I made the following modifications to the code:

def step(self):
    current_lr = self.lr_schedule[self.iter]
    for param_group in self.optimizer.param_groups:
        param_group['lr'] = current_lr
    self.iter += 1
    return current_lr

Second Issue:

In the main function, it appeared that anchor_features was being unnecessarily retrieved in the loop:

for idx, (images, features, anchor_features) in enumerate(metric_logger.log_every(train_loader, print_freq, header)):

Therefore, I adjusted the code as follows:

for idx, (images, features) in enumerate(val_metric_logger.log_every(val_loader, print_freq, '>> Val Epoch: [{}]'.format(epoch))):
    ...
    distill = model(images, features)
    ...

After making these adjustments, the training proceeded without further issues.

However, upon completing the training and moving on to Ours_Asys_testing.py, I encountered another problem when using the R101-DELG.pth file you shared on the google drive.

Issue:

When attempting to load the state dictionary with:

db_net.load_state_dict(os.path.join(get_data_root(), 'R101-DELG.pth'), strict=True)

I found that there were keys present in the loaded state_dict that did not exist in db_net:
Unexpected in loaded state_dict: {'attention.att_conv1.weight', 'reduction.weight', 'reduction.bias'}

After removing these keys, I ran the code again but found that the mAP value was 4 (very low)

Having gone through these modifications to run the provided code, I am currently seeing unsatisfactory results.
I'm wondering if there's an updated version of the code available that addresses these issues.
Additionally, I would like to know if there might be a problem with the provided R101-DELG.pth file.

Looking forward to your reply. Thank you in advance!

@MCC-WH
Copy link
Owner

MCC-WH commented Mar 6, 2024

I apologize for the errors, it should be my oversight when organizing the code. The correct DELG network structure definition should be found here DELG.

@MCC-WH
Copy link
Owner

MCC-WH commented Mar 6, 2024

These days I'm in an important internship and I don't have time to fix bugs in the code for a while, so if you run into some issues you could ask me again and I'll get back to you when I see it. :)

@SeunghanYu
Copy link
Author

Hello, @MCC-WH.

I modified the SSP/networks/R101-DELG.py. (based on Token/networks/RetrievalNet.py)

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch import Tensor, nn
from .pooling import GeM
from utils import resnet_block_dilation

import math

eps_fea_norm = 1e-5
eps_l2_norm = 1e-10


class ArcFace(nn.Module):
    def __init__(self, in_features, out_features, s=64.0, m=0.50, eps=1e-6):
        super(ArcFace, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.eps = eps

        self.s = s
        self.m = m

        self.weight = nn.Parameter(torch.randn(out_features, in_features))
        nn.init.xavier_uniform_(self.weight)
        self.threshold = math.pi - self.m

    def forward(self, input, label):
        cos_theta = F.linear(F.normalize(input, dim=-1), F.normalize(self.weight, dim=-1))
        theta = torch.acos(torch.clamp(cos_theta, -1.0 + self.eps, 1.0 - self.eps))

        one_hot = torch.zeros(cos_theta.size()).to(input.device)
        one_hot.scatter_(1, label.view(-1, 1), 1)

        selected = torch.where(theta > self.threshold, torch.zeros_like(one_hot), one_hot).bool()

        output = torch.cos(torch.where(selected, theta + self.m, theta))
        output *= self.s
        return output

class ResNet(nn.Module):
    def __init__(self, name: str, train_backbone: bool, dilation_block5: bool):
        super(ResNet, self).__init__()
        net_in = getattr(torchvision.models, name)(pretrained=True)
        if name.startswith('resnet'):
            features = list(net_in.children())[:-2]
        else:
            raise ValueError('Unsupported or unknown architecture: {}!'.format(name))
        features = nn.Sequential(*features)
        self.outputdim_block5 = 2048
        self.outputdim_block4 = 1024
        self.block1 = features[:4]
        self.block2 = features[4]
        self.block3 = features[5]
        self.block4 = features[6]
        self.block5 = features[7]
        if dilation_block5:
            self.block5 = resnet_block_dilation(self.block5, dilation=2)
        if not train_backbone:
            for param in self.parameters():
                param.requires_grad_(False)

    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.block5(x)
        return x
    
class ResNet_STAGE45(nn.Module):
    def __init__(self, name: str, train_backbone: bool, dilation_block5: bool):
        super(ResNet_STAGE45, self).__init__()
        net_in = getattr(torchvision.models, name)(pretrained=True)
        if name.startswith('resnet'):
            features = list(net_in.children())[:-2]
        else:
            raise ValueError('Unsupported or unknown architecture: {}!'.format(name))
        features = nn.Sequential(*features)
        self.outputdim_block5 = 2048
        self.outputdim_block4 = 1024
        self.block1 = features[:4]
        self.block2 = features[4]
        self.block3 = features[5]
        self.block4 = features[6]
        self.block5 = features[7]
        if dilation_block5:
            self.block5 = resnet_block_dilation(self.block5, dilation=2)
        if not train_backbone:
            for param in self.parameters():
                param.requires_grad_(False)

    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x4 = self.block4(x)
        x5 = self.block5(x4)
        return x4, x5

class Spatial_Attention(nn.Module):
    def __init__(self, input_dim):
        super(Spatial_Attention, self).__init__()
        self.att_conv1 = nn.Conv2d(input_dim, 1, kernel_size=(1, 1), padding=0, stride=1, bias=False)
        self.att_act2 = nn.Softplus(beta=1, threshold=20)
        self._reset_parameters()

    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, x):
        att_score = self.att_act2(self.att_conv1(x))
        return att_score


class DELG(nn.Module):
    def __init__(self, outputdim=2048, reduction_dim=128, classifier_num=1024):
        super(DELG, self).__init__()

        self.backbone = ResNet_STAGE45(name='resnet101', train_backbone=True, dilation_block5=False)
        self.pooling = GeM(p=3.0)
        self.whiten = nn.Conv2d(self.backbone.outputdim_block5, outputdim, kernel_size=(1, 1), stride=1, padding=0, bias=True)
        self.outputdim = outputdim
        self.classifier_block5 = ArcFace(in_features=outputdim, out_features=classifier_num, s=math.sqrt(self.outputdim), m=0.2)
        self.classifier_block4 = ArcFace(in_features=reduction_dim, out_features=classifier_num, s=math.sqrt(reduction_dim), m=0.1)
        self.attention = Spatial_Attention(input_dim=1024)
        self.reduction = nn.Conv2d(self.backbone.outputdim_block4, reduction_dim, kernel_size=1, padding=0, stride=1, bias=True)

    def _init_input_proj(self, weight, bias):
        self.reduction.weight.data = weight
        self.reduction.bias.data = bias

    @torch.no_grad()
    def forward_test(self, x):
        x4, x5 = self.backbone(x)
        global_feature = F.normalize(self.pooling(x5), p=2.0, dim=1)
        global_feature = self.whiten(global_feature).squeeze(-1).squeeze(-1)
        global_feature = F.normalize(global_feature, p=2.0, dim=-1)
        return global_feature
    
    def forward(self, x):
        x4, x5 = self.backbone(x)
        global_feature = F.normalize(self.pooling(x5), p=2.0, dim=1)
        global_feature = self.whiten(global_feature).squeeze(-1).squeeze(-1)
        global_feature = F.normalize(global_feature, p=2.0, dim=-1)
        return global_feature

After making these modifications, I reran the entire process from the beginning.

Upon executing Ours_Asys_testing.py, I observed a significant drop in performance, as detailed below:

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:08<00:00,  7.85it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 4993/4993 [13:59<00:00,  5.95it/s]
>> Test Dataset: roxford5k *** Feature Type: GeM asys >>
>> whiten: mAP Medium: 4.79, Hard: 6.26
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 70/70 [00:06<00:00, 11.51it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 6322/6322 [18:11<00:00,  5.79it/s]
>> Test Dataset: rparis6k *** Feature Type: GeM asys >>
>> whiten: mAP Medium: 2.55, Hard: 5.4

These results were obtained using timm=0.9.16.
I'm curious if there are known performance issues related to specific versions of timm.
Could the version of timm be causing this performance degradation?

Additionally, I would like to know if there are any plans to distribute the modified code after your internship concludes.

Looking forward to your reply. Thank you in advance!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants