Glossary term
Glossary term
Infrastructure and Serving
A JAX function that splits code to run across multiple accelerator chips. The user passes a function to pjit, which returns a function that has the equivalent semantics but is compiled into an XLA computation that runs across multiple devices (such as GPUs or TPU cores).
pjit enables users to shard computations without rewriting them by using the SPMD partitioner.
As of March 2023, pjit has been merged with jit. Refer to Distributed arrays and automatic parallelization for more details.
Created for this library
A research engineer uses pjit in JAX to declare how parameters and activations are sharded across a device mesh.
An ML platform team uses pjit to scale training of very large transformer models across many TPU chips.
A research team uses pjit to express its partitioning strategy cleanly in JAX while still iterating quickly during development.
Definition source: Google for Developers Machine Learning Glossary | Creative Commons Attribution 4.0 License