Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added exp1_logs.txt
Binary file not shown.
Binary file added exp2_logs.txt
Binary file not shown.
Binary file added exp_balanced_peak.ptz
Binary file not shown.
Binary file added exp_deep_stable.ptz
Binary file not shown.
123 changes: 123 additions & 0 deletions exp_deep_stable.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
model_params: 7214100
step:0 val_loss:6.9391 val_bpb:3.0801 time:0ms
step:1/20000 train_loss:6.9395 time:1867ms
step:200/20000 train_loss:3.4674 time:356549ms
step:400/20000 train_loss:2.9106 time:713415ms
step:600/20000 train_loss:2.6936 time:1070507ms
step:800/20000 train_loss:2.6614 time:1427634ms
step:1000/20000 train_loss:2.5387 time:1784223ms
step:1000 val_loss:2.5594 val_bpb:1.1361 time:1784250ms
step:1200/20000 train_loss:2.5454 time:2141147ms
step:1400/20000 train_loss:2.5426 time:2498112ms
step:1600/20000 train_loss:2.4487 time:2855086ms
step:1800/20000 train_loss:2.3379 time:3212126ms
step:2000/20000 train_loss:2.4562 time:3568629ms
step:2000 val_loss:2.4289 val_bpb:1.0781 time:3568656ms
step:2200/20000 train_loss:2.4138 time:3925446ms
step:2400/20000 train_loss:2.4391 time:4282379ms
step:2600/20000 train_loss:2.4551 time:4639388ms
step:2800/20000 train_loss:2.5381 time:4996325ms
step:3000/20000 train_loss:2.3298 time:5352934ms
step:3000 val_loss:2.3803 val_bpb:1.0566 time:5352961ms
step:3200/20000 train_loss:2.3655 time:5710027ms
step:3400/20000 train_loss:2.3240 time:6066993ms
step:3600/20000 train_loss:2.3648 time:6423945ms
step:3800/20000 train_loss:2.3503 time:6780408ms
step:4000/20000 train_loss:2.2927 time:7137404ms
step:4000 val_loss:2.3472 val_bpb:1.0419 time:7137430ms
step:4200/20000 train_loss:2.2679 time:7494473ms
step:4400/20000 train_loss:2.3620 time:7851536ms
step:4600/20000 train_loss:2.2581 time:8208549ms
step:4800/20000 train_loss:2.3169 time:8565316ms
step:5000/20000 train_loss:2.2840 time:8922606ms
step:5000 val_loss:2.3301 val_bpb:1.0343 time:8922633ms
step:5200/20000 train_loss:2.3659 time:9279791ms
step:5400/20000 train_loss:2.3127 time:9636974ms
step:5600/20000 train_loss:2.3703 time:9994176ms
step:5800/20000 train_loss:2.3227 time:10350840ms
step:6000/20000 train_loss:2.3040 time:10708007ms
step:6000 val_loss:2.3168 val_bpb:1.0284 time:10708034ms
step:6200/20000 train_loss:2.2733 time:11065082ms
step:6400/20000 train_loss:2.2992 time:11422134ms
step:6600/20000 train_loss:2.3462 time:11778809ms
step:6800/20000 train_loss:2.3268 time:12135985ms
step:7000/20000 train_loss:2.3172 time:12493125ms
step:7000 val_loss:2.3070 val_bpb:1.0240 time:12493152ms
step:7200/20000 train_loss:2.3400 time:12850075ms
step:7400/20000 train_loss:2.3022 time:13207262ms
step:7600/20000 train_loss:2.3172 time:13563930ms
step:7800/20000 train_loss:2.2923 time:13921208ms
step:8000/20000 train_loss:2.4725 time:14278302ms
step:8000 val_loss:2.2961 val_bpb:1.0192 time:14278328ms
step:8200/20000 train_loss:2.2816 time:14635460ms
step:8400/20000 train_loss:2.2776 time:14992494ms
step:8600/20000 train_loss:2.3197 time:15349042ms
step:8800/20000 train_loss:2.2828 time:15706095ms
step:9000/20000 train_loss:2.3846 time:16063234ms
step:9000 val_loss:2.2874 val_bpb:1.0153 time:16063261ms
step:9200/20000 train_loss:2.2827 time:16420410ms
step:9400/20000 train_loss:2.2857 time:16777125ms
step:9600/20000 train_loss:2.3054 time:17134224ms
step:9800/20000 train_loss:2.2765 time:17491346ms
step:10000/20000 train_loss:2.2645 time:17848514ms
step:10000 val_loss:2.2813 val_bpb:1.0126 time:17848541ms
step:10200/20000 train_loss:2.2942 time:18205640ms
step:10400/20000 train_loss:2.2714 time:18562294ms
step:10600/20000 train_loss:2.2175 time:18919345ms
step:10800/20000 train_loss:2.2286 time:19276431ms
step:11000/20000 train_loss:2.3317 time:19633544ms
step:11000 val_loss:2.2760 val_bpb:1.0103 time:19633571ms
step:11200/20000 train_loss:2.2325 time:19990561ms
step:11400/20000 train_loss:2.3167 time:20347219ms
step:11600/20000 train_loss:2.2806 time:20704380ms
step:11800/20000 train_loss:2.2680 time:21061526ms
step:12000/20000 train_loss:2.2483 time:21418563ms
step:12000 val_loss:2.2716 val_bpb:1.0083 time:21418590ms
step:12200/20000 train_loss:2.2979 time:21775193ms
step:12400/20000 train_loss:2.4286 time:22132106ms
step:12600/20000 train_loss:2.2667 time:22489177ms
step:12800/20000 train_loss:2.3347 time:22846141ms
step:13000/20000 train_loss:2.3092 time:23203249ms
step:13000 val_loss:2.2664 val_bpb:1.0060 time:23203275ms
step:13200/20000 train_loss:2.2674 time:23559741ms
step:13400/20000 train_loss:2.1219 time:23916757ms
step:13600/20000 train_loss:2.3351 time:24273831ms
step:13800/20000 train_loss:2.1970 time:24630769ms
step:14000/20000 train_loss:2.2588 time:24987888ms
step:14000 val_loss:2.2648 val_bpb:1.0053 time:24987915ms
step:14200/20000 train_loss:2.2442 time:25344451ms
step:14400/20000 train_loss:2.2547 time:25701380ms
step:14600/20000 train_loss:2.3532 time:26058340ms
step:14800/20000 train_loss:2.2392 time:26415282ms
step:15000/20000 train_loss:2.2752 time:26771797ms
step:15000 val_loss:2.2610 val_bpb:1.0036 time:26771824ms
step:15200/20000 train_loss:2.2623 time:27128714ms
step:15400/20000 train_loss:2.2178 time:27485680ms
step:15600/20000 train_loss:2.2268 time:27842632ms
step:15800/20000 train_loss:2.2129 time:28199690ms
step:16000/20000 train_loss:2.1928 time:28556155ms
step:16000 val_loss:2.2559 val_bpb:1.0013 time:28556182ms
step:16200/20000 train_loss:2.0934 time:28913073ms
step:16400/20000 train_loss:2.2089 time:29270064ms
step:16600/20000 train_loss:2.2424 time:29626980ms
step:16800/20000 train_loss:2.3310 time:29984060ms
step:17000/20000 train_loss:2.1701 time:30340626ms
step:17000 val_loss:2.2533 val_bpb:1.0002 time:30340652ms
step:17200/20000 train_loss:2.2980 time:30697720ms
step:17400/20000 train_loss:2.2495 time:31054901ms
step:17600/20000 train_loss:2.2457 time:31412053ms
step:17800/20000 train_loss:2.2284 time:31768663ms
step:18000/20000 train_loss:2.2611 time:32125810ms
step:18000 val_loss:2.2502 val_bpb:0.9988 time:32125836ms
step:18200/20000 train_loss:2.2404 time:32483151ms
step:18400/20000 train_loss:2.2179 time:32840284ms
step:18600/20000 train_loss:2.2812 time:33197342ms
step:18800/20000 train_loss:2.2552 time:33553864ms
step:19000/20000 train_loss:2.3222 time:33910761ms
step:19000 val_loss:2.2481 val_bpb:0.9979 time:33910788ms
step:19200/20000 train_loss:2.2750 time:34267930ms
step:19400/20000 train_loss:2.3267 time:34625031ms
step:19600/20000 train_loss:2.2458 time:34982029ms
step:19800/20000 train_loss:2.2294 time:35338887ms
step:20000/20000 train_loss:2.2933 time:35696205ms
step:20000 val_loss:2.2451 val_bpb:0.9965 time:35696231ms
75 changes: 75 additions & 0 deletions exp_large.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
model_params: 4722704
step:0 val_loss:6.9313 val_bpb:3.0767 time:0ms
step:1/100000 train_loss:6.9316 time:1366ms
step:200/100000 train_loss:3.3504 time:252792ms
step:400/100000 train_loss:2.7117 time:505701ms
step:600/100000 train_loss:2.7676 time:758655ms
step:800/100000 train_loss:2.6034 time:1011610ms
step:1000/100000 train_loss:2.6137 time:1264517ms
step:1000 val_loss:2.5906 val_bpb:1.1499 time:1264544ms
step:1200/100000 train_loss:2.5432 time:1517485ms
step:1400/100000 train_loss:2.5791 time:1770502ms
step:1600/100000 train_loss:2.4673 time:2023463ms
step:1800/100000 train_loss:2.5169 time:2276429ms
step:2000/100000 train_loss:2.4716 time:2529429ms
step:2000 val_loss:2.4860 val_bpb:1.1035 time:2529456ms
step:2200/100000 train_loss:2.4066 time:2782329ms
step:2400/100000 train_loss:2.4410 time:3035326ms
step:2600/100000 train_loss:2.5189 time:3288278ms
step:2800/100000 train_loss:2.4805 time:3541314ms
step:3000/100000 train_loss:2.3908 time:3794405ms
step:3000 val_loss:2.4431 val_bpb:1.0844 time:3794432ms
step:3200/100000 train_loss:2.4809 time:4047561ms
step:3400/100000 train_loss:2.4615 time:4301295ms
step:3600/100000 train_loss:2.4051 time:4554080ms
step:3800/100000 train_loss:2.4890 time:4807127ms
step:4000/100000 train_loss:2.3803 time:5060054ms
step:4000 val_loss:2.4200 val_bpb:1.0742 time:5060082ms
step:4200/100000 train_loss:2.4535 time:5313409ms
step:4400/100000 train_loss:2.4609 time:5566296ms
step:4600/100000 train_loss:2.3022 time:5819221ms
step:4800/100000 train_loss:2.4082 time:6072129ms
step:5000/100000 train_loss:2.4053 time:6325152ms
step:5000 val_loss:2.4035 val_bpb:1.0668 time:6325179ms
step:5200/100000 train_loss:2.4255 time:6578051ms
step:5400/100000 train_loss:2.4298 time:6830983ms
step:5600/100000 train_loss:2.4312 time:7083927ms
step:5800/100000 train_loss:2.4098 time:7336854ms
step:6000/100000 train_loss:2.5350 time:7589790ms
step:6000 val_loss:2.3963 val_bpb:1.0636 time:7589817ms
step:6200/100000 train_loss:2.3915 time:7842747ms
step:6400/100000 train_loss:2.4177 time:8095560ms
step:6600/100000 train_loss:2.3729 time:8348483ms
step:6800/100000 train_loss:2.4950 time:8601352ms
step:7000/100000 train_loss:2.3784 time:8854296ms
step:7000 val_loss:2.3845 val_bpb:1.0584 time:8854323ms
step:7200/100000 train_loss:2.4101 time:9107260ms
step:7400/100000 train_loss:2.3841 time:9360233ms
step:7600/100000 train_loss:2.3276 time:9613224ms
step:7800/100000 train_loss:2.3774 time:9866197ms
step:8000/100000 train_loss:2.3701 time:10119195ms
step:8000 val_loss:2.3764 val_bpb:1.0548 time:10119222ms
step:8200/100000 train_loss:2.3478 time:10372158ms
step:8400/100000 train_loss:2.3402 time:10625650ms
step:8600/100000 train_loss:2.3790 time:10878633ms
step:8800/100000 train_loss:2.3785 time:11131675ms
step:9000/100000 train_loss:2.3425 time:11384741ms
step:9000 val_loss:2.3722 val_bpb:1.0530 time:11384768ms
step:9200/100000 train_loss:2.3309 time:11637707ms
step:9400/100000 train_loss:2.3777 time:11890682ms
step:9600/100000 train_loss:2.4305 time:12143596ms
step:9800/100000 train_loss:2.3388 time:12396581ms
step:10000/100000 train_loss:2.2825 time:12649568ms
step:10000 val_loss:2.3684 val_bpb:1.0513 time:12649595ms
step:10200/100000 train_loss:2.4388 time:12902507ms
step:10400/100000 train_loss:2.3939 time:13155443ms
step:10600/100000 train_loss:2.3045 time:13408383ms
step:10800/100000 train_loss:2.3808 time:13661295ms
step:11000/100000 train_loss:2.3669 time:13914261ms
step:11000 val_loss:2.3632 val_bpb:1.0490 time:13914288ms
step:11200/100000 train_loss:2.3381 time:14167332ms
step:11400/100000 train_loss:2.3324 time:14420201ms
step:11600/100000 train_loss:2.3328 time:14673080ms
step:11800/100000 train_loss:2.3933 time:14926023ms
step:12000/100000 train_loss:2.2873 time:15179072ms
step:12000 val_loss:2.3438 val_bpb:1.0404 time:15179099ms
39 changes: 39 additions & 0 deletions exp_xl.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
model_params: 7214100
step:0 val_loss:6.9392 val_bpb:3.0801 time:0ms
step:1/150000 train_loss:6.9393 time:3665ms
step:200/150000 train_loss:3.4129 time:720568ms
step:400/150000 train_loss:2.7769 time:1441803ms
step:600/150000 train_loss:2.5164 time:2163050ms
step:800/150000 train_loss:2.5384 time:2884328ms
step:1000/150000 train_loss:2.4318 time:3605217ms
step:1000 val_loss:2.4864 val_bpb:1.1036 time:3605277ms
step:1200/150000 train_loss:2.4116 time:4326409ms
step:1400/150000 train_loss:2.4039 time:5047843ms
step:1600/150000 train_loss:2.4459 time:5768990ms
step:1800/150000 train_loss:2.4444 time:6490127ms
step:2000/150000 train_loss:2.3412 time:7211520ms
step:2000 val_loss:2.3684 val_bpb:1.0513 time:7211580ms
step:2200/150000 train_loss:2.3277 time:7931124ms
step:2400/150000 train_loss:2.3313 time:8652261ms
step:2600/150000 train_loss:2.4536 time:9373368ms
step:2800/150000 train_loss:2.3634 time:10094473ms
step:3000/150000 train_loss:2.3563 time:10813916ms
step:3000 val_loss:2.3277 val_bpb:1.0332 time:10813976ms
step:3200/150000 train_loss:2.3075 time:11535027ms
step:3400/150000 train_loss:2.3187 time:12256350ms
step:3600/150000 train_loss:2.2674 time:12977761ms
step:3800/150000 train_loss:2.3659 time:13699086ms
step:4000/150000 train_loss:2.3068 time:14420163ms
step:4000 val_loss:2.3065 val_bpb:1.0238 time:14420223ms
step:4200/150000 train_loss:2.2895 time:15141323ms
step:4400/150000 train_loss:2.3260 time:15862396ms
step:4600/150000 train_loss:2.2911 time:16583460ms
step:4800/150000 train_loss:2.2846 time:17302885ms
step:5000/150000 train_loss:2.2688 time:18024135ms
step:5000 val_loss:2.2900 val_bpb:1.0165 time:18024195ms
step:5200/150000 train_loss:2.3081 time:18743467ms
step:5400/150000 train_loss:2.3612 time:19464628ms
step:5600/150000 train_loss:2.2872 time:20185876ms
step:5800/150000 train_loss:2.3083 time:20907058ms
step:6000/150000 train_loss:2.2619 time:21628023ms
step:6000 val_loss:2.2666 val_bpb:1.0061 time:21628083ms
Binary file added exp_xl_model.ptz
Binary file not shown.
Binary file added final_model.int8.ptz
Binary file not shown.
111 changes: 111 additions & 0 deletions infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import torch
import torch.nn.functional as F
import sentencepiece as spm
import zlib
import io
import os

