Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Strange behaviour with mps M1 #10178

Closed
1 of 2 tasks
jgoo9410 opened this issue Nov 16, 2022 · 38 comments
Closed
1 of 2 tasks

Strange behaviour with mps M1 #10178

jgoo9410 opened this issue Nov 16, 2022 · 38 comments
Labels
bug Something isn't working Stale Stale and schedule for closing soon

Comments

@jgoo9410
Copy link

Search before asking

  • I have searched the YOLOv5 issues and found no similar bug report.

YOLOv5 Component

Detection

Bug

When using the yolov5 python module and targeting mps as the desired device (model = yolov5.load('best_1.pt','mps')) appears to yield inaccurate results.

model = yolov5.load('best_1.pt','mps')

results = model(image)
scores = results.pred[0][:,4]
categories = results.pred[0][:, 5]
boxes = results.pred[0][:, :4]
scores = scores.cpu().numpy()
classes = categories.cpu().numpy()

boxes = boxes.tolist()
classes = classes.tolist()
scores = scores.tolist()

The detections are correct in the sense that there is a detectable object in frame. The problem is that the locations of the boxes are often, but not always incorrect. Some of the time the boxes are about 30% of the image shifted to the left, other times the boxes are exactly correct. When the boxes are incorrectly located they don't jitter, but track with the object just offset by a distance.

My intuition tells me this is a rounding error somewhere. I'm imaging its within the nightly build of torch or torchvision.

>>> torch.__version__
'1.14.0.dev20221116'
>>> torchvision.__version__
'0.15.0.dev20221116'
>>> sys.version
'3.10.6 | packaged by conda-forge | (main, Aug 22 2022, 20:38:29) [Clang 13.0.1 ]'

Has anyone else come across this issue? Would be great to be able to utilise the fast inferencing speed of the M1 GPU.

Targeting the cpu (default) yields the correct results (model = yolov5.load('best_1.pt'))

Environment

No response

Minimal Reproducible Example

No response

Additional

No response

Are you willing to submit a PR?

  • Yes I'd like to help by submitting a PR!
@jgoo9410 jgoo9410 added the bug Something isn't working label Nov 16, 2022
@glenn-jocher
Copy link
Member

@jgoo9410 yes there are silent errors in MPS inference, likely in the Detect() head. If you can help debug and trace the source of the differences that would help. I compared feature outputs into Detect and I believe they were identical. Perhaps anchor/grid tensors on different devices or dtypes might be the cause.

@jgoo9410
Copy link
Author

Hmm. Ive had a look as similar bugs on here, but this one seems different. In some of the other bugs, using mps appears to generate multiple incorrect detections. In my case the detection is 'correct', just displaced. Interestingly though, if I allow the detections to continue, the majority of them end up being correct both in detection and in location, even in areas where objects were previously displaced. The problem posits over a few hundred frames, until eventually the detection bbox and the object converge and the detection becomes correct. The point is that there is definitely a pattern, and it appears only certain detections are bringing the issue to the fore.

GDPR prevents me from uploading the footage here, but I'd be happy to share it privately.

I'm happy to put in a shift to try and find the source of the inaccuracy, although I suspect its not an inherent yolov5 issue.

@glenn-jocher
Copy link
Member

glenn-jocher commented Nov 17, 2022

@jgoo9410 one clue is that classification inference with MPS works correctly (same result as CPU), so this is why I say that the difference is likely in the Detect grids or anchor devices/dtypes.

NMS itself has been converted to CPU when MPS is used as MPS torch ops are not fully supported there so I don't think NMS plays a part in the difference.

I think you're right though that it's likely a torch bug rather than a YOLOv5 bug, but I do know we handle Detect grids/anchors a little strangely, i.e. using custom _apply() function here to make sure they respect module.to() ops.

yolov5/models/common.py

Lines 646 to 656 in a9f895d

def _apply(self, fn):
# Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
self = super()._apply(fn)
if self.pt:
m = self.model.model.model[-1] if self.dmb else self.model.model[-1] # Detect()
m.stride = fn(m.stride)
m.grid = list(map(fn, m.grid))
if isinstance(m.anchor_grid, list):
m.anchor_grid = list(map(fn, m.anchor_grid))
return self

@jgoo9410
Copy link
Author

Okay @glenn-jocher, thanks for the clue.

It will take me a couple of days to get my head around the process but I'll report back.

@glenn-jocher
Copy link
Member

@jgoo9410 great!

@jgoo9410
Copy link
Author

jgoo9410 commented Nov 20, 2022

@glenn-jocher, debugging is probably a slight grandiose term for what I've been doing, but I've been making as many comparisons as I can. Not sure if any of them are going to be useful, but here are the key ones:

In the _make_grid function in yolo.py, I found there to be no difference between 'grid' and 'anchor_grid'.

Following the program flow, I've been examining the outputs function by function and I find that the first deviation comes from the forward function (in yolo.py) output.

The outputs are as follow:

CPU:

