Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmark/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch.distributed as dist
import yaml

from trinity.algorithm.algorithm import ALGORITHM_TYPE
from trinity.algorithm import ALGORITHM_TYPE
from trinity.common.constants import MODEL_PATH_ENV_VAR, SyncStyle
from trinity.utils.dlc_utils import get_dlc_env_vars

Expand Down
2 changes: 1 addition & 1 deletion benchmark/plugins/guru_math/reward.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Optional

from trinity.common.rewards import REWARD_FUNCTIONS
from trinity.common.rewards.math_reward import MathBoxedRewardFn
from trinity.common.rewards.reward_fn import REWARD_FUNCTIONS


@REWARD_FUNCTIONS.register_module("math_boxed_reward_naive_dapo")
Expand Down
3 changes: 1 addition & 2 deletions benchmark/reports/gsm8k.md
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,11 @@ from typing import List, Optional
import openai
from trinity.common.experience import Experience
from trinity.common.models.model import ModelWrapper
from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow
from trinity.common.workflows.workflow import Task, Workflow

from verl.utils.reward_score import gsm8k


@WORKFLOWS.register_module("verl_gsm8k_workflow")
class VerlGSM8kWorkflow(Workflow):
can_reset: bool = True
can_repeat: bool = True
Expand Down
13 changes: 5 additions & 8 deletions docs/sphinx_doc/source/tutorial/develop_algorithm.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,8 @@ For convenience, Trinity-RFT provides an abstract class {class}`trinity.algorith
Here's an implementation example for the OPMD algorithm's advantage function:

```python
from trinity.algorithm.advantage_fn import ADVANTAGE_FN, GroupAdvantage
from trinity.algorithm.advantage_fn import GroupAdvantage

@ADVANTAGE_FN.register_module("opmd")
class OPMDGroupAdvantage(GroupAdvantage):
"""OPMD Group Advantage computation"""

Expand Down Expand Up @@ -90,7 +89,7 @@ class OPMDGroupAdvantage(GroupAdvantage):
return {"opmd_baseline": "mean", "tau": 1.0}
```

After implementation, you need to register this module through {class}`trinity.algorithm.ADVANTAGE_FN`. Once registered, the module can be configured in the configuration file using the registered name.
After implementation, you need to register this module in the `default_mapping` of `trinity/algorithm/__init__.py`. Once registered, the module can be configured in the configuration file using the registered name.


#### Step 1.2: Implement `PolicyLossFn`
Expand All @@ -100,13 +99,12 @@ Developers need to implement the {class}`trinity.algorithm.PolicyLossFn` interfa
- `__call__`: Calculates the loss based on input parameters. Unlike `AdvantageFn`, the input parameters here are all `torch.Tensor`. This interface automatically scans the parameter list of the `__call__` method and converts it to the corresponding fields in the experience data. Therefore, please write all tensor names needed for loss calculation directly in the parameter list, rather than selecting parameters from `kwargs`.
- `default_args`: Returns default initialization parameters in dictionary form, which will be used by default when users don't specify initialization parameters in the configuration file.

Similarly, after implementation, you need to register this module through {class}`trinity.algorithm.POLICY_LOSS_FN`.
Similarly, after implementation, you need to register this module in the `default_mapping` of `trinity/algorithm/policy_loss_fn/__init__.py`.

Here's an implementation example for the OPMD algorithm's Policy Loss Fn. Since OPMD's Policy Loss only requires logprob, action_mask, and advantages, only these three items are specified in the parameter list of the `__call__` method:


```python
@POLICY_LOSS_FN.register_module("opmd")
class OPMDPolicyLossFn(PolicyLossFn):
def __init__(self, tau: float = 1.0) -> None:
self.tau = tau
Expand Down Expand Up @@ -134,7 +132,7 @@ class OPMDPolicyLossFn(PolicyLossFn):

The above steps implement the components needed for the algorithm, but these components are scattered and need to be configured in multiple places to take effect.

To simplify configuration, Trinity-RFT provides {class}`trinity.algorithm.AlgorithmType` to describe a complete algorithm and registers it in {class}`trinity.algorithm.ALGORITHM_TYPE`, enabling one-click configuration.
To simplify configuration, Trinity-RFT provides {class}`trinity.algorithm.AlgorithmType` to describe a complete algorithm and registers it in `trinity/algorithm/__init__.py`, enabling one-click configuration.

