Skip to content

Commit

Permalink
Add test_bases from QuantumOpticsBase and reorganize a bit
Browse files Browse the repository at this point in the history
  • Loading branch information
akirakyle committed Dec 6, 2024
1 parent f22adb7 commit 791017c
Show file tree
Hide file tree
Showing 12 changed files with 211 additions and 143 deletions.
36 changes: 31 additions & 5 deletions src/QuantumInterface.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,39 @@
module QuantumInterface

import Base: ==, +, -, *, /, ^, length, one, exp, conj, conj!, transpose, copy
import LinearAlgebra: tr, ishermitian, norm, normalize, normalize!
import Base: show, summary
import SparseArrays: sparse, spzeros, AbstractSparseMatrix # TODO move to an extension
##
# Basis specific
##

"""
basis(a)
Return the basis of an object.
If it's ambiguous, e.g. if an operator has a different left and right basis,
an [`IncompatibleBases`](@ref) error is thrown.
"""
function basis end

"""
Exception that should be raised for an illegal algebraic operation.
"""
mutable struct IncompatibleBases <: Exception end


##
# Standard methods
##

function apply! end

function dagger end

"""
directsum(x, y, z...)
Direct sum of the given objects. Alternatively, the unicode
symbol ⊕ (\\oplus) can be used.
"""
function directsum end
const = directsum
directsum() = GenericBasis(0)
Expand Down Expand Up @@ -86,8 +111,9 @@ function squeeze end
function wigner end


include("bases.jl")
include("abstract_types.jl")
include("bases.jl")
include("show.jl")

include("linalg.jl")
include("tensor.jl")
Expand Down
32 changes: 15 additions & 17 deletions src/abstract_types.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
"""
Abstract base class for all specialized bases.
The Basis class is meant to specify a basis of the Hilbert space of the
studied system. Besides basis specific information all subclasses must
implement a shape variable which indicates the dimension of the used
Hilbert space. For a spin-1/2 Hilbert space this would be the
vector `[2]`. A system composed of two spins would then have a
shape vector `[2 2]`.
Composite systems can be defined with help of the [`CompositeBasis`](@ref)
class.
"""
abstract type Basis end

"""
Abstract base class for `Bra` and `Ket` states.
Expand Down Expand Up @@ -38,20 +53,3 @@ A_{br_1,br_2} = B_{bl_1,bl_2} S_{(bl_1,bl_2) ↔ (br_1,br_2)}
```
"""
abstract type AbstractSuperOperator{B1,B2} end

function summary(stream::IO, x::AbstractOperator)
print(stream, "$(typeof(x).name.name)(dim=$(length(x.basis_l))x$(length(x.basis_r)))\n")
if samebases(x)
print(stream, " basis: ")
show(stream, basis(x))
else
print(stream, " basis left: ")
show(stream, x.basis_l)
print(stream, "\n basis right: ")
show(stream, x.basis_r)
end
end

show(stream::IO, x::AbstractOperator) = summary(stream, x)

traceout!(s::StateVector, i) = ptrace(s,i)
96 changes: 3 additions & 93 deletions src/bases.jl
Original file line number Diff line number Diff line change
@@ -1,36 +1,10 @@
"""
Abstract base class for all specialized bases.
The Basis class is meant to specify a basis of the Hilbert space of the
studied system. Besides basis specific information all subclasses must
implement a shape variable which indicates the dimension of the used
Hilbert space. For a spin-1/2 Hilbert space this would be the
vector `[2]`. A system composed of two spins would then have a
shape vector `[2 2]`.
Composite systems can be defined with help of the [`CompositeBasis`](@ref)
class.
"""
abstract type Basis end

"""
length(b::Basis)
Total dimension of the Hilbert space.
"""
Base.length(b::Basis) = prod(b.shape)

"""
basis(a)
Return the basis of an object.
If it's ambiguous, e.g. if an operator has a different left and right basis,
an [`IncompatibleBases`](@ref) error is thrown.
"""
function basis end


"""
GenericBasis(N)
Expand Down Expand Up @@ -137,11 +111,6 @@ function equal_bases(a, b)
return true
end

"""
Exception that should be raised for an illegal algebraic operation.
"""
mutable struct IncompatibleBases <: Exception end

const BASES_CHECK = Ref(true)

"""
Expand Down Expand Up @@ -366,9 +335,9 @@ SumBasis(shape, bases::Vector) = (tmp = (bases...,); SumBasis(shape, tmp))
SumBasis(bases::Vector) = SumBasis((bases...,))
SumBasis(bases::Basis...) = SumBasis((bases...,))

==(b1::T, b2::T) where T<:SumBasis = equal_shape(b1.shape, b2.shape)
==(b1::SumBasis, b2::SumBasis) = false
length(b::SumBasis) = sum(b.shape)
Base.:(==)(b1::T, b2::T) where T<:SumBasis = equal_shape(b1.shape, b2.shape)
Base.:(==)(b1::SumBasis, b2::SumBasis) = false
Base.length(b::SumBasis) = sum(b.shape)

Check warning on line 340 in src/bases.jl

View check run for this annotation

Codecov / codecov/patch

src/bases.jl#L338-L340

Added lines #L338 - L340 were not covered by tests

"""
directsum(b1::Basis, b2::Basis)
Expand All @@ -393,62 +362,3 @@ function directsum(b1::SumBasis, b2::SumBasis)
bases = [b1.bases...;b2.bases...]
return SumBasis(shape, (bases...,))
end

embed(b::SumBasis, indices, ops) = embed(b, b, indices, ops)

##
# show methods
##

function show(stream::IO, x::GenericBasis)
if length(x.shape) == 1
write(stream, "Basis(dim=$(x.shape[1]))")
else
s = replace(string(x.shape), " " => "")
write(stream, "Basis(shape=$s)")
end
end

function show(stream::IO, x::CompositeBasis)
write(stream, "[")
for i in 1:length(x.bases)
show(stream, x.bases[i])
if i != length(x.bases)
write(stream, "")
end
end
write(stream, "]")
end

function show(stream::IO, x::SpinBasis)
d = denominator(x.spinnumber)
n = numerator(x.spinnumber)
if d == 1
write(stream, "Spin($n)")
else
write(stream, "Spin($n/$d)")
end
end

function show(stream::IO, x::FockBasis)
if iszero(x.offset)
write(stream, "Fock(cutoff=$(x.N))")
else
write(stream, "Fock(cutoff=$(x.N), offset=$(x.offset))")
end
end

function show(stream::IO, x::NLevelBasis)
write(stream, "NLevel(N=$(x.N))")
end

function show(stream::IO, x::SumBasis)
write(stream, "[")
for i in 1:length(x.bases)
show(stream, x.bases[i])
if i != length(x.bases)
write(stream, "")
end
end
write(stream, "]")
end
6 changes: 4 additions & 2 deletions src/embed_permute.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ function embed(basis_l::CompositeBasis, basis_r::CompositeBasis,
ops_sb = [x[2] for x in idxop_sb]

for (idxsb, opsb) in zip(indices_sb, ops_sb)
(opsb.basis_l == basis_l.bases[idxsb]) || throw(IncompatibleBases())
(opsb.basis_r == basis_r.bases[idxsb]) || throw(IncompatibleBases())
(opsb.basis_l == basis_l.bases[idxsb]) || throw(IncompatibleBases()) # FIXME issue #12
(opsb.basis_r == basis_r.bases[idxsb]) || throw(IncompatibleBases()) # FIXME issue #12

Check warning on line 71 in src/embed_permute.jl

View check run for this annotation

Codecov / codecov/patch

src/embed_permute.jl#L70-L71

Added lines #L70 - L71 were not covered by tests
end

S = length(operators) > 0 ? mapreduce(eltype, promote_type, operators) : Any
Expand All @@ -83,6 +83,8 @@ function embed(basis_l::CompositeBasis, basis_r::CompositeBasis,
return embed_op
end

embed(b::SumBasis, indices, ops) = embed(b, b, indices, ops)

Check warning on line 86 in src/embed_permute.jl

View check run for this annotation

Codecov / codecov/patch

src/embed_permute.jl#L86

Added line #L86 was not covered by tests

permutesystems(a::AbstractOperator, perm) = arithmetic_unary_error("Permutations of subsystems", a)

nsubsystems(s::AbstractKet) = nsubsystems(basis(s))
Expand Down
4 changes: 2 additions & 2 deletions src/identityoperator.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
one(x::Union{<:Basis,<:AbstractOperator}) = identityoperator(x)
Base.one(x::Union{<:Basis,<:AbstractOperator}) = identityoperator(x)

Check warning on line 1 in src/identityoperator.jl

View check run for this annotation

Codecov / codecov/patch

src/identityoperator.jl#L1

Added line #L1 was not covered by tests

"""
identityoperator(a::Basis[, b::Basis])
Expand All @@ -22,4 +22,4 @@ identityoperator(::Type{T}, ::Type{Any}, b1::Basis, b2::Basis) where T<:Abstract
identityoperator(b1::Basis, b2::Basis) = identityoperator(ComplexF64, b1, b2)

