From 7b23beedb06866c3fa75c2df67acf48856f0a829 Mon Sep 17 00:00:00 2001 From: Katharine Hyatt <67932820+kshyatt-aws@users.noreply.github.com> Date: Tue, 21 Jan 2025 11:01:25 -0500 Subject: [PATCH] docs: More docstrings and comments for gate kernels (#62) --- docs/src/internals.md | 7 ++ src/gate_kernels.jl | 235 +++++++++++++++++++++++++++++++++++++++--- 2 files changed, 230 insertions(+), 12 deletions(-) diff --git a/docs/src/internals.md b/docs/src/internals.md index 218fc33..fa844ed 100644 --- a/docs/src/internals.md +++ b/docs/src/internals.md @@ -11,4 +11,11 @@ BraketSimulator._combine_operations BraketSimulator._prepare_program BraketSimulator._get_measured_qubits BraketSimulator._compute_results +BraketSimulator.flip_bit +BraketSimulator.flip_bits +BraketSimulator.pad_bit +BraketSimulator.pad_bits +BraketSimulator.matrix_rep +BraketSimulator.endian_qubits +BraketSimulator.get_amps_and_qubits ``` diff --git a/src/gate_kernels.jl b/src/gate_kernels.jl index afeaca6..af3274c 100644 --- a/src/gate_kernels.jl +++ b/src/gate_kernels.jl @@ -1,8 +1,80 @@ +""" + pad_bit(amp_index::Integer, bit::Integer) + +Insert a `0` at location `bit` of `amp_index` (in its bits representation). +The first valid value of `bit` is **zero**. + +# Examples +```jldoctest +julia> amp_index = 10 +10 + +julia> digits(amp_index, base=2, pad=6) +6-element Vector{Int64}: + 0 + 1 + 0 + 1 + 0 + 0 + +julia> amp_index = BraketSimulator.pad_bit(amp_index, 2); + +julia> digits(amp_index, base=2, pad=7) +7-element Vector{Int64}: + 0 + 1 + 0 + 0 + 1 + 0 + 0 +``` +""" @inline function pad_bit(amp_index::Ti, bit::Tj)::Ti where {Ti<:Integer,Tj<:Integer} left = (amp_index >> bit) << bit right = amp_index - left return (left << one(Ti)) ⊻ right end +""" + pad_bits(amp_index::Integer, to_pad) + +Insert a `0` in `amp_index` at each location `bit` in the collection `to_pad`. +The first valid value of any `bit` is **zero**. + +# Examples +```jldoctest +julia> amp_index = 10 +10 + +julia> digits(amp_index, base=2, pad=6) +6-element Vector{Int64}: + 0 + 1 + 0 + 1 + 0 + 0 + +julia> amp_index = BraketSimulator.pad_bits(amp_index, (2, 4)); + +julia> digits(amp_index, base=2, pad=8) +8-element Vector{Int64}: + 0 + 1 + 0 + 0 + 0 + 1 + 0 + 0 +``` + +!!! note + + The indices in `pad_bits` aren't adjusted based on previous indices -- this can be seen in the above example, + where the bit at index 4 is different **before** and **after** inserting a bit at index 2. +""" function pad_bits(ix::Ti, to_pad)::Ti where {Ti<:Integer} padded_ix = ix for bit in to_pad @@ -11,9 +83,77 @@ function pad_bits(ix::Ti, to_pad)::Ti where {Ti<:Integer} return padded_ix end -@inline function flip_bit(amp_index::Ti, bit::Tj)::Ti where {Ti<:Integer,Tj<:Integer} +""" + flip_bit(amp_index::Integer, bit::Integer) + +Flip the `bit`-th bit of `amp_index`, so that 0 becomes 1 and 1 becomes 0. +The first valid value of `bit` is **zero**. + +# Examples +```jldoctest +julia> amp_index = 10 +10 + +julia> digits(amp_index, base=2, pad=6) +6-element Vector{Int64}: + 0 + 1 + 0 + 1 + 0 + 0 + +julia> amp_index = BraketSimulator.flip_bit(amp_index, 1) +8 + +julia> digits(amp_index, base=2, pad=6) +6-element Vector{Int64}: + 0 + 0 + 0 + 1 + 0 + 0 +``` +""" +@inline function flip_bit(amp_index::Ti, bit::Tj)::Ti where {Ti<:Integer, Tj<:Integer} return amp_index ⊻ (one(Ti) << bit) end + +""" + flip_bits(amp_index::Integer, to_flip) + +Flip the `bit`-th bit of `amp_index` for every `bit` in `to_flip`, +so that 0 becomes 1 and 1 becomes 0. +The first valid value of `bit` is **zero**. + +# Examples +```jldoctest +julia> amp_index = 10 +10 + +julia> digits(amp_index, base=2, pad=6) +6-element Vector{Int64}: + 0 + 1 + 0 + 1 + 0 + 0 + +julia> amp_index = BraketSimulator.flip_bits(amp_index, (1, 3, 2)) +4 + +julia> digits(amp_index, base=2, pad=6) +6-element Vector{Int64}: + 0 + 0 + 1 + 0 + 0 + 0 +``` +""" function flip_bits(ix::Ti, to_flip)::Ti where {Ti<:Integer} flipped_ix = ix for bit in to_flip @@ -21,8 +161,61 @@ function flip_bits(ix::Ti, to_flip)::Ti where {Ti<:Integer} end return flipped_ix end +""" + endian_qubits(n_qubits::Int, qubit::Int) + +Rotate the qubit index `qubit` to match what Braket expects with the +correct endianness. This has to be done because Braket and Julia have different +[endianness](https://en.wikipedia.org/wiki/Endianness). + +!!! note + + The first valid value for `qubit` is **zero**, since qubits are zero-indexed. + +# Examples +```jldoctest +julia> qubit = 2 +2 + +julia> n_qubits = 5 +5 + +julia> BraketSimulator.endian_qubits(n_qubits, qubit) +2 + +julia> qubit = 3 +3 + +julia> BraketSimulator.endian_qubits(n_qubits, qubit) +1 +``` +""" @inline endian_qubits(n_qubits::Int, qubit::Int) = n_qubits - 1 - qubit +""" + endian_qubits(n_qubits::Int, qubits::Int...) + +Rotate each qubit index in `qubits` to match what Braket expects with the +correct endianness. This has to be done because Braket and Julia have different +[endianness](https://en.wikipedia.org/wiki/Endianness). + +!!! note + + The first valid value for any element of `qubits` is **zero**, + since qubits are zero-indexed. +""" @inline endian_qubits(n_qubits::Int, qubits::Int...) = n_qubits .- 1 .- qubits +""" + get_amps_and_qubits(state_vec::AbstractStateVector, qubits::Int...) + +Get the total number of amplitudes of `state_vec` (its length) and use this +to apply [`endian_qubits`](@ref) to `qubits`. This is a convenience function +to automate several common operations. + +!!! note + + The first valid value for any element of `qubits` is **zero**, + since qubits are zero-indexed. +""" @inline function get_amps_and_qubits(state_vec::AbstractStateVector, qubits::Int...) n_amps = length(state_vec) n_qubits = Int(log2(n_amps)) @@ -269,6 +462,12 @@ matrix_rep_raw(::ZZ, ϕ) = (θ = ϕ/2.0; return Diagonal(SVector{4}(exp(-im * θ # 1/√2 * (IX - XY) matrix_rep_raw(g::ECR) = SMatrix{4,4}(1/√2 * [0 1 0 im; 1 0 -im 0; 0 im 0 1; -im 0 1 0]) matrix_rep_raw(g::Unitary) = g.matrix +""" + matrix_rep(g::Gate) + +Convert `g` into its matrix form, applying its argument values and any +exponent it is raised to. +""" function matrix_rep(g::Gate) n = g.pow_exponent iszero(n) && matrix_rep_raw(I(), qubit_count(g)) @@ -311,12 +510,16 @@ function apply_gate!( is_small_target = flipper < CHUNK_SIZE g_00, g_10, g_01, g_11 = g_matrix Threads.@threads for chunk_index = 0:n_chunks-1 - # my_amps is the group of amplitude generators which this `Task` will process + # first_amp is the leading index in the group + # of amplitude generators which this `Task` will process first_amp = n_chunks > 1 ? chunk_index*CHUNK_SIZE : 0 + # amp_block is the total size of the block this `Task` will process amp_block = n_chunks > 1 ? CHUNK_SIZE : n_tasks lower_ix = pad_bit(first_amp, endian_qubit) + 1 higher_ix = lower_ix + flipper for task_amp = 0:amp_block-1 + # this avoids hitting an index pair already "touched" earlier in the block + # if 2 ^ qubit_index is smaller than the block size if is_small_target && div(task_amp, flipper) > 0 && mod(task_amp, flipper) == 0 lower_ix = higher_ix higher_ix = lower_ix + flipper @@ -372,12 +575,13 @@ function apply_gate!( return end +# single controlled single target unitaries like CZ, CV, CPhaseShift function apply_controlled_gate!( g_matrix::Union{SMatrix{2,2,T}, Diagonal{T,SVector{2,T}}, Matrix{T}}, - c_bit::Bool, + c_bit::Bool, # the bit-value to control on (0 or 1) state_vec::AbstractStateVector{T}, - control::Int, - target::Int, + control::Int, # the qubit to control on + target::Int, # the qubit to target ) where {T<:Complex} n_amps, (endian_control, endian_target) = get_amps_and_qubits(state_vec, control, target) @@ -400,13 +604,14 @@ function apply_controlled_gate!( end return end +# single controlled two target unitaries like CSWAP function apply_controlled_gate!( g_matrix::Union{SMatrix{4, 4, T}, Diagonal{T, SVector{4, T}}, Matrix{T}}, - c_bit::Bool, + c_bit::Bool, # the bit-value to control on (0 or 1) state_vec::AbstractStateVector{T}, - control::Int, - target_1::Int, - target_2::Int, + control::Int, # the qubit to control on + target_1::Int, # the first qubit to target + target_2::Int, # the second qubit to target ) where {T<:Complex} n_amps, (endian_control, endian_t1, endian_t2) = get_amps_and_qubits(state_vec, control, target_1, target_2) small_t = min(endian_control, endian_t1, endian_t2) @@ -429,11 +634,11 @@ function apply_controlled_gate!( end return end -# doubly controlled unitaries +# doubly controlled single target unitaries like CCNot function apply_controlled_gate!( g_matrix::Union{SMatrix{2, 2, T}, Diagonal{T, SVector{2, T}}, Matrix{T}}, - c1_bit::Bool, - c2_bit::Bool, + c1_bit::Bool, # the bit-value to control on (0 or 1) for the first control qubit + c2_bit::Bool, # the bit-value to control on (0 or 1) for the second control qubit state_vec::AbstractStateVector{T}, control_1::Int, control_2::Int, @@ -463,6 +668,10 @@ function apply_controlled_gate!( end return end +# these are "intermediate" dispatch methods which turn a `Gate` into the appropriate +# static matrix and dispatch to the appropriate kernel to apply it +# the `:conj` versions are there for *density matrices*, and apply the conjugated +# (but *not* transposed) version of the gate matrix. for (V, f) in ((true, :conj), (false, :identity)) @eval begin apply_gate!(::Val{$V}, gate::Control{G, B}, state_vec::AbstractStateVector{T}, qubits::Int...) where {T<:Complex, G<:Gate, B} = apply_controlled_gate!(Val($V), Val(B), gate, gate.g ^ gate.pow_exponent, state_vec, gate.bitvals, qubits...) @@ -539,6 +748,8 @@ function apply_gate!( apply_gate!(Diagonal(SVector{2^N, ComplexF64}(g_matrix)), state_vec, qubits...) end +# fallback method for arbitrary unitaries with `NQ` targets +# such as a Unitary on 5 qubits function apply_gate!( g_matrix::Union{SMatrix{N, N, T}, Diagonal{T, SVector{N, T}}}, state_vec::AbstractStateVector{T},