From b517b32abb525950c95d76de1c0672ed7db1807d Mon Sep 17 00:00:00 2001 From: Matt Bauman Date: Fri, 26 Jun 2020 04:21:35 -0500 Subject: [PATCH] Fix pointer to no longer assume contiguity (#36405) * Fix pointer to no longer assume contiguity --- base/abstractarray.jl | 23 +++-- base/permuteddimsarray.jl | 1 + base/subarray.jl | 17 +--- stdlib/LinearAlgebra/src/adjtrans.jl | 3 + test/abstractarray.jl | 141 +++++++++++++++++++++++++++ 5 files changed, 163 insertions(+), 22 deletions(-) diff --git a/base/abstractarray.jl b/base/abstractarray.jl index 1d6dc4c36a0c0c..1fdf83cbc504cd 100644 --- a/base/abstractarray.jl +++ b/base/abstractarray.jl @@ -1007,7 +1007,14 @@ end pointer(x::AbstractArray{T}) where {T} = unsafe_convert(Ptr{T}, x) function pointer(x::AbstractArray{T}, i::Integer) where T @_inline_meta - unsafe_convert(Ptr{T}, x) + (i - first(LinearIndices(x)))*elsize(x) + unsafe_convert(Ptr{T}, x) + _memory_offset(x, i) +end + +# The distance from pointer(x) to the element at x[I...] in bytes +_memory_offset(x::DenseArray, I...) = (_to_linear_index(x, I...) - first(LinearIndices(x)))*elsize(x) +function _memory_offset(x::AbstractArray, I...) + J = _to_subscript_indices(x, I...) + return sum(map((i, s, o)->s*(i-o), J, strides(x), Tuple(first(CartesianIndices(x)))))*elsize(x) end ## Approach: @@ -1078,10 +1085,10 @@ function _getindex(::IndexLinear, A::AbstractArray, I::Vararg{Int,M}) where M @inbounds r = getindex(A, _to_linear_index(A, I...)) r end -_to_linear_index(A::AbstractArray, i::Int) = i -_to_linear_index(A::AbstractVector, i::Int, I::Int...) = i +_to_linear_index(A::AbstractArray, i::Integer) = i +_to_linear_index(A::AbstractVector, i::Integer, I::Integer...) = i _to_linear_index(A::AbstractArray) = 1 -_to_linear_index(A::AbstractArray, I::Int...) = (@_inline_meta; _sub2ind(A, I...)) +_to_linear_index(A::AbstractArray, I::Integer...) = (@_inline_meta; _sub2ind(A, I...)) ## IndexCartesian Scalar indexing: Canonical method is full dimensionality of Ints function _getindex(::IndexCartesian, A::AbstractArray, I::Vararg{Int,M}) where M @@ -1094,12 +1101,12 @@ function _getindex(::IndexCartesian, A::AbstractArray{T,N}, I::Vararg{Int, N}) w @_propagate_inbounds_meta getindex(A, I...) end -_to_subscript_indices(A::AbstractArray, i::Int) = (@_inline_meta; _unsafe_ind2sub(A, i)) +_to_subscript_indices(A::AbstractArray, i::Integer) = (@_inline_meta; _unsafe_ind2sub(A, i)) _to_subscript_indices(A::AbstractArray{T,N}) where {T,N} = (@_inline_meta; fill_to_length((), 1, Val(N))) _to_subscript_indices(A::AbstractArray{T,0}) where {T} = () -_to_subscript_indices(A::AbstractArray{T,0}, i::Int) where {T} = () -_to_subscript_indices(A::AbstractArray{T,0}, I::Int...) where {T} = () -function _to_subscript_indices(A::AbstractArray{T,N}, I::Int...) where {T,N} +_to_subscript_indices(A::AbstractArray{T,0}, i::Integer) where {T} = () +_to_subscript_indices(A::AbstractArray{T,0}, I::Integer...) where {T} = () +function _to_subscript_indices(A::AbstractArray{T,N}, I::Integer...) where {T,N} @_inline_meta J, Jrem = IteratorsMD.split(I, Val(N)) _to_subscript_indices(A, J, Jrem) diff --git a/base/permuteddimsarray.jl b/base/permuteddimsarray.jl index 348d4ad449a538..7ac87df0ad6870 100644 --- a/base/permuteddimsarray.jl +++ b/base/permuteddimsarray.jl @@ -64,6 +64,7 @@ function Base.strides(A::PermutedDimsArray{T,N,perm}) where {T,N,perm} s = strides(parent(A)) ntuple(d->s[perm[d]], Val(N)) end +Base.elsize(::Type{<:PermutedDimsArray{<:Any, <:Any, <:Any, <:Any, P}}) where {P} = Base.elsize(P) @inline function Base.getindex(A::PermutedDimsArray{T,N,perm,iperm}, I::Vararg{Int,N}) where {T,N,perm,iperm} @boundscheck checkbounds(A, I...) diff --git a/base/subarray.jl b/base/subarray.jl index ba891183924bbb..a4cd5920157ad1 100644 --- a/base/subarray.jl +++ b/base/subarray.jl @@ -398,23 +398,12 @@ find_extended_inds(::ScalarIndex, I...) = (@_inline_meta; find_extended_inds(I.. find_extended_inds(i1, I...) = (@_inline_meta; (i1, find_extended_inds(I...)...)) find_extended_inds() = () -unsafe_convert(::Type{Ptr{T}}, V::SubArray{T,N,P,<:Tuple{Vararg{RangeIndex}}}) where {T,N,P} = - unsafe_convert(Ptr{T}, V.parent) + (first_index(V)-1)*sizeof(T) +function unsafe_convert(::Type{Ptr{T}}, V::SubArray{T,N,P,<:Tuple{Vararg{RangeIndex}}}) where {T,N,P} + return unsafe_convert(Ptr{T}, V.parent) + _memory_offset(V.parent, map(first, V.indices)...) +end pointer(V::FastSubArray, i::Int) = pointer(V.parent, V.offset1 + V.stride1*i) pointer(V::FastContiguousSubArray, i::Int) = pointer(V.parent, V.offset1 + i) -pointer(V::SubArray, i::Int) = _pointer(V, i) -_pointer(V::SubArray{<:Any,1}, i::Int) = pointer(V, (i,)) -_pointer(V::SubArray, i::Int) = pointer(V, Base._ind2sub(axes(V), i)) - -function pointer(V::SubArray{T,N,<:Array,<:Tuple{Vararg{RangeIndex}}}, is::Tuple{Vararg{Int}}) where {T,N} - index = first_index(V) - strds = strides(V) - for d = 1:length(is) - index += (is[d]-1)*strds[d] - end - return pointer(V.parent, index) -end # indices are taken from the range/vector # Since bounds-checking is performance-critical and uses diff --git a/stdlib/LinearAlgebra/src/adjtrans.jl b/stdlib/LinearAlgebra/src/adjtrans.jl index e3952d02618364..458beea92604e1 100644 --- a/stdlib/LinearAlgebra/src/adjtrans.jl +++ b/stdlib/LinearAlgebra/src/adjtrans.jl @@ -208,6 +208,9 @@ Base.strides(A::Transpose{<:Any, <:StridedMatrix}) = reverse(strides(A.parent)) Base.unsafe_convert(::Type{Ptr{T}}, A::Adjoint{<:Real, <:StridedVecOrMat}) where {T} = Base.unsafe_convert(Ptr{T}, A.parent) Base.unsafe_convert(::Type{Ptr{T}}, A::Transpose{<:Any, <:StridedVecOrMat}) where {T} = Base.unsafe_convert(Ptr{T}, A.parent) +Base.elsize(::Type{<:Adjoint{<:Real, P}}) where {P<:StridedVecOrMat} = Base.elsize(P) +Base.elsize(::Type{<:Transpose{<:Any, P}}) where {P<:StridedVecOrMat} = Base.elsize(P) + # for vectors, the semantics of the wrapped and unwrapped types differ # so attempt to maintain both the parent and wrapper type insofar as possible similar(A::AdjOrTransAbsVec) = wrapperop(A)(similar(A.parent)) diff --git a/test/abstractarray.jl b/test/abstractarray.jl index a35667e28ba72b..0089aac4bc62fc 100644 --- a/test/abstractarray.jl +++ b/test/abstractarray.jl @@ -978,3 +978,144 @@ end @test Core.sizeof(arrayOfUInt48) == 24 end end + +struct Strider{T,N} <: AbstractArray{T,N} + data::Vector{T} + offset::Int + strides::NTuple{N,Int} + size::NTuple{N,Int} +end +function Strider{T}(strides::NTuple{N}, size::NTuple{N}) where {T,N} + offset = 1-sum(strides .* (strides .< 0) .* (size .- 1)) + data = Array{T}(undef, sum(abs.(strides) .* (size .- 1)) + 1) + return Strider{T, N, Vector{T}}(data, offset, strides, size) +end +function Strider(vec::AbstractArray{T}, strides::NTuple{N}, size::NTuple{N}) where {T,N} + offset = 1-sum(strides .* (strides .< 0) .* (size .- 1)) + @assert length(vec) >= sum(abs.(strides) .* (size .- 1)) + 1 + return Strider{T, N}(vec, offset, strides, size) +end +Base.size(S::Strider) = S.size +function Base.getindex(S::Strider{<:Any,N}, I::Vararg{Int,N}) where {N} + return S.data[sum(S.strides .* (I .- 1)) + S.offset] +end +Base.strides(S::Strider) = S.strides +Base.elsize(::Type{<:Strider{T}}) where {T} = Base.elsize(Vector{T}) +Base.unsafe_convert(::Type{Ptr{T}}, S::Strider{T}) where {T} = pointer(S.data, S.offset) + +@testset "Simple 3d strided views and permutes" for sz in ((5, 3, 2), (7, 11, 13)) + A = collect(reshape(1:prod(sz), sz)) + S = Strider(vec(A), strides(A), sz) + @test pointer(A) == pointer(S) + for i in 1:prod(sz) + @test pointer(A, i) == pointer(S, i) + @test A[i] == S[i] + end + for idxs in ((1:sz[1], 1:sz[2], 1:sz[3]), + (1:sz[1], 2:2:sz[2], sz[3]:-1:1), + (2:2:sz[1]-1, sz[2]:-1:1, sz[3]:-2:2), + (sz[1]:-1:1, sz[2]:-1:1, sz[3]:-1:1), + (sz[1]-1:-3:1, sz[2]:-2:3, 1:sz[3]),) + Ai = A[idxs...] + Av = view(A, idxs...) + Sv = view(S, idxs...) + Ss = Strider{Int, 3}(vec(A), sum((first.(idxs).-1).*strides(A))+1, strides(Av), length.(idxs)) + @test pointer(Av) == pointer(Sv) == pointer(Ss) + for i in 1:length(Av) + @test pointer(Av, i) == pointer(Sv, i) == pointer(Ss, i) + @test Ai[i] == Av[i] == Sv[i] == Ss[i] + end + for perm in ((3, 2, 1), (2, 1, 3), (3, 1, 2)) + P = permutedims(A, perm) + Ap = Base.PermutedDimsArray(A, perm) + Sp = Base.PermutedDimsArray(S, perm) + Ps = Strider{Int, 3}(vec(A), 1, strides(A)[collect(perm)], sz[collect(perm)]) + @test pointer(Ap) == pointer(Sp) == pointer(Ps) + for i in 1:length(Ap) + # This is intentionally disabled due to ambiguity + @test_broken pointer(Ap, i) == pointer(Sp, i) == pointer(Ps, i) + @test P[i] == Ap[i] == Sp[i] == Ps[i] + end + Pv = view(P, idxs[collect(perm)]...) + Pi = P[idxs[collect(perm)]...] + Apv = view(Ap, idxs[collect(perm)]...) + Spv = view(Sp, idxs[collect(perm)]...) + Pvs = Strider{Int, 3}(vec(A), sum((first.(idxs).-1).*strides(A))+1, strides(Apv), size(Apv)) + @test pointer(Apv) == pointer(Spv) == pointer(Pvs) + for i in 1:length(Apv) + @test pointer(Apv, i) == pointer(Spv, i) == pointer(Pvs, i) + @test Pi[i] == Pv[i] == Apv[i] == Spv[i] == Pvs[i] + end + Vp = permutedims(Av, perm) + Ip = permutedims(Ai, perm) + Avp = Base.PermutedDimsArray(Av, perm) + Svp = Base.PermutedDimsArray(Sv, perm) + @test pointer(Avp) == pointer(Svp) + for i in 1:length(Avp) + # This is intentionally disabled due to ambiguity + @test_broken pointer(Avp, i) == pointer(Svp, i) + @test Ip[i] == Vp[i] == Avp[i] == Svp[i] + end + end + end +end + +@testset "simple 2d strided views, permutes, transposes" for sz in ((5, 3), (7, 11)) + A = collect(reshape(1:prod(sz), sz)) + S = Strider(vec(A), strides(A), sz) + @test pointer(A) == pointer(S) + for i in 1:prod(sz) + @test pointer(A, i) == pointer(S, i) + @test A[i] == S[i] + end + for idxs in ((1:sz[1], 1:sz[2]), + (1:sz[1], 2:2:sz[2]), + (2:2:sz[1]-1, sz[2]:-1:1), + (sz[1]:-1:1, sz[2]:-1:1), + (sz[1]-1:-3:1, sz[2]:-2:3),) + Av = view(A, idxs...) + Sv = view(S, idxs...) + Ss = Strider{Int, 2}(vec(A), sum((first.(idxs).-1).*strides(A))+1, strides(Av), length.(idxs)) + @test pointer(Av) == pointer(Sv) == pointer(Ss) + for i in 1:length(Av) + @test pointer(Av, i) == pointer(Sv, i) == pointer(Ss, i) + @test Av[i] == Sv[i] == Ss[i] + end + perm = (2, 1) + P = permutedims(A, perm) + Ap = Base.PermutedDimsArray(A, perm) + At = transpose(A) + Aa = adjoint(A) + Sp = Base.PermutedDimsArray(S, perm) + Ps = Strider{Int, 2}(vec(A), 1, strides(A)[collect(perm)], sz[collect(perm)]) + @test pointer(Ap) == pointer(Sp) == pointer(Ps) == pointer(At) == pointer(Aa) + for i in 1:length(Ap) + # This is intentionally disabled due to ambiguity + @test_broken pointer(Ap, i) == pointer(Sp, i) == pointer(Ps, i) == pointer(At, i) == pointer(Aa, i) + @test pointer(Ps, i) == pointer(At, i) == pointer(Aa, i) + @test P[i] == Ap[i] == Sp[i] == Ps[i] == At[i] == Aa[i] + end + Pv = view(P, idxs[collect(perm)]...) + Apv = view(Ap, idxs[collect(perm)]...) + Atv = view(At, idxs[collect(perm)]...) + Ata = view(Aa, idxs[collect(perm)]...) + Spv = view(Sp, idxs[collect(perm)]...) + Pvs = Strider{Int, 2}(vec(A), sum((first.(idxs).-1).*strides(A))+1, strides(Apv), size(Apv)) + @test pointer(Apv) == pointer(Spv) == pointer(Pvs) == pointer(Atv) == pointer(Ata) + for i in 1:length(Apv) + @test pointer(Apv, i) == pointer(Spv, i) == pointer(Pvs, i) == pointer(Atv, i) == pointer(Ata, i) + @test Pv[i] == Apv[i] == Spv[i] == Pvs[i] == Atv[i] == Ata[i] + end + Vp = permutedims(Av, perm) + Avp = Base.PermutedDimsArray(Av, perm) + Avt = transpose(Av) + Ava = adjoint(Av) + Svp = Base.PermutedDimsArray(Sv, perm) + @test pointer(Avp) == pointer(Svp) == pointer(Avt) == pointer(Ava) + for i in 1:length(Avp) + # This is intentionally disabled due to ambiguity + @test_broken pointer(Avp, i) == pointer(Svp, i) == pointer(Avt, i) == pointer(Ava, i) + @test Vp[i] == Avp[i] == Svp[i] == Avt[i] == Ava[i] + end + end +end