Add examples for frozenlake and emailsearch (#94)
This commit is contained in:
@@ -19,7 +19,9 @@ from structured_model import (
|
|||||||
get_seer_model,
|
get_seer_model,
|
||||||
get_hunter_model,
|
get_hunter_model,
|
||||||
)
|
)
|
||||||
from prompt import EnglishPrompts as Prompts
|
from prompt import (
|
||||||
|
EnglishPrompts as Prompts,
|
||||||
|
) # pylint: disable=no-name-in-module
|
||||||
|
|
||||||
# Uncomment the following line to use Chinese prompts
|
# Uncomment the following line to use Chinese prompts
|
||||||
# from prompt import ChinesePrompts as Prompts
|
# from prompt import ChinesePrompts as Prompts
|
||||||
|
|||||||
@@ -6,7 +6,9 @@ from typing import Any
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from agentscope.agent import AgentBase, ReActAgent
|
from agentscope.agent import AgentBase, ReActAgent
|
||||||
from agentscope.message import Msg
|
from agentscope.message import Msg
|
||||||
from prompt import EnglishPrompts as Prompts
|
from prompt import (
|
||||||
|
EnglishPrompts as Prompts,
|
||||||
|
) # pylint: disable=no-name-in-module
|
||||||
|
|
||||||
MAX_GAME_ROUND = 30
|
MAX_GAME_ROUND = 30
|
||||||
MAX_DISCUSSION_ROUND = 3
|
MAX_DISCUSSION_ROUND = 3
|
||||||
|
|||||||
@@ -22,6 +22,6 @@ AgentScope Tuner requires:
|
|||||||
- `agentscope>=1.0.12`
|
- `agentscope>=1.0.12`
|
||||||
- `trinity-rft>=0.4.1`
|
- `trinity-rft>=0.4.1`
|
||||||
|
|
||||||
AgentScope Tuner is built on top of [Trinity-RFT](https://github.com/modelscope/Trinity-RFT).
|
AgentScope Tuner is built on top of [Trinity-RFT](https://github.com/agentscope-ai/Trinity-RFT).
|
||||||
Please refer to the [Trinity-RFT installation guide](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/trinity_installation.html)
|
Please refer to the [Trinity-RFT installation guide](https://agentscope-ai.github.io/Trinity-RFT/en/main/tutorial/trinity_installation.html)
|
||||||
for detailed instructions on how to set up the environment.
|
for detailed instructions on how to set up the environment.
|
||||||
|
|||||||
@@ -21,6 +21,6 @@ AgentScope Tuner 需要:
|
|||||||
- `agentscope>=1.0.12`
|
- `agentscope>=1.0.12`
|
||||||
- `trinity-rft>=0.4.1`
|
- `trinity-rft>=0.4.1`
|
||||||
|
|
||||||
AgentScope Tuner 基于 [Trinity-RFT](https://github.com/modelscope/Trinity-RFT) 构建。
|
AgentScope Tuner 基于 [Trinity-RFT](https://github.com/agentscope-ai/Trinity-RFT) 构建。
|
||||||
请参考 [Trinity-RFT 安装指南](https://modelscope.github.io/Trinity-RFT/zh/main/tutorial/trinity_installation.html)
|
请参考 [Trinity-RFT 安装指南](https://agentscope-ai.github.io/Trinity-RFT/zh/main/tutorial/trinity_installation.html)
|
||||||
获取详细的安装方法。
|
获取详细的安装方法。
|
||||||
|
|||||||
279
tuner/email_search/README.md
Normal file
279
tuner/email_search/README.md
Normal file
@@ -0,0 +1,279 @@
|
|||||||
|
# Training Email Search Agent with RL using AgentScope-Tuner
|
||||||
|
|
||||||
|
This example demonstrates how to implement reinforcement fine-tuning for the Email Search task (inspired by [ART](https://openpipe.ai/blog/art-e-mail-agent)) using AgentScope-Tuner, whose RFT functionality is backed by [Trinity-RFT](https://github.com/agentscope-ai/Trinity-RFT).
|
||||||
|
|
||||||
|
## Task Setting
|
||||||
|
|
||||||
|
The agent's goal is to answer user queries by searching through an email inbox. The agent needs to:
|
||||||
|
- Understand the user's question
|
||||||
|
- Search for relevant emails using keywords
|
||||||
|
- Read email contents to extract information
|
||||||
|
- Provide accurate answers with proper source citations
|
||||||
|
|
||||||
|
**Agent Type**: The agent (`EmailSearchAgent`) extends `ReActAgent`, which follows a reasoning-acting loop to solve tasks iteratively.
|
||||||
|
|
||||||
|
**Environment**: The environment is a SQLite database containing emails from the Enron Email dataset. Each task provides:
|
||||||
|
- `question`: The user's email search query
|
||||||
|
- `inbox_address`: The email inbox to search
|
||||||
|
- `query_date`: The date context for the query
|
||||||
|
- `answer`: The expected answer (ground truth), only for reward calculation
|
||||||
|
- `message_ids`: IDs of relevant emails containing the answer, only for reward calculation
|
||||||
|
|
||||||
|
**Available Tools**:
|
||||||
|
- `search_emails`: Find emails by keywords, inbox address, and date range. Returns a list of email summaries (message_id and snippet).
|
||||||
|
- `read_email`: Read the full content of a specific email by message_id.
|
||||||
|
- `generate_response`: Provide the final structured answer with sources (inherited from ReAct agent).
|
||||||
|
|
||||||
|
## Dataset Preparation
|
||||||
|
|
||||||
|
The dataset contains email queries based on the [Enron Email dataset](https://huggingface.co/datasets/corbt/enron-emails). Run the data preparation script to generate the email database and datasets:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python prepare_data.py
|
||||||
|
```
|
||||||
|
|
||||||
|
If you want to choose a new database path, you can modify the `DEFAULT_DB_PATH` in [`prepare_data.py`](./prepare_data.py). Also, remember to set an environment variable `DEFAULT_EMAIL_DB_PATH` to point to the database path before moving to the next step:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export DEFAULT_EMAIL_DB_PATH=/path/to/enron_emails_dataset/data/enron_emails.db
|
||||||
|
```
|
||||||
|
|
||||||
|
This will create a SQLite database and datasets:
|
||||||
|
|
||||||
|
```
|
||||||
|
/path/to/enron_emails_dataset/
|
||||||
|
├── data
|
||||||
|
└── enron_emails.db # Email database
|
||||||
|
├── train.parquet # Training samples
|
||||||
|
└── test.parquet # Test samples
|
||||||
|
```
|
||||||
|
|
||||||
|
Each sample looks like:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"id": 0,
|
||||||
|
"question": "Were there any variances detected for hour 6 on 3/9/01?",
|
||||||
|
"answer": "Yes, variances were detected in both Generation and Energy Import/Export schedules for hour 6 on 3/9/01.",
|
||||||
|
"message_ids": ["<17407857.1075840601283.JavaMail.evans@thyme>"],
|
||||||
|
"how_realistic": 0.800000011920929,
|
||||||
|
"inbox_address": "pete.davis@enron.com",
|
||||||
|
"query_date": "2001-03-16"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Code Implementation
|
||||||
|
|
||||||
|
This section provides a high-level overview of the code implementation. For detailed implementation, please refer to the source code.
|
||||||
|
|
||||||
|
### Agent Workflow
|
||||||
|
|
||||||
|
The workflow function `run_email_search_agent` implements the agent-environment interaction loop:
|
||||||
|
|
||||||
|
```python
|
||||||
|
async def run_email_search_agent(
|
||||||
|
task: Dict,
|
||||||
|
model: ChatModelBase,
|
||||||
|
auxiliary_models: Dict[str, ChatModelBase],
|
||||||
|
) -> WorkflowOutput:
|
||||||
|
# Parse task and create agent
|
||||||
|
agent = EmailSearchAgent(
|
||||||
|
name="email_search_agent",
|
||||||
|
sys_prompt=system_prompt,
|
||||||
|
model=model,
|
||||||
|
max_iters=max_turns,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run the agent with structured output
|
||||||
|
response = await agent.reply(
|
||||||
|
msg=Msg("user", question, role="user"),
|
||||||
|
structured_model=AnswerModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
return WorkflowOutput(response=response)
|
||||||
|
```
|
||||||
|
|
||||||
|
The agent follows a ReAct pattern: it reasons about the task, calls tools to search and read emails, and finally generates a structured response containing the answer and source message IDs.
|
||||||
|
|
||||||
|
### Judge Function
|
||||||
|
|
||||||
|
The judge function `email_search_judge` implements reward calculation using LLM-as-a-Judge:
|
||||||
|
|
||||||
|
```python
|
||||||
|
async def email_search_judge(
|
||||||
|
task: Dict,
|
||||||
|
response: Msg,
|
||||||
|
auxiliary_models: Dict[str, ChatModelBase],
|
||||||
|
) -> JudgeOutput:
|
||||||
|
# Extract answer and sources from response
|
||||||
|
answer = answer_and_sources.get("answer")
|
||||||
|
sources = answer_and_sources.get("sources", [])
|
||||||
|
|
||||||
|
# Judge correctness using LLM-as-a-Judge
|
||||||
|
judge_model = auxiliary_models.get('judge') or list(auxiliary_models.values())[0]
|
||||||
|
judge_response = await judge_correctness(
|
||||||
|
answer, query, judge_model
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate reward based on:
|
||||||
|
# - Answer correctness (accuracy: -1.0 to 1.0)
|
||||||
|
# - Source correctness (format: partial rewards)
|
||||||
|
# - Efficiency (bonus for fewer turns, correct sources)
|
||||||
|
result = {"accuracy": ..., "format": ...} # calculated based on judge_response
|
||||||
|
|
||||||
|
return JudgeOutput(
|
||||||
|
reward=sum(result.values()),
|
||||||
|
metrics=metrics,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
The reward function considers:
|
||||||
|
- **Answer correctness**: Evaluated by LLM-as-a-Judge comparing the agent's answer with the ground truth
|
||||||
|
- **Source correctness**: Whether the agent cited the correct email message IDs
|
||||||
|
- **Efficiency**: Bonus rewards for finding/reading the correct email and taking fewer turns
|
||||||
|
|
||||||
|
See [`main.py`](./main.py) and [`email_search_agent.py`](./email_search_agent.py) for implementation details.
|
||||||
|
|
||||||
|
## How to Run
|
||||||
|
|
||||||
|
### Prerequisites
|
||||||
|
|
||||||
|
- At least 4 NVIDIA GPUs with CUDA 12.8 or newer
|
||||||
|
* Note: For the 30B Judge model, you need to use a GPU with at least 4080 memory; you can also run the model on multiple GPUs by using `tensor_parallel_size > 1` to reduce the memory usage (by default, `tensor_parallel_size=2`).
|
||||||
|
- Follow the Trinity-RFT [installation guide](https://agentscope-ai.github.io/Trinity-RFT/en/main/tutorial/trinity_installation.html) to install the latest version from source code
|
||||||
|
- Download the model checkpoint (example):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
huggingface-cli download Qwen/Qwen3-4B-Instruct-2507
|
||||||
|
huggingface-cli download Qwen/Qwen3-30B-A3B-Instruct-2507 # judge model
|
||||||
|
```
|
||||||
|
|
||||||
|
### Configuration
|
||||||
|
|
||||||
|
Adjust the configuration file ([`config.yaml`](./config.yaml)) based on your hardware. Key configuration sections include:
|
||||||
|
|
||||||
|
- **TunerModelConfig**: Set `model_path` to your model checkpoint path
|
||||||
|
- **AlgorithmConfig**: Configure RL algorithm parameters (e.g., `multi_step_grpo`, learning rate, policy loss function)
|
||||||
|
- **DatasetConfig**: The dataset path is specified in `main.py` when creating the `DatasetConfig` object
|
||||||
|
- **Auxiliary Models**: Configure judge model settings for LLM-as-a-Judge
|
||||||
|
|
||||||
|
For full configuration details, see [Trinity-RFT Configuration Guide](https://agentscope-ai.github.io/Trinity-RFT/en/main/tutorial/trinity_configs.html).
|
||||||
|
|
||||||
|
### Start-Up Commands
|
||||||
|
|
||||||
|
1. Prepare the dataset:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python prepare_data.py
|
||||||
|
export DEFAULT_EMAIL_DB_PATH=/path/to/enron_emails_dataset/data/enron_emails.db
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Set up a [Ray](https://github.com/ray-project/ray) cluster:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
ray start --head
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Run the training script:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python main.py
|
||||||
|
```
|
||||||
|
|
||||||
|
## Experimental Results
|
||||||
|
|
||||||
|
### Quantitative Results
|
||||||
|
|
||||||
|
The training results show improvements in agent performance over training iterations. Key metrics include:
|
||||||
|
|
||||||
|
- **Train reward**: The average reward on training samples increases as the agent learns better strategies
|
||||||
|
- **Rollout accuracy**: The average accuracy on rollout samples increases as the agent learns better strategies
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
|
||||||
|
### Concrete Example
|
||||||
|
|
||||||
|
An example of the agent's behavior is shown below:
|
||||||
|
|
||||||
|
**Query:** "What do the color codes mean in the curve assessment?"
|
||||||
|
|
||||||
|
We show the last several turns of agent responses:
|
||||||
|
|
||||||
|
The agent performs multiple search attempts to find relevant emails. After some unsuccessful searches, the agent tries:
|
||||||
|
|
||||||
|
**Tool call:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"name": "search_emails",
|
||||||
|
"input": {
|
||||||
|
"inbox_address": "steven.kean@enron.com",
|
||||||
|
"query_date": "2001-04-03",
|
||||||
|
"keywords": ["curve", "assessment"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Tool result:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"name": "search_emails",
|
||||||
|
"output": [
|
||||||
|
{
|
||||||
|
"message_id": "<2654330.1075846153519.JavaMail.evans@thyme>",
|
||||||
|
"snippet": "... curves and Senior Management as necessary.\\n\\n<b>Curve Assessment</b>\\n Green: \\n - curve is reasonable\\n - small P&L ..."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"message_id": "<12499967.1075846153495.JavaMail.evans@thyme>",
|
||||||
|
"snippet": "... curves and Senior Management as necessary.\\n\\n<b>Curve Assessment</b>\\n Green: \\n - curve is reasonable\\n - small P&L ..."
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
After finding relevant emails, the agent uses `read_email` tool to read the full content of both emails:
|
||||||
|
|
||||||
|
**Tool call:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"name": "read_email",
|
||||||
|
"input": {
|
||||||
|
"message_id": "<2654330.1075846153519.JavaMail.evans@thyme>"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Tool call:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"name": "read_email",
|
||||||
|
"input": {
|
||||||
|
"message_id": "<12499967.1075846153495.JavaMail.evans@thyme>"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
After reading the emails, the agent generates the final structured response:
|
||||||
|
|
||||||
|
**Tool call:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"name": "generate_response",
|
||||||
|
"input": {
|
||||||
|
"answer": "In the curve assessment, the color codes are used as follows:\n- Green: The curve is considered reasonable, and the P&L (profit and loss) is small.",
|
||||||
|
"sources": [
|
||||||
|
"<2654330.1075846153519.JavaMail.evans@thyme>",
|
||||||
|
"<12499967.1075846153495.JavaMail.evans@thyme>"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
The judge evaluates the answer as correct based on the ground truth.
|
||||||
279
tuner/email_search/README_zh.md
Normal file
279
tuner/email_search/README_zh.md
Normal file
@@ -0,0 +1,279 @@
|
|||||||
|
# 使用 AgentScope-Tuner 训练邮件搜索智能体
|
||||||
|
|
||||||
|
本示例展示如何使用 AgentScope-Tuner 对邮件搜索任务(灵感来自 [ART](https://openpipe.ai/blog/art-e-mail-agent))进行强化微调,其 RFT 功能由 [Trinity-RFT](https://github.com/agentscope-ai/Trinity-RFT) 提供支持。
|
||||||
|
|
||||||
|
## 任务设定
|
||||||
|
|
||||||
|
智能体的目标是通过搜索邮件收件箱来回答用户查询。智能体需要:
|
||||||
|
- 理解用户的问题
|
||||||
|
- 使用关键词搜索相关邮件
|
||||||
|
- 阅读邮件内容以提取信息
|
||||||
|
- 提供准确的答案并附上适当的来源引用
|
||||||
|
|
||||||
|
**智能体类型**:智能体(`EmailSearchAgent`)继承自 `ReActAgent`,遵循推理-行动循环来迭代解决任务。
|
||||||
|
|
||||||
|
**环境**:环境是一个包含来自 Enron 邮件数据集的 SQLite 数据库。每个任务提供:
|
||||||
|
- `question`:用户的邮件搜索查询
|
||||||
|
- `inbox_address`:要搜索的邮件收件箱
|
||||||
|
- `query_date`:查询的日期上下文
|
||||||
|
- `answer`:预期答案(真实值),仅用于奖励计算
|
||||||
|
- `message_ids`:包含答案的相关邮件 ID,仅用于奖励计算
|
||||||
|
|
||||||
|
**可用工具**:
|
||||||
|
- `search_emails`:通过关键词、收件箱地址和日期范围查找邮件。返回邮件摘要列表(message_id 和片段)。
|
||||||
|
- `read_email`:通过 message_id 读取特定邮件的完整内容。
|
||||||
|
- `generate_response`:提供带有来源的最终结构化答案(继承自 ReAct 智能体)。
|
||||||
|
|
||||||
|
## 数据集准备
|
||||||
|
|
||||||
|
数据集包含基于 [Enron 邮件数据集](https://huggingface.co/datasets/corbt/enron-emails) 的邮件查询。运行数据准备脚本以生成邮件数据库和数据集:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python prepare_data.py
|
||||||
|
```
|
||||||
|
|
||||||
|
如果你想选择新的数据库路径,可以修改 [`prepare_data.py`](./prepare_data.py) 中的 `DEFAULT_DB_PATH`。同时,请记住在进入下一步之前设置环境变量 `DEFAULT_EMAIL_DB_PATH` 指向数据库路径:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export DEFAULT_EMAIL_DB_PATH=/path/to/enron_emails_dataset/data/enron_emails.db
|
||||||
|
```
|
||||||
|
|
||||||
|
这将创建一个 SQLite 数据库和数据集:
|
||||||
|
|
||||||
|
```
|
||||||
|
/path/to/enron_emails_dataset/
|
||||||
|
├── data
|
||||||
|
└── enron_emails.db # 邮件数据库
|
||||||
|
├── train.parquet # 训练样本
|
||||||
|
└── test.parquet # 测试样本
|
||||||
|
```
|
||||||
|
|
||||||
|
每个样本如下所示:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"id": 0,
|
||||||
|
"question": "Were there any variances detected for hour 6 on 3/9/01?",
|
||||||
|
"answer": "Yes, variances were detected in both Generation and Energy Import/Export schedules for hour 6 on 3/9/01.",
|
||||||
|
"message_ids": ["<17407857.1075840601283.JavaMail.evans@thyme>"],
|
||||||
|
"how_realistic": 0.800000011920929,
|
||||||
|
"inbox_address": "pete.davis@enron.com",
|
||||||
|
"query_date": "2001-03-16"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## 代码实现
|
||||||
|
|
||||||
|
本节提供代码实现的高级概览。详细实现请参考源代码。
|
||||||
|
|
||||||
|
### 智能体工作流
|
||||||
|
|
||||||
|
工作流函数 `run_email_search_agent` 实现智能体-环境交互循环:
|
||||||
|
|
||||||
|
```python
|
||||||
|
async def run_email_search_agent(
|
||||||
|
task: Dict,
|
||||||
|
model: ChatModelBase,
|
||||||
|
auxiliary_models: Dict[str, ChatModelBase],
|
||||||
|
) -> WorkflowOutput:
|
||||||
|
# 解析任务并创建智能体
|
||||||
|
agent = EmailSearchAgent(
|
||||||
|
name="email_search_agent",
|
||||||
|
sys_prompt=system_prompt,
|
||||||
|
model=model,
|
||||||
|
max_iters=max_turns,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 使用结构化输出运行智能体
|
||||||
|
response = await agent.reply(
|
||||||
|
msg=Msg("user", question, role="user"),
|
||||||
|
structured_model=AnswerModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
return WorkflowOutput(response=response)
|
||||||
|
```
|
||||||
|
|
||||||
|
智能体遵循 ReAct 模式:它推理任务,调用工具搜索和阅读邮件,最后生成包含答案和来源消息 ID 的结构化响应。
|
||||||
|
|
||||||
|
### 评判函数
|
||||||
|
|
||||||
|
评判函数 `email_search_judge` 使用 LLM-as-a-Judge 实现奖励计算:
|
||||||
|
|
||||||
|
```python
|
||||||
|
async def email_search_judge(
|
||||||
|
task: Dict,
|
||||||
|
response: Msg,
|
||||||
|
auxiliary_models: Dict[str, ChatModelBase],
|
||||||
|
) -> JudgeOutput:
|
||||||
|
# 从响应中提取答案和来源
|
||||||
|
answer = answer_and_sources.get("answer")
|
||||||
|
sources = answer_and_sources.get("sources", [])
|
||||||
|
|
||||||
|
# 使用 LLM-as-a-Judge 评判正确性
|
||||||
|
judge_model = auxiliary_models.get('judge') or list(auxiliary_models.values())[0]
|
||||||
|
judge_response = await judge_correctness(
|
||||||
|
answer, query, judge_model
|
||||||
|
)
|
||||||
|
|
||||||
|
# 基于以下因素计算奖励:
|
||||||
|
# - 答案正确性(准确度:-1.0 到 1.0)
|
||||||
|
# - 来源正确性(格式:部分奖励)
|
||||||
|
# - 效率(对更少轮次、正确来源的奖励)
|
||||||
|
result = {"accuracy": ..., "format": ...} # 基于 judge_response 计算
|
||||||
|
|
||||||
|
return JudgeOutput(
|
||||||
|
reward=sum(result.values()),
|
||||||
|
metrics=metrics,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
奖励函数考虑以下因素:
|
||||||
|
- **答案正确性**:通过 LLM-as-a-Judge 比较智能体的答案与真实值进行评估
|
||||||
|
- **来源正确性**:智能体是否引用了正确的邮件消息 ID
|
||||||
|
- **效率**:对找到/阅读正确邮件和更少轮次的奖励
|
||||||
|
|
||||||
|
详细实现请参考 [`main.py`](./main.py) 和 [`email_search_agent.py`](./email_search_agent.py)。
|
||||||
|
|
||||||
|
## 运行方法
|
||||||
|
|
||||||
|
### 前置要求
|
||||||
|
|
||||||
|
- 至少 4 张 NVIDIA GPU,CUDA 版本 ≥ 12.8
|
||||||
|
* 注意:对于 30B 评判模型,需要使用至少 4080 显存的 GPU;你也可以通过使用 `tensor_parallel_size > 1` 在多张 GPU 上运行模型以减少显存使用(默认情况下,`tensor_parallel_size=2`)。
|
||||||
|
- 按照 Trinity-RFT [安装指南](https://agentscope-ai.github.io/Trinity-RFT/zh/main/tutorial/trinity_installation.html) 从源码安装最新版本
|
||||||
|
- 下载模型检查点(示例):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
huggingface-cli download Qwen/Qwen3-4B-Instruct-2507
|
||||||
|
huggingface-cli download Qwen/Qwen3-30B-A3B-Instruct-2507 # 评判模型
|
||||||
|
```
|
||||||
|
|
||||||
|
### 配置
|
||||||
|
|
||||||
|
根据你的硬件调整配置文件([`config.yaml`](./config.yaml))。关键配置部分包括:
|
||||||
|
|
||||||
|
- **TunerModelConfig**:将 `model_path` 设置为你的模型检查点路径
|
||||||
|
- **AlgorithmConfig**:配置 RL 算法参数(例如,`multi_step_grpo`、学习率、策略损失函数)
|
||||||
|
- **DatasetConfig**:数据集路径在创建 `DatasetConfig` 对象时在 `main.py` 中指定
|
||||||
|
- **辅助模型**:为 LLM-as-a-Judge 配置评判模型设置
|
||||||
|
|
||||||
|
完整配置详情请参考 [Trinity-RFT 配置指南](https://agentscope-ai.github.io/Trinity-RFT/zh/main/tutorial/trinity_configs.html)。
|
||||||
|
|
||||||
|
### 启动命令
|
||||||
|
|
||||||
|
1. 准备数据集:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python prepare_data.py
|
||||||
|
export DEFAULT_EMAIL_DB_PATH=/path/to/enron_emails_dataset/data/enron_emails.db
|
||||||
|
```
|
||||||
|
|
||||||
|
2. 启动 [Ray](https://github.com/ray-project/ray):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
ray start --head
|
||||||
|
```
|
||||||
|
|
||||||
|
3. 运行训练脚本:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python main.py
|
||||||
|
```
|
||||||
|
|
||||||
|
## 实验结果
|
||||||
|
|
||||||
|
### 定量结果
|
||||||
|
|
||||||
|
训练结果显示智能体性能随训练迭代次数的提升。关键指标包括:
|
||||||
|
|
||||||
|
- **训练奖励**:训练样本上的平均奖励随着智能体学习更好的策略而增加
|
||||||
|
- **Rollout 准确度**:Rollout 样本上的平均准确度随着智能体学习更好的策略而增加
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
|
||||||
|
### 具体示例
|
||||||
|
|
||||||
|
智能体行为示例如下:
|
||||||
|
|
||||||
|
**查询:** "What do the color codes mean in the curve assessment?"
|
||||||
|
|
||||||
|
我们展示智能体响应的最后几轮:
|
||||||
|
|
||||||
|
智能体执行多次搜索尝试以找到相关邮件。经过一些不成功的搜索后,智能体尝试:
|
||||||
|
|
||||||
|
**工具调用:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"name": "search_emails",
|
||||||
|
"input": {
|
||||||
|
"inbox_address": "steven.kean@enron.com",
|
||||||
|
"query_date": "2001-04-03",
|
||||||
|
"keywords": ["curve", "assessment"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**工具结果:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"name": "search_emails",
|
||||||
|
"output": [
|
||||||
|
{
|
||||||
|
"message_id": "<2654330.1075846153519.JavaMail.evans@thyme>",
|
||||||
|
"snippet": "... curves and Senior Management as necessary.\\n\\n<b>Curve Assessment</b>\\n Green: \\n - curve is reasonable\\n - small P&L ..."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"message_id": "<12499967.1075846153495.JavaMail.evans@thyme>",
|
||||||
|
"snippet": "... curves and Senior Management as necessary.\\n\\n<b>Curve Assessment</b>\\n Green: \\n - curve is reasonable\\n - small P&L ..."
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
找到相关邮件后,智能体使用 `read_email` 工具读取两封邮件的完整内容:
|
||||||
|
|
||||||
|
**工具调用:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"name": "read_email",
|
||||||
|
"input": {
|
||||||
|
"message_id": "<2654330.1075846153519.JavaMail.evans@thyme>"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**工具调用:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"name": "read_email",
|
||||||
|
"input": {
|
||||||
|
"message_id": "<12499967.1075846153495.JavaMail.evans@thyme>"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
阅读邮件后,智能体生成最终的结构化响应:
|
||||||
|
|
||||||
|
**工具调用:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"name": "generate_response",
|
||||||
|
"input": {
|
||||||
|
"answer": "In the curve assessment, the color codes are used as follows:\n- Green: The curve is considered reasonable, and the P&L (profit and loss) is small.",
|
||||||
|
"sources": [
|
||||||
|
"<2654330.1075846153519.JavaMail.evans@thyme>",
|
||||||
|
"<12499967.1075846153495.JavaMail.evans@thyme>"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
评判器评估上面的答案为正确。
|
||||||
175
tuner/email_search/_email_search_agent.py
Normal file
175
tuner/email_search/_email_search_agent.py
Normal file
@@ -0,0 +1,175 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""Adapted from Trinity-RFT"""
|
||||||
|
import json
|
||||||
|
import traceback
|
||||||
|
from dataclasses import asdict
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Any
|
||||||
|
from _utils import ( # pylint: disable=E0611
|
||||||
|
read_email_tool,
|
||||||
|
search_emails_tool,
|
||||||
|
)
|
||||||
|
from agentscope import logger
|
||||||
|
from agentscope.agent import ReActAgent
|
||||||
|
from agentscope.message import TextBlock
|
||||||
|
from agentscope.tool import Toolkit, ToolResponse
|
||||||
|
|
||||||
|
|
||||||
|
def pre_reasoning_hook(_self: Any, _kwargs: Any) -> dict[str, Any] | None:
|
||||||
|
"""Pre-reasoning hook to remove tool_choice from kwargs."""
|
||||||
|
_kwargs.pop("tool_choice", None)
|
||||||
|
return _kwargs
|
||||||
|
|
||||||
|
|
||||||
|
class EmailSearchAgent(ReActAgent):
|
||||||
|
"""
|
||||||
|
A customized ReAct agent with pre-defined tools for
|
||||||
|
email search and reading.
|
||||||
|
Ref: https://github.com/OpenPipe/ART
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||||
|
self.message_id_list = (
|
||||||
|
[]
|
||||||
|
) # List to store message IDs found during search
|
||||||
|
self.ever_read_message_ids = (
|
||||||
|
[]
|
||||||
|
) # List to store message IDs that have been read
|
||||||
|
toolkit = Toolkit()
|
||||||
|
toolkit.register_tool_function(self.search_emails)
|
||||||
|
toolkit.register_tool_function(self.read_email)
|
||||||
|
super().__init__(*args, toolkit=toolkit, **kwargs)
|
||||||
|
|
||||||
|
self.register_instance_hook(
|
||||||
|
"pre_reasoning",
|
||||||
|
"tool_choice_hook",
|
||||||
|
pre_reasoning_hook,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def reset(self) -> None:
|
||||||
|
"""Reset agent state for a new rollout/episode."""
|
||||||
|
self.message_id_list.clear()
|
||||||
|
self.ever_read_message_ids.clear()
|
||||||
|
await self.memory.clear()
|
||||||
|
|
||||||
|
def search_emails(
|
||||||
|
self,
|
||||||
|
inbox_address: str,
|
||||||
|
query_date: str,
|
||||||
|
keywords: list[str],
|
||||||
|
**_kwargs: Any,
|
||||||
|
) -> ToolResponse:
|
||||||
|
"""
|
||||||
|
Search the user's email inbox for emails that match the given keywords.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inbox_address: The user's email address.
|
||||||
|
query_date: The date of the query in 'YYYY-MM-DD' format.
|
||||||
|
keywords: Keywords to search for in the user's email inbox.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ToolResponse:
|
||||||
|
A ToolResponse object containing a list of TextBlock objects
|
||||||
|
in the `content` field. On success, the text field of the
|
||||||
|
TextBlock contains a JSON string representing a list of email
|
||||||
|
summaries (e.g., message_id, snippet) matching the search
|
||||||
|
criteria. Each email summary is converted to a dictionary via
|
||||||
|
`asdict`. On failure, the text indicates an error message.
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
next_day = (
|
||||||
|
datetime.strptime(query_date, "%Y-%m-%d") + timedelta(days=1)
|
||||||
|
).strftime(
|
||||||
|
"%Y-%m-%d",
|
||||||
|
)
|
||||||
|
res = search_emails_tool(
|
||||||
|
inbox=inbox_address,
|
||||||
|
sent_before=next_day,
|
||||||
|
keywords=keywords,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.message_id_list.extend([r.message_id for r in res])
|
||||||
|
|
||||||
|
return ToolResponse(
|
||||||
|
content=[
|
||||||
|
TextBlock(
|
||||||
|
type="text",
|
||||||
|
text=json.dumps([asdict(r) for r in res]),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.info(
|
||||||
|
"Error in search_emails: %s, traceback: %s",
|
||||||
|
e,
|
||||||
|
traceback.format_exc(),
|
||||||
|
)
|
||||||
|
return ToolResponse(
|
||||||
|
content=[
|
||||||
|
TextBlock(
|
||||||
|
type="text",
|
||||||
|
text=(
|
||||||
|
f"Error: Failed to search emails.\n"
|
||||||
|
f"Error message: {e}"
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
def read_email(self, message_id: str, **_kwargs: Any) -> ToolResponse:
|
||||||
|
"""
|
||||||
|
Read the content of an email from the user's email inbox.
|
||||||
|
Returns the email content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message_id (str): The unique identifier of the email to read.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ToolResponse:
|
||||||
|
A ToolResponse object containing the email content or an
|
||||||
|
error message if the email is not found.
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
email_content = read_email_tool(message_id)
|
||||||
|
|
||||||
|
self.ever_read_message_ids.append(message_id)
|
||||||
|
|
||||||
|
if email_content is None:
|
||||||
|
return ToolResponse(
|
||||||
|
content=[
|
||||||
|
TextBlock(
|
||||||
|
type="text",
|
||||||
|
text=(
|
||||||
|
f"Error: Email (message_id = {message_id}) "
|
||||||
|
f"not found."
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
return ToolResponse(
|
||||||
|
content=[
|
||||||
|
TextBlock(
|
||||||
|
type="text",
|
||||||
|
text=json.dumps(email_content.model_dump()),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.info(
|
||||||
|
"Error in read_email: %s, traceback: %s",
|
||||||
|
e,
|
||||||
|
traceback.format_exc(),
|
||||||
|
)
|
||||||
|
return ToolResponse(
|
||||||
|
content=[
|
||||||
|
TextBlock(
|
||||||
|
type="text",
|
||||||
|
text=(
|
||||||
|
f"Error: Failed to read email.\n"
|
||||||
|
f"Error message: {e}"
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
328
tuner/email_search/_utils.py
Normal file
328
tuner/email_search/_utils.py
Normal file
@@ -0,0 +1,328 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
This file defines Dataclass and tool implementations.
|
||||||
|
Modified from https://github.com/OpenPipe/ART/blob/art-e/
|
||||||
|
"""
|
||||||
|
import datetime
|
||||||
|
import os
|
||||||
|
import sqlite3
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, List, Optional
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
from agentscope import logger
|
||||||
|
|
||||||
|
DEFAULT_DB_PATH = os.environ.get("DEFAULT_EMAIL_DB_PATH")
|
||||||
|
conn = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_conn() -> sqlite3.Connection:
|
||||||
|
"""Get or create a database connection."""
|
||||||
|
global conn
|
||||||
|
if conn is None:
|
||||||
|
conn = sqlite3.connect(
|
||||||
|
f"file:{DEFAULT_DB_PATH}?mode=ro",
|
||||||
|
uri=True,
|
||||||
|
check_same_thread=False,
|
||||||
|
)
|
||||||
|
return conn
|
||||||
|
|
||||||
|
|
||||||
|
class QueryModel(BaseModel):
|
||||||
|
"""Model for email search query."""
|
||||||
|
|
||||||
|
id: int
|
||||||
|
question: str
|
||||||
|
answer: str
|
||||||
|
message_ids: List[str] # message_ids (strings) of referenced emails
|
||||||
|
how_realistic: float
|
||||||
|
inbox_address: str
|
||||||
|
query_date: str
|
||||||
|
|
||||||
|
@field_validator("query_date", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def format_date(cls, v: Any) -> str:
|
||||||
|
"""Format date to string if it's a datetime object."""
|
||||||
|
if isinstance(v, datetime.datetime):
|
||||||
|
return v.strftime("%Y-%m-%d")
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
class AnswerModel(BaseModel):
|
||||||
|
"""Model for agent's answer with sources."""
|
||||||
|
|
||||||
|
answer: str = Field(
|
||||||
|
description=(
|
||||||
|
"It should be called with the answer and the sources. "
|
||||||
|
"If you cannot find the answer, you should return "
|
||||||
|
"'I don't know' with an empty list of sources."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
sources: List[str] = Field(
|
||||||
|
description=(
|
||||||
|
"a list of message ids that are relevant to the query. "
|
||||||
|
"Usually there will be only one. If you cannot find the "
|
||||||
|
"answer, you should return an empty list."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Email(BaseModel):
|
||||||
|
"""Model representing an email."""
|
||||||
|
|
||||||
|
message_id: str
|
||||||
|
date: str # ISO 8601 string 'YYYY-MM-DD HH:MM:SS'
|
||||||
|
subject: Optional[str] = None
|
||||||
|
from_address: Optional[str] = None
|
||||||
|
to_addresses: List[str] = Field(default_factory=list)
|
||||||
|
cc_addresses: List[str] = Field(default_factory=list)
|
||||||
|
bcc_addresses: List[str] = Field(default_factory=list)
|
||||||
|
body: Optional[str] = None
|
||||||
|
file_name: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SearchResult:
|
||||||
|
"""Result from email search."""
|
||||||
|
|
||||||
|
message_id: str
|
||||||
|
snippet: str
|
||||||
|
|
||||||
|
|
||||||
|
class FinalRubric(BaseModel):
|
||||||
|
"""Rubric for evaluating agent performance."""
|
||||||
|
|
||||||
|
answer_correct: bool = False
|
||||||
|
sources_correct: bool = False
|
||||||
|
num_turns: int = 0
|
||||||
|
attempted_answer: bool = False
|
||||||
|
ever_found_right_email: bool = False
|
||||||
|
ever_read_right_email: bool = False
|
||||||
|
cant_parse_tool_call: bool = False
|
||||||
|
bad_tool_call_name: bool = False
|
||||||
|
bad_tool_call_args: bool = False
|
||||||
|
ran_out_of_turns: bool = False
|
||||||
|
returned_i_dont_know: bool = False
|
||||||
|
num_sources: int = 0
|
||||||
|
ever_tried_to_read_invalid_email: bool = False
|
||||||
|
prompt_tokens: int = 0
|
||||||
|
completion_tokens: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
# Define tools for agent
|
||||||
|
|
||||||
|
|
||||||
|
def search_emails_tool(
|
||||||
|
inbox: str,
|
||||||
|
keywords: List[str],
|
||||||
|
from_addr: Optional[str] = None,
|
||||||
|
to_addr: Optional[str] = None,
|
||||||
|
sent_after: Optional[str] = None,
|
||||||
|
sent_before: Optional[str] = None,
|
||||||
|
max_results: int = 10,
|
||||||
|
) -> List[SearchResult]:
|
||||||
|
"""
|
||||||
|
Searches the email database based on keywords, inbox,
|
||||||
|
sender, recipient, and date range.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inbox: The email address of the user performing the search.
|
||||||
|
Results include emails sent from or to (inc. cc/bcc)
|
||||||
|
this address.
|
||||||
|
keywords: A list of keywords that must all appear in the
|
||||||
|
subject or body.
|
||||||
|
from_addr: Optional email address to filter emails sent *from*.
|
||||||
|
to_addr: Optional email address to filter emails sent *to*
|
||||||
|
(inc. cc/bcc).
|
||||||
|
sent_after: Optional date string 'YYYY-MM-DD'. Filters for
|
||||||
|
emails sent on or after this date.
|
||||||
|
sent_before: Optional date string 'YYYY-MM-DD'. Filters for
|
||||||
|
emails sent before this date.
|
||||||
|
max_results: The maximum number of results to return.
|
||||||
|
Cannot exceed 10.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of SearchResult objects, each containing 'message_id'
|
||||||
|
and 'snippet'. Returns an empty list if no results are found
|
||||||
|
or an error occurs.
|
||||||
|
"""
|
||||||
|
# Initialize sql and params
|
||||||
|
sql: Optional[str] = None
|
||||||
|
params: List[str | int] = []
|
||||||
|
|
||||||
|
cursor = get_conn().cursor()
|
||||||
|
|
||||||
|
# --- Build Query ---
|
||||||
|
where_clauses: List[str] = []
|
||||||
|
|
||||||
|
# 1. Keywords (FTS)
|
||||||
|
if not keywords:
|
||||||
|
raise ValueError("No keywords provided for search.")
|
||||||
|
|
||||||
|
if max_results > 10:
|
||||||
|
raise ValueError("max_results must be less than or equal to 10.")
|
||||||
|
|
||||||
|
# FTS5 default is AND, so just join keywords. Escape quotes for safety.
|
||||||
|
fts_query = " ".join(f""" "{k.replace('"', '""')}" """ for k in keywords)
|
||||||
|
where_clauses.append("emails_fts MATCH ?")
|
||||||
|
params.append(fts_query)
|
||||||
|
|
||||||
|
# 2. Inbox filter (must be from OR to/cc/bcc the inbox user)
|
||||||
|
# Use the composite index idx_recipients_address_email here
|
||||||
|
where_clauses.append(
|
||||||
|
"""
|
||||||
|
(e.from_address = ? OR EXISTS (
|
||||||
|
SELECT 1 FROM recipients r_inbox
|
||||||
|
WHERE r_inbox.recipient_address = ? AND r_inbox.email_id = e.id
|
||||||
|
))
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
params.extend([inbox, inbox])
|
||||||
|
|
||||||
|
# 3. Optional From filter
|
||||||
|
if from_addr:
|
||||||
|
where_clauses.append("e.from_address = ?")
|
||||||
|
params.append(from_addr)
|
||||||
|
|
||||||
|
# 4. Optional To filter (includes to, cc, bcc)
|
||||||
|
# Use composite index idx_recipients_address_email
|
||||||
|
if to_addr:
|
||||||
|
where_clauses.append(
|
||||||
|
"""
|
||||||
|
EXISTS (
|
||||||
|
SELECT 1 FROM recipients r_to
|
||||||
|
WHERE r_to.recipient_address = ? AND r_to.email_id = e.id
|
||||||
|
)
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
params.append(to_addr)
|
||||||
|
|
||||||
|
# 5. Optional Sent After filter
|
||||||
|
if sent_after:
|
||||||
|
# Assumes date format 'YYYY-MM-DD'
|
||||||
|
# Compare against the start of the day
|
||||||
|
where_clauses.append("e.date >= ?")
|
||||||
|
params.append(f"{sent_after} 00:00:00")
|
||||||
|
|
||||||
|
# 6. Optional Sent Before filter
|
||||||
|
if sent_before:
|
||||||
|
# Assumes date format 'YYYY-MM-DD'
|
||||||
|
# Compare against the start of the day (exclusive)
|
||||||
|
where_clauses.append("e.date < ?")
|
||||||
|
params.append(f"{sent_before} 00:00:00")
|
||||||
|
|
||||||
|
# --- Construct Final Query ---
|
||||||
|
# snippet(<table>, <column_index>, <highlight_start>,
|
||||||
|
# <highlight_end>, <ellipsis>, <tokens>)
|
||||||
|
# -1 means highlight across all columns (subject, body)
|
||||||
|
sql = f"""
|
||||||
|
SELECT
|
||||||
|
e.message_id,
|
||||||
|
snippet(emails_fts, -1, '<b>', '</b>', ' ... ', 15) as snippet
|
||||||
|
FROM
|
||||||
|
emails e JOIN emails_fts fts ON e.id = fts.rowid
|
||||||
|
WHERE
|
||||||
|
{" AND ".join(where_clauses)}
|
||||||
|
ORDER BY
|
||||||
|
e.date DESC -- Order by date for relevance
|
||||||
|
LIMIT ?;
|
||||||
|
"""
|
||||||
|
params.append(max_results)
|
||||||
|
|
||||||
|
# --- Execute and Fetch ---
|
||||||
|
logger.debug("Executing SQL: %s", sql)
|
||||||
|
logger.debug("With params: %s", params)
|
||||||
|
cursor.execute(sql, params)
|
||||||
|
results = cursor.fetchall()
|
||||||
|
|
||||||
|
# Format results
|
||||||
|
formatted_results = [
|
||||||
|
SearchResult(message_id=row[0], snippet=row[1]) for row in results
|
||||||
|
]
|
||||||
|
logger.info("Search found %d results.", len(formatted_results))
|
||||||
|
return formatted_results
|
||||||
|
|
||||||
|
|
||||||
|
def read_email_tool(message_id: str) -> Optional[Email]:
|
||||||
|
"""
|
||||||
|
Retrieves a single email by its message_id from the database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message_id: The unique identifier of the email to retrieve.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An Email object containing the details of the found email,
|
||||||
|
or None if the email is not found or an error occurs.
|
||||||
|
"""
|
||||||
|
cursor = get_conn().cursor()
|
||||||
|
|
||||||
|
# --- Query for Email Core Details ---
|
||||||
|
email_sql = """
|
||||||
|
SELECT id, message_id, date, subject, from_address, body, file_name
|
||||||
|
FROM emails
|
||||||
|
WHERE message_id = ?;
|
||||||
|
"""
|
||||||
|
cursor.execute(email_sql, (message_id,))
|
||||||
|
email_row = cursor.fetchone()
|
||||||
|
|
||||||
|
if not email_row:
|
||||||
|
logger.warning("Email with message_id '%s' not found.", message_id)
|
||||||
|
return None
|
||||||
|
|
||||||
|
email_pk_id, msg_id, date, subject, from_addr, body, file_name = email_row
|
||||||
|
|
||||||
|
# DEBUG
|
||||||
|
logger.info("[read_email_tool] input_message_id=%s", message_id)
|
||||||
|
logger.info(
|
||||||
|
"[read_email_tool] db: id=%s, message_id=%s",
|
||||||
|
email_pk_id,
|
||||||
|
msg_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# search for recipients by emails.id (rather than message_id)
|
||||||
|
recipients_sql = """
|
||||||
|
SELECT recipient_address, recipient_type
|
||||||
|
FROM recipients
|
||||||
|
WHERE email_id = ?;
|
||||||
|
"""
|
||||||
|
cursor.execute(recipients_sql, (email_pk_id,))
|
||||||
|
recipient_rows = cursor.fetchall()
|
||||||
|
|
||||||
|
to_addresses: List[str] = []
|
||||||
|
cc_addresses: List[str] = []
|
||||||
|
bcc_addresses: List[str] = []
|
||||||
|
|
||||||
|
for addr, rtype in recipient_rows:
|
||||||
|
type_lower = rtype.lower()
|
||||||
|
if type_lower == "to":
|
||||||
|
to_addresses.append(addr)
|
||||||
|
elif type_lower == "cc":
|
||||||
|
cc_addresses.append(addr)
|
||||||
|
elif type_lower == "bcc":
|
||||||
|
bcc_addresses.append(addr)
|
||||||
|
|
||||||
|
# --- Construct Email Object ---
|
||||||
|
email_obj = Email(
|
||||||
|
message_id=msg_id, # Convert to string to match Pydantic model
|
||||||
|
date=date,
|
||||||
|
subject=subject,
|
||||||
|
from_address=from_addr,
|
||||||
|
to_addresses=to_addresses,
|
||||||
|
cc_addresses=cc_addresses,
|
||||||
|
bcc_addresses=bcc_addresses,
|
||||||
|
body=body,
|
||||||
|
file_name=file_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
return email_obj
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"QueryModel",
|
||||||
|
"AnswerModel",
|
||||||
|
"FinalRubric",
|
||||||
|
"Email",
|
||||||
|
"SearchResult",
|
||||||
|
"search_emails_tool",
|
||||||
|
"read_email_tool",
|
||||||
|
"get_conn",
|
||||||
|
]
|
||||||
72
tuner/email_search/config.yaml
Normal file
72
tuner/email_search/config.yaml
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
project: "AgentScope" # Project name
|
||||||
|
name: "Email_search" # Experiment name
|
||||||
|
checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints} # Directory to save model checkpoints
|
||||||
|
algorithm:
|
||||||
|
algorithm_type: multi_step_grpo # GRPO series for multi-step scenario
|
||||||
|
repeat_times: 8 # Number of rollouts per prompt for advantage estimation
|
||||||
|
optimizer:
|
||||||
|
lr: 1e-6 # Learning rate
|
||||||
|
policy_loss_fn: "rec" # Policy loss function
|
||||||
|
policy_loss_fn_args: # Policy loss function arguments
|
||||||
|
epsilon_low: 0.2
|
||||||
|
epsilon_high: 0.2
|
||||||
|
clip_mode: "one-side"
|
||||||
|
weight: "none"
|
||||||
|
temp: 1.0
|
||||||
|
regularizer: "none"
|
||||||
|
regularizer_coef: 0.0
|
||||||
|
kl_loss_fn: 'k2' # KL divergence loss function
|
||||||
|
kl_loss_fn_args:
|
||||||
|
kl_coef: 0.0 # KL divergence coefficient
|
||||||
|
advantage_fn_args:
|
||||||
|
std_cal_level: 'batch' # Advantage normalization level
|
||||||
|
model:
|
||||||
|
model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen3-4B-Instruct-2507} # Base model path
|
||||||
|
max_response_tokens: 4096 # Max tokens per response
|
||||||
|
max_model_len: 20480 # Max context length
|
||||||
|
buffer:
|
||||||
|
total_epochs: 10 # Total training epochs
|
||||||
|
batch_size: 64 # Batch size per explore step
|
||||||
|
train_batch_size: 2560 # 64*8*5, total experiences per training step
|
||||||
|
trainer_input:
|
||||||
|
experience_buffer:
|
||||||
|
name: experience_buffer
|
||||||
|
storage_type: queue
|
||||||
|
replay_buffer:
|
||||||
|
enable: true # Enable experience replay
|
||||||
|
priority_fn: 'decay_limit_randomization'
|
||||||
|
priority_fn_args:
|
||||||
|
decay: 2.0
|
||||||
|
use_count_limit: 3
|
||||||
|
sigma: 2.0
|
||||||
|
explorer:
|
||||||
|
eval_interval: 10
|
||||||
|
max_repeat_times_per_runner: 1 # Max repeat times per runner
|
||||||
|
max_timeout: 3600 # Max timeout for each rollout (seconds)
|
||||||
|
rollout_model:
|
||||||
|
enable_history: true # Enable conversation history
|
||||||
|
enable_openai_api: true # Enable OpenAI-compatible API
|
||||||
|
enable_auto_tool_choice: true # Enable automatic tool selection
|
||||||
|
tool_call_parser: hermes # Parser for tool calls
|
||||||
|
engine_num: 4 # Number of vLLM engines for rollout model
|
||||||
|
tensor_parallel_size: 1 # TP size per engine for rollout model
|
||||||
|
enable_prefix_caching: false # Disable prefix caching
|
||||||
|
auxiliary_models:
|
||||||
|
- name: judge
|
||||||
|
model_path: Qwen/Qwen3-30B-A3B-Instruct-2507 # Judge model path
|
||||||
|
engine_num: 1 # Number of vLLM engines for judge model
|
||||||
|
tensor_parallel_size: 2 # TP size per engine for judge model
|
||||||
|
enable_thinking: false # Disable thinking/reasoning mode
|
||||||
|
max_prompt_tokens: 2048 # Max tokens for prompt
|
||||||
|
max_response_tokens: 128 # Max tokens for response
|
||||||
|
max_model_len: 2500 # Max model context length
|
||||||
|
synchronizer:
|
||||||
|
sync_style: dynamic_by_explorer # Sync triggered dynamically by explorer
|
||||||
|
sync_interval: 5 # Sync every N steps
|
||||||
|
sync_timeout: 3600 # Timeout for synchronization (seconds)
|
||||||
|
trainer:
|
||||||
|
save_interval: 100 # Save checkpoint every N steps
|
||||||
|
grad_clip: 1.0 # Gradient clipping value
|
||||||
|
use_dynamic_bsz: true # Use dynamic batch size
|
||||||
|
max_token_len_per_gpu: 16384 # Max token length per GPU
|
||||||
|
ulysses_sequence_parallel_size: 1 # Sequence parallel size for Ulysses
|
||||||
BIN
tuner/email_search/critic_reward_mean.png
Normal file
BIN
tuner/email_search/critic_reward_mean.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 470 KiB |
379
tuner/email_search/main.py
Normal file
379
tuner/email_search/main.py
Normal file
@@ -0,0 +1,379 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""Example of training an Email Search agent with Trinity-RFT."""
|
||||||
|
import os
|
||||||
|
from typing import Dict
|
||||||
|
from _email_search_agent import EmailSearchAgent
|
||||||
|
from _utils import ( # pylint: disable=E0611
|
||||||
|
AnswerModel,
|
||||||
|
FinalRubric,
|
||||||
|
QueryModel,
|
||||||
|
)
|
||||||
|
from agentscope import logger
|
||||||
|
from agentscope.formatter import OpenAIChatFormatter
|
||||||
|
from agentscope.message import Msg
|
||||||
|
from agentscope.tuner import (
|
||||||
|
TunerModelConfig,
|
||||||
|
DatasetConfig,
|
||||||
|
JudgeOutput,
|
||||||
|
WorkflowOutput,
|
||||||
|
AlgorithmConfig,
|
||||||
|
tune,
|
||||||
|
)
|
||||||
|
from agentscope.model import ChatModelBase
|
||||||
|
|
||||||
|
|
||||||
|
SYSTEM_PROMPT = """You are an email search agent. You are given a user query
|
||||||
|
and a list of tools you can use to search the user's email. Use the tools to
|
||||||
|
search the user's emails and find the answer to the user's query. You may take
|
||||||
|
up to {max_turns} turns to find the answer, so if your first seach doesn't
|
||||||
|
find the answer, you can try with different keywords.
|
||||||
|
|
||||||
|
Always describe what you see and plan your next steps clearly. When taking
|
||||||
|
actions, explain what you're doing and why. When the answer to the task is
|
||||||
|
found, call `generate_response` to finish the process. Only call
|
||||||
|
`generate_response` when answer is found. You should not respond any next steps
|
||||||
|
in `generate_response`. Complete all steps and then call `generate_response`.
|
||||||
|
|
||||||
|
User's email address is {inbox_address}
|
||||||
|
|
||||||
|
Today's date is {query_date}
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
async def run_email_search_agent(
|
||||||
|
task: Dict,
|
||||||
|
model: ChatModelBase,
|
||||||
|
auxiliary_models: Dict[str, ChatModelBase],
|
||||||
|
) -> WorkflowOutput: # noqa: PLR0915
|
||||||
|
"""A workflow function using the Email Search agent to solve tasks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task (Dict): The task to be solved.
|
||||||
|
Should contain fields from QueryModel.
|
||||||
|
model (TrinityChatModel): The language model to use.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
WorkflowOutput: The output containing the agent's response.
|
||||||
|
"""
|
||||||
|
assert len(auxiliary_models) > 0, "LLM-as-a-Judge is required"
|
||||||
|
|
||||||
|
# Parse task data
|
||||||
|
query = QueryModel.model_validate(task)
|
||||||
|
question = task.get("question", task.get("task_desc", ""))
|
||||||
|
|
||||||
|
# Get workflow arguments with defaults
|
||||||
|
workflow_args = task.get("workflow_args", {})
|
||||||
|
max_turns = int(workflow_args.get("max_turns", 10))
|
||||||
|
|
||||||
|
# Format system prompt
|
||||||
|
system_prompt = SYSTEM_PROMPT.format(
|
||||||
|
max_turns=max_turns,
|
||||||
|
inbox_address=query.inbox_address,
|
||||||
|
query_date=query.query_date,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create EmailSearchAgent
|
||||||
|
agent = EmailSearchAgent(
|
||||||
|
name="email_search_agent",
|
||||||
|
sys_prompt=system_prompt,
|
||||||
|
model=model,
|
||||||
|
formatter=OpenAIChatFormatter(),
|
||||||
|
max_iters=max_turns,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reset agent state for a new rollout
|
||||||
|
await agent.reset()
|
||||||
|
|
||||||
|
# Run the agent with structured output
|
||||||
|
response = await agent.reply(
|
||||||
|
msg=Msg("user", question, role="user"),
|
||||||
|
structured_model=AnswerModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract answer and sources from response metadata
|
||||||
|
answer_and_sources = response.metadata or {}
|
||||||
|
if not answer_and_sources:
|
||||||
|
# Fallback: try to parse from content
|
||||||
|
answer_and_sources = {
|
||||||
|
"answer": response.get_text_content() or "",
|
||||||
|
"sources": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Store agent state for judge function
|
||||||
|
# We'll pass this through the response metadata
|
||||||
|
response_metadata = {
|
||||||
|
"answer_and_sources": answer_and_sources,
|
||||||
|
"query": query.model_dump(),
|
||||||
|
"message_id_list": agent.message_id_list,
|
||||||
|
"ever_read_message_ids": agent.ever_read_message_ids,
|
||||||
|
# Estimate actual_turns from memory length
|
||||||
|
"actual_turns": (
|
||||||
|
max(1, (len(agent.memory.content) - 1) // 2)
|
||||||
|
if len(agent.memory.content) > 1
|
||||||
|
else 1
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Update response metadata
|
||||||
|
if response.metadata is None:
|
||||||
|
response.metadata = {}
|
||||||
|
response.metadata.update(response_metadata)
|
||||||
|
|
||||||
|
return WorkflowOutput(
|
||||||
|
response=response,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _calculate_partial_rewards(rubric: FinalRubric) -> float:
|
||||||
|
"""Calculate partial rewards based on rubric."""
|
||||||
|
partial_rewards = 0.0
|
||||||
|
partial_rewards += 0.1 if rubric.ever_found_right_email else 0
|
||||||
|
partial_rewards += 0.1 if rubric.ever_read_right_email else 0
|
||||||
|
partial_rewards += 0.1 if rubric.sources_correct else 0
|
||||||
|
return partial_rewards
|
||||||
|
|
||||||
|
|
||||||
|
def _calculate_correct_answer_reward(
|
||||||
|
rubric: FinalRubric,
|
||||||
|
max_turns: int,
|
||||||
|
) -> float:
|
||||||
|
"""Calculate reward for correct answers."""
|
||||||
|
reward = 1.0
|
||||||
|
reward += 0.3 if rubric.sources_correct else 0
|
||||||
|
reward += 0.1 / rubric.num_sources if rubric.num_sources > 0 else 0
|
||||||
|
reward += 0.1 * (1 - rubric.num_turns / max_turns)
|
||||||
|
return reward
|
||||||
|
|
||||||
|
|
||||||
|
def _initialize_rubric(
|
||||||
|
answer: str,
|
||||||
|
sources: list[str],
|
||||||
|
actual_turns: int,
|
||||||
|
query: QueryModel,
|
||||||
|
message_id_list: list[str],
|
||||||
|
ever_read_message_ids: list[str],
|
||||||
|
) -> FinalRubric:
|
||||||
|
"""Initialize and populate rubric with basic information."""
|
||||||
|
rubric = FinalRubric()
|
||||||
|
rubric.attempted_answer = answer is not None and answer != ""
|
||||||
|
rubric.returned_i_dont_know = answer == "I don't know"
|
||||||
|
rubric.num_sources = len(sources)
|
||||||
|
rubric.num_turns = actual_turns
|
||||||
|
|
||||||
|
if len(query.message_ids) > 0:
|
||||||
|
rubric.ever_found_right_email = query.message_ids[0] in message_id_list
|
||||||
|
rubric.ever_read_right_email = (
|
||||||
|
query.message_ids[0] in ever_read_message_ids
|
||||||
|
)
|
||||||
|
rubric.sources_correct = query.message_ids[0] in sources
|
||||||
|
return rubric
|
||||||
|
|
||||||
|
|
||||||
|
async def email_search_judge(
|
||||||
|
task: Dict,
|
||||||
|
response: Msg,
|
||||||
|
auxiliary_models: Dict[str, ChatModelBase],
|
||||||
|
) -> JudgeOutput:
|
||||||
|
"""A judge function to calculate reward based on agent's response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task (Dict): The task information for the corresponding workflow.
|
||||||
|
response (Msg): The response generated by the corresponding workflow.
|
||||||
|
auxiliary_models (Dict[str, ChatModelBase]):
|
||||||
|
A dictionary of additional chat models available for LLM-as-a-Judge
|
||||||
|
usage. The keys are model names, and the values are the
|
||||||
|
corresponding ChatModelBase instances.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JudgeOutput: The reward value assigned by the judge function.
|
||||||
|
"""
|
||||||
|
# Extract metadata from response
|
||||||
|
metadata = response.metadata or {}
|
||||||
|
answer_and_sources = metadata.get("answer_and_sources", {})
|
||||||
|
query_dict = metadata.get("query", {})
|
||||||
|
message_id_list = metadata.get("message_id_list", [])
|
||||||
|
ever_read_message_ids = metadata.get("ever_read_message_ids", [])
|
||||||
|
actual_turns = metadata.get("actual_turns", 0)
|
||||||
|
|
||||||
|
# Parse query model
|
||||||
|
if not query_dict:
|
||||||
|
query_dict = task
|
||||||
|
query = QueryModel.model_validate(query_dict)
|
||||||
|
|
||||||
|
# Get arguments
|
||||||
|
workflow_args = task.get("workflow_args", {})
|
||||||
|
max_turns = int(workflow_args.get("max_turns", 10))
|
||||||
|
|
||||||
|
# Extract answer and sources
|
||||||
|
try:
|
||||||
|
answer = answer_and_sources.get("answer", None)
|
||||||
|
sources = answer_and_sources.get("sources", [])
|
||||||
|
except Exception:
|
||||||
|
result = {"accuracy": 0.0, "format": -1.0}
|
||||||
|
return JudgeOutput(
|
||||||
|
reward=sum(result.values()),
|
||||||
|
metrics=result,
|
||||||
|
)
|
||||||
|
|
||||||
|
if answer is None:
|
||||||
|
result = {"accuracy": 0.0, "format": -1.0}
|
||||||
|
return JudgeOutput(
|
||||||
|
reward=sum(result.values()),
|
||||||
|
metrics=result,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize rubric
|
||||||
|
rubric = _initialize_rubric(
|
||||||
|
answer,
|
||||||
|
sources,
|
||||||
|
actual_turns,
|
||||||
|
query,
|
||||||
|
message_id_list,
|
||||||
|
ever_read_message_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Judge correctness using LLM-as-a-Judge
|
||||||
|
try:
|
||||||
|
judge_model = (
|
||||||
|
auxiliary_models.get("judge") or list(auxiliary_models.values())[0]
|
||||||
|
if auxiliary_models
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
judge_response = await judge_correctness(
|
||||||
|
answer,
|
||||||
|
query,
|
||||||
|
judge_model,
|
||||||
|
)
|
||||||
|
rubric.answer_correct = judge_response
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error judging correctness: %s", e)
|
||||||
|
rubric.answer_correct = False
|
||||||
|
|
||||||
|
# Calculate rewards
|
||||||
|
partial_rewards = _calculate_partial_rewards(rubric)
|
||||||
|
|
||||||
|
if rubric.attempted_answer and not rubric.answer_correct:
|
||||||
|
result = {"accuracy": -1.0, "format": partial_rewards}
|
||||||
|
elif rubric.returned_i_dont_know or rubric.ran_out_of_turns:
|
||||||
|
result = {"accuracy": 0.0, "format": partial_rewards}
|
||||||
|
elif rubric.answer_correct:
|
||||||
|
reward = _calculate_correct_answer_reward(rubric, max_turns)
|
||||||
|
result = {"accuracy": 1.0, "format": reward}
|
||||||
|
else:
|
||||||
|
result = {"accuracy": 0.0, "format": 0.0}
|
||||||
|
|
||||||
|
metrics = result.copy()
|
||||||
|
metrics.update({"actual_turns": actual_turns})
|
||||||
|
|
||||||
|
return JudgeOutput(
|
||||||
|
reward=sum(result.values()),
|
||||||
|
metrics=metrics,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# LLM-as-a-judge
|
||||||
|
|
||||||
|
|
||||||
|
async def judge_correctness(
|
||||||
|
answer: str,
|
||||||
|
query: QueryModel,
|
||||||
|
judge: ChatModelBase,
|
||||||
|
) -> bool:
|
||||||
|
"""Use an LLM to decide whether *answer* matches *query.answer*.
|
||||||
|
|
||||||
|
Returns a boolean *accept* flag used for scoring.
|
||||||
|
"""
|
||||||
|
|
||||||
|
system_prompt = """You are given a question, the reference answer
|
||||||
|
(labelled **Reference answer**), and an answer generated by an AI assistant
|
||||||
|
(labelled **AI answer**).
|
||||||
|
|
||||||
|
Follow these steps to decide whether the AI answer should be accepted:
|
||||||
|
1. Identify EXACTLY what information the **question** is asking for
|
||||||
|
(e.g. who, what, when, where, why, how, quantity, etc.).
|
||||||
|
2. From the **Reference answer**, extract ONLY the facts that are required
|
||||||
|
to directly satisfy the information need identified in step 1. Treat all
|
||||||
|
other facts as non-essential context.
|
||||||
|
3. Verify that every essential fact from step 2 appears in the **AI answer**
|
||||||
|
with the same meaning. Differences in wording, order, or additional
|
||||||
|
non-conflicting details are allowed.
|
||||||
|
4. If any essential fact is missing or contradicted in the **AI answer**,
|
||||||
|
then *accept* must be **false**. Otherwise *accept* must be **true**.
|
||||||
|
|
||||||
|
Important: Do NOT penalise the **AI answer** for omitting non-essential
|
||||||
|
facts that appear in the **Reference answer**. The answer should only be
|
||||||
|
rejected for errors or omissions in the information explicitly requested by
|
||||||
|
the question.
|
||||||
|
|
||||||
|
Return your judgement **accept** from **true** and **false**. Do not return
|
||||||
|
any other text or formatting.
|
||||||
|
"""
|
||||||
|
prompt = (
|
||||||
|
f"Question: {query.question}\n"
|
||||||
|
f"Reference answer: {query.answer}\n"
|
||||||
|
f"AI answer: {answer}"
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": system_prompt},
|
||||||
|
{"role": "user", "content": prompt},
|
||||||
|
]
|
||||||
|
chat_response = await judge(messages)
|
||||||
|
|
||||||
|
# Extract text content from ChatResponse
|
||||||
|
result_parts = []
|
||||||
|
for block in chat_response.content:
|
||||||
|
if isinstance(block, dict) and block.get("type") == "text":
|
||||||
|
result_parts.append(str(block.get("text", "")))
|
||||||
|
result = "".join(result_parts)
|
||||||
|
logger.info("LLM judge response: %s", result)
|
||||||
|
|
||||||
|
return "true" in result.lower()
|
||||||
|
|
||||||
|
|
||||||
|
# End of LLM-as-a-judge
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
config_path = os.path.join(
|
||||||
|
os.path.dirname(__file__),
|
||||||
|
"config.yaml",
|
||||||
|
)
|
||||||
|
dataset = DatasetConfig(
|
||||||
|
path="/path/to/enron_emails_dataset",
|
||||||
|
split="train",
|
||||||
|
)
|
||||||
|
tuner_model = TunerModelConfig(
|
||||||
|
model_path="Qwen/Qwen3-4B-Instruct-2507",
|
||||||
|
max_model_len=20480,
|
||||||
|
max_tokens=4096,
|
||||||
|
inference_engine_num=4,
|
||||||
|
reasoning_parser=None,
|
||||||
|
)
|
||||||
|
aux_models = {
|
||||||
|
"judge": TunerModelConfig(
|
||||||
|
model_path="Qwen/Qwen3-30B-A3B-Instruct-2507",
|
||||||
|
max_model_len=2500,
|
||||||
|
max_tokens=2048,
|
||||||
|
inference_engine_num=1,
|
||||||
|
tensor_parallel_size=2,
|
||||||
|
tool_call_parser=None,
|
||||||
|
reasoning_parser=None,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
algorithm = AlgorithmConfig(
|
||||||
|
algorithm_type="multi_step_grpo",
|
||||||
|
group_size=8,
|
||||||
|
batch_size=64,
|
||||||
|
learning_rate=1e-6,
|
||||||
|
)
|
||||||
|
tune(
|
||||||
|
workflow_func=run_email_search_agent,
|
||||||
|
judge_func=email_search_judge,
|
||||||
|
train_dataset=dataset,
|
||||||
|
model=tuner_model,
|
||||||
|
auxiliary_models=aux_models,
|
||||||
|
algorithm=algorithm,
|
||||||
|
config_path=config_path,
|
||||||
|
)
|
||||||
357
tuner/email_search/prepare_data.py
Normal file
357
tuner/email_search/prepare_data.py
Normal file
@@ -0,0 +1,357 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
Prepare data for training.
|
||||||
|
Modified from OpenPipe/ART
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sqlite3
|
||||||
|
from datetime import datetime
|
||||||
|
from datasets import Dataset, Features, Sequence, Value, load_dataset
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
# Resolve paths relative to this file
|
||||||
|
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
# Database will live in "../data/enron_emails.db" relative to project root
|
||||||
|
DEFAULT_DB_PATH = os.path.join(BASE_DIR, "..", "..", "data", "enron_emails.db")
|
||||||
|
|
||||||
|
DEFAULT_REPO_ID = "corbt/enron-emails"
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# --- Database Schema ---
|
||||||
|
SQL_CREATE_TABLES = """
|
||||||
|
DROP TABLE IF EXISTS recipients;
|
||||||
|
DROP TABLE IF EXISTS emails_fts;
|
||||||
|
DROP TABLE IF EXISTS emails;
|
||||||
|
|
||||||
|
CREATE TABLE emails (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
message_id TEXT UNIQUE,
|
||||||
|
subject TEXT,
|
||||||
|
from_address TEXT,
|
||||||
|
date TEXT, -- Store as ISO 8601 string 'YYYY-MM-DD HH:MM:SS'
|
||||||
|
body TEXT,
|
||||||
|
file_name TEXT
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE recipients (
|
||||||
|
email_id INTEGER,
|
||||||
|
recipient_address TEXT,
|
||||||
|
recipient_type TEXT, -- 'to', 'cc', 'bcc'
|
||||||
|
FOREIGN KEY(email_id) REFERENCES emails(id) ON DELETE CASCADE
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
|
||||||
|
SQL_CREATE_INDEXES_TRIGGERS = """
|
||||||
|
CREATE INDEX idx_emails_from ON emails(from_address);
|
||||||
|
CREATE INDEX idx_emails_date ON emails(date);
|
||||||
|
CREATE INDEX idx_emails_message_id ON emails(message_id);
|
||||||
|
CREATE INDEX idx_recipients_address ON recipients(recipient_address);
|
||||||
|
CREATE INDEX idx_recipients_type ON recipients(recipient_type);
|
||||||
|
CREATE INDEX idx_recipients_email_id ON recipients(email_id);
|
||||||
|
CREATE INDEX idx_recipients_address_email ON recipients(
|
||||||
|
recipient_address, email_id
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE VIRTUAL TABLE emails_fts USING fts5(
|
||||||
|
subject,
|
||||||
|
body,
|
||||||
|
content='emails',
|
||||||
|
content_rowid='id'
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TRIGGER emails_ai AFTER INSERT ON emails BEGIN
|
||||||
|
INSERT INTO emails_fts (rowid, subject, body)
|
||||||
|
VALUES (new.id, new.subject, new.body);
|
||||||
|
END;
|
||||||
|
|
||||||
|
CREATE TRIGGER emails_ad AFTER DELETE ON emails BEGIN
|
||||||
|
DELETE FROM emails_fts WHERE rowid=old.id;
|
||||||
|
END;
|
||||||
|
|
||||||
|
CREATE TRIGGER emails_au AFTER UPDATE ON emails BEGIN
|
||||||
|
UPDATE emails_fts SET subject=new.subject, body=new.body
|
||||||
|
WHERE rowid=old.id;
|
||||||
|
END;
|
||||||
|
|
||||||
|
INSERT INTO emails_fts (rowid, subject, body)
|
||||||
|
SELECT id, subject, body FROM emails;
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
# --- Functions ---
|
||||||
|
|
||||||
|
|
||||||
|
def download_dataset(repo_id: str) -> Dataset:
|
||||||
|
"""Downloads the dataset from Hugging Face Hub."""
|
||||||
|
logging.info(
|
||||||
|
"Attempting to download dataset from Hugging Face Hub: %s",
|
||||||
|
repo_id,
|
||||||
|
)
|
||||||
|
expected_features = Features(
|
||||||
|
{
|
||||||
|
"message_id": Value("string"),
|
||||||
|
"subject": Value("string"),
|
||||||
|
"from": Value("string"),
|
||||||
|
"to": Sequence(Value("string")),
|
||||||
|
"cc": Sequence(Value("string")),
|
||||||
|
"bcc": Sequence(Value("string")),
|
||||||
|
"date": Value("timestamp[us]"),
|
||||||
|
"body": Value("string"),
|
||||||
|
"file_name": Value("string"),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
dataset_obj = load_dataset(
|
||||||
|
repo_id,
|
||||||
|
features=expected_features,
|
||||||
|
split="train",
|
||||||
|
)
|
||||||
|
# Basic type check remains useful
|
||||||
|
if not isinstance(dataset_obj, Dataset):
|
||||||
|
raise TypeError(f"Expected Dataset, got {type(dataset_obj)}")
|
||||||
|
logging.info(
|
||||||
|
"Successfully loaded dataset '%s' with %d records.",
|
||||||
|
repo_id,
|
||||||
|
len(dataset_obj),
|
||||||
|
)
|
||||||
|
return dataset_obj
|
||||||
|
|
||||||
|
|
||||||
|
def create_database(db_path: str) -> None:
|
||||||
|
"""Creates the SQLite database and tables."""
|
||||||
|
logging.info("Creating SQLite database and tables at: %s", db_path)
|
||||||
|
conn = sqlite3.connect(db_path)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.executescript(SQL_CREATE_TABLES)
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
logging.info("Database tables created successfully.")
|
||||||
|
|
||||||
|
|
||||||
|
def _should_skip_email(
|
||||||
|
body: str,
|
||||||
|
message_id: str,
|
||||||
|
to_list: list[str],
|
||||||
|
cc_list: list[str],
|
||||||
|
bcc_list: list[str],
|
||||||
|
) -> bool:
|
||||||
|
"""Check if email should be skipped based on filters."""
|
||||||
|
if len(body) > 5000:
|
||||||
|
logging.debug(
|
||||||
|
"Skipping email %s: Body length > 5000 characters.",
|
||||||
|
message_id,
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
total_recipients = len(to_list) + len(cc_list) + len(bcc_list)
|
||||||
|
if total_recipients > 30:
|
||||||
|
logging.debug(
|
||||||
|
"Skipping email %s: Total recipients (%d) > 30.",
|
||||||
|
message_id,
|
||||||
|
total_recipients,
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _prepare_recipient_data(
|
||||||
|
email_pk_id: int,
|
||||||
|
to_list: list[str],
|
||||||
|
cc_list: list[str],
|
||||||
|
bcc_list: list[str],
|
||||||
|
) -> list[tuple[int, str, str]]:
|
||||||
|
"""Prepare recipient data for database insertion."""
|
||||||
|
recipient_data = []
|
||||||
|
for addr in to_list:
|
||||||
|
recipient_data.append((email_pk_id, addr, "to"))
|
||||||
|
for addr in cc_list:
|
||||||
|
recipient_data.append((email_pk_id, addr, "cc"))
|
||||||
|
for addr in bcc_list:
|
||||||
|
recipient_data.append((email_pk_id, addr, "bcc"))
|
||||||
|
return recipient_data
|
||||||
|
|
||||||
|
|
||||||
|
def populate_database(db_path: str, dataset: Dataset) -> None:
|
||||||
|
"""Populates the database with data from the Hugging Face dataset."""
|
||||||
|
logging.info("Populating database %s...", db_path)
|
||||||
|
conn = sqlite3.connect(db_path)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
# --- Performance Pragmas ---
|
||||||
|
conn.execute("PRAGMA synchronous = OFF;")
|
||||||
|
conn.execute("PRAGMA journal_mode = MEMORY;")
|
||||||
|
|
||||||
|
record_count = 0
|
||||||
|
skipped_count = 0
|
||||||
|
duplicate_count = 0
|
||||||
|
processed_emails = set()
|
||||||
|
|
||||||
|
conn.execute("BEGIN TRANSACTION;")
|
||||||
|
|
||||||
|
for email_data in tqdm(dataset, desc="Inserting emails"):
|
||||||
|
assert isinstance(email_data, dict)
|
||||||
|
message_id = email_data["message_id"]
|
||||||
|
subject = email_data["subject"]
|
||||||
|
from_address = email_data["from"]
|
||||||
|
date_obj: datetime = email_data["date"]
|
||||||
|
body = email_data["body"]
|
||||||
|
file_name = email_data["file_name"]
|
||||||
|
to_list_raw = email_data["to"]
|
||||||
|
cc_list_raw = email_data["cc"]
|
||||||
|
bcc_list_raw = email_data["bcc"]
|
||||||
|
|
||||||
|
date_str = date_obj.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
to_list = [str(addr) for addr in to_list_raw if addr]
|
||||||
|
cc_list = [str(addr) for addr in cc_list_raw if addr]
|
||||||
|
bcc_list = [str(addr) for addr in bcc_list_raw if addr]
|
||||||
|
|
||||||
|
if _should_skip_email(body, message_id, to_list, cc_list, bcc_list):
|
||||||
|
skipped_count += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
email_key = (subject, body, from_address)
|
||||||
|
if email_key in processed_emails:
|
||||||
|
logging.debug(
|
||||||
|
"Skipping duplicate email (Subject: %s..., From: %s)",
|
||||||
|
subject[:50],
|
||||||
|
from_address,
|
||||||
|
)
|
||||||
|
duplicate_count += 1
|
||||||
|
continue
|
||||||
|
processed_emails.add(email_key)
|
||||||
|
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO emails (
|
||||||
|
message_id, subject, from_address, date, body, file_name
|
||||||
|
)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?)
|
||||||
|
""",
|
||||||
|
(message_id, subject, from_address, date_str, body, file_name),
|
||||||
|
)
|
||||||
|
email_pk_id = cursor.lastrowid
|
||||||
|
if email_pk_id is None:
|
||||||
|
logging.warning(
|
||||||
|
"Failed to get email ID after insert for message_id: %s",
|
||||||
|
message_id,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
recipient_data = _prepare_recipient_data(
|
||||||
|
email_pk_id,
|
||||||
|
to_list,
|
||||||
|
cc_list,
|
||||||
|
bcc_list,
|
||||||
|
)
|
||||||
|
|
||||||
|
if recipient_data:
|
||||||
|
cursor.executemany(
|
||||||
|
"""
|
||||||
|
INSERT INTO recipients (
|
||||||
|
email_id, recipient_address, recipient_type
|
||||||
|
)
|
||||||
|
VALUES (?, ?, ?)
|
||||||
|
""",
|
||||||
|
recipient_data,
|
||||||
|
)
|
||||||
|
record_count += 1
|
||||||
|
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
logging.info("Successfully inserted %d email records.", record_count)
|
||||||
|
if skipped_count > 0:
|
||||||
|
logging.info(
|
||||||
|
"Skipped %d email records due to length or recipient limits.",
|
||||||
|
skipped_count,
|
||||||
|
)
|
||||||
|
if duplicate_count > 0:
|
||||||
|
logging.info(
|
||||||
|
"Skipped %d duplicate email records "
|
||||||
|
"(based on subject, body, from).",
|
||||||
|
duplicate_count,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_indexes_and_triggers(db_path: str) -> None:
|
||||||
|
"""Creates indexes and triggers on the populated database."""
|
||||||
|
logging.info("Creating indexes and triggers for database: %s...", db_path)
|
||||||
|
conn = sqlite3.connect(db_path)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.executescript(SQL_CREATE_INDEXES_TRIGGERS)
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
logging.info("Indexes and triggers created successfully.")
|
||||||
|
|
||||||
|
|
||||||
|
def generate_database(
|
||||||
|
repo_id: str = DEFAULT_REPO_ID,
|
||||||
|
db_path: str = DEFAULT_DB_PATH,
|
||||||
|
overwrite: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Generates the SQLite database from the specified Hugging Face dataset.
|
||||||
|
Simplified version without extensive error handling.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repo_id: The Hugging Face repository ID for the dataset.
|
||||||
|
db_path: The path where the SQLite database file should be
|
||||||
|
created.
|
||||||
|
overwrite: If True, any existing database file at db_path will
|
||||||
|
be removed.
|
||||||
|
"""
|
||||||
|
logging.info(
|
||||||
|
"Starting database generation for repo '%s' at '%s'",
|
||||||
|
repo_id,
|
||||||
|
db_path,
|
||||||
|
)
|
||||||
|
logging.info("Overwrite existing database: %s", overwrite)
|
||||||
|
|
||||||
|
db_dir = os.path.dirname(db_path)
|
||||||
|
if db_dir and not os.path.exists(db_dir):
|
||||||
|
logging.info("Creating data directory: %s", db_dir)
|
||||||
|
os.makedirs(db_dir)
|
||||||
|
|
||||||
|
if overwrite and os.path.exists(db_path):
|
||||||
|
logging.warning("Removing existing database file: %s", db_path)
|
||||||
|
os.remove(db_path)
|
||||||
|
elif not overwrite and os.path.exists(db_path):
|
||||||
|
# If not overwriting and file exists, subsequent steps might fail
|
||||||
|
# or behave unexpectedly. We are removing the explicit error here
|
||||||
|
# as requested.
|
||||||
|
logging.warning(
|
||||||
|
"Database file %s exists and overwrite is False. "
|
||||||
|
"Assuming file is already generated.",
|
||||||
|
db_path,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# 1. Download dataset
|
||||||
|
dataset = download_dataset(repo_id)
|
||||||
|
|
||||||
|
# 2. Create database schema (Tables only)
|
||||||
|
# Note: This will fail if overwrite=False and the file exists with
|
||||||
|
# incompatible schema/data.
|
||||||
|
create_database(db_path)
|
||||||
|
|
||||||
|
# 3. Populate database
|
||||||
|
populate_database(db_path, dataset)
|
||||||
|
|
||||||
|
# 4. Create Indexes and Triggers
|
||||||
|
create_indexes_and_triggers(db_path)
|
||||||
|
|
||||||
|
logging.info("Database generation process completed for %s.", db_path)
|
||||||
|
logging.info(
|
||||||
|
"Please set the environment variable DEFAULT_EMAIL_DB_PATH "
|
||||||
|
"to this path.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
generate_database(overwrite=True)
|
||||||
BIN
tuner/email_search/rollout_accuracy_mean.png
Normal file
BIN
tuner/email_search/rollout_accuracy_mean.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 442 KiB |
271
tuner/frozen_lake/README.md
Normal file
271
tuner/frozen_lake/README.md
Normal file
@@ -0,0 +1,271 @@
|
|||||||
|
# Training FrozenLake Agent with RL using AgentScope-Tuner
|
||||||
|
|
||||||
|
## Summary
|
||||||
|
|
||||||
|
This example demonstrates how to use AgentScope-Tuner to implement reinforcement fine-tuning for the [Frozen Lake](https://gymnasium.farama.org/environments/toy_text/frozen_lake/) task using [Trinity-RFT](https://github.com/agentscope-ai/Trinity-RFT). The agent learns to navigate a frozen lake grid from a starting position to a goal while avoiding holes through multi-step interactions with the environment.
|
||||||
|
|
||||||
|
## Task Setting
|
||||||
|
|
||||||
|
### Agent Goal
|
||||||
|
The agent's objective is to navigate from the starting position (S) to the goal position (G) on a frozen lake grid without falling into holes (H). The agent must:
|
||||||
|
- Plan a path through frozen tiles (F) to reach the goal
|
||||||
|
- Avoid holes that terminate the episode with zero reward
|
||||||
|
- Complete the task within a limited number of steps
|
||||||
|
|
||||||
|
### Agent Type
|
||||||
|
The agent is implemented as a **ReActAgent** (Reasoning and Acting Agent) that:
|
||||||
|
- Observes the current state of the frozen lake grid
|
||||||
|
- Reasons about the best action to take
|
||||||
|
- Executes actions (Up, Down, Left, Right) to move through the environment
|
||||||
|
- Maintains internal state across multiple steps in an episode
|
||||||
|
|
||||||
|
### Environment
|
||||||
|
The environment is based on Gymnasium's FrozenLake environment, wrapped to provide:
|
||||||
|
- **Grid-based navigation**: Randomly generated maps with configurable size (2x2 to 6x6)
|
||||||
|
- **Tile types**:
|
||||||
|
- `S`: Start position
|
||||||
|
- `F`: Frozen tile (safe to walk on)
|
||||||
|
- `H`: Hole (terminates episode with reward 0)
|
||||||
|
- `G`: Goal (terminates episode with reward +1.0)
|
||||||
|
- **Action space**: Discrete actions (Up, Down, Left, Right)
|
||||||
|
- **Reward structure**:
|
||||||
|
- +1.0 for reaching the goal
|
||||||
|
- 0.0 for falling into a hole or failing to reach the goal
|
||||||
|
- **Observations**: Text-based grid representation showing current player position
|
||||||
|
|
||||||
|
The agent does not use external tools. It interacts directly with the environment through:
|
||||||
|
- `env.reset(task)`: Initialize environment with task parameters
|
||||||
|
- `env.step(action)`: Execute action and receive observation, reward, and done flag
|
||||||
|
- `env.render()`: Get text representation of current state
|
||||||
|
|
||||||
|
|
||||||
|
## Dataset Preparation
|
||||||
|
|
||||||
|
The dataset contains task parameters for generating FrozenLake environments. Each sample specifies:
|
||||||
|
- `seed`: Random seed for reproducible map generation
|
||||||
|
- `size`: Grid size (randomly sampled from 2 to `map_max_size`, e.g., 4x4, 6x6)
|
||||||
|
- `p`: Probability that a tile is frozen (vs. being a hole), randomly sampled from 0.6 to 0.85
|
||||||
|
- `index`: Sample index
|
||||||
|
- `uid`: Unique identifier combining seed, size, and p
|
||||||
|
|
||||||
|
Run the data preparation script to generate training and test datasets:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python get_frozenlake_data.py --map_max_size 6 --train_size 10000 --test_size 100
|
||||||
|
```
|
||||||
|
|
||||||
|
This will create parquet files in the specified directory:
|
||||||
|
|
||||||
|
```
|
||||||
|
/path/to/frozenlake_dataset/
|
||||||
|
├── train.parquet # 10000 training samples
|
||||||
|
└── test.parquet # 100 test samples
|
||||||
|
```
|
||||||
|
|
||||||
|
Each sample looks like:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{"seed": 12345, "size": 5, "p": 0.75, "index": 0, "uid": "12345_5_0.75"}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Note**: The data preparation script ensures that all generated maps have a valid path from start to goal within the maximum allowed steps (`env_max_steps=8`), filtering out unsolvable tasks.
|
||||||
|
|
||||||
|
## Code Implementation
|
||||||
|
|
||||||
|
This section provides a high-level overview of the code implementation. For detailed implementation, please refer to the source code.
|
||||||
|
|
||||||
|
### High-level Overview
|
||||||
|
|
||||||
|
The implementation consists of three main components:
|
||||||
|
|
||||||
|
1. **Agent** (`FrozenLakeAgent`): Extends `ReActAgent` to handle multi-step navigation
|
||||||
|
2. **Environment** (`FrozenLakeEnv`): Wraps Gymnasium's FrozenLake environment
|
||||||
|
3. **Workflow** (`run_frozen_lake`): Orchestrates the agent-environment interaction loop
|
||||||
|
|
||||||
|
### Agent Workflow
|
||||||
|
|
||||||
|
The workflow function `run_frozen_lake` implements the agent-environment interaction loop:
|
||||||
|
|
||||||
|
```python
|
||||||
|
async def run_frozen_lake(
|
||||||
|
task: Dict,
|
||||||
|
model: ChatModelBase,
|
||||||
|
auxiliary_models: Dict[str, ChatModelBase],
|
||||||
|
) -> WorkflowOutput:
|
||||||
|
# ...
|
||||||
|
|
||||||
|
# Create agent and environment
|
||||||
|
agent = FrozenLakeAgent(model=model, ...)
|
||||||
|
env = FrozenLakeEnv(...)
|
||||||
|
observation, _ = env.reset(task)
|
||||||
|
rewards = []
|
||||||
|
# ...
|
||||||
|
|
||||||
|
# Agent-environment interaction loop
|
||||||
|
for _ in range(max_steps):
|
||||||
|
response = await agent.reply(msg=Msg("user", agent.get_prompt(observation), role="user"))
|
||||||
|
action = agent.get_action(response)
|
||||||
|
observation, reward, done, _ = env.step(action)
|
||||||
|
rewards.append(reward)
|
||||||
|
if done:
|
||||||
|
break
|
||||||
|
|
||||||
|
# ...
|
||||||
|
final_reward = sum(rewards)
|
||||||
|
final_response = Msg("assistant", response_content, role="assistant")
|
||||||
|
|
||||||
|
return WorkflowOutput(
|
||||||
|
reward=final_reward,
|
||||||
|
response=final_response,
|
||||||
|
metrics={
|
||||||
|
"env_steps": float(step_count),
|
||||||
|
"env_done": float(done),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
**Key characteristics:**
|
||||||
|
- Multi-step interaction: The agent takes multiple actions in a single episode, unlike single-turn QA tasks
|
||||||
|
- State tracking: The agent maintains internal state (current step, last action, last observation) across steps
|
||||||
|
- Error handling: Invalid actions or agent errors are caught and handled gracefully
|
||||||
|
|
||||||
|
### Reward Function
|
||||||
|
|
||||||
|
No separate judge function is needed. The reward comes directly from the environment:
|
||||||
|
- 1.0: Agent successfully reaches the goal (G)
|
||||||
|
- 0.0: Agent falls into a hole (H) or fails to reach the goal within the maximum steps
|
||||||
|
|
||||||
|
The reward is computed as the sum of step rewards throughout the episode. The workflow returns:
|
||||||
|
- `reward`: Final cumulative reward
|
||||||
|
- `response`: Final response message containing observation, total reward, steps taken, and termination reason
|
||||||
|
- `metrics`: Additional metrics including `env_steps` (number of steps taken) and `env_done` (whether episode completed)
|
||||||
|
|
||||||
|
### Implementation Details
|
||||||
|
|
||||||
|
The environment (`FrozenLakeEnv`) wraps Gymnasium's FrozenLake and provides:
|
||||||
|
- `reset(task)`: Initialize the environment with task parameters
|
||||||
|
- `step(action)`: Execute an action and return (observation, reward, done, info)
|
||||||
|
- `render()`: Return a text representation of the current state
|
||||||
|
|
||||||
|
The agent (`FrozenLakeAgent`) extends `ReActAgent` and provides:
|
||||||
|
- `reply(msg)`: Reply to a message and return an action (inherited from AgentScope)
|
||||||
|
- `get_prompt(observation)`: Generate a prompt from the current observation
|
||||||
|
- `get_action(response)`: Parse the model's response to extract an action (Up/Down/Left/Right)
|
||||||
|
- `update_state(action, observation)`: Update internal state after each step
|
||||||
|
|
||||||
|
See [frozenlake_env.py](./frozenlake_env.py) and [frozenlake_agent.py](./frozenlake_agent.py) for implementation details.
|
||||||
|
|
||||||
|
### Step 4: Use `tune` to train the workflow
|
||||||
|
|
||||||
|
```python
|
||||||
|
from agentscope.tuner import tune, DatasetConfig
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
config_path = os.path.join(
|
||||||
|
os.path.dirname(__file__),
|
||||||
|
"config.yaml",
|
||||||
|
)
|
||||||
|
dataset = DatasetConfig(
|
||||||
|
path="/path/to/frozenlake_dataset",
|
||||||
|
name="default",
|
||||||
|
split="train",
|
||||||
|
)
|
||||||
|
tune(
|
||||||
|
workflow_func=run_frozen_lake,
|
||||||
|
train_dataset=dataset,
|
||||||
|
config_path=config_path,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
See [config.yaml](./config.yaml) for the training configuration. For full configuration details, see [Trinity-RFT Configuration Guide](https://agentscope-ai.github.io/Trinity-RFT/en/main/tutorial/trinity_configs.html).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## How to Run
|
||||||
|
|
||||||
|
### Prerequisites
|
||||||
|
|
||||||
|
- At least 2 NVIDIA GPUs with CUDA 12.8 or newer
|
||||||
|
- Follow the Trinity-RFT [installation guide](https://agentscope-ai.github.io/Trinity-RFT/en/main/tutorial/trinity_installation.html) to install the latest version from source code
|
||||||
|
- Install gymnasium for the FrozenLake environment:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install gymnasium[toy_text]
|
||||||
|
```
|
||||||
|
|
||||||
|
- Download the model checkpoint (example):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
huggingface-cli download Qwen/Qwen2.5-3B-Instruct
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 1: Prepare the Dataset
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python get_frozenlake_data.py --map_max_size 6 --train_size 10000 --test_size 100
|
||||||
|
```
|
||||||
|
|
||||||
|
Update the dataset path in `main.py` to point to your generated dataset directory.
|
||||||
|
|
||||||
|
### Step 2: Configure the Training
|
||||||
|
|
||||||
|
Key configuration can be identified in the code, including:
|
||||||
|
|
||||||
|
**Algorithm Configuration** (`AlgorithmConfig`):
|
||||||
|
- `algorithm_type`: `multi_step_grpo` (Group Relative Policy Optimization for multi-step tasks)
|
||||||
|
- `group_size`: Number of policy update iterations per batch (default: 16)
|
||||||
|
- `batch_size`: Batch size for training (default: 32)
|
||||||
|
- `learning_rate`: Learning rate (default: 1e-6)
|
||||||
|
|
||||||
|
**Model Configuration** (`TunerModelConfig`):
|
||||||
|
- `model_path`: Path to the base model (e.g., `Qwen/Qwen2.5-3B-Instruct`)
|
||||||
|
- `max_model_len`: Maximum model context length (default: 25600)
|
||||||
|
- `max_tokens`: Maximum tokens for response generation (default: 2048)
|
||||||
|
- `inference_engine_num`: Number of inference engines (default: 6, using 6 GPUs for inference)
|
||||||
|
|
||||||
|
**Dataset Configuration** (`DatasetConfig`):
|
||||||
|
- `path`: Path to the dataset (default: `/path/to/frozenlake`)
|
||||||
|
- `split`: Split of the dataset (default: `train`)
|
||||||
|
|
||||||
|
Adjust these parameters based on your hardware resources and training requirements. Other parameters can be spetified in [config.yaml](./config.yaml).
|
||||||
|
|
||||||
|
|
||||||
|
### Step 3: Set Up Ray Cluster
|
||||||
|
|
||||||
|
Set up a [Ray](https://github.com/ray-project/ray) cluster:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
ray start --head
|
||||||
|
# for multi-node setup, run the following command on worker nodes
|
||||||
|
# ray start --address=<master_address>
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 4: Run the Training Script
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python main.py
|
||||||
|
```
|
||||||
|
|
||||||
|
The training will start and you can monitor the progress through the logs. Checkpoints will be saved once every `trainer.save_interval` steps.
|
||||||
|
|
||||||
|
## Experimental Results
|
||||||
|
|
||||||
|
### Training Reward Curve
|
||||||
|
|
||||||
|
The reward curve during training shows the agent's learning progress:
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
The training reward typically increases over epochs as the agent learns to navigate the frozen lake more effectively.
|
||||||
|
|
||||||
|
### Example Agent Output
|
||||||
|
|
||||||
|
An example of agent output is given below:
|
||||||
|
```
|
||||||
|
From the current observation, let's analyze the situation. The player (P) is at: (4, 0), and the goal (G) is at: (2, 3). There is also a hole (O) at (4, 4). Given this, I can move towards the goal without worrying about slippery tiles right now.
|
||||||
|
|
||||||
|
The shortest path from P to G involves moving left (4 steps) followed by moving down (1 step), since going directly would bypass the hole or move us further from the goal. Let's move left first.
|
||||||
|
|
||||||
|
Let's take the action ```Left```.
|
||||||
|
```
|
||||||
250
tuner/frozen_lake/README_zh.md
Normal file
250
tuner/frozen_lake/README_zh.md
Normal file
@@ -0,0 +1,250 @@
|
|||||||
|
# 使用 AgentScope-Tuner 训练 FrozenLake Agent
|
||||||
|
|
||||||
|
## 摘要
|
||||||
|
|
||||||
|
本示例展示如何使用 AgentScope-Tuner 配合 [Trinity-RFT](https://github.com/agentscope-ai/Trinity-RFT) 对 [Frozen Lake](https://gymnasium.farama.org/environments/toy_text/frozen_lake/) 任务进行强化微调。智能体需要在冰湖网格中从起点走到终点,避开坑洞,并在有限步数内完成任务。
|
||||||
|
|
||||||
|
## 任务设定
|
||||||
|
|
||||||
|
### 智能体目标
|
||||||
|
智能体要在冰湖网格上从起点 (S) 抵达终点 (G),同时:
|
||||||
|
- 规划路径经过冰面 (F) 到达终点
|
||||||
|
- 避开会结束回合且奖励为 0 的坑洞 (H)
|
||||||
|
- 在限定步数内完成任务
|
||||||
|
|
||||||
|
### 智能体类型
|
||||||
|
智能体实现为 **ReActAgent**,它的行为包括:
|
||||||
|
- 观察当前冰湖网格状态
|
||||||
|
- 推理下一步最优动作
|
||||||
|
- 执行动作(上、下、左、右)在环境中移动
|
||||||
|
- 在多步交互中维护内部状态
|
||||||
|
|
||||||
|
### 环境
|
||||||
|
环境基于 Gymnasium 的 FrozenLake,并提供:
|
||||||
|
- **网格导航**:随机生成 2x2 至 6x6 的地图
|
||||||
|
- **格子类型**:
|
||||||
|
- `S`:起点
|
||||||
|
- `F`:冰面(可通行)
|
||||||
|
- `H`:坑洞(奖励 0,结束回合)
|
||||||
|
- `G`:终点(奖励 +1.0,结束回合)
|
||||||
|
- **动作空间**:离散动作(上、下、左、右)
|
||||||
|
- **奖励设计**:
|
||||||
|
- 到达终点 +1.0
|
||||||
|
- 掉入坑洞或未在最大步数内到达终点为 0.0
|
||||||
|
- **观测**:返回当前玩家位置的文本网格表示
|
||||||
|
|
||||||
|
智能体不使用外部工具,直接通过以下接口与环境交互:
|
||||||
|
- `env.reset(task)`:根据任务参数初始化环境
|
||||||
|
- `env.step(action)`:执行动作,返回观测、奖励和结束标志
|
||||||
|
- `env.render()`:返回当前状态的文本表示
|
||||||
|
|
||||||
|
## 数据集准备
|
||||||
|
|
||||||
|
数据集包含用于生成 FrozenLake 环境的任务参数,每个样本包含:
|
||||||
|
- `seed`:随机种子,保证地图可复现
|
||||||
|
- `size`:网格大小(在 2 和 `map_max_size` 之间随机,如 4x4、6x6)
|
||||||
|
- `p`:格子为冰面的概率(0.6 到 0.85 之间随机),其余为坑洞
|
||||||
|
- `index`:样本索引
|
||||||
|
- `uid`:由 seed、size、p 组合而成的唯一 ID
|
||||||
|
|
||||||
|
运行数据准备脚本生成训练集与测试集:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python get_frozenlake_data.py --map_max_size 6 --train_size 10000 --test_size 100
|
||||||
|
```
|
||||||
|
|
||||||
|
生成的目录结构示例:
|
||||||
|
```
|
||||||
|
/path/to/frozenlake_dataset/
|
||||||
|
├── train.parquet # 10000 条训练样本
|
||||||
|
└── test.parquet # 100 条测试样本
|
||||||
|
```
|
||||||
|
|
||||||
|
样本示例:
|
||||||
|
```json
|
||||||
|
{"seed": 12345, "size": 5, "p": 0.75, "index": 0, "uid": "12345_5_0.75"}
|
||||||
|
```
|
||||||
|
|
||||||
|
**注意**:脚本会过滤无解的地图,确保在最大步数 (`env_max_steps=8`) 内存在从起点到终点的可行路径。
|
||||||
|
|
||||||
|
## 代码实现
|
||||||
|
|
||||||
|
本节提供代码实现的高级概览。详细实现请参考源代码。
|
||||||
|
|
||||||
|
### 高级概览
|
||||||
|
实现由三部分组成:
|
||||||
|
1. **Agent** (`FrozenLakeAgent`):继承 `ReActAgent`,负责多步交互
|
||||||
|
2. **环境** (`FrozenLakeEnv`):封装 Gymnasium FrozenLake
|
||||||
|
3. **工作流** (`run_frozen_lake`):组织智能体与环境的交互流程
|
||||||
|
|
||||||
|
### 工作流
|
||||||
|
`run_frozen_lake` 实现多步交互流程:
|
||||||
|
|
||||||
|
```python
|
||||||
|
async def run_frozen_lake(
|
||||||
|
task: Dict,
|
||||||
|
model: ChatModelBase,
|
||||||
|
auxiliary_models: Dict[str, ChatModelBase],
|
||||||
|
) -> WorkflowOutput:
|
||||||
|
# ...
|
||||||
|
|
||||||
|
# 创建智能体和环境
|
||||||
|
agent = FrozenLakeAgent(model=model, ...)
|
||||||
|
env = FrozenLakeEnv(...)
|
||||||
|
observation, _ = env.reset(task)
|
||||||
|
rewards = []
|
||||||
|
# ...
|
||||||
|
|
||||||
|
# 智能体-环境交互循环
|
||||||
|
for _ in range(max_steps):
|
||||||
|
response = await agent.reply(msg=Msg("user", agent.get_prompt(observation), role="user"))
|
||||||
|
action = agent.get_action(response)
|
||||||
|
observation, reward, done, _ = env.step(action)
|
||||||
|
rewards.append(reward)
|
||||||
|
if done:
|
||||||
|
break
|
||||||
|
|
||||||
|
# ...
|
||||||
|
final_reward = sum(rewards)
|
||||||
|
final_response = Msg("assistant", response_content, role="assistant")
|
||||||
|
|
||||||
|
return WorkflowOutput(
|
||||||
|
reward=final_reward,
|
||||||
|
response=final_response,
|
||||||
|
metrics={"env_steps": float(step_count), "env_done": float(done)},
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**关键特性:**
|
||||||
|
- 多步交互:单次 episode 内多次动作,不是单轮 QA
|
||||||
|
- 状态跟踪:记录当前步、上次动作与观测
|
||||||
|
- 错误处理:无效动作或异常会被捕获并处理
|
||||||
|
|
||||||
|
### 奖励函数
|
||||||
|
无需额外 judge,奖励由环境直接给出:
|
||||||
|
- 1.0:到达终点
|
||||||
|
- 0.0:掉入坑洞或超步数未达终点
|
||||||
|
|
||||||
|
工作流返回:
|
||||||
|
- `reward`:累计奖励
|
||||||
|
- `response`:包含观测、总奖励、步数、终止原因的最终回复
|
||||||
|
- `metrics`:`env_steps`(步数)、`env_done`(是否结束)
|
||||||
|
|
||||||
|
### 实现细节
|
||||||
|
|
||||||
|
环境 (`FrozenLakeEnv`) 封装了 Gymnasium 的 FrozenLake,提供:
|
||||||
|
- `reset(task)`: 使用任务参数初始化环境
|
||||||
|
- `step(action)`: 执行动作并返回 (observation, reward, done, info)
|
||||||
|
- `render()`: 返回当前状态的文本表示
|
||||||
|
|
||||||
|
智能体 (`FrozenLakeAgent`) 继承 `ReActAgent`,提供:
|
||||||
|
- `reply(msg)`: 回复消息并返回动作(继承自 AgentScope)
|
||||||
|
- `get_prompt(observation)`: 从当前观测生成提示
|
||||||
|
- `get_action(response)`: 解析模型响应以提取动作(Up/Down/Left/Right)
|
||||||
|
- `update_state(action, observation)`: 在每步后更新内部状态
|
||||||
|
|
||||||
|
详细实现请参考 [frozenlake_env.py](./frozenlake_env.py) 和 [frozenlake_agent.py](./frozenlake_agent.py)。
|
||||||
|
|
||||||
|
### 步骤 4:使用 `tune` 训练工作流
|
||||||
|
|
||||||
|
```python
|
||||||
|
from agentscope.tuner import tune, DatasetConfig
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
config_path = os.path.join(
|
||||||
|
os.path.dirname(__file__),
|
||||||
|
"config.yaml",
|
||||||
|
)
|
||||||
|
dataset = DatasetConfig(
|
||||||
|
path="/path/to/frozenlake_dataset",
|
||||||
|
name="default",
|
||||||
|
split="train",
|
||||||
|
)
|
||||||
|
tune(
|
||||||
|
workflow_func=run_frozen_lake,
|
||||||
|
train_dataset=dataset,
|
||||||
|
config_path=config_path,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
训练配置请参考 [config.yaml](./config.yaml)。完整配置详情请参考 [Trinity-RFT 配置指南](https://agentscope-ai.github.io/Trinity-RFT/en/main/tutorial/trinity_configs.html)。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 运行方法
|
||||||
|
|
||||||
|
### 依赖
|
||||||
|
- 至少 2 张 NVIDIA GPU,CUDA 版本 ≥ 12.8
|
||||||
|
- 按 [Trinity-RFT 安装指南](https://agentscope-ai.github.io/Trinity-RFT/en/main/tutorial/trinity_installation.html) 从源码安装
|
||||||
|
- 安装 gymnasium 冰湖环境:
|
||||||
|
```bash
|
||||||
|
pip install gymnasium[toy_text]
|
||||||
|
```
|
||||||
|
- 下载模型权重(示例):
|
||||||
|
```bash
|
||||||
|
huggingface-cli download Qwen/Qwen2.5-3B-Instruct
|
||||||
|
```
|
||||||
|
|
||||||
|
### 步骤 1:准备数据集
|
||||||
|
```bash
|
||||||
|
python get_frozenlake_data.py --map_max_size 6 --train_size 10000 --test_size 100
|
||||||
|
```
|
||||||
|
将 `main.py` 中的数据集路径改为你的生成目录。
|
||||||
|
|
||||||
|
### 步骤 2:配置训练
|
||||||
|
|
||||||
|
关键配置可在代码中设置,包括:
|
||||||
|
|
||||||
|
**算法配置** (`AlgorithmConfig`):
|
||||||
|
- `algorithm_type`: `multi_step_grpo`(用于多步任务的组相对策略优化)
|
||||||
|
- `group_size`: 每批次的策略更新组大小(默认 16)
|
||||||
|
- `batch_size`: 批大小(默认 32)
|
||||||
|
- `learning_rate`: 学习率(默认 1e-6)
|
||||||
|
|
||||||
|
**模型配置** (`TunerModelConfig`):
|
||||||
|
- `model_path`: 基础模型路径(如 `Qwen/Qwen2.5-3B-Instruct`)
|
||||||
|
- `max_model_len`: 最大上下文长度(默认 25600)
|
||||||
|
- `max_tokens`: 响应最大生成长度(默认 2048)
|
||||||
|
- `inference_engine_num`: 推理引擎数量(默认 6,表示用 6 个 GPU 进行推理)
|
||||||
|
|
||||||
|
**数据集配置** (`DatasetConfig`):
|
||||||
|
- `path`: 数据集路径(默认 `/path/to/frozenlake`)
|
||||||
|
- `split`: 数据集分片(默认 `train`)
|
||||||
|
|
||||||
|
可根据硬件资源和训练需求调整这些参数。其他参数可在 [config.yaml](./config.yaml) 中指定。
|
||||||
|
|
||||||
|
### 步骤 3:设置 Ray 集群
|
||||||
|
|
||||||
|
设置 [Ray](https://github.com/ray-project/ray) 集群:
|
||||||
|
```bash
|
||||||
|
ray start --head
|
||||||
|
# 对于多节点设置,在工作节点上运行以下命令
|
||||||
|
# ray start --address=<master_address>
|
||||||
|
```
|
||||||
|
|
||||||
|
### 步骤 4:运行训练脚本
|
||||||
|
```bash
|
||||||
|
python main.py
|
||||||
|
```
|
||||||
|
训练将开始,可通过日志监控进度。检查点将每 `trainer.save_interval` 步保存一次。
|
||||||
|
|
||||||
|
## 实验结果
|
||||||
|
|
||||||
|
### 训练奖励曲线
|
||||||
|
|
||||||
|
训练过程中的奖励曲线显示智能体的学习进度:
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
训练奖励通常随着智能体学习更有效地导航冰湖而随训练轮次增加。
|
||||||
|
|
||||||
|
### 智能体输出示例
|
||||||
|
|
||||||
|
智能体输出示例如下:
|
||||||
|
```
|
||||||
|
From the current observation, let's analyze the situation. The player (P) is at: (4, 0), and the goal (G) is at: (2, 3). There is also a hole (O) at (4, 4). Given this, I can move towards the goal without worrying about slippery tiles right now.
|
||||||
|
|
||||||
|
The shortest path from P to G involves moving left (4 steps) followed by moving down (1 step), since going directly would bypass the hole or move us further from the goal. Let's move left first.
|
||||||
|
|
||||||
|
Let's take the action ```Left```.
|
||||||
|
```
|
||||||
102
tuner/frozen_lake/_frozenlake_agent.py
Normal file
102
tuner/frozen_lake/_frozenlake_agent.py
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""Adapted from Trinity-RFT"""
|
||||||
|
import re
|
||||||
|
from _utils import SYSTEM_PROMPT, FrozenLakeAction # pylint: disable=E0611
|
||||||
|
from agentscope.agent import ReActAgent
|
||||||
|
from agentscope.formatter import OpenAIChatFormatter
|
||||||
|
from agentscope.message import Msg
|
||||||
|
from agentscope.model import OpenAIChatModel
|
||||||
|
|
||||||
|
|
||||||
|
INVALID_ACTION = "still"
|
||||||
|
VALID_ACTIONS = {
|
||||||
|
"left": 1,
|
||||||
|
"down": 2,
|
||||||
|
"right": 3,
|
||||||
|
"up": 4,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class FrozenLakeAgent(ReActAgent):
|
||||||
|
"""Agent for FrozenLake environment."""
|
||||||
|
|
||||||
|
def __init__(self, model: OpenAIChatModel, max_steps: int = 20):
|
||||||
|
super().__init__(
|
||||||
|
name="frozenlake_agent",
|
||||||
|
model=model,
|
||||||
|
sys_prompt=SYSTEM_PROMPT,
|
||||||
|
formatter=OpenAIChatFormatter(),
|
||||||
|
max_iters=1,
|
||||||
|
)
|
||||||
|
self.response_structure = FrozenLakeAction
|
||||||
|
self.current_step = 0
|
||||||
|
self.last_action = None
|
||||||
|
self.last_observation = None
|
||||||
|
self.max_steps = max_steps
|
||||||
|
|
||||||
|
def get_prompt(self, observation: str) -> str:
|
||||||
|
"""Get prompt for the agent based on current observation."""
|
||||||
|
prompt = (
|
||||||
|
f"Current Observation ({self.current_step}): \n"
|
||||||
|
+ observation
|
||||||
|
+ "\n"
|
||||||
|
+ (
|
||||||
|
"You have not achieved the goal, P has not reached G yet. "
|
||||||
|
"Please give the next action."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if self.current_step > 0 and self.last_action is not None:
|
||||||
|
if self.last_observation == observation:
|
||||||
|
prompt += (
|
||||||
|
"\nYour last response is invalid. "
|
||||||
|
"Your position didn't change at all. "
|
||||||
|
"You may need to recheck your thinking process, "
|
||||||
|
"action outputted, and the format of response. "
|
||||||
|
"Remember, you should only output the NEXT ACTION "
|
||||||
|
"at each iteration in the ``` ```. "
|
||||||
|
"For example, if you want to move up, "
|
||||||
|
"you should output ```Up```."
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.max_steps is not None
|
||||||
|
and self.max_steps - self.current_step > 0
|
||||||
|
):
|
||||||
|
remaining = self.max_steps - self.current_step
|
||||||
|
prompt += (
|
||||||
|
f"\nThe maximum number of steps remaining is {remaining}."
|
||||||
|
)
|
||||||
|
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
def get_action(self, msg: Msg) -> str:
|
||||||
|
"""Extract action from agent response message."""
|
||||||
|
response: str = (
|
||||||
|
msg.content
|
||||||
|
if isinstance(msg.content, str)
|
||||||
|
else msg.content[0].get("text")
|
||||||
|
)
|
||||||
|
action = INVALID_ACTION
|
||||||
|
|
||||||
|
matches = re.findall(r"```(.*?)```", response, re.DOTALL)
|
||||||
|
|
||||||
|
if matches:
|
||||||
|
last_match_content = matches[-1].strip()
|
||||||
|
action = last_match_content.lower()
|
||||||
|
if action not in VALID_ACTIONS:
|
||||||
|
action = INVALID_ACTION
|
||||||
|
|
||||||
|
return action
|
||||||
|
|
||||||
|
def update_state(self, action: str, observation: str) -> None:
|
||||||
|
"""Update agent state with action and observation."""
|
||||||
|
self.last_action = action
|
||||||
|
self.last_observation = observation
|
||||||
|
self.current_step += 1
|
||||||
|
|
||||||
|
async def reset(self) -> None:
|
||||||
|
"""Reset agent state for a new episode."""
|
||||||
|
self.current_step = 0
|
||||||
|
self.last_action = None
|
||||||
|
self.last_observation = None
|
||||||
|
await self.memory.clear()
|
||||||
316
tuner/frozen_lake/_frozenlake_env.py
Normal file
316
tuner/frozen_lake/_frozenlake_env.py
Normal file
@@ -0,0 +1,316 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""Adapted from Trinity-RFT"""
|
||||||
|
import copy
|
||||||
|
from typing import Dict, Optional, Tuple, Union
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
try:
|
||||||
|
from gymnasium.envs.toy_text.frozen_lake import (
|
||||||
|
FrozenLakeEnv as GymFrozenLakeEnv,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
GymFrozenLakeEnv = object
|
||||||
|
from _utils import ( # pylint: disable=E0611
|
||||||
|
generate_random_map,
|
||||||
|
get_goal_position,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FrozenLakeEnv(GymFrozenLakeEnv):
|
||||||
|
"""FrozenLake environment wrapper."""
|
||||||
|
|
||||||
|
# Map gym state in integer
|
||||||
|
MAP_LOOKUP = {
|
||||||
|
b"P": 0,
|
||||||
|
b"F": 1,
|
||||||
|
b"H": 2,
|
||||||
|
b"G": 3,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Define rules to transform to rendered text observation of the environment
|
||||||
|
GRID_LOOKUP = {
|
||||||
|
0: " P \t", # player
|
||||||
|
1: " _ \t", # frozen
|
||||||
|
2: " O \t", # hole
|
||||||
|
3: " G \t", # goal
|
||||||
|
4: " X \t", # player fall into hole
|
||||||
|
5: " √ \t", # player on goal
|
||||||
|
}
|
||||||
|
|
||||||
|
ACTION_LOOKUP = {
|
||||||
|
"still": 0,
|
||||||
|
"left": 1,
|
||||||
|
"down": 2,
|
||||||
|
"right": 3,
|
||||||
|
"up": 4,
|
||||||
|
}
|
||||||
|
|
||||||
|
INVALID_ACTION = 0
|
||||||
|
PENALTY_FOR_INVALID = -1
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_steps: int = 8,
|
||||||
|
desc: Optional[str] = None,
|
||||||
|
is_slippery: bool = False,
|
||||||
|
size: int = 8,
|
||||||
|
p: float = 0.8,
|
||||||
|
seed: int = 42,
|
||||||
|
):
|
||||||
|
self.max_steps = max_steps or 8
|
||||||
|
self.desc: Union[str, np.ndarray, None] = desc
|
||||||
|
self.is_slippery = is_slippery
|
||||||
|
self.size = size
|
||||||
|
self.p = p
|
||||||
|
self.seed = seed
|
||||||
|
self.render_mode: Optional[str] = None
|
||||||
|
try:
|
||||||
|
import gymnasium as gym
|
||||||
|
except ImportError as e:
|
||||||
|
error_message = (
|
||||||
|
"Gymnasium is not installed. "
|
||||||
|
"Please install gymnasium first before "
|
||||||
|
"running the frozen_lake workflow. "
|
||||||
|
f"Error: {str(e)}"
|
||||||
|
)
|
||||||
|
raise ImportError(error_message) from e
|
||||||
|
|
||||||
|
if self.desc is None:
|
||||||
|
random_map, goal_position = generate_random_map(
|
||||||
|
size=self.size,
|
||||||
|
p=self.p,
|
||||||
|
seed=self.seed,
|
||||||
|
max_steps=self.max_steps,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
random_map = np.asarray(copy.deepcopy(self.desc), dtype="c")
|
||||||
|
goal_position = get_goal_position(random_map)
|
||||||
|
|
||||||
|
self.goal_position = goal_position
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
desc=random_map[:],
|
||||||
|
is_slippery=self.is_slippery,
|
||||||
|
)
|
||||||
|
assert isinstance(self.desc, np.ndarray)
|
||||||
|
self.action_space = gym.spaces.Discrete(4, start=1)
|
||||||
|
|
||||||
|
self.map_kwargs = {
|
||||||
|
"size": size,
|
||||||
|
"p": p,
|
||||||
|
}
|
||||||
|
self.env_kwargs = {
|
||||||
|
"is_slippery": is_slippery,
|
||||||
|
"desc": copy.deepcopy(desc),
|
||||||
|
"seed": seed,
|
||||||
|
}
|
||||||
|
|
||||||
|
self.action_map = {
|
||||||
|
1: 0, # left
|
||||||
|
2: 1, # down
|
||||||
|
3: 2, # right
|
||||||
|
4: 3, # up
|
||||||
|
}
|
||||||
|
|
||||||
|
def _get_player_position(self) -> Tuple[int, int]:
|
||||||
|
return (self.s // self.ncol, self.s % self.ncol) # (row, col)
|
||||||
|
|
||||||
|
def step(self, action: str) -> Tuple[str, float, bool, Dict]:
|
||||||
|
"""Execute a step in the environment.
|
||||||
|
|
||||||
|
Maps custom action to gymnasium FrozenLakeEnv action and
|
||||||
|
takes the step. Checks if the action is effective (whether
|
||||||
|
player moves in the env).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action: The action to take.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (observation, reward, done, info).
|
||||||
|
"""
|
||||||
|
if self.success():
|
||||||
|
obs = self.render(mode="tiny_rgb_array")
|
||||||
|
assert isinstance(obs, str)
|
||||||
|
return obs, 1.0, True, {"action_is_effective": False}
|
||||||
|
|
||||||
|
action_id: int = self.ACTION_LOOKUP.get(action.lower(), 0)
|
||||||
|
|
||||||
|
if not action_id:
|
||||||
|
action_id = self.INVALID_ACTION
|
||||||
|
|
||||||
|
if (
|
||||||
|
action_id == self.INVALID_ACTION
|
||||||
|
or action_id not in self.action_map
|
||||||
|
):
|
||||||
|
obs = self.render(mode="tiny_rgb_array")
|
||||||
|
assert isinstance(obs, str)
|
||||||
|
return obs, 0.0, False, {"action_is_effective": False}
|
||||||
|
|
||||||
|
prev_player_position = int(self.s)
|
||||||
|
|
||||||
|
# Call parent class step method
|
||||||
|
# Note: GymFrozenLakeEnv is imported at module level
|
||||||
|
player_pos, reward, done, _, _ = super().step(
|
||||||
|
self.action_map[action_id],
|
||||||
|
)
|
||||||
|
|
||||||
|
obs = self.render(mode="tiny_rgb_array")
|
||||||
|
assert isinstance(obs, str)
|
||||||
|
return (
|
||||||
|
obs,
|
||||||
|
float(reward),
|
||||||
|
bool(done),
|
||||||
|
{"action_is_effective": prev_player_position != int(player_pos)},
|
||||||
|
)
|
||||||
|
|
||||||
|
def render(
|
||||||
|
self,
|
||||||
|
mode: str = "tiny_rgb_array",
|
||||||
|
) -> str | list[str] | np.ndarray:
|
||||||
|
"""Render the environment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mode: Rendering mode. Options: "tiny_rgb_array", "list",
|
||||||
|
"state", "rgb_array", "ansi".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Rendered observation based on the mode.
|
||||||
|
"""
|
||||||
|
assert mode in [
|
||||||
|
"tiny_rgb_array",
|
||||||
|
"list",
|
||||||
|
"state",
|
||||||
|
"rgb_array",
|
||||||
|
"ansi",
|
||||||
|
]
|
||||||
|
if mode in ["rgb_array", "ansi"]:
|
||||||
|
prev_render_mode = self.render_mode
|
||||||
|
self.render_mode = mode
|
||||||
|
obs = super().render()
|
||||||
|
self.render_mode = prev_render_mode
|
||||||
|
return obs
|
||||||
|
assert isinstance(self.desc, np.ndarray)
|
||||||
|
room_state = copy.deepcopy(self.desc)
|
||||||
|
|
||||||
|
# replace the position of start 'S' with 'F'
|
||||||
|
position_S = np.where(room_state == b"S")
|
||||||
|
room_state[position_S] = b"F"
|
||||||
|
|
||||||
|
# replace the position of the player with 'P'
|
||||||
|
position_P = self._get_player_position()
|
||||||
|
room_state[position_P] = b"P"
|
||||||
|
|
||||||
|
if mode == "state":
|
||||||
|
# transform 'S', 'F', 'H', 'G' to numpy integer array
|
||||||
|
room_state = np.vectorize(lambda x: self.MAP_LOOKUP[x])(room_state)
|
||||||
|
# add player in hole or player on goal
|
||||||
|
if self.desc[position_P] == b"H":
|
||||||
|
room_state[position_P] = 4
|
||||||
|
elif self.desc[position_P] == b"G":
|
||||||
|
room_state[position_P] = 5
|
||||||
|
return room_state
|
||||||
|
|
||||||
|
room_state = self.render(mode="state").tolist()
|
||||||
|
assert isinstance(room_state, list)
|
||||||
|
|
||||||
|
if mode == "list":
|
||||||
|
|
||||||
|
def lookup_list(cell: int) -> str:
|
||||||
|
return self.GRID_LOOKUP.get(cell, "?").strip("\t").strip()
|
||||||
|
|
||||||
|
return [
|
||||||
|
" ".join(lookup_list(cell) for cell in row)
|
||||||
|
for row in room_state
|
||||||
|
]
|
||||||
|
|
||||||
|
if mode == "tiny_rgb_array":
|
||||||
|
|
||||||
|
def lookup_tiny(cell: int) -> str:
|
||||||
|
return self.GRID_LOOKUP.get(cell, "?")
|
||||||
|
|
||||||
|
result = "\n".join(
|
||||||
|
"".join(lookup_tiny(cell) for cell in row)
|
||||||
|
for row in room_state
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
# Default return for other modes
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def reset(
|
||||||
|
self,
|
||||||
|
task: Optional[Dict] = None,
|
||||||
|
) -> tuple[str, Dict]:
|
||||||
|
"""Reset the environment with optional task parameters."""
|
||||||
|
task = task or {}
|
||||||
|
# Update parameters from task if provided
|
||||||
|
size = task.get("size", self.map_kwargs["size"])
|
||||||
|
p = task.get("p", self.map_kwargs["p"])
|
||||||
|
seed = task.get("seed", self.env_kwargs["seed"])
|
||||||
|
is_slippery = task.get(
|
||||||
|
"is_slippery",
|
||||||
|
self.env_kwargs["is_slippery"],
|
||||||
|
)
|
||||||
|
desc = task.get("desc", self.env_kwargs.get("desc"))
|
||||||
|
|
||||||
|
# Update instance variables
|
||||||
|
self.size = size
|
||||||
|
self.p = p
|
||||||
|
self.seed = seed
|
||||||
|
self.is_slippery = is_slippery
|
||||||
|
self.map_kwargs["size"] = size
|
||||||
|
self.map_kwargs["p"] = p
|
||||||
|
self.env_kwargs["seed"] = seed
|
||||||
|
self.env_kwargs["is_slippery"] = is_slippery
|
||||||
|
if desc is not None:
|
||||||
|
self.env_kwargs["desc"] = copy.deepcopy(desc)
|
||||||
|
|
||||||
|
if desc is None:
|
||||||
|
random_map, goal_position = generate_random_map(
|
||||||
|
size=size,
|
||||||
|
p=p,
|
||||||
|
seed=seed,
|
||||||
|
max_steps=self.max_steps,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
random_map = np.asarray(copy.deepcopy(desc), dtype="c")
|
||||||
|
goal_position = get_goal_position(random_map)
|
||||||
|
|
||||||
|
self.goal_position = goal_position
|
||||||
|
self.desc = random_map[:]
|
||||||
|
|
||||||
|
# Reinitialize parent class with new map
|
||||||
|
try:
|
||||||
|
import gymnasium as gym
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
desc=random_map[:],
|
||||||
|
is_slippery=self.is_slippery,
|
||||||
|
)
|
||||||
|
assert isinstance(self.desc, np.ndarray)
|
||||||
|
self.action_space = gym.spaces.Discrete(4, start=1)
|
||||||
|
except ImportError as e:
|
||||||
|
error_message = (
|
||||||
|
"Gymnasium is not installed. "
|
||||||
|
"Please install gymnasium first before "
|
||||||
|
"running the frozen_lake workflow. "
|
||||||
|
f"Error: {str(e)}"
|
||||||
|
)
|
||||||
|
raise ImportError(error_message) from e
|
||||||
|
|
||||||
|
super().reset(seed=self.seed)
|
||||||
|
obs = self.render(mode="tiny_rgb_array")
|
||||||
|
assert isinstance(obs, str)
|
||||||
|
return obs, {}
|
||||||
|
|
||||||
|
def finished(self) -> bool:
|
||||||
|
"""Check if the episode is finished (goal or hole)."""
|
||||||
|
player_pos = self._get_player_position()
|
||||||
|
assert isinstance(self.desc, np.ndarray)
|
||||||
|
return self.desc[player_pos] in b"GH" # type: ignore
|
||||||
|
|
||||||
|
def success(self) -> bool:
|
||||||
|
"""Check if the agent has reached the goal (G)."""
|
||||||
|
player_pos = self._get_player_position()
|
||||||
|
assert isinstance(self.desc, np.ndarray)
|
||||||
|
return self.desc[player_pos] in b"G"
|
||||||
209
tuner/frozen_lake/_utils.py
Normal file
209
tuner/frozen_lake/_utils.py
Normal file
@@ -0,0 +1,209 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
Utils for the FrozenLake environment.
|
||||||
|
Modified from rllm
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Literal, Optional, Tuple
|
||||||
|
import numpy as np
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
# Map gym state in integer
|
||||||
|
MAP_LOOKUP = {
|
||||||
|
b"P": 0,
|
||||||
|
b"F": 1,
|
||||||
|
b"H": 2,
|
||||||
|
b"G": 3,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Define rules to transform to rendered text observation of the environment
|
||||||
|
GRID_LOOKUP = {
|
||||||
|
0: " P \t", # player
|
||||||
|
1: " _ \t", # frozen
|
||||||
|
2: " O \t", # hole
|
||||||
|
3: " G \t", # goal
|
||||||
|
4: " X \t", # player fall into hole
|
||||||
|
5: " √ \t", # player on goal
|
||||||
|
}
|
||||||
|
|
||||||
|
ACTION_LOOKUP = {
|
||||||
|
0: "None",
|
||||||
|
1: "Left",
|
||||||
|
2: "Down",
|
||||||
|
3: "Right",
|
||||||
|
4: "Up",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Prompting format inspired by the RAGEN project
|
||||||
|
SYSTEM_PROMPT = """You are Qwen, created by Alibaba Cloud. \
|
||||||
|
You are a helpful assistant. You are walking on a frozen lake.
|
||||||
|
|
||||||
|
FrozenLake Quick Guide
|
||||||
|
Goal: Reach the goal (G). Player (P) and Goal (G) must overlap.
|
||||||
|
|
||||||
|
Symbols:
|
||||||
|
_ Frozen | O Hole | G Goal | P Player
|
||||||
|
|
||||||
|
Rules:
|
||||||
|
1. Avoid falling into holes (O).
|
||||||
|
2. Frozen tiles are slippery, you may move perpendicular to
|
||||||
|
your intended direction.
|
||||||
|
|
||||||
|
Valid Action (separated by | ):
|
||||||
|
Up | Down | Left | Right
|
||||||
|
|
||||||
|
Rewards:
|
||||||
|
Fall into hole: 0
|
||||||
|
Reach goal: +1.0
|
||||||
|
|
||||||
|
You will be provided the current observation, please decide on
|
||||||
|
the next Action.
|
||||||
|
You should show your thought process and then input the final
|
||||||
|
action in ``` ```.
|
||||||
|
You should only output the NEXT ACTION at each iteration in
|
||||||
|
the ``` ```. For example, if you want to move up, you should
|
||||||
|
output ```Up```.
|
||||||
|
You should plan ahead and need to achieve it in minimum number
|
||||||
|
of steps.
|
||||||
|
You should be aware that frozen tiles can be slippery, but the
|
||||||
|
chance is small and you should not overthink it.
|
||||||
|
|
||||||
|
Please show your thinking process and put the final action in
|
||||||
|
``` ```. In every turn, the final action MUST be one of Up,
|
||||||
|
Down, Left, Right.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class FrozenLakeAction(BaseModel):
|
||||||
|
"""Action model for FrozenLake environment."""
|
||||||
|
|
||||||
|
action: Literal["Up", "Down", "Left", "Right"] = Field(
|
||||||
|
description=(
|
||||||
|
"The action to take in the FrozenLake environment, "
|
||||||
|
"must be one of Up, Down, Left, Right"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def is_valid(board: list[list[str]], max_size: int, max_steps: int) -> bool:
|
||||||
|
"""DFS to check that it's a valid path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
board: The board representation as a list of lists.
|
||||||
|
max_size: Maximum size of the board.
|
||||||
|
max_steps: Maximum number of steps allowed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if there's a valid path from start to goal within max_steps,
|
||||||
|
False otherwise.
|
||||||
|
"""
|
||||||
|
frontier, discovered = [], set()
|
||||||
|
# find the start point
|
||||||
|
start_r, start_c = np.where(np.array(board) == "S")
|
||||||
|
frontier.append((start_r[0], start_c[0], 0)) # row, col steps
|
||||||
|
# dfs to check if there is a path from start to goal
|
||||||
|
while frontier:
|
||||||
|
r, c, steps = frontier.pop()
|
||||||
|
if steps > max_steps:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if (r, c) not in discovered:
|
||||||
|
discovered.add((r, c))
|
||||||
|
directions = [(1, 0), (0, 1), (-1, 0), (0, -1)]
|
||||||
|
for x, y in directions:
|
||||||
|
r_new = r + x
|
||||||
|
c_new = c + y
|
||||||
|
if (
|
||||||
|
r_new < 0
|
||||||
|
or r_new >= max_size
|
||||||
|
or c_new < 0
|
||||||
|
or c_new >= max_size
|
||||||
|
): # noqa: PLR2004
|
||||||
|
continue
|
||||||
|
if board[r_new][c_new] == "G":
|
||||||
|
return True
|
||||||
|
if board[r_new][c_new] != "H":
|
||||||
|
frontier.append((r_new, c_new, steps + 1))
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def generate_random_map(
|
||||||
|
size: int = 8,
|
||||||
|
p: float = 0.8,
|
||||||
|
seed: int = 0,
|
||||||
|
max_steps: int = 5,
|
||||||
|
) -> Tuple[list[str], Tuple[int, int]]:
|
||||||
|
"""Generates a random valid map (one that has a path from start to goal).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
size: Size of each side of the grid.
|
||||||
|
p: Probability that a tile is frozen.
|
||||||
|
seed: Seed to ensure the generation of reproducible maps.
|
||||||
|
max_steps: Maximum number of steps allowed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple containing a random valid map and the goal position (row, col).
|
||||||
|
"""
|
||||||
|
valid = False
|
||||||
|
board: list[list[str]] = [] # initialize to make pyright happy
|
||||||
|
|
||||||
|
try:
|
||||||
|
from gymnasium.utils import seeding
|
||||||
|
|
||||||
|
np_random, _ = seeding.np_random(seed)
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError(
|
||||||
|
"Gymnasium is not installed. "
|
||||||
|
"Please install gymnasium first before "
|
||||||
|
"running the frozen_lake workflow.",
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
# generate random start and end points
|
||||||
|
while not valid:
|
||||||
|
p = min(1, p)
|
||||||
|
board = np_random.choice(
|
||||||
|
["F", "H"],
|
||||||
|
(size, size),
|
||||||
|
p=[p, 1 - p],
|
||||||
|
).tolist()
|
||||||
|
|
||||||
|
while True:
|
||||||
|
start_r = int(np_random.integers(0, size))
|
||||||
|
start_c = int(np_random.integers(0, size))
|
||||||
|
goal_r = int(np_random.integers(0, size))
|
||||||
|
goal_c = int(np_random.integers(0, size))
|
||||||
|
|
||||||
|
# Ensure start and goal are different positions
|
||||||
|
if (start_r, start_c) != (goal_r, goal_c):
|
||||||
|
break
|
||||||
|
|
||||||
|
board[start_r][start_c] = "S"
|
||||||
|
board[goal_r][goal_c] = "G"
|
||||||
|
|
||||||
|
valid = is_valid(board, size, max_steps)
|
||||||
|
return ["".join(x) for x in board], (goal_r, goal_c)
|
||||||
|
|
||||||
|
|
||||||
|
def get_goal_position(
|
||||||
|
random_map: np.ndarray,
|
||||||
|
) -> Optional[Tuple[int, int]]:
|
||||||
|
"""Get the goal position from a random map.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
random_map: The map as a numpy array.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (row, col) if goal found, None otherwise.
|
||||||
|
"""
|
||||||
|
positions = np.argwhere(random_map == b"G")
|
||||||
|
if positions.size == 0:
|
||||||
|
return None # G not found
|
||||||
|
return tuple(positions[0]) # returns (row, col)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"SYSTEM_PROMPT",
|
||||||
|
"FrozenLakeAction",
|
||||||
|
"generate_random_map",
|
||||||
|
"get_goal_position",
|
||||||
|
]
|
||||||
53
tuner/frozen_lake/config.yaml
Normal file
53
tuner/frozen_lake/config.yaml
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
project: "AgentScope" # Project name
|
||||||
|
name: "FrozenLake" # Experiment name
|
||||||
|
checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints} # Directory to save model checkpoints
|
||||||
|
algorithm:
|
||||||
|
algorithm_type: multi_step_grpo # GRPO series for multi-step scenario
|
||||||
|
repeat_times: 16 # Number of rollouts per prompt for advantage estimation
|
||||||
|
kl_loss_fn: "low_var_kl"
|
||||||
|
kl_loss_fn_args:
|
||||||
|
kl_coef: 0 # KL divergence coefficient
|
||||||
|
advantage_fn_args:
|
||||||
|
epsilon: 1e-6 # Small value for numerical stability
|
||||||
|
std_threshold: 0.0001 # Threshold for standard deviation
|
||||||
|
optimizer:
|
||||||
|
lr: 1e-6 # Learning rate
|
||||||
|
model:
|
||||||
|
model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-3B-Instruct} # Base model path
|
||||||
|
max_prompt_tokens: 23552 # Max tokens for prompt
|
||||||
|
max_response_tokens: 2048 # Max tokens per response
|
||||||
|
max_model_len: 25600 # Max context length
|
||||||
|
temperature: 1.0 # Sampling temperature
|
||||||
|
buffer:
|
||||||
|
total_epochs: 5 # Total training epochs
|
||||||
|
batch_size: 32 # Batch size per explore step
|
||||||
|
train_batch_size: 1024 # Total experiences per training step
|
||||||
|
trainer_input:
|
||||||
|
experience_buffer:
|
||||||
|
name: experience_buffer
|
||||||
|
storage_type: queue
|
||||||
|
max_read_timeout: 7200 # Max timeout for reading from buffer (seconds)
|
||||||
|
replay_buffer:
|
||||||
|
enable: true # Enable experience replay
|
||||||
|
priority_fn: linear_decay # Priority function for replay buffer
|
||||||
|
priority_fn_args:
|
||||||
|
decay: 0.1 # Decay rate for priority function
|
||||||
|
explorer:
|
||||||
|
runner_per_model: 16 # Number of runners per model
|
||||||
|
rollout_model:
|
||||||
|
engine_num: 6 # Number of vLLM engines for rollout model
|
||||||
|
tensor_parallel_size: 1 # TP size per engine for rollout model
|
||||||
|
enable_openai_api: true # Enable OpenAI-compatible API
|
||||||
|
enable_history: true # Enable conversation history
|
||||||
|
enable_auto_tool_choice: true # Enable automatic tool selection
|
||||||
|
tool_call_parser: hermes # Parser for tool calls
|
||||||
|
trainer:
|
||||||
|
save_interval: 100 # Save checkpoint every N steps
|
||||||
|
use_dynamic_bsz: true # Use dynamic batch size
|
||||||
|
grad_clip: 1.0 # Gradient clipping value
|
||||||
|
max_token_len_per_gpu: 25600 # Max token length per GPU
|
||||||
|
ulysses_sequence_parallel_size: 2 # Sequence parallel size for Ulysses
|
||||||
|
synchronizer:
|
||||||
|
sync_style: dynamic_by_explorer # Sync triggered dynamically by explorer
|
||||||
|
sync_interval: 1 # Sync every N steps
|
||||||
|
sync_timeout: 1200 # Timeout for synchronization (seconds)
|
||||||
BIN
tuner/frozen_lake/critic_rewards_mean.png
Normal file
BIN
tuner/frozen_lake/critic_rewards_mean.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 62 KiB |
131
tuner/frozen_lake/get_frozenlake_data.py
Normal file
131
tuner/frozen_lake/get_frozenlake_data.py
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
Modified from rllm
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_DATA_PATH = os.path.join(
|
||||||
|
os.path.dirname(os.path.abspath(__file__)),
|
||||||
|
"..",
|
||||||
|
"data",
|
||||||
|
"frozenlake",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def save_dataset_to_local(
|
||||||
|
data_path: str,
|
||||||
|
data: list[dict],
|
||||||
|
split: str = "default",
|
||||||
|
) -> str:
|
||||||
|
"""Save dataset directly to local data_path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_path: Path to save the dataset
|
||||||
|
data: List of dictionaries containing the dataset examples
|
||||||
|
split: Split name (e.g., 'train', 'test', 'default')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Path to the saved parquet file
|
||||||
|
"""
|
||||||
|
os.makedirs(data_path, exist_ok=True)
|
||||||
|
|
||||||
|
# Convert to DataFrame and save
|
||||||
|
data_df = pd.DataFrame(data)
|
||||||
|
dataset_path = os.path.join(data_path, f"{split}.parquet")
|
||||||
|
data_df.to_parquet(dataset_path)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Saved dataset frozenlake split '{split}' "
|
||||||
|
f"with {len(data)} examples at {dataset_path}. "
|
||||||
|
f"Make sure to set the environment variable "
|
||||||
|
f"<TRINITY_TASKSET_PATH> to {data_path}.",
|
||||||
|
)
|
||||||
|
|
||||||
|
return dataset_path
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_frozenlake_data(
|
||||||
|
data_path: str,
|
||||||
|
train_size: int = 10000,
|
||||||
|
test_size: int = 100,
|
||||||
|
map_max_size: int = 6,
|
||||||
|
) -> tuple[list[dict], list[dict]]:
|
||||||
|
"""
|
||||||
|
Prepare and save FrozenLake datasets for training and testing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_path (str): Path to save the dataset
|
||||||
|
train_size (int): Number of training examples to generate
|
||||||
|
test_size (int): Number of test examples to generate
|
||||||
|
map_max_size (int): Maximum size of the map
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (train_data, test_data) - Lists of data dictionaries
|
||||||
|
"""
|
||||||
|
# Set random seed for reproducibility
|
||||||
|
np.random.seed(42)
|
||||||
|
|
||||||
|
# Generate random parameters for train and test sets
|
||||||
|
train_seeds = np.random.randint(0, 100000, size=train_size)
|
||||||
|
test_seeds = np.random.randint(0, 100000, size=test_size)
|
||||||
|
train_sizes = np.random.randint(2, map_max_size, size=train_size)
|
||||||
|
test_sizes = np.random.randint(2, map_max_size, size=test_size)
|
||||||
|
train_ps = np.random.uniform(0.6, 0.85, size=train_size)
|
||||||
|
test_ps = np.random.uniform(0.6, 0.85, size=test_size)
|
||||||
|
|
||||||
|
def frozenlake_process_fn(
|
||||||
|
seed: int,
|
||||||
|
size: int,
|
||||||
|
p: float,
|
||||||
|
idx: int,
|
||||||
|
) -> dict:
|
||||||
|
"""Process function to create FrozenLake task instances."""
|
||||||
|
return {
|
||||||
|
"seed": seed,
|
||||||
|
"size": size,
|
||||||
|
"p": p,
|
||||||
|
"index": idx,
|
||||||
|
"uid": f"{seed}_{size}_{p}",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create train and test data
|
||||||
|
train_data_list = [
|
||||||
|
frozenlake_process_fn(seed, train_sizes[idx], train_ps[idx], idx)
|
||||||
|
for idx, seed in enumerate(train_seeds)
|
||||||
|
]
|
||||||
|
test_data_list = [
|
||||||
|
frozenlake_process_fn(seed, test_sizes[idx], test_ps[idx], idx)
|
||||||
|
for idx, seed in enumerate(test_seeds)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Save datasets directly to local DATA_PATH
|
||||||
|
save_dataset_to_local(data_path, train_data_list, "train")
|
||||||
|
save_dataset_to_local(data_path, test_data_list, "test")
|
||||||
|
|
||||||
|
return train_data_list, test_data_list
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--local_dir", default=DEFAULT_DATA_PATH)
|
||||||
|
parser.add_argument("--train_size", type=int, default=10000)
|
||||||
|
parser.add_argument("--test_size", type=int, default=100)
|
||||||
|
parser.add_argument("--map_max_size", type=int, default=6)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
train_data, test_data = prepare_frozenlake_data(
|
||||||
|
data_path=args.local_dir,
|
||||||
|
train_size=args.train_size,
|
||||||
|
test_size=args.test_size,
|
||||||
|
map_max_size=args.map_max_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Train dataset: {len(train_data)} examples")
|
||||||
|
print(f"Test dataset: {len(test_data)} examples")
|
||||||
|
print("Sample train example:", train_data[0])
|
||||||
|
print("Sample test example:", test_data[0])
|
||||||
151
tuner/frozen_lake/main.py
Normal file
151
tuner/frozen_lake/main.py
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""Example of training a FrozenLake agent with Trinity-RFT."""
|
||||||
|
import os
|
||||||
|
from typing import Dict
|
||||||
|
from _frozenlake_agent import FrozenLakeAgent
|
||||||
|
from _frozenlake_env import FrozenLakeEnv
|
||||||
|
from agentscope.message import Msg
|
||||||
|
from agentscope.tuner import (
|
||||||
|
tune,
|
||||||
|
WorkflowOutput,
|
||||||
|
DatasetConfig,
|
||||||
|
TunerModelConfig,
|
||||||
|
AlgorithmConfig,
|
||||||
|
)
|
||||||
|
from agentscope.model import ChatModelBase
|
||||||
|
|
||||||
|
|
||||||
|
async def run_frozen_lake(
|
||||||
|
task: Dict,
|
||||||
|
model: ChatModelBase,
|
||||||
|
auxiliary_models: Dict[str, ChatModelBase],
|
||||||
|
) -> WorkflowOutput:
|
||||||
|
"""A workflow function using the FrozenLake agent to solve tasks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task (Dict): The task to be solved, containing environment parameters
|
||||||
|
like size, p, seed, is_slippery, etc.
|
||||||
|
model (ChatModelBase): The language model to use.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
WorkflowOutput: The workflow output containing the reward, response and
|
||||||
|
metrics.
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert len(auxiliary_models) == 0, "No auxiliary models are needed"
|
||||||
|
|
||||||
|
# Extract workflow arguments from task or use defaults
|
||||||
|
workflow_args = task.get("workflow_args", {})
|
||||||
|
if not workflow_args:
|
||||||
|
workflow_args = task
|
||||||
|
|
||||||
|
env_max_steps = workflow_args.get("env_max_steps", 8)
|
||||||
|
agent_max_steps = workflow_args.get("agent_max_steps", 10)
|
||||||
|
is_slippery = workflow_args.get("is_slippery", False)
|
||||||
|
desc = workflow_args.get("desc", None)
|
||||||
|
|
||||||
|
# Extract task-specific arguments (for environment generation)
|
||||||
|
size = task.get("size", 8)
|
||||||
|
p = task.get("p", 0.8)
|
||||||
|
seed = task.get("seed", 42)
|
||||||
|
|
||||||
|
# Initialize agent and environment
|
||||||
|
agent = FrozenLakeAgent(model=model, max_steps=agent_max_steps)
|
||||||
|
env = FrozenLakeEnv(
|
||||||
|
max_steps=env_max_steps,
|
||||||
|
desc=desc,
|
||||||
|
is_slippery=is_slippery,
|
||||||
|
size=size,
|
||||||
|
p=p,
|
||||||
|
seed=seed,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reset environment with task parameters
|
||||||
|
observation, _ = env.reset(task)
|
||||||
|
observation_str = str(observation)
|
||||||
|
rewards = []
|
||||||
|
step_count = 0
|
||||||
|
done = False
|
||||||
|
terminate_reason = None
|
||||||
|
|
||||||
|
# Run agent-environment interaction loop
|
||||||
|
for _ in range(agent_max_steps):
|
||||||
|
step_count += 1
|
||||||
|
try:
|
||||||
|
# get prompt
|
||||||
|
prompt = agent.get_prompt(observation_str)
|
||||||
|
|
||||||
|
response = await agent.reply(msg=Msg("user", prompt, role="user"))
|
||||||
|
|
||||||
|
# record action and observation
|
||||||
|
action = agent.get_action(response)
|
||||||
|
agent.update_state(action=action, observation=observation_str)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
terminate_reason = f"agent_error: {str(e)}"
|
||||||
|
break
|
||||||
|
|
||||||
|
# environment step
|
||||||
|
observation, reward, done, _ = env.step(action)
|
||||||
|
observation_str = str(observation)
|
||||||
|
rewards.append(reward)
|
||||||
|
|
||||||
|
if done:
|
||||||
|
terminate_reason = "success" if env.success() else "hole"
|
||||||
|
break
|
||||||
|
|
||||||
|
if terminate_reason is None:
|
||||||
|
terminate_reason = "max_steps_reached"
|
||||||
|
|
||||||
|
final_reward = sum(rewards)
|
||||||
|
final_observation = observation_str
|
||||||
|
|
||||||
|
# Create response message with environment information
|
||||||
|
response_content = (
|
||||||
|
f"Final observation:\n{final_observation}\n"
|
||||||
|
f"Total reward: {final_reward}\n"
|
||||||
|
f"Steps taken: {step_count}\n"
|
||||||
|
f"Terminate reason: {terminate_reason}"
|
||||||
|
)
|
||||||
|
|
||||||
|
final_response = Msg("assistant", response_content, role="assistant")
|
||||||
|
|
||||||
|
return WorkflowOutput(
|
||||||
|
reward=final_reward,
|
||||||
|
response=final_response,
|
||||||
|
metrics={
|
||||||
|
"env_steps": float(step_count),
|
||||||
|
"env_done": float(done),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
dataset = DatasetConfig(
|
||||||
|
path="/path/to/frozenlake",
|
||||||
|
split="train",
|
||||||
|
)
|
||||||
|
tuner_model = TunerModelConfig(
|
||||||
|
model_path="Qwen/Qwen2.5-3B-Instruct",
|
||||||
|
max_model_len=25600,
|
||||||
|
max_tokens=2048,
|
||||||
|
inference_engine_num=6,
|
||||||
|
reasoning_parser=None,
|
||||||
|
)
|
||||||
|
algorithm = AlgorithmConfig(
|
||||||
|
algorithm_type="multi_step_grpo",
|
||||||
|
group_size=16,
|
||||||
|
batch_size=32,
|
||||||
|
learning_rate=1e-6,
|
||||||
|
)
|
||||||
|
config_path = os.path.join(
|
||||||
|
os.path.dirname(__file__),
|
||||||
|
"config.yaml",
|
||||||
|
) # define some default parameters
|
||||||
|
tune(
|
||||||
|
workflow_func=run_frozen_lake,
|
||||||
|
model=tuner_model,
|
||||||
|
train_dataset=dataset,
|
||||||
|
algorithm=algorithm,
|
||||||
|
config_path=config_path,
|
||||||
|
)
|
||||||
@@ -313,7 +313,7 @@ python tuner/learn_to_ask/data_prepare/3_rollout_then_evaluate.py \
|
|||||||
```
|
```
|
||||||
|
|
||||||
> ⚠️ **Note**: Your trained model must be converted to **Hugging Face format** first.
|
> ⚠️ **Note**: Your trained model must be converted to **Hugging Face format** first.
|
||||||
> See: [Converting FSDP Checkpoints Guide](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/faq.html)
|
> See: [Converting FSDP Checkpoints Guide](https://agentscope-ai.github.io/Trinity-RFT/en/main/tutorial/faq.html)
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
|||||||
@@ -248,10 +248,10 @@ if __name__ == "__main__":
|
|||||||
Here, we use `DatasetConfig` to load the training dataset, `TunerModelConfig` to initialize the trainable model, and `AlgorithmConfig` to specify the RL algorithm and its hyperparameters.
|
Here, we use `DatasetConfig` to load the training dataset, `TunerModelConfig` to initialize the trainable model, and `AlgorithmConfig` to specify the RL algorithm and its hyperparameters.
|
||||||
|
|
||||||
> Note:
|
> Note:
|
||||||
> The `tune` function is based on [Trinity-RFT](https://github.com/modelscope/Trinity-RFT) and it converts the input parameters into a YAML configuration internally.
|
> The `tune` function is based on [Trinity-RFT](https://github.com/agentscope-ai/Trinity-RFT) and it converts the input parameters into a YAML configuration internally.
|
||||||
> Advanced users can ignore `model`, `train_dataset`, `algorithm` arguments and provide a configuration file path pointing to a YAML file using the `config_path` argument instead (see [config.yaml](./config.yaml) for an example).
|
> Advanced users can ignore `model`, `train_dataset`, `algorithm` arguments and provide a configuration file path pointing to a YAML file using the `config_path` argument instead (see [config.yaml](./config.yaml) for an example).
|
||||||
> We recommend using the configuration file approach for fine-grained control over the training process and leveraging advanced features provided by Trinity-RFT.
|
> We recommend using the configuration file approach for fine-grained control over the training process and leveraging advanced features provided by Trinity-RFT.
|
||||||
> You can refer to the Trinity-RFT [Configuration Guide](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/trinity_configs.html) for more details on configuration options.
|
> You can refer to the Trinity-RFT [Configuration Guide](https://agentscope-ai.github.io/Trinity-RFT/en/main/tutorial/trinity_configs.html) for more details on configuration options.
|
||||||
|
|
||||||
The checkpoint and logs will automatically be saved to the `checkpoints/AgentScope` directory under the current working directory and each run will be saved in a sub-directory suffixed with the current timestamp.
|
The checkpoint and logs will automatically be saved to the `checkpoints/AgentScope` directory under the current working directory and each run will be saved in a sub-directory suffixed with the current timestamp.
|
||||||
You can find the tensorboard logs inside `monitor/tensorboard` of the checkpoint directory.
|
You can find the tensorboard logs inside `monitor/tensorboard` of the checkpoint directory.
|
||||||
@@ -342,7 +342,7 @@ After implementing the workflow function, follow these steps to run the training
|
|||||||
|
|
||||||
- At least 2 NVIDIA GPUs with CUDA 12.8 or newer.
|
- At least 2 NVIDIA GPUs with CUDA 12.8 or newer.
|
||||||
- Adjust the configuration file ([config.yaml](./config.yaml)) based on your hardware.
|
- Adjust the configuration file ([config.yaml](./config.yaml)) based on your hardware.
|
||||||
- Follow the Trinity-RFT [installation guide](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/trinity_installation.html) to install the latest version from source code.
|
- Follow the Trinity-RFT [installation guide](https://agentscope-ai.github.io/Trinity-RFT/en/main/tutorial/trinity_installation.html) to install the latest version from source code.
|
||||||
- Download the GSM8K dataset and Qwen/Qwen3-0.6B model checkpoints (example):
|
- Download the GSM8K dataset and Qwen/Qwen3-0.6B model checkpoints (example):
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|||||||
@@ -245,10 +245,10 @@ if __name__ == "__main__":
|
|||||||
这里用 `DatasetConfig` 加载训练数据集,`TunerModelConfig` 初始化可训练模型,`AlgorithmConfig` 指定 RL 算法及其超参数。
|
这里用 `DatasetConfig` 加载训练数据集,`TunerModelConfig` 初始化可训练模型,`AlgorithmConfig` 指定 RL 算法及其超参数。
|
||||||
|
|
||||||
> 注意:
|
> 注意:
|
||||||
> `tune` 函数基于 [Trinity-RFT](https://github.com/modelscope/Trinity-RFT) 实现,会将输入参数自动转为 YAML 配置。
|
> `tune` 函数基于 [Trinity-RFT](https://github.com/agentscope-ai/Trinity-RFT) 实现,会将输入参数自动转为 YAML 配置。
|
||||||
> 高级用户可忽略 `model`、`train_dataset`、`algorithm` 参数,直接用 `config_path` 指定 YAML 配置文件(见 [config.yaml](./config.yaml) 示例)。
|
> 高级用户可忽略 `model`、`train_dataset`、`algorithm` 参数,直接用 `config_path` 指定 YAML 配置文件(见 [config.yaml](./config.yaml) 示例)。
|
||||||
> 推荐用配置文件方式实现更细粒度的训练控制,充分利用 Trinity-RFT 的高级特性。
|
> 推荐用配置文件方式实现更细粒度的训练控制,充分利用 Trinity-RFT 的高级特性。
|
||||||
> 详细配置说明见 Trinity-RFT [配置指南](https://modelscope.github.io/Trinity-RFT/zh/main/tutorial/trinity_configs.html)。
|
> 详细配置说明见 Trinity-RFT [配置指南](https://agentscope-ai.github.io/Trinity-RFT/zh/main/tutorial/trinity_configs.html)。
|
||||||
|
|
||||||
训练产生的 checkpoint 和日志信息会自动保存在当前目录下的 `checkpoints/AgentScope` 中,每次运行会新建带时间戳的子目录。
|
训练产生的 checkpoint 和日志信息会自动保存在当前目录下的 `checkpoints/AgentScope` 中,每次运行会新建带时间戳的子目录。
|
||||||
TensorBoard 日志在 checkpoint 目录下的 `monitor/tensorboard` 中。
|
TensorBoard 日志在 checkpoint 目录下的 `monitor/tensorboard` 中。
|
||||||
@@ -335,7 +335,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
- 至少 2 块 NVIDIA GPU,CUDA 12.8 或更高。
|
- 至少 2 块 NVIDIA GPU,CUDA 12.8 或更高。
|
||||||
- 根据硬件调整配置文件([config.yaml](./config.yaml))。
|
- 根据硬件调整配置文件([config.yaml](./config.yaml))。
|
||||||
- 按 Trinity-RFT [安装指南](https://modelscope.github.io/Trinity-RFT/zh/main/tutorial/trinity_installation.html) 从源码安装最新版。
|
- 按 Trinity-RFT [安装指南](https://agentscope-ai.github.io/Trinity-RFT/zh/main/tutorial/trinity_installation.html) 从源码安装最新版。
|
||||||
- 下载 GSM8K 数据集和 Qwen/Qwen3-0.6B 模型权重(示例):
|
- 下载 GSM8K 数据集和 Qwen/Qwen3-0.6B 模型权重(示例):
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
# Please refer to https://modelscope.github.io/Trinity-RFT/en/main/tutorial/trinity_configs.html for detailed explanation of each field.
|
# Please refer to https://agentscope-ai.github.io/Trinity-RFT/en/main/tutorial/trinity_configs.html for detailed explanation of each field.
|
||||||
project: AgentScope
|
project: AgentScope
|
||||||
name: GSM8K-Qwen3-0.6B
|
name: GSM8K-Qwen3-0.6B
|
||||||
# directory to save checkpoints, default to ./checkpoints if TRINITY_CHECKPOINT_ROOT_DIR not set
|
# directory to save checkpoints, default to ./checkpoints if TRINITY_CHECKPOINT_ROOT_DIR not set
|
||||||
|
|||||||
Reference in New Issue
Block a user