-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun_stress_test.py
More file actions
119 lines (98 loc) · 4.62 KB
/
run_stress_test.py
File metadata and controls
119 lines (98 loc) · 4.62 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from safetensors.torch import load_file
from colorama import Fore, Style, init
import os, sys, time
from threading import Thread
init()
print(Fore.CYAN + "--- TFC STRESS TEST TELEMETRY ---" + Style.RESET_ALL)
MODEL_ID = "Qwen/Qwen1.5-MoE-A2.7B"
SYSTEM_PATH = "./TFC_System"
CAST_FOLDER = os.path.join(SYSTEM_PATH, "The_Cast")
DIRECTOR_FILE = os.path.join(SYSTEM_PATH, "The_Director.safetensors")
# --- GLOBAL TELEMETRY STATS ---
# We use this to track what happens during a single query
stats = {
"activations": 0, # Total router decisions
"unique_experts": set(), # Unique experts called (layer, idx)
"new_loads": 0 # Experts fetched from disk this turn
}
# 1. INITIALIZE SKELETON
print("1. Initializing Skeleton...")
config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True, dtype=torch.float32)
# 2. LOAD DIRECTOR
if not os.path.exists(DIRECTOR_FILE): sys.exit("Director missing!")
print("2. Loading Director...")
state = load_file(DIRECTOR_FILE)
model.load_state_dict({k: v.to(dtype=torch.float32) for k, v in state.items()}, strict=False)
del state
model.tie_weights()
# 3. SETUP
loaded_experts = {i: set() for i in range(config.num_hidden_layers)}
device = "mps" if torch.backends.mps.is_available() else "cpu"
print(f"3. Moving to {device.upper()}...")
model.to(device)
# 4. TELEMETRY HOOK
def hook(module, args):
layer_idx = module.layer_id_tracker
hidden_states = args[0]
# Run Router
out = module.gate(hidden_states)
logits = out[0] if isinstance(out, tuple) else out
# Identify Experts
indices = torch.topk(logits, k=4, dim=-1).indices.view(-1).unique().tolist()
# Update Telemetry
stats["activations"] += 1
for i in indices:
stats["unique_experts"].add((layer_idx, i))
# Injection Logic
missing = [i for i in indices if i not in loaded_experts[layer_idx]]
for idx in missing:
try:
stats["new_loads"] += 1 # Track new load
fpath = os.path.join(CAST_FOLDER, f"layer_{layer_idx}_actor_{idx}.safetensors")
wd = load_file(fpath)
module.experts.gate_up_proj.data[idx] = wd["gate_up_proj"].to(device=device, dtype=torch.float32)
module.experts.down_proj.data[idx] = wd["down_proj"].to(device=device, dtype=torch.float32)
loaded_experts[layer_idx].add(idx)
except Exception: pass
for i, layer in enumerate(model.model.layers):
layer.mlp.layer_id_tracker = i
layer.mlp.register_forward_pre_hook(hook)
print(Fore.GREEN + "=== SYSTEM READY FOR STRESS TESTING ===" + Style.RESET_ALL)
# 5. LOOP
while True:
try:
q = input(Fore.YELLOW + "\nMy Lord: " + Style.RESET_ALL)
if q.lower() in ["exit", "quit"]: break
# RESET STATS
stats["activations"] = 0
stats["unique_experts"] = set()
stats["new_loads"] = 0
start_time = time.time()
# GENERATE
txt = tokenizer.apply_chat_template([{"role":"user", "content":q}], tokenize=False, add_generation_prompt=True)
inputs = tokenizer(txt, return_tensors="pt").to(device)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
kwargs = dict(inputs, streamer=streamer, max_new_tokens=250, do_sample=False, pad_token_id=tokenizer.eos_token_id)
Thread(target=model.generate, kwargs=kwargs).start()
print(Fore.WHITE + "TFC: ", end="")
for t in streamer: print(t, end="", flush=True)
print("")
# REPORT
end_time = time.time()
duration = end_time - start_time
total_possible_experts = config.num_hidden_layers * config.num_experts
used_count = len(stats["unique_experts"])
total_loaded = sum(len(v) for v in loaded_experts.values())
print(Fore.MAGENTA + "\n--- TELEMETRY REPORT ---")
print(f"⏱ Time: {duration:.2f}s")
print(f"🧠 Brain Used: {used_count} unique experts ({used_count/total_possible_experts:.1%} of total capacity)")
print(f"💾 RAM Injection: {stats['new_loads']} new files loaded from disk")
print(f"🗄 Total Active: {total_loaded} experts currently residing in RAM")
print(f"⚡ Complexity: {stats['activations']} routing decisions made")
print(Style.RESET_ALL)
except KeyboardInterrupt: break