`(tensor([[[4.71834e+00, 4.22245e+00, 1.26789e+01, 1.28928e+01, 4.03146e-06, 9.99985e-01],
[1.03612e+01, 3.50429e+00, 2.33117e+01, 7.36838e+00, 2.45448e-05, 9.99982e-01],
[1.94862e+01, 3.23583e+00, 2.55511e+01, 6.56572e+00, 2.87676e-05, 9.99980e-01],
[2.72205e+01, 3.93437e+00, 1.42811e+01, 1.17292e+01, 4.48620e-06, 9.99984e-01],
[3.02114e+00, 9.46344e+00, 7.45484e+00, 2.16289e+01, 1.05020e-05, 9.99985e-01],
[9.84933e+00, 8.52242e+00, 1.89697e+01, 1.60340e+01, 1.86524e-05, 9.99981e-01],
[2.01119e+01, 7.98669e+00, 2.10757e+01, 1.46526e+01, 8.20274e-06, 9.99977e-01],
[2.91984e+01, 9.24670e+00, 8.25948e+00, 1.99892e+01, 8.63184e-06, 9.99984e-01],
[2.34308e+00, 1.98044e+01, 5.97782e+00, 2.38734e+01, 3.16952e-06, 9.99985e-01],
[8.89093e+00, 2.04105e+01, 1.55121e+01, 1.90317e+01, 2.45687e-06, 9.99982e-01],
[2.07845e+01, 2.16434e+01, 1.65778e+01, 1.59136e+01, 2.60486e-07, 9.99982e-01],
[2.94944e+01, 2.10536e+01, 7.02836e+00, 2.10696e+01, 1.27519e-06, 9.99986e-01],
[3.08340e+00, 2.69844e+01, 9.07867e+00, 1.72686e+01, 4.19950e-07, 9.99985e-01],
[1.10167e+01, 2.85149e+01, 1.78563e+01, 1.05998e+01, 6.80164e-07, 9.99984e-01],
[2.03293e+01, 2.89626e+01, 1.92784e+01, 9.26143e+00, 4.13052e-07, 9.99984e-01],
[2.83967e+01, 2.83278e+01, 1.13331e+01, 1.47721e+01, 2.16046e-07, 9.99986e-01],
[4.80459e+00, 4.42266e+00, 1.13291e+01, 1.70981e+01, 6.67889e-07, 9.99985e-01],
[1.05240e+01, 4.18377e+00, 2.06738e+01, 1.11106e+01, 2.10076e-06, 9.99985e-01],
[1.93080e+01, 3.69241e+00, 2.27657e+01, 1.09791e+01, 2.17662e-06, 9.99986e-01],
[2.69854e+01, 4.17255e+00, 1.30957e+01, 1.65031e+01, 7.30851e-07, 9.99986e-01],
[3.52818e+00, 9.32290e+00, 7.53948e+00, 2.45389e+01, 6.24792e-06, 9.99985e-01],
[1.02905e+01, 8.58594e+00, 1.80759e+01, 1.85537e+01, 7.56650e-06, 9.99986e-01],
[2.05525e+01, 8.02950e+00, 1.91249e+01, 1.80997e+01, 3.25458e-06, 9.99988e-01],
[2.90049e+01, 9.30921e+00, 8.40057e+00, 2.35844e+01, 4.59429e-06, 9.99987e-01],
[3.06975e+00, 1.95570e+01, 6.98802e+00, 2.74083e+01, 1.54343e-06, 9.99985e-01],
[9.63404e+00, 2.03523e+01, 1.56002e+01, 2.29159e+01, 1.22175e-06, 9.99987e-01],
[2.14424e+01, 2.15031e+01, 1.60534e+01, 2.20463e+01, 1.08042e-07, 9.99988e-01],
[2.93999e+01, 2.09949e+01, 8.05966e+00, 2.56492e+01, 6.07952e-07, 9.99986e-01],
[3.55364e+00, 2.64519e+01, 9.89420e+00, 2.28044e+01, 8.36822e-08, 9.99986e-01],
[1.13259e+01, 2.75735e+01, 1.73786e+01, 1.65808e+01, 1.10241e-07, 9.99986e-01],
[2.05342e+01, 2.80423e+01, 1.82515e+01, 1.60191e+01, 6.33876e-08, 9.99987e-01],
[2.85252e+01, 2.78957e+01, 1.18571e+01, 2.16438e+01, 5.42176e-08, 9.99987e-01],
[4.80942e+00, 4.16223e+00, 1.63935e+01, 1.33569e+01, 6.45103e-07, 9.99984e-01],
[1.05064e+01, 3.81908e+00, 2.37167e+01, 8.39516e+00, 5.73223e-06, 9.99984e-01],
[1.89762e+01, 3.60396e+00, 2.68293e+01, 8.30921e+00, 6.77903e-06, 9.99982e-01],
[2.63150e+01, 4.01624e+00, 1.86717e+01, 1.28549e+01, 9.55182e-07, 9.99981e-01],
[4.10259e+00, 9.49324e+00, 1.20973e+01, 2.07147e+01, 7.19049e-07, 9.99985e-01],
[1.02990e+01, 8.62792e+00, 2.02154e+01, 1.65284e+01, 8.23237e-06, 9.99983e-01],
[2.02926e+01, 8.32926e+00, 2.26817e+01, 1.60422e+01, 3.37735e-06, 9.99981e-01],
[2.84680e+01, 9.58594e+00, 1.37744e+01, 2.07298e+01, 8.39323e-07, 9.99983e-01],
[3.61726e+00, 1.97145e+01, 1.14012e+01, 2.31716e+01, 1.49972e-07, 9.99982e-01],
[9.60742e+00, 2.02335e+01, 1.78748e+01, 2.03734e+01, 1.49022e-06, 9.99977e-01],
[2.09821e+01, 2.20158e+01, 1.92812e+01, 1.99618e+01, 1.19832e-07, 9.99974e-01],
[2.87747e+01, 2.11591e+01, 1.29945e+01, 2.28179e+01, 1.16483e-07, 9.99981e-01],
[3.74387e+00, 2.67293e+01, 1.52742e+01, 1.83953e+01, 3.91533e-08, 9.99983e-01],
[1.12974e+01, 2.78776e+01, 2.13183e+01, 1.27016e+01, 2.07226e-07, 9.99982e-01],
[2.02751e+01, 2.85259e+01, 2.26451e+01, 1.22173e+01, 1.09836e-07, 9.99981e-01],
[2.79748e+01, 2.82007e+01, 1.74276e+01, 1.71713e+01, 3.50622e-08, 9.99981e-01],
[6.81140e+00, 6.82047e+00, 1.70426e+01, 2.39586e+01, 2.12621e-08, 9.99977e-01],
[2.33509e+01, 6.62131e+00, 1.83138e+01, 2.25639e+01, 3.49461e-08, 9.99978e-01],
[7.03914e+00, 2.26433e+01, 1.65877e+01, 2.44514e+01, 1.74505e-07, 9.99979e-01],
[2.45234e+01, 2.32200e+01, 1.82174e+01, 2.36797e+01, 2.66098e-07, 9.99979e-01],
[6.52683e+00, 6.97478e+00, 2.00686e+01, 1.69333e+01, 6.58365e-08, 9.99985e-01],
[2.28293e+01, 6.71701e+00, 2.19644e+01, 1.69828e+01, 2.11255e-07, 9.99985e-01],
[6.85469e+00, 2.26485e+01, 2.01960e+01, 1.86878e+01, 1.82255e-07, 9.99986e-01],
[2.40645e+01, 2.32000e+01, 2.21281e+01, 1.90851e+01, 3.47547e-07, 9.99985e-01],
[6.62242e+00, 7.83097e+00, 4.39705e+01, 7.22210e+01, 2.97236e-08, 9.99991e-01],
[2.30084e+01, 7.74025e+00, 4.52561e+01, 6.65716e+01, 3.76250e-08, 9.99990e-01],
[7.33187e+00, 2.41173e+01, 3.66571e+01, 6.03408e+01, 6.69018e-08, 9.99991e-01],
[2.44606e+01, 2.47291e+01, 3.77499e+01, 5.81968e+01, 7.96966e-08, 9.99990e-01],
[2.09533e+01, 9.70077e+00, 2.13152e+01, 1.91494e+01, 5.96689e-06, 9.99988e-01],
[1.42876e+01, 1.29688e+01, 3.62461e+01, 2.44363e+01, 7.55122e-05, 9.99984e-01],
[1.22043e+01, 1.26066e+01, 3.50026e+01, 1.94199e+01, 2.45772e-05, 9.99987e-01],
[2.41711e+00, 8.06363e+00, 2.07774e+01, 1.73894e+01, 3.48774e-06, 9.99990e-01],
[1.85576e+01, 1.60251e+01, 2.13179e+01, 2.88633e+01, 1.52165e-06, 9.99991e-01],
[1.38048e+01, 1.78118e+01, 3.53022e+01, 3.34384e+01, 4.21912e-03, 9.99978e-01],
[1.34321e+01, 1.78993e+01, 3.42232e+01, 3.13680e+01, 3.83457e-04, 9.99978e-01],
[5.81598e+00, 1.52368e+01, 2.28645e+01, 2.64963e+01, 3.84153e-07, 9.99989e-01],
[1.98673e+01, 1.96954e+01, 1.70390e+01, 2.77495e+01, 3.06072e-07, 9.99991e-01],
[1.37948e+01, 1.87801e+01, 3.48635e+01, 3.30073e+01, 3.85941e-04, 9.99984e-01],
[1.25275e+01, 1.98425e+01, 3.47314e+01, 3.24159e+01, 2.18887e-04, 9.99982e-01],
[4.46415e+00, 2.07202e+01, 2.22401e+01, 2.55084e+01, 4.89855e-07, 9.99990e-01],
[2.29885e+01, 3.00200e+01, 1.29935e+01, 1.36424e+01, 7.76265e-07, 9.99989e-01],
[1.37503e+01, 2.90878e+01, 3.32264e+01, 1.26388e+01, 7.65563e-07, 9.99986e-01],
[1.17604e+01, 3.08870e+01, 3.27302e+01, 1.01528e+01, 2.25509e-05, 9.99982e-01],
[2.73846e+00, 3.25087e+01, 2.00864e+01, 1.04758e+01, 2.27550e-06, 9.99989e-01],
[2.06341e+01, 9.50592e+00, 2.25581e+01, 2.46319e+01, 1.25005e-06, 9.99983e-01],
[1.41923e+01, 1.32703e+01, 3.56189e+01, 2.74909e+01, 2.70573e-05, 9.99979e-01],
[1.20043e+01, 1.28975e+01, 3.50201e+01, 2.47001e+01, 3.92011e-06, 9.99984e-01],
[2.17026e+00, 9.84475e+00, 2.04956e+01, 2.15374e+01, 6.86600e-07, 9.99985e-01],
[1.83380e+01, 1.57491e+01, 2.46273e+01, 3.57117e+01, 1.51068e-06, 9.99982e-01],
[1.35792e+01, 1.76493e+01, 3.76662e+01, 3.75401e+01, 2.58717e-03, 9.99981e-01],
[1.28589e+01, 1.78031e+01, 3.76197e+01, 3.70722e+01, 2.79134e-04, 9.99982e-01],
[5.58244e+00, 1.65438e+01, 2.22310e+01, 2.95319e+01, 2.56864e-07, 9.99989e-01],
[1.93052e+01, 1.90950e+01, 2.04148e+01, 3.26341e+01, 1.60082e-07, 9.99989e-01],
[1.32516e+01, 1.83992e+01, 3.75025e+01, 3.62630e+01, 2.33269e-04, 9.99984e-01],
[1.20746e+01, 1.95123e+01, 3.72511e+01, 3.54711e+01, 1.24246e-04, 9.99986e-01],
[4.07372e+00, 2.17201e+01, 2.06674e+01, 2.81361e+01, 2.05289e-07, 9.99988e-01],
[2.25850e+01, 2.83755e+01, 1.64697e+01, 1.93651e+01, 1.16195e-07, 9.99990e-01],
[1.36647e+01, 2.76363e+01, 3.36292e+01, 1.73706e+01, 1.84094e-07, 9.99987e-01],
[1.12774e+01, 2.93658e+01, 3.27155e+01, 1.42193e+01, 2.84572e-06, 9.99986e-01],
[2.63078e+00, 3.18044e+01, 1.90255e+01, 1.54871e+01, 3.69856e-07, 9.99988e-01],
[2.01046e+01, 9.51214e+00, 2.91591e+01, 1.93930e+01, 1.90600e-06, 9.99989e-01],
[1.42296e+01, 1.33591e+01, 3.90780e+01, 2.51482e+01, 3.79085e-05, 9.99992e-01],
[1.24922e+01, 1.30163e+01, 3.85068e+01, 2.17499e+01, 7.13624e-06, 9.99990e-01],
[3.45833e+00, 9.17277e+00, 2.58789e+01, 1.93841e+01, 1.02238e-06, 9.99989e-01],
[1.81832e+01, 1.71320e+01, 2.89085e+01, 3.03261e+01, 6.05350e-07, 9.99987e-01],
[1.34964e+01, 1.78379e+01, 3.90230e+01, 3.46369e+01, 2.87768e-03, 9.99986e-01],
[1.27847e+01, 1.77773e+01, 3.99271e+01, 3.25086e+01, 2.35727e-04, 9.99989e-01],
[6.07129e+00, 1.60557e+01, 2.61183e+01, 2.85357e+01, 1.80877e-07, 9.99989e-01],
[1.88122e+01, 2.00261e+01, 2.61396e+01, 2.93050e+01, 6.68130e-08, 9.99969e-01],
[1.33913e+01, 1.85451e+01, 3.90149e+01, 3.33746e+01, 2.99304e-04, 9.99980e-01],
[1.22835e+01, 1.93078e+01, 3.83251e+01, 3.33460e+01, 1.51183e-04, 9.99981e-01],
[4.93227e+00, 2.11337e+01, 2.45216e+01, 2.75695e+01, 1.80397e-07, 9.99985e-01],
[2.18406e+01, 2.95422e+01, 2.20381e+01, 1.59293e+01, 1.19330e-07, 9.99981e-01],
[1.36276e+01, 2.84898e+01, 3.60094e+01, 1.50632e+01, 4.29174e-07, 9.99978e-01],
[1.17048e+01, 3.02172e+01, 3.47991e+01, 1.23526e+01, 6.84682e-06, 9.99980e-01],
[3.09385e+00, 3.26134e+01, 2.30715e+01, 1.32225e+01, 4.59888e-07, 9.99985e-01],
[1.32627e+01, 1.74690e+01, 3.76246e+01, 3.44120e+01, 2.14252e-05, 9.99986e-01],
[1.30707e+01, 1.72933e+01, 3.67107e+01, 3.52413e+01, 3.15300e-05, 9.99982e-01],
[1.38775e+01, 1.73037e+01, 3.69381e+01, 3.40177e+01, 6.73318e-06, 9.99984e-01],
[1.26083e+01, 1.71350e+01, 3.57648e+01, 3.40946e+01, 1.82978e-07, 9.99984e-01],
[1.28156e+01, 1.73255e+01, 3.86070e+01, 3.55889e+01, 2.12197e-05, 9.99985e-01],
[1.31892e+01, 1.72638e+01, 3.81440e+01, 3.59751e+01, 2.38625e-05, 9.99984e-01],
[1.32672e+01, 1.69642e+01, 3.75706e+01, 3.51907e+01, 5.27713e-06, 9.99985e-01],
[1.27958e+01, 1.72852e+01, 3.62158e+01, 3.36891e+01, 3.92006e-07, 9.99982e-01],
[1.27027e+01, 1.75363e+01, 4.29028e+01, 4.25550e+01, 1.09735e-05, 9.99980e-01],
[1.21690e+01, 1.81968e+01, 4.06619e+01, 4.21756e+01, 9.61098e-06, 9.99980e-01],
[1.35517e+01, 1.78563e+01, 4.51561e+01, 4.30401e+01, 9.74589e-06, 9.99980e-01],
[1.19142e+01, 1.85640e+01, 4.31189e+01, 4.55345e+01, 5.59298e-07, 9.99982e-01],
[1.31700e+01, 1.74878e+01, 4.16828e+01, 3.73903e+01, 2.57319e-03, 9.99984e-01],
[1.02095e+01, 1.90038e+01, 4.28035e+01, 4.46768e+01, 2.12460e-04, 9.99983e-01],
[1.14219e+01, 2.09939e+01, 3.97878e+02, 3.72341e+02, 2.70091e-04, 9.99983e-01],
[2.33756e+01, 2.18953e+01, 4.54776e+01, 4.21486e+01, 4.50637e-05, 9.99985e-01],
[2.43078e+01, 2.13468e+01, 4.45326e+01, 4.37532e+01, 8.45956e-06, 9.99982e-01],
[2.26223e+01, 2.22370e+01, 4.45061e+01, 4.12869e+01, 1.91312e-05, 9.99984e-01],
[2.45282e+01, 2.14509e+01, 4.37606e+01, 4.20492e+01, 3.06529e-07, 9.99984e-01],
[2.39035e+01, 2.16280e+01, 4.67028e+01, 4.34007e+01, 4.44824e-05, 9.99986e-01],
[2.41069e+01, 2.13104e+01, 4.63457e+01, 4.46423e+01, 6.65346e-06, 9.99985e-01],
[2.32327e+01, 2.18131e+01, 4.55805e+01, 4.20596e+01, 1.67143e-05, 9.99986e-01],
[2.44200e+01, 2.17605e+01, 4.45806e+01, 4.24237e+01, 5.66683e-07, 9.99982e-01],
[2.40461e+01, 2.20263e+01, 4.99831e+01, 4.95756e+01, 2.57011e-05, 9.99980e-01],
[2.54670e+01, 2.27546e+01, 4.89100e+01, 5.22317e+01, 7.52182e-06, 9.99980e-01],
[2.27149e+01, 2.27715e+01, 5.19827e+01, 5.11854e+01, 3.10184e-05, 9.99980e-01],
[2.53886e+01, 2.33424e+01, 5.12241e+01, 5.45938e+01, 1.26335e-06, 9.99981e-01],
[2.37149e+01, 2.22820e+01, 5.16435e+01, 4.63234e+01, 1.63436e-03, 9.99984e-01],
[2.73978e+01, 2.40050e+01, 5.40506e+01, 5.72263e+01, 1.81253e-04, 9.99983e-01],
[2.53104e+01, 2.62951e+01, 4.91241e+02, 4.58225e+02, 2.41970e-04, 9.99983e-01]]]), None)

(tensor([[[7.56392e+00, 7.16428e+00, 1.50140e+01, 1.45258e+01, 1.16115e-05, 9.99985e-01],
[1.42932e+01, 6.79053e+00, 3.08830e+01, 1.17587e+01, 3.14116e-06, 9.99989e-01],
[1.99927e+01, 6.90281e+00, 3.61015e+01, 8.99111e+00, 2.48195e-07, 9.99987e-01],
...,
[5.47316e+02, 3.60982e+02, 4.96466e+02, 3.46931e+02, 1.26217e-05, 9.99985e-01],
[5.93443e+02, 3.60176e+02, 4.88920e+02, 3.43098e+02, 2.37296e-05, 9.99985e-01],
[6.43792e+02, 3.58729e+02, 4.73254e+02, 3.79098e+02, 7.81242e-06, 9.99986e-01]]]), None)
`

