Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] fix the wrong weight reference bug in BaseTransformerLayer #1418

Merged
merged 1 commit into from
Nov 2, 2021

Conversation

gaotongxiao
Copy link
Collaborator

@gaotongxiao gaotongxiao commented Oct 20, 2021

Fixes the wrong function/weight reference bug in BaseTransformerLayer when batch_first is True.

Motivation

There are some cases that users clone a model with copy.deepcopy to create, for example, a ModuleList. However, things become tricky when it comes to BaseTransformerLayer. If we create a ModuleList on GPU in this way with batch_first=True and run the forward function, an error will prompt saying the weights are not on GPU.

baselayer = BaseTransformerLayer(
    operation_order=('self_attn', 'ffn'),
    batch_first=True,
    attn_cfgs=dict(
        type='MultiheadAttention',
        embed_dims=256,
        num_heads=8,
    ),
)
baselayers = ModuleList([copy.deepcopy(baselayer) for _ in range(2)])
baselayers.to('cuda')
x = torch.rand(2, 10, 256).cuda()
out = baselayers[0](x)

The error:

RuntimeError: Tensor for argument #3 'mat2' is on CPU, but expected it to be on GPU (while checking arguments for addmm)

It turned out to be the problem of forward_wrapper, which somehow stores the function pointer of self.attn.forward during the initialization and would always call this address after transpose operations. Such logic works well for this module but not for modules deepcopied from it, which will still call the forward function of the original module. Essentially they are still using the weights of the original module to conduct computations!

if self.batch_first:
def _bnc_to_nbc(forward):
"""Because the dataflow('key', 'query', 'value') of
``torch.nn.MultiheadAttention`` is (num_query, batch,
embed_dims), We should adjust the shape of dataflow from
batch_first (batch, num_query, embed_dims) to num_query_first
(num_query ,batch, embed_dims), and recover ``attn_output``
from num_query_first to batch_first."""
def forward_wrapper(**kwargs):
convert_keys = ('key', 'query', 'value')
for key in kwargs.keys():
if key in convert_keys:
kwargs[key] = kwargs[key].transpose(0, 1)
attn_output, attn_output_weights = forward(**kwargs)
return attn_output.transpose(0, 1), attn_output_weights
return forward_wrapper
self.attn.forward = _bnc_to_nbc(self.attn.forward)

Modification

Hardcoding a function pointer in the initializer is not a good practice. I move the transpose logic to forward and it should be much safer. A unit test for this case is also added.

BC-breaking (Optional)

No

@zhouzaida zhouzaida requested a review from jshilong October 20, 2021 09:30
Copy link
Collaborator

