Skip to content

Commit

Permalink
Merge 884e5ea into 3ee4bb1
Browse files Browse the repository at this point in the history
  • Loading branch information
Ian Atol authored May 6, 2022
2 parents 3ee4bb1 + 884e5ea commit 0c08729
Show file tree
Hide file tree
Showing 8 changed files with 380 additions and 118 deletions.
180 changes: 137 additions & 43 deletions base/compiler/ssair/inlining.jl

Large diffs are not rendered by default.

115 changes: 77 additions & 38 deletions base/compiler/ssair/ir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -631,16 +631,13 @@ mutable struct IncrementalCompact
perm = my_sortperm(Int[code.new_nodes.info[i].pos for i in 1:length(code.new_nodes)])
new_len = length(code.stmts) + length(code.new_nodes)
ssa_rename = Any[SSAValue(i) for i = 1:new_len]
new_new_used_ssas = Vector{Int}()
late_fixup = Vector{Int}()
bb_rename = Vector{Int}()
new_new_nodes = NewNodeStream()
pending_nodes = NewNodeStream()
pending_perm = Int[]
return new(code, parent.result,
parent.result_bbs, ssa_rename, bb_rename, bb_rename, parent.used_ssas,
late_fixup, perm, 1,
new_new_nodes, new_new_used_ssas, pending_nodes, pending_perm,
parent.late_fixup, perm, 1,
parent.new_new_nodes, parent.new_new_used_ssas, pending_nodes, pending_perm,
1, result_offset, parent.active_result_bb, false, false, false)
end
end
Expand Down Expand Up @@ -1470,62 +1467,104 @@ function maybe_erase_unused!(
return false
end

function fixup_phinode_values!(compact::IncrementalCompact, old_values::Vector{Any})
struct FixedNode
node::Any
needs_fixup::Bool
FixedNode(@nospecialize(node), needs_fixup::Bool) = new(node, needs_fixup)
end

function fixup_phinode_values!(compact::IncrementalCompact, old_values::Vector{Any}, reify_new_nodes::Bool)
values = Vector{Any}(undef, length(old_values))
needs_fixup = false
for i = 1:length(old_values)
isassigned(old_values, i) || continue
val = old_values[i]
if isa(val, Union{OldSSAValue, NewSSAValue})
val = fixup_node(compact, val)
if isa(val, OldSSAValue)
val = compact.ssa_rename[val.id]
if isa(val, SSAValue)
compact.used_ssas[val.id] += 1
end
elseif isa(val, NewSSAValue)
if reify_new_nodes
val = SSAValue(length(compact.result) + val.id)
else
needs_fixup = true
end
end
values[i] = val
end
values
return FixedNode(values, needs_fixup)
end

function fixup_node(compact::IncrementalCompact, @nospecialize(stmt))
function fixup_node(compact::IncrementalCompact, @nospecialize(stmt), reify_new_nodes::Bool)
if isa(stmt, PhiNode)
return PhiNode(stmt.edges, fixup_phinode_values!(compact, stmt.values))
(;node, needs_fixup) = fixup_phinode_values!(compact, stmt.values, reify_new_nodes)
return FixedNode(PhiNode(stmt.edges, node), needs_fixup)
elseif isa(stmt, PhiCNode)
return PhiCNode(fixup_phinode_values!(compact, stmt.values))
(;node, needs_fixup) = fixup_phinode_values!(compact, stmt.values, reify_new_nodes)
return FixedNode(PhiCNode(node), needs_fixup)
elseif isa(stmt, NewSSAValue)
return SSAValue(length(compact.result) + stmt.id)
elseif isa(stmt, OldSSAValue)
val = compact.ssa_rename[stmt.id]
if isa(val, SSAValue)
# If `val.id` is greater than the length of `compact.result` or
# `compact.used_ssas`, this SSA value is in `new_new_nodes`, so
# don't count the use
compact.used_ssas[val.id] += 1
if reify_new_nodes
return FixedNode(SSAValue(length(compact.result) + stmt.id), false)
else
return FixedNode(stmt, true)
end
return val
elseif isa(stmt, OldSSAValue)
return FixedNode(compact.ssa_rename[stmt.id], false)
else
urs = userefs(stmt)
needs_fixup = false
for ur in urs
val = ur[]
if isa(val, Union{NewSSAValue, OldSSAValue})
ur[] = fixup_node(compact, val)
if isa(val, NewSSAValue)
if reify_new_nodes
val = SSAValue(length(compact.result) + val.id)
else
needs_fixup = true
end
elseif isa(val, OldSSAValue)
val = compact.ssa_rename[val.id]
end
if isa(val, SSAValue) && val.id <= length(compact.used_ssas)
# If `val.id` is greater than the length of `compact.result` or
# `compact.used_ssas`, this SSA value is in `new_new_nodes`, so
# don't count the use
compact.used_ssas[val.id] += 1
end
ur[] = val
end
return urs[]
return FixedNode(urs[], needs_fixup)
end
end

function just_fixup!(compact::IncrementalCompact)
resize!(compact.used_ssas, length(compact.result))
append!(compact.used_ssas, compact.new_new_used_ssas)
empty!(compact.new_new_used_ssas)
for idx in compact.late_fixup
function just_fixup!(compact::IncrementalCompact, new_new_nodes_offset::Union{Int, Nothing} = nothing, late_fixup_offset::Union{Int, Nothing}=nothing)
if new_new_nodes_offset === late_fixup_offset === nothing # only do this appending in non_dce_finish!
resize!(compact.used_ssas, length(compact.result))
append!(compact.used_ssas, compact.new_new_used_ssas)
empty!(compact.new_new_used_ssas)
end
off = late_fixup_offset === nothing ? 1 : (late_fixup_offset+1)
set_off = off
for i in off:length(compact.late_fixup)
idx = compact.late_fixup[i]
stmt = compact.result[idx][:inst]
new_stmt = fixup_node(compact, stmt)
(stmt === new_stmt) || (compact.result[idx][:inst] = new_stmt)
end
for idx in 1:length(compact.new_new_nodes)
node = compact.new_new_nodes.stmts[idx]
stmt = node[:inst]
new_stmt = fixup_node(compact, stmt)
if new_stmt !== stmt
node[:inst] = new_stmt
(;node, needs_fixup) = fixup_node(compact, stmt, late_fixup_offset === nothing)
(stmt === node) || (compact.result[idx][:inst] = node)
if needs_fixup
compact.late_fixup[set_off] = idx
set_off += 1
end
end
if late_fixup_offset !== nothing
resize!(compact.late_fixup, set_off-1)
end
off = new_new_nodes_offset === nothing ? 1 : (new_new_nodes_offset+1)
for idx in off:length(compact.new_new_nodes)
new_node = compact.new_new_nodes.stmts[idx]
stmt = new_node[:inst]
(;node) = fixup_node(compact, stmt, late_fixup_offset === nothing)
if node !== stmt
new_node[:inst] = node
end
end
end
Expand Down
82 changes: 82 additions & 0 deletions base/compiler/ssair/passes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -713,6 +713,86 @@ function perform_lifting!(compact::IncrementalCompact,
return stmt_val # N.B. should never happen
end

function lift_svec_ref!(compact, idx, stmt)
if length(stmt.args) != 4
return
end

vec = stmt.args[3]
val = stmt.args[4]
valT = argextype(val, compact)
(isa(valT, Const) && isa(valT.val, Int)) || return
valI = valT.val
(1 <= valI) || return

if isa(vec, SimpleVector)
if valI <= length(val)
compact[idx] = vec[valI]
end
return
end

if isa(vec, SSAValue)
# TODO: We could do the whole lifing machinery here, but really all
# we want to do is clean this up when it got inserted by inlining,
# which always
def = compact[vec]
if is_known_call(def, Core.svec, compact)
nargs = length(def.args)
if valI <= nargs-1
compact[idx] = def.args[valI+1]
end
return
elseif is_known_call(def, Core._compute_sparams, compact)
m = argextype(def.args[2], compact)
isa(m, Const) || return
m = m.val
isa(m, Method) || return
# For now, just pattern match the benchmark case
# TODO: More general structural analysis of the intersection
length(def.args) == 3 || return
sig = m.sig
isa(sig, UnionAll) || return
tvar = sig.var
sig = sig.body
isa(sig, DataType) || return
sig.name === Tuple.name
length(sig.parameters) == 1 || return

arg = sig.parameters[1]
isa(arg, DataType) || return
arg.name === typename(Type) || return
arg = arg.parameters[1]

isa(arg, DataType) || return

rarg = def.args[3]
isa(rarg, SSAValue) || return
argdef = compact[rarg]

is_known_call(argdef, Core.apply_type, compact) || return
length(argdef.args) == 3 || return

applyT = argextype(argdef.args[2], compact)
isa(applyT, Const) || return
applyT = applyT.val

isa(applyT, UnionAll) || return
applyTvar = applyT.var
applyTbody = applyT.body

isa(applyTbody, DataType) || return
applyTbody.name == arg.name || return
length(applyTbody.parameters) == length(arg.parameters) == 1 || return
applyTbody.parameters[1] === applyTvar || return
arg.parameters[1] === tvar || return

compact[idx] = argdef.args[3]
return
end
end
end

# NOTE we use `IdSet{Int}` instead of `BitSet` for in these passes since they work on IR after inlining,
# which can be very large sometimes, and program counters in question are often very sparse
const SPCSet = IdSet{Int}
Expand Down Expand Up @@ -814,6 +894,8 @@ function sroa_pass!(ir::IRCode)
else # TODO: This isn't the best place to put these
if is_known_call(stmt, typeassert, compact)
canonicalize_typeassert!(compact, idx, stmt)
elseif is_known_call(stmt, Core._svec_ref, compact)
lift_svec_ref!(compact, idx, stmt)
elseif is_known_call(stmt, (===), compact)
lift_comparison!(===, compact, idx, stmt, lifting_cache)
elseif is_known_call(stmt, isa, compact)
Expand Down
8 changes: 1 addition & 7 deletions base/essentials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -672,13 +672,7 @@ end

# SimpleVector

function getindex(v::SimpleVector, i::Int)
@boundscheck if !(1 <= i <= length(v))
throw(BoundsError(v,i))
end
return ccall(:jl_svec_ref, Any, (Any, Int), v, i - 1)
end

@eval getindex(v::SimpleVector, i::Int) = Core._svec_ref($(Expr(:boundscheck)), v, i)
function length(v::SimpleVector)
return ccall(:jl_svec_len, Int, (Any,), v)
end
Expand Down
4 changes: 4 additions & 0 deletions src/builtin_proto.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ DECLARE_BUILTIN(_typevar);
DECLARE_BUILTIN(donotdelete);
DECLARE_BUILTIN(getglobal);
DECLARE_BUILTIN(setglobal);
DECLARE_BUILTIN(_compute_sparams);
DECLARE_BUILTIN(_svec_ref);

JL_CALLABLE(jl_f_invoke_kwsorter);
#ifdef DEFINE_BUILTIN_GLOBALS
Expand All @@ -73,6 +75,8 @@ JL_CALLABLE(jl_f_get_binding_type);
JL_CALLABLE(jl_f_set_binding_type);
JL_CALLABLE(jl_f_donotdelete);
JL_CALLABLE(jl_f_setglobal);
JL_CALLABLE(jl_f__compute_sparams);
JL_CALLABLE(jl_f__svec_ref);

#ifdef __cplusplus
}
Expand Down
32 changes: 32 additions & 0 deletions src/builtins.c
Original file line number Diff line number Diff line change
Expand Up @@ -1591,6 +1591,36 @@ JL_CALLABLE(jl_f_donotdelete)
return jl_nothing;
}

