Can't use GPU with Jax in serverless endpoint
Hi, I'm trying to run a serverless worker to perform point tracking on a video. It works ok, but I think that it is running on CPU.
I read that the telemetry on the UI isn't reliable, but the Container Logs indicate that too. There is an image of what they logs say. It finds the Nvidia GPU, but there are problems with Jax I think.
I use the function on the first image to check the device:
And the outputs I get are on the second image:
In my Dockerfile, I'm setting this as base image:
FROM nvidia/cuda:11.8.0-cudnn8-devel-ubuntu20.04
I'm running this command to install the jax version that is supposed to work with CUDA 11.8.
RUN pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Then I install requirements.txt (I don't install Jax again here) and do other stuff
And finally I do this to set the library path for CUDA:
ENV LD_LIBRARY_PATH=/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH
I still can't get to make it work on GPU, if someone could tell me where the problem could be, it would be extremely helpful, thank you.

36 Replies
Unknown User•14mo ago
Message Not Public
Sign In & Join Server To View
try add this to your dockerfile
Unknown User•14mo ago
Message Not Public
Sign In & Join Server To View
I'm also why to use CUDA 11.8 rather than 12.1
Yeah I tried both cuda 12 and 11.8
@galakurpismo3 any use case I might try make Better JAX template though would need to understand how you test it
do i have to run this command in a cmd inside the Worker Container? Or how is it?
you would probably need to add it in docker container
but the container is running on the serverless endpoint right?
workers are basically pods
ok I'll run that command from the python code in the beginning and add your suggestion too
tried to run:
pip install --upgrade "jax[cuda12_local]"
okay, in the dockerfile, right?
or try use this as base https://github.com/NVIDIA/JAX-Toolbox
GitHub
GitHub - NVIDIA/JAX-Toolbox: JAX-Toolbox
JAX-Toolbox. Contribute to NVIDIA/JAX-Toolbox development by creating an account on GitHub.
okay, I'll try yes, thank you
Hi, I think that it worked but there is a new error now, related to cudnn I think, these are the logs:
Starting Serverless Worker | Version 1.6.0 ---
{"requestId": "cbeb73b4-8679-43d1-aaa0-8c68101e76ac-e1", "message": "Started.", "level": "INFO"}
Get inside input_fn
xla_bridge.py :889 Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
xla_bridge.py :889 Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
inference.py :172 Found device: cuda:0
inference.py :176 JAX is not using the GPU. Check your JAX installation and environment configuration.
inference.py :177 JAX backend: gpu
inference.py :182 CUDA_VISIBLE_DEVICES: 0,1
inference.py :183 LD_LIBRARY_PATH: /opt/venv/lib/python3.9/site-packages/cv2/../../lib64:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/nvidia/lib:/usr/local/nvidia/lib64:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
inference.py :187 libcudart.so loaded successfully.
inference.py :189 libcudnn.so loaded successfully.
inference.py :143 Read and resized video, number of frames: 107
E0716 cuda_dnn.cc:535 Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
E0716 cuda_dnn.cc:539 Memory usage: 84536328192 bytes free, 84986691584 bytes total.
E0716 cuda_dnn.cc:535 Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
E0716 cuda_dnn.cc:539 Memory usage: 84536328192 bytes free, 84986691584 bytes total.
inference.py :162 Error during processing: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.
{"requestId": "cbeb73b4-8679-43d1-aaa0-8c68101e76ac-e1", "message": "Finished.", "level": "INFO"}
I've tried with 24GB GPU and 80GB GPU.
I'm using this base image:
FROM nvidia/cuda:12.0.0-cudnn8-devel-ubuntu20.04
Unknown User•14mo ago
Message Not Public
Sign In & Join Server To View
It looks like an issue with vscode there, I don't know if it would be related, I've tried with all gpus and I get the same error every time
Unknown User•14mo ago
Message Not Public
Sign In & Join Server To View
Yeah I tried with all gpus now
Unknown User•14mo ago
Message Not Public
Sign In & Join Server To View
It's Cuda 12.0, with this base image, I think it installs CUDNN 8.8:
https://hub.docker.com/layers/nvidia/cuda/12.0.0-cudnn8-runtime-ubuntu20.04/images/sha256-7d0f83420618c3b337d02cfa8243b8e4a7e002ee4b436dd5c70f71cee176f4a0?context=explore
And for Jax I do this to install it:
RUN pip install --upgrade "jax[cuda12_local]"
Unknown User•14mo ago
Message Not Public
Sign In & Join Server To View
what does this mean?
Unknown User•14mo ago
Message Not Public
Sign In & Join Server To View
aah okay, I'll try 11.8 too, thank you
Unknown User•14mo ago
Message Not Public
Sign In & Join Server To View
I'll try this, I'll tell you if it works, thanks a lot for helping
nvidia/cuda:12.1.0-cudnn8-devel-ubuntu20.04
@galakurpismo3 is your worker open source?
I can share it with you but it's not simple to test, I'll try to share a simplified version
hi, here is a simple version of the worker:
https://github.com/galakurpi/yekar_coaches_point_tracking_simple
for testing it, send the video link i have in this code in that same format:
import requests
url = 'https://api.runpod.ai/v2/sd1ylpcd55dj12/run'
data = {
'input': {
'video_url': 'https://drive.google.com/uc?export=download&id=1SER_MwYt0XyOHOX0UbN30iyMCmeWE-dd'
}
}
headers = {
'Content-Type': 'application/json',
'Authorization': 'Bearer <RUNPOD API KEY MISSING>' # If authentication is needed
}
response = requests.post(url, json=data, headers=headers)
print(response.json())
thank you
GitHub
GitHub - galakurpi/yekar_coaches_point_tracking_simple
Contribute to galakurpi/yekar_coaches_point_tracking_simple development by creating an account on GitHub.
let me know if you test anything or need anything
btw did you make sure to filter cuda version on machines in serverless
Unknown User•14mo ago
Message Not Public
Sign In & Join Server To View
Actually, no, sorry, but the logs showed that CUDA 12.1 was running
But I'll try again with that
Unknown User•14mo ago
Message Not Public
Sign In & Join Server To View
I tried with the filtering of CUDA 12.1 and nothing changed