The `AlgorithmType` class includes the following attributes and methods:

Expand All @@ -145,14 +143,13 @@ The `AlgorithmType` class includes the following attributes and methods:
- `schema`: The format of experience data corresponding to the algorithm
- `default_config`: Gets the default configuration of the algorithm, which will override attributes with the same name in `ALGORITHM_TYPE`

Similarly, after implementation, you need to register this module through `ALGORITHM_TYPE`.
Similarly, after implementation, you need to register this module in the `default_mapping` of `trinity/algorithm/__init__.py`.

Below is the implementation for the OPMD algorithm.
Since the OPMD algorithm doesn't need to use the Critic model, `use_critic` is set to `False`.
The dictionary returned by the `default_config` method indicates that OPMD will use the `opmd` type `AdvantageFn` and `PolicyLossFn` implemented in Step 1, will not apply KL Penalty on rewards, but will add a `k2` type KL loss when calculating the final loss.

```python
@ALGORITHM_TYPE.register_module("opmd")
class OPMDAlgorithm(AlgorithmType):
"""OPMD algorithm."""

Expand Down
3 changes: 1 addition & 2 deletions docs/sphinx_doc/source/tutorial/develop_operator.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,10 @@ class ExperienceOperator(ABC):
Here is an implementation of a simple operator that filters out experiences with rewards below a certain threshold:

```python
from trinity.buffer.operators import EXPERIENCE_OPERATORS, ExperienceOperator
from trinity.buffer.operators import ExperienceOperator
from trinity.common.experience import Experience


@EXPERIENCE_OPERATORS.register_module("reward_filter")
class RewardFilter(ExperienceOperator):

def __init__(self, threshold: float = 0.0) -> None:
Expand Down
12 changes: 11 additions & 1 deletion docs/sphinx_doc/source/tutorial/develop_overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,23 @@ The table below lists the main functions of each extension interface, its target
Trinity-RFT provides a modular development approach, allowing you to flexibly add custom modules without modifying the framework code.
You can place your module code in the `trinity/plugins` directory. Trinity-RFT will automatically load all Python files in that directory at runtime and register the custom modules within them.
Trinity-RFT also supports specifying other directories at runtime by setting the `--plugin-dir` option, for example: `trinity run --config <config_file> --plugin-dir <your_plugin_dir>`.
Alternatively, you can use the relative path to the custom module in the YAML configuration file, for example: `default_workflow_type: 'examples.agentscope_frozenlake.workflow.FrozenLakeWorkflow'`.
```

For modules you plan to contribute to Trinity-RFT, please follow these steps:

1. Implement your code in the appropriate directory, such as `trinity/common/workflows` for `Workflow`, `trinity/algorithm` for `Algorithm`, and `trinity/buffer/operators` for `Operator`.

2. Register your module in the corresponding `__init__.py` file of the directory.
2. Register your module in the corresponding mapping dictionary in the `__init__.py` file of the directory.
For example, if you want to register a new workflow class `ExampleWorkflow`, you need to modify the `default_mapping` dictionary of `WORKFLOWS` in the `trinity/common/workflows/__init__.py` file:
```python
WORKFLOWS: Registry = Registry(
"workflows",
default_mapping={
"example_workflow": "trinity.common.workflows.workflow.ExampleWorkflow",
},
)
```

3. Add tests for your module in the `tests` directory, following the naming conventions and structure of existing tests.

Expand Down
16 changes: 11 additions & 5 deletions docs/sphinx_doc/source/tutorial/develop_selector.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ To create a new selector, inherit from `BaseSelector` and implement the followin
This selector focuses on samples whose predicted performance is closest to a target (e.g., 90% success rate), effectively choosing "just right" difficulty tasks.

```python
@SELECTORS.register_module("difficulty_based")
class DifficultyBasedSelector(BaseSelector):
def __init__(self, data_source, config: TaskSelectorConfig) -> None:
super().__init__(data_source, config)
Expand Down Expand Up @@ -125,7 +124,15 @@ class DifficultyBasedSelector(BaseSelector):
self.current_index = state_dict.get("current_index", 0)
```

> 🔁 After defining your class, use `@SELECTORS.register_module("your_name")` so it can be referenced by name in configs.
> 🔁 After defining your class, remember to register it in the `default_mapping` of `trinity/buffer/selector/__init__.py` so it can be referenced by name in configs.
```python
SELECTORS = Registry(
"selectors",
default_mapping={
"difficulty_based": "trinity.buffer.selector.selector.DifficultyBasedSelector",
},
)
```



Expand All @@ -152,7 +159,6 @@ The operator must output a metric under the key `trinity.common.constants.SELECT
#### Example: Pass Rate Calculator

