Skip to content

Commit

Permalink
WIP: Allow generated functions to return a CodeInstance
Browse files Browse the repository at this point in the history
This PR allows generated functions to return a `CodeInstance` containing
optimized IR, allowing them to bypass inference and directly adding
inferred code into the ordinary course of execution. This is an enabling
capability for various external compiler implementations that may want
to provide compilation results to the Julia runtime.

As a minimal demonstrator of this capability, this adds a
Cassette-like `with_new_compiler` higher-order function, which
will compile/execute its arguments with the currently loaded `Compiler`
package. Unlike `@activate Compiler[:codegen]`, this change is not
global and the cache is fully partitioned. This by itself is a very
useful feature when developing Compiler code to be able to test
the full end-to-end codegen behavior before the changes are capable
of fully self-hosting.

A key enabler for this was the recent merging of #54899. This PR
includes a hacky version of the second TODO left at the end of
that PR, just to make everthing work end-to-end.

This PR is working end-to-end, but all three parts of it (the CodeInstance
return from generated functions, the `with_new_compiler` feature,
and the interpreter integration) need some additional cleanup. This
PR is mostly intended as a discussion point for what that additional
work needs to be.
  • Loading branch information
Keno committed Nov 22, 2024
1 parent 1bf2ef9 commit 0d66579
Show file tree
Hide file tree
Showing 8 changed files with 132 additions and 10 deletions.
15 changes: 15 additions & 0 deletions Compiler/extras/CompilerDevTools/Manifest.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# This file is machine-generated - editing it directly is not advised

julia_version = "1.12.0-DEV"
manifest_format = "2.0"
project_hash = "84f495a1bf065c95f732a48af36dd0cd2cefb9d5"

[[deps.Compiler]]
path = "../.."
uuid = "807dbc54-b67e-4c79-8afb-eafe4df6f2e1"
version = "0.0.2"

[[deps.CompilerDevTools]]
path = "."
uuid = "92b2d91f-d2bd-4c05-9214-4609ac33433f"
version = "0.0.0"
5 changes: 5 additions & 0 deletions Compiler/extras/CompilerDevTools/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
name = "CompilerDevTools"
uuid = "92b2d91f-d2bd-4c05-9214-4609ac33433f"

[deps]
Compiler = "807dbc54-b67e-4c79-8afb-eafe4df6f2e1"
62 changes: 62 additions & 0 deletions Compiler/extras/CompilerDevTools/src/CompilerDevTools.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
module CompilerDevTools

using Compiler
using Core.IR

include(joinpath(dirname(pathof(Compiler)), "..", "test", "newinterp.jl"))

@newinterp SplitCacheInterp

function generate_codeinst(world::UInt, #=source=#::LineNumberNode, this, fargtypes)
tt = Base.to_tuple_type(fargtypes)
match = Base._which(tt; raise=false, world)
match === nothing && return nothing # method match failed – the fallback implementation will raise a proper MethodError
mi = Compiler.specialize_method(match)
interp = SplitCacheInterp(; world)
new_compiler_ci = Compiler.typeinf_ext(interp, mi, Compiler.SOURCE_MODE_ABI)

# Construct a CodeInstance that matches `with_new_compiler` and forwards
# to new_compiler_ci

src = ccall(:jl_new_code_info_uninit, Ref{CodeInfo}, ())
src.slotnames = Symbol[:self, :args]
src.slotflags = fill(zero(UInt8), 2)
src.slottypes = Any[this, fargtypes]
src.isva = true
src.nargs = 2

code = Any[]
ssavaluetypes = Any[]
ncalleeargs = length(fargtypes)
for i = 1:ncalleeargs
push!(code, Expr(:call, getfield, Core.Argument(2), i))
push!(ssavaluetypes, fargtypes[i])
end
push!(code, Expr(:invoke, new_compiler_ci, (SSAValue(i) for i = 1:ncalleeargs)...))
push!(ssavaluetypes, new_compiler_ci.rettype)
push!(code, ReturnNode(SSAValue(ncalleeargs+1)))
push!(ssavaluetypes, Union{})
src.code = code
src.ssavaluetypes = ssavaluetypes

return CodeInstance(
mi, nothing, new_compiler_ci.rettype, new_compiler_ci.exctype,
isdefined(new_compiler_ci, :rettype_const) ? new_compiler_ci.rettype_const : nothing,
src,
isdefined(new_compiler_ci, :rettype_const) ? Int32(0x02) : Int32(0x00),
new_compiler_ci.min_world, new_compiler_ci.max_world,
new_compiler_ci.ipo_purity_bits, nothing, new_compiler_ci.relocatability,
nothing, Core.svec(new_compiler_ci))
end

function refresh()
@eval function with_new_compiler(args...)
$(Expr(:meta, :generated_only))
$(Expr(:meta, :generated, generate_codeinst))
end
end
refresh()

export with_new_compiler

end
4 changes: 2 additions & 2 deletions Compiler/src/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,12 @@ end
function call_get_staged(mi::MethodInstance, world::UInt, cache_ci::RefValue{CodeInstance})
token = @_gc_preserve_begin cache_ci
cache_ci_ptr = pointer_from_objref(cache_ci)
src = ccall(:jl_code_for_staged, Ref{CodeInfo}, (Any, UInt, Ptr{CodeInstance}), mi, world, cache_ci_ptr)
src = ccall(:jl_code_for_staged, Any, (Any, UInt, Ptr{CodeInstance}), mi, world, cache_ci_ptr)
@_gc_preserve_end token
return src
end
function call_get_staged(mi::MethodInstance, world::UInt, ::Nothing)
return ccall(:jl_code_for_staged, Ref{CodeInfo}, (Any, UInt, Ptr{Cvoid}), mi, world, C_NULL)
return ccall(:jl_code_for_staged, Any, (Any, UInt, Ptr{Cvoid}), mi, world, C_NULL)
end

function get_cached_uninferred(mi::MethodInstance, world::UInt)
Expand Down
6 changes: 5 additions & 1 deletion base/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@ function code_lowered(@nospecialize(f), @nospecialize(t=Tuple); generated::Bool=
for m in method_instances(f, t, world)
if generated && hasgenerator(m)
if may_invoke_generator(m)
code = ccall(:jl_code_for_staged, Ref{CodeInfo}, (Any, UInt, Ptr{Cvoid}), m, world, C_NULL)
code = ccall(:jl_code_for_staged, Any, (Any, UInt, Ptr{Cvoid}), m, world, C_NULL)
if isa(code, CodeInstance)
error("Generator `@generated` method ", m, " ",
"returned an optimized result")
end
else
error("Could not expand generator for `@generated` method ", m, ". ",
"This can happen if the provided argument types (", t, ") are ",
Expand Down
30 changes: 27 additions & 3 deletions src/interpreter.c
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,28 @@ static jl_value_t *do_invoke(jl_value_t **args, size_t nargs, interpreter_state
argv[i-1] = eval_value(args[i], s);
jl_value_t *c = args[0];
assert(jl_is_code_instance(c) || jl_is_method_instance(c));
jl_method_instance_t *meth = jl_is_method_instance(c) ? (jl_method_instance_t*)c : ((jl_code_instance_t*)c)->def;
jl_value_t *result = jl_invoke(argv[0], nargs == 2 ? NULL : &argv[1], nargs - 2, meth);
jl_value_t *result = NULL;
if (jl_is_code_instance(c)) {
jl_code_instance_t *codeinst = (jl_code_instance_t*)c;
assert(jl_atomic_load_relaxed(&codeinst->min_world) <= jl_current_task->world_age &&
jl_current_task->world_age <= jl_atomic_load_relaxed(&codeinst->max_world));
jl_callptr_t invoke = jl_atomic_load_acquire(&codeinst->invoke);
if (!invoke) {
jl_compile_codeinst(codeinst);
invoke = jl_atomic_load_acquire(&codeinst->invoke);
}
if (invoke) {
result = invoke(argv[0], nargs == 2 ? NULL : &argv[1], nargs - 2, codeinst);

} else {
if (codeinst->owner != jl_nothing) {
jl_error("Failed to invoke or compile external codeinst");
}
result = jl_invoke(argv[0], nargs == 2 ? NULL : &argv[1], nargs - 2, codeinst->def);
}
} else {
result = jl_invoke(argv[0], nargs == 2 ? NULL : &argv[1], nargs - 2, (jl_method_instance_t*)c);
}
JL_GC_POP();
return result;
}
Expand Down Expand Up @@ -729,7 +749,11 @@ jl_value_t *jl_code_or_ci_for_interpreter(jl_method_instance_t *mi, size_t world
jl_code_instance_t *uninferred = jl_cached_uninferred(cache, world);
if (!uninferred) {
assert(mi->def.method->generator);
src = jl_code_for_staged(mi, world, &uninferred);
ret = jl_code_for_staged(mi, world, &uninferred);
if (jl_is_code_instance(ret)) {
jl_mi_cache_insert(mi, (jl_code_instance_t*)ret);
return (jl_value_t*)ret;
}
}
ret = (jl_value_t*)uninferred;
src = (jl_code_info_t*)jl_atomic_load_relaxed(&uninferred->inferred);
Expand Down
2 changes: 1 addition & 1 deletion src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -1854,7 +1854,7 @@ JL_DLLEXPORT jl_value_t *jl_get_binding_value_if_resolved(jl_binding_t *b JL_PRO
JL_DLLEXPORT jl_value_t *jl_get_binding_value_if_resolved_and_const(jl_binding_t *b JL_PROPAGATES_ROOT) JL_NOTSAFEPOINT;
JL_DLLEXPORT jl_value_t *jl_declare_const_gf(jl_binding_t *b, jl_module_t *mod, jl_sym_t *name);
JL_DLLEXPORT jl_method_t *jl_method_def(jl_svec_t *argdata, jl_methtable_t *mt, jl_code_info_t *f, jl_module_t *module);
JL_DLLEXPORT jl_code_info_t *jl_code_for_staged(jl_method_instance_t *linfo, size_t world, jl_code_instance_t **cache);
JL_DLLEXPORT jl_value_t *jl_code_for_staged(jl_method_instance_t *linfo, size_t world, jl_code_instance_t **cache);
JL_DLLEXPORT jl_code_info_t *jl_copy_code_info(jl_code_info_t *src);
JL_DLLEXPORT size_t jl_get_world_counter(void) JL_NOTSAFEPOINT;
JL_DLLEXPORT size_t jl_get_tls_world_age(void) JL_NOTSAFEPOINT;
Expand Down
18 changes: 15 additions & 3 deletions src/method.c
Original file line number Diff line number Diff line change
Expand Up @@ -731,7 +731,7 @@ JL_DLLEXPORT jl_code_instance_t *jl_cache_uninferred(jl_method_instance_t *mi, j

// Return a newly allocated CodeInfo for the function signature
// effectively described by the tuple (specTypes, env, Method) inside linfo
JL_DLLEXPORT jl_code_info_t *jl_code_for_staged(jl_method_instance_t *mi, size_t world, jl_code_instance_t **cache)
JL_DLLEXPORT jl_value_t *jl_code_for_staged(jl_method_instance_t *mi, size_t world, jl_code_instance_t **cache)
{
jl_code_instance_t *cache_ci = jl_atomic_load_relaxed(&mi->cache);
jl_code_instance_t *uninferred_ci = jl_cached_uninferred(cache_ci, world);
Expand All @@ -753,6 +753,7 @@ JL_DLLEXPORT jl_code_info_t *jl_code_for_staged(jl_method_instance_t *mi, size_t
assert(generator != NULL);
assert(jl_is_method(def));
jl_code_info_t *func = NULL;
jl_value_t *ret = NULL;
jl_value_t *ex = NULL;
jl_value_t *kind = NULL;
jl_code_info_t *uninferred = NULL;
Expand All @@ -774,7 +775,16 @@ JL_DLLEXPORT jl_code_info_t *jl_code_for_staged(jl_method_instance_t *mi, size_t
ex = jl_call_staged(def, generator, world, mi->sparam_vals, jl_svec_data(ttdt->parameters), jl_nparams(ttdt));

// do some post-processing
if (jl_is_code_info(ex)) {
if (jl_is_code_instance(ex)) {
jl_code_instance_t *ci = (jl_code_instance_t*)ex;
if (ci->owner != jl_nothing)
jl_error("CodeInstance returned from generator must have owner == nothing");
if (ci->next)
jl_error("CodeInstance returned from generator must not be in the cache");
ret = ex;
goto done;
}
else if (jl_is_code_info(ex)) {
func = (jl_code_info_t*)ex;
jl_array_t *stmts = (jl_array_t*)func->code;
jl_resolve_globals_in_ir(stmts, def->module, mi->sparam_vals, 1);
Expand Down Expand Up @@ -865,6 +875,8 @@ JL_DLLEXPORT jl_code_info_t *jl_code_for_staged(jl_method_instance_t *mi, size_t
*cache = cached_ci;
}

ret = (jl_value_t*)func;
done:
ct->ptls->in_pure_callback = last_in;
jl_lineno = last_lineno;
ct->world_age = last_age;
Expand All @@ -875,7 +887,7 @@ JL_DLLEXPORT jl_code_info_t *jl_code_for_staged(jl_method_instance_t *mi, size_t
jl_rethrow();
}
JL_GC_POP();
return func;
return ret;
}

JL_DLLEXPORT jl_code_info_t *jl_copy_code_info(jl_code_info_t *src)
Expand Down

0 comments on commit 0d66579

Please sign in to comment.