-
Notifications
You must be signed in to change notification settings - Fork 259
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Miscellaneous GAIL improvements and refactoring #133
Conversation
Codecov Report
@@ Coverage Diff @@
## master #133 +/- ##
==========================================
- Coverage 87.14% 87.13% -0.01%
==========================================
Files 60 60
Lines 4332 4329 -3
==========================================
- Hits 3775 3772 -3
Misses 557 557
Continue to review full report at Codecov.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the refactoring / ease-of-use changes Sam. Have several questions about what to keep and not to keep when we merge with our custom logger in #135 .
mean_stats = { | ||
k: np.mean(v) for k, v in stat_dict_accum.items() | ||
} | ||
return mean_stats |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you have any code that uses the mean_stats
return value? I think that in a later commit I can cleanly merge all of the mean accumulating logic here with the mean-accumulating custom logger from #135.
The mean-accumulating logger gets to write the means to an SBLogger (whereas the means aren't logged here), so I'm wondering if you think we will still need to return mean_stats
later on.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do, but I won't need it once #135 is removed. It's fine to take out.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(I just added a comment saying you can remove it; I'll let you do the honours when you merge #135)
log_fmts = [ | ||
make_output_format(s, disc_log_dir) for s in log_fmt_strs | ||
] | ||
self._disc_logger = Logger(disc_log_dir, log_fmts) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI, When I merge with #135 later, probably won't use this second logger. Instead will use a "discrim"
context.
This is fine by me for this PR though.
Co-Authored-By: Steven H. Wang <[email protected]>
Co-Authored-By: Steven H. Wang <[email protected]>
3e49a29
to
e4be624
Compare
src/imitation/rewards/discrim_net.py
Outdated
construction of the discriminator network, and a `tf.Tensor` | ||
representing the desired discriminator logits. | ||
build_discrim_net_kwargs: optional extra keyword arguments for | ||
`build_discrim_net()`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now that we have the build_discrim_net_kwargs
pattern, could we make build_mlp_discrim_net
into a function rather than a class?
Not sure if pytype
would be happy with Callable + arbitrary kwargs. (It looks like newer versions of Python will have a more precise Protocol
type for defining callable types: https://stackoverflow.com/a/57840786/1091722)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would like to get rid of the functor class but otherwise is looking good
Done! Sorry for the delay, I forgot that I still had a change to make. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks
This PR includes a set of changes that make GAIL more flexible & easier to use. The DAgger PR (#128) and the BC PR (#125) should be merged before this one, since the branch I'm pulling from is based on both of those. New changes here:
make_summary_writer
writes its logs & makesAdversarialTrainer
responsible for passing in that directory.make_summary_writer
intoinit_trainer
so that the resulting output directory can be used for logs, and not just TB summaries.build_
methods fromDiscrimNet
(build_{train_reward,test_reward,summaries,disc_loss}
, IIRC) and replaces them with onebuild_graph
method. Also merges the twobuild_
methods inAdversarialTrainer
.Logger
toAdversarialTrainer
that records discriminator stats at each update. At the moment this only works for GAIL; the extension to AIRL is straightforward, but I don't have time to test it manually, so I haven't done it myself.Originally I was going to refactor
DiscrimNet
entirely so that it passes around Keras models instead of construction functions and kwargs, but I don't have time to do so at the moment (I've added it to the wishlist in #31).