Can't use GPU with Jax in serverless endpoint
Hi, I'm trying to run a serverless worker to perform point tracking on a video. It works ok, but I think that it is running on CPU.
I read that the telemetry on the UI isn't reliable, but the Container Logs indicate that too. There is an image of what they logs say. It finds the Nvidia GPU, but there are problems with Jax I think.
I use the function on the first image to check the device:
And the outputs I get are on the second image:
In my Dockerfile, I'm setting this as base image:
I'm running this command to install the jax version that is supposed to work with CUDA 11.8.
Then I install requirements.txt (I don't install Jax again here) and do other stuff
And finally I do this to set the library path for CUDA:
I still can't get to make it work on GPU, if someone could tell me where the problem could be, it would be extremely helpful, thank you.
I read that the telemetry on the UI isn't reliable, but the Container Logs indicate that too. There is an image of what they logs say. It finds the Nvidia GPU, but there are problems with Jax I think.
I use the function on the first image to check the device:
And the outputs I get are on the second image:
In my Dockerfile, I'm setting this as base image:
FROM nvidia/cuda:11.8.0-cudnn8-devel-ubuntu20.04I'm running this command to install the jax version that is supposed to work with CUDA 11.8.
RUN pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.htmlThen I install requirements.txt (I don't install Jax again here) and do other stuff
And finally I do this to set the library path for CUDA:
ENV LD_LIBRARY_PATH=/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATHI still can't get to make it work on GPU, if someone could tell me where the problem could be, it would be extremely helpful, thank you.


