Skip to content

Commit

Permalink
generalize to complex case and linear indexing
Browse files Browse the repository at this point in the history
  • Loading branch information
ericphanson authored and odow committed May 5, 2024
1 parent cc20976 commit 288f874
Showing 1 changed file with 49 additions and 50 deletions.
99 changes: 49 additions & 50 deletions src/atoms/IndexAtom.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,60 +42,59 @@ function evaluate(x::IndexAtom)
return output(result)
end

function new_conic_form!(context::Context{T}, x::IndexAtom) where {T}
obj = conic_form!(context, only(AbstractTrees.children(x)))
m = length(x)
n = length(x.children[1])
function _index_real!(
context::Context{T},
obj_size::Tuple,
obj_tape::Union{SparseTape{T},SPARSE_VECTOR{T}},
x::IndexAtom,
) where {T}
if x.inds === nothing
sz = length(x.cols) * length(x.rows)
if !iscomplex(x) # only real case handled here, for now
obj = x.children[1]
obj_tape = conic_form!(context, obj)
linear_indices =
LinearIndices(CartesianIndices(size(obj)))[x.rows, x.cols]
# Here, we are in the real case, so `obj_tape` is either a `SparseTape{T}`, or a `Vector{T}`.
# In the latter case, we can handle it directly
if obj_tape isa Vector{T}
return obj_tape[vec(linear_indices)]
end
# Ok, in this case we have actual work to do. We will construct an auxiliary variable `out`,
# which we will return, and we will constrain it to the values we want.
# This speeds up formulation since we reduce the problem size, and send what we have over to MOI already.
out = Variable(sz)
out_tape = conic_form!(context, out)
for (i, I) in enumerate(linear_indices)
# For each index, we constrain an element of `out` via ScalarAffineFunction to the indexed value.
saf = to_saf(obj_tape, I)
push!(
saf.terms,
MOI.ScalarAffineTerm(T(-1), out_tape.variables[i]),
)
MOI.add_constraint(context.model, saf, MOI.EqualTo(T(0)))
end
return out_tape
else
J = Vector{Int}(undef, sz)
k = 1
num_rows = x.children[1].size[1]
for c in x.cols
for r in x.rows
J[k] = num_rows * (convert(Int, c) - 1) + convert(Int, r)
k += 1
end
end
index_matrix = create_sparse(T, collect(1:sz), J, one(T), m, n)
end
linear_indices =
LinearIndices(CartesianIndices(obj_size))[x.rows, x.cols]
else
index_matrix = create_sparse(
T,
collect(1:length(x.inds)),
collect(x.inds),
one(T),
m,
n,
linear_indices = collect(x.inds)
end
sz = length(linear_indices)

# Here, we are in the real case, so `obj_tape` is either a `SparseTape{T}`, or a `SPARSE_VECTOR`.
# In the latter case, we can handle it directly
if obj_tape isa SPARSE_VECTOR
return obj_tape[vec(linear_indices)]
end
# Ok, in this case we have actual work to do. We will construct an auxiliary variable `out`,
# which we will return, and we will constrain it to the values we want.
# This speeds up formulation since we reduce the problem size, and send what we have over to MOI already.
out = Variable(sz)
out_tape = conic_form!(context, out)
for (i, I) in enumerate(linear_indices)
# For each index, we constrain an element of `out` via ScalarAffineFunction to the indexed value.
saf = to_saf(obj_tape, I)
push!(saf.terms, MOI.ScalarAffineTerm(T(-1), out_tape.variables[i]))
MOI.Utilities.normalize_and_add_constraint(
context.model,
saf,
MOI.EqualTo(T(0)),
)
end
return operate(add_operation, T, sign(x), index_matrix, obj)
return out_tape
end

function new_conic_form!(context::Context{T}, x::IndexAtom) where {T}
input = x.children[1]
if !iscomplex(x) # real case
input_tape = conic_form!(context, input)
return _index_real!(context, size(input), input_tape, x)
else # complex case
input_tape = conic_form!(context, input)
re = _index_real!(context, size(input), real(input_tape), x)
im = _index_real!(context, size(input), imag(input_tape), x)
if re isa SPARSE_VECTOR
@assert im isa SPARSE_VECTOR
return ComplexStructOfVec(re, im)
else
return ComplexTape(re, im)
end
end
end

function Base.getindex(
Expand Down

0 comments on commit 288f874

Please sign in to comment.