Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Transfer Learning API #2

Merged
merged 8 commits into from
Mar 23, 2021
Merged

Implement Transfer Learning API #2

merged 8 commits into from
Mar 23, 2021

Conversation

n2cholas
Copy link
Owner

@n2cholas n2cholas commented Mar 13, 2021

Closes #1. Implements a Sequential combinator as well as a slice to easily extract portions of the models.

TODO:

  • Write more thorough tests
  • Write documentation

@n2cholas
Copy link
Owner Author

@cgarciae any thoughts on this initial design?

@cgarciae
Copy link

Hey @n2cholas ! We had a discussion in this PR as ways to approach this: poets-ai/elegy#169

My thoughts:

  • Slicing is simple and probably the way to go with Flax.
  • elegy.Module enables you to do Transfer Learning in more flexible and easier fashion but I don't want push it too hard because the ecosystem seems to be unifying around Flax. I might write a blog with the approach we use.

@n2cholas
Copy link
Owner Author

Elegy's approach looks very flexible, I'll definitely give it a try in the future. A blog post would be very nice!

Since this project's scope is ResNet-style architectures, I'll iterate on this slice API a bit and stick to it.

I'm curious to see how Flax will address this problem. nn.compact gives a great boost for ease of implementation and readability, but somewhat sacrifices re-usability of the code (since, for now, you have to treat the Module like black box).

I actually quite liked jax.experimental.stax's combinator ideology--it made dataflow explicit and would enable simple arbitrary model surgery.

@n2cholas n2cholas marked this pull request as ready for review March 23, 2021 16:05
@n2cholas n2cholas marked this pull request as draft March 23, 2021 16:07
@n2cholas n2cholas marked this pull request as ready for review March 23, 2021 17:29
@n2cholas
Copy link
Owner Author

@cgarciae we'll provide a Sequential module and a slice_variables method. It's trivial to slice a sequential model yourself (sliced_model = Sequential(model.layers[start:end]), and slice_variables will give you the corresponding variables dict (sliced_variables = slice_variables(variables, start, end)).

Once Flax has its own Sequential module (PR), I'll switch to that.

Thanks again for opening this up!

@n2cholas n2cholas merged commit 5577791 into main Mar 23, 2021
@n2cholas n2cholas deleted the sequential branch March 23, 2021 17:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Transfer Learning API
2 participants