Skip to content

Releases: NTT123/pax

v0.2.4

04 Sep 14:16
61029aa
Compare
Choose a tag to compare

Add documentation for Pax.

The website: https://pax.readthedocs.io/en/main/

v0.2.3

03 Sep 05:24
Compare
Choose a tag to compare
  • Fix a bug of using a constant random key.
  • New nn.BatchNorm1D and nn.BatchNorm2D.
  • New utils.RngSeq module.
  • Document site at https://pax.readthedocs.io/en/main
  • New DCGAN example.
  • Add (experimental) mixed precision support.

v0.2.2

01 Sep 06:55
fccbe55
Compare
Choose a tag to compare

New features:

  • use a single dictionary to manage all Pax tree related fields.
  • new MultiHeadAttention module, add an example training a transformer LM on TPU.
  • hk_init method, allowing a converted haiku module to delay parameters' initialization.
  • summary method, return a string representation of the module tree.
  • _scan_fields, and deep_scan scan fields/module for potential bugs.
  • new VAE training example.

v0.2.0

29 Aug 11:12
57dbb8d
Compare
Choose a tag to compare
  • The new version rewrites how a pax.Module manages its states and parameters. The previous version uses type annotation which is not optimal as type annotation is optional in Python. The new version uses a dictionary to store the list of fields which is state/parameter/module/state subtree/parameter subtree/module subtree.

  • All subclasses of pax.Module now have to call super().__init__() at initialization.

  • The new version also allows Haiku module to delay its parameter initialization until it is executed for the first time.
    This is similar to how dm-haiku initialization work. hk_init is a helper method to initialize a haiku module.

  • freeze method convert all trainable parameters to non-trainable states.

  • Add new haiku modules such as conv_1d_transpose, conv_2d_transpose, avg_pool, max_pool, etc.

v0.1.0: First release

28 Aug 09:39
Compare
Choose a tag to compare

This release includes:

  • Basic module system: pax.Module.
  • Basic leaf node: Only two supported nodes: pax.State and pax.Parameter which are subtypes of pax.tree.Leaf.
  • Basic random state: a global rng_key is stored at pax.rng.state._rng_key. Call pax.next_rng_key() to generate a new rng_key.
  • Basic optimizer: pax.Optimizer.
  • nn modules: nn.Linear, nn.BatchNorm, nn.LayerNorm. nn.Conv1d, nn.Conv2d.
  • Convert dm-haiku module to pax.Module with pax.haiku.from_haiku.
  • Convert optax optimizer to pax.Optimizer with pax.optim.from_optax.
  • An example training RNN language model on TPU.
  • An example training a MNIST classifier with checkpoint.