Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimized implementation for hybrid matching and training logs #12

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 38 additions & 14 deletions engine.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

# ------------------------------------------------------------------------
# H-DETR
# Copyright (c) 2022 Peking University & Microsoft Research Asia. All Rights Reserved.
Expand Down Expand Up @@ -26,13 +27,12 @@
from datasets.coco_eval import CocoEvaluator
from datasets.panoptic_eval import PanopticEvaluator
from datasets.data_prefetcher import data_prefetcher
import time

scaler = torch.cuda.amp.GradScaler()


def train_hybrid(outputs, targets, k_one2many, criterion, lambda_one2many):
# one-to-one-loss
loss_dict = criterion(outputs, targets)
multi_targets = copy.deepcopy(targets)
# repeat the targets
for target in multi_targets:
Expand All @@ -44,14 +44,20 @@ def train_hybrid(outputs, targets, k_one2many, criterion, lambda_one2many):
outputs_one2many["pred_boxes"] = outputs["pred_boxes_one2many"]
outputs_one2many["aux_outputs"] = outputs["aux_outputs_one2many"]

# one-to-many loss
loss_dict_one2many = criterion(outputs_one2many, multi_targets)
for key, value in loss_dict_one2many.items():
if key + "_one2many" in loss_dict.keys():
loss_dict[key + "_one2many"] += value * lambda_one2many
else:
loss_dict[key + "_one2many"] = value * lambda_one2many
return loss_dict
# one-to-one first
(loss_dict, matching_time, assign_time, loss_time,) = criterion(
outputs=outputs,
targets=targets,
outputs_one2many=outputs_one2many,
multi_targets=multi_targets,
k_one2many=k_one2many,
)
return (
loss_dict,
matching_time,
assign_time,
loss_time,
)


