Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an api to trace an function and run it with XLA compilation. #23868

Merged
merged 1 commit into from
Apr 8, 2019

Conversation

bgogul
Copy link
Contributor

@bgogul bgogul commented Apr 8, 2019

This PR adds an API to trace a function that takes a TensorGroup and returns a TensorGroup. The new API also takes an optional useXLA argument that enables XLA compilation when executing the traced graph.

This is one of the several PRs needed for https://bugs.swift.org/browse/TF-406

@bgogul bgogul requested a review from pschuh April 8, 2019 21:40
@bgogul
Copy link
Contributor Author

bgogul commented Apr 8, 2019

@swift-ci please test tensorflow

@bgogul bgogul merged commit b71fb98 into swiftlang:tensorflow Apr 8, 2019
@bgogul bgogul deleted the xla_compile branch April 8, 2019 22:53
@@ -218,7 +218,8 @@ private class TraceContext {

/// Execute the trace graph function, and return the list of output tensors
/// from the trace execution. These output tensors are owned by the caller.
func execute(traceeInputs: [_AnyTensorHandle]) -> [CTensorHandle] {
func execute(
traceeInputs: [_AnyTensorHandle], useXla: Bool = false) -> [CTensorHandle] {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Xla -> XLA

// Trace the given function to generate a TF graph and return a closure
// that can be used to launch the graph.
public func _graph<In : TensorGroup, Out : TensorGroup>(
_ fn: (In) -> Out, useXla: Bool = false
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Xla -> XLA

) -> (In) -> Out {
let traceContext: TraceContext = withoutActuallyEscaping(fn) { escapableFn in
let wrappedFn = {
(inputs: [CTensorHandle]) -> [CTensorHandle] in
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this could be moved to the previous line without going over the column limit.

bgogul added a commit to bgogul/swift that referenced this pull request Apr 12, 2019
rxwei pushed a commit that referenced this pull request Apr 13, 2019
rxwei pushed a commit to rxwei/swift that referenced this pull request May 11, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants