diff --git a/base/operators.jl b/base/operators.jl index f0647be1b65ada..20e65707ad59dc 100644 --- a/base/operators.jl +++ b/base/operators.jl @@ -902,6 +902,9 @@ julia> [1:5;] .|> (x -> x^2) |> sum |> inv """ |>(x, f) = f(x) +_stable_typeof(x) = typeof(x) +_stable_typeof(::Type{T}) where {T} = @isdefined(T) ? Type{T} : DataType + """ f = Returns(value) @@ -928,7 +931,7 @@ julia> f.value struct Returns{V} <: Function value::V Returns{V}(value) where {V} = new{V}(value) - Returns(value) = new{Core.Typeof(value)}(value) + Returns(value) = new{_stable_typeof(value)}(value) end (obj::Returns)(@nospecialize(args...); @nospecialize(kw...)) = obj.value @@ -1014,7 +1017,19 @@ struct ComposedFunction{O,I} <: Function ComposedFunction(outer, inner) = new{Core.Typeof(outer),Core.Typeof(inner)}(outer, inner) end -(c::ComposedFunction)(x...; kw...) = c.outer(c.inner(x...; kw...)) +function (c::ComposedFunction)(x...; kw...) + fs = unwrap_composed(c) + call_composed(fs[1](x...; kw...), tail(fs)...) +end +unwrap_composed(c::ComposedFunction) = (unwrap_composed(c.inner)..., unwrap_composed(c.outer)...) +unwrap_composed(c) = (maybeconstructor(c),) +call_composed(x, f, fs...) = (@inline; call_composed(f(x), fs...)) +call_composed(x, f) = f(x) + +struct Constructor{F} <: Function end +(::Constructor{F})(args...; kw...) where {F} = (@inline; F(args...; kw...)) +maybeconstructor(::Type{F}) where {F} = Constructor{F}() +maybeconstructor(f) = f ∘(f) = f ∘(f, g) = ComposedFunction(f, g) @@ -1078,8 +1093,8 @@ struct Fix1{F,T} <: Function f::F x::T - Fix1(f::F, x::T) where {F,T} = new{F,T}(f, x) - Fix1(f::Type{F}, x::T) where {F,T} = new{Type{F},T}(f, x) + Fix1(f::F, x) where {F} = new{F,_stable_typeof(x)}(f, x) + Fix1(f::Type{F}, x) where {F} = new{Type{F},_stable_typeof(x)}(f, x) end (f::Fix1)(y) = f.f(f.x, y) @@ -1095,8 +1110,8 @@ struct Fix2{F,T} <: Function f::F x::T - Fix2(f::F, x::T) where {F,T} = new{F,T}(f, x) - Fix2(f::Type{F}, x::T) where {F,T} = new{Type{F},T}(f, x) + Fix2(f::F, x) where {F} = new{F,_stable_typeof(x)}(f, x) + Fix2(f::Type{F}, x) where {F} = new{Type{F},_stable_typeof(x)}(f, x) end (f::Fix2)(y) = f.f(y, f.x) diff --git a/base/reduce.jl b/base/reduce.jl index 45284d884a2791..7f0ee2382b68f4 100644 --- a/base/reduce.jl +++ b/base/reduce.jl @@ -140,17 +140,25 @@ what is returned is `itr′` and op′ = (xfₙ ∘ ... ∘ xf₂ ∘ xf₁)(op) """ -_xfadjoint(op, itr) = (op, itr) -_xfadjoint(op, itr::Generator) = - if itr.f === identity - _xfadjoint(op, itr.iter) - else - _xfadjoint(MappingRF(itr.f, op), itr.iter) - end -_xfadjoint(op, itr::Filter) = - _xfadjoint(FilteringRF(itr.flt, op), itr.itr) -_xfadjoint(op, itr::Flatten) = - _xfadjoint(FlatteningRF(op), itr.it) +function _xfadjoint(op, itr) + itr′, wrap = _xfadjoint_unwrap(itr) + wrap(op), itr′ +end + +_xfadjoint_unwrap(itr) = itr, identity +function _xfadjoint_unwrap(itr::Generator) + itr′, wrap = _xfadjoint_unwrap(itr.iter) + itr.f === identity && return itr′, wrap + return itr′, wrap ∘ Fix1(MappingRF, itr.f) +end +function _xfadjoint_unwrap(itr::Filter) + itr′, wrap = _xfadjoint_unwrap(itr.itr) + return itr′, wrap ∘ Fix1(FilteringRF, itr.flt) +end +function _xfadjoint_unwrap(itr::Flatten) + itr′, wrap = _xfadjoint_unwrap(itr.it) + return itr′, wrap ∘ FlatteningRF +end """ mapfoldl(f, op, itr; [init]) diff --git a/test/operators.jl b/test/operators.jl index a1e27d0e1cd7b6..5e505391afd5a6 100644 --- a/test/operators.jl +++ b/test/operators.jl @@ -175,6 +175,15 @@ Base.promote_rule(::Type{T19714}, ::Type{Int}) = T19714 end +@testset "Nested ComposedFunction's stability" begin + f(x) = (1, 1, x...) + g = (f ∘ (f ∘ f)) ∘ (f ∘ f ∘ f) + @test (@inferred (g∘g)(1)) == ntuple(Returns(1), 25) + @test (@inferred g(1)) == ntuple(Returns(1), 13) + h = (-) ∘ (-) ∘ (-) ∘ (-) ∘ (-) ∘ (-) ∘ sum + @test (@inferred h((1, 2, 3); init = 0.0)) == 6.0 +end + @testset "function negation" begin str = randstring(20) @test filter(!isuppercase, str) == replace(str, r"[A-Z]" => "") @@ -308,4 +317,7 @@ end val = [1,2,3] @test Returns(val)(1) === val @test sprint(show, Returns(1.0)) == "Returns{Float64}(1.0)" + + illtype = Vector{Core._typevar(:T, Union{}, Any)} + @test Returns(illtype) == Returns{DataType}(illtype) end diff --git a/test/reduce.jl b/test/reduce.jl index db8c97f2f80cab..78988dbdc4225b 100644 --- a/test/reduce.jl +++ b/test/reduce.jl @@ -677,3 +677,16 @@ end @test mapreduce(+, +, oa, oa) == 2len end end + +# issue #45748 +@testset "foldl's stability for nested Iterators" begin + a = Iterators.flatten((1:3, 1:3)) + b = (2i for i in a if i > 0) + c = Base.Generator(Float64, b) + d = (sin(i) for i in c if i > 0) + @test @inferred(sum(d)) == sum(collect(d)) + @test @inferred(extrema(d)) == extrema(collect(d)) + @test @inferred(maximum(c)) == maximum(collect(c)) + @test @inferred(prod(b)) == prod(collect(b)) + @test @inferred(minimum(a)) == minimum(collect(a)) +end