You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
A redesign of Penzai's neural network system, which introduces first-class mutable state and variable sharing, and removes boilerplate.
You can read more about the differences and how to migrate here.
We plan to replace the original neural network system with this V2 API in Penzai release 0.2.0.
Llama, Mistral, and GPT-NeoX / Pythia support
The pretrained transformer implementation has been generalized, and now supports Llama, Mistral, and GPT-NeoX / Pythia pretrained models.
(This implementation is specific to the V2 neural network API.)
Other features:
New LayerStack combinator, which uses jax.lax.scan to efficiently repeat layers with the same structure
Named arrays can now be updated using .at[...].set(...) operations. For now, only positional indexing is supported (with broadcasting over named axes).
Bug fixes and improvements:
Fixed issue where unit test discovery was not picking up tests in subdirectories (#38)
Fixed issue where adding a NamedArray to a JAX array would not correctly lift the JAX array to a NamedArray (#37)
Documentation changes:
Added documentation of the V2 API, along with instructions on how to migrate.
Added a "How-To Guide" for common tasks (V2 API only)