from train_gpt_optimized import dequantize_state_dict_int8, GPT, Hyperparameters

# Configure hyperparameters to match the fast training run
os.environ["NUM_LAYERS"] = "4"
os.environ["MODEL_DIM"] = "256"
os.environ["NUM_HEADS"] = "4"
os.environ["NUM_KV_HEADS"] = "4"

# Load tokenizer
sp = spm.SentencePieceProcessor()
sp.load("./data/tokenizers/fineweb_1024_bpe.model")

# Load and decompress model weights
print("Loading and dequantizing model...")
with open("final_model.int8.ptz", "rb") as f:
q_obj = torch.load(io.BytesIO(zlib.decompress(f.read())), map_location="cpu")

state_dict = dequantize_state_dict_int8(q_obj)

# Initialize model
args = Hyperparameters()
args.num_steps = 4
args.model_dim = 256
args.num_heads = 4
args.num_kv_heads = 4
model = GPT(args).bfloat16()
model.load_state_dict(state_dict)
model.eval()
print("Model loaded successfully!")

def patched_forward(self, input_ids):
x = F.rms_norm(self.tok_emb(input_ids), (self.args.model_dim,))
x0 = x
for i in range(self.args.num_steps):
block_idx = i % self.args.num_unique_blocks
x = self.unique_blocks[block_idx](x, x0)

