-
Notifications
You must be signed in to change notification settings - Fork 76
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Move RNN to layers.py and make it stateless. #97
base: master
Are you sure you want to change the base?
Conversation
objax/nn/layers.py
Outdated
@@ -327,6 +327,63 @@ def __call__(self, x: JaxArray) -> JaxArray: | |||
self.avg.value += (self.avg.value - x) * (self.momentum - 1) | |||
return self.avg.value | |||
|
|||
class RNN(Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the name RNN
is too generic.
Pretty much any type of recurrent block (LSTM, GRU, ....) could be called RNN.
Is there some better way to call it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also RNN refers to the architecture, not to the cell. Here's what TF/Keras does https://www.tensorflow.org/api_docs/python/tf/compat/v1/nn/rnn_cell/RNNCell
Not sure what PyTorch does.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a specific RNN architecture that operates across time (so not a cell). I would call this something like SimpleRNN; and make sure it replicates keras' SimpleRNN functionality with default arguments:
https://www.tensorflow.org/api_docs/python/tf/keras/layers/SimpleRNN
RNN you could reserve as an object that takes an RNNCell and performs a scan across time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed the name to SimpleRNN
objax/nn/layers.py
Outdated
|
||
self.output_layer = Linear(self.nstate, self.num_outputs) | ||
|
||
def __call__(self, inputs: JaxArray, only_return_final=False) -> JaxArray: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggest adding a get_initial_state method and optional initial_state argument here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added an optional initial_state argument to the call() method.
Can you clarify what the get_initial_state() method would do, considering that the state is initialized during every call() (unless explicitly passed in through the optional argument)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there are two reasons to have a get_initial_state
: One, the caller wants to know if this layer is recurrent, without checking for some general instance type. Two, the caller wants to know the shapes etc of the state, without running __call__
. This is useful for many reasons, like creating buffers for storing state.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to clarify, does get_init_state
really act like a create_init_state
? Or is there an init_state
stored inside the instance?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no; it's a purely functional thing that returns some arrays.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As far as I understood from some of the Keras code, get_initial_state
simply returns zero array of appropriate shape (ex: https://github.com/tensorflow/tensorflow/blob/fcc4b966f1265f466e82617020af93670141b009/tensorflow/python/keras/layers/recurrent.py#L1948 )
It's still not very clear how useful it is.
Could you point us to some example of how it's actually used (either in Tensorflow or any other framework)?
To know shape of the state it would be better to just call rnn_cell_layer.nstate
or maybe have helper method get_state_shape
.
Using get_initial_state
as a way to determine whether layer is RNN seems like a little weird. I don't see how getattr(layer, 'get_initial_state')
is better than isinstance(layer, RNNCell)
. If there is a need to determine whether layer is RNN cell, I think it's better just to make all RNN cells to inherit from some base class and do isinstance
check.
objax/nn/layers.py
Outdated
only_return_final: return only the last output if ``True``, or all output otherwise.` | ||
|
||
Returns: | ||
Output tensor with dimensions ``N * batch_size, vocabulary_size``. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is vocabulary_size the right terminology for RNNs? perhaps you mean nout
here?
Also why is batch_size included here? I thought you don't consider batch_size in these layers?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed vocabulary_size -> nout
I include batch_size because we can process a batch of input data.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@david-berthelot do other layers "know" about batch dimensions? does this one need to?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(from david on another PR: no, layers don't know about batch dimensions, so this one shouldn't either. instead, add a unit test with this object and Vectorized
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Needs a unit test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What the RNN
stands out by in this lib for me is the code readability and simplicity. Any person can easily extend it.
jn.dot(x, self.w_xh.value) | ||
+ jn.dot(state, self.w_hh.value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
num_inputs
could be zero. -- Essentially empty inputs but internal states continue to evolve along time.
Not sure if we shall use two weight matrices or one to act on concatenated [h, x]
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typically it's more efficient to act on one concatenated [h, x]
, but depends on the system and sizes. At some point you can make this an __init__
mode parameter like Keras does. For now I'd suggest using the concatenated format.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another nit, use x.dot(y)
rather than jn.dot(x, y)
since we might as well take advantage of object oriented APIs.
+ jn.dot(state, self.w_hh.value) | ||
+ self.b_h.value | ||
) | ||
y = self.output_layer(state) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need output_layer
or can we directly return internal states h
and let user do further transform on that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I opted for having an output_layer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Question why: this is something the user can do themselves after, right? So is there any purpose to add an output_layer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would drop the output layer, that's forcing a decision on the user about what type of output they'd want.
initial_state to the constructor, and output RNN state when call() returns.
All (the pull request submitter and all commit authors) CLAs are signed, but one or more commits were authored or co-authored by someone other than the pull request submitter. We need to confirm that all authors are ok with their commits being contributed to this project. Please have them confirm that by leaving a comment that contains only Note to project maintainer: There may be cases where the author cannot leave a comment, or the comment is not properly detected as consent. In those cases, you can manually confirm consent of the commit author(s), and set the ℹ️ Googlers: Go here for more info. |
1 similar comment
All (the pull request submitter and all commit authors) CLAs are signed, but one or more commits were authored or co-authored by someone other than the pull request submitter. We need to confirm that all authors are ok with their commits being contributed to this project. Please have them confirm that by leaving a comment that contains only Note to project maintainer: There may be cases where the author cannot leave a comment, or the comment is not properly detected as consent. In those cases, you can manually confirm consent of the commit author(s), and set the ℹ️ Googlers: Go here for more info. |
619be74
to
efcb605
Compare
I am ooo and will return next week.
…On Wed, Oct 28, 2020, 5:44 PM Andreas Terzis (Google) < ***@***.***> wrote:
@aterzis-google <https://github.com/aterzis-google> requested your review
on: #97 <#97> Move RNN to layers.py
and make it stateless..
—
You are receiving this because your review was requested.
Reply to this email directly, view it on GitHub
<#97 (comment)>, or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AANWFG4JVHW6MNO2S5U5WFTSNC3FJANCNFSM4SNTPVWA>
.
|
|
||
if only_return_final: | ||
return y, state | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need for else
.
if only_return_final: | ||
return y, state | ||
else: | ||
return jn.concatenate(outputs, axis=0), state |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should it be jn.stack
?
jn.dot(x, self.w_xh.value) | ||
+ jn.dot(state, self.w_hh.value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another nit, use x.dot(y)
rather than jn.dot(x, y)
since we might as well take advantage of object oriented APIs.
def __call__(self, inputs: JaxArray, initial_state: JaxArray = None, | ||
only_return_final: bool = False) -> Tuple[JaxArray, JaxArray]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One argument per line if they don't all fit on one line.
|
||
|
||
def loss(x, label): # sum(label * log(softmax(logit))) | ||
logit = model(x) | ||
return objax.functional.loss.cross_entropy_logits(logit, label).mean() | ||
logits, _ = model(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logits = model(x)[0]
outputs = [vocab[prefix[0]]] | ||
get_input = lambda: one_hot(jn.array([outputs[-1]]).reshape(1, 1), len(vocab)) | ||
for y in prefix[1:]: # Warmup state with prefix | ||
model(get_input()) | ||
outputs.append(vocab[y]) | ||
for _ in range(num_predicts): # Predict num_predicts steps | ||
Y = model(get_input()) | ||
Y, _ = model(get_input()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Uppercase are for global constants, use lower case identifiers for variables please.
- Also rather than doing two assigns, the better way is to just assign what you use.
Y = model(get_input())[0]
<<<<<<< HEAD:examples/text_generation/shakespeare_rnn.py | ||
from objax.nn import SimpleRNN | ||
======= | ||
from objax.nn import RNN | ||
>>>>>>> 2c04d4e (Move RNN to layers.py and make it stateless.):examples/rnn/shakespeare.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Your commit contains an unresolved merge.
No description provided.