Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ideas for improving inference and optimizer performance #33326

Closed
JeffBezanson opened this issue Sep 19, 2019 · 10 comments
Closed

ideas for improving inference and optimizer performance #33326

JeffBezanson opened this issue Sep 19, 2019 · 10 comments
Labels
compiler:inference Type inference compiler:latency Compiler latency compiler:optimizer Optimization passes (mostly in base/compiler/ssair/)

Comments

@JeffBezanson
Copy link
Member

Many of our latency issues (e.g. time to first plot) primarily consist of inference time. This issue is for documenting and tracking ideas for improving this. I'll try to put ~one idea per comment.

@JeffBezanson JeffBezanson added compiler:inference Type inference compiler:optimizer Optimization passes (mostly in base/compiler/ssair/) compiler:latency Compiler latency labels Sep 19, 2019
@JeffBezanson
Copy link
Member Author

Idea 1: customize max_methods per module.

The first 4 on this line in base/compiler/params.jl is really important:

                   #=inline_tupleret_bonus, max_methods, union_splitting, apply_union_enum=#
                   400, 4, 4, 8,

When an inferred call site has multiple possible targets (due to type imprecision) that limits the number of methods we examine. So the amount of work done by inference can theoretically be max_methods^n where n is the length of a call chain. Usually the blowup is not that bad, but on rare occasion it is, and we end up inferring tons of methods on very general types, ultimately getting no really useful information. Reducing max_methods reliably speeds up inference, with by far the best results from max_methods=1 (unsurprisingly!).

Setting max_methods=1 globally often results in worse type info. But a lot of code is probably fine with it. The idea here is to allow setting max_methods=1 for all code in a given module, e.g. Plots.jl, similar (but orthogonal) to the per-module @nospecialize we already have.

@JeffBezanson
Copy link
Member Author

Idea 2: speed up matching_methods

The inner loop of inference is looking up the set of method matches for a call site. Our procedure for doing that is fairly bulky. The majority of lookups (~60%) are actually for concrete types with a single method match, which could be handled by a much more efficient table lookup like what we use for method dispatch.

#33261 tries to implement that optimization. However, the problem is that we need to both determine the matching method and the range of world ages for which that lookup is correct. That might require looking at multiple table entries, so some extra work is needed.

@JeffBezanson
Copy link
Member Author

Idea 3: cache matching_methods between inference and inlining

Inference calls _methods_by_ftype to determine matching methods for each call site. Then we do inlining, which calls it again for each call site to see if it can inline anything. We could add a per-statement array to cache those results.

@JeffBezanson
Copy link
Member Author

Idea 4: bail out of long abstract call chains

As discussed in idea 1 above, inference sometimes goes on a wild goose chase where f calls g calls h and so on, for a very long chain, and potentially all on Any and resulting in nothing useful. We could potentially detect that situation and bail out. I imagine we would throw back to the last "useful" function (e.g. one whose signature passes jl_isa_compileable_sig), recording Any as the type of the problematic call site.

This is a little tricky. Most importantly, note that the logic

if call_depth > max
    return Any
end

does not work, since it would make inference context-sensitive: we'd infer a different type for a function based on where in the call chain it occurs. So some care is needed.

@JeffBezanson
Copy link
Member Author

Idea 5: elide recursion checks in abstract_call_method

For each call site, we check whether it is recursive and therefore might need to be limited. That involves stepping back through the inference call stack. However, we might just be trying to look up some straightforward thing for which we already have a good cached inference result. It seems we should be able to elide the recursion check in that case. I haven't thought too much about this one though.

@vtjnash
Copy link
Member

vtjnash commented Sep 19, 2019

Idea 6: avoid inferring work that only matter for compilation

Inference always completely analyzes the callee function, under the assumption that we may later need to optimize or inline it. However, in many cases, we don't end up optimizing or inlining it. Just computing the return type is often easy and can be substantially cheaper.

@JeffBezanson
Copy link
Member Author

JeffBezanson commented Sep 19, 2019

Idea 7: be able to save inference results in .ji files

E.g. #31466#32705 (EDIT @timholy) works towards this. Could save a lot of inference time for precompiled packages.