JL_CALLABLE(jl_f__compute_sparams)
{
JL_NARGSV(_compute_sparams, 1);
jl_method_t *m = (jl_method_t*)args[0];
JL_TYPECHK(_compute_sparams, method, (jl_value_t*)m);
jl_datatype_t *tt = jl_inst_arg_tuple_type(args[1], &args[2], nargs-1, 1);
jl_svec_t *env = jl_emptysvec;
JL_GC_PUSH2(&env, &tt);
jl_type_intersection_env((jl_value_t*)tt, m->sig, &env);
JL_GC_POP();
return (jl_value_t*)env;
}

JL_CALLABLE(jl_f__svec_ref)
{
JL_NARGS(_svec_ref, 3, 3);
jl_value_t *b = args[0];
jl_svec_t *s = (jl_svec_t*)args[1];
jl_value_t *i = (jl_value_t*)args[2];
JL_TYPECHK(_svec_ref, bool, b);
JL_TYPECHK(_svec_ref, simplevector, (jl_value_t*)s);
JL_TYPECHK(_svec_ref, long, i);
ssize_t idx = jl_unbox_long(i);
size_t len = jl_svec_len(s);
if (idx < 1 || idx > len) {
jl_bounds_error_int((jl_value_t*)s, idx);
}
return jl_svec_ref(s, idx-1);
}

