feat(alias core): add data source management (#110)
Co-authored-by: Tianjing Zeng <39507457+StCarmen@users.noreply.github.com> Co-authored-by: stcarmen <1106135234@qq.com>
This commit is contained in:
@@ -207,10 +207,23 @@ alias_agent run --mode finance --task "Analyze Tesla's Q4 2024 financial perform
|
|||||||
# Data Science mode
|
# Data Science mode
|
||||||
alias_agent run --mode ds \
|
alias_agent run --mode ds \
|
||||||
--task "Analyze the distribution of incidents across categories in 'incident_records.csv' to identify imbalances, inconsistencies, or anomalies, and determine their root cause." \
|
--task "Analyze the distribution of incidents across categories in 'incident_records.csv' to identify imbalances, inconsistencies, or anomalies, and determine their root cause." \
|
||||||
--files ./docs/data/incident_records.csv
|
--datasource ./docs/data/incident_records.csv
|
||||||
```
|
```
|
||||||
|
|
||||||
**Note**: Files uploaded with `--files` are automatically copied to `/workspace` in the sandbox. Generated files are available in `sessions_mount_dir` subdirectories.
|
#### Input/Output Management
|
||||||
|
|
||||||
|
**Input:**
|
||||||
|
- Use the `--datasource` parameter (with aliases `--files` for backward compatibility) to specify data sources, supporting multiple formats:
|
||||||
|
- **Local files**: such as `./data.txt` or `/absolute/path/file.json`
|
||||||
|
- **Database DSN**: supports relational databases like PostgreSQL and SQLite, with format like `postgresql://user:password@host:port/database`
|
||||||
|
|
||||||
|
Examples: `--datasource file.txt postgresql://user:password@localhost:5432/mydb`
|
||||||
|
|
||||||
|
- Specified data sources will be automatically profiled (analyzed) and provide guidance for efficient data source access to the model.
|
||||||
|
- Uploaded files are automatically copied to the `/workspace` directory in the sandbox.
|
||||||
|
|
||||||
|
**Output:**
|
||||||
|
- Generated files are stored in subdirectories of `sessions_mount_dir`, where all output results can be found.
|
||||||
|
|
||||||
#### Enable Long-Term Memory Service (General Mode Only)
|
#### Enable Long-Term Memory Service (General Mode Only)
|
||||||
To enable the long-term memory service in General mode, you need to:
|
To enable the long-term memory service in General mode, you need to:
|
||||||
|
|||||||
@@ -208,10 +208,25 @@ alias_agent run --mode finance --task "Analyze Tesla's Q4 2024 financial perform
|
|||||||
# 数据科学(Data Science)模式
|
# 数据科学(Data Science)模式
|
||||||
alias_agent run --mode ds \
|
alias_agent run --mode ds \
|
||||||
--task "Analyze the distribution of incidents across categories in 'incident_records.csv' to identify imbalances, inconsistencies, or anomalies, and determine their root cause." \
|
--task "Analyze the distribution of incidents across categories in 'incident_records.csv' to identify imbalances, inconsistencies, or anomalies, and determine their root cause." \
|
||||||
--files ./docs/data/incident_records.csv
|
--datasource ./docs/data/incident_records.csv
|
||||||
```
|
```
|
||||||
|
|
||||||
**注意**:使用 `--files` 上传的文件会自动复制到沙盒中的 `/workspace`。生成的文件可在 `sessions_mount_dir` 的子目录中找到。
|
#### 输入/输出管理
|
||||||
|
|
||||||
|
**输入:**
|
||||||
|
- 使用 `--datasource` 参数指定数据源,支持多种格式 (向后兼容,也支持使用 `--files`):
|
||||||
|
- **本地文件**:如 `./data.txt` 或 `/absolute/path/file.json`
|
||||||
|
- **数据库 DSN**:支持 PostgreSQL、SQLite 等关系型数据库,格式如 `postgresql://user:password@host:port/database`
|
||||||
|
|
||||||
|
示例: `--datasource file.txt postgresql://user:password@localhost:5432/mydb`
|
||||||
|
- 指定的数据源会自动进行 profile(分析),并为模型提供高效访问数据源的指导。
|
||||||
|
- 上传的文件会自动复制到沙盒中的 `/workspace` 目录。
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
**输出:**
|
||||||
|
- 生成的文件存储在 `sessions_mount_dir` 的子目录中,可以在该位置找到所有输出结果。
|
||||||
|
|
||||||
|
|
||||||
#### 启用长期记忆服务(仅限通用模式)
|
#### 启用长期记忆服务(仅限通用模式)
|
||||||
要在通用模式下启用长期记忆服务,您需要:
|
要在通用模式下启用长期记忆服务,您需要:
|
||||||
|
|||||||
@@ -45,7 +45,8 @@ dependencies = [
|
|||||||
"agentscope-runtime>=1.0.0",
|
"agentscope-runtime>=1.0.0",
|
||||||
"aiosqlite>=0.21.0",
|
"aiosqlite>=0.21.0",
|
||||||
"asyncpg>=0.30.0",
|
"asyncpg>=0.30.0",
|
||||||
"itsdangerous>=2.2.0"
|
"itsdangerous>=2.2.0",
|
||||||
|
"polars>=1.37.1"
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.setuptools]
|
[tool.setuptools]
|
||||||
|
|||||||
@@ -0,0 +1,74 @@
|
|||||||
|
---
|
||||||
|
name: csv-excel-file
|
||||||
|
description: Guidelines for handling CSV/Excel files
|
||||||
|
type:
|
||||||
|
- csv
|
||||||
|
- excel
|
||||||
|
---
|
||||||
|
|
||||||
|
# CSV/Excel Handling Specifications
|
||||||
|
|
||||||
|
## Goals
|
||||||
|
|
||||||
|
- Safely load tabular data without crashing.
|
||||||
|
- Detect and handle messy spreadsheets (multiple blocks, missing headers, merged cells artifacts).
|
||||||
|
- Produce reliable outputs (clean dataframe for clean table or structured JSON for messy spreadsheet) with validated types.
|
||||||
|
|
||||||
|
## Encoding, Delimiters, and Locale
|
||||||
|
|
||||||
|
- CSV encoding: Try UTF-8; if garbled, attempt common fallbacks (e.g., gbk, cp1252) based on context.
|
||||||
|
- Delimiters: Detect common separators (,, \t, ;, |) during inspection.
|
||||||
|
- Locale formats: Be cautious with comma decimal separators and thousands separators.
|
||||||
|
|
||||||
|
## Inspection (always first)
|
||||||
|
|
||||||
|
- Identify file type, encoding (CSV), and sheet names (Excel) before full reads.
|
||||||
|
- Prefer small reads to preview structure:
|
||||||
|
- CSV: pd.read_csv(..., nrows=20); if uncertain delimiter: sep=None, engine="python" (small nrows only).
|
||||||
|
- Excel: pd.ExcelFile(path).sheet_names, then pd.read_excel(..., sheet_name=..., nrows=20).
|
||||||
|
- Use df.head(n) and df.columns to check:
|
||||||
|
- Missing/incorrect headers (e.g., columns are numeric 0..N-1)
|
||||||
|
- "Unnamed: X" columns
|
||||||
|
- Unexpected NaN/NaT, merged-cell artifacts
|
||||||
|
- Multiple tables/blocks in one sheet (blank rows separating sections)
|
||||||
|
|
||||||
|
## Preprocessing
|
||||||
|
|
||||||
|
- Treat as messy if any of the following is present:
|
||||||
|
- Columns contain "Unnamed:" or mostly empty column names
|
||||||
|
- Header row appears inside the data (first rows look like data + later row looks like header)
|
||||||
|
- Multiple data blocks (large blank-row gaps, repeated header patterns)
|
||||||
|
- Predominantly NaN/NaT in top rows/left columns
|
||||||
|
- Notes/metadata blocks above/beside the table (titles, footnotes, merged header areas)
|
||||||
|
- If messy spreadsheets are detected:
|
||||||
|
- First choice: use `clean_messy_spreadsheet` tool to extract key tables/fields and output JSON.
|
||||||
|
- Only fall back to manual parsing if tool fails, returns empty/incorrect structure, or cannot locate the target table.
|
||||||
|
|
||||||
|
## Querying
|
||||||
|
|
||||||
|
- Never load entire datasets blindly.
|
||||||
|
- Use minimal reads:
|
||||||
|
- `nrows`, `usecols`, `dtype` (or partial dtype mapping), `parse_dates` only when necessary.
|
||||||
|
- Sampling: `skiprows` with a step pattern for rough profiling when file is huge.
|
||||||
|
- For very large CSV:
|
||||||
|
- Prefer `chunksize` iteration; aggregate/compute per chunk.
|
||||||
|
- For Excel:
|
||||||
|
- Read only needed `sheet_name`, and consider narrowing `usecols`/`nrows` during exploration.
|
||||||
|
|
||||||
|
## Data Quality & Type Validation
|
||||||
|
|
||||||
|
- After load/clean:
|
||||||
|
- Validate types:
|
||||||
|
- Numeric columns: coerce with pd.to_numeric(errors="coerce")
|
||||||
|
- Datetime columns: pd.to_datetime(errors="coerce")
|
||||||
|
- Report coercion fallout (how many became NaN/NaT).
|
||||||
|
- Standardize missing values: treat empty strings/“N/A”/“null” consistently.
|
||||||
|
|
||||||
|
# Best Practices
|
||||||
|
|
||||||
|
- Always inspect structure before processing.
|
||||||
|
- Handle encoding issues appropriately
|
||||||
|
- Keep reads minimal; expand only after confirming layout.
|
||||||
|
- Log decisions: chosen sheet, detected header row, dropped columns/rows, dtype conversions.
|
||||||
|
- Avoid silent data loss: when dropping/cleaning, summarize what changed.
|
||||||
|
- Validate data types after loading
|
||||||
@@ -0,0 +1,47 @@
|
|||||||
|
---
|
||||||
|
name: image-file
|
||||||
|
description: Guidelines for handling image files
|
||||||
|
type: image
|
||||||
|
---
|
||||||
|
|
||||||
|
# Images Handling Specifications
|
||||||
|
|
||||||
|
## Goals
|
||||||
|
|
||||||
|
- Safely identify image properties and metadata without memory exhaustion.
|
||||||
|
- Accurately extract text (OCR) and visual elements (Object Detection/Description).
|
||||||
|
- Perform necessary pre-processing (resize, normalize, crop) for downstream tasks.
|
||||||
|
- Handle multi-frame or high-resolution images efficiently.
|
||||||
|
|
||||||
|
## Inspection (Always First)
|
||||||
|
|
||||||
|
- Identify Properties: Use lightweight libraries (e.g., PIL/Pillow) to get `format`, `size` (width/height), and `mode` (RGB, RGBA, CMYK).
|
||||||
|
- Check File Size: If the image is exceptionally large (e.g., >20MB or >100MP), consider downsampling or tiling before full processing.
|
||||||
|
- Metadata/EXIF Extraction:
|
||||||
|
- Read EXIF data for orientation, GPS tags, and timestamps.
|
||||||
|
- Correction: Automatically apply EXIF orientation to ensure the image is "upright" before visual analysis.
|
||||||
|
|
||||||
|
## Content Extraction & Vision
|
||||||
|
|
||||||
|
- Vision Analysis:
|
||||||
|
- Use multimodal vision models to describe scenes, identify objects, and detect activities.
|
||||||
|
- For complex images (e.g., infographics, UI screenshots), guide the model to focus on specific regions.
|
||||||
|
- OCR (Optical Character Recognition):
|
||||||
|
- If text is detected, specify whether to extract "raw text" or "structured data" (like forms/tables).
|
||||||
|
- Handle low-contrast or noisy backgrounds by applying pre-filters (grayscale, binarization).
|
||||||
|
- Format Conversion: Convert non-standard formats (e.g., HEIC, TIFF) to standard formats (JPEG/PNG) if tools require it.
|
||||||
|
|
||||||
|
## Handling Large or Complex Images
|
||||||
|
|
||||||
|
- Tiling: For ultra-high-res images (e.g., satellite maps, medical scans), split into overlapping tiles to avoid missing small details.
|
||||||
|
- Batching: Process multiple images using generators to keep memory usage stable.
|
||||||
|
- Alpha Channel: Be mindful of transparency (PNG/WebP); decide whether to discard it or composite against a solid background (e.g., white).
|
||||||
|
|
||||||
|
## Best Practices
|
||||||
|
|
||||||
|
- Safety First: Validate that the file is a genuine image (not a renamed malicious script).
|
||||||
|
- Graceful Failure: Handle corrupted files, truncated downloads, or unsupported formats with descriptive error logs.
|
||||||
|
- Efficiency: Avoid unnecessary re-encoding (e.g., multiple JPEG saves) to prevent "generation loss" or artifacts.
|
||||||
|
- Process images individually or in small batches to prevent system crashes
|
||||||
|
- Consider memory usage when working with large or high-resolution images
|
||||||
|
- Resource Management: Close file pointers or use context managers (`with Image.open(...) as img:`) to prevent memory leaks.
|
||||||
@@ -0,0 +1,54 @@
|
|||||||
|
---
|
||||||
|
name: json-file
|
||||||
|
description: Guildlines for handling json files
|
||||||
|
type: json
|
||||||
|
---
|
||||||
|
|
||||||
|
# JSON Handling Specifications
|
||||||
|
|
||||||
|
## Goals
|
||||||
|
- Safely parse JSON/JSONL without memory overflow.
|
||||||
|
- Discover schema structure (keys, nesting depth, data types).
|
||||||
|
- Flatten complex nested structures into tabular data when necessary.
|
||||||
|
- Handle inconsistent schemas and "dirty" JSON (e.g., trailing commas, mixed types).
|
||||||
|
|
||||||
|
## Inspection (Always First)
|
||||||
|
|
||||||
|
- Structure Discovery:
|
||||||
|
- Determine if the root is a `list` or a `dict`.
|
||||||
|
- Identify if it's a standard JSON or JSONL (one valid JSON object per line).
|
||||||
|
- Schema Sampling:
|
||||||
|
- For large files, read the first few objects/lines to infer the schema.
|
||||||
|
- Identify top-level keys and their types.
|
||||||
|
- Detect nesting depth: If depth > 3, consider it a "deeply nested" structure.
|
||||||
|
- Size Check:
|
||||||
|
- If the file is large (>50MB), avoid `json.load()`. Use iterative parsing or streaming.
|
||||||
|
|
||||||
|
## Processing & Extraction
|
||||||
|
|
||||||
|
- Lazy Loading (Streaming):
|
||||||
|
- For massive JSON: Use `ijson` (Python) or similar streaming parsers to yield specific paths/items.
|
||||||
|
- For JSONL: Read line-by-line using a generator to minimize memory footprint.
|
||||||
|
- Flattening & Normalization:
|
||||||
|
- Use `pandas.json_normalize` to convert nested structures into flat tables if the goal is analysis.
|
||||||
|
- Specify `max_level` during normalization to prevent "column explosion."
|
||||||
|
- Data Filtering:
|
||||||
|
- Extract only required sub-trees (keys) early in the process to reduce the memory object size.
|
||||||
|
|
||||||
|
## Data Quality & Schema Validation
|
||||||
|
|
||||||
|
- Missing Keys: Use `.get(key, default)` or `try-except` blocks. Never assume a key exists in all objects.
|
||||||
|
- Type Coercion:
|
||||||
|
- Validate numeric strings vs. actual numbers.
|
||||||
|
- Standardize `null`, `""`, and `[]` consistently.
|
||||||
|
- Encoding: Default to UTF-8; check for BOM (utf-8-sig) if parsing fails.
|
||||||
|
- Malformed JSON Recovery:
|
||||||
|
- For minor syntax errors (e.g., single quotes instead of double), attempt `ast.literal_eval` or regex-based cleanup only as a fallback.
|
||||||
|
|
||||||
|
## Best Practices
|
||||||
|
|
||||||
|
- Minimal Reads: Don't load a 50MB JSON just to read one config key; use a streaming approach.
|
||||||
|
- Schema Logging: Document the detected structure (e.g., "Root is a list of 500 objects; key 'metadata' is nested").
|
||||||
|
- Error Transparency: When a JSON object in a JSONL stream is corrupted, log the line number, skip it, and continue instead of crashing the entire process.
|
||||||
|
- Avoid Over-Flattening: Be cautious with deeply nested arrays; flattening them can lead to massive row duplication.
|
||||||
|
- Strict Typing: After extraction, explicitly convert types (e.g., `pd.to_datetime`) to ensure downstream reliability.
|
||||||
@@ -0,0 +1,70 @@
|
|||||||
|
---
|
||||||
|
name: database
|
||||||
|
description: Guidelines for handling databases
|
||||||
|
type: relational_db
|
||||||
|
---
|
||||||
|
|
||||||
|
# Database Handling Specifications
|
||||||
|
|
||||||
|
## Goals
|
||||||
|
|
||||||
|
- Safely explore database schema without performance degradation.
|
||||||
|
- Construct precise, efficient SQL queries that prevent system crashes (OOM & OOT).
|
||||||
|
- Handle dialect-specific nuances (PostgreSQL, MySQL, SQLite, etc.).
|
||||||
|
- Transform raw result sets into structured, validated data for analysis.
|
||||||
|
|
||||||
|
## Inspection
|
||||||
|
|
||||||
|
- Volume Estimation:
|
||||||
|
- Before any `SELECT *`, always run `SELECT COUNT(*) FROM table_name` to understand the scale.
|
||||||
|
- If a table has >1,000,000 rows, strictly use indexed columns for filtering.
|
||||||
|
- Sample Data:
|
||||||
|
- Use `SELECT * FROM table_name LIMIT 5` to see actual data formats.
|
||||||
|
|
||||||
|
## Querying
|
||||||
|
|
||||||
|
- Safety Constraints:
|
||||||
|
- Always use `LIMIT`: Never execute a query without a `LIMIT` clause unless the row count is confirmed to be small.
|
||||||
|
- Avoid `SELECT *`: In production-scale tables, explicitly name columns to reduce I/O and memory usage.
|
||||||
|
- Dialect & Syntax:
|
||||||
|
- Case Sensitivity: If a column/table name contains uppercase or special characters, MUST quote it (e.g., `"UserTable"` in Postgres, `` `UserTable` `` in MySQL).
|
||||||
|
- Date/Time: Use standard ISO strings for date filtering; be mindful of timezone-aware vs. naive columns.
|
||||||
|
- Complex Queries:
|
||||||
|
- For `JOIN` operations, ensure joining columns are indexed to prevent full table scans.
|
||||||
|
- When performing `GROUP BY`, ensure the result set size is manageable.
|
||||||
|
|
||||||
|
## Data Retrieval & Transformation
|
||||||
|
|
||||||
|
- Type Mapping:
|
||||||
|
- Ensure SQL types (e.g., `DECIMAL`, `BIGINT`, `TIMESTAMP`) are correctly mapped to Python/JSON types without precision loss.
|
||||||
|
- Convert `NULL` values to a consistent "missing" representation (e.g., `None` or `NaN`).
|
||||||
|
- Chunked Fetching:
|
||||||
|
- For medium-to-large exports, use `fetchmany(size)` or `OFFSET/LIMIT` pagination instead of fetching everything into memory at once.
|
||||||
|
- Aggregations:
|
||||||
|
- Prefer performing calculations (SUM, AVG, COUNT) at the database level rather than pulling raw data to the client for processing.
|
||||||
|
|
||||||
|
## Error Handling & Recovery
|
||||||
|
|
||||||
|
- Timeout Management: If a query takes too long, retry with more restrictive filters or optimized joins.
|
||||||
|
- Syntax Errors: If a query fails, inspect the dialect-specific error message and re-verify the schema (it's often a misspelled column or missing quotes).
|
||||||
|
|
||||||
|
## Anti-Pattern Prevention (Avoiding "Bad" SQL)
|
||||||
|
|
||||||
|
- Index-Friendly Filters: Never wrap indexed columns in functions (e.g., `DATE()`, `UPPER()`) within the `WHERE` clause.
|
||||||
|
- Join Safety: Always verify join keys. Before joining, check if the key has high cardinality to avoid massive intermediate result sets.
|
||||||
|
- Memory Safety:
|
||||||
|
- Avoid `DISTINCT` and `UNION` (which performs de-duplication) on multi-million row sets unless necessary; use `UNION ALL` if duplicates are acceptable.
|
||||||
|
- Avoid `ORDER BY` on large non-indexed text fields.
|
||||||
|
- Wildcard Warning: Strictly avoid leading wildcards in `LIKE` patterns (e.g., `%term`) on large text columns.
|
||||||
|
- No Function on Columns: `WHERE col = FUNC(val)` is good; `WHERE FUNC(col) = val` is bad.
|
||||||
|
- Explicit Columns: Only fetch what is necessary.
|
||||||
|
- Early Filtering: Push `WHERE` conditions as close to the base tables as possible.
|
||||||
|
- CTE for Clarity: Use `WITH` for complex multi-step logic to improve maintainability and optimizer hints.
|
||||||
|
|
||||||
|
# Best Practices
|
||||||
|
|
||||||
|
- Always verify database structure before querying
|
||||||
|
- Use appropriate sampling techniques for large datasets
|
||||||
|
- Optimize queries for efficiency based on schema inspection
|
||||||
|
- Self-review the draft SQL against the "Anti-Pattern Prevention" list.
|
||||||
|
- Perform a silent mental 'EXPLAIN' on your query. If it smells like a full table scan on a large table, refactor it before outputting
|
||||||
@@ -0,0 +1,50 @@
|
|||||||
|
---
|
||||||
|
name: text-file
|
||||||
|
description: Guidelines for handling text files
|
||||||
|
type: text
|
||||||
|
---
|
||||||
|
|
||||||
|
# Text Files Handling Specifications
|
||||||
|
|
||||||
|
## Goals
|
||||||
|
- Safely read text files without memory exhaustion.
|
||||||
|
- Accurately detect encoding to avoid garbled characters.
|
||||||
|
- Identify underlying patterns (e.g., Log formats, Markdown structure, delimiters).
|
||||||
|
- Efficiently extract or search for specific information within large volumes of text.
|
||||||
|
|
||||||
|
## Encoding & Detection
|
||||||
|
|
||||||
|
- Encoding Strategy:
|
||||||
|
- Default to `utf-8`.
|
||||||
|
- If it fails, try `utf-8-sig` (for files with BOM), `gbk/gb18030` (for Chinese context), or `latin-1`.
|
||||||
|
- Use `chardet` or similar logic if encoding is unknown and first few bytes look non-standard.
|
||||||
|
- Line Endings: Be aware of `\n` (Unix), `\r\n` (Windows), and `\r` (Legacy Mac) when counting lines or splitting.
|
||||||
|
|
||||||
|
## Inspection
|
||||||
|
|
||||||
|
- Preview: Read the first 10-20 lines to determine:
|
||||||
|
- Content Type: Is it a log, code, prose, or a semi-structured list?
|
||||||
|
- Uniformity: Does every line follow the same format?
|
||||||
|
- Metadata: Check total file size before reading. If >50MB, treat as a "large file" and avoid full loading.
|
||||||
|
|
||||||
|
## Querying & Reading (Large Files)
|
||||||
|
|
||||||
|
- Streaming: For files exceeding memory or >50MB:
|
||||||
|
- Use `with open(path) as f: for line in f:` to process line-by-line.
|
||||||
|
- Never use `.read()` or `.readlines()` on large files.
|
||||||
|
- Random Sampling: To understand a huge file's structure, read the first N lines, the middle N lines (using `f.seek()`), and the last N lines.
|
||||||
|
- Pattern Matching: Use Regular Expressions (Regex) for targeted extraction instead of complex string slicing.
|
||||||
|
- Grep-like Search: If searching for a keyword, iterate through lines and only store/return matching lines + context.
|
||||||
|
|
||||||
|
## Data Quality
|
||||||
|
|
||||||
|
- Truncation Warning: If only a portion of the file is read, clearly state: "Displaying first X lines of Y total lines."
|
||||||
|
- Empty Lines/Comments: Decide early whether to ignore blank lines or lines starting with specific comment characters (e.g., `#`, `//`).
|
||||||
|
|
||||||
|
## Best Practices
|
||||||
|
|
||||||
|
- Resource Safety: Always use context managers (`with` statement) to ensure file handles are closed.
|
||||||
|
- Memory Consciousness: For logs and large TXT, prioritize "find and extract" over "load and filter."
|
||||||
|
- Regex Optimization: Compile regex patterns if they are used repeatedly in a loop over millions of lines.
|
||||||
|
- Validation: After reading, verify the content isn't binary (e.g., PDF or EXE renamed to .txt) by checking for null bytes or a high density of non-ASCII characters.
|
||||||
|
- Progress Logging: For long-running text processing, log progress every 100k lines or 10% of file size.
|
||||||
@@ -23,6 +23,7 @@ from alias.agent.tools import AliasToolkit, share_tools
|
|||||||
from alias.agent.agents.common_agent_utils import (
|
from alias.agent.agents.common_agent_utils import (
|
||||||
get_user_input_to_mem_pre_reply_hook,
|
get_user_input_to_mem_pre_reply_hook,
|
||||||
)
|
)
|
||||||
|
from alias.agent.agents.data_source.data_source import DataSourceManager
|
||||||
from .ds_agent_utils import (
|
from .ds_agent_utils import (
|
||||||
ReportGenerator,
|
ReportGenerator,
|
||||||
LLMPromptSelector,
|
LLMPromptSelector,
|
||||||
@@ -50,7 +51,8 @@ class DataScienceAgent(AliasAgentBase):
|
|||||||
formatter: FormatterBase,
|
formatter: FormatterBase,
|
||||||
memory: MemoryBase,
|
memory: MemoryBase,
|
||||||
toolkit: AliasToolkit,
|
toolkit: AliasToolkit,
|
||||||
sys_prompt: str = None,
|
data_manager: DataSourceManager = None,
|
||||||
|
sys_prompt: str = "",
|
||||||
max_iters: int = 30,
|
max_iters: int = 30,
|
||||||
tmp_file_storage_dir: str = "/workspace",
|
tmp_file_storage_dir: str = "/workspace",
|
||||||
state_saving_dir: Optional[str] = None,
|
state_saving_dir: Optional[str] = None,
|
||||||
@@ -71,17 +73,16 @@ class DataScienceAgent(AliasAgentBase):
|
|||||||
|
|
||||||
set_run_ipython_cell(self.toolkit.sandbox)
|
set_run_ipython_cell(self.toolkit.sandbox)
|
||||||
|
|
||||||
self.uploaded_files: List[str] = []
|
|
||||||
|
|
||||||
self.todo_list: List[Dict[str, Any]] = []
|
self.todo_list: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
self.infer_trajectories: List[List[Msg]] = []
|
self.tmp_file_storage_dir = tmp_file_storage_dir
|
||||||
|
|
||||||
|
self.data_manager = data_manager
|
||||||
|
|
||||||
self.detailed_report_path = os.path.join(
|
self.detailed_report_path = os.path.join(
|
||||||
tmp_file_storage_dir,
|
tmp_file_storage_dir,
|
||||||
"detailed_report.html",
|
"detailed_report.html",
|
||||||
)
|
)
|
||||||
self.tmp_file_storage_dir = tmp_file_storage_dir
|
|
||||||
|
|
||||||
self.todo_list_prompt = get_prompt_from_file(
|
self.todo_list_prompt = get_prompt_from_file(
|
||||||
os.path.join(
|
os.path.join(
|
||||||
@@ -91,12 +92,19 @@ class DataScienceAgent(AliasAgentBase):
|
|||||||
False,
|
False,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._sys_prompt = get_prompt_from_file(
|
self._sys_prompt = (
|
||||||
|
cast(
|
||||||
|
str,
|
||||||
|
get_prompt_from_file(
|
||||||
os.path.join(
|
os.path.join(
|
||||||
PROMPT_DS_BASE_PATH,
|
PROMPT_DS_BASE_PATH,
|
||||||
"_agent_system_workflow_prompt.md",
|
"_agent_system_workflow_prompt.md",
|
||||||
),
|
),
|
||||||
False,
|
False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
+ "\n\n"
|
||||||
|
+ sys_prompt
|
||||||
)
|
)
|
||||||
|
|
||||||
# load prompts and initialize selector
|
# load prompts and initialize selector
|
||||||
@@ -167,7 +175,7 @@ class DataScienceAgent(AliasAgentBase):
|
|||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[{self.name}] "
|
f"[{self.name}] "
|
||||||
"DeepInsightAgent initialized (fully model-driven).",
|
"DataScienceAgent initialized (fully model-driven).",
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -427,25 +435,56 @@ class DataScienceAgent(AliasAgentBase):
|
|||||||
memory_log=memory_log,
|
memory_log=memory_log,
|
||||||
)
|
)
|
||||||
|
|
||||||
response, report = await report_generator.generate_report()
|
(
|
||||||
|
response,
|
||||||
|
report_md,
|
||||||
|
report_html,
|
||||||
|
) = await report_generator.generate_report()
|
||||||
|
|
||||||
|
if report_md:
|
||||||
|
md_report_path = os.path.join(
|
||||||
|
self.tmp_file_storage_dir,
|
||||||
|
"detailed_report.md",
|
||||||
|
)
|
||||||
|
|
||||||
if report:
|
|
||||||
# report = report.replace(self.tmp_file_storage_dir, ".")
|
|
||||||
await self.toolkit.call_tool_function(
|
await self.toolkit.call_tool_function(
|
||||||
ToolUseBlock(
|
ToolUseBlock(
|
||||||
type="tool_use",
|
type="tool_use",
|
||||||
id=str(uuid.uuid4()),
|
id=str(uuid.uuid4()),
|
||||||
name="write_file",
|
name="write_file",
|
||||||
input={
|
input={
|
||||||
"path": self.detailed_report_path,
|
"path": md_report_path,
|
||||||
"content": report,
|
"content": report_md,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
response = (
|
response = (
|
||||||
f"{response}\n\n"
|
f"{response}\n\n"
|
||||||
"The detailed report has been saved to "
|
"The detailed report (markdown version) has been saved to "
|
||||||
f"{self.detailed_report_path}."
|
f"{md_report_path}."
|
||||||
|
)
|
||||||
|
|
||||||
|
if report_html:
|
||||||
|
html_report_path = os.path.join(
|
||||||
|
self.tmp_file_storage_dir,
|
||||||
|
"detailed_report.html",
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.toolkit.call_tool_function(
|
||||||
|
ToolUseBlock(
|
||||||
|
type="tool_use",
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
name="write_file",
|
||||||
|
input={
|
||||||
|
"path": html_report_path,
|
||||||
|
"content": report_html,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
response = (
|
||||||
|
f"{response}\n\n"
|
||||||
|
"The detailed report (html version) has been saved to "
|
||||||
|
f"{html_report_path}."
|
||||||
)
|
)
|
||||||
|
|
||||||
kwargs["response"] = response
|
kwargs["response"] = response
|
||||||
|
|||||||
@@ -698,7 +698,7 @@ class MetaPlanner(AliasAgentBase):
|
|||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Directly enter the data science mode.
|
Directly enter the data science mode.
|
||||||
Use this when the user provides some data files and ask for processing.
|
Use this for COMPLEX, CODE-BASED data analysis.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_query (`str`):
|
user_query (`str`):
|
||||||
|
|||||||
@@ -0,0 +1,722 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# pylint: disable=R1702,R0912,R0915
|
||||||
|
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
import pandas as pd
|
||||||
|
from sqlalchemy import inspect, text, create_engine
|
||||||
|
from agentscope.message import Msg
|
||||||
|
|
||||||
|
from alias.agent.agents.data_source._typing import SourceType
|
||||||
|
from alias.agent.agents.ds_agent_utils import (
|
||||||
|
get_prompt_from_file,
|
||||||
|
)
|
||||||
|
from alias.agent.utils.llm_call_manager import (
|
||||||
|
LLMCallManager,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseDataProfiler(ABC):
|
||||||
|
"""Abstract base class for data profilers that analyze different data
|
||||||
|
sources like csv, excel, db, etc.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_PROFILE_PROMPT_BASE_PATH = os.path.join(
|
||||||
|
os.path.dirname(__file__),
|
||||||
|
"built_in_prompt",
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
path: str,
|
||||||
|
source_type: SourceType,
|
||||||
|
llm_call_manager: LLMCallManager,
|
||||||
|
):
|
||||||
|
"""Initialize the data profiler with data path, type and LLM manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Path to the data source file or connection string
|
||||||
|
source_type: Enum indicating the type of data source
|
||||||
|
llm_call_manager: Manager for handling LLM calls
|
||||||
|
"""
|
||||||
|
self.path = path
|
||||||
|
self.file_name = os.path.basename(path)
|
||||||
|
self.source_type = source_type
|
||||||
|
self.llm_call_manager = llm_call_manager
|
||||||
|
|
||||||
|
self.source_types_2_prompts = {
|
||||||
|
SourceType.CSV: "_profile_csv_prompt.md",
|
||||||
|
SourceType.EXCEL: "_profile_xlsx_prompt.md",
|
||||||
|
SourceType.IMAGE: "_profile_image_prompt.md",
|
||||||
|
SourceType.RELATIONAL_DB: "_profile_relationdb_prompt.md",
|
||||||
|
"IRREGULAR": "_profile_irregular_xlsx_prompt.md",
|
||||||
|
}
|
||||||
|
if source_type not in self.source_types_2_prompts:
|
||||||
|
raise ValueError(f"Unsupported source type: {source_type}")
|
||||||
|
self.prompt = self._load_prompt(source_type)
|
||||||
|
|
||||||
|
base_model_name = self.llm_call_manager.get_base_model_name()
|
||||||
|
vl_model_name = self.llm_call_manager.get_vl_model_name()
|
||||||
|
|
||||||
|
self.source_types_2_models = {
|
||||||
|
SourceType.CSV: base_model_name,
|
||||||
|
SourceType.EXCEL: base_model_name,
|
||||||
|
SourceType.IMAGE: vl_model_name,
|
||||||
|
SourceType.RELATIONAL_DB: base_model_name,
|
||||||
|
}
|
||||||
|
self.model_name = self.source_types_2_models[source_type]
|
||||||
|
|
||||||
|
def _load_prompt(self, source_type: Any = None):
|
||||||
|
"""Load the appropriate prompt template based on the source type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source_type: Type of data source (CSV, EXCEL, IMAGE, etc.)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Loaded prompt template as string
|
||||||
|
"""
|
||||||
|
prompt_file_name = self.source_types_2_prompts[source_type]
|
||||||
|
prompt = get_prompt_from_file(
|
||||||
|
os.path.join(
|
||||||
|
self._PROFILE_PROMPT_BASE_PATH,
|
||||||
|
prompt_file_name,
|
||||||
|
),
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
async def generate_profile(self) -> Dict[str, Any]:
|
||||||
|
"""Generate a complete data profile
|
||||||
|
by reading data, generating content,
|
||||||
|
calling the LLM, and wrapping the response.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing the complete data profile
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self.data = await self._read_data()
|
||||||
|
# different source types have different data building methods
|
||||||
|
content = self._build_content_with_prompt_and_data(
|
||||||
|
self.prompt,
|
||||||
|
self.data,
|
||||||
|
)
|
||||||
|
# content = self.prompt.format(data=self.data)
|
||||||
|
res = await self._generate_profile_by_llm(content)
|
||||||
|
self.profile = self._wrap_data_response(res)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error generating profile: {e}")
|
||||||
|
self.profile = {}
|
||||||
|
return self.profile
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def tool_clean_json(raw_response: str):
|
||||||
|
"""Clean and parse JSON response from LLM by removing markdown
|
||||||
|
markers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
raw_response: Raw string response from LLM
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Parsed JSON object from the cleaned response
|
||||||
|
"""
|
||||||
|
cleaned_response = raw_response.strip()
|
||||||
|
if cleaned_response.startswith("```json"):
|
||||||
|
cleaned_response = cleaned_response[len("```json") :].lstrip()
|
||||||
|
if cleaned_response.startswith("```"):
|
||||||
|
cleaned_response = cleaned_response[len("```") :].lstrip()
|
||||||
|
if cleaned_response.endswith("```"):
|
||||||
|
cleaned_response = cleaned_response[:-3].rstrip()
|
||||||
|
return json.loads(cleaned_response)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _build_content_with_prompt_and_data(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
data: Any,
|
||||||
|
) -> str:
|
||||||
|
"""Abstract method to build content for LLM based on prompt
|
||||||
|
and data.
|
||||||
|
|
||||||
|
This method should be implemented by subclasses to format
|
||||||
|
content appropriately for different data types.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: Prompt template to use
|
||||||
|
data: Processed data to include in the prompt
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted content for LLM call
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def _read_data(self):
|
||||||
|
"""Abstract method to read and process data from the source path.
|
||||||
|
|
||||||
|
This method should be implemented by subclasses to handle
|
||||||
|
specific
|
||||||
|
data source types (CSV, Excel, DB, etc.).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Processed data in appropriate format for the data type
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def _generate_profile_by_llm(
|
||||||
|
self,
|
||||||
|
content: Any,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Generate profile by calling LLM with prepared content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Content to send to the LLM (text or multimodal)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary response parsed from LLM output
|
||||||
|
"""
|
||||||
|
sys_prompt = "You are a helpful AI assistant for database management."
|
||||||
|
msgs = [
|
||||||
|
Msg("system", sys_prompt, "system"),
|
||||||
|
Msg("user", content, "user"),
|
||||||
|
]
|
||||||
|
response = await self.llm_call_manager(
|
||||||
|
model_name=self.model_name,
|
||||||
|
messages=msgs,
|
||||||
|
)
|
||||||
|
response = BaseDataProfiler.tool_clean_json(response)
|
||||||
|
return response
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _wrap_data_response(self, response: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Abstract method to combine LLM response with original schema.
|
||||||
|
|
||||||
|
This method should be implemented by subclasses to properly
|
||||||
|
merge
|
||||||
|
LLM-generated descriptions with original data structure.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: Dictionary response from LLM
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Combined dictionary with original schema and LLM response
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class StructuredDataProfiler(BaseDataProfiler):
|
||||||
|
"""Base class for profilers that work with structured data sources
|
||||||
|
like CSV, Excel, and relational databases.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_irregular(cols: list[str]):
|
||||||
|
"""Determine if a table has irregular column names (many unnamed).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cols: List of column names from the dataset
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Boolean indicating whether the dataset is irregular
|
||||||
|
"""
|
||||||
|
# any(col.startswith('Unnamed') for col in df.columns.astype(str))?
|
||||||
|
unnamed_columns_ratio = sum(
|
||||||
|
col.startswith("Unnamed") for col in cols.astype(str)
|
||||||
|
) / len(cols)
|
||||||
|
return unnamed_columns_ratio >= 0.5
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_schema_from_table(df: pd.DataFrame, df_name: str) -> dict:
|
||||||
|
"""Analyzes a single DataFrame to extract metadata and samples.
|
||||||
|
|
||||||
|
Extracts column names, data types, and sample values to provide a
|
||||||
|
comprehensive view of the table structure for the LLM.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: The dataframe to analyze
|
||||||
|
df_name: Name of the table (or sheet/filename)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing schema metadata for the table
|
||||||
|
"""
|
||||||
|
col_list = []
|
||||||
|
for col in df.columns:
|
||||||
|
dtype_name = str(df[col].dtype).upper()
|
||||||
|
# Get random samples to help LLM understand the data content
|
||||||
|
# sample(frac=1): shuffle the data
|
||||||
|
# head(n_samples): get the first n_samples,
|
||||||
|
# if less than n_samples, retrieved here without any errors.
|
||||||
|
candidates = (
|
||||||
|
df[col]
|
||||||
|
.drop_duplicates()
|
||||||
|
.sample(frac=1, random_state=42)
|
||||||
|
.head(5)
|
||||||
|
.astype(str)
|
||||||
|
.tolist()
|
||||||
|
)
|
||||||
|
# Limit the size not to exceed 1000 characters.
|
||||||
|
# TODO: dynamic size control? 1000 is too small?
|
||||||
|
samples = []
|
||||||
|
length = 0
|
||||||
|
for s in candidates:
|
||||||
|
if (length := length + len(s)) <= 1000:
|
||||||
|
samples.append(s)
|
||||||
|
col_list.append(
|
||||||
|
{
|
||||||
|
"column name": col,
|
||||||
|
"data type": dtype_name,
|
||||||
|
"data samples": samples,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# Create a CSV snippet of the first few rows
|
||||||
|
raw_data_snippet = df.head(5).to_csv(index=True)
|
||||||
|
|
||||||
|
table_schema = {
|
||||||
|
"name": df_name,
|
||||||
|
"raw_data_snippet": raw_data_snippet,
|
||||||
|
# Note: Row count logic might need optimization for large files
|
||||||
|
# TODO: how to get the row count more efficiently, openpyxl.
|
||||||
|
"row_count": len(df) if len(df) < 100 else None,
|
||||||
|
"col_count": len(df.columns),
|
||||||
|
"columns": col_list,
|
||||||
|
}
|
||||||
|
return table_schema
|
||||||
|
|
||||||
|
def _build_content_with_prompt_and_data(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
data: Any,
|
||||||
|
) -> str:
|
||||||
|
"""Format the prompt with data for structured data sources.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: Template prompt string
|
||||||
|
data: Processed data structure
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted content string ready for LLM
|
||||||
|
"""
|
||||||
|
return prompt.format(data=data)
|
||||||
|
|
||||||
|
def _wrap_data_response(self, response: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Merges the original schema with the LLM-generated response.
|
||||||
|
|
||||||
|
Combines the structural information from the original data with
|
||||||
|
semantic descriptions generated by the LLM.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: Dictionary response from LLM with descriptions
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Combined schema with both structural and semantic info
|
||||||
|
"""
|
||||||
|
new_schema = {}
|
||||||
|
new_schema["name"] = self.data["name"]
|
||||||
|
new_schema["description"] = response["description"]
|
||||||
|
# # For flat files like CSV, they contain columns
|
||||||
|
if "columns" in self.data:
|
||||||
|
new_schema["columns"] = self.data["columns"]
|
||||||
|
# # For multi-table sources like Excel/Database,
|
||||||
|
# they contain tables. Each table contains columns and description
|
||||||
|
if "tables" in self.data and "tables" in response:
|
||||||
|
new_schema["tables"] = []
|
||||||
|
# Build a map for response tables and descriptions
|
||||||
|
res_des_map = {
|
||||||
|
table["name"]: table["description"]
|
||||||
|
for table in response["tables"]
|
||||||
|
}
|
||||||
|
for table in self.data["tables"]:
|
||||||
|
table_name = table["name"]
|
||||||
|
if table_name not in res_des_map:
|
||||||
|
continue
|
||||||
|
new_table = {}
|
||||||
|
new_table["name"] = table_name
|
||||||
|
# Retain the desrciption from the LLM response
|
||||||
|
new_table["description"] = res_des_map[table_name]
|
||||||
|
if "columns" in table:
|
||||||
|
new_table["columns"] = table["columns"]
|
||||||
|
if "irregular_judgment" in table:
|
||||||
|
new_table["irregular_judgment"] = table[
|
||||||
|
"irregular_judgment"
|
||||||
|
]
|
||||||
|
new_schema["tables"].append(new_table)
|
||||||
|
return new_schema
|
||||||
|
|
||||||
|
|
||||||
|
class ExcelProfiler(StructuredDataProfiler):
|
||||||
|
async def _extract_irregular_table(
|
||||||
|
self,
|
||||||
|
path: str,
|
||||||
|
raw_data_snippet: str,
|
||||||
|
sheet_name: str,
|
||||||
|
):
|
||||||
|
"""Extract structure from irregular Excel sheets with unnamed
|
||||||
|
columns. Uses a special LLM call to identify the actual data in
|
||||||
|
sheets with headers or other content above the main data table.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Path to the Excel file
|
||||||
|
raw_data_snippet: Raw text snippet of the sheet content
|
||||||
|
sheet_name: Name of the sheet being processed
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Schema dictionary for the irregular table structure
|
||||||
|
"""
|
||||||
|
prompt = self._load_prompt("IRREGULAR")
|
||||||
|
content = prompt.format(raw_snippet_data=raw_data_snippet)
|
||||||
|
res = await self._generate_profile_by_llm(content=content)
|
||||||
|
|
||||||
|
if "is_extractable_table" in res and res["is_extractable_table"]:
|
||||||
|
logger.debug(res["reasoning"])
|
||||||
|
skiprows = res["row_start_index"] + 1
|
||||||
|
cols_range = res["col_ranges"]
|
||||||
|
df = pd.read_excel(
|
||||||
|
path,
|
||||||
|
sheet_name=sheet_name,
|
||||||
|
nrows=100,
|
||||||
|
skiprows=skiprows,
|
||||||
|
usecols=range(cols_range[0], cols_range[1] + 1),
|
||||||
|
).convert_dtypes()
|
||||||
|
if StructuredDataProfiler.is_irregular(df.columns):
|
||||||
|
schema = {
|
||||||
|
"name": sheet_name,
|
||||||
|
"raw_data_snippet": raw_data_snippet,
|
||||||
|
"irregular_judgment": "UNSTRUCTURED",
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
schema = self._extract_schema_from_table(df, sheet_name)
|
||||||
|
schema["irregular_judgment"] = res
|
||||||
|
else:
|
||||||
|
schema = {
|
||||||
|
"name": sheet_name,
|
||||||
|
"raw_data_snippet": raw_data_snippet,
|
||||||
|
"irregular_judgment": "UNSTRUCTURED",
|
||||||
|
}
|
||||||
|
|
||||||
|
return schema
|
||||||
|
|
||||||
|
async def _read_data(self):
|
||||||
|
"""Read and process Excel file data including all sheets.
|
||||||
|
|
||||||
|
Handles both regular and irregular Excel files by using pandas
|
||||||
|
for regular files and openpyxl for files with unnamed columns.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing metadata for all sheets in the Excel file
|
||||||
|
"""
|
||||||
|
excel_file = pd.ExcelFile(self.path)
|
||||||
|
table_schemas = []
|
||||||
|
schema = {}
|
||||||
|
schema["name"] = self.file_name
|
||||||
|
for sheet_name in excel_file.sheet_names:
|
||||||
|
# TODO: use openpyxl to read excel to avoid irregular excel.
|
||||||
|
# Read a subset of each sheet
|
||||||
|
df = pd.read_excel(
|
||||||
|
self.path,
|
||||||
|
sheet_name=sheet_name,
|
||||||
|
nrows=100,
|
||||||
|
).convert_dtypes()
|
||||||
|
if not StructuredDataProfiler.is_irregular(df.columns):
|
||||||
|
table_schema = (
|
||||||
|
StructuredDataProfiler._extract_schema_from_table(
|
||||||
|
df,
|
||||||
|
sheet_name,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# if unnamed columns, use openpyxl to extract top 100 rows.
|
||||||
|
import openpyxl
|
||||||
|
|
||||||
|
wb = openpyxl.load_workbook(
|
||||||
|
self.path,
|
||||||
|
read_only=True,
|
||||||
|
data_only=True,
|
||||||
|
)
|
||||||
|
ws = wb[sheet_name]
|
||||||
|
rows_data = []
|
||||||
|
for i, row in enumerate(
|
||||||
|
ws.iter_rows(values_only=True),
|
||||||
|
start=1,
|
||||||
|
):
|
||||||
|
if i > 100:
|
||||||
|
break
|
||||||
|
rows_data.append(
|
||||||
|
",".join(
|
||||||
|
"" if cell is None else str(cell) for cell in row
|
||||||
|
),
|
||||||
|
)
|
||||||
|
wb.close()
|
||||||
|
raw_data_snippet = "\n".join(rows_data)
|
||||||
|
|
||||||
|
table_schema = await self._extract_irregular_table(
|
||||||
|
self.path,
|
||||||
|
raw_data_snippet,
|
||||||
|
sheet_name,
|
||||||
|
)
|
||||||
|
# table_schema = {
|
||||||
|
# "name": sheet_name,
|
||||||
|
# "raw_data_snippet": "\n".join(rows_data),
|
||||||
|
# }
|
||||||
|
table_schemas.append(table_schema)
|
||||||
|
schema["tables"] = table_schemas
|
||||||
|
return schema
|
||||||
|
|
||||||
|
|
||||||
|
class RelationalDatabaseProfiler(StructuredDataProfiler):
|
||||||
|
async def _read_data(self):
|
||||||
|
"""
|
||||||
|
Extracts metadata (schema) for all tables in a relational db.
|
||||||
|
|
||||||
|
path (str): The Database Source Name (connection string).
|
||||||
|
eg. postgresql://user:passward@ip:port/db_name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing database metadata for all tables
|
||||||
|
"""
|
||||||
|
options = {
|
||||||
|
"isolation_level": "AUTOCOMMIT",
|
||||||
|
# Test conns before use (handles MySQL 8hr timeout, network drops)
|
||||||
|
"pool_pre_ping": True,
|
||||||
|
# Keep minimal conns (MCP typically handles 1 request at a time)
|
||||||
|
"pool_size": 1,
|
||||||
|
# Allow temporary burst capacity for edge cases
|
||||||
|
"max_overflow": 2,
|
||||||
|
# Force refresh conns older than 1hr (under MySQL's 8hr default)
|
||||||
|
"pool_recycle": 3600,
|
||||||
|
}
|
||||||
|
engine = create_engine(self.path, **options)
|
||||||
|
try:
|
||||||
|
connection = engine.connect()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Connection to {self.path} failed: {e}")
|
||||||
|
raise ConnectionError(f"Failed to connect to database: {e}") from e
|
||||||
|
|
||||||
|
# Use DSN as the db identifier (can parsed cleaner)
|
||||||
|
database_name = self.path
|
||||||
|
inspector = inspect(connection)
|
||||||
|
table_names = inspector.get_table_names()
|
||||||
|
|
||||||
|
tables_data = []
|
||||||
|
for table_name in table_names:
|
||||||
|
try:
|
||||||
|
# 1. Get column information
|
||||||
|
columns = inspector.get_columns(table_name)
|
||||||
|
col_count = len(columns)
|
||||||
|
|
||||||
|
# 2. Get row count
|
||||||
|
row_count_result = connection.execute(
|
||||||
|
text(f"SELECT COUNT(*) FROM {table_name}"),
|
||||||
|
).fetchone()
|
||||||
|
row_count = row_count_result[0] if row_count_result else 0
|
||||||
|
|
||||||
|
# 3. Get raw data snippet (first 5 rows)
|
||||||
|
raw_data_snippet = ""
|
||||||
|
try:
|
||||||
|
result = connection.execute(
|
||||||
|
text(f"SELECT * FROM {table_name} LIMIT 5"),
|
||||||
|
)
|
||||||
|
rows = result.fetchall()
|
||||||
|
if rows:
|
||||||
|
column_names = [col["name"] for col in columns]
|
||||||
|
lines = []
|
||||||
|
# Add header
|
||||||
|
lines.append(", ".join(column_names))
|
||||||
|
# Add data rows
|
||||||
|
for row in rows:
|
||||||
|
row_values = []
|
||||||
|
for value in row:
|
||||||
|
if value is None:
|
||||||
|
row_values.append("NULL")
|
||||||
|
else:
|
||||||
|
# Escape commas and newlines
|
||||||
|
val_str = str(value)
|
||||||
|
if "," in val_str or "\n" in val_str:
|
||||||
|
val_str = f'"{val_str}"'
|
||||||
|
row_values.append(val_str)
|
||||||
|
lines.append(", ".join(row_values))
|
||||||
|
raw_data_snippet = "\n".join(lines)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Error fetching {table_name} data: {str(e)}",
|
||||||
|
)
|
||||||
|
raw_data_snippet = None
|
||||||
|
# 4. detailed column info (types and samples)
|
||||||
|
column_details = []
|
||||||
|
if rows:
|
||||||
|
for i, col in enumerate(columns):
|
||||||
|
col_name = col["name"]
|
||||||
|
col_type = str(col["type"])
|
||||||
|
# Extract samples for this column from the fetched rows
|
||||||
|
sample_values = []
|
||||||
|
for row in rows:
|
||||||
|
if i < len(row):
|
||||||
|
val = row[i]
|
||||||
|
sample_values.append(
|
||||||
|
str(val) if val is not None else "NULL",
|
||||||
|
)
|
||||||
|
|
||||||
|
column_details.append(
|
||||||
|
{
|
||||||
|
"column name": col_name,
|
||||||
|
"data type": col_type,
|
||||||
|
"data sample": sample_values[:3],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
table_info = {
|
||||||
|
"name": table_name,
|
||||||
|
"row_count": row_count,
|
||||||
|
"col_count": col_count,
|
||||||
|
"raw_data_snippet": raw_data_snippet,
|
||||||
|
"columns": column_details,
|
||||||
|
}
|
||||||
|
|
||||||
|
tables_data.append(table_info)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# If one table fails, log it and continue to the next
|
||||||
|
logger.warning(f"Error processing {table_name}: {str(e)}")
|
||||||
|
continue
|
||||||
|
# Contruct the final schema
|
||||||
|
schema = {
|
||||||
|
"name": database_name,
|
||||||
|
"tables": tables_data,
|
||||||
|
}
|
||||||
|
self.data = schema
|
||||||
|
return schema
|
||||||
|
|
||||||
|
|
||||||
|
class CsvProfiler(ExcelProfiler):
|
||||||
|
async def _read_data(self):
|
||||||
|
"""Handles schema extraction for CSV as single-table sources.
|
||||||
|
|
||||||
|
Uses Polars for efficient row counting on large files and
|
||||||
|
pandas for detailed schema analysis of the first 100 rows.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Schema dictionary for the CSV file
|
||||||
|
"""
|
||||||
|
import polars as pl
|
||||||
|
|
||||||
|
# Use Polars for efficient row counting on large files
|
||||||
|
df = pl.scan_csv(self.path, ignore_errors=True)
|
||||||
|
row_count = df.select(pl.len()).collect().item()
|
||||||
|
# Read a subset with Pandas for detailed schema analysis
|
||||||
|
df = pd.read_csv(self.path, nrows=100).convert_dtypes()
|
||||||
|
schema = self._extract_schema_from_table(df, self.file_name)
|
||||||
|
schema["row_count"] = row_count
|
||||||
|
# if StructuredDataProfiler.is_irregular(df.columns):
|
||||||
|
# self._extract_irregular_table(...)
|
||||||
|
return schema
|
||||||
|
|
||||||
|
|
||||||
|
class ImageProfiler(BaseDataProfiler):
|
||||||
|
"""Profiler for image data sources that uses multimodal LLMs."""
|
||||||
|
|
||||||
|
async def _read_data(self):
|
||||||
|
"""
|
||||||
|
For images, this simply returns the path since the LLM API
|
||||||
|
handles loading the image directly.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to the image file
|
||||||
|
"""
|
||||||
|
return self.path
|
||||||
|
|
||||||
|
def _build_content_with_prompt_and_data(self, prompt, data):
|
||||||
|
"""build multimodal content for image analysis.
|
||||||
|
|
||||||
|
Creates content in the format required by multimodal LLM APIs
|
||||||
|
with both image and text components.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: Text prompt template for image analysis
|
||||||
|
data: Path to the image file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List containing image and text components for the LLM call
|
||||||
|
"""
|
||||||
|
# Convert image paths according to the model requirements
|
||||||
|
contents = [
|
||||||
|
{
|
||||||
|
"text": prompt,
|
||||||
|
"type": "text",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"source": {
|
||||||
|
"url": data,
|
||||||
|
"type": "url",
|
||||||
|
},
|
||||||
|
"type": "image",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
return contents
|
||||||
|
|
||||||
|
def _wrap_data_response(self, response: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Format the LLM response for image data into dict.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: Dictionary response from multimodal LLM
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Profile dictionary with image name, description and details
|
||||||
|
"""
|
||||||
|
profile = {
|
||||||
|
"name": self.file_name,
|
||||||
|
"description": response["description"],
|
||||||
|
"details": response["details"],
|
||||||
|
}
|
||||||
|
return profile
|
||||||
|
|
||||||
|
|
||||||
|
class DataProfilerFactory:
|
||||||
|
"""Factory class to create appropriate data profiler instances based
|
||||||
|
on source type.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_profiler(
|
||||||
|
llm_call_manager: LLMCallManager,
|
||||||
|
path: str,
|
||||||
|
source_type: SourceType,
|
||||||
|
) -> BaseDataProfiler:
|
||||||
|
"""Factory method to get the appropriate profiler instance.
|
||||||
|
Generate the correct profile result for the source.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Path to the data source or connection string
|
||||||
|
source_type: Enum indicating the type of data source
|
||||||
|
llm_call_manager: Manager for handling LLM calls
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Instance of the appropriate profiler subclass
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the source_type is unsupported
|
||||||
|
"""
|
||||||
|
if source_type == SourceType.IMAGE:
|
||||||
|
return ImageProfiler(
|
||||||
|
path=path,
|
||||||
|
source_type=source_type,
|
||||||
|
llm_call_manager=llm_call_manager,
|
||||||
|
)
|
||||||
|
elif source_type == SourceType.CSV:
|
||||||
|
return CsvProfiler(
|
||||||
|
path=path,
|
||||||
|
source_type=source_type,
|
||||||
|
llm_call_manager=llm_call_manager,
|
||||||
|
)
|
||||||
|
elif source_type == SourceType.EXCEL:
|
||||||
|
return ExcelProfiler(
|
||||||
|
path=path,
|
||||||
|
source_type=source_type,
|
||||||
|
llm_call_manager=llm_call_manager,
|
||||||
|
)
|
||||||
|
elif source_type == SourceType.RELATIONAL_DB:
|
||||||
|
return RelationalDatabaseProfiler(
|
||||||
|
path=path,
|
||||||
|
source_type=source_type,
|
||||||
|
llm_call_manager=llm_call_manager,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported source type: {source_type}")
|
||||||
@@ -0,0 +1,21 @@
|
|||||||
|
{
|
||||||
|
"relational_db": {
|
||||||
|
"mcp_server": {
|
||||||
|
"mcp_alchemy": {
|
||||||
|
"command": "uvx",
|
||||||
|
"args": [
|
||||||
|
"--from",
|
||||||
|
"mcp-alchemy==2025.8.15.91819",
|
||||||
|
"--with",
|
||||||
|
"psycopg2-binary",
|
||||||
|
"--refresh-package",
|
||||||
|
"mcp-alchemy",
|
||||||
|
"mcp-alchemy"
|
||||||
|
],
|
||||||
|
"env": {
|
||||||
|
"DB_URL": "${endpoint}"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
54
alias/src/alias/agent/agents/data_source/_typing.py
Normal file
54
alias/src/alias/agent/agents/data_source/_typing.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class SourceAccessType(str, Enum):
|
||||||
|
"""Simple source access type classification"""
|
||||||
|
|
||||||
|
DIRECT = "direct"
|
||||||
|
VIA_MCP = "via_mcp"
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.value
|
||||||
|
|
||||||
|
|
||||||
|
class SourceType(str, Enum):
|
||||||
|
"""Simple source type classification"""
|
||||||
|
|
||||||
|
CSV = "csv"
|
||||||
|
JSON = "json"
|
||||||
|
EXCEL = "excel"
|
||||||
|
TEXT = "text"
|
||||||
|
IMAGE = "image"
|
||||||
|
|
||||||
|
# Database sources
|
||||||
|
RELATIONAL_DB = "relational_db"
|
||||||
|
|
||||||
|
OTHER = "other"
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.value
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_valid_source_type(value: str) -> bool:
|
||||||
|
try:
|
||||||
|
SourceType(value)
|
||||||
|
return True
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
# Define mapping between SourceType and SourceAccessType
|
||||||
|
SOURCE_TYPE_TO_ACCESS_TYPE = {
|
||||||
|
# File types -> LOCAL_FILE
|
||||||
|
SourceType.CSV: SourceAccessType.DIRECT,
|
||||||
|
SourceType.JSON: SourceAccessType.DIRECT,
|
||||||
|
SourceType.EXCEL: SourceAccessType.DIRECT,
|
||||||
|
SourceType.TEXT: SourceAccessType.DIRECT,
|
||||||
|
SourceType.IMAGE: SourceAccessType.DIRECT,
|
||||||
|
# Database types -> MCP_TOOL
|
||||||
|
SourceType.RELATIONAL_DB: SourceAccessType.VIA_MCP,
|
||||||
|
# Unknown type -> depends on endpoint
|
||||||
|
SourceType.OTHER: None,
|
||||||
|
}
|
||||||
@@ -0,0 +1,83 @@
|
|||||||
|
# Role
|
||||||
|
You are an expert Data Steward. Your task is to generate a single, comprehensive description sentence for a CSV file based on its metadata and raw content.
|
||||||
|
|
||||||
|
# Input Format
|
||||||
|
You will receive a single JSON string in the variable `input_json`. The structure is:
|
||||||
|
```json
|
||||||
|
{{
|
||||||
|
"name": "filename.csv",
|
||||||
|
"raw_data_snippet": "col1, col2\na, b",
|
||||||
|
"row_count": 100,
|
||||||
|
"col_count": 5,
|
||||||
|
"columns": [
|
||||||
|
{{ "column name": "col1", "data type": "string", "data sample": ["a", "b"] }}
|
||||||
|
],
|
||||||
|
}}
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
# Analysis Logic
|
||||||
|
|
||||||
|
## 1. Context & Metrics Extraction
|
||||||
|
|
||||||
|
* **Subject:** Extract the core concept from the `name` field (e.g., `logistics_data.csv` -> "logistics_data").
|
||||||
|
* **Metrics:** Identify `row_count` and `col_count`.
|
||||||
|
* **Context:** Look for time (e.g., "2024") or location keywords in the `raw_data_snippet` or `name`.
|
||||||
|
|
||||||
|
## 2. Schema Identification
|
||||||
|
|
||||||
|
* **Primary:** Use column names from the `columns` list.
|
||||||
|
* **Secondary (Inference):** If the `columns` list is empty or generic (e.g., "col1"), you MUST infer meaningful column names from the `raw_data_snippet` values (e.g., "2023-01-01" -> `date`).
|
||||||
|
* **Selection:** Choose 3-5 key columns to represent the dataset structure.
|
||||||
|
|
||||||
|
## 3. Description Construction
|
||||||
|
|
||||||
|
* Generate a **single** grammatical sentence.
|
||||||
|
* **Strict Template:** "The file [FileName] contains [Subject] data [Optional: Context] with [RowCount] rows and [ColCount] columns, featuring fields such as [List of 3-5 key columns]."
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
# Output Format (Strict JSON)
|
||||||
|
|
||||||
|
You must output a single valid JSON object containing only the `description` key.
|
||||||
|
|
||||||
|
```json
|
||||||
|
{{
|
||||||
|
"description": "The file [FileName] contains [Subject] data with [Rows] rows and [Cols] columns, featuring fields such as [Columns]."
|
||||||
|
}}
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
# One-Shot Demonstration
|
||||||
|
|
||||||
|
**[Example Input]**
|
||||||
|
`input_json` =
|
||||||
|
|
||||||
|
```json
|
||||||
|
{{
|
||||||
|
"name": "logistics_data.csv",
|
||||||
|
"raw_data_snippet": "SHP-001, Tokyo, London, 2024-05-20\nSHP-002, NY, Paris, 2024-05-21",
|
||||||
|
"row_count": 2000,
|
||||||
|
"col_count": 4,
|
||||||
|
"columns": [
|
||||||
|
{{
|
||||||
|
"column name": "shipment_id",
|
||||||
|
"data type": "string",
|
||||||
|
"data sample": ["SHP-001", "SHP-002"]
|
||||||
|
}}
|
||||||
|
],
|
||||||
|
}}
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
**[Example Output]**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{{
|
||||||
|
"description": "The file logistics_data.csv contains supply chain logistics information for 2024 with 2000 rows and 4 columns, featuring fields such as shipment_id, origin, destination, and date."
|
||||||
|
}}
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
# Input
|
||||||
|
input_json = {data}
|
||||||
@@ -0,0 +1,31 @@
|
|||||||
|
Please carefully analyze this image and perform the following tasks:
|
||||||
|
|
||||||
|
### Step 1: Overall Assessment
|
||||||
|
- Determine whether the image contains "single content" or "composite content" (i.e., multiple independent information modules).
|
||||||
|
- If it is composite content, list the main components (e.g., "Bar chart in the top-left, data table in the bottom-right, title description at the top").
|
||||||
|
|
||||||
|
### Step 2: Region-wise Analysis (for composite content)
|
||||||
|
For each prominent content region, describe it using the following template:
|
||||||
|
|
||||||
|
#### [Module X] Type: [flowchart/table/chart/document/photo]
|
||||||
|
- Position and scope: Briefly describe its location (e.g., "left half", "bottom table")
|
||||||
|
- Content extraction:
|
||||||
|
- If flowchart/diagram: Describe nodes and connections in logical order, and explain label meanings.
|
||||||
|
- If table: Reconstruct row/column structure; present in Markdown table format if possible.
|
||||||
|
- If chart: Explain axes, series, trends, and provide key conclusions (e.g., "Peak reached in Q3").
|
||||||
|
- If document/text: Extract key sentences while preserving original meaning.
|
||||||
|
- If photo: Describe scene, people, and actions.
|
||||||
|
- Functional role: Infer the module's purpose within the whole image (e.g., "Supports the conclusion stated above").
|
||||||
|
|
||||||
|
### Step 3: Global Synthesis
|
||||||
|
- Summarize the core purpose of the entire image (e.g., "Presents quarterly performance analysis").
|
||||||
|
- Describe logical relationships among modules (e.g., "The table provides data sources, the chart shows trends, and the text offers recommendations").
|
||||||
|
- If there are annotations (e.g., label1, Cost2), explain their business meaning.
|
||||||
|
|
||||||
|
### Final Output Requirements
|
||||||
|
Do not output the internal analysis steps separately. You must output the final result **ONLY** in the following format:
|
||||||
|
|
||||||
|
{{
|
||||||
|
"description": "A single, concise sentence describing the overall framework of the image",
|
||||||
|
"details": "A comprehensive and detailed description based on your Step 2 & 3 analysis. Use Markdown formatting (bullet points, bold text) inside this field to ensure the structure is clear and readable."
|
||||||
|
}}
|
||||||
@@ -0,0 +1,82 @@
|
|||||||
|
# Role
|
||||||
|
You are an expert Data Engineer specializing in unstructured Excel parsing. Your task is to analyze the raw content of the first 100 rows of an Excel sheet and determine if it contains structured tabular data suitable for a Pandas DataFrame.
|
||||||
|
|
||||||
|
If it is a valid table, identify the **Header Row** and the **Column Range**.
|
||||||
|
If it is NOT a valid table (e.g., a dashboard, a form, a letter, or empty), you must flag it as unsuitable.
|
||||||
|
|
||||||
|
# Task Analysis
|
||||||
|
Excel sheets fall into two categories:
|
||||||
|
1. **List-Like Tables (Valid)**: Contains a header row followed by multiple rows of consistent record data. This is what we want.
|
||||||
|
2. **Unstructured/Layout-Heavy (Unstructured)**:
|
||||||
|
- **Forms/KV Pairs**: "Label: Value" scattered across the sheet.
|
||||||
|
- **Dashboards**: Multiple small tables, charts, or scattered numbers.
|
||||||
|
- **Text/Notes**: Paragraphs of text or disclaimers without column structure.
|
||||||
|
- **Empty/Near Empty**: Contains almost no data.
|
||||||
|
|
||||||
|
# Rules for Detection
|
||||||
|
|
||||||
|
### A. Validity Check (The "Gatekeeper")
|
||||||
|
Set `is_extractable_table` to **false** if:
|
||||||
|
- There is no distinct row where meaningful column headers align horizontally.
|
||||||
|
- The data is scattered (e.g., values exist in A1, G5, and C20 with no relation).
|
||||||
|
- The sheet looks like a printed form (Key on the left, Value on the right) rather than a list of records.
|
||||||
|
- There are fewer than 3 rows of data following a potential header.
|
||||||
|
|
||||||
|
### B. Structure Extraction (Only if Valid)
|
||||||
|
If the sheet passes the Validity Check:
|
||||||
|
1. **Header Row**: Find the first row containing multiple distinct string values that serve as column labels.
|
||||||
|
2. **Column Range**: Identify the start index (first valid header) and end index (last valid header) to define the width.
|
||||||
|
3. **Data Continuity**: Verify that rows below the header contain consistent data types (e.g., Dates under "Date").
|
||||||
|
|
||||||
|
# Input Data
|
||||||
|
The user will provide the first 100 rows in CSV/Markdown format (0-based index).
|
||||||
|
|
||||||
|
# Output Format
|
||||||
|
You must output a strictly valid JSON object.
|
||||||
|
JSON Structure:
|
||||||
|
{{
|
||||||
|
"is_extractable_table": <boolean, true if it serves as a dataframe source, false otherwise>,
|
||||||
|
"row_start_index": <int or null, 0-based index of the header row>,
|
||||||
|
"col_ranges": <list [start, end] or null, inclusive 0-based column indices>,
|
||||||
|
"confidence_score": <float, 0-1>,
|
||||||
|
"reasoning": "<string, explain what the row data contains. declare the final conclusion(IRREGULAR,REGULAR,INVALIED). >"
|
||||||
|
}}
|
||||||
|
|
||||||
|
# Examples
|
||||||
|
|
||||||
|
## Example 1 (Valid Table with Noise)
|
||||||
|
Input:
|
||||||
|
Title: Monthly Sales, NaN, NaN, NaN
|
||||||
|
NaN, NaN, NaN, NaN
|
||||||
|
NaN, Date, Item, Qty, Total
|
||||||
|
NaN, 2023-01-01, Apple, 10, 500
|
||||||
|
NaN, 2023-01-02, Banana, 5, 100
|
||||||
|
|
||||||
|
Output:
|
||||||
|
{{
|
||||||
|
"is_extractable_table": true,
|
||||||
|
"row_start_index": 2,
|
||||||
|
"col_ranges": [1, 4],
|
||||||
|
"confidence_score": 0.99,
|
||||||
|
"reasoning": " Rows 0-1 are ignored metadata, Row 2 is clear headers. Rows 3-4 contain consistent data aligned with headers. It is IRREGULAR and requires skiprows=2, usecols=[1, 4] to extract using Pansa DataFrame."
|
||||||
|
}}
|
||||||
|
|
||||||
|
## Example 2 (Unstructured - Form/Dashboard)
|
||||||
|
Input:
|
||||||
|
Company Invoice, NaN, NaN, Invoice #: 001
|
||||||
|
To:, John Doe, NaN, Date:, 2023-01-01
|
||||||
|
Address:, 123 St, NaN, Due:, 2023-02-01
|
||||||
|
NaN, NaN, NaN, NaN, NaN
|
||||||
|
Subject:, Consulting Services, NaN, NaN, NaN
|
||||||
|
|
||||||
|
Output:
|
||||||
|
{{
|
||||||
|
"is_extractable_table": false,
|
||||||
|
"row_start_index": null,
|
||||||
|
"col_ranges": null,
|
||||||
|
"confidence_score": 0.95,
|
||||||
|
"reasoning": "Data matches a 'Form/Invoice' layout (Key-Value pairs) rather than a list-like table. No single header row defines a dataset of records. It is INVALIED and cannot be processed as Pandas DataFrame."
|
||||||
|
}}
|
||||||
|
|
||||||
|
# Input
|
||||||
|
{raw_snippet_data}
|
||||||
@@ -0,0 +1,109 @@
|
|||||||
|
# Role
|
||||||
|
You are an expert Data Steward. Your task is to analyze the metadata and content of an Database.
|
||||||
|
**Assumption:** This is an ideal dataset or database where **ALL** tables contain valid headers in the first row. You will process the entire file structure in a single pass.
|
||||||
|
|
||||||
|
# Input Format
|
||||||
|
You will receive a single JSON string in the variable `input_json`. The structure is:
|
||||||
|
```json
|
||||||
|
{{
|
||||||
|
"file": "Name of the file",
|
||||||
|
"tables": [
|
||||||
|
{{
|
||||||
|
"name": "Name of the table",
|
||||||
|
"row_count": 100,
|
||||||
|
"col_count": 5,
|
||||||
|
"raw_data_snippet": "Header1, Header2\nVal1, Val2..."
|
||||||
|
}},
|
||||||
|
...
|
||||||
|
]
|
||||||
|
}}
|
||||||
|
|
||||||
|
```
|
||||||
|
# Analysis Logic
|
||||||
|
|
||||||
|
|
||||||
|
## 1. Sheet Iteration (Sheet Descriptions)
|
||||||
|
|
||||||
|
For **EACH** object in the `tables` array:
|
||||||
|
|
||||||
|
1. **Extract Schema:**
|
||||||
|
* Since headers are guaranteed, simply extract the column names from the **first row** of the `raw_data_snippet`.
|
||||||
|
* Format them as a clean list of strings.
|
||||||
|
|
||||||
|
2. **Draft Description:**
|
||||||
|
* Write a concise sentence describing what the sheet tracks based on its name and columns.
|
||||||
|
* **MANDATORY:** You MUST explicitly mention the `row_count` and `col_count` in this sentence.
|
||||||
|
* *Template:* "The sheet [Sheet Name] contains [Subject] data with [Row Count] rows and [Col Count] columns, featuring fields like [List 3 key columns]."
|
||||||
|
|
||||||
|
## 2. Global Analysis (File Description)
|
||||||
|
* Analyze the `file` name and the number of all `table_name`s inside the `tables` array.
|
||||||
|
* Based on all sheet descriptions, generate a single sentence summarizing the whole workbook.
|
||||||
|
|
||||||
|
# Output Format (Strict JSON)
|
||||||
|
|
||||||
|
You must output a single valid JSON object.
|
||||||
|
|
||||||
|
```json
|
||||||
|
{{
|
||||||
|
"description": "One sentence describing the whole file or database.",
|
||||||
|
"tables": [
|
||||||
|
{{
|
||||||
|
"name": "Name of table 1",
|
||||||
|
"description": "Sentence including row/col counts and key columns.",
|
||||||
|
"columns": ["col1", "col2", "col3"]
|
||||||
|
}},
|
||||||
|
...
|
||||||
|
]
|
||||||
|
}}
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
# One-Shot Demonstration
|
||||||
|
|
||||||
|
**[Example Input]**
|
||||||
|
`input_json` =
|
||||||
|
|
||||||
|
```json
|
||||||
|
{{
|
||||||
|
"file": "logistics_data.xlsx",
|
||||||
|
"tables": [
|
||||||
|
{{
|
||||||
|
"na me": "Shipments",
|
||||||
|
"row_count": 2000,
|
||||||
|
"col_count": 4,
|
||||||
|
"raw_data_snippet": "shipment_id, origin, destination, date\nSHP-001, Tokyo, London, 2024-05-20"
|
||||||
|
}},
|
||||||
|
{{
|
||||||
|
"name": "Rates",
|
||||||
|
"row_count": 50,
|
||||||
|
"col_count": 2,
|
||||||
|
"raw_data_snippet": "Route_ID, Cost_Per_Kg\nR-101, 5.50"
|
||||||
|
}}
|
||||||
|
]
|
||||||
|
}}
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
**[Example Output]**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{{
|
||||||
|
"description": "The file/database logistics_data.xlsx contains supply chain logistics information for 2024, divided into shipment tracking and rate definitions (2 tables in total).",
|
||||||
|
"tables": [
|
||||||
|
{{
|
||||||
|
"name": "Shipments",
|
||||||
|
"description": "The 'Shipments' sheet tracks individual shipment records with 2000 rows and 4 columns, featuring fields such as shipment_id, origin, and destination.",
|
||||||
|
"columns": ["shipment_id", "origin", "destination", "date"]
|
||||||
|
}},
|
||||||
|
{{
|
||||||
|
"name": "Rates",
|
||||||
|
"description": "The 'Rates' sheet lists shipping cost rates with 50 rows and 2 columns, specifically Route_ID and Cost_Per_Kg.",
|
||||||
|
"columns": ["Route_ID", "Cost_Per_Kg"]
|
||||||
|
}}
|
||||||
|
]
|
||||||
|
}}
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
# Input
|
||||||
|
input_json=`{data}`
|
||||||
@@ -0,0 +1,159 @@
|
|||||||
|
# Role
|
||||||
|
You are an expert Data Steward. Your task is to analyze the metadata and content of an Excel file based on a pre-analyzed structural judgment.
|
||||||
|
|
||||||
|
**Context:** The dataset contains three types of sheets:
|
||||||
|
1. **Regular Tables**: Standard headers in row 0.
|
||||||
|
2. **Irregular Tables**: Valid data but requires `skiprows` or `usecols` parameters.
|
||||||
|
3. **Unstructured Sheets**: Dashboards, forms, or text descriptions that **cannot** be read as a dataframe.
|
||||||
|
|
||||||
|
**Constraint**: Your analysis relies on a snippet of the first 100 rows.
|
||||||
|
|
||||||
|
# Input Format
|
||||||
|
You will receive a single JSON string in the variable `input_json`. The structure is:
|
||||||
|
```json
|
||||||
|
{{
|
||||||
|
"file": "Name of the file",
|
||||||
|
"tables": [
|
||||||
|
{{
|
||||||
|
"name": "Sheet Name",
|
||||||
|
"row_count": 100,
|
||||||
|
"col_count": 5,
|
||||||
|
"raw_data_snippet": "...",
|
||||||
|
"irregular_judgment": {{
|
||||||
|
"row_header_index": int,
|
||||||
|
"cols_ranges": list,
|
||||||
|
"reasoning": "..."
|
||||||
|
}}
|
||||||
|
}}
|
||||||
|
]
|
||||||
|
}}
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
*(Note: If `irregular_judgment` is null, treat it as Regular).*
|
||||||
|
|
||||||
|
# Analysis Logic
|
||||||
|
|
||||||
|
## 1. Sheet Iteration (Table Descriptions)
|
||||||
|
|
||||||
|
For **EACH** object in the `tables` array, apply the following priority logic:
|
||||||
|
|
||||||
|
**Case A: Unstructured Sheet (irregular_judgment contains "UNSTRUCTURED")**
|
||||||
|
|
||||||
|
* **Columns**: Return an empty list `[]`.
|
||||||
|
* **Description**: "The sheet [Name] contains [something].
|
||||||
|
**Append MANDATORY Warning**: "It is Unstructured based on a 100-row sample."
|
||||||
|
|
||||||
|
**Case B: Irregular Table (irregular_judgment contains a dict and `row_header_index` > 0 or `cols_ranges` is set)**
|
||||||
|
|
||||||
|
* **Columns**: Extract column names from the row indicated by `row_header_index`.
|
||||||
|
* **Description**:
|
||||||
|
Write a concise sentence describing what the sheet tracks based on its name and columns.
|
||||||
|
1. Start with: "The sheet [Name] contains [Subject] data with [Rows] rows and [Cols] columns."
|
||||||
|
2. **Append MANDATORY Warning**: "It is irregular; requires specifying skiprows={{row_header_index}}, usecols={{cols_ranges}} using pandas dataframe."
|
||||||
|
|
||||||
|
**Case C: Regular Table (Default)**
|
||||||
|
|
||||||
|
* **Columns**: Extract from the first row of `raw_data_snippet`.
|
||||||
|
* **Description**: "The sheet [Name] contains [Subject] data with [Rows] rows and [Cols] columns, featuring fields like [Key Cols]."
|
||||||
|
|
||||||
|
## 2. Global Analysis (File Description)
|
||||||
|
|
||||||
|
Generate a single string summarizing the workbook. This summary **MUST** explicitly include:
|
||||||
|
|
||||||
|
1. **Total Count**: The number of sheets.
|
||||||
|
2. **Status List**: List every table name with its status tag:
|
||||||
|
* (Regular)
|
||||||
|
* (Irregular, requires skiprows=X, usecols=Y)
|
||||||
|
* (Unstructured)
|
||||||
|
* *Format Example:* "The file logistics_data.xlsx contains supply chain logistics information for 2024, analyze the log datas. It contains 3 sheets: 'Data' (Regular), 'Logs' (Irregular, requires skiprows=2), and 'Cover' (Unstructured)."
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# Output Format (Strict JSON)
|
||||||
|
|
||||||
|
You must output a single valid JSON object.
|
||||||
|
|
||||||
|
```json
|
||||||
|
{{
|
||||||
|
"description": "Comprehensive summary including count, names, and specific status tags for ALL tables.",
|
||||||
|
"tables": [
|
||||||
|
{{
|
||||||
|
"name": "Table Name",
|
||||||
|
"description": "Specific description based on Case A, B, or C.",
|
||||||
|
"columns": ["col1", "col2"]
|
||||||
|
}}
|
||||||
|
]
|
||||||
|
}}
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
# One-Shot Demonstration
|
||||||
|
|
||||||
|
**[Example Input]**
|
||||||
|
`input_json` =
|
||||||
|
|
||||||
|
```json
|
||||||
|
{{
|
||||||
|
"file": "finance_report_v2.xlsx",
|
||||||
|
"tables": [
|
||||||
|
{{
|
||||||
|
"name": "Q1_Sales",
|
||||||
|
"row_count": 200,
|
||||||
|
"col_count": 5,
|
||||||
|
"raw_data_snippet": "Date, Item, Amount\n2023-01-01, A, 100",
|
||||||
|
}},
|
||||||
|
{{
|
||||||
|
"name": "Historical_Data",
|
||||||
|
"row_count": 500,
|
||||||
|
"col_count": 10,
|
||||||
|
"raw_data_snippet": "Confidential\nSystem Generated\n\nDate, ID, Val\n...",
|
||||||
|
"irregular_judgment": {{
|
||||||
|
"is_extractable_table": true,
|
||||||
|
"row_header_index": 3,
|
||||||
|
"cols_ranges": [0, 3],
|
||||||
|
"reasoning": "Header offset."
|
||||||
|
}}
|
||||||
|
}},
|
||||||
|
{{
|
||||||
|
"name": "Dashboard_Overview",
|
||||||
|
"row_count": 50,
|
||||||
|
"col_count": 20,
|
||||||
|
"raw_data_snippet": "Total KPI: 500 | Chart Area |\nDisclaimer: Internal Use",
|
||||||
|
"irregular_judgment": "UNSTRUCTURED"
|
||||||
|
}}
|
||||||
|
]
|
||||||
|
}}
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
**[Example Output]**
|
||||||
|
|
||||||
|
```json
|
||||||
|
{{
|
||||||
|
"description": "The file finance_report_v2.xlsx contains historical sales transaction records over the past Q1 period.
|
||||||
|
It contains 3 sheets: 'Q1_Sales' (Regular), 'Historical_Data' (Irregular, requires skiprows=3, usecols=[0, 3], sampled first 100 rows), and 'Dashboard_Overview' (Unstructured).",
|
||||||
|
"tables": [
|
||||||
|
{{
|
||||||
|
"name": "Q1_Sales",
|
||||||
|
"description": "The sheet 'Q1_Sales' contains sales transaction records. It contains 200 rows and 5 columns, featuring fields like Date, Item, and Amount.",
|
||||||
|
"columns": ["Date", "Item", "Amount"]
|
||||||
|
}},
|
||||||
|
{{
|
||||||
|
"name": "Historical_Data",
|
||||||
|
"description": "The sheet 'Historical_Data' contains historical sales transaction records records. It contains 400 rows and 21 columns. It's irregular judged by the first 100 samples(The first 3 rows contains metadata. requires specifying skiprows=3, usecols=[0, 3] using pandas dataframe.)",
|
||||||
|
"columns": ["Date", "ID", "Val"]
|
||||||
|
}},
|
||||||
|
{{
|
||||||
|
"name": "Dashboard_Overview",
|
||||||
|
"description": "The sheet 'Dashboard_Overview' contains the whole overview and summary of the whole dashboards It is Unstructured based on a 100-row sample.",
|
||||||
|
"columns": []
|
||||||
|
}}
|
||||||
|
]
|
||||||
|
}}
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
# Input
|
||||||
|
|
||||||
|
input_json=`{data}`
|
||||||
113
alias/src/alias/agent/agents/data_source/data_profile.py
Normal file
113
alias/src/alias/agent/agents/data_source/data_profile.py
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
import os
|
||||||
|
import base64
|
||||||
|
import tempfile
|
||||||
|
from typing import Any, Dict
|
||||||
|
from io import BytesIO
|
||||||
|
from pathlib import Path
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from alias.agent.agents.data_source._typing import SourceType
|
||||||
|
from alias.agent.agents.data_source._data_profiler_factory import (
|
||||||
|
DataProfilerFactory,
|
||||||
|
)
|
||||||
|
from alias.agent.tools.sandbox_util import (
|
||||||
|
get_workspace_file,
|
||||||
|
)
|
||||||
|
from alias.runtime.alias_sandbox.alias_sandbox import AliasSandbox
|
||||||
|
from alias.agent.utils.llm_call_manager import (
|
||||||
|
LLMCallManager,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_binary_buffer(
|
||||||
|
sandbox: AliasSandbox,
|
||||||
|
file_url: str,
|
||||||
|
):
|
||||||
|
if file_url.startswith(("http://", "https://")):
|
||||||
|
response = requests.get(file_url)
|
||||||
|
response.raise_for_status()
|
||||||
|
buffer = BytesIO(response.content)
|
||||||
|
else:
|
||||||
|
buffer = BytesIO(
|
||||||
|
base64.b64decode(get_workspace_file(sandbox, file_url)),
|
||||||
|
)
|
||||||
|
return buffer
|
||||||
|
|
||||||
|
|
||||||
|
def _copy_file_from_sandbox_with_original_name(
|
||||||
|
sandbox: AliasSandbox,
|
||||||
|
file_path: str,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Copies a file from the sandbox environment
|
||||||
|
or a URL to a local temporary file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sandbox (AliasSandbox): The sandbox environment instance.
|
||||||
|
path (str): Source path or URL.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The path to the local temporary file.
|
||||||
|
"""
|
||||||
|
# Handle different types of file URLs
|
||||||
|
if file_path.startswith(("http://", "https://")):
|
||||||
|
# For web URLs, use the URL directly
|
||||||
|
file_source = file_path
|
||||||
|
else:
|
||||||
|
# For local files, save to a temporary file
|
||||||
|
file_buffer = _get_binary_buffer(
|
||||||
|
sandbox,
|
||||||
|
file_path,
|
||||||
|
)
|
||||||
|
# Create a temporary file with the same name as the original file
|
||||||
|
temp_dir = tempfile.mkdtemp()
|
||||||
|
target_file_name = os.path.basename(file_path)
|
||||||
|
full_path = Path(temp_dir) / target_file_name
|
||||||
|
with open(full_path, "wb") as f:
|
||||||
|
f.write(file_buffer.getvalue())
|
||||||
|
file_source = full_path
|
||||||
|
return str(file_source)
|
||||||
|
|
||||||
|
|
||||||
|
async def data_profile(
|
||||||
|
sandbox: AliasSandbox,
|
||||||
|
sandbox_path: str,
|
||||||
|
source_type: SourceType,
|
||||||
|
llm_call_manager: LLMCallManager,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Generates a detailed profile and summary for data source using LLMs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sandbox (AliasSandbox): The sandbox environment instance.
|
||||||
|
path (str): The location of the data source.
|
||||||
|
- For files: A file path or URL.
|
||||||
|
- For databases: A connection string (DSN).
|
||||||
|
source_type (SourceType): The type of the data source.
|
||||||
|
llm_call_manager: Manager for handling LLM calls
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: An object containing the generated text profile of the data.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the provided `source_type` is not supported.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if source_type in [SourceType.CSV, SourceType.EXCEL, SourceType.IMAGE]:
|
||||||
|
local_path = _copy_file_from_sandbox_with_original_name(
|
||||||
|
sandbox,
|
||||||
|
sandbox_path,
|
||||||
|
)
|
||||||
|
elif source_type == SourceType.RELATIONAL_DB:
|
||||||
|
local_path = sandbox_path
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported source type {source_type}")
|
||||||
|
|
||||||
|
profiler = DataProfilerFactory.get_profiler(
|
||||||
|
llm_call_manager=llm_call_manager,
|
||||||
|
path=local_path,
|
||||||
|
source_type=source_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
return await profiler.generate_profile()
|
||||||
198
alias/src/alias/agent/agents/data_source/data_skill.py
Normal file
198
alias/src/alias/agent/agents/data_source/data_skill.py
Normal file
@@ -0,0 +1,198 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
import os
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import frontmatter
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from agentscope.tool._types import AgentSkill
|
||||||
|
|
||||||
|
from alias.agent.agents.ds_agent_utils.utils import get_prompt_from_file
|
||||||
|
from alias.agent.agents.data_source._typing import SourceType
|
||||||
|
|
||||||
|
|
||||||
|
class DataSkill(AgentSkill):
|
||||||
|
"""The source type of the skill."""
|
||||||
|
|
||||||
|
type: List[SourceType]
|
||||||
|
|
||||||
|
|
||||||
|
class DataSkillManager:
|
||||||
|
"""Data Skill Selector Based on Data Source Type"""
|
||||||
|
|
||||||
|
_default_skill_path_base = os.path.join(
|
||||||
|
Path(__file__).resolve().parent.parent,
|
||||||
|
"_built_in_skill/data",
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
):
|
||||||
|
self.skills = self.register_skill_dir()
|
||||||
|
|
||||||
|
self.source_type_2_skills = {}
|
||||||
|
for skill in self.skills:
|
||||||
|
for t in skill["type"]:
|
||||||
|
self.source_type_2_skills[t] = skill
|
||||||
|
|
||||||
|
def load(self, data_source_types: List[SourceType]) -> List[str]:
|
||||||
|
"""
|
||||||
|
Load skills based on data source type
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_source_types: List of SourceType enum values
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Selected skill content list
|
||||||
|
"""
|
||||||
|
if not data_source_types:
|
||||||
|
return []
|
||||||
|
|
||||||
|
selected_skills = []
|
||||||
|
|
||||||
|
data_source_types = set(data_source_types)
|
||||||
|
for source_type in data_source_types:
|
||||||
|
try:
|
||||||
|
# Get skill from source type mapping
|
||||||
|
skill = self.source_type_2_skills.get(source_type, None)
|
||||||
|
|
||||||
|
# Skip if no corresponding skill
|
||||||
|
if not skill:
|
||||||
|
logger.warning(
|
||||||
|
"DataSkillSelector found no valid skill for data "
|
||||||
|
f"source type: {source_type}",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
f"DataSkillSelector selected skill: {skill['name']} "
|
||||||
|
f"for data source type: {source_type}",
|
||||||
|
)
|
||||||
|
|
||||||
|
skill_content = get_prompt_from_file(
|
||||||
|
skill["dir"],
|
||||||
|
return_json=False,
|
||||||
|
)
|
||||||
|
if skill_content:
|
||||||
|
selected_skills.append(skill_content)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"DataSkillSelector selection failed: {str(e)} "
|
||||||
|
f"for data source type: {source_type}",
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
return selected_skills
|
||||||
|
|
||||||
|
def register_skill_dir(self, skill_dir=_default_skill_path_base):
|
||||||
|
"""Load skills from all directories containing SKILL.md"""
|
||||||
|
|
||||||
|
skills = []
|
||||||
|
# Check the skill directory
|
||||||
|
if not os.path.isdir(skill_dir):
|
||||||
|
raise ValueError(
|
||||||
|
f"The skill directory '{skill_dir}' does not exist or is "
|
||||||
|
"not a directory.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Walk through all files and directories in skill_dir_base
|
||||||
|
for root, dirs, _ in os.walk(skill_dir):
|
||||||
|
# Process directories - look for SKILL.md
|
||||||
|
for dir_name in dirs:
|
||||||
|
dir_path = os.path.join(root, dir_name)
|
||||||
|
skill = self.register_skill(dir_path)
|
||||||
|
if skill:
|
||||||
|
skills.append(skill)
|
||||||
|
|
||||||
|
return skills
|
||||||
|
|
||||||
|
def register_skill(self, path: str, name=None):
|
||||||
|
"""
|
||||||
|
Register a new skill dynamically
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Skill name
|
||||||
|
path: Path to skill directory containing SKILL.md
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Resolve the skill path
|
||||||
|
file_path = self._resolve_skill_path(path)
|
||||||
|
if not file_path:
|
||||||
|
raise FileNotFoundError("`SKILL.md` not found")
|
||||||
|
|
||||||
|
# Parse the skill file
|
||||||
|
skill = self._parse_skill_file(file_path, name)
|
||||||
|
logger.info(
|
||||||
|
f"Successfully registered skill '{skill['name']}' "
|
||||||
|
f"from '{file_path}'",
|
||||||
|
)
|
||||||
|
|
||||||
|
return skill
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to register skill '{skill['name']}' from "
|
||||||
|
f"'{path}': {e}",
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _resolve_skill_path(self, path: str) -> str:
|
||||||
|
"""
|
||||||
|
Resolve a skill path to the actual markdown file path
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Path to skill markdown file or directory containing SKILL.md
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Path to the actual markdown file, or empty string if invalid
|
||||||
|
"""
|
||||||
|
if os.path.isdir(path):
|
||||||
|
skill_md_path = os.path.join(path, "SKILL.md")
|
||||||
|
if not os.path.isfile(skill_md_path):
|
||||||
|
logger.warning(f"Directory '{path}' does not contain SKILL.md")
|
||||||
|
return ""
|
||||||
|
return skill_md_path
|
||||||
|
else:
|
||||||
|
logger.warning(f"Invalid skill path: {path}")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def _parse_skill_file(self, file_path, name=None):
|
||||||
|
"""Parse a skill file and add it to skills_list"""
|
||||||
|
|
||||||
|
# Check YAML Front Matter
|
||||||
|
post = frontmatter.load(file_path)
|
||||||
|
|
||||||
|
# Use directory name as skill name if not provided in YAML
|
||||||
|
if name is None:
|
||||||
|
dir_name = os.path.basename(os.path.dirname(file_path))
|
||||||
|
name = post.get("name", dir_name)
|
||||||
|
else:
|
||||||
|
name = post.get("name", name)
|
||||||
|
|
||||||
|
description = post.get("description", None)
|
||||||
|
_type = post.get("type", None)
|
||||||
|
|
||||||
|
if not name or not description or not _type:
|
||||||
|
raise ValueError(
|
||||||
|
f"The file '{file_path}' must have a YAML Front "
|
||||||
|
"Matter including `name`, `description`, and `type` fields",
|
||||||
|
)
|
||||||
|
|
||||||
|
_type = _type if isinstance(_type, list) else [_type]
|
||||||
|
if any(not SourceType.is_valid_source_type(t) for t in _type):
|
||||||
|
raise ValueError(
|
||||||
|
f"Type of file '{file_path}' must be a member "
|
||||||
|
"(or a list of members) of SourceType",
|
||||||
|
)
|
||||||
|
|
||||||
|
name, description = str(name), str(description)
|
||||||
|
_type = [SourceType(t) for t in _type]
|
||||||
|
|
||||||
|
return DataSkill(
|
||||||
|
name=name,
|
||||||
|
description=description,
|
||||||
|
type=_type,
|
||||||
|
dir=file_path,
|
||||||
|
)
|
||||||
620
alias/src/alias/agent/agents/data_source/data_source.py
Normal file
620
alias/src/alias/agent/agents/data_source/data_source.py
Normal file
@@ -0,0 +1,620 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# pylint: disable=R1702,R0912,R0911
|
||||||
|
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Any, Optional, List
|
||||||
|
import yaml
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from agentscope.mcp import StdIOStatefulClient
|
||||||
|
from agentscope_runtime.sandbox.box.sandbox import Sandbox
|
||||||
|
|
||||||
|
from alias.agent.agents.data_source.data_skill import DataSkillManager
|
||||||
|
from alias.agent.agents.data_source._typing import (
|
||||||
|
SOURCE_TYPE_TO_ACCESS_TYPE,
|
||||||
|
SourceAccessType,
|
||||||
|
SourceType,
|
||||||
|
)
|
||||||
|
from alias.agent.agents.data_source.data_profile import data_profile
|
||||||
|
from alias.agent.agents.data_source.utils import replace_placeholders
|
||||||
|
from alias.agent.tools.toolkit_hooks.text_post_hook import TextPostHook
|
||||||
|
from alias.agent.tools.alias_toolkit import AliasToolkit
|
||||||
|
from alias.agent.tools.sandbox_util import (
|
||||||
|
copy_local_file_to_workspace,
|
||||||
|
)
|
||||||
|
from alias.agent.utils.llm_call_manager import (
|
||||||
|
LLMCallManager,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DataSource:
|
||||||
|
"""
|
||||||
|
Unified data source class representing any data source.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
endpoint: str,
|
||||||
|
source_type: SourceType,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
config: Optional[Dict[str, Any]] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize a data source.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source_access_type: Type of the data source access \
|
||||||
|
(SourceAccessType enum)
|
||||||
|
source_type: Type of the data source (SourceType enum)
|
||||||
|
name: Name/identifier of the data source
|
||||||
|
endpoint: Address/DNS/URL/path to access the data source
|
||||||
|
description: Optional description of the data source
|
||||||
|
config: Configuration for this data source
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.endpoint = endpoint
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
source_access_type = SOURCE_TYPE_TO_ACCESS_TYPE.get(
|
||||||
|
source_type,
|
||||||
|
SourceAccessType.DIRECT,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.source_access_type = source_access_type
|
||||||
|
self.source_type = source_type
|
||||||
|
|
||||||
|
self.config = config or {}
|
||||||
|
self.profile = {}
|
||||||
|
self.source_desc = None
|
||||||
|
self.source_access_desc = None
|
||||||
|
|
||||||
|
async def prepare(self, toolkit: AliasToolkit):
|
||||||
|
"""
|
||||||
|
Prepare data source.
|
||||||
|
For LOCAL_FILE: Upload file to sandbox workspace
|
||||||
|
For MCP_TOOL: Register corresponding MCP server
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sandbox: Sandbox instance
|
||||||
|
"""
|
||||||
|
|
||||||
|
logger.info(f"Preparing data source {self.name}...")
|
||||||
|
|
||||||
|
if self.source_access_type == SourceAccessType.DIRECT:
|
||||||
|
# Get the filename and construct target path in workspace
|
||||||
|
filename = os.path.basename(self.endpoint)
|
||||||
|
target_path = f"/workspace/{filename}"
|
||||||
|
|
||||||
|
if os.getenv("LINK_FILE_TO_WORKSPACE", "off").lower() == "on":
|
||||||
|
logger.info(
|
||||||
|
f"Creating symlink for {self.endpoint} "
|
||||||
|
f"to {target_path}",
|
||||||
|
)
|
||||||
|
# Build ln -s command
|
||||||
|
command = f"ln -s '{self.endpoint}' '{target_path}'"
|
||||||
|
result = toolkit.sandbox.call_tool(
|
||||||
|
name="run_shell_command",
|
||||||
|
arguments={"command": command},
|
||||||
|
)
|
||||||
|
if result.get("isError"):
|
||||||
|
raise ValueError(
|
||||||
|
"Failed to create symlink for "
|
||||||
|
f"{self.endpoint}: {result}",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info(f"Uploading {self.endpoint} to {target_path}")
|
||||||
|
result = copy_local_file_to_workspace(
|
||||||
|
sandbox=toolkit.sandbox,
|
||||||
|
local_path=self.endpoint,
|
||||||
|
target_path=target_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.get("isError"):
|
||||||
|
raise ValueError(
|
||||||
|
f"Failed to upload {self.endpoint}: " f"{result}",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.source_access = target_path
|
||||||
|
self.source_desc = "Local file"
|
||||||
|
self.source_access_desc = f"Access at path: `{target_path}`"
|
||||||
|
|
||||||
|
logger.info(f"Successfully loaded to {result}")
|
||||||
|
|
||||||
|
# Check if this is an MCP tool source
|
||||||
|
elif self.source_access_type == SourceAccessType.VIA_MCP:
|
||||||
|
server_config = self.config.get("mcp_server", {})
|
||||||
|
mcp_server_name = server_config.keys()
|
||||||
|
|
||||||
|
if len(mcp_server_name) != 1:
|
||||||
|
raise ValueError("Register server one by one!")
|
||||||
|
|
||||||
|
mcp_server_name = list(mcp_server_name)[0]
|
||||||
|
server_config = server_config[mcp_server_name]
|
||||||
|
|
||||||
|
cmd = server_config.get("command")
|
||||||
|
args = server_config.get("args")
|
||||||
|
if cmd is None or args is None:
|
||||||
|
raise ValueError(
|
||||||
|
"MCP server configuration requires non-empty "
|
||||||
|
"`command` and `args` fields to start!",
|
||||||
|
)
|
||||||
|
|
||||||
|
client = StdIOStatefulClient(
|
||||||
|
self.name,
|
||||||
|
command=cmd,
|
||||||
|
args=args,
|
||||||
|
env=server_config.get("env"),
|
||||||
|
)
|
||||||
|
|
||||||
|
text_hook = TextPostHook(
|
||||||
|
toolkit.sandbox,
|
||||||
|
budget=5000,
|
||||||
|
auto_save=True,
|
||||||
|
)
|
||||||
|
await toolkit.add_and_connect_mcp_client(
|
||||||
|
client,
|
||||||
|
postprocess_func=text_hook.truncate_and_save_response,
|
||||||
|
)
|
||||||
|
registered_tools = [
|
||||||
|
t.name
|
||||||
|
for t in list(
|
||||||
|
await toolkit.additional_mcp_clients[-1].list_tools(),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
self.source_access = self.endpoint
|
||||||
|
self.source_desc = f"{self.source_type}"
|
||||||
|
self.source_access_desc = (
|
||||||
|
f"Access via MCP tools: [{', '.join(registered_tools)}]"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Successfully connected to {self.name}")
|
||||||
|
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
f"Skipping preparation for source type: {self.source_type}",
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_coarse_desc(self):
|
||||||
|
return (
|
||||||
|
f"{self.source_desc}. {self.source_access_desc}: "
|
||||||
|
+ f"{self._general_profile()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def prepare_profile(
|
||||||
|
self,
|
||||||
|
sandbox: Sandbox,
|
||||||
|
llm_call_manager: LLMCallManager,
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Run type-specific profiling."""
|
||||||
|
if llm_call_manager and not self.profile:
|
||||||
|
try:
|
||||||
|
self.profile = await data_profile(
|
||||||
|
sandbox=sandbox,
|
||||||
|
sandbox_path=self.source_access,
|
||||||
|
source_type=self.source_type,
|
||||||
|
llm_call_manager=llm_call_manager,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"Profiling successfully: "
|
||||||
|
+ f"{self._general_profile()[:100]}...",
|
||||||
|
)
|
||||||
|
except ValueError as e:
|
||||||
|
self.profile = None
|
||||||
|
logger.warning(f"Warning when profile data: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
self.profile = None
|
||||||
|
logger.error(f"Error when profile data: {e}")
|
||||||
|
|
||||||
|
return self.profile
|
||||||
|
|
||||||
|
def _refined_profile(self) -> str:
|
||||||
|
if self.profile:
|
||||||
|
return yaml.dump(
|
||||||
|
self.profile,
|
||||||
|
allow_unicode=True,
|
||||||
|
sort_keys=False,
|
||||||
|
default_flow_style=False
|
||||||
|
if self.source_type == SourceType.IMAGE
|
||||||
|
else None,
|
||||||
|
width=float("inf"),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def _general_profile(self) -> str:
|
||||||
|
return self.profile["description"] if self.profile else ""
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return (
|
||||||
|
f"DataSource(name='{self.name}', type='{self.source_type}', "
|
||||||
|
f"endpoint='{self.endpoint}')"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return self.__str__()
|
||||||
|
|
||||||
|
|
||||||
|
class DataSourceManager:
|
||||||
|
"""
|
||||||
|
Manager class for handling multiple data sources.
|
||||||
|
Provides methods to add, retrieve, and manage data sources.
|
||||||
|
Also manages data source configurations with hierarchical lookup.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_default_data_source_config = os.path.join(
|
||||||
|
Path(__file__).resolve().parent,
|
||||||
|
"_default_config.json",
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
sandbox: Sandbox,
|
||||||
|
llm_call_manager: LLMCallManager,
|
||||||
|
):
|
||||||
|
"""Initialize an empty data source manager."""
|
||||||
|
self._data_sources: Dict[str, DataSource] = {}
|
||||||
|
self._type_defaults = {}
|
||||||
|
|
||||||
|
self._load_default_config()
|
||||||
|
|
||||||
|
self.skill_manager = DataSkillManager()
|
||||||
|
self.selected_skills = None
|
||||||
|
|
||||||
|
self.toolkit = AliasToolkit(sandbox=sandbox)
|
||||||
|
|
||||||
|
self.llm_call_manager = llm_call_manager
|
||||||
|
|
||||||
|
def add_data_source(
|
||||||
|
self,
|
||||||
|
config: str | Dict = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Add a new data source (or multiple sources) to the manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: endpoint(Address/DNS/URL/path to the data source) or
|
||||||
|
configuration for data source conection
|
||||||
|
"""
|
||||||
|
|
||||||
|
if isinstance(config, str):
|
||||||
|
endpoint = config
|
||||||
|
conn_config = None
|
||||||
|
else:
|
||||||
|
if "endpoint" not in config:
|
||||||
|
logger.error(
|
||||||
|
f"Missing 'endpoint' in config for source '{config}'",
|
||||||
|
)
|
||||||
|
|
||||||
|
endpoint = config["endpoint"]
|
||||||
|
conn_config = config
|
||||||
|
|
||||||
|
sources = set()
|
||||||
|
if os.path.isdir(endpoint):
|
||||||
|
# Add all files in directory
|
||||||
|
for filename in os.listdir(endpoint):
|
||||||
|
file_path = os.path.join(endpoint, filename)
|
||||||
|
sources.add(file_path)
|
||||||
|
else:
|
||||||
|
sources.add(endpoint)
|
||||||
|
|
||||||
|
for endpoint in sources:
|
||||||
|
# Auto-detect source type
|
||||||
|
source_type = self._detect_source_type(endpoint)
|
||||||
|
|
||||||
|
# Auto-generate name
|
||||||
|
name = self._generate_name(endpoint)
|
||||||
|
|
||||||
|
# Get configuration for this data source
|
||||||
|
if not conn_config:
|
||||||
|
conn_config = self.get_default_config(source_type)
|
||||||
|
|
||||||
|
if conn_config:
|
||||||
|
conn_config = replace_placeholders(
|
||||||
|
conn_config,
|
||||||
|
{
|
||||||
|
"endpoint": endpoint,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create data source with configuration
|
||||||
|
data_source = DataSource(endpoint, source_type, name, conn_config)
|
||||||
|
self._data_sources[endpoint] = data_source
|
||||||
|
|
||||||
|
async def prepare_data_sources(self) -> None:
|
||||||
|
"""
|
||||||
|
Prepare all data sources.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sandbox: Optional sandbox instance for file uploads and startup \
|
||||||
|
MCP servers
|
||||||
|
"""
|
||||||
|
logger.info(f"Preparing {len(self._data_sources)} data source(s)...")
|
||||||
|
|
||||||
|
all_data_sources = self._data_sources.values()
|
||||||
|
for data_source in all_data_sources:
|
||||||
|
await data_source.prepare(self.toolkit)
|
||||||
|
await data_source.prepare_profile(
|
||||||
|
self.toolkit.sandbox,
|
||||||
|
self.llm_call_manager,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _generate_name(self, endpoint: str) -> str:
|
||||||
|
"""
|
||||||
|
Generate an name based on the endpoint.
|
||||||
|
For databases, removes passwords and uses scheme + database name.
|
||||||
|
For files, uses filename.
|
||||||
|
For URLs, uses domain or last part of path.
|
||||||
|
"""
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
try:
|
||||||
|
# For file paths
|
||||||
|
if os.path.isfile(endpoint):
|
||||||
|
filename = os.path.basename(endpoint)
|
||||||
|
# Remove extension and sanitize
|
||||||
|
name_without_ext = os.path.splitext(filename)[0]
|
||||||
|
return self._sanitize_name(name_without_ext)
|
||||||
|
|
||||||
|
# For database connections
|
||||||
|
db_indicators = [
|
||||||
|
"://",
|
||||||
|
".db",
|
||||||
|
".sqlite",
|
||||||
|
"mongodb://",
|
||||||
|
"mongodb+srv://",
|
||||||
|
"neo4j://",
|
||||||
|
"bolt://",
|
||||||
|
]
|
||||||
|
if any(
|
||||||
|
indicator in endpoint.lower() for indicator in db_indicators
|
||||||
|
):
|
||||||
|
if "://" in endpoint:
|
||||||
|
try:
|
||||||
|
# Split by :// to get scheme and rest
|
||||||
|
scheme, rest = endpoint.split("://", 1)
|
||||||
|
scheme = scheme.lower()
|
||||||
|
|
||||||
|
# Handle authentication (user:password@host)
|
||||||
|
if "@" in rest:
|
||||||
|
auth_part, host_part = rest.split("@", 1)
|
||||||
|
if ":" in auth_part:
|
||||||
|
# Has user:password format, keep only username
|
||||||
|
username = auth_part.split(":")[0]
|
||||||
|
rest = f"{username}@{host_part}"
|
||||||
|
# If no colon, it's just username@host, keep as is
|
||||||
|
|
||||||
|
# Extract database name
|
||||||
|
db_name = "unknown"
|
||||||
|
if "/" in rest:
|
||||||
|
# Split by / and take last part before
|
||||||
|
# query parameters
|
||||||
|
path_parts = rest.split("/")
|
||||||
|
if len(path_parts) > 1:
|
||||||
|
db_name = (
|
||||||
|
path_parts[-1].split("?")[0].split("#")[0]
|
||||||
|
)
|
||||||
|
if not db_name: # If empty, try second to last
|
||||||
|
db_name = (
|
||||||
|
path_parts[-2]
|
||||||
|
if len(path_parts) > 2
|
||||||
|
else "unknown"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Use host name if no database name in path
|
||||||
|
host = (
|
||||||
|
rest.split(":")[0].split("/")[0].split("@")[-1]
|
||||||
|
)
|
||||||
|
db_name = host
|
||||||
|
|
||||||
|
# Create name: scheme_dbname
|
||||||
|
return self._sanitize_name(f"{scheme}_{db_name}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Error parsing database URL {endpoint}: {e}",
|
||||||
|
)
|
||||||
|
# Fall through to URL handling
|
||||||
|
|
||||||
|
elif "." in endpoint:
|
||||||
|
# Use filename without extension for .db/.sqlite files
|
||||||
|
filename = os.path.basename(endpoint)
|
||||||
|
name_without_ext = os.path.splitext(filename)[0]
|
||||||
|
return self._sanitize_name(name_without_ext)
|
||||||
|
|
||||||
|
# For URLs (including database URLs that failed to parse)
|
||||||
|
if "://" in endpoint:
|
||||||
|
try:
|
||||||
|
parsed = urlparse(endpoint)
|
||||||
|
if parsed.netloc:
|
||||||
|
# Use domain name (without port)
|
||||||
|
domain = parsed.netloc.split(":")[0].split("@")[
|
||||||
|
-1
|
||||||
|
] # Remove username if present
|
||||||
|
# If path exists, use last part of path
|
||||||
|
if parsed.path and parsed.path != "/":
|
||||||
|
path_parts = parsed.path.strip("/").split("/")
|
||||||
|
if path_parts:
|
||||||
|
return self._sanitize_name(path_parts[-1])
|
||||||
|
return self._sanitize_name(domain)
|
||||||
|
elif parsed.path:
|
||||||
|
# Use last part of path
|
||||||
|
path_parts = parsed.path.strip("/").split("/")
|
||||||
|
if path_parts:
|
||||||
|
return self._sanitize_name(path_parts[-1])
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error parsing URL {endpoint}: {e}")
|
||||||
|
|
||||||
|
# Fallback: use a sanitized version of the endpoint
|
||||||
|
return self._sanitize_name(endpoint[:50])
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error generating default name for {endpoint}: {e}")
|
||||||
|
# Ultimate fallback
|
||||||
|
return self._sanitize_name("unknown_source")
|
||||||
|
|
||||||
|
def _sanitize_name(self, name: str) -> str:
|
||||||
|
"""Sanitize a name to be used as a data source identifier."""
|
||||||
|
import re
|
||||||
|
|
||||||
|
# Keep only alphanumeric and underscore characters
|
||||||
|
sanitized = re.sub(r"[^a-zA-Z0-9_]", "_", name)
|
||||||
|
|
||||||
|
# Ensure it starts with a letter or underscore
|
||||||
|
if sanitized and not sanitized[0].isalpha() and sanitized[0] != "_":
|
||||||
|
sanitized = "_" + sanitized
|
||||||
|
|
||||||
|
# Truncate if too long
|
||||||
|
sanitized = sanitized[:50]
|
||||||
|
|
||||||
|
# Ensure it's not empty
|
||||||
|
if not sanitized:
|
||||||
|
sanitized = "unknown"
|
||||||
|
|
||||||
|
return sanitized
|
||||||
|
|
||||||
|
def _detect_source_type(self, endpoint: str) -> SourceType:
|
||||||
|
"""Auto-detect source type based on endpoint."""
|
||||||
|
endpoint_lower = endpoint.lower()
|
||||||
|
|
||||||
|
# Check for file extensions
|
||||||
|
if endpoint_lower.endswith(".csv"):
|
||||||
|
source_type = SourceType.CSV
|
||||||
|
elif endpoint_lower.endswith((".xls", ".xlsx", "xlsm")):
|
||||||
|
source_type = SourceType.EXCEL
|
||||||
|
elif endpoint_lower.endswith(".json"):
|
||||||
|
source_type = SourceType.JSON
|
||||||
|
elif endpoint_lower.endswith((".txt", ".log", ".md")):
|
||||||
|
source_type = SourceType.TEXT
|
||||||
|
elif endpoint_lower.endswith(
|
||||||
|
(".jpg", ".jpeg", ".png", ".gif", ".bmp"),
|
||||||
|
):
|
||||||
|
source_type = SourceType.IMAGE
|
||||||
|
|
||||||
|
# Check for database connection strings/patterns
|
||||||
|
# Relational databases
|
||||||
|
elif any(
|
||||||
|
keyword in endpoint_lower
|
||||||
|
for keyword in [
|
||||||
|
"postgresql://",
|
||||||
|
"postgres://",
|
||||||
|
"pg://",
|
||||||
|
"mysql://",
|
||||||
|
"mariadb://",
|
||||||
|
"sqlserver://",
|
||||||
|
]
|
||||||
|
):
|
||||||
|
source_type = SourceType.RELATIONAL_DB
|
||||||
|
elif (
|
||||||
|
"sqlite://" in endpoint_lower
|
||||||
|
or endpoint_lower.endswith(".db")
|
||||||
|
or endpoint_lower.endswith(".sqlite")
|
||||||
|
):
|
||||||
|
source_type = SourceType.RELATIONAL_DB
|
||||||
|
|
||||||
|
else:
|
||||||
|
source_type = SourceType.OTHER
|
||||||
|
|
||||||
|
return source_type
|
||||||
|
|
||||||
|
def get_all_data_sources_desc(self) -> str:
|
||||||
|
"""
|
||||||
|
Get descriptions of all data sources.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of data source descriptions
|
||||||
|
"""
|
||||||
|
return "Available data sources: \n" + "\n".join(
|
||||||
|
[
|
||||||
|
f"[{idx}] " + ds.get_coarse_desc()
|
||||||
|
for idx, ds in enumerate(self._data_sources.values())
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_local_data_sources(self) -> List[str]:
|
||||||
|
"""
|
||||||
|
Get list of local data source endpoints
|
||||||
|
"""
|
||||||
|
|
||||||
|
return [
|
||||||
|
ds.endpoint
|
||||||
|
for ds in self._data_sources.values()
|
||||||
|
if ds.source_access_type == SourceAccessType.DIRECT
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_all_data_sources_name(self) -> List[str]:
|
||||||
|
"""
|
||||||
|
Get a list of all data source names.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of all data source names
|
||||||
|
"""
|
||||||
|
return list(self._data_sources.keys())
|
||||||
|
|
||||||
|
def remove_data_source(self, name: str) -> bool:
|
||||||
|
"""
|
||||||
|
Remove a data source by name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Name of the data source to remove
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if successfully removed, False if not found
|
||||||
|
"""
|
||||||
|
if name in self._data_sources:
|
||||||
|
del self._data_sources[name]
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_default_config(self, source_type: SourceType) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get the default configuration for a source type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source_type: The SourceType to get default config for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Default configuration dictionary, empty dict if not registered
|
||||||
|
"""
|
||||||
|
return self._type_defaults.get(source_type, {})
|
||||||
|
|
||||||
|
def _load_default_config(self) -> None:
|
||||||
|
"""Load default type to configuration."""
|
||||||
|
try:
|
||||||
|
with open(
|
||||||
|
self._default_data_source_config,
|
||||||
|
"r",
|
||||||
|
encoding="utf-8",
|
||||||
|
) as f:
|
||||||
|
config = json.load(f)
|
||||||
|
|
||||||
|
# Load type defaults
|
||||||
|
for type_name, type_config in config.items():
|
||||||
|
try:
|
||||||
|
source_type = SourceType(type_name)
|
||||||
|
self._type_defaults[source_type] = type_config
|
||||||
|
except ValueError:
|
||||||
|
# Skip invalid source types
|
||||||
|
continue
|
||||||
|
|
||||||
|
except FileNotFoundError:
|
||||||
|
# If config file doesn't exist, initialize with empty configs
|
||||||
|
pass
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# If config file is invalid JSON, initialize with empty configs
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""Return the number of data sources managed."""
|
||||||
|
return len(self._data_sources)
|
||||||
|
|
||||||
|
def get_data_skills(self):
|
||||||
|
# TODO: update when data source changed
|
||||||
|
if self.selected_skills is None:
|
||||||
|
source_types = [
|
||||||
|
data.source_type for data in self._data_sources.values()
|
||||||
|
]
|
||||||
|
self.selected_skills = self.skill_manager.load(source_types)
|
||||||
|
|
||||||
|
return "\n".join(self.selected_skills)
|
||||||
30
alias/src/alias/agent/agents/data_source/utils.py
Normal file
30
alias/src/alias/agent/agents/data_source/utils.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
import logging
|
||||||
|
|
||||||
|
# Set up logger
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def replace_placeholders(obj, source_config):
|
||||||
|
if isinstance(obj, str):
|
||||||
|
import re
|
||||||
|
|
||||||
|
pattern = r"\$\{([^}]+)\}"
|
||||||
|
matches = re.finditer(pattern, obj)
|
||||||
|
result = obj
|
||||||
|
for match in matches:
|
||||||
|
var_name = match.group(1)
|
||||||
|
if var_name in source_config:
|
||||||
|
result = result.replace(
|
||||||
|
match.group(0),
|
||||||
|
str(source_config[var_name]),
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
elif isinstance(obj, dict):
|
||||||
|
return {
|
||||||
|
k: replace_placeholders(v, source_config) for k, v in obj.items()
|
||||||
|
}
|
||||||
|
elif isinstance(obj, list):
|
||||||
|
return [replace_placeholders(item, source_config) for item in obj]
|
||||||
|
else:
|
||||||
|
return obj
|
||||||
@@ -137,11 +137,12 @@ async def files_filter_pre_reply_hook(
|
|||||||
# Even if the user only uploaded supplementary files in this interaction,
|
# Even if the user only uploaded supplementary files in this interaction,
|
||||||
# We will also check whether the previously uploaded files are relevant
|
# We will also check whether the previously uploaded files are relevant
|
||||||
# to the question.
|
# to the question.
|
||||||
self.uploaded_files = list(
|
|
||||||
set(files_list) | set(getattr(self, "uploaded_files", [])),
|
uploaded_files = list(
|
||||||
|
set(files_list) | set(self.data_manager.get_local_data_sources()),
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(self.uploaded_files) < 100:
|
if len(uploaded_files) < 100:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Scalable files filtering: not enough files to filter.",
|
"Scalable files filtering: not enough files to filter.",
|
||||||
)
|
)
|
||||||
@@ -164,7 +165,7 @@ await files_filter(query, files_list, api_key=api_key)
|
|||||||
|
|
||||||
files_filter_code += template.substitute(
|
files_filter_code += template.substitute(
|
||||||
query=safe_query,
|
query=safe_query,
|
||||||
files_list=repr(self.uploaded_files),
|
files_list=repr(uploaded_files),
|
||||||
api_key=safe_api_key,
|
api_key=safe_api_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -28,63 +28,7 @@ When executing any data science task (data loading, cleaning, analysis, modeling
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Task Management Rules
|
## Principles: Fact-Based, No Assumptions
|
||||||
|
|
||||||
- **You must use `todo_write` to track progress**, especially for multi-step tasks.
|
|
||||||
- Mark each subtask as complete **immediately** upon finishing—no delays or batch updates.
|
|
||||||
- Skipping planning risks missing critical steps—this is unacceptable.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Data Handling Requirements
|
|
||||||
|
|
||||||
### Data Inspection Methods
|
|
||||||
|
|
||||||
Before any operation, **you must** inspect the true structure of the data source using tools (preferably `run_ipython_cell`):
|
|
||||||
|
|
||||||
| Data Type | Inspection Method |
|
|
||||||
|------------------|-----------------------------------------------------------------------------------|
|
|
||||||
| **Database** | Query table schema (`DESCRIBE table`) and preview first 5–10 rows (`SELECT * FROM ... LIMIT 5`) |
|
|
||||||
| **CSV/Excel** | Use `pandas.head(n)` to view column names and samples |
|
|
||||||
| **Images** | Use PIL to get dimensions/format, or invoke vision tools to extract content |
|
|
||||||
| **Text Files** | Read first 5–10 lines to determine structure and encoding |
|
|
||||||
| **JSON** | Inspect from outer to inner layers progressively |
|
|
||||||
|
|
||||||
> **Core Principle**: What you see is fact; what you haven’t seen is unknown.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### Data Preprocessing Methods
|
|
||||||
|
|
||||||
##### Messy Spreadsheet Handling
|
|
||||||
|
|
||||||
After initial inspection of CSV or Excel files, if you observe:
|
|
||||||
|
|
||||||
- Many `"Unnamed: X"`, `NaN`, or `NaT` entries
|
|
||||||
- Missing or ambiguous headers
|
|
||||||
- Multiple data blocks within a single worksheet
|
|
||||||
|
|
||||||
Then **prioritize** advanced cleaning tools:
|
|
||||||
|
|
||||||
- `clean_messy_spreadsheet`: Extract key information from tables and output as JSON for downstream analysis
|
|
||||||
|
|
||||||
Only fall back to manual pandas row/block parsing if this tool fails.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### Strict Data Volume Limits
|
|
||||||
|
|
||||||
To prevent system crashes, strictly limit data volume during queries and reads:
|
|
||||||
|
|
||||||
- **Database queries**: Always use `LIMIT` (typically 5–10 rows)
|
|
||||||
- **Well-structured CSV/Excel**: Use `head()`, `nrows`, or sampling to fetch minimal data
|
|
||||||
- **Large text files**: Read only the first few lines or process iteratively in chunks
|
|
||||||
|
|
||||||
> **Warning**: Unrestricted large data reads will cause system failure.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### Fact-Based, No Assumptions
|
|
||||||
|
|
||||||
- All decisions must be grounded in the **given task context**. Never simplify, generalize, or subjectively interpret the task goal, data purpose, or business scenario. Any action inconsistent with the problem context is invalid and dangerous.
|
- All decisions must be grounded in the **given task context**. Never simplify, generalize, or subjectively interpret the task goal, data purpose, or business scenario. Any action inconsistent with the problem context is invalid and dangerous.
|
||||||
- Never act on assumptions, guesses, or past experience—even if the situation seems "obvious" or "routine."
|
- Never act on assumptions, guesses, or past experience—even if the situation seems "obvious" or "routine."
|
||||||
@@ -93,6 +37,14 @@ To prevent system crashes, strictly limit data volume during queries and reads:
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
## Task Management Rules
|
||||||
|
|
||||||
|
- **You must use `todo_write` to track progress**, especially for multi-step tasks.
|
||||||
|
- Mark each subtask as complete **immediately** upon finishing—no delays or batch updates.
|
||||||
|
- Skipping planning risks missing critical steps—this is unacceptable.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
## Visualization Strategy
|
## Visualization Strategy
|
||||||
|
|
||||||
- **Plotting library**: Prefer `matplotlib`
|
- **Plotting library**: Prefer `matplotlib`
|
||||||
|
|||||||
@@ -44,8 +44,13 @@ Each task in the roadmap contains:
|
|||||||
- Brief Response
|
- Brief Response
|
||||||
- Detailed Report
|
- Detailed Report
|
||||||
- You should choose the template that is most appropriate for the user task.
|
- You should choose the template that is most appropriate for the user task.
|
||||||
- **Brief Respoonse Template** should ONLY be used when the user asks for a simple data query task, where ONLY numeric or concise string values are returned, and complex analysis or research are not required.
|
- **Brief Respoonse Template** should ONLY be used when the user asks for a
|
||||||
- **Detailed Report Template** should be used when the user asks for a detailed analysis of the data, where the analysis and research are required.
|
simple, static data point (e.g., a total count or a specific value), where
|
||||||
|
the answer is returned as a single numeric or concise string value with no
|
||||||
|
analysis, transformation, comparison, or interpretation required.
|
||||||
|
- **Detailed Report Template** should be used whenever the task involves
|
||||||
|
distribution, discrepancy, imbalance, comparison, trend, root cause, or
|
||||||
|
any form of analysis, interpretation, or evidence generation.
|
||||||
|
|
||||||
2. Data Source Constraints
|
2. Data Source Constraints
|
||||||
- **ONLY use information explicitly present in the log file**
|
- **ONLY use information explicitly present in the log file**
|
||||||
@@ -103,7 +108,7 @@ You MUST ensure all captions, subtitles, and other contents in the report are wr
|
|||||||
- "brief_response": The brief response content.
|
- "brief_response": The brief response content.
|
||||||
- When 'is_brief_response' is True, this field should be fulfilled with the brief response content following the **Brief Response Template**.
|
- When 'is_brief_response' is True, this field should be fulfilled with the brief response content following the **Brief Response Template**.
|
||||||
- When 'is_brief_response' is False, this field should be a concise summary of the detailed report in in markdown format illustrating the key findings and insights.
|
- When 'is_brief_response' is False, this field should be a concise summary of the detailed report in in markdown format illustrating the key findings and insights.
|
||||||
- "detailed_report_content": The detailed markdown report content following the **Detailed Report Template**. This field is ONLY generated when 'is_brief_response' is False, otherwise fulfill an empty string.
|
- "report_content": The detailed markdown report content following the **Detailed Report Template**. This field is ONLY generated when 'is_brief_response' is False, otherwise fulfill an empty string.
|
||||||
- You MUST ensure the JSON object is a valid JSON string and can be parsed by json.loads().
|
- You MUST ensure the JSON object is a valid JSON string and can be parsed by json.loads().
|
||||||
- Double check all escapes are valid.
|
- Double check all escapes are valid.
|
||||||
|
|
||||||
|
|||||||
@@ -112,8 +112,11 @@ def truncate_long_text_post_hook(
|
|||||||
def _add_tool_postprocessing_func(toolkit: AliasToolkit) -> None:
|
def _add_tool_postprocessing_func(toolkit: AliasToolkit) -> None:
|
||||||
for tool_func, _ in toolkit.tools.items():
|
for tool_func, _ in toolkit.tools.items():
|
||||||
if tool_func.startswith("run_ipython_cell"):
|
if tool_func.startswith("run_ipython_cell"):
|
||||||
funcs: list = [ansi_escape_post_hook]
|
funcs: list = [
|
||||||
funcs.append(summarize_plt_chart_hook)
|
ansi_escape_post_hook,
|
||||||
|
summarize_plt_chart_hook,
|
||||||
|
truncate_long_text_post_hook,
|
||||||
|
]
|
||||||
toolkit.tools[tool_func].postprocess_func = partial(
|
toolkit.tools[tool_func].postprocess_func = partial(
|
||||||
run_ipython_cell_post_hook,
|
run_ipython_cell_post_hook,
|
||||||
funcs,
|
funcs,
|
||||||
|
|||||||
@@ -1,20 +1,62 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
import os
|
import os
|
||||||
import json
|
|
||||||
import time
|
import time
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
import dotenv
|
import dotenv
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from agentscope.message import Msg
|
from agentscope.message import Msg
|
||||||
|
|
||||||
from .utils import model_call_with_retry, get_prompt_from_file
|
from .utils import model_call_with_retry, get_prompt_from_file
|
||||||
|
|
||||||
|
|
||||||
from .ds_config import PROMPT_DS_BASE_PATH
|
from .ds_config import PROMPT_DS_BASE_PATH
|
||||||
|
|
||||||
dotenv.load_dotenv()
|
dotenv.load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
class ReportResponse(BaseModel):
|
||||||
|
is_brief_response: bool = Field(
|
||||||
|
...,
|
||||||
|
description=(
|
||||||
|
"True if the response is a brief response; "
|
||||||
|
"False if it includes a detailed report."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
brief_response: str = Field(
|
||||||
|
...,
|
||||||
|
description=(
|
||||||
|
"The brief response content. "
|
||||||
|
"When 'is_brief_response' is True, this field contains the full "
|
||||||
|
"brief response following the Brief Response Template. "
|
||||||
|
"When 'is_brief_response' is False, this field contains a concise "
|
||||||
|
"markdown summary of the detailed report, highlighting key "
|
||||||
|
"findings and insights."
|
||||||
|
),
|
||||||
|
json_schema_extra={
|
||||||
|
"example": (
|
||||||
|
"The analysis shows a 15% increase in user engagement "
|
||||||
|
"after the feature update."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
report_content: str = Field(
|
||||||
|
...,
|
||||||
|
description=(
|
||||||
|
"The detailed markdown report content following the "
|
||||||
|
"Detailed Report Template. This field MUST be an empty "
|
||||||
|
"string ('') when 'is_brief_response' is True. It MUST contain "
|
||||||
|
"the full detailed report when 'is_brief_response' is False."
|
||||||
|
),
|
||||||
|
json_schema_extra={
|
||||||
|
"example": "### User Task Description...\n"
|
||||||
|
"### Associated Data Sources...\n"
|
||||||
|
"### Research Conclusion...\n### Task1...### Task2...",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ReportGenerator:
|
class ReportGenerator:
|
||||||
def __init__(self, model, formatter, memory_log: str):
|
def __init__(self, model, formatter, memory_log: str):
|
||||||
self.model = model
|
self.model = model
|
||||||
@@ -62,22 +104,13 @@ class ReportGenerator:
|
|||||||
self.formatter,
|
self.formatter,
|
||||||
msgs=msgs,
|
msgs=msgs,
|
||||||
msg_name="Report Generation",
|
msg_name="Report Generation",
|
||||||
|
structured_model=ReportResponse,
|
||||||
)
|
)
|
||||||
|
|
||||||
raw_response = res.content[0]["text"]
|
|
||||||
|
|
||||||
# TODO: More robust response cleaning
|
|
||||||
if raw_response.strip().startswith("```json"):
|
|
||||||
cleaned = raw_response.strip()[len("```json") :].lstrip("\n")
|
|
||||||
if cleaned.endswith("```"):
|
|
||||||
cleaned = cleaned[:-3].rstrip()
|
|
||||||
response = cleaned
|
|
||||||
else:
|
|
||||||
response = raw_response.strip()
|
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
# print(response)
|
|
||||||
print(f"Log to markdown took {end_time - start_time} seconds")
|
print(f"Log to markdown took {end_time - start_time} seconds")
|
||||||
return response
|
|
||||||
|
return res.content[-1]["input"]
|
||||||
|
|
||||||
async def _convert_to_html(self, markdown_content: str) -> str:
|
async def _convert_to_html(self, markdown_content: str) -> str:
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
@@ -102,21 +135,15 @@ class ReportGenerator:
|
|||||||
print(f"Convert to html took {end_time - start_time} seconds")
|
print(f"Convert to html took {end_time - start_time} seconds")
|
||||||
return response.content[0]["text"]
|
return response.content[0]["text"]
|
||||||
|
|
||||||
async def generate_report(self) -> Tuple[str, str]:
|
async def generate_report(self) -> Tuple[str, str, str]:
|
||||||
markdown_response = await self._log_to_markdown()
|
"""
|
||||||
|
responseFormat: {
|
||||||
# responseFormat: {
|
"is_brief_response": True,
|
||||||
# "is_brief_response": True,
|
"brief_response": brief_response_content,
|
||||||
# "brief_response": brief_response_content,
|
"report_content": detailed_report_content
|
||||||
# "report_content": detailed_report_content
|
}
|
||||||
# }
|
"""
|
||||||
|
markdown_content = await self._log_to_markdown()
|
||||||
try:
|
|
||||||
markdown_content = json.loads(markdown_response)
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
print(f"Error parsing JSON response: {e}")
|
|
||||||
print(f"Response content: {markdown_response}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
str(markdown_content.get("is_brief_response", False)).lower()
|
str(markdown_content.get("is_brief_response", False)).lower()
|
||||||
@@ -124,12 +151,19 @@ class ReportGenerator:
|
|||||||
):
|
):
|
||||||
# During brief response mode,
|
# During brief response mode,
|
||||||
# directly return the brief response to the user.
|
# directly return the brief response to the user.
|
||||||
return markdown_content["brief_response"], ""
|
return markdown_content.get("brief_response", ""), "", ""
|
||||||
else:
|
else:
|
||||||
# In detailed report mode,
|
# In detailed report mode,
|
||||||
# convert the detailed report to HTML and return it to the user;
|
# convert the detailed report to HTML and return it to the user;
|
||||||
# if a brief summary of the report is needed,
|
# if a brief summary of the report is needed,
|
||||||
# it can be obtained through markdown_content["brief_response"].
|
# it can be obtained through markdown_content["brief_response"].
|
||||||
return markdown_content[
|
html_content = ""
|
||||||
"brief_response"
|
if os.getenv("ENABLE_HTML_REPORT", "ON").lower() != "off":
|
||||||
], await self._convert_to_html(markdown_content["report_content"])
|
html_content = await self._convert_to_html(
|
||||||
|
markdown_content.get("report_content", ""),
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
markdown_content.get("brief_response", ""),
|
||||||
|
markdown_content.get("report_content", ""),
|
||||||
|
html_content,
|
||||||
|
)
|
||||||
|
|||||||
@@ -5,9 +5,9 @@ import json
|
|||||||
from typing import Union
|
from typing import Union
|
||||||
from agentscope.message import Msg
|
from agentscope.message import Msg
|
||||||
from tenacity import retry, stop_after_attempt, wait_fixed
|
from tenacity import retry, stop_after_attempt, wait_fixed
|
||||||
from .ds_config import PROMPT_DS_BASE_PATH
|
|
||||||
|
|
||||||
MODEL_MAX_RETRIES = 50
|
from alias.agent.utils.constants import MODEL_MAX_RETRIES
|
||||||
|
from .ds_config import PROMPT_DS_BASE_PATH
|
||||||
|
|
||||||
|
|
||||||
def get_prompt_from_file(
|
def get_prompt_from_file(
|
||||||
@@ -36,10 +36,16 @@ async def model_call_with_retry(
|
|||||||
tool_json_schemas=None,
|
tool_json_schemas=None,
|
||||||
tool_choice=None,
|
tool_choice=None,
|
||||||
msg_name="model_call",
|
msg_name="model_call",
|
||||||
|
structured_model=None,
|
||||||
) -> Msg:
|
) -> Msg:
|
||||||
prompt = await formatter.format(msgs=msgs)
|
prompt = await formatter.format(msgs=msgs)
|
||||||
|
|
||||||
res = await model(prompt, tools=tool_json_schemas, tool_choice=tool_choice)
|
res = await model(
|
||||||
|
prompt,
|
||||||
|
tools=tool_json_schemas,
|
||||||
|
tool_choice=tool_choice,
|
||||||
|
structured_model=structured_model,
|
||||||
|
)
|
||||||
|
|
||||||
if model.stream:
|
if model.stream:
|
||||||
msg = Msg(msg_name, [], "assistant")
|
msg = Msg(msg_name, [], "assistant")
|
||||||
|
|||||||
@@ -40,6 +40,7 @@ class SessionEntity:
|
|||||||
query: str
|
query: str
|
||||||
upload_files: List = []
|
upload_files: List = []
|
||||||
is_chat: bool = False
|
is_chat: bool = False
|
||||||
|
data_config: List | None = None
|
||||||
use_long_term_memory_service: bool = False
|
use_long_term_memory_service: bool = False
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -51,6 +52,7 @@ class SessionEntity:
|
|||||||
"bi",
|
"bi",
|
||||||
"finance",
|
"finance",
|
||||||
] = "general",
|
] = "general",
|
||||||
|
data_config: List | None = None,
|
||||||
use_long_term_memory_service: bool = False,
|
use_long_term_memory_service: bool = False,
|
||||||
):
|
):
|
||||||
self.user_id: uuid.UUID = uuid.UUID(
|
self.user_id: uuid.UUID = uuid.UUID(
|
||||||
@@ -62,6 +64,7 @@ class SessionEntity:
|
|||||||
self.conversation_id: uuid.UUID = uuid.uuid4()
|
self.conversation_id: uuid.UUID = uuid.uuid4()
|
||||||
self.session_id: uuid.UUID = uuid.uuid4()
|
self.session_id: uuid.UUID = uuid.uuid4()
|
||||||
self.chat_mode = chat_mode
|
self.chat_mode = chat_mode
|
||||||
|
self.data_config = data_config
|
||||||
self.use_long_term_memory_service = use_long_term_memory_service
|
self.use_long_term_memory_service = use_long_term_memory_service
|
||||||
|
|
||||||
def ids(self):
|
def ids(self):
|
||||||
@@ -79,6 +82,7 @@ class MockSessionService:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
runtime_model: Any = None,
|
runtime_model: Any = None,
|
||||||
|
data_config: List | None = None,
|
||||||
use_long_term_memory_service: bool = False,
|
use_long_term_memory_service: bool = False,
|
||||||
):
|
):
|
||||||
self.session_id = "mock_session"
|
self.session_id = "mock_session"
|
||||||
@@ -86,6 +90,7 @@ class MockSessionService:
|
|||||||
self.messages = []
|
self.messages = []
|
||||||
self.plan = MockPlan()
|
self.plan = MockPlan()
|
||||||
self.session_entity = SessionEntity(
|
self.session_entity = SessionEntity(
|
||||||
|
data_config=data_config,
|
||||||
use_long_term_memory_service=use_long_term_memory_service,
|
use_long_term_memory_service=use_long_term_memory_service,
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -215,6 +220,72 @@ class MockSessionService:
|
|||||||
self.messages.append(db_message)
|
self.messages.append(db_message)
|
||||||
return db_message
|
return db_message
|
||||||
|
|
||||||
|
async def append_to_latest_message(
|
||||||
|
self,
|
||||||
|
content_to_append: str,
|
||||||
|
role_filter: Optional[str] = None,
|
||||||
|
) -> Optional[MockMessage]:
|
||||||
|
"""
|
||||||
|
Append content to the most recent message.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content_to_append: Content to append to the message
|
||||||
|
role_filter: Optional role filter (e.g., 'user', 'assistant')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated MockMessage or None if no message found
|
||||||
|
"""
|
||||||
|
if not self.messages:
|
||||||
|
logger.warning("No messages to append to")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Find the most recent message (optionally filtered by role)
|
||||||
|
target_message = None
|
||||||
|
for msg in reversed(self.messages):
|
||||||
|
if role_filter is None or msg.message.get("role") == role_filter:
|
||||||
|
target_message = msg
|
||||||
|
break
|
||||||
|
|
||||||
|
if target_message is None:
|
||||||
|
logger.warning(f"No message found with role={role_filter}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Append content
|
||||||
|
current_content = target_message.message.get("content", "")
|
||||||
|
if isinstance(current_content, str):
|
||||||
|
target_message.message["content"] = (
|
||||||
|
current_content + content_to_append
|
||||||
|
)
|
||||||
|
elif isinstance(current_content, list):
|
||||||
|
# Handle multi-modal content (list of content blocks)
|
||||||
|
target_message.message["content"].append(
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": content_to_append,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.error(f"Unsupported content type: {type(current_content)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Update timestamp
|
||||||
|
target_message.update_time = datetime.now(timezone.utc).isoformat()
|
||||||
|
|
||||||
|
# Optional: Log to file
|
||||||
|
if hasattr(self, "log_storage_path"):
|
||||||
|
content_log = (
|
||||||
|
"=" * 10
|
||||||
|
+ "\n"
|
||||||
|
+ f"APPEND to Role: {target_message.message.get('role')}\n"
|
||||||
|
+ f"Appended: {content_to_append}\n"
|
||||||
|
+ "=" * 10
|
||||||
|
+ "\n"
|
||||||
|
)
|
||||||
|
with open(self.log_storage_path, "a") as file:
|
||||||
|
file.write(content_log)
|
||||||
|
|
||||||
|
return target_message
|
||||||
|
|
||||||
async def get_messages(self) -> List[MockMessage]:
|
async def get_messages(self) -> List[MockMessage]:
|
||||||
logger.log("SEND_MSG", "Get all messages")
|
logger.log("SEND_MSG", "Get all messages")
|
||||||
return self.messages
|
return self.messages
|
||||||
|
|||||||
@@ -1,9 +1,11 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# pylint: disable=W0612,E0611,C2801
|
# pylint: disable=W0612,E0611,C2801
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import traceback
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import traceback
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
@@ -17,27 +19,33 @@ from alias.agent.agents import (
|
|||||||
BrowserAgent,
|
BrowserAgent,
|
||||||
DeepResearchAgent,
|
DeepResearchAgent,
|
||||||
MetaPlanner,
|
MetaPlanner,
|
||||||
DataScienceAgent,
|
|
||||||
init_ds_toolkit,
|
|
||||||
init_dr_toolkit,
|
init_dr_toolkit,
|
||||||
)
|
)
|
||||||
|
|
||||||
from alias.agent.agents.meta_planner_utils._worker_manager import share_tools
|
from alias.agent.agents.meta_planner_utils._worker_manager import share_tools
|
||||||
from alias.agent.mock import MockSessionService as SessionService
|
from alias.agent.mock import MockSessionService as SessionService
|
||||||
from alias.agent.tools import AliasToolkit
|
from alias.agent.tools import AliasToolkit
|
||||||
|
|
||||||
from alias.agent.utils.constants import (
|
from alias.agent.utils.constants import (
|
||||||
BROWSER_AGENT_DESCRIPTION,
|
BROWSER_AGENT_DESCRIPTION,
|
||||||
DEFAULT_DEEP_RESEARCH_AGENT_NAME,
|
DEFAULT_DEEP_RESEARCH_AGENT_NAME,
|
||||||
DEEPRESEARCH_AGENT_DESCRIPTION,
|
DEEPRESEARCH_AGENT_DESCRIPTION,
|
||||||
DS_AGENT_DESCRIPTION,
|
DS_AGENT_DESCRIPTION,
|
||||||
)
|
)
|
||||||
from alias.agent.tools.add_tools import add_tools
|
from alias.agent.utils.prepare_data_source import (
|
||||||
from alias.agent.agents.ds_agent_utils import (
|
add_data_source_tools,
|
||||||
add_ds_specific_tool,
|
prepare_data_sources,
|
||||||
)
|
)
|
||||||
|
from alias.agent.tools.add_tools import add_tools
|
||||||
from alias.agent.memory.longterm_memory import AliasLongTermMemory
|
from alias.agent.memory.longterm_memory import AliasLongTermMemory
|
||||||
from alias.server.clients.memory_client import MemoryClient
|
from alias.server.clients.memory_client import MemoryClient
|
||||||
|
from alias.agent.agents._data_science_agent import (
|
||||||
|
DataScienceAgent,
|
||||||
|
init_ds_toolkit,
|
||||||
|
)
|
||||||
|
|
||||||
|
from alias.agent.utils.llm_call_manager import (
|
||||||
|
LLMCallManager,
|
||||||
|
)
|
||||||
|
|
||||||
MODEL_FORMATTER_MAPPING = {
|
MODEL_FORMATTER_MAPPING = {
|
||||||
"qwen3-max": [
|
"qwen3-max": [
|
||||||
@@ -104,9 +112,28 @@ async def arun_meta_planner(
|
|||||||
# Init deep research toolkit
|
# Init deep research toolkit
|
||||||
deep_research_toolkit = init_dr_toolkit(worker_full_toolkit)
|
deep_research_toolkit = init_dr_toolkit(worker_full_toolkit)
|
||||||
|
|
||||||
# Init BI agent toolkit
|
# Init data science agent toolkit
|
||||||
ds_toolkit = init_ds_toolkit(worker_full_toolkit)
|
ds_toolkit = init_ds_toolkit(worker_full_toolkit)
|
||||||
|
|
||||||
|
# Initialize data source manager
|
||||||
|
llm_call_manager = LLMCallManager(
|
||||||
|
base_model_name=MODEL_CONFIG_NAME,
|
||||||
|
vl_model_name=VL_MODEL_NAME,
|
||||||
|
model_formatter_mapping=MODEL_FORMATTER_MAPPING,
|
||||||
|
)
|
||||||
|
data_manager = await prepare_data_sources(
|
||||||
|
session_service=session_service,
|
||||||
|
sandbox=sandbox,
|
||||||
|
llm_call_manager=llm_call_manager,
|
||||||
|
)
|
||||||
|
add_data_source_tools(
|
||||||
|
data_manager,
|
||||||
|
worker_full_toolkit,
|
||||||
|
browser_toolkit,
|
||||||
|
deep_research_toolkit,
|
||||||
|
ds_toolkit,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
model, formatter = MODEL_FORMATTER_MAPPING[MODEL_CONFIG_NAME]
|
model, formatter = MODEL_FORMATTER_MAPPING[MODEL_CONFIG_NAME]
|
||||||
browser_agent = BrowserAgent(
|
browser_agent = BrowserAgent(
|
||||||
@@ -175,13 +202,15 @@ async def arun_meta_planner(
|
|||||||
description=DEEPRESEARCH_AGENT_DESCRIPTION,
|
description=DEEPRESEARCH_AGENT_DESCRIPTION,
|
||||||
worker_type="built-in",
|
worker_type="built-in",
|
||||||
)
|
)
|
||||||
# === add BI agent ===
|
# === add data science agent ===
|
||||||
ds_agent = DataScienceAgent(
|
ds_agent = DataScienceAgent(
|
||||||
name="Data_Science_Agent",
|
name="Data_Science_Agent",
|
||||||
model=model,
|
model=model,
|
||||||
formatter=formatter,
|
formatter=formatter,
|
||||||
memory=InMemoryMemory(),
|
memory=InMemoryMemory(),
|
||||||
toolkit=ds_toolkit,
|
toolkit=ds_toolkit,
|
||||||
|
data_manager=data_manager,
|
||||||
|
sys_prompt=data_manager.get_data_skills(),
|
||||||
max_iters=30,
|
max_iters=30,
|
||||||
session_service=session_service,
|
session_service=session_service,
|
||||||
)
|
)
|
||||||
@@ -219,6 +248,19 @@ async def arun_deepresearch_agent(
|
|||||||
"run_shell_command",
|
"run_shell_command",
|
||||||
]
|
]
|
||||||
share_tools(global_toolkit, worker_toolkit, test_tool_list)
|
share_tools(global_toolkit, worker_toolkit, test_tool_list)
|
||||||
|
|
||||||
|
llm_call_manager = LLMCallManager(
|
||||||
|
base_model_name=MODEL_CONFIG_NAME,
|
||||||
|
vl_model_name=VL_MODEL_NAME,
|
||||||
|
model_formatter_mapping=MODEL_FORMATTER_MAPPING,
|
||||||
|
)
|
||||||
|
await prepare_data_sources(
|
||||||
|
session_service,
|
||||||
|
sandbox,
|
||||||
|
worker_toolkit,
|
||||||
|
llm_call_manager,
|
||||||
|
)
|
||||||
|
|
||||||
worker_agent = DeepResearchAgent(
|
worker_agent = DeepResearchAgent(
|
||||||
name="Deep_Research_Agent",
|
name="Deep_Research_Agent",
|
||||||
model=model,
|
model=model,
|
||||||
@@ -285,6 +327,18 @@ async def arun_finance_agent(
|
|||||||
active=True,
|
active=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
llm_call_manager = LLMCallManager(
|
||||||
|
base_model_name=MODEL_CONFIG_NAME,
|
||||||
|
vl_model_name=VL_MODEL_NAME,
|
||||||
|
model_formatter_mapping=MODEL_FORMATTER_MAPPING,
|
||||||
|
)
|
||||||
|
await prepare_data_sources(
|
||||||
|
session_service,
|
||||||
|
sandbox,
|
||||||
|
worker_toolkit,
|
||||||
|
llm_call_manager,
|
||||||
|
)
|
||||||
|
|
||||||
worker_agent = DeepResearchAgent(
|
worker_agent = DeepResearchAgent(
|
||||||
name="Deep_Research_Agent",
|
name="Deep_Research_Agent",
|
||||||
model=model,
|
model=model,
|
||||||
@@ -326,17 +380,21 @@ async def arun_datascience_agent(
|
|||||||
session_service: SessionService, # type: ignore[valid-type]
|
session_service: SessionService, # type: ignore[valid-type]
|
||||||
sandbox: Sandbox = None,
|
sandbox: Sandbox = None,
|
||||||
):
|
):
|
||||||
global_toolkit = AliasToolkit(sandbox, add_all=True)
|
|
||||||
# await add_tools(global_toolkit)
|
|
||||||
worker_toolkit = AliasToolkit(sandbox)
|
|
||||||
model, formatter = MODEL_FORMATTER_MAPPING[MODEL_CONFIG_NAME]
|
model, formatter = MODEL_FORMATTER_MAPPING[MODEL_CONFIG_NAME]
|
||||||
test_tool_list = [
|
|
||||||
"write_file",
|
global_toolkit = AliasToolkit(sandbox, add_all=True)
|
||||||
"run_ipython_cell",
|
worker_toolkit = init_ds_toolkit(global_toolkit)
|
||||||
"run_shell_command",
|
llm_call_manager = LLMCallManager(
|
||||||
]
|
base_model_name=MODEL_CONFIG_NAME,
|
||||||
share_tools(global_toolkit, worker_toolkit, test_tool_list)
|
vl_model_name=VL_MODEL_NAME,
|
||||||
add_ds_specific_tool(worker_toolkit)
|
model_formatter_mapping=MODEL_FORMATTER_MAPPING,
|
||||||
|
)
|
||||||
|
data_manager = await prepare_data_sources(
|
||||||
|
session_service=session_service,
|
||||||
|
sandbox=sandbox,
|
||||||
|
binded_toolkit=worker_toolkit,
|
||||||
|
llm_call_manager=llm_call_manager,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
worker_agent = DataScienceAgent(
|
worker_agent = DataScienceAgent(
|
||||||
@@ -345,6 +403,8 @@ async def arun_datascience_agent(
|
|||||||
formatter=formatter,
|
formatter=formatter,
|
||||||
memory=InMemoryMemory(),
|
memory=InMemoryMemory(),
|
||||||
toolkit=worker_toolkit,
|
toolkit=worker_toolkit,
|
||||||
|
data_manager=data_manager,
|
||||||
|
sys_prompt=data_manager.get_data_skills(),
|
||||||
max_iters=30,
|
max_iters=30,
|
||||||
session_service=session_service,
|
session_service=session_service,
|
||||||
)
|
)
|
||||||
@@ -360,6 +420,7 @@ async def arun_datascience_agent(
|
|||||||
finally:
|
finally:
|
||||||
try:
|
try:
|
||||||
await global_toolkit.close_mcp_clients()
|
await global_toolkit.close_mcp_clients()
|
||||||
|
await worker_toolkit.close_mcp_clients()
|
||||||
except (RuntimeError, asyncio.CancelledError) as e:
|
except (RuntimeError, asyncio.CancelledError) as e:
|
||||||
# Event loop might be closed during shutdown
|
# Event loop might be closed during shutdown
|
||||||
if "Event loop is closed" in str(e) or isinstance(
|
if "Event loop is closed" in str(e) or isinstance(
|
||||||
@@ -386,6 +447,18 @@ async def arun_browseruse_agent(
|
|||||||
add_all=True,
|
add_all=True,
|
||||||
is_browser_toolkit=True,
|
is_browser_toolkit=True,
|
||||||
)
|
)
|
||||||
|
llm_call_manager = LLMCallManager(
|
||||||
|
base_model_name=MODEL_CONFIG_NAME,
|
||||||
|
vl_model_name=VL_MODEL_NAME,
|
||||||
|
model_formatter_mapping=MODEL_FORMATTER_MAPPING,
|
||||||
|
)
|
||||||
|
await prepare_data_sources(
|
||||||
|
session_service,
|
||||||
|
sandbox,
|
||||||
|
browser_toolkit,
|
||||||
|
llm_call_manager,
|
||||||
|
)
|
||||||
|
|
||||||
logger.info("Init browser toolkit")
|
logger.info("Init browser toolkit")
|
||||||
try:
|
try:
|
||||||
browser_agent = BrowserAgent(
|
browser_agent = BrowserAgent(
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import json
|
|||||||
import os
|
import os
|
||||||
import tarfile
|
import tarfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
import shlex
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
@@ -172,7 +173,7 @@ def get_workspace_file(
|
|||||||
)
|
)
|
||||||
tool_result = sandbox.call_tool(
|
tool_result = sandbox.call_tool(
|
||||||
"run_shell_command",
|
"run_shell_command",
|
||||||
arguments={"command": f"base64 -i {file_path}"},
|
arguments={"command": f"base64 -i {shlex.quote(file_path)}"},
|
||||||
)
|
)
|
||||||
return tool_result["content"][0]["text"]
|
return tool_result["content"][0]["text"]
|
||||||
|
|
||||||
@@ -194,7 +195,7 @@ def create_or_edit_workspace_file(
|
|||||||
}
|
}
|
||||||
sandbox.call_tool(
|
sandbox.call_tool(
|
||||||
"run_shell_command",
|
"run_shell_command",
|
||||||
arguments={"command": f"touch {file_path}"},
|
arguments={"command": f"touch {shlex.quote(file_path)}"},
|
||||||
)
|
)
|
||||||
fill_result = sandbox.call_tool(
|
fill_result = sandbox.call_tool(
|
||||||
"write_file",
|
"write_file",
|
||||||
@@ -222,7 +223,7 @@ def create_workspace_directory(
|
|||||||
}
|
}
|
||||||
tool_result = sandbox.call_tool(
|
tool_result = sandbox.call_tool(
|
||||||
"run_shell_command",
|
"run_shell_command",
|
||||||
arguments={"command": f"mkdir -p {directory_path}"},
|
arguments={"command": f"mkdir -p {shlex.quote(directory_path)}"},
|
||||||
)
|
)
|
||||||
return tool_result
|
return tool_result
|
||||||
|
|
||||||
@@ -246,7 +247,7 @@ def delete_workspace_file(
|
|||||||
}
|
}
|
||||||
tool_result = sandbox.call_tool(
|
tool_result = sandbox.call_tool(
|
||||||
"run_shell_command",
|
"run_shell_command",
|
||||||
arguments={"command": f"rm -rf {file_path}"},
|
arguments={"command": f"rm -rf {shlex.quote(file_path)}"},
|
||||||
)
|
)
|
||||||
return tool_result
|
return tool_result
|
||||||
|
|
||||||
@@ -280,7 +281,7 @@ def download_workspace_file_from_oss(
|
|||||||
tool_result = sandbox.call_tool(
|
tool_result = sandbox.call_tool(
|
||||||
"run_shell_command",
|
"run_shell_command",
|
||||||
arguments={
|
arguments={
|
||||||
"command": f"wget -O {to_path} {oss_url}",
|
"command": f"wget -O {shlex.quote(to_path)} {oss_url}",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
print(f"{tool_result}")
|
print(f"{tool_result}")
|
||||||
@@ -306,7 +307,7 @@ def delete_workspace_directory(
|
|||||||
}
|
}
|
||||||
tool_result = sandbox.call_tool(
|
tool_result = sandbox.call_tool(
|
||||||
"run_shell_command",
|
"run_shell_command",
|
||||||
arguments={"command": f"rm -rf {directory_path}"},
|
arguments={"command": f"rm -rf {shlex.quote(directory_path)}"},
|
||||||
)
|
)
|
||||||
return tool_result
|
return tool_result
|
||||||
|
|
||||||
|
|||||||
@@ -1,131 +1,8 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
import json
|
|
||||||
import os.path
|
|
||||||
import uuid
|
|
||||||
import textwrap
|
|
||||||
|
|
||||||
from agentscope.tool import ToolResponse
|
from alias.agent.tools.toolkit_hooks.text_post_hook import TextPostHook
|
||||||
from agentscope.message import ToolUseBlock, TextBlock
|
|
||||||
|
|
||||||
from alias.agent.utils.constants import TMP_FILE_DIR
|
|
||||||
from alias.agent.tools.sandbox_util import (
|
|
||||||
create_or_edit_workspace_file,
|
|
||||||
create_workspace_directory,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class LongTextPostHook:
|
class LongTextPostHook(TextPostHook):
|
||||||
def __init__(self, sandbox):
|
def __init__(self, sandbox):
|
||||||
self.sandbox = sandbox
|
super().__init__(sandbox, budget=8194 * 10, auto_save=False)
|
||||||
|
|
||||||
def truncate_and_save_response( # pylint: disable=R1710
|
|
||||||
self,
|
|
||||||
tool_use: ToolUseBlock, # pylint: disable=W0613
|
|
||||||
tool_response: ToolResponse,
|
|
||||||
) -> ToolResponse:
|
|
||||||
"""Post-process tool responses to prevent content overflow.
|
|
||||||
|
|
||||||
This function ensures that tool responses don't exceed a predefined
|
|
||||||
budget to prevent overwhelming the model with too much information.
|
|
||||||
It truncates text content while preserving the structure of
|
|
||||||
the response.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tool_use: The tool use block that triggered the response (unused).
|
|
||||||
tool_response: The tool response to potentially truncate.
|
|
||||||
|
|
||||||
Note:
|
|
||||||
The budget is set to approximately 80K tokens
|
|
||||||
(8194 * 10 characters) to ensure responses remain
|
|
||||||
manageable for the language model.
|
|
||||||
"""
|
|
||||||
# Set budget to prevent overwhelming the model with too much content
|
|
||||||
budget = 8194 * 10 # Approximately 80K tokens of content
|
|
||||||
append_hint = "\n\n[Content is too long and truncated....]"
|
|
||||||
|
|
||||||
new_tool_response = ToolResponse(
|
|
||||||
id=tool_response.id,
|
|
||||||
stream=tool_response.stream,
|
|
||||||
is_last=tool_response.is_last,
|
|
||||||
is_interrupted=tool_response.is_interrupted,
|
|
||||||
content=[],
|
|
||||||
)
|
|
||||||
if isinstance(tool_response.content, list):
|
|
||||||
save_text_block = None
|
|
||||||
for _i, block in enumerate(tool_response.content):
|
|
||||||
if block["type"] == "text":
|
|
||||||
text = block["text"]
|
|
||||||
text_len = len(text)
|
|
||||||
|
|
||||||
# If this block exceeds remaining budget, truncate it
|
|
||||||
if text_len > budget:
|
|
||||||
# Calculate truncation threshold
|
|
||||||
# (80% of proportional budget)
|
|
||||||
threshold = int(budget * 0.85)
|
|
||||||
# save the original response
|
|
||||||
tmp_file_name_prefix = tool_use.get("name", "")
|
|
||||||
save_text_block = self._save_tmp_file(
|
|
||||||
tmp_file_name_prefix,
|
|
||||||
tool_response.content,
|
|
||||||
)
|
|
||||||
new_tool_response.append = (
|
|
||||||
text[:threshold] + append_hint
|
|
||||||
)
|
|
||||||
new_tool_response.content.append(
|
|
||||||
TextBlock(
|
|
||||||
type="text",
|
|
||||||
text=text[:threshold] + append_hint,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
new_tool_response.content.append(block)
|
|
||||||
budget -= text_len
|
|
||||||
if budget <= 0 and save_text_block:
|
|
||||||
new_tool_response.content.append(save_text_block)
|
|
||||||
return new_tool_response
|
|
||||||
elif isinstance(tool_response.content, str):
|
|
||||||
text_len = len(tool_response.content)
|
|
||||||
text = tool_response.content
|
|
||||||
if text_len > budget:
|
|
||||||
tmp_file_name_prefix = tool_use.get("name", "")
|
|
||||||
save_text_block = self._save_tmp_file(
|
|
||||||
tmp_file_name_prefix,
|
|
||||||
tool_response.content,
|
|
||||||
)
|
|
||||||
# Calculate truncation threshold (80% of proportional budget)
|
|
||||||
threshold = int(budget / text_len * len(text) * 0.8)
|
|
||||||
tool_response.content = text[:threshold] + append_hint
|
|
||||||
tool_response.content = [
|
|
||||||
TextBlock(type="text", text=tool_response.content),
|
|
||||||
save_text_block,
|
|
||||||
]
|
|
||||||
return tool_response
|
|
||||||
return tool_response
|
|
||||||
|
|
||||||
def _save_tmp_file(self, save_file_name_prefix: str, content: list | str):
|
|
||||||
create_workspace_directory(self.sandbox, TMP_FILE_DIR)
|
|
||||||
save_file_name = (
|
|
||||||
save_file_name_prefix
|
|
||||||
+ "-"
|
|
||||||
+ str(
|
|
||||||
uuid.uuid4().hex[:8],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
file_path = os.path.join(TMP_FILE_DIR, save_file_name)
|
|
||||||
json_str = json.dumps(content, ensure_ascii=False, indent=2)
|
|
||||||
wrapped = "\\n".join(
|
|
||||||
[textwrap.fill(line, width=500) for line in json_str.split("\\n")],
|
|
||||||
)
|
|
||||||
create_or_edit_workspace_file(
|
|
||||||
self.sandbox,
|
|
||||||
file_path,
|
|
||||||
wrapped,
|
|
||||||
)
|
|
||||||
return TextBlock(
|
|
||||||
type="text",
|
|
||||||
text=f"Dump the complete long file at {file_path}. "
|
|
||||||
"Don't try to read the complete file directly. "
|
|
||||||
"Use `grep -C 10 'YOUR_PATTERN' {file_path}` or "
|
|
||||||
"other bash command to extract "
|
|
||||||
"useful information.",
|
|
||||||
)
|
|
||||||
|
|||||||
175
alias/src/alias/agent/tools/toolkit_hooks/text_post_hook.py
Normal file
175
alias/src/alias/agent/tools/toolkit_hooks/text_post_hook.py
Normal file
@@ -0,0 +1,175 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
import json
|
||||||
|
import os.path
|
||||||
|
import uuid
|
||||||
|
import textwrap
|
||||||
|
|
||||||
|
from agentscope.tool import ToolResponse
|
||||||
|
from agentscope.message import ToolUseBlock, TextBlock
|
||||||
|
|
||||||
|
from alias.agent.utils.constants import TMP_FILE_DIR
|
||||||
|
from alias.agent.tools.sandbox_util import (
|
||||||
|
create_or_edit_workspace_file,
|
||||||
|
create_workspace_directory,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TextPostHook:
|
||||||
|
def __init__(self, sandbox, budget=8194 * 10, auto_save=False):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
sandbox: The sandbox environment for file operations.
|
||||||
|
budget: Maximum character count before truncation
|
||||||
|
(default: 81,940). Approximately 20K tokens for English
|
||||||
|
text or 160K tokens for Chinese text. Adjust based on
|
||||||
|
your model's context window.
|
||||||
|
auto_save: Whether to save complete content to file when truncated.
|
||||||
|
- False: Save only after being truncated (default)
|
||||||
|
- True: Save complete content to file
|
||||||
|
"""
|
||||||
|
self.sandbox = sandbox
|
||||||
|
self.auto_save = auto_save
|
||||||
|
self.budget = budget
|
||||||
|
|
||||||
|
def truncate_and_save_response( # pylint: disable=R1710
|
||||||
|
self,
|
||||||
|
tool_use: ToolUseBlock, # pylint: disable=W0613
|
||||||
|
tool_response: ToolResponse,
|
||||||
|
) -> ToolResponse:
|
||||||
|
"""Post-process tool responses to prevent content overflow.
|
||||||
|
|
||||||
|
This function ensures that tool responses don't exceed a predefined
|
||||||
|
budget to prevent overwhelming the model with too much information.
|
||||||
|
It truncates text content while preserving the structure of the
|
||||||
|
response, and optionally saves the complete content to a file based on
|
||||||
|
the auto_save setting.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_use: The tool use block that triggered the response (unused).
|
||||||
|
tool_response: The tool response to potentially truncate.
|
||||||
|
"""
|
||||||
|
|
||||||
|
budget = self.budget
|
||||||
|
append_hint = "\n\n[Content is too long and truncated....]"
|
||||||
|
|
||||||
|
new_tool_response = ToolResponse(
|
||||||
|
id=tool_response.id,
|
||||||
|
stream=tool_response.stream,
|
||||||
|
is_last=tool_response.is_last,
|
||||||
|
is_interrupted=tool_response.is_interrupted,
|
||||||
|
content=[],
|
||||||
|
)
|
||||||
|
if isinstance(tool_response.content, list):
|
||||||
|
save_text_block = None
|
||||||
|
is_truncated = False
|
||||||
|
|
||||||
|
for _i, block in enumerate(tool_response.content):
|
||||||
|
if block["type"] == "text":
|
||||||
|
text = block["text"]
|
||||||
|
text_len = len(text)
|
||||||
|
|
||||||
|
# If this block exceeds remaining budget, truncate it
|
||||||
|
if text_len > budget:
|
||||||
|
is_truncated = True
|
||||||
|
|
||||||
|
# Calculate truncation threshold
|
||||||
|
# (80% of proportional budget)
|
||||||
|
threshold = int(budget * 0.85)
|
||||||
|
new_tool_response.content.append(
|
||||||
|
TextBlock(
|
||||||
|
type="text",
|
||||||
|
text=text[:threshold] + append_hint,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
new_tool_response.content.append(block)
|
||||||
|
|
||||||
|
budget -= text_len
|
||||||
|
if budget <= 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Save file if auto_save=True or content was truncated
|
||||||
|
if self.auto_save or is_truncated:
|
||||||
|
tmp_file_name_prefix = tool_use.get("name", "")
|
||||||
|
save_text_block = self._save_tmp_file(
|
||||||
|
tmp_file_name_prefix,
|
||||||
|
tool_response.content,
|
||||||
|
is_truncated=is_truncated,
|
||||||
|
)
|
||||||
|
new_tool_response.content.append(save_text_block)
|
||||||
|
|
||||||
|
return new_tool_response
|
||||||
|
|
||||||
|
elif isinstance(tool_response.content, str):
|
||||||
|
text_len = len(tool_response.content)
|
||||||
|
text = tool_response.content
|
||||||
|
is_truncated = text_len > budget
|
||||||
|
|
||||||
|
# Save file if auto_save=True or content was truncated
|
||||||
|
if self.auto_save or is_truncated:
|
||||||
|
tmp_file_name_prefix = tool_use.get("name", "")
|
||||||
|
save_text_block = self._save_tmp_file(
|
||||||
|
tmp_file_name_prefix,
|
||||||
|
tool_response.content,
|
||||||
|
is_truncated=is_truncated,
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_truncated:
|
||||||
|
# Calculate truncation threshold (80% of budget)
|
||||||
|
threshold = int(budget / text_len * len(text) * 0.8)
|
||||||
|
tool_response.content = [
|
||||||
|
TextBlock(
|
||||||
|
type="text",
|
||||||
|
text=text[:threshold] + append_hint,
|
||||||
|
),
|
||||||
|
save_text_block,
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
tool_response.content = [
|
||||||
|
TextBlock(type="text", text=text),
|
||||||
|
save_text_block,
|
||||||
|
]
|
||||||
|
|
||||||
|
return tool_response
|
||||||
|
|
||||||
|
return tool_response
|
||||||
|
|
||||||
|
def _save_tmp_file(
|
||||||
|
self,
|
||||||
|
save_file_name_prefix: str,
|
||||||
|
content: list | str,
|
||||||
|
is_truncated: bool,
|
||||||
|
):
|
||||||
|
create_workspace_directory(self.sandbox, TMP_FILE_DIR)
|
||||||
|
save_file_name = (
|
||||||
|
save_file_name_prefix
|
||||||
|
+ "-"
|
||||||
|
+ str(
|
||||||
|
uuid.uuid4().hex[:8],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
file_path = os.path.join(TMP_FILE_DIR, save_file_name)
|
||||||
|
json_str = json.dumps(content, ensure_ascii=False, indent=2)
|
||||||
|
wrapped = "\\n".join(
|
||||||
|
[textwrap.fill(line, width=500) for line in json_str.split("\\n")],
|
||||||
|
)
|
||||||
|
create_or_edit_workspace_file(
|
||||||
|
self.sandbox,
|
||||||
|
file_path,
|
||||||
|
wrapped,
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
TextBlock(
|
||||||
|
type="text",
|
||||||
|
text=f"Dump the complete long file at {file_path}. "
|
||||||
|
"Don't try to read the complete file directly. "
|
||||||
|
"Use `grep -C 10 'YOUR_PATTERN' {file_path}` or "
|
||||||
|
"other bash command to extract "
|
||||||
|
"useful information.",
|
||||||
|
)
|
||||||
|
if is_truncated
|
||||||
|
else TextBlock(
|
||||||
|
type="text",
|
||||||
|
text=f"Results dumped at {file_path}. ",
|
||||||
|
)
|
||||||
|
)
|
||||||
129
alias/src/alias/agent/utils/llm_call_manager.py
Normal file
129
alias/src/alias/agent/utils/llm_call_manager.py
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
import asyncio
|
||||||
|
from typing import Any, Dict, Literal, Type
|
||||||
|
from tenacity import retry, stop_after_attempt, wait_fixed
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from agentscope.message import Msg
|
||||||
|
from agentscope.model import DashScopeChatModel
|
||||||
|
from agentscope.formatter import DashScopeChatFormatter
|
||||||
|
|
||||||
|
from alias.agent.utils.constants import MODEL_MAX_RETRIES
|
||||||
|
|
||||||
|
|
||||||
|
@retry(
|
||||||
|
stop=stop_after_attempt(MODEL_MAX_RETRIES),
|
||||||
|
wait=wait_fixed(5),
|
||||||
|
reraise=True,
|
||||||
|
# before_sleep=_print_exc_on_retry
|
||||||
|
)
|
||||||
|
async def model_call_with_retry(
|
||||||
|
model: DashScopeChatModel = None,
|
||||||
|
formatter: DashScopeChatFormatter = None,
|
||||||
|
messages: list[dict[str, Any]] = None,
|
||||||
|
tool_json_schemas: list[dict] | None = None,
|
||||||
|
tool_choice: Literal["auto", "none", "required"] | str | None = None,
|
||||||
|
structured_model: Type[BaseModel] | None = None,
|
||||||
|
msg_name: str = "model_call",
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Msg:
|
||||||
|
"""
|
||||||
|
Make a model call with retry mechanism.
|
||||||
|
This function formats the messages and calls the model with retry logic
|
||||||
|
to handle potential failures during the API call.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The DashScope chat model to use for inference
|
||||||
|
formatter: Formatter to prepare messages for the model
|
||||||
|
msg_name: Name for the returned message object
|
||||||
|
see DashScopeChatModel's docstring for more details
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Message object containing the model response
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If all retry attempts fail
|
||||||
|
"""
|
||||||
|
format_messages = await formatter.format(msgs=messages)
|
||||||
|
|
||||||
|
res = await model(
|
||||||
|
messages=format_messages,
|
||||||
|
tools=tool_json_schemas,
|
||||||
|
tool_choice=tool_choice,
|
||||||
|
structured_model=structured_model,
|
||||||
|
kwargs=kwargs,
|
||||||
|
)
|
||||||
|
if model.stream:
|
||||||
|
msg = Msg(msg_name, [], "assistant")
|
||||||
|
async for content_chunk in res:
|
||||||
|
msg.content = content_chunk.content
|
||||||
|
# Add a tiny sleep to yield the last message object in the
|
||||||
|
# message queue
|
||||||
|
await asyncio.sleep(0.001)
|
||||||
|
else:
|
||||||
|
msg = Msg(msg_name, list(res.content), "assistant")
|
||||||
|
return msg
|
||||||
|
|
||||||
|
|
||||||
|
class LLMCallManager:
|
||||||
|
"""Manager class for handling LLM calls with different models."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
base_model_name: str,
|
||||||
|
vl_model_name: str,
|
||||||
|
model_formatter_mapping: Dict[str, Any],
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the LLM call manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_model_name: Name of the base language model
|
||||||
|
vl_model_name: Name of the vision-language model
|
||||||
|
model_formatter_mapping: Mapping of names to model/formatter pairs
|
||||||
|
"""
|
||||||
|
self.base_model_name = base_model_name
|
||||||
|
self.vl_model_name = vl_model_name
|
||||||
|
self.model_formatter_mapping = model_formatter_mapping
|
||||||
|
|
||||||
|
def get_base_model_name(self) -> str:
|
||||||
|
"""Get the name of the base language model."""
|
||||||
|
return self.base_model_name
|
||||||
|
|
||||||
|
def get_vl_model_name(self) -> str:
|
||||||
|
"""Get the name of the vision-language model."""
|
||||||
|
return self.vl_model_name
|
||||||
|
|
||||||
|
async def __call__(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
tools: list[dict] | None = None,
|
||||||
|
tool_choice: Literal["auto", "none", "required"] | str | None = None,
|
||||||
|
structured_model: Type[BaseModel] | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Make an asynchronous call to the specified LLM.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: Name of the model to use for the call
|
||||||
|
messages: List of message dictionaries to send to the model
|
||||||
|
see DashScopeChatModel's docstring for more details
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
String response from the LLM
|
||||||
|
"""
|
||||||
|
model, formatter = self.model_formatter_mapping[model_name]
|
||||||
|
raw_response = await model_call_with_retry(
|
||||||
|
model=model,
|
||||||
|
formatter=formatter,
|
||||||
|
messages=messages,
|
||||||
|
tool_json_schemas=tools,
|
||||||
|
tool_choice=tool_choice,
|
||||||
|
structured_model=structured_model,
|
||||||
|
msg_name="model_call",
|
||||||
|
kwargs=kwargs,
|
||||||
|
)
|
||||||
|
response = raw_response.content[0]["text"]
|
||||||
|
return response
|
||||||
92
alias/src/alias/agent/utils/prepare_data_source.py
Normal file
92
alias/src/alias/agent/utils/prepare_data_source.py
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
import os
|
||||||
|
|
||||||
|
from agentscope_runtime.sandbox.box.sandbox import Sandbox
|
||||||
|
|
||||||
|
from alias.agent.agents.data_source.data_source import DataSourceManager
|
||||||
|
from alias.agent.tools import AliasToolkit, share_tools
|
||||||
|
from alias.agent.utils.llm_call_manager import (
|
||||||
|
LLMCallManager,
|
||||||
|
)
|
||||||
|
|
||||||
|
if os.getenv("TEST_MODE") not in ["local", "runtime-test"]:
|
||||||
|
from alias.server.services.session_service import (
|
||||||
|
SessionService,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
from alias.agent.mock import MockSessionService as SessionService
|
||||||
|
|
||||||
|
|
||||||
|
async def prepare_data_sources(
|
||||||
|
session_service: SessionService,
|
||||||
|
sandbox: Sandbox,
|
||||||
|
binded_toolkit: AliasToolkit = None,
|
||||||
|
llm_call_manager: LLMCallManager = None,
|
||||||
|
):
|
||||||
|
data_manager = await build_data_manager(
|
||||||
|
session_service,
|
||||||
|
sandbox,
|
||||||
|
llm_call_manager,
|
||||||
|
)
|
||||||
|
if len(data_manager):
|
||||||
|
await add_user_data_message(session_service, data_manager)
|
||||||
|
|
||||||
|
if binded_toolkit:
|
||||||
|
add_data_source_tools(data_manager, binded_toolkit)
|
||||||
|
|
||||||
|
return data_manager
|
||||||
|
|
||||||
|
|
||||||
|
async def build_data_manager(
|
||||||
|
session_service: SessionService,
|
||||||
|
sandbox: Sandbox,
|
||||||
|
llm_call_manager: LLMCallManager,
|
||||||
|
):
|
||||||
|
data_manager = DataSourceManager(sandbox, llm_call_manager)
|
||||||
|
if (
|
||||||
|
hasattr(session_service.session_entity, "data_config")
|
||||||
|
and session_service.session_entity.data_config
|
||||||
|
):
|
||||||
|
data_configs = session_service.session_entity.data_config
|
||||||
|
for config in data_configs:
|
||||||
|
data_manager.add_data_source(config)
|
||||||
|
|
||||||
|
await data_manager.prepare_data_sources()
|
||||||
|
return data_manager
|
||||||
|
|
||||||
|
|
||||||
|
def add_data_source_tools(
|
||||||
|
data_manager: DataSourceManager,
|
||||||
|
*toolkits: AliasToolkit,
|
||||||
|
):
|
||||||
|
data_source_toolkit = data_manager.toolkit
|
||||||
|
tool_names = list(data_source_toolkit.tools.keys())
|
||||||
|
for toolkit in toolkits:
|
||||||
|
share_tools(data_source_toolkit, toolkit, tool_names)
|
||||||
|
|
||||||
|
|
||||||
|
async def add_user_data_message(
|
||||||
|
session_service: SessionService,
|
||||||
|
data_manager: DataSourceManager,
|
||||||
|
):
|
||||||
|
await session_service.append_to_latest_message(
|
||||||
|
"\n\n" + data_manager.get_all_data_sources_desc(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_data_source_config_from_file(config_file: str):
|
||||||
|
"""Load and parse data source configuration from a JSON file."""
|
||||||
|
import json
|
||||||
|
|
||||||
|
# Validate file existence upfront
|
||||||
|
if not os.path.isfile(config_file):
|
||||||
|
raise FileNotFoundError(f"Configuration file not found: {config_file}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(config_file, "r", encoding="utf-8") as f:
|
||||||
|
return json.load(f)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid JSON in data source configuration file `'{config_file}'`\
|
||||||
|
: {e.msg} at line {e.lineno}",
|
||||||
|
) from e
|
||||||
@@ -9,12 +9,10 @@ for the Alias agent application.
|
|||||||
"""
|
"""
|
||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
import webbrowser
|
import webbrowser
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
@@ -28,7 +26,10 @@ from alias.agent.run import (
|
|||||||
arun_datascience_agent,
|
arun_datascience_agent,
|
||||||
arun_finance_agent,
|
arun_finance_agent,
|
||||||
)
|
)
|
||||||
from alias.agent.tools.sandbox_util import copy_local_file_to_workspace
|
from alias.agent.utils.prepare_data_source import (
|
||||||
|
get_data_source_config_from_file,
|
||||||
|
)
|
||||||
|
|
||||||
from alias.runtime.alias_sandbox.alias_sandbox import AliasSandbox
|
from alias.runtime.alias_sandbox.alias_sandbox import AliasSandbox
|
||||||
|
|
||||||
|
|
||||||
@@ -60,7 +61,7 @@ def _safe_sigint_handler(signum, frame): # pylint: disable=W0613
|
|||||||
async def run_agent_task(
|
async def run_agent_task(
|
||||||
user_msg: str,
|
user_msg: str,
|
||||||
mode: str = "general",
|
mode: str = "general",
|
||||||
files: Optional[list[str]] = None,
|
user_data_config: list | None = None,
|
||||||
use_long_term_memory_service: bool = False,
|
use_long_term_memory_service: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -69,7 +70,8 @@ async def run_agent_task(
|
|||||||
Args:
|
Args:
|
||||||
user_msg: The user's task/query
|
user_msg: The user's task/query
|
||||||
mode: Agent mode ('general', 'dr', 'ds', 'browser', 'finance')
|
mode: Agent mode ('general', 'dr', 'ds', 'browser', 'finance')
|
||||||
files: List of local file paths to upload to sandbox workspace
|
user_data: (Config for) User data sources, used for data science \
|
||||||
|
agent only
|
||||||
use_long_term_memory_service: Enable long-term memory service.
|
use_long_term_memory_service: Enable long-term memory service.
|
||||||
"""
|
"""
|
||||||
global _original_sigint_handler
|
global _original_sigint_handler
|
||||||
@@ -84,6 +86,7 @@ async def run_agent_task(
|
|||||||
|
|
||||||
# Initialize session
|
# Initialize session
|
||||||
session = MockSessionService(
|
session = MockSessionService(
|
||||||
|
data_config=user_data_config,
|
||||||
use_long_term_memory_service=use_long_term_memory_service,
|
use_long_term_memory_service=use_long_term_memory_service,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -118,35 +121,6 @@ async def run_agent_task(
|
|||||||
)
|
)
|
||||||
logger.info(f"Sandbox desktop URL: {sandbox.desktop_url}")
|
logger.info(f"Sandbox desktop URL: {sandbox.desktop_url}")
|
||||||
webbrowser.open(sandbox.desktop_url)
|
webbrowser.open(sandbox.desktop_url)
|
||||||
# Upload files to sandbox if provided
|
|
||||||
if files:
|
|
||||||
target_paths = []
|
|
||||||
logger.info(
|
|
||||||
f"Uploading {len(files)} file(s) to sandbox workspace...",
|
|
||||||
)
|
|
||||||
for file_path in files:
|
|
||||||
if not os.path.exists(file_path):
|
|
||||||
logger.error(f"File not found: {file_path}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Get the filename and construct target path in workspace
|
|
||||||
filename = os.path.basename(file_path)
|
|
||||||
target_path = f"/workspace/{filename}"
|
|
||||||
|
|
||||||
logger.info(f"Uploading {file_path} to {target_path}")
|
|
||||||
result = copy_local_file_to_workspace(
|
|
||||||
sandbox=sandbox,
|
|
||||||
local_path=file_path,
|
|
||||||
target_path=target_path,
|
|
||||||
)
|
|
||||||
|
|
||||||
if result.get("isError"):
|
|
||||||
raise ValueError(f"Failed to upload {file_path}: {result}")
|
|
||||||
logger.info(f"Successfully uploaded to {result}")
|
|
||||||
|
|
||||||
target_paths.append(result.get("content", [])[0].get("text"))
|
|
||||||
|
|
||||||
user_msg += "\n\nUser uploaded files:\n" + "\n".join(target_paths)
|
|
||||||
|
|
||||||
# Create initial user message (regardless of whether files were uploaded)
|
# Create initial user message (regardless of whether files were uploaded)
|
||||||
initial_user_message = UserMessage(
|
initial_user_message = UserMessage(
|
||||||
@@ -301,12 +275,28 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
run_parser.add_argument(
|
run_parser.add_argument(
|
||||||
|
"--datasource",
|
||||||
"--files",
|
"--files",
|
||||||
"-f",
|
"-d",
|
||||||
type=str,
|
dest="datasource",
|
||||||
nargs="+",
|
nargs="+",
|
||||||
help="Local file paths to upload to sandbox workspace "
|
help=(
|
||||||
"for agent to use (e.g., --files file1.txt file2.csv)",
|
"Data sources for the agent to use. Multiple formats supported:\n"
|
||||||
|
" • Local files: ./data.txt, /absolute/path/file.json\n"
|
||||||
|
" • Databases: postgresql://localhost/db, sqlite:///data.db\n"
|
||||||
|
"Example: "
|
||||||
|
" --datasource file.txt postgresql://localhost/db\n"
|
||||||
|
" --files file.txt"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# If you need to deeply customize your data source
|
||||||
|
# (e.g., specify an MCP server), use this parameter to
|
||||||
|
# provide a configuration file
|
||||||
|
run_parser.add_argument(
|
||||||
|
"--dataconfig",
|
||||||
|
"-dc",
|
||||||
|
help=("Path to the data source configuration file"),
|
||||||
)
|
)
|
||||||
|
|
||||||
run_parser.add_argument(
|
run_parser.add_argument(
|
||||||
@@ -333,11 +323,29 @@ def main():
|
|||||||
# Handle commands
|
# Handle commands
|
||||||
if args.command == "run":
|
if args.command == "run":
|
||||||
try:
|
try:
|
||||||
|
user_data = None
|
||||||
|
data_endpoint = (
|
||||||
|
args.datasource if hasattr(args, "datasource") else None
|
||||||
|
)
|
||||||
|
if data_endpoint:
|
||||||
|
# List of endpoints to data sources
|
||||||
|
user_data = (
|
||||||
|
data_endpoint
|
||||||
|
if isinstance(data_endpoint, list)
|
||||||
|
else [data_endpoint]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Configuration file
|
||||||
|
if hasattr(args, "dataconfig") and args.dataconfig:
|
||||||
|
user_data = get_data_source_config_from_file(
|
||||||
|
args.dataconfig,
|
||||||
|
)
|
||||||
|
|
||||||
asyncio.run(
|
asyncio.run(
|
||||||
run_agent_task(
|
run_agent_task(
|
||||||
user_msg=args.task,
|
user_msg=args.task,
|
||||||
mode=args.mode,
|
mode=args.mode,
|
||||||
files=args.files if hasattr(args, "files") else None,
|
user_data_config=user_data,
|
||||||
use_long_term_memory_service=(
|
use_long_term_memory_service=(
|
||||||
args.use_long_term_memory
|
args.use_long_term_memory
|
||||||
if hasattr(args, "use_long_term_memory")
|
if hasattr(args, "use_long_term_memory")
|
||||||
|
|||||||
Reference in New Issue
Block a user