Skip to content

Commit

Permalink
Improve inferability of slicedim (#20154)
Browse files Browse the repository at this point in the history
* Add `setindex` for tuples

* Improve inferability of `slicedim`

* Remove `slicedim` specialization for `BitArray`

* Add `slicedim(::BitVector, ::Integer, ::Integer)` to restore previous behavior
  • Loading branch information
martinholters authored Jan 31, 2017
1 parent 686ffce commit 552d5e0
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 30 deletions.
2 changes: 1 addition & 1 deletion base/abstractarraymath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ function slicedim(A::AbstractArray, d::Integer, i)
d >= 1 || throw(ArgumentError("dimension must be ≥ 1"))
nd = ndims(A)
d > nd && (i == 1 || throw_boundserror(A, (ntuple(k->Colon(),nd)..., ntuple(k->1,d-1-nd)..., i)))
A[( n==d ? i : indices(A,n) for n in 1:nd )...]
A[setindex(indices(A), i, d)...]
end

function flipdim(A::AbstractVector, d::Integer)
Expand Down
39 changes: 10 additions & 29 deletions base/bitarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1278,39 +1278,20 @@ end

## Data movement ##

# TODO some of this could be optimized

function slicedim(A::BitArray, d::Integer, i::Integer)
d_in = size(A)
leading = d_in[1:(d-1)]
d_out = tuple(leading..., d_in[(d+1):end]...)

M = prod(leading)
N = length(A)
stride = M * d_in[d]

B = BitArray(d_out)
index_offset = 1 + (i-1)*M

l = 1

if M == 1
for j = 0:stride:(N-stride)
B[l] = A[j + index_offset]
l += 1
end
# preserve some special behavior
function slicedim(A::BitVector, d::Integer, i::Integer)
d >= 1 || throw(ArgumentError("dimension must be ≥ 1"))
if d > 1
i == 1 || throw_boundserror(A, (:, ntuple(k->1,d-2)..., i))
A[:]
else
for j = 0:stride:(N-stride)
offs = j + index_offset
for k = 0:(M-1)
B[l] = A[offs + k]
l += 1
end
end
fill!(BitArray{0}(), A[i]) # generic slicedim would return A[i] here
end
return B
end


# TODO some of this could be optimized

function flipdim(A::BitArray, d::Integer)
nd = ndims(A)
1 d nd || throw(ArgumentError("dimension $d is not 1 ≤ $d$nd"))
Expand Down
8 changes: 8 additions & 0 deletions base/tuple.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,14 @@ getindex(t::Tuple, i::Real) = getfield(t, convert(Int, i))
getindex{T}(t::Tuple, r::AbstractArray{T,1}) = tuple([t[ri] for ri in r]...)
getindex(t::Tuple, b::AbstractArray{Bool,1}) = length(b) == length(t) ? getindex(t,find(b)) : throw(BoundsError(t, b))

# returns new tuple; N.B.: becomes no-op if i is out-of-bounds
setindex(x::Tuple, v, i::Integer) = _setindex((), x, v, i::Integer)
function _setindex(y::Tuple, r::Tuple, v, i::Integer)
@_inline_meta
_setindex((y..., ifelse(length(y) + 1 == i, v, first(r))), tail(r), v, i)
end
_setindex(y::Tuple, r::Tuple{}, v, i::Integer) = y

## iterating ##

start(t::Tuple) = 1
Expand Down
1 change: 1 addition & 0 deletions test/arrayops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1619,6 +1619,7 @@ B = 1.5:5.5
@test_throws ArgumentError slicedim(A,0,1)
@test slicedim(A, 3, 1) == A
@test_throws BoundsError slicedim(A, 3, 2)
@test @inferred(slicedim(A, 1, 2:2)) == collect(2:4:20)'
end
end

Expand Down

0 comments on commit 552d5e0

Please sign in to comment.