Skip to main content

Installing Jax with GPU Support

I’ve been struggling with installing Jax with GPU support. Something seems to go wrong each time I try. Sometimes it can’t find the GPU. Other times some dynamic libraries will be missing. More frustratingly, sometimes things get messed up when I install new libraries such as Optax. Finally, I think I’ve got it all working. So here’s how I did it.

Basic environment

I use conda virtual environments for my Python as it keeps my workstation configurations clean, especially when it comes to work with GPUs. conda helps me manage the required cuda toolkits and libraries without me having to make permanent installations on my workstation. In this tutorial, I will be installing Jax in a conda virtual environment.

conda create -n jax python=3.9 pip

Activate the virtual environment using conda activate jax and proceed with the following steps.

1. Installing nvcc

According to the Jax installation guide, Jax requires ptxas which is part of the cuda-nvcc package on conda. On my workstation, my GPU driver version is 510.108.03 and the corresponding Cuda version is 11.6. To keep everything consistent, I will install cuda-nvcc built with cuda-11.6 using conda as follows:

conda install cuda-nvcc -c "nvidia/label/cuda-11.6.2"

For different versions, check out the conda page.

Note that both Jax and cuda-nvcc are installed together using conda in the installation guide. I have found that it doesn’t work for me. I think it has something to do with how conda resolves package dependencies.

I will be installing Jax using pip at a later step.

2. Installing cudnn

After all, why not?

conda install cudnn=8.2 -c conda-forge

This is also required for Jax later.

3. Install Jax using pip

I have found that using pip works better for me.

pip install "jax[cuda11_cudnn82]" -f 

Note that I have matched the cuda and cudnn to what we have installed in the virtual environment.

4. Install Optax using pip

Let’s face it, if you are going to do machine learning work, you’re most likely going to use optimizers. Might as well install optax the default library of optimizers for Jax.

pip install optax

That’s it! It should work and detect the GPU properly.