x = self.final_norm(x) # [bsz, seq_len, dim]
logits_proj = F.linear(x, self.tok_emb.weight) if self.args.tie_embeddings else self.lm_head(x)
logits = self.args.logit_softcap * torch.tanh(logits_proj / self.args.logit_softcap)
return logits

GPT.forward = patched_forward

def sample_logits(logits, temperature=0.8, top_k=40, top_p=0.9):
logits = logits / temperature
probs = F.softmax(logits, dim=-1)

# Top-k
if top_k > 0:
values, indices = torch.topk(probs, top_k)
probs_filtered = torch.zeros_like(probs)
probs_filtered.scatter_(0, indices, values)
probs = probs_filtered / probs_filtered.sum()

# Top-p (nucleus)
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
cumulative_probs = torch.cumsum(sorted_probs, dim=0)

cutoff = cumulative_probs > top_p
if torch.any(cutoff):
cutoff_idx = torch.where(cutoff)[0][0]
sorted_probs[cutoff_idx:] = 0
if sorted_probs.sum() > 0:
sorted_probs /= sorted_probs.sum()
probs = torch.zeros_like(probs)
probs.scatter_(0, sorted_indices, sorted_probs)

return torch.multinomial(probs, 1).item()

def generate(prompt, max_tokens=80, temperature=0.8, top_k=40, top_p=0.9, repetition_penalty=1.2):
tokens = sp.encode(prompt)
tokens = torch.tensor(tokens, dtype=torch.long).unsqueeze(0)
generated = tokens.clone()

for _ in range(max_tokens):
with torch.no_grad():
with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
logits = model(generated)[0, -1].float()

# repetition penalty
for token_id in set(generated[0].tolist()):
if logits[token_id] < 0:
logits[token_id] *= repetition_penalty
else:
logits[token_id] /= repetition_penalty

next_token = sample_logits(logits, temperature, top_k, top_p)
generated = torch.cat([generated, torch.tensor([[next_token]])], dim=1)

return sp.decode(generated[0].tolist())

# Test prompts
prompts = [
"The future of AI is",
"Once upon a time",
"India is known for",
"The meaning of life is"
]

for p in prompts:
print("\nPROMPT:", p)
print(generate(p))

Loading