Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 154 additions & 0 deletions vllm/patches/vllm_xpu_worker_skip_profile.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
--- a/vllm/v1/worker/xpu_worker.py 2026-03-31 22:53:27.530257234 +0900
+++ b/vllm/v1/worker/xpu_worker.py 2026-03-31 22:27:53.255789217 +0900
@@ -67,63 +67,111 @@
def determine_available_memory(self) -> int:
"""Profiles the peak memory usage of the model to determine how many
KV blocks may be allocated without OOMs.
+
The engine will first conduct a profiling of the existing memory usage.
Then, it calculates the maximum possible number of GPU and CPU blocks
that can be allocated with the remaining free memory.
+
.. tip::
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter.
"""
+ fallback_profile = os.getenv("VLLM_FALLBACK_PROFILE", "0") == "1"
+
+ if fallback_profile:
+ return self._determine_available_memory_fallback()
+ else:
+ return self._determine_available_memory_default()
+
+ def _determine_available_memory_fallback(self) -> int:
+ """Upstream profiling method using memory_allocated()."""
# Profile the memory usage of the model and get the maximum number of
# cache blocks that can be allocated with the remaining free memory.
torch.xpu.empty_cache()
- torch.xpu.reset_peak_memory_stats()

- free_gpu_memory, total_gpu_memory = torch.xpu.mem_get_info()
- current_allocated_bytes = torch.xpu.memory_allocated()
- msg = (
- "Before memory profiling run, "
- f"total GPU memory: {total_gpu_memory / 1024**2:.2f} MB, "
- f"model load takes {current_allocated_bytes / 1024**2:.2f} MB, "
- f"free gpu memory is {free_gpu_memory / 1024**2:.2f} MB."
- )
- logger.info(msg)
# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
self.model_runner.profile_run()

