Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Entropy-regularised Gromov-Wasserstein #165

Open
wants to merge 27 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
2ef3e2b
first attempt at gromov-wasserstein
zsteve Oct 2, 2021
11efd8c
update
zsteve Mar 8, 2022
3273976
Merge branch 'master' into gromov
zsteve Mar 8, 2022
0956c3b
fixed computation of entropic gromov-wasserstein
zsteve Mar 8, 2022
c22d7e7
fixed computation of entropic gromov-wasserstein
zsteve Mar 8, 2022
ff1a92c
Merge branch 'gromov' of https://github.com/JuliaOptimalTransport/Opt…
zsteve Mar 8, 2022
267dfad
exports and tests
zsteve Mar 8, 2022
21609b0
formatting
zsteve Mar 12, 2022
9699e04
Update test/gpu/simple_gpu.jl
zsteve Mar 12, 2022
8510397
update docstrings
zsteve Mar 12, 2022
2f2428f
Merge branch 'gromov' of https://github.com/JuliaOptimalTransport/Opt…
zsteve Mar 12, 2022
20d5885
delete cache file
zsteve Mar 12, 2022
df41c28
add docs and format
zsteve Mar 12, 2022
a7c1a38
remove unnecessary Logging import
zsteve Mar 12, 2022
19e4cab
fix missing power of 2
zsteve Mar 13, 2022
56c4f9b
pull changes from master
zsteve Aug 28, 2022
6e3ac4c
update version number
zsteve Aug 28, 2022
5c376ae
add docs workflow
zsteve Dec 20, 2022
af2a493
add Gromov-Wasserstein to readme
zsteve Jan 25, 2023
6bc3127
bump Julia ver for CI
zsteve Jan 25, 2023
a806f0f
minor edit to runtests
zsteve Jan 25, 2023
f704397
Update .github/workflows/CI.yml
zsteve Jan 27, 2023
71351b9
Update test/runtests.jl
zsteve Jan 27, 2023
f2acc56
delete junk files/dirs
zsteve Jan 27, 2023
0635305
revert runtests.jl
zsteve Jan 27, 2023
c3efe5a
avoid unnecessary allocations
zsteve Jan 27, 2023
39f0b36
format
zsteve Jan 27, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "OptimalTransport"
uuid = "7e02d93a-ae51-4f58-b602-d97af76e3b33"
authors = ["zsteve <[email protected]>"]
version = "0.3.19"
version = "0.3.20"

[deps]
ExactOptimalTransport = "24df6009-d856-477c-ac5c-91f668376b31"
Expand Down
57 changes: 57 additions & 0 deletions src/#gromov.jl#
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Gromov-Wasserstein solver

abstract type EntropicGromovWasserstein end

struct EntropicGromovWassersteinGibbs <: EntropicGromovWasserstein
alg_step::Sinkhorn
end

function entropic_gromov_wasserstein(μ::AbstractVector, ν::AbstractVector, Cμ::AbstractMatrix, Cν::AbstractMatrix, ε::Real,
alg::EntropicGromovWasserstein = EntropicGromovWassersteinGibbs(SinkhornGibbs()); atol = nothing, rtol = nothing, check_convergence = 10, maxiter::Int=1_000, kwargs...)
T = float(Base.promote_eltype(μ, one(eltype(Cμ)) / ε, eltype(Cν)))
C = similar(Cμ, T, size(μ, 1), size(ν, 1))
tmp = similar(C)
plan = similar(C)
@. plan = μ * ν'
plan_prev = similar(C)
plan_prev .= plan
norm_plan = sum(plan)

_atol = atol === nothing ? 0 : atol
_rtol = rtol === nothing ? (_atol > zero(_atol) ? zero(T) : sqrt(eps(T))) : rtol

function get_new_cost!(C, plan, tmp, Cμ, Cν)
A_batched_mul_B!(tmp, Cμ, plan)
A_batched_mul_B!(C, tmp, -4Cν)
# seems to be a missing factor of 4 (or something like that...) compared to the POT implementation?
# added the factor of 4 here to ensure reproducibility for the same value of ε.
# https://github.com/PythonOT/POT/blob/9412f0ad1c0003e659b7d779bf8b6728e0e5e60f/ot/gromov.py#L247
end

get_new_cost!(C, plan, tmp, Cμ, Cν)
to_check_step = check_convergence

isconverged = false
for iter in 1:maxiter
# perform Sinkhorn algorithm
solver = build_solver(μ, ν, C, ε, alg.alg_step; kwargs...)
solve!(solver)
# compute optimal transport plan
plan = sinkhorn_plan(solver)

to_check_step -= 1
if to_check_step == 0 || iter == maxiter
# reset counter
to_check_step = check_convergence
isconverged = sum(abs, plan - plan_prev) < max(_atol, _rtol * norm_plan)
if isconverged
@debug "$Gromov Wasserstein with $(solver.alg) ($iter/$maxiter): converged"
break
end
plan_prev .= plan
end
get_new_cost!(C, plan, tmp, Cμ, Cν)
end

return plan
end
5 changes: 5 additions & 0 deletions src/OptimalTransport.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,19 @@ using LinearAlgebra
using IterativeSolvers
using LogExpFunctions: LogExpFunctions
using NNlib: NNlib
using Logging

export SinkhornGibbs, SinkhornStabilized, SinkhornEpsilonScaling
export SinkhornBarycenterGibbs
export QuadraticOTNewton
export EntropicGromovWassersteinSinkhorn

export sinkhorn, sinkhorn2
export sinkhorn_stabilized, sinkhorn_stabilized_epsscaling, sinkhorn_barycenter
export sinkhorn_unbalanced, sinkhorn_unbalanced2
export sinkhorn_divergence
export quadreg
export entropic_gromov_wasserstein

include("utils.jl")

Expand All @@ -42,4 +45,6 @@ include("quadratic_newton.jl")

include("dual/entropic_dual.jl")

include("gromov.jl")

end
57 changes: 57 additions & 0 deletions src/gromov.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Gromov-Wasserstein solver

abstract type EntropicGromovWasserstein end

struct EntropicGromovWassersteinSinkhorn <: EntropicGromovWasserstein
zsteve marked this conversation as resolved.
Show resolved Hide resolved
alg_step::Sinkhorn
end

function entropic_gromov_wasserstein(μ::AbstractVector, ν::AbstractVector, Cμ::AbstractMatrix, Cν::AbstractMatrix, ε::Real,
alg::EntropicGromovWasserstein = EntropicGromovWassersteinSinkhorn(SinkhornGibbs()); atol = nothing, rtol = nothing, check_convergence = 10, maxiter::Int=1_000, kwargs...)
zsteve marked this conversation as resolved.
Show resolved Hide resolved
T = float(Base.promote_eltype(μ, one(eltype(Cμ)) / ε, eltype(Cν)))
C = similar(Cμ, T, size(μ, 1), size(ν, 1))
tmp = similar(C)
plan = similar(C)
@. plan = μ * ν'
plan_prev = similar(C)
plan_prev .= plan
norm_plan = sum(plan)

_atol = atol === nothing ? 0 : atol
_rtol = rtol === nothing ? (_atol > zero(_atol) ? zero(T) : sqrt(eps(T))) : rtol

function get_new_cost!(C, plan, tmp, Cμ, Cν)
A_batched_mul_B!(tmp, Cμ, plan)
A_batched_mul_B!(C, tmp, -4Cν)
zsteve marked this conversation as resolved.
Show resolved Hide resolved
# seems to be a missing factor of 4 (or something like that...) compared to the POT implementation?
# added the factor of 4 here to ensure reproducibility for the same value of ε.
# https://github.com/PythonOT/POT/blob/9412f0ad1c0003e659b7d779bf8b6728e0e5e60f/ot/gromov.py#L247
end

get_new_cost!(C, plan, tmp, Cμ, Cν)
to_check_step = check_convergence

isconverged = false
for iter in 1:maxiter
# perform Sinkhorn algorithm
solver = build_solver(μ, ν, C, ε, alg.alg_step; kwargs...)
solve!(solver)
# compute optimal transport plan
plan = sinkhorn_plan(solver)

to_check_step -= 1
if to_check_step == 0 || iter == maxiter
# reset counter
to_check_step = check_convergence
isconverged = sum(abs, plan - plan_prev) < max(_atol, _rtol * norm_plan)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

norm_plan is never updated it seems but always set to sum(plan) of the initial randomly initialized plan?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe also avoid allocations here by writing:

Suggested change
isconverged = sum(abs, plan - plan_prev) < max(_atol, _rtol * norm_plan)
plan_prev .-= plan # used as a temporary array here to reduce allocations
isconverged = sum(abs, plan_prev) < max(_atol, _rtol * norm_plan)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

norm_plan is never updated it seems but always set to sum(plan) of the initial randomly initialized plan?

The initial plan is taken to be the independent coupling and here we only consider the balanced problem, so norm_plan should not change. I agree however this is a special case of the unbalanced problem where it would not be constant.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe also avoid allocations here by writing:

Good catch, done

if isconverged
@debug "Gromov Wasserstein with $(solver.alg) ($iter/$maxiter): converged"
break
end
plan_prev .= plan
end
get_new_cost!(C, plan, tmp, Cμ, Cν)
end

return plan
end
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[deps]
zsteve marked this conversation as resolved.
Show resolved Hide resolved
OptimalTransport = "7e02d93a-ae51-4f58-b602-d97af76e3b33"
31 changes: 31 additions & 0 deletions test/gromov.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
using OptimalTransport

using Distances
using PythonOT: PythonOT

using Random
using Test
using LinearAlgebra

const POT = PythonOT

Random.seed!(100)

@testset "gromov.jl" begin
@testset "entropic_gromov_wasserstein" begin
M, N = 250, 200

μ = fill(1/M, M)
zsteve marked this conversation as resolved.
Show resolved Hide resolved
μ_spt = rand(M)
ν = fill(1/N, N)
zsteve marked this conversation as resolved.
Show resolved Hide resolved
ν_spt = rand(N)

Cμ = pairwise(SqEuclidean(), μ_spt)
Cν = pairwise(SqEuclidean(), ν_spt)

γ = entropic_gromov_wasserstein(μ, ν, Cμ, Cν, 0.01; check_convergence = 10)
zsteve marked this conversation as resolved.
Show resolved Hide resolved
γ_pot = PythonOT.entropic_gromov_wasserstein(μ, ν, Cμ, Cν, 0.01)

@test γ ≈ γ_pot rtol = 1e-6
end
end
Empty file added test/gromov.jl~
Empty file.
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ const GROUP = get(ENV, "GROUP", "All")
@safetestset "Quadratically regularized OT" begin
include("quadratic.jl")
end

@safetestset "Gromov-Wasserstein OT" begin
include("gromov.jl")
end
end

# CUDA requires Julia >= 1.6
Expand Down