diff --git a/src/atoms/IndexAtom.jl b/src/atoms/IndexAtom.jl index e3b3d76ac..6c99f5679 100644 --- a/src/atoms/IndexAtom.jl +++ b/src/atoms/IndexAtom.jl @@ -114,10 +114,7 @@ function Base.getindex(x::AbstractExpr, I::AbstractVector{Bool}) end # All rows and columns -function Base.getindex(x::AbstractExpr, ::Colon, ::Colon) - rows, cols = size(x) - return getindex(x, 1:rows, 1:cols) -end +Base.getindex(x::AbstractExpr, ::Colon, ::Colon) = x # All rows for this column(s) function Base.getindex(x::AbstractExpr, ::Colon, col) diff --git a/src/atoms/VcatAtom.jl b/src/atoms/VcatAtom.jl index 3b6d34df3..3a1c20a15 100644 --- a/src/atoms/VcatAtom.jl +++ b/src/atoms/VcatAtom.jl @@ -59,3 +59,71 @@ function Base.vcat(args::Union{AbstractExpr,Value}...) end return VcatAtom(args...) end + +function Base.getindex( + x::VcatAtom, + rows::AbstractVector{<:Real}, + cols::AbstractVector{<:Real}, +) + idx = 0 + rows = collect(rows) # make a mutable copy + keep_children = () + for c in x.children + # here are the row indices into `x` that point to `c` + I = idx .+ (1:size(c, 1)) + if issubset(rows, I) + # if all the row indices we want are in this one child, we can early exit + if rows == I && cols == 1:size(c, 2) + return c + else + return c[rows.-idx, cols] + end + elseif !isdisjoint(rows, I) + # we have some but not all rows in this child, so keep it + keep_children = (keep_children..., c) + idx += size(c, 1) + else # we can drop this child! + # let's update `rows` to account for the removal + l = last(I) + for i in eachindex(rows) + if rows[i] >= l + rows[i] -= length(I) + end + end + end + end + # If we are here, the indices span multiple children. + # We can't necessarily index each separately, since they may be out of order. + # So we will defer to an `IndexAtom` on the remaining children + remaining = VcatAtom(keep_children...) + return IndexAtom(remaining, rows, cols) +end + +# linear indexing: very similar to row-indexing above, but with linear indices +function Base.getindex(x::VcatAtom, inds::AbstractVector{<:Real}) + idx = 0 + inds = collect(inds) + keep_children = () + for c in x.children + I = idx .+ (1:length(c)) + if issubset(inds, I) + if inds == I + return c + else + return c[inds.-idx] + end + elseif !isdisjoint(inds, I) + keep_children = (keep_children..., c) + idx += length(c) + else + l = last(I) + for i in eachindex(inds) + if inds[i] >= l + inds[i] -= length(I) + end + end + end + end + remaining = VcatAtom(keep_children...) + return IndexAtom(remaining, inds) +end diff --git a/test/test_atoms.jl b/test/test_atoms.jl index dd2ab90be..2c373c74f 100644 --- a/test/test_atoms.jl +++ b/test/test_atoms.jl @@ -477,9 +477,8 @@ function test_IndexAtom() _test_atom(target) do context return Variable(2)[:, 1] end - _test_atom(target) do context - return Variable(2)[:, :] - end + x = Variable(2) + @test x[:, :] === x target = """ variables: x1, x2, x3 minobjective: [1.0 * x1, 1.0 * x3] @@ -814,6 +813,42 @@ function test_VcatAtom() return end +function test_VcatAtom_getindex() + x = Variable() + sq = square(x) + for v in [vcat(x, sq, -sq), vcat(vcat(x, sq), -sq)] + @test isequal(v[1], x) + @test isequal(v[2], sq) + @test isequal(v[:, 1], v[1:3, 1]) + @test isequal(v[1:3, :], v[1:3, 1:1]) + @test v[2:-1:1].children[1] isa Convex.VcatAtom + @test v[2:-1:1].children[1].children == (x, sq) + @test vexity(v[1]) isa Convex.AffineVexity + @test vexity(v[2]) isa Convex.ConvexVexity + @test vexity(v[3]) isa Convex.ConcaveVexity + @test vexity(v[1:2]) isa Convex.ConvexVexity + @test vexity(v[:, 1]) isa Convex.NotDcp + end + x = Variable(2, 2) + sq = square(x) + for v in [vcat(x, sq, -sq), vcat(vcat(x, sq), -sq)] + @test isequal(v[:, :], v) + @test isequal(v[1:2, 1:2], x) + @test v[1, 1:2] isa Convex.IndexAtom + @test v[1:3, 1:2] isa Convex.IndexAtom + @test v[1:3, 1:2].children[1] isa Convex.VcatAtom + @test isequal(v[3:4, 1:2], sq) + @test v[3:-1:1, 1].children[1] isa Convex.VcatAtom + @test v[4:-1:1, :].children[1].children == (x, sq) + @test vexity(v[1]) isa Convex.AffineVexity + @test vexity(v[3:4, 1:2]) isa Convex.ConvexVexity + @test vexity(v[5:6, 1]) isa Convex.ConcaveVexity + @test vexity(v[5:6, 1:2]) isa Convex.ConcaveVexity + @test vexity(v[1:3, 1]) isa Convex.ConvexVexity + end + return +end + ### exp_+_sdp_cone/LogDetAtom function test_LogDetAtom()