Skip to content

Commit

Permalink
docs: More docstrings and comments for gate kernels (#62)
Browse files Browse the repository at this point in the history
  • Loading branch information
kshyatt-aws authored Jan 21, 2025
1 parent 0261648 commit 7b23bee
Show file tree
Hide file tree
Showing 2 changed files with 230 additions and 12 deletions.
7 changes: 7 additions & 0 deletions docs/src/internals.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,11 @@ BraketSimulator._combine_operations
BraketSimulator._prepare_program
BraketSimulator._get_measured_qubits
BraketSimulator._compute_results
BraketSimulator.flip_bit
BraketSimulator.flip_bits
BraketSimulator.pad_bit
BraketSimulator.pad_bits
BraketSimulator.matrix_rep
BraketSimulator.endian_qubits
BraketSimulator.get_amps_and_qubits
```
235 changes: 223 additions & 12 deletions src/gate_kernels.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,80 @@
"""
pad_bit(amp_index::Integer, bit::Integer)
Insert a `0` at location `bit` of `amp_index` (in its bits representation).
The first valid value of `bit` is **zero**.
# Examples
```jldoctest
julia> amp_index = 10
10
julia> digits(amp_index, base=2, pad=6)
6-element Vector{Int64}:
0
1
0
1
0
0
julia> amp_index = BraketSimulator.pad_bit(amp_index, 2);
julia> digits(amp_index, base=2, pad=7)
7-element Vector{Int64}:
0
1
0
0
1
0
0
```
"""
@inline function pad_bit(amp_index::Ti, bit::Tj)::Ti where {Ti<:Integer,Tj<:Integer}
left = (amp_index >> bit) << bit
right = amp_index - left
return (left << one(Ti)) right
end
"""
pad_bits(amp_index::Integer, to_pad)
Insert a `0` in `amp_index` at each location `bit` in the collection `to_pad`.
The first valid value of any `bit` is **zero**.
# Examples
```jldoctest
julia> amp_index = 10
10
julia> digits(amp_index, base=2, pad=6)
6-element Vector{Int64}:
0
1
0
1
0
0
julia> amp_index = BraketSimulator.pad_bits(amp_index, (2, 4));
julia> digits(amp_index, base=2, pad=8)
8-element Vector{Int64}:
0
1
0
0
0
1
0
0
```
!!! note
The indices in `pad_bits` aren't adjusted based on previous indices -- this can be seen in the above example,
where the bit at index 4 is different **before** and **after** inserting a bit at index 2.
"""
function pad_bits(ix::Ti, to_pad)::Ti where {Ti<:Integer}
padded_ix = ix
for bit in to_pad
Expand All @@ -11,18 +83,139 @@ function pad_bits(ix::Ti, to_pad)::Ti where {Ti<:Integer}
return padded_ix
end

@inline function flip_bit(amp_index::Ti, bit::Tj)::Ti where {Ti<:Integer,Tj<:Integer}
"""
flip_bit(amp_index::Integer, bit::Integer)
Flip the `bit`-th bit of `amp_index`, so that 0 becomes 1 and 1 becomes 0.
The first valid value of `bit` is **zero**.
# Examples
```jldoctest
julia> amp_index = 10
10
julia> digits(amp_index, base=2, pad=6)
6-element Vector{Int64}:
0
1
0
1
0
0
julia> amp_index = BraketSimulator.flip_bit(amp_index, 1)
8
julia> digits(amp_index, base=2, pad=6)
6-element Vector{Int64}:
0
0
0
1
0
0
```
"""
@inline function flip_bit(amp_index::Ti, bit::Tj)::Ti where {Ti<:Integer, Tj<:Integer}
return amp_index (one(Ti) << bit)
end

"""
flip_bits(amp_index::Integer, to_flip)
Flip the `bit`-th bit of `amp_index` for every `bit` in `to_flip`,
so that 0 becomes 1 and 1 becomes 0.
The first valid value of `bit` is **zero**.
# Examples
```jldoctest
julia> amp_index = 10
10
julia> digits(amp_index, base=2, pad=6)
6-element Vector{Int64}:
0
1
0
1
0
0
julia> amp_index = BraketSimulator.flip_bits(amp_index, (1, 3, 2))
4
julia> digits(amp_index, base=2, pad=6)
6-element Vector{Int64}:
0
0
1
0
0
0
```
"""
function flip_bits(ix::Ti, to_flip)::Ti where {Ti<:Integer}
flipped_ix = ix
for bit in to_flip
flipped_ix = flip_bit(flipped_ix, bit)
end
return flipped_ix
end
"""
endian_qubits(n_qubits::Int, qubit::Int)
Rotate the qubit index `qubit` to match what Braket expects with the
correct endianness. This has to be done because Braket and Julia have different
[endianness](https://en.wikipedia.org/wiki/Endianness).
!!! note
The first valid value for `qubit` is **zero**, since qubits are zero-indexed.
# Examples
```jldoctest
julia> qubit = 2
2
julia> n_qubits = 5
5
julia> BraketSimulator.endian_qubits(n_qubits, qubit)
2
julia> qubit = 3
3
julia> BraketSimulator.endian_qubits(n_qubits, qubit)
1
```
"""
@inline endian_qubits(n_qubits::Int, qubit::Int) = n_qubits - 1 - qubit
"""
endian_qubits(n_qubits::Int, qubits::Int...)
Rotate each qubit index in `qubits` to match what Braket expects with the
correct endianness. This has to be done because Braket and Julia have different
[endianness](https://en.wikipedia.org/wiki/Endianness).
!!! note
The first valid value for any element of `qubits` is **zero**,
since qubits are zero-indexed.
"""
@inline endian_qubits(n_qubits::Int, qubits::Int...) = n_qubits .- 1 .- qubits
"""
get_amps_and_qubits(state_vec::AbstractStateVector, qubits::Int...)
Get the total number of amplitudes of `state_vec` (its length) and use this
to apply [`endian_qubits`](@ref) to `qubits`. This is a convenience function
to automate several common operations.
!!! note
The first valid value for any element of `qubits` is **zero**,
since qubits are zero-indexed.
"""
@inline function get_amps_and_qubits(state_vec::AbstractStateVector, qubits::Int...)
n_amps = length(state_vec)
n_qubits = Int(log2(n_amps))
Expand Down Expand Up @@ -269,6 +462,12 @@ matrix_rep_raw(::ZZ, ϕ) = (θ = ϕ/2.0; return Diagonal(SVector{4}(exp(-im * θ
# 1/√2 * (IX - XY)
matrix_rep_raw(g::ECR) = SMatrix{4,4}(1/√2 * [0 1 0 im; 1 0 -im 0; 0 im 0 1; -im 0 1 0])
matrix_rep_raw(g::Unitary) = g.matrix
"""
matrix_rep(g::Gate)
Convert `g` into its matrix form, applying its argument values and any
exponent it is raised to.
"""
function matrix_rep(g::Gate)
n = g.pow_exponent
iszero(n) && matrix_rep_raw(I(), qubit_count(g))
Expand Down Expand Up @@ -311,12 +510,16 @@ function apply_gate!(
is_small_target = flipper < CHUNK_SIZE
g_00, g_10, g_01, g_11 = g_matrix
Threads.@threads for chunk_index = 0:n_chunks-1
# my_amps is the group of amplitude generators which this `Task` will process
# first_amp is the leading index in the group
# of amplitude generators which this `Task` will process
first_amp = n_chunks > 1 ? chunk_index*CHUNK_SIZE : 0
# amp_block is the total size of the block this `Task` will process
amp_block = n_chunks > 1 ? CHUNK_SIZE : n_tasks
lower_ix = pad_bit(first_amp, endian_qubit) + 1
higher_ix = lower_ix + flipper
for task_amp = 0:amp_block-1
# this avoids hitting an index pair already "touched" earlier in the block
# if 2 ^ qubit_index is smaller than the block size
if is_small_target && div(task_amp, flipper) > 0 && mod(task_amp, flipper) == 0
lower_ix = higher_ix
higher_ix = lower_ix + flipper
Expand Down Expand Up @@ -372,12 +575,13 @@ function apply_gate!(
return
end

# single controlled single target unitaries like CZ, CV, CPhaseShift
function apply_controlled_gate!(
g_matrix::Union{SMatrix{2,2,T}, Diagonal{T,SVector{2,T}}, Matrix{T}},
c_bit::Bool,
c_bit::Bool, # the bit-value to control on (0 or 1)
state_vec::AbstractStateVector{T},
control::Int,
target::Int,
control::Int, # the qubit to control on
target::Int, # the qubit to target
) where {T<:Complex}
n_amps, (endian_control, endian_target) =
get_amps_and_qubits(state_vec, control, target)
Expand All @@ -400,13 +604,14 @@ function apply_controlled_gate!(
end
return
end
# single controlled two target unitaries like CSWAP
function apply_controlled_gate!(
g_matrix::Union{SMatrix{4, 4, T}, Diagonal{T, SVector{4, T}}, Matrix{T}},
c_bit::Bool,
c_bit::Bool, # the bit-value to control on (0 or 1)
state_vec::AbstractStateVector{T},
control::Int,
target_1::Int,
target_2::Int,
control::Int, # the qubit to control on
target_1::Int, # the first qubit to target
target_2::Int, # the second qubit to target
) where {T<:Complex}
n_amps, (endian_control, endian_t1, endian_t2) = get_amps_and_qubits(state_vec, control, target_1, target_2)
small_t = min(endian_control, endian_t1, endian_t2)
Expand All @@ -429,11 +634,11 @@ function apply_controlled_gate!(
end
return
end
# doubly controlled unitaries
# doubly controlled single target unitaries like CCNot
function apply_controlled_gate!(
g_matrix::Union{SMatrix{2, 2, T}, Diagonal{T, SVector{2, T}}, Matrix{T}},
c1_bit::Bool,
c2_bit::Bool,
c1_bit::Bool, # the bit-value to control on (0 or 1) for the first control qubit
c2_bit::Bool, # the bit-value to control on (0 or 1) for the second control qubit
state_vec::AbstractStateVector{T},
control_1::Int,
control_2::Int,
Expand Down Expand Up @@ -463,6 +668,10 @@ function apply_controlled_gate!(
end
return
end
# these are "intermediate" dispatch methods which turn a `Gate` into the appropriate
# static matrix and dispatch to the appropriate kernel to apply it
# the `:conj` versions are there for *density matrices*, and apply the conjugated
# (but *not* transposed) version of the gate matrix.
for (V, f) in ((true, :conj), (false, :identity))
@eval begin
apply_gate!(::Val{$V}, gate::Control{G, B}, state_vec::AbstractStateVector{T}, qubits::Int...) where {T<:Complex, G<:Gate, B} = apply_controlled_gate!(Val($V), Val(B), gate, gate.g ^ gate.pow_exponent, state_vec, gate.bitvals, qubits...)
Expand Down Expand Up @@ -539,6 +748,8 @@ function apply_gate!(
apply_gate!(Diagonal(SVector{2^N, ComplexF64}(g_matrix)), state_vec, qubits...)
end

# fallback method for arbitrary unitaries with `NQ` targets
# such as a Unitary on 5 qubits
function apply_gate!(
g_matrix::Union{SMatrix{N, N, T}, Diagonal{T, SVector{N, T}}},
state_vec::AbstractStateVector{T},
Expand Down

0 comments on commit 7b23bee

Please sign in to comment.