Skip to content

Commit

Permalink
Cache binding pointer in GlobalRef
Browse files Browse the repository at this point in the history
On certain workloads, profiling shows a surprisingly high fraction of
inference time spent looking up bindings in modules. Bindings use
a hash table, so they're expected to be quite fast, but certainly
not zero. A big contributor to the problem is that we do basically
treat it as zero, looking up bindings for GlobalRefs multiple times
for each statement (e.g. in `isconst`, `isdefined`, to get the types,
etc). This PR attempts to improve the situation by adding an extra
field to GlobalRef that caches this lookup. This field is not serialized
and if not set, we fallback to the previous behavior. I would expect
the memory overhead to be relatively small, since we do intern GlobalRefs
in memory to only have one per binding (rather than one per use).

 # Benchmarks

The benchmarks look quite promising. Consider this artifical example
(though it's actually not all that far fetched, given some of the
generated code we see):

```
using Core.Intrinsics: add_int
const ONE = 1
@eval function f(x, y)
	z = 0
	$([:(z = add_int(x, add_int(z, ONE))) for _ = 1:10000]...)
	return add_int(z, y)
end
g(y) = f(ONE, y)
```

On master:
```
julia> @time @code_typed g(1)
  1.427227 seconds (1.31 M allocations: 58.809 MiB, 5.73% gc time, 99.96% compilation time)
CodeInfo(
1 ─ %1 = invoke Main.f(Main.ONE::Int64, y::Int64)::Int64
└──      return %1
) => Int64
```

On this PR:
```
julia> @time @code_typed g(1)
  0.503151 seconds (1.19 M allocations: 53.641 MiB, 5.10% gc time, 33.59% compilation time)
CodeInfo(
1 ─ %1 = invoke Main.f(Main.ONE::Int64, y::Int64)::Int64
└──      return %1
) => Int64
```

I don't expect the same speedup on other workloads, but there should be
a few % speedup on most workloads still.

 # Future extensions

The other motivation for this is that with a view towards #40399,
we will want to more clearly define when bindings get resolved. The
idea here would then be that binding resolution replaces generic
`GlobalRefs` by GlobalRefs with a set binding cache, and any
unresolved bindings would be treated conservatively by inference
and optimization.
  • Loading branch information