MPS:
`(tensor([[[ 1.20000e+01, 1.20000e+01, 0.00000e+00, 5.20000e+01, 0.00000e+00, 1.00000e+00],
[ 2.00000e+01, 1.20000e+01, 0.00000e+00, 0.00000e+00, 0.00000e+00, 1.00000e+00],
[ 2.80000e+01, 1.20000e+01, 0.00000e+00, 0.00000e+00, 0.00000e+00, 1.00000e+00],
[ 3.60000e+01, 1.20000e+01, 0.00000e+00, 0.00000e+00, 0.00000e+00, 1.00000e+00],
[ 1.20000e+01, 4.00000e+00, 4.00000e+01, 5.20000e+01, 0.00000e+00, 1.00000e+00],
[ 2.00000e+01, 2.00000e+01, 4.00000e+01, 5.20000e+01, 0.00000e+00, 1.00000e+00],
[ 2.80000e+01, 2.00000e+01, 0.00000e+00, 0.00000e+00, 0.00000e+00, 1.00000e+00],
[ 3.60000e+01, 2.00000e+01, 0.00000e+00, 0.00000e+00, 0.00000e+00, 1.00000e+00],
[ 1.20000e+01, 2.80000e+01, 0.00000e+00, 0.00000e+00, 0.00000e+00, 1.00000e+00],
[ 2.00000e+01, 2.80000e+01, 0.00000e+00, 0.00000e+00, 0.00000e+00, 1.00000e+00],
[ 2.80000e+01, 1.20000e+01, 4.00000e+01, 0.00000e+00, 0.00000e+00, 1.00000e+00],
[ 3.60000e+01, 1.20000e+01, 0.00000e+00, 0.00000e+00, 0.00000e+00, 1.00000e+00],
[ 1.20000e+01, 3.60000e+01, 0.00000e+00, 5.20000e+01, 0.00000e+00, 1.00000e+00],
[ 2.00000e+01, 3.60000e+01, 0.00000e+00, 0.00000e+00, 0.00000e+00, 1.00000e+00],
[ 2.80000e+01, 2.00000e+01, 4.00000e+01, 5.20000e+01, 0.00000e+00, 1.00000e+00],
[ 3.60000e+01, 3.60000e+01, 0.00000e+00, 0.00000e+00, 0.00000e+00, 1.00000e+00],
[ 1.20000e+01, 1.20000e+01, 0.00000e+00, 1.20000e+02, 0.00000e+00, 0.00000e+00],
[ 2.00000e+01, 1.20000e+01, 0.00000e+00, 1.20000e+02, 0.00000e+00, 0.00000e+00],
[ 2.80000e+01, 1.20000e+01, 0.00000e+00, 0.00000e+00, 0.00000e+00, 0.00000e+00],
[ 3.60000e+01, 1.20000e+01, 0.00000e+00, 1.20000e+02, 0.00000e+00, 0.00000e+00],
[ 1.20000e+01, 4.00000e+00, 0.00000e+00, 1.20000e+02, 0.00000e+00, 0.00000e+00],
[ 2.00000e+01, 2.00000e+01, 0.00000e+00, 1.20000e+02, 0.00000e+00, 0.00000e+00],
[ 2.80000e+01, 2.00000e+01, 0.00000e+00, 1.20000e+02, 0.00000e+00, 0.00000e+00],
[ 3.60000e+01, 2.00000e+01, 0.00000e+00, 1.20000e+02, 0.00000e+00, 0.00000e+00],
[ 1.20000e+01, 1.20000e+01, 0.00000e+00, 1.20000e+02, 0.00000e+00, 0.00000e+00],
[ 2.00000e+01, 1.20000e+01, 0.00000e+00, 1.20000e+02, 0.00000e+00, 0.00000e+00],
[ 2.80000e+01, 1.20000e+01, 0.00000e+00, 1.20000e+02, 0.00000e+00, 0.00000e+00],
[ 3.60000e+01, 1.20000e+01, 0.00000e+00, 1.20000e+02, 0.00000e+00, 0.00000e+00],
[ 1.20000e+01, 2.00000e+01, 0.00000e+00, 1.20000e+02, 0.00000e+00, 0.00000e+00],
[ 2.00000e+01, 2.00000e+01, 0.00000e+00, 1.20000e+02, 0.00000e+00, 0.00000e+00],
[ 2.80000e+01, 2.00000e+01, 0.00000e+00, 0.00000e+00, 0.00000e+00, 0.00000e+00],
[ 3.60000e+01, 3.60000e+01, 0.00000e+00, 1.20000e+02, 0.00000e+00, 0.00000e+00],
[ 1.20000e+01, 1.20000e+01, 0.00000e+00, 9.20000e+01, 0.00000e+00, 0.00000e+00],
[ 2.00000e+01, 1.20000e+01, 0.00000e+00, 0.00000e+00, 0.00000e+00, 1.00000e+00],
[ 2.80000e+01, 1.20000e+01, 0.00000e+00, 0.00000e+00, 0.00000e+00, 1.00000e+00],
[ 3.60000e+01, 1.20000e+01, 0.00000e+00, 9.20000e+01, 0.00000e+00, 1.00000e+00],
[ 1.20000e+01, 2.00000e+01, 1.32000e+02, 9.20000e+01, 0.00000e+00, 1.00000e+00],
[ 2.00000e+01, 2.00000e+01, 1.32000e+02, 9.20000e+01, 0.00000e+00, 0.00000e+00],
[ 2.80000e+01, 2.00000e+01, 1.32000e+02, 9.20000e+01, 0.00000e+00, 1.00000e+00],
[ 2.00000e+01, 2.00000e+01, 1.32000e+02, 0.00000e+00, 0.00000e+00, 1.00000e+00],
[ 1.20000e+01, 1.20000e+01, 1.32000e+02, 9.20000e+01, 0.00000e+00, 1.00000e+00],
[ 2.00000e+01, 1.20000e+01, 1.32000e+02, 9.20000e+01, 0.00000e+00, 0.00000e+00],
[ 2.80000e+01, 1.20000e+01, 1.32000e+02, 9.20000e+01, 0.00000e+00, 0.00000e+00],
[ 2.00000e+01, 1.20000e+01, 1.32000e+02, 9.20000e+01, 0.00000e+00, 0.00000e+00],
[ 1.20000e+01, 3.60000e+01, 1.32000e+02, 9.20000e+01, 0.00000e+00, 1.00000e+00],
[ 2.00000e+01, 2.00000e+01, 1.32000e+02, 9.20000e+01, 0.00000e+00, 0.00000e+00],
[ 2.80000e+01, 3.60000e+01, 1.32000e+02, 9.20000e+01, 0.00000e+00, 0.00000e+00],
[ 3.60000e+01, 2.00000e+01, 0.00000e+00, 9.20000e+01, 0.00000e+00, 1.00000e+00],
[ 2.40000e+01, 2.40000e+01, 1.20000e+02, 2.44000e+02, 0.00000e+00, 0.00000e+00],
[ 4.00000e+01, 2.40000e+01, 1.20000e+02, 2.44000e+02, 0.00000e+00, 0.00000e+00],
[ 2.40000e+01, 8.00000e+00, 1.20000e+02, 2.44000e+02, 0.00000e+00, 0.00000e+00],
[ 8.00000e+00, 8.00000e+00, 1.20000e+02, 2.44000e+02, 0.00000e+00, 0.00000e+00],
[ 2.40000e+01, 2.40000e+01, 2.48000e+02, 1.80000e+02, 0.00000e+00, 0.00000e+00],
[ 8.00000e+00, -8.00000e+00, 2.48000e+02, 1.80000e+02, 0.00000e+00, 0.00000e+00],
[-8.00000e+00, 8.00000e+00, 2.48000e+02, 1.80000e+02, 0.00000e+00, 0.00000e+00],
[ 8.00000e+00, 8.00000e+00, 2.48000e+02, 1.80000e+02, 0.00000e+00, 0.00000e+00],
[ 2.40000e+01, -8.00000e+00, 2.36000e+02, 4.76000e+02, 0.00000e+00, 0.00000e+00],
[ 8.00000e+00, -8.00000e+00, 2.36000e+02, 4.76000e+02, 0.00000e+00, 0.00000e+00],
[-8.00000e+00, 8.00000e+00, 2.36000e+02, 4.76000e+02, 0.00000e+00, 1.00000e+00],
[ 8.00000e+00, 8.00000e+00, 2.36000e+02, 4.76000e+02, 0.00000e+00, 0.00000e+00],
[ 1.75422e+01, 1.44578e+01, 0.00000e+00, 6.26506e+01, 0.00000e+00, 1.00000e+00],
[ 7.90361e+00, 1.44578e+01, 0.00000e+00, 0.00000e+00, 0.00000e+00, 1.00000e+00],
[-1.73494e+00, 1.44578e+01, 0.00000e+00, 0.00000e+00, 0.00000e+00, 1.00000e+00],
[ 7.90361e+00, 1.44578e+01, 0.00000e+00, 6.26506e+01, 0.00000e+00, 1.00000e+00],
[ 1.75422e+01, 2.40964e+01, 4.81928e+01, 6.26506e+01, 0.00000e+00, 1.00000e+00],
[ 7.90361e+00, 2.40964e+01, 4.81928e+01, 6.26506e+01, 0.00000e+00, 0.00000e+00],
[-1.73494e+00, 2.40964e+01, 4.81928e+01, 6.26506e+01, 0.00000e+00, 1.00000e+00],
[ 7.90361e+00, 4.81928e+00, 0.00000e+00, 6.26506e+01, 0.00000e+00, 1.00000e+00],
[ 1.75422e+01, 1.44578e+01, 4.81928e+01, 6.26506e+01, 0.00000e+00, 1.00000e+00],
[ 7.90361e+00, 1.44578e+01, 4.81928e+01, 6.26506e+01, 0.00000e+00, 1.00000e+00],
[-1.73494e+00, 1.44578e+01, 0.00000e+00, 6.26506e+01, 0.00000e+00, 1.00000e+00],
[ 7.90361e+00, 1.44578e+01, 0.00000e+00, 0.00000e+00, 0.00000e+00, 1.00000e+00],
[ 1.75422e+01, 2.40964e+01, 4.81928e+01, 6.26506e+01, 0.00000e+00, 1.00000e+00],
[ 7.90361e+00, 2.40964e+01, 4.81928e+01, 6.26506e+01, 0.00000e+00, 1.00000e+00],
[-1.73494e+00, 2.40964e+01, 4.81928e+01, 6.26506e+01, 0.00000e+00, 1.00000e+00],
[-1.13735e+01, 4.33735e+01, 0.00000e+00, 6.26506e+01, 0.00000e+00, 1.00000e+00],
[ 1.75422e+01, 1.44578e+01, 0.00000e+00, 1.44578e+02, 0.00000e+00, 0.00000e+00],
[ 7.90361e+00, 1.44578e+01, 0.00000e+00, 1.44578e+02, 0.00000e+00, 0.00000e+00],
[-1.73494e+00, 1.44578e+01, 0.00000e+00, 0.00000e+00, 0.00000e+00, 0.00000e+00],
[ 7.90361e+00, 1.44578e+01, 0.00000e+00, 1.44578e+02, 0.00000e+00, 0.00000e+00],
[ 1.75422e+01, 2.40964e+01, 7.71084e+01, 1.44578e+02, 0.00000e+00, 0.00000e+00],
[ 7.90361e+00, 2.40964e+01, 7.71084e+01, 1.44578e+02, 0.00000e+00, 0.00000e+00],
[-1.73494e+00, 2.40964e+01, 0.00000e+00, 0.00000e+00, 0.00000e+00, 0.00000e+00],
[ 7.90361e+00, 2.40964e+01, 0.00000e+00, 1.44578e+02, 0.00000e+00, 0.00000e+00],
[ 1.75422e+01, 1.44578e+01, 7.71084e+01, 1.44578e+02, 0.00000e+00, 0.00000e+00],
[ 7.90361e+00, 1.44578e+01, 7.71084e+01, 1.44578e+02, 0.00000e+00, 0.00000e+00],
[-1.73494e+00, 1.44578e+01, 0.00000e+00, 1.44578e+02, 0.00000e+00, 0.00000e+00],
[ 7.90361e+00, 1.44578e+01, 0.00000e+00, 1.44578e+02, 0.00000e+00, 0.00000e+00],
[ 1.75422e+01, 2.40964e+01, 7.71084e+01, 1.44578e+02, 0.00000e+00, 0.00000e+00],
[ 7.90361e+00, 2.40964e+01, 7.71084e+01, 1.44578e+02, 0.00000e+00, 0.00000e+00],
[-1.73494e+00, 2.40964e+01, 7.71084e+01, 1.44578e+02, 0.00000e+00, 0.00000e+00],
[-1.13735e+01, 4.33735e+01, 0.00000e+00, 1.44578e+02, 0.00000e+00, 0.00000e+00],
[ 1.75422e+01, 1.44578e+01, 0.00000e+00, 1.10843e+02, 0.00000e+00, 0.00000e+00],
[ 7.90361e+00, 1.44578e+01, 1.59036e+02, 1.10843e+02, 0.00000e+00, 0.00000e+00],
[-1.73494e+00, 1.44578e+01, 0.00000e+00, 0.00000e+00, 0.00000e+00, 1.00000e+00],
[ 7.90361e+00, 1.44578e+01, 0.00000e+00, 0.00000e+00, 0.00000e+00, 0.00000e+00],
[ 1.75422e+01, 2.40964e+01, 1.59036e+02, 1.10843e+02, 0.00000e+00, 1.00000e+00],
[ 7.90361e+00, 2.40964e+01, 1.59036e+02, 1.10843e+02, 0.00000e+00, 0.00000e+00],
[ 1.75422e+01, 2.40964e+01, 0.00000e+00, 1.10843e+02, 0.00000e+00, 0.00000e+00],
[ 7.90361e+00, 2.40964e+01, 0.00000e+00, 0.00000e+00, 0.00000e+00, 1.00000e+00],
[ 1.75422e+01, 1.44578e+01, 1.59036e+02, 1.10843e+02, 0.00000e+00, 0.00000e+00],
[ 7.90361e+00, 1.44578e+01, 1.59036e+02, 1.10843e+02, 0.00000e+00, 0.00000e+00],
[-1.73494e+00, 1.44578e+01, 0.00000e+00, 1.10843e+02, 0.00000e+00, 0.00000e+00],
[ 7.90361e+00, 1.44578e+01, 0.00000e+00, 0.00000e+00, 0.00000e+00, 0.00000e+00],
[ 1.75422e+01, 2.40964e+01, 1.59036e+02, 1.10843e+02, 0.00000e+00, 1.00000e+00],
[ 7.90361e+00, 2.40964e+01, 1.59036e+02, 1.10843e+02, 0.00000e+00, 1.00000e+00],
[-1.73494e+00, 2.40964e+01, 1.59036e+02, 1.10843e+02, 0.00000e+00, 1.00000e+00],
[-1.13735e+01, 4.33735e+01, 0.00000e+00, 1.10843e+02, 0.00000e+00, 1.00000e+00],
[ 3.08434e+00, 2.89157e+01, 1.44578e+02, 2.93976e+02, 0.00000e+00, 0.00000e+00],
[ 2.23614e+01, 2.89157e+01, 1.44578e+02, 2.93976e+02, 0.00000e+00, 0.00000e+00],
[ 3.08434e+00, 9.63855e+00, 1.44578e+02, 2.93976e+02, 0.00000e+00, 0.00000e+00],
[ 2.23614e+01, 9.63855e+00, 1.44578e+02, 2.93976e+02, 0.00000e+00, 0.00000e+00],
[ 3.08434e+00, 2.89157e+01, 2.98795e+02, 2.16867e+02, 0.00000e+00, 0.00000e+00],
[ 2.23614e+01, -9.63855e+00, 2.98795e+02, 2.16867e+02, 0.00000e+00, 0.00000e+00],
[ 3.08434e+00, 9.63855e+00, 2.98795e+02, 0.00000e+00, 0.00000e+00, 0.00000e+00],
[ 2.23614e+01, 9.63855e+00, 2.98795e+02, 2.16867e+02, 0.00000e+00, 0.00000e+00],
[ 3.08434e+00, -9.63855e+00, 2.84337e+02, 5.73494e+02, 0.00000e+00, 1.00000e+00],
[ 2.23614e+01, -9.63855e+00, 2.84337e+02, 5.73494e+02, 0.00000e+00, 0.00000e+00],
[ 4.16386e+01, 9.63855e+00, 2.84337e+02, 5.73494e+02, 0.00000e+00, 1.00000e+00],
[ 2.23614e+01, 9.63855e+00, 2.84337e+02, 5.73494e+02, 0.00000e+00, 1.00000e+00],
[-2.58313e+01, -1.92771e+01, 5.59036e+02, 0.00000e+00, 0.00000e+00, 0.00000e+00],
[-2.58313e+01, -1.92771e+01, 7.51807e+02, 9.54217e+02, 0.00000e+00, 1.00000e+00],
[-2.58313e+01, -1.92771e+01, 1.79759e+03, 0.00000e+00, 0.00000e+00, 1.00000e+00],
[ 3.58209e+01, 3.58209e+01, 1.79104e+02, 3.64179e+02, 0.00000e+00, 0.00000e+00],
[ 1.19403e+01, 3.58209e+01, 1.79104e+02, 3.64179e+02, 0.00000e+00, 0.00000e+00],
[ 3.58209e+01, 1.19403e+01, 1.79104e+02, 3.64179e+02, 0.00000e+00, 0.00000e+00],
[ 1.19403e+01, 1.19403e+01, 1.79104e+02, 3.64179e+02, 0.00000e+00, 0.00000e+00],
[ 3.58209e+01, 3.58209e+01, 3.70149e+02, 2.68657e+02, 0.00000e+00, 0.00000e+00],
[ 1.19403e+01, -1.19403e+01, 3.70149e+02, 2.68657e+02, 0.00000e+00, 0.00000e+00],
[-1.19403e+01, 1.19403e+01, 3.70149e+02, 2.68657e+02, 0.00000e+00, 0.00000e+00],
[ 1.19403e+01, 1.19403e+01, 3.70149e+02, 2.68657e+02, 0.00000e+00, 0.00000e+00],
[-1.19403e+01, -1.19403e+01, 3.52239e+02, 7.10448e+02, 0.00000e+00, 1.00000e+00],
[ 1.19403e+01, -1.19403e+01, 3.52239e+02, 7.10448e+02, 0.00000e+00, 1.00000e+00],
[-1.19403e+01, 1.19403e+01, 3.52239e+02, 7.10448e+02, 0.00000e+00, 0.00000e+00],
[ 1.19403e+01, 1.19403e+01, 3.52239e+02, 7.10448e+02, 0.00000e+00, 1.00000e+00],
[ 7.16418e+01, -2.38806e+01, 6.92537e+02, 0.00000e+00, 0.00000e+00, 0.00000e+00],
[ 7.16418e+01, -2.38806e+01, 9.31343e+02, 1.18209e+03, 0.00000e+00, 1.00000e+00],
[ 7.16418e+01, -2.38806e+01, 2.22687e+03, 0.00000e+00, 0.00000e+00, 1.00000e+00]]], device='mps:0'), None)

(tensor([[[7.56392e+00, 7.16428e+00, 1.50140e+01, 1.45258e+01, 1.16116e-05, 9.99985e-01],
[1.42932e+01, 6.79053e+00, 3.08830e+01, 1.17588e+01, 3.14117e-06, 9.99989e-01],
[1.99927e+01, 6.90281e+00, 3.61015e+01, 8.99110e+00, 2.48196e-07, 9.99987e-01],
...,
[6.37456e+02, 3.57896e+02, 2.67473e+01, 2.94569e+01, 3.69597e-11, 9.99986e-01],
[6.49423e+02, 3.54734e+02, 2.67179e+01, 3.61348e+01, 2.51123e-08, 9.99984e-01],
[6.63320e+02, 3.52777e+02, 8.54136e+00, 3.87246e+01, 1.53809e-07, 9.99989e-01]]], device='mps:0'), None)`

