Install JAX on linux based on CUDA 10.1

Jimmy (xiaoke) Shen
2 min readDec 28, 2022

--

CUDA version I have

I have CUDA 10.2 on my linux ubuntu machine

nvidia-smi
Tue Dec 27 13:29:57 2022
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 440.33.01 Driver Version: 440.33.01 CUDA Version: 10.2 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
|===============================+======================+======================|
| 0 TITAN Xp Off | 00000000:03:00.0 Off | N/A |
| 23% 39C P8 9W / 250W | 16MiB / 12187MiB | 0% Default |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes: GPU Memory |
| GPU PID Type Process name Usage |
|=============================================================================|
| 0 2879 G /usr/lib/xorg/Xorg 14MiB |
+-----------------------------------------------------------------------------+

Installation commands

Using the following two commands [1]and it works:

pip install --upgrade pip
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Check whether JAX works based on which device

After installation, you can check on terminal to see whether it works!

Python 3.9.12 (main, Jun  1 2022, 11:38:51) 
[GCC 7.5.0] :: Anaconda, Inc. on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
2022-12-27 13:25:36.837549: W external/org_tensorflow/tensorflow/tsl/platform/default/dso_loader.cc:66] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2022-12-27 13:25:37.013214: W external/org_tensorflow/tensorflow/tsl/platform/default/dso_loader.cc:66] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2022-12-27 13:25:37.014182: W external/org_tensorflow/tensorflow/tsl/platform/default/dso_loader.cc:66] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
>>> jax.__version__
'0.4.1'
>>> from jax.lib import xla_bridge
>>> print(xla_bridge.get_backend().platform)
gpu

Thanks for reading!

Reference

[1]https://github.com/google/jax#installation

--

--

Responses (1)