Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Miscellaneous GAIL improvements and refactoring #133

Merged
merged 23 commits into from
Jan 8, 2020

Conversation

qxcv
Copy link
Member

@qxcv qxcv commented Nov 30, 2019

This PR includes a set of changes that make GAIL more flexible & easier to use. The DAgger PR (#128) and the BC PR (#125) should be merged before this one, since the branch I'm pulling from is based on both of those. New changes here:

  • Allows a custom discriminator model constructor to be passed to GAIL.
  • Allows control over where make_summary_writer writes its logs & makes AdversarialTrainer responsible for passing in that directory.
  • Moves the unique output directory generation code that was in make_summary_writer into init_trainer so that the resulting output directory can be used for logs, and not just TB summaries.
  • Removes the four build_ methods from DiscrimNet (build_{train_reward,test_reward,summaries,disc_loss}, IIRC) and replaces them with one build_graph method. Also merges the two build_ methods in AdversarialTrainer.
  • Adds an SB Logger to AdversarialTrainer that records discriminator stats at each update. At the moment this only works for GAIL; the extension to AIRL is straightforward, but I don't have time to test it manually, so I haven't done it myself.

Originally I was going to refactor DiscrimNet entirely so that it passes around Keras models instead of construction functions and kwargs, but I don't have time to do so at the moment (I've added it to the wishlist in #31).

@qxcv qxcv changed the title GAIL reward net refactor Custom reward nets for GAIL Nov 30, 2019
@qxcv qxcv changed the title Custom reward nets for GAIL Miscellaneous GAIL improvements and refactoring Dec 1, 2019
@codecov
Copy link

codecov bot commented Dec 1, 2019

Codecov Report

Merging #133 into master will decrease coverage by <.01%.
The diff coverage is 100%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #133      +/-   ##
==========================================
- Coverage   87.14%   87.13%   -0.01%     
==========================================
  Files          60       60              
  Lines        4332     4329       -3     
==========================================
- Hits         3775     3772       -3     
  Misses        557      557
Impacted Files Coverage Δ
src/imitation/rewards/discrim_net.py 97.9% <100%> (-0.05%) ⬇️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update d2fd8cb...aac5457. Read the comment docs.

Copy link
Member

@shwang shwang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the refactoring / ease-of-use changes Sam. Have several questions about what to keep and not to keep when we merge with our custom logger in #135 .

src/imitation/algorithms/adversarial.py Outdated Show resolved Hide resolved
src/imitation/rewards/discrim_net.py Outdated Show resolved Hide resolved
src/imitation/algorithms/adversarial.py Outdated Show resolved Hide resolved
mean_stats = {
k: np.mean(v) for k, v in stat_dict_accum.items()
}
return mean_stats
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have any code that uses the mean_stats return value? I think that in a later commit I can cleanly merge all of the mean accumulating logic here with the mean-accumulating custom logger from #135.

The mean-accumulating logger gets to write the means to an SBLogger (whereas the means aren't logged here), so I'm wondering if you think we will still need to return mean_stats later on.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do, but I won't need it once #135 is removed. It's fine to take out.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I just added a comment saying you can remove it; I'll let you do the honours when you merge #135)

log_fmts = [
make_output_format(s, disc_log_dir) for s in log_fmt_strs
]
self._disc_logger = Logger(disc_log_dir, log_fmts)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI, When I merge with #135 later, probably won't use this second logger. Instead will use a "discrim" context.

This is fine by me for this PR though.

src/imitation/algorithms/adversarial.py Outdated Show resolved Hide resolved
src/imitation/algorithms/adversarial.py Outdated Show resolved Hide resolved
src/imitation/summaries.py Outdated Show resolved Hide resolved
src/imitation/rewards/discrim_net.py Show resolved Hide resolved
src/imitation/rewards/discrim_net.py Outdated Show resolved Hide resolved
@qxcv qxcv force-pushed the gail-reward-net-refactor branch from 3e49a29 to e4be624 Compare January 1, 2020 23:22
@qxcv qxcv requested a review from shwang January 1, 2020 23:32
construction of the discriminator network, and a `tf.Tensor`
representing the desired discriminator logits.
build_discrim_net_kwargs: optional extra keyword arguments for
`build_discrim_net()`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that we have the build_discrim_net_kwargs pattern, could we make build_mlp_discrim_net into a function rather than a class?

Not sure if pytype would be happy with Callable + arbitrary kwargs. (It looks like newer versions of Python will have a more precise Protocol type for defining callable types: https://stackoverflow.com/a/57840786/1091722)

Copy link
Member

@shwang shwang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would like to get rid of the functor class but otherwise is looking good

@qxcv
Copy link
Member Author

qxcv commented Jan 8, 2020

Done! Sorry for the delay, I forgot that I still had a change to make.

@qxcv qxcv requested a review from shwang January 8, 2020 00:47
Copy link
Member

@shwang shwang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks

@shwang shwang merged commit 7f01e0e into master Jan 8, 2020
@shwang shwang deleted the gail-reward-net-refactor branch January 8, 2020 01:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants