-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add ability to pass a user context in JIT mode #6313
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about evaluate()
, compile_jit()
, and anything else that does jitting implicitly?
this means that handlers or error buffers can be overridden per call to realize now in a thread-safe manner
This should be explicitly documented.
This PR almost certainly will need testing for WebAssembly, since it's a weird JIT case.
On the whole I don't see anything wrong with this, and I should be happy to expand the ability of JIT in this way, but part of me feels like if we're going to be attacking the limitations of the JIT runtime situation, it would be nice to find a way to re-hook everything in a generic way, but that's definitely scope creep...
(Withholding approval pending green buildbots)
@@ -349,6 +349,8 @@ class Pipeline { | |||
*/ | |||
void compile_jit(const Target &target = get_jit_target_from_environment()); | |||
|
|||
// TODO: deprecate all of these and replace with versions that take a JITUserContext |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we want to land this PR, we absolutely deprecate these as part of it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about making new ones that accept any function that takes a T * as the first arg, where T * must be implicitly convertible to a JITUserContext *? Then you could have signatures that accept a derived class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You shouldn't need a template for that. Given struct Derived : public JITUserContext
, passing a Derived*
should be acceptable anywhere a JITUserContext*
is.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, but I don't want to pass a Derived, I want to pass a function pointer that takes a Derived * as the first argument.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, right. Well... meh. Adding more template usage to our public API is something I'm not wild about -- what would it look like?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm honestly not too wild about the idea now, because it would only be correct if the user passes a context object of the matching type to the next realize call. If they don't it's an implicit downcast to the wrong type. Let's just add versions that take function pointers that take JITUserContext *
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Builds are green. I'm inclined to deprecate these as a follow-up PR, because that will involve a lot of code tweaking in tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm inclined to deprecate these as a follow-up PR
SGTM, re-reviewing now
compile_jit doesn't need to know, because this just changes the handlers struct that gets used at runtime, but evaluate and infer_input_bounds conceivably do. If there's agreement that this approach is sound, I'll do that. I agree it would be good to make hooking generic instead of there being a blessed set of overridable functions, but I wanted to defer that to a later PR because the design is a little tricky. This PR is just exposing an existing thing. |
but with comments explaining why they are the way they are
Now, if we customize the JITUserContext and provide custom handlers for things like obtaining a custom device context/stream/..., how do we do the same in |
|
Yeah, that seems to work fine. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with nits regarding documentation.
(If the followup deprecation work isn't going to happen right away, we should create a tracking issue so it isn't overlooked.)
Realization Func::realize(JITUserContext *context, | ||
std::vector<int32_t> sizes, | ||
const Target &target, | ||
const ParamMap ¶m_map) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Idly wondering if we could someday roll param_map
into the JITUserContext...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ParamMap holds the arguments to the realize call and thus is a completely separate concept from the user context. Arguably the user context could be passed in the ParamMap, but that doesn't seem an improvement to me. Normally Params are retrieved from global variables, but that is not thread safe. It is a bit of a silly design in the first place, but it is really convenient for the just hacking up Halide code so it is totally a thing... In order to pass an arbitrary set of arguments through realize to a JITted call, one has to use some sort of keyed dynamic data structure. That is what ParamMap is.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, the runtime design I'm looking at does make the user context part of the arguments. But until the design doc is done, there's not much point in discussing it. If it goes through, it will reverse this change entirely. (The goal is to make the Halide compiler able to pass through a flexible contract from the outside caller to the runtime called from Halide generated code. This is done by possibly adding arguments to the runtime calls at codegen time.)
@@ -1114,7 +1142,7 @@ class Func { | |||
|
|||
/** Get a struct containing the currently set custom functions | |||
* used by JIT. */ | |||
const Internal::JITHandlers &jit_handlers(); | |||
JITHandlers &jit_handlers(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This returns a mutable ref, so I assume it's ok to just mutate the contents... we should probably explicitly document the rules for doing so. (eg, when do changes I make take effect?)
src/Func.h
Outdated
@@ -829,6 +829,14 @@ class Func { | |||
Realization realize(std::vector<int32_t> sizes = {}, const Target &target = Target(), | |||
const ParamMap ¶m_map = ParamMap::empty_map()); | |||
|
|||
/** Same as above, but takes a custom user-provided context to be | |||
* passed to runtime functions. This can be used to pass state to | |||
* runtime overrides in a thread-safe manner. */ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it legal to pass nullptr
for context? Should be documented.
JIT mode has long hijacked user_context for its own purposes. It uses it to store overrides of runtime functions, and it uses it to store an error buffer for accumulating error messages. This means user_context isn't available in JIT mode, which leaves no good way to pass state to custom overrides in a thread-safe manner.
This PR promotes the existing JITUserContext struct to the public namespace, and lets you pass one per call to realize, instead of hiding it inside the realize implementation. By inheriting from this struct and passing a subclass in as the context pointer, you can pass additional state to your runtime overrides! This is demonstrated in the new test.
Runtime overrides in JIT mode will now be expected to take a JITUserContext * as the first arg instead of a void *. The existing set_custom_foo methods that expect function pointers with void * first args are left in place, but I think we should deprecate them and add new ones that expect a JITUserContext *.
I didn't make any of the base class (JITUserContext) state private or anything, so as a side-effect, this means that handlers or error buffers can be overridden per call to realize now in a thread-safe manner.