Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: Better inference of _apply() (splatting) #20343

Merged
merged 3 commits into from
Feb 13, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 107 additions & 56 deletions base/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,18 @@ immutable InferenceParams
MAX_TUPLE_DEPTH::Int
MAX_TUPLE_SPLAT::Int
MAX_UNION_SPLITTING::Int
MAX_APPLY_UNION_ENUM::Int

# reasonable defaults
function InferenceParams(world::UInt;
inlining::Bool = inlining_enabled(),
tupletype_len::Int = 15,
tuple_depth::Int = 4,
tuple_splat::Int = 16,
union_splitting::Int = 4)
union_splitting::Int = 4,
apply_union_enum::Int = 8)
return new(world, inlining, tupletype_len,
tuple_depth, tuple_splat, union_splitting)
tuple_depth, tuple_splat, union_splitting, apply_union_enum)
end
end

Expand Down Expand Up @@ -1380,67 +1382,116 @@ function abstract_evals_to_constant(ex, c::ANY, vtypes, sv)
return isa(av,Const) && av.val === c
end

# `types` is an array of inferred types for expressions in `args`.
# if an expression constructs a container (e.g. `svec(x,y,z)`),
# refine its type to an array of element types. returns an array of
# arrays of types, or `nothing`.
function precise_container_types(args, types, vtypes::VarTable, sv)
n = length(args)
assert(n == length(types))
result = Vector{Any}(n)
for i = 1:n
ai = args[i]
ti = types[i]
tti = widenconst(ti)
tti = unwrap_unionall(tti)
if isa(ti, Const) && (isa(ti.val, SimpleVector) || isa(ti.val, Tuple))
result[i] = Any[ abstract_eval_constant(x) for x in ti.val ]
elseif isa(ai, Expr) && ai.head === :call && (abstract_evals_to_constant(ai.args[1], svec, vtypes, sv) ||
abstract_evals_to_constant(ai.args[1], tuple, vtypes, sv))
aa = ai.args
result[i] = Any[ (isa(aa[j],Expr) ? aa[j].typ : abstract_eval(aa[j],vtypes,sv)) for j=2:length(aa) ]
if _any(isvarargtype, result[i])
return nothing
end
elseif isa(tti, Union)
return nothing
elseif isa(tti,DataType) && tti <: Tuple
if i == n
if isvatuple(tti) && length(tti.parameters) == 1
result[i] = Any[Vararg{unwrapva(tti.parameters[1])}]
else
result[i] = tti.parameters
end
elseif isknownlength(tti)
result[i] = tti.parameters
else
return nothing
end
elseif tti <: AbstractArray && i == n
result[i] = Any[Vararg{eltype(tti)}]
# `typ` is the inferred type for expression `arg`.
# if the expression constructs a container (e.g. `svec(x,y,z)`),
# refine its type to an array of element types.
# Union of Tuples of the same length is converted to Tuple of Unions.
# returns an array of types, or `nothing`.
function precise_container_type(arg, typ, vtypes::VarTable, sv)
tti = widenconst(typ)
tti = unwrap_unionall(tti)
if isa(typ, Const) && (isa(typ.val, SimpleVector) || isa(typ.val, Tuple))
return Any[ abstract_eval_constant(x) for x in typ.val ]
elseif isa(arg, Expr) && arg.head === :call && (abstract_evals_to_constant(arg.args[1], svec, vtypes, sv) ||
abstract_evals_to_constant(arg.args[1], tuple, vtypes, sv))
aa = arg.args
result = Any[ (isa(aa[j],Expr) ? aa[j].typ : abstract_eval(aa[j],vtypes,sv)) for j=2:length(aa) ]
if _any(isvarargtype, result)
return Any[Vararg{Any}]
end
return result
elseif isa(tti, Union)
utis = uniontypes(tti)
if _any(t -> !isa(t,DataType) || !(t <: Tuple) || !isknownlength(t), utis)
return Any[Vararg{Any}]
end
result = Any[utis[1].parameters...]
for t in utis[2:end]
if length(t.parameters) != length(result)
return Any[Vararg{Any}]
end
for j in 1:length(t.parameters)
result[j] = tmerge(result[j], t.parameters[j])
end
end
return result
elseif isa(tti,DataType) && tti <: Tuple
if isvatuple(tti) && length(tti.parameters) == 1
return Any[Vararg{unwrapva(tti.parameters[1])}]
else
return nothing
return tti.parameters
end
elseif tti <: Array
return Any[Vararg{eltype(tti)}]
else
return Any[Vararg{abstract_iteration(tti, vtypes, sv)}]
end
end

# simulate iteration protocol on container type up to fixpoint
function abstract_iteration(itertype, vtypes::VarTable, sv)
if !isdefined(Main, :Base) || !isdefined(Main.Base, :start) || !isdefined(Main.Base, :next)
return Any
end
statetype = abstract_call(Main.Base.start, (), Any[Const(Main.Base.start), itertype], vtypes, sv)
valtype = Bottom
while valtype !== Any
nt = abstract_call(Main.Base.next, (), Any[Const(Main.Base.next), itertype, statetype], vtypes, sv)
if !isa(nt, DataType) || !(nt <: Tuple) || isvatuple(nt) || length(nt.parameters) != 2
return Any
end
if nt.parameters[1] <: valtype && nt.parameters[2] <: statetype
break
end
valtype = tmerge(valtype, nt.parameters[1])
statetype = tmerge(statetype, nt.parameters[2])
end
return result
return valtype
end

