Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove buggy linearization pass #604

Merged
merged 3 commits into from
Jan 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 41 additions & 23 deletions src/interpret.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,6 @@
throw(UndefVarError(frame.framecode.src.slotnames[slot.id]))
end

function lookup_expr(frame, e::Expr)
head = e.head
head === :the_exception && return frame.framedata.last_exception[]
if head === :static_parameter
arg = e.args[1]::Int
if isassigned(frame.framedata.sparams, arg)
return frame.framedata.sparams[arg]
else
syms = sparam_syms(frame.framecode.scope::Method)
throw(UndefVarError(syms[arg]))
end
end
head === :boundscheck && length(e.args) == 0 && return true
error("invalid lookup expr ", e)
end

"""
rhs = @lookup(frame, node)
rhs = @lookup(mod, frame, node)
Expand Down Expand Up @@ -67,6 +51,32 @@
end
end

function lookup_expr(frame, e::Expr)
head = e.head
head === :the_exception && return frame.framedata.last_exception[]
if head === :static_parameter
arg = e.args[1]::Int
if isassigned(frame.framedata.sparams, arg)
return frame.framedata.sparams[arg]
else
syms = sparam_syms(frame.framecode.scope::Method)
throw(UndefVarError(syms[arg]))
end
end
head === :boundscheck && length(e.args) == 0 && return true
if head === :call
f = @lookup frame e.args[1]

Check warning on line 68 in src/interpret.jl

View check run for this annotation

Codecov / codecov/patch

src/interpret.jl#L68

Added line #L68 was not covered by tests
if (@static VERSION < v"1.11.0-DEV.1180" && true) && f === Core.svec
# work around for a linearization bug in Julia (https://github.com/JuliaLang/julia/pull/52497)
return f(Any[@lookup(frame, e.args[i]) for i in 2:length(e.args)]...)
elseif f === Core.tuple
# handling for ccall literal syntax
return f(Any[@lookup(frame, e.args[i]) for i in 2:length(e.args)]...)
end
end
error("invalid lookup expr ", e)

Check warning on line 77 in src/interpret.jl

View check run for this annotation

Codecov / codecov/patch

src/interpret.jl#L77

Added line #L77 was not covered by tests
end

# This is used only for new struct/abstract/primitive nodes.
# The most important issue is that in these expressions, :call Exprs can be nested,
# and hence our re-use of the `callargs` field of Frame would introduce
Expand All @@ -91,18 +101,26 @@
if ex.head === :call
f = ex.args[1]
if f === Core.svec
return Core.svec(ex.args[2:end]...)
popfirst!(ex.args)
return Core.svec(ex.args...)
elseif f === Core.apply_type
return Core.apply_type(ex.args[2:end]...)
elseif f === Core.typeof
return Core.typeof(ex.args[2])
elseif f === Base.getproperty
popfirst!(ex.args)
return Core.apply_type(ex.args...)
elseif f === typeof && length(ex.args) == 2
return typeof(ex.args[2])

Check warning on line 110 in src/interpret.jl

View check run for this annotation

Codecov / codecov/patch

src/interpret.jl#L110

Added line #L110 was not covered by tests
elseif f === typeassert && length(ex.args) == 3
return typeassert(ex.args[2], ex.args[3])
elseif f === Base.getproperty && length(ex.args) == 3
return Base.getproperty(ex.args[2], ex.args[3])
elseif f === Core.Compiler.Val && length(ex.args) == 2
return Core.Compiler.Val(ex.args[2])
elseif f === Val && length(ex.args) == 2
return Val(ex.args[2])
else
Base.invokelatest(error, "unknown call f ", f)
Base.invokelatest(error, "unknown call f introduced by ccall lowering ", f)

Check warning on line 120 in src/interpret.jl

View check run for this annotation

Codecov / codecov/patch

src/interpret.jl#L120

Added line #L120 was not covered by tests
end
else
error("unknown expr ", ex)
return lookup_expr(frame, ex)
end
elseif isa(node, Int) || isa(node, Number) # Number is slow, requires subtyping
return node
Expand Down
132 changes: 6 additions & 126 deletions src/optimize.jl
Original file line number Diff line number Diff line change
@@ -1,94 +1,5 @@
const calllike = (:call, :foreigncall)

const compiled_calls = Dict{Any,Any}()

function extract_inner_call!(stmt::Expr, idx, once::Bool=false)
(stmt.head === :toplevel || stmt.head === :thunk) && return nothing
once |= stmt.head ∈ calllike
for (i, a) in enumerate(stmt.args)
isa(a, Expr) || continue
# Make sure we don't "damage" special syntax that requires literals
if i == 1 && stmt.head === :foreigncall
continue
end
if i == 2 && stmt.head === :call && stmt.args[1] === :cglobal
continue
end
ret = extract_inner_call!(a, idx, once) # doing this first extracts innermost calls
ret !== nothing && return ret
iscalllike = a.head ∈ calllike
if once && iscalllike
stmt.args[i] = NewSSAValue(idx)
return a
end
end
return nothing
end

function replace_ssa(stmt::Expr, ssalookup)
return Expr(stmt.head, Any[
if isa(a, SSAValue)
SSAValue(ssalookup[a.id])
elseif isa(a, NewSSAValue)
SSAValue(a.id)
elseif isa(a, Expr)
replace_ssa(a, ssalookup)
else
a
end
for a in stmt.args
]...)
end

function renumber_ssa!(stmts::Vector{Any}, ssalookup)
# When updating jumps, when lines get split into multiple lines
# (see "Un-nest :call expressions" below), we need to jump to the first of them.
# Consequently we use the previous "old-code" offset and add one.
# Fixes #455.
jumplookup(l, idx) = idx > 1 ? l[idx-1] + 1 : idx

for (i, stmt) in enumerate(stmts)
if isa(stmt, GotoNode)
stmts[i] = GotoNode(jumplookup(ssalookup, stmt.label))
elseif isa(stmt, SSAValue)
stmts[i] = SSAValue(ssalookup[stmt.id])
elseif isa(stmt, NewSSAValue)
stmts[i] = SSAValue(stmt.id)
elseif isexpr(stmt, :enter)
stmt.args[end] = jumplookup(ssalookup, stmt.args[1]::Int)
elseif isa(stmt, Expr)
stmts[i] = replace_ssa(stmt, ssalookup)
elseif isa(stmt, GotoIfNot)
cond = stmt.cond
if isa(cond, SSAValue)
cond = SSAValue(ssalookup[cond.id])
end
stmts[i] = GotoIfNot(cond, jumplookup(ssalookup, stmt.dest))
elseif isa(stmt, ReturnNode)
val = stmt.val
if isa(val, SSAValue)
stmts[i] = ReturnNode(SSAValue(ssalookup[val.id]))
end
elseif @static (isdefined(Core.IR, :EnterNode) && true) && isa(stmt, Core.IR.EnterNode)
stmts[i] = Core.IR.EnterNode(jumplookup(ssalookup, stmt.catch_dest))
end
end
return stmts
end

function compute_ssa_mapping_delete_statements!(code::CodeInfo, stmts::Vector{Int})
stmts = unique!(sort!(stmts))
ssalookup = collect(1:length(codelocs(code)))
cnt = 1
for i in 1:length(stmts)
start = stmts[i] + 1
stop = i == length(stmts) ? length(codelocs(code)) : stmts[i+1]
ssalookup[start:stop] .-= cnt
cnt += 1
end
return ssalookup
end

# Pre-frame-construction lookup
function lookup_stmt(stmts, arg)
if isa(arg, SSAValue)
Expand Down Expand Up @@ -179,7 +90,8 @@ function optimize!(code::CodeInfo, scope)

# Replace :llvmcall and :foreigncall with compiled variants. See
# https://github.com/JuliaDebug/JuliaInterpreter.jl/issues/13#issuecomment-464880123
foreigncalls_idx = Int[]
# Insert the foreigncall wrappers at the updated idxs
methodtables = Vector{Union{Compiled,DispatchableMethod}}(undef, length(code.code))
for (idx, stmt) in enumerate(code.code)
# Foregincalls can be rhs of assignments
if isexpr(stmt, :(=))
Expand All @@ -192,47 +104,16 @@ function optimize!(code::CodeInfo, scope)
if (arg1 === :llvmcall || lookup_stmt(code.code, arg1) === Base.llvmcall) && isempty(sparams) && scope isa Method
# Call via `invokelatest` to avoid compiling it until we need it
Base.invokelatest(build_compiled_llvmcall!, stmt, code, idx, evalmod)
push!(foreigncalls_idx, idx)
methodtables[idx] = Compiled()
end
elseif stmt.head === :foreigncall && scope isa Method
# Call via `invokelatest` to avoid compiling it until we need it
Base.invokelatest(build_compiled_foreigncall!, stmt, code, sparams, evalmod)
push!(foreigncalls_idx, idx)
methodtables[idx] = Compiled()
end
end
end

## Un-nest :call expressions (so that there will be only one :call per line)
# This will allow us to re-use args-buffers rather than having to allocate new ones each time.
old_code, old_codelocs = code.code, codelocs(code)
code.code = new_code = eltype(old_code)[]
code.codelocs = new_codelocs = Int32[]
ssainc = fill(1, length(old_code))
for (i, stmt) in enumerate(old_code)
loc = old_codelocs[i]
if isa(stmt, Expr)
inner = extract_inner_call!(stmt, length(new_code)+1)
while inner !== nothing
push!(new_code, inner)
push!(new_codelocs, loc)
ssainc[i] += 1
inner = extract_inner_call!(stmt, length(new_code)+1)
end
end
push!(new_code, stmt)
push!(new_codelocs, loc)
end
# Fix all the SSAValues and GotoNodes
ssalookup = cumsum(ssainc)
renumber_ssa!(new_code, ssalookup)
code.ssavaluetypes = length(new_code)

# Insert the foreigncall wrappers at the updated idxs
methodtables = Vector{Union{Compiled,DispatchableMethod}}(undef, length(code.code))
for idx in foreigncalls_idx
methodtables[ssalookup[idx]] = Compiled()
end

return code, methodtables
end

Expand All @@ -255,7 +136,7 @@ function parametric_type_to_expr(@nospecialize(t::Type))
return t
end

function build_compiled_llvmcall!(stmt::Expr, code, idx, evalmod)
function build_compiled_llvmcall!(stmt::Expr, code::CodeInfo, idx::Int, evalmod::Module)
# Run a mini-interpreter to extract the types
framecode = FrameCode(CompiledCalls, code; optimize=false)
frame = Frame(framecode, prepare_framedata(framecode, []))
Expand Down Expand Up @@ -292,9 +173,8 @@ function build_compiled_llvmcall!(stmt::Expr, code, idx, evalmod)
append!(stmt.args, args)
end


# Handle :llvmcall & :foreigncall (issue #28)
function build_compiled_foreigncall!(stmt::Expr, code, sparams::Vector{Symbol}, evalmod)
function build_compiled_foreigncall!(stmt::Expr, code::CodeInfo, sparams::Vector{Symbol}, evalmod::Module)
TVal = evalmod == Core.Compiler ? Core.Compiler.Val : Val
cfunc, RetType, ArgType = lookup_stmt(code.code, stmt.args[1]), stmt.args[2], stmt.args[3]::SimpleVector

Expand Down
5 changes: 0 additions & 5 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,6 @@ which will cause all calls to be evaluated via the interpreter.
struct Compiled end
Base.similar(::Compiled, sz) = Compiled() # to support similar(stack, 0)

# A type used transiently in renumbering CodeInfo SSAValues (to distinguish a new SSAValue from an old one)
struct NewSSAValue
id::Int
end

# Our own replacements for Core types. We need to do this to ensure we can tell the difference
# between "data" (Core types) and "code" (our types) if we step into Core.Compiler
struct SSAValue
Expand Down
4 changes: 4 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ function scan_ssa_use!(used::BitSet, @nospecialize(stmt))
while iterval !== nothing
useref, state = iterval
val = Core.Compiler.getindex(useref)
if (@static VERSION < v"1.11.0-DEV.1180" && true) && isexpr(val, :call)
# work around for a linearization bug in Julia (https://github.com/JuliaLang/julia/pull/52497)
scan_ssa_use!(used, val)
end
if isa(val, SSAValue)
push!(used, val.id)
end
Expand Down
Loading