diff --git a/src/Metric.jl b/src/Metric.jl index 97dfbae42a..bbf0abebfe 100644 --- a/src/Metric.jl +++ b/src/Metric.jl @@ -144,6 +144,21 @@ end return inner(M.manifold, x, v, w) end +@traitfn function inner(B::VectorBundleFibers{<:CotangentSpaceType, MMT}, x, v, w) where {MT<:Manifold, + GT<:Metric, + MMT<:MetricManifold{MT,GT}; + !HasMetric{MT,GT}} + ginv = inverse_local_metric(B.M, x) + return dot(ginv * v, ginv * w) +end + +@traitfn function inner(B::VectorBundleFibers{<:CotangentSpaceType, MMT}, x, v, w) where {MT<:Manifold, + GT<:Metric, + MMT<:MetricManifold{MT,GT}; + HasMetric{MT,GT}} + return inner(VectorBundleFibers(B.VS, B.M.manifold), x, v, w) +end + @traitfn function norm(M::MMT, x, v) where {MT<:Manifold, GT<:Metric, MMT<:MetricManifold{MT,GT}; @@ -457,3 +472,47 @@ end HasMetric{MT,GT}} return is_tangent_vector(M.manifold, x, v; kwargs...) end + +@traitfn function flat!(M::MMT, + v::FVector{CotangentSpaceType}, + x, + w::FVector{TangentSpaceType}) where {MT<:Manifold, + GT<:Metric, + MMT<:MetricManifold{MT,GT}; + !HasMetric{MT,GT}} + g = local_metric(M, x) + copyto!(v, g*w) + return v +end + +@traitfn function flat!(M::MMT, + v::FVector{CotangentSpaceType}, + x, + w::FVector{TangentSpaceType}) where {MT<:Manifold, + GT<:Metric, + MMT<:MetricManifold{MT,GT}; + HasMetric{MT,GT}} + return flat!(M.manifold, v, x, w) +end + +@traitfn function sharp!(M::MMT, + v::FVector{TangentSpaceType}, + x, + w::FVector{CotangentSpaceType}) where {MT<:Manifold, + GT<:Metric, + MMT<:MetricManifold{MT,GT}; + !HasMetric{MT,GT}} + ginv = inverse_local_metric(M, x) + copyto!(v, ginv*w) + return v +end + +@traitfn function sharp!(M::MMT, + v::FVector{TangentSpaceType}, + x, + w::FVector{CotangentSpaceType}) where {MT<:Manifold, + GT<:Metric, + MMT<:MetricManifold{MT,GT}; + HasMetric{MT,GT}} + return flat!(M.manifold, v, x, w) +end diff --git a/src/VectorBundle.jl b/src/VectorBundle.jl index fe705d83d0..c5f12de9cf 100644 --- a/src/VectorBundle.jl +++ b/src/VectorBundle.jl @@ -219,7 +219,10 @@ function inner(B::VectorBundleFibers{<:TangentSpaceType}, x, v, w) end function inner(B::VectorBundleFibers{<:CotangentSpaceType}, x, v, w) - return inner(B.M, x, flat(B, x, v), flat(B, x, w)) + return inner(B.M, + x, + sharp(B.M, x, FVector(CotangentSpace, v)).data, + sharp(B.M, x, FVector(CotangentSpace, w)).data) end norm(B::VectorBundleFibers, x, v) = sqrt(inner(B, x, v, v)) diff --git a/test/metric_test.jl b/test/metric_test.jl index ad72cce515..eeedbfbbaf 100644 --- a/test/metric_test.jl +++ b/test/metric_test.jl @@ -1,3 +1,4 @@ +include("utils.jl") struct TestEuclidean{N} <: Manifold end struct TestEuclideanMetric <: Metric end @@ -153,6 +154,14 @@ struct BaseManifoldMetric{M} <: Metric end Manifolds.exp!(::BaseManifold, y, x, v) = y .= x + 2 * v Manifolds.log!(::BaseManifold, v, x, y) = v .= (y - x) / 2 Manifolds.project_tangent!(::BaseManifold, w, x, v) = w .= 2 .* v + function Manifolds.flat!(::BaseManifold, v::FVector{Manifolds.CotangentSpaceType}, x, w::FVector{Manifolds.TangentSpaceType}) + v.data .= 2 .* w.data + return v + end + function Manifolds.sharp!(::BaseManifold, v::FVector{Manifolds.TangentSpaceType}, x, w::FVector{Manifolds.CotangentSpaceType}) + v.data .= w.data ./ 2 + return v + end M = BaseManifold{3}() g = BaseManifoldMetric{3}() @@ -178,4 +187,13 @@ struct BaseManifoldMetric{M} <: Metric end @test injectivity_radius(MM) === injectivity_radius(M) @test is_manifold_point(MM, x) === is_manifold_point(M, x) @test is_tangent_vector(MM, x, v) === is_tangent_vector(M, x, v) + + cov = flat(M, x, FVector(TangentSpace, v)) + cow = flat(M, x, FVector(TangentSpace, w)) + @test cov.data ≈ flat(MM, x, FVector(TangentSpace, v)).data + cotspace = CotangentBundleFibers(M) + @test cov.data ≈ 2 * v + @test inner(M, x, v, w) ≈ inner(cotspace, x, cov.data, cow.data) + @test inner(MM, x, v, w) ≈ inner(cotspace, x, cov.data, cow.data) + @test sharp(M, x, cov).data ≈ v end