diff --git a/.github/container/test-jax.sh b/.github/container/test-jax.sh index 73aab9fd0..32e3c7990 100755 --- a/.github/container/test-jax.sh +++ b/.github/container/test-jax.sh @@ -146,6 +146,9 @@ fi # nvidia-*-cu1X wheels from PyPI to run tests, use the local installations FLAGS+=("--//jaxlib/tools:add_pypi_cuda_wheel_deps=false") +# Always use KMD as not all combinations of UMD!=KMD are supported +FLAGS+=("--@cuda_driver//:include_cuda_umd_libs=false") + # Default parallelism: at least 10GB per test, no more than 4 tests per GPU. DEFAULT_JOBS_PER_GPU=$(( GPU_MEMORIES_MIB[0] / 10000)) if (( DEFAULT_JOBS_PER_GPU > 4 )); then DEFAULT_JOBS_PER_GPU=4; fi