Skip to content

Commit

Permalink
update README
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Feb 1, 2025
1 parent 862210b commit ae744f0
Showing 1 changed file with 51 additions and 2 deletions.
53 changes: 51 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@


# Physical units and unit-aware mathematical system in JAX
# ``BrainUnit``: physical units and unit-aware mathematical system for brain dynamics and AI4Science

<p align="center">
<img alt="Header image of brainunit." src="https://github.com/chaobrain/brainunit/blob/main/docs/_static/brainunit.png" width=50%>
Expand All @@ -20,7 +20,56 @@
</p>


[``brainunit``](https://github.com/chaor/brainunit) provides physical units and unit-aware mathematical system in JAX for brain dynamics and AI4Science
## Motivation


[``brainunit``](https://github.com/chaor/brainunit) provides physical units and unit-aware mathematical system in JAX for brain dynamics and AI4Science.

It is initially designed to enable unit-aware computations in brain dynamics modeling (see our [ecosystem](https://ecosystem-for-brain-dynamics.readthedocs.io/)).

However, its features and capacities can be applied to general domains for scientific computing and AI for science.
We also provide ample examples and tutorials to help users integrate ``brainunit`` into their projects
(see [Unit-aware computation ecosystem](#unit-aware-computation-ecosystem) in the below).


## Features


The uniqueness of ``Brainunit`` lies in that it brings physical units handling and AI-driven computation together in a seamless way:

- It provides over 2,000 commonly used physical units and constants.
- It implements over 500 unit-aware mathematical functions.
- Its physical units and unit-aware functions are fully compatible with JAX, including autograd, JIT, vecterization, parallelization, and others.


A quick example:

```python

import brainunit as u

# Define a physical quantity
x = 3.0 * u.meter
x
# [out] 3. * meter

# autograd
f = lambda x: x ** 3
u.autograd.grad(f)(x)
# [out] 27. * meter2


# JIT
import jax
jax.jit(f)(x)
# [out] 27. * klitre

# vmap
jax.vmap(f)(u.math.arange(0. * u.mV, 10. * u.mV, 1. * u.mV))
# [out] ArrayImpl([ 0., 1., 8., 27., 64., 125., 216., 343., 512., 729.],
# dtype=float32) * mvolt3
```



## Installation
Expand Down

0 comments on commit ae744f0

Please sign in to comment.