"""Prepare the identity superoperator over a given space."""
function identitysuperoperator end
function identitysuperoperator end
37 changes: 21 additions & 16 deletions src/julia_base.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import Base: +, -, *, /, ^, length, exp, conj, conj!, adjoint, transpose, copy

# Common error messages
arithmetic_unary_error(funcname, x::AbstractOperator) = throw(ArgumentError("$funcname is not defined for this type of operator: $(typeof(x)).\nTry to convert to another operator type first with e.g. dense() or sparse()."))
arithmetic_binary_error(funcname, a::AbstractOperator, b::AbstractOperator) = throw(ArgumentError("$funcname is not defined for this combination of types of operators: $(typeof(a)), $(typeof(b)).\nTry to convert to a common operator type first with e.g. dense() or sparse()."))
Expand All @@ -8,33 +10,33 @@ addnumbererror() = throw(ArgumentError("Can't add or subtract a number and an op
# States
##

-(a::T) where {T<:StateVector} = T(a.basis, -a.data)
-(a::T) where {T<:StateVector} = T(a.basis, -a.data) # FIXME issue #12

Check warning on line 13 in src/julia_base.jl

View check run for this annotation

Codecov / codecov/patch

src/julia_base.jl#L13

Added line #L13 was not covered by tests
*(a::StateVector, b::Number) = b*a
copy(a::T) where {T<:StateVector} = T(a.basis, copy(a.data))
length(a::StateVector) = length(a.basis)::Int
basis(a::StateVector) = a.basis
copy(a::T) where {T<:StateVector} = T(a.basis, copy(a.data)) # FIXME issue #12
length(a::StateVector) = length(a.basis)::Int # FIXME issue #12
basis(a::StateVector) = a.basis # FIXME issue #12

Check warning on line 17 in src/julia_base.jl

View check run for this annotation

Codecov / codecov/patch

src/julia_base.jl#L15-L17

Added lines #L15 - L17 were not covered by tests
directsum(x::StateVector...) = reduce(directsum, x)
adjoint(a::StateVector) = dagger(a)

Check warning on line 19 in src/julia_base.jl

View check run for this annotation

Codecov / codecov/patch

src/julia_base.jl#L19

Added line #L19 was not covered by tests



# Array-like functions
Base.size(x::StateVector) = size(x.data)
@inline Base.axes(x::StateVector) = axes(x.data)
Base.size(x::StateVector) = size(x.data) # FIXME issue #12
@inline Base.axes(x::StateVector) = axes(x.data) # FIXME issue #12

Check warning on line 25 in src/julia_base.jl

View check run for this annotation

Codecov / codecov/patch

src/julia_base.jl#L24-L25

Added lines #L24 - L25 were not covered by tests
Base.ndims(x::StateVector) = 1
Base.ndims(::Type{<:StateVector}) = 1
Base.eltype(x::StateVector) = eltype(x.data)
Base.eltype(x::StateVector) = eltype(x.data) # FIXME issue #12

Check warning on line 28 in src/julia_base.jl

View check run for this annotation

Codecov / codecov/patch

src/julia_base.jl#L28

Added line #L28 was not covered by tests

# Broadcasting
Base.broadcastable(x::StateVector) = x

Base.adjoint(a::StateVector) = dagger(a)


##
# Operators
##

length(a::AbstractOperator) = length(a.basis_l)::Int*length(a.basis_r)::Int
basis(a::AbstractOperator) = (check_samebases(a); a.basis_l)
basis(a::AbstractSuperOperator) = (check_samebases(a); a.basis_l[1])
length(a::AbstractOperator) = length(a.basis_l)::Int*length(a.basis_r)::Int # FIXME issue #12
basis(a::AbstractOperator) = (check_samebases(a); a.basis_l) # FIXME issue #12
basis(a::AbstractSuperOperator) = (check_samebases(a); a.basis_l[1]) # FIXME issue #12

Check warning on line 39 in src/julia_base.jl

View check run for this annotation

Codecov / codecov/patch

src/julia_base.jl#L37-L39

Added lines #L37 - L39 were not covered by tests

# Ensure scalar broadcasting
Base.broadcastable(x::AbstractOperator) = Ref(x)
Expand All @@ -60,14 +62,17 @@ Operator exponential.
"""
exp(op::AbstractOperator) = throw(ArgumentError("exp() is not defined for this type of operator: $(typeof(op)).\nTry to convert to dense operator first with dense()."))

Base.size(op::AbstractOperator) = (length(op.basis_l),length(op.basis_r))
Base.size(op::AbstractOperator) = (length(op.basis_l),length(op.basis_r)) # FIXME issue #12

Check warning on line 65 in src/julia_base.jl

View check run for this annotation

Codecov / codecov/patch

src/julia_base.jl#L65

Added line #L65 was not covered by tests
function Base.size(op::AbstractOperator, i::Int)
i < 1 && throw(ErrorException("dimension index is < 1"))
i > 2 && return 1
i==1 ? length(op.basis_l) : length(op.basis_r)
i==1 ? length(op.basis_l) : length(op.basis_r) # FIXME issue #12

Check warning on line 69 in src/julia_base.jl

View check run for this annotation

Codecov / codecov/patch

src/julia_base.jl#L69

Added line #L69 was not covered by tests
end

Base.adjoint(a::AbstractOperator) = dagger(a)
adjoint(a::AbstractOperator) = dagger(a)

Check warning on line 72 in src/julia_base.jl

View check run for this annotation

Codecov / codecov/patch

src/julia_base.jl#L72

Added line #L72 was not covered by tests

transpose(a::AbstractOperator) = arithmetic_unary_error("Transpose", a)

Check warning on line 74 in src/julia_base.jl

View check run for this annotation

Codecov / codecov/patch

src/julia_base.jl#L74

Added line #L74 was not covered by tests


conj(a::AbstractOperator) = arithmetic_unary_error("Complex conjugate", a)
conj!(a::AbstractOperator) = conj(a::AbstractOperator)
6 changes: 4 additions & 2 deletions src/julia_linalg.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import LinearAlgebra: tr, ishermitian, norm, normalize, normalize!

"""
ishermitian(op::AbstractOperator)
Expand All @@ -17,7 +19,7 @@ tr(x::AbstractOperator) = arithmetic_unary_error("Trace", x)
Norm of the given bra or ket state.
"""
norm(x::StateVector) = norm(x.data)
norm(x::StateVector) = norm(x.data) # FIXME issue #12

Check warning on line 22 in src/julia_linalg.jl

View check run for this annotation

Codecov / codecov/patch

src/julia_linalg.jl#L22

Added line #L22 was not covered by tests

"""
normalize(x::StateVector)
Expand All @@ -31,7 +33,7 @@ normalize(x::StateVector) = x/norm(x)
In-place normalization of the given bra or ket so that `norm(x)` is one.
"""
normalize!(x::StateVector) = (normalize!(x.data); x)
normalize!(x::StateVector) = (normalize!(x.data); x) # FIXME issue #12

Check warning on line 36 in src/julia_linalg.jl

View check run for this annotation

Codecov / codecov/patch

src/julia_linalg.jl#L36

Added line #L36 was not covered by tests

"""
normalize(op)
Expand Down
10 changes: 5 additions & 5 deletions src/linalg.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
samebases(a::AbstractOperator) = samebases(a.basis_l, a.basis_r)::Bool
samebases(a::AbstractOperator, b::AbstractOperator) = samebases(a.basis_l, b.basis_l)::Bool && samebases(a.basis_r, b.basis_r)::Bool
check_samebases(a::Union{AbstractOperator, AbstractSuperOperator}) = check_samebases(a.basis_l, a.basis_r)
multiplicable(a::AbstractOperator, b::AbstractOperator) = multiplicable(a.basis_r, b.basis_l)
samebases(a::AbstractOperator) = samebases(a.basis_l, a.basis_r)::Bool # FIXME issue #12
samebases(a::AbstractOperator, b::AbstractOperator) = samebases(a.basis_l, b.basis_l)::Bool && samebases(a.basis_r, b.basis_r)::Bool # FIXME issue #12
check_samebases(a::Union{AbstractOperator, AbstractSuperOperator}) = check_samebases(a.basis_l, a.basis_r) # FIXME issue #12
multiplicable(a::AbstractOperator, b::AbstractOperator) = multiplicable(a.basis_r, b.basis_l) # FIXME issue #12

Check warning on line 4 in src/linalg.jl

View check run for this annotation

Codecov / codecov/patch

src/linalg.jl#L1-L4

Added lines #L1 - L4 were not covered by tests
dagger(a::AbstractOperator) = arithmetic_unary_error("Hermitian conjugate", a)
transpose(a::AbstractOperator) = arithmetic_unary_error("Transpose", a)
directsum(a::AbstractOperator...) = reduce(directsum, a)
ptrace(a::AbstractOperator, index) = arithmetic_unary_error("Partial trace", a)
_index_complement(b::CompositeBasis, indices) = complement(length(b.bases), indices)
reduced(a, indices) = ptrace(a, _index_complement(basis(a), indices))
traceout!(s::StateVector, i) = ptrace(s,i)

Check warning on line 10 in src/linalg.jl

View check run for this annotation

Codecov / codecov/patch

src/linalg.jl#L10

Added line #L10 was not covered by tests
Loading

0 comments on commit 791017c

Please sign in to comment.