diff --git a/base/subarray.jl b/base/subarray.jl index b2923737f9364..96ae7d5be9bf0 100644 --- a/base/subarray.jl +++ b/base/subarray.jl @@ -167,13 +167,13 @@ end push!(sizeexprs, dimsizeexpr(j, Jindex, length(Jp), :V, :J)) end push!(indexexprs, :(reindex(V.indexes[$IVindex], J[$Jindex]))) - push!(Itypes, rangetype(iv, j)) + push!(Itypes, indextype(iv, j)) else # We have a linear index that spans more than one # dimension of the parent N += 1 push!(sizeexprs, dimsizeexpr(j, Jindex, length(Jp), :V, :J)) - push!(indexexprs, :(merge_indexes(V, V.indexes[$IVindex:end], size(V.parent)[$IVindex:end], J[$Jindex], $Jindex))) + push!(indexexprs, :(merge_indexes(V, V.indexes[$IVindex:end], size(V.parent)[$IVindex:end], J[$Jindex], $Jindex)::Array{Int,1})) push!(Itypes, Array{Int, 1}) Iindex_lin = length(Itypes) Jindex_lin = Jindex @@ -215,7 +215,7 @@ end quote Inew = $Inew $exfirst - SubArray{$T,$N,$PV,$It,$LD}(V.parent, Inew, $dims, f, $strideexpr) + SubArray{$T,$N,$PV,typeof(Inew),$LD}(V.parent, Inew, $dims, f, $strideexpr) end end @@ -258,11 +258,11 @@ end elseif Jindex < length(Jp) || Jindex == NV || IVindex == length(IVp) # simple indexing push!(indexexprs, :(reindex(V.indexes[$IVindex], J[$Jindex]))) - push!(Itypes, rangetype(iv, j)) + push!(Itypes, indextype(iv, j)) push!(ItypesLD, Itypes[end]) else # We have a linear index that spans more than one dimension of the parent - push!(indexexprs, :(merge_indexes(V, V.indexes[$IVindex:end], size(V.parent)[$IVindex:end], J[$Jindex], $Jindex))) + push!(indexexprs, :(merge_indexes(V, V.indexes[$IVindex:end], size(V.parent)[$IVindex:end], J[$Jindex], $Jindex)::Array{Int,1})) push!(Itypes, Array{Int, 1}) push!(ItypesLD, Itypes[end]) break @@ -302,15 +302,24 @@ end $preex Inew = $Inew $exfirst - SubArray{$T,$N,$PV,$It,$LD}(V.parent, Inew, $dims, f, $strideexpr) + SubArray{$T,$N,$PV,typeof(Inew),$LD}(V.parent, Inew, $dims, f, $strideexpr) end end -function rangetype(T1, T2) - rt = return_types(getindex, Tuple{T1, T2}) - length(rt) == 1 || error("Can't infer return type") - rt[1] -end +# These pretty much only have to get the container-type right, +# for use in nextLD. The actual index type is set by typeof(I). +indextype{A<:AbstractArray,I<:Integer}(::Type{A}, ::Type{I}) = eltype(A) +indextype{A<:UnitRange,I<:UnitRange}(::Type{A}, ::Type{I}) = A +indextype{A<:Range,I<:UnitRange}(::Type{A}, ::Type{I}) = A +indextype{A<:UnitRange,I<:Range}(::Type{A}, ::Type{I}) = I +indextype{A<:Range,I<:Range}(::Type{A}, ::Type{I}) = A +indextype{A<:AbstractArray,I<:Range}(::Type{A}, ::Type{I}) = A +indextype{A<:Range,I<:AbstractArray}(::Type{A}, ::Type{I}) = Array{Int,1} +indextype{A<:AbstractArray,I<:AbstractArray}(::Type{A}, ::Type{I}) = Array{Int,1} +indextype(::Type{Colon}, ::Type{Colon}) = Colon +indextype{I<:Integer}(::Type{Colon}, ::Type{I}) = Int +indextype{I<:AbstractArray}(::Type{Colon}, ::Type{I}) = I +indextype{A<:AbstractArray}(::Type{A}, ::Type{Colon}) = A reindex(a, b) = a[b] reindex(a::UnitRange, b::UnitRange{Int}) = range(oftype(first(a), first(a)+first(b)-1), length(b)) diff --git a/test/subarray.jl b/test/subarray.jl index 5e342e92cb7ec..fa2636cb3c99e 100644 --- a/test/subarray.jl +++ b/test/subarray.jl @@ -156,8 +156,8 @@ end function test_linear(A, B) length(A) == length(B) || error("length mismatch") isgood = true - for (iA, iB) in zip(1:length(A), 1:length(B)) - if A[iA] != B[iB] + for i = 1:length(A) + if A[i] != B[i] isgood = false break end @@ -214,7 +214,14 @@ function runtests(A::Array, I...) ldc = Base.subarray_linearindexing_dim(typeof(A), typeof(I)) ld == ldc || err_li(I, ld, ldc) # sub - S = sub(A, I...) + local S + try + S = @inferred(sub(A, I...)) + catch err + @show typeof(A) + @show I + rethrow(err) + end getLD(S) == ldc || err_li(S, ldc) if Base.iscontiguous(S) @test S.stride1 == 1 @@ -223,7 +230,13 @@ function runtests(A::Array, I...) test_cartesian(S, C) test_mixed(S, C) # slice - S = slice(A, I...) + try + S = @inferred(slice(A, I...)) + catch err + @show typeof(A) + @show I + rethrow(err) + end getLD(S) == ldc || err_li(S, ldc) test_linear(S, C) test_cartesian(S, C) @@ -258,7 +271,7 @@ function runtests(A::SubArray, I...) # sub local S try - S = sub(A, I...) + S = @inferred(sub(A, I...)) catch err @show typeof(A) @show A.indexes @@ -272,7 +285,7 @@ function runtests(A::SubArray, I...) test_mixed(S, C) # slice try - S = slice(A, I...) + S = @inferred(slice(A, I...)) catch err @show typeof(A) @show A.indexes