Skip to content
This repository has been archived by the owner on Jul 1, 2023. It is now read-only.

[Layer] Remove the 'context' argument from 'Layer.applied(to:in:)'. #87

Merged
merged 2 commits into from
Apr 15, 2019

Conversation

rxwei
Copy link
Contributor

@rxwei rxwei commented Apr 14, 2019

This PR removes the context argument from Layer's applied(to:in:) method to simplify the call sites of layer application. Instead of passing contexts as an argument, we make contexts available as thread-local information like how our device placement APIs (withDevice(_:execute:)) work.

We made this decision based on user feedback and API design principles, and would like our APIs to be clear and fluent. The team is aware of potential slight concerns in the context of concurrency, but has decided to start from first principles to make this APIs easier to use for everyone.

Internal changes

Add ContextManager class, each of whose instance manages a stack of contexts for one thread. Each thread-local singleton can be accessed through thread-safe type property local. The most recent context can be accessed through computed property currentContext. Pushing and popping can be done via push(_:) and popContext().

I chose not to reuse the thread-local stack for device placement in the compiler runtime code base because it is engineered differently and is tied to the TensorFlow backend.

User-visible changes

  • Remove the context argument. Change applied(to:in:) to applied(to:) in the Layer protocol and all layers.

  • Change Context from a class to a struct, since access to thread-local contexts is already by reference and involves no copies.

  • Add thread-safe computed property .local to Context for retrieving the current context.
    For example, here's how you can check the current learning phase as part of your layer.

    switch Context.local.learningPhase {
    case .training:
        ...
    case .inference:
        ...
    }

    There is a default Context, where learningPhase == .inference. To train a model, instead of creating a Context and passing it to applied(to:in:), users can now set the local learning phase and call applied(to:) directly.

    Context.local.learningPhase = .training
    let grad = model.gradient { model in
        let ŷ = model.applied(to: x)
        return softmaxCrossEntropy(logits: ŷ, labels: y)
    }
  • Add APIs withContext(_:_:), which calls a closure under a temporary context, and withLearningPhase(_:_:), which calls a closure under a temporary learning phase.
    For example, Layer.inferring(from:) is now implemented as

    func inferring(from input: Input) -> Output {
        return withLearningPhase(.inference) { applied(to: input) }
    }

Future PRs will migrate all existing models and tutorials to the new API.

@rxwei rxwei self-assigned this Apr 14, 2019
@lattner
Copy link

lattner commented Apr 14, 2019

yaaaaay! :-)

@rxwei rxwei requested review from lattner and dan-zheng April 15, 2019 07:05
@rxwei rxwei changed the title [Layer] Remove the 'context' argument to Layer.applied(to:in:). [Layer] Remove the 'context' argument from 'Layer.applied(to:in:)'. Apr 15, 2019
@rxwei
Copy link
Contributor Author

rxwei commented Apr 15, 2019

I'll merge and begin migrating models/tutorials/notebooks now, which I anticipate will be a lot more work than this PR. PR feedback will be welcome.

@rxwei rxwei merged commit 16d87eb into tensorflow:master Apr 15, 2019
@rxwei rxwei deleted the no-context branch April 15, 2019 07:15
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants