Skip to content

Commit

Permalink
coco dataset support, automatic aws weight upload (#54)
Browse files Browse the repository at this point in the history
* ignore neptune logs

* ignore neptune logs

* add s3 weight upload to readme

* add s3 file upload util

* integrate automatic s3 upload

* add missing boto3 requirement

* update logging

* better exception handling

* update s3 weight dir

* add coco styled data support for training

* add sahi dependency

* update sahi version

* fix wandb logging error

* fix category based ap logging

* minor fix
  • Loading branch information
fcakyon authored Nov 9, 2021
1 parent c5b5562 commit 07c9cab
Show file tree
Hide file tree
Showing 8 changed files with 142 additions and 9 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ data/*
pycocotools/*
results*.txt
gcp_test*.sh
.neptune

# Datasets -------------------------------------------------------------------------------------------------------------
coco/
Expand Down
29 changes: 27 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ You can call `yolov5 train`, `yolov5 detect`, `yolov5 val` and `yolov5 export` c
<details open>
<summary>Training</summary>

Finetune one of the pretrained YOLOv5 models using your custom `data.yaml`:
- Finetune one of the pretrained YOLOv5 models using your custom `data.yaml`:

```bash
$ yolov5 train --data data.yaml --weights yolov5s.pt --batch-size 16 --img 640
Expand All @@ -171,12 +171,37 @@ $ yolov5 train --data data.yaml --weights yolov5s.pt --batch-size 16 --img 640
yolov5x.pt 2
```

Visualize your experiments via [Neptune.AI](https://neptune.ai/):
- Start a training using a COCO formatted dataset:

```yaml
# data.yml
train_json_path: "train.json"
train_image_dir: "train_image_dir/"
val_json_path: "val.json"
val_image_dir: "val_image_dir/"
```
```bash
$ yolov5 train --data data.yaml --weights yolov5s.pt
```

- Visualize your experiments via [Neptune.AI](https://neptune.ai/):

```bash
$ yolov5 train --data data.yaml --weights yolov5s.pt --neptune_project NAMESPACE/PROJECT_NAME --neptune_token YOUR_NEPTUNE_TOKEN
```

- Automatically upload weights to AWS S3 (with Neptune.AI artifact tracking integration):

```bash
export AWS_ACCESS_KEY_ID=YOUR_KEY
export AWS_SECRET_ACCESS_KEY=YOUR_KEY
```

```bash
$ yolov5 train --data data.yaml --weights yolov5s.pt --s3_dir YOUR_S3_FOLDER_DIRECTORY
```

</details>

<details open>
Expand Down
4 changes: 4 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,7 @@ seaborn>=0.11.0
thop # FLOPs computation
# CLI
fire
# AWS
boto3>=1.19.1
# coco to yolov5 conversion
sahi>=0.8.8
48 changes: 47 additions & 1 deletion yolov5/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import time
from copy import deepcopy
from pathlib import Path
from shutil import copyfile

import numpy as np
import torch
Expand Down Expand Up @@ -46,6 +47,7 @@
from yolov5.utils.metrics import fitness
from yolov5.utils.loggers import Loggers
from yolov5.utils.callbacks import Callbacks
from yolov5.utils.aws import upload_file_to_s3

LOGGER = logging.getLogger(__name__)
LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
Expand All @@ -62,6 +64,27 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \
opt.resume, opt.noval, opt.nosave, opt.workers, opt.freeze

# coco to yolov5 conversion
is_coco_data = False
with open(data, errors='ignore') as f:
data_dict = yaml.safe_load(f) # load data dict
if "train_json_path" in data_dict:
is_coco_data = True
if is_coco_data:
from sahi.utils.coco import export_coco_as_yolov5_via_yml
data = export_coco_as_yolov5_via_yml(yml_path=data, output_dir=save_dir / 'data')
opt.data = data

w = save_dir / 'data' / 'coco' # coco dir
w.mkdir(parents=True, exist_ok=True) # make dir

# copy train.json/val.json and coco_data.yml into data/coco/ folder
copyfile(data, str(w / Path(data).name))
if "train_json_path" in data_dict and Path(data_dict["train_json_path"]).is_file():
copyfile(data_dict["train_json_path"], str(w / Path(data_dict["train_json_path"]).name))
if "val_json_path" in data_dict and Path(data_dict["val_json_path"]).is_file():
copyfile(data_dict["val_json_path"], str(w / Path(data_dict["val_json_path"]).name))

# Directories
w = save_dir / 'weights' # weights dir
(w.parent if evolve else w).mkdir(parents=True, exist_ok=True) # make dir
Expand Down Expand Up @@ -393,6 +416,16 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
if (epoch > 0) and (opt.save_period > 0) and (epoch % opt.save_period == 0):
torch.save(ckpt, w / f'epoch{epoch}.pt')
del ckpt

# upload best model to aws s3
if opt.s3_dir:
s3_file = str(Path(best.parents[1].name) / "weights" / "best.pt")
LOGGER.info(f"{colorstr('aws:')} Uploading best weight to AWS S3...")
result = upload_file_to_s3(local_file=str(best), s3_dir=opt.s3_dir, s3_file=s3_file)
s3_path = str(Path(opt.s3_dir) / s3_file)
if result:
LOGGER.info(f"{colorstr('aws:')} Best weight has been successfully uploaded to {s3_path}")

callbacks.run('on_model_save', last, epoch, final_epoch, best_fitness, fi)

# Stop Single-GPU
Expand Down Expand Up @@ -434,6 +467,16 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
if is_coco:
callbacks.run('on_fit_epoch_end', list(mloss) + list(results) + lr, epoch, best_fitness, fi)

# upload best model to aws s3
if opt.s3_dir:
s3_dir = opt.s3_dir
s3_file = str(Path(best.parents[1].name) / "weights" / "best.pt")
LOGGER.info(f"{colorstr('aws:')} Uploading best weight to AWS S3...")
result = upload_file_to_s3(local_file=str(best), s3_dir=s3_dir, s3_file=s3_file)
s3_path = "s3://" + str(Path(s3_dir.replace("s3://","")) / s3_file)
if result:
LOGGER.info(f"{colorstr('aws:')} Best weight has been successfully uploaded to {s3_path}")

callbacks.run('on_train_end', last, best, plots, epoch, results)
LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}")

Expand Down Expand Up @@ -481,14 +524,17 @@ def parse_opt(known=False):

# Weights & Biases arguments
parser.add_argument('--entity', default=None, help='W&B: Entity')
parser.add_argument('--upload_dataset', action='store_true', help='W&B: Upload dataset as artifact table')
parser.add_argument('--bbox_interval', type=int, default=-1, help='W&B: Set bounding-box image logging interval')
parser.add_argument('--artifact_alias', type=str, default='latest', help='W&B: Version of dataset artifact to use')

# Neptune AI arguments
parser.add_argument('--neptune_token', type=str, default="", help='neptune.ai api token')
parser.add_argument('--neptune_project', type=str, default="", help='https://docs.neptune.ai/api-reference/neptune')

# AWS arguments
parser.add_argument('--s3_dir', type=str, default="", help='aws s3 folder directory to upload best weight and dataset')
parser.add_argument('--upload_dataset', action='store_true', help='upload dataset to aws s3')

opt = parser.parse_known_args()[0] if known else parser.parse_args()
return opt

Expand Down
42 changes: 42 additions & 0 deletions yolov5/utils/aws.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import os
from pathlib import Path
import logging
import boto3
from botocore.exceptions import NoCredentialsError
from yolov5.utils.general import colorstr


AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
AWS_SECRET_ACCESS_KEY = os.environ.get("AWS_SECRET_ACCESS_KEY")

LOGGER = logging.getLogger(__name__)


def parse_s3_uri(s3_uri):
# strip 's3://'
if s3_uri.startswith("s3://"):
s3_uri = s3_uri[5:]
# parse bucket and key
s3_components = s3_uri.split("/")
bucket = s3_components[0]
s3_key = ""
if len(s3_components) > 1:
s3_key = "/".join(s3_components[1:])
return bucket, s3_key

def upload_file_to_s3(local_file, s3_dir, s3_file):
s3 = boto3.client('s3', aws_access_key_id=AWS_ACCESS_KEY_ID,
aws_secret_access_key=AWS_SECRET_ACCESS_KEY)
# parse s3 uri
bucket, s3_key = parse_s3_uri(s3_dir)
# upload to s3
try:
s3_path = str(Path(s3_key) / s3_file)
s3.upload_file(local_file, bucket, s3_path)
return True
except FileNotFoundError:
print(f"{colorstr('aws:')} S3 upload failed because local file not found: {local_file}")
return False
except NoCredentialsError:
print(f"{colorstr('aws:')} AWS credentials are not set. Please configure aws via CLI or set required ENV variables.")
return False
2 changes: 1 addition & 1 deletion yolov5/utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, r
cache, exists = np.load(cache_path, allow_pickle=True).item(), True # load dict
assert cache['version'] == self.cache_version # same version
assert cache['hash'] == get_hash(self.label_files + self.img_files) # same hash
except:
except FileNotFoundError:
cache, exists = self.cache_labels(cache_path, prefix), False # cache

# Display cache
Expand Down
23 changes: 19 additions & 4 deletions yolov5/utils/loggers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""

import os
from pathlib import Path
import warnings
from threading import Thread

Expand Down Expand Up @@ -63,7 +64,8 @@ def __init__(self, save_dir=None, weights=None, opt=None, hyp=None, logger=None,
self.class_name_keys = ['metrics/' + name + '_mAP_50' for name in class_names]
else:
self.class_name_keys = ['val/' + name + '_mAP_50' for name in class_names]

self.s3_weight_folder = None if not opt.s3_dir else "s3://" + str(Path(opt.s3_dir.replace("s3://","")) / save_dir.name / "weights")

# Message
if not wandb:
prefix = colorstr('Weights & Biases: ')
Expand Down Expand Up @@ -163,6 +165,9 @@ def on_model_save(self, last, epoch, final_epoch, best_fitness, fi):
if self.wandb:
if ((epoch + 1) % self.opt.save_period == 0 and not final_epoch) and self.opt.save_period != -1:
self.wandb.log_model(last.parent, self.opt, epoch, fi, best_model=best_fitness == fi)
if self.neptune and self.neptune.neptune_run and self.s3_weight_folder is not None:
if not final_epoch and best_fitness == fi:
self.neptune.neptune_run["weights"].track_files(self.s3_weight_folder)

def on_train_end(self, last, best, plots, epoch, results):
# Callback runs on training end
Expand All @@ -174,11 +179,18 @@ def on_train_end(self, last, best, plots, epoch, results):
if self.tb:
import cv2
for f in files:
if f.name != "results.html":
if f.suffix != ".html":
self.tb.add_image(f.stem, cv2.imread(str(f))[..., ::-1], epoch, dataformats='HWC')

if self.wandb:
self.wandb.log({"Results": [wandb.Image(str(f), caption=f.name) for f in files]})
results = []
for f in files:
if f.suffix == ".html":
results.append(wandb.Html(str(f)))
else:
results.append(wandb.Image(str(f), caption=f.name))

self.wandb.log({"Results": results})
# Calling wandb.log. TODO: Refactor this into WandbLogger.log_model
if not self.opt.evolve:
wandb.log_artifact(str(best if best.exists() else last), type='model',
Expand All @@ -191,9 +203,12 @@ def on_train_end(self, last, best, plots, epoch, results):

if self.neptune and self.neptune.neptune_run:
for f in files:
if f.name == "results.html":
if f.suffix == ".html":
self.neptune.neptune_run['Results/{}'.format(f)].upload(neptune.types.File(str(f)))
else:
self.neptune.neptune_run['Results/{}'.format(f)].log(neptune.types.File(str(f)))

if self.s3_weight_folder is not None:
self.neptune.neptune_run["weights"].track_files(self.s3_weight_folder)

self.neptune.finish_run()
2 changes: 1 addition & 1 deletion yolov5/val.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ def run(data,
print(f"Results saved to {colorstr('bold', save_dir)}{s}")
maps = np.zeros(nc) + map
for i, c in enumerate(ap_class):
maps[c] = ap[i]
maps[c] = ap50[i]
return (mp, mr, map50, map, *(loss.cpu() / len(dataloader)).tolist()), maps, t


Expand Down

0 comments on commit 07c9cab

Please sign in to comment.