Skip to content

Commit

Permalink
optimizer: refactor SROA pass
Browse files Browse the repository at this point in the history
- use `BitSet` instead of `IdSet{Int}`
- reduce # of dynamic allocations
- separate some computations into individual functions
  • Loading branch information
aviatesk committed Nov 27, 2021
1 parent 2b1ece9 commit 6c4e203
Showing 1 changed file with 91 additions and 82 deletions.
173 changes: 91 additions & 82 deletions base/compiler/ssair/passes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@ SSADefUse() = SSADefUse(Int[], Int[], Int[])

compute_live_ins(cfg::CFG, du::SSADefUse) = compute_live_ins(cfg, du.defs, du.uses)

function try_compute_field_stmt(compact::IncrementalCompact, stmt::Expr)
function try_compute_field_stmt(ir::Union{IncrementalCompact,IRCode}, stmt::Expr)
field = stmt.args[3]
# fields are usually literals, handle them manually
if isa(field, QuoteNode)
field = field.value
elseif isa(field, Int)
# try to resolve other constants, e.g. global reference
else
field = compact_exprtype(compact, field)
field = isa(ir, IncrementalCompact) ? compact_exprtype(ir, field) : argextype(field, ir)
if isa(field, Const)
field = field.val
else
Expand All @@ -42,8 +42,8 @@ function try_compute_field_stmt(compact::IncrementalCompact, stmt::Expr)
return field
end

function try_compute_fieldidx_stmt(compact::IncrementalCompact, stmt::Expr, typ::DataType)
field = try_compute_field_stmt(compact, stmt)
function try_compute_fieldidx_stmt(ir::Union{IncrementalCompact,IRCode}, stmt::Expr, typ::DataType)
field = try_compute_field_stmt(ir, stmt)
return try_compute_fieldidx(typ, field)
end

Expand Down Expand Up @@ -112,6 +112,13 @@ function find_def_for_use(ir::IRCode, domtree::DomTree, allblocks::Vector{Int},
return def, stmtblock, curblock
end

function collect_leaves(compact::IncrementalCompact, @nospecialize(val), @nospecialize(typeconstraint))
if isa(val, Union{OldSSAValue, SSAValue})
val, typeconstraint = simple_walk_constraint(compact, val, typeconstraint)
end
return walk_to_defs(compact, val, typeconstraint)
end

function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSAValue=#),
callback = (@nospecialize(pi), @nospecialize(idx)) -> false)
while true
Expand Down Expand Up @@ -152,7 +159,7 @@ function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSA
end

