Glossary term
Glossary term
Infrastructure and Serving
An array computing library, bringing together XLA (Accelerated Linear Algebra) and automatic differentiation for high-performance numerical computing. JAX provides a simple and powerful API for writing accelerated numerical code with composable transformations. JAX provides features such as:
grad (automatic differentiation)
jit (just-in-time compilation)
vmap (automatic vectorization or batching)
pmap (parallelization)
JAX is a language for expressing and composing transformations of numerical code, analogous—but much larger in scope—to Python's NumPy library. (In fact, the .numpy library under JAX is a functionally equivalent, but entirely rewritten version of the Python NumPy library.)
JAX is particularly well-suited for speeding up many machine learning tasks by transforming the models and data into a form suitable for parallelism across GPU and TPU accelerator chips.
Flax, Optax, Pax, and many other libraries are built on the JAX infrastructure.
K
Created for this library
A research lab uses JAX for its high-performance training pipelines because of its strong function transformations and TPU support.
An ML platform team adopts JAX to express research models compactly while still scaling to TPU pods in production.
A startup uses JAX for experimentation on its largest model because pmap and pjit make multi-device parallelism explicit and tractable.
Definition source: Google for Developers Machine Learning Glossary | Creative Commons Attribution 4.0 License