```python
@EXPERIENCE_OPERATORS.register_module("pass_rate_calculator")
class PassRateCalculator(ExperienceOperator):
def __init__(self, **kwargs):
pass
Expand Down Expand Up @@ -194,7 +200,7 @@ After implementing your selector and operator, register them in the config file.
data_processor:
experience_pipeline:
operators:
- name: pass_rate_calculator # Must match @register_module name
- name: pass_rate_calculator
```

#### Configure the Taskset with Your Selector
Expand All @@ -207,7 +213,7 @@ buffer:
storage_type: file
path: ./path/to/tasks
task_selector:
selector_type: difficulty_based # Matches @register_module name
selector_type: difficulty_based
feature_keys: ["correct", "uncertainty"]
kwargs:
m: 16
Expand Down
31 changes: 7 additions & 24 deletions docs/sphinx_doc/source/tutorial/develop_workflow.md
Original file line number Diff line number Diff line change
Expand Up @@ -176,28 +176,16 @@ class ExampleWorkflow(Workflow):

#### Registering Your Workflow

Register your workflow using the `WORKFLOWS.register_module` decorator.
Register your workflow using the `default_mapping` in `trinity/common/workflows/__init__.py`.
Ensure the name does not conflict with existing workflows.

```python
# import some packages
from trinity.common.workflows.workflow import WORKFLOWS

@WORKFLOWS.register_module("example_workflow")
class ExampleWorkflow(Workflow):
pass
```

For workflows that are prepared to be contributed to Trinity-RFT project, you need to place the above code in `trinity/common/workflows` folder, e.g., `trinity/common/workflows/example_workflow.py`. And add the following line to `trinity/common/workflows/__init__.py`:

```python
# existing import lines
from trinity.common.workflows.example_workflow import ExampleWorkflow

__all__ = [
# existing __all__ lines
"ExampleWorkflow",
]
WORKFLOWS = Registry(
"workflows",
default_mapping={
"example_workflow": "trinity.common.workflows.workflow.ExampleWorkflow",
},
)
```

#### Performance Optimization
Expand All @@ -212,7 +200,6 @@ The `can_reset` is a class property that indicates whether the workflow supports
The `reset` method accepts a `Task` parameter and resets the workflow's internal state based on the new task.

```python
@WORKFLOWS.register_module("example_workflow")
class ExampleWorkflow(Workflow):
can_reset: bool = True

Expand All @@ -234,7 +221,6 @@ The `can_repeat` is a class property that indicates whether the workflow support
The `set_repeat_times` method accepts two parameters: `repeat_times` specifies the number of times to execute within the `run` method, and `run_id_base` is an integer used to identify the first run ID in multiple runs (this parameter is used in multi-turn interaction scenarios; for tasks that can be completed with a single model call, this can be ignored).

```python
@WORKFLOWS.register_module("example_workflow")
class ExampleWorkflow(Workflow):
can_repeat: bool = True
# some code
Expand Down Expand Up @@ -275,7 +261,6 @@ class ExampleWorkflow(Workflow):
#### Full Code Example

```python
@WORKFLOWS.register_module("example_workflow")
class ExampleWorkflow(Workflow):
can_reset: bool = True
can_repeat: bool = True
Expand Down Expand Up @@ -359,7 +344,6 @@ trinity run --config <your_yaml_file>
The example above mainly targets synchronous mode. If your workflow needs to use asynchronous methods (e.g., asynchronous API), you can set `is_async` to `True`, then implement the `run_async` method. In this case, you no longer need to implement the `run` method, and the initialization parameter `auxiliary_models` will also change to `List[openai.AsyncOpenAI]`, while other methods and properties remain changed.

```python
@WORKFLOWS.register_module("example_workflow_async")
class ExampleWorkflowAsync(Workflow):