static int equiv_field_types(jl_value_t *old, jl_value_t *ft)
{
size_t nf = jl_svec_len(ft);
Expand Down Expand Up @@ -1961,6 +1991,8 @@ void jl_init_primitives(void) JL_GC_DISABLED
jl_builtin__typebody = add_builtin_func("_typebody!", jl_f__typebody);
add_builtin_func("_equiv_typedef", jl_f__equiv_typedef);
jl_builtin_donotdelete = add_builtin_func("donotdelete", jl_f_donotdelete);
add_builtin_func("_compute_sparams", jl_f__compute_sparams);
add_builtin_func("_svec_ref", jl_f__svec_ref);

// builtin types
add_builtin("Any", (jl_value_t*)jl_any_type);
Expand Down
2 changes: 1 addition & 1 deletion src/staticdata.c
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ static const jl_fptr_args_t id_to_fptrs[] = {
&jl_f_ifelse, &jl_f__structtype, &jl_f__abstracttype, &jl_f__primitivetype,
&jl_f__typebody, &jl_f__setsuper, &jl_f__equiv_typedef, &jl_f_get_binding_type,
&jl_f_set_binding_type, &jl_f_opaque_closure_call, &jl_f_donotdelete,
&jl_f_getglobal, &jl_f_setglobal,
&jl_f_getglobal, &jl_f_setglobal, &jl_f__compute_sparams, &jl_f__svec_ref,
NULL };

typedef struct {
Expand Down
Loading

0 comments on commit 0c08729

Please sign in to comment.