Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

When I training I have a problem. #4

Closed
wntg opened this issue Nov 23, 2022 · 5 comments
Closed

When I training I have a problem. #4

wntg opened this issue Nov 23, 2022 · 5 comments
Labels
good first issue Good for newcomers

Comments

@wntg
Copy link

wntg commented Nov 23, 2022

-- Process 5 terminated with the following error:
Traceback (most recent call last):
File "/root/miniconda3/envs/xclip/lib/python3.7/site-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
fn(i, *args)
File "/root/workspace/UniFormerV2-main/slowfast/utils/multiprocessing.py", line 60, in run
ret = func(cfg)
File "/root/workspace/UniFormerV2-main/tools/train_net.py", line 489, in train
train_loader, model, optimizer, loss_scaler, train_meter, cur_epoch, cfg, writer
File "/root/workspace/UniFormerV2-main/tools/train_net.py", line 105, in train_epoch
loss_scaler(loss, optimizer, clip_grad=cfg.SOLVER.CLIP_GRADIENT, parameters=model.parameters(), create_graph=is_second_order)
File "/root/miniconda3/envs/xclip/lib/python3.7/site-packages/timm/utils/cuda.py", line 43, in call
self._scaler.scale(loss).backward(create_graph=create_graph)
File "/root/miniconda3/envs/xclip/lib/python3.7/site-packages/torch/_tensor.py", line 396, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
File "/root/miniconda3/envs/xclip/lib/python3.7/site-packages/torch/autograd/init.py", line 175, in backward
allow_unreachable=True, accumulate_grad=True) # Calls into the C++ engine to run the backward pass
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.HalfTensor [4, 768, 8, 14, 14]], which is output 0 of AsStridedBackward0, is at version 1; expected version 0 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

@Andy1621
Copy link
Collaborator

Andy1621 commented Nov 23, 2022

It might be caused by different versions of PyTorch or CUDA. You can try to add clone here:

tmp_feats = self.dpe[j](tmp_feats).view(N, C, T_down, L - 1).permute(3, 0, 2, 1).contiguous()

Thus line264 is as follows:

tmp_feats = self.dpe[j](tmp_feats.clone()).view(N, C, T_down, L - 1).permute(3, 0, 2, 1).contiguous()

@Andy1621 Andy1621 pinned this issue Nov 23, 2022
@Andy1621 Andy1621 unpinned this issue Nov 23, 2022
@Andy1621 Andy1621 added the good first issue Good for newcomers label Nov 23, 2022
@Andy1621
Copy link
Collaborator

@wntg Hi! Have you solved the problem?

@Andy1621
Copy link
Collaborator

As there is no more activity, I am closing the issue, don't hesitate to reopen it if necessary.

@xiezexun
Copy link

请问,有vit_b16.pth的链接吗?我下载的预训练模型权重参数不匹配

@Andy1621
Copy link
Collaborator

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
good first issue Good for newcomers
Projects
None yet
Development

No branches or pull requests

3 participants