From 7e99d2937a1cb952a91b2fed97e80cdbd713edec Mon Sep 17 00:00:00 2001 From: gavingavin99 <266622011+gavingavin99@users.noreply.github.com> Date: Tue, 9 Jun 2026 19:44:05 +0800 Subject: [PATCH 1/2] fix(smooth) fix typos and change default smooth config --- configs/Hy3/ptq/fp8/Hy3_smooth.yaml | 8 ++++---- scripts/ptq/run_smooth_for_HY3.sh | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/configs/Hy3/ptq/fp8/Hy3_smooth.yaml b/configs/Hy3/ptq/fp8/Hy3_smooth.yaml index ea6b6ea2..2eb3212e 100644 --- a/configs/Hy3/ptq/fp8/Hy3_smooth.yaml +++ b/configs/Hy3/ptq/fp8/Hy3_smooth.yaml @@ -27,8 +27,8 @@ collect_moe: true ema_momentum: 0.9 # ========== Phase 2: smooth flavours & fixed alpha ========== -smooth_qk: true -smooth_vo: true +smooth_qk: false +smooth_vo: false smooth_down: true alpha_qk: 0.6 alpha_vo: 0.5 @@ -56,5 +56,5 @@ alpha_smooth_search_mode: default # per-tensor-act-first mode parameters (only effective when mode=per-tensor-act-first) # alpha_act_mul_min: 0.1 # alpha_act_mul_max: 1.0 -# alpha_smooth_min: 1e-6 -# alpha_smooth_max: 1e6 +# alpha_smooth_min: 0.000001 +# alpha_smooth_max: 100000 diff --git a/scripts/ptq/run_smooth_for_HY3.sh b/scripts/ptq/run_smooth_for_HY3.sh index d745f642..26f116e8 100755 --- a/scripts/ptq/run_smooth_for_HY3.sh +++ b/scripts/ptq/run_smooth_for_HY3.sh @@ -33,8 +33,6 @@ done # -------- Environment Variables -------- # Allow function serialization for apply_model in vLLM v1 engine export VLLM_ALLOW_INSECURE_SERIALIZATION=1 -# Enable MoE expert statistics collection -export VLLM_MOE_COLLECT_STATS=1 # Force Ray to reload code (disable code caching) export RAY_DEDUP_LOGS=0 # Force Python to not use bytecode cache @@ -43,7 +41,6 @@ export PYTHONDONTWRITEBYTECODE=1 export MAX_NUM_BATCHED_TOKENS=32768 export VLLM_ENABLE_CHUNKED_PREFILL=1 -export MOE_MODE=fused export VLLM_ATTENTION_BACKEND=FLASHINFER export ASYNC_SCHEDULING=1 export VLLM_ENABLE_PREFIX_CACHING=1 @@ -54,7 +51,6 @@ export PRECISIONMODE=HF export VLLM_MOE_COLLECT_SMOOTH_STATS=1 export VLLM_MOE_COLLECT_ALPHA_SEARCH=1 -export PYTHONPATH=/cfs_cloud_code/gavinlee/work/open_source_smooth/AngelSlim # -------- Phase 1: Collect Smooth Stats + Alpha Search -------- if [ "$SKIP_CALIBRATE" = false ]; then echo "========================================" @@ -72,6 +68,10 @@ if [ "$SKIP_CONVERT" = false ]; then python3 tools/smooth/convert_smooth_weights.py -c "$CONFIG" fi +# revert +unset VLLM_MOE_COLLECT_SMOOTH_STATS +unset VLLM_MOE_COLLECT_ALPHA_SEARCH + echo "========================================" echo "Done." echo "========================================" From a541d7703e04706b6c380ea16b14c605c559fd03 Mon Sep 17 00:00:00 2001 From: gavingavin99 <266622011+gavingavin99@users.noreply.github.com> Date: Wed, 10 Jun 2026 14:52:43 +0800 Subject: [PATCH 2/2] fix(smooth): update convert smooth weight for hy_v3 in transformer --- tools/smooth/convert_smooth_weights.py | 371 ++++++++++--------------- 1 file changed, 153 insertions(+), 218 deletions(-) diff --git a/tools/smooth/convert_smooth_weights.py b/tools/smooth/convert_smooth_weights.py index 0eee8fe4..e39469d5 100644 --- a/tools/smooth/convert_smooth_weights.py +++ b/tools/smooth/convert_smooth_weights.py @@ -36,10 +36,7 @@ import argparse import json -import math import os -import shutil -from concurrent.futures import ThreadPoolExecutor, as_completed import torch from transformers import AutoModelForCausalLM, AutoTokenizer @@ -251,64 +248,89 @@ def parse_args(): # --------------------------------------------------------------------------- -# MTP / extra weight recovery: load tensors dropped by AutoModelForCausalLM +# Post-save: append source keys that save_pretrained dropped (e.g. MTP layers) # --------------------------------------------------------------------------- -def load_missing_tensors(model_path: str, loaded_keys: set) -> dict: +def _read_weight_map(model_dir: str) -> tuple[dict, bool]: """ - Load weights that were NOT loaded by AutoModelForCausalLM (e.g. MTP layers). + Read the safetensors weight_map of a model directory. - Typical case: models with Multi-Token Prediction heads whose extra layers - are not registered in the HF architecture and get dropped with - "Some weights of the model checkpoint were not used" warning. - - Args: - model_path: original model directory path - loaded_keys: model.state_dict().keys() (keys already loaded by HF) - - Returns: - dict[str, torch.Tensor] -- missing key -> CPU tensor mapping; - empty dict if nothing is missing. + Returns ``(weight_map, is_sharded)`` where ``weight_map`` maps + ``tensor_key -> shard_filename``. Handles both the sharded layout + (``model.safetensors.index.json``) and the single-file layout + (``model.safetensors``). Returns ``({}, False)`` if neither is found. """ - try: - from safetensors import safe_open - except ImportError: - print(" [WARNING] safetensors not available, cannot supplement missing weights") - return {} - - index_path = os.path.join(model_path, "model.safetensors.index.json") - single_path = os.path.join(model_path, "model.safetensors") + index_path = os.path.join(model_dir, "model.safetensors.index.json") + single_path = os.path.join(model_dir, "model.safetensors") if os.path.exists(index_path): with open(index_path, "r") as f: index = json.load(f) - weight_map: dict[str, str] = index["weight_map"] - elif os.path.exists(single_path): + return dict(index["weight_map"]), True + + if os.path.exists(single_path): + from safetensors import safe_open + with safe_open(single_path, framework="pt", device="cpu") as f_st: - weight_map = {k: "model.safetensors" for k in f_st.keys()} - else: - print(" [WARNING] No safetensors index/file found in model_path, skipping") - return {} + return {k: "model.safetensors" for k in f_st.keys()}, False + + return {}, False + + +def append_missing_keys_from_source( + model_path: str, save_path: str, shard_size_gb: float = 8.0 +) -> None: + """ + Compare the ORIGINAL checkpoint's index with the just-saved index and + append any keys that ``save_pretrained`` dropped back into ``save_path``. + + Typical case: Multi-Token-Prediction (MTP) layers (e.g. + ``model.layers.80.*`` for HY_V3) that the HF architecture never registers, + so they are absent from ``model.state_dict()`` and therefore not written by + ``save_pretrained``. + + The missing tensors are loaded from the source shards (each containing + source shard is opened once) into memory, then written into one or more + ``model-appended-from-source-XYZ.safetensors`` files inside ``save_path``. + When the total appended size exceeds ``shard_size_gb``, the output is split + across multiple files so no single appended shard exceeds that budget. The + saved ``model.safetensors.index.json`` is then updated to reference them + (an index is synthesized if the save produced a single-file layout). + """ + from safetensors import safe_open + from safetensors.torch import save_file + + # 1. Read source & saved weight maps + src_weight_map, _ = _read_weight_map(model_path) + if not src_weight_map: + print(" [Append][WARNING] No safetensors found in model_path, skipping") + return + + saved_weight_map, saved_is_sharded = _read_weight_map(save_path) + if not saved_weight_map: + print(" [Append][WARNING] No safetensors found in save_path, skipping") + return - missing_keys = set(weight_map.keys()) - loaded_keys + # 2. Compute keys present in source but missing from the saved model + missing_keys = set(src_weight_map.keys()) - set(saved_weight_map.keys()) if not missing_keys: - print(" [MTP] No missing keys, model parameters are complete.") - return {} + print(" [Append] No missing keys; saved model matches source key set.") + return print( - f" [MTP] Found {len(missing_keys)} weight keys not loaded by HF " - f"(e.g. MTP layers), supplementing:" + f" [Append] {len(missing_keys)} key(s) in source but missing from saved " + f"model (e.g. MTP layers), appending:" ) for k in sorted(missing_keys)[:8]: print(f" {k}") if len(missing_keys) > 8: print(f" ... total {len(missing_keys)}") - # Group by shard file, open each shard only once + # 3. Load the missing tensors from the source shards (one open per shard). shard_to_keys: dict[str, list[str]] = {} for k in missing_keys: - shard_to_keys.setdefault(weight_map[k], []).append(k) + shard_to_keys.setdefault(src_weight_map[k], []).append(k) missing_tensors: dict[str, torch.Tensor] = {} for shard_file, keys in sorted(shard_to_keys.items()): @@ -319,160 +341,83 @@ def load_missing_tensors(model_path: str, loaded_keys: set) -> dict: if not t.is_contiguous(): t = t.contiguous() missing_tensors[k] = t - print(f" Loaded {len(keys)} missing keys from {shard_file}") - - return missing_tensors - - -# --------------------------------------------------------------------------- -# Parallel safetensors save -# --------------------------------------------------------------------------- - - -def save_model_parallel( - model: torch.nn.Module, - save_path: str, - shard_size_gb: float = 4.0, - num_workers: int = 4, - extra_state_dict: dict | None = None, -) -> None: - """ - Save model weights as safetensors with parallel shard writing. - - Much faster than save_pretrained (which serializes sequentially). - - Flow: - 1. Collect state_dict, split into shards by shard_size_gb - 2. ThreadPoolExecutor writes shards concurrently - 3. Generate model.safetensors.index.json - 4. Single shard -> write model.safetensors directly (no index) - - Args: - model: transformed model - save_path: output directory (must exist) - shard_size_gb: target size per shard (GiB) - num_workers: concurrent writer threads - extra_state_dict: extra weights to merge (e.g. MTP layers from - load_missing_tensors()) - """ - try: - from safetensors.torch import save_file as st_save_file - except ImportError: - raise ImportError( - "safetensors is required for parallel saving. " "Install with: pip install safetensors" - ) + print(f" Loaded {len(keys)} missing key(s) from {shard_file}") + + # 4. Split the missing tensors into size-bounded groups, then write one + # safetensors file per group (a single tensor larger than the budget + # still goes into its own file on its own). + size_budget = int(shard_size_gb * 1024**3) + groups: list[list[str]] = [] + current: list[str] = [] + current_bytes = 0 + for k in sorted(missing_tensors.keys()): + t = missing_tensors[k] + t_bytes = t.numel() * t.element_size() + if current and current_bytes + t_bytes > size_budget: + groups.append(current) + current = [] + current_bytes = 0 + current.append(k) + current_bytes += t_bytes + if current: + groups.append(current) + + n_groups = len(groups) + appended_weight_map: dict[str, str] = {} + appended_bytes = 0 + for gi, keys in enumerate(groups): + # 1-based, zero-padded index in the classic HF sharding style. + out_shard = f"model-appended-{gi + 1:05d}-of-{n_groups:05d}.safetensors" + group_tensors = {k: missing_tensors[k] for k in keys} + save_file(group_tensors, os.path.join(save_path, out_shard), metadata={"format": "pt"}) + group_bytes = sum(t.numel() * t.element_size() for t in group_tensors.values()) + appended_bytes += group_bytes + for k in keys: + appended_weight_map[k] = out_shard + print(f" Wrote {len(keys)} key(s) ({group_bytes / 1024**3:.2f} GiB) -> {out_shard}") - shard_size_bytes = int(shard_size_gb * 1024**3) - - # 1. Collect state_dict (all contiguous cpu tensors) - print(" [Save] Collecting state_dict ...") - raw_state_dict = model.state_dict() - state_dict: dict[str, torch.Tensor] = {} - for name, t in raw_state_dict.items(): - if t.device.type != "cpu": - t = t.cpu() - if not t.is_contiguous(): - t = t.contiguous() - state_dict[name] = t - - # Merge extra_state_dict (e.g. MTP layers) - if extra_state_dict: - n_extra = 0 - for name, t in extra_state_dict.items(): - if name in state_dict: - print(f" [Save][WARNING] key {name!r} already in state_dict, extra ignored") - continue - if t.device.type != "cpu": - t = t.cpu() - if not t.is_contiguous(): - t = t.contiguous() - state_dict[name] = t - n_extra += 1 - print(f" [Save] Supplemented {n_extra} extra weights (MTP etc.) into state_dict") - - total_bytes = sum(t.nbytes for t in state_dict.values()) - total_gb = total_bytes / 1024**3 - n_shards = max(1, math.ceil(total_bytes / shard_size_bytes)) print( - f" [Save] Total params: {total_gb:.2f} GiB, splitting into {n_shards} shards " - f"(each <= {shard_size_gb} GiB)" + f" [Append] Wrote {len(appended_weight_map)} tensor(s) " + f"({appended_bytes / 1024**3:.2f} GiB) across {n_groups} appended file(s)" ) - # 2. Assign tensors to shards (greedy, no cross-shard splitting) - shard_dicts: list[dict[str, torch.Tensor]] = [] - cur_shard: dict[str, torch.Tensor] = {} - cur_bytes = 0 - - for name, tensor in state_dict.items(): - nb = tensor.nbytes - if cur_bytes + nb > shard_size_bytes and cur_shard: - shard_dicts.append(cur_shard) - cur_shard = {} - cur_bytes = 0 - cur_shard[name] = tensor - cur_bytes += nb - - if cur_shard: - shard_dicts.append(cur_shard) - - n_shards = len(shard_dicts) - - # 3. Write shard files concurrently - if n_shards == 1: - out_file = os.path.join(save_path, "model.safetensors") - print(f" [Save] Single shard -> {out_file}") - st_save_file(shard_dicts[0], out_file) - print(f" [Save] Done: {out_file}") - return - - pad = len(str(n_shards)) - shard_filenames = [ - f"model-{str(i + 1).zfill(pad)}-of-{str(n_shards).zfill(pad)}.safetensors" - for i in range(n_shards) - ] - - def _write_shard(idx: int) -> tuple[int, str]: - fname = shard_filenames[idx] - out_file = os.path.join(save_path, fname) - shard_bytes = sum(t.nbytes for t in shard_dicts[idx].values()) - print( - f" [Save] [{idx+1}/{n_shards}] Writing {fname} " - f"({shard_bytes / 1024**3:.2f} GiB, " - f"{len(shard_dicts[idx])} tensors) ..." - ) - st_save_file(shard_dicts[idx], out_file) - return idx, fname - - effective_workers = min(num_workers, n_shards) - print(f" [Save] Writing {n_shards} shards concurrently, workers={effective_workers}") - - results_list: list[tuple[int, str]] = [] - with ThreadPoolExecutor(max_workers=effective_workers) as executor: - futures = {executor.submit(_write_shard, i): i for i in range(n_shards)} - for fut in as_completed(futures): - exc = fut.exception() - if exc is not None: - raise RuntimeError(f"Shard {futures[fut]} write failed: {exc}") from exc - results_list.append(fut.result()) - - results_list.sort() - - # 4. Generate model.safetensors.index.json - weight_map: dict[str, str] = {} - for idx, fname in results_list: - for key in shard_dicts[idx]: - weight_map[key] = fname - - index = { - "metadata": {"total_size": str(total_bytes)}, - "weight_map": weight_map, - } + # 5. Update / synthesize model.safetensors.index.json so the appended keys + # are discoverable. weight_map must reference every tensor file. index_path = os.path.join(save_path, "model.safetensors.index.json") - with open(index_path, "w", encoding="utf-8") as f: - json.dump(index, f, indent=2, ensure_ascii=False) - - print(f" [Save] Index written: {index_path}") - print(f" [Save] All {n_shards} shards written successfully") + if saved_is_sharded: + with open(index_path, "r") as f: + index = json.load(f) + else: + # Single-file save -> build an index that maps existing keys to + # model.safetensors plus the new keys to the appended shards. Compute + # the existing total_size from the single file. + single_path = os.path.join(save_path, "model.safetensors") + existing_bytes = 0 + with safe_open(single_path, framework="pt", device="cpu") as f_st: + for k in f_st.keys(): + t = f_st.get_slice(k) + shape = t.get_shape() + n = 1 + for d in shape: + n *= d + existing_bytes += ( + n + * torch.empty( + 0, dtype=getattr(torch, t.get_dtype().lower(), torch.float32) + ).element_size() + ) + index = { + "metadata": {"total_size": existing_bytes}, + "weight_map": dict(saved_weight_map), + } + + index["weight_map"].update(appended_weight_map) + index.setdefault("metadata", {}) + index["metadata"]["total_size"] = int(index["metadata"].get("total_size", 0)) + appended_bytes + + with open(index_path, "w") as f: + json.dump(index, f, indent=2) + print(f" [Append] Updated index: {index_path} ({len(index['weight_map'])} keys)") # --------------------------------------------------------------------------- @@ -669,40 +614,30 @@ def main(): print(f"\n[Step 4] Saving transformed model to {args.save_path} ...") os.makedirs(args.save_path, exist_ok=True) - # Supplement MTP / extra weights dropped by AutoModelForCausalLM - print("\n[Step 4.1] Detecting and supplementing extra weights (MTP etc.)...") - loaded_keys = set(model.state_dict().keys()) - extra_state_dict = load_missing_tensors(args.model_path, loaded_keys) - - # Parallel safetensors shard writing - save_model_parallel( - model, + # Save with transformers' standard save_pretrained. It serializes to + # safetensors, shards according to max_shard_size, and writes + # model.safetensors.index.json + config.json automatically. By default + # (save_original_format=True) it also reverts the load-time weight + # conversion, so fused params like experts.gate_up_proj are split back to + # the original per-expert checkpoint layout. + model.save_pretrained( args.save_path, - shard_size_gb=args.shard_size_gb, - num_workers=args.save_workers, - extra_state_dict=extra_state_dict if extra_state_dict else None, + max_shard_size=f"{args.shard_size_gb}GB", + safe_serialization=True, ) + print(f" [Save] Model saved via save_pretrained to {args.save_path}") - # Copy config / tokenizer / non-weight files from original model directory - _NON_WEIGHT_EXTS = {".json", ".txt", ".py", ".model", ".tiktoken"} - _SKIP_FILES = {"model.safetensors.index.json"} - copied = [] - for fname in os.listdir(args.model_path): - if fname in _SKIP_FILES: - continue - if fname.startswith("model-") and fname.endswith(".safetensors"): - continue - if fname == "model.safetensors": - continue - ext = os.path.splitext(fname)[1].lower() - if ext in _NON_WEIGHT_EXTS: - src = os.path.join(args.model_path, fname) - dst = os.path.join(args.save_path, fname) - if not os.path.exists(dst): - shutil.copy2(src, dst) - copied.append(fname) - if copied: - print(f" [Save] Copied config files: {copied}") + # ------------------------------------------------------------------ + # Step 4.1: Append keys present in the source but dropped on save + # ------------------------------------------------------------------ + # save_pretrained only writes tensors that exist in model.state_dict(). + # Weights the HF architecture never registers (e.g. MTP layers such as + # model.layers.80.* for HY_V3) are therefore absent from the saved model. + # Compare the source vs saved safetensors index and append the gap. + print("\n[Step 4.1] Appending source keys missing from the saved model ...") + append_missing_keys_from_source( + args.model_path, args.save_path, shard_size_gb=args.shard_size_gb + ) if tokenizer is not None: tokenizer.save_pretrained(args.save_path)