Skip to content

Commit

Permalink
Update SparseML Integration to V6.1 (#26)
Browse files Browse the repository at this point in the history
* SparseML integration

* Add SparseML dependancy

* Update: add missing files

* Update requirements.txt

* Update: sparseml-nightly support

* Update: remove model versioning

* Partial update for multi-stage recipes

* Update: multi-stage recipe support

* Update: remove sparseml dep

* Fix: multi-stage recipe handeling

* Fix: multi stage support

* Fix: non-recipe runs

* Add: legacy hyperparam files

* Fix: add copy-paste to hyps

* Fix: nit

* apply structure fixes
  • Loading branch information
KSGulin authored and Benjamin committed Apr 8, 2022
1 parent 011e7df commit a32b970
Show file tree
Hide file tree
Showing 20 changed files with 758 additions and 103 deletions.
39 changes: 39 additions & 0 deletions data/hyps/hyp.finetune.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Hyperparameters for VOC finetuning
# python train.py --batch 64 --weights yolov5m.pt --data voc.yaml --img 512 --epochs 50
# See tutorials for hyperparameter evolution https://github.com/ultralytics/yolov5#tutorials


# Hyperparameter Evolution Results
# Generations: 306
# P R mAP.5 mAP.5:.95 box obj cls
# Metrics: 0.6 0.936 0.896 0.684 0.0115 0.00805 0.00146

lr0: 0.0032
lrf: 0.12
momentum: 0.843
weight_decay: 0.00036
warmup_epochs: 2.0
warmup_momentum: 0.5
warmup_bias_lr: 0.05
box: 0.0296
cls: 0.243
cls_pw: 0.631
obj: 0.301
obj_pw: 0.911
iou_t: 0.2
anchor_t: 2.91
# anchors: 3.63
fl_gamma: 0.0
hsv_h: 0.0138
hsv_s: 0.664
hsv_v: 0.464
degrees: 0.373
translate: 0.245
scale: 0.898
shear: 0.602
perspective: 0.0
flipud: 0.00856
fliplr: 0.5
mosaic: 1.0
mixup: 0.243
copy_paste: 0.0
34 changes: 34 additions & 0 deletions data/hyps/hyp.scratch.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Hyperparameters for COCO training from scratch
# python train.py --batch 40 --cfg yolov5m.yaml --weights '' --data coco.yaml --img 640 --epochs 300
# See tutorials for hyperparameter evolution https://github.com/ultralytics/yolov5#tutorials


lr0: 0.01 # initial learning rate (SGD=1E-2, Adam=1E-3)
lrf: 0.2 # final OneCycleLR learning rate (lr0 * lrf)
momentum: 0.937 # SGD momentum/Adam beta1
weight_decay: 0.0005 # optimizer weight decay 5e-4
warmup_epochs: 3.0 # warmup epochs (fractions ok)
warmup_momentum: 0.8 # warmup initial momentum
warmup_bias_lr: 0.1 # warmup initial bias lr
box: 0.05 # box loss gain
cls: 0.5 # cls loss gain
cls_pw: 1.0 # cls BCELoss positive_weight
obj: 1.0 # obj loss gain (scale with pixels)
obj_pw: 1.0 # obj BCELoss positive_weight
iou_t: 0.20 # IoU training threshold
anchor_t: 4.0 # anchor-multiple threshold
# anchors: 3 # anchors per output layer (0 to ignore)
fl_gamma: 0.0 # focal loss gamma (efficientDet default gamma=1.5)
hsv_h: 0.015 # image HSV-Hue augmentation (fraction)
hsv_s: 0.7 # image HSV-Saturation augmentation (fraction)
hsv_v: 0.4 # image HSV-Value augmentation (fraction)
degrees: 0.0 # image rotation (+/- deg)
translate: 0.1 # image translation (+/- fraction)
scale: 0.5 # image scale (+/- gain)
shear: 0.0 # image shear (+/- deg)
perspective: 0.0 # image perspective (+/- fraction), range 0-0.001
flipud: 0.0 # image flip up-down (probability)
fliplr: 0.5 # image flip left-right (probability)
mosaic: 1.0 # image mosaic (probability)
mixup: 0.0 # image mixup (probability)
copy_paste: 0.0
3 changes: 2 additions & 1 deletion detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
increment_path, non_max_suppression, print_args, scale_coords, strip_optimizer, xyxy2xywh)
from utils.plots import Annotator, colors, save_one_box
from utils.torch_utils import select_device, time_sync
from export import load_checkpoint


@torch.no_grad()
Expand Down Expand Up @@ -89,7 +90,7 @@ def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s)

# Load model
device = select_device(device)
model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=half)
model, extras = load_checkpoint(type_='val', weights=weights, device=device) # load FP32 model
stride, names, pt = model.stride, model.names, model.pt
imgsz = check_img_size(imgsz, s=stride) # check image size

Expand Down
173 changes: 156 additions & 17 deletions export.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
"""

import argparse
from copy import deepcopy
import json
import os
import platform
Expand All @@ -57,20 +58,26 @@
import torch.nn as nn
from torch.utils.mobile_optimizer import optimize_for_mobile

from sparseml.pytorch.utils import ModuleExporter
from sparseml.pytorch.sparsification.quantization import skip_onnx_input_quantize

FILE = Path(__file__).resolve()
ROOT = FILE.parents[0] # YOLOv5 root directory
if str(ROOT) not in sys.path:
sys.path.append(str(ROOT)) # add ROOT to PATH
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative

from models.common import Conv
from models.common import Conv, DetectMultiBackend
from models.experimental import attempt_load
from models.yolo import Detect
from models.yolo import Detect, Model
from utils.activations import SiLU
from utils.datasets import LoadImages
from utils.general import (LOGGER, check_dataset, check_img_size, check_requirements, check_version, colorstr,
file_size, print_args, url2file)
from utils.torch_utils import select_device
file_size, print_args, url2file, intersect_dicts)
from utils.torch_utils import select_device, torch_distributed_zero_first, is_parallel
from utils.downloads import attempt_download
from utils.sparse import SparseMLWrapper, check_download_sparsezoo_weights



def export_formats():
Expand Down Expand Up @@ -118,14 +125,33 @@ def export_onnx(model, im, file, opset, train, dynamic, simplify, prefix=colorst
LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...')
f = file.with_suffix('.onnx')

torch.onnx.export(model, im, f, verbose=False, opset_version=opset,
training=torch.onnx.TrainingMode.TRAINING if train else torch.onnx.TrainingMode.EVAL,
do_constant_folding=not train,
input_names=['images'],
output_names=['output'],
dynamic_axes={'images': {0: 'batch', 2: 'height', 3: 'width'}, # shape(1,3,640,640)
'output': {0: 'batch', 1: 'anchors'} # shape(1,25200,85)
} if dynamic else None)
# export through SparseML so quantized and pruned graphs can be corrected
save_dir = f.parent.absolute()
save_name = str(f).split(os.path.sep)[-1]

# get the number of outputs so we know how to name and change dynamic axes
# nested outputs can be returned if model is exported with dynamic
def _count_outputs(outputs):
count = 0
if isinstance(outputs, list) or isinstance(outputs, tuple):
for out in outputs:
count += _count_outputs(out)
else:
count += 1
return count

outputs = model(im)
num_outputs = _count_outputs(outputs)
input_names = ['input']
output_names = [f'out_{i}' for i in range(num_outputs)]
dynamic_axes = {k: {0: 'batch'} for k in (input_names + output_names)} if dynamic else None
exporter = ModuleExporter(model, save_dir)
exporter.export_onnx(im, name=save_name, convert_qat=True,
input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes)
try:
skip_onnx_input_quantize(f, f)
except:
pass

# Checks
model_onnx = onnx.load(f) # load onnx model
Expand Down Expand Up @@ -407,14 +433,123 @@ def export_tfjs(keras_model, im, file, prefix=colorstr('TensorFlow.js:')):
except Exception as e:
LOGGER.info(f'\n{prefix} export failure: {e}')

def create_checkpoint(epoch, model, optimizer, ema, sparseml_wrapper, **kwargs):
pickle = not sparseml_wrapper.qat_active(epoch) # qat does not support pickled exports
ckpt_model = deepcopy(model.module if is_parallel(model) else model).float()
yaml = ckpt_model.yaml
if not pickle:
ckpt_model = ckpt_model.state_dict()

return {'epoch': epoch,
'model': ckpt_model,
'optimizer': optimizer.state_dict(),
'yaml': yaml,
'hyp': model.hyp,
**ema.state_dict(pickle),
**sparseml_wrapper.state_dict(),
**kwargs}

def load_checkpoint(
type_,
weights,
device,
cfg=None,
hyp=None,
nc=None,
data=None,
dnn=False,
half = False,
recipe=None,
resume=None,
rank=-1
):
with torch_distributed_zero_first(rank):
# download if not found locally or from sparsezoo if stub
weights = attempt_download(weights) or check_download_sparsezoo_weights(weights)
ckpt = torch.load(weights[0] if isinstance(weights, list) or isinstance(weights, tuple)
else weights, map_location="cpu") # load checkpoint
start_epoch = ckpt['epoch'] + 1 if 'epoch' in ckpt else 0
pickled = isinstance(ckpt['model'], nn.Module)
train_type = type_ == 'train'
ensemble_type = type_ == 'ensemble'
val_type = type_ =='val'

if pickled and ensemble_type:
cfg = None
if ensemble_type:
model = attempt_load(weights, map_location=device) # load ensemble using pickled
state_dict = model.state_dict()
elif val_type:
model = DetectMultiBackend(weights, device=device, dnn=dnn, data=data, fp16=half)
state_dict = model.model.state_dict()
else:
# load model from config and weights
cfg = cfg or (ckpt['yaml'] if 'yaml' in ckpt else None) or \
(ckpt['model'].yaml if pickled else None)
model = Model(cfg, ch=3, nc=ckpt['nc'] if ('nc' in ckpt and not nc) else nc,
anchors=hyp.get('anchors') if hyp else None).to(device)
model_key = 'ema' if (not train_type and 'ema' in ckpt and ckpt['ema']) else 'model'
state_dict = ckpt[model_key].float().state_dict() if pickled else ckpt[model_key]
if val_type:
model = DetectMultiBackend(model=model, device=device, dnn=dnn, data=data, fp16=half)

# turn gradients for params back on in case they were removed
for p in model.parameters():
p.requires_grad = True

# load sparseml recipe for applying pruning and quantization
checkpoint_recipe = train_recipe = None
if resume:
train_recipe = ckpt['recipe'] if ('recipe' in ckpt) else None
elif ckpt['recipe'] or recipe:
train_recipe, checkpoint_recipe = recipe, ckpt['recipe']

sparseml_wrapper = SparseMLWrapper(model.model if val_type else model, checkpoint_recipe, train_recipe)
exclude_anchors = train_type and (cfg or hyp.get('anchors')) and not resume
loaded = False

sparseml_wrapper.apply_checkpoint_structure(float("inf"))
if train_type:
# intialize the recipe for training and restore the weights before if no quantized weights
quantized_state_dict = any([name.endswith('.zero_point') for name in state_dict.keys()])
if not quantized_state_dict:
state_dict = load_state_dict(model, state_dict, train=True, exclude_anchors=exclude_anchors)
loaded = True
sparseml_wrapper.initialize(start_epoch)

if not loaded:
state_dict = load_state_dict(model, state_dict, train=train_type, exclude_anchors=exclude_anchors)

model.float()
report = 'Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights)

return model, {
'ckpt': ckpt,
'state_dict': state_dict,
'sparseml_wrapper': sparseml_wrapper,
'report': report,
}


def load_state_dict(model, state_dict, train, exclude_anchors):
# fix older state_dict names not porting to the new model setup
state_dict = {key if not key.startswith("module.") else key[7:]: val for key, val in state_dict.items()}

if train:
# load any missing weights from the model
state_dict = intersect_dicts(state_dict, model.state_dict(), exclude=['anchor'] if exclude_anchors else [])

model.load_state_dict(state_dict, strict=not train) # load

return state_dict

@torch.no_grad()
def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
weights=ROOT / 'yolov5s.pt', # weights path
imgsz=(640, 640), # image (height, width)
batch_size=1, # batch size
device='cpu', # cuda device, i.e. 0 or 0,1,2,3 or cpu
include=('torchscript', 'onnx'), # include formats
include=('onnx'), # include formats
half=False, # FP16 half-precision export
inplace=False, # set YOLOv5 Detect() inplace=True
train=False, # model.train() mode
Expand All @@ -430,7 +565,8 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
topk_per_class=100, # TF.js NMS: topk per class to keep
topk_all=100, # TF.js NMS: topk for all classes to keep
iou_thres=0.45, # TF.js NMS: IoU threshold
conf_thres=0.25 # TF.js NMS: confidence threshold
conf_thres=0.25, # TF.js NMS: confidence threshold
remove_grid=False,
):
t = time.time()
include = [x.lower() for x in include] # to lowercase
Expand All @@ -443,8 +579,9 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
# Load PyTorch model
device = select_device(device)
assert not (device.type == 'cpu' and half), '--half only compatible with GPU export, i.e. use --device 0'
model = attempt_load(weights, map_location=device, inplace=True, fuse=True) # load FP32 model
nc, names = model.nc, model.names # number of classes, class names
model, extras = load_checkpoint(type_='ensemble', weights=weights, device=device) # load FP32 model
sparseml_wrapper = extras['sparseml_wrapper']
nc, names = extras["ckpt"]["nc"], model.names # number of classes, class names

# Checks
imgsz *= 2 if len(imgsz) == 1 else 1 # expand
Expand All @@ -469,6 +606,7 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
m.onnx_dynamic = dynamic
if hasattr(m, 'forward_export'):
m.forward = m.forward_export # assign custom forward (optional)
model.model[-1].export = not remove_grid # set Detect() layer grid export

for _ in range(2):
y = model(im) # dry runs
Expand Down Expand Up @@ -541,6 +679,7 @@ def parse_opt():
parser.add_argument('--topk-all', type=int, default=100, help='TF.js NMS: topk for all classes to keep')
parser.add_argument('--iou-thres', type=float, default=0.45, help='TF.js NMS: IoU threshold')
parser.add_argument('--conf-thres', type=float, default=0.25, help='TF.js NMS: confidence threshold')
parser.add_argument("--remove-grid", action="store_true", help="remove export of Detect() layer grid")
parser.add_argument('--include', nargs='+',
default=['torchscript', 'onnx'],
help='torchscript, onnx, openvino, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs')
Expand All @@ -556,4 +695,4 @@ def main(opt):

if __name__ == "__main__":
opt = parse_opt()
main(opt)
main(opt)
Loading

0 comments on commit a32b970

Please sign in to comment.