Add examples for frozenlake and emailsearch (#94)

This commit is contained in:
Yuchang Sun
2026-01-19 12:25:13 +08:00
committed by GitHub
parent 3821fb04ac
commit 654c35127a
26 changed files with 3370 additions and 14 deletions

View File

@@ -0,0 +1,279 @@
# Training Email Search Agent with RL using AgentScope-Tuner
This example demonstrates how to implement reinforcement fine-tuning for the Email Search task (inspired by [ART](https://openpipe.ai/blog/art-e-mail-agent)) using AgentScope-Tuner, whose RFT functionality is backed by [Trinity-RFT](https://github.com/agentscope-ai/Trinity-RFT).
## Task Setting
The agent's goal is to answer user queries by searching through an email inbox. The agent needs to:
- Understand the user's question
- Search for relevant emails using keywords
- Read email contents to extract information
- Provide accurate answers with proper source citations
**Agent Type**: The agent (`EmailSearchAgent`) extends `ReActAgent`, which follows a reasoning-acting loop to solve tasks iteratively.
**Environment**: The environment is a SQLite database containing emails from the Enron Email dataset. Each task provides:
- `question`: The user's email search query
- `inbox_address`: The email inbox to search
- `query_date`: The date context for the query
- `answer`: The expected answer (ground truth), only for reward calculation
- `message_ids`: IDs of relevant emails containing the answer, only for reward calculation
**Available Tools**:
- `search_emails`: Find emails by keywords, inbox address, and date range. Returns a list of email summaries (message_id and snippet).
- `read_email`: Read the full content of a specific email by message_id.
- `generate_response`: Provide the final structured answer with sources (inherited from ReAct agent).
## Dataset Preparation
The dataset contains email queries based on the [Enron Email dataset](https://huggingface.co/datasets/corbt/enron-emails). Run the data preparation script to generate the email database and datasets:
```bash
python prepare_data.py
```
If you want to choose a new database path, you can modify the `DEFAULT_DB_PATH` in [`prepare_data.py`](./prepare_data.py). Also, remember to set an environment variable `DEFAULT_EMAIL_DB_PATH` to point to the database path before moving to the next step:
```bash
export DEFAULT_EMAIL_DB_PATH=/path/to/enron_emails_dataset/data/enron_emails.db
```
This will create a SQLite database and datasets:
```
/path/to/enron_emails_dataset/
├── data
└── enron_emails.db # Email database
├── train.parquet # Training samples
└── test.parquet # Test samples
```
Each sample looks like:
```json
{
"id": 0,
"question": "Were there any variances detected for hour 6 on 3/9/01?",
"answer": "Yes, variances were detected in both Generation and Energy Import/Export schedules for hour 6 on 3/9/01.",
"message_ids": ["<17407857.1075840601283.JavaMail.evans@thyme>"],
"how_realistic": 0.800000011920929,
"inbox_address": "pete.davis@enron.com",
"query_date": "2001-03-16"
}
```
## Code Implementation
This section provides a high-level overview of the code implementation. For detailed implementation, please refer to the source code.
### Agent Workflow
The workflow function `run_email_search_agent` implements the agent-environment interaction loop:
```python
async def run_email_search_agent(
task: Dict,
model: ChatModelBase,
auxiliary_models: Dict[str, ChatModelBase],
) -> WorkflowOutput:
# Parse task and create agent
agent = EmailSearchAgent(
name="email_search_agent",
sys_prompt=system_prompt,
model=model,
max_iters=max_turns,
)
# Run the agent with structured output
response = await agent.reply(
msg=Msg("user", question, role="user"),
structured_model=AnswerModel,
)
return WorkflowOutput(response=response)
```
The agent follows a ReAct pattern: it reasons about the task, calls tools to search and read emails, and finally generates a structured response containing the answer and source message IDs.
### Judge Function
The judge function `email_search_judge` implements reward calculation using LLM-as-a-Judge:
```python
async def email_search_judge(
task: Dict,
response: Msg,
auxiliary_models: Dict[str, ChatModelBase],
) -> JudgeOutput:
# Extract answer and sources from response
answer = answer_and_sources.get("answer")
sources = answer_and_sources.get("sources", [])
# Judge correctness using LLM-as-a-Judge
judge_model = auxiliary_models.get('judge') or list(auxiliary_models.values())[0]
judge_response = await judge_correctness(
answer, query, judge_model
)
# Calculate reward based on:
# - Answer correctness (accuracy: -1.0 to 1.0)
# - Source correctness (format: partial rewards)
# - Efficiency (bonus for fewer turns, correct sources)
result = {"accuracy": ..., "format": ...} # calculated based on judge_response
return JudgeOutput(
reward=sum(result.values()),
metrics=metrics,
)
```
The reward function considers:
- **Answer correctness**: Evaluated by LLM-as-a-Judge comparing the agent's answer with the ground truth
- **Source correctness**: Whether the agent cited the correct email message IDs
- **Efficiency**: Bonus rewards for finding/reading the correct email and taking fewer turns
See [`main.py`](./main.py) and [`email_search_agent.py`](./email_search_agent.py) for implementation details.
## How to Run
### Prerequisites
- At least 4 NVIDIA GPUs with CUDA 12.8 or newer
* Note: For the 30B Judge model, you need to use a GPU with at least 4080 memory; you can also run the model on multiple GPUs by using `tensor_parallel_size > 1` to reduce the memory usage (by default, `tensor_parallel_size=2`).
- Follow the Trinity-RFT [installation guide](https://agentscope-ai.github.io/Trinity-RFT/en/main/tutorial/trinity_installation.html) to install the latest version from source code
- Download the model checkpoint (example):
```bash
huggingface-cli download Qwen/Qwen3-4B-Instruct-2507
huggingface-cli download Qwen/Qwen3-30B-A3B-Instruct-2507 # judge model
```
### Configuration
Adjust the configuration file ([`config.yaml`](./config.yaml)) based on your hardware. Key configuration sections include:
- **TunerModelConfig**: Set `model_path` to your model checkpoint path
- **AlgorithmConfig**: Configure RL algorithm parameters (e.g., `multi_step_grpo`, learning rate, policy loss function)
- **DatasetConfig**: The dataset path is specified in `main.py` when creating the `DatasetConfig` object
- **Auxiliary Models**: Configure judge model settings for LLM-as-a-Judge
For full configuration details, see [Trinity-RFT Configuration Guide](https://agentscope-ai.github.io/Trinity-RFT/en/main/tutorial/trinity_configs.html).
### Start-Up Commands
1. Prepare the dataset:
```bash
python prepare_data.py
export DEFAULT_EMAIL_DB_PATH=/path/to/enron_emails_dataset/data/enron_emails.db
```
2. Set up a [Ray](https://github.com/ray-project/ray) cluster:
```bash
ray start --head
```
3. Run the training script:
```bash
python main.py
```
## Experimental Results
### Quantitative Results
The training results show improvements in agent performance over training iterations. Key metrics include:
- **Train reward**: The average reward on training samples increases as the agent learns better strategies
- **Rollout accuracy**: The average accuracy on rollout samples increases as the agent learns better strategies
![Training Rewards](./critic_reward_mean.png)
![Rollout Accuracy](./rollout_accuracy_mean.png)
### Concrete Example
An example of the agent's behavior is shown below:
**Query:** "What do the color codes mean in the curve assessment?"
We show the last several turns of agent responses:
The agent performs multiple search attempts to find relevant emails. After some unsuccessful searches, the agent tries:
**Tool call:**
```json
{
"type": "tool_use",
"name": "search_emails",
"input": {
"inbox_address": "steven.kean@enron.com",
"query_date": "2001-04-03",
"keywords": ["curve", "assessment"]
}
}
```
**Tool result:**
```json
{
"type": "tool_result",
"name": "search_emails",
"output": [
{
"message_id": "<2654330.1075846153519.JavaMail.evans@thyme>",
"snippet": "... curves and Senior Management as necessary.\\n\\n<b>Curve Assessment</b>\\n Green: \\n - curve is reasonable\\n - small P&L ..."
},
{
"message_id": "<12499967.1075846153495.JavaMail.evans@thyme>",
"snippet": "... curves and Senior Management as necessary.\\n\\n<b>Curve Assessment</b>\\n Green: \\n - curve is reasonable\\n - small P&L ..."
}
]
}
```
After finding relevant emails, the agent uses `read_email` tool to read the full content of both emails:
**Tool call:**
```json
{
"type": "tool_use",
"name": "read_email",
"input": {
"message_id": "<2654330.1075846153519.JavaMail.evans@thyme>"
}
}
```
**Tool call:**
```json
{
"type": "tool_use",
"name": "read_email",
"input": {
"message_id": "<12499967.1075846153495.JavaMail.evans@thyme>"
}
}
```
After reading the emails, the agent generates the final structured response:
**Tool call:**
```json
{
"type": "tool_use",
"name": "generate_response",
"input": {
"answer": "In the curve assessment, the color codes are used as follows:\n- Green: The curve is considered reasonable, and the P&L (profit and loss) is small.",
"sources": [
"<2654330.1075846153519.JavaMail.evans@thyme>",
"<12499967.1075846153495.JavaMail.evans@thyme>"
]
}
}
```
The judge evaluates the answer as correct based on the ground truth.

View File

@@ -0,0 +1,279 @@
# 使用 AgentScope-Tuner 训练邮件搜索智能体
本示例展示如何使用 AgentScope-Tuner 对邮件搜索任务(灵感来自 [ART](https://openpipe.ai/blog/art-e-mail-agent))进行强化微调,其 RFT 功能由 [Trinity-RFT](https://github.com/agentscope-ai/Trinity-RFT) 提供支持。
## 任务设定
智能体的目标是通过搜索邮件收件箱来回答用户查询。智能体需要:
- 理解用户的问题
- 使用关键词搜索相关邮件
- 阅读邮件内容以提取信息
- 提供准确的答案并附上适当的来源引用
**智能体类型**:智能体(`EmailSearchAgent`)继承自 `ReActAgent`,遵循推理-行动循环来迭代解决任务。
**环境**:环境是一个包含来自 Enron 邮件数据集的 SQLite 数据库。每个任务提供:
- `question`:用户的邮件搜索查询
- `inbox_address`:要搜索的邮件收件箱
- `query_date`:查询的日期上下文
- `answer`:预期答案(真实值),仅用于奖励计算
- `message_ids`:包含答案的相关邮件 ID仅用于奖励计算
**可用工具**
- `search_emails`通过关键词、收件箱地址和日期范围查找邮件。返回邮件摘要列表message_id 和片段)。
- `read_email`:通过 message_id 读取特定邮件的完整内容。
- `generate_response`:提供带有来源的最终结构化答案(继承自 ReAct 智能体)。
## 数据集准备
数据集包含基于 [Enron 邮件数据集](https://huggingface.co/datasets/corbt/enron-emails) 的邮件查询。运行数据准备脚本以生成邮件数据库和数据集:
```bash
python prepare_data.py
```
如果你想选择新的数据库路径,可以修改 [`prepare_data.py`](./prepare_data.py) 中的 `DEFAULT_DB_PATH`。同时,请记住在进入下一步之前设置环境变量 `DEFAULT_EMAIL_DB_PATH` 指向数据库路径:
```bash
export DEFAULT_EMAIL_DB_PATH=/path/to/enron_emails_dataset/data/enron_emails.db
```
这将创建一个 SQLite 数据库和数据集:
```
/path/to/enron_emails_dataset/
├── data
└── enron_emails.db # 邮件数据库
├── train.parquet # 训练样本
└── test.parquet # 测试样本
```
每个样本如下所示:
```json
{
"id": 0,
"question": "Were there any variances detected for hour 6 on 3/9/01?",
"answer": "Yes, variances were detected in both Generation and Energy Import/Export schedules for hour 6 on 3/9/01.",
"message_ids": ["<17407857.1075840601283.JavaMail.evans@thyme>"],
"how_realistic": 0.800000011920929,
"inbox_address": "pete.davis@enron.com",
"query_date": "2001-03-16"
}
```
## 代码实现
本节提供代码实现的高级概览。详细实现请参考源代码。
### 智能体工作流
工作流函数 `run_email_search_agent` 实现智能体-环境交互循环:
```python
async def run_email_search_agent(
task: Dict,
model: ChatModelBase,
auxiliary_models: Dict[str, ChatModelBase],
) -> WorkflowOutput:
# 解析任务并创建智能体
agent = EmailSearchAgent(
name="email_search_agent",
sys_prompt=system_prompt,
model=model,
max_iters=max_turns,
)
# 使用结构化输出运行智能体
response = await agent.reply(
msg=Msg("user", question, role="user"),
structured_model=AnswerModel,
)
return WorkflowOutput(response=response)
```
智能体遵循 ReAct 模式:它推理任务,调用工具搜索和阅读邮件,最后生成包含答案和来源消息 ID 的结构化响应。
### 评判函数
评判函数 `email_search_judge` 使用 LLM-as-a-Judge 实现奖励计算:
```python
async def email_search_judge(
task: Dict,
response: Msg,
auxiliary_models: Dict[str, ChatModelBase],
) -> JudgeOutput:
# 从响应中提取答案和来源
answer = answer_and_sources.get("answer")
sources = answer_and_sources.get("sources", [])
# 使用 LLM-as-a-Judge 评判正确性
judge_model = auxiliary_models.get('judge') or list(auxiliary_models.values())[0]
judge_response = await judge_correctness(
answer, query, judge_model
)
# 基于以下因素计算奖励:
# - 答案正确性(准确度:-1.0 到 1.0
# - 来源正确性(格式:部分奖励)
# - 效率(对更少轮次、正确来源的奖励)
result = {"accuracy": ..., "format": ...} # 基于 judge_response 计算
return JudgeOutput(
reward=sum(result.values()),
metrics=metrics,
)
```
奖励函数考虑以下因素:
- **答案正确性**:通过 LLM-as-a-Judge 比较智能体的答案与真实值进行评估
- **来源正确性**:智能体是否引用了正确的邮件消息 ID
- **效率**:对找到/阅读正确邮件和更少轮次的奖励
详细实现请参考 [`main.py`](./main.py) 和 [`email_search_agent.py`](./email_search_agent.py)。
## 运行方法
### 前置要求
- 至少 4 张 NVIDIA GPUCUDA 版本 ≥ 12.8
* 注意:对于 30B 评判模型,需要使用至少 4080 显存的 GPU你也可以通过使用 `tensor_parallel_size > 1` 在多张 GPU 上运行模型以减少显存使用(默认情况下,`tensor_parallel_size=2`)。
- 按照 Trinity-RFT [安装指南](https://agentscope-ai.github.io/Trinity-RFT/zh/main/tutorial/trinity_installation.html) 从源码安装最新版本
- 下载模型检查点(示例):
```bash
huggingface-cli download Qwen/Qwen3-4B-Instruct-2507
huggingface-cli download Qwen/Qwen3-30B-A3B-Instruct-2507 # 评判模型
```
### 配置
根据你的硬件调整配置文件([`config.yaml`](./config.yaml))。关键配置部分包括:
- **TunerModelConfig**:将 `model_path` 设置为你的模型检查点路径
- **AlgorithmConfig**:配置 RL 算法参数(例如,`multi_step_grpo`、学习率、策略损失函数)
- **DatasetConfig**:数据集路径在创建 `DatasetConfig` 对象时在 `main.py` 中指定
- **辅助模型**:为 LLM-as-a-Judge 配置评判模型设置
完整配置详情请参考 [Trinity-RFT 配置指南](https://agentscope-ai.github.io/Trinity-RFT/zh/main/tutorial/trinity_configs.html)。
### 启动命令
1. 准备数据集:
```bash
python prepare_data.py
export DEFAULT_EMAIL_DB_PATH=/path/to/enron_emails_dataset/data/enron_emails.db
```
2. 启动 [Ray](https://github.com/ray-project/ray)
```bash
ray start --head
```
3. 运行训练脚本:
```bash
python main.py
```
## 实验结果
### 定量结果
训练结果显示智能体性能随训练迭代次数的提升。关键指标包括:
- **训练奖励**:训练样本上的平均奖励随着智能体学习更好的策略而增加
- **Rollout 准确度**Rollout 样本上的平均准确度随着智能体学习更好的策略而增加
![Training Rewards](./critic_reward_mean.png)
![Rollout Accuracy](./rollout_accuracy_mean.png)
### 具体示例
智能体行为示例如下:
**查询:** "What do the color codes mean in the curve assessment?"
我们展示智能体响应的最后几轮:
智能体执行多次搜索尝试以找到相关邮件。经过一些不成功的搜索后,智能体尝试:
**工具调用:**
```json
{
"type": "tool_use",
"name": "search_emails",
"input": {
"inbox_address": "steven.kean@enron.com",
"query_date": "2001-04-03",
"keywords": ["curve", "assessment"]
}
}
```
**工具结果:**
```json
{
"type": "tool_result",
"name": "search_emails",
"output": [
{
"message_id": "<2654330.1075846153519.JavaMail.evans@thyme>",
"snippet": "... curves and Senior Management as necessary.\\n\\n<b>Curve Assessment</b>\\n Green: \\n - curve is reasonable\\n - small P&L ..."
},
{
"message_id": "<12499967.1075846153495.JavaMail.evans@thyme>",
"snippet": "... curves and Senior Management as necessary.\\n\\n<b>Curve Assessment</b>\\n Green: \\n - curve is reasonable\\n - small P&L ..."
}
]
}
```
找到相关邮件后,智能体使用 `read_email` 工具读取两封邮件的完整内容:
**工具调用:**
```json
{
"type": "tool_use",
"name": "read_email",
"input": {
"message_id": "<2654330.1075846153519.JavaMail.evans@thyme>"
}
}
```
**工具调用:**
```json
{
"type": "tool_use",
"name": "read_email",
"input": {
"message_id": "<12499967.1075846153495.JavaMail.evans@thyme>"
}
}
```
阅读邮件后,智能体生成最终的结构化响应:
**工具调用:**
```json
{
"type": "tool_use",
"name": "generate_response",
"input": {
"answer": "In the curve assessment, the color codes are used as follows:\n- Green: The curve is considered reasonable, and the P&L (profit and loss) is small.",
"sources": [
"<2654330.1075846153519.JavaMail.evans@thyme>",
"<12499967.1075846153495.JavaMail.evans@thyme>"
]
}
}
```
评判器评估上面的答案为正确。

View File

@@ -0,0 +1,175 @@
# -*- coding: utf-8 -*-
"""Adapted from Trinity-RFT"""
import json
import traceback
from dataclasses import asdict
from datetime import datetime, timedelta
from typing import Any
from _utils import ( # pylint: disable=E0611
read_email_tool,
search_emails_tool,
)
from agentscope import logger
from agentscope.agent import ReActAgent
from agentscope.message import TextBlock
from agentscope.tool import Toolkit, ToolResponse
def pre_reasoning_hook(_self: Any, _kwargs: Any) -> dict[str, Any] | None:
"""Pre-reasoning hook to remove tool_choice from kwargs."""
_kwargs.pop("tool_choice", None)
return _kwargs
class EmailSearchAgent(ReActAgent):
"""
A customized ReAct agent with pre-defined tools for
email search and reading.
Ref: https://github.com/OpenPipe/ART
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
self.message_id_list = (
[]
) # List to store message IDs found during search
self.ever_read_message_ids = (
[]
) # List to store message IDs that have been read
toolkit = Toolkit()
toolkit.register_tool_function(self.search_emails)
toolkit.register_tool_function(self.read_email)
super().__init__(*args, toolkit=toolkit, **kwargs)
self.register_instance_hook(
"pre_reasoning",
"tool_choice_hook",
pre_reasoning_hook,
)
async def reset(self) -> None:
"""Reset agent state for a new rollout/episode."""
self.message_id_list.clear()
self.ever_read_message_ids.clear()
await self.memory.clear()
def search_emails(
self,
inbox_address: str,
query_date: str,
keywords: list[str],
**_kwargs: Any,
) -> ToolResponse:
"""
Search the user's email inbox for emails that match the given keywords.
Args:
inbox_address: The user's email address.
query_date: The date of the query in 'YYYY-MM-DD' format.
keywords: Keywords to search for in the user's email inbox.
Returns:
ToolResponse:
A ToolResponse object containing a list of TextBlock objects
in the `content` field. On success, the text field of the
TextBlock contains a JSON string representing a list of email
summaries (e.g., message_id, snippet) matching the search
criteria. Each email summary is converted to a dictionary via
`asdict`. On failure, the text indicates an error message.
"""
try:
next_day = (
datetime.strptime(query_date, "%Y-%m-%d") + timedelta(days=1)
).strftime(
"%Y-%m-%d",
)
res = search_emails_tool(
inbox=inbox_address,
sent_before=next_day,
keywords=keywords,
)
self.message_id_list.extend([r.message_id for r in res])
return ToolResponse(
content=[
TextBlock(
type="text",
text=json.dumps([asdict(r) for r in res]),
),
],
)
except Exception as e:
logger.info(
"Error in search_emails: %s, traceback: %s",
e,
traceback.format_exc(),
)
return ToolResponse(
content=[
TextBlock(
type="text",
text=(
f"Error: Failed to search emails.\n"
f"Error message: {e}"
),
),
],
)
def read_email(self, message_id: str, **_kwargs: Any) -> ToolResponse:
"""
Read the content of an email from the user's email inbox.
Returns the email content.
Args:
message_id (str): The unique identifier of the email to read.
Returns:
ToolResponse:
A ToolResponse object containing the email content or an
error message if the email is not found.
"""
try:
email_content = read_email_tool(message_id)
self.ever_read_message_ids.append(message_id)
if email_content is None:
return ToolResponse(
content=[
TextBlock(
type="text",
text=(
f"Error: Email (message_id = {message_id}) "
f"not found."
),
),
],
)
return ToolResponse(
content=[
TextBlock(
type="text",
text=json.dumps(email_content.model_dump()),
),
],
)
except Exception as e:
logger.info(
"Error in read_email: %s, traceback: %s",
e,
traceback.format_exc(),
)
return ToolResponse(
content=[
TextBlock(
type="text",
text=(
f"Error: Failed to read email.\n"
f"Error message: {e}"
),
),
],
)

View File

@@ -0,0 +1,328 @@
# -*- coding: utf-8 -*-
"""
This file defines Dataclass and tool implementations.
Modified from https://github.com/OpenPipe/ART/blob/art-e/
"""
import datetime
import os
import sqlite3
from dataclasses import dataclass
from typing import Any, List, Optional
from pydantic import BaseModel, Field, field_validator
from agentscope import logger
DEFAULT_DB_PATH = os.environ.get("DEFAULT_EMAIL_DB_PATH")
conn = None
def get_conn() -> sqlite3.Connection:
"""Get or create a database connection."""
global conn
if conn is None:
conn = sqlite3.connect(
f"file:{DEFAULT_DB_PATH}?mode=ro",
uri=True,
check_same_thread=False,
)
return conn
class QueryModel(BaseModel):
"""Model for email search query."""
id: int
question: str
answer: str
message_ids: List[str] # message_ids (strings) of referenced emails
how_realistic: float
inbox_address: str
query_date: str
@field_validator("query_date", mode="before")
@classmethod
def format_date(cls, v: Any) -> str:
"""Format date to string if it's a datetime object."""
if isinstance(v, datetime.datetime):
return v.strftime("%Y-%m-%d")
return v
class AnswerModel(BaseModel):
"""Model for agent's answer with sources."""
answer: str = Field(
description=(
"It should be called with the answer and the sources. "
"If you cannot find the answer, you should return "
"'I don't know' with an empty list of sources."
),
)
sources: List[str] = Field(
description=(
"a list of message ids that are relevant to the query. "
"Usually there will be only one. If you cannot find the "
"answer, you should return an empty list."
),
)
class Email(BaseModel):
"""Model representing an email."""
message_id: str
date: str # ISO 8601 string 'YYYY-MM-DD HH:MM:SS'
subject: Optional[str] = None
from_address: Optional[str] = None
to_addresses: List[str] = Field(default_factory=list)
cc_addresses: List[str] = Field(default_factory=list)
bcc_addresses: List[str] = Field(default_factory=list)
body: Optional[str] = None
file_name: Optional[str] = None
@dataclass
class SearchResult:
"""Result from email search."""
message_id: str
snippet: str
class FinalRubric(BaseModel):
"""Rubric for evaluating agent performance."""
answer_correct: bool = False
sources_correct: bool = False
num_turns: int = 0
attempted_answer: bool = False
ever_found_right_email: bool = False
ever_read_right_email: bool = False
cant_parse_tool_call: bool = False
bad_tool_call_name: bool = False
bad_tool_call_args: bool = False
ran_out_of_turns: bool = False
returned_i_dont_know: bool = False
num_sources: int = 0
ever_tried_to_read_invalid_email: bool = False
prompt_tokens: int = 0
completion_tokens: int = 0
# Define tools for agent
def search_emails_tool(
inbox: str,
keywords: List[str],
from_addr: Optional[str] = None,
to_addr: Optional[str] = None,
sent_after: Optional[str] = None,
sent_before: Optional[str] = None,
max_results: int = 10,
) -> List[SearchResult]:
"""
Searches the email database based on keywords, inbox,
sender, recipient, and date range.
Args:
inbox: The email address of the user performing the search.
Results include emails sent from or to (inc. cc/bcc)
this address.
keywords: A list of keywords that must all appear in the
subject or body.
from_addr: Optional email address to filter emails sent *from*.
to_addr: Optional email address to filter emails sent *to*
(inc. cc/bcc).
sent_after: Optional date string 'YYYY-MM-DD'. Filters for
emails sent on or after this date.
sent_before: Optional date string 'YYYY-MM-DD'. Filters for
emails sent before this date.
max_results: The maximum number of results to return.
Cannot exceed 10.
Returns:
A list of SearchResult objects, each containing 'message_id'
and 'snippet'. Returns an empty list if no results are found
or an error occurs.
"""
# Initialize sql and params
sql: Optional[str] = None
params: List[str | int] = []
cursor = get_conn().cursor()
# --- Build Query ---
where_clauses: List[str] = []
# 1. Keywords (FTS)
if not keywords:
raise ValueError("No keywords provided for search.")
if max_results > 10:
raise ValueError("max_results must be less than or equal to 10.")
# FTS5 default is AND, so just join keywords. Escape quotes for safety.
fts_query = " ".join(f""" "{k.replace('"', '""')}" """ for k in keywords)
where_clauses.append("emails_fts MATCH ?")
params.append(fts_query)
# 2. Inbox filter (must be from OR to/cc/bcc the inbox user)
# Use the composite index idx_recipients_address_email here
where_clauses.append(
"""
(e.from_address = ? OR EXISTS (
SELECT 1 FROM recipients r_inbox
WHERE r_inbox.recipient_address = ? AND r_inbox.email_id = e.id
))
""",
)
params.extend([inbox, inbox])
# 3. Optional From filter
if from_addr:
where_clauses.append("e.from_address = ?")
params.append(from_addr)
# 4. Optional To filter (includes to, cc, bcc)
# Use composite index idx_recipients_address_email
if to_addr:
where_clauses.append(
"""
EXISTS (
SELECT 1 FROM recipients r_to
WHERE r_to.recipient_address = ? AND r_to.email_id = e.id
)
""",
)
params.append(to_addr)
# 5. Optional Sent After filter
if sent_after:
# Assumes date format 'YYYY-MM-DD'
# Compare against the start of the day
where_clauses.append("e.date >= ?")
params.append(f"{sent_after} 00:00:00")
# 6. Optional Sent Before filter
if sent_before:
# Assumes date format 'YYYY-MM-DD'
# Compare against the start of the day (exclusive)
where_clauses.append("e.date < ?")
params.append(f"{sent_before} 00:00:00")
# --- Construct Final Query ---
# snippet(<table>, <column_index>, <highlight_start>,
# <highlight_end>, <ellipsis>, <tokens>)
# -1 means highlight across all columns (subject, body)
sql = f"""
SELECT
e.message_id,
snippet(emails_fts, -1, '<b>', '</b>', ' ... ', 15) as snippet
FROM
emails e JOIN emails_fts fts ON e.id = fts.rowid
WHERE
{" AND ".join(where_clauses)}
ORDER BY
e.date DESC -- Order by date for relevance
LIMIT ?;
"""
params.append(max_results)
# --- Execute and Fetch ---
logger.debug("Executing SQL: %s", sql)
logger.debug("With params: %s", params)
cursor.execute(sql, params)
results = cursor.fetchall()
# Format results
formatted_results = [
SearchResult(message_id=row[0], snippet=row[1]) for row in results
]
logger.info("Search found %d results.", len(formatted_results))
return formatted_results
def read_email_tool(message_id: str) -> Optional[Email]:
"""
Retrieves a single email by its message_id from the database.
Args:
message_id: The unique identifier of the email to retrieve.
Returns:
An Email object containing the details of the found email,
or None if the email is not found or an error occurs.
"""
cursor = get_conn().cursor()
# --- Query for Email Core Details ---
email_sql = """
SELECT id, message_id, date, subject, from_address, body, file_name
FROM emails
WHERE message_id = ?;
"""
cursor.execute(email_sql, (message_id,))
email_row = cursor.fetchone()
if not email_row:
logger.warning("Email with message_id '%s' not found.", message_id)
return None
email_pk_id, msg_id, date, subject, from_addr, body, file_name = email_row
# DEBUG
logger.info("[read_email_tool] input_message_id=%s", message_id)
logger.info(
"[read_email_tool] db: id=%s, message_id=%s",
email_pk_id,
msg_id,
)
# search for recipients by emails.id (rather than message_id)
recipients_sql = """
SELECT recipient_address, recipient_type
FROM recipients
WHERE email_id = ?;
"""
cursor.execute(recipients_sql, (email_pk_id,))
recipient_rows = cursor.fetchall()
to_addresses: List[str] = []
cc_addresses: List[str] = []
bcc_addresses: List[str] = []
for addr, rtype in recipient_rows:
type_lower = rtype.lower()
if type_lower == "to":
to_addresses.append(addr)
elif type_lower == "cc":
cc_addresses.append(addr)
elif type_lower == "bcc":
bcc_addresses.append(addr)
# --- Construct Email Object ---
email_obj = Email(
message_id=msg_id, # Convert to string to match Pydantic model
date=date,
subject=subject,
from_address=from_addr,
to_addresses=to_addresses,
cc_addresses=cc_addresses,
bcc_addresses=bcc_addresses,
body=body,
file_name=file_name,
)
return email_obj
__all__ = [
"QueryModel",
"AnswerModel",
"FinalRubric",
"Email",
"SearchResult",
"search_emails_tool",
"read_email_tool",
"get_conn",
]

View File

@@ -0,0 +1,72 @@
project: "AgentScope" # Project name
name: "Email_search" # Experiment name
checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints} # Directory to save model checkpoints
algorithm:
algorithm_type: multi_step_grpo # GRPO series for multi-step scenario
repeat_times: 8 # Number of rollouts per prompt for advantage estimation
optimizer:
lr: 1e-6 # Learning rate
policy_loss_fn: "rec" # Policy loss function
policy_loss_fn_args: # Policy loss function arguments
epsilon_low: 0.2
epsilon_high: 0.2
clip_mode: "one-side"
weight: "none"
temp: 1.0
regularizer: "none"
regularizer_coef: 0.0
kl_loss_fn: 'k2' # KL divergence loss function
kl_loss_fn_args:
kl_coef: 0.0 # KL divergence coefficient
advantage_fn_args:
std_cal_level: 'batch' # Advantage normalization level
model:
model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen3-4B-Instruct-2507} # Base model path
max_response_tokens: 4096 # Max tokens per response
max_model_len: 20480 # Max context length
buffer:
total_epochs: 10 # Total training epochs
batch_size: 64 # Batch size per explore step
train_batch_size: 2560 # 64*8*5, total experiences per training step
trainer_input:
experience_buffer:
name: experience_buffer
storage_type: queue
replay_buffer:
enable: true # Enable experience replay
priority_fn: 'decay_limit_randomization'
priority_fn_args:
decay: 2.0
use_count_limit: 3
sigma: 2.0
explorer:
eval_interval: 10
max_repeat_times_per_runner: 1 # Max repeat times per runner
max_timeout: 3600 # Max timeout for each rollout (seconds)
rollout_model:
enable_history: true # Enable conversation history
enable_openai_api: true # Enable OpenAI-compatible API
enable_auto_tool_choice: true # Enable automatic tool selection
tool_call_parser: hermes # Parser for tool calls
engine_num: 4 # Number of vLLM engines for rollout model
tensor_parallel_size: 1 # TP size per engine for rollout model
enable_prefix_caching: false # Disable prefix caching
auxiliary_models:
- name: judge
model_path: Qwen/Qwen3-30B-A3B-Instruct-2507 # Judge model path
engine_num: 1 # Number of vLLM engines for judge model
tensor_parallel_size: 2 # TP size per engine for judge model
enable_thinking: false # Disable thinking/reasoning mode
max_prompt_tokens: 2048 # Max tokens for prompt
max_response_tokens: 128 # Max tokens for response
max_model_len: 2500 # Max model context length
synchronizer:
sync_style: dynamic_by_explorer # Sync triggered dynamically by explorer
sync_interval: 5 # Sync every N steps
sync_timeout: 3600 # Timeout for synchronization (seconds)
trainer:
save_interval: 100 # Save checkpoint every N steps
grad_clip: 1.0 # Gradient clipping value
use_dynamic_bsz: true # Use dynamic batch size
max_token_len_per_gpu: 16384 # Max token length per GPU
ulysses_sequence_parallel_size: 1 # Sequence parallel size for Ulysses

Binary file not shown.

After

Width:  |  Height:  |  Size: 470 KiB

379
tuner/email_search/main.py Normal file
View File

@@ -0,0 +1,379 @@
# -*- coding: utf-8 -*-
"""Example of training an Email Search agent with Trinity-RFT."""
import os
from typing import Dict
from _email_search_agent import EmailSearchAgent
from _utils import ( # pylint: disable=E0611
AnswerModel,
FinalRubric,
QueryModel,
)
from agentscope import logger
from agentscope.formatter import OpenAIChatFormatter
from agentscope.message import Msg
from agentscope.tuner import (
TunerModelConfig,
DatasetConfig,
JudgeOutput,
WorkflowOutput,
AlgorithmConfig,
tune,
)
from agentscope.model import ChatModelBase
SYSTEM_PROMPT = """You are an email search agent. You are given a user query
and a list of tools you can use to search the user's email. Use the tools to
search the user's emails and find the answer to the user's query. You may take
up to {max_turns} turns to find the answer, so if your first seach doesn't
find the answer, you can try with different keywords.
Always describe what you see and plan your next steps clearly. When taking
actions, explain what you're doing and why. When the answer to the task is
found, call `generate_response` to finish the process. Only call
`generate_response` when answer is found. You should not respond any next steps
in `generate_response`. Complete all steps and then call `generate_response`.
User's email address is {inbox_address}
Today's date is {query_date}
"""
async def run_email_search_agent(
task: Dict,
model: ChatModelBase,
auxiliary_models: Dict[str, ChatModelBase],
) -> WorkflowOutput: # noqa: PLR0915
"""A workflow function using the Email Search agent to solve tasks.
Args:
task (Dict): The task to be solved.
Should contain fields from QueryModel.
model (TrinityChatModel): The language model to use.
Returns:
WorkflowOutput: The output containing the agent's response.
"""
assert len(auxiliary_models) > 0, "LLM-as-a-Judge is required"
# Parse task data
query = QueryModel.model_validate(task)
question = task.get("question", task.get("task_desc", ""))
# Get workflow arguments with defaults
workflow_args = task.get("workflow_args", {})
max_turns = int(workflow_args.get("max_turns", 10))
# Format system prompt
system_prompt = SYSTEM_PROMPT.format(
max_turns=max_turns,
inbox_address=query.inbox_address,
query_date=query.query_date,
)
# Create EmailSearchAgent
agent = EmailSearchAgent(
name="email_search_agent",
sys_prompt=system_prompt,
model=model,
formatter=OpenAIChatFormatter(),
max_iters=max_turns,
)
# Reset agent state for a new rollout
await agent.reset()
# Run the agent with structured output
response = await agent.reply(
msg=Msg("user", question, role="user"),
structured_model=AnswerModel,
)
# Extract answer and sources from response metadata
answer_and_sources = response.metadata or {}
if not answer_and_sources:
# Fallback: try to parse from content
answer_and_sources = {
"answer": response.get_text_content() or "",
"sources": [],
}
# Store agent state for judge function
# We'll pass this through the response metadata
response_metadata = {
"answer_and_sources": answer_and_sources,
"query": query.model_dump(),
"message_id_list": agent.message_id_list,
"ever_read_message_ids": agent.ever_read_message_ids,
# Estimate actual_turns from memory length
"actual_turns": (
max(1, (len(agent.memory.content) - 1) // 2)
if len(agent.memory.content) > 1
else 1
),
}
# Update response metadata
if response.metadata is None:
response.metadata = {}
response.metadata.update(response_metadata)
return WorkflowOutput(
response=response,
)
def _calculate_partial_rewards(rubric: FinalRubric) -> float:
"""Calculate partial rewards based on rubric."""
partial_rewards = 0.0
partial_rewards += 0.1 if rubric.ever_found_right_email else 0
partial_rewards += 0.1 if rubric.ever_read_right_email else 0
partial_rewards += 0.1 if rubric.sources_correct else 0
return partial_rewards
def _calculate_correct_answer_reward(
rubric: FinalRubric,
max_turns: int,
) -> float:
"""Calculate reward for correct answers."""
reward = 1.0
reward += 0.3 if rubric.sources_correct else 0
reward += 0.1 / rubric.num_sources if rubric.num_sources > 0 else 0
reward += 0.1 * (1 - rubric.num_turns / max_turns)
return reward
def _initialize_rubric(
answer: str,
sources: list[str],
actual_turns: int,
query: QueryModel,
message_id_list: list[str],
ever_read_message_ids: list[str],
) -> FinalRubric:
"""Initialize and populate rubric with basic information."""
rubric = FinalRubric()
rubric.attempted_answer = answer is not None and answer != ""
rubric.returned_i_dont_know = answer == "I don't know"
rubric.num_sources = len(sources)
rubric.num_turns = actual_turns
if len(query.message_ids) > 0:
rubric.ever_found_right_email = query.message_ids[0] in message_id_list
rubric.ever_read_right_email = (
query.message_ids[0] in ever_read_message_ids
)
rubric.sources_correct = query.message_ids[0] in sources
return rubric
async def email_search_judge(
task: Dict,
response: Msg,
auxiliary_models: Dict[str, ChatModelBase],
) -> JudgeOutput:
"""A judge function to calculate reward based on agent's response.
Args:
task (Dict): The task information for the corresponding workflow.
response (Msg): The response generated by the corresponding workflow.
auxiliary_models (Dict[str, ChatModelBase]):
A dictionary of additional chat models available for LLM-as-a-Judge
usage. The keys are model names, and the values are the
corresponding ChatModelBase instances.
Returns:
JudgeOutput: The reward value assigned by the judge function.
"""
# Extract metadata from response
metadata = response.metadata or {}
answer_and_sources = metadata.get("answer_and_sources", {})
query_dict = metadata.get("query", {})
message_id_list = metadata.get("message_id_list", [])
ever_read_message_ids = metadata.get("ever_read_message_ids", [])
actual_turns = metadata.get("actual_turns", 0)
# Parse query model
if not query_dict:
query_dict = task
query = QueryModel.model_validate(query_dict)
# Get arguments
workflow_args = task.get("workflow_args", {})
max_turns = int(workflow_args.get("max_turns", 10))
# Extract answer and sources
try:
answer = answer_and_sources.get("answer", None)
sources = answer_and_sources.get("sources", [])
except Exception:
result = {"accuracy": 0.0, "format": -1.0}
return JudgeOutput(
reward=sum(result.values()),
metrics=result,
)
if answer is None:
result = {"accuracy": 0.0, "format": -1.0}
return JudgeOutput(
reward=sum(result.values()),
metrics=result,
)
# Initialize rubric
rubric = _initialize_rubric(
answer,
sources,
actual_turns,
query,
message_id_list,
ever_read_message_ids,
)
# Judge correctness using LLM-as-a-Judge
try:
judge_model = (
auxiliary_models.get("judge") or list(auxiliary_models.values())[0]
if auxiliary_models
else None
)
judge_response = await judge_correctness(
answer,
query,
judge_model,
)
rubric.answer_correct = judge_response
except Exception as e:
logger.error("Error judging correctness: %s", e)
rubric.answer_correct = False
# Calculate rewards
partial_rewards = _calculate_partial_rewards(rubric)
if rubric.attempted_answer and not rubric.answer_correct:
result = {"accuracy": -1.0, "format": partial_rewards}
elif rubric.returned_i_dont_know or rubric.ran_out_of_turns:
result = {"accuracy": 0.0, "format": partial_rewards}
elif rubric.answer_correct:
reward = _calculate_correct_answer_reward(rubric, max_turns)
result = {"accuracy": 1.0, "format": reward}
else:
result = {"accuracy": 0.0, "format": 0.0}
metrics = result.copy()
metrics.update({"actual_turns": actual_turns})
return JudgeOutput(
reward=sum(result.values()),
metrics=metrics,
)
# LLM-as-a-judge
async def judge_correctness(
answer: str,
query: QueryModel,
judge: ChatModelBase,
) -> bool:
"""Use an LLM to decide whether *answer* matches *query.answer*.
Returns a boolean *accept* flag used for scoring.
"""
system_prompt = """You are given a question, the reference answer
(labelled **Reference answer**), and an answer generated by an AI assistant
(labelled **AI answer**).
Follow these steps to decide whether the AI answer should be accepted:
1. Identify EXACTLY what information the **question** is asking for
(e.g. who, what, when, where, why, how, quantity, etc.).
2. From the **Reference answer**, extract ONLY the facts that are required
to directly satisfy the information need identified in step 1. Treat all
other facts as non-essential context.
3. Verify that every essential fact from step 2 appears in the **AI answer**
with the same meaning. Differences in wording, order, or additional
non-conflicting details are allowed.
4. If any essential fact is missing or contradicted in the **AI answer**,
then *accept* must be **false**. Otherwise *accept* must be **true**.
Important: Do NOT penalise the **AI answer** for omitting non-essential
facts that appear in the **Reference answer**. The answer should only be
rejected for errors or omissions in the information explicitly requested by
the question.
Return your judgement **accept** from **true** and **false**. Do not return
any other text or formatting.
"""
prompt = (
f"Question: {query.question}\n"
f"Reference answer: {query.answer}\n"
f"AI answer: {answer}"
)
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt},
]
chat_response = await judge(messages)
# Extract text content from ChatResponse
result_parts = []
for block in chat_response.content:
if isinstance(block, dict) and block.get("type") == "text":
result_parts.append(str(block.get("text", "")))
result = "".join(result_parts)
logger.info("LLM judge response: %s", result)
return "true" in result.lower()
# End of LLM-as-a-judge
if __name__ == "__main__":
config_path = os.path.join(
os.path.dirname(__file__),
"config.yaml",
)
dataset = DatasetConfig(
path="/path/to/enron_emails_dataset",
split="train",
)
tuner_model = TunerModelConfig(
model_path="Qwen/Qwen3-4B-Instruct-2507",
max_model_len=20480,
max_tokens=4096,
inference_engine_num=4,
reasoning_parser=None,
)
aux_models = {
"judge": TunerModelConfig(
model_path="Qwen/Qwen3-30B-A3B-Instruct-2507",
max_model_len=2500,
max_tokens=2048,
inference_engine_num=1,
tensor_parallel_size=2,
tool_call_parser=None,
reasoning_parser=None,
),
}
algorithm = AlgorithmConfig(
algorithm_type="multi_step_grpo",
group_size=8,
batch_size=64,
learning_rate=1e-6,
)
tune(
workflow_func=run_email_search_agent,
judge_func=email_search_judge,
train_dataset=dataset,
model=tuner_model,
auxiliary_models=aux_models,
algorithm=algorithm,
config_path=config_path,
)

View File

@@ -0,0 +1,357 @@
# -*- coding: utf-8 -*-
"""
Prepare data for training.
Modified from OpenPipe/ART
"""
import logging
import os
import sqlite3
from datetime import datetime
from datasets import Dataset, Features, Sequence, Value, load_dataset
from tqdm import tqdm
# Resolve paths relative to this file
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
# Database will live in "../data/enron_emails.db" relative to project root
DEFAULT_DB_PATH = os.path.join(BASE_DIR, "..", "..", "data", "enron_emails.db")
DEFAULT_REPO_ID = "corbt/enron-emails"
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
)
# --- Database Schema ---
SQL_CREATE_TABLES = """
DROP TABLE IF EXISTS recipients;
DROP TABLE IF EXISTS emails_fts;
DROP TABLE IF EXISTS emails;
CREATE TABLE emails (
id INTEGER PRIMARY KEY AUTOINCREMENT,
message_id TEXT UNIQUE,
subject TEXT,
from_address TEXT,
date TEXT, -- Store as ISO 8601 string 'YYYY-MM-DD HH:MM:SS'
body TEXT,
file_name TEXT
);
CREATE TABLE recipients (
email_id INTEGER,
recipient_address TEXT,
recipient_type TEXT, -- 'to', 'cc', 'bcc'
FOREIGN KEY(email_id) REFERENCES emails(id) ON DELETE CASCADE
);
"""
SQL_CREATE_INDEXES_TRIGGERS = """
CREATE INDEX idx_emails_from ON emails(from_address);
CREATE INDEX idx_emails_date ON emails(date);
CREATE INDEX idx_emails_message_id ON emails(message_id);
CREATE INDEX idx_recipients_address ON recipients(recipient_address);
CREATE INDEX idx_recipients_type ON recipients(recipient_type);
CREATE INDEX idx_recipients_email_id ON recipients(email_id);
CREATE INDEX idx_recipients_address_email ON recipients(
recipient_address, email_id
);
CREATE VIRTUAL TABLE emails_fts USING fts5(
subject,
body,
content='emails',
content_rowid='id'
);
CREATE TRIGGER emails_ai AFTER INSERT ON emails BEGIN
INSERT INTO emails_fts (rowid, subject, body)
VALUES (new.id, new.subject, new.body);
END;
CREATE TRIGGER emails_ad AFTER DELETE ON emails BEGIN
DELETE FROM emails_fts WHERE rowid=old.id;
END;
CREATE TRIGGER emails_au AFTER UPDATE ON emails BEGIN
UPDATE emails_fts SET subject=new.subject, body=new.body
WHERE rowid=old.id;
END;
INSERT INTO emails_fts (rowid, subject, body)
SELECT id, subject, body FROM emails;
"""
# --- Functions ---
def download_dataset(repo_id: str) -> Dataset:
"""Downloads the dataset from Hugging Face Hub."""
logging.info(
"Attempting to download dataset from Hugging Face Hub: %s",
repo_id,
)
expected_features = Features(
{
"message_id": Value("string"),
"subject": Value("string"),
"from": Value("string"),
"to": Sequence(Value("string")),
"cc": Sequence(Value("string")),
"bcc": Sequence(Value("string")),
"date": Value("timestamp[us]"),
"body": Value("string"),
"file_name": Value("string"),
},
)
dataset_obj = load_dataset(
repo_id,
features=expected_features,
split="train",
)
# Basic type check remains useful
if not isinstance(dataset_obj, Dataset):
raise TypeError(f"Expected Dataset, got {type(dataset_obj)}")
logging.info(
"Successfully loaded dataset '%s' with %d records.",
repo_id,
len(dataset_obj),
)
return dataset_obj
def create_database(db_path: str) -> None:
"""Creates the SQLite database and tables."""
logging.info("Creating SQLite database and tables at: %s", db_path)
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
cursor.executescript(SQL_CREATE_TABLES)
conn.commit()
conn.close()
logging.info("Database tables created successfully.")
def _should_skip_email(
body: str,
message_id: str,
to_list: list[str],
cc_list: list[str],
bcc_list: list[str],
) -> bool:
"""Check if email should be skipped based on filters."""
if len(body) > 5000:
logging.debug(
"Skipping email %s: Body length > 5000 characters.",
message_id,
)
return True
total_recipients = len(to_list) + len(cc_list) + len(bcc_list)
if total_recipients > 30:
logging.debug(
"Skipping email %s: Total recipients (%d) > 30.",
message_id,
total_recipients,
)
return True
return False
def _prepare_recipient_data(
email_pk_id: int,
to_list: list[str],
cc_list: list[str],
bcc_list: list[str],
) -> list[tuple[int, str, str]]:
"""Prepare recipient data for database insertion."""
recipient_data = []
for addr in to_list:
recipient_data.append((email_pk_id, addr, "to"))
for addr in cc_list:
recipient_data.append((email_pk_id, addr, "cc"))
for addr in bcc_list:
recipient_data.append((email_pk_id, addr, "bcc"))
return recipient_data
def populate_database(db_path: str, dataset: Dataset) -> None:
"""Populates the database with data from the Hugging Face dataset."""
logging.info("Populating database %s...", db_path)
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
# --- Performance Pragmas ---
conn.execute("PRAGMA synchronous = OFF;")
conn.execute("PRAGMA journal_mode = MEMORY;")
record_count = 0
skipped_count = 0
duplicate_count = 0
processed_emails = set()
conn.execute("BEGIN TRANSACTION;")
for email_data in tqdm(dataset, desc="Inserting emails"):
assert isinstance(email_data, dict)
message_id = email_data["message_id"]
subject = email_data["subject"]
from_address = email_data["from"]
date_obj: datetime = email_data["date"]
body = email_data["body"]
file_name = email_data["file_name"]
to_list_raw = email_data["to"]
cc_list_raw = email_data["cc"]
bcc_list_raw = email_data["bcc"]
date_str = date_obj.strftime("%Y-%m-%d %H:%M:%S")
to_list = [str(addr) for addr in to_list_raw if addr]
cc_list = [str(addr) for addr in cc_list_raw if addr]
bcc_list = [str(addr) for addr in bcc_list_raw if addr]
if _should_skip_email(body, message_id, to_list, cc_list, bcc_list):
skipped_count += 1
continue
email_key = (subject, body, from_address)
if email_key in processed_emails:
logging.debug(
"Skipping duplicate email (Subject: %s..., From: %s)",
subject[:50],
from_address,
)
duplicate_count += 1
continue
processed_emails.add(email_key)
cursor.execute(
"""
INSERT INTO emails (
message_id, subject, from_address, date, body, file_name
)
VALUES (?, ?, ?, ?, ?, ?)
""",
(message_id, subject, from_address, date_str, body, file_name),
)
email_pk_id = cursor.lastrowid
if email_pk_id is None:
logging.warning(
"Failed to get email ID after insert for message_id: %s",
message_id,
)
continue
recipient_data = _prepare_recipient_data(
email_pk_id,
to_list,
cc_list,
bcc_list,
)
if recipient_data:
cursor.executemany(
"""
INSERT INTO recipients (
email_id, recipient_address, recipient_type
)
VALUES (?, ?, ?)
""",
recipient_data,
)
record_count += 1
conn.commit()
conn.close()
logging.info("Successfully inserted %d email records.", record_count)
if skipped_count > 0:
logging.info(
"Skipped %d email records due to length or recipient limits.",
skipped_count,
)
if duplicate_count > 0:
logging.info(
"Skipped %d duplicate email records "
"(based on subject, body, from).",
duplicate_count,
)
def create_indexes_and_triggers(db_path: str) -> None:
"""Creates indexes and triggers on the populated database."""
logging.info("Creating indexes and triggers for database: %s...", db_path)
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
cursor.executescript(SQL_CREATE_INDEXES_TRIGGERS)
conn.commit()
conn.close()
logging.info("Indexes and triggers created successfully.")
def generate_database(
repo_id: str = DEFAULT_REPO_ID,
db_path: str = DEFAULT_DB_PATH,
overwrite: bool = False,
) -> None:
"""
Generates the SQLite database from the specified Hugging Face dataset.
Simplified version without extensive error handling.
Args:
repo_id: The Hugging Face repository ID for the dataset.
db_path: The path where the SQLite database file should be
created.
overwrite: If True, any existing database file at db_path will
be removed.
"""
logging.info(
"Starting database generation for repo '%s' at '%s'",
repo_id,
db_path,
)
logging.info("Overwrite existing database: %s", overwrite)
db_dir = os.path.dirname(db_path)
if db_dir and not os.path.exists(db_dir):
logging.info("Creating data directory: %s", db_dir)
os.makedirs(db_dir)
if overwrite and os.path.exists(db_path):
logging.warning("Removing existing database file: %s", db_path)
os.remove(db_path)
elif not overwrite and os.path.exists(db_path):
# If not overwriting and file exists, subsequent steps might fail
# or behave unexpectedly. We are removing the explicit error here
# as requested.
logging.warning(
"Database file %s exists and overwrite is False. "
"Assuming file is already generated.",
db_path,
)
return
# 1. Download dataset
dataset = download_dataset(repo_id)
# 2. Create database schema (Tables only)
# Note: This will fail if overwrite=False and the file exists with
# incompatible schema/data.
create_database(db_path)
# 3. Populate database
populate_database(db_path, dataset)
# 4. Create Indexes and Triggers
create_indexes_and_triggers(db_path)
logging.info("Database generation process completed for %s.", db_path)
logging.info(
"Please set the environment variable DEFAULT_EMAIL_DB_PATH "
"to this path.",
)
if __name__ == "__main__":
generate_database(overwrite=True)

Binary file not shown.

After

Width:  |  Height:  |  Size: 442 KiB