From 13a4b5711fd10d164724a8b2febafb8b5e471b21 Mon Sep 17 00:00:00 2001 From: Sonbol Yazdanbakhsh Date: Tue, 3 Feb 2026 22:50:12 +0000 Subject: [PATCH 1/4] Add RCCL and training warmup for HYBRID_SHARD stability - Add warmup.py with RCCL communicator warmup, manual param sync, and training collectives warmup functions - Integrate RCCL warmup in build_fsdp_model before FSDP init - Add training warmup in main() before training loop - Add TORCH_DIST_INIT_TIMEOUT support for dist.init_process_group and dist.new_group calls - Add cuda.synchronize() before gradient clipping to prevent race - Improve exception handling in finally block - Update set_env_variables.sh with DOCKER_ENV_VARS array for automatic env propagation to Docker containers - Remove duplicate logging in local_launch.sh (logs now only in logs/ directory via master_launch.sh) - Add warmup config params to shampoo_opt_multi_node.yaml Co-authored-by: Cursor --- config/multi_node/shampoo_opt_multi_node.yaml | 8 +- scripts/multi_node/local_launch.sh | 57 ++-- scripts/multi_node/set_env_variables.sh | 137 ++++++++-- src/aorta/training/fsdp_trainer.py | 130 ++++++++- src/aorta/utils/__init__.py | 10 + src/aorta/utils/warmup.py | 255 ++++++++++++++++++ 6 files changed, 542 insertions(+), 55 deletions(-) create mode 100644 src/aorta/utils/warmup.py diff --git a/config/multi_node/shampoo_opt_multi_node.yaml b/config/multi_node/shampoo_opt_multi_node.yaml index 88fc42bc..edefedd0 100644 --- a/config/multi_node/shampoo_opt_multi_node.yaml +++ b/config/multi_node/shampoo_opt_multi_node.yaml @@ -56,6 +56,12 @@ fsdp: forward_prefetch: true sync_module_states: true param_init_device: meta + # RCCL warmup settings to avoid race conditions during FSDP init + rccl_warmup_iterations: 5 + skip_rccl_warmup: false + # Training warmup: forward/backward/optimizer steps before main loop + training_warmup_steps: 1 + skip_training_warmup: false distributed: backend: nccl @@ -94,7 +100,7 @@ dataloader: pin_memory: true profiling: - enabled: true + enabled: false wait: 2 warmup: 2 active: 6 diff --git a/scripts/multi_node/local_launch.sh b/scripts/multi_node/local_launch.sh index 2eeff181..c4ccf014 100755 --- a/scripts/multi_node/local_launch.sh +++ b/scripts/multi_node/local_launch.sh @@ -1,6 +1,9 @@ #!/bin/bash # Multi-node local launch script for GEMM training # Runs on each node with single channel/thread configuration +# +# NCCL/RCCL environment variables are sourced from set_env_variables.sh +# Edit that file to change NCCL configuration - no need to modify this script. if [[ $# -lt 11 ]]; then echo "Usage: $0 [ENABLE_ROCPROF] [ROCPROF_STATS] [ROCPROF_INPUT] [DOCKER_CONTAINER]" @@ -23,6 +26,20 @@ ROCPROF_STATS="${13:-false}" ROCPROF_INPUT="${14:-}" DOCKER_CONTAINER="${15:-training-overlap-bugs-rocm70_9-1}" +# Source environment variables (should already be sourced by config_node.sh, but ensure it's loaded) +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +if [[ -f "$SCRIPT_DIR/set_env_variables.sh" ]]; then + source "$SCRIPT_DIR/set_env_variables.sh" +fi + +# Override channel/thread settings from command line arguments +export NCCL_MAX_NCHANNELS="${CHANNELS}" +export RCCL_THREADS_PER_BLOCK="${THREADS}" + +# Set AMD_LOG_LEVEL_FILE to experiment directory (will be converted to Docker path later) +# This ensures AMD logs go to the experiment folder instead of current directory +export AMD_LOG_LEVEL_FILE="${EXPERIMENT_DIR}/${THREADS}thread_${CHANNELS}channels/trace_amd_node${NODE_RANK}.log" + echo "==========================================" echo "Local Launch Configuration" echo "==========================================" @@ -37,7 +54,6 @@ echo "Experiment Dir: $EXPERIMENT_DIR" echo "Config File: $CONFIG_FILE" echo "Channels: $CHANNELS" echo "Threads: $THREADS" -echo "Docker Container: $DOCKER_CONTAINER" echo "rocprof enabled: $ENABLE_ROCPROF" echo "==========================================" echo "" @@ -60,25 +76,25 @@ else CONFIG_FILE_DOCKER="$CONFIG_FILE" fi -# Log file -LOG_FILE="${OUTPUT_DIR}/node_${NODE_RANK}_output.log" +# Convert AMD_LOG_LEVEL_FILE to Docker path +export AMD_LOG_LEVEL_FILE=$(echo "$AMD_LOG_LEVEL_FILE" | sed "s|^${AORTA_ROOT_FROM_EXP}|/workspace/aorta|") # Function to log with timestamp log() { local message="$1" local timestamp=$(date '+%Y-%m-%d %H:%M:%S') - echo "[${timestamp}] [Node ${NODE_RANK}] ${message}" | tee -a "${LOG_FILE}" + echo "[${timestamp}] [Node ${NODE_RANK}] ${message}" } # Cleanup function cleanup() { echo "" - echo "=== Caught interrupt signal ===" | tee -a "${LOG_FILE}" + echo "=== Caught interrupt signal ===" log "Cleaning up training processes on node ${NODE_RANK}..." # Try to kill processes inside Docker container - docker exec "$DOCKER_CONTAINER" pkill -9 -f "train.py" 2>/dev/null || true - docker exec "$DOCKER_CONTAINER" pkill -9 -f "torchrun" 2>/dev/null || true + docker exec training-overlap-bugs-rocm70_9-1 pkill -9 -f "train.py" 2>/dev/null || true + docker exec training-overlap-bugs-rocm70_9-1 pkill -9 -f "torchrun" 2>/dev/null || true # Also try on host (in case anything leaked) sudo pkill -9 -f "train.py" 2>/dev/null || true @@ -109,14 +125,18 @@ BASE_CMD="torchrun --nnodes ${NNODES} --node_rank ${NODE_RANK} --nproc_per_node BASE_OVERRIDES="--override profiling.tensorboard=false" # Build docker exec prefix with environment variables -DOCKER_EXEC="docker exec \ - -e RCCL_THREADS_PER_BLOCK=${THREADS} \ - -e NCCL_MAX_NCHANNELS=${CHANNELS} \ - -e HSA_ENABLE_SDMA=0 \ - -e PYTORCH_ROCM_PROFILER_ENABLE_TRACING=1 \ - ${DOCKER_CONTAINER}" +# All NCCL/RCCL variables are defined in set_env_variables.sh +DOCKER_ENV_FLAGS=$(build_docker_env_flags) +DOCKER_EXEC="docker exec ${DOCKER_ENV_FLAGS} ${DOCKER_CONTAINER}" + +# Log which env vars are being passed +log "Docker environment variables:" +for var in "${DOCKER_ENV_VARS[@]}"; do + log " ${var}=${!var}" +done # Run with or without rocprofv3 +# Note: Output is already captured by master_launch.sh's redirection, no need for tee if [ "${ENABLE_ROCPROF}" = "true" ]; then ROCPROF_DIR="${OUTPUT_DIR}/rocprof_traces/node_${NODE_RANK}" mkdir -p "${ROCPROF_DIR}" @@ -125,8 +145,7 @@ if [ "${ENABLE_ROCPROF}" = "true" ]; then log "Using rocprofv3 input file: ${ROCPROF_INPUT}" ${DOCKER_EXEC} bash -c "rocprofv3 -i ${ROCPROF_INPUT} -d ${ROCPROF_DIR} -- \ ${BASE_CMD} ${BASE_OVERRIDES} \ - --override training.output_dir=${OUTPUT_DIR_DOCKER}" \ - 2>&1 | tee -a "${LOG_FILE}" + --override training.output_dir=${OUTPUT_DIR_DOCKER}" 2>&1 else ROCPROF_ARGS="--kernel-trace" if [ "${ROCPROF_STATS}" = "true" ]; then @@ -136,18 +155,16 @@ if [ "${ENABLE_ROCPROF}" = "true" ]; then log "Running with rocprofv3 kernel tracing inside Docker" ${DOCKER_EXEC} bash -c "rocprofv3 ${ROCPROF_ARGS} -d ${ROCPROF_DIR} -- \ ${BASE_CMD} ${BASE_OVERRIDES} \ - --override training.output_dir=${OUTPUT_DIR_DOCKER}" \ - 2>&1 | tee -a "${LOG_FILE}" + --override training.output_dir=${OUTPUT_DIR_DOCKER}" 2>&1 fi else log "Running inside Docker container" log "Command: ${BASE_CMD} ${BASE_OVERRIDES} --override training.output_dir=${OUTPUT_DIR_DOCKER}" ${DOCKER_EXEC} bash -c "${BASE_CMD} ${BASE_OVERRIDES} \ - --override training.output_dir=${OUTPUT_DIR_DOCKER}" \ - 2>&1 | tee -a "${LOG_FILE}" + --override training.output_dir=${OUTPUT_DIR_DOCKER}" 2>&1 fi -EXIT_CODE=${PIPESTATUS[0]} +EXIT_CODE=$? END_TIME=$(date +%s) DURATION=$((END_TIME - START_TIME)) diff --git a/scripts/multi_node/set_env_variables.sh b/scripts/multi_node/set_env_variables.sh index 3e9c070a..67b2cba6 100755 --- a/scripts/multi_node/set_env_variables.sh +++ b/scripts/multi_node/set_env_variables.sh @@ -1,42 +1,133 @@ #!/bin/bash +# ============================================================================= # Global NCCL/RCCL environment variables for multi-node training -# Based on DLRM_set_env_variables.sh +# Configured for MI350X cluster +# +# This file is the SINGLE SOURCE OF TRUTH for all NCCL/RCCL configuration. +# Edit variables here - local_launch.sh will automatically pick them up. +# +# NOTE: When adding a new environment variable, you MUST also add its name +# to the DOCKER_ENV_VARS array below, otherwise it won't be passed +# to the Docker container. +# ============================================================================= -# NCCL Debug Settings (use INFO for debugging network issues) -export NCCL_DEBUG=INFO -export NCCL_DEBUG_SUBSYS=INIT,NET -# Try disabling IB if InfiniBand is not properly configured -export NCCL_IB_DISABLE=1 +# ----------------------------------------------------------------------------- +# NCCL Debug Settings +# ----------------------------------------------------------------------------- +export NCCL_DEBUG=WARN +export NCCL_DEBUG_SUBSYS= # Options: COLL,INIT,NET (empty = none) -# IB/RNIC Configuration (commented out when IB is disabled) -# export NCCL_IB_HCA=bnxt_re0,bnxt_re1,bnxt_re2,bnxt_re3,bnxt_re4,bnxt_re5,bnxt_re6,bnxt_re7 -# export NCCL_IB_GID_INDEX=3 +# ----------------------------------------------------------------------------- +# RCCL-Specific Settings (ROCm) +# ----------------------------------------------------------------------------- +export RCCL_DIRECT_ALLGATHER_DISABLE=1 # Disable direct allgather +export RCCL_MSCCL_ENABLE=0 # Disable MSCCL +export RCCL_THREADS_PER_BLOCK=256 # Threads per block (override via --threads) + +# ----------------------------------------------------------------------------- +# IB/RNIC Configuration for MI350X +# ----------------------------------------------------------------------------- +export NCCL_IB_HCA=bnxt_re0,bnxt_re1,bnxt_re2,bnxt_re3,bnxt_re4,bnxt_re5,bnxt_re6,bnxt_re7 +export NCCL_IB_GID_INDEX=3 export NCCL_NCHANNELS_PER_NET_PEER=8 +# ----------------------------------------------------------------------------- # HSA Settings for ROCm +# ----------------------------------------------------------------------------- export HSA_ENABLE_IPC_MODE_LEGACY=1 +export HSA_ENABLE_SDMA=0 # Disable SDMA for stability -# NCCL Protocol +# ----------------------------------------------------------------------------- +# NCCL Protocol and Channels +# ----------------------------------------------------------------------------- export NCCL_PROTO=Simple +#export NCCL_MIN_NCHANNELS=40 +export NCCL_MAX_NCHANNELS=56 # Override via --channels -# Channel Configuration (can be overridden by sweep parameters) -export NCCL_MIN_NCHANNELS=40 -export NCCL_MAX_NCHANNELS=40 - -# Network Interface -# Change this to match your network interface: eth0, ib0, enp49s0f0np0, etc. -# Temporarily commented out for auto-detection: -# export NCCL_SOCKET_IFNAME=enp193s0f0 +# ----------------------------------------------------------------------------- +# Network Interface for MI350X cluster +# ----------------------------------------------------------------------------- +export NCCL_SOCKET_IFNAME=enp49s0f0np0,fenic0 +# ----------------------------------------------------------------------------- +# Timeout and Error Handling +# ----------------------------------------------------------------------------- +export NCCL_TIMEOUT_MS=12000 # 12 second timeout (legacy, not used by PyTorch) +export NCCL_TIMEOUT=100 # 300 second (5 min) timeout - first backward can be slow due to JIT/init +export TORCH_DIST_INIT_TIMEOUT=150 # Match collective timeout for consistency +export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 +export TORCH_NCCL_TRACE_BUFFER_SIZE=10000 +export TORCH_NCCL_DUMP_ON_TIMEOUT=1 # Critical for hang debugging +#export AMD_LOG_LEVEL=5 +# AMD_LOG_LEVEL_FILE is set dynamically in local_launch.sh to point to experiment directory +# Default fallback (will be overridden): +#export AMD_LOG_LEVEL_FILE=trace_amd.log +# ----------------------------------------------------------------------------- # PyTorch ROCm Profiler +# ----------------------------------------------------------------------------- export PYTORCH_ROCM_PROFILER_ENABLE_TRACING=1 -# Optional: Force non-overlap for debugging +# ----------------------------------------------------------------------------- +# List of environment variables to pass to Docker container +# Add/remove variables here to control what gets passed through +# ----------------------------------------------------------------------------- +DOCKER_ENV_VARS=( + # NCCL Debug + NCCL_DEBUG + NCCL_DEBUG_SUBSYS + # RCCL + RCCL_DIRECT_ALLGATHER_DISABLE + RCCL_MSCCL_ENABLE + RCCL_THREADS_PER_BLOCK + # IB/RNIC + NCCL_IB_HCA + NCCL_IB_GID_INDEX + NCCL_NCHANNELS_PER_NET_PEER + # HSA + HSA_ENABLE_IPC_MODE_LEGACY + HSA_ENABLE_SDMA + # Protocol/Channels + NCCL_PROTO + NCCL_MIN_NCHANNELS + NCCL_MAX_NCHANNELS + # Network + NCCL_SOCKET_IFNAME + # Timeout/Error Handling + NCCL_TIMEOUT_MS + NCCL_TIMEOUT + TORCH_DIST_INIT_TIMEOUT + TORCH_NCCL_ASYNC_ERROR_HANDLING + TORCH_NCCL_TRACE_BUFFER_SIZE + TORCH_NCCL_DUMP_ON_TIMEOUT + # AMD Logging + AMD_LOG_LEVEL + AMD_LOG_LEVEL_FILE + # Profiler + PYTORCH_ROCM_PROFILER_ENABLE_TRACING +) +export DOCKER_ENV_VARS + +# ----------------------------------------------------------------------------- +# Helper function: Build docker -e flags from DOCKER_ENV_VARS +# Usage: DOCKER_ENV_FLAGS=$(build_docker_env_flags) +# ----------------------------------------------------------------------------- +build_docker_env_flags() { + local flags="" + for var in "${DOCKER_ENV_VARS[@]}"; do + local value="${!var}" + flags+=" -e ${var}=${value}" + done + echo "$flags" +} +export -f build_docker_env_flags + +# ============================================================================= +# Optional settings (uncomment to enable) +# ============================================================================= + +# Force non-overlap for debugging (single HW queue) # export GPU_MAX_HW_QUEUES=1 # unset TORCH_NCCL_HIGH_PRIORITY -# Optional: Disable SDMA for testing -# export HSA_ENABLE_SDMA=0 - -# Optional: Disable IB for Ethernet-only testing +# Disable IB for Ethernet-only testing # export NCCL_IB_DISABLE=1 diff --git a/src/aorta/training/fsdp_trainer.py b/src/aorta/training/fsdp_trainer.py index eea19231..a22f9159 100644 --- a/src/aorta/training/fsdp_trainer.py +++ b/src/aorta/training/fsdp_trainer.py @@ -10,6 +10,7 @@ import signal import subprocess from dataclasses import dataclass, field +from datetime import timedelta from pathlib import Path from typing import Any, Dict, Generator, Iterable, Optional from functools import partial @@ -27,7 +28,17 @@ from aorta.data import SyntheticDatasetConfig, create_dataloader from aorta.models import ModelConfig, RankingTransformerModel from aorta.profiling.stream_profiler import StreamProfiler -from aorta.utils import detect_accelerator, get_device, get_distributed_backend, load_config, merge_cli_overrides, setup_logging +from aorta.utils import ( + detect_accelerator, + get_device, + get_distributed_backend, + load_config, + manual_sync_params, + merge_cli_overrides, + setup_logging, + warmup_rccl_communicators, + warmup_training_collectives, +) log = logging.getLogger(__name__) @@ -73,6 +84,17 @@ class FSDPConfig: # For HYBRID_SHARD: GPUs per node (None = auto-detect from LOCAL_WORLD_SIZE env var) # Only set this if auto-detection fails or you want to override hybrid_shard_gpus_per_node: Optional[int] = None + # Number of warmup operations to perform on RCCL communicators before FSDP init + # This helps avoid race conditions in inter-node RDMA setup + # Higher values provide more stability but increase startup time + rccl_warmup_iterations: int = 10 + # Skip RCCL warmup entirely (for testing race conditions) + skip_rccl_warmup: bool = False + # Number of training warmup steps (forward/backward/optimizer) before main loop + # This exercises all collectives to ensure RCCL is fully established + training_warmup_steps: int = 1 + # Skip training warmup entirely + skip_training_warmup: bool = False @dataclass @@ -206,7 +228,9 @@ def dataclass_fields(cls) -> Iterable[Any]: def init_distributed(training_cfg: TrainingConfig, log_level: str) -> Dict[str, Any]: backend = get_distributed_backend() - dist.init_process_group(backend=backend) + # Use timeout from environment variable (set in set_env_variables.sh) + timeout_seconds = int(os.environ.get("TORCH_DIST_INIT_TIMEOUT", "600")) + dist.init_process_group(backend=backend, timeout=timedelta(seconds=timeout_seconds)) rank = dist.get_rank() world_size = dist.get_world_size() local_rank = int(os.environ.get("LOCAL_RANK", os.environ.get("SLURM_LOCALID", 0))) @@ -250,11 +274,48 @@ def build_fsdp_model( # Create process groups for hybrid_shard strategy process_group = None + shard_group = None + replicate_group = None + if sharding == ShardingStrategy.HYBRID_SHARD: - process_group = _create_hybrid_shard_process_groups(fsdp_cfg.hybrid_shard_gpus_per_node) - if process_group is not None: + result = _create_hybrid_shard_process_groups(fsdp_cfg.hybrid_shard_gpus_per_node) + if result is not None: + shard_group, replicate_group = result + process_group = (shard_group, replicate_group) + + # Warmup RCCL communicators BEFORE FSDP initialization + # This ensures inter-node communicators are fully established before + # the _sync_params_and_buffers broadcasts that can cause hangs + if fsdp_cfg.skip_rccl_warmup: + log.warning("SKIPPING RCCL warmup (skip_rccl_warmup=True) - may cause hangs or race conditions") + else: + warmup_rccl_communicators( + shard_group, + replicate_group, + device, + num_warmup_ops=fsdp_cfg.rccl_warmup_iterations, + ) log.info("Created custom process groups for HYBRID_SHARD strategy") + # Ensure GPU operations are complete before FSDP wrapping + # This helps prevent race conditions with inter-node communicators + torch.cuda.synchronize() + dist.barrier() + + # For HYBRID_SHARD with sync_module_states, we disable automatic sync and do it + # manually with explicit barriers to avoid RCCL race conditions + use_sync_module_states = fsdp_cfg.sync_module_states + needs_manual_sync = False + + if sharding == ShardingStrategy.HYBRID_SHARD and fsdp_cfg.sync_module_states: + use_sync_module_states = False + needs_manual_sync = True + log.info( + "Disabling sync_module_states for HYBRID_SHARD - will sync manually after wrapping" + ) + + log.info("Starting FSDP model wrapping with sync_module_states=%s", use_sync_module_states) + fsdp_model = FSDP( model.to(device), sharding_strategy=sharding, @@ -265,8 +326,13 @@ def build_fsdp_model( limit_all_gathers=fsdp_cfg.limit_all_gathers, forward_prefetch=fsdp_cfg.forward_prefetch, device_id=torch.cuda.current_device(), - sync_module_states=fsdp_cfg.sync_module_states, + sync_module_states=use_sync_module_states, ) + + # Manual parameter synchronization for HYBRID_SHARD + if needs_manual_sync and replicate_group is not None: + manual_sync_params(fsdp_model, replicate_group) + if compile_cfg.enabled: fsdp_model = _maybe_compile(fsdp_model, compile_cfg) return fsdp_model @@ -290,6 +356,10 @@ def _create_hybrid_shard_process_groups(gpus_per_node: Optional[int] = None): rank = dist.get_rank() local_rank = int(os.environ.get("LOCAL_RANK", 0)) + # Get timeout from environment - same as init_process_group + timeout_seconds = int(os.environ.get("TORCH_DIST_INIT_TIMEOUT", "600")) + group_timeout = timedelta(seconds=timeout_seconds) + # Auto-detect GPUs per node from environment if not provided if gpus_per_node is None: # torchrun sets LOCAL_WORLD_SIZE to the number of processes per node @@ -320,21 +390,21 @@ def _create_hybrid_shard_process_groups(gpus_per_node: Optional[int] = None): return None log.info( - "Creating HYBRID_SHARD process groups | rank=%d world_size=%d num_nodes=%d gpus_per_node=%d node_id=%d", - rank, world_size, num_nodes, gpus_per_node, node_id + "Creating HYBRID_SHARD process groups | rank=%d world_size=%d num_nodes=%d gpus_per_node=%d node_id=%d timeout=%ds", + rank, world_size, num_nodes, gpus_per_node, node_id, timeout_seconds ) # Intra-node groups: shard within each node for i in range(num_nodes): ranks_in_node = list(range(i * gpus_per_node, (i + 1) * gpus_per_node)) - group = dist.new_group(ranks=ranks_in_node) + group = dist.new_group(ranks=ranks_in_node, timeout=group_timeout) if i == node_id: my_shard_group = group # Inter-node groups: replicate across nodes (same local_rank) for local_r in range(gpus_per_node): ranks_across_nodes = [node * gpus_per_node + local_r for node in range(num_nodes)] - group = dist.new_group(ranks=ranks_across_nodes) + group = dist.new_group(ranks=ranks_across_nodes, timeout=group_timeout) if local_r == local_rank: my_replicate_group = group @@ -502,6 +572,10 @@ def training_loop( else: loss.backward() + # Synchronize before gradient clipping to ensure backward is complete + # This prevents race conditions between FSDP gradient reduction and clipping + torch.cuda.synchronize() + grad_norm = None if training_cfg.grad_clip_norm is not None and training_cfg.grad_clip_norm > 0: with profiler.range("aux", f"epoch{epoch}_step{step}_grad_clip"): @@ -845,6 +919,33 @@ def main(args: Optional[argparse.Namespace] = None, *, enable_rocm_metrics: bool profiler = StreamProfiler(env["device"]) + # Training warmup: run a few forward/backward/optimizer steps to warm up collectives + if not fsdp_cfg.skip_training_warmup and fsdp_cfg.training_warmup_steps > 0: + log.info("Starting training warmup with %d steps...", fsdp_cfg.training_warmup_steps) + # Determine autocast dtype for warmup + mp_mode = training_cfg.mixed_precision.lower() + if mp_mode == "fp16": + warmup_autocast_dtype = torch.float16 + warmup_scaler = torch.cuda.amp.GradScaler() + elif mp_mode == "bf16": + warmup_autocast_dtype = torch.bfloat16 + warmup_scaler = None + else: + warmup_autocast_dtype = None + warmup_scaler = None + + warmup_training_collectives( + model=model, + optimizer=optimizer, + dataloader=dataloader, + device=env["device"], + autocast_dtype=warmup_autocast_dtype, + scaler=warmup_scaler, + loss_fn=compute_loss, + num_warmup_steps=fsdp_cfg.training_warmup_steps, + ) + log.info("Training warmup complete") + try: training_loop( model, @@ -858,8 +959,15 @@ def main(args: Optional[argparse.Namespace] = None, *, enable_rocm_metrics: bool profiler_cfg, ) finally: - dist.barrier() - dist.destroy_process_group() + if dist.is_initialized(): + try: + dist.barrier() + except Exception as e: + log.warning("Barrier failed during cleanup: %s", e) + try: + dist.destroy_process_group() + except Exception as e: + log.warning("destroy_process_group failed: %s", e) __all__ = ["main", "main_cli"] diff --git a/src/aorta/utils/__init__.py b/src/aorta/utils/__init__.py index 7a42f6dc..d840ba09 100644 --- a/src/aorta/utils/__init__.py +++ b/src/aorta/utils/__init__.py @@ -6,6 +6,7 @@ - GPU timing utilities - Configuration loading - Logging setup +- Distributed training warmup """ from .config import load_config, merge_cli_overrides @@ -43,6 +44,11 @@ StreamTimer, TimingContext, ) +from .warmup import ( + manual_sync_params, + warmup_rccl_communicators, + warmup_training_collectives, +) __all__ = [ # Config @@ -78,4 +84,8 @@ "TimingContext", "StreamTimer", "CPUTimer", + # Warmup + "warmup_rccl_communicators", + "manual_sync_params", + "warmup_training_collectives", ] diff --git a/src/aorta/utils/warmup.py b/src/aorta/utils/warmup.py new file mode 100644 index 00000000..5ce766d3 --- /dev/null +++ b/src/aorta/utils/warmup.py @@ -0,0 +1,255 @@ +"""Distributed training warmup utilities. + +This module provides functions to warm up RCCL communicators and training +collectives before the main training loop starts. These help avoid race +conditions in RCCL/RDMA during FSDP initialization. + +Design Notes: + - Training warmup is reduced to 1 step (from 3) to preserve timing + variability while still exercising the collectives. + - RCCL warmup in build_fsdp_model handles communicator initialization + separately from training collectives warmup. + - Set skip_training_warmup=True to maximize timing variability for + race condition testing. +""" + +import logging +from typing import Callable, Dict, Optional + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + +log = logging.getLogger(__name__) + + +def warmup_rccl_communicators( + shard_group: Optional[dist.ProcessGroup], + replicate_group: Optional[dist.ProcessGroup], + device: torch.device, + num_warmup_ops: int = 5, +) -> None: + """ + Warm up RCCL communicators with small operations before heavy FSDP usage. + + This ensures inter-node communicators are fully established before the + _sync_params_and_buffers broadcasts. The race condition in RCCL/RoCE RDMA + setup can cause hangs during FSDP initialization if broadcasts are issued + before the communicators are ready. + + Args: + shard_group: Intra-node shard process group (may be None) + replicate_group: Inter-node replicate process group (may be None) + device: GPU device to use for warmup tensors + num_warmup_ops: Number of warmup operations to perform (default: 5) + """ + rank = dist.get_rank() + # Use a larger tensor for more thorough warmup + warmup_tensor = torch.ones(8192, device=device, dtype=torch.float32) + + log.info("Starting RCCL communicator warmup with %d iterations (rank=%d)...", num_warmup_ops, rank) + log.info("Warmup will test: all_reduce, broadcast, reduce_scatter, all_gather") + + # First, warmup the global world group + log.info("Warming up global world group...") + for i in range(num_warmup_ops): + dist.all_reduce(warmup_tensor) + dist.broadcast(warmup_tensor, src=0) + torch.cuda.synchronize() + + dist.barrier() + log.info("Global world group warmup complete") + + # Then warmup the shard and replicate groups + # IMPORTANT: Must test ALL collective types that FSDP uses: + # - reduce_scatter (FSDP backward gradient reduction) + # - all_gather (FSDP forward parameter gathering) + # - all_reduce (HYBRID_SHARD inter-node gradient sync) + # - broadcast (parameter sync) + for i in range(num_warmup_ops): + # Warmup intra-node shard group + if shard_group is not None: + shard_size = dist.get_world_size(shard_group) + + # all_reduce (used in some FSDP operations) + dist.all_reduce(warmup_tensor, group=shard_group) + + # reduce_scatter - CRITICAL: this is what FSDP backward uses! + # Create input tensor that's shard_size times larger + rs_input = torch.ones(8192 * shard_size, device=device, dtype=torch.float32) + rs_output = torch.empty(8192, device=device, dtype=torch.float32) + dist.reduce_scatter_tensor(rs_output, rs_input, group=shard_group) + + # all_gather - used in FSDP forward + ag_input = torch.ones(8192, device=device, dtype=torch.float32) + ag_output = torch.empty(8192 * shard_size, device=device, dtype=torch.float32) + dist.all_gather_into_tensor(ag_output, ag_input, group=shard_group) + + # broadcast from first rank in shard group + shard_ranks = dist.get_process_group_ranks(shard_group) + dist.broadcast(warmup_tensor, src=shard_ranks[0], group=shard_group) + + # Warmup inter-node replicate group (this is where the race condition occurs) + if replicate_group is not None: + # Get the ranks in this replicate group and use the first one as source + # Note: dist.get_process_group_ranks returns global ranks in the group + group_ranks = dist.get_process_group_ranks(replicate_group) + src_global_rank = group_ranks[0] # First rank in the group + dist.broadcast(warmup_tensor, src=src_global_rank, group=replicate_group) + dist.all_reduce(warmup_tensor, group=replicate_group) + + # Synchronize GPU and global barrier between iterations + torch.cuda.synchronize() + dist.barrier() + + # Final synchronization with extra delay + torch.cuda.synchronize() + dist.barrier() + torch.cuda.synchronize() + dist.barrier() + + log.info("RCCL communicator warmup complete (rank=%d)", rank) + + +def manual_sync_params( + model: FSDP, + replicate_group: Optional[dist.ProcessGroup], +) -> None: + """ + Manually synchronize FSDP parameters from the first rank in each replicate group. + + This replaces the automatic sync_module_states with controlled synchronization + to avoid race conditions in RCCL/RDMA during FSDP initialization. Parameters + are broadcast from the first rank in each replicate group to ensure consistency. + + Args: + model: The FSDP-wrapped model + replicate_group: Inter-node replicate process group for broadcasting + """ + rank = dist.get_rank() + + log.info("Starting manual parameter synchronization (rank=%d)...", rank) + + # Synchronize before param sync + torch.cuda.synchronize() + dist.barrier() + + # Determine the source rank for this replicate group + # Each replicate group contains ranks with the same local_rank across nodes + # e.g., group for local_rank 2: [2, 10, 18] - we broadcast from rank 2 (first in group) + src_global_rank = None + if replicate_group is not None: + group_ranks = dist.get_process_group_ranks(replicate_group) + src_global_rank = group_ranks[0] # First rank in the group + log.info("Manual sync: replicate group ranks=%s, src_rank=%d", group_ranks, src_global_rank) + + param_count = 0 + with torch.no_grad(): + for name, param in model.named_parameters(): + if param.is_meta: + log.debug("Skipping meta parameter: %s", name) + continue + + # Broadcast from the first rank within this replicate group + if replicate_group is not None and src_global_rank is not None: + dist.broadcast(param.data, src=src_global_rank, group=replicate_group) + + param_count += 1 + + # Periodic sync to prevent overwhelming the network + if param_count % 10 == 0: + torch.cuda.synchronize() + + # Final barrier to ensure all ranks complete + torch.cuda.synchronize() + dist.barrier() + + log.info("Manual parameter synchronization complete (rank=%d, params=%d)", rank, param_count) + + +def warmup_training_collectives( + model: nn.Module, + optimizer: torch.optim.Optimizer, + dataloader, + device: torch.device, + autocast_dtype: Optional[torch.dtype], + scaler: Optional[torch.cuda.amp.GradScaler], + loss_fn: Callable[[torch.Tensor, Dict[str, torch.Tensor]], torch.Tensor], + num_warmup_steps: int = 3, +) -> None: + """ + Warm up training collectives by running dummy forward/backward/optimizer steps. + + This exercises all the collective operations used during training (all-gather, + reduce-scatter, all-reduce) to ensure RCCL communicators are fully established + before the main training loop starts. + + Args: + model: The model (FSDP-wrapped) + optimizer: The optimizer + dataloader: Training dataloader + device: GPU device + autocast_dtype: Mixed precision dtype (or None) + scaler: Gradient scaler for fp16 (or None) + loss_fn: Loss function that takes (scores, batch) and returns loss tensor + num_warmup_steps: Number of warmup steps to run + """ + rank = dist.get_rank() if dist.is_initialized() else 0 + + # Get an iterator from the dataloader + data_iter = iter(dataloader) + + for warmup_step in range(num_warmup_steps): + try: + cpu_batch = next(data_iter) + except StopIteration: + # Restart iterator if dataloader is exhausted + data_iter = iter(dataloader) + cpu_batch = next(data_iter) + + # Move batch to device + batch = {k: v.to(device, non_blocking=True) if hasattr(v, 'to') else v + for k, v in cpu_batch.items()} + torch.cuda.synchronize() + + # Forward pass + optimizer.zero_grad(set_to_none=True) + if autocast_dtype: + with torch.autocast(device_type="cuda", dtype=autocast_dtype): + scores = model(batch) + loss = loss_fn(scores, batch) + else: + scores = model(batch) + loss = loss_fn(scores, batch) + + # Backward pass + if scaler is not None: + scaler.scale(loss).backward() + else: + loss.backward() + + # Optimizer step + if scaler is not None: + scaler.step(optimizer) + scaler.update() + else: + optimizer.step() + + # Synchronize all ranks after each warmup step + torch.cuda.synchronize() + dist.barrier() + + log.debug("Warmup step %d complete (rank=%d, loss=%.4f)", warmup_step, rank, loss.item()) + + # Reset optimizer state after warmup to not affect actual training + optimizer.zero_grad(set_to_none=True) + torch.cuda.synchronize() + dist.barrier() + + +__all__ = [ + "warmup_rccl_communicators", + "manual_sync_params", + "warmup_training_collectives", +] From bc511a6088066072671ae2dc645ca8f4e4140df7 Mon Sep 17 00:00:00 2001 From: Vivek Agrawal <197589114+amd-vivekag@users.noreply.github.com> Date: Wed, 4 Feb 2026 11:43:27 +0530 Subject: [PATCH 2/4] removes hard coding on docker container name Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- scripts/multi_node/local_launch.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/multi_node/local_launch.sh b/scripts/multi_node/local_launch.sh index c4ccf014..83af10ba 100755 --- a/scripts/multi_node/local_launch.sh +++ b/scripts/multi_node/local_launch.sh @@ -93,8 +93,8 @@ cleanup() { log "Cleaning up training processes on node ${NODE_RANK}..." # Try to kill processes inside Docker container - docker exec training-overlap-bugs-rocm70_9-1 pkill -9 -f "train.py" 2>/dev/null || true - docker exec training-overlap-bugs-rocm70_9-1 pkill -9 -f "torchrun" 2>/dev/null || true + docker exec "${DOCKER_CONTAINER}" pkill -9 -f "train.py" 2>/dev/null || true + docker exec "${DOCKER_CONTAINER}" pkill -9 -f "torchrun" 2>/dev/null || true # Also try on host (in case anything leaked) sudo pkill -9 -f "train.py" 2>/dev/null || true From 8623e1517ee9906fe40f35f81f7af9c1b891284d Mon Sep 17 00:00:00 2001 From: Vivek Agrawal <197589114+amd-vivekag@users.noreply.github.com> Date: Wed, 4 Feb 2026 11:49:37 +0530 Subject: [PATCH 3/4] adds a log if both shard_group and replicate_group are None Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/aorta/utils/warmup.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/aorta/utils/warmup.py b/src/aorta/utils/warmup.py index 5ce766d3..019f5428 100644 --- a/src/aorta/utils/warmup.py +++ b/src/aorta/utils/warmup.py @@ -67,6 +67,14 @@ def warmup_rccl_communicators( # - all_gather (FSDP forward parameter gathering) # - all_reduce (HYBRID_SHARD inter-node gradient sync) # - broadcast (parameter sync) + if shard_group is None and replicate_group is None: + log.warning( + "No shard_group or replicate_group provided; only the global world " + "process group will be warmed up. Subsequent iterations will run " + "only barriers and CUDA synchronizations without group-specific " + "collective operations (rank=%d).", + rank, + ) for i in range(num_warmup_ops): # Warmup intra-node shard group if shard_group is not None: From 6d36362077f9a146c1c63b39cfb63da086a505f5 Mon Sep 17 00:00:00 2001 From: Vivek Agrawal <197589114+amd-vivekag@users.noreply.github.com> Date: Wed, 4 Feb 2026 11:50:27 +0530 Subject: [PATCH 4/4] makes comment consistent with the code Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- scripts/multi_node/set_env_variables.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/multi_node/set_env_variables.sh b/scripts/multi_node/set_env_variables.sh index 67b2cba6..3503a2cb 100755 --- a/scripts/multi_node/set_env_variables.sh +++ b/scripts/multi_node/set_env_variables.sh @@ -53,7 +53,7 @@ export NCCL_SOCKET_IFNAME=enp49s0f0np0,fenic0 # Timeout and Error Handling # ----------------------------------------------------------------------------- export NCCL_TIMEOUT_MS=12000 # 12 second timeout (legacy, not used by PyTorch) -export NCCL_TIMEOUT=100 # 300 second (5 min) timeout - first backward can be slow due to JIT/init +export NCCL_TIMEOUT=100 # 100 second timeout - first backward can be slow due to JIT/init export TORCH_DIST_INIT_TIMEOUT=150 # Match collective timeout for consistency export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 export TORCH_NCCL_TRACE_BUFFER_SIZE=10000