I’ve been struggling with installing Jax with GPU support. Something seems to go wrong each time I try. Sometimes it can’t find the GPU. Other times some dynamic libraries will be missing. More frustratingly, sometimes things get messed up when I install new libraries such as Optax. Finally, I think I’ve got it all working. So here’s how I did it.
conda virtual environments for my Python as it keeps my workstation configurations clean, especially when it comes to work with GPUs.
conda helps me manage the required cuda toolkits and libraries without me having to make permanent installations on my workstation. In this tutorial, I will be installing Jax in a
conda virtual environment.
conda create -n jax python=3.9 pip
Activate the virtual environment using
conda activate jax and proceed with the following steps.
1. Installing nvcc
According to the Jax installation guide, Jax requires
ptxas which is part of the
cuda-nvcc package on
conda. On my workstation, my GPU driver version is
510.108.03 and the corresponding Cuda version is 11.6. To keep everything consistent, I will install
cuda-nvcc built with
conda as follows:
conda install cuda-nvcc -c "nvidia/label/cuda-11.6.2"
For different versions, check out the conda page.
Note that both Jax and cuda-nvcc are installed together using
conda in the installation guide. I have found that it doesn’t work for me. I think it has something to do with how
conda resolves package dependencies.
I will be installing Jax using pip at a later step.
2. Installing cudnn
After all, why not?
conda install cudnn=8.2 -c conda-forge
This is also required for Jax later.
3. Install Jax using pip
I have found that using pip works better for me.
pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Note that I have matched the
cudnn to what we have installed in the virtual environment.
4. Install Optax using pip
Let’s face it, if you are going to do machine learning work, you’re most likely going to use optimizers. Might as well install
optax the default library of optimizers for Jax.
pip install optax
That’s it! It should work and detect the GPU properly.