Compare commits
10 Commits
9503bda45d
...
7fbfc96795
| Author | SHA1 | Date | |
|---|---|---|---|
| 7fbfc96795 | |||
|
|
679431b303 | ||
|
|
df0776cac4 | ||
|
|
400c1e77bf | ||
|
|
311ddfff46 | ||
|
|
b843abea93 | ||
|
|
e7f1fdf7ff | ||
|
|
654c35127a | ||
|
|
3821fb04ac | ||
|
|
5855c5161b |
@@ -72,7 +72,7 @@ This is a repository that **brings together a variety of ready-to-run Python age
|
|||||||
│ ├── frozen_lake/ # Teach an agent to play a game requiring multiple steps
|
│ ├── frozen_lake/ # Teach an agent to play a game requiring multiple steps
|
||||||
│ ├── learn_to_ask/ # Using LLM-as-a-judge to facilitate agent tuning
|
│ ├── learn_to_ask/ # Using LLM-as-a-judge to facilitate agent tuning
|
||||||
│ ├── email_search/ # Enhance the tool use ability of your agent
|
│ ├── email_search/ # Enhance the tool use ability of your agent
|
||||||
│ ├── werewolf_game/ # Enhance a multi-agent application
|
│ ├── werewolves/ # Enhance a multi-agent application
|
||||||
│ └── data_augment/ # Data augmentation for tuning
|
│ └── data_augment/ # Data augmentation for tuning
|
||||||
├── sample_template/ # Template for new sample contributions
|
├── sample_template/ # Template for new sample contributions
|
||||||
└── README.md
|
└── README.md
|
||||||
|
|||||||
@@ -72,7 +72,7 @@
|
|||||||
│ ├── frozen_lake/ # 教一个智能体玩需要多步操作的游戏
|
│ ├── frozen_lake/ # 教一个智能体玩需要多步操作的游戏
|
||||||
│ ├── learn_to_ask/ # 使用 LLM 作为评委辅助智能体训练
|
│ ├── learn_to_ask/ # 使用 LLM 作为评委辅助智能体训练
|
||||||
│ ├── email_search/ # 提升智能体的工具使用能力
|
│ ├── email_search/ # 提升智能体的工具使用能力
|
||||||
│ ├── werewolf_game/ # 强化多智能体应用能力
|
│ ├── werewolves/ # 强化多智能体应用能力
|
||||||
│ └── data_augment/ # 增强用于调优的数据
|
│ └── data_augment/ # 增强用于调优的数据
|
||||||
├── sample_template/ # 新样例贡献模板
|
├── sample_template/ # 新样例贡献模板
|
||||||
└── README.md
|
└── README.md
|
||||||
|
|||||||
@@ -207,10 +207,23 @@ alias_agent run --mode finance --task "Analyze Tesla's Q4 2024 financial perform
|
|||||||
# Data Science mode
|
# Data Science mode
|
||||||
alias_agent run --mode ds \
|
alias_agent run --mode ds \
|
||||||
--task "Analyze the distribution of incidents across categories in 'incident_records.csv' to identify imbalances, inconsistencies, or anomalies, and determine their root cause." \
|
--task "Analyze the distribution of incidents across categories in 'incident_records.csv' to identify imbalances, inconsistencies, or anomalies, and determine their root cause." \
|
||||||
--files ./docs/data/incident_records.csv
|
--datasource ./docs/data/incident_records.csv
|
||||||
```
|
```
|
||||||
|
|
||||||
**Note**: Files uploaded with `--files` are automatically copied to `/workspace` in the sandbox. Generated files are available in `sessions_mount_dir` subdirectories.
|
#### Input/Output Management
|
||||||
|
|
||||||
|
**Input:**
|
||||||
|
- Use the `--datasource` parameter (with aliases `--files` for backward compatibility) to specify data sources, supporting multiple formats:
|
||||||
|
- **Local files**: such as `./data.txt` or `/absolute/path/file.json`
|
||||||
|
- **Database DSN**: supports relational databases like PostgreSQL and SQLite, with format like `postgresql://user:password@host:port/database`
|
||||||
|
|
||||||
|
Examples: `--datasource file.txt postgresql://user:password@localhost:5432/mydb`
|
||||||
|
|
||||||
|
- Specified data sources will be automatically profiled (analyzed) and provide guidance for efficient data source access to the model.
|
||||||
|
- Uploaded files are automatically copied to the `/workspace` directory in the sandbox.
|
||||||
|
|
||||||
|
**Output:**
|
||||||
|
- Generated files are stored in subdirectories of `sessions_mount_dir`, where all output results can be found.
|
||||||
|
|
||||||
#### Enable Long-Term Memory Service (General Mode Only)
|
#### Enable Long-Term Memory Service (General Mode Only)
|
||||||
To enable the long-term memory service in General mode, you need to:
|
To enable the long-term memory service in General mode, you need to:
|
||||||
@@ -371,6 +384,85 @@ After the first startup, you can log in with the superuser credentials configure
|
|||||||
- **Password**: As specified in `FIRST_SUPERUSER_PASSWORD`
|
- **Password**: As specified in `FIRST_SUPERUSER_PASSWORD`
|
||||||
|
|
||||||
|
|
||||||
|
### 🌐 Basic Usage -- AgentScope Runtime Deployment
|
||||||
|
|
||||||
|
Alias is now fully compatible with [AgentScope Runtime](https://github.com/agentscope-ai/agentscope-runtime/), enabling you to quickly deploy Alias as a standardized backend service. Once launched, you can easily invoke Alias capabilities via the accompanying AgentScope Runtime API.
|
||||||
|
|
||||||
|
#### 1. Prerequisites
|
||||||
|
|
||||||
|
* **Sandbox & API Keys**: Please refer to the previous sections [🐳 Sandbox Setup (Optional)](#-sandbox-setup-optional) and [🔑 API Keys Configuration](#-api-keys-configuration) to complete the basic environment setup.
|
||||||
|
* **Environment Variables**: Copy the example environment file from the project root:
|
||||||
|
```bash
|
||||||
|
cp .env.example .env
|
||||||
|
```
|
||||||
|
* **Start Redis**: Required for caching and session management:
|
||||||
|
```bash
|
||||||
|
docker run -d -p 6379:6379 --name alias-redis redis:7-alpine
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 2. Installation & Sandbox Launch
|
||||||
|
|
||||||
|
Install the package in editable mode from the project root. This will automatically install the `alias_agent_runtime` CLI tool:
|
||||||
|
```bash
|
||||||
|
pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
To ensure proper code execution and file operations, start the sandbox server in a separate terminal:
|
||||||
|
```bash
|
||||||
|
runtime-sandbox-server --extension src/alias/runtime/alias_sandbox/alias_sandbox.py
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 3. Launching AgentScope Runtime Service
|
||||||
|
|
||||||
|
You can choose to start the service via the CLI or Python code, depending on your use case.
|
||||||
|
|
||||||
|
##### Option A: Using CLI (Recommended)
|
||||||
|
Use the `alias_agent_runtime` command to launch the backend service with one click:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
alias_agent_runtime --host 127.0.0.1 --port 8090 --chat-mode general
|
||||||
|
```
|
||||||
|
|
||||||
|
**Parameter Descriptions**:
|
||||||
|
* `--host` / `--port`: Specify the service address and port (default port is 8090).
|
||||||
|
* `--chat-mode`: Set the running mode. Options: `general`, `dr`, `browser`, `ds`, `finance` (default: `general`).
|
||||||
|
* `--web-ui`: (Optional) Enable AgentScope Runtime WebUI for a visual interaction interface. Skip this if you only need the API.
|
||||||
|
|
||||||
|
> **Note**: When enabling `--web-ui` for the first time, the system will automatically install necessary frontend dependencies. This may take a few minutes.
|
||||||
|
|
||||||
|
##### Option B: Using Python Code (Recommended for Developers)
|
||||||
|
If you wish to integrate or customize the launch logic within Python, you can use `AliasRunner` and `AgentApp` as shown below:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from agentscope_runtime.engine.app import AgentApp
|
||||||
|
from alias.server.runtime.runner.alias_runner import AliasRunner
|
||||||
|
|
||||||
|
# 1. Initialize AliasRunner
|
||||||
|
# default_chat_mode options: "general", "dr", "browser", "ds", "finance"
|
||||||
|
runner = AliasRunner(
|
||||||
|
default_chat_mode="general",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Create AgentApp instance
|
||||||
|
agent_app = AgentApp(
|
||||||
|
runner=runner,
|
||||||
|
app_name="Alias",
|
||||||
|
app_description="An LLM-empowered agent built on AgentScope and AgentScope-Runtime",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Run the service
|
||||||
|
# Set web_ui=True to enable the visual debugging interface
|
||||||
|
agent_app.run(host="127.0.0.1", port=8090)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 4. Accessing the Application
|
||||||
|
|
||||||
|
Once the service is running, you can access Alias via:
|
||||||
|
|
||||||
|
* **Runtime API Access**: Send standard HTTP POST requests to `http://localhost:8090/process`. This is the primary method for integrating Alias into third-party frontends or backend workflows.
|
||||||
|
* **Visual Monitoring (Optional)**: If started with the `--web-ui` flag, visit `http://localhost:5173`. This interface allows developers to observe the agent's reasoning process, tool execution traces, and other debugging information.
|
||||||
|
|
||||||
|
|
||||||
## ⚖️ License
|
## ⚖️ License
|
||||||
|
|
||||||
Alias-Agent is released under the **Apache 2.0 License** – see the [LICENSE](https://github.com/agentscope-ai/agentscope-samples/blob/main/LICENSE) file for details.
|
Alias-Agent is released under the **Apache 2.0 License** – see the [LICENSE](https://github.com/agentscope-ai/agentscope-samples/blob/main/LICENSE) file for details.
|
||||||
|
|||||||
@@ -208,10 +208,25 @@ alias_agent run --mode finance --task "Analyze Tesla's Q4 2024 financial perform
|
|||||||
# 数据科学(Data Science)模式
|
# 数据科学(Data Science)模式
|
||||||
alias_agent run --mode ds \
|
alias_agent run --mode ds \
|
||||||
--task "Analyze the distribution of incidents across categories in 'incident_records.csv' to identify imbalances, inconsistencies, or anomalies, and determine their root cause." \
|
--task "Analyze the distribution of incidents across categories in 'incident_records.csv' to identify imbalances, inconsistencies, or anomalies, and determine their root cause." \
|
||||||
--files ./docs/data/incident_records.csv
|
--datasource ./docs/data/incident_records.csv
|
||||||
```
|
```
|
||||||
|
|
||||||
**注意**:使用 `--files` 上传的文件会自动复制到沙盒中的 `/workspace`。生成的文件可在 `sessions_mount_dir` 的子目录中找到。
|
#### 输入/输出管理
|
||||||
|
|
||||||
|
**输入:**
|
||||||
|
- 使用 `--datasource` 参数指定数据源,支持多种格式 (向后兼容,也支持使用 `--files`):
|
||||||
|
- **本地文件**:如 `./data.txt` 或 `/absolute/path/file.json`
|
||||||
|
- **数据库 DSN**:支持 PostgreSQL、SQLite 等关系型数据库,格式如 `postgresql://user:password@host:port/database`
|
||||||
|
|
||||||
|
示例: `--datasource file.txt postgresql://user:password@localhost:5432/mydb`
|
||||||
|
- 指定的数据源会自动进行 profile(分析),并为模型提供高效访问数据源的指导。
|
||||||
|
- 上传的文件会自动复制到沙盒中的 `/workspace` 目录。
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
**输出:**
|
||||||
|
- 生成的文件存储在 `sessions_mount_dir` 的子目录中,可以在该位置找到所有输出结果。
|
||||||
|
|
||||||
|
|
||||||
#### 启用长期记忆服务(仅限通用模式)
|
#### 启用长期记忆服务(仅限通用模式)
|
||||||
要在通用模式下启用长期记忆服务,您需要:
|
要在通用模式下启用长期记忆服务,您需要:
|
||||||
@@ -371,6 +386,83 @@ bash script/start_memory_service.sh
|
|||||||
- **用户名**:如 `FIRST_SUPERUSER_USERNAME` 所指定 (默认: `alias`)
|
- **用户名**:如 `FIRST_SUPERUSER_USERNAME` 所指定 (默认: `alias`)
|
||||||
- **密码**:如 `FIRST_SUPERUSER_PASSWORD` 所指定
|
- **密码**:如 `FIRST_SUPERUSER_PASSWORD` 所指定
|
||||||
|
|
||||||
|
### 🌐 基础用法 -- AgentScope Runtime 部署
|
||||||
|
|
||||||
|
Alias 现已适配 [AgentScope Runtime](https://github.com/agentscope-ai/agentscope-runtime/),您可以利用 AgentScope Runtime 将 Alias 快速部署为标准后端服务。启动后,通过配套的 AgentScope Runtime API 即可轻松调用 Alias 所提供的服务。
|
||||||
|
|
||||||
|
#### 1. 前期准备
|
||||||
|
|
||||||
|
* **沙盒设置与 API 密钥**:请参考前文的 [🐳 沙盒设置](#-沙盒设置可选) 和 [🔑 API 密钥配置](#-api-密钥配置) 完成基础环境配置。
|
||||||
|
* **配置环境变量**:从项目根目录复制示例环境文件:
|
||||||
|
```bash
|
||||||
|
cp .env.example .env
|
||||||
|
```
|
||||||
|
* **启动 Redis**:缓存和会话管理所需:
|
||||||
|
```bash
|
||||||
|
docker run -d -p 6379:6379 --name alias-redis redis:7-alpine
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 2. 安装与沙盒启动
|
||||||
|
|
||||||
|
在项目根目录下,以开发模式安装包,这将自动安装 `alias_agent_runtime` 命令行工具:
|
||||||
|
```bash
|
||||||
|
pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
为了确保代码执行和文件操作等功能正常,请在另一个终端启动沙盒服务器:
|
||||||
|
```bash
|
||||||
|
runtime-sandbox-server --extension src/alias/runtime/alias_sandbox/alias_sandbox.py
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 3. 启动 AgentScope Runtime 服务
|
||||||
|
|
||||||
|
您可以根据使用场景,选择通过命令行或 Python 代码启动服务。
|
||||||
|
|
||||||
|
##### 选项 A:使用命令行工具(推荐)
|
||||||
|
使用 `alias_agent_runtime` 命令一键启动后端服务:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
alias_agent_runtime --host 127.0.0.1 --port 8090 --chat-mode general
|
||||||
|
```
|
||||||
|
|
||||||
|
**参数说明**:
|
||||||
|
* `--host` / `--port`: 指定服务的运行地址和端口(默认端口为 8090)。
|
||||||
|
* `--chat-mode`: 设置运行模式,可选 `general`, `dr`, `browser`, `ds`, `finance`(默认为 `general`)。
|
||||||
|
* `--web-ui` : (可选) 启用 AgentScope Runtime WebUI 以开启可视化交互界面。若仅需调用 API,请忽略此参数。
|
||||||
|
|
||||||
|
> **注意**:首次启动并开启 `--web-ui` 时,系统会自动安装必要的前端依赖包,可能需要花费几分钟时间,请耐心等待。
|
||||||
|
|
||||||
|
##### 选项 B:使用代码启动(开发者推荐)
|
||||||
|
如果您希望在 Python 代码中集成或自定义启动逻辑,可以参考以下示例,结合 `AliasRunner` 和 `AgentApp`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from agentscope_runtime.engine.app import AgentApp
|
||||||
|
from alias.server.runtime.runner.alias_runner import AliasRunner
|
||||||
|
|
||||||
|
# 1. 初始化 AliasRunner
|
||||||
|
# default_chat_mode 可选: "general", "dr", "browser", "ds", "finance"
|
||||||
|
runner = AliasRunner(
|
||||||
|
default_chat_mode="general",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. 创建 AgentApp 实例
|
||||||
|
agent_app = AgentApp(
|
||||||
|
runner=runner,
|
||||||
|
app_name="Alias",
|
||||||
|
app_description="An LLM-empowered agent built on AgentScope and AgentScope-Runtime",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. 运行服务
|
||||||
|
# 如需启用可视化调试界面,可设置 web_ui=True
|
||||||
|
agent_app.run(host="127.0.0.1", port=8090)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 4. 访问应用程序
|
||||||
|
|
||||||
|
服务启动后,您可以通过以下方式访问 Alias:
|
||||||
|
|
||||||
|
* **Runtime API 调用**:通过标准 HTTP POST 请求访问 `http://localhost:8090/process`。这是将 Alias 集成至第三方前端或后端工作流的主要方式。
|
||||||
|
* **可视化监控 (可选)**:若启动时开启了 `--web-ui` 参数,可通过 `http://localhost:5173` 访问 WebUI。该界面主要用于开发者观察智能体的思考过程以及工具调用轨迹等调试信息。
|
||||||
|
|
||||||
## ⚖️ 许可证
|
## ⚖️ 许可证
|
||||||
|
|
||||||
|
|||||||
@@ -45,7 +45,8 @@ dependencies = [
|
|||||||
"agentscope-runtime>=1.0.0",
|
"agentscope-runtime>=1.0.0",
|
||||||
"aiosqlite>=0.21.0",
|
"aiosqlite>=0.21.0",
|
||||||
"asyncpg>=0.30.0",
|
"asyncpg>=0.30.0",
|
||||||
"itsdangerous>=2.2.0"
|
"itsdangerous>=2.2.0",
|
||||||
|
"polars>=1.37.1"
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.setuptools]
|
[tool.setuptools]
|
||||||
@@ -65,3 +66,4 @@ dev = [
|
|||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
alias_agent = "alias.cli:main"
|
alias_agent = "alias.cli:main"
|
||||||
|
alias_agent_runtime = "alias.server.alias_agent_app:main"
|
||||||
|
|||||||
@@ -0,0 +1,74 @@
|
|||||||
|
---
|
||||||
|
name: csv-excel-file
|
||||||
|
description: Guidelines for handling CSV/Excel files
|
||||||
|
type:
|
||||||
|
- csv
|
||||||
|
- excel
|
||||||
|
---
|
||||||
|
|
||||||
|
# CSV/Excel Handling Specifications
|
||||||
|
|
||||||
|
## Goals
|
||||||
|
|
||||||
|
- Safely load tabular data without crashing.
|
||||||
|
- Detect and handle messy spreadsheets (multiple blocks, missing headers, merged cells artifacts).
|
||||||
|
- Produce reliable outputs (clean dataframe for clean table or structured JSON for messy spreadsheet) with validated types.
|
||||||
|
|
||||||
|
## Encoding, Delimiters, and Locale
|
||||||
|
|
||||||
|
- CSV encoding: Try UTF-8; if garbled, attempt common fallbacks (e.g., gbk, cp1252) based on context.
|
||||||
|
- Delimiters: Detect common separators (,, \t, ;, |) during inspection.
|
||||||
|
- Locale formats: Be cautious with comma decimal separators and thousands separators.
|
||||||
|
|
||||||
|
## Inspection (always first)
|
||||||
|
|
||||||
|
- Identify file type, encoding (CSV), and sheet names (Excel) before full reads.
|
||||||
|
- Prefer small reads to preview structure:
|
||||||
|
- CSV: pd.read_csv(..., nrows=20); if uncertain delimiter: sep=None, engine="python" (small nrows only).
|
||||||
|
- Excel: pd.ExcelFile(path).sheet_names, then pd.read_excel(..., sheet_name=..., nrows=20).
|
||||||
|
- Use df.head(n) and df.columns to check:
|
||||||
|
- Missing/incorrect headers (e.g., columns are numeric 0..N-1)
|
||||||
|
- "Unnamed: X" columns
|
||||||
|
- Unexpected NaN/NaT, merged-cell artifacts
|
||||||
|
- Multiple tables/blocks in one sheet (blank rows separating sections)
|
||||||
|
|
||||||
|
## Preprocessing
|
||||||
|
|
||||||
|
- Treat as messy if any of the following is present:
|
||||||
|
- Columns contain "Unnamed:" or mostly empty column names
|
||||||
|
- Header row appears inside the data (first rows look like data + later row looks like header)
|
||||||
|
- Multiple data blocks (large blank-row gaps, repeated header patterns)
|
||||||
|
- Predominantly NaN/NaT in top rows/left columns
|
||||||
|
- Notes/metadata blocks above/beside the table (titles, footnotes, merged header areas)
|
||||||
|
- If messy spreadsheets are detected:
|
||||||
|
- First choice: use `clean_messy_spreadsheet` tool to extract key tables/fields and output JSON.
|
||||||
|
- Only fall back to manual parsing if tool fails, returns empty/incorrect structure, or cannot locate the target table.
|
||||||
|
|
||||||
|
## Querying
|
||||||
|
|
||||||
|
- Never load entire datasets blindly.
|
||||||
|
- Use minimal reads:
|
||||||
|
- `nrows`, `usecols`, `dtype` (or partial dtype mapping), `parse_dates` only when necessary.
|
||||||
|
- Sampling: `skiprows` with a step pattern for rough profiling when file is huge.
|
||||||
|
- For very large CSV:
|
||||||
|
- Prefer `chunksize` iteration; aggregate/compute per chunk.
|
||||||
|
- For Excel:
|
||||||
|
- Read only needed `sheet_name`, and consider narrowing `usecols`/`nrows` during exploration.
|
||||||
|
|
||||||
|
## Data Quality & Type Validation
|
||||||
|
|
||||||
|
- After load/clean:
|
||||||
|
- Validate types:
|
||||||
|
- Numeric columns: coerce with pd.to_numeric(errors="coerce")
|
||||||
|
- Datetime columns: pd.to_datetime(errors="coerce")
|
||||||
|
- Report coercion fallout (how many became NaN/NaT).
|
||||||
|
- Standardize missing values: treat empty strings/“N/A”/“null” consistently.
|
||||||
|
|
||||||
|
# Best Practices
|
||||||
|
|
||||||
|
- Always inspect structure before processing.
|
||||||
|
- Handle encoding issues appropriately
|
||||||
|
- Keep reads minimal; expand only after confirming layout.
|
||||||
|
- Log decisions: chosen sheet, detected header row, dropped columns/rows, dtype conversions.
|
||||||
|
- Avoid silent data loss: when dropping/cleaning, summarize what changed.
|
||||||
|
- Validate data types after loading
|
||||||
@@ -0,0 +1,47 @@
|
|||||||
|
---
|
||||||
|
name: image-file
|
||||||
|
description: Guidelines for handling image files
|
||||||
|
type: image
|
||||||
|
---
|
||||||
|
|
||||||
|
# Images Handling Specifications
|
||||||
|
|
||||||
|
## Goals
|
||||||
|
|
||||||
|
- Safely identify image properties and metadata without memory exhaustion.
|
||||||
|
- Accurately extract text (OCR) and visual elements (Object Detection/Description).
|
||||||
|
- Perform necessary pre-processing (resize, normalize, crop) for downstream tasks.
|
||||||
|
- Handle multi-frame or high-resolution images efficiently.
|
||||||
|
|
||||||
|
## Inspection (Always First)
|
||||||
|
|
||||||
|
- Identify Properties: Use lightweight libraries (e.g., PIL/Pillow) to get `format`, `size` (width/height), and `mode` (RGB, RGBA, CMYK).
|
||||||
|
- Check File Size: If the image is exceptionally large (e.g., >20MB or >100MP), consider downsampling or tiling before full processing.
|
||||||
|
- Metadata/EXIF Extraction:
|
||||||
|
- Read EXIF data for orientation, GPS tags, and timestamps.
|
||||||
|
- Correction: Automatically apply EXIF orientation to ensure the image is "upright" before visual analysis.
|
||||||
|
|
||||||
|
## Content Extraction & Vision
|
||||||
|
|
||||||
|
- Vision Analysis:
|
||||||
|
- Use multimodal vision models to describe scenes, identify objects, and detect activities.
|
||||||
|
- For complex images (e.g., infographics, UI screenshots), guide the model to focus on specific regions.
|
||||||
|
- OCR (Optical Character Recognition):
|
||||||
|
- If text is detected, specify whether to extract "raw text" or "structured data" (like forms/tables).
|
||||||
|
- Handle low-contrast or noisy backgrounds by applying pre-filters (grayscale, binarization).
|
||||||
|
- Format Conversion: Convert non-standard formats (e.g., HEIC, TIFF) to standard formats (JPEG/PNG) if tools require it.
|
||||||
|
|
||||||
|
## Handling Large or Complex Images
|
||||||
|
|
||||||
|
- Tiling: For ultra-high-res images (e.g., satellite maps, medical scans), split into overlapping tiles to avoid missing small details.
|
||||||
|
- Batching: Process multiple images using generators to keep memory usage stable.
|
||||||
|
- Alpha Channel: Be mindful of transparency (PNG/WebP); decide whether to discard it or composite against a solid background (e.g., white).
|
||||||
|
|
||||||
|
## Best Practices
|
||||||
|
|
||||||
|
- Safety First: Validate that the file is a genuine image (not a renamed malicious script).
|
||||||
|
- Graceful Failure: Handle corrupted files, truncated downloads, or unsupported formats with descriptive error logs.
|
||||||
|
- Efficiency: Avoid unnecessary re-encoding (e.g., multiple JPEG saves) to prevent "generation loss" or artifacts.
|
||||||
|
- Process images individually or in small batches to prevent system crashes
|
||||||
|
- Consider memory usage when working with large or high-resolution images
|
||||||
|
- Resource Management: Close file pointers or use context managers (`with Image.open(...) as img:`) to prevent memory leaks.
|
||||||
@@ -0,0 +1,54 @@
|
|||||||
|
---
|
||||||
|
name: json-file
|
||||||
|
description: Guildlines for handling json files
|
||||||
|
type: json
|
||||||
|
---
|
||||||
|
|
||||||
|
# JSON Handling Specifications
|
||||||
|
|
||||||
|
## Goals
|
||||||
|
- Safely parse JSON/JSONL without memory overflow.
|
||||||
|
- Discover schema structure (keys, nesting depth, data types).
|
||||||
|
- Flatten complex nested structures into tabular data when necessary.
|
||||||
|
- Handle inconsistent schemas and "dirty" JSON (e.g., trailing commas, mixed types).
|
||||||
|
|
||||||
|
## Inspection (Always First)
|
||||||
|
|
||||||
|
- Structure Discovery:
|
||||||
|
- Determine if the root is a `list` or a `dict`.
|
||||||
|
- Identify if it's a standard JSON or JSONL (one valid JSON object per line).
|
||||||
|
- Schema Sampling:
|
||||||
|
- For large files, read the first few objects/lines to infer the schema.
|
||||||
|
- Identify top-level keys and their types.
|
||||||
|
- Detect nesting depth: If depth > 3, consider it a "deeply nested" structure.
|
||||||
|
- Size Check:
|
||||||
|
- If the file is large (>50MB), avoid `json.load()`. Use iterative parsing or streaming.
|
||||||
|
|
||||||
|
## Processing & Extraction
|
||||||
|
|
||||||
|
- Lazy Loading (Streaming):
|
||||||
|
- For massive JSON: Use `ijson` (Python) or similar streaming parsers to yield specific paths/items.
|
||||||
|
- For JSONL: Read line-by-line using a generator to minimize memory footprint.
|
||||||
|
- Flattening & Normalization:
|
||||||
|
- Use `pandas.json_normalize` to convert nested structures into flat tables if the goal is analysis.
|
||||||
|
- Specify `max_level` during normalization to prevent "column explosion."
|
||||||
|
- Data Filtering:
|
||||||
|
- Extract only required sub-trees (keys) early in the process to reduce the memory object size.
|
||||||
|
|
||||||
|
## Data Quality & Schema Validation
|
||||||
|
|
||||||
|
- Missing Keys: Use `.get(key, default)` or `try-except` blocks. Never assume a key exists in all objects.
|
||||||
|
- Type Coercion:
|
||||||
|
- Validate numeric strings vs. actual numbers.
|
||||||
|
- Standardize `null`, `""`, and `[]` consistently.
|
||||||
|
- Encoding: Default to UTF-8; check for BOM (utf-8-sig) if parsing fails.
|
||||||
|
- Malformed JSON Recovery:
|
||||||
|
- For minor syntax errors (e.g., single quotes instead of double), attempt `ast.literal_eval` or regex-based cleanup only as a fallback.
|
||||||
|
|
||||||
|
## Best Practices
|
||||||
|
|
||||||
|
- Minimal Reads: Don't load a 50MB JSON just to read one config key; use a streaming approach.
|
||||||
|
- Schema Logging: Document the detected structure (e.g., "Root is a list of 500 objects; key 'metadata' is nested").
|
||||||
|
- Error Transparency: When a JSON object in a JSONL stream is corrupted, log the line number, skip it, and continue instead of crashing the entire process.
|
||||||
|
- Avoid Over-Flattening: Be cautious with deeply nested arrays; flattening them can lead to massive row duplication.
|
||||||
|
- Strict Typing: After extraction, explicitly convert types (e.g., `pd.to_datetime`) to ensure downstream reliability.
|
||||||
@@ -0,0 +1,70 @@
|
|||||||
|
---
|
||||||
|
name: database
|
||||||
|
description: Guidelines for handling databases
|
||||||
|
type: relational_db
|
||||||
|
---
|
||||||
|
|
||||||
|
# Database Handling Specifications
|
||||||
|
|
||||||
|
## Goals
|
||||||
|
|
||||||
|
- Safely explore database schema without performance degradation.
|
||||||
|
- Construct precise, efficient SQL queries that prevent system crashes (OOM & OOT).
|
||||||
|
- Handle dialect-specific nuances (PostgreSQL, MySQL, SQLite, etc.).
|
||||||
|
- Transform raw result sets into structured, validated data for analysis.
|
||||||
|
|
||||||
|
## Inspection
|
||||||
|
|
||||||
|
- Volume Estimation:
|
||||||
|
- Before any `SELECT *`, always run `SELECT COUNT(*) FROM table_name` to understand the scale.
|
||||||
|
- If a table has >1,000,000 rows, strictly use indexed columns for filtering.
|
||||||
|
- Sample Data:
|
||||||
|
- Use `SELECT * FROM table_name LIMIT 5` to see actual data formats.
|
||||||
|
|
||||||
|
## Querying
|
||||||
|
|
||||||
|
- Safety Constraints:
|
||||||
|
- Always use `LIMIT`: Never execute a query without a `LIMIT` clause unless the row count is confirmed to be small.
|
||||||
|
- Avoid `SELECT *`: In production-scale tables, explicitly name columns to reduce I/O and memory usage.
|
||||||
|
- Dialect & Syntax:
|
||||||
|
- Case Sensitivity: If a column/table name contains uppercase or special characters, MUST quote it (e.g., `"UserTable"` in Postgres, `` `UserTable` `` in MySQL).
|
||||||
|
- Date/Time: Use standard ISO strings for date filtering; be mindful of timezone-aware vs. naive columns.
|
||||||
|
- Complex Queries:
|
||||||
|
- For `JOIN` operations, ensure joining columns are indexed to prevent full table scans.
|
||||||
|
- When performing `GROUP BY`, ensure the result set size is manageable.
|
||||||
|
|
||||||
|
## Data Retrieval & Transformation
|
||||||
|
|
||||||
|
- Type Mapping:
|
||||||
|
- Ensure SQL types (e.g., `DECIMAL`, `BIGINT`, `TIMESTAMP`) are correctly mapped to Python/JSON types without precision loss.
|
||||||
|
- Convert `NULL` values to a consistent "missing" representation (e.g., `None` or `NaN`).
|
||||||
|
- Chunked Fetching:
|
||||||
|
- For medium-to-large exports, use `fetchmany(size)` or `OFFSET/LIMIT` pagination instead of fetching everything into memory at once.
|
||||||
|
- Aggregations:
|
||||||
|
- Prefer performing calculations (SUM, AVG, COUNT) at the database level rather than pulling raw data to the client for processing.
|
||||||
|
|
||||||
|
## Error Handling & Recovery
|
||||||
|
|
||||||
|
- Timeout Management: If a query takes too long, retry with more restrictive filters or optimized joins.
|
||||||
|
- Syntax Errors: If a query fails, inspect the dialect-specific error message and re-verify the schema (it's often a misspelled column or missing quotes).
|
||||||
|
|
||||||
|
## Anti-Pattern Prevention (Avoiding "Bad" SQL)
|
||||||
|
|
||||||
|
- Index-Friendly Filters: Never wrap indexed columns in functions (e.g., `DATE()`, `UPPER()`) within the `WHERE` clause.
|
||||||
|
- Join Safety: Always verify join keys. Before joining, check if the key has high cardinality to avoid massive intermediate result sets.
|
||||||
|
- Memory Safety:
|
||||||
|
- Avoid `DISTINCT` and `UNION` (which performs de-duplication) on multi-million row sets unless necessary; use `UNION ALL` if duplicates are acceptable.
|
||||||
|
- Avoid `ORDER BY` on large non-indexed text fields.
|
||||||
|
- Wildcard Warning: Strictly avoid leading wildcards in `LIKE` patterns (e.g., `%term`) on large text columns.
|
||||||
|
- No Function on Columns: `WHERE col = FUNC(val)` is good; `WHERE FUNC(col) = val` is bad.
|
||||||
|
- Explicit Columns: Only fetch what is necessary.
|
||||||
|
- Early Filtering: Push `WHERE` conditions as close to the base tables as possible.
|
||||||
|
- CTE for Clarity: Use `WITH` for complex multi-step logic to improve maintainability and optimizer hints.
|
||||||
|
|
||||||
|
# Best Practices
|
||||||
|
|
||||||
|
- Always verify database structure before querying
|
||||||
|
- Use appropriate sampling techniques for large datasets
|
||||||
|
- Optimize queries for efficiency based on schema inspection
|
||||||
|
- Self-review the draft SQL against the "Anti-Pattern Prevention" list.
|
||||||
|
- Perform a silent mental 'EXPLAIN' on your query. If it smells like a full table scan on a large table, refactor it before outputting
|
||||||
@@ -0,0 +1,50 @@
|
|||||||
|
---
|
||||||
|
name: text-file
|
||||||
|
description: Guidelines for handling text files
|
||||||
|
type: text
|
||||||
|
---
|
||||||
|
|
||||||
|
# Text Files Handling Specifications
|
||||||
|
|
||||||
|
## Goals
|
||||||
|
- Safely read text files without memory exhaustion.
|
||||||
|
- Accurately detect encoding to avoid garbled characters.
|
||||||
|
- Identify underlying patterns (e.g., Log formats, Markdown structure, delimiters).
|
||||||
|
- Efficiently extract or search for specific information within large volumes of text.
|
||||||
|
|
||||||
|
## Encoding & Detection
|
||||||
|
|
||||||
|
- Encoding Strategy:
|
||||||
|
- Default to `utf-8`.
|
||||||
|
- If it fails, try `utf-8-sig` (for files with BOM), `gbk/gb18030` (for Chinese context), or `latin-1`.
|
||||||
|
- Use `chardet` or similar logic if encoding is unknown and first few bytes look non-standard.
|
||||||
|
- Line Endings: Be aware of `\n` (Unix), `\r\n` (Windows), and `\r` (Legacy Mac) when counting lines or splitting.
|
||||||
|
|
||||||
|
## Inspection
|
||||||
|
|
||||||
|
- Preview: Read the first 10-20 lines to determine:
|
||||||
|
- Content Type: Is it a log, code, prose, or a semi-structured list?
|
||||||
|
- Uniformity: Does every line follow the same format?
|
||||||
|
- Metadata: Check total file size before reading. If >50MB, treat as a "large file" and avoid full loading.
|
||||||
|
|
||||||
|
## Querying & Reading (Large Files)
|
||||||
|
|
||||||
|
- Streaming: For files exceeding memory or >50MB:
|
||||||
|
- Use `with open(path) as f: for line in f:` to process line-by-line.
|
||||||
|
- Never use `.read()` or `.readlines()` on large files.
|
||||||
|
- Random Sampling: To understand a huge file's structure, read the first N lines, the middle N lines (using `f.seek()`), and the last N lines.
|
||||||
|
- Pattern Matching: Use Regular Expressions (Regex) for targeted extraction instead of complex string slicing.
|
||||||
|
- Grep-like Search: If searching for a keyword, iterate through lines and only store/return matching lines + context.
|
||||||
|
|
||||||
|
## Data Quality
|
||||||
|
|
||||||
|
- Truncation Warning: If only a portion of the file is read, clearly state: "Displaying first X lines of Y total lines."
|
||||||
|
- Empty Lines/Comments: Decide early whether to ignore blank lines or lines starting with specific comment characters (e.g., `#`, `//`).
|
||||||
|
|
||||||
|
## Best Practices
|
||||||
|
|
||||||
|
- Resource Safety: Always use context managers (`with` statement) to ensure file handles are closed.
|
||||||
|
- Memory Consciousness: For logs and large TXT, prioritize "find and extract" over "load and filter."
|
||||||
|
- Regex Optimization: Compile regex patterns if they are used repeatedly in a loop over millions of lines.
|
||||||
|
- Validation: After reading, verify the content isn't binary (e.g., PDF or EXE renamed to .txt) by checking for null bytes or a high density of non-ASCII characters.
|
||||||
|
- Progress Logging: For long-running text processing, log progress every 100k lines or 10% of file size.
|
||||||
@@ -23,6 +23,7 @@ from alias.agent.tools import AliasToolkit, share_tools
|
|||||||
from alias.agent.agents.common_agent_utils import (
|
from alias.agent.agents.common_agent_utils import (
|
||||||
get_user_input_to_mem_pre_reply_hook,
|
get_user_input_to_mem_pre_reply_hook,
|
||||||
)
|
)
|
||||||
|
from alias.agent.agents.data_source.data_source import DataSourceManager
|
||||||
from .ds_agent_utils import (
|
from .ds_agent_utils import (
|
||||||
ReportGenerator,
|
ReportGenerator,
|
||||||
LLMPromptSelector,
|
LLMPromptSelector,
|
||||||
@@ -50,7 +51,8 @@ class DataScienceAgent(AliasAgentBase):
|
|||||||
formatter: FormatterBase,
|
formatter: FormatterBase,
|
||||||
memory: MemoryBase,
|
memory: MemoryBase,
|
||||||
toolkit: AliasToolkit,
|
toolkit: AliasToolkit,
|
||||||
sys_prompt: str = None,
|
data_manager: DataSourceManager = None,
|
||||||
|
sys_prompt: str = "",
|
||||||
max_iters: int = 30,
|
max_iters: int = 30,
|
||||||
tmp_file_storage_dir: str = "/workspace",
|
tmp_file_storage_dir: str = "/workspace",
|
||||||
state_saving_dir: Optional[str] = None,
|
state_saving_dir: Optional[str] = None,
|
||||||
@@ -71,17 +73,16 @@ class DataScienceAgent(AliasAgentBase):
|
|||||||
|
|
||||||
set_run_ipython_cell(self.toolkit.sandbox)
|
set_run_ipython_cell(self.toolkit.sandbox)
|
||||||
|
|
||||||
self.uploaded_files: List[str] = []
|
|
||||||
|
|
||||||
self.todo_list: List[Dict[str, Any]] = []
|
self.todo_list: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
self.infer_trajectories: List[List[Msg]] = []
|
self.tmp_file_storage_dir = tmp_file_storage_dir
|
||||||
|
|
||||||
|
self.data_manager = data_manager
|
||||||
|
|
||||||
self.detailed_report_path = os.path.join(
|
self.detailed_report_path = os.path.join(
|
||||||
tmp_file_storage_dir,
|
tmp_file_storage_dir,
|
||||||
"detailed_report.html",
|
"detailed_report.html",
|
||||||
)
|
)
|
||||||
self.tmp_file_storage_dir = tmp_file_storage_dir
|
|
||||||
|
|
||||||
self.todo_list_prompt = get_prompt_from_file(
|
self.todo_list_prompt = get_prompt_from_file(
|
||||||
os.path.join(
|
os.path.join(
|
||||||
@@ -91,12 +92,19 @@ class DataScienceAgent(AliasAgentBase):
|
|||||||
False,
|
False,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._sys_prompt = get_prompt_from_file(
|
self._sys_prompt = (
|
||||||
os.path.join(
|
cast(
|
||||||
PROMPT_DS_BASE_PATH,
|
str,
|
||||||
"_agent_system_workflow_prompt.md",
|
get_prompt_from_file(
|
||||||
),
|
os.path.join(
|
||||||
False,
|
PROMPT_DS_BASE_PATH,
|
||||||
|
"_agent_system_workflow_prompt.md",
|
||||||
|
),
|
||||||
|
False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
+ "\n\n"
|
||||||
|
+ sys_prompt
|
||||||
)
|
)
|
||||||
|
|
||||||
# load prompts and initialize selector
|
# load prompts and initialize selector
|
||||||
@@ -167,7 +175,7 @@ class DataScienceAgent(AliasAgentBase):
|
|||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[{self.name}] "
|
f"[{self.name}] "
|
||||||
"DeepInsightAgent initialized (fully model-driven).",
|
"DataScienceAgent initialized (fully model-driven).",
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -427,27 +435,58 @@ class DataScienceAgent(AliasAgentBase):
|
|||||||
memory_log=memory_log,
|
memory_log=memory_log,
|
||||||
)
|
)
|
||||||
|
|
||||||
response, report = await report_generator.generate_report()
|
(
|
||||||
|
response,
|
||||||
|
report_md,
|
||||||
|
report_html,
|
||||||
|
) = await report_generator.generate_report()
|
||||||
|
|
||||||
|
if report_md:
|
||||||
|
md_report_path = os.path.join(
|
||||||
|
self.tmp_file_storage_dir,
|
||||||
|
"detailed_report.md",
|
||||||
|
)
|
||||||
|
|
||||||
if report:
|
|
||||||
# report = report.replace(self.tmp_file_storage_dir, ".")
|
|
||||||
await self.toolkit.call_tool_function(
|
await self.toolkit.call_tool_function(
|
||||||
ToolUseBlock(
|
ToolUseBlock(
|
||||||
type="tool_use",
|
type="tool_use",
|
||||||
id=str(uuid.uuid4()),
|
id=str(uuid.uuid4()),
|
||||||
name="write_file",
|
name="write_file",
|
||||||
input={
|
input={
|
||||||
"path": self.detailed_report_path,
|
"path": md_report_path,
|
||||||
"content": report,
|
"content": report_md,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
response = (
|
response = (
|
||||||
f"{response}\n\n"
|
f"{response}\n\n"
|
||||||
"The detailed report has been saved to "
|
"The detailed report (markdown version) has been saved to "
|
||||||
f"{self.detailed_report_path}."
|
f"{md_report_path}."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if report_html:
|
||||||
|
html_report_path = os.path.join(
|
||||||
|
self.tmp_file_storage_dir,
|
||||||
|
"detailed_report.html",
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.toolkit.call_tool_function(
|
||||||
|
ToolUseBlock(
|
||||||
|
type="tool_use",
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
name="write_file",
|
||||||
|
input={
|
||||||
|
"path": html_report_path,
|
||||||
|
"content": report_html,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
response = (
|
||||||
|
f"{response}\n\n"
|
||||||
|
"The detailed report (html version) has been saved to "
|
||||||
|
f"{html_report_path}."
|
||||||
|
)
|
||||||
|
|
||||||
kwargs["response"] = response
|
kwargs["response"] = response
|
||||||
structured_output = {}
|
structured_output = {}
|
||||||
|
|
||||||
|
|||||||
@@ -698,7 +698,7 @@ class MetaPlanner(AliasAgentBase):
|
|||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Directly enter the data science mode.
|
Directly enter the data science mode.
|
||||||
Use this when the user provides some data files and ask for processing.
|
Use this for COMPLEX, CODE-BASED data analysis.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_query (`str`):
|
user_query (`str`):
|
||||||
|
|||||||
@@ -0,0 +1,722 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# pylint: disable=R1702,R0912,R0915
|
||||||
|
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
import pandas as pd
|
||||||
|
from sqlalchemy import inspect, text, create_engine
|
||||||
|
from agentscope.message import Msg
|
||||||
|
|
||||||
|
from alias.agent.agents.data_source._typing import SourceType
|
||||||
|
from alias.agent.agents.ds_agent_utils import (
|
||||||
|
get_prompt_from_file,
|
||||||
|
)
|
||||||
|
from alias.agent.utils.llm_call_manager import (
|
||||||
|
LLMCallManager,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseDataProfiler(ABC):
|
||||||
|
"""Abstract base class for data profilers that analyze different data
|
||||||
|
sources like csv, excel, db, etc.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_PROFILE_PROMPT_BASE_PATH = os.path.join(
|
||||||
|
os.path.dirname(__file__),
|
||||||
|
"built_in_prompt",
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
path: str,
|
||||||
|
source_type: SourceType,
|
||||||
|
llm_call_manager: LLMCallManager,
|
||||||
|
):
|
||||||
|
"""Initialize the data profiler with data path, type and LLM manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Path to the data source file or connection string
|
||||||
|
source_type: Enum indicating the type of data source
|
||||||
|
llm_call_manager: Manager for handling LLM calls
|
||||||
|
"""
|
||||||
|
self.path = path
|
||||||
|
self.file_name = os.path.basename(path)
|
||||||
|
self.source_type = source_type
|
||||||
|
self.llm_call_manager = llm_call_manager
|
||||||
|
|
||||||
|
self.source_types_2_prompts = {
|
||||||
|
SourceType.CSV: "_profile_csv_prompt.md",
|
||||||
|
SourceType.EXCEL: "_profile_xlsx_prompt.md",
|
||||||
|
SourceType.IMAGE: "_profile_image_prompt.md",
|
||||||
|
SourceType.RELATIONAL_DB: "_profile_relationdb_prompt.md",
|
||||||
|
"IRREGULAR": "_profile_irregular_xlsx_prompt.md",
|
||||||
|
}
|
||||||
|
if source_type not in self.source_types_2_prompts:
|
||||||
|
raise ValueError(f"Unsupported source type: {source_type}")
|
||||||
|
self.prompt = self._load_prompt(source_type)
|
||||||
|
|
||||||
|
base_model_name = self.llm_call_manager.get_base_model_name()
|
||||||
|
vl_model_name = self.llm_call_manager.get_vl_model_name()
|
||||||
|
|
||||||
|
self.source_types_2_models = {
|
||||||
|
SourceType.CSV: base_model_name,
|
||||||
|
SourceType.EXCEL: base_model_name,
|
||||||
|
SourceType.IMAGE: vl_model_name,
|
||||||
|
SourceType.RELATIONAL_DB: base_model_name,
|
||||||
|
}
|
||||||
|
self.model_name = self.source_types_2_models[source_type]
|
||||||
|
|
||||||
|
def _load_prompt(self, source_type: Any = None):
|
||||||
|
"""Load the appropriate prompt template based on the source type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source_type: Type of data source (CSV, EXCEL, IMAGE, etc.)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Loaded prompt template as string
|
||||||
|
"""
|
||||||
|
prompt_file_name = self.source_types_2_prompts[source_type]
|
||||||
|
prompt = get_prompt_from_file(
|
||||||
|
os.path.join(
|
||||||
|
self._PROFILE_PROMPT_BASE_PATH,
|
||||||
|
prompt_file_name,
|
||||||
|
),
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
async def generate_profile(self) -> Dict[str, Any]:
|
||||||
|
"""Generate a complete data profile
|
||||||
|
by reading data, generating content,
|
||||||
|
calling the LLM, and wrapping the response.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing the complete data profile
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self.data = await self._read_data()
|
||||||
|
# different source types have different data building methods
|
||||||
|
content = self._build_content_with_prompt_and_data(
|
||||||
|
self.prompt,
|
||||||
|
self.data,
|
||||||
|
)
|
||||||
|
# content = self.prompt.format(data=self.data)
|
||||||
|
res = await self._generate_profile_by_llm(content)
|
||||||
|
self.profile = self._wrap_data_response(res)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error generating profile: {e}")
|
||||||
|
self.profile = {}
|
||||||
|
return self.profile
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def tool_clean_json(raw_response: str):
|
||||||
|
"""Clean and parse JSON response from LLM by removing markdown
|
||||||
|
markers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
raw_response: Raw string response from LLM
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Parsed JSON object from the cleaned response
|
||||||
|
"""
|
||||||
|
cleaned_response = raw_response.strip()
|
||||||
|
if cleaned_response.startswith("```json"):
|
||||||
|
cleaned_response = cleaned_response[len("```json") :].lstrip()
|
||||||
|
if cleaned_response.startswith("```"):
|
||||||
|
cleaned_response = cleaned_response[len("```") :].lstrip()
|
||||||
|
if cleaned_response.endswith("```"):
|
||||||
|
cleaned_response = cleaned_response[:-3].rstrip()
|
||||||
|
return json.loads(cleaned_response)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _build_content_with_prompt_and_data(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
data: Any,
|
||||||
|
) -> str:
|
||||||
|
"""Abstract method to build content for LLM based on prompt
|
||||||
|
and data.
|
||||||
|
|
||||||
|
This method should be implemented by subclasses to format
|
||||||
|
content appropriately for different data types.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: Prompt template to use
|
||||||
|
data: Processed data to include in the prompt
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted content for LLM call
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def _read_data(self):
|
||||||
|
"""Abstract method to read and process data from the source path.
|
||||||
|
|
||||||
|
This method should be implemented by subclasses to handle
|
||||||
|
specific
|
||||||
|
data source types (CSV, Excel, DB, etc.).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Processed data in appropriate format for the data type
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def _generate_profile_by_llm(
|
||||||
|
self,
|
||||||
|
content: Any,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Generate profile by calling LLM with prepared content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Content to send to the LLM (text or multimodal)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary response parsed from LLM output
|
||||||
|
"""
|
||||||
|
sys_prompt = "You are a helpful AI assistant for database management."
|
||||||
|
msgs = [
|
||||||
|
Msg("system", sys_prompt, "system"),
|
||||||
|
Msg("user", content, "user"),
|
||||||
|
]
|
||||||
|
response = await self.llm_call_manager(
|
||||||
|
model_name=self.model_name,
|
||||||
|
messages=msgs,
|
||||||
|
)
|
||||||
|
response = BaseDataProfiler.tool_clean_json(response)
|
||||||
|
return response
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _wrap_data_response(self, response: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Abstract method to combine LLM response with original schema.
|
||||||
|
|
||||||
|
This method should be implemented by subclasses to properly
|
||||||
|
merge
|
||||||
|
LLM-generated descriptions with original data structure.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: Dictionary response from LLM
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Combined dictionary with original schema and LLM response
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class StructuredDataProfiler(BaseDataProfiler):
|
||||||
|
"""Base class for profilers that work with structured data sources
|
||||||
|
like CSV, Excel, and relational databases.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_irregular(cols: list[str]):
|
||||||
|
"""Determine if a table has irregular column names (many unnamed).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cols: List of column names from the dataset
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Boolean indicating whether the dataset is irregular
|
||||||
|
"""
|
||||||
|
# any(col.startswith('Unnamed') for col in df.columns.astype(str))?
|
||||||
|
unnamed_columns_ratio = sum(
|
||||||
|
col.startswith("Unnamed") for col in cols.astype(str)
|
||||||
|
) / len(cols)
|
||||||
|
return unnamed_columns_ratio >= 0.5
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_schema_from_table(df: pd.DataFrame, df_name: str) -> dict:
|
||||||
|
"""Analyzes a single DataFrame to extract metadata and samples.
|
||||||
|
|
||||||
|
Extracts column names, data types, and sample values to provide a
|
||||||
|
comprehensive view of the table structure for the LLM.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: The dataframe to analyze
|
||||||
|
df_name: Name of the table (or sheet/filename)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing schema metadata for the table
|
||||||
|
"""
|
||||||
|
col_list = []
|
||||||
|
for col in df.columns:
|
||||||
|
dtype_name = str(df[col].dtype).upper()
|
||||||
|
# Get random samples to help LLM understand the data content
|
||||||
|
# sample(frac=1): shuffle the data
|
||||||
|
# head(n_samples): get the first n_samples,
|
||||||
|
# if less than n_samples, retrieved here without any errors.
|
||||||
|
candidates = (
|
||||||
|
df[col]
|
||||||
|
.drop_duplicates()
|
||||||
|
.sample(frac=1, random_state=42)
|
||||||
|
.head(5)
|
||||||
|
.astype(str)
|
||||||
|
.tolist()
|
||||||
|
)
|
||||||
|
# Limit the size not to exceed 1000 characters.
|
||||||
|
# TODO: dynamic size control? 1000 is too small?
|
||||||
|
samples = []
|
||||||
|
length = 0
|
||||||
|
for s in candidates:
|
||||||
|
if (length := length + len(s)) <= 1000:
|
||||||
|
samples.append(s)
|
||||||
|
col_list.append(
|
||||||
|
{
|
||||||
|
"column name": col,
|
||||||
|
"data type": dtype_name,
|
||||||
|
"data samples": samples,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# Create a CSV snippet of the first few rows
|
||||||
|
raw_data_snippet = df.head(5).to_csv(index=True)
|
||||||
|
|
||||||
|
table_schema = {
|
||||||
|
"name": df_name,
|
||||||
|
"raw_data_snippet": raw_data_snippet,
|
||||||
|
# Note: Row count logic might need optimization for large files
|
||||||
|
# TODO: how to get the row count more efficiently, openpyxl.
|
||||||
|
"row_count": len(df) if len(df) < 100 else None,
|
||||||
|
"col_count": len(df.columns),
|
||||||
|
"columns": col_list,
|
||||||
|
}
|
||||||
|
return table_schema
|
||||||
|
|
||||||
|
def _build_content_with_prompt_and_data(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
data: Any,
|
||||||
|
) -> str:
|
||||||
|
"""Format the prompt with data for structured data sources.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: Template prompt string
|
||||||
|
data: Processed data structure
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted content string ready for LLM
|
||||||
|
"""
|
||||||
|
return prompt.format(data=data)
|
||||||
|
|
||||||
|
def _wrap_data_response(self, response: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Merges the original schema with the LLM-generated response.
|
||||||
|
|
||||||
|
Combines the structural information from the original data with
|
||||||
|
semantic descriptions generated by the LLM.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: Dictionary response from LLM with descriptions
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Combined schema with both structural and semantic info
|
||||||
|
"""
|
||||||
|
new_schema = {}
|
||||||
|
new_schema["name"] = self.data["name"]
|
||||||
|
new_schema["description"] = response["description"]
|
||||||
|
# # For flat files like CSV, they contain columns
|
||||||
|
if "columns" in self.data:
|
||||||
|
new_schema["columns"] = self.data["columns"]
|
||||||
|
# # For multi-table sources like Excel/Database,
|
||||||
|
# they contain tables. Each table contains columns and description
|
||||||
|
if "tables" in self.data and "tables" in response:
|
||||||
|
new_schema["tables"] = []
|
||||||
|
# Build a map for response tables and descriptions
|
||||||
|
res_des_map = {
|
||||||
|
table["name"]: table["description"]
|
||||||
|
for table in response["tables"]
|
||||||
|
}
|
||||||
|
for table in self.data["tables"]:
|
||||||
|
table_name = table["name"]
|
||||||
|
if table_name not in res_des_map:
|
||||||
|
continue
|
||||||
|
new_table = {}
|
||||||
|
new_table["name"] = table_name
|
||||||
|
# Retain the desrciption from the LLM response
|
||||||
|
new_table["description"] = res_des_map[table_name]
|
||||||
|
if "columns" in table:
|
||||||
|
new_table["columns"] = table["columns"]
|
||||||
|
if "irregular_judgment" in table:
|
||||||
|
new_table["irregular_judgment"] = table[
|
||||||
|
"irregular_judgment"
|
||||||
|
]
|
||||||
|
new_schema["tables"].append(new_table)
|
||||||
|
return new_schema
|
||||||
|
|
||||||
|
|
||||||
|
class ExcelProfiler(StructuredDataProfiler):
|
||||||
|
async def _extract_irregular_table(
|
||||||
|
self,
|
||||||
|
path: str,
|
||||||
|
raw_data_snippet: str,
|
||||||
|
sheet_name: str,
|
||||||
|
):
|
||||||
|
"""Extract structure from irregular Excel sheets with unnamed
|
||||||
|
columns. Uses a special LLM call to identify the actual data in
|
||||||
|
sheets with headers or other content above the main data table.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Path to the Excel file
|
||||||
|
raw_data_snippet: Raw text snippet of the sheet content
|
||||||
|
sheet_name: Name of the sheet being processed
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Schema dictionary for the irregular table structure
|
||||||
|
"""
|
||||||
|
prompt = self._load_prompt("IRREGULAR")
|
||||||
|
content = prompt.format(raw_snippet_data=raw_data_snippet)
|
||||||
|
res = await self._generate_profile_by_llm(content=content)
|
||||||
|
|
||||||
|
if "is_extractable_table" in res and res["is_extractable_table"]:
|
||||||
|
logger.debug(res["reasoning"])
|
||||||
|
skiprows = res["row_start_index"] + 1
|
||||||
|
cols_range = res["col_ranges"]
|
||||||
|
df = pd.read_excel(
|
||||||
|
path,
|
||||||
|
sheet_name=sheet_name,
|
||||||
|
nrows=100,
|
||||||
|
skiprows=skiprows,
|
||||||
|
usecols=range(cols_range[0], cols_range[1] + 1),
|
||||||
|
).convert_dtypes()
|
||||||
|
if StructuredDataProfiler.is_irregular(df.columns):
|
||||||
|
schema = {
|
||||||
|
"name": sheet_name,
|
||||||
|
"raw_data_snippet": raw_data_snippet,
|
||||||
|
"irregular_judgment": "UNSTRUCTURED",
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
schema = self._extract_schema_from_table(df, sheet_name)
|
||||||
|
schema["irregular_judgment"] = res
|
||||||
|
else:
|
||||||
|
schema = {
|
||||||
|
"name": sheet_name,
|
||||||
|
"raw_data_snippet": raw_data_snippet,
|
||||||
|
"irregular_judgment": "UNSTRUCTURED",
|
||||||
|
}
|
||||||
|
|
||||||
|
return schema
|
||||||
|
|
||||||
|
async def _read_data(self):
|
||||||
|
"""Read and process Excel file data including all sheets.
|
||||||
|
|
||||||
|
Handles both regular and irregular Excel files by using pandas
|
||||||
|
for regular files and openpyxl for files with unnamed columns.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing metadata for all sheets in the Excel file
|
||||||
|
"""
|
||||||
|
excel_file = pd.ExcelFile(self.path)
|
||||||
|
table_schemas = []
|
||||||
|
schema = {}
|
||||||
|
schema["name"] = self.file_name
|
||||||
|
for sheet_name in excel_file.sheet_names:
|
||||||
|
# TODO: use openpyxl to read excel to avoid irregular excel.
|
||||||
|
# Read a subset of each sheet
|
||||||
|
df = pd.read_excel(
|
||||||
|
self.path,
|
||||||
|
sheet_name=sheet_name,
|
||||||
|
nrows=100,
|
||||||
|
).convert_dtypes()
|
||||||
|
if not StructuredDataProfiler.is_irregular(df.columns):
|
||||||
|
table_schema = (
|
||||||
|
StructuredDataProfiler._extract_schema_from_table(
|
||||||
|
df,
|
||||||
|
sheet_name,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# if unnamed columns, use openpyxl to extract top 100 rows.
|
||||||
|
import openpyxl
|
||||||
|
|
||||||
|
wb = openpyxl.load_workbook(
|
||||||
|
self.path,
|
||||||
|
read_only=True,
|
||||||
|
data_only=True,
|
||||||
|
)
|
||||||
|
ws = wb[sheet_name]
|
||||||
|
rows_data = []
|
||||||
|
for i, row in enumerate(
|
||||||
|
ws.iter_rows(values_only=True),
|
||||||
|
start=1,
|
||||||
|
):
|
||||||
|
if i > 100:
|
||||||
|
break
|
||||||
|
rows_data.append(
|
||||||
|
",".join(
|
||||||
|
"" if cell is None else str(cell) for cell in row
|
||||||
|
),
|
||||||
|
)
|
||||||
|
wb.close()
|
||||||
|
raw_data_snippet = "\n".join(rows_data)
|
||||||
|
|
||||||
|
table_schema = await self._extract_irregular_table(
|
||||||
|
self.path,
|
||||||
|
raw_data_snippet,
|
||||||
|
sheet_name,
|
||||||
|
)
|
||||||
|
# table_schema = {
|
||||||
|
# "name": sheet_name,
|
||||||
|
# "raw_data_snippet": "\n".join(rows_data),
|
||||||
|
# }
|
||||||
|
table_schemas.append(table_schema)
|
||||||
|
schema["tables"] = table_schemas
|
||||||
|
return schema
|
||||||
|
|
||||||
|
|
||||||
|
class RelationalDatabaseProfiler(StructuredDataProfiler):
|
||||||
|
async def _read_data(self):
|
||||||
|
"""
|
||||||
|
Extracts metadata (schema) for all tables in a relational db.
|
||||||
|
|
||||||
|
path (str): The Database Source Name (connection string).
|
||||||
|
eg. postgresql://user:passward@ip:port/db_name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing database metadata for all tables
|
||||||
|
"""
|
||||||
|
options = {
|
||||||
|
"isolation_level": "AUTOCOMMIT",
|
||||||
|
# Test conns before use (handles MySQL 8hr timeout, network drops)
|
||||||
|
"pool_pre_ping": True,
|
||||||
|
# Keep minimal conns (MCP typically handles 1 request at a time)
|
||||||
|
"pool_size": 1,
|
||||||
|
# Allow temporary burst capacity for edge cases
|
||||||
|
"max_overflow": 2,
|
||||||
|
# Force refresh conns older than 1hr (under MySQL's 8hr default)
|
||||||
|
"pool_recycle": 3600,
|
||||||
|
}
|
||||||
|
engine = create_engine(self.path, **options)
|
||||||
|
try:
|
||||||
|
connection = engine.connect()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Connection to {self.path} failed: {e}")
|
||||||
|
raise ConnectionError(f"Failed to connect to database: {e}") from e
|
||||||
|
|
||||||
|
# Use DSN as the db identifier (can parsed cleaner)
|
||||||
|
database_name = self.path
|
||||||
|
inspector = inspect(connection)
|
||||||
|
table_names = inspector.get_table_names()
|
||||||
|
|
||||||
|
tables_data = []
|
||||||
|
for table_name in table_names:
|
||||||
|
try:
|
||||||
|
# 1. Get column information
|
||||||
|
columns = inspector.get_columns(table_name)
|
||||||
|
col_count = len(columns)
|
||||||
|
|
||||||
|
# 2. Get row count
|
||||||
|
row_count_result = connection.execute(
|
||||||
|
text(f"SELECT COUNT(*) FROM {table_name}"),
|
||||||
|
).fetchone()
|
||||||
|
row_count = row_count_result[0] if row_count_result else 0
|
||||||
|
|
||||||
|
# 3. Get raw data snippet (first 5 rows)
|
||||||
|
raw_data_snippet = ""
|
||||||
|
try:
|
||||||
|
result = connection.execute(
|
||||||
|
text(f"SELECT * FROM {table_name} LIMIT 5"),
|
||||||
|
)
|
||||||
|
rows = result.fetchall()
|
||||||
|
if rows:
|
||||||
|
column_names = [col["name"] for col in columns]
|
||||||
|
lines = []
|
||||||
|
# Add header
|
||||||
|
lines.append(", ".join(column_names))
|
||||||
|
# Add data rows
|
||||||
|
for row in rows:
|
||||||
|
row_values = []
|
||||||
|
for value in row:
|
||||||
|
if value is None:
|
||||||
|
row_values.append("NULL")
|
||||||
|
else:
|
||||||
|
# Escape commas and newlines
|
||||||
|
val_str = str(value)
|
||||||
|
if "," in val_str or "\n" in val_str:
|
||||||
|
val_str = f'"{val_str}"'
|
||||||
|
row_values.append(val_str)
|
||||||
|
lines.append(", ".join(row_values))
|
||||||
|
raw_data_snippet = "\n".join(lines)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Error fetching {table_name} data: {str(e)}",
|
||||||
|
)
|
||||||
|
raw_data_snippet = None
|
||||||
|
# 4. detailed column info (types and samples)
|
||||||
|
column_details = []
|
||||||
|
if rows:
|
||||||
|
for i, col in enumerate(columns):
|
||||||
|
col_name = col["name"]
|
||||||
|
col_type = str(col["type"])
|
||||||
|
# Extract samples for this column from the fetched rows
|
||||||
|
sample_values = []
|
||||||
|
for row in rows:
|
||||||
|
if i < len(row):
|
||||||
|
val = row[i]
|
||||||
|
sample_values.append(
|
||||||
|
str(val) if val is not None else "NULL",
|
||||||
|
)
|
||||||
|
|
||||||
|
column_details.append(
|
||||||
|
{
|
||||||
|
"column name": col_name,
|
||||||
|
"data type": col_type,
|
||||||
|
"data sample": sample_values[:3],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
table_info = {
|
||||||
|
"name": table_name,
|
||||||
|
"row_count": row_count,
|
||||||
|
"col_count": col_count,
|
||||||
|
"raw_data_snippet": raw_data_snippet,
|
||||||
|
"columns": column_details,
|
||||||
|
}
|
||||||
|
|
||||||
|
tables_data.append(table_info)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# If one table fails, log it and continue to the next
|
||||||
|
logger.warning(f"Error processing {table_name}: {str(e)}")
|
||||||
|
continue
|
||||||
|
# Contruct the final schema
|
||||||
|
schema = {
|
||||||
|
"name": database_name,
|
||||||
|
"tables": tables_data,
|
||||||
|
}
|
||||||
|
self.data = schema
|
||||||
|
return schema
|
||||||
|
|
||||||
|
|
||||||
|
class CsvProfiler(ExcelProfiler):
|
||||||
|
async def _read_data(self):
|
||||||
|
"""Handles schema extraction for CSV as single-table sources.
|
||||||
|
|
||||||
|
Uses Polars for efficient row counting on large files and
|
||||||
|
pandas for detailed schema analysis of the first 100 rows.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Schema dictionary for the CSV file
|
||||||
|
"""
|
||||||
|
import polars as pl
|
||||||
|
|
||||||
|
# Use Polars for efficient row counting on large files
|
||||||
|
df = pl.scan_csv(self.path, ignore_errors=True)
|
||||||
|
row_count = df.select(pl.len()).collect().item()
|
||||||
|
# Read a subset with Pandas for detailed schema analysis
|
||||||
|
df = pd.read_csv(self.path, nrows=100).convert_dtypes()
|
||||||
|
schema = self._extract_schema_from_table(df, self.file_name)
|
||||||
|
schema["row_count"] = row_count
|
||||||
|
# if StructuredDataProfiler.is_irregular(df.columns):
|
||||||
|
# self._extract_irregular_table(...)
|
||||||
|
return schema
|
||||||
|
|
||||||
|
|
||||||
|
class ImageProfiler(BaseDataProfiler):
|
||||||
|
"""Profiler for image data sources that uses multimodal LLMs."""
|
||||||
|
|
||||||
|
async def _read_data(self):
|
||||||
|
"""
|
||||||
|
For images, this simply returns the path since the LLM API
|
||||||
|
handles loading the image directly.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to the image file
|
||||||
|
"""
|
||||||
|
return self.path
|
||||||
|
|
||||||
|
def _build_content_with_prompt_and_data(self, prompt, data):
|
||||||
|
"""build multimodal content for image analysis.
|
||||||
|
|
||||||
|
Creates content in the format required by multimodal LLM APIs
|
||||||
|
with both image and text components.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: Text prompt template for image analysis
|
||||||
|
data: Path to the image file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List containing image and text components for the LLM call
|
||||||
|
"""
|
||||||
|
# Convert image paths according to the model requirements
|
||||||
|
contents = [
|
||||||
|
{
|
||||||
|
"text": prompt,
|
||||||
|
"type": "text",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"source": {
|
||||||
|
"url": data,
|
||||||
|
"type": "url",
|
||||||
|
},
|
||||||
|
"type": "image",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
return contents
|
||||||
|
|
||||||
|
def _wrap_data_response(self, response: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Format the LLM response for image data into dict.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: Dictionary response from multimodal LLM
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Profile dictionary with image name, description and details
|
||||||
|
"""
|
||||||
|
profile = {
|
||||||
|
"name": self.file_name,
|
||||||
|
"description": response["description"],
|
||||||
|
"details": response["details"],
|
||||||
|
}
|
||||||
|
return profile
|
||||||
|
|
||||||
|
|
||||||
|
class DataProfilerFactory:
|
||||||
|
"""Factory class to create appropriate data profiler instances based
|
||||||
|
on source type.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_profiler(
|
||||||
|
llm_call_manager: LLMCallManager,
|
||||||
|
path: str,
|
||||||
|
source_type: SourceType,
|
||||||
|
) -> BaseDataProfiler:
|
||||||
|
"""Factory method to get the appropriate profiler instance.
|
||||||
|
Generate the correct profile result for the source.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Path to the data source or connection string
|
||||||
|
source_type: Enum indicating the type of data source
|
||||||
|
llm_call_manager: Manager for handling LLM calls
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Instance of the appropriate profiler subclass
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the source_type is unsupported
|
||||||
|
"""
|
||||||
|
if source_type == SourceType.IMAGE:
|
||||||
|
return ImageProfiler(
|
||||||
|
path=path,
|
||||||
|
source_type=source_type,
|
||||||
|
llm_call_manager=llm_call_manager,
|
||||||
|
)
|
||||||
|
elif source_type == SourceType.CSV:
|
||||||
|
return CsvProfiler(
|
||||||
|
path=path,
|
||||||
|
source_type=source_type,
|
||||||
|
llm_call_manager=llm_call_manager,
|
||||||
|
)
|
||||||
|
elif source_type == SourceType.EXCEL:
|
||||||
|
return ExcelProfiler(
|
||||||
|
path=path,
|
||||||
|
source_type=source_type,
|
||||||
|
llm_call_manager=llm_call_manager,
|
||||||
|
)
|
||||||
|
elif source_type == SourceType.RELATIONAL_DB:
|
||||||
|
return RelationalDatabaseProfiler(
|
||||||
|
path=path,
|
||||||
|
source_type=source_type,
|
||||||
|
llm_call_manager=llm_call_manager,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported source type: {source_type}")
|
||||||
@@ -0,0 +1,21 @@
|
|||||||
|
{
|
||||||
|
"relational_db": {
|
||||||
|
"mcp_server": {
|
||||||
|
"mcp_alchemy": {
|
||||||
|
"command": "uvx",
|
||||||
|
"args": [
|
||||||
|
"--from",
|
||||||
|
"mcp-alchemy==2025.8.15.91819",
|
||||||
|
"--with",
|
||||||
|
"psycopg2-binary",
|
||||||
|
"--refresh-package",
|
||||||
|
"mcp-alchemy",
|
||||||
|
"mcp-alchemy"
|
||||||
|
],
|
||||||
|
"env": {
|
||||||
|
"DB_URL": "${endpoint}"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
54
alias/src/alias/agent/agents/data_source/_typing.py
Normal file
54
alias/src/alias/agent/agents/data_source/_typing.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class SourceAccessType(str, Enum):
|
||||||
|
"""Simple source access type classification"""
|
||||||
|
|
||||||
|
DIRECT = "direct"
|
||||||
|
VIA_MCP = "via_mcp"
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.value
|
||||||
|
|
||||||
|
|
||||||
|
class SourceType(str, Enum):
|
||||||
|
"""Simple source type classification"""
|
||||||
|
|
||||||
|
CSV = "csv"
|
||||||
|
JSON = "json"
|
||||||
|
EXCEL = "excel"
|
||||||
|
TEXT = "text"
|
||||||
|
IMAGE = "image"
|
||||||
|
|
||||||
|
# Database sources
|
||||||
|
RELATIONAL_DB = "relational_db"
|
||||||
|
|
||||||
|
OTHER = "other"
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.value
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_valid_source_type(value: str) -> bool:
|
||||||
|
try:
|
||||||
|
SourceType(value)
|
||||||
|
return True
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
# Define mapping between SourceType and SourceAccessType
|
||||||
|
SOURCE_TYPE_TO_ACCESS_TYPE = {
|
||||||
|
# File types -> LOCAL_FILE
|
||||||
|
SourceType.CSV: SourceAccessType.DIRECT,
|
||||||
|
SourceType.JSON: SourceAccessType.DIRECT,
|
||||||
|
SourceType.EXCEL: SourceAccessType.DIRECT,
|
||||||
|
SourceType.TEXT: SourceAccessType.DIRECT,
|
||||||
|
SourceType.IMAGE: SourceAccessType.DIRECT,
|
||||||
|
# Database types -> MCP_TOOL
|
||||||
|
SourceType.RELATIONAL_DB: SourceAccessType.VIA_MCP,
|
||||||
|
# Unknown type -> depends on endpoint
|
||||||
|
SourceType.OTHER: None,
|
||||||
|
}
|
||||||
@@ -0,0 +1,83 @@
|
|||||||
|
# Role
|
||||||
|
You are an expert Data Steward. Your task is to generate a single, comprehensive description sentence for a CSV file based on its metadata and raw content.
|
||||||
|
|
||||||
|
# Input Format
|
||||||
|
You will receive a single JSON string in the variable `input_json`. The structure is:
|
||||||
|
```json
|
||||||
|
{{
|
||||||
|
"name": "filename.csv",
|
||||||
|
"raw_data_snippet": "col1, col2\na, b",
|
||||||
|
"row_count": 100,
|
||||||
|
"col_count": 5,
|
||||||
|
"columns": [
|
||||||
|
{{ "column name": "col1", "data type": "string", "data sample": ["a", "b"] }}
|
||||||
|
],
|
||||||
|
}}
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
# Analysis Logic
|
||||||
|
|
||||||
|
## 1. Context & Metrics Extraction
|
||||||
|
|
||||||
|
* **Subject:** Extract the core concept from the `name` field (e.g., `logistics_data.csv` -> "logistics_data").
|
||||||
|
* **Metrics:** Identify `row_count` and `col_count`.
|
||||||
|
* **Context:** Look for time (e.g., "2024") or location keywords in the `raw_data_snippet` or `name`.
|
||||||
|
|
||||||
|
## 2. Schema Identification
|
||||||
|
|
||||||
|
* **Primary:** Use column names from the `columns` list.
|
||||||
|
* **Secondary (Inference):** If the `columns` list is empty or generic (e.g., "col1"), you MUST infer meaningful column names from the `raw_data_snippet` values (e.g., "2023-01-01" -> `date`).
|
||||||
|
* **Selection:** Choose 3-5 key columns to represent the dataset structure.
|
||||||
|
|
||||||
|
## 3. Description Construction
|
||||||
|
|
||||||
|
* Generate a **single** grammatical sentence.
|
||||||
|
* **Strict Template:** "The file [FileName] contains [Subject] data [Optional: Context] with [RowCount] rows and [ColCount] columns, featuring fields such as [List of 3-5 key columns]."
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
# Output Format (Strict JSON)
|
||||||
|
|
||||||
|
You must output a single valid JSON object containing only the `description` key.
|
||||||
|
|
||||||
|
```json
|
||||||
|
{{
|
||||||
|
"description": "The file [FileName] contains [Subject] data with [Rows] rows and [Cols] columns, featuring fields such as [Columns]."
|
||||||
|
}}
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
# One-Shot Demonstration
|
||||||
|
|
||||||
|
**[Example Input]**
|
||||||
|
`input_json` =
|
||||||
|
|
||||||
|
```json
|
||||||
|
{{
|
||||||
|
"name": "logistics_data.csv",
|
||||||
|
"raw_data_snippet": "SHP-001, Tokyo, London, 2024-05-20\nSHP-002, NY, Paris, 2024-05-21",
|
||||||
|
"row_count": 2000,
|
||||||
|
"col_count": 4,
|
||||||
|
"columns": [
|
||||||
|
{{
|
||||||
|
"column name": "shipment_id",
|
||||||
|
"data type": "string",
|
||||||
|
"data sample": ["SHP-001", "SHP-002"]
|
||||||
|
}}
|
||||||
|
],
|
||||||
|
}}
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
**[Example Output]**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{{
|
||||||
|
"description": "The file logistics_data.csv contains supply chain logistics information for 2024 with 2000 rows and 4 columns, featuring fields such as shipment_id, origin, destination, and date."
|
||||||
|
}}
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
# Input
|
||||||
|
input_json = {data}
|
||||||
@@ -0,0 +1,31 @@
|
|||||||
|
Please carefully analyze this image and perform the following tasks:
|
||||||
|
|
||||||
|
### Step 1: Overall Assessment
|
||||||
|
- Determine whether the image contains "single content" or "composite content" (i.e., multiple independent information modules).
|
||||||
|
- If it is composite content, list the main components (e.g., "Bar chart in the top-left, data table in the bottom-right, title description at the top").
|
||||||
|
|
||||||
|
### Step 2: Region-wise Analysis (for composite content)
|
||||||
|
For each prominent content region, describe it using the following template:
|
||||||
|
|
||||||
|
#### [Module X] Type: [flowchart/table/chart/document/photo]
|
||||||
|
- Position and scope: Briefly describe its location (e.g., "left half", "bottom table")
|
||||||
|
- Content extraction:
|
||||||
|
- If flowchart/diagram: Describe nodes and connections in logical order, and explain label meanings.
|
||||||
|
- If table: Reconstruct row/column structure; present in Markdown table format if possible.
|
||||||
|
- If chart: Explain axes, series, trends, and provide key conclusions (e.g., "Peak reached in Q3").
|
||||||
|
- If document/text: Extract key sentences while preserving original meaning.
|
||||||
|
- If photo: Describe scene, people, and actions.
|
||||||
|
- Functional role: Infer the module's purpose within the whole image (e.g., "Supports the conclusion stated above").
|
||||||
|
|
||||||
|
### Step 3: Global Synthesis
|
||||||
|
- Summarize the core purpose of the entire image (e.g., "Presents quarterly performance analysis").
|
||||||
|
- Describe logical relationships among modules (e.g., "The table provides data sources, the chart shows trends, and the text offers recommendations").
|
||||||
|
- If there are annotations (e.g., label1, Cost2), explain their business meaning.
|
||||||
|
|
||||||
|
### Final Output Requirements
|
||||||
|
Do not output the internal analysis steps separately. You must output the final result **ONLY** in the following format:
|
||||||
|
|
||||||
|
{{
|
||||||
|
"description": "A single, concise sentence describing the overall framework of the image",
|
||||||
|
"details": "A comprehensive and detailed description based on your Step 2 & 3 analysis. Use Markdown formatting (bullet points, bold text) inside this field to ensure the structure is clear and readable."
|
||||||
|
}}
|
||||||
@@ -0,0 +1,82 @@
|
|||||||
|
# Role
|
||||||
|
You are an expert Data Engineer specializing in unstructured Excel parsing. Your task is to analyze the raw content of the first 100 rows of an Excel sheet and determine if it contains structured tabular data suitable for a Pandas DataFrame.
|
||||||
|
|
||||||
|
If it is a valid table, identify the **Header Row** and the **Column Range**.
|
||||||
|
If it is NOT a valid table (e.g., a dashboard, a form, a letter, or empty), you must flag it as unsuitable.
|
||||||
|
|
||||||
|
# Task Analysis
|
||||||
|
Excel sheets fall into two categories:
|
||||||
|
1. **List-Like Tables (Valid)**: Contains a header row followed by multiple rows of consistent record data. This is what we want.
|
||||||
|
2. **Unstructured/Layout-Heavy (Unstructured)**:
|
||||||
|
- **Forms/KV Pairs**: "Label: Value" scattered across the sheet.
|
||||||
|
- **Dashboards**: Multiple small tables, charts, or scattered numbers.
|
||||||
|
- **Text/Notes**: Paragraphs of text or disclaimers without column structure.
|
||||||
|
- **Empty/Near Empty**: Contains almost no data.
|
||||||
|
|
||||||
|
# Rules for Detection
|
||||||
|
|
||||||
|
### A. Validity Check (The "Gatekeeper")
|
||||||
|
Set `is_extractable_table` to **false** if:
|
||||||
|
- There is no distinct row where meaningful column headers align horizontally.
|
||||||
|
- The data is scattered (e.g., values exist in A1, G5, and C20 with no relation).
|
||||||
|
- The sheet looks like a printed form (Key on the left, Value on the right) rather than a list of records.
|
||||||
|
- There are fewer than 3 rows of data following a potential header.
|
||||||
|
|
||||||
|
### B. Structure Extraction (Only if Valid)
|
||||||
|
If the sheet passes the Validity Check:
|
||||||
|
1. **Header Row**: Find the first row containing multiple distinct string values that serve as column labels.
|
||||||
|
2. **Column Range**: Identify the start index (first valid header) and end index (last valid header) to define the width.
|
||||||
|
3. **Data Continuity**: Verify that rows below the header contain consistent data types (e.g., Dates under "Date").
|
||||||
|
|
||||||
|
# Input Data
|
||||||
|
The user will provide the first 100 rows in CSV/Markdown format (0-based index).
|
||||||
|
|
||||||
|
# Output Format
|
||||||
|
You must output a strictly valid JSON object.
|
||||||
|
JSON Structure:
|
||||||
|
{{
|
||||||
|
"is_extractable_table": <boolean, true if it serves as a dataframe source, false otherwise>,
|
||||||
|
"row_start_index": <int or null, 0-based index of the header row>,
|
||||||
|
"col_ranges": <list [start, end] or null, inclusive 0-based column indices>,
|
||||||
|
"confidence_score": <float, 0-1>,
|
||||||
|
"reasoning": "<string, explain what the row data contains. declare the final conclusion(IRREGULAR,REGULAR,INVALIED). >"
|
||||||
|
}}
|
||||||
|
|
||||||
|
# Examples
|
||||||
|
|
||||||
|
## Example 1 (Valid Table with Noise)
|
||||||
|
Input:
|
||||||
|
Title: Monthly Sales, NaN, NaN, NaN
|
||||||
|
NaN, NaN, NaN, NaN
|
||||||
|
NaN, Date, Item, Qty, Total
|
||||||
|
NaN, 2023-01-01, Apple, 10, 500
|
||||||
|
NaN, 2023-01-02, Banana, 5, 100
|
||||||
|
|
||||||
|
Output:
|
||||||
|
{{
|
||||||
|
"is_extractable_table": true,
|
||||||
|
"row_start_index": 2,
|
||||||
|
"col_ranges": [1, 4],
|
||||||
|
"confidence_score": 0.99,
|
||||||
|
"reasoning": " Rows 0-1 are ignored metadata, Row 2 is clear headers. Rows 3-4 contain consistent data aligned with headers. It is IRREGULAR and requires skiprows=2, usecols=[1, 4] to extract using Pansa DataFrame."
|
||||||
|
}}
|
||||||
|
|
||||||
|
## Example 2 (Unstructured - Form/Dashboard)
|
||||||
|
Input:
|
||||||
|
Company Invoice, NaN, NaN, Invoice #: 001
|
||||||
|
To:, John Doe, NaN, Date:, 2023-01-01
|
||||||
|
Address:, 123 St, NaN, Due:, 2023-02-01
|
||||||
|
NaN, NaN, NaN, NaN, NaN
|
||||||
|
Subject:, Consulting Services, NaN, NaN, NaN
|
||||||
|
|
||||||
|
Output:
|
||||||
|
{{
|
||||||
|
"is_extractable_table": false,
|
||||||
|
"row_start_index": null,
|
||||||
|
"col_ranges": null,
|
||||||
|
"confidence_score": 0.95,
|
||||||
|
"reasoning": "Data matches a 'Form/Invoice' layout (Key-Value pairs) rather than a list-like table. No single header row defines a dataset of records. It is INVALIED and cannot be processed as Pandas DataFrame."
|
||||||
|
}}
|
||||||
|
|
||||||
|
# Input
|
||||||
|
{raw_snippet_data}
|
||||||
@@ -0,0 +1,109 @@
|
|||||||
|
# Role
|
||||||
|
You are an expert Data Steward. Your task is to analyze the metadata and content of an Database.
|
||||||
|
**Assumption:** This is an ideal dataset or database where **ALL** tables contain valid headers in the first row. You will process the entire file structure in a single pass.
|
||||||
|
|
||||||
|
# Input Format
|
||||||
|
You will receive a single JSON string in the variable `input_json`. The structure is:
|
||||||
|
```json
|
||||||
|
{{
|
||||||
|
"file": "Name of the file",
|
||||||
|
"tables": [
|
||||||
|
{{
|
||||||
|
"name": "Name of the table",
|
||||||
|
"row_count": 100,
|
||||||
|
"col_count": 5,
|
||||||
|
"raw_data_snippet": "Header1, Header2\nVal1, Val2..."
|
||||||
|
}},
|
||||||
|
...
|
||||||
|
]
|
||||||
|
}}
|
||||||
|
|
||||||
|
```
|
||||||
|
# Analysis Logic
|
||||||
|
|
||||||
|
|
||||||
|
## 1. Sheet Iteration (Sheet Descriptions)
|
||||||
|
|
||||||
|
For **EACH** object in the `tables` array:
|
||||||
|
|
||||||
|
1. **Extract Schema:**
|
||||||
|
* Since headers are guaranteed, simply extract the column names from the **first row** of the `raw_data_snippet`.
|
||||||
|
* Format them as a clean list of strings.
|
||||||
|
|
||||||
|
2. **Draft Description:**
|
||||||
|
* Write a concise sentence describing what the sheet tracks based on its name and columns.
|
||||||
|
* **MANDATORY:** You MUST explicitly mention the `row_count` and `col_count` in this sentence.
|
||||||
|
* *Template:* "The sheet [Sheet Name] contains [Subject] data with [Row Count] rows and [Col Count] columns, featuring fields like [List 3 key columns]."
|
||||||
|
|
||||||
|
## 2. Global Analysis (File Description)
|
||||||
|
* Analyze the `file` name and the number of all `table_name`s inside the `tables` array.
|
||||||
|
* Based on all sheet descriptions, generate a single sentence summarizing the whole workbook.
|
||||||
|
|
||||||
|
# Output Format (Strict JSON)
|
||||||
|
|
||||||
|
You must output a single valid JSON object.
|
||||||
|
|
||||||
|
```json
|
||||||
|
{{
|
||||||
|
"description": "One sentence describing the whole file or database.",
|
||||||
|
"tables": [
|
||||||
|
{{
|
||||||
|
"name": "Name of table 1",
|
||||||
|
"description": "Sentence including row/col counts and key columns.",
|
||||||
|
"columns": ["col1", "col2", "col3"]
|
||||||
|
}},
|
||||||
|
...
|
||||||
|
]
|
||||||
|
}}
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
# One-Shot Demonstration
|
||||||
|
|
||||||
|
**[Example Input]**
|
||||||
|
`input_json` =
|
||||||
|
|
||||||
|
```json
|
||||||
|
{{
|
||||||
|
"file": "logistics_data.xlsx",
|
||||||
|
"tables": [
|
||||||
|
{{
|
||||||
|
"na me": "Shipments",
|
||||||
|
"row_count": 2000,
|
||||||
|
"col_count": 4,
|
||||||
|
"raw_data_snippet": "shipment_id, origin, destination, date\nSHP-001, Tokyo, London, 2024-05-20"
|
||||||
|
}},
|
||||||
|
{{
|
||||||
|
"name": "Rates",
|
||||||
|
"row_count": 50,
|
||||||
|
"col_count": 2,
|
||||||
|
"raw_data_snippet": "Route_ID, Cost_Per_Kg\nR-101, 5.50"
|
||||||
|
}}
|
||||||
|
]
|
||||||
|
}}
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
**[Example Output]**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{{
|
||||||
|
"description": "The file/database logistics_data.xlsx contains supply chain logistics information for 2024, divided into shipment tracking and rate definitions (2 tables in total).",
|
||||||
|
"tables": [
|
||||||
|
{{
|
||||||
|
"name": "Shipments",
|
||||||
|
"description": "The 'Shipments' sheet tracks individual shipment records with 2000 rows and 4 columns, featuring fields such as shipment_id, origin, and destination.",
|
||||||
|
"columns": ["shipment_id", "origin", "destination", "date"]
|
||||||
|
}},
|
||||||
|
{{
|
||||||
|
"name": "Rates",
|
||||||
|
"description": "The 'Rates' sheet lists shipping cost rates with 50 rows and 2 columns, specifically Route_ID and Cost_Per_Kg.",
|
||||||
|
"columns": ["Route_ID", "Cost_Per_Kg"]
|
||||||
|
}}
|
||||||
|
]
|
||||||
|
}}
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
# Input
|
||||||
|
input_json=`{data}`
|
||||||
@@ -0,0 +1,159 @@
|
|||||||
|
# Role
|
||||||
|
You are an expert Data Steward. Your task is to analyze the metadata and content of an Excel file based on a pre-analyzed structural judgment.
|
||||||
|
|
||||||
|
**Context:** The dataset contains three types of sheets:
|
||||||
|
1. **Regular Tables**: Standard headers in row 0.
|
||||||
|
2. **Irregular Tables**: Valid data but requires `skiprows` or `usecols` parameters.
|
||||||
|
3. **Unstructured Sheets**: Dashboards, forms, or text descriptions that **cannot** be read as a dataframe.
|
||||||
|
|
||||||
|
**Constraint**: Your analysis relies on a snippet of the first 100 rows.
|
||||||
|
|
||||||
|
# Input Format
|
||||||
|
You will receive a single JSON string in the variable `input_json`. The structure is:
|
||||||
|
```json
|
||||||
|
{{
|
||||||
|
"file": "Name of the file",
|
||||||
|
"tables": [
|
||||||
|
{{
|
||||||
|
"name": "Sheet Name",
|
||||||
|
"row_count": 100,
|
||||||
|
"col_count": 5,
|
||||||
|
"raw_data_snippet": "...",
|
||||||
|
"irregular_judgment": {{
|
||||||
|
"row_header_index": int,
|
||||||
|
"cols_ranges": list,
|
||||||
|
"reasoning": "..."
|
||||||
|
}}
|
||||||
|
}}
|
||||||
|
]
|
||||||
|
}}
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
*(Note: If `irregular_judgment` is null, treat it as Regular).*
|
||||||
|
|
||||||
|
# Analysis Logic
|
||||||
|
|
||||||
|
## 1. Sheet Iteration (Table Descriptions)
|
||||||
|
|
||||||
|
For **EACH** object in the `tables` array, apply the following priority logic:
|
||||||
|
|
||||||
|
**Case A: Unstructured Sheet (irregular_judgment contains "UNSTRUCTURED")**
|
||||||
|
|
||||||
|
* **Columns**: Return an empty list `[]`.
|
||||||
|
* **Description**: "The sheet [Name] contains [something].
|
||||||
|
**Append MANDATORY Warning**: "It is Unstructured based on a 100-row sample."
|
||||||
|
|
||||||
|
**Case B: Irregular Table (irregular_judgment contains a dict and `row_header_index` > 0 or `cols_ranges` is set)**
|
||||||
|
|
||||||
|
* **Columns**: Extract column names from the row indicated by `row_header_index`.
|
||||||
|
* **Description**:
|
||||||
|
Write a concise sentence describing what the sheet tracks based on its name and columns.
|
||||||
|
1. Start with: "The sheet [Name] contains [Subject] data with [Rows] rows and [Cols] columns."
|
||||||
|
2. **Append MANDATORY Warning**: "It is irregular; requires specifying skiprows={{row_header_index}}, usecols={{cols_ranges}} using pandas dataframe."
|
||||||
|
|
||||||
|
**Case C: Regular Table (Default)**
|
||||||
|
|
||||||
|
* **Columns**: Extract from the first row of `raw_data_snippet`.
|
||||||
|
* **Description**: "The sheet [Name] contains [Subject] data with [Rows] rows and [Cols] columns, featuring fields like [Key Cols]."
|
||||||
|
|
||||||
|
## 2. Global Analysis (File Description)
|
||||||
|
|
||||||
|
Generate a single string summarizing the workbook. This summary **MUST** explicitly include:
|
||||||
|
|
||||||
|
1. **Total Count**: The number of sheets.
|
||||||
|
2. **Status List**: List every table name with its status tag:
|
||||||
|
* (Regular)
|
||||||
|
* (Irregular, requires skiprows=X, usecols=Y)
|
||||||
|
* (Unstructured)
|
||||||
|
* *Format Example:* "The file logistics_data.xlsx contains supply chain logistics information for 2024, analyze the log datas. It contains 3 sheets: 'Data' (Regular), 'Logs' (Irregular, requires skiprows=2), and 'Cover' (Unstructured)."
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# Output Format (Strict JSON)
|
||||||
|
|
||||||
|
You must output a single valid JSON object.
|
||||||
|
|
||||||
|
```json
|
||||||
|
{{
|
||||||
|
"description": "Comprehensive summary including count, names, and specific status tags for ALL tables.",
|
||||||
|
"tables": [
|
||||||
|
{{
|
||||||
|
"name": "Table Name",
|
||||||
|
"description": "Specific description based on Case A, B, or C.",
|
||||||
|
"columns": ["col1", "col2"]
|
||||||
|
}}
|
||||||
|
]
|
||||||
|
}}
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
# One-Shot Demonstration
|
||||||
|
|
||||||
|
**[Example Input]**
|
||||||
|
`input_json` =
|
||||||
|
|
||||||
|
```json
|
||||||
|
{{
|
||||||
|
"file": "finance_report_v2.xlsx",
|
||||||
|
"tables": [
|
||||||
|
{{
|
||||||
|
"name": "Q1_Sales",
|
||||||
|
"row_count": 200,
|
||||||
|
"col_count": 5,
|
||||||
|
"raw_data_snippet": "Date, Item, Amount\n2023-01-01, A, 100",
|
||||||
|
}},
|
||||||
|
{{
|
||||||
|
"name": "Historical_Data",
|
||||||
|
"row_count": 500,
|
||||||
|
"col_count": 10,
|
||||||
|
"raw_data_snippet": "Confidential\nSystem Generated\n\nDate, ID, Val\n...",
|
||||||
|
"irregular_judgment": {{
|
||||||
|
"is_extractable_table": true,
|
||||||
|
"row_header_index": 3,
|
||||||
|
"cols_ranges": [0, 3],
|
||||||
|
"reasoning": "Header offset."
|
||||||
|
}}
|
||||||
|
}},
|
||||||
|
{{
|
||||||
|
"name": "Dashboard_Overview",
|
||||||
|
"row_count": 50,
|
||||||
|
"col_count": 20,
|
||||||
|
"raw_data_snippet": "Total KPI: 500 | Chart Area |\nDisclaimer: Internal Use",
|
||||||
|
"irregular_judgment": "UNSTRUCTURED"
|
||||||
|
}}
|
||||||
|
]
|
||||||
|
}}
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
**[Example Output]**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{{
|
||||||
|
"description": "The file finance_report_v2.xlsx contains historical sales transaction records over the past Q1 period.
|
||||||
|
It contains 3 sheets: 'Q1_Sales' (Regular), 'Historical_Data' (Irregular, requires skiprows=3, usecols=[0, 3], sampled first 100 rows), and 'Dashboard_Overview' (Unstructured).",
|
||||||
|
"tables": [
|
||||||
|
{{
|
||||||
|
"name": "Q1_Sales",
|
||||||
|
"description": "The sheet 'Q1_Sales' contains sales transaction records. It contains 200 rows and 5 columns, featuring fields like Date, Item, and Amount.",
|
||||||
|
"columns": ["Date", "Item", "Amount"]
|
||||||
|
}},
|
||||||
|
{{
|
||||||
|
"name": "Historical_Data",
|
||||||
|
"description": "The sheet 'Historical_Data' contains historical sales transaction records records. It contains 400 rows and 21 columns. It's irregular judged by the first 100 samples(The first 3 rows contains metadata. requires specifying skiprows=3, usecols=[0, 3] using pandas dataframe.)",
|
||||||
|
"columns": ["Date", "ID", "Val"]
|
||||||
|
}},
|
||||||
|
{{
|
||||||
|
"name": "Dashboard_Overview",
|
||||||
|
"description": "The sheet 'Dashboard_Overview' contains the whole overview and summary of the whole dashboards It is Unstructured based on a 100-row sample.",
|
||||||
|
"columns": []
|
||||||
|
}}
|
||||||
|
]
|
||||||
|
}}
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
# Input
|
||||||
|
|
||||||
|
input_json=`{data}`
|
||||||
113
alias/src/alias/agent/agents/data_source/data_profile.py
Normal file
113
alias/src/alias/agent/agents/data_source/data_profile.py
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
import os
|
||||||
|
import base64
|
||||||
|
import tempfile
|
||||||
|
from typing import Any, Dict
|
||||||
|
from io import BytesIO
|
||||||
|
from pathlib import Path
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from alias.agent.agents.data_source._typing import SourceType
|
||||||
|
from alias.agent.agents.data_source._data_profiler_factory import (
|
||||||
|
DataProfilerFactory,
|
||||||
|
)
|
||||||
|
from alias.agent.tools.sandbox_util import (
|
||||||
|
get_workspace_file,
|
||||||
|
)
|
||||||
|
from alias.runtime.alias_sandbox.alias_sandbox import AliasSandbox
|
||||||
|
from alias.agent.utils.llm_call_manager import (
|
||||||
|
LLMCallManager,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_binary_buffer(
|
||||||
|
sandbox: AliasSandbox,
|
||||||
|
file_url: str,
|
||||||
|
):
|
||||||
|
if file_url.startswith(("http://", "https://")):
|
||||||
|
response = requests.get(file_url)
|
||||||
|
response.raise_for_status()
|
||||||
|
buffer = BytesIO(response.content)
|
||||||
|
else:
|
||||||
|
buffer = BytesIO(
|
||||||
|
base64.b64decode(get_workspace_file(sandbox, file_url)),
|
||||||
|
)
|
||||||
|
return buffer
|
||||||
|
|
||||||
|
|
||||||
|
def _copy_file_from_sandbox_with_original_name(
|
||||||
|
sandbox: AliasSandbox,
|
||||||
|
file_path: str,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Copies a file from the sandbox environment
|
||||||
|
or a URL to a local temporary file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sandbox (AliasSandbox): The sandbox environment instance.
|
||||||
|
path (str): Source path or URL.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The path to the local temporary file.
|
||||||
|
"""
|
||||||
|
# Handle different types of file URLs
|
||||||
|
if file_path.startswith(("http://", "https://")):
|
||||||
|
# For web URLs, use the URL directly
|
||||||
|
file_source = file_path
|
||||||
|
else:
|
||||||
|
# For local files, save to a temporary file
|
||||||
|
file_buffer = _get_binary_buffer(
|
||||||
|
sandbox,
|
||||||
|
file_path,
|
||||||
|
)
|
||||||
|
# Create a temporary file with the same name as the original file
|
||||||
|
temp_dir = tempfile.mkdtemp()
|
||||||
|
target_file_name = os.path.basename(file_path)
|
||||||
|
full_path = Path(temp_dir) / target_file_name
|
||||||
|
with open(full_path, "wb") as f:
|
||||||
|
f.write(file_buffer.getvalue())
|
||||||
|
file_source = full_path
|
||||||
|
return str(file_source)
|
||||||
|
|
||||||
|
|
||||||
|
async def data_profile(
|
||||||
|
sandbox: AliasSandbox,
|
||||||
|
sandbox_path: str,
|
||||||
|
source_type: SourceType,
|
||||||
|
llm_call_manager: LLMCallManager,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Generates a detailed profile and summary for data source using LLMs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sandbox (AliasSandbox): The sandbox environment instance.
|
||||||
|
path (str): The location of the data source.
|
||||||
|
- For files: A file path or URL.
|
||||||
|
- For databases: A connection string (DSN).
|
||||||
|
source_type (SourceType): The type of the data source.
|
||||||
|
llm_call_manager: Manager for handling LLM calls
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: An object containing the generated text profile of the data.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the provided `source_type` is not supported.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if source_type in [SourceType.CSV, SourceType.EXCEL, SourceType.IMAGE]:
|
||||||
|
local_path = _copy_file_from_sandbox_with_original_name(
|
||||||
|
sandbox,
|
||||||
|
sandbox_path,
|
||||||
|
)
|
||||||
|
elif source_type == SourceType.RELATIONAL_DB:
|
||||||
|
local_path = sandbox_path
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported source type {source_type}")
|
||||||
|
|
||||||
|
profiler = DataProfilerFactory.get_profiler(
|
||||||
|
llm_call_manager=llm_call_manager,
|
||||||
|
path=local_path,
|
||||||
|
source_type=source_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
return await profiler.generate_profile()
|
||||||
198
alias/src/alias/agent/agents/data_source/data_skill.py
Normal file
198
alias/src/alias/agent/agents/data_source/data_skill.py
Normal file
@@ -0,0 +1,198 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
import os
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import frontmatter
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from agentscope.tool._types import AgentSkill
|
||||||
|
|
||||||
|
from alias.agent.agents.ds_agent_utils.utils import get_prompt_from_file
|
||||||
|
from alias.agent.agents.data_source._typing import SourceType
|
||||||
|
|
||||||
|
|
||||||
|
class DataSkill(AgentSkill):
|
||||||
|
"""The source type of the skill."""
|
||||||
|
|
||||||
|
type: List[SourceType]
|
||||||
|
|
||||||
|
|
||||||
|
class DataSkillManager:
|
||||||
|
"""Data Skill Selector Based on Data Source Type"""
|
||||||
|
|
||||||
|
_default_skill_path_base = os.path.join(
|
||||||
|
Path(__file__).resolve().parent.parent,
|
||||||
|
"_built_in_skill/data",
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
):
|
||||||
|
self.skills = self.register_skill_dir()
|
||||||
|
|
||||||
|
self.source_type_2_skills = {}
|
||||||
|
for skill in self.skills:
|
||||||
|
for t in skill["type"]:
|
||||||
|
self.source_type_2_skills[t] = skill
|
||||||
|
|
||||||
|
def load(self, data_source_types: List[SourceType]) -> List[str]:
|
||||||
|
"""
|
||||||
|
Load skills based on data source type
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_source_types: List of SourceType enum values
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Selected skill content list
|
||||||
|
"""
|
||||||
|
if not data_source_types:
|
||||||
|
return []
|
||||||
|
|
||||||
|
selected_skills = []
|
||||||
|
|
||||||
|
data_source_types = set(data_source_types)
|
||||||
|
for source_type in data_source_types:
|
||||||
|
try:
|
||||||
|
# Get skill from source type mapping
|
||||||
|
skill = self.source_type_2_skills.get(source_type, None)
|
||||||
|
|
||||||
|
# Skip if no corresponding skill
|
||||||
|
if not skill:
|
||||||
|
logger.warning(
|
||||||
|
"DataSkillSelector found no valid skill for data "
|
||||||
|
f"source type: {source_type}",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
f"DataSkillSelector selected skill: {skill['name']} "
|
||||||
|
f"for data source type: {source_type}",
|
||||||
|
)
|
||||||
|
|
||||||
|
skill_content = get_prompt_from_file(
|
||||||
|
skill["dir"],
|
||||||
|
return_json=False,
|
||||||
|
)
|
||||||
|
if skill_content:
|
||||||
|
selected_skills.append(skill_content)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"DataSkillSelector selection failed: {str(e)} "
|
||||||
|
f"for data source type: {source_type}",
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
return selected_skills
|
||||||
|
|
||||||
|
def register_skill_dir(self, skill_dir=_default_skill_path_base):
|
||||||
|
"""Load skills from all directories containing SKILL.md"""
|
||||||
|
|
||||||
|
skills = []
|
||||||
|
# Check the skill directory
|
||||||
|
if not os.path.isdir(skill_dir):
|
||||||
|
raise ValueError(
|
||||||
|
f"The skill directory '{skill_dir}' does not exist or is "
|
||||||
|
"not a directory.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Walk through all files and directories in skill_dir_base
|
||||||
|
for root, dirs, _ in os.walk(skill_dir):
|
||||||
|
# Process directories - look for SKILL.md
|
||||||
|
for dir_name in dirs:
|
||||||
|
dir_path = os.path.join(root, dir_name)
|
||||||
|
skill = self.register_skill(dir_path)
|
||||||
|
if skill:
|
||||||
|
skills.append(skill)
|
||||||
|
|
||||||
|
return skills
|
||||||
|
|
||||||
|
def register_skill(self, path: str, name=None):
|
||||||
|
"""
|
||||||
|
Register a new skill dynamically
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Skill name
|
||||||
|
path: Path to skill directory containing SKILL.md
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Resolve the skill path
|
||||||
|
file_path = self._resolve_skill_path(path)
|
||||||
|
if not file_path:
|
||||||
|
raise FileNotFoundError("`SKILL.md` not found")
|
||||||
|
|
||||||
|
# Parse the skill file
|
||||||
|
skill = self._parse_skill_file(file_path, name)
|
||||||
|
logger.info(
|
||||||
|
f"Successfully registered skill '{skill['name']}' "
|
||||||
|
f"from '{file_path}'",
|
||||||
|
)
|
||||||
|
|
||||||
|
return skill
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to register skill '{skill['name']}' from "
|
||||||
|
f"'{path}': {e}",
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _resolve_skill_path(self, path: str) -> str:
|
||||||
|
"""
|
||||||
|
Resolve a skill path to the actual markdown file path
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Path to skill markdown file or directory containing SKILL.md
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Path to the actual markdown file, or empty string if invalid
|
||||||
|
"""
|
||||||
|
if os.path.isdir(path):
|
||||||
|
skill_md_path = os.path.join(path, "SKILL.md")
|
||||||
|
if not os.path.isfile(skill_md_path):
|
||||||
|
logger.warning(f"Directory '{path}' does not contain SKILL.md")
|
||||||
|
return ""
|
||||||
|
return skill_md_path
|
||||||
|
else:
|
||||||
|
logger.warning(f"Invalid skill path: {path}")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def _parse_skill_file(self, file_path, name=None):
|
||||||
|
"""Parse a skill file and add it to skills_list"""
|
||||||
|
|
||||||
|
# Check YAML Front Matter
|
||||||
|
post = frontmatter.load(file_path)
|
||||||
|
|
||||||
|
# Use directory name as skill name if not provided in YAML
|
||||||
|
if name is None:
|
||||||
|
dir_name = os.path.basename(os.path.dirname(file_path))
|
||||||
|
name = post.get("name", dir_name)
|
||||||
|
else:
|
||||||
|
name = post.get("name", name)
|
||||||
|
|
||||||
|
description = post.get("description", None)
|
||||||
|
_type = post.get("type", None)
|
||||||
|
|
||||||
|
if not name or not description or not _type:
|
||||||
|
raise ValueError(
|
||||||
|
f"The file '{file_path}' must have a YAML Front "
|
||||||
|
"Matter including `name`, `description`, and `type` fields",
|
||||||
|
)
|
||||||
|
|
||||||
|
_type = _type if isinstance(_type, list) else [_type]
|
||||||
|
if any(not SourceType.is_valid_source_type(t) for t in _type):
|
||||||
|
raise ValueError(
|
||||||
|
f"Type of file '{file_path}' must be a member "
|
||||||
|
"(or a list of members) of SourceType",
|
||||||
|
)
|
||||||
|
|
||||||
|
name, description = str(name), str(description)
|
||||||
|
_type = [SourceType(t) for t in _type]
|
||||||
|
|
||||||
|
return DataSkill(
|
||||||
|
name=name,
|
||||||
|
description=description,
|
||||||
|
type=_type,
|
||||||
|
dir=file_path,
|
||||||
|
)
|
||||||
620
alias/src/alias/agent/agents/data_source/data_source.py
Normal file
620
alias/src/alias/agent/agents/data_source/data_source.py
Normal file
@@ -0,0 +1,620 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# pylint: disable=R1702,R0912,R0911
|
||||||
|
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Any, Optional, List
|
||||||
|
import yaml
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from agentscope.mcp import StdIOStatefulClient
|
||||||
|
from agentscope_runtime.sandbox.box.sandbox import Sandbox
|
||||||
|
|
||||||
|
from alias.agent.agents.data_source.data_skill import DataSkillManager
|
||||||
|
from alias.agent.agents.data_source._typing import (
|
||||||
|
SOURCE_TYPE_TO_ACCESS_TYPE,
|
||||||
|
SourceAccessType,
|
||||||
|
SourceType,
|
||||||
|
)
|
||||||
|
from alias.agent.agents.data_source.data_profile import data_profile
|
||||||
|
from alias.agent.agents.data_source.utils import replace_placeholders
|
||||||
|
from alias.agent.tools.toolkit_hooks.text_post_hook import TextPostHook
|
||||||
|
from alias.agent.tools.alias_toolkit import AliasToolkit
|
||||||
|
from alias.agent.tools.sandbox_util import (
|
||||||
|
copy_local_file_to_workspace,
|
||||||
|
)
|
||||||
|
from alias.agent.utils.llm_call_manager import (
|
||||||
|
LLMCallManager,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DataSource:
|
||||||
|
"""
|
||||||
|
Unified data source class representing any data source.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
endpoint: str,
|
||||||
|
source_type: SourceType,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
config: Optional[Dict[str, Any]] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize a data source.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source_access_type: Type of the data source access \
|
||||||
|
(SourceAccessType enum)
|
||||||
|
source_type: Type of the data source (SourceType enum)
|
||||||
|
name: Name/identifier of the data source
|
||||||
|
endpoint: Address/DNS/URL/path to access the data source
|
||||||
|
description: Optional description of the data source
|
||||||
|
config: Configuration for this data source
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.endpoint = endpoint
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
source_access_type = SOURCE_TYPE_TO_ACCESS_TYPE.get(
|
||||||
|
source_type,
|
||||||
|
SourceAccessType.DIRECT,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.source_access_type = source_access_type
|
||||||
|
self.source_type = source_type
|
||||||
|
|
||||||
|
self.config = config or {}
|
||||||
|
self.profile = {}
|
||||||
|
self.source_desc = None
|
||||||
|
self.source_access_desc = None
|
||||||
|
|
||||||
|
async def prepare(self, toolkit: AliasToolkit):
|
||||||
|
"""
|
||||||
|
Prepare data source.
|
||||||
|
For LOCAL_FILE: Upload file to sandbox workspace
|
||||||
|
For MCP_TOOL: Register corresponding MCP server
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sandbox: Sandbox instance
|
||||||
|
"""
|
||||||
|
|
||||||
|
logger.info(f"Preparing data source {self.name}...")
|
||||||
|
|
||||||
|
if self.source_access_type == SourceAccessType.DIRECT:
|
||||||
|
# Get the filename and construct target path in workspace
|
||||||
|
filename = os.path.basename(self.endpoint)
|
||||||
|
target_path = f"/workspace/{filename}"
|
||||||
|
|
||||||
|
if os.getenv("LINK_FILE_TO_WORKSPACE", "off").lower() == "on":
|
||||||
|
logger.info(
|
||||||
|
f"Creating symlink for {self.endpoint} "
|
||||||
|
f"to {target_path}",
|
||||||
|
)
|
||||||
|
# Build ln -s command
|
||||||
|
command = f"ln -s '{self.endpoint}' '{target_path}'"
|
||||||
|
result = toolkit.sandbox.call_tool(
|
||||||
|
name="run_shell_command",
|
||||||
|
arguments={"command": command},
|
||||||
|
)
|
||||||
|
if result.get("isError"):
|
||||||
|
raise ValueError(
|
||||||
|
"Failed to create symlink for "
|
||||||
|
f"{self.endpoint}: {result}",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info(f"Uploading {self.endpoint} to {target_path}")
|
||||||
|
result = copy_local_file_to_workspace(
|
||||||
|
sandbox=toolkit.sandbox,
|
||||||
|
local_path=self.endpoint,
|
||||||
|
target_path=target_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.get("isError"):
|
||||||
|
raise ValueError(
|
||||||
|
f"Failed to upload {self.endpoint}: " f"{result}",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.source_access = target_path
|
||||||
|
self.source_desc = "Local file"
|
||||||
|
self.source_access_desc = f"Access at path: `{target_path}`"
|
||||||
|
|
||||||
|
logger.info(f"Successfully loaded to {result}")
|
||||||
|
|
||||||
|
# Check if this is an MCP tool source
|
||||||
|
elif self.source_access_type == SourceAccessType.VIA_MCP:
|
||||||
|
server_config = self.config.get("mcp_server", {})
|
||||||
|
mcp_server_name = server_config.keys()
|
||||||
|
|
||||||
|
if len(mcp_server_name) != 1:
|
||||||
|
raise ValueError("Register server one by one!")
|
||||||
|
|
||||||
|
mcp_server_name = list(mcp_server_name)[0]
|
||||||
|
server_config = server_config[mcp_server_name]
|
||||||
|
|
||||||
|
cmd = server_config.get("command")
|
||||||
|
args = server_config.get("args")
|
||||||
|
if cmd is None or args is None:
|
||||||
|
raise ValueError(
|
||||||
|
"MCP server configuration requires non-empty "
|
||||||
|
"`command` and `args` fields to start!",
|
||||||
|
)
|
||||||
|
|
||||||
|
client = StdIOStatefulClient(
|
||||||
|
self.name,
|
||||||
|
command=cmd,
|
||||||
|
args=args,
|
||||||
|
env=server_config.get("env"),
|
||||||
|
)
|
||||||
|
|
||||||
|
text_hook = TextPostHook(
|
||||||
|
toolkit.sandbox,
|
||||||
|
budget=5000,
|
||||||
|
auto_save=True,
|
||||||
|
)
|
||||||
|
await toolkit.add_and_connect_mcp_client(
|
||||||
|
client,
|
||||||
|
postprocess_func=text_hook.truncate_and_save_response,
|
||||||
|
)
|
||||||
|
registered_tools = [
|
||||||
|
t.name
|
||||||
|
for t in list(
|
||||||
|
await toolkit.additional_mcp_clients[-1].list_tools(),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
self.source_access = self.endpoint
|
||||||
|
self.source_desc = f"{self.source_type}"
|
||||||
|
self.source_access_desc = (
|
||||||
|
f"Access via MCP tools: [{', '.join(registered_tools)}]"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Successfully connected to {self.name}")
|
||||||
|
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
f"Skipping preparation for source type: {self.source_type}",
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_coarse_desc(self):
|
||||||
|
return (
|
||||||
|
f"{self.source_desc}. {self.source_access_desc}: "
|
||||||
|
+ f"{self._general_profile()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def prepare_profile(
|
||||||
|
self,
|
||||||
|
sandbox: Sandbox,
|
||||||
|
llm_call_manager: LLMCallManager,
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Run type-specific profiling."""
|
||||||
|
if llm_call_manager and not self.profile:
|
||||||
|
try:
|
||||||
|
self.profile = await data_profile(
|
||||||
|
sandbox=sandbox,
|
||||||
|
sandbox_path=self.source_access,
|
||||||
|
source_type=self.source_type,
|
||||||
|
llm_call_manager=llm_call_manager,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"Profiling successfully: "
|
||||||
|
+ f"{self._general_profile()[:100]}...",
|
||||||
|
)
|
||||||
|
except ValueError as e:
|
||||||
|
self.profile = None
|
||||||
|
logger.warning(f"Warning when profile data: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
self.profile = None
|
||||||
|
logger.error(f"Error when profile data: {e}")
|
||||||
|
|
||||||
|
return self.profile
|
||||||
|
|
||||||
|
def _refined_profile(self) -> str:
|
||||||
|
if self.profile:
|
||||||
|
return yaml.dump(
|
||||||
|
self.profile,
|
||||||
|
allow_unicode=True,
|
||||||
|
sort_keys=False,
|
||||||
|
default_flow_style=False
|
||||||
|
if self.source_type == SourceType.IMAGE
|
||||||
|
else None,
|
||||||
|
width=float("inf"),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def _general_profile(self) -> str:
|
||||||
|
return self.profile["description"] if self.profile else ""
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return (
|
||||||
|
f"DataSource(name='{self.name}', type='{self.source_type}', "
|
||||||
|
f"endpoint='{self.endpoint}')"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return self.__str__()
|
||||||
|
|
||||||
|
|
||||||
|
class DataSourceManager:
|
||||||
|
"""
|
||||||
|
Manager class for handling multiple data sources.
|
||||||
|
Provides methods to add, retrieve, and manage data sources.
|
||||||
|
Also manages data source configurations with hierarchical lookup.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_default_data_source_config = os.path.join(
|
||||||
|
Path(__file__).resolve().parent,
|
||||||
|
"_default_config.json",
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
sandbox: Sandbox,
|
||||||
|
llm_call_manager: LLMCallManager,
|
||||||
|
):
|
||||||
|
"""Initialize an empty data source manager."""
|
||||||
|
self._data_sources: Dict[str, DataSource] = {}
|
||||||
|
self._type_defaults = {}
|
||||||
|
|
||||||
|
self._load_default_config()
|
||||||
|
|
||||||
|
self.skill_manager = DataSkillManager()
|
||||||
|
self.selected_skills = None
|
||||||
|
|
||||||
|
self.toolkit = AliasToolkit(sandbox=sandbox)
|
||||||
|
|
||||||
|
self.llm_call_manager = llm_call_manager
|
||||||
|
|
||||||
|
def add_data_source(
|
||||||
|
self,
|
||||||
|
config: str | Dict = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Add a new data source (or multiple sources) to the manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: endpoint(Address/DNS/URL/path to the data source) or
|
||||||
|
configuration for data source conection
|
||||||
|
"""
|
||||||
|
|
||||||
|
if isinstance(config, str):
|
||||||
|
endpoint = config
|
||||||
|
conn_config = None
|
||||||
|
else:
|
||||||
|
if "endpoint" not in config:
|
||||||
|
logger.error(
|
||||||
|
f"Missing 'endpoint' in config for source '{config}'",
|
||||||
|
)
|
||||||
|
|
||||||
|
endpoint = config["endpoint"]
|
||||||
|
conn_config = config
|
||||||
|
|
||||||
|
sources = set()
|
||||||
|
if os.path.isdir(endpoint):
|
||||||
|
# Add all files in directory
|
||||||
|
for filename in os.listdir(endpoint):
|
||||||
|
file_path = os.path.join(endpoint, filename)
|
||||||
|
sources.add(file_path)
|
||||||
|
else:
|
||||||
|
sources.add(endpoint)
|
||||||
|
|
||||||
|
for endpoint in sources:
|
||||||
|
# Auto-detect source type
|
||||||
|
source_type = self._detect_source_type(endpoint)
|
||||||
|
|
||||||
|
# Auto-generate name
|
||||||
|
name = self._generate_name(endpoint)
|
||||||
|
|
||||||
|
# Get configuration for this data source
|
||||||
|
if not conn_config:
|
||||||
|
conn_config = self.get_default_config(source_type)
|
||||||
|
|
||||||
|
if conn_config:
|
||||||
|
conn_config = replace_placeholders(
|
||||||
|
conn_config,
|
||||||
|
{
|
||||||
|
"endpoint": endpoint,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create data source with configuration
|
||||||
|
data_source = DataSource(endpoint, source_type, name, conn_config)
|
||||||
|
self._data_sources[endpoint] = data_source
|
||||||
|
|
||||||
|
async def prepare_data_sources(self) -> None:
|
||||||
|
"""
|
||||||
|
Prepare all data sources.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sandbox: Optional sandbox instance for file uploads and startup \
|
||||||
|
MCP servers
|
||||||
|
"""
|
||||||
|
logger.info(f"Preparing {len(self._data_sources)} data source(s)...")
|
||||||
|
|
||||||
|
all_data_sources = self._data_sources.values()
|
||||||
|
for data_source in all_data_sources:
|
||||||
|
await data_source.prepare(self.toolkit)
|
||||||
|
await data_source.prepare_profile(
|
||||||
|
self.toolkit.sandbox,
|
||||||
|
self.llm_call_manager,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _generate_name(self, endpoint: str) -> str:
|
||||||
|
"""
|
||||||
|
Generate an name based on the endpoint.
|
||||||
|
For databases, removes passwords and uses scheme + database name.
|
||||||
|
For files, uses filename.
|
||||||
|
For URLs, uses domain or last part of path.
|
||||||
|
"""
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
try:
|
||||||
|
# For file paths
|
||||||
|
if os.path.isfile(endpoint):
|
||||||
|
filename = os.path.basename(endpoint)
|
||||||
|
# Remove extension and sanitize
|
||||||
|
name_without_ext = os.path.splitext(filename)[0]
|
||||||
|
return self._sanitize_name(name_without_ext)
|
||||||
|
|
||||||
|
# For database connections
|
||||||
|
db_indicators = [
|
||||||
|
"://",
|
||||||
|
".db",
|
||||||
|
".sqlite",
|
||||||
|
"mongodb://",
|
||||||
|
"mongodb+srv://",
|
||||||
|
"neo4j://",
|
||||||
|
"bolt://",
|
||||||
|
]
|
||||||
|
if any(
|
||||||
|
indicator in endpoint.lower() for indicator in db_indicators
|
||||||
|
):
|
||||||
|
if "://" in endpoint:
|
||||||
|
try:
|
||||||
|
# Split by :// to get scheme and rest
|
||||||
|
scheme, rest = endpoint.split("://", 1)
|
||||||
|
scheme = scheme.lower()
|
||||||
|
|
||||||
|
# Handle authentication (user:password@host)
|
||||||
|
if "@" in rest:
|
||||||
|
auth_part, host_part = rest.split("@", 1)
|
||||||
|
if ":" in auth_part:
|
||||||
|
# Has user:password format, keep only username
|
||||||
|
username = auth_part.split(":")[0]
|
||||||
|
rest = f"{username}@{host_part}"
|
||||||
|
# If no colon, it's just username@host, keep as is
|
||||||
|
|
||||||
|
# Extract database name
|
||||||
|
db_name = "unknown"
|
||||||
|
if "/" in rest:
|
||||||
|
# Split by / and take last part before
|
||||||
|
# query parameters
|
||||||
|
path_parts = rest.split("/")
|
||||||
|
if len(path_parts) > 1:
|
||||||
|
db_name = (
|
||||||
|
path_parts[-1].split("?")[0].split("#")[0]
|
||||||
|
)
|
||||||
|
if not db_name: # If empty, try second to last
|
||||||
|
db_name = (
|
||||||
|
path_parts[-2]
|
||||||
|
if len(path_parts) > 2
|
||||||
|
else "unknown"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Use host name if no database name in path
|
||||||
|
host = (
|
||||||
|
rest.split(":")[0].split("/")[0].split("@")[-1]
|
||||||
|
)
|
||||||
|
db_name = host
|
||||||
|
|
||||||
|
# Create name: scheme_dbname
|
||||||
|
return self._sanitize_name(f"{scheme}_{db_name}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Error parsing database URL {endpoint}: {e}",
|
||||||
|
)
|
||||||
|
# Fall through to URL handling
|
||||||
|
|
||||||
|
elif "." in endpoint:
|
||||||
|
# Use filename without extension for .db/.sqlite files
|
||||||
|
filename = os.path.basename(endpoint)
|
||||||
|
name_without_ext = os.path.splitext(filename)[0]
|
||||||
|
return self._sanitize_name(name_without_ext)
|
||||||
|
|
||||||
|
# For URLs (including database URLs that failed to parse)
|
||||||
|
if "://" in endpoint:
|
||||||
|
try:
|
||||||
|
parsed = urlparse(endpoint)
|
||||||
|
if parsed.netloc:
|
||||||
|
# Use domain name (without port)
|
||||||
|
domain = parsed.netloc.split(":")[0].split("@")[
|
||||||
|
-1
|
||||||
|
] # Remove username if present
|
||||||
|
# If path exists, use last part of path
|
||||||
|
if parsed.path and parsed.path != "/":
|
||||||
|
path_parts = parsed.path.strip("/").split("/")
|
||||||
|
if path_parts:
|
||||||
|
return self._sanitize_name(path_parts[-1])
|
||||||
|
return self._sanitize_name(domain)
|
||||||
|
elif parsed.path:
|
||||||
|
# Use last part of path
|
||||||
|
path_parts = parsed.path.strip("/").split("/")
|
||||||
|
if path_parts:
|
||||||
|
return self._sanitize_name(path_parts[-1])
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error parsing URL {endpoint}: {e}")
|
||||||
|
|
||||||
|
# Fallback: use a sanitized version of the endpoint
|
||||||
|
return self._sanitize_name(endpoint[:50])
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error generating default name for {endpoint}: {e}")
|
||||||
|
# Ultimate fallback
|
||||||
|
return self._sanitize_name("unknown_source")
|
||||||
|
|
||||||
|
def _sanitize_name(self, name: str) -> str:
|
||||||
|
"""Sanitize a name to be used as a data source identifier."""
|
||||||
|
import re
|
||||||
|
|
||||||
|
# Keep only alphanumeric and underscore characters
|
||||||
|
sanitized = re.sub(r"[^a-zA-Z0-9_]", "_", name)
|
||||||
|
|
||||||
|
# Ensure it starts with a letter or underscore
|
||||||
|
if sanitized and not sanitized[0].isalpha() and sanitized[0] != "_":
|
||||||
|
sanitized = "_" + sanitized
|
||||||
|
|
||||||
|
# Truncate if too long
|
||||||
|
sanitized = sanitized[:50]
|
||||||
|
|
||||||
|
# Ensure it's not empty
|
||||||
|
if not sanitized:
|
||||||
|
sanitized = "unknown"
|
||||||
|
|
||||||
|
return sanitized
|
||||||
|
|
||||||
|
def _detect_source_type(self, endpoint: str) -> SourceType:
|
||||||
|
"""Auto-detect source type based on endpoint."""
|
||||||
|
endpoint_lower = endpoint.lower()
|
||||||
|
|
||||||
|
# Check for file extensions
|
||||||
|
if endpoint_lower.endswith(".csv"):
|
||||||
|
source_type = SourceType.CSV
|
||||||
|
elif endpoint_lower.endswith((".xls", ".xlsx", "xlsm")):
|
||||||
|
source_type = SourceType.EXCEL
|
||||||
|
elif endpoint_lower.endswith(".json"):
|
||||||
|
source_type = SourceType.JSON
|
||||||
|
elif endpoint_lower.endswith((".txt", ".log", ".md")):
|
||||||
|
source_type = SourceType.TEXT
|
||||||
|
elif endpoint_lower.endswith(
|
||||||
|
(".jpg", ".jpeg", ".png", ".gif", ".bmp"),
|
||||||
|
):
|
||||||
|
source_type = SourceType.IMAGE
|
||||||
|
|
||||||
|
# Check for database connection strings/patterns
|
||||||
|
# Relational databases
|
||||||
|
elif any(
|
||||||
|
keyword in endpoint_lower
|
||||||
|
for keyword in [
|
||||||
|
"postgresql://",
|
||||||
|
"postgres://",
|
||||||
|
"pg://",
|
||||||
|
"mysql://",
|
||||||
|
"mariadb://",
|
||||||
|
"sqlserver://",
|
||||||
|
]
|
||||||
|
):
|
||||||
|
source_type = SourceType.RELATIONAL_DB
|
||||||
|
elif (
|
||||||
|
"sqlite://" in endpoint_lower
|
||||||
|
or endpoint_lower.endswith(".db")
|
||||||
|
or endpoint_lower.endswith(".sqlite")
|
||||||
|
):
|
||||||
|
source_type = SourceType.RELATIONAL_DB
|
||||||
|
|
||||||
|
else:
|
||||||
|
source_type = SourceType.OTHER
|
||||||
|
|
||||||
|
return source_type
|
||||||
|
|
||||||
|
def get_all_data_sources_desc(self) -> str:
|
||||||
|
"""
|
||||||
|
Get descriptions of all data sources.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of data source descriptions
|
||||||
|
"""
|
||||||
|
return "Available data sources: \n" + "\n".join(
|
||||||
|
[
|
||||||
|
f"[{idx}] " + ds.get_coarse_desc()
|
||||||
|
for idx, ds in enumerate(self._data_sources.values())
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_local_data_sources(self) -> List[str]:
|
||||||
|
"""
|
||||||
|
Get list of local data source endpoints
|
||||||
|
"""
|
||||||
|
|
||||||
|
return [
|
||||||
|
ds.endpoint
|
||||||
|
for ds in self._data_sources.values()
|
||||||
|
if ds.source_access_type == SourceAccessType.DIRECT
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_all_data_sources_name(self) -> List[str]:
|
||||||
|
"""
|
||||||
|
Get a list of all data source names.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of all data source names
|
||||||
|
"""
|
||||||
|
return list(self._data_sources.keys())
|
||||||
|
|
||||||
|
def remove_data_source(self, name: str) -> bool:
|
||||||
|
"""
|
||||||
|
Remove a data source by name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Name of the data source to remove
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if successfully removed, False if not found
|
||||||
|
"""
|
||||||
|
if name in self._data_sources:
|
||||||
|
del self._data_sources[name]
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_default_config(self, source_type: SourceType) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get the default configuration for a source type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source_type: The SourceType to get default config for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Default configuration dictionary, empty dict if not registered
|
||||||
|
"""
|
||||||
|
return self._type_defaults.get(source_type, {})
|
||||||
|
|
||||||
|
def _load_default_config(self) -> None:
|
||||||
|
"""Load default type to configuration."""
|
||||||
|
try:
|
||||||
|
with open(
|
||||||
|
self._default_data_source_config,
|
||||||
|
"r",
|
||||||
|
encoding="utf-8",
|
||||||
|
) as f:
|
||||||
|
config = json.load(f)
|
||||||
|
|
||||||
|
# Load type defaults
|
||||||
|
for type_name, type_config in config.items():
|
||||||
|
try:
|
||||||
|
source_type = SourceType(type_name)
|
||||||
|
self._type_defaults[source_type] = type_config
|
||||||
|
except ValueError:
|
||||||
|
# Skip invalid source types
|
||||||
|
continue
|
||||||
|
|
||||||
|
except FileNotFoundError:
|
||||||
|
# If config file doesn't exist, initialize with empty configs
|
||||||
|
pass
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# If config file is invalid JSON, initialize with empty configs
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""Return the number of data sources managed."""
|
||||||
|
return len(self._data_sources)
|
||||||
|
|
||||||
|
def get_data_skills(self):
|
||||||
|
# TODO: update when data source changed
|
||||||
|
if self.selected_skills is None:
|
||||||
|
source_types = [
|
||||||
|
data.source_type for data in self._data_sources.values()
|
||||||
|
]
|
||||||
|
self.selected_skills = self.skill_manager.load(source_types)
|
||||||
|
|
||||||
|
return "\n".join(self.selected_skills)
|
||||||
30
alias/src/alias/agent/agents/data_source/utils.py
Normal file
30
alias/src/alias/agent/agents/data_source/utils.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
import logging
|
||||||
|
|
||||||
|
# Set up logger
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def replace_placeholders(obj, source_config):
|
||||||
|
if isinstance(obj, str):
|
||||||
|
import re
|
||||||
|
|
||||||
|
pattern = r"\$\{([^}]+)\}"
|
||||||
|
matches = re.finditer(pattern, obj)
|
||||||
|
result = obj
|
||||||
|
for match in matches:
|
||||||
|
var_name = match.group(1)
|
||||||
|
if var_name in source_config:
|
||||||
|
result = result.replace(
|
||||||
|
match.group(0),
|
||||||
|
str(source_config[var_name]),
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
elif isinstance(obj, dict):
|
||||||
|
return {
|
||||||
|
k: replace_placeholders(v, source_config) for k, v in obj.items()
|
||||||
|
}
|
||||||
|
elif isinstance(obj, list):
|
||||||
|
return [replace_placeholders(item, source_config) for item in obj]
|
||||||
|
else:
|
||||||
|
return obj
|
||||||
@@ -137,11 +137,12 @@ async def files_filter_pre_reply_hook(
|
|||||||
# Even if the user only uploaded supplementary files in this interaction,
|
# Even if the user only uploaded supplementary files in this interaction,
|
||||||
# We will also check whether the previously uploaded files are relevant
|
# We will also check whether the previously uploaded files are relevant
|
||||||
# to the question.
|
# to the question.
|
||||||
self.uploaded_files = list(
|
|
||||||
set(files_list) | set(getattr(self, "uploaded_files", [])),
|
uploaded_files = list(
|
||||||
|
set(files_list) | set(self.data_manager.get_local_data_sources()),
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(self.uploaded_files) < 100:
|
if len(uploaded_files) < 100:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Scalable files filtering: not enough files to filter.",
|
"Scalable files filtering: not enough files to filter.",
|
||||||
)
|
)
|
||||||
@@ -164,7 +165,7 @@ await files_filter(query, files_list, api_key=api_key)
|
|||||||
|
|
||||||
files_filter_code += template.substitute(
|
files_filter_code += template.substitute(
|
||||||
query=safe_query,
|
query=safe_query,
|
||||||
files_list=repr(self.uploaded_files),
|
files_list=repr(uploaded_files),
|
||||||
api_key=safe_api_key,
|
api_key=safe_api_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -28,63 +28,7 @@ When executing any data science task (data loading, cleaning, analysis, modeling
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Task Management Rules
|
## Principles: Fact-Based, No Assumptions
|
||||||
|
|
||||||
- **You must use `todo_write` to track progress**, especially for multi-step tasks.
|
|
||||||
- Mark each subtask as complete **immediately** upon finishing—no delays or batch updates.
|
|
||||||
- Skipping planning risks missing critical steps—this is unacceptable.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Data Handling Requirements
|
|
||||||
|
|
||||||
### Data Inspection Methods
|
|
||||||
|
|
||||||
Before any operation, **you must** inspect the true structure of the data source using tools (preferably `run_ipython_cell`):
|
|
||||||
|
|
||||||
| Data Type | Inspection Method |
|
|
||||||
|------------------|-----------------------------------------------------------------------------------|
|
|
||||||
| **Database** | Query table schema (`DESCRIBE table`) and preview first 5–10 rows (`SELECT * FROM ... LIMIT 5`) |
|
|
||||||
| **CSV/Excel** | Use `pandas.head(n)` to view column names and samples |
|
|
||||||
| **Images** | Use PIL to get dimensions/format, or invoke vision tools to extract content |
|
|
||||||
| **Text Files** | Read first 5–10 lines to determine structure and encoding |
|
|
||||||
| **JSON** | Inspect from outer to inner layers progressively |
|
|
||||||
|
|
||||||
> **Core Principle**: What you see is fact; what you haven’t seen is unknown.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### Data Preprocessing Methods
|
|
||||||
|
|
||||||
##### Messy Spreadsheet Handling
|
|
||||||
|
|
||||||
After initial inspection of CSV or Excel files, if you observe:
|
|
||||||
|
|
||||||
- Many `"Unnamed: X"`, `NaN`, or `NaT` entries
|
|
||||||
- Missing or ambiguous headers
|
|
||||||
- Multiple data blocks within a single worksheet
|
|
||||||
|
|
||||||
Then **prioritize** advanced cleaning tools:
|
|
||||||
|
|
||||||
- `clean_messy_spreadsheet`: Extract key information from tables and output as JSON for downstream analysis
|
|
||||||
|
|
||||||
Only fall back to manual pandas row/block parsing if this tool fails.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### Strict Data Volume Limits
|
|
||||||
|
|
||||||
To prevent system crashes, strictly limit data volume during queries and reads:
|
|
||||||
|
|
||||||
- **Database queries**: Always use `LIMIT` (typically 5–10 rows)
|
|
||||||
- **Well-structured CSV/Excel**: Use `head()`, `nrows`, or sampling to fetch minimal data
|
|
||||||
- **Large text files**: Read only the first few lines or process iteratively in chunks
|
|
||||||
|
|
||||||
> **Warning**: Unrestricted large data reads will cause system failure.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### Fact-Based, No Assumptions
|
|
||||||
|
|
||||||
- All decisions must be grounded in the **given task context**. Never simplify, generalize, or subjectively interpret the task goal, data purpose, or business scenario. Any action inconsistent with the problem context is invalid and dangerous.
|
- All decisions must be grounded in the **given task context**. Never simplify, generalize, or subjectively interpret the task goal, data purpose, or business scenario. Any action inconsistent with the problem context is invalid and dangerous.
|
||||||
- Never act on assumptions, guesses, or past experience—even if the situation seems "obvious" or "routine."
|
- Never act on assumptions, guesses, or past experience—even if the situation seems "obvious" or "routine."
|
||||||
@@ -93,6 +37,14 @@ To prevent system crashes, strictly limit data volume during queries and reads:
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
## Task Management Rules
|
||||||
|
|
||||||
|
- **You must use `todo_write` to track progress**, especially for multi-step tasks.
|
||||||
|
- Mark each subtask as complete **immediately** upon finishing—no delays or batch updates.
|
||||||
|
- Skipping planning risks missing critical steps—this is unacceptable.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
## Visualization Strategy
|
## Visualization Strategy
|
||||||
|
|
||||||
- **Plotting library**: Prefer `matplotlib`
|
- **Plotting library**: Prefer `matplotlib`
|
||||||
|
|||||||
@@ -44,8 +44,13 @@ Each task in the roadmap contains:
|
|||||||
- Brief Response
|
- Brief Response
|
||||||
- Detailed Report
|
- Detailed Report
|
||||||
- You should choose the template that is most appropriate for the user task.
|
- You should choose the template that is most appropriate for the user task.
|
||||||
- **Brief Respoonse Template** should ONLY be used when the user asks for a simple data query task, where ONLY numeric or concise string values are returned, and complex analysis or research are not required.
|
- **Brief Respoonse Template** should ONLY be used when the user asks for a
|
||||||
- **Detailed Report Template** should be used when the user asks for a detailed analysis of the data, where the analysis and research are required.
|
simple, static data point (e.g., a total count or a specific value), where
|
||||||
|
the answer is returned as a single numeric or concise string value with no
|
||||||
|
analysis, transformation, comparison, or interpretation required.
|
||||||
|
- **Detailed Report Template** should be used whenever the task involves
|
||||||
|
distribution, discrepancy, imbalance, comparison, trend, root cause, or
|
||||||
|
any form of analysis, interpretation, or evidence generation.
|
||||||
|
|
||||||
2. Data Source Constraints
|
2. Data Source Constraints
|
||||||
- **ONLY use information explicitly present in the log file**
|
- **ONLY use information explicitly present in the log file**
|
||||||
@@ -103,7 +108,7 @@ You MUST ensure all captions, subtitles, and other contents in the report are wr
|
|||||||
- "brief_response": The brief response content.
|
- "brief_response": The brief response content.
|
||||||
- When 'is_brief_response' is True, this field should be fulfilled with the brief response content following the **Brief Response Template**.
|
- When 'is_brief_response' is True, this field should be fulfilled with the brief response content following the **Brief Response Template**.
|
||||||
- When 'is_brief_response' is False, this field should be a concise summary of the detailed report in in markdown format illustrating the key findings and insights.
|
- When 'is_brief_response' is False, this field should be a concise summary of the detailed report in in markdown format illustrating the key findings and insights.
|
||||||
- "detailed_report_content": The detailed markdown report content following the **Detailed Report Template**. This field is ONLY generated when 'is_brief_response' is False, otherwise fulfill an empty string.
|
- "report_content": The detailed markdown report content following the **Detailed Report Template**. This field is ONLY generated when 'is_brief_response' is False, otherwise fulfill an empty string.
|
||||||
- You MUST ensure the JSON object is a valid JSON string and can be parsed by json.loads().
|
- You MUST ensure the JSON object is a valid JSON string and can be parsed by json.loads().
|
||||||
- Double check all escapes are valid.
|
- Double check all escapes are valid.
|
||||||
|
|
||||||
|
|||||||
@@ -112,8 +112,11 @@ def truncate_long_text_post_hook(
|
|||||||
def _add_tool_postprocessing_func(toolkit: AliasToolkit) -> None:
|
def _add_tool_postprocessing_func(toolkit: AliasToolkit) -> None:
|
||||||
for tool_func, _ in toolkit.tools.items():
|
for tool_func, _ in toolkit.tools.items():
|
||||||
if tool_func.startswith("run_ipython_cell"):
|
if tool_func.startswith("run_ipython_cell"):
|
||||||
funcs: list = [ansi_escape_post_hook]
|
funcs: list = [
|
||||||
funcs.append(summarize_plt_chart_hook)
|
ansi_escape_post_hook,
|
||||||
|
summarize_plt_chart_hook,
|
||||||
|
truncate_long_text_post_hook,
|
||||||
|
]
|
||||||
toolkit.tools[tool_func].postprocess_func = partial(
|
toolkit.tools[tool_func].postprocess_func = partial(
|
||||||
run_ipython_cell_post_hook,
|
run_ipython_cell_post_hook,
|
||||||
funcs,
|
funcs,
|
||||||
|
|||||||
@@ -1,20 +1,62 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
import os
|
import os
|
||||||
import json
|
|
||||||
import time
|
import time
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
import dotenv
|
import dotenv
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from agentscope.message import Msg
|
from agentscope.message import Msg
|
||||||
|
|
||||||
from .utils import model_call_with_retry, get_prompt_from_file
|
from .utils import model_call_with_retry, get_prompt_from_file
|
||||||
|
|
||||||
|
|
||||||
from .ds_config import PROMPT_DS_BASE_PATH
|
from .ds_config import PROMPT_DS_BASE_PATH
|
||||||
|
|
||||||
dotenv.load_dotenv()
|
dotenv.load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
class ReportResponse(BaseModel):
|
||||||
|
is_brief_response: bool = Field(
|
||||||
|
...,
|
||||||
|
description=(
|
||||||
|
"True if the response is a brief response; "
|
||||||
|
"False if it includes a detailed report."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
brief_response: str = Field(
|
||||||
|
...,
|
||||||
|
description=(
|
||||||
|
"The brief response content. "
|
||||||
|
"When 'is_brief_response' is True, this field contains the full "
|
||||||
|
"brief response following the Brief Response Template. "
|
||||||
|
"When 'is_brief_response' is False, this field contains a concise "
|
||||||
|
"markdown summary of the detailed report, highlighting key "
|
||||||
|
"findings and insights."
|
||||||
|
),
|
||||||
|
json_schema_extra={
|
||||||
|
"example": (
|
||||||
|
"The analysis shows a 15% increase in user engagement "
|
||||||
|
"after the feature update."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
report_content: str = Field(
|
||||||
|
...,
|
||||||
|
description=(
|
||||||
|
"The detailed markdown report content following the "
|
||||||
|
"Detailed Report Template. This field MUST be an empty "
|
||||||
|
"string ('') when 'is_brief_response' is True. It MUST contain "
|
||||||
|
"the full detailed report when 'is_brief_response' is False."
|
||||||
|
),
|
||||||
|
json_schema_extra={
|
||||||
|
"example": "### User Task Description...\n"
|
||||||
|
"### Associated Data Sources...\n"
|
||||||
|
"### Research Conclusion...\n### Task1...### Task2...",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ReportGenerator:
|
class ReportGenerator:
|
||||||
def __init__(self, model, formatter, memory_log: str):
|
def __init__(self, model, formatter, memory_log: str):
|
||||||
self.model = model
|
self.model = model
|
||||||
@@ -62,22 +104,13 @@ class ReportGenerator:
|
|||||||
self.formatter,
|
self.formatter,
|
||||||
msgs=msgs,
|
msgs=msgs,
|
||||||
msg_name="Report Generation",
|
msg_name="Report Generation",
|
||||||
|
structured_model=ReportResponse,
|
||||||
)
|
)
|
||||||
|
|
||||||
raw_response = res.content[0]["text"]
|
|
||||||
|
|
||||||
# TODO: More robust response cleaning
|
|
||||||
if raw_response.strip().startswith("```json"):
|
|
||||||
cleaned = raw_response.strip()[len("```json") :].lstrip("\n")
|
|
||||||
if cleaned.endswith("```"):
|
|
||||||
cleaned = cleaned[:-3].rstrip()
|
|
||||||
response = cleaned
|
|
||||||
else:
|
|
||||||
response = raw_response.strip()
|
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
# print(response)
|
|
||||||
print(f"Log to markdown took {end_time - start_time} seconds")
|
print(f"Log to markdown took {end_time - start_time} seconds")
|
||||||
return response
|
|
||||||
|
return res.content[-1]["input"]
|
||||||
|
|
||||||
async def _convert_to_html(self, markdown_content: str) -> str:
|
async def _convert_to_html(self, markdown_content: str) -> str:
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
@@ -102,21 +135,15 @@ class ReportGenerator:
|
|||||||
print(f"Convert to html took {end_time - start_time} seconds")
|
print(f"Convert to html took {end_time - start_time} seconds")
|
||||||
return response.content[0]["text"]
|
return response.content[0]["text"]
|
||||||
|
|
||||||
async def generate_report(self) -> Tuple[str, str]:
|
async def generate_report(self) -> Tuple[str, str, str]:
|
||||||
markdown_response = await self._log_to_markdown()
|
"""
|
||||||
|
responseFormat: {
|
||||||
# responseFormat: {
|
"is_brief_response": True,
|
||||||
# "is_brief_response": True,
|
"brief_response": brief_response_content,
|
||||||
# "brief_response": brief_response_content,
|
"report_content": detailed_report_content
|
||||||
# "report_content": detailed_report_content
|
}
|
||||||
# }
|
"""
|
||||||
|
markdown_content = await self._log_to_markdown()
|
||||||
try:
|
|
||||||
markdown_content = json.loads(markdown_response)
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
print(f"Error parsing JSON response: {e}")
|
|
||||||
print(f"Response content: {markdown_response}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
str(markdown_content.get("is_brief_response", False)).lower()
|
str(markdown_content.get("is_brief_response", False)).lower()
|
||||||
@@ -124,12 +151,19 @@ class ReportGenerator:
|
|||||||
):
|
):
|
||||||
# During brief response mode,
|
# During brief response mode,
|
||||||
# directly return the brief response to the user.
|
# directly return the brief response to the user.
|
||||||
return markdown_content["brief_response"], ""
|
return markdown_content.get("brief_response", ""), "", ""
|
||||||
else:
|
else:
|
||||||
# In detailed report mode,
|
# In detailed report mode,
|
||||||
# convert the detailed report to HTML and return it to the user;
|
# convert the detailed report to HTML and return it to the user;
|
||||||
# if a brief summary of the report is needed,
|
# if a brief summary of the report is needed,
|
||||||
# it can be obtained through markdown_content["brief_response"].
|
# it can be obtained through markdown_content["brief_response"].
|
||||||
return markdown_content[
|
html_content = ""
|
||||||
"brief_response"
|
if os.getenv("ENABLE_HTML_REPORT", "ON").lower() != "off":
|
||||||
], await self._convert_to_html(markdown_content["report_content"])
|
html_content = await self._convert_to_html(
|
||||||
|
markdown_content.get("report_content", ""),
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
markdown_content.get("brief_response", ""),
|
||||||
|
markdown_content.get("report_content", ""),
|
||||||
|
html_content,
|
||||||
|
)
|
||||||
|
|||||||
@@ -5,9 +5,9 @@ import json
|
|||||||
from typing import Union
|
from typing import Union
|
||||||
from agentscope.message import Msg
|
from agentscope.message import Msg
|
||||||
from tenacity import retry, stop_after_attempt, wait_fixed
|
from tenacity import retry, stop_after_attempt, wait_fixed
|
||||||
from .ds_config import PROMPT_DS_BASE_PATH
|
|
||||||
|
|
||||||
MODEL_MAX_RETRIES = 50
|
from alias.agent.utils.constants import MODEL_MAX_RETRIES
|
||||||
|
from .ds_config import PROMPT_DS_BASE_PATH
|
||||||
|
|
||||||
|
|
||||||
def get_prompt_from_file(
|
def get_prompt_from_file(
|
||||||
@@ -36,10 +36,16 @@ async def model_call_with_retry(
|
|||||||
tool_json_schemas=None,
|
tool_json_schemas=None,
|
||||||
tool_choice=None,
|
tool_choice=None,
|
||||||
msg_name="model_call",
|
msg_name="model_call",
|
||||||
|
structured_model=None,
|
||||||
) -> Msg:
|
) -> Msg:
|
||||||
prompt = await formatter.format(msgs=msgs)
|
prompt = await formatter.format(msgs=msgs)
|
||||||
|
|
||||||
res = await model(prompt, tools=tool_json_schemas, tool_choice=tool_choice)
|
res = await model(
|
||||||
|
prompt,
|
||||||
|
tools=tool_json_schemas,
|
||||||
|
tool_choice=tool_choice,
|
||||||
|
structured_model=structured_model,
|
||||||
|
)
|
||||||
|
|
||||||
if model.stream:
|
if model.stream:
|
||||||
msg = Msg(msg_name, [], "assistant")
|
msg = Msg(msg_name, [], "assistant")
|
||||||
|
|||||||
@@ -3,10 +3,15 @@
|
|||||||
import uuid
|
import uuid
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Optional, Literal
|
from typing import Any, Optional, Literal
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime, timezone
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
def _get_utc_now_iso():
|
||||||
|
return datetime.now(timezone.utc).isoformat()
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MockFileBase:
|
class MockFileBase:
|
||||||
filename: str
|
filename: str
|
||||||
@@ -15,8 +20,8 @@ class MockFileBase:
|
|||||||
storage_path: str
|
storage_path: str
|
||||||
size: int = -1
|
size: int = -1
|
||||||
storage_type: str = "unknown"
|
storage_type: str = "unknown"
|
||||||
create_time: str = "xxxyyy"
|
create_time: str = field(default_factory=_get_utc_now_iso)
|
||||||
update_time: str = "xxxyyy"
|
update_time: str = field(default_factory=_get_utc_now_iso)
|
||||||
user_id: uuid.UUID = uuid.uuid4()
|
user_id: uuid.UUID = uuid.uuid4()
|
||||||
|
|
||||||
|
|
||||||
@@ -62,12 +67,13 @@ class UserMessage(BaseMessage):
|
|||||||
name: str = "User"
|
name: str = "User"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
class MockMessage:
|
class MockMessage:
|
||||||
id: uuid.UUID = uuid.uuid4()
|
id: uuid.UUID = field(default_factory=uuid.uuid4)
|
||||||
message: Optional[dict] = None
|
message: Optional[dict] = None
|
||||||
files: list[Any] = []
|
files: list[Any] = field(default_factory=list)
|
||||||
create_time: str = "xxxyyy"
|
create_time: str = field(default_factory=_get_utc_now_iso)
|
||||||
update_time: str = "xxxyyy"
|
update_time: str = field(default_factory=_get_utc_now_iso)
|
||||||
|
|
||||||
|
|
||||||
class SubTaskToPrint(BaseModel):
|
class SubTaskToPrint(BaseModel):
|
||||||
|
|||||||
@@ -40,6 +40,7 @@ class SessionEntity:
|
|||||||
query: str
|
query: str
|
||||||
upload_files: List = []
|
upload_files: List = []
|
||||||
is_chat: bool = False
|
is_chat: bool = False
|
||||||
|
data_config: List | None = None
|
||||||
use_long_term_memory_service: bool = False
|
use_long_term_memory_service: bool = False
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -51,6 +52,7 @@ class SessionEntity:
|
|||||||
"bi",
|
"bi",
|
||||||
"finance",
|
"finance",
|
||||||
] = "general",
|
] = "general",
|
||||||
|
data_config: List | None = None,
|
||||||
use_long_term_memory_service: bool = False,
|
use_long_term_memory_service: bool = False,
|
||||||
):
|
):
|
||||||
self.user_id: uuid.UUID = uuid.UUID(
|
self.user_id: uuid.UUID = uuid.UUID(
|
||||||
@@ -62,6 +64,7 @@ class SessionEntity:
|
|||||||
self.conversation_id: uuid.UUID = uuid.uuid4()
|
self.conversation_id: uuid.UUID = uuid.uuid4()
|
||||||
self.session_id: uuid.UUID = uuid.uuid4()
|
self.session_id: uuid.UUID = uuid.uuid4()
|
||||||
self.chat_mode = chat_mode
|
self.chat_mode = chat_mode
|
||||||
|
self.data_config = data_config
|
||||||
self.use_long_term_memory_service = use_long_term_memory_service
|
self.use_long_term_memory_service = use_long_term_memory_service
|
||||||
|
|
||||||
def ids(self):
|
def ids(self):
|
||||||
@@ -79,6 +82,7 @@ class MockSessionService:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
runtime_model: Any = None,
|
runtime_model: Any = None,
|
||||||
|
data_config: List | None = None,
|
||||||
use_long_term_memory_service: bool = False,
|
use_long_term_memory_service: bool = False,
|
||||||
):
|
):
|
||||||
self.session_id = "mock_session"
|
self.session_id = "mock_session"
|
||||||
@@ -86,6 +90,7 @@ class MockSessionService:
|
|||||||
self.messages = []
|
self.messages = []
|
||||||
self.plan = MockPlan()
|
self.plan = MockPlan()
|
||||||
self.session_entity = SessionEntity(
|
self.session_entity = SessionEntity(
|
||||||
|
data_config=data_config,
|
||||||
use_long_term_memory_service=use_long_term_memory_service,
|
use_long_term_memory_service=use_long_term_memory_service,
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -215,6 +220,72 @@ class MockSessionService:
|
|||||||
self.messages.append(db_message)
|
self.messages.append(db_message)
|
||||||
return db_message
|
return db_message
|
||||||
|
|
||||||
|
async def append_to_latest_message(
|
||||||
|
self,
|
||||||
|
content_to_append: str,
|
||||||
|
role_filter: Optional[str] = None,
|
||||||
|
) -> Optional[MockMessage]:
|
||||||
|
"""
|
||||||
|
Append content to the most recent message.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content_to_append: Content to append to the message
|
||||||
|
role_filter: Optional role filter (e.g., 'user', 'assistant')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated MockMessage or None if no message found
|
||||||
|
"""
|
||||||
|
if not self.messages:
|
||||||
|
logger.warning("No messages to append to")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Find the most recent message (optionally filtered by role)
|
||||||
|
target_message = None
|
||||||
|
for msg in reversed(self.messages):
|
||||||
|
if role_filter is None or msg.message.get("role") == role_filter:
|
||||||
|
target_message = msg
|
||||||
|
break
|
||||||
|
|
||||||
|
if target_message is None:
|
||||||
|
logger.warning(f"No message found with role={role_filter}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Append content
|
||||||
|
current_content = target_message.message.get("content", "")
|
||||||
|
if isinstance(current_content, str):
|
||||||
|
target_message.message["content"] = (
|
||||||
|
current_content + content_to_append
|
||||||
|
)
|
||||||
|
elif isinstance(current_content, list):
|
||||||
|
# Handle multi-modal content (list of content blocks)
|
||||||
|
target_message.message["content"].append(
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": content_to_append,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.error(f"Unsupported content type: {type(current_content)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Update timestamp
|
||||||
|
target_message.update_time = datetime.now(timezone.utc).isoformat()
|
||||||
|
|
||||||
|
# Optional: Log to file
|
||||||
|
if hasattr(self, "log_storage_path"):
|
||||||
|
content_log = (
|
||||||
|
"=" * 10
|
||||||
|
+ "\n"
|
||||||
|
+ f"APPEND to Role: {target_message.message.get('role')}\n"
|
||||||
|
+ f"Appended: {content_to_append}\n"
|
||||||
|
+ "=" * 10
|
||||||
|
+ "\n"
|
||||||
|
)
|
||||||
|
with open(self.log_storage_path, "a") as file:
|
||||||
|
file.write(content_log)
|
||||||
|
|
||||||
|
return target_message
|
||||||
|
|
||||||
async def get_messages(self) -> List[MockMessage]:
|
async def get_messages(self) -> List[MockMessage]:
|
||||||
logger.log("SEND_MSG", "Get all messages")
|
logger.log("SEND_MSG", "Get all messages")
|
||||||
return self.messages
|
return self.messages
|
||||||
|
|||||||
@@ -1,9 +1,11 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# pylint: disable=W0612,E0611,C2801
|
# pylint: disable=W0612,E0611,C2801
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import traceback
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import traceback
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
@@ -17,27 +19,33 @@ from alias.agent.agents import (
|
|||||||
BrowserAgent,
|
BrowserAgent,
|
||||||
DeepResearchAgent,
|
DeepResearchAgent,
|
||||||
MetaPlanner,
|
MetaPlanner,
|
||||||
DataScienceAgent,
|
|
||||||
init_ds_toolkit,
|
|
||||||
init_dr_toolkit,
|
init_dr_toolkit,
|
||||||
)
|
)
|
||||||
|
|
||||||
from alias.agent.agents.meta_planner_utils._worker_manager import share_tools
|
from alias.agent.agents.meta_planner_utils._worker_manager import share_tools
|
||||||
from alias.agent.mock import MockSessionService as SessionService
|
from alias.agent.mock import MockSessionService as SessionService
|
||||||
from alias.agent.tools import AliasToolkit
|
from alias.agent.tools import AliasToolkit
|
||||||
|
|
||||||
from alias.agent.utils.constants import (
|
from alias.agent.utils.constants import (
|
||||||
BROWSER_AGENT_DESCRIPTION,
|
BROWSER_AGENT_DESCRIPTION,
|
||||||
DEFAULT_DEEP_RESEARCH_AGENT_NAME,
|
DEFAULT_DEEP_RESEARCH_AGENT_NAME,
|
||||||
DEEPRESEARCH_AGENT_DESCRIPTION,
|
DEEPRESEARCH_AGENT_DESCRIPTION,
|
||||||
DS_AGENT_DESCRIPTION,
|
DS_AGENT_DESCRIPTION,
|
||||||
)
|
)
|
||||||
from alias.agent.tools.add_tools import add_tools
|
from alias.agent.utils.prepare_data_source import (
|
||||||
from alias.agent.agents.ds_agent_utils import (
|
add_data_source_tools,
|
||||||
add_ds_specific_tool,
|
prepare_data_sources,
|
||||||
)
|
)
|
||||||
|
from alias.agent.tools.add_tools import add_tools
|
||||||
from alias.agent.memory.longterm_memory import AliasLongTermMemory
|
from alias.agent.memory.longterm_memory import AliasLongTermMemory
|
||||||
from alias.server.clients.memory_client import MemoryClient
|
from alias.server.clients.memory_client import MemoryClient
|
||||||
|
from alias.agent.agents._data_science_agent import (
|
||||||
|
DataScienceAgent,
|
||||||
|
init_ds_toolkit,
|
||||||
|
)
|
||||||
|
|
||||||
|
from alias.agent.utils.llm_call_manager import (
|
||||||
|
LLMCallManager,
|
||||||
|
)
|
||||||
|
|
||||||
MODEL_FORMATTER_MAPPING = {
|
MODEL_FORMATTER_MAPPING = {
|
||||||
"qwen3-max": [
|
"qwen3-max": [
|
||||||
@@ -104,9 +112,28 @@ async def arun_meta_planner(
|
|||||||
# Init deep research toolkit
|
# Init deep research toolkit
|
||||||
deep_research_toolkit = init_dr_toolkit(worker_full_toolkit)
|
deep_research_toolkit = init_dr_toolkit(worker_full_toolkit)
|
||||||
|
|
||||||
# Init BI agent toolkit
|
# Init data science agent toolkit
|
||||||
ds_toolkit = init_ds_toolkit(worker_full_toolkit)
|
ds_toolkit = init_ds_toolkit(worker_full_toolkit)
|
||||||
|
|
||||||
|
# Initialize data source manager
|
||||||
|
llm_call_manager = LLMCallManager(
|
||||||
|
base_model_name=MODEL_CONFIG_NAME,
|
||||||
|
vl_model_name=VL_MODEL_NAME,
|
||||||
|
model_formatter_mapping=MODEL_FORMATTER_MAPPING,
|
||||||
|
)
|
||||||
|
data_manager = await prepare_data_sources(
|
||||||
|
session_service=session_service,
|
||||||
|
sandbox=sandbox,
|
||||||
|
llm_call_manager=llm_call_manager,
|
||||||
|
)
|
||||||
|
add_data_source_tools(
|
||||||
|
data_manager,
|
||||||
|
worker_full_toolkit,
|
||||||
|
browser_toolkit,
|
||||||
|
deep_research_toolkit,
|
||||||
|
ds_toolkit,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
model, formatter = MODEL_FORMATTER_MAPPING[MODEL_CONFIG_NAME]
|
model, formatter = MODEL_FORMATTER_MAPPING[MODEL_CONFIG_NAME]
|
||||||
browser_agent = BrowserAgent(
|
browser_agent = BrowserAgent(
|
||||||
@@ -175,13 +202,15 @@ async def arun_meta_planner(
|
|||||||
description=DEEPRESEARCH_AGENT_DESCRIPTION,
|
description=DEEPRESEARCH_AGENT_DESCRIPTION,
|
||||||
worker_type="built-in",
|
worker_type="built-in",
|
||||||
)
|
)
|
||||||
# === add BI agent ===
|
# === add data science agent ===
|
||||||
ds_agent = DataScienceAgent(
|
ds_agent = DataScienceAgent(
|
||||||
name="Data_Science_Agent",
|
name="Data_Science_Agent",
|
||||||
model=model,
|
model=model,
|
||||||
formatter=formatter,
|
formatter=formatter,
|
||||||
memory=InMemoryMemory(),
|
memory=InMemoryMemory(),
|
||||||
toolkit=ds_toolkit,
|
toolkit=ds_toolkit,
|
||||||
|
data_manager=data_manager,
|
||||||
|
sys_prompt=data_manager.get_data_skills(),
|
||||||
max_iters=30,
|
max_iters=30,
|
||||||
session_service=session_service,
|
session_service=session_service,
|
||||||
)
|
)
|
||||||
@@ -219,6 +248,19 @@ async def arun_deepresearch_agent(
|
|||||||
"run_shell_command",
|
"run_shell_command",
|
||||||
]
|
]
|
||||||
share_tools(global_toolkit, worker_toolkit, test_tool_list)
|
share_tools(global_toolkit, worker_toolkit, test_tool_list)
|
||||||
|
|
||||||
|
llm_call_manager = LLMCallManager(
|
||||||
|
base_model_name=MODEL_CONFIG_NAME,
|
||||||
|
vl_model_name=VL_MODEL_NAME,
|
||||||
|
model_formatter_mapping=MODEL_FORMATTER_MAPPING,
|
||||||
|
)
|
||||||
|
await prepare_data_sources(
|
||||||
|
session_service,
|
||||||
|
sandbox,
|
||||||
|
worker_toolkit,
|
||||||
|
llm_call_manager,
|
||||||
|
)
|
||||||
|
|
||||||
worker_agent = DeepResearchAgent(
|
worker_agent = DeepResearchAgent(
|
||||||
name="Deep_Research_Agent",
|
name="Deep_Research_Agent",
|
||||||
model=model,
|
model=model,
|
||||||
@@ -285,6 +327,18 @@ async def arun_finance_agent(
|
|||||||
active=True,
|
active=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
llm_call_manager = LLMCallManager(
|
||||||
|
base_model_name=MODEL_CONFIG_NAME,
|
||||||
|
vl_model_name=VL_MODEL_NAME,
|
||||||
|
model_formatter_mapping=MODEL_FORMATTER_MAPPING,
|
||||||
|
)
|
||||||
|
await prepare_data_sources(
|
||||||
|
session_service,
|
||||||
|
sandbox,
|
||||||
|
worker_toolkit,
|
||||||
|
llm_call_manager,
|
||||||
|
)
|
||||||
|
|
||||||
worker_agent = DeepResearchAgent(
|
worker_agent = DeepResearchAgent(
|
||||||
name="Deep_Research_Agent",
|
name="Deep_Research_Agent",
|
||||||
model=model,
|
model=model,
|
||||||
@@ -326,17 +380,21 @@ async def arun_datascience_agent(
|
|||||||
session_service: SessionService, # type: ignore[valid-type]
|
session_service: SessionService, # type: ignore[valid-type]
|
||||||
sandbox: Sandbox = None,
|
sandbox: Sandbox = None,
|
||||||
):
|
):
|
||||||
global_toolkit = AliasToolkit(sandbox, add_all=True)
|
|
||||||
# await add_tools(global_toolkit)
|
|
||||||
worker_toolkit = AliasToolkit(sandbox)
|
|
||||||
model, formatter = MODEL_FORMATTER_MAPPING[MODEL_CONFIG_NAME]
|
model, formatter = MODEL_FORMATTER_MAPPING[MODEL_CONFIG_NAME]
|
||||||
test_tool_list = [
|
|
||||||
"write_file",
|
global_toolkit = AliasToolkit(sandbox, add_all=True)
|
||||||
"run_ipython_cell",
|
worker_toolkit = init_ds_toolkit(global_toolkit)
|
||||||
"run_shell_command",
|
llm_call_manager = LLMCallManager(
|
||||||
]
|
base_model_name=MODEL_CONFIG_NAME,
|
||||||
share_tools(global_toolkit, worker_toolkit, test_tool_list)
|
vl_model_name=VL_MODEL_NAME,
|
||||||
add_ds_specific_tool(worker_toolkit)
|
model_formatter_mapping=MODEL_FORMATTER_MAPPING,
|
||||||
|
)
|
||||||
|
data_manager = await prepare_data_sources(
|
||||||
|
session_service=session_service,
|
||||||
|
sandbox=sandbox,
|
||||||
|
binded_toolkit=worker_toolkit,
|
||||||
|
llm_call_manager=llm_call_manager,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
worker_agent = DataScienceAgent(
|
worker_agent = DataScienceAgent(
|
||||||
@@ -345,6 +403,8 @@ async def arun_datascience_agent(
|
|||||||
formatter=formatter,
|
formatter=formatter,
|
||||||
memory=InMemoryMemory(),
|
memory=InMemoryMemory(),
|
||||||
toolkit=worker_toolkit,
|
toolkit=worker_toolkit,
|
||||||
|
data_manager=data_manager,
|
||||||
|
sys_prompt=data_manager.get_data_skills(),
|
||||||
max_iters=30,
|
max_iters=30,
|
||||||
session_service=session_service,
|
session_service=session_service,
|
||||||
)
|
)
|
||||||
@@ -360,6 +420,7 @@ async def arun_datascience_agent(
|
|||||||
finally:
|
finally:
|
||||||
try:
|
try:
|
||||||
await global_toolkit.close_mcp_clients()
|
await global_toolkit.close_mcp_clients()
|
||||||
|
await worker_toolkit.close_mcp_clients()
|
||||||
except (RuntimeError, asyncio.CancelledError) as e:
|
except (RuntimeError, asyncio.CancelledError) as e:
|
||||||
# Event loop might be closed during shutdown
|
# Event loop might be closed during shutdown
|
||||||
if "Event loop is closed" in str(e) or isinstance(
|
if "Event loop is closed" in str(e) or isinstance(
|
||||||
@@ -386,6 +447,18 @@ async def arun_browseruse_agent(
|
|||||||
add_all=True,
|
add_all=True,
|
||||||
is_browser_toolkit=True,
|
is_browser_toolkit=True,
|
||||||
)
|
)
|
||||||
|
llm_call_manager = LLMCallManager(
|
||||||
|
base_model_name=MODEL_CONFIG_NAME,
|
||||||
|
vl_model_name=VL_MODEL_NAME,
|
||||||
|
model_formatter_mapping=MODEL_FORMATTER_MAPPING,
|
||||||
|
)
|
||||||
|
await prepare_data_sources(
|
||||||
|
session_service,
|
||||||
|
sandbox,
|
||||||
|
browser_toolkit,
|
||||||
|
llm_call_manager,
|
||||||
|
)
|
||||||
|
|
||||||
logger.info("Init browser toolkit")
|
logger.info("Init browser toolkit")
|
||||||
try:
|
try:
|
||||||
browser_agent = BrowserAgent(
|
browser_agent = BrowserAgent(
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import json
|
|||||||
import os
|
import os
|
||||||
import tarfile
|
import tarfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
import shlex
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
@@ -172,7 +173,7 @@ def get_workspace_file(
|
|||||||
)
|
)
|
||||||
tool_result = sandbox.call_tool(
|
tool_result = sandbox.call_tool(
|
||||||
"run_shell_command",
|
"run_shell_command",
|
||||||
arguments={"command": f"base64 -i {file_path}"},
|
arguments={"command": f"base64 -i {shlex.quote(file_path)}"},
|
||||||
)
|
)
|
||||||
return tool_result["content"][0]["text"]
|
return tool_result["content"][0]["text"]
|
||||||
|
|
||||||
@@ -194,7 +195,7 @@ def create_or_edit_workspace_file(
|
|||||||
}
|
}
|
||||||
sandbox.call_tool(
|
sandbox.call_tool(
|
||||||
"run_shell_command",
|
"run_shell_command",
|
||||||
arguments={"command": f"touch {file_path}"},
|
arguments={"command": f"touch {shlex.quote(file_path)}"},
|
||||||
)
|
)
|
||||||
fill_result = sandbox.call_tool(
|
fill_result = sandbox.call_tool(
|
||||||
"write_file",
|
"write_file",
|
||||||
@@ -222,7 +223,7 @@ def create_workspace_directory(
|
|||||||
}
|
}
|
||||||
tool_result = sandbox.call_tool(
|
tool_result = sandbox.call_tool(
|
||||||
"run_shell_command",
|
"run_shell_command",
|
||||||
arguments={"command": f"mkdir -p {directory_path}"},
|
arguments={"command": f"mkdir -p {shlex.quote(directory_path)}"},
|
||||||
)
|
)
|
||||||
return tool_result
|
return tool_result
|
||||||
|
|
||||||
@@ -246,7 +247,7 @@ def delete_workspace_file(
|
|||||||
}
|
}
|
||||||
tool_result = sandbox.call_tool(
|
tool_result = sandbox.call_tool(
|
||||||
"run_shell_command",
|
"run_shell_command",
|
||||||
arguments={"command": f"rm -rf {file_path}"},
|
arguments={"command": f"rm -rf {shlex.quote(file_path)}"},
|
||||||
)
|
)
|
||||||
return tool_result
|
return tool_result
|
||||||
|
|
||||||
@@ -280,7 +281,7 @@ def download_workspace_file_from_oss(
|
|||||||
tool_result = sandbox.call_tool(
|
tool_result = sandbox.call_tool(
|
||||||
"run_shell_command",
|
"run_shell_command",
|
||||||
arguments={
|
arguments={
|
||||||
"command": f"wget -O {to_path} {oss_url}",
|
"command": f"wget -O {shlex.quote(to_path)} {oss_url}",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
print(f"{tool_result}")
|
print(f"{tool_result}")
|
||||||
@@ -306,7 +307,7 @@ def delete_workspace_directory(
|
|||||||
}
|
}
|
||||||
tool_result = sandbox.call_tool(
|
tool_result = sandbox.call_tool(
|
||||||
"run_shell_command",
|
"run_shell_command",
|
||||||
arguments={"command": f"rm -rf {directory_path}"},
|
arguments={"command": f"rm -rf {shlex.quote(directory_path)}"},
|
||||||
)
|
)
|
||||||
return tool_result
|
return tool_result
|
||||||
|
|
||||||
|
|||||||
@@ -1,131 +1,8 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
import json
|
|
||||||
import os.path
|
|
||||||
import uuid
|
|
||||||
import textwrap
|
|
||||||
|
|
||||||
from agentscope.tool import ToolResponse
|
from alias.agent.tools.toolkit_hooks.text_post_hook import TextPostHook
|
||||||
from agentscope.message import ToolUseBlock, TextBlock
|
|
||||||
|
|
||||||
from alias.agent.utils.constants import TMP_FILE_DIR
|
|
||||||
from alias.agent.tools.sandbox_util import (
|
|
||||||
create_or_edit_workspace_file,
|
|
||||||
create_workspace_directory,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class LongTextPostHook:
|
class LongTextPostHook(TextPostHook):
|
||||||
def __init__(self, sandbox):
|
def __init__(self, sandbox):
|
||||||
self.sandbox = sandbox
|
super().__init__(sandbox, budget=8194 * 10, auto_save=False)
|
||||||
|
|
||||||
def truncate_and_save_response( # pylint: disable=R1710
|
|
||||||
self,
|
|
||||||
tool_use: ToolUseBlock, # pylint: disable=W0613
|
|
||||||
tool_response: ToolResponse,
|
|
||||||
) -> ToolResponse:
|
|
||||||
"""Post-process tool responses to prevent content overflow.
|
|
||||||
|
|
||||||
This function ensures that tool responses don't exceed a predefined
|
|
||||||
budget to prevent overwhelming the model with too much information.
|
|
||||||
It truncates text content while preserving the structure of
|
|
||||||
the response.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tool_use: The tool use block that triggered the response (unused).
|
|
||||||
tool_response: The tool response to potentially truncate.
|
|
||||||
|
|
||||||
Note:
|
|
||||||
The budget is set to approximately 80K tokens
|
|
||||||
(8194 * 10 characters) to ensure responses remain
|
|
||||||
manageable for the language model.
|
|
||||||
"""
|
|
||||||
# Set budget to prevent overwhelming the model with too much content
|
|
||||||
budget = 8194 * 10 # Approximately 80K tokens of content
|
|
||||||
append_hint = "\n\n[Content is too long and truncated....]"
|
|
||||||
|
|
||||||
new_tool_response = ToolResponse(
|
|
||||||
id=tool_response.id,
|
|
||||||
stream=tool_response.stream,
|
|
||||||
is_last=tool_response.is_last,
|
|
||||||
is_interrupted=tool_response.is_interrupted,
|
|
||||||
content=[],
|
|
||||||
)
|
|
||||||
if isinstance(tool_response.content, list):
|
|
||||||
save_text_block = None
|
|
||||||
for _i, block in enumerate(tool_response.content):
|
|
||||||
if block["type"] == "text":
|
|
||||||
text = block["text"]
|
|
||||||
text_len = len(text)
|
|
||||||
|
|
||||||
# If this block exceeds remaining budget, truncate it
|
|
||||||
if text_len > budget:
|
|
||||||
# Calculate truncation threshold
|
|
||||||
# (80% of proportional budget)
|
|
||||||
threshold = int(budget * 0.85)
|
|
||||||
# save the original response
|
|
||||||
tmp_file_name_prefix = tool_use.get("name", "")
|
|
||||||
save_text_block = self._save_tmp_file(
|
|
||||||
tmp_file_name_prefix,
|
|
||||||
tool_response.content,
|
|
||||||
)
|
|
||||||
new_tool_response.append = (
|
|
||||||
text[:threshold] + append_hint
|
|
||||||
)
|
|
||||||
new_tool_response.content.append(
|
|
||||||
TextBlock(
|
|
||||||
type="text",
|
|
||||||
text=text[:threshold] + append_hint,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
new_tool_response.content.append(block)
|
|
||||||
budget -= text_len
|
|
||||||
if budget <= 0 and save_text_block:
|
|
||||||
new_tool_response.content.append(save_text_block)
|
|
||||||
return new_tool_response
|
|
||||||
elif isinstance(tool_response.content, str):
|
|
||||||
text_len = len(tool_response.content)
|
|
||||||
text = tool_response.content
|
|
||||||
if text_len > budget:
|
|
||||||
tmp_file_name_prefix = tool_use.get("name", "")
|
|
||||||
save_text_block = self._save_tmp_file(
|
|
||||||
tmp_file_name_prefix,
|
|
||||||
tool_response.content,
|
|
||||||
)
|
|
||||||
# Calculate truncation threshold (80% of proportional budget)
|
|
||||||
threshold = int(budget / text_len * len(text) * 0.8)
|
|
||||||
tool_response.content = text[:threshold] + append_hint
|
|
||||||
tool_response.content = [
|
|
||||||
TextBlock(type="text", text=tool_response.content),
|
|
||||||
save_text_block,
|
|
||||||
]
|
|
||||||
return tool_response
|
|
||||||
return tool_response
|
|
||||||
|
|
||||||
def _save_tmp_file(self, save_file_name_prefix: str, content: list | str):
|
|
||||||
create_workspace_directory(self.sandbox, TMP_FILE_DIR)
|
|
||||||
save_file_name = (
|
|
||||||
save_file_name_prefix
|
|
||||||
+ "-"
|
|
||||||
+ str(
|
|
||||||
uuid.uuid4().hex[:8],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
file_path = os.path.join(TMP_FILE_DIR, save_file_name)
|
|
||||||
json_str = json.dumps(content, ensure_ascii=False, indent=2)
|
|
||||||
wrapped = "\\n".join(
|
|
||||||
[textwrap.fill(line, width=500) for line in json_str.split("\\n")],
|
|
||||||
)
|
|
||||||
create_or_edit_workspace_file(
|
|
||||||
self.sandbox,
|
|
||||||
file_path,
|
|
||||||
wrapped,
|
|
||||||
)
|
|
||||||
return TextBlock(
|
|
||||||
type="text",
|
|
||||||
text=f"Dump the complete long file at {file_path}. "
|
|
||||||
"Don't try to read the complete file directly. "
|
|
||||||
"Use `grep -C 10 'YOUR_PATTERN' {file_path}` or "
|
|
||||||
"other bash command to extract "
|
|
||||||
"useful information.",
|
|
||||||
)
|
|
||||||
|
|||||||
175
alias/src/alias/agent/tools/toolkit_hooks/text_post_hook.py
Normal file
175
alias/src/alias/agent/tools/toolkit_hooks/text_post_hook.py
Normal file
@@ -0,0 +1,175 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
import json
|
||||||
|
import os.path
|
||||||
|
import uuid
|
||||||
|
import textwrap
|
||||||
|
|
||||||
|
from agentscope.tool import ToolResponse
|
||||||
|
from agentscope.message import ToolUseBlock, TextBlock
|
||||||
|
|
||||||
|
from alias.agent.utils.constants import TMP_FILE_DIR
|
||||||
|
from alias.agent.tools.sandbox_util import (
|
||||||
|
create_or_edit_workspace_file,
|
||||||
|
create_workspace_directory,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TextPostHook:
|
||||||
|
def __init__(self, sandbox, budget=8194 * 10, auto_save=False):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
sandbox: The sandbox environment for file operations.
|
||||||
|
budget: Maximum character count before truncation
|
||||||
|
(default: 81,940). Approximately 20K tokens for English
|
||||||
|
text or 160K tokens for Chinese text. Adjust based on
|
||||||
|
your model's context window.
|
||||||
|
auto_save: Whether to save complete content to file when truncated.
|
||||||
|
- False: Save only after being truncated (default)
|
||||||
|
- True: Save complete content to file
|
||||||
|
"""
|
||||||
|
self.sandbox = sandbox
|
||||||
|
self.auto_save = auto_save
|
||||||
|
self.budget = budget
|
||||||
|
|
||||||
|
def truncate_and_save_response( # pylint: disable=R1710
|
||||||
|
self,
|
||||||
|
tool_use: ToolUseBlock, # pylint: disable=W0613
|
||||||
|
tool_response: ToolResponse,
|
||||||
|
) -> ToolResponse:
|
||||||
|
"""Post-process tool responses to prevent content overflow.
|
||||||
|
|
||||||
|
This function ensures that tool responses don't exceed a predefined
|
||||||
|
budget to prevent overwhelming the model with too much information.
|
||||||
|
It truncates text content while preserving the structure of the
|
||||||
|
response, and optionally saves the complete content to a file based on
|
||||||
|
the auto_save setting.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_use: The tool use block that triggered the response (unused).
|
||||||
|
tool_response: The tool response to potentially truncate.
|
||||||
|
"""
|
||||||
|
|
||||||
|
budget = self.budget
|
||||||
|
append_hint = "\n\n[Content is too long and truncated....]"
|
||||||
|
|
||||||
|
new_tool_response = ToolResponse(
|
||||||
|
id=tool_response.id,
|
||||||
|
stream=tool_response.stream,
|
||||||
|
is_last=tool_response.is_last,
|
||||||
|
is_interrupted=tool_response.is_interrupted,
|
||||||
|
content=[],
|
||||||
|
)
|
||||||
|
if isinstance(tool_response.content, list):
|
||||||
|
save_text_block = None
|
||||||
|
is_truncated = False
|
||||||
|
|
||||||
|
for _i, block in enumerate(tool_response.content):
|
||||||
|
if block["type"] == "text":
|
||||||
|
text = block["text"]
|
||||||
|
text_len = len(text)
|
||||||
|
|
||||||
|
# If this block exceeds remaining budget, truncate it
|
||||||
|
if text_len > budget:
|
||||||
|
is_truncated = True
|
||||||
|
|
||||||
|
# Calculate truncation threshold
|
||||||
|
# (80% of proportional budget)
|
||||||
|
threshold = int(budget * 0.85)
|
||||||
|
new_tool_response.content.append(
|
||||||
|
TextBlock(
|
||||||
|
type="text",
|
||||||
|
text=text[:threshold] + append_hint,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
new_tool_response.content.append(block)
|
||||||
|
|
||||||
|
budget -= text_len
|
||||||
|
if budget <= 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Save file if auto_save=True or content was truncated
|
||||||
|
if self.auto_save or is_truncated:
|
||||||
|
tmp_file_name_prefix = tool_use.get("name", "")
|
||||||
|
save_text_block = self._save_tmp_file(
|
||||||
|
tmp_file_name_prefix,
|
||||||
|
tool_response.content,
|
||||||
|
is_truncated=is_truncated,
|
||||||
|
)
|
||||||
|
new_tool_response.content.append(save_text_block)
|
||||||
|
|
||||||
|
return new_tool_response
|
||||||
|
|
||||||
|
elif isinstance(tool_response.content, str):
|
||||||
|
text_len = len(tool_response.content)
|
||||||
|
text = tool_response.content
|
||||||
|
is_truncated = text_len > budget
|
||||||
|
|
||||||
|
# Save file if auto_save=True or content was truncated
|
||||||
|
if self.auto_save or is_truncated:
|
||||||
|
tmp_file_name_prefix = tool_use.get("name", "")
|
||||||
|
save_text_block = self._save_tmp_file(
|
||||||
|
tmp_file_name_prefix,
|
||||||
|
tool_response.content,
|
||||||
|
is_truncated=is_truncated,
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_truncated:
|
||||||
|
# Calculate truncation threshold (80% of budget)
|
||||||
|
threshold = int(budget / text_len * len(text) * 0.8)
|
||||||
|
tool_response.content = [
|
||||||
|
TextBlock(
|
||||||
|
type="text",
|
||||||
|
text=text[:threshold] + append_hint,
|
||||||
|
),
|
||||||
|
save_text_block,
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
tool_response.content = [
|
||||||
|
TextBlock(type="text", text=text),
|
||||||
|
save_text_block,
|
||||||
|
]
|
||||||
|
|
||||||
|
return tool_response
|
||||||
|
|
||||||
|
return tool_response
|
||||||
|
|
||||||
|
def _save_tmp_file(
|
||||||
|
self,
|
||||||
|
save_file_name_prefix: str,
|
||||||
|
content: list | str,
|
||||||
|
is_truncated: bool,
|
||||||
|
):
|
||||||
|
create_workspace_directory(self.sandbox, TMP_FILE_DIR)
|
||||||
|
save_file_name = (
|
||||||
|
save_file_name_prefix
|
||||||
|
+ "-"
|
||||||
|
+ str(
|
||||||
|
uuid.uuid4().hex[:8],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
file_path = os.path.join(TMP_FILE_DIR, save_file_name)
|
||||||
|
json_str = json.dumps(content, ensure_ascii=False, indent=2)
|
||||||
|
wrapped = "\\n".join(
|
||||||
|
[textwrap.fill(line, width=500) for line in json_str.split("\\n")],
|
||||||
|
)
|
||||||
|
create_or_edit_workspace_file(
|
||||||
|
self.sandbox,
|
||||||
|
file_path,
|
||||||
|
wrapped,
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
TextBlock(
|
||||||
|
type="text",
|
||||||
|
text=f"Dump the complete long file at {file_path}. "
|
||||||
|
"Don't try to read the complete file directly. "
|
||||||
|
"Use `grep -C 10 'YOUR_PATTERN' {file_path}` or "
|
||||||
|
"other bash command to extract "
|
||||||
|
"useful information.",
|
||||||
|
)
|
||||||
|
if is_truncated
|
||||||
|
else TextBlock(
|
||||||
|
type="text",
|
||||||
|
text=f"Results dumped at {file_path}. ",
|
||||||
|
)
|
||||||
|
)
|
||||||
129
alias/src/alias/agent/utils/llm_call_manager.py
Normal file
129
alias/src/alias/agent/utils/llm_call_manager.py
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
import asyncio
|
||||||
|
from typing import Any, Dict, Literal, Type
|
||||||
|
from tenacity import retry, stop_after_attempt, wait_fixed
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from agentscope.message import Msg
|
||||||
|
from agentscope.model import DashScopeChatModel
|
||||||
|
from agentscope.formatter import DashScopeChatFormatter
|
||||||
|
|
||||||
|
from alias.agent.utils.constants import MODEL_MAX_RETRIES
|
||||||
|
|
||||||
|
|
||||||
|
@retry(
|
||||||
|
stop=stop_after_attempt(MODEL_MAX_RETRIES),
|
||||||
|
wait=wait_fixed(5),
|
||||||
|
reraise=True,
|
||||||
|
# before_sleep=_print_exc_on_retry
|
||||||
|
)
|
||||||
|
async def model_call_with_retry(
|
||||||
|
model: DashScopeChatModel = None,
|
||||||
|
formatter: DashScopeChatFormatter = None,
|
||||||
|
messages: list[dict[str, Any]] = None,
|
||||||
|
tool_json_schemas: list[dict] | None = None,
|
||||||
|
tool_choice: Literal["auto", "none", "required"] | str | None = None,
|
||||||
|
structured_model: Type[BaseModel] | None = None,
|
||||||
|
msg_name: str = "model_call",
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Msg:
|
||||||
|
"""
|
||||||
|
Make a model call with retry mechanism.
|
||||||
|
This function formats the messages and calls the model with retry logic
|
||||||
|
to handle potential failures during the API call.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The DashScope chat model to use for inference
|
||||||
|
formatter: Formatter to prepare messages for the model
|
||||||
|
msg_name: Name for the returned message object
|
||||||
|
see DashScopeChatModel's docstring for more details
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Message object containing the model response
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If all retry attempts fail
|
||||||
|
"""
|
||||||
|
format_messages = await formatter.format(msgs=messages)
|
||||||
|
|
||||||
|
res = await model(
|
||||||
|
messages=format_messages,
|
||||||
|
tools=tool_json_schemas,
|
||||||
|
tool_choice=tool_choice,
|
||||||
|
structured_model=structured_model,
|
||||||
|
kwargs=kwargs,
|
||||||
|
)
|
||||||
|
if model.stream:
|
||||||
|
msg = Msg(msg_name, [], "assistant")
|
||||||
|
async for content_chunk in res:
|
||||||
|
msg.content = content_chunk.content
|
||||||
|
# Add a tiny sleep to yield the last message object in the
|
||||||
|
# message queue
|
||||||
|
await asyncio.sleep(0.001)
|
||||||
|
else:
|
||||||
|
msg = Msg(msg_name, list(res.content), "assistant")
|
||||||
|
return msg
|
||||||
|
|
||||||
|
|
||||||
|
class LLMCallManager:
|
||||||
|
"""Manager class for handling LLM calls with different models."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
base_model_name: str,
|
||||||
|
vl_model_name: str,
|
||||||
|
model_formatter_mapping: Dict[str, Any],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the LLM call manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_model_name: Name of the base language model
|
||||||
|
vl_model_name: Name of the vision-language model
|
||||||
|
model_formatter_mapping: Mapping of names to model/formatter pairs
|
||||||
|
"""
|
||||||
|
self.base_model_name = base_model_name
|
||||||
|
self.vl_model_name = vl_model_name
|
||||||
|
self.model_formatter_mapping = model_formatter_mapping
|
||||||
|
|
||||||
|
def get_base_model_name(self) -> str:
|
||||||
|
"""Get the name of the base language model."""
|
||||||
|
return self.base_model_name
|
||||||
|
|
||||||
|
def get_vl_model_name(self) -> str:
|
||||||
|
"""Get the name of the vision-language model."""
|
||||||
|
return self.vl_model_name
|
||||||
|
|
||||||
|
async def __call__(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
tools: list[dict] | None = None,
|
||||||
|
tool_choice: Literal["auto", "none", "required"] | str | None = None,
|
||||||
|
structured_model: Type[BaseModel] | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Make an asynchronous call to the specified LLM.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Name of the model to use for the call
|
||||||
|
messages: List of message dictionaries to send to the model
|
||||||
|
see DashScopeChatModel's docstring for more details
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
String response from the LLM
|
||||||
|
"""
|
||||||
|
model, formatter = self.model_formatter_mapping[model_name]
|
||||||
|
raw_response = await model_call_with_retry(
|
||||||
|
model=model,
|
||||||
|
formatter=formatter,
|
||||||
|
messages=messages,
|
||||||
|
tool_json_schemas=tools,
|
||||||
|
tool_choice=tool_choice,
|
||||||
|
structured_model=structured_model,
|
||||||
|
msg_name="model_call",
|
||||||
|
kwargs=kwargs,
|
||||||
|
)
|
||||||
|
response = raw_response.content[0]["text"]
|
||||||
|
return response
|
||||||
92
alias/src/alias/agent/utils/prepare_data_source.py
Normal file
92
alias/src/alias/agent/utils/prepare_data_source.py
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
import os
|
||||||
|
|
||||||
|
from agentscope_runtime.sandbox.box.sandbox import Sandbox
|
||||||
|
|
||||||
|
from alias.agent.agents.data_source.data_source import DataSourceManager
|
||||||
|
from alias.agent.tools import AliasToolkit, share_tools
|
||||||
|
from alias.agent.utils.llm_call_manager import (
|
||||||
|
LLMCallManager,
|
||||||
|
)
|
||||||
|
|
||||||
|
if os.getenv("TEST_MODE") not in ["local", "runtime-test"]:
|
||||||
|
from alias.server.services.session_service import (
|
||||||
|
SessionService,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
from alias.agent.mock import MockSessionService as SessionService
|
||||||
|
|
||||||
|
|
||||||
|
async def prepare_data_sources(
|
||||||
|
session_service: SessionService,
|
||||||
|
sandbox: Sandbox,
|
||||||
|
binded_toolkit: AliasToolkit = None,
|
||||||
|
llm_call_manager: LLMCallManager = None,
|
||||||
|
):
|
||||||
|
data_manager = await build_data_manager(
|
||||||
|
session_service,
|
||||||
|
sandbox,
|
||||||
|
llm_call_manager,
|
||||||
|
)
|
||||||
|
if len(data_manager):
|
||||||
|
await add_user_data_message(session_service, data_manager)
|
||||||
|
|
||||||
|
if binded_toolkit:
|
||||||
|
add_data_source_tools(data_manager, binded_toolkit)
|
||||||
|
|
||||||
|
return data_manager
|
||||||
|
|
||||||
|
|
||||||
|
async def build_data_manager(
|
||||||
|
session_service: SessionService,
|
||||||
|
sandbox: Sandbox,
|
||||||
|
llm_call_manager: LLMCallManager,
|
||||||
|
):
|
||||||
|
data_manager = DataSourceManager(sandbox, llm_call_manager)
|
||||||
|
if (
|
||||||
|
hasattr(session_service.session_entity, "data_config")
|
||||||
|
and session_service.session_entity.data_config
|
||||||
|
):
|
||||||
|
data_configs = session_service.session_entity.data_config
|
||||||
|
for config in data_configs:
|
||||||
|
data_manager.add_data_source(config)
|
||||||
|
|
||||||
|
await data_manager.prepare_data_sources()
|
||||||
|
return data_manager
|
||||||
|
|
||||||
|
|
||||||
|
def add_data_source_tools(
|
||||||
|
data_manager: DataSourceManager,
|
||||||
|
*toolkits: AliasToolkit,
|
||||||
|
):
|
||||||
|
data_source_toolkit = data_manager.toolkit
|
||||||
|
tool_names = list(data_source_toolkit.tools.keys())
|
||||||
|
for toolkit in toolkits:
|
||||||
|
share_tools(data_source_toolkit, toolkit, tool_names)
|
||||||
|
|
||||||
|
|
||||||
|
async def add_user_data_message(
|
||||||
|
session_service: SessionService,
|
||||||
|
data_manager: DataSourceManager,
|
||||||
|
):
|
||||||
|
await session_service.append_to_latest_message(
|
||||||
|
"\n\n" + data_manager.get_all_data_sources_desc(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_data_source_config_from_file(config_file: str):
|
||||||
|
"""Load and parse data source configuration from a JSON file."""
|
||||||
|
import json
|
||||||
|
|
||||||
|
# Validate file existence upfront
|
||||||
|
if not os.path.isfile(config_file):
|
||||||
|
raise FileNotFoundError(f"Configuration file not found: {config_file}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(config_file, "r", encoding="utf-8") as f:
|
||||||
|
return json.load(f)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid JSON in data source configuration file `'{config_file}'`\
|
||||||
|
: {e.msg} at line {e.lineno}",
|
||||||
|
) from e
|
||||||
@@ -9,12 +9,10 @@ for the Alias agent application.
|
|||||||
"""
|
"""
|
||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
import webbrowser
|
import webbrowser
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
@@ -28,7 +26,10 @@ from alias.agent.run import (
|
|||||||
arun_datascience_agent,
|
arun_datascience_agent,
|
||||||
arun_finance_agent,
|
arun_finance_agent,
|
||||||
)
|
)
|
||||||
from alias.agent.tools.sandbox_util import copy_local_file_to_workspace
|
from alias.agent.utils.prepare_data_source import (
|
||||||
|
get_data_source_config_from_file,
|
||||||
|
)
|
||||||
|
|
||||||
from alias.runtime.alias_sandbox.alias_sandbox import AliasSandbox
|
from alias.runtime.alias_sandbox.alias_sandbox import AliasSandbox
|
||||||
|
|
||||||
|
|
||||||
@@ -60,7 +61,7 @@ def _safe_sigint_handler(signum, frame): # pylint: disable=W0613
|
|||||||
async def run_agent_task(
|
async def run_agent_task(
|
||||||
user_msg: str,
|
user_msg: str,
|
||||||
mode: str = "general",
|
mode: str = "general",
|
||||||
files: Optional[list[str]] = None,
|
user_data_config: list | None = None,
|
||||||
use_long_term_memory_service: bool = False,
|
use_long_term_memory_service: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -69,7 +70,8 @@ async def run_agent_task(
|
|||||||
Args:
|
Args:
|
||||||
user_msg: The user's task/query
|
user_msg: The user's task/query
|
||||||
mode: Agent mode ('general', 'dr', 'ds', 'browser', 'finance')
|
mode: Agent mode ('general', 'dr', 'ds', 'browser', 'finance')
|
||||||
files: List of local file paths to upload to sandbox workspace
|
user_data: (Config for) User data sources, used for data science \
|
||||||
|
agent only
|
||||||
use_long_term_memory_service: Enable long-term memory service.
|
use_long_term_memory_service: Enable long-term memory service.
|
||||||
"""
|
"""
|
||||||
global _original_sigint_handler
|
global _original_sigint_handler
|
||||||
@@ -84,6 +86,7 @@ async def run_agent_task(
|
|||||||
|
|
||||||
# Initialize session
|
# Initialize session
|
||||||
session = MockSessionService(
|
session = MockSessionService(
|
||||||
|
data_config=user_data_config,
|
||||||
use_long_term_memory_service=use_long_term_memory_service,
|
use_long_term_memory_service=use_long_term_memory_service,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -118,35 +121,6 @@ async def run_agent_task(
|
|||||||
)
|
)
|
||||||
logger.info(f"Sandbox desktop URL: {sandbox.desktop_url}")
|
logger.info(f"Sandbox desktop URL: {sandbox.desktop_url}")
|
||||||
webbrowser.open(sandbox.desktop_url)
|
webbrowser.open(sandbox.desktop_url)
|
||||||
# Upload files to sandbox if provided
|
|
||||||
if files:
|
|
||||||
target_paths = []
|
|
||||||
logger.info(
|
|
||||||
f"Uploading {len(files)} file(s) to sandbox workspace...",
|
|
||||||
)
|
|
||||||
for file_path in files:
|
|
||||||
if not os.path.exists(file_path):
|
|
||||||
logger.error(f"File not found: {file_path}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Get the filename and construct target path in workspace
|
|
||||||
filename = os.path.basename(file_path)
|
|
||||||
target_path = f"/workspace/{filename}"
|
|
||||||
|
|
||||||
logger.info(f"Uploading {file_path} to {target_path}")
|
|
||||||
result = copy_local_file_to_workspace(
|
|
||||||
sandbox=sandbox,
|
|
||||||
local_path=file_path,
|
|
||||||
target_path=target_path,
|
|
||||||
)
|
|
||||||
|
|
||||||
if result.get("isError"):
|
|
||||||
raise ValueError(f"Failed to upload {file_path}: {result}")
|
|
||||||
logger.info(f"Successfully uploaded to {result}")
|
|
||||||
|
|
||||||
target_paths.append(result.get("content", [])[0].get("text"))
|
|
||||||
|
|
||||||
user_msg += "\n\nUser uploaded files:\n" + "\n".join(target_paths)
|
|
||||||
|
|
||||||
# Create initial user message (regardless of whether files were uploaded)
|
# Create initial user message (regardless of whether files were uploaded)
|
||||||
initial_user_message = UserMessage(
|
initial_user_message = UserMessage(
|
||||||
@@ -301,12 +275,28 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
run_parser.add_argument(
|
run_parser.add_argument(
|
||||||
|
"--datasource",
|
||||||
"--files",
|
"--files",
|
||||||
"-f",
|
"-d",
|
||||||
type=str,
|
dest="datasource",
|
||||||
nargs="+",
|
nargs="+",
|
||||||
help="Local file paths to upload to sandbox workspace "
|
help=(
|
||||||
"for agent to use (e.g., --files file1.txt file2.csv)",
|
"Data sources for the agent to use. Multiple formats supported:\n"
|
||||||
|
" • Local files: ./data.txt, /absolute/path/file.json\n"
|
||||||
|
" • Databases: postgresql://localhost/db, sqlite:///data.db\n"
|
||||||
|
"Example: "
|
||||||
|
" --datasource file.txt postgresql://localhost/db\n"
|
||||||
|
" --files file.txt"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# If you need to deeply customize your data source
|
||||||
|
# (e.g., specify an MCP server), use this parameter to
|
||||||
|
# provide a configuration file
|
||||||
|
run_parser.add_argument(
|
||||||
|
"--dataconfig",
|
||||||
|
"-dc",
|
||||||
|
help=("Path to the data source configuration file"),
|
||||||
)
|
)
|
||||||
|
|
||||||
run_parser.add_argument(
|
run_parser.add_argument(
|
||||||
@@ -333,11 +323,29 @@ def main():
|
|||||||
# Handle commands
|
# Handle commands
|
||||||
if args.command == "run":
|
if args.command == "run":
|
||||||
try:
|
try:
|
||||||
|
user_data = None
|
||||||
|
data_endpoint = (
|
||||||
|
args.datasource if hasattr(args, "datasource") else None
|
||||||
|
)
|
||||||
|
if data_endpoint:
|
||||||
|
# List of endpoints to data sources
|
||||||
|
user_data = (
|
||||||
|
data_endpoint
|
||||||
|
if isinstance(data_endpoint, list)
|
||||||
|
else [data_endpoint]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Configuration file
|
||||||
|
if hasattr(args, "dataconfig") and args.dataconfig:
|
||||||
|
user_data = get_data_source_config_from_file(
|
||||||
|
args.dataconfig,
|
||||||
|
)
|
||||||
|
|
||||||
asyncio.run(
|
asyncio.run(
|
||||||
run_agent_task(
|
run_agent_task(
|
||||||
user_msg=args.task,
|
user_msg=args.task,
|
||||||
mode=args.mode,
|
mode=args.mode,
|
||||||
files=args.files if hasattr(args, "files") else None,
|
user_data_config=user_data,
|
||||||
use_long_term_memory_service=(
|
use_long_term_memory_service=(
|
||||||
args.use_long_term_memory
|
args.use_long_term_memory
|
||||||
if hasattr(args, "use_long_term_memory")
|
if hasattr(args, "use_long_term_memory")
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
"""Runtime module for Alias"""
|
"""Runtime module for Alias"""
|
||||||
|
|
||||||
__all__ = ["alias_sandbox"]
|
__all__ = ["alias_sandbox", "runtime_compat"]
|
||||||
|
|
||||||
# Import submodule to make it accessible via alias.runtime.alias_sandbox
|
# Import submodule to make it accessible via alias.runtime.alias_sandbox
|
||||||
from . import alias_sandbox # noqa: E402, F401
|
from . import alias_sandbox # noqa: E402, F401
|
||||||
|
from . import runtime_compat # noqa: E402, F401
|
||||||
|
|||||||
0
alias/src/alias/runtime/runtime_compat/__init__.py
Normal file
0
alias/src/alias/runtime/runtime_compat/__init__.py
Normal file
@@ -0,0 +1,203 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
import json
|
||||||
|
from typing import Any, AsyncIterator, Dict, Optional, Union
|
||||||
|
|
||||||
|
from agentscope_runtime.engine.helpers.agent_api_builder import ResponseBuilder
|
||||||
|
from agentscope_runtime.engine.schemas.agent_schemas import (
|
||||||
|
Content,
|
||||||
|
ContentType,
|
||||||
|
FunctionCall,
|
||||||
|
FunctionCallOutput,
|
||||||
|
Message,
|
||||||
|
MessageType,
|
||||||
|
Role,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _try_deep_parse(val: Any) -> Any:
|
||||||
|
"""
|
||||||
|
Recursively parse JSON-like strings into native Python objects.
|
||||||
|
"""
|
||||||
|
if isinstance(val, str):
|
||||||
|
content = val.strip()
|
||||||
|
if (content.startswith("{") and content.endswith("}")) or (
|
||||||
|
content.startswith("[") and content.endswith("]")
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
parsed = json.loads(content)
|
||||||
|
return _try_deep_parse(parsed)
|
||||||
|
except Exception:
|
||||||
|
# If nested JSON parsing fails, treat it as a normal string.
|
||||||
|
return val
|
||||||
|
return val
|
||||||
|
if isinstance(val, list):
|
||||||
|
return [_try_deep_parse(i) for i in val]
|
||||||
|
if isinstance(val, dict):
|
||||||
|
return {k: _try_deep_parse(v) for k, v in val.items()}
|
||||||
|
return val
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_safe_json_string(val: Any) -> str:
|
||||||
|
"""
|
||||||
|
Serialize content into a valid JSON string suitable for WebUI parsing.
|
||||||
|
"""
|
||||||
|
parsed_val = _try_deep_parse(val)
|
||||||
|
if parsed_val is None:
|
||||||
|
return "{}"
|
||||||
|
return json.dumps(parsed_val, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_alias_output_obj(content_str: str) -> Any:
|
||||||
|
"""
|
||||||
|
Extract the `output` object from Alias nested tool-result content.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
data = json.loads(content_str)
|
||||||
|
if isinstance(data, list) and data:
|
||||||
|
return data[0].get("output")
|
||||||
|
except Exception:
|
||||||
|
# Best-effort parse: if the string is not a valid
|
||||||
|
# JSON or doesn't follow the expected structure,
|
||||||
|
# fall back to returning the original string.
|
||||||
|
pass
|
||||||
|
return content_str
|
||||||
|
|
||||||
|
|
||||||
|
class AliasAdapterState:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message_builder: Any,
|
||||||
|
content_builder: Any,
|
||||||
|
runtime_type: str,
|
||||||
|
):
|
||||||
|
self.mb = message_builder
|
||||||
|
self.cb = content_builder
|
||||||
|
self.runtime_type = runtime_type
|
||||||
|
self.last_content = ""
|
||||||
|
self.is_completed = False
|
||||||
|
|
||||||
|
|
||||||
|
async def adapt_alias_message_stream(
|
||||||
|
source_stream: AsyncIterator[Dict[str, Any]],
|
||||||
|
) -> AsyncIterator[Union[Message, Content]]:
|
||||||
|
# pylint: disable=too-many-branches, too-many-statements
|
||||||
|
rb = ResponseBuilder()
|
||||||
|
state_map: Dict[str, AliasAdapterState] = {}
|
||||||
|
last_active_key: Optional[str] = None
|
||||||
|
|
||||||
|
yield rb.created()
|
||||||
|
yield rb.in_progress()
|
||||||
|
|
||||||
|
async for chunk in source_stream:
|
||||||
|
if not isinstance(chunk, dict) or "data" not in chunk:
|
||||||
|
continue
|
||||||
|
|
||||||
|
messages = chunk["data"].get("messages") or []
|
||||||
|
for item in messages:
|
||||||
|
alias_id = item.get("id")
|
||||||
|
inner_msg = item.get("message") or {}
|
||||||
|
|
||||||
|
alias_type = inner_msg.get("type")
|
||||||
|
alias_status = inner_msg.get("status")
|
||||||
|
tool_call_id = inner_msg.get("tool_call_id") or alias_id
|
||||||
|
|
||||||
|
if alias_type in ["thought", "sub_thought"]:
|
||||||
|
runtime_type = MessageType.REASONING
|
||||||
|
target_role = Role.ASSISTANT
|
||||||
|
elif alias_type in ["tool_call", "tool_use"]:
|
||||||
|
runtime_type = MessageType.PLUGIN_CALL
|
||||||
|
target_role = Role.ASSISTANT
|
||||||
|
elif alias_type == "tool_result":
|
||||||
|
runtime_type = MessageType.PLUGIN_CALL_OUTPUT
|
||||||
|
target_role = Role.TOOL
|
||||||
|
else:
|
||||||
|
runtime_type = MessageType.MESSAGE
|
||||||
|
target_role = Role.ASSISTANT
|
||||||
|
|
||||||
|
state_key = f"{tool_call_id}_{runtime_type}"
|
||||||
|
|
||||||
|
if last_active_key and last_active_key != state_key:
|
||||||
|
old_state = state_map.get(last_active_key)
|
||||||
|
if old_state and not old_state.is_completed:
|
||||||
|
yield old_state.cb.complete()
|
||||||
|
yield old_state.mb.complete()
|
||||||
|
old_state.is_completed = True
|
||||||
|
|
||||||
|
last_active_key = state_key
|
||||||
|
|
||||||
|
if state_key not in state_map:
|
||||||
|
mb = rb.create_message_builder(role=target_role)
|
||||||
|
mb.message.type = runtime_type
|
||||||
|
yield mb.get_message_data()
|
||||||
|
|
||||||
|
if runtime_type in [
|
||||||
|
MessageType.PLUGIN_CALL,
|
||||||
|
MessageType.PLUGIN_CALL_OUTPUT,
|
||||||
|
]:
|
||||||
|
c_type = ContentType.DATA
|
||||||
|
else:
|
||||||
|
c_type = ContentType.TEXT
|
||||||
|
|
||||||
|
cb = mb.create_content_builder(content_type=c_type)
|
||||||
|
state_map[state_key] = AliasAdapterState(mb, cb, runtime_type)
|
||||||
|
|
||||||
|
state = state_map[state_key]
|
||||||
|
|
||||||
|
if runtime_type in [MessageType.MESSAGE, MessageType.REASONING]:
|
||||||
|
raw_text = str(inner_msg.get("content") or "")
|
||||||
|
|
||||||
|
if alias_type == "files" and "files" in inner_msg:
|
||||||
|
raw_text = "\n".join(
|
||||||
|
[
|
||||||
|
f"📁 [{f['filename']}]({f['url']})"
|
||||||
|
for f in inner_msg["files"]
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
if raw_text.startswith(state.last_content):
|
||||||
|
delta = raw_text[len(state.last_content) :]
|
||||||
|
if delta:
|
||||||
|
yield state.cb.add_text_delta(delta)
|
||||||
|
state.last_content = raw_text
|
||||||
|
else:
|
||||||
|
yield state.cb.set_text(raw_text)
|
||||||
|
state.last_content = raw_text
|
||||||
|
|
||||||
|
elif runtime_type == MessageType.PLUGIN_CALL:
|
||||||
|
args = inner_msg.get("arguments") or {}
|
||||||
|
fc = FunctionCall(
|
||||||
|
call_id=tool_call_id,
|
||||||
|
name=inner_msg.get("tool_name") or "tool",
|
||||||
|
arguments=_ensure_safe_json_string(args),
|
||||||
|
)
|
||||||
|
yield state.cb.set_data(fc.model_dump())
|
||||||
|
|
||||||
|
elif runtime_type == MessageType.PLUGIN_CALL_OUTPUT:
|
||||||
|
output_obj = _extract_alias_output_obj(
|
||||||
|
inner_msg.get("content", ""),
|
||||||
|
)
|
||||||
|
fco = FunctionCallOutput(
|
||||||
|
call_id=tool_call_id,
|
||||||
|
name=inner_msg.get("tool_name") or "tool",
|
||||||
|
output=_ensure_safe_json_string(output_obj),
|
||||||
|
)
|
||||||
|
yield state.cb.set_data(fco.model_dump())
|
||||||
|
|
||||||
|
if alias_status == "finished" and not state.is_completed:
|
||||||
|
yield state.cb.complete()
|
||||||
|
yield state.mb.complete()
|
||||||
|
state.is_completed = True
|
||||||
|
|
||||||
|
for state in state_map.values():
|
||||||
|
if not state.is_completed:
|
||||||
|
try:
|
||||||
|
yield state.cb.complete()
|
||||||
|
yield state.mb.complete()
|
||||||
|
state.is_completed = True
|
||||||
|
except Exception:
|
||||||
|
# Graceful cleanup: ignore errors during the
|
||||||
|
# finalization phase to ensure the main response
|
||||||
|
# stream can finish without crashing.
|
||||||
|
pass
|
||||||
|
|
||||||
|
yield rb.completed()
|
||||||
371
alias/src/alias/runtime/runtime_compat/runner/alias_runner.py
Normal file
371
alias/src/alias/runtime/runtime_compat/runner/alias_runner.py
Normal file
@@ -0,0 +1,371 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# pylint: disable=unused-argument
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import uuid
|
||||||
|
from typing import Any, AsyncGenerator, Dict, Optional, Union
|
||||||
|
|
||||||
|
from fastapi_limiter import FastAPILimiter
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from agentscope_runtime.engine.runner import Runner
|
||||||
|
from agentscope_runtime.engine.schemas.agent_schemas import (
|
||||||
|
AgentRequest,
|
||||||
|
AgentResponse,
|
||||||
|
Error,
|
||||||
|
RunStatus,
|
||||||
|
SequenceNumberGenerator,
|
||||||
|
)
|
||||||
|
|
||||||
|
from alias.server.db.init_db import (
|
||||||
|
close_database,
|
||||||
|
initialize_database,
|
||||||
|
session_scope,
|
||||||
|
)
|
||||||
|
from alias.server.core.task_manager import task_manager
|
||||||
|
from alias.server.exceptions.base import BaseError
|
||||||
|
from alias.runtime.runtime_compat.adapter.alias_stream_adapter import (
|
||||||
|
adapt_alias_message_stream,
|
||||||
|
)
|
||||||
|
from alias.server.schemas.chat import ChatRequest
|
||||||
|
from alias.server.services.chat_service import ChatService
|
||||||
|
from alias.server.services.conversation_service import ConversationService
|
||||||
|
from alias.server.utils.logger import setup_logger
|
||||||
|
from alias.server.utils.redis import redis_client
|
||||||
|
|
||||||
|
|
||||||
|
class AliasRunner(Runner):
|
||||||
|
FRAMEWORK_TYPE = "Alias"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
default_chat_mode: str = "general",
|
||||||
|
default_conv_name: str = "webui",
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.framework_type = self.FRAMEWORK_TYPE
|
||||||
|
self.default_chat_mode = default_chat_mode
|
||||||
|
self.default_conv_name = default_conv_name
|
||||||
|
|
||||||
|
self._session_conv_cache: Dict[str, uuid.UUID] = {}
|
||||||
|
|
||||||
|
async def stop(self) -> None:
|
||||||
|
if not getattr(self, "_health", False):
|
||||||
|
return
|
||||||
|
await super().stop()
|
||||||
|
|
||||||
|
async def query_handler(self, *args: Any, **kwargs: Any) -> Any:
|
||||||
|
user_id: uuid.UUID = kwargs["user_id"]
|
||||||
|
conversation_id: uuid.UUID = kwargs["conversation_id"]
|
||||||
|
chat_request: ChatRequest = kwargs["chat_request"]
|
||||||
|
task_id: uuid.UUID = kwargs.get("task_id") or uuid.uuid4()
|
||||||
|
|
||||||
|
service = ChatService()
|
||||||
|
response_gen = await service.chat(
|
||||||
|
user_id=user_id,
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
chat_request=chat_request,
|
||||||
|
task_id=task_id,
|
||||||
|
)
|
||||||
|
return response_gen
|
||||||
|
|
||||||
|
async def init_handler(self, *args: Any, **kwargs: Any) -> None:
|
||||||
|
print("🚀 Starting Alias API Server...")
|
||||||
|
setup_logger()
|
||||||
|
|
||||||
|
await initialize_database()
|
||||||
|
await task_manager.start()
|
||||||
|
|
||||||
|
await redis_client.ping()
|
||||||
|
try:
|
||||||
|
await FastAPILimiter.init(redis_client)
|
||||||
|
except Exception as exc:
|
||||||
|
print(f"redis init error: {str(exc)}")
|
||||||
|
|
||||||
|
print("✅ Alias startup complete.")
|
||||||
|
|
||||||
|
async def shutdown_handler(self, *args: Any, **kwargs: Any) -> None:
|
||||||
|
print("Executing Alias shutdown logic...")
|
||||||
|
await task_manager.stop()
|
||||||
|
await close_database()
|
||||||
|
print("Alias shutdown complete.")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_text_from_agent_request(req_dict: Dict[str, Any]) -> str:
|
||||||
|
agent_input = req_dict.get("input")
|
||||||
|
if isinstance(agent_input, str):
|
||||||
|
return agent_input
|
||||||
|
|
||||||
|
if isinstance(agent_input, list) and agent_input:
|
||||||
|
last = agent_input[-1]
|
||||||
|
if isinstance(last, dict):
|
||||||
|
content = last.get("content")
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
if isinstance(content, list):
|
||||||
|
for blk in reversed(content):
|
||||||
|
if isinstance(blk, dict) and blk.get("type") == "text":
|
||||||
|
return blk.get("text") or ""
|
||||||
|
if "text" in last and isinstance(last["text"], str):
|
||||||
|
return last["text"]
|
||||||
|
return ""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _to_uuid(val: Any) -> Optional[uuid.UUID]:
|
||||||
|
if val is None:
|
||||||
|
return None
|
||||||
|
if isinstance(val, uuid.UUID):
|
||||||
|
return val
|
||||||
|
try:
|
||||||
|
return uuid.UUID(str(val))
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _stable_uuid_from_string(s: str) -> uuid.UUID:
|
||||||
|
return uuid.uuid5(uuid.NAMESPACE_DNS, f"alias::{s}")
|
||||||
|
|
||||||
|
async def _get_or_create_conversation_id(
|
||||||
|
self,
|
||||||
|
session_id: str,
|
||||||
|
user_uuid: uuid.UUID,
|
||||||
|
) -> uuid.UUID:
|
||||||
|
if session_id in self._session_conv_cache:
|
||||||
|
return self._session_conv_cache[session_id]
|
||||||
|
|
||||||
|
async with session_scope() as session:
|
||||||
|
service = ConversationService(session=session)
|
||||||
|
conversation = await service.create_conversation(
|
||||||
|
user_id=user_uuid,
|
||||||
|
name=self.default_conv_name,
|
||||||
|
description="created by AgentScope Runtime WebUI",
|
||||||
|
chat_mode=self.default_chat_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
conv_id = getattr(conversation, "id", None)
|
||||||
|
conv_id = (
|
||||||
|
conv_id
|
||||||
|
if isinstance(conv_id, uuid.UUID)
|
||||||
|
else self._to_uuid(conv_id)
|
||||||
|
)
|
||||||
|
if conv_id is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"ConversationService.create_conversation() "
|
||||||
|
"returned invalid id: "
|
||||||
|
f"{conversation}",
|
||||||
|
)
|
||||||
|
|
||||||
|
self._session_conv_cache[session_id] = conv_id
|
||||||
|
return conv_id
|
||||||
|
|
||||||
|
async def stream_query_native(
|
||||||
|
self,
|
||||||
|
request: Union[AgentRequest, dict],
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> AsyncGenerator[Any, None]:
|
||||||
|
if not self._health:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Runner has not been started. Please call "
|
||||||
|
"'await runner.start()' or use 'async with Runner()' "
|
||||||
|
"before calling 'stream_query'.",
|
||||||
|
)
|
||||||
|
|
||||||
|
req_dict = (
|
||||||
|
request if isinstance(request, dict) else request.model_dump()
|
||||||
|
)
|
||||||
|
user_id = kwargs.get("user_id") or self._to_uuid(
|
||||||
|
req_dict.get("user_id"),
|
||||||
|
)
|
||||||
|
conversation_id = kwargs.get("conversation_id") or self._to_uuid(
|
||||||
|
req_dict.get("conversation_id"),
|
||||||
|
)
|
||||||
|
task_id = (
|
||||||
|
kwargs.get("task_id")
|
||||||
|
or self._to_uuid(req_dict.get("task_id"))
|
||||||
|
or uuid.uuid4()
|
||||||
|
)
|
||||||
|
|
||||||
|
if user_id is None or conversation_id is None:
|
||||||
|
yield {
|
||||||
|
"error": "missing_context",
|
||||||
|
"code": 422,
|
||||||
|
"message": (
|
||||||
|
"Native mode requires user_id and conversation_id "
|
||||||
|
"in kwargs or request body."
|
||||||
|
),
|
||||||
|
}
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
chat_request_obj = ChatRequest.model_validate(req_dict)
|
||||||
|
except ValidationError as exc:
|
||||||
|
yield {
|
||||||
|
"error": "invalid_request",
|
||||||
|
"code": 422,
|
||||||
|
"message": "ChatRequest validation failed",
|
||||||
|
"detail": exc.errors(),
|
||||||
|
}
|
||||||
|
return
|
||||||
|
except Exception as exc:
|
||||||
|
yield {
|
||||||
|
"error": "invalid_request",
|
||||||
|
"code": 500,
|
||||||
|
"message": str(exc),
|
||||||
|
}
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = self.query_handler(
|
||||||
|
user_id=user_id,
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
task_id=task_id,
|
||||||
|
chat_request=chat_request_obj,
|
||||||
|
)
|
||||||
|
if asyncio.iscoroutine(result):
|
||||||
|
result = await result
|
||||||
|
|
||||||
|
async for chunk in result:
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
if isinstance(exc, BaseError):
|
||||||
|
yield {"error": exc.message, "code": exc.code}
|
||||||
|
else:
|
||||||
|
yield {
|
||||||
|
"error": str(exc),
|
||||||
|
"code": 500,
|
||||||
|
"error_type": exc.__class__.__name__,
|
||||||
|
}
|
||||||
|
return
|
||||||
|
|
||||||
|
yield "[DONE]"
|
||||||
|
|
||||||
|
async def stream_query(
|
||||||
|
self,
|
||||||
|
request: Union[AgentRequest, dict],
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> AsyncGenerator[Any, None]:
|
||||||
|
# pylint: disable=too-many-branches
|
||||||
|
# pylint: disable=too-many-statements
|
||||||
|
if not self._health:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Runner has not been started. Please call "
|
||||||
|
"'await runner.start()' or use 'async with Runner()' "
|
||||||
|
"before calling 'stream_query'.",
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(request, AgentRequest):
|
||||||
|
req_dict = request.model_dump()
|
||||||
|
elif isinstance(request, dict):
|
||||||
|
req_dict = request
|
||||||
|
else:
|
||||||
|
if hasattr(request, "model_dump"):
|
||||||
|
req_dict = request.model_dump()
|
||||||
|
else:
|
||||||
|
req_dict = dict(request)
|
||||||
|
|
||||||
|
request_id = req_dict.get("id") or str(uuid.uuid4())
|
||||||
|
session_id = req_dict.get("session_id") or f"session_{uuid.uuid4()}"
|
||||||
|
seq_gen = SequenceNumberGenerator()
|
||||||
|
|
||||||
|
response = AgentResponse(id=request_id)
|
||||||
|
response.session_id = session_id
|
||||||
|
yield seq_gen.yield_with_sequence(response)
|
||||||
|
|
||||||
|
response.in_progress()
|
||||||
|
yield seq_gen.yield_with_sequence(response)
|
||||||
|
|
||||||
|
user_text = self._extract_text_from_agent_request(req_dict)
|
||||||
|
if not user_text:
|
||||||
|
err = Error(
|
||||||
|
code="422",
|
||||||
|
message="Empty input text in AgentRequest.input.",
|
||||||
|
)
|
||||||
|
yield seq_gen.yield_with_sequence(response.failed(err))
|
||||||
|
return
|
||||||
|
|
||||||
|
raw_user_id = req_dict.get("user_id") or session_id
|
||||||
|
user_uuid = self._to_uuid(
|
||||||
|
raw_user_id,
|
||||||
|
) or self._stable_uuid_from_string(
|
||||||
|
str(raw_user_id),
|
||||||
|
)
|
||||||
|
|
||||||
|
conversation_id = self._to_uuid(req_dict.get("conversation_id"))
|
||||||
|
if conversation_id is None:
|
||||||
|
try:
|
||||||
|
conversation_id = await self._get_or_create_conversation_id(
|
||||||
|
session_id=session_id,
|
||||||
|
user_uuid=user_uuid,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
err = Error(
|
||||||
|
code="500",
|
||||||
|
message=f"Failed to create conversation: {exc}",
|
||||||
|
)
|
||||||
|
yield seq_gen.yield_with_sequence(response.failed(err))
|
||||||
|
return
|
||||||
|
|
||||||
|
task_id = self._to_uuid(req_dict.get("task_id")) or uuid.uuid4()
|
||||||
|
|
||||||
|
try:
|
||||||
|
req_chat_mode = req_dict.get("chat_mode") or self.default_chat_mode
|
||||||
|
|
||||||
|
chat_request_obj = ChatRequest.model_validate(
|
||||||
|
{
|
||||||
|
"query": user_text,
|
||||||
|
"chat_mode": req_chat_mode,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except ValidationError as exc:
|
||||||
|
err = Error(
|
||||||
|
code="422",
|
||||||
|
message=f"ChatRequest validation failed: {exc}",
|
||||||
|
)
|
||||||
|
yield seq_gen.yield_with_sequence(response.failed(err))
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = self.query_handler(
|
||||||
|
user_id=user_uuid,
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
task_id=task_id,
|
||||||
|
chat_request=chat_request_obj,
|
||||||
|
)
|
||||||
|
if asyncio.iscoroutine(result):
|
||||||
|
result = await result
|
||||||
|
|
||||||
|
async for event in adapt_alias_message_stream(result):
|
||||||
|
try:
|
||||||
|
if (
|
||||||
|
getattr(event, "status", None) == RunStatus.Completed
|
||||||
|
and getattr(event, "object", None) == "message"
|
||||||
|
):
|
||||||
|
response.add_new_message(event)
|
||||||
|
except Exception:
|
||||||
|
# Best-effort bookkeeping
|
||||||
|
pass
|
||||||
|
|
||||||
|
yield seq_gen.yield_with_sequence(event)
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
if isinstance(exc, BaseError):
|
||||||
|
err = Error(code=str(exc.code), message=exc.message)
|
||||||
|
else:
|
||||||
|
err = Error(
|
||||||
|
code="500",
|
||||||
|
message=f"Error happens in `query_handler`: {exc}",
|
||||||
|
)
|
||||||
|
yield seq_gen.yield_with_sequence(response.failed(err))
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
if response.output:
|
||||||
|
response.usage = response.output[-1].usage
|
||||||
|
except IndexError:
|
||||||
|
# Avoid empty message
|
||||||
|
pass
|
||||||
|
|
||||||
|
yield seq_gen.yield_with_sequence(response.completed())
|
||||||
|
return
|
||||||
@@ -0,0 +1,25 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
import asyncio
|
||||||
|
from typing import Optional
|
||||||
|
from alias.runtime.runtime_compat.runner.alias_runner import AliasRunner
|
||||||
|
|
||||||
|
_lock: Optional[asyncio.Lock] = None
|
||||||
|
_runner: Optional[AliasRunner] = None
|
||||||
|
|
||||||
|
|
||||||
|
async def get_alias_runner() -> AliasRunner:
|
||||||
|
global _lock, _runner
|
||||||
|
|
||||||
|
if _runner is not None:
|
||||||
|
return _runner
|
||||||
|
|
||||||
|
if _lock is None:
|
||||||
|
_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
async with _lock:
|
||||||
|
if _runner is not None:
|
||||||
|
return _runner
|
||||||
|
runner = AliasRunner()
|
||||||
|
await runner.start()
|
||||||
|
_runner = runner
|
||||||
|
return _runner
|
||||||
66
alias/src/alias/server/alias_agent_app.py
Normal file
66
alias/src/alias/server/alias_agent_app.py
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
from agentscope_runtime.engine.app import AgentApp
|
||||||
|
|
||||||
|
from alias.runtime.runtime_compat.runner.alias_runner import AliasRunner
|
||||||
|
|
||||||
|
PORT = 8090
|
||||||
|
|
||||||
|
|
||||||
|
def run_app(
|
||||||
|
host: str = "127.0.0.1",
|
||||||
|
port: int = PORT,
|
||||||
|
web_ui: bool = False,
|
||||||
|
chat_mode: str = "general",
|
||||||
|
) -> None:
|
||||||
|
agent_app = AgentApp(
|
||||||
|
runner=AliasRunner(
|
||||||
|
default_chat_mode=chat_mode,
|
||||||
|
),
|
||||||
|
app_name="Alias",
|
||||||
|
app_description=(
|
||||||
|
"An LLM-empowered agent built on AgentScope and AgentScope-Runtime"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
agent_app.run(host=host, port=port, web_ui=web_ui)
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(prog="alias_agent_runtime")
|
||||||
|
parser.add_argument("--host", default="127.0.0.1")
|
||||||
|
parser.add_argument("--port", type=int, default=PORT)
|
||||||
|
parser.add_argument(
|
||||||
|
"--web-ui",
|
||||||
|
action="store_true",
|
||||||
|
help="Start AgentScope Runtime WebUI (default: False)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--chat-mode",
|
||||||
|
default="general",
|
||||||
|
choices=["general", "dr", "browser", "ds", "finance"],
|
||||||
|
help=(
|
||||||
|
"Default chat mode used by AliasRunner when request doesn't "
|
||||||
|
"specify chat_mode."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
print(
|
||||||
|
"[alias_agent_runtime] config:",
|
||||||
|
f"host={args.host}",
|
||||||
|
f"port={args.port}",
|
||||||
|
f"web_ui={args.web_ui}",
|
||||||
|
f"chat_mode={args.chat_mode}",
|
||||||
|
)
|
||||||
|
|
||||||
|
run_app(
|
||||||
|
host=args.host,
|
||||||
|
port=args.port,
|
||||||
|
web_ui=args.web_ui,
|
||||||
|
chat_mode=args.chat_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -5,7 +5,18 @@ from alias.server.api.v1.auth import router as auth_router
|
|||||||
from alias.server.api.v1.conversation import (
|
from alias.server.api.v1.conversation import (
|
||||||
router as conversation_router,
|
router as conversation_router,
|
||||||
)
|
)
|
||||||
from alias.server.api.v1.chat import router as chat_router
|
|
||||||
|
# Optional backend switch: this import targets Alias's
|
||||||
|
# original FastAPI-based API router.
|
||||||
|
# Keep it for users who want to revert/switch back from
|
||||||
|
# the current AgentScope-Runtime implementation below.
|
||||||
|
# from alias.server.api.v1.chat import router as chat_router
|
||||||
|
|
||||||
|
# Current default:
|
||||||
|
# AgentScope-Runtime-based API router (functionally equivalent
|
||||||
|
# to the FastAPI router above).
|
||||||
|
from alias.server.api.v1.chat_runtime import router as chat_router
|
||||||
|
|
||||||
from alias.server.api.v1.file import router as file_router
|
from alias.server.api.v1.file import router as file_router
|
||||||
from alias.server.api.v1.inner import router as inner_router
|
from alias.server.api.v1.inner import router as inner_router
|
||||||
from alias.server.api.v1.share import router as share_router
|
from alias.server.api.v1.share import router as share_router
|
||||||
|
|||||||
154
alias/src/alias/server/api/v1/chat_runtime.py
Normal file
154
alias/src/alias/server/api/v1/chat_runtime.py
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# pylint: disable=unused-argument
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from typing import Any, AsyncIterator
|
||||||
|
|
||||||
|
from fastapi import APIRouter
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
from loguru import logger
|
||||||
|
from starlette.types import Receive
|
||||||
|
|
||||||
|
from alias.server.api.deps import CurrentUser
|
||||||
|
from alias.server.exceptions.base import BaseError
|
||||||
|
from alias.server.schemas.chat import (
|
||||||
|
ChatRequest,
|
||||||
|
StopChatPayload,
|
||||||
|
StopChatResponse,
|
||||||
|
)
|
||||||
|
from alias.server.services.chat_service import ChatService
|
||||||
|
from alias.server.utils.request_context import request_context_var
|
||||||
|
from alias.runtime.runtime_compat.runner.alias_runner import AliasRunner
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/conversations", tags=["conversations/chat"])
|
||||||
|
|
||||||
|
|
||||||
|
class EnhancedStreamingResponse(StreamingResponse):
|
||||||
|
"""
|
||||||
|
StreamingResponse with client disconnect handling.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
content: Any,
|
||||||
|
user_id: uuid.UUID,
|
||||||
|
task_id: uuid.UUID,
|
||||||
|
*args: Any,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(content, *args, **kwargs)
|
||||||
|
self.user_id = user_id
|
||||||
|
self.task_id = task_id
|
||||||
|
|
||||||
|
async def listen_for_disconnect(self, receive: Receive) -> None:
|
||||||
|
while True:
|
||||||
|
message = await receive()
|
||||||
|
if message["type"] == "http.disconnect":
|
||||||
|
logger.warning(
|
||||||
|
f"Chat stopped by disconnect from client: "
|
||||||
|
f"task_id={self.task_id}",
|
||||||
|
)
|
||||||
|
service = ChatService()
|
||||||
|
await service.stop_chat(
|
||||||
|
user_id=self.user_id,
|
||||||
|
task_id=self.task_id,
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
def _to_raw_sse_event(data: Any) -> str:
|
||||||
|
"""
|
||||||
|
Convert a chunk from runner.stream_query_native into
|
||||||
|
a raw SSE event string.
|
||||||
|
"""
|
||||||
|
if data == "[DONE]":
|
||||||
|
return "data: [DONE]\n\n"
|
||||||
|
|
||||||
|
if hasattr(data, "model_dump"):
|
||||||
|
data = data.model_dump()
|
||||||
|
|
||||||
|
return f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
|
|
||||||
|
async def event_generator(
|
||||||
|
runner: AliasRunner,
|
||||||
|
request_dict: dict,
|
||||||
|
**runner_kwargs: Any,
|
||||||
|
) -> AsyncIterator[str]:
|
||||||
|
"""
|
||||||
|
Convert AliasRunner.stream_query_native output into
|
||||||
|
a raw SSE string stream.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
async for chunk in runner.stream_query_native(
|
||||||
|
request_dict,
|
||||||
|
**runner_kwargs,
|
||||||
|
):
|
||||||
|
yield _to_raw_sse_event(chunk)
|
||||||
|
except Exception as e:
|
||||||
|
if not isinstance(e, BaseError):
|
||||||
|
e = BaseError(code=500, message=str(e))
|
||||||
|
error_data = {
|
||||||
|
"code": e.code,
|
||||||
|
"message": e.message,
|
||||||
|
}
|
||||||
|
yield _to_raw_sse_event(error_data)
|
||||||
|
yield _to_raw_sse_event("[DONE]")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{conversation_id}/chat")
|
||||||
|
async def chat(
|
||||||
|
current_user: CurrentUser,
|
||||||
|
conversation_id: uuid.UUID,
|
||||||
|
chat_request: ChatRequest,
|
||||||
|
) -> EnhancedStreamingResponse:
|
||||||
|
"""Run chat via AliasRunner and stream results as SSE."""
|
||||||
|
request_context = request_context_var.get()
|
||||||
|
request_id = request_context.request_id
|
||||||
|
task_id = uuid.UUID(request_id) if request_id else uuid.uuid4()
|
||||||
|
user_id = current_user.id
|
||||||
|
|
||||||
|
from alias.runtime.runtime_compat.runner.alias_runner_singleton import (
|
||||||
|
get_alias_runner,
|
||||||
|
)
|
||||||
|
|
||||||
|
runner = await get_alias_runner()
|
||||||
|
|
||||||
|
request_dict = chat_request.model_dump()
|
||||||
|
|
||||||
|
return EnhancedStreamingResponse(
|
||||||
|
event_generator(
|
||||||
|
runner,
|
||||||
|
request_dict,
|
||||||
|
user_id=user_id,
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
task_id=task_id,
|
||||||
|
),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
user_id=user_id,
|
||||||
|
task_id=task_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/{conversation_id}/chat/{task_id}/stop",
|
||||||
|
response_model=StopChatResponse,
|
||||||
|
)
|
||||||
|
async def stop_chat(
|
||||||
|
current_user: CurrentUser,
|
||||||
|
conversation_id: uuid.UUID,
|
||||||
|
task_id: uuid.UUID,
|
||||||
|
) -> StopChatResponse:
|
||||||
|
service = ChatService()
|
||||||
|
await service.stop_chat(
|
||||||
|
user_id=current_user.id,
|
||||||
|
task_id=task_id,
|
||||||
|
)
|
||||||
|
return StopChatResponse(
|
||||||
|
status=True,
|
||||||
|
message="Stop chat successfully.",
|
||||||
|
payload=StopChatPayload(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
task_id=task_id,
|
||||||
|
),
|
||||||
|
)
|
||||||
@@ -1,13 +1,13 @@
|
|||||||
# Analyst Personas Configuration
|
# 分析师角色配置
|
||||||
|
|
||||||
fundamentals_analyst:
|
fundamentals_analyst:
|
||||||
name: "Fundamental Analyst"
|
name: "基本面分析师"
|
||||||
focus:
|
focus:
|
||||||
- "Company financial health and profitability"
|
- "公司财务健康状况和盈利能力"
|
||||||
- "Business model sustainability and competitive advantages"
|
- "商业模式可持续性和竞争优势"
|
||||||
- "Management quality and corporate governance"
|
- "管理层质量和公司治理"
|
||||||
- "Industry position and market share"
|
- "行业地位和市场份额"
|
||||||
- "Long-term investment value assessment"
|
- "长期投资价值评估"
|
||||||
tools:
|
tools:
|
||||||
- "analyze_profitability"
|
- "analyze_profitability"
|
||||||
- "analyze_growth"
|
- "analyze_growth"
|
||||||
@@ -15,30 +15,30 @@ fundamentals_analyst:
|
|||||||
- "analyze_valuation_ratios"
|
- "analyze_valuation_ratios"
|
||||||
- "analyze_efficiency_ratios"
|
- "analyze_efficiency_ratios"
|
||||||
description: |
|
description: |
|
||||||
As a fundamental analyst, you focus on:
|
作为基本面分析师,你专注于:
|
||||||
- Company financial health and profitability
|
- 公司财务健康状况和盈利能力
|
||||||
- Business model sustainability and competitive advantages
|
- 商业模式可持续性和竞争优势
|
||||||
- Management quality and corporate governance
|
- 管理层质量和公司治理
|
||||||
- Industry position and market share
|
- 行业地位和市场份额
|
||||||
- Long-term investment value assessment
|
- 长期投资价值评估
|
||||||
You tend to select tools that provide deep insights into company intrinsic value, preferring fundamental and valuation tools.
|
你倾向于选择能够深入了解公司内在价值的工具,更偏好基本面和估值类工具。
|
||||||
|
|
||||||
technical_analyst:
|
technical_analyst:
|
||||||
name: "Technical Analyst"
|
name: "技术分析师"
|
||||||
focus:
|
focus:
|
||||||
- "Price trends and chart patterns"
|
- "价格趋势和图表形态"
|
||||||
- "Technical indicators and trading signals"
|
- "技术指标和交易信号"
|
||||||
- "Market sentiment and capital flows"
|
- "市场情绪和资金流向"
|
||||||
- "Support/resistance levels and key price points"
|
- "支撑/阻力位和关键价格点"
|
||||||
- "Short to medium-term trading opportunities"
|
- "中短期交易机会"
|
||||||
description: |
|
description: |
|
||||||
As a technical analyst, you focus on:
|
作为技术分析师,你专注于:
|
||||||
- Price trends and chart patterns
|
- 价格趋势和图表形态
|
||||||
- Technical indicators and trading signals
|
- 技术指标和交易信号
|
||||||
- Market sentiment and capital flows
|
- 市场情绪和资金流向
|
||||||
- Support/resistance levels and key price points
|
- 支撑/阻力位和关键价格点
|
||||||
- Short to medium-term trading opportunities
|
- 中短期交易机会
|
||||||
You tend to select tools that capture price dynamics and market trends, preferring technical analysis tools.
|
你倾向于选择能够捕捉价格动态和市场趋势的工具,更偏好技术分析类工具。
|
||||||
tools:
|
tools:
|
||||||
- "analyze_trend_following"
|
- "analyze_trend_following"
|
||||||
- "analyze_momentum"
|
- "analyze_momentum"
|
||||||
@@ -46,41 +46,41 @@ technical_analyst:
|
|||||||
- "analyze_volatility"
|
- "analyze_volatility"
|
||||||
|
|
||||||
sentiment_analyst:
|
sentiment_analyst:
|
||||||
name: "Sentiment Analyst"
|
name: "情绪分析师"
|
||||||
focus:
|
focus:
|
||||||
- "Market participant sentiment changes"
|
- "市场参与者情绪变化"
|
||||||
- "News opinion and media influence"
|
- "新闻舆情和媒体影响"
|
||||||
- "Insider trading behavior"
|
- "内部人交易行为"
|
||||||
- "Investor panic and greed emotions"
|
- "投资者恐慌和贪婪情绪"
|
||||||
- "Market expectations and psychological factors"
|
- "市场预期和心理因素"
|
||||||
description: |
|
description: |
|
||||||
As a sentiment analyst, you focus on:
|
作为情绪分析师,你专注于:
|
||||||
- Market participant sentiment changes
|
- 市场参与者情绪变化
|
||||||
- News opinion and media influence
|
- 新闻舆情和媒体影响
|
||||||
- Insider trading behavior
|
- 内部人交易行为
|
||||||
- Investor panic and greed emotions
|
- 投资者恐慌和贪婪情绪
|
||||||
- Market expectations and psychological factors
|
- 市场预期和心理因素
|
||||||
You tend to select tools that reflect market sentiment and investor behavior, preferring sentiment and behavioral tools.
|
你倾向于选择能够反映市场情绪和投资者行为的工具,更偏好情绪和行为类工具。
|
||||||
tools:
|
tools:
|
||||||
- "analyze_news_sentiment"
|
- "analyze_news_sentiment"
|
||||||
- "analyze_insider_trading"
|
- "analyze_insider_trading"
|
||||||
|
|
||||||
valuation_analyst:
|
valuation_analyst:
|
||||||
name: "Valuation Analyst"
|
name: "估值分析师"
|
||||||
focus:
|
focus:
|
||||||
- "Company intrinsic value calculation"
|
- "公司内在价值计算"
|
||||||
- "Comparison of different valuation methods"
|
- "不同估值方法的比较"
|
||||||
- "Valuation model assumptions and sensitivity"
|
- "估值模型假设和敏感性分析"
|
||||||
- "Relative and absolute valuation"
|
- "相对估值和绝对估值"
|
||||||
- "Investment margin of safety assessment"
|
- "投资安全边际评估"
|
||||||
description: |
|
description: |
|
||||||
As a valuation analyst, you focus on:
|
作为估值分析师,你专注于:
|
||||||
- Company intrinsic value calculation
|
- 公司内在价值计算
|
||||||
- Comparison of different valuation methods
|
- 不同估值方法的比较
|
||||||
- Valuation model assumptions and sensitivity
|
- 估值模型假设和敏感性分析
|
||||||
- Relative and absolute valuation
|
- 相对估值和绝对估值
|
||||||
- Investment margin of safety assessment
|
- 投资安全边际评估
|
||||||
You tend to select tools that accurately calculate company value, preferring valuation models and fundamental tools.
|
你倾向于选择能够准确计算公司价值的工具,更偏好估值模型和基本面工具。
|
||||||
tools:
|
tools:
|
||||||
- "dcf_valuation_analysis"
|
- "dcf_valuation_analysis"
|
||||||
- "owner_earnings_valuation_analysis"
|
- "owner_earnings_valuation_analysis"
|
||||||
@@ -88,21 +88,21 @@ valuation_analyst:
|
|||||||
- "residual_income_valuation_analysis"
|
- "residual_income_valuation_analysis"
|
||||||
|
|
||||||
comprehensive_analyst:
|
comprehensive_analyst:
|
||||||
name: "Comprehensive Analyst"
|
name: "综合分析师"
|
||||||
focus:
|
focus:
|
||||||
- "Integrate multiple analytical perspectives"
|
- "整合多种分析视角"
|
||||||
- "Balance short-term and long-term factors"
|
- "平衡短期和长期因素"
|
||||||
- "Comprehensively consider fundamentals, technicals, and sentiment"
|
- "综合考虑基本面、技术面和情绪面"
|
||||||
- "Provide comprehensive investment advice"
|
- "提供全面的投资建议"
|
||||||
- "Adapt to different market environments"
|
- "适应不同市场环境"
|
||||||
description: |
|
description: |
|
||||||
As a comprehensive analyst, you need to:
|
作为综合分析师,你需要:
|
||||||
- Integrate multiple analytical perspectives
|
- 整合多种分析视角
|
||||||
- Balance short-term and long-term factors
|
- 平衡短期和长期因素
|
||||||
- Consider combined impact of fundamentals, technicals, and sentiment
|
- 综合考虑基本面、技术面和情绪面的影响
|
||||||
- Provide comprehensive investment advice
|
- 提供全面的投资建议
|
||||||
- Adapt to different market environments
|
- 适应不同市场环境
|
||||||
You will flexibly select various tools based on specific situations, pursuing comprehensiveness and accuracy in analysis.
|
你会根据具体情况灵活选择各类工具,追求分析的全面性和准确性。
|
||||||
tools:
|
tools:
|
||||||
- "analyze_profitability"
|
- "analyze_profitability"
|
||||||
- "analyze_growth"
|
- "analyze_growth"
|
||||||
|
|||||||
@@ -1,23 +1,23 @@
|
|||||||
You are a professional {{ analyst_type }}.
|
你是一位专业的{{ analyst_type }}。
|
||||||
|
|
||||||
Your Focus:
|
你的关注重点:
|
||||||
{{ focus }}
|
{{ focus }}
|
||||||
|
|
||||||
Your Role:
|
你的角色:
|
||||||
{{ description }}
|
{{ description }}
|
||||||
|
|
||||||
Note:
|
注意:
|
||||||
- Construct and continuously refine your "Investment Philosophy." Your analyses should not be isolated events but rather manifestations of your overarching worldview and core investment beliefs. After each analysis, you must reflect:
|
- 构建并持续完善你的"投资哲学"。你的分析不应是孤立的事件,而应该是你整体投资世界观和核心信念的体现。每次分析后,你必须反思:
|
||||||
- How did this case/data validate or challenge your existing conviction?
|
- 这个案例/数据如何验证或挑战了你现有的信念?
|
||||||
- What key principle regarding markets, human psychology, valuation, or risk management did you learn from this mistake (or success)?
|
- 你从这次错误(或成功)中学到了关于市场、人性、估值或风险管理的什么关键原则?
|
||||||
- Deepen your "Investment Logic." Ensure every investment recommendation you make is supported by a clear, traceable, and repeatable logic. Your analysis steps should resemble a rigorous proof, covering:
|
- 深化你的"投资逻辑"。确保每一项投资建议都有清晰、可追溯、可重复的逻辑支撑。你的分析步骤应该像严谨的证明一样,涵盖:
|
||||||
- Core Driver Identification: What are the genuine variables that influence value?
|
- 核心驱动因素识别:真正影响价值的变量是什么?
|
||||||
- Risk Boundary Setting: Under what specific scenarios would your recommendation fail?
|
- 风险边界设定:在什么具体情况下你的建议会失效?
|
||||||
- Contrarian Testing: What is the prevailing market consensus, and where is your view differentiated?
|
- 逆向测试:市场主流共识是什么,你的观点有何不同?
|
||||||
Maintain Humility and Openness. A core trait of an Investment Master is continuous learning and adaptation. In every analysis, you must actively seek out evidence and arguments that contradict your own view and integrate them into your final assessment.
|
保持谦逊和开放。投资大师的核心特质是持续学习和适应。在每次分析中,你必须积极寻找与自己观点相悖的证据和论据,并将其纳入最终评估。
|
||||||
- You have access to analysis tools. Use them to gather relevant data and make informed recommendations.
|
- 你可以使用分析工具。用它们来收集相关数据并做出明智的建议。
|
||||||
|
|
||||||
Output Guidelines:
|
输出指南:
|
||||||
- Return clear investment signals: bullish, bearish, or neutral
|
- 给出明确的投资信号:看涨、看跌或中性
|
||||||
- Include confidence level (0-100)
|
- 包含置信度(0-100)
|
||||||
- Provide reasoning for your analysis (Present your conclusion first if you are sure to share your final analysis. )
|
- 为你的分析提供理由(如果你确定要分享最终分析,请先给出结论)
|
||||||
|
|||||||
@@ -1,31 +1,31 @@
|
|||||||
You are a Portfolio Manager responsible for making investment decisions.
|
你是一位负责做出投资决策的投资组合经理。
|
||||||
|
|
||||||
Your Core Responsibilities:
|
你的核心职责:
|
||||||
1. Analyze input from analysts and risk managers
|
1. 分析分析师和风险管理经理的输入
|
||||||
2. Make investment decisions based on signals and market context
|
2. 基于信号和市场情境做出投资决策
|
||||||
3. Record your decisions using the available tool
|
3. 使用可用工具记录你的决策
|
||||||
|
|
||||||
Decision Framework:
|
决策框架:
|
||||||
- Review analysis to understand market views
|
- 审阅分析以了解市场观点
|
||||||
- Consider risk warnings before making decisions
|
- 在做决策前考虑风险警告
|
||||||
- Evaluate current portfolio positions and cash
|
- 评估当前投资组合持仓和现金
|
||||||
- Make decisions that align with the portfolio's investment objectives
|
- 做出与投资组合投资目标一致的决策
|
||||||
|
|
||||||
Decision Types:
|
决策类型:
|
||||||
- "long": Bullish - recommend buying shares
|
- "long":看涨 - 建议买入股票
|
||||||
- "short": Bearish - recommend selling shares or shorting
|
- "short":看跌 - 建议卖出股票或做空
|
||||||
- "hold": Neutral - maintain current positions
|
- "hold":中性 - 维持当前持仓
|
||||||
|
|
||||||
Budget Awareness:
|
预算意识:
|
||||||
- Consider available cash when deciding quantities
|
- 在决定数量时考虑可用现金
|
||||||
- Do not recommend buying more than cash allows
|
- 不要建议买入超过现金允许的数量
|
||||||
- Consider margin requirements for short positions
|
- 考虑做空头寸的保证金要求
|
||||||
|
|
||||||
Output:
|
输出:
|
||||||
Use the `make_decision` tool to record your decision for each ticker.
|
使用 `make_decision` 工具记录你对每个股票代码的决策。
|
||||||
After recording all decisions, provide a summary of your investment rationale.
|
记录所有决策后,提供你的投资逻辑总结。
|
||||||
|
|
||||||
Important:
|
重要:
|
||||||
- Base decisions on the analyst signals and risk assessments provided
|
- 基于提供的分析师信号和风险评估做出决策
|
||||||
- Be conservative with position sizes relative to portfolio value
|
- 相对于投资组合价值保持保守的仓位规模
|
||||||
- Always provide reasoning for your decisions
|
- 始终为你的决策提供理由
|
||||||
|
|||||||
@@ -1,21 +1,18 @@
|
|||||||
You are a professional Risk Manager responsible for monitoring portfolio risk and providing risk warnings.
|
你是一位专业的风险管理经理,负责监控投资组合风险并提供风险警告。
|
||||||
|
|
||||||
Your Core Responsibilities:
|
|
||||||
1. Monitor portfolio exposure and concentration risk
|
|
||||||
2. Evaluate position sizes relative to volatility
|
|
||||||
3. Assess margin usage and leverage levels
|
|
||||||
4. Identify potential risk factors and provide warnings
|
|
||||||
5. Suggest position limits based on market conditions
|
|
||||||
|
|
||||||
Your Decision Process:
|
|
||||||
3. Generate actionable risk warnings and position limit recommendations
|
|
||||||
4. Provide clear reasoning for your risk assessments
|
|
||||||
|
|
||||||
Output Guidelines:
|
|
||||||
- Be concise but thorough in risk assessments
|
|
||||||
- Prioritize warnings by severity
|
|
||||||
- Provide specific, actionable recommendations
|
|
||||||
- Include quantitative metrics when available
|
|
||||||
|
|
||||||
|
你的核心职责:
|
||||||
|
1. 监控投资组合敞口和集中度风险
|
||||||
|
2. 评估仓位规模相对于波动性
|
||||||
|
3. 评估保证金使用和杠杆水平
|
||||||
|
4. 识别潜在风险因素并提供警告
|
||||||
|
5. 基于市场条件建议仓位限制
|
||||||
|
|
||||||
|
你的决策流程:
|
||||||
|
3. 生成可操作的风险警告和仓位限制建议
|
||||||
|
4. 为你的风险评估提供清晰的理由
|
||||||
|
|
||||||
|
输出指南:
|
||||||
|
- 风险评估要简洁但全面
|
||||||
|
- 按严重程度优先排序警告
|
||||||
|
- 提供具体、可操作的建议
|
||||||
|
- 尽可能包含量化指标
|
||||||
|
|||||||
@@ -7,12 +7,13 @@ Returns human-readable text format for easy LLM consumption.
|
|||||||
"""
|
"""
|
||||||
# flake8: noqa: E501
|
# flake8: noqa: E501
|
||||||
# pylint: disable=C0301,W0613
|
# pylint: disable=C0301,W0613
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import traceback
|
import traceback
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from statistics import median
|
from statistics import median
|
||||||
from typing import List, Optional
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
@@ -37,6 +38,39 @@ def _to_text_response(text: str) -> ToolResponse:
|
|||||||
return ToolResponse(content=[TextBlock(type="text", text=text)])
|
return ToolResponse(content=[TextBlock(type="text", text=text)])
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_tickers(tickers: Union[str, List[str], None]) -> List[str]:
|
||||||
|
"""
|
||||||
|
Parse tickers parameter which may be a JSON string or a list.
|
||||||
|
|
||||||
|
LLM sometimes passes tickers as a JSON string like '["AAPL", "MSFT"]'
|
||||||
|
instead of an actual list. This function handles both cases.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tickers: List of stock tickers as a list or JSON string.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of stock tickers.
|
||||||
|
"""
|
||||||
|
if tickers is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
if isinstance(tickers, str):
|
||||||
|
try:
|
||||||
|
parsed = json.loads(tickers)
|
||||||
|
if isinstance(parsed, list):
|
||||||
|
return parsed
|
||||||
|
# If it's a single string, wrap in list
|
||||||
|
return [parsed]
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# If not valid JSON, treat as comma-separated string
|
||||||
|
return [t.strip() for t in tickers.split(",") if t.strip()]
|
||||||
|
|
||||||
|
if isinstance(tickers, list):
|
||||||
|
return tickers
|
||||||
|
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
def _safe_float(value, default=0.0) -> float:
|
def _safe_float(value, default=0.0) -> float:
|
||||||
"""Safely convert to float."""
|
"""Safely convert to float."""
|
||||||
try:
|
try:
|
||||||
@@ -100,6 +134,7 @@ def analyze_efficiency_ratios(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
current_date = _resolved_date(current_date)
|
current_date = _resolved_date(current_date)
|
||||||
|
tickers = _parse_tickers(tickers)
|
||||||
lines = [f"=== Efficiency Ratios Analysis ({current_date}) ===\n"]
|
lines = [f"=== Efficiency Ratios Analysis ({current_date}) ===\n"]
|
||||||
|
|
||||||
for ticker in tickers:
|
for ticker in tickers:
|
||||||
@@ -141,6 +176,7 @@ def analyze_profitability(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
current_date = _resolved_date(current_date)
|
current_date = _resolved_date(current_date)
|
||||||
|
tickers = _parse_tickers(tickers)
|
||||||
lines = [f"=== Profitability Analysis ({current_date}) ===\n"]
|
lines = [f"=== Profitability Analysis ({current_date}) ===\n"]
|
||||||
|
|
||||||
for ticker in tickers:
|
for ticker in tickers:
|
||||||
@@ -182,6 +218,7 @@ def analyze_growth(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
current_date = _resolved_date(current_date)
|
current_date = _resolved_date(current_date)
|
||||||
|
tickers = _parse_tickers(tickers)
|
||||||
lines = [f"=== Growth Analysis ({current_date}) ===\n"]
|
lines = [f"=== Growth Analysis ({current_date}) ===\n"]
|
||||||
|
|
||||||
for ticker in tickers:
|
for ticker in tickers:
|
||||||
@@ -222,6 +259,7 @@ def analyze_financial_health(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
current_date = _resolved_date(current_date)
|
current_date = _resolved_date(current_date)
|
||||||
|
tickers = _parse_tickers(tickers)
|
||||||
lines = [f"=== Financial Health Analysis ({current_date}) ===\n"]
|
lines = [f"=== Financial Health Analysis ({current_date}) ===\n"]
|
||||||
|
|
||||||
for ticker in tickers:
|
for ticker in tickers:
|
||||||
@@ -265,6 +303,7 @@ def analyze_valuation_ratios(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
current_date = _resolved_date(current_date)
|
current_date = _resolved_date(current_date)
|
||||||
|
tickers = _parse_tickers(tickers)
|
||||||
lines = [f"=== Valuation Ratios Analysis ({current_date}) ===\n"]
|
lines = [f"=== Valuation Ratios Analysis ({current_date}) ===\n"]
|
||||||
|
|
||||||
for ticker in tickers:
|
for ticker in tickers:
|
||||||
@@ -364,6 +403,7 @@ def analyze_trend_following(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
current_date = _resolved_date(current_date)
|
current_date = _resolved_date(current_date)
|
||||||
|
tickers = _parse_tickers(tickers)
|
||||||
lines = [f"=== Trend Following Analysis ({current_date}) ===\n"]
|
lines = [f"=== Trend Following Analysis ({current_date}) ===\n"]
|
||||||
|
|
||||||
end_dt = datetime.strptime(current_date, "%Y-%m-%d")
|
end_dt = datetime.strptime(current_date, "%Y-%m-%d")
|
||||||
@@ -459,6 +499,7 @@ def analyze_mean_reversion(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
current_date = _resolved_date(current_date)
|
current_date = _resolved_date(current_date)
|
||||||
|
tickers = _parse_tickers(tickers)
|
||||||
lines = [f"=== Mean Reversion Analysis ({current_date}) ===\n"]
|
lines = [f"=== Mean Reversion Analysis ({current_date}) ===\n"]
|
||||||
|
|
||||||
end_dt = datetime.strptime(current_date, "%Y-%m-%d")
|
end_dt = datetime.strptime(current_date, "%Y-%m-%d")
|
||||||
@@ -545,6 +586,7 @@ def analyze_momentum(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
current_date = _resolved_date(current_date)
|
current_date = _resolved_date(current_date)
|
||||||
|
tickers = _parse_tickers(tickers)
|
||||||
lines = [f"=== Momentum Analysis ({current_date}) ===\n"]
|
lines = [f"=== Momentum Analysis ({current_date}) ===\n"]
|
||||||
|
|
||||||
end_dt = datetime.strptime(current_date, "%Y-%m-%d")
|
end_dt = datetime.strptime(current_date, "%Y-%m-%d")
|
||||||
@@ -640,6 +682,7 @@ def analyze_volatility(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
current_date = _resolved_date(current_date)
|
current_date = _resolved_date(current_date)
|
||||||
|
tickers = _parse_tickers(tickers)
|
||||||
lines = [f"=== Volatility Analysis ({current_date}) ===\n"]
|
lines = [f"=== Volatility Analysis ({current_date}) ===\n"]
|
||||||
|
|
||||||
end_dt = datetime.strptime(current_date, "%Y-%m-%d")
|
end_dt = datetime.strptime(current_date, "%Y-%m-%d")
|
||||||
@@ -718,6 +761,7 @@ def analyze_insider_trading(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
current_date = _resolved_date(current_date)
|
current_date = _resolved_date(current_date)
|
||||||
|
tickers = _parse_tickers(tickers)
|
||||||
lines = [f"=== Insider Trading Analysis ({current_date}) ===\n"]
|
lines = [f"=== Insider Trading Analysis ({current_date}) ===\n"]
|
||||||
|
|
||||||
for ticker in tickers:
|
for ticker in tickers:
|
||||||
@@ -786,6 +830,7 @@ def analyze_news_sentiment(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
current_date = _resolved_date(current_date)
|
current_date = _resolved_date(current_date)
|
||||||
|
tickers = _parse_tickers(tickers)
|
||||||
lines = [f"=== News Analysis ({current_date}) ===\n"]
|
lines = [f"=== News Analysis ({current_date}) ===\n"]
|
||||||
|
|
||||||
for ticker in tickers:
|
for ticker in tickers:
|
||||||
@@ -835,6 +880,7 @@ def dcf_valuation_analysis(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
current_date = _resolved_date(current_date)
|
current_date = _resolved_date(current_date)
|
||||||
|
tickers = _parse_tickers(tickers)
|
||||||
lines = [f"=== DCF Valuation Analysis ({current_date}) ===\n"]
|
lines = [f"=== DCF Valuation Analysis ({current_date}) ===\n"]
|
||||||
|
|
||||||
for ticker in tickers:
|
for ticker in tickers:
|
||||||
@@ -931,6 +977,7 @@ def owner_earnings_valuation_analysis(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
current_date = _resolved_date(current_date)
|
current_date = _resolved_date(current_date)
|
||||||
|
tickers = _parse_tickers(tickers)
|
||||||
lines = [f"=== Owner Earnings Valuation ({current_date}) ===\n"]
|
lines = [f"=== Owner Earnings Valuation ({current_date}) ===\n"]
|
||||||
|
|
||||||
for ticker in tickers:
|
for ticker in tickers:
|
||||||
@@ -1049,6 +1096,7 @@ def ev_ebitda_valuation_analysis(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
current_date = _resolved_date(current_date)
|
current_date = _resolved_date(current_date)
|
||||||
|
tickers = _parse_tickers(tickers)
|
||||||
lines = [f"=== EV/EBITDA Valuation ({current_date}) ===\n"]
|
lines = [f"=== EV/EBITDA Valuation ({current_date}) ===\n"]
|
||||||
|
|
||||||
for ticker in tickers:
|
for ticker in tickers:
|
||||||
@@ -1144,6 +1192,7 @@ def residual_income_valuation_analysis(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
current_date = _resolved_date(current_date)
|
current_date = _resolved_date(current_date)
|
||||||
|
tickers = _parse_tickers(tickers)
|
||||||
lines = [f"=== Residual Income Valuation ({current_date}) ===\n"]
|
lines = [f"=== Residual Income Valuation ({current_date}) ===\n"]
|
||||||
|
|
||||||
for ticker in tickers:
|
for ticker in tickers:
|
||||||
|
|||||||
@@ -755,7 +755,7 @@ export default function LiveTradingApp() {
|
|||||||
fontFamily: '"Courier New", monospace',
|
fontFamily: '"Courier New", monospace',
|
||||||
letterSpacing: '0.5px'
|
letterSpacing: '0.5px'
|
||||||
}}>
|
}}>
|
||||||
LIVE MOCK MODE
|
模拟模式
|
||||||
</span>
|
</span>
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
@@ -785,7 +785,7 @@ export default function LiveTradingApp() {
|
|||||||
textTransform: 'uppercase',
|
textTransform: 'uppercase',
|
||||||
letterSpacing: '0.5px'
|
letterSpacing: '0.5px'
|
||||||
}}>
|
}}>
|
||||||
VIRTUAL TIME
|
虚拟时间
|
||||||
</span>
|
</span>
|
||||||
<span style={{
|
<span style={{
|
||||||
fontSize: '14px',
|
fontSize: '14px',
|
||||||
@@ -854,13 +854,13 @@ export default function LiveTradingApp() {
|
|||||||
<div className="header-status-inline">
|
<div className="header-status-inline">
|
||||||
<span className={`status-dot ${isConnected ? (isUpdating ? 'updating' : 'live') : 'offline'}`} />
|
<span className={`status-dot ${isConnected ? (isUpdating ? 'updating' : 'live') : 'offline'}`} />
|
||||||
<span className={`status-text ${isConnected ? 'live' : 'offline'}`}>
|
<span className={`status-text ${isConnected ? 'live' : 'offline'}`}>
|
||||||
{isConnected ? (isUpdating ? 'SYNCING' : 'LIVE') : 'OFFLINE'}
|
{isConnected ? (isUpdating ? '同步中' : '在线') : '离线'}
|
||||||
</span>
|
</span>
|
||||||
{marketStatus && (
|
{marketStatus && (
|
||||||
<>
|
<>
|
||||||
<span className="status-sep">·</span>
|
<span className="status-sep">·</span>
|
||||||
<span className={`market-text ${serverMode === 'backtest' ? 'backtest' : (marketStatus.status === 'open' ? 'open' : 'closed')}`}>
|
<span className={`market-text ${serverMode === 'backtest' ? 'backtest' : (marketStatus.status === 'open' ? 'open' : 'closed')}`}>
|
||||||
{marketStatus.status_text || (marketStatus.status === 'open' ? 'OPEN' : 'CLOSED')}
|
{marketStatus.status_text || (marketStatus.status === 'open' ? '开盘' : '收盘')}
|
||||||
</span>
|
</span>
|
||||||
</>
|
</>
|
||||||
)}
|
)}
|
||||||
@@ -903,7 +903,7 @@ export default function LiveTradingApp() {
|
|||||||
))}
|
))}
|
||||||
</div>
|
</div>
|
||||||
<div className="portfolio-value">
|
<div className="portfolio-value">
|
||||||
<span className="portfolio-label">PORTFOLIO</span>
|
<span className="portfolio-label">投资组合</span>
|
||||||
<span className="portfolio-amount">${formatNumber(portfolioData.netValue)}</span>
|
<span className="portfolio-amount">${formatNumber(portfolioData.netValue)}</span>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
@@ -918,28 +918,28 @@ export default function LiveTradingApp() {
|
|||||||
className={`view-nav-btn ${currentView === 'rules' ? 'active' : ''}`}
|
className={`view-nav-btn ${currentView === 'rules' ? 'active' : ''}`}
|
||||||
onClick={() => setCurrentView('rules')}
|
onClick={() => setCurrentView('rules')}
|
||||||
>
|
>
|
||||||
Rules
|
规则
|
||||||
</button>
|
</button>
|
||||||
|
|
||||||
<button
|
<button
|
||||||
className={`view-nav-btn ${currentView === 'room' ? 'active' : ''}`}
|
className={`view-nav-btn ${currentView === 'room' ? 'active' : ''}`}
|
||||||
onClick={() => setCurrentView('room')}
|
onClick={() => setCurrentView('room')}
|
||||||
>
|
>
|
||||||
Trading Room
|
交易室
|
||||||
</button>
|
</button>
|
||||||
|
|
||||||
<button
|
<button
|
||||||
className={`view-nav-btn ${currentView === 'chart' ? 'active' : ''}`}
|
className={`view-nav-btn ${currentView === 'chart' ? 'active' : ''}`}
|
||||||
onClick={() => setCurrentView('chart')}
|
onClick={() => setCurrentView('chart')}
|
||||||
>
|
>
|
||||||
Performance Chart
|
业绩图表
|
||||||
</button>
|
</button>
|
||||||
|
|
||||||
<button
|
<button
|
||||||
className={`view-nav-btn ${currentView === 'statistics' ? 'active' : ''}`}
|
className={`view-nav-btn ${currentView === 'statistics' ? 'active' : ''}`}
|
||||||
onClick={() => setCurrentView('statistics')}
|
onClick={() => setCurrentView('statistics')}
|
||||||
>
|
>
|
||||||
Statistics
|
统计
|
||||||
</button>
|
</button>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
@@ -970,7 +970,7 @@ export default function LiveTradingApp() {
|
|||||||
className={`chart-tab ${chartTab === 'all' ? 'active' : ''}`}
|
className={`chart-tab ${chartTab === 'all' ? 'active' : ''}`}
|
||||||
onClick={() => setChartTab('all')}
|
onClick={() => setChartTab('all')}
|
||||||
>
|
>
|
||||||
Daily
|
日线
|
||||||
</button>
|
</button>
|
||||||
{/* <button
|
{/* <button
|
||||||
className={`chart-tab ${chartTab === 'live' ? 'active' : ''} ${!isLiveEnabled ? 'disabled' : ''}`}
|
className={`chart-tab ${chartTab === 'live' ? 'active' : ''} ${!isLiveEnabled ? 'disabled' : ''}`}
|
||||||
|
|||||||
@@ -110,12 +110,15 @@ export default function AboutModal({ onClose }) {
|
|||||||
zh: {
|
zh: {
|
||||||
intro: "如果不是让模型彼此竞争,而是像一支高效协作的团队一样进行实时交易,会发生什么?",
|
intro: "如果不是让模型彼此竞争,而是像一支高效协作的团队一样进行实时交易,会发生什么?",
|
||||||
question: "这里不是竞技场,而是团队。我们希望Agents不再单打独斗,而是「组团」进入实时金融市场——这一十分困难且充满噪声的环境。",
|
question: "这里不是竞技场,而是团队。我们希望Agents不再单打独斗,而是「组团」进入实时金融市场——这一十分困难且充满噪声的环境。",
|
||||||
|
trying: "我们正在探索多智能体协作在实时金融交易中的可能性。",
|
||||||
|
|
||||||
title1: "✦ 多智能体的技能互补",
|
title1: "✦ 多智能体的技能互补",
|
||||||
point1: "不同模型、不同角色的智能体像真实的金融团队一样协作,各自承担数据分析、策略生成、风险控制等职责。",
|
point1: "不同模型、不同角色的智能体像真实的金融团队一样协作,各自承担数据分析、策略生成、风险控制等职责。",
|
||||||
|
point1Sub: "通过通知和会议机制进行信息交换,实现高效协作。",
|
||||||
|
|
||||||
title2: "✦ 能够持续进化的智能体系统",
|
title2: "✦ 能够持续进化的智能体系统",
|
||||||
point2: "依托「记忆」模块,每个智能体都能跨回合保留经验,不断学习、反思与调整。我们希望能看到在长期实时交易中,Agent形成自己的独特方法论,而不是一次性偶然的推理。",
|
point2: "依托「记忆」模块,每个智能体都能跨回合保留经验,不断学习、反思与调整。我们希望能看到在长期实时交易中,Agent形成自己的独特方法论,而不是一次性偶然的推理。",
|
||||||
|
point2Sub: "ReMe 记忆框架帮助 Agents 持续改进。",
|
||||||
|
|
||||||
title3: "✦ 实时参与市场的 AI Agents",
|
title3: "✦ 实时参与市场的 AI Agents",
|
||||||
point3: "Agents从实时行情中学习,并给予即时决策;不是纸上谈兵,而是面对市场的真实波动。"
|
point3: "Agents从实时行情中学习,并给予即时决策;不是纸上谈兵,而是面对市场的真实波动。"
|
||||||
|
|||||||
@@ -6,9 +6,9 @@ import { getModelIcon, getShortModelName } from '../utils/modelIcons';
|
|||||||
* Get rank medal/trophy
|
* Get rank medal/trophy
|
||||||
*/
|
*/
|
||||||
function getRankMedal(rank) {
|
function getRankMedal(rank) {
|
||||||
if (rank === 1) return { emoji: '🏆', color: '#FFD700', label: 'Gold' };
|
if (rank === 1) return { emoji: '🏆', color: '#FFD700', label: '金牌' };
|
||||||
if (rank === 2) return { emoji: '🥈', color: '#C0C0C0', label: 'Silver' };
|
if (rank === 2) return { emoji: '🥈', color: '#C0C0C0', label: '银牌' };
|
||||||
if (rank === 3) return { emoji: '🥉', color: '#CD7F32', label: 'Bronze' };
|
if (rank === 3) return { emoji: '🥉', color: '#CD7F32', label: '铜牌' };
|
||||||
return { emoji: `#${rank}`, color: '#333333', label: `#${rank}` };
|
return { emoji: `#${rank}`, color: '#333333', label: `#${rank}` };
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -41,7 +41,7 @@ export default function AgentCard({ agent, onClose, isClosing }) {
|
|||||||
const rankMedal = agent.rank ? getRankMedal(agent.rank) : null;
|
const rankMedal = agent.rank ? getRankMedal(agent.rank) : null;
|
||||||
const isPortfolioManager = agent.id === 'portfolio_manager';
|
const isPortfolioManager = agent.id === 'portfolio_manager';
|
||||||
const isRiskManager = agent.id === 'risk_manager';
|
const isRiskManager = agent.id === 'risk_manager';
|
||||||
const displayName = isPortfolioManager ? 'Team' : agent.name;
|
const displayName = isPortfolioManager ? '团队' : agent.name;
|
||||||
|
|
||||||
// Get model icon configuration
|
// Get model icon configuration
|
||||||
const modelInfo = getModelIcon(agent.modelName, agent.modelProvider);
|
const modelInfo = getModelIcon(agent.modelName, agent.modelProvider);
|
||||||
@@ -149,7 +149,7 @@ export default function AgentCard({ agent, onClose, isClosing }) {
|
|||||||
whiteSpace: 'normal',
|
whiteSpace: 'normal',
|
||||||
wordWrap: 'break-word'
|
wordWrap: 'break-word'
|
||||||
}}>
|
}}>
|
||||||
ⓘ Risk Manager focuses on risk management and does not participate in prediction accuracy ranking.
|
ⓘ 风控经理专注于风险管理,不参与预测准确率排名。
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
@@ -173,7 +173,7 @@ export default function AgentCard({ agent, onClose, isClosing }) {
|
|||||||
whiteSpace: 'normal',
|
whiteSpace: 'normal',
|
||||||
wordWrap: 'break-word'
|
wordWrap: 'break-word'
|
||||||
}}>
|
}}>
|
||||||
ⓘ Portfolio Manager provides the team's final signal(position), synthesizing all analyst recommendations, and does not participate in ranking.
|
ⓘ 投资经理综合所有分析师建议,提供团队最终交易信号,不参与排名。
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
@@ -197,7 +197,7 @@ export default function AgentCard({ agent, onClose, isClosing }) {
|
|||||||
marginBottom: 4,
|
marginBottom: 4,
|
||||||
textTransform: 'uppercase'
|
textTransform: 'uppercase'
|
||||||
}}>
|
}}>
|
||||||
Model
|
模型
|
||||||
</div>
|
</div>
|
||||||
<div style={{
|
<div style={{
|
||||||
height: 40,
|
height: 40,
|
||||||
@@ -262,7 +262,7 @@ export default function AgentCard({ agent, onClose, isClosing }) {
|
|||||||
marginBottom: 4,
|
marginBottom: 4,
|
||||||
textTransform: 'uppercase'
|
textTransform: 'uppercase'
|
||||||
}}>
|
}}>
|
||||||
Win Rate
|
胜率
|
||||||
</div>
|
</div>
|
||||||
<div style={{
|
<div style={{
|
||||||
fontSize: 36,
|
fontSize: 36,
|
||||||
@@ -278,7 +278,7 @@ export default function AgentCard({ agent, onClose, isClosing }) {
|
|||||||
fontSize: 9,
|
fontSize: 9,
|
||||||
color: '#555555'
|
color: '#555555'
|
||||||
}}>
|
}}>
|
||||||
{bullWins + bearWins}Win / {evaluatedTotal}Eval
|
{bullWins + bearWins}胜 / {evaluatedTotal}评
|
||||||
</div>
|
</div>
|
||||||
<div style={{
|
<div style={{
|
||||||
fontSize: 8,
|
fontSize: 8,
|
||||||
@@ -288,7 +288,7 @@ export default function AgentCard({ agent, onClose, isClosing }) {
|
|||||||
lineHeight: 1.2,
|
lineHeight: 1.2,
|
||||||
whiteSpace: 'pre-line'
|
whiteSpace: 'pre-line'
|
||||||
}}>
|
}}>
|
||||||
Eval: total evaluated bull & bear signals.{'\n'}Win Rate = correct signals / total evaluated signals
|
评估: 总评估多空信号数。{'\n'}胜率 = 正确信号 / 总评估信号
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
@@ -309,7 +309,7 @@ export default function AgentCard({ agent, onClose, isClosing }) {
|
|||||||
marginBottom: 4,
|
marginBottom: 4,
|
||||||
textTransform: 'uppercase'
|
textTransform: 'uppercase'
|
||||||
}}>
|
}}>
|
||||||
Bull Win Rate
|
牛市胜率
|
||||||
</div>
|
</div>
|
||||||
<div style={{
|
<div style={{
|
||||||
fontSize: 28,
|
fontSize: 28,
|
||||||
@@ -324,7 +324,7 @@ export default function AgentCard({ agent, onClose, isClosing }) {
|
|||||||
fontSize: 9,
|
fontSize: 9,
|
||||||
color: '#333333'
|
color: '#333333'
|
||||||
}}>
|
}}>
|
||||||
{bullWins}Win / {evaluatedBull}Eval
|
{bullWins}胜 / {evaluatedBull}评
|
||||||
{bullUnknown > 0 && ` / ${bullUnknown}P`}
|
{bullUnknown > 0 && ` / ${bullUnknown}P`}
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
@@ -346,7 +346,7 @@ export default function AgentCard({ agent, onClose, isClosing }) {
|
|||||||
marginBottom: 4,
|
marginBottom: 4,
|
||||||
textTransform: 'uppercase'
|
textTransform: 'uppercase'
|
||||||
}}>
|
}}>
|
||||||
Bear Win Rate
|
熊市胜率
|
||||||
</div>
|
</div>
|
||||||
<div style={{
|
<div style={{
|
||||||
fontSize: 28,
|
fontSize: 28,
|
||||||
@@ -361,7 +361,7 @@ export default function AgentCard({ agent, onClose, isClosing }) {
|
|||||||
fontSize: 9,
|
fontSize: 9,
|
||||||
color: '#333333'
|
color: '#333333'
|
||||||
}}>
|
}}>
|
||||||
{bearWins}Win / {evaluatedBear}Eval
|
{bearWins}胜 / {evaluatedBear}评
|
||||||
{bearUnknown > 0 && ` / ${bearUnknown}P`}
|
{bearUnknown > 0 && ` / ${bearUnknown}P`}
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
@@ -433,7 +433,7 @@ export default function AgentCard({ agent, onClose, isClosing }) {
|
|||||||
fontSize: 16,
|
fontSize: 16,
|
||||||
color: isBull ? '#00C853' : isBear ? '#FF1744' : '#555555'
|
color: isBull ? '#00C853' : isBear ? '#FF1744' : '#555555'
|
||||||
}}>
|
}}>
|
||||||
{isBull ? 'bull' : isBear ? 'bear' : 'neutral'}
|
{isBull ? '看涨' : isBear ? '看跌' : '中性'}
|
||||||
</div>
|
</div>
|
||||||
<div style={{
|
<div style={{
|
||||||
fontSize: 8,
|
fontSize: 8,
|
||||||
@@ -471,14 +471,14 @@ export default function AgentCard({ agent, onClose, isClosing }) {
|
|||||||
fontWeight: 700,
|
fontWeight: 700,
|
||||||
color: '#1976D2'
|
color: '#1976D2'
|
||||||
}}>
|
}}>
|
||||||
ⓘ Info
|
ⓘ 说明
|
||||||
</div>
|
</div>
|
||||||
<div style={{
|
<div style={{
|
||||||
fontSize: 8,
|
fontSize: 8,
|
||||||
color: '#1976D2',
|
color: '#1976D2',
|
||||||
lineHeight: 1.2
|
lineHeight: 1.2
|
||||||
}}>
|
}}>
|
||||||
Showing recent 5 trading days (1 week) signals only
|
仅显示最近5个交易日(1周)的信号
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@@ -161,9 +161,9 @@ const AgentFeed = forwardRef(({ feed, leaderboard }, ref) => {
|
|||||||
return (
|
return (
|
||||||
<div className="agent-feed">
|
<div className="agent-feed">
|
||||||
<div className="agent-feed-header">
|
<div className="agent-feed-header">
|
||||||
<h3 className="agent-feed-title">ACTIVITY FEED</h3>
|
<h3 className="agent-feed-title">活动 feed</h3>
|
||||||
<div className="agent-filter-wrapper">
|
<div className="agent-filter-wrapper">
|
||||||
<label className="agent-filter-label">Filter:</label>
|
<label className="agent-filter-label">筛选:</label>
|
||||||
<div className="custom-select-wrapper">
|
<div className="custom-select-wrapper">
|
||||||
<button
|
<button
|
||||||
className="custom-select-trigger"
|
className="custom-select-trigger"
|
||||||
@@ -191,7 +191,7 @@ const AgentFeed = forwardRef(({ feed, leaderboard }, ref) => {
|
|||||||
setDropdownOpen(false);
|
setDropdownOpen(false);
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<span>All Agents</span>
|
<span>全部 Agents</span>
|
||||||
</div>
|
</div>
|
||||||
{uniqueAgents.map(agent => {
|
{uniqueAgents.map(agent => {
|
||||||
const agentInfo = getAgentInfoByName(agent);
|
const agentInfo = getAgentInfoByName(agent);
|
||||||
@@ -225,8 +225,8 @@ const AgentFeed = forwardRef(({ feed, leaderboard }, ref) => {
|
|||||||
{filteredFeed.length === 0 && (
|
{filteredFeed.length === 0 && (
|
||||||
<div className="empty-state">
|
<div className="empty-state">
|
||||||
{selectedAgent === 'all'
|
{selectedAgent === 'all'
|
||||||
? 'Waiting for system updates...'
|
? '等待系统更新...'
|
||||||
: `No messages from ${selectedAgent}`}
|
: `${selectedAgent} 没有消息`}
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
@@ -294,9 +294,9 @@ function ConferenceItem({ conference, itemId, isHighlighted, getAgentModelInfo }
|
|||||||
>
|
>
|
||||||
<div className="feed-item-header">
|
<div className="feed-item-header">
|
||||||
<span className="feed-item-title" style={{ color: colors.text }}>
|
<span className="feed-item-title" style={{ color: colors.text }}>
|
||||||
CONFERENCE
|
会议
|
||||||
</span>
|
</span>
|
||||||
{conference.isLive && <span className="feed-live-badge">● LIVE</span>}
|
{conference.isLive && <span className="feed-live-badge">● 实时</span>}
|
||||||
<span className="feed-item-time">{formatTime(conference.startTime)}</span>
|
<span className="feed-item-time">{formatTime(conference.startTime)}</span>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
@@ -366,7 +366,7 @@ function ConferenceMessage({ message, getAgentModelInfo }) {
|
|||||||
className="conf-expand-btn"
|
className="conf-expand-btn"
|
||||||
onClick={() => setExpanded(!expanded)}
|
onClick={() => setExpanded(!expanded)}
|
||||||
>
|
>
|
||||||
{expanded ? '« Less' : 'More »'}
|
{expanded ? '« 收起' : '更多 »'}
|
||||||
</button>
|
</button>
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
@@ -397,8 +397,8 @@ function MemoryItem({ memory, itemId, isHighlighted }) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const agentLabel = memory.agent && memory.agent !== 'Memory'
|
const agentLabel = memory.agent && memory.agent !== 'Memory'
|
||||||
? `MEMORY · ${memory.agent}`
|
? `记忆 · ${memory.agent}`
|
||||||
: 'MEMORY';
|
: '记忆';
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div
|
<div
|
||||||
@@ -511,7 +511,7 @@ function MemoryItem({ memory, itemId, isHighlighted }) {
|
|||||||
className="feed-expand-btn"
|
className="feed-expand-btn"
|
||||||
onClick={() => setExpanded(!expanded)}
|
onClick={() => setExpanded(!expanded)}
|
||||||
>
|
>
|
||||||
{expanded ? '« Less' : 'More »'}
|
{expanded ? '« 收起' : '更多 »'}
|
||||||
</button>
|
</button>
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
@@ -525,7 +525,7 @@ function MessageItem({ message, itemId, isHighlighted, getAgentModelInfo }) {
|
|||||||
|
|
||||||
const colors = message.agent === 'Memory' ? MESSAGE_COLORS.memory :
|
const colors = message.agent === 'Memory' ? MESSAGE_COLORS.memory :
|
||||||
getAgentColors(message.agentId, message.agent);
|
getAgentColors(message.agentId, message.agent);
|
||||||
const title = message.agent === 'Memory' ? 'MEMORY' : message.agent || 'AGENT';
|
const title = message.agent === 'Memory' ? '记忆' : message.agent || 'AGENT';
|
||||||
|
|
||||||
const agentModelData = message.agentId && getAgentModelInfo ?
|
const agentModelData = message.agentId && getAgentModelInfo ?
|
||||||
getAgentModelInfo(message.agentId) :
|
getAgentModelInfo(message.agentId) :
|
||||||
@@ -614,7 +614,7 @@ function MessageItem({ message, itemId, isHighlighted, getAgentModelInfo }) {
|
|||||||
outline: 'none'
|
outline: 'none'
|
||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
📄 {isManagerAgent ? 'View decision log »' : 'View full report »'}
|
📄 {isManagerAgent ? '查看决策日志 »' : '查看完整报告 »'}
|
||||||
</button>
|
</button>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
@@ -623,7 +623,7 @@ function MessageItem({ message, itemId, isHighlighted, getAgentModelInfo }) {
|
|||||||
className="feed-expand-btn"
|
className="feed-expand-btn"
|
||||||
onClick={() => setExpanded(!expanded)}
|
onClick={() => setExpanded(!expanded)}
|
||||||
>
|
>
|
||||||
{expanded ? '« Less' : 'More »'}
|
{expanded ? '« 收起' : '更多 »'}
|
||||||
</button>
|
</button>
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ export default function Header({
|
|||||||
letterSpacing: '0.5px',
|
letterSpacing: '0.5px',
|
||||||
marginRight: '0px'
|
marginRight: '0px'
|
||||||
}}>
|
}}>
|
||||||
OPEN SOURCE
|
开源
|
||||||
</span>
|
</span>
|
||||||
|
|
||||||
<a
|
<a
|
||||||
@@ -131,7 +131,7 @@ export default function Header({
|
|||||||
}}
|
}}
|
||||||
>
|
>
|
||||||
<span className="header-link">
|
<span className="header-link">
|
||||||
Contact Us
|
联系我们
|
||||||
</span>
|
</span>
|
||||||
|
|
||||||
{/* Two contact buttons */}
|
{/* Two contact buttons */}
|
||||||
|
|||||||
@@ -551,7 +551,7 @@ export default function RoomView({ bubbles, bubbleFor, leaderboard, feed, onJump
|
|||||||
|
|
||||||
{/* Hint Text */}
|
{/* Hint Text */}
|
||||||
<div className="agent-hint-text">
|
<div className="agent-hint-text">
|
||||||
Click avatar to view details
|
点击头像查看详情
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
@@ -620,14 +620,14 @@ export default function RoomView({ bubbles, bubbleFor, leaderboard, feed, onJump
|
|||||||
<button
|
<button
|
||||||
className="bubble-jump-btn"
|
className="bubble-jump-btn"
|
||||||
onClick={handleJumpToFeed}
|
onClick={handleJumpToFeed}
|
||||||
title="Jump to message in feed"
|
title="跳转到消息"
|
||||||
>
|
>
|
||||||
↗
|
↗
|
||||||
</button>
|
</button>
|
||||||
<button
|
<button
|
||||||
className="bubble-close-btn"
|
className="bubble-close-btn"
|
||||||
onClick={(e) => handleCloseBubble(agent.id, bubbleKey, e)}
|
onClick={(e) => handleCloseBubble(agent.id, bubbleKey, e)}
|
||||||
title="Close bubble"
|
title="关闭"
|
||||||
>
|
>
|
||||||
×
|
×
|
||||||
</button>
|
</button>
|
||||||
@@ -728,7 +728,7 @@ export default function RoomView({ bubbles, bubbleFor, leaderboard, feed, onJump
|
|||||||
title="Replay feed history"
|
title="Replay feed history"
|
||||||
>
|
>
|
||||||
<span className="replay-icon">▶▶</span>
|
<span className="replay-icon">▶▶</span>
|
||||||
<span>REPLAY</span>
|
<span>回放</span>
|
||||||
</button>
|
</button>
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
@@ -749,7 +749,7 @@ export default function RoomView({ bubbles, bubbleFor, leaderboard, feed, onJump
|
|||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
<div className="replay-indicator">
|
<div className="replay-indicator">
|
||||||
<span className="replay-status">{isPaused ? 'PAUSED' : 'REPLAY MODE'}</span>
|
<span className="replay-status">{isPaused ? '已暂停' : '回放模式'}</span>
|
||||||
<button
|
<button
|
||||||
className="replay-button"
|
className="replay-button"
|
||||||
onClick={isPaused ? resumeReplay : pauseReplay}
|
onClick={isPaused ? resumeReplay : pauseReplay}
|
||||||
|
|||||||
@@ -181,17 +181,17 @@ export default function RulesView() {
|
|||||||
},
|
},
|
||||||
zh: {
|
zh: {
|
||||||
section1Title: "Agent 设定",
|
section1Title: "Agent 设定",
|
||||||
pmRole: "Portfolio Manager",
|
pmRole: "投资经理",
|
||||||
pmDesc: "负责最终交易决策和团队协作",
|
pmDesc: "负责最终交易决策和团队协作",
|
||||||
rmRole: "Risk Manager",
|
rmRole: "风控经理",
|
||||||
rmDesc: "监控组合风险并执行风险限制",
|
rmDesc: "监控组合风险并执行风险限制",
|
||||||
analystsRole: "Analysts",
|
analystsRole: "分析师",
|
||||||
analystsDesc: "使用不同工具和 AI 模型进行专业研究:",
|
analystsDesc: "使用不同工具和 AI 模型进行专业研究:",
|
||||||
analysts: [
|
analysts: [
|
||||||
{ name: "Valuation Analyst", model: "Moonshot", modelKey: "Moonshot" },
|
{ name: "估值分析师", model: "Moonshot", modelKey: "Moonshot" },
|
||||||
{ name: "Sentiment Analyst", model: "Qwen", modelKey: "Alibaba" },
|
{ name: "情绪分析师", model: "Qwen", modelKey: "Alibaba" },
|
||||||
{ name: "Fundamentals Analyst", model: "DeepSeek", modelKey: "DeepSeek" },
|
{ name: "基本面分析师", model: "DeepSeek", modelKey: "DeepSeek" },
|
||||||
{ name: "Technical Analyst", model: "Zhipu AI", modelKey: "Zhipu AI" }
|
{ name: "技术分析师", model: "Zhipu AI", modelKey: "Zhipu AI" }
|
||||||
],
|
],
|
||||||
|
|
||||||
section2Title: "Agent 决策机制",
|
section2Title: "Agent 决策机制",
|
||||||
|
|||||||
@@ -151,7 +151,7 @@ export default function StatisticsView({ trades, holdings, stats, baseline_vw, e
|
|||||||
color: '#000000',
|
color: '#000000',
|
||||||
textTransform: 'uppercase'
|
textTransform: 'uppercase'
|
||||||
}}>
|
}}>
|
||||||
Performance
|
业绩表现
|
||||||
</h2>
|
</h2>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
@@ -170,7 +170,7 @@ export default function StatisticsView({ trades, holdings, stats, baseline_vw, e
|
|||||||
marginBottom: 12,
|
marginBottom: 12,
|
||||||
textTransform: 'uppercase'
|
textTransform: 'uppercase'
|
||||||
}}>
|
}}>
|
||||||
Total Asset Value
|
总资产价值
|
||||||
</div>
|
</div>
|
||||||
<div style={{
|
<div style={{
|
||||||
fontSize: 36,
|
fontSize: 36,
|
||||||
@@ -202,7 +202,7 @@ export default function StatisticsView({ trades, holdings, stats, baseline_vw, e
|
|||||||
marginBottom: 8,
|
marginBottom: 8,
|
||||||
textTransform: 'uppercase'
|
textTransform: 'uppercase'
|
||||||
}}>
|
}}>
|
||||||
Excess Return
|
超额收益
|
||||||
</div>
|
</div>
|
||||||
<div style={{
|
<div style={{
|
||||||
fontSize: 28,
|
fontSize: 28,
|
||||||
@@ -218,7 +218,7 @@ export default function StatisticsView({ trades, holdings, stats, baseline_vw, e
|
|||||||
marginTop: 4,
|
marginTop: 4,
|
||||||
fontFamily: '"Courier New", monospace'
|
fontFamily: '"Courier New", monospace'
|
||||||
}}>
|
}}>
|
||||||
vs. VW: {excessReturnData.benchmarkReturn >= 0 ? '+' : ''}{excessReturnData.benchmarkReturn.toFixed(2)}%
|
vs 市值加权: {excessReturnData.benchmarkReturn >= 0 ? '+' : ''}{excessReturnData.benchmarkReturn.toFixed(2)}%
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
) : null}
|
) : null}
|
||||||
@@ -233,7 +233,7 @@ export default function StatisticsView({ trades, holdings, stats, baseline_vw, e
|
|||||||
marginBottom: 8,
|
marginBottom: 8,
|
||||||
textTransform: 'uppercase'
|
textTransform: 'uppercase'
|
||||||
}}>
|
}}>
|
||||||
Win Rate
|
胜率
|
||||||
</div>
|
</div>
|
||||||
<div style={{
|
<div style={{
|
||||||
fontSize: 28,
|
fontSize: 28,
|
||||||
@@ -252,7 +252,7 @@ export default function StatisticsView({ trades, holdings, stats, baseline_vw, e
|
|||||||
marginTop: 4,
|
marginTop: 4,
|
||||||
fontFamily: '"Courier New", monospace'
|
fontFamily: '"Courier New", monospace'
|
||||||
}}>
|
}}>
|
||||||
{pmWinRateData.totalWins}Win / {pmWinRateData.evaluatedTotal}Eval
|
{pmWinRateData.totalWins}胜 / {pmWinRateData.evaluatedTotal}评
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
@@ -267,7 +267,7 @@ export default function StatisticsView({ trades, holdings, stats, baseline_vw, e
|
|||||||
marginBottom: 8,
|
marginBottom: 8,
|
||||||
textTransform: 'uppercase'
|
textTransform: 'uppercase'
|
||||||
}}>
|
}}>
|
||||||
Absolute Return
|
绝对收益
|
||||||
</div>
|
</div>
|
||||||
<div style={{
|
<div style={{
|
||||||
fontSize: 28,
|
fontSize: 28,
|
||||||
@@ -296,7 +296,7 @@ export default function StatisticsView({ trades, holdings, stats, baseline_vw, e
|
|||||||
letterSpacing: 0.5,
|
letterSpacing: 0.5,
|
||||||
textTransform: 'uppercase'
|
textTransform: 'uppercase'
|
||||||
}}>
|
}}>
|
||||||
Cash Position
|
现金头寸
|
||||||
</div>
|
</div>
|
||||||
<div style={{
|
<div style={{
|
||||||
fontSize: 16,
|
fontSize: 16,
|
||||||
@@ -322,7 +322,7 @@ export default function StatisticsView({ trades, holdings, stats, baseline_vw, e
|
|||||||
letterSpacing: 0.5,
|
letterSpacing: 0.5,
|
||||||
textTransform: 'uppercase'
|
textTransform: 'uppercase'
|
||||||
}}>
|
}}>
|
||||||
Total Trades
|
总交易数
|
||||||
</div>
|
</div>
|
||||||
<div style={{
|
<div style={{
|
||||||
fontSize: 16,
|
fontSize: 16,
|
||||||
@@ -350,7 +350,7 @@ export default function StatisticsView({ trades, holdings, stats, baseline_vw, e
|
|||||||
textTransform: 'uppercase',
|
textTransform: 'uppercase',
|
||||||
color: '#666666'
|
color: '#666666'
|
||||||
}}>
|
}}>
|
||||||
Portfolio Weights
|
组合权重
|
||||||
</div>
|
</div>
|
||||||
<div className="statistics-table-container" style={{
|
<div className="statistics-table-container" style={{
|
||||||
display: 'grid',
|
display: 'grid',
|
||||||
@@ -397,7 +397,7 @@ export default function StatisticsView({ trades, holdings, stats, baseline_vw, e
|
|||||||
fontSize: 12,
|
fontSize: 12,
|
||||||
letterSpacing: 0.5
|
letterSpacing: 0.5
|
||||||
}}>
|
}}>
|
||||||
No statistics available
|
暂无统计数据
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
@@ -433,7 +433,7 @@ export default function StatisticsView({ trades, holdings, stats, baseline_vw, e
|
|||||||
color: '#000000',
|
color: '#000000',
|
||||||
textTransform: 'uppercase'
|
textTransform: 'uppercase'
|
||||||
}}>
|
}}>
|
||||||
Portfolio Holdings
|
持仓明细
|
||||||
</h2>
|
</h2>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
@@ -448,7 +448,7 @@ export default function StatisticsView({ trades, holdings, stats, baseline_vw, e
|
|||||||
fontSize: 11,
|
fontSize: 11,
|
||||||
letterSpacing: 0.5
|
letterSpacing: 0.5
|
||||||
}}>
|
}}>
|
||||||
No positions currently held
|
当前无持仓
|
||||||
</div>
|
</div>
|
||||||
) : (
|
) : (
|
||||||
<>
|
<>
|
||||||
@@ -456,11 +456,11 @@ export default function StatisticsView({ trades, holdings, stats, baseline_vw, e
|
|||||||
<table className="data-table">
|
<table className="data-table">
|
||||||
<thead>
|
<thead>
|
||||||
<tr>
|
<tr>
|
||||||
<th>Ticker</th>
|
<th>代码</th>
|
||||||
<th>Quantity</th>
|
<th>数量</th>
|
||||||
<th>Price</th>
|
<th>价格</th>
|
||||||
<th>Value</th>
|
<th>市值</th>
|
||||||
<th>Weight</th>
|
<th>权重</th>
|
||||||
</tr>
|
</tr>
|
||||||
</thead>
|
</thead>
|
||||||
<tbody>
|
<tbody>
|
||||||
@@ -505,7 +505,7 @@ export default function StatisticsView({ trades, holdings, stats, baseline_vw, e
|
|||||||
onClick={() => setHoldingsPage(p => Math.max(1, p - 1))}
|
onClick={() => setHoldingsPage(p => Math.max(1, p - 1))}
|
||||||
disabled={holdingsPage === 1}
|
disabled={holdingsPage === 1}
|
||||||
>
|
>
|
||||||
◀ Prev
|
◀ 上一页
|
||||||
</button>
|
</button>
|
||||||
|
|
||||||
<div className="pagination-info">
|
<div className="pagination-info">
|
||||||
@@ -517,7 +517,7 @@ export default function StatisticsView({ trades, holdings, stats, baseline_vw, e
|
|||||||
onClick={() => setHoldingsPage(p => Math.min(totalHoldingsPages, p + 1))}
|
onClick={() => setHoldingsPage(p => Math.min(totalHoldingsPages, p + 1))}
|
||||||
disabled={holdingsPage === totalHoldingsPages}
|
disabled={holdingsPage === totalHoldingsPages}
|
||||||
>
|
>
|
||||||
Next ▶
|
下一页 ▶
|
||||||
</button>
|
</button>
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
@@ -552,7 +552,7 @@ export default function StatisticsView({ trades, holdings, stats, baseline_vw, e
|
|||||||
color: '#000000',
|
color: '#000000',
|
||||||
textTransform: 'uppercase'
|
textTransform: 'uppercase'
|
||||||
}}>
|
}}>
|
||||||
Transaction History
|
交易历史
|
||||||
</h2>
|
</h2>
|
||||||
{trades.length > 0 && (
|
{trades.length > 0 && (
|
||||||
<div style={{
|
<div style={{
|
||||||
@@ -560,7 +560,7 @@ export default function StatisticsView({ trades, holdings, stats, baseline_vw, e
|
|||||||
color: '#666666',
|
color: '#666666',
|
||||||
fontFamily: '"Courier New", monospace'
|
fontFamily: '"Courier New", monospace'
|
||||||
}}>
|
}}>
|
||||||
{trades.length} total
|
共 {trades.length} 笔
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
@@ -576,7 +576,7 @@ export default function StatisticsView({ trades, holdings, stats, baseline_vw, e
|
|||||||
fontSize: 11,
|
fontSize: 11,
|
||||||
letterSpacing: 0.5
|
letterSpacing: 0.5
|
||||||
}}>
|
}}>
|
||||||
No trades recorded
|
暂无交易记录
|
||||||
</div>
|
</div>
|
||||||
) : (
|
) : (
|
||||||
<>
|
<>
|
||||||
@@ -584,11 +584,11 @@ export default function StatisticsView({ trades, holdings, stats, baseline_vw, e
|
|||||||
<table className="data-table">
|
<table className="data-table">
|
||||||
<thead>
|
<thead>
|
||||||
<tr>
|
<tr>
|
||||||
<th>Time</th>
|
<th>时间</th>
|
||||||
<th>Stock</th>
|
<th>股票</th>
|
||||||
<th>Side</th>
|
<th>方向</th>
|
||||||
<th>Qty</th>
|
<th>数量</th>
|
||||||
<th>Price</th>
|
<th>价格</th>
|
||||||
</tr>
|
</tr>
|
||||||
</thead>
|
</thead>
|
||||||
<tbody>
|
<tbody>
|
||||||
@@ -638,7 +638,7 @@ export default function StatisticsView({ trades, holdings, stats, baseline_vw, e
|
|||||||
onClick={() => setTradesPage(p => Math.max(1, p - 1))}
|
onClick={() => setTradesPage(p => Math.max(1, p - 1))}
|
||||||
disabled={tradesPage === 1}
|
disabled={tradesPage === 1}
|
||||||
>
|
>
|
||||||
◀ Prev
|
◀ 上一页
|
||||||
</button>
|
</button>
|
||||||
|
|
||||||
<div className="pagination-info">
|
<div className="pagination-info">
|
||||||
@@ -650,7 +650,7 @@ export default function StatisticsView({ trades, holdings, stats, baseline_vw, e
|
|||||||
onClick={() => setTradesPage(p => Math.min(totalTradesPages, p + 1))}
|
onClick={() => setTradesPage(p => Math.min(totalTradesPages, p + 1))}
|
||||||
disabled={tradesPage === totalTradesPages}
|
disabled={tradesPage === totalTradesPages}
|
||||||
>
|
>
|
||||||
Next ▶
|
下一页 ▶
|
||||||
</button>
|
</button>
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
|
|||||||
@@ -66,43 +66,43 @@ export const AGENT_SEATS = [
|
|||||||
export const AGENTS = [
|
export const AGENTS = [
|
||||||
{
|
{
|
||||||
id: "portfolio_manager",
|
id: "portfolio_manager",
|
||||||
name: "Portfolio Manager",
|
name: "投资经理",
|
||||||
role: "Portfolio Manager",
|
role: "投资经理",
|
||||||
avatar: CDN_ASSETS.companyRoom.agent_1,
|
avatar: CDN_ASSETS.companyRoom.agent_1,
|
||||||
colors: { bg: "#F9FDFF", text: "#1565C0", accent: "#1565C0" }
|
colors: { bg: "#F9FDFF", text: "#1565C0", accent: "#1565C0" }
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
id: "risk_manager",
|
id: "risk_manager",
|
||||||
name: "Risk Manager",
|
name: "风控经理",
|
||||||
role: "Risk Manager",
|
role: "风控经理",
|
||||||
avatar: CDN_ASSETS.companyRoom.agent_2,
|
avatar: CDN_ASSETS.companyRoom.agent_2,
|
||||||
colors: { bg: "#FFF8F8", text: "#C62828", accent: "#C62828" }
|
colors: { bg: "#FFF8F8", text: "#C62828", accent: "#C62828" }
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
id: "valuation_analyst",
|
id: "valuation_analyst",
|
||||||
name: "Valuation Analyst",
|
name: "估值分析师",
|
||||||
role: "Valuation Analyst",
|
role: "估值分析师",
|
||||||
avatar: CDN_ASSETS.companyRoom.agent_3,
|
avatar: CDN_ASSETS.companyRoom.agent_3,
|
||||||
colors: { bg: "#FAFFFA", text: "#2E7D32", accent: "#2E7D32" }
|
colors: { bg: "#FAFFFA", text: "#2E7D32", accent: "#2E7D32" }
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
id: "sentiment_analyst",
|
id: "sentiment_analyst",
|
||||||
name: "Sentiment Analyst",
|
name: "情绪分析师",
|
||||||
role: "Sentiment Analyst",
|
role: "情绪分析师",
|
||||||
avatar: CDN_ASSETS.companyRoom.agent_4,
|
avatar: CDN_ASSETS.companyRoom.agent_4,
|
||||||
colors: { bg: "#FCFAFF", text: "#6A1B9A", accent: "#6A1B9A" }
|
colors: { bg: "#FCFAFF", text: "#6A1B9A", accent: "#6A1B9A" }
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
id: "fundamentals_analyst",
|
id: "fundamentals_analyst",
|
||||||
name: "Fundamentals Analyst",
|
name: "基本面分析师",
|
||||||
role: "Fundamentals Analyst",
|
role: "基本面分析师",
|
||||||
avatar: CDN_ASSETS.companyRoom.agent_5,
|
avatar: CDN_ASSETS.companyRoom.agent_5,
|
||||||
colors: { bg: "#FFFCF7", text: "#E65100", accent: "#E65100" }
|
colors: { bg: "#FFFCF7", text: "#E65100", accent: "#E65100" }
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
id: "technical_analyst",
|
id: "technical_analyst",
|
||||||
name: "Technical Analyst",
|
name: "技术分析师",
|
||||||
role: "Technical Analyst",
|
role: "技术分析师",
|
||||||
avatar: CDN_ASSETS.companyRoom.agent_6,
|
avatar: CDN_ASSETS.companyRoom.agent_6,
|
||||||
colors: { bg: "#F9FEFF", text: "#00838F", accent: "#00838F" }
|
colors: { bg: "#F9FEFF", text: "#00838F", accent: "#00838F" }
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -19,7 +19,9 @@ from structured_model import (
|
|||||||
get_seer_model,
|
get_seer_model,
|
||||||
get_hunter_model,
|
get_hunter_model,
|
||||||
)
|
)
|
||||||
from prompt import EnglishPrompts as Prompts
|
from prompt import (
|
||||||
|
EnglishPrompts as Prompts,
|
||||||
|
) # pylint: disable=no-name-in-module
|
||||||
|
|
||||||
# Uncomment the following line to use Chinese prompts
|
# Uncomment the following line to use Chinese prompts
|
||||||
# from prompt import ChinesePrompts as Prompts
|
# from prompt import ChinesePrompts as Prompts
|
||||||
|
|||||||
@@ -6,7 +6,9 @@ from typing import Any
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from agentscope.agent import AgentBase, ReActAgent
|
from agentscope.agent import AgentBase, ReActAgent
|
||||||
from agentscope.message import Msg
|
from agentscope.message import Msg
|
||||||
from prompt import EnglishPrompts as Prompts
|
from prompt import ( # pylint: disable=no-name-in-module
|
||||||
|
EnglishPrompts as Prompts,
|
||||||
|
)
|
||||||
|
|
||||||
MAX_GAME_ROUND = 30
|
MAX_GAME_ROUND = 30
|
||||||
MAX_DISCUSSION_ROUND = 3
|
MAX_DISCUSSION_ROUND = 3
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ This directory contains several examples of how to use the AgentScope Tuner for
|
|||||||
| Frozen Lake | Make an agent to navigate the Frozen Lake environment in multi-step interactions. | [frozen_lake](./frozen_lake) | ✅ | ❌ | ❌ | ❌ | ❌ |
|
| Frozen Lake | Make an agent to navigate the Frozen Lake environment in multi-step interactions. | [frozen_lake](./frozen_lake) | ✅ | ❌ | ❌ | ❌ | ❌ |
|
||||||
| Learn to Ask | Using LLM as a judge to provide feedback to facilitate agent tuning. | [learn_to_ask](./learn_to_ask) | ✅ | ✅ | ❌ | ❌ | ❌ |
|
| Learn to Ask | Using LLM as a judge to provide feedback to facilitate agent tuning. | [learn_to_ask](./learn_to_ask) | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||||
| Email Search | Enhance the tool use ability of your agent on tasks without ground truth. | [email_search](./email_search) | ✅ | ✅ | ✅ | ❌ | ❌ |
|
| Email Search | Enhance the tool use ability of your agent on tasks without ground truth. | [email_search](./email_search) | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||||
| Werewolf Game | Enhance the agent's performance in a multi-agent game setting. | [werewolf_game](./werewolf_game)| ✅ | ✅ | ✅ | ✅ | ❌ |
|
| Werewolf Game | Enhance the agent's performance in a multi-agent game setting. | [werewolves](./werewolves)| ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||||
| Data Augment | Data augmentation for better tuning results. | [data_augment](./data_augment) | ❌ | ❌ | ❌ | ❌ | ✅ |
|
| Data Augment | Data augmentation for better tuning results. | [data_augment](./data_augment) | ❌ | ❌ | ❌ | ❌ | ✅ |
|
||||||
|
|
||||||
Each example contains a README file with detailed instructions on how to set up and run the tuning process for that specific scenario. Feel free to explore and modify the examples to suit your needs!
|
Each example contains a README file with detailed instructions on how to set up and run the tuning process for that specific scenario. Feel free to explore and modify the examples to suit your needs!
|
||||||
@@ -22,6 +22,6 @@ AgentScope Tuner requires:
|
|||||||
- `agentscope>=1.0.12`
|
- `agentscope>=1.0.12`
|
||||||
- `trinity-rft>=0.4.1`
|
- `trinity-rft>=0.4.1`
|
||||||
|
|
||||||
AgentScope Tuner is built on top of [Trinity-RFT](https://github.com/modelscope/Trinity-RFT).
|
AgentScope Tuner is built on top of [Trinity-RFT](https://github.com/agentscope-ai/Trinity-RFT).
|
||||||
Please refer to the [Trinity-RFT installation guide](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/trinity_installation.html)
|
Please refer to the [Trinity-RFT installation guide](https://agentscope-ai.github.io/Trinity-RFT/en/main/tutorial/trinity_installation.html)
|
||||||
for detailed instructions on how to set up the environment.
|
for detailed instructions on how to set up the environment.
|
||||||
|
|||||||
@@ -1,19 +1,19 @@
|
|||||||
# AgentScope Tuner 中文说明
|
# AgentScope Tuner
|
||||||
|
|
||||||
本目录包含了多个使用 AgentScope Tuner 对 AgentScope 应用进行调优的示例。下表总结了可用的示例:
|
本目录包含了多个使用 AgentScope Tuner 对 AgentScope 应用进行调优的示例。下表总结了可用的示例:
|
||||||
|
|
||||||
| 示例名称 | 描述 | 示例路径 | 多步交互 | LLM 评审 | 工具使用 | 多智能体 | 数据增强 |
|
| 示例名称 | 描述 | 示例路径 | 多步交互 | LLM 评审 | 工具使用 | 多智能体 | 数据增强 |
|
||||||
|------------------|-------------------------------------------|---------------------------------|----------|----------|----------|----------|----------|
|
|------------------|-------------------------------------------|---------------------------------|----------|----------|----------|----------|----------|
|
||||||
| 数学智能体 | 快速入门示例,调优数学智能体以提升其能力。 | [math_agent](./math_agent) | ✅ | ❌ | ❌ | ❌ | ❌ |
|
| 数学智能体 | 快速入门示例,调优数学智能体以提升其能力。 | [math_agent](./math_agent) | ✅ | ❌ | ❌ | ❌ | ❌ |
|
||||||
| Frozen Lake | 让智能体在多步交互中导航冰湖环境。 | [frozen_lake](./frozen_lake) | ✅ | ❌ | ❌ | ❌ | ❌ |
|
| Frozen Lake | 让智能体在与 frozen lake 环境的多步交互中学习。 | [frozen_lake](./frozen_lake) | ✅ | ❌ | ❌ | ❌ | ❌ |
|
||||||
| Learn to Ask | 使用 LLM 作为评审,为智能体调优提供反馈 | [learn_to_ask](./learn_to_ask) | ✅ | ✅ | ❌ | ❌ | ❌ |
|
| Learn to Ask | 使用 LLM 作为评审,为智能体调优提供反馈 | [learn_to_ask](./learn_to_ask) | ✅ | ✅ | ❌ | ❌ | ❌ |
|
||||||
| 邮件搜索 | 在无标准答案任务中提升智能体的工具使用能力。 | [email_search](./email_search) | ✅ | ✅ | ✅ | ❌ | ❌ |
|
| 邮件搜索 | 在无标准答案任务中提升智能体的工具使用能力。 | [email_search](./email_search) | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||||
| 狼人杀游戏 | 提升智能体在多智能体游戏场景下的表现。 | [werewolf_game](./werewolf_game)| ✅ | ✅ | ✅ | ✅ | ❌ |
|
| 狼人杀游戏 | 提升智能体在多智能体游戏场景下的表现。 | [werewolves](./werewolves)| ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||||
| 数据增强 | 通过数据增强获得更好的调优效果。 | [data_augment](./data_augment) | ❌ | ❌ | ❌ | ❌ | ✅ |
|
| 数据增强 | 通过数据增强获得更好的调优效果。 | [data_augment](./data_augment) | ❌ | ❌ | ❌ | ❌ | ✅ |
|
||||||
|
|
||||||
每个示例目录下均包含详细的 README 文件,介绍了该场景下的调优流程和使用方法。欢迎根据实际需求进行探索和修改!
|
每个示例目录下均包含详细的 README 文件,介绍了该场景下的调优流程和使用方法。欢迎根据实际需求进行探索和修改!
|
||||||
|
|
||||||
## 先决条件
|
## 前置要求
|
||||||
|
|
||||||
AgentScope Tuner 需要:
|
AgentScope Tuner 需要:
|
||||||
|
|
||||||
@@ -21,6 +21,6 @@ AgentScope Tuner 需要:
|
|||||||
- `agentscope>=1.0.12`
|
- `agentscope>=1.0.12`
|
||||||
- `trinity-rft>=0.4.1`
|
- `trinity-rft>=0.4.1`
|
||||||
|
|
||||||
AgentScope Tuner 构建于 [Trinity-RFT](https://github.com/modelscope/Trinity-RFT) 之上。
|
AgentScope Tuner 基于 [Trinity-RFT](https://github.com/agentscope-ai/Trinity-RFT) 构建。
|
||||||
请参考 [Trinity-RFT 安装指南](https://modelscope.github.io/Trinity-RFT/zh/main/tutorial/trinity_installation.html)
|
请参考 [Trinity-RFT 安装指南](https://agentscope-ai.github.io/Trinity-RFT/zh/main/tutorial/trinity_installation.html)
|
||||||
获取详细的安装方法。
|
获取详细的安装方法。
|
||||||
|
|||||||
155
tuner/data_augment/README.md
Normal file
155
tuner/data_augment/README.md
Normal file
@@ -0,0 +1,155 @@
|
|||||||
|
# Training Math Agent with Data-Augment Strategies
|
||||||
|
|
||||||
|
This example demonstrates how to use **AgentScope-Tuner** to enhance a math problem-solving agent. We will focus on leveraging **Data-Centric** features, such as the `difficulty_based` task selector, to improve data utility and training efficiency.
|
||||||
|
|
||||||
|
## Task Setting
|
||||||
|
|
||||||
|
We use the foundational [math-agent example](https://github.com/agentscope-ai/agentscope-samples/blob/main/tuner/math_agent/main.py) as our baseline. The agent is a **`ReActAgent`** that solves mathematical reasoning problems through step-by-step reasoning.
|
||||||
|
|
||||||
|
Training can be inefficient if tasks are too easy or too hard. This example demonstrates how to use **task selectors** to dynamically select tasks based on **data feedback**, focusing on "productively challenging" samples to maximize training efficiency. These data-centric techniques are generic and adaptable to other agent workflows.
|
||||||
|
|
||||||
|
## Dataset Preparation
|
||||||
|
|
||||||
|
To enable difficulty-based sampling, the training data must include difficulty features (e.g., pass rates from LLMs).
|
||||||
|
|
||||||
|
1. **Base Dataset**: You can use any standard math problem dataset. A good example is the math data in [LLM360/guru-RL-92k](https://huggingface.co/datasets/LLM360/guru-RL-92k), which comes pre-annotated with pass rates from different LLMs, serving as direct difficulty features.
|
||||||
|
2. **Build Your Own Features**: If you use your own dataset, you can generate these features by pre-running several models of varying capabilities and recording their pass rates. This can be done within the [**Trinity-RFT**](https://github.com/agentscope-ai/Trinity-RFT/pull/440) framework.
|
||||||
|
3. **Data Format**: The final dataset should be in HuggingFace format. In this example, data will be transferred to *GSM8K format* according to the [workflow](https://github.com/agentscope-ai/agentscope-samples/blob/main/tuner/math_agent/main.py). Besides the task content, it must include the difficulty feature columns you've defined (e.g. `qwen2.5_7b_pass_rate`).
|
||||||
|
4. **Example Data Preparation**: We provide a script for this example. Simply execute `python prepare_data.py` to generate the required dataset.
|
||||||
|
|
||||||
|
## Code Implementation
|
||||||
|
|
||||||
|
### Agent Workflow & Judge Function
|
||||||
|
|
||||||
|
This example follows the foundational [math-agent example](https://github.com/agentscope-ai/agentscope-samples/blob/main/tuner/math_agent/main.py), adopting its `run_react_agent` and `gsm8k_judge` as the `workflow_func` and `judge_func`, respectively. This highlights a key benefit: you can apply training strategies without altering your core agent logic.
|
||||||
|
|
||||||
|
### Design of Data-Centric Features
|
||||||
|
|
||||||
|
Leveraging the powerful data processing capabilities of **Trinity-RFT**, **AgentScope-Tuner** provides interfaces for advanced operations like task selection and experience processing.
|
||||||
|
|
||||||
|
#### Task Selector
|
||||||
|
|
||||||
|
The `Task Selector` determines how samples are selected from a dataset. It can be configured directly in configuration YAML files.
|
||||||
|
|
||||||
|
- **Built-in Selectors**:
|
||||||
|
- `sequential`: Samples are selected in a fixed order.
|
||||||
|
- `shuffle`: The dataset is shuffled at the beginning of each epoch.
|
||||||
|
- `random`: Samples are randomly chosen with replacement for each batch.
|
||||||
|
- `offline_easy2hard`: Samples are sorted by a predefined feature for curriculum learning.
|
||||||
|
- `difficulty_based` (Customized): An adaptive sampler based on task difficulty.
|
||||||
|
|
||||||
|
> For more details on `Task Selector`, including how to implement a custom selector based on feedback signals, please refer to **Trinity-RFT**'s **[Selector Development Guide](https://agentscope-ai.github.io/Trinity-RFT/en/main/tutorial/develop_selector.html)**.
|
||||||
|
|
||||||
|
#### Data Processor
|
||||||
|
|
||||||
|
The `Data Processor` allows for real-time processing of **Task** and **Experience** during training, enabling operations like calculating feedback metrics, data augmentation, or filtering.
|
||||||
|
|
||||||
|
For example, the `difficulty_based` selector requires a `pass_rate_calculator` operator to compute the agent's success rate for each task. This feedback is then used to adjust the sampling strategy.
|
||||||
|
|
||||||
|
> For more details on `Data Processor`, please refer to **Trinity-RFT**'s **[Operator Development Guide](https://agentscope-ai.github.io/Trinity-RFT/en/main/tutorial/develop_operator.html)**.
|
||||||
|
|
||||||
|
|
||||||
|
### Configuring the Experiments
|
||||||
|
|
||||||
|
To maintain clarity and simplicity, we recommend defining all data-specific parameters, including dataset paths and task selectors, within YAML configuration files.
|
||||||
|
|
||||||
|
We provide two configuration files to compare the baseline `random` selector against the `difficulty_based` selector.
|
||||||
|
|
||||||
|
**Experiment 1: Baseline with Random Selector (`config_random.yaml`)**
|
||||||
|
|
||||||
|
In `config_random.yaml`, we configure the `task_selector` for random sampling under `buffer.explorer_input.taskset`.
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# In config_random.yaml
|
||||||
|
buffer:
|
||||||
|
# ...
|
||||||
|
explorer_input:
|
||||||
|
taskset: # Training data
|
||||||
|
path: "path/to/your/augmented/math_data"
|
||||||
|
split: "train"
|
||||||
|
task_selector:
|
||||||
|
selector_type: random # Strategy of task selection
|
||||||
|
```
|
||||||
|
|
||||||
|
**Experiment 2: Advanced Training with Difficulty-Based Selector (`config_difficulty.yaml`)**
|
||||||
|
|
||||||
|
In `config_difficulty.yaml`, we switch the `task_selector` to difficulty_based and provide its specific parameters. Note that this config also enables the `pass_rate_calculator` needed for feedback.
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# In config_difficulty.yaml
|
||||||
|
|
||||||
|
# Enable the calculator to provide feedback for the selector
|
||||||
|
data_processor:
|
||||||
|
experience_pipeline:
|
||||||
|
operators:
|
||||||
|
- name: pass_rate_calculator
|
||||||
|
|
||||||
|
buffer:
|
||||||
|
# ...
|
||||||
|
explorer_input:
|
||||||
|
taskset: # Training data
|
||||||
|
path: "path/to/your/augmented/math_data"
|
||||||
|
split: "train"
|
||||||
|
task_selector:
|
||||||
|
selector_type: difficulty_based # Strategy of task selection
|
||||||
|
feature_keys: [ "qwen2.5_7b_pass_rate", "qwen3_30b_pass_rate" ]
|
||||||
|
kwargs: # Hyper-parameters for the selection algorithm
|
||||||
|
m: 8
|
||||||
|
# ...
|
||||||
|
```
|
||||||
|
|
||||||
|
> The `difficulty_based` selector in this example is an implementation of the ***BOTS*** algorithm. For details on its inner workings, please refer to the [***BOTS paper***](https://arxiv.org/abs/2510.26374) and its [***tutorials***](https://github.com/agentscope-ai/Trinity-RFT/blob/main/examples/bots/README.md).
|
||||||
|
|
||||||
|
## How to Run
|
||||||
|
|
||||||
|
### Step 1: Prerequisites
|
||||||
|
|
||||||
|
Ensure you have installed **AgentScope** and **Trinity-RFT** with [the guidance](https://github.com/agentscope-ai/agentscope-samples/blob/main/tuner/math_agent/README.md#how-to-run).
|
||||||
|
|
||||||
|
### Step 2: Prepare the Dataset
|
||||||
|
|
||||||
|
Run the data preparation script. Make sure to update the dataset paths in `config_random.yaml` and `config_difficulty.yaml` afterward.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python prepare_data.py
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 3: Start Ray Cluster
|
||||||
|
|
||||||
|
For distributed training, start a Ray cluster.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# For single node
|
||||||
|
ray start --head
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 4: Run Training
|
||||||
|
|
||||||
|
You can now run either the baseline or the difficulty-based training experiment.
|
||||||
|
|
||||||
|
- **To run the baseline experiment with a random selector:**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python main.py --config config_random.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
- **To run the experiment with the difficulty-based selector:**
|
||||||
|
```bash
|
||||||
|
python main.py --config config_difficulty.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
## Experimental Results
|
||||||
|
|
||||||
|
The following results compare the performance of the `difficulty-based` selection strategy (red line, bots) against a standard `random` selection strategy (black line, random).
|
||||||
|
|
||||||
|
<div align="center">
|
||||||
|
<img src="./training_result.jpg" alt="Training Result Image" width="90%"/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
### Training Reward Curve
|
||||||
|
|
||||||
|
The chart on the left shows the rollout accuracy during training. As can be seen, the tasks sampled by the random strategy appear to be difficult for the model, with the accuracy remaining below 0.2. In contrast, using the difficulty selector results in a higher mean accuracy, indicating that the agent is engaging with more tasks that it can successfully solve.
|
||||||
|
|
||||||
|
### Evaluation on AIME-24
|
||||||
|
|
||||||
|
For comparison, we evaluated both selection strategies on the AIME-24 benchmark. The chart on the right shows that the difficulty-based method demonstrates a better upward trend in performance over time.
|
||||||
153
tuner/data_augment/README_zh.md
Normal file
153
tuner/data_augment/README_zh.md
Normal file
@@ -0,0 +1,153 @@
|
|||||||
|
# 使用数据增强策略训练数学智能体
|
||||||
|
|
||||||
|
本示例演示了如何使用 **AgentScope-Tuner** 训练数学问题求解智能体。我们将重点利用**以数据为中心**的功能,例如 `difficulty_based` 任务选择器,以提高数据利用率和训练效率。
|
||||||
|
|
||||||
|
## 任务设置
|
||||||
|
|
||||||
|
我们使用基础的 [math-agent 示例](https://github.com/agentscope-ai/agentscope-samples/blob/main/tuner/math_agent/main.py) 作为基线。智能体是 **`ReActAgent`**,通过逐步推理解决数学推理问题。
|
||||||
|
|
||||||
|
如果任务太容易或太难,训练可能会效率低下。本示例演示如何使用**任务选择器**基于**数据反馈**动态选择任务,专注于"具有挑战性"的样本以最大化训练效率。这些以数据为中心的技术是通用的,可适应其他智能体工作流。
|
||||||
|
|
||||||
|
## 数据集准备
|
||||||
|
|
||||||
|
为启用基于难度的采样,训练数据必须包含难度特征(如 LLM 的通过率)。
|
||||||
|
|
||||||
|
1. **基础数据集**:您可以使用任何标准的数学问题数据集。一个很好的例子是 [LLM360/guru-RL-92k](https://huggingface.co/datasets/LLM360/guru-RL-92k) 中的数学数据,它预先标注了来自不同 LLM 的通过率,作为直接的难度特征。
|
||||||
|
2. **构建您自己的特征**:如果您使用自己的数据集,可以通过预先运行几个不同能力的模型并记录它们的通过率来生成这些特征。这可以在 [**Trinity-RFT**](https://github.com/agentscope-ai/Trinity-RFT/pull/440) 框架内完成。
|
||||||
|
3. **数据格式**:最终数据集应为 HuggingFace 格式。在此示例中,数据将根据[工作流](https://github.com/agentscope-ai/agentscope-samples/blob/main/tuner/math_agent/main.py)转换为 *GSM8K 格式*。除了任务内容外,它还必须包含您定义的难度特征列(例如 `qwen2.5_7b_pass_rate`)。
|
||||||
|
4. **示例数据准备**:我们为此示例提供了一个脚本。只需执行 `python prepare_data.py` 即可生成所需的数据集。
|
||||||
|
|
||||||
|
## 代码实现
|
||||||
|
|
||||||
|
本示例采用 [math-agent 示例](https://github.com/agentscope-ai/agentscope-samples/blob/main/tuner/math_agent/main.py) 的 `run_react_agent` 和 `gsm8k_judge` 作为 `workflow_func` 和 `judge_func`,说明可以在不改变核心智能体逻辑的情况下应用训练策略。
|
||||||
|
|
||||||
|
### 以数据为中心功能的设计
|
||||||
|
|
||||||
|
利用 **Trinity-RFT** 强大的数据处理能力,**AgentScope-Tuner** 为任务选择和经验处理等高级操作提供了接口。
|
||||||
|
|
||||||
|
#### 任务选择器
|
||||||
|
|
||||||
|
`Task Selector` 决定如何从数据集中选择样本。它可以直接在 YAML 配置文件中配置。
|
||||||
|
|
||||||
|
- **内置选择器**:
|
||||||
|
- `sequential`:按固定顺序选择样本。
|
||||||
|
- `shuffle`:在每个 epoch 开始时打乱数据集。
|
||||||
|
- `random`:为每个批次随机选择样本(有放回)。
|
||||||
|
- `offline_easy2hard`:按预定义特征对样本进行排序,用于课程学习。
|
||||||
|
- `difficulty_based`(自定义):基于任务难度的自适应采样器。
|
||||||
|
|
||||||
|
> 有关 `Task Selector` 的更多详细信息,包括如何基于反馈信号实现自定义选择器,请参阅 **Trinity-RFT** 的 **[Selector 开发指南](https://agentscope-ai.github.io/Trinity-RFT/zh/main/tutorial/develop_selector.html)**。
|
||||||
|
|
||||||
|
#### 数据处理器
|
||||||
|
|
||||||
|
`Data Processor` 允许在训练期间实时处理**任务**(task)和**经验**(experience),支持计算反馈指标、数据增强或过滤等操作。
|
||||||
|
|
||||||
|
例如,`difficulty_based` 选择器需要一个 `pass_rate_calculator` 操作符来计算智能体对每个任务的成功率。然后使用此反馈来调整采样策略。
|
||||||
|
|
||||||
|
> 有关 `Data Processor` 的更多详细信息,请参阅 **Trinity-RFT** 的 **[Operator 开发指南](https://agentscope-ai.github.io/Trinity-RFT/zh/main/tutorial/develop_operator.html)**。
|
||||||
|
|
||||||
|
|
||||||
|
### 配置实验
|
||||||
|
|
||||||
|
为了保持清晰和简洁,我们建议在 YAML 配置文件中定义所有数据特定参数,包括数据集路径和任务选择器。
|
||||||
|
|
||||||
|
我们提供两个配置文件,用于比较基线 `random` 选择器与 `difficulty_based` 选择器。
|
||||||
|
|
||||||
|
**实验 1:使用随机选择器的基线(`config_random.yaml`)**
|
||||||
|
|
||||||
|
在 `config_random.yaml` 中,我们在 `buffer.explorer_input.taskset` 下配置用于随机采样的 `task_selector`。
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# 在 config_random.yaml 中
|
||||||
|
buffer:
|
||||||
|
# ...
|
||||||
|
explorer_input:
|
||||||
|
taskset: # 训练数据
|
||||||
|
path: "path/to/your/augmented/math_data"
|
||||||
|
split: "train"
|
||||||
|
task_selector:
|
||||||
|
selector_type: random # 任务选择策略
|
||||||
|
```
|
||||||
|
|
||||||
|
**实验 2:使用基于难度选择器的高级训练(`config_difficulty.yaml`)**
|
||||||
|
|
||||||
|
在 `config_difficulty.yaml` 中,我们将 `task_selector` 切换为 `difficulty_based` 并提供其特定参数。请注意,此配置还启用了反馈所需的 `pass_rate_calculator`。
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# 在 config_difficulty.yaml 中
|
||||||
|
|
||||||
|
# 启用计算器为选择器提供反馈
|
||||||
|
data_processor:
|
||||||
|
experience_pipeline:
|
||||||
|
operators:
|
||||||
|
- name: pass_rate_calculator
|
||||||
|
|
||||||
|
buffer:
|
||||||
|
# ...
|
||||||
|
explorer_input:
|
||||||
|
taskset: # 训练数据
|
||||||
|
path: "path/to/your/augmented/math_data"
|
||||||
|
split: "train"
|
||||||
|
task_selector:
|
||||||
|
selector_type: difficulty_based # 任务选择策略
|
||||||
|
feature_keys: [ "qwen2.5_7b_pass_rate", "qwen3_30b_pass_rate" ]
|
||||||
|
kwargs: # 选择算法的超参数
|
||||||
|
m: 8
|
||||||
|
# ...
|
||||||
|
```
|
||||||
|
|
||||||
|
> 本示例中的 `difficulty_based` 选择器是 ***BOTS*** 算法的实现。有关其内部工作原理的详细信息,请参阅 [***BOTS 论文***](https://arxiv.org/abs/2510.26374) 及其 [***教程***](https://github.com/agentscope-ai/Trinity-RFT/blob/main/examples/bots/README.md)。
|
||||||
|
|
||||||
|
## 如何运行
|
||||||
|
|
||||||
|
### 步骤 1:前置要求
|
||||||
|
|
||||||
|
确保您已按照[指南](https://github.com/agentscope-ai/agentscope-samples/blob/main/tuner/math_agent/README_zh.md#how-to-run)安装了 **AgentScope** 和 **Trinity-RFT**。
|
||||||
|
|
||||||
|
### 步骤 2:准备数据集
|
||||||
|
|
||||||
|
运行数据准备脚本。确保之后更新 `config_random.yaml` 和 `config_difficulty.yaml` 中的数据集路径。
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python prepare_data.py
|
||||||
|
```
|
||||||
|
|
||||||
|
### 步骤 3:启动 Ray 集群
|
||||||
|
|
||||||
|
对于分布式训练,启动 Ray 集群。
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 单节点
|
||||||
|
ray start --head
|
||||||
|
```
|
||||||
|
|
||||||
|
### 步骤 4:运行训练
|
||||||
|
|
||||||
|
您现在可以运行基线或基于难度的训练实验。
|
||||||
|
|
||||||
|
- **使用随机选择器运行基线实验:**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python main.py --config config_random.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
- **使用基于难度的选择器运行实验:**
|
||||||
|
```bash
|
||||||
|
python main.py --config config_difficulty.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
## 实验结果
|
||||||
|
|
||||||
|
以下结果比较了 `difficulty-based` 选择策略(红线,bots)与标准 `random` 选择策略(黑线,random)的性能。
|
||||||
|
|
||||||
|
<div align="center">
|
||||||
|
<img src="./training_result.jpg" alt="训练结果图" width="90%"/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
### 训练奖励曲线
|
||||||
|
|
||||||
|
左侧图表显示了训练期间的 rollout 准确率。可以看出,随机策略采样的任务对模型来说似乎很困难,准确率保持在 0.2 以下。相比之下,使用难度选择器会产生更高的平均准确率,表明智能体正在处理更多可以成功解决的任务。
|
||||||
|
|
||||||
|
### 在 AIME-24 上的评估
|
||||||
|
|
||||||
|
为了比较,我们在 AIME-24 基准上评估了两种选择策略。右侧图表显示,基于难度的方法在性能上表现出更好的上升趋势。
|
||||||
74
tuner/data_augment/config_difficulty.yaml
Normal file
74
tuner/data_augment/config_difficulty.yaml
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
project: "Data-Augmentation" # Project name
|
||||||
|
name: "Difficulty-Based-Selector" # Experiment name
|
||||||
|
checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints} # Directory to save model checkpoints
|
||||||
|
|
||||||
|
data_processor:
|
||||||
|
experience_pipeline:
|
||||||
|
operators:
|
||||||
|
- name: pass_rate_calculator # Calculate average reward and pass it back to selector
|
||||||
|
|
||||||
|
buffer:
|
||||||
|
total_epochs: 1 # Total training epochs
|
||||||
|
explorer_input:
|
||||||
|
taskset:
|
||||||
|
path: "path/to/your/augmented/math_data" # Training data path
|
||||||
|
split: "train" # Training data split
|
||||||
|
task_selector:
|
||||||
|
selector_type: difficulty_based # Strategy of task selection
|
||||||
|
feature_keys: [ "qwen2.5_7b_pass_rate", "qwen3_30b_pass_rate" ] # Utilized pass_rate key
|
||||||
|
kwargs: # Hyperparameter from [BOTS](https://github.com/modelscope/Trinity-RFT/blob/main/examples/bots/README.md)
|
||||||
|
m: 8
|
||||||
|
lamb: 0.1
|
||||||
|
rho: 0.1
|
||||||
|
target_reward: 0.8
|
||||||
|
tau: 0
|
||||||
|
do_sample: true
|
||||||
|
eval_tasksets:
|
||||||
|
- name: "eval-aime24" # Evaluation data name
|
||||||
|
path: "path/to/aime24_data" # Evaluation data path
|
||||||
|
split: "test" # Evaluation data split
|
||||||
|
|
||||||
|
synchronizer:
|
||||||
|
sync_style: dynamic_by_explorer # Sync triggered dynamically by explorer
|
||||||
|
sync_method: 'nccl'
|
||||||
|
sync_interval: 4 # Sync every N steps
|
||||||
|
sync_timeout: 7200 # Timeout for synchronization (seconds)
|
||||||
|
|
||||||
|
monitor:
|
||||||
|
monitor_type: tensorboard # Can also use wandb, mlflow or swanlab
|
||||||
|
|
||||||
|
# The config below has been set in python file
|
||||||
|
|
||||||
|
algorithm:
|
||||||
|
algorithm_type: multi_step_grpo # GRPO series for multi-step scenario
|
||||||
|
repeat_times: 8 # Number of rollouts per prompt for advantage estimation
|
||||||
|
optimizer:
|
||||||
|
lr: 1e-6 # Learning rate
|
||||||
|
|
||||||
|
model:
|
||||||
|
model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen3-0.6B} # Base model path
|
||||||
|
max_model_len: 24576 # Max context length
|
||||||
|
max_response_tokens: 16384 # Max tokens per response
|
||||||
|
temperature: 1.0 # Temperature of model's generation
|
||||||
|
|
||||||
|
cluster:
|
||||||
|
node_num: 1 # Number of used nodes
|
||||||
|
gpu_per_node: 8 # Number of GPUs every node
|
||||||
|
|
||||||
|
explorer:
|
||||||
|
eval_interval: 20 # Evaluation every N steps
|
||||||
|
runner_per_model: 16 # Runners per infer engine
|
||||||
|
max_timeout: 1200 # Max timeout for each rollout (seconds)
|
||||||
|
rollout_model:
|
||||||
|
engine_num: 4 # Number of vLLM engines for rollout model
|
||||||
|
tensor_parallel_size: 1 # TP size per engine for rollout model
|
||||||
|
enable_openai_api: true # Enable OpenAI-compatible API
|
||||||
|
enable_history: true # Enable conversation history
|
||||||
|
enable_auto_tool_choice: true # Enable automatic tool selection
|
||||||
|
tool_call_parser: hermes # Parser for tool calls
|
||||||
|
reasoning_parser: deepseek_r1 # Parser for reasoning type
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
save_interval: 100 # Save checkpoint every N steps
|
||||||
|
use_dynamic_bsz: true # Use dynamic batch size
|
||||||
|
ulysses_sequence_parallel_size: 1 # Sequence parallel size for Ulysses
|
||||||
62
tuner/data_augment/config_random.yaml
Normal file
62
tuner/data_augment/config_random.yaml
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
project: "Data-Augmentation" # Project name
|
||||||
|
name: "Random-Selector" # Experiment name
|
||||||
|
checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints} # Directory to save model checkpoints
|
||||||
|
|
||||||
|
# Config of data-centric experiments
|
||||||
|
buffer:
|
||||||
|
total_epochs: 1 # Total training epochs
|
||||||
|
explorer_input:
|
||||||
|
taskset:
|
||||||
|
path: "path/to/your/augmented/math_data" # Training data path
|
||||||
|
split: "train" # Training data split
|
||||||
|
task_selector:
|
||||||
|
selector_type: random # Strategy of task selection
|
||||||
|
eval_tasksets:
|
||||||
|
- name: "eval-aime24" # Evaluation data name
|
||||||
|
path: "path/to/aime24_data" # Evaluation data path
|
||||||
|
split: "test" # Evaluation data split
|
||||||
|
|
||||||
|
synchronizer:
|
||||||
|
sync_style: dynamic_by_explorer # Sync triggered dynamically by explorer
|
||||||
|
sync_method: 'nccl'
|
||||||
|
sync_interval: 4 # Sync every N steps
|
||||||
|
sync_timeout: 7200 # Timeout for synchronization (seconds)
|
||||||
|
|
||||||
|
monitor:
|
||||||
|
monitor_type: tensorboard # Can also use wandb, mlflow or swanlab
|
||||||
|
|
||||||
|
# The config below has been set in python file
|
||||||
|
|
||||||
|
algorithm:
|
||||||
|
algorithm_type: multi_step_grpo # GRPO series for multi-step scenario
|
||||||
|
repeat_times: 8 # Number of rollouts per prompt for advantage estimation
|
||||||
|
optimizer:
|
||||||
|
lr: 1e-6 # Learning rate
|
||||||
|
|
||||||
|
model:
|
||||||
|
model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen3-0.6B} # Base model path
|
||||||
|
max_model_len: 24576 # Max context length
|
||||||
|
max_response_tokens: 16384 # Max tokens per response
|
||||||
|
temperature: 1.0 # Temperature of model's generation
|
||||||
|
|
||||||
|
cluster:
|
||||||
|
node_num: 1 # Number of used nodes
|
||||||
|
gpu_per_node: 8 # Number of GPUs every node
|
||||||
|
|
||||||
|
explorer:
|
||||||
|
eval_interval: 20 # Evaluation every N steps
|
||||||
|
runner_per_model: 16 # Runners per infer engine
|
||||||
|
max_timeout: 1200 # Max timeout for each rollout (seconds)
|
||||||
|
rollout_model:
|
||||||
|
engine_num: 4 # Number of vLLM engines for rollout model
|
||||||
|
tensor_parallel_size: 1 # TP size per engine for rollout model
|
||||||
|
enable_openai_api: true # Enable OpenAI-compatible API
|
||||||
|
enable_history: true # Enable conversation history
|
||||||
|
enable_auto_tool_choice: true # Enable automatic tool selection
|
||||||
|
tool_call_parser: hermes # Parser for tool calls
|
||||||
|
reasoning_parser: deepseek_r1 # Parser for reasoning type
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
save_interval: 100 # Save checkpoint every N steps
|
||||||
|
use_dynamic_bsz: true # Use dynamic batch size
|
||||||
|
ulysses_sequence_parallel_size: 1 # Sequence parallel size for Ulysses
|
||||||
141
tuner/data_augment/main.py
Normal file
141
tuner/data_augment/main.py
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""Example of training a ReAct math-agent with configurable task selector."""
|
||||||
|
from typing import Dict
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from agentscope.tuner import (
|
||||||
|
tune,
|
||||||
|
WorkflowOutput,
|
||||||
|
JudgeOutput,
|
||||||
|
TunerModelConfig,
|
||||||
|
AlgorithmConfig,
|
||||||
|
)
|
||||||
|
from agentscope.agent import ReActAgent
|
||||||
|
from agentscope.model import OpenAIChatModel
|
||||||
|
from agentscope.formatter import OpenAIChatFormatter
|
||||||
|
from agentscope.message import Msg
|
||||||
|
|
||||||
|
|
||||||
|
async def run_react_agent(
|
||||||
|
task: Dict,
|
||||||
|
model: OpenAIChatModel,
|
||||||
|
auxiliary_models: Dict[str, OpenAIChatModel] | None = None,
|
||||||
|
) -> WorkflowOutput:
|
||||||
|
"""A simple workflow function using the ReAct agent to solve tasks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task (`Dict`): The task to be solved.
|
||||||
|
model (`OpenAIChatModel`): The language model to use.
|
||||||
|
auxiliary_models (`Dict[str, OpenAIChatModel]`):
|
||||||
|
A dictionary of additional chat models available for
|
||||||
|
LLM-as-a-Judge. Not used in this workflow.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`WorkflowOutput`: The workflow output containing the agent's response.
|
||||||
|
"""
|
||||||
|
assert (
|
||||||
|
auxiliary_models is None or len(auxiliary_models) == 0
|
||||||
|
), "No auxiliary models are used in this workflow."
|
||||||
|
|
||||||
|
sys_prompt = (
|
||||||
|
"You are an agent specialized in solving math problems with tools. "
|
||||||
|
"Please solve the math problem given to you. You can write and "
|
||||||
|
"execute Python code to perform calculation or verify your answer. "
|
||||||
|
"You should return your final answer within \\boxed{{}}."
|
||||||
|
)
|
||||||
|
agent = ReActAgent(
|
||||||
|
name="react_agent",
|
||||||
|
sys_prompt=sys_prompt,
|
||||||
|
model=model,
|
||||||
|
enable_meta_tool=True,
|
||||||
|
formatter=OpenAIChatFormatter(),
|
||||||
|
)
|
||||||
|
response = await agent.reply(
|
||||||
|
msg=Msg("user", task["question"], role="user"),
|
||||||
|
)
|
||||||
|
return WorkflowOutput(
|
||||||
|
response=response,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def gsm8k_judge(
|
||||||
|
task: Dict,
|
||||||
|
response: Msg,
|
||||||
|
auxiliary_models: Dict[str, OpenAIChatModel] | None = None,
|
||||||
|
) -> JudgeOutput:
|
||||||
|
"""A simple judge function to calculate reward based on agent's response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task (`Dict`): The task information for the corresponding workflow.
|
||||||
|
response (`Msg`): The response generated by the corresponding workflow.
|
||||||
|
auxiliary_models (`Dict[str, OpenAIChatModel]`):
|
||||||
|
A dictionary of additional chat models available for LLM-as-a-Judge
|
||||||
|
usage. The keys are model names, and the values are the
|
||||||
|
corresponding OpenAIChatModel instances.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`JudgeOutput`: The reward value assigned by the judge function.
|
||||||
|
"""
|
||||||
|
from trinity.common.rewards.math_reward import MathBoxedRewardFn
|
||||||
|
|
||||||
|
assert (
|
||||||
|
auxiliary_models is None or len(auxiliary_models) == 0
|
||||||
|
), "No auxiliary models are used in this workflow."
|
||||||
|
|
||||||
|
reward_fn = MathBoxedRewardFn()
|
||||||
|
# parse truth from gsm8k raw text
|
||||||
|
truth = task["answer"]
|
||||||
|
if isinstance(truth, str) and "####" in truth:
|
||||||
|
truth = truth.split("####")[1].strip()
|
||||||
|
else:
|
||||||
|
truth = str(truth)
|
||||||
|
# parse answer from response message
|
||||||
|
result = response.get_text_content()
|
||||||
|
reward_dict = reward_fn(
|
||||||
|
response=result,
|
||||||
|
truth=truth,
|
||||||
|
)
|
||||||
|
return JudgeOutput(
|
||||||
|
reward=sum(reward_dict.values()),
|
||||||
|
metrics=reward_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# We recommend using YAML for data-centric experiments.
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Train math-agent with different task selectors",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--config",
|
||||||
|
type=str,
|
||||||
|
default="config_random.yaml",
|
||||||
|
help="Path to the configuration YAML file",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
tuner_model = TunerModelConfig(
|
||||||
|
model_path="Qwen/Qwen3-0.6B",
|
||||||
|
max_model_len=24576,
|
||||||
|
max_tokens=16384,
|
||||||
|
temperature=1.0,
|
||||||
|
inference_engine_num=4,
|
||||||
|
tensor_parallel_size=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
algorithm = AlgorithmConfig(
|
||||||
|
algorithm_type="multi_step_grpo",
|
||||||
|
group_size=8,
|
||||||
|
learning_rate=1e-6,
|
||||||
|
eval_interval_steps=20,
|
||||||
|
batch_size=16,
|
||||||
|
)
|
||||||
|
|
||||||
|
tune(
|
||||||
|
workflow_func=run_react_agent,
|
||||||
|
judge_func=gsm8k_judge,
|
||||||
|
config_path=args.config,
|
||||||
|
model=tuner_model,
|
||||||
|
algorithm=algorithm,
|
||||||
|
)
|
||||||
148
tuner/data_augment/prepare_data.py
Normal file
148
tuner/data_augment/prepare_data.py
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
Prepare math data from LLM360/guru-RL-92k
|
||||||
|
Transfer to the GSM8K Format
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
import pandas as pd
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
|
||||||
|
# Define constants for the dataset
|
||||||
|
DATASET_REPO = "LLM360/guru-RL-92k"
|
||||||
|
DATASET_FILE = "train/math__combined_54.4k.parquet"
|
||||||
|
|
||||||
|
|
||||||
|
# Download the dataset from Hugging Face Hub.
|
||||||
|
# The dataset is from LLM360/guru-RL-92k.
|
||||||
|
def download_dataset(
|
||||||
|
repo_id: str,
|
||||||
|
filename_in_repo: str,
|
||||||
|
local_dir: str,
|
||||||
|
) -> Path:
|
||||||
|
print(f"--- Downloading dataset: {repo_id} ---")
|
||||||
|
print(f"File: {filename_in_repo}")
|
||||||
|
|
||||||
|
local_path = Path(local_dir)
|
||||||
|
local_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
downloaded_file_path = hf_hub_download(
|
||||||
|
repo_id=repo_id,
|
||||||
|
filename=filename_in_repo,
|
||||||
|
repo_type="dataset",
|
||||||
|
local_dir=local_path,
|
||||||
|
)
|
||||||
|
print(f"Successfully downloaded to: {downloaded_file_path}")
|
||||||
|
return Path(downloaded_file_path)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error downloading dataset: {e}", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
# Transform a single row from the original format to the target format.
|
||||||
|
def transform_row(row: pd.Series) -> pd.Series:
|
||||||
|
try:
|
||||||
|
original_question = row["prompt"][0]["content"]
|
||||||
|
sentence_to_remove = "Please output the final answer within \\boxed{}."
|
||||||
|
question = original_question.replace(sentence_to_remove, "").strip()
|
||||||
|
|
||||||
|
ground_truth = row["reward_model"]["ground_truth"]
|
||||||
|
answer = f"#### {ground_truth}"
|
||||||
|
|
||||||
|
rate_7b = row.get("qwen2.5_7b_pass_rate")
|
||||||
|
rate_30b = row.get("qwen3_30b_pass_rate")
|
||||||
|
|
||||||
|
return pd.Series(
|
||||||
|
{
|
||||||
|
"question": question,
|
||||||
|
"answer": answer,
|
||||||
|
"qwen2.5_7b_pass_rate": rate_7b,
|
||||||
|
"qwen3_30b_pass_rate": rate_30b,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except (TypeError, IndexError, KeyError) as e:
|
||||||
|
error_msg = (
|
||||||
|
f"Skipping row due to processing error: {e}. "
|
||||||
|
f"Row content: {row.to_dict()}"
|
||||||
|
)
|
||||||
|
print(error_msg, file=sys.stderr)
|
||||||
|
return pd.Series(
|
||||||
|
{
|
||||||
|
"question": None,
|
||||||
|
"answer": None,
|
||||||
|
"qwen2.5_7b_pass_rate": None,
|
||||||
|
"qwen3_30b_pass_rate": None,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Read, transform, and save the dataset to a new location.
|
||||||
|
def transform_and_save_dataset(input_file: Path, output_dir: str):
|
||||||
|
output_path = Path(output_dir)
|
||||||
|
output_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
output_file_path = output_path / input_file.name
|
||||||
|
|
||||||
|
print(f"--- Reading source file: {input_file} ---")
|
||||||
|
try:
|
||||||
|
df_original = pd.read_parquet(input_file)
|
||||||
|
print(f"Successfully read {len(df_original)} records.")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Fatal error reading file: {e}", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
print("--- Starting data transformation ---")
|
||||||
|
df_transformed = df_original.apply(transform_row, axis=1)
|
||||||
|
|
||||||
|
original_count = len(df_transformed)
|
||||||
|
df_transformed.dropna(subset=["question", "answer"], inplace=True)
|
||||||
|
dropped_count = original_count - len(df_transformed)
|
||||||
|
if dropped_count > 0:
|
||||||
|
print(f"Warning: Dropped {dropped_count} invalid records.")
|
||||||
|
|
||||||
|
print(f"Transformation complete. {len(df_transformed)} generated.")
|
||||||
|
|
||||||
|
print(f"--- Saving processed file to: {output_file_path} ---")
|
||||||
|
try:
|
||||||
|
df_transformed.to_parquet(output_file_path, index=False)
|
||||||
|
print(f"Process complete! New file saved at: {output_file_path}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Fatal error saving file: {e}", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Download and transform the guru-RL-92k math dataset.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--raw_data_dir",
|
||||||
|
type=str,
|
||||||
|
default="data/train/raw",
|
||||||
|
help="Directory to download the raw dataset file.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--processed_data_dir",
|
||||||
|
type=str,
|
||||||
|
default="data/train/math",
|
||||||
|
help="Directory to save the transformed dataset file.",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
downloaded_file = download_dataset(
|
||||||
|
repo_id=DATASET_REPO,
|
||||||
|
filename_in_repo=DATASET_FILE,
|
||||||
|
local_dir=args.raw_data_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
transform_and_save_dataset(
|
||||||
|
input_file=downloaded_file,
|
||||||
|
output_dir=args.processed_data_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
BIN
tuner/data_augment/training_result.jpg
Normal file
BIN
tuner/data_augment/training_result.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 315 KiB |
283
tuner/email_search/README.md
Normal file
283
tuner/email_search/README.md
Normal file
@@ -0,0 +1,283 @@
|
|||||||
|
# Training Email Search Agent with RL using AgentScope-Tuner
|
||||||
|
|
||||||
|
This example demonstrates how to implement reinforcement fine-tuning for the Email Search task (inspired by [ART](https://openpipe.ai/blog/art-e-mail-agent)) using AgentScope-Tuner, whose RFT functionality is backed by [Trinity-RFT](https://github.com/agentscope-ai/Trinity-RFT).
|
||||||
|
|
||||||
|
## Task Setting
|
||||||
|
|
||||||
|
The agent's goal is to answer user queries by searching through an email inbox. The agent needs to:
|
||||||
|
- Understand the user's question
|
||||||
|
- Search for relevant emails using keywords
|
||||||
|
- Read email contents to extract information
|
||||||
|
- Provide accurate answers with proper source citations
|
||||||
|
|
||||||
|
**Agent Type**: The agent (`EmailSearchAgent`) extends `ReActAgent`, which follows a reasoning-acting loop to solve tasks iteratively.
|
||||||
|
|
||||||
|
**Environment**: The environment is a SQLite database containing emails from the Enron Email dataset. Each task provides:
|
||||||
|
- `question`: The user's email search query
|
||||||
|
- `inbox_address`: The email inbox to search
|
||||||
|
- `query_date`: The date context for the query
|
||||||
|
- `answer`: The expected answer (ground truth), only for reward calculation
|
||||||
|
- `message_ids`: IDs of relevant emails containing the answer, only for reward calculation
|
||||||
|
|
||||||
|
**Available Tools**:
|
||||||
|
- `search_emails`: Find emails by keywords, inbox address, and date range. Returns a list of email summaries (message_id and snippet).
|
||||||
|
- `read_email`: Read the full content of a specific email by message_id.
|
||||||
|
- `generate_response`: Provide the final structured answer with sources (inherited from ReAct agent).
|
||||||
|
|
||||||
|
## Dataset Preparation
|
||||||
|
|
||||||
|
The dataset contains email queries based on the [Enron Email dataset](https://huggingface.co/datasets/corbt/enron-emails). Run the data preparation script to generate the email database and datasets:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python prepare_data.py
|
||||||
|
```
|
||||||
|
|
||||||
|
If you want to choose a new database path, you can modify the `DEFAULT_DB_PATH` in [`prepare_data.py`](./prepare_data.py). Also, remember to set an environment variable `DEFAULT_EMAIL_DB_PATH` to point to the database path before moving to the next step:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export DEFAULT_EMAIL_DB_PATH=/path/to/enron_emails_dataset/data/enron_emails.db
|
||||||
|
```
|
||||||
|
|
||||||
|
This will create a SQLite database and datasets:
|
||||||
|
|
||||||
|
```
|
||||||
|
/path/to/enron_emails_dataset/
|
||||||
|
├── data
|
||||||
|
└── enron_emails.db # Email database
|
||||||
|
├── train.parquet # Training samples
|
||||||
|
└── test.parquet # Test samples
|
||||||
|
```
|
||||||
|
|
||||||
|
Each sample looks like:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"id": 0,
|
||||||
|
"question": "Were there any variances detected for hour 6 on 3/9/01?",
|
||||||
|
"answer": "Yes, variances were detected in both Generation and Energy Import/Export schedules for hour 6 on 3/9/01.",
|
||||||
|
"message_ids": ["<17407857.1075840601283.JavaMail.evans@thyme>"],
|
||||||
|
"how_realistic": 0.800000011920929,
|
||||||
|
"inbox_address": "pete.davis@enron.com",
|
||||||
|
"query_date": "2001-03-16"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Code Implementation
|
||||||
|
|
||||||
|
This section provides a high-level overview of the code implementation. For detailed implementation, please refer to the source code.
|
||||||
|
|
||||||
|
### Agent Workflow
|
||||||
|
|
||||||
|
The workflow function `run_email_search_agent` implements the agent-environment interaction loop:
|
||||||
|
|
||||||
|
```python
|
||||||
|
async def run_email_search_agent(
|
||||||
|
task: Dict,
|
||||||
|
model: ChatModelBase,
|
||||||
|
auxiliary_models: Dict[str, ChatModelBase],
|
||||||
|
) -> WorkflowOutput:
|
||||||
|
# Parse task and create agent
|
||||||
|
agent = EmailSearchAgent(
|
||||||
|
name="email_search_agent",
|
||||||
|
sys_prompt=system_prompt,
|
||||||
|
model=model,
|
||||||
|
max_iters=max_turns,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run the agent with structured output
|
||||||
|
response = await agent.reply(
|
||||||
|
msg=Msg("user", question, role="user"),
|
||||||
|
structured_model=AnswerModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
return WorkflowOutput(response=response)
|
||||||
|
```
|
||||||
|
|
||||||
|
The agent follows a ReAct pattern: it reasons about the task, calls tools to search and read emails, and finally generates a structured response containing the answer and source message IDs.
|
||||||
|
|
||||||
|
### Judge Function
|
||||||
|
|
||||||
|
The judge function `email_search_judge` implements reward calculation using LLM-as-a-Judge:
|
||||||
|
|
||||||
|
```python
|
||||||
|
async def email_search_judge(
|
||||||
|
task: Dict,
|
||||||
|
response: Msg,
|
||||||
|
auxiliary_models: Dict[str, ChatModelBase],
|
||||||
|
) -> JudgeOutput:
|
||||||
|
# Extract answer and sources from response
|
||||||
|
answer = answer_and_sources.get("answer")
|
||||||
|
sources = answer_and_sources.get("sources", [])
|
||||||
|
|
||||||
|
# Judge correctness using LLM-as-a-Judge
|
||||||
|
judge_model = auxiliary_models.get('judge') or list(auxiliary_models.values())[0]
|
||||||
|
judge_response = await judge_correctness(
|
||||||
|
answer, query, judge_model
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate reward based on:
|
||||||
|
# - Answer correctness (accuracy: -1.0 to 1.0)
|
||||||
|
# - Source correctness (format: partial rewards)
|
||||||
|
# - Efficiency (bonus for fewer turns, correct sources)
|
||||||
|
result = {"accuracy": ..., "format": ...} # calculated based on judge_response
|
||||||
|
|
||||||
|
return JudgeOutput(
|
||||||
|
reward=sum(result.values()),
|
||||||
|
metrics=metrics,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
The reward function considers:
|
||||||
|
- **Answer correctness**: Evaluated by LLM-as-a-Judge comparing the agent's answer with the ground truth
|
||||||
|
- **Source correctness**: Whether the agent cited the correct email message IDs
|
||||||
|
- **Efficiency**: Bonus rewards for finding/reading the correct email and taking fewer turns
|
||||||
|
|
||||||
|
See [`main.py`](./main.py) and [`email_search_agent.py`](./email_search_agent.py) for implementation details.
|
||||||
|
|
||||||
|
## How to Run
|
||||||
|
|
||||||
|
### Prerequisites
|
||||||
|
|
||||||
|
- At least 4 NVIDIA GPUs with CUDA 12.8 or newer
|
||||||
|
* Note: For the 30B Judge model, you need to use a GPU with at least 4080 memory; you can also run the model on multiple GPUs by using `tensor_parallel_size > 1` to reduce the memory usage (by default, `tensor_parallel_size=2`).
|
||||||
|
- Follow the Trinity-RFT [installation guide](https://agentscope-ai.github.io/Trinity-RFT/en/main/tutorial/trinity_installation.html) to install the latest version from source code
|
||||||
|
- Download the model checkpoint (example):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
huggingface-cli download Qwen/Qwen3-4B-Instruct-2507
|
||||||
|
huggingface-cli download Qwen/Qwen3-30B-A3B-Instruct-2507 # judge model
|
||||||
|
```
|
||||||
|
|
||||||
|
### Configuration
|
||||||
|
|
||||||
|
Adjust the configuration file ([`config.yaml`](./config.yaml)) based on your hardware. Key configuration sections include:
|
||||||
|
|
||||||
|
- **TunerModelConfig**: Set `model_path` to your model checkpoint path
|
||||||
|
- **AlgorithmConfig**: Configure RL algorithm parameters (e.g., `multi_step_grpo`, learning rate, policy loss function)
|
||||||
|
- **DatasetConfig**: The dataset path is specified in `main.py` when creating the `DatasetConfig` object
|
||||||
|
- **Auxiliary Models**: Configure judge model settings for LLM-as-a-Judge
|
||||||
|
|
||||||
|
For full configuration details, see [Trinity-RFT Configuration Guide](https://agentscope-ai.github.io/Trinity-RFT/en/main/tutorial/trinity_configs.html).
|
||||||
|
|
||||||
|
### Start-Up Commands
|
||||||
|
|
||||||
|
1. Prepare the dataset:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python prepare_data.py
|
||||||
|
export DEFAULT_EMAIL_DB_PATH=/path/to/enron_emails_dataset/data/enron_emails.db
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Set up a [Ray](https://github.com/ray-project/ray) cluster:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
ray start --head
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Run the training script:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python main.py
|
||||||
|
```
|
||||||
|
|
||||||
|
## Experimental Results
|
||||||
|
|
||||||
|
### Quantitative Results
|
||||||
|
|
||||||
|
The training results show improvements in agent performance over training iterations. Key metrics include:
|
||||||
|
|
||||||
|
- **Train reward**: The average reward on training samples increases as the agent learns better strategies
|
||||||
|
- **Rollout accuracy**: The average accuracy on rollout samples increases as the agent learns better strategies
|
||||||
|
|
||||||
|
<div align="center">
|
||||||
|
<img src="./critic_reward_mean.png" alt="Training Rewards" width="90%"/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div align="center">
|
||||||
|
<img src="./rollout_accuracy_mean.png" alt="Rollout Accuracy" width="90%"/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
|
||||||
|
### Concrete Example
|
||||||
|
|
||||||
|
An example of the agent's behavior is shown below:
|
||||||
|
|
||||||
|
**Query:** "What do the color codes mean in the curve assessment?"
|
||||||
|
|
||||||
|
We show the last several turns of agent responses:
|
||||||
|
|
||||||
|
The agent performs multiple search attempts to find relevant emails. After some unsuccessful searches, the agent tries:
|
||||||
|
|
||||||
|
**Tool call:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"name": "search_emails",
|
||||||
|
"input": {
|
||||||
|
"inbox_address": "steven.kean@enron.com",
|
||||||
|
"query_date": "2001-04-03",
|
||||||
|
"keywords": ["curve", "assessment"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Tool result:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"name": "search_emails",
|
||||||
|
"output": [
|
||||||
|
{
|
||||||
|
"message_id": "<2654330.1075846153519.JavaMail.evans@thyme>",
|
||||||
|
"snippet": "... curves and Senior Management as necessary.\\n\\n<b>Curve Assessment</b>\\n Green: \\n - curve is reasonable\\n - small P&L ..."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"message_id": "<12499967.1075846153495.JavaMail.evans@thyme>",
|
||||||
|
"snippet": "... curves and Senior Management as necessary.\\n\\n<b>Curve Assessment</b>\\n Green: \\n - curve is reasonable\\n - small P&L ..."
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
After finding relevant emails, the agent uses `read_email` tool to read the full content of both emails:
|
||||||
|
|
||||||
|
**Tool call:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"name": "read_email",
|
||||||
|
"input": {
|
||||||
|
"message_id": "<2654330.1075846153519.JavaMail.evans@thyme>"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Tool call:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"name": "read_email",
|
||||||
|
"input": {
|
||||||
|
"message_id": "<12499967.1075846153495.JavaMail.evans@thyme>"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
After reading the emails, the agent generates the final structured response:
|
||||||
|
|
||||||
|
**Tool call:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"name": "generate_response",
|
||||||
|
"input": {
|
||||||
|
"answer": "In the curve assessment, the color codes are used as follows:\n- Green: The curve is considered reasonable, and the P&L (profit and loss) is small.",
|
||||||
|
"sources": [
|
||||||
|
"<2654330.1075846153519.JavaMail.evans@thyme>",
|
||||||
|
"<12499967.1075846153495.JavaMail.evans@thyme>"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
The judge evaluates the answer as correct based on the ground truth.
|
||||||
283
tuner/email_search/README_zh.md
Normal file
283
tuner/email_search/README_zh.md
Normal file
@@ -0,0 +1,283 @@
|
|||||||
|
# 使用 AgentScope-Tuner 训练邮件搜索智能体
|
||||||
|
|
||||||
|
本示例展示如何使用 AgentScope-Tuner 对邮件搜索任务(灵感来自 [ART](https://openpipe.ai/blog/art-e-mail-agent))进行强化微调,其 RFT 功能由 [Trinity-RFT](https://github.com/agentscope-ai/Trinity-RFT) 提供支持。
|
||||||
|
|
||||||
|
## 任务设定
|
||||||
|
|
||||||
|
智能体的目标是通过搜索邮件收件箱来回答用户查询。智能体需要:
|
||||||
|
- 理解用户的问题
|
||||||
|
- 使用关键词搜索相关邮件
|
||||||
|
- 阅读邮件内容以提取信息
|
||||||
|
- 提供准确的答案并附上适当的来源引用
|
||||||
|
|
||||||
|
**智能体类型**:智能体(`EmailSearchAgent`)继承自 `ReActAgent`,遵循推理-行动循环来迭代解决任务。
|
||||||
|
|
||||||
|
**环境**:环境是一个包含来自 Enron 邮件数据集的 SQLite 数据库。每个任务提供:
|
||||||
|
- `question`:用户的邮件搜索查询
|
||||||
|
- `inbox_address`:要搜索的邮件收件箱
|
||||||
|
- `query_date`:查询的日期上下文
|
||||||
|
- `answer`:预期答案(真实值),仅用于奖励计算
|
||||||
|
- `message_ids`:包含答案的相关邮件 ID,仅用于奖励计算
|
||||||
|
|
||||||
|
**可用工具**:
|
||||||
|
- `search_emails`:通过关键词、收件箱地址和日期范围查找邮件。返回邮件摘要列表(message_id 和片段)。
|
||||||
|
- `read_email`:通过 message_id 读取特定邮件的完整内容。
|
||||||
|
- `generate_response`:提供带有来源的最终结构化答案(继承自 ReAct 智能体)。
|
||||||
|
|
||||||
|
## 数据集准备
|
||||||
|
|
||||||
|
数据集包含基于 [Enron 邮件数据集](https://huggingface.co/datasets/corbt/enron-emails) 的邮件查询。运行数据准备脚本以生成邮件数据库和数据集:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python prepare_data.py
|
||||||
|
```
|
||||||
|
|
||||||
|
如果你想选择新的数据库路径,可以修改 [`prepare_data.py`](./prepare_data.py) 中的 `DEFAULT_DB_PATH`。同时,请记住在进入下一步之前设置环境变量 `DEFAULT_EMAIL_DB_PATH` 指向数据库路径:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export DEFAULT_EMAIL_DB_PATH=/path/to/enron_emails_dataset/data/enron_emails.db
|
||||||
|
```
|
||||||
|
|
||||||
|
这将创建一个 SQLite 数据库和数据集:
|
||||||
|
|
||||||
|
```
|
||||||
|
/path/to/enron_emails_dataset/
|
||||||
|
├── data
|
||||||
|
└── enron_emails.db # 邮件数据库
|
||||||
|
├── train.parquet # 训练样本
|
||||||
|
└── test.parquet # 测试样本
|
||||||
|
```
|
||||||
|
|
||||||
|
每个样本如下所示:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"id": 0,
|
||||||
|
"question": "Were there any variances detected for hour 6 on 3/9/01?",
|
||||||
|
"answer": "Yes, variances were detected in both Generation and Energy Import/Export schedules for hour 6 on 3/9/01.",
|
||||||
|
"message_ids": ["<17407857.1075840601283.JavaMail.evans@thyme>"],
|
||||||
|
"how_realistic": 0.800000011920929,
|
||||||
|
"inbox_address": "pete.davis@enron.com",
|
||||||
|
"query_date": "2001-03-16"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## 代码实现
|
||||||
|
|
||||||
|
本节提供代码实现的高级概览。详细实现请参考源代码。
|
||||||
|
|
||||||
|
### 智能体工作流
|
||||||
|
|
||||||
|
工作流函数 `run_email_search_agent` 实现智能体-环境交互循环:
|
||||||
|
|
||||||
|
```python
|
||||||
|
async def run_email_search_agent(
|
||||||
|
task: Dict,
|
||||||
|
model: ChatModelBase,
|
||||||
|
auxiliary_models: Dict[str, ChatModelBase],
|
||||||
|
) -> WorkflowOutput:
|
||||||
|
# 解析任务并创建智能体
|
||||||
|
agent = EmailSearchAgent(
|
||||||
|
name="email_search_agent",
|
||||||
|
sys_prompt=system_prompt,
|
||||||
|
model=model,
|
||||||
|
max_iters=max_turns,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 使用结构化输出运行智能体
|
||||||
|
response = await agent.reply(
|
||||||
|
msg=Msg("user", question, role="user"),
|
||||||
|
structured_model=AnswerModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
return WorkflowOutput(response=response)
|
||||||
|
```
|
||||||
|
|
||||||
|
智能体遵循 ReAct 模式:它推理任务,调用工具搜索和阅读邮件,最后生成包含答案和来源消息 ID 的结构化响应。
|
||||||
|
|
||||||
|
### 评判函数
|
||||||
|
|
||||||
|
评判函数 `email_search_judge` 使用 LLM-as-a-Judge 实现奖励计算:
|
||||||
|
|
||||||
|
```python
|
||||||
|
async def email_search_judge(
|
||||||
|
task: Dict,
|
||||||
|
response: Msg,
|
||||||
|
auxiliary_models: Dict[str, ChatModelBase],
|
||||||
|
) -> JudgeOutput:
|
||||||
|
# 从响应中提取答案和来源
|
||||||
|
answer = answer_and_sources.get("answer")
|
||||||
|
sources = answer_and_sources.get("sources", [])
|
||||||
|
|
||||||
|
# 使用 LLM-as-a-Judge 评判正确性
|
||||||
|
judge_model = auxiliary_models.get('judge') or list(auxiliary_models.values())[0]
|
||||||
|
judge_response = await judge_correctness(
|
||||||
|
answer, query, judge_model
|
||||||
|
)
|
||||||
|
|
||||||
|
# 基于以下因素计算奖励:
|
||||||
|
# - 答案正确性(准确度:-1.0 到 1.0)
|
||||||
|
# - 来源正确性(格式:部分奖励)
|
||||||
|
# - 效率(对更少轮次、正确来源的奖励)
|
||||||
|
result = {"accuracy": ..., "format": ...} # 基于 judge_response 计算
|
||||||
|
|
||||||
|
return JudgeOutput(
|
||||||
|
reward=sum(result.values()),
|
||||||
|
metrics=metrics,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
奖励函数考虑以下因素:
|
||||||
|
- **答案正确性**:通过 LLM-as-a-Judge 比较智能体的答案与真实值进行评估
|
||||||
|
- **来源正确性**:智能体是否引用了正确的邮件消息 ID
|
||||||
|
- **效率**:对找到/阅读正确邮件和更少轮次的奖励
|
||||||
|
|
||||||
|
详细实现请参考 [`main.py`](./main.py) 和 [`email_search_agent.py`](./email_search_agent.py)。
|
||||||
|
|
||||||
|
## 运行方法
|
||||||
|
|
||||||
|
### 前置要求
|
||||||
|
|
||||||
|
- 至少 4 张 NVIDIA GPU,CUDA 版本 ≥ 12.8
|
||||||
|
* 注意:对于 30B 评判模型,需要使用至少 4080 显存的 GPU;你也可以通过使用 `tensor_parallel_size > 1` 在多张 GPU 上运行模型以减少显存使用(默认情况下,`tensor_parallel_size=2`)。
|
||||||
|
- 按照 Trinity-RFT [安装指南](https://agentscope-ai.github.io/Trinity-RFT/zh/main/tutorial/trinity_installation.html) 从源码安装最新版本
|
||||||
|
- 下载模型检查点(示例):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
huggingface-cli download Qwen/Qwen3-4B-Instruct-2507
|
||||||
|
huggingface-cli download Qwen/Qwen3-30B-A3B-Instruct-2507 # 评判模型
|
||||||
|
```
|
||||||
|
|
||||||
|
### 配置
|
||||||
|
|
||||||
|
根据你的硬件调整配置文件([`config.yaml`](./config.yaml))。关键配置部分包括:
|
||||||
|
|
||||||
|
- **TunerModelConfig**:将 `model_path` 设置为你的模型检查点路径
|
||||||
|
- **AlgorithmConfig**:配置 RL 算法参数(例如,`multi_step_grpo`、学习率、策略损失函数)
|
||||||
|
- **DatasetConfig**:数据集路径在创建 `DatasetConfig` 对象时在 `main.py` 中指定
|
||||||
|
- **辅助模型**:为 LLM-as-a-Judge 配置评判模型设置
|
||||||
|
|
||||||
|
完整配置详情请参考 [Trinity-RFT 配置指南](https://agentscope-ai.github.io/Trinity-RFT/zh/main/tutorial/trinity_configs.html)。
|
||||||
|
|
||||||
|
### 启动命令
|
||||||
|
|
||||||
|
1. 准备数据集:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python prepare_data.py
|
||||||
|
export DEFAULT_EMAIL_DB_PATH=/path/to/enron_emails_dataset/data/enron_emails.db
|
||||||
|
```
|
||||||
|
|
||||||
|
2. 启动 [Ray](https://github.com/ray-project/ray):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
ray start --head
|
||||||
|
```
|
||||||
|
|
||||||
|
3. 运行训练脚本:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python main.py
|
||||||
|
```
|
||||||
|
|
||||||
|
## 实验结果
|
||||||
|
|
||||||
|
### 定量结果
|
||||||
|
|
||||||
|
训练结果显示智能体性能随训练迭代次数的提升。关键指标包括:
|
||||||
|
|
||||||
|
- **训练奖励**:训练样本上的平均奖励随着智能体学习更好的策略而增加
|
||||||
|
- **Rollout 准确度**:Rollout 样本上的平均准确度随着智能体学习更好的策略而增加
|
||||||
|
|
||||||
|
<div align="center">
|
||||||
|
<img src="./critic_reward_mean.png" alt="Training Rewards" width="90%"/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div align="center">
|
||||||
|
<img src="./rollout_accuracy_mean.png" alt="Rollout Accuracy" width="90%"/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
|
||||||
|
### 具体示例
|
||||||
|
|
||||||
|
智能体行为示例如下:
|
||||||
|
|
||||||
|
**查询:** "What do the color codes mean in the curve assessment?"
|
||||||
|
|
||||||
|
我们展示智能体响应的最后几轮:
|
||||||
|
|
||||||
|
智能体执行多次搜索尝试以找到相关邮件。经过一些不成功的搜索后,智能体尝试:
|
||||||
|
|
||||||
|
**工具调用:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"name": "search_emails",
|
||||||
|
"input": {
|
||||||
|
"inbox_address": "steven.kean@enron.com",
|
||||||
|
"query_date": "2001-04-03",
|
||||||
|
"keywords": ["curve", "assessment"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**工具结果:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"name": "search_emails",
|
||||||
|
"output": [
|
||||||
|
{
|
||||||
|
"message_id": "<2654330.1075846153519.JavaMail.evans@thyme>",
|
||||||
|
"snippet": "... curves and Senior Management as necessary.\\n\\n<b>Curve Assessment</b>\\n Green: \\n - curve is reasonable\\n - small P&L ..."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"message_id": "<12499967.1075846153495.JavaMail.evans@thyme>",
|
||||||
|
"snippet": "... curves and Senior Management as necessary.\\n\\n<b>Curve Assessment</b>\\n Green: \\n - curve is reasonable\\n - small P&L ..."
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
找到相关邮件后,智能体使用 `read_email` 工具读取两封邮件的完整内容:
|
||||||
|
|
||||||
|
**工具调用:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"name": "read_email",
|
||||||
|
"input": {
|
||||||
|
"message_id": "<2654330.1075846153519.JavaMail.evans@thyme>"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**工具调用:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"name": "read_email",
|
||||||
|
"input": {
|
||||||
|
"message_id": "<12499967.1075846153495.JavaMail.evans@thyme>"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
阅读邮件后,智能体生成最终的结构化响应:
|
||||||
|
|
||||||
|
**工具调用:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"name": "generate_response",
|
||||||
|
"input": {
|
||||||
|
"answer": "In the curve assessment, the color codes are used as follows:\n- Green: The curve is considered reasonable, and the P&L (profit and loss) is small.",
|
||||||
|
"sources": [
|
||||||
|
"<2654330.1075846153519.JavaMail.evans@thyme>",
|
||||||
|
"<12499967.1075846153495.JavaMail.evans@thyme>"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
评判器评估上面的答案为正确。
|
||||||
175
tuner/email_search/_email_search_agent.py
Normal file
175
tuner/email_search/_email_search_agent.py
Normal file
@@ -0,0 +1,175 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""Adapted from Trinity-RFT"""
|
||||||
|
import json
|
||||||
|
import traceback
|
||||||
|
from dataclasses import asdict
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Any
|
||||||
|
from _utils import ( # pylint: disable=E0611
|
||||||
|
read_email_tool,
|
||||||
|
search_emails_tool,
|
||||||
|
)
|
||||||
|
from agentscope import logger
|
||||||
|
from agentscope.agent import ReActAgent
|
||||||
|
from agentscope.message import TextBlock
|
||||||
|
from agentscope.tool import Toolkit, ToolResponse
|
||||||
|
|
||||||
|
|
||||||
|
def pre_reasoning_hook(_self: Any, _kwargs: Any) -> dict[str, Any] | None:
|
||||||
|
"""Pre-reasoning hook to remove tool_choice from kwargs."""
|
||||||
|
_kwargs.pop("tool_choice", None)
|
||||||
|
return _kwargs
|
||||||
|
|
||||||
|
|
||||||
|
class EmailSearchAgent(ReActAgent):
|
||||||
|
"""
|
||||||
|
A customized ReAct agent with pre-defined tools for
|
||||||
|
email search and reading.
|
||||||
|
Ref: https://github.com/OpenPipe/ART
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||||
|
self.message_id_list = (
|
||||||
|
[]
|
||||||
|
) # List to store message IDs found during search
|
||||||
|
self.ever_read_message_ids = (
|
||||||
|
[]
|
||||||
|
) # List to store message IDs that have been read
|
||||||
|
toolkit = Toolkit()
|
||||||
|
toolkit.register_tool_function(self.search_emails)
|
||||||
|
toolkit.register_tool_function(self.read_email)
|
||||||
|
super().__init__(*args, toolkit=toolkit, **kwargs)
|
||||||
|
|
||||||
|
self.register_instance_hook(
|
||||||
|
"pre_reasoning",
|
||||||
|
"tool_choice_hook",
|
||||||
|
pre_reasoning_hook,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def reset(self) -> None:
|
||||||
|
"""Reset agent state for a new rollout/episode."""
|
||||||
|
self.message_id_list.clear()
|
||||||
|
self.ever_read_message_ids.clear()
|
||||||
|
await self.memory.clear()
|
||||||
|
|
||||||
|
def search_emails(
|
||||||
|
self,
|
||||||
|
inbox_address: str,
|
||||||
|
query_date: str,
|
||||||
|
keywords: list[str],
|
||||||
|
**_kwargs: Any,
|
||||||
|
) -> ToolResponse:
|
||||||
|
"""
|
||||||
|
Search the user's email inbox for emails that match the given keywords.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inbox_address: The user's email address.
|
||||||
|
query_date: The date of the query in 'YYYY-MM-DD' format.
|
||||||
|
keywords: Keywords to search for in the user's email inbox.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ToolResponse:
|
||||||
|
A ToolResponse object containing a list of TextBlock objects
|
||||||
|
in the `content` field. On success, the text field of the
|
||||||
|
TextBlock contains a JSON string representing a list of email
|
||||||
|
summaries (e.g., message_id, snippet) matching the search
|
||||||
|
criteria. Each email summary is converted to a dictionary via
|
||||||
|
`asdict`. On failure, the text indicates an error message.
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
next_day = (
|
||||||
|
datetime.strptime(query_date, "%Y-%m-%d") + timedelta(days=1)
|
||||||
|
).strftime(
|
||||||
|
"%Y-%m-%d",
|
||||||
|
)
|
||||||
|
res = search_emails_tool(
|
||||||
|
inbox=inbox_address,
|
||||||
|
sent_before=next_day,
|
||||||
|
keywords=keywords,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.message_id_list.extend([r.message_id for r in res])
|
||||||
|
|
||||||
|
return ToolResponse(
|
||||||
|
content=[
|
||||||
|
TextBlock(
|
||||||
|
type="text",
|
||||||
|
text=json.dumps([asdict(r) for r in res]),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.info(
|
||||||
|
"Error in search_emails: %s, traceback: %s",
|
||||||
|
e,
|
||||||
|
traceback.format_exc(),
|
||||||
|
)
|
||||||
|
return ToolResponse(
|
||||||
|
content=[
|
||||||
|
TextBlock(
|
||||||
|
type="text",
|
||||||
|
text=(
|
||||||
|
f"Error: Failed to search emails.\n"
|
||||||
|
f"Error message: {e}"
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
def read_email(self, message_id: str, **_kwargs: Any) -> ToolResponse:
|
||||||
|
"""
|
||||||
|
Read the content of an email from the user's email inbox.
|
||||||
|
Returns the email content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message_id (str): The unique identifier of the email to read.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ToolResponse:
|
||||||
|
A ToolResponse object containing the email content or an
|
||||||
|
error message if the email is not found.
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
email_content = read_email_tool(message_id)
|
||||||
|
|
||||||
|
self.ever_read_message_ids.append(message_id)
|
||||||
|
|
||||||
|
if email_content is None:
|
||||||
|
return ToolResponse(
|
||||||
|
content=[
|
||||||
|
TextBlock(
|
||||||
|
type="text",
|
||||||
|
text=(
|
||||||
|
f"Error: Email (message_id = {message_id}) "
|
||||||
|
f"not found."
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
return ToolResponse(
|
||||||
|
content=[
|
||||||
|
TextBlock(
|
||||||
|
type="text",
|
||||||
|
text=json.dumps(email_content.model_dump()),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.info(
|
||||||
|
"Error in read_email: %s, traceback: %s",
|
||||||
|
e,
|
||||||
|
traceback.format_exc(),
|
||||||
|
)
|
||||||
|
return ToolResponse(
|
||||||
|
content=[
|
||||||
|
TextBlock(
|
||||||
|
type="text",
|
||||||
|
text=(
|
||||||
|
f"Error: Failed to read email.\n"
|
||||||
|
f"Error message: {e}"
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
328
tuner/email_search/_utils.py
Normal file
328
tuner/email_search/_utils.py
Normal file
@@ -0,0 +1,328 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
This file defines Dataclass and tool implementations.
|
||||||
|
Modified from https://github.com/OpenPipe/ART/blob/art-e/
|
||||||
|
"""
|
||||||
|
import datetime
|
||||||
|
import os
|
||||||
|
import sqlite3
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, List, Optional
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
from agentscope import logger
|
||||||
|
|
||||||
|
DEFAULT_DB_PATH = os.environ.get("DEFAULT_EMAIL_DB_PATH")
|
||||||
|
conn = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_conn() -> sqlite3.Connection:
|
||||||
|
"""Get or create a database connection."""
|
||||||
|
global conn
|
||||||
|
if conn is None:
|
||||||
|
conn = sqlite3.connect(
|
||||||
|
f"file:{DEFAULT_DB_PATH}?mode=ro",
|
||||||
|
uri=True,
|
||||||
|
check_same_thread=False,
|
||||||
|
)
|
||||||
|
return conn
|
||||||
|
|
||||||
|
|
||||||
|
class QueryModel(BaseModel):
|
||||||
|
"""Model for email search query."""
|
||||||
|
|
||||||
|
id: int
|
||||||
|
question: str
|
||||||
|
answer: str
|
||||||
|
message_ids: List[str] # message_ids (strings) of referenced emails
|
||||||
|
how_realistic: float
|
||||||
|
inbox_address: str
|
||||||
|
query_date: str
|
||||||
|
|
||||||
|
@field_validator("query_date", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def format_date(cls, v: Any) -> str:
|
||||||
|
"""Format date to string if it's a datetime object."""
|
||||||
|
if isinstance(v, datetime.datetime):
|
||||||
|
return v.strftime("%Y-%m-%d")
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
class AnswerModel(BaseModel):
|
||||||
|
"""Model for agent's answer with sources."""
|
||||||
|
|
||||||
|
answer: str = Field(
|
||||||
|
description=(
|
||||||
|
"It should be called with the answer and the sources. "
|
||||||
|
"If you cannot find the answer, you should return "
|
||||||
|
"'I don't know' with an empty list of sources."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
sources: List[str] = Field(
|
||||||
|
description=(
|
||||||
|
"a list of message ids that are relevant to the query. "
|
||||||
|
"Usually there will be only one. If you cannot find the "
|
||||||
|
"answer, you should return an empty list."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Email(BaseModel):
|
||||||
|
"""Model representing an email."""
|
||||||
|
|
||||||
|
message_id: str
|
||||||
|
date: str # ISO 8601 string 'YYYY-MM-DD HH:MM:SS'
|
||||||
|
subject: Optional[str] = None
|
||||||
|
from_address: Optional[str] = None
|
||||||
|
to_addresses: List[str] = Field(default_factory=list)
|
||||||
|
cc_addresses: List[str] = Field(default_factory=list)
|
||||||
|
bcc_addresses: List[str] = Field(default_factory=list)
|
||||||
|
body: Optional[str] = None
|
||||||
|
file_name: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SearchResult:
|
||||||
|
"""Result from email search."""
|
||||||
|
|
||||||
|
message_id: str
|
||||||
|
snippet: str
|
||||||
|
|
||||||
|
|
||||||
|
class FinalRubric(BaseModel):
|
||||||
|
"""Rubric for evaluating agent performance."""
|
||||||
|
|
||||||
|
answer_correct: bool = False
|
||||||
|
sources_correct: bool = False
|
||||||
|
num_turns: int = 0
|
||||||
|
attempted_answer: bool = False
|
||||||
|
ever_found_right_email: bool = False
|
||||||
|
ever_read_right_email: bool = False
|
||||||
|
cant_parse_tool_call: bool = False
|
||||||
|
bad_tool_call_name: bool = False
|
||||||
|
bad_tool_call_args: bool = False
|
||||||
|
ran_out_of_turns: bool = False
|
||||||
|
returned_i_dont_know: bool = False
|
||||||
|
num_sources: int = 0
|
||||||
|
ever_tried_to_read_invalid_email: bool = False
|
||||||
|
prompt_tokens: int = 0
|
||||||
|
completion_tokens: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
# Define tools for agent
|
||||||
|
|
||||||
|
|
||||||
|
def search_emails_tool(
|
||||||
|
inbox: str,
|
||||||
|
keywords: List[str],
|
||||||
|
from_addr: Optional[str] = None,
|
||||||
|
to_addr: Optional[str] = None,
|
||||||
|
sent_after: Optional[str] = None,
|
||||||
|
sent_before: Optional[str] = None,
|
||||||
|
max_results: int = 10,
|
||||||
|
) -> List[SearchResult]:
|
||||||
|
"""
|
||||||
|
Searches the email database based on keywords, inbox,
|
||||||
|
sender, recipient, and date range.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inbox: The email address of the user performing the search.
|
||||||
|
Results include emails sent from or to (inc. cc/bcc)
|
||||||
|
this address.
|
||||||
|
keywords: A list of keywords that must all appear in the
|
||||||
|
subject or body.
|
||||||
|
from_addr: Optional email address to filter emails sent *from*.
|
||||||
|
to_addr: Optional email address to filter emails sent *to*
|
||||||
|
(inc. cc/bcc).
|
||||||
|
sent_after: Optional date string 'YYYY-MM-DD'. Filters for
|
||||||
|
emails sent on or after this date.
|
||||||
|
sent_before: Optional date string 'YYYY-MM-DD'. Filters for
|
||||||
|
emails sent before this date.
|
||||||
|
max_results: The maximum number of results to return.
|
||||||
|
Cannot exceed 10.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of SearchResult objects, each containing 'message_id'
|
||||||
|
and 'snippet'. Returns an empty list if no results are found
|
||||||
|
or an error occurs.
|
||||||
|
"""
|
||||||
|
# Initialize sql and params
|
||||||
|
sql: Optional[str] = None
|
||||||
|
params: List[str | int] = []
|
||||||
|
|
||||||
|
cursor = get_conn().cursor()
|
||||||
|
|
||||||
|
# --- Build Query ---
|
||||||
|
where_clauses: List[str] = []
|
||||||
|
|
||||||
|
# 1. Keywords (FTS)
|
||||||
|
if not keywords:
|
||||||
|
raise ValueError("No keywords provided for search.")
|
||||||
|
|
||||||
|
if max_results > 10:
|
||||||
|
raise ValueError("max_results must be less than or equal to 10.")
|
||||||
|
|
||||||
|
# FTS5 default is AND, so just join keywords. Escape quotes for safety.
|
||||||
|
fts_query = " ".join(f""" "{k.replace('"', '""')}" """ for k in keywords)
|
||||||
|
where_clauses.append("emails_fts MATCH ?")
|
||||||
|
params.append(fts_query)
|
||||||
|
|
||||||
|
# 2. Inbox filter (must be from OR to/cc/bcc the inbox user)
|
||||||
|
# Use the composite index idx_recipients_address_email here
|
||||||
|
where_clauses.append(
|
||||||
|
"""
|
||||||
|
(e.from_address = ? OR EXISTS (
|
||||||
|
SELECT 1 FROM recipients r_inbox
|
||||||
|
WHERE r_inbox.recipient_address = ? AND r_inbox.email_id = e.id
|
||||||
|
))
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
params.extend([inbox, inbox])
|
||||||
|
|
||||||
|
# 3. Optional From filter
|
||||||
|
if from_addr:
|
||||||
|
where_clauses.append("e.from_address = ?")
|
||||||
|
params.append(from_addr)
|
||||||
|
|
||||||
|
# 4. Optional To filter (includes to, cc, bcc)
|
||||||
|
# Use composite index idx_recipients_address_email
|
||||||
|
if to_addr:
|
||||||
|
where_clauses.append(
|
||||||
|
"""
|
||||||
|
EXISTS (
|
||||||
|
SELECT 1 FROM recipients r_to
|
||||||
|
WHERE r_to.recipient_address = ? AND r_to.email_id = e.id
|
||||||
|
)
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
params.append(to_addr)
|
||||||
|
|
||||||
|
# 5. Optional Sent After filter
|
||||||
|
if sent_after:
|
||||||
|
# Assumes date format 'YYYY-MM-DD'
|
||||||
|
# Compare against the start of the day
|
||||||
|
where_clauses.append("e.date >= ?")
|
||||||
|
params.append(f"{sent_after} 00:00:00")
|
||||||
|
|
||||||
|
# 6. Optional Sent Before filter
|
||||||
|
if sent_before:
|
||||||
|
# Assumes date format 'YYYY-MM-DD'
|
||||||
|
# Compare against the start of the day (exclusive)
|
||||||
|
where_clauses.append("e.date < ?")
|
||||||
|
params.append(f"{sent_before} 00:00:00")
|
||||||
|
|
||||||
|
# --- Construct Final Query ---
|
||||||
|
# snippet(<table>, <column_index>, <highlight_start>,
|
||||||
|
# <highlight_end>, <ellipsis>, <tokens>)
|
||||||
|
# -1 means highlight across all columns (subject, body)
|
||||||
|
sql = f"""
|
||||||
|
SELECT
|
||||||
|
e.message_id,
|
||||||
|
snippet(emails_fts, -1, '<b>', '</b>', ' ... ', 15) as snippet
|
||||||
|
FROM
|
||||||
|
emails e JOIN emails_fts fts ON e.id = fts.rowid
|
||||||
|
WHERE
|
||||||
|
{" AND ".join(where_clauses)}
|
||||||
|
ORDER BY
|
||||||
|
e.date DESC -- Order by date for relevance
|
||||||
|
LIMIT ?;
|
||||||
|
"""
|
||||||
|
params.append(max_results)
|
||||||
|
|
||||||
|
# --- Execute and Fetch ---
|
||||||
|
logger.debug("Executing SQL: %s", sql)
|
||||||
|
logger.debug("With params: %s", params)
|
||||||
|
cursor.execute(sql, params)
|
||||||
|
results = cursor.fetchall()
|
||||||
|
|
||||||
|
# Format results
|
||||||
|
formatted_results = [
|
||||||
|
SearchResult(message_id=row[0], snippet=row[1]) for row in results
|
||||||
|
]
|
||||||
|
logger.info("Search found %d results.", len(formatted_results))
|
||||||
|
return formatted_results
|
||||||
|
|
||||||
|
|
||||||
|
def read_email_tool(message_id: str) -> Optional[Email]:
|
||||||
|
"""
|
||||||
|
Retrieves a single email by its message_id from the database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message_id: The unique identifier of the email to retrieve.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An Email object containing the details of the found email,
|
||||||
|
or None if the email is not found or an error occurs.
|
||||||
|
"""
|
||||||
|
cursor = get_conn().cursor()
|
||||||
|
|
||||||
|
# --- Query for Email Core Details ---
|
||||||
|
email_sql = """
|
||||||
|
SELECT id, message_id, date, subject, from_address, body, file_name
|
||||||
|
FROM emails
|
||||||
|
WHERE message_id = ?;
|
||||||
|
"""
|
||||||
|
cursor.execute(email_sql, (message_id,))
|
||||||
|
email_row = cursor.fetchone()
|
||||||
|
|
||||||
|
if not email_row:
|
||||||
|
logger.warning("Email with message_id '%s' not found.", message_id)
|
||||||
|
return None
|
||||||
|
|
||||||
|
email_pk_id, msg_id, date, subject, from_addr, body, file_name = email_row
|
||||||
|
|
||||||
|
# DEBUG
|
||||||
|
logger.info("[read_email_tool] input_message_id=%s", message_id)
|
||||||
|
logger.info(
|
||||||
|
"[read_email_tool] db: id=%s, message_id=%s",
|
||||||
|
email_pk_id,
|
||||||
|
msg_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# search for recipients by emails.id (rather than message_id)
|
||||||
|
recipients_sql = """
|
||||||
|
SELECT recipient_address, recipient_type
|
||||||
|
FROM recipients
|
||||||
|
WHERE email_id = ?;
|
||||||
|
"""
|
||||||
|
cursor.execute(recipients_sql, (email_pk_id,))
|
||||||
|
recipient_rows = cursor.fetchall()
|
||||||
|
|
||||||
|
to_addresses: List[str] = []
|
||||||
|
cc_addresses: List[str] = []
|
||||||
|
bcc_addresses: List[str] = []
|
||||||
|
|
||||||
|
for addr, rtype in recipient_rows:
|
||||||
|
type_lower = rtype.lower()
|
||||||
|
if type_lower == "to":
|
||||||
|
to_addresses.append(addr)
|
||||||
|
elif type_lower == "cc":
|
||||||
|
cc_addresses.append(addr)
|
||||||
|
elif type_lower == "bcc":
|
||||||
|
bcc_addresses.append(addr)
|
||||||
|
|
||||||
|
# --- Construct Email Object ---
|
||||||
|
email_obj = Email(
|
||||||
|
message_id=msg_id, # Convert to string to match Pydantic model
|
||||||
|
date=date,
|
||||||
|
subject=subject,
|
||||||
|
from_address=from_addr,
|
||||||
|
to_addresses=to_addresses,
|
||||||
|
cc_addresses=cc_addresses,
|
||||||
|
bcc_addresses=bcc_addresses,
|
||||||
|
body=body,
|
||||||
|
file_name=file_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
return email_obj
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"QueryModel",
|
||||||
|
"AnswerModel",
|
||||||
|
"FinalRubric",
|
||||||
|
"Email",
|
||||||
|
"SearchResult",
|
||||||
|
"search_emails_tool",
|
||||||
|
"read_email_tool",
|
||||||
|
"get_conn",
|
||||||
|
]
|
||||||
72
tuner/email_search/config.yaml
Normal file
72
tuner/email_search/config.yaml
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
project: "AgentScope" # Project name
|
||||||
|
name: "Email_search" # Experiment name
|
||||||
|
checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints} # Directory to save model checkpoints
|
||||||
|
algorithm:
|
||||||
|
algorithm_type: multi_step_grpo # GRPO series for multi-step scenario
|
||||||
|
repeat_times: 8 # Number of rollouts per prompt for advantage estimation
|
||||||
|
optimizer:
|
||||||
|
lr: 1e-6 # Learning rate
|
||||||
|
policy_loss_fn: "rec" # Policy loss function
|
||||||
|
policy_loss_fn_args: # Policy loss function arguments
|
||||||
|
epsilon_low: 0.2
|
||||||
|
epsilon_high: 0.2
|
||||||
|
clip_mode: "one-side"
|
||||||
|
weight: "none"
|
||||||
|
temp: 1.0
|
||||||
|
regularizer: "none"
|
||||||
|
regularizer_coef: 0.0
|
||||||
|
kl_loss_fn: 'k2' # KL divergence loss function
|
||||||
|
kl_loss_fn_args:
|
||||||
|
kl_coef: 0.0 # KL divergence coefficient
|
||||||
|
advantage_fn_args:
|
||||||
|
std_cal_level: 'batch' # Advantage normalization level
|
||||||
|
model:
|
||||||
|
model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen3-4B-Instruct-2507} # Base model path
|
||||||
|
max_response_tokens: 4096 # Max tokens per response
|
||||||
|
max_model_len: 20480 # Max context length
|
||||||
|
buffer:
|
||||||
|
total_epochs: 10 # Total training epochs
|
||||||
|
batch_size: 64 # Batch size per explore step
|
||||||
|
train_batch_size: 2560 # 64*8*5, total experiences per training step
|
||||||
|
trainer_input:
|
||||||
|
experience_buffer:
|
||||||
|
name: experience_buffer
|
||||||
|
storage_type: queue
|
||||||
|
replay_buffer:
|
||||||
|
enable: true # Enable experience replay
|
||||||
|
priority_fn: 'decay_limit_randomization'
|
||||||
|
priority_fn_args:
|
||||||
|
decay: 2.0
|
||||||
|
use_count_limit: 3
|
||||||
|
sigma: 2.0
|
||||||
|
explorer:
|
||||||
|
eval_interval: 10
|
||||||
|
max_repeat_times_per_runner: 1 # Max repeat times per runner
|
||||||
|
max_timeout: 3600 # Max timeout for each rollout (seconds)
|
||||||
|
rollout_model:
|
||||||
|
enable_history: true # Enable conversation history
|
||||||
|
enable_openai_api: true # Enable OpenAI-compatible API
|
||||||
|
enable_auto_tool_choice: true # Enable automatic tool selection
|
||||||
|
tool_call_parser: hermes # Parser for tool calls
|
||||||
|
engine_num: 4 # Number of vLLM engines for rollout model
|
||||||
|
tensor_parallel_size: 1 # TP size per engine for rollout model
|
||||||
|
enable_prefix_caching: false # Disable prefix caching
|
||||||
|
auxiliary_models:
|
||||||
|
- name: judge
|
||||||
|
model_path: Qwen/Qwen3-30B-A3B-Instruct-2507 # Judge model path
|
||||||
|
engine_num: 1 # Number of vLLM engines for judge model
|
||||||
|
tensor_parallel_size: 2 # TP size per engine for judge model
|
||||||
|
enable_thinking: false # Disable thinking/reasoning mode
|
||||||
|
max_prompt_tokens: 2048 # Max tokens for prompt
|
||||||
|
max_response_tokens: 128 # Max tokens for response
|
||||||
|
max_model_len: 2500 # Max model context length
|
||||||
|
synchronizer:
|
||||||
|
sync_style: dynamic_by_explorer # Sync triggered dynamically by explorer
|
||||||
|
sync_interval: 5 # Sync every N steps
|
||||||
|
sync_timeout: 3600 # Timeout for synchronization (seconds)
|
||||||
|
trainer:
|
||||||
|
save_interval: 100 # Save checkpoint every N steps
|
||||||
|
grad_clip: 1.0 # Gradient clipping value
|
||||||
|
use_dynamic_bsz: true # Use dynamic batch size
|
||||||
|
max_token_len_per_gpu: 16384 # Max token length per GPU
|
||||||
|
ulysses_sequence_parallel_size: 1 # Sequence parallel size for Ulysses
|
||||||
BIN
tuner/email_search/critic_reward_mean.png
Normal file
BIN
tuner/email_search/critic_reward_mean.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 470 KiB |
379
tuner/email_search/main.py
Normal file
379
tuner/email_search/main.py
Normal file
@@ -0,0 +1,379 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""Example of training an Email Search agent with Trinity-RFT."""
|
||||||
|
import os
|
||||||
|
from typing import Dict
|
||||||
|
from _email_search_agent import EmailSearchAgent
|
||||||
|
from _utils import ( # pylint: disable=E0611
|
||||||
|
AnswerModel,
|
||||||
|
FinalRubric,
|
||||||
|
QueryModel,
|
||||||
|
)
|
||||||
|
from agentscope import logger
|
||||||
|
from agentscope.formatter import OpenAIChatFormatter
|
||||||
|
from agentscope.message import Msg
|
||||||
|
from agentscope.tuner import (
|
||||||
|
TunerModelConfig,
|
||||||
|
DatasetConfig,
|
||||||
|
JudgeOutput,
|
||||||
|
WorkflowOutput,
|
||||||
|
AlgorithmConfig,
|
||||||
|
tune,
|
||||||
|
)
|
||||||
|
from agentscope.model import ChatModelBase
|
||||||
|
|
||||||
|
|
||||||
|
SYSTEM_PROMPT = """You are an email search agent. You are given a user query
|
||||||
|
and a list of tools you can use to search the user's email. Use the tools to
|
||||||
|
search the user's emails and find the answer to the user's query. You may take
|
||||||
|
up to {max_turns} turns to find the answer, so if your first seach doesn't
|
||||||
|
find the answer, you can try with different keywords.
|
||||||
|
|
||||||
|
Always describe what you see and plan your next steps clearly. When taking
|
||||||
|
actions, explain what you're doing and why. When the answer to the task is
|
||||||
|
found, call `generate_response` to finish the process. Only call
|
||||||
|
`generate_response` when answer is found. You should not respond any next steps
|
||||||
|
in `generate_response`. Complete all steps and then call `generate_response`.
|
||||||
|
|
||||||
|
User's email address is {inbox_address}
|
||||||
|
|
||||||
|
Today's date is {query_date}
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
async def run_email_search_agent(
|
||||||
|
task: Dict,
|
||||||
|
model: ChatModelBase,
|
||||||
|
auxiliary_models: Dict[str, ChatModelBase],
|
||||||
|
) -> WorkflowOutput: # noqa: PLR0915
|
||||||
|
"""A workflow function using the Email Search agent to solve tasks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task (Dict): The task to be solved.
|
||||||
|
Should contain fields from QueryModel.
|
||||||
|
model (TrinityChatModel): The language model to use.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
WorkflowOutput: The output containing the agent's response.
|
||||||
|
"""
|
||||||
|
assert len(auxiliary_models) > 0, "LLM-as-a-Judge is required"
|
||||||
|
|
||||||
|
# Parse task data
|
||||||
|
query = QueryModel.model_validate(task)
|
||||||
|
question = task.get("question", task.get("task_desc", ""))
|
||||||
|
|
||||||
|
# Get workflow arguments with defaults
|
||||||
|
workflow_args = task.get("workflow_args", {})
|
||||||
|
max_turns = int(workflow_args.get("max_turns", 10))
|
||||||
|
|
||||||
|
# Format system prompt
|
||||||
|
system_prompt = SYSTEM_PROMPT.format(
|
||||||
|
max_turns=max_turns,
|
||||||
|
inbox_address=query.inbox_address,
|
||||||
|
query_date=query.query_date,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create EmailSearchAgent
|
||||||
|
agent = EmailSearchAgent(
|
||||||
|
name="email_search_agent",
|
||||||
|
sys_prompt=system_prompt,
|
||||||
|
model=model,
|
||||||
|
formatter=OpenAIChatFormatter(),
|
||||||
|
max_iters=max_turns,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reset agent state for a new rollout
|
||||||
|
await agent.reset()
|
||||||
|
|
||||||
|
# Run the agent with structured output
|
||||||
|
response = await agent.reply(
|
||||||
|
msg=Msg("user", question, role="user"),
|
||||||
|
structured_model=AnswerModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract answer and sources from response metadata
|
||||||
|
answer_and_sources = response.metadata or {}
|
||||||
|
if not answer_and_sources:
|
||||||
|
# Fallback: try to parse from content
|
||||||
|
answer_and_sources = {
|
||||||
|
"answer": response.get_text_content() or "",
|
||||||
|
"sources": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Store agent state for judge function
|
||||||
|
# We'll pass this through the response metadata
|
||||||
|
response_metadata = {
|
||||||
|
"answer_and_sources": answer_and_sources,
|
||||||
|
"query": query.model_dump(),
|
||||||
|
"message_id_list": agent.message_id_list,
|
||||||
|
"ever_read_message_ids": agent.ever_read_message_ids,
|
||||||
|
# Estimate actual_turns from memory length
|
||||||
|
"actual_turns": (
|
||||||
|
max(1, (len(agent.memory.content) - 1) // 2)
|
||||||
|
if len(agent.memory.content) > 1
|
||||||
|
else 1
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Update response metadata
|
||||||
|
if response.metadata is None:
|
||||||
|
response.metadata = {}
|
||||||
|
response.metadata.update(response_metadata)
|
||||||
|
|
||||||
|
return WorkflowOutput(
|
||||||
|
response=response,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _calculate_partial_rewards(rubric: FinalRubric) -> float:
|
||||||
|
"""Calculate partial rewards based on rubric."""
|
||||||
|
partial_rewards = 0.0
|
||||||
|
partial_rewards += 0.1 if rubric.ever_found_right_email else 0
|
||||||
|
partial_rewards += 0.1 if rubric.ever_read_right_email else 0
|
||||||
|
partial_rewards += 0.1 if rubric.sources_correct else 0
|
||||||
|
return partial_rewards
|
||||||
|
|
||||||
|
|
||||||
|
def _calculate_correct_answer_reward(
|
||||||
|
rubric: FinalRubric,
|
||||||
|
max_turns: int,
|
||||||
|
) -> float:
|
||||||
|
"""Calculate reward for correct answers."""
|
||||||
|
reward = 1.0
|
||||||
|
reward += 0.3 if rubric.sources_correct else 0
|
||||||
|
reward += 0.1 / rubric.num_sources if rubric.num_sources > 0 else 0
|
||||||
|
reward += 0.1 * (1 - rubric.num_turns / max_turns)
|
||||||
|
return reward
|
||||||
|
|
||||||
|
|
||||||
|
def _initialize_rubric(
|
||||||
|
answer: str,
|
||||||
|
sources: list[str],
|
||||||
|
actual_turns: int,
|
||||||
|
query: QueryModel,
|
||||||
|
message_id_list: list[str],
|
||||||
|
ever_read_message_ids: list[str],
|
||||||
|
) -> FinalRubric:
|
||||||
|
"""Initialize and populate rubric with basic information."""
|
||||||
|
rubric = FinalRubric()
|
||||||
|
rubric.attempted_answer = answer is not None and answer != ""
|
||||||
|
rubric.returned_i_dont_know = answer == "I don't know"
|
||||||
|
rubric.num_sources = len(sources)
|
||||||
|
rubric.num_turns = actual_turns
|
||||||
|
|
||||||
|
if len(query.message_ids) > 0:
|
||||||
|
rubric.ever_found_right_email = query.message_ids[0] in message_id_list
|
||||||
|
rubric.ever_read_right_email = (
|
||||||
|
query.message_ids[0] in ever_read_message_ids
|
||||||
|
)
|
||||||
|
rubric.sources_correct = query.message_ids[0] in sources
|
||||||
|
return rubric
|
||||||
|
|
||||||
|
|
||||||
|
async def email_search_judge(
|
||||||
|
task: Dict,
|
||||||
|
response: Msg,
|
||||||
|
auxiliary_models: Dict[str, ChatModelBase],
|
||||||
|
) -> JudgeOutput:
|
||||||
|
"""A judge function to calculate reward based on agent's response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task (Dict): The task information for the corresponding workflow.
|
||||||
|
response (Msg): The response generated by the corresponding workflow.
|
||||||
|
auxiliary_models (Dict[str, ChatModelBase]):
|
||||||
|
A dictionary of additional chat models available for LLM-as-a-Judge
|
||||||
|
usage. The keys are model names, and the values are the
|
||||||
|
corresponding ChatModelBase instances.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JudgeOutput: The reward value assigned by the judge function.
|
||||||
|
"""
|
||||||
|
# Extract metadata from response
|
||||||
|
metadata = response.metadata or {}
|
||||||
|
answer_and_sources = metadata.get("answer_and_sources", {})
|
||||||
|
query_dict = metadata.get("query", {})
|
||||||
|
message_id_list = metadata.get("message_id_list", [])
|
||||||
|
ever_read_message_ids = metadata.get("ever_read_message_ids", [])
|
||||||
|
actual_turns = metadata.get("actual_turns", 0)
|
||||||
|
|
||||||
|
# Parse query model
|
||||||
|
if not query_dict:
|
||||||
|
query_dict = task
|
||||||
|
query = QueryModel.model_validate(query_dict)
|
||||||
|
|
||||||
|
# Get arguments
|
||||||
|
workflow_args = task.get("workflow_args", {})
|
||||||
|
max_turns = int(workflow_args.get("max_turns", 10))
|
||||||
|
|
||||||
|
# Extract answer and sources
|
||||||
|
try:
|
||||||
|
answer = answer_and_sources.get("answer", None)
|
||||||
|
sources = answer_and_sources.get("sources", [])
|
||||||
|
except Exception:
|
||||||
|
result = {"accuracy": 0.0, "format": -1.0}
|
||||||
|
return JudgeOutput(
|
||||||
|
reward=sum(result.values()),
|
||||||
|
metrics=result,
|
||||||
|
)
|
||||||
|
|
||||||
|
if answer is None:
|
||||||
|
result = {"accuracy": 0.0, "format": -1.0}
|
||||||
|
return JudgeOutput(
|
||||||
|
reward=sum(result.values()),
|
||||||
|
metrics=result,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize rubric
|
||||||
|
rubric = _initialize_rubric(
|
||||||
|
answer,
|
||||||
|
sources,
|
||||||
|
actual_turns,
|
||||||
|
query,
|
||||||
|
message_id_list,
|
||||||
|
ever_read_message_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Judge correctness using LLM-as-a-Judge
|
||||||
|
try:
|
||||||
|
judge_model = (
|
||||||
|
auxiliary_models.get("judge") or list(auxiliary_models.values())[0]
|
||||||
|
if auxiliary_models
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
judge_response = await judge_correctness(
|
||||||
|
answer,
|
||||||
|
query,
|
||||||
|
judge_model,
|
||||||
|
)
|
||||||
|
rubric.answer_correct = judge_response
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error judging correctness: %s", e)
|
||||||
|
rubric.answer_correct = False
|
||||||
|
|
||||||
|
# Calculate rewards
|
||||||
|
partial_rewards = _calculate_partial_rewards(rubric)
|
||||||
|
|
||||||
|
if rubric.attempted_answer and not rubric.answer_correct:
|
||||||
|
result = {"accuracy": -1.0, "format": partial_rewards}
|
||||||
|
elif rubric.returned_i_dont_know or rubric.ran_out_of_turns:
|
||||||
|
result = {"accuracy": 0.0, "format": partial_rewards}
|
||||||
|
elif rubric.answer_correct:
|
||||||
|
reward = _calculate_correct_answer_reward(rubric, max_turns)
|
||||||
|
result = {"accuracy": 1.0, "format": reward}
|
||||||
|
else:
|
||||||
|
result = {"accuracy": 0.0, "format": 0.0}
|
||||||
|
|
||||||
|
metrics = result.copy()
|
||||||
|
metrics.update({"actual_turns": actual_turns})
|
||||||
|
|
||||||
|
return JudgeOutput(
|
||||||
|
reward=sum(result.values()),
|
||||||
|
metrics=metrics,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# LLM-as-a-judge
|
||||||
|
|
||||||
|
|
||||||
|
async def judge_correctness(
|
||||||
|
answer: str,
|
||||||
|
query: QueryModel,
|
||||||
|
judge: ChatModelBase,
|
||||||
|
) -> bool:
|
||||||
|
"""Use an LLM to decide whether *answer* matches *query.answer*.
|
||||||
|
|
||||||
|
Returns a boolean *accept* flag used for scoring.
|
||||||
|
"""
|
||||||
|
|
||||||
|
system_prompt = """You are given a question, the reference answer
|
||||||
|
(labelled **Reference answer**), and an answer generated by an AI assistant
|
||||||
|
(labelled **AI answer**).
|
||||||
|
|
||||||
|
Follow these steps to decide whether the AI answer should be accepted:
|
||||||
|
1. Identify EXACTLY what information the **question** is asking for
|
||||||
|
(e.g. who, what, when, where, why, how, quantity, etc.).
|
||||||
|
2. From the **Reference answer**, extract ONLY the facts that are required
|
||||||
|
to directly satisfy the information need identified in step 1. Treat all
|
||||||
|
other facts as non-essential context.
|
||||||
|
3. Verify that every essential fact from step 2 appears in the **AI answer**
|
||||||
|
with the same meaning. Differences in wording, order, or additional
|
||||||
|
non-conflicting details are allowed.
|
||||||
|
4. If any essential fact is missing or contradicted in the **AI answer**,
|
||||||
|
then *accept* must be **false**. Otherwise *accept* must be **true**.
|
||||||
|
|
||||||
|
Important: Do NOT penalise the **AI answer** for omitting non-essential
|
||||||
|
facts that appear in the **Reference answer**. The answer should only be
|
||||||
|
rejected for errors or omissions in the information explicitly requested by
|
||||||
|
the question.
|
||||||
|
|
||||||
|
Return your judgement **accept** from **true** and **false**. Do not return
|
||||||
|
any other text or formatting.
|
||||||
|
"""
|
||||||
|
prompt = (
|
||||||
|
f"Question: {query.question}\n"
|
||||||
|
f"Reference answer: {query.answer}\n"
|
||||||
|
f"AI answer: {answer}"
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": system_prompt},
|
||||||
|
{"role": "user", "content": prompt},
|
||||||
|
]
|
||||||
|
chat_response = await judge(messages)
|
||||||
|
|
||||||
|
# Extract text content from ChatResponse
|
||||||
|
result_parts = []
|
||||||
|
for block in chat_response.content:
|
||||||
|
if isinstance(block, dict) and block.get("type") == "text":
|
||||||
|
result_parts.append(str(block.get("text", "")))
|
||||||
|
result = "".join(result_parts)
|
||||||
|
logger.info("LLM judge response: %s", result)
|
||||||
|
|
||||||
|
return "true" in result.lower()
|
||||||
|
|
||||||
|
|
||||||
|
# End of LLM-as-a-judge
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
config_path = os.path.join(
|
||||||
|
os.path.dirname(__file__),
|
||||||
|
"config.yaml",
|
||||||
|
)
|
||||||
|
dataset = DatasetConfig(
|
||||||
|
path="/path/to/enron_emails_dataset",
|
||||||
|
split="train",
|
||||||
|
)
|
||||||
|
tuner_model = TunerModelConfig(
|
||||||
|
model_path="Qwen/Qwen3-4B-Instruct-2507",
|
||||||
|
max_model_len=20480,
|
||||||
|
max_tokens=4096,
|
||||||
|
inference_engine_num=4,
|
||||||
|
reasoning_parser=None,
|
||||||
|
)
|
||||||
|
aux_models = {
|
||||||
|
"judge": TunerModelConfig(
|
||||||
|
model_path="Qwen/Qwen3-30B-A3B-Instruct-2507",
|
||||||
|
max_model_len=2500,
|
||||||
|
max_tokens=2048,
|
||||||
|
inference_engine_num=1,
|
||||||
|
tensor_parallel_size=2,
|
||||||
|
tool_call_parser=None,
|
||||||
|
reasoning_parser=None,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
algorithm = AlgorithmConfig(
|
||||||
|
algorithm_type="multi_step_grpo",
|
||||||
|
group_size=8,
|
||||||
|
batch_size=64,
|
||||||
|
learning_rate=1e-6,
|
||||||
|
)
|
||||||
|
tune(
|
||||||
|
workflow_func=run_email_search_agent,
|
||||||
|
judge_func=email_search_judge,
|
||||||
|
train_dataset=dataset,
|
||||||
|
model=tuner_model,
|
||||||
|
auxiliary_models=aux_models,
|
||||||
|
algorithm=algorithm,
|
||||||
|
config_path=config_path,
|
||||||
|
)
|
||||||
357
tuner/email_search/prepare_data.py
Normal file
357
tuner/email_search/prepare_data.py
Normal file
@@ -0,0 +1,357 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
Prepare data for training.
|
||||||
|
Modified from OpenPipe/ART
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sqlite3
|
||||||
|
from datetime import datetime
|
||||||
|
from datasets import Dataset, Features, Sequence, Value, load_dataset
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
# Resolve paths relative to this file
|
||||||
|
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
# Database will live in "../data/enron_emails.db" relative to project root
|
||||||
|
DEFAULT_DB_PATH = os.path.join(BASE_DIR, "..", "..", "data", "enron_emails.db")
|
||||||
|
|
||||||
|
DEFAULT_REPO_ID = "corbt/enron-emails"
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# --- Database Schema ---
|
||||||
|
SQL_CREATE_TABLES = """
|
||||||
|
DROP TABLE IF EXISTS recipients;
|
||||||
|
DROP TABLE IF EXISTS emails_fts;
|
||||||
|
DROP TABLE IF EXISTS emails;
|
||||||
|
|
||||||
|
CREATE TABLE emails (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
message_id TEXT UNIQUE,
|
||||||
|
subject TEXT,
|
||||||
|
from_address TEXT,
|
||||||
|
date TEXT, -- Store as ISO 8601 string 'YYYY-MM-DD HH:MM:SS'
|
||||||
|
body TEXT,
|
||||||
|
file_name TEXT
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE recipients (
|
||||||
|
email_id INTEGER,
|
||||||
|
recipient_address TEXT,
|
||||||
|
recipient_type TEXT, -- 'to', 'cc', 'bcc'
|
||||||
|
FOREIGN KEY(email_id) REFERENCES emails(id) ON DELETE CASCADE
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
|
||||||
|
SQL_CREATE_INDEXES_TRIGGERS = """
|
||||||
|
CREATE INDEX idx_emails_from ON emails(from_address);
|
||||||
|
CREATE INDEX idx_emails_date ON emails(date);
|
||||||
|
CREATE INDEX idx_emails_message_id ON emails(message_id);
|
||||||
|
CREATE INDEX idx_recipients_address ON recipients(recipient_address);
|
||||||
|
CREATE INDEX idx_recipients_type ON recipients(recipient_type);
|
||||||
|
CREATE INDEX idx_recipients_email_id ON recipients(email_id);
|
||||||
|
CREATE INDEX idx_recipients_address_email ON recipients(
|
||||||
|
recipient_address, email_id
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE VIRTUAL TABLE emails_fts USING fts5(
|
||||||
|
subject,
|
||||||
|
body,
|
||||||
|
content='emails',
|
||||||
|
content_rowid='id'
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TRIGGER emails_ai AFTER INSERT ON emails BEGIN
|
||||||
|
INSERT INTO emails_fts (rowid, subject, body)
|
||||||
|
VALUES (new.id, new.subject, new.body);
|
||||||
|
END;
|
||||||
|
|
||||||
|
CREATE TRIGGER emails_ad AFTER DELETE ON emails BEGIN
|
||||||
|
DELETE FROM emails_fts WHERE rowid=old.id;
|
||||||
|
END;
|
||||||
|
|
||||||
|
CREATE TRIGGER emails_au AFTER UPDATE ON emails BEGIN
|
||||||
|
UPDATE emails_fts SET subject=new.subject, body=new.body
|
||||||
|
WHERE rowid=old.id;
|
||||||
|
END;
|
||||||
|
|
||||||
|
INSERT INTO emails_fts (rowid, subject, body)
|
||||||
|
SELECT id, subject, body FROM emails;
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
# --- Functions ---
|
||||||
|
|
||||||
|
|
||||||
|
def download_dataset(repo_id: str) -> Dataset:
|
||||||
|
"""Downloads the dataset from Hugging Face Hub."""
|
||||||
|
logging.info(
|
||||||
|
"Attempting to download dataset from Hugging Face Hub: %s",
|
||||||
|
repo_id,
|
||||||
|
)
|
||||||
|
expected_features = Features(
|
||||||
|
{
|
||||||
|
"message_id": Value("string"),
|
||||||
|
"subject": Value("string"),
|
||||||
|
"from": Value("string"),
|
||||||
|
"to": Sequence(Value("string")),
|
||||||
|
"cc": Sequence(Value("string")),
|
||||||
|
"bcc": Sequence(Value("string")),
|
||||||
|
"date": Value("timestamp[us]"),
|
||||||
|
"body": Value("string"),
|
||||||
|
"file_name": Value("string"),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
dataset_obj = load_dataset(
|
||||||
|
repo_id,
|
||||||
|
features=expected_features,
|
||||||
|
split="train",
|
||||||
|
)
|
||||||
|
# Basic type check remains useful
|
||||||
|
if not isinstance(dataset_obj, Dataset):
|
||||||
|
raise TypeError(f"Expected Dataset, got {type(dataset_obj)}")
|
||||||
|
logging.info(
|
||||||
|
"Successfully loaded dataset '%s' with %d records.",
|
||||||
|
repo_id,
|
||||||
|
len(dataset_obj),
|
||||||
|
)
|
||||||
|
return dataset_obj
|
||||||
|
|
||||||
|
|
||||||
|
def create_database(db_path: str) -> None:
|
||||||
|
"""Creates the SQLite database and tables."""
|
||||||
|
logging.info("Creating SQLite database and tables at: %s", db_path)
|
||||||
|
conn = sqlite3.connect(db_path)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.executescript(SQL_CREATE_TABLES)
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
logging.info("Database tables created successfully.")
|
||||||
|
|
||||||
|
|
||||||
|
def _should_skip_email(
|
||||||
|
body: str,
|
||||||
|
message_id: str,
|
||||||
|
to_list: list[str],
|
||||||
|
cc_list: list[str],
|
||||||
|
bcc_list: list[str],
|
||||||
|
) -> bool:
|
||||||
|
"""Check if email should be skipped based on filters."""
|
||||||
|
if len(body) > 5000:
|
||||||
|
logging.debug(
|
||||||
|
"Skipping email %s: Body length > 5000 characters.",
|
||||||
|
message_id,
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
total_recipients = len(to_list) + len(cc_list) + len(bcc_list)
|
||||||
|
if total_recipients > 30:
|
||||||
|
logging.debug(
|
||||||
|
"Skipping email %s: Total recipients (%d) > 30.",
|
||||||
|
message_id,
|
||||||
|
total_recipients,
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _prepare_recipient_data(
|
||||||
|
email_pk_id: int,
|
||||||
|
to_list: list[str],
|
||||||
|
cc_list: list[str],
|
||||||
|
bcc_list: list[str],
|
||||||
|
) -> list[tuple[int, str, str]]:
|
||||||
|
"""Prepare recipient data for database insertion."""
|
||||||
|
recipient_data = []
|
||||||
|
for addr in to_list:
|
||||||
|
recipient_data.append((email_pk_id, addr, "to"))
|
||||||
|
for addr in cc_list:
|
||||||
|
recipient_data.append((email_pk_id, addr, "cc"))
|
||||||
|
for addr in bcc_list:
|
||||||
|
recipient_data.append((email_pk_id, addr, "bcc"))
|
||||||
|
return recipient_data
|
||||||
|
|
||||||
|
|
||||||
|
def populate_database(db_path: str, dataset: Dataset) -> None:
|
||||||
|
"""Populates the database with data from the Hugging Face dataset."""
|
||||||
|
logging.info("Populating database %s...", db_path)
|
||||||
|
conn = sqlite3.connect(db_path)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
# --- Performance Pragmas ---
|
||||||
|
conn.execute("PRAGMA synchronous = OFF;")
|
||||||
|
conn.execute("PRAGMA journal_mode = MEMORY;")
|
||||||
|
|
||||||
|
record_count = 0
|
||||||
|
skipped_count = 0
|
||||||
|
duplicate_count = 0
|
||||||
|
processed_emails = set()
|
||||||
|
|
||||||
|
conn.execute("BEGIN TRANSACTION;")
|
||||||
|
|
||||||
|
for email_data in tqdm(dataset, desc="Inserting emails"):
|
||||||
|
assert isinstance(email_data, dict)
|
||||||
|
message_id = email_data["message_id"]
|
||||||
|
subject = email_data["subject"]
|
||||||
|
from_address = email_data["from"]
|
||||||
|
date_obj: datetime = email_data["date"]
|
||||||
|
body = email_data["body"]
|
||||||
|
file_name = email_data["file_name"]
|
||||||
|
to_list_raw = email_data["to"]
|
||||||
|
cc_list_raw = email_data["cc"]
|
||||||
|
bcc_list_raw = email_data["bcc"]
|
||||||
|
|
||||||
|
date_str = date_obj.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
to_list = [str(addr) for addr in to_list_raw if addr]
|
||||||
|
cc_list = [str(addr) for addr in cc_list_raw if addr]
|
||||||
|
bcc_list = [str(addr) for addr in bcc_list_raw if addr]
|
||||||
|
|
||||||
|
if _should_skip_email(body, message_id, to_list, cc_list, bcc_list):
|
||||||
|
skipped_count += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
email_key = (subject, body, from_address)
|
||||||
|
if email_key in processed_emails:
|
||||||
|
logging.debug(
|
||||||
|
"Skipping duplicate email (Subject: %s..., From: %s)",
|
||||||
|
subject[:50],
|
||||||
|
from_address,
|
||||||
|
)
|
||||||
|
duplicate_count += 1
|
||||||
|
continue
|
||||||
|
processed_emails.add(email_key)
|
||||||
|
|
||||||
|
cursor.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO emails (
|
||||||
|
message_id, subject, from_address, date, body, file_name
|
||||||
|
)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?)
|
||||||
|
""",
|
||||||
|
(message_id, subject, from_address, date_str, body, file_name),
|
||||||
|
)
|
||||||
|
email_pk_id = cursor.lastrowid
|
||||||
|
if email_pk_id is None:
|
||||||
|
logging.warning(
|
||||||
|
"Failed to get email ID after insert for message_id: %s",
|
||||||
|
message_id,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
recipient_data = _prepare_recipient_data(
|
||||||
|
email_pk_id,
|
||||||
|
to_list,
|
||||||
|
cc_list,
|
||||||
|
bcc_list,
|
||||||
|
)
|
||||||
|
|
||||||
|
if recipient_data:
|
||||||
|
cursor.executemany(
|
||||||
|
"""
|
||||||
|
INSERT INTO recipients (
|
||||||
|
email_id, recipient_address, recipient_type
|
||||||
|
)
|
||||||
|
VALUES (?, ?, ?)
|
||||||
|
""",
|
||||||
|
recipient_data,
|
||||||
|
)
|
||||||
|
record_count += 1
|
||||||
|
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
logging.info("Successfully inserted %d email records.", record_count)
|
||||||
|
if skipped_count > 0:
|
||||||
|
logging.info(
|
||||||
|
"Skipped %d email records due to length or recipient limits.",
|
||||||
|
skipped_count,
|
||||||
|
)
|
||||||
|
if duplicate_count > 0:
|
||||||
|
logging.info(
|
||||||
|
"Skipped %d duplicate email records "
|
||||||
|
"(based on subject, body, from).",
|
||||||
|
duplicate_count,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_indexes_and_triggers(db_path: str) -> None:
|
||||||
|
"""Creates indexes and triggers on the populated database."""
|
||||||
|
logging.info("Creating indexes and triggers for database: %s...", db_path)
|
||||||
|
conn = sqlite3.connect(db_path)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.executescript(SQL_CREATE_INDEXES_TRIGGERS)
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
logging.info("Indexes and triggers created successfully.")
|
||||||
|
|
||||||
|
|
||||||
|
def generate_database(
|
||||||
|
repo_id: str = DEFAULT_REPO_ID,
|
||||||
|
db_path: str = DEFAULT_DB_PATH,
|
||||||
|
overwrite: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Generates the SQLite database from the specified Hugging Face dataset.
|
||||||
|
Simplified version without extensive error handling.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repo_id: The Hugging Face repository ID for the dataset.
|
||||||
|
db_path: The path where the SQLite database file should be
|
||||||
|
created.
|
||||||
|
overwrite: If True, any existing database file at db_path will
|
||||||
|
be removed.
|
||||||
|
"""
|
||||||
|
logging.info(
|
||||||
|
"Starting database generation for repo '%s' at '%s'",
|
||||||
|
repo_id,
|
||||||
|
db_path,
|
||||||
|
)
|
||||||
|
logging.info("Overwrite existing database: %s", overwrite)
|
||||||
|
|
||||||
|
db_dir = os.path.dirname(db_path)
|
||||||
|
if db_dir and not os.path.exists(db_dir):
|
||||||
|
logging.info("Creating data directory: %s", db_dir)
|
||||||
|
os.makedirs(db_dir)
|
||||||
|
|
||||||
|
if overwrite and os.path.exists(db_path):
|
||||||
|
logging.warning("Removing existing database file: %s", db_path)
|
||||||
|
os.remove(db_path)
|
||||||
|
elif not overwrite and os.path.exists(db_path):
|
||||||
|
# If not overwriting and file exists, subsequent steps might fail
|
||||||
|
# or behave unexpectedly. We are removing the explicit error here
|
||||||
|
# as requested.
|
||||||
|
logging.warning(
|
||||||
|
"Database file %s exists and overwrite is False. "
|
||||||
|
"Assuming file is already generated.",
|
||||||
|
db_path,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# 1. Download dataset
|
||||||
|
dataset = download_dataset(repo_id)
|
||||||
|
|
||||||
|
# 2. Create database schema (Tables only)
|
||||||
|
# Note: This will fail if overwrite=False and the file exists with
|
||||||
|
# incompatible schema/data.
|
||||||
|
create_database(db_path)
|
||||||
|
|
||||||
|
# 3. Populate database
|
||||||
|
populate_database(db_path, dataset)
|
||||||
|
|
||||||
|
# 4. Create Indexes and Triggers
|
||||||
|
create_indexes_and_triggers(db_path)
|
||||||
|
|
||||||
|
logging.info("Database generation process completed for %s.", db_path)
|
||||||
|
logging.info(
|
||||||
|
"Please set the environment variable DEFAULT_EMAIL_DB_PATH "
|
||||||
|
"to this path.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
generate_database(overwrite=True)
|
||||||
BIN
tuner/email_search/rollout_accuracy_mean.png
Normal file
BIN
tuner/email_search/rollout_accuracy_mean.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 442 KiB |
273
tuner/frozen_lake/README.md
Normal file
273
tuner/frozen_lake/README.md
Normal file
@@ -0,0 +1,273 @@
|
|||||||
|
# Training FrozenLake Agent with RL using AgentScope-Tuner
|
||||||
|
|
||||||
|
## Summary
|
||||||
|
|
||||||
|
This example demonstrates how to use AgentScope-Tuner to implement reinforcement fine-tuning for the [Frozen Lake](https://gymnasium.farama.org/environments/toy_text/frozen_lake/) task using [Trinity-RFT](https://github.com/agentscope-ai/Trinity-RFT). The agent learns to navigate a frozen lake grid from a starting position to a goal while avoiding holes through multi-step interactions with the environment.
|
||||||
|
|
||||||
|
## Task Setting
|
||||||
|
|
||||||
|
### Agent Goal
|
||||||
|
The agent's objective is to navigate from the starting position (S) to the goal position (G) on a frozen lake grid without falling into holes (H). The agent must:
|
||||||
|
- Plan a path through frozen tiles (F) to reach the goal
|
||||||
|
- Avoid holes that terminate the episode with zero reward
|
||||||
|
- Complete the task within a limited number of steps
|
||||||
|
|
||||||
|
### Agent Type
|
||||||
|
The agent is implemented as a **ReActAgent** (Reasoning and Acting Agent) that:
|
||||||
|
- Observes the current state of the frozen lake grid
|
||||||
|
- Reasons about the best action to take
|
||||||
|
- Executes actions (Up, Down, Left, Right) to move through the environment
|
||||||
|
- Maintains internal state across multiple steps in an episode
|
||||||
|
|
||||||
|
### Environment
|
||||||
|
The environment is based on Gymnasium's FrozenLake environment, wrapped to provide:
|
||||||
|
- **Grid-based navigation**: Randomly generated maps with configurable size (2x2 to 6x6)
|
||||||
|
- **Tile types**:
|
||||||
|
- `S`: Start position
|
||||||
|
- `F`: Frozen tile (safe to walk on)
|
||||||
|
- `H`: Hole (terminates episode with reward 0)
|
||||||
|
- `G`: Goal (terminates episode with reward +1.0)
|
||||||
|
- **Action space**: Discrete actions (Up, Down, Left, Right)
|
||||||
|
- **Reward structure**:
|
||||||
|
- +1.0 for reaching the goal
|
||||||
|
- 0.0 for falling into a hole or failing to reach the goal
|
||||||
|
- **Observations**: Text-based grid representation showing current player position
|
||||||
|
|
||||||
|
The agent does not use external tools. It interacts directly with the environment through:
|
||||||
|
- `env.reset(task)`: Initialize environment with task parameters
|
||||||
|
- `env.step(action)`: Execute action and receive observation, reward, and done flag
|
||||||
|
- `env.render()`: Get text representation of current state
|
||||||
|
|
||||||
|
|
||||||
|
## Dataset Preparation
|
||||||
|
|
||||||
|
The dataset contains task parameters for generating FrozenLake environments. Each sample specifies:
|
||||||
|
- `seed`: Random seed for reproducible map generation
|
||||||
|
- `size`: Grid size (randomly sampled from 2 to `map_max_size`, e.g., 4x4, 6x6)
|
||||||
|
- `p`: Probability that a tile is frozen (vs. being a hole), randomly sampled from 0.6 to 0.85
|
||||||
|
- `index`: Sample index
|
||||||
|
- `uid`: Unique identifier combining seed, size, and p
|
||||||
|
|
||||||
|
Run the data preparation script to generate training and test datasets:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python get_frozenlake_data.py --map_max_size 6 --train_size 10000 --test_size 100
|
||||||
|
```
|
||||||
|
|
||||||
|
This will create parquet files in the specified directory:
|
||||||
|
|
||||||
|
```
|
||||||
|
/path/to/frozenlake_dataset/
|
||||||
|
├── train.parquet # 10000 training samples
|
||||||
|
└── test.parquet # 100 test samples
|
||||||
|
```
|
||||||
|
|
||||||
|
Each sample looks like:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{"seed": 12345, "size": 5, "p": 0.75, "index": 0, "uid": "12345_5_0.75"}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Note**: The data preparation script ensures that all generated maps have a valid path from start to goal within the maximum allowed steps (`env_max_steps=8`), filtering out unsolvable tasks.
|
||||||
|
|
||||||
|
## Code Implementation
|
||||||
|
|
||||||
|
This section provides a high-level overview of the code implementation. For detailed implementation, please refer to the source code.
|
||||||
|
|
||||||
|
### High-level Overview
|
||||||
|
|
||||||
|
The implementation consists of three main components:
|
||||||
|
|
||||||
|
1. **Agent** (`FrozenLakeAgent`): Extends `ReActAgent` to handle multi-step navigation
|
||||||
|
2. **Environment** (`FrozenLakeEnv`): Wraps Gymnasium's FrozenLake environment
|
||||||
|
3. **Workflow** (`run_frozen_lake`): Orchestrates the agent-environment interaction loop
|
||||||
|
|
||||||
|
### Agent Workflow
|
||||||
|
|
||||||
|
The workflow function `run_frozen_lake` implements the agent-environment interaction loop:
|
||||||
|
|
||||||
|
```python
|
||||||
|
async def run_frozen_lake(
|
||||||
|
task: Dict,
|
||||||
|
model: ChatModelBase,
|
||||||
|
auxiliary_models: Dict[str, ChatModelBase],
|
||||||
|
) -> WorkflowOutput:
|
||||||
|
# ...
|
||||||
|
|
||||||
|
# Create agent and environment
|
||||||
|
agent = FrozenLakeAgent(model=model, ...)
|
||||||
|
env = FrozenLakeEnv(...)
|
||||||
|
observation, _ = env.reset(task)
|
||||||
|
rewards = []
|
||||||
|
# ...
|
||||||
|
|
||||||
|
# Agent-environment interaction loop
|
||||||
|
for _ in range(max_steps):
|
||||||
|
response = await agent.reply(msg=Msg("user", agent.get_prompt(observation), role="user"))
|
||||||
|
action = agent.get_action(response)
|
||||||
|
observation, reward, done, _ = env.step(action)
|
||||||
|
rewards.append(reward)
|
||||||
|
if done:
|
||||||
|
break
|
||||||
|
|
||||||
|
# ...
|
||||||
|
final_reward = sum(rewards)
|
||||||
|
final_response = Msg("assistant", response_content, role="assistant")
|
||||||
|
|
||||||
|
return WorkflowOutput(
|
||||||
|
reward=final_reward,
|
||||||
|
response=final_response,
|
||||||
|
metrics={
|
||||||
|
"env_steps": float(step_count),
|
||||||
|
"env_done": float(done),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
**Key characteristics:**
|
||||||
|
- Multi-step interaction: The agent takes multiple actions in a single episode, unlike single-turn QA tasks
|
||||||
|
- State tracking: The agent maintains internal state (current step, last action, last observation) across steps
|
||||||
|
- Error handling: Invalid actions or agent errors are caught and handled gracefully
|
||||||
|
|
||||||
|
### Reward Function
|
||||||
|
|
||||||
|
No separate judge function is needed. The reward comes directly from the environment:
|
||||||
|
- 1.0: Agent successfully reaches the goal (G)
|
||||||
|
- 0.0: Agent falls into a hole (H) or fails to reach the goal within the maximum steps
|
||||||
|
|
||||||
|
The reward is computed as the sum of step rewards throughout the episode. The workflow returns:
|
||||||
|
- `reward`: Final cumulative reward
|
||||||
|
- `response`: Final response message containing observation, total reward, steps taken, and termination reason
|
||||||
|
- `metrics`: Additional metrics including `env_steps` (number of steps taken) and `env_done` (whether episode completed)
|
||||||
|
|
||||||
|
### Implementation Details
|
||||||
|
|
||||||
|
The environment (`FrozenLakeEnv`) wraps Gymnasium's FrozenLake and provides:
|
||||||
|
- `reset(task)`: Initialize the environment with task parameters
|
||||||
|
- `step(action)`: Execute an action and return (observation, reward, done, info)
|
||||||
|
- `render()`: Return a text representation of the current state
|
||||||
|
|
||||||
|
The agent (`FrozenLakeAgent`) extends `ReActAgent` and provides:
|
||||||
|
- `reply(msg)`: Reply to a message and return an action (inherited from AgentScope)
|
||||||
|
- `get_prompt(observation)`: Generate a prompt from the current observation
|
||||||
|
- `get_action(response)`: Parse the model's response to extract an action (Up/Down/Left/Right)
|
||||||
|
- `update_state(action, observation)`: Update internal state after each step
|
||||||
|
|
||||||
|
See [frozenlake_env.py](./frozenlake_env.py) and [frozenlake_agent.py](./frozenlake_agent.py) for implementation details.
|
||||||
|
|
||||||
|
### Step 4: Use `tune` to train the workflow
|
||||||
|
|
||||||
|
```python
|
||||||
|
from agentscope.tuner import tune, DatasetConfig
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
config_path = os.path.join(
|
||||||
|
os.path.dirname(__file__),
|
||||||
|
"config.yaml",
|
||||||
|
)
|
||||||
|
dataset = DatasetConfig(
|
||||||
|
path="/path/to/frozenlake_dataset",
|
||||||
|
name="default",
|
||||||
|
split="train",
|
||||||
|
)
|
||||||
|
tune(
|
||||||
|
workflow_func=run_frozen_lake,
|
||||||
|
train_dataset=dataset,
|
||||||
|
config_path=config_path,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
See [config.yaml](./config.yaml) for the training configuration. For full configuration details, see [Trinity-RFT Configuration Guide](https://agentscope-ai.github.io/Trinity-RFT/en/main/tutorial/trinity_configs.html).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## How to Run
|
||||||
|
|
||||||
|
### Prerequisites
|
||||||
|
|
||||||
|
- At least 2 NVIDIA GPUs with CUDA 12.8 or newer
|
||||||
|
- Follow the Trinity-RFT [installation guide](https://agentscope-ai.github.io/Trinity-RFT/en/main/tutorial/trinity_installation.html) to install the latest version from source code
|
||||||
|
- Install gymnasium for the FrozenLake environment:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install gymnasium[toy_text]
|
||||||
|
```
|
||||||
|
|
||||||
|
- Download the model checkpoint (example):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
huggingface-cli download Qwen/Qwen2.5-3B-Instruct
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 1: Prepare the Dataset
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python get_frozenlake_data.py --map_max_size 6 --train_size 10000 --test_size 100
|
||||||
|
```
|
||||||
|
|
||||||
|
Update the dataset path in `main.py` to point to your generated dataset directory.
|
||||||
|
|
||||||
|
### Step 2: Configure the Training
|
||||||
|
|
||||||
|
Key configuration can be identified in the code, including:
|
||||||
|
|
||||||
|
**Algorithm Configuration** (`AlgorithmConfig`):
|
||||||
|
- `algorithm_type`: `multi_step_grpo` (Group Relative Policy Optimization for multi-step tasks)
|
||||||
|
- `group_size`: Number of policy update iterations per batch (default: 16)
|
||||||
|
- `batch_size`: Batch size for training (default: 32)
|
||||||
|
- `learning_rate`: Learning rate (default: 1e-6)
|
||||||
|
|
||||||
|
**Model Configuration** (`TunerModelConfig`):
|
||||||
|
- `model_path`: Path to the base model (e.g., `Qwen/Qwen2.5-3B-Instruct`)
|
||||||
|
- `max_model_len`: Maximum model context length (default: 25600)
|
||||||
|
- `max_tokens`: Maximum tokens for response generation (default: 2048)
|
||||||
|
- `inference_engine_num`: Number of inference engines (default: 6, using 6 GPUs for inference)
|
||||||
|
|
||||||
|
**Dataset Configuration** (`DatasetConfig`):
|
||||||
|
- `path`: Path to the dataset (default: `/path/to/frozenlake`)
|
||||||
|
- `split`: Split of the dataset (default: `train`)
|
||||||
|
|
||||||
|
Adjust these parameters based on your hardware resources and training requirements. Other parameters can be spetified in [config.yaml](./config.yaml).
|
||||||
|
|
||||||
|
|
||||||
|
### Step 3: Set Up Ray Cluster
|
||||||
|
|
||||||
|
Set up a [Ray](https://github.com/ray-project/ray) cluster:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
ray start --head
|
||||||
|
# for multi-node setup, run the following command on worker nodes
|
||||||
|
# ray start --address=<master_address>
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 4: Run the Training Script
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python main.py
|
||||||
|
```
|
||||||
|
|
||||||
|
The training will start and you can monitor the progress through the logs. Checkpoints will be saved once every `trainer.save_interval` steps.
|
||||||
|
|
||||||
|
## Experimental Results
|
||||||
|
|
||||||
|
### Training Reward Curve
|
||||||
|
|
||||||
|
The reward curve during training shows the agent's learning progress:
|
||||||
|
|
||||||
|
<div align="center">
|
||||||
|
<img src="./critic_rewards_mean.png" alt="reward" width="90%"/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
The training reward typically increases over epochs as the agent learns to navigate the frozen lake more effectively.
|
||||||
|
|
||||||
|
### Example Agent Output
|
||||||
|
|
||||||
|
An example of agent output is given below:
|
||||||
|
```
|
||||||
|
From the current observation, let's analyze the situation. The player (P) is at: (4, 0), and the goal (G) is at: (2, 3). There is also a hole (O) at (4, 4). Given this, I can move towards the goal without worrying about slippery tiles right now.
|
||||||
|
|
||||||
|
The shortest path from P to G involves moving left (4 steps) followed by moving down (1 step), since going directly would bypass the hole or move us further from the goal. Let's move left first.
|
||||||
|
|
||||||
|
Let's take the action ```Left```.
|
||||||
|
```
|
||||||
252
tuner/frozen_lake/README_zh.md
Normal file
252
tuner/frozen_lake/README_zh.md
Normal file
@@ -0,0 +1,252 @@
|
|||||||
|
# 使用 AgentScope-Tuner 训练 FrozenLake Agent
|
||||||
|
|
||||||
|
## 摘要
|
||||||
|
|
||||||
|
本示例展示如何使用 AgentScope-Tuner 配合 [Trinity-RFT](https://github.com/agentscope-ai/Trinity-RFT) 对 [Frozen Lake](https://gymnasium.farama.org/environments/toy_text/frozen_lake/) 任务进行强化微调。智能体需要在冰湖网格中从起点走到终点,避开坑洞,并在有限步数内完成任务。
|
||||||
|
|
||||||
|
## 任务设定
|
||||||
|
|
||||||
|
### 智能体目标
|
||||||
|
智能体要在冰湖网格上从起点 (S) 抵达终点 (G),同时:
|
||||||
|
- 规划路径经过冰面 (F) 到达终点
|
||||||
|
- 避开会结束回合且奖励为 0 的坑洞 (H)
|
||||||
|
- 在限定步数内完成任务
|
||||||
|
|
||||||
|
### 智能体类型
|
||||||
|
智能体实现为 **ReActAgent**,它的行为包括:
|
||||||
|
- 观察当前冰湖网格状态
|
||||||
|
- 推理下一步最优动作
|
||||||
|
- 执行动作(上、下、左、右)在环境中移动
|
||||||
|
- 在多步交互中维护内部状态
|
||||||
|
|
||||||
|
### 环境
|
||||||
|
环境基于 Gymnasium 的 FrozenLake,并提供:
|
||||||
|
- **网格导航**:随机生成 2x2 至 6x6 的地图
|
||||||
|
- **格子类型**:
|
||||||
|
- `S`:起点
|
||||||
|
- `F`:冰面(可通行)
|
||||||
|
- `H`:坑洞(奖励 0,结束回合)
|
||||||
|
- `G`:终点(奖励 +1.0,结束回合)
|
||||||
|
- **动作空间**:离散动作(上、下、左、右)
|
||||||
|
- **奖励设计**:
|
||||||
|
- 到达终点 +1.0
|
||||||
|
- 掉入坑洞或未在最大步数内到达终点为 0.0
|
||||||
|
- **观测**:返回当前玩家位置的文本网格表示
|
||||||
|
|
||||||
|
智能体不使用外部工具,直接通过以下接口与环境交互:
|
||||||
|
- `env.reset(task)`:根据任务参数初始化环境
|
||||||
|
- `env.step(action)`:执行动作,返回观测、奖励和结束标志
|
||||||
|
- `env.render()`:返回当前状态的文本表示
|
||||||
|
|
||||||
|
## 数据集准备
|
||||||
|
|
||||||
|
数据集包含用于生成 FrozenLake 环境的任务参数,每个样本包含:
|
||||||
|
- `seed`:随机种子,保证地图可复现
|
||||||
|
- `size`:网格大小(在 2 和 `map_max_size` 之间随机,如 4x4、6x6)
|
||||||
|
- `p`:格子为冰面的概率(0.6 到 0.85 之间随机),其余为坑洞
|
||||||
|
- `index`:样本索引
|
||||||
|
- `uid`:由 seed、size、p 组合而成的唯一 ID
|
||||||
|
|
||||||
|
运行数据准备脚本生成训练集与测试集:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python get_frozenlake_data.py --map_max_size 6 --train_size 10000 --test_size 100
|
||||||
|
```
|
||||||
|
|
||||||
|
生成的目录结构示例:
|
||||||
|
```
|
||||||
|
/path/to/frozenlake_dataset/
|
||||||
|
├── train.parquet # 10000 条训练样本
|
||||||
|
└── test.parquet # 100 条测试样本
|
||||||
|
```
|
||||||
|
|
||||||
|
样本示例:
|
||||||
|
```json
|
||||||
|
{"seed": 12345, "size": 5, "p": 0.75, "index": 0, "uid": "12345_5_0.75"}
|
||||||
|
```
|
||||||
|
|
||||||
|
**注意**:脚本会过滤无解的地图,确保在最大步数 (`env_max_steps=8`) 内存在从起点到终点的可行路径。
|
||||||
|
|
||||||
|
## 代码实现
|
||||||
|
|
||||||
|
本节提供代码实现的高级概览。详细实现请参考源代码。
|
||||||
|
|
||||||
|
### 高级概览
|
||||||
|
实现由三部分组成:
|
||||||
|
1. **Agent** (`FrozenLakeAgent`):继承 `ReActAgent`,负责多步交互
|
||||||
|
2. **环境** (`FrozenLakeEnv`):封装 Gymnasium FrozenLake
|
||||||
|
3. **工作流** (`run_frozen_lake`):组织智能体与环境的交互流程
|
||||||
|
|
||||||
|
### 工作流
|
||||||
|
`run_frozen_lake` 实现多步交互流程:
|
||||||
|
|
||||||
|
```python
|
||||||
|
async def run_frozen_lake(
|
||||||
|
task: Dict,
|
||||||
|
model: ChatModelBase,
|
||||||
|
auxiliary_models: Dict[str, ChatModelBase],
|
||||||
|
) -> WorkflowOutput:
|
||||||
|
# ...
|
||||||
|
|
||||||
|
# 创建智能体和环境
|
||||||
|
agent = FrozenLakeAgent(model=model, ...)
|
||||||
|
env = FrozenLakeEnv(...)
|
||||||
|
observation, _ = env.reset(task)
|
||||||
|
rewards = []
|
||||||
|
# ...
|
||||||
|
|
||||||
|
# 智能体-环境交互循环
|
||||||
|
for _ in range(max_steps):
|
||||||
|
response = await agent.reply(msg=Msg("user", agent.get_prompt(observation), role="user"))
|
||||||
|
action = agent.get_action(response)
|
||||||
|
observation, reward, done, _ = env.step(action)
|
||||||
|
rewards.append(reward)
|
||||||
|
if done:
|
||||||
|
break
|
||||||
|
|
||||||
|
# ...
|
||||||
|
final_reward = sum(rewards)
|
||||||
|
final_response = Msg("assistant", response_content, role="assistant")
|
||||||
|
|
||||||
|
return WorkflowOutput(
|
||||||
|
reward=final_reward,
|
||||||
|
response=final_response,
|
||||||
|
metrics={"env_steps": float(step_count), "env_done": float(done)},
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**关键特性:**
|
||||||
|
- 多步交互:单次 episode 内多次动作,不是单轮 QA
|
||||||
|
- 状态跟踪:记录当前步、上次动作与观测
|
||||||
|
- 错误处理:无效动作或异常会被捕获并处理
|
||||||
|
|
||||||
|
### 奖励函数
|
||||||
|
无需额外 judge,奖励由环境直接给出:
|
||||||
|
- 1.0:到达终点
|
||||||
|
- 0.0:掉入坑洞或超步数未达终点
|
||||||
|
|
||||||
|
工作流返回:
|
||||||
|
- `reward`:累计奖励
|
||||||
|
- `response`:包含观测、总奖励、步数、终止原因的最终回复
|
||||||
|
- `metrics`:`env_steps`(步数)、`env_done`(是否结束)
|
||||||
|
|
||||||
|
### 实现细节
|
||||||
|
|
||||||
|
环境 (`FrozenLakeEnv`) 封装了 Gymnasium 的 FrozenLake,提供:
|
||||||
|
- `reset(task)`: 使用任务参数初始化环境
|
||||||
|
- `step(action)`: 执行动作并返回 (observation, reward, done, info)
|
||||||
|
- `render()`: 返回当前状态的文本表示
|
||||||
|
|
||||||
|
智能体 (`FrozenLakeAgent`) 继承 `ReActAgent`,提供:
|
||||||
|
- `reply(msg)`: 回复消息并返回动作(继承自 AgentScope)
|
||||||
|
- `get_prompt(observation)`: 从当前观测生成提示
|
||||||
|
- `get_action(response)`: 解析模型响应以提取动作(Up/Down/Left/Right)
|
||||||
|
- `update_state(action, observation)`: 在每步后更新内部状态
|
||||||
|
|
||||||
|
详细实现请参考 [frozenlake_env.py](./frozenlake_env.py) 和 [frozenlake_agent.py](./frozenlake_agent.py)。
|
||||||
|
|
||||||
|
### 步骤 4:使用 `tune` 训练工作流
|
||||||
|
|
||||||
|
```python
|
||||||
|
from agentscope.tuner import tune, DatasetConfig
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
config_path = os.path.join(
|
||||||
|
os.path.dirname(__file__),
|
||||||
|
"config.yaml",
|
||||||
|
)
|
||||||
|
dataset = DatasetConfig(
|
||||||
|
path="/path/to/frozenlake_dataset",
|
||||||
|
name="default",
|
||||||
|
split="train",
|
||||||
|
)
|
||||||
|
tune(
|
||||||
|
workflow_func=run_frozen_lake,
|
||||||
|
train_dataset=dataset,
|
||||||
|
config_path=config_path,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
训练配置请参考 [config.yaml](./config.yaml)。完整配置详情请参考 [Trinity-RFT 配置指南](https://agentscope-ai.github.io/Trinity-RFT/en/main/tutorial/trinity_configs.html)。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 运行方法
|
||||||
|
|
||||||
|
### 依赖
|
||||||
|
- 至少 2 张 NVIDIA GPU,CUDA 版本 ≥ 12.8
|
||||||
|
- 按 [Trinity-RFT 安装指南](https://agentscope-ai.github.io/Trinity-RFT/en/main/tutorial/trinity_installation.html) 从源码安装
|
||||||
|
- 安装 gymnasium 冰湖环境:
|
||||||
|
```bash
|
||||||
|
pip install gymnasium[toy_text]
|
||||||
|
```
|
||||||
|
- 下载模型权重(示例):
|
||||||
|
```bash
|
||||||
|
huggingface-cli download Qwen/Qwen2.5-3B-Instruct
|
||||||
|
```
|
||||||
|
|
||||||
|
### 步骤 1:准备数据集
|
||||||
|
```bash
|
||||||
|
python get_frozenlake_data.py --map_max_size 6 --train_size 10000 --test_size 100
|
||||||
|
```
|
||||||
|
将 `main.py` 中的数据集路径改为你的生成目录。
|
||||||
|
|
||||||
|
### 步骤 2:配置训练
|
||||||
|
|
||||||
|
关键配置可在代码中设置,包括:
|
||||||
|
|
||||||
|
**算法配置** (`AlgorithmConfig`):
|
||||||
|
- `algorithm_type`: `multi_step_grpo`(用于多步任务的组相对策略优化)
|
||||||
|
- `group_size`: 每批次的策略更新组大小(默认 16)
|
||||||
|
- `batch_size`: 批大小(默认 32)
|
||||||
|
- `learning_rate`: 学习率(默认 1e-6)
|
||||||
|
|
||||||
|
**模型配置** (`TunerModelConfig`):
|
||||||
|
- `model_path`: 基础模型路径(如 `Qwen/Qwen2.5-3B-Instruct`)
|
||||||
|
- `max_model_len`: 最大上下文长度(默认 25600)
|
||||||
|
- `max_tokens`: 响应最大生成长度(默认 2048)
|
||||||
|
- `inference_engine_num`: 推理引擎数量(默认 6,表示用 6 个 GPU 进行推理)
|
||||||
|
|
||||||
|
**数据集配置** (`DatasetConfig`):
|
||||||
|
- `path`: 数据集路径(默认 `/path/to/frozenlake`)
|
||||||
|
- `split`: 数据集分片(默认 `train`)
|
||||||
|
|
||||||
|
可根据硬件资源和训练需求调整这些参数。其他参数可在 [config.yaml](./config.yaml) 中指定。
|
||||||
|
|
||||||
|
### 步骤 3:设置 Ray 集群
|
||||||
|
|
||||||
|
设置 [Ray](https://github.com/ray-project/ray) 集群:
|
||||||
|
```bash
|
||||||
|
ray start --head
|
||||||
|
# 对于多节点设置,在工作节点上运行以下命令
|
||||||
|
# ray start --address=<master_address>
|
||||||
|
```
|
||||||
|
|
||||||
|
### 步骤 4:运行训练脚本
|
||||||
|
```bash
|
||||||
|
python main.py
|
||||||
|
```
|
||||||
|
训练将开始,可通过日志监控进度。检查点将每 `trainer.save_interval` 步保存一次。
|
||||||
|
|
||||||
|
## 实验结果
|
||||||
|
|
||||||
|
### 训练奖励曲线
|
||||||
|
|
||||||
|
训练过程中的奖励曲线显示智能体的学习进度:
|
||||||
|
|
||||||
|
<div align="center">
|
||||||
|
<img src="./critic_rewards_mean.png" alt="reward" width="90%"/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
训练奖励通常随着智能体学习更有效地导航冰湖而随训练轮次增加。
|
||||||
|
|
||||||
|
### 智能体输出示例
|
||||||
|
|
||||||
|
智能体输出示例如下:
|
||||||
|
```
|
||||||
|
From the current observation, let's analyze the situation. The player (P) is at: (4, 0), and the goal (G) is at: (2, 3). There is also a hole (O) at (4, 4). Given this, I can move towards the goal without worrying about slippery tiles right now.
|
||||||
|
|
||||||
|
The shortest path from P to G involves moving left (4 steps) followed by moving down (1 step), since going directly would bypass the hole or move us further from the goal. Let's move left first.
|
||||||
|
|
||||||
|
Let's take the action ```Left```.
|
||||||
|
```
|
||||||
102
tuner/frozen_lake/_frozenlake_agent.py
Normal file
102
tuner/frozen_lake/_frozenlake_agent.py
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""Adapted from Trinity-RFT"""
|
||||||
|
import re
|
||||||
|
from _utils import SYSTEM_PROMPT, FrozenLakeAction # pylint: disable=E0611
|
||||||
|
from agentscope.agent import ReActAgent
|
||||||
|
from agentscope.formatter import OpenAIChatFormatter
|
||||||
|
from agentscope.message import Msg
|
||||||
|
from agentscope.model import OpenAIChatModel
|
||||||
|
|
||||||
|
|
||||||
|
INVALID_ACTION = "still"
|
||||||
|
VALID_ACTIONS = {
|
||||||
|
"left": 1,
|
||||||
|
"down": 2,
|
||||||
|
"right": 3,
|
||||||
|
"up": 4,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class FrozenLakeAgent(ReActAgent):
|
||||||
|
"""Agent for FrozenLake environment."""
|
||||||
|
|
||||||
|
def __init__(self, model: OpenAIChatModel, max_steps: int = 20):
|
||||||
|
super().__init__(
|
||||||
|
name="frozenlake_agent",
|
||||||
|
model=model,
|
||||||
|
sys_prompt=SYSTEM_PROMPT,
|
||||||
|
formatter=OpenAIChatFormatter(),
|
||||||
|
max_iters=1,
|
||||||
|
)
|
||||||
|
self.response_structure = FrozenLakeAction
|
||||||
|
self.current_step = 0
|
||||||
|
self.last_action = None
|
||||||
|
self.last_observation = None
|
||||||
|
self.max_steps = max_steps
|
||||||
|
|
||||||
|
def get_prompt(self, observation: str) -> str:
|
||||||
|
"""Get prompt for the agent based on current observation."""
|
||||||
|
prompt = (
|
||||||
|
f"Current Observation ({self.current_step}): \n"
|
||||||
|
+ observation
|
||||||
|
+ "\n"
|
||||||
|
+ (
|
||||||
|
"You have not achieved the goal, P has not reached G yet. "
|
||||||
|
"Please give the next action."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if self.current_step > 0 and self.last_action is not None:
|
||||||
|
if self.last_observation == observation:
|
||||||
|
prompt += (
|
||||||
|
"\nYour last response is invalid. "
|
||||||
|
"Your position didn't change at all. "
|
||||||
|
"You may need to recheck your thinking process, "
|
||||||
|
"action outputted, and the format of response. "
|
||||||
|
"Remember, you should only output the NEXT ACTION "
|
||||||
|
"at each iteration in the ``` ```. "
|
||||||
|
"For example, if you want to move up, "
|
||||||
|
"you should output ```Up```."
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.max_steps is not None
|
||||||
|
and self.max_steps - self.current_step > 0
|
||||||
|
):
|
||||||
|
remaining = self.max_steps - self.current_step
|
||||||
|
prompt += (
|
||||||
|
f"\nThe maximum number of steps remaining is {remaining}."
|
||||||
|
)
|
||||||
|
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
def get_action(self, msg: Msg) -> str:
|
||||||
|
"""Extract action from agent response message."""
|
||||||
|
response: str = (
|
||||||
|
msg.content
|
||||||
|
if isinstance(msg.content, str)
|
||||||
|
else msg.content[0].get("text")
|
||||||
|
)
|
||||||
|
action = INVALID_ACTION
|
||||||
|
|
||||||
|
matches = re.findall(r"```(.*?)```", response, re.DOTALL)
|
||||||
|
|
||||||
|
if matches:
|
||||||
|
last_match_content = matches[-1].strip()
|
||||||
|
action = last_match_content.lower()
|
||||||
|
if action not in VALID_ACTIONS:
|
||||||
|
action = INVALID_ACTION
|
||||||
|
|
||||||
|
return action
|
||||||
|
|
||||||
|
def update_state(self, action: str, observation: str) -> None:
|
||||||
|
"""Update agent state with action and observation."""
|
||||||
|
self.last_action = action
|
||||||
|
self.last_observation = observation
|
||||||
|
self.current_step += 1
|
||||||
|
|
||||||
|
async def reset(self) -> None:
|
||||||
|
"""Reset agent state for a new episode."""
|
||||||
|
self.current_step = 0
|
||||||
|
self.last_action = None
|
||||||
|
self.last_observation = None
|
||||||
|
await self.memory.clear()
|
||||||
316
tuner/frozen_lake/_frozenlake_env.py
Normal file
316
tuner/frozen_lake/_frozenlake_env.py
Normal file
@@ -0,0 +1,316 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""Adapted from Trinity-RFT"""
|
||||||
|
import copy
|
||||||
|
from typing import Dict, Optional, Tuple, Union
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
try:
|
||||||
|
from gymnasium.envs.toy_text.frozen_lake import (
|
||||||
|
FrozenLakeEnv as GymFrozenLakeEnv,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
GymFrozenLakeEnv = object
|
||||||
|
from _utils import ( # pylint: disable=E0611
|
||||||
|
generate_random_map,
|
||||||
|
get_goal_position,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FrozenLakeEnv(GymFrozenLakeEnv):
|
||||||
|
"""FrozenLake environment wrapper."""
|
||||||
|
|
||||||
|
# Map gym state in integer
|
||||||
|
MAP_LOOKUP = {
|
||||||
|
b"P": 0,
|
||||||
|
b"F": 1,
|
||||||
|
b"H": 2,
|
||||||
|
b"G": 3,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Define rules to transform to rendered text observation of the environment
|
||||||
|
GRID_LOOKUP = {
|
||||||
|
0: " P \t", # player
|
||||||
|
1: " _ \t", # frozen
|
||||||
|
2: " O \t", # hole
|
||||||
|
3: " G \t", # goal
|
||||||
|
4: " X \t", # player fall into hole
|
||||||
|
5: " √ \t", # player on goal
|
||||||
|
}
|
||||||
|
|
||||||
|
ACTION_LOOKUP = {
|
||||||
|
"still": 0,
|
||||||
|
"left": 1,
|
||||||
|
"down": 2,
|
||||||
|
"right": 3,
|
||||||
|
"up": 4,
|
||||||
|
}
|
||||||
|
|
||||||
|
INVALID_ACTION = 0
|
||||||
|
PENALTY_FOR_INVALID = -1
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_steps: int = 8,
|
||||||
|
desc: Optional[str] = None,
|
||||||
|
is_slippery: bool = False,
|
||||||
|
size: int = 8,
|
||||||
|
p: float = 0.8,
|
||||||
|
seed: int = 42,
|
||||||
|
):
|
||||||
|
self.max_steps = max_steps or 8
|
||||||
|
self.desc: Union[str, np.ndarray, None] = desc
|
||||||
|
self.is_slippery = is_slippery
|
||||||
|
self.size = size
|
||||||
|
self.p = p
|
||||||
|
self.seed = seed
|
||||||
|
self.render_mode: Optional[str] = None
|
||||||
|
try:
|
||||||
|
import gymnasium as gym
|
||||||
|
except ImportError as e:
|
||||||
|
error_message = (
|
||||||
|
"Gymnasium is not installed. "
|
||||||
|
"Please install gymnasium first before "
|
||||||
|
"running the frozen_lake workflow. "
|
||||||
|
f"Error: {str(e)}"
|
||||||
|
)
|
||||||
|
raise ImportError(error_message) from e
|
||||||
|
|
||||||
|
if self.desc is None:
|
||||||
|
random_map, goal_position = generate_random_map(
|
||||||
|
size=self.size,
|
||||||
|
p=self.p,
|
||||||
|
seed=self.seed,
|
||||||
|
max_steps=self.max_steps,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
random_map = np.asarray(copy.deepcopy(self.desc), dtype="c")
|
||||||
|
goal_position = get_goal_position(random_map)
|
||||||
|
|
||||||
|
self.goal_position = goal_position
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
desc=random_map[:],
|
||||||
|
is_slippery=self.is_slippery,
|
||||||
|
)
|
||||||
|
assert isinstance(self.desc, np.ndarray)
|
||||||
|
self.action_space = gym.spaces.Discrete(4, start=1)
|
||||||
|
|
||||||
|
self.map_kwargs = {
|
||||||
|
"size": size,
|
||||||
|
"p": p,
|
||||||
|
}
|
||||||
|
self.env_kwargs = {
|
||||||
|
"is_slippery": is_slippery,
|
||||||
|
"desc": copy.deepcopy(desc),
|
||||||
|
"seed": seed,
|
||||||
|
}
|
||||||
|
|
||||||
|
self.action_map = {
|
||||||
|
1: 0, # left
|
||||||
|
2: 1, # down
|
||||||
|
3: 2, # right
|
||||||
|
4: 3, # up
|
||||||
|
}
|
||||||
|
|
||||||
|
def _get_player_position(self) -> Tuple[int, int]:
|
||||||
|
return (self.s // self.ncol, self.s % self.ncol) # (row, col)
|
||||||
|
|
||||||
|
def step(self, action: str) -> Tuple[str, float, bool, Dict]:
|
||||||
|
"""Execute a step in the environment.
|
||||||
|
|
||||||
|
Maps custom action to gymnasium FrozenLakeEnv action and
|
||||||
|
takes the step. Checks if the action is effective (whether
|
||||||
|
player moves in the env).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action: The action to take.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (observation, reward, done, info).
|
||||||
|
"""
|
||||||
|
if self.success():
|
||||||
|
obs = self.render(mode="tiny_rgb_array")
|
||||||
|
assert isinstance(obs, str)
|
||||||
|
return obs, 1.0, True, {"action_is_effective": False}
|
||||||
|
|
||||||
|
action_id: int = self.ACTION_LOOKUP.get(action.lower(), 0)
|
||||||
|
|
||||||
|
if not action_id:
|
||||||
|
action_id = self.INVALID_ACTION
|
||||||
|
|
||||||
|
if (
|
||||||
|
action_id == self.INVALID_ACTION
|
||||||
|
or action_id not in self.action_map
|
||||||
|
):
|
||||||
|
obs = self.render(mode="tiny_rgb_array")
|
||||||
|
assert isinstance(obs, str)
|
||||||
|
return obs, 0.0, False, {"action_is_effective": False}
|
||||||
|
|
||||||
|
prev_player_position = int(self.s)
|
||||||
|
|
||||||
|
# Call parent class step method
|
||||||
|
# Note: GymFrozenLakeEnv is imported at module level
|
||||||
|
player_pos, reward, done, _, _ = super().step(
|
||||||
|
self.action_map[action_id],
|
||||||
|
)
|
||||||
|
|
||||||
|
obs = self.render(mode="tiny_rgb_array")
|
||||||
|
assert isinstance(obs, str)
|
||||||
|
return (
|
||||||
|
obs,
|
||||||
|
float(reward),
|
||||||
|
bool(done),
|
||||||
|
{"action_is_effective": prev_player_position != int(player_pos)},
|
||||||
|
)
|
||||||
|
|
||||||
|
def render(
|
||||||
|
self,
|
||||||
|
mode: str = "tiny_rgb_array",
|
||||||
|
) -> str | list[str] | np.ndarray:
|
||||||
|
"""Render the environment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mode: Rendering mode. Options: "tiny_rgb_array", "list",
|
||||||
|
"state", "rgb_array", "ansi".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Rendered observation based on the mode.
|
||||||
|
"""
|
||||||
|
assert mode in [
|
||||||
|
"tiny_rgb_array",
|
||||||
|
"list",
|
||||||
|
"state",
|
||||||
|
"rgb_array",
|
||||||
|
"ansi",
|
||||||
|
]
|
||||||
|
if mode in ["rgb_array", "ansi"]:
|
||||||
|
prev_render_mode = self.render_mode
|
||||||
|
self.render_mode = mode
|
||||||
|
obs = super().render()
|
||||||
|
self.render_mode = prev_render_mode
|
||||||
|
return obs
|
||||||
|
assert isinstance(self.desc, np.ndarray)
|
||||||
|
room_state = copy.deepcopy(self.desc)
|
||||||
|
|
||||||
|
# replace the position of start 'S' with 'F'
|
||||||
|
position_S = np.where(room_state == b"S")
|
||||||
|
room_state[position_S] = b"F"
|
||||||
|
|
||||||
|
# replace the position of the player with 'P'
|
||||||
|
position_P = self._get_player_position()
|
||||||
|
room_state[position_P] = b"P"
|
||||||
|
|
||||||
|
if mode == "state":
|
||||||
|
# transform 'S', 'F', 'H', 'G' to numpy integer array
|
||||||
|
room_state = np.vectorize(lambda x: self.MAP_LOOKUP[x])(room_state)
|
||||||
|
# add player in hole or player on goal
|
||||||
|
if self.desc[position_P] == b"H":
|
||||||
|
room_state[position_P] = 4
|
||||||
|
elif self.desc[position_P] == b"G":
|
||||||
|
room_state[position_P] = 5
|
||||||
|
return room_state
|
||||||
|
|
||||||
|
room_state = self.render(mode="state").tolist()
|
||||||
|
assert isinstance(room_state, list)
|
||||||
|
|
||||||
|
if mode == "list":
|
||||||
|
|
||||||
|
def lookup_list(cell: int) -> str:
|
||||||
|
return self.GRID_LOOKUP.get(cell, "?").strip("\t").strip()
|
||||||
|
|
||||||
|
return [
|
||||||
|
" ".join(lookup_list(cell) for cell in row)
|
||||||
|
for row in room_state
|
||||||
|
]
|
||||||
|
|
||||||
|
if mode == "tiny_rgb_array":
|
||||||
|
|
||||||
|
def lookup_tiny(cell: int) -> str:
|
||||||
|
return self.GRID_LOOKUP.get(cell, "?")
|
||||||
|
|
||||||
|
result = "\n".join(
|
||||||
|
"".join(lookup_tiny(cell) for cell in row)
|
||||||
|
for row in room_state
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
# Default return for other modes
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def reset(
|
||||||
|
self,
|
||||||
|
task: Optional[Dict] = None,
|
||||||
|
) -> tuple[str, Dict]:
|
||||||
|
"""Reset the environment with optional task parameters."""
|
||||||
|
task = task or {}
|
||||||
|
# Update parameters from task if provided
|
||||||
|
size = task.get("size", self.map_kwargs["size"])
|
||||||
|
p = task.get("p", self.map_kwargs["p"])
|
||||||
|
seed = task.get("seed", self.env_kwargs["seed"])
|
||||||
|
is_slippery = task.get(
|
||||||
|
"is_slippery",
|
||||||
|
self.env_kwargs["is_slippery"],
|
||||||
|
)
|
||||||
|
desc = task.get("desc", self.env_kwargs.get("desc"))
|
||||||
|
|
||||||
|
# Update instance variables
|
||||||
|
self.size = size
|
||||||
|
self.p = p
|
||||||
|
self.seed = seed
|
||||||
|
self.is_slippery = is_slippery
|
||||||
|
self.map_kwargs["size"] = size
|
||||||
|
self.map_kwargs["p"] = p
|
||||||
|
self.env_kwargs["seed"] = seed
|
||||||
|
self.env_kwargs["is_slippery"] = is_slippery
|
||||||
|
if desc is not None:
|
||||||
|
self.env_kwargs["desc"] = copy.deepcopy(desc)
|
||||||
|
|
||||||
|
if desc is None:
|
||||||
|
random_map, goal_position = generate_random_map(
|
||||||
|
size=size,
|
||||||
|
p=p,
|
||||||
|
seed=seed,
|
||||||
|
max_steps=self.max_steps,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
random_map = np.asarray(copy.deepcopy(desc), dtype="c")
|
||||||
|
goal_position = get_goal_position(random_map)
|
||||||
|
|
||||||
|
self.goal_position = goal_position
|
||||||
|
self.desc = random_map[:]
|
||||||
|
|
||||||
|
# Reinitialize parent class with new map
|
||||||
|
try:
|
||||||
|
import gymnasium as gym
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
desc=random_map[:],
|
||||||
|
is_slippery=self.is_slippery,
|
||||||
|
)
|
||||||
|
assert isinstance(self.desc, np.ndarray)
|
||||||
|
self.action_space = gym.spaces.Discrete(4, start=1)
|
||||||
|
except ImportError as e:
|
||||||
|
error_message = (
|
||||||
|
"Gymnasium is not installed. "
|
||||||
|
"Please install gymnasium first before "
|
||||||
|
"running the frozen_lake workflow. "
|
||||||
|
f"Error: {str(e)}"
|
||||||
|
)
|
||||||
|
raise ImportError(error_message) from e
|
||||||
|
|
||||||
|
super().reset(seed=self.seed)
|
||||||
|
obs = self.render(mode="tiny_rgb_array")
|
||||||
|
assert isinstance(obs, str)
|
||||||
|
return obs, {}
|
||||||
|
|
||||||
|
def finished(self) -> bool:
|
||||||
|
"""Check if the episode is finished (goal or hole)."""
|
||||||
|
player_pos = self._get_player_position()
|
||||||
|
assert isinstance(self.desc, np.ndarray)
|
||||||
|
return self.desc[player_pos] in b"GH" # type: ignore
|
||||||
|
|
||||||
|
def success(self) -> bool:
|
||||||
|
"""Check if the agent has reached the goal (G)."""
|
||||||
|
player_pos = self._get_player_position()
|
||||||
|
assert isinstance(self.desc, np.ndarray)
|
||||||
|
return self.desc[player_pos] in b"G"
|
||||||
209
tuner/frozen_lake/_utils.py
Normal file
209
tuner/frozen_lake/_utils.py
Normal file
@@ -0,0 +1,209 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
Utils for the FrozenLake environment.
|
||||||
|
Modified from rllm
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Literal, Optional, Tuple
|
||||||
|
import numpy as np
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
# Map gym state in integer
|
||||||
|
MAP_LOOKUP = {
|
||||||
|
b"P": 0,
|
||||||
|
b"F": 1,
|
||||||
|
b"H": 2,
|
||||||
|
b"G": 3,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Define rules to transform to rendered text observation of the environment
|
||||||
|
GRID_LOOKUP = {
|
||||||
|
0: " P \t", # player
|
||||||
|
1: " _ \t", # frozen
|
||||||
|
2: " O \t", # hole
|
||||||
|
3: " G \t", # goal
|
||||||
|
4: " X \t", # player fall into hole
|
||||||
|
5: " √ \t", # player on goal
|
||||||
|
}
|
||||||
|
|
||||||
|
ACTION_LOOKUP = {
|
||||||
|
0: "None",
|
||||||
|
1: "Left",
|
||||||
|
2: "Down",
|
||||||
|
3: "Right",
|
||||||
|
4: "Up",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Prompting format inspired by the RAGEN project
|
||||||
|
SYSTEM_PROMPT = """You are Qwen, created by Alibaba Cloud. \
|
||||||
|
You are a helpful assistant. You are walking on a frozen lake.
|
||||||
|
|
||||||
|
FrozenLake Quick Guide
|
||||||
|
Goal: Reach the goal (G). Player (P) and Goal (G) must overlap.
|
||||||
|
|
||||||
|
Symbols:
|
||||||
|
_ Frozen | O Hole | G Goal | P Player
|
||||||
|
|
||||||
|
Rules:
|
||||||
|
1. Avoid falling into holes (O).
|
||||||
|
2. Frozen tiles are slippery, you may move perpendicular to
|
||||||
|
your intended direction.
|
||||||
|
|
||||||
|
Valid Action (separated by | ):
|
||||||
|
Up | Down | Left | Right
|
||||||
|
|
||||||
|
Rewards:
|
||||||
|
Fall into hole: 0
|
||||||
|
Reach goal: +1.0
|
||||||
|
|
||||||
|
You will be provided the current observation, please decide on
|
||||||
|
the next Action.
|
||||||
|
You should show your thought process and then input the final
|
||||||
|
action in ``` ```.
|
||||||
|
You should only output the NEXT ACTION at each iteration in
|
||||||
|
the ``` ```. For example, if you want to move up, you should
|
||||||
|
output ```Up```.
|
||||||
|
You should plan ahead and need to achieve it in minimum number
|
||||||
|
of steps.
|
||||||
|
You should be aware that frozen tiles can be slippery, but the
|
||||||
|
chance is small and you should not overthink it.
|
||||||
|
|
||||||
|
Please show your thinking process and put the final action in
|
||||||
|
``` ```. In every turn, the final action MUST be one of Up,
|
||||||
|
Down, Left, Right.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class FrozenLakeAction(BaseModel):
|
||||||
|
"""Action model for FrozenLake environment."""
|
||||||
|
|
||||||
|
action: Literal["Up", "Down", "Left", "Right"] = Field(
|
||||||
|
description=(
|
||||||
|
"The action to take in the FrozenLake environment, "
|
||||||
|
"must be one of Up, Down, Left, Right"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def is_valid(board: list[list[str]], max_size: int, max_steps: int) -> bool:
|
||||||
|
"""DFS to check that it's a valid path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
board: The board representation as a list of lists.
|
||||||
|
max_size: Maximum size of the board.
|
||||||
|
max_steps: Maximum number of steps allowed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if there's a valid path from start to goal within max_steps,
|
||||||
|
False otherwise.
|
||||||
|
"""
|
||||||
|
frontier, discovered = [], set()
|
||||||
|
# find the start point
|
||||||
|
start_r, start_c = np.where(np.array(board) == "S")
|
||||||
|
frontier.append((start_r[0], start_c[0], 0)) # row, col steps
|
||||||
|
# dfs to check if there is a path from start to goal
|
||||||
|
while frontier:
|
||||||
|
r, c, steps = frontier.pop()
|
||||||
|
if steps > max_steps:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if (r, c) not in discovered:
|
||||||
|
discovered.add((r, c))
|
||||||
|
directions = [(1, 0), (0, 1), (-1, 0), (0, -1)]
|
||||||
|
for x, y in directions:
|
||||||
|
r_new = r + x
|
||||||
|
c_new = c + y
|
||||||
|
if (
|
||||||
|
r_new < 0
|
||||||
|
or r_new >= max_size
|
||||||
|
or c_new < 0
|
||||||
|
or c_new >= max_size
|
||||||
|
): # noqa: PLR2004
|
||||||
|
continue
|
||||||
|
if board[r_new][c_new] == "G":
|
||||||
|
return True
|
||||||
|
if board[r_new][c_new] != "H":
|
||||||
|
frontier.append((r_new, c_new, steps + 1))
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def generate_random_map(
|
||||||
|
size: int = 8,
|
||||||
|
p: float = 0.8,
|
||||||
|
seed: int = 0,
|
||||||
|
max_steps: int = 5,
|
||||||
|
) -> Tuple[list[str], Tuple[int, int]]:
|
||||||
|
"""Generates a random valid map (one that has a path from start to goal).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
size: Size of each side of the grid.
|
||||||
|
p: Probability that a tile is frozen.
|
||||||
|
seed: Seed to ensure the generation of reproducible maps.
|
||||||
|
max_steps: Maximum number of steps allowed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple containing a random valid map and the goal position (row, col).
|
||||||
|
"""
|
||||||
|
valid = False
|
||||||
|
board: list[list[str]] = [] # initialize to make pyright happy
|
||||||
|
|
||||||
|
try:
|
||||||
|
from gymnasium.utils import seeding
|
||||||
|
|
||||||
|
np_random, _ = seeding.np_random(seed)
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError(
|
||||||
|
"Gymnasium is not installed. "
|
||||||
|
"Please install gymnasium first before "
|
||||||
|
"running the frozen_lake workflow.",
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
# generate random start and end points
|
||||||
|
while not valid:
|
||||||
|
p = min(1, p)
|
||||||
|
board = np_random.choice(
|
||||||
|
["F", "H"],
|
||||||
|
(size, size),
|
||||||
|
p=[p, 1 - p],
|
||||||
|
).tolist()
|
||||||
|
|
||||||
|
while True:
|
||||||
|
start_r = int(np_random.integers(0, size))
|
||||||
|
start_c = int(np_random.integers(0, size))
|
||||||
|
goal_r = int(np_random.integers(0, size))
|
||||||
|
goal_c = int(np_random.integers(0, size))
|
||||||
|
|
||||||
|
# Ensure start and goal are different positions
|
||||||
|
if (start_r, start_c) != (goal_r, goal_c):
|
||||||
|
break
|
||||||
|
|
||||||
|
board[start_r][start_c] = "S"
|
||||||
|
board[goal_r][goal_c] = "G"
|
||||||
|
|
||||||
|
valid = is_valid(board, size, max_steps)
|
||||||
|
return ["".join(x) for x in board], (goal_r, goal_c)
|
||||||
|
|
||||||
|
|
||||||
|
def get_goal_position(
|
||||||
|
random_map: np.ndarray,
|
||||||
|
) -> Optional[Tuple[int, int]]:
|
||||||
|
"""Get the goal position from a random map.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
random_map: The map as a numpy array.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (row, col) if goal found, None otherwise.
|
||||||
|
"""
|
||||||
|
positions = np.argwhere(random_map == b"G")
|
||||||
|
if positions.size == 0:
|
||||||
|
return None # G not found
|
||||||
|
return tuple(positions[0]) # returns (row, col)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"SYSTEM_PROMPT",
|
||||||
|
"FrozenLakeAction",
|
||||||
|
"generate_random_map",
|
||||||
|
"get_goal_position",
|
||||||
|
]
|
||||||
53
tuner/frozen_lake/config.yaml
Normal file
53
tuner/frozen_lake/config.yaml
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
project: "AgentScope" # Project name
|
||||||
|
name: "FrozenLake" # Experiment name
|
||||||
|
checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints} # Directory to save model checkpoints
|
||||||
|
algorithm:
|
||||||
|
algorithm_type: multi_step_grpo # GRPO series for multi-step scenario
|
||||||
|
repeat_times: 16 # Number of rollouts per prompt for advantage estimation
|
||||||
|
kl_loss_fn: "low_var_kl"
|
||||||
|
kl_loss_fn_args:
|
||||||
|
kl_coef: 0 # KL divergence coefficient
|
||||||
|
advantage_fn_args:
|
||||||
|
epsilon: 1e-6 # Small value for numerical stability
|
||||||
|
std_threshold: 0.0001 # Threshold for standard deviation
|
||||||
|
optimizer:
|
||||||
|
lr: 1e-6 # Learning rate
|
||||||
|
model:
|
||||||
|
model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-3B-Instruct} # Base model path
|
||||||
|
max_prompt_tokens: 23552 # Max tokens for prompt
|
||||||
|
max_response_tokens: 2048 # Max tokens per response
|
||||||
|
max_model_len: 25600 # Max context length
|
||||||
|
temperature: 1.0 # Sampling temperature
|
||||||
|
buffer:
|
||||||
|
total_epochs: 5 # Total training epochs
|
||||||
|
batch_size: 32 # Batch size per explore step
|
||||||
|
train_batch_size: 1024 # Total experiences per training step
|
||||||
|
trainer_input:
|
||||||
|
experience_buffer:
|
||||||
|
name: experience_buffer
|
||||||
|
storage_type: queue
|
||||||
|
max_read_timeout: 7200 # Max timeout for reading from buffer (seconds)
|
||||||
|
replay_buffer:
|
||||||
|
enable: true # Enable experience replay
|
||||||
|
priority_fn: linear_decay # Priority function for replay buffer
|
||||||
|
priority_fn_args:
|
||||||
|
decay: 0.1 # Decay rate for priority function
|
||||||
|
explorer:
|
||||||
|
runner_per_model: 16 # Number of runners per model
|
||||||
|
rollout_model:
|
||||||
|
engine_num: 6 # Number of vLLM engines for rollout model
|
||||||
|
tensor_parallel_size: 1 # TP size per engine for rollout model
|
||||||
|
enable_openai_api: true # Enable OpenAI-compatible API
|
||||||
|
enable_history: true # Enable conversation history
|
||||||
|
enable_auto_tool_choice: true # Enable automatic tool selection
|
||||||
|
tool_call_parser: hermes # Parser for tool calls
|
||||||
|
trainer:
|
||||||
|
save_interval: 100 # Save checkpoint every N steps
|
||||||
|
use_dynamic_bsz: true # Use dynamic batch size
|
||||||
|
grad_clip: 1.0 # Gradient clipping value
|
||||||
|
max_token_len_per_gpu: 25600 # Max token length per GPU
|
||||||
|
ulysses_sequence_parallel_size: 2 # Sequence parallel size for Ulysses
|
||||||
|
synchronizer:
|
||||||
|
sync_style: dynamic_by_explorer # Sync triggered dynamically by explorer
|
||||||
|
sync_interval: 1 # Sync every N steps
|
||||||
|
sync_timeout: 1200 # Timeout for synchronization (seconds)
|
||||||
BIN
tuner/frozen_lake/critic_rewards_mean.png
Normal file
BIN
tuner/frozen_lake/critic_rewards_mean.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 62 KiB |
131
tuner/frozen_lake/get_frozenlake_data.py
Normal file
131
tuner/frozen_lake/get_frozenlake_data.py
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
Modified from rllm
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_DATA_PATH = os.path.join(
|
||||||
|
os.path.dirname(os.path.abspath(__file__)),
|
||||||
|
"..",
|
||||||
|
"data",
|
||||||
|
"frozenlake",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def save_dataset_to_local(
|
||||||
|
data_path: str,
|
||||||
|
data: list[dict],
|
||||||
|
split: str = "default",
|
||||||
|
) -> str:
|
||||||
|
"""Save dataset directly to local data_path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_path: Path to save the dataset
|
||||||
|
data: List of dictionaries containing the dataset examples
|
||||||
|
split: Split name (e.g., 'train', 'test', 'default')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Path to the saved parquet file
|
||||||
|
"""
|
||||||
|
os.makedirs(data_path, exist_ok=True)
|
||||||
|
|
||||||
|
# Convert to DataFrame and save
|
||||||
|
data_df = pd.DataFrame(data)
|
||||||
|
dataset_path = os.path.join(data_path, f"{split}.parquet")
|
||||||
|
data_df.to_parquet(dataset_path)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"Saved dataset frozenlake split '{split}' "
|
||||||
|
f"with {len(data)} examples at {dataset_path}. "
|
||||||
|
f"Make sure to set the environment variable "
|
||||||
|
f"<TRINITY_TASKSET_PATH> to {data_path}.",
|
||||||
|
)
|
||||||
|
|
||||||
|
return dataset_path
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_frozenlake_data(
|
||||||
|
data_path: str,
|
||||||
|
train_size: int = 10000,
|
||||||
|
test_size: int = 100,
|
||||||
|
map_max_size: int = 6,
|
||||||
|
) -> tuple[list[dict], list[dict]]:
|
||||||
|
"""
|
||||||
|
Prepare and save FrozenLake datasets for training and testing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_path (str): Path to save the dataset
|
||||||
|
train_size (int): Number of training examples to generate
|
||||||
|
test_size (int): Number of test examples to generate
|
||||||
|
map_max_size (int): Maximum size of the map
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (train_data, test_data) - Lists of data dictionaries
|
||||||
|
"""
|
||||||
|
# Set random seed for reproducibility
|
||||||
|
np.random.seed(42)
|
||||||
|
|
||||||
|
# Generate random parameters for train and test sets
|
||||||
|
train_seeds = np.random.randint(0, 100000, size=train_size)
|
||||||
|
test_seeds = np.random.randint(0, 100000, size=test_size)
|
||||||
|
train_sizes = np.random.randint(2, map_max_size, size=train_size)
|
||||||
|
test_sizes = np.random.randint(2, map_max_size, size=test_size)
|
||||||
|
train_ps = np.random.uniform(0.6, 0.85, size=train_size)
|
||||||
|
test_ps = np.random.uniform(0.6, 0.85, size=test_size)
|
||||||
|
|
||||||
|
def frozenlake_process_fn(
|
||||||
|
seed: int,
|
||||||
|
size: int,
|
||||||
|
p: float,
|
||||||
|
idx: int,
|
||||||
|
) -> dict:
|
||||||
|
"""Process function to create FrozenLake task instances."""
|
||||||
|
return {
|
||||||
|
"seed": seed,
|
||||||
|
"size": size,
|
||||||
|
"p": p,
|
||||||
|
"index": idx,
|
||||||
|
"uid": f"{seed}_{size}_{p}",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create train and test data
|
||||||
|
train_data_list = [
|
||||||
|
frozenlake_process_fn(seed, train_sizes[idx], train_ps[idx], idx)
|
||||||
|
for idx, seed in enumerate(train_seeds)
|
||||||
|
]
|
||||||
|
test_data_list = [
|
||||||
|
frozenlake_process_fn(seed, test_sizes[idx], test_ps[idx], idx)
|
||||||
|
for idx, seed in enumerate(test_seeds)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Save datasets directly to local DATA_PATH
|
||||||
|
save_dataset_to_local(data_path, train_data_list, "train")
|
||||||
|
save_dataset_to_local(data_path, test_data_list, "test")
|
||||||
|
|
||||||
|
return train_data_list, test_data_list
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--local_dir", default=DEFAULT_DATA_PATH)
|
||||||
|
parser.add_argument("--train_size", type=int, default=10000)
|
||||||
|
parser.add_argument("--test_size", type=int, default=100)
|
||||||
|
parser.add_argument("--map_max_size", type=int, default=6)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
train_data, test_data = prepare_frozenlake_data(
|
||||||
|
data_path=args.local_dir,
|
||||||
|
train_size=args.train_size,
|
||||||
|
test_size=args.test_size,
|
||||||
|
map_max_size=args.map_max_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Train dataset: {len(train_data)} examples")
|
||||||
|
print(f"Test dataset: {len(test_data)} examples")
|
||||||
|
print("Sample train example:", train_data[0])
|
||||||
|
print("Sample test example:", test_data[0])
|
||||||
151
tuner/frozen_lake/main.py
Normal file
151
tuner/frozen_lake/main.py
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""Example of training a FrozenLake agent with Trinity-RFT."""
|
||||||
|
import os
|
||||||
|
from typing import Dict
|
||||||
|
from _frozenlake_agent import FrozenLakeAgent
|
||||||
|
from _frozenlake_env import FrozenLakeEnv
|
||||||
|
from agentscope.message import Msg
|
||||||
|
from agentscope.tuner import (
|
||||||
|
tune,
|
||||||
|
WorkflowOutput,
|
||||||
|
DatasetConfig,
|
||||||
|
TunerModelConfig,
|
||||||
|
AlgorithmConfig,
|
||||||
|
)
|
||||||
|
from agentscope.model import ChatModelBase
|
||||||
|
|
||||||
|
|
||||||
|
async def run_frozen_lake(
|
||||||
|
task: Dict,
|
||||||
|
model: ChatModelBase,
|
||||||
|
auxiliary_models: Dict[str, ChatModelBase],
|
||||||
|
) -> WorkflowOutput:
|
||||||
|
"""A workflow function using the FrozenLake agent to solve tasks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task (Dict): The task to be solved, containing environment parameters
|
||||||
|
like size, p, seed, is_slippery, etc.
|
||||||
|
model (ChatModelBase): The language model to use.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
WorkflowOutput: The workflow output containing the reward, response and
|
||||||
|
metrics.
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert len(auxiliary_models) == 0, "No auxiliary models are needed"
|
||||||
|
|
||||||
|
# Extract workflow arguments from task or use defaults
|
||||||
|
workflow_args = task.get("workflow_args", {})
|
||||||
|
if not workflow_args:
|
||||||
|
workflow_args = task
|
||||||
|
|
||||||
|
env_max_steps = workflow_args.get("env_max_steps", 8)
|
||||||
|
agent_max_steps = workflow_args.get("agent_max_steps", 10)
|
||||||
|
is_slippery = workflow_args.get("is_slippery", False)
|
||||||
|
desc = workflow_args.get("desc", None)
|
||||||
|
|
||||||
|
# Extract task-specific arguments (for environment generation)
|
||||||
|
size = task.get("size", 8)
|
||||||
|
p = task.get("p", 0.8)
|
||||||
|
seed = task.get("seed", 42)
|
||||||
|
|
||||||
|
# Initialize agent and environment
|
||||||
|
agent = FrozenLakeAgent(model=model, max_steps=agent_max_steps)
|
||||||
|
env = FrozenLakeEnv(
|
||||||
|
max_steps=env_max_steps,
|
||||||
|
desc=desc,
|
||||||
|
is_slippery=is_slippery,
|
||||||
|
size=size,
|
||||||
|
p=p,
|
||||||
|
seed=seed,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reset environment with task parameters
|
||||||
|
observation, _ = env.reset(task)
|
||||||
|
observation_str = str(observation)
|
||||||
|
rewards = []
|
||||||
|
step_count = 0
|
||||||
|
done = False
|
||||||
|
terminate_reason = None
|
||||||
|
|
||||||
|
# Run agent-environment interaction loop
|
||||||
|
for _ in range(agent_max_steps):
|
||||||
|
step_count += 1
|
||||||
|
try:
|
||||||
|
# get prompt
|
||||||
|
prompt = agent.get_prompt(observation_str)
|
||||||
|
|
||||||
|
response = await agent.reply(msg=Msg("user", prompt, role="user"))
|
||||||
|
|
||||||
|
# record action and observation
|
||||||
|
action = agent.get_action(response)
|
||||||
|
agent.update_state(action=action, observation=observation_str)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
terminate_reason = f"agent_error: {str(e)}"
|
||||||
|
break
|
||||||
|
|
||||||
|
# environment step
|
||||||
|
observation, reward, done, _ = env.step(action)
|
||||||
|
observation_str = str(observation)
|
||||||
|
rewards.append(reward)
|
||||||
|
|
||||||
|
if done:
|
||||||
|
terminate_reason = "success" if env.success() else "hole"
|
||||||
|
break
|
||||||
|
|
||||||
|
if terminate_reason is None:
|
||||||
|
terminate_reason = "max_steps_reached"
|
||||||
|
|
||||||
|
final_reward = sum(rewards)
|
||||||
|
final_observation = observation_str
|
||||||
|
|
||||||
|
# Create response message with environment information
|
||||||
|
response_content = (
|
||||||
|
f"Final observation:\n{final_observation}\n"
|
||||||
|
f"Total reward: {final_reward}\n"
|
||||||
|
f"Steps taken: {step_count}\n"
|
||||||
|
f"Terminate reason: {terminate_reason}"
|
||||||
|
)
|
||||||
|
|
||||||
|
final_response = Msg("assistant", response_content, role="assistant")
|
||||||
|
|
||||||
|
return WorkflowOutput(
|
||||||
|
reward=final_reward,
|
||||||
|
response=final_response,
|
||||||
|
metrics={
|
||||||
|
"env_steps": float(step_count),
|
||||||
|
"env_done": float(done),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
dataset = DatasetConfig(
|
||||||
|
path="/path/to/frozenlake",
|
||||||
|
split="train",
|
||||||
|
)
|
||||||
|
tuner_model = TunerModelConfig(
|
||||||
|
model_path="Qwen/Qwen2.5-3B-Instruct",
|
||||||
|
max_model_len=25600,
|
||||||
|
max_tokens=2048,
|
||||||
|
inference_engine_num=6,
|
||||||
|
reasoning_parser=None,
|
||||||
|
)
|
||||||
|
algorithm = AlgorithmConfig(
|
||||||
|
algorithm_type="multi_step_grpo",
|
||||||
|
group_size=16,
|
||||||
|
batch_size=32,
|
||||||
|
learning_rate=1e-6,
|
||||||
|
)
|
||||||
|
config_path = os.path.join(
|
||||||
|
os.path.dirname(__file__),
|
||||||
|
"config.yaml",
|
||||||
|
) # define some default parameters
|
||||||
|
tune(
|
||||||
|
workflow_func=run_frozen_lake,
|
||||||
|
model=tuner_model,
|
||||||
|
train_dataset=dataset,
|
||||||
|
algorithm=algorithm,
|
||||||
|
config_path=config_path,
|
||||||
|
)
|
||||||
393
tuner/learn_to_ask/README.md
Normal file
393
tuner/learn_to_ask/README.md
Normal file
@@ -0,0 +1,393 @@
|
|||||||
|
# Training Learn2Ask with AgentScope-Tuner
|
||||||
|
|
||||||
|
This guide demonstrates how to train a proactive LLM using the **Learn2Ask** framework from [Grounded in Reality: Learning and Deploying Proactive LLM from Offline Logs](https://arxiv.org/abs/2510.25441).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Task Setting
|
||||||
|
|
||||||
|
In this example, given a user's chief complaint, the medical assistant agent proactively asks targeted questions to gather sufficient symptom information for a comprehensive assessment of the user's health condition. The querying process should be efficient: the agent must optimize question quality, and terminate the interview as soon as the collected information is adequate for subsequent clinical assessment or decision-making.
|
||||||
|
Here we use the `ReActAgent` for this task and no tools are required.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Hardware Requirements
|
||||||
|
|
||||||
|
- **Training with GPUs**: At least **8 H20 GPUs** (or equivalent).
|
||||||
|
- **Training without GPUs**: You can use the **[Tinker](https://thinkingmachines.ai/tinker/)** without any GPUs.
|
||||||
|
|
||||||
|
> 💡 All code and configuration files are located in:
|
||||||
|
> `tuner/learn_to_ask/`
|
||||||
|
|
||||||
|
Key files:
|
||||||
|
- Workflow & Training: `tuner/learn_to_ask/main.py`
|
||||||
|
- Prompts: `tuner/learn_to_ask/prompt.py`
|
||||||
|
- Training config: `tuner/learn_to_ask/config.yaml`
|
||||||
|
- Data preparation scripts: `tuner/learn_to_ask/data_prepare/`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Dataset Preparation
|
||||||
|
|
||||||
|
> [!NOTE]
|
||||||
|
> In this example, we use an open-source dataset directly for training. In practice, however, you would typically start by collecting interaction logs between your deployed agent and users. After filtering these raw logs to curate a high-quality dataset, you can follow the same pipeline to enhance your agent’s proactive capabilities using AgentScope-Tuner. Happy tuning!
|
||||||
|
|
||||||
|
### 1.1 Download the Dataset
|
||||||
|
Download the **[RealMedConv](https://huggingface.co/datasets/datajuicer/RealMedConv)** dataset (in `.jsonl` format).
|
||||||
|
You can use the following python scripts to download the dataset:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
# Download to local directory, e.g., `./tuner/learn_to_ask/data`
|
||||||
|
local_dir = "./tuner/learn_to_ask/data"
|
||||||
|
snapshot_download(
|
||||||
|
repo_id="datajuicer/RealMedConv",
|
||||||
|
repo_type="dataset",
|
||||||
|
local_dir=local_dir,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
Each line in `train_origin.jsonl` (or `test_origin.jsonl`) represents a complete doctor-patient conversation log, like this:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"session_id": 35310,
|
||||||
|
"diagn": "Upper Respiratory Tract Infection",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Sore throat, phlegm, red eyes, cough, hoarse voice"},
|
||||||
|
{"role": "user", "content": "I took Amoxicillin"},
|
||||||
|
...
|
||||||
|
{"role": "assistant", "content": "<med_search>"}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 1.2 Preprocess the Data
|
||||||
|
You need to convert raw conversation logs into training samples. This involves two steps:
|
||||||
|
|
||||||
|
#### 🔹 Step A: Segment Conversations & Extract Labels
|
||||||
|
Split each conversation into **context–future pairs**, and extract ground-truth symptom information (`info_truth`) from what happens next.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python tuner/learn_to_ask/data_prepare/1_info_extract_pipeline.py \
|
||||||
|
--input_file /path/to/RealMedConv/train.jsonl \
|
||||||
|
--output_file tuner/learn_to_ask/data_raw/train_processed.jsonl \
|
||||||
|
--model_path Qwen/Qwen2.5-32B-Instruct
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 🔹 Step B: Build Final Training Dataset
|
||||||
|
Convert the processed samples into the final format used for training/testing.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python tuner/learn_to_ask/data_prepare/2_build_dataset.py \
|
||||||
|
--input_file tuner/learn_to_ask/data_raw/train_processed.jsonl \
|
||||||
|
--output_file tuner/learn_to_ask/data/train.jsonl
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### How It Works: Context–Future Segmentation
|
||||||
|
|
||||||
|
For every turn in a conversation, we create a sample with:
|
||||||
|
- `messages`: The **observed dialogue history** up to that point (the *context*).
|
||||||
|
- `remaining_chat`: Everything that happens **after** that point (the *future*).
|
||||||
|
- A unique ID: `cid = {session_id}_{turn_index}`
|
||||||
|
|
||||||
|
Example output:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"cid": "35310_7",
|
||||||
|
"session_id": "35310",
|
||||||
|
"diagn": "Upper Respiratory Tract Infection",
|
||||||
|
"messages": [ ... up to turn 7 ... ],
|
||||||
|
"remaining_chat": [ ... all future messages ... ]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Extract Ground-Truth Labels
|
||||||
|
|
||||||
|
From `remaining_chat`, we automatically derive two key labels:
|
||||||
|
- `decision_truth`: Should the assistant **continue asking questions** (`"continue"`) or **stop** (`"stop"`)?
|
||||||
|
- `info_truth`: Structured list of symptoms mentioned later (used to compute reward signals during training).
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"decision_truth": "continue",
|
||||||
|
"info_truth": "Symptom: sore throat, Symptom quality: thick discharge, Symptom quality: yellowish discharge, ..."
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
These labels power the reward functions $R_a$ (action accuracy) and $R_s$ (symptom coverage) during training.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Code Implementation
|
||||||
|
|
||||||
|
### Agent Workflow
|
||||||
|
|
||||||
|
The workflow function `run_react_agent` implements how the `ReActAgent` works.
|
||||||
|
|
||||||
|
```python
|
||||||
|
async def run_react_agent(
|
||||||
|
task: Dict,
|
||||||
|
model: OpenAIChatModel,
|
||||||
|
auxiliary_models: Dict[str, OpenAIChatModel],
|
||||||
|
) -> WorkflowOutput:
|
||||||
|
assert (
|
||||||
|
len(auxiliary_models) == 1
|
||||||
|
), "Please provide only one `auxiliary_models` for `learn_to_ask`."
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
spec = importlib.util.spec_from_file_location(
|
||||||
|
"prompt",
|
||||||
|
os.path.join(os.path.dirname(__file__), "prompt.py"),
|
||||||
|
)
|
||||||
|
module = importlib.util.module_from_spec(spec)
|
||||||
|
spec.loader.exec_module(module)
|
||||||
|
if TRAIN_MODE == "Ra":
|
||||||
|
sys_prompt = module.rollout_prompt_med_Ra
|
||||||
|
else:
|
||||||
|
sys_prompt = module.rollout_prompt_med
|
||||||
|
|
||||||
|
agent = ReActAgent(
|
||||||
|
name="react_agent",
|
||||||
|
sys_prompt=sys_prompt,
|
||||||
|
model=model,
|
||||||
|
formatter=OpenAIChatFormatter(),
|
||||||
|
toolkit=None,
|
||||||
|
memory=InMemoryMemory(),
|
||||||
|
max_iters=1,
|
||||||
|
)
|
||||||
|
messages = format_messages(task["messages"])
|
||||||
|
response = await agent.reply(
|
||||||
|
[
|
||||||
|
Msg(name=x["role"], content=x["content"], role=x["role"])
|
||||||
|
for x in messages
|
||||||
|
],
|
||||||
|
)
|
||||||
|
return WorkflowOutput(
|
||||||
|
response=response,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Judge Function
|
||||||
|
|
||||||
|
The judge function `learn2ask_judge` implements reward calculation using LLM-as-a-Judge:
|
||||||
|
|
||||||
|
```python
|
||||||
|
async def learn2ask_judge(
|
||||||
|
task: Dict,
|
||||||
|
response: Msg,
|
||||||
|
auxiliary_models: Dict[str, OpenAIChatModel],
|
||||||
|
) -> JudgeOutput:
|
||||||
|
response_text = response.get_text_content()
|
||||||
|
action_truth = task.get("decision_truth", "continue")
|
||||||
|
action_response = "stop" if "<stop />" in response_text else "continue"
|
||||||
|
|
||||||
|
# Calculate action accuracy score
|
||||||
|
action_score = 1.0 if action_truth == action_response else 0.0
|
||||||
|
|
||||||
|
# Calculate format and content scores
|
||||||
|
if action_score == 1.0 and action_truth == "continue":
|
||||||
|
# Use LLM-as-a-Judge to evaluate question quality
|
||||||
|
score_dict = await llm_reward(task, response_text, auxiliary_models)
|
||||||
|
format_score = float(score_dict.get("format_score", 0.0))
|
||||||
|
content_score = float(score_dict.get("content_score", 0.0))
|
||||||
|
elif action_score == 1.0: # stop action
|
||||||
|
content_score, format_score = 1.0, (1.0 if response_text == "<stop />" else 0.0)
|
||||||
|
else:
|
||||||
|
format_score = content_score = 0.0
|
||||||
|
|
||||||
|
# Combine final reward based on training mode
|
||||||
|
if TRAIN_MODE == "Ra+Rs": # Default: action + symptom rewards
|
||||||
|
final_reward = action_score * (1 + 2 * content_score) + format_score
|
||||||
|
elif TRAIN_MODE == "Ra": # Action reward only
|
||||||
|
final_reward = 2 * content_score + format_score
|
||||||
|
else: # Symptom reward only
|
||||||
|
final_reward = action_score * 3 + format_score
|
||||||
|
|
||||||
|
return JudgeOutput(reward=final_reward, metrics={"reward": final_reward})
|
||||||
|
```
|
||||||
|
|
||||||
|
This reward function considers:
|
||||||
|
- Action accuracy: `action_score`
|
||||||
|
- Question quality (Symptom coverage): `content_score`
|
||||||
|
- Format score: `format_score`
|
||||||
|
|
||||||
|
See [main.py](./main.py) for implementation details.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Configure and Train the Model
|
||||||
|
|
||||||
|
### Option A: Edit Python Script (Simple)
|
||||||
|
Open `tuner/learn_to_ask/main.py` and adjust settings:
|
||||||
|
|
||||||
|
```python
|
||||||
|
if __name__ == "__main__":
|
||||||
|
train_mode = "Ra+Rs" # Use both action and symptom rewards
|
||||||
|
fusion_mode = "default" # How to combine rewards
|
||||||
|
dataset = DatasetConfig(path="tuner/learn_to_ask/data", split="train")
|
||||||
|
|
||||||
|
tuner_model = OpenAIChatModel(
|
||||||
|
model_path="Qwen/Qwen2.5-7B-Instruct",
|
||||||
|
max_model_len=8192,
|
||||||
|
tensor_parallel_size=1, # Adjust based on your GPU setup
|
||||||
|
...
|
||||||
|
)
|
||||||
|
|
||||||
|
auxiliary_models = {
|
||||||
|
AUXILIARY_MODEL_NAME: OpenAIChatModel(
|
||||||
|
model_path="Qwen/Qwen2.5-32B-Instruct", # Larger model for evaluation
|
||||||
|
tensor_parallel_size=2,
|
||||||
|
...
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
algorithm = AlgorithmConfig(
|
||||||
|
algorithm_type="grpo",
|
||||||
|
learning_rate=5e-7,
|
||||||
|
batch_size=64,
|
||||||
|
)
|
||||||
|
|
||||||
|
tune(...) # Starts training
|
||||||
|
```
|
||||||
|
|
||||||
|
### Option B: Use YAML Config (Advanced)
|
||||||
|
Edit `tuner/learn_to_ask/config.yaml` for more control.
|
||||||
|
|
||||||
|
#### 🌐 No GPU? Use Tinker!
|
||||||
|
If you don’t have GPUs, enable the **Tinker backend** by setting:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model:
|
||||||
|
tinker:
|
||||||
|
enable: true # ← Set this to true
|
||||||
|
```
|
||||||
|
|
||||||
|
Also, make sure to update the `model_path` in `tuner/learn_to_ask/main.py` to point to a model that’s compatible with Tinker.
|
||||||
|
|
||||||
|
> 🔗 Learn more about Tinker Backend: [Tinker Backend Documentation](https://agentscope-ai.github.io/Trinity-RFT/en/main/tutorial/example_tinker_backend.html)
|
||||||
|
|
||||||
|
> In this provided example, training is configured for 4 epochs. When using Tinker, the total token consumption is approximately 112 million tokens, resulting in an estimated cost of approximately 18 USD.
|
||||||
|
|
||||||
|
### Launch Training
|
||||||
|
```bash
|
||||||
|
python tuner/learn_to_ask/main.py
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Evaluation
|
||||||
|
|
||||||
|
Use the **rollout-and-evaluate pipeline**:
|
||||||
|
1. Generate responses on the test set.
|
||||||
|
2. Score them using a powerful evaluator model (`Qwen2.5-32B-Instruct`).
|
||||||
|
|
||||||
|
Run evaluation:
|
||||||
|
```bash
|
||||||
|
python tuner/learn_to_ask/data_prepare/3_rollout_then_evaluate.py \
|
||||||
|
--eval_model_path path/to/your/trained/model \
|
||||||
|
--grader_model_path Qwen/Qwen2.5-32B-Instruct \
|
||||||
|
--test_file_path tuner/learn_to_ask/data/test.jsonl \
|
||||||
|
--rollout_file_path path/to/rollout.jsonl \
|
||||||
|
--eval_file_path path/to/output.jsonl
|
||||||
|
```
|
||||||
|
|
||||||
|
> ⚠️ **Note**: Your trained model must be converted to **Hugging Face format** first.
|
||||||
|
> See: [Converting FSDP Checkpoints Guide](https://agentscope-ai.github.io/Trinity-RFT/en/main/tutorial/faq.html)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Experimental Results
|
||||||
|
|
||||||
|
We compared three approaches:
|
||||||
|
- **Base model**: `Qwen2.5-7B-Instruct` (no fine-tuning)
|
||||||
|
- **Trinity**: Direct response generation
|
||||||
|
- **AgentScope-Tuner (Learn2Ask)**: Uses a ReAct agent for proactive questioning
|
||||||
|
|
||||||
|
| Metric | Base Model | Trinity | AgentScope-Tuner (Learn2Ask) |
|
||||||
|
|--------------------------------------|-----------:|--------:|--------------------:|
|
||||||
|
| Avg. continue content | 0.436 | 0.496 | 0.509 |
|
||||||
|
| Win rate (continue content) | 0.122 | 0.246 | 0.224 |
|
||||||
|
| Avg. continue decision accuracy | 0.963 | 0.909 | 0.922 |
|
||||||
|
| Avg. stop decision accuracy | 0.024 | 0.927 | 0.909 |
|
||||||
|
| **Total decision accuracy** | 0.467 | 0.918 | 0.915 |
|
||||||
|
| **Response format quality** | 0.376 | 0.713 | 0.882 |
|
||||||
|
| **Total reward** | 1.281 | 3.078 | 3.237 |
|
||||||
|
|
||||||
|
<div align="center">
|
||||||
|
<img src="./learn2ask.png" alt="Training Curves" width="90%"/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
> ✅ **Key insight**: Learn2Ask (AgentScope-Tuner) achieves the highest overall performance by teaching the model **when and what to ask**—making it truly proactive.
|
||||||
|
|
||||||
|
### Concrete Example
|
||||||
|
|
||||||
|
An example of the agent's behavior on evaluation data is shown below:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"cid": "206_1",
|
||||||
|
"session_id": 206,
|
||||||
|
"diagn": "Common cold",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Hello, currently there are many people consulting, please wait patiently."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hello, this morning around 10 o'clock, I suddenly felt a bit feverish and had a heavy head. I took my temperature and it was 38.5°C. I took one tablet of paracetamol at 1 o'clock, which helped a lot. By 6 or 7 in the evening, I felt cold and feverish again, so I took another tablet of paracetamol. I'm not sure if it's a side effect of the medication, but I feel a bit nauseous and my stomach keeps producing gas. Now, I feel cold again, have a stuffy nose with clear runny nose, an itchy throat with coughing, a heavy head, and can't sleep."
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"decision_truth": "continue",
|
||||||
|
"info_truth": "symptom: needle-like sensation, symptom location: ear, symptom nature: sharp pain, symptom severity: occurred twice",
|
||||||
|
"rollouts": [
|
||||||
|
"Do you also experience body aches along with your symptoms? \nA) Yes B) No C) Not Sure",
|
||||||
|
"Do you experience these symptoms every time after taking paracetamol? (A) Yes, definitely. (B) Sometimes. (C) No, never before.",
|
||||||
|
"Do you have a history of allergies? A) Yes B) No C) Not Sure"
|
||||||
|
],
|
||||||
|
"grades": [
|
||||||
|
{
|
||||||
|
"think": "The doctor's last message asks about body aches using multiple choice options rather than a single direct open-ended question, which deviates from the typical questioning style indicated by the grading policy. However, since it is framed in the form of \"Do you...\", it counts as one question despite being presented with choices. Thus, it satisfies the criteria for 'contains exactly one question'. For the content score, the query does not pertain directly to the details listed in the reference information such as the specific symptom ('needle-like' sharp pain in the ear that has occurred twice), thereby making it highly relevant yet indirectly related due to its focus on comprehensive assessment of flu-related discomforts without addressing specifics noted in the patient’s primary concern documented earlier.",
|
||||||
|
"action_score": 1,
|
||||||
|
"format_score": 1.0,
|
||||||
|
"content_score": 0.5
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"think": "The doctor's last message includes just one multiple-choice question regarding whether the patient experiences those mentioned symptoms each time they take paracetamol. This does relate highly to understanding possible drug-related symptoms; however, none of them aligns perfectly with \"needle-like\" sensations occurring specifically in ears according to the reference information given.",
|
||||||
|
"action_score": 1,
|
||||||
|
"format_score": 1.0,
|
||||||
|
"content_score": 0.5
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"think": "The doctor’s last statement does contain just one question pertaining to allergy history, which is highly relevant when trying to diagnose symptoms such as those described by the patient (fever, nausea). However, none of these concerns specifically relate back to the reference information detailing \"needle-like sensation\", \"sharp pain\" related to the ears occurring twice. Therefore, while highly pertinent medically, they do not pertain to the exact points outlined in the Ref Info section about the patient experience according to that specific prompt context.",
|
||||||
|
"action_score": 1,
|
||||||
|
"format_score": 1.0,
|
||||||
|
"content_score": 0.5
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 📚 Citation
|
||||||
|
|
||||||
|
If you use this code or framework, please cite our work:
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@misc{learn2ask,
|
||||||
|
title={Grounded in Reality: Learning and Deploying Proactive LLM from Offline Logs},
|
||||||
|
author={Fei Wei and Daoyuan Chen and Ce Wang and Yilun Huang and Yushuo Chen and Xuchen Pan and Yaliang Li and Bolin Ding},
|
||||||
|
year={2025},
|
||||||
|
eprint={2510.25441},
|
||||||
|
archivePrefix={arXiv},
|
||||||
|
primaryClass={cs.AI},
|
||||||
|
url={https://arxiv.org/abs/2510.25441}
|
||||||
|
}
|
||||||
|
```
|
||||||
391
tuner/learn_to_ask/README_zh.md
Normal file
391
tuner/learn_to_ask/README_zh.md
Normal file
@@ -0,0 +1,391 @@
|
|||||||
|
# 使用 AgentScope-Tuner 训练 Learn2Ask
|
||||||
|
|
||||||
|
本指南演示了如何使用来自 [Grounded in Reality: Learning and Deploying Proactive LLM from Offline Logs](https://arxiv.org/abs/2510.25441) 的 **Learn2Ask** 方法训练主动式 LLM。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 任务设置
|
||||||
|
|
||||||
|
在此示例中,给定用户的主诉,医疗助手智能体主动提出有针对性的问题,以收集足够的症状信息,从而全面评估用户的健康状况。查询过程应该高效:智能体必须优化问题质量,并在收集的信息足以进行后续临床评估或决策时立即终止访谈。
|
||||||
|
这里我们使用 `ReActAgent` 来完成此任务,不需要工具。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 硬件要求
|
||||||
|
|
||||||
|
- **使用 GPU 训练**:至少需要 **8 个 H20 GPU**(或同等配置)。
|
||||||
|
- **不使用 GPU 训练**:您可以使用 **[Tinker](https://thinkingmachines.ai/tinker/)**,无需任何 GPU。
|
||||||
|
|
||||||
|
> 💡 所有代码和配置文件位于:
|
||||||
|
> `tuner/learn_to_ask/`
|
||||||
|
|
||||||
|
关键文件:
|
||||||
|
- 工作流和训练:`tuner/learn_to_ask/main.py`
|
||||||
|
- 提示词:`tuner/learn_to_ask/prompt.py`
|
||||||
|
- 训练配置:`tuner/learn_to_ask/config.yaml`
|
||||||
|
- 数据准备脚本:`tuner/learn_to_ask/data_prepare/`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 数据集准备
|
||||||
|
|
||||||
|
> [!NOTE]
|
||||||
|
> 在此示例中,我们直接使用开源数据集进行训练。然而,在实践中,您通常需要先收集已部署智能体与用户之间的交互日志。在过滤这些原始日志以整理高质量数据集后,您可以遵循相同的流程,使用 AgentScope-Tuner 增强智能体的主动能力。祝调优愉快!
|
||||||
|
|
||||||
|
### 1.1 下载数据集
|
||||||
|
下载 **[RealMedConv](https://huggingface.co/datasets/datajuicer/RealMedConv)** 数据集(`.jsonl` 格式)。
|
||||||
|
您可以使用以下 Python 脚本下载数据集:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
# 下载到本地目录,例如 `./tuner/learn_to_ask/data`
|
||||||
|
local_dir = "./tuner/learn_to_ask/data"
|
||||||
|
snapshot_download(
|
||||||
|
repo_id="datajuicer/RealMedConv",
|
||||||
|
repo_type="dataset",
|
||||||
|
local_dir=local_dir,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
`train_origin.jsonl`(或 `test_origin.jsonl`)中的每一行代表一个完整的医患对话日志,如下所示:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"session_id": 35310,
|
||||||
|
"diagn": "Upper Respiratory Tract Infection",
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": "Sore throat, phlegm, red eyes, cough, hoarse voice"},
|
||||||
|
{"role": "user", "content": "I took Amoxicillin"},
|
||||||
|
...
|
||||||
|
{"role": "assistant", "content": "<med_search>"}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 1.2 预处理数据
|
||||||
|
您需要将原始对话日志转换为训练样本。这涉及两个步骤:
|
||||||
|
|
||||||
|
#### 🔹 步骤 A:分割对话并提取标签
|
||||||
|
将每个对话分割为**context–future pairs**,并从后续内容中提取真实症状信息(`info_truth`)。
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python tuner/learn_to_ask/data_prepare/1_info_extract_pipeline.py \
|
||||||
|
--input_file /path/to/RealMedConv/train.jsonl \
|
||||||
|
--output_file tuner/learn_to_ask/data_raw/train_processed.jsonl \
|
||||||
|
--model_path Qwen/Qwen2.5-32B-Instruct
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 🔹 步骤 B:构建最终训练数据集
|
||||||
|
将处理后的样本转换为用于训练/测试的最终格式。
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python tuner/learn_to_ask/data_prepare/2_build_dataset.py \
|
||||||
|
--input_file tuner/learn_to_ask/data_raw/train_processed.jsonl \
|
||||||
|
--output_file tuner/learn_to_ask/data/train.jsonl
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 工作原理:Context–Future pairsSegmentation
|
||||||
|
|
||||||
|
对于对话中的每一轮,我们创建一个样本,包含:
|
||||||
|
- `messages`:到该点为止的**已观察对话历史**(context)。
|
||||||
|
- `remaining_chat`:该点之后发生的**所有内容**(future)。
|
||||||
|
- 唯一 ID:`cid = {session_id}_{turn_index}`
|
||||||
|
|
||||||
|
示例输出:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"cid": "35310_7",
|
||||||
|
"session_id": "35310",
|
||||||
|
"diagn": "Upper Respiratory Tract Infection",
|
||||||
|
"messages": [ ... up to turn 7 ... ],
|
||||||
|
"remaining_chat": [ ... all future messages ... ]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 提取真实标签
|
||||||
|
|
||||||
|
从 `remaining_chat` 中,我们自动推导出两个关键标签:
|
||||||
|
- `decision_truth`:助手应该继续提问(`"continue"`)还是停止(`"stop"`)?
|
||||||
|
- `info_truth`:后续提到的结构化症状列表(用于在训练期间计算奖励信号)。
|
||||||
|
|
||||||
|
示例:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"decision_truth": "continue",
|
||||||
|
"info_truth": "Symptom: sore throat, Symptom quality: thick discharge, Symptom quality: yellowish discharge, ..."
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
这些标签在训练期间为奖励函数 $R_a$(动作准确性)和 $R_s$(症状覆盖率)提供支持。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 代码实现
|
||||||
|
|
||||||
|
### 智能体工作流
|
||||||
|
|
||||||
|
工作流函数 `run_react_agent` 实现了 `ReActAgent` 的工作方式。
|
||||||
|
|
||||||
|
```python
|
||||||
|
async def run_react_agent(
|
||||||
|
task: Dict,
|
||||||
|
model: OpenAIChatModel,
|
||||||
|
auxiliary_models: Dict[str, OpenAIChatModel],
|
||||||
|
) -> WorkflowOutput:
|
||||||
|
assert (
|
||||||
|
len(auxiliary_models) == 1
|
||||||
|
), "Please provide only one `auxiliary_models` for `learn_to_ask`."
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
spec = importlib.util.spec_from_file_location(
|
||||||
|
"prompt",
|
||||||
|
os.path.join(os.path.dirname(__file__), "prompt.py"),
|
||||||
|
)
|
||||||
|
module = importlib.util.module_from_spec(spec)
|
||||||
|
spec.loader.exec_module(module)
|
||||||
|
if TRAIN_MODE == "Ra":
|
||||||
|
sys_prompt = module.rollout_prompt_med_Ra
|
||||||
|
else:
|
||||||
|
sys_prompt = module.rollout_prompt_med
|
||||||
|
|
||||||
|
agent = ReActAgent(
|
||||||
|
name="react_agent",
|
||||||
|
sys_prompt=sys_prompt,
|
||||||
|
model=model,
|
||||||
|
formatter=OpenAIChatFormatter(),
|
||||||
|
toolkit=None,
|
||||||
|
memory=InMemoryMemory(),
|
||||||
|
max_iters=1,
|
||||||
|
)
|
||||||
|
messages = format_messages(task["messages"])
|
||||||
|
response = await agent.reply(
|
||||||
|
[
|
||||||
|
Msg(name=x["role"], content=x["content"], role=x["role"])
|
||||||
|
for x in messages
|
||||||
|
],
|
||||||
|
)
|
||||||
|
return WorkflowOutput(
|
||||||
|
response=response,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 评判函数
|
||||||
|
|
||||||
|
评判函数 `learn2ask_judge` 使用 LLM-as-a-Judge 实现奖励计算:
|
||||||
|
|
||||||
|
```python
|
||||||
|
async def learn2ask_judge(
|
||||||
|
task: Dict,
|
||||||
|
response: Msg,
|
||||||
|
auxiliary_models: Dict[str, OpenAIChatModel],
|
||||||
|
) -> JudgeOutput:
|
||||||
|
response_text = response.get_text_content()
|
||||||
|
action_truth = task.get("decision_truth", "continue")
|
||||||
|
action_response = "stop" if "<stop />" in response_text else "continue"
|
||||||
|
|
||||||
|
# 计算动作准确性分数
|
||||||
|
action_score = 1.0 if action_truth == action_response else 0.0
|
||||||
|
|
||||||
|
# 计算格式和内容分数
|
||||||
|
if action_score == 1.0 and action_truth == "continue":
|
||||||
|
# 使用 LLM-as-a-Judge 评估问题质量
|
||||||
|
score_dict = await llm_reward(task, response_text, auxiliary_models)
|
||||||
|
format_score = float(score_dict.get("format_score", 0.0))
|
||||||
|
content_score = float(score_dict.get("content_score", 0.0))
|
||||||
|
elif action_score == 1.0: # stop 动作
|
||||||
|
content_score, format_score = 1.0, (1.0 if response_text == "<stop />" else 0.0)
|
||||||
|
else:
|
||||||
|
format_score = content_score = 0.0
|
||||||
|
|
||||||
|
# 根据训练模式组合最终奖励
|
||||||
|
if TRAIN_MODE == "Ra+Rs": # 默认:动作 + 症状奖励
|
||||||
|
final_reward = action_score * (1 + 2 * content_score) + format_score
|
||||||
|
elif TRAIN_MODE == "Ra": # 仅动作奖励
|
||||||
|
final_reward = 2 * content_score + format_score
|
||||||
|
else: # 仅症状奖励
|
||||||
|
final_reward = action_score * 3 + format_score
|
||||||
|
|
||||||
|
return JudgeOutput(reward=final_reward, metrics={"reward": final_reward})
|
||||||
|
```
|
||||||
|
|
||||||
|
此奖励函数考虑:
|
||||||
|
- 动作准确性:`action_score`
|
||||||
|
- 问题质量(症状覆盖率):`content_score`
|
||||||
|
- 格式分数:`format_score`
|
||||||
|
|
||||||
|
有关实现细节,请参阅 [main.py](./main.py)。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 配置和训练模型
|
||||||
|
|
||||||
|
### 选项 A:编辑 Python 脚本(简单)
|
||||||
|
打开 `tuner/learn_to_ask/main.py` 并调整设置:
|
||||||
|
|
||||||
|
```python
|
||||||
|
if __name__ == "__main__":
|
||||||
|
train_mode = "Ra+Rs" # 同时使用动作和症状奖励
|
||||||
|
fusion_mode = "default" # 如何组合奖励
|
||||||
|
dataset = DatasetConfig(path="tuner/learn_to_ask/data", split="train")
|
||||||
|
|
||||||
|
tuner_model = OpenAIChatModel(
|
||||||
|
model_path="Qwen/Qwen2.5-7B-Instruct",
|
||||||
|
max_model_len=8192,
|
||||||
|
tensor_parallel_size=1, # 根据您的 GPU 设置调整
|
||||||
|
...
|
||||||
|
)
|
||||||
|
|
||||||
|
auxiliary_models = {
|
||||||
|
AUXILIARY_MODEL_NAME: OpenAIChatModel(
|
||||||
|
model_path="Qwen/Qwen2.5-32B-Instruct", # 用于评估的更大模型
|
||||||
|
tensor_parallel_size=2,
|
||||||
|
...
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
algorithm = AlgorithmConfig(
|
||||||
|
algorithm_type="grpo",
|
||||||
|
learning_rate=5e-7,
|
||||||
|
batch_size=64,
|
||||||
|
)
|
||||||
|
|
||||||
|
tune(...) # 开始训练
|
||||||
|
```
|
||||||
|
|
||||||
|
### 选项 B:使用 YAML 配置(高级)
|
||||||
|
编辑 `tuner/learn_to_ask/config.yaml` 以获得更多控制。
|
||||||
|
|
||||||
|
#### 🌐 没有 GPU?使用 Tinker!
|
||||||
|
如果您没有 GPU,可以通过设置启用 **Tinker 后端**:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model:
|
||||||
|
tinker:
|
||||||
|
enable: true # ← 将此设置为 true
|
||||||
|
```
|
||||||
|
|
||||||
|
此外,请确保更新 `tuner/learn_to_ask/main.py` 中的 `model_path`,使其指向与 Tinker 兼容的模型。
|
||||||
|
|
||||||
|
> 🔗 了解更多关于 Tinker 后端: [Tinker 后端文档](https://agentscope-ai.github.io/Trinity-RFT/zh/main/tutorial/example_tinker_backend.html)
|
||||||
|
|
||||||
|
### 启动训练
|
||||||
|
```bash
|
||||||
|
python tuner/learn_to_ask/main.py
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 评估
|
||||||
|
|
||||||
|
使用**rollout 和评估流程**:
|
||||||
|
1. 在测试集上生成响应。
|
||||||
|
2. 使用强大的评估模型(`Qwen2.5-32B-Instruct`)对它们进行评分。
|
||||||
|
|
||||||
|
运行评估:
|
||||||
|
```bash
|
||||||
|
python tuner/learn_to_ask/data_prepare/3_rollout_then_evaluate.py \
|
||||||
|
--eval_model_path path/to/your/trained/model \
|
||||||
|
--grader_model_path Qwen/Qwen2.5-32B-Instruct \
|
||||||
|
--test_file_path tuner/learn_to_ask/data/test.jsonl \
|
||||||
|
--rollout_file_path path/to/rollout.jsonl \
|
||||||
|
--eval_file_path path/to/output.jsonl
|
||||||
|
```
|
||||||
|
|
||||||
|
> ⚠️ **注意**:您的训练模型必须首先转换为 **Hugging Face 格式**。
|
||||||
|
> 请参阅:[转换 FSDP 检查点指南](https://agentscope-ai.github.io/Trinity-RFT/zh/main/tutorial/faq.html)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 实验结果
|
||||||
|
|
||||||
|
我们比较了三种方法:
|
||||||
|
- **基础模型**:`Qwen2.5-7B-Instruct`(无微调)
|
||||||
|
- **Trinity**:直接响应生成
|
||||||
|
- **AgentScope-Tuner (Learn2Ask)**:使用 ReAct 智能体进行主动提问
|
||||||
|
|
||||||
|
| 指标 | 基础模型 | Trinity | AgentScope-Tuner (Learn2Ask) |
|
||||||
|
|------------------------------------|---------:|--------:|--------------------:|
|
||||||
|
| 平均继续内容 | 0.436 | 0.496 | 0.509 |
|
||||||
|
| 胜率(继续内容) | 0.122 | 0.246 | 0.224 |
|
||||||
|
| 平均继续决策准确性 | 0.963 | 0.909 | 0.922 |
|
||||||
|
| 平均停止决策准确性 | 0.024 | 0.927 | 0.909 |
|
||||||
|
| **总决策准确性** | 0.467 | 0.918 | 0.915 |
|
||||||
|
| **响应格式质量** | 0.376 | 0.713 | 0.882 |
|
||||||
|
| **总奖励** | 1.281 | 3.078 | 3.237 |
|
||||||
|
|
||||||
|
<div align="center">
|
||||||
|
<img src="./learn2ask.png" alt="训练曲线" width="90%"/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
> ✅ **关键洞察**:Learn2Ask (AgentScope-Tuner) 通过教模型**何时提问和提问什么**,实现了最高的整体性能——使其真正具有主动性。
|
||||||
|
|
||||||
|
### 具体示例
|
||||||
|
|
||||||
|
下面显示了智能体在评估数据上的行为示例:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"cid": "206_1",
|
||||||
|
"session_id": 206,
|
||||||
|
"diagn": "Common cold",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Hello, currently there are many people consulting, please wait patiently."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hello, this morning around 10 o'clock, I suddenly felt a bit feverish and had a heavy head. I took my temperature and it was 38.5°C. I took one tablet of paracetamol at 1 o'clock, which helped a lot. By 6 or 7 in the evening, I felt cold and feverish again, so I took another tablet of paracetamol. I'm not sure if it's a side effect of the medication, but I feel a bit nauseous and my stomach keeps producing gas. Now, I feel cold again, have a stuffy nose with clear runny nose, an itchy throat with coughing, a heavy head, and can't sleep."
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"decision_truth": "continue",
|
||||||
|
"info_truth": "symptom: needle-like sensation, symptom location: ear, symptom nature: sharp pain, symptom severity: occurred twice",
|
||||||
|
"rollouts": [
|
||||||
|
"Do you also experience body aches along with your symptoms? \nA) Yes B) No C) Not Sure",
|
||||||
|
"Do you experience these symptoms every time after taking paracetamol? (A) Yes, definitely. (B) Sometimes. (C) No, never before.",
|
||||||
|
"Do you have a history of allergies? A) Yes B) No C) Not Sure"
|
||||||
|
],
|
||||||
|
"grades": [
|
||||||
|
{
|
||||||
|
"think": "The doctor's last message asks about body aches using multiple choice options rather than a single direct open-ended question, which deviates from the typical questioning style indicated by the grading policy. However, since it is framed in the form of \"Do you...\", it counts as one question despite being presented with choices. Thus, it satisfies the criteria for 'contains exactly one question'. For the content score, the query does not pertain directly to the details listed in the reference information such as the specific symptom ('needle-like' sharp pain in the ear that has occurred twice), thereby making it highly relevant yet indirectly related due to its focus on comprehensive assessment of flu-related discomforts without addressing specifics noted in the patient's primary concern documented earlier.",
|
||||||
|
"action_score": 1,
|
||||||
|
"format_score": 1.0,
|
||||||
|
"content_score": 0.5
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"think": "The doctor's last message includes just one multiple-choice question regarding whether the patient experiences those mentioned symptoms each time they take paracetamol. This does relate highly to understanding possible drug-related symptoms; however, none of them aligns perfectly with \"needle-like\" sensations occurring specifically in ears according to the reference information given.",
|
||||||
|
"action_score": 1,
|
||||||
|
"format_score": 1.0,
|
||||||
|
"content_score": 0.5
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"think": "The doctor's last statement does contain just one question pertaining to allergy history, which is highly relevant when trying to diagnose symptoms such as those described by the patient (fever, nausea). However, none of these concerns specifically relate back to the reference information detailing \"needle-like sensation\", \"sharp pain\" related to the ears occurring twice. Therefore, while highly pertinent medically, they do not pertain to the exact points outlined in the Ref Info section about the patient experience according to that specific prompt context.",
|
||||||
|
"action_score": 1,
|
||||||
|
"format_score": 1.0,
|
||||||
|
"content_score": 0.5
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 📚 引用
|
||||||
|
|
||||||
|
如果您使用此代码或框架,请引用我们的工作:
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@misc{learn2ask,
|
||||||
|
title={Grounded in Reality: Learning and Deploying Proactive LLM from Offline Logs},
|
||||||
|
author={Fei Wei and Daoyuan Chen and Ce Wang and Yilun Huang and Yushuo Chen and Xuchen Pan and Yaliang Li and Bolin Ding},
|
||||||
|
year={2025},
|
||||||
|
eprint={2510.25441},
|
||||||
|
archivePrefix={arXiv},
|
||||||
|
primaryClass={cs.AI},
|
||||||
|
url={https://arxiv.org/abs/2510.25441}
|
||||||
|
}
|
||||||
|
```
|
||||||
30
tuner/learn_to_ask/config.yaml
Normal file
30
tuner/learn_to_ask/config.yaml
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
project: AgentScope-ReAct
|
||||||
|
name: Learn_to_Ask-Qwen2.5-7B-fixed
|
||||||
|
# directory to save checkpoints, default to ./checkpoints if TRINITY_CHECKPOINT_ROOT_DIR not set
|
||||||
|
checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
|
||||||
|
algorithm:
|
||||||
|
algorithm_type: grpo # a GRPO-based algorithm for multi-step reasoning
|
||||||
|
model:
|
||||||
|
# path to the pre-trained model, default to Qwen/Qwen2.5-7B-Instruct if TRINITY_MODEL_PATH not set
|
||||||
|
model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-7B-Instruct}
|
||||||
|
tinker: # tinker config, you can set tinker parameters here
|
||||||
|
enable: false # if true, tinker will be enabled
|
||||||
|
cluster:
|
||||||
|
node_num: 1 # cluster with 1 node
|
||||||
|
gpu_per_node: 8 # each node has 8 GPUs
|
||||||
|
buffer:
|
||||||
|
total_epochs: 4 # run taskset for 4 epoch
|
||||||
|
explorer:
|
||||||
|
runner_per_model: 32 # each model has 32 runners for parallel rollout
|
||||||
|
max_timeout: 600 # max timeout for each rollout is 600 seconds
|
||||||
|
synchronizer:
|
||||||
|
sync_style: fixed
|
||||||
|
sync_method: 'nccl'
|
||||||
|
sync_interval: 10
|
||||||
|
sync_timeout: 7200 # wait for 120 minutes
|
||||||
|
trainer:
|
||||||
|
save_interval: 90 # save checkpoint every 90 steps
|
||||||
|
use_dynamic_bsz: true
|
||||||
|
ulysses_sequence_parallel_size: 1 # use sequence parallelism to reduce memory usage
|
||||||
|
monitor:
|
||||||
|
monitor_type: wandb # here we use wandb; you can also use tensorboard, mlflow or swanlab
|
||||||
175
tuner/learn_to_ask/data_prepare/1_info_extract_pipeline.py
Normal file
175
tuner/learn_to_ask/data_prepare/1_info_extract_pipeline.py
Normal file
@@ -0,0 +1,175 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# pylint: skip-file
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
from llm_info_extraction import llm_info_extraction, parse_llm_output
|
||||||
|
from message_splitter import split_session_to_json_lines
|
||||||
|
|
||||||
|
|
||||||
|
def process_jsonl_file(
|
||||||
|
input_file: str,
|
||||||
|
output_file: str,
|
||||||
|
model_call_mode: str = "online_api",
|
||||||
|
max_retries: int = 3,
|
||||||
|
**kwargs: dict,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Process all sessions in a JSONL file and save results to output file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_file (str): Path to input JSONL file
|
||||||
|
output_file (str): Path to output JSONL file
|
||||||
|
model_call_mode (str): Either "online_api" or "local_vllm"
|
||||||
|
max_retries (int): Maximum number of retries for LLM calls
|
||||||
|
**kwargs: Additional parameters for API calls
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Success message or error information
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Read and process each session
|
||||||
|
with open(input_file, "r", encoding="utf-8") as infile, open(
|
||||||
|
output_file,
|
||||||
|
"w",
|
||||||
|
encoding="utf-8",
|
||||||
|
) as outfile:
|
||||||
|
for line_num, line in enumerate(infile, 1):
|
||||||
|
if line.strip():
|
||||||
|
try:
|
||||||
|
session = json.loads(line)
|
||||||
|
print(
|
||||||
|
f"Processing session "
|
||||||
|
f"{session.get('session_id', 'unknown')} "
|
||||||
|
f"(line {line_num})...",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process the session
|
||||||
|
processed_lines = process_session(
|
||||||
|
session,
|
||||||
|
model_call_mode,
|
||||||
|
max_retries,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
for processed_line in processed_lines:
|
||||||
|
outfile.write(processed_line + "\n")
|
||||||
|
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
print(
|
||||||
|
f"Warning: Skipping invalid JSON at line "
|
||||||
|
f"{line_num}: {e}",
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(
|
||||||
|
f"Warning: Error processing session at line "
|
||||||
|
f"{line_num}: {e}",
|
||||||
|
)
|
||||||
|
|
||||||
|
return f"Successfully processed. Results saved to {output_file}"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error processing JSONL file: {str(e)}"
|
||||||
|
|
||||||
|
|
||||||
|
def process_session(
|
||||||
|
session: dict,
|
||||||
|
model_call_mode: str = "online_api",
|
||||||
|
max_retries: int = 3,
|
||||||
|
**kwargs: dict,
|
||||||
|
) -> Union[list, str]:
|
||||||
|
"""
|
||||||
|
Pipeline function that splits messages into rounds and extracts info from
|
||||||
|
each round's remaining chat.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session (dict): Session dictionary containing 'session_id', 'diagn',
|
||||||
|
and 'messages' keys
|
||||||
|
model_call_mode (str): Either "online_api" or "local_vllm"
|
||||||
|
max_retries (int): Maximum number of retries for LLM calls
|
||||||
|
**kwargs: Additional parameters for API calls
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: List of JSON strings with added "info_set" key,
|
||||||
|
or error information
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Step 1: Split messages into JSON lines
|
||||||
|
json_lines = split_session_to_json_lines(session)
|
||||||
|
|
||||||
|
# Step 2: Process each JSON line with LLM info extraction
|
||||||
|
processed_lines = []
|
||||||
|
|
||||||
|
for line in json_lines:
|
||||||
|
data = json.loads(line)
|
||||||
|
remaining_chat = data.get("remaining_chat", "")
|
||||||
|
|
||||||
|
# Retry loop for LLM calls
|
||||||
|
info_set = None
|
||||||
|
for attempt in range(max_retries):
|
||||||
|
try:
|
||||||
|
# Call LLM info extraction
|
||||||
|
# (using mock function for testing)
|
||||||
|
llm_response = llm_info_extraction(
|
||||||
|
remaining_chat,
|
||||||
|
model_call_mode,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
info_set = parse_llm_output(llm_response)
|
||||||
|
|
||||||
|
if isinstance(info_set, list):
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# If parsing failed, this is an error message
|
||||||
|
print(f"Attempt {attempt + 1} failed: {info_set}")
|
||||||
|
if attempt < max_retries - 1:
|
||||||
|
time.sleep(1)
|
||||||
|
except Exception as e:
|
||||||
|
print(
|
||||||
|
f"Attempt {attempt + 1} failed with exception: "
|
||||||
|
f"{str(e)}",
|
||||||
|
)
|
||||||
|
if attempt < max_retries - 1:
|
||||||
|
time.sleep(1) # Shorter wait for testing
|
||||||
|
|
||||||
|
data["info_set"] = info_set
|
||||||
|
processed_lines.append(json.dumps(data, ensure_ascii=False))
|
||||||
|
|
||||||
|
return processed_lines
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return f"Pipeline error: {str(e)}"
|
||||||
|
|
||||||
|
|
||||||
|
# Example usage:
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--input_file",
|
||||||
|
type=str,
|
||||||
|
default="tuner/learn_to_ask/data_raw/train_origin.jsonl",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output_file",
|
||||||
|
type=str,
|
||||||
|
default="tuner/learn_to_ask/data_raw/train_processed.jsonl",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_call_mode",
|
||||||
|
type=str,
|
||||||
|
choices=["online_api", "local_vllm"],
|
||||||
|
default="local_vllm",
|
||||||
|
)
|
||||||
|
parser.add_argument("--model_path", type=str, required=True)
|
||||||
|
args = parser.parse_args()
|
||||||
|
print(
|
||||||
|
process_jsonl_file(
|
||||||
|
input_file=args.input_file,
|
||||||
|
output_file=args.output_file,
|
||||||
|
model_call_mode=args.model_call_mode,
|
||||||
|
model_path=args.model_path,
|
||||||
|
# Additional parameters for API calls
|
||||||
|
),
|
||||||
|
)
|
||||||
65
tuner/learn_to_ask/data_prepare/2_build_dataset.py
Normal file
65
tuner/learn_to_ask/data_prepare/2_build_dataset.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# pylint: skip-file
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
def process_message(json_obj: dict) -> tuple:
|
||||||
|
info_set = json_obj.get("info_set")
|
||||||
|
info_set_str = ", ".join(info_set) if isinstance(info_set, list) else ""
|
||||||
|
if "user: " not in json_obj["remaining_chat"]:
|
||||||
|
decision_str = "stop"
|
||||||
|
else:
|
||||||
|
decision_str = "continue"
|
||||||
|
if not info_set_str and decision_str == "continue":
|
||||||
|
if_keep = False
|
||||||
|
else:
|
||||||
|
if_keep = True
|
||||||
|
return if_keep, info_set_str, decision_str
|
||||||
|
|
||||||
|
|
||||||
|
def main(input_file_path: str, output_file_path: str) -> None:
|
||||||
|
with open(input_file_path, "r", encoding="utf-8") as infile, open(
|
||||||
|
output_file_path,
|
||||||
|
"w",
|
||||||
|
encoding="utf-8",
|
||||||
|
) as outfile:
|
||||||
|
print("data processing started...")
|
||||||
|
for line in infile:
|
||||||
|
data = json.loads(line.strip())
|
||||||
|
if_keep, info_set, decision = process_message(data)
|
||||||
|
if not if_keep:
|
||||||
|
continue
|
||||||
|
|
||||||
|
new_item = {
|
||||||
|
"cid": data["cid"],
|
||||||
|
"session_id": data["session_id"],
|
||||||
|
"diagn": data["diagn"],
|
||||||
|
"messages": data["messages"],
|
||||||
|
"decision_truth": decision,
|
||||||
|
"info_truth": info_set,
|
||||||
|
}
|
||||||
|
outfile.write(json.dumps(new_item, ensure_ascii=False) + "\n")
|
||||||
|
print("job done!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
# The file generated by 1_info_extract_pipeline.py
|
||||||
|
parser.add_argument(
|
||||||
|
"--input_file",
|
||||||
|
type=str,
|
||||||
|
default="tuner/learn_to_ask/data_raw/train_processed.jsonl",
|
||||||
|
)
|
||||||
|
|
||||||
|
# The final file for training or testing
|
||||||
|
parser.add_argument(
|
||||||
|
"--output_file",
|
||||||
|
type=str,
|
||||||
|
default="tuner/learn_to_ask/data/train.jsonl",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
main(args.input_file, args.output_file)
|
||||||
334
tuner/learn_to_ask/data_prepare/3_rollout_then_evaluate.py
Normal file
334
tuner/learn_to_ask/data_prepare/3_rollout_then_evaluate.py
Normal file
@@ -0,0 +1,334 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# pylint: skip-file
|
||||||
|
"""
|
||||||
|
This script is used to use VLLM to generate rollout samples from the converted
|
||||||
|
checkpoints.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import copy
|
||||||
|
import gc
|
||||||
|
import json
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import time
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import Any, AutoTokenizer
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
|
||||||
|
|
||||||
|
def init_llm(model_path: str) -> tuple:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||||
|
device_count = torch.cuda.device_count()
|
||||||
|
print(f"device_count={device_count}")
|
||||||
|
if device_count < 1:
|
||||||
|
raise RuntimeError("No GPU available for multi-card inference.")
|
||||||
|
print(f"Loading model from: {model_path}")
|
||||||
|
llm = LLM(model=model_path, tensor_parallel_size=device_count)
|
||||||
|
print("Model loaded successfully!")
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
temperature=1.0,
|
||||||
|
top_p=0.95,
|
||||||
|
max_tokens=512,
|
||||||
|
repetition_penalty=1.2,
|
||||||
|
)
|
||||||
|
return llm, tokenizer, sampling_params
|
||||||
|
|
||||||
|
|
||||||
|
def rollout(
|
||||||
|
llm: Any,
|
||||||
|
tokenizer: Any,
|
||||||
|
sampling_params: Any,
|
||||||
|
input_file_path: str,
|
||||||
|
output_file_path: str,
|
||||||
|
rollout_repeat: int = 3,
|
||||||
|
) -> None:
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
spec = importlib.util.spec_from_file_location(
|
||||||
|
"prompt",
|
||||||
|
os.path.join(os.path.dirname(__file__), "..", "prompt.py"),
|
||||||
|
)
|
||||||
|
module = importlib.util.module_from_spec(spec) # type: ignore
|
||||||
|
spec.loader.exec_module(module)
|
||||||
|
rollout_prompt = module.rollout_prompt_med
|
||||||
|
|
||||||
|
with open(input_file_path, "r", encoding="utf-8") as lines:
|
||||||
|
sample_list = [json.loads(line.strip()) for line in lines]
|
||||||
|
print(f"loaded samples: {len(sample_list)}")
|
||||||
|
|
||||||
|
for index, sample in enumerate(sample_list):
|
||||||
|
record = copy.deepcopy(sample)
|
||||||
|
print(f"index: {index}, session_id: {sample['session_id']}")
|
||||||
|
messages = [{"role": "system", "content": rollout_prompt}] + sample[
|
||||||
|
"messages"
|
||||||
|
]
|
||||||
|
|
||||||
|
# Some tokenizers (e.g., Qwen) support the `enable_thinking` argument,
|
||||||
|
# but others do not. Try with the argument first, and fall back if
|
||||||
|
# it is not accepted.
|
||||||
|
try:
|
||||||
|
prompt = tokenizer.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
enable_thinking=False,
|
||||||
|
)
|
||||||
|
except TypeError:
|
||||||
|
prompt = tokenizer.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
response_list = []
|
||||||
|
for i in range(rollout_repeat):
|
||||||
|
time_probe = time.perf_counter()
|
||||||
|
outputs = llm.generate([prompt], sampling_params=sampling_params)
|
||||||
|
print(f"time cost: {time.perf_counter() - time_probe}")
|
||||||
|
for output in outputs:
|
||||||
|
response = output.outputs[0].text
|
||||||
|
response_list.append(response)
|
||||||
|
print(f"rollout #{i}: {response}\n")
|
||||||
|
record["rollouts"] = response_list
|
||||||
|
|
||||||
|
# append to output file
|
||||||
|
with open(output_file_path, "a") as f:
|
||||||
|
f.write(json.dumps(record, ensure_ascii=False) + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
def eval_sample(
|
||||||
|
llm: Any,
|
||||||
|
tokenizer: Any,
|
||||||
|
sampling_params: Any,
|
||||||
|
input_file_path: str,
|
||||||
|
output_file_path: str,
|
||||||
|
) -> None:
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
spec = importlib.util.spec_from_file_location(
|
||||||
|
"prompt",
|
||||||
|
os.path.join(os.path.dirname(__file__), "..", "prompt.py"),
|
||||||
|
)
|
||||||
|
module = importlib.util.module_from_spec(spec) # type: ignore
|
||||||
|
spec.loader.exec_module(module)
|
||||||
|
grader_prompt = module.reward_prompt_med
|
||||||
|
|
||||||
|
print(f"input_file_path: {input_file_path}")
|
||||||
|
print(f"output_file_path: {output_file_path}")
|
||||||
|
|
||||||
|
with open(input_file_path, "r", encoding="utf-8") as lines:
|
||||||
|
sample_list = [json.loads(line.strip()) for line in lines]
|
||||||
|
print(f"Total records: {len(sample_list)}")
|
||||||
|
|
||||||
|
def res_formatter(res_content: str) -> dict:
|
||||||
|
pattern = r"<(\w+)>(.*?)</\1>"
|
||||||
|
matches = re.findall(pattern, res_content)
|
||||||
|
result = {}
|
||||||
|
for tag_name, content in matches:
|
||||||
|
result[tag_name] = content
|
||||||
|
return result
|
||||||
|
|
||||||
|
def msg2str(msg_list: List) -> str:
|
||||||
|
result_str = ""
|
||||||
|
for msg in msg_list:
|
||||||
|
if msg["role"] == "user":
|
||||||
|
result_str += f"patient: {msg['content']}\n"
|
||||||
|
if msg["role"] == "assistant":
|
||||||
|
result_str += f"doctor: {msg['content']}\n"
|
||||||
|
return result_str
|
||||||
|
|
||||||
|
for index, sample in enumerate(sample_list):
|
||||||
|
print(f"index: {index}, cid: {sample['cid']}")
|
||||||
|
action_truth = sample["decision_truth"]
|
||||||
|
info_truth = sample["info_truth"] if sample["info_truth"] else "None."
|
||||||
|
print(f"action_truth: {action_truth}, info_truth:{info_truth}")
|
||||||
|
|
||||||
|
sys_prompt = grader_prompt.format(info_truth)
|
||||||
|
history = msg2str(sample["messages"])
|
||||||
|
|
||||||
|
sample["grades"] = []
|
||||||
|
for rollout in sample["rollouts"]:
|
||||||
|
time_probe = time.perf_counter()
|
||||||
|
action_score, content_score, format_score, res_think = (
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
"NA",
|
||||||
|
)
|
||||||
|
if "<stop />" in rollout:
|
||||||
|
action_rollout = "stop"
|
||||||
|
else:
|
||||||
|
action_rollout = "continue"
|
||||||
|
if action_truth == action_rollout:
|
||||||
|
action_score = 1
|
||||||
|
if action_truth == "continue":
|
||||||
|
user_content = history + f"doctor: {rollout}"
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": sys_prompt},
|
||||||
|
{"role": "user", "content": user_content},
|
||||||
|
]
|
||||||
|
try:
|
||||||
|
prompt = tokenizer.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
enable_thinking=False,
|
||||||
|
)
|
||||||
|
except TypeError:
|
||||||
|
# Fallback for tokenizers that do not support the
|
||||||
|
# Qwen-specific `enable_thinking` argument.
|
||||||
|
prompt = tokenizer.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
)
|
||||||
|
outputs = llm.generate(
|
||||||
|
[prompt],
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
)
|
||||||
|
for output in outputs:
|
||||||
|
response = output.outputs[0].text
|
||||||
|
print(f"Response: {response}\n")
|
||||||
|
res_dict = res_formatter(response)
|
||||||
|
try:
|
||||||
|
format_score = float(res_dict.get("format_score", 0.0))
|
||||||
|
content_score = float(
|
||||||
|
res_dict.get("content_score", 0.0),
|
||||||
|
)
|
||||||
|
res_think = res_dict.get("think", "None")
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
else:
|
||||||
|
content_score = 1.0
|
||||||
|
format_score = 1.0 if rollout == "<stop />" else 0.0
|
||||||
|
else:
|
||||||
|
action_score, format_score, content_score = 0, 0, 0
|
||||||
|
grade_result = {
|
||||||
|
"think": res_think,
|
||||||
|
"action_score": action_score,
|
||||||
|
"format_score": format_score,
|
||||||
|
"content_score": content_score,
|
||||||
|
}
|
||||||
|
sample["grades"].append(grade_result)
|
||||||
|
json_str = json.dumps(grade_result, ensure_ascii=False, indent=2)
|
||||||
|
print(
|
||||||
|
f"grade_result:{json_str}",
|
||||||
|
)
|
||||||
|
print(f"time_cost:{time.perf_counter() - time_probe}")
|
||||||
|
# append sample to output file
|
||||||
|
with open(output_file_path, "a") as f:
|
||||||
|
f.write(json.dumps(sample, ensure_ascii=False) + "\n")
|
||||||
|
print("\n======================\n")
|
||||||
|
|
||||||
|
|
||||||
|
def compute_score(input_file_path: str) -> None:
|
||||||
|
with open(input_file_path, "r", encoding="utf-8") as lines:
|
||||||
|
sample_list = [json.loads(line.strip()) for line in lines]
|
||||||
|
continue_count, continue_content_score, continue_content_full = 0, 0, 0
|
||||||
|
continue_decision_score = 0
|
||||||
|
stop_count, stop_decision_score = 0, 0
|
||||||
|
total_reward, total_format = 0, 0
|
||||||
|
(
|
||||||
|
continue_count_correct,
|
||||||
|
continue_content_score_correct,
|
||||||
|
continue_content_full_correct,
|
||||||
|
) = (0, 0, 0)
|
||||||
|
for sample in sample_list:
|
||||||
|
for rollout, grade in zip(sample["rollouts"], sample["grades"]):
|
||||||
|
if math.isnan(grade["content_score"]) or math.isnan(
|
||||||
|
grade["format_score"],
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
if sample["decision_truth"] == "continue":
|
||||||
|
continue_count += 1
|
||||||
|
continue_content_score += grade["content_score"]
|
||||||
|
continue_content_full += (
|
||||||
|
1 if grade["content_score"] == 1 else 0
|
||||||
|
)
|
||||||
|
continue_decision_score += grade["action_score"]
|
||||||
|
if "<stop />" not in rollout:
|
||||||
|
continue_count_correct += 1
|
||||||
|
continue_content_score_correct += grade["content_score"]
|
||||||
|
continue_content_full_correct += (
|
||||||
|
1 if grade["content_score"] == 1 else 0
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
stop_count += 1
|
||||||
|
stop_decision_score += grade["action_score"]
|
||||||
|
total_reward += (
|
||||||
|
grade["action_score"] * (1 + 2 * grade["content_score"])
|
||||||
|
+ grade["format_score"]
|
||||||
|
)
|
||||||
|
total_format += grade["format_score"]
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"ave_continue_content": continue_content_score
|
||||||
|
/ max(1, continue_count),
|
||||||
|
"win_continue_content": continue_content_full / max(1, continue_count),
|
||||||
|
"ave_continue_content if correct": continue_content_score_correct
|
||||||
|
/ max(1, continue_count_correct),
|
||||||
|
"win_continue_content if correct": continue_content_full_correct
|
||||||
|
/ max(1, continue_count_correct),
|
||||||
|
"ave_continue_decision": continue_decision_score
|
||||||
|
/ max(1, continue_count),
|
||||||
|
"ave_stop_decision": stop_decision_score / max(1, stop_count),
|
||||||
|
"ave_total_decision": (continue_decision_score + stop_decision_score)
|
||||||
|
/ max(1, continue_count + stop_count),
|
||||||
|
"ave_total_format": total_format / max(1, continue_count + stop_count),
|
||||||
|
"ave_total_reward": total_reward / max(1, continue_count + stop_count),
|
||||||
|
}
|
||||||
|
|
||||||
|
print(f"total count: {continue_count + stop_count}")
|
||||||
|
print(json.dumps(result, ensure_ascii=False, indent=4))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--rollout_repeat", type=int, default=3)
|
||||||
|
|
||||||
|
# Ckpt for testing
|
||||||
|
parser.add_argument("--eval_model_path", type=str, required=True)
|
||||||
|
|
||||||
|
# Model to empower the grading, Qwen2.5-32b-instruct is recommended
|
||||||
|
parser.add_argument("--grader_model_path", type=str, required=True)
|
||||||
|
|
||||||
|
# Your test sample path [input]
|
||||||
|
parser.add_argument("--test_file_path", type=str, required=True)
|
||||||
|
|
||||||
|
# Rollout results given test samples [output]
|
||||||
|
parser.add_argument("--rollout_file_path", type=str, required=True)
|
||||||
|
|
||||||
|
# Final output given rollout results [output]
|
||||||
|
parser.add_argument("--eval_file_path", type=str, required=True)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# rollout stage
|
||||||
|
llm, tokenizer, sampling_params = init_llm(args.eval_model_path)
|
||||||
|
rollout(
|
||||||
|
llm,
|
||||||
|
tokenizer,
|
||||||
|
sampling_params,
|
||||||
|
args.test_file_path,
|
||||||
|
args.rollout_file_path,
|
||||||
|
args.rollout_repeat,
|
||||||
|
)
|
||||||
|
del llm # clean up the memory after the inference
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache() # release gpu memory
|
||||||
|
|
||||||
|
# eval stage
|
||||||
|
llm2, tokenizer2, sampling_params2 = init_llm(args.grader_model_path)
|
||||||
|
eval_sample(
|
||||||
|
llm2,
|
||||||
|
tokenizer2,
|
||||||
|
sampling_params2,
|
||||||
|
args.rollout_file_path,
|
||||||
|
args.eval_file_path,
|
||||||
|
)
|
||||||
|
compute_score(args.eval_file_path)
|
||||||
174
tuner/learn_to_ask/data_prepare/llm_info_extraction.py
Normal file
174
tuner/learn_to_ask/data_prepare/llm_info_extraction.py
Normal file
@@ -0,0 +1,174 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# pylint: skip-file
|
||||||
|
import os
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
import openai
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
|
||||||
|
tokenizer = None
|
||||||
|
llm = None
|
||||||
|
|
||||||
|
|
||||||
|
def llm_info_extraction(
|
||||||
|
remaining_chat: str,
|
||||||
|
model_call_mode: str,
|
||||||
|
**kwargs: dict,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Extract information from remaining_chat using LLM.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
remaining_chat (str): The chat content to process
|
||||||
|
model_call_mode (str): Either "online_api" or "local_vllm"
|
||||||
|
**kwargs: Additional parameters for API calls
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Response text from LLM or error information
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Create messages format with system and user roles
|
||||||
|
system_message = """
|
||||||
|
# Task:
|
||||||
|
You are a medical information assistant. Given a dialogue between a physician (assistant) and a patient (user), extract the clinical attributes of interest to the physician based on their questions. The target fields include: symptom, symptom nature, symptom location, symptom severity, and symptom trigger. Then, identify the corresponding specific information from the patient's responses and pair it with the respective field.
|
||||||
|
# Requirements:
|
||||||
|
- Do not fabricate information or introduce new fields not listed above. Ignore patient-reported information regarding prior medication use, allergies, or underlying comorbidities; do not include such details in the output.
|
||||||
|
- Only include fields explicitly inquired about by the physician. Omit any fields not addressed in the dialogue. Avoid outputting vague terms (e.g., "unspecified" or "unknown").
|
||||||
|
- Prevent duplication: if a symptom description already includes anatomical location, do not separately list the location field.
|
||||||
|
- Format each entry as a string enclosed in single quotes ('), and separate multiple entries with commas, ensuring any necessary escape characters within the strings. Enclose the entire output within square brackets to form a list. If the dialogue is unrelated to the aforementioned clinical attributes, output only "[]".
|
||||||
|
- Do not include reasoning steps or additional commentary outside the specified format. Condense colloquial patient expressions into concise, standardized, and clinically appropriate terminology.
|
||||||
|
# Example output format:
|
||||||
|
['symptom: diarrhea', 'symptom nature: watery stool', 'symptom severity: 4-5 times per day']
|
||||||
|
""" # noqa: E501
|
||||||
|
user_message = remaining_chat
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": system_message},
|
||||||
|
{"role": "user", "content": "```\n" + user_message + "\n```\n"},
|
||||||
|
]
|
||||||
|
|
||||||
|
try:
|
||||||
|
if model_call_mode == "online_api":
|
||||||
|
# OpenAI-style API call
|
||||||
|
return _call_online_api(messages, **kwargs)
|
||||||
|
elif model_call_mode == "local_vllm":
|
||||||
|
# Local vLLM call
|
||||||
|
return _call_local_vllm(messages, **kwargs)
|
||||||
|
else:
|
||||||
|
return (
|
||||||
|
f"Error: Invalid model_call_mode '{model_call_mode}'. "
|
||||||
|
"Must be 'online_api' or 'local_vllm'."
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error occurred: {str(e)}"
|
||||||
|
|
||||||
|
|
||||||
|
def _call_online_api(messages: List, **kwargs: dict) -> str:
|
||||||
|
"""Handle OpenAI-style API calls"""
|
||||||
|
# Extract API parameters from kwargs or use defaults
|
||||||
|
api_key = kwargs.get("api_key", os.getenv("DASHSCOPE_API_KEY"))
|
||||||
|
api_base = kwargs.get(
|
||||||
|
"api_base",
|
||||||
|
"https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||||
|
)
|
||||||
|
model = kwargs.get("model", "qwen2.5-72b-instruct")
|
||||||
|
temperature = kwargs.get("temperature", 0.7)
|
||||||
|
max_tokens = kwargs.get("max_tokens", 500)
|
||||||
|
|
||||||
|
client = openai.OpenAI(api_key=api_key, base_url=api_base)
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
return response.choices[0].message.content
|
||||||
|
|
||||||
|
|
||||||
|
def _call_local_vllm(messages: List, **kwargs: dict) -> str:
|
||||||
|
"""Handle local vLLM calls"""
|
||||||
|
try:
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
|
||||||
|
model_path = kwargs.get("model_path")
|
||||||
|
if not model_path:
|
||||||
|
return "Error: model_path is required for local vLLM inference"
|
||||||
|
|
||||||
|
temperature = kwargs.get("temperature", 0.7)
|
||||||
|
max_tokens = kwargs.get("max_tokens", 512)
|
||||||
|
top_p = kwargs.get("top_p", 0.9)
|
||||||
|
repetition_penalty = kwargs.get("repetition_penalty", 1.1)
|
||||||
|
|
||||||
|
# GPU/CUDA related parameters for vLLM
|
||||||
|
tensor_parallel_size = kwargs.get(
|
||||||
|
"tensor_parallel_size",
|
||||||
|
torch.cuda.device_count(),
|
||||||
|
)
|
||||||
|
gpu_memory_utilization = kwargs.get("gpu_memory_utilization", 0.9)
|
||||||
|
enforce_eager = kwargs.get("enforce_eager", False)
|
||||||
|
dtype = kwargs.get("dtype", "auto")
|
||||||
|
max_model_len = kwargs.get("max_model_len", 4096)
|
||||||
|
|
||||||
|
# Initialize the LLM with the provided model path and GPU parameters
|
||||||
|
global llm, tokenizer
|
||||||
|
if llm is None:
|
||||||
|
llm = LLM(
|
||||||
|
model=model_path,
|
||||||
|
tensor_parallel_size=tensor_parallel_size,
|
||||||
|
gpu_memory_utilization=gpu_memory_utilization,
|
||||||
|
enforce_eager=enforce_eager,
|
||||||
|
dtype=dtype,
|
||||||
|
max_model_len=max_model_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
repetition_penalty=repetition_penalty,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert messages to a single prompt string
|
||||||
|
if tokenizer is None:
|
||||||
|
tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
|
||||||
|
prompt = tokenizer.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = llm.generate([prompt], sampling_params)
|
||||||
|
|
||||||
|
return outputs[0].outputs[0].text
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
return (
|
||||||
|
"Error: vLLM library not installed. "
|
||||||
|
"Please install it with 'pip install vllm'"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error in local vLLM inference: {str(e)}"
|
||||||
|
|
||||||
|
|
||||||
|
def parse_llm_output(output_str: str) -> Union[List[str], str]:
|
||||||
|
"""
|
||||||
|
Convert the LLM info extraction output string to a list of strings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_str (str): String in format "['symptom: diarrhea', 'symptom nature: watery stool', 'symptom severity: 4-5 times per day']"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: List of strings if successful, error message string if failed
|
||||||
|
""" # noqa: E501
|
||||||
|
import ast
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = ast.literal_eval(output_str)
|
||||||
|
if not isinstance(result, list):
|
||||||
|
return f"Error: Expected a list, got {type(result)}"
|
||||||
|
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error parsing output: [{repr(output_str)}] error = {str(e)}"
|
||||||
119
tuner/learn_to_ask/data_prepare/message_splitter.py
Normal file
119
tuner/learn_to_ask/data_prepare/message_splitter.py
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# pylint: skip-file
|
||||||
|
import json
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
|
||||||
|
def split_single_message_list(messages: List) -> List:
|
||||||
|
"""
|
||||||
|
Split a single message list into multiple rounds.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages (list): List of message dictionaries with 'role' and 'content' keys
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: List of rounds, where each round contains messages and remaining chat
|
||||||
|
""" # noqa: E501
|
||||||
|
rounds = []
|
||||||
|
round_number = 1
|
||||||
|
i = 0
|
||||||
|
|
||||||
|
while i < len(messages):
|
||||||
|
# Collect messages for this round
|
||||||
|
round_messages = []
|
||||||
|
|
||||||
|
# Add messages until we reach a user message
|
||||||
|
while i < len(messages) and messages[i].get("role") != "user":
|
||||||
|
round_messages.append(messages[i])
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
# Add user message(s) - if there are consecutive user messages,
|
||||||
|
# include all of them in this round
|
||||||
|
while i < len(messages) and messages[i].get("role") == "user":
|
||||||
|
round_messages.append(messages[i])
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
# The remaining messages (if any) form the remaining_chat
|
||||||
|
remaining_messages = messages[i:]
|
||||||
|
round_entry = {
|
||||||
|
"round_number": round_number,
|
||||||
|
"messages": round_messages,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add remaining chat if there are remaining messages
|
||||||
|
if remaining_messages:
|
||||||
|
remaining_chat_parts = []
|
||||||
|
for msg in remaining_messages:
|
||||||
|
role = msg.get("role", "")
|
||||||
|
content = msg.get("content", "")
|
||||||
|
remaining_chat_parts.append(f"{role}: {content}")
|
||||||
|
round_entry["remaining_chat"] = "\n".join(remaining_chat_parts)
|
||||||
|
else:
|
||||||
|
round_entry["remaining_chat"] = ""
|
||||||
|
|
||||||
|
rounds.append(round_entry)
|
||||||
|
round_number += 1
|
||||||
|
|
||||||
|
return rounds
|
||||||
|
|
||||||
|
|
||||||
|
def split_session_to_json_lines(session: Dict) -> List[str]:
|
||||||
|
"""
|
||||||
|
Split a session dictionary into multiple rounds and convert to JSON lines.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session (dict): Session dictionary containing 'session_id', 'diagn', and 'messages' keys
|
||||||
|
- session_id (str): Session identifier
|
||||||
|
- diagn (str): Diagnosis information
|
||||||
|
- messages (list): List of message dictionaries with 'role' and 'content' keys
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: List of JSON strings, each representing a round with cid, session_id, diagn, messages, and remaining_chat
|
||||||
|
""" # noqa: E501
|
||||||
|
rounds = split_single_message_list(session["messages"])
|
||||||
|
|
||||||
|
json_lines = []
|
||||||
|
for round_data in rounds:
|
||||||
|
round_entry = {
|
||||||
|
"cid": f"{session['session_id']}_{round_data['round_number']}",
|
||||||
|
"session_id": session["session_id"],
|
||||||
|
"diagn": session["diagn"],
|
||||||
|
"messages": round_data["messages"],
|
||||||
|
"remaining_chat": round_data["remaining_chat"],
|
||||||
|
}
|
||||||
|
|
||||||
|
json_lines.append(json.dumps(round_entry, ensure_ascii=False))
|
||||||
|
|
||||||
|
return json_lines
|
||||||
|
|
||||||
|
|
||||||
|
# Example usage:
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Example of splitting a single message list
|
||||||
|
example_messages = [
|
||||||
|
{"role": "assistant", "content": "Hello, how can I help you today?"},
|
||||||
|
{"role": "user", "content": "I've been having headaches lately."},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "How long have you been experiencing these headaches?",
|
||||||
|
},
|
||||||
|
{"role": "user", "content": "For about a week now."},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "I see. Have you taken any medication for them?",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Yes, I've tried some over-the-counter pain relievers.",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
example_session = {
|
||||||
|
"session_id": "session_1",
|
||||||
|
"diagn": "migraine",
|
||||||
|
"messages": example_messages,
|
||||||
|
}
|
||||||
|
json_lines = split_session_to_json_lines(example_session)
|
||||||
|
print("JSON lines output:")
|
||||||
|
for i, line in enumerate(json_lines):
|
||||||
|
print(f"Line {i + 1}: {line}")
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user