Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add insertdims method which is inverse to dropdims #45793

Merged
merged 29 commits into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
9903e42
Add insertdims
roflmaostc Jun 23, 2022
0f93207
Update abstractarraymath.jl
roflmaostc Jun 23, 2022
bd6e45e
Fix some whitespaces and doctest [skip ci]
roflmaostc Jun 24, 2022
f4087a2
Handle merge [skip ci]
roflmaostc Jun 24, 2022
50ae007
Merge branch 'JuliaLang:master' into master
roflmaostc May 30, 2024
4f53621
Fix bug in _foldoneto call
roflmaostc May 30, 2024
1196b8f
Add test for multiple singleton dimensions at one dim
roflmaostc May 31, 2024
6f986dc
Update base/abstractarraymath.jl
roflmaostc Jul 25, 2024
16e0203
Update base/abstractarraymath.jl
roflmaostc Jul 25, 2024
cf23193
Update base/abstractarraymath.jl
roflmaostc Jul 25, 2024
6057e22
Add docs
roflmaostc Jul 25, 2024
84d0693
Update base/abstractarraymath.jl
roflmaostc Jul 25, 2024
aa51242
Merge branch 'JuliaLang:master' into master
roflmaostc Jul 25, 2024
8648d84
Add news
roflmaostc Jul 25, 2024
080c83c
Fix merge mistake
roflmaostc Jul 26, 2024
ec383d9
Rebase [skip ci]
roflmaostc Jul 27, 2024
66d0561
Update test/arrayops.jl [skip ci]
roflmaostc Jul 27, 2024
bdb3258
Update base/abstractarraymath.jl
roflmaostc Jul 27, 2024
0ddee55
Update comment in code [skip ci]
roflmaostc Jul 27, 2024
76b31f7
Remove parts in docstring [skip ci]
roflmaostc Jul 27, 2024
490b103
Adapt docstring
roflmaostc Jul 27, 2024
98283db
Remove trailing whitespace
roflmaostc Jul 27, 2024
c321d51
Remove trailing whitespace in test [skip ci]
roflmaostc Jul 27, 2024
2a9e3ac
Update base/abstractarraymath.jl [skip ci]
roflmaostc Jul 27, 2024
d7cb35c
Rewrite doc [skip ci]
roflmaostc Jul 27, 2024
514b4fc
Update test/arrayops.jl
roflmaostc Aug 1, 2024
4f07e18
Update test/arrayops.jl
roflmaostc Aug 1, 2024
cfa1465
Update base/abstractarraymath.jl
roflmaostc Aug 1, 2024
9f174b7
Update base/abstractarraymath.jl
roflmaostc Aug 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 39 additions & 56 deletions base/abstractarraymath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,80 +95,63 @@ _dropdims(A::AbstractArray, dim::Integer) = _dropdims(A, (Int(dim),))


"""
insertdims(A; dims)
inserdims(A; dims)
roflmaostc marked this conversation as resolved.
Show resolved Hide resolved

Return an array with the same data as `A`, but with singleton dimensions specified by
`dims` inserted.
The dimensions of `A` and `dims` must be contiguous.
If dimensions occur multiple times in `dims`, several singleton dimensions are inserted.
Inverse of [`dropdims`](@ref); return an array with new singleton dimensions
at every dimension in `dims`.

Repeated dimensions are forbidden and the largest entry in `dims` must be
smaller than the dimensionality of the array and the length of `dims` together.

The result shares the same underlying data as `A`, such that the
result is mutable if and only if `A` is mutable, and setting elements of one
alters the values of the other.

See also: [`reshape`](@ref), [`dropdims`](@ref), [`vec`](@ref).

See also: [`dropdims`](@ref), [`reshape`](@ref), [`vec`](@ref).
# Examples
```jldoctest
julia> a = [1 2; 3 4]
2×2 Matrix{Int64}:
1 2
3 4

julia> b = insertdims(a, dims=(1,1))
1×1×2×2 Array{Int64, 4}:
[:, :, 1, 1] =
1

[:, :, 2, 1] =
3

[:, :, 1, 2] =
2

[:, :, 2, 2] =
4

julia> b = insertdims(a, dims=(1,2))
1×2×1×2 Array{Int64, 4}:
[:, :, 1, 1] =
1 3

[:, :, 1, 2] =
2 4
julia> x = [1 2 3; 4 5 6]
2×3 Matrix{Int64}:
1 2 3
4 5 6

julia> b = insertdims(a, dims=(1,3))
1×2×2×1 Array{Int64, 4}:
[:, :, 1, 1] =
1 3
julia> insertdims(x, dims=3)
2×3×1 Array{Int64, 3}:
[:, :, 1] =
1 2 3
4 5 6

[:, :, 2, 1] =
2 4
julia> insertdims(x, dims=(1,2,5)) == reshape(x, 1, 1, 2, 3, 1)
true

julia> b[1,1,1,1] = 5; a
2 Matrix{Int64}:
5 2
3 4
julia> dropdims(insertdims(x, dims=(1,2,5)), dims=(1,2,5))
3 Matrix{Int64}:
1 2 3
4 5 6
```
roflmaostc marked this conversation as resolved.
Show resolved Hide resolved

!!! compat "Julia 1.12"
Requires Julia 1.12 or later.
"""
insertdims(A; dims) = _insertdims(A, dims)
function _insertdims(A::AbstractArray{T, N}, dims::Tuple{Vararg{Int64, M}}) where {T, N, M}
roflmaostc marked this conversation as resolved.
Show resolved Hide resolved
maximum(dims) ≤ ndims(A)+1 || throw(ArgumentError("The largest entry in dims must be ≤ ndims(A) + 1."))
1 ≤ minimum(dims) || throw(ArgumentError("The smallest entry in dims must be ≥ 1."))
issorted(dims) || throw(ArgumentError("dims=$(dims) are not sorted"))

# n is the amount of the dims already inserted
ax_n = Base._foldoneto(((ds, n, dims), _) ->
dims != Tuple(()) && n == first(dims) ?
((ds..., Base.OneTo(1)), n, Base.tail(dims)) :
((ds..., axes(A,n)), n+1, dims),
((), 1, dims), Val(ndims(A) + length(dims)))

# we need only the new shape and not n
reshape(A, ax_n[1])
for i in eachindex(dims)
1 ≤ dims[i] || throw(ArgumentError("the smallest entry in dims must be ≥ 1."))
dims[i] ≤ N+M || throw(ArgumentError("the largest entry in dims must be not larger than the dimension of the array and the length of dims added"))
LilithHafner marked this conversation as resolved.
Show resolved Hide resolved
for j = 1:i-1
dims[j] == dims[i] && throw(ArgumentError("inserted dims must be unique"))
end
end

# acc is a tuple, where the first entry is the final shape
# the second entry off acc is a counter for the axes of A
inds= Base._foldoneto((acc, i) ->
i ∈ dims
? ((acc[1]..., Base.OneTo(1)), acc[2])
: ((acc[1]..., axes(A, acc[2])), acc[2] + 1),
((), 1), Val(N+M))
LilithHafner marked this conversation as resolved.
Show resolved Hide resolved
new_shape = inds[1]
return reshape(A, new_shape)
end
_insertdims(A::AbstractArray, dim::Integer) = _insertdims(A, (Int(dim),))

Expand Down
24 changes: 19 additions & 5 deletions test/arrayops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -311,17 +311,31 @@ end

a = rand(8, 7)
@test @inferred(insertdims(a, dims=1)) == @inferred(insertdims(a, dims=(1,))) == reshape(a, (1, 8, 7))
@test @inferred(insertdims(a, dims=(1, 3))) == reshape(a, (1, 8, 7, 1))
@test @inferred(insertdims(a, dims=(1, 2, 3))) == reshape(a, (1, 8, 1, 7, 1))
@test @inferred(insertdims(a, dims=(1, 1, 2, 3))) == reshape(a, (1, 1, 8, 1, 7, 1))
@test @inferred(insertdims(a, dims=(1, 2, 2, 3))) == reshape(a, (1, 8, 1, 1, 7, 1))
@test @inferred(insertdims(a, dims=(1, 2, 3, 3))) == reshape(a, (1, 8, 1, 7, 1, 1))
@test @inferred(insertdims(a, dims=3)) == @inferred(insertdims(a, dims=(3,))) == reshape(a, (8, 7, 1))
@test @inferred(insertdims(a, dims=(1, 3))) == reshape(a, (1, 8, 1, 7))
@test @inferred(insertdims(a, dims=(1, 2, 3))) == reshape(a, (1, 1, 1, 8, 7))
@test @inferred(insertdims(a, dims=(1, 4))) == reshape(a, (1, 8, 7, 1))
@test @inferred(insertdims(a, dims=(1, 3, 5))) == reshape(a, (1, 8, 1, 7, 1))
@test @inferred(insertdims(a, dims=(1, 2, 4, 6))) == reshape(a, (1, 1, 8, 1, 7, 1))
@test @inferred(insertdims(a, dims=(1, 3, 4, 6))) == reshape(a, (1, 8, 1, 1, 7, 1))
roflmaostc marked this conversation as resolved.
Show resolved Hide resolved
@test @inferred(insertdims(a, dims=(1, 3, 5, 6))) == reshape(a, (1, 8, 1, 7, 1, 1))

@test_throws ArgumentError insertdims(a, dims=(1, 1, 2, 3))
@test_throws ArgumentError insertdims(a, dims=(1, 2, 2, 3))
@test_throws ArgumentError insertdims(a, dims=(1, 2, 3, 3))
@test_throws UndefKeywordError insertdims(a)
@test_throws ArgumentError insertdims(a, dims=0)
@test_throws ArgumentError insertdims(a, dims=(1, 2, 1))
@test_throws ArgumentError insertdims(a, dims=4)
@test_throws ArgumentError insertdims(a, dims=6)

# insertdims and dropdims are inverses
b = rand(1,1,1,5,1,1,7)
for dims in [1, (1,), 2, (2,), 3, (3,), (1,3), (1,2,3), (1,2), (1,3,5), (1,2,5,6), (1,3,5,6), (1,3,5,6)]
roflmaostc marked this conversation as resolved.
Show resolved Hide resolved
@test dropdims(insertdims(a; dims); dims) == a
@test insertdims(dropdims(b; dims); dims) == b
end

sz = (5,8,7)
A = reshape(1:prod(sz),sz...)
@test A[2:6] == [2:6;]
Expand Down