Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resume training from official yolov3 weights #2

Closed
lianuo opened this issue Sep 3, 2018 · 54 comments
Closed

Resume training from official yolov3 weights #2

lianuo opened this issue Sep 3, 2018 · 54 comments
Assignees
Labels
bug Something isn't working help wanted Extra attention is needed

Comments

@lianuo
Copy link

lianuo commented Sep 3, 2018

Thanks for your improvement of this YOLOv3 implementation.
I have just test the training ,got some problem .
I follow these steps.

  1. load the original yolov3.weight to the model
  2. train it on coco2014 with your train.py.
    3.Got the following logs ,the precision is down fast from 0.5->0.1. but recall is up to 0.35.
    see Screenshot here
    log

4.I save the weight with precision0.2, and run the detect.py
the result like this ,
000000000019
if I do not train,the orginal wight can get this result:
000000000019

I do not know whether I used wrong parameters or something else, lead to generation of many bbox .
could you give me some suggestion?
Thank you~

@lianuo
Copy link
Author

lianuo commented Sep 4, 2018

The loss is down, so I wounder whether the definition of loss lead to this problem?

@Ricardozzf
Copy link

@lianuo
282555 2eda000050740616
person

same issue, i change cls to 2(bg and person), it looks like cls confidence have sth wrong in training

@glenn-jocher
Copy link
Member

glenn-jocher commented Sep 4, 2018

@lianuo @Ricardozzf I've not tried to continue training from the official yolov3 weights. It probably won't pick up smoothly where Joseph Redmon and company left off for a number of reasons, such as the optimizer starting with no knowledge of the previous optimizer's momentum and LR. There are also a few primary differences between my training and the official darknet training:

  • Issue Optimizer Choice: SGD vs Adam #4: train.py uses the Adam optimizer in place of SGD. I could not get SGD to converge with the yolov3 learning rate.
  • Non Maximal Suppression (NMS) is not applied during training, so precision may appear artificially low while training, as many of the False Positives (FPs) in the denominator term P = TP / (TP + FP) are eliminated during testing but not training.
  • Issue Classification Loss: CE vs BCE #3: I use CrossEntropyLoss in place of BinaryCrossEntropyLoss for classification loss during training. I made this change after observing better performance with CE vs BCE (I don't understand the reason for this, as darknet uses BCE). These two loss terms are on line 162 and 163 of models.py. Note that BCEWithLogitsLoss that I use produces the same loss as BinaryCrossEntropyLoss + torch.sigmoid() on the first term, but BCEWithLogitsLoss is preferable for numerical stability reasons. If you want to try to continue training from yolov3.weights you need to use BinaryCrossEntropy or BCEWithLogitsLoss as in the commented line below.
lcls = nM * CrossEntropyLoss(pred_cls[mask], torch.argmax(tcls, 1))
# lcls = nM * BCEWithLogitsLoss2(pred_cls[mask], tcls.float())

@lianuo how many epochs did you train this way? If you make the switch the BCE does this help?

@Ricardozzf your results don't look good. Are you training from scratch or resuming training from yolov3 weights like @lianuo? If you suspect class confidence has a problem it must be because I've swapped CE for BCE. You can switch BCE back on by switching the commented lines above. But also note that if you are training from scratch you need significant number of epochs before things start looking good. In my training I see about 0.50 mAP on COCO2014 validate set after 40 epochs (3 days of training on a 1080 Ti).

@glenn-jocher glenn-jocher changed the title about the Training Resume training from official yolov3 weights Sep 4, 2018
@lianuo
Copy link
Author

lianuo commented Sep 4, 2018

@glenn-jocher Thank you for reply!
I just try resume training from official yolov3 weights with
optimizer = torch.optim.SGD(model.parameters(), lr=.0001, momentum=.9, weight_decay=5e-4, nesterov=True)
and switch to BCEWithLogitsLoss
the precision is down to 0.18 and recall grow to 0.6.just like previous settings.

It is strange , that when I run test.py with this trained weight , I can still have high sore,see the screamshot:
test sore
but when I run detect.py with this trained weight. the result is still not good.like this:
000000000019

Is this because of the method of evaluate mAP?

@lianuo
Copy link
Author