def train_one_epoch(
Expand Down Expand Up @@ -83,6 +89,11 @@ def train_one_epoch(
prefetcher = data_prefetcher(data_loader, device, prefetch=True)
samples, targets = prefetcher.next()

time_for_matching = 0.0
time_for_assign = 0.0
time_for_loss = 0.0

start_time = time.time()
# for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
for _ in metric_logger.log_every(range(len(data_loader)), print_freq, header):
with torch.cuda.amp.autocast() if use_fp16 else torch.cuda.amp.autocast(
Expand All @@ -93,11 +104,16 @@ def train_one_epoch(
outputs = model(samples)

if k_one2many > 0:
loss_dict = train_hybrid(
loss_dict, matching_time, assign_time, loss_time = train_hybrid(
outputs, targets, k_one2many, criterion, lambda_one2many
)
else:
loss_dict = criterion(outputs, targets)
loss_dict, matching_time, assign_time, loss_time = criterion(
outputs, targets, k_one2many=0
)
time_for_matching += matching_time
time_for_assign += assign_time
time_for_loss += loss_time
weight_dict = criterion.weight_dict
losses = sum(
loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict
Expand Down Expand Up @@ -133,7 +149,7 @@ def train_one_epoch(
model.parameters(), max_norm
)
else:
grad_total_norm = utils.get_total_grad_norm(model.parameters(), max_norm)
grad_total_norm = utils.get_total_grad_norm(model.parameters())

if use_fp16:
scaler.step(optimizer)
Expand All @@ -155,6 +171,12 @@ def train_one_epoch(
wandb.log(loss_dict)
except:
pass
end_time = time.time()
total_time_cost = end_time - start_time
print("total time cost for an epoch is:", total_time_cost)
print("time for matching part for an epoch is:", time_for_matching)
print("time for linear assign part for an epoch is:", time_for_assign)
print("time for loss part for an epoch is:", time_for_loss)
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
Expand Down Expand Up @@ -205,7 +227,9 @@ def evaluate(
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

outputs = model(samples)
loss_dict = criterion(outputs, targets)
loss_dict, eval_matching_time, eval_assign_time, eval_loss_time = criterion(
outputs, targets, k_one2many=0
)
weight_dict = criterion.weight_dict

# reduce losses over all GPUs for logging purposes
Expand Down
146 changes: 131 additions & 15 deletions models/deformable_detr.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
# ------------------------------------------------------------------------
# H-DETR
# Copyright (c) 2022 Peking University & Microsoft Research Asia. All Rights Reserved.
# Licensed under the MIT-style license found in the LICENSE file in the root directory
# ------------------------------------------------------------------------
# Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
Expand All @@ -18,7 +14,7 @@
import torch.nn.functional as F
from torch import nn
import math

import time
from util import box_ops
from util.misc import (
NestedTensor,
Expand Down Expand Up @@ -461,24 +457,81 @@ def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs):
assert loss in loss_map, f"do you really want to compute {loss} loss?"
return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)

def forward(self, outputs, targets):
def forward(
self, outputs, targets, outputs_one2many=None, multi_targets=None, k_one2many=0
):
""" This performs the loss computation.
Parameters:
outputs: dict of tensors, see the output specification of the model for the format
targets: list of dicts, such that len(targets) == batch_size.
The expected keys in each dict depends on the losses applied, see each loss' doc
"""
start = time.time()
matching_time = 0.0
assign_time = 0.0
whole_targets = copy.deepcopy(targets)
outputs_without_aux = {
k: v
for k, v in outputs.items()
if k != "aux_outputs" and k != "enc_outputs"
}

# Retrieve the matching between the outputs of the last layer and the targets
indices = self.matcher(outputs_without_aux, targets)
num_one2one_queries = outputs_without_aux["pred_logits"].shape[1]

if outputs_one2many != None:
outputs["pred_logits"] = torch.cat(
[outputs["pred_logits"], outputs_one2many["pred_logits"]], dim=1
)
outputs["pred_boxes"] = torch.cat(
[outputs["pred_boxes"], outputs_one2many["pred_boxes"]], dim=1
)
# Retrieve the matching between the outputs of the last layer and the targets
indices, cur_match_time, cur_assign_time, cost_matrix = self.matcher(
outputs_without_aux, targets, full_outputs=outputs
)
else:
# Retrieve the matching between the outputs of the last layer and the targets
indices, cur_match_time, cur_assign_time, cost_matrix = self.matcher(
outputs_without_aux, targets
)
matching_time += cur_match_time
assign_time += cur_assign_time

if outputs_one2many != None:
outputs_without_aux_one2many = {
k: v
for k, v in outputs_one2many.items()
if k != "aux_outputs" and k != "enc_outputs"
}
indices_one2many, cur_match_time, cur_assign_time, _ = self.matcher(
outputs_without_aux_one2many,
multi_targets,
cost_matrix=cost_matrix,
k_one2many=k_one2many,
single_targets=targets,
)
matching_time += cur_match_time
assign_time += cur_assign_time

for b in range(len(targets)):
cur_gt_num = targets[b]["labels"].shape[0]

indices[b] = (
torch.cat(
[indices[b][0], indices_one2many[b][0] + num_one2one_queries]
),
torch.cat([indices[b][1], indices_one2many[b][1] + cur_gt_num]),
)

whole_targets[b]["boxes"] = torch.cat(
[targets[b]["boxes"], multi_targets[b]["boxes"]], dim=0
)
whole_targets[b]["labels"] = torch.cat(
[targets[b]["labels"], multi_targets[b]["labels"]], dim=0
)

# Compute the average number of target boxes accross all nodes, for normalization purposes
num_boxes = sum(len(t["labels"]) for t in targets)
num_boxes = sum(len(t["labels"]) for t in whole_targets)
num_boxes = torch.as_tensor(
[num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device
)
Expand All @@ -491,13 +544,72 @@ def forward(self, outputs, targets):
for loss in self.losses:
kwargs = {}
losses.update(
self.get_loss(loss, outputs, targets, indices, num_boxes, **kwargs)
self.get_loss(
loss, outputs, whole_targets, indices, num_boxes, **kwargs
)
)

# In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
if "aux_outputs" in outputs:
for i, aux_outputs in enumerate(outputs["aux_outputs"]):
indices = self.matcher(aux_outputs, targets)
if outputs_one2many != None:
new_aux_outputs = dict()
cur_one2many_aux_outputs = outputs_one2many["aux_outputs"][i]
new_aux_outputs["pred_logits"] = torch.cat(
[
aux_outputs["pred_logits"],
cur_one2many_aux_outputs["pred_logits"],
],
dim=1,
)
new_aux_outputs["pred_boxes"] = torch.cat(
[
aux_outputs["pred_boxes"],
cur_one2many_aux_outputs["pred_boxes"],
],
dim=1,
)

(
indices,
cur_match_time,
cur_assign_time,
cost_matrix,
) = self.matcher(aux_outputs, targets, full_outputs=new_aux_outputs)
aux_outputs = new_aux_outputs
else:
(
indices,
cur_match_time,
cur_assign_time,
cost_matrix,
) = self.matcher(aux_outputs, targets)
matching_time += cur_match_time
assign_time += cur_assign_time
if outputs_one2many != None:
indices_one2many, cur_match_time, cur_assign_time, _ = self.matcher(
cur_one2many_aux_outputs,
multi_targets,
cost_matrix=cost_matrix,
k_one2many=k_one2many,
single_targets=targets,
)
matching_time += cur_match_time
assign_time += cur_assign_time
for b in range(len(targets)):
cur_gt_num = targets[b]["labels"].shape[0]
indices[b] = (
torch.cat(
[
indices[b][0],
indices_one2many[b][0] + num_one2one_queries,
]
),
torch.cat(
[indices[b][1], indices_one2many[b][1] + cur_gt_num]
),
)

for loss in self.losses:
if loss == "masks":
# Intermediate masks losses are too costly to compute, we ignore them.
Expand All @@ -507,7 +619,7 @@ def forward(self, outputs, targets):
# Logging is enabled only for the last layer
kwargs["log"] = False
l_dict = self.get_loss(
loss, aux_outputs, targets, indices, num_boxes, **kwargs
loss, aux_outputs, whole_targets, indices, num_boxes, **kwargs
)
l_dict = {k + f"_{i}": v for k, v in l_dict.items()}
losses.update(l_dict)
Expand All @@ -517,7 +629,11 @@ def forward(self, outputs, targets):
bin_targets = copy.deepcopy(targets)
for bt in bin_targets:
bt["labels"] = torch.zeros_like(bt["labels"])
indices = self.matcher(enc_outputs, bin_targets)
indices, cur_match_time, cur_assign_time, _ = self.matcher(
enc_outputs, bin_targets
)
matching_time += cur_match_time
assign_time += cur_assign_time
for loss in self.losses:
if loss == "masks":
# Intermediate masks losses are too costly to compute, we ignore them.
Expand All @@ -531,8 +647,8 @@ def forward(self, outputs, targets):
)
l_dict = {k + f"_enc": v for k, v in l_dict.items()}
losses.update(l_dict)

return losses
end = time.time()
return losses, matching_time, assign_time, end - start


class PostProcess(nn.Module):
Expand Down
Loading