jax.grad
and voila!jax.jit
to speed upimport jax
import jax.numpy as jnp
float32
float64
) by hands. from jax.config import config
config.update("jax_enable_x64", True)
n = 5
x = jax.random.normal(jax.random.PRNGKey(0), (n,))
y = jax.random.normal(jax.random.PRNGKey(10), (n,))
print(x.shape, y.shape)
print(x @ y)
print(x.T @ y)
print(jnp.outer(x, y))
print(x[:, None].shape, y.shape)
print((x[None, :] @ y)[0])
(5,) (5,) 0.5431455433042264 0.5431455433042264 [[-0.17401344 0.09929537 -0.43481767 0.179563 0.54544231] [ 0.12356473 -0.07050838 0.30875849 -0.1275054 -0.38731164] [-0.08535699 0.04870632 -0.21328657 0.08807916 0.26755011] [-0.3445655 0.19661561 -0.8609862 0.35555423 1.08003499] [-0.20590302 0.11749217 -0.51450206 0.21246959 0.64539969]] (5, 1) (5,) 0.5431455433042264
@jax.jit # Just-in-time compilation
def f(x, A, b):
res = A @ x - b
res = jax.ops.index_update(res, 0, 100)
# y = res[res > 1]
# res[0] = 100
return res @ res
gradf = jax.grad(f, argnums=0, has_aux=False)
n = 1000
x = jax.random.normal(jax.random.PRNGKey(0), (n, ))
A = jax.random.normal(jax.random.PRNGKey(0), (n, n))
b = jax.random.normal(jax.random.PRNGKey(0), (n, ))
print("Check correctness", jnp.linalg.norm(gradf(x, A, b) - 2 * A.T @ (A @ x - b)))
# print(gradf(x, A, b))
print("Compare speed")
print("Analytical gradient")
# %timeit 2 * A.T @ (A @ x - b)
print("Grad function")
%timeit gradf(x, A, b).block_until_ready()
jit_gradf = jax.jit(gradf)
print("Jitted grad function")
%timeit jit_gradf(x, A, b).block_until_ready()
Check correctness 1388.1018567160188 Compare speed Analytical gradient Grad function 3.68 ms ± 483 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) Jitted grad function 1.37 ms ± 160 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
hess_func = jax.jit(jax.hessian(f))
print("Check correctness", jnp.linalg.norm(2 * A.T @ A - hess_func(x, A, b)))
print("Time for hessian")
%timeit hess_func(x, A, b).block_until_ready()
print("Emulate hessian and check correctness",
jnp.linalg.norm(jax.jit(hess_func)(x, A, b) - jax.jacfwd(jax.jacrev(f))(x, A, b)))
print("Time of emulating hessian")
hess_umul_func = jax.jit(jax.jacfwd(jax.jacrev(f)))
%timeit hess_umul_func(x, A, b).block_until_ready()
Check correctness 0.0 Time for hessian 95.7 ms ± 4.68 ms per loop (mean ± std. dev. of 7 runs, 10 loops each) Emulate hessian and check correctness 0.0 Time of emulating hessian 100 ms ± 8.79 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)