diff --git a/benchmark/bench.py b/benchmark/bench.py index ac336d904a..07a481c312 100644 --- a/benchmark/bench.py +++ b/benchmark/bench.py @@ -9,7 +9,7 @@ import torch.distributed as dist import yaml -from trinity.algorithm.algorithm import ALGORITHM_TYPE +from trinity.algorithm import ALGORITHM_TYPE from trinity.common.constants import MODEL_PATH_ENV_VAR, SyncStyle from trinity.utils.dlc_utils import get_dlc_env_vars diff --git a/benchmark/plugins/guru_math/reward.py b/benchmark/plugins/guru_math/reward.py index d30d60de44..f273e4f5e1 100644 --- a/benchmark/plugins/guru_math/reward.py +++ b/benchmark/plugins/guru_math/reward.py @@ -1,7 +1,7 @@ from typing import Optional +from trinity.common.rewards import REWARD_FUNCTIONS from trinity.common.rewards.math_reward import MathBoxedRewardFn -from trinity.common.rewards.reward_fn import REWARD_FUNCTIONS @REWARD_FUNCTIONS.register_module("math_boxed_reward_naive_dapo") diff --git a/benchmark/reports/gsm8k.md b/benchmark/reports/gsm8k.md index 3ebfeb57f8..0c43309255 100644 --- a/benchmark/reports/gsm8k.md +++ b/benchmark/reports/gsm8k.md @@ -174,12 +174,11 @@ from typing import List, Optional import openai from trinity.common.experience import Experience from trinity.common.models.model import ModelWrapper -from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow +from trinity.common.workflows.workflow import Task, Workflow from verl.utils.reward_score import gsm8k -@WORKFLOWS.register_module("verl_gsm8k_workflow") class VerlGSM8kWorkflow(Workflow): can_reset: bool = True can_repeat: bool = True diff --git a/docs/sphinx_doc/source/tutorial/develop_algorithm.md b/docs/sphinx_doc/source/tutorial/develop_algorithm.md index 41b6a184b9..89b0f913b4 100644 --- a/docs/sphinx_doc/source/tutorial/develop_algorithm.md +++ b/docs/sphinx_doc/source/tutorial/develop_algorithm.md @@ -47,9 +47,8 @@ For convenience, Trinity-RFT provides an abstract class {class}`trinity.algorith Here's an implementation example for the OPMD algorithm's advantage function: ```python -from trinity.algorithm.advantage_fn import ADVANTAGE_FN, GroupAdvantage +from trinity.algorithm.advantage_fn import GroupAdvantage -@ADVANTAGE_FN.register_module("opmd") class OPMDGroupAdvantage(GroupAdvantage): """OPMD Group Advantage computation""" @@ -90,7 +89,7 @@ class OPMDGroupAdvantage(GroupAdvantage): return {"opmd_baseline": "mean", "tau": 1.0} ``` -After implementation, you need to register this module through {class}`trinity.algorithm.ADVANTAGE_FN`. Once registered, the module can be configured in the configuration file using the registered name. +After implementation, you need to register this module in the `default_mapping` of `trinity/algorithm/__init__.py`. Once registered, the module can be configured in the configuration file using the registered name. #### Step 1.2: Implement `PolicyLossFn` @@ -100,13 +99,12 @@ Developers need to implement the {class}`trinity.algorithm.PolicyLossFn` interfa - `__call__`: Calculates the loss based on input parameters. Unlike `AdvantageFn`, the input parameters here are all `torch.Tensor`. This interface automatically scans the parameter list of the `__call__` method and converts it to the corresponding fields in the experience data. Therefore, please write all tensor names needed for loss calculation directly in the parameter list, rather than selecting parameters from `kwargs`. - `default_args`: Returns default initialization parameters in dictionary form, which will be used by default when users don't specify initialization parameters in the configuration file. -Similarly, after implementation, you need to register this module through {class}`trinity.algorithm.POLICY_LOSS_FN`. +Similarly, after implementation, you need to register this module in the `default_mapping` of `trinity/algorithm/policy_loss_fn/__init__.py`. Here's an implementation example for the OPMD algorithm's Policy Loss Fn. Since OPMD's Policy Loss only requires logprob, action_mask, and advantages, only these three items are specified in the parameter list of the `__call__` method: ```python -@POLICY_LOSS_FN.register_module("opmd") class OPMDPolicyLossFn(PolicyLossFn): def __init__(self, tau: float = 1.0) -> None: self.tau = tau @@ -134,7 +132,7 @@ class OPMDPolicyLossFn(PolicyLossFn): The above steps implement the components needed for the algorithm, but these components are scattered and need to be configured in multiple places to take effect. -To simplify configuration, Trinity-RFT provides {class}`trinity.algorithm.AlgorithmType` to describe a complete algorithm and registers it in {class}`trinity.algorithm.ALGORITHM_TYPE`, enabling one-click configuration. +To simplify configuration, Trinity-RFT provides {class}`trinity.algorithm.AlgorithmType` to describe a complete algorithm and registers it in `trinity/algorithm/__init__.py`, enabling one-click configuration. The `AlgorithmType` class includes the following attributes and methods: @@ -145,14 +143,13 @@ The `AlgorithmType` class includes the following attributes and methods: - `schema`: The format of experience data corresponding to the algorithm - `default_config`: Gets the default configuration of the algorithm, which will override attributes with the same name in `ALGORITHM_TYPE` -Similarly, after implementation, you need to register this module through `ALGORITHM_TYPE`. +Similarly, after implementation, you need to register this module in the `default_mapping` of `trinity/algorithm/__init__.py`. Below is the implementation for the OPMD algorithm. Since the OPMD algorithm doesn't need to use the Critic model, `use_critic` is set to `False`. The dictionary returned by the `default_config` method indicates that OPMD will use the `opmd` type `AdvantageFn` and `PolicyLossFn` implemented in Step 1, will not apply KL Penalty on rewards, but will add a `k2` type KL loss when calculating the final loss. ```python -@ALGORITHM_TYPE.register_module("opmd") class OPMDAlgorithm(AlgorithmType): """OPMD algorithm.""" diff --git a/docs/sphinx_doc/source/tutorial/develop_operator.md b/docs/sphinx_doc/source/tutorial/develop_operator.md index 394b3fb175..640624775a 100644 --- a/docs/sphinx_doc/source/tutorial/develop_operator.md +++ b/docs/sphinx_doc/source/tutorial/develop_operator.md @@ -40,11 +40,10 @@ class ExperienceOperator(ABC): Here is an implementation of a simple operator that filters out experiences with rewards below a certain threshold: ```python -from trinity.buffer.operators import EXPERIENCE_OPERATORS, ExperienceOperator +from trinity.buffer.operators import ExperienceOperator from trinity.common.experience import Experience -@EXPERIENCE_OPERATORS.register_module("reward_filter") class RewardFilter(ExperienceOperator): def __init__(self, threshold: float = 0.0) -> None: diff --git a/docs/sphinx_doc/source/tutorial/develop_overview.md b/docs/sphinx_doc/source/tutorial/develop_overview.md index 91eecc2d60..c6256f6ae1 100644 --- a/docs/sphinx_doc/source/tutorial/develop_overview.md +++ b/docs/sphinx_doc/source/tutorial/develop_overview.md @@ -17,13 +17,23 @@ The table below lists the main functions of each extension interface, its target Trinity-RFT provides a modular development approach, allowing you to flexibly add custom modules without modifying the framework code. You can place your module code in the `trinity/plugins` directory. Trinity-RFT will automatically load all Python files in that directory at runtime and register the custom modules within them. Trinity-RFT also supports specifying other directories at runtime by setting the `--plugin-dir` option, for example: `trinity run --config --plugin-dir `. +Alternatively, you can use the relative path to the custom module in the YAML configuration file, for example: `default_workflow_type: 'examples.agentscope_frozenlake.workflow.FrozenLakeWorkflow'`. ``` For modules you plan to contribute to Trinity-RFT, please follow these steps: 1. Implement your code in the appropriate directory, such as `trinity/common/workflows` for `Workflow`, `trinity/algorithm` for `Algorithm`, and `trinity/buffer/operators` for `Operator`. -2. Register your module in the corresponding `__init__.py` file of the directory. +2. Register your module in the corresponding mapping dictionary in the `__init__.py` file of the directory. + For example, if you want to register a new workflow class `ExampleWorkflow`, you need to modify the `default_mapping` dictionary of `WORKFLOWS` in the `trinity/common/workflows/__init__.py` file: + ```python + WORKFLOWS: Registry = Registry( + "workflows", + default_mapping={ + "example_workflow": "trinity.common.workflows.workflow.ExampleWorkflow", + }, + ) + ``` 3. Add tests for your module in the `tests` directory, following the naming conventions and structure of existing tests. diff --git a/docs/sphinx_doc/source/tutorial/develop_selector.md b/docs/sphinx_doc/source/tutorial/develop_selector.md index d7ecdcb293..a63da515a1 100644 --- a/docs/sphinx_doc/source/tutorial/develop_selector.md +++ b/docs/sphinx_doc/source/tutorial/develop_selector.md @@ -60,7 +60,6 @@ To create a new selector, inherit from `BaseSelector` and implement the followin This selector focuses on samples whose predicted performance is closest to a target (e.g., 90% success rate), effectively choosing "just right" difficulty tasks. ```python -@SELECTORS.register_module("difficulty_based") class DifficultyBasedSelector(BaseSelector): def __init__(self, data_source, config: TaskSelectorConfig) -> None: super().__init__(data_source, config) @@ -125,7 +124,15 @@ class DifficultyBasedSelector(BaseSelector): self.current_index = state_dict.get("current_index", 0) ``` -> 🔁 After defining your class, use `@SELECTORS.register_module("your_name")` so it can be referenced by name in configs. +> 🔁 After defining your class, remember to register it in the `default_mapping` of `trinity/buffer/selector/__init__.py` so it can be referenced by name in configs. +```python +SELECTORS = Registry( + "selectors", + default_mapping={ + "difficulty_based": "trinity.buffer.selector.selector.DifficultyBasedSelector", + }, +) +``` @@ -152,7 +159,6 @@ The operator must output a metric under the key `trinity.common.constants.SELECT #### Example: Pass Rate Calculator ```python -@EXPERIENCE_OPERATORS.register_module("pass_rate_calculator") class PassRateCalculator(ExperienceOperator): def __init__(self, **kwargs): pass @@ -194,7 +200,7 @@ After implementing your selector and operator, register them in the config file. data_processor: experience_pipeline: operators: - - name: pass_rate_calculator # Must match @register_module name + - name: pass_rate_calculator ``` #### Configure the Taskset with Your Selector @@ -207,7 +213,7 @@ buffer: storage_type: file path: ./path/to/tasks task_selector: - selector_type: difficulty_based # Matches @register_module name + selector_type: difficulty_based feature_keys: ["correct", "uncertainty"] kwargs: m: 16 diff --git a/docs/sphinx_doc/source/tutorial/develop_workflow.md b/docs/sphinx_doc/source/tutorial/develop_workflow.md index 68a6ab77a2..f8b4551edd 100644 --- a/docs/sphinx_doc/source/tutorial/develop_workflow.md +++ b/docs/sphinx_doc/source/tutorial/develop_workflow.md @@ -176,28 +176,16 @@ class ExampleWorkflow(Workflow): #### Registering Your Workflow -Register your workflow using the `WORKFLOWS.register_module` decorator. +Register your workflow using the `default_mapping` in `trinity/common/workflows/__init__.py`. Ensure the name does not conflict with existing workflows. ```python -# import some packages -from trinity.common.workflows.workflow import WORKFLOWS - -@WORKFLOWS.register_module("example_workflow") -class ExampleWorkflow(Workflow): - pass -``` - -For workflows that are prepared to be contributed to Trinity-RFT project, you need to place the above code in `trinity/common/workflows` folder, e.g., `trinity/common/workflows/example_workflow.py`. And add the following line to `trinity/common/workflows/__init__.py`: - -```python -# existing import lines -from trinity.common.workflows.example_workflow import ExampleWorkflow - -__all__ = [ - # existing __all__ lines - "ExampleWorkflow", -] +WORKFLOWS = Registry( + "workflows", + default_mapping={ + "example_workflow": "trinity.common.workflows.workflow.ExampleWorkflow", + }, +) ``` #### Performance Optimization @@ -212,7 +200,6 @@ The `can_reset` is a class property that indicates whether the workflow supports The `reset` method accepts a `Task` parameter and resets the workflow's internal state based on the new task. ```python -@WORKFLOWS.register_module("example_workflow") class ExampleWorkflow(Workflow): can_reset: bool = True @@ -234,7 +221,6 @@ The `can_repeat` is a class property that indicates whether the workflow support The `set_repeat_times` method accepts two parameters: `repeat_times` specifies the number of times to execute within the `run` method, and `run_id_base` is an integer used to identify the first run ID in multiple runs (this parameter is used in multi-turn interaction scenarios; for tasks that can be completed with a single model call, this can be ignored). ```python -@WORKFLOWS.register_module("example_workflow") class ExampleWorkflow(Workflow): can_repeat: bool = True # some code @@ -275,7 +261,6 @@ class ExampleWorkflow(Workflow): #### Full Code Example ```python -@WORKFLOWS.register_module("example_workflow") class ExampleWorkflow(Workflow): can_reset: bool = True can_repeat: bool = True @@ -359,7 +344,6 @@ trinity run --config The example above mainly targets synchronous mode. If your workflow needs to use asynchronous methods (e.g., asynchronous API), you can set `is_async` to `True`, then implement the `run_async` method. In this case, you no longer need to implement the `run` method, and the initialization parameter `auxiliary_models` will also change to `List[openai.AsyncOpenAI]`, while other methods and properties remain changed. ```python -@WORKFLOWS.register_module("example_workflow_async") class ExampleWorkflowAsync(Workflow): is_async: bool = True @@ -386,7 +370,6 @@ explorer: ``` ```python -@WORKFLOWS.register_module("example_workflow") class ExampleWorkflow(Workflow): def __init__(self, task: Task, model: ModelWrapper, auxiliary_models: List): diff --git a/docs/sphinx_doc/source/tutorial/example_mix_algo.md b/docs/sphinx_doc/source/tutorial/example_mix_algo.md index e4fc932a36..7cd545c9f8 100644 --- a/docs/sphinx_doc/source/tutorial/example_mix_algo.md +++ b/docs/sphinx_doc/source/tutorial/example_mix_algo.md @@ -47,7 +47,6 @@ The path to expert data is passed to `buffer.trainer_input.auxiliary_buffers.sft In `trinity/algorithm/algorithm.py`, we introduce a new algorithm type `MIX`. ```python -@ALGORITHM_TYPE.register_module("mix") class MIXAlgorithm(AlgorithmType): """MIX algorithm.""" @@ -159,7 +158,6 @@ Here we use the `custom_fields` argument of `Experiences.gather_experiences` to We define a `MixPolicyLoss` class in `trinity/algorithm/policy_loss_fn/mix_policy_loss.py`, which computes the sum of two loss terms regarding usual and expert experiences, respectively. ```python -@POLICY_LOSS_FN.register_module("mix") class MIXPolicyLossFn(PolicyLossFn): def __init__( self, diff --git a/docs/sphinx_doc/source/tutorial/example_multi_turn.md b/docs/sphinx_doc/source/tutorial/example_multi_turn.md index aa20529e4f..e30c52c488 100644 --- a/docs/sphinx_doc/source/tutorial/example_multi_turn.md +++ b/docs/sphinx_doc/source/tutorial/example_multi_turn.md @@ -126,28 +126,14 @@ class AlfworldWorkflow(MultiTurnWorkflow): return self.generate_env_inference_samples(env, rollout_n) ``` -Also, remember to register your workflow: +Also, remember to register your workflow in the `default_mapping` of `trinity/common/workflows/__init__.py`. ```python -@WORKFLOWS.register_module("alfworld_workflow") -class AlfworldWorkflow(MultiTurnWorkflow): - """A workflow for alfworld task.""" - ... -``` - -and include it in the init file `trinity/common/workflows/__init__.py` - -```diff - # -*- coding: utf-8 -*- - """Workflow module""" - from trinity.common.workflows.workflow import WORKFLOWS, MathWorkflow, SimpleWorkflow -+from trinity.common.workflows.envs.alfworld.alfworld_workflow import AlfworldWorkflow - - __all__ = [ - "WORKFLOWS", - "SimpleWorkflow", - "MathWorkflow", -+ "AlfworldWorkflow", - ] +WORKFLOWS = Registry( + "workflows", + default_mapping={ + "alfworld_workflow": "trinity.common.workflows.envs.alfworld.alfworld_workflow.AlfworldWorkflow", + }, +) ``` Then you are all set! It should be pretty simple😄, and the training processes in both environments converge. diff --git a/docs/sphinx_doc/source/tutorial/example_step_wise.md b/docs/sphinx_doc/source/tutorial/example_step_wise.md index 8463b3ad77..ba31c9f91f 100644 --- a/docs/sphinx_doc/source/tutorial/example_step_wise.md +++ b/docs/sphinx_doc/source/tutorial/example_step_wise.md @@ -49,28 +49,14 @@ class StepWiseAlfworldWorkflow(RewardPropagationWorkflow): return self.final_reward ``` -Also, remember to register your workflow: +Also, remember to register your workflow in the `default_mapping` of `trinity/common/workflows/__init__.py`. ```python -@WORKFLOWS.register_module("step_wise_alfworld_workflow") -class StepWiseAlfworldWorkflow(RewardPropagationWorkflow): - """A step-wise workflow for alfworld task.""" - ... -``` - -and include it in the init file `trinity/common/workflows/__init__.py` - -```diff - # -*- coding: utf-8 -*- - """Workflow module""" - from trinity.common.workflows.workflow import WORKFLOWS, MathWorkflow, SimpleWorkflow -+from trinity.common.workflows.envs.alfworld.alfworld_workflow import StepWiseAlfworldWorkflow - - __all__ = [ - "WORKFLOWS", - "SimpleWorkflow", - "MathWorkflow", -+ "StepWiseAlfworldWorkflow", - ] +WORKFLOWS = Registry( + "workflows", + default_mapping={ + "step_wise_alfworld_workflow": "trinity.common.workflows.step_wise_workflow.StepWiseAlfworldWorkflow", + }, +) ``` ### Other Configuration diff --git a/docs/sphinx_doc/source_zh/tutorial/develop_algorithm.md b/docs/sphinx_doc/source_zh/tutorial/develop_algorithm.md index d003d6e2fd..1767aa8fc7 100644 --- a/docs/sphinx_doc/source_zh/tutorial/develop_algorithm.md +++ b/docs/sphinx_doc/source_zh/tutorial/develop_algorithm.md @@ -42,9 +42,8 @@ OPMD 与 PPO 算法的主要区别在于优势值和策略损失的计算。OPMD 以下是 OPMD 算法优势函数的实现示例: ```python -from trinity.algorithm.advantage_fn import ADVANTAGE_FN, GroupAdvantage +from trinity.algorithm.advantage_fn import GroupAdvantage -@ADVANTAGE_FN.register_module("opmd") class OPMDGroupAdvantage(GroupAdvantage): """OPMD Group Advantage computation""" @@ -85,7 +84,7 @@ class OPMDGroupAdvantage(GroupAdvantage): return {"opmd_baseline": "mean", "tau": 1.0} ``` -实现后,你需要通过 {class}`trinity.algorithm.ADVANTAGE_FN` 注册此模块。注册后,该模块可在配置文件中使用注册名称进行配置。 +实现后,你需要在 `trinity/algorithm/__init__.py` 中的 `default_mapping` 中注册此模块。注册后,该模块可在配置文件中使用注册名称进行配置。 #### 步骤 1.2:实现 `PolicyLossFn` @@ -94,12 +93,11 @@ class OPMDGroupAdvantage(GroupAdvantage): - `__call__`:根据输入参数计算损失。与 `AdvantageFn` 不同,这里的输入参数均为 `torch.Tensor`。该接口会自动扫描 `__call__` 方法的参数列表,并将其转换为 experience 数据中的对应字段。因此,请直接在参数列表中写出损失计算所需的所有张量名称,而不是从 `kwargs` 中选择参数。 - `default_args`:返回默认初始化参数(字典形式),当用户未在配置文件中指定初始化参数时,默认使用此方法返回的参数。 -同样,实现后需要通过 {class}`trinity.algorithm.POLICY_LOSS_FN` 注册此模块。 +同样,实现后需要在 `trinity/algorithm/policy_loss_fn/__init__.py` 中的 `default_mapping` 中注册此模块。 以下是 OPMD 算法策略损失函数的实现示例。由于 OPMD 的策略损失仅需 logprob、action_mask 和 advantages,因此 `__call__` 方法的参数列表中仅指定这三个项: ```python -@POLICY_LOSS_FN.register_module("opmd") class OPMDPolicyLossFn(PolicyLossFn): def __init__( self, backend: str = "verl", tau: float = 1.0, loss_agg_mode: str = "token-mean" @@ -131,7 +129,7 @@ class OPMDPolicyLossFn(PolicyLossFn): 上述步骤实现了算法所需的组件,但这些组件是分散的,需要在多个地方配置才能生效。 -为简化配置,Trinity-RFT 提供了 {class}`trinity.algorithm.AlgorithmType` 来描述完整算法,并在 {class}`trinity.algorithm.ALGORITHM_TYPE` 中注册,实现一键配置。 +为简化配置,Trinity-RFT 提供了 {class}`trinity.algorithm.AlgorithmType` 来描述完整算法,并在 `trinity/algorithm/__init__.py` 中注册,实现一键配置。 `AlgorithmType` 类包含以下属性和方法: @@ -142,14 +140,13 @@ class OPMDPolicyLossFn(PolicyLossFn): - `schema`:算法对应的 experience 数据格式 - `default_config`:获取算法的默认配置,将覆盖 `ALGORITHM_TYPE` 中同名属性 -同样,实现后需要通过 `ALGORITHM_TYPE` 注册此模块。 +同样,实现后需要在 `trinity/algorithm/__init__.py` 中的 `default_mapping` 中注册此模块。 以下是 OPMD 算法的实现。 由于 OPMD 算法不需要使用 Critic 模型,`use_critic` 设置为 `False`。 `default_config` 方法返回的字典表明 OPMD 将使用步骤 1 中实现的 `opmd` 类型的 `AdvantageFn` 和 `PolicyLossFn`,不会对奖励应用 KL 惩罚,但在计算最终损失时会添加 `k2` 类型的 KL 损失。 ```python -@ALGORITHM_TYPE.register_module("opmd") class OPMDAlgorithm(AlgorithmType): """OPMD algorithm.""" diff --git a/docs/sphinx_doc/source_zh/tutorial/develop_operator.md b/docs/sphinx_doc/source_zh/tutorial/develop_operator.md index bb95b45f87..6523ded045 100644 --- a/docs/sphinx_doc/source_zh/tutorial/develop_operator.md +++ b/docs/sphinx_doc/source_zh/tutorial/develop_operator.md @@ -41,11 +41,10 @@ class ExperienceOperator(ABC): 以下是一个简单数据处理算子的实现示例,该算子过滤掉奖励低于某一阈值的 experience: ```python -from trinity.buffer.operators import EXPERIENCE_OPERATORS, ExperienceOperator +from trinity.buffer.operators import ExperienceOperator from trinity.common.experience import Experience -@EXPERIENCE_OPERATORS.register_module("reward_filter") class RewardFilter(ExperienceOperator): def __init__(self, threshold: float = 0.0) -> None: @@ -57,7 +56,15 @@ class RewardFilter(ExperienceOperator): return filtered_exps, metrics ``` -实现后,你需要通过 {class}`trinity.buffer.operators.EXPERIENCE_OPERATORS` 注册此模块。注册后,该模块可在配置文件中使用注册名称进行配置。 +实现后,你需要在 `trinity/buffer/operators/__init__.py` 中的 `default_mapping` 中注册此模块。注册后,该模块可在配置文件中使用注册名称进行配置。 +```python +EXPERIENCE_OPERATORS = Registry( + "experience_operators", + default_mapping={ + "reward_filter": "trinity.buffer.operators.filters.reward_filter.RewardFilter", + }, +) +``` ### 步骤 2:使用此算子 diff --git a/docs/sphinx_doc/source_zh/tutorial/develop_overview.md b/docs/sphinx_doc/source_zh/tutorial/develop_overview.md index 2ce151198f..8619562419 100644 --- a/docs/sphinx_doc/source_zh/tutorial/develop_overview.md +++ b/docs/sphinx_doc/source_zh/tutorial/develop_overview.md @@ -17,13 +17,22 @@ Trinity-RFT 将 RL 训练过程拆分为了三个模块:**Explorer**、**Train Trinity-RFT 提供了插件化的开发方式,可以在不修改框架代码的前提下,灵活地添加自定义模块。 开发者可以将自己编写的模块代码放在 `trinity/plugins` 目录下。Trinity-RFT 会在运行时自动加载该目录下的所有 Python 文件,并注册其中的自定义模块。 Trinity-RFT 也支持在运行时通过设置 `--plugin-dir` 选项来指定其他目录,例如:`trinity run --config --plugin-dir `。 +另外,你也可以使用相对路径来在 YAML 配置文件中指定自定义模块,例如:`default_workflow_type: 'examples.agentscope_frozenlake.workflow.FrozenLakeWorkflow'`。 ``` 对于准备向 Trinity-RFT 提交的模块,请遵循以下步骤: 1. 在适当目录中实现你的代码,例如 `trinity/common/workflows` 用于 `Workflow`,`trinity/algorithm` 用于 `Algorithm`,`trinity/buffer/operators` 用于 `Operator`。 -2. 在目录对应的 `__init__.py` 文件中注册你的模块。 +2. 在目录对应的 `__init__.py` 文件中的 `default_mapping` 字典中注册你的模块。例如,对于新的 `ExampleWorkflow` 类,你需要在 `trinity/common/workflows/__init__.py` 文件中的 `WORKFLOWS` 中添加你的模块: + ```python + WORKFLOWS: Registry = Registry( + "workflows", + default_mapping={ + "example_workflow": "trinity.common.workflows.workflow.ExampleWorkflow", + }, + ) + ``` 3. 在 `tests` 目录中为你的模块添加测试,遵循现有测试的命名约定和结构。 diff --git a/docs/sphinx_doc/source_zh/tutorial/develop_selector.md b/docs/sphinx_doc/source_zh/tutorial/develop_selector.md index 1f92f05d4c..1d08b42508 100644 --- a/docs/sphinx_doc/source_zh/tutorial/develop_selector.md +++ b/docs/sphinx_doc/source_zh/tutorial/develop_selector.md @@ -58,7 +58,6 @@ 该选择器聚焦于模型预测表现最接近目标值的样本(例如 90% 成功率),从而挑选出“难度适中”的任务。 ```python -@SELECTORS.register_module("difficulty_based") class DifficultyBasedSelector(BaseSelector): def __init__(self, data_source, config: TaskSelectorConfig) -> None: super().__init__(data_source, config) @@ -123,8 +122,15 @@ class DifficultyBasedSelector(BaseSelector): self.current_index = state_dict.get("current_index", 0) ``` -> 🔁 定义完类后,请使用 `@SELECTORS.register_module("your_name")` 注册,以便在配置文件中通过名称引用。 - +> 🔁 定义完类后,请在 `trinity/buffer/selector/__init__.py` 中的 `default_mapping` 中注册,以便在配置文件中通过名称引用。 +```python +SELECTORS = Registry( + "selectors", + default_mapping={ + "difficulty_based": "trinity.buffer.selector.selector.DifficultyBasedSelector", + }, +) +``` ### ✅ 步骤 2:实现反馈操作器(Feedback Operator) @@ -150,7 +156,6 @@ class DifficultyBasedSelector(BaseSelector): #### 示例:通过率计算器(Pass Rate Calculator) ```python -@EXPERIENCE_OPERATORS.register_module("pass_rate_calculator") class PassRateCalculator(ExperienceOperator): def __init__(self, **kwargs): pass @@ -192,7 +197,7 @@ class PassRateCalculator(ExperienceOperator): data_processor: experience_pipeline: operators: - - name: pass_rate_calculator # 必须与 @register_module 名称一致 + - name: pass_rate_calculator ``` #### 为任务集配置你的选择器 @@ -205,7 +210,7 @@ buffer: storage_type: file path: ./path/to/tasks task_selector: - selector_type: difficulty_based # 必须与 @register_module 名称匹配 + selector_type: difficulty_based feature_keys: ["correct", "uncertainty"] kwargs: m: 16 diff --git a/docs/sphinx_doc/source_zh/tutorial/develop_workflow.md b/docs/sphinx_doc/source_zh/tutorial/develop_workflow.md index 8e27d69e51..a3bb025cfc 100644 --- a/docs/sphinx_doc/source_zh/tutorial/develop_workflow.md +++ b/docs/sphinx_doc/source_zh/tutorial/develop_workflow.md @@ -170,27 +170,15 @@ class ExampleWorkflow(Workflow): #### 注册你的工作流 -为了让 Trinity-RFT 能够通过配置文件中的名称自动找到你的工作流,你需要使用 `WORKFLOWS.register_module` 装饰器注册。 +为了让 Trinity-RFT 能够通过配置文件中的名称自动找到你的工作流,你需要在 `trinity/common/workflows/__init__.py` 中的 `default_mapping` 中注册。 ```python -# import some packages -from trinity.common.workflows.workflow import WORKFLOWS - -@WORKFLOWS.register_module("example_workflow") -class ExampleWorkflow(Workflow): - pass -``` - -对于准备贡献给 Trinity-RFT 项目的模块,你需要将上述代码放入 `trinity/common/workflows` 文件夹中,例如 `trinity/common/workflows/example_workflow.py`。并在 `trinity/common/workflows/__init__.py` 中添加以下行: - -```python -# existing import lines -from trinity.common.workflows.example_workflow import ExampleWorkflow - -__all__ = [ - # existing __all__ lines - "ExampleWorkflow", -] +WORKFLOWS = Registry( + "workflows", + default_mapping={ + "example_workflow": "trinity.common.workflows.workflow.ExampleWorkflow", + }, +) ``` #### 性能调优 @@ -207,7 +195,6 @@ __all__ = [ `reset` 方法接受一个新的 `Task` 实例,并使用该实例更新工作流的状态。 ```python -@WORKFLOWS.register_module("example_workflow") class ExampleWorkflow(Workflow): can_reset: bool = True @@ -229,7 +216,6 @@ class ExampleWorkflow(Workflow): `set_repeat_times` 方法接受两个参数:`repeat_times` 指定了在 `run` 方法内需要执行的次数,`run_id_base` 是一个整数,用于标识多次运行中第一次的运行 ID,之后各次的 ID 基于此递增(该参数用于多轮交互场景,单次模型调用即可完成的任务可以忽略该项)。 ```python -@WORKFLOWS.register_module("example_workflow") class ExampleWorkflow(Workflow): can_repeat: bool = True # some code @@ -270,7 +256,6 @@ class ExampleWorkflow(Workflow): #### 完整代码示例 ```python -@WORKFLOWS.register_module("example_workflow") class ExampleWorkflow(Workflow): can_reset: bool = True can_repeat: bool = True @@ -354,7 +339,6 @@ trinity run --config 本节样例主要针对同步模式,如果你的工作流需要使用异步方法(例如异步 API),你可以将 `is_async` 属性设置为 `True`,然后实现 `run_async` 方法,在这种情况下不再需要实现 `run` 方法,并且初始化参数 `auxiliary_models` 也会自动变为 `List[openai.AsyncOpenAI]` 类型,其余方法和属性保持不变。 ```python -@WORKFLOWS.register_module("example_workflow_async") class ExampleWorkflowAsync(Workflow): is_async: bool = True @@ -384,7 +368,6 @@ explorer: ``` ```python -@WORKFLOWS.register_module("example_workflow") class ExampleWorkflow(Workflow): def __init__(self, task: Task, model: ModelWrapper, auxiliary_models: List): diff --git a/docs/sphinx_doc/source_zh/tutorial/example_mix_algo.md b/docs/sphinx_doc/source_zh/tutorial/example_mix_algo.md index 9f9b6f3b1c..b5027bae56 100644 --- a/docs/sphinx_doc/source_zh/tutorial/example_mix_algo.md +++ b/docs/sphinx_doc/source_zh/tutorial/example_mix_algo.md @@ -41,7 +41,6 @@ $$ 在 `trinity/algorithm/algorithm.py` 中,我们引入一个新的算法类型 `MIX`。 ```python -@ALGORITHM_TYPE.register_module("mix") class MIXAlgorithm(AlgorithmType): """MIX algorithm.""" @@ -150,7 +149,6 @@ class MixSampleStrategy(SampleStrategy): 我们在 `trinity/algorithm/policy_loss_fn/mix_policy_loss.py` 中定义一个 `MixPolicyLoss` 类,它分别计算关于普通 experience 和专家 experience 的两个 losses 之和。 ```python -@POLICY_LOSS_FN.register_module("mix") class MIXPolicyLossFn(PolicyLossFn): def __init__( self, diff --git a/docs/sphinx_doc/source_zh/tutorial/example_multi_turn.md b/docs/sphinx_doc/source_zh/tutorial/example_multi_turn.md index c204dc7804..a2efbce398 100644 --- a/docs/sphinx_doc/source_zh/tutorial/example_multi_turn.md +++ b/docs/sphinx_doc/source_zh/tutorial/example_multi_turn.md @@ -126,28 +126,15 @@ class AlfworldWorkflow(MultiTurnWorkflow): return self.generate_env_inference_samples(env, rollout_n) ``` -同时,记得注册你的 workflow: -```python -@WORKFLOWS.register_module("alfworld_workflow") -class AlfworldWorkflow(MultiTurnWorkflow): - """A workflow for alfworld task.""" - ... -``` - -并在初始化文件 `trinity/common/workflows/__init__.py` 中包含它: +同时,记得在 `trinity/common/workflows/__init__.py` 中的 `default_mapping` 中注册你的 workflow。 -```diff - # -*- coding: utf-8 -*- - """Workflow module""" - from .workflow import WORKFLOWS, MathWorkflow, SimpleWorkflow -+from .envs.alfworld.alfworld_workflow import AlfworldWorkflow - - __all__ = [ - "WORKFLOWS", - "SimpleWorkflow", - "MathWorkflow", -+ "AlfworldWorkflow", - ] +```python +WORKFLOWS = Registry( + "workflows", + default_mapping={ + "alfworld_workflow": "trinity.common.workflows.envs.alfworld.alfworld_workflow.AlfworldWorkflow", + }, +) ``` 这样就完成了!整个过程非常简单😄,并且在这两个环境中的训练过程都能收敛。 diff --git a/docs/sphinx_doc/source_zh/tutorial/example_step_wise.md b/docs/sphinx_doc/source_zh/tutorial/example_step_wise.md index a796d5d296..8030f6e06e 100644 --- a/docs/sphinx_doc/source_zh/tutorial/example_step_wise.md +++ b/docs/sphinx_doc/source_zh/tutorial/example_step_wise.md @@ -48,28 +48,14 @@ class StepWiseAlfworldWorkflow(RewardPropagationWorkflow): return self.final_reward ``` -同时,请记得注册你的工作流: +同时,请记得在 `trinity/common/workflows/__init__.py` 中的 `default_mapping` 中注册你的工作流: ```python -@WORKFLOWS.register_module("step_wise_alfworld_workflow") -class StepWiseAlfworldWorkflow(RewardPropagationWorkflow): - """A step-wise workflow for alfworld task.""" - ... -``` - -并将其添加到初始化文件 `trinity/common/workflows/__init__.py` 中: - -```diff - # -*- coding: utf-8 -*- - """Workflow module""" - from .workflow import WORKFLOWS, MathWorkflow, SimpleWorkflow -+from .envs.alfworld.alfworld_workflow import StepWiseAlfworldWorkflow - - __all__ = [ - "WORKFLOWS", - "SimpleWorkflow", - "MathWorkflow", -+ "StepWiseAlfworldWorkflow", - ] +WORKFLOWS = Registry( + "workflows", + default_mapping={ + "step_wise_alfworld_workflow": "trinity.common.workflows.step_wise_workflow.StepWiseAlfworldWorkflow", + }, +) ``` ### 其他配置 diff --git a/examples/bots/README.md b/examples/bots/README.md index d3e5ec19dd..7ba9072d14 100644 --- a/examples/bots/README.md +++ b/examples/bots/README.md @@ -55,7 +55,7 @@ Remember to update `task_selector.feature_keys` in `bots.yaml`. ##### Step 3: Training Launch training by executing: ```bash -trinity run --config examples/bots/bots.yaml --plugin-dir examples/bots/workflow +trinity run --config examples/bots/bots.yaml ``` The improvement over random selection baseline can be stably obtained 🤖🤖🤖. diff --git a/examples/bots/README_zh.md b/examples/bots/README_zh.md index 292336e334..3a40e9d2b3 100644 --- a/examples/bots/README_zh.md +++ b/examples/bots/README_zh.md @@ -53,7 +53,7 @@ python examples/bots/ref_eval_collect.py \ ##### 第三步:训练 执行以下命令启动训练: ```bash -trinity run --config examples/bots/bots.yaml --plugin-dir examples/bots/workflow +trinity run --config examples/bots/bots.yaml ``` 相比随机选择基线的提升可以被稳定地观察到🤖🤖🤖. diff --git a/examples/bots/bots.yaml b/examples/bots/bots.yaml index a794b03e85..b0bebf05f4 100644 --- a/examples/bots/bots.yaml +++ b/examples/bots/bots.yaml @@ -50,7 +50,7 @@ buffer: response_key: 'reward_model.ground_truth' rollout_args: temperature: 1.0 - default_workflow_type: 'bots_math_boxed_workflow' + default_workflow_type: 'examples.bots.workflow.bots_math_boxed_workflow.BOTSMathBoxedWorkflow' trainer_input: experience_buffer: name: exp_buffer diff --git a/examples/bots/random.yaml b/examples/bots/random.yaml index cb6958004c..fd18862909 100644 --- a/examples/bots/random.yaml +++ b/examples/bots/random.yaml @@ -38,7 +38,7 @@ buffer: response_key: 'reward_model.ground_truth' rollout_args: temperature: 1.0 - default_workflow_type: 'bots_math_boxed_workflow' + default_workflow_type: 'examples.bots.workflow.bots_math_boxed_workflow.BOTSMathBoxedWorkflow' trainer_input: experience_buffer: name: exp_buffer diff --git a/examples/bots/workflow/bots_math_boxed_reward.py b/examples/bots/workflow/bots_math_boxed_reward.py index c49c5a36a2..47cc498e86 100644 --- a/examples/bots/workflow/bots_math_boxed_reward.py +++ b/examples/bots/workflow/bots_math_boxed_reward.py @@ -1,10 +1,10 @@ from typing import Optional +from examples.bots.workflow.bots_reward import compute_score_bots from trinity.common.rewards.eval_utils import validate_think_pattern -from trinity.common.rewards.reward_fn import REWARD_FUNCTIONS, RewardFn +from trinity.common.rewards.reward_fn import RewardFn -@REWARD_FUNCTIONS.register_module("bots_math_boxed_reward") class BOTSMathBoxedRewardFn(RewardFn): """A reward function that rewards for math task for BOTS.""" @@ -22,8 +22,6 @@ def __call__( # type: ignore format_score_coef: Optional[float] = 0.1, **kwargs, ) -> dict[str, float]: - from trinity.plugins.bots_reward import compute_score_bots - accuracy_score = compute_score_bots(response, truth) format_score = 0.0 diff --git a/examples/bots/workflow/bots_math_boxed_workflow.py b/examples/bots/workflow/bots_math_boxed_workflow.py index 90ce2a3f2b..63c0d8f4d5 100644 --- a/examples/bots/workflow/bots_math_boxed_workflow.py +++ b/examples/bots/workflow/bots_math_boxed_workflow.py @@ -3,18 +3,16 @@ import os from typing import List, Union +from examples.bots.workflow.bots_math_boxed_reward import BOTSMathBoxedRewardFn from trinity.common.experience import Experience from trinity.common.workflows.customized_math_workflows import MathBoxedWorkflow, Task -from trinity.common.workflows.workflow import WORKFLOWS -@WORKFLOWS.register_module("bots_math_boxed_workflow") class BOTSMathBoxedWorkflow(MathBoxedWorkflow): """A workflow for math tasks that give answers in boxed format for BOTS.""" def reset(self, task: Task): super().reset(task) - from trinity.plugins.bots_math_boxed_reward import BOTSMathBoxedRewardFn self.reward_fn = BOTSMathBoxedRewardFn(**self.reward_fn_args) self.task_desc = nested_query(self.format_args.prompt_key, self.raw_task) @@ -25,14 +23,11 @@ def format_messages(self): return self.task_desc -@WORKFLOWS.register_module("bots_ref_eval_collect_math_boxed_workflow") class BOTSRefEvalCollectMathBoxedWorkflow(MathBoxedWorkflow): """A reference evaluation collection workflow for math tasks that give answers in boxed format for BOTS.""" def reset(self, task: Task): super().reset(task) - from trinity.plugins.bots_math_boxed_reward import BOTSMathBoxedRewardFn - self.reward_fn = BOTSMathBoxedRewardFn(**self.reward_fn_args) self.task_desc = nested_query(self.format_args.prompt_key, self.raw_task) self.truth = nested_query(self.format_args.response_key, self.raw_task) diff --git a/examples/learn_to_ask/README.md b/examples/learn_to_ask/README.md index 01c024afd1..1b300fe228 100644 --- a/examples/learn_to_ask/README.md +++ b/examples/learn_to_ask/README.md @@ -109,7 +109,7 @@ explorer: Then, launch training: ```bash -trinity run --config examples/learn_to_ask/train.yaml --plugin-dir examples/learn_to_ask/workflow +trinity run --config examples/learn_to_ask/train.yaml ``` --- diff --git a/examples/learn_to_ask/data_prepare/3_rollout_then_evaluate.py b/examples/learn_to_ask/data_prepare/3_rollout_then_evaluate.py index ec34da0d46..60d7755d8c 100644 --- a/examples/learn_to_ask/data_prepare/3_rollout_then_evaluate.py +++ b/examples/learn_to_ask/data_prepare/3_rollout_then_evaluate.py @@ -37,7 +37,9 @@ def init_llm(model_path): def rollout(llm, tokenizer, sampling_params, input_file_path, output_file_path, rollout_repeat=3): - from trinity.plugins.prompt_learn2ask import rollout_prompt_med as rollout_prompt + from examples.learn_to_ask.workflow.prompt_learn2ask import ( + rollout_prompt_med as rollout_prompt, + ) with open(input_file_path, "r") as lines: sample_list = [json.loads(line.strip()) for line in lines] @@ -68,7 +70,9 @@ def rollout(llm, tokenizer, sampling_params, input_file_path, output_file_path, def eval_sample(llm, tokenizer, sampling_params, input_file_path, output_file_path): - from trinity.plugins.prompt_learn2ask import reward_prompt_med as grader_prompt + from examples.learn_to_ask.workflow.prompt_learn2ask import ( + reward_prompt_med as grader_prompt, + ) print(f"input_file_path: {input_file_path}") print(f"output_file_path: {output_file_path}") diff --git a/examples/learn_to_ask/train.yaml b/examples/learn_to_ask/train.yaml index 28edac55db..9aee01b4a9 100644 --- a/examples/learn_to_ask/train.yaml +++ b/examples/learn_to_ask/train.yaml @@ -41,7 +41,7 @@ buffer: train_mode: "Ra+Rs" fusion_mode: "default" eval_tasksets: [ ] - default_workflow_type: learn2ask_workflow + default_workflow_type: examples.learn_to_ask.workflow.workflow_learn2ask.Learn2AskWorkflow trainer_input: experience_buffer: name: experience_buffer diff --git a/examples/learn_to_ask/workflow/workflow_learn2ask.py b/examples/learn_to_ask/workflow/workflow_learn2ask.py index f7f28ffb64..6e5cb6e0da 100644 --- a/examples/learn_to_ask/workflow/workflow_learn2ask.py +++ b/examples/learn_to_ask/workflow/workflow_learn2ask.py @@ -11,7 +11,7 @@ from trinity.common.experience import Experience from trinity.common.models.model import ModelWrapper -from trinity.common.workflows import WORKFLOWS, SimpleWorkflow, Task +from trinity.common.workflows.workflow import SimpleWorkflow, Task from trinity.utils.log import get_logger logger = get_logger(__name__) @@ -28,7 +28,6 @@ """ -@WORKFLOWS.register_module("learn2ask_workflow") class Learn2AskWorkflow(SimpleWorkflow): """A workflow for Elem training with local model.""" @@ -56,11 +55,11 @@ def resettable(self): def reset(self, task: Task): if self.train_mode == "Ra": # we have a different system prompt for this training mode. - from trinity.plugins.prompt_learn2ask import ( + from examples.learn_to_ask.workflow.prompt_learn2ask import ( rollout_prompt_med_Ra as system_prompt, ) else: # other modes use the same system prompt - from trinity.plugins.prompt_learn2ask import ( + from examples.learn_to_ask.workflow.prompt_learn2ask import ( rollout_prompt_med as system_prompt, ) @@ -129,7 +128,9 @@ def run(self) -> List[Experience]: return responses def llm_reward(self, response): - from trinity.plugins.prompt_learn2ask import reward_prompt_med as reward_prompt + from examples.learn_to_ask.workflow.prompt_learn2ask import ( + reward_prompt_med as reward_prompt, + ) history = self.merge_msg_list(self.task_desc + [{"role": "assistant", "content": response}]) messages = [ diff --git a/scripts/context_length_test/context_length.yaml b/scripts/context_length_test/context_length.yaml index e7133de8dd..27611dc5e6 100644 --- a/scripts/context_length_test/context_length.yaml +++ b/scripts/context_length_test/context_length.yaml @@ -46,7 +46,7 @@ buffer: prompt_len: ${model.max_prompt_tokens} max_model_len: ${model.max_model_len} eval_tasksets: [] - default_workflow_type: synthetic_exp_workflow + default_workflow_type: scripts.context_length_test.workflow.synthetic_exp_workflow.SyntheticExpWorkflow default_reward_fn_type: math_reward trainer_input: experience_buffer: diff --git a/scripts/context_length_test/workflow/synthetic_exp_workflow.py b/scripts/context_length_test/workflow/synthetic_exp_workflow.py index 5147acec65..dad377ecc3 100644 --- a/scripts/context_length_test/workflow/synthetic_exp_workflow.py +++ b/scripts/context_length_test/workflow/synthetic_exp_workflow.py @@ -1,10 +1,9 @@ import torch from trinity.common.experience import Experience -from trinity.common.workflows.workflow import WORKFLOWS, SimpleWorkflow, Task +from trinity.common.workflows.workflow import SimpleWorkflow, Task -@WORKFLOWS.register_module("synthetic_exp_workflow") class SyntheticExpWorkflow(SimpleWorkflow): def reset(self, task: Task): self.workflow_args = task.workflow_args diff --git a/tests/algorithm/kl_fn_test.py b/tests/algorithm/kl_fn_test.py index f3771c5f7c..4f8bc3393d 100644 --- a/tests/algorithm/kl_fn_test.py +++ b/tests/algorithm/kl_fn_test.py @@ -5,7 +5,7 @@ import torch -from trinity.algorithm.kl_fn.kl_fn import KL_FN +from trinity.algorithm.kl_fn import KL_FN class KLFnTest(unittest.TestCase): diff --git a/tests/algorithm/policy_loss_test.py b/tests/algorithm/policy_loss_test.py index d4ddbbf87c..bac8e69322 100644 --- a/tests/algorithm/policy_loss_test.py +++ b/tests/algorithm/policy_loss_test.py @@ -6,7 +6,7 @@ import torch from verl import DataProto -from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN +from trinity.algorithm.policy_loss_fn import POLICY_LOSS_FN class VerlPolicyLossTest(unittest.TestCase): diff --git a/tests/buffer/formatter_test.py b/tests/buffer/formatter_test.py index 92b6616555..e5381a7249 100644 --- a/tests/buffer/formatter_test.py +++ b/tests/buffer/formatter_test.py @@ -8,7 +8,7 @@ get_unittest_dataset_config, get_vision_language_model_path, ) -from trinity.buffer.schema.formatter import FORMATTER +from trinity.buffer.schema import FORMATTER from trinity.common.config import FormatConfig, StorageConfig from trinity.common.constants import PromptType from trinity.common.experience import Experience diff --git a/tests/explorer/scheduler_test.py b/tests/explorer/scheduler_test.py index aa8706c27d..7aec80ed23 100644 --- a/tests/explorer/scheduler_test.py +++ b/tests/explorer/scheduler_test.py @@ -13,8 +13,7 @@ from trinity.common.constants import StorageType, SyncStyle from trinity.common.experience import EID, Experience from trinity.common.models.model import InferenceModel, ModelWrapper -from trinity.common.workflows import Task -from trinity.common.workflows.workflow import WORKFLOWS, Workflow +from trinity.common.workflows import WORKFLOWS, Task, Workflow from trinity.explorer.scheduler import Scheduler diff --git a/tests/explorer/workflow_test.py b/tests/explorer/workflow_test.py index b0df0369a5..16bb0b7bbb 100644 --- a/tests/explorer/workflow_test.py +++ b/tests/explorer/workflow_test.py @@ -18,16 +18,12 @@ from trinity.common.experience import EID, Experience from trinity.common.models import create_inference_models from trinity.common.models.model import ModelWrapper -from trinity.common.rewards import RMGalleryFn -from trinity.common.workflows import ( - WORKFLOWS, - MathBoxedWorkflow, - MathEvalWorkflow, - MathRMWorkflow, - MathWorkflow, - Workflow, -) -from trinity.common.workflows.workflow import MultiTurnWorkflow, Task +from trinity.common.rewards.reward_fn import RMGalleryFn +from trinity.common.workflows import WORKFLOWS, Workflow +from trinity.common.workflows.customized_math_workflows import MathBoxedWorkflow +from trinity.common.workflows.eval_workflow import MathEvalWorkflow +from trinity.common.workflows.math_rm_workflow import MathRMWorkflow +from trinity.common.workflows.workflow import MathWorkflow, MultiTurnWorkflow, Task from trinity.explorer.workflow_runner import WorkflowRunner @@ -553,7 +549,6 @@ async def monitor_routine(): class TestAgentScopeWorkflowAdapter(unittest.IsolatedAsyncioTestCase): - @unittest.skip("Waiting for agentscope>=0.1.6") async def test_adapter(self): try: from agentscope.model import TrinityChatModel diff --git a/tests/manager/synchronizer_test.py b/tests/manager/synchronizer_test.py index 596bf5b839..114e396aca 100644 --- a/tests/manager/synchronizer_test.py +++ b/tests/manager/synchronizer_test.py @@ -21,7 +21,7 @@ get_template_config, get_unittest_dataset_config, ) -from trinity.algorithm.algorithm import ALGORITHM_TYPE +from trinity.algorithm import ALGORITHM_TYPE from trinity.cli.launcher import both, explore, train from trinity.common.config import Config, ExperienceBufferConfig from trinity.common.constants import StorageType, SyncMethod, SyncStyle diff --git a/tests/utils/plugin_test.py b/tests/utils/plugin_test.py index 61e5665ae9..3bda6754ae 100644 --- a/tests/utils/plugin_test.py +++ b/tests/utils/plugin_test.py @@ -55,8 +55,8 @@ class TestPluginLoader(unittest.TestCase): @parameterized.expand(PLUGIN_DIR_PARAMS) def test_load_plugins_local(self, plugin_dir): if os.path.isabs(plugin_dir): - my_workflow_cls = WORKFLOWS.get("my_workflow") - self.assertIsNone(my_workflow_cls) + with self.assertRaises(ValueError): + my_workflow_cls = WORKFLOWS.get("my_workflow") os.environ[PLUGIN_DIRS_ENV_VAR] = plugin_dir try: load_plugins() @@ -84,15 +84,15 @@ def test_load_plugins_remote(self, plugin_dir): ignore_reinit_error=True, runtime_env={"env_vars": {PLUGIN_DIRS_ENV_VAR: plugin_dir}}, ) - my_workflow_cls = WORKFLOWS.get("my_workflow") # disable plugin and use custom class from registry remote_plugin = ray.remote(PluginActor).remote(config, enable_load_plugins=False) - remote_plugin.run.remote(my_workflow_cls) with self.assertRaises(ray.exceptions.ActorDiedError): + # During initialization, enable_workflow=True (default) will try to get "my_workflow" ray.get(remote_plugin.__ray_ready__.remote()) # enable plugin remote_plugin = ray.remote(PluginActor).remote(config) + my_workflow_cls = WORKFLOWS.get("my_workflow") remote_res = ray.get(remote_plugin.run.remote(my_workflow_cls)) self.assertEqual(remote_res[0], "Hello world") self.assertEqual(remote_res[1], "Hi") diff --git a/tests/utils/registry_test.py b/tests/utils/registry_test.py index 4bdfff1e12..e028cc2a0f 100644 --- a/tests/utils/registry_test.py +++ b/tests/utils/registry_test.py @@ -1,6 +1,48 @@ +# -*- coding: utf-8 -*- +"""Test cases for workflows registry mapping.""" import unittest import ray +import torch + +from trinity.algorithm import ( + ADVANTAGE_FN, + ALGORITHM_TYPE, + ENTROPY_LOSS_FN, + KL_FN, + POLICY_LOSS_FN, + SAMPLE_STRATEGY, + AdvantageFn, + AlgorithmType, + EntropyLossFn, + KLFn, + PolicyLossFn, + SampleStrategy, +) +from trinity.buffer.buffer_reader import BufferReader +from trinity.buffer.operators import EXPERIENCE_OPERATORS, ExperienceOperator +from trinity.buffer.reader import READER +from trinity.buffer.schema import FORMATTER, SQL_SCHEMA +from trinity.buffer.selector import SELECTORS, BaseSelector +from trinity.buffer.storage import PRIORITY_FUNC +from trinity.buffer.storage.queue import PriorityFunction +from trinity.common.rewards import REWARD_FUNCTIONS, RewardFn +from trinity.common.workflows import WORKFLOWS, Workflow +from trinity.utils.monitor import MONITOR, Monitor + + +@ENTROPY_LOSS_FN.register_module("dummy_entropy_loss_fn") +class DummyEntropyLossFn(EntropyLossFn): + def __init__(self, entropy_coef: float): + self.entropy_coef = entropy_coef + + def __call__( + self, + entropy, + action_mask, + **kwargs, + ): + return torch.tensor(0.0), {} class ImportUtils: @@ -15,7 +57,7 @@ def run(self): assert res[1] == "0" -class TestRegistry(unittest.TestCase): +class TestRegistryWithRay(unittest.TestCase): def setUp(self): ray.init(ignore_reinit_error=True) @@ -27,3 +69,257 @@ def test_dynamic_import(self): ImportUtils().run() # test remote import ray.get(ray.remote(ImportUtils).remote().run.remote()) + + +class TestRegistry(unittest.TestCase): + """Test registry functionality.""" + + def test_common_module_registry_mapping(self): + """Test registry mapping in common module""" + # test workflow + workflow_names = list(WORKFLOWS._default_mapping.keys()) + for workflow_name in workflow_names: + with self.subTest(workflow_name=workflow_name): + workflow_cls = WORKFLOWS.get(workflow_name) + self.assertIsNotNone( + workflow_cls, f"{workflow_name} should be retrievable from registry" + ) + self.assertTrue( + issubclass(workflow_cls, Workflow), + f"{workflow_name} should be a subclass of Workflow", + ) + with self.assertRaises(ValueError): + WORKFLOWS.get("non_existent_workflow") + + # test reward function + reward_fn_names = list(REWARD_FUNCTIONS._default_mapping.keys()) + for reward_fn_name in reward_fn_names: + with self.subTest(reward_fn_name=reward_fn_name): + reward_fn_cls = REWARD_FUNCTIONS.get(reward_fn_name) + self.assertIsNotNone( + reward_fn_cls, f"{reward_fn_name} should be retrievable from registry" + ) + self.assertTrue( + issubclass(reward_fn_cls, RewardFn), + f"{reward_fn_name} should be a subclass of RewardFn", + ) + with self.assertRaises(ValueError): + REWARD_FUNCTIONS.get("non_existent_reward_fn") + + def test_algorithm_registry_mapping(self): + """Test registry mapping in algorithm module""" + # test algorithm + algorithm_names = list(ALGORITHM_TYPE._default_mapping.keys()) + for algorithm_name in algorithm_names: + with self.subTest(algorithm_name=algorithm_name): + algorithm_cls = ALGORITHM_TYPE.get(algorithm_name) + self.assertIsNotNone( + algorithm_cls, f"{algorithm_name} should be retrievable from registry" + ) + self.assertTrue( + issubclass(algorithm_cls, AlgorithmType), + f"{algorithm_name} should be a subclass of AlgorithmType", + ) + with self.assertRaises(ValueError): + ALGORITHM_TYPE.get("non_existent_algorithm") + + # test advantage function + advantage_fn_names = list(ADVANTAGE_FN._default_mapping.keys()) + for advantage_fn_name in advantage_fn_names: + with self.subTest(advantage_fn_name=advantage_fn_name): + advantage_fn_cls = ADVANTAGE_FN.get(advantage_fn_name) + self.assertIsNotNone( + advantage_fn_cls, f"{advantage_fn_name} should be retrievable from registry" + ) + self.assertTrue( + issubclass(advantage_fn_cls, AdvantageFn), + f"{advantage_fn_name} should be a subclass of AdvantageFn", + ) + with self.assertRaises(ValueError): + ADVANTAGE_FN.get("non_existent_advantage_fn") + + # test entropy loss function + entropy_loss_fn_names = list(ENTROPY_LOSS_FN._default_mapping.keys()) + for entropy_loss_fn_name in entropy_loss_fn_names: + with self.subTest(entropy_loss_fn_name=entropy_loss_fn_name): + entropy_loss_fn_cls = ENTROPY_LOSS_FN.get(entropy_loss_fn_name) + self.assertIsNotNone( + entropy_loss_fn_cls, + f"{entropy_loss_fn_name} should be retrievable from registry", + ) + self.assertTrue( + issubclass(entropy_loss_fn_cls, EntropyLossFn), + f"{entropy_loss_fn_name} should be a subclass of EntropyLossFn", + ) + with self.assertRaises(ValueError): + ENTROPY_LOSS_FN.get("non_existent_entropy_loss_fn") + + # test kl function + kl_fn_names = list(KL_FN._default_mapping.keys()) + for kl_fn_name in kl_fn_names: + with self.subTest(kl_fn_name=kl_fn_name): + kl_fn_cls = KL_FN.get(kl_fn_name) + self.assertIsNotNone(kl_fn_cls, f"{kl_fn_name} should be retrievable from registry") + self.assertTrue( + issubclass(kl_fn_cls, KLFn), f"{kl_fn_name} should be a subclass of KLFn" + ) + with self.assertRaises(ValueError): + KL_FN.get("non_existent_kl_fn") + + # test policy loss function + policy_loss_fn_names = list(POLICY_LOSS_FN._default_mapping.keys()) + for policy_loss_fn_name in policy_loss_fn_names: + with self.subTest(policy_loss_fn_name=policy_loss_fn_name): + policy_loss_fn_cls = POLICY_LOSS_FN.get(policy_loss_fn_name) + self.assertIsNotNone( + policy_loss_fn_cls, f"{policy_loss_fn_name} should be retrievable from registry" + ) + self.assertTrue( + issubclass(policy_loss_fn_cls, PolicyLossFn), + f"{policy_loss_fn_name} should be a subclass of PolicyLossFn", + ) + with self.assertRaises(ValueError): + POLICY_LOSS_FN.get("non_existent_policy_loss_fn") + + # test sample strategy + sample_strategy_names = list(SAMPLE_STRATEGY._default_mapping.keys()) + for sample_strategy_name in sample_strategy_names: + with self.subTest(sample_strategy_name=sample_strategy_name): + sample_strategy_cls = SAMPLE_STRATEGY.get(sample_strategy_name) + self.assertIsNotNone( + sample_strategy_cls, + f"{sample_strategy_name} should be retrievable from registry", + ) + self.assertTrue( + issubclass(sample_strategy_cls, SampleStrategy), + f"{sample_strategy_name} should be a subclass of SampleStrategy", + ) + with self.assertRaises(ValueError): + SAMPLE_STRATEGY.get("non_existent_sample_strategy") + + def test_buffer_module_registry_mapping(self): + """Test registry mapping in buffer module""" + # test experience operator + operator_names = list(EXPERIENCE_OPERATORS._default_mapping.keys()) + for operator_name in operator_names: + with self.subTest(operator_name=operator_name): + operator_cls = EXPERIENCE_OPERATORS.get(operator_name) + self.assertIsNotNone( + operator_cls, f"{operator_name} should be retrievable from registry" + ) + self.assertTrue( + issubclass(operator_cls, ExperienceOperator), + f"{operator_name} should be a subclass of ExperienceOperator", + ) + with self.assertRaises(ValueError): + EXPERIENCE_OPERATORS.get("non_existent_operator") + + # test reader + reader_names = list(READER._default_mapping.keys()) + for reader_name in reader_names: + with self.subTest(reader_name=reader_name): + reader_cls = READER.get(reader_name) + self.assertIsNotNone( + reader_cls, f"{reader_name} should be retrievable from registry" + ) + self.assertTrue( + issubclass(reader_cls, BufferReader), + f"{reader_name} should be a subclass of BufferReader", + ) + with self.assertRaises(ValueError): + READER.get("non_existent_reader") + + # test formatter + formatter_names = list(FORMATTER._default_mapping.keys()) + for formatter_name in formatter_names: + with self.subTest(formatter_name=formatter_name): + formatter_cls = FORMATTER.get(formatter_name) + self.assertIsNotNone( + formatter_cls, f"{formatter_name} should be retrievable from registry" + ) + with self.assertRaises(ValueError): + FORMATTER.get("non_existent_formatter") + + # test sql schema + schema_names = list(SQL_SCHEMA._default_mapping.keys()) + for schema_name in schema_names: + with self.subTest(schema_name=schema_name): + schema_cls = SQL_SCHEMA.get(schema_name) + self.assertIsNotNone( + schema_cls, f"{schema_name} should be retrievable from registry" + ) + with self.assertRaises(ValueError): + SQL_SCHEMA.get("non_existent_schema") + + # test selector + selector_names = list(SELECTORS._default_mapping.keys()) + for selector_name in selector_names: + with self.subTest(selector_name=selector_name): + selector_cls = SELECTORS.get(selector_name) + self.assertIsNotNone( + selector_cls, f"{selector_name} should be retrievable from registry" + ) + self.assertTrue( + issubclass(selector_cls, BaseSelector), + f"{selector_name} should be a subclass of BaseSelector", + ) + with self.assertRaises(ValueError): + SELECTORS.get("non_existent_selector") + + # test priority function + priority_fn_names = list(PRIORITY_FUNC._default_mapping.keys()) + + for priority_fn_name in priority_fn_names: + with self.subTest(priority_fn_name=priority_fn_name): + priority_fn_cls = PRIORITY_FUNC.get(priority_fn_name) + self.assertIsNotNone( + priority_fn_cls, f"{priority_fn_name} should be retrievable from registry" + ) + self.assertTrue( + issubclass(priority_fn_cls, PriorityFunction), + f"{priority_fn_name} should be a subclass of PriorityFunction", + ) + with self.assertRaises(ValueError): + PRIORITY_FUNC.get("non_existent_priority_fn") + + def test_utils_module_registry_mapping(self): + """Test registry mapping in utils module""" + # test monitor + monitor_names = list(MONITOR._default_mapping.keys()) + for monitor_name in monitor_names: + with self.subTest(monitor_name=monitor_name): + monitor_cls = MONITOR.get(monitor_name) + self.assertIsNotNone( + monitor_cls, f"{monitor_name} should be retrievable from registry" + ) + self.assertTrue( + issubclass(monitor_cls, Monitor), + f"{monitor_name} should be a subclass of Monitor", + ) + with self.assertRaises(ValueError): + MONITOR.get("non_existent_monitor") + + def test_register_module(self): + """Test register module functionality""" + # Test that the registered class can be retrieved from registry + retrieved_cls = ENTROPY_LOSS_FN.get("dummy_entropy_loss_fn") + self.assertIsNotNone( + retrieved_cls, "dummy_entropy_loss_fn should be retrievable from registry" + ) + self.assertTrue( + issubclass(retrieved_cls, EntropyLossFn), + "dummy_entropy_loss_fn should be a subclass of EntropyLossFn", + ) + self.assertEqual( + retrieved_cls, DummyEntropyLossFn, "Retrieved class should be DummyEntropyLossFn" + ) + + # Test that the registered class can be instantiated and used + instance = retrieved_cls(entropy_coef=0.1) + self.assertIsInstance(instance, EntropyLossFn) + self.assertEqual(instance.entropy_coef, 0.1) + + # Test that the instance can be called (basic functionality) + loss, metrics = instance(entropy=torch.tensor(1.0), action_mask=torch.tensor(1.0)) + self.assertEqual(loss.item(), 0.0) + self.assertIsInstance(metrics, dict) diff --git a/trinity/algorithm/__init__.py b/trinity/algorithm/__init__.py index 667aa10d74..3b347b26e5 100644 --- a/trinity/algorithm/__init__.py +++ b/trinity/algorithm/__init__.py @@ -1,9 +1,34 @@ from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn -from trinity.algorithm.algorithm import ALGORITHM_TYPE, AlgorithmType +from trinity.algorithm.algorithm import AlgorithmType from trinity.algorithm.entropy_loss_fn import ENTROPY_LOSS_FN, EntropyLossFn from trinity.algorithm.kl_fn import KL_FN, KLFn from trinity.algorithm.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn from trinity.algorithm.sample_strategy import SAMPLE_STRATEGY, SampleStrategy +from trinity.utils.registry import Registry + +ALGORITHM_TYPE: Registry = Registry( + "algorithm", + default_mapping={ + "sft": "trinity.algorithm.algorithm.SFTAlgorithm", + "ppo": "trinity.algorithm.algorithm.PPOAlgorithm", + "grpo": "trinity.algorithm.algorithm.GRPOAlgorithm", + "reinforceplusplus": "trinity.algorithm.algorithm.ReinforcePlusPlusAlgorithm", + "rloo": "trinity.algorithm.algorithm.RLOOAlgorithm", + "opmd": "trinity.algorithm.algorithm.OPMDAlgorithm", + "asymre": "trinity.algorithm.algorithm.AsymREAlgorithm", + "dpo": "trinity.algorithm.algorithm.DPOAlgorithm", + "topr": "trinity.algorithm.algorithm.TOPRAlgorithm", + "cispo": "trinity.algorithm.algorithm.CISPOAlgorithm", + "gspo": "trinity.algorithm.algorithm.GSPOAlgorithm", + "sapo": "trinity.algorithm.algorithm.SAPOAlgorithm", + "mix": "trinity.algorithm.algorithm.MIXAlgorithm", + "mix_chord": "trinity.algorithm.algorithm.MIXCHORDAlgorithm", + "raft": "trinity.algorithm.algorithm.RAFTAlgorithm", + "sppo": "trinity.algorithm.algorithm.sPPOAlgorithm", + "rec": "trinity.algorithm.algorithm.RECAlgorithm", + "multi_step_grpo": "trinity.algorithm.algorithm.MultiStepGRPOAlgorithm", + }, +) __all__ = [ "ALGORITHM_TYPE", diff --git a/trinity/algorithm/advantage_fn/__init__.py b/trinity/algorithm/advantage_fn/__init__.py index f8349062f1..741a68fa6f 100644 --- a/trinity/algorithm/advantage_fn/__init__.py +++ b/trinity/algorithm/advantage_fn/__init__.py @@ -1,43 +1,27 @@ -from trinity.algorithm.advantage_fn.advantage_fn import ( - ADVANTAGE_FN, - AdvantageFn, - GroupAdvantage, -) -from trinity.algorithm.advantage_fn.asymre_advantage import ASYMREAdvantageFn -from trinity.algorithm.advantage_fn.grpo_advantage import ( - GRPOAdvantageFn, - GRPOGroupedAdvantage, -) -from trinity.algorithm.advantage_fn.multi_step_grpo_advantage import ( - StepWiseGRPOAdvantageFn, -) -from trinity.algorithm.advantage_fn.opmd_advantage import ( - OPMDAdvantageFn, - OPMDGroupAdvantage, -) -from trinity.algorithm.advantage_fn.ppo_advantage import PPOAdvantageFn -from trinity.algorithm.advantage_fn.rec_advantage import RECGroupedAdvantage -from trinity.algorithm.advantage_fn.reinforce_advantage import REINFORCEGroupAdvantage -from trinity.algorithm.advantage_fn.reinforce_plus_plus_advantage import ( - REINFORCEPLUSPLUSAdvantageFn, +from trinity.algorithm.advantage_fn.advantage_fn import AdvantageFn, GroupAdvantage +from trinity.utils.registry import Registry + +ADVANTAGE_FN: Registry = Registry( + "advantage_fn", + default_mapping={ + "ppo": "trinity.algorithm.advantage_fn.ppo_advantage.PPOAdvantageFn", + "grpo": "trinity.algorithm.advantage_fn.grpo_advantage.GRPOGroupedAdvantage", + "grpo_verl": "trinity.algorithm.advantage_fn.grpo_advantage.GRPOAdvantageFn", + "step_wise_grpo": "trinity.algorithm.advantage_fn.multi_step_grpo_advantage.StepWiseGRPOAdvantageFn", + "reinforceplusplus": "trinity.algorithm.advantage_fn.reinforce_plus_plus_advantage.REINFORCEPLUSPLUSAdvantageFn", + "reinforce": "trinity.algorithm.advantage_fn.reinforce_advantage.REINFORCEGroupAdvantage", + "remax": "trinity.algorithm.advantage_fn.remax_advantage.REMAXAdvantageFn", + "rloo": "trinity.algorithm.advantage_fn.rloo_advantage.RLOOAdvantageFn", + "opmd": "trinity.algorithm.advantage_fn.opmd_advantage.OPMDGroupAdvantage", + "opmd_verl": "trinity.algorithm.advantage_fn.opmd_advantage.OPMDAdvantageFn", + "asymre": "trinity.algorithm.advantage_fn.asymre_advantage.ASYMREGroupAdvantage", + "asymre_verl": "trinity.algorithm.advantage_fn.asymre_advantage.ASYMREAdvantageFn", + "rec": "trinity.algorithm.advantage_fn.rec_advantage.RECGroupedAdvantage", + }, ) -from trinity.algorithm.advantage_fn.remax_advantage import REMAXAdvantageFn -from trinity.algorithm.advantage_fn.rloo_advantage import RLOOAdvantageFn __all__ = [ "ADVANTAGE_FN", "AdvantageFn", "GroupAdvantage", - "PPOAdvantageFn", - "GRPOAdvantageFn", - "GRPOGroupedAdvantage", - "StepWiseGRPOAdvantageFn", - "REINFORCEPLUSPLUSAdvantageFn", - "REMAXAdvantageFn", - "RLOOAdvantageFn", - "OPMDAdvantageFn", - "OPMDGroupAdvantage", - "REINFORCEGroupAdvantage", - "ASYMREAdvantageFn", - "RECGroupedAdvantage", ] diff --git a/trinity/algorithm/advantage_fn/advantage_fn.py b/trinity/algorithm/advantage_fn/advantage_fn.py index 6810feda72..6deee4d257 100644 --- a/trinity/algorithm/advantage_fn/advantage_fn.py +++ b/trinity/algorithm/advantage_fn/advantage_fn.py @@ -4,9 +4,6 @@ from trinity.buffer.operators import ExperienceOperator from trinity.common.experience import Experience from trinity.utils.monitor import gather_metrics -from trinity.utils.registry import Registry - -ADVANTAGE_FN = Registry("advantage_fn") class AdvantageFn(ABC): diff --git a/trinity/algorithm/advantage_fn/asymre_advantage.py b/trinity/algorithm/advantage_fn/asymre_advantage.py index 81aa6659ff..52f2314272 100644 --- a/trinity/algorithm/advantage_fn/asymre_advantage.py +++ b/trinity/algorithm/advantage_fn/asymre_advantage.py @@ -6,17 +6,12 @@ import torch from verl import DataProto -from trinity.algorithm.advantage_fn.advantage_fn import ( - ADVANTAGE_FN, - AdvantageFn, - GroupAdvantage, -) +from trinity.algorithm.advantage_fn import AdvantageFn, GroupAdvantage from trinity.common.experience import Experience, group_by from trinity.utils.annotations import Deprecated @Deprecated -@ADVANTAGE_FN.register_module("asymre_verl") class ASYMREAdvantageFn(AdvantageFn): """AsymRE advantage computation""" @@ -87,7 +82,6 @@ def default_args(cls) -> Dict: } -@ADVANTAGE_FN.register_module("asymre") class ASYMREGroupAdvantage(GroupAdvantage): """asymre Group Advantage computation""" diff --git a/trinity/algorithm/advantage_fn/grpo_advantage.py b/trinity/algorithm/advantage_fn/grpo_advantage.py index a438c02869..7d1c58977d 100644 --- a/trinity/algorithm/advantage_fn/grpo_advantage.py +++ b/trinity/algorithm/advantage_fn/grpo_advantage.py @@ -8,18 +8,13 @@ import torch from verl import DataProto -from trinity.algorithm.advantage_fn.advantage_fn import ( - ADVANTAGE_FN, - AdvantageFn, - GroupAdvantage, -) +from trinity.algorithm.advantage_fn.advantage_fn import AdvantageFn, GroupAdvantage from trinity.common.experience import Experience, group_by from trinity.utils.annotations import Deprecated from trinity.utils.monitor import gather_metrics @Deprecated -@ADVANTAGE_FN.register_module("grpo_verl") class GRPOAdvantageFn(AdvantageFn): """GRPO advantage computation""" @@ -91,7 +86,6 @@ def default_args(cls) -> Dict: } -@ADVANTAGE_FN.register_module("grpo") class GRPOGroupedAdvantage(GroupAdvantage): """An advantage class that calculates GRPO advantages.""" diff --git a/trinity/algorithm/advantage_fn/multi_step_grpo_advantage.py b/trinity/algorithm/advantage_fn/multi_step_grpo_advantage.py index 3c11daf203..72dea0ac22 100644 --- a/trinity/algorithm/advantage_fn/multi_step_grpo_advantage.py +++ b/trinity/algorithm/advantage_fn/multi_step_grpo_advantage.py @@ -4,13 +4,12 @@ import torch -from trinity.algorithm.advantage_fn.advantage_fn import ADVANTAGE_FN, AdvantageFn +from trinity.algorithm.advantage_fn.advantage_fn import AdvantageFn from trinity.buffer.operators import ExperienceOperator from trinity.common.experience import Experience, group_by from trinity.utils.monitor import gather_metrics -@ADVANTAGE_FN.register_module("step_wise_grpo") class StepWiseGRPOAdvantageFn(AdvantageFn, ExperienceOperator): """ An advantage function that broadcasts advantages from the last step to previous steps. diff --git a/trinity/algorithm/advantage_fn/opmd_advantage.py b/trinity/algorithm/advantage_fn/opmd_advantage.py index 82bea6c90d..d5e9203e3c 100644 --- a/trinity/algorithm/advantage_fn/opmd_advantage.py +++ b/trinity/algorithm/advantage_fn/opmd_advantage.py @@ -6,17 +6,12 @@ import torch from verl import DataProto -from trinity.algorithm.advantage_fn.advantage_fn import ( - ADVANTAGE_FN, - AdvantageFn, - GroupAdvantage, -) +from trinity.algorithm.advantage_fn.advantage_fn import AdvantageFn, GroupAdvantage from trinity.common.experience import Experience, group_by from trinity.utils.annotations import Deprecated @Deprecated -@ADVANTAGE_FN.register_module("opmd_verl") class OPMDAdvantageFn(AdvantageFn): """OPMD advantage computation""" @@ -103,7 +98,6 @@ def default_args(cls) -> Dict: } -@ADVANTAGE_FN.register_module("opmd") class OPMDGroupAdvantage(GroupAdvantage): """OPMD Group Advantage computation""" diff --git a/trinity/algorithm/advantage_fn/ppo_advantage.py b/trinity/algorithm/advantage_fn/ppo_advantage.py index 31fda4454c..99c787375c 100644 --- a/trinity/algorithm/advantage_fn/ppo_advantage.py +++ b/trinity/algorithm/advantage_fn/ppo_advantage.py @@ -8,11 +8,10 @@ import torch from verl import DataProto -from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn +from trinity.algorithm.advantage_fn.advantage_fn import AdvantageFn from trinity.algorithm.utils import masked_whiten -@ADVANTAGE_FN.register_module("ppo") class PPOAdvantageFn(AdvantageFn): def __init__( self, diff --git a/trinity/algorithm/advantage_fn/rec_advantage.py b/trinity/algorithm/advantage_fn/rec_advantage.py index 140e3975cb..0a99137c00 100644 --- a/trinity/algorithm/advantage_fn/rec_advantage.py +++ b/trinity/algorithm/advantage_fn/rec_advantage.py @@ -5,11 +5,10 @@ import torch -from trinity.algorithm.advantage_fn.advantage_fn import ADVANTAGE_FN, GroupAdvantage +from trinity.algorithm.advantage_fn.advantage_fn import GroupAdvantage from trinity.common.experience import Experience, group_by -@ADVANTAGE_FN.register_module("rec") class RECGroupedAdvantage(GroupAdvantage): """An advantage class that calculates REC advantages.""" diff --git a/trinity/algorithm/advantage_fn/reinforce_advantage.py b/trinity/algorithm/advantage_fn/reinforce_advantage.py index 8c06451eda..bc8c6546b4 100644 --- a/trinity/algorithm/advantage_fn/reinforce_advantage.py +++ b/trinity/algorithm/advantage_fn/reinforce_advantage.py @@ -4,11 +4,10 @@ import torch -from trinity.algorithm.advantage_fn.advantage_fn import ADVANTAGE_FN, GroupAdvantage +from trinity.algorithm.advantage_fn.advantage_fn import GroupAdvantage from trinity.common.experience import Experience, group_by -@ADVANTAGE_FN.register_module("reinforce") class REINFORCEGroupAdvantage(GroupAdvantage): """Reinforce Group Advantage computation""" diff --git a/trinity/algorithm/advantage_fn/reinforce_plus_plus_advantage.py b/trinity/algorithm/advantage_fn/reinforce_plus_plus_advantage.py index eb63c3605b..77e735cafa 100644 --- a/trinity/algorithm/advantage_fn/reinforce_plus_plus_advantage.py +++ b/trinity/algorithm/advantage_fn/reinforce_plus_plus_advantage.py @@ -8,11 +8,10 @@ import torch from verl import DataProto -from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn +from trinity.algorithm.advantage_fn import AdvantageFn from trinity.algorithm.utils import masked_whiten -@ADVANTAGE_FN.register_module("reinforceplusplus") class REINFORCEPLUSPLUSAdvantageFn(AdvantageFn): def __init__(self, gamma: float = 1.0) -> None: self.gamma = gamma diff --git a/trinity/algorithm/advantage_fn/remax_advantage.py b/trinity/algorithm/advantage_fn/remax_advantage.py index 07f92d91a0..d71d7b8310 100644 --- a/trinity/algorithm/advantage_fn/remax_advantage.py +++ b/trinity/algorithm/advantage_fn/remax_advantage.py @@ -8,10 +8,9 @@ import torch from verl import DataProto -from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn +from trinity.algorithm.advantage_fn import AdvantageFn -@ADVANTAGE_FN.register_module("remax") class REMAXAdvantageFn(AdvantageFn): def __init__(self) -> None: pass diff --git a/trinity/algorithm/advantage_fn/rloo_advantage.py b/trinity/algorithm/advantage_fn/rloo_advantage.py index 5cc079e687..71914114d1 100644 --- a/trinity/algorithm/advantage_fn/rloo_advantage.py +++ b/trinity/algorithm/advantage_fn/rloo_advantage.py @@ -9,10 +9,9 @@ import torch from verl import DataProto -from trinity.algorithm.advantage_fn import ADVANTAGE_FN, AdvantageFn +from trinity.algorithm.advantage_fn import AdvantageFn -@ADVANTAGE_FN.register_module("rloo") class RLOOAdvantageFn(AdvantageFn): def __init__(self) -> None: pass diff --git a/trinity/algorithm/algorithm.py b/trinity/algorithm/algorithm.py index 93f76676f8..b08da64cdc 100644 --- a/trinity/algorithm/algorithm.py +++ b/trinity/algorithm/algorithm.py @@ -7,12 +7,9 @@ from trinity.common.config import Config from trinity.common.constants import SyncMethod from trinity.utils.log import get_logger -from trinity.utils.registry import Registry logger = get_logger(__name__) -ALGORITHM_TYPE = Registry("algorithm") - class ConstantMeta(ABCMeta): def __setattr__(cls, name, value): @@ -48,7 +45,6 @@ def check_config(cls, config: Config) -> None: pass -@ALGORITHM_TYPE.register_module("sft") class SFTAlgorithm(AlgorithmType): """SFT Algorithm.""" @@ -68,7 +64,6 @@ def default_config(cls) -> Dict: } -@ALGORITHM_TYPE.register_module("ppo") class PPOAlgorithm(AlgorithmType): """PPO Algorithm.""" @@ -91,7 +86,6 @@ def default_config(cls) -> Dict: } -@ALGORITHM_TYPE.register_module("grpo") class GRPOAlgorithm(AlgorithmType): """GRPO algorithm.""" @@ -114,7 +108,6 @@ def default_config(cls) -> Dict: } -@ALGORITHM_TYPE.register_module("reinforceplusplus") class ReinforcePlusPlusAlgorithm(AlgorithmType): """Reinforce++ algorithm.""" @@ -137,7 +130,6 @@ def default_config(cls) -> Dict: } -@ALGORITHM_TYPE.register_module("rloo") class RLOOAlgorithm(AlgorithmType): """RLOO algorithm.""" @@ -160,7 +152,6 @@ def default_config(cls) -> Dict: } -@ALGORITHM_TYPE.register_module("opmd") class OPMDAlgorithm(AlgorithmType): """OPMD algorithm.""" @@ -183,7 +174,6 @@ def default_config(cls) -> Dict: } -@ALGORITHM_TYPE.register_module("asymre") class AsymREAlgorithm(AlgorithmType): """AsymRE algorithm.""" @@ -206,7 +196,6 @@ def default_config(cls) -> Dict: } -@ALGORITHM_TYPE.register_module("dpo") class DPOAlgorithm(AlgorithmType): """DPO algorithm.""" @@ -250,7 +239,6 @@ def check_config(cls, config: Config) -> None: logger.warning("DPO must use KL loss. Set `algorithm.kl_loss_fn` to `k2`") -@ALGORITHM_TYPE.register_module("topr") class TOPRAlgorithm(AlgorithmType): """TOPR algorithm. See https://arxiv.org/pdf/2503.14286v1""" @@ -273,7 +261,6 @@ def default_config(cls) -> Dict: } -@ALGORITHM_TYPE.register_module("cispo") class CISPOAlgorithm(AlgorithmType): """CISPO algorithm. See https://arxiv.org/abs/2506.13585""" @@ -296,7 +283,6 @@ def default_config(cls) -> Dict: } -@ALGORITHM_TYPE.register_module("gspo") class GSPOAlgorithm(AlgorithmType): """GSPO algorithm. See https://arxiv.org/pdf/2507.18071""" @@ -319,7 +305,6 @@ def default_config(cls) -> Dict: } -@ALGORITHM_TYPE.register_module("sapo") class SAPOAlgorithm(AlgorithmType): """SAPO (Soft Adaptive Policy Optimization) algorithm. @@ -346,7 +331,6 @@ def default_config(cls) -> Dict: } -@ALGORITHM_TYPE.register_module("mix") class MIXAlgorithm(AlgorithmType): """MIX algorithm.""" @@ -368,7 +352,6 @@ def default_config(cls) -> Dict: } -@ALGORITHM_TYPE.register_module("mix_chord") class MIXCHORDAlgorithm(AlgorithmType): """MIX algorithm.""" @@ -390,7 +373,6 @@ def default_config(cls) -> Dict: } -@ALGORITHM_TYPE.register_module("raft") class RAFTAlgorithm(AlgorithmType): """RAFT Algorithm. This algorithm is conceptually similar to Supervised Fine-Tuning (SFT) @@ -413,7 +395,6 @@ def default_config(cls) -> Dict: } -@ALGORITHM_TYPE.register_module("sppo") class sPPOAlgorithm(AlgorithmType): """sPPO Algorithm.""" @@ -436,7 +417,6 @@ def default_config(cls) -> Dict: } -@ALGORITHM_TYPE.register_module("rec") class RECAlgorithm(AlgorithmType): """REC Algorithm.""" @@ -459,7 +439,6 @@ def default_config(cls) -> Dict: } -@ALGORITHM_TYPE.register_module("multi_step_grpo") class MultiStepGRPOAlgorithm(AlgorithmType): """Multi-Step GRPO Algorithm.""" diff --git a/trinity/algorithm/entropy_loss_fn/__init__.py b/trinity/algorithm/entropy_loss_fn/__init__.py index d932b94fde..0d3e1c9735 100644 --- a/trinity/algorithm/entropy_loss_fn/__init__.py +++ b/trinity/algorithm/entropy_loss_fn/__init__.py @@ -1,6 +1,13 @@ -from trinity.algorithm.entropy_loss_fn.entropy_loss_fn import ( - ENTROPY_LOSS_FN, - EntropyLossFn, +from trinity.algorithm.entropy_loss_fn.entropy_loss_fn import EntropyLossFn +from trinity.utils.registry import Registry + +ENTROPY_LOSS_FN: Registry = Registry( + "entropy_loss_fn", + default_mapping={ + "default": "trinity.algorithm.entropy_loss_fn.entropy_loss_fn.DefaultEntropyLossFn", + "mix": "trinity.algorithm.entropy_loss_fn.entropy_loss_fn.MixEntropyLossFn", + "none": "trinity.algorithm.entropy_loss_fn.entropy_loss_fn.DummyEntropyLossFn", + }, ) __all__ = [ diff --git a/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py b/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py index d5bf55dc1b..75069c3739 100644 --- a/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py +++ b/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py @@ -4,9 +4,6 @@ import torch from trinity.algorithm.utils import aggregate_loss -from trinity.utils.registry import Registry - -ENTROPY_LOSS_FN = Registry("entropy_loss_fn") class EntropyLossFn(ABC): @@ -40,7 +37,6 @@ def default_args(cls) -> Dict: return {"entropy_coef": 0.0} -@ENTROPY_LOSS_FN.register_module("default") class DefaultEntropyLossFn(EntropyLossFn): """ Basic entropy loss function. @@ -60,7 +56,6 @@ def __call__( return entropy_loss * self.entropy_coef, {"entropy_loss": entropy_loss.detach().item()} -@ENTROPY_LOSS_FN.register_module("mix") class MixEntropyLossFn(EntropyLossFn): """ Basic entropy loss function for mix algorithm. @@ -88,7 +83,6 @@ def __call__( return entropy_loss * self.entropy_coef, {"entropy_loss": entropy_loss.detach().item()} -@ENTROPY_LOSS_FN.register_module("none") class DummyEntropyLossFn(EntropyLossFn): """ Dummy entropy loss function. diff --git a/trinity/algorithm/kl_fn/__init__.py b/trinity/algorithm/kl_fn/__init__.py index 875c620442..2426f07991 100644 --- a/trinity/algorithm/kl_fn/__init__.py +++ b/trinity/algorithm/kl_fn/__init__.py @@ -1,3 +1,17 @@ -from trinity.algorithm.kl_fn.kl_fn import KL_FN, KLFn +from trinity.algorithm.kl_fn.kl_fn import KLFn +from trinity.utils.registry import Registry + +KL_FN: Registry = Registry( + "kl_fn", + default_mapping={ + "none": "trinity.algorithm.kl_fn.kl_fn.DummyKLFn", + "k1": "trinity.algorithm.kl_fn.kl_fn.K1Fn", + "k2": "trinity.algorithm.kl_fn.kl_fn.K2Fn", + "k3": "trinity.algorithm.kl_fn.kl_fn.K3Fn", + "low_var_kl": "trinity.algorithm.kl_fn.kl_fn.LowVarKLFn", + "abs": "trinity.algorithm.kl_fn.kl_fn.AbsFn", + "corrected_k3": "trinity.algorithm.kl_fn.kl_fn.CorrectedK3Fn", + }, +) __all__ = ["KLFn", "KL_FN"] diff --git a/trinity/algorithm/kl_fn/kl_fn.py b/trinity/algorithm/kl_fn/kl_fn.py index 54e1639608..312f530329 100644 --- a/trinity/algorithm/kl_fn/kl_fn.py +++ b/trinity/algorithm/kl_fn/kl_fn.py @@ -12,9 +12,6 @@ import torch from trinity.algorithm.utils import aggregate_loss, masked_mean -from trinity.utils.registry import Registry - -KL_FN = Registry("kl_fn") class KLFn(ABC): @@ -122,7 +119,6 @@ def default_args(cls): return {"adaptive": False, "kl_coef": 0.001} -@KL_FN.register_module("none") class DummyKLFn(KLFn): """ Dummy KL function. @@ -152,7 +148,6 @@ def calculate_kl_loss( return torch.tensor(0.0), {} -@KL_FN.register_module("k1") class K1Fn(KLFn): """ KL K1 function. @@ -167,7 +162,6 @@ def calculate_kl( return logprob - ref_logprob -@KL_FN.register_module("k2") class K2Fn(KLFn): """ KL K2 function. @@ -182,7 +176,6 @@ def calculate_kl( return (logprob - ref_logprob).square() * 0.5 -@KL_FN.register_module("k3") class K3Fn(KLFn): """ KL K3 function. @@ -198,7 +191,6 @@ def calculate_kl( return logr.exp() - 1 - logr -@KL_FN.register_module("low_var_kl") class LowVarKLFn(KLFn): """ Low Variance KL function. @@ -217,7 +209,6 @@ def calculate_kl( return torch.clamp(kld, min=-10, max=10) -@KL_FN.register_module("abs") class AbsFn(KLFn): """ KL Abs function. @@ -232,7 +223,6 @@ def calculate_kl( return torch.abs(logprob - ref_logprob) -@KL_FN.register_module("corrected_k3") class CorrectedK3Fn(KLFn): """ Corrected K3 function with importance sampling. diff --git a/trinity/algorithm/policy_loss_fn/__init__.py b/trinity/algorithm/policy_loss_fn/__init__.py index 124a23fff3..5d30e9f4f6 100644 --- a/trinity/algorithm/policy_loss_fn/__init__.py +++ b/trinity/algorithm/policy_loss_fn/__init__.py @@ -1,36 +1,27 @@ -from trinity.algorithm.policy_loss_fn.chord_policy_loss import ( - MIXCHORDPolicyLossFn, - SFTISLossFn, - SFTPhiLossFn, +from trinity.algorithm.policy_loss_fn.policy_loss_fn import PolicyLossFn +from trinity.utils.registry import Registry + +POLICY_LOSS_FN: Registry = Registry( + "policy_loss_fn", + default_mapping={ + "ppo": "trinity.algorithm.policy_loss_fn.ppo_policy_loss.PPOPolicyLossFn", + "opmd": "trinity.algorithm.policy_loss_fn.opmd_policy_loss.OPMDPolicyLossFn", + "dpo": "trinity.algorithm.policy_loss_fn.dpo_loss.DPOLossFn", + "sft": "trinity.algorithm.policy_loss_fn.sft_loss.SFTLossFn", + "mix": "trinity.algorithm.policy_loss_fn.mix_policy_loss.MIXPolicyLossFn", + "gspo": "trinity.algorithm.policy_loss_fn.gspo_policy_loss.GSPOLossFn", + "topr": "trinity.algorithm.policy_loss_fn.topr_policy_loss.TOPRPolicyLossFn", + "cispo": "trinity.algorithm.policy_loss_fn.cispo_policy_loss.CISPOPolicyLossFn", + "sft_is": "trinity.algorithm.policy_loss_fn.chord_policy_loss.SFTISLossFn", + "sft_phi": "trinity.algorithm.policy_loss_fn.chord_policy_loss.SFTPhiLossFn", + "mix_chord": "trinity.algorithm.policy_loss_fn.chord_policy_loss.MIXCHORDPolicyLossFn", + "sppo": "trinity.algorithm.policy_loss_fn.sppo_loss_fn.sPPOPolicyLossFn", + "rec": "trinity.algorithm.policy_loss_fn.rec_policy_loss.RECPolicyLossFn", + "sapo": "trinity.algorithm.policy_loss_fn.sapo_policy_loss.SAPOPolicyLossFn", + }, ) -from trinity.algorithm.policy_loss_fn.cispo_policy_loss import CISPOPolicyLossFn -from trinity.algorithm.policy_loss_fn.dpo_loss import DPOLossFn -from trinity.algorithm.policy_loss_fn.gspo_policy_loss import GSPOLossFn -from trinity.algorithm.policy_loss_fn.mix_policy_loss import MIXPolicyLossFn -from trinity.algorithm.policy_loss_fn.opmd_policy_loss import OPMDPolicyLossFn -from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn -from trinity.algorithm.policy_loss_fn.ppo_policy_loss import PPOPolicyLossFn -from trinity.algorithm.policy_loss_fn.rec_policy_loss import RECPolicyLossFn -from trinity.algorithm.policy_loss_fn.sapo_policy_loss import SAPOPolicyLossFn -from trinity.algorithm.policy_loss_fn.sft_loss import SFTLossFn -from trinity.algorithm.policy_loss_fn.sppo_loss_fn import sPPOPolicyLossFn -from trinity.algorithm.policy_loss_fn.topr_policy_loss import TOPRPolicyLossFn __all__ = [ "POLICY_LOSS_FN", "PolicyLossFn", - "PPOPolicyLossFn", - "OPMDPolicyLossFn", - "DPOLossFn", - "SFTLossFn", - "MIXPolicyLossFn", - "GSPOLossFn", - "TOPRPolicyLossFn", - "CISPOPolicyLossFn", - "MIXCHORDPolicyLossFn", - "SFTISLossFn", - "SFTPhiLossFn", - "sPPOPolicyLossFn", - "RECPolicyLossFn", - "SAPOPolicyLossFn", ] diff --git a/trinity/algorithm/policy_loss_fn/chord_policy_loss.py b/trinity/algorithm/policy_loss_fn/chord_policy_loss.py index 1a6a893041..6653b4d225 100644 --- a/trinity/algorithm/policy_loss_fn/chord_policy_loss.py +++ b/trinity/algorithm/policy_loss_fn/chord_policy_loss.py @@ -5,7 +5,7 @@ import torch -from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn +from trinity.algorithm.policy_loss_fn.policy_loss_fn import PolicyLossFn from trinity.algorithm.policy_loss_fn.ppo_policy_loss import PPOPolicyLossFn from trinity.algorithm.policy_loss_fn.sft_loss import SFTLossFn from trinity.algorithm.utils import aggregate_loss @@ -31,7 +31,6 @@ def mu_schedule_function( return decayed_mu -@POLICY_LOSS_FN.register_module("sft_is") class SFTISLossFn(PolicyLossFn): """ SFT loss with importance sampling @@ -68,7 +67,6 @@ def phi_function(token_prob): return token_prob * (1 - token_prob) -@POLICY_LOSS_FN.register_module("sft_phi") class SFTPhiLossFn(PolicyLossFn): """ SFT loss with transformed phi function @@ -107,7 +105,6 @@ def default_args(cls): } -@POLICY_LOSS_FN.register_module("mix_chord") class MIXCHORDPolicyLossFn(PolicyLossFn): """Implements a mixed policy loss combining GRPO and SFT losses. diff --git a/trinity/algorithm/policy_loss_fn/cispo_policy_loss.py b/trinity/algorithm/policy_loss_fn/cispo_policy_loss.py index 6da3d07526..45d4b2277b 100644 --- a/trinity/algorithm/policy_loss_fn/cispo_policy_loss.py +++ b/trinity/algorithm/policy_loss_fn/cispo_policy_loss.py @@ -6,11 +6,10 @@ import torch -from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn +from trinity.algorithm.policy_loss_fn.policy_loss_fn import PolicyLossFn from trinity.algorithm.utils import aggregate_loss, masked_mean -@POLICY_LOSS_FN.register_module("cispo") class CISPOPolicyLossFn(PolicyLossFn): def __init__( self, diff --git a/trinity/algorithm/policy_loss_fn/dpo_loss.py b/trinity/algorithm/policy_loss_fn/dpo_loss.py index 0858cb7002..d1b3a2122d 100644 --- a/trinity/algorithm/policy_loss_fn/dpo_loss.py +++ b/trinity/algorithm/policy_loss_fn/dpo_loss.py @@ -5,11 +5,10 @@ import torch import torch.nn.functional as F -from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn +from trinity.algorithm.policy_loss_fn.policy_loss_fn import PolicyLossFn from trinity.algorithm.utils import masked_sum -@POLICY_LOSS_FN.register_module("dpo") class DPOLossFn(PolicyLossFn): def __init__( self, diff --git a/trinity/algorithm/policy_loss_fn/gspo_policy_loss.py b/trinity/algorithm/policy_loss_fn/gspo_policy_loss.py index 14e76dc02b..5b30a6a6af 100644 --- a/trinity/algorithm/policy_loss_fn/gspo_policy_loss.py +++ b/trinity/algorithm/policy_loss_fn/gspo_policy_loss.py @@ -7,11 +7,10 @@ import torch -from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn +from trinity.algorithm.policy_loss_fn.policy_loss_fn import PolicyLossFn from trinity.algorithm.utils import aggregate_loss, masked_mean -@POLICY_LOSS_FN.register_module("gspo") class GSPOLossFn(PolicyLossFn): def __init__( self, diff --git a/trinity/algorithm/policy_loss_fn/mix_policy_loss.py b/trinity/algorithm/policy_loss_fn/mix_policy_loss.py index ae3e3ffb84..dc5fcc65b5 100644 --- a/trinity/algorithm/policy_loss_fn/mix_policy_loss.py +++ b/trinity/algorithm/policy_loss_fn/mix_policy_loss.py @@ -4,12 +4,11 @@ import torch -from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn +from trinity.algorithm.policy_loss_fn.policy_loss_fn import PolicyLossFn from trinity.algorithm.policy_loss_fn.ppo_policy_loss import PPOPolicyLossFn from trinity.algorithm.policy_loss_fn.sft_loss import SFTLossFn -@POLICY_LOSS_FN.register_module("mix") class MIXPolicyLossFn(PolicyLossFn): """Implements a mixed policy loss combining GRPO and SFT losses. diff --git a/trinity/algorithm/policy_loss_fn/opmd_policy_loss.py b/trinity/algorithm/policy_loss_fn/opmd_policy_loss.py index b83a960bba..528e2e4ecf 100644 --- a/trinity/algorithm/policy_loss_fn/opmd_policy_loss.py +++ b/trinity/algorithm/policy_loss_fn/opmd_policy_loss.py @@ -4,11 +4,10 @@ import torch -from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn +from trinity.algorithm.policy_loss_fn.policy_loss_fn import PolicyLossFn from trinity.algorithm.utils import aggregate_loss -@POLICY_LOSS_FN.register_module("opmd") class OPMDPolicyLossFn(PolicyLossFn): def __init__( self, backend: str = "verl", tau: float = 1.0, loss_agg_mode: str = "token-mean" diff --git a/trinity/algorithm/policy_loss_fn/policy_loss_fn.py b/trinity/algorithm/policy_loss_fn/policy_loss_fn.py index aa6025252e..8cf73d5dbc 100644 --- a/trinity/algorithm/policy_loss_fn/policy_loss_fn.py +++ b/trinity/algorithm/policy_loss_fn/policy_loss_fn.py @@ -5,9 +5,6 @@ import torch from trinity.algorithm.key_mapper import ALL_MAPPERS -from trinity.utils.registry import Registry - -POLICY_LOSS_FN = Registry("policy_loss_fn") class PolicyLossFnMeta(ABCMeta): diff --git a/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py b/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py index d5e17a83e9..86f91c96a8 100644 --- a/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py +++ b/trinity/algorithm/policy_loss_fn/ppo_policy_loss.py @@ -7,11 +7,10 @@ import torch -from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn +from trinity.algorithm.policy_loss_fn.policy_loss_fn import PolicyLossFn from trinity.algorithm.utils import aggregate_loss, masked_mean -@POLICY_LOSS_FN.register_module("ppo") class PPOPolicyLossFn(PolicyLossFn): def __init__( self, diff --git a/trinity/algorithm/policy_loss_fn/rec_policy_loss.py b/trinity/algorithm/policy_loss_fn/rec_policy_loss.py index ec0e623c9f..5d7c547e62 100644 --- a/trinity/algorithm/policy_loss_fn/rec_policy_loss.py +++ b/trinity/algorithm/policy_loss_fn/rec_policy_loss.py @@ -5,11 +5,10 @@ import torch -from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn +from trinity.algorithm.policy_loss_fn.policy_loss_fn import PolicyLossFn from trinity.algorithm.utils import masked_mean -@POLICY_LOSS_FN.register_module("rec") class RECPolicyLossFn(PolicyLossFn): def __init__( self, diff --git a/trinity/algorithm/policy_loss_fn/sapo_policy_loss.py b/trinity/algorithm/policy_loss_fn/sapo_policy_loss.py index 7d5c4a9598..285a24e338 100644 --- a/trinity/algorithm/policy_loss_fn/sapo_policy_loss.py +++ b/trinity/algorithm/policy_loss_fn/sapo_policy_loss.py @@ -9,11 +9,10 @@ import torch -from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn +from trinity.algorithm.policy_loss_fn.policy_loss_fn import PolicyLossFn from trinity.algorithm.utils import aggregate_loss, masked_mean -@POLICY_LOSS_FN.register_module("sapo") class SAPOPolicyLossFn(PolicyLossFn): def __init__( self, diff --git a/trinity/algorithm/policy_loss_fn/sft_loss.py b/trinity/algorithm/policy_loss_fn/sft_loss.py index bd81d15380..e677e02613 100644 --- a/trinity/algorithm/policy_loss_fn/sft_loss.py +++ b/trinity/algorithm/policy_loss_fn/sft_loss.py @@ -4,11 +4,10 @@ import torch -from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn +from trinity.algorithm.policy_loss_fn.policy_loss_fn import PolicyLossFn from trinity.algorithm.utils import aggregate_loss -@POLICY_LOSS_FN.register_module("sft") class SFTLossFn(PolicyLossFn): def __init__(self, backend: str = "verl", loss_agg_mode: str = "token-mean") -> None: super().__init__(backend=backend) diff --git a/trinity/algorithm/policy_loss_fn/sppo_loss_fn.py b/trinity/algorithm/policy_loss_fn/sppo_loss_fn.py index 068a201c26..b9776815ad 100644 --- a/trinity/algorithm/policy_loss_fn/sppo_loss_fn.py +++ b/trinity/algorithm/policy_loss_fn/sppo_loss_fn.py @@ -6,11 +6,10 @@ import torch -from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn +from trinity.algorithm.policy_loss_fn.policy_loss_fn import PolicyLossFn from trinity.algorithm.utils import aggregate_loss, masked_mean -@POLICY_LOSS_FN.register_module("sppo") class sPPOPolicyLossFn(PolicyLossFn): def __init__( self, diff --git a/trinity/algorithm/policy_loss_fn/topr_policy_loss.py b/trinity/algorithm/policy_loss_fn/topr_policy_loss.py index cb6500754a..a12a19e779 100644 --- a/trinity/algorithm/policy_loss_fn/topr_policy_loss.py +++ b/trinity/algorithm/policy_loss_fn/topr_policy_loss.py @@ -5,11 +5,10 @@ import torch -from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn +from trinity.algorithm.policy_loss_fn.policy_loss_fn import PolicyLossFn from trinity.algorithm.utils import aggregate_loss, masked_mean -@POLICY_LOSS_FN.register_module("topr") class TOPRPolicyLossFn(PolicyLossFn): def __init__( self, diff --git a/trinity/algorithm/sample_strategy/__init__.py b/trinity/algorithm/sample_strategy/__init__.py index cd4b9e0d66..9e2700fb4a 100644 --- a/trinity/algorithm/sample_strategy/__init__.py +++ b/trinity/algorithm/sample_strategy/__init__.py @@ -1,15 +1,16 @@ -from trinity.algorithm.sample_strategy.mix_sample_strategy import MixSampleStrategy -from trinity.algorithm.sample_strategy.sample_strategy import ( - SAMPLE_STRATEGY, - DefaultSampleStrategy, - SampleStrategy, - WarmupSampleStrategy, +from trinity.algorithm.sample_strategy.sample_strategy import SampleStrategy +from trinity.utils.registry import Registry + +SAMPLE_STRATEGY: Registry = Registry( + "sample_strategy", + default_mapping={ + "default": "trinity.algorithm.sample_strategy.sample_strategy.DefaultSampleStrategy", + "warmup": "trinity.algorithm.sample_strategy.sample_strategy.WarmupSampleStrategy", + "mix": "trinity.algorithm.sample_strategy.mix_sample_strategy.MixSampleStrategy", + }, ) __all__ = [ "SAMPLE_STRATEGY", "SampleStrategy", - "DefaultSampleStrategy", - "WarmupSampleStrategy", - "MixSampleStrategy", ] diff --git a/trinity/algorithm/sample_strategy/mix_sample_strategy.py b/trinity/algorithm/sample_strategy/mix_sample_strategy.py index 141213c342..65acc44d3c 100644 --- a/trinity/algorithm/sample_strategy/mix_sample_strategy.py +++ b/trinity/algorithm/sample_strategy/mix_sample_strategy.py @@ -4,10 +4,7 @@ import torch -from trinity.algorithm.sample_strategy.sample_strategy import ( - SAMPLE_STRATEGY, - SampleStrategy, -) +from trinity.algorithm.sample_strategy.sample_strategy import SampleStrategy from trinity.algorithm.sample_strategy.utils import representative_sample from trinity.buffer import get_buffer_reader from trinity.common.config import BufferConfig @@ -15,7 +12,6 @@ from trinity.utils.timer import Timer -@SAMPLE_STRATEGY.register_module("mix") class MixSampleStrategy(SampleStrategy): """The default sample strategy.""" diff --git a/trinity/algorithm/sample_strategy/sample_strategy.py b/trinity/algorithm/sample_strategy/sample_strategy.py index 27b021146f..e15c3e0b0b 100644 --- a/trinity/algorithm/sample_strategy/sample_strategy.py +++ b/trinity/algorithm/sample_strategy/sample_strategy.py @@ -7,11 +7,8 @@ from trinity.common.experience import Experience, Experiences from trinity.utils.annotations import Deprecated from trinity.utils.monitor import gather_metrics -from trinity.utils.registry import Registry from trinity.utils.timer import Timer -SAMPLE_STRATEGY = Registry("sample_strategy") - class SampleStrategy(ABC): def __init__(self, buffer_config: BufferConfig, **kwargs) -> None: @@ -52,7 +49,6 @@ def load_state_dict(self, state_dict: dict) -> None: """Load the state dict of the sample strategy.""" -@SAMPLE_STRATEGY.register_module("default") class DefaultSampleStrategy(SampleStrategy): def __init__(self, buffer_config: BufferConfig, **kwargs): super().__init__(buffer_config) @@ -81,7 +77,6 @@ def load_state_dict(self, state_dict: dict) -> None: @Deprecated -@SAMPLE_STRATEGY.register_module("warmup") class WarmupSampleStrategy(DefaultSampleStrategy): """The warmup sample strategy. Deprecated, keep this class for backward compatibility only. diff --git a/trinity/buffer/operators/__init__.py b/trinity/buffer/operators/__init__.py index 4153c049b2..258d4d76b5 100644 --- a/trinity/buffer/operators/__init__.py +++ b/trinity/buffer/operators/__init__.py @@ -1,18 +1,18 @@ -from trinity.buffer.operators.data_juicer_operator import DataJuicerOperator -from trinity.buffer.operators.experience_operator import ( - EXPERIENCE_OPERATORS, - ExperienceOperator, +from trinity.buffer.operators.experience_operator import ExperienceOperator +from trinity.utils.registry import Registry + +EXPERIENCE_OPERATORS: Registry = Registry( + "experience_operators", + default_mapping={ + "reward_filter": "trinity.buffer.operators.filters.reward_filter.RewardFilter", + "reward_std_filter": "trinity.buffer.operators.filters.reward_filter.RewardSTDFilter", + "reward_shaping_mapper": "trinity.buffer.operators.mappers.reward_shaping_mapper.RewardShapingMapper", + "pass_rate_calculator": "trinity.buffer.operators.mappers.pass_rate_calculator.PassRateCalculator", + "data_juicer": "trinity.buffer.operators.data_juicer_operator.DataJuicerOperator", + }, ) -from trinity.buffer.operators.filters.reward_filter import RewardFilter, RewardSTDFilter -from trinity.buffer.operators.mappers.pass_rate_calculator import PassRateCalculator -from trinity.buffer.operators.mappers.reward_shaping_mapper import RewardShapingMapper __all__ = [ "ExperienceOperator", "EXPERIENCE_OPERATORS", - "RewardFilter", - "RewardSTDFilter", - "RewardShapingMapper", - "PassRateCalculator", - "DataJuicerOperator", ] diff --git a/trinity/buffer/operators/data_juicer_operator.py b/trinity/buffer/operators/data_juicer_operator.py index d651885fea..4287f79287 100644 --- a/trinity/buffer/operators/data_juicer_operator.py +++ b/trinity/buffer/operators/data_juicer_operator.py @@ -1,15 +1,11 @@ from typing import Dict, List, Optional, Tuple -from trinity.buffer.operators.experience_operator import ( - EXPERIENCE_OPERATORS, - ExperienceOperator, -) +from trinity.buffer.operators.experience_operator import ExperienceOperator from trinity.common.config import DataJuicerServiceConfig from trinity.common.experience import Experience from trinity.service.data_juicer.client import DataJuicerClient -@EXPERIENCE_OPERATORS.register_module("data_juicer") class DataJuicerOperator(ExperienceOperator): def __init__( self, diff --git a/trinity/buffer/operators/experience_operator.py b/trinity/buffer/operators/experience_operator.py index f53995f5d6..7834dcd2be 100644 --- a/trinity/buffer/operators/experience_operator.py +++ b/trinity/buffer/operators/experience_operator.py @@ -5,9 +5,6 @@ from trinity.common.config import OperatorConfig from trinity.common.experience import Experience -from trinity.utils.registry import Registry - -EXPERIENCE_OPERATORS = Registry("experience_operators") class ExperienceOperator(ABC): @@ -37,6 +34,9 @@ def create_operators(cls, operator_configs: List[OperatorConfig]) -> List[Experi Returns: List[ExperienceOperator]: List of instantiated ExperienceOperator objects. """ + # Import here to avoid circular import + from trinity.buffer.operators import EXPERIENCE_OPERATORS + operators = [] for config in operator_configs: operator_class = EXPERIENCE_OPERATORS.get(config.name) diff --git a/trinity/buffer/operators/filters/reward_filter.py b/trinity/buffer/operators/filters/reward_filter.py index 95e68960b8..dc5bd92e7e 100644 --- a/trinity/buffer/operators/filters/reward_filter.py +++ b/trinity/buffer/operators/filters/reward_filter.py @@ -2,11 +2,10 @@ import numpy as np -from trinity.buffer.operators import EXPERIENCE_OPERATORS, ExperienceOperator +from trinity.buffer.operators import ExperienceOperator from trinity.common.experience import Experience, group_by -@EXPERIENCE_OPERATORS.register_module("reward_filter") class RewardFilter(ExperienceOperator): """ Filter experiences based on the reward value. @@ -24,7 +23,6 @@ def process(self, exps: List[Experience]) -> Tuple[List[Experience], dict]: return filtered_exps, metrics -@EXPERIENCE_OPERATORS.register_module("reward_std_filter") class RewardSTDFilter(ExperienceOperator): """ Filter experiences based on the standard deviation of rewards within each group. diff --git a/trinity/buffer/operators/mappers/pass_rate_calculator.py b/trinity/buffer/operators/mappers/pass_rate_calculator.py index a743c9c122..31dce63d17 100644 --- a/trinity/buffer/operators/mappers/pass_rate_calculator.py +++ b/trinity/buffer/operators/mappers/pass_rate_calculator.py @@ -3,15 +3,11 @@ import numpy as np -from trinity.buffer.operators.experience_operator import ( - EXPERIENCE_OPERATORS, - ExperienceOperator, -) +from trinity.buffer.operators.experience_operator import ExperienceOperator from trinity.common.constants import SELECTOR_METRIC from trinity.common.experience import Experience -@EXPERIENCE_OPERATORS.register_module("pass_rate_calculator") class PassRateCalculator(ExperienceOperator): def __init__(self, **kwargs): pass diff --git a/trinity/buffer/operators/mappers/reward_shaping_mapper.py b/trinity/buffer/operators/mappers/reward_shaping_mapper.py index 86098ea1bc..a1d56839ff 100644 --- a/trinity/buffer/operators/mappers/reward_shaping_mapper.py +++ b/trinity/buffer/operators/mappers/reward_shaping_mapper.py @@ -1,11 +1,10 @@ from typing import Dict, List, Optional, Tuple -from trinity.buffer.operators import EXPERIENCE_OPERATORS, ExperienceOperator +from trinity.buffer.operators import ExperienceOperator from trinity.common.constants import OpType from trinity.common.experience import Experience -@EXPERIENCE_OPERATORS.register_module("reward_shaping_mapper") class RewardShapingMapper(ExperienceOperator): """Re-shaping the existing rewards of experiences based on rules or other advanced methods. diff --git a/trinity/buffer/reader/__init__.py b/trinity/buffer/reader/__init__.py index 49d3bbe3d2..b6968a7158 100644 --- a/trinity/buffer/reader/__init__.py +++ b/trinity/buffer/reader/__init__.py @@ -1,6 +1,12 @@ -from trinity.buffer.reader.file_reader import FileReader -from trinity.buffer.reader.queue_reader import QueueReader -from trinity.buffer.reader.reader import READER -from trinity.buffer.reader.sql_reader import SQLReader +from trinity.utils.registry import Registry -__all__ = ["READER", "FileReader", "QueueReader", "SQLReader"] +READER = Registry( + "reader", + default_mapping={ + "file": "trinity.buffer.reader.file_reader.FileReader", + "queue": "trinity.buffer.reader.queue_reader.QueueReader", + "sql": "trinity.buffer.reader.sql_reader.SQLReader", + }, +) + +__all__ = ["READER"] diff --git a/trinity/buffer/reader/file_reader.py b/trinity/buffer/reader/file_reader.py index 8fa3d4c03d..6cb35e58a7 100644 --- a/trinity/buffer/reader/file_reader.py +++ b/trinity/buffer/reader/file_reader.py @@ -6,8 +6,7 @@ from datasets import Dataset, load_dataset from trinity.buffer.buffer_reader import BufferReader -from trinity.buffer.reader.reader import READER -from trinity.buffer.schema.formatter import FORMATTER +from trinity.buffer.schema import FORMATTER from trinity.common.config import StorageConfig @@ -93,7 +92,6 @@ async def read_async(self, batch_size: Optional[int] = None): raise StopAsyncIteration from e -@READER.register_module("file") class FileReader(BaseFileReader): """Provide a unified interface for Experience and Task file readers.""" diff --git a/trinity/buffer/reader/queue_reader.py b/trinity/buffer/reader/queue_reader.py index e8debdada3..b743c20b8e 100644 --- a/trinity/buffer/reader/queue_reader.py +++ b/trinity/buffer/reader/queue_reader.py @@ -5,13 +5,11 @@ import ray from trinity.buffer.buffer_reader import BufferReader -from trinity.buffer.reader.reader import READER from trinity.buffer.storage.queue import QueueStorage from trinity.common.config import StorageConfig from trinity.common.constants import StorageType -@READER.register_module("queue") class QueueReader(BufferReader): """Reader of the Queue buffer.""" diff --git a/trinity/buffer/reader/reader.py b/trinity/buffer/reader/reader.py deleted file mode 100644 index 63da7b48e3..0000000000 --- a/trinity/buffer/reader/reader.py +++ /dev/null @@ -1,3 +0,0 @@ -from trinity.utils.registry import Registry - -READER = Registry("reader") diff --git a/trinity/buffer/reader/sql_reader.py b/trinity/buffer/reader/sql_reader.py index d13feeed7f..0d7943f8dd 100644 --- a/trinity/buffer/reader/sql_reader.py +++ b/trinity/buffer/reader/sql_reader.py @@ -5,13 +5,11 @@ import ray from trinity.buffer.buffer_reader import BufferReader -from trinity.buffer.reader.reader import READER from trinity.buffer.storage.sql import SQLStorage from trinity.common.config import StorageConfig from trinity.common.constants import StorageType -@READER.register_module("sql") class SQLReader(BufferReader): """Reader of the SQL buffer.""" diff --git a/trinity/buffer/schema/__init__.py b/trinity/buffer/schema/__init__.py index 5fdf581dbc..f9f60f3bcd 100644 --- a/trinity/buffer/schema/__init__.py +++ b/trinity/buffer/schema/__init__.py @@ -1,4 +1,23 @@ -from trinity.buffer.schema.formatter import FORMATTER from trinity.buffer.schema.sql_schema import init_engine +from trinity.utils.registry import Registry -__all__ = ["init_engine", "FORMATTER"] +FORMATTER: Registry = Registry( + "formatter", + { + "task": "trinity.buffer.schema.formatter.TaskFormatter", + "sft": "trinity.buffer.schema.formatter.SFTFormatter", + "dpo": "trinity.buffer.schema.formatter.DPOFormatter", + }, +) + +SQL_SCHEMA: Registry = Registry( + "sql_schema", + { + "task": "trinity.buffer.schema.sql_schema.TaskModel", + "experience": "trinity.buffer.schema.sql_schema.ExperienceModel", + "sft": "trinity.buffer.schema.sql_schema.SFTDataModel", + "dpo": "trinity.buffer.schema.sql_schema.DPODataModel", + }, +) + +__all__ = ["init_engine", "FORMATTER", "SQL_SCHEMA"] diff --git a/trinity/buffer/schema/formatter.py b/trinity/buffer/schema/formatter.py index 8e3fe4e69a..039074521b 100644 --- a/trinity/buffer/schema/formatter.py +++ b/trinity/buffer/schema/formatter.py @@ -9,11 +9,8 @@ from trinity.common.experience import Experience from trinity.common.models.utils import get_action_mask_method from trinity.common.rewards import REWARD_FUNCTIONS -from trinity.common.workflows.workflow import WORKFLOWS, Task +from trinity.common.workflows import WORKFLOWS, Task from trinity.utils.log import get_logger -from trinity.utils.registry import Registry - -FORMATTER = Registry("formatter") class ExperienceFormatter(ABC): @@ -22,7 +19,6 @@ def format(self, sample: Dict) -> Experience: """Format a raw sample dict into an experience.""" -@FORMATTER.register_module("task") class TaskFormatter: """Formatter for task data. @@ -69,7 +65,6 @@ def format(self, sample: Dict) -> Task: ) -@FORMATTER.register_module("sft") class SFTFormatter(ExperienceFormatter): """Formatter for SFT data, supporting both message list and plaintext formats. @@ -288,7 +283,6 @@ def format(self, sample: Dict) -> Experience: return self._messages_to_experience(messages, tools, mm_data) -@FORMATTER.register_module("dpo") class DPOFormatter(ExperienceFormatter): """Formatter for DPO plaintext data. diff --git a/trinity/buffer/schema/sql_schema.py b/trinity/buffer/schema/sql_schema.py index e0df0b1e8e..3a7ae3f105 100644 --- a/trinity/buffer/schema/sql_schema.py +++ b/trinity/buffer/schema/sql_schema.py @@ -9,14 +9,10 @@ from trinity.common.experience import Experience from trinity.utils.log import get_logger -from trinity.utils.registry import Registry - -SQL_SCHEMA = Registry("sql_schema") Base = declarative_base() -@SQL_SCHEMA.register_module("task") class TaskModel(Base): # type: ignore """Model for storing tasks in SQLAlchemy.""" @@ -30,7 +26,6 @@ def from_dict(cls, dict: Dict): return cls(raw_task=dict) -@SQL_SCHEMA.register_module("experience") class ExperienceModel(Base): # type: ignore """SQLAlchemy model for Experience.""" @@ -63,7 +58,6 @@ def from_experience(cls, experience: Experience): ) -@SQL_SCHEMA.register_module("sft") class SFTDataModel(Base): # type: ignore """SQLAlchemy model for SFT data.""" @@ -86,7 +80,6 @@ def from_experience(cls, experience: Experience): ) -@SQL_SCHEMA.register_module("dpo") class DPODataModel(Base): # type: ignore """SQLAlchemy model for DPO data.""" @@ -119,6 +112,8 @@ def init_engine(db_url: str, table_name, schema_type: Optional[str]) -> Tuple: if schema_type is None: schema_type = "task" + from trinity.buffer.schema import SQL_SCHEMA + base_class = SQL_SCHEMA.get(schema_type) table_attrs = { diff --git a/trinity/buffer/selector/__init__.py b/trinity/buffer/selector/__init__.py index 1b84348a3b..3e69c615aa 100644 --- a/trinity/buffer/selector/__init__.py +++ b/trinity/buffer/selector/__init__.py @@ -1,5 +1,18 @@ -from trinity.buffer.selector.selector import SELECTORS +from trinity.buffer.selector.selector import BaseSelector +from trinity.utils.registry import Registry + +SELECTORS = Registry( + "selectors", + default_mapping={ + "sequential": "trinity.buffer.selector.selector.SequentialSelector", + "shuffle": "trinity.buffer.selector.selector.ShuffleSelector", + "random": "trinity.buffer.selector.selector.RandomSelector", + "offline_easy2hard": "trinity.buffer.selector.selector.OfflineEasy2HardSelector", + "difficulty_based": "trinity.buffer.selector.selector.DifficultyBasedSelector", + }, +) __all__ = [ + "BaseSelector", "SELECTORS", ] diff --git a/trinity/buffer/selector/selector.py b/trinity/buffer/selector/selector.py index cc04a573ae..8782244f55 100644 --- a/trinity/buffer/selector/selector.py +++ b/trinity/buffer/selector/selector.py @@ -9,9 +9,6 @@ from trinity.common.config import TaskSelectorConfig from trinity.utils.annotations import Experimental from trinity.utils.log import get_logger -from trinity.utils.registry import Registry - -SELECTORS = Registry("selectors") @Experimental @@ -78,7 +75,6 @@ def load_state_dict(self, state_dict: Dict) -> None: raise NotImplementedError -@SELECTORS.register_module("sequential") class SequentialSelector(BaseSelector): """ Selects data sequentially in fixed order across epochs. @@ -112,7 +108,6 @@ def load_state_dict(self, state_dict): self.current_index = state_dict.get("current_index", 0) -@SELECTORS.register_module("shuffle") class ShuffleSelector(BaseSelector): """ Shuffles dataset once per epoch and iterates through it sequentially. @@ -169,7 +164,6 @@ def load_state_dict(self, state_dict): self.orders = self._get_orders() -@SELECTORS.register_module("random") class RandomSelector(BaseSelector): """ Uniformly samples batches randomly with replacement *per batch*. @@ -207,7 +201,6 @@ def load_state_dict(self, state_dict): self.current_index = state_dict.get("current_index", 0) -@SELECTORS.register_module("offline_easy2hard") class OfflineEasy2HardSelector(BaseSelector): """ Selects samples in an 'easy-to-hard' curriculum based on pre-defined difficulty features. @@ -292,7 +285,6 @@ def load_state_dict(self, state_dict): self.current_index = state_dict.get("current_index", 0) -@SELECTORS.register_module("difficulty_based") class DifficultyBasedSelector(BaseSelector): """ Adaptive difficulty-based selector using probabilistic modeling of sample difficulty. diff --git a/trinity/buffer/storage/__init__.py b/trinity/buffer/storage/__init__.py index e69de29bb2..d03fbd9f82 100644 --- a/trinity/buffer/storage/__init__.py +++ b/trinity/buffer/storage/__init__.py @@ -0,0 +1,15 @@ +from trinity.buffer.storage.queue import PriorityFunction +from trinity.utils.registry import Registry + +PRIORITY_FUNC = Registry( + "priority_fn", + default_mapping={ + "linear_decay": "trinity.buffer.storage.queue.LinearDecayPriority", + "decay_limit_randomization": "trinity.buffer.storage.queue.LinearDecayUseCountControlPriority", + }, +) + +__all__ = [ + "PriorityFunction", + "PRIORITY_FUNC", +] diff --git a/trinity/buffer/storage/queue.py b/trinity/buffer/storage/queue.py index 523cde5b18..3f1c7268b6 100644 --- a/trinity/buffer/storage/queue.py +++ b/trinity/buffer/storage/queue.py @@ -15,7 +15,6 @@ from trinity.common.constants import StorageType from trinity.common.experience import Experience from trinity.utils.log import get_logger -from trinity.utils.registry import Registry def is_database_url(path: str) -> bool: @@ -26,9 +25,6 @@ def is_json_file(path: str) -> bool: return path.endswith(".json") or path.endswith(".jsonl") -PRIORITY_FUNC = Registry("priority_fn") - - class PriorityFunction(ABC): """ Each priority_fn, @@ -53,7 +49,6 @@ def default_config(cls) -> Dict: """Return the default config.""" -@PRIORITY_FUNC.register_module("linear_decay") class LinearDecayPriority(PriorityFunction): """Calculate priority by linear decay. @@ -75,7 +70,6 @@ def default_config(cls) -> Dict: } -@PRIORITY_FUNC.register_module("decay_limit_randomization") class LinearDecayUseCountControlPriority(PriorityFunction): """Calculate priority by linear decay, use count control, and randomization. @@ -198,6 +192,8 @@ def __init__( priority_fn (`str`): Name of the function to use for determining item priority. kwargs: Additional keyword arguments for the priority function. """ + from trinity.buffer.storage import PRIORITY_FUNC + self.capacity = capacity self.item_count = 0 self.priority_groups = SortedDict() # Maps priority -> deque of items diff --git a/trinity/buffer/storage/sql.py b/trinity/buffer/storage/sql.py index e3068fd896..fb21373cda 100644 --- a/trinity/buffer/storage/sql.py +++ b/trinity/buffer/storage/sql.py @@ -9,8 +9,8 @@ from sqlalchemy import asc, desc from sqlalchemy.orm import sessionmaker -from trinity.buffer.schema import init_engine -from trinity.buffer.schema.formatter import FORMATTER, TaskFormatter +from trinity.buffer.schema import FORMATTER, init_engine +from trinity.buffer.schema.formatter import TaskFormatter from trinity.buffer.utils import retry_session from trinity.common.config import StorageConfig from trinity.common.experience import Experience diff --git a/trinity/common/config.py b/trinity/common/config.py index 7c5a9c5a56..137dc2d066 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -844,7 +844,6 @@ def _update_config_from_ray_cluster(self) -> None: """Update config if `node_num` or `gpu_per_node` are not set.""" if self.cluster.node_num is not None and self.cluster.gpu_per_node is not None: return - # init ray cluster to detect node_num and gpu_per_node was_initialized = ray.is_initialized() if not was_initialized: @@ -991,7 +990,7 @@ def _check_trainer_input(self) -> None: f"Auto set `buffer.trainer_input.experience_buffer.path` to {experience_buffer.path}" ) - from trinity.algorithm.algorithm import ALGORITHM_TYPE + from trinity.algorithm import ALGORITHM_TYPE experience_buffer.schema_type = ALGORITHM_TYPE.get(self.algorithm.algorithm_type).schema experience_buffer.batch_size = self.buffer.train_batch_size @@ -1103,12 +1102,12 @@ def _check_buffer(self) -> None: # noqa: C901 def _check_algorithm(self) -> None: from trinity.algorithm import ( ADVANTAGE_FN, + ALGORITHM_TYPE, ENTROPY_LOSS_FN, KL_FN, POLICY_LOSS_FN, SAMPLE_STRATEGY, ) - from trinity.algorithm.algorithm import ALGORITHM_TYPE algorithm = ALGORITHM_TYPE.get(self.algorithm.algorithm_type) algorithm.check_config(self) diff --git a/trinity/common/models/vllm_worker.py b/trinity/common/models/vllm_worker.py index 810b0d0d55..15220b18dc 100644 --- a/trinity/common/models/vllm_worker.py +++ b/trinity/common/models/vllm_worker.py @@ -3,7 +3,6 @@ import ray import torch import torch.distributed -from verl.utils.vllm_utils import patch_vllm_moe_model_weight_loader from trinity.common.models.vllm_patch.worker_patch import patch_vllm_prompt_logprobs from trinity.manager.synchronizer import Synchronizer @@ -14,6 +13,8 @@ class WorkerExtension: def apply_patches(self): """Apply necessary patches to vLLM.""" + from verl.utils.vllm_utils import patch_vllm_moe_model_weight_loader + patch_vllm_moe_model_weight_loader(self.model_runner.model) patch_vllm_prompt_logprobs(self.model_runner) diff --git a/trinity/common/rewards/__init__.py b/trinity/common/rewards/__init__.py index ad36b8103a..c129b6b3bf 100644 --- a/trinity/common/rewards/__init__.py +++ b/trinity/common/rewards/__init__.py @@ -1,25 +1,24 @@ # -*- coding: utf-8 -*- """Reward functions for RFT""" -# isort: off -from trinity.common.rewards.reward_fn import REWARD_FUNCTIONS, RewardFn, RMGalleryFn +from trinity.common.rewards.reward_fn import RewardFn +from trinity.utils.registry import Registry -from trinity.common.rewards.accuracy_reward import AccuracyReward -from trinity.common.rewards.countdown_reward import CountDownRewardFn -from trinity.common.rewards.dapo_reward import MathDAPORewardFn -from trinity.common.rewards.format_reward import FormatReward -from trinity.common.rewards.math_reward import MathBoxedRewardFn, MathRewardFn - -# isort: on +REWARD_FUNCTIONS = Registry( + "reward_functions", + default_mapping={ + "rm_gallery_reward": "trinity.common.rewards.reward_fn.RMGalleryFn", + "math_reward": "trinity.common.rewards.math_reward.MathRewardFn", + "math_boxed_reward": "trinity.common.rewards.math_reward.MathBoxedRewardFn", + "format_reward": "trinity.common.rewards.format_reward.FormatReward", + "countdown_reward": "trinity.common.rewards.countdown_reward.CountDownRewardFn", + "accuracy_reward": "trinity.common.rewards.accuracy_reward.AccuracyReward", + "math_dapo_reward": "trinity.common.rewards.dapo_reward.MathDAPORewardFn", + }, +) __all__ = [ "RewardFn", "RMGalleryFn", "REWARD_FUNCTIONS", - "AccuracyReward", - "CountDownRewardFn", - "FormatReward", - "MathRewardFn", - "MathBoxedRewardFn", - "MathDAPORewardFn", ] diff --git a/trinity/common/rewards/accuracy_reward.py b/trinity/common/rewards/accuracy_reward.py index 9bc1b14e5c..196102cdcf 100644 --- a/trinity/common/rewards/accuracy_reward.py +++ b/trinity/common/rewards/accuracy_reward.py @@ -6,11 +6,10 @@ from math_verify import LatexExtractionConfig from trinity.common.rewards.eval_utils import parse_with_timeout, verify_with_timeout -from trinity.common.rewards.reward_fn import REWARD_FUNCTIONS, RewardFn +from trinity.common.rewards.reward_fn import RewardFn from trinity.utils.log import get_logger -@REWARD_FUNCTIONS.register_module("accuracy_reward") class AccuracyReward(RewardFn): """A reward function that rewards correct answers. Ref: https://github.com/huggingface/open-r1/blob/main/src/open_r1/rewards.py diff --git a/trinity/common/rewards/countdown_reward.py b/trinity/common/rewards/countdown_reward.py index 1ec20ea68a..9f63353a34 100644 --- a/trinity/common/rewards/countdown_reward.py +++ b/trinity/common/rewards/countdown_reward.py @@ -7,10 +7,9 @@ extract_solution, validate_equation, ) -from trinity.common.rewards.reward_fn import REWARD_FUNCTIONS, RewardFn +from trinity.common.rewards.reward_fn import RewardFn -@REWARD_FUNCTIONS.register_module("countdown_reward") class CountDownRewardFn(RewardFn): """A reward function that rewards for countdown task. Ref: Jiayi-Pan/TinyZero verl/utils/reward_score/countdown.py diff --git a/trinity/common/rewards/dapo_reward.py b/trinity/common/rewards/dapo_reward.py index 85438e4333..ff2783e223 100644 --- a/trinity/common/rewards/dapo_reward.py +++ b/trinity/common/rewards/dapo_reward.py @@ -5,10 +5,9 @@ import torch from trinity.common.rewards.naive_dapo_score import compute_score -from trinity.common.rewards.reward_fn import REWARD_FUNCTIONS, RewardFn +from trinity.common.rewards.reward_fn import RewardFn -@REWARD_FUNCTIONS.register_module("math_dapo_reward") class MathDAPORewardFn(RewardFn): """A reward function that follows the definition in DAPO for math task.""" diff --git a/trinity/common/rewards/format_reward.py b/trinity/common/rewards/format_reward.py index ad9b203f82..bc4835e5be 100644 --- a/trinity/common/rewards/format_reward.py +++ b/trinity/common/rewards/format_reward.py @@ -3,10 +3,9 @@ import re from typing import Optional -from trinity.common.rewards.reward_fn import REWARD_FUNCTIONS, RewardFn +from trinity.common.rewards.reward_fn import RewardFn -@REWARD_FUNCTIONS.register_module("format_reward") class FormatReward(RewardFn): """A reward function that checks if the reasoning process is enclosed within and tags, while the final answer is enclosed within and tags. Ref: https://github.com/huggingface/open-r1/blob/main/src/open_r1/rewards.py diff --git a/trinity/common/rewards/math_reward.py b/trinity/common/rewards/math_reward.py index b6b65416f6..3018261e73 100644 --- a/trinity/common/rewards/math_reward.py +++ b/trinity/common/rewards/math_reward.py @@ -9,10 +9,9 @@ validate_think_pattern, ) from trinity.common.rewards.format_reward import FormatReward -from trinity.common.rewards.reward_fn import REWARD_FUNCTIONS, RewardFn +from trinity.common.rewards.reward_fn import RewardFn -@REWARD_FUNCTIONS.register_module("math_reward") class MathRewardFn(RewardFn): """A reward function that rewards for math task.""" @@ -40,7 +39,6 @@ def __call__( # type: ignore return {**accuracy_score, **format_score} -@REWARD_FUNCTIONS.register_module("math_boxed_reward") class MathBoxedRewardFn(RewardFn): """A reward function that rewards for math task.""" diff --git a/trinity/common/rewards/naive_dapo_score.py b/trinity/common/rewards/naive_dapo_score.py index 8f7ba2c73d..fd791c854d 100644 --- a/trinity/common/rewards/naive_dapo_score.py +++ b/trinity/common/rewards/naive_dapo_score.py @@ -12,8 +12,6 @@ import sympy from pylatexenc import latex2text from sympy.parsing import sympy_parser -from verl.utils.reward_score.prime_math import math_normalize -from verl.utils.reward_score.prime_math.grader import math_equal # Constants for normalization SUBSTITUTIONS = [ @@ -376,6 +374,7 @@ def grade_answer(given_answer: str, ground_truth: str) -> tuple[bool, str]: """ if given_answer is None: return False + from verl.utils.reward_score.prime_math import math_normalize ground_truth_normalized_mathd = math_normalize.normalize_answer(ground_truth) given_answer_normalized_mathd = math_normalize.normalize_answer(given_answer) @@ -478,6 +477,8 @@ def compute_score(solution_str: str, ground_truth: str) -> float: Returns: Reward score (1.0 for correct, 0.0 for incorrect) """ + from verl.utils.reward_score.prime_math.grader import math_equal + # First assert intended generation and gt type model_output = str(solution_str) ground_truth = str(ground_truth) diff --git a/trinity/common/rewards/reward_fn.py b/trinity/common/rewards/reward_fn.py index bcf9f97f8a..5eacc98914 100644 --- a/trinity/common/rewards/reward_fn.py +++ b/trinity/common/rewards/reward_fn.py @@ -5,9 +5,6 @@ from trinity.common.experience import Experience from trinity.common.rewards.utils import to_rm_gallery_messages -from trinity.utils.registry import Registry - -REWARD_FUNCTIONS = Registry("reward_functions") class RewardFn(ABC): @@ -22,7 +19,6 @@ def __call__(self, **kwargs) -> Dict[str, float]: pass -@REWARD_FUNCTIONS.register_module("rm_gallery_reward") class RMGalleryFn(RewardFn): """Reward Function from RMGallery. https://github.com/modelscope/RM-Gallery diff --git a/trinity/common/workflows/__init__.py b/trinity/common/workflows/__init__.py index 847a3d8707..4aeca67f5b 100644 --- a/trinity/common/workflows/__init__.py +++ b/trinity/common/workflows/__init__.py @@ -1,99 +1,54 @@ # -*- coding: utf-8 -*- """Workflow module""" -from trinity.common.workflows.agentscope.react.react_workflow import ( - AgentScopeReActWorkflow, -) -from trinity.common.workflows.agentscope_workflow import AgentScopeWorkflowAdapter -from trinity.common.workflows.customized_math_workflows import ( - AsyncMathBoxedWorkflow, - MathBoxedWorkflow, -) -from trinity.common.workflows.customized_toolcall_workflows import ToolCallWorkflow -from trinity.common.workflows.envs.agentscope.agentscopev0_react_workflow import ( # will be deprecated soon - AgentScopeV0ReactMathWorkflow, -) -from trinity.common.workflows.envs.agentscope.agentscopev1_react_workflow import ( - AgentScopeReactMathWorkflow, -) -from trinity.common.workflows.envs.agentscope.agentscopev1_search_workflow import ( - AgentScopeV1ReactSearchWorkflow, -) -from trinity.common.workflows.envs.alfworld.alfworld_workflow import ( - AlfworldWorkflow, - StepWiseAlfworldWorkflow, -) -from trinity.common.workflows.envs.alfworld.RAFT_alfworld_workflow import ( - RAFTAlfworldWorkflow, -) -from trinity.common.workflows.envs.alfworld.RAFT_reflect_alfworld_workflow import ( - RAFTReflectAlfworldWorkflow, -) -from trinity.common.workflows.envs.email_searcher.workflow import EmailSearchWorkflow -from trinity.common.workflows.envs.frozen_lake.workflow import FrozenLakeWorkflow -from trinity.common.workflows.envs.sciworld.sciworld_workflow import SciWorldWorkflow -from trinity.common.workflows.envs.webshop.webshop_workflow import WebShopWorkflow -from trinity.common.workflows.eval_workflow import ( - AsyncMathEvalWorkflow, - MathEvalWorkflow, -) -from trinity.common.workflows.math_rm_workflow import ( - AsyncMathRMWorkflow, - MathRMWorkflow, -) -from trinity.common.workflows.math_ruler_workflow import ( - AsyncMathRULERWorkflow, - MathRULERWorkflow, -) -from trinity.common.workflows.math_trainable_ruler_workflow import ( - MathTrainableRULERWorkflow, -) -from trinity.common.workflows.rubric_judge_workflow import RubricJudgeWorkflow -from trinity.common.workflows.simple_mm_workflow import ( - AsyncSimpleMMWorkflow, - SimpleMMWorkflow, -) -from trinity.common.workflows.workflow import ( - WORKFLOWS, - AsyncMathWorkflow, - AsyncSimpleWorkflow, - MathWorkflow, - SimpleWorkflow, - Task, - Workflow, +from trinity.common.workflows.workflow import Task, Workflow +from trinity.utils.registry import Registry + +WORKFLOWS: Registry = Registry( + "workflows", + default_mapping={ + # simple/math + "simple_workflow": "trinity.common.workflows.workflow.SimpleWorkflow", + "async_simple_workflow": "trinity.common.workflows.workflow.AsyncSimpleWorkflow", + "math_workflow": "trinity.common.workflows.workflow.MathWorkflow", + "async_math_workflow": "trinity.common.workflows.workflow.AsyncMathWorkflow", + "math_boxed_workflow": "trinity.common.workflows.customized_math_workflows.MathBoxedWorkflow", + "async_math_boxed_workflow": "trinity.common.workflows.customized_math_workflows.AsyncMathBoxedWorkflow", + "math_eval_workflow": "trinity.common.workflows.eval_workflow.MathEvalWorkflow", + "async_math_eval_workflow": "trinity.common.workflows.eval_workflow.AsyncMathEvalWorkflow", + "math_rm_workflow": "trinity.common.workflows.math_rm_workflow.MathRMWorkflow", + "async_math_rm_workflow": "trinity.common.workflows.math_rm_workflow.AsyncMathRMWorkflow", + # tool_call + "tool_call_workflow": "trinity.common.workflows.customized_toolcall_workflows.ToolCallWorkflow", + # agentscope + "agentscope_react_workflow": "trinity.common.workflows.agentscope.react.react_workflow.AgentScopeReActWorkflow", + "agentscope_workflow_adapter": "trinity.common.workflows.agentscope_workflow.AgentScopeWorkflowAdapter", + "agentscope_react_math_workflow": "trinity.common.workflows.envs.agentscope.agentscopev1_react_workflow.AgentScopeReactMathWorkflow", + "as_react_workflow": "trinity.common.workflows.agentscope.react.react_workflow.AgentScopeReActWorkflow", + "agentscopev0_react_math_workflow": "trinity.common.workflows.envs.agentscope.agentscopev0_react_workflow.AgentScopeV0ReactMathWorkflow", + "agentscope_v1_react_search_workflow": "trinity.common.workflows.envs.agentscope.agentscopev1_search_workflow.AgentScopeV1ReactSearchWorkflow", + "email_search_workflow": "trinity.common.workflows.envs.email_searcher.workflow.EmailSearchWorkflow", + # concatenated multi-turn + "alfworld_workflow": "trinity.common.workflows.envs.alfworld.alfworld_workflow.AlfworldWorkflow", + "RAFT_alfworld_workflow": "trinity.common.workflows.envs.alfworld.RAFT_alfworld_workflow.RAFTAlfworldWorkflow", + "RAFT_reflect_alfworld_workflow": "trinity.common.workflows.envs.alfworld.RAFT_reflect_alfworld_workflow.RAFTReflectAlfworldWorkflow", + "frozen_lake_workflow": "trinity.common.workflows.envs.frozen_lake.workflow.FrozenLakeWorkflow", + "sciworld_workflow": "trinity.common.workflows.envs.sciworld.sciworld_workflow.SciWorldWorkflow", + "webshop_workflow": "trinity.common.workflows.envs.webshop.webshop_workflow.WebShopWorkflow", + # general multi-turn + "step_wise_alfworld_workflow": "trinity.common.workflows.envs.alfworld.alfworld_workflow.StepWiseAlfworldWorkflow", + # ruler/llm_as_a_judge + "async_math_ruler_workflow": "trinity.common.workflows.math_ruler_workflow.AsyncMathRULERWorkflow", + "math_ruler_workflow": "trinity.common.workflows.math_ruler_workflow.MathRULERWorkflow", + "math_trainable_ruler_workflow": "trinity.common.workflows.math_trainable_ruler_workflow.MathTrainableRULERWorkflow", + "rubric_judge_workflow": "trinity.common.workflows.rubric_judge_workflow.RubricJudgeWorkflow", + # others + "simple_mm_workflow": "trinity.common.workflows.simple_mm_workflow.SimpleMMWorkflow", + "async_simple_mm_workflow": "trinity.common.workflows.simple_mm_workflow.AsyncSimpleMMWorkflow", + }, ) __all__ = [ "Task", "Workflow", "WORKFLOWS", - "AsyncSimpleWorkflow", - "SimpleWorkflow", - "AsyncMathWorkflow", - "MathWorkflow", - "WebShopWorkflow", - "AlfworldWorkflow", - "StepWiseAlfworldWorkflow", - "RAFTAlfworldWorkflow", - "RAFTReflectAlfworldWorkflow", - "SciWorldWorkflow", - "AsyncMathBoxedWorkflow", - "MathBoxedWorkflow", - "AsyncMathRMWorkflow", - "MathRMWorkflow", - "ToolCallWorkflow", - "AsyncMathEvalWorkflow", - "MathEvalWorkflow", - "AgentScopeV0ReactMathWorkflow", # will be deprecated soon - "AgentScopeReactMathWorkflow", - "AgentScopeV1ReactSearchWorkflow", - "AgentScopeReActWorkflow", - "EmailSearchWorkflow", - "AsyncMathRULERWorkflow", - "MathRULERWorkflow", - "MathTrainableRULERWorkflow", - "AsyncSimpleMMWorkflow", - "SimpleMMWorkflow", - "RubricJudgeWorkflow", - "AgentScopeWorkflowAdapter", - "FrozenLakeWorkflow", ] diff --git a/trinity/common/workflows/agentscope/react/react_workflow.py b/trinity/common/workflows/agentscope/react/react_workflow.py index a6dbca28e3..f4c8c4375d 100644 --- a/trinity/common/workflows/agentscope/react/react_workflow.py +++ b/trinity/common/workflows/agentscope/react/react_workflow.py @@ -9,12 +9,11 @@ from trinity.common.experience import Experience from trinity.common.models.model import ModelWrapper -from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow +from trinity.common.workflows.workflow import Task, Workflow from .templates import TEMPLATE_MAP -@WORKFLOWS.register_module("as_react_workflow") class AgentScopeReActWorkflow(Workflow): is_async: bool = True diff --git a/trinity/common/workflows/agentscope/react/templates.py b/trinity/common/workflows/agentscope/react/templates.py index 754ce7eb2b..0b2ba2fd3f 100644 --- a/trinity/common/workflows/agentscope/react/templates.py +++ b/trinity/common/workflows/agentscope/react/templates.py @@ -3,7 +3,8 @@ from pydantic import BaseModel, Field -from trinity.common.rewards import MathBoxedRewardFn, RewardFn +from trinity.common.rewards import RewardFn +from trinity.common.rewards.math_reward import MathBoxedRewardFn # For GSM8K task GSM8KSystemPrompt = """You are an agent specialized in solving math problems with tools. Please solve the math problem given to you. You can write and execute Python code to perform calculation or verify your answer. You should return your final answer within \\boxed{{}}.""" @@ -22,7 +23,7 @@ def __call__( # type: ignore [override] truth: str, format_score_coef: float = 0.1, **kwargs, - ) -> dict[str, float]: + ) -> Dict[str, float]: # parse GSM8K truth if isinstance(truth, str) and "####" in truth: truth = truth.split("####")[1].strip() diff --git a/trinity/common/workflows/agentscope_workflow.py b/trinity/common/workflows/agentscope_workflow.py index 03f3689d72..8fc28fbe32 100644 --- a/trinity/common/workflows/agentscope_workflow.py +++ b/trinity/common/workflows/agentscope_workflow.py @@ -4,10 +4,9 @@ from trinity.common.experience import Experience from trinity.common.models.model import ModelWrapper -from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow +from trinity.common.workflows.workflow import Task, Workflow -@WORKFLOWS.register_module("agentscope_workflow_adapter") class AgentScopeWorkflowAdapter(Workflow): """Adapter to wrap a agentscope trainable workflow function into a Trinity Workflow.""" diff --git a/trinity/common/workflows/customized_math_workflows.py b/trinity/common/workflows/customized_math_workflows.py index 65bed1ac0d..47b33ff3f9 100644 --- a/trinity/common/workflows/customized_math_workflows.py +++ b/trinity/common/workflows/customized_math_workflows.py @@ -5,10 +5,9 @@ from trinity.common.experience import Experience from trinity.common.rewards.math_reward import MathBoxedRewardFn -from trinity.common.workflows.workflow import WORKFLOWS, SimpleWorkflow, Task +from trinity.common.workflows.workflow import SimpleWorkflow, Task -@WORKFLOWS.register_module("math_boxed_workflow") class MathBoxedWorkflow(SimpleWorkflow): """A workflow for math tasks that give answers in boxed format.""" @@ -93,7 +92,6 @@ def run(self) -> List[Experience]: return responses -@WORKFLOWS.register_module("async_math_boxed_workflow") class AsyncMathBoxedWorkflow(MathBoxedWorkflow): is_async: bool = True diff --git a/trinity/common/workflows/customized_toolcall_workflows.py b/trinity/common/workflows/customized_toolcall_workflows.py index 2da3daa281..938c027b03 100644 --- a/trinity/common/workflows/customized_toolcall_workflows.py +++ b/trinity/common/workflows/customized_toolcall_workflows.py @@ -11,7 +11,7 @@ from typing import List from trinity.common.experience import Experience -from trinity.common.workflows.workflow import WORKFLOWS, SimpleWorkflow, Task +from trinity.common.workflows.workflow import SimpleWorkflow, Task # Adapted from https://github.com/NVlabs/Tool-N1 qwen_tool_prompts = """# Tool @@ -207,7 +207,6 @@ def compute_toolcall_reward( return float(res[0]) -@WORKFLOWS.register_module("toolcall_workflow") class ToolCallWorkflow(SimpleWorkflow): """ A workflow for toolcall tasks. diff --git a/trinity/common/workflows/envs/agentscope/agentscopev0_react_workflow.py b/trinity/common/workflows/envs/agentscope/agentscopev0_react_workflow.py index a44acedf1f..fab5403dcd 100644 --- a/trinity/common/workflows/envs/agentscope/agentscopev0_react_workflow.py +++ b/trinity/common/workflows/envs/agentscope/agentscopev0_react_workflow.py @@ -7,12 +7,11 @@ from trinity.common.models.model import ModelWrapper from trinity.common.rewards.math_reward import MathBoxedRewardFn -from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow +from trinity.common.workflows.workflow import Task, Workflow from trinity.utils.annotations import Deprecated @Deprecated -@WORKFLOWS.register_module("agentscopev0_react_math_workflow") class AgentScopeV0ReactMathWorkflow(Workflow): """ This workflow serves as an example of how to use the agentscope framework within the trinity workflow. diff --git a/trinity/common/workflows/envs/agentscope/agentscopev1_react_workflow.py b/trinity/common/workflows/envs/agentscope/agentscopev1_react_workflow.py index dbc73150bc..c8cc0dc155 100644 --- a/trinity/common/workflows/envs/agentscope/agentscopev1_react_workflow.py +++ b/trinity/common/workflows/envs/agentscope/agentscopev1_react_workflow.py @@ -7,10 +7,9 @@ from trinity.common.models.model import ModelWrapper from trinity.common.rewards.math_reward import MathBoxedRewardFn -from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow +from trinity.common.workflows.workflow import Task, Workflow -@WORKFLOWS.register_module("agentscope_react_math_workflow") class AgentScopeReactMathWorkflow(Workflow): """ This workflow serves as an example of how to use the agentscope framework within the trinity workflow. diff --git a/trinity/common/workflows/envs/agentscope/agentscopev1_search_workflow.py b/trinity/common/workflows/envs/agentscope/agentscopev1_search_workflow.py index 2585113303..8d2231b0e7 100644 --- a/trinity/common/workflows/envs/agentscope/agentscopev1_search_workflow.py +++ b/trinity/common/workflows/envs/agentscope/agentscopev1_search_workflow.py @@ -8,10 +8,9 @@ import openai from trinity.common.models.model import ModelWrapper -from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow +from trinity.common.workflows.workflow import Task, Workflow -@WORKFLOWS.register_module("agentscope_v1_react_search_workflow") class AgentScopeV1ReactSearchWorkflow(Workflow): """ This workflow serves as an example of how to use the agentscope framework within the trinity workflow. diff --git a/trinity/common/workflows/envs/alfworld/RAFT_alfworld_workflow.py b/trinity/common/workflows/envs/alfworld/RAFT_alfworld_workflow.py index e9e3f30206..15050dc0eb 100644 --- a/trinity/common/workflows/envs/alfworld/RAFT_alfworld_workflow.py +++ b/trinity/common/workflows/envs/alfworld/RAFT_alfworld_workflow.py @@ -13,10 +13,9 @@ process_messages_to_experience, validate_trajectory_format, ) -from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow +from trinity.common.workflows.workflow import Task, Workflow -@WORKFLOWS.register_module("RAFT_alfworld_workflow") class RAFTAlfworldWorkflow(Workflow): """ RAFT workflow for alfworld using trajectory context. diff --git a/trinity/common/workflows/envs/alfworld/RAFT_reflect_alfworld_workflow.py b/trinity/common/workflows/envs/alfworld/RAFT_reflect_alfworld_workflow.py index 81aed995d9..589fb9a8e6 100644 --- a/trinity/common/workflows/envs/alfworld/RAFT_reflect_alfworld_workflow.py +++ b/trinity/common/workflows/envs/alfworld/RAFT_reflect_alfworld_workflow.py @@ -18,10 +18,9 @@ save_task_data, validate_trajectory_format, ) -from trinity.common.workflows.workflow import WORKFLOWS, Task +from trinity.common.workflows.workflow import Task -@WORKFLOWS.register_module("RAFT_reflect_alfworld_workflow") class RAFTReflectAlfworldWorkflow(RAFTAlfworldWorkflow): """ RAFT workflow for alfworld using trajectory context. diff --git a/trinity/common/workflows/envs/alfworld/alfworld_workflow.py b/trinity/common/workflows/envs/alfworld/alfworld_workflow.py index 2ae075a68c..7388e156a1 100644 --- a/trinity/common/workflows/envs/alfworld/alfworld_workflow.py +++ b/trinity/common/workflows/envs/alfworld/alfworld_workflow.py @@ -4,7 +4,7 @@ from trinity.common.experience import Experience from trinity.common.models.model import ModelWrapper from trinity.common.workflows.step_wise_workflow import RewardPropagationWorkflow -from trinity.common.workflows.workflow import WORKFLOWS, MultiTurnWorkflow, Task +from trinity.common.workflows.workflow import MultiTurnWorkflow, Task EXAMPLE_PROMPT = """ Observation: @@ -93,7 +93,6 @@ def parse_action(response): return "" -@WORKFLOWS.register_module("alfworld_workflow") class AlfworldWorkflow(MultiTurnWorkflow): """A workflow for alfworld task.""" @@ -176,7 +175,6 @@ def create_environment(game_file): return await self.generate_env_inference_samples(env) -@WORKFLOWS.register_module("step_wise_alfworld_workflow") class StepWiseAlfworldWorkflow(RewardPropagationWorkflow): """ An Alfworld workflow refactored to use the RewardPropagationWorkflow base class. diff --git a/trinity/common/workflows/envs/email_searcher/workflow.py b/trinity/common/workflows/envs/email_searcher/workflow.py index 0df651c72d..0a243ef507 100644 --- a/trinity/common/workflows/envs/email_searcher/workflow.py +++ b/trinity/common/workflows/envs/email_searcher/workflow.py @@ -11,7 +11,7 @@ QueryModel, judge_correctness, ) -from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow +from trinity.common.workflows.workflow import Task, Workflow SYSTEM_PROMPT = """You are an email search agent. You are given a user query and a list of tools you can use to search the user's email. Use the tools to search the user's emails and find the answer to the user's query. You may take up to {max_turns} turns to find the answer, so if your first seach doesn't find the answer, you can try with different keywords. Always describe what you see and plan your next steps clearly. When taking actions, explain what you're doing and why. When the answer to the task is found, call `generate_response` to finish the process. Only call `generate_response` when answer is found. You should not respond any next steps in `generate_response`. Complete all steps and then call `generate_response`. @@ -21,7 +21,6 @@ """ -@WORKFLOWS.register_module("email_search_workflow") class EmailSearchWorkflow(Workflow): """ Multi-turn Email Search workflow (ReAct-style tool use). diff --git a/trinity/common/workflows/envs/frozen_lake/workflow.py b/trinity/common/workflows/envs/frozen_lake/workflow.py index c7a13c17fd..604b50282d 100644 --- a/trinity/common/workflows/envs/frozen_lake/workflow.py +++ b/trinity/common/workflows/envs/frozen_lake/workflow.py @@ -22,10 +22,9 @@ generate_random_map, get_goal_position, ) -from trinity.common.workflows.workflow import WORKFLOWS, MultiTurnWorkflow, Task +from trinity.common.workflows.workflow import MultiTurnWorkflow, Task -@WORKFLOWS.register_module("frozen_lake_workflow") class FrozenLakeWorkflow(MultiTurnWorkflow): """ FrozenLake environment for multi-step workflows. diff --git a/trinity/common/workflows/envs/sciworld/sciworld_workflow.py b/trinity/common/workflows/envs/sciworld/sciworld_workflow.py index 4c2a7417a9..f0862b6f38 100644 --- a/trinity/common/workflows/envs/sciworld/sciworld_workflow.py +++ b/trinity/common/workflows/envs/sciworld/sciworld_workflow.py @@ -4,7 +4,7 @@ from trinity.common.experience import Experience from trinity.common.models.model import ModelWrapper -from trinity.common.workflows.workflow import WORKFLOWS, MultiTurnWorkflow, Task +from trinity.common.workflows.workflow import MultiTurnWorkflow, Task SCIWORLD_SYSTEM_PROMPT = """ You are an agent, you job is to do some scientific experiment in a virtual test-based environments. @@ -55,7 +55,6 @@ def parse_action(response): return "" -@WORKFLOWS.register_module("sciworld_workflow") class SciWorldWorkflow(MultiTurnWorkflow): """A workflow for sciworld task.""" diff --git a/trinity/common/workflows/envs/webshop/webshop_workflow.py b/trinity/common/workflows/envs/webshop/webshop_workflow.py index 4cef8fc456..48a0f464f7 100644 --- a/trinity/common/workflows/envs/webshop/webshop_workflow.py +++ b/trinity/common/workflows/envs/webshop/webshop_workflow.py @@ -3,7 +3,7 @@ from trinity.common.experience import Experience from trinity.common.models.model import ModelWrapper -from trinity.common.workflows.workflow import WORKFLOWS, MultiTurnWorkflow, Task +from trinity.common.workflows.workflow import MultiTurnWorkflow, Task SPARSE_REWARD = False @@ -177,7 +177,6 @@ def validate_action(action, available_actions): ) -@WORKFLOWS.register_module("webshop_workflow") class WebShopWorkflow(MultiTurnWorkflow): """A workflow for webshop task.""" diff --git a/trinity/common/workflows/eval_workflow.py b/trinity/common/workflows/eval_workflow.py index 837bc12881..0fcb115b6f 100644 --- a/trinity/common/workflows/eval_workflow.py +++ b/trinity/common/workflows/eval_workflow.py @@ -10,10 +10,9 @@ from trinity.common.experience import Experience from trinity.common.models.model import ModelWrapper from trinity.common.rewards.qwen25_eval import verify_math_answer -from trinity.common.workflows.workflow import WORKFLOWS, Task, Workflow +from trinity.common.workflows.workflow import Task, Workflow -@WORKFLOWS.register_module("math_eval_workflow") class MathEvalWorkflow(Workflow): """ A workflow for standard math evaluation. @@ -79,7 +78,6 @@ def run(self) -> List[Experience]: return responses -@WORKFLOWS.register_module("async_math_eval_workflow") class AsyncMathEvalWorkflow(MathEvalWorkflow): is_async: bool = True diff --git a/trinity/common/workflows/math_rm_workflow.py b/trinity/common/workflows/math_rm_workflow.py index aca5586717..a5213e0034 100644 --- a/trinity/common/workflows/math_rm_workflow.py +++ b/trinity/common/workflows/math_rm_workflow.py @@ -7,10 +7,9 @@ from trinity.common.experience import Experience from trinity.common.models.model import ModelWrapper -from trinity.common.workflows.workflow import WORKFLOWS, SimpleWorkflow, Task +from trinity.common.workflows.workflow import SimpleWorkflow, Task -@WORKFLOWS.register_module("math_rm_workflow") class MathRMWorkflow(SimpleWorkflow): """A workflow for math tasks as introduced in DeepSeek-R1.""" @@ -53,7 +52,6 @@ def run(self) -> List[Experience]: return responses -@WORKFLOWS.register_module("async_math_rm_workflow") class AsyncMathRMWorkflow(MathRMWorkflow): is_async: bool = True diff --git a/trinity/common/workflows/math_ruler_workflow.py b/trinity/common/workflows/math_ruler_workflow.py index fec1bc72a4..42848dd0d7 100644 --- a/trinity/common/workflows/math_ruler_workflow.py +++ b/trinity/common/workflows/math_ruler_workflow.py @@ -8,10 +8,9 @@ from trinity.common.experience import Experience from trinity.common.models.model import ModelWrapper from trinity.common.rewards.math_reward import MathRewardFn -from trinity.common.workflows.workflow import WORKFLOWS, SimpleWorkflow, Task +from trinity.common.workflows.workflow import SimpleWorkflow, Task -@WORKFLOWS.register_module("math_ruler_workflow") class MathRULERWorkflow(SimpleWorkflow): """A workflow for math with RULER reward function. @@ -158,7 +157,6 @@ def get_ruler_scores( return False, [0.0 for _ in range(num_responses)] -@WORKFLOWS.register_module("async_math_ruler_workflow") class AsyncMathRULERWorkflow(MathRULERWorkflow): is_async: bool = True diff --git a/trinity/common/workflows/math_trainable_ruler_workflow.py b/trinity/common/workflows/math_trainable_ruler_workflow.py index a60c8d7093..d43cbe4aed 100644 --- a/trinity/common/workflows/math_trainable_ruler_workflow.py +++ b/trinity/common/workflows/math_trainable_ruler_workflow.py @@ -10,13 +10,12 @@ from trinity.common.experience import Experience from trinity.common.models.model import ModelWrapper from trinity.common.rewards.math_reward import MathRewardFn -from trinity.common.workflows.workflow import WORKFLOWS, SimpleWorkflow, Task +from trinity.common.workflows.workflow import SimpleWorkflow, Task # the probability that the ground truth is assumed to be available for RL PROBABILITY_GROUND_TRUTH_AVAILABLE = 0.2 -@WORKFLOWS.register_module("math_trainable_ruler_workflow") class MathTrainableRULERWorkflow(SimpleWorkflow): """A workflow for math, where the policy model itself serves as a RULER reward model. Modified from `MathRULERWorkflow`. diff --git a/trinity/common/workflows/rubric_judge_workflow.py b/trinity/common/workflows/rubric_judge_workflow.py index 12b73366b7..2311803cac 100644 --- a/trinity/common/workflows/rubric_judge_workflow.py +++ b/trinity/common/workflows/rubric_judge_workflow.py @@ -7,10 +7,9 @@ from trinity.common.experience import Experience from trinity.common.models.model import ModelWrapper -from trinity.common.workflows.workflow import WORKFLOWS, SimpleWorkflow, Task +from trinity.common.workflows.workflow import SimpleWorkflow, Task -@WORKFLOWS.register_module("rubric_judge_workflow") class RubricJudgeWorkflow(SimpleWorkflow): """A workflow using LLM-as-a-judge and rubrics to get the reward. diff --git a/trinity/common/workflows/simple_mm_workflow.py b/trinity/common/workflows/simple_mm_workflow.py index 2bb3f857e2..97044e0cf1 100644 --- a/trinity/common/workflows/simple_mm_workflow.py +++ b/trinity/common/workflows/simple_mm_workflow.py @@ -5,10 +5,9 @@ from trinity.common.experience import Experience from trinity.common.models.model import ModelWrapper from trinity.common.rewards.reward_fn import RewardFn -from trinity.common.workflows.workflow import WORKFLOWS, SimpleWorkflow, Task +from trinity.common.workflows.workflow import SimpleWorkflow, Task -@WORKFLOWS.register_module("simple_mm_workflow") class SimpleMMWorkflow(SimpleWorkflow): """A workflow for simple single-round task.""" @@ -80,7 +79,6 @@ def run(self) -> List[Experience]: return responses -@WORKFLOWS.register_module("async_simple_mm_workflow") class AsyncSimpleMMWorkflow(SimpleMMWorkflow): is_async: bool = True diff --git a/trinity/common/workflows/workflow.py b/trinity/common/workflows/workflow.py index 90f52a2505..7d9e16e689 100644 --- a/trinity/common/workflows/workflow.py +++ b/trinity/common/workflows/workflow.py @@ -4,19 +4,17 @@ from __future__ import annotations from dataclasses import asdict, dataclass, field -from typing import Any, List, Optional, Type, Union - -import openai +from typing import TYPE_CHECKING, Any, List, Optional, Type, Union from trinity.common.config import FormatConfig, GenerationConfig from trinity.common.experience import Experience -from trinity.common.models.model import ModelWrapper -from trinity.common.rewards.math_reward import MathRewardFn from trinity.common.rewards.reward_fn import RewardFn from trinity.utils.log import get_logger -from trinity.utils.registry import Registry -WORKFLOWS = Registry("workflows") +if TYPE_CHECKING: + import openai + + from trinity.common.models.model import ModelWrapper @dataclass @@ -250,7 +248,6 @@ def format_messages(self): return messages -@WORKFLOWS.register_module("simple_workflow") class SimpleWorkflow(BaseSimpleWorkflow): """A workflow for simple single-round task.""" @@ -282,7 +279,6 @@ def run(self) -> List[Experience]: return responses -@WORKFLOWS.register_module("async_simple_workflow") class AsyncSimpleWorkflow(BaseSimpleWorkflow): is_async: bool = True @@ -311,7 +307,6 @@ async def run_async(self) -> List[Experience]: return responses -@WORKFLOWS.register_module("math_workflow") class MathWorkflow(SimpleWorkflow): """A workflow for math tasks as introduced in DeepSeek-R1.""" @@ -330,6 +325,8 @@ def __init__( ) def reset(self, task: Task): + from trinity.common.rewards.math_reward import MathRewardFn + if task.reward_fn is None: task.reward_fn = MathRewardFn if task.reward_fn == MathRewardFn and task.format_args.system_prompt is None: @@ -341,6 +338,5 @@ def reset(self, task: Task): super().reset(task) -@WORKFLOWS.register_module("async_math_workflow") class AsyncMathWorkflow(AsyncSimpleWorkflow, MathWorkflow): pass diff --git a/trinity/manager/config_manager.py b/trinity/manager/config_manager.py index d4b705e85a..0f74cc6af5 100644 --- a/trinity/manager/config_manager.py +++ b/trinity/manager/config_manager.py @@ -7,18 +7,18 @@ import streamlit as st import yaml -from trinity.algorithm.advantage_fn.advantage_fn import ADVANTAGE_FN -from trinity.algorithm.algorithm import ALGORITHM_TYPE -from trinity.algorithm.entropy_loss_fn.entropy_loss_fn import ENTROPY_LOSS_FN -from trinity.algorithm.kl_fn.kl_fn import KL_FN -from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN -from trinity.algorithm.sample_strategy.sample_strategy import SAMPLE_STRATEGY +from trinity.algorithm import ALGORITHM_TYPE +from trinity.algorithm.advantage_fn import ADVANTAGE_FN +from trinity.algorithm.entropy_loss_fn import ENTROPY_LOSS_FN +from trinity.algorithm.kl_fn import KL_FN +from trinity.algorithm.policy_loss_fn import POLICY_LOSS_FN +from trinity.algorithm.sample_strategy import SAMPLE_STRATEGY from trinity.common.constants import StorageType +from trinity.manager.config_registry import CONFIG_GENERATORS from trinity.manager.config_registry.buffer_config_manager import ( get_train_batch_size, parse_priority_fn_args, ) -from trinity.manager.config_registry.config_registry import CONFIG_GENERATORS from trinity.manager.config_registry.trainer_config_manager import use_critic from trinity.utils.plugin_loader import load_plugins diff --git a/trinity/manager/config_registry/algorithm_config_manager.py b/trinity/manager/config_registry/algorithm_config_manager.py index ca1270ab36..da3f94e969 100644 --- a/trinity/manager/config_registry/algorithm_config_manager.py +++ b/trinity/manager/config_registry/algorithm_config_manager.py @@ -1,26 +1,23 @@ import streamlit as st -from trinity.algorithm.advantage_fn import ( - ADVANTAGE_FN, - GRPOAdvantageFn, - OPMDAdvantageFn, - PPOAdvantageFn, -) -from trinity.algorithm.algorithm import ALGORITHM_TYPE, GRPOAlgorithm -from trinity.algorithm.entropy_loss_fn.entropy_loss_fn import ( - ENTROPY_LOSS_FN, - EntropyLossFn, -) -from trinity.algorithm.kl_fn.kl_fn import KL_FN, KLFn -from trinity.algorithm.policy_loss_fn import ( - POLICY_LOSS_FN, - DPOLossFn, - MIXPolicyLossFn, - OPMDPolicyLossFn, - PPOPolicyLossFn, - SFTLossFn, -) -from trinity.algorithm.sample_strategy import SAMPLE_STRATEGY, MixSampleStrategy +from trinity.algorithm import ALGORITHM_TYPE +from trinity.algorithm.advantage_fn import ADVANTAGE_FN +from trinity.algorithm.advantage_fn.grpo_advantage import GRPOAdvantageFn +from trinity.algorithm.advantage_fn.opmd_advantage import OPMDAdvantageFn +from trinity.algorithm.advantage_fn.ppo_advantage import PPOAdvantageFn +from trinity.algorithm.algorithm import GRPOAlgorithm +from trinity.algorithm.entropy_loss_fn import ENTROPY_LOSS_FN +from trinity.algorithm.entropy_loss_fn.entropy_loss_fn import EntropyLossFn +from trinity.algorithm.kl_fn import KL_FN +from trinity.algorithm.kl_fn.kl_fn import KLFn +from trinity.algorithm.policy_loss_fn import POLICY_LOSS_FN +from trinity.algorithm.policy_loss_fn.dpo_loss import DPOLossFn +from trinity.algorithm.policy_loss_fn.mix_policy_loss import MIXPolicyLossFn +from trinity.algorithm.policy_loss_fn.opmd_policy_loss import OPMDPolicyLossFn +from trinity.algorithm.policy_loss_fn.ppo_policy_loss import PPOPolicyLossFn +from trinity.algorithm.policy_loss_fn.sft_loss import SFTLossFn +from trinity.algorithm.sample_strategy import SAMPLE_STRATEGY +from trinity.algorithm.sample_strategy.mix_sample_strategy import MixSampleStrategy from trinity.manager.config_registry.config_registry import CONFIG_GENERATORS from trinity.manager.config_registry.model_config_manager import set_trainer_gpu_num from trinity.utils.registry import Registry diff --git a/trinity/manager/config_registry/buffer_config_manager.py b/trinity/manager/config_registry/buffer_config_manager.py index 92351095f6..06cf45ab59 100644 --- a/trinity/manager/config_registry/buffer_config_manager.py +++ b/trinity/manager/config_registry/buffer_config_manager.py @@ -3,10 +3,10 @@ import pandas as pd import streamlit as st -from trinity.buffer.storage.queue import PRIORITY_FUNC +from trinity.buffer.storage import PRIORITY_FUNC from trinity.common.constants import PromptType, StorageType -from trinity.common.rewards.reward_fn import REWARD_FUNCTIONS -from trinity.common.workflows.workflow import WORKFLOWS +from trinity.common.rewards import REWARD_FUNCTIONS +from trinity.common.workflows import WORKFLOWS from trinity.manager.config_registry.config_registry import CONFIG_GENERATORS diff --git a/trinity/manager/config_registry/trainer_config_manager.py b/trinity/manager/config_registry/trainer_config_manager.py index 060851c154..3ca9ff3284 100644 --- a/trinity/manager/config_registry/trainer_config_manager.py +++ b/trinity/manager/config_registry/trainer_config_manager.py @@ -1,6 +1,6 @@ import streamlit as st -from trinity.algorithm.algorithm import ALGORITHM_TYPE +from trinity.algorithm import ALGORITHM_TYPE from trinity.manager.config_registry.buffer_config_manager import ( get_train_batch_size_per_gpu, ) diff --git a/trinity/trainer/verl/megatron_actor.py b/trinity/trainer/verl/megatron_actor.py index 48849c9037..105cc5456c 100644 --- a/trinity/trainer/verl/megatron_actor.py +++ b/trinity/trainer/verl/megatron_actor.py @@ -42,12 +42,8 @@ from verl.workers.megatron_workers import MegatronPPOActor as OldMegatronPPOActor from verl.workers.megatron_workers import logger -from trinity.algorithm.entropy_loss_fn.entropy_loss_fn import ( - ENTROPY_LOSS_FN, - DummyEntropyLossFn, -) -from trinity.algorithm.kl_fn.kl_fn import KL_FN -from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN +from trinity.algorithm import ENTROPY_LOSS_FN, KL_FN, POLICY_LOSS_FN +from trinity.algorithm.entropy_loss_fn.entropy_loss_fn import DummyEntropyLossFn from trinity.algorithm.utils import prefix_metrics from trinity.common.config import AlgorithmConfig diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py index 603f54c98e..16c0525327 100644 --- a/trinity/trainer/verl_trainer.py +++ b/trinity/trainer/verl_trainer.py @@ -30,8 +30,7 @@ from verl.utils.fs import copy_local_path_from_hdfs from verl.utils.metric import reduce_metrics -from trinity.algorithm import ADVANTAGE_FN, KL_FN, SAMPLE_STRATEGY -from trinity.algorithm.algorithm import ALGORITHM_TYPE +from trinity.algorithm import ADVANTAGE_FN, ALGORITHM_TYPE, KL_FN, SAMPLE_STRATEGY from trinity.algorithm.utils import prefix_metrics from trinity.common.config import Config from trinity.common.constants import SaveStrategy diff --git a/trinity/utils/lora_utils.py b/trinity/utils/lora_utils.py index 3e881016aa..bddce2684c 100644 --- a/trinity/utils/lora_utils.py +++ b/trinity/utils/lora_utils.py @@ -1,8 +1,3 @@ -import torch -from peft import LoraConfig, TaskType, get_peft_model -from transformers import AutoConfig, AutoModelForCausalLM - - def create_dummy_lora( model_path: str, checkpoint_job_dir: str, @@ -10,6 +5,10 @@ def create_dummy_lora( lora_alpha: int, target_modules: str, ) -> str: + import torch + from peft import LoraConfig, TaskType, get_peft_model + from transformers import AutoConfig, AutoModelForCausalLM + config = AutoConfig.from_pretrained(model_path) model = AutoModelForCausalLM.from_config(config) lora_config = { diff --git a/trinity/utils/monitor.py b/trinity/utils/monitor.py index 4ddfadcf9d..73b64229fb 100644 --- a/trinity/utils/monitor.py +++ b/trinity/utils/monitor.py @@ -22,7 +22,14 @@ from trinity.utils.log import get_logger from trinity.utils.registry import Registry -MONITOR = Registry("monitor") +MONITOR = Registry( + "monitor", + default_mapping={ + "tensorboard": "trinity.utils.monitor.TensorboardMonitor", + "wandb": "trinity.utils.monitor.WandbMonitor", + "mlflow": "trinity.utils.monitor.MlflowMonitor", + }, +) def gather_metrics( @@ -98,7 +105,6 @@ def default_args(cls) -> Dict: return {} -@MONITOR.register_module("tensorboard") class TensorboardMonitor(Monitor): def __init__( self, project: str, group: str, name: str, role: str, config: Config = None @@ -121,7 +127,6 @@ def close(self) -> None: self.logger.close() -@MONITOR.register_module("wandb") class WandbMonitor(Monitor): """Monitor with Weights & Biases. @@ -172,7 +177,6 @@ def default_args(cls) -> Dict: } -@MONITOR.register_module("mlflow") class MlflowMonitor(Monitor): """Monitor with MLflow. diff --git a/trinity/utils/registry.py b/trinity/utils/registry.py index d7c0858c8f..10e8aabcfc 100644 --- a/trinity/utils/registry.py +++ b/trinity/utils/registry.py @@ -7,13 +7,15 @@ class Registry(object): """A class for registry.""" - def __init__(self, name: str): + def __init__(self, name: str, default_mapping: dict = {}): """ Args: name (`str`): The name of the registry. + default_mapping (`dict`): Default mapping from module names to module paths (strings). """ self._name = name self._modules = {} + self._default_mapping = default_mapping self.logger = get_logger() @property @@ -49,8 +51,19 @@ def get(self, module_key) -> Any: """ module = self._modules.get(module_key, None) if module is None: - # try to dynamic import - if isinstance(module_key, str) and "." in module_key: + # try to get from default mapping + if module_key in self._default_mapping: + module_path, class_name = self._default_mapping[module_key].rsplit(".", 1) + try: + module = self._dynamic_import(module_path, class_name) + except Exception: + self.logger.error( + f"Failed to dynamically import {class_name} from {module_path}:\n" + + traceback.format_exc() + ) + raise ImportError(f"Cannot dynamically import {class_name} from {module_path}") + # try to get from string path + elif isinstance(module_key, str) and "." in module_key: module_path, class_name = module_key.rsplit(".", 1) try: module = self._dynamic_import(module_path, class_name) @@ -61,6 +74,11 @@ def get(self, module_key) -> Any: ) raise ImportError(f"Cannot dynamically import {class_name} from {module_path}") self._register_module(module_name=module_key, module_cls=module) + elif module_key is None: + self.logger.info("Empty module key, return None") + return None + else: + raise ValueError(f"Invalid module key: {module_key}") return module def _register_module(self, module_name=None, module_cls=None, force=False):