From 500efb5caf1c4cd5e46822c3273d07d8a544ef9a Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Sat, 16 Jul 2022 23:46:23 +0200 Subject: [PATCH] Link fuse() to AutoShape() for Hub models (#8599) --- hubconf.py | 3 +-- models/common.py | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/hubconf.py b/hubconf.py index df585f8cb411..6bb9484a856d 100644 --- a/hubconf.py +++ b/hubconf.py @@ -36,7 +36,6 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo if not verbose: LOGGER.setLevel(logging.WARNING) - check_requirements(exclude=('tensorboard', 'thop', 'opencv-python')) name = Path(name) path = name.with_suffix('.pt') if name.suffix == '' and not name.is_dir() else name # checkpoint path @@ -44,7 +43,7 @@ def _create(name, pretrained=True, channels=3, classes=80, autoshape=True, verbo device = select_device(device) if pretrained and channels == 3 and classes == 80: - model = DetectMultiBackend(path, device=device) # download/load FP32 model + model = DetectMultiBackend(path, device=device, fuse=autoshape) # download/load FP32 model # model = models.experimental.attempt_load(path, map_location=device) # download/load FP32 model else: cfg = list((Path(__file__).parent / 'models').rglob(f'{path.stem}.yaml'))[0] # model.yaml path diff --git a/models/common.py b/models/common.py index fb5ac3a6f5a4..5ea1c307f034 100644 --- a/models/common.py +++ b/models/common.py @@ -305,7 +305,7 @@ def forward(self, x): class DetectMultiBackend(nn.Module): # YOLOv5 MultiBackend class for python inference on various backends - def __init__(self, weights='yolov5s.pt', device=torch.device('cpu'), dnn=False, data=None, fp16=False): + def __init__(self, weights='yolov5s.pt', device=torch.device('cpu'), dnn=False, data=None, fp16=False, fuse=True): # Usage: # PyTorch: weights = *.pt # TorchScript: *.torchscript @@ -331,7 +331,7 @@ def __init__(self, weights='yolov5s.pt', device=torch.device('cpu'), dnn=False, names = yaml.safe_load(f)['names'] if pt: # PyTorch - model = attempt_load(weights if isinstance(weights, list) else w, device=device) + model = attempt_load(weights if isinstance(weights, list) else w, device=device, inplace=True, fuse=fuse) stride = max(int(model.stride.max()), 32) # model stride names = model.module.names if hasattr(model, 'module') else model.names # get class names model.half() if fp16 else model.float()