diff --git a/games/game_werewolves/game.py b/games/game_werewolves/game.py index e7e08e6..b2355b8 100644 --- a/games/game_werewolves/game.py +++ b/games/game_werewolves/game.py @@ -19,7 +19,9 @@ from structured_model import ( get_seer_model, get_hunter_model, ) -from prompt import EnglishPrompts as Prompts +from prompt import ( + EnglishPrompts as Prompts, +) # pylint: disable=no-name-in-module # Uncomment the following line to use Chinese prompts # from prompt import ChinesePrompts as Prompts diff --git a/games/game_werewolves/utils.py b/games/game_werewolves/utils.py index d168b18..3430697 100644 --- a/games/game_werewolves/utils.py +++ b/games/game_werewolves/utils.py @@ -6,7 +6,9 @@ from typing import Any import numpy as np from agentscope.agent import AgentBase, ReActAgent from agentscope.message import Msg -from prompt import EnglishPrompts as Prompts +from prompt import ( + EnglishPrompts as Prompts, +) # pylint: disable=no-name-in-module MAX_GAME_ROUND = 30 MAX_DISCUSSION_ROUND = 3 diff --git a/tuner/README.md b/tuner/README.md index 902c6c4..e89f901 100644 --- a/tuner/README.md +++ b/tuner/README.md @@ -22,6 +22,6 @@ AgentScope Tuner requires: - `agentscope>=1.0.12` - `trinity-rft>=0.4.1` -AgentScope Tuner is built on top of [Trinity-RFT](https://github.com/modelscope/Trinity-RFT). -Please refer to the [Trinity-RFT installation guide](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/trinity_installation.html) +AgentScope Tuner is built on top of [Trinity-RFT](https://github.com/agentscope-ai/Trinity-RFT). +Please refer to the [Trinity-RFT installation guide](https://agentscope-ai.github.io/Trinity-RFT/en/main/tutorial/trinity_installation.html) for detailed instructions on how to set up the environment. diff --git a/tuner/README_zh.md b/tuner/README_zh.md index 8247b8c..f2c1abb 100644 --- a/tuner/README_zh.md +++ b/tuner/README_zh.md @@ -21,6 +21,6 @@ AgentScope Tuner 需要: - `agentscope>=1.0.12` - `trinity-rft>=0.4.1` -AgentScope Tuner 基于 [Trinity-RFT](https://github.com/modelscope/Trinity-RFT) 构建。 -请参考 [Trinity-RFT 安装指南](https://modelscope.github.io/Trinity-RFT/zh/main/tutorial/trinity_installation.html) +AgentScope Tuner 基于 [Trinity-RFT](https://github.com/agentscope-ai/Trinity-RFT) 构建。 +请参考 [Trinity-RFT 安装指南](https://agentscope-ai.github.io/Trinity-RFT/zh/main/tutorial/trinity_installation.html) 获取详细的安装方法。 diff --git a/tuner/email_search/README.md b/tuner/email_search/README.md new file mode 100644 index 0000000..d976b0d --- /dev/null +++ b/tuner/email_search/README.md @@ -0,0 +1,279 @@ +# Training Email Search Agent with RL using AgentScope-Tuner + +This example demonstrates how to implement reinforcement fine-tuning for the Email Search task (inspired by [ART](https://openpipe.ai/blog/art-e-mail-agent)) using AgentScope-Tuner, whose RFT functionality is backed by [Trinity-RFT](https://github.com/agentscope-ai/Trinity-RFT). + +## Task Setting + +The agent's goal is to answer user queries by searching through an email inbox. The agent needs to: +- Understand the user's question +- Search for relevant emails using keywords +- Read email contents to extract information +- Provide accurate answers with proper source citations + +**Agent Type**: The agent (`EmailSearchAgent`) extends `ReActAgent`, which follows a reasoning-acting loop to solve tasks iteratively. + +**Environment**: The environment is a SQLite database containing emails from the Enron Email dataset. Each task provides: +- `question`: The user's email search query +- `inbox_address`: The email inbox to search +- `query_date`: The date context for the query +- `answer`: The expected answer (ground truth), only for reward calculation +- `message_ids`: IDs of relevant emails containing the answer, only for reward calculation + +**Available Tools**: +- `search_emails`: Find emails by keywords, inbox address, and date range. Returns a list of email summaries (message_id and snippet). +- `read_email`: Read the full content of a specific email by message_id. +- `generate_response`: Provide the final structured answer with sources (inherited from ReAct agent). + +## Dataset Preparation + +The dataset contains email queries based on the [Enron Email dataset](https://huggingface.co/datasets/corbt/enron-emails). Run the data preparation script to generate the email database and datasets: + +```bash +python prepare_data.py +``` + +If you want to choose a new database path, you can modify the `DEFAULT_DB_PATH` in [`prepare_data.py`](./prepare_data.py). Also, remember to set an environment variable `DEFAULT_EMAIL_DB_PATH` to point to the database path before moving to the next step: + +```bash +export DEFAULT_EMAIL_DB_PATH=/path/to/enron_emails_dataset/data/enron_emails.db +``` + +This will create a SQLite database and datasets: + +``` +/path/to/enron_emails_dataset/ + ├── data + └── enron_emails.db # Email database + ├── train.parquet # Training samples + └── test.parquet # Test samples +``` + +Each sample looks like: + +```json +{ + "id": 0, + "question": "Were there any variances detected for hour 6 on 3/9/01?", + "answer": "Yes, variances were detected in both Generation and Energy Import/Export schedules for hour 6 on 3/9/01.", + "message_ids": ["<17407857.1075840601283.JavaMail.evans@thyme>"], + "how_realistic": 0.800000011920929, + "inbox_address": "pete.davis@enron.com", + "query_date": "2001-03-16" +} +``` + +## Code Implementation + +This section provides a high-level overview of the code implementation. For detailed implementation, please refer to the source code. + +### Agent Workflow + +The workflow function `run_email_search_agent` implements the agent-environment interaction loop: + +```python +async def run_email_search_agent( + task: Dict, + model: ChatModelBase, + auxiliary_models: Dict[str, ChatModelBase], +) -> WorkflowOutput: + # Parse task and create agent + agent = EmailSearchAgent( + name="email_search_agent", + sys_prompt=system_prompt, + model=model, + max_iters=max_turns, + ) + + # Run the agent with structured output + response = await agent.reply( + msg=Msg("user", question, role="user"), + structured_model=AnswerModel, + ) + + return WorkflowOutput(response=response) +``` + +The agent follows a ReAct pattern: it reasons about the task, calls tools to search and read emails, and finally generates a structured response containing the answer and source message IDs. + +### Judge Function + +The judge function `email_search_judge` implements reward calculation using LLM-as-a-Judge: + +```python +async def email_search_judge( + task: Dict, + response: Msg, + auxiliary_models: Dict[str, ChatModelBase], +) -> JudgeOutput: + # Extract answer and sources from response + answer = answer_and_sources.get("answer") + sources = answer_and_sources.get("sources", []) + + # Judge correctness using LLM-as-a-Judge + judge_model = auxiliary_models.get('judge') or list(auxiliary_models.values())[0] + judge_response = await judge_correctness( + answer, query, judge_model + ) + + # Calculate reward based on: + # - Answer correctness (accuracy: -1.0 to 1.0) + # - Source correctness (format: partial rewards) + # - Efficiency (bonus for fewer turns, correct sources) + result = {"accuracy": ..., "format": ...} # calculated based on judge_response + + return JudgeOutput( + reward=sum(result.values()), + metrics=metrics, + ) +``` + +The reward function considers: +- **Answer correctness**: Evaluated by LLM-as-a-Judge comparing the agent's answer with the ground truth +- **Source correctness**: Whether the agent cited the correct email message IDs +- **Efficiency**: Bonus rewards for finding/reading the correct email and taking fewer turns + +See [`main.py`](./main.py) and [`email_search_agent.py`](./email_search_agent.py) for implementation details. + +## How to Run + +### Prerequisites + +- At least 4 NVIDIA GPUs with CUDA 12.8 or newer + * Note: For the 30B Judge model, you need to use a GPU with at least 4080 memory; you can also run the model on multiple GPUs by using `tensor_parallel_size > 1` to reduce the memory usage (by default, `tensor_parallel_size=2`). +- Follow the Trinity-RFT [installation guide](https://agentscope-ai.github.io/Trinity-RFT/en/main/tutorial/trinity_installation.html) to install the latest version from source code +- Download the model checkpoint (example): + + ```bash + huggingface-cli download Qwen/Qwen3-4B-Instruct-2507 + huggingface-cli download Qwen/Qwen3-30B-A3B-Instruct-2507 # judge model + ``` + +### Configuration + +Adjust the configuration file ([`config.yaml`](./config.yaml)) based on your hardware. Key configuration sections include: + +- **TunerModelConfig**: Set `model_path` to your model checkpoint path +- **AlgorithmConfig**: Configure RL algorithm parameters (e.g., `multi_step_grpo`, learning rate, policy loss function) +- **DatasetConfig**: The dataset path is specified in `main.py` when creating the `DatasetConfig` object +- **Auxiliary Models**: Configure judge model settings for LLM-as-a-Judge + +For full configuration details, see [Trinity-RFT Configuration Guide](https://agentscope-ai.github.io/Trinity-RFT/en/main/tutorial/trinity_configs.html). + +### Start-Up Commands + +1. Prepare the dataset: + + ```bash + python prepare_data.py + export DEFAULT_EMAIL_DB_PATH=/path/to/enron_emails_dataset/data/enron_emails.db + ``` + +2. Set up a [Ray](https://github.com/ray-project/ray) cluster: + + ```bash + ray start --head + ``` + +3. Run the training script: + + ```bash + python main.py + ``` + +## Experimental Results + +### Quantitative Results + +The training results show improvements in agent performance over training iterations. Key metrics include: + +- **Train reward**: The average reward on training samples increases as the agent learns better strategies +- **Rollout accuracy**: The average accuracy on rollout samples increases as the agent learns better strategies + +![Training Rewards](./critic_reward_mean.png) + +![Rollout Accuracy](./rollout_accuracy_mean.png) + + +### Concrete Example + +An example of the agent's behavior is shown below: + +**Query:** "What do the color codes mean in the curve assessment?" + +We show the last several turns of agent responses: + +The agent performs multiple search attempts to find relevant emails. After some unsuccessful searches, the agent tries: + +**Tool call:** +```json +{ + "type": "tool_use", + "name": "search_emails", + "input": { + "inbox_address": "steven.kean@enron.com", + "query_date": "2001-04-03", + "keywords": ["curve", "assessment"] + } +} +``` + +**Tool result:** +```json +{ + "type": "tool_result", + "name": "search_emails", + "output": [ + { + "message_id": "<2654330.1075846153519.JavaMail.evans@thyme>", + "snippet": "... curves and Senior Management as necessary.\\n\\nCurve Assessment\\n Green: \\n - curve is reasonable\\n - small P&L ..." + }, + { + "message_id": "<12499967.1075846153495.JavaMail.evans@thyme>", + "snippet": "... curves and Senior Management as necessary.\\n\\nCurve Assessment\\n Green: \\n - curve is reasonable\\n - small P&L ..." + } + ] +} +``` + +After finding relevant emails, the agent uses `read_email` tool to read the full content of both emails: + +**Tool call:** +```json +{ + "type": "tool_use", + "name": "read_email", + "input": { + "message_id": "<2654330.1075846153519.JavaMail.evans@thyme>" + } +} +``` + +**Tool call:** +```json +{ + "type": "tool_use", + "name": "read_email", + "input": { + "message_id": "<12499967.1075846153495.JavaMail.evans@thyme>" + } +} +``` + +After reading the emails, the agent generates the final structured response: + +**Tool call:** +```json +{ + "type": "tool_use", + "name": "generate_response", + "input": { + "answer": "In the curve assessment, the color codes are used as follows:\n- Green: The curve is considered reasonable, and the P&L (profit and loss) is small.", + "sources": [ + "<2654330.1075846153519.JavaMail.evans@thyme>", + "<12499967.1075846153495.JavaMail.evans@thyme>" + ] + } +} +``` + +The judge evaluates the answer as correct based on the ground truth. diff --git a/tuner/email_search/README_zh.md b/tuner/email_search/README_zh.md new file mode 100644 index 0000000..f5e409c --- /dev/null +++ b/tuner/email_search/README_zh.md @@ -0,0 +1,279 @@ +# 使用 AgentScope-Tuner 训练邮件搜索智能体 + +本示例展示如何使用 AgentScope-Tuner 对邮件搜索任务(灵感来自 [ART](https://openpipe.ai/blog/art-e-mail-agent))进行强化微调,其 RFT 功能由 [Trinity-RFT](https://github.com/agentscope-ai/Trinity-RFT) 提供支持。 + +## 任务设定 + +智能体的目标是通过搜索邮件收件箱来回答用户查询。智能体需要: +- 理解用户的问题 +- 使用关键词搜索相关邮件 +- 阅读邮件内容以提取信息 +- 提供准确的答案并附上适当的来源引用 + +**智能体类型**:智能体(`EmailSearchAgent`)继承自 `ReActAgent`,遵循推理-行动循环来迭代解决任务。 + +**环境**:环境是一个包含来自 Enron 邮件数据集的 SQLite 数据库。每个任务提供: +- `question`:用户的邮件搜索查询 +- `inbox_address`:要搜索的邮件收件箱 +- `query_date`:查询的日期上下文 +- `answer`:预期答案(真实值),仅用于奖励计算 +- `message_ids`:包含答案的相关邮件 ID,仅用于奖励计算 + +**可用工具**: +- `search_emails`:通过关键词、收件箱地址和日期范围查找邮件。返回邮件摘要列表(message_id 和片段)。 +- `read_email`:通过 message_id 读取特定邮件的完整内容。 +- `generate_response`:提供带有来源的最终结构化答案(继承自 ReAct 智能体)。 + +## 数据集准备 + +数据集包含基于 [Enron 邮件数据集](https://huggingface.co/datasets/corbt/enron-emails) 的邮件查询。运行数据准备脚本以生成邮件数据库和数据集: + +```bash +python prepare_data.py +``` + +如果你想选择新的数据库路径,可以修改 [`prepare_data.py`](./prepare_data.py) 中的 `DEFAULT_DB_PATH`。同时,请记住在进入下一步之前设置环境变量 `DEFAULT_EMAIL_DB_PATH` 指向数据库路径: + +```bash +export DEFAULT_EMAIL_DB_PATH=/path/to/enron_emails_dataset/data/enron_emails.db +``` + +这将创建一个 SQLite 数据库和数据集: + +``` +/path/to/enron_emails_dataset/ + ├── data + └── enron_emails.db # 邮件数据库 + ├── train.parquet # 训练样本 + └── test.parquet # 测试样本 +``` + +每个样本如下所示: + +```json +{ + "id": 0, + "question": "Were there any variances detected for hour 6 on 3/9/01?", + "answer": "Yes, variances were detected in both Generation and Energy Import/Export schedules for hour 6 on 3/9/01.", + "message_ids": ["<17407857.1075840601283.JavaMail.evans@thyme>"], + "how_realistic": 0.800000011920929, + "inbox_address": "pete.davis@enron.com", + "query_date": "2001-03-16" +} +``` + +## 代码实现 + +本节提供代码实现的高级概览。详细实现请参考源代码。 + +### 智能体工作流 + +工作流函数 `run_email_search_agent` 实现智能体-环境交互循环: + +```python +async def run_email_search_agent( + task: Dict, + model: ChatModelBase, + auxiliary_models: Dict[str, ChatModelBase], +) -> WorkflowOutput: + # 解析任务并创建智能体 + agent = EmailSearchAgent( + name="email_search_agent", + sys_prompt=system_prompt, + model=model, + max_iters=max_turns, + ) + + # 使用结构化输出运行智能体 + response = await agent.reply( + msg=Msg("user", question, role="user"), + structured_model=AnswerModel, + ) + + return WorkflowOutput(response=response) +``` + +智能体遵循 ReAct 模式:它推理任务,调用工具搜索和阅读邮件,最后生成包含答案和来源消息 ID 的结构化响应。 + +### 评判函数 + +评判函数 `email_search_judge` 使用 LLM-as-a-Judge 实现奖励计算: + +```python +async def email_search_judge( + task: Dict, + response: Msg, + auxiliary_models: Dict[str, ChatModelBase], +) -> JudgeOutput: + # 从响应中提取答案和来源 + answer = answer_and_sources.get("answer") + sources = answer_and_sources.get("sources", []) + + # 使用 LLM-as-a-Judge 评判正确性 + judge_model = auxiliary_models.get('judge') or list(auxiliary_models.values())[0] + judge_response = await judge_correctness( + answer, query, judge_model + ) + + # 基于以下因素计算奖励: + # - 答案正确性(准确度:-1.0 到 1.0) + # - 来源正确性(格式:部分奖励) + # - 效率(对更少轮次、正确来源的奖励) + result = {"accuracy": ..., "format": ...} # 基于 judge_response 计算 + + return JudgeOutput( + reward=sum(result.values()), + metrics=metrics, + ) +``` + +奖励函数考虑以下因素: +- **答案正确性**:通过 LLM-as-a-Judge 比较智能体的答案与真实值进行评估 +- **来源正确性**:智能体是否引用了正确的邮件消息 ID +- **效率**:对找到/阅读正确邮件和更少轮次的奖励 + +详细实现请参考 [`main.py`](./main.py) 和 [`email_search_agent.py`](./email_search_agent.py)。 + +## 运行方法 + +### 前置要求 + +- 至少 4 张 NVIDIA GPU,CUDA 版本 ≥ 12.8 + * 注意:对于 30B 评判模型,需要使用至少 4080 显存的 GPU;你也可以通过使用 `tensor_parallel_size > 1` 在多张 GPU 上运行模型以减少显存使用(默认情况下,`tensor_parallel_size=2`)。 +- 按照 Trinity-RFT [安装指南](https://agentscope-ai.github.io/Trinity-RFT/zh/main/tutorial/trinity_installation.html) 从源码安装最新版本 +- 下载模型检查点(示例): + + ```bash + huggingface-cli download Qwen/Qwen3-4B-Instruct-2507 + huggingface-cli download Qwen/Qwen3-30B-A3B-Instruct-2507 # 评判模型 + ``` + +### 配置 + +根据你的硬件调整配置文件([`config.yaml`](./config.yaml))。关键配置部分包括: + +- **TunerModelConfig**:将 `model_path` 设置为你的模型检查点路径 +- **AlgorithmConfig**:配置 RL 算法参数(例如,`multi_step_grpo`、学习率、策略损失函数) +- **DatasetConfig**:数据集路径在创建 `DatasetConfig` 对象时在 `main.py` 中指定 +- **辅助模型**:为 LLM-as-a-Judge 配置评判模型设置 + +完整配置详情请参考 [Trinity-RFT 配置指南](https://agentscope-ai.github.io/Trinity-RFT/zh/main/tutorial/trinity_configs.html)。 + +### 启动命令 + +1. 准备数据集: + + ```bash + python prepare_data.py + export DEFAULT_EMAIL_DB_PATH=/path/to/enron_emails_dataset/data/enron_emails.db + ``` + +2. 启动 [Ray](https://github.com/ray-project/ray): + + ```bash + ray start --head + ``` + +3. 运行训练脚本: + + ```bash + python main.py + ``` + +## 实验结果 + +### 定量结果 + +训练结果显示智能体性能随训练迭代次数的提升。关键指标包括: + +- **训练奖励**:训练样本上的平均奖励随着智能体学习更好的策略而增加 +- **Rollout 准确度**:Rollout 样本上的平均准确度随着智能体学习更好的策略而增加 + +![Training Rewards](./critic_reward_mean.png) + +![Rollout Accuracy](./rollout_accuracy_mean.png) + + +### 具体示例 + +智能体行为示例如下: + +**查询:** "What do the color codes mean in the curve assessment?" + +我们展示智能体响应的最后几轮: + +智能体执行多次搜索尝试以找到相关邮件。经过一些不成功的搜索后,智能体尝试: + +**工具调用:** +```json +{ + "type": "tool_use", + "name": "search_emails", + "input": { + "inbox_address": "steven.kean@enron.com", + "query_date": "2001-04-03", + "keywords": ["curve", "assessment"] + } +} +``` + +**工具结果:** +```json +{ + "type": "tool_result", + "name": "search_emails", + "output": [ + { + "message_id": "<2654330.1075846153519.JavaMail.evans@thyme>", + "snippet": "... curves and Senior Management as necessary.\\n\\nCurve Assessment\\n Green: \\n - curve is reasonable\\n - small P&L ..." + }, + { + "message_id": "<12499967.1075846153495.JavaMail.evans@thyme>", + "snippet": "... curves and Senior Management as necessary.\\n\\nCurve Assessment\\n Green: \\n - curve is reasonable\\n - small P&L ..." + } + ] +} +``` + +找到相关邮件后,智能体使用 `read_email` 工具读取两封邮件的完整内容: + +**工具调用:** +```json +{ + "type": "tool_use", + "name": "read_email", + "input": { + "message_id": "<2654330.1075846153519.JavaMail.evans@thyme>" + } +} +``` + +**工具调用:** +```json +{ + "type": "tool_use", + "name": "read_email", + "input": { + "message_id": "<12499967.1075846153495.JavaMail.evans@thyme>" + } +} +``` + +阅读邮件后,智能体生成最终的结构化响应: + +**工具调用:** +```json +{ + "type": "tool_use", + "name": "generate_response", + "input": { + "answer": "In the curve assessment, the color codes are used as follows:\n- Green: The curve is considered reasonable, and the P&L (profit and loss) is small.", + "sources": [ + "<2654330.1075846153519.JavaMail.evans@thyme>", + "<12499967.1075846153495.JavaMail.evans@thyme>" + ] + } +} +``` + +评判器评估上面的答案为正确。 diff --git a/tuner/email_search/_email_search_agent.py b/tuner/email_search/_email_search_agent.py new file mode 100644 index 0000000..d548666 --- /dev/null +++ b/tuner/email_search/_email_search_agent.py @@ -0,0 +1,175 @@ +# -*- coding: utf-8 -*- +"""Adapted from Trinity-RFT""" +import json +import traceback +from dataclasses import asdict +from datetime import datetime, timedelta +from typing import Any +from _utils import ( # pylint: disable=E0611 + read_email_tool, + search_emails_tool, +) +from agentscope import logger +from agentscope.agent import ReActAgent +from agentscope.message import TextBlock +from agentscope.tool import Toolkit, ToolResponse + + +def pre_reasoning_hook(_self: Any, _kwargs: Any) -> dict[str, Any] | None: + """Pre-reasoning hook to remove tool_choice from kwargs.""" + _kwargs.pop("tool_choice", None) + return _kwargs + + +class EmailSearchAgent(ReActAgent): + """ + A customized ReAct agent with pre-defined tools for + email search and reading. + Ref: https://github.com/OpenPipe/ART + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + self.message_id_list = ( + [] + ) # List to store message IDs found during search + self.ever_read_message_ids = ( + [] + ) # List to store message IDs that have been read + toolkit = Toolkit() + toolkit.register_tool_function(self.search_emails) + toolkit.register_tool_function(self.read_email) + super().__init__(*args, toolkit=toolkit, **kwargs) + + self.register_instance_hook( + "pre_reasoning", + "tool_choice_hook", + pre_reasoning_hook, + ) + + async def reset(self) -> None: + """Reset agent state for a new rollout/episode.""" + self.message_id_list.clear() + self.ever_read_message_ids.clear() + await self.memory.clear() + + def search_emails( + self, + inbox_address: str, + query_date: str, + keywords: list[str], + **_kwargs: Any, + ) -> ToolResponse: + """ + Search the user's email inbox for emails that match the given keywords. + + Args: + inbox_address: The user's email address. + query_date: The date of the query in 'YYYY-MM-DD' format. + keywords: Keywords to search for in the user's email inbox. + + Returns: + ToolResponse: + A ToolResponse object containing a list of TextBlock objects + in the `content` field. On success, the text field of the + TextBlock contains a JSON string representing a list of email + summaries (e.g., message_id, snippet) matching the search + criteria. Each email summary is converted to a dictionary via + `asdict`. On failure, the text indicates an error message. + """ + + try: + next_day = ( + datetime.strptime(query_date, "%Y-%m-%d") + timedelta(days=1) + ).strftime( + "%Y-%m-%d", + ) + res = search_emails_tool( + inbox=inbox_address, + sent_before=next_day, + keywords=keywords, + ) + + self.message_id_list.extend([r.message_id for r in res]) + + return ToolResponse( + content=[ + TextBlock( + type="text", + text=json.dumps([asdict(r) for r in res]), + ), + ], + ) + except Exception as e: + logger.info( + "Error in search_emails: %s, traceback: %s", + e, + traceback.format_exc(), + ) + return ToolResponse( + content=[ + TextBlock( + type="text", + text=( + f"Error: Failed to search emails.\n" + f"Error message: {e}" + ), + ), + ], + ) + + def read_email(self, message_id: str, **_kwargs: Any) -> ToolResponse: + """ + Read the content of an email from the user's email inbox. + Returns the email content. + + Args: + message_id (str): The unique identifier of the email to read. + + Returns: + ToolResponse: + A ToolResponse object containing the email content or an + error message if the email is not found. + """ + + try: + email_content = read_email_tool(message_id) + + self.ever_read_message_ids.append(message_id) + + if email_content is None: + return ToolResponse( + content=[ + TextBlock( + type="text", + text=( + f"Error: Email (message_id = {message_id}) " + f"not found." + ), + ), + ], + ) + return ToolResponse( + content=[ + TextBlock( + type="text", + text=json.dumps(email_content.model_dump()), + ), + ], + ) + except Exception as e: + logger.info( + "Error in read_email: %s, traceback: %s", + e, + traceback.format_exc(), + ) + return ToolResponse( + content=[ + TextBlock( + type="text", + text=( + f"Error: Failed to read email.\n" + f"Error message: {e}" + ), + ), + ], + ) diff --git a/tuner/email_search/_utils.py b/tuner/email_search/_utils.py new file mode 100644 index 0000000..67cde62 --- /dev/null +++ b/tuner/email_search/_utils.py @@ -0,0 +1,328 @@ +# -*- coding: utf-8 -*- +""" +This file defines Dataclass and tool implementations. +Modified from https://github.com/OpenPipe/ART/blob/art-e/ +""" +import datetime +import os +import sqlite3 +from dataclasses import dataclass +from typing import Any, List, Optional +from pydantic import BaseModel, Field, field_validator +from agentscope import logger + +DEFAULT_DB_PATH = os.environ.get("DEFAULT_EMAIL_DB_PATH") +conn = None + + +def get_conn() -> sqlite3.Connection: + """Get or create a database connection.""" + global conn + if conn is None: + conn = sqlite3.connect( + f"file:{DEFAULT_DB_PATH}?mode=ro", + uri=True, + check_same_thread=False, + ) + return conn + + +class QueryModel(BaseModel): + """Model for email search query.""" + + id: int + question: str + answer: str + message_ids: List[str] # message_ids (strings) of referenced emails + how_realistic: float + inbox_address: str + query_date: str + + @field_validator("query_date", mode="before") + @classmethod + def format_date(cls, v: Any) -> str: + """Format date to string if it's a datetime object.""" + if isinstance(v, datetime.datetime): + return v.strftime("%Y-%m-%d") + return v + + +class AnswerModel(BaseModel): + """Model for agent's answer with sources.""" + + answer: str = Field( + description=( + "It should be called with the answer and the sources. " + "If you cannot find the answer, you should return " + "'I don't know' with an empty list of sources." + ), + ) + sources: List[str] = Field( + description=( + "a list of message ids that are relevant to the query. " + "Usually there will be only one. If you cannot find the " + "answer, you should return an empty list." + ), + ) + + +class Email(BaseModel): + """Model representing an email.""" + + message_id: str + date: str # ISO 8601 string 'YYYY-MM-DD HH:MM:SS' + subject: Optional[str] = None + from_address: Optional[str] = None + to_addresses: List[str] = Field(default_factory=list) + cc_addresses: List[str] = Field(default_factory=list) + bcc_addresses: List[str] = Field(default_factory=list) + body: Optional[str] = None + file_name: Optional[str] = None + + +@dataclass +class SearchResult: + """Result from email search.""" + + message_id: str + snippet: str + + +class FinalRubric(BaseModel): + """Rubric for evaluating agent performance.""" + + answer_correct: bool = False + sources_correct: bool = False + num_turns: int = 0 + attempted_answer: bool = False + ever_found_right_email: bool = False + ever_read_right_email: bool = False + cant_parse_tool_call: bool = False + bad_tool_call_name: bool = False + bad_tool_call_args: bool = False + ran_out_of_turns: bool = False + returned_i_dont_know: bool = False + num_sources: int = 0 + ever_tried_to_read_invalid_email: bool = False + prompt_tokens: int = 0 + completion_tokens: int = 0 + + +# Define tools for agent + + +def search_emails_tool( + inbox: str, + keywords: List[str], + from_addr: Optional[str] = None, + to_addr: Optional[str] = None, + sent_after: Optional[str] = None, + sent_before: Optional[str] = None, + max_results: int = 10, +) -> List[SearchResult]: + """ + Searches the email database based on keywords, inbox, + sender, recipient, and date range. + + Args: + inbox: The email address of the user performing the search. + Results include emails sent from or to (inc. cc/bcc) + this address. + keywords: A list of keywords that must all appear in the + subject or body. + from_addr: Optional email address to filter emails sent *from*. + to_addr: Optional email address to filter emails sent *to* + (inc. cc/bcc). + sent_after: Optional date string 'YYYY-MM-DD'. Filters for + emails sent on or after this date. + sent_before: Optional date string 'YYYY-MM-DD'. Filters for + emails sent before this date. + max_results: The maximum number of results to return. + Cannot exceed 10. + + Returns: + A list of SearchResult objects, each containing 'message_id' + and 'snippet'. Returns an empty list if no results are found + or an error occurs. + """ + # Initialize sql and params + sql: Optional[str] = None + params: List[str | int] = [] + + cursor = get_conn().cursor() + + # --- Build Query --- + where_clauses: List[str] = [] + + # 1. Keywords (FTS) + if not keywords: + raise ValueError("No keywords provided for search.") + + if max_results > 10: + raise ValueError("max_results must be less than or equal to 10.") + + # FTS5 default is AND, so just join keywords. Escape quotes for safety. + fts_query = " ".join(f""" "{k.replace('"', '""')}" """ for k in keywords) + where_clauses.append("emails_fts MATCH ?") + params.append(fts_query) + + # 2. Inbox filter (must be from OR to/cc/bcc the inbox user) + # Use the composite index idx_recipients_address_email here + where_clauses.append( + """ + (e.from_address = ? OR EXISTS ( + SELECT 1 FROM recipients r_inbox + WHERE r_inbox.recipient_address = ? AND r_inbox.email_id = e.id + )) + """, + ) + params.extend([inbox, inbox]) + + # 3. Optional From filter + if from_addr: + where_clauses.append("e.from_address = ?") + params.append(from_addr) + + # 4. Optional To filter (includes to, cc, bcc) + # Use composite index idx_recipients_address_email + if to_addr: + where_clauses.append( + """ + EXISTS ( + SELECT 1 FROM recipients r_to + WHERE r_to.recipient_address = ? AND r_to.email_id = e.id + ) + """, + ) + params.append(to_addr) + + # 5. Optional Sent After filter + if sent_after: + # Assumes date format 'YYYY-MM-DD' + # Compare against the start of the day + where_clauses.append("e.date >= ?") + params.append(f"{sent_after} 00:00:00") + + # 6. Optional Sent Before filter + if sent_before: + # Assumes date format 'YYYY-MM-DD' + # Compare against the start of the day (exclusive) + where_clauses.append("e.date < ?") + params.append(f"{sent_before} 00:00:00") + + # --- Construct Final Query --- + # snippet(, , , + # , , ) + # -1 means highlight across all columns (subject, body) + sql = f""" + SELECT + e.message_id, + snippet(emails_fts, -1, '', '', ' ... ', 15) as snippet + FROM + emails e JOIN emails_fts fts ON e.id = fts.rowid + WHERE + {" AND ".join(where_clauses)} + ORDER BY + e.date DESC -- Order by date for relevance + LIMIT ?; + """ + params.append(max_results) + + # --- Execute and Fetch --- + logger.debug("Executing SQL: %s", sql) + logger.debug("With params: %s", params) + cursor.execute(sql, params) + results = cursor.fetchall() + + # Format results + formatted_results = [ + SearchResult(message_id=row[0], snippet=row[1]) for row in results + ] + logger.info("Search found %d results.", len(formatted_results)) + return formatted_results + + +def read_email_tool(message_id: str) -> Optional[Email]: + """ + Retrieves a single email by its message_id from the database. + + Args: + message_id: The unique identifier of the email to retrieve. + + Returns: + An Email object containing the details of the found email, + or None if the email is not found or an error occurs. + """ + cursor = get_conn().cursor() + + # --- Query for Email Core Details --- + email_sql = """ + SELECT id, message_id, date, subject, from_address, body, file_name + FROM emails + WHERE message_id = ?; + """ + cursor.execute(email_sql, (message_id,)) + email_row = cursor.fetchone() + + if not email_row: + logger.warning("Email with message_id '%s' not found.", message_id) + return None + + email_pk_id, msg_id, date, subject, from_addr, body, file_name = email_row + + # DEBUG + logger.info("[read_email_tool] input_message_id=%s", message_id) + logger.info( + "[read_email_tool] db: id=%s, message_id=%s", + email_pk_id, + msg_id, + ) + + # search for recipients by emails.id (rather than message_id) + recipients_sql = """ + SELECT recipient_address, recipient_type + FROM recipients + WHERE email_id = ?; + """ + cursor.execute(recipients_sql, (email_pk_id,)) + recipient_rows = cursor.fetchall() + + to_addresses: List[str] = [] + cc_addresses: List[str] = [] + bcc_addresses: List[str] = [] + + for addr, rtype in recipient_rows: + type_lower = rtype.lower() + if type_lower == "to": + to_addresses.append(addr) + elif type_lower == "cc": + cc_addresses.append(addr) + elif type_lower == "bcc": + bcc_addresses.append(addr) + + # --- Construct Email Object --- + email_obj = Email( + message_id=msg_id, # Convert to string to match Pydantic model + date=date, + subject=subject, + from_address=from_addr, + to_addresses=to_addresses, + cc_addresses=cc_addresses, + bcc_addresses=bcc_addresses, + body=body, + file_name=file_name, + ) + + return email_obj + + +__all__ = [ + "QueryModel", + "AnswerModel", + "FinalRubric", + "Email", + "SearchResult", + "search_emails_tool", + "read_email_tool", + "get_conn", +] diff --git a/tuner/email_search/config.yaml b/tuner/email_search/config.yaml new file mode 100644 index 0000000..76fbf11 --- /dev/null +++ b/tuner/email_search/config.yaml @@ -0,0 +1,72 @@ +project: "AgentScope" # Project name +name: "Email_search" # Experiment name +checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints} # Directory to save model checkpoints +algorithm: + algorithm_type: multi_step_grpo # GRPO series for multi-step scenario + repeat_times: 8 # Number of rollouts per prompt for advantage estimation + optimizer: + lr: 1e-6 # Learning rate + policy_loss_fn: "rec" # Policy loss function + policy_loss_fn_args: # Policy loss function arguments + epsilon_low: 0.2 + epsilon_high: 0.2 + clip_mode: "one-side" + weight: "none" + temp: 1.0 + regularizer: "none" + regularizer_coef: 0.0 + kl_loss_fn: 'k2' # KL divergence loss function + kl_loss_fn_args: + kl_coef: 0.0 # KL divergence coefficient + advantage_fn_args: + std_cal_level: 'batch' # Advantage normalization level +model: + model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen3-4B-Instruct-2507} # Base model path + max_response_tokens: 4096 # Max tokens per response + max_model_len: 20480 # Max context length +buffer: + total_epochs: 10 # Total training epochs + batch_size: 64 # Batch size per explore step + train_batch_size: 2560 # 64*8*5, total experiences per training step + trainer_input: + experience_buffer: + name: experience_buffer + storage_type: queue + replay_buffer: + enable: true # Enable experience replay + priority_fn: 'decay_limit_randomization' + priority_fn_args: + decay: 2.0 + use_count_limit: 3 + sigma: 2.0 +explorer: + eval_interval: 10 + max_repeat_times_per_runner: 1 # Max repeat times per runner + max_timeout: 3600 # Max timeout for each rollout (seconds) + rollout_model: + enable_history: true # Enable conversation history + enable_openai_api: true # Enable OpenAI-compatible API + enable_auto_tool_choice: true # Enable automatic tool selection + tool_call_parser: hermes # Parser for tool calls + engine_num: 4 # Number of vLLM engines for rollout model + tensor_parallel_size: 1 # TP size per engine for rollout model + enable_prefix_caching: false # Disable prefix caching + auxiliary_models: + - name: judge + model_path: Qwen/Qwen3-30B-A3B-Instruct-2507 # Judge model path + engine_num: 1 # Number of vLLM engines for judge model + tensor_parallel_size: 2 # TP size per engine for judge model + enable_thinking: false # Disable thinking/reasoning mode + max_prompt_tokens: 2048 # Max tokens for prompt + max_response_tokens: 128 # Max tokens for response + max_model_len: 2500 # Max model context length +synchronizer: + sync_style: dynamic_by_explorer # Sync triggered dynamically by explorer + sync_interval: 5 # Sync every N steps + sync_timeout: 3600 # Timeout for synchronization (seconds) +trainer: + save_interval: 100 # Save checkpoint every N steps + grad_clip: 1.0 # Gradient clipping value + use_dynamic_bsz: true # Use dynamic batch size + max_token_len_per_gpu: 16384 # Max token length per GPU + ulysses_sequence_parallel_size: 1 # Sequence parallel size for Ulysses diff --git a/tuner/email_search/critic_reward_mean.png b/tuner/email_search/critic_reward_mean.png new file mode 100644 index 0000000..334fe39 Binary files /dev/null and b/tuner/email_search/critic_reward_mean.png differ diff --git a/tuner/email_search/main.py b/tuner/email_search/main.py new file mode 100644 index 0000000..aab4253 --- /dev/null +++ b/tuner/email_search/main.py @@ -0,0 +1,379 @@ +# -*- coding: utf-8 -*- +"""Example of training an Email Search agent with Trinity-RFT.""" +import os +from typing import Dict +from _email_search_agent import EmailSearchAgent +from _utils import ( # pylint: disable=E0611 + AnswerModel, + FinalRubric, + QueryModel, +) +from agentscope import logger +from agentscope.formatter import OpenAIChatFormatter +from agentscope.message import Msg +from agentscope.tuner import ( + TunerModelConfig, + DatasetConfig, + JudgeOutput, + WorkflowOutput, + AlgorithmConfig, + tune, +) +from agentscope.model import ChatModelBase + + +SYSTEM_PROMPT = """You are an email search agent. You are given a user query +and a list of tools you can use to search the user's email. Use the tools to +search the user's emails and find the answer to the user's query. You may take +up to {max_turns} turns to find the answer, so if your first seach doesn't +find the answer, you can try with different keywords. + +Always describe what you see and plan your next steps clearly. When taking +actions, explain what you're doing and why. When the answer to the task is +found, call `generate_response` to finish the process. Only call +`generate_response` when answer is found. You should not respond any next steps +in `generate_response`. Complete all steps and then call `generate_response`. + +User's email address is {inbox_address} + +Today's date is {query_date} + +""" + + +async def run_email_search_agent( + task: Dict, + model: ChatModelBase, + auxiliary_models: Dict[str, ChatModelBase], +) -> WorkflowOutput: # noqa: PLR0915 + """A workflow function using the Email Search agent to solve tasks. + + Args: + task (Dict): The task to be solved. + Should contain fields from QueryModel. + model (TrinityChatModel): The language model to use. + + Returns: + WorkflowOutput: The output containing the agent's response. + """ + assert len(auxiliary_models) > 0, "LLM-as-a-Judge is required" + + # Parse task data + query = QueryModel.model_validate(task) + question = task.get("question", task.get("task_desc", "")) + + # Get workflow arguments with defaults + workflow_args = task.get("workflow_args", {}) + max_turns = int(workflow_args.get("max_turns", 10)) + + # Format system prompt + system_prompt = SYSTEM_PROMPT.format( + max_turns=max_turns, + inbox_address=query.inbox_address, + query_date=query.query_date, + ) + + # Create EmailSearchAgent + agent = EmailSearchAgent( + name="email_search_agent", + sys_prompt=system_prompt, + model=model, + formatter=OpenAIChatFormatter(), + max_iters=max_turns, + ) + + # Reset agent state for a new rollout + await agent.reset() + + # Run the agent with structured output + response = await agent.reply( + msg=Msg("user", question, role="user"), + structured_model=AnswerModel, + ) + + # Extract answer and sources from response metadata + answer_and_sources = response.metadata or {} + if not answer_and_sources: + # Fallback: try to parse from content + answer_and_sources = { + "answer": response.get_text_content() or "", + "sources": [], + } + + # Store agent state for judge function + # We'll pass this through the response metadata + response_metadata = { + "answer_and_sources": answer_and_sources, + "query": query.model_dump(), + "message_id_list": agent.message_id_list, + "ever_read_message_ids": agent.ever_read_message_ids, + # Estimate actual_turns from memory length + "actual_turns": ( + max(1, (len(agent.memory.content) - 1) // 2) + if len(agent.memory.content) > 1 + else 1 + ), + } + + # Update response metadata + if response.metadata is None: + response.metadata = {} + response.metadata.update(response_metadata) + + return WorkflowOutput( + response=response, + ) + + +def _calculate_partial_rewards(rubric: FinalRubric) -> float: + """Calculate partial rewards based on rubric.""" + partial_rewards = 0.0 + partial_rewards += 0.1 if rubric.ever_found_right_email else 0 + partial_rewards += 0.1 if rubric.ever_read_right_email else 0 + partial_rewards += 0.1 if rubric.sources_correct else 0 + return partial_rewards + + +def _calculate_correct_answer_reward( + rubric: FinalRubric, + max_turns: int, +) -> float: + """Calculate reward for correct answers.""" + reward = 1.0 + reward += 0.3 if rubric.sources_correct else 0 + reward += 0.1 / rubric.num_sources if rubric.num_sources > 0 else 0 + reward += 0.1 * (1 - rubric.num_turns / max_turns) + return reward + + +def _initialize_rubric( + answer: str, + sources: list[str], + actual_turns: int, + query: QueryModel, + message_id_list: list[str], + ever_read_message_ids: list[str], +) -> FinalRubric: + """Initialize and populate rubric with basic information.""" + rubric = FinalRubric() + rubric.attempted_answer = answer is not None and answer != "" + rubric.returned_i_dont_know = answer == "I don't know" + rubric.num_sources = len(sources) + rubric.num_turns = actual_turns + + if len(query.message_ids) > 0: + rubric.ever_found_right_email = query.message_ids[0] in message_id_list + rubric.ever_read_right_email = ( + query.message_ids[0] in ever_read_message_ids + ) + rubric.sources_correct = query.message_ids[0] in sources + return rubric + + +async def email_search_judge( + task: Dict, + response: Msg, + auxiliary_models: Dict[str, ChatModelBase], +) -> JudgeOutput: + """A judge function to calculate reward based on agent's response. + + Args: + task (Dict): The task information for the corresponding workflow. + response (Msg): The response generated by the corresponding workflow. + auxiliary_models (Dict[str, ChatModelBase]): + A dictionary of additional chat models available for LLM-as-a-Judge + usage. The keys are model names, and the values are the + corresponding ChatModelBase instances. + + Returns: + JudgeOutput: The reward value assigned by the judge function. + """ + # Extract metadata from response + metadata = response.metadata or {} + answer_and_sources = metadata.get("answer_and_sources", {}) + query_dict = metadata.get("query", {}) + message_id_list = metadata.get("message_id_list", []) + ever_read_message_ids = metadata.get("ever_read_message_ids", []) + actual_turns = metadata.get("actual_turns", 0) + + # Parse query model + if not query_dict: + query_dict = task + query = QueryModel.model_validate(query_dict) + + # Get arguments + workflow_args = task.get("workflow_args", {}) + max_turns = int(workflow_args.get("max_turns", 10)) + + # Extract answer and sources + try: + answer = answer_and_sources.get("answer", None) + sources = answer_and_sources.get("sources", []) + except Exception: + result = {"accuracy": 0.0, "format": -1.0} + return JudgeOutput( + reward=sum(result.values()), + metrics=result, + ) + + if answer is None: + result = {"accuracy": 0.0, "format": -1.0} + return JudgeOutput( + reward=sum(result.values()), + metrics=result, + ) + + # Initialize rubric + rubric = _initialize_rubric( + answer, + sources, + actual_turns, + query, + message_id_list, + ever_read_message_ids, + ) + + # Judge correctness using LLM-as-a-Judge + try: + judge_model = ( + auxiliary_models.get("judge") or list(auxiliary_models.values())[0] + if auxiliary_models + else None + ) + judge_response = await judge_correctness( + answer, + query, + judge_model, + ) + rubric.answer_correct = judge_response + except Exception as e: + logger.error("Error judging correctness: %s", e) + rubric.answer_correct = False + + # Calculate rewards + partial_rewards = _calculate_partial_rewards(rubric) + + if rubric.attempted_answer and not rubric.answer_correct: + result = {"accuracy": -1.0, "format": partial_rewards} + elif rubric.returned_i_dont_know or rubric.ran_out_of_turns: + result = {"accuracy": 0.0, "format": partial_rewards} + elif rubric.answer_correct: + reward = _calculate_correct_answer_reward(rubric, max_turns) + result = {"accuracy": 1.0, "format": reward} + else: + result = {"accuracy": 0.0, "format": 0.0} + + metrics = result.copy() + metrics.update({"actual_turns": actual_turns}) + + return JudgeOutput( + reward=sum(result.values()), + metrics=metrics, + ) + + +# LLM-as-a-judge + + +async def judge_correctness( + answer: str, + query: QueryModel, + judge: ChatModelBase, +) -> bool: + """Use an LLM to decide whether *answer* matches *query.answer*. + + Returns a boolean *accept* flag used for scoring. + """ + + system_prompt = """You are given a question, the reference answer +(labelled **Reference answer**), and an answer generated by an AI assistant +(labelled **AI answer**). + +Follow these steps to decide whether the AI answer should be accepted: +1. Identify EXACTLY what information the **question** is asking for + (e.g. who, what, when, where, why, how, quantity, etc.). +2. From the **Reference answer**, extract ONLY the facts that are required + to directly satisfy the information need identified in step 1. Treat all + other facts as non-essential context. +3. Verify that every essential fact from step 2 appears in the **AI answer** + with the same meaning. Differences in wording, order, or additional + non-conflicting details are allowed. +4. If any essential fact is missing or contradicted in the **AI answer**, + then *accept* must be **false**. Otherwise *accept* must be **true**. + +Important: Do NOT penalise the **AI answer** for omitting non-essential +facts that appear in the **Reference answer**. The answer should only be +rejected for errors or omissions in the information explicitly requested by +the question. + +Return your judgement **accept** from **true** and **false**. Do not return +any other text or formatting. +""" + prompt = ( + f"Question: {query.question}\n" + f"Reference answer: {query.answer}\n" + f"AI answer: {answer}" + ) + + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": prompt}, + ] + chat_response = await judge(messages) + + # Extract text content from ChatResponse + result_parts = [] + for block in chat_response.content: + if isinstance(block, dict) and block.get("type") == "text": + result_parts.append(str(block.get("text", ""))) + result = "".join(result_parts) + logger.info("LLM judge response: %s", result) + + return "true" in result.lower() + + +# End of LLM-as-a-judge + + +if __name__ == "__main__": + config_path = os.path.join( + os.path.dirname(__file__), + "config.yaml", + ) + dataset = DatasetConfig( + path="/path/to/enron_emails_dataset", + split="train", + ) + tuner_model = TunerModelConfig( + model_path="Qwen/Qwen3-4B-Instruct-2507", + max_model_len=20480, + max_tokens=4096, + inference_engine_num=4, + reasoning_parser=None, + ) + aux_models = { + "judge": TunerModelConfig( + model_path="Qwen/Qwen3-30B-A3B-Instruct-2507", + max_model_len=2500, + max_tokens=2048, + inference_engine_num=1, + tensor_parallel_size=2, + tool_call_parser=None, + reasoning_parser=None, + ), + } + algorithm = AlgorithmConfig( + algorithm_type="multi_step_grpo", + group_size=8, + batch_size=64, + learning_rate=1e-6, + ) + tune( + workflow_func=run_email_search_agent, + judge_func=email_search_judge, + train_dataset=dataset, + model=tuner_model, + auxiliary_models=aux_models, + algorithm=algorithm, + config_path=config_path, + ) diff --git a/tuner/email_search/prepare_data.py b/tuner/email_search/prepare_data.py new file mode 100644 index 0000000..d987d25 --- /dev/null +++ b/tuner/email_search/prepare_data.py @@ -0,0 +1,357 @@ +# -*- coding: utf-8 -*- +""" +Prepare data for training. +Modified from OpenPipe/ART +""" + +import logging +import os +import sqlite3 +from datetime import datetime +from datasets import Dataset, Features, Sequence, Value, load_dataset +from tqdm import tqdm + + +# Resolve paths relative to this file +BASE_DIR = os.path.dirname(os.path.abspath(__file__)) +# Database will live in "../data/enron_emails.db" relative to project root +DEFAULT_DB_PATH = os.path.join(BASE_DIR, "..", "..", "data", "enron_emails.db") + +DEFAULT_REPO_ID = "corbt/enron-emails" + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(levelname)s - %(message)s", +) + + +# --- Database Schema --- +SQL_CREATE_TABLES = """ +DROP TABLE IF EXISTS recipients; +DROP TABLE IF EXISTS emails_fts; +DROP TABLE IF EXISTS emails; + +CREATE TABLE emails ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + message_id TEXT UNIQUE, + subject TEXT, + from_address TEXT, + date TEXT, -- Store as ISO 8601 string 'YYYY-MM-DD HH:MM:SS' + body TEXT, + file_name TEXT +); + +CREATE TABLE recipients ( + email_id INTEGER, + recipient_address TEXT, + recipient_type TEXT, -- 'to', 'cc', 'bcc' + FOREIGN KEY(email_id) REFERENCES emails(id) ON DELETE CASCADE +); +""" + +SQL_CREATE_INDEXES_TRIGGERS = """ +CREATE INDEX idx_emails_from ON emails(from_address); +CREATE INDEX idx_emails_date ON emails(date); +CREATE INDEX idx_emails_message_id ON emails(message_id); +CREATE INDEX idx_recipients_address ON recipients(recipient_address); +CREATE INDEX idx_recipients_type ON recipients(recipient_type); +CREATE INDEX idx_recipients_email_id ON recipients(email_id); +CREATE INDEX idx_recipients_address_email ON recipients( + recipient_address, email_id +); + +CREATE VIRTUAL TABLE emails_fts USING fts5( + subject, + body, + content='emails', + content_rowid='id' +); + +CREATE TRIGGER emails_ai AFTER INSERT ON emails BEGIN + INSERT INTO emails_fts (rowid, subject, body) + VALUES (new.id, new.subject, new.body); +END; + +CREATE TRIGGER emails_ad AFTER DELETE ON emails BEGIN + DELETE FROM emails_fts WHERE rowid=old.id; +END; + +CREATE TRIGGER emails_au AFTER UPDATE ON emails BEGIN + UPDATE emails_fts SET subject=new.subject, body=new.body + WHERE rowid=old.id; +END; + +INSERT INTO emails_fts (rowid, subject, body) +SELECT id, subject, body FROM emails; +""" + + +# --- Functions --- + + +def download_dataset(repo_id: str) -> Dataset: + """Downloads the dataset from Hugging Face Hub.""" + logging.info( + "Attempting to download dataset from Hugging Face Hub: %s", + repo_id, + ) + expected_features = Features( + { + "message_id": Value("string"), + "subject": Value("string"), + "from": Value("string"), + "to": Sequence(Value("string")), + "cc": Sequence(Value("string")), + "bcc": Sequence(Value("string")), + "date": Value("timestamp[us]"), + "body": Value("string"), + "file_name": Value("string"), + }, + ) + dataset_obj = load_dataset( + repo_id, + features=expected_features, + split="train", + ) + # Basic type check remains useful + if not isinstance(dataset_obj, Dataset): + raise TypeError(f"Expected Dataset, got {type(dataset_obj)}") + logging.info( + "Successfully loaded dataset '%s' with %d records.", + repo_id, + len(dataset_obj), + ) + return dataset_obj + + +def create_database(db_path: str) -> None: + """Creates the SQLite database and tables.""" + logging.info("Creating SQLite database and tables at: %s", db_path) + conn = sqlite3.connect(db_path) + cursor = conn.cursor() + cursor.executescript(SQL_CREATE_TABLES) + conn.commit() + conn.close() + logging.info("Database tables created successfully.") + + +def _should_skip_email( + body: str, + message_id: str, + to_list: list[str], + cc_list: list[str], + bcc_list: list[str], +) -> bool: + """Check if email should be skipped based on filters.""" + if len(body) > 5000: + logging.debug( + "Skipping email %s: Body length > 5000 characters.", + message_id, + ) + return True + + total_recipients = len(to_list) + len(cc_list) + len(bcc_list) + if total_recipients > 30: + logging.debug( + "Skipping email %s: Total recipients (%d) > 30.", + message_id, + total_recipients, + ) + return True + return False + + +def _prepare_recipient_data( + email_pk_id: int, + to_list: list[str], + cc_list: list[str], + bcc_list: list[str], +) -> list[tuple[int, str, str]]: + """Prepare recipient data for database insertion.""" + recipient_data = [] + for addr in to_list: + recipient_data.append((email_pk_id, addr, "to")) + for addr in cc_list: + recipient_data.append((email_pk_id, addr, "cc")) + for addr in bcc_list: + recipient_data.append((email_pk_id, addr, "bcc")) + return recipient_data + + +def populate_database(db_path: str, dataset: Dataset) -> None: + """Populates the database with data from the Hugging Face dataset.""" + logging.info("Populating database %s...", db_path) + conn = sqlite3.connect(db_path) + cursor = conn.cursor() + + # --- Performance Pragmas --- + conn.execute("PRAGMA synchronous = OFF;") + conn.execute("PRAGMA journal_mode = MEMORY;") + + record_count = 0 + skipped_count = 0 + duplicate_count = 0 + processed_emails = set() + + conn.execute("BEGIN TRANSACTION;") + + for email_data in tqdm(dataset, desc="Inserting emails"): + assert isinstance(email_data, dict) + message_id = email_data["message_id"] + subject = email_data["subject"] + from_address = email_data["from"] + date_obj: datetime = email_data["date"] + body = email_data["body"] + file_name = email_data["file_name"] + to_list_raw = email_data["to"] + cc_list_raw = email_data["cc"] + bcc_list_raw = email_data["bcc"] + + date_str = date_obj.strftime("%Y-%m-%d %H:%M:%S") + to_list = [str(addr) for addr in to_list_raw if addr] + cc_list = [str(addr) for addr in cc_list_raw if addr] + bcc_list = [str(addr) for addr in bcc_list_raw if addr] + + if _should_skip_email(body, message_id, to_list, cc_list, bcc_list): + skipped_count += 1 + continue + + email_key = (subject, body, from_address) + if email_key in processed_emails: + logging.debug( + "Skipping duplicate email (Subject: %s..., From: %s)", + subject[:50], + from_address, + ) + duplicate_count += 1 + continue + processed_emails.add(email_key) + + cursor.execute( + """ + INSERT INTO emails ( + message_id, subject, from_address, date, body, file_name + ) + VALUES (?, ?, ?, ?, ?, ?) + """, + (message_id, subject, from_address, date_str, body, file_name), + ) + email_pk_id = cursor.lastrowid + if email_pk_id is None: + logging.warning( + "Failed to get email ID after insert for message_id: %s", + message_id, + ) + continue + + recipient_data = _prepare_recipient_data( + email_pk_id, + to_list, + cc_list, + bcc_list, + ) + + if recipient_data: + cursor.executemany( + """ + INSERT INTO recipients ( + email_id, recipient_address, recipient_type + ) + VALUES (?, ?, ?) + """, + recipient_data, + ) + record_count += 1 + + conn.commit() + conn.close() + logging.info("Successfully inserted %d email records.", record_count) + if skipped_count > 0: + logging.info( + "Skipped %d email records due to length or recipient limits.", + skipped_count, + ) + if duplicate_count > 0: + logging.info( + "Skipped %d duplicate email records " + "(based on subject, body, from).", + duplicate_count, + ) + + +def create_indexes_and_triggers(db_path: str) -> None: + """Creates indexes and triggers on the populated database.""" + logging.info("Creating indexes and triggers for database: %s...", db_path) + conn = sqlite3.connect(db_path) + cursor = conn.cursor() + cursor.executescript(SQL_CREATE_INDEXES_TRIGGERS) + conn.commit() + conn.close() + logging.info("Indexes and triggers created successfully.") + + +def generate_database( + repo_id: str = DEFAULT_REPO_ID, + db_path: str = DEFAULT_DB_PATH, + overwrite: bool = False, +) -> None: + """ + Generates the SQLite database from the specified Hugging Face dataset. + Simplified version without extensive error handling. + + Args: + repo_id: The Hugging Face repository ID for the dataset. + db_path: The path where the SQLite database file should be + created. + overwrite: If True, any existing database file at db_path will + be removed. + """ + logging.info( + "Starting database generation for repo '%s' at '%s'", + repo_id, + db_path, + ) + logging.info("Overwrite existing database: %s", overwrite) + + db_dir = os.path.dirname(db_path) + if db_dir and not os.path.exists(db_dir): + logging.info("Creating data directory: %s", db_dir) + os.makedirs(db_dir) + + if overwrite and os.path.exists(db_path): + logging.warning("Removing existing database file: %s", db_path) + os.remove(db_path) + elif not overwrite and os.path.exists(db_path): + # If not overwriting and file exists, subsequent steps might fail + # or behave unexpectedly. We are removing the explicit error here + # as requested. + logging.warning( + "Database file %s exists and overwrite is False. " + "Assuming file is already generated.", + db_path, + ) + return + + # 1. Download dataset + dataset = download_dataset(repo_id) + + # 2. Create database schema (Tables only) + # Note: This will fail if overwrite=False and the file exists with + # incompatible schema/data. + create_database(db_path) + + # 3. Populate database + populate_database(db_path, dataset) + + # 4. Create Indexes and Triggers + create_indexes_and_triggers(db_path) + + logging.info("Database generation process completed for %s.", db_path) + logging.info( + "Please set the environment variable DEFAULT_EMAIL_DB_PATH " + "to this path.", + ) + + +if __name__ == "__main__": + generate_database(overwrite=True) diff --git a/tuner/email_search/rollout_accuracy_mean.png b/tuner/email_search/rollout_accuracy_mean.png new file mode 100644 index 0000000..63e8d16 Binary files /dev/null and b/tuner/email_search/rollout_accuracy_mean.png differ diff --git a/tuner/frozen_lake/README.md b/tuner/frozen_lake/README.md new file mode 100644 index 0000000..dca3b94 --- /dev/null +++ b/tuner/frozen_lake/README.md @@ -0,0 +1,271 @@ +# Training FrozenLake Agent with RL using AgentScope-Tuner + +## Summary + +This example demonstrates how to use AgentScope-Tuner to implement reinforcement fine-tuning for the [Frozen Lake](https://gymnasium.farama.org/environments/toy_text/frozen_lake/) task using [Trinity-RFT](https://github.com/agentscope-ai/Trinity-RFT). The agent learns to navigate a frozen lake grid from a starting position to a goal while avoiding holes through multi-step interactions with the environment. + +## Task Setting + +### Agent Goal +The agent's objective is to navigate from the starting position (S) to the goal position (G) on a frozen lake grid without falling into holes (H). The agent must: +- Plan a path through frozen tiles (F) to reach the goal +- Avoid holes that terminate the episode with zero reward +- Complete the task within a limited number of steps + +### Agent Type +The agent is implemented as a **ReActAgent** (Reasoning and Acting Agent) that: +- Observes the current state of the frozen lake grid +- Reasons about the best action to take +- Executes actions (Up, Down, Left, Right) to move through the environment +- Maintains internal state across multiple steps in an episode + +### Environment +The environment is based on Gymnasium's FrozenLake environment, wrapped to provide: +- **Grid-based navigation**: Randomly generated maps with configurable size (2x2 to 6x6) +- **Tile types**: + - `S`: Start position + - `F`: Frozen tile (safe to walk on) + - `H`: Hole (terminates episode with reward 0) + - `G`: Goal (terminates episode with reward +1.0) +- **Action space**: Discrete actions (Up, Down, Left, Right) +- **Reward structure**: + - +1.0 for reaching the goal + - 0.0 for falling into a hole or failing to reach the goal +- **Observations**: Text-based grid representation showing current player position + +The agent does not use external tools. It interacts directly with the environment through: +- `env.reset(task)`: Initialize environment with task parameters +- `env.step(action)`: Execute action and receive observation, reward, and done flag +- `env.render()`: Get text representation of current state + + +## Dataset Preparation + +The dataset contains task parameters for generating FrozenLake environments. Each sample specifies: +- `seed`: Random seed for reproducible map generation +- `size`: Grid size (randomly sampled from 2 to `map_max_size`, e.g., 4x4, 6x6) +- `p`: Probability that a tile is frozen (vs. being a hole), randomly sampled from 0.6 to 0.85 +- `index`: Sample index +- `uid`: Unique identifier combining seed, size, and p + +Run the data preparation script to generate training and test datasets: + +```bash +python get_frozenlake_data.py --map_max_size 6 --train_size 10000 --test_size 100 +``` + +This will create parquet files in the specified directory: + +``` +/path/to/frozenlake_dataset/ + ├── train.parquet # 10000 training samples + └── test.parquet # 100 test samples +``` + +Each sample looks like: + +```json +{"seed": 12345, "size": 5, "p": 0.75, "index": 0, "uid": "12345_5_0.75"} +``` + +**Note**: The data preparation script ensures that all generated maps have a valid path from start to goal within the maximum allowed steps (`env_max_steps=8`), filtering out unsolvable tasks. + +## Code Implementation + +This section provides a high-level overview of the code implementation. For detailed implementation, please refer to the source code. + +### High-level Overview + +The implementation consists of three main components: + +1. **Agent** (`FrozenLakeAgent`): Extends `ReActAgent` to handle multi-step navigation +2. **Environment** (`FrozenLakeEnv`): Wraps Gymnasium's FrozenLake environment +3. **Workflow** (`run_frozen_lake`): Orchestrates the agent-environment interaction loop + +### Agent Workflow + +The workflow function `run_frozen_lake` implements the agent-environment interaction loop: + +```python +async def run_frozen_lake( + task: Dict, + model: ChatModelBase, + auxiliary_models: Dict[str, ChatModelBase], +) -> WorkflowOutput: + # ... + + # Create agent and environment + agent = FrozenLakeAgent(model=model, ...) + env = FrozenLakeEnv(...) + observation, _ = env.reset(task) + rewards = [] + # ... + + # Agent-environment interaction loop + for _ in range(max_steps): + response = await agent.reply(msg=Msg("user", agent.get_prompt(observation), role="user")) + action = agent.get_action(response) + observation, reward, done, _ = env.step(action) + rewards.append(reward) + if done: + break + + # ... + final_reward = sum(rewards) + final_response = Msg("assistant", response_content, role="assistant") + + return WorkflowOutput( + reward=final_reward, + response=final_response, + metrics={ + "env_steps": float(step_count), + "env_done": float(done), + }, + ) + +``` + +**Key characteristics:** +- Multi-step interaction: The agent takes multiple actions in a single episode, unlike single-turn QA tasks +- State tracking: The agent maintains internal state (current step, last action, last observation) across steps +- Error handling: Invalid actions or agent errors are caught and handled gracefully + +### Reward Function + +No separate judge function is needed. The reward comes directly from the environment: +- 1.0: Agent successfully reaches the goal (G) +- 0.0: Agent falls into a hole (H) or fails to reach the goal within the maximum steps + +The reward is computed as the sum of step rewards throughout the episode. The workflow returns: +- `reward`: Final cumulative reward +- `response`: Final response message containing observation, total reward, steps taken, and termination reason +- `metrics`: Additional metrics including `env_steps` (number of steps taken) and `env_done` (whether episode completed) + +### Implementation Details + +The environment (`FrozenLakeEnv`) wraps Gymnasium's FrozenLake and provides: +- `reset(task)`: Initialize the environment with task parameters +- `step(action)`: Execute an action and return (observation, reward, done, info) +- `render()`: Return a text representation of the current state + +The agent (`FrozenLakeAgent`) extends `ReActAgent` and provides: +- `reply(msg)`: Reply to a message and return an action (inherited from AgentScope) +- `get_prompt(observation)`: Generate a prompt from the current observation +- `get_action(response)`: Parse the model's response to extract an action (Up/Down/Left/Right) +- `update_state(action, observation)`: Update internal state after each step + +See [frozenlake_env.py](./frozenlake_env.py) and [frozenlake_agent.py](./frozenlake_agent.py) for implementation details. + +### Step 4: Use `tune` to train the workflow + +```python +from agentscope.tuner import tune, DatasetConfig + +if __name__ == "__main__": + config_path = os.path.join( + os.path.dirname(__file__), + "config.yaml", + ) + dataset = DatasetConfig( + path="/path/to/frozenlake_dataset", + name="default", + split="train", + ) + tune( + workflow_func=run_frozen_lake, + train_dataset=dataset, + config_path=config_path, + ) +``` + +See [config.yaml](./config.yaml) for the training configuration. For full configuration details, see [Trinity-RFT Configuration Guide](https://agentscope-ai.github.io/Trinity-RFT/en/main/tutorial/trinity_configs.html). + +--- + +## How to Run + +### Prerequisites + +- At least 2 NVIDIA GPUs with CUDA 12.8 or newer +- Follow the Trinity-RFT [installation guide](https://agentscope-ai.github.io/Trinity-RFT/en/main/tutorial/trinity_installation.html) to install the latest version from source code +- Install gymnasium for the FrozenLake environment: + + ```bash + pip install gymnasium[toy_text] + ``` + +- Download the model checkpoint (example): + + ```bash + huggingface-cli download Qwen/Qwen2.5-3B-Instruct + ``` + +### Step 1: Prepare the Dataset + +```bash +python get_frozenlake_data.py --map_max_size 6 --train_size 10000 --test_size 100 +``` + +Update the dataset path in `main.py` to point to your generated dataset directory. + +### Step 2: Configure the Training + +Key configuration can be identified in the code, including: + +**Algorithm Configuration** (`AlgorithmConfig`): +- `algorithm_type`: `multi_step_grpo` (Group Relative Policy Optimization for multi-step tasks) +- `group_size`: Number of policy update iterations per batch (default: 16) +- `batch_size`: Batch size for training (default: 32) +- `learning_rate`: Learning rate (default: 1e-6) + +**Model Configuration** (`TunerModelConfig`): +- `model_path`: Path to the base model (e.g., `Qwen/Qwen2.5-3B-Instruct`) +- `max_model_len`: Maximum model context length (default: 25600) +- `max_tokens`: Maximum tokens for response generation (default: 2048) +- `inference_engine_num`: Number of inference engines (default: 6, using 6 GPUs for inference) + +**Dataset Configuration** (`DatasetConfig`): +- `path`: Path to the dataset (default: `/path/to/frozenlake`) +- `split`: Split of the dataset (default: `train`) + +Adjust these parameters based on your hardware resources and training requirements. Other parameters can be spetified in [config.yaml](./config.yaml). + + +### Step 3: Set Up Ray Cluster + +Set up a [Ray](https://github.com/ray-project/ray) cluster: + +```bash +ray start --head +# for multi-node setup, run the following command on worker nodes +# ray start --address= +``` + +### Step 4: Run the Training Script + +```bash +python main.py +``` + +The training will start and you can monitor the progress through the logs. Checkpoints will be saved once every `trainer.save_interval` steps. + +## Experimental Results + +### Training Reward Curve + +The reward curve during training shows the agent's learning progress: + +![reward](./critic_rewards_mean.png) + +The training reward typically increases over epochs as the agent learns to navigate the frozen lake more effectively. + +### Example Agent Output + +An example of agent output is given below: +``` +From the current observation, let's analyze the situation. The player (P) is at: (4, 0), and the goal (G) is at: (2, 3). There is also a hole (O) at (4, 4). Given this, I can move towards the goal without worrying about slippery tiles right now. + +The shortest path from P to G involves moving left (4 steps) followed by moving down (1 step), since going directly would bypass the hole or move us further from the goal. Let's move left first. + +Let's take the action ```Left```. +``` \ No newline at end of file diff --git a/tuner/frozen_lake/README_zh.md b/tuner/frozen_lake/README_zh.md new file mode 100644 index 0000000..6b54c6f --- /dev/null +++ b/tuner/frozen_lake/README_zh.md @@ -0,0 +1,250 @@ +# 使用 AgentScope-Tuner 训练 FrozenLake Agent + +## 摘要 + +本示例展示如何使用 AgentScope-Tuner 配合 [Trinity-RFT](https://github.com/agentscope-ai/Trinity-RFT) 对 [Frozen Lake](https://gymnasium.farama.org/environments/toy_text/frozen_lake/) 任务进行强化微调。智能体需要在冰湖网格中从起点走到终点,避开坑洞,并在有限步数内完成任务。 + +## 任务设定 + +### 智能体目标 +智能体要在冰湖网格上从起点 (S) 抵达终点 (G),同时: +- 规划路径经过冰面 (F) 到达终点 +- 避开会结束回合且奖励为 0 的坑洞 (H) +- 在限定步数内完成任务 + +### 智能体类型 +智能体实现为 **ReActAgent**,它的行为包括: +- 观察当前冰湖网格状态 +- 推理下一步最优动作 +- 执行动作(上、下、左、右)在环境中移动 +- 在多步交互中维护内部状态 + +### 环境 +环境基于 Gymnasium 的 FrozenLake,并提供: +- **网格导航**:随机生成 2x2 至 6x6 的地图 +- **格子类型**: + - `S`:起点 + - `F`:冰面(可通行) + - `H`:坑洞(奖励 0,结束回合) + - `G`:终点(奖励 +1.0,结束回合) +- **动作空间**:离散动作(上、下、左、右) +- **奖励设计**: + - 到达终点 +1.0 + - 掉入坑洞或未在最大步数内到达终点为 0.0 +- **观测**:返回当前玩家位置的文本网格表示 + +智能体不使用外部工具,直接通过以下接口与环境交互: +- `env.reset(task)`:根据任务参数初始化环境 +- `env.step(action)`:执行动作,返回观测、奖励和结束标志 +- `env.render()`:返回当前状态的文本表示 + +## 数据集准备 + +数据集包含用于生成 FrozenLake 环境的任务参数,每个样本包含: +- `seed`:随机种子,保证地图可复现 +- `size`:网格大小(在 2 和 `map_max_size` 之间随机,如 4x4、6x6) +- `p`:格子为冰面的概率(0.6 到 0.85 之间随机),其余为坑洞 +- `index`:样本索引 +- `uid`:由 seed、size、p 组合而成的唯一 ID + +运行数据准备脚本生成训练集与测试集: + +```bash +python get_frozenlake_data.py --map_max_size 6 --train_size 10000 --test_size 100 +``` + +生成的目录结构示例: +``` +/path/to/frozenlake_dataset/ + ├── train.parquet # 10000 条训练样本 + └── test.parquet # 100 条测试样本 +``` + +样本示例: +```json +{"seed": 12345, "size": 5, "p": 0.75, "index": 0, "uid": "12345_5_0.75"} +``` + +**注意**:脚本会过滤无解的地图,确保在最大步数 (`env_max_steps=8`) 内存在从起点到终点的可行路径。 + +## 代码实现 + +本节提供代码实现的高级概览。详细实现请参考源代码。 + +### 高级概览 +实现由三部分组成: +1. **Agent** (`FrozenLakeAgent`):继承 `ReActAgent`,负责多步交互 +2. **环境** (`FrozenLakeEnv`):封装 Gymnasium FrozenLake +3. **工作流** (`run_frozen_lake`):组织智能体与环境的交互流程 + +### 工作流 +`run_frozen_lake` 实现多步交互流程: + +```python +async def run_frozen_lake( + task: Dict, + model: ChatModelBase, + auxiliary_models: Dict[str, ChatModelBase], +) -> WorkflowOutput: + # ... + + # 创建智能体和环境 + agent = FrozenLakeAgent(model=model, ...) + env = FrozenLakeEnv(...) + observation, _ = env.reset(task) + rewards = [] + # ... + + # 智能体-环境交互循环 + for _ in range(max_steps): + response = await agent.reply(msg=Msg("user", agent.get_prompt(observation), role="user")) + action = agent.get_action(response) + observation, reward, done, _ = env.step(action) + rewards.append(reward) + if done: + break + + # ... + final_reward = sum(rewards) + final_response = Msg("assistant", response_content, role="assistant") + + return WorkflowOutput( + reward=final_reward, + response=final_response, + metrics={"env_steps": float(step_count), "env_done": float(done)}, + ) +``` + +**关键特性:** +- 多步交互:单次 episode 内多次动作,不是单轮 QA +- 状态跟踪:记录当前步、上次动作与观测 +- 错误处理:无效动作或异常会被捕获并处理 + +### 奖励函数 +无需额外 judge,奖励由环境直接给出: +- 1.0:到达终点 +- 0.0:掉入坑洞或超步数未达终点 + +工作流返回: +- `reward`:累计奖励 +- `response`:包含观测、总奖励、步数、终止原因的最终回复 +- `metrics`:`env_steps`(步数)、`env_done`(是否结束) + +### 实现细节 + +环境 (`FrozenLakeEnv`) 封装了 Gymnasium 的 FrozenLake,提供: +- `reset(task)`: 使用任务参数初始化环境 +- `step(action)`: 执行动作并返回 (observation, reward, done, info) +- `render()`: 返回当前状态的文本表示 + +智能体 (`FrozenLakeAgent`) 继承 `ReActAgent`,提供: +- `reply(msg)`: 回复消息并返回动作(继承自 AgentScope) +- `get_prompt(observation)`: 从当前观测生成提示 +- `get_action(response)`: 解析模型响应以提取动作(Up/Down/Left/Right) +- `update_state(action, observation)`: 在每步后更新内部状态 + +详细实现请参考 [frozenlake_env.py](./frozenlake_env.py) 和 [frozenlake_agent.py](./frozenlake_agent.py)。 + +### 步骤 4:使用 `tune` 训练工作流 + +```python +from agentscope.tuner import tune, DatasetConfig + +if __name__ == "__main__": + config_path = os.path.join( + os.path.dirname(__file__), + "config.yaml", + ) + dataset = DatasetConfig( + path="/path/to/frozenlake_dataset", + name="default", + split="train", + ) + tune( + workflow_func=run_frozen_lake, + train_dataset=dataset, + config_path=config_path, + ) +``` + +训练配置请参考 [config.yaml](./config.yaml)。完整配置详情请参考 [Trinity-RFT 配置指南](https://agentscope-ai.github.io/Trinity-RFT/en/main/tutorial/trinity_configs.html)。 + +--- + +## 运行方法 + +### 依赖 +- 至少 2 张 NVIDIA GPU,CUDA 版本 ≥ 12.8 +- 按 [Trinity-RFT 安装指南](https://agentscope-ai.github.io/Trinity-RFT/en/main/tutorial/trinity_installation.html) 从源码安装 +- 安装 gymnasium 冰湖环境: + ```bash + pip install gymnasium[toy_text] + ``` +- 下载模型权重(示例): + ```bash + huggingface-cli download Qwen/Qwen2.5-3B-Instruct + ``` + +### 步骤 1:准备数据集 +```bash +python get_frozenlake_data.py --map_max_size 6 --train_size 10000 --test_size 100 +``` +将 `main.py` 中的数据集路径改为你的生成目录。 + +### 步骤 2:配置训练 + +关键配置可在代码中设置,包括: + +**算法配置** (`AlgorithmConfig`): +- `algorithm_type`: `multi_step_grpo`(用于多步任务的组相对策略优化) +- `group_size`: 每批次的策略更新组大小(默认 16) +- `batch_size`: 批大小(默认 32) +- `learning_rate`: 学习率(默认 1e-6) + +**模型配置** (`TunerModelConfig`): +- `model_path`: 基础模型路径(如 `Qwen/Qwen2.5-3B-Instruct`) +- `max_model_len`: 最大上下文长度(默认 25600) +- `max_tokens`: 响应最大生成长度(默认 2048) +- `inference_engine_num`: 推理引擎数量(默认 6,表示用 6 个 GPU 进行推理) + +**数据集配置** (`DatasetConfig`): +- `path`: 数据集路径(默认 `/path/to/frozenlake`) +- `split`: 数据集分片(默认 `train`) + +可根据硬件资源和训练需求调整这些参数。其他参数可在 [config.yaml](./config.yaml) 中指定。 + +### 步骤 3:设置 Ray 集群 + +设置 [Ray](https://github.com/ray-project/ray) 集群: +```bash +ray start --head +# 对于多节点设置,在工作节点上运行以下命令 +# ray start --address= +``` + +### 步骤 4:运行训练脚本 +```bash +python main.py +``` +训练将开始,可通过日志监控进度。检查点将每 `trainer.save_interval` 步保存一次。 + +## 实验结果 + +### 训练奖励曲线 + +训练过程中的奖励曲线显示智能体的学习进度: + +![reward](./critic_rewards_mean.png) + +训练奖励通常随着智能体学习更有效地导航冰湖而随训练轮次增加。 + +### 智能体输出示例 + +智能体输出示例如下: +``` +From the current observation, let's analyze the situation. The player (P) is at: (4, 0), and the goal (G) is at: (2, 3). There is also a hole (O) at (4, 4). Given this, I can move towards the goal without worrying about slippery tiles right now. + +The shortest path from P to G involves moving left (4 steps) followed by moving down (1 step), since going directly would bypass the hole or move us further from the goal. Let's move left first. + +Let's take the action ```Left```. +``` diff --git a/tuner/frozen_lake/_frozenlake_agent.py b/tuner/frozen_lake/_frozenlake_agent.py new file mode 100644 index 0000000..634aa74 --- /dev/null +++ b/tuner/frozen_lake/_frozenlake_agent.py @@ -0,0 +1,102 @@ +# -*- coding: utf-8 -*- +"""Adapted from Trinity-RFT""" +import re +from _utils import SYSTEM_PROMPT, FrozenLakeAction # pylint: disable=E0611 +from agentscope.agent import ReActAgent +from agentscope.formatter import OpenAIChatFormatter +from agentscope.message import Msg +from agentscope.model import OpenAIChatModel + + +INVALID_ACTION = "still" +VALID_ACTIONS = { + "left": 1, + "down": 2, + "right": 3, + "up": 4, +} + + +class FrozenLakeAgent(ReActAgent): + """Agent for FrozenLake environment.""" + + def __init__(self, model: OpenAIChatModel, max_steps: int = 20): + super().__init__( + name="frozenlake_agent", + model=model, + sys_prompt=SYSTEM_PROMPT, + formatter=OpenAIChatFormatter(), + max_iters=1, + ) + self.response_structure = FrozenLakeAction + self.current_step = 0 + self.last_action = None + self.last_observation = None + self.max_steps = max_steps + + def get_prompt(self, observation: str) -> str: + """Get prompt for the agent based on current observation.""" + prompt = ( + f"Current Observation ({self.current_step}): \n" + + observation + + "\n" + + ( + "You have not achieved the goal, P has not reached G yet. " + "Please give the next action." + ) + ) + if self.current_step > 0 and self.last_action is not None: + if self.last_observation == observation: + prompt += ( + "\nYour last response is invalid. " + "Your position didn't change at all. " + "You may need to recheck your thinking process, " + "action outputted, and the format of response. " + "Remember, you should only output the NEXT ACTION " + "at each iteration in the ``` ```. " + "For example, if you want to move up, " + "you should output ```Up```." + ) + + if ( + self.max_steps is not None + and self.max_steps - self.current_step > 0 + ): + remaining = self.max_steps - self.current_step + prompt += ( + f"\nThe maximum number of steps remaining is {remaining}." + ) + + return prompt + + def get_action(self, msg: Msg) -> str: + """Extract action from agent response message.""" + response: str = ( + msg.content + if isinstance(msg.content, str) + else msg.content[0].get("text") + ) + action = INVALID_ACTION + + matches = re.findall(r"```(.*?)```", response, re.DOTALL) + + if matches: + last_match_content = matches[-1].strip() + action = last_match_content.lower() + if action not in VALID_ACTIONS: + action = INVALID_ACTION + + return action + + def update_state(self, action: str, observation: str) -> None: + """Update agent state with action and observation.""" + self.last_action = action + self.last_observation = observation + self.current_step += 1 + + async def reset(self) -> None: + """Reset agent state for a new episode.""" + self.current_step = 0 + self.last_action = None + self.last_observation = None + await self.memory.clear() diff --git a/tuner/frozen_lake/_frozenlake_env.py b/tuner/frozen_lake/_frozenlake_env.py new file mode 100644 index 0000000..8f874d6 --- /dev/null +++ b/tuner/frozen_lake/_frozenlake_env.py @@ -0,0 +1,316 @@ +# -*- coding: utf-8 -*- +"""Adapted from Trinity-RFT""" +import copy +from typing import Dict, Optional, Tuple, Union +import numpy as np + +try: + from gymnasium.envs.toy_text.frozen_lake import ( + FrozenLakeEnv as GymFrozenLakeEnv, + ) +except ImportError: + GymFrozenLakeEnv = object +from _utils import ( # pylint: disable=E0611 + generate_random_map, + get_goal_position, +) + + +class FrozenLakeEnv(GymFrozenLakeEnv): + """FrozenLake environment wrapper.""" + + # Map gym state in integer + MAP_LOOKUP = { + b"P": 0, + b"F": 1, + b"H": 2, + b"G": 3, + } + + # Define rules to transform to rendered text observation of the environment + GRID_LOOKUP = { + 0: " P \t", # player + 1: " _ \t", # frozen + 2: " O \t", # hole + 3: " G \t", # goal + 4: " X \t", # player fall into hole + 5: " √ \t", # player on goal + } + + ACTION_LOOKUP = { + "still": 0, + "left": 1, + "down": 2, + "right": 3, + "up": 4, + } + + INVALID_ACTION = 0 + PENALTY_FOR_INVALID = -1 + + def __init__( + self, + max_steps: int = 8, + desc: Optional[str] = None, + is_slippery: bool = False, + size: int = 8, + p: float = 0.8, + seed: int = 42, + ): + self.max_steps = max_steps or 8 + self.desc: Union[str, np.ndarray, None] = desc + self.is_slippery = is_slippery + self.size = size + self.p = p + self.seed = seed + self.render_mode: Optional[str] = None + try: + import gymnasium as gym + except ImportError as e: + error_message = ( + "Gymnasium is not installed. " + "Please install gymnasium first before " + "running the frozen_lake workflow. " + f"Error: {str(e)}" + ) + raise ImportError(error_message) from e + + if self.desc is None: + random_map, goal_position = generate_random_map( + size=self.size, + p=self.p, + seed=self.seed, + max_steps=self.max_steps, + ) + else: + random_map = np.asarray(copy.deepcopy(self.desc), dtype="c") + goal_position = get_goal_position(random_map) + + self.goal_position = goal_position + + super().__init__( + desc=random_map[:], + is_slippery=self.is_slippery, + ) + assert isinstance(self.desc, np.ndarray) + self.action_space = gym.spaces.Discrete(4, start=1) + + self.map_kwargs = { + "size": size, + "p": p, + } + self.env_kwargs = { + "is_slippery": is_slippery, + "desc": copy.deepcopy(desc), + "seed": seed, + } + + self.action_map = { + 1: 0, # left + 2: 1, # down + 3: 2, # right + 4: 3, # up + } + + def _get_player_position(self) -> Tuple[int, int]: + return (self.s // self.ncol, self.s % self.ncol) # (row, col) + + def step(self, action: str) -> Tuple[str, float, bool, Dict]: + """Execute a step in the environment. + + Maps custom action to gymnasium FrozenLakeEnv action and + takes the step. Checks if the action is effective (whether + player moves in the env). + + Args: + action: The action to take. + + Returns: + Tuple of (observation, reward, done, info). + """ + if self.success(): + obs = self.render(mode="tiny_rgb_array") + assert isinstance(obs, str) + return obs, 1.0, True, {"action_is_effective": False} + + action_id: int = self.ACTION_LOOKUP.get(action.lower(), 0) + + if not action_id: + action_id = self.INVALID_ACTION + + if ( + action_id == self.INVALID_ACTION + or action_id not in self.action_map + ): + obs = self.render(mode="tiny_rgb_array") + assert isinstance(obs, str) + return obs, 0.0, False, {"action_is_effective": False} + + prev_player_position = int(self.s) + + # Call parent class step method + # Note: GymFrozenLakeEnv is imported at module level + player_pos, reward, done, _, _ = super().step( + self.action_map[action_id], + ) + + obs = self.render(mode="tiny_rgb_array") + assert isinstance(obs, str) + return ( + obs, + float(reward), + bool(done), + {"action_is_effective": prev_player_position != int(player_pos)}, + ) + + def render( + self, + mode: str = "tiny_rgb_array", + ) -> str | list[str] | np.ndarray: + """Render the environment. + + Args: + mode: Rendering mode. Options: "tiny_rgb_array", "list", + "state", "rgb_array", "ansi". + + Returns: + Rendered observation based on the mode. + """ + assert mode in [ + "tiny_rgb_array", + "list", + "state", + "rgb_array", + "ansi", + ] + if mode in ["rgb_array", "ansi"]: + prev_render_mode = self.render_mode + self.render_mode = mode + obs = super().render() + self.render_mode = prev_render_mode + return obs + assert isinstance(self.desc, np.ndarray) + room_state = copy.deepcopy(self.desc) + + # replace the position of start 'S' with 'F' + position_S = np.where(room_state == b"S") + room_state[position_S] = b"F" + + # replace the position of the player with 'P' + position_P = self._get_player_position() + room_state[position_P] = b"P" + + if mode == "state": + # transform 'S', 'F', 'H', 'G' to numpy integer array + room_state = np.vectorize(lambda x: self.MAP_LOOKUP[x])(room_state) + # add player in hole or player on goal + if self.desc[position_P] == b"H": + room_state[position_P] = 4 + elif self.desc[position_P] == b"G": + room_state[position_P] = 5 + return room_state + + room_state = self.render(mode="state").tolist() + assert isinstance(room_state, list) + + if mode == "list": + + def lookup_list(cell: int) -> str: + return self.GRID_LOOKUP.get(cell, "?").strip("\t").strip() + + return [ + " ".join(lookup_list(cell) for cell in row) + for row in room_state + ] + + if mode == "tiny_rgb_array": + + def lookup_tiny(cell: int) -> str: + return self.GRID_LOOKUP.get(cell, "?") + + result = "\n".join( + "".join(lookup_tiny(cell) for cell in row) + for row in room_state + ) + return result + + # Default return for other modes + return "" + + def reset( + self, + task: Optional[Dict] = None, + ) -> tuple[str, Dict]: + """Reset the environment with optional task parameters.""" + task = task or {} + # Update parameters from task if provided + size = task.get("size", self.map_kwargs["size"]) + p = task.get("p", self.map_kwargs["p"]) + seed = task.get("seed", self.env_kwargs["seed"]) + is_slippery = task.get( + "is_slippery", + self.env_kwargs["is_slippery"], + ) + desc = task.get("desc", self.env_kwargs.get("desc")) + + # Update instance variables + self.size = size + self.p = p + self.seed = seed + self.is_slippery = is_slippery + self.map_kwargs["size"] = size + self.map_kwargs["p"] = p + self.env_kwargs["seed"] = seed + self.env_kwargs["is_slippery"] = is_slippery + if desc is not None: + self.env_kwargs["desc"] = copy.deepcopy(desc) + + if desc is None: + random_map, goal_position = generate_random_map( + size=size, + p=p, + seed=seed, + max_steps=self.max_steps, + ) + else: + random_map = np.asarray(copy.deepcopy(desc), dtype="c") + goal_position = get_goal_position(random_map) + + self.goal_position = goal_position + self.desc = random_map[:] + + # Reinitialize parent class with new map + try: + import gymnasium as gym + + super().__init__( + desc=random_map[:], + is_slippery=self.is_slippery, + ) + assert isinstance(self.desc, np.ndarray) + self.action_space = gym.spaces.Discrete(4, start=1) + except ImportError as e: + error_message = ( + "Gymnasium is not installed. " + "Please install gymnasium first before " + "running the frozen_lake workflow. " + f"Error: {str(e)}" + ) + raise ImportError(error_message) from e + + super().reset(seed=self.seed) + obs = self.render(mode="tiny_rgb_array") + assert isinstance(obs, str) + return obs, {} + + def finished(self) -> bool: + """Check if the episode is finished (goal or hole).""" + player_pos = self._get_player_position() + assert isinstance(self.desc, np.ndarray) + return self.desc[player_pos] in b"GH" # type: ignore + + def success(self) -> bool: + """Check if the agent has reached the goal (G).""" + player_pos = self._get_player_position() + assert isinstance(self.desc, np.ndarray) + return self.desc[player_pos] in b"G" diff --git a/tuner/frozen_lake/_utils.py b/tuner/frozen_lake/_utils.py new file mode 100644 index 0000000..7be8f50 --- /dev/null +++ b/tuner/frozen_lake/_utils.py @@ -0,0 +1,209 @@ +# -*- coding: utf-8 -*- +""" +Utils for the FrozenLake environment. +Modified from rllm +""" + +from typing import Literal, Optional, Tuple +import numpy as np +from pydantic import BaseModel, Field + +# Map gym state in integer +MAP_LOOKUP = { + b"P": 0, + b"F": 1, + b"H": 2, + b"G": 3, +} + +# Define rules to transform to rendered text observation of the environment +GRID_LOOKUP = { + 0: " P \t", # player + 1: " _ \t", # frozen + 2: " O \t", # hole + 3: " G \t", # goal + 4: " X \t", # player fall into hole + 5: " √ \t", # player on goal +} + +ACTION_LOOKUP = { + 0: "None", + 1: "Left", + 2: "Down", + 3: "Right", + 4: "Up", +} + +# Prompting format inspired by the RAGEN project +SYSTEM_PROMPT = """You are Qwen, created by Alibaba Cloud. \ +You are a helpful assistant. You are walking on a frozen lake. + +FrozenLake Quick Guide +Goal: Reach the goal (G). Player (P) and Goal (G) must overlap. + +Symbols: +_ Frozen | O Hole | G Goal | P Player + +Rules: +1. Avoid falling into holes (O). +2. Frozen tiles are slippery, you may move perpendicular to + your intended direction. + +Valid Action (separated by | ): +Up | Down | Left | Right + +Rewards: +Fall into hole: 0 +Reach goal: +1.0 + +You will be provided the current observation, please decide on +the next Action. +You should show your thought process and then input the final +action in ``` ```. +You should only output the NEXT ACTION at each iteration in +the ``` ```. For example, if you want to move up, you should +output ```Up```. +You should plan ahead and need to achieve it in minimum number +of steps. +You should be aware that frozen tiles can be slippery, but the +chance is small and you should not overthink it. + +Please show your thinking process and put the final action in +``` ```. In every turn, the final action MUST be one of Up, +Down, Left, Right. +""" + + +class FrozenLakeAction(BaseModel): + """Action model for FrozenLake environment.""" + + action: Literal["Up", "Down", "Left", "Right"] = Field( + description=( + "The action to take in the FrozenLake environment, " + "must be one of Up, Down, Left, Right" + ), + ) + + +def is_valid(board: list[list[str]], max_size: int, max_steps: int) -> bool: + """DFS to check that it's a valid path. + + Args: + board: The board representation as a list of lists. + max_size: Maximum size of the board. + max_steps: Maximum number of steps allowed. + + Returns: + True if there's a valid path from start to goal within max_steps, + False otherwise. + """ + frontier, discovered = [], set() + # find the start point + start_r, start_c = np.where(np.array(board) == "S") + frontier.append((start_r[0], start_c[0], 0)) # row, col steps + # dfs to check if there is a path from start to goal + while frontier: + r, c, steps = frontier.pop() + if steps > max_steps: + continue + + if (r, c) not in discovered: + discovered.add((r, c)) + directions = [(1, 0), (0, 1), (-1, 0), (0, -1)] + for x, y in directions: + r_new = r + x + c_new = c + y + if ( + r_new < 0 + or r_new >= max_size + or c_new < 0 + or c_new >= max_size + ): # noqa: PLR2004 + continue + if board[r_new][c_new] == "G": + return True + if board[r_new][c_new] != "H": + frontier.append((r_new, c_new, steps + 1)) + return False + + +def generate_random_map( + size: int = 8, + p: float = 0.8, + seed: int = 0, + max_steps: int = 5, +) -> Tuple[list[str], Tuple[int, int]]: + """Generates a random valid map (one that has a path from start to goal). + + Args: + size: Size of each side of the grid. + p: Probability that a tile is frozen. + seed: Seed to ensure the generation of reproducible maps. + max_steps: Maximum number of steps allowed. + + Returns: + A tuple containing a random valid map and the goal position (row, col). + """ + valid = False + board: list[list[str]] = [] # initialize to make pyright happy + + try: + from gymnasium.utils import seeding + + np_random, _ = seeding.np_random(seed) + except ImportError as exc: + raise ImportError( + "Gymnasium is not installed. " + "Please install gymnasium first before " + "running the frozen_lake workflow.", + ) from exc + + # generate random start and end points + while not valid: + p = min(1, p) + board = np_random.choice( + ["F", "H"], + (size, size), + p=[p, 1 - p], + ).tolist() + + while True: + start_r = int(np_random.integers(0, size)) + start_c = int(np_random.integers(0, size)) + goal_r = int(np_random.integers(0, size)) + goal_c = int(np_random.integers(0, size)) + + # Ensure start and goal are different positions + if (start_r, start_c) != (goal_r, goal_c): + break + + board[start_r][start_c] = "S" + board[goal_r][goal_c] = "G" + + valid = is_valid(board, size, max_steps) + return ["".join(x) for x in board], (goal_r, goal_c) + + +def get_goal_position( + random_map: np.ndarray, +) -> Optional[Tuple[int, int]]: + """Get the goal position from a random map. + + Args: + random_map: The map as a numpy array. + + Returns: + Tuple of (row, col) if goal found, None otherwise. + """ + positions = np.argwhere(random_map == b"G") + if positions.size == 0: + return None # G not found + return tuple(positions[0]) # returns (row, col) + + +__all__ = [ + "SYSTEM_PROMPT", + "FrozenLakeAction", + "generate_random_map", + "get_goal_position", +] diff --git a/tuner/frozen_lake/config.yaml b/tuner/frozen_lake/config.yaml new file mode 100644 index 0000000..e9f128b --- /dev/null +++ b/tuner/frozen_lake/config.yaml @@ -0,0 +1,53 @@ +project: "AgentScope" # Project name +name: "FrozenLake" # Experiment name +checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints} # Directory to save model checkpoints +algorithm: + algorithm_type: multi_step_grpo # GRPO series for multi-step scenario + repeat_times: 16 # Number of rollouts per prompt for advantage estimation + kl_loss_fn: "low_var_kl" + kl_loss_fn_args: + kl_coef: 0 # KL divergence coefficient + advantage_fn_args: + epsilon: 1e-6 # Small value for numerical stability + std_threshold: 0.0001 # Threshold for standard deviation + optimizer: + lr: 1e-6 # Learning rate +model: + model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-3B-Instruct} # Base model path + max_prompt_tokens: 23552 # Max tokens for prompt + max_response_tokens: 2048 # Max tokens per response + max_model_len: 25600 # Max context length + temperature: 1.0 # Sampling temperature +buffer: + total_epochs: 5 # Total training epochs + batch_size: 32 # Batch size per explore step + train_batch_size: 1024 # Total experiences per training step + trainer_input: + experience_buffer: + name: experience_buffer + storage_type: queue + max_read_timeout: 7200 # Max timeout for reading from buffer (seconds) + replay_buffer: + enable: true # Enable experience replay + priority_fn: linear_decay # Priority function for replay buffer + priority_fn_args: + decay: 0.1 # Decay rate for priority function +explorer: + runner_per_model: 16 # Number of runners per model + rollout_model: + engine_num: 6 # Number of vLLM engines for rollout model + tensor_parallel_size: 1 # TP size per engine for rollout model + enable_openai_api: true # Enable OpenAI-compatible API + enable_history: true # Enable conversation history + enable_auto_tool_choice: true # Enable automatic tool selection + tool_call_parser: hermes # Parser for tool calls +trainer: + save_interval: 100 # Save checkpoint every N steps + use_dynamic_bsz: true # Use dynamic batch size + grad_clip: 1.0 # Gradient clipping value + max_token_len_per_gpu: 25600 # Max token length per GPU + ulysses_sequence_parallel_size: 2 # Sequence parallel size for Ulysses +synchronizer: + sync_style: dynamic_by_explorer # Sync triggered dynamically by explorer + sync_interval: 1 # Sync every N steps + sync_timeout: 1200 # Timeout for synchronization (seconds) diff --git a/tuner/frozen_lake/critic_rewards_mean.png b/tuner/frozen_lake/critic_rewards_mean.png new file mode 100644 index 0000000..df013c6 Binary files /dev/null and b/tuner/frozen_lake/critic_rewards_mean.png differ diff --git a/tuner/frozen_lake/get_frozenlake_data.py b/tuner/frozen_lake/get_frozenlake_data.py new file mode 100644 index 0000000..ea61d5a --- /dev/null +++ b/tuner/frozen_lake/get_frozenlake_data.py @@ -0,0 +1,131 @@ +# -*- coding: utf-8 -*- +""" +Modified from rllm +""" +import argparse +import os + +import numpy as np +import pandas as pd + + +DEFAULT_DATA_PATH = os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "..", + "data", + "frozenlake", +) + + +def save_dataset_to_local( + data_path: str, + data: list[dict], + split: str = "default", +) -> str: + """Save dataset directly to local data_path. + + Args: + data_path: Path to save the dataset + data: List of dictionaries containing the dataset examples + split: Split name (e.g., 'train', 'test', 'default') + + Returns: + str: Path to the saved parquet file + """ + os.makedirs(data_path, exist_ok=True) + + # Convert to DataFrame and save + data_df = pd.DataFrame(data) + dataset_path = os.path.join(data_path, f"{split}.parquet") + data_df.to_parquet(dataset_path) + + print( + f"Saved dataset frozenlake split '{split}' " + f"with {len(data)} examples at {dataset_path}. " + f"Make sure to set the environment variable " + f" to {data_path}.", + ) + + return dataset_path + + +def prepare_frozenlake_data( + data_path: str, + train_size: int = 10000, + test_size: int = 100, + map_max_size: int = 6, +) -> tuple[list[dict], list[dict]]: + """ + Prepare and save FrozenLake datasets for training and testing. + + Args: + data_path (str): Path to save the dataset + train_size (int): Number of training examples to generate + test_size (int): Number of test examples to generate + map_max_size (int): Maximum size of the map + + Returns: + tuple: (train_data, test_data) - Lists of data dictionaries + """ + # Set random seed for reproducibility + np.random.seed(42) + + # Generate random parameters for train and test sets + train_seeds = np.random.randint(0, 100000, size=train_size) + test_seeds = np.random.randint(0, 100000, size=test_size) + train_sizes = np.random.randint(2, map_max_size, size=train_size) + test_sizes = np.random.randint(2, map_max_size, size=test_size) + train_ps = np.random.uniform(0.6, 0.85, size=train_size) + test_ps = np.random.uniform(0.6, 0.85, size=test_size) + + def frozenlake_process_fn( + seed: int, + size: int, + p: float, + idx: int, + ) -> dict: + """Process function to create FrozenLake task instances.""" + return { + "seed": seed, + "size": size, + "p": p, + "index": idx, + "uid": f"{seed}_{size}_{p}", + } + + # Create train and test data + train_data_list = [ + frozenlake_process_fn(seed, train_sizes[idx], train_ps[idx], idx) + for idx, seed in enumerate(train_seeds) + ] + test_data_list = [ + frozenlake_process_fn(seed, test_sizes[idx], test_ps[idx], idx) + for idx, seed in enumerate(test_seeds) + ] + + # Save datasets directly to local DATA_PATH + save_dataset_to_local(data_path, train_data_list, "train") + save_dataset_to_local(data_path, test_data_list, "test") + + return train_data_list, test_data_list + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--local_dir", default=DEFAULT_DATA_PATH) + parser.add_argument("--train_size", type=int, default=10000) + parser.add_argument("--test_size", type=int, default=100) + parser.add_argument("--map_max_size", type=int, default=6) + args = parser.parse_args() + + train_data, test_data = prepare_frozenlake_data( + data_path=args.local_dir, + train_size=args.train_size, + test_size=args.test_size, + map_max_size=args.map_max_size, + ) + + print(f"Train dataset: {len(train_data)} examples") + print(f"Test dataset: {len(test_data)} examples") + print("Sample train example:", train_data[0]) + print("Sample test example:", test_data[0]) diff --git a/tuner/frozen_lake/main.py b/tuner/frozen_lake/main.py new file mode 100644 index 0000000..3f90451 --- /dev/null +++ b/tuner/frozen_lake/main.py @@ -0,0 +1,151 @@ +# -*- coding: utf-8 -*- +"""Example of training a FrozenLake agent with Trinity-RFT.""" +import os +from typing import Dict +from _frozenlake_agent import FrozenLakeAgent +from _frozenlake_env import FrozenLakeEnv +from agentscope.message import Msg +from agentscope.tuner import ( + tune, + WorkflowOutput, + DatasetConfig, + TunerModelConfig, + AlgorithmConfig, +) +from agentscope.model import ChatModelBase + + +async def run_frozen_lake( + task: Dict, + model: ChatModelBase, + auxiliary_models: Dict[str, ChatModelBase], +) -> WorkflowOutput: + """A workflow function using the FrozenLake agent to solve tasks. + + Args: + task (Dict): The task to be solved, containing environment parameters + like size, p, seed, is_slippery, etc. + model (ChatModelBase): The language model to use. + + Returns: + WorkflowOutput: The workflow output containing the reward, response and + metrics. + """ + + assert len(auxiliary_models) == 0, "No auxiliary models are needed" + + # Extract workflow arguments from task or use defaults + workflow_args = task.get("workflow_args", {}) + if not workflow_args: + workflow_args = task + + env_max_steps = workflow_args.get("env_max_steps", 8) + agent_max_steps = workflow_args.get("agent_max_steps", 10) + is_slippery = workflow_args.get("is_slippery", False) + desc = workflow_args.get("desc", None) + + # Extract task-specific arguments (for environment generation) + size = task.get("size", 8) + p = task.get("p", 0.8) + seed = task.get("seed", 42) + + # Initialize agent and environment + agent = FrozenLakeAgent(model=model, max_steps=agent_max_steps) + env = FrozenLakeEnv( + max_steps=env_max_steps, + desc=desc, + is_slippery=is_slippery, + size=size, + p=p, + seed=seed, + ) + + # Reset environment with task parameters + observation, _ = env.reset(task) + observation_str = str(observation) + rewards = [] + step_count = 0 + done = False + terminate_reason = None + + # Run agent-environment interaction loop + for _ in range(agent_max_steps): + step_count += 1 + try: + # get prompt + prompt = agent.get_prompt(observation_str) + + response = await agent.reply(msg=Msg("user", prompt, role="user")) + + # record action and observation + action = agent.get_action(response) + agent.update_state(action=action, observation=observation_str) + + except Exception as e: + terminate_reason = f"agent_error: {str(e)}" + break + + # environment step + observation, reward, done, _ = env.step(action) + observation_str = str(observation) + rewards.append(reward) + + if done: + terminate_reason = "success" if env.success() else "hole" + break + + if terminate_reason is None: + terminate_reason = "max_steps_reached" + + final_reward = sum(rewards) + final_observation = observation_str + + # Create response message with environment information + response_content = ( + f"Final observation:\n{final_observation}\n" + f"Total reward: {final_reward}\n" + f"Steps taken: {step_count}\n" + f"Terminate reason: {terminate_reason}" + ) + + final_response = Msg("assistant", response_content, role="assistant") + + return WorkflowOutput( + reward=final_reward, + response=final_response, + metrics={ + "env_steps": float(step_count), + "env_done": float(done), + }, + ) + + +if __name__ == "__main__": + dataset = DatasetConfig( + path="/path/to/frozenlake", + split="train", + ) + tuner_model = TunerModelConfig( + model_path="Qwen/Qwen2.5-3B-Instruct", + max_model_len=25600, + max_tokens=2048, + inference_engine_num=6, + reasoning_parser=None, + ) + algorithm = AlgorithmConfig( + algorithm_type="multi_step_grpo", + group_size=16, + batch_size=32, + learning_rate=1e-6, + ) + config_path = os.path.join( + os.path.dirname(__file__), + "config.yaml", + ) # define some default parameters + tune( + workflow_func=run_frozen_lake, + model=tuner_model, + train_dataset=dataset, + algorithm=algorithm, + config_path=config_path, + ) diff --git a/tuner/learn_to_ask/README.md b/tuner/learn_to_ask/README.md index 430303c..b53f384 100644 --- a/tuner/learn_to_ask/README.md +++ b/tuner/learn_to_ask/README.md @@ -313,7 +313,7 @@ python tuner/learn_to_ask/data_prepare/3_rollout_then_evaluate.py \ ``` > ⚠️ **Note**: Your trained model must be converted to **Hugging Face format** first. -> See: [Converting FSDP Checkpoints Guide](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/faq.html) +> See: [Converting FSDP Checkpoints Guide](https://agentscope-ai.github.io/Trinity-RFT/en/main/tutorial/faq.html) --- diff --git a/tuner/math_agent/README.md b/tuner/math_agent/README.md index 00759b8..8035b0b 100644 --- a/tuner/math_agent/README.md +++ b/tuner/math_agent/README.md @@ -248,10 +248,10 @@ if __name__ == "__main__": Here, we use `DatasetConfig` to load the training dataset, `TunerModelConfig` to initialize the trainable model, and `AlgorithmConfig` to specify the RL algorithm and its hyperparameters. > Note: -> The `tune` function is based on [Trinity-RFT](https://github.com/modelscope/Trinity-RFT) and it converts the input parameters into a YAML configuration internally. +> The `tune` function is based on [Trinity-RFT](https://github.com/agentscope-ai/Trinity-RFT) and it converts the input parameters into a YAML configuration internally. > Advanced users can ignore `model`, `train_dataset`, `algorithm` arguments and provide a configuration file path pointing to a YAML file using the `config_path` argument instead (see [config.yaml](./config.yaml) for an example). > We recommend using the configuration file approach for fine-grained control over the training process and leveraging advanced features provided by Trinity-RFT. -> You can refer to the Trinity-RFT [Configuration Guide](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/trinity_configs.html) for more details on configuration options. +> You can refer to the Trinity-RFT [Configuration Guide](https://agentscope-ai.github.io/Trinity-RFT/en/main/tutorial/trinity_configs.html) for more details on configuration options. The checkpoint and logs will automatically be saved to the `checkpoints/AgentScope` directory under the current working directory and each run will be saved in a sub-directory suffixed with the current timestamp. You can find the tensorboard logs inside `monitor/tensorboard` of the checkpoint directory. @@ -342,7 +342,7 @@ After implementing the workflow function, follow these steps to run the training - At least 2 NVIDIA GPUs with CUDA 12.8 or newer. - Adjust the configuration file ([config.yaml](./config.yaml)) based on your hardware. - - Follow the Trinity-RFT [installation guide](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/trinity_installation.html) to install the latest version from source code. + - Follow the Trinity-RFT [installation guide](https://agentscope-ai.github.io/Trinity-RFT/en/main/tutorial/trinity_installation.html) to install the latest version from source code. - Download the GSM8K dataset and Qwen/Qwen3-0.6B model checkpoints (example): ```bash diff --git a/tuner/math_agent/README_zh.md b/tuner/math_agent/README_zh.md index f82b037..e3a8abf 100644 --- a/tuner/math_agent/README_zh.md +++ b/tuner/math_agent/README_zh.md @@ -245,10 +245,10 @@ if __name__ == "__main__": 这里用 `DatasetConfig` 加载训练数据集,`TunerModelConfig` 初始化可训练模型,`AlgorithmConfig` 指定 RL 算法及其超参数。 > 注意: -> `tune` 函数基于 [Trinity-RFT](https://github.com/modelscope/Trinity-RFT) 实现,会将输入参数自动转为 YAML 配置。 +> `tune` 函数基于 [Trinity-RFT](https://github.com/agentscope-ai/Trinity-RFT) 实现,会将输入参数自动转为 YAML 配置。 > 高级用户可忽略 `model`、`train_dataset`、`algorithm` 参数,直接用 `config_path` 指定 YAML 配置文件(见 [config.yaml](./config.yaml) 示例)。 > 推荐用配置文件方式实现更细粒度的训练控制,充分利用 Trinity-RFT 的高级特性。 -> 详细配置说明见 Trinity-RFT [配置指南](https://modelscope.github.io/Trinity-RFT/zh/main/tutorial/trinity_configs.html)。 +> 详细配置说明见 Trinity-RFT [配置指南](https://agentscope-ai.github.io/Trinity-RFT/zh/main/tutorial/trinity_configs.html)。 训练产生的 checkpoint 和日志信息会自动保存在当前目录下的 `checkpoints/AgentScope` 中,每次运行会新建带时间戳的子目录。 TensorBoard 日志在 checkpoint 目录下的 `monitor/tensorboard` 中。 @@ -335,7 +335,7 @@ if __name__ == "__main__": - 至少 2 块 NVIDIA GPU,CUDA 12.8 或更高。 - 根据硬件调整配置文件([config.yaml](./config.yaml))。 - - 按 Trinity-RFT [安装指南](https://modelscope.github.io/Trinity-RFT/zh/main/tutorial/trinity_installation.html) 从源码安装最新版。 + - 按 Trinity-RFT [安装指南](https://agentscope-ai.github.io/Trinity-RFT/zh/main/tutorial/trinity_installation.html) 从源码安装最新版。 - 下载 GSM8K 数据集和 Qwen/Qwen3-0.6B 模型权重(示例): ```bash diff --git a/tuner/math_agent/config.yaml b/tuner/math_agent/config.yaml index 9d98203..98a214d 100644 --- a/tuner/math_agent/config.yaml +++ b/tuner/math_agent/config.yaml @@ -1,4 +1,4 @@ -# Please refer to https://modelscope.github.io/Trinity-RFT/en/main/tutorial/trinity_configs.html for detailed explanation of each field. +# Please refer to https://agentscope-ai.github.io/Trinity-RFT/en/main/tutorial/trinity_configs.html for detailed explanation of each field. project: AgentScope name: GSM8K-Qwen3-0.6B # directory to save checkpoints, default to ./checkpoints if TRINITY_CHECKPOINT_ROOT_DIR not set