Skip to content

Commit

Permalink
optimizer: clean up query interfaces (#43324)
Browse files Browse the repository at this point in the history
- unify `compact_exprtype` and `argextype`
- remove redundant arguments
- unify `is_known_call` definitions and improve the precision of
  `is_known_call(..., ::IRCode)` (by using `singleton_type`)
  • Loading branch information
aviatesk authored Dec 9, 2021
1 parent f60bfd1 commit 63f5b8a
Show file tree
Hide file tree
Showing 9 changed files with 186 additions and 217 deletions.
159 changes: 135 additions & 24 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ function stmt_affects_purity(@nospecialize(stmt), ir)
return false
end
if isa(stmt, GotoIfNot)
t = argextype(stmt.cond, ir, ir.sptypes)
t = argextype(stmt.cond, ir)
return !(t Bool)
end
if isa(stmt, Expr)
Expand All @@ -195,6 +195,127 @@ function stmt_affects_purity(@nospecialize(stmt), ir)
return true
end

"""
stmt_effect_free(stmt, rt, src::Union{IRCode,IncrementalCompact})
Determine whether a `stmt` is "side-effect-free", i.e. may be removed if it has no uses.
"""
function stmt_effect_free(@nospecialize(stmt), @nospecialize(rt), src::Union{IRCode,IncrementalCompact})
isa(stmt, PiNode) && return true
isa(stmt, PhiNode) && return true
isa(stmt, ReturnNode) && return false
isa(stmt, GotoNode) && return false
isa(stmt, GotoIfNot) && return false
isa(stmt, Slot) && return false # Slots shouldn't occur in the IR at this point, but let's be defensive here
isa(stmt, GlobalRef) && return isdefined(stmt.mod, stmt.name)
if isa(stmt, Expr)
(; head, args) = stmt
if head === :static_parameter
etyp = (isa(src, IRCode) ? src.sptypes : src.ir.sptypes)[args[1]::Int]
# if we aren't certain enough about the type, it might be an UndefVarError at runtime
return isa(etyp, Const)
end
if head === :call
f = argextype(args[1], src)
f = singleton_type(f)
f === nothing && return false
is_return_type(f) && return true
if isa(f, IntrinsicFunction)
intrinsic_effect_free_if_nothrow(f) || return false
return intrinsic_nothrow(f,
Any[argextype(args[i], src) for i = 2:length(args)])
end
contains_is(_PURE_BUILTINS, f) && return true
contains_is(_PURE_OR_ERROR_BUILTINS, f) || return false
rt === Bottom && return false
return _builtin_nothrow(f, Any[argextype(args[i], src) for i = 2:length(args)], rt)
elseif head === :new
typ = argextype(args[1], src)
# `Expr(:new)` of unknown type could raise arbitrary TypeError.
typ, isexact = instanceof_tfunc(typ)
isexact || return false
isconcretedispatch(typ) || return false
typ = typ::DataType
fieldcount(typ) >= length(args) - 1 || return false
for fld_idx in 1:(length(args) - 1)
eT = argextype(args[fld_idx + 1], src)
fT = fieldtype(typ, fld_idx)
eT fT || return false
end
return true
elseif head === :new_opaque_closure
length(args) < 5 && return false
typ = argextype(args[1], src)
typ, isexact = instanceof_tfunc(typ)
isexact || return false
typ Tuple || return false
isva = argextype(args[2], src)
rt_lb = argextype(args[3], src)
rt_ub = argextype(args[4], src)
src = argextype(args[5], src)
if !(isva Bool && rt_lb Type && rt_ub Type && src Method)
return false
end
return true
elseif head === :isdefined || head === :the_exception || head === :copyast || head === :inbounds || head === :boundscheck
return true
else
# e.g. :loopinfo
return false
end
end
return true
end

"""
argextype(x, src::Union{IRCode,IncrementalCompact}) -> t
argextype(x, src::CodeInfo, sptypes::Vector{Any}) -> t
Return the type of value `x` in the context of inferred source `src`.
Note that `t` might be an extended lattice element.
Use `widenconst(t)` to get the native Julia type of `x`.
"""
argextype(@nospecialize(x), ir::IRCode, sptypes::Vector{Any} = ir.sptypes) =
argextype(x, ir, sptypes, ir.argtypes)
function argextype(@nospecialize(x), compact::IncrementalCompact, sptypes::Vector{Any} = compact.ir.sptypes)
isa(x, AnySSAValue) && return types(compact)[x]
return argextype(x, compact, sptypes, compact.ir.argtypes)
end
argextype(@nospecialize(x), src::CodeInfo, sptypes::Vector{Any}) = argextype(x, src, sptypes, src.slottypes::Vector{Any})
function argextype(
@nospecialize(x), src::Union{IRCode,IncrementalCompact,CodeInfo},
sptypes::Vector{Any}, slottypes::Vector{Any})
if isa(x, Expr)
if x.head === :static_parameter
return sptypes[x.args[1]::Int]
elseif x.head === :boundscheck
return Bool
elseif x.head === :copyast
return argextype(x.args[1], src, sptypes, slottypes)
end
@assert false "argextype only works on argument-position values"
elseif isa(x, SlotNumber)
return slottypes[x.id]
elseif isa(x, TypedSlot)
return x.typ
elseif isa(x, SSAValue)
return abstract_eval_ssavalue(x, src)
elseif isa(x, Argument)
return slottypes[x.n]
elseif isa(x, QuoteNode)
return Const(x.value)
elseif isa(x, GlobalRef)
return abstract_eval_global(x.mod, x.name)
elseif isa(x, PhiNode)
return Any
elseif isa(x, PiNode)
return x.typ
else
return Const(x)
end
end
abstract_eval_ssavalue(s::SSAValue, src::Union{IRCode,IncrementalCompact}) = types(src)[s]

# compute inlining cost and sideeffects
function finish(interp::AbstractInterpreter, opt::OptimizationState, params::OptimizationParams, ir::IRCode, @nospecialize(result))
(; src, linfo) = opt
Expand All @@ -214,7 +335,7 @@ function finish(interp::AbstractInterpreter, opt::OptimizationState, params::Opt
for i in 1:length(ir.stmts)
node = ir.stmts[i]
stmt = node[:inst]
if stmt_affects_purity(stmt, ir) && !stmt_effect_free(stmt, node[:type], ir, ir.sptypes)
if stmt_affects_purity(stmt, ir) && !stmt_effect_free(stmt, node[:type], ir)
proven_pure = false
break
end
Expand Down Expand Up @@ -432,20 +553,19 @@ plus_saturate(x::Int, y::Int) = max(x, y, x+y)
isknowntype(@nospecialize T) = (T === Union{}) || isa(T, Const) || isconcretetype(widenconst(T))

function statement_cost(ex::Expr, line::Int, src::Union{CodeInfo, IRCode}, sptypes::Vector{Any},
slottypes::Vector{Any}, union_penalties::Bool,
params::OptimizationParams, error_path::Bool = false)
union_penalties::Bool, params::OptimizationParams, error_path::Bool = false)
head = ex.head
if is_meta_expr_head(head)
return 0
elseif head === :call
farg = ex.args[1]
ftyp = argextype(farg, src, sptypes, slottypes)
ftyp = argextype(farg, src, sptypes)
if ftyp === IntrinsicFunction && farg isa SSAValue
# if this comes from code that was already inlined into another function,
# Consts have been widened. try to recover in simple cases.
farg = isa(src, CodeInfo) ? src.code[farg.id] : src.stmts[farg.id][:inst]
if isa(farg, GlobalRef) || isa(farg, QuoteNode) || isa(farg, IntrinsicFunction) || isexpr(farg, :static_parameter)
ftyp = argextype(farg, src, sptypes, slottypes)
ftyp = argextype(farg, src, sptypes)
end
end
f = singleton_type(ftyp)
Expand All @@ -467,15 +587,15 @@ function statement_cost(ex::Expr, line::Int, src::Union{CodeInfo, IRCode}, sptyp
# return plus_saturate(argcost, isknowntype(extyp) ? 1 : params.inline_nonleaf_penalty)
return 0
elseif (f === Core.arrayref || f === Core.const_arrayref || f === Core.arrayset) && length(ex.args) >= 3
atyp = argextype(ex.args[3], src, sptypes, slottypes)
atyp = argextype(ex.args[3], src, sptypes)
return isknowntype(atyp) ? 4 : error_path ? params.inline_error_path_cost : params.inline_nonleaf_penalty
elseif f === typeassert && isconstType(widenconst(argextype(ex.args[3], src, sptypes, slottypes)))
elseif f === typeassert && isconstType(widenconst(argextype(ex.args[3], src, sptypes)))
return 1
elseif f === Core.isa
# If we're in a union context, we penalize type computations
# on union types. In such cases, it is usually better to perform
# union splitting on the outside.
if union_penalties && isa(argextype(ex.args[2], src, sptypes, slottypes), Union)
if union_penalties && isa(argextype(ex.args[2], src, sptypes), Union)
return params.inline_nonleaf_penalty
end
end
Expand All @@ -487,7 +607,7 @@ function statement_cost(ex::Expr, line::Int, src::Union{CodeInfo, IRCode}, sptyp
end
return T_FFUNC_COST[fidx]
end
extyp = line == -1 ? Any : argextype(SSAValue(line), src, sptypes, slottypes)
extyp = line == -1 ? Any : argextype(SSAValue(line), src, sptypes)
if extyp === Union{}
return 0
end
Expand All @@ -498,7 +618,7 @@ function statement_cost(ex::Expr, line::Int, src::Union{CodeInfo, IRCode}, sptyp
# run-time of the function, we omit them from
# consideration. This way, non-inlined error branches do not
# prevent inlining.
extyp = line == -1 ? Any : argextype(SSAValue(line), src, sptypes, slottypes)
extyp = line == -1 ? Any : argextype(SSAValue(line), src, sptypes)
return extyp === Union{} ? 0 : 20
elseif head === :(=)
if ex.args[1] isa GlobalRef
Expand All @@ -508,7 +628,7 @@ function statement_cost(ex::Expr, line::Int, src::Union{CodeInfo, IRCode}, sptyp
end
a = ex.args[2]
if a isa Expr
cost = plus_saturate(cost, statement_cost(a, -1, src, sptypes, slottypes, union_penalties, params, error_path))
cost = plus_saturate(cost, statement_cost(a, -1, src, sptypes, union_penalties, params, error_path))
end
return cost
elseif head === :copyast
Expand All @@ -524,11 +644,11 @@ function statement_cost(ex::Expr, line::Int, src::Union{CodeInfo, IRCode}, sptyp
end

function statement_or_branch_cost(@nospecialize(stmt), line::Int, src::Union{CodeInfo, IRCode}, sptypes::Vector{Any},
slottypes::Vector{Any}, union_penalties::Bool, params::OptimizationParams)
union_penalties::Bool, params::OptimizationParams)
thiscost = 0
dst(tgt) = isa(src, IRCode) ? first(src.cfg.blocks[tgt].stmts) : tgt
if stmt isa Expr
thiscost = statement_cost(stmt, line, src, sptypes, slottypes, union_penalties, params,
thiscost = statement_cost(stmt, line, src, sptypes, union_penalties, params,
is_stmt_throw_block(isa(src, IRCode) ? src.stmts.flag[line] : src.ssaflags[line]))::Int
elseif stmt isa GotoNode
# loops are generally always expensive
Expand All @@ -546,7 +666,7 @@ function inline_worthy(ir::IRCode,
bodycost::Int = 0
for line = 1:length(ir.stmts)
stmt = ir.stmts[line][:inst]
thiscost = statement_or_branch_cost(stmt, line, ir, ir.sptypes, ir.argtypes, union_penalties, params)
thiscost = statement_or_branch_cost(stmt, line, ir, ir.sptypes, union_penalties, params)
bodycost = plus_saturate(bodycost, thiscost)
bodycost > cost_threshold && return false
end
Expand All @@ -558,7 +678,6 @@ function statement_costs!(cost::Vector{Int}, body::Vector{Any}, src::Union{CodeI
for line = 1:length(body)
stmt = body[line]
thiscost = statement_or_branch_cost(stmt, line, src, sptypes,
src isa CodeInfo ? src.slottypes : src.argtypes,
unionpenalties, params)
cost[line] = thiscost
if thiscost > maxcost
Expand All @@ -568,14 +687,6 @@ function statement_costs!(cost::Vector{Int}, body::Vector{Any}, src::Union{CodeI
return maxcost
end

function is_known_call(e::Expr, @nospecialize(func), src, sptypes::Vector{Any}, slottypes::Vector{Any} = EMPTY_SLOTTYPES)
if e.head !== :call
return false
end
f = argextype(e.args[1], src, sptypes, slottypes)
return isa(f, Const) && f.val === func
end

function renumber_ir_elements!(body::Vector{Any}, changemap::Vector{Int})
return renumber_ir_elements!(body, changemap, changemap)
end
Expand Down
1 change: 0 additions & 1 deletion base/compiler/ssair/driver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ include("compiler/ssair/basicblock.jl")
include("compiler/ssair/domtree.jl")
include("compiler/ssair/ir.jl")
include("compiler/ssair/slot2ssa.jl")
include("compiler/ssair/queries.jl")
include("compiler/ssair/passes.jl")
include("compiler/ssair/inlining.jl")
include("compiler/ssair/verify.jl")
Expand Down
22 changes: 11 additions & 11 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector
return_value = SSAValue(idx′)
inline_compact[idx′] = val
inline_compact.result[idx′][:type] =
compact_exprtype(isa(val, Argument) || isa(val, Expr) ? compact : inline_compact, val)
argextype(val, isa(val, Argument) || isa(val, Expr) ? compact : inline_compact)
break
end
inline_compact[idx′] = stmt′
Expand Down Expand Up @@ -400,7 +400,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector
if isa(val, GlobalRef) || isa(val, Expr)
stmt′ = val
inline_compact.result[idx′][:type] =
compact_exprtype(isa(val, Expr) ? compact : inline_compact, val)
argextype(val, isa(val, Expr) ? compact : inline_compact)
insert_node_here!(inline_compact, NewInstruction(GotoNode(post_bb_id),
Any, compact.result[idx′][:line]),
true)
Expand Down Expand Up @@ -435,7 +435,7 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector
return_value = pn.values[1]
else
return_value = insert_node_here!(compact,
NewInstruction(pn, compact_exprtype(compact, SSAValue(idx)), compact.result[idx][:line]))
NewInstruction(pn, argextype(SSAValue(idx), compact), compact.result[idx][:line]))
end
end
return_value
Expand Down Expand Up @@ -580,7 +580,7 @@ function batch_inline!(todo::Vector{Pair{Int, Any}}, ir::IRCode, linetable::Vect
for aidx in 1:length(argexprs)
aexpr = argexprs[aidx]
if isa(aexpr, Expr) || isa(aexpr, GlobalRef)
ninst = effect_free(NewInstruction(aexpr, compact_exprtype(compact, aexpr), compact.result[idx][:line]))
ninst = effect_free(NewInstruction(aexpr, argextype(aexpr, compact), compact.result[idx][:line]))
argexprs[aidx] = insert_node_here!(compact, ninst)
end
end
Expand Down Expand Up @@ -886,7 +886,7 @@ function inline_splatnew!(ir::IRCode, idx::Int, stmt::Expr, @nospecialize(rt))
if nf isa Const
eargs = stmt.args
tup = eargs[2]
tt = argextype(tup, ir, ir.sptypes)
tt = argextype(tup, ir)
tnf = nfields_tfunc(tt)
# TODO: hoisting this tnf.val === nf.val check into codegen
# would enable us to almost always do this transform
Expand All @@ -908,15 +908,15 @@ end

function call_sig(ir::IRCode, stmt::Expr)
isempty(stmt.args) && return nothing
ft = argextype(stmt.args[1], ir, ir.sptypes)
ft = argextype(stmt.args[1], ir)
has_free_typevars(ft) && return nothing
f = singleton_type(ft)
f === Core.Intrinsics.llvmcall && return nothing
f === Core.Intrinsics.cglobal && return nothing
argtypes = Vector{Any}(undef, length(stmt.args))
argtypes[1] = ft
for i = 2:length(stmt.args)
a = argextype(stmt.args[i], ir, ir.sptypes)
a = argextype(stmt.args[i], ir)
(a === Bottom || isvarargtype(a)) && return nothing
argtypes[i] = a
end
Expand Down Expand Up @@ -1025,10 +1025,10 @@ end

function narrow_opaque_closure!(ir::IRCode, stmt::Expr, @nospecialize(info), state::InliningState)
if isa(info, OpaqueClosureCreateInfo)
lbt = argextype(stmt.args[3], ir, ir.sptypes)
lbt = argextype(stmt.args[3], ir)
lb, exact = instanceof_tfunc(lbt)
exact || return
ubt = argextype(stmt.args[4], ir, ir.sptypes)
ubt = argextype(stmt.args[4], ir)
ub, exact = instanceof_tfunc(ubt)
exact || return
# Narrow opaque closure type
Expand All @@ -1046,7 +1046,7 @@ end
# For primitives, we do that right here. For proper calls, we will
# discover this when we consult the caches.
function check_effect_free!(ir::IRCode, idx::Int, @nospecialize(stmt), @nospecialize(rt))
if stmt_effect_free(stmt, rt, ir, ir.sptypes)
if stmt_effect_free(stmt, rt, ir)
ir.stmts[idx][:flag] |= IR_FLAG_EFFECT_FREE
end
end
Expand Down Expand Up @@ -1346,7 +1346,7 @@ end

function mk_tuplecall!(compact::IncrementalCompact, args::Vector{Any}, line_idx::Int32)
e = Expr(:call, TOP_TUPLE, args...)
etyp = tuple_tfunc(Any[compact_exprtype(compact, args[i]) for i in 1:length(args)])
etyp = tuple_tfunc(Any[argextype(args[i], compact) for i in 1:length(args)])
return insert_node_here!(compact, NewInstruction(e, etyp, line_idx))
end

Expand Down
11 changes: 3 additions & 8 deletions base/compiler/ssair/ir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,7 @@ function insert_node!(ir::IRCode, pos::Int, inst::NewInstruction, attach_after::
node[:line] = something(inst.line, ir.stmts[pos][:line])
flag = inst.flag
if !inst.effect_free_computed
if stmt_effect_free(inst.stmt, inst.type, ir, ir.sptypes)
if stmt_effect_free(inst.stmt, inst.type, ir)
flag |= IR_FLAG_EFFECT_FREE
end
end
Expand Down Expand Up @@ -765,7 +765,7 @@ function insert_node_here!(compact::IncrementalCompact, inst::NewInstruction, re
resize!(compact, result_idx)
end
flag = inst.flag
if !inst.effect_free_computed && stmt_effect_free(inst.stmt, inst.type, compact, compact.ir.sptypes)
if !inst.effect_free_computed && stmt_effect_free(inst.stmt, inst.type, compact)
flag |= IR_FLAG_EFFECT_FREE
end
node = compact.result[result_idx]
Expand Down Expand Up @@ -1316,7 +1316,7 @@ function maybe_erase_unused!(
callback = null_dce_callback)
stmt = compact.result[idx][:inst]
stmt === nothing && return false
if compact_exprtype(compact, SSAValue(idx)) === Bottom
if argextype(SSAValue(idx), compact) === Bottom
effect_free = false
else
effect_free = compact.result[idx][:flag] & IR_FLAG_EFFECT_FREE != 0
Expand Down Expand Up @@ -1466,8 +1466,3 @@ function iterate(x::BBIdxIter, (idx, bb)::Tuple{Int, Int}=(1, 1))
end
return (bb, idx), (idx + 1, next_bb)
end

is_known_call(e::Expr, @nospecialize(func), ir::IRCode) =
is_known_call(e, func, ir, ir.sptypes, ir.argtypes)

argextype(@nospecialize(x), ir::IRCode) = argextype(x, ir, ir.sptypes, ir.argtypes)
Loading

2 comments on commit 63f5b8a

@nanosoldier
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Executing the daily package evaluation, I will reply here when finished:

@nanosoldier runtests(ALL, isdaily = true)

@nanosoldier
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your package evaluation job has completed - possible new issues were detected. A full report can be found here.

Please sign in to comment.