lianuo commented Sep 4, 2018

@glenn-jocher have you use the weights you trained (0.50 mAP on COCO2014) to test a image?
could you share the weight or the test result of images?
it is a little strange that the score is high while the image testing result is not good...
Thank you so much for reply

@lianuo
Copy link
Author

lianuo commented Sep 4, 2018

@Ricardozzf thanks for you information.I am not alone ,haha.

@Ricardozzf
Copy link

Ricardozzf commented Sep 5, 2018

@glenn-jocher thanks for your reply
i have trained the model from scratch for 14 epochs on a TITAN X. In order to make full use of GPU, i chaged batch_size from 12 to 16, and other conf is default.
In training, the model looks good:
image
In testing, I use crowdhuman dataset, the score is high
image
Although the score in training and testing is high, the result processed by detect.py is bad, maybe one thing could be confirmed, testing score didn't match results of detect.py

I hope the information is useful to us.

@glenn-jocher
Copy link
Member

glenn-jocher commented Sep 5, 2018

@lianuo @Ricardozzf thats a good question, I will compare my test.py and detect.py results. I am at epoch 37 training on COCO2014. If I run test.py I see this:

+ Sample [4998/5024] AP: 0.7528 (0.4926)
+ Sample [4999/5024] AP: 0.8333 (0.4927)
+ Sample [5000/5024] AP: 0.5543 (0.4927)
Mean Average Precision: 0.4927

If I then use the epoch 37 checkpoint latest.pt with detect.py I see this on my example image, which is the same problem you guys are seeing.
zidane

I'm wondering if I caused this by switching from BCE to CE. In xView when I used this code I had to increase my -conf_thresh in detect.py to ~0.99 to reduce FP. If I increase -conf_thresh to 0.99 now (and change -nms_thresh to 0.45 to match test.py) then I get this. Better, but still not quite right.

zidane

This is a bit apples and oranges comparison though. The official weights are at 160 epochs and my latest.pt is only at 37 epochs, so its possible that training up to 160 will resolve this problem.

I don't understand why test.py is producing such a high mAP though, especially since it uses a very low -conf_thresh of 0.5. You guys are right, there is an unresolved issue somewhere. I will try and investigate more. The problem seems twofold:

  1. Issue mAP Computation in test.py #5: test.py is possibly over-reporting mAP on trained checkpoints, even though it correctly reports mAP on the official YOLOv3 weights, an odd inconsistency. This seems to be the easiest issue to resolve, so I'll look at this first.
  2. Trained weights seem to require much higher confidence thresholds (~0.99) than typically used in YOLOv3 (~0.8 commonly). This would seem to be unrelated to the CE vs BCE issue, as @lianuo trained from epoch 160 using BCE and still saw poor results.

Any ideas are appreciated as well!

@glenn-jocher glenn-jocher self-assigned this Sep 9, 2018
@glenn-jocher glenn-jocher added bug Something isn't working help wanted Extra attention is needed labels Sep 9, 2018
@glenn-jocher
Copy link
Member

