Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mat index bug #249

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/ComponentArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ export fastindices # Deprecated
include("lazyarray.jl")

include("axis.jl")
export AbstractAxis, Axis, PartitionedAxis, ShapedAxis, ViewAxis, FlatAxis
export AbstractAxis, Axis, PartitionedAxis, ShapedAxis, Shaped1DAxis, ViewAxis, FlatAxis

include("componentarray.jl")
export ComponentArray, ComponentVector, ComponentMatrix, getaxes, getdata, valkeys
Expand Down
19 changes: 15 additions & 4 deletions src/axis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,20 @@ example)
"""
struct ShapedAxis{Shape} <: AbstractAxis{nothing} end
@inline ShapedAxis(Shape) = ShapedAxis{Shape}()
ShapedAxis(::Tuple{<:Int}) = FlatAxis()
# ShapedAxis(::Tuple{<:Int}) = FlatAxis()

struct Shaped1DAxis{Shape} <: AbstractAxis{nothing} end
ShapedAxis(shape::Tuple{<:Int}) = Shaped1DAxis{shape}()
Shaped1DAxis(shape::Tuple{<:Int}) = Shaped1DAxis{shape}()

const Shape = ShapedAxis

unshape(ax) = ax
unshape(ax::ShapedAxis) = Axis(indexmap(ax))
unshape(ax::Shaped1DAxis) = Axis(indexmap(ax))

Base.size(::ShapedAxis{Shape}) where {Shape} = Shape
Base.size(::Shaped1DAxis{Shape}) where {Shape} = Shape



Expand Down Expand Up @@ -133,9 +139,9 @@ Axis(::Number) = NullAxis()
Axis(::NamedTuple{()}) = FlatAxis()
Axis(x) = FlatAxis()

const NotShapedAxis = Union{Axis{IdxMap}, FlatAxis, NullAxis} where {IdxMap}
const NotPartitionedAxis = Union{Axis{IdxMap}, FlatAxis, NullAxis, ShapedAxis{Shape}} where {Shape, IdxMap}
const NotShapedOrPartitionedAxis = Union{Axis{IdxMap}, FlatAxis, NullAxis} where {IdxMap}
const NotShapedAxis = Union{Axis{IdxMap}, FlatAxis, NullAxis, Shaped1DAxis} where {IdxMap}
const NotPartitionedAxis = Union{Axis{IdxMap}, FlatAxis, NullAxis, ShapedAxis{Shape}, Shaped1DAxis} where {Shape, IdxMap}
const NotShapedOrPartitionedAxis = Union{Axis{IdxMap}, FlatAxis, Shaped1DAxis} where {IdxMap}


Base.merge(axs::Vararg{Axis}) = Axis(merge(indexmap.(axs)...))
Expand All @@ -149,6 +155,10 @@ reindex(i, offset) = i .+ offset
reindex(ax::FlatAxis, _) = ax
reindex(ax::Axis, offset) = Axis(map(x->reindex(x, offset), indexmap(ax)))
reindex(ax::ViewAxis, offset) = ViewAxis(viewindex(ax) .+ offset, indexmap(ax))
function reindex(ax::ViewAxis{OldInds,IdxMap,Ax}, offset) where {OldInds,IdxMap,Ax<:Shaped1DAxis}
NewInds = viewindex(ax) .+ offset
return ViewAxis(NewInds, Ax())
end

# Get AbstractAxis index
@inline Base.getindex(::AbstractAxis, idx) = ComponentIndex(idx)
Expand All @@ -175,6 +185,7 @@ end

_maybe_view_axis(inds, ax::Axis) = ViewAxis(inds, ax)
_maybe_view_axis(inds, ::NullAxis) = inds[1]
_maybe_view_axis(inds, ax::Union{ShapedAxis,Shaped1DAxis}) = ViewAxis(inds, ax)

struct CombinedAxis{C,A} <: AbstractUnitRange{Int}
component_axis::C
Expand Down
3 changes: 2 additions & 1 deletion src/compat/static_arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ end

_maybe_SArray(x::SubArray, ::Val{N}, ::FlatAxis) where {N} = SVector{N}(x)
_maybe_SArray(x::Base.ReshapedArray, ::Val, ::ShapedAxis{Sz}) where {Sz} = SArray{Tuple{Sz...}}(x)
_maybe_SArray(x, ::Val, ::Shaped1DAxis{Sz}) where {Sz} = SArray{Tuple{Sz...}}(x)
_maybe_SArray(x, vals...) = x

@generated function static_getproperty(ca::ComponentVector, ::Val{s}) where {s}
Expand Down Expand Up @@ -32,4 +33,4 @@ macro static_unpack(expr)
push!(out.args, :($esc_name = static_getproperty($parent_var_name, $(Val(name)))))
end
return out
end
end
5 changes: 3 additions & 2 deletions src/componentarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ ComponentArray{T}(::UndefInitializer, ax::Axes) where {T,Axes<:Tuple} =

# Entry from data array and AbstractAxis types dispatches to correct shapes and partitions
# then packs up axes into a tuple for inner constructor
ComponentArray(data, ::FlatAxis...) = data
# ComponentArray(data, ::FlatAxis...) = data
ComponentArray(data, ::Union{FlatAxis,Shaped1DAxis}...) = data
ComponentArray(data, ax::NotShapedOrPartitionedAxis...) = ComponentArray(data, ax)
ComponentArray(data, ax::NotPartitionedAxis...) = ComponentArray(maybe_reshape(data, ax...), unshape.(ax)...)
function ComponentArray(data, ax::AbstractAxis...)
Expand Down Expand Up @@ -232,7 +233,7 @@ end
# Reshape ComponentArrays with ShapedAxis axes
maybe_reshape(data, ::NotShapedOrPartitionedAxis...) = data
function maybe_reshape(data, axs::AbstractAxis...)
shapes = filter_by_type(ShapedAxis, axs...) .|> size
shapes = filter_by_type(Union{ShapedAxis,Shaped1DAxis}, axs...) .|> size
shapes = reduce((tup, s) -> (tup..., s...), shapes)
return reshape(data, shapes)
end
Expand Down
4 changes: 3 additions & 1 deletion src/componentindex.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ struct ComponentIndex{Idx, Ax<:AbstractAxis}
ax::Ax
end
ComponentIndex(idx) = ComponentIndex(idx, FlatAxis())
ComponentIndex(idx::CartesianIndex) = ComponentIndex(idx, ShapedAxis((1,)))
ComponentIndex(idx::AbstractArray{<:Integer}) = ComponentIndex(idx, ShapedAxis(size(idx)))
ComponentIndex(idx::Int) = ComponentIndex(idx, NullAxis())
ComponentIndex(vax::ViewAxis{Inds,IdxMap,Ax}) where {Inds,IdxMap,Ax} = ComponentIndex(Inds, vax.ax)

Expand Down Expand Up @@ -44,4 +46,4 @@ function _getindex_keep(ax::AbstractAxis, sym::Symbol)
end
new_ax = reindex(new_ax, -first(idx)+1)
return ComponentIndex(idx, new_ax)
end
end
2 changes: 2 additions & 0 deletions src/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ Base.show(io::IO, ::PartitionedAxis{PartSz, IdxMap, Ax}) where {PartSz, IdxMap,

Base.show(io::IO, ::ShapedAxis{Shape}) where {Shape} =
print(io, "ShapedAxis($Shape)")
Base.show(io::IO, ::Shaped1DAxis{Shape}) where {Shape} =
print(io, "Shaped1DAxis($Shape)")

Base.show(io::IO, ::MIME"text/plain", ::ViewAxis{Inds, IdxMap, Ax}) where {Inds, IdxMap, Ax} =
print(io, "ViewAxis($Inds, $(Ax()))")
Expand Down
41 changes: 31 additions & 10 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,16 @@ using OffsetArrays
using Test
using Unitful

# Convert abstract unit range to a ViewAxis with ShapeAxis.
r2v(r::AbstractUnitRange) = ViewAxis(r, ShapedAxis(size(r)))

## Test setup
c = (a = (a = 1, b = [1.0, 4.4]), b = [0.4, 2, 1, 45])
nt = (a = 100, b = [4, 1.3], c = c)
nt2 = (a = 5, b = [(a = (a = 20, b = 1), b = 0), (a = (a = 33, b = 1), b = 0)], c = (a = (a = 2, b = [1, 2]), b = [1.0 2.0; 5 6]))

ax = Axis(a = 1, b = 2:3, c = ViewAxis(4:10, (a = ViewAxis(1:3, (a = 1, b = 2:3)), b = 4:7)))
ax_c = (a = ViewAxis(1:3, (a = 1, b = 2:3)), b = 4:7)
ax = Axis(a = 1, b = r2v(2:3), c = ViewAxis(4:10, (a = ViewAxis(1:3, (a = 1, b = r2v(2:3))), b = r2v(4:7))))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm. I guess I don't understand why we'd need to wrap this in a new type. In the example you gave in the issue, b is a vector element, not an nx1 matrix. I don't think we should introduce a new ShapedAxis1d type if all vector elements are behaving incorrectly. We should just fix the issue of adjoint/transposition that's broken directly. So I guess specifically, the ShapedAxis needs to be created from the FlatAxis during the adjoint operation, not beforehand. And for that, we can use the normal ShapedAxis without having to introduce a new type.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jonniedie My bad, just saw this comment.
Hmm... not sure I follow, but fwiw I think everything is fine with the transpose—the issue comes up later on.
The problem with the test case in issue #249, aka

ni = 4
nj = 2
nk = 3

X = ComponentVector(a=0.0, b=zeros(Float64, nk), c=zeros(Float64, ni, nj))
Y = ComponentVector(d=zeros(Float64, ni, nj))
J = Y.*X'

appears to be due to the fact that the shape of a unit range (which is used for vector components like X[:b] in the example) isn't understood by ComponentArrays.maybe_reshape. The error I get:

julia> J[:d, :b]
ERROR: DimensionMismatch: new dimensions (4, 2) must be consistent with array size 24
Stacktrace:
 [1] (::Base.var"#throw_dmrsa#328")(dims::Tuple{Int64, Int64}, len::Int64)
   @ Base ./reshapedarray.jl:41
 [2] reshape
   @ ./reshapedarray.jl:45 [inlined]
 [3] maybe_reshape
   @ ~/projects/componentarrays_bugs/dev/ComponentArrays/src/componentarray.jl:237 [inlined]
 [4] ComponentArray
   @ ~/projects/componentarrays_bugs/dev/ComponentArrays/src/componentarray.jl:52 [inlined]
 [5] macro expansion
   @ ~/projects/componentarrays_bugs/dev/ComponentArrays/src/array_interface.jl:0 [inlined]
 [6] _getindex(::typeof(getindex), ::ComponentMatrix{Float64, Matrix{Float64}, Tuple{Axis{…}, Axis{…}}}, ::Val{:d}, ::Val{:b})
   @ ComponentArrays ~/projects/componentarrays_bugs/dev/ComponentArrays/src/array_interface.jl:119
 [7] getindex
   @ ~/projects/componentarrays_bugs/dev/ComponentArrays/src/array_interface.jl:103 [inlined]
 [8] getindex(::ComponentMatrix{Float64, Matrix{Float64}, Tuple{Axis{…}, Axis{…}}}, ::Symbol, ::Symbol)
   @ ComponentArrays ~/projects/componentarrays_bugs/dev/ComponentArrays/src/array_interface.jl:102
 [9] top-level scope
   @ REPL[16]:1
Some type information was truncated. Use `show(err)` to see complete types.

shell> 

In maybe_reshape:

# Reshape ComponentArrays with ShapedAxis axes
maybe_reshape(data, ::NotShapedOrPartitionedAxis...) = data
function maybe_reshape(data, axs::AbstractAxis...)
    shapes = filter_by_type(ShapedAxis, axs...) .|> size
    shapes = reduce((tup, s) -> (tup..., s...), shapes)
    return reshape(data, shapes)
end

during J[:d, :b], axs is

axs = (ShapedAxis((4, 2)), FlatAxis())

i.e. the unit range is converted to a FlatAxis, which is then ignored when defining shapes in maybe_reshape.

The axes of the transpose X' look fine to me:

julia> getaxes(X')
(FlatAxis(), Axis(a = 1, b = 2:4, c = ViewAxis(5:12, ShapedAxis((4, 2)))))

julia> 

The point of the Shaped1DAxis is to have a type that can be a part of the NotShapedOrPartitionedAxis Union defined in axis.jl.
This allows it to skip the problematic and unnecessary maybe_reshape for vector components.
If there was a way to dispatch on ShapedAxis{Shape} when Shape was a length-1 Tuple then I don't think we'd need Shaped1DAxis, but I don't know of a way to do that.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jonniedie Another polite ping re: this PR. :-)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jonniedie Another polite ping re: this PR.

ax_c = (a = ViewAxis(1:3, (a = 1, b = r2v(2:3))), b = r2v(4:7))

a = Float64[100, 4, 1.3, 1, 1, 4.4, 0.4, 2, 1, 45]
sq_mat = collect(reshape(1:9, 3, 3))
Expand All @@ -36,6 +38,9 @@ caa = ComponentArray(a = ca, b = sq_mat)

_a, _b, _c = Val.((:a, :b, :c))

ca3 = ComponentArray(a=1, b=[2, 3, 4, 5], c=reshape(6:11, 3, 2))
cmat3 = ca3 .* ca3'
cmat3check = (1:11) .* (1:11)'

## Tests
@testset "Allocations and Inference" begin
Expand Down Expand Up @@ -132,7 +137,7 @@ end
for T in [Int64, Int32, Float64, Float32, ComplexF64, ComplexF32]
@test ComponentArray(a = T[]) == ComponentVector{T}(a = T[])
@test ComponentArray(a = T[], b = T[]) == ComponentVector{T}(a = T[], b = T[])
@test ComponentArray(a = T[], b = (;)) == ComponentVector{T}(a = T[], b = T[])
@test_broken ComponentArray(a = T[], b = (;)) == ComponentVector{T}(a = T[], b = T[])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's important we don't break this one because the behavior is relied upon in some some large simulation projects. This allows ComponentArrays that match a nested model structure to be initialized when one of the internal models doesn't have integrable state.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha, thanks. I'll take a look at it again when I get a chance.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jonniedie Another polite ping. :-)

@test ComponentArray(a = Any[one(Int32)], b=T[]) == ComponentVector{T}(a = [one(T)], b = T[])
end
@test ComponentArray(NamedTuple()) == ComponentVector{Any}()
Expand Down Expand Up @@ -285,6 +290,17 @@ end
# OffsetArrays' type piracy without introducing type piracy
# ourselves because `() isa Tuple{N, <:CombinedAxis} where {N}`
# @test reshape(a, axes(ca)...) isa Vector{Float64}

# Issue #248: Indexing ComponentMatrix with FlatAxis components
@test cmat3[:a, :a] == cmat3check[1, 1]
@test cmat3[:a, :b] == cmat3check[1, 2:5]
@test cmat3[:a, :c] == reshape(cmat3check[1, 6:11], 3, 2)
@test cmat3[:b, :a] == cmat3check[2:5, 1]
@test cmat3[:b, :b] == cmat3check[2:5, 2:5]
@test cmat3[:b, :c] == reshape(cmat3check[2:5, 6:11], 4, 3, 2)
@test cmat3[:c, :a] == reshape(cmat3check[6:11, 1], 3, 2)
@test cmat3[:c, :b] == reshape(cmat3check[6:11, 2:5], 3, 2, 4)
@test cmat3[:c, :c] == reshape(cmat3check[6:11, 6:11], 3, 2, 3, 2)
end

@testset "Set" begin
Expand Down Expand Up @@ -327,7 +343,7 @@ end
temp = deepcopy(cmat)
@test all((temp[:c, :c][:a, :a] .= 0) .== 0)

A = ComponentArray(zeros(Int, 4, 4), Axis(x = 1:4), Axis(x = 1:4))
A = ComponentArray(zeros(Int, 4, 4), Axis(x = r2v(1:4)), Axis(x = r2v(1:4)))
A[1, :] .= 1
@test A[1, :] == ComponentVector(x = ones(Int, 4))
end
Expand All @@ -337,12 +353,17 @@ end
ca = ComponentArray(a = 1, b = 2, c = [3, 4], d = (a = [5, 6, 7], b = 8))
cmat = ca * ca'

cidx = reshape((1:(2*3)) .+ 2, 2, 3)
ca2 = ComponentArray(a = 1, b = 2, c = cidx, d = (a = [9, 10, 11], b = 12))

@testset "ComponentIndex" begin
ax = getaxes(ca)[1]
@test ax[:a] == ax[1] == ComponentArrays.ComponentIndex(1, ComponentArrays.NullAxis())
@test ax[:c] == ax[3:4] == ComponentArrays.ComponentIndex(3:4, FlatAxis())
@test ax[:d] == ComponentArrays.ComponentIndex(5:8, Axis(a = 1:3, b = 4))
@test ax[(:a, :c)] == ax[[:a, :c]] == ComponentArrays.ComponentIndex([1, 3, 4], Axis(a = 1, c = 2:3))
@test ax[:c] == ax[3:4] == ComponentArrays.ComponentIndex(3:4, ShapedAxis(size(3:4)))
@test ax[:d] == ComponentArrays.ComponentIndex(5:8, Axis(a = r2v(1:3), b = 4))
@test ax[(:a, :c)] == ax[[:a, :c]] == ComponentArrays.ComponentIndex([1, 3, 4], Axis(a = 1, c = r2v(2:3)))
ax2 = getaxes(ca2)[1]
@test ax2[(:a, :c)] == ax2[[:a, :c]] == ComponentArrays.ComponentIndex([1, 3:8...], Axis(a = 1, c = ViewAxis(2:7, ShapedAxis((2,3)))))
end

@testset "KeepIndex" begin
Expand All @@ -353,14 +374,14 @@ end

@test ca[KeepIndex(1:2)] == ComponentArray(a = 1, b = 2)
@test ca[KeepIndex(1:3)] == ComponentArray([1, 2, 3], Axis(a = 1, b = 2)) # Drops c axis
@test ca[KeepIndex(2:5)] == ComponentArray([2, 3, 4, 5], Axis(b = 1, c = 2:3))
@test ca[KeepIndex(2:5)] == ComponentArray([2, 3, 4, 5], Axis(b = 1, c = r2v(2:3)))
@test ca[KeepIndex(3:end)] == ComponentArray(c = [3, 4], d = (a = [5, 6, 7], b = 8))

@test ca[KeepIndex(:)] == ca

@test cmat[KeepIndex(:a), KeepIndex(:b)] == ComponentArray(fill(2, 1, 1), Axis(a = 1), Axis(b = 1))
@test cmat[KeepIndex(:), KeepIndex(:c)] == ComponentArray((1:8) * (3:4)', getaxes(ca)[1], Axis(c = 1:2))
@test cmat[KeepIndex(2:5), 1:2] == ComponentArray((2:5) * (1:2)', Axis(b = 1, c = 2:3), FlatAxis())
@test cmat[KeepIndex(:), KeepIndex(:c)] == ComponentArray((1:8) * (3:4)', getaxes(ca)[1], Axis(c = r2v(1:2)))
@test cmat[KeepIndex(2:5), 1:2] == ComponentArray((2:5) * (1:2)', Axis(b = 1, c = r2v(2:3)), ShapedAxis(size(1:2)))
@test cmat[KeepIndex(2), KeepIndex(3)] == ComponentArray(fill(2 * 3, 1, 1), Axis(b = 1), FlatAxis())
@test cmat[KeepIndex(2), 3] == ComponentArray(b = 2 * 3)
end
Expand Down
Loading