Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 14 additions & 15 deletions docs/sphinx_doc/source/tutorial/trinity_configs.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,18 @@

The following is the main config file for Trinity-RFT. Take `scripts/config/countdown.yaml` as an example.


## Monitor

```yaml
monitor:
project: "Trinity-RFT-countdown"
name: "qwen2.5-1.5B-countdown"
```

- `monitor.project`: The project name. It must be set manually.
- `monitor.name`: The name of the experiment. It must be set manually.

## Data

<!-- The `data` configuration specifies the data used for training. It includes the total number of epochs, the batch size, the path to the dataset, the default workflow type, the default reward function type, and the format configuration. -->
Expand Down Expand Up @@ -53,15 +65,13 @@ model:
max_prompt_tokens: 256
max_response_tokens: 1024
checkpoint_path: 'checkpoints/qwen2.5-1.5B-countdown'
load_checkpoint: true
```

- `model.model_path`: The path to the model checkpoint. It must be set manually.
- `model.critic_model_path`: The path to the critic model checkpoint. If not set, the `model.critic_model_path` will be set to `model.model_path`.
- `model.max_prompt_tokens`: The maximum number of tokens in the prompt. Default is `2048`. It should be set manually.
- `model.max_response_tokens`: The maximum number of tokens in the response. Default is `2048`. It should be set manually.
- `model.checkpoint_path`: The path to the checkpoint of the model. It must be set manually.
- `model.load_checkpoint`: Whether to load the checkpoint of the model. Default is `true`.

## Cluster

Expand Down Expand Up @@ -149,19 +159,6 @@ synchronizer:
- `synchronizer.sync_method`: The synchronization method, Support `online` and `offline`. Default is `online`.
- `synchronizer.sync_iteration_interval`: The interval between two synchronizations. Default is `10`. It should be set manually.

## Monitor

```yaml
monitor:
cache_root_dir: ""
project: "Trinity-RFT-countdown"
name: "qwen2.5-1.5B-countdown"
```

- `monitor.cache_root_dir`: The root directory of the cache. Default is `os.path.join(model.checkpoint_path, ".cache")`.
- `monitor.project`: The project name. It must be set manually.
- `monitor.name`: The name of the experiment. It must be set manually.

## Trainer

```yaml
Expand Down Expand Up @@ -386,6 +383,7 @@ trainer:
- `actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu`: Batch size for one GPU in one forward pass.
- `actor_rollout_ref.actor.grad_clip`: Gradient clip for actor model training.
- `actor_rollout_ref.actor.clip_ratio`: Used for compute policy loss.
- `actor_rollout_ref.actor.entropy_coeff`: Used for compute policy loss.
- `actor_rollout_ref.actor.use_kl_loss`: True for GRPO.
- `actor_rollout_ref.actor.kl_loss_coef`: Used for GRPO, optional value is `kl`, `abs`, `mse` or `low_var_kl`.
- `actor_rollout_ref.actor.ulysses_sequence_parallel_size`: Ulysses sequence parallel size.
Expand All @@ -412,6 +410,7 @@ trainer:

- `algorithm`: Training algorithm settings.

- `trainer.balance_batch`: Whether to balance batch size between GPUs during training.
- `trainer.save_freq`: Frequency of saving checkpoints.
- `trainer.resume_mode`: Resume mode for training. Support `disable`, `auto` and `resume_path`.
- `trainer.resume_from_path`: Path to resume from.
Expand Down
1 change: 0 additions & 1 deletion scripts/config/alfworld.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ model:
max_prompt_tokens: 4096
max_response_tokens: 16384
checkpoint_path: 'checkpoints/ALFWORLD_RFT'
load_checkpoint: true
cluster:
node_num: 1
gpu_per_node: 8
Expand Down
1 change: 0 additions & 1 deletion scripts/config/countdown.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ model:
max_prompt_tokens: 256
max_response_tokens: 1024
checkpoint_path: 'checkpoints/qwen2.5-1.5B-countdown'
load_checkpoint: true
cluster:
node_num: 1
gpu_per_node: 8
Expand Down
1 change: 0 additions & 1 deletion scripts/config/dpo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ model:
max_prompt_tokens: 1792
max_response_tokens: 256
checkpoint_path: 'checkpoints/trinity_dpo'
load_checkpoint: true
cluster:
node_num: 1
gpu_per_node: 8
Expand Down
1 change: 0 additions & 1 deletion scripts/config/gsm8k.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ model:
max_prompt_tokens: 256
max_response_tokens: 1024
checkpoint_path: '/PATH/TO/CHECKPOINT/'
load_checkpoint: true
cluster:
node_num: 1
gpu_per_node: 8
Expand Down
1 change: 0 additions & 1 deletion scripts/config/gsm8k_opmd.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ model:
max_prompt_tokens: 256
max_response_tokens: 1024
checkpoint_path: '{path to checkpoints}/test-opmd-gsm8k/qwen2.5-1.5B-gsm8k-opmd-kl_0.001-entropy_0-tau_4-beta1_0.0-beta2_0.95-lr_2e-6-sync10'
load_checkpoint: false
cluster:
node_num: 1
gpu_per_node: 8
Expand Down
1 change: 0 additions & 1 deletion scripts/config/webshop.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ model:
max_prompt_tokens: 4096
max_response_tokens: 16384
checkpoint_path: 'checkpoints/WEBSHOP_RFT'
load_checkpoint: true
cluster:
node_num: 1
gpu_per_node: 8
Expand Down
2 changes: 0 additions & 2 deletions tests/common/tmp/template_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ model:
max_prompt_tokens: 2048
max_response_tokens: 2048
checkpoint_path: ''
load_checkpoint: true
cluster:
node_num: 1
gpu_per_node: 8
Expand Down Expand Up @@ -60,7 +59,6 @@ trainer:
trainer_config_path: tests/common/tmp/template_verl_config.yaml
monitor:
project: unittest
group: test
name: test
synchronizer:
sync_method: offline
Expand Down
1 change: 0 additions & 1 deletion tests/test_data/template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ model:
max_prompt_tokens: 2048
max_response_tokens: 2048
checkpoint_path: ''
load_checkpoint: true
cluster:
node_num: 1
gpu_per_node: 8
Expand Down
3 changes: 0 additions & 3 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ class ModelConfig:
max_response_tokens: int = 2048
# The checkpoint directory, contains a latest dir link and multiple checkpoint dirs.
checkpoint_path: str = ""
load_checkpoint: bool = True


@dataclass
Expand Down Expand Up @@ -201,8 +200,6 @@ class MonitorConfig:
# TODO: add more
project: str = "trinity"
name: str = "rft"
group: str = ""
run_id: str = ""

# ! DO NOT SET
# the root directory for cache and meta files, automatically generated
Expand Down
Loading