Tutorials | Install | Documentation | Philosophy
This is not an officially supported Google product.
Objax is an open source machine learning framework that accelerates research and learning thanks to a minimalist object-oriented design and a readable code base. Its name comes from the contraction of Object and JAX -- a popular high-performance framework. Objax is designed by researchers for researchers with a focus on simplicity and understandability. Its users should be able to easily read, understand, extend, and modify it to fit their needs.
This is the developer repository of Objax, there is very little user documentation here, for the full documentation go to objax.readthedocs.io.
You can find READMEs in the subdirectory of this project, for example:
You install Objax using pip
as follows:
pip install --upgrade objax
Objax supports GPUs but assumes that you already have some version of CUDA installed. Here are the extra steps required to install CUDA-enabled jaxlib (jaxlib releases require CUDA 11.2 or newer):
RELEASE_URL="https://storage.googleapis.com/jax-releases/jax_cuda_releases.html"
JAX_VERSION=`python3 -c 'import jax; print(jax.__version__)'`
pip uninstall -y jaxlib
pip install -f $RELEASE_URL jax[cuda]==$JAX_VERSION
For more installation options, see https://github.com/google/jax#pip-installation-gpu-cuda
Here are a few useful options:
# Prevent JAX from taking the whole GPU memory
# (useful if you want to run several programs on a single GPU)
export XLA_PYTHON_CLIENT_PREALLOCATE=false
You can test your installation by running the code below:
import jax
import objax
print(f'Number of GPUs {jax.device_count()}')
x = objax.random.normal(shape=(100, 4))
m = objax.nn.Linear(nin=4, nout=5)
print('Matrix product shape', m(x).shape) # (100, 5)
x = objax.random.normal(shape=(100, 3, 32, 32))
m = objax.nn.Conv2D(nin=3, nout=4, k=3)
print('Conv2D return shape', m(x).shape) # (100, 4, 32, 32)
Typically if you get errors running this using CUDA, it probably means your installation of CUDA or CuDNN has issues.
Clone the code repository:
git clone https://github.com/google/objax.git
cd objax/examples
To cite this repository:
@software{objax2020github,
author = {{Objax Developers}},
title = {{Objax}},
url = {https://github.com/google/objax},
version = {1.2.0},
year = {2020},
}
Here is information about development setup and a guide on adding new code.