Keno committed Sep 15, 2022
1 parent 1916d35 commit 7677e64
Show file tree
Hide file tree
Showing 15 changed files with 117 additions and 26 deletions.
4 changes: 3 additions & 1 deletion base/boot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ eval(Core, quote
end
LineInfoNode(mod::Module, @nospecialize(method), file::Symbol, line::Int32, inlined_at::Int32) =
$(Expr(:new, :LineInfoNode, :mod, :method, :file, :line, :inlined_at))
GlobalRef(m::Module, s::Symbol) = $(Expr(:new, :GlobalRef, :m, :s))
GlobalRef(m::Module, s::Symbol, binding::Ptr{Nothing}) = $(Expr(:new, :GlobalRef, :m, :s, :binding))
SlotNumber(n::Int) = $(Expr(:new, :SlotNumber, :n))
TypedSlot(n::Int, @nospecialize(t)) = $(Expr(:new, :TypedSlot, :n, :t))
PhiNode(edges::Array{Int32, 1}, values::Array{Any, 1}) = $(Expr(:new, :PhiNode, :edges, :values))
Expand Down Expand Up @@ -812,6 +812,8 @@ Unsigned(x::Union{Float16, Float32, Float64, Bool}) = UInt(x)
Integer(x::Integer) = x
Integer(x::Union{Float16, Float32, Float64}) = Int(x)

GlobalRef(m::Module, s::Symbol) = GlobalRef(m, s, bitcast(Ptr{Nothing}, 0))

# Binding for the julia parser, called as
#
# Core._parse(text, filename, lineno, offset, options)
Expand Down
23 changes: 15 additions & 8 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1986,7 +1986,7 @@ function abstract_eval_special_value(interp::AbstractInterpreter, @nospecialize(
return sv.argtypes[e.n]
end
elseif isa(e, GlobalRef)
return abstract_eval_global(interp, e.mod, e.name, sv)
return abstract_eval_globalref(interp, e, sv)
end

return Const(e)
Expand Down Expand Up @@ -2260,17 +2260,24 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e),
return rt
end

function abstract_eval_global(M::Module, s::Symbol)
if isdefined(M, s) && isconst(M, s)
return Const(getglobal(M, s))
function isdefined_globalref(g::GlobalRef)
g.binding != C_NULL && return ccall(:jl_binding_boundp, Cint, (Ptr{Cvoid},), g.binding) != 0
return isdefined(g.mod, g.name)
end

function abstract_eval_globalref(g::GlobalRef)
if isdefined_globalref(g) && isconst(g)
g.binding != C_NULL && return Const(ccall(:jl_binding_value, Any, (Ptr{Cvoid},), g.binding))
return Const(getglobal(g.mod, g.name))
end
ty = ccall(:jl_binding_type, Any, (Any, Any), M, s)
ty = ccall(:jl_binding_type, Any, (Any, Any), g.mod, g.name)
ty === nothing && return Any
return ty
end
abstract_eval_global(M::Module, s::Symbol) = abstract_eval_globalref(GlobalRef(M, s))

function abstract_eval_global(interp::AbstractInterpreter, M::Module, s::Symbol, frame::Union{InferenceState, IRCode})
rt = abstract_eval_global(M, s)
function abstract_eval_globalref(interp::AbstractInterpreter, g::GlobalRef, frame::Union{InferenceState, IRCode})
rt = abstract_eval_globalref(g)
consistent = inaccessiblememonly = ALWAYS_FALSE
nothrow = false
if isa(rt, Const)
Expand All @@ -2281,7 +2288,7 @@ function abstract_eval_global(interp::AbstractInterpreter, M::Module, s::Symbol,
else
nothrow = true
end
elseif isdefined(M,s)
elseif isdefined_globalref(g)
nothrow = true
end
merge_effects!(interp, frame, Effects(EFFECTS_TOTAL; consistent, nothrow, inaccessiblememonly))
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ function argextype(
elseif isa(x, QuoteNode)
return Const(x.value)
elseif isa(x, GlobalRef)
return abstract_eval_global(x.mod, x.name)
return abstract_eval_globalref(x)
elseif isa(x, PhiNode)
return Any
elseif isa(x, PiNode)
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/ssair/slot2ssa.jl
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ function typ_for_val(@nospecialize(x), ci::CodeInfo, sptypes::Vector{Any}, idx::
end
return (ci.ssavaluetypes::Vector{Any})[idx]
end
isa(x, GlobalRef) && return abstract_eval_global(x.mod, x.name)
isa(x, GlobalRef) && return abstract_eval_globalref(x)
isa(x, SSAValue) && return (ci.ssavaluetypes::Vector{Any})[x.id]
isa(x, Argument) && return slottypes[x.n]
isa(x, NewSSAValue) && return DelayedTyp(x)
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ function is_throw_call(e::Expr)
if e.head === :call
f = e.args[1]
if isa(f, GlobalRef)
ff = abstract_eval_global(f.mod, f.name)
ff = abstract_eval_globalref(f)
if isa(ff, Const) && ff.val === Core.throw
return true
end
Expand Down
2 changes: 2 additions & 0 deletions base/essentials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ length(a::Array) = arraylen(a)
eval(:(getindex(A::Array, i1::Int) = arrayref($(Expr(:boundscheck)), A, i1)))
eval(:(getindex(A::Array, i1::Int, i2::Int, I::Int...) = (@inline; arrayref($(Expr(:boundscheck)), A, i1, i2, I...))))

==(a::GlobalRef, b::GlobalRef) = a.mod === b.mod && a.name === b.name

"""
AbstractSet{T}
Expand Down
5 changes: 5 additions & 0 deletions base/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,11 @@ Determine whether a global is declared `const` in a given module `m`.
isconst(m::Module, s::Symbol) =
ccall(:jl_is_const, Cint, (Any, Any), m, s) != 0

function isconst(g::GlobalRef)
g.binding != C_NULL && return ccall(:jl_binding_is_const, Cint, (Ptr{Cvoid},), g.binding) != 0
return isconst(g.mod, g.name)
end

"""
isconst(t::DataType, s::Union{Int,Symbol}) -> Bool
Expand Down
5 changes: 3 additions & 2 deletions base/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1838,9 +1838,10 @@ function allow_macroname(ex)
end
end

function is_core_macro(arg, macro_name::AbstractString)
arg === GlobalRef(Core, Symbol(macro_name))
function is_core_macro(arg::GlobalRef, macro_name::AbstractString)
arg == GlobalRef(Core, Symbol(macro_name))
end
is_core_macro(@nospecialize(arg), macro_name::AbstractString) = false

# symbol for IOContext flag signaling whether "begin" is treated
# as an ordinary symbol, which is true in indexing expressions.
Expand Down
1 change: 1 addition & 0 deletions src/clangsa/GCChecker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -745,6 +745,7 @@ bool GCChecker::isGCTrackedType(QualType QT) {
Name.endswith_lower("jl_method_match_t") ||
Name.endswith_lower("jl_vararg_t") ||
Name.endswith_lower("jl_opaque_closure_t") ||
Name.endswith_lower("jl_globalref_t") ||
// Probably not technically true for these, but let's allow it
Name.endswith_lower("typemap_intersection_env") ||
Name.endswith_lower("interpreter_state") ||
Expand Down
4 changes: 4 additions & 0 deletions src/dump.c
Original file line number Diff line number Diff line change
Expand Up @@ -2084,6 +2084,10 @@ static void jl_deserialize_struct(jl_serializer_state *s, jl_value_t *v) JL_GC_D
entry->min_world = 1;
entry->max_world = 0;
}
} else if (dt == jl_globalref_type) {
jl_globalref_t *r = (jl_globalref_t*)v;
jl_binding_t *b = jl_get_binding_if_bound(r->mod, r->name);
r->bnd_cache = b && b->value ? b : NULL;
}
}

Expand Down
13 changes: 12 additions & 1 deletion src/interpreter.c
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,17 @@ jl_value_t *jl_eval_global_var(jl_module_t *m, jl_sym_t *e)
return v;
}

jl_value_t *jl_eval_globalref(jl_globalref_t *g)
{
if (g->bnd_cache) {
jl_value_t *v = g->bnd_cache->value;
if (v == NULL)
jl_undefined_var_error(g->name);
return v;
}
return jl_eval_global_var(g->mod, g->name);
}

static int jl_source_nslots(jl_code_info_t *src) JL_NOTSAFEPOINT
{
return jl_array_len(src->slotflags);
Expand Down Expand Up @@ -190,7 +201,7 @@ static jl_value_t *eval_value(jl_value_t *e, interpreter_state *s)
return jl_quotenode_value(e);
}
if (jl_is_globalref(e)) {
return jl_eval_global_var(jl_globalref_mod(e), jl_globalref_name(e));
return jl_eval_globalref((jl_globalref_t*)e);
}
if (jl_is_symbol(e)) { // bare symbols appear in toplevel exprs not wrapped in `thunk`
return jl_eval_global_var(s->module, (jl_sym_t*)e);
Expand Down
12 changes: 6 additions & 6 deletions src/jltypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -2392,12 +2392,6 @@ void jl_init_types(void) JL_GC_DISABLED
jl_svec(1, jl_slotnumber_type),
jl_emptysvec, 0, 0, 1);

jl_globalref_type =
jl_new_datatype(jl_symbol("GlobalRef"), core, jl_any_type, jl_emptysvec,
jl_perm_symsvec(2, "mod", "name"),
jl_svec(2, jl_module_type, jl_symbol_type),
jl_emptysvec, 0, 0, 2);

jl_code_info_type =
jl_new_datatype(jl_symbol("CodeInfo"), core,
jl_any_type, jl_emptysvec,
Expand Down Expand Up @@ -2694,6 +2688,12 @@ void jl_init_types(void) JL_GC_DISABLED

jl_value_t *pointer_void = jl_apply_type1((jl_value_t*)jl_pointer_type, (jl_value_t*)jl_nothing_type);

jl_globalref_type =
jl_new_datatype(jl_symbol("GlobalRef"), core, jl_any_type, jl_emptysvec,
jl_perm_symsvec(3, "mod", "name", "binding"),
jl_svec(3, jl_module_type, jl_symbol_type, pointer_void),
jl_emptysvec, 0, 0, 3);

tv = jl_svec2(tvar("A"), tvar("R"));
jl_opaque_closure_type = (jl_unionall_t*)jl_new_datatype(jl_symbol("OpaqueClosure"), core, jl_function_type, tv,
jl_perm_symsvec(5, "captures", "world", "source", "invoke", "specptr"),
Expand Down
10 changes: 10 additions & 0 deletions src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,13 @@ typedef struct _jl_module_t {
jl_mutex_t lock;
} jl_module_t;

typedef struct {
jl_module_t *mod;
jl_sym_t *name;
// Not serialized. Caches the value of jl_get_binding(mod, name).
jl_binding_t *bnd_cache;
} jl_globalref_t;

// one Type-to-Value entry
typedef struct _jl_typemap_entry_t {
JL_DATA_TYPE
Expand Down Expand Up @@ -1616,6 +1623,7 @@ JL_DLLEXPORT int jl_get_module_max_methods(jl_module_t *m);
// get binding for reading
JL_DLLEXPORT jl_binding_t *jl_get_binding(jl_module_t *m JL_PROPAGATES_ROOT, jl_sym_t *var);
JL_DLLEXPORT jl_binding_t *jl_get_binding_or_error(jl_module_t *m, jl_sym_t *var);
JL_DLLEXPORT jl_binding_t *jl_get_binding_if_bound(jl_module_t *m, jl_sym_t *var);
JL_DLLEXPORT jl_value_t *jl_module_globalref(jl_module_t *m, jl_sym_t *var);
JL_DLLEXPORT jl_value_t *jl_binding_type(jl_module_t *m, jl_sym_t *var);
// get binding for assignment
Expand All @@ -1626,6 +1634,8 @@ JL_DLLEXPORT int jl_boundp(jl_module_t *m, jl_sym_t *var);
JL_DLLEXPORT int jl_defines_or_exports_p(jl_module_t *m, jl_sym_t *var);
JL_DLLEXPORT int jl_binding_resolved_p(jl_module_t *m, jl_sym_t *var);
JL_DLLEXPORT int jl_is_const(jl_module_t *m, jl_sym_t *var);
JL_DLLEXPORT int jl_binding_is_const(jl_binding_t *b);
JL_DLLEXPORT int jl_binding_boundp(jl_binding_t *b);
JL_DLLEXPORT jl_value_t *jl_get_global(jl_module_t *m JL_PROPAGATES_ROOT, jl_sym_t *var);
JL_DLLEXPORT void jl_set_global(jl_module_t *m JL_ROOTING_ARGUMENT, jl_sym_t *var, jl_value_t *val JL_ROOTED_ARGUMENT);
JL_DLLEXPORT void jl_set_const(jl_module_t *m JL_ROOTING_ARGUMENT, jl_sym_t *var, jl_value_t *val JL_ROOTED_ARGUMENT);
Expand Down
54 changes: 49 additions & 5 deletions src/module.c
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,20 @@ static jl_binding_t *jl_get_binding_(jl_module_t *m, jl_sym_t *var, modstack_t *
return b;
}

JL_DLLEXPORT jl_binding_t *jl_get_binding_if_bound(jl_module_t *m, jl_sym_t *var)
{
JL_LOCK(&m->lock);
jl_binding_t *b = _jl_get_module_binding(m, var);
JL_UNLOCK(&m->lock);
if (b == HT_NOTFOUND || b->owner == NULL) {
return NULL;
}
if (b->owner != m || b->name != var)
return jl_get_binding_if_bound(b->owner, b->name);
return b;
}


// get owner of binding when accessing m.var, without resolving the binding
JL_DLLEXPORT jl_value_t *jl_binding_owner(jl_module_t *m, jl_sym_t *var)
{
Expand Down Expand Up @@ -410,17 +424,29 @@ JL_DLLEXPORT jl_binding_t *jl_get_binding_or_error(jl_module_t *m, jl_sym_t *var
return b;
}

JL_DLLEXPORT jl_globalref_t *jl_new_globalref(jl_module_t *mod, jl_sym_t *name, jl_binding_t *b)
{
jl_task_t *ct = jl_current_task;
jl_globalref_t *g = (jl_globalref_t *)jl_gc_alloc(ct->ptls, sizeof(jl_globalref_t), jl_globalref_type);
g->mod = mod;
jl_gc_wb(g, g->mod);
g->name = name;
g->bnd_cache = b;
return g;
}

JL_DLLEXPORT jl_value_t *jl_module_globalref(jl_module_t *m, jl_sym_t *var)
{
JL_LOCK(&m->lock);
jl_binding_t *b = (jl_binding_t*)ptrhash_get(&m->bindings, var);
jl_binding_t *b = _jl_get_module_binding(m, var);
if (b == HT_NOTFOUND) {
JL_UNLOCK(&m->lock);
return jl_new_struct(jl_globalref_type, m, var);
return (jl_value_t *)jl_new_globalref(m, var, NULL);
}
jl_value_t *globalref = jl_atomic_load_relaxed(&b->globalref);
if (globalref == NULL) {
jl_value_t *newref = jl_new_struct(jl_globalref_type, m, var);
jl_value_t *newref = (jl_value_t *)jl_new_globalref(m, var,
!b->owner ? NULL : b->owner == m ? b : _jl_get_module_binding(b->owner, b->name));
if (jl_atomic_cmpswap_relaxed(&b->globalref, &globalref, newref)) {
JL_GC_PROMISE_ROOTED(newref);
globalref = newref;
Expand Down Expand Up @@ -662,12 +688,18 @@ JL_DLLEXPORT jl_binding_t *jl_get_module_binding(jl_module_t *m JL_PROPAGATES_RO
return b == HT_NOTFOUND ? NULL : b;
}


JL_DLLEXPORT jl_value_t *jl_binding_value(jl_binding_t *b JL_PROPAGATES_ROOT)
{
return b->value;
}

JL_DLLEXPORT jl_value_t *jl_get_global(jl_module_t *m, jl_sym_t *var)
{
jl_binding_t *b = jl_get_binding(m, var);
if (b == NULL) return NULL;
if (b->deprecated) jl_binding_deprecation_warning(m, b);
return b->value;
return jl_binding_value(b);
}

JL_DLLEXPORT void jl_set_global(jl_module_t *m JL_ROOTING_ARGUMENT, jl_sym_t *var, jl_value_t *val JL_ROOTED_ARGUMENT)
Expand Down Expand Up @@ -696,10 +728,22 @@ JL_DLLEXPORT void jl_set_const(jl_module_t *m JL_ROOTING_ARGUMENT, jl_sym_t *var
jl_symbol_name(bp->name));
}

JL_DLLEXPORT int jl_binding_is_const(jl_binding_t *b)
{
assert(b);
return b->constp;
}

JL_DLLEXPORT int jl_binding_boundp(jl_binding_t *b)
{
assert(b);
return b->value != 0;
}

JL_DLLEXPORT int jl_is_const(jl_module_t *m, jl_sym_t *var)
{
jl_binding_t *b = jl_get_binding(m, var);
return b && b->constp;
return b && jl_binding_is_const(b);
}

// set the deprecated flag for a binding:
Expand Down
4 changes: 4 additions & 0 deletions src/staticdata.c
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,10 @@ static void jl_serialize_value__(jl_serializer_state *s, jl_value_t *v, int recu
jl_serialize_value(s, tn->partial);
}
else if (t->layout->nfields > 0) {
if (jl_typeis(v, jl_globalref_type)) {
// Don't save the cached binding reference in staticdata
((jl_globalref_t*)v)->bnd_cache = NULL;
}
char *data = (char*)jl_data_ptr(v);
size_t i, np = t->layout->npointers;
for (i = 0; i < np; i++) {
Expand Down

0 comments on commit 7677e64

Please sign in to comment.