Add examples for frozenlake and emailsearch (#94)

This commit is contained in:
Yuchang Sun
2026-01-19 12:25:13 +08:00
committed by GitHub
parent 3821fb04ac
commit 654c35127a
26 changed files with 3370 additions and 14 deletions

View File

@@ -22,6 +22,6 @@ AgentScope Tuner requires:
- `agentscope>=1.0.12`
- `trinity-rft>=0.4.1`
AgentScope Tuner is built on top of [Trinity-RFT](https://github.com/modelscope/Trinity-RFT).
Please refer to the [Trinity-RFT installation guide](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/trinity_installation.html)
AgentScope Tuner is built on top of [Trinity-RFT](https://github.com/agentscope-ai/Trinity-RFT).
Please refer to the [Trinity-RFT installation guide](https://agentscope-ai.github.io/Trinity-RFT/en/main/tutorial/trinity_installation.html)
for detailed instructions on how to set up the environment.

View File

@@ -21,6 +21,6 @@ AgentScope Tuner 需要:
- `agentscope>=1.0.12`
- `trinity-rft>=0.4.1`
AgentScope Tuner 基于 [Trinity-RFT](https://github.com/modelscope/Trinity-RFT) 构建。
请参考 [Trinity-RFT 安装指南](https://modelscope.github.io/Trinity-RFT/zh/main/tutorial/trinity_installation.html)
AgentScope Tuner 基于 [Trinity-RFT](https://github.com/agentscope-ai/Trinity-RFT) 构建。
请参考 [Trinity-RFT 安装指南](https://agentscope-ai.github.io/Trinity-RFT/zh/main/tutorial/trinity_installation.html)
获取详细的安装方法。

View File

@@ -0,0 +1,279 @@
# Training Email Search Agent with RL using AgentScope-Tuner
This example demonstrates how to implement reinforcement fine-tuning for the Email Search task (inspired by [ART](https://openpipe.ai/blog/art-e-mail-agent)) using AgentScope-Tuner, whose RFT functionality is backed by [Trinity-RFT](https://github.com/agentscope-ai/Trinity-RFT).
## Task Setting
The agent's goal is to answer user queries by searching through an email inbox. The agent needs to:
- Understand the user's question
- Search for relevant emails using keywords
- Read email contents to extract information
- Provide accurate answers with proper source citations
**Agent Type**: The agent (`EmailSearchAgent`) extends `ReActAgent`, which follows a reasoning-acting loop to solve tasks iteratively.
**Environment**: The environment is a SQLite database containing emails from the Enron Email dataset. Each task provides:
- `question`: The user's email search query
- `inbox_address`: The email inbox to search
- `query_date`: The date context for the query
- `answer`: The expected answer (ground truth), only for reward calculation
- `message_ids`: IDs of relevant emails containing the answer, only for reward calculation
**Available Tools**:
- `search_emails`: Find emails by keywords, inbox address, and date range. Returns a list of email summaries (message_id and snippet).
- `read_email`: Read the full content of a specific email by message_id.
- `generate_response`: Provide the final structured answer with sources (inherited from ReAct agent).
## Dataset Preparation
The dataset contains email queries based on the [Enron Email dataset](https://huggingface.co/datasets/corbt/enron-emails). Run the data preparation script to generate the email database and datasets:
```bash
python prepare_data.py
```
If you want to choose a new database path, you can modify the `DEFAULT_DB_PATH` in [`prepare_data.py`](./prepare_data.py). Also, remember to set an environment variable `DEFAULT_EMAIL_DB_PATH` to point to the database path before moving to the next step:
```bash
export DEFAULT_EMAIL_DB_PATH=/path/to/enron_emails_dataset/data/enron_emails.db
```
This will create a SQLite database and datasets:
```
/path/to/enron_emails_dataset/
├── data
└── enron_emails.db # Email database
├── train.parquet # Training samples
└── test.parquet # Test samples
```
Each sample looks like:
```json
{
"id": 0,
"question": "Were there any variances detected for hour 6 on 3/9/01?",
"answer": "Yes, variances were detected in both Generation and Energy Import/Export schedules for hour 6 on 3/9/01.",
"message_ids": ["<17407857.1075840601283.JavaMail.evans@thyme>"],
"how_realistic": 0.800000011920929,
"inbox_address": "pete.davis@enron.com",
"query_date": "2001-03-16"
}
```
## Code Implementation
This section provides a high-level overview of the code implementation. For detailed implementation, please refer to the source code.
### Agent Workflow
The workflow function `run_email_search_agent` implements the agent-environment interaction loop:
```python
async def run_email_search_agent(
task: Dict,
model: ChatModelBase,
auxiliary_models: Dict[str, ChatModelBase],
) -> WorkflowOutput:
# Parse task and create agent
agent = EmailSearchAgent(
name="email_search_agent",
sys_prompt=system_prompt,
model=model,
max_iters=max_turns,
)
# Run the agent with structured output
response = await agent.reply(
msg=Msg("user", question, role="user"),
structured_model=AnswerModel,
)
return WorkflowOutput(response=response)
```
The agent follows a ReAct pattern: it reasons about the task, calls tools to search and read emails, and finally generates a structured response containing the answer and source message IDs.
### Judge Function
The judge function `email_search_judge` implements reward calculation using LLM-as-a-Judge:
```python
async def email_search_judge(
task: Dict,
response: Msg,
auxiliary_models: Dict[str, ChatModelBase],
) -> JudgeOutput:
# Extract answer and sources from response
answer = answer_and_sources.get("answer")
sources = answer_and_sources.get("sources", [])
# Judge correctness using LLM-as-a-Judge
judge_model = auxiliary_models.get('judge') or list(auxiliary_models.values())[0]
judge_response = await judge_correctness(
answer, query, judge_model
)
# Calculate reward based on:
# - Answer correctness (accuracy: -1.0 to 1.0)
# - Source correctness (format: partial rewards)
# - Efficiency (bonus for fewer turns, correct sources)
result = {"accuracy": ..., "format": ...} # calculated based on judge_response
return JudgeOutput(
reward=sum(result.values()),
metrics=metrics,
)
```
The reward function considers:
- **Answer correctness**: Evaluated by LLM-as-a-Judge comparing the agent's answer with the ground truth
- **Source correctness**: Whether the agent cited the correct email message IDs
- **Efficiency**: Bonus rewards for finding/reading the correct email and taking fewer turns
See [`main.py`](./main.py) and [`email_search_agent.py`](./email_search_agent.py) for implementation details.
## How to Run
### Prerequisites
- At least 4 NVIDIA GPUs with CUDA 12.8 or newer
* Note: For the 30B Judge model, you need to use a GPU with at least 4080 memory; you can also run the model on multiple GPUs by using `tensor_parallel_size > 1` to reduce the memory usage (by default, `tensor_parallel_size=2`).
- Follow the Trinity-RFT [installation guide](https://agentscope-ai.github.io/Trinity-RFT/en/main/tutorial/trinity_installation.html) to install the latest version from source code
- Download the model checkpoint (example):
```bash
huggingface-cli download Qwen/Qwen3-4B-Instruct-2507
huggingface-cli download Qwen/Qwen3-30B-A3B-Instruct-2507 # judge model
```
### Configuration
Adjust the configuration file ([`config.yaml`](./config.yaml)) based on your hardware. Key configuration sections include:
- **TunerModelConfig**: Set `model_path` to your model checkpoint path
- **AlgorithmConfig**: Configure RL algorithm parameters (e.g., `multi_step_grpo`, learning rate, policy loss function)
- **DatasetConfig**: The dataset path is specified in `main.py` when creating the `DatasetConfig` object
- **Auxiliary Models**: Configure judge model settings for LLM-as-a-Judge
For full configuration details, see [Trinity-RFT Configuration Guide](https://agentscope-ai.github.io/Trinity-RFT/en/main/tutorial/trinity_configs.html).
### Start-Up Commands
1. Prepare the dataset:
```bash
python prepare_data.py
export DEFAULT_EMAIL_DB_PATH=/path/to/enron_emails_dataset/data/enron_emails.db
```
2. Set up a [Ray](https://github.com/ray-project/ray) cluster:
```bash
ray start --head
```
3. Run the training script:
```bash
python main.py
```
## Experimental Results
### Quantitative Results
The training results show improvements in agent performance over training iterations. Key metrics include:
- **Train reward**: The average reward on training samples increases as the agent learns better strategies
- **Rollout accuracy**: The average accuracy on rollout samples increases as the agent learns better strategies
![Training Rewards](./critic_reward_mean.png)
![Rollout Accuracy](./rollout_accuracy_mean.png)
### Concrete Example
An example of the agent's behavior is shown below:
**Query:** "What do the color codes mean in the curve assessment?"
We show the last several turns of agent responses:
The agent performs multiple search attempts to find relevant emails. After some unsuccessful searches, the agent tries:
**Tool call:**
```json
{
"type": "tool_use",
"name": "search_emails",
"input": {
"inbox_address": "steven.kean@enron.com",
"query_date": "2001-04-03",
"keywords": ["curve", "assessment"]
}
}
```
**Tool result:**
```json
{
"type": "tool_result",
"name": "search_emails",
"output": [
{
"message_id": "<2654330.1075846153519.JavaMail.evans@thyme>",
"snippet": "... curves and Senior Management as necessary.\\n\\n<b>Curve Assessment</b>\\n Green: \\n - curve is reasonable\\n - small P&L ..."
},
{
"message_id": "<12499967.1075846153495.JavaMail.evans@thyme>",
"snippet": "... curves and Senior Management as necessary.\\n\\n<b>Curve Assessment</b>\\n Green: \\n - curve is reasonable\\n - small P&L ..."
}
]
}
```
After finding relevant emails, the agent uses `read_email` tool to read the full content of both emails:
**Tool call:**
```json
{
"type": "tool_use",
"name": "read_email",
"input": {
"message_id": "<2654330.1075846153519.JavaMail.evans@thyme>"
}
}
```
**Tool call:**
```json
{
"type": "tool_use",
"name": "read_email",
"input": {
"message_id": "<12499967.1075846153495.JavaMail.evans@thyme>"
}
}
```
After reading the emails, the agent generates the final structured response:
**Tool call:**
```json
{
"type": "tool_use",
"name": "generate_response",
"input": {
"answer": "In the curve assessment, the color codes are used as follows:\n- Green: The curve is considered reasonable, and the P&L (profit and loss) is small.",
"sources": [
"<2654330.1075846153519.JavaMail.evans@thyme>",
"<12499967.1075846153495.JavaMail.evans@thyme>"
]
}
}
```
The judge evaluates the answer as correct based on the ground truth.

View File

@@ -0,0 +1,279 @@
# 使用 AgentScope-Tuner 训练邮件搜索智能体
本示例展示如何使用 AgentScope-Tuner 对邮件搜索任务(灵感来自 [ART](https://openpipe.ai/blog/art-e-mail-agent))进行强化微调,其 RFT 功能由 [Trinity-RFT](https://github.com/agentscope-ai/Trinity-RFT) 提供支持。
## 任务设定
智能体的目标是通过搜索邮件收件箱来回答用户查询。智能体需要:
- 理解用户的问题
- 使用关键词搜索相关邮件
- 阅读邮件内容以提取信息
- 提供准确的答案并附上适当的来源引用
**智能体类型**:智能体(`EmailSearchAgent`)继承自 `ReActAgent`,遵循推理-行动循环来迭代解决任务。
**环境**:环境是一个包含来自 Enron 邮件数据集的 SQLite 数据库。每个任务提供:
- `question`:用户的邮件搜索查询
- `inbox_address`:要搜索的邮件收件箱
- `query_date`:查询的日期上下文
- `answer`:预期答案(真实值),仅用于奖励计算
- `message_ids`:包含答案的相关邮件 ID仅用于奖励计算
**可用工具**
- `search_emails`通过关键词、收件箱地址和日期范围查找邮件。返回邮件摘要列表message_id 和片段)。
- `read_email`:通过 message_id 读取特定邮件的完整内容。
- `generate_response`:提供带有来源的最终结构化答案(继承自 ReAct 智能体)。
## 数据集准备
数据集包含基于 [Enron 邮件数据集](https://huggingface.co/datasets/corbt/enron-emails) 的邮件查询。运行数据准备脚本以生成邮件数据库和数据集:
```bash
python prepare_data.py
```
如果你想选择新的数据库路径,可以修改 [`prepare_data.py`](./prepare_data.py) 中的 `DEFAULT_DB_PATH`。同时,请记住在进入下一步之前设置环境变量 `DEFAULT_EMAIL_DB_PATH` 指向数据库路径:
```bash
export DEFAULT_EMAIL_DB_PATH=/path/to/enron_emails_dataset/data/enron_emails.db
```
这将创建一个 SQLite 数据库和数据集:
```
/path/to/enron_emails_dataset/
├── data
└── enron_emails.db # 邮件数据库
├── train.parquet # 训练样本
└── test.parquet # 测试样本
```
每个样本如下所示:
```json
{
"id": 0,
"question": "Were there any variances detected for hour 6 on 3/9/01?",
"answer": "Yes, variances were detected in both Generation and Energy Import/Export schedules for hour 6 on 3/9/01.",
"message_ids": ["<17407857.1075840601283.JavaMail.evans@thyme>"],
"how_realistic": 0.800000011920929,
"inbox_address": "pete.davis@enron.com",
"query_date": "2001-03-16"
}
```
## 代码实现
本节提供代码实现的高级概览。详细实现请参考源代码。
### 智能体工作流
工作流函数 `run_email_search_agent` 实现智能体-环境交互循环:
```python
async def run_email_search_agent(
task: Dict,
model: ChatModelBase,
auxiliary_models: Dict[str, ChatModelBase],
) -> WorkflowOutput:
# 解析任务并创建智能体
agent = EmailSearchAgent(
name="email_search_agent",
sys_prompt=system_prompt,
model=model,
max_iters=max_turns,
)
# 使用结构化输出运行智能体
response = await agent.reply(
msg=Msg("user", question, role="user"),
structured_model=AnswerModel,
)
return WorkflowOutput(response=response)
```
智能体遵循 ReAct 模式:它推理任务,调用工具搜索和阅读邮件,最后生成包含答案和来源消息 ID 的结构化响应。
### 评判函数
评判函数 `email_search_judge` 使用 LLM-as-a-Judge 实现奖励计算:
```python
async def email_search_judge(
task: Dict,
response: Msg,
auxiliary_models: Dict[str, ChatModelBase],
) -> JudgeOutput:
# 从响应中提取答案和来源
answer = answer_and_sources.get("answer")
sources = answer_and_sources.get("sources", [])
# 使用 LLM-as-a-Judge 评判正确性
judge_model = auxiliary_models.get('judge') or list(auxiliary_models.values())[0]
judge_response = await judge_correctness(
answer, query, judge_model
)
# 基于以下因素计算奖励:
# - 答案正确性(准确度:-1.0 到 1.0
# - 来源正确性(格式:部分奖励)
# - 效率(对更少轮次、正确来源的奖励)
result = {"accuracy": ..., "format": ...} # 基于 judge_response 计算
return JudgeOutput(
reward=sum(result.values()),
metrics=metrics,
)
```
奖励函数考虑以下因素:
- **答案正确性**:通过 LLM-as-a-Judge 比较智能体的答案与真实值进行评估
- **来源正确性**:智能体是否引用了正确的邮件消息 ID
- **效率**:对找到/阅读正确邮件和更少轮次的奖励
详细实现请参考 [`main.py`](./main.py) 和 [`email_search_agent.py`](./email_search_agent.py)。
## 运行方法
### 前置要求
- 至少 4 张 NVIDIA GPUCUDA 版本 ≥ 12.8
* 注意:对于 30B 评判模型,需要使用至少 4080 显存的 GPU你也可以通过使用 `tensor_parallel_size > 1` 在多张 GPU 上运行模型以减少显存使用(默认情况下,`tensor_parallel_size=2`)。
- 按照 Trinity-RFT [安装指南](https://agentscope-ai.github.io/Trinity-RFT/zh/main/tutorial/trinity_installation.html) 从源码安装最新版本
- 下载模型检查点(示例):
```bash
huggingface-cli download Qwen/Qwen3-4B-Instruct-2507
huggingface-cli download Qwen/Qwen3-30B-A3B-Instruct-2507 # 评判模型
```
### 配置
根据你的硬件调整配置文件([`config.yaml`](./config.yaml))。关键配置部分包括:
- **TunerModelConfig**:将 `model_path` 设置为你的模型检查点路径
- **AlgorithmConfig**:配置 RL 算法参数(例如,`multi_step_grpo`、学习率、策略损失函数)
- **DatasetConfig**:数据集路径在创建 `DatasetConfig` 对象时在 `main.py` 中指定
- **辅助模型**:为 LLM-as-a-Judge 配置评判模型设置
完整配置详情请参考 [Trinity-RFT 配置指南](https://agentscope-ai.github.io/Trinity-RFT/zh/main/tutorial/trinity_configs.html)。
### 启动命令
1. 准备数据集:
```bash
python prepare_data.py
export DEFAULT_EMAIL_DB_PATH=/path/to/enron_emails_dataset/data/enron_emails.db
```
2. 启动 [Ray](https://github.com/ray-project/ray)
```bash
ray start --head
```
3. 运行训练脚本:
```bash
python main.py
```
## 实验结果
### 定量结果
训练结果显示智能体性能随训练迭代次数的提升。关键指标包括:
- **训练奖励**:训练样本上的平均奖励随着智能体学习更好的策略而增加
- **Rollout 准确度**Rollout 样本上的平均准确度随着智能体学习更好的策略而增加
![Training Rewards](./critic_reward_mean.png)
![Rollout Accuracy](./rollout_accuracy_mean.png)
### 具体示例
智能体行为示例如下:
**查询:** "What do the color codes mean in the curve assessment?"
我们展示智能体响应的最后几轮:
智能体执行多次搜索尝试以找到相关邮件。经过一些不成功的搜索后,智能体尝试:
**工具调用:**
```json
{
"type": "tool_use",
"name": "search_emails",
"input": {
"inbox_address": "steven.kean@enron.com",
"query_date": "2001-04-03",
"keywords": ["curve", "assessment"]
}
}
```
**工具结果:**
```json
{
"type": "tool_result",
"name": "search_emails",
"output": [
{
"message_id": "<2654330.1075846153519.JavaMail.evans@thyme>",
"snippet": "... curves and Senior Management as necessary.\\n\\n<b>Curve Assessment</b>\\n Green: \\n - curve is reasonable\\n - small P&L ..."
},
{
"message_id": "<12499967.1075846153495.JavaMail.evans@thyme>",
"snippet": "... curves and Senior Management as necessary.\\n\\n<b>Curve Assessment</b>\\n Green: \\n - curve is reasonable\\n - small P&L ..."
}
]
}
```
找到相关邮件后,智能体使用 `read_email` 工具读取两封邮件的完整内容:
**工具调用:**
```json
{
"type": "tool_use",
"name": "read_email",
"input": {
"message_id": "<2654330.1075846153519.JavaMail.evans@thyme>"
}
}
```
**工具调用:**
```json
{
"type": "tool_use",
"name": "read_email",
"input": {
"message_id": "<12499967.1075846153495.JavaMail.evans@thyme>"
}
}
```
阅读邮件后,智能体生成最终的结构化响应:
**工具调用:**
```json
{
"type": "tool_use",
"name": "generate_response",
"input": {
"answer": "In the curve assessment, the color codes are used as follows:\n- Green: The curve is considered reasonable, and the P&L (profit and loss) is small.",
"sources": [
"<2654330.1075846153519.JavaMail.evans@thyme>",
"<12499967.1075846153495.JavaMail.evans@thyme>"
]
}
}
```
评判器评估上面的答案为正确。

View File

@@ -0,0 +1,175 @@
# -*- coding: utf-8 -*-
"""Adapted from Trinity-RFT"""
import json
import traceback
from dataclasses import asdict
from datetime import datetime, timedelta
from typing import Any
from _utils import ( # pylint: disable=E0611
read_email_tool,
search_emails_tool,
)
from agentscope import logger
from agentscope.agent import ReActAgent
from agentscope.message import TextBlock
from agentscope.tool import Toolkit, ToolResponse
def pre_reasoning_hook(_self: Any, _kwargs: Any) -> dict[str, Any] | None:
"""Pre-reasoning hook to remove tool_choice from kwargs."""
_kwargs.pop("tool_choice", None)
return _kwargs
class EmailSearchAgent(ReActAgent):
"""
A customized ReAct agent with pre-defined tools for
email search and reading.
Ref: https://github.com/OpenPipe/ART
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
self.message_id_list = (
[]
) # List to store message IDs found during search
self.ever_read_message_ids = (
[]
) # List to store message IDs that have been read
toolkit = Toolkit()
toolkit.register_tool_function(self.search_emails)
toolkit.register_tool_function(self.read_email)
super().__init__(*args, toolkit=toolkit, **kwargs)
self.register_instance_hook(
"pre_reasoning",
"tool_choice_hook",
pre_reasoning_hook,
)
async def reset(self) -> None:
"""Reset agent state for a new rollout/episode."""
self.message_id_list.clear()
self.ever_read_message_ids.clear()
await self.memory.clear()
def search_emails(
self,
inbox_address: str,
query_date: str,
keywords: list[str],
**_kwargs: Any,
) -> ToolResponse:
"""
Search the user's email inbox for emails that match the given keywords.
Args:
inbox_address: The user's email address.
query_date: The date of the query in 'YYYY-MM-DD' format.
keywords: Keywords to search for in the user's email inbox.
Returns:
ToolResponse:
A ToolResponse object containing a list of TextBlock objects
in the `content` field. On success, the text field of the
TextBlock contains a JSON string representing a list of email
summaries (e.g., message_id, snippet) matching the search
criteria. Each email summary is converted to a dictionary via
`asdict`. On failure, the text indicates an error message.
"""
try:
next_day = (
datetime.strptime(query_date, "%Y-%m-%d") + timedelta(days=1)
).strftime(
"%Y-%m-%d",
)
res = search_emails_tool(
inbox=inbox_address,
sent_before=next_day,
keywords=keywords,
)
self.message_id_list.extend([r.message_id for r in res])
return ToolResponse(
content=[
TextBlock(
type="text",
text=json.dumps([asdict(r) for r in res]),
),
],
)
except Exception as e:
logger.info(
"Error in search_emails: %s, traceback: %s",
e,
traceback.format_exc(),
)
return ToolResponse(
content=[
TextBlock(
type="text",
text=(
f"Error: Failed to search emails.\n"
f"Error message: {e}"
),
),
],
)
def read_email(self, message_id: str, **_kwargs: Any) -> ToolResponse:
"""
Read the content of an email from the user's email inbox.
Returns the email content.
Args:
message_id (str): The unique identifier of the email to read.
Returns:
ToolResponse:
A ToolResponse object containing the email content or an
error message if the email is not found.
"""
try:
email_content = read_email_tool(message_id)
self.ever_read_message_ids.append(message_id)
if email_content is None:
return ToolResponse(
content=[
TextBlock(
type="text",
text=(
f"Error: Email (message_id = {message_id}) "
f"not found."
),
),
],
)
return ToolResponse(
content=[
TextBlock(
type="text",
text=json.dumps(email_content.model_dump()),
),
],
)
except Exception as e:
logger.info(
"Error in read_email: %s, traceback: %s",
e,
traceback.format_exc(),
)
return ToolResponse(
content=[
TextBlock(
type="text",
text=(
f"Error: Failed to read email.\n"
f"Error message: {e}"
),
),
],
)

View File

@@ -0,0 +1,328 @@
# -*- coding: utf-8 -*-
"""
This file defines Dataclass and tool implementations.
Modified from https://github.com/OpenPipe/ART/blob/art-e/
"""
import datetime
import os
import sqlite3
from dataclasses import dataclass
from typing import Any, List, Optional
from pydantic import BaseModel, Field, field_validator
from agentscope import logger
DEFAULT_DB_PATH = os.environ.get("DEFAULT_EMAIL_DB_PATH")
conn = None
def get_conn() -> sqlite3.Connection:
"""Get or create a database connection."""
global conn
if conn is None:
conn = sqlite3.connect(
f"file:{DEFAULT_DB_PATH}?mode=ro",
uri=True,
check_same_thread=False,
)
return conn
class QueryModel(BaseModel):
"""Model for email search query."""
id: int
question: str
answer: str
message_ids: List[str] # message_ids (strings) of referenced emails
how_realistic: float
inbox_address: str
query_date: str
@field_validator("query_date", mode="before")
@classmethod
def format_date(cls, v: Any) -> str:
"""Format date to string if it's a datetime object."""
if isinstance(v, datetime.datetime):
return v.strftime("%Y-%m-%d")
return v
class AnswerModel(BaseModel):
"""Model for agent's answer with sources."""
answer: str = Field(
description=(
"It should be called with the answer and the sources. "
"If you cannot find the answer, you should return "
"'I don't know' with an empty list of sources."
),
)
sources: List[str] = Field(
description=(
"a list of message ids that are relevant to the query. "
"Usually there will be only one. If you cannot find the "
"answer, you should return an empty list."
),
)
class Email(BaseModel):
"""Model representing an email."""
message_id: str
date: str # ISO 8601 string 'YYYY-MM-DD HH:MM:SS'
subject: Optional[str] = None
from_address: Optional[str] = None
to_addresses: List[str] = Field(default_factory=list)
cc_addresses: List[str] = Field(default_factory=list)
bcc_addresses: List[str] = Field(default_factory=list)
body: Optional[str] = None
file_name: Optional[str] = None
@dataclass
class SearchResult:
"""Result from email search."""
message_id: str
snippet: str
class FinalRubric(BaseModel):
"""Rubric for evaluating agent performance."""
answer_correct: bool = False
sources_correct: bool = False
num_turns: int = 0
attempted_answer: bool = False
ever_found_right_email: bool = False
ever_read_right_email: bool = False
cant_parse_tool_call: bool = False
bad_tool_call_name: bool = False
bad_tool_call_args: bool = False
ran_out_of_turns: bool = False
returned_i_dont_know: bool = False
num_sources: int = 0
ever_tried_to_read_invalid_email: bool = False
prompt_tokens: int = 0
completion_tokens: int = 0
# Define tools for agent
def search_emails_tool(
inbox: str,
keywords: List[str],
from_addr: Optional[str] = None,
to_addr: Optional[str] = None,
sent_after: Optional[str] = None,
sent_before: Optional[str] = None,
max_results: int = 10,
) -> List[SearchResult]:
"""
Searches the email database based on keywords, inbox,
sender, recipient, and date range.
Args:
inbox: The email address of the user performing the search.
Results include emails sent from or to (inc. cc/bcc)
this address.
keywords: A list of keywords that must all appear in the
subject or body.
from_addr: Optional email address to filter emails sent *from*.
to_addr: Optional email address to filter emails sent *to*
(inc. cc/bcc).
sent_after: Optional date string 'YYYY-MM-DD'. Filters for
emails sent on or after this date.
sent_before: Optional date string 'YYYY-MM-DD'. Filters for
emails sent before this date.
max_results: The maximum number of results to return.
Cannot exceed 10.
Returns:
A list of SearchResult objects, each containing 'message_id'
and 'snippet'. Returns an empty list if no results are found
or an error occurs.
"""
# Initialize sql and params
sql: Optional[str] = None
params: List[str | int] = []
cursor = get_conn().cursor()
# --- Build Query ---
where_clauses: List[str] = []
# 1. Keywords (FTS)
if not keywords:
raise ValueError("No keywords provided for search.")
if max_results > 10:
raise ValueError("max_results must be less than or equal to 10.")
# FTS5 default is AND, so just join keywords. Escape quotes for safety.
fts_query = " ".join(f""" "{k.replace('"', '""')}" """ for k in keywords)
where_clauses.append("emails_fts MATCH ?")
params.append(fts_query)
# 2. Inbox filter (must be from OR to/cc/bcc the inbox user)
# Use the composite index idx_recipients_address_email here
where_clauses.append(
"""
(e.from_address = ? OR EXISTS (
SELECT 1 FROM recipients r_inbox
WHERE r_inbox.recipient_address = ? AND r_inbox.email_id = e.id
))
""",
)
params.extend([inbox, inbox])
# 3. Optional From filter
if from_addr:
where_clauses.append("e.from_address = ?")
params.append(from_addr)
# 4. Optional To filter (includes to, cc, bcc)
# Use composite index idx_recipients_address_email
if to_addr:
where_clauses.append(
"""
EXISTS (
SELECT 1 FROM recipients r_to
WHERE r_to.recipient_address = ? AND r_to.email_id = e.id
)
""",
)
params.append(to_addr)
# 5. Optional Sent After filter
if sent_after:
# Assumes date format 'YYYY-MM-DD'
# Compare against the start of the day
where_clauses.append("e.date >= ?")
params.append(f"{sent_after} 00:00:00")
# 6. Optional Sent Before filter
if sent_before:
# Assumes date format 'YYYY-MM-DD'
# Compare against the start of the day (exclusive)
where_clauses.append("e.date < ?")
params.append(f"{sent_before} 00:00:00")
# --- Construct Final Query ---
# snippet(<table>, <column_index>, <highlight_start>,
# <highlight_end>, <ellipsis>, <tokens>)
# -1 means highlight across all columns (subject, body)
sql = f"""
SELECT
e.message_id,
snippet(emails_fts, -1, '<b>', '</b>', ' ... ', 15) as snippet
FROM
emails e JOIN emails_fts fts ON e.id = fts.rowid
WHERE
{" AND ".join(where_clauses)}
ORDER BY
e.date DESC -- Order by date for relevance
LIMIT ?;
"""
params.append(max_results)
# --- Execute and Fetch ---
logger.debug("Executing SQL: %s", sql)
logger.debug("With params: %s", params)
cursor.execute(sql, params)
results = cursor.fetchall()
# Format results
formatted_results = [
SearchResult(message_id=row[0], snippet=row[1]) for row in results
]
logger.info("Search found %d results.", len(formatted_results))
return formatted_results
def read_email_tool(message_id: str) -> Optional[Email]:
"""
Retrieves a single email by its message_id from the database.
Args:
message_id: The unique identifier of the email to retrieve.
Returns:
An Email object containing the details of the found email,
or None if the email is not found or an error occurs.
"""
cursor = get_conn().cursor()
# --- Query for Email Core Details ---
email_sql = """
SELECT id, message_id, date, subject, from_address, body, file_name
FROM emails
WHERE message_id = ?;
"""
cursor.execute(email_sql, (message_id,))
email_row = cursor.fetchone()
if not email_row:
logger.warning("Email with message_id '%s' not found.", message_id)
return None
email_pk_id, msg_id, date, subject, from_addr, body, file_name = email_row
# DEBUG
logger.info("[read_email_tool] input_message_id=%s", message_id)
logger.info(
"[read_email_tool] db: id=%s, message_id=%s",
email_pk_id,
msg_id,
)
# search for recipients by emails.id (rather than message_id)
recipients_sql = """
SELECT recipient_address, recipient_type
FROM recipients
WHERE email_id = ?;
"""
cursor.execute(recipients_sql, (email_pk_id,))
recipient_rows = cursor.fetchall()
to_addresses: List[str] = []
cc_addresses: List[str] = []
bcc_addresses: List[str] = []
for addr, rtype in recipient_rows:
type_lower = rtype.lower()
if type_lower == "to":
to_addresses.append(addr)
elif type_lower == "cc":
cc_addresses.append(addr)
elif type_lower == "bcc":
bcc_addresses.append(addr)
# --- Construct Email Object ---
email_obj = Email(
message_id=msg_id, # Convert to string to match Pydantic model
date=date,
subject=subject,
from_address=from_addr,
to_addresses=to_addresses,
cc_addresses=cc_addresses,
bcc_addresses=bcc_addresses,
body=body,
file_name=file_name,
)
return email_obj
__all__ = [
"QueryModel",
"AnswerModel",
"FinalRubric",
"Email",
"SearchResult",
"search_emails_tool",
"read_email_tool",
"get_conn",
]

View File

@@ -0,0 +1,72 @@
project: "AgentScope" # Project name
name: "Email_search" # Experiment name
checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints} # Directory to save model checkpoints
algorithm:
algorithm_type: multi_step_grpo # GRPO series for multi-step scenario
repeat_times: 8 # Number of rollouts per prompt for advantage estimation
optimizer:
lr: 1e-6 # Learning rate
policy_loss_fn: "rec" # Policy loss function
policy_loss_fn_args: # Policy loss function arguments
epsilon_low: 0.2
epsilon_high: 0.2
clip_mode: "one-side"
weight: "none"
temp: 1.0
regularizer: "none"
regularizer_coef: 0.0
kl_loss_fn: 'k2' # KL divergence loss function
kl_loss_fn_args:
kl_coef: 0.0 # KL divergence coefficient
advantage_fn_args:
std_cal_level: 'batch' # Advantage normalization level
model:
model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen3-4B-Instruct-2507} # Base model path
max_response_tokens: 4096 # Max tokens per response
max_model_len: 20480 # Max context length
buffer:
total_epochs: 10 # Total training epochs
batch_size: 64 # Batch size per explore step
train_batch_size: 2560 # 64*8*5, total experiences per training step
trainer_input:
experience_buffer:
name: experience_buffer
storage_type: queue
replay_buffer:
enable: true # Enable experience replay
priority_fn: 'decay_limit_randomization'
priority_fn_args:
decay: 2.0
use_count_limit: 3
sigma: 2.0
explorer:
eval_interval: 10
max_repeat_times_per_runner: 1 # Max repeat times per runner
max_timeout: 3600 # Max timeout for each rollout (seconds)
rollout_model:
enable_history: true # Enable conversation history
enable_openai_api: true # Enable OpenAI-compatible API
enable_auto_tool_choice: true # Enable automatic tool selection
tool_call_parser: hermes # Parser for tool calls
engine_num: 4 # Number of vLLM engines for rollout model
tensor_parallel_size: 1 # TP size per engine for rollout model
enable_prefix_caching: false # Disable prefix caching
auxiliary_models:
- name: judge
model_path: Qwen/Qwen3-30B-A3B-Instruct-2507 # Judge model path
engine_num: 1 # Number of vLLM engines for judge model
tensor_parallel_size: 2 # TP size per engine for judge model
enable_thinking: false # Disable thinking/reasoning mode
max_prompt_tokens: 2048 # Max tokens for prompt
max_response_tokens: 128 # Max tokens for response
max_model_len: 2500 # Max model context length
synchronizer:
sync_style: dynamic_by_explorer # Sync triggered dynamically by explorer
sync_interval: 5 # Sync every N steps
sync_timeout: 3600 # Timeout for synchronization (seconds)
trainer:
save_interval: 100 # Save checkpoint every N steps
grad_clip: 1.0 # Gradient clipping value
use_dynamic_bsz: true # Use dynamic batch size
max_token_len_per_gpu: 16384 # Max token length per GPU
ulysses_sequence_parallel_size: 1 # Sequence parallel size for Ulysses

Binary file not shown.

After

Width:  |  Height:  |  Size: 470 KiB

379
tuner/email_search/main.py Normal file
View File

@@ -0,0 +1,379 @@
# -*- coding: utf-8 -*-
"""Example of training an Email Search agent with Trinity-RFT."""
import os
from typing import Dict
from _email_search_agent import EmailSearchAgent
from _utils import ( # pylint: disable=E0611
AnswerModel,
FinalRubric,
QueryModel,
)
from agentscope import logger
from agentscope.formatter import OpenAIChatFormatter
from agentscope.message import Msg
from agentscope.tuner import (
TunerModelConfig,
DatasetConfig,
JudgeOutput,
WorkflowOutput,
AlgorithmConfig,
tune,
)
from agentscope.model import ChatModelBase
SYSTEM_PROMPT = """You are an email search agent. You are given a user query
and a list of tools you can use to search the user's email. Use the tools to
search the user's emails and find the answer to the user's query. You may take
up to {max_turns} turns to find the answer, so if your first seach doesn't
find the answer, you can try with different keywords.
Always describe what you see and plan your next steps clearly. When taking
actions, explain what you're doing and why. When the answer to the task is
found, call `generate_response` to finish the process. Only call
`generate_response` when answer is found. You should not respond any next steps
in `generate_response`. Complete all steps and then call `generate_response`.
User's email address is {inbox_address}
Today's date is {query_date}
"""
async def run_email_search_agent(
task: Dict,
model: ChatModelBase,
auxiliary_models: Dict[str, ChatModelBase],
) -> WorkflowOutput: # noqa: PLR0915
"""A workflow function using the Email Search agent to solve tasks.
Args:
task (Dict): The task to be solved.
Should contain fields from QueryModel.
model (TrinityChatModel): The language model to use.
Returns:
WorkflowOutput: The output containing the agent's response.
"""
assert len(auxiliary_models) > 0, "LLM-as-a-Judge is required"
# Parse task data
query = QueryModel.model_validate(task)
question = task.get("question", task.get("task_desc", ""))
# Get workflow arguments with defaults
workflow_args = task.get("workflow_args", {})
max_turns = int(workflow_args.get("max_turns", 10))
# Format system prompt
system_prompt = SYSTEM_PROMPT.format(
max_turns=max_turns,
inbox_address=query.inbox_address,
query_date=query.query_date,
)
# Create EmailSearchAgent
agent = EmailSearchAgent(
name="email_search_agent",
sys_prompt=system_prompt,
model=model,
formatter=OpenAIChatFormatter(),
max_iters=max_turns,
)
# Reset agent state for a new rollout
await agent.reset()
# Run the agent with structured output
response = await agent.reply(
msg=Msg("user", question, role="user"),
structured_model=AnswerModel,
)
# Extract answer and sources from response metadata
answer_and_sources = response.metadata or {}
if not answer_and_sources:
# Fallback: try to parse from content
answer_and_sources = {
"answer": response.get_text_content() or "",
"sources": [],
}
# Store agent state for judge function
# We'll pass this through the response metadata
response_metadata = {
"answer_and_sources": answer_and_sources,
"query": query.model_dump(),
"message_id_list": agent.message_id_list,
"ever_read_message_ids": agent.ever_read_message_ids,
# Estimate actual_turns from memory length
"actual_turns": (
max(1, (len(agent.memory.content) - 1) // 2)
if len(agent.memory.content) > 1
else 1
),
}
# Update response metadata
if response.metadata is None:
response.metadata = {}
response.metadata.update(response_metadata)
return WorkflowOutput(
response=response,
)
def _calculate_partial_rewards(rubric: FinalRubric) -> float:
"""Calculate partial rewards based on rubric."""
partial_rewards = 0.0
partial_rewards += 0.1 if rubric.ever_found_right_email else 0
partial_rewards += 0.1 if rubric.ever_read_right_email else 0
partial_rewards += 0.1 if rubric.sources_correct else 0
return partial_rewards
def _calculate_correct_answer_reward(
rubric: FinalRubric,
max_turns: int,
) -> float:
"""Calculate reward for correct answers."""
reward = 1.0
reward += 0.3 if rubric.sources_correct else 0
reward += 0.1 / rubric.num_sources if rubric.num_sources > 0 else 0
reward += 0.1 * (1 - rubric.num_turns / max_turns)
return reward
def _initialize_rubric(
answer: str,
sources: list[str],
actual_turns: int,
query: QueryModel,
message_id_list: list[str],
ever_read_message_ids: list[str],
) -> FinalRubric:
"""Initialize and populate rubric with basic information."""
rubric = FinalRubric()
rubric.attempted_answer = answer is not None and answer != ""
rubric.returned_i_dont_know = answer == "I don't know"
rubric.num_sources = len(sources)
rubric.num_turns = actual_turns
if len(query.message_ids) > 0:
rubric.ever_found_right_email = query.message_ids[0] in message_id_list
rubric.ever_read_right_email = (
query.message_ids[0] in ever_read_message_ids
)
rubric.sources_correct = query.message_ids[0] in sources
return rubric
async def email_search_judge(
task: Dict,
response: Msg,
auxiliary_models: Dict[str, ChatModelBase],
) -> JudgeOutput:
"""A judge function to calculate reward based on agent's response.
Args:
task (Dict): The task information for the corresponding workflow.
response (Msg): The response generated by the corresponding workflow.
auxiliary_models (Dict[str, ChatModelBase]):
A dictionary of additional chat models available for LLM-as-a-Judge
usage. The keys are model names, and the values are the
corresponding ChatModelBase instances.
Returns:
JudgeOutput: The reward value assigned by the judge function.
"""
# Extract metadata from response
metadata = response.metadata or {}
answer_and_sources = metadata.get("answer_and_sources", {})
query_dict = metadata.get("query", {})
message_id_list = metadata.get("message_id_list", [])
ever_read_message_ids = metadata.get("ever_read_message_ids", [])
actual_turns = metadata.get("actual_turns", 0)
# Parse query model
if not query_dict:
query_dict = task
query = QueryModel.model_validate(query_dict)
# Get arguments
workflow_args = task.get("workflow_args", {})
max_turns = int(workflow_args.get("max_turns", 10))
# Extract answer and sources
try:
answer = answer_and_sources.get("answer", None)
sources = answer_and_sources.get("sources", [])
except Exception:
result = {"accuracy": 0.0, "format": -1.0}
return JudgeOutput(
reward=sum(result.values()),
metrics=result,
)
if answer is None:
result = {"accuracy": 0.0, "format": -1.0}
return JudgeOutput(
reward=sum(result.values()),
metrics=result,
)
# Initialize rubric
rubric = _initialize_rubric(
answer,
sources,
actual_turns,
query,
message_id_list,
ever_read_message_ids,
)
# Judge correctness using LLM-as-a-Judge
try:
judge_model = (
auxiliary_models.get("judge") or list(auxiliary_models.values())[0]
if auxiliary_models
else None
)
judge_response = await judge_correctness(
answer,
query,
judge_model,
)
rubric.answer_correct = judge_response
except Exception as e:
logger.error("Error judging correctness: %s", e)
rubric.answer_correct = False
# Calculate rewards
partial_rewards = _calculate_partial_rewards(rubric)
if rubric.attempted_answer and not rubric.answer_correct:
result = {"accuracy": -1.0, "format": partial_rewards}
elif rubric.returned_i_dont_know or rubric.ran_out_of_turns:
result = {"accuracy": 0.0, "format": partial_rewards}
elif rubric.answer_correct:
reward = _calculate_correct_answer_reward(rubric, max_turns)
result = {"accuracy": 1.0, "format": reward}
else:
result = {"accuracy": 0.0, "format": 0.0}
metrics = result.copy()
metrics.update({"actual_turns": actual_turns})
return JudgeOutput(
reward=sum(result.values()),
metrics=metrics,
)
# LLM-as-a-judge
async def judge_correctness(
answer: str,
query: QueryModel,
judge: ChatModelBase,
) -> bool:
"""Use an LLM to decide whether *answer* matches *query.answer*.
Returns a boolean *accept* flag used for scoring.
"""
system_prompt = """You are given a question, the reference answer
(labelled **Reference answer**), and an answer generated by an AI assistant
(labelled **AI answer**).
Follow these steps to decide whether the AI answer should be accepted:
1. Identify EXACTLY what information the **question** is asking for
(e.g. who, what, when, where, why, how, quantity, etc.).
2. From the **Reference answer**, extract ONLY the facts that are required
to directly satisfy the information need identified in step 1. Treat all
other facts as non-essential context.
3. Verify that every essential fact from step 2 appears in the **AI answer**
with the same meaning. Differences in wording, order, or additional
non-conflicting details are allowed.
4. If any essential fact is missing or contradicted in the **AI answer**,
then *accept* must be **false**. Otherwise *accept* must be **true**.
Important: Do NOT penalise the **AI answer** for omitting non-essential
facts that appear in the **Reference answer**. The answer should only be
rejected for errors or omissions in the information explicitly requested by
the question.
Return your judgement **accept** from **true** and **false**. Do not return
any other text or formatting.
"""
prompt = (
f"Question: {query.question}\n"
f"Reference answer: {query.answer}\n"
f"AI answer: {answer}"
)
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt},
]
chat_response = await judge(messages)
# Extract text content from ChatResponse
result_parts = []
for block in chat_response.content:
if isinstance(block, dict) and block.get("type") == "text":
result_parts.append(str(block.get("text", "")))
result = "".join(result_parts)
logger.info("LLM judge response: %s", result)
return "true" in result.lower()
# End of LLM-as-a-judge
if __name__ == "__main__":
config_path = os.path.join(
os.path.dirname(__file__),
"config.yaml",
)
dataset = DatasetConfig(
path="/path/to/enron_emails_dataset",
split="train",
)
tuner_model = TunerModelConfig(
model_path="Qwen/Qwen3-4B-Instruct-2507",
max_model_len=20480,
max_tokens=4096,
inference_engine_num=4,
reasoning_parser=None,
)
aux_models = {
"judge": TunerModelConfig(
model_path="Qwen/Qwen3-30B-A3B-Instruct-2507",
max_model_len=2500,
max_tokens=2048,
inference_engine_num=1,
tensor_parallel_size=2,
tool_call_parser=None,
reasoning_parser=None,
),
}
algorithm = AlgorithmConfig(
algorithm_type="multi_step_grpo",
group_size=8,
batch_size=64,
learning_rate=1e-6,
)
tune(
workflow_func=run_email_search_agent,
judge_func=email_search_judge,
train_dataset=dataset,
model=tuner_model,
auxiliary_models=aux_models,
algorithm=algorithm,
config_path=config_path,
)

View File

@@ -0,0 +1,357 @@
# -*- coding: utf-8 -*-
"""
Prepare data for training.
Modified from OpenPipe/ART
"""
import logging
import os
import sqlite3
from datetime import datetime
from datasets import Dataset, Features, Sequence, Value, load_dataset
from tqdm import tqdm
# Resolve paths relative to this file
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
# Database will live in "../data/enron_emails.db" relative to project root
DEFAULT_DB_PATH = os.path.join(BASE_DIR, "..", "..", "data", "enron_emails.db")
DEFAULT_REPO_ID = "corbt/enron-emails"
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
)
# --- Database Schema ---
SQL_CREATE_TABLES = """
DROP TABLE IF EXISTS recipients;
DROP TABLE IF EXISTS emails_fts;
DROP TABLE IF EXISTS emails;
CREATE TABLE emails (
id INTEGER PRIMARY KEY AUTOINCREMENT,
message_id TEXT UNIQUE,
subject TEXT,
from_address TEXT,
date TEXT, -- Store as ISO 8601 string 'YYYY-MM-DD HH:MM:SS'
body TEXT,
file_name TEXT
);
CREATE TABLE recipients (
email_id INTEGER,
recipient_address TEXT,
recipient_type TEXT, -- 'to', 'cc', 'bcc'
FOREIGN KEY(email_id) REFERENCES emails(id) ON DELETE CASCADE
);
"""
SQL_CREATE_INDEXES_TRIGGERS = """
CREATE INDEX idx_emails_from ON emails(from_address);
CREATE INDEX idx_emails_date ON emails(date);
CREATE INDEX idx_emails_message_id ON emails(message_id);
CREATE INDEX idx_recipients_address ON recipients(recipient_address);
CREATE INDEX idx_recipients_type ON recipients(recipient_type);
CREATE INDEX idx_recipients_email_id ON recipients(email_id);
CREATE INDEX idx_recipients_address_email ON recipients(
recipient_address, email_id
);
CREATE VIRTUAL TABLE emails_fts USING fts5(
subject,
body,
content='emails',
content_rowid='id'
);
CREATE TRIGGER emails_ai AFTER INSERT ON emails BEGIN
INSERT INTO emails_fts (rowid, subject, body)
VALUES (new.id, new.subject, new.body);
END;
CREATE TRIGGER emails_ad AFTER DELETE ON emails BEGIN
DELETE FROM emails_fts WHERE rowid=old.id;
END;
CREATE TRIGGER emails_au AFTER UPDATE ON emails BEGIN
UPDATE emails_fts SET subject=new.subject, body=new.body
WHERE rowid=old.id;
END;
INSERT INTO emails_fts (rowid, subject, body)
SELECT id, subject, body FROM emails;
"""
# --- Functions ---
def download_dataset(repo_id: str) -> Dataset:
"""Downloads the dataset from Hugging Face Hub."""
logging.info(
"Attempting to download dataset from Hugging Face Hub: %s",
repo_id,
)
expected_features = Features(
{
"message_id": Value("string"),
"subject": Value("string"),
"from": Value("string"),
"to": Sequence(Value("string")),
"cc": Sequence(Value("string")),
"bcc": Sequence(Value("string")),
"date": Value("timestamp[us]"),
"body": Value("string"),
"file_name": Value("string"),
},
)
dataset_obj = load_dataset(
repo_id,
features=expected_features,
split="train",
)
# Basic type check remains useful
if not isinstance(dataset_obj, Dataset):
raise TypeError(f"Expected Dataset, got {type(dataset_obj)}")
logging.info(
"Successfully loaded dataset '%s' with %d records.",
repo_id,
len(dataset_obj),
)
return dataset_obj
def create_database(db_path: str) -> None:
"""Creates the SQLite database and tables."""
logging.info("Creating SQLite database and tables at: %s", db_path)
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
cursor.executescript(SQL_CREATE_TABLES)
conn.commit()
conn.close()
logging.info("Database tables created successfully.")
def _should_skip_email(
body: str,
message_id: str,
to_list: list[str],
cc_list: list[str],
bcc_list: list[str],
) -> bool:
"""Check if email should be skipped based on filters."""
if len(body) > 5000:
logging.debug(
"Skipping email %s: Body length > 5000 characters.",
message_id,
)
return True
total_recipients = len(to_list) + len(cc_list) + len(bcc_list)
if total_recipients > 30:
logging.debug(
"Skipping email %s: Total recipients (%d) > 30.",
message_id,
total_recipients,
)
return True
return False
def _prepare_recipient_data(
email_pk_id: int,
to_list: list[str],
cc_list: list[str],
bcc_list: list[str],
) -> list[tuple[int, str, str]]:
"""Prepare recipient data for database insertion."""
recipient_data = []
for addr in to_list:
recipient_data.append((email_pk_id, addr, "to"))
for addr in cc_list:
recipient_data.append((email_pk_id, addr, "cc"))
for addr in bcc_list:
recipient_data.append((email_pk_id, addr, "bcc"))
return recipient_data
def populate_database(db_path: str, dataset: Dataset) -> None:
"""Populates the database with data from the Hugging Face dataset."""
logging.info("Populating database %s...", db_path)
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
# --- Performance Pragmas ---
conn.execute("PRAGMA synchronous = OFF;")
conn.execute("PRAGMA journal_mode = MEMORY;")
record_count = 0
skipped_count = 0
duplicate_count = 0
processed_emails = set()
conn.execute("BEGIN TRANSACTION;")
for email_data in tqdm(dataset, desc="Inserting emails"):
assert isinstance(email_data, dict)
message_id = email_data["message_id"]
subject = email_data["subject"]
from_address = email_data["from"]
date_obj: datetime = email_data["date"]
body = email_data["body"]
file_name = email_data["file_name"]
to_list_raw = email_data["to"]
cc_list_raw = email_data["cc"]
bcc_list_raw = email_data["bcc"]
date_str = date_obj.strftime("%Y-%m-%d %H:%M:%S")
to_list = [str(addr) for addr in to_list_raw if addr]
cc_list = [str(addr) for addr in cc_list_raw if addr]
bcc_list = [str(addr) for addr in bcc_list_raw if addr]
if _should_skip_email(body, message_id, to_list, cc_list, bcc_list):
skipped_count += 1
continue
email_key = (subject, body, from_address)
if email_key in processed_emails:
logging.debug(
"Skipping duplicate email (Subject: %s..., From: %s)",
subject[:50],
from_address,
)
duplicate_count += 1
continue
processed_emails.add(email_key)
cursor.execute(
"""
INSERT INTO emails (
message_id, subject, from_address, date, body, file_name
)
VALUES (?, ?, ?, ?, ?, ?)
""",
(message_id, subject, from_address, date_str, body, file_name),
)
email_pk_id = cursor.lastrowid
if email_pk_id is None:
logging.warning(
"Failed to get email ID after insert for message_id: %s",
message_id,
)
continue
recipient_data = _prepare_recipient_data(
email_pk_id,
to_list,
cc_list,
bcc_list,
)
if recipient_data:
cursor.executemany(
"""
INSERT INTO recipients (
email_id, recipient_address, recipient_type
)
VALUES (?, ?, ?)
""",
recipient_data,
)
record_count += 1
conn.commit()
conn.close()
logging.info("Successfully inserted %d email records.", record_count)
if skipped_count > 0:
logging.info(
"Skipped %d email records due to length or recipient limits.",
skipped_count,
)
if duplicate_count > 0:
logging.info(
"Skipped %d duplicate email records "
"(based on subject, body, from).",
duplicate_count,
)
def create_indexes_and_triggers(db_path: str) -> None:
"""Creates indexes and triggers on the populated database."""
logging.info("Creating indexes and triggers for database: %s...", db_path)
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
cursor.executescript(SQL_CREATE_INDEXES_TRIGGERS)
conn.commit()
conn.close()
logging.info("Indexes and triggers created successfully.")
def generate_database(
repo_id: str = DEFAULT_REPO_ID,
db_path: str = DEFAULT_DB_PATH,
overwrite: bool = False,
) -> None:
"""
Generates the SQLite database from the specified Hugging Face dataset.
Simplified version without extensive error handling.
Args:
repo_id: The Hugging Face repository ID for the dataset.
db_path: The path where the SQLite database file should be
created.
overwrite: If True, any existing database file at db_path will
be removed.
"""
logging.info(
"Starting database generation for repo '%s' at '%s'",
repo_id,
db_path,
)
logging.info("Overwrite existing database: %s", overwrite)
db_dir = os.path.dirname(db_path)
if db_dir and not os.path.exists(db_dir):
logging.info("Creating data directory: %s", db_dir)
os.makedirs(db_dir)
if overwrite and os.path.exists(db_path):
logging.warning("Removing existing database file: %s", db_path)
os.remove(db_path)
elif not overwrite and os.path.exists(db_path):
# If not overwriting and file exists, subsequent steps might fail
# or behave unexpectedly. We are removing the explicit error here
# as requested.
logging.warning(
"Database file %s exists and overwrite is False. "
"Assuming file is already generated.",
db_path,
)
return
# 1. Download dataset
dataset = download_dataset(repo_id)
# 2. Create database schema (Tables only)
# Note: This will fail if overwrite=False and the file exists with
# incompatible schema/data.
create_database(db_path)
# 3. Populate database
populate_database(db_path, dataset)
# 4. Create Indexes and Triggers
create_indexes_and_triggers(db_path)
logging.info("Database generation process completed for %s.", db_path)
logging.info(
"Please set the environment variable DEFAULT_EMAIL_DB_PATH "
"to this path.",
)
if __name__ == "__main__":
generate_database(overwrite=True)

Binary file not shown.

After

Width:  |  Height:  |  Size: 442 KiB

271
tuner/frozen_lake/README.md Normal file
View File

@@ -0,0 +1,271 @@
# Training FrozenLake Agent with RL using AgentScope-Tuner
## Summary
This example demonstrates how to use AgentScope-Tuner to implement reinforcement fine-tuning for the [Frozen Lake](https://gymnasium.farama.org/environments/toy_text/frozen_lake/) task using [Trinity-RFT](https://github.com/agentscope-ai/Trinity-RFT). The agent learns to navigate a frozen lake grid from a starting position to a goal while avoiding holes through multi-step interactions with the environment.
## Task Setting
### Agent Goal
The agent's objective is to navigate from the starting position (S) to the goal position (G) on a frozen lake grid without falling into holes (H). The agent must:
- Plan a path through frozen tiles (F) to reach the goal
- Avoid holes that terminate the episode with zero reward
- Complete the task within a limited number of steps
### Agent Type
The agent is implemented as a **ReActAgent** (Reasoning and Acting Agent) that:
- Observes the current state of the frozen lake grid
- Reasons about the best action to take
- Executes actions (Up, Down, Left, Right) to move through the environment
- Maintains internal state across multiple steps in an episode
### Environment
The environment is based on Gymnasium's FrozenLake environment, wrapped to provide:
- **Grid-based navigation**: Randomly generated maps with configurable size (2x2 to 6x6)
- **Tile types**:
- `S`: Start position
- `F`: Frozen tile (safe to walk on)
- `H`: Hole (terminates episode with reward 0)
- `G`: Goal (terminates episode with reward +1.0)
- **Action space**: Discrete actions (Up, Down, Left, Right)
- **Reward structure**:
- +1.0 for reaching the goal
- 0.0 for falling into a hole or failing to reach the goal
- **Observations**: Text-based grid representation showing current player position
The agent does not use external tools. It interacts directly with the environment through:
- `env.reset(task)`: Initialize environment with task parameters
- `env.step(action)`: Execute action and receive observation, reward, and done flag
- `env.render()`: Get text representation of current state
## Dataset Preparation
The dataset contains task parameters for generating FrozenLake environments. Each sample specifies:
- `seed`: Random seed for reproducible map generation
- `size`: Grid size (randomly sampled from 2 to `map_max_size`, e.g., 4x4, 6x6)
- `p`: Probability that a tile is frozen (vs. being a hole), randomly sampled from 0.6 to 0.85
- `index`: Sample index
- `uid`: Unique identifier combining seed, size, and p
Run the data preparation script to generate training and test datasets:
```bash
python get_frozenlake_data.py --map_max_size 6 --train_size 10000 --test_size 100
```
This will create parquet files in the specified directory:
```
/path/to/frozenlake_dataset/
├── train.parquet # 10000 training samples
└── test.parquet # 100 test samples
```
Each sample looks like:
```json
{"seed": 12345, "size": 5, "p": 0.75, "index": 0, "uid": "12345_5_0.75"}
```
**Note**: The data preparation script ensures that all generated maps have a valid path from start to goal within the maximum allowed steps (`env_max_steps=8`), filtering out unsolvable tasks.
## Code Implementation
This section provides a high-level overview of the code implementation. For detailed implementation, please refer to the source code.
### High-level Overview
The implementation consists of three main components:
1. **Agent** (`FrozenLakeAgent`): Extends `ReActAgent` to handle multi-step navigation
2. **Environment** (`FrozenLakeEnv`): Wraps Gymnasium's FrozenLake environment
3. **Workflow** (`run_frozen_lake`): Orchestrates the agent-environment interaction loop
### Agent Workflow
The workflow function `run_frozen_lake` implements the agent-environment interaction loop:
```python
async def run_frozen_lake(
task: Dict,
model: ChatModelBase,
auxiliary_models: Dict[str, ChatModelBase],
) -> WorkflowOutput:
# ...
# Create agent and environment
agent = FrozenLakeAgent(model=model, ...)
env = FrozenLakeEnv(...)
observation, _ = env.reset(task)
rewards = []
# ...
# Agent-environment interaction loop
for _ in range(max_steps):
response = await agent.reply(msg=Msg("user", agent.get_prompt(observation), role="user"))
action = agent.get_action(response)
observation, reward, done, _ = env.step(action)
rewards.append(reward)
if done:
break
# ...
final_reward = sum(rewards)
final_response = Msg("assistant", response_content, role="assistant")
return WorkflowOutput(
reward=final_reward,
response=final_response,
metrics={
"env_steps": float(step_count),
"env_done": float(done),
},
)
```
**Key characteristics:**
- Multi-step interaction: The agent takes multiple actions in a single episode, unlike single-turn QA tasks
- State tracking: The agent maintains internal state (current step, last action, last observation) across steps
- Error handling: Invalid actions or agent errors are caught and handled gracefully
### Reward Function
No separate judge function is needed. The reward comes directly from the environment:
- 1.0: Agent successfully reaches the goal (G)
- 0.0: Agent falls into a hole (H) or fails to reach the goal within the maximum steps
The reward is computed as the sum of step rewards throughout the episode. The workflow returns:
- `reward`: Final cumulative reward
- `response`: Final response message containing observation, total reward, steps taken, and termination reason
- `metrics`: Additional metrics including `env_steps` (number of steps taken) and `env_done` (whether episode completed)
### Implementation Details
The environment (`FrozenLakeEnv`) wraps Gymnasium's FrozenLake and provides:
- `reset(task)`: Initialize the environment with task parameters
- `step(action)`: Execute an action and return (observation, reward, done, info)
- `render()`: Return a text representation of the current state
The agent (`FrozenLakeAgent`) extends `ReActAgent` and provides:
- `reply(msg)`: Reply to a message and return an action (inherited from AgentScope)
- `get_prompt(observation)`: Generate a prompt from the current observation
- `get_action(response)`: Parse the model's response to extract an action (Up/Down/Left/Right)
- `update_state(action, observation)`: Update internal state after each step
See [frozenlake_env.py](./frozenlake_env.py) and [frozenlake_agent.py](./frozenlake_agent.py) for implementation details.
### Step 4: Use `tune` to train the workflow
```python
from agentscope.tuner import tune, DatasetConfig
if __name__ == "__main__":
config_path = os.path.join(
os.path.dirname(__file__),
"config.yaml",
)
dataset = DatasetConfig(
path="/path/to/frozenlake_dataset",
name="default",
split="train",
)
tune(
workflow_func=run_frozen_lake,
train_dataset=dataset,
config_path=config_path,
)
```
See [config.yaml](./config.yaml) for the training configuration. For full configuration details, see [Trinity-RFT Configuration Guide](https://agentscope-ai.github.io/Trinity-RFT/en/main/tutorial/trinity_configs.html).
---
## How to Run
### Prerequisites
- At least 2 NVIDIA GPUs with CUDA 12.8 or newer
- Follow the Trinity-RFT [installation guide](https://agentscope-ai.github.io/Trinity-RFT/en/main/tutorial/trinity_installation.html) to install the latest version from source code
- Install gymnasium for the FrozenLake environment:
```bash
pip install gymnasium[toy_text]
```
- Download the model checkpoint (example):
```bash
huggingface-cli download Qwen/Qwen2.5-3B-Instruct
```
### Step 1: Prepare the Dataset
```bash
python get_frozenlake_data.py --map_max_size 6 --train_size 10000 --test_size 100
```
Update the dataset path in `main.py` to point to your generated dataset directory.
### Step 2: Configure the Training
Key configuration can be identified in the code, including:
**Algorithm Configuration** (`AlgorithmConfig`):
- `algorithm_type`: `multi_step_grpo` (Group Relative Policy Optimization for multi-step tasks)
- `group_size`: Number of policy update iterations per batch (default: 16)
- `batch_size`: Batch size for training (default: 32)
- `learning_rate`: Learning rate (default: 1e-6)
**Model Configuration** (`TunerModelConfig`):
- `model_path`: Path to the base model (e.g., `Qwen/Qwen2.5-3B-Instruct`)
- `max_model_len`: Maximum model context length (default: 25600)
- `max_tokens`: Maximum tokens for response generation (default: 2048)
- `inference_engine_num`: Number of inference engines (default: 6, using 6 GPUs for inference)
**Dataset Configuration** (`DatasetConfig`):
- `path`: Path to the dataset (default: `/path/to/frozenlake`)
- `split`: Split of the dataset (default: `train`)
Adjust these parameters based on your hardware resources and training requirements. Other parameters can be spetified in [config.yaml](./config.yaml).
### Step 3: Set Up Ray Cluster
Set up a [Ray](https://github.com/ray-project/ray) cluster:
```bash
ray start --head
# for multi-node setup, run the following command on worker nodes
# ray start --address=<master_address>
```
### Step 4: Run the Training Script
```bash
python main.py
```
The training will start and you can monitor the progress through the logs. Checkpoints will be saved once every `trainer.save_interval` steps.
## Experimental Results
### Training Reward Curve
The reward curve during training shows the agent's learning progress:
![reward](./critic_rewards_mean.png)
The training reward typically increases over epochs as the agent learns to navigate the frozen lake more effectively.
### Example Agent Output
An example of agent output is given below:
```
From the current observation, let's analyze the situation. The player (P) is at: (4, 0), and the goal (G) is at: (2, 3). There is also a hole (O) at (4, 4). Given this, I can move towards the goal without worrying about slippery tiles right now.
The shortest path from P to G involves moving left (4 steps) followed by moving down (1 step), since going directly would bypass the hole or move us further from the goal. Let's move left first.
Let's take the action ```Left```.
```

View File

@@ -0,0 +1,250 @@
# 使用 AgentScope-Tuner 训练 FrozenLake Agent
## 摘要
本示例展示如何使用 AgentScope-Tuner 配合 [Trinity-RFT](https://github.com/agentscope-ai/Trinity-RFT) 对 [Frozen Lake](https://gymnasium.farama.org/environments/toy_text/frozen_lake/) 任务进行强化微调。智能体需要在冰湖网格中从起点走到终点,避开坑洞,并在有限步数内完成任务。
## 任务设定
### 智能体目标
智能体要在冰湖网格上从起点 (S) 抵达终点 (G),同时:
- 规划路径经过冰面 (F) 到达终点
- 避开会结束回合且奖励为 0 的坑洞 (H)
- 在限定步数内完成任务
### 智能体类型
智能体实现为 **ReActAgent**,它的行为包括:
- 观察当前冰湖网格状态
- 推理下一步最优动作
- 执行动作(上、下、左、右)在环境中移动
- 在多步交互中维护内部状态
### 环境
环境基于 Gymnasium 的 FrozenLake并提供
- **网格导航**:随机生成 2x2 至 6x6 的地图
- **格子类型**
- `S`:起点
- `F`:冰面(可通行)
- `H`:坑洞(奖励 0结束回合
- `G`:终点(奖励 +1.0,结束回合)
- **动作空间**:离散动作(上、下、左、右)
- **奖励设计**
- 到达终点 +1.0
- 掉入坑洞或未在最大步数内到达终点为 0.0
- **观测**:返回当前玩家位置的文本网格表示
智能体不使用外部工具,直接通过以下接口与环境交互:
- `env.reset(task)`:根据任务参数初始化环境
- `env.step(action)`:执行动作,返回观测、奖励和结束标志
- `env.render()`:返回当前状态的文本表示
## 数据集准备
数据集包含用于生成 FrozenLake 环境的任务参数,每个样本包含:
- `seed`:随机种子,保证地图可复现
- `size`:网格大小(在 2 和 `map_max_size` 之间随机,如 4x4、6x6
- `p`格子为冰面的概率0.6 到 0.85 之间随机),其余为坑洞
- `index`:样本索引
- `uid`:由 seed、size、p 组合而成的唯一 ID
运行数据准备脚本生成训练集与测试集:
```bash
python get_frozenlake_data.py --map_max_size 6 --train_size 10000 --test_size 100
```
生成的目录结构示例:
```
/path/to/frozenlake_dataset/
├── train.parquet # 10000 条训练样本
└── test.parquet # 100 条测试样本
```
样本示例:
```json
{"seed": 12345, "size": 5, "p": 0.75, "index": 0, "uid": "12345_5_0.75"}
```
**注意**:脚本会过滤无解的地图,确保在最大步数 (`env_max_steps=8`) 内存在从起点到终点的可行路径。
## 代码实现
本节提供代码实现的高级概览。详细实现请参考源代码。
### 高级概览
实现由三部分组成:
1. **Agent** (`FrozenLakeAgent`):继承 `ReActAgent`,负责多步交互
2. **环境** (`FrozenLakeEnv`):封装 Gymnasium FrozenLake
3. **工作流** (`run_frozen_lake`):组织智能体与环境的交互流程
### 工作流
`run_frozen_lake` 实现多步交互流程:
```python
async def run_frozen_lake(
task: Dict,
model: ChatModelBase,
auxiliary_models: Dict[str, ChatModelBase],
) -> WorkflowOutput:
# ...
# 创建智能体和环境
agent = FrozenLakeAgent(model=model, ...)
env = FrozenLakeEnv(...)
observation, _ = env.reset(task)
rewards = []
# ...
# 智能体-环境交互循环
for _ in range(max_steps):
response = await agent.reply(msg=Msg("user", agent.get_prompt(observation), role="user"))
action = agent.get_action(response)
observation, reward, done, _ = env.step(action)
rewards.append(reward)
if done:
break
# ...
final_reward = sum(rewards)
final_response = Msg("assistant", response_content, role="assistant")
return WorkflowOutput(
reward=final_reward,
response=final_response,
metrics={"env_steps": float(step_count), "env_done": float(done)},
)
```
**关键特性:**
- 多步交互:单次 episode 内多次动作,不是单轮 QA
- 状态跟踪:记录当前步、上次动作与观测
- 错误处理:无效动作或异常会被捕获并处理
### 奖励函数
无需额外 judge奖励由环境直接给出
- 1.0:到达终点
- 0.0:掉入坑洞或超步数未达终点
工作流返回:
- `reward`:累计奖励
- `response`:包含观测、总奖励、步数、终止原因的最终回复
- `metrics``env_steps`(步数)、`env_done`(是否结束)
### 实现细节
环境 (`FrozenLakeEnv`) 封装了 Gymnasium 的 FrozenLake提供
- `reset(task)`: 使用任务参数初始化环境
- `step(action)`: 执行动作并返回 (observation, reward, done, info)
- `render()`: 返回当前状态的文本表示
智能体 (`FrozenLakeAgent`) 继承 `ReActAgent`,提供:
- `reply(msg)`: 回复消息并返回动作(继承自 AgentScope
- `get_prompt(observation)`: 从当前观测生成提示
- `get_action(response)`: 解析模型响应以提取动作Up/Down/Left/Right
- `update_state(action, observation)`: 在每步后更新内部状态
详细实现请参考 [frozenlake_env.py](./frozenlake_env.py) 和 [frozenlake_agent.py](./frozenlake_agent.py)。
### 步骤 4使用 `tune` 训练工作流
```python
from agentscope.tuner import tune, DatasetConfig
if __name__ == "__main__":
config_path = os.path.join(
os.path.dirname(__file__),
"config.yaml",
)
dataset = DatasetConfig(
path="/path/to/frozenlake_dataset",
name="default",
split="train",
)
tune(
workflow_func=run_frozen_lake,
train_dataset=dataset,
config_path=config_path,
)
```
训练配置请参考 [config.yaml](./config.yaml)。完整配置详情请参考 [Trinity-RFT 配置指南](https://agentscope-ai.github.io/Trinity-RFT/en/main/tutorial/trinity_configs.html)。
---
## 运行方法
### 依赖
- 至少 2 张 NVIDIA GPUCUDA 版本 ≥ 12.8
- 按 [Trinity-RFT 安装指南](https://agentscope-ai.github.io/Trinity-RFT/en/main/tutorial/trinity_installation.html) 从源码安装
- 安装 gymnasium 冰湖环境:
```bash
pip install gymnasium[toy_text]
```
- 下载模型权重(示例):
```bash
huggingface-cli download Qwen/Qwen2.5-3B-Instruct
```
### 步骤 1准备数据集
```bash
python get_frozenlake_data.py --map_max_size 6 --train_size 10000 --test_size 100
```
将 `main.py` 中的数据集路径改为你的生成目录。
### 步骤 2配置训练
关键配置可在代码中设置,包括:
**算法配置** (`AlgorithmConfig`)
- `algorithm_type`: `multi_step_grpo`(用于多步任务的组相对策略优化)
- `group_size`: 每批次的策略更新组大小(默认 16
- `batch_size`: 批大小(默认 32
- `learning_rate`: 学习率(默认 1e-6
**模型配置** (`TunerModelConfig`)
- `model_path`: 基础模型路径(如 `Qwen/Qwen2.5-3B-Instruct`
- `max_model_len`: 最大上下文长度(默认 25600
- `max_tokens`: 响应最大生成长度(默认 2048
- `inference_engine_num`: 推理引擎数量(默认 6表示用 6 个 GPU 进行推理)
**数据集配置** (`DatasetConfig`)
- `path`: 数据集路径(默认 `/path/to/frozenlake`
- `split`: 数据集分片(默认 `train`
可根据硬件资源和训练需求调整这些参数。其他参数可在 [config.yaml](./config.yaml) 中指定。
### 步骤 3设置 Ray 集群
设置 [Ray](https://github.com/ray-project/ray) 集群:
```bash
ray start --head
# 对于多节点设置,在工作节点上运行以下命令
# ray start --address=<master_address>
```
### 步骤 4运行训练脚本
```bash
python main.py
```
训练将开始,可通过日志监控进度。检查点将每 `trainer.save_interval` 步保存一次。
## 实验结果
### 训练奖励曲线
训练过程中的奖励曲线显示智能体的学习进度:
![reward](./critic_rewards_mean.png)
训练奖励通常随着智能体学习更有效地导航冰湖而随训练轮次增加。
### 智能体输出示例
智能体输出示例如下:
```
From the current observation, let's analyze the situation. The player (P) is at: (4, 0), and the goal (G) is at: (2, 3). There is also a hole (O) at (4, 4). Given this, I can move towards the goal without worrying about slippery tiles right now.
The shortest path from P to G involves moving left (4 steps) followed by moving down (1 step), since going directly would bypass the hole or move us further from the goal. Let's move left first.
Let's take the action ```Left```.
```

View File

@@ -0,0 +1,102 @@
# -*- coding: utf-8 -*-
"""Adapted from Trinity-RFT"""
import re
from _utils import SYSTEM_PROMPT, FrozenLakeAction # pylint: disable=E0611
from agentscope.agent import ReActAgent
from agentscope.formatter import OpenAIChatFormatter
from agentscope.message import Msg
from agentscope.model import OpenAIChatModel
INVALID_ACTION = "still"
VALID_ACTIONS = {
"left": 1,
"down": 2,
"right": 3,
"up": 4,
}
class FrozenLakeAgent(ReActAgent):
"""Agent for FrozenLake environment."""
def __init__(self, model: OpenAIChatModel, max_steps: int = 20):
super().__init__(
name="frozenlake_agent",
model=model,
sys_prompt=SYSTEM_PROMPT,
formatter=OpenAIChatFormatter(),
max_iters=1,
)
self.response_structure = FrozenLakeAction
self.current_step = 0
self.last_action = None
self.last_observation = None
self.max_steps = max_steps
def get_prompt(self, observation: str) -> str:
"""Get prompt for the agent based on current observation."""
prompt = (
f"Current Observation ({self.current_step}): \n"
+ observation
+ "\n"
+ (
"You have not achieved the goal, P has not reached G yet. "
"Please give the next action."
)
)
if self.current_step > 0 and self.last_action is not None:
if self.last_observation == observation:
prompt += (
"\nYour last response is invalid. "
"Your position didn't change at all. "
"You may need to recheck your thinking process, "
"action outputted, and the format of response. "
"Remember, you should only output the NEXT ACTION "
"at each iteration in the ``` ```. "
"For example, if you want to move up, "
"you should output ```Up```."
)
if (
self.max_steps is not None
and self.max_steps - self.current_step > 0
):
remaining = self.max_steps - self.current_step
prompt += (
f"\nThe maximum number of steps remaining is {remaining}."
)
return prompt
def get_action(self, msg: Msg) -> str:
"""Extract action from agent response message."""
response: str = (
msg.content
if isinstance(msg.content, str)
else msg.content[0].get("text")
)
action = INVALID_ACTION
matches = re.findall(r"```(.*?)```", response, re.DOTALL)
if matches:
last_match_content = matches[-1].strip()
action = last_match_content.lower()
if action not in VALID_ACTIONS:
action = INVALID_ACTION
return action
def update_state(self, action: str, observation: str) -> None:
"""Update agent state with action and observation."""
self.last_action = action
self.last_observation = observation
self.current_step += 1
async def reset(self) -> None:
"""Reset agent state for a new episode."""
self.current_step = 0
self.last_action = None
self.last_observation = None
await self.memory.clear()

View File

@@ -0,0 +1,316 @@
# -*- coding: utf-8 -*-
"""Adapted from Trinity-RFT"""
import copy
from typing import Dict, Optional, Tuple, Union
import numpy as np
try:
from gymnasium.envs.toy_text.frozen_lake import (
FrozenLakeEnv as GymFrozenLakeEnv,
)
except ImportError:
GymFrozenLakeEnv = object
from _utils import ( # pylint: disable=E0611
generate_random_map,
get_goal_position,
)
class FrozenLakeEnv(GymFrozenLakeEnv):
"""FrozenLake environment wrapper."""
# Map gym state in integer
MAP_LOOKUP = {
b"P": 0,
b"F": 1,
b"H": 2,
b"G": 3,
}
# Define rules to transform to rendered text observation of the environment
GRID_LOOKUP = {
0: " P \t", # player
1: " _ \t", # frozen
2: " O \t", # hole
3: " G \t", # goal
4: " X \t", # player fall into hole
5: "\t", # player on goal
}
ACTION_LOOKUP = {
"still": 0,
"left": 1,
"down": 2,
"right": 3,
"up": 4,
}
INVALID_ACTION = 0
PENALTY_FOR_INVALID = -1
def __init__(
self,
max_steps: int = 8,
desc: Optional[str] = None,
is_slippery: bool = False,
size: int = 8,
p: float = 0.8,
seed: int = 42,
):
self.max_steps = max_steps or 8
self.desc: Union[str, np.ndarray, None] = desc
self.is_slippery = is_slippery
self.size = size
self.p = p
self.seed = seed
self.render_mode: Optional[str] = None
try:
import gymnasium as gym
except ImportError as e:
error_message = (
"Gymnasium is not installed. "
"Please install gymnasium first before "
"running the frozen_lake workflow. "
f"Error: {str(e)}"
)
raise ImportError(error_message) from e
if self.desc is None:
random_map, goal_position = generate_random_map(
size=self.size,
p=self.p,
seed=self.seed,
max_steps=self.max_steps,
)
else:
random_map = np.asarray(copy.deepcopy(self.desc), dtype="c")
goal_position = get_goal_position(random_map)
self.goal_position = goal_position
super().__init__(
desc=random_map[:],
is_slippery=self.is_slippery,
)
assert isinstance(self.desc, np.ndarray)
self.action_space = gym.spaces.Discrete(4, start=1)
self.map_kwargs = {
"size": size,
"p": p,
}
self.env_kwargs = {
"is_slippery": is_slippery,
"desc": copy.deepcopy(desc),
"seed": seed,
}
self.action_map = {
1: 0, # left
2: 1, # down
3: 2, # right
4: 3, # up
}
def _get_player_position(self) -> Tuple[int, int]:
return (self.s // self.ncol, self.s % self.ncol) # (row, col)
def step(self, action: str) -> Tuple[str, float, bool, Dict]:
"""Execute a step in the environment.
Maps custom action to gymnasium FrozenLakeEnv action and
takes the step. Checks if the action is effective (whether
player moves in the env).
Args:
action: The action to take.
Returns:
Tuple of (observation, reward, done, info).
"""
if self.success():
obs = self.render(mode="tiny_rgb_array")
assert isinstance(obs, str)
return obs, 1.0, True, {"action_is_effective": False}
action_id: int = self.ACTION_LOOKUP.get(action.lower(), 0)
if not action_id:
action_id = self.INVALID_ACTION
if (
action_id == self.INVALID_ACTION
or action_id not in self.action_map
):
obs = self.render(mode="tiny_rgb_array")
assert isinstance(obs, str)
return obs, 0.0, False, {"action_is_effective": False}
prev_player_position = int(self.s)
# Call parent class step method
# Note: GymFrozenLakeEnv is imported at module level
player_pos, reward, done, _, _ = super().step(
self.action_map[action_id],
)
obs = self.render(mode="tiny_rgb_array")
assert isinstance(obs, str)
return (
obs,
float(reward),
bool(done),
{"action_is_effective": prev_player_position != int(player_pos)},
)
def render(
self,
mode: str = "tiny_rgb_array",
) -> str | list[str] | np.ndarray:
"""Render the environment.
Args:
mode: Rendering mode. Options: "tiny_rgb_array", "list",
"state", "rgb_array", "ansi".
Returns:
Rendered observation based on the mode.
"""
assert mode in [
"tiny_rgb_array",
"list",
"state",
"rgb_array",
"ansi",
]
if mode in ["rgb_array", "ansi"]:
prev_render_mode = self.render_mode
self.render_mode = mode
obs = super().render()
self.render_mode = prev_render_mode
return obs
assert isinstance(self.desc, np.ndarray)
room_state = copy.deepcopy(self.desc)
# replace the position of start 'S' with 'F'
position_S = np.where(room_state == b"S")
room_state[position_S] = b"F"
# replace the position of the player with 'P'
position_P = self._get_player_position()
room_state[position_P] = b"P"
if mode == "state":
# transform 'S', 'F', 'H', 'G' to numpy integer array
room_state = np.vectorize(lambda x: self.MAP_LOOKUP[x])(room_state)
# add player in hole or player on goal
if self.desc[position_P] == b"H":
room_state[position_P] = 4
elif self.desc[position_P] == b"G":
room_state[position_P] = 5
return room_state
room_state = self.render(mode="state").tolist()
assert isinstance(room_state, list)
if mode == "list":
def lookup_list(cell: int) -> str:
return self.GRID_LOOKUP.get(cell, "?").strip("\t").strip()
return [
" ".join(lookup_list(cell) for cell in row)
for row in room_state
]
if mode == "tiny_rgb_array":
def lookup_tiny(cell: int) -> str:
return self.GRID_LOOKUP.get(cell, "?")
result = "\n".join(
"".join(lookup_tiny(cell) for cell in row)
for row in room_state
)
return result
# Default return for other modes
return ""
def reset(
self,
task: Optional[Dict] = None,
) -> tuple[str, Dict]:
"""Reset the environment with optional task parameters."""
task = task or {}
# Update parameters from task if provided
size = task.get("size", self.map_kwargs["size"])
p = task.get("p", self.map_kwargs["p"])
seed = task.get("seed", self.env_kwargs["seed"])
is_slippery = task.get(
"is_slippery",
self.env_kwargs["is_slippery"],
)
desc = task.get("desc", self.env_kwargs.get("desc"))
# Update instance variables
self.size = size
self.p = p
self.seed = seed
self.is_slippery = is_slippery
self.map_kwargs["size"] = size
self.map_kwargs["p"] = p
self.env_kwargs["seed"] = seed
self.env_kwargs["is_slippery"] = is_slippery
if desc is not None:
self.env_kwargs["desc"] = copy.deepcopy(desc)
if desc is None:
random_map, goal_position = generate_random_map(
size=size,
p=p,
seed=seed,
max_steps=self.max_steps,
)
else:
random_map = np.asarray(copy.deepcopy(desc), dtype="c")
goal_position = get_goal_position(random_map)
self.goal_position = goal_position
self.desc = random_map[:]
# Reinitialize parent class with new map
try:
import gymnasium as gym
super().__init__(
desc=random_map[:],
is_slippery=self.is_slippery,
)
assert isinstance(self.desc, np.ndarray)
self.action_space = gym.spaces.Discrete(4, start=1)
except ImportError as e:
error_message = (
"Gymnasium is not installed. "
"Please install gymnasium first before "
"running the frozen_lake workflow. "
f"Error: {str(e)}"
)
raise ImportError(error_message) from e
super().reset(seed=self.seed)
obs = self.render(mode="tiny_rgb_array")
assert isinstance(obs, str)
return obs, {}
def finished(self) -> bool:
"""Check if the episode is finished (goal or hole)."""
player_pos = self._get_player_position()
assert isinstance(self.desc, np.ndarray)
return self.desc[player_pos] in b"GH" # type: ignore
def success(self) -> bool:
"""Check if the agent has reached the goal (G)."""
player_pos = self._get_player_position()
assert isinstance(self.desc, np.ndarray)
return self.desc[player_pos] in b"G"

209
tuner/frozen_lake/_utils.py Normal file
View File

@@ -0,0 +1,209 @@
# -*- coding: utf-8 -*-
"""
Utils for the FrozenLake environment.
Modified from rllm
"""
from typing import Literal, Optional, Tuple
import numpy as np
from pydantic import BaseModel, Field
# Map gym state in integer
MAP_LOOKUP = {
b"P": 0,
b"F": 1,
b"H": 2,
b"G": 3,
}
# Define rules to transform to rendered text observation of the environment
GRID_LOOKUP = {
0: " P \t", # player
1: " _ \t", # frozen
2: " O \t", # hole
3: " G \t", # goal
4: " X \t", # player fall into hole
5: "\t", # player on goal
}
ACTION_LOOKUP = {
0: "None",
1: "Left",
2: "Down",
3: "Right",
4: "Up",
}
# Prompting format inspired by the RAGEN project
SYSTEM_PROMPT = """You are Qwen, created by Alibaba Cloud. \
You are a helpful assistant. You are walking on a frozen lake.
FrozenLake Quick Guide
Goal: Reach the goal (G). Player (P) and Goal (G) must overlap.
Symbols:
_ Frozen | O Hole | G Goal | P Player
Rules:
1. Avoid falling into holes (O).
2. Frozen tiles are slippery, you may move perpendicular to
your intended direction.
Valid Action (separated by | ):
Up | Down | Left | Right
Rewards:
Fall into hole: 0
Reach goal: +1.0
You will be provided the current observation, please decide on
the next Action.
You should show your thought process and then input the final
action in ``` ```.
You should only output the NEXT ACTION at each iteration in
the ``` ```. For example, if you want to move up, you should
output ```Up```.
You should plan ahead and need to achieve it in minimum number
of steps.
You should be aware that frozen tiles can be slippery, but the
chance is small and you should not overthink it.
Please show your thinking process and put the final action in
``` ```. In every turn, the final action MUST be one of Up,
Down, Left, Right.
"""
class FrozenLakeAction(BaseModel):
"""Action model for FrozenLake environment."""
action: Literal["Up", "Down", "Left", "Right"] = Field(
description=(
"The action to take in the FrozenLake environment, "
"must be one of Up, Down, Left, Right"
),
)
def is_valid(board: list[list[str]], max_size: int, max_steps: int) -> bool:
"""DFS to check that it's a valid path.
Args:
board: The board representation as a list of lists.
max_size: Maximum size of the board.
max_steps: Maximum number of steps allowed.
Returns:
True if there's a valid path from start to goal within max_steps,
False otherwise.
"""
frontier, discovered = [], set()
# find the start point
start_r, start_c = np.where(np.array(board) == "S")
frontier.append((start_r[0], start_c[0], 0)) # row, col steps
# dfs to check if there is a path from start to goal
while frontier:
r, c, steps = frontier.pop()
if steps > max_steps:
continue
if (r, c) not in discovered:
discovered.add((r, c))
directions = [(1, 0), (0, 1), (-1, 0), (0, -1)]
for x, y in directions:
r_new = r + x
c_new = c + y
if (
r_new < 0
or r_new >= max_size
or c_new < 0
or c_new >= max_size
): # noqa: PLR2004
continue
if board[r_new][c_new] == "G":
return True
if board[r_new][c_new] != "H":
frontier.append((r_new, c_new, steps + 1))
return False
def generate_random_map(
size: int = 8,
p: float = 0.8,
seed: int = 0,
max_steps: int = 5,
) -> Tuple[list[str], Tuple[int, int]]:
"""Generates a random valid map (one that has a path from start to goal).
Args:
size: Size of each side of the grid.
p: Probability that a tile is frozen.
seed: Seed to ensure the generation of reproducible maps.
max_steps: Maximum number of steps allowed.
Returns:
A tuple containing a random valid map and the goal position (row, col).
"""
valid = False
board: list[list[str]] = [] # initialize to make pyright happy
try:
from gymnasium.utils import seeding
np_random, _ = seeding.np_random(seed)
except ImportError as exc:
raise ImportError(
"Gymnasium is not installed. "
"Please install gymnasium first before "
"running the frozen_lake workflow.",
) from exc
# generate random start and end points
while not valid:
p = min(1, p)
board = np_random.choice(
["F", "H"],
(size, size),
p=[p, 1 - p],
).tolist()
while True:
start_r = int(np_random.integers(0, size))
start_c = int(np_random.integers(0, size))
goal_r = int(np_random.integers(0, size))
goal_c = int(np_random.integers(0, size))
# Ensure start and goal are different positions
if (start_r, start_c) != (goal_r, goal_c):
break
board[start_r][start_c] = "S"
board[goal_r][goal_c] = "G"
valid = is_valid(board, size, max_steps)
return ["".join(x) for x in board], (goal_r, goal_c)
def get_goal_position(
random_map: np.ndarray,
) -> Optional[Tuple[int, int]]:
"""Get the goal position from a random map.
Args:
random_map: The map as a numpy array.
Returns:
Tuple of (row, col) if goal found, None otherwise.
"""
positions = np.argwhere(random_map == b"G")
if positions.size == 0:
return None # G not found
return tuple(positions[0]) # returns (row, col)
__all__ = [
"SYSTEM_PROMPT",
"FrozenLakeAction",
"generate_random_map",
"get_goal_position",
]

View File

@@ -0,0 +1,53 @@
project: "AgentScope" # Project name
name: "FrozenLake" # Experiment name
checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints} # Directory to save model checkpoints
algorithm:
algorithm_type: multi_step_grpo # GRPO series for multi-step scenario
repeat_times: 16 # Number of rollouts per prompt for advantage estimation
kl_loss_fn: "low_var_kl"
kl_loss_fn_args:
kl_coef: 0 # KL divergence coefficient
advantage_fn_args:
epsilon: 1e-6 # Small value for numerical stability
std_threshold: 0.0001 # Threshold for standard deviation
optimizer:
lr: 1e-6 # Learning rate
model:
model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-3B-Instruct} # Base model path
max_prompt_tokens: 23552 # Max tokens for prompt
max_response_tokens: 2048 # Max tokens per response
max_model_len: 25600 # Max context length
temperature: 1.0 # Sampling temperature
buffer:
total_epochs: 5 # Total training epochs
batch_size: 32 # Batch size per explore step
train_batch_size: 1024 # Total experiences per training step
trainer_input:
experience_buffer:
name: experience_buffer
storage_type: queue
max_read_timeout: 7200 # Max timeout for reading from buffer (seconds)
replay_buffer:
enable: true # Enable experience replay
priority_fn: linear_decay # Priority function for replay buffer
priority_fn_args:
decay: 0.1 # Decay rate for priority function
explorer:
runner_per_model: 16 # Number of runners per model
rollout_model:
engine_num: 6 # Number of vLLM engines for rollout model
tensor_parallel_size: 1 # TP size per engine for rollout model
enable_openai_api: true # Enable OpenAI-compatible API
enable_history: true # Enable conversation history
enable_auto_tool_choice: true # Enable automatic tool selection
tool_call_parser: hermes # Parser for tool calls
trainer:
save_interval: 100 # Save checkpoint every N steps
use_dynamic_bsz: true # Use dynamic batch size
grad_clip: 1.0 # Gradient clipping value
max_token_len_per_gpu: 25600 # Max token length per GPU
ulysses_sequence_parallel_size: 2 # Sequence parallel size for Ulysses
synchronizer:
sync_style: dynamic_by_explorer # Sync triggered dynamically by explorer
sync_interval: 1 # Sync every N steps
sync_timeout: 1200 # Timeout for synchronization (seconds)

Binary file not shown.

After

Width:  |  Height:  |  Size: 62 KiB

View File

@@ -0,0 +1,131 @@
# -*- coding: utf-8 -*-
"""
Modified from rllm
"""
import argparse
import os
import numpy as np
import pandas as pd
DEFAULT_DATA_PATH = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"..",
"data",
"frozenlake",
)
def save_dataset_to_local(
data_path: str,
data: list[dict],
split: str = "default",
) -> str:
"""Save dataset directly to local data_path.
Args:
data_path: Path to save the dataset
data: List of dictionaries containing the dataset examples
split: Split name (e.g., 'train', 'test', 'default')
Returns:
str: Path to the saved parquet file
"""
os.makedirs(data_path, exist_ok=True)
# Convert to DataFrame and save
data_df = pd.DataFrame(data)
dataset_path = os.path.join(data_path, f"{split}.parquet")
data_df.to_parquet(dataset_path)
print(
f"Saved dataset frozenlake split '{split}' "
f"with {len(data)} examples at {dataset_path}. "
f"Make sure to set the environment variable "
f"<TRINITY_TASKSET_PATH> to {data_path}.",
)
return dataset_path
def prepare_frozenlake_data(
data_path: str,
train_size: int = 10000,
test_size: int = 100,
map_max_size: int = 6,
) -> tuple[list[dict], list[dict]]:
"""
Prepare and save FrozenLake datasets for training and testing.
Args:
data_path (str): Path to save the dataset
train_size (int): Number of training examples to generate
test_size (int): Number of test examples to generate
map_max_size (int): Maximum size of the map
Returns:
tuple: (train_data, test_data) - Lists of data dictionaries
"""
# Set random seed for reproducibility
np.random.seed(42)
# Generate random parameters for train and test sets
train_seeds = np.random.randint(0, 100000, size=train_size)
test_seeds = np.random.randint(0, 100000, size=test_size)
train_sizes = np.random.randint(2, map_max_size, size=train_size)
test_sizes = np.random.randint(2, map_max_size, size=test_size)
train_ps = np.random.uniform(0.6, 0.85, size=train_size)
test_ps = np.random.uniform(0.6, 0.85, size=test_size)
def frozenlake_process_fn(
seed: int,
size: int,
p: float,
idx: int,
) -> dict:
"""Process function to create FrozenLake task instances."""
return {
"seed": seed,
"size": size,
"p": p,
"index": idx,
"uid": f"{seed}_{size}_{p}",
}
# Create train and test data
train_data_list = [
frozenlake_process_fn(seed, train_sizes[idx], train_ps[idx], idx)
for idx, seed in enumerate(train_seeds)
]
test_data_list = [
frozenlake_process_fn(seed, test_sizes[idx], test_ps[idx], idx)
for idx, seed in enumerate(test_seeds)
]
# Save datasets directly to local DATA_PATH
save_dataset_to_local(data_path, train_data_list, "train")
save_dataset_to_local(data_path, test_data_list, "test")
return train_data_list, test_data_list
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--local_dir", default=DEFAULT_DATA_PATH)
parser.add_argument("--train_size", type=int, default=10000)
parser.add_argument("--test_size", type=int, default=100)
parser.add_argument("--map_max_size", type=int, default=6)
args = parser.parse_args()
train_data, test_data = prepare_frozenlake_data(
data_path=args.local_dir,
train_size=args.train_size,
test_size=args.test_size,
map_max_size=args.map_max_size,
)
print(f"Train dataset: {len(train_data)} examples")
print(f"Test dataset: {len(test_data)} examples")
print("Sample train example:", train_data[0])
print("Sample test example:", test_data[0])

151
tuner/frozen_lake/main.py Normal file
View File

@@ -0,0 +1,151 @@
# -*- coding: utf-8 -*-
"""Example of training a FrozenLake agent with Trinity-RFT."""
import os
from typing import Dict
from _frozenlake_agent import FrozenLakeAgent
from _frozenlake_env import FrozenLakeEnv
from agentscope.message import Msg
from agentscope.tuner import (
tune,
WorkflowOutput,
DatasetConfig,
TunerModelConfig,
AlgorithmConfig,
)
from agentscope.model import ChatModelBase
async def run_frozen_lake(
task: Dict,
model: ChatModelBase,
auxiliary_models: Dict[str, ChatModelBase],
) -> WorkflowOutput:
"""A workflow function using the FrozenLake agent to solve tasks.
Args:
task (Dict): The task to be solved, containing environment parameters
like size, p, seed, is_slippery, etc.
model (ChatModelBase): The language model to use.
Returns:
WorkflowOutput: The workflow output containing the reward, response and
metrics.
"""
assert len(auxiliary_models) == 0, "No auxiliary models are needed"
# Extract workflow arguments from task or use defaults
workflow_args = task.get("workflow_args", {})
if not workflow_args:
workflow_args = task
env_max_steps = workflow_args.get("env_max_steps", 8)
agent_max_steps = workflow_args.get("agent_max_steps", 10)
is_slippery = workflow_args.get("is_slippery", False)
desc = workflow_args.get("desc", None)
# Extract task-specific arguments (for environment generation)
size = task.get("size", 8)
p = task.get("p", 0.8)
seed = task.get("seed", 42)
# Initialize agent and environment
agent = FrozenLakeAgent(model=model, max_steps=agent_max_steps)
env = FrozenLakeEnv(
max_steps=env_max_steps,
desc=desc,
is_slippery=is_slippery,
size=size,
p=p,
seed=seed,
)
# Reset environment with task parameters
observation, _ = env.reset(task)
observation_str = str(observation)
rewards = []
step_count = 0
done = False
terminate_reason = None
# Run agent-environment interaction loop
for _ in range(agent_max_steps):
step_count += 1
try:
# get prompt
prompt = agent.get_prompt(observation_str)
response = await agent.reply(msg=Msg("user", prompt, role="user"))
# record action and observation
action = agent.get_action(response)
agent.update_state(action=action, observation=observation_str)
except Exception as e:
terminate_reason = f"agent_error: {str(e)}"
break
# environment step
observation, reward, done, _ = env.step(action)
observation_str = str(observation)
rewards.append(reward)
if done:
terminate_reason = "success" if env.success() else "hole"
break
if terminate_reason is None:
terminate_reason = "max_steps_reached"
final_reward = sum(rewards)
final_observation = observation_str
# Create response message with environment information
response_content = (
f"Final observation:\n{final_observation}\n"
f"Total reward: {final_reward}\n"
f"Steps taken: {step_count}\n"
f"Terminate reason: {terminate_reason}"
)
final_response = Msg("assistant", response_content, role="assistant")
return WorkflowOutput(
reward=final_reward,
response=final_response,
metrics={
"env_steps": float(step_count),
"env_done": float(done),
},
)
if __name__ == "__main__":
dataset = DatasetConfig(
path="/path/to/frozenlake",
split="train",
)
tuner_model = TunerModelConfig(
model_path="Qwen/Qwen2.5-3B-Instruct",
max_model_len=25600,
max_tokens=2048,
inference_engine_num=6,
reasoning_parser=None,
)
algorithm = AlgorithmConfig(
algorithm_type="multi_step_grpo",
group_size=16,
batch_size=32,
learning_rate=1e-6,
)
config_path = os.path.join(
os.path.dirname(__file__),
"config.yaml",
) # define some default parameters
tune(
workflow_func=run_frozen_lake,
model=tuner_model,
train_dataset=dataset,
algorithm=algorithm,
config_path=config_path,
)

View File

@@ -313,7 +313,7 @@ python tuner/learn_to_ask/data_prepare/3_rollout_then_evaluate.py \
```
> ⚠️ **Note**: Your trained model must be converted to **Hugging Face format** first.
> See: [Converting FSDP Checkpoints Guide](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/faq.html)
> See: [Converting FSDP Checkpoints Guide](https://agentscope-ai.github.io/Trinity-RFT/en/main/tutorial/faq.html)
---

View File

@@ -248,10 +248,10 @@ if __name__ == "__main__":
Here, we use `DatasetConfig` to load the training dataset, `TunerModelConfig` to initialize the trainable model, and `AlgorithmConfig` to specify the RL algorithm and its hyperparameters.
> Note:
> The `tune` function is based on [Trinity-RFT](https://github.com/modelscope/Trinity-RFT) and it converts the input parameters into a YAML configuration internally.
> The `tune` function is based on [Trinity-RFT](https://github.com/agentscope-ai/Trinity-RFT) and it converts the input parameters into a YAML configuration internally.
> Advanced users can ignore `model`, `train_dataset`, `algorithm` arguments and provide a configuration file path pointing to a YAML file using the `config_path` argument instead (see [config.yaml](./config.yaml) for an example).
> We recommend using the configuration file approach for fine-grained control over the training process and leveraging advanced features provided by Trinity-RFT.
> You can refer to the Trinity-RFT [Configuration Guide](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/trinity_configs.html) for more details on configuration options.
> You can refer to the Trinity-RFT [Configuration Guide](https://agentscope-ai.github.io/Trinity-RFT/en/main/tutorial/trinity_configs.html) for more details on configuration options.
The checkpoint and logs will automatically be saved to the `checkpoints/AgentScope` directory under the current working directory and each run will be saved in a sub-directory suffixed with the current timestamp.
You can find the tensorboard logs inside `monitor/tensorboard` of the checkpoint directory.
@@ -342,7 +342,7 @@ After implementing the workflow function, follow these steps to run the training
- At least 2 NVIDIA GPUs with CUDA 12.8 or newer.
- Adjust the configuration file ([config.yaml](./config.yaml)) based on your hardware.
- Follow the Trinity-RFT [installation guide](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/trinity_installation.html) to install the latest version from source code.
- Follow the Trinity-RFT [installation guide](https://agentscope-ai.github.io/Trinity-RFT/en/main/tutorial/trinity_installation.html) to install the latest version from source code.
- Download the GSM8K dataset and Qwen/Qwen3-0.6B model checkpoints (example):
```bash

View File

@@ -245,10 +245,10 @@ if __name__ == "__main__":
这里用 `DatasetConfig` 加载训练数据集,`TunerModelConfig` 初始化可训练模型,`AlgorithmConfig` 指定 RL 算法及其超参数。
> 注意:
> `tune` 函数基于 [Trinity-RFT](https://github.com/modelscope/Trinity-RFT) 实现,会将输入参数自动转为 YAML 配置。
> `tune` 函数基于 [Trinity-RFT](https://github.com/agentscope-ai/Trinity-RFT) 实现,会将输入参数自动转为 YAML 配置。
> 高级用户可忽略 `model`、`train_dataset`、`algorithm` 参数,直接用 `config_path` 指定 YAML 配置文件(见 [config.yaml](./config.yaml) 示例)。
> 推荐用配置文件方式实现更细粒度的训练控制,充分利用 Trinity-RFT 的高级特性。
> 详细配置说明见 Trinity-RFT [配置指南](https://modelscope.github.io/Trinity-RFT/zh/main/tutorial/trinity_configs.html)。
> 详细配置说明见 Trinity-RFT [配置指南](https://agentscope-ai.github.io/Trinity-RFT/zh/main/tutorial/trinity_configs.html)。
训练产生的 checkpoint 和日志信息会自动保存在当前目录下的 `checkpoints/AgentScope` 中,每次运行会新建带时间戳的子目录。
TensorBoard 日志在 checkpoint 目录下的 `monitor/tensorboard` 中。
@@ -335,7 +335,7 @@ if __name__ == "__main__":
- 至少 2 块 NVIDIA GPUCUDA 12.8 或更高。
- 根据硬件调整配置文件([config.yaml](./config.yaml))。
- 按 Trinity-RFT [安装指南](https://modelscope.github.io/Trinity-RFT/zh/main/tutorial/trinity_installation.html) 从源码安装最新版。
- 按 Trinity-RFT [安装指南](https://agentscope-ai.github.io/Trinity-RFT/zh/main/tutorial/trinity_installation.html) 从源码安装最新版。
- 下载 GSM8K 数据集和 Qwen/Qwen3-0.6B 模型权重(示例):
```bash

View File

@@ -1,4 +1,4 @@
# Please refer to https://modelscope.github.io/Trinity-RFT/en/main/tutorial/trinity_configs.html for detailed explanation of each field.
# Please refer to https://agentscope-ai.github.io/Trinity-RFT/en/main/tutorial/trinity_configs.html for detailed explanation of each field.
project: AgentScope
name: GSM8K-Qwen3-0.6B
# directory to save checkpoints, default to ./checkpoints if TRINITY_CHECKPOINT_ROOT_DIR not set