Any thoughts? Is this a red herring?

@glenn-jocher
Copy link
Member

@jgoo9410 not sure what you mean by the first forward function output. You mean the very first convolution in the model?

@jgoo9410
Copy link
Author

jgoo9410 commented Nov 21, 2022

@glenn-jocher

def forward(self, x, augment=False, profile=False, visualize=False):
    if augment:
        return self._forward_augment(x)  # augmented inference, None
    return self._forward_once(x, profile, visualize)  # single-scale inference, train # - not the same!

@glenn-jocher
Copy link
Member

@jgoo9410 oh sorry, forward and forward_once are only called once for the model as a whole. The output of forward and forward once is the same as the output of the whole model.

If I were you I'd print the inputs and values in Detect() with --device cpu and --device mps and start debugging there.

@jgoo9410
Copy link
Author

@glenn-jocher Okay, I'll take another look. Do you have a flowchart I could reference?

@glenn-jocher
Copy link
Member

glenn-jocher commented Nov 22, 2022

@jgoo9410 👋 Hello! Thanks for asking about YOLOv5 🚀 architecture visualization. We've made visualizing YOLOv5 🚀 architectures super easy. There are 3 main ways:

model.yaml

Each model has a corresponding yaml file that displays the model architecture. Here is YOLOv5s, defined by yolov5s.yaml:

