NB NumPy is about 20 years old! It has a great impact on Python (see Buffer protocol or operator @
) and an entier area of scientific computations (see Matplotlib or SciPy).
jax.grad
, jax.jvp
or jax.vjp
).jax.distributed
and jax.pmap
).jax.vmap(jax.jit(jax.vmap(jax.grad(...))))
can be easily implemented in extensible way).In general one can install NumPy and JAX as follows (for advanced installations see documentation).
!pip install numpy
!pip install jax
From here onwards we use the imports below.
import numpy as np
import jax.numpy as jnp
The simplest way to create array is a
arr = np.array([1, 2, 3])
arr
arr.ndim
arr.shape
np.zeros(3)
np.ones((2, 3))
We can do the same in JAX! We need only to replace np
with jnp
.
data = np.array([1, 2])
ones = np.ones_like(data)
data - ones
data * ones
data / ones
Again, we can do the same in JAX as follows.
data = jnp.array([1, 2])
ones = jnp.ones_like(data)
jnp.recarray # There is not analogue to numpy.recarray in JAX.
gauss ellimination
laplace operator with diriclet conditions
TODO: Common difference is data types!
Let's draw samples from normal ditribution with NumPy.
np.random.normal(size=(2, 3))
It does not works in JAX. The issue is that a RNG has a state but in the same we expect that any computation is
Consider linear congruential generator (LCG) which is
$$ x_{n + 1} = a x_n + b \quad (\mathrm{mod} M). $$Its state is two numbers $a$ and $b$. Note, LCG is a bad random number generator. If you needs speed condifer xoroshiro
-like RNGS.
Problem. How to generate random $K$ numbers in parallel in two or on device? Skip $K$ numbers and then take next $K$? Replicate somehow RNGs and seed them with different initial seeds? How does this seeding influece quality of pseudorandom?
Use explicit random state in NumPy for sampling.
rng = np.random.RandomState(42)
rng
rng.normal(size=(2, 3))
rng.normal(size=(2, 3))
Almost the same true for JAX.
import jax
key = jax.random.PRNGKey(42)
key
jax.random.normal(key, (2, 3))
jax.random.normal(key, (2, 3))
However, we can split state of the RNG as follows.
key, subkey = jax.random.split(key)
jax.random.normal(subkey, (2, 3))
key, *subkeys = jax.random.split(key, 9)
for i, subkey in enumerate(subkeys, 1):
print(f'sample #{i}:', jax.random.normal(subkey))
The most powerfull and non-trivial feature of JAX is easy in use and implemenation a composition of transformations. Say, we write transformatoins jax.jit
and jax.grad
. They can be implimented in a quite common way (e.g. JIT-compilation with jax.jit
can be implemented as direct tree traversing). But implementation both transformations become more complicated. Moreover, difficulty exponentially increases as a number of different transformations increases. JAX authors solve the issue with so called tagless final encoding or just tagless final (see Kiselyov paper).
In this section we will try some transformations on practice which JAX suggests.
Probably, this is the most famous transformation. It evaluate JAX code in an abstract way and produces native code which can be run directly on CPU, GPU, or TPU.
def selu(x, alpha=1.67, lambda_=1.05):
return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
inp = np.arange(100.)
%timeit selu(inp)
inp = jnp.arange(100.)
%timeit selu(inp).block_until_ready()
selu_jit = jax.jit(selu)
inp = jnp.arange(100.)
%timeit selu_jit(inp).block_until_ready()
NOTE Do not forget to wait until computations end.
NOTE It may be hard to find speed up for simple function but on large ones like large neural networks (e.g. BERT or GPT) speed up can be significant even in comparison to other deep learning framework. The diffirence is up to 20%!
Usually, benifits of JIT become visible under composition of functions or large inputs. Let's try to implement a part of gradient descent optimizer with L2 regularization. Namely, we want to optimizer function $$ p_{n + 1} = p_{n} - u_{n}, $$ where optimization parameters are $p$ and $u$ is a final update to parameters which is calculated as $$ u_{n} = \alpha \cdot (g_n + \lambda * p_{n}). $$ Gradient of origin function is $g_n$, learning rate is $\alpha$ and L2 regularizer is $\lambda$.
NOTE See optax
library which implements optimizers in similar way in JAX.
def sgd(params, updates, learning_rate):
return learning_rate * updates
def weighted(params, updates, l2):
return params + l2 * updates
def make_sgd(learning_rate=1e-1, l2=1e-5):
if l2 == 0.0:
return partial(sgd, learning_rate=learning_rate)
def fn(params, updates):
updates = weighted(params, updates, l2)
updates = sgd(params, updates, learning_rate)
return updates
return fn
opt_fn = make_sgd()
params = np.random.normal(size=(1000, ))
updates = np.random.normal(size=(1000, ))
%timeit opt_fn(params, updates)
opt_fn_jit = jax.jit(opt_fn)
params = jnp.array(params)
updates = jnp.array(updates)
%timeit opt_fn_jit(params, updates).block_until_ready()
Symbolical differentiation is a wide-spreaded feature of numerical frameworks. JAX has advanced functionality to evaluate gradients of functions.
def objective(params):
return (params ** 2).sum()
def objective_grad(params):
return 2 * params.sum()
xs = jnp.array([1, 2], jnp.float32)
objective(xs)
jax.grad(objective)(xs)
objective_grad(xs)
There is a "shortcut" to evaluate function value and its gradients at once.
value, grad = jax.value_and_grad(objective)(xs)
print('value:', value)
print('grad:', grad)
The second-order derivative.
jax.jacobian(jax.grad(objective))(xs)
And the thirdorder derivative.
jax.jacobian(jax.jacobian(jax.grad(objective)))(xs)
And compile the resulting function.
jax.jit(jax.jacobian(jax.jacobian(jax.grad(objective))))(xs)
Also, JAX have routines for evaluation Jacobian-vector product jax.jvp
and vector-Jacobian product jax.vjp
. The latter is well-known as a backpropagation algorithm.
Vectorization transformation is a JAX variant of the classical one numpy.vectorize
. The idea is to apply a function to a tensor along some axis.
In this trivial example we will find a sum of matrix rows and columns with built-in numpy.sum
and jax.vmap
.
mat = jnp.diag(jnp.arange(10.))[:5]
mat
mat.sum(axis=0)
jax.vmap(lambda x: x.sum(), in_axes=(0,))(mat)
mat.sum(axis=1)
jax.vmap(lambda x: x.sum(), in_axes=(1,))(mat)
Now, we can implement jax.jacobian
as folows.
jax.grad(objective)(xs)
jax.vmap(lambda x: jax.grad)
JAX does not allow to throw exceptions in JITed function.
def double(value):
value = jnp.asarray(value)
if (value < 0).any():
raise ValueError('Expected a scalar.')
return 2 * value
jax.jit(double)(1.0) # Error: Throw ConcretizationTypeError exception!
But there is an experimental feature checkify
to check assertion on function arguments.
from jax.experimental.checkify import check, checkify
def double(value):
value = jnp.asarray(value)
check((value >= 0).any(), 'Expected a non-negative values.')
return 2 * value
err, res = jax.jit(checkify(double))(1.0)
err.throw() # OK
err, res = jax.jit(checkify(double))(-1.0)
err.throw() # Error: ValueError: Expected a scalar!
Transformation checkify
wraps computation in Error
monad what enables us to check assertions.
from functools import partial
@partial(jax.jit, static_argnums=(0, 1, 4))
def minimize(obj, opt, params, max_iter=20, verbose=False):
for i in range(max_iter):
value, updates = jax.value_and_grad(obj)(params)
params -= opt(params, updates)
if verbose:
jax.debug.print('[{i}] value={value}', i=i, value=value)
return value, params
value, params = minimize(objective, opt_fn, jnp.ones(2))
value.item()
@partial(jax.jit, static_argnums=(0, 1, 4))
def minimize(obj, opt, params, max_iter=20, verbose=False):
def body_fn(index, state):
_, params = state
value, updates = jax.value_and_grad(obj)(params)
params -= opt(params, updates)
if verbose:
jax.debug.print('[{i}] value={value}', i=i, value=value)
return value, params
value, params = jax.lax.fori_loop(0, max_iter, body_fn, (0.0, params))
return value, params
value, params = minimize(objective, opt_fn, jnp.ones(2))
value.item()
@partial(jax.jit, static_argnums=(0, 1, 4))
def minimize(obj, opt, params, max_iter=20, verbose=False):
def cond_fn(state):
value, _ = state
jax.debug.print('[{i}] value={value}', i=1, value=value)
return value >= 1e-1
def body_fn(state):
_, params = state
value, updates = jax.value_and_grad(obj)(params)
params -= opt(params, updates)
if verbose:
jax.debug.print('[{i}] value={value}', i=i, value=value)
return value, params
value, params = jax.lax.while_loop(cond_fn, body_fn, (float('+inf'), params))
return value, params
value, params = minimize(objective, opt_fn, jnp.ones(2))
value.item()
print(minimize.lower(objective, opt_fn, jnp.ones(2)).as_text())
Sometimes we want to write high-level code without messy indexing. In this case we ussually define types. The issue is that user-defined types causes fails in JAX transformation by default.
Consider canonical decomposition which can be generalized to CANDECOM/PARAFAC or simply CP in case of arbitrary dimension $d$.
$$ A = U \cdot V^T $$We want to calculate an element of matrix $A$ represented in CP-format. The calculation reduces to the following in the Einstein notation.
$$ A_{ij} = U_{ik} V{jk}. $$from dataclasses import dataclass
@dataclass
class CP:
cores: tuple[np.ndarray, ...]
shape: tuple[int, ...]
rank: int
@property
def ndim(self) -> int:
return len(self.shape)
cp = CP(cores=[jnp.ones((3, 2)), jnp.ones((3, 2))],
shape=(3, 3),
rank=2)
def getitem(cp: CP, index):
ix, jx = index
row = cp.cores[0][ix]
col = cp.cores[1][jx]
return row @ col
getitem(cp, (0, 0)) # OK
jax.jit(getitem)(cp, (0, 0)) # Error: throws TypeError exception!
JAX works with trees (pytrees). So, we need say JAX how to serialize our data structure to tree and deserialize it form tree.
from jax.tree_util import register_pytree_node_class
@register_pytree_node_class
class Canonical:
"""Class Canonical is a container for tensor represented with
CP-decomposition.
"""
def __init__(self, cores: list[jax.Array], shape: list[int, ...],
rank: int):
self.shape = tuple(shape)
self.rank = rank
self.cores = cores
def __repr__(self) -> str:
params = f'ndim={self.ndim}, shape={self.shape}, ranks={self.rank}'
return f'{self.__class__.__name__}({params})'
@property
def ndim(self) -> int:
return len(self.shape)
def tree_flatten(self):
return self.cores, {'shape': self.shape, 'rank': self.rank}
@classmethod
def tree_unflatten(cls, treedef, leaves):
return cls(leaves, **treedef)
cp = Canonical(cores=[jnp.ones((3, 2)), jnp.ones((3, 2))],
shape=(3, 3),
rank=2)
jax.jit(getitem)(cp, (0, 0)) # OK
jax.jit(getitem)(cp, (0, 0)) # OK
Why? Wrapped function with jax.jit
performs ser/de to tree and from tree under the hood.
leaves, treedef = jax.tree_util.tree_flatten(cp)
jax.tree_util.tree_unflatten(treedef, leaves)