From a2aeaccd14148164272c29c3511a568669025e95 Mon Sep 17 00:00:00 2001 From: Takafumi Arakaki Date: Sat, 29 Jan 2022 21:06:51 -0500 Subject: [PATCH 1/4] Add `completebasecase` protocol --- src/combinators.jl | 46 ++++++++++++++++++++++++++++++++++++++-------- src/core.jl | 5 +++++ src/processes.jl | 8 +++++--- 3 files changed, 48 insertions(+), 11 deletions(-) diff --git a/src/combinators.jl b/src/combinators.jl index 9e3a48eaae..7610fa872a 100644 --- a/src/combinators.jl +++ b/src/combinators.jl @@ -217,34 +217,60 @@ julia> foldxt(rf, Map(identity), 1:4; basesize = 1, init = OnInit(() -> [])) 4 ``` """ -struct AdHocRF{OnInit,Start,Next,Complete,Combine} <: _Function +struct AdHocRF{OnInit,Start,Next,CompleteBasecase,Complete,Combine} <: _Function oninit::OnInit start::Start next::Next + completebasecase::CompleteBasecase complete::Complete combine::Combine - AdHocRF{OnInit,Start,Next,Complete,Combine}( + function AdHocRF{OnInit,Start,Next,CompleteBasecase,Complete,Combine}( oninit, start, next, + completebasecase, complete, combine, - ) where {OnInit,Start,Next,Complete,Combine} = - new{OnInit,Start,Next,Complete,Combine}(oninit, start, next, complete, combine) + ) where {OnInit,Start,Next,CompleteBasecase,Complete,Combine} + return new{OnInit,Start,Next,CompleteBasecase,Complete,Combine}( + oninit, + start, + next, + completebasecase, + complete, + combine, + ) + end end -AdHocRF(oninit, start, op, complete, combine) = - AdHocRF{_typeof(oninit),_typeof(start),_typeof(op),_typeof(complete),_typeof(combine)}( +# Capture T::Type as Type{T} +function AdHocRF(oninit, start, op, completebasecase, complete, combine) + return AdHocRF{ + _typeof(oninit), + _typeof(start), + _typeof(op), + _typeof(completebasecase), + _typeof(complete), + _typeof(combine), + }( oninit, start, op, + completebasecase, complete, combine, ) +end -AdHocRF(op; oninit = nothing, start = identity, complete = identity, combine = nothing) = - AdHocRF(oninit, start, op, complete, combine) +AdHocRF( + op; + oninit = nothing, + start = identity, + completebasecase = identity, + complete = identity, + combine = nothing, +) = AdHocRF(oninit, start, op, completebasecase, complete, combine) AdHocRF(op::AdHocRF; kwargs...) = setproperties(op, values(kwargs)) @@ -260,6 +286,7 @@ AdHocRF(op::AdHocRF; kwargs...) = setproperties(op, values(kwargs)) @inline start(rf::AdHocRF, init) = rf.start(initialize(init, rf.next)) @inline next(rf::AdHocRF, acc, x) = rf.next(acc, x) @inline complete(rf::AdHocRF, acc) = rf.complete(acc) +@inline completebasecase(rf::AdHocRF, acc) = rf.completebasecase(acc) @inline combine(rf::AdHocRF, a, b) = something(rf.combine, rf.next)(a, b) _asmonoid(rf::AdHocRF) = @set rf.next = _asmonoid(rf.next) @@ -267,11 +294,14 @@ Completing(rf::AdHocRF) = rf wheninit(oninit, op) = AdHocRF(op; oninit = oninit) whenstart(start, op) = AdHocRF(op; start = start) +whencompletebasecase(completebasecase, op) = + AdHocRF(op; completebasecase = completebasecase) whencomplete(complete, op) = AdHocRF(op; complete = complete) whencombine(combine, op) = AdHocRF(op; combine = combine) wheninit(oninit) = op -> wheninit(oninit, op) whenstart(start) = op -> whenstart(start, op) +whencompletebasecase(completebasecase) = op -> whencompletebasecase(completebasecase, op) whencomplete(complete) = op -> whencomplete(complete, op) whencombine(combine) = op -> whencombine(combine, op) diff --git a/src/core.jl b/src/core.jl index 531f2362a1..967ec8a71c 100644 --- a/src/core.jl +++ b/src/core.jl @@ -287,6 +287,7 @@ ensurerf(f) = BottomRF(f) # `Completing` etc. start(rf::BottomRF, result) = start(inner(rf), result) @inline next(rf::BottomRF, result, input) = next(inner(rf), result, input) +@inline completebasecase(rf::BottomRF, result) = completebasecase(inner(rf), result) complete(rf::BottomRF, result) = complete(inner(rf), result) combine(rf::BottomRF, a, b) = combine(inner(rf), a, b) @@ -495,6 +496,10 @@ real-world examples. # done(rf, result) +completebasecase(_, result) = result +completebasecase(rf::RF, result) where {RF <: AbstractReduction} = + completebasecase(inner(rf), result) + """ Transducers.complete(rf::R_{X}, state) diff --git a/src/processes.jl b/src/processes.jl index faa88210d8..1d2a6a658a 100644 --- a/src/processes.jl +++ b/src/processes.jl @@ -349,12 +349,14 @@ function simple_transduce(xform, f, init, coll) end """ - foldl_nocomplete(rf, init, coll) + foldl_basecase(rf, init, coll) Call [`__foldl__`](@ref) without calling [`complete`](@ref). """ -@inline foldl_nocomplete(rf::RF, init, coll) where {RF} = - __foldl__(skipcomplete(rf), init, coll) +@inline foldl_basecase(rf::RF, init, coll) where {RF} = + completebasecase(rf, __foldl__(skipcomplete(rf), init, coll)) + +const foldl_nocomplete = foldl_basecase """ foldxl(step, xf::Transducer, reducible; init, simd) :: T From 141580432bda12c855d7308b10cd3ed602338ee9 Mon Sep 17 00:00:00 2001 From: Takafumi Arakaki Date: Sun, 30 Jan 2022 12:48:48 -0500 Subject: [PATCH 2/4] Notes on completebasecase --- src/combinators.jl | 32 ++++++++++++++++++++++++++++++++ src/core.jl | 13 +++++++++++++ 2 files changed, 45 insertions(+) diff --git a/src/combinators.jl b/src/combinators.jl index 7610fa872a..b313f55ada 100644 --- a/src/combinators.jl +++ b/src/combinators.jl @@ -417,3 +417,35 @@ julia> foldxt(averaging2, Filter(isodd), 1:50; basesize = 1) ``` """ (wheninit, whenstart, whencomplete, whencombine) + +""" + whencompletebasecase(completebasecase, rf) -> rf′ + whencompletebasecase(completebasecase) -> rf -> rf′ + +Add [`completebasecase`](@ref) protocol to arbitrary reducing function. + +The function `completebasecase` is used as follows in the basecase +implementation of `reduce` as follows: + +```julia +init′ = oninit() +acc = start(init′) +for x in collection + acc += rf(acc, x) +end +result = completebasecase(acc) +return result +``` + +The `result₁` from basecase 1 and `result₂` from basecase 2 are combined +using [`combine`](@ref) protcol: + +```julia +combine(result₁, result₂) +``` + +!!! note + + This function is an internal experimental interface for FoldsCUDA. +""" +whencompletebasecase diff --git a/src/core.jl b/src/core.jl index 967ec8a71c..b81a1dba6b 100644 --- a/src/core.jl +++ b/src/core.jl @@ -496,6 +496,19 @@ real-world examples. # done(rf, result) +""" + Transducers.completebasecase(rf, state) + +Process basecase result `state` before merged by [`combine`](@ref). + +For example, on GPU, this function can be used to translate mutable states to +immutable values for exchanging them through (un-GC-managed) memory. See +[`whencompletebasecase`](@ref). + +!!! note + + This function is an internal experimental interface for FoldsCUDA. +""" completebasecase(_, result) = result completebasecase(rf::RF, result) where {RF <: AbstractReduction} = completebasecase(inner(rf), result) From 65e6c7ab3a39d72ea8f71c9853666d7fb2891b9f Mon Sep 17 00:00:00 2001 From: Takafumi Arakaki Date: Sun, 30 Jan 2022 13:30:54 -0500 Subject: [PATCH 3/4] Test `whencompletebasecase` --- test/test_adhocrf.jl | 32 +++++++++++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/test/test_adhocrf.jl b/test/test_adhocrf.jl index 83decb9f8c..5669cfad18 100644 --- a/test/test_adhocrf.jl +++ b/test/test_adhocrf.jl @@ -2,8 +2,17 @@ module TestAdHocRF using Test using Transducers -using Transducers: wheninit, whencomplete, whencombine, complete, whenstart +using Transducers: + complete, + foldl_basecase, + start, + whencombine, + whencomplete, + whencompletebasecase, + wheninit, + whenstart using MicroCollections: EmptyVector +using StaticArrays: MVector, SVector @testset "setters" begin rf = nothing |> wheninit(1) |> whenstart(2) |> whencomplete(3) |> whencombine(4) @@ -27,6 +36,27 @@ end @test foldxt(collector!!, Filter(isodd), 1:5; basesize = 1) == 1:2:5 end +counter(n::Integer) = counter(Val(Int(n))) +function counter(::Val{n}) where {n} + init() = zero(MVector{n,Int}) + function inc!(b, i) + @inbounds b[max(begin, min(i, end))] += 1 + b + end + completebasecase(b) = SVector(b) + combine(h, b) = h .+ b + return inc! |> + wheninit(init) |> + whencompletebasecase(completebasecase) |> + whencombine(combine) +end + +@testset "counter" begin + @test foldxl(counter(10), 1:10)::MVector == ones(10) + rf = counter(10) + @test foldl_basecase(rf, start(rf, Init)::MVector, 1:10)::SVector == ones(10) +end + getoninit(rf) = rf.oninit @testset "inference" begin From 2601be1a1bf7e11280d2f489bec5d74ea0da0b47 Mon Sep 17 00:00:00 2001 From: Takafumi Arakaki Date: Sun, 30 Jan 2022 14:41:17 -0500 Subject: [PATCH 4/4] Fix syntax for old Julia --- test/test_adhocrf.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_adhocrf.jl b/test/test_adhocrf.jl index 5669cfad18..a5a3629f76 100644 --- a/test/test_adhocrf.jl +++ b/test/test_adhocrf.jl @@ -40,7 +40,7 @@ counter(n::Integer) = counter(Val(Int(n))) function counter(::Val{n}) where {n} init() = zero(MVector{n,Int}) function inc!(b, i) - @inbounds b[max(begin, min(i, end))] += 1 + @inbounds b[max(1, min(i, n))] += 1 b end completebasecase(b) = SVector(b)