10 Commits

Author SHA1 Message Date
7fbfc96795 feat(ui): update frontend components and agent prompts
Some checks failed
Pre-commit / run (ubuntu-latest) (push) Has been cancelled
- Update analyst, portfolio_manager, risk_manager prompts
- Enhance analysis tools functionality
- Improve UI components (AgentCard, AgentFeed, RoomView, etc.)
- Update constants configuration

Co-Authored-By: Claude <noreply@anthropic.com>
2026-03-13 04:31:06 +08:00
Dandan Liu
679431b303 feat(alias core): add data source management (#110)
Co-authored-by: Tianjing Zeng <39507457+StCarmen@users.noreply.github.com>
Co-authored-by: stcarmen <1106135234@qq.com>
2026-02-02 13:57:16 +08:00
Osier-Yi
df0776cac4 initial the create_time to actual timestamp (#113) 2026-01-30 17:32:26 +08:00
Yuchang Sun
400c1e77bf Add README (Chinese) for tuner (#106) 2026-01-20 19:46:50 +08:00
chenyushuo
311ddfff46 Fix in learn to ask (#107) 2026-01-20 14:13:20 +08:00
XiuShenAl
b843abea93 feat(alias): add AgentScope-Runtime support (#95) 2026-01-20 12:23:45 +08:00
lingzhq
e7f1fdf7ff Add example for data augmentation in tuner (#98) 2026-01-19 12:25:52 +08:00
Yuchang Sun
654c35127a Add examples for frozenlake and emailsearch (#94) 2026-01-19 12:25:13 +08:00
chenyushuo
3821fb04ac Add Tuner learn_to_ask example (#101) 2026-01-16 19:24:46 +08:00
garyzhang99
5855c5161b Add examples for werewolf game tuner (#96) 2026-01-16 17:25:49 +08:00
119 changed files with 13464 additions and 574 deletions

View File

@@ -72,7 +72,7 @@ This is a repository that **brings together a variety of ready-to-run Python age
│ ├── frozen_lake/ # Teach an agent to play a game requiring multiple steps
│ ├── learn_to_ask/ # Using LLM-as-a-judge to facilitate agent tuning
│ ├── email_search/ # Enhance the tool use ability of your agent
│ ├── werewolf_game/ # Enhance a multi-agent application
│ ├── werewolves/ # Enhance a multi-agent application
│ └── data_augment/ # Data augmentation for tuning
├── sample_template/ # Template for new sample contributions
└── README.md

View File

@@ -72,7 +72,7 @@
│ ├── frozen_lake/ # 教一个智能体玩需要多步操作的游戏
│ ├── learn_to_ask/ # 使用 LLM 作为评委辅助智能体训练
│ ├── email_search/ # 提升智能体的工具使用能力
│ ├── werewolf_game/ # 强化多智能体应用能力
│ ├── werewolves/ # 强化多智能体应用能力
│ └── data_augment/ # 增强用于调优的数据
├── sample_template/ # 新样例贡献模板
└── README.md

View File

@@ -207,10 +207,23 @@ alias_agent run --mode finance --task "Analyze Tesla's Q4 2024 financial perform
# Data Science mode
alias_agent run --mode ds \
--task "Analyze the distribution of incidents across categories in 'incident_records.csv' to identify imbalances, inconsistencies, or anomalies, and determine their root cause." \
--files ./docs/data/incident_records.csv
--datasource ./docs/data/incident_records.csv
```
**Note**: Files uploaded with `--files` are automatically copied to `/workspace` in the sandbox. Generated files are available in `sessions_mount_dir` subdirectories.
#### Input/Output Management
**Input:**
- Use the `--datasource` parameter (with aliases `--files` for backward compatibility) to specify data sources, supporting multiple formats:
- **Local files**: such as `./data.txt` or `/absolute/path/file.json`
- **Database DSN**: supports relational databases like PostgreSQL and SQLite, with format like `postgresql://user:password@host:port/database`
Examples: `--datasource file.txt postgresql://user:password@localhost:5432/mydb`
- Specified data sources will be automatically profiled (analyzed) and provide guidance for efficient data source access to the model.
- Uploaded files are automatically copied to the `/workspace` directory in the sandbox.
**Output:**
- Generated files are stored in subdirectories of `sessions_mount_dir`, where all output results can be found.
#### Enable Long-Term Memory Service (General Mode Only)
To enable the long-term memory service in General mode, you need to:
@@ -371,6 +384,85 @@ After the first startup, you can log in with the superuser credentials configure
- **Password**: As specified in `FIRST_SUPERUSER_PASSWORD`
### 🌐 Basic Usage -- AgentScope Runtime Deployment
Alias is now fully compatible with [AgentScope Runtime](https://github.com/agentscope-ai/agentscope-runtime/), enabling you to quickly deploy Alias as a standardized backend service. Once launched, you can easily invoke Alias capabilities via the accompanying AgentScope Runtime API.
#### 1. Prerequisites
* **Sandbox & API Keys**: Please refer to the previous sections [🐳 Sandbox Setup (Optional)](#-sandbox-setup-optional) and [🔑 API Keys Configuration](#-api-keys-configuration) to complete the basic environment setup.
* **Environment Variables**: Copy the example environment file from the project root:
```bash
cp .env.example .env
```
* **Start Redis**: Required for caching and session management:
```bash
docker run -d -p 6379:6379 --name alias-redis redis:7-alpine
```
#### 2. Installation & Sandbox Launch
Install the package in editable mode from the project root. This will automatically install the `alias_agent_runtime` CLI tool:
```bash
pip install -e .
```
To ensure proper code execution and file operations, start the sandbox server in a separate terminal:
```bash
runtime-sandbox-server --extension src/alias/runtime/alias_sandbox/alias_sandbox.py
```
#### 3. Launching AgentScope Runtime Service
You can choose to start the service via the CLI or Python code, depending on your use case.
##### Option A: Using CLI (Recommended)
Use the `alias_agent_runtime` command to launch the backend service with one click:
```bash
alias_agent_runtime --host 127.0.0.1 --port 8090 --chat-mode general
```
**Parameter Descriptions**:
* `--host` / `--port`: Specify the service address and port (default port is 8090).
* `--chat-mode`: Set the running mode. Options: `general`, `dr`, `browser`, `ds`, `finance` (default: `general`).
* `--web-ui`: (Optional) Enable AgentScope Runtime WebUI for a visual interaction interface. Skip this if you only need the API.
> **Note**: When enabling `--web-ui` for the first time, the system will automatically install necessary frontend dependencies. This may take a few minutes.
##### Option B: Using Python Code (Recommended for Developers)
If you wish to integrate or customize the launch logic within Python, you can use `AliasRunner` and `AgentApp` as shown below:
```python
from agentscope_runtime.engine.app import AgentApp
from alias.server.runtime.runner.alias_runner import AliasRunner
# 1. Initialize AliasRunner
# default_chat_mode options: "general", "dr", "browser", "ds", "finance"
runner = AliasRunner(
default_chat_mode="general",
)
# 2. Create AgentApp instance
agent_app = AgentApp(
runner=runner,
app_name="Alias",
app_description="An LLM-empowered agent built on AgentScope and AgentScope-Runtime",
)
# 3. Run the service
# Set web_ui=True to enable the visual debugging interface
agent_app.run(host="127.0.0.1", port=8090)
```
#### 4. Accessing the Application
Once the service is running, you can access Alias via:
* **Runtime API Access**: Send standard HTTP POST requests to `http://localhost:8090/process`. This is the primary method for integrating Alias into third-party frontends or backend workflows.
* **Visual Monitoring (Optional)**: If started with the `--web-ui` flag, visit `http://localhost:5173`. This interface allows developers to observe the agent's reasoning process, tool execution traces, and other debugging information.
## ⚖️ License
Alias-Agent is released under the **Apache 2.0 License** see the [LICENSE](https://github.com/agentscope-ai/agentscope-samples/blob/main/LICENSE) file for details.

View File

@@ -208,10 +208,25 @@ alias_agent run --mode finance --task "Analyze Tesla's Q4 2024 financial perform
# 数据科学Data Science模式
alias_agent run --mode ds \
--task "Analyze the distribution of incidents across categories in 'incident_records.csv' to identify imbalances, inconsistencies, or anomalies, and determine their root cause." \
--files ./docs/data/incident_records.csv
--datasource ./docs/data/incident_records.csv
```
**注意**:使用 `--files` 上传的文件会自动复制到沙盒中的 `/workspace`。生成的文件可在 `sessions_mount_dir` 的子目录中找到。
#### 输入/输出管理
**输入:**
- 使用 `--datasource` 参数指定数据源,支持多种格式 (向后兼容,也支持使用 `--files`):
- **本地文件**:如 `./data.txt``/absolute/path/file.json`
- **数据库 DSN**:支持 PostgreSQL、SQLite 等关系型数据库,格式如 `postgresql://user:password@host:port/database`
示例: `--datasource file.txt postgresql://user:password@localhost:5432/mydb`
- 指定的数据源会自动进行 profile分析并为模型提供高效访问数据源的指导。
- 上传的文件会自动复制到沙盒中的 `/workspace` 目录。
**输出:**
- 生成的文件存储在 `sessions_mount_dir` 的子目录中,可以在该位置找到所有输出结果。
#### 启用长期记忆服务(仅限通用模式)
要在通用模式下启用长期记忆服务,您需要:
@@ -371,6 +386,83 @@ bash script/start_memory_service.sh
- **用户名**:如 `FIRST_SUPERUSER_USERNAME` 所指定 (默认: `alias`)
- **密码**:如 `FIRST_SUPERUSER_PASSWORD` 所指定
### 🌐 基础用法 -- AgentScope Runtime 部署
Alias 现已适配 [AgentScope Runtime](https://github.com/agentscope-ai/agentscope-runtime/),您可以利用 AgentScope Runtime 将 Alias 快速部署为标准后端服务。启动后,通过配套的 AgentScope Runtime API 即可轻松调用 Alias 所提供的服务。
#### 1. 前期准备
* **沙盒设置与 API 密钥**:请参考前文的 [🐳 沙盒设置](#-沙盒设置可选) 和 [🔑 API 密钥配置](#-api-密钥配置) 完成基础环境配置。
* **配置环境变量**:从项目根目录复制示例环境文件:
```bash
cp .env.example .env
```
* **启动 Redis**:缓存和会话管理所需:
```bash
docker run -d -p 6379:6379 --name alias-redis redis:7-alpine
```
#### 2. 安装与沙盒启动
在项目根目录下,以开发模式安装包,这将自动安装 `alias_agent_runtime` 命令行工具:
```bash
pip install -e .
```
为了确保代码执行和文件操作等功能正常,请在另一个终端启动沙盒服务器:
```bash
runtime-sandbox-server --extension src/alias/runtime/alias_sandbox/alias_sandbox.py
```
#### 3. 启动 AgentScope Runtime 服务
您可以根据使用场景,选择通过命令行或 Python 代码启动服务。
##### 选项 A使用命令行工具推荐
使用 `alias_agent_runtime` 命令一键启动后端服务:
```bash
alias_agent_runtime --host 127.0.0.1 --port 8090 --chat-mode general
```
**参数说明**
* `--host` / `--port`: 指定服务的运行地址和端口(默认端口为 8090
* `--chat-mode`: 设置运行模式,可选 `general`, `dr`, `browser`, `ds`, `finance`(默认为 `general`)。
* `--web-ui` : (可选) 启用 AgentScope Runtime WebUI 以开启可视化交互界面。若仅需调用 API请忽略此参数。
> **注意**:首次启动并开启 `--web-ui` 时,系统会自动安装必要的前端依赖包,可能需要花费几分钟时间,请耐心等待。
##### 选项 B使用代码启动开发者推荐
如果您希望在 Python 代码中集成或自定义启动逻辑,可以参考以下示例,结合 `AliasRunner` 和 `AgentApp`
```python
from agentscope_runtime.engine.app import AgentApp
from alias.server.runtime.runner.alias_runner import AliasRunner
# 1. 初始化 AliasRunner
# default_chat_mode 可选: "general", "dr", "browser", "ds", "finance"
runner = AliasRunner(
default_chat_mode="general",
)
# 2. 创建 AgentApp 实例
agent_app = AgentApp(
runner=runner,
app_name="Alias",
app_description="An LLM-empowered agent built on AgentScope and AgentScope-Runtime",
)
# 3. 运行服务
# 如需启用可视化调试界面,可设置 web_ui=True
agent_app.run(host="127.0.0.1", port=8090)
```
#### 4. 访问应用程序
服务启动后,您可以通过以下方式访问 Alias
* **Runtime API 调用**:通过标准 HTTP POST 请求访问 `http://localhost:8090/process`。这是将 Alias 集成至第三方前端或后端工作流的主要方式。
* **可视化监控 (可选)**:若启动时开启了 `--web-ui` 参数,可通过 `http://localhost:5173` 访问 WebUI。该界面主要用于开发者观察智能体的思考过程以及工具调用轨迹等调试信息。
## ⚖️ 许可证

View File

@@ -45,7 +45,8 @@ dependencies = [
"agentscope-runtime>=1.0.0",
"aiosqlite>=0.21.0",
"asyncpg>=0.30.0",
"itsdangerous>=2.2.0"
"itsdangerous>=2.2.0",
"polars>=1.37.1"
]
[tool.setuptools]
@@ -65,3 +66,4 @@ dev = [
[project.scripts]
alias_agent = "alias.cli:main"
alias_agent_runtime = "alias.server.alias_agent_app:main"

View File

@@ -0,0 +1,74 @@
---
name: csv-excel-file
description: Guidelines for handling CSV/Excel files
type:
- csv
- excel
---
# CSV/Excel Handling Specifications
## Goals
- Safely load tabular data without crashing.
- Detect and handle messy spreadsheets (multiple blocks, missing headers, merged cells artifacts).
- Produce reliable outputs (clean dataframe for clean table or structured JSON for messy spreadsheet) with validated types.
## Encoding, Delimiters, and Locale
- CSV encoding: Try UTF-8; if garbled, attempt common fallbacks (e.g., gbk, cp1252) based on context.
- Delimiters: Detect common separators (,, \t, ;, |) during inspection.
- Locale formats: Be cautious with comma decimal separators and thousands separators.
## Inspection (always first)
- Identify file type, encoding (CSV), and sheet names (Excel) before full reads.
- Prefer small reads to preview structure:
- CSV: pd.read_csv(..., nrows=20); if uncertain delimiter: sep=None, engine="python" (small nrows only).
- Excel: pd.ExcelFile(path).sheet_names, then pd.read_excel(..., sheet_name=..., nrows=20).
- Use df.head(n) and df.columns to check:
- Missing/incorrect headers (e.g., columns are numeric 0..N-1)
- "Unnamed: X" columns
- Unexpected NaN/NaT, merged-cell artifacts
- Multiple tables/blocks in one sheet (blank rows separating sections)
## Preprocessing
- Treat as messy if any of the following is present:
- Columns contain "Unnamed:" or mostly empty column names
- Header row appears inside the data (first rows look like data + later row looks like header)
- Multiple data blocks (large blank-row gaps, repeated header patterns)
- Predominantly NaN/NaT in top rows/left columns
- Notes/metadata blocks above/beside the table (titles, footnotes, merged header areas)
- If messy spreadsheets are detected:
- First choice: use `clean_messy_spreadsheet` tool to extract key tables/fields and output JSON.
- Only fall back to manual parsing if tool fails, returns empty/incorrect structure, or cannot locate the target table.
## Querying
- Never load entire datasets blindly.
- Use minimal reads:
- `nrows`, `usecols`, `dtype` (or partial dtype mapping), `parse_dates` only when necessary.
- Sampling: `skiprows` with a step pattern for rough profiling when file is huge.
- For very large CSV:
- Prefer `chunksize` iteration; aggregate/compute per chunk.
- For Excel:
- Read only needed `sheet_name`, and consider narrowing `usecols`/`nrows` during exploration.
## Data Quality & Type Validation
- After load/clean:
- Validate types:
- Numeric columns: coerce with pd.to_numeric(errors="coerce")
- Datetime columns: pd.to_datetime(errors="coerce")
- Report coercion fallout (how many became NaN/NaT).
- Standardize missing values: treat empty strings/“N/A”/“null” consistently.
# Best Practices
- Always inspect structure before processing.
- Handle encoding issues appropriately
- Keep reads minimal; expand only after confirming layout.
- Log decisions: chosen sheet, detected header row, dropped columns/rows, dtype conversions.
- Avoid silent data loss: when dropping/cleaning, summarize what changed.
- Validate data types after loading

View File

@@ -0,0 +1,47 @@
---
name: image-file
description: Guidelines for handling image files
type: image
---
# Images Handling Specifications
## Goals
- Safely identify image properties and metadata without memory exhaustion.
- Accurately extract text (OCR) and visual elements (Object Detection/Description).
- Perform necessary pre-processing (resize, normalize, crop) for downstream tasks.
- Handle multi-frame or high-resolution images efficiently.
## Inspection (Always First)
- Identify Properties: Use lightweight libraries (e.g., PIL/Pillow) to get `format`, `size` (width/height), and `mode` (RGB, RGBA, CMYK).
- Check File Size: If the image is exceptionally large (e.g., >20MB or >100MP), consider downsampling or tiling before full processing.
- Metadata/EXIF Extraction:
- Read EXIF data for orientation, GPS tags, and timestamps.
- Correction: Automatically apply EXIF orientation to ensure the image is "upright" before visual analysis.
## Content Extraction & Vision
- Vision Analysis:
- Use multimodal vision models to describe scenes, identify objects, and detect activities.
- For complex images (e.g., infographics, UI screenshots), guide the model to focus on specific regions.
- OCR (Optical Character Recognition):
- If text is detected, specify whether to extract "raw text" or "structured data" (like forms/tables).
- Handle low-contrast or noisy backgrounds by applying pre-filters (grayscale, binarization).
- Format Conversion: Convert non-standard formats (e.g., HEIC, TIFF) to standard formats (JPEG/PNG) if tools require it.
## Handling Large or Complex Images
- Tiling: For ultra-high-res images (e.g., satellite maps, medical scans), split into overlapping tiles to avoid missing small details.
- Batching: Process multiple images using generators to keep memory usage stable.
- Alpha Channel: Be mindful of transparency (PNG/WebP); decide whether to discard it or composite against a solid background (e.g., white).
## Best Practices
- Safety First: Validate that the file is a genuine image (not a renamed malicious script).
- Graceful Failure: Handle corrupted files, truncated downloads, or unsupported formats with descriptive error logs.
- Efficiency: Avoid unnecessary re-encoding (e.g., multiple JPEG saves) to prevent "generation loss" or artifacts.
- Process images individually or in small batches to prevent system crashes
- Consider memory usage when working with large or high-resolution images
- Resource Management: Close file pointers or use context managers (`with Image.open(...) as img:`) to prevent memory leaks.

View File

@@ -0,0 +1,54 @@
---
name: json-file
description: Guildlines for handling json files
type: json
---
# JSON Handling Specifications
## Goals
- Safely parse JSON/JSONL without memory overflow.
- Discover schema structure (keys, nesting depth, data types).
- Flatten complex nested structures into tabular data when necessary.
- Handle inconsistent schemas and "dirty" JSON (e.g., trailing commas, mixed types).
## Inspection (Always First)
- Structure Discovery:
- Determine if the root is a `list` or a `dict`.
- Identify if it's a standard JSON or JSONL (one valid JSON object per line).
- Schema Sampling:
- For large files, read the first few objects/lines to infer the schema.
- Identify top-level keys and their types.
- Detect nesting depth: If depth > 3, consider it a "deeply nested" structure.
- Size Check:
- If the file is large (>50MB), avoid `json.load()`. Use iterative parsing or streaming.
## Processing & Extraction
- Lazy Loading (Streaming):
- For massive JSON: Use `ijson` (Python) or similar streaming parsers to yield specific paths/items.
- For JSONL: Read line-by-line using a generator to minimize memory footprint.
- Flattening & Normalization:
- Use `pandas.json_normalize` to convert nested structures into flat tables if the goal is analysis.
- Specify `max_level` during normalization to prevent "column explosion."
- Data Filtering:
- Extract only required sub-trees (keys) early in the process to reduce the memory object size.
## Data Quality & Schema Validation
- Missing Keys: Use `.get(key, default)` or `try-except` blocks. Never assume a key exists in all objects.
- Type Coercion:
- Validate numeric strings vs. actual numbers.
- Standardize `null`, `""`, and `[]` consistently.
- Encoding: Default to UTF-8; check for BOM (utf-8-sig) if parsing fails.
- Malformed JSON Recovery:
- For minor syntax errors (e.g., single quotes instead of double), attempt `ast.literal_eval` or regex-based cleanup only as a fallback.
## Best Practices
- Minimal Reads: Don't load a 50MB JSON just to read one config key; use a streaming approach.
- Schema Logging: Document the detected structure (e.g., "Root is a list of 500 objects; key 'metadata' is nested").
- Error Transparency: When a JSON object in a JSONL stream is corrupted, log the line number, skip it, and continue instead of crashing the entire process.
- Avoid Over-Flattening: Be cautious with deeply nested arrays; flattening them can lead to massive row duplication.
- Strict Typing: After extraction, explicitly convert types (e.g., `pd.to_datetime`) to ensure downstream reliability.

View File

@@ -0,0 +1,70 @@
---
name: database
description: Guidelines for handling databases
type: relational_db
---
# Database Handling Specifications
## Goals
- Safely explore database schema without performance degradation.
- Construct precise, efficient SQL queries that prevent system crashes (OOM & OOT).
- Handle dialect-specific nuances (PostgreSQL, MySQL, SQLite, etc.).
- Transform raw result sets into structured, validated data for analysis.
## Inspection
- Volume Estimation:
- Before any `SELECT *`, always run `SELECT COUNT(*) FROM table_name` to understand the scale.
- If a table has >1,000,000 rows, strictly use indexed columns for filtering.
- Sample Data:
- Use `SELECT * FROM table_name LIMIT 5` to see actual data formats.
## Querying
- Safety Constraints:
- Always use `LIMIT`: Never execute a query without a `LIMIT` clause unless the row count is confirmed to be small.
- Avoid `SELECT *`: In production-scale tables, explicitly name columns to reduce I/O and memory usage.
- Dialect & Syntax:
- Case Sensitivity: If a column/table name contains uppercase or special characters, MUST quote it (e.g., `"UserTable"` in Postgres, `` `UserTable` `` in MySQL).
- Date/Time: Use standard ISO strings for date filtering; be mindful of timezone-aware vs. naive columns.
- Complex Queries:
- For `JOIN` operations, ensure joining columns are indexed to prevent full table scans.
- When performing `GROUP BY`, ensure the result set size is manageable.
## Data Retrieval & Transformation
- Type Mapping:
- Ensure SQL types (e.g., `DECIMAL`, `BIGINT`, `TIMESTAMP`) are correctly mapped to Python/JSON types without precision loss.
- Convert `NULL` values to a consistent "missing" representation (e.g., `None` or `NaN`).
- Chunked Fetching:
- For medium-to-large exports, use `fetchmany(size)` or `OFFSET/LIMIT` pagination instead of fetching everything into memory at once.
- Aggregations:
- Prefer performing calculations (SUM, AVG, COUNT) at the database level rather than pulling raw data to the client for processing.
## Error Handling & Recovery
- Timeout Management: If a query takes too long, retry with more restrictive filters or optimized joins.
- Syntax Errors: If a query fails, inspect the dialect-specific error message and re-verify the schema (it's often a misspelled column or missing quotes).
## Anti-Pattern Prevention (Avoiding "Bad" SQL)
- Index-Friendly Filters: Never wrap indexed columns in functions (e.g., `DATE()`, `UPPER()`) within the `WHERE` clause.
- Join Safety: Always verify join keys. Before joining, check if the key has high cardinality to avoid massive intermediate result sets.
- Memory Safety:
- Avoid `DISTINCT` and `UNION` (which performs de-duplication) on multi-million row sets unless necessary; use `UNION ALL` if duplicates are acceptable.
- Avoid `ORDER BY` on large non-indexed text fields.
- Wildcard Warning: Strictly avoid leading wildcards in `LIKE` patterns (e.g., `%term`) on large text columns.
- No Function on Columns: `WHERE col = FUNC(val)` is good; `WHERE FUNC(col) = val` is bad.
- Explicit Columns: Only fetch what is necessary.
- Early Filtering: Push `WHERE` conditions as close to the base tables as possible.
- CTE for Clarity: Use `WITH` for complex multi-step logic to improve maintainability and optimizer hints.
# Best Practices
- Always verify database structure before querying
- Use appropriate sampling techniques for large datasets
- Optimize queries for efficiency based on schema inspection
- Self-review the draft SQL against the "Anti-Pattern Prevention" list.
- Perform a silent mental 'EXPLAIN' on your query. If it smells like a full table scan on a large table, refactor it before outputting

View File

@@ -0,0 +1,50 @@
---
name: text-file
description: Guidelines for handling text files
type: text
---
# Text Files Handling Specifications
## Goals
- Safely read text files without memory exhaustion.
- Accurately detect encoding to avoid garbled characters.
- Identify underlying patterns (e.g., Log formats, Markdown structure, delimiters).
- Efficiently extract or search for specific information within large volumes of text.
## Encoding & Detection
- Encoding Strategy:
- Default to `utf-8`.
- If it fails, try `utf-8-sig` (for files with BOM), `gbk/gb18030` (for Chinese context), or `latin-1`.
- Use `chardet` or similar logic if encoding is unknown and first few bytes look non-standard.
- Line Endings: Be aware of `\n` (Unix), `\r\n` (Windows), and `\r` (Legacy Mac) when counting lines or splitting.
## Inspection
- Preview: Read the first 10-20 lines to determine:
- Content Type: Is it a log, code, prose, or a semi-structured list?
- Uniformity: Does every line follow the same format?
- Metadata: Check total file size before reading. If >50MB, treat as a "large file" and avoid full loading.
## Querying & Reading (Large Files)
- Streaming: For files exceeding memory or >50MB:
- Use `with open(path) as f: for line in f:` to process line-by-line.
- Never use `.read()` or `.readlines()` on large files.
- Random Sampling: To understand a huge file's structure, read the first N lines, the middle N lines (using `f.seek()`), and the last N lines.
- Pattern Matching: Use Regular Expressions (Regex) for targeted extraction instead of complex string slicing.
- Grep-like Search: If searching for a keyword, iterate through lines and only store/return matching lines + context.
## Data Quality
- Truncation Warning: If only a portion of the file is read, clearly state: "Displaying first X lines of Y total lines."
- Empty Lines/Comments: Decide early whether to ignore blank lines or lines starting with specific comment characters (e.g., `#`, `//`).
## Best Practices
- Resource Safety: Always use context managers (`with` statement) to ensure file handles are closed.
- Memory Consciousness: For logs and large TXT, prioritize "find and extract" over "load and filter."
- Regex Optimization: Compile regex patterns if they are used repeatedly in a loop over millions of lines.
- Validation: After reading, verify the content isn't binary (e.g., PDF or EXE renamed to .txt) by checking for null bytes or a high density of non-ASCII characters.
- Progress Logging: For long-running text processing, log progress every 100k lines or 10% of file size.

View File

@@ -23,6 +23,7 @@ from alias.agent.tools import AliasToolkit, share_tools
from alias.agent.agents.common_agent_utils import (
get_user_input_to_mem_pre_reply_hook,
)
from alias.agent.agents.data_source.data_source import DataSourceManager
from .ds_agent_utils import (
ReportGenerator,
LLMPromptSelector,
@@ -50,7 +51,8 @@ class DataScienceAgent(AliasAgentBase):
formatter: FormatterBase,
memory: MemoryBase,
toolkit: AliasToolkit,
sys_prompt: str = None,
data_manager: DataSourceManager = None,
sys_prompt: str = "",
max_iters: int = 30,
tmp_file_storage_dir: str = "/workspace",
state_saving_dir: Optional[str] = None,
@@ -71,17 +73,16 @@ class DataScienceAgent(AliasAgentBase):
set_run_ipython_cell(self.toolkit.sandbox)
self.uploaded_files: List[str] = []
self.todo_list: List[Dict[str, Any]] = []
self.infer_trajectories: List[List[Msg]] = []
self.tmp_file_storage_dir = tmp_file_storage_dir
self.data_manager = data_manager
self.detailed_report_path = os.path.join(
tmp_file_storage_dir,
"detailed_report.html",
)
self.tmp_file_storage_dir = tmp_file_storage_dir
self.todo_list_prompt = get_prompt_from_file(
os.path.join(
@@ -91,12 +92,19 @@ class DataScienceAgent(AliasAgentBase):
False,
)
self._sys_prompt = get_prompt_from_file(
os.path.join(
PROMPT_DS_BASE_PATH,
"_agent_system_workflow_prompt.md",
),
False,
self._sys_prompt = (
cast(
str,
get_prompt_from_file(
os.path.join(
PROMPT_DS_BASE_PATH,
"_agent_system_workflow_prompt.md",
),
False,
),
)
+ "\n\n"
+ sys_prompt
)
# load prompts and initialize selector
@@ -167,7 +175,7 @@ class DataScienceAgent(AliasAgentBase):
logger.info(
f"[{self.name}] "
"DeepInsightAgent initialized (fully model-driven).",
"DataScienceAgent initialized (fully model-driven).",
)
@property
@@ -427,27 +435,58 @@ class DataScienceAgent(AliasAgentBase):
memory_log=memory_log,
)
response, report = await report_generator.generate_report()
(
response,
report_md,
report_html,
) = await report_generator.generate_report()
if report_md:
md_report_path = os.path.join(
self.tmp_file_storage_dir,
"detailed_report.md",
)
if report:
# report = report.replace(self.tmp_file_storage_dir, ".")
await self.toolkit.call_tool_function(
ToolUseBlock(
type="tool_use",
id=str(uuid.uuid4()),
name="write_file",
input={
"path": self.detailed_report_path,
"content": report,
"path": md_report_path,
"content": report_md,
},
),
)
response = (
f"{response}\n\n"
"The detailed report has been saved to "
f"{self.detailed_report_path}."
"The detailed report (markdown version) has been saved to "
f"{md_report_path}."
)
if report_html:
html_report_path = os.path.join(
self.tmp_file_storage_dir,
"detailed_report.html",
)
await self.toolkit.call_tool_function(
ToolUseBlock(
type="tool_use",
id=str(uuid.uuid4()),
name="write_file",
input={
"path": html_report_path,
"content": report_html,
},
),
)
response = (
f"{response}\n\n"
"The detailed report (html version) has been saved to "
f"{html_report_path}."
)
kwargs["response"] = response
structured_output = {}

View File

@@ -698,7 +698,7 @@ class MetaPlanner(AliasAgentBase):
):
"""
Directly enter the data science mode.
Use this when the user provides some data files and ask for processing.
Use this for COMPLEX, CODE-BASED data analysis.
Args:
user_query (`str`):

View File

@@ -0,0 +1,722 @@
# -*- coding: utf-8 -*-
# pylint: disable=R1702,R0912,R0915
import os
import json
from abc import ABC, abstractmethod
from typing import Any, Dict
from loguru import logger
import pandas as pd
from sqlalchemy import inspect, text, create_engine
from agentscope.message import Msg
from alias.agent.agents.data_source._typing import SourceType
from alias.agent.agents.ds_agent_utils import (
get_prompt_from_file,
)
from alias.agent.utils.llm_call_manager import (
LLMCallManager,
)
class BaseDataProfiler(ABC):
"""Abstract base class for data profilers that analyze different data
sources like csv, excel, db, etc.
"""
_PROFILE_PROMPT_BASE_PATH = os.path.join(
os.path.dirname(__file__),
"built_in_prompt",
)
def __init__(
self,
path: str,
source_type: SourceType,
llm_call_manager: LLMCallManager,
):
"""Initialize the data profiler with data path, type and LLM manager.
Args:
path: Path to the data source file or connection string
source_type: Enum indicating the type of data source
llm_call_manager: Manager for handling LLM calls
"""
self.path = path
self.file_name = os.path.basename(path)
self.source_type = source_type
self.llm_call_manager = llm_call_manager
self.source_types_2_prompts = {
SourceType.CSV: "_profile_csv_prompt.md",
SourceType.EXCEL: "_profile_xlsx_prompt.md",
SourceType.IMAGE: "_profile_image_prompt.md",
SourceType.RELATIONAL_DB: "_profile_relationdb_prompt.md",
"IRREGULAR": "_profile_irregular_xlsx_prompt.md",
}
if source_type not in self.source_types_2_prompts:
raise ValueError(f"Unsupported source type: {source_type}")
self.prompt = self._load_prompt(source_type)
base_model_name = self.llm_call_manager.get_base_model_name()
vl_model_name = self.llm_call_manager.get_vl_model_name()
self.source_types_2_models = {
SourceType.CSV: base_model_name,
SourceType.EXCEL: base_model_name,
SourceType.IMAGE: vl_model_name,
SourceType.RELATIONAL_DB: base_model_name,
}
self.model_name = self.source_types_2_models[source_type]
def _load_prompt(self, source_type: Any = None):
"""Load the appropriate prompt template based on the source type.
Args:
source_type: Type of data source (CSV, EXCEL, IMAGE, etc.)
Returns:
Loaded prompt template as string
"""
prompt_file_name = self.source_types_2_prompts[source_type]
prompt = get_prompt_from_file(
os.path.join(
self._PROFILE_PROMPT_BASE_PATH,
prompt_file_name,
),
False,
)
return prompt
async def generate_profile(self) -> Dict[str, Any]:
"""Generate a complete data profile
by reading data, generating content,
calling the LLM, and wrapping the response.
Returns:
Dictionary containing the complete data profile
"""
try:
self.data = await self._read_data()
# different source types have different data building methods
content = self._build_content_with_prompt_and_data(
self.prompt,
self.data,
)
# content = self.prompt.format(data=self.data)
res = await self._generate_profile_by_llm(content)
self.profile = self._wrap_data_response(res)
except Exception as e:
logger.warning(f"Error generating profile: {e}")
self.profile = {}
return self.profile
@staticmethod
def tool_clean_json(raw_response: str):
"""Clean and parse JSON response from LLM by removing markdown
markers.
Args:
raw_response: Raw string response from LLM
Returns:
Parsed JSON object from the cleaned response
"""
cleaned_response = raw_response.strip()
if cleaned_response.startswith("```json"):
cleaned_response = cleaned_response[len("```json") :].lstrip()
if cleaned_response.startswith("```"):
cleaned_response = cleaned_response[len("```") :].lstrip()
if cleaned_response.endswith("```"):
cleaned_response = cleaned_response[:-3].rstrip()
return json.loads(cleaned_response)
@abstractmethod
def _build_content_with_prompt_and_data(
self,
prompt: str,
data: Any,
) -> str:
"""Abstract method to build content for LLM based on prompt
and data.
This method should be implemented by subclasses to format
content appropriately for different data types.
Args:
prompt: Prompt template to use
data: Processed data to include in the prompt
Returns:
Formatted content for LLM call
"""
@abstractmethod
async def _read_data(self):
"""Abstract method to read and process data from the source path.
This method should be implemented by subclasses to handle
specific
data source types (CSV, Excel, DB, etc.).
Returns:
Processed data in appropriate format for the data type
"""
async def _generate_profile_by_llm(
self,
content: Any,
) -> Dict[str, Any]:
"""Generate profile by calling LLM with prepared content.
Args:
content: Content to send to the LLM (text or multimodal)
Returns:
Dictionary response parsed from LLM output
"""
sys_prompt = "You are a helpful AI assistant for database management."
msgs = [
Msg("system", sys_prompt, "system"),
Msg("user", content, "user"),
]
response = await self.llm_call_manager(
model_name=self.model_name,
messages=msgs,
)
response = BaseDataProfiler.tool_clean_json(response)
return response
@abstractmethod
def _wrap_data_response(self, response: Dict[str, Any]) -> Dict[str, Any]:
"""Abstract method to combine LLM response with original schema.
This method should be implemented by subclasses to properly
merge
LLM-generated descriptions with original data structure.
Args:
response: Dictionary response from LLM
Returns:
Combined dictionary with original schema and LLM response
"""
class StructuredDataProfiler(BaseDataProfiler):
"""Base class for profilers that work with structured data sources
like CSV, Excel, and relational databases.
"""
@staticmethod
def is_irregular(cols: list[str]):
"""Determine if a table has irregular column names (many unnamed).
Args:
cols: List of column names from the dataset
Returns:
Boolean indicating whether the dataset is irregular
"""
# any(col.startswith('Unnamed') for col in df.columns.astype(str))?
unnamed_columns_ratio = sum(
col.startswith("Unnamed") for col in cols.astype(str)
) / len(cols)
return unnamed_columns_ratio >= 0.5
@staticmethod
def _extract_schema_from_table(df: pd.DataFrame, df_name: str) -> dict:
"""Analyzes a single DataFrame to extract metadata and samples.
Extracts column names, data types, and sample values to provide a
comprehensive view of the table structure for the LLM.
Args:
df: The dataframe to analyze
df_name: Name of the table (or sheet/filename)
Returns:
Dictionary containing schema metadata for the table
"""
col_list = []
for col in df.columns:
dtype_name = str(df[col].dtype).upper()
# Get random samples to help LLM understand the data content
# sample(frac=1): shuffle the data
# head(n_samples): get the first n_samples,
# if less than n_samples, retrieved here without any errors.
candidates = (
df[col]
.drop_duplicates()
.sample(frac=1, random_state=42)
.head(5)
.astype(str)
.tolist()
)
# Limit the size not to exceed 1000 characters.
# TODO: dynamic size control? 1000 is too small?
samples = []
length = 0
for s in candidates:
if (length := length + len(s)) <= 1000:
samples.append(s)
col_list.append(
{
"column name": col,
"data type": dtype_name,
"data samples": samples,
},
)
# Create a CSV snippet of the first few rows
raw_data_snippet = df.head(5).to_csv(index=True)
table_schema = {
"name": df_name,
"raw_data_snippet": raw_data_snippet,
# Note: Row count logic might need optimization for large files
# TODO: how to get the row count more efficiently, openpyxl.
"row_count": len(df) if len(df) < 100 else None,
"col_count": len(df.columns),
"columns": col_list,
}
return table_schema
def _build_content_with_prompt_and_data(
self,
prompt: str,
data: Any,
) -> str:
"""Format the prompt with data for structured data sources.
Args:
prompt: Template prompt string
data: Processed data structure
Returns:
Formatted content string ready for LLM
"""
return prompt.format(data=data)
def _wrap_data_response(self, response: Dict[str, Any]) -> Dict[str, Any]:
"""Merges the original schema with the LLM-generated response.
Combines the structural information from the original data with
semantic descriptions generated by the LLM.
Args:
response: Dictionary response from LLM with descriptions
Returns:
Combined schema with both structural and semantic info
"""
new_schema = {}
new_schema["name"] = self.data["name"]
new_schema["description"] = response["description"]
# # For flat files like CSV, they contain columns
if "columns" in self.data:
new_schema["columns"] = self.data["columns"]
# # For multi-table sources like Excel/Database,
# they contain tables. Each table contains columns and description
if "tables" in self.data and "tables" in response:
new_schema["tables"] = []
# Build a map for response tables and descriptions
res_des_map = {
table["name"]: table["description"]
for table in response["tables"]
}
for table in self.data["tables"]:
table_name = table["name"]
if table_name not in res_des_map:
continue
new_table = {}
new_table["name"] = table_name
# Retain the desrciption from the LLM response
new_table["description"] = res_des_map[table_name]
if "columns" in table:
new_table["columns"] = table["columns"]
if "irregular_judgment" in table:
new_table["irregular_judgment"] = table[
"irregular_judgment"
]
new_schema["tables"].append(new_table)
return new_schema
class ExcelProfiler(StructuredDataProfiler):
async def _extract_irregular_table(
self,
path: str,
raw_data_snippet: str,
sheet_name: str,
):
"""Extract structure from irregular Excel sheets with unnamed
columns. Uses a special LLM call to identify the actual data in
sheets with headers or other content above the main data table.
Args:
path: Path to the Excel file
raw_data_snippet: Raw text snippet of the sheet content
sheet_name: Name of the sheet being processed
Returns:
Schema dictionary for the irregular table structure
"""
prompt = self._load_prompt("IRREGULAR")
content = prompt.format(raw_snippet_data=raw_data_snippet)
res = await self._generate_profile_by_llm(content=content)
if "is_extractable_table" in res and res["is_extractable_table"]:
logger.debug(res["reasoning"])
skiprows = res["row_start_index"] + 1
cols_range = res["col_ranges"]
df = pd.read_excel(
path,
sheet_name=sheet_name,
nrows=100,
skiprows=skiprows,
usecols=range(cols_range[0], cols_range[1] + 1),
).convert_dtypes()
if StructuredDataProfiler.is_irregular(df.columns):
schema = {
"name": sheet_name,
"raw_data_snippet": raw_data_snippet,
"irregular_judgment": "UNSTRUCTURED",
}
else:
schema = self._extract_schema_from_table(df, sheet_name)
schema["irregular_judgment"] = res
else:
schema = {
"name": sheet_name,
"raw_data_snippet": raw_data_snippet,
"irregular_judgment": "UNSTRUCTURED",
}
return schema
async def _read_data(self):
"""Read and process Excel file data including all sheets.
Handles both regular and irregular Excel files by using pandas
for regular files and openpyxl for files with unnamed columns.
Returns:
Dictionary containing metadata for all sheets in the Excel file
"""
excel_file = pd.ExcelFile(self.path)
table_schemas = []
schema = {}
schema["name"] = self.file_name
for sheet_name in excel_file.sheet_names:
# TODO: use openpyxl to read excel to avoid irregular excel.
# Read a subset of each sheet
df = pd.read_excel(
self.path,
sheet_name=sheet_name,
nrows=100,
).convert_dtypes()
if not StructuredDataProfiler.is_irregular(df.columns):
table_schema = (
StructuredDataProfiler._extract_schema_from_table(
df,
sheet_name,
)
)
else:
# if unnamed columns, use openpyxl to extract top 100 rows.
import openpyxl
wb = openpyxl.load_workbook(
self.path,
read_only=True,
data_only=True,
)
ws = wb[sheet_name]
rows_data = []
for i, row in enumerate(
ws.iter_rows(values_only=True),
start=1,
):
if i > 100:
break
rows_data.append(
",".join(
"" if cell is None else str(cell) for cell in row
),
)
wb.close()
raw_data_snippet = "\n".join(rows_data)
table_schema = await self._extract_irregular_table(
self.path,
raw_data_snippet,
sheet_name,
)
# table_schema = {
# "name": sheet_name,
# "raw_data_snippet": "\n".join(rows_data),
# }
table_schemas.append(table_schema)
schema["tables"] = table_schemas
return schema
class RelationalDatabaseProfiler(StructuredDataProfiler):
async def _read_data(self):
"""
Extracts metadata (schema) for all tables in a relational db.
path (str): The Database Source Name (connection string).
eg. postgresql://user:passward@ip:port/db_name
Returns:
Dictionary containing database metadata for all tables
"""
options = {
"isolation_level": "AUTOCOMMIT",
# Test conns before use (handles MySQL 8hr timeout, network drops)
"pool_pre_ping": True,
# Keep minimal conns (MCP typically handles 1 request at a time)
"pool_size": 1,
# Allow temporary burst capacity for edge cases
"max_overflow": 2,
# Force refresh conns older than 1hr (under MySQL's 8hr default)
"pool_recycle": 3600,
}
engine = create_engine(self.path, **options)
try:
connection = engine.connect()
except Exception as e:
logger.error(f"Connection to {self.path} failed: {e}")
raise ConnectionError(f"Failed to connect to database: {e}") from e
# Use DSN as the db identifier (can parsed cleaner)
database_name = self.path
inspector = inspect(connection)
table_names = inspector.get_table_names()
tables_data = []
for table_name in table_names:
try:
# 1. Get column information
columns = inspector.get_columns(table_name)
col_count = len(columns)
# 2. Get row count
row_count_result = connection.execute(
text(f"SELECT COUNT(*) FROM {table_name}"),
).fetchone()
row_count = row_count_result[0] if row_count_result else 0
# 3. Get raw data snippet (first 5 rows)
raw_data_snippet = ""
try:
result = connection.execute(
text(f"SELECT * FROM {table_name} LIMIT 5"),
)
rows = result.fetchall()
if rows:
column_names = [col["name"] for col in columns]
lines = []
# Add header
lines.append(", ".join(column_names))
# Add data rows
for row in rows:
row_values = []
for value in row:
if value is None:
row_values.append("NULL")
else:
# Escape commas and newlines
val_str = str(value)
if "," in val_str or "\n" in val_str:
val_str = f'"{val_str}"'
row_values.append(val_str)
lines.append(", ".join(row_values))
raw_data_snippet = "\n".join(lines)
except Exception as e:
logger.warning(
f"Error fetching {table_name} data: {str(e)}",
)
raw_data_snippet = None
# 4. detailed column info (types and samples)
column_details = []
if rows:
for i, col in enumerate(columns):
col_name = col["name"]
col_type = str(col["type"])
# Extract samples for this column from the fetched rows
sample_values = []
for row in rows:
if i < len(row):
val = row[i]
sample_values.append(
str(val) if val is not None else "NULL",
)
column_details.append(
{
"column name": col_name,
"data type": col_type,
"data sample": sample_values[:3],
},
)
table_info = {
"name": table_name,
"row_count": row_count,
"col_count": col_count,
"raw_data_snippet": raw_data_snippet,
"columns": column_details,
}
tables_data.append(table_info)
except Exception as e:
# If one table fails, log it and continue to the next
logger.warning(f"Error processing {table_name}: {str(e)}")
continue
# Contruct the final schema
schema = {
"name": database_name,
"tables": tables_data,
}
self.data = schema
return schema
class CsvProfiler(ExcelProfiler):
async def _read_data(self):
"""Handles schema extraction for CSV as single-table sources.
Uses Polars for efficient row counting on large files and
pandas for detailed schema analysis of the first 100 rows.
Returns:
Schema dictionary for the CSV file
"""
import polars as pl
# Use Polars for efficient row counting on large files
df = pl.scan_csv(self.path, ignore_errors=True)
row_count = df.select(pl.len()).collect().item()
# Read a subset with Pandas for detailed schema analysis
df = pd.read_csv(self.path, nrows=100).convert_dtypes()
schema = self._extract_schema_from_table(df, self.file_name)
schema["row_count"] = row_count
# if StructuredDataProfiler.is_irregular(df.columns):
# self._extract_irregular_table(...)
return schema
class ImageProfiler(BaseDataProfiler):
"""Profiler for image data sources that uses multimodal LLMs."""
async def _read_data(self):
"""
For images, this simply returns the path since the LLM API
handles loading the image directly.
Returns:
Path to the image file
"""
return self.path
def _build_content_with_prompt_and_data(self, prompt, data):
"""build multimodal content for image analysis.
Creates content in the format required by multimodal LLM APIs
with both image and text components.
Args:
prompt: Text prompt template for image analysis
data: Path to the image file
Returns:
List containing image and text components for the LLM call
"""
# Convert image paths according to the model requirements
contents = [
{
"text": prompt,
"type": "text",
},
{
"source": {
"url": data,
"type": "url",
},
"type": "image",
},
]
return contents
def _wrap_data_response(self, response: Dict[str, Any]) -> Dict[str, Any]:
"""Format the LLM response for image data into dict.
Args:
response: Dictionary response from multimodal LLM
Returns:
Profile dictionary with image name, description and details
"""
profile = {
"name": self.file_name,
"description": response["description"],
"details": response["details"],
}
return profile
class DataProfilerFactory:
"""Factory class to create appropriate data profiler instances based
on source type.
"""
@staticmethod
def get_profiler(
llm_call_manager: LLMCallManager,
path: str,
source_type: SourceType,
) -> BaseDataProfiler:
"""Factory method to get the appropriate profiler instance.
Generate the correct profile result for the source.
Args:
path: Path to the data source or connection string
source_type: Enum indicating the type of data source
llm_call_manager: Manager for handling LLM calls
Returns:
Instance of the appropriate profiler subclass
Raises:
ValueError: If the source_type is unsupported
"""
if source_type == SourceType.IMAGE:
return ImageProfiler(
path=path,
source_type=source_type,
llm_call_manager=llm_call_manager,
)
elif source_type == SourceType.CSV:
return CsvProfiler(
path=path,
source_type=source_type,
llm_call_manager=llm_call_manager,
)
elif source_type == SourceType.EXCEL:
return ExcelProfiler(
path=path,
source_type=source_type,
llm_call_manager=llm_call_manager,
)
elif source_type == SourceType.RELATIONAL_DB:
return RelationalDatabaseProfiler(
path=path,
source_type=source_type,
llm_call_manager=llm_call_manager,
)
else:
raise ValueError(f"Unsupported source type: {source_type}")

View File

@@ -0,0 +1,21 @@
{
"relational_db": {
"mcp_server": {
"mcp_alchemy": {
"command": "uvx",
"args": [
"--from",
"mcp-alchemy==2025.8.15.91819",
"--with",
"psycopg2-binary",
"--refresh-package",
"mcp-alchemy",
"mcp-alchemy"
],
"env": {
"DB_URL": "${endpoint}"
}
}
}
}
}

View File

@@ -0,0 +1,54 @@
# -*- coding: utf-8 -*-
from enum import Enum
class SourceAccessType(str, Enum):
"""Simple source access type classification"""
DIRECT = "direct"
VIA_MCP = "via_mcp"
def __str__(self):
return self.value
class SourceType(str, Enum):
"""Simple source type classification"""
CSV = "csv"
JSON = "json"
EXCEL = "excel"
TEXT = "text"
IMAGE = "image"
# Database sources
RELATIONAL_DB = "relational_db"
OTHER = "other"
def __str__(self):
return self.value
@staticmethod
def is_valid_source_type(value: str) -> bool:
try:
SourceType(value)
return True
except ValueError:
return False
# Define mapping between SourceType and SourceAccessType
SOURCE_TYPE_TO_ACCESS_TYPE = {
# File types -> LOCAL_FILE
SourceType.CSV: SourceAccessType.DIRECT,
SourceType.JSON: SourceAccessType.DIRECT,
SourceType.EXCEL: SourceAccessType.DIRECT,
SourceType.TEXT: SourceAccessType.DIRECT,
SourceType.IMAGE: SourceAccessType.DIRECT,
# Database types -> MCP_TOOL
SourceType.RELATIONAL_DB: SourceAccessType.VIA_MCP,
# Unknown type -> depends on endpoint
SourceType.OTHER: None,
}

View File

@@ -0,0 +1,83 @@
# Role
You are an expert Data Steward. Your task is to generate a single, comprehensive description sentence for a CSV file based on its metadata and raw content.
# Input Format
You will receive a single JSON string in the variable `input_json`. The structure is:
```json
{{
"name": "filename.csv",
"raw_data_snippet": "col1, col2\na, b",
"row_count": 100,
"col_count": 5,
"columns": [
{{ "column name": "col1", "data type": "string", "data sample": ["a", "b"] }}
],
}}
```
# Analysis Logic
## 1. Context & Metrics Extraction
* **Subject:** Extract the core concept from the `name` field (e.g., `logistics_data.csv` -> "logistics_data").
* **Metrics:** Identify `row_count` and `col_count`.
* **Context:** Look for time (e.g., "2024") or location keywords in the `raw_data_snippet` or `name`.
## 2. Schema Identification
* **Primary:** Use column names from the `columns` list.
* **Secondary (Inference):** If the `columns` list is empty or generic (e.g., "col1"), you MUST infer meaningful column names from the `raw_data_snippet` values (e.g., "2023-01-01" -> `date`).
* **Selection:** Choose 3-5 key columns to represent the dataset structure.
## 3. Description Construction
* Generate a **single** grammatical sentence.
* **Strict Template:** "The file [FileName] contains [Subject] data [Optional: Context] with [RowCount] rows and [ColCount] columns, featuring fields such as [List of 3-5 key columns]."
---
# Output Format (Strict JSON)
You must output a single valid JSON object containing only the `description` key.
```json
{{
"description": "The file [FileName] contains [Subject] data with [Rows] rows and [Cols] columns, featuring fields such as [Columns]."
}}
```
# One-Shot Demonstration
**[Example Input]**
`input_json` =
```json
{{
"name": "logistics_data.csv",
"raw_data_snippet": "SHP-001, Tokyo, London, 2024-05-20\nSHP-002, NY, Paris, 2024-05-21",
"row_count": 2000,
"col_count": 4,
"columns": [
{{
"column name": "shipment_id",
"data type": "string",
"data sample": ["SHP-001", "SHP-002"]
}}
],
}}
```
**[Example Output]**
```json
{{
"description": "The file logistics_data.csv contains supply chain logistics information for 2024 with 2000 rows and 4 columns, featuring fields such as shipment_id, origin, destination, and date."
}}
```
# Input
input_json = {data}

View File

@@ -0,0 +1,31 @@
Please carefully analyze this image and perform the following tasks:
### Step 1: Overall Assessment
- Determine whether the image contains "single content" or "composite content" (i.e., multiple independent information modules).
- If it is composite content, list the main components (e.g., "Bar chart in the top-left, data table in the bottom-right, title description at the top").
### Step 2: Region-wise Analysis (for composite content)
For each prominent content region, describe it using the following template:
#### [Module X] Type: [flowchart/table/chart/document/photo]
- Position and scope: Briefly describe its location (e.g., "left half", "bottom table")
- Content extraction:
- If flowchart/diagram: Describe nodes and connections in logical order, and explain label meanings.
- If table: Reconstruct row/column structure; present in Markdown table format if possible.
- If chart: Explain axes, series, trends, and provide key conclusions (e.g., "Peak reached in Q3").
- If document/text: Extract key sentences while preserving original meaning.
- If photo: Describe scene, people, and actions.
- Functional role: Infer the module's purpose within the whole image (e.g., "Supports the conclusion stated above").
### Step 3: Global Synthesis
- Summarize the core purpose of the entire image (e.g., "Presents quarterly performance analysis").
- Describe logical relationships among modules (e.g., "The table provides data sources, the chart shows trends, and the text offers recommendations").
- If there are annotations (e.g., label1, Cost2), explain their business meaning.
### Final Output Requirements
Do not output the internal analysis steps separately. You must output the final result **ONLY** in the following format:
{{
"description": "A single, concise sentence describing the overall framework of the image",
"details": "A comprehensive and detailed description based on your Step 2 & 3 analysis. Use Markdown formatting (bullet points, bold text) inside this field to ensure the structure is clear and readable."
}}

View File

@@ -0,0 +1,82 @@
# Role
You are an expert Data Engineer specializing in unstructured Excel parsing. Your task is to analyze the raw content of the first 100 rows of an Excel sheet and determine if it contains structured tabular data suitable for a Pandas DataFrame.
If it is a valid table, identify the **Header Row** and the **Column Range**.
If it is NOT a valid table (e.g., a dashboard, a form, a letter, or empty), you must flag it as unsuitable.
# Task Analysis
Excel sheets fall into two categories:
1. **List-Like Tables (Valid)**: Contains a header row followed by multiple rows of consistent record data. This is what we want.
2. **Unstructured/Layout-Heavy (Unstructured)**:
- **Forms/KV Pairs**: "Label: Value" scattered across the sheet.
- **Dashboards**: Multiple small tables, charts, or scattered numbers.
- **Text/Notes**: Paragraphs of text or disclaimers without column structure.
- **Empty/Near Empty**: Contains almost no data.
# Rules for Detection
### A. Validity Check (The "Gatekeeper")
Set `is_extractable_table` to **false** if:
- There is no distinct row where meaningful column headers align horizontally.
- The data is scattered (e.g., values exist in A1, G5, and C20 with no relation).
- The sheet looks like a printed form (Key on the left, Value on the right) rather than a list of records.
- There are fewer than 3 rows of data following a potential header.
### B. Structure Extraction (Only if Valid)
If the sheet passes the Validity Check:
1. **Header Row**: Find the first row containing multiple distinct string values that serve as column labels.
2. **Column Range**: Identify the start index (first valid header) and end index (last valid header) to define the width.
3. **Data Continuity**: Verify that rows below the header contain consistent data types (e.g., Dates under "Date").
# Input Data
The user will provide the first 100 rows in CSV/Markdown format (0-based index).
# Output Format
You must output a strictly valid JSON object.
JSON Structure:
{{
"is_extractable_table": <boolean, true if it serves as a dataframe source, false otherwise>,
"row_start_index": <int or null, 0-based index of the header row>,
"col_ranges": <list [start, end] or null, inclusive 0-based column indices>,
"confidence_score": <float, 0-1>,
"reasoning": "<string, explain what the row data contains. declare the final conclusion(IRREGULAR,REGULAR,INVALIED). >"
}}
# Examples
## Example 1 (Valid Table with Noise)
Input:
Title: Monthly Sales, NaN, NaN, NaN
NaN, NaN, NaN, NaN
NaN, Date, Item, Qty, Total
NaN, 2023-01-01, Apple, 10, 500
NaN, 2023-01-02, Banana, 5, 100
Output:
{{
"is_extractable_table": true,
"row_start_index": 2,
"col_ranges": [1, 4],
"confidence_score": 0.99,
"reasoning": " Rows 0-1 are ignored metadata, Row 2 is clear headers. Rows 3-4 contain consistent data aligned with headers. It is IRREGULAR and requires skiprows=2, usecols=[1, 4] to extract using Pansa DataFrame."
}}
## Example 2 (Unstructured - Form/Dashboard)
Input:
Company Invoice, NaN, NaN, Invoice #: 001
To:, John Doe, NaN, Date:, 2023-01-01
Address:, 123 St, NaN, Due:, 2023-02-01
NaN, NaN, NaN, NaN, NaN
Subject:, Consulting Services, NaN, NaN, NaN
Output:
{{
"is_extractable_table": false,
"row_start_index": null,
"col_ranges": null,
"confidence_score": 0.95,
"reasoning": "Data matches a 'Form/Invoice' layout (Key-Value pairs) rather than a list-like table. No single header row defines a dataset of records. It is INVALIED and cannot be processed as Pandas DataFrame."
}}
# Input
{raw_snippet_data}

View File

@@ -0,0 +1,109 @@
# Role
You are an expert Data Steward. Your task is to analyze the metadata and content of an Database.
**Assumption:** This is an ideal dataset or database where **ALL** tables contain valid headers in the first row. You will process the entire file structure in a single pass.
# Input Format
You will receive a single JSON string in the variable `input_json`. The structure is:
```json
{{
"file": "Name of the file",
"tables": [
{{
"name": "Name of the table",
"row_count": 100,
"col_count": 5,
"raw_data_snippet": "Header1, Header2\nVal1, Val2..."
}},
...
]
}}
```
# Analysis Logic
## 1. Sheet Iteration (Sheet Descriptions)
For **EACH** object in the `tables` array:
1. **Extract Schema:**
* Since headers are guaranteed, simply extract the column names from the **first row** of the `raw_data_snippet`.
* Format them as a clean list of strings.
2. **Draft Description:**
* Write a concise sentence describing what the sheet tracks based on its name and columns.
* **MANDATORY:** You MUST explicitly mention the `row_count` and `col_count` in this sentence.
* *Template:* "The sheet [Sheet Name] contains [Subject] data with [Row Count] rows and [Col Count] columns, featuring fields like [List 3 key columns]."
## 2. Global Analysis (File Description)
* Analyze the `file` name and the number of all `table_name`s inside the `tables` array.
* Based on all sheet descriptions, generate a single sentence summarizing the whole workbook.
# Output Format (Strict JSON)
You must output a single valid JSON object.
```json
{{
"description": "One sentence describing the whole file or database.",
"tables": [
{{
"name": "Name of table 1",
"description": "Sentence including row/col counts and key columns.",
"columns": ["col1", "col2", "col3"]
}},
...
]
}}
```
# One-Shot Demonstration
**[Example Input]**
`input_json` =
```json
{{
"file": "logistics_data.xlsx",
"tables": [
{{
"na me": "Shipments",
"row_count": 2000,
"col_count": 4,
"raw_data_snippet": "shipment_id, origin, destination, date\nSHP-001, Tokyo, London, 2024-05-20"
}},
{{
"name": "Rates",
"row_count": 50,
"col_count": 2,
"raw_data_snippet": "Route_ID, Cost_Per_Kg\nR-101, 5.50"
}}
]
}}
```
**[Example Output]**
```json
{{
"description": "The file/database logistics_data.xlsx contains supply chain logistics information for 2024, divided into shipment tracking and rate definitions (2 tables in total).",
"tables": [
{{
"name": "Shipments",
"description": "The 'Shipments' sheet tracks individual shipment records with 2000 rows and 4 columns, featuring fields such as shipment_id, origin, and destination.",
"columns": ["shipment_id", "origin", "destination", "date"]
}},
{{
"name": "Rates",
"description": "The 'Rates' sheet lists shipping cost rates with 50 rows and 2 columns, specifically Route_ID and Cost_Per_Kg.",
"columns": ["Route_ID", "Cost_Per_Kg"]
}}
]
}}
```
# Input
input_json=`{data}`

View File

@@ -0,0 +1,159 @@
# Role
You are an expert Data Steward. Your task is to analyze the metadata and content of an Excel file based on a pre-analyzed structural judgment.
**Context:** The dataset contains three types of sheets:
1. **Regular Tables**: Standard headers in row 0.
2. **Irregular Tables**: Valid data but requires `skiprows` or `usecols` parameters.
3. **Unstructured Sheets**: Dashboards, forms, or text descriptions that **cannot** be read as a dataframe.
**Constraint**: Your analysis relies on a snippet of the first 100 rows.
# Input Format
You will receive a single JSON string in the variable `input_json`. The structure is:
```json
{{
"file": "Name of the file",
"tables": [
{{
"name": "Sheet Name",
"row_count": 100,
"col_count": 5,
"raw_data_snippet": "...",
"irregular_judgment": {{
"row_header_index": int,
"cols_ranges": list,
"reasoning": "..."
}}
}}
]
}}
```
*(Note: If `irregular_judgment` is null, treat it as Regular).*
# Analysis Logic
## 1. Sheet Iteration (Table Descriptions)
For **EACH** object in the `tables` array, apply the following priority logic:
**Case A: Unstructured Sheet (irregular_judgment contains "UNSTRUCTURED")**
* **Columns**: Return an empty list `[]`.
* **Description**: "The sheet [Name] contains [something].
**Append MANDATORY Warning**: "It is Unstructured based on a 100-row sample."
**Case B: Irregular Table (irregular_judgment contains a dict and `row_header_index` > 0 or `cols_ranges` is set)**
* **Columns**: Extract column names from the row indicated by `row_header_index`.
* **Description**:
Write a concise sentence describing what the sheet tracks based on its name and columns.
1. Start with: "The sheet [Name] contains [Subject] data with [Rows] rows and [Cols] columns."
2. **Append MANDATORY Warning**: "It is irregular; requires specifying skiprows={{row_header_index}}, usecols={{cols_ranges}} using pandas dataframe."
**Case C: Regular Table (Default)**
* **Columns**: Extract from the first row of `raw_data_snippet`.
* **Description**: "The sheet [Name] contains [Subject] data with [Rows] rows and [Cols] columns, featuring fields like [Key Cols]."
## 2. Global Analysis (File Description)
Generate a single string summarizing the workbook. This summary **MUST** explicitly include:
1. **Total Count**: The number of sheets.
2. **Status List**: List every table name with its status tag:
* (Regular)
* (Irregular, requires skiprows=X, usecols=Y)
* (Unstructured)
* *Format Example:* "The file logistics_data.xlsx contains supply chain logistics information for 2024, analyze the log datas. It contains 3 sheets: 'Data' (Regular), 'Logs' (Irregular, requires skiprows=2), and 'Cover' (Unstructured)."
# Output Format (Strict JSON)
You must output a single valid JSON object.
```json
{{
"description": "Comprehensive summary including count, names, and specific status tags for ALL tables.",
"tables": [
{{
"name": "Table Name",
"description": "Specific description based on Case A, B, or C.",
"columns": ["col1", "col2"]
}}
]
}}
```
# One-Shot Demonstration
**[Example Input]**
`input_json` =
```json
{{
"file": "finance_report_v2.xlsx",
"tables": [
{{
"name": "Q1_Sales",
"row_count": 200,
"col_count": 5,
"raw_data_snippet": "Date, Item, Amount\n2023-01-01, A, 100",
}},
{{
"name": "Historical_Data",
"row_count": 500,
"col_count": 10,
"raw_data_snippet": "Confidential\nSystem Generated\n\nDate, ID, Val\n...",
"irregular_judgment": {{
"is_extractable_table": true,
"row_header_index": 3,
"cols_ranges": [0, 3],
"reasoning": "Header offset."
}}
}},
{{
"name": "Dashboard_Overview",
"row_count": 50,
"col_count": 20,
"raw_data_snippet": "Total KPI: 500 | Chart Area |\nDisclaimer: Internal Use",
"irregular_judgment": "UNSTRUCTURED"
}}
]
}}
```
**[Example Output]**
```json
{{
"description": "The file finance_report_v2.xlsx contains historical sales transaction records over the past Q1 period.
It contains 3 sheets: 'Q1_Sales' (Regular), 'Historical_Data' (Irregular, requires skiprows=3, usecols=[0, 3], sampled first 100 rows), and 'Dashboard_Overview' (Unstructured).",
"tables": [
{{
"name": "Q1_Sales",
"description": "The sheet 'Q1_Sales' contains sales transaction records. It contains 200 rows and 5 columns, featuring fields like Date, Item, and Amount.",
"columns": ["Date", "Item", "Amount"]
}},
{{
"name": "Historical_Data",
"description": "The sheet 'Historical_Data' contains historical sales transaction records records. It contains 400 rows and 21 columns. It's irregular judged by the first 100 samples(The first 3 rows contains metadata. requires specifying skiprows=3, usecols=[0, 3] using pandas dataframe.)",
"columns": ["Date", "ID", "Val"]
}},
{{
"name": "Dashboard_Overview",
"description": "The sheet 'Dashboard_Overview' contains the whole overview and summary of the whole dashboards It is Unstructured based on a 100-row sample.",
"columns": []
}}
]
}}
```
# Input
input_json=`{data}`

View File

@@ -0,0 +1,113 @@
# -*- coding: utf-8 -*-
import os
import base64
import tempfile
from typing import Any, Dict
from io import BytesIO
from pathlib import Path
import requests
from alias.agent.agents.data_source._typing import SourceType
from alias.agent.agents.data_source._data_profiler_factory import (
DataProfilerFactory,
)
from alias.agent.tools.sandbox_util import (
get_workspace_file,
)
from alias.runtime.alias_sandbox.alias_sandbox import AliasSandbox
from alias.agent.utils.llm_call_manager import (
LLMCallManager,
)
def _get_binary_buffer(
sandbox: AliasSandbox,
file_url: str,
):
if file_url.startswith(("http://", "https://")):
response = requests.get(file_url)
response.raise_for_status()
buffer = BytesIO(response.content)
else:
buffer = BytesIO(
base64.b64decode(get_workspace_file(sandbox, file_url)),
)
return buffer
def _copy_file_from_sandbox_with_original_name(
sandbox: AliasSandbox,
file_path: str,
) -> str:
"""
Copies a file from the sandbox environment
or a URL to a local temporary file.
Args:
sandbox (AliasSandbox): The sandbox environment instance.
path (str): Source path or URL.
Returns:
str: The path to the local temporary file.
"""
# Handle different types of file URLs
if file_path.startswith(("http://", "https://")):
# For web URLs, use the URL directly
file_source = file_path
else:
# For local files, save to a temporary file
file_buffer = _get_binary_buffer(
sandbox,
file_path,
)
# Create a temporary file with the same name as the original file
temp_dir = tempfile.mkdtemp()
target_file_name = os.path.basename(file_path)
full_path = Path(temp_dir) / target_file_name
with open(full_path, "wb") as f:
f.write(file_buffer.getvalue())
file_source = full_path
return str(file_source)
async def data_profile(
sandbox: AliasSandbox,
sandbox_path: str,
source_type: SourceType,
llm_call_manager: LLMCallManager,
) -> Dict[str, Any]:
"""
Generates a detailed profile and summary for data source using LLMs.
Args:
sandbox (AliasSandbox): The sandbox environment instance.
path (str): The location of the data source.
- For files: A file path or URL.
- For databases: A connection string (DSN).
source_type (SourceType): The type of the data source.
llm_call_manager: Manager for handling LLM calls
Returns:
Dict: An object containing the generated text profile of the data.
Raises:
ValueError: If the provided `source_type` is not supported.
"""
if source_type in [SourceType.CSV, SourceType.EXCEL, SourceType.IMAGE]:
local_path = _copy_file_from_sandbox_with_original_name(
sandbox,
sandbox_path,
)
elif source_type == SourceType.RELATIONAL_DB:
local_path = sandbox_path
else:
raise ValueError(f"Unsupported source type {source_type}")
profiler = DataProfilerFactory.get_profiler(
llm_call_manager=llm_call_manager,
path=local_path,
source_type=source_type,
)
return await profiler.generate_profile()

View File

@@ -0,0 +1,198 @@
# -*- coding: utf-8 -*-
import os
from pathlib import Path
from typing import List
import frontmatter
from loguru import logger
from agentscope.tool._types import AgentSkill
from alias.agent.agents.ds_agent_utils.utils import get_prompt_from_file
from alias.agent.agents.data_source._typing import SourceType
class DataSkill(AgentSkill):
"""The source type of the skill."""
type: List[SourceType]
class DataSkillManager:
"""Data Skill Selector Based on Data Source Type"""
_default_skill_path_base = os.path.join(
Path(__file__).resolve().parent.parent,
"_built_in_skill/data",
)
def __init__(
self,
):
self.skills = self.register_skill_dir()
self.source_type_2_skills = {}
for skill in self.skills:
for t in skill["type"]:
self.source_type_2_skills[t] = skill
def load(self, data_source_types: List[SourceType]) -> List[str]:
"""
Load skills based on data source type
Args:
data_source_types: List of SourceType enum values
Returns:
Selected skill content list
"""
if not data_source_types:
return []
selected_skills = []
data_source_types = set(data_source_types)
for source_type in data_source_types:
try:
# Get skill from source type mapping
skill = self.source_type_2_skills.get(source_type, None)
# Skip if no corresponding skill
if not skill:
logger.warning(
"DataSkillSelector found no valid skill for data "
f"source type: {source_type}",
)
else:
logger.info(
f"DataSkillSelector selected skill: {skill['name']} "
f"for data source type: {source_type}",
)
skill_content = get_prompt_from_file(
skill["dir"],
return_json=False,
)
if skill_content:
selected_skills.append(skill_content)
except Exception as e:
logger.error(
f"DataSkillSelector selection failed: {str(e)} "
f"for data source type: {source_type}",
)
continue
return selected_skills
def register_skill_dir(self, skill_dir=_default_skill_path_base):
"""Load skills from all directories containing SKILL.md"""
skills = []
# Check the skill directory
if not os.path.isdir(skill_dir):
raise ValueError(
f"The skill directory '{skill_dir}' does not exist or is "
"not a directory.",
)
# Walk through all files and directories in skill_dir_base
for root, dirs, _ in os.walk(skill_dir):
# Process directories - look for SKILL.md
for dir_name in dirs:
dir_path = os.path.join(root, dir_name)
skill = self.register_skill(dir_path)
if skill:
skills.append(skill)
return skills
def register_skill(self, path: str, name=None):
"""
Register a new skill dynamically
Args:
name: Skill name
path: Path to skill directory containing SKILL.md
"""
try:
# Resolve the skill path
file_path = self._resolve_skill_path(path)
if not file_path:
raise FileNotFoundError("`SKILL.md` not found")
# Parse the skill file
skill = self._parse_skill_file(file_path, name)
logger.info(
f"Successfully registered skill '{skill['name']}' "
f"from '{file_path}'",
)
return skill
except Exception as e:
logger.error(
f"Failed to register skill '{skill['name']}' from "
f"'{path}': {e}",
)
return None
def _resolve_skill_path(self, path: str) -> str:
"""
Resolve a skill path to the actual markdown file path
Args:
path: Path to skill markdown file or directory containing SKILL.md
Returns:
str: Path to the actual markdown file, or empty string if invalid
"""
if os.path.isdir(path):
skill_md_path = os.path.join(path, "SKILL.md")
if not os.path.isfile(skill_md_path):
logger.warning(f"Directory '{path}' does not contain SKILL.md")
return ""
return skill_md_path
else:
logger.warning(f"Invalid skill path: {path}")
return ""
def _parse_skill_file(self, file_path, name=None):
"""Parse a skill file and add it to skills_list"""
# Check YAML Front Matter
post = frontmatter.load(file_path)
# Use directory name as skill name if not provided in YAML
if name is None:
dir_name = os.path.basename(os.path.dirname(file_path))
name = post.get("name", dir_name)
else:
name = post.get("name", name)
description = post.get("description", None)
_type = post.get("type", None)
if not name or not description or not _type:
raise ValueError(
f"The file '{file_path}' must have a YAML Front "
"Matter including `name`, `description`, and `type` fields",
)
_type = _type if isinstance(_type, list) else [_type]
if any(not SourceType.is_valid_source_type(t) for t in _type):
raise ValueError(
f"Type of file '{file_path}' must be a member "
"(or a list of members) of SourceType",
)
name, description = str(name), str(description)
_type = [SourceType(t) for t in _type]
return DataSkill(
name=name,
description=description,
type=_type,
dir=file_path,
)

View File

@@ -0,0 +1,620 @@
# -*- coding: utf-8 -*-
# pylint: disable=R1702,R0912,R0911
import os
import json
from pathlib import Path
from typing import Dict, Any, Optional, List
import yaml
from loguru import logger
from agentscope.mcp import StdIOStatefulClient
from agentscope_runtime.sandbox.box.sandbox import Sandbox
from alias.agent.agents.data_source.data_skill import DataSkillManager
from alias.agent.agents.data_source._typing import (
SOURCE_TYPE_TO_ACCESS_TYPE,
SourceAccessType,
SourceType,
)
from alias.agent.agents.data_source.data_profile import data_profile
from alias.agent.agents.data_source.utils import replace_placeholders
from alias.agent.tools.toolkit_hooks.text_post_hook import TextPostHook
from alias.agent.tools.alias_toolkit import AliasToolkit
from alias.agent.tools.sandbox_util import (
copy_local_file_to_workspace,
)
from alias.agent.utils.llm_call_manager import (
LLMCallManager,
)
class DataSource:
"""
Unified data source class representing any data source.
"""
def __init__(
self,
endpoint: str,
source_type: SourceType,
name: Optional[str] = None,
config: Optional[Dict[str, Any]] = None,
):
"""
Initialize a data source.
Args:
source_access_type: Type of the data source access \
(SourceAccessType enum)
source_type: Type of the data source (SourceType enum)
name: Name/identifier of the data source
endpoint: Address/DNS/URL/path to access the data source
description: Optional description of the data source
config: Configuration for this data source
"""
self.endpoint = endpoint
self.name = name
source_access_type = SOURCE_TYPE_TO_ACCESS_TYPE.get(
source_type,
SourceAccessType.DIRECT,
)
self.source_access_type = source_access_type
self.source_type = source_type
self.config = config or {}
self.profile = {}
self.source_desc = None
self.source_access_desc = None
async def prepare(self, toolkit: AliasToolkit):
"""
Prepare data source.
For LOCAL_FILE: Upload file to sandbox workspace
For MCP_TOOL: Register corresponding MCP server
Args:
sandbox: Sandbox instance
"""
logger.info(f"Preparing data source {self.name}...")
if self.source_access_type == SourceAccessType.DIRECT:
# Get the filename and construct target path in workspace
filename = os.path.basename(self.endpoint)
target_path = f"/workspace/{filename}"
if os.getenv("LINK_FILE_TO_WORKSPACE", "off").lower() == "on":
logger.info(
f"Creating symlink for {self.endpoint} "
f"to {target_path}",
)
# Build ln -s command
command = f"ln -s '{self.endpoint}' '{target_path}'"
result = toolkit.sandbox.call_tool(
name="run_shell_command",
arguments={"command": command},
)
if result.get("isError"):
raise ValueError(
"Failed to create symlink for "
f"{self.endpoint}: {result}",
)
else:
logger.info(f"Uploading {self.endpoint} to {target_path}")
result = copy_local_file_to_workspace(
sandbox=toolkit.sandbox,
local_path=self.endpoint,
target_path=target_path,
)
if result.get("isError"):
raise ValueError(
f"Failed to upload {self.endpoint}: " f"{result}",
)
self.source_access = target_path
self.source_desc = "Local file"
self.source_access_desc = f"Access at path: `{target_path}`"
logger.info(f"Successfully loaded to {result}")
# Check if this is an MCP tool source
elif self.source_access_type == SourceAccessType.VIA_MCP:
server_config = self.config.get("mcp_server", {})
mcp_server_name = server_config.keys()
if len(mcp_server_name) != 1:
raise ValueError("Register server one by one!")
mcp_server_name = list(mcp_server_name)[0]
server_config = server_config[mcp_server_name]
cmd = server_config.get("command")
args = server_config.get("args")
if cmd is None or args is None:
raise ValueError(
"MCP server configuration requires non-empty "
"`command` and `args` fields to start!",
)
client = StdIOStatefulClient(
self.name,
command=cmd,
args=args,
env=server_config.get("env"),
)
text_hook = TextPostHook(
toolkit.sandbox,
budget=5000,
auto_save=True,
)
await toolkit.add_and_connect_mcp_client(
client,
postprocess_func=text_hook.truncate_and_save_response,
)
registered_tools = [
t.name
for t in list(
await toolkit.additional_mcp_clients[-1].list_tools(),
)
]
self.source_access = self.endpoint
self.source_desc = f"{self.source_type}"
self.source_access_desc = (
f"Access via MCP tools: [{', '.join(registered_tools)}]"
)
logger.info(f"Successfully connected to {self.name}")
else:
logger.info(
f"Skipping preparation for source type: {self.source_type}",
)
def get_coarse_desc(self):
return (
f"{self.source_desc}. {self.source_access_desc}: "
+ f"{self._general_profile()}"
)
async def prepare_profile(
self,
sandbox: Sandbox,
llm_call_manager: LLMCallManager,
) -> Optional[Dict[str, Any]]:
"""Run type-specific profiling."""
if llm_call_manager and not self.profile:
try:
self.profile = await data_profile(
sandbox=sandbox,
sandbox_path=self.source_access,
source_type=self.source_type,
llm_call_manager=llm_call_manager,
)
logger.info(
"Profiling successfully: "
+ f"{self._general_profile()[:100]}...",
)
except ValueError as e:
self.profile = None
logger.warning(f"Warning when profile data: {e}")
except Exception as e:
self.profile = None
logger.error(f"Error when profile data: {e}")
return self.profile
def _refined_profile(self) -> str:
if self.profile:
return yaml.dump(
self.profile,
allow_unicode=True,
sort_keys=False,
default_flow_style=False
if self.source_type == SourceType.IMAGE
else None,
width=float("inf"),
)
else:
return ""
def _general_profile(self) -> str:
return self.profile["description"] if self.profile else ""
def __str__(self) -> str:
return (
f"DataSource(name='{self.name}', type='{self.source_type}', "
f"endpoint='{self.endpoint}')"
)
def __repr__(self) -> str:
return self.__str__()
class DataSourceManager:
"""
Manager class for handling multiple data sources.
Provides methods to add, retrieve, and manage data sources.
Also manages data source configurations with hierarchical lookup.
"""
_default_data_source_config = os.path.join(
Path(__file__).resolve().parent,
"_default_config.json",
)
def __init__(
self,
sandbox: Sandbox,
llm_call_manager: LLMCallManager,
):
"""Initialize an empty data source manager."""
self._data_sources: Dict[str, DataSource] = {}
self._type_defaults = {}
self._load_default_config()
self.skill_manager = DataSkillManager()
self.selected_skills = None
self.toolkit = AliasToolkit(sandbox=sandbox)
self.llm_call_manager = llm_call_manager
def add_data_source(
self,
config: str | Dict = None,
):
"""
Add a new data source (or multiple sources) to the manager.
Args:
config: endpoint(Address/DNS/URL/path to the data source) or
configuration for data source conection
"""
if isinstance(config, str):
endpoint = config
conn_config = None
else:
if "endpoint" not in config:
logger.error(
f"Missing 'endpoint' in config for source '{config}'",
)
endpoint = config["endpoint"]
conn_config = config
sources = set()
if os.path.isdir(endpoint):
# Add all files in directory
for filename in os.listdir(endpoint):
file_path = os.path.join(endpoint, filename)
sources.add(file_path)
else:
sources.add(endpoint)
for endpoint in sources:
# Auto-detect source type
source_type = self._detect_source_type(endpoint)
# Auto-generate name
name = self._generate_name(endpoint)
# Get configuration for this data source
if not conn_config:
conn_config = self.get_default_config(source_type)
if conn_config:
conn_config = replace_placeholders(
conn_config,
{
"endpoint": endpoint,
},
)
# Create data source with configuration
data_source = DataSource(endpoint, source_type, name, conn_config)
self._data_sources[endpoint] = data_source
async def prepare_data_sources(self) -> None:
"""
Prepare all data sources.
Args:
sandbox: Optional sandbox instance for file uploads and startup \
MCP servers
"""
logger.info(f"Preparing {len(self._data_sources)} data source(s)...")
all_data_sources = self._data_sources.values()
for data_source in all_data_sources:
await data_source.prepare(self.toolkit)
await data_source.prepare_profile(
self.toolkit.sandbox,
self.llm_call_manager,
)
def _generate_name(self, endpoint: str) -> str:
"""
Generate an name based on the endpoint.
For databases, removes passwords and uses scheme + database name.
For files, uses filename.
For URLs, uses domain or last part of path.
"""
from urllib.parse import urlparse
try:
# For file paths
if os.path.isfile(endpoint):
filename = os.path.basename(endpoint)
# Remove extension and sanitize
name_without_ext = os.path.splitext(filename)[0]
return self._sanitize_name(name_without_ext)
# For database connections
db_indicators = [
"://",
".db",
".sqlite",
"mongodb://",
"mongodb+srv://",
"neo4j://",
"bolt://",
]
if any(
indicator in endpoint.lower() for indicator in db_indicators
):
if "://" in endpoint:
try:
# Split by :// to get scheme and rest
scheme, rest = endpoint.split("://", 1)
scheme = scheme.lower()
# Handle authentication (user:password@host)
if "@" in rest:
auth_part, host_part = rest.split("@", 1)
if ":" in auth_part:
# Has user:password format, keep only username
username = auth_part.split(":")[0]
rest = f"{username}@{host_part}"
# If no colon, it's just username@host, keep as is
# Extract database name
db_name = "unknown"
if "/" in rest:
# Split by / and take last part before
# query parameters
path_parts = rest.split("/")
if len(path_parts) > 1:
db_name = (
path_parts[-1].split("?")[0].split("#")[0]
)
if not db_name: # If empty, try second to last
db_name = (
path_parts[-2]
if len(path_parts) > 2
else "unknown"
)
else:
# Use host name if no database name in path
host = (
rest.split(":")[0].split("/")[0].split("@")[-1]
)
db_name = host
# Create name: scheme_dbname
return self._sanitize_name(f"{scheme}_{db_name}")
except Exception as e:
logger.warning(
f"Error parsing database URL {endpoint}: {e}",
)
# Fall through to URL handling
elif "." in endpoint:
# Use filename without extension for .db/.sqlite files
filename = os.path.basename(endpoint)
name_without_ext = os.path.splitext(filename)[0]
return self._sanitize_name(name_without_ext)
# For URLs (including database URLs that failed to parse)
if "://" in endpoint:
try:
parsed = urlparse(endpoint)
if parsed.netloc:
# Use domain name (without port)
domain = parsed.netloc.split(":")[0].split("@")[
-1
] # Remove username if present
# If path exists, use last part of path
if parsed.path and parsed.path != "/":
path_parts = parsed.path.strip("/").split("/")
if path_parts:
return self._sanitize_name(path_parts[-1])
return self._sanitize_name(domain)
elif parsed.path:
# Use last part of path
path_parts = parsed.path.strip("/").split("/")
if path_parts:
return self._sanitize_name(path_parts[-1])
except Exception as e:
logger.warning(f"Error parsing URL {endpoint}: {e}")
# Fallback: use a sanitized version of the endpoint
return self._sanitize_name(endpoint[:50])
except Exception as e:
logger.error(f"Error generating default name for {endpoint}: {e}")
# Ultimate fallback
return self._sanitize_name("unknown_source")
def _sanitize_name(self, name: str) -> str:
"""Sanitize a name to be used as a data source identifier."""
import re
# Keep only alphanumeric and underscore characters
sanitized = re.sub(r"[^a-zA-Z0-9_]", "_", name)
# Ensure it starts with a letter or underscore
if sanitized and not sanitized[0].isalpha() and sanitized[0] != "_":
sanitized = "_" + sanitized
# Truncate if too long
sanitized = sanitized[:50]
# Ensure it's not empty
if not sanitized:
sanitized = "unknown"
return sanitized
def _detect_source_type(self, endpoint: str) -> SourceType:
"""Auto-detect source type based on endpoint."""
endpoint_lower = endpoint.lower()
# Check for file extensions
if endpoint_lower.endswith(".csv"):
source_type = SourceType.CSV
elif endpoint_lower.endswith((".xls", ".xlsx", "xlsm")):
source_type = SourceType.EXCEL
elif endpoint_lower.endswith(".json"):
source_type = SourceType.JSON
elif endpoint_lower.endswith((".txt", ".log", ".md")):
source_type = SourceType.TEXT
elif endpoint_lower.endswith(
(".jpg", ".jpeg", ".png", ".gif", ".bmp"),
):
source_type = SourceType.IMAGE
# Check for database connection strings/patterns
# Relational databases
elif any(
keyword in endpoint_lower
for keyword in [
"postgresql://",
"postgres://",
"pg://",
"mysql://",
"mariadb://",
"sqlserver://",
]
):
source_type = SourceType.RELATIONAL_DB
elif (
"sqlite://" in endpoint_lower
or endpoint_lower.endswith(".db")
or endpoint_lower.endswith(".sqlite")
):
source_type = SourceType.RELATIONAL_DB
else:
source_type = SourceType.OTHER
return source_type
def get_all_data_sources_desc(self) -> str:
"""
Get descriptions of all data sources.
Returns:
List of data source descriptions
"""
return "Available data sources: \n" + "\n".join(
[
f"[{idx}] " + ds.get_coarse_desc()
for idx, ds in enumerate(self._data_sources.values())
],
)
def get_local_data_sources(self) -> List[str]:
"""
Get list of local data source endpoints
"""
return [
ds.endpoint
for ds in self._data_sources.values()
if ds.source_access_type == SourceAccessType.DIRECT
]
def get_all_data_sources_name(self) -> List[str]:
"""
Get a list of all data source names.
Returns:
List of all data source names
"""
return list(self._data_sources.keys())
def remove_data_source(self, name: str) -> bool:
"""
Remove a data source by name.
Args:
name: Name of the data source to remove
Returns:
True if successfully removed, False if not found
"""
if name in self._data_sources:
del self._data_sources[name]
return True
return False
def get_default_config(self, source_type: SourceType) -> Dict[str, Any]:
"""
Get the default configuration for a source type.
Args:
source_type: The SourceType to get default config for
Returns:
Default configuration dictionary, empty dict if not registered
"""
return self._type_defaults.get(source_type, {})
def _load_default_config(self) -> None:
"""Load default type to configuration."""
try:
with open(
self._default_data_source_config,
"r",
encoding="utf-8",
) as f:
config = json.load(f)
# Load type defaults
for type_name, type_config in config.items():
try:
source_type = SourceType(type_name)
self._type_defaults[source_type] = type_config
except ValueError:
# Skip invalid source types
continue
except FileNotFoundError:
# If config file doesn't exist, initialize with empty configs
pass
except json.JSONDecodeError:
# If config file is invalid JSON, initialize with empty configs
pass
def __len__(self) -> int:
"""Return the number of data sources managed."""
return len(self._data_sources)
def get_data_skills(self):
# TODO: update when data source changed
if self.selected_skills is None:
source_types = [
data.source_type for data in self._data_sources.values()
]
self.selected_skills = self.skill_manager.load(source_types)
return "\n".join(self.selected_skills)

View File

@@ -0,0 +1,30 @@
# -*- coding: utf-8 -*-
import logging
# Set up logger
logger = logging.getLogger(__name__)
def replace_placeholders(obj, source_config):
if isinstance(obj, str):
import re
pattern = r"\$\{([^}]+)\}"
matches = re.finditer(pattern, obj)
result = obj
for match in matches:
var_name = match.group(1)
if var_name in source_config:
result = result.replace(
match.group(0),
str(source_config[var_name]),
)
return result
elif isinstance(obj, dict):
return {
k: replace_placeholders(v, source_config) for k, v in obj.items()
}
elif isinstance(obj, list):
return [replace_placeholders(item, source_config) for item in obj]
else:
return obj

View File

@@ -137,11 +137,12 @@ async def files_filter_pre_reply_hook(
# Even if the user only uploaded supplementary files in this interaction,
# We will also check whether the previously uploaded files are relevant
# to the question.
self.uploaded_files = list(
set(files_list) | set(getattr(self, "uploaded_files", [])),
uploaded_files = list(
set(files_list) | set(self.data_manager.get_local_data_sources()),
)
if len(self.uploaded_files) < 100:
if len(uploaded_files) < 100:
logger.info(
"Scalable files filtering: not enough files to filter.",
)
@@ -164,7 +165,7 @@ await files_filter(query, files_list, api_key=api_key)
files_filter_code += template.substitute(
query=safe_query,
files_list=repr(self.uploaded_files),
files_list=repr(uploaded_files),
api_key=safe_api_key,
)

View File

@@ -28,63 +28,7 @@ When executing any data science task (data loading, cleaning, analysis, modeling
---
## Task Management Rules
- **You must use `todo_write` to track progress**, especially for multi-step tasks.
- Mark each subtask as complete **immediately** upon finishing—no delays or batch updates.
- Skipping planning risks missing critical steps—this is unacceptable.
---
## Data Handling Requirements
### Data Inspection Methods
Before any operation, **you must** inspect the true structure of the data source using tools (preferably `run_ipython_cell`):
| Data Type | Inspection Method |
|------------------|-----------------------------------------------------------------------------------|
| **Database** | Query table schema (`DESCRIBE table`) and preview first 510 rows (`SELECT * FROM ... LIMIT 5`) |
| **CSV/Excel** | Use `pandas.head(n)` to view column names and samples |
| **Images** | Use PIL to get dimensions/format, or invoke vision tools to extract content |
| **Text Files** | Read first 510 lines to determine structure and encoding |
| **JSON** | Inspect from outer to inner layers progressively |
> **Core Principle**: What you see is fact; what you havent seen is unknown.
---
### Data Preprocessing Methods
##### Messy Spreadsheet Handling
After initial inspection of CSV or Excel files, if you observe:
- Many `"Unnamed: X"`, `NaN`, or `NaT` entries
- Missing or ambiguous headers
- Multiple data blocks within a single worksheet
Then **prioritize** advanced cleaning tools:
- `clean_messy_spreadsheet`: Extract key information from tables and output as JSON for downstream analysis
Only fall back to manual pandas row/block parsing if this tool fails.
---
### Strict Data Volume Limits
To prevent system crashes, strictly limit data volume during queries and reads:
- **Database queries**: Always use `LIMIT` (typically 510 rows)
- **Well-structured CSV/Excel**: Use `head()`, `nrows`, or sampling to fetch minimal data
- **Large text files**: Read only the first few lines or process iteratively in chunks
> **Warning**: Unrestricted large data reads will cause system failure.
---
### Fact-Based, No Assumptions
## Principles: Fact-Based, No Assumptions
- All decisions must be grounded in the **given task context**. Never simplify, generalize, or subjectively interpret the task goal, data purpose, or business scenario. Any action inconsistent with the problem context is invalid and dangerous.
- Never act on assumptions, guesses, or past experience—even if the situation seems "obvious" or "routine."
@@ -93,6 +37,14 @@ To prevent system crashes, strictly limit data volume during queries and reads:
---
## Task Management Rules
- **You must use `todo_write` to track progress**, especially for multi-step tasks.
- Mark each subtask as complete **immediately** upon finishing—no delays or batch updates.
- Skipping planning risks missing critical steps—this is unacceptable.
---
## Visualization Strategy
- **Plotting library**: Prefer `matplotlib`

View File

@@ -44,8 +44,13 @@ Each task in the roadmap contains:
- Brief Response
- Detailed Report
- You should choose the template that is most appropriate for the user task.
- **Brief Respoonse Template** should ONLY be used when the user asks for a simple data query task, where ONLY numeric or concise string values are returned, and complex analysis or research are not required.
- **Detailed Report Template** should be used when the user asks for a detailed analysis of the data, where the analysis and research are required.
- **Brief Respoonse Template** should ONLY be used when the user asks for a
simple, static data point (e.g., a total count or a specific value), where
the answer is returned as a single numeric or concise string value with no
analysis, transformation, comparison, or interpretation required.
- **Detailed Report Template** should be used whenever the task involves
distribution, discrepancy, imbalance, comparison, trend, root cause, or
any form of analysis, interpretation, or evidence generation.
2. Data Source Constraints
- **ONLY use information explicitly present in the log file**
@@ -103,7 +108,7 @@ You MUST ensure all captions, subtitles, and other contents in the report are wr
- "brief_response": The brief response content.
- When 'is_brief_response' is True, this field should be fulfilled with the brief response content following the **Brief Response Template**.
- When 'is_brief_response' is False, this field should be a concise summary of the detailed report in in markdown format illustrating the key findings and insights.
- "detailed_report_content": The detailed markdown report content following the **Detailed Report Template**. This field is ONLY generated when 'is_brief_response' is False, otherwise fulfill an empty string.
- "report_content": The detailed markdown report content following the **Detailed Report Template**. This field is ONLY generated when 'is_brief_response' is False, otherwise fulfill an empty string.
- You MUST ensure the JSON object is a valid JSON string and can be parsed by json.loads().
- Double check all escapes are valid.

View File

@@ -112,8 +112,11 @@ def truncate_long_text_post_hook(
def _add_tool_postprocessing_func(toolkit: AliasToolkit) -> None:
for tool_func, _ in toolkit.tools.items():
if tool_func.startswith("run_ipython_cell"):
funcs: list = [ansi_escape_post_hook]
funcs.append(summarize_plt_chart_hook)
funcs: list = [
ansi_escape_post_hook,
summarize_plt_chart_hook,
truncate_long_text_post_hook,
]
toolkit.tools[tool_func].postprocess_func = partial(
run_ipython_cell_post_hook,
funcs,

View File

@@ -1,20 +1,62 @@
# -*- coding: utf-8 -*-
import os
import json
import time
from typing import Tuple
import dotenv
from pydantic import BaseModel, Field
from agentscope.message import Msg
from .utils import model_call_with_retry, get_prompt_from_file
from .ds_config import PROMPT_DS_BASE_PATH
dotenv.load_dotenv()
class ReportResponse(BaseModel):
is_brief_response: bool = Field(
...,
description=(
"True if the response is a brief response; "
"False if it includes a detailed report."
),
)
brief_response: str = Field(
...,
description=(
"The brief response content. "
"When 'is_brief_response' is True, this field contains the full "
"brief response following the Brief Response Template. "
"When 'is_brief_response' is False, this field contains a concise "
"markdown summary of the detailed report, highlighting key "
"findings and insights."
),
json_schema_extra={
"example": (
"The analysis shows a 15% increase in user engagement "
"after the feature update."
),
},
)
report_content: str = Field(
...,
description=(
"The detailed markdown report content following the "
"Detailed Report Template. This field MUST be an empty "
"string ('') when 'is_brief_response' is True. It MUST contain "
"the full detailed report when 'is_brief_response' is False."
),
json_schema_extra={
"example": "### User Task Description...\n"
"### Associated Data Sources...\n"
"### Research Conclusion...\n### Task1...### Task2...",
},
)
class ReportGenerator:
def __init__(self, model, formatter, memory_log: str):
self.model = model
@@ -62,22 +104,13 @@ class ReportGenerator:
self.formatter,
msgs=msgs,
msg_name="Report Generation",
structured_model=ReportResponse,
)
raw_response = res.content[0]["text"]
# TODO: More robust response cleaning
if raw_response.strip().startswith("```json"):
cleaned = raw_response.strip()[len("```json") :].lstrip("\n")
if cleaned.endswith("```"):
cleaned = cleaned[:-3].rstrip()
response = cleaned
else:
response = raw_response.strip()
end_time = time.time()
# print(response)
print(f"Log to markdown took {end_time - start_time} seconds")
return response
return res.content[-1]["input"]
async def _convert_to_html(self, markdown_content: str) -> str:
start_time = time.time()
@@ -102,21 +135,15 @@ class ReportGenerator:
print(f"Convert to html took {end_time - start_time} seconds")
return response.content[0]["text"]
async def generate_report(self) -> Tuple[str, str]:
markdown_response = await self._log_to_markdown()
# responseFormat: {
# "is_brief_response": True,
# "brief_response": brief_response_content,
# "report_content": detailed_report_content
# }
try:
markdown_content = json.loads(markdown_response)
except json.JSONDecodeError as e:
print(f"Error parsing JSON response: {e}")
print(f"Response content: {markdown_response}")
raise
async def generate_report(self) -> Tuple[str, str, str]:
"""
responseFormat: {
"is_brief_response": True,
"brief_response": brief_response_content,
"report_content": detailed_report_content
}
"""
markdown_content = await self._log_to_markdown()
if (
str(markdown_content.get("is_brief_response", False)).lower()
@@ -124,12 +151,19 @@ class ReportGenerator:
):
# During brief response mode,
# directly return the brief response to the user.
return markdown_content["brief_response"], ""
return markdown_content.get("brief_response", ""), "", ""
else:
# In detailed report mode,
# convert the detailed report to HTML and return it to the user;
# if a brief summary of the report is needed,
# it can be obtained through markdown_content["brief_response"].
return markdown_content[
"brief_response"
], await self._convert_to_html(markdown_content["report_content"])
html_content = ""
if os.getenv("ENABLE_HTML_REPORT", "ON").lower() != "off":
html_content = await self._convert_to_html(
markdown_content.get("report_content", ""),
)
return (
markdown_content.get("brief_response", ""),
markdown_content.get("report_content", ""),
html_content,
)

View File

@@ -5,9 +5,9 @@ import json
from typing import Union
from agentscope.message import Msg
from tenacity import retry, stop_after_attempt, wait_fixed
from .ds_config import PROMPT_DS_BASE_PATH
MODEL_MAX_RETRIES = 50
from alias.agent.utils.constants import MODEL_MAX_RETRIES
from .ds_config import PROMPT_DS_BASE_PATH
def get_prompt_from_file(
@@ -36,10 +36,16 @@ async def model_call_with_retry(
tool_json_schemas=None,
tool_choice=None,
msg_name="model_call",
structured_model=None,
) -> Msg:
prompt = await formatter.format(msgs=msgs)
res = await model(prompt, tools=tool_json_schemas, tool_choice=tool_choice)
res = await model(
prompt,
tools=tool_json_schemas,
tool_choice=tool_choice,
structured_model=structured_model,
)
if model.stream:
msg = Msg(msg_name, [], "assistant")

View File

@@ -3,10 +3,15 @@
import uuid
from enum import Enum
from typing import Any, Optional, Literal
from dataclasses import dataclass
from dataclasses import dataclass, field
from datetime import datetime, timezone
from pydantic import BaseModel, Field
def _get_utc_now_iso():
return datetime.now(timezone.utc).isoformat()
@dataclass
class MockFileBase:
filename: str
@@ -15,8 +20,8 @@ class MockFileBase:
storage_path: str
size: int = -1
storage_type: str = "unknown"
create_time: str = "xxxyyy"
update_time: str = "xxxyyy"
create_time: str = field(default_factory=_get_utc_now_iso)
update_time: str = field(default_factory=_get_utc_now_iso)
user_id: uuid.UUID = uuid.uuid4()
@@ -62,12 +67,13 @@ class UserMessage(BaseMessage):
name: str = "User"
@dataclass
class MockMessage:
id: uuid.UUID = uuid.uuid4()
id: uuid.UUID = field(default_factory=uuid.uuid4)
message: Optional[dict] = None
files: list[Any] = []
create_time: str = "xxxyyy"
update_time: str = "xxxyyy"
files: list[Any] = field(default_factory=list)
create_time: str = field(default_factory=_get_utc_now_iso)
update_time: str = field(default_factory=_get_utc_now_iso)
class SubTaskToPrint(BaseModel):

View File

@@ -40,6 +40,7 @@ class SessionEntity:
query: str
upload_files: List = []
is_chat: bool = False
data_config: List | None = None
use_long_term_memory_service: bool = False
def __init__(
@@ -51,6 +52,7 @@ class SessionEntity:
"bi",
"finance",
] = "general",
data_config: List | None = None,
use_long_term_memory_service: bool = False,
):
self.user_id: uuid.UUID = uuid.UUID(
@@ -62,6 +64,7 @@ class SessionEntity:
self.conversation_id: uuid.UUID = uuid.uuid4()
self.session_id: uuid.UUID = uuid.uuid4()
self.chat_mode = chat_mode
self.data_config = data_config
self.use_long_term_memory_service = use_long_term_memory_service
def ids(self):
@@ -79,6 +82,7 @@ class MockSessionService:
def __init__(
self,
runtime_model: Any = None,
data_config: List | None = None,
use_long_term_memory_service: bool = False,
):
self.session_id = "mock_session"
@@ -86,6 +90,7 @@ class MockSessionService:
self.messages = []
self.plan = MockPlan()
self.session_entity = SessionEntity(
data_config=data_config,
use_long_term_memory_service=use_long_term_memory_service,
)
logger.info(
@@ -215,6 +220,72 @@ class MockSessionService:
self.messages.append(db_message)
return db_message
async def append_to_latest_message(
self,
content_to_append: str,
role_filter: Optional[str] = None,
) -> Optional[MockMessage]:
"""
Append content to the most recent message.
Args:
content_to_append: Content to append to the message
role_filter: Optional role filter (e.g., 'user', 'assistant')
Returns:
Updated MockMessage or None if no message found
"""
if not self.messages:
logger.warning("No messages to append to")
return None
# Find the most recent message (optionally filtered by role)
target_message = None
for msg in reversed(self.messages):
if role_filter is None or msg.message.get("role") == role_filter:
target_message = msg
break
if target_message is None:
logger.warning(f"No message found with role={role_filter}")
return None
# Append content
current_content = target_message.message.get("content", "")
if isinstance(current_content, str):
target_message.message["content"] = (
current_content + content_to_append
)
elif isinstance(current_content, list):
# Handle multi-modal content (list of content blocks)
target_message.message["content"].append(
{
"type": "text",
"text": content_to_append,
},
)
else:
logger.error(f"Unsupported content type: {type(current_content)}")
return None
# Update timestamp
target_message.update_time = datetime.now(timezone.utc).isoformat()
# Optional: Log to file
if hasattr(self, "log_storage_path"):
content_log = (
"=" * 10
+ "\n"
+ f"APPEND to Role: {target_message.message.get('role')}\n"
+ f"Appended: {content_to_append}\n"
+ "=" * 10
+ "\n"
)
with open(self.log_storage_path, "a") as file:
file.write(content_log)
return target_message
async def get_messages(self) -> List[MockMessage]:
logger.log("SEND_MSG", "Get all messages")
return self.messages

View File

@@ -1,9 +1,11 @@
# -*- coding: utf-8 -*-
# pylint: disable=W0612,E0611,C2801
import os
import traceback
from datetime import datetime
import asyncio
import traceback
from typing import Literal
from loguru import logger
@@ -17,27 +19,33 @@ from alias.agent.agents import (
BrowserAgent,
DeepResearchAgent,
MetaPlanner,
DataScienceAgent,
init_ds_toolkit,
init_dr_toolkit,
)
from alias.agent.agents.meta_planner_utils._worker_manager import share_tools
from alias.agent.mock import MockSessionService as SessionService
from alias.agent.tools import AliasToolkit
from alias.agent.utils.constants import (
BROWSER_AGENT_DESCRIPTION,
DEFAULT_DEEP_RESEARCH_AGENT_NAME,
DEEPRESEARCH_AGENT_DESCRIPTION,
DS_AGENT_DESCRIPTION,
)
from alias.agent.tools.add_tools import add_tools
from alias.agent.agents.ds_agent_utils import (
add_ds_specific_tool,
from alias.agent.utils.prepare_data_source import (
add_data_source_tools,
prepare_data_sources,
)
from alias.agent.tools.add_tools import add_tools
from alias.agent.memory.longterm_memory import AliasLongTermMemory
from alias.server.clients.memory_client import MemoryClient
from alias.agent.agents._data_science_agent import (
DataScienceAgent,
init_ds_toolkit,
)
from alias.agent.utils.llm_call_manager import (
LLMCallManager,
)
MODEL_FORMATTER_MAPPING = {
"qwen3-max": [
@@ -104,9 +112,28 @@ async def arun_meta_planner(
# Init deep research toolkit
deep_research_toolkit = init_dr_toolkit(worker_full_toolkit)
# Init BI agent toolkit
# Init data science agent toolkit
ds_toolkit = init_ds_toolkit(worker_full_toolkit)
# Initialize data source manager
llm_call_manager = LLMCallManager(
base_model_name=MODEL_CONFIG_NAME,
vl_model_name=VL_MODEL_NAME,
model_formatter_mapping=MODEL_FORMATTER_MAPPING,
)
data_manager = await prepare_data_sources(
session_service=session_service,
sandbox=sandbox,
llm_call_manager=llm_call_manager,
)
add_data_source_tools(
data_manager,
worker_full_toolkit,
browser_toolkit,
deep_research_toolkit,
ds_toolkit,
)
try:
model, formatter = MODEL_FORMATTER_MAPPING[MODEL_CONFIG_NAME]
browser_agent = BrowserAgent(
@@ -175,13 +202,15 @@ async def arun_meta_planner(
description=DEEPRESEARCH_AGENT_DESCRIPTION,
worker_type="built-in",
)
# === add BI agent ===
# === add data science agent ===
ds_agent = DataScienceAgent(
name="Data_Science_Agent",
model=model,
formatter=formatter,
memory=InMemoryMemory(),
toolkit=ds_toolkit,
data_manager=data_manager,
sys_prompt=data_manager.get_data_skills(),
max_iters=30,
session_service=session_service,
)
@@ -219,6 +248,19 @@ async def arun_deepresearch_agent(
"run_shell_command",
]
share_tools(global_toolkit, worker_toolkit, test_tool_list)
llm_call_manager = LLMCallManager(
base_model_name=MODEL_CONFIG_NAME,
vl_model_name=VL_MODEL_NAME,
model_formatter_mapping=MODEL_FORMATTER_MAPPING,
)
await prepare_data_sources(
session_service,
sandbox,
worker_toolkit,
llm_call_manager,
)
worker_agent = DeepResearchAgent(
name="Deep_Research_Agent",
model=model,
@@ -285,6 +327,18 @@ async def arun_finance_agent(
active=True,
)
llm_call_manager = LLMCallManager(
base_model_name=MODEL_CONFIG_NAME,
vl_model_name=VL_MODEL_NAME,
model_formatter_mapping=MODEL_FORMATTER_MAPPING,
)
await prepare_data_sources(
session_service,
sandbox,
worker_toolkit,
llm_call_manager,
)
worker_agent = DeepResearchAgent(
name="Deep_Research_Agent",
model=model,
@@ -326,17 +380,21 @@ async def arun_datascience_agent(
session_service: SessionService, # type: ignore[valid-type]
sandbox: Sandbox = None,
):
global_toolkit = AliasToolkit(sandbox, add_all=True)
# await add_tools(global_toolkit)
worker_toolkit = AliasToolkit(sandbox)
model, formatter = MODEL_FORMATTER_MAPPING[MODEL_CONFIG_NAME]
test_tool_list = [
"write_file",
"run_ipython_cell",
"run_shell_command",
]
share_tools(global_toolkit, worker_toolkit, test_tool_list)
add_ds_specific_tool(worker_toolkit)
global_toolkit = AliasToolkit(sandbox, add_all=True)
worker_toolkit = init_ds_toolkit(global_toolkit)
llm_call_manager = LLMCallManager(
base_model_name=MODEL_CONFIG_NAME,
vl_model_name=VL_MODEL_NAME,
model_formatter_mapping=MODEL_FORMATTER_MAPPING,
)
data_manager = await prepare_data_sources(
session_service=session_service,
sandbox=sandbox,
binded_toolkit=worker_toolkit,
llm_call_manager=llm_call_manager,
)
try:
worker_agent = DataScienceAgent(
@@ -345,6 +403,8 @@ async def arun_datascience_agent(
formatter=formatter,
memory=InMemoryMemory(),
toolkit=worker_toolkit,
data_manager=data_manager,
sys_prompt=data_manager.get_data_skills(),
max_iters=30,
session_service=session_service,
)
@@ -360,6 +420,7 @@ async def arun_datascience_agent(
finally:
try:
await global_toolkit.close_mcp_clients()
await worker_toolkit.close_mcp_clients()
except (RuntimeError, asyncio.CancelledError) as e:
# Event loop might be closed during shutdown
if "Event loop is closed" in str(e) or isinstance(
@@ -386,6 +447,18 @@ async def arun_browseruse_agent(
add_all=True,
is_browser_toolkit=True,
)
llm_call_manager = LLMCallManager(
base_model_name=MODEL_CONFIG_NAME,
vl_model_name=VL_MODEL_NAME,
model_formatter_mapping=MODEL_FORMATTER_MAPPING,
)
await prepare_data_sources(
session_service,
sandbox,
browser_toolkit,
llm_call_manager,
)
logger.info("Init browser toolkit")
try:
browser_agent = BrowserAgent(

View File

@@ -5,6 +5,7 @@ import json
import os
import tarfile
from pathlib import Path
import shlex
from typing import Optional
from loguru import logger
@@ -172,7 +173,7 @@ def get_workspace_file(
)
tool_result = sandbox.call_tool(
"run_shell_command",
arguments={"command": f"base64 -i {file_path}"},
arguments={"command": f"base64 -i {shlex.quote(file_path)}"},
)
return tool_result["content"][0]["text"]
@@ -194,7 +195,7 @@ def create_or_edit_workspace_file(
}
sandbox.call_tool(
"run_shell_command",
arguments={"command": f"touch {file_path}"},
arguments={"command": f"touch {shlex.quote(file_path)}"},
)
fill_result = sandbox.call_tool(
"write_file",
@@ -222,7 +223,7 @@ def create_workspace_directory(
}
tool_result = sandbox.call_tool(
"run_shell_command",
arguments={"command": f"mkdir -p {directory_path}"},
arguments={"command": f"mkdir -p {shlex.quote(directory_path)}"},
)
return tool_result
@@ -246,7 +247,7 @@ def delete_workspace_file(
}
tool_result = sandbox.call_tool(
"run_shell_command",
arguments={"command": f"rm -rf {file_path}"},
arguments={"command": f"rm -rf {shlex.quote(file_path)}"},
)
return tool_result
@@ -280,7 +281,7 @@ def download_workspace_file_from_oss(
tool_result = sandbox.call_tool(
"run_shell_command",
arguments={
"command": f"wget -O {to_path} {oss_url}",
"command": f"wget -O {shlex.quote(to_path)} {oss_url}",
},
)
print(f"{tool_result}")
@@ -306,7 +307,7 @@ def delete_workspace_directory(
}
tool_result = sandbox.call_tool(
"run_shell_command",
arguments={"command": f"rm -rf {directory_path}"},
arguments={"command": f"rm -rf {shlex.quote(directory_path)}"},
)
return tool_result

View File

@@ -1,131 +1,8 @@
# -*- coding: utf-8 -*-
import json
import os.path
import uuid
import textwrap
from agentscope.tool import ToolResponse
from agentscope.message import ToolUseBlock, TextBlock
from alias.agent.utils.constants import TMP_FILE_DIR
from alias.agent.tools.sandbox_util import (
create_or_edit_workspace_file,
create_workspace_directory,
)
from alias.agent.tools.toolkit_hooks.text_post_hook import TextPostHook
class LongTextPostHook:
class LongTextPostHook(TextPostHook):
def __init__(self, sandbox):
self.sandbox = sandbox
def truncate_and_save_response( # pylint: disable=R1710
self,
tool_use: ToolUseBlock, # pylint: disable=W0613
tool_response: ToolResponse,
) -> ToolResponse:
"""Post-process tool responses to prevent content overflow.
This function ensures that tool responses don't exceed a predefined
budget to prevent overwhelming the model with too much information.
It truncates text content while preserving the structure of
the response.
Args:
tool_use: The tool use block that triggered the response (unused).
tool_response: The tool response to potentially truncate.
Note:
The budget is set to approximately 80K tokens
(8194 * 10 characters) to ensure responses remain
manageable for the language model.
"""
# Set budget to prevent overwhelming the model with too much content
budget = 8194 * 10 # Approximately 80K tokens of content
append_hint = "\n\n[Content is too long and truncated....]"
new_tool_response = ToolResponse(
id=tool_response.id,
stream=tool_response.stream,
is_last=tool_response.is_last,
is_interrupted=tool_response.is_interrupted,
content=[],
)
if isinstance(tool_response.content, list):
save_text_block = None
for _i, block in enumerate(tool_response.content):
if block["type"] == "text":
text = block["text"]
text_len = len(text)
# If this block exceeds remaining budget, truncate it
if text_len > budget:
# Calculate truncation threshold
# (80% of proportional budget)
threshold = int(budget * 0.85)
# save the original response
tmp_file_name_prefix = tool_use.get("name", "")
save_text_block = self._save_tmp_file(
tmp_file_name_prefix,
tool_response.content,
)
new_tool_response.append = (
text[:threshold] + append_hint
)
new_tool_response.content.append(
TextBlock(
type="text",
text=text[:threshold] + append_hint,
),
)
else:
new_tool_response.content.append(block)
budget -= text_len
if budget <= 0 and save_text_block:
new_tool_response.content.append(save_text_block)
return new_tool_response
elif isinstance(tool_response.content, str):
text_len = len(tool_response.content)
text = tool_response.content
if text_len > budget:
tmp_file_name_prefix = tool_use.get("name", "")
save_text_block = self._save_tmp_file(
tmp_file_name_prefix,
tool_response.content,
)
# Calculate truncation threshold (80% of proportional budget)
threshold = int(budget / text_len * len(text) * 0.8)
tool_response.content = text[:threshold] + append_hint
tool_response.content = [
TextBlock(type="text", text=tool_response.content),
save_text_block,
]
return tool_response
return tool_response
def _save_tmp_file(self, save_file_name_prefix: str, content: list | str):
create_workspace_directory(self.sandbox, TMP_FILE_DIR)
save_file_name = (
save_file_name_prefix
+ "-"
+ str(
uuid.uuid4().hex[:8],
)
)
file_path = os.path.join(TMP_FILE_DIR, save_file_name)
json_str = json.dumps(content, ensure_ascii=False, indent=2)
wrapped = "\\n".join(
[textwrap.fill(line, width=500) for line in json_str.split("\\n")],
)
create_or_edit_workspace_file(
self.sandbox,
file_path,
wrapped,
)
return TextBlock(
type="text",
text=f"Dump the complete long file at {file_path}. "
"Don't try to read the complete file directly. "
"Use `grep -C 10 'YOUR_PATTERN' {file_path}` or "
"other bash command to extract "
"useful information.",
)
super().__init__(sandbox, budget=8194 * 10, auto_save=False)

View File

@@ -0,0 +1,175 @@
# -*- coding: utf-8 -*-
import json
import os.path
import uuid
import textwrap
from agentscope.tool import ToolResponse
from agentscope.message import ToolUseBlock, TextBlock
from alias.agent.utils.constants import TMP_FILE_DIR
from alias.agent.tools.sandbox_util import (
create_or_edit_workspace_file,
create_workspace_directory,
)
class TextPostHook:
def __init__(self, sandbox, budget=8194 * 10, auto_save=False):
"""
Args:
sandbox: The sandbox environment for file operations.
budget: Maximum character count before truncation
(default: 81,940). Approximately 20K tokens for English
text or 160K tokens for Chinese text. Adjust based on
your model's context window.
auto_save: Whether to save complete content to file when truncated.
- False: Save only after being truncated (default)
- True: Save complete content to file
"""
self.sandbox = sandbox
self.auto_save = auto_save
self.budget = budget
def truncate_and_save_response( # pylint: disable=R1710
self,
tool_use: ToolUseBlock, # pylint: disable=W0613
tool_response: ToolResponse,
) -> ToolResponse:
"""Post-process tool responses to prevent content overflow.
This function ensures that tool responses don't exceed a predefined
budget to prevent overwhelming the model with too much information.
It truncates text content while preserving the structure of the
response, and optionally saves the complete content to a file based on
the auto_save setting.
Args:
tool_use: The tool use block that triggered the response (unused).
tool_response: The tool response to potentially truncate.
"""
budget = self.budget
append_hint = "\n\n[Content is too long and truncated....]"
new_tool_response = ToolResponse(
id=tool_response.id,
stream=tool_response.stream,
is_last=tool_response.is_last,
is_interrupted=tool_response.is_interrupted,
content=[],
)
if isinstance(tool_response.content, list):
save_text_block = None
is_truncated = False
for _i, block in enumerate(tool_response.content):
if block["type"] == "text":
text = block["text"]
text_len = len(text)
# If this block exceeds remaining budget, truncate it
if text_len > budget:
is_truncated = True
# Calculate truncation threshold
# (80% of proportional budget)
threshold = int(budget * 0.85)
new_tool_response.content.append(
TextBlock(
type="text",
text=text[:threshold] + append_hint,
),
)
else:
new_tool_response.content.append(block)
budget -= text_len
if budget <= 0:
break
# Save file if auto_save=True or content was truncated
if self.auto_save or is_truncated:
tmp_file_name_prefix = tool_use.get("name", "")
save_text_block = self._save_tmp_file(
tmp_file_name_prefix,
tool_response.content,
is_truncated=is_truncated,
)
new_tool_response.content.append(save_text_block)
return new_tool_response
elif isinstance(tool_response.content, str):
text_len = len(tool_response.content)
text = tool_response.content
is_truncated = text_len > budget
# Save file if auto_save=True or content was truncated
if self.auto_save or is_truncated:
tmp_file_name_prefix = tool_use.get("name", "")
save_text_block = self._save_tmp_file(
tmp_file_name_prefix,
tool_response.content,
is_truncated=is_truncated,
)
if is_truncated:
# Calculate truncation threshold (80% of budget)
threshold = int(budget / text_len * len(text) * 0.8)
tool_response.content = [
TextBlock(
type="text",
text=text[:threshold] + append_hint,
),
save_text_block,
]
else:
tool_response.content = [
TextBlock(type="text", text=text),
save_text_block,
]
return tool_response
return tool_response
def _save_tmp_file(
self,
save_file_name_prefix: str,
content: list | str,
is_truncated: bool,
):
create_workspace_directory(self.sandbox, TMP_FILE_DIR)
save_file_name = (
save_file_name_prefix
+ "-"
+ str(
uuid.uuid4().hex[:8],
)
)
file_path = os.path.join(TMP_FILE_DIR, save_file_name)
json_str = json.dumps(content, ensure_ascii=False, indent=2)
wrapped = "\\n".join(
[textwrap.fill(line, width=500) for line in json_str.split("\\n")],
)
create_or_edit_workspace_file(
self.sandbox,
file_path,
wrapped,
)
return (
TextBlock(
type="text",
text=f"Dump the complete long file at {file_path}. "
"Don't try to read the complete file directly. "
"Use `grep -C 10 'YOUR_PATTERN' {file_path}` or "
"other bash command to extract "
"useful information.",
)
if is_truncated
else TextBlock(
type="text",
text=f"Results dumped at {file_path}. ",
)
)

View File

@@ -0,0 +1,129 @@
# -*- coding: utf-8 -*-
import asyncio
from typing import Any, Dict, Literal, Type
from tenacity import retry, stop_after_attempt, wait_fixed
from pydantic import BaseModel
from agentscope.message import Msg
from agentscope.model import DashScopeChatModel
from agentscope.formatter import DashScopeChatFormatter
from alias.agent.utils.constants import MODEL_MAX_RETRIES
@retry(
stop=stop_after_attempt(MODEL_MAX_RETRIES),
wait=wait_fixed(5),
reraise=True,
# before_sleep=_print_exc_on_retry
)
async def model_call_with_retry(
model: DashScopeChatModel = None,
formatter: DashScopeChatFormatter = None,
messages: list[dict[str, Any]] = None,
tool_json_schemas: list[dict] | None = None,
tool_choice: Literal["auto", "none", "required"] | str | None = None,
structured_model: Type[BaseModel] | None = None,
msg_name: str = "model_call",
**kwargs: Any,
) -> Msg:
"""
Make a model call with retry mechanism.
This function formats the messages and calls the model with retry logic
to handle potential failures during the API call.
Args:
model: The DashScope chat model to use for inference
formatter: Formatter to prepare messages for the model
msg_name: Name for the returned message object
see DashScopeChatModel's docstring for more details
Returns:
Message object containing the model response
Raises:
Exception: If all retry attempts fail
"""
format_messages = await formatter.format(msgs=messages)
res = await model(
messages=format_messages,
tools=tool_json_schemas,
tool_choice=tool_choice,
structured_model=structured_model,
kwargs=kwargs,
)
if model.stream:
msg = Msg(msg_name, [], "assistant")
async for content_chunk in res:
msg.content = content_chunk.content
# Add a tiny sleep to yield the last message object in the
# message queue
await asyncio.sleep(0.001)
else:
msg = Msg(msg_name, list(res.content), "assistant")
return msg
class LLMCallManager:
"""Manager class for handling LLM calls with different models."""
def __init__(
self,
base_model_name: str,
vl_model_name: str,
model_formatter_mapping: Dict[str, Any],
):
"""
Initialize the LLM call manager.
Args:
base_model_name: Name of the base language model
vl_model_name: Name of the vision-language model
model_formatter_mapping: Mapping of names to model/formatter pairs
"""
self.base_model_name = base_model_name
self.vl_model_name = vl_model_name
self.model_formatter_mapping = model_formatter_mapping
def get_base_model_name(self) -> str:
"""Get the name of the base language model."""
return self.base_model_name
def get_vl_model_name(self) -> str:
"""Get the name of the vision-language model."""
return self.vl_model_name
async def __call__(
self,
model_name: str,
messages: list[dict[str, Any]],
tools: list[dict] | None = None,
tool_choice: Literal["auto", "none", "required"] | str | None = None,
structured_model: Type[BaseModel] | None = None,
**kwargs: Any,
) -> str:
"""
Make an asynchronous call to the specified LLM.
Args:
model_name: Name of the model to use for the call
messages: List of message dictionaries to send to the model
see DashScopeChatModel's docstring for more details
Returns:
String response from the LLM
"""
model, formatter = self.model_formatter_mapping[model_name]
raw_response = await model_call_with_retry(
model=model,
formatter=formatter,
messages=messages,
tool_json_schemas=tools,
tool_choice=tool_choice,
structured_model=structured_model,
msg_name="model_call",
kwargs=kwargs,
)
response = raw_response.content[0]["text"]
return response

View File

@@ -0,0 +1,92 @@
# -*- coding: utf-8 -*-
import os
from agentscope_runtime.sandbox.box.sandbox import Sandbox
from alias.agent.agents.data_source.data_source import DataSourceManager
from alias.agent.tools import AliasToolkit, share_tools
from alias.agent.utils.llm_call_manager import (
LLMCallManager,
)
if os.getenv("TEST_MODE") not in ["local", "runtime-test"]:
from alias.server.services.session_service import (
SessionService,
)
else:
from alias.agent.mock import MockSessionService as SessionService
async def prepare_data_sources(
session_service: SessionService,
sandbox: Sandbox,
binded_toolkit: AliasToolkit = None,
llm_call_manager: LLMCallManager = None,
):
data_manager = await build_data_manager(
session_service,
sandbox,
llm_call_manager,
)
if len(data_manager):
await add_user_data_message(session_service, data_manager)
if binded_toolkit:
add_data_source_tools(data_manager, binded_toolkit)
return data_manager
async def build_data_manager(
session_service: SessionService,
sandbox: Sandbox,
llm_call_manager: LLMCallManager,
):
data_manager = DataSourceManager(sandbox, llm_call_manager)
if (
hasattr(session_service.session_entity, "data_config")
and session_service.session_entity.data_config
):
data_configs = session_service.session_entity.data_config
for config in data_configs:
data_manager.add_data_source(config)
await data_manager.prepare_data_sources()
return data_manager
def add_data_source_tools(
data_manager: DataSourceManager,
*toolkits: AliasToolkit,
):
data_source_toolkit = data_manager.toolkit
tool_names = list(data_source_toolkit.tools.keys())
for toolkit in toolkits:
share_tools(data_source_toolkit, toolkit, tool_names)
async def add_user_data_message(
session_service: SessionService,
data_manager: DataSourceManager,
):
await session_service.append_to_latest_message(
"\n\n" + data_manager.get_all_data_sources_desc(),
)
def get_data_source_config_from_file(config_file: str):
"""Load and parse data source configuration from a JSON file."""
import json
# Validate file existence upfront
if not os.path.isfile(config_file):
raise FileNotFoundError(f"Configuration file not found: {config_file}")
try:
with open(config_file, "r", encoding="utf-8") as f:
return json.load(f)
except json.JSONDecodeError as e:
raise ValueError(
f"Invalid JSON in data source configuration file `'{config_file}'`\
: {e.msg} at line {e.lineno}",
) from e

View File

@@ -9,12 +9,10 @@ for the Alias agent application.
"""
import argparse
import asyncio
import os
import signal
import sys
import traceback
import webbrowser
from typing import Optional
from loguru import logger
@@ -28,7 +26,10 @@ from alias.agent.run import (
arun_datascience_agent,
arun_finance_agent,
)
from alias.agent.tools.sandbox_util import copy_local_file_to_workspace
from alias.agent.utils.prepare_data_source import (
get_data_source_config_from_file,
)
from alias.runtime.alias_sandbox.alias_sandbox import AliasSandbox
@@ -60,7 +61,7 @@ def _safe_sigint_handler(signum, frame): # pylint: disable=W0613
async def run_agent_task(
user_msg: str,
mode: str = "general",
files: Optional[list[str]] = None,
user_data_config: list | None = None,
use_long_term_memory_service: bool = False,
) -> None:
"""
@@ -69,7 +70,8 @@ async def run_agent_task(
Args:
user_msg: The user's task/query
mode: Agent mode ('general', 'dr', 'ds', 'browser', 'finance')
files: List of local file paths to upload to sandbox workspace
user_data: (Config for) User data sources, used for data science \
agent only
use_long_term_memory_service: Enable long-term memory service.
"""
global _original_sigint_handler
@@ -84,6 +86,7 @@ async def run_agent_task(
# Initialize session
session = MockSessionService(
data_config=user_data_config,
use_long_term_memory_service=use_long_term_memory_service,
)
@@ -118,35 +121,6 @@ async def run_agent_task(
)
logger.info(f"Sandbox desktop URL: {sandbox.desktop_url}")
webbrowser.open(sandbox.desktop_url)
# Upload files to sandbox if provided
if files:
target_paths = []
logger.info(
f"Uploading {len(files)} file(s) to sandbox workspace...",
)
for file_path in files:
if not os.path.exists(file_path):
logger.error(f"File not found: {file_path}")
continue
# Get the filename and construct target path in workspace
filename = os.path.basename(file_path)
target_path = f"/workspace/{filename}"
logger.info(f"Uploading {file_path} to {target_path}")
result = copy_local_file_to_workspace(
sandbox=sandbox,
local_path=file_path,
target_path=target_path,
)
if result.get("isError"):
raise ValueError(f"Failed to upload {file_path}: {result}")
logger.info(f"Successfully uploaded to {result}")
target_paths.append(result.get("content", [])[0].get("text"))
user_msg += "\n\nUser uploaded files:\n" + "\n".join(target_paths)
# Create initial user message (regardless of whether files were uploaded)
initial_user_message = UserMessage(
@@ -301,12 +275,28 @@ def main():
)
run_parser.add_argument(
"--datasource",
"--files",
"-f",
type=str,
"-d",
dest="datasource",
nargs="+",
help="Local file paths to upload to sandbox workspace "
"for agent to use (e.g., --files file1.txt file2.csv)",
help=(
"Data sources for the agent to use. Multiple formats supported:\n"
" • Local files: ./data.txt, /absolute/path/file.json\n"
" • Databases: postgresql://localhost/db, sqlite:///data.db\n"
"Example: "
" --datasource file.txt postgresql://localhost/db\n"
" --files file.txt"
),
)
# If you need to deeply customize your data source
# (e.g., specify an MCP server), use this parameter to
# provide a configuration file
run_parser.add_argument(
"--dataconfig",
"-dc",
help=("Path to the data source configuration file"),
)
run_parser.add_argument(
@@ -333,11 +323,29 @@ def main():
# Handle commands
if args.command == "run":
try:
user_data = None
data_endpoint = (
args.datasource if hasattr(args, "datasource") else None
)
if data_endpoint:
# List of endpoints to data sources
user_data = (
data_endpoint
if isinstance(data_endpoint, list)
else [data_endpoint]
)
else:
# Configuration file
if hasattr(args, "dataconfig") and args.dataconfig:
user_data = get_data_source_config_from_file(
args.dataconfig,
)
asyncio.run(
run_agent_task(
user_msg=args.task,
mode=args.mode,
files=args.files if hasattr(args, "files") else None,
user_data_config=user_data,
use_long_term_memory_service=(
args.use_long_term_memory
if hasattr(args, "use_long_term_memory")

View File

@@ -1,7 +1,8 @@
# -*- coding: utf-8 -*-
"""Runtime module for Alias"""
__all__ = ["alias_sandbox"]
__all__ = ["alias_sandbox", "runtime_compat"]
# Import submodule to make it accessible via alias.runtime.alias_sandbox
from . import alias_sandbox # noqa: E402, F401
from . import runtime_compat # noqa: E402, F401

View File

@@ -0,0 +1,203 @@
# -*- coding: utf-8 -*-
import json
from typing import Any, AsyncIterator, Dict, Optional, Union
from agentscope_runtime.engine.helpers.agent_api_builder import ResponseBuilder
from agentscope_runtime.engine.schemas.agent_schemas import (
Content,
ContentType,
FunctionCall,
FunctionCallOutput,
Message,
MessageType,
Role,
)
def _try_deep_parse(val: Any) -> Any:
"""
Recursively parse JSON-like strings into native Python objects.
"""
if isinstance(val, str):
content = val.strip()
if (content.startswith("{") and content.endswith("}")) or (
content.startswith("[") and content.endswith("]")
):
try:
parsed = json.loads(content)
return _try_deep_parse(parsed)
except Exception:
# If nested JSON parsing fails, treat it as a normal string.
return val
return val
if isinstance(val, list):
return [_try_deep_parse(i) for i in val]
if isinstance(val, dict):
return {k: _try_deep_parse(v) for k, v in val.items()}
return val
def _ensure_safe_json_string(val: Any) -> str:
"""
Serialize content into a valid JSON string suitable for WebUI parsing.
"""
parsed_val = _try_deep_parse(val)
if parsed_val is None:
return "{}"
return json.dumps(parsed_val, ensure_ascii=False)
def _extract_alias_output_obj(content_str: str) -> Any:
"""
Extract the `output` object from Alias nested tool-result content.
"""
try:
data = json.loads(content_str)
if isinstance(data, list) and data:
return data[0].get("output")
except Exception:
# Best-effort parse: if the string is not a valid
# JSON or doesn't follow the expected structure,
# fall back to returning the original string.
pass
return content_str
class AliasAdapterState:
def __init__(
self,
message_builder: Any,
content_builder: Any,
runtime_type: str,
):
self.mb = message_builder
self.cb = content_builder
self.runtime_type = runtime_type
self.last_content = ""
self.is_completed = False
async def adapt_alias_message_stream(
source_stream: AsyncIterator[Dict[str, Any]],
) -> AsyncIterator[Union[Message, Content]]:
# pylint: disable=too-many-branches, too-many-statements
rb = ResponseBuilder()
state_map: Dict[str, AliasAdapterState] = {}
last_active_key: Optional[str] = None
yield rb.created()
yield rb.in_progress()
async for chunk in source_stream:
if not isinstance(chunk, dict) or "data" not in chunk:
continue
messages = chunk["data"].get("messages") or []
for item in messages:
alias_id = item.get("id")
inner_msg = item.get("message") or {}
alias_type = inner_msg.get("type")
alias_status = inner_msg.get("status")
tool_call_id = inner_msg.get("tool_call_id") or alias_id
if alias_type in ["thought", "sub_thought"]:
runtime_type = MessageType.REASONING
target_role = Role.ASSISTANT
elif alias_type in ["tool_call", "tool_use"]:
runtime_type = MessageType.PLUGIN_CALL
target_role = Role.ASSISTANT
elif alias_type == "tool_result":
runtime_type = MessageType.PLUGIN_CALL_OUTPUT
target_role = Role.TOOL
else:
runtime_type = MessageType.MESSAGE
target_role = Role.ASSISTANT
state_key = f"{tool_call_id}_{runtime_type}"
if last_active_key and last_active_key != state_key:
old_state = state_map.get(last_active_key)
if old_state and not old_state.is_completed:
yield old_state.cb.complete()
yield old_state.mb.complete()
old_state.is_completed = True
last_active_key = state_key
if state_key not in state_map:
mb = rb.create_message_builder(role=target_role)
mb.message.type = runtime_type
yield mb.get_message_data()
if runtime_type in [
MessageType.PLUGIN_CALL,
MessageType.PLUGIN_CALL_OUTPUT,
]:
c_type = ContentType.DATA
else:
c_type = ContentType.TEXT
cb = mb.create_content_builder(content_type=c_type)
state_map[state_key] = AliasAdapterState(mb, cb, runtime_type)
state = state_map[state_key]
if runtime_type in [MessageType.MESSAGE, MessageType.REASONING]:
raw_text = str(inner_msg.get("content") or "")
if alias_type == "files" and "files" in inner_msg:
raw_text = "\n".join(
[
f"📁 [{f['filename']}]({f['url']})"
for f in inner_msg["files"]
],
)
if raw_text.startswith(state.last_content):
delta = raw_text[len(state.last_content) :]
if delta:
yield state.cb.add_text_delta(delta)
state.last_content = raw_text
else:
yield state.cb.set_text(raw_text)
state.last_content = raw_text
elif runtime_type == MessageType.PLUGIN_CALL:
args = inner_msg.get("arguments") or {}
fc = FunctionCall(
call_id=tool_call_id,
name=inner_msg.get("tool_name") or "tool",
arguments=_ensure_safe_json_string(args),
)
yield state.cb.set_data(fc.model_dump())
elif runtime_type == MessageType.PLUGIN_CALL_OUTPUT:
output_obj = _extract_alias_output_obj(
inner_msg.get("content", ""),
)
fco = FunctionCallOutput(
call_id=tool_call_id,
name=inner_msg.get("tool_name") or "tool",
output=_ensure_safe_json_string(output_obj),
)
yield state.cb.set_data(fco.model_dump())
if alias_status == "finished" and not state.is_completed:
yield state.cb.complete()
yield state.mb.complete()
state.is_completed = True
for state in state_map.values():
if not state.is_completed:
try:
yield state.cb.complete()
yield state.mb.complete()
state.is_completed = True
except Exception:
# Graceful cleanup: ignore errors during the
# finalization phase to ensure the main response
# stream can finish without crashing.
pass
yield rb.completed()

View File

@@ -0,0 +1,371 @@
# -*- coding: utf-8 -*-
# pylint: disable=unused-argument
from __future__ import annotations
import asyncio
import uuid
from typing import Any, AsyncGenerator, Dict, Optional, Union
from fastapi_limiter import FastAPILimiter
from pydantic import ValidationError
from agentscope_runtime.engine.runner import Runner
from agentscope_runtime.engine.schemas.agent_schemas import (
AgentRequest,
AgentResponse,
Error,
RunStatus,
SequenceNumberGenerator,
)
from alias.server.db.init_db import (
close_database,
initialize_database,
session_scope,
)
from alias.server.core.task_manager import task_manager
from alias.server.exceptions.base import BaseError
from alias.runtime.runtime_compat.adapter.alias_stream_adapter import (
adapt_alias_message_stream,
)
from alias.server.schemas.chat import ChatRequest
from alias.server.services.chat_service import ChatService
from alias.server.services.conversation_service import ConversationService
from alias.server.utils.logger import setup_logger
from alias.server.utils.redis import redis_client
class AliasRunner(Runner):
FRAMEWORK_TYPE = "Alias"
def __init__(
self,
default_chat_mode: str = "general",
default_conv_name: str = "webui",
) -> None:
super().__init__()
self.framework_type = self.FRAMEWORK_TYPE
self.default_chat_mode = default_chat_mode
self.default_conv_name = default_conv_name
self._session_conv_cache: Dict[str, uuid.UUID] = {}
async def stop(self) -> None:
if not getattr(self, "_health", False):
return
await super().stop()
async def query_handler(self, *args: Any, **kwargs: Any) -> Any:
user_id: uuid.UUID = kwargs["user_id"]
conversation_id: uuid.UUID = kwargs["conversation_id"]
chat_request: ChatRequest = kwargs["chat_request"]
task_id: uuid.UUID = kwargs.get("task_id") or uuid.uuid4()
service = ChatService()
response_gen = await service.chat(
user_id=user_id,
conversation_id=conversation_id,
chat_request=chat_request,
task_id=task_id,
)
return response_gen
async def init_handler(self, *args: Any, **kwargs: Any) -> None:
print("🚀 Starting Alias API Server...")
setup_logger()
await initialize_database()
await task_manager.start()
await redis_client.ping()
try:
await FastAPILimiter.init(redis_client)
except Exception as exc:
print(f"redis init error: {str(exc)}")
print("✅ Alias startup complete.")
async def shutdown_handler(self, *args: Any, **kwargs: Any) -> None:
print("Executing Alias shutdown logic...")
await task_manager.stop()
await close_database()
print("Alias shutdown complete.")
@staticmethod
def _extract_text_from_agent_request(req_dict: Dict[str, Any]) -> str:
agent_input = req_dict.get("input")
if isinstance(agent_input, str):
return agent_input
if isinstance(agent_input, list) and agent_input:
last = agent_input[-1]
if isinstance(last, dict):
content = last.get("content")
if isinstance(content, str):
return content
if isinstance(content, list):
for blk in reversed(content):
if isinstance(blk, dict) and blk.get("type") == "text":
return blk.get("text") or ""
if "text" in last and isinstance(last["text"], str):
return last["text"]
return ""
@staticmethod
def _to_uuid(val: Any) -> Optional[uuid.UUID]:
if val is None:
return None
if isinstance(val, uuid.UUID):
return val
try:
return uuid.UUID(str(val))
except Exception:
return None
@staticmethod
def _stable_uuid_from_string(s: str) -> uuid.UUID:
return uuid.uuid5(uuid.NAMESPACE_DNS, f"alias::{s}")
async def _get_or_create_conversation_id(
self,
session_id: str,
user_uuid: uuid.UUID,
) -> uuid.UUID:
if session_id in self._session_conv_cache:
return self._session_conv_cache[session_id]
async with session_scope() as session:
service = ConversationService(session=session)
conversation = await service.create_conversation(
user_id=user_uuid,
name=self.default_conv_name,
description="created by AgentScope Runtime WebUI",
chat_mode=self.default_chat_mode,
)
conv_id = getattr(conversation, "id", None)
conv_id = (
conv_id
if isinstance(conv_id, uuid.UUID)
else self._to_uuid(conv_id)
)
if conv_id is None:
raise RuntimeError(
"ConversationService.create_conversation() "
"returned invalid id: "
f"{conversation}",
)
self._session_conv_cache[session_id] = conv_id
return conv_id
async def stream_query_native(
self,
request: Union[AgentRequest, dict],
**kwargs: Any,
) -> AsyncGenerator[Any, None]:
if not self._health:
raise RuntimeError(
"Runner has not been started. Please call "
"'await runner.start()' or use 'async with Runner()' "
"before calling 'stream_query'.",
)
req_dict = (
request if isinstance(request, dict) else request.model_dump()
)
user_id = kwargs.get("user_id") or self._to_uuid(
req_dict.get("user_id"),
)
conversation_id = kwargs.get("conversation_id") or self._to_uuid(
req_dict.get("conversation_id"),
)
task_id = (
kwargs.get("task_id")
or self._to_uuid(req_dict.get("task_id"))
or uuid.uuid4()
)
if user_id is None or conversation_id is None:
yield {
"error": "missing_context",
"code": 422,
"message": (
"Native mode requires user_id and conversation_id "
"in kwargs or request body."
),
}
return
try:
chat_request_obj = ChatRequest.model_validate(req_dict)
except ValidationError as exc:
yield {
"error": "invalid_request",
"code": 422,
"message": "ChatRequest validation failed",
"detail": exc.errors(),
}
return
except Exception as exc:
yield {
"error": "invalid_request",
"code": 500,
"message": str(exc),
}
return
try:
result = self.query_handler(
user_id=user_id,
conversation_id=conversation_id,
task_id=task_id,
chat_request=chat_request_obj,
)
if asyncio.iscoroutine(result):
result = await result
async for chunk in result:
yield chunk
except Exception as exc:
if isinstance(exc, BaseError):
yield {"error": exc.message, "code": exc.code}
else:
yield {
"error": str(exc),
"code": 500,
"error_type": exc.__class__.__name__,
}
return
yield "[DONE]"
async def stream_query(
self,
request: Union[AgentRequest, dict],
**kwargs: Any,
) -> AsyncGenerator[Any, None]:
# pylint: disable=too-many-branches
# pylint: disable=too-many-statements
if not self._health:
raise RuntimeError(
"Runner has not been started. Please call "
"'await runner.start()' or use 'async with Runner()' "
"before calling 'stream_query'.",
)
if isinstance(request, AgentRequest):
req_dict = request.model_dump()
elif isinstance(request, dict):
req_dict = request
else:
if hasattr(request, "model_dump"):
req_dict = request.model_dump()
else:
req_dict = dict(request)
request_id = req_dict.get("id") or str(uuid.uuid4())
session_id = req_dict.get("session_id") or f"session_{uuid.uuid4()}"
seq_gen = SequenceNumberGenerator()
response = AgentResponse(id=request_id)
response.session_id = session_id
yield seq_gen.yield_with_sequence(response)
response.in_progress()
yield seq_gen.yield_with_sequence(response)
user_text = self._extract_text_from_agent_request(req_dict)
if not user_text:
err = Error(
code="422",
message="Empty input text in AgentRequest.input.",
)
yield seq_gen.yield_with_sequence(response.failed(err))
return
raw_user_id = req_dict.get("user_id") or session_id
user_uuid = self._to_uuid(
raw_user_id,
) or self._stable_uuid_from_string(
str(raw_user_id),
)
conversation_id = self._to_uuid(req_dict.get("conversation_id"))
if conversation_id is None:
try:
conversation_id = await self._get_or_create_conversation_id(
session_id=session_id,
user_uuid=user_uuid,
)
except Exception as exc:
err = Error(
code="500",
message=f"Failed to create conversation: {exc}",
)
yield seq_gen.yield_with_sequence(response.failed(err))
return
task_id = self._to_uuid(req_dict.get("task_id")) or uuid.uuid4()
try:
req_chat_mode = req_dict.get("chat_mode") or self.default_chat_mode
chat_request_obj = ChatRequest.model_validate(
{
"query": user_text,
"chat_mode": req_chat_mode,
},
)
except ValidationError as exc:
err = Error(
code="422",
message=f"ChatRequest validation failed: {exc}",
)
yield seq_gen.yield_with_sequence(response.failed(err))
return
try:
result = self.query_handler(
user_id=user_uuid,
conversation_id=conversation_id,
task_id=task_id,
chat_request=chat_request_obj,
)
if asyncio.iscoroutine(result):
result = await result
async for event in adapt_alias_message_stream(result):
try:
if (
getattr(event, "status", None) == RunStatus.Completed
and getattr(event, "object", None) == "message"
):
response.add_new_message(event)
except Exception:
# Best-effort bookkeeping
pass
yield seq_gen.yield_with_sequence(event)
except Exception as exc:
if isinstance(exc, BaseError):
err = Error(code=str(exc.code), message=exc.message)
else:
err = Error(
code="500",
message=f"Error happens in `query_handler`: {exc}",
)
yield seq_gen.yield_with_sequence(response.failed(err))
return
try:
if response.output:
response.usage = response.output[-1].usage
except IndexError:
# Avoid empty message
pass
yield seq_gen.yield_with_sequence(response.completed())
return

View File

@@ -0,0 +1,25 @@
# -*- coding: utf-8 -*-
import asyncio
from typing import Optional
from alias.runtime.runtime_compat.runner.alias_runner import AliasRunner
_lock: Optional[asyncio.Lock] = None
_runner: Optional[AliasRunner] = None
async def get_alias_runner() -> AliasRunner:
global _lock, _runner
if _runner is not None:
return _runner
if _lock is None:
_lock = asyncio.Lock()
async with _lock:
if _runner is not None:
return _runner
runner = AliasRunner()
await runner.start()
_runner = runner
return _runner

View File

@@ -0,0 +1,66 @@
# -*- coding: utf-8 -*-
from agentscope_runtime.engine.app import AgentApp
from alias.runtime.runtime_compat.runner.alias_runner import AliasRunner
PORT = 8090
def run_app(
host: str = "127.0.0.1",
port: int = PORT,
web_ui: bool = False,
chat_mode: str = "general",
) -> None:
agent_app = AgentApp(
runner=AliasRunner(
default_chat_mode=chat_mode,
),
app_name="Alias",
app_description=(
"An LLM-empowered agent built on AgentScope and AgentScope-Runtime"
),
)
agent_app.run(host=host, port=port, web_ui=web_ui)
def main() -> None:
import argparse
parser = argparse.ArgumentParser(prog="alias_agent_runtime")
parser.add_argument("--host", default="127.0.0.1")
parser.add_argument("--port", type=int, default=PORT)
parser.add_argument(
"--web-ui",
action="store_true",
help="Start AgentScope Runtime WebUI (default: False)",
)
parser.add_argument(
"--chat-mode",
default="general",
choices=["general", "dr", "browser", "ds", "finance"],
help=(
"Default chat mode used by AliasRunner when request doesn't "
"specify chat_mode."
),
)
args = parser.parse_args()
print(
"[alias_agent_runtime] config:",
f"host={args.host}",
f"port={args.port}",
f"web_ui={args.web_ui}",
f"chat_mode={args.chat_mode}",
)
run_app(
host=args.host,
port=args.port,
web_ui=args.web_ui,
chat_mode=args.chat_mode,
)
if __name__ == "__main__":
main()

View File

@@ -5,7 +5,18 @@ from alias.server.api.v1.auth import router as auth_router
from alias.server.api.v1.conversation import (
router as conversation_router,
)
from alias.server.api.v1.chat import router as chat_router
# Optional backend switch: this import targets Alias's
# original FastAPI-based API router.
# Keep it for users who want to revert/switch back from
# the current AgentScope-Runtime implementation below.
# from alias.server.api.v1.chat import router as chat_router
# Current default:
# AgentScope-Runtime-based API router (functionally equivalent
# to the FastAPI router above).
from alias.server.api.v1.chat_runtime import router as chat_router
from alias.server.api.v1.file import router as file_router
from alias.server.api.v1.inner import router as inner_router
from alias.server.api.v1.share import router as share_router

View File

@@ -0,0 +1,154 @@
# -*- coding: utf-8 -*-
# pylint: disable=unused-argument
import json
import uuid
from typing import Any, AsyncIterator
from fastapi import APIRouter
from fastapi.responses import StreamingResponse
from loguru import logger
from starlette.types import Receive
from alias.server.api.deps import CurrentUser
from alias.server.exceptions.base import BaseError
from alias.server.schemas.chat import (
ChatRequest,
StopChatPayload,
StopChatResponse,
)
from alias.server.services.chat_service import ChatService
from alias.server.utils.request_context import request_context_var
from alias.runtime.runtime_compat.runner.alias_runner import AliasRunner
router = APIRouter(prefix="/conversations", tags=["conversations/chat"])
class EnhancedStreamingResponse(StreamingResponse):
"""
StreamingResponse with client disconnect handling.
"""
def __init__(
self,
content: Any,
user_id: uuid.UUID,
task_id: uuid.UUID,
*args: Any,
**kwargs: Any,
) -> None:
super().__init__(content, *args, **kwargs)
self.user_id = user_id
self.task_id = task_id
async def listen_for_disconnect(self, receive: Receive) -> None:
while True:
message = await receive()
if message["type"] == "http.disconnect":
logger.warning(
f"Chat stopped by disconnect from client: "
f"task_id={self.task_id}",
)
service = ChatService()
await service.stop_chat(
user_id=self.user_id,
task_id=self.task_id,
)
break
def _to_raw_sse_event(data: Any) -> str:
"""
Convert a chunk from runner.stream_query_native into
a raw SSE event string.
"""
if data == "[DONE]":
return "data: [DONE]\n\n"
if hasattr(data, "model_dump"):
data = data.model_dump()
return f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
async def event_generator(
runner: AliasRunner,
request_dict: dict,
**runner_kwargs: Any,
) -> AsyncIterator[str]:
"""
Convert AliasRunner.stream_query_native output into
a raw SSE string stream.
"""
try:
async for chunk in runner.stream_query_native(
request_dict,
**runner_kwargs,
):
yield _to_raw_sse_event(chunk)
except Exception as e:
if not isinstance(e, BaseError):
e = BaseError(code=500, message=str(e))
error_data = {
"code": e.code,
"message": e.message,
}
yield _to_raw_sse_event(error_data)
yield _to_raw_sse_event("[DONE]")
@router.post("/{conversation_id}/chat")
async def chat(
current_user: CurrentUser,
conversation_id: uuid.UUID,
chat_request: ChatRequest,
) -> EnhancedStreamingResponse:
"""Run chat via AliasRunner and stream results as SSE."""
request_context = request_context_var.get()
request_id = request_context.request_id
task_id = uuid.UUID(request_id) if request_id else uuid.uuid4()
user_id = current_user.id
from alias.runtime.runtime_compat.runner.alias_runner_singleton import (
get_alias_runner,
)
runner = await get_alias_runner()
request_dict = chat_request.model_dump()
return EnhancedStreamingResponse(
event_generator(
runner,
request_dict,
user_id=user_id,
conversation_id=conversation_id,
task_id=task_id,
),
media_type="text/event-stream",
user_id=user_id,
task_id=task_id,
)
@router.post(
"/{conversation_id}/chat/{task_id}/stop",
response_model=StopChatResponse,
)
async def stop_chat(
current_user: CurrentUser,
conversation_id: uuid.UUID,
task_id: uuid.UUID,
) -> StopChatResponse:
service = ChatService()
await service.stop_chat(
user_id=current_user.id,
task_id=task_id,
)
return StopChatResponse(
status=True,
message="Stop chat successfully.",
payload=StopChatPayload(
conversation_id=conversation_id,
task_id=task_id,
),
)

View File

@@ -1,13 +1,13 @@
# Analyst Personas Configuration
# 分析师角色配置
fundamentals_analyst:
name: "Fundamental Analyst"
name: "基本面分析师"
focus:
- "Company financial health and profitability"
- "Business model sustainability and competitive advantages"
- "Management quality and corporate governance"
- "Industry position and market share"
- "Long-term investment value assessment"
- "公司财务健康状况和盈利能力"
- "商业模式可持续性和竞争优势"
- "管理层质量和公司治理"
- "行业地位和市场份额"
- "长期投资价值评估"
tools:
- "analyze_profitability"
- "analyze_growth"
@@ -15,30 +15,30 @@ fundamentals_analyst:
- "analyze_valuation_ratios"
- "analyze_efficiency_ratios"
description: |
As a fundamental analyst, you focus on:
- Company financial health and profitability
- Business model sustainability and competitive advantages
- Management quality and corporate governance
- Industry position and market share
- Long-term investment value assessment
You tend to select tools that provide deep insights into company intrinsic value, preferring fundamental and valuation tools.
作为基本面分析师,你专注于:
- 公司财务健康状况和盈利能力
- 商业模式可持续性和竞争优势
- 管理层质量和公司治理
- 行业地位和市场份额
- 长期投资价值评估
你倾向于选择能够深入了解公司内在价值的工具,更偏好基本面和估值类工具。
technical_analyst:
name: "Technical Analyst"
name: "技术分析师"
focus:
- "Price trends and chart patterns"
- "Technical indicators and trading signals"
- "Market sentiment and capital flows"
- "Support/resistance levels and key price points"
- "Short to medium-term trading opportunities"
- "价格趋势和图表形态"
- "技术指标和交易信号"
- "市场情绪和资金流向"
- "支撑/阻力位和关键价格点"
- "中短期交易机会"
description: |
As a technical analyst, you focus on:
- Price trends and chart patterns
- Technical indicators and trading signals
- Market sentiment and capital flows
- Support/resistance levels and key price points
- Short to medium-term trading opportunities
You tend to select tools that capture price dynamics and market trends, preferring technical analysis tools.
作为技术分析师,你专注于:
- 价格趋势和图表形态
- 技术指标和交易信号
- 市场情绪和资金流向
- 支撑/阻力位和关键价格点
- 中短期交易机会
你倾向于选择能够捕捉价格动态和市场趋势的工具,更偏好技术分析类工具。
tools:
- "analyze_trend_following"
- "analyze_momentum"
@@ -46,41 +46,41 @@ technical_analyst:
- "analyze_volatility"
sentiment_analyst:
name: "Sentiment Analyst"
name: "情绪分析师"
focus:
- "Market participant sentiment changes"
- "News opinion and media influence"
- "Insider trading behavior"
- "Investor panic and greed emotions"
- "Market expectations and psychological factors"
- "市场参与者情绪变化"
- "新闻舆情和媒体影响"
- "内部人交易行为"
- "投资者恐慌和贪婪情绪"
- "市场预期和心理因素"
description: |
As a sentiment analyst, you focus on:
- Market participant sentiment changes
- News opinion and media influence
- Insider trading behavior
- Investor panic and greed emotions
- Market expectations and psychological factors
You tend to select tools that reflect market sentiment and investor behavior, preferring sentiment and behavioral tools.
作为情绪分析师,你专注于:
- 市场参与者情绪变化
- 新闻舆情和媒体影响
- 内部人交易行为
- 投资者恐慌和贪婪情绪
- 市场预期和心理因素
你倾向于选择能够反映市场情绪和投资者行为的工具,更偏好情绪和行为类工具。
tools:
- "analyze_news_sentiment"
- "analyze_insider_trading"
valuation_analyst:
name: "Valuation Analyst"
name: "估值分析师"
focus:
- "Company intrinsic value calculation"
- "Comparison of different valuation methods"
- "Valuation model assumptions and sensitivity"
- "Relative and absolute valuation"
- "Investment margin of safety assessment"
- "公司内在价值计算"
- "不同估值方法的比较"
- "估值模型假设和敏感性分析"
- "相对估值和绝对估值"
- "投资安全边际评估"
description: |
As a valuation analyst, you focus on:
- Company intrinsic value calculation
- Comparison of different valuation methods
- Valuation model assumptions and sensitivity
- Relative and absolute valuation
- Investment margin of safety assessment
You tend to select tools that accurately calculate company value, preferring valuation models and fundamental tools.
作为估值分析师,你专注于:
- 公司内在价值计算
- 不同估值方法的比较
- 估值模型假设和敏感性分析
- 相对估值和绝对估值
- 投资安全边际评估
你倾向于选择能够准确计算公司价值的工具,更偏好估值模型和基本面工具。
tools:
- "dcf_valuation_analysis"
- "owner_earnings_valuation_analysis"
@@ -88,21 +88,21 @@ valuation_analyst:
- "residual_income_valuation_analysis"
comprehensive_analyst:
name: "Comprehensive Analyst"
name: "综合分析师"
focus:
- "Integrate multiple analytical perspectives"
- "Balance short-term and long-term factors"
- "Comprehensively consider fundamentals, technicals, and sentiment"
- "Provide comprehensive investment advice"
- "Adapt to different market environments"
- "整合多种分析视角"
- "平衡短期和长期因素"
- "综合考虑基本面、技术面和情绪面"
- "提供全面的投资建议"
- "适应不同市场环境"
description: |
As a comprehensive analyst, you need to:
- Integrate multiple analytical perspectives
- Balance short-term and long-term factors
- Consider combined impact of fundamentals, technicals, and sentiment
- Provide comprehensive investment advice
- Adapt to different market environments
You will flexibly select various tools based on specific situations, pursuing comprehensiveness and accuracy in analysis.
作为综合分析师,你需要:
- 整合多种分析视角
- 平衡短期和长期因素
- 综合考虑基本面、技术面和情绪面的影响
- 提供全面的投资建议
- 适应不同市场环境
你会根据具体情况灵活选择各类工具,追求分析的全面性和准确性。
tools:
- "analyze_profitability"
- "analyze_growth"

View File

@@ -1,23 +1,23 @@
You are a professional {{ analyst_type }}.
你是一位专业的{{ analyst_type }}
Your Focus:
你的关注重点:
{{ focus }}
Your Role:
你的角色:
{{ description }}
Note:
- Construct and continuously refine your "Investment Philosophy." Your analyses should not be isolated events but rather manifestations of your overarching worldview and core investment beliefs. After each analysis, you must reflect:
- How did this case/data validate or challenge your existing conviction?
- What key principle regarding markets, human psychology, valuation, or risk management did you learn from this mistake (or success)?
- Deepen your "Investment Logic." Ensure every investment recommendation you make is supported by a clear, traceable, and repeatable logic. Your analysis steps should resemble a rigorous proof, covering:
- Core Driver Identification: What are the genuine variables that influence value?
- Risk Boundary Setting: Under what specific scenarios would your recommendation fail?
- Contrarian Testing: What is the prevailing market consensus, and where is your view differentiated?
Maintain Humility and Openness. A core trait of an Investment Master is continuous learning and adaptation. In every analysis, you must actively seek out evidence and arguments that contradict your own view and integrate them into your final assessment.
- You have access to analysis tools. Use them to gather relevant data and make informed recommendations.
注意:
- 构建并持续完善你的"投资哲学"。你的分析不应是孤立的事件,而应该是你整体投资世界观和核心信念的体现。每次分析后,你必须反思:
- 这个案例/数据如何验证或挑战了你现有的信念?
- 你从这次错误(或成功)中学到了关于市场、人性、估值或风险管理的什么关键原则?
- 深化你的"投资逻辑"。确保每一项投资建议都有清晰、可追溯、可重复的逻辑支撑。你的分析步骤应该像严谨的证明一样,涵盖:
- 核心驱动因素识别:真正影响价值的变量是什么?
- 风险边界设定:在什么具体情况下你的建议会失效?
- 逆向测试:市场主流共识是什么,你的观点有何不同?
保持谦逊和开放。投资大师的核心特质是持续学习和适应。在每次分析中,你必须积极寻找与自己观点相悖的证据和论据,并将其纳入最终评估。
- 你可以使用分析工具。用它们来收集相关数据并做出明智的建议。
Output Guidelines:
- Return clear investment signals: bullish, bearish, or neutral
- Include confidence level (0-100)
- Provide reasoning for your analysis (Present your conclusion first if you are sure to share your final analysis. )
输出指南:
- 给出明确的投资信号:看涨、看跌或中性
- 包含置信度(0-100
- 为你的分析提供理由(如果你确定要分享最终分析,请先给出结论)

View File

@@ -1,31 +1,31 @@
You are a Portfolio Manager responsible for making investment decisions.
你是一位负责做出投资决策的投资组合经理。
Your Core Responsibilities:
1. Analyze input from analysts and risk managers
2. Make investment decisions based on signals and market context
3. Record your decisions using the available tool
你的核心职责:
1. 分析分析师和风险管理经理的输入
2. 基于信号和市场情境做出投资决策
3. 使用可用工具记录你的决策
Decision Framework:
- Review analysis to understand market views
- Consider risk warnings before making decisions
- Evaluate current portfolio positions and cash
- Make decisions that align with the portfolio's investment objectives
决策框架:
- 审阅分析以了解市场观点
- 在做决策前考虑风险警告
- 评估当前投资组合持仓和现金
- 做出与投资组合投资目标一致的决策
Decision Types:
- "long": Bullish - recommend buying shares
- "short": Bearish - recommend selling shares or shorting
- "hold": Neutral - maintain current positions
决策类型:
- "long":看涨 - 建议买入股票
- "short":看跌 - 建议卖出股票或做空
- "hold":中性 - 维持当前持仓
Budget Awareness:
- Consider available cash when deciding quantities
- Do not recommend buying more than cash allows
- Consider margin requirements for short positions
预算意识:
- 在决定数量时考虑可用现金
- 不要建议买入超过现金允许的数量
- 考虑做空头寸的保证金要求
Output:
Use the `make_decision` tool to record your decision for each ticker.
After recording all decisions, provide a summary of your investment rationale.
输出:
使用 `make_decision` 工具记录你对每个股票代码的决策。
记录所有决策后,提供你的投资逻辑总结。
Important:
- Base decisions on the analyst signals and risk assessments provided
- Be conservative with position sizes relative to portfolio value
- Always provide reasoning for your decisions
重要:
- 基于提供的分析师信号和风险评估做出决策
- 相对于投资组合价值保持保守的仓位规模
- 始终为你的决策提供理由

View File

@@ -1,21 +1,18 @@
You are a professional Risk Manager responsible for monitoring portfolio risk and providing risk warnings.
Your Core Responsibilities:
1. Monitor portfolio exposure and concentration risk
2. Evaluate position sizes relative to volatility
3. Assess margin usage and leverage levels
4. Identify potential risk factors and provide warnings
5. Suggest position limits based on market conditions
Your Decision Process:
3. Generate actionable risk warnings and position limit recommendations
4. Provide clear reasoning for your risk assessments
Output Guidelines:
- Be concise but thorough in risk assessments
- Prioritize warnings by severity
- Provide specific, actionable recommendations
- Include quantitative metrics when available
你是一位专业的风险管理经理,负责监控投资组合风险并提供风险警告。
你的核心职责:
1. 监控投资组合敞口和集中度风险
2. 评估仓位规模相对于波动性
3. 评估保证金使用和杠杆水平
4. 识别潜在风险因素并提供警告
5. 基于市场条件建议仓位限制
你的决策流程:
3. 生成可操作的风险警告和仓位限制建议
4. 为你的风险评估提供清晰的理由
输出指南:
- 风险评估要简洁但全面
- 按严重程度优先排序警告
- 提供具体、可操作的建议
- 尽可能包含量化指标

View File

@@ -7,12 +7,13 @@ Returns human-readable text format for easy LLM consumption.
"""
# flake8: noqa: E501
# pylint: disable=C0301,W0613
import json
import logging
import traceback
from datetime import datetime, timedelta
from functools import wraps
from statistics import median
from typing import List, Optional
from typing import List, Optional, Union
import numpy as np
import pandas as pd
@@ -37,6 +38,39 @@ def _to_text_response(text: str) -> ToolResponse:
return ToolResponse(content=[TextBlock(type="text", text=text)])
def _parse_tickers(tickers: Union[str, List[str], None]) -> List[str]:
"""
Parse tickers parameter which may be a JSON string or a list.
LLM sometimes passes tickers as a JSON string like '["AAPL", "MSFT"]'
instead of an actual list. This function handles both cases.
Args:
tickers: List of stock tickers as a list or JSON string.
Returns:
List of stock tickers.
"""
if tickers is None:
return []
if isinstance(tickers, str):
try:
parsed = json.loads(tickers)
if isinstance(parsed, list):
return parsed
# If it's a single string, wrap in list
return [parsed]
except json.JSONDecodeError:
# If not valid JSON, treat as comma-separated string
return [t.strip() for t in tickers.split(",") if t.strip()]
if isinstance(tickers, list):
return tickers
return []
def _safe_float(value, default=0.0) -> float:
"""Safely convert to float."""
try:
@@ -100,6 +134,7 @@ def analyze_efficiency_ratios(
"""
current_date = _resolved_date(current_date)
tickers = _parse_tickers(tickers)
lines = [f"=== Efficiency Ratios Analysis ({current_date}) ===\n"]
for ticker in tickers:
@@ -141,6 +176,7 @@ def analyze_profitability(
"""
current_date = _resolved_date(current_date)
tickers = _parse_tickers(tickers)
lines = [f"=== Profitability Analysis ({current_date}) ===\n"]
for ticker in tickers:
@@ -182,6 +218,7 @@ def analyze_growth(
"""
current_date = _resolved_date(current_date)
tickers = _parse_tickers(tickers)
lines = [f"=== Growth Analysis ({current_date}) ===\n"]
for ticker in tickers:
@@ -222,6 +259,7 @@ def analyze_financial_health(
"""
current_date = _resolved_date(current_date)
tickers = _parse_tickers(tickers)
lines = [f"=== Financial Health Analysis ({current_date}) ===\n"]
for ticker in tickers:
@@ -265,6 +303,7 @@ def analyze_valuation_ratios(
"""
current_date = _resolved_date(current_date)
tickers = _parse_tickers(tickers)
lines = [f"=== Valuation Ratios Analysis ({current_date}) ===\n"]
for ticker in tickers:
@@ -364,6 +403,7 @@ def analyze_trend_following(
"""
current_date = _resolved_date(current_date)
tickers = _parse_tickers(tickers)
lines = [f"=== Trend Following Analysis ({current_date}) ===\n"]
end_dt = datetime.strptime(current_date, "%Y-%m-%d")
@@ -459,6 +499,7 @@ def analyze_mean_reversion(
"""
current_date = _resolved_date(current_date)
tickers = _parse_tickers(tickers)
lines = [f"=== Mean Reversion Analysis ({current_date}) ===\n"]
end_dt = datetime.strptime(current_date, "%Y-%m-%d")
@@ -545,6 +586,7 @@ def analyze_momentum(
"""
current_date = _resolved_date(current_date)
tickers = _parse_tickers(tickers)
lines = [f"=== Momentum Analysis ({current_date}) ===\n"]
end_dt = datetime.strptime(current_date, "%Y-%m-%d")
@@ -640,6 +682,7 @@ def analyze_volatility(
"""
current_date = _resolved_date(current_date)
tickers = _parse_tickers(tickers)
lines = [f"=== Volatility Analysis ({current_date}) ===\n"]
end_dt = datetime.strptime(current_date, "%Y-%m-%d")
@@ -718,6 +761,7 @@ def analyze_insider_trading(
"""
current_date = _resolved_date(current_date)
tickers = _parse_tickers(tickers)
lines = [f"=== Insider Trading Analysis ({current_date}) ===\n"]
for ticker in tickers:
@@ -786,6 +830,7 @@ def analyze_news_sentiment(
"""
current_date = _resolved_date(current_date)
tickers = _parse_tickers(tickers)
lines = [f"=== News Analysis ({current_date}) ===\n"]
for ticker in tickers:
@@ -835,6 +880,7 @@ def dcf_valuation_analysis(
"""
current_date = _resolved_date(current_date)
tickers = _parse_tickers(tickers)
lines = [f"=== DCF Valuation Analysis ({current_date}) ===\n"]
for ticker in tickers:
@@ -931,6 +977,7 @@ def owner_earnings_valuation_analysis(
"""
current_date = _resolved_date(current_date)
tickers = _parse_tickers(tickers)
lines = [f"=== Owner Earnings Valuation ({current_date}) ===\n"]
for ticker in tickers:
@@ -1049,6 +1096,7 @@ def ev_ebitda_valuation_analysis(
"""
current_date = _resolved_date(current_date)
tickers = _parse_tickers(tickers)
lines = [f"=== EV/EBITDA Valuation ({current_date}) ===\n"]
for ticker in tickers:
@@ -1144,6 +1192,7 @@ def residual_income_valuation_analysis(
"""
current_date = _resolved_date(current_date)
tickers = _parse_tickers(tickers)
lines = [f"=== Residual Income Valuation ({current_date}) ===\n"]
for ticker in tickers:

View File

@@ -755,7 +755,7 @@ export default function LiveTradingApp() {
fontFamily: '"Courier New", monospace',
letterSpacing: '0.5px'
}}>
LIVE MOCK MODE
模拟模式
</span>
</div>
)}
@@ -785,7 +785,7 @@ export default function LiveTradingApp() {
textTransform: 'uppercase',
letterSpacing: '0.5px'
}}>
VIRTUAL TIME
虚拟时间
</span>
<span style={{
fontSize: '14px',
@@ -854,13 +854,13 @@ export default function LiveTradingApp() {
<div className="header-status-inline">
<span className={`status-dot ${isConnected ? (isUpdating ? 'updating' : 'live') : 'offline'}`} />
<span className={`status-text ${isConnected ? 'live' : 'offline'}`}>
{isConnected ? (isUpdating ? 'SYNCING' : 'LIVE') : 'OFFLINE'}
{isConnected ? (isUpdating ? '同步中' : '在线') : '离线'}
</span>
{marketStatus && (
<>
<span className="status-sep">·</span>
<span className={`market-text ${serverMode === 'backtest' ? 'backtest' : (marketStatus.status === 'open' ? 'open' : 'closed')}`}>
{marketStatus.status_text || (marketStatus.status === 'open' ? 'OPEN' : 'CLOSED')}
{marketStatus.status_text || (marketStatus.status === 'open' ? '开盘' : '收盘')}
</span>
</>
)}
@@ -903,7 +903,7 @@ export default function LiveTradingApp() {
))}
</div>
<div className="portfolio-value">
<span className="portfolio-label">PORTFOLIO</span>
<span className="portfolio-label">投资组合</span>
<span className="portfolio-amount">${formatNumber(portfolioData.netValue)}</span>
</div>
</div>
@@ -918,28 +918,28 @@ export default function LiveTradingApp() {
className={`view-nav-btn ${currentView === 'rules' ? 'active' : ''}`}
onClick={() => setCurrentView('rules')}
>
Rules
规则
</button>
<button
className={`view-nav-btn ${currentView === 'room' ? 'active' : ''}`}
onClick={() => setCurrentView('room')}
>
Trading Room
交易室
</button>
<button
className={`view-nav-btn ${currentView === 'chart' ? 'active' : ''}`}
onClick={() => setCurrentView('chart')}
>
Performance Chart
业绩图表
</button>
<button
className={`view-nav-btn ${currentView === 'statistics' ? 'active' : ''}`}
onClick={() => setCurrentView('statistics')}
>
Statistics
统计
</button>
</div>
@@ -970,7 +970,7 @@ export default function LiveTradingApp() {
className={`chart-tab ${chartTab === 'all' ? 'active' : ''}`}
onClick={() => setChartTab('all')}
>
Daily
日线
</button>
{/* <button
className={`chart-tab ${chartTab === 'live' ? 'active' : ''} ${!isLiveEnabled ? 'disabled' : ''}`}

View File

@@ -110,12 +110,15 @@ export default function AboutModal({ onClose }) {
zh: {
intro: "如果不是让模型彼此竞争,而是像一支高效协作的团队一样进行实时交易,会发生什么?",
question: "这里不是竞技场而是团队。我们希望Agents不再单打独斗而是「组团」进入实时金融市场——这一十分困难且充满噪声的环境。",
trying: "我们正在探索多智能体协作在实时金融交易中的可能性。",
title1: "✦ 多智能体的技能互补",
point1: "不同模型、不同角色的智能体像真实的金融团队一样协作,各自承担数据分析、策略生成、风险控制等职责。",
point1Sub: "通过通知和会议机制进行信息交换,实现高效协作。",
title2: "✦ 能够持续进化的智能体系统",
point2: "依托「记忆」模块每个智能体都能跨回合保留经验不断学习、反思与调整。我们希望能看到在长期实时交易中Agent形成自己的独特方法论而不是一次性偶然的推理。",
point2Sub: "ReMe 记忆框架帮助 Agents 持续改进。",
title3: "✦ 实时参与市场的 AI Agents",
point3: "Agents从实时行情中学习并给予即时决策不是纸上谈兵而是面对市场的真实波动。"

View File

@@ -6,9 +6,9 @@ import { getModelIcon, getShortModelName } from '../utils/modelIcons';
* Get rank medal/trophy
*/
function getRankMedal(rank) {
if (rank === 1) return { emoji: '🏆', color: '#FFD700', label: 'Gold' };
if (rank === 2) return { emoji: '🥈', color: '#C0C0C0', label: 'Silver' };
if (rank === 3) return { emoji: '🥉', color: '#CD7F32', label: 'Bronze' };
if (rank === 1) return { emoji: '🏆', color: '#FFD700', label: '金牌' };
if (rank === 2) return { emoji: '🥈', color: '#C0C0C0', label: '银牌' };
if (rank === 3) return { emoji: '🥉', color: '#CD7F32', label: '铜牌' };
return { emoji: `#${rank}`, color: '#333333', label: `#${rank}` };
}
@@ -41,7 +41,7 @@ export default function AgentCard({ agent, onClose, isClosing }) {
const rankMedal = agent.rank ? getRankMedal(agent.rank) : null;
const isPortfolioManager = agent.id === 'portfolio_manager';
const isRiskManager = agent.id === 'risk_manager';
const displayName = isPortfolioManager ? 'Team' : agent.name;
const displayName = isPortfolioManager ? '团队' : agent.name;
// Get model icon configuration
const modelInfo = getModelIcon(agent.modelName, agent.modelProvider);
@@ -149,7 +149,7 @@ export default function AgentCard({ agent, onClose, isClosing }) {
whiteSpace: 'normal',
wordWrap: 'break-word'
}}>
Risk Manager focuses on risk management and does not participate in prediction accuracy ranking.
风控经理专注于风险管理不参与预测准确率排名
</div>
</div>
)}
@@ -173,7 +173,7 @@ export default function AgentCard({ agent, onClose, isClosing }) {
whiteSpace: 'normal',
wordWrap: 'break-word'
}}>
Portfolio Manager provides the team's final signal(position), synthesizing all analyst recommendations, and does not participate in ranking.
投资经理综合所有分析师建议提供团队最终交易信号不参与排名
</div>
</div>
)}
@@ -197,7 +197,7 @@ export default function AgentCard({ agent, onClose, isClosing }) {
marginBottom: 4,
textTransform: 'uppercase'
}}>
Model
模型
</div>
<div style={{
height: 40,
@@ -262,7 +262,7 @@ export default function AgentCard({ agent, onClose, isClosing }) {
marginBottom: 4,
textTransform: 'uppercase'
}}>
Win Rate
胜率
</div>
<div style={{
fontSize: 36,
@@ -278,7 +278,7 @@ export default function AgentCard({ agent, onClose, isClosing }) {
fontSize: 9,
color: '#555555'
}}>
{bullWins + bearWins}Win / {evaluatedTotal}Eval
{bullWins + bearWins} / {evaluatedTotal}
</div>
<div style={{
fontSize: 8,
@@ -288,7 +288,7 @@ export default function AgentCard({ agent, onClose, isClosing }) {
lineHeight: 1.2,
whiteSpace: 'pre-line'
}}>
Eval: total evaluated bull & bear signals.{'\n'}Win Rate = correct signals / total evaluated signals
评估: 总评估多空信号数{'\n'}胜率 = 正确信号 / 总评估信号
</div>
</div>
)}
@@ -309,7 +309,7 @@ export default function AgentCard({ agent, onClose, isClosing }) {
marginBottom: 4,
textTransform: 'uppercase'
}}>
Bull Win Rate
牛市胜率
</div>
<div style={{
fontSize: 28,
@@ -324,7 +324,7 @@ export default function AgentCard({ agent, onClose, isClosing }) {
fontSize: 9,
color: '#333333'
}}>
{bullWins}Win / {evaluatedBull}Eval
{bullWins} / {evaluatedBull}
{bullUnknown > 0 && ` / ${bullUnknown}P`}
</div>
</div>
@@ -346,7 +346,7 @@ export default function AgentCard({ agent, onClose, isClosing }) {
marginBottom: 4,
textTransform: 'uppercase'
}}>
Bear Win Rate
熊市胜率
</div>
<div style={{
fontSize: 28,
@@ -361,7 +361,7 @@ export default function AgentCard({ agent, onClose, isClosing }) {
fontSize: 9,
color: '#333333'
}}>
{bearWins}Win / {evaluatedBear}Eval
{bearWins} / {evaluatedBear}
{bearUnknown > 0 && ` / ${bearUnknown}P`}
</div>
</div>
@@ -433,7 +433,7 @@ export default function AgentCard({ agent, onClose, isClosing }) {
fontSize: 16,
color: isBull ? '#00C853' : isBear ? '#FF1744' : '#555555'
}}>
{isBull ? 'bull' : isBear ? 'bear' : 'neutral'}
{isBull ? '看涨' : isBear ? '看跌' : '中性'}
</div>
<div style={{
fontSize: 8,
@@ -471,14 +471,14 @@ export default function AgentCard({ agent, onClose, isClosing }) {
fontWeight: 700,
color: '#1976D2'
}}>
ⓘ Info
说明
</div>
<div style={{
fontSize: 8,
color: '#1976D2',
lineHeight: 1.2
}}>
Showing recent 5 trading days (1 week) signals only
仅显示最近5个交易日(1)的信号
</div>
</div>
</div>

View File

@@ -161,9 +161,9 @@ const AgentFeed = forwardRef(({ feed, leaderboard }, ref) => {
return (
<div className="agent-feed">
<div className="agent-feed-header">
<h3 className="agent-feed-title">ACTIVITY FEED</h3>
<h3 className="agent-feed-title">活动 feed</h3>
<div className="agent-filter-wrapper">
<label className="agent-filter-label">Filter:</label>
<label className="agent-filter-label">筛选:</label>
<div className="custom-select-wrapper">
<button
className="custom-select-trigger"
@@ -191,7 +191,7 @@ const AgentFeed = forwardRef(({ feed, leaderboard }, ref) => {
setDropdownOpen(false);
}}
>
<span>All Agents</span>
<span>全部 Agents</span>
</div>
{uniqueAgents.map(agent => {
const agentInfo = getAgentInfoByName(agent);
@@ -225,8 +225,8 @@ const AgentFeed = forwardRef(({ feed, leaderboard }, ref) => {
{filteredFeed.length === 0 && (
<div className="empty-state">
{selectedAgent === 'all'
? 'Waiting for system updates...'
: `No messages from ${selectedAgent}`}
? '等待系统更新...'
: `${selectedAgent} 没有消息`}
</div>
)}
@@ -294,9 +294,9 @@ function ConferenceItem({ conference, itemId, isHighlighted, getAgentModelInfo }
>
<div className="feed-item-header">
<span className="feed-item-title" style={{ color: colors.text }}>
CONFERENCE
会议
</span>
{conference.isLive && <span className="feed-live-badge"> LIVE</span>}
{conference.isLive && <span className="feed-live-badge"> 实时</span>}
<span className="feed-item-time">{formatTime(conference.startTime)}</span>
</div>
@@ -366,7 +366,7 @@ function ConferenceMessage({ message, getAgentModelInfo }) {
className="conf-expand-btn"
onClick={() => setExpanded(!expanded)}
>
{expanded ? Less' : 'More »'}
{expanded ? 收起' : '更多 »'}
</button>
)}
</div>
@@ -397,8 +397,8 @@ function MemoryItem({ memory, itemId, isHighlighted }) {
}
const agentLabel = memory.agent && memory.agent !== 'Memory'
? `MEMORY · ${memory.agent}`
: 'MEMORY';
? `记忆 · ${memory.agent}`
: '记忆';
return (
<div
@@ -511,7 +511,7 @@ function MemoryItem({ memory, itemId, isHighlighted }) {
className="feed-expand-btn"
onClick={() => setExpanded(!expanded)}
>
{expanded ? Less' : 'More »'}
{expanded ? 收起' : '更多 »'}
</button>
)}
</div>
@@ -525,7 +525,7 @@ function MessageItem({ message, itemId, isHighlighted, getAgentModelInfo }) {
const colors = message.agent === 'Memory' ? MESSAGE_COLORS.memory :
getAgentColors(message.agentId, message.agent);
const title = message.agent === 'Memory' ? 'MEMORY' : message.agent || 'AGENT';
const title = message.agent === 'Memory' ? '记忆' : message.agent || 'AGENT';
const agentModelData = message.agentId && getAgentModelInfo ?
getAgentModelInfo(message.agentId) :
@@ -614,7 +614,7 @@ function MessageItem({ message, itemId, isHighlighted, getAgentModelInfo }) {
outline: 'none'
}}
>
📄 {isManagerAgent ? 'View decision log »' : 'View full report »'}
📄 {isManagerAgent ? '查看决策日志 »' : '查看完整报告 »'}
</button>
)}
@@ -623,7 +623,7 @@ function MessageItem({ message, itemId, isHighlighted, getAgentModelInfo }) {
className="feed-expand-btn"
onClick={() => setExpanded(!expanded)}
>
{expanded ? Less' : 'More »'}
{expanded ? 收起' : '更多 »'}
</button>
)}
</div>

View File

@@ -59,7 +59,7 @@ export default function Header({
letterSpacing: '0.5px',
marginRight: '0px'
}}>
OPEN SOURCE
开源
</span>
<a
@@ -131,7 +131,7 @@ export default function Header({
}}
>
<span className="header-link">
Contact Us
联系我们
</span>
{/* Two contact buttons */}

View File

@@ -551,7 +551,7 @@ export default function RoomView({ bubbles, bubbleFor, leaderboard, feed, onJump
{/* Hint Text */}
<div className="agent-hint-text">
Click avatar to view details
点击头像查看详情
</div>
</div>
@@ -620,14 +620,14 @@ export default function RoomView({ bubbles, bubbleFor, leaderboard, feed, onJump
<button
className="bubble-jump-btn"
onClick={handleJumpToFeed}
title="Jump to message in feed"
title="跳转到消息"
>
</button>
<button
className="bubble-close-btn"
onClick={(e) => handleCloseBubble(agent.id, bubbleKey, e)}
title="Close bubble"
title="关闭"
>
×
</button>
@@ -728,7 +728,7 @@ export default function RoomView({ bubbles, bubbleFor, leaderboard, feed, onJump
title="Replay feed history"
>
<span className="replay-icon">&#9654;&#9654;</span>
<span>REPLAY</span>
<span>回放</span>
</button>
</div>
)}
@@ -749,7 +749,7 @@ export default function RoomView({ bubbles, bubbleFor, leaderboard, feed, onJump
}}
/>
<div className="replay-indicator">
<span className="replay-status">{isPaused ? 'PAUSED' : 'REPLAY MODE'}</span>
<span className="replay-status">{isPaused ? '已暂停' : '回放模式'}</span>
<button
className="replay-button"
onClick={isPaused ? resumeReplay : pauseReplay}

View File

@@ -181,17 +181,17 @@ export default function RulesView() {
},
zh: {
section1Title: "Agent 设定",
pmRole: "Portfolio Manager",
pmRole: "投资经理",
pmDesc: "负责最终交易决策和团队协作",
rmRole: "Risk Manager",
rmRole: "风控经理",
rmDesc: "监控组合风险并执行风险限制",
analystsRole: "Analysts",
analystsRole: "分析师",
analystsDesc: "使用不同工具和 AI 模型进行专业研究:",
analysts: [
{ name: "Valuation Analyst", model: "Moonshot", modelKey: "Moonshot" },
{ name: "Sentiment Analyst", model: "Qwen", modelKey: "Alibaba" },
{ name: "Fundamentals Analyst", model: "DeepSeek", modelKey: "DeepSeek" },
{ name: "Technical Analyst", model: "Zhipu AI", modelKey: "Zhipu AI" }
{ name: "估值分析师", model: "Moonshot", modelKey: "Moonshot" },
{ name: "情绪分析师", model: "Qwen", modelKey: "Alibaba" },
{ name: "基本面分析师", model: "DeepSeek", modelKey: "DeepSeek" },
{ name: "技术分析师", model: "Zhipu AI", modelKey: "Zhipu AI" }
],
section2Title: "Agent 决策机制",

View File

@@ -151,7 +151,7 @@ export default function StatisticsView({ trades, holdings, stats, baseline_vw, e
color: '#000000',
textTransform: 'uppercase'
}}>
Performance
业绩表现
</h2>
</div>
@@ -170,7 +170,7 @@ export default function StatisticsView({ trades, holdings, stats, baseline_vw, e
marginBottom: 12,
textTransform: 'uppercase'
}}>
Total Asset Value
总资产价值
</div>
<div style={{
fontSize: 36,
@@ -202,7 +202,7 @@ export default function StatisticsView({ trades, holdings, stats, baseline_vw, e
marginBottom: 8,
textTransform: 'uppercase'
}}>
Excess Return
超额收益
</div>
<div style={{
fontSize: 28,
@@ -218,7 +218,7 @@ export default function StatisticsView({ trades, holdings, stats, baseline_vw, e
marginTop: 4,
fontFamily: '"Courier New", monospace'
}}>
vs. VW: {excessReturnData.benchmarkReturn >= 0 ? '+' : ''}{excessReturnData.benchmarkReturn.toFixed(2)}%
vs 市值加权: {excessReturnData.benchmarkReturn >= 0 ? '+' : ''}{excessReturnData.benchmarkReturn.toFixed(2)}%
</div>
</div>
) : null}
@@ -233,7 +233,7 @@ export default function StatisticsView({ trades, holdings, stats, baseline_vw, e
marginBottom: 8,
textTransform: 'uppercase'
}}>
Win Rate
胜率
</div>
<div style={{
fontSize: 28,
@@ -252,7 +252,7 @@ export default function StatisticsView({ trades, holdings, stats, baseline_vw, e
marginTop: 4,
fontFamily: '"Courier New", monospace'
}}>
{pmWinRateData.totalWins}Win / {pmWinRateData.evaluatedTotal}Eval
{pmWinRateData.totalWins} / {pmWinRateData.evaluatedTotal}
</div>
)}
</div>
@@ -267,7 +267,7 @@ export default function StatisticsView({ trades, holdings, stats, baseline_vw, e
marginBottom: 8,
textTransform: 'uppercase'
}}>
Absolute Return
绝对收益
</div>
<div style={{
fontSize: 28,
@@ -296,7 +296,7 @@ export default function StatisticsView({ trades, holdings, stats, baseline_vw, e
letterSpacing: 0.5,
textTransform: 'uppercase'
}}>
Cash Position
现金头寸
</div>
<div style={{
fontSize: 16,
@@ -322,7 +322,7 @@ export default function StatisticsView({ trades, holdings, stats, baseline_vw, e
letterSpacing: 0.5,
textTransform: 'uppercase'
}}>
Total Trades
总交易数
</div>
<div style={{
fontSize: 16,
@@ -350,7 +350,7 @@ export default function StatisticsView({ trades, holdings, stats, baseline_vw, e
textTransform: 'uppercase',
color: '#666666'
}}>
Portfolio Weights
组合权重
</div>
<div className="statistics-table-container" style={{
display: 'grid',
@@ -397,7 +397,7 @@ export default function StatisticsView({ trades, holdings, stats, baseline_vw, e
fontSize: 12,
letterSpacing: 0.5
}}>
No statistics available
暂无统计数据
</div>
)}
</div>
@@ -433,7 +433,7 @@ export default function StatisticsView({ trades, holdings, stats, baseline_vw, e
color: '#000000',
textTransform: 'uppercase'
}}>
Portfolio Holdings
持仓明细
</h2>
</div>
@@ -448,7 +448,7 @@ export default function StatisticsView({ trades, holdings, stats, baseline_vw, e
fontSize: 11,
letterSpacing: 0.5
}}>
No positions currently held
当前无持仓
</div>
) : (
<>
@@ -456,11 +456,11 @@ export default function StatisticsView({ trades, holdings, stats, baseline_vw, e
<table className="data-table">
<thead>
<tr>
<th>Ticker</th>
<th>Quantity</th>
<th>Price</th>
<th>Value</th>
<th>Weight</th>
<th>代码</th>
<th>数量</th>
<th>价格</th>
<th>市值</th>
<th>权重</th>
</tr>
</thead>
<tbody>
@@ -505,7 +505,7 @@ export default function StatisticsView({ trades, holdings, stats, baseline_vw, e
onClick={() => setHoldingsPage(p => Math.max(1, p - 1))}
disabled={holdingsPage === 1}
>
Prev
上一页
</button>
<div className="pagination-info">
@@ -517,7 +517,7 @@ export default function StatisticsView({ trades, holdings, stats, baseline_vw, e
onClick={() => setHoldingsPage(p => Math.min(totalHoldingsPages, p + 1))}
disabled={holdingsPage === totalHoldingsPages}
>
Next
下一页
</button>
</div>
)}
@@ -552,7 +552,7 @@ export default function StatisticsView({ trades, holdings, stats, baseline_vw, e
color: '#000000',
textTransform: 'uppercase'
}}>
Transaction History
交易历史
</h2>
{trades.length > 0 && (
<div style={{
@@ -560,7 +560,7 @@ export default function StatisticsView({ trades, holdings, stats, baseline_vw, e
color: '#666666',
fontFamily: '"Courier New", monospace'
}}>
{trades.length} total
{trades.length}
</div>
)}
</div>
@@ -576,7 +576,7 @@ export default function StatisticsView({ trades, holdings, stats, baseline_vw, e
fontSize: 11,
letterSpacing: 0.5
}}>
No trades recorded
暂无交易记录
</div>
) : (
<>
@@ -584,11 +584,11 @@ export default function StatisticsView({ trades, holdings, stats, baseline_vw, e
<table className="data-table">
<thead>
<tr>
<th>Time</th>
<th>Stock</th>
<th>Side</th>
<th>Qty</th>
<th>Price</th>
<th>时间</th>
<th>股票</th>
<th>方向</th>
<th>数量</th>
<th>价格</th>
</tr>
</thead>
<tbody>
@@ -638,7 +638,7 @@ export default function StatisticsView({ trades, holdings, stats, baseline_vw, e
onClick={() => setTradesPage(p => Math.max(1, p - 1))}
disabled={tradesPage === 1}
>
Prev
上一页
</button>
<div className="pagination-info">
@@ -650,7 +650,7 @@ export default function StatisticsView({ trades, holdings, stats, baseline_vw, e
onClick={() => setTradesPage(p => Math.min(totalTradesPages, p + 1))}
disabled={tradesPage === totalTradesPages}
>
Next
下一页
</button>
</div>
)}

View File

@@ -66,43 +66,43 @@ export const AGENT_SEATS = [
export const AGENTS = [
{
id: "portfolio_manager",
name: "Portfolio Manager",
role: "Portfolio Manager",
name: "投资经理",
role: "投资经理",
avatar: CDN_ASSETS.companyRoom.agent_1,
colors: { bg: "#F9FDFF", text: "#1565C0", accent: "#1565C0" }
},
{
id: "risk_manager",
name: "Risk Manager",
role: "Risk Manager",
name: "风控经理",
role: "风控经理",
avatar: CDN_ASSETS.companyRoom.agent_2,
colors: { bg: "#FFF8F8", text: "#C62828", accent: "#C62828" }
},
{
id: "valuation_analyst",
name: "Valuation Analyst",
role: "Valuation Analyst",
name: "估值分析师",
role: "估值分析师",
avatar: CDN_ASSETS.companyRoom.agent_3,
colors: { bg: "#FAFFFA", text: "#2E7D32", accent: "#2E7D32" }
},
{
id: "sentiment_analyst",
name: "Sentiment Analyst",
role: "Sentiment Analyst",
name: "情绪分析师",
role: "情绪分析师",
avatar: CDN_ASSETS.companyRoom.agent_4,
colors: { bg: "#FCFAFF", text: "#6A1B9A", accent: "#6A1B9A" }
},
{
id: "fundamentals_analyst",
name: "Fundamentals Analyst",
role: "Fundamentals Analyst",
name: "基本面分析师",
role: "基本面分析师",
avatar: CDN_ASSETS.companyRoom.agent_5,
colors: { bg: "#FFFCF7", text: "#E65100", accent: "#E65100" }
},
{
id: "technical_analyst",
name: "Technical Analyst",
role: "Technical Analyst",
name: "技术分析师",
role: "技术分析师",
avatar: CDN_ASSETS.companyRoom.agent_6,
colors: { bg: "#F9FEFF", text: "#00838F", accent: "#00838F" }
},

View File

@@ -19,7 +19,9 @@ from structured_model import (
get_seer_model,
get_hunter_model,
)
from prompt import EnglishPrompts as Prompts
from prompt import (
EnglishPrompts as Prompts,
) # pylint: disable=no-name-in-module
# Uncomment the following line to use Chinese prompts
# from prompt import ChinesePrompts as Prompts

View File

@@ -6,7 +6,9 @@ from typing import Any
import numpy as np
from agentscope.agent import AgentBase, ReActAgent
from agentscope.message import Msg
from prompt import EnglishPrompts as Prompts
from prompt import ( # pylint: disable=no-name-in-module
EnglishPrompts as Prompts,
)
MAX_GAME_ROUND = 30
MAX_DISCUSSION_ROUND = 3

View File

@@ -8,7 +8,7 @@ This directory contains several examples of how to use the AgentScope Tuner for
| Frozen Lake | Make an agent to navigate the Frozen Lake environment in multi-step interactions. | [frozen_lake](./frozen_lake) | ✅ | ❌ | ❌ | ❌ | ❌ |
| Learn to Ask | Using LLM as a judge to provide feedback to facilitate agent tuning. | [learn_to_ask](./learn_to_ask) | ✅ | ✅ | ❌ | ❌ | ❌ |
| Email Search | Enhance the tool use ability of your agent on tasks without ground truth. | [email_search](./email_search) | ✅ | ✅ | ✅ | ❌ | ❌ |
| Werewolf Game | Enhance the agent's performance in a multi-agent game setting. | [werewolf_game](./werewolf_game)| ✅ | ✅ | ✅ | ✅ | ❌ |
| Werewolf Game | Enhance the agent's performance in a multi-agent game setting. | [werewolves](./werewolves)| ✅ | ✅ | ✅ | ✅ | ❌ |
| Data Augment | Data augmentation for better tuning results. | [data_augment](./data_augment) | ❌ | ❌ | ❌ | ❌ | ✅ |
Each example contains a README file with detailed instructions on how to set up and run the tuning process for that specific scenario. Feel free to explore and modify the examples to suit your needs!
@@ -22,6 +22,6 @@ AgentScope Tuner requires:
- `agentscope>=1.0.12`
- `trinity-rft>=0.4.1`
AgentScope Tuner is built on top of [Trinity-RFT](https://github.com/modelscope/Trinity-RFT).
Please refer to the [Trinity-RFT installation guide](https://modelscope.github.io/Trinity-RFT/en/main/tutorial/trinity_installation.html)
AgentScope Tuner is built on top of [Trinity-RFT](https://github.com/agentscope-ai/Trinity-RFT).
Please refer to the [Trinity-RFT installation guide](https://agentscope-ai.github.io/Trinity-RFT/en/main/tutorial/trinity_installation.html)
for detailed instructions on how to set up the environment.

View File

@@ -1,19 +1,19 @@
# AgentScope Tuner 中文说明
# AgentScope Tuner
本目录包含了多个使用 AgentScope Tuner 对 AgentScope 应用进行调优的示例。下表总结了可用的示例:
| 示例名称 | 描述 | 示例路径 | 多步交互 | LLM 评审 | 工具使用 | 多智能体 | 数据增强 |
|------------------|-------------------------------------------|---------------------------------|----------|----------|----------|----------|----------|
| 数学智能体 | 快速入门示例,调优数学智能体以提升其能力。 | [math_agent](./math_agent) | ✅ | ❌ | ❌ | ❌ | ❌ |
| Frozen Lake | 让智能体在多步交互中导航冰湖环境。 | [frozen_lake](./frozen_lake) | ✅ | ❌ | ❌ | ❌ | ❌ |
| Frozen Lake | 让智能体在与 frozen lake 环境的多步交互中学习。 | [frozen_lake](./frozen_lake) | ✅ | ❌ | ❌ | ❌ | ❌ |
| Learn to Ask | 使用 LLM 作为评审,为智能体调优提供反馈 | [learn_to_ask](./learn_to_ask) | ✅ | ✅ | ❌ | ❌ | ❌ |
| 邮件搜索 | 在无标准答案任务中提升智能体的工具使用能力。 | [email_search](./email_search) | ✅ | ✅ | ✅ | ❌ | ❌ |
| 狼人杀游戏 | 提升智能体在多智能体游戏场景下的表现。 | [werewolf_game](./werewolf_game)| ✅ | ✅ | ✅ | ✅ | ❌ |
| 狼人杀游戏 | 提升智能体在多智能体游戏场景下的表现。 | [werewolves](./werewolves)| ✅ | ✅ | ✅ | ✅ | ❌ |
| 数据增强 | 通过数据增强获得更好的调优效果。 | [data_augment](./data_augment) | ❌ | ❌ | ❌ | ❌ | ✅ |
每个示例目录下均包含详细的 README 文件,介绍了该场景下的调优流程和使用方法。欢迎根据实际需求进行探索和修改!
## 先决条件
## 前置要求
AgentScope Tuner 需要:
@@ -21,6 +21,6 @@ AgentScope Tuner 需要:
- `agentscope>=1.0.12`
- `trinity-rft>=0.4.1`
AgentScope Tuner 构建于 [Trinity-RFT](https://github.com/modelscope/Trinity-RFT) 之上
请参考 [Trinity-RFT 安装指南](https://modelscope.github.io/Trinity-RFT/zh/main/tutorial/trinity_installation.html)
AgentScope Tuner 于 [Trinity-RFT](https://github.com/agentscope-ai/Trinity-RFT) 构建
请参考 [Trinity-RFT 安装指南](https://agentscope-ai.github.io/Trinity-RFT/zh/main/tutorial/trinity_installation.html)
获取详细的安装方法。

View File

@@ -0,0 +1,155 @@
# Training Math Agent with Data-Augment Strategies
This example demonstrates how to use **AgentScope-Tuner** to enhance a math problem-solving agent. We will focus on leveraging **Data-Centric** features, such as the `difficulty_based` task selector, to improve data utility and training efficiency.
## Task Setting
We use the foundational [math-agent example](https://github.com/agentscope-ai/agentscope-samples/blob/main/tuner/math_agent/main.py) as our baseline. The agent is a **`ReActAgent`** that solves mathematical reasoning problems through step-by-step reasoning.
Training can be inefficient if tasks are too easy or too hard. This example demonstrates how to use **task selectors** to dynamically select tasks based on **data feedback**, focusing on "productively challenging" samples to maximize training efficiency. These data-centric techniques are generic and adaptable to other agent workflows.
## Dataset Preparation
To enable difficulty-based sampling, the training data must include difficulty features (e.g., pass rates from LLMs).
1. **Base Dataset**: You can use any standard math problem dataset. A good example is the math data in [LLM360/guru-RL-92k](https://huggingface.co/datasets/LLM360/guru-RL-92k), which comes pre-annotated with pass rates from different LLMs, serving as direct difficulty features.
2. **Build Your Own Features**: If you use your own dataset, you can generate these features by pre-running several models of varying capabilities and recording their pass rates. This can be done within the [**Trinity-RFT**](https://github.com/agentscope-ai/Trinity-RFT/pull/440) framework.
3. **Data Format**: The final dataset should be in HuggingFace format. In this example, data will be transferred to *GSM8K format* according to the [workflow](https://github.com/agentscope-ai/agentscope-samples/blob/main/tuner/math_agent/main.py). Besides the task content, it must include the difficulty feature columns you've defined (e.g. `qwen2.5_7b_pass_rate`).
4. **Example Data Preparation**: We provide a script for this example. Simply execute `python prepare_data.py` to generate the required dataset.
## Code Implementation
### Agent Workflow & Judge Function
This example follows the foundational [math-agent example](https://github.com/agentscope-ai/agentscope-samples/blob/main/tuner/math_agent/main.py), adopting its `run_react_agent` and `gsm8k_judge` as the `workflow_func` and `judge_func`, respectively. This highlights a key benefit: you can apply training strategies without altering your core agent logic.
### Design of Data-Centric Features
Leveraging the powerful data processing capabilities of **Trinity-RFT**, **AgentScope-Tuner** provides interfaces for advanced operations like task selection and experience processing.
#### Task Selector
The `Task Selector` determines how samples are selected from a dataset. It can be configured directly in configuration YAML files.
- **Built-in Selectors**:
- `sequential`: Samples are selected in a fixed order.
- `shuffle`: The dataset is shuffled at the beginning of each epoch.
- `random`: Samples are randomly chosen with replacement for each batch.
- `offline_easy2hard`: Samples are sorted by a predefined feature for curriculum learning.
- `difficulty_based` (Customized): An adaptive sampler based on task difficulty.
> For more details on `Task Selector`, including how to implement a custom selector based on feedback signals, please refer to **Trinity-RFT**'s **[Selector Development Guide](https://agentscope-ai.github.io/Trinity-RFT/en/main/tutorial/develop_selector.html)**.
#### Data Processor
The `Data Processor` allows for real-time processing of **Task** and **Experience** during training, enabling operations like calculating feedback metrics, data augmentation, or filtering.
For example, the `difficulty_based` selector requires a `pass_rate_calculator` operator to compute the agent's success rate for each task. This feedback is then used to adjust the sampling strategy.
> For more details on `Data Processor`, please refer to **Trinity-RFT**'s **[Operator Development Guide](https://agentscope-ai.github.io/Trinity-RFT/en/main/tutorial/develop_operator.html)**.
### Configuring the Experiments
To maintain clarity and simplicity, we recommend defining all data-specific parameters, including dataset paths and task selectors, within YAML configuration files.
We provide two configuration files to compare the baseline `random` selector against the `difficulty_based` selector.
**Experiment 1: Baseline with Random Selector (`config_random.yaml`)**
In `config_random.yaml`, we configure the `task_selector` for random sampling under `buffer.explorer_input.taskset`.
```yaml
# In config_random.yaml
buffer:
# ...
explorer_input:
taskset: # Training data
path: "path/to/your/augmented/math_data"
split: "train"
task_selector:
selector_type: random # Strategy of task selection
```
**Experiment 2: Advanced Training with Difficulty-Based Selector (`config_difficulty.yaml`)**
In `config_difficulty.yaml`, we switch the `task_selector` to difficulty_based and provide its specific parameters. Note that this config also enables the `pass_rate_calculator` needed for feedback.
```yaml
# In config_difficulty.yaml
# Enable the calculator to provide feedback for the selector
data_processor:
experience_pipeline:
operators:
- name: pass_rate_calculator
buffer:
# ...
explorer_input:
taskset: # Training data
path: "path/to/your/augmented/math_data"
split: "train"
task_selector:
selector_type: difficulty_based # Strategy of task selection
feature_keys: [ "qwen2.5_7b_pass_rate", "qwen3_30b_pass_rate" ]
kwargs: # Hyper-parameters for the selection algorithm
m: 8
# ...
```
> The `difficulty_based` selector in this example is an implementation of the ***BOTS*** algorithm. For details on its inner workings, please refer to the [***BOTS paper***](https://arxiv.org/abs/2510.26374) and its [***tutorials***](https://github.com/agentscope-ai/Trinity-RFT/blob/main/examples/bots/README.md).
## How to Run
### Step 1: Prerequisites
Ensure you have installed **AgentScope** and **Trinity-RFT** with [the guidance](https://github.com/agentscope-ai/agentscope-samples/blob/main/tuner/math_agent/README.md#how-to-run).
### Step 2: Prepare the Dataset
Run the data preparation script. Make sure to update the dataset paths in `config_random.yaml` and `config_difficulty.yaml` afterward.
```bash
python prepare_data.py
```
### Step 3: Start Ray Cluster
For distributed training, start a Ray cluster.
```bash
# For single node
ray start --head
```
### Step 4: Run Training
You can now run either the baseline or the difficulty-based training experiment.
- **To run the baseline experiment with a random selector:**
```bash
python main.py --config config_random.yaml
```
- **To run the experiment with the difficulty-based selector:**
```bash
python main.py --config config_difficulty.yaml
```
## Experimental Results
The following results compare the performance of the `difficulty-based` selection strategy (red line, bots) against a standard `random` selection strategy (black line, random).
<div align="center">
<img src="./training_result.jpg" alt="Training Result Image" width="90%"/>
</div>
### Training Reward Curve
The chart on the left shows the rollout accuracy during training. As can be seen, the tasks sampled by the random strategy appear to be difficult for the model, with the accuracy remaining below 0.2. In contrast, using the difficulty selector results in a higher mean accuracy, indicating that the agent is engaging with more tasks that it can successfully solve.
### Evaluation on AIME-24
For comparison, we evaluated both selection strategies on the AIME-24 benchmark. The chart on the right shows that the difficulty-based method demonstrates a better upward trend in performance over time.

View File

@@ -0,0 +1,153 @@
# 使用数据增强策略训练数学智能体
本示例演示了如何使用 **AgentScope-Tuner** 训练数学问题求解智能体。我们将重点利用**以数据为中心**的功能,例如 `difficulty_based` 任务选择器,以提高数据利用率和训练效率。
## 任务设置
我们使用基础的 [math-agent 示例](https://github.com/agentscope-ai/agentscope-samples/blob/main/tuner/math_agent/main.py) 作为基线。智能体是 **`ReActAgent`**,通过逐步推理解决数学推理问题。
如果任务太容易或太难,训练可能会效率低下。本示例演示如何使用**任务选择器**基于**数据反馈**动态选择任务,专注于"具有挑战性"的样本以最大化训练效率。这些以数据为中心的技术是通用的,可适应其他智能体工作流。
## 数据集准备
为启用基于难度的采样,训练数据必须包含难度特征(如 LLM 的通过率)。
1. **基础数据集**:您可以使用任何标准的数学问题数据集。一个很好的例子是 [LLM360/guru-RL-92k](https://huggingface.co/datasets/LLM360/guru-RL-92k) 中的数学数据,它预先标注了来自不同 LLM 的通过率,作为直接的难度特征。
2. **构建您自己的特征**:如果您使用自己的数据集,可以通过预先运行几个不同能力的模型并记录它们的通过率来生成这些特征。这可以在 [**Trinity-RFT**](https://github.com/agentscope-ai/Trinity-RFT/pull/440) 框架内完成。
3. **数据格式**:最终数据集应为 HuggingFace 格式。在此示例中,数据将根据[工作流](https://github.com/agentscope-ai/agentscope-samples/blob/main/tuner/math_agent/main.py)转换为 *GSM8K 格式*。除了任务内容外,它还必须包含您定义的难度特征列(例如 `qwen2.5_7b_pass_rate`)。
4. **示例数据准备**:我们为此示例提供了一个脚本。只需执行 `python prepare_data.py` 即可生成所需的数据集。
## 代码实现
本示例采用 [math-agent 示例](https://github.com/agentscope-ai/agentscope-samples/blob/main/tuner/math_agent/main.py) 的 `run_react_agent``gsm8k_judge` 作为 `workflow_func``judge_func`,说明可以在不改变核心智能体逻辑的情况下应用训练策略。
### 以数据为中心功能的设计
利用 **Trinity-RFT** 强大的数据处理能力,**AgentScope-Tuner** 为任务选择和经验处理等高级操作提供了接口。
#### 任务选择器
`Task Selector` 决定如何从数据集中选择样本。它可以直接在 YAML 配置文件中配置。
- **内置选择器**
- `sequential`:按固定顺序选择样本。
- `shuffle`:在每个 epoch 开始时打乱数据集。
- `random`:为每个批次随机选择样本(有放回)。
- `offline_easy2hard`:按预定义特征对样本进行排序,用于课程学习。
- `difficulty_based`(自定义):基于任务难度的自适应采样器。
> 有关 `Task Selector` 的更多详细信息,包括如何基于反馈信号实现自定义选择器,请参阅 **Trinity-RFT** 的 **[Selector 开发指南](https://agentscope-ai.github.io/Trinity-RFT/zh/main/tutorial/develop_selector.html)**。
#### 数据处理器
`Data Processor` 允许在训练期间实时处理**任务**task和**经验**experience支持计算反馈指标、数据增强或过滤等操作。
例如,`difficulty_based` 选择器需要一个 `pass_rate_calculator` 操作符来计算智能体对每个任务的成功率。然后使用此反馈来调整采样策略。
> 有关 `Data Processor` 的更多详细信息,请参阅 **Trinity-RFT** 的 **[Operator 开发指南](https://agentscope-ai.github.io/Trinity-RFT/zh/main/tutorial/develop_operator.html)**。
### 配置实验
为了保持清晰和简洁,我们建议在 YAML 配置文件中定义所有数据特定参数,包括数据集路径和任务选择器。
我们提供两个配置文件,用于比较基线 `random` 选择器与 `difficulty_based` 选择器。
**实验 1使用随机选择器的基线`config_random.yaml`**
`config_random.yaml` 中,我们在 `buffer.explorer_input.taskset` 下配置用于随机采样的 `task_selector`
```yaml
# 在 config_random.yaml 中
buffer:
# ...
explorer_input:
taskset: # 训练数据
path: "path/to/your/augmented/math_data"
split: "train"
task_selector:
selector_type: random # 任务选择策略
```
**实验 2使用基于难度选择器的高级训练`config_difficulty.yaml`**
`config_difficulty.yaml` 中,我们将 `task_selector` 切换为 `difficulty_based` 并提供其特定参数。请注意,此配置还启用了反馈所需的 `pass_rate_calculator`
```yaml
# 在 config_difficulty.yaml 中
# 启用计算器为选择器提供反馈
data_processor:
experience_pipeline:
operators:
- name: pass_rate_calculator
buffer:
# ...
explorer_input:
taskset: # 训练数据
path: "path/to/your/augmented/math_data"
split: "train"
task_selector:
selector_type: difficulty_based # 任务选择策略
feature_keys: [ "qwen2.5_7b_pass_rate", "qwen3_30b_pass_rate" ]
kwargs: # 选择算法的超参数
m: 8
# ...
```
> 本示例中的 `difficulty_based` 选择器是 ***BOTS*** 算法的实现。有关其内部工作原理的详细信息,请参阅 [***BOTS 论文***](https://arxiv.org/abs/2510.26374) 及其 [***教程***](https://github.com/agentscope-ai/Trinity-RFT/blob/main/examples/bots/README.md)。
## 如何运行
### 步骤 1前置要求
确保您已按照[指南](https://github.com/agentscope-ai/agentscope-samples/blob/main/tuner/math_agent/README_zh.md#how-to-run)安装了 **AgentScope****Trinity-RFT**
### 步骤 2准备数据集
运行数据准备脚本。确保之后更新 `config_random.yaml``config_difficulty.yaml` 中的数据集路径。
```bash
python prepare_data.py
```
### 步骤 3启动 Ray 集群
对于分布式训练,启动 Ray 集群。
```bash
# 单节点
ray start --head
```
### 步骤 4运行训练
您现在可以运行基线或基于难度的训练实验。
- **使用随机选择器运行基线实验:**
```bash
python main.py --config config_random.yaml
```
- **使用基于难度的选择器运行实验:**
```bash
python main.py --config config_difficulty.yaml
```
## 实验结果
以下结果比较了 `difficulty-based` 选择策略红线bots与标准 `random` 选择策略黑线random的性能。
<div align="center">
<img src="./training_result.jpg" alt="训练结果图" width="90%"/>
</div>
### 训练奖励曲线
左侧图表显示了训练期间的 rollout 准确率。可以看出,随机策略采样的任务对模型来说似乎很困难,准确率保持在 0.2 以下。相比之下,使用难度选择器会产生更高的平均准确率,表明智能体正在处理更多可以成功解决的任务。
### 在 AIME-24 上的评估
为了比较,我们在 AIME-24 基准上评估了两种选择策略。右侧图表显示,基于难度的方法在性能上表现出更好的上升趋势。

View File

@@ -0,0 +1,74 @@
project: "Data-Augmentation" # Project name
name: "Difficulty-Based-Selector" # Experiment name
checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints} # Directory to save model checkpoints
data_processor:
experience_pipeline:
operators:
- name: pass_rate_calculator # Calculate average reward and pass it back to selector
buffer:
total_epochs: 1 # Total training epochs
explorer_input:
taskset:
path: "path/to/your/augmented/math_data" # Training data path
split: "train" # Training data split
task_selector:
selector_type: difficulty_based # Strategy of task selection
feature_keys: [ "qwen2.5_7b_pass_rate", "qwen3_30b_pass_rate" ] # Utilized pass_rate key
kwargs: # Hyperparameter from [BOTS](https://github.com/modelscope/Trinity-RFT/blob/main/examples/bots/README.md)
m: 8
lamb: 0.1
rho: 0.1
target_reward: 0.8
tau: 0
do_sample: true
eval_tasksets:
- name: "eval-aime24" # Evaluation data name
path: "path/to/aime24_data" # Evaluation data path
split: "test" # Evaluation data split
synchronizer:
sync_style: dynamic_by_explorer # Sync triggered dynamically by explorer
sync_method: 'nccl'
sync_interval: 4 # Sync every N steps
sync_timeout: 7200 # Timeout for synchronization (seconds)
monitor:
monitor_type: tensorboard # Can also use wandb, mlflow or swanlab
# The config below has been set in python file
algorithm:
algorithm_type: multi_step_grpo # GRPO series for multi-step scenario
repeat_times: 8 # Number of rollouts per prompt for advantage estimation
optimizer:
lr: 1e-6 # Learning rate
model:
model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen3-0.6B} # Base model path
max_model_len: 24576 # Max context length
max_response_tokens: 16384 # Max tokens per response
temperature: 1.0 # Temperature of model's generation
cluster:
node_num: 1 # Number of used nodes
gpu_per_node: 8 # Number of GPUs every node
explorer:
eval_interval: 20 # Evaluation every N steps
runner_per_model: 16 # Runners per infer engine
max_timeout: 1200 # Max timeout for each rollout (seconds)
rollout_model:
engine_num: 4 # Number of vLLM engines for rollout model
tensor_parallel_size: 1 # TP size per engine for rollout model
enable_openai_api: true # Enable OpenAI-compatible API
enable_history: true # Enable conversation history
enable_auto_tool_choice: true # Enable automatic tool selection
tool_call_parser: hermes # Parser for tool calls
reasoning_parser: deepseek_r1 # Parser for reasoning type
trainer:
save_interval: 100 # Save checkpoint every N steps
use_dynamic_bsz: true # Use dynamic batch size
ulysses_sequence_parallel_size: 1 # Sequence parallel size for Ulysses

View File

@@ -0,0 +1,62 @@
project: "Data-Augmentation" # Project name
name: "Random-Selector" # Experiment name
checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints} # Directory to save model checkpoints
# Config of data-centric experiments
buffer:
total_epochs: 1 # Total training epochs
explorer_input:
taskset:
path: "path/to/your/augmented/math_data" # Training data path
split: "train" # Training data split
task_selector:
selector_type: random # Strategy of task selection
eval_tasksets:
- name: "eval-aime24" # Evaluation data name
path: "path/to/aime24_data" # Evaluation data path
split: "test" # Evaluation data split
synchronizer:
sync_style: dynamic_by_explorer # Sync triggered dynamically by explorer
sync_method: 'nccl'
sync_interval: 4 # Sync every N steps
sync_timeout: 7200 # Timeout for synchronization (seconds)
monitor:
monitor_type: tensorboard # Can also use wandb, mlflow or swanlab
# The config below has been set in python file
algorithm:
algorithm_type: multi_step_grpo # GRPO series for multi-step scenario
repeat_times: 8 # Number of rollouts per prompt for advantage estimation
optimizer:
lr: 1e-6 # Learning rate
model:
model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen3-0.6B} # Base model path
max_model_len: 24576 # Max context length
max_response_tokens: 16384 # Max tokens per response
temperature: 1.0 # Temperature of model's generation
cluster:
node_num: 1 # Number of used nodes
gpu_per_node: 8 # Number of GPUs every node
explorer:
eval_interval: 20 # Evaluation every N steps
runner_per_model: 16 # Runners per infer engine
max_timeout: 1200 # Max timeout for each rollout (seconds)
rollout_model:
engine_num: 4 # Number of vLLM engines for rollout model
tensor_parallel_size: 1 # TP size per engine for rollout model
enable_openai_api: true # Enable OpenAI-compatible API
enable_history: true # Enable conversation history
enable_auto_tool_choice: true # Enable automatic tool selection
tool_call_parser: hermes # Parser for tool calls
reasoning_parser: deepseek_r1 # Parser for reasoning type
trainer:
save_interval: 100 # Save checkpoint every N steps
use_dynamic_bsz: true # Use dynamic batch size
ulysses_sequence_parallel_size: 1 # Sequence parallel size for Ulysses

141
tuner/data_augment/main.py Normal file
View File

@@ -0,0 +1,141 @@
# -*- coding: utf-8 -*-
"""Example of training a ReAct math-agent with configurable task selector."""
from typing import Dict
import argparse
from agentscope.tuner import (
tune,
WorkflowOutput,
JudgeOutput,
TunerModelConfig,
AlgorithmConfig,
)
from agentscope.agent import ReActAgent
from agentscope.model import OpenAIChatModel
from agentscope.formatter import OpenAIChatFormatter
from agentscope.message import Msg
async def run_react_agent(
task: Dict,
model: OpenAIChatModel,
auxiliary_models: Dict[str, OpenAIChatModel] | None = None,
) -> WorkflowOutput:
"""A simple workflow function using the ReAct agent to solve tasks.
Args:
task (`Dict`): The task to be solved.
model (`OpenAIChatModel`): The language model to use.
auxiliary_models (`Dict[str, OpenAIChatModel]`):
A dictionary of additional chat models available for
LLM-as-a-Judge. Not used in this workflow.
Returns:
`WorkflowOutput`: The workflow output containing the agent's response.
"""
assert (
auxiliary_models is None or len(auxiliary_models) == 0
), "No auxiliary models are used in this workflow."
sys_prompt = (
"You are an agent specialized in solving math problems with tools. "
"Please solve the math problem given to you. You can write and "
"execute Python code to perform calculation or verify your answer. "
"You should return your final answer within \\boxed{{}}."
)
agent = ReActAgent(
name="react_agent",
sys_prompt=sys_prompt,
model=model,
enable_meta_tool=True,
formatter=OpenAIChatFormatter(),
)
response = await agent.reply(
msg=Msg("user", task["question"], role="user"),
)
return WorkflowOutput(
response=response,
)
async def gsm8k_judge(
task: Dict,
response: Msg,
auxiliary_models: Dict[str, OpenAIChatModel] | None = None,
) -> JudgeOutput:
"""A simple judge function to calculate reward based on agent's response.
Args:
task (`Dict`): The task information for the corresponding workflow.
response (`Msg`): The response generated by the corresponding workflow.
auxiliary_models (`Dict[str, OpenAIChatModel]`):
A dictionary of additional chat models available for LLM-as-a-Judge
usage. The keys are model names, and the values are the
corresponding OpenAIChatModel instances.
Returns:
`JudgeOutput`: The reward value assigned by the judge function.
"""
from trinity.common.rewards.math_reward import MathBoxedRewardFn
assert (
auxiliary_models is None or len(auxiliary_models) == 0
), "No auxiliary models are used in this workflow."
reward_fn = MathBoxedRewardFn()
# parse truth from gsm8k raw text
truth = task["answer"]
if isinstance(truth, str) and "####" in truth:
truth = truth.split("####")[1].strip()
else:
truth = str(truth)
# parse answer from response message
result = response.get_text_content()
reward_dict = reward_fn(
response=result,
truth=truth,
)
return JudgeOutput(
reward=sum(reward_dict.values()),
metrics=reward_dict,
)
if __name__ == "__main__":
# We recommend using YAML for data-centric experiments.
parser = argparse.ArgumentParser(
description="Train math-agent with different task selectors",
)
parser.add_argument(
"--config",
type=str,
default="config_random.yaml",
help="Path to the configuration YAML file",
)
args = parser.parse_args()
tuner_model = TunerModelConfig(
model_path="Qwen/Qwen3-0.6B",
max_model_len=24576,
max_tokens=16384,
temperature=1.0,
inference_engine_num=4,
tensor_parallel_size=1,
)
algorithm = AlgorithmConfig(
algorithm_type="multi_step_grpo",
group_size=8,
learning_rate=1e-6,
eval_interval_steps=20,
batch_size=16,
)
tune(
workflow_func=run_react_agent,
judge_func=gsm8k_judge,
config_path=args.config,
model=tuner_model,
algorithm=algorithm,
)

View File

@@ -0,0 +1,148 @@
# -*- coding: utf-8 -*-
"""
Prepare math data from LLM360/guru-RL-92k
Transfer to the GSM8K Format
"""
import argparse
import sys
from pathlib import Path
import pandas as pd
from huggingface_hub import hf_hub_download
# Define constants for the dataset
DATASET_REPO = "LLM360/guru-RL-92k"
DATASET_FILE = "train/math__combined_54.4k.parquet"
# Download the dataset from Hugging Face Hub.
# The dataset is from LLM360/guru-RL-92k.
def download_dataset(
repo_id: str,
filename_in_repo: str,
local_dir: str,
) -> Path:
print(f"--- Downloading dataset: {repo_id} ---")
print(f"File: {filename_in_repo}")
local_path = Path(local_dir)
local_path.mkdir(parents=True, exist_ok=True)
try:
downloaded_file_path = hf_hub_download(
repo_id=repo_id,
filename=filename_in_repo,
repo_type="dataset",
local_dir=local_path,
)
print(f"Successfully downloaded to: {downloaded_file_path}")
return Path(downloaded_file_path)
except Exception as e:
print(f"Error downloading dataset: {e}", file=sys.stderr)
sys.exit(1)
# Transform a single row from the original format to the target format.
def transform_row(row: pd.Series) -> pd.Series:
try:
original_question = row["prompt"][0]["content"]
sentence_to_remove = "Please output the final answer within \\boxed{}."
question = original_question.replace(sentence_to_remove, "").strip()
ground_truth = row["reward_model"]["ground_truth"]
answer = f"#### {ground_truth}"
rate_7b = row.get("qwen2.5_7b_pass_rate")
rate_30b = row.get("qwen3_30b_pass_rate")
return pd.Series(
{
"question": question,
"answer": answer,
"qwen2.5_7b_pass_rate": rate_7b,
"qwen3_30b_pass_rate": rate_30b,
},
)
except (TypeError, IndexError, KeyError) as e:
error_msg = (
f"Skipping row due to processing error: {e}. "
f"Row content: {row.to_dict()}"
)
print(error_msg, file=sys.stderr)
return pd.Series(
{
"question": None,
"answer": None,
"qwen2.5_7b_pass_rate": None,
"qwen3_30b_pass_rate": None,
},
)
# Read, transform, and save the dataset to a new location.
def transform_and_save_dataset(input_file: Path, output_dir: str):
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
output_file_path = output_path / input_file.name
print(f"--- Reading source file: {input_file} ---")
try:
df_original = pd.read_parquet(input_file)
print(f"Successfully read {len(df_original)} records.")
except Exception as e:
print(f"Fatal error reading file: {e}", file=sys.stderr)
sys.exit(1)
print("--- Starting data transformation ---")
df_transformed = df_original.apply(transform_row, axis=1)
original_count = len(df_transformed)
df_transformed.dropna(subset=["question", "answer"], inplace=True)
dropped_count = original_count - len(df_transformed)
if dropped_count > 0:
print(f"Warning: Dropped {dropped_count} invalid records.")
print(f"Transformation complete. {len(df_transformed)} generated.")
print(f"--- Saving processed file to: {output_file_path} ---")
try:
df_transformed.to_parquet(output_file_path, index=False)
print(f"Process complete! New file saved at: {output_file_path}")
except Exception as e:
print(f"Fatal error saving file: {e}", file=sys.stderr)
sys.exit(1)
def main():
parser = argparse.ArgumentParser(
description="Download and transform the guru-RL-92k math dataset.",
)
parser.add_argument(
"--raw_data_dir",
type=str,
default="data/train/raw",
help="Directory to download the raw dataset file.",
)
parser.add_argument(
"--processed_data_dir",
type=str,
default="data/train/math",
help="Directory to save the transformed dataset file.",
)
args = parser.parse_args()
downloaded_file = download_dataset(
repo_id=DATASET_REPO,
filename_in_repo=DATASET_FILE,
local_dir=args.raw_data_dir,
)
transform_and_save_dataset(
input_file=downloaded_file,
output_dir=args.processed_data_dir,
)
if __name__ == "__main__":
main()

Binary file not shown.

After

Width:  |  Height:  |  Size: 315 KiB

View File

@@ -0,0 +1,283 @@
# Training Email Search Agent with RL using AgentScope-Tuner
This example demonstrates how to implement reinforcement fine-tuning for the Email Search task (inspired by [ART](https://openpipe.ai/blog/art-e-mail-agent)) using AgentScope-Tuner, whose RFT functionality is backed by [Trinity-RFT](https://github.com/agentscope-ai/Trinity-RFT).
## Task Setting
The agent's goal is to answer user queries by searching through an email inbox. The agent needs to:
- Understand the user's question
- Search for relevant emails using keywords
- Read email contents to extract information
- Provide accurate answers with proper source citations
**Agent Type**: The agent (`EmailSearchAgent`) extends `ReActAgent`, which follows a reasoning-acting loop to solve tasks iteratively.
**Environment**: The environment is a SQLite database containing emails from the Enron Email dataset. Each task provides:
- `question`: The user's email search query
- `inbox_address`: The email inbox to search
- `query_date`: The date context for the query
- `answer`: The expected answer (ground truth), only for reward calculation
- `message_ids`: IDs of relevant emails containing the answer, only for reward calculation
**Available Tools**:
- `search_emails`: Find emails by keywords, inbox address, and date range. Returns a list of email summaries (message_id and snippet).
- `read_email`: Read the full content of a specific email by message_id.
- `generate_response`: Provide the final structured answer with sources (inherited from ReAct agent).
## Dataset Preparation
The dataset contains email queries based on the [Enron Email dataset](https://huggingface.co/datasets/corbt/enron-emails). Run the data preparation script to generate the email database and datasets:
```bash
python prepare_data.py
```
If you want to choose a new database path, you can modify the `DEFAULT_DB_PATH` in [`prepare_data.py`](./prepare_data.py). Also, remember to set an environment variable `DEFAULT_EMAIL_DB_PATH` to point to the database path before moving to the next step:
```bash
export DEFAULT_EMAIL_DB_PATH=/path/to/enron_emails_dataset/data/enron_emails.db
```
This will create a SQLite database and datasets:
```
/path/to/enron_emails_dataset/
├── data
└── enron_emails.db # Email database
├── train.parquet # Training samples
└── test.parquet # Test samples
```
Each sample looks like:
```json
{
"id": 0,
"question": "Were there any variances detected for hour 6 on 3/9/01?",
"answer": "Yes, variances were detected in both Generation and Energy Import/Export schedules for hour 6 on 3/9/01.",
"message_ids": ["<17407857.1075840601283.JavaMail.evans@thyme>"],
"how_realistic": 0.800000011920929,
"inbox_address": "pete.davis@enron.com",
"query_date": "2001-03-16"
}
```
## Code Implementation
This section provides a high-level overview of the code implementation. For detailed implementation, please refer to the source code.
### Agent Workflow
The workflow function `run_email_search_agent` implements the agent-environment interaction loop:
```python
async def run_email_search_agent(
task: Dict,
model: ChatModelBase,
auxiliary_models: Dict[str, ChatModelBase],
) -> WorkflowOutput:
# Parse task and create agent
agent = EmailSearchAgent(
name="email_search_agent",
sys_prompt=system_prompt,
model=model,
max_iters=max_turns,
)
# Run the agent with structured output
response = await agent.reply(
msg=Msg("user", question, role="user"),
structured_model=AnswerModel,
)
return WorkflowOutput(response=response)
```
The agent follows a ReAct pattern: it reasons about the task, calls tools to search and read emails, and finally generates a structured response containing the answer and source message IDs.
### Judge Function
The judge function `email_search_judge` implements reward calculation using LLM-as-a-Judge:
```python
async def email_search_judge(
task: Dict,
response: Msg,
auxiliary_models: Dict[str, ChatModelBase],
) -> JudgeOutput:
# Extract answer and sources from response
answer = answer_and_sources.get("answer")
sources = answer_and_sources.get("sources", [])
# Judge correctness using LLM-as-a-Judge
judge_model = auxiliary_models.get('judge') or list(auxiliary_models.values())[0]
judge_response = await judge_correctness(
answer, query, judge_model
)
# Calculate reward based on:
# - Answer correctness (accuracy: -1.0 to 1.0)
# - Source correctness (format: partial rewards)
# - Efficiency (bonus for fewer turns, correct sources)
result = {"accuracy": ..., "format": ...} # calculated based on judge_response
return JudgeOutput(
reward=sum(result.values()),
metrics=metrics,
)
```
The reward function considers:
- **Answer correctness**: Evaluated by LLM-as-a-Judge comparing the agent's answer with the ground truth
- **Source correctness**: Whether the agent cited the correct email message IDs
- **Efficiency**: Bonus rewards for finding/reading the correct email and taking fewer turns
See [`main.py`](./main.py) and [`email_search_agent.py`](./email_search_agent.py) for implementation details.
## How to Run
### Prerequisites
- At least 4 NVIDIA GPUs with CUDA 12.8 or newer
* Note: For the 30B Judge model, you need to use a GPU with at least 4080 memory; you can also run the model on multiple GPUs by using `tensor_parallel_size > 1` to reduce the memory usage (by default, `tensor_parallel_size=2`).
- Follow the Trinity-RFT [installation guide](https://agentscope-ai.github.io/Trinity-RFT/en/main/tutorial/trinity_installation.html) to install the latest version from source code
- Download the model checkpoint (example):
```bash
huggingface-cli download Qwen/Qwen3-4B-Instruct-2507
huggingface-cli download Qwen/Qwen3-30B-A3B-Instruct-2507 # judge model
```
### Configuration
Adjust the configuration file ([`config.yaml`](./config.yaml)) based on your hardware. Key configuration sections include:
- **TunerModelConfig**: Set `model_path` to your model checkpoint path
- **AlgorithmConfig**: Configure RL algorithm parameters (e.g., `multi_step_grpo`, learning rate, policy loss function)
- **DatasetConfig**: The dataset path is specified in `main.py` when creating the `DatasetConfig` object
- **Auxiliary Models**: Configure judge model settings for LLM-as-a-Judge
For full configuration details, see [Trinity-RFT Configuration Guide](https://agentscope-ai.github.io/Trinity-RFT/en/main/tutorial/trinity_configs.html).
### Start-Up Commands
1. Prepare the dataset:
```bash
python prepare_data.py
export DEFAULT_EMAIL_DB_PATH=/path/to/enron_emails_dataset/data/enron_emails.db
```
2. Set up a [Ray](https://github.com/ray-project/ray) cluster:
```bash
ray start --head
```
3. Run the training script:
```bash
python main.py
```
## Experimental Results
### Quantitative Results
The training results show improvements in agent performance over training iterations. Key metrics include:
- **Train reward**: The average reward on training samples increases as the agent learns better strategies
- **Rollout accuracy**: The average accuracy on rollout samples increases as the agent learns better strategies
<div align="center">
<img src="./critic_reward_mean.png" alt="Training Rewards" width="90%"/>
</div>
<div align="center">
<img src="./rollout_accuracy_mean.png" alt="Rollout Accuracy" width="90%"/>
</div>
### Concrete Example
An example of the agent's behavior is shown below:
**Query:** "What do the color codes mean in the curve assessment?"
We show the last several turns of agent responses:
The agent performs multiple search attempts to find relevant emails. After some unsuccessful searches, the agent tries:
**Tool call:**
```json
{
"type": "tool_use",
"name": "search_emails",
"input": {
"inbox_address": "steven.kean@enron.com",
"query_date": "2001-04-03",
"keywords": ["curve", "assessment"]
}
}
```
**Tool result:**
```json
{
"type": "tool_result",
"name": "search_emails",
"output": [
{
"message_id": "<2654330.1075846153519.JavaMail.evans@thyme>",
"snippet": "... curves and Senior Management as necessary.\\n\\n<b>Curve Assessment</b>\\n Green: \\n - curve is reasonable\\n - small P&L ..."
},
{
"message_id": "<12499967.1075846153495.JavaMail.evans@thyme>",
"snippet": "... curves and Senior Management as necessary.\\n\\n<b>Curve Assessment</b>\\n Green: \\n - curve is reasonable\\n - small P&L ..."
}
]
}
```
After finding relevant emails, the agent uses `read_email` tool to read the full content of both emails:
**Tool call:**
```json
{
"type": "tool_use",
"name": "read_email",
"input": {
"message_id": "<2654330.1075846153519.JavaMail.evans@thyme>"
}
}
```
**Tool call:**
```json
{
"type": "tool_use",
"name": "read_email",
"input": {
"message_id": "<12499967.1075846153495.JavaMail.evans@thyme>"
}
}
```
After reading the emails, the agent generates the final structured response:
**Tool call:**
```json
{
"type": "tool_use",
"name": "generate_response",
"input": {
"answer": "In the curve assessment, the color codes are used as follows:\n- Green: The curve is considered reasonable, and the P&L (profit and loss) is small.",
"sources": [
"<2654330.1075846153519.JavaMail.evans@thyme>",
"<12499967.1075846153495.JavaMail.evans@thyme>"
]
}
}
```
The judge evaluates the answer as correct based on the ground truth.

View File

@@ -0,0 +1,283 @@
# 使用 AgentScope-Tuner 训练邮件搜索智能体
本示例展示如何使用 AgentScope-Tuner 对邮件搜索任务(灵感来自 [ART](https://openpipe.ai/blog/art-e-mail-agent))进行强化微调,其 RFT 功能由 [Trinity-RFT](https://github.com/agentscope-ai/Trinity-RFT) 提供支持。
## 任务设定
智能体的目标是通过搜索邮件收件箱来回答用户查询。智能体需要:
- 理解用户的问题
- 使用关键词搜索相关邮件
- 阅读邮件内容以提取信息
- 提供准确的答案并附上适当的来源引用
**智能体类型**:智能体(`EmailSearchAgent`)继承自 `ReActAgent`,遵循推理-行动循环来迭代解决任务。
**环境**:环境是一个包含来自 Enron 邮件数据集的 SQLite 数据库。每个任务提供:
- `question`:用户的邮件搜索查询
- `inbox_address`:要搜索的邮件收件箱
- `query_date`:查询的日期上下文
- `answer`:预期答案(真实值),仅用于奖励计算
- `message_ids`:包含答案的相关邮件 ID仅用于奖励计算
**可用工具**
- `search_emails`通过关键词、收件箱地址和日期范围查找邮件。返回邮件摘要列表message_id 和片段)。
- `read_email`:通过 message_id 读取特定邮件的完整内容。
- `generate_response`:提供带有来源的最终结构化答案(继承自 ReAct 智能体)。
## 数据集准备
数据集包含基于 [Enron 邮件数据集](https://huggingface.co/datasets/corbt/enron-emails) 的邮件查询。运行数据准备脚本以生成邮件数据库和数据集:
```bash
python prepare_data.py
```
如果你想选择新的数据库路径,可以修改 [`prepare_data.py`](./prepare_data.py) 中的 `DEFAULT_DB_PATH`。同时,请记住在进入下一步之前设置环境变量 `DEFAULT_EMAIL_DB_PATH` 指向数据库路径:
```bash
export DEFAULT_EMAIL_DB_PATH=/path/to/enron_emails_dataset/data/enron_emails.db
```
这将创建一个 SQLite 数据库和数据集:
```
/path/to/enron_emails_dataset/
├── data
└── enron_emails.db # 邮件数据库
├── train.parquet # 训练样本
└── test.parquet # 测试样本
```
每个样本如下所示:
```json
{
"id": 0,
"question": "Were there any variances detected for hour 6 on 3/9/01?",
"answer": "Yes, variances were detected in both Generation and Energy Import/Export schedules for hour 6 on 3/9/01.",
"message_ids": ["<17407857.1075840601283.JavaMail.evans@thyme>"],
"how_realistic": 0.800000011920929,
"inbox_address": "pete.davis@enron.com",
"query_date": "2001-03-16"
}
```
## 代码实现
本节提供代码实现的高级概览。详细实现请参考源代码。
### 智能体工作流
工作流函数 `run_email_search_agent` 实现智能体-环境交互循环:
```python
async def run_email_search_agent(
task: Dict,
model: ChatModelBase,
auxiliary_models: Dict[str, ChatModelBase],
) -> WorkflowOutput:
# 解析任务并创建智能体
agent = EmailSearchAgent(
name="email_search_agent",
sys_prompt=system_prompt,
model=model,
max_iters=max_turns,
)
# 使用结构化输出运行智能体
response = await agent.reply(
msg=Msg("user", question, role="user"),
structured_model=AnswerModel,
)
return WorkflowOutput(response=response)
```
智能体遵循 ReAct 模式:它推理任务,调用工具搜索和阅读邮件,最后生成包含答案和来源消息 ID 的结构化响应。
### 评判函数
评判函数 `email_search_judge` 使用 LLM-as-a-Judge 实现奖励计算:
```python
async def email_search_judge(
task: Dict,
response: Msg,
auxiliary_models: Dict[str, ChatModelBase],
) -> JudgeOutput:
# 从响应中提取答案和来源
answer = answer_and_sources.get("answer")
sources = answer_and_sources.get("sources", [])
# 使用 LLM-as-a-Judge 评判正确性
judge_model = auxiliary_models.get('judge') or list(auxiliary_models.values())[0]
judge_response = await judge_correctness(
answer, query, judge_model
)
# 基于以下因素计算奖励:
# - 答案正确性(准确度:-1.0 到 1.0
# - 来源正确性(格式:部分奖励)
# - 效率(对更少轮次、正确来源的奖励)
result = {"accuracy": ..., "format": ...} # 基于 judge_response 计算
return JudgeOutput(
reward=sum(result.values()),
metrics=metrics,
)
```
奖励函数考虑以下因素:
- **答案正确性**:通过 LLM-as-a-Judge 比较智能体的答案与真实值进行评估
- **来源正确性**:智能体是否引用了正确的邮件消息 ID
- **效率**:对找到/阅读正确邮件和更少轮次的奖励
详细实现请参考 [`main.py`](./main.py) 和 [`email_search_agent.py`](./email_search_agent.py)。
## 运行方法
### 前置要求
- 至少 4 张 NVIDIA GPUCUDA 版本 ≥ 12.8
* 注意:对于 30B 评判模型,需要使用至少 4080 显存的 GPU你也可以通过使用 `tensor_parallel_size > 1` 在多张 GPU 上运行模型以减少显存使用(默认情况下,`tensor_parallel_size=2`)。
- 按照 Trinity-RFT [安装指南](https://agentscope-ai.github.io/Trinity-RFT/zh/main/tutorial/trinity_installation.html) 从源码安装最新版本
- 下载模型检查点(示例):
```bash
huggingface-cli download Qwen/Qwen3-4B-Instruct-2507
huggingface-cli download Qwen/Qwen3-30B-A3B-Instruct-2507 # 评判模型
```
### 配置
根据你的硬件调整配置文件([`config.yaml`](./config.yaml))。关键配置部分包括:
- **TunerModelConfig**:将 `model_path` 设置为你的模型检查点路径
- **AlgorithmConfig**:配置 RL 算法参数(例如,`multi_step_grpo`、学习率、策略损失函数)
- **DatasetConfig**:数据集路径在创建 `DatasetConfig` 对象时在 `main.py` 中指定
- **辅助模型**:为 LLM-as-a-Judge 配置评判模型设置
完整配置详情请参考 [Trinity-RFT 配置指南](https://agentscope-ai.github.io/Trinity-RFT/zh/main/tutorial/trinity_configs.html)。
### 启动命令
1. 准备数据集:
```bash
python prepare_data.py
export DEFAULT_EMAIL_DB_PATH=/path/to/enron_emails_dataset/data/enron_emails.db
```
2. 启动 [Ray](https://github.com/ray-project/ray)
```bash
ray start --head
```
3. 运行训练脚本:
```bash
python main.py
```
## 实验结果
### 定量结果
训练结果显示智能体性能随训练迭代次数的提升。关键指标包括:
- **训练奖励**:训练样本上的平均奖励随着智能体学习更好的策略而增加
- **Rollout 准确度**Rollout 样本上的平均准确度随着智能体学习更好的策略而增加
<div align="center">
<img src="./critic_reward_mean.png" alt="Training Rewards" width="90%"/>
</div>
<div align="center">
<img src="./rollout_accuracy_mean.png" alt="Rollout Accuracy" width="90%"/>
</div>
### 具体示例
智能体行为示例如下:
**查询:** "What do the color codes mean in the curve assessment?"
我们展示智能体响应的最后几轮:
智能体执行多次搜索尝试以找到相关邮件。经过一些不成功的搜索后,智能体尝试:
**工具调用:**
```json
{
"type": "tool_use",
"name": "search_emails",
"input": {
"inbox_address": "steven.kean@enron.com",
"query_date": "2001-04-03",
"keywords": ["curve", "assessment"]
}
}
```
**工具结果:**
```json
{
"type": "tool_result",
"name": "search_emails",
"output": [
{
"message_id": "<2654330.1075846153519.JavaMail.evans@thyme>",
"snippet": "... curves and Senior Management as necessary.\\n\\n<b>Curve Assessment</b>\\n Green: \\n - curve is reasonable\\n - small P&L ..."
},
{
"message_id": "<12499967.1075846153495.JavaMail.evans@thyme>",
"snippet": "... curves and Senior Management as necessary.\\n\\n<b>Curve Assessment</b>\\n Green: \\n - curve is reasonable\\n - small P&L ..."
}
]
}
```
找到相关邮件后,智能体使用 `read_email` 工具读取两封邮件的完整内容:
**工具调用:**
```json
{
"type": "tool_use",
"name": "read_email",
"input": {
"message_id": "<2654330.1075846153519.JavaMail.evans@thyme>"
}
}
```
**工具调用:**
```json
{
"type": "tool_use",
"name": "read_email",
"input": {
"message_id": "<12499967.1075846153495.JavaMail.evans@thyme>"
}
}
```
阅读邮件后,智能体生成最终的结构化响应:
**工具调用:**
```json
{
"type": "tool_use",
"name": "generate_response",
"input": {
"answer": "In the curve assessment, the color codes are used as follows:\n- Green: The curve is considered reasonable, and the P&L (profit and loss) is small.",
"sources": [
"<2654330.1075846153519.JavaMail.evans@thyme>",
"<12499967.1075846153495.JavaMail.evans@thyme>"
]
}
}
```
评判器评估上面的答案为正确。

View File

@@ -0,0 +1,175 @@
# -*- coding: utf-8 -*-
"""Adapted from Trinity-RFT"""
import json
import traceback
from dataclasses import asdict
from datetime import datetime, timedelta
from typing import Any
from _utils import ( # pylint: disable=E0611
read_email_tool,
search_emails_tool,
)
from agentscope import logger
from agentscope.agent import ReActAgent
from agentscope.message import TextBlock
from agentscope.tool import Toolkit, ToolResponse
def pre_reasoning_hook(_self: Any, _kwargs: Any) -> dict[str, Any] | None:
"""Pre-reasoning hook to remove tool_choice from kwargs."""
_kwargs.pop("tool_choice", None)
return _kwargs
class EmailSearchAgent(ReActAgent):
"""
A customized ReAct agent with pre-defined tools for
email search and reading.
Ref: https://github.com/OpenPipe/ART
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
self.message_id_list = (
[]
) # List to store message IDs found during search
self.ever_read_message_ids = (
[]
) # List to store message IDs that have been read
toolkit = Toolkit()
toolkit.register_tool_function(self.search_emails)
toolkit.register_tool_function(self.read_email)
super().__init__(*args, toolkit=toolkit, **kwargs)
self.register_instance_hook(
"pre_reasoning",
"tool_choice_hook",
pre_reasoning_hook,
)
async def reset(self) -> None:
"""Reset agent state for a new rollout/episode."""
self.message_id_list.clear()
self.ever_read_message_ids.clear()
await self.memory.clear()
def search_emails(
self,
inbox_address: str,
query_date: str,
keywords: list[str],
**_kwargs: Any,
) -> ToolResponse:
"""
Search the user's email inbox for emails that match the given keywords.
Args:
inbox_address: The user's email address.
query_date: The date of the query in 'YYYY-MM-DD' format.
keywords: Keywords to search for in the user's email inbox.
Returns:
ToolResponse:
A ToolResponse object containing a list of TextBlock objects
in the `content` field. On success, the text field of the
TextBlock contains a JSON string representing a list of email
summaries (e.g., message_id, snippet) matching the search
criteria. Each email summary is converted to a dictionary via
`asdict`. On failure, the text indicates an error message.
"""
try:
next_day = (
datetime.strptime(query_date, "%Y-%m-%d") + timedelta(days=1)
).strftime(
"%Y-%m-%d",
)
res = search_emails_tool(
inbox=inbox_address,
sent_before=next_day,
keywords=keywords,
)
self.message_id_list.extend([r.message_id for r in res])
return ToolResponse(
content=[
TextBlock(
type="text",
text=json.dumps([asdict(r) for r in res]),
),
],
)
except Exception as e:
logger.info(
"Error in search_emails: %s, traceback: %s",
e,
traceback.format_exc(),
)
return ToolResponse(
content=[
TextBlock(
type="text",
text=(
f"Error: Failed to search emails.\n"
f"Error message: {e}"
),
),
],
)
def read_email(self, message_id: str, **_kwargs: Any) -> ToolResponse:
"""
Read the content of an email from the user's email inbox.
Returns the email content.
Args:
message_id (str): The unique identifier of the email to read.
Returns:
ToolResponse:
A ToolResponse object containing the email content or an
error message if the email is not found.
"""
try:
email_content = read_email_tool(message_id)
self.ever_read_message_ids.append(message_id)
if email_content is None:
return ToolResponse(
content=[
TextBlock(
type="text",
text=(
f"Error: Email (message_id = {message_id}) "
f"not found."
),
),
],
)
return ToolResponse(
content=[
TextBlock(
type="text",
text=json.dumps(email_content.model_dump()),
),
],
)
except Exception as e:
logger.info(
"Error in read_email: %s, traceback: %s",
e,
traceback.format_exc(),
)
return ToolResponse(
content=[
TextBlock(
type="text",
text=(
f"Error: Failed to read email.\n"
f"Error message: {e}"
),
),
],
)

View File

@@ -0,0 +1,328 @@
# -*- coding: utf-8 -*-
"""
This file defines Dataclass and tool implementations.
Modified from https://github.com/OpenPipe/ART/blob/art-e/
"""
import datetime
import os
import sqlite3
from dataclasses import dataclass
from typing import Any, List, Optional
from pydantic import BaseModel, Field, field_validator
from agentscope import logger
DEFAULT_DB_PATH = os.environ.get("DEFAULT_EMAIL_DB_PATH")
conn = None
def get_conn() -> sqlite3.Connection:
"""Get or create a database connection."""
global conn
if conn is None:
conn = sqlite3.connect(
f"file:{DEFAULT_DB_PATH}?mode=ro",
uri=True,
check_same_thread=False,
)
return conn
class QueryModel(BaseModel):
"""Model for email search query."""
id: int
question: str
answer: str
message_ids: List[str] # message_ids (strings) of referenced emails
how_realistic: float
inbox_address: str
query_date: str
@field_validator("query_date", mode="before")
@classmethod
def format_date(cls, v: Any) -> str:
"""Format date to string if it's a datetime object."""
if isinstance(v, datetime.datetime):
return v.strftime("%Y-%m-%d")
return v
class AnswerModel(BaseModel):
"""Model for agent's answer with sources."""
answer: str = Field(
description=(
"It should be called with the answer and the sources. "
"If you cannot find the answer, you should return "
"'I don't know' with an empty list of sources."
),
)
sources: List[str] = Field(
description=(
"a list of message ids that are relevant to the query. "
"Usually there will be only one. If you cannot find the "
"answer, you should return an empty list."
),
)
class Email(BaseModel):
"""Model representing an email."""
message_id: str
date: str # ISO 8601 string 'YYYY-MM-DD HH:MM:SS'
subject: Optional[str] = None
from_address: Optional[str] = None
to_addresses: List[str] = Field(default_factory=list)
cc_addresses: List[str] = Field(default_factory=list)
bcc_addresses: List[str] = Field(default_factory=list)
body: Optional[str] = None
file_name: Optional[str] = None
@dataclass
class SearchResult:
"""Result from email search."""
message_id: str
snippet: str
class FinalRubric(BaseModel):
"""Rubric for evaluating agent performance."""
answer_correct: bool = False
sources_correct: bool = False
num_turns: int = 0
attempted_answer: bool = False
ever_found_right_email: bool = False
ever_read_right_email: bool = False
cant_parse_tool_call: bool = False
bad_tool_call_name: bool = False
bad_tool_call_args: bool = False
ran_out_of_turns: bool = False
returned_i_dont_know: bool = False
num_sources: int = 0
ever_tried_to_read_invalid_email: bool = False
prompt_tokens: int = 0
completion_tokens: int = 0
# Define tools for agent
def search_emails_tool(
inbox: str,
keywords: List[str],
from_addr: Optional[str] = None,
to_addr: Optional[str] = None,
sent_after: Optional[str] = None,
sent_before: Optional[str] = None,
max_results: int = 10,
) -> List[SearchResult]:
"""
Searches the email database based on keywords, inbox,
sender, recipient, and date range.
Args:
inbox: The email address of the user performing the search.
Results include emails sent from or to (inc. cc/bcc)
this address.
keywords: A list of keywords that must all appear in the
subject or body.
from_addr: Optional email address to filter emails sent *from*.
to_addr: Optional email address to filter emails sent *to*
(inc. cc/bcc).
sent_after: Optional date string 'YYYY-MM-DD'. Filters for
emails sent on or after this date.
sent_before: Optional date string 'YYYY-MM-DD'. Filters for
emails sent before this date.
max_results: The maximum number of results to return.
Cannot exceed 10.
Returns:
A list of SearchResult objects, each containing 'message_id'
and 'snippet'. Returns an empty list if no results are found
or an error occurs.
"""
# Initialize sql and params
sql: Optional[str] = None
params: List[str | int] = []
cursor = get_conn().cursor()
# --- Build Query ---
where_clauses: List[str] = []
# 1. Keywords (FTS)
if not keywords:
raise ValueError("No keywords provided for search.")
if max_results > 10:
raise ValueError("max_results must be less than or equal to 10.")
# FTS5 default is AND, so just join keywords. Escape quotes for safety.
fts_query = " ".join(f""" "{k.replace('"', '""')}" """ for k in keywords)
where_clauses.append("emails_fts MATCH ?")
params.append(fts_query)
# 2. Inbox filter (must be from OR to/cc/bcc the inbox user)
# Use the composite index idx_recipients_address_email here
where_clauses.append(
"""
(e.from_address = ? OR EXISTS (
SELECT 1 FROM recipients r_inbox
WHERE r_inbox.recipient_address = ? AND r_inbox.email_id = e.id
))
""",
)
params.extend([inbox, inbox])
# 3. Optional From filter
if from_addr:
where_clauses.append("e.from_address = ?")
params.append(from_addr)
# 4. Optional To filter (includes to, cc, bcc)
# Use composite index idx_recipients_address_email
if to_addr:
where_clauses.append(
"""
EXISTS (
SELECT 1 FROM recipients r_to
WHERE r_to.recipient_address = ? AND r_to.email_id = e.id
)
""",
)
params.append(to_addr)
# 5. Optional Sent After filter
if sent_after:
# Assumes date format 'YYYY-MM-DD'
# Compare against the start of the day
where_clauses.append("e.date >= ?")
params.append(f"{sent_after} 00:00:00")
# 6. Optional Sent Before filter
if sent_before:
# Assumes date format 'YYYY-MM-DD'
# Compare against the start of the day (exclusive)
where_clauses.append("e.date < ?")
params.append(f"{sent_before} 00:00:00")
# --- Construct Final Query ---
# snippet(<table>, <column_index>, <highlight_start>,
# <highlight_end>, <ellipsis>, <tokens>)
# -1 means highlight across all columns (subject, body)
sql = f"""
SELECT
e.message_id,
snippet(emails_fts, -1, '<b>', '</b>', ' ... ', 15) as snippet
FROM
emails e JOIN emails_fts fts ON e.id = fts.rowid
WHERE
{" AND ".join(where_clauses)}
ORDER BY
e.date DESC -- Order by date for relevance
LIMIT ?;
"""
params.append(max_results)
# --- Execute and Fetch ---
logger.debug("Executing SQL: %s", sql)
logger.debug("With params: %s", params)
cursor.execute(sql, params)
results = cursor.fetchall()
# Format results
formatted_results = [
SearchResult(message_id=row[0], snippet=row[1]) for row in results
]
logger.info("Search found %d results.", len(formatted_results))
return formatted_results
def read_email_tool(message_id: str) -> Optional[Email]:
"""
Retrieves a single email by its message_id from the database.
Args:
message_id: The unique identifier of the email to retrieve.
Returns:
An Email object containing the details of the found email,
or None if the email is not found or an error occurs.
"""
cursor = get_conn().cursor()
# --- Query for Email Core Details ---
email_sql = """
SELECT id, message_id, date, subject, from_address, body, file_name
FROM emails
WHERE message_id = ?;
"""
cursor.execute(email_sql, (message_id,))
email_row = cursor.fetchone()
if not email_row:
logger.warning("Email with message_id '%s' not found.", message_id)
return None
email_pk_id, msg_id, date, subject, from_addr, body, file_name = email_row
# DEBUG
logger.info("[read_email_tool] input_message_id=%s", message_id)
logger.info(
"[read_email_tool] db: id=%s, message_id=%s",
email_pk_id,
msg_id,
)
# search for recipients by emails.id (rather than message_id)
recipients_sql = """
SELECT recipient_address, recipient_type
FROM recipients
WHERE email_id = ?;
"""
cursor.execute(recipients_sql, (email_pk_id,))
recipient_rows = cursor.fetchall()
to_addresses: List[str] = []
cc_addresses: List[str] = []
bcc_addresses: List[str] = []
for addr, rtype in recipient_rows:
type_lower = rtype.lower()
if type_lower == "to":
to_addresses.append(addr)
elif type_lower == "cc":
cc_addresses.append(addr)
elif type_lower == "bcc":
bcc_addresses.append(addr)
# --- Construct Email Object ---
email_obj = Email(
message_id=msg_id, # Convert to string to match Pydantic model
date=date,
subject=subject,
from_address=from_addr,
to_addresses=to_addresses,
cc_addresses=cc_addresses,
bcc_addresses=bcc_addresses,
body=body,
file_name=file_name,
)
return email_obj
__all__ = [
"QueryModel",
"AnswerModel",
"FinalRubric",
"Email",
"SearchResult",
"search_emails_tool",
"read_email_tool",
"get_conn",
]

View File

@@ -0,0 +1,72 @@
project: "AgentScope" # Project name
name: "Email_search" # Experiment name
checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints} # Directory to save model checkpoints
algorithm:
algorithm_type: multi_step_grpo # GRPO series for multi-step scenario
repeat_times: 8 # Number of rollouts per prompt for advantage estimation
optimizer:
lr: 1e-6 # Learning rate
policy_loss_fn: "rec" # Policy loss function
policy_loss_fn_args: # Policy loss function arguments
epsilon_low: 0.2
epsilon_high: 0.2
clip_mode: "one-side"
weight: "none"
temp: 1.0
regularizer: "none"
regularizer_coef: 0.0
kl_loss_fn: 'k2' # KL divergence loss function
kl_loss_fn_args:
kl_coef: 0.0 # KL divergence coefficient
advantage_fn_args:
std_cal_level: 'batch' # Advantage normalization level
model:
model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen3-4B-Instruct-2507} # Base model path
max_response_tokens: 4096 # Max tokens per response
max_model_len: 20480 # Max context length
buffer:
total_epochs: 10 # Total training epochs
batch_size: 64 # Batch size per explore step
train_batch_size: 2560 # 64*8*5, total experiences per training step
trainer_input:
experience_buffer:
name: experience_buffer
storage_type: queue
replay_buffer:
enable: true # Enable experience replay
priority_fn: 'decay_limit_randomization'
priority_fn_args:
decay: 2.0
use_count_limit: 3
sigma: 2.0
explorer:
eval_interval: 10
max_repeat_times_per_runner: 1 # Max repeat times per runner
max_timeout: 3600 # Max timeout for each rollout (seconds)
rollout_model:
enable_history: true # Enable conversation history
enable_openai_api: true # Enable OpenAI-compatible API
enable_auto_tool_choice: true # Enable automatic tool selection
tool_call_parser: hermes # Parser for tool calls
engine_num: 4 # Number of vLLM engines for rollout model
tensor_parallel_size: 1 # TP size per engine for rollout model
enable_prefix_caching: false # Disable prefix caching
auxiliary_models:
- name: judge
model_path: Qwen/Qwen3-30B-A3B-Instruct-2507 # Judge model path
engine_num: 1 # Number of vLLM engines for judge model
tensor_parallel_size: 2 # TP size per engine for judge model
enable_thinking: false # Disable thinking/reasoning mode
max_prompt_tokens: 2048 # Max tokens for prompt
max_response_tokens: 128 # Max tokens for response
max_model_len: 2500 # Max model context length
synchronizer:
sync_style: dynamic_by_explorer # Sync triggered dynamically by explorer
sync_interval: 5 # Sync every N steps
sync_timeout: 3600 # Timeout for synchronization (seconds)
trainer:
save_interval: 100 # Save checkpoint every N steps
grad_clip: 1.0 # Gradient clipping value
use_dynamic_bsz: true # Use dynamic batch size
max_token_len_per_gpu: 16384 # Max token length per GPU
ulysses_sequence_parallel_size: 1 # Sequence parallel size for Ulysses

Binary file not shown.

After

Width:  |  Height:  |  Size: 470 KiB

379
tuner/email_search/main.py Normal file
View File

@@ -0,0 +1,379 @@
# -*- coding: utf-8 -*-
"""Example of training an Email Search agent with Trinity-RFT."""
import os
from typing import Dict
from _email_search_agent import EmailSearchAgent
from _utils import ( # pylint: disable=E0611
AnswerModel,
FinalRubric,
QueryModel,
)
from agentscope import logger
from agentscope.formatter import OpenAIChatFormatter
from agentscope.message import Msg
from agentscope.tuner import (
TunerModelConfig,
DatasetConfig,
JudgeOutput,
WorkflowOutput,
AlgorithmConfig,
tune,
)
from agentscope.model import ChatModelBase
SYSTEM_PROMPT = """You are an email search agent. You are given a user query
and a list of tools you can use to search the user's email. Use the tools to
search the user's emails and find the answer to the user's query. You may take
up to {max_turns} turns to find the answer, so if your first seach doesn't
find the answer, you can try with different keywords.
Always describe what you see and plan your next steps clearly. When taking
actions, explain what you're doing and why. When the answer to the task is
found, call `generate_response` to finish the process. Only call
`generate_response` when answer is found. You should not respond any next steps
in `generate_response`. Complete all steps and then call `generate_response`.
User's email address is {inbox_address}
Today's date is {query_date}
"""
async def run_email_search_agent(
task: Dict,
model: ChatModelBase,
auxiliary_models: Dict[str, ChatModelBase],
) -> WorkflowOutput: # noqa: PLR0915
"""A workflow function using the Email Search agent to solve tasks.
Args:
task (Dict): The task to be solved.
Should contain fields from QueryModel.
model (TrinityChatModel): The language model to use.
Returns:
WorkflowOutput: The output containing the agent's response.
"""
assert len(auxiliary_models) > 0, "LLM-as-a-Judge is required"
# Parse task data
query = QueryModel.model_validate(task)
question = task.get("question", task.get("task_desc", ""))
# Get workflow arguments with defaults
workflow_args = task.get("workflow_args", {})
max_turns = int(workflow_args.get("max_turns", 10))
# Format system prompt
system_prompt = SYSTEM_PROMPT.format(
max_turns=max_turns,
inbox_address=query.inbox_address,
query_date=query.query_date,
)
# Create EmailSearchAgent
agent = EmailSearchAgent(
name="email_search_agent",
sys_prompt=system_prompt,
model=model,
formatter=OpenAIChatFormatter(),
max_iters=max_turns,
)
# Reset agent state for a new rollout
await agent.reset()
# Run the agent with structured output
response = await agent.reply(
msg=Msg("user", question, role="user"),
structured_model=AnswerModel,
)
# Extract answer and sources from response metadata
answer_and_sources = response.metadata or {}
if not answer_and_sources:
# Fallback: try to parse from content
answer_and_sources = {
"answer": response.get_text_content() or "",
"sources": [],
}
# Store agent state for judge function
# We'll pass this through the response metadata
response_metadata = {
"answer_and_sources": answer_and_sources,
"query": query.model_dump(),
"message_id_list": agent.message_id_list,
"ever_read_message_ids": agent.ever_read_message_ids,
# Estimate actual_turns from memory length
"actual_turns": (
max(1, (len(agent.memory.content) - 1) // 2)
if len(agent.memory.content) > 1
else 1
),
}
# Update response metadata
if response.metadata is None:
response.metadata = {}
response.metadata.update(response_metadata)
return WorkflowOutput(
response=response,
)
def _calculate_partial_rewards(rubric: FinalRubric) -> float:
"""Calculate partial rewards based on rubric."""
partial_rewards = 0.0
partial_rewards += 0.1 if rubric.ever_found_right_email else 0
partial_rewards += 0.1 if rubric.ever_read_right_email else 0
partial_rewards += 0.1 if rubric.sources_correct else 0
return partial_rewards
def _calculate_correct_answer_reward(
rubric: FinalRubric,
max_turns: int,
) -> float:
"""Calculate reward for correct answers."""
reward = 1.0
reward += 0.3 if rubric.sources_correct else 0
reward += 0.1 / rubric.num_sources if rubric.num_sources > 0 else 0
reward += 0.1 * (1 - rubric.num_turns / max_turns)
return reward
def _initialize_rubric(
answer: str,
sources: list[str],
actual_turns: int,
query: QueryModel,
message_id_list: list[str],
ever_read_message_ids: list[str],
) -> FinalRubric:
"""Initialize and populate rubric with basic information."""
rubric = FinalRubric()
rubric.attempted_answer = answer is not None and answer != ""
rubric.returned_i_dont_know = answer == "I don't know"
rubric.num_sources = len(sources)
rubric.num_turns = actual_turns
if len(query.message_ids) > 0:
rubric.ever_found_right_email = query.message_ids[0] in message_id_list
rubric.ever_read_right_email = (
query.message_ids[0] in ever_read_message_ids
)
rubric.sources_correct = query.message_ids[0] in sources
return rubric
async def email_search_judge(
task: Dict,
response: Msg,
auxiliary_models: Dict[str, ChatModelBase],
) -> JudgeOutput:
"""A judge function to calculate reward based on agent's response.
Args:
task (Dict): The task information for the corresponding workflow.
response (Msg): The response generated by the corresponding workflow.
auxiliary_models (Dict[str, ChatModelBase]):
A dictionary of additional chat models available for LLM-as-a-Judge
usage. The keys are model names, and the values are the
corresponding ChatModelBase instances.
Returns:
JudgeOutput: The reward value assigned by the judge function.
"""
# Extract metadata from response
metadata = response.metadata or {}
answer_and_sources = metadata.get("answer_and_sources", {})
query_dict = metadata.get("query", {})
message_id_list = metadata.get("message_id_list", [])
ever_read_message_ids = metadata.get("ever_read_message_ids", [])
actual_turns = metadata.get("actual_turns", 0)
# Parse query model
if not query_dict:
query_dict = task
query = QueryModel.model_validate(query_dict)
# Get arguments
workflow_args = task.get("workflow_args", {})
max_turns = int(workflow_args.get("max_turns", 10))
# Extract answer and sources
try:
answer = answer_and_sources.get("answer", None)
sources = answer_and_sources.get("sources", [])
except Exception:
result = {"accuracy": 0.0, "format": -1.0}
return JudgeOutput(
reward=sum(result.values()),
metrics=result,
)
if answer is None:
result = {"accuracy": 0.0, "format": -1.0}
return JudgeOutput(
reward=sum(result.values()),
metrics=result,
)
# Initialize rubric
rubric = _initialize_rubric(
answer,
sources,
actual_turns,
query,
message_id_list,
ever_read_message_ids,
)
# Judge correctness using LLM-as-a-Judge
try:
judge_model = (
auxiliary_models.get("judge") or list(auxiliary_models.values())[0]
if auxiliary_models
else None
)
judge_response = await judge_correctness(
answer,
query,
judge_model,
)
rubric.answer_correct = judge_response
except Exception as e:
logger.error("Error judging correctness: %s", e)
rubric.answer_correct = False
# Calculate rewards
partial_rewards = _calculate_partial_rewards(rubric)
if rubric.attempted_answer and not rubric.answer_correct:
result = {"accuracy": -1.0, "format": partial_rewards}
elif rubric.returned_i_dont_know or rubric.ran_out_of_turns:
result = {"accuracy": 0.0, "format": partial_rewards}
elif rubric.answer_correct:
reward = _calculate_correct_answer_reward(rubric, max_turns)
result = {"accuracy": 1.0, "format": reward}
else:
result = {"accuracy": 0.0, "format": 0.0}
metrics = result.copy()
metrics.update({"actual_turns": actual_turns})
return JudgeOutput(
reward=sum(result.values()),
metrics=metrics,
)
# LLM-as-a-judge
async def judge_correctness(
answer: str,
query: QueryModel,
judge: ChatModelBase,
) -> bool:
"""Use an LLM to decide whether *answer* matches *query.answer*.
Returns a boolean *accept* flag used for scoring.
"""
system_prompt = """You are given a question, the reference answer
(labelled **Reference answer**), and an answer generated by an AI assistant
(labelled **AI answer**).
Follow these steps to decide whether the AI answer should be accepted:
1. Identify EXACTLY what information the **question** is asking for
(e.g. who, what, when, where, why, how, quantity, etc.).
2. From the **Reference answer**, extract ONLY the facts that are required
to directly satisfy the information need identified in step 1. Treat all
other facts as non-essential context.
3. Verify that every essential fact from step 2 appears in the **AI answer**
with the same meaning. Differences in wording, order, or additional
non-conflicting details are allowed.
4. If any essential fact is missing or contradicted in the **AI answer**,
then *accept* must be **false**. Otherwise *accept* must be **true**.
Important: Do NOT penalise the **AI answer** for omitting non-essential
facts that appear in the **Reference answer**. The answer should only be
rejected for errors or omissions in the information explicitly requested by
the question.
Return your judgement **accept** from **true** and **false**. Do not return
any other text or formatting.
"""
prompt = (
f"Question: {query.question}\n"
f"Reference answer: {query.answer}\n"
f"AI answer: {answer}"
)
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt},
]
chat_response = await judge(messages)
# Extract text content from ChatResponse
result_parts = []
for block in chat_response.content:
if isinstance(block, dict) and block.get("type") == "text":
result_parts.append(str(block.get("text", "")))
result = "".join(result_parts)
logger.info("LLM judge response: %s", result)
return "true" in result.lower()
# End of LLM-as-a-judge
if __name__ == "__main__":
config_path = os.path.join(
os.path.dirname(__file__),
"config.yaml",
)
dataset = DatasetConfig(
path="/path/to/enron_emails_dataset",
split="train",
)
tuner_model = TunerModelConfig(
model_path="Qwen/Qwen3-4B-Instruct-2507",
max_model_len=20480,
max_tokens=4096,
inference_engine_num=4,
reasoning_parser=None,
)
aux_models = {
"judge": TunerModelConfig(
model_path="Qwen/Qwen3-30B-A3B-Instruct-2507",
max_model_len=2500,
max_tokens=2048,
inference_engine_num=1,
tensor_parallel_size=2,
tool_call_parser=None,
reasoning_parser=None,
),
}
algorithm = AlgorithmConfig(
algorithm_type="multi_step_grpo",
group_size=8,
batch_size=64,
learning_rate=1e-6,
)
tune(
workflow_func=run_email_search_agent,
judge_func=email_search_judge,
train_dataset=dataset,
model=tuner_model,
auxiliary_models=aux_models,
algorithm=algorithm,
config_path=config_path,
)

View File

@@ -0,0 +1,357 @@
# -*- coding: utf-8 -*-
"""
Prepare data for training.
Modified from OpenPipe/ART
"""
import logging
import os
import sqlite3
from datetime import datetime
from datasets import Dataset, Features, Sequence, Value, load_dataset
from tqdm import tqdm
# Resolve paths relative to this file
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
# Database will live in "../data/enron_emails.db" relative to project root
DEFAULT_DB_PATH = os.path.join(BASE_DIR, "..", "..", "data", "enron_emails.db")
DEFAULT_REPO_ID = "corbt/enron-emails"
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
)
# --- Database Schema ---
SQL_CREATE_TABLES = """
DROP TABLE IF EXISTS recipients;
DROP TABLE IF EXISTS emails_fts;
DROP TABLE IF EXISTS emails;
CREATE TABLE emails (
id INTEGER PRIMARY KEY AUTOINCREMENT,
message_id TEXT UNIQUE,
subject TEXT,
from_address TEXT,
date TEXT, -- Store as ISO 8601 string 'YYYY-MM-DD HH:MM:SS'
body TEXT,
file_name TEXT
);
CREATE TABLE recipients (
email_id INTEGER,
recipient_address TEXT,
recipient_type TEXT, -- 'to', 'cc', 'bcc'
FOREIGN KEY(email_id) REFERENCES emails(id) ON DELETE CASCADE
);
"""
SQL_CREATE_INDEXES_TRIGGERS = """
CREATE INDEX idx_emails_from ON emails(from_address);
CREATE INDEX idx_emails_date ON emails(date);
CREATE INDEX idx_emails_message_id ON emails(message_id);
CREATE INDEX idx_recipients_address ON recipients(recipient_address);
CREATE INDEX idx_recipients_type ON recipients(recipient_type);
CREATE INDEX idx_recipients_email_id ON recipients(email_id);
CREATE INDEX idx_recipients_address_email ON recipients(
recipient_address, email_id
);
CREATE VIRTUAL TABLE emails_fts USING fts5(
subject,
body,
content='emails',
content_rowid='id'
);
CREATE TRIGGER emails_ai AFTER INSERT ON emails BEGIN
INSERT INTO emails_fts (rowid, subject, body)
VALUES (new.id, new.subject, new.body);
END;
CREATE TRIGGER emails_ad AFTER DELETE ON emails BEGIN
DELETE FROM emails_fts WHERE rowid=old.id;
END;
CREATE TRIGGER emails_au AFTER UPDATE ON emails BEGIN
UPDATE emails_fts SET subject=new.subject, body=new.body
WHERE rowid=old.id;
END;
INSERT INTO emails_fts (rowid, subject, body)
SELECT id, subject, body FROM emails;
"""
# --- Functions ---
def download_dataset(repo_id: str) -> Dataset:
"""Downloads the dataset from Hugging Face Hub."""
logging.info(
"Attempting to download dataset from Hugging Face Hub: %s",
repo_id,
)
expected_features = Features(
{
"message_id": Value("string"),
"subject": Value("string"),
"from": Value("string"),
"to": Sequence(Value("string")),
"cc": Sequence(Value("string")),
"bcc": Sequence(Value("string")),
"date": Value("timestamp[us]"),
"body": Value("string"),
"file_name": Value("string"),
},
)
dataset_obj = load_dataset(
repo_id,
features=expected_features,
split="train",
)
# Basic type check remains useful
if not isinstance(dataset_obj, Dataset):
raise TypeError(f"Expected Dataset, got {type(dataset_obj)}")
logging.info(
"Successfully loaded dataset '%s' with %d records.",
repo_id,
len(dataset_obj),
)
return dataset_obj
def create_database(db_path: str) -> None:
"""Creates the SQLite database and tables."""
logging.info("Creating SQLite database and tables at: %s", db_path)
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
cursor.executescript(SQL_CREATE_TABLES)
conn.commit()
conn.close()
logging.info("Database tables created successfully.")
def _should_skip_email(
body: str,
message_id: str,
to_list: list[str],
cc_list: list[str],
bcc_list: list[str],
) -> bool:
"""Check if email should be skipped based on filters."""
if len(body) > 5000:
logging.debug(
"Skipping email %s: Body length > 5000 characters.",
message_id,
)
return True
total_recipients = len(to_list) + len(cc_list) + len(bcc_list)
if total_recipients > 30:
logging.debug(
"Skipping email %s: Total recipients (%d) > 30.",
message_id,
total_recipients,
)
return True
return False
def _prepare_recipient_data(
email_pk_id: int,
to_list: list[str],
cc_list: list[str],
bcc_list: list[str],
) -> list[tuple[int, str, str]]:
"""Prepare recipient data for database insertion."""
recipient_data = []
for addr in to_list:
recipient_data.append((email_pk_id, addr, "to"))
for addr in cc_list:
recipient_data.append((email_pk_id, addr, "cc"))
for addr in bcc_list:
recipient_data.append((email_pk_id, addr, "bcc"))
return recipient_data
def populate_database(db_path: str, dataset: Dataset) -> None:
"""Populates the database with data from the Hugging Face dataset."""
logging.info("Populating database %s...", db_path)
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
# --- Performance Pragmas ---
conn.execute("PRAGMA synchronous = OFF;")
conn.execute("PRAGMA journal_mode = MEMORY;")
record_count = 0
skipped_count = 0
duplicate_count = 0
processed_emails = set()
conn.execute("BEGIN TRANSACTION;")
for email_data in tqdm(dataset, desc="Inserting emails"):
assert isinstance(email_data, dict)
message_id = email_data["message_id"]
subject = email_data["subject"]
from_address = email_data["from"]
date_obj: datetime = email_data["date"]
body = email_data["body"]
file_name = email_data["file_name"]
to_list_raw = email_data["to"]
cc_list_raw = email_data["cc"]
bcc_list_raw = email_data["bcc"]
date_str = date_obj.strftime("%Y-%m-%d %H:%M:%S")
to_list = [str(addr) for addr in to_list_raw if addr]
cc_list = [str(addr) for addr in cc_list_raw if addr]
bcc_list = [str(addr) for addr in bcc_list_raw if addr]
if _should_skip_email(body, message_id, to_list, cc_list, bcc_list):
skipped_count += 1
continue
email_key = (subject, body, from_address)
if email_key in processed_emails:
logging.debug(
"Skipping duplicate email (Subject: %s..., From: %s)",
subject[:50],
from_address,
)
duplicate_count += 1
continue
processed_emails.add(email_key)
cursor.execute(
"""
INSERT INTO emails (
message_id, subject, from_address, date, body, file_name
)
VALUES (?, ?, ?, ?, ?, ?)
""",
(message_id, subject, from_address, date_str, body, file_name),
)
email_pk_id = cursor.lastrowid
if email_pk_id is None:
logging.warning(
"Failed to get email ID after insert for message_id: %s",
message_id,
)
continue
recipient_data = _prepare_recipient_data(
email_pk_id,
to_list,
cc_list,
bcc_list,
)
if recipient_data:
cursor.executemany(
"""
INSERT INTO recipients (
email_id, recipient_address, recipient_type
)
VALUES (?, ?, ?)
""",
recipient_data,
)
record_count += 1
conn.commit()
conn.close()
logging.info("Successfully inserted %d email records.", record_count)
if skipped_count > 0:
logging.info(
"Skipped %d email records due to length or recipient limits.",
skipped_count,
)
if duplicate_count > 0:
logging.info(
"Skipped %d duplicate email records "
"(based on subject, body, from).",
duplicate_count,
)
def create_indexes_and_triggers(db_path: str) -> None:
"""Creates indexes and triggers on the populated database."""
logging.info("Creating indexes and triggers for database: %s...", db_path)
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
cursor.executescript(SQL_CREATE_INDEXES_TRIGGERS)
conn.commit()
conn.close()
logging.info("Indexes and triggers created successfully.")
def generate_database(
repo_id: str = DEFAULT_REPO_ID,
db_path: str = DEFAULT_DB_PATH,
overwrite: bool = False,
) -> None:
"""
Generates the SQLite database from the specified Hugging Face dataset.
Simplified version without extensive error handling.
Args:
repo_id: The Hugging Face repository ID for the dataset.
db_path: The path where the SQLite database file should be
created.
overwrite: If True, any existing database file at db_path will
be removed.
"""
logging.info(
"Starting database generation for repo '%s' at '%s'",
repo_id,
db_path,
)
logging.info("Overwrite existing database: %s", overwrite)
db_dir = os.path.dirname(db_path)
if db_dir and not os.path.exists(db_dir):
logging.info("Creating data directory: %s", db_dir)
os.makedirs(db_dir)
if overwrite and os.path.exists(db_path):
logging.warning("Removing existing database file: %s", db_path)
os.remove(db_path)
elif not overwrite and os.path.exists(db_path):
# If not overwriting and file exists, subsequent steps might fail
# or behave unexpectedly. We are removing the explicit error here
# as requested.
logging.warning(
"Database file %s exists and overwrite is False. "
"Assuming file is already generated.",
db_path,
)
return
# 1. Download dataset
dataset = download_dataset(repo_id)
# 2. Create database schema (Tables only)
# Note: This will fail if overwrite=False and the file exists with
# incompatible schema/data.
create_database(db_path)
# 3. Populate database
populate_database(db_path, dataset)
# 4. Create Indexes and Triggers
create_indexes_and_triggers(db_path)
logging.info("Database generation process completed for %s.", db_path)
logging.info(
"Please set the environment variable DEFAULT_EMAIL_DB_PATH "
"to this path.",
)
if __name__ == "__main__":
generate_database(overwrite=True)

Binary file not shown.

After

Width:  |  Height:  |  Size: 442 KiB

273
tuner/frozen_lake/README.md Normal file
View File

@@ -0,0 +1,273 @@
# Training FrozenLake Agent with RL using AgentScope-Tuner
## Summary
This example demonstrates how to use AgentScope-Tuner to implement reinforcement fine-tuning for the [Frozen Lake](https://gymnasium.farama.org/environments/toy_text/frozen_lake/) task using [Trinity-RFT](https://github.com/agentscope-ai/Trinity-RFT). The agent learns to navigate a frozen lake grid from a starting position to a goal while avoiding holes through multi-step interactions with the environment.
## Task Setting
### Agent Goal
The agent's objective is to navigate from the starting position (S) to the goal position (G) on a frozen lake grid without falling into holes (H). The agent must:
- Plan a path through frozen tiles (F) to reach the goal
- Avoid holes that terminate the episode with zero reward
- Complete the task within a limited number of steps
### Agent Type
The agent is implemented as a **ReActAgent** (Reasoning and Acting Agent) that:
- Observes the current state of the frozen lake grid
- Reasons about the best action to take
- Executes actions (Up, Down, Left, Right) to move through the environment
- Maintains internal state across multiple steps in an episode
### Environment
The environment is based on Gymnasium's FrozenLake environment, wrapped to provide:
- **Grid-based navigation**: Randomly generated maps with configurable size (2x2 to 6x6)
- **Tile types**:
- `S`: Start position
- `F`: Frozen tile (safe to walk on)
- `H`: Hole (terminates episode with reward 0)
- `G`: Goal (terminates episode with reward +1.0)
- **Action space**: Discrete actions (Up, Down, Left, Right)
- **Reward structure**:
- +1.0 for reaching the goal
- 0.0 for falling into a hole or failing to reach the goal
- **Observations**: Text-based grid representation showing current player position
The agent does not use external tools. It interacts directly with the environment through:
- `env.reset(task)`: Initialize environment with task parameters
- `env.step(action)`: Execute action and receive observation, reward, and done flag
- `env.render()`: Get text representation of current state
## Dataset Preparation
The dataset contains task parameters for generating FrozenLake environments. Each sample specifies:
- `seed`: Random seed for reproducible map generation
- `size`: Grid size (randomly sampled from 2 to `map_max_size`, e.g., 4x4, 6x6)
- `p`: Probability that a tile is frozen (vs. being a hole), randomly sampled from 0.6 to 0.85
- `index`: Sample index
- `uid`: Unique identifier combining seed, size, and p
Run the data preparation script to generate training and test datasets:
```bash
python get_frozenlake_data.py --map_max_size 6 --train_size 10000 --test_size 100
```
This will create parquet files in the specified directory:
```
/path/to/frozenlake_dataset/
├── train.parquet # 10000 training samples
└── test.parquet # 100 test samples
```
Each sample looks like:
```json
{"seed": 12345, "size": 5, "p": 0.75, "index": 0, "uid": "12345_5_0.75"}
```
**Note**: The data preparation script ensures that all generated maps have a valid path from start to goal within the maximum allowed steps (`env_max_steps=8`), filtering out unsolvable tasks.
## Code Implementation
This section provides a high-level overview of the code implementation. For detailed implementation, please refer to the source code.
### High-level Overview
The implementation consists of three main components:
1. **Agent** (`FrozenLakeAgent`): Extends `ReActAgent` to handle multi-step navigation
2. **Environment** (`FrozenLakeEnv`): Wraps Gymnasium's FrozenLake environment
3. **Workflow** (`run_frozen_lake`): Orchestrates the agent-environment interaction loop
### Agent Workflow
The workflow function `run_frozen_lake` implements the agent-environment interaction loop:
```python
async def run_frozen_lake(
task: Dict,
model: ChatModelBase,
auxiliary_models: Dict[str, ChatModelBase],
) -> WorkflowOutput:
# ...
# Create agent and environment
agent = FrozenLakeAgent(model=model, ...)
env = FrozenLakeEnv(...)
observation, _ = env.reset(task)
rewards = []
# ...
# Agent-environment interaction loop
for _ in range(max_steps):
response = await agent.reply(msg=Msg("user", agent.get_prompt(observation), role="user"))
action = agent.get_action(response)
observation, reward, done, _ = env.step(action)
rewards.append(reward)
if done:
break
# ...
final_reward = sum(rewards)
final_response = Msg("assistant", response_content, role="assistant")
return WorkflowOutput(
reward=final_reward,
response=final_response,
metrics={
"env_steps": float(step_count),
"env_done": float(done),
},
)
```
**Key characteristics:**
- Multi-step interaction: The agent takes multiple actions in a single episode, unlike single-turn QA tasks
- State tracking: The agent maintains internal state (current step, last action, last observation) across steps
- Error handling: Invalid actions or agent errors are caught and handled gracefully
### Reward Function
No separate judge function is needed. The reward comes directly from the environment:
- 1.0: Agent successfully reaches the goal (G)
- 0.0: Agent falls into a hole (H) or fails to reach the goal within the maximum steps
The reward is computed as the sum of step rewards throughout the episode. The workflow returns:
- `reward`: Final cumulative reward
- `response`: Final response message containing observation, total reward, steps taken, and termination reason
- `metrics`: Additional metrics including `env_steps` (number of steps taken) and `env_done` (whether episode completed)
### Implementation Details
The environment (`FrozenLakeEnv`) wraps Gymnasium's FrozenLake and provides:
- `reset(task)`: Initialize the environment with task parameters
- `step(action)`: Execute an action and return (observation, reward, done, info)
- `render()`: Return a text representation of the current state
The agent (`FrozenLakeAgent`) extends `ReActAgent` and provides:
- `reply(msg)`: Reply to a message and return an action (inherited from AgentScope)
- `get_prompt(observation)`: Generate a prompt from the current observation
- `get_action(response)`: Parse the model's response to extract an action (Up/Down/Left/Right)
- `update_state(action, observation)`: Update internal state after each step
See [frozenlake_env.py](./frozenlake_env.py) and [frozenlake_agent.py](./frozenlake_agent.py) for implementation details.
### Step 4: Use `tune` to train the workflow
```python
from agentscope.tuner import tune, DatasetConfig
if __name__ == "__main__":
config_path = os.path.join(
os.path.dirname(__file__),
"config.yaml",
)
dataset = DatasetConfig(
path="/path/to/frozenlake_dataset",
name="default",
split="train",
)
tune(
workflow_func=run_frozen_lake,
train_dataset=dataset,
config_path=config_path,
)
```
See [config.yaml](./config.yaml) for the training configuration. For full configuration details, see [Trinity-RFT Configuration Guide](https://agentscope-ai.github.io/Trinity-RFT/en/main/tutorial/trinity_configs.html).
---
## How to Run
### Prerequisites
- At least 2 NVIDIA GPUs with CUDA 12.8 or newer
- Follow the Trinity-RFT [installation guide](https://agentscope-ai.github.io/Trinity-RFT/en/main/tutorial/trinity_installation.html) to install the latest version from source code
- Install gymnasium for the FrozenLake environment:
```bash
pip install gymnasium[toy_text]
```
- Download the model checkpoint (example):
```bash
huggingface-cli download Qwen/Qwen2.5-3B-Instruct
```
### Step 1: Prepare the Dataset
```bash
python get_frozenlake_data.py --map_max_size 6 --train_size 10000 --test_size 100
```
Update the dataset path in `main.py` to point to your generated dataset directory.
### Step 2: Configure the Training
Key configuration can be identified in the code, including:
**Algorithm Configuration** (`AlgorithmConfig`):
- `algorithm_type`: `multi_step_grpo` (Group Relative Policy Optimization for multi-step tasks)
- `group_size`: Number of policy update iterations per batch (default: 16)
- `batch_size`: Batch size for training (default: 32)
- `learning_rate`: Learning rate (default: 1e-6)
**Model Configuration** (`TunerModelConfig`):
- `model_path`: Path to the base model (e.g., `Qwen/Qwen2.5-3B-Instruct`)
- `max_model_len`: Maximum model context length (default: 25600)
- `max_tokens`: Maximum tokens for response generation (default: 2048)
- `inference_engine_num`: Number of inference engines (default: 6, using 6 GPUs for inference)
**Dataset Configuration** (`DatasetConfig`):
- `path`: Path to the dataset (default: `/path/to/frozenlake`)
- `split`: Split of the dataset (default: `train`)
Adjust these parameters based on your hardware resources and training requirements. Other parameters can be spetified in [config.yaml](./config.yaml).
### Step 3: Set Up Ray Cluster
Set up a [Ray](https://github.com/ray-project/ray) cluster:
```bash
ray start --head
# for multi-node setup, run the following command on worker nodes
# ray start --address=<master_address>
```
### Step 4: Run the Training Script
```bash
python main.py
```
The training will start and you can monitor the progress through the logs. Checkpoints will be saved once every `trainer.save_interval` steps.
## Experimental Results
### Training Reward Curve
The reward curve during training shows the agent's learning progress:
<div align="center">
<img src="./critic_rewards_mean.png" alt="reward" width="90%"/>
</div>
The training reward typically increases over epochs as the agent learns to navigate the frozen lake more effectively.
### Example Agent Output
An example of agent output is given below:
```
From the current observation, let's analyze the situation. The player (P) is at: (4, 0), and the goal (G) is at: (2, 3). There is also a hole (O) at (4, 4). Given this, I can move towards the goal without worrying about slippery tiles right now.
The shortest path from P to G involves moving left (4 steps) followed by moving down (1 step), since going directly would bypass the hole or move us further from the goal. Let's move left first.
Let's take the action ```Left```.
```

View File

@@ -0,0 +1,252 @@
# 使用 AgentScope-Tuner 训练 FrozenLake Agent
## 摘要
本示例展示如何使用 AgentScope-Tuner 配合 [Trinity-RFT](https://github.com/agentscope-ai/Trinity-RFT) 对 [Frozen Lake](https://gymnasium.farama.org/environments/toy_text/frozen_lake/) 任务进行强化微调。智能体需要在冰湖网格中从起点走到终点,避开坑洞,并在有限步数内完成任务。
## 任务设定
### 智能体目标
智能体要在冰湖网格上从起点 (S) 抵达终点 (G),同时:
- 规划路径经过冰面 (F) 到达终点
- 避开会结束回合且奖励为 0 的坑洞 (H)
- 在限定步数内完成任务
### 智能体类型
智能体实现为 **ReActAgent**,它的行为包括:
- 观察当前冰湖网格状态
- 推理下一步最优动作
- 执行动作(上、下、左、右)在环境中移动
- 在多步交互中维护内部状态
### 环境
环境基于 Gymnasium 的 FrozenLake并提供
- **网格导航**:随机生成 2x2 至 6x6 的地图
- **格子类型**
- `S`:起点
- `F`:冰面(可通行)
- `H`:坑洞(奖励 0结束回合
- `G`:终点(奖励 +1.0,结束回合)
- **动作空间**:离散动作(上、下、左、右)
- **奖励设计**
- 到达终点 +1.0
- 掉入坑洞或未在最大步数内到达终点为 0.0
- **观测**:返回当前玩家位置的文本网格表示
智能体不使用外部工具,直接通过以下接口与环境交互:
- `env.reset(task)`:根据任务参数初始化环境
- `env.step(action)`:执行动作,返回观测、奖励和结束标志
- `env.render()`:返回当前状态的文本表示
## 数据集准备
数据集包含用于生成 FrozenLake 环境的任务参数,每个样本包含:
- `seed`:随机种子,保证地图可复现
- `size`:网格大小(在 2 和 `map_max_size` 之间随机,如 4x4、6x6
- `p`格子为冰面的概率0.6 到 0.85 之间随机),其余为坑洞
- `index`:样本索引
- `uid`:由 seed、size、p 组合而成的唯一 ID
运行数据准备脚本生成训练集与测试集:
```bash
python get_frozenlake_data.py --map_max_size 6 --train_size 10000 --test_size 100
```
生成的目录结构示例:
```
/path/to/frozenlake_dataset/
├── train.parquet # 10000 条训练样本
└── test.parquet # 100 条测试样本
```
样本示例:
```json
{"seed": 12345, "size": 5, "p": 0.75, "index": 0, "uid": "12345_5_0.75"}
```
**注意**:脚本会过滤无解的地图,确保在最大步数 (`env_max_steps=8`) 内存在从起点到终点的可行路径。
## 代码实现
本节提供代码实现的高级概览。详细实现请参考源代码。
### 高级概览
实现由三部分组成:
1. **Agent** (`FrozenLakeAgent`):继承 `ReActAgent`,负责多步交互
2. **环境** (`FrozenLakeEnv`):封装 Gymnasium FrozenLake
3. **工作流** (`run_frozen_lake`):组织智能体与环境的交互流程
### 工作流
`run_frozen_lake` 实现多步交互流程:
```python
async def run_frozen_lake(
task: Dict,
model: ChatModelBase,
auxiliary_models: Dict[str, ChatModelBase],
) -> WorkflowOutput:
# ...
# 创建智能体和环境
agent = FrozenLakeAgent(model=model, ...)
env = FrozenLakeEnv(...)
observation, _ = env.reset(task)
rewards = []
# ...
# 智能体-环境交互循环
for _ in range(max_steps):
response = await agent.reply(msg=Msg("user", agent.get_prompt(observation), role="user"))
action = agent.get_action(response)
observation, reward, done, _ = env.step(action)
rewards.append(reward)
if done:
break
# ...
final_reward = sum(rewards)
final_response = Msg("assistant", response_content, role="assistant")
return WorkflowOutput(
reward=final_reward,
response=final_response,
metrics={"env_steps": float(step_count), "env_done": float(done)},
)
```
**关键特性:**
- 多步交互:单次 episode 内多次动作,不是单轮 QA
- 状态跟踪:记录当前步、上次动作与观测
- 错误处理:无效动作或异常会被捕获并处理
### 奖励函数
无需额外 judge奖励由环境直接给出
- 1.0:到达终点
- 0.0:掉入坑洞或超步数未达终点
工作流返回:
- `reward`:累计奖励
- `response`:包含观测、总奖励、步数、终止原因的最终回复
- `metrics``env_steps`(步数)、`env_done`(是否结束)
### 实现细节
环境 (`FrozenLakeEnv`) 封装了 Gymnasium 的 FrozenLake提供
- `reset(task)`: 使用任务参数初始化环境
- `step(action)`: 执行动作并返回 (observation, reward, done, info)
- `render()`: 返回当前状态的文本表示
智能体 (`FrozenLakeAgent`) 继承 `ReActAgent`,提供:
- `reply(msg)`: 回复消息并返回动作(继承自 AgentScope
- `get_prompt(observation)`: 从当前观测生成提示
- `get_action(response)`: 解析模型响应以提取动作Up/Down/Left/Right
- `update_state(action, observation)`: 在每步后更新内部状态
详细实现请参考 [frozenlake_env.py](./frozenlake_env.py) 和 [frozenlake_agent.py](./frozenlake_agent.py)。
### 步骤 4使用 `tune` 训练工作流
```python
from agentscope.tuner import tune, DatasetConfig
if __name__ == "__main__":
config_path = os.path.join(
os.path.dirname(__file__),
"config.yaml",
)
dataset = DatasetConfig(
path="/path/to/frozenlake_dataset",
name="default",
split="train",
)
tune(
workflow_func=run_frozen_lake,
train_dataset=dataset,
config_path=config_path,
)
```
训练配置请参考 [config.yaml](./config.yaml)。完整配置详情请参考 [Trinity-RFT 配置指南](https://agentscope-ai.github.io/Trinity-RFT/en/main/tutorial/trinity_configs.html)。
---
## 运行方法
### 依赖
- 至少 2 张 NVIDIA GPUCUDA 版本 ≥ 12.8
- 按 [Trinity-RFT 安装指南](https://agentscope-ai.github.io/Trinity-RFT/en/main/tutorial/trinity_installation.html) 从源码安装
- 安装 gymnasium 冰湖环境:
```bash
pip install gymnasium[toy_text]
```
- 下载模型权重(示例):
```bash
huggingface-cli download Qwen/Qwen2.5-3B-Instruct
```
### 步骤 1准备数据集
```bash
python get_frozenlake_data.py --map_max_size 6 --train_size 10000 --test_size 100
```
将 `main.py` 中的数据集路径改为你的生成目录。
### 步骤 2配置训练
关键配置可在代码中设置,包括:
**算法配置** (`AlgorithmConfig`)
- `algorithm_type`: `multi_step_grpo`(用于多步任务的组相对策略优化)
- `group_size`: 每批次的策略更新组大小(默认 16
- `batch_size`: 批大小(默认 32
- `learning_rate`: 学习率(默认 1e-6
**模型配置** (`TunerModelConfig`)
- `model_path`: 基础模型路径(如 `Qwen/Qwen2.5-3B-Instruct`
- `max_model_len`: 最大上下文长度(默认 25600
- `max_tokens`: 响应最大生成长度(默认 2048
- `inference_engine_num`: 推理引擎数量(默认 6表示用 6 个 GPU 进行推理)
**数据集配置** (`DatasetConfig`)
- `path`: 数据集路径(默认 `/path/to/frozenlake`
- `split`: 数据集分片(默认 `train`
可根据硬件资源和训练需求调整这些参数。其他参数可在 [config.yaml](./config.yaml) 中指定。
### 步骤 3设置 Ray 集群
设置 [Ray](https://github.com/ray-project/ray) 集群:
```bash
ray start --head
# 对于多节点设置,在工作节点上运行以下命令
# ray start --address=<master_address>
```
### 步骤 4运行训练脚本
```bash
python main.py
```
训练将开始,可通过日志监控进度。检查点将每 `trainer.save_interval` 步保存一次。
## 实验结果
### 训练奖励曲线
训练过程中的奖励曲线显示智能体的学习进度:
<div align="center">
<img src="./critic_rewards_mean.png" alt="reward" width="90%"/>
</div>
训练奖励通常随着智能体学习更有效地导航冰湖而随训练轮次增加。
### 智能体输出示例
智能体输出示例如下:
```
From the current observation, let's analyze the situation. The player (P) is at: (4, 0), and the goal (G) is at: (2, 3). There is also a hole (O) at (4, 4). Given this, I can move towards the goal without worrying about slippery tiles right now.
The shortest path from P to G involves moving left (4 steps) followed by moving down (1 step), since going directly would bypass the hole or move us further from the goal. Let's move left first.
Let's take the action ```Left```.
```

View File

@@ -0,0 +1,102 @@
# -*- coding: utf-8 -*-
"""Adapted from Trinity-RFT"""
import re
from _utils import SYSTEM_PROMPT, FrozenLakeAction # pylint: disable=E0611
from agentscope.agent import ReActAgent
from agentscope.formatter import OpenAIChatFormatter
from agentscope.message import Msg
from agentscope.model import OpenAIChatModel
INVALID_ACTION = "still"
VALID_ACTIONS = {
"left": 1,
"down": 2,
"right": 3,
"up": 4,
}
class FrozenLakeAgent(ReActAgent):
"""Agent for FrozenLake environment."""
def __init__(self, model: OpenAIChatModel, max_steps: int = 20):
super().__init__(
name="frozenlake_agent",
model=model,
sys_prompt=SYSTEM_PROMPT,
formatter=OpenAIChatFormatter(),
max_iters=1,
)
self.response_structure = FrozenLakeAction
self.current_step = 0
self.last_action = None
self.last_observation = None
self.max_steps = max_steps
def get_prompt(self, observation: str) -> str:
"""Get prompt for the agent based on current observation."""
prompt = (
f"Current Observation ({self.current_step}): \n"
+ observation
+ "\n"
+ (
"You have not achieved the goal, P has not reached G yet. "
"Please give the next action."
)
)
if self.current_step > 0 and self.last_action is not None:
if self.last_observation == observation:
prompt += (
"\nYour last response is invalid. "
"Your position didn't change at all. "
"You may need to recheck your thinking process, "
"action outputted, and the format of response. "
"Remember, you should only output the NEXT ACTION "
"at each iteration in the ``` ```. "
"For example, if you want to move up, "
"you should output ```Up```."
)
if (
self.max_steps is not None
and self.max_steps - self.current_step > 0
):
remaining = self.max_steps - self.current_step
prompt += (
f"\nThe maximum number of steps remaining is {remaining}."
)
return prompt
def get_action(self, msg: Msg) -> str:
"""Extract action from agent response message."""
response: str = (
msg.content
if isinstance(msg.content, str)
else msg.content[0].get("text")
)
action = INVALID_ACTION
matches = re.findall(r"```(.*?)```", response, re.DOTALL)
if matches:
last_match_content = matches[-1].strip()
action = last_match_content.lower()
if action not in VALID_ACTIONS:
action = INVALID_ACTION
return action
def update_state(self, action: str, observation: str) -> None:
"""Update agent state with action and observation."""
self.last_action = action
self.last_observation = observation
self.current_step += 1
async def reset(self) -> None:
"""Reset agent state for a new episode."""
self.current_step = 0
self.last_action = None
self.last_observation = None
await self.memory.clear()

View File

@@ -0,0 +1,316 @@
# -*- coding: utf-8 -*-
"""Adapted from Trinity-RFT"""
import copy
from typing import Dict, Optional, Tuple, Union
import numpy as np
try:
from gymnasium.envs.toy_text.frozen_lake import (
FrozenLakeEnv as GymFrozenLakeEnv,
)
except ImportError:
GymFrozenLakeEnv = object
from _utils import ( # pylint: disable=E0611
generate_random_map,
get_goal_position,
)
class FrozenLakeEnv(GymFrozenLakeEnv):
"""FrozenLake environment wrapper."""
# Map gym state in integer
MAP_LOOKUP = {
b"P": 0,
b"F": 1,
b"H": 2,
b"G": 3,
}
# Define rules to transform to rendered text observation of the environment
GRID_LOOKUP = {
0: " P \t", # player
1: " _ \t", # frozen
2: " O \t", # hole
3: " G \t", # goal
4: " X \t", # player fall into hole
5: "\t", # player on goal
}
ACTION_LOOKUP = {
"still": 0,
"left": 1,
"down": 2,
"right": 3,
"up": 4,
}
INVALID_ACTION = 0
PENALTY_FOR_INVALID = -1
def __init__(
self,
max_steps: int = 8,
desc: Optional[str] = None,
is_slippery: bool = False,
size: int = 8,
p: float = 0.8,
seed: int = 42,
):
self.max_steps = max_steps or 8
self.desc: Union[str, np.ndarray, None] = desc
self.is_slippery = is_slippery
self.size = size
self.p = p
self.seed = seed
self.render_mode: Optional[str] = None
try:
import gymnasium as gym
except ImportError as e:
error_message = (
"Gymnasium is not installed. "
"Please install gymnasium first before "
"running the frozen_lake workflow. "
f"Error: {str(e)}"
)
raise ImportError(error_message) from e
if self.desc is None:
random_map, goal_position = generate_random_map(
size=self.size,
p=self.p,
seed=self.seed,
max_steps=self.max_steps,
)
else:
random_map = np.asarray(copy.deepcopy(self.desc), dtype="c")
goal_position = get_goal_position(random_map)
self.goal_position = goal_position
super().__init__(
desc=random_map[:],
is_slippery=self.is_slippery,
)
assert isinstance(self.desc, np.ndarray)
self.action_space = gym.spaces.Discrete(4, start=1)
self.map_kwargs = {
"size": size,
"p": p,
}
self.env_kwargs = {
"is_slippery": is_slippery,
"desc": copy.deepcopy(desc),
"seed": seed,
}
self.action_map = {
1: 0, # left
2: 1, # down
3: 2, # right
4: 3, # up
}
def _get_player_position(self) -> Tuple[int, int]:
return (self.s // self.ncol, self.s % self.ncol) # (row, col)
def step(self, action: str) -> Tuple[str, float, bool, Dict]:
"""Execute a step in the environment.
Maps custom action to gymnasium FrozenLakeEnv action and
takes the step. Checks if the action is effective (whether
player moves in the env).
Args:
action: The action to take.
Returns:
Tuple of (observation, reward, done, info).
"""
if self.success():
obs = self.render(mode="tiny_rgb_array")
assert isinstance(obs, str)
return obs, 1.0, True, {"action_is_effective": False}
action_id: int = self.ACTION_LOOKUP.get(action.lower(), 0)
if not action_id:
action_id = self.INVALID_ACTION
if (
action_id == self.INVALID_ACTION
or action_id not in self.action_map
):
obs = self.render(mode="tiny_rgb_array")
assert isinstance(obs, str)
return obs, 0.0, False, {"action_is_effective": False}
prev_player_position = int(self.s)
# Call parent class step method
# Note: GymFrozenLakeEnv is imported at module level
player_pos, reward, done, _, _ = super().step(
self.action_map[action_id],
)
obs = self.render(mode="tiny_rgb_array")
assert isinstance(obs, str)
return (
obs,
float(reward),
bool(done),
{"action_is_effective": prev_player_position != int(player_pos)},
)
def render(
self,
mode: str = "tiny_rgb_array",
) -> str | list[str] | np.ndarray:
"""Render the environment.
Args:
mode: Rendering mode. Options: "tiny_rgb_array", "list",
"state", "rgb_array", "ansi".
Returns:
Rendered observation based on the mode.
"""
assert mode in [
"tiny_rgb_array",
"list",
"state",
"rgb_array",
"ansi",
]
if mode in ["rgb_array", "ansi"]:
prev_render_mode = self.render_mode
self.render_mode = mode
obs = super().render()
self.render_mode = prev_render_mode
return obs
assert isinstance(self.desc, np.ndarray)
room_state = copy.deepcopy(self.desc)
# replace the position of start 'S' with 'F'
position_S = np.where(room_state == b"S")
room_state[position_S] = b"F"
# replace the position of the player with 'P'
position_P = self._get_player_position()
room_state[position_P] = b"P"
if mode == "state":
# transform 'S', 'F', 'H', 'G' to numpy integer array
room_state = np.vectorize(lambda x: self.MAP_LOOKUP[x])(room_state)
# add player in hole or player on goal
if self.desc[position_P] == b"H":
room_state[position_P] = 4
elif self.desc[position_P] == b"G":
room_state[position_P] = 5
return room_state
room_state = self.render(mode="state").tolist()
assert isinstance(room_state, list)
if mode == "list":
def lookup_list(cell: int) -> str:
return self.GRID_LOOKUP.get(cell, "?").strip("\t").strip()
return [
" ".join(lookup_list(cell) for cell in row)
for row in room_state
]
if mode == "tiny_rgb_array":
def lookup_tiny(cell: int) -> str:
return self.GRID_LOOKUP.get(cell, "?")
result = "\n".join(
"".join(lookup_tiny(cell) for cell in row)
for row in room_state
)
return result
# Default return for other modes
return ""
def reset(
self,
task: Optional[Dict] = None,
) -> tuple[str, Dict]:
"""Reset the environment with optional task parameters."""
task = task or {}
# Update parameters from task if provided
size = task.get("size", self.map_kwargs["size"])
p = task.get("p", self.map_kwargs["p"])
seed = task.get("seed", self.env_kwargs["seed"])
is_slippery = task.get(
"is_slippery",
self.env_kwargs["is_slippery"],
)
desc = task.get("desc", self.env_kwargs.get("desc"))
# Update instance variables
self.size = size
self.p = p
self.seed = seed
self.is_slippery = is_slippery
self.map_kwargs["size"] = size
self.map_kwargs["p"] = p
self.env_kwargs["seed"] = seed
self.env_kwargs["is_slippery"] = is_slippery
if desc is not None:
self.env_kwargs["desc"] = copy.deepcopy(desc)
if desc is None:
random_map, goal_position = generate_random_map(
size=size,
p=p,
seed=seed,
max_steps=self.max_steps,
)
else:
random_map = np.asarray(copy.deepcopy(desc), dtype="c")
goal_position = get_goal_position(random_map)
self.goal_position = goal_position
self.desc = random_map[:]
# Reinitialize parent class with new map
try:
import gymnasium as gym
super().__init__(
desc=random_map[:],
is_slippery=self.is_slippery,
)
assert isinstance(self.desc, np.ndarray)
self.action_space = gym.spaces.Discrete(4, start=1)
except ImportError as e:
error_message = (
"Gymnasium is not installed. "
"Please install gymnasium first before "
"running the frozen_lake workflow. "
f"Error: {str(e)}"
)
raise ImportError(error_message) from e
super().reset(seed=self.seed)
obs = self.render(mode="tiny_rgb_array")
assert isinstance(obs, str)
return obs, {}
def finished(self) -> bool:
"""Check if the episode is finished (goal or hole)."""
player_pos = self._get_player_position()
assert isinstance(self.desc, np.ndarray)
return self.desc[player_pos] in b"GH" # type: ignore
def success(self) -> bool:
"""Check if the agent has reached the goal (G)."""
player_pos = self._get_player_position()
assert isinstance(self.desc, np.ndarray)
return self.desc[player_pos] in b"G"

209
tuner/frozen_lake/_utils.py Normal file
View File

@@ -0,0 +1,209 @@
# -*- coding: utf-8 -*-
"""
Utils for the FrozenLake environment.
Modified from rllm
"""
from typing import Literal, Optional, Tuple
import numpy as np
from pydantic import BaseModel, Field
# Map gym state in integer
MAP_LOOKUP = {
b"P": 0,
b"F": 1,
b"H": 2,
b"G": 3,
}
# Define rules to transform to rendered text observation of the environment
GRID_LOOKUP = {
0: " P \t", # player
1: " _ \t", # frozen
2: " O \t", # hole
3: " G \t", # goal
4: " X \t", # player fall into hole
5: "\t", # player on goal
}
ACTION_LOOKUP = {
0: "None",
1: "Left",
2: "Down",
3: "Right",
4: "Up",
}
# Prompting format inspired by the RAGEN project
SYSTEM_PROMPT = """You are Qwen, created by Alibaba Cloud. \
You are a helpful assistant. You are walking on a frozen lake.
FrozenLake Quick Guide
Goal: Reach the goal (G). Player (P) and Goal (G) must overlap.
Symbols:
_ Frozen | O Hole | G Goal | P Player
Rules:
1. Avoid falling into holes (O).
2. Frozen tiles are slippery, you may move perpendicular to
your intended direction.
Valid Action (separated by | ):
Up | Down | Left | Right
Rewards:
Fall into hole: 0
Reach goal: +1.0
You will be provided the current observation, please decide on
the next Action.
You should show your thought process and then input the final
action in ``` ```.
You should only output the NEXT ACTION at each iteration in
the ``` ```. For example, if you want to move up, you should
output ```Up```.
You should plan ahead and need to achieve it in minimum number
of steps.
You should be aware that frozen tiles can be slippery, but the
chance is small and you should not overthink it.
Please show your thinking process and put the final action in
``` ```. In every turn, the final action MUST be one of Up,
Down, Left, Right.
"""
class FrozenLakeAction(BaseModel):
"""Action model for FrozenLake environment."""
action: Literal["Up", "Down", "Left", "Right"] = Field(
description=(
"The action to take in the FrozenLake environment, "
"must be one of Up, Down, Left, Right"
),
)
def is_valid(board: list[list[str]], max_size: int, max_steps: int) -> bool:
"""DFS to check that it's a valid path.
Args:
board: The board representation as a list of lists.
max_size: Maximum size of the board.
max_steps: Maximum number of steps allowed.
Returns:
True if there's a valid path from start to goal within max_steps,
False otherwise.
"""
frontier, discovered = [], set()
# find the start point
start_r, start_c = np.where(np.array(board) == "S")
frontier.append((start_r[0], start_c[0], 0)) # row, col steps
# dfs to check if there is a path from start to goal
while frontier:
r, c, steps = frontier.pop()
if steps > max_steps:
continue
if (r, c) not in discovered:
discovered.add((r, c))
directions = [(1, 0), (0, 1), (-1, 0), (0, -1)]
for x, y in directions:
r_new = r + x
c_new = c + y
if (
r_new < 0
or r_new >= max_size
or c_new < 0
or c_new >= max_size
): # noqa: PLR2004
continue
if board[r_new][c_new] == "G":
return True
if board[r_new][c_new] != "H":
frontier.append((r_new, c_new, steps + 1))
return False
def generate_random_map(
size: int = 8,
p: float = 0.8,
seed: int = 0,
max_steps: int = 5,
) -> Tuple[list[str], Tuple[int, int]]:
"""Generates a random valid map (one that has a path from start to goal).
Args:
size: Size of each side of the grid.
p: Probability that a tile is frozen.
seed: Seed to ensure the generation of reproducible maps.
max_steps: Maximum number of steps allowed.
Returns:
A tuple containing a random valid map and the goal position (row, col).
"""
valid = False
board: list[list[str]] = [] # initialize to make pyright happy
try:
from gymnasium.utils import seeding
np_random, _ = seeding.np_random(seed)
except ImportError as exc:
raise ImportError(
"Gymnasium is not installed. "
"Please install gymnasium first before "
"running the frozen_lake workflow.",
) from exc
# generate random start and end points
while not valid:
p = min(1, p)
board = np_random.choice(
["F", "H"],
(size, size),
p=[p, 1 - p],
).tolist()
while True:
start_r = int(np_random.integers(0, size))
start_c = int(np_random.integers(0, size))
goal_r = int(np_random.integers(0, size))
goal_c = int(np_random.integers(0, size))
# Ensure start and goal are different positions
if (start_r, start_c) != (goal_r, goal_c):
break
board[start_r][start_c] = "S"
board[goal_r][goal_c] = "G"
valid = is_valid(board, size, max_steps)
return ["".join(x) for x in board], (goal_r, goal_c)
def get_goal_position(
random_map: np.ndarray,
) -> Optional[Tuple[int, int]]:
"""Get the goal position from a random map.
Args:
random_map: The map as a numpy array.
Returns:
Tuple of (row, col) if goal found, None otherwise.
"""
positions = np.argwhere(random_map == b"G")
if positions.size == 0:
return None # G not found
return tuple(positions[0]) # returns (row, col)
__all__ = [
"SYSTEM_PROMPT",
"FrozenLakeAction",
"generate_random_map",
"get_goal_position",
]

View File

@@ -0,0 +1,53 @@
project: "AgentScope" # Project name
name: "FrozenLake" # Experiment name
checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints} # Directory to save model checkpoints
algorithm:
algorithm_type: multi_step_grpo # GRPO series for multi-step scenario
repeat_times: 16 # Number of rollouts per prompt for advantage estimation
kl_loss_fn: "low_var_kl"
kl_loss_fn_args:
kl_coef: 0 # KL divergence coefficient
advantage_fn_args:
epsilon: 1e-6 # Small value for numerical stability
std_threshold: 0.0001 # Threshold for standard deviation
optimizer:
lr: 1e-6 # Learning rate
model:
model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-3B-Instruct} # Base model path
max_prompt_tokens: 23552 # Max tokens for prompt
max_response_tokens: 2048 # Max tokens per response
max_model_len: 25600 # Max context length
temperature: 1.0 # Sampling temperature
buffer:
total_epochs: 5 # Total training epochs
batch_size: 32 # Batch size per explore step
train_batch_size: 1024 # Total experiences per training step
trainer_input:
experience_buffer:
name: experience_buffer
storage_type: queue
max_read_timeout: 7200 # Max timeout for reading from buffer (seconds)
replay_buffer:
enable: true # Enable experience replay
priority_fn: linear_decay # Priority function for replay buffer
priority_fn_args:
decay: 0.1 # Decay rate for priority function
explorer:
runner_per_model: 16 # Number of runners per model
rollout_model:
engine_num: 6 # Number of vLLM engines for rollout model
tensor_parallel_size: 1 # TP size per engine for rollout model
enable_openai_api: true # Enable OpenAI-compatible API
enable_history: true # Enable conversation history
enable_auto_tool_choice: true # Enable automatic tool selection
tool_call_parser: hermes # Parser for tool calls
trainer:
save_interval: 100 # Save checkpoint every N steps
use_dynamic_bsz: true # Use dynamic batch size
grad_clip: 1.0 # Gradient clipping value
max_token_len_per_gpu: 25600 # Max token length per GPU
ulysses_sequence_parallel_size: 2 # Sequence parallel size for Ulysses
synchronizer:
sync_style: dynamic_by_explorer # Sync triggered dynamically by explorer
sync_interval: 1 # Sync every N steps
sync_timeout: 1200 # Timeout for synchronization (seconds)

Binary file not shown.

After

Width:  |  Height:  |  Size: 62 KiB

View File

@@ -0,0 +1,131 @@
# -*- coding: utf-8 -*-
"""
Modified from rllm
"""
import argparse
import os
import numpy as np
import pandas as pd
DEFAULT_DATA_PATH = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"..",
"data",
"frozenlake",
)
def save_dataset_to_local(
data_path: str,
data: list[dict],
split: str = "default",
) -> str:
"""Save dataset directly to local data_path.
Args:
data_path: Path to save the dataset
data: List of dictionaries containing the dataset examples
split: Split name (e.g., 'train', 'test', 'default')
Returns:
str: Path to the saved parquet file
"""
os.makedirs(data_path, exist_ok=True)
# Convert to DataFrame and save
data_df = pd.DataFrame(data)
dataset_path = os.path.join(data_path, f"{split}.parquet")
data_df.to_parquet(dataset_path)
print(
f"Saved dataset frozenlake split '{split}' "
f"with {len(data)} examples at {dataset_path}. "
f"Make sure to set the environment variable "
f"<TRINITY_TASKSET_PATH> to {data_path}.",
)
return dataset_path
def prepare_frozenlake_data(
data_path: str,
train_size: int = 10000,
test_size: int = 100,
map_max_size: int = 6,
) -> tuple[list[dict], list[dict]]:
"""
Prepare and save FrozenLake datasets for training and testing.
Args:
data_path (str): Path to save the dataset
train_size (int): Number of training examples to generate
test_size (int): Number of test examples to generate
map_max_size (int): Maximum size of the map
Returns:
tuple: (train_data, test_data) - Lists of data dictionaries
"""
# Set random seed for reproducibility
np.random.seed(42)
# Generate random parameters for train and test sets
train_seeds = np.random.randint(0, 100000, size=train_size)
test_seeds = np.random.randint(0, 100000, size=test_size)
train_sizes = np.random.randint(2, map_max_size, size=train_size)
test_sizes = np.random.randint(2, map_max_size, size=test_size)
train_ps = np.random.uniform(0.6, 0.85, size=train_size)
test_ps = np.random.uniform(0.6, 0.85, size=test_size)
def frozenlake_process_fn(
seed: int,
size: int,
p: float,
idx: int,
) -> dict:
"""Process function to create FrozenLake task instances."""
return {
"seed": seed,
"size": size,
"p": p,
"index": idx,
"uid": f"{seed}_{size}_{p}",
}
# Create train and test data
train_data_list = [
frozenlake_process_fn(seed, train_sizes[idx], train_ps[idx], idx)
for idx, seed in enumerate(train_seeds)
]
test_data_list = [
frozenlake_process_fn(seed, test_sizes[idx], test_ps[idx], idx)
for idx, seed in enumerate(test_seeds)
]
# Save datasets directly to local DATA_PATH
save_dataset_to_local(data_path, train_data_list, "train")
save_dataset_to_local(data_path, test_data_list, "test")
return train_data_list, test_data_list
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--local_dir", default=DEFAULT_DATA_PATH)
parser.add_argument("--train_size", type=int, default=10000)
parser.add_argument("--test_size", type=int, default=100)
parser.add_argument("--map_max_size", type=int, default=6)
args = parser.parse_args()
train_data, test_data = prepare_frozenlake_data(
data_path=args.local_dir,
train_size=args.train_size,
test_size=args.test_size,
map_max_size=args.map_max_size,
)
print(f"Train dataset: {len(train_data)} examples")
print(f"Test dataset: {len(test_data)} examples")
print("Sample train example:", train_data[0])
print("Sample test example:", test_data[0])

151
tuner/frozen_lake/main.py Normal file
View File

@@ -0,0 +1,151 @@
# -*- coding: utf-8 -*-
"""Example of training a FrozenLake agent with Trinity-RFT."""
import os
from typing import Dict
from _frozenlake_agent import FrozenLakeAgent
from _frozenlake_env import FrozenLakeEnv
from agentscope.message import Msg
from agentscope.tuner import (
tune,
WorkflowOutput,
DatasetConfig,
TunerModelConfig,
AlgorithmConfig,
)
from agentscope.model import ChatModelBase
async def run_frozen_lake(
task: Dict,
model: ChatModelBase,
auxiliary_models: Dict[str, ChatModelBase],
) -> WorkflowOutput:
"""A workflow function using the FrozenLake agent to solve tasks.
Args:
task (Dict): The task to be solved, containing environment parameters
like size, p, seed, is_slippery, etc.
model (ChatModelBase): The language model to use.
Returns:
WorkflowOutput: The workflow output containing the reward, response and
metrics.
"""
assert len(auxiliary_models) == 0, "No auxiliary models are needed"
# Extract workflow arguments from task or use defaults
workflow_args = task.get("workflow_args", {})
if not workflow_args:
workflow_args = task
env_max_steps = workflow_args.get("env_max_steps", 8)
agent_max_steps = workflow_args.get("agent_max_steps", 10)
is_slippery = workflow_args.get("is_slippery", False)
desc = workflow_args.get("desc", None)
# Extract task-specific arguments (for environment generation)
size = task.get("size", 8)
p = task.get("p", 0.8)
seed = task.get("seed", 42)
# Initialize agent and environment
agent = FrozenLakeAgent(model=model, max_steps=agent_max_steps)
env = FrozenLakeEnv(
max_steps=env_max_steps,
desc=desc,
is_slippery=is_slippery,
size=size,
p=p,
seed=seed,
)
# Reset environment with task parameters
observation, _ = env.reset(task)
observation_str = str(observation)
rewards = []
step_count = 0
done = False
terminate_reason = None
# Run agent-environment interaction loop
for _ in range(agent_max_steps):
step_count += 1
try:
# get prompt
prompt = agent.get_prompt(observation_str)
response = await agent.reply(msg=Msg("user", prompt, role="user"))
# record action and observation
action = agent.get_action(response)
agent.update_state(action=action, observation=observation_str)
except Exception as e:
terminate_reason = f"agent_error: {str(e)}"
break
# environment step
observation, reward, done, _ = env.step(action)
observation_str = str(observation)
rewards.append(reward)
if done:
terminate_reason = "success" if env.success() else "hole"
break
if terminate_reason is None:
terminate_reason = "max_steps_reached"
final_reward = sum(rewards)
final_observation = observation_str
# Create response message with environment information
response_content = (
f"Final observation:\n{final_observation}\n"
f"Total reward: {final_reward}\n"
f"Steps taken: {step_count}\n"
f"Terminate reason: {terminate_reason}"
)
final_response = Msg("assistant", response_content, role="assistant")
return WorkflowOutput(
reward=final_reward,
response=final_response,
metrics={
"env_steps": float(step_count),
"env_done": float(done),
},
)
if __name__ == "__main__":
dataset = DatasetConfig(
path="/path/to/frozenlake",
split="train",
)
tuner_model = TunerModelConfig(
model_path="Qwen/Qwen2.5-3B-Instruct",
max_model_len=25600,
max_tokens=2048,
inference_engine_num=6,
reasoning_parser=None,
)
algorithm = AlgorithmConfig(
algorithm_type="multi_step_grpo",
group_size=16,
batch_size=32,
learning_rate=1e-6,
)
config_path = os.path.join(
os.path.dirname(__file__),
"config.yaml",
) # define some default parameters
tune(
workflow_func=run_frozen_lake,
model=tuner_model,
train_dataset=dataset,
algorithm=algorithm,
config_path=config_path,
)

View File

@@ -0,0 +1,393 @@
# Training Learn2Ask with AgentScope-Tuner
This guide demonstrates how to train a proactive LLM using the **Learn2Ask** framework from [Grounded in Reality: Learning and Deploying Proactive LLM from Offline Logs](https://arxiv.org/abs/2510.25441).
---
## Task Setting
In this example, given a user's chief complaint, the medical assistant agent proactively asks targeted questions to gather sufficient symptom information for a comprehensive assessment of the user's health condition. The querying process should be efficient: the agent must optimize question quality, and terminate the interview as soon as the collected information is adequate for subsequent clinical assessment or decision-making.
Here we use the `ReActAgent` for this task and no tools are required.
---
## Hardware Requirements
- **Training with GPUs**: At least **8 H20 GPUs** (or equivalent).
- **Training without GPUs**: You can use the **[Tinker](https://thinkingmachines.ai/tinker/)** without any GPUs.
> 💡 All code and configuration files are located in:
> `tuner/learn_to_ask/`
Key files:
- Workflow & Training: `tuner/learn_to_ask/main.py`
- Prompts: `tuner/learn_to_ask/prompt.py`
- Training config: `tuner/learn_to_ask/config.yaml`
- Data preparation scripts: `tuner/learn_to_ask/data_prepare/`
---
## Dataset Preparation
> [!NOTE]
> In this example, we use an open-source dataset directly for training. In practice, however, you would typically start by collecting interaction logs between your deployed agent and users. After filtering these raw logs to curate a high-quality dataset, you can follow the same pipeline to enhance your agents proactive capabilities using AgentScope-Tuner. Happy tuning!
### 1.1 Download the Dataset
Download the **[RealMedConv](https://huggingface.co/datasets/datajuicer/RealMedConv)** dataset (in `.jsonl` format).
You can use the following python scripts to download the dataset:
```python
from huggingface_hub import snapshot_download
# Download to local directory, e.g., `./tuner/learn_to_ask/data`
local_dir = "./tuner/learn_to_ask/data"
snapshot_download(
repo_id="datajuicer/RealMedConv",
repo_type="dataset",
local_dir=local_dir,
)
```
Each line in `train_origin.jsonl` (or `test_origin.jsonl`) represents a complete doctor-patient conversation log, like this:
```json
{
"session_id": 35310,
"diagn": "Upper Respiratory Tract Infection",
"messages": [
{"role": "user", "content": "Sore throat, phlegm, red eyes, cough, hoarse voice"},
{"role": "user", "content": "I took Amoxicillin"},
...
{"role": "assistant", "content": "<med_search>"}
]
}
```
### 1.2 Preprocess the Data
You need to convert raw conversation logs into training samples. This involves two steps:
#### 🔹 Step A: Segment Conversations & Extract Labels
Split each conversation into **contextfuture pairs**, and extract ground-truth symptom information (`info_truth`) from what happens next.
```bash
python tuner/learn_to_ask/data_prepare/1_info_extract_pipeline.py \
--input_file /path/to/RealMedConv/train.jsonl \
--output_file tuner/learn_to_ask/data_raw/train_processed.jsonl \
--model_path Qwen/Qwen2.5-32B-Instruct
```
#### 🔹 Step B: Build Final Training Dataset
Convert the processed samples into the final format used for training/testing.
```bash
python tuner/learn_to_ask/data_prepare/2_build_dataset.py \
--input_file tuner/learn_to_ask/data_raw/train_processed.jsonl \
--output_file tuner/learn_to_ask/data/train.jsonl
```
---
### How It Works: ContextFuture Segmentation
For every turn in a conversation, we create a sample with:
- `messages`: The **observed dialogue history** up to that point (the *context*).
- `remaining_chat`: Everything that happens **after** that point (the *future*).
- A unique ID: `cid = {session_id}_{turn_index}`
Example output:
```json
{
"cid": "35310_7",
"session_id": "35310",
"diagn": "Upper Respiratory Tract Infection",
"messages": [ ... up to turn 7 ... ],
"remaining_chat": [ ... all future messages ... ]
}
```
### Extract Ground-Truth Labels
From `remaining_chat`, we automatically derive two key labels:
- `decision_truth`: Should the assistant **continue asking questions** (`"continue"`) or **stop** (`"stop"`)?
- `info_truth`: Structured list of symptoms mentioned later (used to compute reward signals during training).
Example:
```json
{
"decision_truth": "continue",
"info_truth": "Symptom: sore throat, Symptom quality: thick discharge, Symptom quality: yellowish discharge, ..."
}
```
These labels power the reward functions $R_a$ (action accuracy) and $R_s$ (symptom coverage) during training.
---
## Code Implementation
### Agent Workflow
The workflow function `run_react_agent` implements how the `ReActAgent` works.
```python
async def run_react_agent(
task: Dict,
model: OpenAIChatModel,
auxiliary_models: Dict[str, OpenAIChatModel],
) -> WorkflowOutput:
assert (
len(auxiliary_models) == 1
), "Please provide only one `auxiliary_models` for `learn_to_ask`."
import importlib
spec = importlib.util.spec_from_file_location(
"prompt",
os.path.join(os.path.dirname(__file__), "prompt.py"),
)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
if TRAIN_MODE == "Ra":
sys_prompt = module.rollout_prompt_med_Ra
else:
sys_prompt = module.rollout_prompt_med
agent = ReActAgent(
name="react_agent",
sys_prompt=sys_prompt,
model=model,
formatter=OpenAIChatFormatter(),
toolkit=None,
memory=InMemoryMemory(),
max_iters=1,
)
messages = format_messages(task["messages"])
response = await agent.reply(
[
Msg(name=x["role"], content=x["content"], role=x["role"])
for x in messages
],
)
return WorkflowOutput(
response=response,
)
```
### Judge Function
The judge function `learn2ask_judge` implements reward calculation using LLM-as-a-Judge:
```python
async def learn2ask_judge(
task: Dict,
response: Msg,
auxiliary_models: Dict[str, OpenAIChatModel],
) -> JudgeOutput:
response_text = response.get_text_content()
action_truth = task.get("decision_truth", "continue")
action_response = "stop" if "<stop />" in response_text else "continue"
# Calculate action accuracy score
action_score = 1.0 if action_truth == action_response else 0.0
# Calculate format and content scores
if action_score == 1.0 and action_truth == "continue":
# Use LLM-as-a-Judge to evaluate question quality
score_dict = await llm_reward(task, response_text, auxiliary_models)
format_score = float(score_dict.get("format_score", 0.0))
content_score = float(score_dict.get("content_score", 0.0))
elif action_score == 1.0: # stop action
content_score, format_score = 1.0, (1.0 if response_text == "<stop />" else 0.0)
else:
format_score = content_score = 0.0
# Combine final reward based on training mode
if TRAIN_MODE == "Ra+Rs": # Default: action + symptom rewards
final_reward = action_score * (1 + 2 * content_score) + format_score
elif TRAIN_MODE == "Ra": # Action reward only
final_reward = 2 * content_score + format_score
else: # Symptom reward only
final_reward = action_score * 3 + format_score
return JudgeOutput(reward=final_reward, metrics={"reward": final_reward})
```
This reward function considers:
- Action accuracy: `action_score`
- Question quality (Symptom coverage): `content_score`
- Format score: `format_score`
See [main.py](./main.py) for implementation details.
---
## Configure and Train the Model
### Option A: Edit Python Script (Simple)
Open `tuner/learn_to_ask/main.py` and adjust settings:
```python
if __name__ == "__main__":
train_mode = "Ra+Rs" # Use both action and symptom rewards
fusion_mode = "default" # How to combine rewards
dataset = DatasetConfig(path="tuner/learn_to_ask/data", split="train")
tuner_model = OpenAIChatModel(
model_path="Qwen/Qwen2.5-7B-Instruct",
max_model_len=8192,
tensor_parallel_size=1, # Adjust based on your GPU setup
...
)
auxiliary_models = {
AUXILIARY_MODEL_NAME: OpenAIChatModel(
model_path="Qwen/Qwen2.5-32B-Instruct", # Larger model for evaluation
tensor_parallel_size=2,
...
)
}
algorithm = AlgorithmConfig(
algorithm_type="grpo",
learning_rate=5e-7,
batch_size=64,
)
tune(...) # Starts training
```
### Option B: Use YAML Config (Advanced)
Edit `tuner/learn_to_ask/config.yaml` for more control.
#### 🌐 No GPU? Use Tinker!
If you dont have GPUs, enable the **Tinker backend** by setting:
```yaml
model:
tinker:
enable: true # ← Set this to true
```
Also, make sure to update the `model_path` in `tuner/learn_to_ask/main.py` to point to a model thats compatible with Tinker.
> 🔗 Learn more about Tinker Backend: [Tinker Backend Documentation](https://agentscope-ai.github.io/Trinity-RFT/en/main/tutorial/example_tinker_backend.html)
> In this provided example, training is configured for 4 epochs. When using Tinker, the total token consumption is approximately 112 million tokens, resulting in an estimated cost of approximately 18 USD.
### Launch Training
```bash
python tuner/learn_to_ask/main.py
```
---
## Evaluation
Use the **rollout-and-evaluate pipeline**:
1. Generate responses on the test set.
2. Score them using a powerful evaluator model (`Qwen2.5-32B-Instruct`).
Run evaluation:
```bash
python tuner/learn_to_ask/data_prepare/3_rollout_then_evaluate.py \
--eval_model_path path/to/your/trained/model \
--grader_model_path Qwen/Qwen2.5-32B-Instruct \
--test_file_path tuner/learn_to_ask/data/test.jsonl \
--rollout_file_path path/to/rollout.jsonl \
--eval_file_path path/to/output.jsonl
```
> ⚠️ **Note**: Your trained model must be converted to **Hugging Face format** first.
> See: [Converting FSDP Checkpoints Guide](https://agentscope-ai.github.io/Trinity-RFT/en/main/tutorial/faq.html)
---
## Experimental Results
We compared three approaches:
- **Base model**: `Qwen2.5-7B-Instruct` (no fine-tuning)
- **Trinity**: Direct response generation
- **AgentScope-Tuner (Learn2Ask)**: Uses a ReAct agent for proactive questioning
| Metric | Base Model | Trinity | AgentScope-Tuner (Learn2Ask) |
|--------------------------------------|-----------:|--------:|--------------------:|
| Avg. continue content | 0.436 | 0.496 | 0.509 |
| Win rate (continue content) | 0.122 | 0.246 | 0.224 |
| Avg. continue decision accuracy | 0.963 | 0.909 | 0.922 |
| Avg. stop decision accuracy | 0.024 | 0.927 | 0.909 |
| **Total decision accuracy** | 0.467 | 0.918 | 0.915 |
| **Response format quality** | 0.376 | 0.713 | 0.882 |
| **Total reward** | 1.281 | 3.078 | 3.237 |
<div align="center">
<img src="./learn2ask.png" alt="Training Curves" width="90%"/>
</div>
> ✅ **Key insight**: Learn2Ask (AgentScope-Tuner) achieves the highest overall performance by teaching the model **when and what to ask**—making it truly proactive.
### Concrete Example
An example of the agent's behavior on evaluation data is shown below:
```json
{
"cid": "206_1",
"session_id": 206,
"diagn": "Common cold",
"messages": [
{
"role": "assistant",
"content": "Hello, currently there are many people consulting, please wait patiently."
},
{
"role": "user",
"content": "Hello, this morning around 10 o'clock, I suddenly felt a bit feverish and had a heavy head. I took my temperature and it was 38.5°C. I took one tablet of paracetamol at 1 o'clock, which helped a lot. By 6 or 7 in the evening, I felt cold and feverish again, so I took another tablet of paracetamol. I'm not sure if it's a side effect of the medication, but I feel a bit nauseous and my stomach keeps producing gas. Now, I feel cold again, have a stuffy nose with clear runny nose, an itchy throat with coughing, a heavy head, and can't sleep."
}
],
"decision_truth": "continue",
"info_truth": "symptom: needle-like sensation, symptom location: ear, symptom nature: sharp pain, symptom severity: occurred twice",
"rollouts": [
"Do you also experience body aches along with your symptoms? \nA) Yes B) No C) Not Sure",
"Do you experience these symptoms every time after taking paracetamol? (A) Yes, definitely. (B) Sometimes. (C) No, never before.",
"Do you have a history of allergies? A) Yes B) No C) Not Sure"
],
"grades": [
{
"think": "The doctor's last message asks about body aches using multiple choice options rather than a single direct open-ended question, which deviates from the typical questioning style indicated by the grading policy. However, since it is framed in the form of \"Do you...\", it counts as one question despite being presented with choices. Thus, it satisfies the criteria for 'contains exactly one question'. For the content score, the query does not pertain directly to the details listed in the reference information such as the specific symptom ('needle-like' sharp pain in the ear that has occurred twice), thereby making it highly relevant yet indirectly related due to its focus on comprehensive assessment of flu-related discomforts without addressing specifics noted in the patients primary concern documented earlier.",
"action_score": 1,
"format_score": 1.0,
"content_score": 0.5
},
{
"think": "The doctor's last message includes just one multiple-choice question regarding whether the patient experiences those mentioned symptoms each time they take paracetamol. This does relate highly to understanding possible drug-related symptoms; however, none of them aligns perfectly with \"needle-like\" sensations occurring specifically in ears according to the reference information given.",
"action_score": 1,
"format_score": 1.0,
"content_score": 0.5
},
{
"think": "The doctors last statement does contain just one question pertaining to allergy history, which is highly relevant when trying to diagnose symptoms such as those described by the patient (fever, nausea). However, none of these concerns specifically relate back to the reference information detailing \"needle-like sensation\", \"sharp pain\" related to the ears occurring twice. Therefore, while highly pertinent medically, they do not pertain to the exact points outlined in the Ref Info section about the patient experience according to that specific prompt context.",
"action_score": 1,
"format_score": 1.0,
"content_score": 0.5
}
]
}
```
---
## 📚 Citation
If you use this code or framework, please cite our work:
```bibtex
@misc{learn2ask,
title={Grounded in Reality: Learning and Deploying Proactive LLM from Offline Logs},
author={Fei Wei and Daoyuan Chen and Ce Wang and Yilun Huang and Yushuo Chen and Xuchen Pan and Yaliang Li and Bolin Ding},
year={2025},
eprint={2510.25441},
archivePrefix={arXiv},
primaryClass={cs.AI},
url={https://arxiv.org/abs/2510.25441}
}
```

View File

@@ -0,0 +1,391 @@
# 使用 AgentScope-Tuner 训练 Learn2Ask
本指南演示了如何使用来自 [Grounded in Reality: Learning and Deploying Proactive LLM from Offline Logs](https://arxiv.org/abs/2510.25441) 的 **Learn2Ask** 方法训练主动式 LLM。
---
## 任务设置
在此示例中,给定用户的主诉,医疗助手智能体主动提出有针对性的问题,以收集足够的症状信息,从而全面评估用户的健康状况。查询过程应该高效:智能体必须优化问题质量,并在收集的信息足以进行后续临床评估或决策时立即终止访谈。
这里我们使用 `ReActAgent` 来完成此任务,不需要工具。
---
## 硬件要求
- **使用 GPU 训练**:至少需要 **8 个 H20 GPU**(或同等配置)。
- **不使用 GPU 训练**:您可以使用 **[Tinker](https://thinkingmachines.ai/tinker/)**,无需任何 GPU。
> 💡 所有代码和配置文件位于:
> `tuner/learn_to_ask/`
关键文件:
- 工作流和训练:`tuner/learn_to_ask/main.py`
- 提示词:`tuner/learn_to_ask/prompt.py`
- 训练配置:`tuner/learn_to_ask/config.yaml`
- 数据准备脚本:`tuner/learn_to_ask/data_prepare/`
---
## 数据集准备
> [!NOTE]
> 在此示例中,我们直接使用开源数据集进行训练。然而,在实践中,您通常需要先收集已部署智能体与用户之间的交互日志。在过滤这些原始日志以整理高质量数据集后,您可以遵循相同的流程,使用 AgentScope-Tuner 增强智能体的主动能力。祝调优愉快!
### 1.1 下载数据集
下载 **[RealMedConv](https://huggingface.co/datasets/datajuicer/RealMedConv)** 数据集(`.jsonl` 格式)。
您可以使用以下 Python 脚本下载数据集:
```python
from huggingface_hub import snapshot_download
# 下载到本地目录,例如 `./tuner/learn_to_ask/data`
local_dir = "./tuner/learn_to_ask/data"
snapshot_download(
repo_id="datajuicer/RealMedConv",
repo_type="dataset",
local_dir=local_dir,
)
```
`train_origin.jsonl`(或 `test_origin.jsonl`)中的每一行代表一个完整的医患对话日志,如下所示:
```json
{
"session_id": 35310,
"diagn": "Upper Respiratory Tract Infection",
"messages": [
{"role": "user", "content": "Sore throat, phlegm, red eyes, cough, hoarse voice"},
{"role": "user", "content": "I took Amoxicillin"},
...
{"role": "assistant", "content": "<med_search>"}
]
}
```
### 1.2 预处理数据
您需要将原始对话日志转换为训练样本。这涉及两个步骤:
#### 🔹 步骤 A分割对话并提取标签
将每个对话分割为**contextfuture pairs**,并从后续内容中提取真实症状信息(`info_truth`)。
```bash
python tuner/learn_to_ask/data_prepare/1_info_extract_pipeline.py \
--input_file /path/to/RealMedConv/train.jsonl \
--output_file tuner/learn_to_ask/data_raw/train_processed.jsonl \
--model_path Qwen/Qwen2.5-32B-Instruct
```
#### 🔹 步骤 B构建最终训练数据集
将处理后的样本转换为用于训练/测试的最终格式。
```bash
python tuner/learn_to_ask/data_prepare/2_build_dataset.py \
--input_file tuner/learn_to_ask/data_raw/train_processed.jsonl \
--output_file tuner/learn_to_ask/data/train.jsonl
```
---
### 工作原理ContextFuture pairsSegmentation
对于对话中的每一轮,我们创建一个样本,包含:
- `messages`:到该点为止的**已观察对话历史**context
- `remaining_chat`:该点之后发生的**所有内容**future
- 唯一 ID`cid = {session_id}_{turn_index}`
示例输出:
```json
{
"cid": "35310_7",
"session_id": "35310",
"diagn": "Upper Respiratory Tract Infection",
"messages": [ ... up to turn 7 ... ],
"remaining_chat": [ ... all future messages ... ]
}
```
### 提取真实标签
`remaining_chat` 中,我们自动推导出两个关键标签:
- `decision_truth`:助手应该继续提问(`"continue"`)还是停止(`"stop"`
- `info_truth`:后续提到的结构化症状列表(用于在训练期间计算奖励信号)。
示例:
```json
{
"decision_truth": "continue",
"info_truth": "Symptom: sore throat, Symptom quality: thick discharge, Symptom quality: yellowish discharge, ..."
}
```
这些标签在训练期间为奖励函数 $R_a$(动作准确性)和 $R_s$(症状覆盖率)提供支持。
---
## 代码实现
### 智能体工作流
工作流函数 `run_react_agent` 实现了 `ReActAgent` 的工作方式。
```python
async def run_react_agent(
task: Dict,
model: OpenAIChatModel,
auxiliary_models: Dict[str, OpenAIChatModel],
) -> WorkflowOutput:
assert (
len(auxiliary_models) == 1
), "Please provide only one `auxiliary_models` for `learn_to_ask`."
import importlib
spec = importlib.util.spec_from_file_location(
"prompt",
os.path.join(os.path.dirname(__file__), "prompt.py"),
)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
if TRAIN_MODE == "Ra":
sys_prompt = module.rollout_prompt_med_Ra
else:
sys_prompt = module.rollout_prompt_med
agent = ReActAgent(
name="react_agent",
sys_prompt=sys_prompt,
model=model,
formatter=OpenAIChatFormatter(),
toolkit=None,
memory=InMemoryMemory(),
max_iters=1,
)
messages = format_messages(task["messages"])
response = await agent.reply(
[
Msg(name=x["role"], content=x["content"], role=x["role"])
for x in messages
],
)
return WorkflowOutput(
response=response,
)
```
### 评判函数
评判函数 `learn2ask_judge` 使用 LLM-as-a-Judge 实现奖励计算:
```python
async def learn2ask_judge(
task: Dict,
response: Msg,
auxiliary_models: Dict[str, OpenAIChatModel],
) -> JudgeOutput:
response_text = response.get_text_content()
action_truth = task.get("decision_truth", "continue")
action_response = "stop" if "<stop />" in response_text else "continue"
# 计算动作准确性分数
action_score = 1.0 if action_truth == action_response else 0.0
# 计算格式和内容分数
if action_score == 1.0 and action_truth == "continue":
# 使用 LLM-as-a-Judge 评估问题质量
score_dict = await llm_reward(task, response_text, auxiliary_models)
format_score = float(score_dict.get("format_score", 0.0))
content_score = float(score_dict.get("content_score", 0.0))
elif action_score == 1.0: # stop 动作
content_score, format_score = 1.0, (1.0 if response_text == "<stop />" else 0.0)
else:
format_score = content_score = 0.0
# 根据训练模式组合最终奖励
if TRAIN_MODE == "Ra+Rs": # 默认:动作 + 症状奖励
final_reward = action_score * (1 + 2 * content_score) + format_score
elif TRAIN_MODE == "Ra": # 仅动作奖励
final_reward = 2 * content_score + format_score
else: # 仅症状奖励
final_reward = action_score * 3 + format_score
return JudgeOutput(reward=final_reward, metrics={"reward": final_reward})
```
此奖励函数考虑:
- 动作准确性:`action_score`
- 问题质量(症状覆盖率):`content_score`
- 格式分数:`format_score`
有关实现细节,请参阅 [main.py](./main.py)。
---
## 配置和训练模型
### 选项 A编辑 Python 脚本(简单)
打开 `tuner/learn_to_ask/main.py` 并调整设置:
```python
if __name__ == "__main__":
train_mode = "Ra+Rs" # 同时使用动作和症状奖励
fusion_mode = "default" # 如何组合奖励
dataset = DatasetConfig(path="tuner/learn_to_ask/data", split="train")
tuner_model = OpenAIChatModel(
model_path="Qwen/Qwen2.5-7B-Instruct",
max_model_len=8192,
tensor_parallel_size=1, # 根据您的 GPU 设置调整
...
)
auxiliary_models = {
AUXILIARY_MODEL_NAME: OpenAIChatModel(
model_path="Qwen/Qwen2.5-32B-Instruct", # 用于评估的更大模型
tensor_parallel_size=2,
...
)
}
algorithm = AlgorithmConfig(
algorithm_type="grpo",
learning_rate=5e-7,
batch_size=64,
)
tune(...) # 开始训练
```
### 选项 B使用 YAML 配置(高级)
编辑 `tuner/learn_to_ask/config.yaml` 以获得更多控制。
#### 🌐 没有 GPU使用 Tinker
如果您没有 GPU可以通过设置启用 **Tinker 后端**
```yaml
model:
tinker:
enable: true # ← 将此设置为 true
```
此外,请确保更新 `tuner/learn_to_ask/main.py` 中的 `model_path`,使其指向与 Tinker 兼容的模型。
> 🔗 了解更多关于 Tinker 后端: [Tinker 后端文档](https://agentscope-ai.github.io/Trinity-RFT/zh/main/tutorial/example_tinker_backend.html)
### 启动训练
```bash
python tuner/learn_to_ask/main.py
```
---
## 评估
使用**rollout 和评估流程**
1. 在测试集上生成响应。
2. 使用强大的评估模型(`Qwen2.5-32B-Instruct`)对它们进行评分。
运行评估:
```bash
python tuner/learn_to_ask/data_prepare/3_rollout_then_evaluate.py \
--eval_model_path path/to/your/trained/model \
--grader_model_path Qwen/Qwen2.5-32B-Instruct \
--test_file_path tuner/learn_to_ask/data/test.jsonl \
--rollout_file_path path/to/rollout.jsonl \
--eval_file_path path/to/output.jsonl
```
> ⚠️ **注意**:您的训练模型必须首先转换为 **Hugging Face 格式**。
> 请参阅:[转换 FSDP 检查点指南](https://agentscope-ai.github.io/Trinity-RFT/zh/main/tutorial/faq.html)
---
## 实验结果
我们比较了三种方法:
- **基础模型**`Qwen2.5-7B-Instruct`(无微调)
- **Trinity**:直接响应生成
- **AgentScope-Tuner (Learn2Ask)**:使用 ReAct 智能体进行主动提问
| 指标 | 基础模型 | Trinity | AgentScope-Tuner (Learn2Ask) |
|------------------------------------|---------:|--------:|--------------------:|
| 平均继续内容 | 0.436 | 0.496 | 0.509 |
| 胜率(继续内容) | 0.122 | 0.246 | 0.224 |
| 平均继续决策准确性 | 0.963 | 0.909 | 0.922 |
| 平均停止决策准确性 | 0.024 | 0.927 | 0.909 |
| **总决策准确性** | 0.467 | 0.918 | 0.915 |
| **响应格式质量** | 0.376 | 0.713 | 0.882 |
| **总奖励** | 1.281 | 3.078 | 3.237 |
<div align="center">
<img src="./learn2ask.png" alt="训练曲线" width="90%"/>
</div>
> ✅ **关键洞察**Learn2Ask (AgentScope-Tuner) 通过教模型**何时提问和提问什么**,实现了最高的整体性能——使其真正具有主动性。
### 具体示例
下面显示了智能体在评估数据上的行为示例:
```json
{
"cid": "206_1",
"session_id": 206,
"diagn": "Common cold",
"messages": [
{
"role": "assistant",
"content": "Hello, currently there are many people consulting, please wait patiently."
},
{
"role": "user",
"content": "Hello, this morning around 10 o'clock, I suddenly felt a bit feverish and had a heavy head. I took my temperature and it was 38.5°C. I took one tablet of paracetamol at 1 o'clock, which helped a lot. By 6 or 7 in the evening, I felt cold and feverish again, so I took another tablet of paracetamol. I'm not sure if it's a side effect of the medication, but I feel a bit nauseous and my stomach keeps producing gas. Now, I feel cold again, have a stuffy nose with clear runny nose, an itchy throat with coughing, a heavy head, and can't sleep."
}
],
"decision_truth": "continue",
"info_truth": "symptom: needle-like sensation, symptom location: ear, symptom nature: sharp pain, symptom severity: occurred twice",
"rollouts": [
"Do you also experience body aches along with your symptoms? \nA) Yes B) No C) Not Sure",
"Do you experience these symptoms every time after taking paracetamol? (A) Yes, definitely. (B) Sometimes. (C) No, never before.",
"Do you have a history of allergies? A) Yes B) No C) Not Sure"
],
"grades": [
{
"think": "The doctor's last message asks about body aches using multiple choice options rather than a single direct open-ended question, which deviates from the typical questioning style indicated by the grading policy. However, since it is framed in the form of \"Do you...\", it counts as one question despite being presented with choices. Thus, it satisfies the criteria for 'contains exactly one question'. For the content score, the query does not pertain directly to the details listed in the reference information such as the specific symptom ('needle-like' sharp pain in the ear that has occurred twice), thereby making it highly relevant yet indirectly related due to its focus on comprehensive assessment of flu-related discomforts without addressing specifics noted in the patient's primary concern documented earlier.",
"action_score": 1,
"format_score": 1.0,
"content_score": 0.5
},
{
"think": "The doctor's last message includes just one multiple-choice question regarding whether the patient experiences those mentioned symptoms each time they take paracetamol. This does relate highly to understanding possible drug-related symptoms; however, none of them aligns perfectly with \"needle-like\" sensations occurring specifically in ears according to the reference information given.",
"action_score": 1,
"format_score": 1.0,
"content_score": 0.5
},
{
"think": "The doctor's last statement does contain just one question pertaining to allergy history, which is highly relevant when trying to diagnose symptoms such as those described by the patient (fever, nausea). However, none of these concerns specifically relate back to the reference information detailing \"needle-like sensation\", \"sharp pain\" related to the ears occurring twice. Therefore, while highly pertinent medically, they do not pertain to the exact points outlined in the Ref Info section about the patient experience according to that specific prompt context.",
"action_score": 1,
"format_score": 1.0,
"content_score": 0.5
}
]
}
```
---
## 📚 引用
如果您使用此代码或框架,请引用我们的工作:
```bibtex
@misc{learn2ask,
title={Grounded in Reality: Learning and Deploying Proactive LLM from Offline Logs},
author={Fei Wei and Daoyuan Chen and Ce Wang and Yilun Huang and Yushuo Chen and Xuchen Pan and Yaliang Li and Bolin Ding},
year={2025},
eprint={2510.25441},
archivePrefix={arXiv},
primaryClass={cs.AI},
url={https://arxiv.org/abs/2510.25441}
}
```

View File

@@ -0,0 +1,30 @@
project: AgentScope-ReAct
name: Learn_to_Ask-Qwen2.5-7B-fixed
# directory to save checkpoints, default to ./checkpoints if TRINITY_CHECKPOINT_ROOT_DIR not set
checkpoint_root_dir: ${oc.env:TRINITY_CHECKPOINT_ROOT_DIR,./checkpoints}
algorithm:
algorithm_type: grpo # a GRPO-based algorithm for multi-step reasoning
model:
# path to the pre-trained model, default to Qwen/Qwen2.5-7B-Instruct if TRINITY_MODEL_PATH not set
model_path: ${oc.env:TRINITY_MODEL_PATH,Qwen/Qwen2.5-7B-Instruct}
tinker: # tinker config, you can set tinker parameters here
enable: false # if true, tinker will be enabled
cluster:
node_num: 1 # cluster with 1 node
gpu_per_node: 8 # each node has 8 GPUs
buffer:
total_epochs: 4 # run taskset for 4 epoch
explorer:
runner_per_model: 32 # each model has 32 runners for parallel rollout
max_timeout: 600 # max timeout for each rollout is 600 seconds
synchronizer:
sync_style: fixed
sync_method: 'nccl'
sync_interval: 10
sync_timeout: 7200 # wait for 120 minutes
trainer:
save_interval: 90 # save checkpoint every 90 steps
use_dynamic_bsz: true
ulysses_sequence_parallel_size: 1 # use sequence parallelism to reduce memory usage
monitor:
monitor_type: wandb # here we use wandb; you can also use tensorboard, mlflow or swanlab

View File

@@ -0,0 +1,175 @@
# -*- coding: utf-8 -*-
# pylint: skip-file
import argparse
import json
import time
from typing import Union
from llm_info_extraction import llm_info_extraction, parse_llm_output
from message_splitter import split_session_to_json_lines
def process_jsonl_file(
input_file: str,
output_file: str,
model_call_mode: str = "online_api",
max_retries: int = 3,
**kwargs: dict,
) -> str:
"""
Process all sessions in a JSONL file and save results to output file.
Args:
input_file (str): Path to input JSONL file
output_file (str): Path to output JSONL file
model_call_mode (str): Either "online_api" or "local_vllm"
max_retries (int): Maximum number of retries for LLM calls
**kwargs: Additional parameters for API calls
Returns:
str: Success message or error information
"""
try:
# Read and process each session
with open(input_file, "r", encoding="utf-8") as infile, open(
output_file,
"w",
encoding="utf-8",
) as outfile:
for line_num, line in enumerate(infile, 1):
if line.strip():
try:
session = json.loads(line)
print(
f"Processing session "
f"{session.get('session_id', 'unknown')} "
f"(line {line_num})...",
)
# Process the session
processed_lines = process_session(
session,
model_call_mode,
max_retries,
**kwargs,
)
for processed_line in processed_lines:
outfile.write(processed_line + "\n")
except json.JSONDecodeError as e:
print(
f"Warning: Skipping invalid JSON at line "
f"{line_num}: {e}",
)
except Exception as e:
print(
f"Warning: Error processing session at line "
f"{line_num}: {e}",
)
return f"Successfully processed. Results saved to {output_file}"
except Exception as e:
return f"Error processing JSONL file: {str(e)}"
def process_session(
session: dict,
model_call_mode: str = "online_api",
max_retries: int = 3,
**kwargs: dict,
) -> Union[list, str]:
"""
Pipeline function that splits messages into rounds and extracts info from
each round's remaining chat.
Args:
session (dict): Session dictionary containing 'session_id', 'diagn',
and 'messages' keys
model_call_mode (str): Either "online_api" or "local_vllm"
max_retries (int): Maximum number of retries for LLM calls
**kwargs: Additional parameters for API calls
Returns:
list: List of JSON strings with added "info_set" key,
or error information
"""
try:
# Step 1: Split messages into JSON lines
json_lines = split_session_to_json_lines(session)
# Step 2: Process each JSON line with LLM info extraction
processed_lines = []
for line in json_lines:
data = json.loads(line)
remaining_chat = data.get("remaining_chat", "")
# Retry loop for LLM calls
info_set = None
for attempt in range(max_retries):
try:
# Call LLM info extraction
# (using mock function for testing)
llm_response = llm_info_extraction(
remaining_chat,
model_call_mode,
**kwargs,
)
info_set = parse_llm_output(llm_response)
if isinstance(info_set, list):
break
else:
# If parsing failed, this is an error message
print(f"Attempt {attempt + 1} failed: {info_set}")
if attempt < max_retries - 1:
time.sleep(1)
except Exception as e:
print(
f"Attempt {attempt + 1} failed with exception: "
f"{str(e)}",
)
if attempt < max_retries - 1:
time.sleep(1) # Shorter wait for testing
data["info_set"] = info_set
processed_lines.append(json.dumps(data, ensure_ascii=False))
return processed_lines
except Exception as e:
return f"Pipeline error: {str(e)}"
# Example usage:
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--input_file",
type=str,
default="tuner/learn_to_ask/data_raw/train_origin.jsonl",
)
parser.add_argument(
"--output_file",
type=str,
default="tuner/learn_to_ask/data_raw/train_processed.jsonl",
)
parser.add_argument(
"--model_call_mode",
type=str,
choices=["online_api", "local_vllm"],
default="local_vllm",
)
parser.add_argument("--model_path", type=str, required=True)
args = parser.parse_args()
print(
process_jsonl_file(
input_file=args.input_file,
output_file=args.output_file,
model_call_mode=args.model_call_mode,
model_path=args.model_path,
# Additional parameters for API calls
),
)

View File

@@ -0,0 +1,65 @@
# -*- coding: utf-8 -*-
# pylint: skip-file
import argparse
import json
def process_message(json_obj: dict) -> tuple:
info_set = json_obj.get("info_set")
info_set_str = ", ".join(info_set) if isinstance(info_set, list) else ""
if "user: " not in json_obj["remaining_chat"]:
decision_str = "stop"
else:
decision_str = "continue"
if not info_set_str and decision_str == "continue":
if_keep = False
else:
if_keep = True
return if_keep, info_set_str, decision_str
def main(input_file_path: str, output_file_path: str) -> None:
with open(input_file_path, "r", encoding="utf-8") as infile, open(
output_file_path,
"w",
encoding="utf-8",
) as outfile:
print("data processing started...")
for line in infile:
data = json.loads(line.strip())
if_keep, info_set, decision = process_message(data)
if not if_keep:
continue
new_item = {
"cid": data["cid"],
"session_id": data["session_id"],
"diagn": data["diagn"],
"messages": data["messages"],
"decision_truth": decision,
"info_truth": info_set,
}
outfile.write(json.dumps(new_item, ensure_ascii=False) + "\n")
print("job done!")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# The file generated by 1_info_extract_pipeline.py
parser.add_argument(
"--input_file",
type=str,
default="tuner/learn_to_ask/data_raw/train_processed.jsonl",
)
# The final file for training or testing
parser.add_argument(
"--output_file",
type=str,
default="tuner/learn_to_ask/data/train.jsonl",
)
args = parser.parse_args()
main(args.input_file, args.output_file)

View File

@@ -0,0 +1,334 @@
# -*- coding: utf-8 -*-
# pylint: skip-file
"""
This script is used to use VLLM to generate rollout samples from the converted
checkpoints.
"""
import argparse
import copy
import gc
import json
import math
import os
import re
import time
from typing import List
import torch
from transformers import Any, AutoTokenizer
from vllm import LLM, SamplingParams
def init_llm(model_path: str) -> tuple:
tokenizer = AutoTokenizer.from_pretrained(model_path)
device_count = torch.cuda.device_count()
print(f"device_count={device_count}")
if device_count < 1:
raise RuntimeError("No GPU available for multi-card inference.")
print(f"Loading model from: {model_path}")
llm = LLM(model=model_path, tensor_parallel_size=device_count)
print("Model loaded successfully!")
sampling_params = SamplingParams(
temperature=1.0,
top_p=0.95,
max_tokens=512,
repetition_penalty=1.2,
)
return llm, tokenizer, sampling_params
def rollout(
llm: Any,
tokenizer: Any,
sampling_params: Any,
input_file_path: str,
output_file_path: str,
rollout_repeat: int = 3,
) -> None:
import importlib
spec = importlib.util.spec_from_file_location(
"prompt",
os.path.join(os.path.dirname(__file__), "..", "prompt.py"),
)
module = importlib.util.module_from_spec(spec) # type: ignore
spec.loader.exec_module(module)
rollout_prompt = module.rollout_prompt_med
with open(input_file_path, "r", encoding="utf-8") as lines:
sample_list = [json.loads(line.strip()) for line in lines]
print(f"loaded samples: {len(sample_list)}")
for index, sample in enumerate(sample_list):
record = copy.deepcopy(sample)
print(f"index: {index}, session_id: {sample['session_id']}")
messages = [{"role": "system", "content": rollout_prompt}] + sample[
"messages"
]
# Some tokenizers (e.g., Qwen) support the `enable_thinking` argument,
# but others do not. Try with the argument first, and fall back if
# it is not accepted.
try:
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False,
)
except TypeError:
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
response_list = []
for i in range(rollout_repeat):
time_probe = time.perf_counter()
outputs = llm.generate([prompt], sampling_params=sampling_params)
print(f"time cost: {time.perf_counter() - time_probe}")
for output in outputs:
response = output.outputs[0].text
response_list.append(response)
print(f"rollout #{i}: {response}\n")
record["rollouts"] = response_list
# append to output file
with open(output_file_path, "a") as f:
f.write(json.dumps(record, ensure_ascii=False) + "\n")
def eval_sample(
llm: Any,
tokenizer: Any,
sampling_params: Any,
input_file_path: str,
output_file_path: str,
) -> None:
import importlib
spec = importlib.util.spec_from_file_location(
"prompt",
os.path.join(os.path.dirname(__file__), "..", "prompt.py"),
)
module = importlib.util.module_from_spec(spec) # type: ignore
spec.loader.exec_module(module)
grader_prompt = module.reward_prompt_med
print(f"input_file_path: {input_file_path}")
print(f"output_file_path: {output_file_path}")
with open(input_file_path, "r", encoding="utf-8") as lines:
sample_list = [json.loads(line.strip()) for line in lines]
print(f"Total records: {len(sample_list)}")
def res_formatter(res_content: str) -> dict:
pattern = r"<(\w+)>(.*?)</\1>"
matches = re.findall(pattern, res_content)
result = {}
for tag_name, content in matches:
result[tag_name] = content
return result
def msg2str(msg_list: List) -> str:
result_str = ""
for msg in msg_list:
if msg["role"] == "user":
result_str += f"patient: {msg['content']}\n"
if msg["role"] == "assistant":
result_str += f"doctor: {msg['content']}\n"
return result_str
for index, sample in enumerate(sample_list):
print(f"index: {index}, cid: {sample['cid']}")
action_truth = sample["decision_truth"]
info_truth = sample["info_truth"] if sample["info_truth"] else "None."
print(f"action_truth: {action_truth}, info_truth:{info_truth}")
sys_prompt = grader_prompt.format(info_truth)
history = msg2str(sample["messages"])
sample["grades"] = []
for rollout in sample["rollouts"]:
time_probe = time.perf_counter()
action_score, content_score, format_score, res_think = (
0,
0,
0,
"NA",
)
if "<stop />" in rollout:
action_rollout = "stop"
else:
action_rollout = "continue"
if action_truth == action_rollout:
action_score = 1
if action_truth == "continue":
user_content = history + f"doctor: {rollout}"
messages = [
{"role": "system", "content": sys_prompt},
{"role": "user", "content": user_content},
]
try:
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False,
)
except TypeError:
# Fallback for tokenizers that do not support the
# Qwen-specific `enable_thinking` argument.
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
outputs = llm.generate(
[prompt],
sampling_params=sampling_params,
)
for output in outputs:
response = output.outputs[0].text
print(f"Response: {response}\n")
res_dict = res_formatter(response)
try:
format_score = float(res_dict.get("format_score", 0.0))
content_score = float(
res_dict.get("content_score", 0.0),
)
res_think = res_dict.get("think", "None")
except Exception as e:
print(e)
else:
content_score = 1.0
format_score = 1.0 if rollout == "<stop />" else 0.0
else:
action_score, format_score, content_score = 0, 0, 0
grade_result = {
"think": res_think,
"action_score": action_score,
"format_score": format_score,
"content_score": content_score,
}
sample["grades"].append(grade_result)
json_str = json.dumps(grade_result, ensure_ascii=False, indent=2)
print(
f"grade_result:{json_str}",
)
print(f"time_cost:{time.perf_counter() - time_probe}")
# append sample to output file
with open(output_file_path, "a") as f:
f.write(json.dumps(sample, ensure_ascii=False) + "\n")
print("\n======================\n")
def compute_score(input_file_path: str) -> None:
with open(input_file_path, "r", encoding="utf-8") as lines:
sample_list = [json.loads(line.strip()) for line in lines]
continue_count, continue_content_score, continue_content_full = 0, 0, 0
continue_decision_score = 0
stop_count, stop_decision_score = 0, 0
total_reward, total_format = 0, 0
(
continue_count_correct,
continue_content_score_correct,
continue_content_full_correct,
) = (0, 0, 0)
for sample in sample_list:
for rollout, grade in zip(sample["rollouts"], sample["grades"]):
if math.isnan(grade["content_score"]) or math.isnan(
grade["format_score"],
):
continue
if sample["decision_truth"] == "continue":
continue_count += 1
continue_content_score += grade["content_score"]
continue_content_full += (
1 if grade["content_score"] == 1 else 0
)
continue_decision_score += grade["action_score"]
if "<stop />" not in rollout:
continue_count_correct += 1
continue_content_score_correct += grade["content_score"]
continue_content_full_correct += (
1 if grade["content_score"] == 1 else 0
)
else:
stop_count += 1
stop_decision_score += grade["action_score"]
total_reward += (
grade["action_score"] * (1 + 2 * grade["content_score"])
+ grade["format_score"]
)
total_format += grade["format_score"]
result = {
"ave_continue_content": continue_content_score
/ max(1, continue_count),
"win_continue_content": continue_content_full / max(1, continue_count),
"ave_continue_content if correct": continue_content_score_correct
/ max(1, continue_count_correct),
"win_continue_content if correct": continue_content_full_correct
/ max(1, continue_count_correct),
"ave_continue_decision": continue_decision_score
/ max(1, continue_count),
"ave_stop_decision": stop_decision_score / max(1, stop_count),
"ave_total_decision": (continue_decision_score + stop_decision_score)
/ max(1, continue_count + stop_count),
"ave_total_format": total_format / max(1, continue_count + stop_count),
"ave_total_reward": total_reward / max(1, continue_count + stop_count),
}
print(f"total count: {continue_count + stop_count}")
print(json.dumps(result, ensure_ascii=False, indent=4))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--rollout_repeat", type=int, default=3)
# Ckpt for testing
parser.add_argument("--eval_model_path", type=str, required=True)
# Model to empower the grading, Qwen2.5-32b-instruct is recommended
parser.add_argument("--grader_model_path", type=str, required=True)
# Your test sample path [input]
parser.add_argument("--test_file_path", type=str, required=True)
# Rollout results given test samples [output]
parser.add_argument("--rollout_file_path", type=str, required=True)
# Final output given rollout results [output]
parser.add_argument("--eval_file_path", type=str, required=True)
args = parser.parse_args()
# rollout stage
llm, tokenizer, sampling_params = init_llm(args.eval_model_path)
rollout(
llm,
tokenizer,
sampling_params,
args.test_file_path,
args.rollout_file_path,
args.rollout_repeat,
)
del llm # clean up the memory after the inference
gc.collect()
torch.cuda.empty_cache() # release gpu memory
# eval stage
llm2, tokenizer2, sampling_params2 = init_llm(args.grader_model_path)
eval_sample(
llm2,
tokenizer2,
sampling_params2,
args.rollout_file_path,
args.eval_file_path,
)
compute_score(args.eval_file_path)

View File

@@ -0,0 +1,174 @@
# -*- coding: utf-8 -*-
# pylint: skip-file
import os
from typing import List, Union
import openai
import torch
import transformers
tokenizer = None
llm = None
def llm_info_extraction(
remaining_chat: str,
model_call_mode: str,
**kwargs: dict,
) -> str:
"""
Extract information from remaining_chat using LLM.
Args:
remaining_chat (str): The chat content to process
model_call_mode (str): Either "online_api" or "local_vllm"
**kwargs: Additional parameters for API calls
Returns:
str: Response text from LLM or error information
"""
# Create messages format with system and user roles
system_message = """
# Task:
You are a medical information assistant. Given a dialogue between a physician (assistant) and a patient (user), extract the clinical attributes of interest to the physician based on their questions. The target fields include: symptom, symptom nature, symptom location, symptom severity, and symptom trigger. Then, identify the corresponding specific information from the patient's responses and pair it with the respective field.
# Requirements:
- Do not fabricate information or introduce new fields not listed above. Ignore patient-reported information regarding prior medication use, allergies, or underlying comorbidities; do not include such details in the output.
- Only include fields explicitly inquired about by the physician. Omit any fields not addressed in the dialogue. Avoid outputting vague terms (e.g., "unspecified" or "unknown").
- Prevent duplication: if a symptom description already includes anatomical location, do not separately list the location field.
- Format each entry as a string enclosed in single quotes ('), and separate multiple entries with commas, ensuring any necessary escape characters within the strings. Enclose the entire output within square brackets to form a list. If the dialogue is unrelated to the aforementioned clinical attributes, output only "[]".
- Do not include reasoning steps or additional commentary outside the specified format. Condense colloquial patient expressions into concise, standardized, and clinically appropriate terminology.
# Example output format:
['symptom: diarrhea', 'symptom nature: watery stool', 'symptom severity: 4-5 times per day']
""" # noqa: E501
user_message = remaining_chat
messages = [
{"role": "system", "content": system_message},
{"role": "user", "content": "```\n" + user_message + "\n```\n"},
]
try:
if model_call_mode == "online_api":
# OpenAI-style API call
return _call_online_api(messages, **kwargs)
elif model_call_mode == "local_vllm":
# Local vLLM call
return _call_local_vllm(messages, **kwargs)
else:
return (
f"Error: Invalid model_call_mode '{model_call_mode}'. "
"Must be 'online_api' or 'local_vllm'."
)
except Exception as e:
return f"Error occurred: {str(e)}"
def _call_online_api(messages: List, **kwargs: dict) -> str:
"""Handle OpenAI-style API calls"""
# Extract API parameters from kwargs or use defaults
api_key = kwargs.get("api_key", os.getenv("DASHSCOPE_API_KEY"))
api_base = kwargs.get(
"api_base",
"https://dashscope.aliyuncs.com/compatible-mode/v1",
)
model = kwargs.get("model", "qwen2.5-72b-instruct")
temperature = kwargs.get("temperature", 0.7)
max_tokens = kwargs.get("max_tokens", 500)
client = openai.OpenAI(api_key=api_key, base_url=api_base)
response = client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
)
return response.choices[0].message.content
def _call_local_vllm(messages: List, **kwargs: dict) -> str:
"""Handle local vLLM calls"""
try:
from vllm import LLM, SamplingParams
model_path = kwargs.get("model_path")
if not model_path:
return "Error: model_path is required for local vLLM inference"
temperature = kwargs.get("temperature", 0.7)
max_tokens = kwargs.get("max_tokens", 512)
top_p = kwargs.get("top_p", 0.9)
repetition_penalty = kwargs.get("repetition_penalty", 1.1)
# GPU/CUDA related parameters for vLLM
tensor_parallel_size = kwargs.get(
"tensor_parallel_size",
torch.cuda.device_count(),
)
gpu_memory_utilization = kwargs.get("gpu_memory_utilization", 0.9)
enforce_eager = kwargs.get("enforce_eager", False)
dtype = kwargs.get("dtype", "auto")
max_model_len = kwargs.get("max_model_len", 4096)
# Initialize the LLM with the provided model path and GPU parameters
global llm, tokenizer
if llm is None:
llm = LLM(
model=model_path,
tensor_parallel_size=tensor_parallel_size,
gpu_memory_utilization=gpu_memory_utilization,
enforce_eager=enforce_eager,
dtype=dtype,
max_model_len=max_model_len,
)
sampling_params = SamplingParams(
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
repetition_penalty=repetition_penalty,
)
# Convert messages to a single prompt string
if tokenizer is None:
tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
outputs = llm.generate([prompt], sampling_params)
return outputs[0].outputs[0].text
except ImportError:
return (
"Error: vLLM library not installed. "
"Please install it with 'pip install vllm'"
)
except Exception as e:
return f"Error in local vLLM inference: {str(e)}"
def parse_llm_output(output_str: str) -> Union[List[str], str]:
"""
Convert the LLM info extraction output string to a list of strings.
Args:
output_str (str): String in format "['symptom: diarrhea', 'symptom nature: watery stool', 'symptom severity: 4-5 times per day']"
Returns:
list: List of strings if successful, error message string if failed
""" # noqa: E501
import ast
try:
result = ast.literal_eval(output_str)
if not isinstance(result, list):
return f"Error: Expected a list, got {type(result)}"
return result
except Exception as e:
return f"Error parsing output: [{repr(output_str)}] error = {str(e)}"

View File

@@ -0,0 +1,119 @@
# -*- coding: utf-8 -*-
# pylint: skip-file
import json
from typing import Dict, List
def split_single_message_list(messages: List) -> List:
"""
Split a single message list into multiple rounds.
Args:
messages (list): List of message dictionaries with 'role' and 'content' keys
Returns:
list: List of rounds, where each round contains messages and remaining chat
""" # noqa: E501
rounds = []
round_number = 1
i = 0
while i < len(messages):
# Collect messages for this round
round_messages = []
# Add messages until we reach a user message
while i < len(messages) and messages[i].get("role") != "user":
round_messages.append(messages[i])
i += 1
# Add user message(s) - if there are consecutive user messages,
# include all of them in this round
while i < len(messages) and messages[i].get("role") == "user":
round_messages.append(messages[i])
i += 1
# The remaining messages (if any) form the remaining_chat
remaining_messages = messages[i:]
round_entry = {
"round_number": round_number,
"messages": round_messages,
}
# Add remaining chat if there are remaining messages
if remaining_messages:
remaining_chat_parts = []
for msg in remaining_messages:
role = msg.get("role", "")
content = msg.get("content", "")
remaining_chat_parts.append(f"{role}: {content}")
round_entry["remaining_chat"] = "\n".join(remaining_chat_parts)
else:
round_entry["remaining_chat"] = ""
rounds.append(round_entry)
round_number += 1
return rounds
def split_session_to_json_lines(session: Dict) -> List[str]:
"""
Split a session dictionary into multiple rounds and convert to JSON lines.
Args:
session (dict): Session dictionary containing 'session_id', 'diagn', and 'messages' keys
- session_id (str): Session identifier
- diagn (str): Diagnosis information
- messages (list): List of message dictionaries with 'role' and 'content' keys
Returns:
list: List of JSON strings, each representing a round with cid, session_id, diagn, messages, and remaining_chat
""" # noqa: E501
rounds = split_single_message_list(session["messages"])
json_lines = []
for round_data in rounds:
round_entry = {
"cid": f"{session['session_id']}_{round_data['round_number']}",
"session_id": session["session_id"],
"diagn": session["diagn"],
"messages": round_data["messages"],
"remaining_chat": round_data["remaining_chat"],
}
json_lines.append(json.dumps(round_entry, ensure_ascii=False))
return json_lines
# Example usage:
if __name__ == "__main__":
# Example of splitting a single message list
example_messages = [
{"role": "assistant", "content": "Hello, how can I help you today?"},
{"role": "user", "content": "I've been having headaches lately."},
{
"role": "assistant",
"content": "How long have you been experiencing these headaches?",
},
{"role": "user", "content": "For about a week now."},
{
"role": "assistant",
"content": "I see. Have you taken any medication for them?",
},
{
"role": "user",
"content": "Yes, I've tried some over-the-counter pain relievers.",
},
]
example_session = {
"session_id": "session_1",
"diagn": "migraine",
"messages": example_messages,
}
json_lines = split_session_to_json_lines(example_session)
print("JSON lines output:")
for i, line in enumerate(json_lines):
print(f"Line {i + 1}: {line}")

Some files were not shown because too many files have changed in this diff Show More