diff --git a/README.md b/README.md index 0df3f8fc80..797b71ec0f 100644 --- a/README.md +++ b/README.md @@ -42,6 +42,7 @@ Trinity-RFT provides functionalities for users with different backgrounds and ob ## 🚀 News +* [2025-12] Trinity-RFT has supported [tinker](https://thinkingmachines.ai/tinker/) training backend, which enables model training on devices **without GPUs**. * [2025-12] Trinity-RFT powers the medical and health business of "Taobao Shangou", enabling the AI agent to understand vague symptoms, proactively ask follow-up questions, and provide precise recommendations ([News](https://tech.china.com.cn/sx/20251201/411376.shtml)). * [2025-11] [[Release Notes](https://github.com/modelscope/Trinity-RFT/releases/tag/v0.3.3)] Trinity-RFT v0.3.3 released: bug fixes. * [2025-11] Introducing [Learn-to-Ask](https://github.com/modelscope/Trinity-RFT/tree/main/examples/learn_to_ask): a framework for training proactive dialogue agents from offline expert data ([paper](https://arxiv.org/pdf/2510.25441)). @@ -154,6 +155,10 @@ We list some algorithms supported by Trinity-RFT in the following table. For mor > [!NOTE] > This project is currently under active development. Comments and suggestions are welcome! +> +> **No GPU? No problem!** You can still try it out: +> 1. Follow the installation steps (feel free to skip GPU-specific packages like `flash-attn`) +> 2. Run the **[Tinker training example](https://github.com/modelscope/Trinity-RFT/tree/main/examples/tinker)**, which is specifically designed to work on CPU-only systems. ### Step 1: installation @@ -186,10 +191,15 @@ Choose one of the following options: conda create -n trinity python=3.12 conda activate trinity -pip install -e ".[dev]" -pip install -e ".[flash_attn]" -# if you encounter issues when installing flash-attn, try: +pip install -e ".[vllm,flash_attn]" + +# If you have no GPU, comment out the line above and uncomment this instead: +# pip install -e ".[tinker]" + +# If you encounter issues when installing flash-attn, try: # pip install flash-attn==2.8.1 --no-build-isolation + +pip install -e ".[dev]" # for development like linting and debugging ``` ###### Using venv @@ -198,10 +208,15 @@ pip install -e ".[flash_attn]" python3.10 -m venv .venv source .venv/bin/activate -pip install -e ".[dev]" -pip install -e ".[flash_attn]" -# if you encounter issues when installing flash-attn, try: +pip install -e ".[vllm,flash_attn]" + +# If you have no GPU, comment out the line above and uncomment this instead: +# pip install -e ".[tinker]" + +# If you encounter issues when installing flash-attn, try: # pip install flash-attn==2.8.1 --no-build-isolation + +pip install -e ".[dev]" # for development like linting and debugging ``` ###### Using `uv` @@ -209,7 +224,10 @@ pip install -e ".[flash_attn]" [`uv`](https://github.com/astral-sh/uv) is a modern Python package installer. ```bash -uv sync --extra dev --extra flash_attn +uv sync --extra vllm --extra dev --extra flash_attn + +# If you have no GPU, try to use Tinker instead: +# uv sync --extra tinker --extra dev ``` diff --git a/README_zh.md b/README_zh.md index a16063f30d..9a2ce35baa 100644 --- a/README_zh.md +++ b/README_zh.md @@ -41,6 +41,7 @@ Trinity-RFT 面向不同背景和目标的用户提供相应功能: ## 🚀 新闻 +* [2025-12] Trinity-RFT 已支持 [tinker](https://thinkingmachines.ai/tinker/) 训练后端,可在**无 GPU 的设备**上进行模型训练。 * [2025-12] Trinity-RFT 助力淘宝闪购医药健康业务,让 AI 智能体能够理解模糊症状、主动询问后续问题,并提供精准推荐([新闻](https://tech.china.com.cn/sx/20251201/411376.shtml))。 * [2025-11] [[发布说明](https://github.com/modelscope/Trinity-RFT/releases/tag/v0.3.3)] Trinity-RFT v0.3.3 发布:修复若干 Bug。 * [2025-11] 推出 [Learn-to-Ask](https://github.com/modelscope/Trinity-RFT/tree/main/examples/learn_to_ask):利用离线专家数据,训练具备主动问询能力的对话智能体([论文](https://arxiv.org/pdf/2510.25441)). @@ -154,6 +155,10 @@ Trinity-RFT 面向不同背景和目标的用户提供相应功能: > [!NOTE] > 本项目正处于活跃开发阶段。欢迎提出意见和建议! +> +> **没有 GPU?没问题!** 您仍然可以尝试使用: +> 1. 按照安装步骤进行操作(可跳过 `flash-attn` 等 GPU 专用的软件包) +> 2. 运行 **[Tinker 训练示例](https://github.com/modelscope/Trinity-RFT/tree/main/examples/tinker)**,该示例专为仅使用 CPU 的系统设计。 ### 第一步:安装 @@ -185,10 +190,15 @@ cd Trinity-RFT conda create -n trinity python=3.12 conda activate trinity -pip install -e ".[dev]" -pip install -e ".[flash_attn]" +pip install -e ".[vllm,flash_attn]" + +# 如果没有GPU,可以注释上一行的命令,改为使用Tinker: +# pip install -e ".[tinker]" + # 如果安装 flash-attn 时遇到问题,可尝试: # pip install flash-attn==2.8.1 --no-build-isolation + +pip install -e ".[dev]" # 用于调试和开发 ``` #### 使用 venv @@ -197,10 +207,15 @@ pip install -e ".[flash_attn]" python3.10 -m venv .venv source .venv/bin/activate -pip install -e ".[dev]" -pip install -e ".[flash_attn]" +pip install -e ".[vllm,flash_attn]" + +# 如果没有GPU,可以注释上一行的命令,改为使用Tinker: +# pip install -e ".[tinker]" + # 如果安装 flash-attn 时遇到问题,可尝试: # pip install flash-attn==2.8.1 --no-build-isolation + +pip install -e ".[dev]" # 用于调试和开发 ``` #### 使用 `uv` @@ -208,7 +223,10 @@ pip install -e ".[flash_attn]" [`uv`](https://github.com/astral-sh/uv) 是现代的 Python 包管理工具。 ```bash -uv sync --extra dev --extra flash_attn +uv sync --extra vllm --extra dev --extra flash_attn + +# 如果没有GPU,可以改为使用Tinker: +# uv sync --extra tinker --extra dev ``` ## 通过 PyPI 安装 diff --git a/docs/sphinx_doc/assets/tinker-gsm8k.png b/docs/sphinx_doc/assets/tinker-gsm8k.png new file mode 100644 index 0000000000..fe0d1145a3 Binary files /dev/null and b/docs/sphinx_doc/assets/tinker-gsm8k.png differ diff --git a/docs/sphinx_doc/source/tutorial/example_async_mode.md b/docs/sphinx_doc/source/tutorial/example_async_mode.md index f10089c001..fbecc00c68 100644 --- a/docs/sphinx_doc/source/tutorial/example_async_mode.md +++ b/docs/sphinx_doc/source/tutorial/example_async_mode.md @@ -112,6 +112,10 @@ You can run this example with the following command: bash examples/async_gsm8k/run.sh ``` +```{note} +In the current asynchronous RFT training, it is recommended to start the Trainer before starting the Explorer to avoid the situation where the Trainer cannot read the generated experience data after the Explorer process terminates prematurely. This issue will be resolved in a future version. +``` + The following plot shows the learning curve of GRPO in the asynchronous mode. > This result should be regarded merely as a baseline, since GRPO is supposed to be an on-policy algorithm. > We are continuously investigating other RL algorithms (e.g., [OPMD](./example_reasoning_advanced.md)) in the asynchronous mode. diff --git a/docs/sphinx_doc/source/tutorial/trinity_configs.md b/docs/sphinx_doc/source/tutorial/trinity_configs.md index b1d08f03b0..0c858ae505 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source/tutorial/trinity_configs.md @@ -164,9 +164,20 @@ model: max_response_tokens: 16384 min_response_tokens: 1 enable_prompt_truncation: true + repetition_penalty: 1.0 + lora_configs: null + rope_scaling: null + rope_theta: null + tinker: + enable: false + rank: 32 + seed: null + train_mlp: true + train_attn: true + train_unembed: true ``` -- `model_path`: Path to the model being trained. +- `model_path`: Path to the model being trained. If `tinker` is enabled, this is the path to the local tokenizer. - `critic_model_path`: Optional path to a separate critic model. If empty, defaults to `model_path`. - `custom_chat_template`: Optional custom chat template in string format. If not specified, the system will use the default chat template from tokenizer. - `chat_template_path`: Optional path to the chat template file in jinja2 type; overrides `custom_chat_template` if set. If not specified, the system will use the default chat template from tokenizer. @@ -175,6 +186,24 @@ model: - `max_prompt_tokens`: Maximum number of tokens allowed in prompts. Only for `chat` and `generate` methods in `InferenceModel`. - `min_response_tokens`: Minimum number of tokens allowed in generated responses. Only for `chat` and `generate` methods in `InferenceModel`. Default is `1`. It must be less than `max_response_tokens`. - `enable_prompt_truncation`: Whether to truncate the prompt. Default is `true`. If set to `true`, the prompt will be truncated to `max_prompt_tokens` tokens; if set to `false`, the prompt will not be truncated and there is a risk that the prompt length plus response length exceeds `max_model_len`. This function does not work with openai api mode. +- `repetition_penalty`: Repetition penalty factor. Default is `1.0`. +- `lora_configs`: Optional LoRA configuration. If not specified, defaults to `null`. Currently, only one LoRA configuration is supported, and this configuration will not be applied if `tinker` is enabled. + - `name`: Name of the LoRA. Default is `None`. + - `path`: Path to the LoRA. Default is `None`. + - `base_model_name`: Name of the base model for LoRA. If not specified, defaults to `None`. + - `lora_rank`: Rank of the LoRA. Default is `32`. + - `lora_alpha`: Alpha value of the LoRA. Default is `32`. + - `lora_dtype`: Data type of the LoRA. Default is `auto`. + - `target_modules`: List of target modules for LoRA. Default is `all-linear`. +- `rope_scaling`: Optional RoPE scaling configuration in JSON format. If not specified, defaults to `null`. +- `rope_theta`: Optional RoPE theta value. If not specified, defaults to `null`. +- `tinker`: Optional Tinker configuration. Note: LoRA configuration will be ignored if Tinker is enabled. + - `enable`: Whether to enable Tinker. Default is `false`. + - `rank`: LoRA rank controlling the size of adaptation matrices. Default is `32`. + - `seed`: Random seed for Tinker. If not specified, defaults to `null`. + - `train_mlp`: Whether to train the MLP layer. Default is `true`. + - `train_attn`: Whether to train the attention layer. Default is `true`. + - `train_unembed`: Whether to train the unembedding layer. Default is `true`. ```{tip} If you are using the openai API provided by Explorer, only `max_model_len` will take effect, and the value of `max_response_tokens`, `max_prompt_tokens`, and `min_response_tokens` will be ignored. When `max_tokens` is not independently specified, each API call will generate up to `max_model_len - prompt_length` tokens. Therefore, please ensure that the prompt length is less than `max_model_len` when using the API. diff --git a/docs/sphinx_doc/source/tutorial/trinity_installation.md b/docs/sphinx_doc/source/tutorial/trinity_installation.md index 2e554906d8..8523518515 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_installation.md +++ b/docs/sphinx_doc/source/tutorial/trinity_installation.md @@ -3,11 +3,18 @@ For installing Trinity-RFT, you have three options: from source (recommended), via PyPI, or using Docker. -Before installing, ensure your system meets the following requirements: +**Before you begin**, check your system setup: -- **Python**: Version 3.10 to 3.12 (inclusive) -- **CUDA**: Version >= 12.8 -- **GPUs**: At least 2 GPUs +### If you have GPUs and want to use them: +Make sure your system meets these requirements: +- **Python**: 3.10 – 3.12 +- **CUDA**: 12.8 or higher +- **GPUs**: At least 2 available + +### If you don’t have GPUs (or prefer not to use them): +You can use the `tinker` option instead, which only requires: +- **Python**: 3.11 – 3.12 +- **GPUs**: Not required --- @@ -32,10 +39,15 @@ Choose one of the following options: conda create -n trinity python=3.12 conda activate trinity -pip install -e ".[dev]" -pip install -e ".[flash_attn]" -# if you encounter issues when installing flash-attn, try: +pip install -e ".[vllm,flash_attn]" + +# If you have no GPU, comment out the line above and uncomment this instead: +# pip install -e ".[tinker]" + +# If you encounter issues when installing flash-attn, try: # pip install flash-attn==2.8.1 --no-build-isolation + +pip install -e ".[dev]" # for development like linting and debugging ``` #### Using venv @@ -44,10 +56,15 @@ pip install -e ".[flash_attn]" python3.10 -m venv .venv source .venv/bin/activate -pip install -e ".[dev]" -pip install -e ".[flash_attn]" -# if you encounter issues when installing flash-attn, try: +pip install -e ".[vllm,flash_attn]" + +# If you have no GPU, comment out the line above and uncomment this instead: +# pip install -e ".[tinker]" + +# If you encounter issues when installing flash-attn, try: # pip install flash-attn==2.8.1 --no-build-isolation + +pip install -e ".[dev]" # for development like linting and debugging ``` #### Using `uv` @@ -55,7 +72,10 @@ pip install -e ".[flash_attn]" [`uv`](https://github.com/astral-sh/uv) is a modern Python package installer. ```bash -uv sync --extra dev --extra flash_attn +uv sync --extra vllm --extra dev --extra flash_attn + +# If you have no GPU, try to use Tinker instead: +# uv sync --extra tinker --extra dev ``` --- diff --git a/docs/sphinx_doc/source_zh/tutorial/example_async_mode.md b/docs/sphinx_doc/source_zh/tutorial/example_async_mode.md index 804fa01b49..96befe34e8 100644 --- a/docs/sphinx_doc/source_zh/tutorial/example_async_mode.md +++ b/docs/sphinx_doc/source_zh/tutorial/example_async_mode.md @@ -112,6 +112,10 @@ trainer: bash examples/async_gsm8k/run.sh ``` +```{note} +目前异步 RFT 训练中,最好需要先启动Trainer后启动Explorer,以避免在Explorer进程提前结束之后,Trainer读取不到生成的Experience数据。此问题将在未来的版本中解决。 +``` + 下图展示了 GRPO 在异步模式下的学习曲线: > 此结果仅应视为基线,因为 GRPO 本质上是一种 on-policy 算法。 > 我们正在持续研究其他在异步模式下适用的强化学习算法(例如 [OPMD](./example_reasoning_advanced.md))。 diff --git a/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md b/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md index 0fff6cd91f..6d96bff547 100644 --- a/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md @@ -164,9 +164,20 @@ model: max_response_tokens: 16384 min_response_tokens: 1 enable_prompt_truncation: true + repetition_penalty: 1.0 + lora_configs: null + rope_scaling: null + rope_theta: null + tinker: + enable: false + rank: 32 + seed: null + train_mlp: true + train_attn: true + train_unembed: true ``` -- `model_path`: 被训练模型的路径。 +- `model_path`: 被训练模型的路径。如果启用了`tinker`,则该路径为本地 tokenizer 的路径。 - `critic_model_path`: 可选的独立 critic 模型路径。若为空,则默认为 `model_path`。 - `custom_chat_template`: 可选的自定义 chat template 字符串格式。若未指定,系统会使用 tokenizer 的默认 chat template。 - `chat_template_path`: 可选的 chat template 文件路径,类型通常为 jinja2;若设置,则覆盖 `custom_chat_template`。若未指定,系统会使用 tokenizer 的默认 chat template。 @@ -175,6 +186,24 @@ model: - `max_response_tokens`: 模型生成的回复中允许的最大 token 数。仅对 `InferenceModel` 中的 `chat` 和 `generate` 方法生效。 - `min_response_tokens`: 模型生成的回复中允许的最小 token 数。仅对 `InferenceModel` 中的 `chat` 和 `generate` 方法生效。 - `enable_prompt_truncation`: 是否截断 prompt。默认为 `true`。若设置为 `true`,则 prompt 将被截断为 `max_prompt_tokens` 个 token;若设置为 `false`,则 prompt 不会被截断,存在 prompt 和 response 长度之和超过 `max_model_len` 的风险。在 OpenAI API 模式下不生效。 +- `repetition_penalty`:重复惩罚因子。默认值为 `1.0`。 +- `lora_configs`:可选的 LoRA 配置。若未指定,则默认为 `null`。目前仅支持一个 LoRA 配置,并且如果启用了`tinker`,则不会使用此LoRA配置。 + - `name`:LoRA 的名称。默认为 `None`。 + - `path`:LoRA 的路径。默认为 `None`。 + - `base_model_name`:LoRA 所基于的基础模型名称。若未指定,则默认为 `None`。 + - `lora_rank`:LoRA 的秩(rank)。默认为 `32`。 + - `lora_alpha`:LoRA 的 alpha 值。默认为 `32`。 + - `lora_dtype`:LoRA 的数据类型。默认为 `auto`。 + - `target_modules`:LoRA 的目标模块列表。默认为 `all-linear`。 +- `rope_scaling`:可选的 RoPE 缩放配置,采用 JSON 格式。若未指定,则默认为 `null`。 +- `rope_theta`:可选的 RoPE theta 值。若未指定,则默认为 `null`。 +- `tinker`:可选的 Tinker 配置。注意:若启用 Tinker,则 LoRA 配置将被忽略。 + - `enable`:是否启用 Tinker。默认为 `false`。 + - `rank`:控制适配矩阵大小的 LoRA 秩(rank)。默认为 `32`。 + - `seed`:Tinker 使用的随机种子。若未指定,则默认为 `null`。 + - `train_mlp`:是否训练 MLP 层。默认为 `true`。 + - `train_attn`:是否训练注意力层。默认为 `true`。 + - `train_unembed`:是否训练反嵌入(unembedding)层。默认为 `true`。 ```{tip} 如果使用的是 Explorer 提供的 openai API,则只有 `max_model_len` 会生效,而 `max_response_tokens`、`max_prompt_tokens` 和 `min_response_tokens` 的值将被忽略,在没有独立指定 `max_tokens` 时,每次 API 调用将生成最多 `max_model_len - prompt_length` 个 token,因此在使用时请确保 prompt 长度小于 `max_model_len`。 diff --git a/docs/sphinx_doc/source_zh/tutorial/trinity_installation.md b/docs/sphinx_doc/source_zh/tutorial/trinity_installation.md index 24b4eefbb2..4244af0021 100644 --- a/docs/sphinx_doc/source_zh/tutorial/trinity_installation.md +++ b/docs/sphinx_doc/source_zh/tutorial/trinity_installation.md @@ -3,11 +3,18 @@ 安装 Trinity-RFT 有三种方式:源码安装(推荐)、通过 PyPI 安装,或使用 Docker。 -在安装前,请确保您的系统满足以下要求: +**开始之前**,请检查您的系统配置: -- **Python**:3.10 至 3.12(包含) -- **CUDA**:大于等于 12.8 -- **GPU**:至少 2 块 GPU +### 如果您拥有 GPU 并希望使用它们: +请确保您的系统满足以下要求: +- **Python**:3.10 – 3.12 +- **CUDA**:12.8 或更高版本 +- **GPU**:至少 2 块可用 + +### 如果您没有 GPU(或不希望使用 GPU): +您可以改用 `tinker` 选项,该选项仅需满足: +- **Python**:3.11 – 3.12 +- **GPU**:无需 --- @@ -32,10 +39,15 @@ cd Trinity-RFT conda create -n trinity python=3.12 conda activate trinity -pip install -e ".[dev]" -pip install -e ".[flash_attn]" +pip install -e ".[vllm,flash_attn]" + +# 如果没有GPU,可以注释上一行的命令,改为使用Tinker: +# pip install -e ".[tinker]" + # 如果安装 flash-attn 时遇到问题,可尝试: # pip install flash-attn==2.8.1 --no-build-isolation + +pip install -e ".[dev]" # 用于调试和开发 ``` #### 使用 venv @@ -44,10 +56,15 @@ pip install -e ".[flash_attn]" python3.10 -m venv .venv source .venv/bin/activate -pip install -e ".[dev]" -pip install -e ".[flash_attn]" +pip install -e ".[vllm,flash_attn]" + +# 如果没有GPU,可以注释上一行的命令,改为使用Tinker: +# pip install -e ".[tinker]" + # 如果安装 flash-attn 时遇到问题,可尝试: # pip install flash-attn==2.8.1 --no-build-isolation + +pip install -e ".[dev]" # 用于调试和开发 ``` #### 使用 `uv` @@ -55,7 +72,10 @@ pip install -e ".[flash_attn]" [`uv`](https://github.com/astral-sh/uv) 是现代的 Python 包管理工具。 ```bash -uv sync --extra dev --extra flash_attn +uv sync --extra vllm --extra dev --extra flash_attn + +# 如果没有GPU,可以改为使用Tinker: +# uv sync --extra tinker --extra dev ``` --- diff --git a/examples/async_gsm8k/README.md b/examples/async_gsm8k/README.md index d60f7d6a42..6c00956693 100644 --- a/examples/async_gsm8k/README.md +++ b/examples/async_gsm8k/README.md @@ -11,3 +11,6 @@ You can run this example by the following command: ```bash bash examples/async_gsm8k/run.sh ``` + +> [!NOTE] +> In the current asynchronous RFT training, it is recommended to start the Trainer before starting the Explorer to avoid the situation where the Trainer cannot read the generated experience data after the Explorer process terminates prematurely. This issue will be resolved in a future version. diff --git a/examples/async_gsm8k/run.sh b/examples/async_gsm8k/run.sh index ff9ad66bbc..6801da4118 100644 --- a/examples/async_gsm8k/run.sh +++ b/examples/async_gsm8k/run.sh @@ -1,4 +1,4 @@ #!/bin/bash -trinity run --config examples/async_gsm8k/explorer.yaml 2>&1 | tee explorer.log & -sleep 30 trinity run --config examples/async_gsm8k/trainer.yaml 2>&1 | tee trainer.log & +sleep 30 +trinity run --config examples/async_gsm8k/explorer.yaml 2>&1 | tee explorer.log & diff --git a/examples/tinker/README.md b/examples/tinker/README.md new file mode 100644 index 0000000000..79de2c959f --- /dev/null +++ b/examples/tinker/README.md @@ -0,0 +1,245 @@ +# Trinity with Tinker Backend + +> [!NOTE] +> This example demonstrates how to use Trinity with the [Tinker](https://thinkingmachines.ai/tinker/) backend, which enables model training on devices **without GPUs**. + +## Setup Instructions + +### 1. API Key Configuration +Before starting Ray, you must set the `TRINITY_API_KEY` environment variable to your Tinker API key to enable proper access to Tinker's API: + +```bash +export TRINITY_API_KEY=your_tinker_api_key +``` + +### 2. Configuration File +Configure the Tinker backend in your YAML configuration file by setting the `model.tinker` parameters as shown below: + +```yaml +model: + tinker: + enable: true + base_model: null + rank: 32 + seed: null + train_mlp: true + train_attn: true + train_unembed: true +``` + +### 3. Configuration Parameters Explained + +- **`tinker`**: Tinker-specific configuration section. **Important**: When Tinker is enabled, any LoRA configuration settings (`model.lora_configs`) will be ignored. + - **`enable`**: Whether to activate the Tinker backend. Default: `false` + - **`base_model`**: Path to the base model for Tinker. If not specified (`null`), it defaults to the `model_path` defined elsewhere in your config + - **`rank`**: The LoRA rank that controls the size of the adaptation matrices. Default: `32` + - **`seed`**: Random seed for reproducible Tinker operations. If not specified (`null`), no specific seed is set + - **`train_mlp`**: Whether to train the MLP (feed-forward) layers. Default: `true` + - **`train_attn`**: Whether to train the attention layers. Default: `true` + - **`train_unembed`**: Whether to train the unembedding (output) layer. Default: `true` + + +## Usage + +Once configured, Trinity works with the Tinker backend just like it does with the standard veRL backend. Start training with: + +```bash +trinity run --config tinker.yaml # Replace with your actual config file path +``` + +### Important Limitations of the Tinker Backend + +1. **Entropy loss** is not consistent compared to veRL backends. +2. **Algorithms requiring `compute_advantage_in_trainer=true` are NOT supported currently**, including: + - PPO (`algorithm.algorithm_type=ppo`) + - Reinforce++ (`algorithm.algorithm_type=reinforceplusplus`) + - RLOO (`algorithm.algorithm_type=rloo`) + - On-policy distillation (`algorithm.algorithm_type=on_policy_distill`) + + Algorithms like `algorithm.algorithm_type=grpo` are supported. We will add support for these algorithms in the future. +3. **Multiple stages training** is not supported currently, we will add support for this in the future. + +> 💡 A complete example configuration file is available at [`tinker.yaml`](tinker.yaml). + + +## Results on the Llama-3.2-3B Model + +We trained the **Llama-3.2-3B** model on the **GSM8K** dataset using both the **Tinker** and **veRL** backends. Below are the full configuration files used in our experiments. + + +
Click to expand: Tinker Backend Configuration + +```yaml +mode: both +project: Trinity-RFT-gsm8k +group: alignment-tinker +name: tinker-llama3.2-3B-off1 +checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints} +algorithm: + algorithm_type: grpo + repeat_times: 8 + sample_strategy: default + kl_loss_fn_args: + kl_coef: 0.0 + optimizer: + lr: 1.0e-05 + lr_warmup_steps_ratio: 0.0 + warmup_style: constant +data_processor: {} +model: + model_path: meta-llama/Llama-3.2-3B + max_prompt_tokens: 1024 + max_response_tokens: 2048 + custom_chat_template: "{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- if strftime_now is defined %}\n {%- set date_string = strftime_now(\"%d %b %Y\") %}\n {%- else %}\n {%- set date_string = \"26 Jul 2024\" %}\n {%- endif %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = \"\" %}\n{%- endif %}\n\n{#- System message #}\n{{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n{%- if tools is not none %}\n {{- \"Environment: ipython\\n\" }}\n{%- endif %}\n{{- \"Cutting Knowledge Date: December 2023\\n\" }}\n{{- \"Today Date: \" + date_string + \"\\n\\n\" }}\n{%- if tools is not none and not tools_in_user_message %}\n {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n{%- endif %}\n{{- system_message }}\n{{- \"<|eot_id|>\" }}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n {{- '<|start_header_id|>user<|end_header_id|>\\n\\n' -}}\n {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n {{- \"with its proper arguments that best answers the given prompt.\\n\\n\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n {{- first_user_message + \"<|eot_id|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }}\n {%- elif 'tool_calls' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception(\"This model only supports single tool-calls at once!\") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- '{\"name\": \"' + tool_call.name + '\", ' }}\n {{- '\"parameters\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- \"}\" }}\n {{- \"<|eot_id|>\" }}\n {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n {{- \"<|start_header_id|>ipython<|end_header_id|>\\n\\n\" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n" + tinker: + enable: true + base_model: meta-llama/Llama-3.2-3B +cluster: + node_num: 1 + gpu_per_node: 8 +buffer: + batch_size: 96 + total_epochs: 1 + explorer_input: + taskset: + name: taskset + storage_type: file + path: openai/gsm8k + split: train + subset_name: main + format: + prompt_key: question + response_key: answer + rollout_args: + temperature: 1.0 + logprobs: 0 + eval_tasksets: [] + default_workflow_type: math_workflow + trainer_input: + experience_buffer: + name: experience_buffer + storage_type: queue + replay_buffer: + enable: false +explorer: + runner_per_model: 16 + rollout_model: + engine_num: 4 + seed: 42 + auxiliary_models: [] + eval_interval: 1000 +trainer: + save_interval: 100 + enable_preview: true + grad_clip: 1.0 + max_token_len_per_gpu: 16384 +monitor: + monitor_type: wandb +synchronizer: + sync_method: checkpoint + sync_style: fixed + sync_interval: 1 + sync_offset: 1 + sync_timeout: 1200 +``` + +
+ + +
Click to expand: veRL Backend Configuration (LoRA) + +```yaml +mode: both +project: Trinity-RFT-gsm8k +group: alignment-tinker +name: verl-llama3.2-3B-lora-off1 +checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints} +algorithm: + algorithm_type: grpo + repeat_times: 8 + sample_strategy: default + kl_loss_fn_args: + kl_coef: 0.0 + optimizer: + lr: 1.0e-05 + lr_warmup_steps_ratio: 0.0 + warmup_style: constant +data_processor: {} +model: + model_path: meta-llama/Llama-3.2-3B + max_prompt_tokens: 1024 + max_response_tokens: 2048 + custom_chat_template: "{{- bos_token }}\n{%- if custom_tools is defined %}\n {%- set tools = custom_tools %}\n{%- endif %}\n{%- if not tools_in_user_message is defined %}\n {%- set tools_in_user_message = true %}\n{%- endif %}\n{%- if not date_string is defined %}\n {%- if strftime_now is defined %}\n {%- set date_string = strftime_now(\"%d %b %Y\") %}\n {%- else %}\n {%- set date_string = \"26 Jul 2024\" %}\n {%- endif %}\n{%- endif %}\n{%- if not tools is defined %}\n {%- set tools = none %}\n{%- endif %}\n\n{#- This block extracts the system message, so we can slot it into the right place. #}\n{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n{%- else %}\n {%- set system_message = \"\" %}\n{%- endif %}\n\n{#- System message #}\n{{- \"<|start_header_id|>system<|end_header_id|>\\n\\n\" }}\n{%- if tools is not none %}\n {{- \"Environment: ipython\\n\" }}\n{%- endif %}\n{{- \"Cutting Knowledge Date: December 2023\\n\" }}\n{{- \"Today Date: \" + date_string + \"\\n\\n\" }}\n{%- if tools is not none and not tools_in_user_message %}\n {{- \"You have access to the following functions. To call a function, please respond with JSON for a function call.\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n{%- endif %}\n{{- system_message }}\n{{- \"<|eot_id|>\" }}\n\n{#- Custom tools are passed in a user message with some extra guidance #}\n{%- if tools_in_user_message and not tools is none %}\n {#- Extract the first user message so we can plug it in here #}\n {%- if messages | length != 0 %}\n {%- set first_user_message = messages[0]['content']|trim %}\n {%- set messages = messages[1:] %}\n {%- else %}\n {{- raise_exception(\"Cannot put tools in the first user message when there's no first user message!\") }}\n{%- endif %}\n {{- '<|start_header_id|>user<|end_header_id|>\\n\\n' -}}\n {{- \"Given the following functions, please respond with a JSON for a function call \" }}\n {{- \"with its proper arguments that best answers the given prompt.\\n\\n\" }}\n {{- 'Respond in the format {\"name\": function name, \"parameters\": dictionary of argument name and its value}.' }}\n {{- \"Do not use variables.\\n\\n\" }}\n {%- for t in tools %}\n {{- t | tojson(indent=4) }}\n {{- \"\\n\\n\" }}\n {%- endfor %}\n {{- first_user_message + \"<|eot_id|>\"}}\n{%- endif %}\n\n{%- for message in messages %}\n {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}\n {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n'+ message['content'] | trim + '<|eot_id|>' }}\n {%- elif 'tool_calls' in message %}\n {%- if not message.tool_calls|length == 1 %}\n {{- raise_exception(\"This model only supports single tool-calls at once!\") }}\n {%- endif %}\n {%- set tool_call = message.tool_calls[0].function %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' -}}\n {{- '{\"name\": \"' + tool_call.name + '\", ' }}\n {{- '\"parameters\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- \"}\" }}\n {{- \"<|eot_id|>\" }}\n {%- elif message.role == \"tool\" or message.role == \"ipython\" %}\n {{- \"<|start_header_id|>ipython<|end_header_id|>\\n\\n\" }}\n {%- if message.content is mapping or message.content is iterable %}\n {{- message.content | tojson }}\n {%- else %}\n {{- message.content }}\n {%- endif %}\n {{- \"<|eot_id|>\" }}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}\n{%- endif %}\n" + lora_configs: + - name: lora + lora_rank: 32 + lora_alpha: 32 +cluster: + node_num: 1 + gpu_per_node: 8 +buffer: + batch_size: 96 + total_epochs: 1 + explorer_input: + taskset: + name: taskset + storage_type: file + path: openai/gsm8k + split: train + subset_name: main + format: + prompt_key: question + response_key: answer + rollout_args: + temperature: 1.0 + logprobs: 0 + eval_tasksets: [] + default_workflow_type: math_workflow + trainer_input: + experience_buffer: + name: experience_buffer + storage_type: queue + replay_buffer: + enable: false +explorer: + runner_per_model: 16 + rollout_model: + engine_num: 4 + tensor_parallel_size: 1 + enforce_eager: false + enable_prefix_caching: false + enable_chunked_prefill: false + gpu_memory_utilization: 0.9 + dtype: bfloat16 + seed: 42 + enable_thinking: false + enable_history: false + enable_openai_api: false + enable_auto_tool_choice: false + tool_call_parser: null + reasoning_parser: null + auxiliary_models: [] + eval_interval: 1000 +trainer: + trainer_type: verl + save_interval: 100 + enable_preview: true + grad_clip: 1.0 + max_token_len_per_gpu: 16384 +monitor: + monitor_type: wandb +synchronizer: + sync_method: checkpoint + sync_style: fixed + sync_interval: 1 + sync_offset: 1 + sync_timeout: 1200 +``` + +
+ +### Observations + +Since Llama-3.2-3B is a base (non-instruct-tuned) model, it has limited ability to follow formatting instructions. Additionally, we trained for only **one epoch**. As a result, both backends achieved final rewards just slightly above 0.1. Nonetheless, the training curves show a clear upward trend in reward, indicating successful learning. The results are visualized below: + +![Training Rewards on GSM8K](../../docs/sphinx_doc/assets/tinker-gsm8k.png) diff --git a/examples/tinker/tinker.yaml b/examples/tinker/tinker.yaml new file mode 100644 index 0000000000..744357e745 --- /dev/null +++ b/examples/tinker/tinker.yaml @@ -0,0 +1,67 @@ +mode: both +project: Trinity-RFT-gsm8k +name: tinker-Qwen3-4B +checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints} +algorithm: + algorithm_type: grpo + repeat_times: 8 + sample_strategy: default + kl_loss_fn_args: + kl_coef: 0.0 + optimizer: + lr: 1.0e-05 + lr_warmup_steps_ratio: 0.0 + warmup_style: constant +data_processor: {} +model: + model_path: Qwen/Qwen3-4B-Instruct-2507 + max_prompt_tokens: 1024 + max_response_tokens: 2048 + tinker: + enable: true + base_model: Qwen/Qwen3-4B-Instruct-2507 +buffer: + batch_size: 96 + total_epochs: 1 + explorer_input: + taskset: + name: taskset + storage_type: file + path: openai/gsm8k + split: train + subset_name: main + format: + prompt_key: question + response_key: answer + rollout_args: + temperature: 1.0 + logprobs: 0 + eval_tasksets: [] + default_workflow_type: math_workflow + trainer_input: + experience_buffer: + name: experience_buffer + storage_type: queue + replay_buffer: + enable: false +explorer: + runner_per_model: 8 + rollout_model: + engine_num: 4 + seed: 42 + auxiliary_models: [] + eval_interval: 1000 +trainer: + save_interval: 100 + enable_preview: true + grad_clip: 1.0 + max_token_len_per_gpu: 16384 +monitor: + monitor_type: tensorboard +synchronizer: + sync_method: memory + sync_style: fixed + sync_interval: 1 + sync_timeout: 1200 +log: + level: INFO diff --git a/pyproject.toml b/pyproject.toml index f741f8151c..e92f3ba98e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,6 @@ requires-python = ">=3.10,<3.13" dependencies = [ "verl==0.5.0", "ray[default]>=2.50.0", - "vllm>=0.10.2,<=0.11.0", "tensordict", "wandb", "omegaconf", @@ -43,13 +42,16 @@ dependencies = [ "sortedcontainers", "word2number", "transformers", - "tinker", + "datasets", ] [project.scripts] trinity = "trinity.cli.launcher:main" [project.optional-dependencies] +vllm = [ + "vllm>=0.10.2,<=0.11.0", +] data = [ "py-data-juicer>=1.4.3" ] @@ -79,6 +81,9 @@ megatron = [ "transformer_engine[pytorch]==2.8.0", "mbridge>=0.13.0", ] +tinker = [ + "tinker", # tinker requires python>=3.11 +] doc = [ "sphinx", diff --git a/scripts/docker/Dockerfile b/scripts/docker/Dockerfile index b87558f123..ab21822289 100644 --- a/scripts/docker/Dockerfile +++ b/scripts/docker/Dockerfile @@ -27,7 +27,7 @@ RUN chmod 1777 /tmp && apt update && apt install -y \ # copy the Trinity-RFT dir into the workspace COPY . . -RUN pip install --upgrade pip && pip install -e .[mm,dev] && pip install flash_attn==2.8.1 --no-build-isolation +RUN pip install --upgrade pip && pip install -e .[vllm,mm,dev] && pip install flash_attn==2.8.1 --no-build-isolation # Set Env variables diff --git a/scripts/docker/Dockerfile.megatron b/scripts/docker/Dockerfile.megatron index 7659452311..41c4e88168 100644 --- a/scripts/docker/Dockerfile.megatron +++ b/scripts/docker/Dockerfile.megatron @@ -28,7 +28,7 @@ COPY . . # Install Trinity-RFT with Megatron RUN pip install --upgrade pip \ - && pip install -e .[mm,dev] \ + && pip install -e .[vllm,mm,dev] \ && pip install flash_attn==2.8.1 --no-build-isolation \ && pip install -e .[megatron] \ && NVCC_APPEND_FLAGS="--threads 4" APEX_PARALLEL_BUILD=8 pip install -v \ diff --git a/scripts/docker/Dockerfile.uv b/scripts/docker/Dockerfile.uv index 75f6fe12ba..01492428a5 100644 --- a/scripts/docker/Dockerfile.uv +++ b/scripts/docker/Dockerfile.uv @@ -35,7 +35,7 @@ RUN pip install uv && uv venv /opt/venv --python=python3.12 # Install minimal Trinity-RFT RUN . /opt/venv/bin/activate && \ - uv pip install -e.[mm,dev] + uv pip install -e.[vllm,mm,dev] # Install flash_attn and Megatron RUN . /opt/venv/bin/activate && \ diff --git a/tests/buffer/sample_strategy_test.py b/tests/buffer/sample_strategy_test.py index 32ea84bdb7..c3a9af9179 100644 --- a/tests/buffer/sample_strategy_test.py +++ b/tests/buffer/sample_strategy_test.py @@ -58,7 +58,9 @@ def _init_buffer_writer_and_sample_strategy(self): async def _verify_model_version(self, step, expected_versions): batch, metrics, _ = await self.sample_strategy.sample(step=step) self.assertEqual( - batch.rewards.tolist(), expected_versions, f"Model versions mismatch at step {step}" + [exp.reward for exp in batch], + expected_versions, + f"Model versions mismatch at step {step}", ) self.assertEqual( metrics["sample/model_version/min"], diff --git a/tests/common/config_test.py b/tests/common/config_test.py index 0a5a57bbdb..2502dae824 100644 --- a/tests/common/config_test.py +++ b/tests/common/config_test.py @@ -45,6 +45,7 @@ def test_all_examples_are_valid(self): filename.startswith("train_") or filename.startswith("verl_") or filename.startswith("dj_") + or filename.startswith("tinker") ): print(f"Checking config: {filename}") config_path = os.path.join(example_dir, example_name, filename) diff --git a/tests/common/vllm_test.py b/tests/common/vllm_test.py index 7b545d52f3..b6c675f550 100644 --- a/tests/common/vllm_test.py +++ b/tests/common/vllm_test.py @@ -125,17 +125,13 @@ def setUp(self): self.config.algorithm.repeat_times = self.repeat_times self.config.explorer.rollout_model.enable_history = self.enable_history self.config.check_and_update() - from pprint import pprint - pprint(self.config) self.engines, self.auxiliary_engines = create_inference_models(self.config) self.model_wrapper = ModelWrapper( self.engines[0], engine_type="vllm", enable_history=self.enable_history ) - async def test_generate( - self, - ): + async def test_generate(self): await prepare_engines(self.engines, self.auxiliary_engines) await self.model_wrapper.prepare() self.assertEqual(self.model_wrapper.model_path, self.config.model.model_path) @@ -567,7 +563,7 @@ async def test_logprobs_api(self): torch.allclose(response_1.logprobs, logprobs_2_response, rtol=0.3, atol=1e-3) ) self.assertTrue( - torch.allclose(response_2.logprobs, logprobs_4_response, rtol=0.3, atol=1e-3) + torch.allclose(response_2.logprobs, logprobs_4_response, rtol=0.5, atol=1e-2) ) self.assertFalse( torch.allclose(response_2.logprobs, logprobs_3_response, rtol=0.3, atol=1e-3) @@ -616,7 +612,7 @@ async def test_logprobs_api(self): torch.allclose(response_1.logprobs, logprobs_2_response, rtol=0.3, atol=1e-3) ) self.assertTrue( - torch.allclose(response_2.logprobs, logprobs_4_response, rtol=0.3, atol=1e-3) + torch.allclose(response_2.logprobs, logprobs_4_response, rtol=0.5, atol=1e-2) ) self.assertFalse( torch.allclose(response_2.logprobs, logprobs_3_response, rtol=0.3, atol=1e-3) diff --git a/tests/trainer/trainer_test.py b/tests/trainer/trainer_test.py index 972071b99c..c01b55408d 100644 --- a/tests/trainer/trainer_test.py +++ b/tests/trainer/trainer_test.py @@ -980,17 +980,11 @@ async def test_serve_with_trainer(self): # noqa: C901 trainer_config = deepcopy(config) trainer_config.mode = "train" trainer_config.check_and_update() + trainer_config.trainer.max_actor_ckpt_to_keep = 10 trainer_process = multiprocessing.Process(target=run_trainer, args=(trainer_config,)) trainer_process.start() - await asyncio.sleep(5) - serve_config = deepcopy(config) - serve_config.mode = "serve" - serve_config.check_and_update() - serve_process = multiprocessing.Process(target=run_serve, args=(serve_config,)) - serve_process.start() - ray.init(ignore_reinit_error=True) while True: try: @@ -999,6 +993,11 @@ async def test_serve_with_trainer(self): # noqa: C901 except ValueError: print("waiting for trainer to start.") await asyncio.sleep(5) + serve_config = deepcopy(config) + serve_config.mode = "serve" + serve_config.check_and_update() + serve_process = multiprocessing.Process(target=run_serve, args=(serve_config,)) + serve_process.start() state_manager = StateManager( path=serve_config.checkpoint_job_dir, @@ -1322,3 +1321,38 @@ def test_trainer(self): def tearDown(self): # remove dir only when the test passed shutil.rmtree(self.config.checkpoint_job_dir) + + +class TestTinkerTrainer(BaseTrainerCase): + @unittest.skip("Require tinker API key") + def test_trainer(self): + """Test GSM8K on tinker.""" + # test both mode + self.config.algorithm.algorithm_type = "grpo" + self.config.algorithm.repeat_times = 4 + self.config.algorithm.advantage_fn = "grpo" + self.config.algorithm.advantage_fn_args = { + "epsilon": 1e-6, + } + self.config.buffer.total_epochs = 1 + self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k") + self.config.model.tinker.enable = True + self.config.model.tinker.base_model = "Qwen/Qwen3-4B-Instruct-2507" + self.config.check_and_update() + both(self.config) + parser = TensorBoardParser(os.path.join(self.config.monitor.cache_dir, "tensorboard")) + rollout_metrics = parser.metric_list("rollout") + self.assertTrue(len(rollout_metrics) > 0) + pipeline_metrics = parser.metric_list("experience_pipeline") + self.assertTrue(len(pipeline_metrics) > 0) + self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 4) + actor_metrics = parser.metric_list("actor") + self.assertTrue(len(actor_metrics) > 0) + self.assertEqual(parser.metric_max_step(actor_metrics[0]), 4) + response_metrics = parser.metric_list("response_length") + self.assertTrue(len(response_metrics) > 0) + self.assertEqual(parser.metric_max_step(response_metrics[0]), 4) + + def tearDown(self): + # remove dir only when the test passed + shutil.rmtree(self.config.checkpoint_job_dir) diff --git a/trinity/algorithm/advantage_fn/asymre_advantage.py b/trinity/algorithm/advantage_fn/asymre_advantage.py index 52f2314272..5c2ccf9932 100644 --- a/trinity/algorithm/advantage_fn/asymre_advantage.py +++ b/trinity/algorithm/advantage_fn/asymre_advantage.py @@ -1,10 +1,12 @@ """AsymRE advantage computation""" from collections import defaultdict -from typing import Dict, List, Tuple +from typing import TYPE_CHECKING, Dict, List, Tuple import torch -from verl import DataProto + +if TYPE_CHECKING: + from verl import DataProto from trinity.algorithm.advantage_fn import AdvantageFn, GroupAdvantage from trinity.common.experience import Experience, group_by @@ -23,9 +25,9 @@ def __init__( def __call__( self, - exps: DataProto, + exps: "DataProto", **kwargs, - ) -> Tuple[DataProto, Dict]: + ) -> Tuple["DataProto", Dict]: """Modified from compute_grpo_outcome_advantage Compute advantage for AsymRE, operating only on Outcome reward diff --git a/trinity/algorithm/advantage_fn/grpo_advantage.py b/trinity/algorithm/advantage_fn/grpo_advantage.py index 7d1c58977d..4562a5a4b9 100644 --- a/trinity/algorithm/advantage_fn/grpo_advantage.py +++ b/trinity/algorithm/advantage_fn/grpo_advantage.py @@ -3,10 +3,12 @@ import copy from collections import defaultdict -from typing import Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple import torch -from verl import DataProto + +if TYPE_CHECKING: + from verl import DataProto from trinity.algorithm.advantage_fn.advantage_fn import AdvantageFn, GroupAdvantage from trinity.common.experience import Experience, group_by @@ -26,9 +28,9 @@ def __init__( def __call__( self, - exps: DataProto, + exps: "DataProto", **kwargs, - ) -> Tuple[DataProto, Dict]: + ) -> Tuple["DataProto", Dict]: """ Compute advantage for GRPO, operating only on Outcome reward (with only one scalar reward for each response). diff --git a/trinity/algorithm/advantage_fn/opmd_advantage.py b/trinity/algorithm/advantage_fn/opmd_advantage.py index d5e9203e3c..8c0a586986 100644 --- a/trinity/algorithm/advantage_fn/opmd_advantage.py +++ b/trinity/algorithm/advantage_fn/opmd_advantage.py @@ -1,10 +1,12 @@ """OPMD advantage computation""" from collections import defaultdict -from typing import Dict, List, Tuple +from typing import TYPE_CHECKING, Dict, List, Tuple import torch -from verl import DataProto + +if TYPE_CHECKING: + from verl import DataProto from trinity.algorithm.advantage_fn.advantage_fn import AdvantageFn, GroupAdvantage from trinity.common.experience import Experience, group_by @@ -25,9 +27,9 @@ def __init__( def __call__( self, - exps: DataProto, + exps: "DataProto", **kwargs, - ) -> Tuple[DataProto, Dict]: + ) -> Tuple["DataProto", Dict]: """Modified from compute_grpo_outcome_advantage Compute advantage for OPMD, operating only on Outcome reward diff --git a/trinity/algorithm/key_mapper.py b/trinity/algorithm/key_mapper.py index 09c1f988a6..4813c18fb5 100644 --- a/trinity/algorithm/key_mapper.py +++ b/trinity/algorithm/key_mapper.py @@ -26,4 +26,5 @@ def from_trinity(self, key: str) -> str: "advantages": "advantages", } ), + "tinker": KeyMapper({}), } diff --git a/trinity/algorithm/sample_strategy/mix_sample_strategy.py b/trinity/algorithm/sample_strategy/mix_sample_strategy.py index 65acc44d3c..8f0f6ed3d5 100644 --- a/trinity/algorithm/sample_strategy/mix_sample_strategy.py +++ b/trinity/algorithm/sample_strategy/mix_sample_strategy.py @@ -8,7 +8,7 @@ from trinity.algorithm.sample_strategy.utils import representative_sample from trinity.buffer import get_buffer_reader from trinity.common.config import BufferConfig -from trinity.common.experience import CustomField, Experiences +from trinity.common.experience import CustomField, Experience from trinity.utils.timer import Timer @@ -53,7 +53,7 @@ def __init__(self, buffer_config: BufferConfig, **kwargs): expert_buffer_config, ) - async def sample(self, step: int) -> Tuple[Experiences, Dict, List]: + async def sample(self, step: int) -> Tuple[List[Experience], Dict, List]: metrics = {} with Timer(metrics, "time/read_experience"): usual_exp_list = await self.usual_exp_buffer.read_async() @@ -82,24 +82,21 @@ async def sample(self, step: int) -> Tuple[Experiences, Dict, List]: repr_samples = representative_sample(exp_list) self.set_model_version_metric(exp_list, metrics) - with Timer(metrics, "time/gather_experience"): - exps = Experiences.gather_experiences( - experiences=exp_list, - pad_token_id=self.pad_token_id, # type: ignore [arg-type] - custom_fields=[ - CustomField( - source_field="is_expert", - destination_field="expert_mask", - data_type=torch.bool, - ), - CustomField( - source_field="step", - destination_field="step", - data_type=torch.int32, - ), - ], - ) # type: ignore - return exps, metrics, repr_samples + custom_fields = [ + CustomField( + source_field="is_expert", + destination_field="expert_mask", + data_type=torch.bool, + ), + CustomField( + source_field="step", + destination_field="step", + data_type=torch.int32, + ), + ] + for exp in exp_list: + exp.custom_fields = custom_fields + return exp_list, metrics, repr_samples @classmethod def default_args(cls) -> Dict: diff --git a/trinity/algorithm/sample_strategy/sample_strategy.py b/trinity/algorithm/sample_strategy/sample_strategy.py index 2ab63032cb..2398961ae6 100644 --- a/trinity/algorithm/sample_strategy/sample_strategy.py +++ b/trinity/algorithm/sample_strategy/sample_strategy.py @@ -4,7 +4,7 @@ from trinity.algorithm.sample_strategy.utils import representative_sample from trinity.buffer import get_buffer_reader from trinity.common.config import BufferConfig -from trinity.common.experience import Experience, Experiences +from trinity.common.experience import Experience from trinity.utils.annotations import Deprecated from trinity.utils.monitor import gather_metrics from trinity.utils.timer import Timer @@ -12,7 +12,7 @@ class SampleStrategy(ABC): def __init__(self, buffer_config: BufferConfig, **kwargs) -> None: - self.pad_token_id = buffer_config.pad_token_id + pass def set_model_version_metric(self, exp_list: List[Experience], metrics: Dict): metric_list = [ @@ -23,14 +23,14 @@ def set_model_version_metric(self, exp_list: List[Experience], metrics: Dict): metrics.update(gather_metrics(metric_list, "sample")) @abstractmethod - async def sample(self, step: int) -> Tuple[Experiences, Dict, List]: + async def sample(self, step: int) -> Tuple[List[Experience], Dict, List]: """Sample data from buffer. Args: step (`int`): The step number of current step. Returns: - `Experiences`: The sampled Experiences data. + `List[Experience]`: The sampled List[Experience] data. `Dict`: Metrics for logging. `List`: Representative data for logging. """ @@ -54,15 +54,13 @@ def __init__(self, buffer_config: BufferConfig, **kwargs): super().__init__(buffer_config) self.exp_buffer = get_buffer_reader(buffer_config.trainer_input.experience_buffer) # type: ignore[arg-type] - async def sample(self, step: int, **kwargs) -> Tuple[Experiences, Dict, List]: + async def sample(self, step: int, **kwargs) -> Tuple[List[Experience], Dict, List]: metrics = {} with Timer(metrics, "time/read_experience"): exp_list = await self.exp_buffer.read_async() repr_samples = representative_sample(exp_list) self.set_model_version_metric(exp_list, metrics) - with Timer(metrics, "time/gather_experience"): - exps = Experiences.gather_experiences(exp_list, self.pad_token_id) # type: ignore - return exps, metrics, repr_samples + return exp_list, metrics, repr_samples @classmethod def default_args(cls) -> dict: @@ -81,16 +79,14 @@ def __init__(self, buffer_config: BufferConfig, **kwargs): super().__init__(buffer_config) self.max_staleness = kwargs.get("max_staleness", float("inf")) - async def sample(self, step: int, **kwargs) -> Tuple[Experiences, Dict, List]: + async def sample(self, step: int, **kwargs) -> Tuple[List[Experience], Dict, List]: min_model_version = max(step - self.max_staleness, 0) metrics = {} with Timer(metrics, "time/read_experience"): exp_list = await self.exp_buffer.read_async(min_model_version=min_model_version) repr_samples = representative_sample(exp_list) self.set_model_version_metric(exp_list, metrics) - with Timer(metrics, "time/gather_experience"): - exps = Experiences.gather_experiences(exp_list, self.pad_token_id) # type: ignore - return exps, metrics, repr_samples + return exp_list, metrics, repr_samples @Deprecated diff --git a/trinity/cli/launcher.py b/trinity/cli/launcher.py index d9c0d95771..28ba57ecc8 100644 --- a/trinity/cli/launcher.py +++ b/trinity/cli/launcher.py @@ -176,9 +176,9 @@ def run(config_path: str, dlc: bool = False, plugin_dir: str = None): raise RuntimeError("Ray is not running, please start it by `ray start --head`.") try: - from trinity.trainer.verl.utils import get_latest_hf_checkpoint_path - if config.stages: + from trinity.trainer.verl.utils import get_latest_hf_checkpoint_path + state_manager = StateManager( path=os.path.join(config.checkpoint_root_dir, config.project, config.name) ) diff --git a/trinity/common/config.py b/trinity/common/config.py index 2a43b2e235..1b59dc99d4 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -428,6 +428,16 @@ class DataProcessorConfig: ) +@dataclass +class TinkerConfig: + enable: bool = False + rank: int = 32 # lora rank + seed: Optional[int] = None + train_mlp: bool = True + train_attn: bool = True + train_unembed: bool = True + + @dataclass class ModelConfig: # source model path @@ -472,6 +482,9 @@ class ModelConfig: rope_scaling: Optional[dict] = None rope_theta: Optional[float] = None + # tinker config + tinker: TinkerConfig = field(default_factory=TinkerConfig) + @dataclass class InferenceModelConfig: @@ -1149,6 +1162,9 @@ def _check_model(self) -> None: if not model.critic_model_path: model.critic_model_path = model.model_path + if model.tinker.enable: + self._check_tinker() + # check template if model.chat_template_path is not None and model.custom_chat_template is None: try: @@ -1160,7 +1176,51 @@ def _check_model(self) -> None: ) # check max_model_len, max_prompt_tokens, max_response_tokens + self._check_model_len() + + def _check_tinker(self) -> None: + model = self.model + from trinity.algorithm import ALGORITHM_TYPE + + algorithm = ALGORITHM_TYPE.get(self.algorithm.algorithm_type) + if algorithm.use_critic: + raise ValueError("Critic model is not supported when using tinker!") + + import tinker + + service_client = tinker.ServiceClient() + supported_models = { + item.model_name for item in service_client.get_server_capabilities().supported_models + } + if model.model_path not in supported_models: + logger.error(f"Supported models: {supported_models}") + raise ValueError(f"{model.model_path} is not supported by tinker!") + + if ( + self.algorithm.entropy_loss_fn != "none" + and self.algorithm.entropy_loss_fn_args.get("entropy_coef", 0.0) != 0.0 + ): + logger.warning( + "The entropy in Tinker trainer is an estimated value; " + "it is recommended to set `entropy_coef` to 0." + ) + + if self.explorer.rollout_model.engine_type != "tinker": + self.explorer.rollout_model.engine_type = "tinker" + logger.warning("Rollout model engine type is set to `tinker`.") + if self.trainer.trainer_type != "tinker": + self.trainer.trainer_type = "tinker" + logger.warning("Trainer type is set to `tinker`.") + + if self.synchronizer.sync_method == SyncMethod.NCCL: + self.synchronizer.sync_method = SyncMethod.CHECKPOINT + logger.warning( + "Tinker do not support NCCL, `synchronizer.sync_method` is set to `checkpoint`." + ) + + def _check_model_len(self) -> None: + model = self.model # if all three are set, check if they are valid if ( model.max_model_len is not None @@ -1225,6 +1285,103 @@ def _check_model(self) -> None: "`enable_prompt_truncation` is set to False; please make sure the prompt is not too long and `max_model_len` is large enough, otherwise prompt length + response length may exceed `max_model_len`!" ) + def _check_explorer(self) -> None: + rollout_args = ["temperature", "top_p", "top_k", "logprobs", "repetition_penalty"] + length_args = [ + "max_model_len", + "max_prompt_tokens", + "max_response_tokens", + "min_response_tokens", + "enable_prompt_truncation", + ] + rope_args = ["rope_scaling", "rope_theta"] + model_args = rollout_args + length_args + rope_args + set_if_none(self.explorer.rollout_model, "model_path", self.model.model_path) + for args in model_args: + set_if_none(self.explorer.rollout_model, args, getattr(self.model, args)) + if ( + self.explorer.rollout_model.chat_template is None + and self.model.custom_chat_template is not None + ): + self.explorer.rollout_model.chat_template = self.model.custom_chat_template + for aux_model in self.explorer.auxiliary_models: + if not aux_model.model_path: + raise ValueError("auxiliary model's model_path is required.") + for args in model_args: + set_if_none(aux_model, args, getattr(self.model, args)) + + if self.explorer.rollout_model.engine_type != "tinker": + # check gpu number + rollout_gpu_num = ( + self.explorer.rollout_model.tensor_parallel_size + * self.explorer.rollout_model.engine_num + + sum( + ( + model.tensor_parallel_size * model.engine_num + for model in self.explorer.auxiliary_models + ) + ) + ) + assert self.cluster.node_num is not None + assert self.cluster.gpu_per_node is not None + total_gpu_num = self.cluster.node_num * self.cluster.gpu_per_node + if self.mode in ["explore", "bench", "serve"] and rollout_gpu_num > total_gpu_num: + raise ValueError( + f"Total GPU number ({total_gpu_num}) is less than the number of GPUs required for rollout ({rollout_gpu_num})." + ) + elif self.mode == "both" and rollout_gpu_num >= total_gpu_num: + raise ValueError( + f"Not enough GPUs for trainer in 'both' mode. Explorer requires {rollout_gpu_num} GPUs, but total available GPUs are {total_gpu_num}." + ) + + if self.explorer.over_rollout.ratio > 0.0: + if not (0.0 <= self.explorer.over_rollout.ratio < 1.0): + raise ValueError("over_rollout_ratio should be in [0.0, 1.0)") + if self.synchronizer.sync_style == SyncStyle.FIXED: + raise ValueError( + "over_rollout_ratio is not compatible with fixed sync_style, please set " + "`synchronizer.sync_style` to `dynamic_by_explorer` or `dynamic_by_trainer`." + ) + + # for lora configs + if not self.model.tinker.enable and self.model.lora_configs is not None: + self.explorer.rollout_model.enable_lora = True + if len(self.model.lora_configs) > 1: + raise ValueError("Only one lora adapter is supported for now.") + if self.model.lora_configs[0].path is None: + logger.info("Creating dummy lora, since no lora_path is provided.") + lora_path = create_dummy_lora( + model_path=self.model.model_path, + checkpoint_job_dir=self.checkpoint_job_dir, + lora_rank=self.model.lora_configs[0].lora_rank, + lora_alpha=self.model.lora_configs[0].lora_alpha, + target_modules=self.model.lora_configs[0].target_modules, + ) + self.model.lora_configs[0].path = lora_path + self.explorer.rollout_model.lora_modules = [ + { + "lora_int_id": i + 1, + "lora_name": cfg.name, + "lora_path": cfg.path, + "base_model_name": cfg.base_model_name, + } + for i, cfg in enumerate(self.model.lora_configs) + ] + self.explorer.rollout_model.lora_kwargs = { + "max_loras": len(self.model.lora_configs), + "max_lora_rank": max( + ( + model_config.lora_rank + for model_config in self.model.lora_configs + if model_config.lora_rank > 0 + ), + default=0, + ), + "default_lora_path": os.path.join( + self.checkpoint_job_dir, "global_step_0", "actor", "lora_adapter" + ), # will be poped later + } + def __iter__(self): """Iterate over configs with each stage applied in order. @@ -1291,99 +1448,7 @@ def check_and_update(self) -> Config: # noqa: C901 # check explorer if self.explorer is not None: - rollout_args = ["temperature", "top_p", "top_k", "logprobs", "repetition_penalty"] - length_args = [ - "max_model_len", - "max_prompt_tokens", - "max_response_tokens", - "min_response_tokens", - "enable_prompt_truncation", - ] - rope_args = ["rope_scaling", "rope_theta"] - model_args = rollout_args + length_args + rope_args - for args in ["model_path"] + model_args: - set_if_none(self.explorer.rollout_model, args, getattr(self.model, args)) - if ( - self.explorer.rollout_model.chat_template is None - and self.model.custom_chat_template is not None - ): - self.explorer.rollout_model.chat_template = self.model.custom_chat_template - for aux_model in self.explorer.auxiliary_models: - if not aux_model.model_path: - raise ValueError("auxiliary model's model_path is required.") - for args in model_args: - set_if_none(aux_model, args, getattr(self.model, args)) - - # check gpu number - rollout_gpu_num = ( - self.explorer.rollout_model.tensor_parallel_size - * self.explorer.rollout_model.engine_num - + sum( - ( - model.tensor_parallel_size * model.engine_num - for model in self.explorer.auxiliary_models - ) - ) - ) - assert self.cluster.node_num is not None - assert self.cluster.gpu_per_node is not None - total_gpu_num = self.cluster.node_num * self.cluster.gpu_per_node - if self.mode in ["explore", "bench", "serve"] and rollout_gpu_num > total_gpu_num: - raise ValueError( - f"Total GPU number ({total_gpu_num}) is less than the number of GPUs required for rollout ({rollout_gpu_num})." - ) - elif self.mode == "both" and rollout_gpu_num >= total_gpu_num: - raise ValueError( - f"Not enough GPUs for trainer in 'both' mode. Explorer requires {rollout_gpu_num} GPUs, but total available GPUs are {total_gpu_num}." - ) - - if self.explorer.over_rollout.ratio > 0.0: - if not (0.0 <= self.explorer.over_rollout.ratio < 1.0): - raise ValueError("over_rollout_ratio should be in [0.0, 1.0)") - if self.synchronizer.sync_style == SyncStyle.FIXED: - raise ValueError( - "over_rollout_ratio is not compatible with fixed sync_style, please set " - "`synchronizer.sync_style` to `dynamic_by_explorer` or `dynamic_by_trainer`." - ) - - # for lora configs - if self.model.lora_configs is not None: - self.explorer.rollout_model.enable_lora = True - if len(self.model.lora_configs) > 1: - raise ValueError("Only one lora adapter is supported for now.") - if self.model.lora_configs[0].path is None: - logger.info("Creating dummy lora, since no lora_path is provided.") - lora_path = create_dummy_lora( - model_path=self.model.model_path, - checkpoint_job_dir=self.checkpoint_job_dir, - lora_rank=self.model.lora_configs[0].lora_rank, - lora_alpha=self.model.lora_configs[0].lora_alpha, - target_modules=self.model.lora_configs[0].target_modules, - ) - self.model.lora_configs[0].path = lora_path - self.explorer.rollout_model.lora_modules = [ - { - "lora_int_id": i + 1, - "lora_name": cfg.name, - "lora_path": cfg.path, - "base_model_name": cfg.base_model_name, - } - for i, cfg in enumerate(self.model.lora_configs) - ] - self.explorer.rollout_model.lora_kwargs = { - "max_loras": len(self.model.lora_configs), - "max_lora_rank": max( - ( - model_config.lora_rank - for model_config in self.model.lora_configs - if model_config.lora_rank > 0 - ), - default=0, - ), - "default_lora_path": os.path.join( - self.checkpoint_job_dir, "global_step_0", "actor", "lora_adapter" - ), # will be poped later - } + self._check_explorer() # check synchronizer self.synchronizer.ray_namespace = self.ray_namespace @@ -1391,14 +1456,17 @@ def check_and_update(self) -> Config: # noqa: C901 self.explorer.rollout_model.engine_num * self.explorer.rollout_model.tensor_parallel_size ) - if ( - self.mode in ["train", "explore", "bench", "serve"] - and self.synchronizer.sync_method == SyncMethod.NCCL - ): - self.synchronizer.sync_method = SyncMethod.CHECKPOINT - logger.warning( - f"`{self.mode}` mode does not support NCCL synchronization, set `synchronizer.sync_method` to `checkpoint`." - ) + if self.synchronizer.sync_method == SyncMethod.NCCL: + if self.mode in ["train", "explore", "bench", "serve"]: + self.synchronizer.sync_method = SyncMethod.CHECKPOINT + logger.warning( + f"`{self.mode}` mode does not support NCCL synchronization, set `synchronizer.sync_method` to `checkpoint`." + ) + if self.model.lora_configs is not None: + self.synchronizer.sync_method = SyncMethod.CHECKPOINT + logger.warning( + "LoRA is not supported with NCCL synchronization, set `synchronizer.sync_method` to `checkpoint`." + ) self._check_interval() @@ -1450,9 +1518,11 @@ def check_and_update(self) -> Config: # noqa: C901 f"Invalid trainer.save_hf_checkpoint: {self.trainer.save_hf_checkpoint}, " "must be one of 'last', 'always', or 'never'." ) + self.trainer.trainer_config.synchronize_config(self) + elif self.trainer.trainer_type == "tinker": + self.trainer.trainer_config = None else: raise ValueError(f"Invalid trainer type: {self.trainer_type}") - self.trainer.trainer_config.synchronize_config(self) # check service if self.service.data_juicer is not None: diff --git a/trinity/common/experience.py b/trinity/common/experience.py index 0f94e0bdd5..9fa48a59ef 100644 --- a/trinity/common/experience.py +++ b/trinity/common/experience.py @@ -136,6 +136,8 @@ class Experience: # for on-policy distillation teacher_logprobs: Optional[Tensor] = None # [resp_length] + custom_fields: List[CustomField] = field(default_factory=list) + def __init__( # noqa: C901 self, *, @@ -161,6 +163,7 @@ def __init__( # noqa: C901 rejected_messages=None, multi_modal_inputs=None, teacher_logprobs=None, + custom_fields=None, ): if action_mask is not None: experience_type = "multi_turn" @@ -250,6 +253,7 @@ def __init__( # noqa: C901 self.rejected = torch.tensor(self.rejected) if self.teacher_logprobs is not None and not isinstance(self.teacher_logprobs, Tensor): self.teacher_logprobs = torch.tensor(self.teacher_logprobs, dtype=torch.float32) + self.custom_fields = custom_fields or [] def serialize(self) -> bytes: """Serialize the experience to bytes.""" diff --git a/trinity/common/models/__init__.py b/trinity/common/models/__init__.py index 190be581cb..46958faa6c 100644 --- a/trinity/common/models/__init__.py +++ b/trinity/common/models/__init__.py @@ -45,15 +45,44 @@ def create_inference_models( from ray.util.placement_group import placement_group, placement_group_table from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy - from trinity.common.models.vllm_model import vLLMRolloutModel - logger = get_logger(__name__) engine_num = config.explorer.rollout_model.engine_num tensor_parallel_size = config.explorer.rollout_model.tensor_parallel_size rollout_engines = [] if config.explorer.rollout_model.engine_type.startswith("vllm"): + from trinity.common.models.vllm_model import vLLMRolloutModel + engine_cls = vLLMRolloutModel + elif config.explorer.rollout_model.engine_type == "tinker": + from trinity.common.models.tinker_model import TinkerModel + + engine_cls = TinkerModel + namespace = ray.get_runtime_context().namespace + rollout_engines = [ + ray.remote(engine_cls) + .options( + name=f"{config.explorer.name}_rollout_model_{i}", + namespace=namespace, + ) + .remote( + config=config.explorer.rollout_model, + ) + for i in range(engine_num) + ] + auxiliary_engines = [ + ray.remote(engine_cls) + .options( + name=f"{config.explorer.name}_auxiliary_model_{i}_{j}", + namespace=namespace, + ) + .remote( + config=config.explorer.auxiliary_models[i], + ) + for i, model_config in enumerate(config.explorer.auxiliary_models) + for j in range(model_config.engine_num) + ] + return rollout_engines, auxiliary_engines else: raise ValueError(f"Unknown engine type: {config.explorer.rollout_model.engine_type}") @@ -124,7 +153,7 @@ def create_inference_models( model_config.engine_type = "vllm" model_config.bundle_indices = ",".join([str(bid) for bid in bundles_for_engine]) engines.append( - ray.remote(vLLMRolloutModel) + ray.remote(engine_cls) .options( name=f"{config.explorer.name}_auxiliary_model_{i}_{j}", num_cpus=0, diff --git a/trinity/common/models/model.py b/trinity/common/models/model.py index 3fe6f2bf37..00d27b5742 100644 --- a/trinity/common/models/model.py +++ b/trinity/common/models/model.py @@ -46,6 +46,10 @@ async def prepare(self) -> None: """Prepare the model before inference.""" pass + @abstractmethod + async def sync_model(self, model_version: int) -> int: + """Sync the model with the latest model_version.""" + @abstractmethod def get_model_version(self) -> int: """Get the checkpoint version.""" @@ -105,7 +109,9 @@ def __init__( enable_history (bool): Whether to enable history recording. Default to False. enable_thinking (Optional[bool]): Whether to enable thinking mode. Default to None. Only used for Qwen3 series models. """ - assert engine_type.startswith("vllm"), "Only vLLM model is supported for now." + assert ( + engine_type.startswith("vllm") or engine_type == "tinker" + ), "Only vLLM and tinker model is supported for now." self.model = model self.api_address: str = None self.openai_client: openai.OpenAI = None @@ -205,13 +211,13 @@ async def generate_mm_async( def chat(self, messages: List[dict], **kwargs) -> List[Experience]: """Generate a list of experiences from a list of messages.""" lora_request = self.get_lora_request() - return ray.get(self.model.chat.remote(messages, lora_request, **kwargs)) + return ray.get(self.model.chat.remote(messages, lora_request=lora_request, **kwargs)) @_history_recorder async def chat_async(self, messages: List[dict], **kwargs) -> List[Experience]: """Generate a list of experiences from a list of messages in async.""" lora_request = await self.get_lora_request_async() - return await self.model.chat.remote(messages, lora_request, **kwargs) + return await self.model.chat.remote(messages, lora_request=lora_request, **kwargs) @_history_recorder def chat_mm( diff --git a/trinity/common/models/tinker_model.py b/trinity/common/models/tinker_model.py new file mode 100644 index 0000000000..e8b6d492a1 --- /dev/null +++ b/trinity/common/models/tinker_model.py @@ -0,0 +1,205 @@ +from typing import List, Optional, Sequence + +import ray +import tinker +import torch +from tinker import types +from torch import Tensor + +from trinity.common.config import InferenceModelConfig +from trinity.common.experience import Experience +from trinity.common.models.model import InferenceModel +from trinity.common.models.utils import get_action_mask_method +from trinity.manager.synchronizer import Synchronizer +from trinity.utils.log import get_logger + + +class TinkerModel(InferenceModel): + def __init__( + self, + config: InferenceModelConfig, + ) -> None: + self.config = config + self.model_version = -1 + self.synchronizer = Synchronizer.get_actor(namespace=ray.get_runtime_context().namespace) + self.logger = get_logger(__name__) + self.model = None + self.tokenizer = None + self.chat_template = None + if self.config.chat_template: + self.chat_template = self.config.chat_template + self.action_mask_method = get_action_mask_method(self.chat_template) + self.enable_thinking = config.enable_thinking + + async def _initialize_tokenizer(self) -> None: + """Initialize the tokenizer.""" + trainer_client = await self.service_client.create_lora_training_client_async( + base_model=self.config.model_path + ) + self.tokenizer = trainer_client.get_tokenizer() + + async def _generate_internal(self, prompt: dict, **kwargs) -> types.SampleResponse: + assert self.model is not None + sampling_params = { + "max_tokens": kwargs.get("max_tokens", self.config.max_response_tokens), + "seed": kwargs.get("seed", self.config.seed), + "temperature": kwargs.get("temperature", 1.0), + "top_k": kwargs.get("top_k", -1), + "top_p": kwargs.get("top_p", 1), + } + + return await self.model.sample_async( + prompt=types.ModelInput.from_ints(prompt["prompt_token_ids"]), + sampling_params=sampling_params, + num_samples=kwargs.get("n", 1), + include_prompt_logprobs=kwargs.get("include_prompt_logprobs", False), + topk_prompt_logprobs=kwargs.get("topk_prompt_logprobs", self.config.logprobs), + ) + + async def generate(self, prompt: str, **kwargs) -> Sequence[Experience]: + """Generate a responses from a prompt in async.""" + if self.tokenizer is None: + await self._initialize_tokenizer() + + # Tokenize once without truncation to check if truncation is needed + token_ids = self.tokenizer( # type: ignore + prompt, + truncation=False, + return_tensors="pt", + )[ + "input_ids" + ][0].tolist() + + # Check if truncation is needed and apply it + if self.config.enable_prompt_truncation and self.config.max_prompt_tokens is not None: + if len(token_ids) > self.config.max_prompt_tokens: + self.logger.warning( + f"Prompt was truncated to {self.config.max_prompt_tokens} tokens" + ) + token_ids = token_ids[: self.config.max_prompt_tokens + 1] # leave one for response + return [ + Experience( + tokens=token_ids, + logprobs=torch.zeros(1, dtype=torch.float32), + prompt_length=len(token_ids) - 1, + prompt_text=self.tokenizer.decode(token_ids[:-1]), + response_text=self.tokenizer.decode(token_ids[-1]), + truncate_status="prompt_truncated", + reward=0.0, + ) + for _ in range(kwargs.get("n", 1)) + ] + + output = await self._generate_internal(prompt={"prompt_token_ids": token_ids}, **kwargs) + experiences = [ + Experience( + tokens=torch.tensor(token_ids + sequence.tokens, dtype=torch.int32), + logprobs=torch.tensor(sequence.logprobs, dtype=torch.float32), + prompt_length=len(token_ids), + prompt_text=self.tokenizer.decode(token_ids), + response_text=self.tokenizer.decode(sequence.tokens), + ) + for sequence in output.sequences + ] + + return experiences + + async def chat(self, messages: List[dict], **kwargs) -> Sequence[Experience]: + """Generate experiences from a list of history chat messages in async.""" + if self.tokenizer is None: + await self._initialize_tokenizer() + if self.chat_template is None: + self.chat_template = self.tokenizer.get_chat_template() + if messages[-1]["role"] == "assistant": + prompt = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + continue_final_message=True, + chat_template=self.chat_template, + ) + else: + prompt = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + chat_template=self.chat_template, + enable_thinking=self.enable_thinking, + ) + return await self.generate(prompt=prompt, **kwargs) + + async def logprobs(self, token_ids: List[int], **kwargs) -> Tensor: + """Generate logprobs for a list of tokens in async.""" + logprobs = await self.model.compute_logprobs_async(types.ModelInput(token_ids)) + return torch.tensor(logprobs[1:], dtype=torch.float32) + + async def convert_messages_to_experience( + self, + messages: List[dict], + tools: Optional[List[dict]] = None, + temperature: Optional[float] = None, + ) -> Experience: + """Convert a list of messages into an experience in async.""" + if self.tokenizer is None: + await self._initialize_tokenizer() + if self.chat_template is None: + self.chat_template = self.tokenizer.get_chat_template() + token_ids, action_mask, prompt_length = self.action_mask_method( + tokenizer=self.tokenizer, + messages=messages, + tools=tools, + chat_template=self.chat_template, + enable_thinking=self.enable_thinking, + ) # (seq_length, ), (seq_length, ) + + # Truncate tokens if they exceed the length limit + assert token_ids is not None + truncate_status = None + if self.config.max_model_len is not None and self.config.max_model_len > 0: + if len(token_ids) > self.config.max_model_len - 1: + truncate_status = "response_truncated" + self.logger.warning( + f"Warning: {len(token_ids)=} exceeds the length limit {(self.config.max_model_len - 1)=}" + ) + token_ids = token_ids[: self.config.max_model_len - 1] + action_mask = action_mask[: self.config.max_model_len - 1] + + temperature = temperature if temperature is not None else self.config.temperature + logprobs = await self.logprobs( + token_ids=token_ids.tolist(), temperature=temperature + ) # (seq_length - 1,) + return Experience( + tokens=token_ids, + logprobs=logprobs[prompt_length - 1 :], + prompt_length=prompt_length, + action_mask=action_mask[prompt_length:], # Exclude the prompt tokens + messages=messages, + truncate_status=truncate_status, + ) + + async def prepare(self) -> None: + """Prepare the model before inference.""" + self.service_client = tinker.ServiceClient() + self.model = await self.service_client.create_sampling_client_async( + base_model=self.config.model_path, + ) + + async def sync_model(self, model_version: int) -> int: + self.model_version = model_version + remote_sampler_path, _ = await self.synchronizer.get_model_state_dict.remote() + self.model = await self.service_client.create_sampling_client_async( + model_path=remote_sampler_path, + ) + return model_version + + def get_model_version(self) -> int: + """Get the checkpoint version.""" + return self.model_version + + def get_api_server_url(self) -> Optional[str]: + """Get the API server URL if available.""" + # TODO: tinker will support openai api later + return None + + def get_model_path(self) -> Optional[str]: + """Get the model path""" + return self.config.model_path # type: ignore [return-value] diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index 767d30a5ed..e369c7c17f 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -50,6 +50,7 @@ def __init__(self, config: Config): self.last_monitored_step = self.explore_step_num self.synchronizer = Synchronizer.get_actor(config) self.config = config + self.model_type = config.explorer.rollout_model.engine_type self.models, self.auxiliary_models = create_inference_models(config) self.experience_pipeline = self._init_experience_pipeline() self.taskset = ( @@ -149,7 +150,10 @@ async def _checkpoint_weights_update(self, step_num: Optional[int] = None) -> in async def _pull_latest_weights(self): self.logger.info("Start to pull latest model weights.") - new_version = await self.synchronizer.wait_new_model_state_dict.remote(self.model_version) + new_version = await self.synchronizer.wait_new_model_state_dict.remote( + current_version=self.model_version, + no_wait=(self.config.synchronizer.sync_style != SyncStyle.FIXED), + ) if new_version > self.model_version: if self.model_version != -1: self.logger.info(f"New model weights version: {new_version}") @@ -195,7 +199,7 @@ async def prepare(self) -> None: await asyncio.gather(*run_api_ref) self.logger.info("All models are ready.") - if not self.use_nccl_sync: + if not self.use_nccl_sync and self.model_type != "tinker": if self.config.mode == "serve": # In serving mode, each engine will setup its own process group await self.setup_model_level_weight_sync_group() @@ -444,6 +448,9 @@ async def shutdown(self) -> None: self.scheduler = None if self.experience_pipeline: await self.experience_pipeline.close.remote() + # reserve `experience_pipeline.output` for trainer + # TODO: refactor the lifecycle of buffer actor + self._old_experience_pipeline = self.experience_pipeline self.experience_pipeline = None if self.monitor: self.monitor.close() diff --git a/trinity/manager/synchronizer.py b/trinity/manager/synchronizer.py index 8157ad088d..c0b913812e 100644 --- a/trinity/manager/synchronizer.py +++ b/trinity/manager/synchronizer.py @@ -79,7 +79,16 @@ async def _check_modules(self) -> None: pass async def _find_latest_state_dict(self) -> None: - assert self.config.trainer.trainer_type == "verl" + if self.config.trainer.trainer_type == "verl": + await self._find_verl_latest_state_dict() + elif self.config.trainer.trainer_type == "tinker": + await self._find_tinker_latest_state_dict() + else: + self.logger.warning( + "Synchronizer does not support this trainer type. Please use `verl` or `tinker`." + ) + + async def _find_verl_latest_state_dict(self) -> None: default_local_dir = self.config.checkpoint_job_dir local_latest_state_dict_iteration = os.path.join( default_local_dir, "latest_state_dict_iteration.txt" @@ -112,6 +121,33 @@ async def _find_latest_state_dict(self) -> None: await self.set_model_state_dict(model_state_dict, latest_model_version) await asyncio.sleep(1) + async def _find_tinker_latest_state_dict(self) -> None: + default_local_dir = self.config.checkpoint_job_dir + local_latest_state_dict_iteration = os.path.join( + default_local_dir, "latest_state_dict_iteration.txt" + ) + while True: + if os.path.exists(local_latest_state_dict_iteration): + try: + with open(local_latest_state_dict_iteration, "r") as f: + latest_model_version = int(f.read().strip()) + except (IOError, ValueError) as e: + self.logger.warning(f"Failed to read or parse state dict iteration file: {e}") + continue + if latest_model_version > self.model_version: + self.logger.info( + f"Synchronizer has found a new remote tinker sampler path at step {latest_model_version}." + ) + remote_path_file = os.path.join( + default_local_dir, + f"global_step_{latest_model_version}", + "remote_sampler_path.txt", + ) + with open(remote_path_file, "r") as f: + remote_sampler_path = f.read().strip() + await self.set_model_state_dict(remote_sampler_path, latest_model_version) + await asyncio.sleep(1) + async def set_trainer_status(self, status: RunningStatus): """Update the status of the trainer.""" async with self._ready_condition: @@ -192,7 +228,7 @@ async def set_model_state_dict_with_step_num( return checkpoint_step_num async def set_model_state_dict( - self, model_state_dict: Union[dict, None, Tuple[str, str]], trainer_step: int + self, model_state_dict: Union[dict, None, str, Tuple[str, str]], trainer_step: int ): """ Set the new model state and update the version. diff --git a/trinity/trainer/tinker/__init__.py b/trinity/trainer/tinker/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/trinity/trainer/tinker/utils.py b/trinity/trainer/tinker/utils.py new file mode 100644 index 0000000000..544ef89768 --- /dev/null +++ b/trinity/trainer/tinker/utils.py @@ -0,0 +1,243 @@ +from logging import Logger +from typing import Any, List, Tuple + +import torch +from tinker import types + +from trinity.common.experience import Experience, split_dpo_experience_to_single_turn + + +def to_tinker_input( + experiences: List[Experience], logger: Logger +) -> Tuple[List[types.Datum], List[types.ModelInput], List[dict]]: + assert len(experiences) > 0, "No experiences provided." + if experiences[0].experience_type == "dpo": + experiences = split_dpo_experience_to_single_turn(experiences) + + batch = [] + batch_input_tokens = [] + model_inputs_list = [] + for exp in experiences: + tokens = exp.tokens + input_tokens = tokens.long() + prompt_length = exp.prompt_length + total_length = len(tokens) # type: ignore + response_length = total_length - prompt_length + loss_fn_inputs = { + "weights": torch.concat( + [ + torch.zeros(prompt_length - 1, dtype=torch.float32), + exp.action_mask.float(), + ] + ), + "target_tokens": input_tokens.tolist()[1:], + } + model_inputs = { + "total_length": total_length, + "action_mask": exp.action_mask, + } + if exp.reward is not None or exp.token_level_reward is not None: + assert exp.logprobs is not None + if exp.token_level_reward is not None: + if exp.reward is not None: + logger.warning( + "Both exp.rewards and exp.token_level_rewards are provided. " + "Using exp.token_level_rewards." + ) + token_level_reward = exp.token_level_reward + else: + token_level_reward = torch.zeros(response_length, dtype=torch.float32) + token_level_reward[-1] = exp.reward + model_inputs.update( + { + "token_level_scores": token_level_reward, + "old_logprob": exp.logprobs, + } + ) + for attr in ["advantages", "returns", "teacher_logprobs"]: + if getattr(exp, attr, None) is not None: + model_inputs[attr] = getattr(exp, attr) + # TODO: if tinker support multi-modal input, we can add it here + for custom_field in exp.custom_fields: + model_inputs[custom_field.destination_field] = torch.tensor( + exp.info[custom_field.source_field], + dtype=custom_field.data_type, + ) + + batch.append( + types.Datum( + model_input=types.ModelInput.from_ints(tokens=input_tokens.tolist()[:-1]), + loss_fn_inputs=loss_fn_inputs, + ) + ) + batch_input_tokens.append(types.ModelInput.from_ints(input_tokens.tolist())) + model_inputs_list.append(model_inputs) + return batch, batch_input_tokens, model_inputs_list + + +def compute_data_metrics(batch: List[dict[str, torch.Tensor]]) -> dict: + """ + Computes various metrics from a batch of data for PPO training. + Modified from `verl.trainer.ppo.metric_utils.compute_data_metrics`. + + This function calculates metrics related to scores, rewards, advantages, returns, values, + and sequence lengths from a batch of data. It provides statistical information (mean, max, min) + for each metric category. + + Args: + batch: A DataProto object containing batch data with token-level scores, rewards, advantages, etc. + use_critic: Whether to include critic-specific metrics. Defaults to True. + + Returns: + A dictionary of metrics including: + - critic/score/mean, max, min: Statistics about sequence scores + - critic/rewards/mean, max, min: Statistics about sequence rewards + - critic/advantages/mean, max, min: Statistics about advantages + - critic/returns/mean, max, min: Statistics about returns + - critic/values/mean, max, min: Statistics about critic values + - critic/vf_explained_var: Explained variance of the value function + - response_length/mean, max, min, clip_ratio: Statistics about response lengths + - prompt_length/mean, max, min, clip_ratio: Statistics about prompt lengths + """ + metrics = {} + + assert len(batch) > 0, "Batch is empty" + + if "token_level_rewards" in batch[0] and "token_level_scores" in batch[0]: + sequence_score = torch.tensor([data["token_level_scores"].sum() for data in batch]) + sequence_reward = torch.tensor([data["token_level_rewards"].sum() for data in batch]) + metrics.update( + { + # score + "critic/score/mean": torch.mean(sequence_score).detach().item(), + "critic/score/max": torch.max(sequence_score).detach().item(), + "critic/score/min": torch.min(sequence_score).detach().item(), + # reward + "critic/rewards/mean": torch.mean(sequence_reward).detach().item(), + "critic/rewards/max": torch.max(sequence_reward).detach().item(), + "critic/rewards/min": torch.min(sequence_reward).detach().item(), + } + ) + + response_length = torch.tensor([len(data["action_mask"]) for data in batch]).float() + token_length = torch.tensor([data["total_length"] for data in batch]).float() + prompt_length = token_length - response_length + max_response_length = max(response_length) + max_prompt_length = max(prompt_length) + metrics.update( + { + # response length + "response_length/mean": torch.mean(response_length).detach().item(), + "response_length/max": torch.max(response_length).detach().item(), + "response_length/min": torch.min(response_length).detach().item(), + "response_length/clip_ratio": torch.mean( + torch.eq(response_length, max_response_length).float() + ) + .detach() + .item(), + # prompt length + "prompt_length/mean": torch.mean(prompt_length).detach().item(), + "prompt_length/max": torch.max(prompt_length).detach().item(), + "prompt_length/min": torch.min(prompt_length).detach().item(), + "prompt_length/clip_ratio": torch.mean( + torch.eq(prompt_length, max_prompt_length).float() + ) + .detach() + .item(), + } + ) + + if "advantages" in batch[0]: + valid_adv = torch.concat([data["advantages"] for data in batch]) + metrics.update( + { + "critic/advantages/mean": torch.mean(valid_adv).detach().item(), + "critic/advantages/max": torch.max(valid_adv).detach().item(), + "critic/advantages/min": torch.min(valid_adv).detach().item(), + } + ) + if "returns" in batch[0]: + valid_returns = torch.concat([data["returns"] for data in batch]) + metrics.update( + { + "critic/returns/mean": torch.mean(valid_returns).detach().item(), + "critic/returns/max": torch.max(valid_returns).detach().item(), + "critic/returns/min": torch.min(valid_returns).detach().item(), + } + ) + + return metrics + + +def compute_timing_metrics( + batch: List[dict[str, torch.Tensor]], timing_raw: dict[str, float] +) -> dict[str, Any]: + """ + Computes timing metrics for different processing stages in PPO training. + Modified from `verl.trainer.ppo.metric_utils.compute_timing_metrics`. + + This function calculates both raw timing metrics (in seconds) and per-token timing metrics + (in milliseconds) for various processing stages like generation, reference computation, + value computation, advantage computation, and model updates. + + Args: + batch: A DataProto object containing batch data with responses and attention masks. + timing_raw: A dictionary mapping stage names to their execution times in seconds. + + Returns: + A dictionary containing: + - timing_s/{name}: Raw timing in seconds for each stage + - timing_per_token_ms/{name}: Per-token timing in milliseconds for each stage + + Note: + Different stages use different token counts for normalization: + - "gen" uses only response tokens + - Other stages ("ref", "values", "adv", "update_critic", "update_actor") use all tokens + (prompt + response) + """ + num_overall_tokens = sum(data["total_length"] for data in batch) + num_response_tokens = sum(len(data["action_mask"]) for data in batch) + + num_tokens_of_section = { + "gen": num_response_tokens, + **{ + name: num_overall_tokens + for name in ["ref", "values", "adv", "update_critic", "update_actor"] + }, + } + + return { + **{f"timing_s/{name}": value for name, value in timing_raw.items()}, + **{ + f"timing_per_token_ms/{name}": timing_raw[name] * 1000 / num_tokens_of_section[name] + for name in set(num_tokens_of_section.keys()) & set(timing_raw.keys()) + }, + } + + +def compute_throughout_metrics( + batch: List[dict[str, torch.Tensor]], timing_raw: dict[str, float] +) -> dict[str, Any]: + """ + Computes throughput metrics for PPO training. + Modified from `verl.trainer.ppo.metric_utils.compute_throughout_metrics`. + + This function calculates performance metrics related to token processing speed, + including the total number of tokens processed and time per step. + + Args: + batch: A DataProto object containing batch data with meta information about token counts. + timing_raw: A dictionary mapping stage names to their execution times in seconds. + Must contain a "step" key with the total step time. + + Returns: + A dictionary containing: + - perf/total_num_tokens: Total number of tokens processed in the batch + - perf/time_per_step: Time taken for the step in seconds + """ + total_num_tokens = sum(data["total_length"] for data in batch) + time = timing_raw["step"] + return { + "perf/total_num_tokens": total_num_tokens, + "perf/time_per_step": time, + } diff --git a/trinity/trainer/tinker_trainer.py b/trinity/trainer/tinker_trainer.py new file mode 100644 index 0000000000..e355630063 --- /dev/null +++ b/trinity/trainer/tinker_trainer.py @@ -0,0 +1,324 @@ +import os +from typing import Dict, List + +import ray +import tinker +import torch +from tinker import types + +from trinity.algorithm import ALGORITHM_TYPE +from trinity.algorithm.advantage_fn import ADVANTAGE_FN +from trinity.algorithm.entropy_loss_fn import ENTROPY_LOSS_FN +from trinity.algorithm.entropy_loss_fn.entropy_loss_fn import DummyEntropyLossFn +from trinity.algorithm.kl_fn import KL_FN +from trinity.algorithm.policy_loss_fn import POLICY_LOSS_FN +from trinity.algorithm.utils import prefix_metrics +from trinity.common.config import Config +from trinity.common.experience import Experience +from trinity.manager.synchronizer import Synchronizer +from trinity.trainer.tinker.utils import ( + compute_data_metrics, + compute_throughout_metrics, + compute_timing_metrics, + to_tinker_input, +) +from trinity.trainer.trainer import TrainEngineWrapper +from trinity.utils.log import get_logger +from trinity.utils.timer import Timer + + +class TinkerTrainerWrapper(TrainEngineWrapper): + def __init__(self, config: Config): + self.config = config + self.logger = get_logger("tinker_trainer") + self._init_algorithm() + self.synchronizer = Synchronizer.get_actor(namespace=self.config.synchronizer.ray_namespace) + + def _init_algorithm(self): + self.algorithm = ALGORITHM_TYPE.get(self.config.algorithm.algorithm_type) + algorithm_config = self.config.algorithm + if self.algorithm.compute_advantage_in_trainer: + self.advantage_fn = ADVANTAGE_FN.get(algorithm_config.advantage_fn)( + **algorithm_config.advantage_fn_args + ) + self.kl_fn = KL_FN.get(algorithm_config.kl_penalty_fn)( + **algorithm_config.kl_penalty_fn_args + ) + # TODO + raise NotImplementedError( + "`compute_advantage_in_trainer` is not implemented yet in tinker" + ) + self.loss_agg_mode = algorithm_config.loss_agg_mode + self.policy_loss_fn = POLICY_LOSS_FN.get(algorithm_config.policy_loss_fn)( + backend="tinker", **algorithm_config.policy_loss_fn_args + ) + self.kl_loss_fn = KL_FN.get(algorithm_config.kl_loss_fn)(**algorithm_config.kl_loss_fn_args) + self.entropy_loss_fn = ENTROPY_LOSS_FN.get(algorithm_config.entropy_loss_fn)( + **algorithm_config.entropy_loss_fn_args + ) + + # EXPERIMENTAL: apply loss scale fix + self.do_fix_actor_microbatch_loss_scale = ( + self.config.trainer.fix_actor_microbatch_loss_scale + and (self.loss_agg_mode == "token-mean") + ) + + self.adam_params = types.AdamParams( + learning_rate=algorithm_config.optimizer.lr, + beta1=algorithm_config.optimizer.betas[0], + beta2=algorithm_config.optimizer.betas[1], + # eps is currently not in config + weight_decay=algorithm_config.optimizer.weight_decay, + grad_clip_norm=self.config.trainer.grad_clip, + ) + + async def prepare(self): + self.service_client = tinker.ServiceClient() + + name_prefix_list = [self.config.project, self.config.group, self.config.name] + self.tinker_checkpoint_name_prefix = "-".join( + [prefix for prefix in name_prefix_list if prefix] + ) + self.default_local_dir = self.config.checkpoint_job_dir + + self.local_latest_checkpointed_iteration = os.path.join( + self.config.checkpoint_job_dir, "latest_checkpointed_iteration.txt" + ) + self.local_latest_state_dict_iteration = os.path.join( + self.config.checkpoint_job_dir, "latest_state_dict_iteration.txt" + ) + + if os.path.exists(self.local_latest_checkpointed_iteration): + with open(self.local_latest_checkpointed_iteration, "r") as f: + self._train_step_num = self.latest_remote_checkpoint_step = int(f.read().strip()) + checkpoint_file_path = os.path.join( + self.default_local_dir, + f"global_step_{self._train_step_num}", + "remote_checkpoint_path.txt", + ) + with open(checkpoint_file_path, "r") as f: + self.latest_remote_checkpoint_path = f.read().strip() + self.actor_client = ( + await self.service_client.create_training_client_from_state_with_optimizer_async( + path=self.latest_remote_checkpoint_path, + ) + ) + else: + self.actor_client = await self.service_client.create_lora_training_client_async( + base_model=self.config.model.model_path, + rank=self.config.model.tinker.rank, + seed=self.config.model.tinker.seed, + train_mlp=self.config.model.tinker.train_mlp, + train_attn=self.config.model.tinker.train_attn, + train_unembed=self.config.model.tinker.train_unembed, + ) + self.latest_remote_checkpoint_step = 0 + self.latest_remote_checkpoint_path = None + self._train_step_num = 0 + + if os.path.exists(self.local_latest_state_dict_iteration): + with open(self.local_latest_state_dict_iteration, "r") as f: + self.latest_remote_sampler_step = int(f.read().strip()) + sampler_file_path = os.path.join( + self.default_local_dir, + f"global_step_{self.latest_remote_sampler_step}", + "remote_sampler_path.txt", + ) + with open(sampler_file_path, "r") as f: + self.latest_remote_sampler_path = f.read().strip() + else: + self.latest_remote_sampler_step = 0 + self.latest_remote_sampler_path = None + + self.ref_client = await self.service_client.create_sampling_client_async( + base_model=self.config.model.model_path, + ) + + @property + def train_step_num(self) -> int: + """Get the current training step number.""" + return self._train_step_num + + def _loss_func( + self, batch: list[types.Datum], logprobs: list[torch.Tensor] + ) -> tuple[torch.Tensor, dict[str, float]]: + total_loss = 0.0 + metrics = {} + assert len(self.model_inputs_list) == len( + logprobs + ), "len(self.model_inputs_list) must equal to len(logprobs)" + for model_inputs, logprob in zip(self.model_inputs_list, logprobs): + micro_batch_metrics = {} + response_mask = model_inputs["action_mask"] + logprob = logprob[-response_mask.shape[0] :] + + pg_loss, pg_loss_metrics = self.policy_loss_fn(logprob=logprob, **model_inputs) + prefix_metrics( + src_metrics=pg_loss_metrics, prefix="actor", dst_metrics=micro_batch_metrics + ) + + if self.entropy_loss_fn != DummyEntropyLossFn: + entropy = -(logprob * logprob.exp()) + else: + entropy = None + # compute entropy loss from entropy + entropy_loss, entropy_loss_metrics = self.entropy_loss_fn( # type: ignore + entropy=entropy, + **model_inputs, + loss_agg_mode=self.loss_agg_mode, + ) + prefix_metrics( + src_metrics=entropy_loss_metrics, + prefix="actor", + dst_metrics=micro_batch_metrics, + ) + + # compute kl loss + kl_loss, kl_loss_metrics = self.kl_loss_fn.calculate_kl_loss( + logprob=logprob, + ref_logprob=model_inputs["ref_logprob"], + response_mask=response_mask, + loss_agg_mode=self.loss_agg_mode, + old_logprob=model_inputs["old_logprob"], + ) + prefix_metrics( + src_metrics=kl_loss_metrics, + prefix="actor", + dst_metrics=micro_batch_metrics, + ) + + # compute policy loss + policy_loss = pg_loss - entropy_loss + kl_loss + loss_scale = 1.0 + if not self.do_fix_actor_microbatch_loss_scale: + loss_scale /= len(logprobs) + loss = policy_loss * loss_scale + total_loss = total_loss + loss + micro_batch_metrics["actor/final_loss"] = loss.detach().item() + + # update metrics + for key, val in micro_batch_metrics.items(): + if key not in metrics: + metrics[key] = [] + metrics[key].append(val) + + avg_metrics = {k: sum(v) / len(v) for k, v in metrics.items()} + return total_loss, avg_metrics + + async def train_step(self, batch_exps: List[Experience]) -> Dict: + """Training one step. + + Args: + batch (List[Experience]): A batch of experiences to train. + + Returns: + Dict: Metrics of the training step. + """ + batch, batch_input_tokens, model_inputs_list = to_tinker_input(batch_exps, self.logger) + self.model_inputs_list = model_inputs_list + timing_raw = {} + metrics = {} + self._train_step_num += 1 + + with Timer(timing_raw, "step"): + if self.algorithm.use_reference: # ref_logprob may not be used + import asyncio + + ref_logprobs = await asyncio.gather( + *[ + self.ref_client.compute_logprobs_async(input_tokens) + for input_tokens in batch_input_tokens + ] + ) + for model_inputs, ref_logprob in zip(model_inputs_list, ref_logprobs): + response_length = model_inputs["action_mask"].shape[0] + model_inputs["ref_logprob"] = torch.tensor(ref_logprob[-response_length:]) + + if self.algorithm.compute_advantage_in_trainer: + # TODO: following is verl format, which is not compatible with tinker + raise NotImplementedError( + "`compute_advantage_in_trainer` is not implemented yet in tinker" + ) + else: + # skip token_level_scores for sft/dpo + for model_inputs in model_inputs_list: + if "token_level_scores" in model_inputs: + assert "token_level_rewards" not in model_inputs + model_inputs["token_level_rewards"] = model_inputs["token_level_scores"] + + # update actor + with Timer(timing_raw, "update_actor"): + fwdbwd_future = await self.actor_client.forward_backward_custom_async( + batch, self._loss_func + ) + optim_future = await self.actor_client.optim_step_async(self.adam_params) + fwdbwd_result = await fwdbwd_future + optim_result = await optim_future + metrics.update(fwdbwd_result.metrics) + if optim_result.metrics: + metrics.update(optim_result.metrics) + + # collect metrics + metrics.update(compute_data_metrics(batch=self.model_inputs_list)) + timing_metrics = compute_timing_metrics(batch=self.model_inputs_list, timing_raw=timing_raw) + metrics.update({k.replace("timing_s/", "time/"): v for k, v in timing_metrics.items()}) + metrics.update( + compute_throughout_metrics(batch=self.model_inputs_list, timing_raw=timing_raw) + ) + + return metrics + + def save_checkpoint(self, block_until_saved: bool = False, save_as_hf: bool = False) -> None: + """Save the checkpoint.""" + if self.train_step_num == self.latest_remote_checkpoint_step: + return + self.latest_remote_checkpoint_step = self.train_step_num + checkpoint_name = f"{self.tinker_checkpoint_name_prefix}-state-{self.train_step_num}" + self.latest_remote_checkpoint_path = ( + self.actor_client.save_state(checkpoint_name).result().path + ) + local_path = os.path.join( + self.default_local_dir, + f"global_step_{self.train_step_num}", + ) + os.makedirs(local_path, exist_ok=True) + remote_checkpoint_path = os.path.join(local_path, "remote_checkpoint_path.txt") + with open(remote_checkpoint_path, "w") as f: + f.write(self.latest_remote_checkpoint_path) + + with open(self.local_latest_checkpointed_iteration, "w") as f: + f.write(str(self.train_step_num)) + + def sync_weight(self) -> None: + """Sync the model weight.""" + raise NotImplementedError("Tinker trainer does not support NCCL sync") + + def upload_state_dict(self) -> None: + """Upload the state dict to Synchronizer.""" + self.save_state_dict() + ray.get( + self.synchronizer.set_model_state_dict.remote( + self.latest_remote_sampler_path, self.train_step_num + ) + ) + + def save_state_dict(self) -> None: + """Only save the model state dict for Synchronizer.""" + if self.train_step_num == self.latest_remote_sampler_step: + return + self.latest_remote_sampler_step = self.train_step_num + checkpoint_name = f"{self.tinker_checkpoint_name_prefix}-sampler-{self.train_step_num}" + self.latest_remote_sampler_path = ( + self.actor_client.save_weights_for_sampler(checkpoint_name).result().path + ) + local_path = os.path.join( + self.default_local_dir, + f"global_step_{self.train_step_num}", + ) + os.makedirs(local_path, exist_ok=True) + remote_sampler_path = os.path.join(local_path, "remote_sampler_path.txt") + with open(remote_sampler_path, "w") as f: + f.write(self.latest_remote_sampler_path) + + with open(self.local_latest_state_dict_iteration, "w") as f: + f.write(str(self.train_step_num)) diff --git a/trinity/trainer/trainer.py b/trinity/trainer/trainer.py index c42fccfb26..c4901f3a2f 100644 --- a/trinity/trainer/trainer.py +++ b/trinity/trainer/trainer.py @@ -17,7 +17,7 @@ from trinity.algorithm.sample_strategy.sample_strategy import SampleStrategy from trinity.common.config import Config from trinity.common.constants import RunningStatus, SyncMethod, SyncStyle -from trinity.common.experience import Experiences +from trinity.common.experience import Experience from trinity.manager.state_manager import StateManager from trinity.manager.synchronizer import Synchronizer from trinity.utils.log import get_logger @@ -39,7 +39,6 @@ def __init__(self, config: Config) -> None: path=config.checkpoint_job_dir, trainer_name=config.trainer.name, config=config ) trainer_state = self.state.load_trainer() - self.last_trainer_sync_step = 0 self.monitor = MONITOR.get(config.monitor.monitor_type)( project=config.project, group=self.config.group, @@ -60,15 +59,15 @@ def __init__(self, config: Config) -> None: sample_strategy_state = trainer_state.get("sample_strategy_state", {}) self.sample_strategy.load_state_dict(sample_strategy_state) self.save_interval = config.trainer.save_interval - self.last_sync_step = None + self.last_sync_step = 0 self.last_sync_time = None self.total_steps = config.trainer.total_steps or float("inf") self.save_hf_checkpoint = config.trainer.save_hf_checkpoint async def prepare(self) -> None: """Prepare the trainer.""" - self.engine.prepare() - self.last_trainer_sync_step = self.train_step_num + await self.engine.prepare() + self.last_sync_step = self.train_step_num await self.synchronizer.set_trainer_status.remote(RunningStatus.RUNNING) async def train(self) -> str: @@ -109,7 +108,7 @@ async def train(self) -> str: self.logger.info("--------------------\n> Trainer finished.\n--------------------") return self.config.trainer.name - async def train_step(self, exps: Experiences) -> Dict: + async def train_step(self, exps: List[Experience]) -> Dict: """Train one step. Returns: @@ -119,21 +118,21 @@ async def train_step(self, exps: Experiences) -> Dict: self.logger.info(f"Training at step {self.train_step_num + 1} started.") metrics = {} with Timer(metrics, "time/train_step"): - train_metrics = self.engine.train_step(exps) + train_metrics = await self.engine.train_step(exps) self.logger.info(f"Training at step {self.train_step_num} finished.") metrics.update(train_metrics) return metrics - async def _sample_data(self) -> Tuple[Experiences, Dict, List[Dict]]: + async def _sample_data(self) -> Tuple[List[Experience], Dict, List[Dict]]: """Sample a batch of experiences. Returns: - Experiences: A batch of experiences. + List[Experience]: A batch of experiences. Dict: Metrics of the sampling step. List[Dict]: A list of representative samples for logging. """ batch, metrics, repr_samples = await self.sample_strategy.sample(self.train_step_num + 1) - metrics["sample/task_count"] = len(set(eid.tid for eid in batch.eids)) + metrics["sample/task_count"] = len(set(exp.eid.tid for exp in batch)) return batch, metrics, repr_samples async def need_sync(self) -> bool: @@ -145,14 +144,17 @@ async def need_sync(self) -> bool: ) else: if self.config.synchronizer.sync_style == SyncStyle.DYNAMIC_BY_TRAINER: - delta = self.train_step_num - self.last_trainer_sync_step + delta = self.train_step_num - self.last_sync_step if delta >= self.config.synchronizer.sync_interval: await self.synchronizer.set_trainer_status.remote(RunningStatus.REQUIRE_SYNC) explorer_status_counts = await self.synchronizer.get_explorer_status_counts.remote() if self.config.synchronizer.sync_method == SyncMethod.NCCL: return explorer_status_counts[RunningStatus.WAITING_SYNC] > 0 else: # memory & checkpoint - return explorer_status_counts[RunningStatus.REQUIRE_SYNC] > 0 + return ( + self.last_sync_step != self.train_step_num + and explorer_status_counts[RunningStatus.REQUIRE_SYNC] > 0 + ) def need_save(self) -> bool: """Whether to save the checkpoint.""" @@ -173,7 +175,6 @@ async def sync_weight(self) -> Dict: self.logger.error("Trainer sync_weights failed.") else: self.engine.sync_weight() - self.last_trainer_sync_step = self.train_step_num elif self.config.synchronizer.sync_method == SyncMethod.CHECKPOINT: self.engine.save_state_dict() elif self.config.synchronizer.sync_method == SyncMethod.MEMORY: @@ -229,7 +230,7 @@ class TrainEngineWrapper(ABC): """A wrapper class to wrap various training engines.""" @abstractmethod - def prepare(self) -> None: + async def prepare(self) -> None: """Do some preparation before training started.""" @property @@ -238,11 +239,11 @@ def train_step_num(self) -> int: """Get the current training step number.""" @abstractmethod - def train_step(self, batch: Experiences) -> Dict: + async def train_step(self, batch_exps: List[Experience]) -> Dict: """Training one step. Args: - batch (Experiences): A batch of experiences to train. + batch_exps (List[Experience]): A batch of experiences to train. Returns: Dict: Metrics of the training step. @@ -271,5 +272,9 @@ def get_trainer_wrapper(config: Config) -> TrainEngineWrapper: from trinity.trainer.verl_trainer import VerlPPOTrainerWrapper return VerlPPOTrainerWrapper(config) + elif config.trainer.trainer_type == "tinker": + from trinity.trainer.tinker_trainer import TinkerTrainerWrapper + + return TinkerTrainerWrapper(config) else: raise NotImplementedError diff --git a/trinity/trainer/verl/utils.py b/trinity/trainer/verl/utils.py index 922182e643..640ee2b748 100644 --- a/trinity/trainer/verl/utils.py +++ b/trinity/trainer/verl/utils.py @@ -2,6 +2,7 @@ import os from logging import Logger +from typing import List import numpy as np import torch @@ -10,75 +11,98 @@ from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path from trinity.common.config import Config -from trinity.common.experience import Experiences - - -def to_data_proto(experiences: Experiences, logger: Logger) -> DataProto: # noqa: C901 - """Convert Experiences to verl DataProto.""" - attention_mask = experiences.attention_masks +from trinity.common.experience import ( + Experience, + gather_action_masks, + gather_attention_masks, + gather_response_attrs, + gather_token_ids, + split_dpo_experience_to_single_turn, +) + + +def to_data_proto( + experiences: List[Experience], pad_token_id: int, logger: Logger +) -> DataProto: # noqa: C901 + """Convert List[Experience] to verl DataProto.""" + assert len(experiences) > 0, "No experiences provided." + if experiences[0].experience_type == "dpo": + experiences = split_dpo_experience_to_single_turn(experiences) + max_prompt_length = max([exp.prompt_length for exp in experiences]) + max_response_length = max([len(exp.tokens) - exp.prompt_length for exp in experiences]) # type: ignore + + attention_mask = gather_attention_masks( + experiences, max_prompt_length, max_response_length + ).long() cumsum = torch.cumsum(attention_mask, dim=-1) position_ids = torch.clip(cumsum - 1, 0, None).long() + tokens = gather_token_ids( + experiences, max_prompt_length, max_response_length, pad_token_id + ).long() batch_dict = { - "uid": np.array([eid.tid for eid in experiences.eids]), - "unique_ids": np.array([eid.uid for eid in experiences.eids]), + "uid": np.array([exp.eid.tid for exp in experiences]), + "unique_ids": np.array([exp.eid.uid for exp in experiences]), "position_ids": position_ids, - "input_ids": experiences.tokens.long(), - "responses": experiences.tokens[:, experiences.prompt_length :].long(), - "attention_mask": attention_mask.long(), - "response_mask": ( - experiences.action_masks.long() - if hasattr(experiences, "action_masks") and experiences.action_masks is not None - else attention_mask[:, experiences.prompt_length :].long() - ), + "input_ids": tokens, + "responses": tokens[:, max_prompt_length:], + "attention_mask": attention_mask, + "response_mask": gather_action_masks(experiences, max_response_length), } - if experiences.rewards is not None or experiences.token_level_rewards is not None: - assert experiences.logprobs is not None - if experiences.token_level_rewards is not None: - if experiences.rewards is not None: + have_reward = all(exp.reward is not None for exp in experiences) + have_token_level_reward = all(exp.token_level_reward is not None for exp in experiences) + if have_reward or have_token_level_reward: + assert all(exp.logprobs is not None for exp in experiences), "No logprobs provided." + if have_token_level_reward: + if have_reward: logger.warning( "Both experiences.rewards and experiences.token_level_rewards are provided. " "Using experiences.token_level_rewards." ) - token_level_rewards = experiences.token_level_rewards + token_level_rewards = gather_response_attrs( + experiences, "token_level_reward", max_response_length + ) else: - token_level_rewards = torch.zeros(attention_mask.shape, dtype=experiences.rewards.dtype) + token_level_rewards = torch.zeros(attention_mask.shape, dtype=torch.float32) eos_mask_idx = cumsum.argmax(dim=-1) - token_level_rewards[ - torch.arange(experiences.batch_size), eos_mask_idx - ] = experiences.rewards - token_level_rewards = token_level_rewards[:, experiences.prompt_length :] + token_level_rewards[torch.arange(len(experiences)), eos_mask_idx] = torch.tensor( + [exp.reward for exp in experiences] + ) + token_level_rewards = token_level_rewards[:, max_prompt_length:] batch_dict.update( { "token_level_scores": token_level_rewards, - "old_log_probs": experiences.logprobs, # type: ignore + "old_log_probs": gather_response_attrs( + experiences, "logprobs", max_response_length + ), } ) - if experiences.advantages is not None: - batch_dict["advantages"] = experiences.advantages - if experiences.returns is not None: - batch_dict["returns"] = experiences.returns - if experiences.teacher_logprobs is not None: - batch_dict["teacher_log_probs"] = experiences.teacher_logprobs - - if experiences.multi_modal_inputs is not None: - batch_size = len(batch_dict["unique_ids"]) + + for attr in ["advantages", "returns", "teacher_logprobs"]: + if all(getattr(exp, attr, None) is not None for exp in experiences): + batch_dict[attr] = gather_response_attrs(experiences, attr, max_response_length) + + if all(exp.multi_modal_inputs is not None for exp in experiences): + keys = experiences[0].multi_modal_inputs.keys() batch_dict["multi_modal_inputs"] = np.array( - [ - {k: v[i] for k, v in experiences.multi_modal_inputs.items()} - for i in range(batch_size) - ], + [{key: exp.multi_modal_inputs[key] for key in keys} for exp in experiences], # type: ignore dtype=object, ) - if experiences.custom_fields: - for field in experiences.custom_fields: - if hasattr(experiences, field): - batch_dict[field] = getattr(experiences, field) + custom_fields_set = set(tuple(exp.custom_fields) for exp in experiences) + if len(custom_fields_set) == 1: + custom_fields = list(custom_fields_set)[0] + for custom_field in custom_fields: + batch_dict[custom_field.destination_field] = torch.tensor( + [exp.info[custom_field.source_field] for exp in experiences], + dtype=custom_field.data_type, + ) + else: + raise ValueError("Custom fields are not consistent across experiences.") return DataProto.from_single_dict(batch_dict) -def compute_data_metrics(batch: DataProto, use_critic: bool = False) -> dict: +def compute_data_metrics(batch: DataProto) -> dict: """ Computes various metrics from a batch of data for PPO training. Modified from verl.trainer.ppo.metric_utils.compute_data_metrics @@ -89,7 +113,6 @@ def compute_data_metrics(batch: DataProto, use_critic: bool = False) -> dict: Args: batch: A DataProto object containing batch data with token-level scores, rewards, advantages, etc. - use_critic: Whether to include critic-specific metrics. Defaults to True. Returns: A dictionary of metrics including: @@ -97,8 +120,8 @@ def compute_data_metrics(batch: DataProto, use_critic: bool = False) -> dict: - critic/rewards/mean, max, min: Statistics about sequence rewards - critic/advantages/mean, max, min: Statistics about advantages - critic/returns/mean, max, min: Statistics about returns - - critic/values/mean, max, min: Statistics about critic values (if use_critic=True) - - critic/vf_explained_var: Explained variance of the value function (if use_critic=True) + - critic/values/mean, max, min: Statistics about critic values + - critic/vf_explained_var: Explained variance of the value function - response_length/mean, max, min, clip_ratio: Statistics about response lengths - prompt_length/mean, max, min, clip_ratio: Statistics about prompt lengths """ diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py index 16c0525327..9c52af3d66 100644 --- a/trinity/trainer/verl_trainer.py +++ b/trinity/trainer/verl_trainer.py @@ -7,7 +7,7 @@ import os import sys from collections import defaultdict -from typing import Dict, Optional +from typing import Dict, List, Optional import ray import torch @@ -30,11 +30,11 @@ from verl.utils.fs import copy_local_path_from_hdfs from verl.utils.metric import reduce_metrics -from trinity.algorithm import ADVANTAGE_FN, ALGORITHM_TYPE, KL_FN, SAMPLE_STRATEGY +from trinity.algorithm import ADVANTAGE_FN, ALGORITHM_TYPE, KL_FN from trinity.algorithm.utils import prefix_metrics from trinity.common.config import Config from trinity.common.constants import SaveStrategy -from trinity.common.experience import Experiences +from trinity.common.experience import Experience from trinity.trainer.trainer import TrainEngineWrapper from trinity.trainer.verl.utils import compute_data_metrics, to_data_proto from trinity.utils.log import get_logger @@ -187,6 +187,7 @@ def __init__( global_config: Config, ): self.logger = get_logger(__name__, in_ray_actor=True) + self.pad_token_id = global_config.buffer.pad_token_id train_config = global_config.trainer config = OmegaConf.structured(train_config.trainer_config) # download the checkpoint from hdfs @@ -261,11 +262,6 @@ def __init__( self.kl_fn = KL_FN.get(self.algorithm_config.kl_penalty_fn)( **self.algorithm_config.kl_penalty_fn_args ) - self.sample_strategy = SAMPLE_STRATEGY.get(global_config.algorithm.sample_strategy)( - buffer_config=global_config.buffer, - trainer_type=global_config.trainer.trainer_type, - **global_config.algorithm.sample_strategy_args, - ) super().__init__( config, tokenizer, @@ -379,7 +375,7 @@ def init_workers(self): def train_step_num(self) -> int: return self.global_steps - def prepare(self): + async def prepare(self): self.actor_rollout_wg.setup_weight_sync_group() self.actor_rollout_wg.set_algorithm(self.algorithm_config) @@ -411,8 +407,8 @@ def save_state_dict(self): # checkpoint sync def upload_state_dict(self): # state dict sync self.actor_rollout_wg.upload_state_dict(self.global_steps) - def train_step(self, batch: Experiences) -> Dict: # noqa C901 - batch = to_data_proto(batch, self.logger) + async def train_step(self, batch_exps: List[Experience]) -> Dict: # noqa C901 + batch = to_data_proto(batch_exps, self.pad_token_id, self.logger) # type: ignore batch = self.post_process_batch(batch) metrics = {} self.global_steps += 1 @@ -476,7 +472,7 @@ def train_step(self, batch: Experiences) -> Dict: # noqa C901 metrics.update(actor_output_metrics) # collect metrics - metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) + metrics.update(compute_data_metrics(batch=batch)) timing_metrics = compute_timing_metrics(batch=batch, timing_raw=timing_raw) metrics.update({k.replace("timing_s/", "time/"): v for k, v in timing_metrics.items()}) n_gpus = self.resource_pool_manager.get_n_gpus()