# do apply(af, fargs...), where af is a function value
function abstract_apply(af::ANY, fargs, aargtypes::Vector{Any}, vtypes::VarTable, sv)
ctypes = precise_container_types(fargs, aargtypes, vtypes, sv)
if ctypes !== nothing
# apply with known func with known tuple types
# can be collapsed to a call to the applied func
at = append_any(Any[Const(af)], ctypes...)
n = length(at)
if n-1 > sv.params.MAX_TUPLETYPE_LEN
tail = foldl((a,b)->tmerge(a,unwrapva(b)), Bottom, at[sv.params.MAX_TUPLETYPE_LEN+1:n])
at = vcat(at[1:sv.params.MAX_TUPLETYPE_LEN], Any[Vararg{widenconst(tail)}])
end
return abstract_call(af, (), at, vtypes, sv)
end
# apply known function with unknown args => f(Any...)
return abstract_call(af, (), Any[Const(af), Vararg{Any}], vtypes, sv)
res = Union{}
nargs = length(fargs)
assert(nargs == length(aargtypes))
splitunions = countunionsplit(aargtypes) <= sv.params.MAX_APPLY_UNION_ENUM
ctypes = Any[Any[]]
for i = 1:nargs
if aargtypes[i] === Any
# bail out completely and infer as f(::Any...)
# instead could keep what we got so far and just append a Vararg{Any} (by just
# using the normal logic from below), but that makes the time of the subarray
# test explode
ctypes = Any[Any[Vararg{Any}]]
break
end
ctypes´ = []
for ti in (splitunions ? uniontypes(aargtypes[i]) : Any[aargtypes[i]])
cti = precise_container_type(fargs[i], ti, vtypes, sv)
for ct in ctypes
if !isempty(ct) && isvarargtype(ct[end])
tail = foldl((a,b)->tmerge(a,unwrapva(b)), unwrapva(ct[end]), cti)
push!(ctypes´, push!(ct[1:end-1], Vararg{widenconst(tail)}))
else
push!(ctypes´, append_any(ct, cti))
end
end
end
ctypes = ctypes´
end
for ct in ctypes
if length(ct) > sv.params.MAX_TUPLETYPE_LEN
tail = foldl((a,b)->tmerge(a,unwrapva(b)), Bottom, ct[sv.params.MAX_TUPLETYPE_LEN:end])
resize!(ct, sv.params.MAX_TUPLETYPE_LEN)
ct[end] = Vararg{widenconst(tail)}
end
at = append_any(Any[Const(af)], ct)
res = tmerge(res, abstract_call(af, (), at, vtypes, sv))
if res === Any
break
end
end
return res
end

function return_type_tfunc(argtypes::ANY, vtypes::VarTable, sv::InferenceState)
Expand Down
35 changes: 35 additions & 0 deletions test/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,41 @@ g11015(::Type{Bool}, ::Bool) = 2.0
@test Int <: Base.return_types(f11015, (AT11015,))[1]
@test f11015(AT11015(true)) === 1

# better inference of apply (#20343)
f20343(::String, ::Int) = 1
f20343(::Int, ::String, ::Int, ::Int) = 1
f20343(::Int, ::Int, ::String, ::Int, ::Int, ::Int) = 1
f20343(::Union{Int,String}...) = Int8(1)
f20343(::Any...) = "no"
function g20343()
n = rand(1:3)
i = ntuple(i->n==i ? "" : 0, 2n)::Union{Tuple{String,Int},Tuple{Int,String,Int,Int},Tuple{Int,Int,String,Int,Int,Int}}
f20343(i...)
end
@test Base.return_types(g20343, ()) == [Int]
function h20343()
n = rand(1:3)
i = ntuple(i->n==i ? "" : 0, 3)::Union{Tuple{String,Int,Int},Tuple{Int,String,Int},Tuple{Int,Int,String}}
f20343(i..., i...)
end
@test all(t -> t<:Integer, Base.return_types(h20343, ()))
function i20343()
f20343([1,2,3]..., 4)
end
@test Base.return_types(i20343, ()) == [Int8]
immutable Foo20518 <: AbstractVector{Int}; end # issue #20518; inference assumed AbstractArrays
Base.getindex(::Foo20518, ::Int) = "oops" # not to lie about their element type
Base.indices(::Foo20518) = (Base.OneTo(4),)
foo20518(xs::Any...) = -1
foo20518(xs::Int...) = [0]
bar20518(xs) = sum(foo20518(xs...))
@test bar20518(Foo20518()) == -1
f19957(::Int) = Int8(1) # issue #19957, inference failure when splatting a number
f19957(::Int...) = Int16(1)
f19957(::Any...) = "no"
g19957(x) = f19957(x...)
@test all(t -> t<:Union{Int8,Int16}, Base.return_types(g19957, (Int,))) # with a full fix, this should just be Int8

# Inference for some type-level computation
fUnionAll{T}(::Type{T}) = Type{S} where S <: T
@inferred fUnionAll(Real) == Type{T} where T <: Real
Expand Down