-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodal_train.py
More file actions
44 lines (38 loc) · 1.03 KB
/
modal_train.py
File metadata and controls
44 lines (38 loc) · 1.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import modal
checkpoints_vol = modal.Volume.from_name("checkpoints", create_if_missing=True)
app = modal.App(name="train")
image = (
modal.Image.debian_slim()
.pip_install("torch", "numpy<2", "pandas", "tqdm", "transformers")
.add_local_dir(".", "/root/project")
.add_local_dir("data", "/root/project/data")
)
@app.function(
gpu="H100",
image=image,
timeout=60*60,
volumes={"/vol/checkpoints": checkpoints_vol}
)
def train_main(model_name: str, dataset: str):
import sys
sys.path.append("/root/project")
from dlkth.train import train_workflow
train_workflow(model_name, dataset, save_dir="/vol/checkpoints")
checkpoints_vol.commit()
@app.local_entrypoint()
def main():
datasets = [
# "el_quijote",
"valenciano"
# "shakespeare"
]
models = [
# "bigram",
# "rnn",
#"rnn_baseline",
"lstm"
# "transformer",
]
for dataset in datasets:
for model in models:
train_main.remote(model, dataset)