function simple_walk_constraint(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSAValue=#),
@nospecialize(typeconstraint = types(compact)[defssa]))
@nospecialize(typeconstraint))
callback = function (@nospecialize(pi), @nospecialize(idx))
if isa(pi, PiNode)
typeconstraint = typeintersect(typeconstraint, widenconst(pi.typ))
Expand All @@ -164,20 +171,16 @@ function simple_walk_constraint(compact::IncrementalCompact, @nospecialize(defss
end

"""
walk_to_defs(compact, val, intermediaries)
walk_to_defs(compact, val, typeconstraint)
Starting at `val` walk use-def chains to get all the leaves feeding into
this val (pruning those leaves rules out by path conditions).
Starting at `val` walk use-def chains to get all the leaves feeding into this `val`
(pruning those leaves rules out by path conditions).
"""
function walk_to_defs(compact::IncrementalCompact, @nospecialize(defssa), @nospecialize(typeconstraint), visited_phinodes::Vector{AnySSAValue}=AnySSAValue[])
isa(defssa, AnySSAValue) || return Any[defssa]
function walk_to_defs(compact::IncrementalCompact, @nospecialize(defssa), @nospecialize(typeconstraint))
visited_phinodes = AnySSAValue[]
isa(defssa, AnySSAValue) || return Any[defssa], visited_phinodes
def = compact[defssa]
isa(def, PhiNode) || return Any[defssa]
# Step 2: Figure out what the struct is defined as
## Track definitions through PiNode/PhiNode
found_def = false
## Track which PhiNodes, SSAValue intermediaries
## we forwarded through.
isa(def, PhiNode) || return Any[defssa], visited_phinodes
visited_constraints = IdDict{AnySSAValue, Any}()
worklist_defs = AnySSAValue[]
worklist_constraints = Any[]
Expand Down Expand Up @@ -239,10 +242,10 @@ function walk_to_defs(compact::IncrementalCompact, @nospecialize(defssa), @nospe
push!(leaves, defssa)
end
end
leaves
return leaves, visited_phinodes
end

function process_immutable_preserve(new_preserves::Vector{Any}, compact::IncrementalCompact, def::Expr)
function process_immutable_preserve!(new_preserves::Vector{Any}, compact::IncrementalCompact, def::Expr)
for arg in (isexpr(def, :new) ? def.args : def.args[2:end])
if !isbitstype(widenconst(compact_exprtype(compact, arg)))
push!(new_preserves, arg)
Expand Down Expand Up @@ -449,13 +452,10 @@ function lift_comparison!(compact::IncrementalCompact,
return
end

if isa(val, Union{OldSSAValue, SSAValue})
val, typeconstraint = simple_walk_constraint(compact, val, typeconstraint)
end

visited_phinodes = AnySSAValue[]
leaves = walk_to_defs(compact, val, typeconstraint, visited_phinodes)
valtyp = widenconst(compact_exprtype(compact, val))
isa(valtyp, Union) || return # bail out if there won't be a good chance for lifting

leaves, visited_phinodes = collect_leaves(compact, val, valtyp)
length(leaves) 1 && return # bail out if we don't have multiple leaves

# Let's check if we evaluate the comparison for each one of the leaves
Expand All @@ -476,10 +476,6 @@ function lift_comparison!(compact::IncrementalCompact,
visited_phinodes, cmp, lifting_cache, Bool,
lifted_leaves::IdDict{Any, Union{Nothing,LiftedValue}}, val)::LiftedValue

# global assertion_counter
# assertion_counter::Int += 1
# insert_node_here!(compact, Expr(:assert_egal, Symbol(string("assert_egal_", assertion_counter)), SSAValue(idx), lifted_val), nothing, 0, true)
# return
compact[idx] = lifted_val.x
end

Expand Down Expand Up @@ -576,6 +572,10 @@ function perform_lifting!(compact::IncrementalCompact,
return stmt_val # N.B. should never happen
end

# NOTE we use `IdSet{Int}` instead of `BitSet` for `sroa_pass!` since it works on IR after inlining,
# which can be very large sometimes, and analyzed program counters are often very sparse
const SPCSet = IdSet{Int}

"""
sroa_pass!(ir::IRCode) -> newir::IRCode
Expand All @@ -596,17 +596,16 @@ a result of succeeding dead code elimination.
"""
function sroa_pass!(ir::IRCode)
compact = IncrementalCompact(ir)
defuses = IdDict{Int, Tuple{IdSet{Int}, SSADefUse}}()
defuses = nothing # will be initialized once we encounter mutability in order to reduce dynamic allocations
lifting_cache = IdDict{Pair{AnySSAValue, Any}, AnySSAValue}()
for ((_, idx), stmt) in compact
# check whether this statement is `getfield` / `setfield!` (or other "interesting" statement)
isa(stmt, Expr) || continue
result_t = compact_exprtype(compact, SSAValue(idx))
is_setfield = false
field_ordering = :unspecified
# Step 1: Check whether the statement we're looking at is a getfield/setfield!
if is_known_call(stmt, setfield!, compact)
is_setfield = true
4 <= length(stmt.args) <= 5 || continue
is_setfield = true
if length(stmt.args) == 5
field_ordering = compact_exprtype(compact, stmt.args[5])
end
Expand All @@ -624,7 +623,7 @@ function sroa_pass!(ir::IRCode)
old_preserves = stmt.args[(6+nccallargs):end]
for (pidx, preserved_arg) in enumerate(old_preserves)
isa(preserved_arg, SSAValue) || continue
let intermediaries = IdSet{Int}()
let intermediaries = SPCSet()
callback = function (@nospecialize(pi), @nospecialize(ssa))
push!(intermediaries, ssa.id)
return false
Expand All @@ -634,7 +633,7 @@ function sroa_pass!(ir::IRCode)
defidx = def.id
def = compact[defidx]
if is_tuple_call(compact, def)
process_immutable_preserve(new_preserves, compact, def)
process_immutable_preserve!(new_preserves, compact, def)
old_preserves[pidx] = nothing
continue
elseif isexpr(def, :new)
Expand All @@ -643,14 +642,17 @@ function sroa_pass!(ir::IRCode)
typ = unwrap_unionall(typ)
end
if typ isa DataType && !ismutabletype(typ)
process_immutable_preserve(new_preserves, compact, def)
process_immutable_preserve!(new_preserves, compact, def)
old_preserves[pidx] = nothing
continue
end
else
continue
end
mid, defuse = get!(defuses, defidx, (IdSet{Int}(), SSADefUse()))
if defuses === nothing
defuses = IdDict{Int, Tuple{SPCSet, SSADefUse}}()
end
mid, defuse = get!(defuses, defidx, (SPCSet(), SSADefUse()))
push!(defuse.ccall_preserve_uses, idx)
union!(mid, intermediaries)
end
Expand All @@ -675,10 +677,15 @@ function sroa_pass!(ir::IRCode)
else
continue
end

# analyze this `getfield` / `setfield!` call

field = try_compute_field_stmt(compact, stmt)
field === nothing && continue

struct_typ = unwrap_unionall(widenconst(compact_exprtype(compact, stmt.args[2])))
val = stmt.args[2]

struct_typ = unwrap_unionall(widenconst(compact_exprtype(compact, val)))
if isa(struct_typ, Union) && struct_typ <: Tuple
struct_typ = unswitchtupleunion(struct_typ)
end
Expand All @@ -689,19 +696,21 @@ function sroa_pass!(ir::IRCode)
continue
end

def, typeconstraint = stmt.args[2], struct_typ

# analyze this mutable struct here for the later pass
if ismutabletype(struct_typ)
isa(def, SSAValue) || continue
let intermediaries = IdSet{Int}()
isa(val, SSAValue) || continue
let intermediaries = SPCSet()
callback = function (@nospecialize(pi), @nospecialize(ssa))
push!(intermediaries, ssa.id)
return false
end
def = simple_walk(compact, def, callback)
def = simple_walk(compact, val, callback)
# Mutable stuff here
isa(def, SSAValue) || continue
mid, defuse = get!(defuses, def.id, (IdSet{Int}(), SSADefUse()))
if defuses === nothing
defuses = IdDict{Int, Tuple{SPCSet, SSADefUse}}()
end
mid, defuse = get!(defuses, def.id, (SPCSet(), SSADefUse()))
if is_setfield
push!(defuse.defs, idx)
else
Expand All @@ -711,32 +720,28 @@ function sroa_pass!(ir::IRCode)
end
continue
elseif is_setfield
continue
continue # invalid `setfield!` call, but just ignore here
end

# perform SROA on immutable structs here on

if isa(def, Union{OldSSAValue, SSAValue})
def, typeconstraint = simple_walk_constraint(compact, def, typeconstraint)
end

visited_phinodes = AnySSAValue[]
leaves = walk_to_defs(compact, def, typeconstraint, visited_phinodes)

isempty(leaves) && continue

field = try_compute_fieldidx(struct_typ, field)
field === nothing && continue

r = lift_leaves(compact, result_t, field, leaves)
r === nothing && continue
lifted_leaves, any_undef = r
leaves, visited_phinodes = collect_leaves(compact, val, struct_typ)
isempty(leaves) && continue

result_t = compact_exprtype(compact, SSAValue(idx))
lifted_result = lift_leaves(compact, result_t, field, leaves)
lifted_result === nothing && continue
lifted_leaves, any_undef = lifted_result

if any_undef
result_t = make_MaybeUndef(result_t)
end

val = perform_lifting!(compact, visited_phinodes, field, lifting_cache, result_t, lifted_leaves, stmt.args[2])
val = perform_lifting!(compact,
visited_phinodes, field, lifting_cache, result_t, lifted_leaves, val)

# Insert the undef check if necessary
if any_undef
Expand All @@ -750,28 +755,32 @@ function sroa_pass!(ir::IRCode)
@assert val !== nothing
end

# global assertion_counter
# assertion_counter::Int += 1
# insert_node_here!(compact, Expr(:assert_egal, Symbol(string("assert_egal_", assertion_counter)), SSAValue(idx), val), nothing, 0, true)
# continue
compact[idx] = val === nothing ? nothing : val.x
end

non_dce_finish!(compact)
# Copy the use count, `simple_dce!` may modify it and for our predicate
# below we need it consistent with the state of the IR here (after tracking
# phi node arguments, but before dce).
used_ssas = copy(compact.used_ssas)
simple_dce!(compact)
ir = complete(compact)

# Compute domtree, needed below, now that we have finished compacting the
# IR. This needs to be after we iterate through the IR with
# `IncrementalCompact` because removing dead blocks can invalidate the
# domtree.
if defuses !== nothing
# now go through analyzed mutable structs and see which ones we can eliminate
# NOTE copy the use count here, because `simple_dce!` may modify it and we need it
# consistent with the state of the IR here (after tracking `PhiNode` arguments,
# but before the DCE) for our predicate within `sroa_mutables!`
used_ssas = copy(compact.used_ssas)
simple_dce!(compact)
ir = complete(compact)
sroa_mutables!(ir, defuses, used_ssas)
return ir
else
simple_dce!(compact)
return complete(compact)
end
end

function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse}}, used_ssas::Vector{Int})
# Compute domtree, needed below, now that we have finished compacting the IR.
# This needs to be after we iterate through the IR with `IncrementalCompact`
# because removing dead blocks can invalidate the domtree.
@timeit "domtree 2" domtree = construct_domtree(ir.cfg.blocks)

# Now go through any mutable structs and see which ones we can eliminate
for (idx, (intermediaries, defuse)) in defuses
intermediaries = collect(intermediaries)
# Check if there are any uses we did not account for. If so, the variable
Expand Down Expand Up @@ -806,12 +815,12 @@ function sroa_pass!(ir::IRCode)
# it would have been deleted. That's fine, just ignore
# the use in that case.
stmt === nothing && continue
field = try_compute_fieldidx_stmt(compact, stmt::Expr, typ)
field = try_compute_fieldidx_stmt(ir, stmt::Expr, typ)
field === nothing && @goto skip
push!(fielddefuse[field].uses, use)
end
for use in defuse.defs
field = try_compute_fieldidx_stmt(compact, ir[SSAValue(use)]::Expr, typ)
field = try_compute_fieldidx_stmt(ir, ir[SSAValue(use)]::Expr, typ)
field === nothing && @goto skip
push!(fielddefuse[field].defs, use)
end
Expand Down Expand Up @@ -846,8 +855,9 @@ function sroa_pass!(ir::IRCode)
end
end
end
preserve_uses = IdDict{Int, Vector{Any}}((idx=>Any[] for idx in IdSet{Int}(defuse.ccall_preserve_uses)))
# Everything accounted for. Go field by field and perform idf
preserve_uses = isempty(defuse.ccall_preserve_uses) ? nothing :
IdDict{Int, Vector{Any}}((idx=>Any[] for idx in SPCSet(defuse.ccall_preserve_uses)))
for fidx in 1:ndefuse
du = fielddefuse[fidx]
ftyp = fieldtype(typ, fidx)
Expand All @@ -863,8 +873,10 @@ function sroa_pass!(ir::IRCode)
ir[SSAValue(stmt)] = compute_value_for_use(ir, domtree, allblocks, du, phinodes, fidx, stmt)
end
if !isbitstype(ftyp)
for (use, list) in preserve_uses
push!(list, compute_value_for_use(ir, domtree, allblocks, du, phinodes, fidx, use))
if preserve_uses !== nothing
for (use, list) in preserve_uses
push!(list, compute_value_for_use(ir, domtree, allblocks, du, phinodes, fidx, use))
end
end
end
for b in phiblocks
Expand All @@ -881,7 +893,7 @@ function sroa_pass!(ir::IRCode)
ir[SSAValue(stmt)] = nothing
end
end
isempty(defuse.ccall_preserve_uses) && continue
preserve_uses === nothing && continue
push!(intermediaries, newidx)
# Insert the new preserves
for (use, new_preserves) in preserve_uses
Expand All @@ -897,10 +909,7 @@ function sroa_pass!(ir::IRCode)

@label skip
end

return ir
end
# assertion_counter = 0

"""
canonicalize_typeassert!(compact::IncrementalCompact, idx::Int, stmt::Expr)
Expand Down

0 comments on commit 6c4e203

Please sign in to comment.