Skip to content

Commit

Permalink
Update base_head.py (#1843)
Browse files Browse the repository at this point in the history
Co-authored-by: lupeng <[email protected]>
  • Loading branch information
xinxinxinxu and Ben-Louis authored Dec 1, 2022
1 parent 92d5a5b commit eaca6f2
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 4 deletions.
15 changes: 12 additions & 3 deletions mmpose/models/heads/base_head.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from abc import ABCMeta, abstractmethod
from typing import List, Tuple, Union
from typing import List, Sequence, Tuple, Union

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -63,9 +64,17 @@ def _get_in_channels(self) -> Union[int, List[int]]:

return in_channels

def _transform_inputs(self, feats: Tuple[Tensor]
) -> Union[Tensor, Tuple[Tensor]]:
def _transform_inputs(
self,
feats: Union[Tensor, Sequence[Tensor]],
) -> Union[Tensor, Tuple[Tensor]]:
"""Transform multi scale features into the network input."""
if not isinstance(feats, Sequence):
warnings.warn(f'the input of {self._get_name()} is a tensor '
f'instead of a tuple or list. The argument '
f'`input_transform` will be ignored.')
return feats

if self.input_transform == 'resize_concat':
inputs = [feats[i] for i in self.input_index]
resized_inputs = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,31 @@ def test_predict(self):
self.assertEqual(preds[0].keypoints.shape,
batch_data_samples[0].gt_instances.keypoints.shape)

# input transform: output heatmap
# input transform: none
head = HeatmapHead(
in_channels=[16, 32],
out_channels=17,
input_transform='resize_concat',
input_index=[0, 1],
deconv_out_channels=(256, 256),
deconv_kernel_sizes=(4, 4),
conv_out_channels=(256, ),
conv_kernel_sizes=(1, ),
decoder=decoder_cfg)
feats = self._get_feats(batch_size=2, feat_shapes=[(48, 16, 12)])[0]
batch_data_samples = get_packed_inputs(batch_size=2)['data_samples']
with self.assertWarnsRegex(
Warning,
'the input of HeatmapHead is a tensor instead of a tuple '
'or list. The argument `input_transform` will be ignored.'):
preds = head.predict(feats, batch_data_samples)

self.assertTrue(len(preds), 2)
self.assertIsInstance(preds[0], InstanceData)
self.assertEqual(preds[0].keypoints.shape,
batch_data_samples[0].gt_instances.keypoints.shape)

# output heatmap
head = HeatmapHead(
in_channels=[16, 32],
out_channels=17,
Expand Down

0 comments on commit eaca6f2

Please sign in to comment.