Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/jio.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ jobs:
build:
needs: metadata
strategy:
fail-fast: true
fail-fast: false
matrix:
ARCHITECTURE: [amd64, arm64]
runs-on: [self-hosted, "${{ matrix.ARCHITECTURE }}", "small"]
Expand All @@ -67,7 +67,7 @@ jobs:
ARCHITECTURE: ${{ matrix.ARCHITECTURE }}
ARTIFACT_NAME: artifact-jio-build
BADGE_FILENAME: badge-jio-build
BASE_IMAGE: nvcr.io/nvidia/cuda-dl-base:25.06-cuda12.9-devel-ubuntu24.04
BASE_IMAGE: nvcr.io/nvidia/cuda-dl-base:25.11-cuda13.0-devel-ubuntu24.04
BUILD_DATE: ${{ needs.metadata.outputs.BUILD_DATE }}
CONTAINER_NAME: jio
DOCKERFILE: jax-inference-offloading/dockerfile/oss.dockerfile
Expand Down
20 changes: 11 additions & 9 deletions jax-inference-offloading/dockerfile/oss.dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

ARG BASE_IMAGE=nvcr.io/nvidia/cuda-dl-base:25.06-cuda12.9-devel-ubuntu24.04
ARG BASE_IMAGE=nvcr.io/nvidia/cuda-dl-base:25.11-cuda13.0-devel-ubuntu24.04
ARG URL_JIO=https://github.com/NVIDIA/JAX-Toolbox.git
ARG REF_JIO=main
ARG URL_TUNIX=https://github.com/google/tunix.git
Expand Down Expand Up @@ -76,7 +76,9 @@ EOF
RUN <<"EOF" bash -ex -o pipefail
mkdir -p /opt/pip-tools.d
pip freeze | grep wheel >> /opt/pip-tools.d/overrides.in
echo "jax[cuda12_local]" >> /opt/pip-tools.d/requirements.in
if [[ $(uname -m) == "x86_64" ]]; then
echo "vllm @ https://github.com/vllm-project/vllm/releases/download/v0.11.2/vllm-0.11.2+cu130-cp38-abi3-manylinux1_x86_64.whl" >> /opt/pip-tools.d/requirements.in
fi
echo "-e file://${SRC_PATH_JIO}" >> /opt/pip-tools.d/requirements.in
echo "-e file://${SRC_PATH_TUNIX}" >> /opt/pip-tools.d/requirements.in
cat "${SRC_PATH_JIO}/examples/requirements.in" >> /opt/pip-tools.d/requirements.in
Expand All @@ -90,20 +92,20 @@ FROM mealkit AS final

# Finalize installation
RUN <<"EOF" bash -ex -o pipefail
export PIP_INDEX_URL=https://download.pytorch.org/whl/cu129
export PIP_EXTRA_INDEX_URL="https://flashinfer.ai/whl/cu129 https://pypi.org/simple"
export PIP_INDEX_URL=https://download.pytorch.org/whl/cu130
export PIP_EXTRA_INDEX_URL="https://flashinfer.ai/whl/cu130 https://pypi.org/simple"
pushd /opt/pip-tools.d
pip-compile -o requirements.txt $(ls requirements*.in) --constraint overrides.in
# remove cuda wheels from install list since the container already has them
sed -i 's/^nvidia-/# nvidia-/g' requirements.txt
sed -i 's/# nvidia-nvshmem/nvidia-nvshmem/g' requirements.txt
pip install --no-deps --src /opt -r requirements.txt
# make pip happy about the missing torch dependencies
pip-mark-installed nvidia-cublas-cu12 nvidia-cuda-cupti-cu12 \
nvidia-cuda-nvrtc-cu12 nvidia-cuda-runtime-cu12 nvidia-cudnn-cu12 \
nvidia-cufft-cu12 nvidia-cufile-cu12 nvidia-curand-cu12 nvidia-cusolver-cu12 \
nvidia-cusparse-cu12 nvidia-cusparselt-cu12 nvidia-nccl-cu12 \
nvidia-nvjitlink-cu12 nvidia-nvtx-cu12
pip-mark-installed nvidia-cublas-cu13 nvidia-cuda-cupti-cu13 \
nvidia-cuda-nvrtc-cu13 nvidia-cuda-runtime-cu13 nvidia-cudnn-cu13 \
nvidia-cufft-cu13 nvidia-cufile-cu13 nvidia-curand-cu13 nvidia-cusolver-cu13 \
nvidia-cusparse-cu13 nvidia-cusparselt-cu13 nvidia-nccl-cu13 \
nvidia-nvjitlink-cu13 nvidia-nvtx-cu13
popd
rm -rf ~/.cache/*
EOF
Expand Down
4 changes: 2 additions & 2 deletions jax-inference-offloading/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,13 @@ def run(self):
version='0.0.1',
packages=['jax_inference_offloading'],
install_requires=[
'cupy-cuda12x',
'cupy-cuda13x',
'cloudpickle',
'flax',
'grpcio==1.76.*',
'protobuf==6.33.*',
'huggingface-hub',
'jax==0.8.0',
'jax[cuda13_local]==0.8.0',
'jaxtyping',
'kagglehub',
'vllm==0.11.2',
Expand Down
Loading