Skip to content

Penzai 0.1.3 (+ V2 NN API!)

Compare
Choose a tag to compare
@danieldjohnson danieldjohnson released this 28 Jun 01:57
· 44 commits to main since this release

New features:

  • V2 neural network API (penzai.experimental.v2)
    • A redesign of Penzai's neural network system, which introduces first-class mutable state and variable sharing, and removes boilerplate.
    • You can read more about the differences and how to migrate here.
    • We plan to replace the original neural network system with this V2 API in Penzai release 0.2.0.
  • Llama, Mistral, and GPT-NeoX / Pythia support
    • The pretrained transformer implementation has been generalized, and now supports Llama, Mistral, and GPT-NeoX / Pythia pretrained models.
    • (This implementation is specific to the V2 neural network API.)
  • Other features:
    • New LayerStack combinator, which uses jax.lax.scan to efficiently repeat layers with the same structure
    • Named arrays can now be updated using .at[...].set(...) operations. For now, only positional indexing is supported (with broadcasting over named axes).

Bug fixes and improvements:

  • Fixed issue where unit test discovery was not picking up tests in subdirectories (#38)
  • Fixed issue where adding a NamedArray to a JAX array would not correctly lift the JAX array to a NamedArray (#37)

Documentation changes:

  • Added documentation of the V2 API, along with instructions on how to migrate.
  • Added a "How-To Guide" for common tasks (V2 API only)