# YOLOv5 v6.0 backbone
backbone:
# [from, number, module, args]
[[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
[-1, 3, C3, [128]],
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
[-1, 6, C3, [256]],
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
[-1, 9, C3, [512]],
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
[-1, 3, C3, [1024]],
[-1, 1, SPPF, [1024, 5]], # 9
]
# YOLOv5 v6.0 head
head:
[[-1, 1, Conv, [512, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 6], 1, Concat, [1]], # cat backbone P4
[-1, 3, C3, [512, False]], # 13
[-1, 1, Conv, [256, 1, 1]],
[-1, 1, nn.Upsample, [None, 2, 'nearest']],
[[-1, 4], 1, Concat, [1]], # cat backbone P3
[-1, 3, C3, [256, False]], # 17 (P3/8-small)
[-1, 1, Conv, [256, 3, 2]],
[[-1, 14], 1, Concat, [1]], # cat head P4
[-1, 3, C3, [512, False]], # 20 (P4/16-medium)
[-1, 1, Conv, [512, 3, 2]],
[[-1, 10], 1, Concat, [1]], # cat head P5
[-1, 3, C3, [1024, False]], # 23 (P5/32-large)
[[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
]

TensorBoard Graph

Simply start training a model, and then view the TensorBoard Graph for an interactive view of the model architecture. This example shows YOLOv5s viewed in our NotebookOpen In Colab Open In Kaggle

# Tensorboard
%load_ext tensorboard
%tensorboard --logdir runs/train

# Train YOLOv5s on COCO128 for 3 epochs
python train.py --weights yolov5s.pt --epochs 3

Screenshot 2021-04-11 at 01 10 09

Netron viewer

Use https://netron.app to view exported ONNX models:

python export.py --weights yolov5s.pt --include onnx --simplify

Screen Shot 2022-04-29 at 11 09 23 AM

Good luck 🍀 and let us know if you have any other questions!

@jgoo9410
Copy link
Author

@glenn-jocher that'll be a yes then. Thanks.

@jgoo9410
Copy link
Author

jgoo9410 commented Nov 23, 2022

@glenn-jocher
using:

def _apply(self, fn):
        from pprint import pprint
        torch.set_printoptions(profile="full")
        # Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
        self = super()._apply(fn)
        if self.pt:
            m = self.model.model.model[-1] if self.dmb else self.model.model[-1]  # Detect()
            pprint(vars(m))
            m.stride = fn(m.stride)
            m.grid = list(map(fn, m.grid))
            if isinstance(m.anchor_grid, list):
                m.anchor_grid = list(map(fn, m.anchor_grid))
        return self

I've compared the state of m for both cpu and mps and there is no difference other than the parameter 'device' being present when running mps.

e.g.
CPU: [373., 326.]]]]])],
GPU: [373., 326.]]]]], device='mps:0')],

