From 47b0a97e064c3876acca9b9ab97125439fc6e8d8 Mon Sep 17 00:00:00 2001 From: David Gold Date: Sat, 1 Oct 2016 14:45:23 -0700 Subject: [PATCH] Implement lifting infrastructure --- base/exports.jl | 1 + base/nullable.jl | 124 +++++++++++++++++++++++++++++++++++++++++++++++ test/nullable.jl | 69 ++++++++++++++++++++++++++ 3 files changed, 194 insertions(+) diff --git a/base/exports.jl b/base/exports.jl index 87ca75cf646cb6..13d7a0ef50f60c 100644 --- a/base/exports.jl +++ b/base/exports.jl @@ -1326,6 +1326,7 @@ export # nullable types isnull, unsafe_get, + Lifted, # Macros # parser internal diff --git a/base/nullable.jl b/base/nullable.jl index c2d58dd968f87b..91add723b283da 100644 --- a/base/nullable.jl +++ b/base/nullable.jl @@ -217,3 +217,127 @@ function hash(x::Nullable, h::UInt) return hash(x.value, h + nullablehash_seed) end end + +""" + Lifted{F} + +A type used to represent the lifted version of a function `f::F`. + +Calling an `_f::Lifted{F}` on arguments `xs...` lowers to +`lift(_f.f, U, xs...)`, where the return type parameter `U` is chosen with the +help of type inference. +""" +immutable Lifted{F} + f::F + cache::Dict{Tuple{Vararg{DataType}}, DataType} + + (::Type{Lifted}){F}(f::F) = new{F}( + f, Dict{Tuple{Vararg{DataType}}, DataType}() + ) +end + +function (_f::Lifted{F}){F}(xs...) + f, cache = _f.f, _f.cache + signature = map(eltype, xs) + U = Base.@get!( + cache, + signature, + Core.Inference.return_type(f, Tuple{signature...}) + ) + return lift(f, U, xs...) +end + +""" + lift(f::F)::Lifted{F} + +Return a lifted version of `f`. +""" +lift(f) = Lifted(f) + +""" + lift(f, U, xs...) + +Return an empty `Nullable{U}` if any of the `xs` is null; otherwise, return the +(`Nullable`-wrapped) value of `f` applied to the values of the `xs`. + +NOTE: There are two exceptions to the above: `lift(|, Bool, x, y)` and +`lift(&, Bool, x, y)`. These methods both follow three-valued logic semantics. +""" +function lift(f, U::DataType, x) + if isnull(x) + return Nullable{U}() + else + return Nullable{U}(f(unsafe_get(x))) + end +end + +function lift(f, U::DataType, x1, x2) + if isnull(x1) | isnull(x2) + return Nullable{U}() + else + return Nullable{U}(f(unsafe_get(x1), unsafe_get(x2))) + end +end + +function lift(f, U::DataType, xs...) + if mapreduce(isnull, |, false, xs) + return Nullable{U}() + else + return Nullable{U}(f(map(unsafe_get, xs)...)) + end +end + +# Three-valued logic + +(::Lifted{&})(x::Union{Bool, Nullable{Bool}}, y::Union{Bool, Nullable{Bool}}) = + lift(&, Bool, x, y) +(::Lifted{|})(x::Union{Bool, Nullable{Bool}}, y::Union{Bool, Nullable{Bool}}) = + lift(|, Bool, x, y) + +function lift(f::typeof(&), ::Type{Bool}, x, y)::Nullable{Bool} + return ifelse( + isnull(x), + ifelse( + isnull(y), + Nullable{Bool}(), + ifelse( + unsafe_get(y), + Nullable{Bool}(), + Nullable(false) + ) + ), + ifelse( + isnull(y), + ifelse( + unsafe_get(x), + Nullable{Bool}(), + Nullable(false) + ), + Nullable(unsafe_get(x) & unsafe_get(y)) + ) + ) +end + +function lift(f::typeof(|), ::Type{Bool}, x, y)::Nullable{Bool} + return ifelse( + isnull(x), + ifelse( + isnull(y), + Nullable{Bool}(), + ifelse( + unsafe_get(y), + Nullable(true), + Nullable{Bool}() + ) + ), + ifelse( + isnull(y), + ifelse( + unsafe_get(x), + Nullable(true), + Nullable{Bool}() + ), + Nullable(unsafe_get(x) | unsafe_get(y)) + ) + ) +end diff --git a/test/nullable.jl b/test/nullable.jl index 819cd4a5871ce3..e55da36e9be69f 100644 --- a/test/nullable.jl +++ b/test/nullable.jl @@ -387,3 +387,72 @@ end # issue #11675 @test repr(Nullable()) == "Nullable{Union{}}()" + +# lifting + +f(x::Number) = 5 * x +f(x::Number, y::Number) = x + y +f(x::Number, y::Number, z::Number) = x + y * z +_f = lift(f) + +for T in setdiff(types, [Bool]) + a = one(T) + x = Nullable{T}(a) + y = Nullable{T}() + + U1 = Core.Inference.return_type(f, Tuple{T}) + @test isequal(_f(x), Nullable(f(a))) + @test isequal(_f(y), Nullable{U1}()) + + U2 = Core.Inference.return_type(f, Tuple{T, T}) + @test isequal(_f(x, x), Nullable(f(a, a))) + @test isequal(_f(x, y), Nullable{U2}()) + + U3 = Core.Inference.return_type(f, Tuple{T, T, T}) + @test isequal(_f(x, x, x), Nullable(f(a, a, a))) + @test isequal(_f(x, y, x), Nullable{U3}()) +end + +# three-valued logic + +# & truth table +v1 = lift(&, Bool, Nullable(true), Nullable(true)) +v2 = lift(&, Bool, Nullable(true), Nullable(false)) +v3 = lift(&, Bool, Nullable(true), Nullable{Bool}()) +v4 = lift(&, Bool, Nullable(false), Nullable(true)) +v5 = lift(&, Bool, Nullable(false), Nullable(false)) +v6 = lift(&, Bool, Nullable(false), Nullable{Bool}()) +v7 = lift(&, Bool, Nullable{Bool}(), Nullable(true)) +v8 = lift(&, Bool, Nullable{Bool}(), Nullable(false)) +v9 = lift(&, Bool, Nullable{Bool}(), Nullable{Bool}()) + +@test isequal(v1, Nullable(true)) +@test isequal(v2, Nullable(false)) +@test isequal(v3, Nullable{Bool}()) +@test isequal(v4, Nullable(false)) +@test isequal(v5, Nullable(false)) +@test isequal(v6, Nullable(false)) +@test isequal(v7, Nullable{Bool}()) +@test isequal(v8, Nullable(false)) +@test isequal(v9, Nullable{Bool}()) + +# | truth table +u1 = lift(|, Bool, Nullable(true), Nullable(true)) +u2 = lift(|, Bool, Nullable(true), Nullable(false)) +u3 = lift(|, Bool, Nullable(true), Nullable{Bool}()) +u4 = lift(|, Bool, Nullable(false), Nullable(true)) +u5 = lift(|, Bool, Nullable(false), Nullable(false)) +u6 = lift(|, Bool, Nullable(false), Nullable{Bool}()) +u7 = lift(|, Bool, Nullable{Bool}(), Nullable(true)) +u8 = lift(|, Bool, Nullable{Bool}(), Nullable(false)) +u9 = lift(|, Bool, Nullable{Bool}(), Nullable{Bool}()) + +@test isequal(u1, Nullable(true)) +@test isequal(u2, Nullable(true)) +@test isequal(u3, Nullable(true)) +@test isequal(u4, Nullable(true)) +@test isequal(u5, Nullable(false)) +@test isequal(u6, Nullable{Bool}()) +@test isequal(u7, Nullable(true)) +@test isequal(u8, Nullable{Bool}()) +@test isequal(u9, Nullable{Bool}())