-
-
Notifications
You must be signed in to change notification settings - Fork 16.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Update train.py * Update train.py * Update train.py * Update train.py * Create train.py
- Loading branch information
1 parent
0070995
commit ca5b10b
Showing
1 changed file
with
16 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -146,8 +146,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): | |
|
||
# Results | ||
if ckpt.get('training_results') is not None: | ||
with open(results_file, 'w') as file: | ||
file.write(ckpt['training_results']) # write results.txt | ||
results_file.write_text(ckpt['training_results']) # write results.txt | ||
|
||
# Epochs | ||
start_epoch = ckpt['epoch'] + 1 | ||
|
@@ -354,7 +353,7 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): | |
|
||
# Write | ||
with open(results_file, 'a') as f: | ||
f.write(s + '%10.4g' * 7 % results + '\n') # P, R, [email protected], [email protected], val_loss(box, obj, cls) | ||
f.write(s + '%10.4g' * 7 % results + '\n') # append metrics, val_loss | ||
if len(opt.name) and opt.bucket: | ||
os.system('gsutil cp %s gs://%s/results/results%s.txt' % (results_file, opt.bucket, opt.name)) | ||
|
||
|
@@ -375,15 +374,13 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): | |
best_fitness = fi | ||
|
||
# Save model | ||
save = (not opt.nosave) or (final_epoch and not opt.evolve) | ||
if save: | ||
with open(results_file, 'r') as f: # create checkpoint | ||
ckpt = {'epoch': epoch, | ||
'best_fitness': best_fitness, | ||
'training_results': f.read(), | ||
'model': ema.ema, | ||
'optimizer': None if final_epoch else optimizer.state_dict(), | ||
'wandb_id': wandb_run.id if wandb else None} | ||
if (not opt.nosave) or (final_epoch and not opt.evolve): # if save | ||
ckpt = {'epoch': epoch, | ||
'best_fitness': best_fitness, | ||
'training_results': results_file.read_text(), | ||
'model': ema.ema, | ||
'optimizer': None if final_epoch else optimizer.state_dict(), | ||
'wandb_id': wandb_run.id if wandb else None} | ||
|
||
# Save last, best and delete | ||
torch.save(ckpt, last) | ||
|
@@ -396,9 +393,9 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): | |
if rank in [-1, 0]: | ||
# Strip optimizers | ||
final = best if best.exists() else last # final model | ||
for f in [last, best]: | ||
for f in last, best: | ||
if f.exists(): | ||
strip_optimizer(f) # strip optimizers | ||
strip_optimizer(f) | ||
if opt.bucket: | ||
os.system(f'gsutil cp {final} gs://{opt.bucket}/weights') # upload | ||
|
||
|
@@ -415,17 +412,17 @@ def train(hyp, opt, device, tb_writer=None, wandb=None): | |
# Test best.pt | ||
logger.info('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1, (time.time() - t0) / 3600)) | ||
if opt.data.endswith('coco.yaml') and nc == 80: # if COCO | ||
for conf, iou, save_json in ([0.25, 0.45, False], [0.001, 0.65, True]): # speed, mAP tests | ||
for m in (last, best) if best.exists() else (last): # speed, mAP tests | ||
results, _, _ = test.test(opt.data, | ||
batch_size=batch_size * 2, | ||
imgsz=imgsz_test, | ||
conf_thres=conf, | ||
iou_thres=iou, | ||
model=attempt_load(final, device).half(), | ||
conf_thres=0.001, | ||
iou_thres=0.7, | ||
model=attempt_load(m, device).half(), | ||
single_cls=opt.single_cls, | ||
dataloader=testloader, | ||
save_dir=save_dir, | ||
save_json=save_json, | ||
save_json=True, | ||
plots=False) | ||
|
||
else: | ||
|