-
Notifications
You must be signed in to change notification settings - Fork 10.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add an api to trace an function and run it with XLA compilation. #23868
Conversation
@swift-ci please test tensorflow |
@@ -218,7 +218,8 @@ private class TraceContext { | |||
|
|||
/// Execute the trace graph function, and return the list of output tensors | |||
/// from the trace execution. These output tensors are owned by the caller. | |||
func execute(traceeInputs: [_AnyTensorHandle]) -> [CTensorHandle] { | |||
func execute( | |||
traceeInputs: [_AnyTensorHandle], useXla: Bool = false) -> [CTensorHandle] { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Xla
-> XLA
// Trace the given function to generate a TF graph and return a closure | ||
// that can be used to launch the graph. | ||
public func _graph<In : TensorGroup, Out : TensorGroup>( | ||
_ fn: (In) -> Out, useXla: Bool = false |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Xla
-> XLA
) -> (In) -> Out { | ||
let traceContext: TraceContext = withoutActuallyEscaping(fn) { escapableFn in | ||
let wrappedFn = { | ||
(inputs: [CTensorHandle]) -> [CTensorHandle] in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this could be moved to the previous line without going over the column limit.
This PR adds an API to trace a function that takes a
TensorGroup
and returns aTensorGroup
. The new API also takes an optionaluseXLA
argument that enables XLA compilation when executing the traced graph.This is one of the several PRs needed for https://bugs.swift.org/browse/TF-406