# JAX Tutorial

## Agenda

1. [Introduction][1].
2. [General array and algebra routines (NumPy-like)][2].
    1. Array creation routines.
    2. Array indexing.
    3. Array mutation.
3. Random Number generators.
4. Device placement management.
5. Composable Transformations
    1. jit
    2. vmap
    3. pmap
    4. checkify
6. Control flow and static arguments.
7. Tree utils.
8. Custom types in JAX.
9. Custom primitives.
10. Common gotchas.
11. Debugging (id_print).

[1]: #1-Introduction
[2]: #2-Arrays-Basics

## 1 Introduction

### What tools did ancestors use?

- Low-level programming languages like Fortran or C/C++. It's bad: a lot of bugs, slow development, requires knowledge about hardware or OS internals.
- High-level programming languages like Matlab. It's a litte bit better: faster development, interactive development, slow execution.

### What is NumPy?

- Performant numerical algebra routines in Python.
- Lowering to common BLAS/LAPACK (e.g. Intel oneAPI aka MKL).
- Inspired by Matlab and Fortran.
- Implimented as pure Python library with a set of Python extensions written in plain C.

**NB** NumPy is about 20 years old! It has a great impact on Python (see [Buffer protocol][4] or operator `@`) and an entier area of scientific computations (see [Matplotlib][2] or [SciPy][3]).

[1]: https://github.com/numpy/numpy
[2]: https://github.com/matplotlib/matplotlib
[3]: https://github.com/scipy/scipy
[4]: https://docs.python.org/3/c-api/buffer.html

### What is JAX?

- "Next-generation" NumPy.
- Advanced support for native hardware: optimization and compilation to the actual hardware (JIT-compilation).
- Advanced support for math accelerator (GPU or TPU).
- Autograd out of the box (e.g. `jax.grad`, `jax.jvp` or `jax.vjp`).
- Distributed execution out of the box (`jax.distributed` and `jax.pmap`).
- Composable transformation (e.g. `jax.vmap(jax.jit(jax.vmap(jax.grad(...))))` can be easily implemented in extensible way).

### What alternatives exist?

- Use high-level Python libraries. TensorFlow, PyTorch, Theano, Numba, CuPy are well-known examples.
- We still can write Python extensions in "low-level" programming languages like C/C++, Rust, Fortran.
- Do everying in other "high-level" language like Julia, R, Matlab, etc.

### What is the future?