is_async: bool = True
Expand All @@ -386,7 +370,6 @@ explorer:
```

```python
@WORKFLOWS.register_module("example_workflow")
class ExampleWorkflow(Workflow):

def __init__(self, task: Task, model: ModelWrapper, auxiliary_models: List):
Expand Down
2 changes: 0 additions & 2 deletions docs/sphinx_doc/source/tutorial/example_mix_algo.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ The path to expert data is passed to `buffer.trainer_input.auxiliary_buffers.sft
In `trinity/algorithm/algorithm.py`, we introduce a new algorithm type `MIX`.

```python
@ALGORITHM_TYPE.register_module("mix")
class MIXAlgorithm(AlgorithmType):
"""MIX algorithm."""

Expand Down Expand Up @@ -159,7 +158,6 @@ Here we use the `custom_fields` argument of `Experiences.gather_experiences` to
We define a `MixPolicyLoss` class in `trinity/algorithm/policy_loss_fn/mix_policy_loss.py`, which computes the sum of two loss terms regarding usual and expert experiences, respectively.

```python
@POLICY_LOSS_FN.register_module("mix")
class MIXPolicyLossFn(PolicyLossFn):
def __init__(
self,
Expand Down
28 changes: 7 additions & 21 deletions docs/sphinx_doc/source/tutorial/example_multi_turn.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,28 +126,14 @@ class AlfworldWorkflow(MultiTurnWorkflow):
return self.generate_env_inference_samples(env, rollout_n)
```

Also, remember to register your workflow:
Also, remember to register your workflow in the `default_mapping` of `trinity/common/workflows/__init__.py`.
```python
@WORKFLOWS.register_module("alfworld_workflow")
class AlfworldWorkflow(MultiTurnWorkflow):
"""A workflow for alfworld task."""
...
```

and include it in the init file `trinity/common/workflows/__init__.py`

```diff
# -*- coding: utf-8 -*-
"""Workflow module"""
from trinity.common.workflows.workflow import WORKFLOWS, MathWorkflow, SimpleWorkflow
+from trinity.common.workflows.envs.alfworld.alfworld_workflow import AlfworldWorkflow

__all__ = [
"WORKFLOWS",
"SimpleWorkflow",
"MathWorkflow",
+ "AlfworldWorkflow",
]
WORKFLOWS = Registry(
"workflows",
default_mapping={
"alfworld_workflow": "trinity.common.workflows.envs.alfworld.alfworld_workflow.AlfworldWorkflow",
},
)
```

Then you are all set! It should be pretty simple😄, and the training processes in both environments converge.
Expand Down
28 changes: 7 additions & 21 deletions docs/sphinx_doc/source/tutorial/example_step_wise.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,28 +49,14 @@ class StepWiseAlfworldWorkflow(RewardPropagationWorkflow):
return self.final_reward
```

Also, remember to register your workflow:
Also, remember to register your workflow in the `default_mapping` of `trinity/common/workflows/__init__.py`.
```python
@WORKFLOWS.register_module("step_wise_alfworld_workflow")
class StepWiseAlfworldWorkflow(RewardPropagationWorkflow):
"""A step-wise workflow for alfworld task."""
...
```

and include it in the init file `trinity/common/workflows/__init__.py`

```diff
# -*- coding: utf-8 -*-
"""Workflow module"""
from trinity.common.workflows.workflow import WORKFLOWS, MathWorkflow, SimpleWorkflow
+from trinity.common.workflows.envs.alfworld.alfworld_workflow import StepWiseAlfworldWorkflow

__all__ = [
"WORKFLOWS",
"SimpleWorkflow",
"MathWorkflow",
+ "StepWiseAlfworldWorkflow",
]
WORKFLOWS = Registry(
"workflows",
default_mapping={
"step_wise_alfworld_workflow": "trinity.common.workflows.step_wise_workflow.StepWiseAlfworldWorkflow",
},
)
```

### Other Configuration
Expand Down
Loading