@lianuo @Ricardozzf the overly-high mAPs you were seeing before should be partly fixed in the latest commits, which fixed mAP calculations (see issue #7). The official weights now produce .57 mAP, but the trained weights that before gave me 0.50 mAP now return about 0.13 mAP, much more in-line with the poor boxes you see in your images.

I still don't understand the actual cause of the poor training results however.

@lianuo
Copy link
Author

lianuo commented Sep 10, 2018

@glenn-jocher Thank you for reply~

@lianuo
Copy link
Author

lianuo commented Sep 10, 2018

@glenn-jocher the loss is still decrease when training ,do you think the loss function need to modify?

@xyutao
Copy link

xyutao commented Sep 11, 2018

No warm-up process found for SGD. According to the paper of YOLO9000 and the official code, we need to warm-up the first 1000 iterations to make it better converge:
warmup_lr = lr * batch_size / burn_in, where lr = 1e-3, batch_size = 64 and burn_in = 1000

@xyutao
Copy link

xyutao commented Sep 11, 2018

@glenn-jocher The usage of CrossEntropyLoss might be incorrect. The input shape is (nB, nA, nG, nG, nC), but the pytorch-doc suggests it to be (nB, nC, ...). (See

p = p.view(bs, self.nA, self.bbox_attrs, nG, nG).permute(0, 1, 3, 4, 2).contiguous() # prediction
)
Besides, the torch.argmax(tcls, 1) fetches C from dim=1, but the shape of tcls is actually (nB, nA, nG, nG, nC). Maybe we need to permutate the dims so that C is at dim=1 .

@glenn-jocher
Copy link
Member

glenn-jocher commented Sep 11, 2018

@xyutao I looked into the CELoss function, I think this part is ok. When I start training and debug this spot, the dimensions look good (assuming nC = 80 and assuming we have 47 targets here in the first batch of nB=12 images). I think mask is eliminating all the other dimensions:

tcls.shape
Out[2]: torch.Size([12, 3, 13, 13, 80])

tcls = tcls[mask]
tcls.shape
Out[3]: torch.Size([47, 80])

lcls = nM * CrossEntropyLoss(pred_cls[mask], torch.argmax(tcls, 1))
Out[4]: tensor(206.37325, grad_fn=<MulBackward1>)

pred_cls[mask].shape
Out[5]: torch.Size([47, 80])

torch.argmax(tcls, 1).shape
Out[6]: torch.Size([47])

I linked to your comment on the SGD warmup however, this is a good catch! Issue #4 is open on this. By the first 1000 iterations do you mean the first 1000 batches?

@xyutao
Copy link

xyutao commented Sep 12, 2018

@glenn-jocher Yeah, the first 1000 batches of batch_size=64.

@CF2220160244
Copy link

please help,i have the same error,
did you guys solve this problem?thanks!

@jaelim
Copy link

jaelim commented Sep 18, 2018

@lianuo Hi, just wondering how you loaded a pre-trained weights. Did you add this line of code in train.py?

    # Initialize model 
    model = Darknet(opt.cfg, opt.img_size) 
    model.load_weights(opt.weights_path)

@jaelim
Copy link

jaelim commented Sep 18, 2018

@lianuo I found out from detect.py that you add this line:

load_weights(model, weights_path)

But, now, I'm getting a different error from datasets.py:
image

image

Have you encountered this problem; if yes, how do you deal with it?

@glenn-jocher
Copy link
Member

glenn-jocher commented Sep 18, 2018

@jaelim you resume training from a trained model (i.e. latest.pt) by setting opt.resume = True:

yolov3/train.py

Lines 50 to 53 in 68de92f

if opt.resume:
checkpoint = torch.load('checkpoints/latest.pt', map_location='cpu')
model.load_state_dict(checkpoint['model'])

If you are seeing the error you mentioned it is because you failed to define a proper path to an image, or image folder in detect.py line 14 (no images are loaded). Make sure there are only image files in the path if you specify a path. Also please do not ask questions unrelated to the main issue title in this thread.

parser.add_argument('-image_folder', type=str, default='data/samples', help='path to images')

@glenn-jocher
Copy link
Member

glenn-jocher commented Sep 20, 2018

@xyutao I've switched from Adam to SGD with burn-in (which exponentially ramps up the learning rate from 0 to 0.001 over the first 1000 iterations) in commit a722601:

yolov3/train.py

Lines 115 to 120 in a722601

# SGD burn-in
if (epoch == 0) & (i <= 1000):
power = 4
lr = 1e-3 * (i / 1000) ** power
for g in optimizer.param_groups:
g['lr'] = lr

Unfortunately this caused width and height loss terms to diverge when training from scratch. I saw that these are the only unbounded outputs of the network (all the rest are sigmoided), so I was forced to sigmoid them and create new width and height calculations, after which the training converged. The original and updated ones I made in this commit are:

yolov3/models.py

Lines 121 to 131 in a722601

# Width and height (yolo method)
# w = p[..., 2] # Width
# h = p[..., 3] # Height
# width = torch.exp(w.data) * self.anchor_w
# height = torch.exp(h.data) * self.anchor_h
# Width and height (power method)
w = torch.sigmoid(p[..., 2]) # Width
h = torch.sigmoid(p[..., 3]) # Height
width = ((w.data * 2) ** 2) * self.anchor_w
height = ((h.data * 2) ** 2) * self.anchor_h

If I plot both of these in MATLAB it looks like the lack of a ceiling on the original code is causing the divergence problem. It may be that the original width/height equations are incorrect. Does anyone know where to find the original width and height darknet calcuations?

>> x=linspace(-3,3);
>> y1 = exp(x);
>> y2 = ((logsig(x) * 2).^2);
>> fig; plot(x,y1,'.-'); plot(x,y2,'.-'); h=gca; h.YLim=[0,5]; legend('original','updated'); xyzlabel('network output','anchor width multiple'); fcnfontsize(14)

@ultralytics ultralytics deleted a comment from jaelim Sep 21, 2018
@glenn-jocher
Copy link
Member

glenn-jocher commented Sep 23, 2018

@lianuo @Ricardozzf @xyutao @CF2220160244 @jaelim I have good news. A significant bug in the loss function was found today in issue #12, namely a problem size_average-ing the various loss terms. This caused the lconf_obj term to be 80 times too large (80 = COCO class count), which caused the network to over-detect objects, which I believe was the major problem many of you saw in your training.

I fixed this in commit cf9b4cf, and after the change observed that SGD with burn-in now converges with the original YOLO width/height calculations, so I placed those back in in commit 5d402ad.

Update: Sorry guys I think I might have spoken too soon. The changes help, but resuming training from yolov3.pt still causes P and R to drop from initially high values to lower values after ~50 batches. I think we are getting closer to the source of the problem however, which I feel is in the model loss term somewhere. TODO: I also need to ignore non-best anchors with > 0.50 iou to match yolov3.

@deeppower
Copy link

@sporterman Sorry, i haven't trained my own dataset.

@deeppower
Copy link

@glenn-jocher Thanks for your reply and great work. I have varied conf_threshto 0.2, and mAP is 0.41 at epoch 68. There are still some problems that we need to solve.

@glenn-jocher
Copy link
Member

@deeppower yes the performance is still not as good for training as darknet unfortunately. I tried a few epochs of multi_scale training after epoch 80 and this did not seem to help. I've tried to align everything as closely as possible to darknet, so for example if you resume training from the official yolov3.pt weights the P and R values are very steady (though still dropping slightly over time). This makes me think the loss function is correct, or at least very close to the original darknet loss function. Inference works well, so the problem can not be there, it must be in the training-only code, which could be optimizer, LR scheduler, loss function, target building functions, IOU function, augmentation function...

@glenn-jocher
Copy link
Member

glenn-jocher commented Nov 10, 2018

@nirbenz @okanlv @deeppower @okanlv @xiao1228 Good news I think. I thought about the problem a bit and decided that the loss terms needed rebalancing. In my last plot you can see Classification is consuming the great majority of the loss, which means that it is being optimised at the expense of all the other losses. Ideally the 6 losses would be roughly equal in magnitude so that they are all optimised with equal priority.

So I made a commit that multiplied Objectness loss by 10, and divided Classification loss by 10:

yolov3/models.py

Lines 166 to 176 in e04bb75

if nM > 0:
lx = k * MSELoss(x[mask], tx[mask])
ly = k * MSELoss(y[mask], ty[mask])
lw = k * MSELoss(w[mask], tw[mask])
lh = k * MSELoss(h[mask], th[mask])
# lconf = k * BCEWithLogitsLoss(pred_conf[mask], mask[mask].float())
lconf = (k * 10) * BCEWithLogitsLoss(pred_conf, mask.float())
lcls = (k / 10) * CrossEntropyLoss(pred_cls[mask], torch.argmax(tcls, 1))
# lcls = k * BCEWithLogitsLoss(pred_cls[mask], tcls.float())

I ran this for most of the day on GCP, and after about 10 epochs I overlaid the 3 different trainings I'd done. This new approach seems vastly better, in particular at increasing Recall compared to before. I thought this was exciting enough to post the news right away, I'll have to train for another week to get to 70+ epochs and see the true effect. I'm wondering if there isn't a better way to more automatically balance these 6 equally important loss terms. They seem roughly equal now after 10 epochs, but maybe theres a way to update the balancing terms every epoch with the previous epochs gains. Any ideas?

UPDATE 1: mAP is 0.43 (-conf_thresh 0.20) at epoch 20. Updated plots below (green).
UPDATE 2: mAP is 0.46 (-conf_thresh 0.20) at epoch 35. Updated plots below (green).
UPDATE 3: mAP is 0.46 (-conf_thresh 0.30) at epoch 49 :( Jumps in loss observed during training, possibly due to many restarts of preemtable GCP VM. New commit 45c5567 to run test.py after each training epoch commit and record training mAP to results.txt. Starting new training from scratch using PyTorch 1.0 on GCP. Will post new comment when new results start coming in.
figure_1

@glenn-jocher
Copy link
Member

@deeppower Yes, objectness loss is higher than before because I multiplied it by 10x now. I'm trying to balance the loss terms so they contribute equally to the gradient, or else the largest loss terms will get optimized at the expense of the smaller loss terms. It appears to be working, though my loss term multiples are rather arbitrary unfortunately.

Ideally we want to take this a step further, and better equalize not just the loss terms, but the target distributions to something like zero mean and unity variance, which helps regression networks at least (not sure about object detection). Any experiments you can run on your own would help significantly, I'm just one man with one GPU here, so I can only try a finite sets of things to improve the results.

@glenn-jocher
Copy link
Member

@okanlv I have a question for you. Now that I've defaulted to start training from darknet53.conv.74, would it make sense to freeze those layers for a bit of time before allowing them to change?

I was thinking I could freeze them for the first epoch perhaps, which would be 7328 batches, or half epoch at least. The first 1000 batches are burn in. I feel like it would make sense to do this since the randomly initiated layers might converge must faster without the darknet53.conv.74 layers changing underneath them.

@okanlv
Copy link

okanlv commented Nov 25, 2018

@glenn-jocher In the darknet repo, all layers are trained together after yolov3 is initialized with darknet53.conv.74 weights. In this paper, the authors have showed that updating the parameters of all the layers increases the performance compared to updating the parameters of only the top layers (related to fragile coadaptation of the layers, mentioned in the paper). That being said, your method might also work, because there are a few differences between your approach and the experiments in the paper. If you train yolov3 with your approach, could you share the loss graphs including your approach and the current method? It could be beneficial for further experiments.

@glenn-jocher
Copy link
Member

glenn-jocher commented Nov 27, 2018

@nirbenz @okanlv @deeppower @okanlv @xiao1228 I've started running studies to improve the COCO map when training from darknet53.conv.74. I started with the #2 (comment) model that gets 0.46 mAP at epoch 35. The primary breakthrough there was simply rebalancing the loss terms, multiplying lconf = (k * 10)... * and dividing lcls = (k / 10) * ... to get that 0.46 mAP.

All tests below are only run for the first epoch. Freezing the darknet53 layers (just for the first epoch) showed slightly positive results. It seems further rebalancing the loss terms has the biggest effect. In most ML regression problems the inputs and targets are always recalibrated to zero mean and unity variance, yolov3 does this for the inputs via batch_norm layers but does not do this for the regression targets (the bounding boxes), so I want to try this (regression problems that fail to do this have far worse performance).

Any other experiments you guys want let me know. I'll keep populating this comment as my results come in over the next week.

  mAP (epoch 0) Precision Recall
default #2 (comment) 0.168 0.200 0.175
... + weight_decay=0 0.169 0.200 0.176
... + darknet53 frozen 0.172 0.210 0.179
... + lconf*16 0.181 0.214 0.188
... + lcls/4 0.231 0.268 0.243
... + dkn53 unfrozen + lconf*32 0.237 0.263 0.25
... + lconf*64 0.225 0.249 0.235
... + bbox targets normalization
... + additional experiments?

This is my selected configuration, + lconf*64 in the above table, and in the latest commit. darknet53 is not frozen in the first epoch, as I found this hurts later epochs. I'm now training to about 50 epochs.
mAP is 0.45 (-conf_thresh 0.30) at epoch 12
mAP is 0.48 (-conf_thresh 0.30) at epoch 17
mAP is 0.50 (-conf_thresh 0.30) at epoch 45 (jumps in losses, not sure why again)
mAP is 0.522 (-conf_thresh 0.30 at img_size 416) at epoch 62 (max mAP achieved)
coco_training_loss

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests