A simplified version of the geom-median Python library, updated to be higher performance on Pytorch and with full type-hinting. Thanks to @themachinefan!
pip install torch-geometric-median
This library exports a single function, geometric_median
, which takes a tensor of shape (N, D)
where N
is the number of samples, and D
is the size of each sample, and returns the geometric median of the points in the tensor .
from torch_geometric_median import geometric_median
# Create a tensor of points
points = torch.tensor([
[0.0, 0.0],
[1.0, 1.0],
[2.0, 2.0],
[3.0, 3.0],
[4.0, 4.0],
])
# Compute the geometric median
median = geometric_median(points).median
Like the original geom-median library, this library supports backpropagation through the geometric median computation.
median = geometric_median(points).median
torch.linalg.norm(out.median).backward()
# The gradient of the median with respect to the input points is now in `points.grad`
The geometric_median
function also supports a few extra options:
maxiter
: The maximum number of iterations to run the optimization for. Default is 100.ftol
: If objective value does not improve by at least thisftol
fraction, terminate the algorithm. Default 1e-20.weights
: A tensor of shape(N,)
containing the weights for each point, whereN
is the number of samples. Default isNone
, which means all points are weighted equally.show_progress
: IfTrue
, show a progress bar for the optimization. Default isFalse
.log_objective_values
: IfTrue
, log the objective value at each iteration under the keyobjective_values_log
. Default isFalse
.
median = geometric_median(
points,
maxiter=1000,
ftol=1e-10,
weights=torch.tensor([1.0, 2.0, 1.0, 1.0, 1.0]),
show_progress=True,
log_objective_values=True
).median
It appears that the original geom-median library is no longer maintained, and as pointed out by @themachinefan, the original library is not very performant on Pytorch. This library is a repackaging of @themachinefan's improvements to the original geom-median library, simplying the code to just support pytorch, improving torch performance, and adding full type-hinting.
This library is a repackaging of the work done by the original geom-median library, and @themachinefan in their PR, and as such, all credit goes to these incredible authors. If you use this library, you should cite the original geom-median paper.
This library is licensed under a GPL license, as per the original geom-median library.
Contributions are welcome! Please open an issue or a PR if you have any suggestions or improvements. This library uses PDM for dependency management, Ruff for linting, Pyright for type-checking, and Pytest for tests.
To contribute to the repo, first install dependencies with pdm install
. Tests are run with pdm run pytest
. Formatting is done with pdm run ruff format
and linting with pdm run ruff lint
. Type-checking is done with pdm run pyright
. Please ensure that all tests pass, and that the code is formatted, linted, and type-checked before opening a PR.