- free_gpu_memory, _ = self.xpu_get_mem_info()
+ # Calculate the number of blocks that can be allocated with the
+ # profiled peak memory.
+ torch.xpu.synchronize()
+ used_memory = torch.xpu.memory_allocated()
+ total_gpu_memory = torch.xpu.get_device_properties(self.local_rank).total_memory
+ free_gpu_memory = total_gpu_memory - used_memory
+
# NOTE(woosuk): Here we assume that the other processes using the same
# GPU did not change their memory usage during the profiling.
- assert self.init_gpu_memory > free_gpu_memory, (
+ peak_memory = self.init_gpu_memory - free_gpu_memory
+ assert peak_memory > 0, (
"Error in memory profiling. "
f"Initial free memory {self.init_gpu_memory}, current free memory"
f" {free_gpu_memory}. This happens when the GPU memory was "
"not properly cleaned up before initializing the vLLM instance."
)

- # Get the peak memory allocation recorded by torch
- peak_memory = torch.xpu.memory_stats()["allocated_bytes.all.peak"]
-
torch.xpu.empty_cache()
- torch_allocated_bytes = torch.xpu.memory_stats()["allocated_bytes.all.current"]
- total_allocated_bytes = self.xpu_get_mem_info()[1] - self.xpu_get_mem_info()[0]

- non_torch_allocations = total_allocated_bytes - torch_allocated_bytes
- if non_torch_allocations > 0:
- peak_memory += non_torch_allocations
available_kv_cache_memory = (
total_gpu_memory * self.cache_config.gpu_memory_utilization - peak_memory
)

- msg = (
- "After memory profiling run, "
- f"peak memory usage is {peak_memory / 1024**2:.2f} MB,"
- f"torch mem is {torch_allocated_bytes / 1024**2:.2f} MB, "
- f"non-torch mem is {non_torch_allocations / 1024**2:.2f} MB, "
- f"free gpu memory is {free_gpu_memory / 1024**2:.2f} MB."
+ return int(available_kv_cache_memory)
+
+ def _determine_available_memory_default(self) -> int:
+ """Custom profiling method using peak memory stats."""
+ # Profile the memory usage of the model and get the maximum number of
+ # cache blocks that can be allocated with the remaining free memory.
+ torch.xpu.empty_cache()
+ torch.xpu.synchronize()
+
+ skip_profile = os.getenv("VLLM_SKIP_PROFILE_RUN", "0") == "1"
+
+ if not skip_profile:
+ torch.xpu.reset_peak_memory_stats()
+
+ # Execute a forward pass with dummy inputs to profile the memory usage
+ # of the model.
+ if "Qwen3ASR" in self.model_runner.model_config.architecture:
+ pass
+ else:
+ self.model_runner.profile_run()
+
+ # Calculate the number of blocks that can be allocated with the
+ # profiled peak memory.
+ torch.xpu.synchronize()
+ total_gpu_memory = torch.xpu.get_device_properties(self.local_rank).total_memory
+
+ # NOTE(woosuk): Here we assume that the other processes using the same
+ # GPU did not change their memory usage during the profiling.
+ stats = torch.xpu.memory_stats()
+ peak_allocated = stats.get("allocated_bytes.all.peak", 0)
+ else:
+ total_gpu_memory = torch.xpu.get_device_properties(self.local_rank).total_memory
+ # Skip profile_run — estimate peak from current allocation + 20% overhead
+ used_memory = torch.xpu.memory_allocated()
+ peak_allocated = int(used_memory * 1.2)
+ print(f"\n[VLLM_SKIP_PROFILE_RUN] Skipping profile_run, estimating peak from allocated memory")
+
+ current_reserved = torch.xpu.memory_reserved()
+
+ fragmentation_bytes = current_reserved - peak_allocated
+ fragmentation_gb = fragmentation_bytes / 1024**3
+ peak_gb = peak_allocated / 1024**3
+ reserved_gb = current_reserved / 1024**3
+ model_memory = self.model_runner.model_memory_usage / 1024 ** 3
+
+ print(f"\n[Memory Profiling Analysis]")
+ print(f" > Peak Allocated (Real Need) : {peak_gb:.2f} GB")
+ print(f" > Model memory usage : {model_memory:.2f} GB")
+ print(f" > Current Reserved (Footprint): {reserved_gb:.2f} GB")
+ print(f" > Fragmentation (Wasted) : {fragmentation_gb:.2f} GB")
+
+ torch.xpu.empty_cache()
+
+ available_kv_cache_memory = (
+ total_gpu_memory * self.cache_config.gpu_memory_utilization - peak_allocated
)
- logger.info(msg)

return int(available_kv_cache_memory)

@@ -161,9 +209,9 @@
)

# global all_reduce needed for overall oneccl warm up
- torch.distributed.all_reduce(
- torch.zeros(1).xpu(), group=get_world_group().device_group
- )
+# torch.distributed.all_reduce(
+# torch.zeros(1).xpu(), group=get_world_group().device_group
+# )

# Set random seed.
set_random_seed(self.model_config.seed)
117 changes: 117 additions & 0 deletions vllm/scripts/lunar_lake_serve.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
#!/bin/bash
# ==============================================================================
# Lunar Lake vLLM Serving Script
# ------------------------------------------------------------------------------
# Launch vLLM on Intel Core Ultra (Lunar Lake) with Arc 140V iGPU.
# Configures memory-aware settings for shared LPDDR5x memory.
#
# Usage:
# ./lunar_lake_serve.sh <model_path_or_name> [extra vllm args...]
#
# Examples:
# ./lunar_lake_serve.sh Qwen/Qwen3-8B --quantization fp8
# ./lunar_lake_serve.sh /models/DeepSeek-R1-Distill-Qwen-7B --quantization int4
# ./lunar_lake_serve.sh Qwen/Qwen3.5-35B-A3B --quantization int4 --max-model-len 8192
# ==============================================================================

set -euo pipefail

RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
NC='\033[0m'

# === Source oneAPI ===
if [ -f /opt/intel/oneapi/setvars.sh ]; then
set +euo pipefail
source /opt/intel/oneapi/setvars.sh --force 2>/dev/null || true
set -euo pipefail
fi

# === Fix MKL library path ===
# PyTorch bundles MKL stubs with relative RPATHs that break in venvs.
# Preload the real oneAPI MKL libraries to avoid "Cannot load libmkl_core.so" errors.
if [ -n "${MKLROOT:-}" ] && [ -f "$MKLROOT/lib/libmkl_core.so.2" ]; then
export LD_PRELOAD="${MKLROOT}/lib/libmkl_core.so.2:${MKLROOT}/lib/libmkl_intel_thread.so.2:${MKLROOT}/lib/libmkl_intel_lp64.so.2${LD_PRELOAD:+:$LD_PRELOAD}"
echo -e "${GREEN}[Lunar Lake vLLM]${NC} MKL preloaded from $MKLROOT"
fi

# === Validate args ===
if [ $# -lt 1 ]; then
echo -e "${RED}Usage: $0 <model_path_or_name> [extra vllm args...]${NC}"
echo ""
echo "Recommended models for Lunar Lake (32GB shared memory):"
echo " Small (fits easily): Qwen/Qwen3-8B --quantization fp8"
echo " Medium (fits tight): Qwen/Qwen3-14B --quantization int4"
echo " Large (requires int4): Qwen/Qwen3.5-35B-A3B --quantization int4 --max-model-len 8192"
echo ""
echo "Notes:"
echo " - Always use --quantization (fp8 or int4) to fit in shared memory"
echo " - Use --max-model-len to limit context and reduce KV cache memory"
echo " - INT4 is recommended for models >14B on 32GB systems"
exit 1
fi

MODEL="$1"
shift

# === Detect available memory ===
TOTAL_MEM_GB=$(free -g | awk '/^Mem:/{print $2}')
AVAIL_MEM_GB=$(free -g | awk '/^Mem:/{print $7}')

echo -e "${GREEN}[Lunar Lake vLLM]${NC} System memory: ${TOTAL_MEM_GB}GB total, ${AVAIL_MEM_GB}GB available"
echo -e "${GREEN}[Lunar Lake vLLM]${NC} Model: $MODEL"

# === Memory warnings ===
if [ "$AVAIL_MEM_GB" -lt 8 ]; then
echo -e "${RED}WARNING: Only ${AVAIL_MEM_GB}GB available. Close other applications.${NC}"
echo -e "${RED}Minimum 8GB free recommended for inference.${NC}"
elif [ "$AVAIL_MEM_GB" -lt 16 ]; then
echo -e "${YELLOW}NOTE: ${AVAIL_MEM_GB}GB available. Use INT4 quantization and limit context length.${NC}"
fi

# === Environment for iGPU ===
export VLLM_TARGET_DEVICE=xpu
export VLLM_WORKER_MULTIPROC_METHOD=spawn
export VLLM_OFFLOAD_WEIGHTS_BEFORE_QUANT=1
export VLLM_ALLOW_LONG_MAX_MODEL_LEN=1
export PYTORCH_ALLOC_CONF="expandable_segments:True"
# USM mode for shared memory (no P2P needed)
export CCL_TOPO_P2P_ACCESS=0
# Skip profile_run() during KV cache init — the dummy forward pass hangs
# indefinitely on Lunar Lake iGPU (Xe2/BMG). Instead, estimate peak memory
# from current allocation. Requires the corresponding xpu_worker.py patch.
export VLLM_SKIP_PROFILE_RUN=1

# === CCL single-GPU workaround ===
# oneCCL's KVS init tries to resolve a network interface even for single-GPU.
# On laptops/handhelds without wired Ethernet this can fail with
# "fill_local_host_ip: can't find non-loopback interface".
# These env vars force CCL to use a local TCP transport instead.
export MASTER_ADDR=127.0.0.1
export MASTER_PORT=${MASTER_PORT:-29500}
export CCL_ZE_ENABLE=0
export CCL_ATL_TRANSPORT=ofi
export FI_PROVIDER=tcp
# Use WiFi interface if available, fallback to loopback
if ip link show wlo1 &>/dev/null; then
export CCL_SOCKET_IFNAME=wlo1
elif ip link show wlan0 &>/dev/null; then
export CCL_SOCKET_IFNAME=wlan0
else
export CCL_SOCKET_IFNAME=lo
fi

echo -e "${GREEN}[Lunar Lake vLLM]${NC} Launching vLLM serve..."
echo "───────────────────────────────────────────────────────────────────────────────"

# === Launch vLLM ===
# Device is set via VLLM_TARGET_DEVICE=xpu (not a CLI flag)
# --tensor-parallel-size 1: Single GPU (integrated)
# --gpu-memory-utilization: Conservative for shared memory (leave room for OS + KV cache)
# --enforce-eager: Disable CUDA graphs (not supported on XPU)
exec vllm serve "$MODEL" \
--tensor-parallel-size 1 \
--gpu-memory-utilization 0.7 \
--enforce-eager \
"$@"