RunpodR
Runpod2y ago
54 replies
galakurpismo3

Can't use GPU with Jax in serverless endpoint

Hi, I'm trying to run a serverless worker to perform point tracking on a video. It works ok, but I think that it is running on CPU.

I read that the telemetry on the UI isn't reliable, but the Container Logs indicate that too. There is an image of what they logs say. It finds the Nvidia GPU, but there are problems with Jax I think.

I use the function on the first image to check the device:
And the outputs I get are on the second image:


In my Dockerfile, I'm setting this as base image:
FROM nvidia/cuda:11.8.0-cudnn8-devel-ubuntu20.04

I'm running this command to install the jax version that is supposed to work with CUDA 11.8.
RUN pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Then I install requirements.txt (I don't install Jax again here) and do other stuff

And finally I do this to set the library path for CUDA:
ENV LD_LIBRARY_PATH=/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH


I still can't get to make it work on GPU, if someone could tell me where the problem could be, it would be extremely helpful, thank you.
image.png
image.png
Was this page helpful?