Install JAX on linux based on CUDA 10.1
2 min readDec 28, 2022
CUDA version I have
I have CUDA 10.2 on my linux ubuntu machine
nvidia-smi
Tue Dec 27 13:29:57 2022
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 440.33.01 Driver Version: 440.33.01 CUDA Version: 10.2 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
|===============================+======================+======================|
| 0 TITAN Xp Off | 00000000:03:00.0 Off | N/A |
| 23% 39C P8 9W / 250W | 16MiB / 12187MiB | 0% Default |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: GPU Memory |
| GPU PID Type Process name Usage |
|=============================================================================|
| 0 2879 G /usr/lib/xorg/Xorg 14MiB |
+-----------------------------------------------------------------------------+
Installation commands
Using the following two commands [1]and it works:
pip install --upgrade pip
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Check whether JAX works based on which device
After installation, you can check on terminal to see whether it works!
Python 3.9.12 (main, Jun 1 2022, 11:38:51)
[GCC 7.5.0] :: Anaconda, Inc. on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
2022-12-27 13:25:36.837549: W external/org_tensorflow/tensorflow/tsl/platform/default/dso_loader.cc:66] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2022-12-27 13:25:37.013214: W external/org_tensorflow/tensorflow/tsl/platform/default/dso_loader.cc:66] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2022-12-27 13:25:37.014182: W external/org_tensorflow/tensorflow/tsl/platform/default/dso_loader.cc:66] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
>>> jax.__version__
'0.4.1'
>>> from jax.lib import xla_bridge
>>> print(xla_bridge.get_backend().platform)
gpu
Thanks for reading!