Skip to content

Commit

Permalink
Specify num_classes = 2 for my custom dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
woctezuma committed Aug 13, 2020
1 parent 40ed80c commit 6ecf58f
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions models/detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,10 @@ def build(args):
num_classes = 20 if args.dataset_file != 'coco' else 91
if args.dataset_file == "coco_panoptic":
num_classes = 250
if args.dataset_file == 'custom':
# "You should always use num_classes = max_id + 1 where max_id is the highest class ID that you have in your dataset."
# Reference: https://github.com/facebookresearch/detr/issues/108#issuecomment-650269223
num_classes = 2
device = torch.device(args.device)

backbone = build_backbone(args)
Expand Down

0 comments on commit 6ecf58f

Please sign in to comment.