diff --git a/.github/workflows/main b/.github/workflows/main new file mode 100644 index 00000000..e69de29b diff --git a/Project.toml b/Project.toml index 3c57e214..0ccb544c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "OptimalTransport" uuid = "7e02d93a-ae51-4f58-b602-d97af76e3b33" authors = ["zsteve "] -version = "0.3.20" +version = "0.3.21" [deps] ExactOptimalTransport = "24df6009-d856-477c-ac5c-91f668376b31" diff --git a/README.md b/README.md index c66d42f7..48b58ddc 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ [![Coveralls](https://coveralls.io/repos/github/JuliaOptimalTransport/OptimalTransport.jl/badge.svg?branch=master)](https://coveralls.io/github/JuliaOptimalTransport/OptimalTransport.jl?branch=master) [![Code Style: Blue](https://img.shields.io/badge/code%20style-blue-4495d1.svg)](https://github.com/invenia/BlueStyle) -This package provides some [Julia](https://julialang.org/) implementations of algorithms for computational [optimal transport](https://optimaltransport.github.io/), including the Earth-Mover's (Wasserstein) distance, Sinkhorn algorithm for entropically regularized optimal transport as well as some variants or extensions. +This package provides some [Julia](https://julialang.org/) implementations of algorithms for computational [optimal transport](https://optimaltransport.github.io/), including the Earth-Mover's (Wasserstein) distance, Sinkhorn algorithm for entropically regularized optimal transport as well as variants and extensions, including unbalanced transport and Gromov-Wasserstein matching. Notably, OptimalTransport.jl provides GPU acceleration through [CUDA.jl](https://github.com/JuliaGPU/CUDA.jl/) and [NNlibCUDA.jl](https://github.com/FluxML/NNlibCUDA.jl). diff --git a/docs/src/index.md b/docs/src/index.md index 2e80ae22..4813456f 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -68,6 +68,15 @@ Currently the following algorithms for solving quadratically regularised optimal QuadraticOTNewton ``` +## Gromov-Wasserstein optimal transport + +```@docs +entropic_gromov_wasserstein +``` + +Currently, only entropy-regularised Gromov-Wasserstein is supported. For exact computations, we refer the user to +[PythonOT](https://github.com/JuliaOptimalTransport/PythonOT.jl) to access functionality from the [Python Optimal Transport library](https://pythonot.github.io/). + ## Dual ```@docs diff --git a/src/OptimalTransport.jl b/src/OptimalTransport.jl index bbf0a29a..e29f51de 100644 --- a/src/OptimalTransport.jl +++ b/src/OptimalTransport.jl @@ -17,12 +17,14 @@ using NNlib: NNlib export SinkhornGibbs, SinkhornStabilized, SinkhornEpsilonScaling export SinkhornBarycenterGibbs export QuadraticOTNewton +export EntropicGromovWassersteinSinkhorn export sinkhorn, sinkhorn2 export sinkhorn_stabilized, sinkhorn_stabilized_epsscaling, sinkhorn_barycenter export sinkhorn_unbalanced, sinkhorn_unbalanced2 export sinkhorn_divergence, sinkhorn_divergence_unbalanced export quadreg +export entropic_gromov_wasserstein include("utils.jl") @@ -42,4 +44,6 @@ include("quadratic_newton.jl") include("dual/entropic_dual.jl") +include("gromov.jl") + end diff --git a/src/gromov.jl b/src/gromov.jl new file mode 100644 index 00000000..9c69a116 --- /dev/null +++ b/src/gromov.jl @@ -0,0 +1,93 @@ +# Gromov-Wasserstein solver + +abstract type EntropicGromovWasserstein end + +struct EntropicGromovWassersteinSinkhorn <: EntropicGromovWasserstein + alg_step::Sinkhorn +end + +""" + entropic_gromov_wasserstein( + μ, ν, Cμ, Cν, ε, alg=EntropicGromovWassersteinSinkhorn(SinkhornGibbs()); + atol = nothing, rtol = nothing, check_convergence = 10, maxiter = 1_000, kwargs... + ) + +Computes the transport map for the entropically regularized Gromov-Wasserstein optimal transport problem with source and target +marginals `μ` and `ν` and corresponding cost matrices `Cμ` and `Cν`. That is, we seek `γ` a local minimizer of +```math + \\inf_{\\gamma \\in \\Pi(\\mu, \\nu)} \\sum_{i, j, i', j'} |C^{(\\mu)}_{i,i'} - C^{(\\nu)}_{j,j'}|^2 \\gamma_{i,j} \\gamma_{i',j'} + \\varepsilon \\Omega(\\gamma), +``` +where ``\\Omega(\\gamma)`` is the entropic regularization term, see e.g. [`sinkhorn`](@ref). + +This function employs the iterative method described in (Section 10.6.4, [^PC19]), which solves a series of Sinkhorn iteration sub-problems to arrive at a solution. Note that the Gromov-Wasserstein problem is non-convex owing to the cross-terms in the +objective function, and thus in general one is guaranteed to arrive at a local optimum. + +Every `check_convergence` steps, the current iteration of `γ` is compared with `γ_prev` (the previous iteration from `check_convergence` ago). +The quantity ``\\| \\gamma - \\gamma_\\text{prev} \\|_1`` is compared against `atol` and `rtol`. + +[^PC19]: Peyré, G. and Cuturi, M., 2019. Computational optimal transport: With applications to data science. Foundations and Trends® in Machine Learning, 11(5-6), pp.355-607. + +See also: [`sinkhorn`](@ref) +""" +function entropic_gromov_wasserstein( + μ::AbstractVector, + ν::AbstractVector, + Cμ::AbstractMatrix, + Cν::AbstractMatrix, + ε::Real, + alg::EntropicGromovWasserstein=EntropicGromovWassersteinSinkhorn(SinkhornGibbs()); + atol=nothing, + rtol=nothing, + check_convergence=10, + maxiter::Int=1_000, + kwargs..., +) + T = float(Base.promote_eltype(μ, one(eltype(Cμ)) / ε, eltype(Cν))) + C = similar(Cμ, T, size(μ, 1), size(ν, 1)) + tmp = similar(C) + plan = similar(C) + @. plan = μ * ν' + plan_prev = similar(C) + plan_prev .= plan + norm_plan = sum(plan) + + _atol = atol === nothing ? 0 : atol + _rtol = rtol === nothing ? (_atol > zero(_atol) ? zero(T) : sqrt(eps(T))) : rtol + + function get_new_cost!(C, plan, tmp, Cμ, Cν) + A_batched_mul_B!(tmp, Cμ, plan) + lmul!(-4, tmp) + return A_batched_mul_B!(C, tmp, Cν) + # seems to be a missing factor of 4 (or something like that...) compared to the POT implementation? + # added the factor of 4 here to ensure reproducibility for the same value of ε. + # https://github.com/PythonOT/POT/blob/9412f0ad1c0003e659b7d779bf8b6728e0e5e60f/ot/gromov.py#L247 + end + + get_new_cost!(C, plan, tmp, Cμ, Cν) + to_check_step = check_convergence + + isconverged = false + for iter in 1:maxiter + # perform Sinkhorn algorithm + solver = build_solver(μ, ν, C, ε, alg.alg_step; kwargs...) + solve!(solver) + # compute optimal transport plan + plan = sinkhorn_plan(solver) + + to_check_step -= 1 + if to_check_step == 0 || iter == maxiter + # reset counter + to_check_step = check_convergence + plan_prev .-= plan + isconverged = sum(abs, plan_prev) < max(_atol, _rtol * norm_plan) + if isconverged + @debug "Gromov Wasserstein with $(solver.alg) ($iter/$maxiter): converged" + break + end + plan_prev .= plan + end + get_new_cost!(C, plan, tmp, Cμ, Cν) + end + + return plan +end diff --git a/test/gromov.jl b/test/gromov.jl new file mode 100644 index 00000000..c65ece78 --- /dev/null +++ b/test/gromov.jl @@ -0,0 +1,31 @@ +using OptimalTransport + +using Distances +using PythonOT: PythonOT + +using Random +using Test +using LinearAlgebra + +const POT = PythonOT + +Random.seed!(100) + +@testset "gromov.jl" begin + @testset "entropic_gromov_wasserstein" begin + M, N = 250, 200 + + μ = fill(1 / M, M) + μ_spt = rand(M) + ν = fill(1 / N, N) + ν_spt = rand(N) + + Cμ = pairwise(SqEuclidean(), μ_spt) + Cν = pairwise(SqEuclidean(), ν_spt) + + γ = entropic_gromov_wasserstein(μ, ν, Cμ, Cν, 0.01; check_convergence=10) + γ_pot = PythonOT.entropic_gromov_wasserstein(μ, ν, Cμ, Cν, 0.01) + + @test γ ≈ γ_pot rtol = 1e-6 + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 66314dfa..ffce6ef0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -36,6 +36,10 @@ const GROUP = get(ENV, "GROUP", "All") @safetestset "Quadratically regularized OT" begin include("quadratic.jl") end + + @safetestset "Gromov-Wasserstein OT" begin + include("gromov.jl") + end end # CUDA requires Julia >= 1.6