Automatic differentiation with JAX

Main features

How to compute gradient of your objective?

Random numbers in JAX

Summary