forked from ultralytics/yolov5
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add ComputeLoss() class (ultralytics#1950)
- Loading branch information
1 parent
862874c
commit 5105585
Showing
3 changed files
with
138 additions
and
122 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,7 +29,7 @@ | |
fitness, strip_optimizer, get_latest_run, check_dataset, check_file, check_git_status, check_img_size, \ | ||
check_requirements, print_mutation, set_logging, one_cycle, colorstr | ||
from utils.google_utils import attempt_download | ||
from utils.loss import compute_loss | ||
from utils.loss import ComputeLoss | ||
from utils.plots import plot_images, plot_labels, plot_results, plot_evolution | ||
from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first | ||
|
||
|
@@ -227,6 +227,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): | |
results = (0, 0, 0, 0, 0, 0, 0) # P, R, [email protected], [email protected], val_loss(box, obj, cls) | ||
scheduler.last_epoch = start_epoch - 1 # do not move | ||
scaler = amp.GradScaler(enabled=cuda) | ||
compute_loss = ComputeLoss(model) # init loss class | ||
logger.info(f'Image sizes {imgsz} train, {imgsz_test} test\n' | ||
f'Using {dataloader.num_workers} dataloader workers\n' | ||
f'Logging results to {save_dir}\n' | ||
|
@@ -286,7 +287,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): | |
# Forward | ||
with amp.autocast(enabled=cuda): | ||
pred = model(imgs) # forward | ||
loss, loss_items = compute_loss(pred, targets.to(device), model) # loss scaled by batch_size | ||
loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size | ||
if rank != -1: | ||
loss *= opt.world_size # gradient averaged between devices in DDP mode | ||
if opt.quad: | ||
|
@@ -344,7 +345,8 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): | |
dataloader=testloader, | ||
save_dir=save_dir, | ||
plots=plots and final_epoch, | ||
log_imgs=opt.log_imgs if wandb else 0) | ||
log_imgs=opt.log_imgs if wandb else 0, | ||
compute_loss=compute_loss) | ||
|
||
# Write | ||
with open(results_file, 'a') as f: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters