Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check that axes start with 1 for AbstractRange operations #30950

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 31 additions & 14 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ Module containing the broadcasting implementation.
module Broadcast

using .Base.Cartesian
using .Base: Indices, OneTo, tail, to_shape, isoperator, promote_typejoin,
using .Base: Indices, OneTo, AxesStartStyle, AxesStart1, AxesStartAny
using .Base: tail, to_shape, isoperator, promote_typejoin, require_one_based_indexing,
_msk_end, unsafe_bitgetindex, bitcache_chunks, bitcache_size, dumpbitcache, unalias
import .Base: copy, copyto!, axes
export broadcast, broadcast!, BroadcastStyle, broadcast_axes, broadcastable, dotview, @__dot__
Expand Down Expand Up @@ -490,7 +491,11 @@ _bcsm(a::Number, b::Number) = a == b || b == 1
# (We may not want to define general promotion rules between, say, OneTo and Slice, but if
# we get here we know the axes are at least consistent for the purposes of broadcasting)
axistype(a::T, b::T) where T = a
axistype(a, b) = UnitRange{Int}(a)
axistype(a, b) = _axistype(AxesStartStyle(a), AxesStartStyle(b), a, b)
_axistype(::AxesStart1, ::AxesStart1, a, b) = OneTo{Int}(a)
_axistype(::AxesStartAny, ::AxesStart1, a, b) = a # if we get here, b has length 1
_axistype(::AxesStart1, ::AxesStartAny, a, b) = error("mismatched axes, or specialize `Broadcast.axistype` for ", typeof(a), " and ", typeof(b), " axes")
_axistype(::AxesStartAny, ::AxesStartAny, a, b) = error("mismatched axes, or specialize `Broadcast.axistype` for ", typeof(a), " and ", typeof(b), " axes")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unsure about these erroring methods. Here's an interesting and troubling case:

julia> using OffsetArrays, Test

julia> x = [1]
1-element Array{Int64,1}:
 1

julia> y = OffsetArray([1], -1:-1)
1-element OffsetArray(::Array{Int64,1}, -1:-1) with eltype Int64 with indices -1:-1:
 1

julia> @inferred(x.*y)
1-element OffsetArray(::Array{Int64,1}, -1:-1) with eltype Int64 with indices -1:-1:
 1

julia> @inferred(y.*x)
ERROR: mismatched axes, or specialize `Broadcast.axistype` for Base.OneTo{Int64} and Base.IdentityUnitRange{UnitRange{Int64}} axes
Stacktrace:
 [1] error(::String, ::Type, ::String, ::Type, ::String) at ./error.jl:42
 [2] _axistype(::Base.AxesStart1, ::Base.AxesStartAny, ::Base.OneTo{Int64}, ::Base.IdentityUnitRange{UnitRange{Int64}}) at ./broadcast.jl:497
 [3] axistype(::Base.OneTo{Int64}, ::Base.IdentityUnitRange{UnitRange{Int64}}) at ./broadcast.jl:494
 [4] _bcs1 at ./broadcast.jl:485 [inlined]
 [5] _bcs at ./broadcast.jl:479 [inlined]
 [6] broadcast_shape at ./broadcast.jl:473 [inlined]
 [7] combine_axes at ./broadcast.jl:468 [inlined]
 [8] instantiate at ./broadcast.jl:256 [inlined]
 [9] materialize(::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Nothing,typeof(*),Tuple{OffsetArray{Int64,1,Array{Int64,1}},Array{Int64,1}}}) at ./broadcast.jl:802
 [10] _materialize_broadcasted(::Function, ::OffsetArray{Int64,1,Array{Int64,1}}, ::Vararg{Any,N} where N) at /home/tim/src/julia-branch/usr/share/julia/stdlib/v1.2/Test/src/Test.jl:1279
 [11] top-level scope at REPL[5]:1

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fascinating. So the larger question is: which axis should be used as the "basis" for the output and which axis is being "extruded" over that basis?

If we have any non-singleton dimension then it's a no-brainer: that's our output axis and all other singleton axes are discarded as they're being "extruded". We demand that the axes values match, so multiple non-singleton dimensions are by definition identical or an error. (We do promote the axis type, though, such that an OffsetArray that happens to be 1-based will still return a similarly non-offset OffsetArray for type stability.)

But when we have multiple singleton dimensions it's not clear. On master, we just use the last-provided axis for the given dimension, which isn't great. It's not really a one-based vs. offset thing — the same happens for combinations of offset arrays:

julia> axes(OffsetArray([0], 1) .+ OffsetArray([0], 2))
(Base.IdentityUnitRange(3:3),)

julia> axes(OffsetArray([0], 2) .+ OffsetArray([0], 1))
(Base.IdentityUnitRange(2:2),)

I think the one-based vs. offset axis thing is a red herring — we don't really know which is going to matter in the output. So I say we just find some consistent (symmetric) rule that can generalize to multiple singleton offset axes, too. Some possibilities:

  • Always use a one-based output for any combination of singleton axes. I don't like this because then OffsetArray(rand(1), -1) .+ OffsetArray(rand(1), -1) would no longer be offset. So let's ax this candidate.
  • Always demand that combinations of singleton axes match. At first glance I'd think this would be massively breaking. It does help, though, that we completely ignore the pseudo-axes beyond the dimensionality of the array (e.g., we don't care what axes(fill(0), 2) returns, so scalars and zero-dimensional arrays could still broadcast with anything). It may be interesting to try out, but it could make for a really annoying trap where an expression mostly works until you end up with a one-vector instead of a two-vector.
  • Use some symmetric value-based rule that picks the result based on the axes values themselves. Possible candidates include:
    • Just pick the axis with the minimum or maximum element
    • Pick the axis with the element closest to 1 (biasing the comparison to favor either axes below or above 1 so there aren't ties). Perhaps the arc of the Julian universe bends towards 1-based arrays.
    • Pick the axis with the element farthest from 1 (again, biased to favor either above/below). The reasoning here would be that someone chose to have that offset; they should be allowed to keep it.
    • Pick the axis that matches the most other axes in the expression. In case of ties (which would be common for binary broadcasts), use one of the other rules.


## Check that all arguments are broadcast compatible with shape
# comparing one input against a shape
Expand Down Expand Up @@ -1002,15 +1007,20 @@ broadcasted(::DefaultArrayStyle{1}, ::typeof(+), r::OrdinalRange) = r
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), r::StepRangeLen) = r
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), r::LinRange) = r

broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::OrdinalRange) = range(-first(r), step=-step(r), length=length(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::OrdinalRange) =
(require_one_based_indexing(r); range(-first(r), step=-step(r), length=length(r)))
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::StepRangeLen) = StepRangeLen(-r.ref, -r.step, length(r), r.offset)
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::LinRange) = LinRange(-r.start, -r.stop, length(r))

broadcasted(::DefaultArrayStyle{1}, ::typeof(+), x::Real, r::AbstractUnitRange) = range(x + first(r), length=length(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), r::AbstractUnitRange, x::Real) = range(first(r) + x, length=length(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), x::Real, r::AbstractUnitRange) =
(require_one_based_indexing(r); range(x + first(r), length=length(r)))
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), r::AbstractUnitRange, x::Real) =
(require_one_based_indexing(r); range(first(r) + x, length=length(r)))
# For #18336 we need to prevent promotion of the step type:
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), r::AbstractRange, x::Number) = range(first(r) + x, step=step(r), length=length(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), x::Number, r::AbstractRange) = range(x + first(r), step=step(r), length=length(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), r::AbstractRange, x::Number) =
(require_one_based_indexing(r); range(first(r) + x, step=step(r), length=length(r)))
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), x::Number, r::AbstractRange) =
(require_one_based_indexing(r); range(x + first(r), step=step(r), length=length(r)))
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), r::StepRangeLen{T}, x::Number) where T =
StepRangeLen{typeof(T(r.ref)+x)}(r.ref + x, r.step, length(r), r.offset)
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), x::Number, r::StepRangeLen{T}) where T =
Expand All @@ -1019,9 +1029,12 @@ broadcasted(::DefaultArrayStyle{1}, ::typeof(+), r::LinRange, x::Number) = LinRa
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), x::Number, r::LinRange) = LinRange(x + r.start, x + r.stop, length(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), r1::AbstractRange, r2::AbstractRange) = r1 + r2

broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::AbstractUnitRange, x::Number) = range(first(r)-x, length=length(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::AbstractRange, x::Number) = range(first(r)-x, step=step(r), length=length(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), x::Number, r::AbstractRange) = range(x-first(r), step=-step(r), length=length(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::AbstractUnitRange, x::Number) =
(require_one_based_indexing(r); range(first(r)-x, length=length(r)))
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::AbstractRange, x::Number) =
(require_one_based_indexing(r); range(first(r)-x, step=step(r), length=length(r)))
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), x::Number, r::AbstractRange) =
(require_one_based_indexing(r); range(x-first(r), step=-step(r), length=length(r)))
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::StepRangeLen{T}, x::Number) where T =
StepRangeLen{typeof(T(r.ref)-x)}(r.ref - x, r.step, length(r), r.offset)
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), x::Number, r::StepRangeLen{T}) where T =
Expand All @@ -1030,22 +1043,26 @@ broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::LinRange, x::Number) = LinRa
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), x::Number, r::LinRange) = LinRange(x - r.start, x - r.stop, length(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r1::AbstractRange, r2::AbstractRange) = r1 - r2

broadcasted(::DefaultArrayStyle{1}, ::typeof(*), x::Number, r::AbstractRange) = range(x*first(r), step=x*step(r), length=length(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(*), x::Number, r::AbstractRange) =
(require_one_based_indexing(r); range(x*first(r), step=x*step(r), length=length(r)))
broadcasted(::DefaultArrayStyle{1}, ::typeof(*), x::Number, r::StepRangeLen{T}) where {T} =
StepRangeLen{typeof(x*T(r.ref))}(x*r.ref, x*r.step, length(r), r.offset)
broadcasted(::DefaultArrayStyle{1}, ::typeof(*), x::Number, r::LinRange) = LinRange(x * r.start, x * r.stop, r.len)
# separate in case of noncommutative multiplication
broadcasted(::DefaultArrayStyle{1}, ::typeof(*), r::AbstractRange, x::Number) = range(first(r)*x, step=step(r)*x, length=length(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(*), r::AbstractRange, x::Number) =
(require_one_based_indexing(r); range(first(r)*x, step=step(r)*x, length=length(r)))
broadcasted(::DefaultArrayStyle{1}, ::typeof(*), r::StepRangeLen{T}, x::Number) where {T} =
StepRangeLen{typeof(T(r.ref)*x)}(r.ref*x, r.step*x, length(r), r.offset)
broadcasted(::DefaultArrayStyle{1}, ::typeof(*), r::LinRange, x::Number) = LinRange(r.start * x, r.stop * x, r.len)

broadcasted(::DefaultArrayStyle{1}, ::typeof(/), r::AbstractRange, x::Number) = range(first(r)/x, step=step(r)/x, length=length(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(/), r::AbstractRange, x::Number) =
(require_one_based_indexing(r); range(first(r)/x, step=step(r)/x, length=length(r)))
broadcasted(::DefaultArrayStyle{1}, ::typeof(/), r::StepRangeLen{T}, x::Number) where {T} =
StepRangeLen{typeof(T(r.ref)/x)}(r.ref/x, r.step/x, length(r), r.offset)
broadcasted(::DefaultArrayStyle{1}, ::typeof(/), r::LinRange, x::Number) = LinRange(r.start / x, r.stop / x, r.len)

broadcasted(::DefaultArrayStyle{1}, ::typeof(\), x::Number, r::AbstractRange) = range(x\first(r), step=x\step(r), length=length(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(\), x::Number, r::AbstractRange) =
(require_one_based_indexing(r); range(x\first(r), step=x\step(r), length=length(r)))
broadcasted(::DefaultArrayStyle{1}, ::typeof(\), x::Number, r::StepRangeLen) = StepRangeLen(x\r.ref, x\r.step, length(r), r.offset)
broadcasted(::DefaultArrayStyle{1}, ::typeof(\), x::Number, r::LinRange) = LinRange(x \ r.start, x \ r.stop, r.len)

Expand Down
16 changes: 16 additions & 0 deletions base/indices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ struct Slice{T<:AbstractUnitRange} <: AbstractUnitRange{Int}
indices::T
end
Slice(S::Slice) = S
AxesStartStyle(::Type{<:Slice}) = AxesStartAny()
axes(S::Slice) = (IdentityUnitRange(S.indices),)
unsafe_indices(S::Slice) = (IdentityUnitRange(S.indices),)
axes1(S::Slice) = IdentityUnitRange(S.indices)
Expand All @@ -333,6 +334,13 @@ getindex(S::Slice, i::AbstractUnitRange{<:Integer}) = (@_inline_meta; @boundsche
getindex(S::Slice, i::StepRange{<:Integer}) = (@_inline_meta; @boundscheck checkbounds(S, i); i)
show(io::IO, r::Slice) = print(io, "Base.Slice(", r.indices, ")")
iterate(S::Slice, s...) = iterate(S.indices, s...)
_convert(::AxesStart1, ::AxesStartAny, ::Type{T}, S::Slice) where {T<:AbstractRange} =
(require_one_based_indexing(S); T(S))
function _convert(::AxesStartAny, ::AxesStart1, ::Type{T}, r::AbstractUnitRange) where {T<:Slice}
throwerr(r) = (@_noinline_meta; throw(ArgumentError("`convert($T, r)` requires a range with first element 1, got $(first(r))")))
first(r) == 1 || throwerr(r)
return T(r)
end


"""
Expand All @@ -346,6 +354,7 @@ struct IdentityUnitRange{T<:AbstractUnitRange} <: AbstractUnitRange{Int}
indices::T
end
IdentityUnitRange(S::IdentityUnitRange) = S
AxesStartStyle(::Type{<:IdentityUnitRange}) = AxesStartAny()
# IdentityUnitRanges are offset and thus have offset axes, so they are their own axes... but
# we need to strip the wholedim marker because we don't know how they'll be used
axes(S::IdentityUnitRange) = (S,)
Expand All @@ -365,6 +374,13 @@ getindex(S::IdentityUnitRange, i::AbstractUnitRange{<:Integer}) = (@_inline_meta
getindex(S::IdentityUnitRange, i::StepRange{<:Integer}) = (@_inline_meta; @boundscheck checkbounds(S, i); i)
show(io::IO, r::IdentityUnitRange) = print(io, "Base.IdentityUnitRange(", r.indices, ")")
iterate(S::IdentityUnitRange, s...) = iterate(S.indices, s...)
_convert(::AxesStart1, ::AxesStartAny, ::Type{T}, S::IdentityUnitRange) where {T<:AbstractRange} =
(require_one_based_indexing(S); T(S))
function _convert(::AxesStartAny, ::AxesStart1, ::Type{T}, r::AbstractUnitRange) where {T<:IdentityUnitRange}
throwerr(r) = (@_noinline_meta; throw(ArgumentError("`convert($T, r)` requires a range with first element 1, got $(first(r))")))
first(r) == 1 || throwerr(r)
return T(r)
end

"""
LinearIndices(A::AbstractArray)
Expand Down
60 changes: 46 additions & 14 deletions base/range.jl
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,20 @@ abstract type AbstractRange{T} <: AbstractArray{T,1} end
RangeStepStyle(::Type{<:AbstractRange}) = RangeStepIrregular()
RangeStepStyle(::Type{<:AbstractRange{<:Integer}}) = RangeStepRegular()

convert(::Type{T}, r::AbstractRange) where {T<:AbstractRange} = r isa T ? r : T(r)
AxesStartStyle(::Type{<:AbstractRange}) = AxesStart1() # opt-out of AxesStart1 for "weird" range types
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In some ways I'd prefer "opt in" on this trait, but I think it's too breaking. So I think it's better to mark AbstractRange types that violate the expectation of having unit indexing.

AxesStartStyle(r::AbstractRange) = AxesStartStyle(typeof(r))

convert(::Type{AbstractRange}, r::AbstractRange) = r
convert(::Type{T}, r::T) where {T<:AbstractRange} = r
convert(::Type{T}, r::AbstractRange) where {T<:AbstractRange} = _convert(AxesStartStyle(T), AxesStartStyle(r), T, r)
_convert(::AxesStart1, ::AxesStart1, ::Type{T}, r::AbstractRange) where {T<:AbstractRange} = T(r)
_convert(::AxesStartStyle, ::AxesStartStyle, ::Type{T}, r::AbstractRange) where {T<:AbstractRange} =
throw(MethodError(convert, (T, r)))

require_one_based_indexing(r::AbstractRange) = _require_one_based_indexing(AxesStartStyle(r), r)
_require_one_based_indexing(::AxesStartStyle, r) =
!has_offset_axes(r) || throw(ArgumentError("offset arrays are not supported but got an array with index other than 1"))
_require_one_based_indexing(::AxesStart1, r) = true

## ordinal ranges

Expand All @@ -157,6 +170,8 @@ type can represent values smaller than `oneunit(Float64)`.
"""
abstract type OrdinalRange{T,S} <: AbstractRange{T} end

convert(::Type{OrdinalRange{T,S}}, r::OrdinalRange{T,S}) where {T,S} = r
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not strictly necessary for this PR, I think, but I noticed their lack and thought it might be good for disambiguating.


"""
AbstractUnitRange{T} <: OrdinalRange{T, T}

Expand All @@ -165,6 +180,8 @@ Supertype for ranges with a step size of [`oneunit(T)`](@ref) with elements of t
"""
abstract type AbstractUnitRange{T} <: OrdinalRange{T,T} end

convert(::Type{AbstractUnitRange{T}}, r::AbstractUnitRange{T}) where {T} = r

"""
StepRange{T, S} <: OrdinalRange{T, S}

Expand Down Expand Up @@ -307,15 +324,15 @@ be 1.
struct OneTo{T<:Integer} <: AbstractUnitRange{T}
stop::T
OneTo{T}(stop) where {T<:Integer} = new(max(zero(T), stop))
function OneTo{T}(r::AbstractRange) where {T<:Integer}
throwstart(r) = (@_noinline_meta; throw(ArgumentError("first element must be 1, got $(first(r))")))
throwstep(r) = (@_noinline_meta; throw(ArgumentError("step must be 1, got $(step(r))")))
first(r) == 1 || throwstart(r)
step(r) == 1 || throwstep(r)
return new(max(zero(T), last(r)))
end
end
OneTo(stop::T) where {T<:Integer} = OneTo{T}(stop)
function OneTo{T}(r::AbstractRange) where {T<:Integer}
throwstart(r) = (@_noinline_meta; throw(ArgumentError("first element must be 1, got $(first(r))")))
throwstep(r) = (@_noinline_meta; throw(ArgumentError("step must be 1, got $(step(r))")))
first(r) == 1 || throwstart(r)
step(r) == 1 || throwstep(r)
return OneTo{T}(last(r))
end
OneTo(r::AbstractRange{T}) where {T<:Integer} = OneTo{T}(r)

## Step ranges parameterized by length
Expand Down Expand Up @@ -713,10 +730,14 @@ show(io::IO, r::AbstractRange) = print(io, repr(first(r)), ':', repr(step(r)), '
show(io::IO, r::UnitRange) = print(io, repr(first(r)), ':', repr(last(r)))
show(io::IO, r::OneTo) = print(io, "Base.OneTo(", r.stop, ")")

range_axes_first_same(r, s) = _range_axes_first_same(AxesStartStyle(r), AxesStartStyle(s), r, s)
_range_axes_first_same(::AxesStart1, ::AxesStart1, r, s) = true
_range_axes_first_same(::AxesStartStyle, ::AxesStartStyle, r, s) = first(axes1(r)) == first(axes1(s))

==(r::T, s::T) where {T<:AbstractRange} =
(first(r) == first(s)) & (step(r) == step(s)) & (last(r) == last(s))
(first(r) == first(s)) & (step(r) == step(s)) & (last(r) == last(s)) & range_axes_first_same(r, s)
==(r::OrdinalRange, s::OrdinalRange) =
(first(r) == first(s)) & (step(r) == step(s)) & (last(r) == last(s))
(first(r) == first(s)) & (step(r) == step(s)) & (last(r) == last(s)) & range_axes_first_same(r, s)
==(r::T, s::T) where {T<:Union{StepRangeLen,LinRange}} =
(first(r) == first(s)) & (length(r) == length(s)) & (last(r) == last(s))
==(r::Union{StepRange{T},StepRangeLen{T,T}}, s::Union{StepRange{T},StepRangeLen{T,T}}) where {T} =
Expand All @@ -727,6 +748,7 @@ function ==(r::AbstractRange, s::AbstractRange)
if lr != length(s)
return false
end
range_axes_first_same(r, s) || return false
yr, ys = iterate(r), iterate(s)
while yr !== nothing
yr[1] == ys[1] || return false
Expand Down Expand Up @@ -849,7 +871,7 @@ end

## linear operations on ranges ##

-(r::OrdinalRange) = range(-first(r), step=-step(r), length=length(r))
-(r::OrdinalRange) = (require_one_based_indexing(r); range(-first(r), step=-step(r), length=length(r)))
-(r::StepRangeLen{T,R,S}) where {T,R,S} =
StepRangeLen{T,R,S}(-r.ref, -r.step, length(r), r.offset)
-(r::LinRange) = LinRange(-r.start, -r.stop, length(r))
Expand All @@ -861,6 +883,9 @@ el_same(::Type{T}, a::Type{<:AbstractArray{T,n}}, b::Type{<:AbstractArray{S,n}})
el_same(::Type{T}, a::Type{<:AbstractArray{S,n}}, b::Type{<:AbstractArray{T,n}}) where {T,S,n} = b
el_same(::Type, a, b) = promote_typejoin(a, b)

# promote_rule and more constructors
# Note: construction is distinct from conversion, convert should check require_one_based_indexing(r)
# but construction should not.
promote_rule(a::Type{UnitRange{T1}}, b::Type{UnitRange{T2}}) where {T1,T2} =
el_same(promote_type(T1,T2), a, b)
UnitRange{T}(r::UnitRange{T}) where {T<:Real} = r
Expand Down Expand Up @@ -944,7 +969,10 @@ end
Array{T,1}(r::AbstractRange{T}) where {T} = vcat(r)
collect(r::AbstractRange) = vcat(r)

reverse(r::OrdinalRange) = (:)(last(r), -step(r), first(r))
function reverse(r::OrdinalRange)
require_one_based_indexing(r)
(:)(last(r), -step(r), first(r))
end
function reverse(r::StepRangeLen)
# If `r` is empty, `length(r) - r.offset + 1 will be nonpositive hence
# invalid. As `reverse(r)` is also empty, any offset would work so we keep
Expand All @@ -964,8 +992,11 @@ sort!(r::AbstractUnitRange) = r

sort(r::AbstractRange) = issorted(r) ? r : reverse(r)

sortperm(r::AbstractUnitRange) = 1:length(r)
sortperm(r::AbstractRange) = issorted(r) ? (1:1:length(r)) : (length(r):-1:1)
sortperm(r::AbstractUnitRange) = (require_one_based_indexing(r); 1:length(r))
function sortperm(r::AbstractRange)
require_one_based_indexing(r)
issorted(r) ? (1:1:length(r)) : (length(r):-1:1)
end

function sum(r::AbstractRange{<:Real})
l = length(r)
Expand Down Expand Up @@ -1004,6 +1035,7 @@ function _define_range_op(@nospecialize f)
r1l = length(r1)
(r1l == length(r2) ||
throw(DimensionMismatch("argument dimensions must match")))
require_one_based_indexing(r1, r2)
range($f(first(r1), first(r2)), step=$f(step(r1), step(r2)), length=r1l)
end

Expand Down
34 changes: 13 additions & 21 deletions base/reducedim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

# for reductions that expand 0 dims to 1
reduced_index(i::OneTo) = OneTo(1)
reduced_index(i::Union{Slice, IdentityUnitRange}) = first(i):first(i)
reduced_index(i::Slice) = Slice(first(i):first(i))
reduced_index(i::IdentityUnitRange) = IdentityUnitRange(first(i):first(i))
reduced_index(i::AbstractUnitRange) =
throw(ArgumentError(
"""
Expand Down Expand Up @@ -43,33 +44,24 @@ function reduced_indices0(inds::Indices{N}, d::Int) where N
end
end

function reduced_indices(inds::Indices{N}, region) where N
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I initially started modifying this code to fix a problem but then transitioned to the changes above. Still, this seemed like old code in need of some cleanup.

rinds = [inds...]
function check_reduced_region(region, N)
for i in region
isa(i, Integer) || throw(ArgumentError("reduced dimension(s) must be integers"))
d = Int(i)
if d < 1
throw(ArgumentError("region dimension(s) must be ≥ 1, got $d"))
elseif d <= N
rinds[d] = reduced_index(rinds[d])
if i < 1
throw(ArgumentError("region dimension(s) must be ≥ 1, got $i"))
end
end
tuple(rinds...)::typeof(inds)
return nothing
end

function reduced_indices(inds::Indices{N}, region) where N
check_reduced_region(region, N)
ntuple(i->in(i, region) ? reduced_index(inds[i]) : inds[i], Val(N))::typeof(inds)
end

function reduced_indices0(inds::Indices{N}, region) where N
rinds = [inds...]
for i in region
isa(i, Integer) || throw(ArgumentError("reduced dimension(s) must be integers"))
d = Int(i)
if d < 1
throw(ArgumentError("region dimension(s) must be ≥ 1, got $d"))
elseif d <= N
rind = rinds[d]
rinds[d] = isempty(rind) ? rind : reduced_index(rind)
end
end
tuple(rinds...)::typeof(inds)
check_reduced_region(region, N)
ntuple(i->in(i, region) ? (r = inds[i]; isempty(r) ? r : reduced_index(r)) : inds[i], Val(N))::typeof(inds)
end

###### Generic reduction functions #####
Expand Down
2 changes: 1 addition & 1 deletion base/reinterpretarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ function axes(a::ReinterpretArray{T,N,S} where {N}) where {T,S}
paxs = axes(a.parent)
f, l = first(paxs[1]), length(paxs[1])
size1 = div(l*sizeof(S), sizeof(T))
tuple(oftype(paxs[1], f:f+size1-1), tail(paxs)...)
tuple(typeof(paxs[1])(f:f+size1-1), tail(paxs)...)
end
axes(a::ReinterpretArray{T,0}) where {T} = ()

Expand Down
15 changes: 15 additions & 0 deletions base/traits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,18 @@ struct RangeStepRegular <: RangeStepStyle end # range with regular step
struct RangeStepIrregular <: RangeStepStyle end # range with rounding error

RangeStepStyle(instance) = RangeStepStyle(typeof(instance))

# trait that allows skipping of axes-checking on abstract range types (risks overflow on `length`)
"""
AxesStartStyle(instance)
AxesStartStyle(T::Type)

Indicate the value that `axes(instance)` starts with. Containers that return `AxesStart1()`
must have `axes(instance)` start with 1 (e.g., `Base.OneTo` axes). Such containers may
bypass axes checks for certain operations (e.g., range comparisons to avoid risk of overflow).
`AxesStartAny()` indicates that one cannot count on the axes starting with 1, and that
an explicit check is required.
"""
abstract type AxesStartStyle end
struct AxesStart1 <: AxesStartStyle end
struct AxesStartAny <: AxesStartStyle end
Loading