Among the innovations driving the popular open source TensorFlow machine learning platform are automatic differentiation ( Autograd ) and XLA (accelerated linear algebra). It is a technology that optimizes the compiler for deep learning. Google JAX is another project that combines the two, with significant advantages in terms of speed and performance. When running on a GPU or TPU, JAX can replace other programs that call NumPy and run faster. Also, using JAX for neural networks makes it easier to add new features than extending large frameworks like TensorFlow.
Here, we introduce an overview including the advantages and limitations of Google JAX and how to install it, and let’s experience Google JAX in Colab.
Definition of Autograd
Autograd is an automatic differentiation engine that started as a research project for Ryan Adams’ Harvard Intelligent Probabilistic Systems Group. Currently, the engine is being maintained but not actively developed. Instead, the Autograd devs are working on Google JAX development, which combines additional features such as XLA JIT compilation with Autograd. The Autograd engine automatically differentiates native Python and NumPy code. The main application area is gradient-based optimization.
TensorFlow’s tf.GradientTape?API is based on a similar concept to Autograd, but the implementation is different. Autograd is written entirely in Python and computes the gradient right in the function, whereas TensorFlow’s gradient tape feature is written in C++ and uses a Python wrapper. TensorFlow uses backpropagation to compute the loss difference, estimate the slope of the loss, and predict the best next step.
Definition of XLA
XLA is a domain-specific compiler for linear algebra developed by TensorFlow. According to the TensorFlow documentation, XLA can accelerate TensorFlow models and improve speed and memory usage without changing the source code. As an example, in the 2020 Google BERT MLPert benchmark , eight Volta V100 GPUs using XLA resulted in 7x performance improvement and 5x batch size improvement.
XLA compiles the TensorFlow graph into a sequence of computational kernels that are created by fitting a given model. Since these kernels are unique to the model, information specific to the model can be utilized for optimization. XLA is also called Just-In-Time (JIT) compiler within TensorFlow. You can enable it using a flag in the @tf.function python decorator like this:
@tf.function(jit_compile=True)
You can also enable XLA in TensorFlow by setting the TF_XLA_FLAGS environment variable or by running the independent tfcompile tool. Besides TensorFlow, there are Google JAX , Julia , PyTorch , and Nx that can generate XLA programs .
Getting Started with Google JAX
The JAX Quick Start of Colab I looked at uses GPU by default. If you prefer TPU, you can choose to use TPU, but the free monthly TPU usage is limited. In addition, if you want to use Colab TPU for Google JAX, you need to initialize it separately .
To start the quickstart, click Open in Colab at the top of the parallel evaluation documentation page in JAX . This will switch you to a live laptop environment. Hit the Connect button on your laptop to connect to the hosted runtime. When I ran the quickstart on the GPU, it was clear how much JAX can accelerate matrix algebra and linear algebra operations. I later checked the JIT acceleration time measured in microseconds on my laptop. If you look at the code, you will be reminded, but most of them express general functions used in deep learning.
Installing JAX
The JAX installation depends on the operating system, the selected CPU and GPU, and the TPU version. For the CPU, it’s simple. For example, to run JAX on a notebook, type:
pip install --upgrade pip
pip install --upgrade "jax[cpu]"
For GPU, CUDA and CuDNN must be installed, and compatible Nvidia drivers are also required. Both require more or less up-to-date versions. On Linux with the latest versions of CUDA and CuDNN, you can install the pre-built CUDA-compatible wheel. Otherwise, you have to build from source. JAX also provides pre-built wheels for Google Cloud TPUs . Cloud TPUs are newer and not backwards compatible than Colab TPUs, but the Colab environment already includes JAX and proper TPU support.
JAX API
The JAX API has three layers. The top-level JAX implements a mirror of the NumPy API, jax.numpy. Almost anything you can do with numpy, you can do with jax.numpy. The limitation of jax.numpy is that unlike NumPy arrays, JAX arrays are immutable. That is, once created, its contents cannot be changed.
The middle tier JAX API is jax.lax, which is stricter and in many cases more powerful than the numpy tier. All operations of jax.numpy are finally expressed in terms of functions defined in jax.lax. jax.numpy implicitly promotes arguments to allow operations between mixed datatypes, but jax.lax does not, providing an explicit promotion function. The lowest layer of the API is XLA. All jax.lax operations are Python wrappers for operations in XLA. All JAX operations are ultimately expressed in terms of these underlying XLA operations, thereby realizing JIT compilation.
Limitations of JAX
JAX conversion and compilation only work with functionally pure Python functions. If a function has side effects, even as simple as a print() statement, if the code is executed multiple times, other side effects will appear. On subsequent executions, print() may print something else or nothing at all. Another limitation of JAX is that in-place mutations are not allowed (since arrays are immutable). This constraint can be circumvented by allowing out-of-place array updates.
updated_array = jax_array.at[1, :].set(1.0)
Also, NumPy defaults to double-precision (float64), while JAX defaults to single-precision numbers (float32). If double precision is absolutely necessary, you can set JAX to jax_enable_x64.
Using JAX for Accelerated Neural Networks
Summing up what we’ve seen so far, it’s definitely possible to implement an accelerated neural network with JAX. But if you think about it the other way around, do you really need to use a new method? Google Research Group and DeepMind have open-sourced several JAX-based neural network libraries .
- Flax is a complete library for training neural networks, along with examples and usage guides.
- Haiku is used for neural network modules,
- Optax is used for gradient processing and optimization,
- RLax is used for reinforcement learning (RL) algorithms, and
- chex is used for stable code and testing.
Learn more about JAX
In addition to the JAX quickstart , JAX has various tutorials that you can run in COLab . The first tutorial shows how to use the jax.numpy function, the grad and value_and_grad functions, and the @jit decorator. Subsequent tutorials go deeper into JIT compilation, and in the last tutorial you can learn how to compile and automatically split functions in single and multi-host environments.
You can also read, and should read, the JAX reference documentation ( starting with the FAQ ) and running advanced tutorials ( starting with the Autodiff Cookbook) on Colab . Finally, starting with the main JAX package , the API documentation is also worth reading
Related:
- How to Start Growing After Doing a Complete Machine Learning Course?
- Patch Tuesday: A must-read checklist for enterprise users
- Why Edge Computing Matters: Definition, How It Works, and Use Cases