diff --git a/src/identity.jl b/src/identity.jl index 7c300fa..04dde86 100644 --- a/src/identity.jl +++ b/src/identity.jl @@ -76,12 +76,12 @@ end Base.IndexStyle(::Type{<:IdentityMultiple}) = IndexLinear() Base.size(𝐼::IdentityMultiple) = (𝐼.n, 𝐼.n) -function Base.getindex(𝐼::IdentityMultiple, inds...) +function Base.getindex(𝐼::IdentityMultiple, inds::Integer...) any(idx -> idx > 𝐼.n, inds) && throw(BoundsError(𝐼, inds)) return getindex(𝐼.M, inds...) end -function Base.getindex(𝐼::IdentityMultiple{T}, ind) where {T} +function Base.getindex(𝐼::IdentityMultiple{T}, ind::Integer) where {T} if 1 ≤ ind ≤ 𝐼.n^2 return rem(ind - 1, 𝐼.n + 1) == 0 ? 𝐼.M.λ : zero(T) else @@ -94,8 +94,6 @@ function Base.setindex!(::IdentityMultiple, ::Any, inds...) end Base.:(-)(𝐼::IdentityMultiple) = IdentityMultiple(-𝐼.M, 𝐼.n) -Base.:(+)(𝐼::IdentityMultiple, M::AbstractMatrix) = 𝐼.M + M -Base.:(+)(M::AbstractMatrix, 𝐼::IdentityMultiple) = M + 𝐼.M Base.:(*)(x::Number, 𝐼::IdentityMultiple) = IdentityMultiple(x * 𝐼.M, 𝐼.n) Base.:(*)(𝐼::IdentityMultiple, x::Number) = IdentityMultiple(x * 𝐼.M, 𝐼.n) Base.:(/)(𝐼::IdentityMultiple, x::Number) = IdentityMultiple(𝐼.M / x, 𝐼.n) @@ -193,8 +191,4 @@ end # callable identity matrix given the scaling factor and the size IdentityMultiple(λ::Number, n::Int) = IdentityMultiple(λ * I, n) -function LinearAlgebra.Hermitian(𝐼::IdentityMultiple) - return Hermitian(Diagonal(fill(𝐼.M.λ, 𝐼.n))) -end - Base.exp(𝐼::IdentityMultiple) = IdentityMultiple(exp(𝐼.M.λ), 𝐼.n) diff --git a/test/identity.jl b/test/identity.jl index adf0737..fb5b33b 100644 --- a/test/identity.jl +++ b/test/identity.jl @@ -76,6 +76,7 @@ end end @testset "Specific methods for IdentityMultiple" begin + @test Diagonal(Id(2)) == Diagonal([1.0, 1]) @test Hermitian(Id(2)) == Hermitian([1.0 0; 0 1]) @test exp(Id(3, 1)) == Id(3, exp(1)) end