diff --git a/pytorch3d/ops/perspective_n_points.py b/pytorch3d/ops/perspective_n_points.py index 5452f58d8..e468bd3c2 100644 --- a/pytorch3d/ops/perspective_n_points.py +++ b/pytorch3d/ops/perspective_n_points.py @@ -66,6 +66,10 @@ def _build_M(y, alphas, weight): def prepad(t, v): return F.pad(t, (1, 0), value=v) + if weight is not None: + # weight the alphas in order to get a correctly weighted version of M + alphas = alphas * weight[:, :, None] + # outer left-multiply by alphas def lm_alphas(t): return torch.matmul(alphas[..., None], t).reshape(bs, n, 12) @@ -82,9 +86,6 @@ def lm_alphas(t): dim=-1, ).reshape(bs, -1, 12) - if weight is not None: - M = M * weight.repeat(1, 2)[:, :, None] - return M diff --git a/tests/test_perspective_n_points.py b/tests/test_perspective_n_points.py index c46dbf759..55a8d83ee 100644 --- a/tests/test_perspective_n_points.py +++ b/tests/test_perspective_n_points.py @@ -24,6 +24,21 @@ def setUp(self) -> None: super().setUp() torch.manual_seed(42) + @classmethod + def _generate_epnp_test_from_2d(cls, y): + """ + Instantiate random x_world, x_cam, R, T given a set of input + 2D projections y. + """ + batch_size = y.shape[0] + x_cam = torch.cat((y, torch.rand_like(y[:, :, :1]) * 2.0 + 3.5), dim=2) + x_cam[:, :, :2] *= x_cam[:, :, 2:] # unproject + R = rotation_conversions.random_rotations(batch_size).to(y) + T = torch.randn_like(R[:, :1, :]) + T[:, :, 2] = (T[:, :, 2] + 3.0).clamp(2.0) + x_world = torch.matmul(x_cam - T, R.transpose(1, 2)) + return x_cam, x_world, R, T + def _run_and_print(self, x_world, y, R, T, print_stats, skip_q, check_output=False): sol = perspective_n_points.efficient_pnp( x_world, y.expand_as(x_world[:, :, :2]), skip_quadratic_eq=skip_q @@ -45,16 +60,16 @@ def _run_and_print(self, x_world, y, R, T, print_stats, skip_q, check_output=Fal ) self.assertClose(err_2d, sol.err_2d, msg=assert_msg) - self.assertTrue((err_2d < 1e-4).all(), msg=assert_msg) + self.assertTrue((err_2d < 5e-4).all(), msg=assert_msg) def norm_fn(t): return t.norm(dim=-1) self.assertNormsClose( - T, sol.T[:, None, :], rtol=3e-3, norm_fn=norm_fn, msg=assert_msg + T, sol.T[:, None, :], rtol=4e-3, norm_fn=norm_fn, msg=assert_msg ) self.assertNormsClose( - R_quat, R_est_quat, rtol=3e-4, norm_fn=norm_fn, msg=assert_msg + R_quat, R_est_quat, rtol=3e-3, norm_fn=norm_fn, msg=assert_msg ) if print_stats: @@ -71,12 +86,9 @@ def norm_fn(t): print("T_hat | T_gt\n", T_gt) def _testcase_from_2d(self, y, print_stats, benchmark, skip_q=False): - x_cam = torch.cat((y, torch.rand_like(y[:, :1]) * 2.0 + 3.5), dim=1) - x_cam[:, :2] *= x_cam[:, 2:] # unproject - - R = rotation_conversions.random_rotations(16).to(y) - T = torch.randn_like(R[:, :1, :]) - x_world = torch.matmul(x_cam - T, R.transpose(1, 2)) + x_cam, x_world, R, T = TestPerspectiveNPoints._generate_epnp_test_from_2d( + y[None].repeat(16, 1, 1) + ) if print_stats: print("Run without noise") @@ -129,3 +141,45 @@ def test_perspective_n_points(self, print_stats=False): benchmark=False, skip_q=skip_q, ) + + def test_weighted_perspective_n_points(self, batch_size=16, num_pts=200): + # instantiate random x_world and y + y = torch.randn((batch_size, num_pts, 2)).cuda() / 3.0 + x_cam, x_world, R, T = TestPerspectiveNPoints._generate_epnp_test_from_2d(y) + + # randomly drop 50% of the rows + weights = (torch.rand_like(x_world[:, :, 0]) > 0.5).float() + + # make sure we retain at least 6 points for each case + weights[:, :6] = 1.0 + + # fill ignored y with trash to ensure that we get different + # solution in case the weighting is wrong + y = y + (1 - weights[:, :, None]) * 100.0 + + def norm_fn(t): + return t.norm(dim=-1) + + for skip_quadratic_eq in (True, False): + # get the solution for the 0/1 weighted case + sol = perspective_n_points.efficient_pnp( + x_world, y, skip_quadratic_eq=skip_quadratic_eq, weights=weights + ) + sol_R_quat = rotation_conversions.matrix_to_quaternion(sol.R) + sol_T = sol.T + + # check that running only on points with non-zero weights ends in the + # same place as running the 0/1 weighted version + for i in range(batch_size): + ok = weights[i] > 0 + x_world_ok = x_world[i, ok][None] + y_ok = y[i, ok][None] + sol_ok = perspective_n_points.efficient_pnp( + x_world_ok, y_ok, skip_quadratic_eq=False + ) + R_est_quat_ok = rotation_conversions.matrix_to_quaternion(sol_ok.R) + + self.assertNormsClose(sol_T[i], sol_ok.T[0], rtol=3e-3, norm_fn=norm_fn) + self.assertNormsClose( + sol_R_quat[i], R_est_quat_ok[0], rtol=3e-4, norm_fn=norm_fn + )