-
Notifications
You must be signed in to change notification settings - Fork 34
/
Copy pathstates.jl
280 lines (234 loc) · 10.2 KB
/
states.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
import Base: ==, +, -, *, /, length, copy
import LinearAlgebra: norm, normalize, normalize!
"""
Abstract base class for [`Bra`](@ref) and [`Ket`](@ref) states.
The state vector class stores the coefficients of an abstract state
in respect to a certain basis. These coefficients are stored in the
`data` field and the basis is defined in the `basis`
field.
"""
abstract type StateVector{B<:Basis,T<:AbstractVector} end
"""
Bra(b::Basis[, data])
Bra state defined by coefficients in respect to the basis.
"""
mutable struct Bra{B<:Basis,T<:AbstractVector} <: StateVector{B,T}
basis::B
data::T
function Bra{B,T}(b::B, data::T) where {B<:Basis,T<:AbstractVector}
(length(b)==length(data)) || throw(DimensionMismatch("Tried to assign data of length $(length(data)) to Hilbert space of size $(length(b))"))
new(b, data)
end
end
"""
Ket(b::Basis[, data])
Ket state defined by coefficients in respect to the given basis.
"""
mutable struct Ket{B<:Basis,T<:AbstractVector} <: StateVector{B,T}
basis::B
data::T
function Ket{B,T}(b::B, data::T) where {B<:Basis,T<:AbstractVector}
(length(b)==length(data)) || throw(DimensionMismatch("Tried to assign data of length $(length(data)) to Hilbert space of size $(length(b))"))
new(b, data)
end
end
Bra{B}(b::B, data::T) where {B<:Basis,T} = Bra{B,T}(b, data)
Ket{B}(b::B, data::T) where {B<:Basis,T} = Ket{B,T}(b, data)
Bra(b::B, data::T) where {B<:Basis,T} = Bra{B,T}(b, data)
Ket(b::B, data::T) where {B<:Basis,T} = Ket{B,T}(b, data)
Bra{B}(b::B) where B<:Basis = Bra{B}(b, zeros(ComplexF64, length(b)))
Ket{B}(b::B) where B<:Basis = Ket{B}(b, zeros(ComplexF64, length(b)))
Bra(b::Basis) = Bra(b, zeros(ComplexF64, length(b)))
Ket(b::Basis) = Ket(b, zeros(ComplexF64, length(b)))
copy(a::T) where {T<:StateVector} = T(a.basis, copy(a.data))
length(a::StateVector) = length(a.basis)::Int
basis(a::StateVector) = a.basis
==(x::Ket{B}, y::Ket{B}) where {B<:Basis} = (samebases(x, y) && x.data==y.data)
==(x::Bra{B}, y::Bra{B}) where {B<:Basis} = (samebases(x, y) && x.data==y.data)
==(x::Ket, y::Ket) = false
==(x::Bra, y::Bra) = false
Base.isapprox(x::Ket{B}, y::Ket{B}; kwargs...) where {B<:Basis} = (samebases(x, y) && isapprox(x.data,y.data;kwargs...))
Base.isapprox(x::Bra{B}, y::Bra{B}; kwargs...) where {B<:Basis} = (samebases(x, y) && isapprox(x.data,y.data;kwargs...))
Base.isapprox(x::Ket, y::Ket; kwargs...) = false
Base.isapprox(x::Bra, y::Bra; kwargs...) = false
# Arithmetic operations
+(a::Ket{B}, b::Ket{B}) where {B<:Basis} = Ket(a.basis, a.data+b.data)
+(a::Bra{B}, b::Bra{B}) where {B<:Basis} = Bra(a.basis, a.data+b.data)
+(a::Ket, b::Ket) = throw(IncompatibleBases())
+(a::Bra, b::Bra) = throw(IncompatibleBases())
-(a::Ket{B}, b::Ket{B}) where {B<:Basis} = Ket(a.basis, a.data-b.data)
-(a::Bra{B}, b::Bra{B}) where {B<:Basis} = Bra(a.basis, a.data-b.data)
-(a::Ket, b::Ket) = throw(IncompatibleBases())
-(a::Bra, b::Bra) = throw(IncompatibleBases())
-(a::T) where {T<:StateVector} = T(a.basis, -a.data)
*(a::Bra{B}, b::Ket{B}) where {B<:Basis} = transpose(a.data)*b.data
*(a::Bra, b::Ket) = throw(IncompatibleBases())
*(a::Number, b::Ket) = Ket(b.basis, a*b.data)
*(a::Number, b::Bra) = Bra(b.basis, a*b.data)
*(a::StateVector, b::Number) = b*a
/(a::Ket, b::Number) = Ket(a.basis, a.data ./ b)
/(a::Bra, b::Number) = Bra(a.basis, a.data ./ b)
"""
dagger(x)
Hermitian conjugate.
"""
dagger(x::Bra) = Ket(x.basis, conj(x.data))
dagger(x::Ket) = Bra(x.basis, conj(x.data))
Base.adjoint(a::StateVector) = dagger(a)
"""
tensor(x::Ket, y::Ket, z::Ket...)
Tensor product ``|x⟩⊗|y⟩⊗|z⟩⊗…`` of the given states.
"""
tensor(a::Ket, b::Ket) = Ket(tensor(a.basis, b.basis), kron(b.data, a.data))
tensor(a::Bra, b::Bra) = Bra(tensor(a.basis, b.basis), kron(b.data, a.data))
tensor(state::StateVector) = state
tensor(states::Ket...) = reduce(tensor, states)
tensor(states::Bra...) = reduce(tensor, states)
tensor(states::Vector{T}) where T<:StateVector = reduce(tensor, states)
# Normalization functions
"""
norm(x::StateVector)
Norm of the given bra or ket state.
"""
norm(x::StateVector) = norm(x.data)
"""
normalize(x::StateVector)
Return the normalized state so that `norm(x)` is one.
"""
normalize(x::StateVector) = x/norm(x)
"""
normalize!(x::StateVector)
In-place normalization of the given bra or ket so that `norm(x)` is one.
"""
normalize!(x::StateVector) = (normalize!(x.data); x)
function permutesystems(state::T, perm::Vector{Int}) where T<:Ket
@assert length(state.basis.bases) == length(perm)
@assert isperm(perm)
data = reshape(state.data, state.basis.shape...)
data = permutedims(data, perm)
data = reshape(data, length(data))
Ket(permutesystems(state.basis, perm), data)
end
function permutesystems(state::T, perm::Vector{Int}) where T<:Bra
@assert length(state.basis.bases) == length(perm)
@assert isperm(perm)
data = reshape(state.data, state.basis.shape...)
data = permutedims(data, perm)
data = reshape(data, length(data))
Bra(permutesystems(state.basis, perm), data)
end
# Creation of basis states.
"""
basisstate(b, index; sparse=false, dType=ComplexF64)
Basis vector specified by `index` as ket state.
For a composite system `index` can be a vector which then creates a tensor
product state ``|i_1⟩⊗|i_2⟩⊗…⊗|i_n⟩`` of the corresponding basis states.
"""
function basisstate(b::Basis, indices::Vector{Int}; sparse=false, dType=ComplexF64)
@assert length(b.shape) == length(indices)
x = if sparse
spzeros(dType, length(b))
else
zeros(dType, length(b))
end
x[LinearIndices(tuple(b.shape...))[indices...]] = one(dType)
Ket(b, x)
end
function basisstate(b::Basis, index::Int; sparse=false, dType=ComplexF64)
data = if sparse
spzeros(dType, length(b))
else
zeros(dType, length(b))
end
data[index] = one(dType)
Ket(b, data)
end
# Helper functions to check validity of arguments
function check_multiplicable(a::Bra, b::Ket)
if a.basis != b.basis
throw(IncompatibleBases())
end
end
samebases(a::Ket{B}, b::Ket{B}) where {B} = samebases(a.basis, b.basis)::Bool
samebases(a::Bra{B}, b::Bra{B}) where {B} = samebases(a.basis, b.basis)::Bool
# Array-like functions
Base.size(x::StateVector) = size(x.data)
@inline Base.axes(x::StateVector) = axes(x.data)
Base.ndims(x::StateVector) = 1
Base.ndims(::Type{<:StateVector}) = 1
Base.eltype(x::StateVector) = eltype(x.data)
# Broadcasting
Base.broadcastable(x::StateVector) = x
# Custom broadcasting style
abstract type StateVectorStyle{B<:Basis} <: Broadcast.BroadcastStyle end
struct KetStyle{B<:Basis} <: StateVectorStyle{B} end
struct BraStyle{B<:Basis} <: StateVectorStyle{B} end
# Style precedence rules
Broadcast.BroadcastStyle(::Type{<:Ket{B}}) where {B<:Basis} = KetStyle{B}()
Broadcast.BroadcastStyle(::Type{<:Bra{B}}) where {B<:Basis} = BraStyle{B}()
Broadcast.BroadcastStyle(::KetStyle{B1}, ::KetStyle{B2}) where {B1<:Basis,B2<:Basis} = throw(IncompatibleBases())
Broadcast.BroadcastStyle(::BraStyle{B1}, ::BraStyle{B2}) where {B1<:Basis,B2<:Basis} = throw(IncompatibleBases())
# Out-of-place broadcasting
@inline function Base.copy(bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {B<:Basis,Style<:KetStyle{B},Axes,F,Args<:Tuple}
bcf = Broadcast.flatten(bc)
bc_ = Broadcasted_restrict_f(bcf.f, bcf.args, axes(bcf))
b = find_basis(bcf)
return Ket{B}(b, copy(bc_))
end
@inline function Base.copy(bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {B<:Basis,Style<:BraStyle{B},Axes,F,Args<:Tuple}
bcf = Broadcast.flatten(bc)
bc_ = Broadcasted_restrict_f(bcf.f, bcf.args, axes(bcf))
b = find_basis(bcf)
return Bra{B}(b, copy(bc_))
end
find_basis(bc::Broadcast.Broadcasted) = find_basis(bc.args)
find_basis(args::Tuple) = find_basis(find_basis(args[1]), Base.tail(args))
find_basis(x) = x
find_basis(a::StateVector, rest) = a.basis
find_basis(::Any, rest) = find_basis(rest)
const BasicMathFunc = Union{typeof(+),typeof(-),typeof(*)}
function Broadcasted_restrict_f(f::BasicMathFunc, args::Tuple{Vararg{<:T}}, axes) where T<:StateVector
args_ = Tuple(a.data for a=args)
return Broadcast.Broadcasted(f, args_, axes)
end
function Broadcasted_restrict_f(f, args::Tuple{Vararg{<:T}}, axes) where T<:StateVector
throw(error("Cannot broadcast function `$f` on type `$T`"))
end
# In-place broadcasting for Kets
@inline function Base.copyto!(dest::Ket{B}, bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {B<:Basis,Style<:KetStyle{B},Axes,F,Args}
axes(dest) == axes(bc) || Base.Broadcast.throwdm(axes(dest), axes(bc))
# Performance optimization: broadcast!(identity, dest, A) is equivalent to copyto!(dest, A) if indices match
if bc.f === identity && isa(bc.args, Tuple{<:Ket{B}}) # only a single input argument to broadcast!
A = bc.args[1]
if axes(dest) == axes(A)
return copyto!(dest, A)
end
end
# Get the underlying data fields of kets and broadcast them as arrays
bcf = Broadcast.flatten(bc)
args_ = Tuple(a.data for a=bcf.args)
bc_ = Broadcast.Broadcasted(bcf.f, args_, axes(bcf))
copyto!(dest.data, bc_)
return dest
end
@inline Base.copyto!(dest::Ket{B1}, bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {B1<:Basis,B2<:Basis,Style<:KetStyle{B2},Axes,F,Args} =
throw(IncompatibleBases())
# In-place broadcasting for Bras
@inline function Base.copyto!(dest::Bra{B}, bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {B<:Basis,Style<:BraStyle{B},Axes,F,Args}
axes(dest) == axes(bc) || Base.Broadcast.throwdm(axes(dest), axes(bc))
# Performance optimization: broadcast!(identity, dest, A) is equivalent to copyto!(dest, A) if indices match
if bc.f === identity && isa(bc.args, Tuple{<:Bra{B}}) # only a single input argument to broadcast!
A = bc.args[1]
if axes(dest) == axes(A)
return copyto!(dest, A)
end
end
# Get the underlying data fields of bras and broadcast them as arrays
bcf = Broadcast.flatten(bc)
bc_ = Broadcasted_restrict_f(bcf.f, bcf.args, axes(bcf))
copyto!(dest.data, bc_)
return dest
end
@inline Base.copyto!(dest::Bra{B1}, bc::Broadcast.Broadcasted{Style,Axes,F,Args}) where {B1<:Basis,B2<:Basis,Style<:BraStyle{B2},Axes,F,Args} =
throw(IncompatibleBases())
@inline Base.copyto!(A::T,B::T) where T<:StateVector = (copyto!(A.data,B.data); A)