forked from ultralytics/yolov5
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Tensor initialization on device improvements (ultralytics#6959)
* Update common.py speed improvements Eliminate .to() ops where possible for reduced data transfer overhead. Primarily affects warmup and PyTorch Hub inference. * Updates * Updates * Update detect.py * Update val.py
- Loading branch information
1 parent
42ec267
commit c049d63
Showing
2 changed files
with
4 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -87,7 +87,7 @@ def process_batch(detections, labels, iouv): | |
matches = matches[np.unique(matches[:, 1], return_index=True)[1]] | ||
# matches = matches[matches[:, 2].argsort()[::-1]] | ||
matches = matches[np.unique(matches[:, 0], return_index=True)[1]] | ||
matches = torch.Tensor(matches).to(iouv.device) | ||
matches = torch.from_numpy(matches).to(iouv.device) | ||
correct[matches[:, 1].long()] = matches[:, 2:3] >= iouv | ||
return correct | ||
|
||
|
@@ -155,7 +155,7 @@ def run(data, | |
cuda = device.type != 'cpu' | ||
is_coco = isinstance(data.get('val'), str) and data['val'].endswith('coco/val2017.txt') # COCO dataset | ||
nc = 1 if single_cls else int(data['nc']) # number of classes | ||
iouv = torch.linspace(0.5, 0.95, 10).to(device) # iou vector for [email protected]:0.95 | ||
iouv = torch.linspace(0.5, 0.95, 10, device=device) # iou vector for [email protected]:0.95 | ||
niou = iouv.numel() | ||
|
||
# Dataloader | ||
|
@@ -196,7 +196,7 @@ def run(data, | |
loss += compute_loss([x.float() for x in train_out], targets)[1] # box, obj, cls | ||
|
||
# NMS | ||
targets[:, 2:] *= torch.Tensor([width, height, width, height]).to(device) # to pixels | ||
targets[:, 2:] *= torch.tensor((width, height, width, height), device=device) # to pixels | ||
lb = [targets[targets[:, 0] == i, 1:] for i in range(nb)] if save_hybrid else [] # for autolabelling | ||
t3 = time_sync() | ||
out = non_max_suppression(out, conf_thres, iou_thres, labels=lb, multi_label=True, agnostic=single_cls) | ||
|