, Associate Applied Machine Learning Specialist, Vector Institute
, Software Engineer, NVIDIA
, Senior Research Scientist, Autodesk Research
, Researcher, Google Research
, Senior Research Scientist, InstaDeep
JAX is a high-performance framework for Python and NumPy programs that achieves performance through compilation, supports automatic differentiation, and provides composable transformations for automatic parallelization, making it one of the easiest-to-use frameworks for large-scale neural network training.
Learn about the current state of JAX on GPUs and why JAX and GPUs are such a great pairing. We'll cover how to contribute to GPU support on JAX, provide the latest performance and capabilities for generative AI such as large-scale language models and diffusion models, discuss scientific computing applications, and highlight the latest features in JAX for the new Hopper architecture. As possible, we'll provide glimpses into what's in store for JAX support on GPUs as well.
If you're using JAX today or are considering it, want to contribute to JAX, or want to stay abreast of the latest developments in the newest accelerated framework on the block, this is an important session for you.