JAX (software)

JAX is a machine learning framework for transforming numerical functions developed by Google with some contributions from Nvidia.

[2][3][4] It is described as bringing together a modified version of autograd (automatic obtaining of the gradient function through differentiation of a function) and OpenXLA's XLA (Accelerated Linear Algebra).

It is designed to follow the structure and workflow of NumPy as closely as possible and works with various existing frameworks such as TensorFlow and PyTorch.

The final line should outputː The below code demonstrates the jit function's optimization through fusion.

The below code demonstrates the pmap function's parallelization for matrix multiplication.

Illustration video of vectorized addition