JAX Tutorial

Agenda

  1. Introduction.
  2. General array and algebra routines (NumPy-like).
    1. Array creation routines.
    2. Array indexing.
    3. Array mutation.
  3. Random Number generators.
  4. Device placement management.
  5. Composable Transformations
    1. jit
    2. vmap
    3. pmap
    4. checkify
  6. Control flow and static arguments.
  7. Tree utils.
  8. Custom types in JAX.
  9. Custom primitives.
  10. Common gotchas.
  11. Debugging (id_print).

1 Introduction

What tools did ancestors use?

What is NumPy?

NB NumPy is about 20 years old! It has a great impact on Python (see Buffer protocol or operator @) and an entier area of scientific computations (see Matplotlib or SciPy).

What is JAX?

What alternatives exist?

What is the future?

2 Arrays Basics

In general one can install NumPy and JAX as follows (for advanced installations see documentation).

From here onwards we use the imports below.

2.1 Array Creation Basics

The simplest way to create array is a

alt

We can do the same in JAX! We need only to replace np with jnp.

2.2 Array Algebra Basics

Again, we can do the same in JAX as follows.

gauss ellimination

laplace operator with diriclet conditions

TODO: Common difference is data types!

2.3. Array Indexing Basics

3 (Pseudo) Random Number Generators aka (P)RNG

Let's draw samples from normal ditribution with NumPy.

It does not works in JAX. The issue is that a RNG has a state but in the same we expect that any computation is

Consider linear congruential generator (LCG) which is

$$ x_{n + 1} = a x_n + b \quad (\mathrm{mod} M). $$

Its state is two numbers $a$ and $b$. Note, LCG is a bad random number generator. If you needs speed condifer xoroshiro-like RNGS.

Problem. How to generate random $K$ numbers in parallel in two or on device? Skip $K$ numbers and then take next $K$? Replicate somehow RNGs and seed them with different initial seeds? How does this seeding influece quality of pseudorandom?

Use explicit random state in NumPy for sampling.

Almost the same true for JAX.

However, we can split state of the RNG as follows.

4 Device Placement Management

5 Composable Transformations

The most powerfull and non-trivial feature of JAX is easy in use and implemenation a composition of transformations. Say, we write transformatoins jax.jit and jax.grad. They can be implimented in a quite common way (e.g. JIT-compilation with jax.jit can be implemented as direct tree traversing). But implementation both transformations become more complicated. Moreover, difficulty exponentially increases as a number of different transformations increases. JAX authors solve the issue with so called tagless final encoding or just tagless final (see Kiselyov paper).

In this section we will try some transformations on practice which JAX suggests.

1 JIT

Probably, this is the most famous transformation. It evaluate JAX code in an abstract way and produces native code which can be run directly on CPU, GPU, or TPU.

NOTE Do not forget to wait until computations end.

NOTE It may be hard to find speed up for simple function but on large ones like large neural networks (e.g. BERT or GPT) speed up can be significant even in comparison to other deep learning framework. The diffirence is up to 20%!

Usually, benifits of JIT become visible under composition of functions or large inputs. Let's try to implement a part of gradient descent optimizer with L2 regularization. Namely, we want to optimizer function $$ p_{n + 1} = p_{n} - u_{n}, $$ where optimization parameters are $p$ and $u$ is a final update to parameters which is calculated as $$ u_{n} = \alpha \cdot (g_n + \lambda * p_{n}). $$ Gradient of origin function is $g_n$, learning rate is $\alpha$ and L2 regularizer is $\lambda$.

NOTE See optax library which implements optimizers in similar way in JAX.

2 Differentiation

Symbolical differentiation is a wide-spreaded feature of numerical frameworks. JAX has advanced functionality to evaluate gradients of functions.

There is a "shortcut" to evaluate function value and its gradients at once.

The second-order derivative.

And the thirdorder derivative.

And compile the resulting function.

Also, JAX have routines for evaluation Jacobian-vector product jax.jvp and vector-Jacobian product jax.vjp. The latter is well-known as a backpropagation algorithm.

3 Vectorization

Vectorization transformation is a JAX variant of the classical one numpy.vectorize. The idea is to apply a function to a tensor along some axis.

In this trivial example we will find a sum of matrix rows and columns with built-in numpy.sum and jax.vmap.

Now, we can implement jax.jacobian as folows.

4 Checkify

JAX does not allow to throw exceptions in JITed function.

But there is an experimental feature checkify to check assertion on function arguments.

Transformation checkify wraps computation in Error monad what enables us to check assertions.

6 Control Flow and Static Arguments

7 Tree Utils

8 Common Gotchas

9 Debugging

10 Custom Types in JAX

Sometimes we want to write high-level code without messy indexing. In this case we ussually define types. The issue is that user-defined types causes fails in JAX transformation by default.

Consider canonical decomposition which can be generalized to CANDECOM/PARAFAC or simply CP in case of arbitrary dimension $d$.

$$ A = U \cdot V^T $$

We want to calculate an element of matrix $A$ represented in CP-format. The calculation reduces to the following in the Einstein notation.

$$ A_{ij} = U_{ik} V{jk}. $$

JAX works with trees (pytrees). So, we need say JAX how to serialize our data structure to tree and deserialize it form tree.

Why? Wrapped function with jax.jit performs ser/de to tree and from tree under the hood.

11 Custom Primitives