-
Notifications
You must be signed in to change notification settings - Fork 469
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: 🐛 fix sar saved model export #635
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR! I added some comments, let's avoid breaking changes in this PR if possible 🙏
if kwargs.get('training') and gt is None: | ||
raise ValueError('Need to provide labels during training for teacher forcing') | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This needs to stay, was it blocking the savedmodel export?
@@ -150,7 +149,8 @@ def call( | |||
# shape (N, rnn_units + 1) -> (N, vocab_size + 1) | |||
logits = self.output_dense(logits, **kwargs) | |||
# update symbol with predicted logits for t+1 step | |||
if kwargs.get('training'): | |||
if kwargs.get('training') and gt is not None: | |||
#'Need to provide labels during training for teacher forcing' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need to duplicate comment
@@ -150,7 +149,8 @@ def call( | |||
# shape (N, rnn_units + 1) -> (N, vocab_size + 1) | |||
logits = self.output_dense(logits, **kwargs) | |||
# update symbol with predicted logits for t+1 step | |||
if kwargs.get('training'): | |||
if kwargs.get('training') and gt is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can be reverted since we already ensure that the gt is provided in training mode above
word_values = [word.decode() for word in decoded_strings_pred.numpy().tolist()] | ||
|
||
return list(zip(word_values, probs.numpy().tolist())) | ||
|
||
return { | ||
"decoded_strings_pred":decoded_strings_pred, | ||
"probs":probs, | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can't change the output signature of the call method in this PR or we'll break the high-level API
@khalidMindee is this PR still good? If yes, can you review FG's comments, modify in consequences so we can potentially merge it please? |
@khalidMindee Do you want to go further with this ? :) Otherwise i think we have on track to provide, that every model can be exported to ONNX #789 If you are good with this we could close the PR :) |
Ah great @felixdittrich92 that's fine for me to close the PR . |
This PR solves #602 for SAR recognition models to allow export as a saved model.