I've also individually compared the states of m.stride, m.grid and m.anchor_grid at the end of the function, no difference.

Any other ideas?

@glenn-jocher
Copy link
Member

@jgoo9410 really strange. I think the difference is somewhere inside Detect(), i.e. maybe print inference forward pass feature values using both devices at different stages inside Detect(), i.e.

print(x[i].mean()) all throughout Detect() forward:

yolov5/models/yolo.py

Lines 56 to 79 in 7398d2d

def forward(self, x):
z = [] # inference output
for i in range(self.nl):
x[i] = self.m[i](x[i]) # conv
bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
if not self.training: # inference
if self.dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]:
self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i)
if isinstance(self, Segment): # (boxes + masks)
xy, wh, conf, mask = x[i].split((2, 2, self.nc + 1, self.no - self.nc - 5), 4)
xy = (xy.sigmoid() * 2 + self.grid[i]) * self.stride[i] # xy
wh = (wh.sigmoid() * 2) ** 2 * self.anchor_grid[i] # wh
y = torch.cat((xy, wh, conf.sigmoid(), mask), 4)
else: # Detect (boxes only)
xy, wh, conf = x[i].sigmoid().split((2, 2, self.nc + 1), 4)
xy = (xy * 2 + self.grid[i]) * self.stride[i] # xy
wh = (wh * 2) ** 2 * self.anchor_grid[i] # wh
y = torch.cat((xy, wh, conf), 4)
z.append(y.view(bs, self.na * nx * ny, self.no))
return x if self.training else (torch.cat(z, 1),) if self.export else (torch.cat(z, 1), x)

@jgoo9410
Copy link
Author

jgoo9410 commented Nov 23, 2022

@glenn-jocher printing x reveals a difference. Forcing an error and tracing back through the sequence of function calls takes me back to AutoShape.forward(). Examining the variables within AutoShape.forward(), I noticed something strange.

class AutoShape(nn.Module):
    
    ....

    @smart_inference_mode()
    def forward(self, ims, size=640, augment=False, profile=False):
        #printx1')
        # Inference from various sources. For size(height=640, width=1280), RGB images example inputs are:
        #   file:        ims = 'data/images/zidane.jpg'  # str or PosixPath
        #   URI:             = 'https://ultralytics.com/images/zidane.jpg'
        #   OpenCV:          = cv2.imread('image.jpg')[:,:,::-1]  # HWC BGR to RGB x(640,1280,3)
        #   PIL:             = Image.open('image.jpg') or ImageGrab.grab()  # HWC x(640,1280,3)
        #   numpy:           = np.zeros((640,1280,3))  # HWC
        #   torch:           = torch.zeros(16,3,320,640)  # BCHW (scaled to size=640, 0-1 values)
        #   multiple:        = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...]  # list of images

        dt = (Profile(), Profile(), Profile())
        with dt[0]:
            if isinstance(size, int):  # expand
                size = (size, size)
            p = next(self.model.parameters()) if self.pt else torch.empty(1, device=self.model.device)  # param
            autocast = self.amp and (p.device.type != 'cpu')  # Automatic Mixed Precision (AMP) inference
            if isinstance(ims, torch.Tensor):  # torch
                with amp.autocast(autocast):
                    return self.model(ims.to(p.device).type_as(p), augment=augment)  # inference

            # Pre-process
            n, ims = (len(ims), list(ims)) if isinstance(ims, (list, tuple)) else (1, [ims])  # number, list of images
            shape0, shape1, files = [], [], []  # image and inference shapes, filenames
            for i, im in enumerate(ims):
                f = f'image{i}'  # filename
                if isinstance(im, (str, Path)):  # filename or uri
                    im, f = Image.open(requests.get(im, stream=True).raw if str(im).startswith('http') else im), im
                    im = np.asarray(exif_transpose(im))
                elif isinstance(im, Image.Image):  # PIL Image
                    im, f = np.asarray(exif_transpose(im)), getattr(im, 'filename', f) or f
                files.append(Path(f).with_suffix('.jpg').name)
                if im.shape[0] < 5:  # image in CHW
                    im = im.transpose((1, 2, 0))  # reverse dataloader .transpose(2, 0, 1)
                im = im[..., :3] if im.ndim == 3 else cv2.cvtColor(im, cv2.COLOR_GRAY2BGR)  # enforce 3ch input
                s = im.shape[:2]  # HWC
                shape0.append(s)  # image shape
                g = max(size) / max(s)  # gain
                shape1.append([y * g for y in s])
                ims[i] = im if im.data.contiguous else np.ascontiguousarray(im)  # update
            shape1 = [make_divisible(x, self.stride) for x in np.array(shape1).max(0)] if self.pt else size  # inf shape
            x = [letterbox(im, shape1, auto=False)[0] for im in ims]  # pad
            x = np.ascontiguousarray(np.array(x).transpose((0, 3, 1, 2)))  # stack and BHWC to BCHW
            x = torch.from_numpy(x).to(p.device).type_as(p) / 255  # uint8 to fp16/32

        with amp.autocast(autocast):
            # Inference
            with dt[1]:
                pprint(x) <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< here
                y = self.model(x, augment=augment)  # forward

When printing x on a CPU run:

tensor([[[[1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 0.9922, 0.9922, 0.9882, 0.9882, 0.9804, 0.9765, 0.9765, 0.9686, 0.9608, 0.9529, 0.9529, 0.9490, 0.9373, 0.9373, 0.9373, 0.9333, 0.9294, 0.9333, 0.9412, 0.9412, 0.9373, 0.9333, 0.9333, 0.9333, 0.9216, 0.9216, 0.9216, 0.9137, 0.8941, 0.8980, 0.8941, 0.8863, 0.8863, 0.8784, 0.8745, 0.8745, 0.8706, 0.8706, 0.8667, 0.8588, 0.8784, 0.8784, 0.8588, 0.8667, 0.8667, 0.8667, 0.8667, 0.8549, 0.8510, 0.8588, 0.8549, 0.8392, 0.8471, 0.8431, 0.8314, 0.8431, 0.8235, 0.8275, 0.8314, 0.8314, 0.8275, 0.8275, 0.8235, 0.8235, 0.8196, 0.8196, 0.8314, 0.8314, 0.8235, 0.8235, 0.8196, 0.8196, 0.8196, 0.8157,.......

When printing x on an mps run:

tensor([[[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
           1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
           1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
           1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
           1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
           1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
           1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
           1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
           1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
           1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
           1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
           1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
           1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
           1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
           1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
           1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
           1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,

In the 40,000 subsequent lines, no values other than 1 and 0 are present for the mps run.

Looks like a rounding process is taking place to the wrong number of sig figs, or something even more strange.

In between initiating the detection from my program, and it reaching the function above, it passes through torch/nn/modules/module.py "forward_call()" and torch/autograd/grad_mode.py "func()".

My money is on the weirdness coming from somewhere in there, but I might need someone more experienced with these libraries to give me a hand if you want me to trace the problem any further.

@glenn-jocher
Copy link
Member

@jgoo9410 seems like autocast is not behaving well with MPS then. But I don't think the problem is autocast, because python detect.py --device mps does not use autocast and still has problems.

@jgoo9410
Copy link
Author

jgoo9410 commented Dec 1, 2022

I'm afraid I'm at the limit of my current understanding with relation to this issue, and therefore can't comment.

Im sure there are lots of other priorities with yolov5 at the moment, but being able to use mps really would be a game changer in terms of its use on such ubiquitous hardware.

If there is any other way I could contribute to help get to the bottom of this issue, let me know.

@glenn-jocher
Copy link
Member

@jgoo9410 I'll take a look today, hold on.

@glenn-jocher
Copy link
Member

@jgoo9410 also yes this is a semi-priority, the confusion lies in the fact that torch itself is not fully MPS-supportive. Some modules we rely on like torchvision NMS and others are not yet supported, so I've take a wait approach until there is better support.

@jgoo9410
Copy link
Author

jgoo9410 commented Dec 1, 2022

I hear you. Had it not worked at all I'd probably be resigned to waiting, but it so nearly works, and I've witnessed the performance increase. I cant go back to cpu, I've developed a taste for the good stuff!

@glenn-jocher
Copy link
Member

glenn-jocher commented Dec 1, 2022

Yeah it's almost working, and the performance increase is pretty dramatic, so when it does work that'll be great for all us Apple hardware ppl :)

@jgoo9410
Copy link
Author

jgoo9410 commented Dec 1, 2022

Probably time to take out some shares in Apple.

@glenn-jocher
Copy link
Member

@jgoo9410 yes, same situation now. We need aten::_unique2, aten::sort.values_stable and NMS, which are in various stages of support in pytorch/pytorch#77764, so I'd say contribute a thumbs up or comment on those on the torch issue and sit back and wait a bit.

@jgoo9410
Copy link
Author

jgoo9410 commented Dec 1, 2022

Okay, i'll see if I can support the devs working on those features in some way.

I assume as a workaround, you have implemented a CPU version of the missing mps components? If so, is it obvious to you which one is misbehaving?

@glenn-jocher
Copy link
Member

@jgoo9410 torch itself has a fallback to revert to CPU which is PYTORCH_ENABLE_MPS_FALLBACK=1. You can see that YOLOv5 classification is producing identical results on CPU and MPS with this:

PYTORCH_ENABLE_MPS_FALLBACK=1 python classify/predict.py --device cpu
PYTORCH_ENABLE_MPS_FALLBACK=1 python classify/predict.py --device mps

But detection is not. This is why I think there may be an issue with the Detect() head, because the rest of the detection model is very much in common with the classification version:

PYTORCH_ENABLE_MPS_FALLBACK=1 python detect.py --device cpu
PYTORCH_ENABLE_MPS_FALLBACK=1 python detect.py --device mps

@jgoo9410
Copy link
Author

jgoo9410 commented Dec 1, 2022

Right so, if it's not an implemented feature it will fallback to using the CPU version of that feature? Perhaps then the datatype of some mps function is being 'cast' improperly during the transition causing the output I listed above?

@glenn-jocher
Copy link
Member

I can debug this very simply. If I place print(x[i].mean()) at the beginning of the Detect.forward method and print(y.mean()) at the end, I can see that x[i] is identical for CPU and MPS, but y is not, so it's likely that the grids/anchors are not transferring properly between devices.

    def forward(self, x):
        z = []  # inference output
        for i in range(self.nl):
            x[i] = self.m[i](x[i])  # conv
            bs, _, ny, nx = x[i].shape  # x(bs,255,20,20) to x(bs,3,20,20,85)
            x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()

            print(x[i].mean())  # < ---- PRINT RESULT BEFORE GRIDS/ANCHORS ---------------------

            if not self.training:  # inference
                if self.dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]:
                    self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i)

                if isinstance(self, Segment):  # (boxes + masks)
                    xy, wh, conf, mask = x[i].split((2, 2, self.nc + 1, self.no - self.nc - 5), 4)
                    xy = (xy.sigmoid() * 2 + self.grid[i]) * self.stride[i]  # xy
                    wh = (wh.sigmoid() * 2) ** 2 * self.anchor_grid[i]  # wh
                    y = torch.cat((xy, wh, conf.sigmoid(), mask), 4)
                else:  # Detect (boxes only)
                    xy, wh, conf = x[i].sigmoid().split((2, 2, self.nc + 1), 4)
                    xy = (xy * 2 + self.grid[i]) * self.stride[i]  # xy
                    wh = (wh * 2) ** 2 * self.anchor_grid[i]  # wh
                    y = torch.cat((xy, wh, conf), 4)
                z.append(y.view(bs, self.na * nx * ny, self.no))

            print(y.mean())  # < ---- PRINT RESULT AFTER GRIDS/ANCHORS ------------------------

        return x if self.training else (torch.cat(z, 1),) if self.export else (torch.cat(z, 1), x)

@glenn-jocher
Copy link
Member

Looks like self.anchor_grid[i] is different on the two devices.

@glenn-jocher
Copy link
Member

self.stride also, which I think is used to calculate self.anchor_grid, so it's probably the origin of the problem.

@jgoo9410
Copy link
Author

jgoo9410 commented Dec 1, 2022

Is self.anchor_grid dynamic? Could you not just calculate it using a normal cpu run and then copy the value over to the mps run on the basis that it's going to use the cpu for that part anyway?

@glenn-jocher
Copy link
Member

glenn-jocher commented Dec 1, 2022

@jgoo9410 anchor_grid depends on image size.

There seems to be some bugs to work out in MPS. If I run this simple command I get erroneous output on the last term. The .mean() op is failing to run on the correct index. Seems like a PyTorch bug.

            if not self.training:  # inference
                if self.dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]:
                    self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i)

                print(i, self.stride, self.stride[i], self.stride[i].mean())

                0 tensor([ 8., 16., 32.], device='mps:0') tensor(8., device='mps:0') tensor(8., device='mps:0')
                1 tensor([ 8., 16., 32.], device='mps:0') tensor(16., device='mps:0') tensor(8., device='mps:0')
                2 tensor([ 8., 16., 32.], device='mps:0') tensor(32., device='mps:0') tensor(8., device='mps:0')

@jgoo9410
Copy link
Author

jgoo9410 commented Dec 1, 2022

@glenn-jocher okay. For my specific use case, I’m performing detections on a video, so my images are all the same size. In that case, if I were able to use a cpu run to get the anchor_grid and then apply it to the mps run, in theory it should work? Confirming the issue is exclusively with the above.

@glenn-jocher
Copy link
Member

@jgoo9410 you can experiment to see if you can find a solution in this area.

@github-actions
Copy link
Contributor

github-actions bot commented Jan 1, 2023

👋 Hello, this issue has been automatically marked as stale because it has not had recent activity. Please note it will be closed if no further activity occurs.

Access additional YOLOv5 🚀 resources:

Access additional Ultralytics ⚡ resources:

Feel free to inform us of any other issues you discover or feature requests that come to mind in the future. Pull Requests (PRs) are also always welcomed!

Thank you for your contributions to YOLOv5 🚀 and Vision AI ⭐!

@github-actions github-actions bot added the Stale Stale and schedule for closing soon label Jan 1, 2023
@github-actions github-actions bot closed this as not planned Won't fix, can't repro, duplicate, stale Jan 11, 2023
@jgoo9410
Copy link
Author

jgoo9410 commented Jan 20, 2023

So, not much progress in finding the source of the issue, but I have another symptom:

It appears that the whole issue is in the first x-coord in 'boxes'. If I use the difference in y (y2-y1) to determine the size of the square (won't work for rectangles obviously), and then anchor the square using the second x coord (x2), I get perfect tracking.

Interestingly, when the first x value is incorrect, it is always pinned to the centre of the image, exactly 50% of the resolution.

It may be the case that if I used a video that was 'portrait' rather than 'landscape' the issue would be with the Y coordinate, as I assume they are calculated identically.

def detect(image):
    results = model(image)
    scores = results.pred[0][:,4]
    categories = results.pred[0][:, 5]
    boxes = results.pred[0][:, :4]
    scores = scores.cpu().numpy()
    classes = categories.cpu().numpy()

Example of the error:

Everything is working fine at this point and detections are being located correctly:

[[540.2986450195312, 182.97451782226562, 572.0088500976562, 218.41183471679688], [364.5255432128906, 45.33429718017578, 394.2867736816406, 75.33317565917969]]
[[562.913818359375, 182.84805297851562, 595.349609375, 215.48165893554688], [386.10003662109375, 41.88550567626953, 416.5657958984375, 70.09605407714844]]
[[406.08209228515625, 44.18846130371094, 432.9619140625, 70.15641784667969], [589.3087768554688, 177.3563690185547, 619.9940795898438, 210.57054138183594]]
[[428.22216796875, 42.18955993652344, 454.005615234375, 70.91654968261719], [612.3798217773438, 182.76608276367188, 639.4381713867188, 214.57382202148438]]
[[441.1624755859375, 40.16451644897461, 466.58294677734375, 68.98831176757812], [626.9257202148438, 187.1055908203125, 639.8721313476562, 221.32794189453125]]

Here is there the issue starts. You can see that the first x value is being pinned to 352, which is 50% of the width.

[[352.0, 40.07295227050781, 484.5635070800781, 69.09397888183594]]
[[352.0, 45.35044860839844, 505.8199768066406, 72.57444763183594]]
[[352.0, 47.0360107421875, 524.5860595703125, 74.59780883789062]]
[[352.0, 44.714088439941406, 540.20849609375, 69.66326141357422]]

I could understand if this was an overflow or rounding, or something of that nature, but I would have expected the second x value to have suffered in the same way, which it clearly hasn't.

Anything about this jumping out to you @glenn-jocher?

Here is a video of the issue: https://imgur.com/a/vkaEzbi
Sadly I don't own the footage so I can't post it, but I've recorded the detections without the footage.

@glenn-jocher
Copy link
Member

@jgoo9410 This is indeed a strange behavior. The fact that the issue is consistently seen with the first x value consistently being at 50% of the width hints at a potential bug related to miscalculation or transformation. Unfortunately, without access to the actual video footage, pinpointing the exact cause can be challenging. I would recommend continuing to investigate and possibly reaching out to the YOLO community or the Ultralytics team for further insight.

@jgoo9410
Copy link
Author

@glenn-jocher This is quite an old issue, but i have more information on it. It appears to be unrelated to Ultralytics, and is actually related to torch and torch vision.

The issue we present when I created this post in Feb of this year, obviously. In maybe August time, with an inadvertent update of torch and torch vision the issue disappeared and all was well. In the last month or so, the latest version of torch has reintroduced the issue.

I can confirm that with torch==2.0.1 and torchvision==0.15.2, the issue is not present. Thats the version I'm sticking with for now.

I know this is only half the picture, but hopefully someone can tell you a version it is definitely not working with to help with further investigation.

@glenn-jocher
Copy link
Member

Thanks for sharing this valuable information, @jgoo9410. It's great to have additional context on this issue and the specific versions of torch and torchvision where the problem arises. This will be helpful for others encountering similar issues and for ongoing investigation. We appreciate your contribution to the community!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working Stale Stale and schedule for closing soon
Projects
None yet
Development

No branches or pull requests

2 participants