Skip to content

Commit

Permalink
Merge pull request #1647 from tillns/weighted_procrustes
Browse files Browse the repository at this point in the history
Weighted Procrustes Analysis
  • Loading branch information
mikedh authored Aug 4, 2022
2 parents 563ed96 + 8793323 commit 5e098c6
Showing 1 changed file with 21 additions and 4 deletions.
25 changes: 21 additions & 4 deletions trimesh/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,21 +180,28 @@ def key_points(m, count):

def procrustes(a,
b,
weights=None,
reflection=True,
translation=True,
scale=True,
return_cost=True):
"""
Perform Procrustes' analysis subject to constraints. Finds the
transformation T mapping a to b which minimizes the square sum
distances between Ta and b, also called the cost.
distances between Ta and b, also called the cost. Optionally
specify different weights for the points in a to minimize the
weighted square sum distances between Ta and b. This can
improve transformation robustness on noisy data if the points'
probability distribution is known.
Parameters
----------
a : (n,3) float
List of points in space
b : (n,3) float
List of points in space
weights : (n,) float
List of floats representing how much weight is assigned to each point of a
reflection : bool
If the transformation is allowed reflections
translation : bool
Expand All @@ -216,32 +223,42 @@ def procrustes(a,

a = np.asanyarray(a, dtype=np.float64)
b = np.asanyarray(b, dtype=np.float64)
weights = np.ones(len(a)) if weights is None else weights
w = np.asanyarray(weights, dtype=np.float64)
w_normed = w / np.sum(w)
w_mat = np.diag(w)
if not util.is_shape(a, (-1, 3)) or not util.is_shape(b, (-1, 3)):
raise ValueError('points must be (n,3)!')

if len(a) != len(b):
raise ValueError('a and b must contain same number of points!')

if len(w) != len(a):
raise ValueError("weights must have same length as a and b!")

# Remove translation component
if translation:
acenter = a.mean(axis=0)
# acenter is a weighted average of the individual points.
acenter = np.sum(np.expand_dims(w_normed, axis=1) * a, axis=0)
bcenter = b.mean(axis=0)
else:
acenter = np.zeros(a.shape[1])
bcenter = np.zeros(b.shape[1])

# Remove scale component
if scale:
ascale = np.sqrt(((a - acenter)**2).sum() / len(a))
# ascale is the square root of weighted average of the squared difference between each point and acenter.
ascale = np.sqrt(np.sum(((a - acenter)**2) * np.expand_dims(w_normed, axis=1)))
bscale = np.sqrt(((b - bcenter)**2).sum() / len(b))
else:
ascale = 1
bscale = 1

# Use SVD to find optimal orthogonal matrix R
# constrained to det(R) = 1 if necessary.
# w_mat is multiplied with the centered and scaled a, such that the points can be weighted differently.
u, s, vh = np.linalg.svd(
np.dot(((b - bcenter) / bscale).T, ((a - acenter) / ascale)))
np.dot(((b - bcenter) / bscale).T, (w_mat.dot((a - acenter) / ascale))))

if reflection:
R = np.dot(u, vh)
Expand Down

0 comments on commit 5e098c6

Please sign in to comment.