forked from open-mmlab/mmsegmentation
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add hmr head for human mesh estimation task. (open-mmlab#162)
* Add the hmr head and discriminator for SMPL parameters. Add test codes and test data. * Modify the codes according to review.
- Loading branch information
1 parent
6604a31
commit 273367d
Showing
7 changed files
with
526 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .hmr_head import MeshHMRHead | ||
|
||
__all__ = ['MeshHMRHead'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,286 @@ | ||
# ------------------------------------------------------------------------------ | ||
# Adapted from https://github.com/akanazawa/hmr | ||
# Original licence: Copyright (c) 2018 akanazawa, under the MIT License. | ||
# ------------------------------------------------------------------------------ | ||
|
||
from abc import abstractmethod | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from .geometric_layers import batch_rodrigues | ||
|
||
|
||
class BaseDiscriminator(nn.Module): | ||
"""Base linear module for SMPL parameter discriminator. | ||
Args: | ||
fc_layers (Tuple): Tuple of neuron count, | ||
such as (9, 32, 32, 1) | ||
use_dropout (Tuple): Tuple of bool define use dropout or not | ||
for each layer, such as (True, True, False) | ||
drop_prob (Tuple): Tuple of float defined the drop prob, | ||
such as (0.5, 0.5, 0) | ||
use_activation(Tuple): Tuple of bool define use active function | ||
or not, such as (True, True, False) | ||
""" | ||
|
||
def __init__(self, fc_layers, use_dropout, drop_prob, use_activation): | ||
super().__init__() | ||
self.fc_layers = fc_layers | ||
self.use_dropout = use_dropout | ||
self.drop_prob = drop_prob | ||
self.use_activation = use_activation | ||
self._check() | ||
self.create_layers() | ||
|
||
def _check(self): | ||
"""Check input to avoid ValueError.""" | ||
if not isinstance(self.fc_layers, tuple): | ||
raise TypeError(f'fc_layers require tuple, ' | ||
f'get {type(self.fc_layers)}') | ||
|
||
if not isinstance(self.use_dropout, tuple): | ||
raise TypeError(f'use_dropout require tuple, ' | ||
f'get {type(self.use_dropout)}') | ||
|
||
if not isinstance(self.drop_prob, tuple): | ||
raise TypeError(f'drop_prob require tuple, ' | ||
f'get {type(self.drop_prob)}') | ||
|
||
if not isinstance(self.use_activation, tuple): | ||
raise TypeError(f'use_activation require tuple, ' | ||
f'get {type(self.use_activation)}') | ||
|
||
l_fc_layer = len(self.fc_layers) | ||
l_use_drop = len(self.use_dropout) | ||
l_drop_prob = len(self.drop_prob) | ||
l_use_activation = len(self.use_activation) | ||
|
||
pass_check = ( | ||
l_fc_layer >= 2 and l_use_drop < l_fc_layer | ||
and l_drop_prob < l_fc_layer and l_use_activation < l_fc_layer | ||
and l_drop_prob == l_use_drop) | ||
|
||
if not pass_check: | ||
msg = 'Wrong BaseDiscriminator parameters!' | ||
raise ValueError(msg) | ||
|
||
def create_layers(self): | ||
"""Create layers.""" | ||
l_fc_layer = len(self.fc_layers) | ||
l_use_drop = len(self.use_dropout) | ||
l_use_activation = len(self.use_activation) | ||
|
||
self.fc_blocks = nn.Sequential() | ||
|
||
for i in range(l_fc_layer - 1): | ||
self.fc_blocks.add_module( | ||
name=f'regressor_fc_{i}', | ||
module=nn.Linear( | ||
in_features=self.fc_layers[i], | ||
out_features=self.fc_layers[i + 1])) | ||
|
||
if i < l_use_activation and self.use_activation[i]: | ||
self.fc_blocks.add_module( | ||
name=f'regressor_af_{i}', module=nn.ReLU()) | ||
|
||
if i < l_use_drop and self.use_dropout[i]: | ||
self.fc_blocks.add_module( | ||
name=f'regressor_fc_dropout_{i}', | ||
module=nn.Dropout(p=self.drop_prob[i])) | ||
|
||
@abstractmethod | ||
def forward(self, inputs): | ||
"""Forward function.""" | ||
msg = 'the base class [BaseDiscriminator] is not callable!' | ||
raise NotImplementedError(msg) | ||
|
||
|
||
class ShapeDiscriminator(BaseDiscriminator): | ||
"""Discriminator for SMPL shape parameters, the inputs is (batch_size x 10) | ||
Args: | ||
fc_layers (Tuple): Tuple of neuron count, | ||
such as (10, 5, 1) | ||
use_dropout (Tuple): Tuple of bool define use dropout or | ||
not for each layer, such as (True, True, False) | ||
drop_prob (Tuple): Tuple of float defined the drop prob, | ||
such as (0.5, 0) | ||
use_activation(Tuple): Tuple of bool define use active | ||
function or not, such as (True, False) | ||
""" | ||
|
||
def __init__(self, fc_layers, use_dropout, drop_prob, use_activation): | ||
if fc_layers[-1] != 1: | ||
msg = f'the neuron count of the last layer ' \ | ||
f'must be 1, but got {fc_layers[-1]}' | ||
raise ValueError(msg) | ||
|
||
super().__init__(fc_layers, use_dropout, drop_prob, use_activation) | ||
|
||
def forward(self, inputs): | ||
"""Forward function.""" | ||
return self.fc_blocks(inputs) | ||
|
||
|
||
class PoseDiscriminator(nn.Module): | ||
"""Discriminator for SMPL pose parameters of each joint. It is composed of | ||
discriminators for each joints. The inputs is (batch_size x joint_count x | ||
9) | ||
Args: | ||
channels (Tuple): Tuple of channel number, | ||
such as (9, 32, 32, 1) | ||
joint_count (int): Joint number, such as 23 | ||
""" | ||
|
||
def __init__(self, channels, joint_count): | ||
super().__init__() | ||
if channels[-1] != 1: | ||
msg = f'the neuron count of the last layer ' \ | ||
f'must be 1, but got {channels[-1]}' | ||
raise ValueError(msg) | ||
self.joint_count = joint_count | ||
|
||
self.conv_blocks = nn.Sequential() | ||
len_channels = len(channels) | ||
for idx in range(len_channels - 2): | ||
self.conv_blocks.add_module( | ||
name=f'conv_{idx}', | ||
module=nn.Conv2d( | ||
in_channels=channels[idx], | ||
out_channels=channels[idx + 1], | ||
kernel_size=1, | ||
stride=1)) | ||
|
||
self.fc_layer = nn.ModuleList() | ||
for idx in range(joint_count): | ||
self.fc_layer.append( | ||
nn.Linear( | ||
in_features=channels[len_channels - 2], out_features=1)) | ||
|
||
def forward(self, inputs): | ||
"""Forward function. | ||
The input is (batch_size x joint_count x 9) | ||
""" | ||
# shape: batch_size x 9 x 1 x joint_count | ||
inputs = inputs.transpose(1, 2).unsqueeze(2).contiguous() | ||
# shape: batch_size x c x 1 x joint_count | ||
internal_outputs = self.conv_blocks(inputs) | ||
outputs = [] | ||
for idx in range(self.joint_count): | ||
outputs.append(self.fc_layer[idx](internal_outputs[:, :, 0, idx])) | ||
|
||
return torch.cat(outputs, 1), internal_outputs | ||
|
||
|
||
class FullPoseDiscriminator(BaseDiscriminator): | ||
"""Discriminator for SMPL pose parameters of all joints. | ||
Args: | ||
fc_layers (Tuple): Tuple of neuron count, | ||
such as (736, 1024, 1024, 1) | ||
use_dropout (Tuple): Tuple of bool define use dropout or not | ||
for each layer, such as (True, True, False) | ||
drop_prob (Tuple): Tuple of float defined the drop prob, | ||
such as (0.5, 0.5, 0) | ||
use_activation(Tuple): Tuple of bool define use active | ||
function or not, such as (True, True, False) | ||
""" | ||
|
||
def __init__(self, fc_layers, use_dropout, drop_prob, use_activation): | ||
if fc_layers[-1] != 1: | ||
msg = f'the neuron count of the last layer must be 1,' \ | ||
f' but got {fc_layers[-1]}' | ||
raise ValueError(msg) | ||
|
||
super().__init__(fc_layers, use_dropout, drop_prob, use_activation) | ||
|
||
def forward(self, inputs): | ||
"""Forward function.""" | ||
return self.fc_blocks(inputs) | ||
|
||
|
||
class SMPLDiscriminator(nn.Module): | ||
"""Discriminator for SMPL pose and shape parameters. It is composed of a | ||
discriminator for SMPL shape parameters, a discriminator for SMPL pose | ||
parameters of all joints and a discriminator for SMPL pose parameters of | ||
each joint. | ||
Args: | ||
beta_channel (tuple of int): Tuple of neuron count of the | ||
discriminator of shape parameters. Defaults to (10, 5, 1) | ||
per_joint_channel (tuple of int): Tuple of neuron count of the | ||
discriminator of each joint. Defaults to (9, 32, 32, 1) | ||
full_pose_channel (tuple of int): Tuple of neuron count of the | ||
discriminator of full pose. Defaults to (23*32, 1024, 1024, 1) | ||
""" | ||
|
||
def __init__(self, | ||
beta_channel=(10, 5, 1), | ||
per_joint_channel=(9, 32, 32, 1), | ||
full_pose_channel=(23 * 32, 1024, 1024, 1)): | ||
super().__init__() | ||
self.joint_count = 23 | ||
# The count of SMPL shape parameter is 10. | ||
assert beta_channel[0] == 10 | ||
# Use 3 x 3 rotation matrix as the pose parameters | ||
# of each joint, so the input channel is 9. | ||
assert per_joint_channel[0] == 9 | ||
assert self.joint_count * per_joint_channel[-2] \ | ||
== full_pose_channel[0] | ||
|
||
self.beta_channel = beta_channel | ||
self.per_joint_channel = per_joint_channel | ||
self.full_pose_channel = full_pose_channel | ||
self._create_sub_modules() | ||
|
||
def _create_sub_modules(self): | ||
"""Create sub discriminators.""" | ||
|
||
# create theta discriminator for each joint | ||
self.pose_discriminator = PoseDiscriminator(self.per_joint_channel, | ||
self.joint_count) | ||
|
||
# create full pose discriminator for total joints | ||
fc_layers = self.full_pose_channel | ||
use_dropout = tuple([False] * (len(fc_layers) - 1)) | ||
drop_prob = tuple([0.5] * (len(fc_layers) - 1)) | ||
use_activation = tuple([True] * (len(fc_layers) - 2) + [False]) | ||
|
||
self.full_pose_discriminator = FullPoseDiscriminator( | ||
fc_layers, use_dropout, drop_prob, use_activation) | ||
|
||
# create shape discriminator for betas | ||
fc_layers = self.beta_channel | ||
use_dropout = tuple([False] * (len(fc_layers) - 1)) | ||
drop_prob = tuple([0.5] * (len(fc_layers) - 1)) | ||
use_activation = tuple([True] * (len(fc_layers) - 2) + [False]) | ||
self.shape_discriminator = ShapeDiscriminator(fc_layers, use_dropout, | ||
drop_prob, | ||
use_activation) | ||
|
||
def forward(self, thetas): | ||
"""Forward function.""" | ||
cams, poses, shapes = thetas | ||
|
||
batch_size = poses.shape[0] | ||
shape_disc_value = self.shape_discriminator(shapes) | ||
|
||
# The first rotation matrix is global rotation | ||
# and is NOT used in discriminator. | ||
if poses.dim() == 2: | ||
rotate_matrixs = \ | ||
batch_rodrigues(poses.contiguous().view(-1, 3) | ||
).view(batch_size, 24, 9)[:, 1:, :] | ||
else: | ||
rotate_matrixs = poses.contiguous().view(batch_size, 24, | ||
9)[:, 1:, :].contiguous() | ||
pose_disc_value, pose_inter_disc_value \ | ||
= self.pose_discriminator(rotate_matrixs) | ||
full_pose_disc_value = self.full_pose_discriminator( | ||
pose_inter_disc_value.contiguous().view(batch_size, -1)) | ||
return torch.cat( | ||
(pose_disc_value, full_pose_disc_value, shape_disc_value), 1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import torch | ||
from torch.nn import functional as F | ||
|
||
|
||
def rot6d_to_rotmat(x): | ||
"""Convert 6D rotation representation to 3x3 rotation matrix. | ||
Based on Zhou et al., "On the Continuity of Rotation | ||
Representations in Neural Networks", CVPR 2019 | ||
Input: | ||
(B,6) Batch of 6-D rotation representations | ||
Output: | ||
(B,3,3) Batch of corresponding rotation matrices | ||
""" | ||
x = x.view(-1, 3, 2) | ||
a1 = x[:, :, 0] | ||
a2 = x[:, :, 1] | ||
b1 = F.normalize(a1) | ||
b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1) | ||
b3 = torch.cross(b1, b2) | ||
return torch.stack((b1, b2, b3), dim=-1) | ||
|
||
|
||
def batch_rodrigues(theta): | ||
"""Convert axis-angle representation to rotation matrix. | ||
Args: | ||
theta: size = [B, 3] | ||
Returns: | ||
Rotation matrix corresponding to the quaternion | ||
-- size = [B, 3, 3] | ||
""" | ||
l2norm = torch.norm(theta + 1e-8, p=2, dim=1) | ||
angle = torch.unsqueeze(l2norm, -1) | ||
normalized = torch.div(theta, angle) | ||
angle = angle * 0.5 | ||
v_cos = torch.cos(angle) | ||
v_sin = torch.sin(angle) | ||
quat = torch.cat([v_cos, v_sin * normalized], dim=1) | ||
return quat_to_rotmat(quat) | ||
|
||
|
||
def quat_to_rotmat(quat): | ||
"""Convert quaternion coefficients to rotation matrix. | ||
Args: | ||
quat: size = [B, 4] 4 <===>(w, x, y, z) | ||
Returns: | ||
Rotation matrix corresponding to the quaternion | ||
-- size = [B, 3, 3] | ||
""" | ||
norm_quat = quat | ||
norm_quat = norm_quat / norm_quat.norm(p=2, dim=1, keepdim=True) | ||
w, x, y, z = norm_quat[:, 0], norm_quat[:, 1],\ | ||
norm_quat[:, 2], norm_quat[:, 3] | ||
|
||
B = quat.size(0) | ||
|
||
w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2) | ||
wx, wy, wz = w * x, w * y, w * z | ||
xy, xz, yz = x * y, x * z, y * z | ||
|
||
rotMat = torch.stack([ | ||
w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz, 2 * wz + 2 * xy, | ||
w2 - x2 + y2 - z2, 2 * yz - 2 * wx, 2 * xz - 2 * wy, 2 * wx + 2 * yz, | ||
w2 - x2 - y2 + z2 | ||
], | ||
dim=1).view(B, 3, 3) | ||
return rotMat |
Oops, something went wrong.