diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index cafa232c..83c60612 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -25,7 +25,7 @@ jobs: include: - os: ubuntu-latest arch: x64 - version: '1.6' + version: '1.10' steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 diff --git a/Project.toml b/Project.toml index b4e4ae11..d7e98872 100644 --- a/Project.toml +++ b/Project.toml @@ -13,6 +13,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" QuantumInterface = "5717a53b-5d69-4fa3-b976-0bf2f97ca1e5" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" RandomMatrices = "2576dda1-a324-5b11-aa66-c48ed7e3c618" +RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Strided = "5e0ebb24-38b0-5f93-81fe-25c709ecae67" UnsafeArrays = "c4a57d5a-5b31-53a6-b365-19f8c011fbd6" @@ -28,7 +29,8 @@ LinearAlgebra = "1" QuantumInterface = "0.3.3" Random = "1" RandomMatrices = "0.5" +RecursiveArrayTools = "3" SparseArrays = "1" Strided = "1, 2" UnsafeArrays = "1" -julia = "1.6" +julia = "1.10" diff --git a/src/QuantumOpticsBase.jl b/src/QuantumOpticsBase.jl index 3fd84fb5..0c931c72 100644 --- a/src/QuantumOpticsBase.jl +++ b/src/QuantumOpticsBase.jl @@ -2,6 +2,7 @@ module QuantumOpticsBase using SparseArrays, LinearAlgebra, LRUCache, Strided, UnsafeArrays, FillArrays import LinearAlgebra: mul!, rmul! +import RecursiveArrayTools import QuantumInterface: dagger, directsum, ⊕, dm, embed, nsubsystems, expect, identityoperator, identitysuperoperator, permutesystems, projector, ptrace, reduced, tensor, ⊗, variance, apply!, basis, AbstractSuperOperator diff --git a/src/operators_dense.jl b/src/operators_dense.jl index e36f17f1..c00b512b 100644 --- a/src/operators_dense.jl +++ b/src/operators_dense.jl @@ -420,37 +420,50 @@ struct OperatorStyle{BL,BR} <: DataOperatorStyle{BL,BR} end Broadcast.BroadcastStyle(::Type{<:Operator{BL,BR}}) where {BL,BR} = OperatorStyle{BL,BR}() Broadcast.BroadcastStyle(::OperatorStyle{B1,B2}, ::OperatorStyle{B3,B4}) where {B1,B2,B3,B4} = throw(IncompatibleBases()) +# Broadcast with scalars (of use in ODE solvers checking for tolerances, e.g. `.* reltol .+ abstol`) +Broadcast.BroadcastStyle(::T, ::Broadcast.DefaultArrayStyle{0}) where {Bl<:Basis, Br<:Basis, T<:OperatorStyle{Bl,Br}} = T() + # Out-of-place broadcasting @inline function Base.copy(bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {BL,BR,Style<:OperatorStyle{BL,BR},Axes,F,Args<:Tuple} bcf = Broadcast.flatten(bc) bl,br = find_basis(bcf.args) - bc_ = Broadcasted_restrict_f(bcf.f, bcf.args, axes(bcf)) - return Operator{BL,BR}(bl, br, copy(bc_)) + T = find_dType(bcf) + data = zeros(T, length(bl), length(br)) + @inbounds @simd for I in eachindex(bcf) + data[I] = bcf[I] + end + return Operator{BL,BR}(bl, br, data) end -find_basis(a::DataOperator, rest) = (a.basis_l, a.basis_r) -const BasicMathFunc = Union{typeof(+),typeof(-),typeof(*)} -function Broadcasted_restrict_f(f::BasicMathFunc, args::Tuple{Vararg{<:DataOperator}}, axes) - args_ = Tuple(a.data for a=args) - return Broadcast.Broadcasted(f, args_, axes) -end +find_basis(a::DataOperator, rest) = (a.basis_l, a.basis_r) +find_dType(a::DataOperator, rest) = eltype(a) +@inline Base.getindex(a::DataOperator, idx) = getindex(a.data, idx) +Base.@propagate_inbounds Base.Broadcast._broadcast_getindex(x::DataOperator, i) = x.data[i] +Base.iterate(a::DataOperator) = iterate(a.data) +Base.iterate(a::DataOperator, idx) = iterate(a.data, idx) # In-place broadcasting @inline function Base.copyto!(dest::DataOperator{BL,BR}, bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {BL,BR,Style<:DataOperatorStyle{BL,BR},Axes,F,Args} axes(dest) == axes(bc) || Base.Broadcast.throwdm(axes(dest), axes(bc)) - # Performance optimization: broadcast!(identity, dest, A) is equivalent to copyto!(dest, A) if indices match - if bc.f === identity && isa(bc.args, Tuple{<:DataOperator{BL,BR}}) # only a single input argument to broadcast! - A = bc.args[1] - if axes(dest) == axes(A) - return copyto!(dest, A) - end + bc′ = Base.Broadcast.preprocess(dest, bc) + dest′ = dest.data + @inbounds @simd for I in eachindex(bc′) + dest′[I] = bc′[I] end - # Get the underlying data fields of operators and broadcast them as arrays - bcf = Broadcast.flatten(bc) - bc_ = Broadcasted_restrict_f(bcf.f, bcf.args, axes(bcf)) - copyto!(dest.data, bc_) return dest end @inline Base.copyto!(A::DataOperator{BL,BR},B::DataOperator{BL,BR}) where {BL,BR} = (copyto!(A.data,B.data); A) @inline Base.copyto!(dest::DataOperator{BL,BR}, bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {BL,BR,Style<:DataOperatorStyle,Axes,F,Args} = throw(IncompatibleBases()) + +# A few more standard interfaces: These do not necessarily make sense for a StateVector, but enable transparent use of DifferentialEquations.jl +Base.eltype(::Type{Operator{Bl,Br,A}}) where {Bl,Br,N,A<:AbstractMatrix{N}} = N # ODE init +Base.any(f::Function, x::Operator; kwargs...) = any(f, x.data; kwargs...) # ODE nan checks +Base.all(f::Function, x::Operator; kwargs...) = all(f, x.data; kwargs...) +Base.fill!(x::Operator, a) = typeof(x)(x.basis_l, x.basis_r, fill!(x.data, a)) +Base.ndims(x::Type{Operator{Bl,Br,A}}) where {Bl,Br,N,A<:AbstractMatrix{N}} = ndims(A) +Base.similar(x::Operator, t) = typeof(x)(x.basis_l, x.basis_r, copy(x.data)) +RecursiveArrayTools.recursivecopy!(dest::Operator{Bl,Br,A},src::Operator{Bl,Br,A}) where {Bl,Br,A} = copyto!(dest,src) # ODE in-place equations +RecursiveArrayTools.recursivecopy(x::Operator) = copy(x) +RecursiveArrayTools.recursivecopy(x::AbstractArray{T}) where {T<:Operator} = copy(x) +RecursiveArrayTools.recursivefill!(x::Operator, a) = fill!(x, a) \ No newline at end of file diff --git a/src/states.jl b/src/states.jl index b632c177..ee15cb56 100644 --- a/src/states.jl +++ b/src/states.jl @@ -180,52 +180,51 @@ Broadcast.BroadcastStyle(::Type{<:Bra{B}}) where {B} = BraStyle{B}() Broadcast.BroadcastStyle(::KetStyle{B1}, ::KetStyle{B2}) where {B1,B2} = throw(IncompatibleBases()) Broadcast.BroadcastStyle(::BraStyle{B1}, ::BraStyle{B2}) where {B1,B2} = throw(IncompatibleBases()) +# Broadcast with scalars (of use in ODE solvers checking for tolerances, e.g. `.* reltol .+ abstol`) +Broadcast.BroadcastStyle(::T, ::Broadcast.DefaultArrayStyle{0}) where {B<:Basis, T<:KetStyle{B}} = T() +Broadcast.BroadcastStyle(::T, ::Broadcast.DefaultArrayStyle{0}) where {B<:Basis, T<:BraStyle{B}} = T() + # Out-of-place broadcasting @inline function Base.copy(bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {B,Style<:KetStyle{B},Axes,F,Args<:Tuple} bcf = Broadcast.flatten(bc) - bc_ = Broadcasted_restrict_f(bcf.f, bcf.args, axes(bcf)) b = find_basis(bcf) - return Ket{B}(b, copy(bc_)) + T = find_dType(bcf) + data = zeros(T, length(b)) + @inbounds @simd for I in eachindex(bcf) + data[I] = bcf[I] + end + return Ket{B}(b, data) end @inline function Base.copy(bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {B,Style<:BraStyle{B},Axes,F,Args<:Tuple} bcf = Broadcast.flatten(bc) - bc_ = Broadcasted_restrict_f(bcf.f, bcf.args, axes(bcf)) b = find_basis(bcf) - return Bra{B}(b, copy(bc_)) -end -find_basis(bc::Broadcast.Broadcasted) = find_basis(bc.args) -find_basis(args::Tuple) = find_basis(find_basis(args[1]), Base.tail(args)) -find_basis(x) = x -find_basis(a::StateVector, rest) = a.basis -find_basis(::Any, rest) = find_basis(rest) - -const BasicMathFunc = Union{typeof(+),typeof(-),typeof(*)} -function Broadcasted_restrict_f(f::BasicMathFunc, args::NTuple{N,<:T}, axes) where {T<:StateVector,N} - args_ = Tuple(a.data for a=args) - return Broadcast.Broadcasted(f, args_, axes) -end -function Broadcasted_restrict_f(f, args::Tuple, axes) - error("Cannot broadcast function `$f` on $(typeof(args))") + T = find_dType(bcf) + data = zeros(T, length(b)) + @inbounds @simd for I in eachindex(bcf) + data[I] = bcf[I] + end + return Bra{B}(b, data) end -function Broadcasted_restrict_f(f::BasicMathFunc, args::Tuple{}, axes) # Defined to avoid method ambiguities - error("Cannot broadcast function `$f` on an empty set of arguments") +for f ∈ [:find_basis,:find_dType] + @eval ($f)(bc::Broadcast.Broadcasted) = ($f)(bc.args) + @eval ($f)(args::Tuple) = ($f)(($f)(args[1]), Base.tail(args)) + @eval ($f)(x) = x + @eval ($f)(::Any, rest) = ($f)(rest) end +find_basis(x::T, rest) where {T<:Union{Ket, Bra}} = x.basis +find_dType(x::T, rest) where {T<:Union{Ket, Bra}} = eltype(x) +@inline Base.getindex(x::T, idx) where {T<:Union{Ket, Bra}} = getindex(x.data, idx) +Base.@propagate_inbounds Base.Broadcast._broadcast_getindex(x::T, i) where {T<:Union{Ket, Bra}} = x.data[i] + # In-place broadcasting for Kets @inline function Base.copyto!(dest::Ket{B}, bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {B,Style<:KetStyle{B},Axes,F,Args} axes(dest) == axes(bc) || Base.Broadcast.throwdm(axes(dest), axes(bc)) - # Performance optimization: broadcast!(identity, dest, A) is equivalent to copyto!(dest, A) if indices match - if bc.f === identity && isa(bc.args, Tuple{<:Ket{B}}) # only a single input argument to broadcast! - A = bc.args[1] - if axes(dest) == axes(A) - return copyto!(dest, A) - end + bc′ = Base.Broadcast.preprocess(dest, bc) + dest′ = dest.data + @inbounds @simd for I in eachindex(bc′) + dest′[I] = bc′[I] end - # Get the underlying data fields of kets and broadcast them as arrays - bcf = Broadcast.flatten(bc) - args_ = Tuple(a.data for a=bcf.args) - bc_ = Broadcast.Broadcasted(bcf.f, args_, axes(bcf)) - copyto!(dest.data, bc_) return dest end @inline Base.copyto!(dest::Ket{B1}, bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {B1,B2,Style<:KetStyle{B2},Axes,F,Args} = @@ -234,20 +233,27 @@ end # In-place broadcasting for Bras @inline function Base.copyto!(dest::Bra{B}, bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {B,Style<:BraStyle{B},Axes,F,Args} axes(dest) == axes(bc) || Base.Broadcast.throwdm(axes(dest), axes(bc)) - # Performance optimization: broadcast!(identity, dest, A) is equivalent to copyto!(dest, A) if indices match - if bc.f === identity && isa(bc.args, Tuple{<:Bra{B}}) # only a single input argument to broadcast! - A = bc.args[1] - if axes(dest) == axes(A) - return copyto!(dest, A) - end + bc′ = Base.Broadcast.preprocess(dest, bc) + dest′ = dest.data + @inbounds @simd for I in eachindex(bc′) + dest′[I] = bc′[I] end - # Get the underlying data fields of bras and broadcast them as arrays - bcf = Broadcast.flatten(bc) - bc_ = Broadcasted_restrict_f(bcf.f, bcf.args, axes(bcf)) - copyto!(dest.data, bc_) return dest end @inline Base.copyto!(dest::Bra{B1}, bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {B1,B2,Style<:BraStyle{B2},Axes,F,Args} = throw(IncompatibleBases()) -@inline Base.copyto!(A::T,B::T) where T<:Union{Ket, Bra} = (copyto!(A.data,B.data); A) # Can not use T<:QuantumInterface.StateVector, because StateVector does not imply the existence of a data property +@inline Base.copyto!(dest::T,src::T) where {T<:Union{Ket, Bra}} = (copyto!(dest.data,src.data); dest) # Can not use T<:QuantumInterface.StateVector, because StateVector does not imply the existence of a data property + +# A few more standard interfaces: These do not necessarily make sense for a StateVector, but enable transparent use of DifferentialEquations.jl +Base.eltype(::Type{Ket{B,A}}) where {B,N,A<:AbstractVector{N}} = N # ODE init +Base.eltype(::Type{Bra{B,A}}) where {B,N,A<:AbstractVector{N}} = N +Base.any(f::Function, x::T; kwargs...) where {T<:Union{Ket, Bra}} = any(f, x.data; kwargs...) # ODE nan checks +Base.all(f::Function, x::T; kwargs...) where {T<:Union{Ket, Bra}} = all(f, x.data; kwargs...) +Base.fill!(x::T, a) where {T<:Union{Ket, Bra}} = typeof(x)(x.basis, fill!(x.data, a)) +Base.similar(x::T, t) where {T<:Union{Ket, Bra}} = typeof(x)(x.basis, similar(x.data)) +RecursiveArrayTools.recursivecopy!(dest::Ket{B,A},src::Ket{B,A}) where {B,A} = copyto!(dest, src) # ODE in-place equations +RecursiveArrayTools.recursivecopy!(dest::Bra{B,A},src::Bra{B,A}) where {B,A} = copyto!(dest, src) +RecursiveArrayTools.recursivecopy(x::T) where {T<:Union{Ket, Bra}} = copy(x) +RecursiveArrayTools.recursivecopy(x::AbstractArray{T}) where {T<:Union{Ket, Bra}} = copy(x) +RecursiveArrayTools.recursivefill!(x::T, a) where {T<:Union{Ket, Bra}} = fill!(x, a) \ No newline at end of file diff --git a/src/superoperators.jl b/src/superoperators.jl index 8128c3e5..5161f5b5 100644 --- a/src/superoperators.jl +++ b/src/superoperators.jl @@ -287,7 +287,7 @@ end # end find_basis(a::SuperOperator, rest) = (a.basis_l, a.basis_r) -const BasicMathFunc = Union{typeof(+),typeof(-),typeof(*)} +const BasicMathFunc = Union{typeof(+),typeof(-),typeof(*),typeof(/)} function Broadcasted_restrict_f(f::BasicMathFunc, args::Tuple{Vararg{<:SuperOperator}}, axes) args_ = Tuple(a.data for a=args) return Broadcast.Broadcasted(f, args_, axes) diff --git a/test/Project.toml b/test/Project.toml index 272e2c15..dde78927 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -7,7 +7,9 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" LRUCache = "8ac3fa9e-de4c-5943-b1dc-09c6b5f20637" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" QuantumInterface = "5717a53b-5d69-4fa3-b976-0bf2f97ca1e5" +QuantumOptics = "6e0679c1-51ea-5a7c-ac74-d61b76210b0c" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" RandomMatrices = "2576dda1-a324-5b11-aa66-c48ed7e3c618" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" diff --git a/test/runtests.jl b/test/runtests.jl index fdaf5a4d..3a132340 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -22,6 +22,8 @@ names = [ "test_subspace.jl", "test_state_definitions.jl", + "test_sciml_broadcast_interfaces.jl", + "test_transformations.jl", "test_metrics.jl", diff --git a/test/test_abstractdata.jl b/test/test_abstractdata.jl index 659c1fca..563afb83 100644 --- a/test/test_abstractdata.jl +++ b/test/test_abstractdata.jl @@ -340,7 +340,6 @@ op1 .= op1_ .+ 3 * op1_ bf = FockBasis(3) op3 = randtestoperator(bf) @test_throws QuantumOpticsBase.IncompatibleBases op1 .+ op3 -@test_throws ErrorException cos.(op1) #################### # Test lazy tensor # diff --git a/test/test_jet.jl b/test/test_jet.jl index 29c419ef..a1d886a2 100644 --- a/test/test_jet.jl +++ b/test/test_jet.jl @@ -35,7 +35,7 @@ using LinearAlgebra, LRUCache, Strided, StridedViews, Dates, SparseArrays, Rando AnyFrameModule(RandomMatrices)) ) @show rep - @test length(JET.get_reports(rep)) <= 24 + @test length(JET.get_reports(rep)) <= 28 @test_broken length(JET.get_reports(rep)) == 0 end end # testset diff --git a/test/test_operators_dense.jl b/test/test_operators_dense.jl index 1fbb1b41..08add7b9 100644 --- a/test/test_operators_dense.jl +++ b/test/test_operators_dense.jl @@ -382,7 +382,6 @@ op1 .= op1_ .+ 3 * op1_ bf = FockBasis(3) op3 = randoperator(bf) @test_throws QuantumOpticsBase.IncompatibleBases op1 .+ op3 -@test_throws ErrorException cos.(op1) # Dimension mismatches b1, b2, b3 = NLevelBasis.((2,3,4)) # N is not a type parameter diff --git a/test/test_operators_sparse.jl b/test/test_operators_sparse.jl index ba0ca15e..78f6ba13 100644 --- a/test/test_operators_sparse.jl +++ b/test/test_operators_sparse.jl @@ -419,7 +419,6 @@ op3 = sprandop(FockBasis(1),FockBasis(2)) op_ = copy(op1) op_ .+= op1 @test op_ == 2*op1 -@test_throws ErrorException cos.(op_) # Dimension mismatches b1, b2, b3 = NLevelBasis.((2,3,4)) # N is not a type parameter diff --git a/test/test_sciml_broadcast_interfaces.jl b/test/test_sciml_broadcast_interfaces.jl new file mode 100644 index 00000000..cf7b27b5 --- /dev/null +++ b/test/test_sciml_broadcast_interfaces.jl @@ -0,0 +1,63 @@ +using Test +using QuantumOptics +using OrdinaryDiffEq + +@testset "sciml interface" begin + +# ket ODE problem +ℋ = SpinBasis(1//2) +↓ = spindown(ℋ) +t₀, t₁ = (0.0, pi) +σx = sigmax(ℋ) +iσx = im*σx +schrod!(dψ, ψ, p, t) = QuantumOptics.mul!(dψ, iσx, ψ) + +ix = iσx.data +schrod_data!(dψ,ψ,p,t) = QuantumOptics.mul!(dψ, ix, ψ) +u0 = (↓).data + +prob! = ODEProblem(schrod!, ↓, (t₀, t₁)) +prob_data! = ODEProblem(schrod_data!, u0, (t₀, t₁)) +sol = solve(prob!, DP5(); reltol = 1.0e-8, abstol = 1.0e-10, save_everystep=false) +sol_data = solve(prob_data!, DP5(); reltol = 1.0e-8, abstol = 1.0e-10, save_everystep=false) + +@test sol[end].data ≈ sol_data[end] + +# dense operator ODE problem +σ₋ = sigmam(ℋ) +σ₊ = σ₋' +mhalfσ₊σ₋ = -σ₊*σ₋/2 +ρ0 = dm(↓) +tmp = zero(ρ0) +function lind!(dρ,ρ,p,t) + QuantumOptics.mul!(tmp, ρ, σ₊) + QuantumOptics.mul!(dρ, σ₋, ρ) + QuantumOptics.mul!(dρ, ρ, mhalfσ₊σ₋, true, true) + QuantumOptics.mul!(dρ, mhalfσ₊σ₋, ρ, true, true) + QuantumOptics.mul!(dρ, iσx, ρ, -ComplexF64(1), ComplexF64(1)) + QuantumOptics.mul!(dρ, ρ, iσx, true, true) + return dρ +end +m0 = ρ0.data +σ₋d = σ₋.data +σ₊d = σ₊.data +mhalfσ₊σ₋d = mhalfσ₊σ₋.data +tmpd = zero(m0) +function lind_data!(dρ,ρ,p,t) + QuantumOptics.mul!(tmpd, ρ, σ₊d) + QuantumOptics.mul!(dρ, σ₋d, ρ) + QuantumOptics.mul!(dρ, ρ, mhalfσ₊σ₋d, true, true) + QuantumOptics.mul!(dρ, mhalfσ₊σ₋d, ρ, true, true) + QuantumOptics.mul!(dρ, ix, ρ, -ComplexF64(1), ComplexF64(1)) + QuantumOptics.mul!(dρ, ρ, ix, true, true) + return dρ +end + +prob! = ODEProblem(lind!, ρ0, (t₀, t₁)) +prob_data! = ODEProblem(lind_data!, m0, (t₀, t₁)) +sol = solve(prob!, DP5(); reltol = 1.0e-8, abstol = 1.0e-10, save_everystep=false) +sol_data = solve(prob_data!, DP5(); reltol = 1.0e-8, abstol = 1.0e-10, save_everystep=false) + +@test sol[end].data ≈ sol_data[end] + +end \ No newline at end of file diff --git a/test/test_states.jl b/test/test_states.jl index 723872a6..eea67e21 100644 --- a/test/test_states.jl +++ b/test/test_states.jl @@ -166,8 +166,6 @@ psi_ .+= psi123 bra_ = copy(bra123) bra_ .= 3*bra123 @test bra_ == 3*dagger(psi123) -@test_throws ErrorException cos.(psi_) -@test_throws ErrorException cos.(bra_) end # testset diff --git a/test/test_superoperators.jl b/test/test_superoperators.jl index 204264ac..1a0c25ab 100644 --- a/test/test_superoperators.jl +++ b/test/test_superoperators.jl @@ -215,8 +215,6 @@ Ldense .+= Ldense Ldense .+= L @test isa(Ldense, DenseSuperOpType) @test isapprox(Ldense.data, 5*Ldense_.data) -@test_throws ErrorException cos.(Ldense) -@test_throws ErrorException cos.(L) b = FockBasis(20) L = liouvillian(identityoperator(b), [destroy(b)])