diff --git a/examples/mliap/jax/README.md b/examples/mliap/jax/README.md index ccaa234034..5cc0c49f16 100644 --- a/examples/mliap/jax/README.md +++ b/examples/mliap/jax/README.md @@ -33,11 +33,11 @@ Use same Python dependencies as above, with some extra changes: 1. Make sure you install cupy properly! E.g. - python -m pip install cupy-cuda12x + python -m pip install cupy-cuda12x 2. Install JAX for GPU/CUDA: - python -m pip install --trusted-host storage.googleapis.com --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html + python -m pip install --trusted-host storage.googleapis.com --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 3. Install cudNN: https://developer.nvidia.com/cudnn