-
-
Notifications
You must be signed in to change notification settings - Fork 35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Mat index bug #249
base: main
Are you sure you want to change the base?
Mat index bug #249
Changes from 9 commits
90226ad
808177f
7b3767c
571d349
def0aca
0db92ae
f6e4620
01d7c6e
0d9d5d2
b43853b
bec718c
ab68e9a
f60730a
0c99a38
f2ec1b7
55165ed
6d20899
0ba4bb1
a9154d9
afed1e5
577e2a8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,14 +9,16 @@ using OffsetArrays | |
using Test | ||
using Unitful | ||
|
||
# Convert abstract unit range to a ViewAxis with ShapeAxis. | ||
r2v(r::AbstractUnitRange) = ViewAxis(r, ShapedAxis(size(r))) | ||
|
||
## Test setup | ||
c = (a = (a = 1, b = [1.0, 4.4]), b = [0.4, 2, 1, 45]) | ||
nt = (a = 100, b = [4, 1.3], c = c) | ||
nt2 = (a = 5, b = [(a = (a = 20, b = 1), b = 0), (a = (a = 33, b = 1), b = 0)], c = (a = (a = 2, b = [1, 2]), b = [1.0 2.0; 5 6])) | ||
|
||
ax = Axis(a = 1, b = 2:3, c = ViewAxis(4:10, (a = ViewAxis(1:3, (a = 1, b = 2:3)), b = 4:7))) | ||
ax_c = (a = ViewAxis(1:3, (a = 1, b = 2:3)), b = 4:7) | ||
ax = Axis(a = 1, b = r2v(2:3), c = ViewAxis(4:10, (a = ViewAxis(1:3, (a = 1, b = r2v(2:3))), b = r2v(4:7)))) | ||
ax_c = (a = ViewAxis(1:3, (a = 1, b = r2v(2:3))), b = r2v(4:7)) | ||
|
||
a = Float64[100, 4, 1.3, 1, 1, 4.4, 0.4, 2, 1, 45] | ||
sq_mat = collect(reshape(1:9, 3, 3)) | ||
|
@@ -36,6 +38,9 @@ caa = ComponentArray(a = ca, b = sq_mat) | |
|
||
_a, _b, _c = Val.((:a, :b, :c)) | ||
|
||
ca3 = ComponentArray(a=1, b=[2, 3, 4, 5], c=reshape(6:11, 3, 2)) | ||
cmat3 = ca3 .* ca3' | ||
cmat3check = (1:11) .* (1:11)' | ||
|
||
## Tests | ||
@testset "Allocations and Inference" begin | ||
|
@@ -132,7 +137,7 @@ end | |
for T in [Int64, Int32, Float64, Float32, ComplexF64, ComplexF32] | ||
@test ComponentArray(a = T[]) == ComponentVector{T}(a = T[]) | ||
@test ComponentArray(a = T[], b = T[]) == ComponentVector{T}(a = T[], b = T[]) | ||
@test ComponentArray(a = T[], b = (;)) == ComponentVector{T}(a = T[], b = T[]) | ||
@test_broken ComponentArray(a = T[], b = (;)) == ComponentVector{T}(a = T[], b = T[]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's important we don't break this one because the behavior is relied upon in some some large simulation projects. This allows There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Gotcha, thanks. I'll take a look at it again when I get a chance. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jonniedie Another polite ping. :-) |
||
@test ComponentArray(a = Any[one(Int32)], b=T[]) == ComponentVector{T}(a = [one(T)], b = T[]) | ||
end | ||
@test ComponentArray(NamedTuple()) == ComponentVector{Any}() | ||
|
@@ -285,6 +290,17 @@ end | |
# OffsetArrays' type piracy without introducing type piracy | ||
# ourselves because `() isa Tuple{N, <:CombinedAxis} where {N}` | ||
# @test reshape(a, axes(ca)...) isa Vector{Float64} | ||
|
||
# Issue #248: Indexing ComponentMatrix with FlatAxis components | ||
@test cmat3[:a, :a] == cmat3check[1, 1] | ||
@test cmat3[:a, :b] == cmat3check[1, 2:5] | ||
@test cmat3[:a, :c] == reshape(cmat3check[1, 6:11], 3, 2) | ||
@test cmat3[:b, :a] == cmat3check[2:5, 1] | ||
@test cmat3[:b, :b] == cmat3check[2:5, 2:5] | ||
@test cmat3[:b, :c] == reshape(cmat3check[2:5, 6:11], 4, 3, 2) | ||
@test cmat3[:c, :a] == reshape(cmat3check[6:11, 1], 3, 2) | ||
@test cmat3[:c, :b] == reshape(cmat3check[6:11, 2:5], 3, 2, 4) | ||
@test cmat3[:c, :c] == reshape(cmat3check[6:11, 6:11], 3, 2, 3, 2) | ||
end | ||
|
||
@testset "Set" begin | ||
|
@@ -327,7 +343,7 @@ end | |
temp = deepcopy(cmat) | ||
@test all((temp[:c, :c][:a, :a] .= 0) .== 0) | ||
|
||
A = ComponentArray(zeros(Int, 4, 4), Axis(x = 1:4), Axis(x = 1:4)) | ||
A = ComponentArray(zeros(Int, 4, 4), Axis(x = r2v(1:4)), Axis(x = r2v(1:4))) | ||
A[1, :] .= 1 | ||
@test A[1, :] == ComponentVector(x = ones(Int, 4)) | ||
end | ||
|
@@ -337,12 +353,17 @@ end | |
ca = ComponentArray(a = 1, b = 2, c = [3, 4], d = (a = [5, 6, 7], b = 8)) | ||
cmat = ca * ca' | ||
|
||
cidx = reshape((1:(2*3)) .+ 2, 2, 3) | ||
ca2 = ComponentArray(a = 1, b = 2, c = cidx, d = (a = [9, 10, 11], b = 12)) | ||
|
||
@testset "ComponentIndex" begin | ||
ax = getaxes(ca)[1] | ||
@test ax[:a] == ax[1] == ComponentArrays.ComponentIndex(1, ComponentArrays.NullAxis()) | ||
@test ax[:c] == ax[3:4] == ComponentArrays.ComponentIndex(3:4, FlatAxis()) | ||
@test ax[:d] == ComponentArrays.ComponentIndex(5:8, Axis(a = 1:3, b = 4)) | ||
@test ax[(:a, :c)] == ax[[:a, :c]] == ComponentArrays.ComponentIndex([1, 3, 4], Axis(a = 1, c = 2:3)) | ||
@test ax[:c] == ax[3:4] == ComponentArrays.ComponentIndex(3:4, ShapedAxis(size(3:4))) | ||
@test ax[:d] == ComponentArrays.ComponentIndex(5:8, Axis(a = r2v(1:3), b = 4)) | ||
@test ax[(:a, :c)] == ax[[:a, :c]] == ComponentArrays.ComponentIndex([1, 3, 4], Axis(a = 1, c = r2v(2:3))) | ||
ax2 = getaxes(ca2)[1] | ||
@test ax2[(:a, :c)] == ax2[[:a, :c]] == ComponentArrays.ComponentIndex([1, 3:8...], Axis(a = 1, c = ViewAxis(2:7, ShapedAxis((2,3))))) | ||
end | ||
|
||
@testset "KeepIndex" begin | ||
|
@@ -353,14 +374,14 @@ end | |
|
||
@test ca[KeepIndex(1:2)] == ComponentArray(a = 1, b = 2) | ||
@test ca[KeepIndex(1:3)] == ComponentArray([1, 2, 3], Axis(a = 1, b = 2)) # Drops c axis | ||
@test ca[KeepIndex(2:5)] == ComponentArray([2, 3, 4, 5], Axis(b = 1, c = 2:3)) | ||
@test ca[KeepIndex(2:5)] == ComponentArray([2, 3, 4, 5], Axis(b = 1, c = r2v(2:3))) | ||
@test ca[KeepIndex(3:end)] == ComponentArray(c = [3, 4], d = (a = [5, 6, 7], b = 8)) | ||
|
||
@test ca[KeepIndex(:)] == ca | ||
|
||
@test cmat[KeepIndex(:a), KeepIndex(:b)] == ComponentArray(fill(2, 1, 1), Axis(a = 1), Axis(b = 1)) | ||
@test cmat[KeepIndex(:), KeepIndex(:c)] == ComponentArray((1:8) * (3:4)', getaxes(ca)[1], Axis(c = 1:2)) | ||
@test cmat[KeepIndex(2:5), 1:2] == ComponentArray((2:5) * (1:2)', Axis(b = 1, c = 2:3), FlatAxis()) | ||
@test cmat[KeepIndex(:), KeepIndex(:c)] == ComponentArray((1:8) * (3:4)', getaxes(ca)[1], Axis(c = r2v(1:2))) | ||
@test cmat[KeepIndex(2:5), 1:2] == ComponentArray((2:5) * (1:2)', Axis(b = 1, c = r2v(2:3)), ShapedAxis(size(1:2))) | ||
@test cmat[KeepIndex(2), KeepIndex(3)] == ComponentArray(fill(2 * 3, 1, 1), Axis(b = 1), FlatAxis()) | ||
@test cmat[KeepIndex(2), 3] == ComponentArray(b = 2 * 3) | ||
end | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm. I guess I don't understand why we'd need to wrap this in a new type. In the example you gave in the issue,
b
is a vector element, not annx1
matrix. I don't think we should introduce a newShapedAxis1d
type if all vector elements are behaving incorrectly. We should just fix the issue of adjoint/transposition that's broken directly. So I guess specifically, theShapedAxis
needs to be created from theFlatAxis
during the adjoint operation, not beforehand. And for that, we can use the normalShapedAxis
without having to introduce a new type.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jonniedie My bad, just saw this comment.
Hmm... not sure I follow, but fwiw I think everything is fine with the transpose—the issue comes up later on.
The problem with the test case in issue #249, aka
appears to be due to the fact that the shape of a unit range (which is used for vector components like
X[:b]
in the example) isn't understood byComponentArrays.maybe_reshape
. The error I get:In
maybe_reshape
:during
J[:d, :b]
,axs
isi.e. the unit range is converted to a
FlatAxis
, which is then ignored when definingshapes
inmaybe_reshape
.The axes of the transpose
X'
look fine to me:The point of the
Shaped1DAxis
is to have a type that can be a part of theNotShapedOrPartitionedAxis
Union
defined inaxis.jl
.This allows it to skip the problematic and unnecessary
maybe_reshape
for vector components.If there was a way to dispatch on
ShapedAxis{Shape}
whenShape
was a length-1Tuple
then I don't think we'd needShaped1DAxis
, but I don't know of a way to do that.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jonniedie Another polite ping re: this PR. :-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jonniedie Another polite ping re: this PR.