@JeffBezanson
Copy link
Member Author

Backedges part 1: reducing graph density

When f calls g, and inference on f depends on information about g, a "backedge" is stored from g to f, so that if g changes we can invalidate the inference result.

Backedges need to be stored in .ji files and processed on load, which takes time, and invalidating too many functions leads to costly re-inferring and compiling.

The idea here is to reduce the number of backedges somehow. For example, find call sites that are not "important" (e.g. they only happen on an error path) and infer them as Any to avoid storing a backedge.

@JeffBezanson
Copy link
Member Author

Backedges part 2: precision

Currently any new method causes invalidation via any backedges with an overlapping signature. This could be much more precise if it could take inference results into account. For example, if all we infer about f(x) is that it returns an AbstractArray, and f(x) is dynamically dispatched, then new methods of f should have no effect as long as they also return AbstractArray.

@timholy
Copy link
Member

timholy commented May 5, 2020

Backedges part 3: limit the propagation of invalidation

If a given MethodInstance gets invalidated, all of its callers (direct and indirect) also get invalidated. In many cases this seems wasteful: in particular, if the return type inference is the same and the MethodInstance was not inline_worthy, then it seems that it should be possible (in principle) to just update the immediate callers to call the new version, and thus break the (sometimes very long) chain of invalidations.

(Cross post from slack, don't want to lose it.) Here's a call chain of 3 methods. We change the lowest one, and if we could rewrite the call from the middle method to call a different instance of the lowest method, we could stop there. This comes up frequently when implementing "fallbacks" and "specializations," where here I've implemented an `O(N)` fallback and an `O(1)` specialization:
julia> @noinline function countelements(iter)
           # In general, we just have to count them, O(N)
           n = 0
           for item in iter
               n += 1
           end
           return n
       end
countelements (generic function with 1 method)

julia> @noinline domath(x) = countelements(x)*5.0 - 1.5
domath (generic function with 1 method)

julia> doubleit(x) = 2*domath(x)
doubleit (generic function with 1 method)

julia> x = [1, 2, 3]
3-element Array{Int64,1}:
 1
 2
 3

julia> doubleit(x)
27.0

julia> @code_llvm doubleit(x);  @ REPL[3]:1 within `doubleit'
define double @julia_doubleit_214(%jl_value_t* nonnull align 16 dereferenceable(40)) {
top:
  %1 = call double @j_domath_215(%jl_value_t* nonnull %0)
; ┌ @ promotion.jl:312 within `*' @ float.jl:405
   %2 = fmul double %1, 2.000000e+00
; └
  ret double %2
}

Now watch the invalidations when we define the O(1) algorithm:

julia> unsafe_store!(cglobal(:jl_debug_method_invalidation, Cint), 1)
Ptr{Int32} @0x00007f07d4824180

julia> @noinline countelements(iter::AbstractArray) = length(iter)
 domath(Array{Int64, 1}) (in Main)
  doubleit(Array{Int64, 1}) (in Main)
>> Main.countelements(...) Tuple{typeof(Main.countelements), AbstractArray{T, N} where N where T}
countelements (generic function with 2 methods)

julia> doubleit(x)
27.0

julia> @code_llvm doubleit(x);  @ REPL[3]:1 within `doubleit'
define double @julia_doubleit_249(%jl_value_t* nonnull align 16 dereferenceable(40)) {
top:
  %1 = call double @j_domath_250(%jl_value_t* nonnull %0)
; ┌ @ promotion.jl:312 within `*' @ float.jl:405
   %2 = fmul double %1, 2.000000e+00
; └
  ret double %2
}

You can see doubleit got recompiled when all we really had to do was change a single call in domath.

My preliminary sense is that this comes up quite a lot in practice. Currently we get a bunch of invalidations from loading FixedPointNumbers that stem from convert(Type{<:Bool}, ::Bool). Now, that might be fixed by other means, but in any event it would seem that once you fix the direct callers of this method you should be able to stop there.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
compiler:inference Type inference compiler:latency Compiler latency compiler:optimizer Optimization passes (mostly in base/compiler/ssair/)
Projects
None yet
Development

No branches or pull requests

3 participants