- One can see convergance of different tools to a common set of ideas or base technologies. For example, Numba, Julia, TensorFlow, JAX, flang (LLVM's Fortran) are based on LLVM compiler infrastructure.
- Common execution runtimes exist for diffrent frameworks e.g. IREE or ONNX.
- Different libraries/lagnguages influence each other (e.g. JAX causes arrival of [functorch][1]).
- Other proof-of-concept experimental languages like [DEX][2].

[1]: https://github.com/pytorch/functorch
[2]: https://github.com/google-research/dex-lang

## 2 Arrays Basics

In general one can install NumPy and JAX as follows (for advanced installations see [documentation][1]).

[1]: https://github.com/google/jax#installation

In [None]:
!pip install numpy

In [None]:
!pip install jax

From here onwards we use the imports below.

In [None]:
import numpy as np

In [None]:
import jax.numpy as jnp

### 2.1 Array Creation Basics

The simplest way to create array is a 

![alt](img/np_array.png)

In [None]:
arr = np.array([1, 2, 3])
arr

In [None]:
arr.ndim

In [None]:
arr.shape

In [None]:
np.zeros(3)

In [None]:
np.ones((2, 3))

We can do the same in JAX! We need only to replace `np` with `jnp`.

### 2.2 Array Algebra Basics

![](img/np_sub_mult_divide.png)

In [None]:
data = np.array([1, 2])
ones = np.ones_like(data)

In [None]:
data - ones

In [None]:
data * ones

In [None]:
data / ones

Again, we can do the same in JAX as follows.

In [None]:
data = jnp.array([1, 2])
ones = jnp.ones_like(data)

In [None]:
jnp.recarray  # There is not analogue to numpy.recarray in JAX.

gauss ellimination

laplace operator with diriclet conditions

TODO: Common difference is data types!

### 2.3. Array Indexing Basics

## 3 (Pseudo) Random Number Generators aka (P)RNG

Let's draw samples from normal ditribution with NumPy.

In [None]:
np.random.normal(size=(2, 3))

It does not works in JAX. The issue is that a RNG has a state but in the same we expect that any computation is

- reproducible,
- parallelizable,
- vectorisable.

Consider linear congruential generator (LCG) which is 

$$
    x_{n + 1} = a x_n + b \quad (\mathrm{mod} M).
$$
Its state is two numbers $a$ and $b$. Note, LCG is a bad random number generator. If you needs speed condifer `xoroshiro`-like RNGS.

**Problem.** How to generate random $K$ numbers in parallel in two or on device? Skip $K$ numbers and then take next $K$? Replicate somehow RNGs and seed them with different initial seeds? How does this seeding influece quality of pseudorandom?

Use explicit random state in NumPy for sampling.

In [None]:
rng = np.random.RandomState(42)
rng

In [None]:
rng.normal(size=(2, 3))

In [None]:
rng.normal(size=(2, 3))

Almost the same true for JAX.

In [None]:
import jax

In [None]:
key = jax.random.PRNGKey(42)
key

In [None]:
jax.random.normal(key, (2, 3))

In [None]:
jax.random.normal(key, (2, 3))

However, we can split state of the RNG as follows.

In [None]:
key, subkey = jax.random.split(key)
jax.random.normal(subkey, (2, 3))

In [None]:
key, *subkeys = jax.random.split(key, 9)
for i, subkey in enumerate(subkeys, 1):
    print(f'sample #{i}:', jax.random.normal(subkey))

### 4 Device Placement Management

### 5 Composable Transformations

The most powerfull and non-trivial feature of JAX is easy in use and implemenation a composition of transformations. Say, we write transformatoins `jax.jit` and `jax.grad`. They can be implimented in a quite common way (e.g. JIT-compilation with `jax.jit` can be implemented as direct tree traversing). But implementation both transformations become more complicated. Moreover, difficulty exponentially increases as a number of different transformations increases. JAX authors solve the issue with so called tagless final encoding or just tagless final (see [Kiselyov paper][1]).

In this section we will try some transformations on practice which JAX suggests.

[1]: https://okmij.org/ftp/tagless-final/#tagless-final

### 1 JIT

Probably, this is the most famous transformation. It evaluate JAX code in an abstract way and produces native code which can be run directly on CPU, GPU, or TPU.

In [None]:
def selu(x, alpha=1.67, lambda_=1.05):
    return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

In [None]:
inp = np.arange(100.)
%timeit selu(inp)

In [None]:
inp = jnp.arange(100.)
%timeit selu(inp).block_until_ready()

In [None]:
selu_jit = jax.jit(selu)
inp = jnp.arange(100.)
%timeit selu_jit(inp).block_until_ready()

**NOTE** Do not forget to wait until computations end.

**NOTE** It may be hard to find speed up for simple function but on large ones like large neural networks (e.g. BERT or GPT) speed up can be significant even in comparison to other deep learning framework. The diffirence is up to 20%!

Usually, benifits of JIT become visible under composition of functions or large inputs. Let's try to implement a part of gradient descent optimizer with L2 regularization. Namely, we want to optimizer function
$$
    p_{n + 1} = p_{n} - u_{n},
$$
where optimization parameters are $p$ and $u$ is a final update to parameters which is calculated as
$$
    u_{n} = \alpha \cdot (g_n + \lambda * p_{n}).
$$
Gradient of origin function is $g_n$, learning rate is $\alpha$ and L2 regularizer is $\lambda$.

**NOTE** See `optax` library which implements optimizers in similar way in JAX.

In [None]:
def sgd(params, updates, learning_rate):
    return learning_rate * updates

In [None]:
def weighted(params, updates, l2):
    return params + l2 * updates

In [None]:
def make_sgd(learning_rate=1e-1, l2=1e-5):
    if l2 == 0.0:
        return partial(sgd, learning_rate=learning_rate)
    
    def fn(params, updates):
        updates = weighted(params, updates, l2)
        updates = sgd(params, updates, learning_rate)
        return updates
    
    return fn

In [None]:
opt_fn = make_sgd()

In [None]:
params = np.random.normal(size=(1000, ))
updates = np.random.normal(size=(1000, ))

In [None]:
%timeit opt_fn(params, updates)

In [None]:
opt_fn_jit = jax.jit(opt_fn)

In [None]:
params = jnp.array(params)
updates = jnp.array(updates)

In [None]:
%timeit opt_fn_jit(params, updates).block_until_ready()

### 2 Differentiation

Symbolical differentiation is a wide-spreaded feature of numerical frameworks. JAX has advanced functionality to evaluate gradients of functions.

In [None]:
def objective(params):
    return (params ** 2).sum()


def objective_grad(params):
    return 2 * params.sum()

In [None]:
xs = jnp.array([1, 2], jnp.float32)

In [None]:
objective(xs)

In [None]:
jax.grad(objective)(xs)

In [None]:
objective_grad(xs)

There is a "shortcut" to evaluate function value and its gradients at once.

In [None]:
value, grad = jax.value_and_grad(objective)(xs)
print('value:', value)
print('grad:', grad)

The second-order derivative.

In [None]:
jax.jacobian(jax.grad(objective))(xs)

And the thirdorder derivative.

In [None]:
jax.jacobian(jax.jacobian(jax.grad(objective)))(xs)

And compile the resulting function.

In [None]:
jax.jit(jax.jacobian(jax.jacobian(jax.grad(objective))))(xs)

Also, JAX have routines for evaluation Jacobian-vector product `jax.jvp` and vector-Jacobian product `jax.vjp`. The latter is well-known as a backpropagation algorithm.

### 3 Vectorization

Vectorization transformation is a JAX variant of the classical one `numpy.vectorize`. The idea is to apply a function to a tensor along some axis.

 In this trivial example we will find a sum of matrix rows and columns with built-in `numpy.sum` and `jax.vmap`.

In [None]:
mat = jnp.diag(jnp.arange(10.))[:5]
mat

In [None]:
mat.sum(axis=0)

In [None]:
jax.vmap(lambda x: x.sum(), in_axes=(0,))(mat)

In [None]:
mat.sum(axis=1)

In [None]:
jax.vmap(lambda x: x.sum(), in_axes=(1,))(mat)

Now, we can implement `jax.jacobian` as folows.

In [None]:
jax.grad(objective)(xs)

In [None]:
jax.vmap(lambda x: jax.grad)

### 4 Checkify

JAX does not allow to throw exceptions in JITed function.

In [None]:
def double(value):
    value = jnp.asarray(value)
    if (value < 0).any():
        raise ValueError('Expected a scalar.')
    return 2 * value

In [None]:
jax.jit(double)(1.0)  # Error: Throw ConcretizationTypeError exception!

But there is an experimental feature `checkify` to check assertion on function arguments.

In [None]:
from jax.experimental.checkify import check, checkify

In [None]:
def double(value):
    value = jnp.asarray(value)
    check((value >= 0).any(), 'Expected a non-negative values.')
    return 2 * value

In [None]:
err, res = jax.jit(checkify(double))(1.0)
err.throw()  # OK

In [None]:
err, res = jax.jit(checkify(double))(-1.0)
err.throw()  # Error: ValueError: Expected a scalar!

Transformation `checkify` wraps computation in `Error` monad what enables us to check assertions.

### 6 Control Flow and Static Arguments

In [None]:
from functools import partial

In [None]:
@partial(jax.jit, static_argnums=(0, 1, 4))
def minimize(obj, opt, params, max_iter=20, verbose=False):    
    for i in range(max_iter):
        value, updates = jax.value_and_grad(obj)(params)
        params -= opt(params, updates)
        if verbose:
            jax.debug.print('[{i}] value={value}', i=i, value=value)
    return value, params

In [None]:
value, params = minimize(objective, opt_fn, jnp.ones(2))
value.item()

In [None]:
@partial(jax.jit, static_argnums=(0, 1, 4))
def minimize(obj, opt, params, max_iter=20, verbose=False):    
    def body_fn(index, state):
        _, params = state
        value, updates = jax.value_and_grad(obj)(params)
        params -= opt(params, updates)
        if verbose:
            jax.debug.print('[{i}] value={value}', i=i, value=value)
        return value, params
    value, params = jax.lax.fori_loop(0, max_iter, body_fn, (0.0, params))
    return value, params

In [None]:
value, params = minimize(objective, opt_fn, jnp.ones(2))
value.item()

In [None]:
@partial(jax.jit, static_argnums=(0, 1, 4))
def minimize(obj, opt, params, max_iter=20, verbose=False):    
    def cond_fn(state):
        value, _ = state
        jax.debug.print('[{i}] value={value}', i=1, value=value)
        return value >= 1e-1
    
    def body_fn(state):
        _, params = state
        value, updates = jax.value_and_grad(obj)(params)
        params -= opt(params, updates)
        if verbose:
            jax.debug.print('[{i}] value={value}', i=i, value=value)
        return value, params
    
    value, params = jax.lax.while_loop(cond_fn, body_fn, (float('+inf'), params))
    return value, params

In [None]:
value, params = minimize(objective, opt_fn, jnp.ones(2))
value.item()

In [None]:
print(minimize.lower(objective, opt_fn, jnp.ones(2)).as_text())

### 7 Tree Utils

### 8 Common Gotchas

### 9 Debugging

### 10 Custom Types in JAX

Sometimes we want to write high-level code without messy indexing. In this case we ussually define types. The issue is that user-defined types causes fails in JAX transformation by default.

Consider canonical decomposition which can be generalized to CANDECOM/PARAFAC or simply CP in case of arbitrary dimension $d$.

$$
    A = U \cdot V^T
$$

We want to calculate an element of matrix $A$ represented in CP-format. The calculation reduces to the following in the Einstein notation.

$$
    A_{ij} = U_{ik} V{jk}.
$$

In [None]:
from dataclasses import dataclass

In [None]:
@dataclass
class CP:
    
    cores: tuple[np.ndarray, ...]
        
    shape: tuple[int, ...]
        
    rank: int
        
    @property
    def ndim(self) -> int:
        return len(self.shape)

In [None]:
cp = CP(cores=[jnp.ones((3, 2)), jnp.ones((3, 2))],
        shape=(3, 3),
        rank=2)

In [None]:
def getitem(cp: CP, index):
    ix, jx = index
    row = cp.cores[0][ix]
    col = cp.cores[1][jx]
    return row @ col

In [None]:
getitem(cp, (0, 0))  # OK

In [None]:
jax.jit(getitem)(cp, (0, 0))  # Error: throws TypeError exception!

JAX works with trees (pytrees). So, we need say JAX how to serialize our data structure to tree and deserialize it form tree.

In [None]:
from jax.tree_util import register_pytree_node_class

In [None]:
@register_pytree_node_class
class Canonical:
    """Class Canonical is a container for tensor represented with
    CP-decomposition.
    """

    def __init__(self, cores: list[jax.Array], shape: list[int, ...],
                 rank: int):
        self.shape = tuple(shape)
        self.rank = rank
        self.cores = cores

    def __repr__(self) -> str:
        params = f'ndim={self.ndim}, shape={self.shape}, ranks={self.rank}'
        return f'{self.__class__.__name__}({params})'

    @property
    def ndim(self) -> int:
        return len(self.shape)
    
    def tree_flatten(self):
        return self.cores, {'shape': self.shape, 'rank': self.rank}

    @classmethod
    def tree_unflatten(cls, treedef, leaves):
        return cls(leaves, **treedef)

In [None]:
cp = Canonical(cores=[jnp.ones((3, 2)), jnp.ones((3, 2))],
               shape=(3, 3),
               rank=2)

In [None]:
jax.jit(getitem)(cp, (0, 0))  # OK

In [None]:
jax.jit(getitem)(cp, (0, 0))  # OK

Why? Wrapped function with `jax.jit` performs ser/de to tree and from tree under the hood.

In [None]:
leaves, treedef = jax.tree_util.tree_flatten(cp)

In [None]:
jax.tree_util.tree_unflatten(treedef, leaves)

### 11 Custom Primitives