Accelerating Scientific Computing with JAX-LBM

Mehdi Ataei

Hesam Salehipour


Every time you get in a modern airplane, sit in a new car, or witness a SpaceX launch, chances are that those designs were perfected using a kind of virtual testing based on Computational Fluid Dynamics (CFD). CFD is also used extensively in climate modeling and weather simulation allowing scientists to predict storms and cyclones. As a wonder of the modern computing age, CFD has deep historical roots that can be traced back to the pioneering works of Archimedes to DaVinci, to Lord Kelvin and Navier & Stokes.

The mathematics developed around the study of fluid mechanics are notoriously complex, and solving those equations analytically remains so challenging that there is still a millennium prize associated with the proof of existence of unique solutions for these fundamental equations of nature. With CFD and modern digital computers, you can query to your heart’s content and look at flow characteristics with the click of a mouse. The catch, however, is that these computations are expensive, due to tremendous computational requirements, especially when you’re doing a lot of design iterations. Tuning CFD algorithms can also be a daunting task, as they have traditionally been implemented in low-level programming languages, demanding significant effort to achieve adequate computational performance and efficiency. While the resulting algorithms are undeniably powerful, their development and improvement can be laborious and, more importantly, require expertise in high-performance computing (HPC).

In recent years, Graphics Processing Units (GPUs) have become ubiquitous, and were developed to specially accelerate computer graphics and games. In the age of the cloud, we now have access to multiple nodes each containing multiple GPUs and CPUs in a distributed configuration. Leveraging these computing environments for CFD algorithms has traditionally been a difficult and painstaking undertaking with few solutions available to use. JAX-LBM is aimed to completely change this paradigm.

With all the excitement and focus on AI and machine learning (ML), massive investment has occurred to improve computation for ML. Companies like Google and NVIDIA have created a host of libraries and languages that are now also useful in other domains like simulation. JAX is one such technology, the marriage between an optimizing compiler-based technique to accelerate linear algebra and auto-differentiation. JAX also provides an easy-to-use NumPy-like environment for algorithm development, where user code is automatically translated to highly efficient parallel computations for single-node or distributed cluster architectures. On the other hand, the Lattice Boltzmann Method (LBM) is a unique approach in CFD that is tailor-made for massive scalability on more modern, highly parallel (many computing elements) hardware. For each grid point, LBM defines local calculations that only depend on the information stored by immediate neighbouring points, reducing the need to move data around between computing elements. Using JAX to implement LBM seemed a natural fit, harnessing the strengths of both frameworks delivering an optimal algorithm for parallelization and a highly extendable and performant interface for CFD research.

Within Autodesk Research, the Simulation, Optimization, and Systems (SOS) group has been leading a collaboration with NVIDIA to explore the use of JAX for LBM simulations on Nvidia’s latest GPUs. In their recent GTC presentation, Autodesk researchers demonstrated a series of innovative techniques to improve the scalability of their LBM implementation. Specifically, they utilized explicit collectives provided by JAX for the “streaming” step of the LBM simulation, which facilitated better control of the parallelization of the simulation across multiple GPUs, while leveraging the recently introduced shard-map Single Program Multiple Data (SPMD) parallelism to maintain the NumPy-like interface for the remaining codebase. This approach liberated domain experts, who typically are not HPC experts, from grappling with the parallelization implementations and instead enabled them to work directly with NumPy-like JAX arrays, significantly shortening the development process from several months to just a few weeks.

DrivAer model in a wind-tunnel using KBC Lattice Boltzmann Simulation with approx. 317 million voxels

Flow over a NACA airfoil using KBC Lattice Boltzmann Simulation with approx. 100 million voxels


To evaluate the effectiveness of these techniques, the team conducted weak and strong scaling tests on the simulation. Weak scaling tests involved increasing the size of the simulation while maintaining a constant workload per processor. Strong scaling tests involved dividing the simulation into smaller parts and increasing the number of processors. The results demonstrate high weak scaling and great efficiency, enabling simulations with billions of voxels on a single DGX machine.

The benefits of these improvements are evident: The ability to conduct more realistic and accurate simulations.  As a showcase, the team simulated the DrivAer vehicle model and a NACA airfoil in a wind tunnel, employing hundreds of millions of voxels. Remarkably, these intricate simulations can be completed in just a few minutes using a single DGX machine. This rapid analysis is essential in the design processes, as it provides engineers the opportunity to evaluate the aerodynamics of their designs in a controlled setting and iterate quickly to gain deeper insight.

Looking ahead, the fusion of JAX and LBM, or JAX-LBM, is believed to pave the way for ground-breaking research in generative design of multi-physics systems and physics-based AI, as it integrates seamlessly with a wealth of optimization and machine learning tools, such as Flax and Haiku, that are native to JAX. This integration expands the capabilities of LBM simulations further and facilitates the adoption of advanced machine learning techniques in CFD research that are yet to be explored.

Mehdi Ataei is a Senior Research Scientist and Hesam Salehipour is a Principal Computational Physics Research Scientist at Autodesk. 

Get in touch

Have we piqued your interest? Get in touch if you’d like to learn more about Autodesk Research, our projects, people, and potential collaboration opportunities

Contact us