Skip to content

Commit

Permalink
remove precompile mutation step from staticdata (#48309)
Browse files Browse the repository at this point in the history
Make sure things are properly ordered here, so that when serializing,
nothing is mutating the system at the same time.

Fix #48047
  • Loading branch information
vtjnash authored Jan 19, 2023
1 parent 1c5fa2b commit 87b8896
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 104 deletions.
18 changes: 13 additions & 5 deletions src/gf.c
Original file line number Diff line number Diff line change
Expand Up @@ -281,11 +281,9 @@ jl_code_info_t *jl_type_infer(jl_method_instance_t *mi, size_t world, int force)
return NULL;
jl_task_t *ct = jl_current_task;
if (ct->reentrant_inference == (uint16_t)-1) {
// TODO: We should avoid attempting to re-inter inference here at all
// and turn on this warning, but that requires further refactoring
// of the precompile code, so for now just catch that case here.
//jl_printf(JL_STDERR, "ERROR: Attempted to enter inference while writing out image.");
return NULL;
// We must avoid attempting to re-enter inference here
assert(0 && "attempted to enter inference while writing out image");
abort();
}
if (ct->reentrant_inference > 2)
return NULL;
Expand Down Expand Up @@ -487,6 +485,7 @@ int foreach_mtable_in_module(
// this is the original/primary binding for the type (name/wrapper)
jl_methtable_t *mt = tn->mt;
if (mt != NULL && (jl_value_t*)mt != jl_nothing && mt != jl_type_type_mt && mt != jl_nonfunction_mt) {
assert(mt->module == m);
if (!visit(mt, env))
return 0;
}
Expand All @@ -500,6 +499,15 @@ int foreach_mtable_in_module(
return 0;
}
}
else if (jl_is_mtable(v)) {
jl_methtable_t *mt = (jl_methtable_t*)v;
if (mt->module == m && mt->name == name) {
// this is probably an external method table here, so let's
// assume so as there is no way to precisely distinguish them
if (!visit(mt, env))
return 0;
}
}
}
}
table = jl_atomic_load_relaxed(&m->bindings);
Expand Down
4 changes: 4 additions & 0 deletions src/precompile_utils.c
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ static int compile_all_collect__(jl_typemap_entry_t *ml, void *env)
{
jl_array_t *allmeths = (jl_array_t*)env;
jl_method_t *m = ml->func.method;
if (m->external_mt)
return 1;
if (m->source) {
// method has a non-generated definition; can be compiled generically
jl_array_ptr_1d_push(allmeths, (jl_value_t*)m);
Expand Down Expand Up @@ -204,6 +206,8 @@ static int precompile_enq_specialization_(jl_method_instance_t *mi, void *closur
static int precompile_enq_all_specializations__(jl_typemap_entry_t *def, void *closure)
{
jl_method_t *m = def->func.method;
if (m->external_mt)
return 1;
if ((m->name == jl_symbol("__init__") || m->ccallable) && jl_is_dispatch_tupletype(m->sig)) {
// ensure `__init__()` and @ccallables get strongly-hinted, specialized, and compiled
jl_method_instance_t *mi = jl_specializations_get_linfo(m, m->sig, jl_emptysvec);
Expand Down
112 changes: 61 additions & 51 deletions src/staticdata.c
Original file line number Diff line number Diff line change
Expand Up @@ -2170,50 +2170,52 @@ JL_DLLEXPORT jl_value_t *jl_as_global_root(jl_value_t *val JL_MAYBE_UNROOTED)
}

static void jl_prepare_serialization_data(jl_array_t *mod_array, jl_array_t *newly_inferred, uint64_t worklist_key,
/* outputs */ jl_array_t **extext_methods,
jl_array_t **new_specializations, jl_array_t **method_roots_list,
jl_array_t **ext_targets, jl_array_t **edges)
/* outputs */ jl_array_t **extext_methods, jl_array_t **new_specializations,
jl_array_t **method_roots_list, jl_array_t **ext_targets, jl_array_t **edges)
{
// extext_methods: [method1, ...], worklist-owned "extending external" methods added to functions owned by modules outside the worklist
// ext_targets: [invokesig1, callee1, matches1, ...] non-worklist callees of worklist-owned methods
// ordinary dispatch: invokesig=NULL, callee is MethodInstance
// `invoke` dispatch: invokesig is signature, callee is MethodInstance
// abstract call: callee is signature
// edges: [caller1, ext_targets_indexes1, ...] for worklist-owned methods calling external methods

assert(edges_map == NULL);
JL_GC_PUSH1(&edges_map);

// Save the inferred code from newly inferred, external methods
htable_new(&external_mis, 0); // we need external_mis until after `jl_collect_edges` finishes
// Save the inferred code from newly inferred, external methods
*new_specializations = queue_external_cis(newly_inferred);
// Collect the new method roots
htable_t methods_with_newspecs;
htable_new(&methods_with_newspecs, 0);
jl_collect_methods(&methods_with_newspecs, *new_specializations);
*method_roots_list = jl_alloc_vec_any(0);
jl_collect_new_roots(*method_roots_list, &methods_with_newspecs, worklist_key);
htable_free(&methods_with_newspecs);

// Collect method extensions and edges data
edges_map = jl_alloc_vec_any(0);
JL_GC_PUSH1(&edges_map);
if (edges)
edges_map = jl_alloc_vec_any(0);
*extext_methods = jl_alloc_vec_any(0);
jl_collect_methtable_from_mod(jl_type_type_mt, *extext_methods);
jl_collect_methtable_from_mod(jl_nonfunction_mt, *extext_methods);
size_t i, len = jl_array_len(mod_array);
for (i = 0; i < len; i++) {
jl_module_t *m = (jl_module_t*)jl_array_ptr_ref(mod_array, i);
assert(jl_is_module(m));
if (m->parent == m) // some toplevel modules (really just Base) aren't actually
jl_collect_extext_methods_from_mod(*extext_methods, m);
}
jl_collect_methtable_from_mod(*extext_methods, jl_type_type_mt);
jl_collect_missing_backedges(jl_type_type_mt);
jl_collect_methtable_from_mod(*extext_methods, jl_nonfunction_mt);
jl_collect_missing_backedges(jl_nonfunction_mt);
// jl_collect_extext_methods_from_mod and jl_collect_missing_backedges also accumulate data in callers_with_edges.
// Process this to extract `edges` and `ext_targets`.
*ext_targets = jl_alloc_vec_any(0);
*edges = jl_alloc_vec_any(0);
jl_collect_edges(*edges, *ext_targets);

if (edges) {
jl_collect_missing_backedges(jl_type_type_mt);
jl_collect_missing_backedges(jl_nonfunction_mt);
// jl_collect_extext_methods_from_mod and jl_collect_missing_backedges also accumulate data in callers_with_edges.
// Process this to extract `edges` and `ext_targets`.
*ext_targets = jl_alloc_vec_any(0);
*edges = jl_alloc_vec_any(0);
*method_roots_list = jl_alloc_vec_any(0);
// Collect the new method roots
htable_t methods_with_newspecs;
htable_new(&methods_with_newspecs, 0);
jl_collect_methods(&methods_with_newspecs, *new_specializations);
jl_collect_new_roots(*method_roots_list, &methods_with_newspecs, worklist_key);
htable_free(&methods_with_newspecs);
jl_collect_edges(*edges, *ext_targets);
}
htable_free(&external_mis);
assert(edges_map == NULL); // jl_collect_edges clears this when done

Expand Down Expand Up @@ -2501,9 +2503,8 @@ static void jl_save_system_image_to_stream(ios_t *f,
jl_gc_enable(en);
}

static void jl_write_header_for_incremental(ios_t *f, jl_array_t *worklist, jl_array_t **mod_array, jl_array_t **udeps, int64_t *srctextpos, int64_t *checksumpos)
static void jl_write_header_for_incremental(ios_t *f, jl_array_t *worklist, jl_array_t *mod_array, jl_array_t **udeps, int64_t *srctextpos, int64_t *checksumpos)
{
*mod_array = jl_get_loaded_modules(); // __toplevel__ modules loaded in this session (from Base.loaded_modules_array)
assert(jl_precompile_toplevel_module == NULL);
jl_precompile_toplevel_module = (jl_module_t*)jl_array_ptr_ref(worklist, jl_array_len(worklist)-1);

Expand All @@ -2519,7 +2520,7 @@ static void jl_write_header_for_incremental(ios_t *f, jl_array_t *worklist, jl_a
// write description of requirements for loading (modules that must be pre-loaded if initialization is to succeed)
// this can return errors during deserialize,
// best to keep it early (before any actual initialization)
write_mod_list(f, *mod_array);
write_mod_list(f, mod_array);
}

JL_DLLEXPORT void jl_create_system_image(void **_native_data, jl_array_t *worklist, bool_t emit_split,
Expand Down Expand Up @@ -2550,49 +2551,58 @@ JL_DLLEXPORT void jl_create_system_image(void **_native_data, jl_array_t *workli
int64_t checksumpos_ff = 0;
int64_t datastartpos = 0;
JL_GC_PUSH6(&mod_array, &extext_methods, &new_specializations, &method_roots_list, &ext_targets, &edges);
if (worklist) {
jl_write_header_for_incremental(f, worklist, &mod_array, udeps, srctextpos, &checksumpos);
if (emit_split) {
checksumpos_ff = write_header(ff, 1);
write_uint8(ff, jl_cache_flags());
write_mod_list(ff, mod_array);
} else {
checksumpos_ff = checksumpos;
}
{
// make sure we don't run any Julia code concurrently after this point
jl_gc_enable_finalizers(ct, 0);
assert(ct->reentrant_inference == 0);
ct->reentrant_inference = (uint16_t)-1;
}
jl_prepare_serialization_data(mod_array, newly_inferred, jl_worklist_key(worklist), &extext_methods, &new_specializations, &method_roots_list, &ext_targets, &edges);

if (worklist) {
mod_array = jl_get_loaded_modules(); // __toplevel__ modules loaded in this session (from Base.loaded_modules_array)
// Generate _native_data`
if (jl_options.outputo || jl_options.outputbc || jl_options.outputunoptbc || jl_options.outputasm) {
jl_prepare_serialization_data(mod_array, newly_inferred, jl_worklist_key(worklist),
&extext_methods, &new_specializations, NULL, NULL, NULL);
jl_precompile_toplevel_module = (jl_module_t*)jl_array_ptr_ref(worklist, jl_array_len(worklist)-1);
*_native_data = jl_precompile_worklist(worklist, extext_methods, new_specializations);
jl_precompile_toplevel_module = NULL;
extext_methods = NULL;
new_specializations = NULL;
}
jl_write_header_for_incremental(f, worklist, mod_array, udeps, srctextpos, &checksumpos);
if (emit_split) {
checksumpos_ff = write_header(ff, 1);
write_uint8(ff, jl_cache_flags());
write_mod_list(ff, mod_array);
}
else {
checksumpos_ff = checksumpos;
}
}
else {
*_native_data = jl_precompile(jl_options.compile_enabled == JL_OPTIONS_COMPILE_ALL);
}

// Make sure we don't run any Julia code concurrently after this point
// since it will invalidate our serialization preparations
jl_gc_enable_finalizers(ct, 0);
assert(ct->reentrant_inference == 0);
ct->reentrant_inference = (uint16_t)-1;
if (worklist) {
jl_prepare_serialization_data(mod_array, newly_inferred, jl_worklist_key(worklist),
&extext_methods, &new_specializations, &method_roots_list, &ext_targets, &edges);
if (!emit_split) {
write_int32(f, 0); // No clone_targets
write_padding(f, LLT_ALIGN(ios_pos(f), JL_CACHE_BYTE_ALIGNMENT) - ios_pos(f));
} else {
}
else {
write_padding(ff, LLT_ALIGN(ios_pos(ff), JL_CACHE_BYTE_ALIGNMENT) - ios_pos(ff));
}
datastartpos = ios_pos(ff);
} else {
*_native_data = jl_precompile(jl_options.compile_enabled == JL_OPTIONS_COMPILE_ALL);
}
native_functions = *_native_data;
jl_save_system_image_to_stream(ff, worklist, extext_methods, new_specializations, method_roots_list, ext_targets, edges);
native_functions = NULL;
if (worklist) {
// Re-enable running julia code for postoutput hooks, atexit, etc.
jl_gc_enable_finalizers(ct, 1);
ct->reentrant_inference = 0;
jl_precompile_toplevel_module = NULL;
}
// make sure we don't run any Julia code concurrently before this point
// Re-enable running julia code for postoutput hooks, atexit, etc.
jl_gc_enable_finalizers(ct, 1);
ct->reentrant_inference = 0;
jl_precompile_toplevel_module = NULL;

if (worklist) {
// Go back and update the checksum in the header
Expand Down
60 changes: 12 additions & 48 deletions src/staticdata_utils.c
Original file line number Diff line number Diff line change
Expand Up @@ -273,12 +273,12 @@ static void jl_collect_methods(htable_t *mset, jl_array_t *new_specializations)
}
}

static void jl_collect_new_roots(jl_array_t *roots, htable_t *mset, uint64_t key)
static void jl_collect_new_roots(jl_array_t *roots, const htable_t *mset, uint64_t key)
{
size_t i, sz = mset->size;
int nwithkey;
jl_method_t *m;
void **table = mset->table;
void *const *table = mset->table;
jl_array_t *newroots = NULL;
JL_GC_PUSH1(&newroots);
for (i = 0; i < sz; i += 2) {
Expand Down Expand Up @@ -369,6 +369,8 @@ static int jl_collect_methcache_from_mod(jl_typemap_entry_t *ml, void *closure)
if (s && !jl_object_in_image((jl_value_t*)m->module)) {
jl_array_ptr_1d_push(s, (jl_value_t*)m);
}
if (edges_map == NULL)
return 1;
jl_svec_t *specializations = m->specializations;
size_t i, l = jl_svec_len(specializations);
for (i = 0; i < l; i++) {
Expand All @@ -379,60 +381,22 @@ static int jl_collect_methcache_from_mod(jl_typemap_entry_t *ml, void *closure)
return 1;
}

static void jl_collect_methtable_from_mod(jl_array_t *s, jl_methtable_t *mt)
static int jl_collect_methtable_from_mod(jl_methtable_t *mt, void *env)
{
jl_typemap_visitor(mt->defs, jl_collect_methcache_from_mod, (void*)s);
if (!jl_object_in_image((jl_value_t*)mt))
env = NULL; // do not collect any methods from here
jl_typemap_visitor(jl_atomic_load_relaxed(&mt->defs), jl_collect_methcache_from_mod, env);
if (env && edges_map)
jl_collect_missing_backedges(mt);
return 1;
}

// Collect methods of external functions defined by modules in the worklist
// "extext" = "extending external"
// Also collect relevant backedges
static void jl_collect_extext_methods_from_mod(jl_array_t *s, jl_module_t *m)
{
if (s && !jl_object_in_image((jl_value_t*)m))
s = NULL; // do not collect any methods
jl_svec_t *table = jl_atomic_load_relaxed(&m->bindings);
for (size_t i = 0; i < jl_svec_len(table); i++) {
jl_binding_t *b = (jl_binding_t*)jl_svec_ref(table, i);
if ((void*)b == jl_nothing)
break;
jl_sym_t *name = b->globalref->name;
if (b->owner == b && b->value && b->constp) {
jl_value_t *bv = jl_unwrap_unionall(b->value);
if (jl_is_datatype(bv)) {
jl_typename_t *tn = ((jl_datatype_t*)bv)->name;
if (tn->module == m && tn->name == name && tn->wrapper == b->value) {
jl_methtable_t *mt = tn->mt;
if (mt != NULL &&
(jl_value_t*)mt != jl_nothing &&
(mt != jl_type_type_mt && mt != jl_nonfunction_mt)) {
assert(mt->module == tn->module);
jl_collect_methtable_from_mod(s, mt);
if (s)
jl_collect_missing_backedges(mt);
}
}
}
else if (jl_is_module(b->value)) {
jl_module_t *child = (jl_module_t*)b->value;
if (child != m && child->parent == m && child->name == name) {
// this is the original/primary binding for the submodule
jl_collect_extext_methods_from_mod(s, (jl_module_t*)b->value);
}
}
else if (jl_is_mtable(b->value)) {
jl_methtable_t *mt = (jl_methtable_t*)b->value;
if (mt->module == m && mt->name == name) {
// this is probably an external method table, so let's assume so
// as there is no way to precisely distinguish them,
// and the rest of this serializer does not bother
// to handle any method tables specially
jl_collect_methtable_from_mod(s, (jl_methtable_t*)bv);
}
}
}
table = jl_atomic_load_relaxed(&m->bindings);
}
foreach_mtable_in_module(m, jl_collect_methtable_from_mod, s);
}

static void jl_record_edges(jl_method_instance_t *caller, arraylist_t *wq, jl_array_t *edges)
Expand Down

0 comments on commit 87b8896

Please sign in to comment.