@jshilong jshilong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ZwwWayne ZwwWayne merged commit c522b47 into open-mmlab:master Nov 2, 2021
@gaotongxiao gaotongxiao deleted the fix_transformer branch November 3, 2021 01:24
zhouzaida pushed a commit that referenced this pull request Nov 3, 2021
@zhouzaida zhouzaida mentioned this pull request Nov 7, 2021
13 tasks
zhouzaida added a commit that referenced this pull request Apr 16, 2022
* [Feature] Add roiaware pool3d ops from mmdet3d (#1382)

* add ops (roiaware pool3d) in mmdet3d

* refactor code

* fix typo

Co-authored-by: zhouzaida <[email protected]>

* [Feature] Add iou3d op from mmdet3d (#1356)

* add ops (iou3d) in mmdet3d

* add unit test

* refactor code

* refactor code

* refactor code

* refactor code

* refactor code

Co-authored-by: zhouzaida <[email protected]>

* [Fix] Update test data for test_iou3d (#1427)

* Update test data for test_iou3d

* delete blank lines

Co-authored-by: Zaida Zhou <[email protected]>

* [Feature] Add group points ops from mmdet3d (#1415)

* add op (group points) and its related ops (ball query and knn) in mmdet3d

* refactor code

* fix typo

* refactor code

* fix typo

* refactor code

* make input contiguous

Co-authored-by: zhouzaida <[email protected]>

* add mmdet3d op (#1425)

Co-authored-by: zhouzaida <[email protected]>

* [Feature] Loading objects from different backends and dumping objects to different backends (#1330)

* [Feature] Choose storage backend by the prefix of filepath

* refactor FileClient and add unittest

* support loading from different backends

* polish docstring

* fix unittet

* rename attribute str_like_obj to is_str_like_obj

* add infer_client method

* add check_exist method

* rename var client to file_client

* polish docstring

* add join_paths method

* remove join_paths and add _format_path

* enhance unittest

* refactor unittest

* singleton pattern

* fix test_clientio.py

* deprecate CephBackend

* enhance docstring

* refactor unittest for petrel

* refactor unittest for disk backend

* update io.md

* add concat_paths method

* improve docstring

* improve docstring

* add isdir and copyfile for file backend

* delete copyfile and add get_local_path

* remove isdir method of petrel

* fix typo

* add comment and polish docstring

* polish docstring

* rename _path_mapping to _map_path

* polish docstring and fix typo

* refactor get_local_path

* add list_dir_or_file for FileClient

* add list_dir_or_file for PetrelBackend

* fix windows ci

* Add return docstring

* polish docstring

* fix typo

* fix typo

* deprecate the conversion from Path to str

* add docs for loading checkpoints with FileClient

* refactor map_path

* add _ensure_methods to ensure methods have been implemented

* fix list_dir_or_file

* rename _ensure_method_implemented to has_method

* Add CI for pytorch 1.10 (#1431)

* [Feature] Upload checkpoints and logs to ceph (#1375)

* [Feature] Choose storage backend by the prefix of filepath

* refactor FileClient and add unittest

* support loading from different backends

* polish docstring

* fix unittet

* rename attribute str_like_obj to is_str_like_obj

* [Docs] Upload checkpoint to petrel oss

* add infer_client method

* Support uploading checkpoint to petrel oss

* add check_exist method

* refactor CheckpointHook

* support uploading logs to ceph

* rename var client to file_client

* polish docstring

* enhance load_from_ceph

* refactor load_from_ceph

* refactor TextLoggerHook

* change the meaning of out_dir argument

* fix test_checkpoint_hook.py

* add join_paths method

* remove join_paths and add _format_path

* enhance unittest

* refactor unittest

* add a unittest for EvalHook when file backend is petrel

* singleton pattern

* fix test_clientio.py

* deprecate CephBackend

* add warning in load_from_ceph

* fix type of out_suffix

* enhance docstring

* refactor unittest for petrel

* refactor unittest for disk backend

* update io.md

* add concat_paths method

* fix CI

* mock check_exist

* improve docstring

* improve docstring

* improve docstring

* improve docstring

* add isdir and copyfile for file backend

* delete copyfile and add get_local_path

* remove isdir method of petrel

* fix typo

* rename check_exists to exists

* refactor code and polish docstring

* fix windows ci

* add comment and polish docstring

* polish docstring

* polish docstring

* rename _path_mapping to _map_path

* polish docstring and fix typo

* refactor get_local_path

* add list_dir_or_file for FileClient

* add list_dir_or_file for PetrelBackend

* fix windows ci

* Add return docstring

* polish docstring

* fix typo

* fix typo

* fix typo

* fix error when mocking PetrelBackend

* deprecate the conversion from Path to str

* add docs for loading checkpoints with FileClient

* rename keep_log to keep_local

* refactor map_path

* add _ensure_methods to ensure methods have been implemented

* fix list_dir_or_file

* rename _ensure_method_implemented to has_method

* refactor

* polish information

* format information

* bump version to v1.3.16 (#1430)

* [Fix]: Update test data of test_tin_shift (#1426)

* Update test data of test_tin_shift

* Delete tmp.engine

* add pytest raises asserterror test

* raise valueerror, update test log

* add more comment

* Apply suggestions from code review

Co-authored-by: Zaida Zhou <[email protected]>

Co-authored-by: Zaida Zhou <[email protected]>

* fix the wrong function reference bug in BaseTransformerLayer when batch_first is True (#1418)

* [Docs] Add mmcv itself in the docs list (#1441)

* Add mmcv itself in the docs list

* modify link of docs

* [Improve] improve checkpoint loading log (#1446)

* [Feature] Support SigmoidFocalLoss with Cambricon MLU backend (#1346)

* [Feature] Support SigmoidFocalLoss with Cambricon MLU backend

* refactor MMCV_WITH_MLU macro define

* refactor NFU_ALIGN_SIZE, PAD_DOWN and split_pipeline_num

* delete extra fool proofing in cpp

* [Feature] Support SigmoidFocalLossBackward with Cambricon MLU backend

* fix macro definition in SigmoidFocalLoss

* refactor mlu files into clang-format

* refactor sigmoid focal loss test

* refactor Sigmoid Focal Loss file structure.

* fix python lint error

* fix import torch_mlu error type

* fix lint

* refactor clang format style to google

Co-authored-by: zhouzaida <[email protected]>

* [Feature] Support RoiAlign With Cambricon MLU Backend (#1429)

* [Feature] Support NMS with cambricon MLU backend (#1467)

* [Feature] Support BBoxOverlaps with cambricon MLU backend (#1507)

* [Refactor] Format C++ code

* [Refactor] include common_mlu_helper in pytorch_mlu_helper and refactor build condition

* [Improve] Improve the performance of roialign, nms and focalloss with MLU backend (#1572)

* [Improve] Improve the performance of roialign with MLU backend

* replace CHECK_MLU with CHECK_MLU_INPUT

* [Improve] Improve the perf of nms and focallosssigmoid with MLU backend

* [Improve] Improve the performance of roialign with MLU backend (#1741)

* [Feature] Support tin_shift with cambricon MLU backend (#1696)

* [Feature] Support tin_shift with cambricon MLU backend

* [fix] Add the assertion of batch_size in tin_shift.py

* [fix] fix the param check of tin_shift in cambricon code

* [fix] Fix lint failure.

* [fix] Fix source file lint failure.

* Update mmcv/ops/tin_shift.py

[Refactor] Modify the code in mmcv/ops/tin_shift.py.

Co-authored-by: Zaida Zhou <[email protected]>

Co-authored-by: budefei <[email protected]>
Co-authored-by: budefei <[email protected]>
Co-authored-by: Zaida Zhou <[email protected]>

* resolve conflicts and fix lint

* fix mmcv.utils.__init__

* fix mmcv.utils.__init__

* Fix lints and change FLAG

* fix setup and refine

* remove a redundant line

* remove an unnecessary 'f'

* fix compilation error

Co-authored-by: dingchang <[email protected]>
Co-authored-by: zhouzaida <[email protected]>
Co-authored-by: q.yao <[email protected]>
Co-authored-by: Zaida Zhou <[email protected]>
Co-authored-by: pc <[email protected]>
Co-authored-by: Wenwei Zhang <[email protected]>
Co-authored-by: q.yao <[email protected]>
Co-authored-by: Tong Gao <[email protected]>
Co-authored-by: Yuxin Liu <[email protected]>
Co-authored-by: zihanchang11 <[email protected]>
Co-authored-by: shlrao <[email protected]>
Co-authored-by: zhouchenyang <[email protected]>
Co-authored-by: Mrxiaofei <[email protected]>
Co-authored-by: budefei <[email protected]>
Co-authored-by: budefei <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants