From 0d66579c8ae39cfb357fb3078b74fb653c9070e2 Mon Sep 17 00:00:00 2001 From: Keno Fischer Date: Fri, 22 Nov 2024 09:18:57 +0000 Subject: [PATCH] WIP: Allow generated functions to return a `CodeInstance` This PR allows generated functions to return a `CodeInstance` containing optimized IR, allowing them to bypass inference and directly adding inferred code into the ordinary course of execution. This is an enabling capability for various external compiler implementations that may want to provide compilation results to the Julia runtime. As a minimal demonstrator of this capability, this adds a Cassette-like `with_new_compiler` higher-order function, which will compile/execute its arguments with the currently loaded `Compiler` package. Unlike `@activate Compiler[:codegen]`, this change is not global and the cache is fully partitioned. This by itself is a very useful feature when developing Compiler code to be able to test the full end-to-end codegen behavior before the changes are capable of fully self-hosting. A key enabler for this was the recent merging of #54899. This PR includes a hacky version of the second TODO left at the end of that PR, just to make everthing work end-to-end. This PR is working end-to-end, but all three parts of it (the CodeInstance return from generated functions, the `with_new_compiler` feature, and the interpreter integration) need some additional cleanup. This PR is mostly intended as a discussion point for what that additional work needs to be. --- .../extras/CompilerDevTools/Manifest.toml | 15 +++++ Compiler/extras/CompilerDevTools/Project.toml | 5 ++ .../CompilerDevTools/src/CompilerDevTools.jl | 62 +++++++++++++++++++ Compiler/src/utilities.jl | 4 +- base/reflection.jl | 6 +- src/interpreter.c | 30 ++++++++- src/julia.h | 2 +- src/method.c | 18 +++++- 8 files changed, 132 insertions(+), 10 deletions(-) create mode 100644 Compiler/extras/CompilerDevTools/Manifest.toml create mode 100644 Compiler/extras/CompilerDevTools/Project.toml create mode 100644 Compiler/extras/CompilerDevTools/src/CompilerDevTools.jl diff --git a/Compiler/extras/CompilerDevTools/Manifest.toml b/Compiler/extras/CompilerDevTools/Manifest.toml new file mode 100644 index 0000000000000..bcc78f1ded34a --- /dev/null +++ b/Compiler/extras/CompilerDevTools/Manifest.toml @@ -0,0 +1,15 @@ +# This file is machine-generated - editing it directly is not advised + +julia_version = "1.12.0-DEV" +manifest_format = "2.0" +project_hash = "84f495a1bf065c95f732a48af36dd0cd2cefb9d5" + +[[deps.Compiler]] +path = "../.." +uuid = "807dbc54-b67e-4c79-8afb-eafe4df6f2e1" +version = "0.0.2" + +[[deps.CompilerDevTools]] +path = "." +uuid = "92b2d91f-d2bd-4c05-9214-4609ac33433f" +version = "0.0.0" diff --git a/Compiler/extras/CompilerDevTools/Project.toml b/Compiler/extras/CompilerDevTools/Project.toml new file mode 100644 index 0000000000000..a2749a9a56a84 --- /dev/null +++ b/Compiler/extras/CompilerDevTools/Project.toml @@ -0,0 +1,5 @@ +name = "CompilerDevTools" +uuid = "92b2d91f-d2bd-4c05-9214-4609ac33433f" + +[deps] +Compiler = "807dbc54-b67e-4c79-8afb-eafe4df6f2e1" diff --git a/Compiler/extras/CompilerDevTools/src/CompilerDevTools.jl b/Compiler/extras/CompilerDevTools/src/CompilerDevTools.jl new file mode 100644 index 0000000000000..eaf812f71b93e --- /dev/null +++ b/Compiler/extras/CompilerDevTools/src/CompilerDevTools.jl @@ -0,0 +1,62 @@ +module CompilerDevTools + +using Compiler +using Core.IR + +include(joinpath(dirname(pathof(Compiler)), "..", "test", "newinterp.jl")) + +@newinterp SplitCacheInterp + +function generate_codeinst(world::UInt, #=source=#::LineNumberNode, this, fargtypes) + tt = Base.to_tuple_type(fargtypes) + match = Base._which(tt; raise=false, world) + match === nothing && return nothing # method match failed – the fallback implementation will raise a proper MethodError + mi = Compiler.specialize_method(match) + interp = SplitCacheInterp(; world) + new_compiler_ci = Compiler.typeinf_ext(interp, mi, Compiler.SOURCE_MODE_ABI) + + # Construct a CodeInstance that matches `with_new_compiler` and forwards + # to new_compiler_ci + + src = ccall(:jl_new_code_info_uninit, Ref{CodeInfo}, ()) + src.slotnames = Symbol[:self, :args] + src.slotflags = fill(zero(UInt8), 2) + src.slottypes = Any[this, fargtypes] + src.isva = true + src.nargs = 2 + + code = Any[] + ssavaluetypes = Any[] + ncalleeargs = length(fargtypes) + for i = 1:ncalleeargs + push!(code, Expr(:call, getfield, Core.Argument(2), i)) + push!(ssavaluetypes, fargtypes[i]) + end + push!(code, Expr(:invoke, new_compiler_ci, (SSAValue(i) for i = 1:ncalleeargs)...)) + push!(ssavaluetypes, new_compiler_ci.rettype) + push!(code, ReturnNode(SSAValue(ncalleeargs+1))) + push!(ssavaluetypes, Union{}) + src.code = code + src.ssavaluetypes = ssavaluetypes + + return CodeInstance( + mi, nothing, new_compiler_ci.rettype, new_compiler_ci.exctype, + isdefined(new_compiler_ci, :rettype_const) ? new_compiler_ci.rettype_const : nothing, + src, + isdefined(new_compiler_ci, :rettype_const) ? Int32(0x02) : Int32(0x00), + new_compiler_ci.min_world, new_compiler_ci.max_world, + new_compiler_ci.ipo_purity_bits, nothing, new_compiler_ci.relocatability, + nothing, Core.svec(new_compiler_ci)) +end + +function refresh() + @eval function with_new_compiler(args...) + $(Expr(:meta, :generated_only)) + $(Expr(:meta, :generated, generate_codeinst)) + end +end +refresh() + +export with_new_compiler + +end diff --git a/Compiler/src/utilities.jl b/Compiler/src/utilities.jl index 11d926f0c9d4e..9691f448651e4 100644 --- a/Compiler/src/utilities.jl +++ b/Compiler/src/utilities.jl @@ -114,12 +114,12 @@ end function call_get_staged(mi::MethodInstance, world::UInt, cache_ci::RefValue{CodeInstance}) token = @_gc_preserve_begin cache_ci cache_ci_ptr = pointer_from_objref(cache_ci) - src = ccall(:jl_code_for_staged, Ref{CodeInfo}, (Any, UInt, Ptr{CodeInstance}), mi, world, cache_ci_ptr) + src = ccall(:jl_code_for_staged, Any, (Any, UInt, Ptr{CodeInstance}), mi, world, cache_ci_ptr) @_gc_preserve_end token return src end function call_get_staged(mi::MethodInstance, world::UInt, ::Nothing) - return ccall(:jl_code_for_staged, Ref{CodeInfo}, (Any, UInt, Ptr{Cvoid}), mi, world, C_NULL) + return ccall(:jl_code_for_staged, Any, (Any, UInt, Ptr{Cvoid}), mi, world, C_NULL) end function get_cached_uninferred(mi::MethodInstance, world::UInt) diff --git a/base/reflection.jl b/base/reflection.jl index 9246b4cb0ac71..31a230d8eacab 100644 --- a/base/reflection.jl +++ b/base/reflection.jl @@ -31,7 +31,11 @@ function code_lowered(@nospecialize(f), @nospecialize(t=Tuple); generated::Bool= for m in method_instances(f, t, world) if generated && hasgenerator(m) if may_invoke_generator(m) - code = ccall(:jl_code_for_staged, Ref{CodeInfo}, (Any, UInt, Ptr{Cvoid}), m, world, C_NULL) + code = ccall(:jl_code_for_staged, Any, (Any, UInt, Ptr{Cvoid}), m, world, C_NULL) + if isa(code, CodeInstance) + error("Generator `@generated` method ", m, " ", + "returned an optimized result") + end else error("Could not expand generator for `@generated` method ", m, ". ", "This can happen if the provided argument types (", t, ") are ", diff --git a/src/interpreter.c b/src/interpreter.c index 49a3afed14f0c..e71d7d8b3a178 100644 --- a/src/interpreter.c +++ b/src/interpreter.c @@ -137,8 +137,28 @@ static jl_value_t *do_invoke(jl_value_t **args, size_t nargs, interpreter_state argv[i-1] = eval_value(args[i], s); jl_value_t *c = args[0]; assert(jl_is_code_instance(c) || jl_is_method_instance(c)); - jl_method_instance_t *meth = jl_is_method_instance(c) ? (jl_method_instance_t*)c : ((jl_code_instance_t*)c)->def; - jl_value_t *result = jl_invoke(argv[0], nargs == 2 ? NULL : &argv[1], nargs - 2, meth); + jl_value_t *result = NULL; + if (jl_is_code_instance(c)) { + jl_code_instance_t *codeinst = (jl_code_instance_t*)c; + assert(jl_atomic_load_relaxed(&codeinst->min_world) <= jl_current_task->world_age && + jl_current_task->world_age <= jl_atomic_load_relaxed(&codeinst->max_world)); + jl_callptr_t invoke = jl_atomic_load_acquire(&codeinst->invoke); + if (!invoke) { + jl_compile_codeinst(codeinst); + invoke = jl_atomic_load_acquire(&codeinst->invoke); + } + if (invoke) { + result = invoke(argv[0], nargs == 2 ? NULL : &argv[1], nargs - 2, codeinst); + + } else { + if (codeinst->owner != jl_nothing) { + jl_error("Failed to invoke or compile external codeinst"); + } + result = jl_invoke(argv[0], nargs == 2 ? NULL : &argv[1], nargs - 2, codeinst->def); + } + } else { + result = jl_invoke(argv[0], nargs == 2 ? NULL : &argv[1], nargs - 2, (jl_method_instance_t*)c); + } JL_GC_POP(); return result; } @@ -729,7 +749,11 @@ jl_value_t *jl_code_or_ci_for_interpreter(jl_method_instance_t *mi, size_t world jl_code_instance_t *uninferred = jl_cached_uninferred(cache, world); if (!uninferred) { assert(mi->def.method->generator); - src = jl_code_for_staged(mi, world, &uninferred); + ret = jl_code_for_staged(mi, world, &uninferred); + if (jl_is_code_instance(ret)) { + jl_mi_cache_insert(mi, (jl_code_instance_t*)ret); + return (jl_value_t*)ret; + } } ret = (jl_value_t*)uninferred; src = (jl_code_info_t*)jl_atomic_load_relaxed(&uninferred->inferred); diff --git a/src/julia.h b/src/julia.h index 944fd3c43a297..0a21204148146 100644 --- a/src/julia.h +++ b/src/julia.h @@ -1854,7 +1854,7 @@ JL_DLLEXPORT jl_value_t *jl_get_binding_value_if_resolved(jl_binding_t *b JL_PRO JL_DLLEXPORT jl_value_t *jl_get_binding_value_if_resolved_and_const(jl_binding_t *b JL_PROPAGATES_ROOT) JL_NOTSAFEPOINT; JL_DLLEXPORT jl_value_t *jl_declare_const_gf(jl_binding_t *b, jl_module_t *mod, jl_sym_t *name); JL_DLLEXPORT jl_method_t *jl_method_def(jl_svec_t *argdata, jl_methtable_t *mt, jl_code_info_t *f, jl_module_t *module); -JL_DLLEXPORT jl_code_info_t *jl_code_for_staged(jl_method_instance_t *linfo, size_t world, jl_code_instance_t **cache); +JL_DLLEXPORT jl_value_t *jl_code_for_staged(jl_method_instance_t *linfo, size_t world, jl_code_instance_t **cache); JL_DLLEXPORT jl_code_info_t *jl_copy_code_info(jl_code_info_t *src); JL_DLLEXPORT size_t jl_get_world_counter(void) JL_NOTSAFEPOINT; JL_DLLEXPORT size_t jl_get_tls_world_age(void) JL_NOTSAFEPOINT; diff --git a/src/method.c b/src/method.c index 8e3bb7d0060b7..18cfdfc89f966 100644 --- a/src/method.c +++ b/src/method.c @@ -731,7 +731,7 @@ JL_DLLEXPORT jl_code_instance_t *jl_cache_uninferred(jl_method_instance_t *mi, j // Return a newly allocated CodeInfo for the function signature // effectively described by the tuple (specTypes, env, Method) inside linfo -JL_DLLEXPORT jl_code_info_t *jl_code_for_staged(jl_method_instance_t *mi, size_t world, jl_code_instance_t **cache) +JL_DLLEXPORT jl_value_t *jl_code_for_staged(jl_method_instance_t *mi, size_t world, jl_code_instance_t **cache) { jl_code_instance_t *cache_ci = jl_atomic_load_relaxed(&mi->cache); jl_code_instance_t *uninferred_ci = jl_cached_uninferred(cache_ci, world); @@ -753,6 +753,7 @@ JL_DLLEXPORT jl_code_info_t *jl_code_for_staged(jl_method_instance_t *mi, size_t assert(generator != NULL); assert(jl_is_method(def)); jl_code_info_t *func = NULL; + jl_value_t *ret = NULL; jl_value_t *ex = NULL; jl_value_t *kind = NULL; jl_code_info_t *uninferred = NULL; @@ -774,7 +775,16 @@ JL_DLLEXPORT jl_code_info_t *jl_code_for_staged(jl_method_instance_t *mi, size_t ex = jl_call_staged(def, generator, world, mi->sparam_vals, jl_svec_data(ttdt->parameters), jl_nparams(ttdt)); // do some post-processing - if (jl_is_code_info(ex)) { + if (jl_is_code_instance(ex)) { + jl_code_instance_t *ci = (jl_code_instance_t*)ex; + if (ci->owner != jl_nothing) + jl_error("CodeInstance returned from generator must have owner == nothing"); + if (ci->next) + jl_error("CodeInstance returned from generator must not be in the cache"); + ret = ex; + goto done; + } + else if (jl_is_code_info(ex)) { func = (jl_code_info_t*)ex; jl_array_t *stmts = (jl_array_t*)func->code; jl_resolve_globals_in_ir(stmts, def->module, mi->sparam_vals, 1); @@ -865,6 +875,8 @@ JL_DLLEXPORT jl_code_info_t *jl_code_for_staged(jl_method_instance_t *mi, size_t *cache = cached_ci; } + ret = (jl_value_t*)func; +done: ct->ptls->in_pure_callback = last_in; jl_lineno = last_lineno; ct->world_age = last_age; @@ -875,7 +887,7 @@ JL_DLLEXPORT jl_code_info_t *jl_code_for_staged(jl_method_instance_t *mi, size_t jl_rethrow(); } JL_GC_POP(); - return func; + return ret; } JL_DLLEXPORT jl_code_info_t *jl_copy_code_info(jl_code_info_t *src)