Compare commits
32 Commits
a41cd705b4
...
codex/work
| Author | SHA1 | Date | |
|---|---|---|---|
| 4295293a21 | |||
| 4aa69650e8 | |||
| 5c08c1865c | |||
| 6ecc224427 | |||
| 9bcc4221a4 | |||
| fecf8a9466 | |||
| 86eb8c37a9 | |||
| 1f9063edad | |||
| 7e7a58769a | |||
| 16bb3c4211 | |||
| da6d642aaa | |||
| 8d6c3c5647 | |||
| 6413edf8c9 | |||
| c5eaf2b5ad | |||
| 032c37538f | |||
| 456748b01e | |||
| 609b509446 | |||
| 38102d0805 | |||
| 3448667b79 | |||
| 0f1bc2bb39 | |||
| 06a23c32a4 | |||
| 5b925fbe02 | |||
| 4b5ac86b83 | |||
| f4a2b7f3af | |||
| 2dcda63394 | |||
| a3f767126f | |||
| 9ec4a8702d | |||
| 3174734f26 | |||
| 59b44545d0 | |||
| 2daf5717ba | |||
| 1f5ee3698e | |||
| 3a5558b576 |
41
.env.example
Normal file
41
.env.example
Normal file
@@ -0,0 +1,41 @@
|
||||
# Copy this file to `.env` for local development.
|
||||
# Keep `.env` untracked and never paste real secrets into tracked files.
|
||||
|
||||
# ================== General Configuration | 通用配置 ==================
|
||||
TICKERS=AAPL,MSFT,GOOGL,AMZN,NVDA,META,TSLA,AMD,NFLX,AVGO,PLTR,COIN
|
||||
|
||||
# Financial Data API
|
||||
# At least `FINANCIAL_DATASETS_API_KEY` is required when using `FIN_DATA_SOURCE=financial_datasets`.
|
||||
# `FINNHUB_API_KEY` is recommended for `FIN_DATA_SOURCE=finnhub` and required for live mode.
|
||||
FIN_DATA_SOURCE=finnhub
|
||||
ENABLED_DATA_SOURCES=financial_datasets,finnhub,yfinance,local_csv
|
||||
FINANCIAL_DATASETS_API_KEY=
|
||||
FINNHUB_API_KEY=
|
||||
POLYGON_API_KEY=
|
||||
MARKET_DB_PATH=
|
||||
|
||||
# Model API
|
||||
OPENAI_API_KEY=
|
||||
OPENAI_BASE_URL=
|
||||
MODEL_NAME=qwen3-max-preview
|
||||
EXPLAIN_ENRICH_USE_LLM=false
|
||||
EXPLAIN_ENRICH_MODEL_PROVIDER=
|
||||
EXPLAIN_ENRICH_MODEL_NAME=
|
||||
EXPLAIN_RANGE_USE_LLM=
|
||||
|
||||
# Memory module
|
||||
MEMORY_API_KEY=
|
||||
|
||||
# ================== Agent-Specific Model Configuration | Agent特定模型配置 ==================
|
||||
AGENT_SENTIMENT_ANALYST_MODEL_NAME=deepseek-v3.2-exp
|
||||
AGENT_TECHNICAL_ANALYST_MODEL_NAME=glm-4.6
|
||||
AGENT_FUNDAMENTALS_ANALYST_MODEL_NAME=qwen3-max-preview
|
||||
AGENT_VALUATION_ANALYST_MODEL_NAME=Moonshot-Kimi-K2-Instruct
|
||||
AGENT_RISK_MANAGER_MODEL_NAME=qwen3-max-preview
|
||||
AGENT_PORTFOLIO_MANAGER_MODEL_NAME=qwen3-max-preview
|
||||
|
||||
# ================== Advanced Configuration | 高阶配置 ==================
|
||||
MAX_COMM_CYCLES=2
|
||||
MARGIN_REQUIREMENT=0.5
|
||||
DATA_START_DATE=2022-01-01
|
||||
AUTO_UPDATE_DATA=true
|
||||
6
.gitignore
vendored
6
.gitignore
vendored
@@ -51,13 +51,15 @@ node_modules
|
||||
outputs/
|
||||
/production/
|
||||
/smoke_test/
|
||||
/smoke_live_mock/
|
||||
|
||||
# Local tooling state
|
||||
/.omc/
|
||||
.omc/
|
||||
/.pydeps/
|
||||
/referance/
|
||||
|
||||
# Run outputs
|
||||
/runs/
|
||||
|
||||
# Data files
|
||||
backend/data/ret_data/
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
{
|
||||
"version": "1.0.0",
|
||||
"lastScanned": 1773304964541,
|
||||
"projectRoot": "/Users/cillin/workspeace/agentscope-samples/evotraders",
|
||||
"lastScanned": 1774515151036,
|
||||
"projectRoot": "/Users/cillin/workspeace/evotraders",
|
||||
"techStack": {
|
||||
"languages": [
|
||||
{
|
||||
@@ -40,7 +40,8 @@
|
||||
"isMonorepo": false,
|
||||
"workspaces": [],
|
||||
"mainDirectories": [
|
||||
"docs"
|
||||
"docs",
|
||||
"scripts"
|
||||
],
|
||||
"gitBranches": {
|
||||
"defaultBranch": "main",
|
||||
@@ -52,26 +53,54 @@
|
||||
"backend": {
|
||||
"path": "backend",
|
||||
"purpose": null,
|
||||
"fileCount": 3,
|
||||
"lastAccessed": 1773304964533,
|
||||
"fileCount": 4,
|
||||
"lastAccessed": 1774515151025,
|
||||
"keyFiles": [
|
||||
"__init__.py",
|
||||
"cli.py",
|
||||
"gateway_server.py",
|
||||
"main.py"
|
||||
]
|
||||
},
|
||||
"backtest": {
|
||||
"path": "backtest",
|
||||
"purpose": null,
|
||||
"fileCount": 0,
|
||||
"lastAccessed": 1774515151026,
|
||||
"keyFiles": []
|
||||
},
|
||||
"data": {
|
||||
"path": "data",
|
||||
"purpose": "Data files",
|
||||
"fileCount": 3,
|
||||
"lastAccessed": 1774515151027,
|
||||
"keyFiles": [
|
||||
"market_research.db",
|
||||
"market_research.db-shm",
|
||||
"market_research.db-wal"
|
||||
]
|
||||
},
|
||||
"deploy": {
|
||||
"path": "deploy",
|
||||
"purpose": null,
|
||||
"fileCount": 0,
|
||||
"lastAccessed": 1774515151027,
|
||||
"keyFiles": []
|
||||
},
|
||||
"docs": {
|
||||
"path": "docs",
|
||||
"purpose": "Documentation",
|
||||
"fileCount": 0,
|
||||
"lastAccessed": 1773304964533,
|
||||
"keyFiles": []
|
||||
"fileCount": 1,
|
||||
"lastAccessed": 1774515151027,
|
||||
"keyFiles": [
|
||||
"compat-removal-plan.md"
|
||||
]
|
||||
},
|
||||
"evotraders.egg-info": {
|
||||
"path": "evotraders.egg-info",
|
||||
"purpose": null,
|
||||
"fileCount": 6,
|
||||
"lastAccessed": 1773304964534,
|
||||
"lastAccessed": 1774515151028,
|
||||
"keyFiles": [
|
||||
"PKG-INFO",
|
||||
"SOURCES.txt",
|
||||
@@ -83,8 +112,8 @@
|
||||
"frontend": {
|
||||
"path": "frontend",
|
||||
"purpose": null,
|
||||
"fileCount": 12,
|
||||
"lastAccessed": 1773304964535,
|
||||
"fileCount": 13,
|
||||
"lastAccessed": 1774515151028,
|
||||
"keyFiles": [
|
||||
"README.md",
|
||||
"components.json",
|
||||
@@ -93,239 +122,414 @@
|
||||
"index.css"
|
||||
]
|
||||
},
|
||||
"live": {
|
||||
"path": "live",
|
||||
"purpose": null,
|
||||
"fileCount": 0,
|
||||
"lastAccessed": 1774515151028,
|
||||
"keyFiles": []
|
||||
},
|
||||
"reference": {
|
||||
"path": "reference",
|
||||
"purpose": null,
|
||||
"fileCount": 0,
|
||||
"lastAccessed": 1774515151028,
|
||||
"keyFiles": []
|
||||
},
|
||||
"runs": {
|
||||
"path": "runs",
|
||||
"purpose": null,
|
||||
"fileCount": 0,
|
||||
"lastAccessed": 1774515151029,
|
||||
"keyFiles": []
|
||||
},
|
||||
"scripts": {
|
||||
"path": "scripts",
|
||||
"purpose": "Build/utility scripts",
|
||||
"fileCount": 1,
|
||||
"lastAccessed": 1774515151030,
|
||||
"keyFiles": [
|
||||
"run_prod.sh"
|
||||
]
|
||||
},
|
||||
"services": {
|
||||
"path": "services",
|
||||
"purpose": "Business logic services",
|
||||
"fileCount": 1,
|
||||
"lastAccessed": 1774515151030,
|
||||
"keyFiles": [
|
||||
"README.md"
|
||||
]
|
||||
},
|
||||
"shared": {
|
||||
"path": "shared",
|
||||
"purpose": null,
|
||||
"fileCount": 0,
|
||||
"lastAccessed": 1774515151030,
|
||||
"keyFiles": []
|
||||
},
|
||||
"backend/api": {
|
||||
"path": "backend/api",
|
||||
"purpose": "API routes",
|
||||
"fileCount": 5,
|
||||
"lastAccessed": 1774515151030,
|
||||
"keyFiles": [
|
||||
"__init__.py",
|
||||
"agents.py",
|
||||
"guard.py"
|
||||
]
|
||||
},
|
||||
"backend/config": {
|
||||
"path": "backend/config",
|
||||
"purpose": "Configuration files",
|
||||
"fileCount": 4,
|
||||
"lastAccessed": 1773304964535,
|
||||
"fileCount": 6,
|
||||
"lastAccessed": 1774515151030,
|
||||
"keyFiles": [
|
||||
"__init__.py",
|
||||
"constants.py",
|
||||
"data_config.py"
|
||||
"agent_profiles.yaml",
|
||||
"bootstrap_config.py"
|
||||
]
|
||||
},
|
||||
"backend/data": {
|
||||
"path": "backend/data",
|
||||
"purpose": "Data files",
|
||||
"fileCount": 7,
|
||||
"lastAccessed": 1773304964536,
|
||||
"fileCount": 12,
|
||||
"lastAccessed": 1774515151031,
|
||||
"keyFiles": [
|
||||
"__init__.py",
|
||||
"cache.py",
|
||||
"historical_price_manager.py"
|
||||
]
|
||||
},
|
||||
"backend/services": {
|
||||
"path": "backend/services",
|
||||
"purpose": "Business logic services",
|
||||
"fileCount": 4,
|
||||
"lastAccessed": 1773304964536,
|
||||
"keyFiles": [
|
||||
"__init__.py",
|
||||
"gateway.py",
|
||||
"market.py"
|
||||
]
|
||||
},
|
||||
"backend/tests": {
|
||||
"path": "backend/tests",
|
||||
"purpose": "Test files",
|
||||
"fileCount": 4,
|
||||
"lastAccessed": 1773304964536,
|
||||
"keyFiles": [
|
||||
"__init__.py",
|
||||
"test_agents.py",
|
||||
"test_market_service.py"
|
||||
]
|
||||
},
|
||||
"docs/assets": {
|
||||
"path": "docs/assets",
|
||||
"purpose": "Static assets",
|
||||
"fileCount": 5,
|
||||
"lastAccessed": 1773304964536,
|
||||
"lastAccessed": 1774515151031,
|
||||
"keyFiles": [
|
||||
"dashboard.jpg",
|
||||
"evotraders_demo.gif",
|
||||
"evotraders_logo.jpg"
|
||||
]
|
||||
},
|
||||
"frontend/public": {
|
||||
"path": "frontend/public",
|
||||
"purpose": "Public files",
|
||||
"fileCount": 1,
|
||||
"lastAccessed": 1773304964538,
|
||||
"frontend/dist": {
|
||||
"path": "frontend/dist",
|
||||
"purpose": "Distribution/build output",
|
||||
"fileCount": 2,
|
||||
"lastAccessed": 1774515151031,
|
||||
"keyFiles": [
|
||||
"index.html",
|
||||
"trading_logo.png"
|
||||
]
|
||||
},
|
||||
"frontend/node_modules": {
|
||||
"path": "frontend/node_modules",
|
||||
"purpose": "Dependencies",
|
||||
"fileCount": 1,
|
||||
"lastAccessed": 1774515151036,
|
||||
"keyFiles": []
|
||||
}
|
||||
},
|
||||
"hotPaths": [
|
||||
{
|
||||
"path": "frontend/src/components/StatisticsView.jsx",
|
||||
"accessCount": 22,
|
||||
"lastAccessed": 1773310044545,
|
||||
"path": "frontend/src/hooks/useWebSocketConnection.js",
|
||||
"accessCount": 100,
|
||||
"lastAccessed": 1774550862686,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "frontend/src/components/AgentCard.jsx",
|
||||
"accessCount": 17,
|
||||
"lastAccessed": 1773309995177,
|
||||
"path": "backend/services/gateway.py",
|
||||
"accessCount": 98,
|
||||
"lastAccessed": 1774550272354,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "backend/services/gateway_openclaw_handlers.py",
|
||||
"accessCount": 91,
|
||||
"lastAccessed": 1774550256325,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "backend/api/openclaw.py",
|
||||
"accessCount": 48,
|
||||
"lastAccessed": 1774545375555,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "frontend/src/hooks/useOpenClawPanel.js",
|
||||
"accessCount": 42,
|
||||
"lastAccessed": 1774550688926,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "shared/client/openclaw_client.py",
|
||||
"accessCount": 39,
|
||||
"lastAccessed": 1774545484770,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "frontend/src",
|
||||
"accessCount": 35,
|
||||
"lastAccessed": 1774550715529,
|
||||
"type": "directory"
|
||||
},
|
||||
{
|
||||
"path": "reference/openclaw/src",
|
||||
"accessCount": 33,
|
||||
"lastAccessed": 1774550840611,
|
||||
"type": "directory"
|
||||
},
|
||||
{
|
||||
"path": "backend/services/openclaw_cli.py",
|
||||
"accessCount": 31,
|
||||
"lastAccessed": 1774545484887,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "frontend/src/components/TraderView.jsx",
|
||||
"accessCount": 23,
|
||||
"lastAccessed": 1774543366574,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "shared/models/openclaw.py",
|
||||
"accessCount": 22,
|
||||
"lastAccessed": 1774545419541,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "frontend/src/store/openclawStore.js",
|
||||
"accessCount": 20,
|
||||
"lastAccessed": 1774550319533,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "frontend/src/App.jsx",
|
||||
"accessCount": 12,
|
||||
"lastAccessed": 1773309849392,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "frontend/src/components/AgentFeed.jsx",
|
||||
"accessCount": 12,
|
||||
"lastAccessed": 1773309960022,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": ".env",
|
||||
"accessCount": 7,
|
||||
"lastAccessed": 1773308950505,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "frontend/src/components/RoomView.jsx",
|
||||
"accessCount": 7,
|
||||
"lastAccessed": 1773309864236,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "backend/tools/analysis_tools.py",
|
||||
"accessCount": 5,
|
||||
"lastAccessed": 1773312271446,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "frontend/src/components/Header.jsx",
|
||||
"accessCount": 4,
|
||||
"lastAccessed": 1773309827069,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "frontend/src/components/AboutModal.jsx",
|
||||
"accessCount": 4,
|
||||
"lastAccessed": 1773310093371,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "backend/agents/prompts/analyst/personas.yaml",
|
||||
"accessCount": 4,
|
||||
"lastAccessed": 1773312049213,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "backend/agents/prompts/analyst/system.md",
|
||||
"accessCount": 4,
|
||||
"lastAccessed": 1773312049696,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "backend/agents/prompts/portfolio_manager/system.md",
|
||||
"accessCount": 4,
|
||||
"lastAccessed": 1773312050326,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "backend/agents/prompts/risk_manager/system.md",
|
||||
"accessCount": 4,
|
||||
"lastAccessed": 1773312050782,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "frontend/src/config/constants.js",
|
||||
"accessCount": 3,
|
||||
"lastAccessed": 1773309824671,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "frontend/src/components/RulesView.jsx",
|
||||
"accessCount": 3,
|
||||
"lastAccessed": 1773310061939,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "backend",
|
||||
"accessCount": 3,
|
||||
"lastAccessed": 1773312200721,
|
||||
"type": "directory"
|
||||
},
|
||||
{
|
||||
"path": "backend/services/gateway.py",
|
||||
"accessCount": 2,
|
||||
"lastAccessed": 1773312232905,
|
||||
"type": "directory"
|
||||
},
|
||||
{
|
||||
"path": "README.md",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1773305013217,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "README_zh.md",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1773305013274,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "env.template",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1773305019965,
|
||||
"accessCount": 18,
|
||||
"lastAccessed": 1774544542524,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "frontend/src/services/websocket.js",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1773309324302,
|
||||
"accessCount": 18,
|
||||
"lastAccessed": 1774549669596,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "backend/config/data_config.py",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1773309324414,
|
||||
"path": "start-dev.sh",
|
||||
"accessCount": 15,
|
||||
"lastAccessed": 1774548224246,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "backend/cli.py",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1773309336899,
|
||||
"path": "frontend/src/components/RuntimeView.jsx",
|
||||
"accessCount": 14,
|
||||
"lastAccessed": 1774518525793,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "frontend/src/components/AppShell.jsx",
|
||||
"accessCount": 13,
|
||||
"lastAccessed": 1774533781725,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "backend/main.py",
|
||||
"accessCount": 13,
|
||||
"lastAccessed": 1774548236340,
|
||||
"type": "directory"
|
||||
},
|
||||
{
|
||||
"path": "backend/agents/portfolio_manager.py",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1773311956562,
|
||||
"path": "backend/apps/openclaw_service.py",
|
||||
"accessCount": 10,
|
||||
"lastAccessed": 1774547900186,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "backend/agents/risk_manager.py",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1773311956760,
|
||||
"path": "frontend/src/components/OpenClawStatusPanel.jsx",
|
||||
"accessCount": 8,
|
||||
"lastAccessed": 1774533622019,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "backend/agents/analyst.py",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1773311963222,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "backend/tools",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1773312289643,
|
||||
"path": "reference/openclaw/src/commands",
|
||||
"accessCount": 7,
|
||||
"lastAccessed": 1774530402019,
|
||||
"type": "directory"
|
||||
},
|
||||
{
|
||||
"path": "backend/tools/data_tools.py",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1773312293851,
|
||||
"path": "frontend/src/config/constants.js",
|
||||
"accessCount": 7,
|
||||
"lastAccessed": 1774544689658,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "",
|
||||
"accessCount": 6,
|
||||
"lastAccessed": 1774550700047,
|
||||
"type": "directory"
|
||||
},
|
||||
{
|
||||
"path": "backend/services",
|
||||
"accessCount": 5,
|
||||
"lastAccessed": 1774550692490,
|
||||
"type": "directory"
|
||||
},
|
||||
{
|
||||
"path": "frontend/src/store/uiStore.js",
|
||||
"accessCount": 4,
|
||||
"lastAccessed": 1774533747700,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "frontend/src/styles/GlobalStyles.jsx",
|
||||
"accessCount": 4,
|
||||
"lastAccessed": 1774533753657,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "frontend/src/store/agentStore.js",
|
||||
"accessCount": 3,
|
||||
"lastAccessed": 1774517930592,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "reference/openclaw/src/cli/skills-cli.ts",
|
||||
"accessCount": 3,
|
||||
"lastAccessed": 1774527140107,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "reference/openclaw/src/commands/agents.commands.list.ts",
|
||||
"accessCount": 3,
|
||||
"lastAccessed": 1774533427441,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "frontend/src/store/runtimeStore.js",
|
||||
"accessCount": 2,
|
||||
"lastAccessed": 1774517930660,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "frontend/src/hooks/useAgentWorkspacePanel.js",
|
||||
"accessCount": 2,
|
||||
"lastAccessed": 1774518021290,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "frontend/src/services/runtimeApi.js",
|
||||
"accessCount": 2,
|
||||
"lastAccessed": 1774518025465,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "reference/openclaw/src/commands/agents.commands.delete.ts",
|
||||
"accessCount": 2,
|
||||
"lastAccessed": 1774530389553,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "reference/openclaw/src/commands/agents.commands.add.ts",
|
||||
"accessCount": 2,
|
||||
"lastAccessed": 1774530389605,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "backend/api/__init__.py",
|
||||
"accessCount": 2,
|
||||
"lastAccessed": 1774542416191,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "frontend/vite.config.js",
|
||||
"accessCount": 2,
|
||||
"lastAccessed": 1774544772960,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "frontend/src/store/index.js",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1774515811752,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "frontend/src/store/marketStore.js",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1774515838923,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "frontend/src/store/portfolioStore.js",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1774515839687,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "frontend/src/index.css",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1774515988837,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "frontend/src/App.css",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1774515998423,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "frontend/package.json",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1774516005569,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "frontend/src/hooks/useAgentDataRequests.js",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1774517930219,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "backend/services/gateway_admin_handlers.py",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1774517937966,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "backend/apps/agent_service.py",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1774517946208,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "frontend/src/hooks",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1774517946260,
|
||||
"type": "directory"
|
||||
},
|
||||
{
|
||||
"path": "frontend/src/hooks/useFeedProcessor.js",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1774517952115,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "reference/openclaw/src/commands/models/set.ts",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1774526963526,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "reference/openclaw/src/commands/models/list.ts",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1774526963632,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "reference/openclaw/src/cli/skills-cli.format.ts",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1774526963684,
|
||||
"type": "file"
|
||||
}
|
||||
],
|
||||
"userDirectives": []
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"timestamp": "2026-03-12T20:33:59.497Z",
|
||||
"timestamp": "2026-03-27T04:53:52.906Z",
|
||||
"backgroundTasks": [],
|
||||
"sessionStartTimestamp": "2026-03-12T14:19:33.615Z",
|
||||
"sessionId": "73b0d597-0141-4873-9d0e-2b60e4e0635e"
|
||||
"sessionStartTimestamp": "2026-03-27T04:53:21.944Z",
|
||||
"sessionId": "cbb9004e-771b-4e82-95d4-cea6d9753642"
|
||||
}
|
||||
@@ -1 +1 @@
|
||||
{"session_id":"73b0d597-0141-4873-9d0e-2b60e4e0635e","transcript_path":"/Users/cillin/.claude/projects/-Users-cillin-workspeace-agentscope-samples-evotraders/73b0d597-0141-4873-9d0e-2b60e4e0635e.jsonl","cwd":"/Users/cillin/workspeace/agentscope-samples/evotraders","model":{"id":"kimi-for-coding","display_name":"kimi-for-coding"},"workspace":{"current_dir":"/Users/cillin/workspeace/agentscope-samples/evotraders","project_dir":"/Users/cillin/workspeace/agentscope-samples/evotraders","added_dirs":["/Users/cillin/workspeace/agentscope-samples/EvoTraders","/Users/cillin/workspeace/agentscope-samples/evotraders"]},"version":"2.1.63","output_style":{"name":"default"},"cost":{"total_cost_usd":6.822239999999999,"total_duration_ms":42679588,"total_api_duration_ms":1223637,"total_lines_added":275,"total_lines_removed":240},"context_window":{"total_input_tokens":654274,"total_output_tokens":27014,"context_window_size":200000,"current_usage":{"input_tokens":48465,"output_tokens":0,"cache_creation_input_tokens":0,"cache_read_input_tokens":0},"used_percentage":24,"remaining_percentage":76},"exceeds_200k_tokens":false}
|
||||
{"session_id":"cbb9004e-771b-4e82-95d4-cea6d9753642","transcript_path":"/Users/cillin/.claude/projects/-Users-cillin-workspeace-evotraders/cbb9004e-771b-4e82-95d4-cea6d9753642.jsonl","cwd":"/Users/cillin/workspeace/evotraders","model":{"id":"MiniMax-M2.7-highspeed","display_name":"MiniMax-M2.7-highspeed"},"workspace":{"current_dir":"/Users/cillin/workspeace/evotraders","project_dir":"/Users/cillin/workspeace/evotraders","added_dirs":[]},"version":"2.1.78","output_style":{"name":"default"},"cost":{"total_cost_usd":0.660433,"total_duration_ms":168502,"total_api_duration_ms":37670,"total_lines_added":0,"total_lines_removed":0},"context_window":{"total_input_tokens":14416,"total_output_tokens":1705,"context_window_size":200000,"current_usage":{"input_tokens":461,"output_tokens":214,"cache_creation_input_tokens":0,"cache_read_input_tokens":53991},"used_percentage":27,"remaining_percentage":73},"exceeds_200k_tokens":false}
|
||||
@@ -1,3 +1,3 @@
|
||||
{
|
||||
"lastSentAt": "2026-03-12T20:31:37.362Z"
|
||||
"lastSentAt": "2026-03-27T04:55:49.635Z"
|
||||
}
|
||||
@@ -1,26 +0,0 @@
|
||||
{
|
||||
"agents": [
|
||||
{
|
||||
"agent_id": "a4090d26a45ac828d",
|
||||
"agent_type": "oh-my-claudecode:executor",
|
||||
"started_at": "2026-03-12T10:02:38.238Z",
|
||||
"parent_mode": "none",
|
||||
"status": "completed",
|
||||
"completed_at": "2026-03-12T10:10:59.192Z",
|
||||
"duration_ms": 500954
|
||||
},
|
||||
{
|
||||
"agent_id": "af87583ef76a4df30",
|
||||
"agent_type": "oh-my-claudecode:executor",
|
||||
"started_at": "2026-03-12T10:40:04.409Z",
|
||||
"parent_mode": "none",
|
||||
"status": "completed",
|
||||
"completed_at": "2026-03-12T10:41:17.387Z",
|
||||
"duration_ms": 72978
|
||||
}
|
||||
],
|
||||
"total_spawned": 2,
|
||||
"total_completed": 2,
|
||||
"total_failed": 0,
|
||||
"last_updated": "2026-03-12T10:41:17.490Z"
|
||||
}
|
||||
BIN
.playwright-mcp/page-2026-03-26T12-28-14-006Z.png
Normal file
BIN
.playwright-mcp/page-2026-03-26T12-28-14-006Z.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 144 KiB |
376
CLAUDE.md
Normal file
376
CLAUDE.md
Normal file
@@ -0,0 +1,376 @@
|
||||
# CLAUDE.md
|
||||
|
||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||
|
||||
本文件为 Claude Code (claude.ai/code) 在此代码库中工作时提供指导。
|
||||
|
||||
## 项目概述
|
||||
|
||||
大时代 是一个自进化多智能体交易系统,由 6 个 AI Agent(4 名分析师 + 投资经理 + 风控经理)协作完成交易决策。Agent 基于 AgentScope 框架构建,配合 ReMe 记忆系统实现持续学习。
|
||||
|
||||
## 常用命令
|
||||
|
||||
### Backend (Python)
|
||||
|
||||
```bash
|
||||
# 安装依赖
|
||||
uv pip install -e .
|
||||
|
||||
# 运行命令
|
||||
evotraders backtest --start 2025-11-01 --end 2025-12-01 # 回测模式
|
||||
evotraders backtest --start 2025-11-01 --end 2025-12-01 --enable-memory # 带记忆回测
|
||||
evotraders live # 实盘交易
|
||||
evotraders live -t 22:30 # 定时每日交易
|
||||
evotraders frontend # 启动可视化界面
|
||||
|
||||
# 开发服务器
|
||||
./start-dev.sh # 启动全部 4 个微服务 (agent, runtime, trading, news)
|
||||
|
||||
# Gateway WebSocket 服务器
|
||||
python backend/main.py --mode live --config-name live
|
||||
|
||||
# 单独启动微服务
|
||||
python -m uvicorn backend.apps.runtime_service:app --host 0.0.0.0 --port 8003 --reload
|
||||
python -m uvicorn backend.apps.agent_service:app --host 0.0.0.0 --port 8000 --reload
|
||||
python -m uvicorn backend.apps.trading_service:app --host 0.0.0.0 --port 8001 --reload
|
||||
python -m uvicorn backend.apps.news_service:app --host 0.0.0.0 --port 8002 --reload
|
||||
|
||||
# 测试
|
||||
pytest backend/tests # 运行全部测试
|
||||
pytest backend/tests/test_news_service_app.py -v # 运行单个测试
|
||||
```
|
||||
|
||||
### Frontend (React)
|
||||
|
||||
```bash
|
||||
cd frontend
|
||||
npm run dev # Vite 开发服务器 (http://localhost:5173)
|
||||
npm run build # 生产构建
|
||||
npm run lint # ESLint 检查
|
||||
npm run lint:fix # ESLint 自动修复
|
||||
npm run test # Vitest 单元测试
|
||||
```
|
||||
|
||||
## 架构概览
|
||||
|
||||
### 系统分层
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ Frontend (React) │
|
||||
│ WebSocket ws://localhost:8765 连接 Gateway │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ Gateway (backend/services/gateway.py) │
|
||||
│ WebSocket 服务器,编排 Pipeline,4 阶段启动 │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
│ │ │ │
|
||||
▼ ▼ ▼ ▼
|
||||
┌────────────┐ ┌────────────┐ ┌────────────┐ ┌────────────┐
|
||||
│ Market │ │ Storage │ │ Pipeline │ │ Scheduler │
|
||||
│ Service │ │ Service │ │ │ │ │
|
||||
└────────────┘ └────────────┘ └────────────┘ └────────────┘
|
||||
│
|
||||
┌──────────────────────┼──────────────────────┐
|
||||
▼ ▼ ▼
|
||||
┌──────────┐ ┌──────────┐ ┌──────────┐
|
||||
│ Analysts │ │ PM │ │ Risk │
|
||||
│ (4 个) │ │ │ │ Manager │
|
||||
└──────────┘ └──────────┘ └──────────┘
|
||||
```
|
||||
|
||||
### 微服务架构 (`backend/apps/`)
|
||||
|
||||
| 服务 | 端口 | 职责 |
|
||||
|------|------|------|
|
||||
| runtime_service | 8003 | 运行时配置、任务启动、Pipeline Runner |
|
||||
| agent_service | 8000 | Agent 生命周期、工作区管理 |
|
||||
| trading_service | 8001 | 市场数据、交易操作 |
|
||||
| news_service | 8002 | 新闻、新闻富化、解释功能 |
|
||||
|
||||
### Gateway 4 阶段启动 (`backend/services/gateway.py`)
|
||||
|
||||
1. **WebSocket Server** - 前端立即可连接
|
||||
2. **Market Service** - 价格数据开始推送
|
||||
3. **Market Status Monitor** - 市场状态监控
|
||||
4. **Scheduler** - 交易周期开始
|
||||
|
||||
### 运行时管理层 (`backend/runtime/`)
|
||||
|
||||
| 文件 | 职责 |
|
||||
|------|------|
|
||||
| `manager.py` | TradingRuntimeManager - 全局运行时管理器,agent 注册、会话、事件快照 |
|
||||
| `agent_runtime.py` | AgentRuntimeState - 单 agent 状态(status、last_session) |
|
||||
| `context.py` | TradingRunContext - 运行上下文 |
|
||||
| `session.py` | TradingSessionKey - 交易日会话键 |
|
||||
| `registry.py` | RuntimeRegistry - agent 状态注册表 |
|
||||
|
||||
快照持久化到 `runs/<run_id>/state/runtime_state.json`。
|
||||
|
||||
### Pipeline 执行 (`backend/core/`)
|
||||
|
||||
| 文件 | 职责 |
|
||||
|------|------|
|
||||
| `pipeline.py` | TradingPipeline - 核心编排器(分析→沟通→决策→执行→评估) |
|
||||
| `pipeline_runner.py` | REST API 触发的独立执行,5 阶段启动 |
|
||||
| `scheduler.py` | BacktestScheduler、Scheduler - 回测/实盘调度 |
|
||||
| `state_sync.py` | StateSync - 状态同步和广播 |
|
||||
|
||||
## 后端结构
|
||||
|
||||
```
|
||||
backend/
|
||||
├── agents/ # 多智能体实现
|
||||
│ ├── analyst.py # AnalystAgent 基类
|
||||
│ ├── portfolio_manager.py # PMAgent 投资经理
|
||||
│ ├── risk_manager.py # RiskAgent 风控经理
|
||||
│ ├── factory.py # Agent 实例工厂
|
||||
│ ├── toolkit_factory.py # 工具集工厂
|
||||
│ ├── skills_manager.py # 技能加载管理
|
||||
│ ├── workspace_manager.py # 工作区管理
|
||||
│ ├── skill_loader.py # 技能加载器
|
||||
│ ├── agent_workspace.py # Agent 工作区
|
||||
│ ├── prompt_loader.py # Prompt 加载器
|
||||
│ ├── prompt_factory.py # Prompt 工厂
|
||||
│ ├── skill_metadata.py # 技能元数据
|
||||
│ ├── registry.py # Agent 注册表
|
||||
│ ├── team_pipeline_config.py # 团队 Pipeline 配置
|
||||
│ ├── compat.py # 兼容性层
|
||||
│ ├── templates.py # 模板
|
||||
│ ├── workspace.py # 工作区
|
||||
│ ├── base/ # 核心类、Hooks
|
||||
│ │ ├── evo_agent.py # 基于 AgentScope 的核心实现
|
||||
│ │ └── hooks.py # 生命周期 Hooks
|
||||
│ └── prompts/ # Agent 提示词
|
||||
│ └── analyst/personas.yaml
|
||||
│
|
||||
├── apps/ # 微服务入口
|
||||
│ ├── runtime_service.py # 运行时服务(端口 8003)
|
||||
│ ├── agent_service.py # Agent 服务(端口 8000)
|
||||
│ ├── trading_service.py # 交易服务(端口 8001)
|
||||
│ ├── news_service.py # 新闻服务(端口 8002)
|
||||
│ └── cors.py
|
||||
│
|
||||
├── runtime/ # 运行时管理层
|
||||
│ ├── manager.py # TradingRuntimeManager
|
||||
│ ├── agent_runtime.py # AgentRuntimeState
|
||||
│ ├── context.py # TradingRunContext
|
||||
│ ├── session.py # TradingSessionKey
|
||||
│ └── registry.py # RuntimeRegistry
|
||||
│
|
||||
├── process/ # 进程监管层
|
||||
│ ├── supervisor.py # ProcessSupervisor
|
||||
│ ├── registry.py # RunRegistry
|
||||
│ └── models.py # ProcessRun、ProcessRunState
|
||||
│
|
||||
├── core/ # Pipeline 执行
|
||||
│ ├── pipeline.py # TradingPipeline(核心编排器)
|
||||
│ ├── pipeline_runner.py # 独立 Pipeline 执行
|
||||
│ ├── scheduler.py # 调度器
|
||||
│ └── state_sync.py # 状态同步
|
||||
│
|
||||
├── services/ # Gateway 和服务
|
||||
│ ├── gateway.py # WebSocket 网关
|
||||
│ ├── gateway_*.py # Gateway 子模块
|
||||
│ ├── market.py # 市场数据服务
|
||||
│ ├── storage.py # 存储服务
|
||||
│ ├── runtime_db.py # 运行时数据库
|
||||
│ └── research_db.py # 研究数据库
|
||||
│
|
||||
├── data/ # 市场数据处理
|
||||
│ ├── provider_router.py # 数据源路由
|
||||
│ ├── provider_utils.py # 数据源工具
|
||||
│ ├── market_store.py # 市场数据存储
|
||||
│ ├── market_ingest.py # 数据采集
|
||||
│ ├── cache.py # 缓存
|
||||
│ ├── schema.py # 数据 schema
|
||||
│ ├── historical_price_manager.py # 历史价格管理
|
||||
│ ├── polling_price_manager.py # 轮询价格管理
|
||||
│ ├── news_alignment.py # 新闻对齐
|
||||
│ ├── polygon_client.py # Polygon.io 客户端
|
||||
│ └── ret_data_updater.py # 离线数据更新
|
||||
│
|
||||
├── config/ # 配置
|
||||
│ ├── constants.py # Agent 配置、显示名称
|
||||
│ ├── bootstrap_config.py # 启动配置解析
|
||||
│ ├── env_config.py # 环境变量配置
|
||||
│ ├── data_config.py # 数据源配置
|
||||
│ └── agent_profiles.yaml # Agent Profile 配置
|
||||
│
|
||||
├── domains/ # 领域业务逻辑
|
||||
│ ├── news.py
|
||||
│ └── trading.py
|
||||
│
|
||||
├── llm/ # LLM 集成
|
||||
│ └── models.py # RetryChatModel、TokenRecordingModelWrapper
|
||||
│
|
||||
├── skills/ # 技能定义
|
||||
├── tools/ # 交易和分析工具
|
||||
├── enrich/ # LLM 响应富化
|
||||
├── explain/ # 交易决策解释
|
||||
├── utils/ # 工具函数
|
||||
│ ├── settlement.py # 结算协调器
|
||||
│ ├── trade_executor.py # 交易执行器
|
||||
│ ├── terminal_dashboard.py # 终端仪表板
|
||||
│ ├── analyst_tracker.py # 分析师追踪
|
||||
│ ├── baselines.py # 基准线
|
||||
│ ├── msg_adapter.py # 消息适配器
|
||||
│ └── progress.py # 进度追踪
|
||||
│
|
||||
├── api/ # FastAPI 端点
|
||||
│ └── runtime.py
|
||||
│
|
||||
└── main.py # 主入口点
|
||||
```
|
||||
|
||||
## 前端结构
|
||||
|
||||
```
|
||||
frontend/src/
|
||||
├── App.jsx # 主应用(LiveTradingApp)
|
||||
├── AppShell.jsx # App 外壳(布局、侧边栏)
|
||||
├── components/
|
||||
│ ├── RuntimeView.jsx # 交易运行时 UI
|
||||
│ ├── TraderView.jsx # 交易员界面
|
||||
│ ├── RoomView.jsx # 聊天室视图
|
||||
│ ├── StockExplainView.jsx # 股票解释视图
|
||||
│ ├── RuntimeSettingsPanel.jsx # 运行时设置面板
|
||||
│ ├── RuntimeLogsModal.jsx # 运行时日志弹窗
|
||||
│ ├── WatchlistPanel.jsx # 关注列表
|
||||
│ ├── PerformanceView.jsx # 绩效视图
|
||||
│ ├── StatisticsView.jsx # 统计视图
|
||||
│ ├── NetValueChart.jsx # 净值曲线图
|
||||
│ ├── AgentCard.jsx # Agent 卡片
|
||||
│ ├── AgentFeed.jsx # Agent 动态
|
||||
│ ├── Header.jsx # 头部
|
||||
│ ├── MarkdownModal.jsx # Markdown 弹窗
|
||||
│ ├── StockLogo.jsx # 股票 Logo
|
||||
│ └── explain/ # 解释组件
|
||||
│ ├── ExplainNewsSection.jsx
|
||||
│ ├── ExplainRangeSection.jsx
|
||||
│ ├── ExplainSimilarDaysSection.jsx
|
||||
│ ├── ExplainStorySection.jsx
|
||||
│ └── useExplainModel.js
|
||||
├── hooks/ # React Hooks
|
||||
│ ├── useWebSocketConnection.js # WebSocket 连接管理
|
||||
│ ├── useRuntimeControls.js # 运行时配置管理
|
||||
│ ├── useAgentDataRequests.js # Agent 数据请求
|
||||
│ ├── useStockDataRequests.js # 股票数据请求
|
||||
│ ├── useStockExplainData.js # 股票解释数据
|
||||
│ ├── useAgentWorkspacePanel.js # Agent 工作区面板
|
||||
│ ├── useWebsocketSessionSync.js # WebSocket 会话同步
|
||||
│ └── useFeedProcessor.js # Feed 事件处理
|
||||
├── store/ # Zustand 状态管理
|
||||
│ ├── runtimeStore.js # 连接状态、运行时配置
|
||||
│ ├── marketStore.js # 市场数据、股票价格
|
||||
│ ├── portfolioStore.js # 组合、持仓、交易
|
||||
│ ├── agentStore.js # Agent 技能、工作区
|
||||
│ └── uiStore.js # UI 状态、视图切换
|
||||
├── services/
|
||||
│ ├── websocket.js # WebSocket 客户端
|
||||
│ ├── runtimeApi.js # 运行时 API
|
||||
│ ├── runtimeControls.js # 运行时控制
|
||||
│ ├── newsApi.js # 新闻 API
|
||||
│ └── tradingApi.js # 交易 API
|
||||
├── utils/
|
||||
│ ├── formatters.js # 格式化工具
|
||||
│ └── modelIcons.js # 模型图标
|
||||
└── config/
|
||||
└── constants.js # Agent 定义、配置
|
||||
```
|
||||
|
||||
## Agent 系统
|
||||
|
||||
### 6 种 Agent 角色
|
||||
|
||||
| 角色 ID | 名称 | 职责 |
|
||||
|---------|------|------|
|
||||
| `fundamentals_analyst` | 基本面分析师 | 财务健康、盈利能力、成长质量 |
|
||||
| `technical_analyst` | 技术分析师 | 价格趋势、技术指标、动量分析 |
|
||||
| `sentiment_analyst` | 情绪分析师 | 市场情绪、新闻情绪、内幕交易 |
|
||||
| `valuation_analyst` | 估值分析师 | DCF、EV/EBITDA、intrinsic value |
|
||||
| `portfolio_manager` | 投资经理 | 决策执行、交易协调 |
|
||||
| `risk_manager` | 风控经理 | 实时价格/波动率监控、仓位限制 |
|
||||
|
||||
### 添加自定义分析师
|
||||
|
||||
1. `backend/agents/prompts/analyst/personas.yaml` 注册
|
||||
2. `backend/config/constants.py` 的 `ANALYST_TYPES` 字典添加
|
||||
3. `frontend/src/config/constants.js` 可选更新
|
||||
|
||||
### LLM 模型封装 (`backend/llm/models.py`)
|
||||
|
||||
- **RetryChatModel**: 自动重试瞬态 LLM 错误,指数退避
|
||||
- **TokenRecordingModelWrapper**: 追踪 token 消耗和成本
|
||||
|
||||
## 技能系统 (`backend/skills/`)
|
||||
|
||||
技能定义在 `SKILL.md`,包含 `instructions`、`triggers`、`parameters`、`available_tools`。
|
||||
|
||||
技能管理器支持 6 种作用域:builtin、customized、installed、active、disabled、local。
|
||||
|
||||
## 运行时数据布局
|
||||
|
||||
- `data/market_research.db` - 持久研究数据
|
||||
- `runs/<run_id>/` - 每次任务运行的状态
|
||||
- `runs/<run_id>/team_dashboard/*.json` - 仪表板导出层(非权威源)
|
||||
- `runs/<run_id>/state/runtime_state.json` - 运行时快照
|
||||
- 运行时 API 优先使用 `server_state.json` 和 `runtime.db`
|
||||
|
||||
```bash
|
||||
RUNS_RETENTION_COUNT=20 # 时间戳格式文件夹自动清理
|
||||
```
|
||||
|
||||
## 环境配置
|
||||
|
||||
### Backend (`env.template`)
|
||||
|
||||
```bash
|
||||
# 金融数据源(支持多源fallback)
|
||||
FIN_DATA_SOURCE=finnhub|financial_datasets|yfinance|local_csv
|
||||
ENABLED_DATA_SOURCES=financial_datasets,finnhub,yfinance,local_csv
|
||||
FINANCIAL_DATASETS_API_KEY= # 回测必需
|
||||
FINNHUB_API_KEY= # 实盘必需
|
||||
POLYGON_API_KEY= # Polygon市场库采集可选
|
||||
|
||||
# LLM 配置
|
||||
OPENAI_API_KEY=
|
||||
OPENAI_BASE_URL=
|
||||
MODEL_NAME=qwen3-max-preview
|
||||
|
||||
# Agent 特定模型
|
||||
AGENT_SENTIMENT_ANALYST_MODEL_NAME=deepseek-v3.2-exp
|
||||
AGENT_TECHNICAL_ANALYST_MODEL_NAME=glm-4.6
|
||||
AGENT_FUNDAMENTALS_ANALYST_MODEL_NAME=qwen3-max-preview
|
||||
AGENT_VALUATION_ANALYST_MODEL_NAME=Moonshot-Kimi-K2-Instruct
|
||||
AGENT_RISK_MANAGER_MODEL_NAME=qwen3-max-preview
|
||||
AGENT_PORTFOLIO_MANAGER_MODEL_NAME=qwen3-max-preview
|
||||
|
||||
# ReMe 记忆系统
|
||||
MEMORY_API_KEY=
|
||||
MEMORY_MODEL_NAME=qwen3-max
|
||||
MEMORY_EMBEDDING_MODEL=text-embedding-v4
|
||||
|
||||
# 交易参数
|
||||
MAX_COMM_CYCLES=2
|
||||
MARGIN_REQUIREMENT=0.5
|
||||
DATA_START_DATE=2022-01-01
|
||||
AUTO_UPDATE_DATA=true
|
||||
```
|
||||
|
||||
### Frontend (`frontend/env.template`)
|
||||
|
||||
```bash
|
||||
VITE_WS_URL=ws://localhost:8765
|
||||
```
|
||||
|
||||
## 关键依赖
|
||||
|
||||
- **AgentScope** - 多智能体框架
|
||||
- **ReMe** - 持续学习记忆系统
|
||||
- **FastAPI** + **uvicorn** - 后端 API
|
||||
- **websockets** - 实时通信
|
||||
- **React 19** + **Vite** + **TailwindCSS** - 前端
|
||||
- **Zustand** - 状态管理
|
||||
415
README.md
415
README.md
@@ -1,36 +1,34 @@
|
||||
<p align="center">
|
||||
<img src="./docs/assets/evotraders_logo.jpg" width="45%">
|
||||
<img src="./docs/assets/bigtime_logo.jpg" width="45%">
|
||||
</p>
|
||||
|
||||
<h2 align="center">EvoTraders: A Self-Evolving Multi-Agent Trading System</h2>
|
||||
<h2 align="center">大时代:自进化多智能体交易系统</h2>
|
||||
|
||||
<p align="center">
|
||||
📌 <a href="http://trading.evoagents.cn">Visit us at EvoTraders website !</a>
|
||||
📌 <a href="http://trading.evoagents.cn">Visit the 大时代 website</a>
|
||||
</p>
|
||||
|
||||

|
||||

|
||||
|
||||
EvoTraders is an open-source financial trading agent framework that builds a trading system capable of continuous learning and evolution in real markets through multi-agent collaboration and memory systems.
|
||||
大时代 is an open-source financial trading agent framework that combines multi-agent collaboration, run-scoped workspaces, and memory to support both backtests and live trading workflows.
|
||||
|
||||
The repository name and CLI entrypoints still use `evotraders` for compatibility, but the product-facing branding now follows the 大时代 naming used by the reference branch.
|
||||
|
||||
---
|
||||
|
||||
## Core Features
|
||||
|
||||
**Multi-Agent Collaborative Trading**
|
||||
A team of 6 members, including 4 specialized analyst roles (fundamentals, technical, sentiment, valuation) + portfolio manager + risk management, collaborating to make decisions like a real trading team.
|
||||
**Multi-agent trading team**
|
||||
Six roles collaborate like a real desk: four specialist analysts (fundamentals, technical, sentiment, valuation), one portfolio manager, and one risk manager.
|
||||
|
||||
You can customize your Agents here: [Custom Configuration](#custom-configuration)
|
||||
**Continuous learning**
|
||||
Agents can persist long-term memory with ReMe, reflect after each cycle, and evolve their decision patterns over time.
|
||||
|
||||
**Continuous Learning and Evolution**
|
||||
Based on the ReMe memory framework, agents reflect and summarize after each trade, preserving experience across rounds, and forming unique investment methodologies.
|
||||
**Backtest and live modes**
|
||||
The same runtime model supports historical simulation and live execution with real-time market data.
|
||||
|
||||
Through this design, we hope that when AI Agents form a team and enter the real-time market, they will gradually develop their own trading styles and decision preferences, rather than one-time random inference.
|
||||
|
||||
**Real-Time Market Trading**
|
||||
Supports real-time market data integration, providing backtesting mode and live trading mode, allowing AI Agents to learn and make decisions in real market fluctuations.
|
||||
|
||||
**Visualized Trading Information**
|
||||
Observe agents' analysis processes, communication records, and decision evolution in real-time, with complete tracking of return curves and analyst performance.
|
||||
**Operator-facing UI**
|
||||
The frontend exposes the trading room, runtime controls, logs, approvals, agent workspaces, and explain/news views.
|
||||
|
||||
<p>
|
||||
<img src="docs/assets/performance.jpg" width="45%">
|
||||
@@ -39,198 +37,325 @@ Observe agents' analysis processes, communication records, and decision evolutio
|
||||
|
||||
---
|
||||
|
||||
## Current Architecture
|
||||
|
||||
The repository is currently in a transition from a modular monolith to split service surfaces. The split-service path is the default local development mode.
|
||||
|
||||
Current app surfaces:
|
||||
|
||||
- `backend.apps.agent_service` on `:8000`: control plane for workspaces, agents, skills, and guard/approval APIs
|
||||
- `backend.apps.trading_service` on `:8001`: read-only trading data APIs
|
||||
- `backend.apps.news_service` on `:8002`: read-only explain/news APIs
|
||||
- `backend.apps.runtime_service` on `:8003`: runtime lifecycle APIs
|
||||
- `backend.apps.openclaw_service` on `:8004`: read-only OpenClaw facade
|
||||
- WebSocket gateway on `:8765`: live event/feed channel for the frontend
|
||||
|
||||
The most important runtime path today is:
|
||||
|
||||
`frontend -> runtime_service/control APIs -> gateway/runtime manager -> market service + pipeline + storage`
|
||||
|
||||
Reference notes for the migration live in [services/README.md](./services/README.md).
|
||||
|
||||
---
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Installation
|
||||
### 1. Install
|
||||
|
||||
```bash
|
||||
# Clone repository
|
||||
git clone https://github.com/agentscope-ai/agentscope-samples
|
||||
cd agentscope-samples/EvoTraders
|
||||
# clone this repository, then:
|
||||
cd evotraders
|
||||
|
||||
# Install dependencies (Recommend uv!)
|
||||
# backend runtime dependencies
|
||||
uv pip install -r requirements.txt
|
||||
|
||||
# install package entrypoint in editable mode
|
||||
uv pip install -e .
|
||||
# optional: pip install -e .
|
||||
|
||||
# optional
|
||||
# uv pip install -e ".[dev]"
|
||||
# pip install -e .
|
||||
```
|
||||
|
||||
# Configure environment variables
|
||||
Frontend dependencies:
|
||||
|
||||
```bash
|
||||
cd frontend
|
||||
npm ci
|
||||
cd ..
|
||||
```
|
||||
|
||||
Production deployment should prefer `requirements.txt` for backend and `npm ci` for frontend so the pulled environment matches the checked-in lockfiles and version pins.
|
||||
|
||||
### 2. Configure environment
|
||||
|
||||
```bash
|
||||
cp env.template .env
|
||||
# Edit .env file and add your API Keys. The following config are required:
|
||||
```
|
||||
|
||||
# finance data API: At minimum, FINANCIAL_DATASETS_API_KEY is required, corresponding to FIN_DATA_SOURCE=financial_datasets; It is recommended to add FINNHUB_API_KEY, corresponding to FIN_DATA_SOURCE=finnhub; If using live mode, FINNHUB_API_KEY must be added
|
||||
FIN_DATA_SOURCE = #finnhub or financial_datasets
|
||||
FINANCIAL_DATASETS_API_KEY= #Required
|
||||
FINNHUB_API_KEY= #Optional
|
||||
The root `env.template` is the canonical local template. A `.env.example` is also kept in the repo for reference.
|
||||
|
||||
# LLM API for Agents
|
||||
Minimum useful variables:
|
||||
|
||||
```bash
|
||||
# watchlist
|
||||
TICKERS=AAPL,MSFT,GOOGL,NVDA,TSLA,META,AMZN
|
||||
|
||||
# market data
|
||||
FIN_DATA_SOURCE=finnhub
|
||||
FINANCIAL_DATASETS_API_KEY=
|
||||
FINNHUB_API_KEY=
|
||||
POLYGON_API_KEY=
|
||||
|
||||
# agent model
|
||||
OPENAI_API_KEY=
|
||||
OPENAI_BASE_URL=
|
||||
MODEL_NAME=qwen3-max-preview
|
||||
|
||||
# LLM & embedding API for Memory
|
||||
# memory (optional unless --enable-memory is used)
|
||||
MEMORY_API_KEY=
|
||||
```
|
||||
|
||||
### Running
|
||||
Notes:
|
||||
|
||||
- `FINNHUB_API_KEY` is required for live mode.
|
||||
- `POLYGON_API_KEY` enables long-lived market-store ingestion and refresh helpers.
|
||||
- `MEMORY_API_KEY` is only required when long-term memory is enabled.
|
||||
|
||||
For a production-style local start flow, you can also use:
|
||||
|
||||
**Backtest Mode:**
|
||||
```bash
|
||||
evotraders backtest --start 2025-11-01 --end 2025-12-01
|
||||
evotraders backtest --start 2025-11-01 --end 2025-12-01 --enable-memory # Use Memory
|
||||
./start.sh
|
||||
```
|
||||
|
||||
If you do not have market data APIs and just want to try the backtest demo, download the offline data and unzip it into `backend/data`:
|
||||
### 3. Start the stack
|
||||
|
||||
Recommended local development flow:
|
||||
|
||||
```bash
|
||||
./start-dev.sh
|
||||
```
|
||||
|
||||
This starts:
|
||||
|
||||
- `agent_service` at `http://localhost:8000`
|
||||
- `trading_service` at `http://localhost:8001`
|
||||
- `news_service` at `http://localhost:8002`
|
||||
- `runtime_service` at `http://localhost:8003`
|
||||
- gateway WebSocket at `ws://localhost:8765`
|
||||
|
||||
Then start the frontend in another terminal:
|
||||
|
||||
```bash
|
||||
evotraders frontend
|
||||
```
|
||||
|
||||
Open `http://localhost:5173`.
|
||||
|
||||
You can also run services manually:
|
||||
|
||||
```bash
|
||||
python -m uvicorn backend.apps.agent_service:app --host 0.0.0.0 --port 8000 --reload
|
||||
python -m uvicorn backend.apps.trading_service:app --host 0.0.0.0 --port 8001 --reload
|
||||
python -m uvicorn backend.apps.news_service:app --host 0.0.0.0 --port 8002 --reload
|
||||
python -m uvicorn backend.apps.runtime_service:app --host 0.0.0.0 --port 8003 --reload
|
||||
python -m backend.main --mode live --host 0.0.0.0 --port 8765
|
||||
```
|
||||
|
||||
### 4. Run backtest or live mode from CLI
|
||||
|
||||
Backtest:
|
||||
|
||||
```bash
|
||||
evotraders backtest --start 2025-11-01 --end 2025-12-01
|
||||
evotraders backtest --start 2025-11-01 --end 2025-12-01 --enable-memory
|
||||
evotraders backtest --config-name smoke_fullstack --start 2025-11-01 --end 2025-12-01
|
||||
```
|
||||
|
||||
Live:
|
||||
|
||||
```bash
|
||||
evotraders live
|
||||
evotraders live --enable-memory
|
||||
evotraders live --schedule-mode intraday --interval-minutes 60
|
||||
evotraders live --trigger-time 22:30
|
||||
```
|
||||
|
||||
Help:
|
||||
|
||||
```bash
|
||||
evotraders --help
|
||||
evotraders backtest --help
|
||||
evotraders live --help
|
||||
evotraders frontend --help
|
||||
```
|
||||
|
||||
### Offline backtest data
|
||||
|
||||
If you want a quick backtest demo without external market APIs, download the offline bundle and unzip it into `backend/data`:
|
||||
|
||||
```bash
|
||||
wget "https://agentscope-open.oss-cn-beijing.aliyuncs.com/ret_data.zip"
|
||||
unzip ret_data.zip -d backend/data
|
||||
```
|
||||
The zip includes basic stock price data so you can run the backtest demo out of the box.
|
||||
|
||||
**Live Trading:**
|
||||
```bash
|
||||
evotraders live # Run immediately (default)
|
||||
evotraders live --enable-memory # Use memory
|
||||
evotraders live --mock # Mock mode (testing)
|
||||
evotraders live -t 22:30 # Run daily at 22:30 local time (auto-converts to NYSE timezone)
|
||||
```
|
||||
|
||||
**Get Help:**
|
||||
```bash
|
||||
evotraders --help # View global CLI help
|
||||
evotraders backtest --help # View backtest mode parameters
|
||||
evotraders live --help # View live/mock run parameters
|
||||
```
|
||||
|
||||
**Launch Visualization Interface:**
|
||||
```bash
|
||||
# Ensure npm is installed, otherwise install it:
|
||||
# npm install
|
||||
evotraders frontend # Default connects to port 8765, you can modify the address in ./frontend/env.local to change the port number
|
||||
```
|
||||
|
||||
Visit `http://localhost:5173/` to view the trading room, select a date and click Run/Replay to observe the decision-making process.
|
||||
|
||||
---
|
||||
|
||||
## System Architecture
|
||||
## Runtime Data Layout
|
||||
|
||||

|
||||
- Long-lived research data lives in `data/market_research.db`
|
||||
- Each run writes run-scoped state under `runs/<run_id>/`
|
||||
- `runs/<run_id>/BOOTSTRAP.md` stores run-specific bootstrap values and prompt body
|
||||
- `runs/<run_id>/state/runtime_state.json` stores runtime snapshot state
|
||||
- `runs/<run_id>/team_dashboard/*.json` is a compatibility/export layer for dashboard consumers, not the primary runtime source of truth
|
||||
|
||||
### Agent Design
|
||||
Optional retention control:
|
||||
|
||||
**Analyst Team:**
|
||||
- **Fundamentals Analyst**: Financial health, profitability, growth quality
|
||||
- **Technical Analyst**: Price trends, technical indicators, momentum analysis
|
||||
- **Sentiment Analyst**: Market sentiment, news sentiment, insider trading
|
||||
- **Valuation Analyst**: DCF, residual income, EV/EBITDA
|
||||
|
||||
**Decision Layer:**
|
||||
- **Portfolio Manager**: Integrates analysis signals from analysts, executes communication strategies, combines analyst and team historical performance, recent investment memories, and long-term investment experience to make final decisions
|
||||
- **Risk Management**: Real-time price and volatility monitoring, position limits, multi-layer risk warnings
|
||||
|
||||
### Decision Process
|
||||
|
||||
```
|
||||
Real-time Market Data → Independent Analysis → Intelligent Communication (1v1/1vN/NvN) → Decision Execution → Performance Evaluation → Learning and Evolution (Memory Update)
|
||||
```bash
|
||||
RUNS_RETENTION_COUNT=20
|
||||
```
|
||||
|
||||
Each trading day goes through five stages:
|
||||
|
||||
1. **Analysis Stage**: Each agent independently analyzes based on their respective tools and historical experience
|
||||
2. **Communication Stage**: Exchange views through private chats, notifications, meetings, etc.
|
||||
3. **Decision Stage**: Portfolio manager makes comprehensive judgments and provides final trades
|
||||
4. **Evaluation Stage**
|
||||
- **Performance Charts**: Track portfolio return curves vs. benchmark strategies (equal-weighted, market-cap weighted, momentum). Used to evaluate overall strategy effectiveness.
|
||||
|
||||
- **Analyst Rankings**: Click on avatars in the Trading Room to view analyst performance (win rate, bull/bear market win rate). Used to understand which analysts provide the most valuable insights.
|
||||
|
||||
- **Statistics**: Detailed position and trading history. Used for in-depth analysis of position management and execution quality.
|
||||
|
||||
5. **Review Stage**: Agents reflect on decisions and summarize experiences based on actual returns of the day, and store them in the ReMe memory framework for continuous improvement
|
||||
Only timestamped run folders like `YYYYMMDD_HHMMSS` are pruned automatically. Named runs such as `live`, `smoke_fullstack`, or `reload_demo_*` are preserved.
|
||||
|
||||
---
|
||||
|
||||
### Module Support
|
||||
## Frontend Service Routing
|
||||
|
||||
- **Agent Framework**: [AgentScope](https://github.com/agentscope-ai/agentscope)
|
||||
- **Memory System**: [ReMe](https://github.com/agentscope-ai/reme)
|
||||
- **LLM Support**: OpenAI, DeepSeek, Qwen, Moonshot, Zhipu AI, etc.
|
||||
The frontend always uses the control plane and runtime APIs, and can optionally call split services directly for read-only data.
|
||||
|
||||
Useful frontend env vars:
|
||||
|
||||
```bash
|
||||
VITE_CONTROL_API_BASE_URL=http://localhost:8000/api
|
||||
VITE_RUNTIME_API_BASE_URL=http://localhost:8003/api/runtime
|
||||
VITE_NEWS_SERVICE_URL=http://localhost:8002
|
||||
VITE_TRADING_SERVICE_URL=http://localhost:8001
|
||||
VITE_WS_URL=ws://localhost:8765
|
||||
```
|
||||
|
||||
If these are not set, the frontend falls back to its local defaults and compatibility paths where available.
|
||||
|
||||
---
|
||||
|
||||
## Decision Flow
|
||||
|
||||
```text
|
||||
Market data -> independent analyst work -> team communication -> portfolio decision ->
|
||||
risk review -> execution/settlement -> reflection/memory update
|
||||
```
|
||||
|
||||
The runtime manager also tracks:
|
||||
|
||||
- agent registration and status
|
||||
- pending approvals
|
||||
- run events
|
||||
- current session key
|
||||
|
||||
---
|
||||
|
||||
## Custom Configuration
|
||||
|
||||
### Custom Analyst Roles
|
||||
### Add or change analyst roles
|
||||
|
||||
1. Register role information in [./backend/agents/prompts/analyst/personas.yaml](./backend/agents/prompts/analyst/personas.yaml), for example:
|
||||
1. Define the analyst persona in [backend/agents/prompts/analyst/personas.yaml](./backend/agents/prompts/analyst/personas.yaml)
|
||||
2. Register the role in [backend/config/constants.py](./backend/config/constants.py)
|
||||
3. Optionally add/update the frontend seat metadata in [frontend/src/config/constants.js](./frontend/src/config/constants.js)
|
||||
|
||||
Example persona entry:
|
||||
|
||||
```yaml
|
||||
comprehensive_analyst:
|
||||
name: "Comprehensive Analyst"
|
||||
focus:
|
||||
- ...
|
||||
preferred_tools: # Flexibly select based on situation
|
||||
- multi-factor synthesis
|
||||
preferred_tools:
|
||||
- get_stock_price
|
||||
- get_company_financials
|
||||
description: |
|
||||
As a comprehensive analyst ...
|
||||
A generalist analyst that combines multiple signals.
|
||||
```
|
||||
|
||||
2. Add role definition in [./backend/config/constants.py](./backend/config/constants.py)
|
||||
```python
|
||||
ANALYST_TYPES = {
|
||||
# Add new analyst
|
||||
"comprehensive_analyst": {
|
||||
"display_name": "Comprehensive Analyst",
|
||||
"agent_id": "comprehensive_analyst",
|
||||
"description": "Uses LLM to intelligently select analysis tools, performs comprehensive analysis",
|
||||
"order": 15
|
||||
}
|
||||
}
|
||||
```
|
||||
### Configure per-agent models
|
||||
|
||||
3. Introduce new role in frontend configuration [./frontend/src/config/constants.js](./frontend/src/config/constants.js) (optional)
|
||||
```javascript
|
||||
export const AGENTS = [
|
||||
// Override one of the agents
|
||||
{
|
||||
id: "comprehensive_analyst",
|
||||
name: "Comprehensive Analyst",
|
||||
role: "Comprehensive Analyst",
|
||||
avatar: `${ASSET_BASE_URL}/...`,
|
||||
colors: { bg: '#F9FDFF', text: '#1565C0', accent: '#1565C0' }
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
### Custom Models
|
||||
|
||||
Configure models used by different agents in the [.env](.env) file:
|
||||
Model overrides are configured in `.env`:
|
||||
|
||||
```bash
|
||||
AGENT_SENTIMENT_ANALYST_MODEL_NAME=qwen3-max-preview
|
||||
AGENT_FUNDAMENTALS_ANALYST_MODEL_NAME=deepseek-chat
|
||||
AGENT_TECHNICAL_ANALYST_MODEL_NAME=glm-4-plus
|
||||
AGENT_VALUATION_ANALYST_MODEL_NAME=moonshot-v1-32k
|
||||
AGENT_SENTIMENT_ANALYST_MODEL_NAME=deepseek-v3.2-exp
|
||||
AGENT_TECHNICAL_ANALYST_MODEL_NAME=glm-4.6
|
||||
AGENT_FUNDAMENTALS_ANALYST_MODEL_NAME=qwen3-max-preview
|
||||
AGENT_VALUATION_ANALYST_MODEL_NAME=Moonshot-Kimi-K2-Instruct
|
||||
AGENT_RISK_MANAGER_MODEL_NAME=qwen3-max-preview
|
||||
AGENT_PORTFOLIO_MANAGER_MODEL_NAME=qwen3-max-preview
|
||||
```
|
||||
|
||||
### Project Structure
|
||||
### Run-scoped bootstrap config
|
||||
|
||||
Each run can override defaults through `runs/<run_id>/BOOTSTRAP.md`. The front matter is parsed by [backend/config/bootstrap_config.py](./backend/config/bootstrap_config.py) and can define values such as:
|
||||
|
||||
```yaml
|
||||
tickers:
|
||||
- AAPL
|
||||
- MSFT
|
||||
initial_cash: 100000
|
||||
margin_requirement: 0.5
|
||||
max_comm_cycles: 2
|
||||
schedule_mode: daily
|
||||
trigger_time: "09:30"
|
||||
enable_memory: false
|
||||
```
|
||||
EvoTraders/
|
||||
|
||||
Initialize a run workspace with:
|
||||
|
||||
```bash
|
||||
evotraders init-workspace --config-name my_run
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Project Structure
|
||||
|
||||
```text
|
||||
evotraders/
|
||||
├── backend/
|
||||
│ ├── agents/ # Agent implementation
|
||||
│ ├── communication/ # Communication system
|
||||
│ ├── memory/ # Memory system (ReMe)
|
||||
│ ├── tools/ # Analysis toolset
|
||||
│ ├── servers/ # WebSocket services
|
||||
│ └── cli.py # CLI entry point
|
||||
├── frontend/ # React visualization interface
|
||||
└── logs_and_memory/ # Logs and memory data
|
||||
│ ├── agents/ # agent roles, prompts, skills, workspaces
|
||||
│ ├── api/ # FastAPI routers
|
||||
│ ├── apps/ # split service surfaces
|
||||
│ ├── core/ # pipeline, scheduler, state sync
|
||||
│ ├── runtime/ # runtime manager and agent runtime state
|
||||
│ ├── services/ # gateway, market/storage/db services
|
||||
│ └── cli.py # Typer CLI entrypoint
|
||||
├── frontend/ # React + Vite UI
|
||||
├── shared/ # shared clients and schemas for split services
|
||||
├── runs/ # run-scoped state and dashboards
|
||||
├── data/ # long-lived research artifacts
|
||||
└── services/README.md
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Testing
|
||||
|
||||
Backend tests live under `backend/tests` and cover service apps, shared clients, domains, routing, enrichment, gateway support, and runtime support.
|
||||
|
||||
Typical commands:
|
||||
|
||||
```bash
|
||||
pytest
|
||||
pytest backend/tests/test_runtime_service_app.py
|
||||
pytest backend/tests/test_trading_service_app.py
|
||||
```
|
||||
|
||||
Frontend tests:
|
||||
|
||||
```bash
|
||||
cd frontend
|
||||
npm test
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## License and Disclaimer
|
||||
|
||||
EvoTraders is a research and educational project, open-sourced under the Apache 2.0 license.
|
||||
大时代 is a research and educational project. Review the repository license before redistribution or commercial use.
|
||||
|
||||
**Risk Warning**: Before trading with real funds, please conduct thorough testing and risk assessment. Past performance does not guarantee future returns. Investment involves risks, and decisions should be made with caution.
|
||||
**Risk warning**: this project is not investment advice. Test thoroughly before any real-money deployment. Past performance does not guarantee future returns.
|
||||
|
||||
418
README_zh.md
418
README_zh.md
@@ -1,243 +1,359 @@
|
||||
<p align="center">
|
||||
<img src="./docs/assets/evotraders_logo.jpg" width="45%">
|
||||
<img src="./docs/assets/bigtime_logo.jpg" width="45%">
|
||||
</p>
|
||||
|
||||
<h2 align="center">EvoTraders:自我进化的多智能体交易系统</h2>
|
||||
|
||||
<h2 align="center">大时代:自进化多智能体交易系统</h2>
|
||||
|
||||
<p align="center">
|
||||
📌 <a href="http://trading.evoagents.cn">Visit us at EvoTraders website !</a>
|
||||
📌 <a href="http://trading.evoagents.cn">访问大时代官网</a>
|
||||
</p>
|
||||
|
||||

|
||||

|
||||
|
||||
EvoTraders是一个开源的金融交易智能体框架,通过多智能体协作和记忆系统,构建能够在真实市场中持续学习与进化的交易系统。
|
||||
大时代 是一个开源的金融交易智能体框架,结合多智能体协作、run 级工作区和记忆机制,支持回测与实盘两类交易运行模式。
|
||||
|
||||
---
|
||||
|
||||
## 核心特性
|
||||
|
||||
**多智能体协作交易**
|
||||
6名成员,包含4种专业分析师角色(基本面、技术面、情绪、估值)+ 投资组合经理 + 风险管理,像真实交易团队一样协作决策。
|
||||
**多智能体交易团队**
|
||||
系统默认包含 6 个角色:4 个分析师(基本面、技术面、情绪、估值)+ 投资经理 + 风控经理。
|
||||
|
||||
你可以在这里自定义你的Agents,支持配置不同大模型(如 Qwen、DeepSeek、GPT、Claude等)协同分析:[自定义配置](#自定义配置)
|
||||
**持续学习**
|
||||
可选接入 ReMe 长期记忆,智能体会在每轮结束后反思、复盘并沉淀经验。
|
||||
|
||||
**持续学习与进化**
|
||||
基于 ReMe 记忆框架,智能体在每次交易后反思总结,跨回合保留经验,形成独特的投资方法论。
|
||||
|
||||
通过这样的设计,我们希望当 AI Agents 组成团队进入实时市场,它们会逐渐形成自己的交易风格和决策偏好,而不是一次性的随机推理
|
||||
|
||||
|
||||
**实时市场交易**
|
||||
支持实时行情接入,提供回测模式和实盘模式,让 AI Agents 在真实市场波动中学习和决策。
|
||||
|
||||
**可视化交易信息**
|
||||
实时观察 Agents 的分析过程、沟通记录和决策演化,完整追踪收益曲线和分析师表现。
|
||||
**统一运行时**
|
||||
同一套运行时模型支持历史回测和实时行情驱动的实盘流程。
|
||||
|
||||
**可操作前端**
|
||||
前端不只是展示层,还包含交易室、运行控制、日志、审批、Agent 工作区和 explain/news 视图。
|
||||
|
||||
<p>
|
||||
<img src="docs/assets/performance.jpg" width="45%">
|
||||
<img src="./docs/assets/dashboard.jpg" width="45%">
|
||||
</p>
|
||||
|
||||
---
|
||||
|
||||
## 当前架构
|
||||
|
||||
仓库目前处于“模块化单体 -> 拆分服务”的迁移阶段,本地开发默认走 split-service 路径。
|
||||
|
||||
当前 app surface:
|
||||
|
||||
- `backend.apps.agent_service`,端口 `8000`:控制面,负责 workspaces、agents、skills、审批接口
|
||||
- `backend.apps.trading_service`,端口 `8001`:只读交易数据接口
|
||||
- `backend.apps.news_service`,端口 `8002`:只读 explain/news 接口
|
||||
- `backend.apps.runtime_service`,端口 `8003`:运行时生命周期接口
|
||||
- `backend.apps.openclaw_service`,端口 `8004`:只读 OpenClaw facade
|
||||
- WebSocket gateway,端口 `8765`:前端实时事件和 feed 通道
|
||||
|
||||
当前最关键的主链路是:
|
||||
|
||||
`frontend -> runtime_service/control APIs -> gateway/runtime manager -> market service + pipeline + storage`
|
||||
|
||||
迁移背景可参考 [services/README.md](./services/README.md)。
|
||||
|
||||
---
|
||||
|
||||
## 快速开始
|
||||
|
||||
### 安装
|
||||
### 1. 安装
|
||||
|
||||
```bash
|
||||
# 克隆仓库
|
||||
git clone https://github.com/agentscope-ai/agentscope-samples
|
||||
cd agentscope-samples/EvoTraders
|
||||
# 克隆仓库后进入项目目录
|
||||
cd evotraders
|
||||
|
||||
# 安装依赖(推荐使用uv)
|
||||
# 安装后端运行时依赖
|
||||
uv pip install -r requirements.txt
|
||||
|
||||
# 安装项目入口(可编辑模式)
|
||||
uv pip install -e .
|
||||
# (可选)pip install -e .
|
||||
|
||||
# 配置环境变量
|
||||
# 可选
|
||||
# uv pip install -e ".[dev]"
|
||||
# pip install -e .
|
||||
```
|
||||
|
||||
前端依赖:
|
||||
|
||||
```bash
|
||||
cd frontend
|
||||
npm ci
|
||||
cd ..
|
||||
```
|
||||
|
||||
生产环境部署建议后端使用 `requirements.txt`,前端使用 `npm ci`,这样拉起的环境会严格跟随仓库中锁定的依赖版本。
|
||||
|
||||
### 2. 配置环境变量
|
||||
|
||||
```bash
|
||||
cp env.template .env
|
||||
# 编辑 .env 文件,添加你的 API Keys,以下的配置项为必填项
|
||||
```
|
||||
|
||||
# finance data API:至少需要FINANCIAL_DATASETS_API_KEY,对应FIN_DATA_SOURCE=financial_datasets;推荐添加FINNHUB_API_KEY,对应至少需要FINANCIAL_DATASETS_API_KEY,对应FIN_DATA_SOURCE填为finnhub;如果使用live 模式必须添加FINNHUB_API_KEY
|
||||
FIN_DATA_SOURCE= #finnhub or financial_datasets
|
||||
FINANCIAL_DATASETS_API_KEY= #必需
|
||||
FINNHUB_API_KEY= #可选
|
||||
根目录 `env.template` 是当前本地开发的主模板,仓库里也保留了 `.env.example` 作为参考。
|
||||
|
||||
# LLM API for Agents
|
||||
最常用的配置项:
|
||||
|
||||
```bash
|
||||
# 自选股
|
||||
TICKERS=AAPL,MSFT,GOOGL,NVDA,TSLA,META,AMZN
|
||||
|
||||
# 行情数据
|
||||
FIN_DATA_SOURCE=finnhub
|
||||
FINANCIAL_DATASETS_API_KEY=
|
||||
FINNHUB_API_KEY=
|
||||
POLYGON_API_KEY=
|
||||
|
||||
# Agent 模型
|
||||
OPENAI_API_KEY=
|
||||
OPENAI_BASE_URL=
|
||||
MODEL_NAME=qwen3-max-preview
|
||||
|
||||
# LLM & embedding API for Memory
|
||||
# 长期记忆(只有启用 --enable-memory 才需要)
|
||||
MEMORY_API_KEY=
|
||||
```
|
||||
|
||||
### 运行
|
||||
说明:
|
||||
|
||||
- live 模式必须配置 `FINNHUB_API_KEY`
|
||||
- `POLYGON_API_KEY` 用于长期 market store 的补数和刷新
|
||||
- `MEMORY_API_KEY` 仅在启用长期记忆时需要
|
||||
|
||||
如果要用更接近生产的本地启动方式,也可以直接执行:
|
||||
|
||||
**回测模式:**
|
||||
```bash
|
||||
evotraders backtest --start 2025-11-01 --end 2025-12-01
|
||||
evotraders backtest --start 2025-11-01 --end 2025-12-01 --enable-memory # 使用记忆
|
||||
|
||||
./start.sh
|
||||
```
|
||||
|
||||
如果没有可用的行情 API,想快速体验回测 demo,可直接下载离线数据并解压到 `backend/data`:
|
||||
### 3. 启动服务栈
|
||||
|
||||
本地开发推荐直接使用:
|
||||
|
||||
```bash
|
||||
./start-dev.sh
|
||||
```
|
||||
|
||||
该脚本会启动:
|
||||
|
||||
- `agent_service`:`http://localhost:8000`
|
||||
- `trading_service`:`http://localhost:8001`
|
||||
- `news_service`:`http://localhost:8002`
|
||||
- `runtime_service`:`http://localhost:8003`
|
||||
- gateway WebSocket:`ws://localhost:8765`
|
||||
|
||||
然后在另一个终端启动前端:
|
||||
|
||||
```bash
|
||||
evotraders frontend
|
||||
```
|
||||
|
||||
访问 `http://localhost:5173`。
|
||||
|
||||
也可以手动分别启动:
|
||||
|
||||
```bash
|
||||
python -m uvicorn backend.apps.agent_service:app --host 0.0.0.0 --port 8000 --reload
|
||||
python -m uvicorn backend.apps.trading_service:app --host 0.0.0.0 --port 8001 --reload
|
||||
python -m uvicorn backend.apps.news_service:app --host 0.0.0.0 --port 8002 --reload
|
||||
python -m uvicorn backend.apps.runtime_service:app --host 0.0.0.0 --port 8003 --reload
|
||||
python -m backend.main --mode live --host 0.0.0.0 --port 8765
|
||||
```
|
||||
|
||||
### 4. 使用 CLI 运行回测或实盘
|
||||
|
||||
回测:
|
||||
|
||||
```bash
|
||||
evotraders backtest --start 2025-11-01 --end 2025-12-01
|
||||
evotraders backtest --start 2025-11-01 --end 2025-12-01 --enable-memory
|
||||
evotraders backtest --config-name smoke_fullstack --start 2025-11-01 --end 2025-12-01
|
||||
```
|
||||
|
||||
实盘:
|
||||
|
||||
```bash
|
||||
evotraders live
|
||||
evotraders live --enable-memory
|
||||
evotraders live --schedule-mode intraday --interval-minutes 60
|
||||
evotraders live --trigger-time 22:30
|
||||
```
|
||||
|
||||
帮助:
|
||||
|
||||
```bash
|
||||
evotraders --help
|
||||
evotraders backtest --help
|
||||
evotraders live --help
|
||||
evotraders frontend --help
|
||||
```
|
||||
|
||||
### 离线回测数据
|
||||
|
||||
如果只是想快速体验回测,不依赖外部行情 API,可以下载离线数据包并解压到 `backend/data`:
|
||||
|
||||
```bash
|
||||
wget "https://agentscope-open.oss-cn-beijing.aliyuncs.com/ret_data.zip"
|
||||
unzip ret_data.zip -d backend/data
|
||||
```
|
||||
该压缩包提供基础的股票行情数据,解压后即可直接用于回测演示。
|
||||
|
||||
**实盘交易:**
|
||||
```bash
|
||||
evotraders live # 立即运行(默认)
|
||||
evotraders live --enable-memory # 使用记忆
|
||||
evotraders live --mock # Mock 模式(测试)
|
||||
evotraders live -t 22:30 # 每天本地时间 22:30 运行(自动转换为 NYSE 时区)
|
||||
```
|
||||
|
||||
**获取帮助:**
|
||||
```bash
|
||||
evotraders --help # 查看整体命令行帮助
|
||||
evotraders backtest --help # 查看回测模式的参数说明
|
||||
evotraders live --help # 查看实盘/Mock 运行的参数说明
|
||||
```
|
||||
|
||||
**启动可视化界面:**
|
||||
```bash
|
||||
# 确保已安装 npm, 否则请安装:
|
||||
# npm install
|
||||
evotraders frontend # 默认连接 8765 端口, 你可以修改 ./frontend/env.local 中的地址从而修改端口号
|
||||
```
|
||||
|
||||
访问 `http://localhost:5173/` 查看交易大厅,选择日期并点击 Run/Replay 观察决策过程。
|
||||
|
||||
---
|
||||
|
||||
## 系统架构
|
||||
## 运行时数据布局
|
||||
|
||||

|
||||
- 长期研究数据保存在 `data/market_research.db`
|
||||
- 每次 run 的状态写入 `runs/<run_id>/`
|
||||
- `runs/<run_id>/BOOTSTRAP.md` 保存该 run 的 bootstrap 值和 prompt body
|
||||
- `runs/<run_id>/state/runtime_state.json` 保存运行时快照
|
||||
- `runs/<run_id>/team_dashboard/*.json` 主要是给 dashboard 用的兼容导出层,不是唯一真相源
|
||||
|
||||
### 智能体设计
|
||||
可选保留策略:
|
||||
|
||||
**分析师团队:**
|
||||
- **基本面分析师**:财务健康度、盈利能力、增长质量
|
||||
- **技术分析师**:价格趋势、技术指标、动量分析
|
||||
- **情绪分析师**:市场情绪、新闻舆情、内部人交易
|
||||
- **估值分析师**:DCF、剩余收益、EV/EBITDA
|
||||
|
||||
**决策层:**
|
||||
- **投资组合经理**:整合来自分析师的分析信号,执行沟通策略,结合分析师和团队历史表现、近期投资记忆和长期投资经验,进行最终决策
|
||||
- **风险管理**:实时价格与波动率监控、头寸限制,多层风险预警
|
||||
|
||||
### 决策流程
|
||||
|
||||
```
|
||||
实时行情 → 独立分析 → 智能沟通 (1v1/1vN/NvN) → 决策执行 → 收益评估 → 学习与进化(记忆更新)
|
||||
```bash
|
||||
RUNS_RETENTION_COUNT=20
|
||||
```
|
||||
|
||||
每个交易日经历五个阶段:
|
||||
|
||||
1. **分析阶段**:各智能体基于各自工具和历史经验独立分析
|
||||
2. **沟通阶段**:通过私聊、通知、会议等方式交换观点
|
||||
3. **决策阶段**:投资组合经理综合判断,给出最终交易
|
||||
4. **评估阶段**
|
||||
- **业绩图表**: 追踪组合收益曲线 vs. 基准策略(等权、市值加权、动量)。用于评估整体策略有效性。
|
||||
|
||||
- **分析师排名**: 在 Trading Room 点击头像查看分析师表现(胜率、牛/熊市胜率)。用于了解哪些分析师提供最有价值的洞察。
|
||||
|
||||
- **统计数据**: 详细的持仓和交易历史。用于深入分析仓位管理和执行质量。
|
||||
|
||||
4. **复盘阶段**:Agents 根据当日实际收益反思决策、总结经验,并存入 ReMe 记忆框架以持续改进
|
||||
只有形如 `YYYYMMDD_HHMMSS` 的时间戳目录会被自动清理;`live`、`smoke_fullstack`、`reload_demo_*` 这类命名 run 会保留。
|
||||
|
||||
---
|
||||
|
||||
### 模块支持
|
||||
## 前端服务路由
|
||||
|
||||
- **智能体框架**:[AgentScope](https://github.com/agentscope-ai/agentscope)
|
||||
- **记忆系统**:[ReMe](https://github.com/agentscope-ai/reme)
|
||||
- **LLM 支持**:OpenAI、DeepSeek、Qwen、Moonshot、Zhipu AI 等
|
||||
前端始终会使用 control plane 和 runtime API,同时可以选择直连拆分服务读取只读数据。
|
||||
|
||||
常用前端环境变量:
|
||||
|
||||
```bash
|
||||
VITE_CONTROL_API_BASE_URL=http://localhost:8000/api
|
||||
VITE_RUNTIME_API_BASE_URL=http://localhost:8003/api/runtime
|
||||
VITE_NEWS_SERVICE_URL=http://localhost:8002
|
||||
VITE_TRADING_SERVICE_URL=http://localhost:8001
|
||||
VITE_WS_URL=ws://localhost:8765
|
||||
```
|
||||
|
||||
如果不配置,前端会按本地默认值和兼容回退逻辑运行。
|
||||
|
||||
---
|
||||
|
||||
## 决策流程
|
||||
|
||||
```text
|
||||
市场数据 -> 分析师独立分析 -> 团队沟通 -> 投资决策 ->
|
||||
风控审核 -> 执行/结算 -> 复盘/记忆更新
|
||||
```
|
||||
|
||||
运行时管理器还会跟踪:
|
||||
|
||||
- agent 注册和状态
|
||||
- 待审批项
|
||||
- run 事件
|
||||
- 当前 session key
|
||||
|
||||
---
|
||||
|
||||
## 自定义配置
|
||||
|
||||
### 自定义分析师角色
|
||||
### 新增或修改分析师角色
|
||||
|
||||
1. 在 [./backend/agents/prompts/analyst/personas.yaml](./backend/agents/prompts/analyst/personas.yaml) 中注册角色信息,例如:
|
||||
1. 在 [backend/agents/prompts/analyst/personas.yaml](./backend/agents/prompts/analyst/personas.yaml) 中定义 persona
|
||||
2. 在 [backend/config/constants.py](./backend/config/constants.py) 中注册角色
|
||||
3. 如有需要,在 [frontend/src/config/constants.js](./frontend/src/config/constants.js) 中补充前端展示元数据
|
||||
|
||||
示例:
|
||||
|
||||
```yaml
|
||||
comprehensive_analyst:
|
||||
name: "Comprehensive Analyst"
|
||||
focus:
|
||||
- ...
|
||||
preferred_tools: # Flexibly select based on situation
|
||||
- multi-factor synthesis
|
||||
preferred_tools:
|
||||
- get_stock_price
|
||||
- get_company_financials
|
||||
description: |
|
||||
As a comprehensive analyst ...
|
||||
A generalist analyst that combines multiple signals.
|
||||
```
|
||||
|
||||
2. 在 [./backend/config/constants.py](./backend/config/constants.py) 添加角色定义
|
||||
```python
|
||||
ANALYST_TYPES = {
|
||||
# 增加新的分析师
|
||||
"comprehensive_analyst": {
|
||||
"display_name": "Comprehensive Analyst",
|
||||
"agent_id": "comprehensive_analyst",
|
||||
"description": "Uses LLM to intelligently select analysis tools, performs comprehensive analysis",
|
||||
"order": 15
|
||||
}
|
||||
}
|
||||
```
|
||||
### 配置各 Agent 使用的模型
|
||||
|
||||
3. 在前端配置 [./frontend/src/config/constants.js](./frontend/src/config/constants.js) 中引入新角色(可选)
|
||||
```javascript
|
||||
export const AGENTS = [
|
||||
// 覆盖掉其中某一个agent
|
||||
{
|
||||
id: "comprehensive_analyst",
|
||||
name: "Comprehensive Analyst",
|
||||
role: "Comprehensive Analyst",
|
||||
avatar: `${ASSET_BASE_URL}/...`,
|
||||
colors: { bg: '#F9FDFF', text: '#1565C0', accent: '#1565C0' }
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
|
||||
|
||||
### 自定义模型
|
||||
|
||||
在 [.env](.env) 文件中配置不同智能体使用的模型:
|
||||
模型覆盖在 `.env` 中配置:
|
||||
|
||||
```bash
|
||||
AGENT_SENTIMENT_ANALYST_MODEL_NAME=qwen3-max-preview
|
||||
AGENT_FUNDAMENTAL_ANALYST_MODEL_NAME=deepseek-chat
|
||||
AGENT_TECHNICAL_ANALYST_MODEL_NAME=glm-4-plus
|
||||
AGENT_VALUATION_ANALYST_MODEL_NAME=moonshot-v1-32k
|
||||
AGENT_SENTIMENT_ANALYST_MODEL_NAME=deepseek-v3.2-exp
|
||||
AGENT_TECHNICAL_ANALYST_MODEL_NAME=glm-4.6
|
||||
AGENT_FUNDAMENTALS_ANALYST_MODEL_NAME=qwen3-max-preview
|
||||
AGENT_VALUATION_ANALYST_MODEL_NAME=Moonshot-Kimi-K2-Instruct
|
||||
AGENT_RISK_MANAGER_MODEL_NAME=qwen3-max-preview
|
||||
AGENT_PORTFOLIO_MANAGER_MODEL_NAME=qwen3-max-preview
|
||||
```
|
||||
|
||||
### 项目结构
|
||||
### run 级 BOOTSTRAP 配置
|
||||
|
||||
每个 run 都可以通过 `runs/<run_id>/BOOTSTRAP.md` 覆盖默认值。该文件由 [backend/config/bootstrap_config.py](./backend/config/bootstrap_config.py) 解析,front matter 可配置:
|
||||
|
||||
```yaml
|
||||
tickers:
|
||||
- AAPL
|
||||
- MSFT
|
||||
initial_cash: 100000
|
||||
margin_requirement: 0.5
|
||||
max_comm_cycles: 2
|
||||
schedule_mode: daily
|
||||
trigger_time: "09:30"
|
||||
enable_memory: false
|
||||
```
|
||||
EvoTraders/
|
||||
|
||||
初始化一个 run 工作区:
|
||||
|
||||
```bash
|
||||
evotraders init-workspace --config-name my_run
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 项目结构
|
||||
|
||||
```text
|
||||
evotraders/
|
||||
├── backend/
|
||||
│ ├── agents/ # 智能体实现
|
||||
│ ├── communication/ # 通信系统
|
||||
│ ├── memory/ # 记忆系统 (ReMe)
|
||||
│ ├── tools/ # 分析工具集
|
||||
│ ├── servers/ # WebSocket 服务
|
||||
│ └── cli.py # CLI 入口
|
||||
├── frontend/ # React 可视化界面
|
||||
└── logs_and_memory/ # 日志和记忆数据
|
||||
│ ├── agents/ # agent 角色、prompts、skills、workspaces
|
||||
│ ├── api/ # FastAPI 路由层
|
||||
│ ├── apps/ # 拆分服务 app surface
|
||||
│ ├── core/ # pipeline、scheduler、state sync
|
||||
│ ├── runtime/ # runtime manager 和 agent runtime state
|
||||
│ ├── services/ # gateway、market/storage/db 服务
|
||||
│ └── cli.py # Typer CLI 入口
|
||||
├── frontend/ # React + Vite 前端
|
||||
├── shared/ # 拆分服务共用 client 和 schema
|
||||
├── runs/ # run 级状态和 dashboard 导出
|
||||
├── data/ # 长期研究数据
|
||||
└── services/README.md
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 测试
|
||||
|
||||
后端测试位于 `backend/tests`,覆盖 service app、shared client、domain、路由、enrichment、gateway 支撑模块和 runtime 支撑模块。
|
||||
|
||||
常用命令:
|
||||
|
||||
```bash
|
||||
pytest
|
||||
pytest backend/tests/test_runtime_service_app.py
|
||||
pytest backend/tests/test_trading_service_app.py
|
||||
```
|
||||
|
||||
前端测试:
|
||||
|
||||
```bash
|
||||
cd frontend
|
||||
npm test
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 许可与免责
|
||||
|
||||
EvoTraders 是一个研究和教育项目,采用 Apache 2.0 许可协议开源。
|
||||
大时代 是研究和教育用途项目。再次分发或商用前,请先核对仓库中的实际 license 文件。
|
||||
|
||||
**风险提示**:在实际资金交易前,请务必进行充分的测试和风险评估。历史表现不代表未来收益,投资有风险,决策需谨慎。
|
||||
**风险提示**:本项目不构成投资建议。任何实盘部署前都应进行充分测试和风险评估,历史表现不代表未来收益。
|
||||
|
||||
@@ -1,6 +1,56 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Agents package - EvoAgent architecture for trading system.
|
||||
|
||||
Exports:
|
||||
- EvoAgent: Next-generation agent with workspace support
|
||||
- ToolGuardMixin: Tool call approval/denial flow
|
||||
- CommandHandler: System command handling
|
||||
- AgentFactory: Dynamic agent creation and management
|
||||
- WorkspaceManager: Legacy name for the persistent workspace registry
|
||||
- WorkspaceRegistry: Explicit run-time-agnostic workspace registry
|
||||
- RunWorkspaceManager: Run-scoped workspace asset manager
|
||||
- AgentRegistry: Central agent registry
|
||||
- Legacy compatibility: AnalystAgent, PMAgent, RiskAgent
|
||||
"""
|
||||
|
||||
# New EvoAgent architecture (from agent_core.py)
|
||||
from .agent_core import EvoAgent, ToolGuardMixin, CommandHandler
|
||||
from .factory import AgentFactory, ModelConfig
|
||||
from .workspace import WorkspaceManager, WorkspaceRegistry, WorkspaceConfig
|
||||
from .workspace_manager import RunWorkspaceManager
|
||||
from .registry import AgentRegistry, AgentInfo, get_registry, reset_registry
|
||||
|
||||
# Legacy agents (backward compatibility)
|
||||
from .analyst import AnalystAgent
|
||||
from .portfolio_manager import PMAgent
|
||||
from .risk_manager import RiskAgent
|
||||
|
||||
__all__ = ["AnalystAgent", "PMAgent", "RiskAgent"]
|
||||
# Compatibility layer
|
||||
from .compat import LegacyAgentAdapter, adapt_agent, adapt_agents, is_legacy_agent
|
||||
|
||||
__all__ = [
|
||||
# New architecture
|
||||
"EvoAgent",
|
||||
"ToolGuardMixin",
|
||||
"CommandHandler",
|
||||
"AgentFactory",
|
||||
"ModelConfig",
|
||||
"WorkspaceManager",
|
||||
"WorkspaceRegistry",
|
||||
"WorkspaceConfig",
|
||||
"RunWorkspaceManager",
|
||||
"AgentRegistry",
|
||||
"AgentInfo",
|
||||
"get_registry",
|
||||
"reset_registry",
|
||||
# Legacy compatibility
|
||||
"AnalystAgent",
|
||||
"PMAgent",
|
||||
"RiskAgent",
|
||||
# Compatibility layer
|
||||
"LegacyAgentAdapter",
|
||||
"adapt_agent",
|
||||
"adapt_agents",
|
||||
"is_legacy_agent",
|
||||
]
|
||||
|
||||
18
backend/agents/agent_core.py
Normal file
18
backend/agents/agent_core.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Compatibility layer for legacy imports.
|
||||
|
||||
This module re-exports the newer base implementations so existing import
|
||||
paths (`from backend.agents.agent_core import EvoAgent`) continue to work while
|
||||
centralizing the actual logic in `backend.agents.base.evo_agent`.
|
||||
"""
|
||||
|
||||
from .base.command_handler import CommandHandler
|
||||
from .base.evo_agent import EvoAgent
|
||||
from .base.tool_guard import ToolGuardMixin
|
||||
|
||||
__all__ = [
|
||||
"EvoAgent",
|
||||
"ToolGuardMixin",
|
||||
"CommandHandler",
|
||||
]
|
||||
75
backend/agents/agent_workspace.py
Normal file
75
backend/agents/agent_workspace.py
Normal file
@@ -0,0 +1,75 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Per-agent run-scoped workspace configuration helpers."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AgentWorkspaceConfig:
|
||||
"""Structured agent config loaded from runs/<config>/agents/<agent>/agent.yaml."""
|
||||
|
||||
values: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def prompt_files(self) -> Optional[List[str]]:
|
||||
raw = self.values.get("prompt_files")
|
||||
if not isinstance(raw, list):
|
||||
return None
|
||||
files = [
|
||||
str(item).strip()
|
||||
for item in raw
|
||||
if isinstance(item, str) and str(item).strip()
|
||||
]
|
||||
return files or None
|
||||
|
||||
@property
|
||||
def enabled_skills(self) -> List[str]:
|
||||
return _normalized_string_list(self.values.get("enabled_skills"))
|
||||
|
||||
@property
|
||||
def disabled_skills(self) -> List[str]:
|
||||
return _normalized_string_list(self.values.get("disabled_skills"))
|
||||
|
||||
@property
|
||||
def active_tool_groups(self) -> Optional[List[str]]:
|
||||
groups = _normalized_string_list(self.values.get("active_tool_groups"))
|
||||
return groups or None
|
||||
|
||||
@property
|
||||
def disabled_tool_groups(self) -> List[str]:
|
||||
return _normalized_string_list(self.values.get("disabled_tool_groups"))
|
||||
|
||||
def get(self, key: str, default: Any = None) -> Any:
|
||||
return self.values.get(key, default)
|
||||
|
||||
|
||||
def _normalized_string_list(raw: Any) -> List[str]:
|
||||
if not isinstance(raw, list):
|
||||
return []
|
||||
seen: List[str] = []
|
||||
for item in raw:
|
||||
if not isinstance(item, str):
|
||||
continue
|
||||
value = item.strip()
|
||||
if value and value not in seen:
|
||||
seen.append(value)
|
||||
return seen
|
||||
|
||||
|
||||
def load_agent_workspace_config(path: Path) -> AgentWorkspaceConfig:
|
||||
"""Load agent.yaml if present."""
|
||||
if not path.exists() or not path.is_file():
|
||||
return AgentWorkspaceConfig()
|
||||
|
||||
raw = path.read_text(encoding="utf-8").strip()
|
||||
if not raw:
|
||||
return AgentWorkspaceConfig()
|
||||
|
||||
parsed = yaml.safe_load(raw) or {}
|
||||
if not isinstance(parsed, dict):
|
||||
parsed = {}
|
||||
return AgentWorkspaceConfig(values=parsed)
|
||||
@@ -84,7 +84,6 @@ class AnalystAgent(ReActAgent):
|
||||
agent_id=self.agent_id,
|
||||
config_name=self.config.get("config_name", "default"),
|
||||
toolkit=self.toolkit,
|
||||
analyst_type=self.analyst_type_key,
|
||||
)
|
||||
|
||||
async def reply(self, x: Msg = None) -> Msg:
|
||||
|
||||
57
backend/agents/base/__init__.py
Normal file
57
backend/agents/base/__init__.py
Normal file
@@ -0,0 +1,57 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Base agent module for 大时代.
|
||||
|
||||
提供Agent基础类、命令处理、工具守卫和钩子管理等功能。
|
||||
"""
|
||||
|
||||
# 命令处理器 (从command_handler.py导入)
|
||||
from .command_handler import (
|
||||
AgentCommandDispatcher,
|
||||
CommandContext,
|
||||
CommandHandler,
|
||||
CommandResult,
|
||||
create_command_dispatcher,
|
||||
)
|
||||
|
||||
# 评估钩子 (从evaluation_hook.py导入)
|
||||
from .evaluation_hook import (
|
||||
EvaluationHook,
|
||||
EvaluationCollector,
|
||||
MetricType,
|
||||
EvaluationMetric,
|
||||
EvaluationResult,
|
||||
parse_evaluation_hooks,
|
||||
)
|
||||
|
||||
# 技能适配钩子 (从skill_adaptation_hook.py导入)
|
||||
from .skill_adaptation_hook import (
|
||||
AdaptationAction,
|
||||
AdaptationThreshold,
|
||||
AdaptationEvent,
|
||||
SkillAdaptationHook,
|
||||
AdaptationManager,
|
||||
get_adaptation_manager,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# 命令处理
|
||||
"AgentCommandDispatcher",
|
||||
"CommandContext",
|
||||
"CommandHandler",
|
||||
"CommandResult",
|
||||
"create_command_dispatcher",
|
||||
# 评估钩子
|
||||
"EvaluationHook",
|
||||
"EvaluationCollector",
|
||||
"MetricType",
|
||||
"EvaluationMetric",
|
||||
"EvaluationResult",
|
||||
"parse_evaluation_hooks",
|
||||
# 技能适配钩子
|
||||
"AdaptationAction",
|
||||
"AdaptationThreshold",
|
||||
"AdaptationEvent",
|
||||
"SkillAdaptationHook",
|
||||
"AdaptationManager",
|
||||
"get_adaptation_manager",
|
||||
]
|
||||
543
backend/agents/base/command_handler.py
Normal file
543
backend/agents/base/command_handler.py
Normal file
@@ -0,0 +1,543 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Agent command handler for system commands.
|
||||
|
||||
This module handles system commands like /save, /compact, /skills, /reload, etc.
|
||||
参考CoPaw设计,为EvoAgent提供命令处理能力。
|
||||
"""
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Protocol
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .agent import EvoAgent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommandResult:
|
||||
"""命令执行结果"""
|
||||
success: bool
|
||||
message: str
|
||||
data: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class CommandContext:
|
||||
"""命令执行上下文"""
|
||||
|
||||
def __init__(self, agent: "EvoAgent", raw_query: str, args: str = ""):
|
||||
self.agent = agent
|
||||
self.raw_query = raw_query
|
||||
self.args = args
|
||||
self.config_name = getattr(agent, "config_name", "default")
|
||||
self.agent_id = getattr(agent, "agent_id", "unknown")
|
||||
|
||||
|
||||
class CommandHandler(ABC):
|
||||
"""命令处理器抽象基类"""
|
||||
|
||||
@abstractmethod
|
||||
async def handle(self, ctx: CommandContext) -> CommandResult:
|
||||
"""处理命令"""
|
||||
pass
|
||||
|
||||
|
||||
class SaveCommandHandler(CommandHandler):
|
||||
"""处理 /save <message> 命令 - 保存内容到MEMORY.md"""
|
||||
|
||||
async def handle(self, ctx: CommandContext) -> CommandResult:
|
||||
message = ctx.args.strip()
|
||||
if not message:
|
||||
return CommandResult(
|
||||
success=False,
|
||||
message="Usage: /save <message>\n请提供要保存的内容。"
|
||||
)
|
||||
|
||||
try:
|
||||
memory_path = self._get_memory_path(ctx)
|
||||
memory_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
timestamp = self._get_timestamp()
|
||||
entry = f"\n## {timestamp}\n\n{message}\n"
|
||||
|
||||
with open(memory_path, "a", encoding="utf-8") as f:
|
||||
f.write(entry)
|
||||
|
||||
return CommandResult(
|
||||
success=True,
|
||||
message=f"✅ 内容已保存到 MEMORY.md\n- 路径: {memory_path}\n- 长度: {len(message)} 字符",
|
||||
data={"path": str(memory_path), "length": len(message)}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save to MEMORY.md: {e}")
|
||||
return CommandResult(
|
||||
success=False,
|
||||
message=f"❌ 保存失败: {str(e)}"
|
||||
)
|
||||
|
||||
def _get_memory_path(self, ctx: CommandContext) -> Path:
|
||||
"""获取MEMORY.md路径"""
|
||||
from backend.agents.skills_manager import SkillsManager
|
||||
sm = SkillsManager()
|
||||
asset_dir = sm.get_agent_asset_dir(ctx.config_name, ctx.agent_id)
|
||||
return asset_dir / "MEMORY.md"
|
||||
|
||||
def _get_timestamp(self) -> str:
|
||||
"""获取当前时间戳"""
|
||||
from datetime import datetime
|
||||
return datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
|
||||
class CompactCommandHandler(CommandHandler):
|
||||
"""处理 /compact 命令 - 压缩记忆"""
|
||||
|
||||
async def handle(self, ctx: CommandContext) -> CommandResult:
|
||||
try:
|
||||
agent = ctx.agent
|
||||
memory_manager = getattr(agent, "memory_manager", None)
|
||||
|
||||
if memory_manager is None:
|
||||
return CommandResult(
|
||||
success=False,
|
||||
message="❌ Memory Manager 未启用\n\n- 记忆压缩功能不可用\n- 请在配置中启用 memory_manager"
|
||||
)
|
||||
|
||||
messages = await self._get_messages(agent)
|
||||
if not messages:
|
||||
return CommandResult(
|
||||
success=False,
|
||||
message="⚠️ 没有可压缩的消息\n\n- 当前记忆为空\n- 无需执行压缩"
|
||||
)
|
||||
|
||||
compact_content = await memory_manager.compact_memory(messages)
|
||||
await self._update_compressed_summary(agent, compact_content)
|
||||
|
||||
return CommandResult(
|
||||
success=True,
|
||||
message=f"✅ 记忆压缩完成\n\n- 压缩了 {len(messages)} 条消息\n- 摘要长度: {len(compact_content)} 字符",
|
||||
data={"message_count": len(messages), "summary_length": len(compact_content)}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to compact memory: {e}")
|
||||
return CommandResult(
|
||||
success=False,
|
||||
message=f"❌ 压缩失败: {str(e)}"
|
||||
)
|
||||
|
||||
async def _get_messages(self, agent: "EvoAgent") -> List[Any]:
|
||||
"""获取Agent的记忆消息"""
|
||||
memory = getattr(agent, "memory", None)
|
||||
if memory is None:
|
||||
return []
|
||||
return await memory.get_memory() if hasattr(memory, "get_memory") else []
|
||||
|
||||
async def _update_compressed_summary(self, agent: "EvoAgent", content: str) -> None:
|
||||
"""更新压缩摘要"""
|
||||
memory = getattr(agent, "memory", None)
|
||||
if memory and hasattr(memory, "update_compressed_summary"):
|
||||
await memory.update_compressed_summary(content)
|
||||
|
||||
|
||||
class SkillsListCommandHandler(CommandHandler):
|
||||
"""处理 /skills list 命令 - 列出已激活技能"""
|
||||
|
||||
async def handle(self, ctx: CommandContext) -> CommandResult:
|
||||
try:
|
||||
from backend.agents.skills_manager import SkillsManager
|
||||
sm = SkillsManager()
|
||||
|
||||
active_skills = sm.list_active_skill_metadata(ctx.config_name, ctx.agent_id)
|
||||
catalog = sm.list_agent_skill_catalog(ctx.config_name, ctx.agent_id)
|
||||
|
||||
lines = ["📋 技能列表", ""]
|
||||
|
||||
if active_skills:
|
||||
lines.append("✅ 已激活技能:")
|
||||
for skill in active_skills:
|
||||
lines.append(f" • {skill.name} - {skill.description[:50]}...")
|
||||
else:
|
||||
lines.append("⚠️ 当前没有激活的技能")
|
||||
|
||||
lines.append("")
|
||||
lines.append(f"📚 可用技能总数: {len(catalog)}")
|
||||
lines.append("💡 使用 /skills enable <name> 启用技能")
|
||||
|
||||
return CommandResult(
|
||||
success=True,
|
||||
message="\n".join(lines),
|
||||
data={
|
||||
"active_count": len(active_skills),
|
||||
"catalog_count": len(catalog),
|
||||
"active": [s.skill_name for s in active_skills]
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list skills: {e}")
|
||||
return CommandResult(
|
||||
success=False,
|
||||
message=f"❌ 获取技能列表失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
class SkillsEnableCommandHandler(CommandHandler):
|
||||
"""处理 /skills enable <name> 命令 - 启用技能"""
|
||||
|
||||
async def handle(self, ctx: CommandContext) -> CommandResult:
|
||||
skill_name = ctx.args.strip()
|
||||
if not skill_name:
|
||||
return CommandResult(
|
||||
success=False,
|
||||
message="Usage: /skills enable <skill_name>\n请提供技能名称。"
|
||||
)
|
||||
|
||||
try:
|
||||
from backend.agents.skills_manager import SkillsManager
|
||||
sm = SkillsManager()
|
||||
|
||||
result = sm.update_agent_skill_overrides(
|
||||
ctx.config_name,
|
||||
ctx.agent_id,
|
||||
enable=[skill_name]
|
||||
)
|
||||
|
||||
return CommandResult(
|
||||
success=True,
|
||||
message=f"✅ 技能已启用: {skill_name}\n\n已启用技能: {', '.join(result['enabled_skills'])}",
|
||||
data=result
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to enable skill: {e}")
|
||||
return CommandResult(
|
||||
success=False,
|
||||
message=f"❌ 启用技能失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
class SkillsDisableCommandHandler(CommandHandler):
|
||||
"""处理 /skills disable <name> 命令 - 禁用技能"""
|
||||
|
||||
async def handle(self, ctx: CommandContext) -> CommandResult:
|
||||
skill_name = ctx.args.strip()
|
||||
if not skill_name:
|
||||
return CommandResult(
|
||||
success=False,
|
||||
message="Usage: /skills disable <skill_name>\n请提供技能名称。"
|
||||
)
|
||||
|
||||
try:
|
||||
from backend.agents.skills_manager import SkillsManager
|
||||
sm = SkillsManager()
|
||||
|
||||
result = sm.update_agent_skill_overrides(
|
||||
ctx.config_name,
|
||||
ctx.agent_id,
|
||||
disable=[skill_name]
|
||||
)
|
||||
|
||||
return CommandResult(
|
||||
success=True,
|
||||
message=f"✅ 技能已禁用: {skill_name}\n\n已禁用技能: {', '.join(result['disabled_skills'])}",
|
||||
data=result
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to disable skill: {e}")
|
||||
return CommandResult(
|
||||
success=False,
|
||||
message=f"❌ 禁用技能失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
class SkillsInstallCommandHandler(CommandHandler):
|
||||
"""处理 /skills install <name> 命令 - 安装技能"""
|
||||
|
||||
async def handle(self, ctx: CommandContext) -> CommandResult:
|
||||
skill_name = ctx.args.strip()
|
||||
if not skill_name:
|
||||
return CommandResult(
|
||||
success=False,
|
||||
message="Usage: /skills install <skill_name>\n请提供技能名称。"
|
||||
)
|
||||
|
||||
try:
|
||||
from backend.agents.skills_manager import SkillsManager
|
||||
from backend.agents.skill_loader import load_skill_from_dir
|
||||
sm = SkillsManager()
|
||||
|
||||
# 查找技能源目录
|
||||
source_dir = self._resolve_skill_source(sm, skill_name)
|
||||
if not source_dir:
|
||||
return CommandResult(
|
||||
success=False,
|
||||
message=f"❌ 技能未找到: {skill_name}\n\n请检查技能名称是否正确,或技能是否存在于 builtin/customized 目录。"
|
||||
)
|
||||
|
||||
# 加载并验证技能
|
||||
skill_info = load_skill_from_dir(source_dir)
|
||||
if not skill_info:
|
||||
return CommandResult(
|
||||
success=False,
|
||||
message=f"❌ 技能加载失败: {skill_name}\n\n技能格式可能不正确。"
|
||||
)
|
||||
|
||||
# 安装到agent的installed目录
|
||||
installed_root = sm.get_agent_installed_root(ctx.config_name, ctx.agent_id)
|
||||
target_dir = installed_root / skill_name
|
||||
|
||||
import shutil
|
||||
if target_dir.exists():
|
||||
shutil.rmtree(target_dir)
|
||||
shutil.copytree(source_dir, target_dir)
|
||||
|
||||
return CommandResult(
|
||||
success=True,
|
||||
message=f"✅ 技能已安装: {skill_name}\n\n- 名称: {skill_info.get('name', skill_name)}\n- 版本: {skill_info.get('version', 'unknown')}\n- 路径: {target_dir}",
|
||||
data={"skill_name": skill_name, "target_dir": str(target_dir)}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to install skill: {e}")
|
||||
return CommandResult(
|
||||
success=False,
|
||||
message=f"❌ 安装技能失败: {str(e)}"
|
||||
)
|
||||
|
||||
def _resolve_skill_source(self, sm: "SkillsManager", skill_name: str) -> Optional[Path]:
|
||||
"""解析技能源目录"""
|
||||
for root in [sm.customized_root, sm.builtin_root]:
|
||||
candidate = root / skill_name
|
||||
if candidate.exists() and (candidate / "SKILL.md").exists():
|
||||
return candidate
|
||||
return None
|
||||
|
||||
|
||||
class ReloadCommandHandler(CommandHandler):
|
||||
"""处理 /reload 命令 - 重新加载配置"""
|
||||
|
||||
async def handle(self, ctx: CommandContext) -> CommandResult:
|
||||
try:
|
||||
agent = ctx.agent
|
||||
|
||||
# 重新加载配置
|
||||
if hasattr(agent, "reload_config"):
|
||||
await agent.reload_config()
|
||||
|
||||
# 重新加载技能
|
||||
from backend.agents.skills_manager import SkillsManager
|
||||
sm = SkillsManager()
|
||||
|
||||
# 刷新技能同步
|
||||
active_root = sm.get_agent_active_root(ctx.config_name, ctx.agent_id)
|
||||
if active_root.exists():
|
||||
# 清除缓存,强制重新加载
|
||||
import shutil
|
||||
for item in active_root.iterdir():
|
||||
if item.is_dir():
|
||||
shutil.rmtree(item)
|
||||
|
||||
return CommandResult(
|
||||
success=True,
|
||||
message="✅ 配置已重新加载\n\n- Agent配置已刷新\n- 技能缓存已清除\n- 请重启对话以应用所有更改",
|
||||
data={"config_name": ctx.config_name, "agent_id": ctx.agent_id}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to reload config: {e}")
|
||||
return CommandResult(
|
||||
success=False,
|
||||
message=f"❌ 重新加载失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
class StatusCommandHandler(CommandHandler):
|
||||
"""处理 /status 命令 - 显示Agent状态"""
|
||||
|
||||
async def handle(self, ctx: CommandContext) -> CommandResult:
|
||||
try:
|
||||
agent = ctx.agent
|
||||
|
||||
lines = ["📊 Agent 状态", ""]
|
||||
lines.append(f"🆔 Agent ID: {ctx.agent_id}")
|
||||
lines.append(f"⚙️ Config: {ctx.config_name}")
|
||||
|
||||
# 模型信息
|
||||
model = getattr(agent, "model", None)
|
||||
if model:
|
||||
lines.append(f"🤖 Model: {model}")
|
||||
|
||||
# 记忆状态
|
||||
memory = getattr(agent, "memory", None)
|
||||
if memory:
|
||||
msg_count = len(getattr(memory, "content", []))
|
||||
lines.append(f"💾 Memory: {msg_count} messages")
|
||||
|
||||
# 技能状态
|
||||
from backend.agents.skills_manager import SkillsManager
|
||||
sm = SkillsManager()
|
||||
active_skills = sm.list_active_skill_metadata(ctx.config_name, ctx.agent_id)
|
||||
lines.append(f"🔧 Active Skills: {len(active_skills)}")
|
||||
|
||||
# 工具组状态
|
||||
toolkit = getattr(agent, "toolkit", None)
|
||||
if toolkit:
|
||||
groups = getattr(toolkit, "tool_groups", {})
|
||||
active_groups = [name for name, g in groups.items() if getattr(g, "active", False)]
|
||||
lines.append(f"🛠️ Active Tool Groups: {', '.join(active_groups) if active_groups else 'None'}")
|
||||
|
||||
return CommandResult(
|
||||
success=True,
|
||||
message="\n".join(lines),
|
||||
data={
|
||||
"agent_id": ctx.agent_id,
|
||||
"config_name": ctx.config_name,
|
||||
"active_skills_count": len(active_skills)
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get status: {e}")
|
||||
return CommandResult(
|
||||
success=False,
|
||||
message=f"❌ 获取状态失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
class HelpCommandHandler(CommandHandler):
|
||||
"""处理 /help 命令 - 显示帮助"""
|
||||
|
||||
async def handle(self, ctx: CommandContext) -> CommandResult:
|
||||
help_text = """📖 EvoAgent 命令帮助
|
||||
|
||||
可用命令:
|
||||
/save <message> - 保存内容到 MEMORY.md
|
||||
/compact - 压缩记忆
|
||||
/skills list - 列出已激活技能
|
||||
/skills enable <name> - 启用技能
|
||||
/skills disable <name>- 禁用技能
|
||||
/skills install <name>- 安装技能
|
||||
/reload - 重新加载配置
|
||||
/status - 显示Agent状态
|
||||
/help - 显示此帮助信息
|
||||
|
||||
提示:
|
||||
• 所有命令以 / 开头
|
||||
• 命令不区分大小写
|
||||
• 使用 Tab 键可自动补全命令
|
||||
"""
|
||||
return CommandResult(success=True, message=help_text)
|
||||
|
||||
|
||||
class AgentCommandDispatcher:
|
||||
"""Agent命令分发器
|
||||
|
||||
参考CoPaw的CommandHandler设计,为EvoAgent提供统一的命令处理入口。
|
||||
"""
|
||||
|
||||
# 支持的系统命令
|
||||
SYSTEM_COMMANDS = frozenset({
|
||||
"save", "compact",
|
||||
"skills", "reload",
|
||||
"status", "help"
|
||||
})
|
||||
|
||||
def __init__(self):
|
||||
self._handlers: Dict[str, CommandHandler] = {}
|
||||
self._subcommands: Dict[str, Dict[str, CommandHandler]] = {}
|
||||
self._register_default_handlers()
|
||||
|
||||
def _register_default_handlers(self) -> None:
|
||||
"""注册默认命令处理器"""
|
||||
self._handlers["save"] = SaveCommandHandler()
|
||||
self._handlers["compact"] = CompactCommandHandler()
|
||||
self._handlers["reload"] = ReloadCommandHandler()
|
||||
self._handlers["status"] = StatusCommandHandler()
|
||||
self._handlers["help"] = HelpCommandHandler()
|
||||
|
||||
# 子命令: /skills list/enable/disable/install
|
||||
self._subcommands["skills"] = {
|
||||
"list": SkillsListCommandHandler(),
|
||||
"enable": SkillsEnableCommandHandler(),
|
||||
"disable": SkillsDisableCommandHandler(),
|
||||
"install": SkillsInstallCommandHandler(),
|
||||
}
|
||||
|
||||
def is_command(self, query: str | None) -> bool:
|
||||
"""检查是否为命令
|
||||
|
||||
Args:
|
||||
query: 用户输入字符串
|
||||
|
||||
Returns:
|
||||
True 如果是系统命令
|
||||
"""
|
||||
if not isinstance(query, str) or not query.startswith("/"):
|
||||
return False
|
||||
|
||||
parts = query.strip().lstrip("/").split()
|
||||
if not parts:
|
||||
return False
|
||||
|
||||
cmd = parts[0].lower()
|
||||
|
||||
# 检查主命令
|
||||
if cmd in self.SYSTEM_COMMANDS:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def handle(self, agent: "EvoAgent", query: str) -> CommandResult:
|
||||
"""处理命令
|
||||
|
||||
Args:
|
||||
agent: EvoAgent实例
|
||||
query: 命令字符串
|
||||
|
||||
Returns:
|
||||
命令执行结果
|
||||
"""
|
||||
if not self.is_command(query):
|
||||
return CommandResult(
|
||||
success=False,
|
||||
message=f"未知命令: {query}\n使用 /help 查看可用命令。"
|
||||
)
|
||||
|
||||
# 解析命令和参数
|
||||
parts = query.strip().lstrip("/").split(maxsplit=1)
|
||||
cmd = parts[0].lower()
|
||||
args = parts[1] if len(parts) > 1 else ""
|
||||
|
||||
logger.info(f"Processing command: {cmd}, args: {args}")
|
||||
|
||||
# 处理子命令 (e.g., /skills list)
|
||||
if cmd in self._subcommands:
|
||||
sub_parts = args.split(maxsplit=1)
|
||||
sub_cmd = sub_parts[0].lower() if sub_parts else ""
|
||||
sub_args = sub_parts[1] if len(sub_parts) > 1 else ""
|
||||
|
||||
handlers = self._subcommands[cmd]
|
||||
handler = handlers.get(sub_cmd)
|
||||
|
||||
if handler is None:
|
||||
available = ", ".join(handlers.keys())
|
||||
return CommandResult(
|
||||
success=False,
|
||||
message=f"未知子命令: {sub_cmd}\n可用子命令: {available}"
|
||||
)
|
||||
|
||||
ctx = CommandContext(agent, query, sub_args)
|
||||
return await handler.handle(ctx)
|
||||
|
||||
# 处理主命令
|
||||
handler = self._handlers.get(cmd)
|
||||
if handler is None:
|
||||
return CommandResult(
|
||||
success=False,
|
||||
message=f"命令未实现: {cmd}"
|
||||
)
|
||||
|
||||
ctx = CommandContext(agent, query, args)
|
||||
return await handler.handle(ctx)
|
||||
|
||||
|
||||
# 便捷函数
|
||||
def create_command_dispatcher() -> AgentCommandDispatcher:
|
||||
"""创建命令分发器实例"""
|
||||
return AgentCommandDispatcher()
|
||||
452
backend/agents/base/evaluation_hook.py
Normal file
452
backend/agents/base/evaluation_hook.py
Normal file
@@ -0,0 +1,452 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Evaluation hooks system for skills.
|
||||
|
||||
Provides evaluation metric collection and storage for skill performance tracking.
|
||||
Based on the evaluation hooks design in SKILL_TEMPLATE.md.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MetricType(Enum):
|
||||
"""Types of evaluation metrics."""
|
||||
HIT_RATE = "hit_rate" # 信号命中率
|
||||
RISK_VIOLATION = "risk_violation" # 风控违例率
|
||||
POSITION_DEVIATION = "position_deviation" # 仓位偏离率
|
||||
PnL_ATTRIBUTION = "pnl_attribution" # P&L 归因一致性
|
||||
SIGNAL_CONSISTENCY = "signal_consistency" # 信号一致性
|
||||
DECISION_LATENCY = "decision_latency" # 决策延迟
|
||||
TOOL_USAGE = "tool_usage" # 工具使用率
|
||||
CUSTOM = "custom" # 自定义指标
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvaluationMetric:
|
||||
"""A single evaluation metric."""
|
||||
name: str
|
||||
metric_type: MetricType
|
||||
value: float
|
||||
timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"name": self.name,
|
||||
"metric_type": self.metric_type.value,
|
||||
"value": self.value,
|
||||
"timestamp": self.timestamp,
|
||||
"metadata": self.metadata,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvaluationResult:
|
||||
"""Evaluation result for a skill execution."""
|
||||
skill_name: str
|
||||
run_id: str
|
||||
agent_id: str
|
||||
metrics: List[EvaluationMetric] = field(default_factory=list)
|
||||
inputs: Dict[str, Any] = field(default_factory=dict)
|
||||
outputs: Dict[str, Any] = field(default_factory=dict)
|
||||
decision: Optional[str] = None
|
||||
success: bool = True
|
||||
error_message: Optional[str] = None
|
||||
started_at: Optional[str] = None
|
||||
completed_at: Optional[str] = field(default_factory=lambda: datetime.now().isoformat())
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"skill_name": self.skill_name,
|
||||
"run_id": self.run_id,
|
||||
"agent_id": self.agent_id,
|
||||
"metrics": [m.to_dict() for m in self.metrics],
|
||||
"inputs": self.inputs,
|
||||
"outputs": self.outputs,
|
||||
"decision": self.decision,
|
||||
"success": self.success,
|
||||
"error_message": self.error_message,
|
||||
"started_at": self.started_at,
|
||||
"completed_at": self.completed_at,
|
||||
}
|
||||
|
||||
|
||||
class EvaluationHook:
|
||||
"""Hook for collecting skill evaluation metrics.
|
||||
|
||||
This hook collects and stores evaluation metrics after skill execution
|
||||
for later analysis and memory/reflection stages.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
storage_dir: Path,
|
||||
run_id: str,
|
||||
agent_id: str,
|
||||
):
|
||||
"""Initialize evaluation hook.
|
||||
|
||||
Args:
|
||||
storage_dir: Directory to store evaluation results
|
||||
run_id: Current run identifier
|
||||
agent_id: Current agent identifier
|
||||
"""
|
||||
self.storage_dir = Path(storage_dir)
|
||||
self.run_id = run_id
|
||||
self.agent_id = agent_id
|
||||
self._current_evaluation: Optional[EvaluationResult] = None
|
||||
|
||||
def start_evaluation(
|
||||
self,
|
||||
skill_name: str,
|
||||
inputs: Dict[str, Any],
|
||||
) -> None:
|
||||
"""Start a new evaluation session.
|
||||
|
||||
Args:
|
||||
skill_name: Name of the skill being evaluated
|
||||
inputs: Input parameters for the skill
|
||||
"""
|
||||
self._current_evaluation = EvaluationResult(
|
||||
skill_name=skill_name,
|
||||
run_id=self.run_id,
|
||||
agent_id=self.agent_id,
|
||||
inputs=inputs,
|
||||
started_at=datetime.now().isoformat(),
|
||||
)
|
||||
logger.debug(f"Started evaluation for skill: {skill_name}")
|
||||
|
||||
def add_metric(
|
||||
self,
|
||||
name: str,
|
||||
metric_type: MetricType,
|
||||
value: float,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""Add an evaluation metric.
|
||||
|
||||
Args:
|
||||
name: Metric name
|
||||
metric_type: Type of metric
|
||||
value: Metric value
|
||||
metadata: Additional metadata
|
||||
"""
|
||||
if self._current_evaluation is None:
|
||||
logger.warning("No active evaluation session, ignoring metric")
|
||||
return
|
||||
|
||||
metric = EvaluationMetric(
|
||||
name=name,
|
||||
metric_type=metric_type,
|
||||
value=value,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
self._current_evaluation.metrics.append(metric)
|
||||
logger.debug(f"Added metric: {name} = {value}")
|
||||
|
||||
def add_metrics(self, metrics: List[EvaluationMetric]) -> None:
|
||||
"""Add multiple evaluation metrics at once.
|
||||
|
||||
Args:
|
||||
metrics: List of metrics to add
|
||||
"""
|
||||
if self._current_evaluation is None:
|
||||
logger.warning("No active evaluation session, ignoring metrics")
|
||||
return
|
||||
|
||||
self._current_evaluation.metrics.extend(metrics)
|
||||
|
||||
def record_outputs(self, outputs: Dict[str, Any]) -> None:
|
||||
"""Record skill outputs.
|
||||
|
||||
Args:
|
||||
outputs: Output from skill execution
|
||||
"""
|
||||
if self._current_evaluation is None:
|
||||
logger.warning("No active evaluation session, ignoring outputs")
|
||||
return
|
||||
|
||||
self._current_evaluation.outputs = outputs
|
||||
|
||||
def record_decision(self, decision: str) -> None:
|
||||
"""Record the final decision.
|
||||
|
||||
Args:
|
||||
decision: Final decision made by the skill
|
||||
"""
|
||||
if self._current_evaluation is None:
|
||||
logger.warning("No active evaluation session, ignoring decision")
|
||||
return
|
||||
|
||||
self._current_evaluation.decision = decision
|
||||
|
||||
def complete_evaluation(
|
||||
self,
|
||||
success: bool = True,
|
||||
error_message: Optional[str] = None,
|
||||
) -> Optional[EvaluationResult]:
|
||||
"""Complete the evaluation session and persist results.
|
||||
|
||||
Args:
|
||||
success: Whether the skill execution was successful
|
||||
error_message: Error message if failed
|
||||
|
||||
Returns:
|
||||
The completed evaluation result, or None if no active evaluation
|
||||
"""
|
||||
if self._current_evaluation is None:
|
||||
logger.warning("No active evaluation to complete")
|
||||
return None
|
||||
|
||||
self._current_evaluation.success = success
|
||||
self._current_evaluation.error_message = error_message
|
||||
self._current_evaluation.completed_at = datetime.now().isoformat()
|
||||
|
||||
# Persist to storage
|
||||
result = self._persist_evaluation(self._current_evaluation)
|
||||
|
||||
self._current_evaluation = None
|
||||
logger.debug(f"Completed evaluation for skill: {result.skill_name}")
|
||||
|
||||
return result
|
||||
|
||||
def _persist_evaluation(self, evaluation: EvaluationResult) -> EvaluationResult:
|
||||
"""Persist evaluation result to storage.
|
||||
|
||||
Args:
|
||||
evaluation: Evaluation result to persist
|
||||
|
||||
Returns:
|
||||
The persisted evaluation
|
||||
"""
|
||||
# Create run-specific directory
|
||||
run_dir = self.storage_dir / self.run_id
|
||||
run_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create agent-specific subdirectory
|
||||
agent_dir = run_dir / self.agent_id
|
||||
agent_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Generate filename with timestamp
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
||||
filename = f"{evaluation.skill_name}_{timestamp}.json"
|
||||
filepath = agent_dir / filename
|
||||
|
||||
# Write evaluation result
|
||||
try:
|
||||
with open(filepath, "w", encoding="utf-8") as f:
|
||||
json.dump(evaluation.to_dict(), f, ensure_ascii=False, indent=2)
|
||||
logger.info(f"Persisted evaluation to: {filepath}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to persist evaluation: {e}")
|
||||
|
||||
return evaluation
|
||||
|
||||
def cancel_evaluation(self) -> None:
|
||||
"""Cancel the current evaluation session without saving."""
|
||||
if self._current_evaluation is not None:
|
||||
logger.debug(f"Cancelled evaluation for: {self._current_evaluation.skill_name}")
|
||||
self._current_evaluation = None
|
||||
|
||||
|
||||
class EvaluationCollector:
|
||||
"""Collector for aggregating evaluation metrics across runs.
|
||||
|
||||
Provides methods to query and analyze evaluation results.
|
||||
"""
|
||||
|
||||
def __init__(self, storage_dir: Path):
|
||||
"""Initialize evaluation collector.
|
||||
|
||||
Args:
|
||||
storage_dir: Root directory containing evaluation results
|
||||
"""
|
||||
self.storage_dir = Path(storage_dir)
|
||||
|
||||
def get_run_evaluations(
|
||||
self,
|
||||
run_id: str,
|
||||
agent_id: Optional[str] = None,
|
||||
) -> List[EvaluationResult]:
|
||||
"""Get all evaluations for a run.
|
||||
|
||||
Args:
|
||||
run_id: Run identifier
|
||||
agent_id: Optional agent identifier to filter by
|
||||
|
||||
Returns:
|
||||
List of evaluation results
|
||||
"""
|
||||
run_dir = self.storage_dir / run_id
|
||||
if not run_dir.exists():
|
||||
return []
|
||||
|
||||
evaluations = []
|
||||
|
||||
agent_dirs = [run_dir / agent_id] if agent_id else run_dir.iterdir()
|
||||
|
||||
for agent_dir in agent_dirs:
|
||||
if not agent_dir.is_dir():
|
||||
continue
|
||||
|
||||
for eval_file in agent_dir.glob("*.json"):
|
||||
try:
|
||||
with open(eval_file, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
evaluations.append(self._parse_evaluation(data))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load evaluation {eval_file}: {e}")
|
||||
|
||||
return evaluations
|
||||
|
||||
def get_skill_metrics(
|
||||
self,
|
||||
skill_name: str,
|
||||
run_ids: Optional[List[str]] = None,
|
||||
) -> List[EvaluationMetric]:
|
||||
"""Get all metrics for a specific skill.
|
||||
|
||||
Args:
|
||||
skill_name: Name of the skill
|
||||
run_ids: Optional list of run IDs to filter by
|
||||
|
||||
Returns:
|
||||
List of metrics for the skill
|
||||
"""
|
||||
metrics = []
|
||||
|
||||
if run_ids is None:
|
||||
run_ids = [d.name for d in self.storage_dir.iterdir() if d.is_dir()]
|
||||
|
||||
for run_id in run_ids:
|
||||
evaluations = self.get_run_evaluations(run_id)
|
||||
for eval_result in evaluations:
|
||||
if eval_result.skill_name == skill_name:
|
||||
metrics.extend(eval_result.metrics)
|
||||
|
||||
return metrics
|
||||
|
||||
def calculate_skill_stats(
|
||||
self,
|
||||
skill_name: str,
|
||||
metric_type: MetricType,
|
||||
run_ids: Optional[List[str]] = None,
|
||||
) -> Dict[str, float]:
|
||||
"""Calculate statistics for a specific metric type.
|
||||
|
||||
Args:
|
||||
skill_name: Name of the skill
|
||||
metric_type: Type of metric to calculate
|
||||
run_ids: Optional list of run IDs to filter by
|
||||
|
||||
Returns:
|
||||
Dictionary with min, max, avg, count statistics
|
||||
"""
|
||||
metrics = self.get_skill_metrics(skill_name, run_ids)
|
||||
filtered = [m for m in metrics if m.metric_type == metric_type]
|
||||
|
||||
if not filtered:
|
||||
return {"count": 0}
|
||||
|
||||
values = [m.value for m in filtered]
|
||||
return {
|
||||
"count": len(values),
|
||||
"min": min(values),
|
||||
"max": max(values),
|
||||
"avg": sum(values) / len(values),
|
||||
}
|
||||
|
||||
def _parse_evaluation(self, data: Dict[str, Any]) -> EvaluationResult:
|
||||
"""Parse evaluation data into EvaluationResult.
|
||||
|
||||
Args:
|
||||
data: Raw evaluation data
|
||||
|
||||
Returns:
|
||||
Parsed EvaluationResult
|
||||
"""
|
||||
metrics = []
|
||||
for m in data.get("metrics", []):
|
||||
metrics.append(EvaluationMetric(
|
||||
name=m["name"],
|
||||
metric_type=MetricType(m["metric_type"]),
|
||||
value=m["value"],
|
||||
timestamp=m.get("timestamp", ""),
|
||||
metadata=m.get("metadata", {}),
|
||||
))
|
||||
|
||||
return EvaluationResult(
|
||||
skill_name=data["skill_name"],
|
||||
run_id=data["run_id"],
|
||||
agent_id=data["agent_id"],
|
||||
metrics=metrics,
|
||||
inputs=data.get("inputs", {}),
|
||||
outputs=data.get("outputs", {}),
|
||||
decision=data.get("decision"),
|
||||
success=data.get("success", True),
|
||||
error_message=data.get("error_message"),
|
||||
started_at=data.get("started_at"),
|
||||
completed_at=data.get("completed_at"),
|
||||
)
|
||||
|
||||
|
||||
def parse_evaluation_hooks(skill_dir: Path) -> Dict[str, Any]:
|
||||
"""Parse evaluation hooks from SKILL.md.
|
||||
|
||||
Extracts the Optional: Evaluation hooks section from skill documentation.
|
||||
|
||||
Args:
|
||||
skill_dir: Skill directory path
|
||||
|
||||
Returns:
|
||||
Dictionary containing evaluation hook definitions
|
||||
"""
|
||||
skill_md = skill_dir / "SKILL.md"
|
||||
if not skill_md.exists():
|
||||
return {}
|
||||
|
||||
try:
|
||||
content = skill_md.read_text(encoding="utf-8")
|
||||
|
||||
# Extract evaluation hooks section
|
||||
if "## Optional: Evaluation hooks" in content:
|
||||
start = content.find("## Optional: Evaluation hooks")
|
||||
# Find the next ## section or end of file
|
||||
next_section = content.find("\n## ", start + 1)
|
||||
if next_section == -1:
|
||||
eval_section = content[start:]
|
||||
else:
|
||||
eval_section = content[start:next_section]
|
||||
|
||||
# Parse metrics from the section
|
||||
metrics = []
|
||||
for metric_type in MetricType:
|
||||
if metric_type.value.replace("_", " ") in eval_section.lower():
|
||||
metrics.append(metric_type.value)
|
||||
|
||||
return {
|
||||
"supported_metrics": metrics,
|
||||
"section_content": eval_section.strip(),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse evaluation hooks: {e}")
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
__all__ = [
|
||||
"MetricType",
|
||||
"EvaluationMetric",
|
||||
"EvaluationResult",
|
||||
"EvaluationHook",
|
||||
"EvaluationCollector",
|
||||
"parse_evaluation_hooks",
|
||||
]
|
||||
510
backend/agents/base/evo_agent.py
Normal file
510
backend/agents/base/evo_agent.py
Normal file
@@ -0,0 +1,510 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""EvoAgent - Core agent implementation for 大时代.
|
||||
|
||||
This module provides the main EvoAgent class built on AgentScope's ReActAgent,
|
||||
with integrated tools, skills, and memory management based on CoPaw design.
|
||||
|
||||
Key features:
|
||||
- Workspace-driven configuration from Markdown files
|
||||
- Dynamic skill loading from skills/active directories
|
||||
- Tool-guard security interception
|
||||
- Hook system for extensibility
|
||||
- Runtime skill and prompt reloading
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Type, TYPE_CHECKING
|
||||
|
||||
from agentscope.agent import ReActAgent
|
||||
from agentscope.memory import InMemoryMemory
|
||||
from agentscope.message import Msg
|
||||
from agentscope.tool import Toolkit
|
||||
|
||||
from .tool_guard import ToolGuardMixin
|
||||
from .hooks import (
|
||||
HookManager,
|
||||
BootstrapHook,
|
||||
MemoryCompactionHook,
|
||||
WorkspaceWatchHook,
|
||||
HOOK_PRE_REASONING,
|
||||
)
|
||||
from ..prompts.builder import (
|
||||
PromptBuilder,
|
||||
build_system_prompt_from_workspace,
|
||||
)
|
||||
from ..agent_workspace import load_agent_workspace_config
|
||||
from ..skills_manager import SkillsManager
|
||||
|
||||
# Team infrastructure imports (graceful import - may not exist yet)
|
||||
try:
|
||||
from backend.agents.team.messenger import AgentMessenger
|
||||
from backend.agents.team.task_delegator import TaskDelegator
|
||||
TEAM_INFRA_AVAILABLE = True
|
||||
except ImportError:
|
||||
TEAM_INFRA_AVAILABLE = False
|
||||
AgentMessenger = None
|
||||
TaskDelegator = None
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from agentscope.formatter import FormatterBase
|
||||
from agentscope.model import ModelWrapperBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EvoAgent(ToolGuardMixin, ReActAgent):
|
||||
"""EvoAgent with integrated tools, skills, and memory management.
|
||||
|
||||
This agent extends ReActAgent with:
|
||||
- Workspace-driven configuration from AGENTS.md/SOUL.md/PROFILE.md/etc.
|
||||
- Dynamic skill loading from skills/active directories
|
||||
- Tool-guard security interception (via ToolGuardMixin)
|
||||
- Hook system for extensibility (bootstrap, memory compaction)
|
||||
- Runtime skill and prompt reloading
|
||||
|
||||
MRO note
|
||||
~~~~~~~~
|
||||
``ToolGuardMixin`` overrides ``_acting`` and ``_reasoning`` via
|
||||
Python's MRO: EvoAgent → ToolGuardMixin → ReActAgent.
|
||||
|
||||
Example:
|
||||
agent = EvoAgent(
|
||||
agent_id="fundamentals_analyst",
|
||||
config_name="smoke_fullstack",
|
||||
workspace_dir=Path("runs/smoke_fullstack/agents/fundamentals_analyst"),
|
||||
model=model_instance,
|
||||
formatter=formatter_instance,
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_id: str,
|
||||
config_name: str,
|
||||
workspace_dir: Path,
|
||||
model: "ModelWrapperBase",
|
||||
formatter: "FormatterBase",
|
||||
skills_manager: Optional[SkillsManager] = None,
|
||||
sys_prompt: Optional[str] = None,
|
||||
max_iters: int = 10,
|
||||
memory: Optional[Any] = None,
|
||||
enable_tool_guard: bool = True,
|
||||
enable_bootstrap_hook: bool = True,
|
||||
enable_memory_compaction: bool = False,
|
||||
memory_manager: Optional[Any] = None,
|
||||
memory_compact_threshold: Optional[int] = None,
|
||||
env_context: Optional[str] = None,
|
||||
prompt_files: Optional[List[str]] = None,
|
||||
):
|
||||
"""Initialize EvoAgent.
|
||||
|
||||
Args:
|
||||
agent_id: Unique identifier for this agent
|
||||
config_name: Run configuration name (e.g., "smoke_fullstack")
|
||||
workspace_dir: Agent workspace directory containing markdown files
|
||||
model: LLM model instance
|
||||
formatter: Message formatter instance
|
||||
skills_manager: Optional SkillsManager instance
|
||||
sys_prompt: Optional override for system prompt
|
||||
max_iters: Maximum reasoning-acting iterations
|
||||
memory: Optional memory instance (defaults to InMemoryMemory)
|
||||
enable_tool_guard: Enable tool-guard security interception
|
||||
enable_bootstrap_hook: Enable bootstrap guidance on first interaction
|
||||
enable_memory_compaction: Enable automatic memory compaction
|
||||
memory_manager: Optional memory manager for compaction
|
||||
memory_compact_threshold: Token threshold for memory compaction
|
||||
env_context: Optional environment context to prepend to system prompt
|
||||
prompt_files: List of markdown files to load (defaults to standard set)
|
||||
"""
|
||||
self.agent_id = agent_id
|
||||
self.config_name = config_name
|
||||
self.workspace_dir = Path(workspace_dir)
|
||||
self._skills_manager = skills_manager or SkillsManager()
|
||||
self._env_context = env_context
|
||||
self._prompt_files = prompt_files
|
||||
|
||||
# Initialize tool guard
|
||||
if enable_tool_guard:
|
||||
self._init_tool_guard()
|
||||
|
||||
# Load agent configuration from workspace
|
||||
self._agent_config = self._load_agent_config()
|
||||
|
||||
# Build or use provided system prompt
|
||||
if sys_prompt is not None:
|
||||
self._sys_prompt = sys_prompt
|
||||
else:
|
||||
self._sys_prompt = self._build_system_prompt()
|
||||
|
||||
# Create toolkit with skills
|
||||
toolkit = self._create_toolkit()
|
||||
|
||||
# Initialize hook manager
|
||||
self._hook_manager = HookManager()
|
||||
|
||||
# Initialize parent ReActAgent
|
||||
super().__init__(
|
||||
name=agent_id,
|
||||
model=model,
|
||||
sys_prompt=self._sys_prompt,
|
||||
toolkit=toolkit,
|
||||
memory=memory or InMemoryMemory(),
|
||||
formatter=formatter,
|
||||
max_iters=max_iters,
|
||||
)
|
||||
|
||||
# Register hooks
|
||||
self._register_hooks(
|
||||
enable_bootstrap=enable_bootstrap_hook,
|
||||
enable_memory_compaction=enable_memory_compaction,
|
||||
memory_manager=memory_manager,
|
||||
memory_compact_threshold=memory_compact_threshold,
|
||||
)
|
||||
|
||||
# Initialize team infrastructure if available
|
||||
self._messenger: Optional["AgentMessenger"] = None
|
||||
self._task_delegator: Optional["TaskDelegator"] = None
|
||||
if TEAM_INFRA_AVAILABLE:
|
||||
self._init_team_infrastructure()
|
||||
|
||||
logger.info(
|
||||
"EvoAgent initialized: %s (workspace: %s)",
|
||||
agent_id,
|
||||
workspace_dir,
|
||||
)
|
||||
|
||||
def _load_agent_config(self) -> Dict[str, Any]:
|
||||
"""Load agent configuration from workspace.
|
||||
|
||||
Returns:
|
||||
Agent configuration dictionary
|
||||
"""
|
||||
config_path = self.workspace_dir / "agent.yaml"
|
||||
if config_path.exists():
|
||||
loaded = load_agent_workspace_config(config_path)
|
||||
return dict(loaded.values)
|
||||
return {}
|
||||
|
||||
def _build_system_prompt(self) -> str:
|
||||
"""Build system prompt from workspace markdown files.
|
||||
|
||||
Uses PromptBuilder to load and combine AGENTS.md, SOUL.md,
|
||||
PROFILE.md, and other configured files.
|
||||
|
||||
Returns:
|
||||
Complete system prompt string
|
||||
"""
|
||||
prompt = build_system_prompt_from_workspace(
|
||||
workspace_dir=self.workspace_dir,
|
||||
enabled_files=self._prompt_files,
|
||||
agent_id=self.agent_id,
|
||||
extra_context=self._env_context,
|
||||
)
|
||||
return prompt
|
||||
|
||||
def _create_toolkit(self) -> Toolkit:
|
||||
"""Create and populate toolkit with agent skills.
|
||||
|
||||
Loads skills from the agent's active skills directory and
|
||||
registers them with the toolkit.
|
||||
|
||||
Returns:
|
||||
Configured Toolkit instance
|
||||
"""
|
||||
toolkit = Toolkit(
|
||||
agent_skill_instruction=(
|
||||
"<system-info>You have access to specialized skills. "
|
||||
"Each skill lives in a directory and is described by SKILL.md. "
|
||||
"Follow the skill instructions when they are relevant to the current task."
|
||||
"</system-info>"
|
||||
),
|
||||
agent_skill_template="- {name} (dir: {dir}): {description}",
|
||||
)
|
||||
|
||||
# Register skills from active directory
|
||||
active_skills_dir = self._skills_manager.get_agent_active_root(
|
||||
self.config_name,
|
||||
self.agent_id,
|
||||
)
|
||||
|
||||
if active_skills_dir.exists():
|
||||
for skill_dir in sorted(active_skills_dir.iterdir()):
|
||||
if skill_dir.is_dir() and (skill_dir / "SKILL.md").exists():
|
||||
try:
|
||||
toolkit.register_agent_skill(str(skill_dir))
|
||||
logger.debug("Registered skill: %s", skill_dir.name)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to register skill '%s': %s",
|
||||
skill_dir.name,
|
||||
e,
|
||||
)
|
||||
|
||||
return toolkit
|
||||
|
||||
def _register_hooks(
|
||||
self,
|
||||
enable_bootstrap: bool,
|
||||
enable_memory_compaction: bool,
|
||||
memory_manager: Optional[Any],
|
||||
memory_compact_threshold: Optional[int],
|
||||
) -> None:
|
||||
"""Register agent hooks.
|
||||
|
||||
Args:
|
||||
enable_bootstrap: Enable bootstrap hook
|
||||
enable_memory_compaction: Enable memory compaction hook
|
||||
memory_manager: Memory manager instance
|
||||
memory_compact_threshold: Token threshold for compaction
|
||||
"""
|
||||
# Bootstrap hook - checks BOOTSTRAP.md on first interaction
|
||||
if enable_bootstrap:
|
||||
bootstrap_hook = BootstrapHook(
|
||||
workspace_dir=self.workspace_dir,
|
||||
language="zh",
|
||||
)
|
||||
self._hook_manager.register(
|
||||
hook_type=HOOK_PRE_REASONING,
|
||||
hook_name="bootstrap",
|
||||
hook=bootstrap_hook,
|
||||
)
|
||||
logger.debug("Registered bootstrap hook")
|
||||
|
||||
# Memory compaction hook
|
||||
if enable_memory_compaction and memory_manager is not None:
|
||||
compaction_hook = MemoryCompactionHook(
|
||||
memory_manager=memory_manager,
|
||||
memory_compact_threshold=memory_compact_threshold,
|
||||
)
|
||||
self._hook_manager.register(
|
||||
hook_type=HOOK_PRE_REASONING,
|
||||
hook_name="memory_compaction",
|
||||
hook=compaction_hook,
|
||||
)
|
||||
logger.debug("Registered memory compaction hook")
|
||||
|
||||
# Workspace watch hook - auto-reload markdown files on change
|
||||
workspace_watch_hook = WorkspaceWatchHook(
|
||||
workspace_dir=self.workspace_dir,
|
||||
)
|
||||
self._hook_manager.register(
|
||||
hook_type=HOOK_PRE_REASONING,
|
||||
hook_name="workspace_watch",
|
||||
hook=workspace_watch_hook,
|
||||
)
|
||||
logger.debug("Registered workspace watch hook")
|
||||
|
||||
async def _reasoning(self, **kwargs) -> Msg:
|
||||
"""Override reasoning to execute pre-reasoning hooks.
|
||||
|
||||
Args:
|
||||
**kwargs: Arguments for reasoning
|
||||
|
||||
Returns:
|
||||
Response message
|
||||
"""
|
||||
# Execute pre-reasoning hooks
|
||||
kwargs = await self._hook_manager.execute(
|
||||
hook_type=HOOK_PRE_REASONING,
|
||||
agent=self,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
|
||||
# Call parent (which may be ToolGuardMixin's _reasoning)
|
||||
return await super()._reasoning(**kwargs)
|
||||
|
||||
def reload_skills(self, active_skill_dirs: Optional[List[Path]] = None) -> None:
|
||||
"""Reload skills at runtime.
|
||||
|
||||
Rebuilds the toolkit with current skills from the active directory.
|
||||
|
||||
Args:
|
||||
active_skill_dirs: Optional list of specific skill directories to load
|
||||
"""
|
||||
logger.info("Reloading skills for agent: %s", self.agent_id)
|
||||
|
||||
# Create new toolkit
|
||||
new_toolkit = Toolkit(
|
||||
agent_skill_instruction=(
|
||||
"<system-info>You have access to specialized skills. "
|
||||
"Each skill lives in a directory and is described by SKILL.md. "
|
||||
"Follow the skill instructions when they are relevant to the current task."
|
||||
"</system-info>"
|
||||
),
|
||||
agent_skill_template="- {name} (dir: {dir}): {description}",
|
||||
)
|
||||
|
||||
# Register skills
|
||||
if active_skill_dirs is None:
|
||||
active_skills_dir = self._skills_manager.get_agent_active_root(
|
||||
self.config_name,
|
||||
self.agent_id,
|
||||
)
|
||||
if active_skills_dir.exists():
|
||||
active_skill_dirs = [
|
||||
d for d in active_skills_dir.iterdir()
|
||||
if d.is_dir() and (d / "SKILL.md").exists()
|
||||
]
|
||||
else:
|
||||
active_skill_dirs = []
|
||||
|
||||
for skill_dir in active_skill_dirs:
|
||||
if skill_dir.exists() and (skill_dir / "SKILL.md").exists():
|
||||
try:
|
||||
new_toolkit.register_agent_skill(str(skill_dir))
|
||||
logger.debug("Reloaded skill: %s", skill_dir.name)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to reload skill '%s': %s",
|
||||
skill_dir.name,
|
||||
e,
|
||||
)
|
||||
|
||||
# Replace toolkit
|
||||
self.toolkit = new_toolkit
|
||||
logger.info("Skills reloaded for agent: %s", self.agent_id)
|
||||
|
||||
def rebuild_sys_prompt(self) -> None:
|
||||
"""Rebuild and replace the system prompt at runtime.
|
||||
|
||||
Useful after updating AGENTS.md, SOUL.md, PROFILE.md, etc.
|
||||
to ensure the prompt reflects the latest configuration.
|
||||
|
||||
Updates both self._sys_prompt and the first system-role
|
||||
message stored in self.memory.content.
|
||||
"""
|
||||
logger.info("Rebuilding system prompt for agent: %s", self.agent_id)
|
||||
|
||||
# Reload agent config in case it changed
|
||||
self._agent_config = self._load_agent_config()
|
||||
|
||||
# Rebuild prompt
|
||||
self._sys_prompt = self._build_system_prompt()
|
||||
|
||||
# Update memory if system message exists
|
||||
if hasattr(self, "memory") and self.memory.content:
|
||||
for msg, _marks in self.memory.content:
|
||||
if getattr(msg, "role", None) == "system":
|
||||
msg.content = self._sys_prompt
|
||||
logger.debug("Updated system message in memory")
|
||||
break
|
||||
|
||||
logger.info("System prompt rebuilt for agent: %s", self.agent_id)
|
||||
|
||||
async def reply(
|
||||
self,
|
||||
msg: Msg | List[Msg] | None = None,
|
||||
structured_model: Optional[Type[Any]] = None,
|
||||
) -> Msg:
|
||||
"""Process a message and return a response.
|
||||
|
||||
Args:
|
||||
msg: Input message(s) from user
|
||||
structured_model: Optional pydantic model for structured output
|
||||
|
||||
Returns:
|
||||
Response message
|
||||
"""
|
||||
# Handle list of messages
|
||||
if isinstance(msg, list):
|
||||
# Process each message in sequence
|
||||
for m in msg[:-1]:
|
||||
await self.memory.add(m)
|
||||
msg = msg[-1] if msg else None
|
||||
|
||||
return await super().reply(msg=msg, structured_model=structured_model)
|
||||
|
||||
def get_agent_info(self) -> Dict[str, Any]:
|
||||
"""Get agent information.
|
||||
|
||||
Returns:
|
||||
Dictionary with agent metadata
|
||||
"""
|
||||
return {
|
||||
"agent_id": self.agent_id,
|
||||
"config_name": self.config_name,
|
||||
"workspace_dir": str(self.workspace_dir),
|
||||
"skills_count": len([
|
||||
s for s in self._skills_manager.list_active_skill_metadata(
|
||||
self.config_name,
|
||||
self.agent_id,
|
||||
)
|
||||
]),
|
||||
"registered_hooks": self._hook_manager.list_hooks(),
|
||||
"team_infra_available": TEAM_INFRA_AVAILABLE,
|
||||
}
|
||||
|
||||
def _init_team_infrastructure(self) -> None:
|
||||
"""Initialize team infrastructure components (messenger and task delegator).
|
||||
|
||||
This method initializes the AgentMessenger for inter-agent communication
|
||||
and the TaskDelegator for subagent delegation.
|
||||
"""
|
||||
if not TEAM_INFRA_AVAILABLE:
|
||||
return
|
||||
|
||||
try:
|
||||
self._messenger = AgentMessenger(agent_id=self.agent_id)
|
||||
self._task_delegator = TaskDelegator(agent=self)
|
||||
logger.debug(
|
||||
"Team infrastructure initialized for agent: %s",
|
||||
self.agent_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to initialize team infrastructure for %s: %s",
|
||||
self.agent_id,
|
||||
e,
|
||||
)
|
||||
self._messenger = None
|
||||
self._task_delegator = None
|
||||
|
||||
@property
|
||||
def messenger(self) -> Optional["AgentMessenger"]:
|
||||
"""Get the agent's messenger for inter-agent communication.
|
||||
|
||||
Returns:
|
||||
AgentMessenger instance if available, None otherwise
|
||||
"""
|
||||
return self._messenger
|
||||
|
||||
async def delegate_task(
|
||||
self,
|
||||
task_type: str,
|
||||
task_data: Dict[str, Any],
|
||||
target_agent: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Delegate a task to a subagent using the TaskDelegator.
|
||||
|
||||
Args:
|
||||
task_type: Type of task to delegate
|
||||
task_data: Data/payload for the task
|
||||
target_agent: Optional specific agent ID to delegate to
|
||||
|
||||
Returns:
|
||||
Dict containing the delegation result
|
||||
"""
|
||||
if not TEAM_INFRA_AVAILABLE or self._task_delegator is None:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Team infrastructure not available",
|
||||
}
|
||||
|
||||
try:
|
||||
return await self._task_delegator.delegate_task(
|
||||
task_type=task_type,
|
||||
task_data=task_data,
|
||||
target_agent=target_agent,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Task delegation failed for %s: %s",
|
||||
self.agent_id,
|
||||
e,
|
||||
)
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
|
||||
__all__ = ["EvoAgent"]
|
||||
613
backend/agents/base/hooks.py
Normal file
613
backend/agents/base/hooks.py
Normal file
@@ -0,0 +1,613 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Hook system for EvoAgent.
|
||||
|
||||
Provides pre_reasoning and post_acting hooks with built-in implementations:
|
||||
- BootstrapHook: First-time setup guidance
|
||||
- MemoryCompactionHook: Automatic memory compression
|
||||
|
||||
Based on CoPaw's hooks design.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from agentscope.agent import ReActAgent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Hook types
|
||||
HookType = str
|
||||
HOOK_PRE_REASONING: HookType = "pre_reasoning"
|
||||
HOOK_POST_ACTING: HookType = "post_acting"
|
||||
|
||||
|
||||
class Hook(ABC):
|
||||
"""Abstract base class for agent hooks."""
|
||||
|
||||
@abstractmethod
|
||||
async def __call__(
|
||||
self,
|
||||
agent: "ReActAgent",
|
||||
kwargs: Dict[str, Any],
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Execute the hook.
|
||||
|
||||
Args:
|
||||
agent: The agent instance
|
||||
kwargs: Input arguments to the method being hooked
|
||||
|
||||
Returns:
|
||||
Modified kwargs or None to use original
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class HookManager:
|
||||
"""Manages agent hooks.
|
||||
|
||||
Provides registration and execution of hooks for different
|
||||
lifecycle events in the agent's operation.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._hooks: Dict[HookType, List[tuple[str, Hook]]] = {
|
||||
HOOK_PRE_REASONING: [],
|
||||
HOOK_POST_ACTING: [],
|
||||
}
|
||||
|
||||
def register(
|
||||
self,
|
||||
hook_type: HookType,
|
||||
hook_name: str,
|
||||
hook: Hook | Callable,
|
||||
) -> None:
|
||||
"""Register a hook.
|
||||
|
||||
Args:
|
||||
hook_type: Type of hook (pre_reasoning, post_acting)
|
||||
hook_name: Unique name for this hook
|
||||
hook: Hook instance or callable
|
||||
"""
|
||||
# Remove existing hook with same name
|
||||
self._hooks[hook_type] = [
|
||||
(name, h) for name, h in self._hooks[hook_type] if name != hook_name
|
||||
]
|
||||
self._hooks[hook_type].append((hook_name, hook))
|
||||
logger.debug("Registered hook '%s' for type '%s'", hook_name, hook_type)
|
||||
|
||||
def unregister(self, hook_type: HookType, hook_name: str) -> bool:
|
||||
"""Unregister a hook.
|
||||
|
||||
Args:
|
||||
hook_type: Type of hook
|
||||
hook_name: Name of the hook to remove
|
||||
|
||||
Returns:
|
||||
True if hook was found and removed
|
||||
"""
|
||||
original_len = len(self._hooks[hook_type])
|
||||
self._hooks[hook_type] = [
|
||||
(name, h) for name, h in self._hooks[hook_type] if name != hook_name
|
||||
]
|
||||
removed = len(self._hooks[hook_type]) < original_len
|
||||
if removed:
|
||||
logger.debug("Unregistered hook '%s' from type '%s'", hook_name, hook_type)
|
||||
return removed
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
hook_type: HookType,
|
||||
agent: "ReActAgent",
|
||||
kwargs: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute all hooks of a given type.
|
||||
|
||||
Args:
|
||||
hook_type: Type of hooks to execute
|
||||
agent: The agent instance
|
||||
kwargs: Input arguments
|
||||
|
||||
Returns:
|
||||
Potentially modified kwargs
|
||||
"""
|
||||
for name, hook in self._hooks[hook_type]:
|
||||
try:
|
||||
result = await hook(agent, kwargs)
|
||||
if result is not None:
|
||||
kwargs = result
|
||||
except Exception as e:
|
||||
logger.error("Hook '%s' failed: %s", name, e, exc_info=True)
|
||||
|
||||
return kwargs
|
||||
|
||||
def list_hooks(self, hook_type: Optional[HookType] = None) -> List[str]:
|
||||
"""List registered hook names.
|
||||
|
||||
Args:
|
||||
hook_type: Optional type to filter by
|
||||
|
||||
Returns:
|
||||
List of hook names
|
||||
"""
|
||||
if hook_type:
|
||||
return [name for name, _ in self._hooks.get(hook_type, [])]
|
||||
|
||||
names = []
|
||||
for hooks in self._hooks.values():
|
||||
names.extend([name for name, _ in hooks])
|
||||
return names
|
||||
|
||||
|
||||
class BootstrapHook(Hook):
|
||||
"""Hook for bootstrap guidance on first user interaction.
|
||||
|
||||
This hook looks for a BOOTSTRAP.md file in the working directory
|
||||
and if found, prepends guidance to the first user message to help
|
||||
establish the agent's identity and user preferences.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
workspace_dir: Path,
|
||||
language: str = "zh",
|
||||
):
|
||||
"""Initialize bootstrap hook.
|
||||
|
||||
Args:
|
||||
workspace_dir: Working directory containing BOOTSTRAP.md
|
||||
language: Language code for bootstrap guidance (en/zh)
|
||||
"""
|
||||
self.workspace_dir = Path(workspace_dir)
|
||||
self.language = language
|
||||
self._completed_flag = self.workspace_dir / ".bootstrap_completed"
|
||||
|
||||
def _is_first_user_interaction(self, agent: "ReActAgent") -> bool:
|
||||
"""Check if this is the first user interaction.
|
||||
|
||||
Args:
|
||||
agent: The agent instance
|
||||
|
||||
Returns:
|
||||
True if first user interaction
|
||||
"""
|
||||
if not hasattr(agent, "memory") or not agent.memory.content:
|
||||
return True
|
||||
|
||||
# Count user messages (excluding system)
|
||||
user_count = sum(
|
||||
1 for msg, _ in agent.memory.content if msg.role == "user"
|
||||
)
|
||||
return user_count <= 1
|
||||
|
||||
def _build_bootstrap_guidance(self) -> str:
|
||||
"""Build bootstrap guidance message.
|
||||
|
||||
Returns:
|
||||
Formatted bootstrap guidance
|
||||
"""
|
||||
if self.language == "zh":
|
||||
return (
|
||||
"# 引导模式\n"
|
||||
"\n"
|
||||
"工作目录中存在 `BOOTSTRAP.md` — 首次设置。\n"
|
||||
"\n"
|
||||
"1. 阅读 BOOTSTRAP.md,友好地表示初次见面,"
|
||||
"引导用户完成设置。\n"
|
||||
"2. 按照 BOOTSTRAP.md 的指示,"
|
||||
"帮助用户定义你的身份和偏好。\n"
|
||||
"3. 按指南创建/更新必要文件"
|
||||
"(PROFILE.md、MEMORY.md 等)。\n"
|
||||
"4. 完成后删除 BOOTSTRAP.md。\n"
|
||||
"\n"
|
||||
"如果用户希望跳过,直接回答下面的问题即可。\n"
|
||||
"\n"
|
||||
"---\n"
|
||||
"\n"
|
||||
)
|
||||
|
||||
return (
|
||||
"# BOOTSTRAP MODE\n"
|
||||
"\n"
|
||||
"`BOOTSTRAP.md` exists — first-time setup.\n"
|
||||
"\n"
|
||||
"1. Read BOOTSTRAP.md, greet the user, "
|
||||
"and guide them through setup.\n"
|
||||
"2. Follow BOOTSTRAP.md instructions "
|
||||
"to define identity and preferences.\n"
|
||||
"3. Create/update files "
|
||||
"(PROFILE.md, MEMORY.md, etc.) as described.\n"
|
||||
"4. Delete BOOTSTRAP.md when done.\n"
|
||||
"\n"
|
||||
"If the user wants to skip, answer their "
|
||||
"question directly instead.\n"
|
||||
"\n"
|
||||
"---\n"
|
||||
"\n"
|
||||
)
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
agent: "ReActAgent",
|
||||
kwargs: Dict[str, Any],
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Check and load BOOTSTRAP.md on first user interaction.
|
||||
|
||||
Args:
|
||||
agent: The agent instance
|
||||
kwargs: Input arguments to the _reasoning method
|
||||
|
||||
Returns:
|
||||
None (hook doesn't modify kwargs)
|
||||
"""
|
||||
try:
|
||||
bootstrap_path = self.workspace_dir / "BOOTSTRAP.md"
|
||||
|
||||
# Check if bootstrap has already been triggered
|
||||
if self._completed_flag.exists():
|
||||
return None
|
||||
|
||||
if not bootstrap_path.exists():
|
||||
return None
|
||||
|
||||
if not self._is_first_user_interaction(agent):
|
||||
return None
|
||||
|
||||
bootstrap_guidance = self._build_bootstrap_guidance()
|
||||
|
||||
logger.debug("Found BOOTSTRAP.md [%s], prepending guidance", self.language)
|
||||
|
||||
# Prepend to first user message in memory
|
||||
if hasattr(agent, "memory") and agent.memory.content:
|
||||
system_count = sum(
|
||||
1 for msg, _ in agent.memory.content if msg.role == "system"
|
||||
)
|
||||
for msg, _ in agent.memory.content[system_count:]:
|
||||
if msg.role == "user":
|
||||
# Prepend guidance to message content
|
||||
original_content = msg.content
|
||||
msg.content = bootstrap_guidance + original_content
|
||||
break
|
||||
|
||||
logger.debug("Bootstrap guidance prepended to first user message")
|
||||
|
||||
# Create completion flag to prevent repeated triggering
|
||||
self._completed_flag.touch()
|
||||
logger.debug("Created bootstrap completion flag")
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to process bootstrap: %s", e, exc_info=True)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class WorkspaceWatchHook(Hook):
|
||||
"""Hook for auto-reloading workspace markdown files on change.
|
||||
|
||||
Monitors SOUL.md, AGENTS.md, PROFILE.md, etc. and triggers
|
||||
a prompt rebuild when any of them change. Based on CoPaw's
|
||||
AgentConfigWatcher approach but for markdown files.
|
||||
"""
|
||||
|
||||
# Files to monitor (same as PromptBuilder.DEFAULT_FILES)
|
||||
WATCHED_FILES = frozenset([
|
||||
"SOUL.md", "AGENTS.md", "PROFILE.md",
|
||||
"POLICY.md", "MEMORY.md",
|
||||
"BOOTSTRAP.md",
|
||||
])
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
workspace_dir: Path,
|
||||
poll_interval: float = 2.0,
|
||||
):
|
||||
"""Initialize workspace watch hook.
|
||||
|
||||
Args:
|
||||
workspace_dir: Workspace directory to monitor
|
||||
poll_interval: How often to check for changes (seconds)
|
||||
"""
|
||||
self.workspace_dir = Path(workspace_dir)
|
||||
self.poll_interval = poll_interval
|
||||
self._last_mtimes: dict[str, float] = {}
|
||||
self._initialized = False
|
||||
|
||||
def _scan_mtimes(self) -> dict[str, float]:
|
||||
"""Scan watched files and return their current mtimes."""
|
||||
mtimes = {}
|
||||
for name in self.WATCHED_FILES:
|
||||
path = self.workspace_dir / name
|
||||
if path.exists():
|
||||
mtimes[name] = path.stat().st_mtime
|
||||
return mtimes
|
||||
|
||||
def _has_changes(self) -> bool:
|
||||
"""Check if any watched file has changed since last check."""
|
||||
current = self._scan_mtimes()
|
||||
|
||||
if not self._initialized:
|
||||
self._last_mtimes = current
|
||||
self._initialized = True
|
||||
return False
|
||||
|
||||
# Check for new, modified, or deleted files
|
||||
if set(current.keys()) != set(self._last_mtimes.keys()):
|
||||
self._last_mtimes = current
|
||||
return True
|
||||
|
||||
for name, mtime in current.items():
|
||||
if mtime != self._last_mtimes.get(name):
|
||||
self._last_mtimes = current
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
agent: "ReActAgent",
|
||||
kwargs: Dict[str, Any],
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Check for file changes and rebuild prompt if needed.
|
||||
|
||||
Args:
|
||||
agent: The agent instance
|
||||
kwargs: Input arguments (unused)
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
try:
|
||||
if self._has_changes():
|
||||
logger.info(
|
||||
"Workspace files changed, triggering prompt rebuild for: %s",
|
||||
getattr(agent, "agent_id", "unknown"),
|
||||
)
|
||||
if hasattr(agent, "rebuild_sys_prompt"):
|
||||
agent.rebuild_sys_prompt()
|
||||
else:
|
||||
logger.warning(
|
||||
"Agent %s has no rebuild_sys_prompt method",
|
||||
getattr(agent, "agent_id", "unknown"),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Workspace watch hook failed: %s", e, exc_info=True)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class MemoryCompactionHook(Hook):
|
||||
"""Hook for automatic memory compaction when context is full.
|
||||
|
||||
This hook monitors the token count of messages and triggers compaction
|
||||
when it exceeds the threshold. It preserves the system prompt and recent
|
||||
messages while summarizing older conversation history.
|
||||
|
||||
Based on CoPaw's memory compaction design with additional improvements:
|
||||
- memory_compact_ratio: Ratio to compact when threshold reached
|
||||
- memory_reserve_ratio: Always keep a reserve of tokens for recent messages
|
||||
- enable_tool_result_compact: Compact tool results separately
|
||||
- tool_result_compact_keep_n: Number of tool results to keep
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
memory_manager: Any,
|
||||
memory_compact_threshold: Optional[int] = None,
|
||||
memory_compact_ratio: float = 0.75,
|
||||
memory_reserve_ratio: float = 0.1,
|
||||
enable_tool_result_compact: bool = False,
|
||||
tool_result_compact_keep_n: int = 5,
|
||||
):
|
||||
"""Initialize memory compaction hook.
|
||||
|
||||
Args:
|
||||
memory_manager: Memory manager instance for compaction
|
||||
memory_compact_threshold: Token threshold for compaction
|
||||
memory_compact_ratio: Target ratio to compact to (e.g., 0.75 = compact to 75%)
|
||||
memory_reserve_ratio: Reserve ratio to always keep free (e.g., 0.1 = 10%)
|
||||
enable_tool_result_compact: Enable tool result compaction
|
||||
tool_result_compact_keep_n: Number of tool results to keep
|
||||
"""
|
||||
self.memory_manager = memory_manager
|
||||
self.memory_compact_threshold = memory_compact_threshold
|
||||
self.memory_compact_ratio = memory_compact_ratio
|
||||
self.memory_reserve_ratio = memory_reserve_ratio
|
||||
self.enable_tool_result_compact = enable_tool_result_compact
|
||||
self.tool_result_compact_keep_n = tool_result_compact_keep_n
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
agent: "ReActAgent",
|
||||
kwargs: Dict[str, Any],
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Pre-reasoning hook to check and compact memory if needed.
|
||||
|
||||
Args:
|
||||
agent: The agent instance
|
||||
kwargs: Input arguments to the _reasoning method
|
||||
|
||||
Returns:
|
||||
None (hook doesn't modify kwargs)
|
||||
"""
|
||||
try:
|
||||
if not hasattr(agent, "memory") or not self.memory_manager:
|
||||
return None
|
||||
|
||||
memory = agent.memory
|
||||
|
||||
# Get current token count estimate
|
||||
messages = await memory.get_memory()
|
||||
total_tokens = self._estimate_tokens(messages)
|
||||
|
||||
if self.memory_compact_threshold is None:
|
||||
return None
|
||||
|
||||
if total_tokens < self.memory_compact_threshold:
|
||||
return None
|
||||
|
||||
logger.info(
|
||||
"Memory compaction triggered: %d tokens (threshold: %d)",
|
||||
total_tokens,
|
||||
self.memory_compact_threshold,
|
||||
)
|
||||
|
||||
# Compact memory
|
||||
await self._compact_memory(agent, messages)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to compact memory: %s", e, exc_info=True)
|
||||
|
||||
return None
|
||||
|
||||
def _estimate_tokens(self, messages: List[Any]) -> int:
|
||||
"""Estimate token count for messages.
|
||||
|
||||
Args:
|
||||
messages: List of messages
|
||||
|
||||
Returns:
|
||||
Estimated token count
|
||||
"""
|
||||
# Simple estimation: ~4 chars per token
|
||||
total_chars = sum(
|
||||
len(str(getattr(msg, "content", "")))
|
||||
for msg in messages
|
||||
)
|
||||
return total_chars // 4
|
||||
|
||||
async def _compact_memory(
|
||||
self,
|
||||
agent: "ReActAgent",
|
||||
messages: List[Any],
|
||||
) -> None:
|
||||
"""Compact memory by summarizing older messages.
|
||||
|
||||
Uses CoPaw-style memory management:
|
||||
- memory_compact_ratio: Target ratio to compact to (e.g., 0.75 means compact to 75%)
|
||||
- memory_reserve_ratio: Always keep this ratio free (e.g., 0.1 means keep 10% for recent)
|
||||
|
||||
Args:
|
||||
agent: The agent instance
|
||||
messages: Current messages in memory
|
||||
"""
|
||||
if self.memory_compact_threshold is None:
|
||||
return
|
||||
|
||||
# Estimate total tokens
|
||||
total_tokens = self._estimate_tokens(messages)
|
||||
|
||||
# Calculate reserve based on ratio (CoPaw-style)
|
||||
reserve_tokens = int(total_tokens * self.memory_reserve_ratio)
|
||||
|
||||
# Calculate target tokens after compaction
|
||||
target_tokens = int(total_tokens * self.memory_compact_ratio)
|
||||
target_tokens = max(target_tokens, total_tokens - reserve_tokens)
|
||||
|
||||
# Find messages to compact (older ones)
|
||||
# Keep recent messages that fit within target
|
||||
messages_to_compact = []
|
||||
kept_tokens = 0
|
||||
|
||||
# Start from oldest, stop when we've kept enough
|
||||
for msg in messages:
|
||||
msg_tokens = self._estimate_tokens([msg])
|
||||
if kept_tokens + msg_tokens > target_tokens:
|
||||
messages_to_compact.append(msg)
|
||||
else:
|
||||
kept_tokens += msg_tokens
|
||||
|
||||
if not messages_to_compact:
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"Compacting %d messages (%d tokens) to target %d tokens",
|
||||
len(messages_to_compact),
|
||||
self._estimate_tokens(messages_to_compact),
|
||||
target_tokens,
|
||||
)
|
||||
|
||||
# Use memory manager to compact if available
|
||||
if hasattr(self.memory_manager, "compact_memory"):
|
||||
try:
|
||||
summary = await self.memory_manager.compact_memory(
|
||||
messages=messages_to_compact,
|
||||
)
|
||||
logger.info(
|
||||
"Memory compacted: %d messages summarized, summary: %s",
|
||||
len(messages_to_compact),
|
||||
summary[:200] if summary else "N/A",
|
||||
)
|
||||
|
||||
# Mark messages as compressed if supported
|
||||
if hasattr(agent.memory, "update_messages_mark"):
|
||||
from agentscope.agent._react_agent import _MemoryMark
|
||||
await agent.memory.update_messages_mark(
|
||||
new_mark=_MemoryMark.COMPRESSED,
|
||||
msg_ids=[msg.id for msg in messages_to_compact],
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Memory manager compaction failed: %s", e)
|
||||
|
||||
# Tool result compaction (CoPaw-style)
|
||||
if self.enable_tool_result_compact:
|
||||
await self._compact_tool_results(agent, messages)
|
||||
|
||||
async def _compact_tool_results(
|
||||
self,
|
||||
agent: "ReActAgent",
|
||||
messages: List[Any],
|
||||
) -> None:
|
||||
"""Compact tool results by keeping only recent ones.
|
||||
|
||||
Based on CoPaw's tool_result_compact_keep_n pattern.
|
||||
Tool results can be very verbose, so we keep only the N most recent ones.
|
||||
|
||||
Args:
|
||||
agent: The agent instance
|
||||
messages: Current messages in memory
|
||||
"""
|
||||
if not hasattr(agent.memory, "content"):
|
||||
return
|
||||
|
||||
# Find tool result messages (usually have "tool" role or tool_related content)
|
||||
tool_results = []
|
||||
for msg, _ in agent.memory.content:
|
||||
if hasattr(msg, "role") and msg.role == "tool":
|
||||
tool_results.append(msg)
|
||||
|
||||
if len(tool_results) <= self.tool_result_compact_keep_n:
|
||||
return
|
||||
|
||||
# Keep only the most recent N tool results
|
||||
excess_results = tool_results[:-self.tool_result_compact_keep_n]
|
||||
|
||||
logger.info(
|
||||
"Tool result compaction: %d tool results found, keeping %d, compacting %d",
|
||||
len(tool_results),
|
||||
self.tool_result_compact_keep_n,
|
||||
len(excess_results),
|
||||
)
|
||||
|
||||
# Mark excess tool results as compressed if supported
|
||||
if hasattr(agent.memory, "update_messages_mark"):
|
||||
from agentscope.agent._react_agent import _MemoryMark
|
||||
await agent.memory.update_messages_mark(
|
||||
new_mark=_MemoryMark.COMPRESSED,
|
||||
msg_ids=[msg.id for msg in excess_results],
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Hook",
|
||||
"HookManager",
|
||||
"HookType",
|
||||
"HOOK_PRE_REASONING",
|
||||
"HOOK_POST_ACTING",
|
||||
"BootstrapHook",
|
||||
"MemoryCompactionHook",
|
||||
"WorkspaceWatchHook",
|
||||
]
|
||||
489
backend/agents/base/skill_adaptation_hook.py
Normal file
489
backend/agents/base/skill_adaptation_hook.py
Normal file
@@ -0,0 +1,489 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Skill adaptation hook for automatic evaluation-to-iteration闭环.
|
||||
|
||||
Monitors evaluation metrics against configurable thresholds and triggers
|
||||
automatic skill reload or logs warnings when thresholds are breached.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
from .evaluation_hook import (
|
||||
EvaluationCollector,
|
||||
EvaluationResult,
|
||||
MetricType,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AdaptationAction(Enum):
|
||||
"""Actions to take when threshold is breached."""
|
||||
RELOAD = "reload" # 自动重新加载技能
|
||||
WARN = "warn" # 记录警告供人工审核
|
||||
BOTH = "both" # 同时执行重载和警告
|
||||
NONE = "none" # 不做任何操作
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdaptationThreshold:
|
||||
"""Threshold configuration for a metric."""
|
||||
metric_type: MetricType
|
||||
operator: str = "lt" # lt (less than), gt (greater than), lte, gte, eq
|
||||
value: float = 0.0
|
||||
window_size: int = 10 # 移动窗口大小,用于计算滑动平均
|
||||
min_samples: int = 5 # 最少样本数才触发检查
|
||||
action: AdaptationAction = AdaptationAction.WARN
|
||||
cooldown_seconds: int = 300 # 触发后的冷却时间
|
||||
|
||||
def evaluate(self, current_value: float) -> bool:
|
||||
"""Evaluate if threshold is breached."""
|
||||
ops = {
|
||||
"lt": lambda x, y: x < y,
|
||||
"lte": lambda x, y: x <= y,
|
||||
"gt": lambda x, y: x > y,
|
||||
"gte": lambda x, y: x >= y,
|
||||
"eq": lambda x, y: x == y,
|
||||
}
|
||||
op_func = ops.get(self.operator)
|
||||
if op_func is None:
|
||||
logger.warning(f"Unknown operator: {self.operator}")
|
||||
return False
|
||||
return op_func(current_value, self.value)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"metric_type": self.metric_type.value,
|
||||
"operator": self.operator,
|
||||
"value": self.value,
|
||||
"window_size": self.window_size,
|
||||
"min_samples": self.min_samples,
|
||||
"action": self.action.value,
|
||||
"cooldown_seconds": self.cooldown_seconds,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdaptationEvent:
|
||||
"""Record of an adaptation trigger event."""
|
||||
timestamp: str
|
||||
skill_name: str
|
||||
metric_type: MetricType
|
||||
threshold: AdaptationThreshold
|
||||
current_value: float
|
||||
avg_value: float
|
||||
action_taken: AdaptationAction
|
||||
details: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"timestamp": self.timestamp,
|
||||
"skill_name": self.skill_name,
|
||||
"metric_type": self.metric_type.value,
|
||||
"threshold": self.threshold.to_dict(),
|
||||
"current_value": self.current_value,
|
||||
"avg_value": self.avg_value,
|
||||
"action_taken": self.action_taken.value,
|
||||
"details": self.details,
|
||||
}
|
||||
|
||||
|
||||
class SkillAdaptationHook:
|
||||
"""Hook for monitoring evaluation metrics and triggering skill adaptation.
|
||||
|
||||
This hook wraps EvaluationHook to add threshold-based adaptation logic.
|
||||
When metrics breach configured thresholds, it can:
|
||||
- Automatically reload skills via SkillsManager
|
||||
- Log warnings for human review
|
||||
- Both
|
||||
"""
|
||||
|
||||
# Default thresholds for common metrics
|
||||
DEFAULT_THRESHOLDS: List[AdaptationThreshold] = [
|
||||
AdaptationThreshold(
|
||||
metric_type=MetricType.HIT_RATE,
|
||||
operator="lt",
|
||||
value=0.5,
|
||||
action=AdaptationAction.WARN,
|
||||
cooldown_seconds=600,
|
||||
),
|
||||
AdaptationThreshold(
|
||||
metric_type=MetricType.RISK_VIOLATION,
|
||||
operator="gt",
|
||||
value=0.1,
|
||||
action=AdaptationAction.WARN,
|
||||
cooldown_seconds=300,
|
||||
),
|
||||
AdaptationThreshold(
|
||||
metric_type=MetricType.DECISION_LATENCY,
|
||||
operator="gt",
|
||||
value=5000, # 5 seconds
|
||||
action=AdaptationAction.WARN,
|
||||
cooldown_seconds=300,
|
||||
),
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
storage_dir: Path,
|
||||
run_id: str,
|
||||
agent_id: str,
|
||||
thresholds: Optional[List[AdaptationThreshold]] = None,
|
||||
collector: Optional[EvaluationCollector] = None,
|
||||
):
|
||||
"""Initialize skill adaptation hook.
|
||||
|
||||
Args:
|
||||
storage_dir: Directory to store adaptation events
|
||||
run_id: Current run identifier
|
||||
agent_id: Current agent identifier
|
||||
thresholds: Custom threshold configurations (uses defaults if None)
|
||||
collector: Optional EvaluationCollector for historical data
|
||||
"""
|
||||
self.storage_dir = Path(storage_dir)
|
||||
self.run_id = run_id
|
||||
self.agent_id = agent_id
|
||||
self.thresholds = thresholds or self.DEFAULT_THRESHOLDS
|
||||
self.collector = collector or EvaluationCollector(storage_dir)
|
||||
|
||||
# Track cooldowns to prevent rapid re-triggering
|
||||
self._cooldowns: Dict[str, datetime] = {}
|
||||
|
||||
# Store recent metrics in memory for quick access
|
||||
self._recent_metrics: Dict[str, List[float]] = {}
|
||||
|
||||
# Pending adaptation events
|
||||
self._pending_events: List[AdaptationEvent] = []
|
||||
|
||||
def check_threshold(
|
||||
self,
|
||||
skill_name: str,
|
||||
metric_type: MetricType,
|
||||
current_value: float,
|
||||
) -> Optional[AdaptationEvent]:
|
||||
"""Check if a metric breaches any threshold.
|
||||
|
||||
Args:
|
||||
skill_name: Name of the skill
|
||||
metric_type: Type of metric
|
||||
current_value: Current metric value
|
||||
|
||||
Returns:
|
||||
AdaptationEvent if threshold breached, None otherwise
|
||||
"""
|
||||
# Find applicable thresholds
|
||||
applicable_thresholds = [
|
||||
t for t in self.thresholds
|
||||
if t.metric_type == metric_type
|
||||
]
|
||||
|
||||
if not applicable_thresholds:
|
||||
return None
|
||||
|
||||
# Check cooldown
|
||||
cooldown_key = f"{skill_name}:{metric_type.value}"
|
||||
now = datetime.now()
|
||||
last_trigger = self._cooldowns.get(cooldown_key)
|
||||
|
||||
# Store current value first for avg calculation
|
||||
self._store_metric(cooldown_key, current_value)
|
||||
|
||||
for threshold in applicable_thresholds:
|
||||
if last_trigger:
|
||||
elapsed = (now - last_trigger).total_seconds()
|
||||
if elapsed < threshold.cooldown_seconds:
|
||||
continue
|
||||
|
||||
# Evaluate threshold
|
||||
if threshold.evaluate(current_value):
|
||||
# Calculate moving average
|
||||
avg_value = self._calculate_avg(skill_name, metric_type, current_value)
|
||||
|
||||
# Check minimum samples (allow immediate trigger if min_samples <= 1)
|
||||
sample_count = len(self._recent_metrics.get(cooldown_key, []))
|
||||
if threshold.min_samples > 1 and sample_count < threshold.min_samples:
|
||||
# Not enough samples yet
|
||||
continue
|
||||
|
||||
# Trigger adaptation
|
||||
event = AdaptationEvent(
|
||||
timestamp=now.isoformat(),
|
||||
skill_name=skill_name,
|
||||
metric_type=metric_type,
|
||||
threshold=threshold,
|
||||
current_value=current_value,
|
||||
avg_value=avg_value,
|
||||
action_taken=threshold.action,
|
||||
details={
|
||||
"run_id": self.run_id,
|
||||
"agent_id": self.agent_id,
|
||||
},
|
||||
)
|
||||
|
||||
# Update cooldown
|
||||
self._cooldowns[cooldown_key] = now
|
||||
|
||||
# Persist event
|
||||
self._persist_event(event)
|
||||
|
||||
logger.info(
|
||||
f"Threshold breached for {skill_name}.{metric_type.value}: "
|
||||
f"current={current_value}, avg={avg_value}, action={threshold.action.value}"
|
||||
)
|
||||
|
||||
return event
|
||||
|
||||
return None
|
||||
|
||||
def _calculate_avg(
|
||||
self,
|
||||
skill_name: str,
|
||||
metric_type: MetricType,
|
||||
current_value: float,
|
||||
) -> float:
|
||||
"""Calculate moving average for a metric."""
|
||||
key = f"{skill_name}:{metric_type.value}"
|
||||
values = self._recent_metrics.get(key, [])
|
||||
if not values:
|
||||
return current_value
|
||||
return sum(values) / len(values)
|
||||
|
||||
def _store_metric(self, key: str, value: float) -> None:
|
||||
"""Store metric value with sliding window."""
|
||||
if key not in self._recent_metrics:
|
||||
self._recent_metrics[key] = []
|
||||
self._recent_metrics[key].append(value)
|
||||
# Keep only last 100 values
|
||||
if len(self._recent_metrics[key]) > 100:
|
||||
self._recent_metrics[key] = self._recent_metrics[key][-100:]
|
||||
|
||||
def _persist_event(self, event: AdaptationEvent) -> None:
|
||||
"""Persist adaptation event to storage."""
|
||||
run_dir = self.storage_dir / self.run_id / "adaptations"
|
||||
run_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
||||
filename = f"{event.skill_name}_{event.metric_type.value}_{timestamp}.json"
|
||||
filepath = run_dir / filename
|
||||
|
||||
try:
|
||||
with open(filepath, "w", encoding="utf-8") as f:
|
||||
json.dump(event.to_dict(), f, ensure_ascii=False, indent=2)
|
||||
logger.debug(f"Persisted adaptation event to: {filepath}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to persist adaptation event: {e}")
|
||||
|
||||
# Also add to pending list
|
||||
self._pending_events.append(event)
|
||||
|
||||
def get_pending_warnings(self) -> List[AdaptationEvent]:
|
||||
"""Get all pending warning events that need human review."""
|
||||
return [
|
||||
e for e in self._pending_events
|
||||
if e.action_taken in (AdaptationAction.WARN, AdaptationAction.BOTH)
|
||||
]
|
||||
|
||||
def clear_pending_warnings(self) -> None:
|
||||
"""Clear pending warnings after they have been reviewed."""
|
||||
self._pending_events = [
|
||||
e for e in self._pending_events
|
||||
if e.action_taken == AdaptationAction.RELOAD
|
||||
]
|
||||
|
||||
def get_recent_events(
|
||||
self,
|
||||
skill_name: Optional[str] = None,
|
||||
metric_type: Optional[MetricType] = None,
|
||||
limit: int = 50,
|
||||
) -> List[AdaptationEvent]:
|
||||
"""Get recent adaptation events.
|
||||
|
||||
Args:
|
||||
skill_name: Optional filter by skill name
|
||||
metric_type: Optional filter by metric type
|
||||
limit: Maximum number of events to return
|
||||
|
||||
Returns:
|
||||
List of recent adaptation events
|
||||
"""
|
||||
events_dir = self.storage_dir / self.run_id / "adaptations"
|
||||
if not events_dir.exists():
|
||||
return []
|
||||
|
||||
events = []
|
||||
for eval_file in sorted(events_dir.glob("*.json"), reverse=True)[:limit]:
|
||||
try:
|
||||
with open(eval_file, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
event = self._parse_event(data)
|
||||
if skill_name and event.skill_name != skill_name:
|
||||
continue
|
||||
if metric_type and event.metric_type != metric_type:
|
||||
continue
|
||||
events.append(event)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load adaptation event {eval_file}: {e}")
|
||||
|
||||
return events
|
||||
|
||||
def _parse_event(self, data: Dict[str, Any]) -> AdaptationEvent:
|
||||
"""Parse adaptation event from JSON data."""
|
||||
threshold_data = data.get("threshold", {})
|
||||
metric_type = MetricType(threshold_data.get("metric_type", "custom"))
|
||||
|
||||
threshold = AdaptationThreshold(
|
||||
metric_type=metric_type,
|
||||
operator=threshold_data.get("operator", "lt"),
|
||||
value=threshold_data.get("value", 0.0),
|
||||
window_size=threshold_data.get("window_size", 10),
|
||||
min_samples=threshold_data.get("min_samples", 5),
|
||||
action=AdaptationAction(threshold_data.get("action", "warn")),
|
||||
cooldown_seconds=threshold_data.get("cooldown_seconds", 300),
|
||||
)
|
||||
|
||||
return AdaptationEvent(
|
||||
timestamp=data.get("timestamp", ""),
|
||||
skill_name=data.get("skill_name", ""),
|
||||
metric_type=metric_type,
|
||||
threshold=threshold,
|
||||
current_value=data.get("current_value", 0.0),
|
||||
avg_value=data.get("avg_value", 0.0),
|
||||
action_taken=AdaptationAction(data.get("action_taken", "warn")),
|
||||
details=data.get("details", {}),
|
||||
)
|
||||
|
||||
def add_threshold(self, threshold: AdaptationThreshold) -> None:
|
||||
"""Add a new threshold configuration."""
|
||||
self.thresholds.append(threshold)
|
||||
|
||||
def remove_threshold(self, metric_type: MetricType) -> None:
|
||||
"""Remove all thresholds for a specific metric type."""
|
||||
self.thresholds = [
|
||||
t for t in self.thresholds
|
||||
if t.metric_type != metric_type
|
||||
]
|
||||
|
||||
def update_threshold(
|
||||
self,
|
||||
metric_type: MetricType,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Update threshold configuration for a metric type."""
|
||||
for threshold in self.thresholds:
|
||||
if threshold.metric_type == metric_type:
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(threshold, key):
|
||||
setattr(threshold, key, value)
|
||||
|
||||
def get_thresholds(self) -> List[AdaptationThreshold]:
|
||||
"""Get current threshold configurations."""
|
||||
return list(self.thresholds)
|
||||
|
||||
def is_in_cooldown(self, skill_name: str, metric_type: MetricType) -> bool:
|
||||
"""Check if a skill/metric combination is in cooldown period."""
|
||||
key = f"{skill_name}:{metric_type.value}"
|
||||
last_trigger = self._cooldowns.get(key)
|
||||
if not last_trigger:
|
||||
return False
|
||||
|
||||
# Find the threshold for this metric type
|
||||
for threshold in self.thresholds:
|
||||
if threshold.metric_type == metric_type:
|
||||
elapsed = (datetime.now() - last_trigger).total_seconds()
|
||||
return elapsed < threshold.cooldown_seconds
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class AdaptationManager:
|
||||
"""Manager for coordinating skill adaptation across multiple agents.
|
||||
|
||||
Provides centralized tracking of adaptation events and skill reloads.
|
||||
"""
|
||||
|
||||
def __init__(self, storage_dir: Path):
|
||||
"""Initialize adaptation manager.
|
||||
|
||||
Args:
|
||||
storage_dir: Root directory for storing adaptation data
|
||||
"""
|
||||
self.storage_dir = Path(storage_dir)
|
||||
self._hooks: Dict[str, SkillAdaptationHook] = {}
|
||||
|
||||
def get_hook(
|
||||
self,
|
||||
run_id: str,
|
||||
agent_id: str,
|
||||
thresholds: Optional[List[AdaptationThreshold]] = None,
|
||||
) -> SkillAdaptationHook:
|
||||
"""Get or create an adaptation hook for an agent.
|
||||
|
||||
Args:
|
||||
run_id: Run identifier
|
||||
agent_id: Agent identifier
|
||||
thresholds: Optional custom thresholds
|
||||
|
||||
Returns:
|
||||
SkillAdaptationHook instance
|
||||
"""
|
||||
key = f"{run_id}:{agent_id}"
|
||||
if key not in self._hooks:
|
||||
self._hooks[key] = SkillAdaptationHook(
|
||||
storage_dir=self.storage_dir,
|
||||
run_id=run_id,
|
||||
agent_id=agent_id,
|
||||
thresholds=thresholds,
|
||||
)
|
||||
return self._hooks[key]
|
||||
|
||||
def get_all_pending_warnings(self) -> List[AdaptationEvent]:
|
||||
"""Get all pending warnings from all hooks."""
|
||||
warnings = []
|
||||
for hook in self._hooks.values():
|
||||
warnings.extend(hook.get_pending_warnings())
|
||||
return warnings
|
||||
|
||||
def get_run_adaptations(self, run_id: str) -> List[AdaptationEvent]:
|
||||
"""Get all adaptation events for a run."""
|
||||
events = []
|
||||
for hook in self._hooks.values():
|
||||
if hook.run_id == run_id:
|
||||
events.extend(hook.get_recent_events())
|
||||
return events
|
||||
|
||||
|
||||
# Global manager instance
|
||||
_adaptation_manager: Optional[AdaptationManager] = None
|
||||
|
||||
|
||||
def get_adaptation_manager(storage_dir: Optional[Path] = None) -> AdaptationManager:
|
||||
"""Get global adaptation manager instance.
|
||||
|
||||
Args:
|
||||
storage_dir: Optional storage directory (required on first call)
|
||||
|
||||
Returns:
|
||||
AdaptationManager instance
|
||||
"""
|
||||
global _adaptation_manager
|
||||
if _adaptation_manager is None:
|
||||
if storage_dir is None:
|
||||
raise ValueError("storage_dir required on first initialization")
|
||||
_adaptation_manager = AdaptationManager(storage_dir)
|
||||
return _adaptation_manager
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AdaptationAction",
|
||||
"AdaptationThreshold",
|
||||
"AdaptationEvent",
|
||||
"SkillAdaptationHook",
|
||||
"AdaptationManager",
|
||||
"get_adaptation_manager",
|
||||
]
|
||||
684
backend/agents/base/tool_guard.py
Normal file
684
backend/agents/base/tool_guard.py
Normal file
@@ -0,0 +1,684 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""ToolGuardMixin - Security interception for dangerous tool calls.
|
||||
|
||||
Provides ``_acting`` and ``_reasoning`` overrides that intercept
|
||||
sensitive tool calls before execution, implementing the deny /
|
||||
guard / approve flow.
|
||||
|
||||
Based on CoPaw's tool_guard_mixin.py design.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Set
|
||||
|
||||
from agentscope.message import Msg
|
||||
from backend.runtime.manager import get_global_runtime_manager
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class SeverityLevel(str, Enum):
|
||||
"""Risk severity level."""
|
||||
|
||||
LOW = "low"
|
||||
MEDIUM = "medium"
|
||||
HIGH = "high"
|
||||
CRITICAL = "critical"
|
||||
|
||||
|
||||
class ApprovalStatus(str, Enum):
|
||||
"""Approval lifecycle state."""
|
||||
|
||||
PENDING = "pending"
|
||||
APPROVED = "approved"
|
||||
DENIED = "denied"
|
||||
EXPIRED = "expired"
|
||||
|
||||
|
||||
class ToolFindingRecord:
|
||||
"""Internal representation of a guard finding."""
|
||||
|
||||
def __init__(self, severity: SeverityLevel, message: str, field: Optional[str] = None) -> None:
|
||||
self.severity = severity
|
||||
self.message = message
|
||||
self.field = field
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"severity": self.severity.value,
|
||||
"message": self.message,
|
||||
"field": self.field,
|
||||
}
|
||||
|
||||
|
||||
class ApprovalRecord:
|
||||
"""Stores the state of an approval request."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
approval_id: str,
|
||||
tool_name: str,
|
||||
tool_input: Dict[str, Any],
|
||||
agent_id: str,
|
||||
workspace_id: str,
|
||||
session_id: Optional[str] = None,
|
||||
findings: Optional[List[ToolFindingRecord]] = None,
|
||||
) -> None:
|
||||
self.approval_id = approval_id
|
||||
self.tool_name = tool_name
|
||||
self.tool_input = tool_input
|
||||
self.agent_id = agent_id
|
||||
self.workspace_id = workspace_id
|
||||
self.session_id = session_id
|
||||
self.status = ApprovalStatus.PENDING
|
||||
self.findings = findings or []
|
||||
self.created_at = datetime.utcnow()
|
||||
self.resolved_at: Optional[datetime] = None
|
||||
self.resolved_by: Optional[str] = None
|
||||
self.metadata: Dict[str, Any] = {}
|
||||
self.pending_request: "ToolApprovalRequest" | None = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"approval_id": self.approval_id,
|
||||
"status": self.status.value,
|
||||
"tool_name": self.tool_name,
|
||||
"tool_input": self.tool_input,
|
||||
"agent_id": self.agent_id,
|
||||
"workspace_id": self.workspace_id,
|
||||
"session_id": self.session_id,
|
||||
"findings": [f.to_dict() for f in self.findings],
|
||||
"created_at": self.created_at.isoformat(),
|
||||
"resolved_at": self.resolved_at.isoformat() if self.resolved_at else None,
|
||||
"resolved_by": self.resolved_by,
|
||||
}
|
||||
|
||||
|
||||
class ToolGuardStore:
|
||||
"""Simple in-memory approval store for development/testing."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._records: Dict[str, ApprovalRecord] = {}
|
||||
self._counter = 0
|
||||
|
||||
def next_id(self) -> str:
|
||||
self._counter += 1
|
||||
return f"approval_{self._counter:06d}"
|
||||
|
||||
def list(
|
||||
self,
|
||||
status: ApprovalStatus | None = None,
|
||||
workspace_id: Optional[str] = None,
|
||||
agent_id: Optional[str] = None,
|
||||
) -> Iterable[ApprovalRecord]:
|
||||
for record in self._records.values():
|
||||
if status and record.status != status:
|
||||
continue
|
||||
if workspace_id and record.workspace_id != workspace_id:
|
||||
continue
|
||||
if agent_id and record.agent_id != agent_id:
|
||||
continue
|
||||
yield record
|
||||
|
||||
def get(self, approval_id: str) -> Optional[ApprovalRecord]:
|
||||
return self._records.get(approval_id)
|
||||
|
||||
def create_pending(
|
||||
self,
|
||||
tool_name: str,
|
||||
tool_input: Dict[str, Any],
|
||||
agent_id: str,
|
||||
workspace_id: str,
|
||||
session_id: Optional[str] = None,
|
||||
findings: Optional[List[ToolFindingRecord]] = None,
|
||||
) -> ApprovalRecord:
|
||||
record = ApprovalRecord(
|
||||
approval_id=self.next_id(),
|
||||
tool_name=tool_name,
|
||||
tool_input=tool_input,
|
||||
agent_id=agent_id,
|
||||
workspace_id=workspace_id,
|
||||
session_id=session_id,
|
||||
findings=findings,
|
||||
)
|
||||
self._records[record.approval_id] = record
|
||||
return record
|
||||
|
||||
def set_status(
|
||||
self,
|
||||
approval_id: str,
|
||||
status: ApprovalStatus,
|
||||
resolved_by: Optional[str] = None,
|
||||
notify_request: bool = True,
|
||||
) -> ApprovalRecord:
|
||||
record = self._records[approval_id]
|
||||
if record.status == status:
|
||||
return record
|
||||
|
||||
record.status = status
|
||||
record.resolved_at = datetime.utcnow()
|
||||
record.resolved_by = resolved_by
|
||||
if notify_request and record.pending_request:
|
||||
if status == ApprovalStatus.APPROVED:
|
||||
record.pending_request.approve()
|
||||
elif status == ApprovalStatus.DENIED:
|
||||
record.pending_request.deny()
|
||||
return record
|
||||
|
||||
def cancel(self, approval_id: str) -> None:
|
||||
self._records.pop(approval_id, None)
|
||||
|
||||
|
||||
TOOL_GUARD_STORE = ToolGuardStore()
|
||||
|
||||
|
||||
def get_tool_guard_store() -> ToolGuardStore:
|
||||
return TOOL_GUARD_STORE
|
||||
|
||||
|
||||
# Default tools that require approval
|
||||
DEFAULT_GUARDED_TOOLS: Set[str] = {
|
||||
"execute_shell_command",
|
||||
"write_file",
|
||||
"edit_file",
|
||||
"place_order",
|
||||
"modify_position",
|
||||
"delete_file",
|
||||
}
|
||||
|
||||
# Default denied tools (cannot be approved)
|
||||
DEFAULT_DENIED_TOOLS: Set[str] = {
|
||||
"execute_shell_command", # Shell execution is dangerous
|
||||
}
|
||||
|
||||
# Mark for tool guard denied messages
|
||||
TOOL_GUARD_DENIED_MARK = "tool_guard_denied"
|
||||
|
||||
|
||||
def default_findings_for_tool(tool_name: str) -> List[ToolFindingRecord]:
|
||||
findings: List[ToolFindingRecord] = []
|
||||
if tool_name in {"execute_trade", "modify_portfolio"}:
|
||||
findings.append(
|
||||
ToolFindingRecord(
|
||||
severity=SeverityLevel.HIGH,
|
||||
message=f"Tool '{tool_name}' touches portfolio state",
|
||||
)
|
||||
)
|
||||
return findings
|
||||
|
||||
|
||||
class ToolApprovalRequest:
|
||||
"""Represents a pending tool approval request."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
approval_id: str,
|
||||
tool_name: str,
|
||||
tool_input: Dict[str, Any],
|
||||
tool_call_id: str,
|
||||
session_id: Optional[str] = None,
|
||||
):
|
||||
self.approval_id = approval_id
|
||||
self.tool_name = tool_name
|
||||
self.tool_input = tool_input
|
||||
self.tool_call_id = tool_call_id
|
||||
self.session_id = session_id
|
||||
self.approved: Optional[bool] = None
|
||||
self._event = asyncio.Event()
|
||||
|
||||
async def wait_for_approval(self, timeout: Optional[float] = None) -> bool:
|
||||
"""Wait for approval decision.
|
||||
|
||||
Args:
|
||||
timeout: Maximum time to wait in seconds
|
||||
|
||||
Returns:
|
||||
True if approved, False otherwise
|
||||
"""
|
||||
try:
|
||||
await asyncio.wait_for(self._event.wait(), timeout=timeout)
|
||||
except asyncio.TimeoutError:
|
||||
return False
|
||||
return self.approved is True
|
||||
|
||||
def approve(self) -> None:
|
||||
"""Approve this request."""
|
||||
self.approved = True
|
||||
self._event.set()
|
||||
|
||||
def deny(self) -> None:
|
||||
"""Deny this request."""
|
||||
self.approved = False
|
||||
self._event.set()
|
||||
|
||||
|
||||
class ToolGuardMixin:
|
||||
"""Mixin that adds tool-guard interception to a ReActAgent.
|
||||
|
||||
At runtime this class is combined with ReActAgent via MRO,
|
||||
so ``super()._acting`` and ``super()._reasoning`` resolve to
|
||||
the concrete agent methods.
|
||||
|
||||
Usage:
|
||||
class MyAgent(ToolGuardMixin, ReActAgent):
|
||||
def __init__(self, ...):
|
||||
super().__init__(...)
|
||||
self._init_tool_guard()
|
||||
"""
|
||||
|
||||
def _init_tool_guard(
|
||||
self,
|
||||
guarded_tools: Optional[Set[str]] = None,
|
||||
denied_tools: Optional[Set[str]] = None,
|
||||
approval_timeout: float = 300.0,
|
||||
) -> None:
|
||||
"""Initialize tool guard.
|
||||
|
||||
Args:
|
||||
guarded_tools: Set of tool names requiring approval
|
||||
denied_tools: Set of tool names that are always denied
|
||||
approval_timeout: Timeout for approval requests in seconds
|
||||
"""
|
||||
self._guarded_tools = guarded_tools or DEFAULT_GUARDED_TOOLS.copy()
|
||||
self._denied_tools = denied_tools or DEFAULT_DENIED_TOOLS.copy()
|
||||
self._approval_timeout = approval_timeout
|
||||
self._pending_approval: Optional[ToolApprovalRequest] = None
|
||||
self._approval_callback: Optional[Callable[[ToolApprovalRequest], None]] = None
|
||||
self._approval_lock = asyncio.Lock()
|
||||
|
||||
def set_approval_callback(
|
||||
self,
|
||||
callback: Callable[[ToolApprovalRequest], None],
|
||||
) -> None:
|
||||
"""Set callback for approval requests.
|
||||
|
||||
Args:
|
||||
callback: Function called when approval is needed
|
||||
"""
|
||||
self._approval_callback = callback
|
||||
|
||||
def _is_tool_guarded(self, tool_name: str) -> bool:
|
||||
"""Check if a tool requires approval.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool
|
||||
|
||||
Returns:
|
||||
True if tool requires approval
|
||||
"""
|
||||
return tool_name in self._guarded_tools
|
||||
|
||||
def _is_tool_denied(self, tool_name: str) -> bool:
|
||||
"""Check if a tool is always denied.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool
|
||||
|
||||
Returns:
|
||||
True if tool is denied
|
||||
"""
|
||||
return tool_name in self._denied_tools
|
||||
|
||||
def _last_tool_response_is_denied(self) -> bool:
|
||||
"""Check if the last message is a guard-denied tool result."""
|
||||
if not hasattr(self, "memory") or not self.memory.content:
|
||||
return False
|
||||
|
||||
msg, marks = self.memory.content[-1]
|
||||
return TOOL_GUARD_DENIED_MARK in marks and msg.role == "system"
|
||||
|
||||
async def _cleanup_tool_guard_denied_messages(
|
||||
self,
|
||||
include_denial_response: bool = True,
|
||||
) -> None:
|
||||
"""Remove tool-guard denied messages from memory.
|
||||
|
||||
Args:
|
||||
include_denial_response: Also remove the assistant's denial explanation
|
||||
"""
|
||||
if not hasattr(self, "memory"):
|
||||
return
|
||||
|
||||
ids_to_delete: list[str] = []
|
||||
last_marked_idx = -1
|
||||
|
||||
for i, (msg, marks) in enumerate(self.memory.content):
|
||||
if TOOL_GUARD_DENIED_MARK in marks:
|
||||
ids_to_delete.append(msg.id)
|
||||
last_marked_idx = i
|
||||
|
||||
if (
|
||||
include_denial_response
|
||||
and last_marked_idx >= 0
|
||||
and last_marked_idx + 1 < len(self.memory.content)
|
||||
):
|
||||
next_msg, _ = self.memory.content[last_marked_idx + 1]
|
||||
if next_msg.role == "assistant":
|
||||
ids_to_delete.append(next_msg.id)
|
||||
|
||||
if ids_to_delete:
|
||||
removed = await self.memory.delete(ids_to_delete)
|
||||
logger.info("Tool guard: cleaned up %d denied message(s)", removed)
|
||||
|
||||
async def _request_guard_approval(
|
||||
self,
|
||||
tool_name: str,
|
||||
tool_input: Dict[str, Any],
|
||||
tool_call_id: str,
|
||||
) -> bool:
|
||||
"""Request approval for a guarded tool call.
|
||||
|
||||
This method creates a ToolApprovalRequest and waits for
|
||||
external approval via approve_guard_call() or deny_guard_call().
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool
|
||||
tool_input: Tool input parameters
|
||||
tool_call_id: ID of the tool call
|
||||
|
||||
Returns:
|
||||
True if approved, False otherwise
|
||||
"""
|
||||
async with self._approval_lock:
|
||||
record = TOOL_GUARD_STORE.create_pending(
|
||||
tool_name=tool_name,
|
||||
tool_input=tool_input,
|
||||
agent_id=getattr(self, "agent_id", "unknown"),
|
||||
workspace_id=getattr(self, "workspace_id", "default"),
|
||||
session_id=getattr(self, "session_id", None),
|
||||
findings=default_findings_for_tool(tool_name),
|
||||
)
|
||||
|
||||
manager = get_global_runtime_manager()
|
||||
if manager:
|
||||
manager.register_pending_approval(
|
||||
record.approval_id,
|
||||
{
|
||||
"tool_name": record.tool_name,
|
||||
"agent_id": record.agent_id,
|
||||
"workspace_id": record.workspace_id,
|
||||
"session_id": record.session_id,
|
||||
"tool_input": record.tool_input,
|
||||
},
|
||||
)
|
||||
|
||||
self._pending_approval = ToolApprovalRequest(
|
||||
approval_id=record.approval_id,
|
||||
tool_name=tool_name,
|
||||
tool_input=tool_input,
|
||||
tool_call_id=tool_call_id,
|
||||
session_id=getattr(self, "session_id", None),
|
||||
)
|
||||
record.pending_request = self._pending_approval
|
||||
|
||||
# Notify via callback if set
|
||||
if self._approval_callback:
|
||||
self._approval_callback(self._pending_approval)
|
||||
|
||||
# Wait for approval (lock is released during wait, re-acquired after)
|
||||
approval_request = self._pending_approval
|
||||
|
||||
# Wait for approval outside the lock to allow concurrent approval
|
||||
approved = await approval_request.wait_for_approval(
|
||||
timeout=self._approval_timeout
|
||||
)
|
||||
|
||||
async with self._approval_lock:
|
||||
if approval_request:
|
||||
status = (
|
||||
ApprovalStatus.APPROVED
|
||||
if approval_request.approved is True
|
||||
else ApprovalStatus.DENIED
|
||||
if approval_request.approved is False
|
||||
else ApprovalStatus.EXPIRED
|
||||
)
|
||||
TOOL_GUARD_STORE.set_status(
|
||||
approval_request.approval_id,
|
||||
status,
|
||||
resolved_by="agent",
|
||||
notify_request=False,
|
||||
)
|
||||
manager = get_global_runtime_manager()
|
||||
if manager:
|
||||
manager.resolve_pending_approval(
|
||||
approval_request.approval_id,
|
||||
resolved_by="agent",
|
||||
status=status.value,
|
||||
)
|
||||
|
||||
# Only clear if this is still the same request
|
||||
if self._pending_approval is approval_request:
|
||||
self._pending_approval = None
|
||||
|
||||
return approved
|
||||
|
||||
async def approve_guard_call(self, request_id: Optional[str] = None) -> bool:
|
||||
"""Approve a pending guard request.
|
||||
|
||||
This method is called externally to approve a tool call
|
||||
that is waiting for approval.
|
||||
|
||||
Args:
|
||||
request_id: Optional request ID to verify (not yet implemented)
|
||||
|
||||
Returns:
|
||||
True if a request was approved, False if no pending request
|
||||
"""
|
||||
async with self._approval_lock:
|
||||
if self._pending_approval is None:
|
||||
logger.warning("No pending approval request to approve")
|
||||
return False
|
||||
|
||||
TOOL_GUARD_STORE.set_status(
|
||||
self._pending_approval.approval_id,
|
||||
ApprovalStatus.APPROVED,
|
||||
resolved_by="agent",
|
||||
notify_request=False,
|
||||
)
|
||||
manager = get_global_runtime_manager()
|
||||
if manager:
|
||||
manager.resolve_pending_approval(
|
||||
self._pending_approval.approval_id,
|
||||
resolved_by="agent",
|
||||
status=ApprovalStatus.APPROVED.value,
|
||||
)
|
||||
self._pending_approval.approve()
|
||||
logger.info("Approved tool call: %s", self._pending_approval.tool_name)
|
||||
return True
|
||||
|
||||
async def deny_guard_call(self, request_id: Optional[str] = None) -> bool:
|
||||
"""Deny a pending guard request.
|
||||
|
||||
This method is called externally to deny a tool call
|
||||
that is waiting for approval.
|
||||
|
||||
Args:
|
||||
request_id: Optional request ID to verify (not yet implemented)
|
||||
|
||||
Returns:
|
||||
True if a request was denied, False if no pending request
|
||||
"""
|
||||
async with self._approval_lock:
|
||||
if self._pending_approval is None:
|
||||
logger.warning("No pending approval request to deny")
|
||||
return False
|
||||
|
||||
TOOL_GUARD_STORE.set_status(
|
||||
self._pending_approval.approval_id,
|
||||
ApprovalStatus.DENIED,
|
||||
resolved_by="agent",
|
||||
notify_request=False,
|
||||
)
|
||||
manager = get_global_runtime_manager()
|
||||
if manager:
|
||||
manager.resolve_pending_approval(
|
||||
self._pending_approval.approval_id,
|
||||
resolved_by="agent",
|
||||
status=ApprovalStatus.DENIED.value,
|
||||
)
|
||||
self._pending_approval.deny()
|
||||
logger.info("Denied tool call: %s", self._pending_approval.tool_name)
|
||||
return True
|
||||
|
||||
async def _acting(self, tool_call) -> dict | None:
|
||||
"""Intercept sensitive tool calls before execution.
|
||||
|
||||
1. If tool is in denied_tools, auto-deny unconditionally.
|
||||
2. Check for a one-shot pre-approval.
|
||||
3. If tool is in the guarded scope, request approval.
|
||||
4. Otherwise, delegate to parent _acting.
|
||||
|
||||
Args:
|
||||
tool_call: Tool call from the model
|
||||
|
||||
Returns:
|
||||
Tool result dict or None
|
||||
"""
|
||||
tool_name: str = tool_call.get("name", "")
|
||||
tool_input: dict = tool_call.get("input", {})
|
||||
tool_call_id: str = tool_call.get("id", "")
|
||||
|
||||
# Check if tool is denied
|
||||
if tool_name and self._is_tool_denied(tool_name):
|
||||
logger.warning("Tool '%s' is in the denied set, auto-denying", tool_name)
|
||||
return await self._acting_auto_denied(tool_call, tool_name)
|
||||
|
||||
# Check if tool is guarded
|
||||
if tool_name and self._is_tool_guarded(tool_name):
|
||||
approved = await self._request_guard_approval(
|
||||
tool_name=tool_name,
|
||||
tool_input=tool_input,
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
|
||||
if not approved:
|
||||
return await self._acting_with_denial(tool_call, tool_name)
|
||||
|
||||
# Call parent _acting
|
||||
return await super()._acting(tool_call) # type: ignore[misc]
|
||||
|
||||
async def _acting_auto_denied(
|
||||
self,
|
||||
tool_call: Dict[str, Any],
|
||||
tool_name: str,
|
||||
) -> dict | None:
|
||||
"""Auto-deny a tool call without offering approval.
|
||||
|
||||
Args:
|
||||
tool_call: Tool call from the model
|
||||
tool_name: Name of the denied tool
|
||||
|
||||
Returns:
|
||||
Denial result
|
||||
"""
|
||||
from agentscope.message import ToolResultBlock
|
||||
|
||||
denied_text = (
|
||||
f"⛔ **Tool Blocked / 工具已拦截**\n\n"
|
||||
f"- Tool / 工具: `{tool_name}`\n"
|
||||
f"- Reason / 原因: This tool is blocked for security reasons\n\n"
|
||||
f"This tool is blocked and cannot be approved.\n"
|
||||
f"该工具已被禁止,无法批准执行。"
|
||||
)
|
||||
|
||||
tool_res_msg = Msg(
|
||||
"system",
|
||||
[
|
||||
ToolResultBlock(
|
||||
type="tool_result",
|
||||
id=tool_call.get("id", ""),
|
||||
name=tool_name,
|
||||
output=[{"type": "text", "text": denied_text}],
|
||||
),
|
||||
],
|
||||
"system",
|
||||
)
|
||||
|
||||
await self.print(tool_res_msg, True)
|
||||
await self.memory.add(tool_res_msg)
|
||||
return None
|
||||
|
||||
async def _acting_with_denial(
|
||||
self,
|
||||
tool_call: Dict[str, Any],
|
||||
tool_name: str,
|
||||
) -> dict | None:
|
||||
"""Deny the tool call after approval was rejected.
|
||||
|
||||
Args:
|
||||
tool_call: Tool call from the model
|
||||
tool_name: Name of the tool
|
||||
|
||||
Returns:
|
||||
Denial result
|
||||
"""
|
||||
from agentscope.message import ToolResultBlock
|
||||
|
||||
params_text = json.dumps(
|
||||
tool_call.get("input", {}),
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
|
||||
denied_text = (
|
||||
f"⚠️ **Tool Call Denied / 工具调用被拒绝**\n\n"
|
||||
f"- Tool / 工具: `{tool_name}`\n"
|
||||
f"- Parameters / 参数:\n"
|
||||
f"```json\n{params_text}\n```\n\n"
|
||||
f"The tool call was denied by the user or timed out.\n"
|
||||
f"工具调用被用户拒绝或已超时。"
|
||||
)
|
||||
|
||||
tool_res_msg = Msg(
|
||||
"system",
|
||||
[
|
||||
ToolResultBlock(
|
||||
type="tool_result",
|
||||
id=tool_call.get("id", ""),
|
||||
name=tool_name,
|
||||
output=[{"type": "text", "text": denied_text}],
|
||||
),
|
||||
],
|
||||
"system",
|
||||
)
|
||||
|
||||
await self.print(tool_res_msg, True)
|
||||
await self.memory.add(tool_res_msg, marks=TOOL_GUARD_DENIED_MARK)
|
||||
return None
|
||||
|
||||
async def _reasoning(self, **kwargs) -> Msg:
|
||||
"""Short-circuit reasoning when awaiting guard approval.
|
||||
|
||||
If the last message was a guard denial, return a waiting message
|
||||
instead of continuing reasoning.
|
||||
|
||||
Returns:
|
||||
Response message
|
||||
"""
|
||||
if self._last_tool_response_is_denied():
|
||||
msg = Msg(
|
||||
self.name,
|
||||
"⏳ Waiting for approval / 等待审批...\n\n"
|
||||
"Type `/approve` to approve, or send any message to deny.\n"
|
||||
"输入 `/approve` 批准执行,或发送任意消息拒绝。",
|
||||
"assistant",
|
||||
)
|
||||
await self.print(msg, True)
|
||||
await self.memory.add(msg)
|
||||
return msg
|
||||
|
||||
return await super()._reasoning(**kwargs) # type: ignore[misc]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ToolGuardMixin",
|
||||
"ToolApprovalRequest",
|
||||
"DEFAULT_GUARDED_TOOLS",
|
||||
"DEFAULT_DENIED_TOOLS",
|
||||
"TOOL_GUARD_DENIED_MARK",
|
||||
]
|
||||
146
backend/agents/compat.py
Normal file
146
backend/agents/compat.py
Normal file
@@ -0,0 +1,146 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Compatibility Layer - Adapters for legacy to EvoAgent migration.
|
||||
|
||||
Provides:
|
||||
- LegacyAgentAdapter: Wraps old AnalystAgent to work with new interfaces
|
||||
- Migration utilities for gradual adoption
|
||||
"""
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from agentscope.message import Msg
|
||||
|
||||
from .agent_core import EvoAgent
|
||||
|
||||
|
||||
class LegacyAgentAdapter:
|
||||
"""
|
||||
Adapter to make legacy AnalystAgent compatible with EvoAgent interfaces.
|
||||
|
||||
This allows gradual migration by wrapping existing agents.
|
||||
"""
|
||||
|
||||
def __init__(self, legacy_agent: Any):
|
||||
"""
|
||||
Initialize adapter.
|
||||
|
||||
Args:
|
||||
legacy_agent: Legacy AnalystAgent instance
|
||||
"""
|
||||
self._agent = legacy_agent
|
||||
self.agent_id = getattr(legacy_agent, 'agent_id', getattr(legacy_agent, 'name', 'unknown'))
|
||||
self.analyst_type = getattr(legacy_agent, 'analyst_type_key', None)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Get agent name."""
|
||||
return getattr(self._agent, 'name', self.agent_id)
|
||||
|
||||
@property
|
||||
def toolkit(self) -> Any:
|
||||
"""Get agent toolkit."""
|
||||
return getattr(self._agent, 'toolkit', None)
|
||||
|
||||
@property
|
||||
def model(self) -> Any:
|
||||
"""Get agent model."""
|
||||
return getattr(self._agent, 'model', None)
|
||||
|
||||
@property
|
||||
def memory(self) -> Any:
|
||||
"""Get agent memory."""
|
||||
return getattr(self._agent, 'memory', None)
|
||||
|
||||
async def reply(self, x: Msg = None) -> Msg:
|
||||
"""
|
||||
Delegate to legacy agent's reply method.
|
||||
|
||||
Args:
|
||||
x: Input message
|
||||
|
||||
Returns:
|
||||
Response message
|
||||
"""
|
||||
return await self._agent.reply(x)
|
||||
|
||||
def reload_runtime_assets(self, active_skill_dirs: Optional[list] = None) -> None:
|
||||
"""
|
||||
Reload runtime assets if supported.
|
||||
|
||||
Args:
|
||||
active_skill_dirs: Optional list of active skill directories
|
||||
"""
|
||||
if hasattr(self._agent, 'reload_runtime_assets'):
|
||||
self._agent.reload_runtime_assets(active_skill_dirs)
|
||||
|
||||
def to_evo_agent(
|
||||
self,
|
||||
workspace_manager: Optional[Any] = None,
|
||||
enable_tool_guard: bool = False,
|
||||
) -> EvoAgent:
|
||||
"""
|
||||
Convert legacy agent to EvoAgent.
|
||||
|
||||
Args:
|
||||
workspace_manager: Optional workspace manager
|
||||
enable_tool_guard: Whether to enable tool guard
|
||||
|
||||
Returns:
|
||||
New EvoAgent instance with same configuration
|
||||
"""
|
||||
return EvoAgent(
|
||||
agent_id=self.agent_id,
|
||||
model=self.model,
|
||||
formatter=getattr(self._agent, 'formatter', None),
|
||||
toolkit=self.toolkit,
|
||||
workspace_manager=workspace_manager,
|
||||
config=getattr(self._agent, 'config', {}),
|
||||
long_term_memory=getattr(self._agent, 'long_term_memory', None),
|
||||
enable_tool_guard=enable_tool_guard,
|
||||
sys_prompt=getattr(self._agent, '_sys_prompt', None),
|
||||
)
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
"""Delegate unknown attributes to wrapped agent."""
|
||||
return getattr(self._agent, name)
|
||||
|
||||
|
||||
def is_legacy_agent(agent: Any) -> bool:
|
||||
"""
|
||||
Check if an agent is a legacy agent.
|
||||
|
||||
Args:
|
||||
agent: Agent instance to check
|
||||
|
||||
Returns:
|
||||
True if legacy agent
|
||||
"""
|
||||
return hasattr(agent, 'analyst_type_key') and not isinstance(agent, EvoAgent)
|
||||
|
||||
|
||||
def adapt_agent(agent: Any) -> Any:
|
||||
"""
|
||||
Wrap agent in adapter if it's a legacy agent.
|
||||
|
||||
Args:
|
||||
agent: Agent instance
|
||||
|
||||
Returns:
|
||||
Adapted agent or original if already EvoAgent
|
||||
"""
|
||||
if is_legacy_agent(agent):
|
||||
return LegacyAgentAdapter(agent)
|
||||
return agent
|
||||
|
||||
|
||||
def adapt_agents(agents: list) -> list:
|
||||
"""
|
||||
Wrap multiple agents in adapters.
|
||||
|
||||
Args:
|
||||
agents: List of agent instances
|
||||
|
||||
Returns:
|
||||
List of adapted agents
|
||||
"""
|
||||
return [adapt_agent(agent) for agent in agents]
|
||||
332
backend/agents/factory.py
Normal file
332
backend/agents/factory.py
Normal file
@@ -0,0 +1,332 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Agent Factory - Dynamic creation and management of AgentConfigs."""
|
||||
|
||||
import logging
|
||||
import shutil
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import yaml
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelConfig:
|
||||
"""Model configuration for an agent."""
|
||||
|
||||
model_name: str = "gpt-4o"
|
||||
temperature: float = 0.7
|
||||
max_tokens: int = 4096
|
||||
|
||||
|
||||
class AgentConfig:
|
||||
"""Represents a configured agent instance (data class)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_id: str,
|
||||
agent_type: str,
|
||||
workspace_id: str,
|
||||
config_path: Path,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
):
|
||||
self.agent_id = agent_id
|
||||
self.agent_type = agent_type
|
||||
self.workspace_id = workspace_id
|
||||
self.config_path = config_path
|
||||
self.model_config = model_config or ModelConfig()
|
||||
self.agent_dir = config_path.parent
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Serialize agent to dictionary."""
|
||||
return {
|
||||
"agent_id": self.agent_id,
|
||||
"agent_type": self.agent_type,
|
||||
"workspace_id": self.workspace_id,
|
||||
"config_path": str(self.config_path),
|
||||
"agent_dir": str(self.agent_dir),
|
||||
"model_config": {
|
||||
"model_name": self.model_config.model_name,
|
||||
"temperature": self.model_config.temperature,
|
||||
"max_tokens": self.model_config.max_tokens,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class AgentFactory:
|
||||
"""Factory for creating, cloning, and managing agents."""
|
||||
|
||||
def __init__(self, project_root: Optional[Path] = None):
|
||||
"""Initialize the agent factory.
|
||||
|
||||
Args:
|
||||
project_root: Root directory of the project
|
||||
"""
|
||||
self.project_root = project_root or Path(__file__).parent.parent.parent
|
||||
self.workspaces_root = self.project_root / "workspaces"
|
||||
self.template_dir = self.project_root / "backend" / "workspaces" / ".template"
|
||||
|
||||
def create_agent(
|
||||
self,
|
||||
agent_id: str,
|
||||
agent_type: str,
|
||||
workspace_id: str,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
clone_from: Optional[str] = None,
|
||||
) -> AgentConfig:
|
||||
"""Create a new agent.
|
||||
|
||||
Args:
|
||||
agent_id: Unique identifier for the agent
|
||||
agent_type: Type of agent (e.g., "technical_analyst")
|
||||
workspace_id: ID of the workspace to create agent in
|
||||
model_config: Model configuration
|
||||
clone_from: Path to existing agent to clone from (optional)
|
||||
|
||||
Returns:
|
||||
AgentConfig instance
|
||||
|
||||
Raises:
|
||||
ValueError: If agent already exists or workspace doesn't exist
|
||||
"""
|
||||
workspace_dir = self.workspaces_root / workspace_id
|
||||
if not workspace_dir.exists():
|
||||
raise ValueError(f"Workspace '{workspace_id}' does not exist")
|
||||
|
||||
agent_dir = workspace_dir / "agents" / agent_id
|
||||
if agent_dir.exists():
|
||||
raise ValueError(f"Agent '{agent_id}' already exists in workspace '{workspace_id}'")
|
||||
|
||||
# Create directory structure
|
||||
agent_dir.mkdir(parents=True, exist_ok=True)
|
||||
(agent_dir / "skills" / "active").mkdir(parents=True, exist_ok=True)
|
||||
(agent_dir / "skills" / "local").mkdir(parents=True, exist_ok=True)
|
||||
(agent_dir / "skills" / "installed").mkdir(parents=True, exist_ok=True)
|
||||
(agent_dir / "skills" / "disabled").mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Copy template or clone existing agent
|
||||
if clone_from:
|
||||
self._clone_agent_files(clone_from, agent_dir, agent_id)
|
||||
else:
|
||||
self._copy_template(agent_dir, agent_id, agent_type)
|
||||
|
||||
# Write agent.yaml
|
||||
config_path = agent_dir / "agent.yaml"
|
||||
self._write_agent_yaml(config_path, agent_id, agent_type, model_config)
|
||||
|
||||
return AgentConfig(
|
||||
agent_id=agent_id,
|
||||
agent_type=agent_type,
|
||||
workspace_id=workspace_id,
|
||||
config_path=config_path,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
def delete_agent(self, agent_id: str, workspace_id: str) -> bool:
|
||||
"""Delete an agent and its workspace.
|
||||
|
||||
Args:
|
||||
agent_id: ID of the agent to delete
|
||||
workspace_id: ID of the workspace containing the agent
|
||||
|
||||
Returns:
|
||||
True if deleted, False if agent didn't exist
|
||||
"""
|
||||
agent_dir = self.workspaces_root / workspace_id / "agents" / agent_id
|
||||
if not agent_dir.exists():
|
||||
return False
|
||||
|
||||
shutil.rmtree(agent_dir)
|
||||
return True
|
||||
|
||||
def clone_agent(
|
||||
self,
|
||||
source_agent_id: str,
|
||||
source_workspace_id: str,
|
||||
new_agent_id: str,
|
||||
target_workspace_id: Optional[str] = None,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
) -> AgentConfig:
|
||||
"""Clone an existing agent.
|
||||
|
||||
Args:
|
||||
source_agent_id: ID of the agent to clone
|
||||
source_workspace_id: Workspace containing the source agent
|
||||
new_agent_id: ID for the new agent
|
||||
target_workspace_id: Target workspace (defaults to source workspace)
|
||||
model_config: Optional new model configuration
|
||||
|
||||
Returns:
|
||||
AgentConfig instance for the cloned agent
|
||||
"""
|
||||
target_workspace_id = target_workspace_id or source_workspace_id
|
||||
source_dir = self.workspaces_root / source_workspace_id / "agents" / source_agent_id
|
||||
|
||||
if not source_dir.exists():
|
||||
raise ValueError(f"Source agent '{source_agent_id}' not found")
|
||||
|
||||
# Load source agent config
|
||||
source_config_path = source_dir / "agent.yaml"
|
||||
source_config = {}
|
||||
if source_config_path.exists():
|
||||
with open(source_config_path, "r", encoding="utf-8") as f:
|
||||
source_config = yaml.safe_load(f) or {}
|
||||
|
||||
agent_type = source_config.get("agent_type", "generic")
|
||||
|
||||
# Determine source path for cloning
|
||||
clone_from = str(source_dir)
|
||||
|
||||
return self.create_agent(
|
||||
agent_id=new_agent_id,
|
||||
agent_type=agent_type,
|
||||
workspace_id=target_workspace_id,
|
||||
model_config=model_config,
|
||||
clone_from=clone_from,
|
||||
)
|
||||
|
||||
def list_agents(self, workspace_id: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
"""List all agents.
|
||||
|
||||
Args:
|
||||
workspace_id: Optional workspace to filter by
|
||||
|
||||
Returns:
|
||||
List of agent information dictionaries
|
||||
"""
|
||||
agents = []
|
||||
|
||||
if workspace_id:
|
||||
workspaces = [self.workspaces_root / workspace_id]
|
||||
else:
|
||||
if not self.workspaces_root.exists():
|
||||
return agents
|
||||
workspaces = [d for d in self.workspaces_root.iterdir() if d.is_dir()]
|
||||
|
||||
for workspace in workspaces:
|
||||
agents_dir = workspace / "agents"
|
||||
if not agents_dir.exists():
|
||||
continue
|
||||
|
||||
for agent_dir in agents_dir.iterdir():
|
||||
if not agent_dir.is_dir():
|
||||
continue
|
||||
|
||||
config_path = agent_dir / "agent.yaml"
|
||||
if config_path.exists():
|
||||
try:
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config = yaml.safe_load(f) or {}
|
||||
|
||||
agents.append({
|
||||
"agent_id": agent_dir.name,
|
||||
"workspace_id": workspace.name,
|
||||
"agent_type": config.get("agent_type", "unknown"),
|
||||
"config_path": str(config_path),
|
||||
})
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load agent config {config_path}: {e}")
|
||||
|
||||
return agents
|
||||
|
||||
def _copy_template(
|
||||
self,
|
||||
agent_dir: Path,
|
||||
agent_id: str,
|
||||
agent_type: str,
|
||||
) -> None:
|
||||
"""Copy template files to agent directory.
|
||||
|
||||
Args:
|
||||
agent_dir: Target agent directory
|
||||
agent_id: ID of the agent
|
||||
agent_type: Type of the agent
|
||||
"""
|
||||
# Create default markdown files
|
||||
default_files = {
|
||||
"AGENTS.md": f"# Agent Guide\n\nDocument how {agent_id} should work, collaborate, and choose tools or skills.\n\n",
|
||||
"SOUL.md": f"# Soul\n\nDescribe {agent_id}'s temperament, reasoning posture, and voice.\n\n",
|
||||
"PROFILE.md": f"# Profile\n\nTrack {agent_id}'s long-lived investment style, preferences, and strengths.\n\n",
|
||||
"MEMORY.md": f"# Memory\n\nStore durable lessons, heuristics, and reminders for {agent_id}.\n\n",
|
||||
"POLICY.md": f"# Policy\n\nOptional run-scoped constraints, limits, or strategy policy.\n\n",
|
||||
}
|
||||
|
||||
for filename, content in default_files.items():
|
||||
filepath = agent_dir / filename
|
||||
if not filepath.exists():
|
||||
filepath.write_text(content, encoding="utf-8")
|
||||
|
||||
def _clone_agent_files(self, source_path: str, target_dir: Path, new_agent_id: str) -> None:
|
||||
"""Clone files from an existing agent.
|
||||
|
||||
Args:
|
||||
source_path: Path to source agent directory
|
||||
target_dir: Target agent directory
|
||||
new_agent_id: ID for the new agent
|
||||
"""
|
||||
source_dir = Path(source_path)
|
||||
if not source_dir.exists():
|
||||
raise ValueError(f"Source path '{source_path}' does not exist")
|
||||
|
||||
# Copy markdown files
|
||||
for md_file in source_dir.glob("*.md"):
|
||||
target_file = target_dir / md_file.name
|
||||
content = md_file.read_text(encoding="utf-8")
|
||||
# Update agent references in content
|
||||
source_name = source_dir.name
|
||||
content = content.replace(source_name, new_agent_id)
|
||||
target_file.write_text(content, encoding="utf-8")
|
||||
|
||||
# Copy skills directory structure (but not contents)
|
||||
for skill_subdir in ["active", "local", "installed", "disabled"]:
|
||||
source_skills = source_dir / "skills" / skill_subdir
|
||||
if source_skills.exists():
|
||||
target_skills = target_dir / "skills" / skill_subdir
|
||||
target_skills.mkdir(parents=True, exist_ok=True)
|
||||
# Copy skill files
|
||||
for skill_file in source_skills.iterdir():
|
||||
if skill_file.is_file():
|
||||
shutil.copy2(skill_file, target_skills / skill_file.name)
|
||||
|
||||
def _write_agent_yaml(
|
||||
self,
|
||||
config_path: Path,
|
||||
agent_id: str,
|
||||
agent_type: str,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
) -> None:
|
||||
"""Write agent.yaml configuration file.
|
||||
|
||||
Args:
|
||||
config_path: Path to write configuration
|
||||
agent_id: Agent ID
|
||||
agent_type: Agent type
|
||||
model_config: Optional model configuration
|
||||
"""
|
||||
config = {
|
||||
"agent_id": agent_id,
|
||||
"agent_type": agent_type,
|
||||
"prompt_files": [
|
||||
"SOUL.md",
|
||||
"PROFILE.md",
|
||||
"AGENTS.md",
|
||||
"POLICY.md",
|
||||
"MEMORY.md",
|
||||
],
|
||||
"enabled_skills": [],
|
||||
"disabled_skills": [],
|
||||
"active_tool_groups": [],
|
||||
"disabled_tool_groups": [],
|
||||
}
|
||||
|
||||
if model_config:
|
||||
config["model"] = {
|
||||
"name": model_config.model_name,
|
||||
"temperature": model_config.temperature,
|
||||
"max_tokens": model_config.max_tokens,
|
||||
}
|
||||
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
yaml.safe_dump(config, f, allow_unicode=True, sort_keys=False)
|
||||
@@ -4,7 +4,8 @@ Portfolio Manager Agent - Based on AgentScope ReActAgent
|
||||
Responsible for decision-making (NOT trade execution)
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Callable
|
||||
|
||||
from agentscope.agent import ReActAgent
|
||||
from agentscope.memory import InMemoryMemory, LongTermMemoryBase
|
||||
@@ -13,6 +14,8 @@ from agentscope.tool import Toolkit, ToolResponse
|
||||
|
||||
from ..utils.progress import progress
|
||||
from .prompt_factory import build_agent_system_prompt, clear_prompt_factory_cache
|
||||
from .team_pipeline_config import update_active_analysts
|
||||
from ..config.constants import ANALYST_TYPES
|
||||
|
||||
|
||||
class PMAgent(ReActAgent):
|
||||
@@ -61,6 +64,8 @@ class PMAgent(ReActAgent):
|
||||
"_toolkit_factory_kwargs",
|
||||
toolkit_factory_kwargs,
|
||||
)
|
||||
object.__setattr__(self, "_create_team_agent_cb", None)
|
||||
object.__setattr__(self, "_remove_team_agent_cb", None)
|
||||
|
||||
# Create toolkit after local state is ready so bound tool methods can be registered.
|
||||
if toolkit is None:
|
||||
@@ -152,6 +157,107 @@ class PMAgent(ReActAgent):
|
||||
],
|
||||
)
|
||||
|
||||
def _add_team_analyst(self, agent_id: str) -> ToolResponse:
|
||||
"""Add one analyst to active discussion team."""
|
||||
config_name = self.config.get("config_name", "default")
|
||||
project_root = Path(__file__).resolve().parents[2]
|
||||
active = update_active_analysts(
|
||||
project_root=project_root,
|
||||
config_name=config_name,
|
||||
available_analysts=list(ANALYST_TYPES.keys()),
|
||||
add=[agent_id],
|
||||
)
|
||||
return ToolResponse(
|
||||
content=[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=(
|
||||
f"Active analyst team updated. Added: {agent_id}. "
|
||||
f"Current active analysts: {', '.join(active)}"
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
def _remove_team_analyst(self, agent_id: str) -> ToolResponse:
|
||||
"""Remove one analyst from active discussion team."""
|
||||
callback_msg = ""
|
||||
callback = self._remove_team_agent_cb
|
||||
if callback is not None:
|
||||
callback_msg = callback(agent_id=agent_id)
|
||||
|
||||
config_name = self.config.get("config_name", "default")
|
||||
project_root = Path(__file__).resolve().parents[2]
|
||||
active = update_active_analysts(
|
||||
project_root=project_root,
|
||||
config_name=config_name,
|
||||
available_analysts=list(ANALYST_TYPES.keys()),
|
||||
remove=[agent_id],
|
||||
)
|
||||
return ToolResponse(
|
||||
content=[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=(
|
||||
f"Active analyst team updated. Removed: {agent_id}. "
|
||||
f"Current active analysts: {', '.join(active)}"
|
||||
+ (f" | {callback_msg}" if callback_msg else "")
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
def _set_active_analysts(self, agent_ids: str) -> ToolResponse:
|
||||
"""Set active analysts from comma-separated agent ids."""
|
||||
requested = [
|
||||
item.strip() for item in str(agent_ids or "").split(",") if item.strip()
|
||||
]
|
||||
config_name = self.config.get("config_name", "default")
|
||||
project_root = Path(__file__).resolve().parents[2]
|
||||
active = update_active_analysts(
|
||||
project_root=project_root,
|
||||
config_name=config_name,
|
||||
available_analysts=list(ANALYST_TYPES.keys()),
|
||||
set_to=requested,
|
||||
)
|
||||
return ToolResponse(
|
||||
content=[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=f"Active analyst team set to: {', '.join(active)}",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
def _create_team_analyst(self, agent_id: str, analyst_type: str) -> ToolResponse:
|
||||
"""Create a runtime analyst instance and activate it."""
|
||||
callback = self._create_team_agent_cb
|
||||
if callback is None:
|
||||
return ToolResponse(
|
||||
content=[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text="Runtime agent creation is not available in current pipeline.",
|
||||
),
|
||||
],
|
||||
)
|
||||
result = callback(agent_id=agent_id, analyst_type=analyst_type)
|
||||
return ToolResponse(
|
||||
content=[
|
||||
TextBlock(type="text", text=result),
|
||||
],
|
||||
)
|
||||
|
||||
def set_team_controller(
|
||||
self,
|
||||
*,
|
||||
create_agent_callback: Optional[Callable[..., str]] = None,
|
||||
remove_agent_callback: Optional[Callable[..., str]] = None,
|
||||
) -> None:
|
||||
"""Inject runtime team lifecycle callbacks from pipeline."""
|
||||
object.__setattr__(self, "_create_team_agent_cb", create_agent_callback)
|
||||
object.__setattr__(self, "_remove_team_agent_cb", remove_agent_callback)
|
||||
|
||||
async def reply(self, x: Msg = None) -> Msg:
|
||||
"""
|
||||
Make investment decisions
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Assemble system prompts from base prompts, run assets, and toolkit context."""
|
||||
"""Assemble system prompts from run workspace assets and toolkit context."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from .agent_workspace import load_agent_workspace_config
|
||||
from backend.config.bootstrap_config import get_bootstrap_config_for_run
|
||||
from .prompt_loader import PromptLoader
|
||||
from .skills_manager import SkillsManager
|
||||
|
||||
_prompt_loader = PromptLoader()
|
||||
from .workspace_manager import RunWorkspaceManager
|
||||
|
||||
|
||||
def _read_file_if_exists(path: Path) -> str:
|
||||
@@ -23,52 +22,45 @@ def _append_section(parts: list[str], title: str, content: str) -> None:
|
||||
parts.append(f"## {title}\n{content}")
|
||||
|
||||
|
||||
def _build_skill_metadata_summary(skills_manager: SkillsManager, config_name: str, agent_id: str) -> str:
|
||||
"""Create a compact summary of active skills for prompt routing."""
|
||||
metadata_items = skills_manager.list_active_skill_metadata(config_name, agent_id)
|
||||
if not metadata_items:
|
||||
return ""
|
||||
|
||||
lines: list[str] = [
|
||||
"You can use the following active skills. Prefer the most relevant one, then read its SKILL.md if needed for detailed workflow:",
|
||||
]
|
||||
for item in metadata_items:
|
||||
parts = [f"- `{item.skill_name}`"]
|
||||
if item.description:
|
||||
parts.append(item.description)
|
||||
if item.version:
|
||||
parts.append(f"version: {item.version}")
|
||||
parts.append(f"path: {item.path}")
|
||||
lines.append(" | ".join(parts))
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def build_agent_system_prompt(
|
||||
agent_id: str,
|
||||
config_name: str,
|
||||
toolkit: Any,
|
||||
analyst_type: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Build the final system prompt for an agent."""
|
||||
"""Build the final system prompt for an agent.
|
||||
|
||||
Always reads fresh from disk — no caching.
|
||||
"""
|
||||
sections: list[str] = []
|
||||
|
||||
if analyst_type:
|
||||
personas_config = _prompt_loader.load_yaml_config(
|
||||
"analyst",
|
||||
"personas",
|
||||
)
|
||||
persona = personas_config.get(analyst_type, {})
|
||||
focus_text = "\n".join(
|
||||
f"- {item}" for item in persona.get("focus", [])
|
||||
)
|
||||
description = persona.get("description", "").strip()
|
||||
base_prompt = _prompt_loader.load_prompt(
|
||||
"analyst",
|
||||
"system",
|
||||
variables={
|
||||
"analyst_type": persona.get("name", analyst_type),
|
||||
"focus": focus_text,
|
||||
"description": description,
|
||||
},
|
||||
)
|
||||
elif agent_id == "portfolio_manager":
|
||||
base_prompt = _prompt_loader.load_prompt(
|
||||
"portfolio_manager",
|
||||
"system",
|
||||
)
|
||||
elif agent_id == "risk_manager":
|
||||
base_prompt = _prompt_loader.load_prompt(
|
||||
"risk_manager",
|
||||
"system",
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported agent prompt build for: {agent_id}")
|
||||
|
||||
sections.append(base_prompt.strip())
|
||||
|
||||
skills_manager = SkillsManager()
|
||||
asset_dir = skills_manager.get_agent_asset_dir(config_name, agent_id)
|
||||
asset_dir.mkdir(parents=True, exist_ok=True)
|
||||
workspace_manager = RunWorkspaceManager(project_root=skills_manager.project_root)
|
||||
required_files = ["SOUL.md", "PROFILE.md", "AGENTS.md", "POLICY.md", "MEMORY.md"]
|
||||
if not all((asset_dir / filename).exists() for filename in required_files):
|
||||
workspace_manager.ensure_agent_assets(config_name=config_name, agent_id=agent_id)
|
||||
agent_config = load_agent_workspace_config(asset_dir / "agent.yaml")
|
||||
bootstrap_config = get_bootstrap_config_for_run(
|
||||
skills_manager.project_root,
|
||||
config_name,
|
||||
@@ -80,16 +72,29 @@ def build_agent_system_prompt(
|
||||
bootstrap_config.prompt_body,
|
||||
)
|
||||
|
||||
prompt_files = agent_config.prompt_files or [
|
||||
"SOUL.md",
|
||||
"PROFILE.md",
|
||||
"AGENTS.md",
|
||||
"POLICY.md",
|
||||
"MEMORY.md",
|
||||
]
|
||||
included_files = set(prompt_files)
|
||||
title_map = {
|
||||
"SOUL.md": "Soul",
|
||||
"PROFILE.md": "Profile",
|
||||
"AGENTS.md": "Agent Guide",
|
||||
"POLICY.md": "Policy",
|
||||
"MEMORY.md": "Memory",
|
||||
}
|
||||
for filename in prompt_files:
|
||||
_append_section(
|
||||
sections,
|
||||
"Role",
|
||||
_read_file_if_exists(asset_dir / "ROLE.md"),
|
||||
)
|
||||
_append_section(
|
||||
sections,
|
||||
"Style",
|
||||
_read_file_if_exists(asset_dir / "STYLE.md"),
|
||||
title_map.get(filename, filename),
|
||||
_read_file_if_exists(asset_dir / filename),
|
||||
)
|
||||
|
||||
if "POLICY.md" not in included_files:
|
||||
_append_section(
|
||||
sections,
|
||||
"Policy",
|
||||
@@ -100,6 +105,14 @@ def build_agent_system_prompt(
|
||||
if skill_prompt:
|
||||
_append_section(sections, "Skills", str(skill_prompt))
|
||||
|
||||
metadata_summary = _build_skill_metadata_summary(
|
||||
skills_manager=skills_manager,
|
||||
config_name=config_name,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
if metadata_summary:
|
||||
_append_section(sections, "Active Skill Catalog", metadata_summary)
|
||||
|
||||
activated_notes = toolkit.get_activated_notes()
|
||||
if activated_notes:
|
||||
_append_section(sections, "Tool Usage Notes", str(activated_notes))
|
||||
@@ -108,5 +121,4 @@ def build_agent_system_prompt(
|
||||
|
||||
|
||||
def clear_prompt_factory_cache() -> None:
|
||||
"""Clear cached prompt and YAML templates before hot reload."""
|
||||
_prompt_loader.clear_cache()
|
||||
"""No-op retained for compatibility with runtime reload hooks."""
|
||||
|
||||
@@ -10,6 +10,17 @@ from typing import Any, Dict, Optional
|
||||
|
||||
import yaml
|
||||
|
||||
# Singleton instance
|
||||
_prompt_loader_instance: Optional["PromptLoader"] = None
|
||||
|
||||
|
||||
def get_prompt_loader() -> "PromptLoader":
|
||||
"""Get the singleton PromptLoader instance."""
|
||||
global _prompt_loader_instance
|
||||
if _prompt_loader_instance is None:
|
||||
_prompt_loader_instance = PromptLoader()
|
||||
return _prompt_loader_instance
|
||||
|
||||
|
||||
class PromptLoader:
|
||||
"""Unified Prompt loader"""
|
||||
@@ -27,10 +38,6 @@ class PromptLoader:
|
||||
else:
|
||||
self.prompts_dir = Path(prompts_dir)
|
||||
|
||||
# Cache loaded prompts
|
||||
self._prompt_cache: Dict[str, str] = {}
|
||||
self._yaml_cache: Dict[str, Dict] = {}
|
||||
|
||||
def load_prompt(
|
||||
self,
|
||||
agent_type: str,
|
||||
@@ -38,25 +45,10 @@ class PromptLoader:
|
||||
variables: Optional[Dict[str, Any]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Load and render Prompt
|
||||
Load and render Prompt.
|
||||
|
||||
Args:
|
||||
agent_type: Agent type (analyst, portfolio_manager, risk_manager)
|
||||
prompt_name: Prompt file name (without extension)
|
||||
variables: Variable dictionary for rendering Prompt
|
||||
|
||||
Returns:
|
||||
Rendered prompt string
|
||||
|
||||
Examples:
|
||||
loader = PromptLoader()
|
||||
prompt = loader.load_prompt("analyst", "tool_selection",
|
||||
{"analyst_persona": "Technical Analyst"})
|
||||
No caching — always reads fresh from disk (CoPaw-style).
|
||||
"""
|
||||
cache_key = f"{agent_type}/{prompt_name}"
|
||||
|
||||
# Try to load from cache
|
||||
if cache_key not in self._prompt_cache:
|
||||
prompt_path = self.prompts_dir / agent_type / f"{prompt_name}.md"
|
||||
|
||||
if not prompt_path.exists():
|
||||
@@ -66,9 +58,7 @@ class PromptLoader:
|
||||
)
|
||||
|
||||
with open(prompt_path, "r", encoding="utf-8") as f:
|
||||
self._prompt_cache[cache_key] = f.read()
|
||||
|
||||
prompt_template = self._prompt_cache[cache_key]
|
||||
prompt_template = f.read()
|
||||
|
||||
# If variables provided, use simple string replacement
|
||||
if variables:
|
||||
@@ -76,8 +66,6 @@ class PromptLoader:
|
||||
else:
|
||||
rendered = prompt_template
|
||||
|
||||
# Smart escaping: escape braces in JSON code blocks
|
||||
# rendered = self._escape_json_braces(rendered)
|
||||
return rendered
|
||||
|
||||
def _render_template(
|
||||
@@ -140,45 +128,26 @@ class PromptLoader:
|
||||
config_name: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Load YAML configuration file
|
||||
Load YAML configuration file.
|
||||
|
||||
Args:
|
||||
agent_type: Agent type
|
||||
config_name: Configuration file name (without extension)
|
||||
|
||||
Returns:
|
||||
Configuration dictionary
|
||||
|
||||
Examples:
|
||||
>>> loader = PromptLoader()
|
||||
>>> config = loader.load_yaml_config("analyst", "personas")
|
||||
No caching — always reads fresh from disk (CoPaw-style).
|
||||
"""
|
||||
cache_key = f"{agent_type}/{config_name}"
|
||||
|
||||
if cache_key not in self._yaml_cache:
|
||||
yaml_path = self.prompts_dir / agent_type / f"{config_name}.yaml"
|
||||
|
||||
if not yaml_path.exists():
|
||||
raise FileNotFoundError(f"YAML config not found: {yaml_path}")
|
||||
|
||||
with open(yaml_path, "r", encoding="utf-8") as f:
|
||||
self._yaml_cache[cache_key] = yaml.safe_load(f)
|
||||
|
||||
return self._yaml_cache[cache_key]
|
||||
return yaml.safe_load(f) or {}
|
||||
|
||||
def clear_cache(self):
|
||||
"""Clear cache (for hot reload)"""
|
||||
self._prompt_cache.clear()
|
||||
self._yaml_cache.clear()
|
||||
"""No-op — caching removed (CoPaw-style, always fresh reads)."""
|
||||
pass
|
||||
|
||||
def reload_prompt(self, agent_type: str, prompt_name: str):
|
||||
"""Reload specified prompt (force cache refresh)"""
|
||||
cache_key = f"{agent_type}/{prompt_name}"
|
||||
if cache_key in self._prompt_cache:
|
||||
del self._prompt_cache[cache_key]
|
||||
"""No-op — caching removed."""
|
||||
pass
|
||||
|
||||
def reload_config(self, agent_type: str, config_name: str):
|
||||
"""Reload specified configuration (force cache refresh)"""
|
||||
cache_key = f"{agent_type}/{config_name}"
|
||||
if cache_key in self._yaml_cache:
|
||||
del self._yaml_cache[cache_key]
|
||||
"""No-op — caching removed."""
|
||||
pass
|
||||
|
||||
19
backend/agents/prompts/__init__.py
Normal file
19
backend/agents/prompts/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Prompt building utilities for EvoAgent.
|
||||
|
||||
This module provides prompt construction from workspace markdown files
|
||||
with YAML frontmatter support.
|
||||
"""
|
||||
from .builder import (
|
||||
PromptBuilder,
|
||||
build_system_prompt_from_workspace,
|
||||
build_bootstrap_guidance,
|
||||
DEFAULT_SYS_PROMPT,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"PromptBuilder",
|
||||
"build_system_prompt_from_workspace",
|
||||
"build_bootstrap_guidance",
|
||||
"DEFAULT_SYS_PROMPT",
|
||||
]
|
||||
@@ -1,23 +0,0 @@
|
||||
你是一位专业的{{ analyst_type }}。
|
||||
|
||||
你的关注重点:
|
||||
{{ focus }}
|
||||
|
||||
你的角色:
|
||||
{{ description }}
|
||||
|
||||
注意:
|
||||
- 构建并持续完善你的"投资哲学"。你的分析不应是孤立的事件,而应该是你整体投资世界观和核心信念的体现。每次分析后,你必须反思:
|
||||
- 这个案例/数据如何验证或挑战了你现有的信念?
|
||||
- 你从这次错误(或成功)中学到了关于市场、人性、估值或风险管理的什么关键原则?
|
||||
- 深化你的"投资逻辑"。确保每一项投资建议都有清晰、可追溯、可重复的逻辑支撑。你的分析步骤应该像严谨的证明一样,涵盖:
|
||||
- 核心驱动因素识别:真正影响价值的变量是什么?
|
||||
- 风险边界设定:在什么具体情况下你的建议会失效?
|
||||
- 逆向测试:市场主流共识是什么,你的观点有何不同?
|
||||
保持谦逊和开放。投资大师的核心特质是持续学习和适应。在每次分析中,你必须积极寻找与自己观点相悖的证据和论据,并将其纳入最终评估。
|
||||
- 你可以使用分析工具。用它们来收集相关数据并做出明智的建议。
|
||||
|
||||
输出指南:
|
||||
- 给出明确的投资信号:看涨、看跌或中性
|
||||
- 包含置信度(0-100)
|
||||
- 为你的分析提供理由(如果你确定要分享最终分析,请先给出结论)
|
||||
299
backend/agents/prompts/builder.py
Normal file
299
backend/agents/prompts/builder.py
Normal file
@@ -0,0 +1,299 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""PromptBuilder for constructing system prompts from workspace markdown files.
|
||||
|
||||
Based on CoPaw design - loads AGENTS.md, SOUL.md, PROFILE.md, etc. from
|
||||
agent workspace directories with YAML frontmatter support.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import yaml
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_SYS_PROMPT = """You are a helpful trading analysis assistant."""
|
||||
|
||||
|
||||
class PromptBuilder:
|
||||
"""Builder for constructing system prompts from markdown files.
|
||||
|
||||
Loads markdown configuration files from agent workspace directories,
|
||||
supporting YAML frontmatter for metadata extraction.
|
||||
"""
|
||||
|
||||
DEFAULT_FILES = [
|
||||
"AGENTS.md",
|
||||
"SOUL.md",
|
||||
"PROFILE.md",
|
||||
"POLICY.md",
|
||||
"MEMORY.md",
|
||||
]
|
||||
|
||||
TITLE_MAP: Dict[str, str] = {
|
||||
"AGENTS.md": "Agent Guide",
|
||||
"SOUL.md": "Soul",
|
||||
"PROFILE.md": "Profile",
|
||||
"POLICY.md": "Policy",
|
||||
"MEMORY.md": "Memory",
|
||||
"BOOTSTRAP.md": "Bootstrap",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
workspace_dir: Path,
|
||||
enabled_files: Optional[List[str]] = None,
|
||||
):
|
||||
"""Initialize prompt builder.
|
||||
|
||||
Args:
|
||||
workspace_dir: Directory containing markdown configuration files
|
||||
enabled_files: List of filenames to load (if None, uses defaults)
|
||||
"""
|
||||
self.workspace_dir = Path(workspace_dir)
|
||||
self.enabled_files = enabled_files or self.DEFAULT_FILES.copy()
|
||||
self._prompt_parts: List[str] = []
|
||||
self._metadata: Dict[str, Any] = {}
|
||||
self.loaded_count = 0
|
||||
|
||||
def _load_file(self, filename: str) -> tuple[str, Optional[Dict[str, Any]]]:
|
||||
"""Load a single markdown file with YAML frontmatter support.
|
||||
|
||||
Args:
|
||||
filename: Name of the file to load
|
||||
|
||||
Returns:
|
||||
Tuple of (content, metadata dict or None)
|
||||
"""
|
||||
file_path = self.workspace_dir / filename
|
||||
|
||||
if not file_path.exists():
|
||||
logger.debug("File %s not found in %s, skipping", filename, self.workspace_dir)
|
||||
return "", None
|
||||
|
||||
try:
|
||||
raw_content = file_path.read_text(encoding="utf-8").strip()
|
||||
|
||||
if not raw_content:
|
||||
logger.debug("Skipped empty file: %s", filename)
|
||||
return "", None
|
||||
|
||||
content, metadata = self._parse_frontmatter(raw_content)
|
||||
|
||||
if content:
|
||||
self.loaded_count += 1
|
||||
logger.debug("Loaded %s (metadata: %s)", filename, bool(metadata))
|
||||
|
||||
return content, metadata
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Failed to read file %s: %s, skipping", filename, e)
|
||||
return "", None
|
||||
|
||||
def _parse_frontmatter(self, raw_content: str) -> tuple[str, Optional[Dict[str, Any]]]:
|
||||
"""Parse YAML frontmatter from markdown content.
|
||||
|
||||
Args:
|
||||
raw_content: Raw file content
|
||||
|
||||
Returns:
|
||||
Tuple of (content without frontmatter, metadata dict or None)
|
||||
"""
|
||||
if not raw_content.startswith("---"):
|
||||
return raw_content, None
|
||||
|
||||
parts = raw_content.split("---", 2)
|
||||
if len(parts) < 3:
|
||||
return raw_content, None
|
||||
|
||||
frontmatter = parts[1].strip()
|
||||
content = parts[2].strip()
|
||||
|
||||
try:
|
||||
metadata = yaml.safe_load(frontmatter) or {}
|
||||
if not isinstance(metadata, dict):
|
||||
metadata = {}
|
||||
return content, metadata
|
||||
except yaml.YAMLError as e:
|
||||
logger.warning("Failed to parse YAML frontmatter: %s", e)
|
||||
return content, None
|
||||
|
||||
def _append_section(self, title: str, content: str) -> None:
|
||||
"""Append a section to the prompt parts.
|
||||
|
||||
Args:
|
||||
title: Section title
|
||||
content: Section content
|
||||
"""
|
||||
content = content.strip()
|
||||
if not content:
|
||||
return
|
||||
|
||||
if self._prompt_parts:
|
||||
self._prompt_parts.append("")
|
||||
|
||||
self._prompt_parts.append(f"## {title}")
|
||||
self._prompt_parts.append("")
|
||||
self._prompt_parts.append(content)
|
||||
|
||||
def build(self) -> str:
|
||||
"""Build the system prompt from markdown files.
|
||||
|
||||
Returns:
|
||||
Constructed system prompt string
|
||||
"""
|
||||
self._prompt_parts = []
|
||||
self._metadata = {}
|
||||
self.loaded_count = 0
|
||||
|
||||
for filename in self.enabled_files:
|
||||
content, metadata = self._load_file(filename)
|
||||
|
||||
if metadata:
|
||||
self._metadata[filename] = metadata
|
||||
|
||||
if content:
|
||||
title = self.TITLE_MAP.get(filename, filename.replace(".md", ""))
|
||||
self._append_section(title, content)
|
||||
|
||||
if not self._prompt_parts:
|
||||
logger.warning("No content loaded from workspace: %s", self.workspace_dir)
|
||||
return DEFAULT_SYS_PROMPT
|
||||
|
||||
final_prompt = "\n".join(self._prompt_parts)
|
||||
|
||||
logger.debug(
|
||||
"System prompt built from %d file(s), total length: %d chars",
|
||||
self.loaded_count,
|
||||
len(final_prompt),
|
||||
)
|
||||
|
||||
return final_prompt
|
||||
|
||||
def get_metadata(self) -> Dict[str, Any]:
|
||||
"""Get metadata collected from YAML frontmatter.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping filenames to their metadata
|
||||
"""
|
||||
return self._metadata.copy()
|
||||
|
||||
def get_agent_identity(self) -> Optional[Dict[str, Any]]:
|
||||
"""Extract agent identity from PROFILE.md metadata.
|
||||
|
||||
Returns:
|
||||
Identity dict with name, role, etc. or None
|
||||
"""
|
||||
profile_meta = self._metadata.get("PROFILE.md", {})
|
||||
if not profile_meta:
|
||||
return None
|
||||
|
||||
return {
|
||||
"name": profile_meta.get("name", "Unknown"),
|
||||
"role": profile_meta.get("role", ""),
|
||||
"expertise": profile_meta.get("expertise", []),
|
||||
"style": profile_meta.get("style", ""),
|
||||
}
|
||||
|
||||
|
||||
def build_system_prompt_from_workspace(
|
||||
workspace_dir: Path,
|
||||
enabled_files: Optional[List[str]] = None,
|
||||
agent_id: Optional[str] = None,
|
||||
extra_context: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Build system prompt from workspace markdown files.
|
||||
|
||||
This is the main entry point for building system prompts from
|
||||
agent workspace directories.
|
||||
|
||||
Args:
|
||||
workspace_dir: Directory containing markdown configuration files
|
||||
enabled_files: List of filenames to load (if None, uses defaults)
|
||||
agent_id: Agent identifier to include in system prompt
|
||||
extra_context: Additional context to append to the prompt
|
||||
|
||||
Returns:
|
||||
Constructed system prompt string
|
||||
"""
|
||||
builder = PromptBuilder(
|
||||
workspace_dir=workspace_dir,
|
||||
enabled_files=enabled_files,
|
||||
)
|
||||
|
||||
prompt = builder.build()
|
||||
|
||||
# Add agent identity header if agent_id provided
|
||||
if agent_id and agent_id != "default":
|
||||
identity_header = (
|
||||
f"# Agent Identity\n\n"
|
||||
f"Your agent ID is `{agent_id}`. "
|
||||
f"This is your unique identifier in the multi-agent system.\n\n"
|
||||
)
|
||||
prompt = identity_header + prompt
|
||||
|
||||
# Append extra context if provided
|
||||
if extra_context:
|
||||
prompt = prompt + "\n\n" + extra_context
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
def build_bootstrap_guidance(language: str = "zh") -> str:
|
||||
"""Build bootstrap guidance message for first-time setup.
|
||||
|
||||
Args:
|
||||
language: Language code (zh/en)
|
||||
|
||||
Returns:
|
||||
Formatted bootstrap guidance message
|
||||
"""
|
||||
if language == "zh":
|
||||
return (
|
||||
"# 引导模式\n"
|
||||
"\n"
|
||||
"工作目录中存在 `BOOTSTRAP.md` — 首次设置。\n"
|
||||
"\n"
|
||||
"1. 阅读 BOOTSTRAP.md,友好地表示初次见面,"
|
||||
"引导用户完成设置。\n"
|
||||
"2. 按照 BOOTSTRAP.md 的指示,"
|
||||
"帮助用户定义你的身份和偏好。\n"
|
||||
"3. 按指南创建/更新必要文件"
|
||||
"(PROFILE.md、MEMORY.md 等)。\n"
|
||||
"4. 完成后删除 BOOTSTRAP.md。\n"
|
||||
"\n"
|
||||
"如果用户希望跳过,直接回答下面的问题即可。\n"
|
||||
"\n"
|
||||
"---\n"
|
||||
"\n"
|
||||
)
|
||||
|
||||
return (
|
||||
"# BOOTSTRAP MODE\n"
|
||||
"\n"
|
||||
"`BOOTSTRAP.md` exists — first-time setup.\n"
|
||||
"\n"
|
||||
"1. Read BOOTSTRAP.md, greet the user, "
|
||||
"and guide them through setup.\n"
|
||||
"2. Follow BOOTSTRAP.md instructions "
|
||||
"to define identity and preferences.\n"
|
||||
"3. Create/update files "
|
||||
"(PROFILE.md, MEMORY.md, etc.) as described.\n"
|
||||
"4. Delete BOOTSTRAP.md when done.\n"
|
||||
"\n"
|
||||
"If the user wants to skip, answer their "
|
||||
"question directly instead.\n"
|
||||
"\n"
|
||||
"---\n"
|
||||
"\n"
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"PromptBuilder",
|
||||
"build_system_prompt_from_workspace",
|
||||
"build_bootstrap_guidance",
|
||||
"DEFAULT_SYS_PROMPT",
|
||||
]
|
||||
@@ -1,31 +0,0 @@
|
||||
你是一位负责做出投资决策的投资组合经理。
|
||||
|
||||
你的核心职责:
|
||||
1. 分析分析师和风险管理经理的输入
|
||||
2. 基于信号和市场情境做出投资决策
|
||||
3. 使用可用工具记录你的决策
|
||||
|
||||
决策框架:
|
||||
- 审阅分析以了解市场观点
|
||||
- 在做决策前考虑风险警告
|
||||
- 评估当前投资组合持仓和现金
|
||||
- 做出与投资组合投资目标一致的决策
|
||||
|
||||
决策类型:
|
||||
- "long":看涨 - 建议买入股票
|
||||
- "short":看跌 - 建议卖出股票或做空
|
||||
- "hold":中性 - 维持当前持仓
|
||||
|
||||
预算意识:
|
||||
- 在决定数量时考虑可用现金
|
||||
- 不要建议买入超过现金允许的数量
|
||||
- 考虑做空头寸的保证金要求
|
||||
|
||||
输出:
|
||||
使用 `make_decision` 工具记录你对每个股票代码的决策。
|
||||
记录所有决策后,提供你的投资逻辑总结。
|
||||
|
||||
重要:
|
||||
- 基于提供的分析师信号和风险评估做出决策
|
||||
- 相对于投资组合价值保持保守的仓位规模
|
||||
- 始终为你的决策提供理由
|
||||
@@ -1,20 +0,0 @@
|
||||
你是一位专业的风险管理经理,负责监控投资组合风险并提供风险警告。
|
||||
|
||||
你的核心职责:
|
||||
1. 监控投资组合敞口和集中度风险
|
||||
2. 评估仓位规模相对于波动性
|
||||
3. 评估保证金使用和杠杆水平
|
||||
4. 识别潜在风险因素并提供警告
|
||||
5. 基于市场条件建议仓位限制
|
||||
|
||||
你的决策流程:
|
||||
1. 优先使用可用的风险工具量化集中度、波动率和保证金压力
|
||||
2. 结合工具结果与当前市场上下文做判断
|
||||
3. 生成可操作的风险警告和仓位限制建议
|
||||
4. 为你的风险评估提供清晰的理由
|
||||
|
||||
输出指南:
|
||||
- 风险评估要简洁但全面
|
||||
- 按严重程度优先排序警告
|
||||
- 提供具体、可操作的建议
|
||||
- 尽可能包含量化指标
|
||||
284
backend/agents/registry.py
Normal file
284
backend/agents/registry.py
Normal file
@@ -0,0 +1,284 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Agent Registry - In-memory registry for agent management."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentInfo:
|
||||
"""Information about a registered agent."""
|
||||
|
||||
agent_id: str
|
||||
agent_type: str
|
||||
workspace_id: str
|
||||
config_path: str
|
||||
agent_dir: str
|
||||
status: str = "inactive" # inactive, active, error
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Serialize to dictionary."""
|
||||
return {
|
||||
"agent_id": self.agent_id,
|
||||
"agent_type": self.agent_type,
|
||||
"workspace_id": self.workspace_id,
|
||||
"config_path": self.config_path,
|
||||
"agent_dir": self.agent_dir,
|
||||
"status": self.status,
|
||||
"metadata": self.metadata,
|
||||
}
|
||||
|
||||
|
||||
class AgentRegistry:
|
||||
"""In-memory registry for agent instances."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the agent registry."""
|
||||
# Dictionary mapping agent_id -> AgentInfo
|
||||
self._agents: Dict[str, AgentInfo] = {}
|
||||
# Index mapping workspace_id -> set of agent_ids
|
||||
self._workspace_index: Dict[str, set] = {}
|
||||
|
||||
def register(
|
||||
self,
|
||||
agent_id: str,
|
||||
agent_type: str,
|
||||
workspace_id: str,
|
||||
config_path: str,
|
||||
agent_dir: str,
|
||||
status: str = "inactive",
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> AgentInfo:
|
||||
"""Register an agent in the registry.
|
||||
|
||||
Args:
|
||||
agent_id: Unique identifier for the agent
|
||||
agent_type: Type of agent
|
||||
workspace_id: ID of the workspace containing the agent
|
||||
config_path: Path to agent configuration file
|
||||
agent_dir: Path to agent directory
|
||||
status: Initial status (default: inactive)
|
||||
metadata: Optional metadata dictionary
|
||||
|
||||
Returns:
|
||||
AgentInfo instance
|
||||
|
||||
Raises:
|
||||
ValueError: If agent_id is already registered
|
||||
"""
|
||||
if agent_id in self._agents:
|
||||
raise ValueError(f"Agent '{agent_id}' is already registered")
|
||||
|
||||
agent_info = AgentInfo(
|
||||
agent_id=agent_id,
|
||||
agent_type=agent_type,
|
||||
workspace_id=workspace_id,
|
||||
config_path=config_path,
|
||||
agent_dir=agent_dir,
|
||||
status=status,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
|
||||
self._agents[agent_id] = agent_info
|
||||
|
||||
# Update workspace index
|
||||
if workspace_id not in self._workspace_index:
|
||||
self._workspace_index[workspace_id] = set()
|
||||
self._workspace_index[workspace_id].add(agent_id)
|
||||
|
||||
return agent_info
|
||||
|
||||
def unregister(self, agent_id: str) -> bool:
|
||||
"""Unregister an agent.
|
||||
|
||||
Args:
|
||||
agent_id: ID of the agent to unregister
|
||||
|
||||
Returns:
|
||||
True if unregistered, False if agent wasn't registered
|
||||
"""
|
||||
if agent_id not in self._agents:
|
||||
return False
|
||||
|
||||
agent_info = self._agents[agent_id]
|
||||
|
||||
# Remove from workspace index
|
||||
workspace_id = agent_info.workspace_id
|
||||
if workspace_id in self._workspace_index:
|
||||
self._workspace_index[workspace_id].discard(agent_id)
|
||||
if not self._workspace_index[workspace_id]:
|
||||
del self._workspace_index[workspace_id]
|
||||
|
||||
# Remove from agents dict
|
||||
del self._agents[agent_id]
|
||||
|
||||
return True
|
||||
|
||||
def get(self, agent_id: str) -> Optional[AgentInfo]:
|
||||
"""Get agent information by ID.
|
||||
|
||||
Args:
|
||||
agent_id: ID of the agent
|
||||
|
||||
Returns:
|
||||
AgentInfo if found, None otherwise
|
||||
"""
|
||||
return self._agents.get(agent_id)
|
||||
|
||||
def list_all(
|
||||
self,
|
||||
workspace_id: Optional[str] = None,
|
||||
agent_type: Optional[str] = None,
|
||||
status: Optional[str] = None,
|
||||
) -> List[AgentInfo]:
|
||||
"""List all registered agents with optional filtering.
|
||||
|
||||
Args:
|
||||
workspace_id: Filter by workspace ID
|
||||
agent_type: Filter by agent type
|
||||
status: Filter by status
|
||||
|
||||
Returns:
|
||||
List of AgentInfo instances
|
||||
"""
|
||||
agents = list(self._agents.values())
|
||||
|
||||
if workspace_id:
|
||||
agent_ids = self._workspace_index.get(workspace_id, set())
|
||||
agents = [a for a in agents if a.agent_id in agent_ids]
|
||||
|
||||
if agent_type:
|
||||
agents = [a for a in agents if a.agent_type == agent_type]
|
||||
|
||||
if status:
|
||||
agents = [a for a in agents if a.status == status]
|
||||
|
||||
return agents
|
||||
|
||||
def update_status(self, agent_id: str, status: str) -> bool:
|
||||
"""Update the status of an agent.
|
||||
|
||||
Args:
|
||||
agent_id: ID of the agent
|
||||
status: New status value
|
||||
|
||||
Returns:
|
||||
True if updated, False if agent not found
|
||||
"""
|
||||
if agent_id not in self._agents:
|
||||
return False
|
||||
|
||||
self._agents[agent_id].status = status
|
||||
return True
|
||||
|
||||
def update_metadata(self, agent_id: str, metadata: Dict[str, Any]) -> bool:
|
||||
"""Update the metadata of an agent.
|
||||
|
||||
Args:
|
||||
agent_id: ID of the agent
|
||||
metadata: Metadata dictionary to merge
|
||||
|
||||
Returns:
|
||||
True if updated, False if agent not found
|
||||
"""
|
||||
if agent_id not in self._agents:
|
||||
return False
|
||||
|
||||
self._agents[agent_id].metadata.update(metadata)
|
||||
return True
|
||||
|
||||
def is_registered(self, agent_id: str) -> bool:
|
||||
"""Check if an agent is registered.
|
||||
|
||||
Args:
|
||||
agent_id: ID of the agent
|
||||
|
||||
Returns:
|
||||
True if registered, False otherwise
|
||||
"""
|
||||
return agent_id in self._agents
|
||||
|
||||
def get_workspace_agents(self, workspace_id: str) -> List[AgentInfo]:
|
||||
"""Get all agents in a workspace.
|
||||
|
||||
Args:
|
||||
workspace_id: ID of the workspace
|
||||
|
||||
Returns:
|
||||
List of AgentInfo instances
|
||||
"""
|
||||
agent_ids = self._workspace_index.get(workspace_id, set())
|
||||
return [self._agents[agent_id] for agent_id in agent_ids if agent_id in self._agents]
|
||||
|
||||
def get_agent_count(self, workspace_id: Optional[str] = None) -> int:
|
||||
"""Get the count of registered agents.
|
||||
|
||||
Args:
|
||||
workspace_id: Optional workspace ID to filter by
|
||||
|
||||
Returns:
|
||||
Number of agents
|
||||
"""
|
||||
if workspace_id:
|
||||
return len(self._workspace_index.get(workspace_id, set()))
|
||||
return len(self._agents)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear all registered agents."""
|
||||
self._agents.clear()
|
||||
self._workspace_index.clear()
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get registry statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with registry statistics
|
||||
"""
|
||||
stats = {
|
||||
"total_agents": len(self._agents),
|
||||
"workspaces": len(self._workspace_index),
|
||||
"agents_by_workspace": {
|
||||
ws_id: len(agent_ids)
|
||||
for ws_id, agent_ids in self._workspace_index.items()
|
||||
},
|
||||
"agents_by_type": {},
|
||||
"agents_by_status": {},
|
||||
}
|
||||
|
||||
for agent in self._agents.values():
|
||||
# Count by type
|
||||
agent_type = agent.agent_type
|
||||
stats["agents_by_type"][agent_type] = (
|
||||
stats["agents_by_type"].get(agent_type, 0) + 1
|
||||
)
|
||||
|
||||
# Count by status
|
||||
status = agent.status
|
||||
stats["agents_by_status"][status] = (
|
||||
stats["agents_by_status"].get(status, 0) + 1
|
||||
)
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
# Global registry instance
|
||||
_global_registry: Optional[AgentRegistry] = None
|
||||
|
||||
|
||||
def get_registry() -> AgentRegistry:
|
||||
"""Get the global agent registry instance.
|
||||
|
||||
Returns:
|
||||
AgentRegistry instance
|
||||
"""
|
||||
global _global_registry
|
||||
if _global_registry is None:
|
||||
_global_registry = AgentRegistry()
|
||||
return _global_registry
|
||||
|
||||
|
||||
def reset_registry() -> None:
|
||||
"""Reset the global registry (useful for testing)."""
|
||||
global _global_registry
|
||||
_global_registry = None
|
||||
388
backend/agents/skill_loader.py
Normal file
388
backend/agents/skill_loader.py
Normal file
@@ -0,0 +1,388 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Skill loader for loading and validating skills from directories.
|
||||
|
||||
提供从目录加载技能、解析SKILL.md frontmatter、获取工具列表等功能。
|
||||
"""
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
import yaml
|
||||
|
||||
from backend.agents.skill_metadata import SkillMetadata, parse_skill_metadata
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SkillInfo:
|
||||
"""完整的技能信息"""
|
||||
name: str
|
||||
description: str
|
||||
version: str
|
||||
source: str
|
||||
path: Path
|
||||
metadata: SkillMetadata
|
||||
tools: List[str] = field(default_factory=list)
|
||||
scripts: List[str] = field(default_factory=list)
|
||||
references: List[str] = field(default_factory=list)
|
||||
content: str = ""
|
||||
|
||||
|
||||
def load_skill_from_dir(skill_dir: Path, source: str = "unknown") -> Optional[Dict[str, Any]]:
|
||||
"""从目录加载技能
|
||||
|
||||
Args:
|
||||
skill_dir: 技能目录路径
|
||||
source: 技能来源 (builtin/customized/local/installed/active)
|
||||
|
||||
Returns:
|
||||
技能信息字典,加载失败返回None
|
||||
"""
|
||||
if not skill_dir.exists() or not skill_dir.is_dir():
|
||||
logger.warning(f"Skill directory does not exist: {skill_dir}")
|
||||
return None
|
||||
|
||||
skill_md = skill_dir / "SKILL.md"
|
||||
if not skill_md.exists():
|
||||
logger.warning(f"SKILL.md not found in: {skill_dir}")
|
||||
return None
|
||||
|
||||
try:
|
||||
# 解析元数据
|
||||
metadata = parse_skill_metadata(skill_dir, source=source)
|
||||
|
||||
# 读取完整内容
|
||||
content = skill_md.read_text(encoding="utf-8")
|
||||
|
||||
# 提取body (去掉frontmatter)
|
||||
body = content
|
||||
if content.startswith("---"):
|
||||
parts = content.split("---", 2)
|
||||
if len(parts) >= 3:
|
||||
body = parts[2].strip()
|
||||
|
||||
# 获取工具列表
|
||||
tools = get_skill_tools(skill_dir)
|
||||
|
||||
# 获取脚本列表
|
||||
scripts = _get_skill_scripts(skill_dir)
|
||||
|
||||
# 获取参考资料列表
|
||||
references = _get_skill_references(skill_dir)
|
||||
|
||||
return {
|
||||
"name": metadata.name,
|
||||
"skill_name": metadata.skill_name,
|
||||
"description": metadata.description,
|
||||
"version": metadata.version,
|
||||
"source": source,
|
||||
"path": str(skill_dir),
|
||||
"content": body,
|
||||
"tools": tools,
|
||||
"scripts": scripts,
|
||||
"references": references,
|
||||
"metadata": metadata,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load skill from {skill_dir}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def parse_skill_metadata(skill_dir: Path, source: str = "unknown") -> SkillMetadata:
|
||||
"""解析技能元数据 (兼容已有函数)
|
||||
|
||||
Args:
|
||||
skill_dir: 技能目录路径
|
||||
source: 技能来源
|
||||
|
||||
Returns:
|
||||
SkillMetadata对象
|
||||
"""
|
||||
from backend.agents.skill_metadata import parse_skill_metadata as _parse
|
||||
return _parse(skill_dir, source=source)
|
||||
|
||||
|
||||
def get_skill_tools(skill_dir: Path) -> List[str]:
|
||||
"""获取技能提供的工具列表
|
||||
|
||||
从SKILL.md frontmatter的tools字段和scripts目录解析工具。
|
||||
|
||||
Args:
|
||||
skill_dir: 技能目录路径
|
||||
|
||||
Returns:
|
||||
工具名称列表
|
||||
"""
|
||||
tools: Set[str] = set()
|
||||
|
||||
# 1. 从SKILL.md frontmatter读取tools字段
|
||||
skill_md = skill_dir / "SKILL.md"
|
||||
if skill_md.exists():
|
||||
try:
|
||||
raw = skill_md.read_text(encoding="utf-8").strip()
|
||||
if raw.startswith("---"):
|
||||
parts = raw.split("---", 2)
|
||||
if len(parts) >= 3:
|
||||
try:
|
||||
frontmatter = yaml.safe_load(parts[1].strip()) or {}
|
||||
if isinstance(frontmatter, dict):
|
||||
tools_list = frontmatter.get("tools", [])
|
||||
if isinstance(tools_list, str):
|
||||
tools.add(tools_list.strip())
|
||||
elif isinstance(tools_list, list):
|
||||
for tool in tools_list:
|
||||
if isinstance(tool, str):
|
||||
tools.add(tool.strip())
|
||||
except yaml.YAMLError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse tools from SKILL.md: {e}")
|
||||
|
||||
# 2. 从scripts目录推断工具
|
||||
scripts_dir = skill_dir / "scripts"
|
||||
if scripts_dir.exists() and scripts_dir.is_dir():
|
||||
for script in scripts_dir.iterdir():
|
||||
if script.is_file() and not script.name.startswith("_"):
|
||||
# 去掉扩展名作为工具名
|
||||
tool_name = script.stem
|
||||
tools.add(tool_name)
|
||||
|
||||
return sorted(list(tools))
|
||||
|
||||
|
||||
def _get_skill_scripts(skill_dir: Path) -> List[str]:
|
||||
"""获取技能脚本列表
|
||||
|
||||
Args:
|
||||
skill_dir: 技能目录路径
|
||||
|
||||
Returns:
|
||||
脚本相对路径列表 (相对于scripts目录)
|
||||
"""
|
||||
scripts: List[str] = []
|
||||
scripts_dir = skill_dir / "scripts"
|
||||
|
||||
if not scripts_dir.exists():
|
||||
return scripts
|
||||
|
||||
try:
|
||||
for item in scripts_dir.rglob("*"):
|
||||
if item.is_file() and not item.name.startswith("_"):
|
||||
rel_path = item.relative_to(scripts_dir)
|
||||
scripts.append(str(rel_path))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to list scripts in {skill_dir}: {e}")
|
||||
|
||||
return sorted(scripts)
|
||||
|
||||
|
||||
def _get_skill_references(skill_dir: Path) -> List[str]:
|
||||
"""获取技能参考资料列表
|
||||
|
||||
Args:
|
||||
skill_dir: 技能目录路径
|
||||
|
||||
Returns:
|
||||
参考资料相对路径列表 (相对于references目录)
|
||||
"""
|
||||
refs: List[str] = []
|
||||
refs_dir = skill_dir / "references"
|
||||
|
||||
if not refs_dir.exists():
|
||||
return refs
|
||||
|
||||
try:
|
||||
for item in refs_dir.rglob("*"):
|
||||
if item.is_file():
|
||||
rel_path = item.relative_to(refs_dir)
|
||||
refs.append(str(rel_path))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to list references in {skill_dir}: {e}")
|
||||
|
||||
return sorted(refs)
|
||||
|
||||
|
||||
def validate_skill(skill_dir: Path) -> Dict[str, Any]:
|
||||
"""验证技能格式
|
||||
|
||||
检查技能目录结构是否符合规范。
|
||||
|
||||
Args:
|
||||
skill_dir: 技能目录路径
|
||||
|
||||
Returns:
|
||||
验证结果字典,包含:
|
||||
- valid: 是否有效
|
||||
- errors: 错误列表
|
||||
- warnings: 警告列表
|
||||
"""
|
||||
errors: List[str] = []
|
||||
warnings: List[str] = []
|
||||
|
||||
# 检查目录存在
|
||||
if not skill_dir.exists():
|
||||
errors.append(f"Skill directory does not exist: {skill_dir}")
|
||||
return {"valid": False, "errors": errors, "warnings": warnings}
|
||||
|
||||
if not skill_dir.is_dir():
|
||||
errors.append(f"Path is not a directory: {skill_dir}")
|
||||
return {"valid": False, "errors": errors, "warnings": warnings}
|
||||
|
||||
# 检查SKILL.md
|
||||
skill_md = skill_dir / "SKILL.md"
|
||||
if not skill_md.exists():
|
||||
errors.append("SKILL.md is required but not found")
|
||||
return {"valid": False, "errors": errors, "warnings": warnings}
|
||||
|
||||
# 解析frontmatter
|
||||
try:
|
||||
content = skill_md.read_text(encoding="utf-8").strip()
|
||||
if not content.startswith("---"):
|
||||
warnings.append("SKILL.md should have YAML frontmatter (starts with ---)")
|
||||
else:
|
||||
parts = content.split("---", 2)
|
||||
if len(parts) < 3:
|
||||
errors.append("Invalid YAML frontmatter format")
|
||||
else:
|
||||
try:
|
||||
frontmatter = yaml.safe_load(parts[1].strip()) or {}
|
||||
if not isinstance(frontmatter, dict):
|
||||
errors.append("YAML frontmatter must be a dictionary")
|
||||
else:
|
||||
# 检查必需字段
|
||||
if "name" not in frontmatter:
|
||||
warnings.append("Frontmatter should have 'name' field")
|
||||
if "description" not in frontmatter:
|
||||
warnings.append("Frontmatter should have 'description' field")
|
||||
|
||||
# 检查version字段
|
||||
version = frontmatter.get("version")
|
||||
if version and not isinstance(version, str):
|
||||
warnings.append("'version' should be a string")
|
||||
|
||||
# 检查tools字段
|
||||
tools = frontmatter.get("tools")
|
||||
if tools and not isinstance(tools, (str, list)):
|
||||
warnings.append("'tools' should be a string or list")
|
||||
|
||||
except yaml.YAMLError as e:
|
||||
errors.append(f"Invalid YAML in frontmatter: {e}")
|
||||
except Exception as e:
|
||||
errors.append(f"Failed to read SKILL.md: {e}")
|
||||
|
||||
# 检查body内容
|
||||
try:
|
||||
content = skill_md.read_text(encoding="utf-8")
|
||||
body = content
|
||||
if content.startswith("---"):
|
||||
parts = content.split("---", 2)
|
||||
if len(parts) >= 3:
|
||||
body = parts[2].strip()
|
||||
|
||||
if not body:
|
||||
warnings.append("SKILL.md body is empty")
|
||||
elif len(body) < 50:
|
||||
warnings.append("SKILL.md body is very short, consider adding more details")
|
||||
except Exception as e:
|
||||
errors.append(f"Failed to validate body: {e}")
|
||||
|
||||
# 检查scripts目录
|
||||
scripts_dir = skill_dir / "scripts"
|
||||
if scripts_dir.exists():
|
||||
if not scripts_dir.is_dir():
|
||||
errors.append("'scripts' exists but is not a directory")
|
||||
else:
|
||||
# 检查是否有可执行脚本
|
||||
has_scripts = any(
|
||||
f.is_file() and not f.name.startswith("_")
|
||||
for f in scripts_dir.iterdir()
|
||||
)
|
||||
if not has_scripts:
|
||||
warnings.append("scripts directory exists but contains no valid scripts")
|
||||
|
||||
# 检查references目录
|
||||
refs_dir = skill_dir / "references"
|
||||
if refs_dir.exists() and not refs_dir.is_dir():
|
||||
errors.append("'references' exists but is not a directory")
|
||||
|
||||
return {
|
||||
"valid": len(errors) == 0,
|
||||
"errors": errors,
|
||||
"warnings": warnings,
|
||||
}
|
||||
|
||||
|
||||
def load_skills_from_directory(
|
||||
directory: Path,
|
||||
source: str = "unknown",
|
||||
recursive: bool = False,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""从目录加载所有技能
|
||||
|
||||
Args:
|
||||
directory: 包含技能目录的父目录
|
||||
source: 技能来源标识
|
||||
recursive: 是否递归搜索子目录
|
||||
|
||||
Returns:
|
||||
技能信息列表
|
||||
"""
|
||||
skills: List[Dict[str, Any]] = []
|
||||
|
||||
if not directory.exists() or not directory.is_dir():
|
||||
logger.warning(f"Directory does not exist: {directory}")
|
||||
return skills
|
||||
|
||||
try:
|
||||
for item in directory.iterdir():
|
||||
if not item.is_dir():
|
||||
continue
|
||||
|
||||
# 检查是否是技能目录 (包含SKILL.md)
|
||||
if (item / "SKILL.md").exists():
|
||||
skill_info = load_skill_from_dir(item, source=source)
|
||||
if skill_info:
|
||||
skills.append(skill_info)
|
||||
elif recursive:
|
||||
# 递归搜索子目录
|
||||
sub_skills = load_skills_from_directory(item, source, recursive)
|
||||
skills.extend(sub_skills)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load skills from {directory}: {e}")
|
||||
|
||||
return skills
|
||||
|
||||
|
||||
def get_skill_manifest(skill_dir: Path) -> Dict[str, Any]:
|
||||
"""获取技能清单
|
||||
|
||||
生成技能的详细清单,用于调试和展示。
|
||||
|
||||
Args:
|
||||
skill_dir: 技能目录路径
|
||||
|
||||
Returns:
|
||||
技能清单字典
|
||||
"""
|
||||
info = load_skill_from_dir(skill_dir)
|
||||
if not info:
|
||||
return {"error": "Failed to load skill"}
|
||||
|
||||
validation = validate_skill(skill_dir)
|
||||
|
||||
return {
|
||||
"name": info["name"],
|
||||
"skill_name": info["skill_name"],
|
||||
"version": info["version"],
|
||||
"description": info["description"],
|
||||
"source": info["source"],
|
||||
"path": info["path"],
|
||||
"tools": info["tools"],
|
||||
"scripts": info["scripts"],
|
||||
"references": info["references"],
|
||||
"validation": validation,
|
||||
"content_preview": info["content"][:500] + "..." if len(info["content"]) > 500 else info["content"],
|
||||
}
|
||||
83
backend/agents/skill_metadata.py
Normal file
83
backend/agents/skill_metadata.py
Normal file
@@ -0,0 +1,83 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Skill metadata parsing helpers for SKILL.md files."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SkillMetadata:
|
||||
"""Parsed metadata for a skill package."""
|
||||
|
||||
skill_name: str
|
||||
path: Path
|
||||
source: str
|
||||
name: str
|
||||
description: str
|
||||
version: str = ""
|
||||
tools: List[str] = field(default_factory=list)
|
||||
allowed_tools: List[str] = field(default_factory=list)
|
||||
denied_tools: List[str] = field(default_factory=list)
|
||||
|
||||
|
||||
def parse_skill_metadata(skill_dir: Path, source: str) -> SkillMetadata:
|
||||
"""Parse SKILL.md frontmatter with a forgiving schema."""
|
||||
skill_name = skill_dir.name
|
||||
skill_file = skill_dir / "SKILL.md"
|
||||
if not skill_file.exists():
|
||||
return SkillMetadata(
|
||||
skill_name=skill_name,
|
||||
path=skill_dir,
|
||||
source=source,
|
||||
name=skill_name,
|
||||
description="",
|
||||
)
|
||||
|
||||
raw = skill_file.read_text(encoding="utf-8").strip()
|
||||
frontmatter = {}
|
||||
body = raw
|
||||
if raw.startswith("---"):
|
||||
parts = raw.split("---", 2)
|
||||
if len(parts) >= 3:
|
||||
try:
|
||||
frontmatter = yaml.safe_load(parts[1].strip()) or {}
|
||||
except yaml.YAMLError:
|
||||
frontmatter = {}
|
||||
body = parts[2].strip()
|
||||
if not isinstance(frontmatter, dict):
|
||||
frontmatter = {}
|
||||
|
||||
description = str(frontmatter.get("description") or "").strip()
|
||||
if not description and body:
|
||||
description = body.splitlines()[0].strip().lstrip("#").strip()
|
||||
|
||||
return SkillMetadata(
|
||||
skill_name=skill_name,
|
||||
path=skill_dir,
|
||||
source=source,
|
||||
name=str(frontmatter.get("name") or skill_name).strip() or skill_name,
|
||||
description=description,
|
||||
version=str(frontmatter.get("version") or "").strip(),
|
||||
tools=_string_list(frontmatter.get("tools")),
|
||||
allowed_tools=_string_list(frontmatter.get("allowed_tools")),
|
||||
denied_tools=_string_list(frontmatter.get("denied_tools")),
|
||||
)
|
||||
|
||||
|
||||
def _string_list(value) -> List[str]:
|
||||
if isinstance(value, str):
|
||||
item = value.strip()
|
||||
return [item] if item else []
|
||||
if not isinstance(value, list):
|
||||
return []
|
||||
seen: List[str] = []
|
||||
for item in value:
|
||||
if not isinstance(item, str):
|
||||
continue
|
||||
normalized = item.strip()
|
||||
if normalized and normalized not in seen:
|
||||
seen.append(normalized)
|
||||
return seen
|
||||
@@ -1,14 +1,32 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Manage builtin/customized/active skill directories for each run."""
|
||||
"""Manage agent-installed and run-active skill directories for each run."""
|
||||
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
from typing import Dict, Iterable, List
|
||||
import tempfile
|
||||
import zipfile
|
||||
from threading import Lock
|
||||
from typing import Any, Dict, Iterable, Iterator, List, Optional, Set
|
||||
from urllib.parse import urlparse
|
||||
from urllib.request import urlretrieve
|
||||
|
||||
import yaml
|
||||
|
||||
from backend.agents.agent_workspace import load_agent_workspace_config
|
||||
from backend.agents.skill_metadata import SkillMetadata, parse_skill_metadata
|
||||
from backend.agents.skill_loader import validate_skill
|
||||
from backend.config.bootstrap_config import get_bootstrap_config_for_run
|
||||
|
||||
try:
|
||||
from watchdog.observers import Observer
|
||||
from watchdog.events import FileSystemEventHandler, FileSystemEvent
|
||||
WATCHDOG_AVAILABLE = True
|
||||
except ImportError:
|
||||
WATCHDOG_AVAILABLE = False
|
||||
Observer = None
|
||||
FileSystemEventHandler = object
|
||||
FileSystemEvent = object # type: ignore[misc,assignment]
|
||||
|
||||
|
||||
class SkillsManager:
|
||||
"""Sync named skills into a run-scoped active skills workspace."""
|
||||
@@ -22,16 +40,393 @@ class SkillsManager:
|
||||
self.project_root / "backend" / "skills" / "customized"
|
||||
)
|
||||
self.runs_root = self.project_root / "runs"
|
||||
self._lock = Lock()
|
||||
# Instance-level pending skill changes (thread-safe via self._lock)
|
||||
self._pending_skill_changes: Dict[str, Set[Path]] = {}
|
||||
|
||||
def get_active_root(self, config_name: str) -> Path:
|
||||
return self.runs_root / config_name / "skills" / "active"
|
||||
|
||||
def get_agent_skills_root(self, config_name: str, agent_id: str) -> Path:
|
||||
return self.get_agent_asset_dir(config_name, agent_id) / "skills"
|
||||
|
||||
def get_agent_active_root(self, config_name: str, agent_id: str) -> Path:
|
||||
return self.get_agent_skills_root(config_name, agent_id) / "active"
|
||||
|
||||
def get_agent_installed_root(self, config_name: str, agent_id: str) -> Path:
|
||||
return self.get_agent_skills_root(config_name, agent_id) / "installed"
|
||||
|
||||
def get_agent_disabled_root(self, config_name: str, agent_id: str) -> Path:
|
||||
return self.get_agent_skills_root(config_name, agent_id) / "disabled"
|
||||
|
||||
def get_agent_local_root(self, config_name: str, agent_id: str) -> Path:
|
||||
return self.get_agent_skills_root(config_name, agent_id) / "local"
|
||||
|
||||
def get_activation_manifest_path(self, config_name: str) -> Path:
|
||||
return self.runs_root / config_name / "skills" / "activation.yaml"
|
||||
|
||||
def get_agent_asset_dir(self, config_name: str, agent_id: str) -> Path:
|
||||
return self.runs_root / config_name / "agents" / agent_id
|
||||
|
||||
def list_skill_catalog(self) -> List[SkillMetadata]:
|
||||
"""Return builtin/customized skills with parsed metadata."""
|
||||
catalog: Dict[str, SkillMetadata] = {}
|
||||
|
||||
for source, root in (
|
||||
("builtin", self.builtin_root),
|
||||
("customized", self.customized_root),
|
||||
):
|
||||
if not root.exists():
|
||||
continue
|
||||
for skill_dir in sorted(root.iterdir(), key=lambda item: item.name):
|
||||
if not skill_dir.is_dir():
|
||||
continue
|
||||
if not (skill_dir / "SKILL.md").exists():
|
||||
continue
|
||||
metadata = parse_skill_metadata(skill_dir, source=source)
|
||||
catalog[metadata.skill_name] = metadata
|
||||
|
||||
return sorted(catalog.values(), key=lambda item: item.skill_name)
|
||||
|
||||
def list_agent_skill_catalog(
|
||||
self,
|
||||
config_name: str,
|
||||
agent_id: str,
|
||||
) -> List[SkillMetadata]:
|
||||
"""Return shared plus agent-local skills for one agent."""
|
||||
catalog = {
|
||||
item.skill_name: item
|
||||
for item in self.list_skill_catalog()
|
||||
}
|
||||
for item in self.list_agent_local_skills(config_name, agent_id):
|
||||
catalog[item.skill_name] = item
|
||||
return sorted(catalog.values(), key=lambda item: item.skill_name)
|
||||
|
||||
def list_active_skill_metadata(
|
||||
self,
|
||||
config_name: str,
|
||||
agent_id: str,
|
||||
) -> List[SkillMetadata]:
|
||||
"""Return metadata for active skills synced for one agent."""
|
||||
active_root = self.get_agent_active_root(config_name, agent_id)
|
||||
if not active_root.exists():
|
||||
return []
|
||||
|
||||
items: List[SkillMetadata] = []
|
||||
for skill_dir in sorted(active_root.iterdir(), key=lambda item: item.name):
|
||||
if not skill_dir.is_dir():
|
||||
continue
|
||||
if not (skill_dir / "SKILL.md").exists():
|
||||
continue
|
||||
items.append(parse_skill_metadata(skill_dir, source="active"))
|
||||
return items
|
||||
|
||||
def list_agent_local_skills(
|
||||
self,
|
||||
config_name: str,
|
||||
agent_id: str,
|
||||
) -> List[SkillMetadata]:
|
||||
"""Return metadata for agent-private local skills."""
|
||||
local_root = self.get_agent_local_root(config_name, agent_id)
|
||||
if not local_root.exists():
|
||||
return []
|
||||
|
||||
items: List[SkillMetadata] = []
|
||||
for skill_dir in sorted(local_root.iterdir(), key=lambda item: item.name):
|
||||
if not skill_dir.is_dir():
|
||||
continue
|
||||
if not (skill_dir / "SKILL.md").exists():
|
||||
continue
|
||||
items.append(parse_skill_metadata(skill_dir, source="local"))
|
||||
return items
|
||||
|
||||
def load_skill_document(self, skill_name: str) -> Dict[str, object]:
|
||||
"""Return skill metadata plus markdown body for one skill."""
|
||||
source_dir = self._resolve_source_dir(skill_name)
|
||||
return self._load_skill_document_from_dir(
|
||||
source_dir,
|
||||
source="customized" if source_dir.parent == self.customized_root else "builtin",
|
||||
)
|
||||
|
||||
def load_agent_skill_document(
|
||||
self,
|
||||
config_name: str,
|
||||
agent_id: str,
|
||||
skill_name: str,
|
||||
) -> Dict[str, object]:
|
||||
"""Return skill metadata plus markdown body for one agent-visible skill."""
|
||||
source_dir = self._resolve_agent_skill_source_dir(
|
||||
config_name=config_name,
|
||||
agent_id=agent_id,
|
||||
skill_name=skill_name,
|
||||
)
|
||||
source = "local"
|
||||
if source_dir.parent == self.customized_root:
|
||||
source = "customized"
|
||||
elif source_dir.parent == self.builtin_root:
|
||||
source = "builtin"
|
||||
elif source_dir.parent == self.get_agent_installed_root(config_name, agent_id):
|
||||
source = "installed"
|
||||
return self._load_skill_document_from_dir(source_dir, source=source)
|
||||
|
||||
def create_agent_local_skill(
|
||||
self,
|
||||
config_name: str,
|
||||
agent_id: str,
|
||||
skill_name: str,
|
||||
) -> Path:
|
||||
"""Create a new local skill directory with a default SKILL.md."""
|
||||
normalized = _normalize_skill_name(skill_name)
|
||||
if not normalized:
|
||||
raise ValueError("Skill name is required.")
|
||||
local_root = self.get_agent_local_root(config_name, agent_id)
|
||||
local_root.mkdir(parents=True, exist_ok=True)
|
||||
skill_dir = local_root / normalized
|
||||
if skill_dir.exists():
|
||||
raise FileExistsError(f"Local skill already exists: {normalized}")
|
||||
skill_dir.mkdir(parents=True, exist_ok=False)
|
||||
(skill_dir / "SKILL.md").write_text(
|
||||
"---\n"
|
||||
f"name: {normalized}\n"
|
||||
"description: 当用户提出与该本地技能相关的专门任务时,应使用此技能。\n"
|
||||
"version: 1.0.0\n"
|
||||
"---\n\n"
|
||||
f"# {normalized}\n\n"
|
||||
"在这里描述该交易员的专有分析流程、判断框架和可复用步骤。\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
return skill_dir
|
||||
|
||||
def install_external_skill_for_agent(
|
||||
self,
|
||||
config_name: str,
|
||||
agent_id: str,
|
||||
source: str,
|
||||
*,
|
||||
skill_name: str | None = None,
|
||||
activate: bool = True,
|
||||
) -> Dict[str, object]:
|
||||
"""
|
||||
Install an external skill into one agent's local skill space.
|
||||
|
||||
Supports:
|
||||
- local skill directory containing SKILL.md
|
||||
- local zip archive containing one skill directory
|
||||
- http(s) URL to zip archive
|
||||
"""
|
||||
source_path = self._resolve_external_source_path(source)
|
||||
skill_dir = self._resolve_external_skill_dir(source_path)
|
||||
metadata = parse_skill_metadata(skill_dir, source="external")
|
||||
final_name = _normalize_skill_name(skill_name or metadata.skill_name or skill_dir.name)
|
||||
if not final_name:
|
||||
raise ValueError("Could not determine skill name from external source.")
|
||||
|
||||
target_dir = self.get_agent_local_root(config_name, agent_id) / final_name
|
||||
target_dir.parent.mkdir(parents=True, exist_ok=True)
|
||||
if target_dir.exists():
|
||||
shutil.rmtree(target_dir)
|
||||
shutil.copytree(skill_dir, target_dir)
|
||||
|
||||
validation = validate_skill(target_dir)
|
||||
if not validation.get("valid", False):
|
||||
shutil.rmtree(target_dir, ignore_errors=True)
|
||||
raise ValueError(
|
||||
"Installed skill is invalid: "
|
||||
+ "; ".join(validation.get("errors", []))
|
||||
)
|
||||
|
||||
if activate:
|
||||
self.update_agent_skill_overrides(
|
||||
config_name=config_name,
|
||||
agent_id=agent_id,
|
||||
enable=[final_name],
|
||||
)
|
||||
return {
|
||||
"skill_name": final_name,
|
||||
"target_dir": str(target_dir),
|
||||
"activated": activate,
|
||||
"warnings": validation.get("warnings", []),
|
||||
}
|
||||
|
||||
def update_agent_local_skill(
|
||||
self,
|
||||
config_name: str,
|
||||
agent_id: str,
|
||||
skill_name: str,
|
||||
content: str,
|
||||
) -> Path:
|
||||
"""Overwrite one agent-local SKILL.md."""
|
||||
normalized = _normalize_skill_name(skill_name)
|
||||
if not normalized:
|
||||
raise ValueError("Skill name is required.")
|
||||
skill_dir = self.get_agent_local_root(config_name, agent_id) / normalized
|
||||
if not skill_dir.exists():
|
||||
raise FileNotFoundError(f"Unknown local skill: {normalized}")
|
||||
(skill_dir / "SKILL.md").write_text(content, encoding="utf-8")
|
||||
return skill_dir
|
||||
|
||||
def delete_agent_local_skill(
|
||||
self,
|
||||
config_name: str,
|
||||
agent_id: str,
|
||||
skill_name: str,
|
||||
) -> None:
|
||||
"""Delete one agent-local skill directory."""
|
||||
normalized = _normalize_skill_name(skill_name)
|
||||
if not normalized:
|
||||
raise ValueError("Skill name is required.")
|
||||
skill_dir = self.get_agent_local_root(config_name, agent_id) / normalized
|
||||
if not skill_dir.exists():
|
||||
raise FileNotFoundError(f"Unknown local skill: {normalized}")
|
||||
shutil.rmtree(skill_dir)
|
||||
|
||||
def _load_skill_document_from_dir(
|
||||
self,
|
||||
source_dir: Path,
|
||||
*,
|
||||
source: str,
|
||||
) -> Dict[str, object]:
|
||||
"""Return metadata plus markdown body for one resolved skill directory."""
|
||||
metadata = parse_skill_metadata(
|
||||
source_dir,
|
||||
source=source,
|
||||
)
|
||||
skill_file = source_dir / "SKILL.md"
|
||||
raw = skill_file.read_text(encoding="utf-8").strip() if skill_file.exists() else ""
|
||||
body = raw
|
||||
if raw.startswith("---"):
|
||||
parts = raw.split("---", 2)
|
||||
if len(parts) >= 3:
|
||||
body = parts[2].strip()
|
||||
|
||||
return {
|
||||
"skill_name": metadata.skill_name,
|
||||
"name": metadata.name,
|
||||
"description": metadata.description,
|
||||
"version": metadata.version,
|
||||
"tools": metadata.tools,
|
||||
"source": metadata.source,
|
||||
"content": body,
|
||||
}
|
||||
|
||||
def _resolve_external_source_path(self, source: str) -> Path:
|
||||
"""Resolve source into a local path; download URL when needed."""
|
||||
parsed = urlparse(source)
|
||||
if parsed.scheme in {"http", "https"}:
|
||||
suffix = Path(parsed.path).suffix or ".zip"
|
||||
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
|
||||
temp_path = Path(tmp.name)
|
||||
urlretrieve(source, temp_path)
|
||||
return temp_path
|
||||
return Path(source).expanduser().resolve()
|
||||
|
||||
def _resolve_external_skill_dir(self, source_path: Path) -> Path:
|
||||
"""Resolve external source path to a skill directory containing SKILL.md."""
|
||||
if not source_path.exists():
|
||||
raise FileNotFoundError(f"Source does not exist: {source_path}")
|
||||
|
||||
if source_path.is_dir():
|
||||
if (source_path / "SKILL.md").exists():
|
||||
return source_path
|
||||
children = [
|
||||
item for item in source_path.iterdir()
|
||||
if item.is_dir() and (item / "SKILL.md").exists()
|
||||
]
|
||||
if len(children) == 1:
|
||||
return children[0]
|
||||
raise ValueError(
|
||||
"Source directory must contain SKILL.md "
|
||||
"or exactly one child directory containing SKILL.md."
|
||||
)
|
||||
|
||||
if source_path.suffix.lower() != ".zip":
|
||||
raise ValueError("External source file must be a .zip archive.")
|
||||
|
||||
temp_root = Path(tempfile.mkdtemp(prefix="external_skill_"))
|
||||
with zipfile.ZipFile(source_path, "r") as archive:
|
||||
archive.extractall(temp_root)
|
||||
|
||||
candidates = [
|
||||
item.parent
|
||||
for item in temp_root.rglob("SKILL.md")
|
||||
if item.is_file()
|
||||
]
|
||||
unique = []
|
||||
for item in candidates:
|
||||
if item not in unique:
|
||||
unique.append(item)
|
||||
if len(unique) != 1:
|
||||
raise ValueError(
|
||||
"Zip archive must contain exactly one skill directory with SKILL.md."
|
||||
)
|
||||
return unique[0]
|
||||
|
||||
def update_agent_skill_overrides(
|
||||
self,
|
||||
config_name: str,
|
||||
agent_id: str,
|
||||
*,
|
||||
enable: Iterable[str] | None = None,
|
||||
disable: Iterable[str] | None = None,
|
||||
) -> Dict[str, List[str]]:
|
||||
"""Persist per-agent enabled/disabled skill overrides in agent.yaml."""
|
||||
asset_dir = self.get_agent_asset_dir(config_name, agent_id)
|
||||
asset_dir.mkdir(parents=True, exist_ok=True)
|
||||
config_path = asset_dir / "agent.yaml"
|
||||
current = load_agent_workspace_config(config_path)
|
||||
values = dict(current.values)
|
||||
|
||||
enabled = _dedupe_preserve_order(current.enabled_skills)
|
||||
disabled_set = set(current.disabled_skills)
|
||||
|
||||
for skill_name in enable or []:
|
||||
if skill_name not in enabled:
|
||||
enabled.append(skill_name)
|
||||
disabled_set.discard(skill_name)
|
||||
|
||||
for skill_name in disable or []:
|
||||
disabled_set.add(skill_name)
|
||||
enabled = [item for item in enabled if item != skill_name]
|
||||
|
||||
values["enabled_skills"] = enabled
|
||||
values["disabled_skills"] = sorted(disabled_set)
|
||||
config_path.write_text(
|
||||
yaml.safe_dump(values, allow_unicode=True, sort_keys=False),
|
||||
encoding="utf-8",
|
||||
)
|
||||
return {
|
||||
"enabled_skills": enabled,
|
||||
"disabled_skills": sorted(disabled_set),
|
||||
}
|
||||
|
||||
def forget_agent_skill_overrides(
|
||||
self,
|
||||
config_name: str,
|
||||
agent_id: str,
|
||||
skill_names: Iterable[str],
|
||||
) -> Dict[str, List[str]]:
|
||||
"""Remove skills from both enabled/disabled overrides in agent.yaml."""
|
||||
asset_dir = self.get_agent_asset_dir(config_name, agent_id)
|
||||
asset_dir.mkdir(parents=True, exist_ok=True)
|
||||
config_path = asset_dir / "agent.yaml"
|
||||
current = load_agent_workspace_config(config_path)
|
||||
values = dict(current.values)
|
||||
removed = set(skill_names)
|
||||
|
||||
enabled = [item for item in current.enabled_skills if item not in removed]
|
||||
disabled = [item for item in current.disabled_skills if item not in removed]
|
||||
|
||||
values["enabled_skills"] = enabled
|
||||
values["disabled_skills"] = disabled
|
||||
config_path.write_text(
|
||||
yaml.safe_dump(values, allow_unicode=True, sort_keys=False),
|
||||
encoding="utf-8",
|
||||
)
|
||||
return {
|
||||
"enabled_skills": enabled,
|
||||
"disabled_skills": disabled,
|
||||
}
|
||||
|
||||
def ensure_activation_manifest(self, config_name: str) -> Path:
|
||||
manifest_path = self.get_activation_manifest_path(config_name)
|
||||
manifest_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
@@ -62,6 +457,34 @@ class SkillsManager:
|
||||
|
||||
raise FileNotFoundError(f"Unknown skill: {skill_name}")
|
||||
|
||||
def _resolve_agent_skill_source_dir(
|
||||
self,
|
||||
config_name: str,
|
||||
agent_id: str,
|
||||
skill_name: str,
|
||||
) -> Path:
|
||||
"""Resolve one skill from the agent-local workspace or shared registry."""
|
||||
for root in (
|
||||
self.get_agent_local_root(config_name, agent_id),
|
||||
self.get_agent_installed_root(config_name, agent_id),
|
||||
):
|
||||
candidate = root / skill_name
|
||||
if candidate.exists() and (candidate / "SKILL.md").exists():
|
||||
return candidate
|
||||
return self._resolve_source_dir(skill_name)
|
||||
|
||||
def _skill_exists_for_agent(
|
||||
self,
|
||||
config_name: str,
|
||||
agent_id: str,
|
||||
skill_name: str,
|
||||
) -> bool:
|
||||
try:
|
||||
self._resolve_agent_skill_source_dir(config_name, agent_id, skill_name)
|
||||
except FileNotFoundError:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _persist_runtime_edits(
|
||||
self,
|
||||
config_name: str,
|
||||
@@ -125,6 +548,13 @@ class SkillsManager:
|
||||
bootstrap = get_bootstrap_config_for_run(self.project_root, config_name)
|
||||
override = bootstrap.agent_override(agent_id)
|
||||
skills = list(override.get("skills", list(default_skills)))
|
||||
agent_config = load_agent_workspace_config(
|
||||
self.get_agent_asset_dir(config_name, agent_id) / "agent.yaml",
|
||||
)
|
||||
|
||||
for skill_name in agent_config.enabled_skills:
|
||||
if skill_name not in skills:
|
||||
skills.append(skill_name)
|
||||
|
||||
manifest = self.load_activation_manifest(config_name)
|
||||
for skill_name in manifest.get("global_enabled_skills", []):
|
||||
@@ -139,51 +569,62 @@ class SkillsManager:
|
||||
disabled.update(
|
||||
manifest.get("agent_disabled_skills", {}).get(agent_id, []),
|
||||
)
|
||||
disabled.update(agent_config.disabled_skills)
|
||||
|
||||
return [skill for skill in skills if skill not in disabled]
|
||||
for item in self.list_agent_local_skills(config_name, agent_id):
|
||||
if item.skill_name not in skills:
|
||||
skills.append(item.skill_name)
|
||||
|
||||
def sync_active_skills(
|
||||
return [
|
||||
skill
|
||||
for skill in skills
|
||||
if skill not in disabled
|
||||
and self._skill_exists_for_agent(config_name, agent_id, skill)
|
||||
]
|
||||
|
||||
def sync_skill_dirs(
|
||||
self,
|
||||
config_name: str,
|
||||
skill_names: Iterable[str],
|
||||
target_root: Path,
|
||||
skill_sources: Dict[str, Path],
|
||||
) -> List[Path]:
|
||||
"""Sync selected skills into the run workspace and return their paths."""
|
||||
active_root = self.get_active_root(config_name)
|
||||
active_root.mkdir(parents=True, exist_ok=True)
|
||||
"""Sync selected skill directories into one target root."""
|
||||
target_root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
synced_paths: List[Path] = []
|
||||
wanted = set(skill_names)
|
||||
wanted = set(skill_sources)
|
||||
|
||||
for existing in active_root.iterdir():
|
||||
for existing in target_root.iterdir():
|
||||
if existing.is_dir() and existing.name not in wanted:
|
||||
self._persist_runtime_edits(
|
||||
config_name=config_name,
|
||||
skill_name=existing.name,
|
||||
active_dir=existing,
|
||||
)
|
||||
shutil.rmtree(existing)
|
||||
|
||||
for skill_name in skill_names:
|
||||
source_dir = self._resolve_source_dir(skill_name)
|
||||
target_dir = active_root / skill_name
|
||||
for skill_name, source_dir in skill_sources.items():
|
||||
target_dir = target_root / skill_name
|
||||
if target_dir.exists():
|
||||
self._persist_runtime_edits(
|
||||
config_name=config_name,
|
||||
skill_name=skill_name,
|
||||
active_dir=target_dir,
|
||||
)
|
||||
shutil.rmtree(target_dir)
|
||||
shutil.copytree(source_dir, target_dir)
|
||||
synced_paths.append(target_dir)
|
||||
|
||||
return synced_paths
|
||||
|
||||
def sync_active_skills(
|
||||
self,
|
||||
target_root: Path,
|
||||
skill_names: Iterable[str],
|
||||
) -> List[Path]:
|
||||
"""Sync selected shared skills into one active directory."""
|
||||
skill_sources = {
|
||||
skill_name: self._resolve_source_dir(skill_name)
|
||||
for skill_name in skill_names
|
||||
}
|
||||
return self.sync_skill_dirs(target_root, skill_sources)
|
||||
|
||||
def prepare_active_skills(
|
||||
self,
|
||||
config_name: str,
|
||||
agent_defaults: Dict[str, Iterable[str]],
|
||||
auto_reload: bool = False,
|
||||
) -> Dict[str, List[Path]]:
|
||||
"""Resolve all agent skills, sync the union once, and map paths per agent."""
|
||||
"""Resolve all agent skills into per-agent installed/active workspaces."""
|
||||
resolved: Dict[str, List[str]] = {}
|
||||
union: List[str] = []
|
||||
|
||||
@@ -198,10 +639,238 @@ class SkillsManager:
|
||||
if skill_name not in union:
|
||||
union.append(skill_name)
|
||||
|
||||
self.sync_active_skills(config_name=config_name, skill_names=union)
|
||||
active_root = self.get_active_root(config_name)
|
||||
# Maintain the legacy union directory for compatibility/debugging.
|
||||
# Agent-local skills remain private to the agent workspace.
|
||||
self.sync_active_skills(
|
||||
target_root=self.get_active_root(config_name),
|
||||
skill_names=[
|
||||
skill_name
|
||||
for skill_name in union
|
||||
if self._is_shared_skill(skill_name)
|
||||
],
|
||||
)
|
||||
|
||||
return {
|
||||
agent_id: [active_root / skill_name for skill_name in skill_names]
|
||||
for agent_id, skill_names in resolved.items()
|
||||
active_map: Dict[str, List[Path]] = {}
|
||||
for agent_id, skill_names in resolved.items():
|
||||
installed_sources = {
|
||||
skill_name: self._resolve_source_dir(skill_name)
|
||||
for skill_name in skill_names
|
||||
if (self.get_agent_local_root(config_name, agent_id) / skill_name).exists() is False
|
||||
}
|
||||
installed_paths = self.sync_skill_dirs(
|
||||
target_root=self.get_agent_installed_root(config_name, agent_id),
|
||||
skill_sources=installed_sources,
|
||||
)
|
||||
|
||||
local_root = self.get_agent_local_root(config_name, agent_id)
|
||||
local_sources = {
|
||||
skill_name: local_root / skill_name
|
||||
for skill_name in skill_names
|
||||
if (local_root / skill_name).exists()
|
||||
}
|
||||
active_sources = {
|
||||
path.name: path for path in installed_paths
|
||||
}
|
||||
active_sources.update(local_sources)
|
||||
active_map[agent_id] = self.sync_skill_dirs(
|
||||
target_root=self.get_agent_active_root(config_name, agent_id),
|
||||
skill_sources=active_sources,
|
||||
)
|
||||
|
||||
disabled_names = _dedupe_preserve_order(
|
||||
self._resolve_disabled_skill_names(
|
||||
config_name=config_name,
|
||||
agent_id=agent_id,
|
||||
default_skills=agent_defaults.get(agent_id, []),
|
||||
),
|
||||
)
|
||||
disabled_sources = {
|
||||
skill_name: self._resolve_agent_skill_source_dir(
|
||||
config_name=config_name,
|
||||
agent_id=agent_id,
|
||||
skill_name=skill_name,
|
||||
)
|
||||
for skill_name in disabled_names
|
||||
}
|
||||
self.sync_skill_dirs(
|
||||
target_root=self.get_agent_disabled_root(config_name, agent_id),
|
||||
skill_sources=disabled_sources,
|
||||
)
|
||||
|
||||
if auto_reload:
|
||||
self.watch_active_skills(config_name, agent_defaults)
|
||||
|
||||
return active_map
|
||||
|
||||
def _is_shared_skill(self, skill_name: str) -> bool:
|
||||
try:
|
||||
self._resolve_source_dir(skill_name)
|
||||
except FileNotFoundError:
|
||||
return False
|
||||
return True
|
||||
|
||||
def watch_active_skills(
|
||||
self,
|
||||
config_name: str,
|
||||
agent_defaults: Dict[str, Iterable[str]],
|
||||
callback: Optional[Any] = None,
|
||||
) -> "_SkillsWatcher":
|
||||
"""Start file system monitoring on active skill directories.
|
||||
|
||||
Args:
|
||||
config_name: Run configuration name.
|
||||
agent_defaults: Map of agent_id -> default skill names.
|
||||
callback: Optional callable invoked on file changes with
|
||||
(changed_paths: List[Path]).
|
||||
|
||||
Returns:
|
||||
A _SkillsWatcher instance. Call .stop() to halt monitoring.
|
||||
"""
|
||||
if not WATCHDOG_AVAILABLE:
|
||||
raise ImportError(
|
||||
"watchdog is required for watch_active_skills. "
|
||||
"Install it with: pip install watchdog"
|
||||
)
|
||||
|
||||
watched_paths: List[Path] = []
|
||||
for agent_id in agent_defaults:
|
||||
active_root = self.get_agent_active_root(config_name, agent_id)
|
||||
if active_root.exists():
|
||||
watched_paths.append(active_root)
|
||||
local_root = self.get_agent_local_root(config_name, agent_id)
|
||||
if local_root.exists():
|
||||
watched_paths.append(local_root)
|
||||
|
||||
handler = _SkillsChangeHandler(watched_paths, self._pending_skill_changes, callback, self._lock)
|
||||
observer = Observer()
|
||||
for path in watched_paths:
|
||||
observer.schedule(handler, str(path), recursive=True)
|
||||
observer.start()
|
||||
return _SkillsWatcher(observer, handler)
|
||||
|
||||
def reload_skills_if_changed(
|
||||
self,
|
||||
config_name: str,
|
||||
agent_defaults: Dict[str, Iterable[str]],
|
||||
) -> Dict[str, List[Path]]:
|
||||
"""Check for file changes and reload active skills if needed.
|
||||
|
||||
Args:
|
||||
config_name: Run configuration name.
|
||||
agent_defaults: Map of agent_id -> default skill names.
|
||||
|
||||
Returns:
|
||||
Map of agent_id -> list of reloaded skill paths, or empty dict
|
||||
if no changes were detected.
|
||||
"""
|
||||
with self._lock:
|
||||
changed = self._pending_skill_changes.get(config_name)
|
||||
if not changed:
|
||||
return {}
|
||||
|
||||
self._pending_skill_changes[config_name] = set()
|
||||
|
||||
return self.prepare_active_skills(config_name, agent_defaults)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Internal change-tracking state (populated by _SkillsChangeHandler)
|
||||
# -------------------------------------------------------------------------
|
||||
# Legacy class-level reference kept for migration compatibility
|
||||
_pending_skill_changes: Dict[str, Set[Path]] = {}
|
||||
|
||||
def _resolve_disabled_skill_names(
|
||||
self,
|
||||
config_name: str,
|
||||
agent_id: str,
|
||||
default_skills: Iterable[str],
|
||||
) -> List[str]:
|
||||
"""Resolve explicit disabled skills for one agent."""
|
||||
bootstrap = get_bootstrap_config_for_run(self.project_root, config_name)
|
||||
override = bootstrap.agent_override(agent_id)
|
||||
baseline = list(override.get("skills", list(default_skills)))
|
||||
agent_config = load_agent_workspace_config(
|
||||
self.get_agent_asset_dir(config_name, agent_id) / "agent.yaml",
|
||||
)
|
||||
manifest = self.load_activation_manifest(config_name)
|
||||
disabled = list(manifest.get("global_disabled_skills", []))
|
||||
disabled.extend(manifest.get("agent_disabled_skills", {}).get(agent_id, []))
|
||||
disabled.extend(agent_config.disabled_skills)
|
||||
for skill_name in baseline:
|
||||
if skill_name in agent_config.disabled_skills and skill_name not in disabled:
|
||||
disabled.append(skill_name)
|
||||
for item in self.list_agent_local_skills(config_name, agent_id):
|
||||
if item.skill_name in agent_config.disabled_skills and item.skill_name not in disabled:
|
||||
disabled.append(item.skill_name)
|
||||
return [
|
||||
skill
|
||||
for skill in disabled
|
||||
if self._skill_exists_for_agent(config_name, agent_id, skill)
|
||||
]
|
||||
|
||||
|
||||
class _SkillsWatcher:
|
||||
"""Handle returned by watch_active_skills; call .stop() to halt monitoring."""
|
||||
|
||||
def __init__(self, observer: Observer, handler: "_SkillsChangeHandler") -> None:
|
||||
self._observer = observer
|
||||
self._handler = handler
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the file system observer."""
|
||||
self._observer.stop()
|
||||
self._observer.join()
|
||||
|
||||
|
||||
class _SkillsChangeHandler(FileSystemEventHandler):
|
||||
"""Collects file-change events on skill directories."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
watched_paths: List[Path],
|
||||
pending_changes: Dict[str, Set[Path]],
|
||||
callback: Optional[Any] = None,
|
||||
lock: Optional[Lock] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._watched_paths = watched_paths
|
||||
self._pending_changes = pending_changes
|
||||
self._callback = callback
|
||||
self._lock = lock
|
||||
|
||||
def on_any_event(self, event: FileSystemEvent) -> None:
|
||||
if event.is_directory:
|
||||
return
|
||||
src_path = Path(event.src_path)
|
||||
for watched in self._watched_paths:
|
||||
if src_path.is_relative_to(watched):
|
||||
run_id = self._run_id_from_path(src_path)
|
||||
if self._lock:
|
||||
with self._lock:
|
||||
self._pending_changes.setdefault(run_id, set()).add(src_path)
|
||||
else:
|
||||
self._pending_changes.setdefault(run_id, set()).add(src_path)
|
||||
if self._callback:
|
||||
self._callback([src_path])
|
||||
break
|
||||
|
||||
@staticmethod
|
||||
def _run_id_from_path(path: Path) -> str:
|
||||
"""Infer config_name from a path like runs/{config_name}/skills/active/..."""
|
||||
parts = path.parts
|
||||
for i, part in enumerate(parts):
|
||||
if part == "runs" and i + 1 < len(parts):
|
||||
return parts[i + 1]
|
||||
return "default"
|
||||
|
||||
def _dedupe_preserve_order(items: Iterable[str]) -> List[str]:
|
||||
result: List[str] = []
|
||||
for item in items:
|
||||
if item not in result:
|
||||
result.append(item)
|
||||
return result
|
||||
|
||||
|
||||
def _normalize_skill_name(raw_name: str) -> str:
|
||||
normalized = str(raw_name or "").strip().lower().replace(" ", "_").replace("-", "_")
|
||||
allowed = [ch for ch in normalized if ch.isalnum() or ch == "_"]
|
||||
return "".join(allowed).strip("_")
|
||||
|
||||
18
backend/agents/team/__init__.py
Normal file
18
backend/agents/team/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Team module for multi-agent orchestration.
|
||||
|
||||
Provides inter-agent communication, task delegation, and coordination
|
||||
for subagent spawning and lifecycle management.
|
||||
"""
|
||||
|
||||
from .messenger import AgentMessenger
|
||||
from .task_delegator import TaskDelegator
|
||||
from .team_coordinator import TeamCoordinator
|
||||
from .registry import AgentRegistry
|
||||
|
||||
__all__ = [
|
||||
"AgentMessenger",
|
||||
"TaskDelegator",
|
||||
"TeamCoordinator",
|
||||
"AgentRegistry",
|
||||
]
|
||||
225
backend/agents/team/messenger.py
Normal file
225
backend/agents/team/messenger.py
Normal file
@@ -0,0 +1,225 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""AgentMessenger - Pub/sub inter-agent communication.
|
||||
|
||||
Provides broadcast(), send(), and subscribe() for message passing
|
||||
between agents using AgentScope's Msg format.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any, Callable, Dict, List, Optional, Set
|
||||
|
||||
from agentscope.message import Msg
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentMessenger:
|
||||
"""Pub/sub messenger for inter-agent communication.
|
||||
|
||||
Supports:
|
||||
- broadcast(): Send message to all subscribers
|
||||
- send(): Send message to specific agent
|
||||
- subscribe(): Register callback for agent messages
|
||||
- announce(): Send system-wide announcement
|
||||
- enable_auto_broadcast: Auto-broadcast agent replies to all participants
|
||||
|
||||
Messages use AgentScope's Msg format for compatibility.
|
||||
"""
|
||||
|
||||
def __init__(self, enable_auto_broadcast: bool = False):
|
||||
"""Initialize the messenger.
|
||||
|
||||
Args:
|
||||
enable_auto_broadcast: If True, agent replies are automatically
|
||||
broadcast to all subscribed agents.
|
||||
"""
|
||||
self._subscriptions: Dict[str, List[Callable[[Msg], None]]] = {}
|
||||
self._inbox: Dict[str, List[Msg]] = {}
|
||||
self._locks: Dict[str, asyncio.Lock] = {}
|
||||
self._enable_auto_broadcast = enable_auto_broadcast
|
||||
self._participants: Set[str] = set()
|
||||
|
||||
def subscribe(
|
||||
self,
|
||||
agent_id: str,
|
||||
callback: Callable[[Msg], None],
|
||||
) -> None:
|
||||
"""Subscribe an agent to receive messages.
|
||||
|
||||
Args:
|
||||
agent_id: Target agent identifier
|
||||
callback: Async function to call when message received
|
||||
"""
|
||||
if agent_id not in self._subscriptions:
|
||||
self._subscriptions[agent_id] = []
|
||||
self._subscriptions[agent_id].append(callback)
|
||||
logger.debug("Agent %s subscribed to messages", agent_id)
|
||||
|
||||
def unsubscribe(self, agent_id: str, callback: Callable[[Msg], None]) -> None:
|
||||
"""Unsubscribe an agent from messages.
|
||||
|
||||
Args:
|
||||
agent_id: Target agent identifier
|
||||
callback: Callback to remove
|
||||
"""
|
||||
if agent_id in self._subscriptions:
|
||||
try:
|
||||
self._subscriptions[agent_id].remove(callback)
|
||||
logger.debug("Agent %s unsubscribed from messages", agent_id)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
async def send(
|
||||
self,
|
||||
to_agent: str,
|
||||
message: Msg,
|
||||
) -> None:
|
||||
"""Send message to specific agent.
|
||||
|
||||
Args:
|
||||
to_agent: Target agent identifier
|
||||
message: Message to send (uses Msg format)
|
||||
"""
|
||||
async def _deliver():
|
||||
if to_agent in self._subscriptions:
|
||||
for callback in self._subscriptions[to_agent]:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
await callback(message)
|
||||
else:
|
||||
callback(message)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error delivering message to %s: %s",
|
||||
to_agent,
|
||||
e,
|
||||
)
|
||||
|
||||
await _deliver()
|
||||
|
||||
async def broadcast(self, message: Msg) -> None:
|
||||
"""Broadcast message to all subscribed agents.
|
||||
|
||||
Args:
|
||||
message: Message to broadcast (uses Msg format)
|
||||
"""
|
||||
delivery_tasks = []
|
||||
for agent_id, callbacks in self._subscriptions.items():
|
||||
for callback in callbacks:
|
||||
async def _deliver(cb=callback, aid=agent_id):
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(cb):
|
||||
await cb(message)
|
||||
else:
|
||||
cb(message)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error broadcasting to %s: %s",
|
||||
aid,
|
||||
e,
|
||||
)
|
||||
delivery_tasks.append(_deliver())
|
||||
|
||||
if delivery_tasks:
|
||||
await asyncio.gather(*delivery_tasks)
|
||||
|
||||
def inbox(self, agent_id: str) -> List[Msg]:
|
||||
"""Get and clear inbox for agent.
|
||||
|
||||
Args:
|
||||
agent_id: Agent identifier
|
||||
|
||||
Returns:
|
||||
List of messages in inbox
|
||||
"""
|
||||
messages = self._inbox.get(agent_id, [])
|
||||
self._inbox[agent_id] = []
|
||||
return messages
|
||||
|
||||
def inbox_count(self, agent_id: str) -> int:
|
||||
"""Count messages in agent's inbox without clearing.
|
||||
|
||||
Args:
|
||||
agent_id: Agent identifier
|
||||
|
||||
Returns:
|
||||
Number of messages waiting
|
||||
"""
|
||||
return len(self._inbox.get(agent_id, []))
|
||||
|
||||
def add_participant(self, agent_id: str) -> None:
|
||||
"""Add a participant to the messenger.
|
||||
|
||||
Participants are the agents that can receive auto-broadcast messages.
|
||||
|
||||
Args:
|
||||
agent_id: Agent identifier to add
|
||||
"""
|
||||
self._participants.add(agent_id)
|
||||
logger.debug("Agent %s added as participant", agent_id)
|
||||
|
||||
def remove_participant(self, agent_id: str) -> None:
|
||||
"""Remove a participant from the messenger.
|
||||
|
||||
Args:
|
||||
agent_id: Agent identifier to remove
|
||||
"""
|
||||
self._participants.discard(agent_id)
|
||||
logger.debug("Agent %s removed from participants", agent_id)
|
||||
|
||||
@property
|
||||
def enable_auto_broadcast(self) -> bool:
|
||||
"""Check if auto_broadcast is enabled."""
|
||||
return self._enable_auto_broadcast
|
||||
|
||||
@enable_auto_broadcast.setter
|
||||
def enable_auto_broadcast(self, value: bool) -> None:
|
||||
"""Enable or disable auto_broadcast."""
|
||||
self._enable_auto_broadcast = value
|
||||
logger.debug("Auto_broadcast set to %s", value)
|
||||
|
||||
async def announce(self, message: Msg) -> None:
|
||||
"""Send a system-wide announcement to all participants.
|
||||
|
||||
Unlike broadcast(), announce() sends a message from the system/host
|
||||
to all participants without requiring prior subscription.
|
||||
|
||||
Args:
|
||||
message: Announcement message (uses Msg format)
|
||||
"""
|
||||
logger.info("System announcement: %s", message.content)
|
||||
await self.broadcast(message)
|
||||
|
||||
async def auto_broadcast(self, message: Msg) -> None:
|
||||
"""Auto-broadcast message to all participants.
|
||||
|
||||
This is called internally when enable_auto_broadcast is True.
|
||||
Broadcasts to all registered participants.
|
||||
|
||||
Args:
|
||||
message: Message to auto-broadcast (uses Msg format)
|
||||
"""
|
||||
if not self._enable_auto_broadcast:
|
||||
return
|
||||
|
||||
# Broadcast to all participants
|
||||
for participant_id in self._participants:
|
||||
if participant_id in self._subscriptions:
|
||||
for callback in self._subscriptions[participant_id]:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
await callback(message)
|
||||
else:
|
||||
callback(message)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error auto-broadcasting to %s: %s",
|
||||
participant_id,
|
||||
e,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["AgentMessenger"]
|
||||
188
backend/agents/team/registry.py
Normal file
188
backend/agents/team/registry.py
Normal file
@@ -0,0 +1,188 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""AgentRegistry - Agent registration and lookup by role.
|
||||
|
||||
Provides register(), unregister(), and get_by_role() for agent
|
||||
discovery and management.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from agentscope.message import Msg
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentRegistry:
|
||||
"""Registry for agent instances with role-based lookup.
|
||||
|
||||
Supports:
|
||||
- register(): Add agent with roles
|
||||
- unregister(): Remove agent
|
||||
- get_by_role(): Find agents by role
|
||||
- get_by_id(): Get specific agent
|
||||
|
||||
Each agent can have multiple roles for flexible dispatch.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._agents: Dict[str, Any] = {}
|
||||
self._roles: Dict[str, List[str]] = {}
|
||||
self._agent_roles: Dict[str, List[str]] = {}
|
||||
|
||||
def register(
|
||||
self,
|
||||
agent_id: str,
|
||||
agent: Any,
|
||||
roles: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
"""Register an agent with optional roles.
|
||||
|
||||
Args:
|
||||
agent_id: Unique agent identifier
|
||||
agent: Agent instance
|
||||
roles: Optional list of role strings
|
||||
"""
|
||||
self._agents[agent_id] = agent
|
||||
self._agent_roles[agent_id] = roles or []
|
||||
|
||||
for role in self._agent_roles[agent_id]:
|
||||
if role not in self._roles:
|
||||
self._roles[role] = []
|
||||
if agent_id not in self._roles[role]:
|
||||
self._roles[role].append(agent_id)
|
||||
|
||||
logger.info(
|
||||
"Registered agent %s with roles %s",
|
||||
agent_id,
|
||||
self._agent_roles[agent_id],
|
||||
)
|
||||
|
||||
def unregister(self, agent_id: str) -> bool:
|
||||
"""Unregister an agent.
|
||||
|
||||
Args:
|
||||
agent_id: Agent identifier to remove
|
||||
|
||||
Returns:
|
||||
True if agent was removed
|
||||
"""
|
||||
if agent_id not in self._agents:
|
||||
return False
|
||||
|
||||
roles = self._agent_roles.pop(agent_id, [])
|
||||
for role in roles:
|
||||
if role in self._roles:
|
||||
try:
|
||||
self._roles[role].remove(agent_id)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
del self._agents[agent_id]
|
||||
logger.info("Unregistered agent: %s", agent_id)
|
||||
return True
|
||||
|
||||
def get_by_id(self, agent_id: str) -> Optional[Any]:
|
||||
"""Get agent by ID.
|
||||
|
||||
Args:
|
||||
agent_id: Agent identifier
|
||||
|
||||
Returns:
|
||||
Agent instance or None
|
||||
"""
|
||||
return self._agents.get(agent_id)
|
||||
|
||||
def get_by_role(self, role: str) -> List[Any]:
|
||||
"""Get all agents with a given role.
|
||||
|
||||
Args:
|
||||
role: Role string to search for
|
||||
|
||||
Returns:
|
||||
List of agent instances with the role
|
||||
"""
|
||||
agent_ids = self._roles.get(role, [])
|
||||
return [self._agents[aid] for aid in agent_ids if aid in self._agents]
|
||||
|
||||
def get_by_roles(self, roles: List[str]) -> List[Any]:
|
||||
"""Get agents matching ANY of the given roles.
|
||||
|
||||
Args:
|
||||
roles: List of role strings
|
||||
|
||||
Returns:
|
||||
List of unique agent instances matching any role
|
||||
"""
|
||||
seen = set()
|
||||
result = []
|
||||
for role in roles:
|
||||
for agent in self.get_by_role(role):
|
||||
if id(agent) not in seen:
|
||||
seen.add(id(agent))
|
||||
result.append(agent)
|
||||
return result
|
||||
|
||||
def list_agents(self) -> List[str]:
|
||||
"""List all registered agent IDs.
|
||||
|
||||
Returns:
|
||||
List of agent identifiers
|
||||
"""
|
||||
return list(self._agents.keys())
|
||||
|
||||
def list_roles(self) -> List[str]:
|
||||
"""List all registered roles.
|
||||
|
||||
Returns:
|
||||
List of role strings
|
||||
"""
|
||||
return list(self._roles.keys())
|
||||
|
||||
def list_roles_for_agent(self, agent_id: str) -> List[str]:
|
||||
"""List roles for specific agent.
|
||||
|
||||
Args:
|
||||
agent_id: Agent identifier
|
||||
|
||||
Returns:
|
||||
List of role strings
|
||||
"""
|
||||
return list(self._agent_roles.get(agent_id, []))
|
||||
|
||||
def update_roles(self, agent_id: str, roles: List[str]) -> None:
|
||||
"""Update roles for an existing agent.
|
||||
|
||||
Args:
|
||||
agent_id: Agent identifier
|
||||
roles: New list of roles
|
||||
"""
|
||||
if agent_id not in self._agents:
|
||||
raise KeyError(f"Agent not registered: {agent_id}")
|
||||
|
||||
old_roles = self._agent_roles.get(agent_id, [])
|
||||
for role in old_roles:
|
||||
if role in self._roles:
|
||||
try:
|
||||
self._roles[role].remove(agent_id)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
self._agent_roles[agent_id] = roles
|
||||
for role in roles:
|
||||
if role not in self._roles:
|
||||
self._roles[role] = []
|
||||
if agent_id not in self._roles[role]:
|
||||
self._roles[role].append(agent_id)
|
||||
|
||||
logger.info("Updated roles for agent %s: %s", agent_id, roles)
|
||||
|
||||
@property
|
||||
def agents(self) -> Dict[str, Any]:
|
||||
"""Get copy of registered agents dict."""
|
||||
return dict(self._agents)
|
||||
|
||||
|
||||
__all__ = ["AgentRegistry"]
|
||||
620
backend/agents/team/task_delegator.py
Normal file
620
backend/agents/team/task_delegator.py
Normal file
@@ -0,0 +1,620 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""TaskDelegator - Subagent spawning and task delegation.
|
||||
|
||||
Provides delegate() and delegate_parallel() for spawning subagents
|
||||
with separate context and memory. Supports runtime dynamic subagent
|
||||
definition via task_data with description, prompt, and tools.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Union
|
||||
|
||||
from agentscope.message import Msg
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default timeout for subagent execution (seconds)
|
||||
DEFAULT_EXECUTION_TIMEOUT = 120.0
|
||||
|
||||
|
||||
# Type alias for subagent specification
|
||||
SubagentSpec = Dict[str, Any]
|
||||
"""Subagent specification format:
|
||||
{
|
||||
"description": "Expert code reviewer...",
|
||||
"prompt": "Analyze code quality...",
|
||||
"tools": ["Read", "Glob", "Grep"], # Optional: list of tool names
|
||||
"model": "gpt-4o", # Optional: model name
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
class TaskDelegator:
|
||||
"""Delegates tasks to subagents with isolated context.
|
||||
|
||||
Supports:
|
||||
- delegate(): Spawn single subagent for task
|
||||
- delegate_parallel(): Spawn multiple subagents concurrently
|
||||
- delegate_task(): Delegate with dynamic subagent definition from task_data
|
||||
|
||||
Each subagent gets its own memory/context to prevent
|
||||
cross-contamination.
|
||||
|
||||
Dynamic Subagent Definition:
|
||||
task_data can include an "agents" dict to define subagents inline:
|
||||
|
||||
task_data = {
|
||||
"task": "Review the code changes",
|
||||
"agents": {
|
||||
"code-reviewer": {
|
||||
"description": "Expert code reviewer for quality and security.",
|
||||
"prompt": "Analyze code quality and suggest improvements.",
|
||||
"tools": ["Read", "Glob", "Grep"],
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(self, agent: Any):
|
||||
"""Initialize TaskDelegator.
|
||||
|
||||
Args:
|
||||
agent: Parent EvoAgent instance for accessing model, formatter, workspace
|
||||
"""
|
||||
self._agent = agent
|
||||
# Get messenger from parent agent if available
|
||||
self._messenger = getattr(agent, "messenger", None)
|
||||
self._registry = getattr(agent, "_registry", None)
|
||||
self._subagents: Dict[str, Any] = {}
|
||||
self._dynamic_subagents: Dict[str, SubagentSpec] = {}
|
||||
self._tasks: Dict[str, asyncio.Task] = {}
|
||||
|
||||
# Extract model and formatter from parent agent
|
||||
self._model = getattr(agent, "model", None)
|
||||
self._formatter = getattr(agent, "formatter", None)
|
||||
self._workspace_dir = getattr(agent, "workspace_dir", None)
|
||||
self._config_name = getattr(agent, "config_name", None)
|
||||
|
||||
async def delegate(
|
||||
self,
|
||||
agent_id: str,
|
||||
task: Callable[..., Awaitable[Msg]],
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
) -> asyncio.Task:
|
||||
"""Delegate task to a single subagent.
|
||||
|
||||
Args:
|
||||
agent_id: Unique identifier for this subagent instance
|
||||
task: Async function representing the task
|
||||
context: Optional context dict for the subagent
|
||||
|
||||
Returns:
|
||||
asyncio.Task for the delegated task
|
||||
"""
|
||||
async def _run_with_context():
|
||||
result = await task(context or {})
|
||||
return result
|
||||
|
||||
self._tasks[agent_id] = asyncio.create_task(_run_with_context())
|
||||
logger.info("Delegated task to subagent: %s", agent_id)
|
||||
return self._tasks[agent_id]
|
||||
|
||||
async def delegate_parallel(
|
||||
self,
|
||||
tasks: List[Dict[str, Any]],
|
||||
) -> List[asyncio.Task]:
|
||||
"""Delegate multiple tasks in parallel.
|
||||
|
||||
Args:
|
||||
tasks: List of task dicts with keys:
|
||||
- agent_id: Unique identifier
|
||||
- task: Async function to execute
|
||||
- context: Optional context dict
|
||||
|
||||
Returns:
|
||||
List of asyncio.Task for all delegated tasks
|
||||
"""
|
||||
async def _run_task(task_def: Dict[str, Any]):
|
||||
agent_id = task_def["agent_id"]
|
||||
task_func = task_def["task"]
|
||||
context = task_def.get("context", {})
|
||||
|
||||
async def _run_with_context():
|
||||
return await task_func(context)
|
||||
|
||||
self._tasks[agent_id] = asyncio.create_task(_run_with_context())
|
||||
return self._tasks[agent_id]
|
||||
|
||||
gathered_tasks = await asyncio.gather(
|
||||
*[_run_task(t) for t in tasks],
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
valid_tasks = [t for t in gathered_tasks if isinstance(t, asyncio.Task)]
|
||||
logger.info(
|
||||
"Delegated %d tasks in parallel (%d succeeded)",
|
||||
len(tasks),
|
||||
len(valid_tasks),
|
||||
)
|
||||
return valid_tasks
|
||||
|
||||
async def wait_for(self, agent_id: str, timeout: Optional[float] = None) -> Any:
|
||||
"""Wait for subagent task to complete.
|
||||
|
||||
Args:
|
||||
agent_id: Subagent identifier
|
||||
timeout: Optional timeout in seconds
|
||||
|
||||
Returns:
|
||||
Task result
|
||||
|
||||
Raises:
|
||||
asyncio.TimeoutError: If task doesn't complete in time
|
||||
KeyError: If agent_id not found
|
||||
"""
|
||||
if agent_id not in self._tasks:
|
||||
raise KeyError(f"Unknown subagent: {agent_id}")
|
||||
|
||||
try:
|
||||
return await asyncio.wait_for(
|
||||
self._tasks[agent_id],
|
||||
timeout=timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Task %s timed out after %s seconds", agent_id, timeout)
|
||||
raise
|
||||
|
||||
async def cancel(self, agent_id: str) -> bool:
|
||||
"""Cancel a subagent task.
|
||||
|
||||
Args:
|
||||
agent_id: Subagent identifier
|
||||
|
||||
Returns:
|
||||
True if task was cancelled
|
||||
"""
|
||||
if agent_id in self._tasks:
|
||||
self._tasks[agent_id].cancel()
|
||||
del self._tasks[agent_id]
|
||||
logger.info("Cancelled subagent task: %s", agent_id)
|
||||
return True
|
||||
return False
|
||||
|
||||
def list_tasks(self) -> List[str]:
|
||||
"""List active subagent task IDs.
|
||||
|
||||
Returns:
|
||||
List of agent_ids with pending tasks
|
||||
"""
|
||||
return list(self._tasks.keys())
|
||||
|
||||
@property
|
||||
def tasks(self) -> Dict[str, asyncio.Task]:
|
||||
"""Get copy of active tasks dict."""
|
||||
return dict(self._tasks)
|
||||
|
||||
async def delegate_task(
|
||||
self,
|
||||
task_type: str,
|
||||
task_data: Dict[str, Any],
|
||||
target_agent: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Delegate a task with optional dynamic subagent definition.
|
||||
|
||||
Supports runtime subagent definition via task_data["agents"]:
|
||||
|
||||
task_data = {
|
||||
"task": "Review code changes",
|
||||
"agents": {
|
||||
"code-reviewer": {
|
||||
"description": "Expert code reviewer...",
|
||||
"prompt": "Analyze code quality...",
|
||||
"tools": ["Read", "Glob", "Grep"],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Args:
|
||||
task_type: Type of task (e.g., "analysis", "review", "research")
|
||||
task_data: Task payload, may include "agents" for dynamic subagent def
|
||||
target_agent: Optional specific agent ID to delegate to
|
||||
|
||||
Returns:
|
||||
Dict with "success" and result/error
|
||||
"""
|
||||
try:
|
||||
# Extract dynamic subagent definitions from task_data
|
||||
agents_def = task_data.get("agents", {})
|
||||
|
||||
if agents_def:
|
||||
# Register dynamic subagents
|
||||
for agent_name, agent_spec in agents_def.items():
|
||||
self._dynamic_subagents[agent_name] = agent_spec
|
||||
logger.info(
|
||||
"Registered dynamic subagent: %s (description: %s)",
|
||||
agent_name,
|
||||
agent_spec.get("description", "")[:50],
|
||||
)
|
||||
|
||||
# Determine target agent
|
||||
effective_target = target_agent
|
||||
if not effective_target:
|
||||
# Use first available dynamic subagent or default
|
||||
if agents_def:
|
||||
effective_target = next(iter(agents_def.keys()))
|
||||
else:
|
||||
effective_target = "default"
|
||||
|
||||
# Execute the task (async)
|
||||
task_result = await self._execute_task(
|
||||
task_type=task_type,
|
||||
task_data=task_data,
|
||||
target_agent=effective_target,
|
||||
)
|
||||
|
||||
# Clean up dynamic subagents after execution
|
||||
for agent_name in agents_def.keys():
|
||||
self._dynamic_subagents.pop(agent_name, None)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"result": task_result,
|
||||
"subagents_used": list(agents_def.keys()) if agents_def else [],
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Task delegation failed: %s", e)
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
async def _execute_task(
|
||||
self,
|
||||
task_type: str,
|
||||
task_data: Dict[str, Any],
|
||||
target_agent: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute the delegated task with a real subagent.
|
||||
|
||||
Args:
|
||||
task_type: Type of task
|
||||
task_data: Task payload
|
||||
target_agent: Target agent identifier
|
||||
|
||||
Returns:
|
||||
Task execution result with success/failure info
|
||||
"""
|
||||
task_content = task_data.get("task", task_data.get("prompt", ""))
|
||||
timeout = task_data.get("timeout", DEFAULT_EXECUTION_TIMEOUT)
|
||||
|
||||
# Check if we have a dynamic subagent spec for this target
|
||||
agent_spec = self._dynamic_subagents.get(target_agent)
|
||||
|
||||
if agent_spec:
|
||||
logger.info(
|
||||
"Executing task '%s' with dynamic subagent '%s'",
|
||||
task_type,
|
||||
target_agent,
|
||||
)
|
||||
return await self._create_and_run_subagent(
|
||||
agent_name=target_agent,
|
||||
agent_spec=agent_spec,
|
||||
task_content=task_content,
|
||||
task_type=task_type,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
# Fallback: try to use parent agent's model to process the task directly
|
||||
logger.info(
|
||||
"Executing task '%s' with parent agent '%s' (no dynamic subagent)",
|
||||
task_type,
|
||||
target_agent,
|
||||
)
|
||||
return await self._run_with_parent_agent(
|
||||
task_content=task_content,
|
||||
task_type=task_type,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
async def _create_and_run_subagent(
|
||||
self,
|
||||
agent_name: str,
|
||||
agent_spec: SubagentSpec,
|
||||
task_content: str,
|
||||
task_type: str,
|
||||
timeout: float,
|
||||
) -> Dict[str, Any]:
|
||||
"""Create and run a dynamic subagent.
|
||||
|
||||
Args:
|
||||
agent_name: Name identifier for the subagent
|
||||
agent_spec: Subagent specification (description, prompt, tools, model)
|
||||
task_content: Task prompt to send to the subagent
|
||||
task_type: Type of task
|
||||
timeout: Execution timeout in seconds
|
||||
|
||||
Returns:
|
||||
Dict with execution results
|
||||
"""
|
||||
subagent_id = f"subagent_{agent_name}_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
try:
|
||||
# Create subagent instance
|
||||
subagent = await self._create_subagent(
|
||||
subagent_id=subagent_id,
|
||||
agent_spec=agent_spec,
|
||||
)
|
||||
|
||||
if subagent is None:
|
||||
return {
|
||||
"task_type": task_type,
|
||||
"task": task_content,
|
||||
"subagent": agent_name,
|
||||
"status": "failed",
|
||||
"error": "Failed to create subagent",
|
||||
"message": f"Could not instantiate subagent '{agent_name}'",
|
||||
}
|
||||
|
||||
# Store for potential cleanup
|
||||
self._subagents[subagent_id] = subagent
|
||||
|
||||
# Execute with timeout
|
||||
result = await asyncio.wait_for(
|
||||
self._run_subagent(subagent, task_content),
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
# Extract response content
|
||||
response_content = ""
|
||||
if isinstance(result, Msg):
|
||||
response_content = result.content
|
||||
elif hasattr(result, "content"):
|
||||
response_content = str(result.content)
|
||||
elif isinstance(result, dict):
|
||||
response_content = result.get("content", str(result))
|
||||
else:
|
||||
response_content = str(result)
|
||||
|
||||
logger.info(
|
||||
"Subagent '%s' completed task '%s' successfully",
|
||||
agent_name,
|
||||
task_type,
|
||||
)
|
||||
|
||||
return {
|
||||
"task_type": task_type,
|
||||
"task": task_content,
|
||||
"subagent": {
|
||||
"name": agent_name,
|
||||
"id": subagent_id,
|
||||
"description": agent_spec.get("description", ""),
|
||||
},
|
||||
"status": "completed",
|
||||
"response": response_content,
|
||||
"message": f"Task '{task_type}' executed with subagent '{agent_name}'",
|
||||
}
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
"Subagent '%s' timed out after %.1f seconds for task '%s'",
|
||||
agent_name,
|
||||
timeout,
|
||||
task_type,
|
||||
)
|
||||
# Cancel the task if still running
|
||||
if subagent_id in self._subagents:
|
||||
self._subagents.pop(subagent_id, None)
|
||||
return {
|
||||
"task_type": task_type,
|
||||
"task": task_content,
|
||||
"subagent": agent_name,
|
||||
"status": "timeout",
|
||||
"error": f"Execution timed out after {timeout} seconds",
|
||||
"message": f"Task '{task_type}' timed out for subagent '{agent_name}'",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Subagent '%s' failed for task '%s': %s",
|
||||
agent_name,
|
||||
task_type,
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
# Cleanup on failure
|
||||
if subagent_id in self._subagents:
|
||||
self._subagents.pop(subagent_id, None)
|
||||
return {
|
||||
"task_type": task_type,
|
||||
"task": task_content,
|
||||
"subagent": agent_name,
|
||||
"status": "error",
|
||||
"error": str(e),
|
||||
"message": f"Task '{task_type}' failed for subagent '{agent_name}': {e}",
|
||||
}
|
||||
|
||||
async def _create_subagent(
|
||||
self,
|
||||
subagent_id: str,
|
||||
agent_spec: SubagentSpec,
|
||||
) -> Optional[Any]:
|
||||
"""Create a subagent instance.
|
||||
|
||||
Uses the parent agent's model/formatter to create a lightweight
|
||||
subagent for task execution.
|
||||
|
||||
Args:
|
||||
subagent_id: Unique identifier for the subagent
|
||||
agent_spec: Subagent specification
|
||||
|
||||
Returns:
|
||||
Subagent instance or None if creation fails
|
||||
"""
|
||||
try:
|
||||
# Import here to avoid circular imports
|
||||
from agentscope.memory import InMemoryMemory
|
||||
|
||||
# Get model and formatter from parent
|
||||
model = self._model
|
||||
formatter = self._formatter
|
||||
|
||||
if model is None:
|
||||
logger.error("Cannot create subagent: parent agent has no model")
|
||||
return None
|
||||
|
||||
# Build system prompt from agent spec
|
||||
description = agent_spec.get("description", "")
|
||||
prompt_template = agent_spec.get("prompt", "")
|
||||
system_prompt = f"""You are {description}
|
||||
|
||||
{prompt_template}
|
||||
|
||||
Your task is to complete the user's request below.
|
||||
"""
|
||||
|
||||
# Create a minimal ReActAgent as the subagent
|
||||
from agentscope.agent import ReActAgent
|
||||
|
||||
subagent = ReActAgent(
|
||||
name=subagent_id,
|
||||
model=model,
|
||||
sys_prompt=system_prompt,
|
||||
toolkit=None, # Could load tools from agent_spec.get("tools", [])
|
||||
memory=InMemoryMemory(),
|
||||
formatter=formatter,
|
||||
max_iters=agent_spec.get("max_iters", 5),
|
||||
)
|
||||
|
||||
logger.debug("Created subagent: %s", subagent_id)
|
||||
return subagent
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to create subagent '%s': %s",
|
||||
subagent_id,
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
|
||||
async def _run_subagent(
|
||||
self,
|
||||
subagent: Any,
|
||||
task_content: str,
|
||||
) -> Any:
|
||||
"""Run a subagent with the given task.
|
||||
|
||||
Args:
|
||||
subagent: Subagent instance
|
||||
task_content: Task prompt
|
||||
|
||||
Returns:
|
||||
Agent response (Msg or similar)
|
||||
"""
|
||||
from agentscope.message import Msg
|
||||
|
||||
# Create message for the subagent
|
||||
task_msg = Msg(
|
||||
name="user",
|
||||
content=task_content,
|
||||
role="user",
|
||||
)
|
||||
|
||||
# Execute the agent
|
||||
response = await subagent.reply(task_msg)
|
||||
return response
|
||||
|
||||
async def _run_with_parent_agent(
|
||||
self,
|
||||
task_content: str,
|
||||
task_type: str,
|
||||
timeout: float,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run task using the parent agent directly.
|
||||
|
||||
Used when no dynamic subagent is defined.
|
||||
|
||||
Args:
|
||||
task_content: Task prompt
|
||||
task_type: Type of task
|
||||
timeout: Execution timeout
|
||||
|
||||
Returns:
|
||||
Dict with execution results
|
||||
"""
|
||||
try:
|
||||
result = await asyncio.wait_for(
|
||||
self._agent.reply(Msg(
|
||||
name="user",
|
||||
content=task_content,
|
||||
role="user",
|
||||
)),
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
response_content = ""
|
||||
if isinstance(result, Msg):
|
||||
response_content = result.content
|
||||
elif hasattr(result, "content"):
|
||||
response_content = str(result.content)
|
||||
else:
|
||||
response_content = str(result)
|
||||
|
||||
return {
|
||||
"task_type": task_type,
|
||||
"task": task_content,
|
||||
"status": "completed",
|
||||
"response": response_content,
|
||||
"message": f"Task '{task_type}' executed with parent agent",
|
||||
}
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
return {
|
||||
"task_type": task_type,
|
||||
"task": task_content,
|
||||
"status": "timeout",
|
||||
"error": f"Execution timed out after {timeout} seconds",
|
||||
"message": f"Task '{task_type}' timed out",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Parent agent failed for task '%s': %s",
|
||||
task_type,
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
return {
|
||||
"task_type": task_type,
|
||||
"task": task_content,
|
||||
"status": "error",
|
||||
"error": str(e),
|
||||
"message": f"Task '{task_type}' failed: {e}",
|
||||
}
|
||||
|
||||
def get_dynamic_subagent(self, name: str) -> Optional[SubagentSpec]:
|
||||
"""Get a dynamically defined subagent specification.
|
||||
|
||||
Args:
|
||||
name: Subagent name
|
||||
|
||||
Returns:
|
||||
Subagent spec dict or None if not found
|
||||
"""
|
||||
return self._dynamic_subagents.get(name)
|
||||
|
||||
def list_dynamic_subagents(self) -> List[str]:
|
||||
"""List all registered dynamic subagent names.
|
||||
|
||||
Returns:
|
||||
List of subagent names
|
||||
"""
|
||||
return list(self._dynamic_subagents.keys())
|
||||
|
||||
|
||||
__all__ = ["TaskDelegator", "SubagentSpec"]
|
||||
389
backend/agents/team/team_coordinator.py
Normal file
389
backend/agents/team/team_coordinator.py
Normal file
@@ -0,0 +1,389 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""TeamCoordinator - Agent lifecycle management and execution.
|
||||
|
||||
Provides run_parallel() using asyncio.gather() and run_sequential()
|
||||
for coordinating multiple agents.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Type
|
||||
|
||||
from agentscope.message import Msg
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TeamCoordinator:
|
||||
"""Coordinates agent lifecycle and execution.
|
||||
|
||||
Supports:
|
||||
- run_parallel(): Execute multiple agents concurrently with asyncio.gather()
|
||||
- run_sequential(): Execute agents one after another
|
||||
- run_phase(): Execute a named phase with registered agents
|
||||
- register_agent(): Add agent to coordinator
|
||||
- unregister_agent(): Remove agent from coordinator
|
||||
|
||||
Each agent maintains separate context/memory.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
participants: Optional[List[Any]] = None,
|
||||
task_content: Optional[str] = None,
|
||||
messenger: Optional[Any] = None,
|
||||
registry: Optional[Any] = None,
|
||||
):
|
||||
"""Initialize TeamCoordinator.
|
||||
|
||||
Args:
|
||||
participants: List of agent instances to coordinate
|
||||
task_content: Task description content for the agents
|
||||
messenger: AgentMessenger for communication (optional)
|
||||
registry: AgentRegistry for agent lookup (optional)
|
||||
"""
|
||||
self._participants = participants or []
|
||||
self._task_content = task_content or ""
|
||||
self._messenger = messenger
|
||||
self._registry = registry
|
||||
self._agents: Dict[str, Any] = {}
|
||||
self._running_tasks: Dict[str, asyncio.Task] = {}
|
||||
# Auto-register participants
|
||||
for agent in self._participants:
|
||||
if hasattr(agent, "name"):
|
||||
self._agents[agent.name] = agent
|
||||
elif hasattr(agent, "id"):
|
||||
self._agents[agent.id] = agent
|
||||
|
||||
def register_agent(self, agent_id: str, agent: Any) -> None:
|
||||
"""Register an agent with the coordinator.
|
||||
|
||||
Args:
|
||||
agent_id: Unique agent identifier
|
||||
agent: Agent instance
|
||||
"""
|
||||
self._agents[agent_id] = agent
|
||||
logger.info("Registered agent: %s", agent_id)
|
||||
|
||||
def unregister_agent(self, agent_id: str) -> None:
|
||||
"""Unregister an agent from the coordinator.
|
||||
|
||||
Args:
|
||||
agent_id: Agent identifier to remove
|
||||
"""
|
||||
if agent_id in self._agents:
|
||||
del self._agents[agent_id]
|
||||
logger.info("Unregistered agent: %s", agent_id)
|
||||
|
||||
def get_agent(self, agent_id: str) -> Any:
|
||||
"""Get registered agent by ID.
|
||||
|
||||
Args:
|
||||
agent_id: Agent identifier
|
||||
|
||||
Returns:
|
||||
Agent instance
|
||||
"""
|
||||
return self._agents.get(agent_id)
|
||||
|
||||
def list_agents(self) -> List[str]:
|
||||
"""List all registered agent IDs.
|
||||
|
||||
Returns:
|
||||
List of agent identifiers
|
||||
"""
|
||||
return list(self._agents.keys())
|
||||
|
||||
async def run_parallel(
|
||||
self,
|
||||
agent_ids: List[str],
|
||||
initial_message: Optional[Msg] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run multiple agents in parallel using asyncio.gather().
|
||||
|
||||
Args:
|
||||
agent_ids: List of agent IDs to run concurrently
|
||||
initial_message: Optional initial message to broadcast
|
||||
|
||||
Returns:
|
||||
Dict mapping agent_id to result
|
||||
"""
|
||||
async def _run_agent(aid: str) -> tuple[str, Any]:
|
||||
agent = self._agents.get(aid)
|
||||
if agent is None:
|
||||
logger.error("Agent %s not found", aid)
|
||||
return (aid, None)
|
||||
|
||||
try:
|
||||
if hasattr(agent, "reply") and asyncio.iscoroutinefunction(agent.reply):
|
||||
if initial_message:
|
||||
result = await agent.reply(initial_message)
|
||||
else:
|
||||
result = await agent.reply()
|
||||
elif hasattr(agent, "run") and asyncio.iscoroutinefunction(agent.run):
|
||||
result = await agent.run()
|
||||
else:
|
||||
result = await agent()
|
||||
logger.info("Agent %s completed successfully", aid)
|
||||
return (aid, result)
|
||||
except Exception as e:
|
||||
logger.error("Agent %s failed: %s", aid, e)
|
||||
return (aid, {"error": str(e)})
|
||||
|
||||
results = await asyncio.gather(
|
||||
*[_run_agent(aid) for aid in agent_ids],
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
output: Dict[str, Any] = {}
|
||||
for result in results:
|
||||
if isinstance(result, tuple):
|
||||
agent_id, agent_result = result
|
||||
output[agent_id] = agent_result
|
||||
else:
|
||||
logger.error("Unexpected result from asyncio.gather: %s", result)
|
||||
|
||||
logger.info("Parallel run completed for %d agents", len(agent_ids))
|
||||
return output
|
||||
|
||||
async def run_sequential(
|
||||
self,
|
||||
agent_ids: List[str],
|
||||
initial_message: Optional[Msg] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run agents one after another in order.
|
||||
|
||||
Args:
|
||||
agent_ids: List of agent IDs to run in sequence
|
||||
initial_message: Optional initial message for first agent
|
||||
|
||||
Returns:
|
||||
Dict mapping agent_id to result
|
||||
"""
|
||||
output: Dict[str, Any] = {}
|
||||
current_message = initial_message
|
||||
|
||||
for agent_id in agent_ids:
|
||||
agent = self._agents.get(agent_id)
|
||||
if agent is None:
|
||||
logger.error("Agent %s not found", agent_id)
|
||||
output[agent_id] = {"error": "Agent not found"}
|
||||
continue
|
||||
|
||||
try:
|
||||
if hasattr(agent, "reply") and asyncio.iscoroutinefunction(agent.reply):
|
||||
result = await agent.reply(current_message)
|
||||
elif hasattr(agent, "run") and asyncio.iscoroutinefunction(agent.run):
|
||||
result = await agent.run()
|
||||
else:
|
||||
result = await agent()
|
||||
|
||||
output[agent_id] = result
|
||||
current_message = result
|
||||
logger.info("Agent %s completed sequentially", agent_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Agent %s failed: %s", agent_id, e)
|
||||
output[agent_id] = {"error": str(e)}
|
||||
break
|
||||
|
||||
logger.info("Sequential run completed for %d agents", len(agent_ids))
|
||||
return output
|
||||
|
||||
async def run_phase(
|
||||
self,
|
||||
phase_name: str,
|
||||
agent_ids: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> List[Any]:
|
||||
"""Execute a named phase with registered agents.
|
||||
|
||||
Args:
|
||||
phase_name: Name of the phase (e.g., "analyst_analysis")
|
||||
agent_ids: Optional list of agent IDs; if None, uses all registered
|
||||
metadata: Optional metadata to include in the message (e.g., tickers, date)
|
||||
|
||||
Returns:
|
||||
List of results from each agent
|
||||
"""
|
||||
if agent_ids is None:
|
||||
agent_ids = list(self._agents.keys())
|
||||
|
||||
_agent_ids = [aid for aid in agent_ids if aid in self._agents]
|
||||
|
||||
logger.info(
|
||||
"Running phase '%s' with %d agents: %s",
|
||||
phase_name,
|
||||
len(_agent_ids),
|
||||
_agent_ids,
|
||||
)
|
||||
|
||||
# Create messages for each agent
|
||||
results: List[Any] = []
|
||||
for agent_id in _agent_ids:
|
||||
agent = self._agents[agent_id]
|
||||
try:
|
||||
if hasattr(agent, "reply") and asyncio.iscoroutinefunction(agent.reply):
|
||||
# Create a message for the agent with proper structure
|
||||
msg = Msg(
|
||||
name="system",
|
||||
content=self._task_content or f"Please execute phase: {phase_name}",
|
||||
role="user",
|
||||
metadata=metadata,
|
||||
)
|
||||
result = await agent.reply(msg)
|
||||
elif hasattr(agent, "run") and asyncio.iscoroutinefunction(agent.run):
|
||||
result = await agent.run()
|
||||
else:
|
||||
result = await agent()
|
||||
results.append(result)
|
||||
logger.info("Phase '%s': Agent %s completed", phase_name, agent_id)
|
||||
except Exception as e:
|
||||
logger.error("Phase '%s': Agent %s failed: %s", phase_name, agent_id, e)
|
||||
results.append(None)
|
||||
|
||||
logger.info("Phase '%s' completed with %d results", phase_name, len(results))
|
||||
return results
|
||||
|
||||
async def run_with_dependencies(
|
||||
self,
|
||||
agent_tasks: Dict[str, List[str]],
|
||||
initial_message: Optional[Msg] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run agents respecting dependency graph.
|
||||
|
||||
Args:
|
||||
agent_tasks: Dict mapping agent_id to list of prerequisite agent_ids
|
||||
initial_message: Optional initial message
|
||||
|
||||
Returns:
|
||||
Dict mapping agent_id to result
|
||||
"""
|
||||
completed: Dict[str, Any] = {}
|
||||
remaining = set(agent_tasks.keys())
|
||||
|
||||
while remaining:
|
||||
ready = [
|
||||
aid for aid in remaining
|
||||
if all(dep in completed for dep in agent_tasks.get(aid, []))
|
||||
]
|
||||
|
||||
if not ready:
|
||||
logger.error("Circular dependency detected in agent tasks")
|
||||
for aid in remaining:
|
||||
completed[aid] = {"error": "Circular dependency"}
|
||||
break
|
||||
|
||||
results = await self.run_parallel(ready, initial_message)
|
||||
completed.update(results)
|
||||
|
||||
for aid in ready:
|
||||
remaining.discard(aid)
|
||||
initial_message = results.get(aid)
|
||||
|
||||
return completed
|
||||
|
||||
async def fanout_pipeline(
|
||||
self,
|
||||
agents: List[Any],
|
||||
msg: Optional[Msg] = None,
|
||||
) -> List[Msg]:
|
||||
"""Fanout a message to multiple agents concurrently and collect all responses.
|
||||
|
||||
Similar to AgentScope's fanout_pipeline, this sends the same message
|
||||
to all specified agents and returns a list of all agent responses.
|
||||
|
||||
Args:
|
||||
agents: List of agent instances to fanout the message to
|
||||
msg: Message to send to all agents (optional)
|
||||
|
||||
Returns:
|
||||
List of Msg responses from each agent (in the same order as input agents)
|
||||
|
||||
Example:
|
||||
>>> responses = await fanout_pipeline(
|
||||
... agents=[alice, bob, charlie],
|
||||
... msg=question,
|
||||
... )
|
||||
>>> # responses is a list of Msg responses from each agent
|
||||
"""
|
||||
async def _fanout_to_agent(agent: Any) -> Optional[Msg]:
|
||||
"""Send message to a single agent and return its response."""
|
||||
try:
|
||||
if hasattr(agent, "reply") and asyncio.iscoroutinefunction(agent.reply):
|
||||
result = await agent.reply(msg) if msg is not None else await agent.reply()
|
||||
elif hasattr(agent, "run") and asyncio.iscoroutinefunction(agent.run):
|
||||
result = await agent.run()
|
||||
else:
|
||||
result = await agent()
|
||||
|
||||
# Convert result to Msg if needed
|
||||
if result is None:
|
||||
return None
|
||||
if isinstance(result, Msg):
|
||||
return result
|
||||
# If result is a dict with content, wrap it
|
||||
if isinstance(result, dict) and "content" in result:
|
||||
return Msg(
|
||||
name=getattr(agent, "name", "unknown"),
|
||||
content=result.get("content", ""),
|
||||
role="assistant",
|
||||
metadata=result.get("metadata"),
|
||||
)
|
||||
# Otherwise wrap the result
|
||||
return Msg(
|
||||
name=getattr(agent, "name", "unknown"),
|
||||
content=str(result),
|
||||
role="assistant",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Agent %s failed in fanout_pipeline: %s",
|
||||
getattr(agent, "name", "unknown"), e)
|
||||
return None
|
||||
|
||||
# Run all agents concurrently
|
||||
results = await asyncio.gather(
|
||||
*[_fanout_to_agent(agent) for agent in agents],
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
# Filter out exceptions and keep only valid responses
|
||||
responses: List[Msg] = []
|
||||
for i, result in enumerate(results):
|
||||
if isinstance(result, Exception):
|
||||
logger.error("Fanout to agent %d failed: %s", i, result)
|
||||
responses.append(None) # type: ignore[arg-type]
|
||||
else:
|
||||
responses.append(result) # type: ignore[arg-type]
|
||||
|
||||
logger.info("Fanout pipeline completed for %d agents", len(agents))
|
||||
return responses
|
||||
|
||||
async def shutdown(self, timeout: Optional[float] = 5.0) -> None:
|
||||
"""Shutdown all running agents gracefully.
|
||||
|
||||
Args:
|
||||
timeout: Timeout for graceful shutdown
|
||||
"""
|
||||
logger.info("Shutting down TeamCoordinator...")
|
||||
|
||||
cancel_tasks = [
|
||||
asyncio.create_task(asyncio.wait_for(task, timeout=timeout))
|
||||
for task in self._running_tasks.values()
|
||||
]
|
||||
|
||||
if cancel_tasks:
|
||||
await asyncio.gather(*cancel_tasks, return_exceptions=True)
|
||||
|
||||
self._running_tasks.clear()
|
||||
logger.info("TeamCoordinator shutdown complete")
|
||||
|
||||
@property
|
||||
def agents(self) -> Dict[str, Any]:
|
||||
"""Get copy of registered agents dict."""
|
||||
return dict(self._agents)
|
||||
|
||||
|
||||
__all__ = ["TeamCoordinator"]
|
||||
132
backend/agents/team_pipeline_config.py
Normal file
132
backend/agents/team_pipeline_config.py
Normal file
@@ -0,0 +1,132 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Run-scoped team pipeline configuration helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Iterable, List, Dict, Any
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
DEFAULT_FILENAME = "TEAM_PIPELINE.yaml"
|
||||
|
||||
|
||||
def team_pipeline_path(project_root: Path, config_name: str) -> Path:
|
||||
"""Return run-scoped team pipeline config path."""
|
||||
return project_root / "runs" / config_name / DEFAULT_FILENAME
|
||||
|
||||
|
||||
def ensure_team_pipeline_config(
|
||||
project_root: Path,
|
||||
config_name: str,
|
||||
default_analysts: Iterable[str],
|
||||
) -> Path:
|
||||
"""Ensure TEAM_PIPELINE.yaml exists for one run."""
|
||||
path = team_pipeline_path(project_root, config_name)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if path.exists():
|
||||
return path
|
||||
|
||||
payload = {
|
||||
"version": 1,
|
||||
"controller_agent": "portfolio_manager",
|
||||
"discussion": {
|
||||
"allow_dynamic_team_update": True,
|
||||
"active_analysts": list(default_analysts),
|
||||
},
|
||||
"decision": {
|
||||
"require_risk_manager": True,
|
||||
},
|
||||
}
|
||||
path.write_text(
|
||||
yaml.safe_dump(payload, allow_unicode=True, sort_keys=False),
|
||||
encoding="utf-8",
|
||||
)
|
||||
return path
|
||||
|
||||
|
||||
def load_team_pipeline_config(project_root: Path, config_name: str) -> Dict[str, Any]:
|
||||
"""Load TEAM_PIPELINE.yaml and return parsed dict."""
|
||||
path = team_pipeline_path(project_root, config_name)
|
||||
if not path.exists():
|
||||
return {}
|
||||
parsed = yaml.safe_load(path.read_text(encoding="utf-8")) or {}
|
||||
return parsed if isinstance(parsed, dict) else {}
|
||||
|
||||
|
||||
def save_team_pipeline_config(
|
||||
project_root: Path,
|
||||
config_name: str,
|
||||
config: Dict[str, Any],
|
||||
) -> Path:
|
||||
"""Persist TEAM_PIPELINE.yaml."""
|
||||
path = team_pipeline_path(project_root, config_name)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_text(
|
||||
yaml.safe_dump(config, allow_unicode=True, sort_keys=False),
|
||||
encoding="utf-8",
|
||||
)
|
||||
return path
|
||||
|
||||
|
||||
def resolve_active_analysts(
|
||||
project_root: Path,
|
||||
config_name: str,
|
||||
available_analysts: Iterable[str],
|
||||
) -> List[str]:
|
||||
"""Resolve active analysts from TEAM_PIPELINE.yaml."""
|
||||
available = [item for item in available_analysts]
|
||||
parsed = load_team_pipeline_config(project_root, config_name)
|
||||
discussion = parsed.get("discussion", {}) if isinstance(parsed, dict) else {}
|
||||
configured = discussion.get("active_analysts", [])
|
||||
if not isinstance(configured, list) or not configured:
|
||||
return available
|
||||
|
||||
active = [item for item in configured if item in available]
|
||||
return active or available
|
||||
|
||||
|
||||
def update_active_analysts(
|
||||
project_root: Path,
|
||||
config_name: str,
|
||||
available_analysts: Iterable[str],
|
||||
*,
|
||||
add: Iterable[str] | None = None,
|
||||
remove: Iterable[str] | None = None,
|
||||
set_to: Iterable[str] | None = None,
|
||||
) -> List[str]:
|
||||
"""Update active analysts and persist TEAM_PIPELINE.yaml."""
|
||||
available = [item for item in available_analysts]
|
||||
ensure_team_pipeline_config(project_root, config_name, available)
|
||||
parsed = load_team_pipeline_config(project_root, config_name)
|
||||
discussion = parsed.setdefault("discussion", {})
|
||||
if not isinstance(discussion, dict):
|
||||
discussion = {}
|
||||
parsed["discussion"] = discussion
|
||||
|
||||
current = discussion.get("active_analysts", [])
|
||||
if not isinstance(current, list):
|
||||
current = []
|
||||
current = [item for item in current if item in available]
|
||||
if not current:
|
||||
current = list(available)
|
||||
|
||||
if set_to is not None:
|
||||
target = [item for item in set_to if item in available]
|
||||
current = target or current
|
||||
|
||||
for item in add or []:
|
||||
if item in available and item not in current:
|
||||
current.append(item)
|
||||
|
||||
for item in remove or []:
|
||||
current = [existing for existing in current if existing != item]
|
||||
|
||||
if not current:
|
||||
current = [available[0]] if available else []
|
||||
|
||||
discussion["active_analysts"] = current
|
||||
save_team_pipeline_config(project_root, config_name, parsed)
|
||||
return current
|
||||
|
||||
@@ -1,21 +1,31 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Toolkit factory following AgentScope's skill + tool group practices."""
|
||||
"""Toolkit factory following AgentScope's skill + tool group practices.
|
||||
|
||||
from typing import Any, Dict, Iterable
|
||||
支持从Agent工作空间动态创建工具集,加载builtin/customized技能,
|
||||
以及合并Agent特定工具。
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set
|
||||
from pathlib import Path
|
||||
|
||||
from backend.config.bootstrap_config import get_bootstrap_config_for_run
|
||||
import yaml
|
||||
|
||||
from .skills_manager import SkillsManager
|
||||
from backend.agents.agent_workspace import load_agent_workspace_config
|
||||
from backend.agents.skills_manager import SkillsManager
|
||||
from backend.agents.skill_loader import load_skill_from_dir, get_skill_tools
|
||||
from backend.agents.skill_metadata import parse_skill_metadata
|
||||
from backend.config.bootstrap_config import get_bootstrap_config_for_run
|
||||
|
||||
|
||||
def load_agent_profiles() -> Dict[str, Dict[str, Any]]:
|
||||
"""加载Agent配置文件"""
|
||||
config_path = SkillsManager().project_root / "backend" / "config" / "agent_profiles.yaml"
|
||||
with open(config_path, "r", encoding="utf-8") as file:
|
||||
return yaml.safe_load(file) or {}
|
||||
|
||||
|
||||
def _register_analysis_tool_groups(toolkit: Any) -> None:
|
||||
"""注册分析工具组"""
|
||||
from backend.tools.analysis_tools import TOOL_REGISTRY
|
||||
|
||||
tool_groups = {
|
||||
@@ -94,6 +104,7 @@ def _register_analysis_tool_groups(toolkit: Any) -> None:
|
||||
|
||||
|
||||
def _register_portfolio_tool_groups(toolkit: Any, pm_agent: Any) -> None:
|
||||
"""注册投资组合工具组"""
|
||||
toolkit.create_tool_group(
|
||||
group_name="portfolio_ops",
|
||||
description="Portfolio decision recording tools.",
|
||||
@@ -107,9 +118,30 @@ def _register_portfolio_tool_groups(toolkit: Any, pm_agent: Any) -> None:
|
||||
pm_agent._make_decision,
|
||||
group_name="portfolio_ops",
|
||||
)
|
||||
if hasattr(pm_agent, "_add_team_analyst"):
|
||||
toolkit.register_tool_function(
|
||||
pm_agent._add_team_analyst,
|
||||
group_name="portfolio_ops",
|
||||
)
|
||||
if hasattr(pm_agent, "_remove_team_analyst"):
|
||||
toolkit.register_tool_function(
|
||||
pm_agent._remove_team_analyst,
|
||||
group_name="portfolio_ops",
|
||||
)
|
||||
if hasattr(pm_agent, "_set_active_analysts"):
|
||||
toolkit.register_tool_function(
|
||||
pm_agent._set_active_analysts,
|
||||
group_name="portfolio_ops",
|
||||
)
|
||||
if hasattr(pm_agent, "_create_team_analyst"):
|
||||
toolkit.register_tool_function(
|
||||
pm_agent._create_team_analyst,
|
||||
group_name="portfolio_ops",
|
||||
)
|
||||
|
||||
|
||||
def _register_risk_tool_groups(toolkit: Any) -> None:
|
||||
"""注册风险工具组"""
|
||||
from backend.tools.risk_tools import (
|
||||
assess_margin_and_liquidity,
|
||||
assess_position_concentration,
|
||||
@@ -145,12 +177,25 @@ def create_agent_toolkit(
|
||||
owner: Any = None,
|
||||
active_skill_dirs: Iterable[str] | None = None,
|
||||
) -> Any:
|
||||
"""Create a Toolkit with agent skills and grouped tools."""
|
||||
"""Create a Toolkit with agent skills and grouped tools.
|
||||
|
||||
Args:
|
||||
agent_id: Agent标识符
|
||||
config_name: 运行配置名称
|
||||
owner: Agent实例(用于注册特定方法)
|
||||
active_skill_dirs: 显式指定的活动技能目录列表
|
||||
|
||||
Returns:
|
||||
配置好的Toolkit实例
|
||||
"""
|
||||
from agentscope.tool import Toolkit
|
||||
|
||||
profiles = load_agent_profiles()
|
||||
profile = profiles.get(agent_id, {})
|
||||
skills_manager = SkillsManager()
|
||||
agent_config = load_agent_workspace_config(
|
||||
skills_manager.get_agent_asset_dir(config_name, agent_id) / "agent.yaml",
|
||||
)
|
||||
bootstrap_config = get_bootstrap_config_for_run(
|
||||
skills_manager.project_root,
|
||||
config_name,
|
||||
@@ -158,8 +203,16 @@ def create_agent_toolkit(
|
||||
override = bootstrap_config.agent_override(agent_id)
|
||||
active_groups = override.get(
|
||||
"active_tool_groups",
|
||||
profile.get("active_tool_groups", []),
|
||||
agent_config.active_tool_groups
|
||||
or profile.get("active_tool_groups", []),
|
||||
)
|
||||
disabled_groups = set(agent_config.disabled_tool_groups)
|
||||
if disabled_groups:
|
||||
active_groups = [
|
||||
group_name
|
||||
for group_name in active_groups
|
||||
if group_name not in disabled_groups
|
||||
]
|
||||
|
||||
toolkit = Toolkit(
|
||||
agent_skill_instruction=(
|
||||
@@ -184,14 +237,281 @@ def create_agent_toolkit(
|
||||
default_skills=profile.get("skills", []),
|
||||
)
|
||||
active_skill_dirs = [
|
||||
skills_manager.get_active_root(config_name) / skill_name
|
||||
skills_manager.get_agent_active_root(config_name, agent_id) / skill_name
|
||||
for skill_name in skill_names
|
||||
]
|
||||
|
||||
for skill_dir in active_skill_dirs:
|
||||
toolkit.register_agent_skill(str(skill_dir))
|
||||
|
||||
apply_skill_tool_restrictions(toolkit, active_skill_dirs)
|
||||
|
||||
if active_groups:
|
||||
toolkit.update_tool_groups(group_names=active_groups, active=True)
|
||||
|
||||
return toolkit
|
||||
|
||||
|
||||
def create_toolkit_from_workspace(
|
||||
agent_id: str,
|
||||
config_name: str,
|
||||
owner: Any = None,
|
||||
include_builtin: bool = True,
|
||||
include_customized: bool = True,
|
||||
include_local: bool = True,
|
||||
active_groups: Optional[List[str]] = None,
|
||||
) -> Any:
|
||||
"""从Agent工作空间创建工具集
|
||||
|
||||
这是create_agent_toolkit的增强版本,支持更灵活的技能加载策略。
|
||||
|
||||
Args:
|
||||
agent_id: Agent标识符
|
||||
config_name: 运行配置名称
|
||||
owner: Agent实例
|
||||
include_builtin: 是否包含builtin技能
|
||||
include_customized: 是否包含customized技能
|
||||
include_local: 是否包含agent-local技能
|
||||
active_groups: 显式指定的活动工具组
|
||||
|
||||
Returns:
|
||||
配置好的Toolkit实例
|
||||
"""
|
||||
from agentscope.tool import Toolkit
|
||||
|
||||
skills_manager = SkillsManager()
|
||||
agent_config = load_agent_workspace_config(
|
||||
skills_manager.get_agent_asset_dir(config_name, agent_id) / "agent.yaml",
|
||||
)
|
||||
|
||||
toolkit = Toolkit(
|
||||
agent_skill_instruction=(
|
||||
"<system-info>You have access to project skills. Each skill lives in a "
|
||||
"directory and is described by SKILL.md. Follow the skill instructions "
|
||||
"when they are relevant to the current task.</system-info>"
|
||||
),
|
||||
agent_skill_template="- {name} (dir: {dir}): {description}",
|
||||
)
|
||||
|
||||
# 注册Agent类型的默认工具组
|
||||
if agent_id.endswith("_analyst"):
|
||||
_register_analysis_tool_groups(toolkit)
|
||||
elif agent_id == "portfolio_manager" and owner is not None:
|
||||
_register_portfolio_tool_groups(toolkit, owner)
|
||||
elif agent_id == "risk_manager":
|
||||
_register_risk_tool_groups(toolkit)
|
||||
|
||||
# 收集所有要加载的技能目录
|
||||
skill_dirs: List[Path] = []
|
||||
|
||||
# 1. 从active目录加载已同步的技能
|
||||
active_root = skills_manager.get_agent_active_root(config_name, agent_id)
|
||||
if active_root.exists():
|
||||
for skill_dir in sorted(active_root.iterdir()):
|
||||
if skill_dir.is_dir() and (skill_dir / "SKILL.md").exists():
|
||||
skill_dirs.append(skill_dir)
|
||||
|
||||
# 2. 从installed目录加载
|
||||
installed_root = skills_manager.get_agent_installed_root(config_name, agent_id)
|
||||
if installed_root.exists():
|
||||
for skill_dir in sorted(installed_root.iterdir()):
|
||||
if skill_dir.is_dir() and (skill_dir / "SKILL.md").exists():
|
||||
if skill_dir not in skill_dirs:
|
||||
skill_dirs.append(skill_dir)
|
||||
|
||||
# 3. 从local目录加载agent-local技能
|
||||
if include_local:
|
||||
local_root = skills_manager.get_agent_local_root(config_name, agent_id)
|
||||
if local_root.exists():
|
||||
for skill_dir in sorted(local_root.iterdir()):
|
||||
if skill_dir.is_dir() and (skill_dir / "SKILL.md").exists():
|
||||
if skill_dir not in skill_dirs:
|
||||
skill_dirs.append(skill_dir)
|
||||
|
||||
# 注册技能到toolkit
|
||||
for skill_dir in skill_dirs:
|
||||
toolkit.register_agent_skill(str(skill_dir))
|
||||
|
||||
apply_skill_tool_restrictions(toolkit, skill_dirs)
|
||||
|
||||
# 激活指定的工具组
|
||||
if active_groups is None:
|
||||
# 从配置中读取
|
||||
profiles = load_agent_profiles()
|
||||
profile = profiles.get(agent_id, {})
|
||||
active_groups = agent_config.active_tool_groups or profile.get("active_tool_groups", [])
|
||||
|
||||
# 应用禁用列表
|
||||
disabled_groups = set(agent_config.disabled_tool_groups)
|
||||
if disabled_groups:
|
||||
active_groups = [g for g in active_groups if g not in disabled_groups]
|
||||
|
||||
if active_groups:
|
||||
toolkit.update_tool_groups(group_names=active_groups, active=True)
|
||||
|
||||
return toolkit
|
||||
|
||||
|
||||
def get_toolkit_info(toolkit: Any) -> Dict[str, Any]:
|
||||
"""获取工具集信息
|
||||
|
||||
Args:
|
||||
toolkit: Toolkit实例
|
||||
|
||||
Returns:
|
||||
工具集信息字典
|
||||
"""
|
||||
info = {
|
||||
"tool_groups": {},
|
||||
"skills": [],
|
||||
"tools_count": 0,
|
||||
}
|
||||
|
||||
# 获取工具组信息
|
||||
groups = getattr(toolkit, "tool_groups", {})
|
||||
for name, group in groups.items():
|
||||
info["tool_groups"][name] = {
|
||||
"description": getattr(group, "description", ""),
|
||||
"active": getattr(group, "active", False),
|
||||
"tools": [t.name for t in getattr(group, "tools", [])],
|
||||
}
|
||||
info["tools_count"] += len(getattr(group, "tools", []))
|
||||
|
||||
# 获取技能信息
|
||||
skills = getattr(toolkit, "agent_skills", [])
|
||||
for skill in skills:
|
||||
info["skills"].append({
|
||||
"name": getattr(skill, "name", "unknown"),
|
||||
"path": getattr(skill, "path", ""),
|
||||
"description": getattr(skill, "description", ""),
|
||||
})
|
||||
|
||||
return info
|
||||
|
||||
|
||||
def refresh_toolkit_skills(
|
||||
toolkit: Any,
|
||||
agent_id: str,
|
||||
config_name: str,
|
||||
) -> None:
|
||||
"""刷新工具集中的技能
|
||||
|
||||
重新从工作空间加载技能,用于运行时技能变更。
|
||||
|
||||
Args:
|
||||
toolkit: Toolkit实例
|
||||
agent_id: Agent标识符
|
||||
config_name: 运行配置名称
|
||||
"""
|
||||
skills_manager = SkillsManager()
|
||||
|
||||
# 清除现有技能
|
||||
if hasattr(toolkit, "agent_skills"):
|
||||
toolkit.agent_skills.clear()
|
||||
|
||||
# 重新加载active技能
|
||||
active_root = skills_manager.get_agent_active_root(config_name, agent_id)
|
||||
if active_root.exists():
|
||||
for skill_dir in sorted(active_root.iterdir()):
|
||||
if skill_dir.is_dir() and (skill_dir / "SKILL.md").exists():
|
||||
toolkit.register_agent_skill(str(skill_dir))
|
||||
|
||||
# 重新加载local技能
|
||||
local_root = skills_manager.get_agent_local_root(config_name, agent_id)
|
||||
if local_root.exists():
|
||||
for skill_dir in sorted(local_root.iterdir()):
|
||||
if skill_dir.is_dir() and (skill_dir / "SKILL.md").exists():
|
||||
toolkit.register_agent_skill(str(skill_dir))
|
||||
|
||||
|
||||
def apply_skill_tool_restrictions(toolkit: Any, skill_dirs: List[Path]) -> None:
|
||||
"""Apply per-skill allowed_tools / denied_tools restrictions to a toolkit.
|
||||
|
||||
If a skill specifies allowed_tools, only those tools are accessible when
|
||||
that skill is active. If a skill specifies denied_tools, those tools are
|
||||
removed regardless of allowed_tools. Denied tools take precedence.
|
||||
|
||||
This function annotates the toolkit with a _skill_tool_restrictions map
|
||||
that downstream code can consult when resolving available tools.
|
||||
|
||||
Args:
|
||||
toolkit: The agentscope Toolkit instance.
|
||||
skill_dirs: List of skill directory paths to inspect.
|
||||
"""
|
||||
restrictions: Dict[str, Dict[str, Set[str]]] = {}
|
||||
for skill_dir in skill_dirs:
|
||||
metadata = parse_skill_metadata(skill_dir, source="active")
|
||||
if not metadata.allowed_tools and not metadata.denied_tools:
|
||||
continue
|
||||
restrictions[skill_dir.name] = {
|
||||
"allowed": set(metadata.allowed_tools),
|
||||
"denied": set(metadata.denied_tools),
|
||||
}
|
||||
if hasattr(toolkit, "agent_skills"):
|
||||
for skill in toolkit.agent_skills:
|
||||
skill_name = getattr(skill, "name", "") or ""
|
||||
if skill_name in restrictions:
|
||||
setattr(
|
||||
skill,
|
||||
"_tool_allowed",
|
||||
restrictions[skill_name]["allowed"],
|
||||
)
|
||||
setattr(
|
||||
skill,
|
||||
"_tool_denied",
|
||||
restrictions[skill_name]["denied"],
|
||||
)
|
||||
|
||||
|
||||
def get_skill_effective_tools(skill: Any) -> Optional[Set[str]]:
|
||||
"""Return the effective tool set for a skill after applying restrictions.
|
||||
|
||||
If the skill has no restrictions (no allowed_tools / denied_tools),
|
||||
returns None to indicate "all tools allowed".
|
||||
|
||||
If allowed_tools is set, returns only those tools minus denied_tools.
|
||||
If only denied_tools is set, returns all tools minus denied_tools.
|
||||
|
||||
Args:
|
||||
skill: A skill object previously registered via register_agent_skill.
|
||||
|
||||
Returns:
|
||||
A set of allowed tool names, or None if unrestricted.
|
||||
"""
|
||||
allowed = getattr(skill, "_tool_allowed", None)
|
||||
denied = getattr(skill, "_tool_denied", set())
|
||||
|
||||
if allowed is None:
|
||||
return None
|
||||
|
||||
effective = allowed - denied
|
||||
return effective
|
||||
|
||||
|
||||
def filter_toolkit_by_skill(
|
||||
toolkit: Any,
|
||||
skill_name: str,
|
||||
) -> Set[str]:
|
||||
"""Return the set of tool names that are accessible for a given skill.
|
||||
|
||||
Args:
|
||||
toolkit: The agentscope Toolkit instance.
|
||||
skill_name: Name of the skill to query.
|
||||
|
||||
Returns:
|
||||
Set of allowed tool names, or all registered tool names if unrestricted.
|
||||
"""
|
||||
if not hasattr(toolkit, "agent_skills"):
|
||||
return set()
|
||||
|
||||
for skill in toolkit.agent_skills:
|
||||
name = getattr(skill, "name", "") or ""
|
||||
if name != skill_name:
|
||||
continue
|
||||
effective = get_skill_effective_tools(skill)
|
||||
if effective is None:
|
||||
return set()
|
||||
return effective
|
||||
|
||||
return set()
|
||||
|
||||
|
||||
327
backend/agents/workspace.py
Normal file
327
backend/agents/workspace.py
Normal file
@@ -0,0 +1,327 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Workspace Manager - Create and manage agent workspaces."""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import yaml
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkspaceConfig:
|
||||
"""Configuration for a workspace."""
|
||||
|
||||
workspace_id: str
|
||||
name: str = ""
|
||||
description: str = ""
|
||||
created_at: str = ""
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Serialize to dictionary."""
|
||||
return {
|
||||
"workspace_id": self.workspace_id,
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"created_at": self.created_at,
|
||||
"metadata": self.metadata,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "WorkspaceConfig":
|
||||
"""Create from dictionary."""
|
||||
return cls(
|
||||
workspace_id=data.get("workspace_id", ""),
|
||||
name=data.get("name", ""),
|
||||
description=data.get("description", ""),
|
||||
created_at=data.get("created_at", ""),
|
||||
metadata=data.get("metadata", {}),
|
||||
)
|
||||
|
||||
|
||||
class WorkspaceRegistry:
|
||||
"""Registry for persistent workspace definitions (design-time)."""
|
||||
|
||||
def __init__(self, project_root: Optional[Path] = None):
|
||||
"""Initialize the workspace manager.
|
||||
|
||||
Args:
|
||||
project_root: Root directory of the project
|
||||
"""
|
||||
self.project_root = project_root or Path(__file__).parent.parent.parent
|
||||
self.workspaces_root = self.project_root / "workspaces"
|
||||
self.workspaces_root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def create_workspace(
|
||||
self,
|
||||
workspace_id: str,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> WorkspaceConfig:
|
||||
"""Create a new workspace with directory structure.
|
||||
|
||||
Args:
|
||||
workspace_id: Unique identifier for the workspace
|
||||
name: Display name for the workspace
|
||||
description: Optional description
|
||||
metadata: Optional metadata dictionary
|
||||
|
||||
Returns:
|
||||
WorkspaceConfig instance
|
||||
|
||||
Raises:
|
||||
ValueError: If workspace already exists
|
||||
"""
|
||||
workspace_dir = self.workspaces_root / workspace_id
|
||||
|
||||
if workspace_dir.exists():
|
||||
raise ValueError(f"Workspace '{workspace_id}' already exists")
|
||||
|
||||
# Create directory structure
|
||||
workspace_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create subdirectories
|
||||
(workspace_dir / "agents").mkdir(exist_ok=True)
|
||||
(workspace_dir / "shared" / "market_data").mkdir(parents=True, exist_ok=True)
|
||||
(workspace_dir / "shared" / "memories").mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create workspace.yaml
|
||||
from datetime import datetime
|
||||
|
||||
config = WorkspaceConfig(
|
||||
workspace_id=workspace_id,
|
||||
name=name or workspace_id,
|
||||
description=description or "",
|
||||
created_at=datetime.now().isoformat(),
|
||||
metadata=metadata or {},
|
||||
)
|
||||
|
||||
self._write_workspace_config(workspace_dir, config)
|
||||
|
||||
return config
|
||||
|
||||
def list_workspaces(self) -> List[WorkspaceConfig]:
|
||||
"""List all workspaces.
|
||||
|
||||
Returns:
|
||||
List of WorkspaceConfig instances
|
||||
"""
|
||||
workspaces = []
|
||||
|
||||
if not self.workspaces_root.exists():
|
||||
return workspaces
|
||||
|
||||
for workspace_dir in self.workspaces_root.iterdir():
|
||||
if not workspace_dir.is_dir():
|
||||
continue
|
||||
|
||||
config_path = workspace_dir / "workspace.yaml"
|
||||
if config_path.exists():
|
||||
try:
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f) or {}
|
||||
workspaces.append(WorkspaceConfig.from_dict(data))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load workspace config {config_path}: {e}")
|
||||
|
||||
return workspaces
|
||||
|
||||
def get_workspace_agents(self, workspace_id: str) -> List[Dict[str, Any]]:
|
||||
"""Get all agents in a workspace.
|
||||
|
||||
Args:
|
||||
workspace_id: ID of the workspace
|
||||
|
||||
Returns:
|
||||
List of agent information dictionaries
|
||||
|
||||
Raises:
|
||||
ValueError: If workspace doesn't exist
|
||||
"""
|
||||
workspace_dir = self.workspaces_root / workspace_id
|
||||
|
||||
if not workspace_dir.exists():
|
||||
raise ValueError(f"Workspace '{workspace_id}' does not exist")
|
||||
|
||||
agents = []
|
||||
agents_dir = workspace_dir / "agents"
|
||||
|
||||
if not agents_dir.exists():
|
||||
return agents
|
||||
|
||||
for agent_dir in agents_dir.iterdir():
|
||||
if not agent_dir.is_dir():
|
||||
continue
|
||||
|
||||
config_path = agent_dir / "agent.yaml"
|
||||
if config_path.exists():
|
||||
try:
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config = yaml.safe_load(f) or {}
|
||||
|
||||
agents.append({
|
||||
"agent_id": agent_dir.name,
|
||||
"agent_type": config.get("agent_type", "unknown"),
|
||||
"config_path": str(config_path),
|
||||
})
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load agent config {config_path}: {e}")
|
||||
|
||||
return agents
|
||||
|
||||
def get_agent_workspace(self, agent_id: str, workspace_id: str) -> Optional[Path]:
|
||||
"""Get the workspace path for an agent.
|
||||
|
||||
Args:
|
||||
agent_id: ID of the agent
|
||||
workspace_id: ID of the workspace
|
||||
|
||||
Returns:
|
||||
Path to agent directory, or None if not found
|
||||
"""
|
||||
agent_dir = self.workspaces_root / workspace_id / "agents" / agent_id
|
||||
|
||||
if agent_dir.exists():
|
||||
return agent_dir
|
||||
|
||||
return None
|
||||
|
||||
def workspace_exists(self, workspace_id: str) -> bool:
|
||||
"""Check if a workspace exists.
|
||||
|
||||
Args:
|
||||
workspace_id: ID of the workspace
|
||||
|
||||
Returns:
|
||||
True if workspace exists, False otherwise
|
||||
"""
|
||||
workspace_dir = self.workspaces_root / workspace_id
|
||||
return workspace_dir.exists() and (workspace_dir / "workspace.yaml").exists()
|
||||
|
||||
def delete_workspace(self, workspace_id: str, force: bool = False) -> bool:
|
||||
"""Delete a workspace and all its agents.
|
||||
|
||||
Args:
|
||||
workspace_id: ID of the workspace to delete
|
||||
force: If True, delete even if workspace has agents
|
||||
|
||||
Returns:
|
||||
True if deleted, False if workspace didn't exist
|
||||
|
||||
Raises:
|
||||
ValueError: If workspace has agents and force is False
|
||||
"""
|
||||
import shutil
|
||||
|
||||
workspace_dir = self.workspaces_root / workspace_id
|
||||
|
||||
if not workspace_dir.exists():
|
||||
return False
|
||||
|
||||
# Check for agents
|
||||
agents_dir = workspace_dir / "agents"
|
||||
if agents_dir.exists() and any(agents_dir.iterdir()):
|
||||
if not force:
|
||||
raise ValueError(
|
||||
f"Workspace '{workspace_id}' contains agents. "
|
||||
"Use force=True to delete anyway."
|
||||
)
|
||||
|
||||
shutil.rmtree(workspace_dir)
|
||||
return True
|
||||
|
||||
def get_workspace_path(self, workspace_id: str) -> Path:
|
||||
"""Get the path to a workspace directory.
|
||||
|
||||
Args:
|
||||
workspace_id: ID of the workspace
|
||||
|
||||
Returns:
|
||||
Path to workspace directory
|
||||
"""
|
||||
return self.workspaces_root / workspace_id
|
||||
|
||||
def get_shared_data_path(self, workspace_id: str) -> Optional[Path]:
|
||||
"""Get the shared data directory for a workspace.
|
||||
|
||||
Args:
|
||||
workspace_id: ID of the workspace
|
||||
|
||||
Returns:
|
||||
Path to shared data directory, or None if workspace doesn't exist
|
||||
"""
|
||||
workspace_dir = self.workspaces_root / workspace_id
|
||||
|
||||
if not workspace_dir.exists():
|
||||
return None
|
||||
|
||||
return workspace_dir / "shared"
|
||||
|
||||
def update_workspace_config(
|
||||
self,
|
||||
workspace_id: str,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> WorkspaceConfig:
|
||||
"""Update workspace configuration.
|
||||
|
||||
Args:
|
||||
workspace_id: ID of the workspace
|
||||
name: New display name (optional)
|
||||
description: New description (optional)
|
||||
metadata: Metadata to merge (optional)
|
||||
|
||||
Returns:
|
||||
Updated WorkspaceConfig
|
||||
|
||||
Raises:
|
||||
ValueError: If workspace doesn't exist
|
||||
"""
|
||||
workspace_dir = self.workspaces_root / workspace_id
|
||||
|
||||
if not workspace_dir.exists():
|
||||
raise ValueError(f"Workspace '{workspace_id}' does not exist")
|
||||
|
||||
config_path = workspace_dir / "workspace.yaml"
|
||||
current_config = {}
|
||||
|
||||
if config_path.exists():
|
||||
try:
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
current_config = yaml.safe_load(f) or {}
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load existing config {config_path}: {e}")
|
||||
|
||||
# Update fields
|
||||
if name is not None:
|
||||
current_config["name"] = name
|
||||
if description is not None:
|
||||
current_config["description"] = description
|
||||
if metadata is not None:
|
||||
current_config["metadata"] = {**current_config.get("metadata", {}), **metadata}
|
||||
|
||||
config = WorkspaceConfig.from_dict(current_config)
|
||||
self._write_workspace_config(workspace_dir, config)
|
||||
|
||||
return config
|
||||
|
||||
def _write_workspace_config(self, workspace_dir: Path, config: WorkspaceConfig) -> None:
|
||||
"""Write workspace configuration to file.
|
||||
|
||||
Args:
|
||||
workspace_dir: Workspace directory
|
||||
config: Workspace configuration
|
||||
"""
|
||||
config_path = workspace_dir / "workspace.yaml"
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
yaml.safe_dump(config.to_dict(), f, allow_unicode=True, sort_keys=False)
|
||||
|
||||
|
||||
# Backward-compatible alias: legacy imports expect WorkspaceManager.
|
||||
WorkspaceManager = WorkspaceRegistry
|
||||
@@ -4,10 +4,13 @@
|
||||
from pathlib import Path
|
||||
from typing import Dict, Iterable, Optional
|
||||
|
||||
import yaml
|
||||
|
||||
from .skills_manager import SkillsManager
|
||||
from .team_pipeline_config import ensure_team_pipeline_config
|
||||
|
||||
|
||||
class WorkspaceManager:
|
||||
class RunWorkspaceManager:
|
||||
"""Create and maintain run-level prompt asset files for each agent."""
|
||||
|
||||
def __init__(self, project_root: Optional[Path] = None):
|
||||
@@ -21,6 +24,16 @@ class WorkspaceManager:
|
||||
run_dir = self.get_run_dir(config_name)
|
||||
run_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.skills_manager.ensure_activation_manifest(config_name)
|
||||
ensure_team_pipeline_config(
|
||||
project_root=self.project_root,
|
||||
config_name=config_name,
|
||||
default_analysts=[
|
||||
"fundamentals_analyst",
|
||||
"technical_analyst",
|
||||
"sentiment_analyst",
|
||||
"valuation_analyst",
|
||||
],
|
||||
)
|
||||
bootstrap_path = run_dir / "BOOTSTRAP.md"
|
||||
if not bootstrap_path.exists():
|
||||
bootstrap_path.write_text(
|
||||
@@ -28,6 +41,16 @@ class WorkspaceManager:
|
||||
"tickers:\n"
|
||||
" - AAPL\n"
|
||||
" - MSFT\n"
|
||||
" - GOOGL\n"
|
||||
" - AMZN\n"
|
||||
" - NVDA\n"
|
||||
" - META\n"
|
||||
" - TSLA\n"
|
||||
" - AMD\n"
|
||||
" - NFLX\n"
|
||||
" - AVGO\n"
|
||||
" - PLTR\n"
|
||||
" - COIN\n"
|
||||
"initial_cash: 100000\n"
|
||||
"margin_requirement: 0.0\n"
|
||||
"enable_memory: false\n"
|
||||
@@ -50,39 +73,95 @@ class WorkspaceManager:
|
||||
self,
|
||||
config_name: str,
|
||||
agent_id: str,
|
||||
role_seed: str = "",
|
||||
style_seed: str = "",
|
||||
policy_seed: str = "",
|
||||
file_contents: Optional[Dict[str, str]] = None,
|
||||
persona: Optional[Dict[str, object]] = None,
|
||||
) -> Path:
|
||||
asset_dir = self.skills_manager.get_agent_asset_dir(
|
||||
config_name,
|
||||
agent_id,
|
||||
)
|
||||
asset_dir.mkdir(parents=True, exist_ok=True)
|
||||
(asset_dir / "skills" / "installed").mkdir(parents=True, exist_ok=True)
|
||||
(asset_dir / "skills" / "active").mkdir(parents=True, exist_ok=True)
|
||||
(asset_dir / "skills" / "disabled").mkdir(parents=True, exist_ok=True)
|
||||
(asset_dir / "skills" / "local").mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self._ensure_file(
|
||||
asset_dir / "ROLE.md",
|
||||
"# Role\n\n"
|
||||
"Optional run-scoped role override.\n\n"
|
||||
f"{role_seed}".strip()
|
||||
+ "\n",
|
||||
file_contents = file_contents or self.build_default_agent_files(agent_id=agent_id)
|
||||
for filename, content in file_contents.items():
|
||||
legacy_contents = self.build_legacy_agent_file_variants(
|
||||
agent_id=agent_id,
|
||||
filename=filename,
|
||||
persona=persona,
|
||||
)
|
||||
self._ensure_file(
|
||||
asset_dir / "STYLE.md",
|
||||
"# Style\n\n"
|
||||
"Optional run-scoped communication or reasoning style.\n\n"
|
||||
f"{style_seed}".strip()
|
||||
+ "\n",
|
||||
)
|
||||
self._ensure_file(
|
||||
asset_dir / "POLICY.md",
|
||||
"# Policy\n\n"
|
||||
"Optional run-scoped constraints, limits, or strategy policy.\n\n"
|
||||
f"{policy_seed}".strip()
|
||||
+ "\n",
|
||||
self._ensure_file(asset_dir / filename, content, legacy_contents=legacy_contents)
|
||||
self._ensure_agent_yaml(
|
||||
asset_dir / "agent.yaml",
|
||||
agent_id=agent_id,
|
||||
)
|
||||
return asset_dir
|
||||
|
||||
def build_default_agent_files(
|
||||
self,
|
||||
*,
|
||||
agent_id: str,
|
||||
persona: Optional[Dict[str, object]] = None,
|
||||
) -> Dict[str, str]:
|
||||
"""Build default workspace markdown files for one agent."""
|
||||
if agent_id.endswith("_analyst"):
|
||||
return self._build_analyst_files(agent_id=agent_id, persona=persona or {})
|
||||
if agent_id == "portfolio_manager":
|
||||
return self._build_portfolio_manager_files()
|
||||
if agent_id == "risk_manager":
|
||||
return self._build_risk_manager_files()
|
||||
return self._build_generic_files(agent_id=agent_id)
|
||||
|
||||
def build_legacy_agent_file_variants(
|
||||
self,
|
||||
*,
|
||||
agent_id: str,
|
||||
filename: str,
|
||||
persona: Optional[Dict[str, object]] = None,
|
||||
) -> list[str]:
|
||||
"""Return known generated legacy variants safe to upgrade in-place."""
|
||||
persona = persona or {}
|
||||
variants: list[dict[str, str]] = [
|
||||
self._build_legacy_english_files(agent_id=agent_id),
|
||||
self._build_previous_chinese_files(agent_id=agent_id, persona=persona),
|
||||
]
|
||||
values: list[str] = []
|
||||
for item in variants:
|
||||
content = item.get(filename)
|
||||
if content:
|
||||
values.append(content)
|
||||
return values
|
||||
|
||||
def load_agent_file(
|
||||
self,
|
||||
*,
|
||||
config_name: str,
|
||||
agent_id: str,
|
||||
filename: str,
|
||||
) -> str:
|
||||
"""Load one run-scoped agent workspace file."""
|
||||
path = self.skills_manager.get_agent_asset_dir(config_name, agent_id) / filename
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"File not found: {filename}")
|
||||
return path.read_text(encoding="utf-8")
|
||||
|
||||
def update_agent_file(
|
||||
self,
|
||||
*,
|
||||
config_name: str,
|
||||
agent_id: str,
|
||||
filename: str,
|
||||
content: str,
|
||||
) -> None:
|
||||
"""Write one run-scoped agent workspace file."""
|
||||
asset_dir = self.skills_manager.get_agent_asset_dir(config_name, agent_id)
|
||||
asset_dir.mkdir(parents=True, exist_ok=True)
|
||||
path = asset_dir / filename
|
||||
path.write_text(content, encoding="utf-8")
|
||||
|
||||
def initialize_default_assets(
|
||||
self,
|
||||
config_name: str,
|
||||
@@ -95,46 +174,310 @@ class WorkspaceManager:
|
||||
for agent_id in agent_ids:
|
||||
if agent_id.endswith("_analyst"):
|
||||
persona = analyst_personas.get(agent_id, {})
|
||||
role_seed = persona.get("description", "").strip()
|
||||
focus_items = persona.get("focus", [])
|
||||
style_seed = "\n".join(f"- {item}" for item in focus_items)
|
||||
policy_seed = (
|
||||
"State a clear signal, confidence, and the conditions that would invalidate the thesis."
|
||||
)
|
||||
elif agent_id == "portfolio_manager":
|
||||
role_seed = (
|
||||
"Synthesize analyst and risk inputs into explicit portfolio decisions."
|
||||
)
|
||||
style_seed = (
|
||||
"Be concise, capital-aware, and explicit about sizing rationale."
|
||||
)
|
||||
policy_seed = (
|
||||
"Respect cash, margin, and portfolio concentration constraints before recording decisions."
|
||||
)
|
||||
elif agent_id == "risk_manager":
|
||||
role_seed = (
|
||||
"Quantify concentration, leverage, liquidity, and volatility risk before trade execution."
|
||||
)
|
||||
style_seed = (
|
||||
"Prioritize the highest-severity risk first and state concrete limits."
|
||||
)
|
||||
policy_seed = (
|
||||
"Use available risk tools before issuing the final risk memo."
|
||||
file_contents = self.build_default_agent_files(
|
||||
agent_id=agent_id,
|
||||
persona=persona,
|
||||
)
|
||||
else:
|
||||
role_seed = ""
|
||||
style_seed = ""
|
||||
policy_seed = ""
|
||||
|
||||
self.ensure_agent_assets(
|
||||
config_name=config_name,
|
||||
persona = None
|
||||
file_contents = self.build_default_agent_files(agent_id=agent_id)
|
||||
asset_dir = self.skills_manager.get_agent_asset_dir(config_name, agent_id)
|
||||
asset_dir.mkdir(parents=True, exist_ok=True)
|
||||
(asset_dir / "skills" / "installed").mkdir(parents=True, exist_ok=True)
|
||||
(asset_dir / "skills" / "active").mkdir(parents=True, exist_ok=True)
|
||||
(asset_dir / "skills" / "disabled").mkdir(parents=True, exist_ok=True)
|
||||
(asset_dir / "skills" / "local").mkdir(parents=True, exist_ok=True)
|
||||
for filename, content in file_contents.items():
|
||||
self._ensure_file(
|
||||
asset_dir / filename,
|
||||
content,
|
||||
legacy_contents=self.build_legacy_agent_file_variants(
|
||||
agent_id=agent_id,
|
||||
role_seed=role_seed,
|
||||
style_seed=style_seed,
|
||||
policy_seed=policy_seed,
|
||||
filename=filename,
|
||||
persona=persona,
|
||||
),
|
||||
)
|
||||
self._ensure_agent_yaml(asset_dir / "agent.yaml", agent_id=agent_id)
|
||||
|
||||
@staticmethod
|
||||
def _ensure_file(path: Path, content: str) -> None:
|
||||
def _ensure_file(path: Path, content: str, *, legacy_contents: Optional[list[str]] = None) -> None:
|
||||
if not path.exists():
|
||||
path.write_text(content, encoding="utf-8")
|
||||
return
|
||||
existing = path.read_text(encoding="utf-8")
|
||||
normalized_existing = existing.strip()
|
||||
candidates = {item.strip() for item in (legacy_contents or []) if item and item.strip()}
|
||||
if normalized_existing in candidates:
|
||||
path.write_text(content, encoding="utf-8")
|
||||
|
||||
@staticmethod
|
||||
def _build_generic_files(agent_id: str) -> Dict[str, str]:
|
||||
return {
|
||||
"SOUL.md": (
|
||||
"# Soul\n\n"
|
||||
f"你是 `{agent_id}`,语气冷静、客观、专业。保持清晰推理,优先基于数据而不是情绪下结论。\n"
|
||||
),
|
||||
"PROFILE.md": (
|
||||
"# Profile\n\n"
|
||||
"记录这个 agent 长期稳定的分析风格、偏好、优势与盲点。\n"
|
||||
),
|
||||
"AGENTS.md": (
|
||||
"# Agent Guide\n\n"
|
||||
"工作要求:\n"
|
||||
"- 优先使用已激活的技能和工具\n"
|
||||
"- 结论要明确,过程要可追溯\n"
|
||||
"- 与其他 agent 协作时保持输入输出简洁\n"
|
||||
"- 最终输出必须使用简体中文;如需引用英文术语,仅保留专有名词,解释和结论必须用中文\n"
|
||||
),
|
||||
"POLICY.md": (
|
||||
"# Policy\n\n"
|
||||
"- 给出结论时说明核心驱动因素\n"
|
||||
"- 明确风险边界和结论失效条件\n"
|
||||
"- 出现反例时需要纳入最终判断\n"
|
||||
"- 不要输出英文报告标题、英文摘要或整段英文正文\n"
|
||||
),
|
||||
"MEMORY.md": (
|
||||
"# Memory\n\n"
|
||||
"记录可复用的经验、失误复盘、有效启发式和需要持续跟踪的提醒。\n"
|
||||
),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _build_analyst_files(cls, *, agent_id: str, persona: Dict[str, object]) -> Dict[str, str]:
|
||||
role_name = str(persona.get("name") or agent_id)
|
||||
focus_items = [
|
||||
str(item).strip()
|
||||
for item in persona.get("focus", [])
|
||||
if str(item).strip()
|
||||
]
|
||||
focus_md = "\n".join(f"- {item}" for item in focus_items) or "- 根据当前任务选择最相关的分析维度"
|
||||
description = str(persona.get("description") or "").strip()
|
||||
|
||||
files = cls._build_generic_files(agent_id)
|
||||
files["SOUL.md"] = (
|
||||
"# Soul\n\n"
|
||||
f"你是一位专业的{role_name}。\n\n"
|
||||
"保持谦逊和开放,主动寻找与自己观点相悖的证据,并将其纳入最终评估。"
|
||||
"你的分析要体现持续演化的投资哲学,而不是一次性的结论。\n"
|
||||
)
|
||||
files["PROFILE.md"] = (
|
||||
"# Profile\n\n"
|
||||
f"角色定位:{role_name}\n\n"
|
||||
"你的关注重点:\n"
|
||||
f"{focus_md}\n\n"
|
||||
"角色说明:\n"
|
||||
f"{description or '围绕最关键的基本面、技术面、情绪面或估值因素形成高质量判断。'}\n"
|
||||
)
|
||||
files["AGENTS.md"] = (
|
||||
"# Agent Guide\n\n"
|
||||
"分析流程:\n"
|
||||
"- 优先识别真正驱动价值或价格变化的核心变量\n"
|
||||
"- 使用相关工具和技能补足证据链\n"
|
||||
"- 给出可验证、可复查、可执行的分析结果\n"
|
||||
"- 在团队讨论中清晰表达你的论点和反论点\n\n"
|
||||
"输出要求:\n"
|
||||
"- 给出明确投资信号:看涨、看跌或中性\n"
|
||||
"- 包含置信度(0-100)\n"
|
||||
"- 如果你确定要分享最终分析,请先给出结论,再给出推理依据\n"
|
||||
"- 最终输出必须使用简体中文,不要生成英文版 analysis report\n"
|
||||
)
|
||||
files["POLICY.md"] = (
|
||||
"# Policy\n\n"
|
||||
"- 深化你的投资逻辑,确保每项建议都有清晰、可追溯、可重复的依据\n"
|
||||
"- 明确风险边界:在什么具体情况下当前结论会失效\n"
|
||||
"- 做逆向测试:说明市场主流共识与你的不同点\n"
|
||||
"- 每次分析后反思这次案例如何验证或挑战你现有的信念\n"
|
||||
"- 即使输入新闻或财报原文是英文,最终表达也必须用中文\n"
|
||||
)
|
||||
return files
|
||||
|
||||
@classmethod
|
||||
def _build_portfolio_manager_files(cls) -> Dict[str, str]:
|
||||
files = cls._build_generic_files("portfolio_manager")
|
||||
files["SOUL.md"] = (
|
||||
"# Soul\n\n"
|
||||
"你是一位负责做出投资决策的投资组合经理。你需要综合多个分析视角,"
|
||||
"做出保守、明确、资本约束下可执行的组合决策。\n"
|
||||
)
|
||||
files["PROFILE.md"] = (
|
||||
"# Profile\n\n"
|
||||
"核心职责:\n"
|
||||
"- 分析分析师和风险管理经理的输入\n"
|
||||
"- 基于信号和市场情境做出投资决策\n"
|
||||
"- 使用可用工具记录每个 ticker 的决策\n"
|
||||
)
|
||||
files["AGENTS.md"] = (
|
||||
"# Agent Guide\n\n"
|
||||
"决策框架:\n"
|
||||
"- 审阅分析以理解市场观点\n"
|
||||
"- 在做决策前先考虑风险警告\n"
|
||||
"- 评估当前投资组合持仓、现金与保证金占用\n"
|
||||
"- 决策必须与整体投资目标和风险约束一致\n\n"
|
||||
"决策类型:\n"
|
||||
'- `long`:看涨,建议买入\n'
|
||||
'- `short`:看跌,建议卖出或做空\n'
|
||||
'- `hold`:中性,维持当前持仓\n\n'
|
||||
"输出要求:\n"
|
||||
"- 使用 `make_decision` 工具记录每个股票的最终决策\n"
|
||||
"- 记录完成后给出投资逻辑总结\n"
|
||||
"- 最终总结必须使用简体中文\n"
|
||||
)
|
||||
files["POLICY.md"] = (
|
||||
"# Policy\n\n"
|
||||
"- 在决定数量时考虑可用现金,不要超出现金允许范围\n"
|
||||
"- 考虑做空头寸的保证金要求\n"
|
||||
"- 仓位规模相对于组合总资产保持保守\n"
|
||||
"- 始终为决策提供清晰理由\n"
|
||||
"- 不要输出英文投资报告或英文结论\n"
|
||||
)
|
||||
return files
|
||||
|
||||
@classmethod
|
||||
def _build_risk_manager_files(cls) -> Dict[str, str]:
|
||||
files = cls._build_generic_files("risk_manager")
|
||||
files["SOUL.md"] = (
|
||||
"# Soul\n\n"
|
||||
"你是一位专业的风险管理经理,负责监控投资组合风险并提供风险警告。"
|
||||
"你的目标不是输出空泛的谨慎,而是给出量化、可执行、可优先级排序的风险意见。\n"
|
||||
)
|
||||
files["PROFILE.md"] = (
|
||||
"# Profile\n\n"
|
||||
"核心职责:\n"
|
||||
"- 监控投资组合敞口和集中度风险\n"
|
||||
"- 评估仓位规模相对于波动性是否合理\n"
|
||||
"- 评估保证金使用和杠杆水平\n"
|
||||
"- 识别潜在风险因素并提供警告\n"
|
||||
"- 基于市场条件建议仓位限制\n"
|
||||
)
|
||||
files["AGENTS.md"] = (
|
||||
"# Agent Guide\n\n"
|
||||
"决策流程:\n"
|
||||
"- 优先使用可用的风险工具量化集中度、波动率和保证金压力\n"
|
||||
"- 结合工具结果与当前市场上下文做判断\n"
|
||||
"- 生成可操作的风险警告和仓位限制建议\n"
|
||||
"- 为风险评估提供清晰理由\n\n"
|
||||
"输出要求:\n"
|
||||
"- 风险评估要简洁但全面\n"
|
||||
"- 按严重程度优先排序警告\n"
|
||||
"- 提供具体、可操作的建议\n"
|
||||
"- 尽可能包含量化指标\n"
|
||||
"- 最终风险结论必须使用简体中文\n"
|
||||
)
|
||||
files["POLICY.md"] = (
|
||||
"# Policy\n\n"
|
||||
"- 先量化,再判断,不要只给抽象风险表述\n"
|
||||
"- 高严重度风险必须先说\n"
|
||||
"- 最终结论需要明确仓位限制或调整建议\n"
|
||||
"- 不要输出英文风险报告或英文摘要\n"
|
||||
)
|
||||
return files
|
||||
|
||||
@staticmethod
|
||||
def _build_legacy_english_files(agent_id: str) -> Dict[str, str]:
|
||||
policy_tail = "Optional run-scoped constraints, limits, or strategy policy.\n\n"
|
||||
if agent_id == "portfolio_manager":
|
||||
policy_tail += "Respect cash, margin, and portfolio concentration constraints before recording decisions.\n"
|
||||
elif agent_id == "risk_manager":
|
||||
policy_tail += "Use available risk tools before issuing the final risk memo.\n"
|
||||
elif agent_id.endswith("_analyst"):
|
||||
policy_tail += "State a clear signal, confidence, and the conditions that would invalidate the thesis.\n"
|
||||
return {
|
||||
"SOUL.md": "# Soul\n\nDescribe the agent's temperament, reasoning posture, and voice.\n\n",
|
||||
"PROFILE.md": "# Profile\n\nTrack this agent's long-lived investment style, preferences, and strengths.\n\n",
|
||||
"AGENTS.md": "# Agent Guide\n\nDocument how this agent should work, collaborate, and choose tools or skills.\n\n",
|
||||
"POLICY.md": "# Policy\n\n" + policy_tail,
|
||||
"MEMORY.md": "# Memory\n\nStore durable lessons, heuristics, and reminders for this agent.\n\n",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _build_previous_chinese_files(cls, *, agent_id: str, persona: Dict[str, object]) -> Dict[str, str]:
|
||||
if agent_id.endswith("_analyst"):
|
||||
role_name = str(persona.get("name") or agent_id)
|
||||
focus_items = [
|
||||
str(item).strip()
|
||||
for item in persona.get("focus", [])
|
||||
if str(item).strip()
|
||||
]
|
||||
focus_md = "\n".join(f"- {item}" for item in focus_items) or "- 根据当前任务选择最相关的分析维度"
|
||||
description = str(persona.get("description") or "").strip()
|
||||
return {
|
||||
"SOUL.md": (
|
||||
"# Soul\n\n"
|
||||
f"你是一位专业的{role_name}。\n\n"
|
||||
"保持谦逊和开放,主动寻找与自己观点相悖的证据,并将其纳入最终评估。"
|
||||
"你的分析要体现持续演化的投资哲学,而不是一次性的结论。\n"
|
||||
),
|
||||
"PROFILE.md": (
|
||||
"# Profile\n\n"
|
||||
f"角色定位:{role_name}\n\n"
|
||||
"你的关注重点:\n"
|
||||
f"{focus_md}\n\n"
|
||||
"角色说明:\n"
|
||||
f"{description or '围绕最关键的基本面、技术面、情绪面或估值因素形成高质量判断。'}\n"
|
||||
),
|
||||
"AGENTS.md": (
|
||||
"# Agent Guide\n\n"
|
||||
"分析流程:\n"
|
||||
"- 优先识别真正驱动价值或价格变化的核心变量\n"
|
||||
"- 使用相关工具和技能补足证据链\n"
|
||||
"- 给出可验证、可复查、可执行的分析结果\n"
|
||||
"- 在团队讨论中清晰表达你的论点和反论点\n\n"
|
||||
"输出要求:\n"
|
||||
"- 给出明确投资信号:看涨、看跌或中性\n"
|
||||
"- 包含置信度(0-100)\n"
|
||||
"- 如果你确定要分享最终分析,请先给出结论,再给出推理依据\n"
|
||||
),
|
||||
"POLICY.md": (
|
||||
"# Policy\n\n"
|
||||
"- 深化你的投资逻辑,确保每项建议都有清晰、可追溯、可重复的依据\n"
|
||||
"- 明确风险边界:在什么具体情况下当前结论会失效\n"
|
||||
"- 做逆向测试:说明市场主流共识与你的不同点\n"
|
||||
"- 每次分析后反思这次案例如何验证或挑战你现有的信念\n"
|
||||
),
|
||||
"MEMORY.md": "# Memory\n\n记录可复用的经验、失误复盘、有效启发式和需要持续跟踪的提醒。\n",
|
||||
}
|
||||
if agent_id == "portfolio_manager":
|
||||
return {
|
||||
"SOUL.md": "# Soul\n\n你是一位负责做出投资决策的投资组合经理。你需要综合多个分析视角,做出保守、明确、资本约束下可执行的组合决策。\n",
|
||||
"PROFILE.md": "# Profile\n\n核心职责:\n- 分析分析师和风险管理经理的输入\n- 基于信号和市场情境做出投资决策\n- 使用可用工具记录每个 ticker 的决策\n",
|
||||
"AGENTS.md": "# Agent Guide\n\n决策框架:\n- 审阅分析以理解市场观点\n- 在做决策前先考虑风险警告\n- 评估当前投资组合持仓、现金与保证金占用\n- 决策必须与整体投资目标和风险约束一致\n\n决策类型:\n- `long`:看涨,建议买入\n- `short`:看跌,建议卖出或做空\n- `hold`:中性,维持当前持仓\n\n输出要求:\n- 使用 `make_decision` 工具记录每个股票的最终决策\n- 记录完成后给出投资逻辑总结\n",
|
||||
"POLICY.md": "# Policy\n\n- 在决定数量时考虑可用现金,不要超出现金允许范围\n- 考虑做空头寸的保证金要求\n- 仓位规模相对于组合总资产保持保守\n- 始终为决策提供清晰理由\n",
|
||||
"MEMORY.md": "# Memory\n\n记录可复用的经验、失误复盘、有效启发式和需要持续跟踪的提醒。\n",
|
||||
}
|
||||
if agent_id == "risk_manager":
|
||||
return {
|
||||
"SOUL.md": "# Soul\n\n你是一位专业的风险管理经理,负责监控投资组合风险并提供风险警告。你的目标不是输出空泛的谨慎,而是给出量化、可执行、可优先级排序的风险意见。\n",
|
||||
"PROFILE.md": "# Profile\n\n核心职责:\n- 监控投资组合敞口和集中度风险\n- 评估仓位规模相对于波动性是否合理\n- 评估保证金使用和杠杆水平\n- 识别潜在风险因素并提供警告\n- 基于市场条件建议仓位限制\n",
|
||||
"AGENTS.md": "# Agent Guide\n\n决策流程:\n- 优先使用可用的风险工具量化集中度、波动率和保证金压力\n- 结合工具结果与当前市场上下文做判断\n- 生成可操作的风险警告和仓位限制建议\n- 为风险评估提供清晰理由\n\n输出要求:\n- 风险评估要简洁但全面\n- 按严重程度优先排序警告\n- 提供具体、可操作的建议\n- 尽可能包含量化指标\n",
|
||||
"POLICY.md": "# Policy\n\n- 先量化,再判断,不要只给抽象风险表述\n- 高严重度风险必须先说\n- 最终结论需要明确仓位限制或调整建议\n",
|
||||
"MEMORY.md": "# Memory\n\n记录可复用的经验、失误复盘、有效启发式和需要持续跟踪的提醒。\n",
|
||||
}
|
||||
return cls._build_legacy_english_files(agent_id)
|
||||
|
||||
@staticmethod
|
||||
def _ensure_agent_yaml(path: Path, agent_id: str) -> None:
|
||||
if path.exists():
|
||||
return
|
||||
|
||||
payload = {
|
||||
"agent_id": agent_id,
|
||||
"prompt_files": [
|
||||
"SOUL.md",
|
||||
"PROFILE.md",
|
||||
"AGENTS.md",
|
||||
"POLICY.md",
|
||||
"MEMORY.md",
|
||||
],
|
||||
"enabled_skills": [],
|
||||
"disabled_skills": [],
|
||||
"active_tool_groups": [],
|
||||
"disabled_tool_groups": [],
|
||||
}
|
||||
path.write_text(
|
||||
yaml.safe_dump(payload, allow_unicode=True, sort_keys=False),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
|
||||
# Backward-compatible alias: code importing WorkspaceManager from this module should continue to work.
|
||||
WorkspaceManager = RunWorkspaceManager
|
||||
|
||||
23
backend/api/__init__.py
Normal file
23
backend/api/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
API Routes Package
|
||||
|
||||
Provides REST API endpoints for:
|
||||
- Agent management
|
||||
- Workspace management
|
||||
- Tool guard operations
|
||||
"""
|
||||
|
||||
from .agents import router as agents_router
|
||||
from .workspaces import router as workspaces_router
|
||||
from .guard import router as guard_router
|
||||
from .openclaw import router as openclaw_router
|
||||
from .runtime import router as runtime_router
|
||||
|
||||
__all__ = [
|
||||
"agents_router",
|
||||
"workspaces_router",
|
||||
"guard_router",
|
||||
"openclaw_router",
|
||||
"runtime_router",
|
||||
]
|
||||
709
backend/api/agents.py
Normal file
709
backend/api/agents.py
Normal file
@@ -0,0 +1,709 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Agent API Routes
|
||||
|
||||
Provides REST API endpoints for agent management within workspaces.
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends, Body, UploadFile, File, Form
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.agents import AgentFactory, get_registry
|
||||
from backend.agents.workspace_manager import RunWorkspaceManager
|
||||
from backend.agents.agent_workspace import load_agent_workspace_config
|
||||
from backend.agents.skills_manager import SkillsManager
|
||||
from backend.agents.toolkit_factory import load_agent_profiles
|
||||
from backend.config.bootstrap_config import get_bootstrap_config_for_run
|
||||
from backend.llm.models import get_agent_model_info
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/workspaces/{workspace_id}/agents", tags=["agents"])
|
||||
|
||||
|
||||
# Request/Response Models
|
||||
class CreateAgentRequest(BaseModel):
|
||||
"""Request to create a new agent."""
|
||||
agent_id: str = Field(..., description="Unique agent identifier")
|
||||
agent_type: str = Field(..., description="Type of agent (e.g., technical_analyst)")
|
||||
name: Optional[str] = Field(None, description="Display name")
|
||||
description: Optional[str] = Field(None, description="Agent description")
|
||||
clone_from: Optional[str] = Field(None, description="Agent ID to clone from")
|
||||
llm_model_config: Optional[Dict[str, Any]] = Field(None, description="LLM model configuration")
|
||||
|
||||
|
||||
class UpdateAgentRequest(BaseModel):
|
||||
"""Request to update an agent."""
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
enabled_skills: Optional[List[str]] = None
|
||||
disabled_skills: Optional[List[str]] = None
|
||||
|
||||
|
||||
class InstallExternalSkillRequest(BaseModel):
|
||||
"""Request to install an external skill for one agent."""
|
||||
source: str = Field(..., description="Directory path, zip path, or http(s) zip URL")
|
||||
name: Optional[str] = Field(None, description="Optional override skill name")
|
||||
activate: bool = Field(True, description="Whether to enable skill immediately")
|
||||
|
||||
|
||||
class LocalSkillRequest(BaseModel):
|
||||
skill_name: str = Field(..., description="Local skill name")
|
||||
|
||||
|
||||
class LocalSkillContentRequest(BaseModel):
|
||||
content: str = Field(..., description="Updated SKILL.md content")
|
||||
|
||||
|
||||
class AgentResponse(BaseModel):
|
||||
"""Agent information response."""
|
||||
agent_id: str
|
||||
agent_type: str
|
||||
workspace_id: str
|
||||
config_path: str
|
||||
agent_dir: str
|
||||
status: str = "inactive"
|
||||
|
||||
|
||||
class AgentFileResponse(BaseModel):
|
||||
"""Agent file content response."""
|
||||
filename: str
|
||||
content: str
|
||||
|
||||
|
||||
class AgentProfileResponse(BaseModel):
|
||||
agent_id: str
|
||||
workspace_id: str
|
||||
profile: Dict[str, Any]
|
||||
|
||||
|
||||
class AgentSkillsResponse(BaseModel):
|
||||
agent_id: str
|
||||
workspace_id: str
|
||||
skills: List[Dict[str, Any]]
|
||||
|
||||
|
||||
class SkillDetailResponse(BaseModel):
|
||||
agent_id: str
|
||||
workspace_id: str
|
||||
skill: Dict[str, Any]
|
||||
|
||||
|
||||
# Dependencies
|
||||
def get_agent_factory():
|
||||
"""Get AgentFactory instance."""
|
||||
return AgentFactory()
|
||||
|
||||
|
||||
def get_workspace_manager():
|
||||
"""Get run-scoped workspace manager instance."""
|
||||
return RunWorkspaceManager()
|
||||
|
||||
|
||||
def get_skills_manager():
|
||||
"""Get SkillsManager instance."""
|
||||
return SkillsManager()
|
||||
|
||||
|
||||
# Routes
|
||||
@router.post("", response_model=AgentResponse)
|
||||
async def create_agent(
|
||||
workspace_id: str,
|
||||
request: CreateAgentRequest,
|
||||
factory: AgentFactory = Depends(get_agent_factory),
|
||||
registry = Depends(get_registry),
|
||||
):
|
||||
"""
|
||||
Create a new agent in a workspace.
|
||||
|
||||
Args:
|
||||
workspace_id: Workspace identifier
|
||||
request: Agent creation parameters
|
||||
|
||||
Returns:
|
||||
Created agent information
|
||||
"""
|
||||
# Check workspace exists
|
||||
if not factory.workspaces_root.exists():
|
||||
raise HTTPException(status_code=404, detail="Workspaces root not found")
|
||||
|
||||
workspace_dir = factory.workspaces_root / workspace_id
|
||||
if not workspace_dir.exists():
|
||||
raise HTTPException(status_code=404, detail=f"Workspace '{workspace_id}' not found")
|
||||
|
||||
try:
|
||||
# Create agent
|
||||
agent = factory.create_agent(
|
||||
agent_id=request.agent_id,
|
||||
agent_type=request.agent_type,
|
||||
workspace_id=workspace_id,
|
||||
clone_from=request.clone_from,
|
||||
)
|
||||
|
||||
# Register in registry
|
||||
registry.register(
|
||||
agent_id=request.agent_id,
|
||||
agent_type=request.agent_type,
|
||||
workspace_id=workspace_id,
|
||||
config_path=str(agent.config_path),
|
||||
agent_dir=str(agent.agent_dir),
|
||||
status="inactive",
|
||||
)
|
||||
|
||||
return AgentResponse(
|
||||
agent_id=agent.agent_id,
|
||||
agent_type=agent.agent_type,
|
||||
workspace_id=agent.workspace_id,
|
||||
config_path=str(agent.config_path),
|
||||
agent_dir=str(agent.agent_dir),
|
||||
status="inactive",
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.get("", response_model=List[AgentResponse])
|
||||
async def list_agents(
|
||||
workspace_id: str,
|
||||
factory: AgentFactory = Depends(get_agent_factory),
|
||||
):
|
||||
"""
|
||||
List all agents in a workspace.
|
||||
|
||||
Args:
|
||||
workspace_id: Workspace identifier
|
||||
|
||||
Returns:
|
||||
List of agents
|
||||
"""
|
||||
try:
|
||||
agents_data = factory.list_agents(workspace_id=workspace_id)
|
||||
return [
|
||||
AgentResponse(
|
||||
agent_id=agent["agent_id"],
|
||||
agent_type=agent["agent_type"],
|
||||
workspace_id=workspace_id,
|
||||
config_path=agent["config_path"],
|
||||
agent_dir=str(Path(agent["config_path"]).parent),
|
||||
status="inactive",
|
||||
)
|
||||
for agent in agents_data
|
||||
]
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/{agent_id}", response_model=AgentResponse)
|
||||
async def get_agent(
|
||||
workspace_id: str,
|
||||
agent_id: str,
|
||||
registry = Depends(get_registry),
|
||||
):
|
||||
"""
|
||||
Get agent details.
|
||||
|
||||
Args:
|
||||
workspace_id: Workspace identifier
|
||||
agent_id: Agent identifier
|
||||
|
||||
Returns:
|
||||
Agent information
|
||||
"""
|
||||
agent_info = registry.get(agent_id)
|
||||
|
||||
if not agent_info or agent_info.workspace_id != workspace_id:
|
||||
raise HTTPException(status_code=404, detail=f"Agent '{agent_id}' not found")
|
||||
|
||||
return AgentResponse(
|
||||
agent_id=agent_info.agent_id,
|
||||
agent_type=agent_info.agent_type,
|
||||
workspace_id=agent_info.workspace_id,
|
||||
config_path=agent_info.config_path,
|
||||
agent_dir=agent_info.agent_dir,
|
||||
status=agent_info.status,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{agent_id}/profile", response_model=AgentProfileResponse)
|
||||
async def get_agent_profile(
|
||||
workspace_id: str,
|
||||
agent_id: str,
|
||||
skills_manager: SkillsManager = Depends(get_skills_manager),
|
||||
):
|
||||
asset_dir = skills_manager.get_agent_asset_dir(workspace_id, agent_id)
|
||||
agent_config = load_agent_workspace_config(asset_dir / "agent.yaml")
|
||||
profiles = load_agent_profiles()
|
||||
profile = profiles.get(agent_id, {})
|
||||
bootstrap = get_bootstrap_config_for_run(skills_manager.project_root, workspace_id)
|
||||
override = bootstrap.agent_override(agent_id)
|
||||
active_tool_groups = override.get("active_tool_groups", agent_config.active_tool_groups or profile.get("active_tool_groups", []))
|
||||
if not isinstance(active_tool_groups, list):
|
||||
active_tool_groups = []
|
||||
disabled_tool_groups = agent_config.disabled_tool_groups
|
||||
if disabled_tool_groups:
|
||||
disabled_set = set(disabled_tool_groups)
|
||||
active_tool_groups = [group_name for group_name in active_tool_groups if group_name not in disabled_set]
|
||||
|
||||
default_skills = profile.get("skills", [])
|
||||
if not isinstance(default_skills, list):
|
||||
default_skills = []
|
||||
resolved_skills = skills_manager.resolve_agent_skill_names(
|
||||
config_name=workspace_id,
|
||||
agent_id=agent_id,
|
||||
default_skills=default_skills,
|
||||
)
|
||||
prompt_files = agent_config.prompt_files or ["SOUL.md", "PROFILE.md", "AGENTS.md", "POLICY.md", "MEMORY.md"]
|
||||
model_name, model_provider = get_agent_model_info(agent_id)
|
||||
|
||||
return AgentProfileResponse(
|
||||
agent_id=agent_id,
|
||||
workspace_id=workspace_id,
|
||||
profile={
|
||||
"model_name": model_name,
|
||||
"model_provider": model_provider,
|
||||
"prompt_files": prompt_files,
|
||||
"default_skills": default_skills,
|
||||
"resolved_skills": resolved_skills,
|
||||
"active_tool_groups": active_tool_groups,
|
||||
"disabled_tool_groups": disabled_tool_groups,
|
||||
"enabled_skills": agent_config.enabled_skills,
|
||||
"disabled_skills": agent_config.disabled_skills,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{agent_id}/skills", response_model=AgentSkillsResponse)
|
||||
async def get_agent_skills(
|
||||
workspace_id: str,
|
||||
agent_id: str,
|
||||
skills_manager: SkillsManager = Depends(get_skills_manager),
|
||||
):
|
||||
agent_asset_dir = skills_manager.get_agent_asset_dir(workspace_id, agent_id)
|
||||
agent_config = load_agent_workspace_config(agent_asset_dir / "agent.yaml")
|
||||
resolved_skills = set(skills_manager.resolve_agent_skill_names(config_name=workspace_id, agent_id=agent_id, default_skills=[]))
|
||||
enabled = set(agent_config.enabled_skills)
|
||||
disabled = set(agent_config.disabled_skills)
|
||||
|
||||
payload = []
|
||||
for item in skills_manager.list_agent_skill_catalog(workspace_id, agent_id):
|
||||
if item.skill_name in disabled:
|
||||
status = "disabled"
|
||||
elif item.skill_name in enabled:
|
||||
status = "enabled"
|
||||
elif item.skill_name in resolved_skills:
|
||||
status = "active"
|
||||
else:
|
||||
status = "available"
|
||||
payload.append({
|
||||
"skill_name": item.skill_name,
|
||||
"name": item.name,
|
||||
"description": item.description,
|
||||
"version": item.version,
|
||||
"source": item.source,
|
||||
"tools": item.tools,
|
||||
"status": status,
|
||||
})
|
||||
|
||||
return AgentSkillsResponse(agent_id=agent_id, workspace_id=workspace_id, skills=payload)
|
||||
|
||||
|
||||
@router.get("/{agent_id}/skills/{skill_name}", response_model=SkillDetailResponse)
|
||||
async def get_agent_skill_detail(
|
||||
workspace_id: str,
|
||||
agent_id: str,
|
||||
skill_name: str,
|
||||
skills_manager: SkillsManager = Depends(get_skills_manager),
|
||||
):
|
||||
try:
|
||||
detail = skills_manager.load_agent_skill_document(
|
||||
config_name=workspace_id,
|
||||
agent_id=agent_id,
|
||||
skill_name=skill_name,
|
||||
)
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(status_code=404, detail=f"Unknown skill: {skill_name}")
|
||||
|
||||
return SkillDetailResponse(agent_id=agent_id, workspace_id=workspace_id, skill=detail)
|
||||
|
||||
|
||||
@router.delete("/{agent_id}")
|
||||
async def delete_agent(
|
||||
workspace_id: str,
|
||||
agent_id: str,
|
||||
factory: AgentFactory = Depends(get_agent_factory),
|
||||
registry = Depends(get_registry),
|
||||
):
|
||||
"""
|
||||
Delete an agent.
|
||||
|
||||
Args:
|
||||
workspace_id: Workspace identifier
|
||||
agent_id: Agent identifier
|
||||
|
||||
Returns:
|
||||
Success message
|
||||
"""
|
||||
# Check agent exists in registry
|
||||
agent_info = registry.get(agent_id)
|
||||
if not agent_info or agent_info.workspace_id != workspace_id:
|
||||
raise HTTPException(status_code=404, detail=f"Agent '{agent_id}' not found")
|
||||
|
||||
# Delete from factory
|
||||
success = factory.delete_agent(agent_id, workspace_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail=f"Agent '{agent_id}' not found")
|
||||
|
||||
# Unregister
|
||||
registry.unregister(agent_id)
|
||||
|
||||
return {"message": f"Agent '{agent_id}' deleted successfully"}
|
||||
|
||||
|
||||
@router.patch("/{agent_id}", response_model=AgentResponse)
|
||||
async def update_agent(
|
||||
workspace_id: str,
|
||||
agent_id: str,
|
||||
request: UpdateAgentRequest,
|
||||
registry = Depends(get_registry),
|
||||
):
|
||||
"""
|
||||
Update agent configuration.
|
||||
|
||||
Args:
|
||||
workspace_id: Workspace identifier
|
||||
agent_id: Agent identifier
|
||||
request: Update parameters
|
||||
|
||||
Returns:
|
||||
Updated agent information
|
||||
"""
|
||||
agent_info = registry.get(agent_id)
|
||||
if not agent_info or agent_info.workspace_id != workspace_id:
|
||||
raise HTTPException(status_code=404, detail=f"Agent '{agent_id}' not found")
|
||||
|
||||
# Update metadata in registry
|
||||
metadata_updates = {}
|
||||
if request.name:
|
||||
metadata_updates["name"] = request.name
|
||||
if request.description:
|
||||
metadata_updates["description"] = request.description
|
||||
|
||||
if metadata_updates:
|
||||
registry.update_metadata(agent_id, metadata_updates)
|
||||
|
||||
# Update skills if provided
|
||||
if request.enabled_skills or request.disabled_skills:
|
||||
skills_manager = SkillsManager()
|
||||
skills_manager.update_agent_skill_overrides(
|
||||
config_name=workspace_id,
|
||||
agent_id=agent_id,
|
||||
enable=request.enabled_skills or [],
|
||||
disable=request.disabled_skills or [],
|
||||
)
|
||||
|
||||
# Get updated info
|
||||
agent_info = registry.get(agent_id)
|
||||
return AgentResponse(
|
||||
agent_id=agent_info.agent_id,
|
||||
agent_type=agent_info.agent_type,
|
||||
workspace_id=agent_info.workspace_id,
|
||||
config_path=agent_info.config_path,
|
||||
agent_dir=agent_info.agent_dir,
|
||||
status=agent_info.status,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{agent_id}/skills/{skill_name}/enable")
|
||||
async def enable_skill(
|
||||
workspace_id: str,
|
||||
agent_id: str,
|
||||
skill_name: str,
|
||||
registry = Depends(get_registry),
|
||||
):
|
||||
"""
|
||||
Enable a skill for an agent.
|
||||
|
||||
Args:
|
||||
workspace_id: Workspace identifier
|
||||
agent_id: Agent identifier
|
||||
skill_name: Skill name to enable
|
||||
|
||||
Returns:
|
||||
Success message
|
||||
"""
|
||||
agent_info = registry.get(agent_id)
|
||||
if not agent_info or agent_info.workspace_id != workspace_id:
|
||||
raise HTTPException(status_code=404, detail=f"Agent '{agent_id}' not found")
|
||||
|
||||
skills_manager = SkillsManager()
|
||||
result = skills_manager.update_agent_skill_overrides(
|
||||
config_name=workspace_id,
|
||||
agent_id=agent_id,
|
||||
enable=[skill_name],
|
||||
)
|
||||
|
||||
return {
|
||||
"message": f"Skill '{skill_name}' enabled for agent '{agent_id}'",
|
||||
"enabled_skills": result["enabled_skills"],
|
||||
}
|
||||
|
||||
|
||||
@router.post("/{agent_id}/skills/{skill_name}/disable")
|
||||
async def disable_skill(
|
||||
workspace_id: str,
|
||||
agent_id: str,
|
||||
skill_name: str,
|
||||
registry = Depends(get_registry),
|
||||
):
|
||||
"""
|
||||
Disable a skill for an agent.
|
||||
|
||||
Args:
|
||||
workspace_id: Workspace identifier
|
||||
agent_id: Agent identifier
|
||||
skill_name: Skill name to disable
|
||||
|
||||
Returns:
|
||||
Success message
|
||||
"""
|
||||
agent_info = registry.get(agent_id)
|
||||
if not agent_info or agent_info.workspace_id != workspace_id:
|
||||
raise HTTPException(status_code=404, detail=f"Agent '{agent_id}' not found")
|
||||
|
||||
skills_manager = SkillsManager()
|
||||
result = skills_manager.update_agent_skill_overrides(
|
||||
config_name=workspace_id,
|
||||
agent_id=agent_id,
|
||||
disable=[skill_name],
|
||||
)
|
||||
|
||||
return {
|
||||
"message": f"Skill '{skill_name}' disabled for agent '{agent_id}'",
|
||||
"disabled_skills": result["disabled_skills"],
|
||||
}
|
||||
|
||||
|
||||
@router.post("/{agent_id}/skills/install")
|
||||
async def install_external_skill(
|
||||
workspace_id: str,
|
||||
agent_id: str,
|
||||
request: InstallExternalSkillRequest,
|
||||
registry=Depends(get_registry),
|
||||
):
|
||||
"""Install an external skill into one agent's local skills."""
|
||||
agent_info = registry.get(agent_id)
|
||||
if not agent_info or agent_info.workspace_id != workspace_id:
|
||||
raise HTTPException(status_code=404, detail=f"Agent '{agent_id}' not found")
|
||||
|
||||
skills_manager = SkillsManager()
|
||||
try:
|
||||
result = skills_manager.install_external_skill_for_agent(
|
||||
config_name=workspace_id,
|
||||
agent_id=agent_id,
|
||||
source=request.source,
|
||||
skill_name=request.name,
|
||||
activate=request.activate,
|
||||
)
|
||||
except (FileNotFoundError, ValueError) as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc))
|
||||
|
||||
return {
|
||||
"message": f"Installed external skill '{result['skill_name']}' for '{agent_id}'",
|
||||
**result,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/{agent_id}/skills/local")
|
||||
async def create_local_skill(
|
||||
workspace_id: str,
|
||||
agent_id: str,
|
||||
request: LocalSkillRequest,
|
||||
registry=Depends(get_registry),
|
||||
):
|
||||
agent_info = registry.get(agent_id)
|
||||
if not agent_info or agent_info.workspace_id != workspace_id:
|
||||
raise HTTPException(status_code=404, detail=f"Agent '{agent_id}' not found")
|
||||
|
||||
skills_manager = SkillsManager()
|
||||
try:
|
||||
skills_manager.create_agent_local_skill(
|
||||
config_name=workspace_id,
|
||||
agent_id=agent_id,
|
||||
skill_name=request.skill_name,
|
||||
)
|
||||
except (ValueError, FileExistsError) as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc))
|
||||
|
||||
return {"message": f"Created local skill '{request.skill_name}' for '{agent_id}'"}
|
||||
|
||||
|
||||
@router.put("/{agent_id}/skills/local/{skill_name}")
|
||||
async def update_local_skill(
|
||||
workspace_id: str,
|
||||
agent_id: str,
|
||||
skill_name: str,
|
||||
request: LocalSkillContentRequest,
|
||||
registry=Depends(get_registry),
|
||||
):
|
||||
agent_info = registry.get(agent_id)
|
||||
if not agent_info or agent_info.workspace_id != workspace_id:
|
||||
raise HTTPException(status_code=404, detail=f"Agent '{agent_id}' not found")
|
||||
|
||||
skills_manager = SkillsManager()
|
||||
try:
|
||||
skills_manager.update_agent_local_skill(
|
||||
config_name=workspace_id,
|
||||
agent_id=agent_id,
|
||||
skill_name=skill_name,
|
||||
content=request.content,
|
||||
)
|
||||
except (ValueError, FileNotFoundError) as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc))
|
||||
|
||||
return {"message": f"Updated local skill '{skill_name}' for '{agent_id}'"}
|
||||
|
||||
|
||||
@router.delete("/{agent_id}/skills/local/{skill_name}")
|
||||
async def delete_local_skill(
|
||||
workspace_id: str,
|
||||
agent_id: str,
|
||||
skill_name: str,
|
||||
registry=Depends(get_registry),
|
||||
):
|
||||
agent_info = registry.get(agent_id)
|
||||
if not agent_info or agent_info.workspace_id != workspace_id:
|
||||
raise HTTPException(status_code=404, detail=f"Agent '{agent_id}' not found")
|
||||
|
||||
skills_manager = SkillsManager()
|
||||
try:
|
||||
skills_manager.delete_agent_local_skill(
|
||||
config_name=workspace_id,
|
||||
agent_id=agent_id,
|
||||
skill_name=skill_name,
|
||||
)
|
||||
skills_manager.forget_agent_skill_overrides(
|
||||
config_name=workspace_id,
|
||||
agent_id=agent_id,
|
||||
skill_names=[skill_name],
|
||||
)
|
||||
except (ValueError, FileNotFoundError) as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc))
|
||||
|
||||
return {"message": f"Deleted local skill '{skill_name}' for '{agent_id}'"}
|
||||
|
||||
|
||||
@router.post("/{agent_id}/skills/upload")
|
||||
async def upload_external_skill(
|
||||
workspace_id: str,
|
||||
agent_id: str,
|
||||
file: UploadFile = File(...),
|
||||
name: Optional[str] = Form(None),
|
||||
activate: bool = Form(True),
|
||||
registry=Depends(get_registry),
|
||||
):
|
||||
"""Upload a zip skill package from frontend and install for one agent."""
|
||||
agent_info = registry.get(agent_id)
|
||||
if not agent_info or agent_info.workspace_id != workspace_id:
|
||||
raise HTTPException(status_code=404, detail=f"Agent '{agent_id}' not found")
|
||||
|
||||
original_name = (file.filename or "").strip()
|
||||
if not original_name.lower().endswith(".zip"):
|
||||
raise HTTPException(status_code=400, detail="Uploaded file must be a .zip archive")
|
||||
|
||||
suffix = Path(original_name).suffix or ".zip"
|
||||
temp_path: Optional[str] = None
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
|
||||
temp_path = tmp.name
|
||||
content = await file.read()
|
||||
tmp.write(content)
|
||||
|
||||
skills_manager = SkillsManager()
|
||||
result = skills_manager.install_external_skill_for_agent(
|
||||
config_name=workspace_id,
|
||||
agent_id=agent_id,
|
||||
source=temp_path,
|
||||
skill_name=name,
|
||||
activate=activate,
|
||||
)
|
||||
except (FileNotFoundError, ValueError) as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc))
|
||||
finally:
|
||||
try:
|
||||
await file.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to close uploaded file: {e}")
|
||||
if temp_path and os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
|
||||
return {
|
||||
"message": f"Uploaded and installed external skill '{result['skill_name']}' for '{agent_id}'",
|
||||
**result,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/{agent_id}/files/{filename}", response_model=AgentFileResponse)
|
||||
async def get_agent_file(
|
||||
workspace_id: str,
|
||||
agent_id: str,
|
||||
filename: str,
|
||||
workspace_manager: RunWorkspaceManager = Depends(get_workspace_manager),
|
||||
):
|
||||
"""
|
||||
Read an agent's workspace file.
|
||||
|
||||
Args:
|
||||
workspace_id: Workspace identifier
|
||||
agent_id: Agent identifier
|
||||
filename: File to read (e.g., SOUL.md, PROFILE.md)
|
||||
|
||||
Returns:
|
||||
File content
|
||||
"""
|
||||
try:
|
||||
content = workspace_manager.load_agent_file(
|
||||
config_name=workspace_id,
|
||||
agent_id=agent_id,
|
||||
filename=filename,
|
||||
)
|
||||
return AgentFileResponse(filename=filename, content=content)
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(status_code=404, detail=f"File '{filename}' not found")
|
||||
|
||||
|
||||
@router.put("/{agent_id}/files/{filename}", response_model=AgentFileResponse)
|
||||
async def update_agent_file(
|
||||
workspace_id: str,
|
||||
agent_id: str,
|
||||
filename: str,
|
||||
content: str = Body(..., media_type="text/plain"),
|
||||
workspace_manager: RunWorkspaceManager = Depends(get_workspace_manager),
|
||||
):
|
||||
"""
|
||||
Update an agent's workspace file.
|
||||
|
||||
Args:
|
||||
workspace_id: Workspace identifier
|
||||
agent_id: Agent identifier
|
||||
filename: File to update
|
||||
content: New file content
|
||||
|
||||
Returns:
|
||||
Updated file information
|
||||
"""
|
||||
try:
|
||||
workspace_manager.update_agent_file(
|
||||
config_name=workspace_id,
|
||||
agent_id=agent_id,
|
||||
filename=filename,
|
||||
content=content,
|
||||
)
|
||||
return AgentFileResponse(filename=filename, content=content)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
257
backend/api/guard.py
Normal file
257
backend/api/guard.py
Normal file
@@ -0,0 +1,257 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Tool Guard API Routes
|
||||
|
||||
Provides REST API endpoints for tool guard operations.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.agents.base.tool_guard import (
|
||||
ApprovalRecord,
|
||||
ApprovalStatus,
|
||||
SeverityLevel,
|
||||
TOOL_GUARD_STORE,
|
||||
default_findings_for_tool,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/api/guard", tags=["guard"])
|
||||
|
||||
|
||||
# Request/Response Models
|
||||
class ToolCallRequest(BaseModel):
|
||||
"""Tool call request."""
|
||||
tool_name: str = Field(..., description="Name of the tool")
|
||||
tool_input: Dict[str, Any] = Field(default_factory=dict, description="Tool parameters")
|
||||
agent_id: str = Field(..., description="Agent making the request")
|
||||
workspace_id: str = Field(..., description="Workspace context")
|
||||
session_id: Optional[str] = Field(None, description="Session identifier")
|
||||
|
||||
|
||||
class ApprovalRequest(BaseModel):
|
||||
"""Request to approve a tool call."""
|
||||
approval_id: str = Field(..., description="Approval request ID")
|
||||
one_time: bool = Field(True, description="Whether this is a one-time approval")
|
||||
expires_in_minutes: Optional[int] = Field(30, description="Approval expiration time")
|
||||
|
||||
|
||||
class DenyRequest(BaseModel):
|
||||
"""Request to deny a tool call."""
|
||||
approval_id: str = Field(..., description="Approval request ID")
|
||||
reason: Optional[str] = Field(None, description="Reason for denial")
|
||||
|
||||
|
||||
class ToolFinding(BaseModel):
|
||||
"""Tool guard finding."""
|
||||
severity: SeverityLevel
|
||||
message: str
|
||||
field: Optional[str] = None
|
||||
|
||||
|
||||
class ApprovalResponse(BaseModel):
|
||||
"""Tool approval response."""
|
||||
approval_id: str
|
||||
status: ApprovalStatus
|
||||
tool_name: str
|
||||
tool_input: Dict[str, Any]
|
||||
agent_id: str
|
||||
workspace_id: str
|
||||
session_id: Optional[str] = None
|
||||
findings: List[ToolFinding] = Field(default_factory=list)
|
||||
created_at: str
|
||||
resolved_at: Optional[str] = None
|
||||
resolved_by: Optional[str] = None
|
||||
|
||||
|
||||
class PendingApprovalsResponse(BaseModel):
|
||||
"""List of pending approvals."""
|
||||
approvals: List[ApprovalResponse]
|
||||
total: int
|
||||
|
||||
|
||||
STORE = TOOL_GUARD_STORE
|
||||
SAFE_TOOLS = {
|
||||
"get_price",
|
||||
"get_fundamentals",
|
||||
"get_news",
|
||||
"analyze_technical",
|
||||
}
|
||||
|
||||
|
||||
def _to_response(record: ApprovalRecord) -> ApprovalResponse:
|
||||
return ApprovalResponse(
|
||||
approval_id=record.approval_id,
|
||||
status=record.status,
|
||||
tool_name=record.tool_name,
|
||||
tool_input=record.tool_input,
|
||||
agent_id=record.agent_id,
|
||||
workspace_id=record.workspace_id,
|
||||
session_id=record.session_id,
|
||||
findings=[ToolFinding(**f.to_dict()) for f in record.findings],
|
||||
created_at=record.created_at.isoformat(),
|
||||
resolved_at=record.resolved_at.isoformat() if record.resolved_at else None,
|
||||
resolved_by=record.resolved_by,
|
||||
)
|
||||
|
||||
|
||||
# Routes
|
||||
@router.post("/check", response_model=ApprovalResponse)
|
||||
async def check_tool_call(
|
||||
request: ToolCallRequest,
|
||||
):
|
||||
"""
|
||||
Check if a tool call requires approval.
|
||||
|
||||
Args:
|
||||
request: Tool call details
|
||||
|
||||
Returns:
|
||||
Approval status - may be auto-approved, auto-denied, or pending
|
||||
"""
|
||||
record = STORE.create_pending(
|
||||
tool_name=request.tool_name,
|
||||
tool_input=request.tool_input,
|
||||
agent_id=request.agent_id,
|
||||
workspace_id=request.workspace_id,
|
||||
session_id=request.session_id,
|
||||
findings=default_findings_for_tool(request.tool_name),
|
||||
)
|
||||
|
||||
if request.tool_name in SAFE_TOOLS:
|
||||
record.status = ApprovalStatus.APPROVED
|
||||
record.resolved_at = datetime.utcnow()
|
||||
record.resolved_by = "system"
|
||||
STORE.set_status(
|
||||
record.approval_id,
|
||||
ApprovalStatus.APPROVED,
|
||||
resolved_by="system",
|
||||
notify_request=False,
|
||||
)
|
||||
|
||||
return _to_response(record)
|
||||
|
||||
|
||||
@router.post("/approve", response_model=ApprovalResponse)
|
||||
async def approve_tool_call(
|
||||
request: ApprovalRequest,
|
||||
):
|
||||
"""
|
||||
Approve a pending tool call.
|
||||
|
||||
Args:
|
||||
request: Approval parameters
|
||||
|
||||
Returns:
|
||||
Updated approval status
|
||||
"""
|
||||
record = STORE.get(request.approval_id)
|
||||
if not record:
|
||||
raise HTTPException(status_code=404, detail="Approval request not found")
|
||||
|
||||
if record.status != ApprovalStatus.PENDING:
|
||||
raise HTTPException(status_code=400, detail=f"Approval already {record.status}")
|
||||
|
||||
record.status = ApprovalStatus.APPROVED
|
||||
record.resolved_at = datetime.utcnow()
|
||||
record.resolved_by = "user"
|
||||
|
||||
return _to_response(record)
|
||||
|
||||
|
||||
@router.post("/deny", response_model=ApprovalResponse)
|
||||
async def deny_tool_call(
|
||||
request: DenyRequest,
|
||||
):
|
||||
"""
|
||||
Deny a pending tool call.
|
||||
|
||||
Args:
|
||||
request: Denial parameters
|
||||
|
||||
Returns:
|
||||
Updated approval status
|
||||
"""
|
||||
record = STORE.get(request.approval_id)
|
||||
if not record:
|
||||
raise HTTPException(status_code=404, detail="Approval request not found")
|
||||
|
||||
if record.status != ApprovalStatus.PENDING:
|
||||
raise HTTPException(status_code=400, detail=f"Approval already {record.status}")
|
||||
|
||||
record.status = ApprovalStatus.DENIED
|
||||
record.resolved_at = datetime.utcnow()
|
||||
record.resolved_by = "user"
|
||||
record.metadata["denial_reason"] = request.reason
|
||||
|
||||
return _to_response(record)
|
||||
|
||||
|
||||
@router.get("/pending", response_model=PendingApprovalsResponse)
|
||||
async def list_pending_approvals(
|
||||
workspace_id: Optional[str] = None,
|
||||
agent_id: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
List pending tool approval requests.
|
||||
|
||||
Args:
|
||||
workspace_id: Filter by workspace
|
||||
agent_id: Filter by agent
|
||||
|
||||
Returns:
|
||||
List of pending approvals
|
||||
"""
|
||||
pending = [
|
||||
_to_response(record)
|
||||
for record in STORE.list(
|
||||
status=ApprovalStatus.PENDING,
|
||||
workspace_id=workspace_id,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
]
|
||||
return PendingApprovalsResponse(approvals=pending, total=len(pending))
|
||||
|
||||
|
||||
@router.get("/approvals/{approval_id}", response_model=ApprovalResponse)
|
||||
async def get_approval_status(
|
||||
approval_id: str,
|
||||
):
|
||||
"""
|
||||
Get the status of a specific approval request.
|
||||
|
||||
Args:
|
||||
approval_id: Approval request ID
|
||||
|
||||
Returns:
|
||||
Approval status
|
||||
"""
|
||||
record = STORE.get(approval_id)
|
||||
if not record:
|
||||
raise HTTPException(status_code=404, detail="Approval request not found")
|
||||
return _to_response(record)
|
||||
|
||||
|
||||
@router.delete("/approvals/{approval_id}")
|
||||
async def cancel_approval(
|
||||
approval_id: str,
|
||||
):
|
||||
"""
|
||||
Cancel/delete a pending approval request.
|
||||
|
||||
Args:
|
||||
approval_id: Approval request ID
|
||||
|
||||
Returns:
|
||||
Success message
|
||||
"""
|
||||
record = STORE.get(approval_id)
|
||||
if not record:
|
||||
raise HTTPException(status_code=404, detail="Approval request not found")
|
||||
|
||||
STORE.cancel(approval_id)
|
||||
return _to_response(record)
|
||||
839
backend/api/openclaw.py
Normal file
839
backend/api/openclaw.py
Normal file
@@ -0,0 +1,839 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Read-only OpenClaw CLI API routes — typed with Pydantic models."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.services.openclaw_cli import OpenClawCliError, OpenClawCliService
|
||||
from shared.models.openclaw import OpenClawStatus
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/openclaw", tags=["openclaw"])
|
||||
|
||||
|
||||
def get_openclaw_cli_service() -> OpenClawCliService:
|
||||
"""Build the OpenClaw CLI service dependency."""
|
||||
return OpenClawCliService()
|
||||
|
||||
|
||||
def _raise_cli_http_error(exc: OpenClawCliError) -> None:
|
||||
detail = {
|
||||
"message": str(exc),
|
||||
"command": exc.command,
|
||||
"exit_code": exc.exit_code,
|
||||
"stdout": exc.stdout,
|
||||
"stderr": exc.stderr,
|
||||
}
|
||||
status_code = 503 if exc.exit_code is None else 502
|
||||
raise HTTPException(status_code=status_code, detail=detail) from exc
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Response wrappers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class StatusResponse(BaseModel):
|
||||
status: object
|
||||
|
||||
|
||||
class SessionsResponse(BaseModel):
|
||||
sessions: list[object]
|
||||
|
||||
|
||||
class SessionDetailResponse(BaseModel):
|
||||
session: object | None
|
||||
|
||||
|
||||
class SessionHistoryResponse(BaseModel):
|
||||
session_key: str
|
||||
session_id: str | None
|
||||
events: list[object]
|
||||
history: list[object]
|
||||
raw_text: str | None
|
||||
|
||||
|
||||
class CronResponse(BaseModel):
|
||||
cron: list[object]
|
||||
jobs: list[object]
|
||||
|
||||
|
||||
class ApprovalsResponse(BaseModel):
|
||||
approvals: list[object]
|
||||
pending: list[object]
|
||||
|
||||
|
||||
class AgentsResponse(BaseModel):
|
||||
agents: list[object]
|
||||
|
||||
|
||||
class SkillsResponse(BaseModel):
|
||||
workspace_dir: str
|
||||
managed_skills_dir: str
|
||||
skills: list[object]
|
||||
|
||||
|
||||
class ModelsResponse(BaseModel):
|
||||
models: list[object]
|
||||
|
||||
|
||||
class HooksResponse(BaseModel):
|
||||
workspace_dir: str
|
||||
managed_hooks_dir: str
|
||||
hooks: list[object]
|
||||
|
||||
|
||||
class PluginsResponse(BaseModel):
|
||||
workspace_dir: str
|
||||
plugins: list[object]
|
||||
diagnostics: list[object]
|
||||
|
||||
|
||||
class SecretsAuditResponse(BaseModel):
|
||||
version: int
|
||||
status: str
|
||||
findings: list[object]
|
||||
|
||||
|
||||
class SecurityAuditResponse2(BaseModel):
|
||||
report: object | None
|
||||
secret_diagnostics: list[str]
|
||||
|
||||
|
||||
class DaemonStatusResponse(BaseModel):
|
||||
service: object | None
|
||||
port: object | None
|
||||
rpc: object | None
|
||||
health: object | None
|
||||
|
||||
|
||||
class PairingListResponse2(BaseModel):
|
||||
channel: str
|
||||
requests: list[object]
|
||||
|
||||
|
||||
class QrCodeResponse2(BaseModel):
|
||||
setup_code: str
|
||||
gateway_url: str
|
||||
auth: str
|
||||
url_source: str
|
||||
|
||||
|
||||
class UpdateStatusResponse2(BaseModel):
|
||||
update: object | None
|
||||
channel: object | None
|
||||
|
||||
|
||||
class ModelAliasesResponse(BaseModel):
|
||||
aliases: dict[str, str]
|
||||
|
||||
|
||||
class ModelFallbacksResponse(BaseModel):
|
||||
key: str
|
||||
label: str
|
||||
items: list[object]
|
||||
|
||||
|
||||
class SkillUpdateResponse(BaseModel):
|
||||
ok: bool
|
||||
slug: str
|
||||
version: str
|
||||
error: str | None
|
||||
|
||||
|
||||
class ModelsStatusResponse(BaseModel):
|
||||
configPath: str | None = None
|
||||
agentId: str | None = None
|
||||
agentDir: str | None = None
|
||||
defaultModel: str | None = None
|
||||
resolvedDefault: str | None = None
|
||||
fallbacks: list[str] = Field(default_factory=list)
|
||||
imageModel: str | None = None
|
||||
imageFallbacks: list[str] = Field(default_factory=list)
|
||||
aliases: dict[str, str] = Field(default_factory=dict)
|
||||
allowed: list[str] = Field(default_factory=list)
|
||||
auth: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ChannelsStatusResponse(BaseModel):
|
||||
reachable: bool | None = None
|
||||
channelAccounts: dict[str, Any] = Field(default_factory=dict)
|
||||
channels: list[str] = Field(default_factory=list)
|
||||
issues: list[dict[str, Any]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ChannelsListResponse(BaseModel):
|
||||
chat: dict[str, list[str]] = Field(default_factory=dict)
|
||||
auth: list[dict[str, Any]] = Field(default_factory=list)
|
||||
usage: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class HookInfoResponse(BaseModel):
|
||||
name: str | None = None
|
||||
description: str | None = None
|
||||
source: str | None = None
|
||||
pluginId: str | None = None
|
||||
filePath: str | None = None
|
||||
handlerPath: str | None = None
|
||||
hookKey: str | None = None
|
||||
emoji: str | None = None
|
||||
homepage: str | None = None
|
||||
events: list[str] = Field(default_factory=list)
|
||||
enabledByConfig: bool | None = None
|
||||
loadable: bool | None = None
|
||||
requirementsSatisfied: bool | None = None
|
||||
requirements: dict[str, Any] = Field(default_factory=dict)
|
||||
error: str | None = None
|
||||
raw: str | None = None
|
||||
|
||||
|
||||
class HooksCheckResponse(BaseModel):
|
||||
workspace_dir: str = ""
|
||||
managed_hooks_dir: str = ""
|
||||
hooks: list[dict[str, Any]] = Field(default_factory=list)
|
||||
eligible: bool | None = None
|
||||
verbose: bool | None = None
|
||||
|
||||
|
||||
class PluginInspectEntry(BaseModel):
|
||||
plugin: dict[str, Any] = Field(default_factory=dict)
|
||||
shape: str | None = None
|
||||
capabilityMode: str | None = None
|
||||
capabilityCount: int = 0
|
||||
capabilities: list[dict[str, Any]] = Field(default_factory=list)
|
||||
typedHooks: list[dict[str, Any]] = Field(default_factory=list)
|
||||
customHooks: list[dict[str, Any]] = Field(default_factory=list)
|
||||
tools: list[dict[str, Any]] = Field(default_factory=list)
|
||||
commands: list[str] = Field(default_factory=list)
|
||||
cliCommands: list[str] = Field(default_factory=list)
|
||||
services: list[str] = Field(default_factory=list)
|
||||
gatewayMethods: list[str] = Field(default_factory=list)
|
||||
mcpServers: list[dict[str, Any]] = Field(default_factory=list)
|
||||
lspServers: list[dict[str, Any]] = Field(default_factory=list)
|
||||
httpRouteCount: int = 0
|
||||
bundleCapabilities: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class PluginsInspectResponse(BaseModel):
|
||||
inspect: list[dict[str, Any]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class AgentBindingItem(BaseModel):
|
||||
agentId: str
|
||||
match: dict[str, Any]
|
||||
description: str
|
||||
|
||||
|
||||
class AgentsBindingsResponse(BaseModel):
|
||||
bindings: list[AgentBindingItem]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Routes — use typed model methods and return Pydantic models directly
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.get("/status")
|
||||
async def api_openclaw_status(
|
||||
service: OpenClawCliService = Depends(get_openclaw_cli_service),
|
||||
) -> OpenClawStatus:
|
||||
"""Read `openclaw status --json` and return a typed model."""
|
||||
try:
|
||||
return service.status_model()
|
||||
except OpenClawCliError as exc:
|
||||
_raise_cli_http_error(exc)
|
||||
|
||||
|
||||
@router.get("/sessions")
|
||||
async def api_openclaw_sessions(
|
||||
service: OpenClawCliService = Depends(get_openclaw_cli_service),
|
||||
) -> SessionsResponse:
|
||||
"""Read `openclaw sessions --json` and return a typed SessionsList."""
|
||||
try:
|
||||
result = service.list_sessions_model()
|
||||
return SessionsResponse(sessions=result.sessions)
|
||||
except OpenClawCliError as exc:
|
||||
_raise_cli_http_error(exc)
|
||||
|
||||
|
||||
@router.get("/sessions/{session_key:path}/history")
|
||||
async def api_openclaw_session_history(
|
||||
session_key: str,
|
||||
limit: int = Query(20, ge=1, le=200),
|
||||
service: OpenClawCliService = Depends(get_openclaw_cli_service),
|
||||
) -> SessionHistoryResponse:
|
||||
"""Read session history and return a typed SessionHistory."""
|
||||
try:
|
||||
result = service.get_session_history_model(session_key, limit=limit)
|
||||
return SessionHistoryResponse(
|
||||
session_key=result.session_key,
|
||||
session_id=result.session_id,
|
||||
events=result.events,
|
||||
history=result.events, # alias for compat
|
||||
raw_text=result.raw_text,
|
||||
)
|
||||
except OpenClawCliError as exc:
|
||||
_raise_cli_http_error(exc)
|
||||
|
||||
|
||||
@router.get("/sessions/{session_key:path}")
|
||||
async def api_openclaw_session_detail(
|
||||
session_key: str,
|
||||
service: OpenClawCliService = Depends(get_openclaw_cli_service),
|
||||
) -> SessionDetailResponse:
|
||||
"""Resolve a single session and return it as a typed model."""
|
||||
try:
|
||||
session = service.get_session_model(session_key)
|
||||
return SessionDetailResponse(session=session)
|
||||
except KeyError as exc:
|
||||
raise HTTPException(status_code=404, detail=f"session '{session_key}' not found") from exc
|
||||
except OpenClawCliError as exc:
|
||||
_raise_cli_http_error(exc)
|
||||
|
||||
|
||||
@router.get("/cron")
|
||||
async def api_openclaw_cron(
|
||||
service: OpenClawCliService = Depends(get_openclaw_cli_service),
|
||||
) -> CronResponse:
|
||||
"""Read `openclaw cron list --json` and return a typed CronList."""
|
||||
try:
|
||||
result = service.list_cron_jobs_model()
|
||||
return CronResponse(cron=list(result.cron), jobs=list(result.jobs))
|
||||
except OpenClawCliError as exc:
|
||||
_raise_cli_http_error(exc)
|
||||
|
||||
|
||||
@router.get("/approvals")
|
||||
async def api_openclaw_approvals(
|
||||
service: OpenClawCliService = Depends(get_openclaw_cli_service),
|
||||
) -> ApprovalsResponse:
|
||||
"""Read `openclaw approvals get --json` and return a typed ApprovalsList."""
|
||||
try:
|
||||
result = service.list_approvals_model()
|
||||
return ApprovalsResponse(
|
||||
approvals=list(result.approvals),
|
||||
pending=list(result.pending),
|
||||
)
|
||||
except OpenClawCliError as exc:
|
||||
_raise_cli_http_error(exc)
|
||||
|
||||
|
||||
@router.get("/agents")
|
||||
async def api_openclaw_agents(
|
||||
service: OpenClawCliService = Depends(get_openclaw_cli_service),
|
||||
) -> AgentsResponse:
|
||||
"""Read `openclaw agents list --json` and return a typed AgentsList."""
|
||||
try:
|
||||
result = service.list_agents_model()
|
||||
return AgentsResponse(agents=list(result.agents))
|
||||
except OpenClawCliError as exc:
|
||||
_raise_cli_http_error(exc)
|
||||
|
||||
|
||||
@router.get("/agents/presence")
|
||||
async def api_openclaw_agents_presence(
|
||||
service: OpenClawCliService = Depends(get_openclaw_cli_service),
|
||||
) -> dict[str, Any]:
|
||||
"""Read runtime session presence for all agents from session files."""
|
||||
result = service.agents_presence()
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Write agents routes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class AgentAddResponse(BaseModel):
|
||||
agentId: str
|
||||
name: str
|
||||
workspace: str
|
||||
agentDir: str
|
||||
model: str | None = None
|
||||
bindings: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class AgentDeleteResponse(BaseModel):
|
||||
agentId: str
|
||||
workspace: str
|
||||
agentDir: str
|
||||
sessionsDir: str
|
||||
removedBindings: list[str] = Field(default_factory=list)
|
||||
removedAllow: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class AgentBindResponse(BaseModel):
|
||||
agentId: str
|
||||
added: list[str] = Field(default_factory=list)
|
||||
updated: list[str] = Field(default_factory=list)
|
||||
skipped: list[str] = Field(default_factory=list)
|
||||
conflicts: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class AgentUnbindResponse(BaseModel):
|
||||
agentId: str
|
||||
removed: list[str] = Field(default_factory=list)
|
||||
missing: list[str] = Field(default_factory=list)
|
||||
conflicts: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class AgentIdentityResponse(BaseModel):
|
||||
agentId: str
|
||||
identity: dict[str, Any] = Field(default_factory=dict)
|
||||
workspace: str | None = None
|
||||
identityFile: str | None = None
|
||||
|
||||
|
||||
@router.post("/agents/add")
|
||||
async def api_openclaw_agents_add(
|
||||
name: str,
|
||||
*,
|
||||
workspace: str | None = None,
|
||||
model: str | None = None,
|
||||
agent_dir: str | None = None,
|
||||
bind: list[str] | None = None,
|
||||
non_interactive: bool = False,
|
||||
service: OpenClawCliService = Depends(get_openclaw_cli_service),
|
||||
) -> AgentAddResponse:
|
||||
"""Run `openclaw agents add <name>` and return JSON result."""
|
||||
try:
|
||||
result = service.agents_add(
|
||||
name,
|
||||
workspace=workspace,
|
||||
model=model,
|
||||
agent_dir=agent_dir,
|
||||
bind=bind,
|
||||
non_interactive=non_interactive,
|
||||
)
|
||||
return AgentAddResponse.model_validate(result, strict=False)
|
||||
except OpenClawCliError as exc:
|
||||
_raise_cli_http_error(exc)
|
||||
|
||||
|
||||
@router.post("/agents/delete/{id}")
|
||||
async def api_openclaw_agents_delete(
|
||||
id: str,
|
||||
force: bool = False,
|
||||
service: OpenClawCliService = Depends(get_openclaw_cli_service),
|
||||
) -> AgentDeleteResponse:
|
||||
"""Run `openclaw agents delete <id> [--force]` and return JSON result."""
|
||||
try:
|
||||
result = service.agents_delete(id, force=force)
|
||||
return AgentDeleteResponse.model_validate(result, strict=False)
|
||||
except OpenClawCliError as exc:
|
||||
_raise_cli_http_error(exc)
|
||||
|
||||
|
||||
@router.post("/agents/bind")
|
||||
async def api_openclaw_agents_bind(
|
||||
*,
|
||||
agent: str | None = None,
|
||||
bind: list[str] | None = None,
|
||||
service: OpenClawCliService = Depends(get_openclaw_cli_service),
|
||||
) -> AgentBindResponse:
|
||||
"""Run `openclaw agents bind [--agent <id>] [--bind <spec>]` and return JSON result."""
|
||||
try:
|
||||
result = service.agents_bind(agent=agent, bind=bind)
|
||||
return AgentBindResponse.model_validate(result, strict=False)
|
||||
except OpenClawCliError as exc:
|
||||
_raise_cli_http_error(exc)
|
||||
|
||||
|
||||
@router.post("/agents/unbind")
|
||||
async def api_openclaw_agents_unbind(
|
||||
*,
|
||||
agent: str | None = None,
|
||||
bind: list[str] | None = None,
|
||||
all: bool = False,
|
||||
service: OpenClawCliService = Depends(get_openclaw_cli_service),
|
||||
) -> AgentUnbindResponse:
|
||||
"""Run `openclaw agents unbind [--agent <id>] [--bind <spec>] [--all]` and return JSON result."""
|
||||
try:
|
||||
result = service.agents_unbind(agent=agent, bind=bind, all=all)
|
||||
return AgentUnbindResponse.model_validate(result, strict=False)
|
||||
except OpenClawCliError as exc:
|
||||
_raise_cli_http_error(exc)
|
||||
|
||||
|
||||
@router.post("/agents/set-identity")
|
||||
async def api_openclaw_agents_set_identity(
|
||||
*,
|
||||
agent: str | None = None,
|
||||
workspace: str | None = None,
|
||||
identity_file: str | None = None,
|
||||
name: str | None = None,
|
||||
emoji: str | None = None,
|
||||
theme: str | None = None,
|
||||
avatar: str | None = None,
|
||||
from_identity: bool = False,
|
||||
service: OpenClawCliService = Depends(get_openclaw_cli_service),
|
||||
) -> AgentIdentityResponse:
|
||||
"""Run `openclaw agents set-identity` and return JSON result."""
|
||||
try:
|
||||
result = service.agents_set_identity(
|
||||
agent=agent,
|
||||
workspace=workspace,
|
||||
identity_file=identity_file,
|
||||
name=name,
|
||||
emoji=emoji,
|
||||
theme=theme,
|
||||
avatar=avatar,
|
||||
from_identity=from_identity,
|
||||
)
|
||||
return AgentIdentityResponse.model_validate(result, strict=False)
|
||||
except OpenClawCliError as exc:
|
||||
_raise_cli_http_error(exc)
|
||||
|
||||
|
||||
@router.get("/skills")
|
||||
async def api_openclaw_skills(
|
||||
service: OpenClawCliService = Depends(get_openclaw_cli_service),
|
||||
) -> SkillsResponse:
|
||||
"""Read `openclaw skills list --json` and return a typed SkillStatusReport."""
|
||||
try:
|
||||
result = service.list_skills_model()
|
||||
return SkillsResponse(
|
||||
workspace_dir=result.workspace_dir,
|
||||
managed_skills_dir=result.managed_skills_dir,
|
||||
skills=list(result.skills),
|
||||
)
|
||||
except OpenClawCliError as exc:
|
||||
_raise_cli_http_error(exc)
|
||||
|
||||
|
||||
@router.get("/models")
|
||||
async def api_openclaw_models(
|
||||
service: OpenClawCliService = Depends(get_openclaw_cli_service),
|
||||
) -> ModelsResponse:
|
||||
"""Read `openclaw models list --json` and return a typed ModelsList."""
|
||||
try:
|
||||
result = service.list_models_model()
|
||||
return ModelsResponse(models=list(result.models))
|
||||
except OpenClawCliError as exc:
|
||||
_raise_cli_http_error(exc)
|
||||
|
||||
|
||||
@router.get("/hooks")
|
||||
async def api_openclaw_hooks(
|
||||
service: OpenClawCliService = Depends(get_openclaw_cli_service),
|
||||
) -> HooksResponse:
|
||||
try:
|
||||
result = service.list_hooks_model()
|
||||
return HooksResponse(
|
||||
workspace_dir=result.workspace_dir,
|
||||
managed_hooks_dir=result.managed_hooks_dir,
|
||||
hooks=list(result.hooks),
|
||||
)
|
||||
except OpenClawCliError as exc:
|
||||
_raise_cli_http_error(exc)
|
||||
|
||||
|
||||
@router.get("/plugins")
|
||||
async def api_openclaw_plugins(
|
||||
service: OpenClawCliService = Depends(get_openclaw_cli_service),
|
||||
) -> PluginsResponse:
|
||||
try:
|
||||
result = service.list_plugins_model()
|
||||
return PluginsResponse(
|
||||
workspace_dir=result.workspace_dir,
|
||||
plugins=list(result.plugins),
|
||||
diagnostics=list(result.diagnostics),
|
||||
)
|
||||
except OpenClawCliError as exc:
|
||||
_raise_cli_http_error(exc)
|
||||
|
||||
|
||||
@router.get("/secrets-audit")
|
||||
async def api_openclaw_secrets_audit(
|
||||
service: OpenClawCliService = Depends(get_openclaw_cli_service),
|
||||
) -> SecretsAuditResponse:
|
||||
try:
|
||||
result = service.secrets_audit_model()
|
||||
return SecretsAuditResponse(
|
||||
version=result.version,
|
||||
status=result.status,
|
||||
findings=list(result.findings),
|
||||
)
|
||||
except OpenClawCliError as exc:
|
||||
_raise_cli_http_error(exc)
|
||||
|
||||
|
||||
@router.get("/security-audit")
|
||||
async def api_openclaw_security_audit(
|
||||
service: OpenClawCliService = Depends(get_openclaw_cli_service),
|
||||
) -> SecurityAuditResponse2:
|
||||
try:
|
||||
result = service.security_audit_model()
|
||||
return SecurityAuditResponse2(
|
||||
report=result.report.model_dump() if result.report else None,
|
||||
secret_diagnostics=list(result.secret_diagnostics),
|
||||
)
|
||||
except OpenClawCliError as exc:
|
||||
_raise_cli_http_error(exc)
|
||||
|
||||
|
||||
@router.get("/daemon-status")
|
||||
async def api_openclaw_daemon_status(
|
||||
service: OpenClawCliService = Depends(get_openclaw_cli_service),
|
||||
) -> DaemonStatusResponse:
|
||||
try:
|
||||
result = service.daemon_status_model()
|
||||
return DaemonStatusResponse(
|
||||
service=result.service.model_dump() if result.service else None,
|
||||
port=result.port.model_dump() if result.port else None,
|
||||
rpc=result.rpc.model_dump() if result.rpc else None,
|
||||
health=result.health.model_dump() if result.health else None,
|
||||
)
|
||||
except OpenClawCliError as exc:
|
||||
_raise_cli_http_error(exc)
|
||||
|
||||
|
||||
@router.get("/pairing")
|
||||
async def api_openclaw_pairing(
|
||||
service: OpenClawCliService = Depends(get_openclaw_cli_service),
|
||||
) -> PairingListResponse2:
|
||||
try:
|
||||
result = service.pairing_list_model()
|
||||
return PairingListResponse2(
|
||||
channel=result.channel,
|
||||
requests=list(result.requests),
|
||||
)
|
||||
except OpenClawCliError as exc:
|
||||
_raise_cli_http_error(exc)
|
||||
|
||||
|
||||
@router.get("/qr")
|
||||
async def api_openclaw_qr(
|
||||
service: OpenClawCliService = Depends(get_openclaw_cli_service),
|
||||
) -> QrCodeResponse2:
|
||||
try:
|
||||
result = service.qr_code_model()
|
||||
return QrCodeResponse2(
|
||||
setup_code=result.setup_code,
|
||||
gateway_url=result.gateway_url,
|
||||
auth=result.auth,
|
||||
url_source=result.url_source,
|
||||
)
|
||||
except OpenClawCliError as exc:
|
||||
_raise_cli_http_error(exc)
|
||||
|
||||
|
||||
@router.get("/update-status")
|
||||
async def api_openclaw_update_status(
|
||||
service: OpenClawCliService = Depends(get_openclaw_cli_service),
|
||||
) -> UpdateStatusResponse2:
|
||||
try:
|
||||
result = service.update_status_model()
|
||||
return UpdateStatusResponse2(
|
||||
update=result.update.model_dump() if result.update else None,
|
||||
channel=result.channel.model_dump() if result.channel else None,
|
||||
)
|
||||
except OpenClawCliError as exc:
|
||||
_raise_cli_http_error(exc)
|
||||
|
||||
|
||||
@router.get("/models-aliases")
|
||||
async def api_openclaw_models_aliases(
|
||||
service: OpenClawCliService = Depends(get_openclaw_cli_service),
|
||||
) -> ModelAliasesResponse:
|
||||
try:
|
||||
result = service.list_model_aliases_model()
|
||||
return ModelAliasesResponse(aliases=result.aliases)
|
||||
except OpenClawCliError as exc:
|
||||
_raise_cli_http_error(exc)
|
||||
|
||||
|
||||
@router.get("/models-fallbacks")
|
||||
async def api_openclaw_models_fallbacks(
|
||||
service: OpenClawCliService = Depends(get_openclaw_cli_service),
|
||||
) -> ModelFallbacksResponse:
|
||||
try:
|
||||
result = service.list_model_fallbacks_model()
|
||||
return ModelFallbacksResponse(
|
||||
key=result.key,
|
||||
label=result.label,
|
||||
items=list(result.items),
|
||||
)
|
||||
except OpenClawCliError as exc:
|
||||
_raise_cli_http_error(exc)
|
||||
|
||||
|
||||
@router.get("/models-image-fallbacks")
|
||||
async def api_openclaw_models_image_fallbacks(
|
||||
service: OpenClawCliService = Depends(get_openclaw_cli_service),
|
||||
) -> ModelFallbacksResponse:
|
||||
try:
|
||||
result = service.list_model_image_fallbacks_model()
|
||||
return ModelFallbacksResponse(
|
||||
key=result.key,
|
||||
label=result.label,
|
||||
items=list(result.items),
|
||||
)
|
||||
except OpenClawCliError as exc:
|
||||
_raise_cli_http_error(exc)
|
||||
|
||||
|
||||
@router.get("/skill-update")
|
||||
async def api_openclaw_skill_update(
|
||||
slug: str | None = None,
|
||||
all: bool = False,
|
||||
service: OpenClawCliService = Depends(get_openclaw_cli_service),
|
||||
) -> SkillUpdateResponse:
|
||||
try:
|
||||
result = service.skill_update_model(slug=slug, all=all)
|
||||
return SkillUpdateResponse(
|
||||
ok=result.ok,
|
||||
slug=result.slug,
|
||||
version=result.version,
|
||||
error=result.error,
|
||||
)
|
||||
except OpenClawCliError as exc:
|
||||
_raise_cli_http_error(exc)
|
||||
|
||||
|
||||
@router.get("/models-status")
|
||||
async def api_openclaw_models_status(
|
||||
probe: bool = False,
|
||||
service: OpenClawCliService = Depends(get_openclaw_cli_service),
|
||||
) -> ModelsStatusResponse:
|
||||
"""Read `openclaw models status --json [--probe]` and return a typed dict."""
|
||||
try:
|
||||
result = service.models_status_model(probe=probe)
|
||||
return ModelsStatusResponse.model_validate(result, strict=False)
|
||||
except OpenClawCliError as exc:
|
||||
_raise_cli_http_error(exc)
|
||||
|
||||
|
||||
@router.get("/channels-status")
|
||||
async def api_openclaw_channels_status(
|
||||
probe: bool = False,
|
||||
service: OpenClawCliService = Depends(get_openclaw_cli_service),
|
||||
) -> ChannelsStatusResponse:
|
||||
"""Read `openclaw channels status --json [--probe]` and return a typed dict."""
|
||||
try:
|
||||
result = service.channels_status_model(probe=probe)
|
||||
return ChannelsStatusResponse.model_validate(result, strict=False)
|
||||
except OpenClawCliError as exc:
|
||||
_raise_cli_http_error(exc)
|
||||
|
||||
|
||||
@router.get("/channels-list")
|
||||
async def api_openclaw_channels_list(
|
||||
service: OpenClawCliService = Depends(get_openclaw_cli_service),
|
||||
) -> ChannelsListResponse:
|
||||
"""Read `openclaw channels list --json` and return a typed dict."""
|
||||
try:
|
||||
result = service.channels_list_model()
|
||||
return ChannelsListResponse.model_validate(result, strict=False)
|
||||
except OpenClawCliError as exc:
|
||||
_raise_cli_http_error(exc)
|
||||
|
||||
|
||||
@router.get("/hooks/info/{name}")
|
||||
async def api_openclaw_hook_info(
|
||||
name: str,
|
||||
service: OpenClawCliService = Depends(get_openclaw_cli_service),
|
||||
) -> HookInfoResponse:
|
||||
"""Read `openclaw hooks info <name> --json` and return a typed dict."""
|
||||
try:
|
||||
result = service.hook_info_model(name)
|
||||
return HookInfoResponse.model_validate(result, strict=False)
|
||||
except OpenClawCliError as exc:
|
||||
_raise_cli_http_error(exc)
|
||||
|
||||
|
||||
@router.get("/hooks/check")
|
||||
async def api_openclaw_hooks_check(
|
||||
service: OpenClawCliService = Depends(get_openclaw_cli_service),
|
||||
) -> HooksCheckResponse:
|
||||
"""Read `openclaw hooks check --json` and return a typed dict."""
|
||||
try:
|
||||
result = service.hooks_check_model()
|
||||
return HooksCheckResponse.model_validate(result, strict=False)
|
||||
except OpenClawCliError as exc:
|
||||
_raise_cli_http_error(exc)
|
||||
|
||||
|
||||
@router.get("/plugins-inspect")
|
||||
async def api_openclaw_plugins_inspect(
|
||||
plugin_id: str | None = None,
|
||||
all: bool = False,
|
||||
service: OpenClawCliService = Depends(get_openclaw_cli_service),
|
||||
) -> PluginsInspectResponse:
|
||||
"""Read `openclaw plugins inspect --json [--all]` and return a typed dict."""
|
||||
try:
|
||||
result = service.plugins_inspect_model(plugin_id=plugin_id, all=all)
|
||||
inspect = result if isinstance(result, list) else result.get("inspect", [])
|
||||
return PluginsInspectResponse(inspect=inspect)
|
||||
except OpenClawCliError as exc:
|
||||
_raise_cli_http_error(exc)
|
||||
|
||||
|
||||
class AgentBindingItem(BaseModel):
|
||||
agentId: str
|
||||
match: dict[str, Any]
|
||||
description: str
|
||||
|
||||
|
||||
class AgentsBindingsResponse(BaseModel):
|
||||
bindings: list[AgentBindingItem]
|
||||
|
||||
|
||||
@router.get("/agents-bindings")
|
||||
async def api_openclaw_agents_bindings(
|
||||
agent: str | None = None,
|
||||
service: OpenClawCliService = Depends(get_openclaw_cli_service),
|
||||
) -> AgentsBindingsResponse:
|
||||
"""Read `openclaw agents bindings --json [--agent <id>]` and return bindings list."""
|
||||
try:
|
||||
result = service.agents_bindings_model(agent=agent)
|
||||
bindings = result if isinstance(result, list) else []
|
||||
return AgentsBindingsResponse(bindings=bindings)
|
||||
except OpenClawCliError as exc:
|
||||
_raise_cli_http_error(exc)
|
||||
|
||||
|
||||
@router.get("/gateway-status")
|
||||
async def api_openclaw_gateway_status(
|
||||
url: str | None = None,
|
||||
token: str | None = None,
|
||||
service: OpenClawCliService = Depends(get_openclaw_cli_service),
|
||||
) -> dict[str, Any]:
|
||||
"""Read `openclaw gateway status --json [--url <url>] [--token <token>]`. Returns full gateway probe result."""
|
||||
try:
|
||||
result = service.gateway_status(url=url, token=token)
|
||||
return result
|
||||
except OpenClawCliError as exc:
|
||||
_raise_cli_http_error(exc)
|
||||
|
||||
|
||||
@router.get("/memory-status")
|
||||
async def api_openclaw_memory_status(
|
||||
agent: str | None = None,
|
||||
deep: bool = False,
|
||||
service: OpenClawCliService = Depends(get_openclaw_cli_service),
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Read `openclaw memory status --json [--agent <id>] [--deep]`. Returns array of per-agent memory status."""
|
||||
try:
|
||||
result = service.memory_status(agent=agent, deep=deep)
|
||||
return result if isinstance(result, list) else []
|
||||
except OpenClawCliError as exc:
|
||||
_raise_cli_http_error(exc)
|
||||
|
||||
|
||||
class WorkspaceFilesResponse(BaseModel):
|
||||
workspace: str
|
||||
files: list[dict[str, Any]]
|
||||
error: str | None = None
|
||||
|
||||
|
||||
@router.get("/workspace-files")
|
||||
async def api_openclaw_workspace_files(
|
||||
workspace: str = Query(..., description="Path to the agent workspace directory"),
|
||||
service: OpenClawCliService = Depends(get_openclaw_cli_service),
|
||||
) -> WorkspaceFilesResponse:
|
||||
"""List .md files in an OpenClaw agent workspace with their content previews."""
|
||||
result = service.list_workspace_files(workspace)
|
||||
return WorkspaceFilesResponse.model_validate(result, strict=False)
|
||||
969
backend/api/runtime.py
Normal file
969
backend/api/runtime.py
Normal file
@@ -0,0 +1,969 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Runtime API routes - Control Plane for managing Gateway processes."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, HTTPException, Request
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.runtime.agent_runtime import AgentRuntimeState
|
||||
from backend.config.bootstrap_config import (
|
||||
resolve_runtime_config,
|
||||
update_bootstrap_values_for_run,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/api/runtime", tags=["runtime"])
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||
|
||||
|
||||
class RuntimeState:
|
||||
"""Thread-safe singleton for managing runtime state.
|
||||
|
||||
Encapsulates runtime_manager, _gateway_process, and _gateway_port
|
||||
with asyncio.Lock protection for concurrent access.
|
||||
"""
|
||||
|
||||
_instance: Optional["RuntimeState"] = None
|
||||
_lock: "threading.Lock" = __import__("threading").Lock()
|
||||
|
||||
def __new__(cls) -> "RuntimeState":
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self) -> None:
|
||||
if self._initialized:
|
||||
return
|
||||
self._runtime_manager: Optional[Any] = None
|
||||
self._gateway_process: Optional[subprocess.Popen] = None
|
||||
self._gateway_port: int = 8765
|
||||
self._state_lock = asyncio.Lock()
|
||||
self._initialized = True
|
||||
|
||||
@property
|
||||
async def lock(self) -> asyncio.Lock:
|
||||
"""Get the asyncio lock for state synchronization."""
|
||||
return self._state_lock
|
||||
|
||||
@property
|
||||
def runtime_manager(self) -> Optional[Any]:
|
||||
"""Get the runtime manager (no lock - read only)."""
|
||||
return self._runtime_manager
|
||||
|
||||
@runtime_manager.setter
|
||||
def runtime_manager(self, value: Optional[Any]) -> None:
|
||||
"""Set the runtime manager."""
|
||||
self._runtime_manager = value
|
||||
|
||||
@property
|
||||
def gateway_process(self) -> Optional[subprocess.Popen]:
|
||||
"""Get the gateway process (no lock - read only)."""
|
||||
return self._gateway_process
|
||||
|
||||
@gateway_process.setter
|
||||
def gateway_process(self, value: Optional[subprocess.Popen]) -> None:
|
||||
"""Set the gateway process."""
|
||||
self._gateway_process = value
|
||||
|
||||
@property
|
||||
def gateway_port(self) -> int:
|
||||
"""Get the gateway port."""
|
||||
return self._gateway_port
|
||||
|
||||
@gateway_port.setter
|
||||
def gateway_port(self, value: int) -> None:
|
||||
"""Set the gateway port."""
|
||||
self._gateway_port = value
|
||||
|
||||
async def set_runtime_manager(self, manager: Any) -> None:
|
||||
"""Set runtime manager with lock protection."""
|
||||
async with self._state_lock:
|
||||
self._runtime_manager = manager
|
||||
|
||||
async def get_runtime_manager(self) -> Optional[Any]:
|
||||
"""Get runtime manager with lock protection."""
|
||||
async with self._state_lock:
|
||||
return self._runtime_manager
|
||||
|
||||
async def set_gateway_process(self, process: Optional[subprocess.Popen]) -> None:
|
||||
"""Set gateway process with lock protection."""
|
||||
async with self._state_lock:
|
||||
self._gateway_process = process
|
||||
|
||||
async def get_gateway_process(self) -> Optional[subprocess.Popen]:
|
||||
"""Get gateway process with lock protection."""
|
||||
async with self._state_lock:
|
||||
return self._gateway_process
|
||||
|
||||
async def set_gateway_port(self, port: int) -> None:
|
||||
"""Set gateway port with lock protection."""
|
||||
async with self._state_lock:
|
||||
self._gateway_port = port
|
||||
|
||||
async def get_gateway_port(self) -> int:
|
||||
"""Get gateway port with lock protection."""
|
||||
async with self._state_lock:
|
||||
return self._gateway_port
|
||||
|
||||
|
||||
# Singleton instance
|
||||
_runtime_state = RuntimeState()
|
||||
|
||||
|
||||
def get_runtime_state() -> RuntimeState:
|
||||
"""Get the RuntimeState singleton instance."""
|
||||
return _runtime_state
|
||||
|
||||
|
||||
# Backward compatibility: module-level runtime_manager for external imports
|
||||
# This is set by register_runtime_manager() for backward compatibility
|
||||
runtime_manager: Optional[Any] = None
|
||||
|
||||
|
||||
class RunContextResponse(BaseModel):
|
||||
config_name: str
|
||||
run_dir: str
|
||||
bootstrap_values: Dict[str, Any]
|
||||
|
||||
|
||||
class RuntimeAgentState(BaseModel):
|
||||
agent_id: str
|
||||
status: str
|
||||
last_session: Optional[str] = None
|
||||
last_updated: str
|
||||
|
||||
|
||||
class RuntimeAgentsResponse(BaseModel):
|
||||
agents: List[RuntimeAgentState]
|
||||
|
||||
|
||||
class RuntimeEvent(BaseModel):
|
||||
timestamp: str
|
||||
event: str
|
||||
details: Dict[str, Any]
|
||||
session: Optional[str]
|
||||
|
||||
|
||||
class RuntimeEventsResponse(BaseModel):
|
||||
events: List[RuntimeEvent]
|
||||
|
||||
|
||||
class LaunchConfig(BaseModel):
|
||||
"""Configuration for launching a new trading task."""
|
||||
launch_mode: str = Field(default="fresh", description="启动形式: fresh, restore")
|
||||
restore_run_id: Optional[str] = Field(default=None, description="历史任务 run_id,用于恢复启动")
|
||||
tickers: List[str] = Field(default_factory=list, description="股票池")
|
||||
schedule_mode: str = Field(default="daily", description="调度模式: daily, interval")
|
||||
interval_minutes: int = Field(default=60, ge=1, description="间隔分钟数")
|
||||
trigger_time: str = Field(default="09:30", description="触发时间 HH:MM")
|
||||
max_comm_cycles: int = Field(default=2, ge=1, description="最大会商轮数")
|
||||
initial_cash: float = Field(default=100000.0, gt=0, description="初始资金")
|
||||
margin_requirement: float = Field(default=0.0, ge=0, description="保证金要求")
|
||||
enable_memory: bool = Field(default=False, description="是否启用长期记忆")
|
||||
mode: str = Field(default="live", description="运行模式: live, backtest")
|
||||
start_date: Optional[str] = Field(default=None, description="回测开始日期 YYYY-MM-DD")
|
||||
end_date: Optional[str] = Field(default=None, description="回测结束日期 YYYY-MM-DD")
|
||||
poll_interval: int = Field(default=10, ge=1, le=300, description="市场数据轮询间隔(秒)")
|
||||
|
||||
|
||||
class LaunchResponse(BaseModel):
|
||||
run_id: str
|
||||
status: str
|
||||
run_dir: str
|
||||
gateway_port: int
|
||||
message: str
|
||||
|
||||
|
||||
class RuntimeHistoryItem(BaseModel):
|
||||
run_id: str
|
||||
run_dir: str
|
||||
updated_at: Optional[str] = None
|
||||
total_trades: int = 0
|
||||
total_asset_value: Optional[float] = None
|
||||
bootstrap: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class RuntimeHistoryResponse(BaseModel):
|
||||
runs: List[RuntimeHistoryItem]
|
||||
|
||||
|
||||
class StopResponse(BaseModel):
|
||||
status: str
|
||||
message: str
|
||||
|
||||
|
||||
class CleanupResponse(BaseModel):
|
||||
status: str
|
||||
kept: int
|
||||
pruned_run_ids: List[str]
|
||||
|
||||
|
||||
class GatewayStatusResponse(BaseModel):
|
||||
is_running: bool
|
||||
port: int
|
||||
run_id: Optional[str] = None
|
||||
|
||||
|
||||
class RuntimeConfigResponse(BaseModel):
|
||||
run_id: str
|
||||
is_running: bool
|
||||
gateway_port: int
|
||||
bootstrap: Dict[str, Any]
|
||||
resolved: Dict[str, Any]
|
||||
|
||||
|
||||
class RuntimeLogResponse(BaseModel):
|
||||
run_id: Optional[str] = None
|
||||
is_running: bool
|
||||
log_path: Optional[str] = None
|
||||
content: str = ""
|
||||
|
||||
|
||||
class UpdateRuntimeConfigRequest(BaseModel):
|
||||
schedule_mode: Optional[str] = None
|
||||
interval_minutes: Optional[int] = Field(default=None, ge=1)
|
||||
trigger_time: Optional[str] = None
|
||||
max_comm_cycles: Optional[int] = Field(default=None, ge=1)
|
||||
initial_cash: Optional[float] = Field(default=None, gt=0)
|
||||
margin_requirement: Optional[float] = Field(default=None, ge=0)
|
||||
enable_memory: Optional[bool] = None
|
||||
|
||||
|
||||
def _generate_run_id() -> str:
|
||||
"""Generate timestamp-based run ID: YYYYMMDD_HHMMSS"""
|
||||
return datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
|
||||
def _get_run_dir(run_id: str) -> Path:
|
||||
"""Return the run directory for a given run ID."""
|
||||
return PROJECT_ROOT / "runs" / run_id
|
||||
|
||||
|
||||
def _load_run_snapshot(run_id: str) -> Dict[str, Any]:
|
||||
"""Load a specific run snapshot by run_id."""
|
||||
snapshot_path = _get_run_dir(run_id) / "state" / "runtime_state.json"
|
||||
if not snapshot_path.exists():
|
||||
raise HTTPException(status_code=404, detail=f"Run snapshot not found: {run_id}")
|
||||
return json.loads(snapshot_path.read_text(encoding="utf-8"))
|
||||
|
||||
|
||||
def _copy_path_if_exists(src: Path, dst: Path) -> None:
|
||||
if not src.exists():
|
||||
return
|
||||
if src.is_dir():
|
||||
shutil.copytree(src, dst, dirs_exist_ok=True)
|
||||
else:
|
||||
dst.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy2(src, dst)
|
||||
|
||||
|
||||
def _restore_run_assets(source_run_id: str, target_run_dir: Path) -> None:
|
||||
"""Seed a fresh run directory from a historical run snapshot."""
|
||||
source_run_dir = _get_run_dir(source_run_id)
|
||||
if not source_run_dir.exists():
|
||||
raise HTTPException(status_code=404, detail=f"Source run not found: {source_run_id}")
|
||||
|
||||
for relative in [
|
||||
"team_dashboard",
|
||||
"agents",
|
||||
"skills",
|
||||
"memory",
|
||||
"state/server_state.json",
|
||||
"state/runtime.db",
|
||||
"state/research.db",
|
||||
]:
|
||||
_copy_path_if_exists(source_run_dir / relative, target_run_dir / relative)
|
||||
|
||||
|
||||
def _list_runs(limit: int = 50) -> list[RuntimeHistoryItem]:
|
||||
runs_root = PROJECT_ROOT / "runs"
|
||||
if not runs_root.exists():
|
||||
return []
|
||||
|
||||
items: list[RuntimeHistoryItem] = []
|
||||
run_dirs = sorted(
|
||||
[path for path in runs_root.iterdir() if path.is_dir()],
|
||||
key=lambda path: path.stat().st_mtime,
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
for run_dir in run_dirs[: max(1, int(limit))]:
|
||||
run_id = run_dir.name
|
||||
runtime_state_path = run_dir / "state" / "runtime_state.json"
|
||||
summary_path = run_dir / "team_dashboard" / "summary.json"
|
||||
|
||||
bootstrap: Dict[str, Any] = {}
|
||||
updated_at: Optional[str] = None
|
||||
total_trades = 0
|
||||
total_asset_value: Optional[float] = None
|
||||
|
||||
if runtime_state_path.exists():
|
||||
try:
|
||||
snapshot = json.loads(runtime_state_path.read_text(encoding="utf-8"))
|
||||
context = snapshot.get("context") or {}
|
||||
bootstrap = dict(context.get("bootstrap_values") or {})
|
||||
updated_at = snapshot.get("events", [{}])[-1].get("timestamp") if snapshot.get("events") else None
|
||||
except Exception:
|
||||
bootstrap = {}
|
||||
|
||||
if summary_path.exists():
|
||||
try:
|
||||
summary = json.loads(summary_path.read_text(encoding="utf-8"))
|
||||
total_trades = int(summary.get("totalTrades") or 0)
|
||||
total_asset_value = float(summary.get("totalAssetValue")) if summary.get("totalAssetValue") is not None else None
|
||||
except Exception:
|
||||
total_trades = 0
|
||||
total_asset_value = None
|
||||
|
||||
items.append(
|
||||
RuntimeHistoryItem(
|
||||
run_id=run_id,
|
||||
run_dir=str(run_dir),
|
||||
updated_at=updated_at,
|
||||
total_trades=total_trades,
|
||||
total_asset_value=total_asset_value,
|
||||
bootstrap=bootstrap,
|
||||
)
|
||||
)
|
||||
|
||||
return items
|
||||
|
||||
|
||||
def _is_timestamped_run_dir(path: Path) -> bool:
|
||||
try:
|
||||
datetime.strptime(path.name, "%Y%m%d_%H%M%S")
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def _prune_old_timestamped_runs(*, keep: int = 20, exclude_run_ids: Optional[set[str]] = None) -> list[str]:
|
||||
"""Prune old timestamped run directories, preserving the newest N and excluded ids."""
|
||||
exclude = exclude_run_ids or set()
|
||||
runs_root = PROJECT_ROOT / "runs"
|
||||
if not runs_root.exists():
|
||||
return []
|
||||
|
||||
candidates = sorted(
|
||||
[
|
||||
path
|
||||
for path in runs_root.iterdir()
|
||||
if path.is_dir() and _is_timestamped_run_dir(path) and path.name not in exclude
|
||||
],
|
||||
key=lambda path: path.name,
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
pruned: list[str] = []
|
||||
for path in candidates[max(0, keep):]:
|
||||
shutil.rmtree(path, ignore_errors=True)
|
||||
pruned.append(path.name)
|
||||
return pruned
|
||||
|
||||
|
||||
def _find_available_port(start_port: int = 8765, max_port: int = 9000) -> int:
|
||||
"""Find an available port for Gateway."""
|
||||
import socket
|
||||
for port in range(start_port, max_port):
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
if s.connect_ex(('localhost', port)) != 0:
|
||||
return port
|
||||
raise RuntimeError("No available port found")
|
||||
|
||||
|
||||
def _is_gateway_running() -> bool:
|
||||
"""Check if Gateway process is running.
|
||||
|
||||
Checks both the internally-managed gateway process and falls back to
|
||||
port availability (for externally-managed gateway processes).
|
||||
"""
|
||||
process = _runtime_state.gateway_process
|
||||
if process is not None and process.poll() is None:
|
||||
return True
|
||||
# Fallback: check if the gateway port is in use (for externally started gateway)
|
||||
import socket
|
||||
try:
|
||||
with socket.create_connection(("127.0.0.1", _runtime_state.gateway_port), timeout=1):
|
||||
return True
|
||||
except OSError:
|
||||
return False
|
||||
|
||||
|
||||
def _stop_gateway() -> bool:
|
||||
"""Stop the Gateway process."""
|
||||
process = _runtime_state.gateway_process
|
||||
if process is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
# Try graceful shutdown first
|
||||
process.terminate()
|
||||
try:
|
||||
process.wait(timeout=5)
|
||||
except subprocess.TimeoutExpired:
|
||||
# Force kill if graceful shutdown fails
|
||||
process.kill()
|
||||
process.wait()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error during gateway shutdown: {e}")
|
||||
finally:
|
||||
_runtime_state.gateway_process = None
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _start_gateway_process(
|
||||
run_id: str,
|
||||
run_dir: Path,
|
||||
bootstrap: Dict[str, Any],
|
||||
port: int
|
||||
) -> subprocess.Popen:
|
||||
"""Start Gateway as a separate process."""
|
||||
# Prepare environment
|
||||
env = os.environ.copy()
|
||||
|
||||
# Create command arguments
|
||||
cmd = [
|
||||
sys.executable,
|
||||
"-m", "backend.gateway_server",
|
||||
"--run-id", run_id,
|
||||
"--run-dir", str(run_dir),
|
||||
"--port", str(port),
|
||||
"--bootstrap", json.dumps(bootstrap)
|
||||
]
|
||||
|
||||
log_path = run_dir / "logs" / "gateway.log"
|
||||
log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
log_file = log_path.open("ab")
|
||||
try:
|
||||
process = subprocess.Popen(
|
||||
cmd,
|
||||
env=env,
|
||||
stdout=log_file,
|
||||
stderr=subprocess.STDOUT,
|
||||
cwd=PROJECT_ROOT
|
||||
)
|
||||
finally:
|
||||
log_file.close()
|
||||
|
||||
return process
|
||||
|
||||
|
||||
@router.get("/context", response_model=RunContextResponse)
|
||||
async def get_run_context() -> RunContextResponse:
|
||||
"""Return active runtime context, or latest persisted context when stopped."""
|
||||
snapshot = _get_active_runtime_snapshot() if _is_gateway_running() else _load_latest_runtime_snapshot()
|
||||
context = snapshot.get("context")
|
||||
if context is None:
|
||||
raise HTTPException(status_code=404, detail="Run context is not ready")
|
||||
|
||||
return RunContextResponse(
|
||||
config_name=context["config_name"],
|
||||
run_dir=context["run_dir"],
|
||||
bootstrap_values=context["bootstrap_values"],
|
||||
)
|
||||
|
||||
|
||||
@router.get("/agents", response_model=RuntimeAgentsResponse)
|
||||
async def get_runtime_agents() -> RuntimeAgentsResponse:
|
||||
"""Return agent states from the active runtime, or latest persisted run."""
|
||||
snapshot = _get_active_runtime_snapshot() if _is_gateway_running() else _load_latest_runtime_snapshot()
|
||||
agents = snapshot.get("agents", [])
|
||||
|
||||
return RuntimeAgentsResponse(
|
||||
agents=[RuntimeAgentState(**a) for a in agents]
|
||||
)
|
||||
|
||||
|
||||
@router.get("/events", response_model=RuntimeEventsResponse)
|
||||
async def get_runtime_events() -> RuntimeEventsResponse:
|
||||
"""Return events from the active runtime, or latest persisted run."""
|
||||
snapshot = _get_active_runtime_snapshot() if _is_gateway_running() else _load_latest_runtime_snapshot()
|
||||
events = snapshot.get("events", [])
|
||||
|
||||
return RuntimeEventsResponse(
|
||||
events=[RuntimeEvent(**e) for e in events]
|
||||
)
|
||||
|
||||
|
||||
@router.get("/history", response_model=RuntimeHistoryResponse)
|
||||
async def get_runtime_history(limit: int = 20) -> RuntimeHistoryResponse:
|
||||
"""List recent historical runs for restore/start selection."""
|
||||
return RuntimeHistoryResponse(runs=_list_runs(limit=limit))
|
||||
|
||||
|
||||
@router.get("/gateway/status", response_model=GatewayStatusResponse)
|
||||
async def get_gateway_status() -> GatewayStatusResponse:
|
||||
"""Get Gateway process status and port."""
|
||||
is_running = _is_gateway_running()
|
||||
run_id = None
|
||||
|
||||
if is_running:
|
||||
try:
|
||||
run_id = _get_active_runtime_context().get("config_name")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to resolve active runtime context: {e}")
|
||||
|
||||
return GatewayStatusResponse(
|
||||
is_running=is_running,
|
||||
port=_runtime_state.gateway_port,
|
||||
run_id=run_id
|
||||
)
|
||||
|
||||
|
||||
@router.get("/gateway/port")
|
||||
async def get_gateway_port(request: Request) -> Dict[str, Any]:
|
||||
"""Get WebSocket Gateway port for frontend connection."""
|
||||
gateway_port = _runtime_state.gateway_port
|
||||
return {
|
||||
"port": gateway_port,
|
||||
"is_running": _is_gateway_running(),
|
||||
"ws_url": _build_gateway_ws_url(request, gateway_port),
|
||||
}
|
||||
|
||||
|
||||
@router.get("/logs", response_model=RuntimeLogResponse)
|
||||
async def get_runtime_logs() -> RuntimeLogResponse:
|
||||
"""Return current runtime log tail, or the latest run log if runtime is stopped."""
|
||||
try:
|
||||
context = _get_active_runtime_context() if _is_gateway_running() else _get_runtime_context_from_latest_snapshot()
|
||||
except HTTPException:
|
||||
return RuntimeLogResponse(is_running=False, content="")
|
||||
|
||||
run_id = str(context.get("config_name") or "").strip() or None
|
||||
log_path = _get_gateway_log_path_for_run(run_id) if run_id else None
|
||||
content = _read_log_tail(log_path) if log_path else ""
|
||||
|
||||
return RuntimeLogResponse(
|
||||
run_id=run_id,
|
||||
is_running=_is_gateway_running(),
|
||||
log_path=str(log_path) if log_path else None,
|
||||
content=content,
|
||||
)
|
||||
|
||||
|
||||
def _build_gateway_ws_url(request: Request, port: int) -> str:
|
||||
"""Build a proxy-safe Gateway WebSocket URL."""
|
||||
forwarded_proto = request.headers.get("x-forwarded-proto", "").split(",")[0].strip()
|
||||
scheme = forwarded_proto or request.url.scheme
|
||||
ws_scheme = "wss" if scheme == "https" else "ws"
|
||||
|
||||
forwarded_host = request.headers.get("x-forwarded-host", "").split(",")[0].strip()
|
||||
host = forwarded_host or request.url.hostname or "localhost"
|
||||
if ":" in host and not host.startswith("["):
|
||||
host = host.split(":", 1)[0]
|
||||
|
||||
return f"{ws_scheme}://{host}:{port}"
|
||||
|
||||
|
||||
def _load_latest_runtime_snapshot() -> Dict[str, Any]:
|
||||
"""Load the latest persisted runtime snapshot."""
|
||||
snapshots = sorted(
|
||||
PROJECT_ROOT.glob("runs/*/state/runtime_state.json"),
|
||||
key=lambda p: p.stat().st_mtime,
|
||||
reverse=True,
|
||||
)
|
||||
if not snapshots:
|
||||
raise HTTPException(status_code=404, detail="No runtime information available")
|
||||
return json.loads(snapshots[0].read_text(encoding="utf-8"))
|
||||
|
||||
|
||||
def _get_active_runtime_snapshot() -> Dict[str, Any]:
|
||||
"""Return the active runtime snapshot, preferring in-memory manager state."""
|
||||
if not _is_gateway_running():
|
||||
raise HTTPException(status_code=404, detail="No runtime is currently running")
|
||||
|
||||
manager = _runtime_state.runtime_manager
|
||||
if manager is not None and hasattr(manager, "build_snapshot"):
|
||||
snapshot = manager.build_snapshot()
|
||||
context = snapshot.get("context") or {}
|
||||
if context.get("config_name"):
|
||||
return snapshot
|
||||
|
||||
return _load_latest_runtime_snapshot()
|
||||
|
||||
|
||||
def _get_runtime_context_from_latest_snapshot() -> Dict[str, Any]:
|
||||
"""Return the latest persisted runtime context regardless of active process state."""
|
||||
latest = _load_latest_runtime_snapshot()
|
||||
context = latest.get("context") or {}
|
||||
if not context.get("config_name"):
|
||||
raise HTTPException(status_code=404, detail="No runtime context available")
|
||||
return context
|
||||
|
||||
|
||||
def _get_gateway_log_path_for_run(run_id: str) -> Path:
|
||||
return _get_run_dir(run_id) / "logs" / "gateway.log"
|
||||
|
||||
|
||||
def _read_log_tail(path: Path, max_chars: int = 120_000) -> str:
|
||||
if not path.exists() or not path.is_file():
|
||||
return ""
|
||||
text = path.read_text(encoding="utf-8", errors="replace")
|
||||
if len(text) <= max_chars:
|
||||
return text
|
||||
return text[-max_chars:]
|
||||
|
||||
|
||||
def _get_current_runtime_context() -> Dict[str, Any]:
|
||||
"""Return the active runtime context from the latest snapshot."""
|
||||
if not _is_gateway_running():
|
||||
raise HTTPException(status_code=404, detail="No runtime is currently running")
|
||||
snapshot = _get_active_runtime_snapshot()
|
||||
context = snapshot.get("context") or {}
|
||||
if not context.get("config_name"):
|
||||
raise HTTPException(status_code=404, detail="No runtime context available")
|
||||
return context
|
||||
|
||||
|
||||
def _get_active_runtime_context() -> Dict[str, Any]:
|
||||
"""Return the active runtime context, preferring in-memory runtime manager state."""
|
||||
return _get_current_runtime_context()
|
||||
|
||||
|
||||
def _resolve_runtime_response(run_id: str) -> RuntimeConfigResponse:
|
||||
"""Build a normalized runtime config response for the active run."""
|
||||
context = _get_current_runtime_context()
|
||||
bootstrap = dict(context.get("bootstrap_values") or {})
|
||||
resolved = resolve_runtime_config(
|
||||
project_root=PROJECT_ROOT,
|
||||
config_name=run_id,
|
||||
enable_memory=bool(bootstrap.get("enable_memory", False)),
|
||||
schedule_mode=str(bootstrap.get("schedule_mode", "daily")),
|
||||
interval_minutes=int(bootstrap.get("interval_minutes", 60) or 60),
|
||||
trigger_time=str(bootstrap.get("trigger_time", "09:30") or "09:30"),
|
||||
)
|
||||
return RuntimeConfigResponse(
|
||||
run_id=run_id,
|
||||
is_running=True,
|
||||
gateway_port=_runtime_state.gateway_port,
|
||||
bootstrap=bootstrap,
|
||||
resolved=resolved,
|
||||
)
|
||||
|
||||
|
||||
def _normalize_runtime_config_updates(
|
||||
request: UpdateRuntimeConfigRequest,
|
||||
) -> Dict[str, Any]:
|
||||
"""Validate and normalize runtime config updates."""
|
||||
updates: Dict[str, Any] = {}
|
||||
|
||||
if request.schedule_mode is not None:
|
||||
schedule_mode = str(request.schedule_mode).strip().lower()
|
||||
if schedule_mode not in {"daily", "intraday"}:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="schedule_mode must be 'daily' or 'intraday'",
|
||||
)
|
||||
updates["schedule_mode"] = schedule_mode
|
||||
|
||||
if request.interval_minutes is not None:
|
||||
updates["interval_minutes"] = int(request.interval_minutes)
|
||||
|
||||
if request.trigger_time is not None:
|
||||
trigger_time = str(request.trigger_time).strip()
|
||||
if trigger_time and trigger_time != "now":
|
||||
try:
|
||||
datetime.strptime(trigger_time, "%H:%M")
|
||||
except ValueError as exc:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="trigger_time must use HH:MM or 'now'",
|
||||
) from exc
|
||||
updates["trigger_time"] = trigger_time or "09:30"
|
||||
|
||||
if request.max_comm_cycles is not None:
|
||||
updates["max_comm_cycles"] = int(request.max_comm_cycles)
|
||||
|
||||
if request.initial_cash is not None:
|
||||
updates["initial_cash"] = float(request.initial_cash)
|
||||
|
||||
if request.margin_requirement is not None:
|
||||
updates["margin_requirement"] = float(request.margin_requirement)
|
||||
|
||||
if request.enable_memory is not None:
|
||||
updates["enable_memory"] = bool(request.enable_memory)
|
||||
|
||||
if not updates:
|
||||
raise HTTPException(status_code=400, detail="No runtime config updates provided")
|
||||
|
||||
return updates
|
||||
|
||||
|
||||
@router.post("/start", response_model=LaunchResponse)
|
||||
async def start_runtime(
|
||||
config: LaunchConfig,
|
||||
background_tasks: BackgroundTasks
|
||||
) -> LaunchResponse:
|
||||
"""Start a new trading runtime with the given configuration.
|
||||
|
||||
1. Stop existing Gateway if running
|
||||
2. Generate run ID and directory
|
||||
3. Create runtime manager
|
||||
4. Start Gateway as subprocess (Data Plane)
|
||||
5. Return Gateway port for WebSocket connection
|
||||
"""
|
||||
# Lazy import to avoid circular dependency
|
||||
from backend.runtime.manager import TradingRuntimeManager
|
||||
|
||||
# 1. Stop existing Gateway
|
||||
if _is_gateway_running():
|
||||
_stop_gateway()
|
||||
await asyncio.sleep(1) # Wait for port release
|
||||
|
||||
launch_mode = str(config.launch_mode or "fresh").strip().lower()
|
||||
if launch_mode not in {"fresh", "restore"}:
|
||||
raise HTTPException(status_code=400, detail="launch_mode must be 'fresh' or 'restore'")
|
||||
|
||||
# 2. Resolve run ID, directory, and bootstrap
|
||||
if launch_mode == "restore":
|
||||
restore_run_id = str(config.restore_run_id or "").strip()
|
||||
if not restore_run_id:
|
||||
raise HTTPException(status_code=400, detail="restore_run_id is required when launch_mode=restore")
|
||||
snapshot = _load_run_snapshot(restore_run_id)
|
||||
context = snapshot.get("context") or {}
|
||||
if not context.get("config_name"):
|
||||
raise HTTPException(status_code=404, detail=f"Run context not found: {restore_run_id}")
|
||||
run_id = restore_run_id
|
||||
run_dir = _get_run_dir(run_id)
|
||||
bootstrap = dict(context.get("bootstrap_values") or {})
|
||||
bootstrap["launch_mode"] = "restore"
|
||||
bootstrap["restore_run_id"] = restore_run_id
|
||||
else:
|
||||
run_id = _generate_run_id()
|
||||
run_dir = _get_run_dir(run_id)
|
||||
bootstrap = {
|
||||
"launch_mode": "fresh",
|
||||
"restore_run_id": None,
|
||||
"tickers": config.tickers,
|
||||
"schedule_mode": config.schedule_mode,
|
||||
"interval_minutes": config.interval_minutes,
|
||||
"trigger_time": config.trigger_time,
|
||||
"max_comm_cycles": config.max_comm_cycles,
|
||||
"initial_cash": config.initial_cash,
|
||||
"margin_requirement": config.margin_requirement,
|
||||
"enable_memory": config.enable_memory,
|
||||
"mode": config.mode,
|
||||
"start_date": config.start_date,
|
||||
"end_date": config.end_date,
|
||||
"poll_interval": config.poll_interval,
|
||||
}
|
||||
|
||||
retention_keep = max(1, int(os.getenv("RUNS_RETENTION_COUNT", "20") or "20"))
|
||||
pruned_run_ids = _prune_old_timestamped_runs(
|
||||
keep=retention_keep,
|
||||
exclude_run_ids={run_id},
|
||||
)
|
||||
if pruned_run_ids:
|
||||
logger.info("Pruned old run directories: %s", ", ".join(pruned_run_ids))
|
||||
|
||||
# 4. Create runtime manager
|
||||
manager = TradingRuntimeManager(
|
||||
config_name=run_id,
|
||||
run_dir=run_dir,
|
||||
bootstrap=bootstrap,
|
||||
)
|
||||
manager.prepare_run()
|
||||
register_runtime_manager(manager)
|
||||
|
||||
# 5. Write BOOTSTRAP.md
|
||||
_write_bootstrap_md(run_dir, bootstrap)
|
||||
|
||||
# 6. Find available port and start Gateway process
|
||||
gateway_port = _find_available_port(start_port=8765)
|
||||
_runtime_state.gateway_port = gateway_port
|
||||
|
||||
try:
|
||||
process = _start_gateway_process(
|
||||
run_id=run_id,
|
||||
run_dir=run_dir,
|
||||
bootstrap=bootstrap,
|
||||
port=gateway_port
|
||||
)
|
||||
_runtime_state.gateway_process = process
|
||||
|
||||
# Wait briefly to check if process started successfully
|
||||
await asyncio.sleep(2)
|
||||
|
||||
if not _is_gateway_running():
|
||||
_runtime_state.gateway_process = None
|
||||
log_path = _get_gateway_log_path_for_run(run_id)
|
||||
log_tail = _read_log_tail(log_path, max_chars=4000)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Gateway failed to start: {log_tail or 'Unknown error'}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
_stop_gateway()
|
||||
raise HTTPException(status_code=500, detail=f"Failed to start Gateway: {str(e)}")
|
||||
|
||||
return LaunchResponse(
|
||||
run_id=run_id,
|
||||
status="started",
|
||||
run_dir=str(run_dir),
|
||||
gateway_port=gateway_port,
|
||||
message=f"Runtime started with run_id: {run_id}, Gateway on port: {gateway_port}",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/config", response_model=RuntimeConfigResponse)
|
||||
async def get_runtime_config() -> RuntimeConfigResponse:
|
||||
"""Return the current runtime bootstrap and resolved settings."""
|
||||
context = _get_current_runtime_context()
|
||||
return _resolve_runtime_response(context["config_name"])
|
||||
|
||||
|
||||
@router.put("/config", response_model=RuntimeConfigResponse)
|
||||
async def update_runtime_config(
|
||||
request: UpdateRuntimeConfigRequest,
|
||||
) -> RuntimeConfigResponse:
|
||||
"""Persist selected runtime configuration updates for the active run."""
|
||||
context = _get_current_runtime_context()
|
||||
run_id = context["config_name"]
|
||||
updates = _normalize_runtime_config_updates(request)
|
||||
updated = update_bootstrap_values_for_run(PROJECT_ROOT, run_id, updates)
|
||||
|
||||
manager = _runtime_state.runtime_manager
|
||||
if manager is not None and getattr(manager, "config_name", None) == run_id:
|
||||
manager.bootstrap.update(updates)
|
||||
if getattr(manager, "context", None) is not None:
|
||||
manager.context.bootstrap_values.update(updates)
|
||||
if hasattr(manager, "_persist_snapshot"):
|
||||
manager._persist_snapshot()
|
||||
|
||||
response = _resolve_runtime_response(run_id)
|
||||
response.bootstrap = dict(updated.values)
|
||||
return response
|
||||
|
||||
|
||||
@router.post("/stop", response_model=StopResponse)
|
||||
async def stop_runtime(force: bool = True) -> StopResponse:
|
||||
"""Stop the current running runtime."""
|
||||
was_running = _is_gateway_running()
|
||||
|
||||
if not was_running:
|
||||
raise HTTPException(status_code=404, detail="No runtime is currently running")
|
||||
|
||||
# Stop Gateway process
|
||||
_stop_gateway()
|
||||
|
||||
# Unregister runtime manager
|
||||
unregister_runtime_manager()
|
||||
|
||||
return StopResponse(
|
||||
status="stopped",
|
||||
message="Runtime stopped successfully",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/cleanup", response_model=CleanupResponse)
|
||||
async def cleanup_old_runs(keep: int = 20) -> CleanupResponse:
|
||||
"""Prune old timestamped run directories while preserving named runs."""
|
||||
keep_count = max(1, int(keep))
|
||||
exclude: set[str] = set()
|
||||
|
||||
if _is_gateway_running():
|
||||
try:
|
||||
active_context = _get_active_runtime_context()
|
||||
active_run_id = str(active_context.get("config_name") or "").strip()
|
||||
if active_run_id:
|
||||
exclude.add(active_run_id)
|
||||
except HTTPException:
|
||||
pass
|
||||
|
||||
pruned = _prune_old_timestamped_runs(keep=keep_count, exclude_run_ids=exclude)
|
||||
return CleanupResponse(status="ok", kept=keep_count, pruned_run_ids=pruned)
|
||||
|
||||
|
||||
@router.post("/restart")
|
||||
async def restart_runtime(
|
||||
config: LaunchConfig,
|
||||
background_tasks: BackgroundTasks
|
||||
):
|
||||
"""Restart the runtime with a new configuration."""
|
||||
# Stop current runtime
|
||||
await stop_runtime(force=True)
|
||||
|
||||
# Start new runtime
|
||||
response = await start_runtime(config, background_tasks)
|
||||
|
||||
return {
|
||||
"run_id": response.run_id,
|
||||
"status": "restarted",
|
||||
"gateway_port": response.gateway_port,
|
||||
"message": f"Runtime restarted with run_id: {response.run_id}",
|
||||
}
|
||||
|
||||
|
||||
@router.get("/current")
|
||||
async def get_current_runtime():
|
||||
"""Get information about the currently running runtime."""
|
||||
if not _is_gateway_running():
|
||||
raise HTTPException(status_code=404, detail="No runtime is currently running")
|
||||
|
||||
context = _get_active_runtime_context()
|
||||
|
||||
return {
|
||||
"run_id": context.get("config_name"),
|
||||
"run_dir": context.get("run_dir"),
|
||||
"is_running": True,
|
||||
"gateway_port": _runtime_state.gateway_port,
|
||||
"bootstrap": context.get("bootstrap_values", {}),
|
||||
}
|
||||
|
||||
|
||||
def register_runtime_manager(manager: Any) -> None:
|
||||
"""Allow other modules to expose the runtime manager to the API."""
|
||||
global runtime_manager
|
||||
runtime_manager = manager
|
||||
# Also update the RuntimeState singleton for internal consistency
|
||||
_runtime_state.runtime_manager = manager
|
||||
|
||||
|
||||
def unregister_runtime_manager() -> None:
|
||||
"""Drop the runtime manager reference."""
|
||||
global runtime_manager
|
||||
runtime_manager = None
|
||||
# Also update the RuntimeState singleton for internal consistency
|
||||
_runtime_state.runtime_manager = None
|
||||
|
||||
|
||||
def _write_bootstrap_md(run_dir: Path, bootstrap: Dict[str, Any]) -> None:
|
||||
"""Write bootstrap configuration to BOOTSTRAP.md."""
|
||||
try:
|
||||
import yaml
|
||||
except ImportError:
|
||||
yaml = None
|
||||
|
||||
bootstrap_path = run_dir / "BOOTSTRAP.md"
|
||||
bootstrap_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Filter out None values
|
||||
values = {k: v for k, v in bootstrap.items() if v is not None}
|
||||
|
||||
if yaml:
|
||||
front_matter = yaml.safe_dump(values, allow_unicode=True, sort_keys=False)
|
||||
else:
|
||||
front_matter = json.dumps(values, ensure_ascii=False, indent=2)
|
||||
|
||||
content = f"---\n{front_matter}---\n"
|
||||
bootstrap_path.write_text(content, encoding="utf-8")
|
||||
196
backend/api/workspaces.py
Normal file
196
backend/api/workspaces.py
Normal file
@@ -0,0 +1,196 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Workspace API Routes
|
||||
|
||||
Provides REST API endpoints for workspace management.
|
||||
"""
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.agents import WorkspaceManager
|
||||
|
||||
router = APIRouter(prefix="/api/workspaces", tags=["workspaces"])
|
||||
|
||||
|
||||
# Request/Response Models
|
||||
class CreateWorkspaceRequest(BaseModel):
|
||||
"""Request to create a new workspace."""
|
||||
workspace_id: str = Field(..., description="Unique workspace identifier")
|
||||
name: Optional[str] = Field(None, description="Display name")
|
||||
description: Optional[str] = Field(None, description="Workspace description")
|
||||
metadata: Optional[Dict[str, Any]] = Field(None, description="Additional metadata")
|
||||
|
||||
|
||||
class UpdateWorkspaceRequest(BaseModel):
|
||||
"""Request to update a workspace."""
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class WorkspaceResponse(BaseModel):
|
||||
"""Workspace information response."""
|
||||
workspace_id: str
|
||||
name: str
|
||||
description: str
|
||||
created_at: Optional[str] = None
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class WorkspaceListResponse(BaseModel):
|
||||
"""List of workspaces response."""
|
||||
workspaces: List[WorkspaceResponse]
|
||||
total: int
|
||||
|
||||
|
||||
# Dependencies
|
||||
def get_workspace_manager():
|
||||
"""Get WorkspaceManager instance."""
|
||||
return WorkspaceManager()
|
||||
|
||||
|
||||
# Routes
|
||||
@router.post("", response_model=WorkspaceResponse)
|
||||
async def create_workspace(
|
||||
request: CreateWorkspaceRequest,
|
||||
manager: WorkspaceManager = Depends(get_workspace_manager),
|
||||
):
|
||||
"""
|
||||
Create a new workspace.
|
||||
|
||||
Args:
|
||||
request: Workspace creation parameters
|
||||
|
||||
Returns:
|
||||
Created workspace information
|
||||
"""
|
||||
try:
|
||||
config = manager.create_workspace(
|
||||
workspace_id=request.workspace_id,
|
||||
name=request.name,
|
||||
description=request.description,
|
||||
metadata=request.metadata or {},
|
||||
)
|
||||
return WorkspaceResponse(
|
||||
workspace_id=config.workspace_id,
|
||||
name=config.name,
|
||||
description=config.description,
|
||||
created_at=config.created_at,
|
||||
metadata=config.metadata,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.get("", response_model=WorkspaceListResponse)
|
||||
async def list_workspaces(
|
||||
manager: WorkspaceManager = Depends(get_workspace_manager),
|
||||
):
|
||||
"""
|
||||
List all workspaces.
|
||||
|
||||
Returns:
|
||||
List of workspaces
|
||||
"""
|
||||
workspaces = manager.list_workspaces()
|
||||
return WorkspaceListResponse(
|
||||
workspaces=[
|
||||
WorkspaceResponse(
|
||||
workspace_id=ws.workspace_id,
|
||||
name=ws.name,
|
||||
description=ws.description,
|
||||
created_at=ws.created_at,
|
||||
metadata=ws.metadata,
|
||||
)
|
||||
for ws in workspaces
|
||||
],
|
||||
total=len(workspaces),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{workspace_id}", response_model=WorkspaceResponse)
|
||||
async def get_workspace(
|
||||
workspace_id: str,
|
||||
manager: WorkspaceManager = Depends(get_workspace_manager),
|
||||
):
|
||||
"""
|
||||
Get workspace details.
|
||||
|
||||
Args:
|
||||
workspace_id: Workspace identifier
|
||||
|
||||
Returns:
|
||||
Workspace information
|
||||
"""
|
||||
workspace = manager.get_workspace(workspace_id)
|
||||
if not workspace:
|
||||
raise HTTPException(status_code=404, detail=f"Workspace '{workspace_id}' not found")
|
||||
|
||||
return WorkspaceResponse(
|
||||
workspace_id=workspace["workspace_id"],
|
||||
name=workspace.get("name", workspace_id),
|
||||
description=workspace.get("description", ""),
|
||||
created_at=workspace.get("created_at"),
|
||||
metadata=workspace.get("metadata", {}),
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/{workspace_id}", response_model=WorkspaceResponse)
|
||||
async def update_workspace(
|
||||
workspace_id: str,
|
||||
request: UpdateWorkspaceRequest,
|
||||
manager: WorkspaceManager = Depends(get_workspace_manager),
|
||||
):
|
||||
"""
|
||||
Update workspace configuration.
|
||||
|
||||
Args:
|
||||
workspace_id: Workspace identifier
|
||||
request: Update parameters
|
||||
|
||||
Returns:
|
||||
Updated workspace information
|
||||
"""
|
||||
try:
|
||||
config = manager.update_workspace_config(
|
||||
workspace_id=workspace_id,
|
||||
name=request.name,
|
||||
description=request.description,
|
||||
metadata=request.metadata,
|
||||
)
|
||||
return WorkspaceResponse(
|
||||
workspace_id=config.workspace_id,
|
||||
name=config.name,
|
||||
description=config.description,
|
||||
created_at=config.created_at,
|
||||
metadata=config.metadata,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/{workspace_id}")
|
||||
async def delete_workspace(
|
||||
workspace_id: str,
|
||||
force: bool = False,
|
||||
manager: WorkspaceManager = Depends(get_workspace_manager),
|
||||
):
|
||||
"""
|
||||
Delete a workspace.
|
||||
|
||||
Args:
|
||||
workspace_id: Workspace identifier
|
||||
force: If True, delete even if workspace has agents
|
||||
|
||||
Returns:
|
||||
Success message
|
||||
"""
|
||||
try:
|
||||
success = manager.delete_workspace(workspace_id, force=force)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail=f"Workspace '{workspace_id}' not found")
|
||||
return {"message": f"Workspace '{workspace_id}' deleted successfully"}
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
34
backend/apps/__init__.py
Normal file
34
backend/apps/__init__.py
Normal file
@@ -0,0 +1,34 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Application surfaces for progressive service extraction."""
|
||||
|
||||
from .agent_service import app as agent_app
|
||||
from .agent_service import create_app as create_agent_app
|
||||
from .news_service import app as news_app
|
||||
from .news_service import create_app as create_news_app
|
||||
from .openclaw_service import app as openclaw_app
|
||||
from .openclaw_service import create_app as create_openclaw_app
|
||||
from .runtime_service import app as runtime_app
|
||||
from .runtime_service import create_app as create_runtime_app
|
||||
from .trading_service import app as trading_app
|
||||
from .trading_service import create_app as create_trading_app
|
||||
from .cors import add_cors_middleware, get_cors_origins
|
||||
|
||||
app = agent_app
|
||||
create_app = create_agent_app
|
||||
|
||||
__all__ = [
|
||||
"app",
|
||||
"create_app",
|
||||
"agent_app",
|
||||
"create_agent_app",
|
||||
"news_app",
|
||||
"create_news_app",
|
||||
"openclaw_app",
|
||||
"create_openclaw_app",
|
||||
"runtime_app",
|
||||
"create_runtime_app",
|
||||
"trading_app",
|
||||
"create_trading_app",
|
||||
"add_cors_middleware",
|
||||
"get_cors_origins",
|
||||
]
|
||||
89
backend/apps/agent_service.py
Normal file
89
backend/apps/agent_service.py
Normal file
@@ -0,0 +1,89 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Agent control-plane FastAPI surface."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
from backend.apps.cors import add_cors_middleware
|
||||
|
||||
from backend.api import agents_router, guard_router, workspaces_router
|
||||
from backend.agents import AgentFactory, WorkspaceManager, get_registry
|
||||
|
||||
# Global instances (initialized on startup)
|
||||
agent_factory: AgentFactory | None = None
|
||||
workspace_manager: WorkspaceManager | None = None
|
||||
|
||||
|
||||
def create_app(project_root: Path | None = None) -> FastAPI:
|
||||
"""Create the agent control-plane app."""
|
||||
resolved_project_root = project_root or Path(__file__).resolve().parents[2]
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(_app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
"""Initialize workspace and registry state for the control plane."""
|
||||
global agent_factory, workspace_manager
|
||||
|
||||
workspace_manager = WorkspaceManager(project_root=resolved_project_root)
|
||||
agent_factory = AgentFactory(project_root=resolved_project_root)
|
||||
agent_factory.workspaces_root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
registry = get_registry()
|
||||
print("✓ 大时代 API started")
|
||||
print(f" - Workspaces root: {agent_factory.workspaces_root}")
|
||||
print(f" - Registered agents: {registry.get_agent_count()}")
|
||||
|
||||
yield
|
||||
|
||||
print("✓ 大时代 API shutting down")
|
||||
|
||||
app = FastAPI(
|
||||
title="大时代 Agent Service",
|
||||
description="REST API for the 大时代 multi-agent control plane",
|
||||
version="0.1.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
add_cors_middleware(app)
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check() -> dict[str, object]:
|
||||
"""Health check endpoint."""
|
||||
registry = get_registry()
|
||||
return {
|
||||
"status": "healthy",
|
||||
"version": "0.1.0",
|
||||
"agents_registered": registry.get_agent_count(),
|
||||
"workspaces_available": (
|
||||
len(workspace_manager.list_workspaces())
|
||||
if workspace_manager
|
||||
else 0
|
||||
),
|
||||
}
|
||||
|
||||
@app.get("/api/status")
|
||||
async def api_status() -> dict[str, object]:
|
||||
"""Get API status and registry information."""
|
||||
registry = get_registry()
|
||||
return {
|
||||
"status": "operational",
|
||||
"registry": registry.get_stats(),
|
||||
}
|
||||
|
||||
app.include_router(workspaces_router)
|
||||
app.include_router(agents_router)
|
||||
app.include_router(guard_router)
|
||||
return app
|
||||
|
||||
|
||||
app = create_app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
30
backend/apps/cors.py
Normal file
30
backend/apps/cors.py
Normal file
@@ -0,0 +1,30 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Shared CORS configuration for all microservice apps."""
|
||||
|
||||
import os
|
||||
from typing import Sequence
|
||||
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
|
||||
def get_cors_origins() -> Sequence[str]:
|
||||
"""Get allowed CORS origins from environment variable.
|
||||
|
||||
Defaults to ["*"] for backward compatibility.
|
||||
Set CORS_ALLOWED_ORIGINS env var (comma-separated) in production.
|
||||
"""
|
||||
origins = os.getenv("CORS_ALLOWED_ORIGINS", "").strip()
|
||||
if not origins:
|
||||
return ["*"]
|
||||
return [o.strip() for o in origins.split(",") if o.strip()]
|
||||
|
||||
|
||||
def add_cors_middleware(app: "FastAPI") -> None:
|
||||
"""Add CORS middleware to app with environment-configured origins."""
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=get_cors_origins(),
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
154
backend/apps/news_service.py
Normal file
154
backend/apps/news_service.py
Normal file
@@ -0,0 +1,154 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""News and explain FastAPI surface."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fastapi import Depends, FastAPI, Query
|
||||
from backend.apps.cors import add_cors_middleware
|
||||
|
||||
from backend.data.market_store import MarketStore
|
||||
from backend.domains import news as news_domain
|
||||
|
||||
|
||||
def get_market_store() -> MarketStore:
|
||||
"""Get the MarketStore singleton dependency."""
|
||||
return MarketStore.get_instance()
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
"""Create the news/explain service app."""
|
||||
app = FastAPI(
|
||||
title="大时代 News Service",
|
||||
description="Read-only news enrichment and explain service surface extracted from the monolith",
|
||||
version="0.1.0",
|
||||
)
|
||||
|
||||
add_cors_middleware(app)
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check() -> dict[str, str]:
|
||||
return {"status": "healthy", "service": "news-service"}
|
||||
|
||||
@app.get("/api/enriched-news")
|
||||
async def api_get_enriched_news(
|
||||
ticker: str = Query(..., min_length=1),
|
||||
start_date: str | None = Query(None),
|
||||
end_date: str | None = Query(None),
|
||||
limit: int = Query(100, ge=1, le=1000),
|
||||
store: MarketStore = Depends(get_market_store),
|
||||
) -> dict[str, Any]:
|
||||
return news_domain.get_enriched_news(
|
||||
store,
|
||||
ticker=ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=limit,
|
||||
refresh_if_stale=False,
|
||||
)
|
||||
|
||||
@app.get("/api/news-for-date")
|
||||
async def api_get_news_for_date(
|
||||
ticker: str = Query(..., min_length=1),
|
||||
date: str = Query(...),
|
||||
limit: int = Query(20, ge=1, le=100),
|
||||
store: MarketStore = Depends(get_market_store),
|
||||
) -> dict[str, Any]:
|
||||
return news_domain.get_news_for_date(
|
||||
store,
|
||||
ticker=ticker,
|
||||
date=date,
|
||||
limit=limit,
|
||||
refresh_if_stale=False,
|
||||
)
|
||||
|
||||
@app.get("/api/news-timeline")
|
||||
async def api_get_news_timeline(
|
||||
ticker: str = Query(..., min_length=1),
|
||||
start_date: str = Query(...),
|
||||
end_date: str = Query(...),
|
||||
store: MarketStore = Depends(get_market_store),
|
||||
) -> dict[str, Any]:
|
||||
return news_domain.get_news_timeline(
|
||||
store,
|
||||
ticker=ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
refresh_if_stale=False,
|
||||
)
|
||||
|
||||
@app.get("/api/categories")
|
||||
async def api_get_categories(
|
||||
ticker: str = Query(..., min_length=1),
|
||||
start_date: str | None = Query(None),
|
||||
end_date: str | None = Query(None),
|
||||
limit: int = Query(200, ge=1, le=1000),
|
||||
store: MarketStore = Depends(get_market_store),
|
||||
) -> dict[str, Any]:
|
||||
return news_domain.get_news_categories(
|
||||
store,
|
||||
ticker=ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=limit,
|
||||
refresh_if_stale=False,
|
||||
)
|
||||
|
||||
@app.get("/api/similar-days")
|
||||
async def api_get_similar_days(
|
||||
ticker: str = Query(..., min_length=1),
|
||||
date: str = Query(...),
|
||||
n_similar: int = Query(5, ge=1, le=20),
|
||||
store: MarketStore = Depends(get_market_store),
|
||||
) -> dict[str, Any]:
|
||||
return news_domain.get_similar_days_payload(
|
||||
store,
|
||||
ticker=ticker,
|
||||
date=date,
|
||||
n_similar=n_similar,
|
||||
refresh_if_stale=False,
|
||||
)
|
||||
|
||||
@app.get("/api/stories/{ticker}")
|
||||
async def api_get_story(
|
||||
ticker: str,
|
||||
as_of_date: str = Query(...),
|
||||
store: MarketStore = Depends(get_market_store),
|
||||
) -> dict[str, Any]:
|
||||
return news_domain.get_story_payload(
|
||||
store,
|
||||
ticker=ticker,
|
||||
as_of_date=as_of_date,
|
||||
refresh_if_stale=False,
|
||||
)
|
||||
|
||||
@app.get("/api/range-explain")
|
||||
async def api_get_range_explain(
|
||||
ticker: str = Query(..., min_length=1),
|
||||
start_date: str = Query(...),
|
||||
end_date: str = Query(...),
|
||||
article_ids: list[str] = Query(default=[]),
|
||||
limit: int = Query(100, ge=1, le=500),
|
||||
store: MarketStore = Depends(get_market_store),
|
||||
) -> dict[str, Any]:
|
||||
return news_domain.get_range_explain_payload(
|
||||
store,
|
||||
ticker=ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
article_ids=article_ids,
|
||||
limit=limit,
|
||||
refresh_if_stale=False,
|
||||
)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
app = create_app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=8002)
|
||||
49
backend/apps/openclaw_service.py
Normal file
49
backend/apps/openclaw_service.py
Normal file
@@ -0,0 +1,49 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Read-only OpenClaw CLI FastAPI surface."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import Depends, FastAPI
|
||||
|
||||
from backend.api import openclaw_router
|
||||
from backend.apps.cors import add_cors_middleware
|
||||
from backend.api.openclaw import get_openclaw_cli_service
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
"""Create the OpenClaw service app."""
|
||||
app = FastAPI(
|
||||
title="大时代 OpenClaw Service",
|
||||
description="Read-only OpenClaw CLI integration service surface",
|
||||
version="0.1.0",
|
||||
)
|
||||
|
||||
add_cors_middleware(app)
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check(
|
||||
service=Depends(get_openclaw_cli_service),
|
||||
) -> dict[str, object]:
|
||||
return service.health()
|
||||
|
||||
@app.get("/api/status")
|
||||
async def api_status(
|
||||
service=Depends(get_openclaw_cli_service),
|
||||
) -> dict[str, object]:
|
||||
return {
|
||||
"status": "operational",
|
||||
"service": "openclaw-service",
|
||||
"openclaw": service.health(),
|
||||
}
|
||||
|
||||
app.include_router(openclaw_router)
|
||||
return app
|
||||
|
||||
|
||||
app = create_app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=8004)
|
||||
62
backend/apps/runtime_service.py
Normal file
62
backend/apps/runtime_service.py
Normal file
@@ -0,0 +1,62 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Dedicated runtime service FastAPI surface."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
from backend.api import runtime_router
|
||||
from backend.api.runtime import get_runtime_state
|
||||
from backend.apps.cors import add_cors_middleware
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
"""Create the runtime service app."""
|
||||
app = FastAPI(
|
||||
title="大时代 Runtime Service",
|
||||
description="Runtime lifecycle and gateway service surface extracted from the monolith",
|
||||
version="0.1.0",
|
||||
)
|
||||
|
||||
add_cors_middleware(app)
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check() -> dict[str, object]:
|
||||
"""Health check for the runtime service."""
|
||||
runtime_state = get_runtime_state()
|
||||
process = runtime_state.gateway_process
|
||||
is_running = process is not None and process.poll() is None
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "runtime-service",
|
||||
"gateway_running": is_running,
|
||||
"gateway_port": runtime_state.gateway_port,
|
||||
}
|
||||
|
||||
@app.get("/api/status")
|
||||
async def api_status() -> dict[str, object]:
|
||||
"""Service-level status payload for runtime orchestration."""
|
||||
runtime_state = get_runtime_state()
|
||||
process = runtime_state.gateway_process
|
||||
is_running = process is not None and process.poll() is None
|
||||
return {
|
||||
"status": "operational",
|
||||
"service": "runtime-service",
|
||||
"runtime": {
|
||||
"gateway_running": is_running,
|
||||
"gateway_port": runtime_state.gateway_port,
|
||||
"has_runtime_manager": runtime_state.runtime_manager is not None,
|
||||
},
|
||||
}
|
||||
|
||||
app.include_router(runtime_router)
|
||||
return app
|
||||
|
||||
|
||||
app = create_app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=8003)
|
||||
136
backend/apps/trading_service.py
Normal file
136
backend/apps/trading_service.py
Normal file
@@ -0,0 +1,136 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Trading data FastAPI surface."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fastapi import FastAPI, Query
|
||||
from backend.apps.cors import add_cors_middleware
|
||||
|
||||
from backend.domains import trading as trading_domain
|
||||
from shared.schema import (
|
||||
CompanyNewsResponse,
|
||||
FinancialMetricsResponse,
|
||||
InsiderTradeResponse,
|
||||
LineItemResponse,
|
||||
PriceResponse,
|
||||
)
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
"""Create the trading data service app."""
|
||||
app = FastAPI(
|
||||
title="大时代 Trading Service",
|
||||
description="Read-only trading data service surface extracted from the monolith",
|
||||
version="0.1.0",
|
||||
)
|
||||
|
||||
add_cors_middleware(app)
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check() -> dict[str, str]:
|
||||
"""Health check endpoint."""
|
||||
return {"status": "healthy", "service": "trading-service"}
|
||||
|
||||
@app.get("/api/prices", response_model=PriceResponse)
|
||||
async def api_get_prices(
|
||||
ticker: str = Query(..., min_length=1),
|
||||
start_date: str = Query(...),
|
||||
end_date: str = Query(...),
|
||||
) -> PriceResponse:
|
||||
payload = trading_domain.get_prices_payload(
|
||||
ticker=ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
return PriceResponse(ticker=payload["ticker"], prices=payload["prices"])
|
||||
|
||||
@app.get("/api/financials", response_model=FinancialMetricsResponse)
|
||||
async def api_get_financials(
|
||||
ticker: str = Query(..., min_length=1),
|
||||
end_date: str = Query(...),
|
||||
period: str = Query("ttm"),
|
||||
limit: int = Query(10, ge=1, le=100),
|
||||
) -> FinancialMetricsResponse:
|
||||
payload = trading_domain.get_financials_payload(
|
||||
ticker=ticker,
|
||||
end_date=end_date,
|
||||
period=period,
|
||||
limit=limit,
|
||||
)
|
||||
return FinancialMetricsResponse(financial_metrics=payload["financial_metrics"])
|
||||
|
||||
@app.get("/api/news", response_model=CompanyNewsResponse)
|
||||
async def api_get_news(
|
||||
ticker: str = Query(..., min_length=1),
|
||||
end_date: str = Query(...),
|
||||
start_date: str | None = Query(None),
|
||||
limit: int = Query(1000, ge=1, le=5000),
|
||||
) -> CompanyNewsResponse:
|
||||
payload = trading_domain.get_news_payload(
|
||||
ticker=ticker,
|
||||
end_date=end_date,
|
||||
start_date=start_date,
|
||||
limit=limit,
|
||||
)
|
||||
return CompanyNewsResponse(news=payload["news"])
|
||||
|
||||
@app.get("/api/insider-trades", response_model=InsiderTradeResponse)
|
||||
async def api_get_insider_trades(
|
||||
ticker: str = Query(..., min_length=1),
|
||||
end_date: str = Query(...),
|
||||
start_date: str | None = Query(None),
|
||||
limit: int = Query(1000, ge=1, le=5000),
|
||||
) -> InsiderTradeResponse:
|
||||
payload = trading_domain.get_insider_trades_payload(
|
||||
ticker=ticker,
|
||||
end_date=end_date,
|
||||
start_date=start_date,
|
||||
limit=limit,
|
||||
)
|
||||
return InsiderTradeResponse(insider_trades=payload["insider_trades"])
|
||||
|
||||
@app.get("/api/market/status")
|
||||
async def api_get_market_status() -> dict[str, Any]:
|
||||
"""Return current market status using the existing market service logic."""
|
||||
return trading_domain.get_market_status_payload()
|
||||
|
||||
@app.get("/api/market-cap")
|
||||
async def api_get_market_cap(
|
||||
ticker: str = Query(..., min_length=1),
|
||||
end_date: str = Query(...),
|
||||
) -> dict[str, Any]:
|
||||
"""Return market cap for one ticker/date."""
|
||||
return trading_domain.get_market_cap_payload(
|
||||
ticker=ticker,
|
||||
end_date=end_date,
|
||||
)
|
||||
|
||||
@app.get("/api/line-items", response_model=LineItemResponse)
|
||||
async def api_get_line_items(
|
||||
ticker: str = Query(..., min_length=1),
|
||||
line_items: list[str] = Query(...),
|
||||
end_date: str = Query(...),
|
||||
period: str = Query("ttm"),
|
||||
limit: int = Query(10, ge=1, le=100),
|
||||
) -> LineItemResponse:
|
||||
payload = trading_domain.get_line_items_payload(
|
||||
ticker=ticker,
|
||||
line_items=line_items,
|
||||
end_date=end_date,
|
||||
period=period,
|
||||
limit=limit,
|
||||
)
|
||||
return LineItemResponse(search_results=payload["search_results"])
|
||||
|
||||
return app
|
||||
|
||||
|
||||
app = create_app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=8001)
|
||||
740
backend/cli.py
740
backend/cli.py
@@ -1,38 +1,69 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
EvoTraders CLI - Command-line interface for the EvoTraders trading system.
|
||||
大时代 CLI - Command-line interface for the 大时代 trading system.
|
||||
|
||||
This module provides easy-to-use commands for running backtest, live trading,
|
||||
and frontend development server.
|
||||
"""
|
||||
# flake8: noqa: E501
|
||||
# pylint: disable=R0912, R0915
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
import typer
|
||||
import yaml
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.prompt import Confirm
|
||||
from rich.table import Table
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from backend.agents.prompt_loader import PromptLoader
|
||||
from backend.agents.agent_workspace import load_agent_workspace_config
|
||||
from backend.agents.prompt_loader import get_prompt_loader
|
||||
from backend.agents.skills_manager import SkillsManager
|
||||
from backend.agents.team_pipeline_config import (
|
||||
ensure_team_pipeline_config,
|
||||
load_team_pipeline_config,
|
||||
)
|
||||
from backend.agents.workspace_manager import WorkspaceManager
|
||||
from backend.config.constants import ANALYST_TYPES
|
||||
from backend.data.market_ingest import ingest_symbols
|
||||
from backend.data.market_store import MarketStore
|
||||
from backend.enrich.llm_enricher import get_explain_model_info, llm_enrichment_enabled
|
||||
from backend.enrich.news_enricher import enrich_symbols
|
||||
|
||||
app = typer.Typer(
|
||||
name="evotraders",
|
||||
help="EvoTraders: A self-evolving multi-agent trading system",
|
||||
help="大时代:自进化多智能体交易系统",
|
||||
add_completion=False,
|
||||
)
|
||||
ingest_app = typer.Typer(help="Ingest Polygon market data into the research warehouse.")
|
||||
app.add_typer(ingest_app, name="ingest")
|
||||
skills_app = typer.Typer(help="Inspect and manage per-agent skills.")
|
||||
app.add_typer(skills_app, name="skills")
|
||||
team_app = typer.Typer(help="Inspect and manage run-scoped team pipeline config.")
|
||||
app.add_typer(team_app, name="team")
|
||||
|
||||
console = Console()
|
||||
_prompt_loader = PromptLoader()
|
||||
_prompt_loader = get_prompt_loader()
|
||||
load_dotenv()
|
||||
|
||||
|
||||
def _normalize_typer_value(value, default):
|
||||
"""Allow CLI command functions to be called directly in tests/internal code."""
|
||||
if hasattr(value, "default"):
|
||||
return value.default
|
||||
return default if value is None else value
|
||||
|
||||
|
||||
def get_project_root() -> Path:
|
||||
@@ -75,8 +106,8 @@ def handle_history_cleanup(config_name: str, auto_clean: bool = False) -> None:
|
||||
)
|
||||
else:
|
||||
console.print(f" Directory size: [cyan]{size_mb:.1f} MB[/cyan]")
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not calculate directory size: {e}")
|
||||
|
||||
# Show last modified time
|
||||
state_dir = base_data_dir / "state"
|
||||
@@ -177,7 +208,8 @@ def run_data_updater(project_root: Path) -> None:
|
||||
console.print(
|
||||
"[yellow] Data updater module not available, skipping update[/yellow]\n",
|
||||
)
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logger.debug(f"Data updater check failed: {e}")
|
||||
console.print(
|
||||
"[yellow] Data updater check failed, skipping update[/yellow]\n",
|
||||
)
|
||||
@@ -204,6 +236,202 @@ def initialize_workspace(config_name: str) -> Path:
|
||||
return workspace_manager.get_run_dir(config_name)
|
||||
|
||||
|
||||
def _require_agent_asset_dir(config_name: str, agent_id: str) -> Path:
|
||||
manager = WorkspaceManager(project_root=get_project_root())
|
||||
manager.initialize_default_assets(
|
||||
config_name=config_name,
|
||||
agent_ids=[agent_id],
|
||||
analyst_personas=_prompt_loader.load_yaml_config(
|
||||
"analyst",
|
||||
"personas",
|
||||
),
|
||||
)
|
||||
return manager.skills_manager.get_agent_asset_dir(config_name, agent_id)
|
||||
|
||||
|
||||
def _resolve_symbols(raw_tickers: Optional[str], config_name: Optional[str] = None) -> list[str]:
|
||||
"""Resolve symbols from explicit input or runtime bootstrap config."""
|
||||
if raw_tickers and raw_tickers.strip():
|
||||
return [
|
||||
item.strip().upper()
|
||||
for item in raw_tickers.split(",")
|
||||
if item.strip()
|
||||
]
|
||||
|
||||
workspace_manager = WorkspaceManager(project_root=get_project_root())
|
||||
bootstrap_path = workspace_manager.get_run_dir(config_name or "default") / "BOOTSTRAP.md"
|
||||
if bootstrap_path.exists():
|
||||
content = bootstrap_path.read_text(encoding="utf-8")
|
||||
for line in content.splitlines():
|
||||
if line.strip().startswith("tickers:"):
|
||||
raw = line.split(":", 1)[1]
|
||||
return [
|
||||
item.strip().upper()
|
||||
for item in raw.split(",")
|
||||
if item.strip()
|
||||
]
|
||||
return []
|
||||
|
||||
|
||||
def _filter_problematic_report_rows(rows: list[dict]) -> list[dict]:
|
||||
"""Keep tickers with incomplete coverage or without any LLM-enriched rows."""
|
||||
return [
|
||||
row
|
||||
for row in rows
|
||||
if float(row.get("coverage_pct") or 0.0) < 100.0
|
||||
or int(row.get("llm_count") or 0) == 0
|
||||
]
|
||||
|
||||
|
||||
def auto_update_market_store(
|
||||
config_name: str,
|
||||
*,
|
||||
end_date: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Refresh the long-lived Polygon market store for the active watchlist."""
|
||||
api_key = os.getenv("POLYGON_API_KEY", "").strip()
|
||||
if not api_key:
|
||||
console.print(
|
||||
"[dim]Skipping Polygon market store update: POLYGON_API_KEY not set[/dim]",
|
||||
)
|
||||
return
|
||||
|
||||
symbols = _resolve_symbols(None, config_name)
|
||||
if not symbols:
|
||||
console.print(
|
||||
f"[dim]Skipping Polygon market store update: no tickers found for config '{config_name}'[/dim]",
|
||||
)
|
||||
return
|
||||
|
||||
target_end = end_date or datetime.now().date().isoformat()
|
||||
console.print(
|
||||
f"[cyan]Updating Polygon market store for {', '.join(symbols)} -> {target_end}[/cyan]",
|
||||
)
|
||||
|
||||
try:
|
||||
results = ingest_symbols(
|
||||
symbols,
|
||||
mode="incremental",
|
||||
end_date=target_end,
|
||||
)
|
||||
except Exception as exc:
|
||||
console.print(
|
||||
f"[yellow]Polygon market store update failed, continuing startup: {exc}[/yellow]",
|
||||
)
|
||||
return
|
||||
|
||||
for result in results:
|
||||
console.print(
|
||||
"[green]"
|
||||
f"{result['symbol']}"
|
||||
"[/green] "
|
||||
f"prices={result['prices']} news={result['news']} aligned={result['aligned']}"
|
||||
)
|
||||
|
||||
|
||||
def auto_prepare_backtest_market_store(
|
||||
config_name: str,
|
||||
*,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
) -> None:
|
||||
"""Ensure the market store has the requested backtest window for the active watchlist."""
|
||||
api_key = os.getenv("POLYGON_API_KEY", "").strip()
|
||||
if not api_key:
|
||||
console.print(
|
||||
"[dim]Skipping Polygon backtest preload: POLYGON_API_KEY not set[/dim]",
|
||||
)
|
||||
return
|
||||
|
||||
symbols = _resolve_symbols(None, config_name)
|
||||
if not symbols:
|
||||
console.print(
|
||||
f"[dim]Skipping Polygon backtest preload: no tickers found for config '{config_name}'[/dim]",
|
||||
)
|
||||
return
|
||||
|
||||
console.print(
|
||||
f"[cyan]Preparing Polygon market store for backtest {start_date} -> {end_date} "
|
||||
f"({', '.join(symbols)})[/cyan]",
|
||||
)
|
||||
|
||||
try:
|
||||
results = ingest_symbols(
|
||||
symbols,
|
||||
mode="full",
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
except Exception as exc:
|
||||
console.print(
|
||||
f"[yellow]Polygon backtest preload failed, continuing startup: {exc}[/yellow]",
|
||||
)
|
||||
return
|
||||
|
||||
for result in results:
|
||||
console.print(
|
||||
"[green]"
|
||||
f"{result['symbol']}"
|
||||
"[/green] "
|
||||
f"prices={result['prices']} news={result['news']} aligned={result['aligned']}"
|
||||
)
|
||||
|
||||
|
||||
def auto_enrich_market_store(
|
||||
config_name: str,
|
||||
*,
|
||||
end_date: Optional[str] = None,
|
||||
lookback_days: int = 120,
|
||||
force: bool = False,
|
||||
) -> None:
|
||||
"""Refresh explain-oriented enriched news for the active watchlist."""
|
||||
symbols = _resolve_symbols(None, config_name)
|
||||
if not symbols:
|
||||
console.print(
|
||||
f"[dim]Skipping explain enrich: no tickers found for config '{config_name}'[/dim]",
|
||||
)
|
||||
return
|
||||
|
||||
target_end = end_date or datetime.now().date().isoformat()
|
||||
try:
|
||||
end_dt = datetime.strptime(target_end, "%Y-%m-%d")
|
||||
except ValueError:
|
||||
console.print(
|
||||
f"[yellow]Skipping explain enrich: invalid end date {target_end}[/yellow]",
|
||||
)
|
||||
return
|
||||
|
||||
start_date = (end_dt - timedelta(days=max(1, lookback_days))).date().isoformat()
|
||||
console.print(
|
||||
f"[cyan]Refreshing explain enrich for {', '.join(symbols)} -> {target_end}[/cyan]",
|
||||
)
|
||||
store = MarketStore()
|
||||
try:
|
||||
results = enrich_symbols(
|
||||
store,
|
||||
symbols,
|
||||
start_date=start_date,
|
||||
end_date=target_end,
|
||||
limit=300,
|
||||
skip_existing=not force,
|
||||
)
|
||||
except Exception as exc:
|
||||
console.print(
|
||||
f"[yellow]Explain enrich failed, continuing startup: {exc}[/yellow]",
|
||||
)
|
||||
return
|
||||
|
||||
for result in results:
|
||||
console.print(
|
||||
"[green]"
|
||||
f"{result['symbol']}"
|
||||
"[/green] "
|
||||
f"news={result['news_count']} queued={result['queued_count']} analyzed={result['analyzed']} "
|
||||
f"skipped={result['skipped_existing_count']} deduped={result['deduped_count']} "
|
||||
f"llm={result['llm_count']} local={result['local_count']}"
|
||||
)
|
||||
|
||||
|
||||
@app.command("init-workspace")
|
||||
def init_workspace(
|
||||
config_name: str = typer.Option(
|
||||
@@ -223,6 +451,416 @@ def init_workspace(
|
||||
)
|
||||
|
||||
|
||||
@ingest_app.command("full")
|
||||
def ingest_full(
|
||||
tickers: Optional[str] = typer.Option(
|
||||
None,
|
||||
"--tickers",
|
||||
"-t",
|
||||
help="Comma-separated tickers to ingest",
|
||||
),
|
||||
start: Optional[str] = typer.Option(
|
||||
None,
|
||||
"--start",
|
||||
help="Start date for full ingestion (YYYY-MM-DD)",
|
||||
),
|
||||
end: Optional[str] = typer.Option(
|
||||
None,
|
||||
"--end",
|
||||
help="End date for ingestion (YYYY-MM-DD)",
|
||||
),
|
||||
config_name: str = typer.Option(
|
||||
"default",
|
||||
"--config-name",
|
||||
"-c",
|
||||
help="Fallback config to read tickers from BOOTSTRAP.md",
|
||||
),
|
||||
):
|
||||
"""Run full Polygon ingestion for the specified symbols."""
|
||||
symbols = _resolve_symbols(tickers, config_name)
|
||||
if not symbols:
|
||||
console.print("[red]No tickers provided and none found in BOOTSTRAP.md[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
console.print(f"[cyan]Starting full Polygon ingest for {', '.join(symbols)}[/cyan]")
|
||||
results = ingest_symbols(symbols, mode="full", start_date=start, end_date=end)
|
||||
for result in results:
|
||||
console.print(
|
||||
f"[green]{result['symbol']}[/green] prices={result['prices']} news={result['news']} aligned={result['aligned']}"
|
||||
)
|
||||
|
||||
|
||||
@ingest_app.command("update")
|
||||
def ingest_update(
|
||||
tickers: Optional[str] = typer.Option(
|
||||
None,
|
||||
"--tickers",
|
||||
"-t",
|
||||
help="Comma-separated tickers to update",
|
||||
),
|
||||
end: Optional[str] = typer.Option(
|
||||
None,
|
||||
"--end",
|
||||
help="Optional end date override (YYYY-MM-DD)",
|
||||
),
|
||||
config_name: str = typer.Option(
|
||||
"default",
|
||||
"--config-name",
|
||||
"-c",
|
||||
help="Fallback config to read tickers from BOOTSTRAP.md",
|
||||
),
|
||||
):
|
||||
"""Run incremental Polygon ingestion using stored watermarks."""
|
||||
symbols = _resolve_symbols(tickers, config_name)
|
||||
if not symbols:
|
||||
console.print("[red]No tickers provided and none found in BOOTSTRAP.md[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
console.print(f"[cyan]Starting incremental Polygon ingest for {', '.join(symbols)}[/cyan]")
|
||||
results = ingest_symbols(symbols, mode="incremental", end_date=end)
|
||||
for result in results:
|
||||
console.print(
|
||||
f"[green]{result['symbol']}[/green] prices={result['prices']} news={result['news']} aligned={result['aligned']}"
|
||||
)
|
||||
|
||||
|
||||
@ingest_app.command("enrich")
|
||||
def ingest_enrich(
|
||||
tickers: Optional[str] = typer.Option(
|
||||
None,
|
||||
"--tickers",
|
||||
"-t",
|
||||
help="Comma-separated tickers to enrich",
|
||||
),
|
||||
start: Optional[str] = typer.Option(
|
||||
None,
|
||||
"--start",
|
||||
help="Optional start date for enrichment window (YYYY-MM-DD)",
|
||||
),
|
||||
end: Optional[str] = typer.Option(
|
||||
None,
|
||||
"--end",
|
||||
help="Optional end date for enrichment window (YYYY-MM-DD)",
|
||||
),
|
||||
limit: int = typer.Option(
|
||||
300,
|
||||
"--limit",
|
||||
help="Maximum raw news rows per ticker to analyze",
|
||||
),
|
||||
force: bool = typer.Option(
|
||||
False,
|
||||
"--force",
|
||||
help="Re-analyze already enriched news instead of only missing rows",
|
||||
),
|
||||
config_name: str = typer.Option(
|
||||
"default",
|
||||
"--config-name",
|
||||
"-c",
|
||||
help="Fallback config to read tickers from BOOTSTRAP.md",
|
||||
),
|
||||
):
|
||||
"""Run explain-oriented news enrichment for symbols already in the market store."""
|
||||
symbols = _resolve_symbols(tickers, config_name)
|
||||
if not symbols:
|
||||
console.print("[red]No tickers provided and none found in BOOTSTRAP.md[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
console.print(f"[cyan]Starting explain enrich for {', '.join(symbols)}[/cyan]")
|
||||
store = MarketStore()
|
||||
results = enrich_symbols(
|
||||
store,
|
||||
symbols,
|
||||
start_date=start,
|
||||
end_date=end,
|
||||
limit=max(10, limit),
|
||||
skip_existing=not force,
|
||||
)
|
||||
for result in results:
|
||||
console.print(
|
||||
f"[green]{result['symbol']}[/green] "
|
||||
f"news={result['news_count']} queued={result['queued_count']} analyzed={result['analyzed']} "
|
||||
f"skipped={result['skipped_existing_count']} deduped={result['deduped_count']} "
|
||||
f"llm={result['llm_count']} local={result['local_count']}"
|
||||
)
|
||||
|
||||
|
||||
@ingest_app.command("report")
|
||||
def ingest_report(
|
||||
tickers: Optional[str] = typer.Option(
|
||||
None,
|
||||
"--tickers",
|
||||
"-t",
|
||||
help="Optional comma-separated tickers to report",
|
||||
),
|
||||
start: Optional[str] = typer.Option(
|
||||
None,
|
||||
"--start",
|
||||
help="Optional start date for report window (YYYY-MM-DD)",
|
||||
),
|
||||
end: Optional[str] = typer.Option(
|
||||
None,
|
||||
"--end",
|
||||
help="Optional end date for report window (YYYY-MM-DD)",
|
||||
),
|
||||
config_name: str = typer.Option(
|
||||
"default",
|
||||
"--config-name",
|
||||
"-c",
|
||||
help="Fallback config to read tickers from BOOTSTRAP.md",
|
||||
),
|
||||
only_problematic: bool = typer.Option(
|
||||
False,
|
||||
"--only-problematic",
|
||||
help="Only show tickers with incomplete coverage or no LLM-enriched news",
|
||||
),
|
||||
):
|
||||
"""Show explain enrichment coverage and freshness per ticker."""
|
||||
symbols = _resolve_symbols(tickers, config_name)
|
||||
store = MarketStore()
|
||||
report_rows = store.get_enrich_report(
|
||||
symbols=symbols or None,
|
||||
start_date=start,
|
||||
end_date=end,
|
||||
)
|
||||
if only_problematic:
|
||||
report_rows = _filter_problematic_report_rows(report_rows)
|
||||
if not report_rows:
|
||||
if only_problematic:
|
||||
console.print("[green]No problematic enrich report rows found for the requested scope[/green]")
|
||||
else:
|
||||
console.print("[yellow]No enrich report rows found for the requested scope[/yellow]")
|
||||
raise typer.Exit(0)
|
||||
|
||||
model_info = get_explain_model_info()
|
||||
model_label = model_info["label"] if llm_enrichment_enabled() else "disabled"
|
||||
table = Table(title="Explain Enrichment Report")
|
||||
table.add_column("Ticker", style="cyan")
|
||||
table.add_column("Raw News", justify="right")
|
||||
table.add_column("Analyzed", justify="right")
|
||||
table.add_column("Coverage", justify="right")
|
||||
table.add_column("LLM", justify="right")
|
||||
table.add_column("Local", justify="right")
|
||||
table.add_column("Latest Trade Date")
|
||||
table.add_column("Latest Analysis")
|
||||
table.caption = f"Explain LLM: {model_label}"
|
||||
|
||||
for row in report_rows:
|
||||
table.add_row(
|
||||
row["symbol"],
|
||||
str(row["raw_news_count"]),
|
||||
str(row["analyzed_news_count"]),
|
||||
f'{row["coverage_pct"]:.1f}%',
|
||||
str(row["llm_count"]),
|
||||
str(row["local_count"]),
|
||||
str(row["latest_trade_date"] or "-"),
|
||||
str(row["latest_analysis_at"] or "-"),
|
||||
)
|
||||
console.print(table)
|
||||
|
||||
|
||||
@skills_app.command("list")
|
||||
def skills_list(
|
||||
config_name: str = typer.Option(
|
||||
"default",
|
||||
"--config-name",
|
||||
"-c",
|
||||
help="Run config name.",
|
||||
),
|
||||
agent_id: Optional[str] = typer.Option(
|
||||
None,
|
||||
"--agent-id",
|
||||
"-a",
|
||||
help="Optional agent id to show resolved status for.",
|
||||
),
|
||||
):
|
||||
"""List available skills and optional agent-level enablement state."""
|
||||
project_root = get_project_root()
|
||||
skills_manager = SkillsManager(project_root=project_root)
|
||||
catalog = (
|
||||
skills_manager.list_agent_skill_catalog(config_name, agent_id)
|
||||
if agent_id
|
||||
else skills_manager.list_skill_catalog()
|
||||
)
|
||||
if not catalog:
|
||||
console.print("[yellow]No skills found[/yellow]")
|
||||
raise typer.Exit(0)
|
||||
|
||||
agent_config = None
|
||||
resolved_skills = set()
|
||||
if agent_id:
|
||||
asset_dir = _require_agent_asset_dir(config_name, agent_id)
|
||||
agent_config = load_agent_workspace_config(asset_dir / "agent.yaml")
|
||||
resolved_skills = set(
|
||||
skills_manager.resolve_agent_skill_names(
|
||||
config_name=config_name,
|
||||
agent_id=agent_id,
|
||||
default_skills=[],
|
||||
),
|
||||
)
|
||||
|
||||
table = Table(title="Skill Catalog")
|
||||
table.add_column("Skill", style="cyan")
|
||||
table.add_column("Source")
|
||||
table.add_column("Description")
|
||||
if agent_id:
|
||||
table.add_column("Status")
|
||||
|
||||
enabled = set(agent_config.enabled_skills) if agent_config else set()
|
||||
disabled = set(agent_config.disabled_skills) if agent_config else set()
|
||||
for skill in catalog:
|
||||
row = [
|
||||
skill.skill_name,
|
||||
skill.source,
|
||||
skill.description or "-",
|
||||
]
|
||||
if agent_id:
|
||||
if skill.skill_name in disabled:
|
||||
status = "disabled"
|
||||
elif skill.skill_name in enabled:
|
||||
status = "enabled"
|
||||
elif skill.skill_name in resolved_skills:
|
||||
status = "active"
|
||||
else:
|
||||
status = "-"
|
||||
row.append(status)
|
||||
table.add_row(*row)
|
||||
console.print(table)
|
||||
|
||||
|
||||
@skills_app.command("enable")
|
||||
def skills_enable(
|
||||
agent_id: str = typer.Option(..., "--agent-id", "-a", help="Agent id."),
|
||||
skill: str = typer.Option(..., "--skill", "-s", help="Skill name."),
|
||||
config_name: str = typer.Option(
|
||||
"default",
|
||||
"--config-name",
|
||||
"-c",
|
||||
help="Run config name.",
|
||||
),
|
||||
):
|
||||
"""Enable a skill for one agent in agent.yaml."""
|
||||
asset_dir = _require_agent_asset_dir(config_name, agent_id)
|
||||
skills_manager = SkillsManager(project_root=get_project_root())
|
||||
catalog = {
|
||||
item.skill_name
|
||||
for item in skills_manager.list_agent_skill_catalog(config_name, agent_id)
|
||||
}
|
||||
if skill not in catalog:
|
||||
console.print(f"[red]Unknown skill: {skill}[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
result = skills_manager.update_agent_skill_overrides(
|
||||
config_name=config_name,
|
||||
agent_id=agent_id,
|
||||
enable=[skill],
|
||||
)
|
||||
console.print(
|
||||
f"[green]Enabled[/green] `{skill}` for `{agent_id}` "
|
||||
f"([{asset_dir / 'agent.yaml'}])",
|
||||
)
|
||||
console.print(f"Enabled skills: {', '.join(result['enabled_skills']) or '-'}")
|
||||
console.print(f"Disabled skills: {', '.join(result['disabled_skills']) or '-'}")
|
||||
|
||||
|
||||
@skills_app.command("disable")
|
||||
def skills_disable(
|
||||
agent_id: str = typer.Option(..., "--agent-id", "-a", help="Agent id."),
|
||||
skill: str = typer.Option(..., "--skill", "-s", help="Skill name."),
|
||||
config_name: str = typer.Option(
|
||||
"default",
|
||||
"--config-name",
|
||||
"-c",
|
||||
help="Run config name.",
|
||||
),
|
||||
):
|
||||
"""Disable a skill for one agent in agent.yaml."""
|
||||
asset_dir = _require_agent_asset_dir(config_name, agent_id)
|
||||
skills_manager = SkillsManager(project_root=get_project_root())
|
||||
result = skills_manager.update_agent_skill_overrides(
|
||||
config_name=config_name,
|
||||
agent_id=agent_id,
|
||||
disable=[skill],
|
||||
)
|
||||
console.print(
|
||||
f"[yellow]Disabled[/yellow] `{skill}` for `{agent_id}` "
|
||||
f"([{asset_dir / 'agent.yaml'}])",
|
||||
)
|
||||
console.print(f"Enabled skills: {', '.join(result['enabled_skills']) or '-'}")
|
||||
console.print(f"Disabled skills: {', '.join(result['disabled_skills']) or '-'}")
|
||||
|
||||
|
||||
@skills_app.command("install")
|
||||
def skills_install(
|
||||
agent_id: str = typer.Option(..., "--agent-id", "-a", help="Target agent id."),
|
||||
source: str = typer.Option(
|
||||
...,
|
||||
"--source",
|
||||
"-s",
|
||||
help="External skill source: directory path, zip path, or http(s) zip URL.",
|
||||
),
|
||||
config_name: str = typer.Option(
|
||||
"default",
|
||||
"--config-name",
|
||||
"-c",
|
||||
help="Run config name.",
|
||||
),
|
||||
name: Optional[str] = typer.Option(
|
||||
None,
|
||||
"--name",
|
||||
help="Optional override skill name.",
|
||||
),
|
||||
activate: bool = typer.Option(
|
||||
True,
|
||||
"--activate/--no-activate",
|
||||
help="Enable the skill for this agent immediately.",
|
||||
),
|
||||
):
|
||||
"""Install an external skill into one agent's local skill directory."""
|
||||
_require_agent_asset_dir(config_name, agent_id)
|
||||
skills_manager = SkillsManager(project_root=get_project_root())
|
||||
result = skills_manager.install_external_skill_for_agent(
|
||||
config_name=config_name,
|
||||
agent_id=agent_id,
|
||||
source=source,
|
||||
skill_name=name,
|
||||
activate=activate,
|
||||
)
|
||||
console.print(
|
||||
f"[green]Installed[/green] `{result['skill_name']}` to `{agent_id}`",
|
||||
)
|
||||
console.print(f"Path: {result['target_dir']}")
|
||||
console.print(f"Activated: {result['activated']}")
|
||||
warnings = result.get("warnings") or []
|
||||
if warnings:
|
||||
console.print(f"Warnings: {'; '.join(warnings)}")
|
||||
|
||||
|
||||
@team_app.command("show")
|
||||
def team_show(
|
||||
config_name: str = typer.Option(
|
||||
"default",
|
||||
"--config-name",
|
||||
"-c",
|
||||
help="Run config name.",
|
||||
),
|
||||
):
|
||||
"""Show TEAM_PIPELINE.yaml for one run."""
|
||||
project_root = get_project_root()
|
||||
ensure_team_pipeline_config(
|
||||
project_root=project_root,
|
||||
config_name=config_name,
|
||||
default_analysts=list(ANALYST_TYPES.keys()),
|
||||
)
|
||||
config = load_team_pipeline_config(project_root, config_name)
|
||||
console.print(
|
||||
Panel.fit(
|
||||
yaml.safe_dump(config, allow_unicode=True, sort_keys=False),
|
||||
title=f"TEAM_PIPELINE ({config_name})",
|
||||
border_style="cyan",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@app.command()
|
||||
def backtest(
|
||||
start: Optional[str] = typer.Option(
|
||||
@@ -281,10 +919,11 @@ def backtest(
|
||||
"""
|
||||
console.print(
|
||||
Panel.fit(
|
||||
"[bold cyan]EvoTraders Backtest Mode[/bold cyan]",
|
||||
"[bold cyan]大时代 Backtest Mode[/bold cyan]",
|
||||
border_style="cyan",
|
||||
),
|
||||
)
|
||||
poll_interval = int(_normalize_typer_value(poll_interval, 10))
|
||||
|
||||
# Validate dates - required for backtest
|
||||
if not start or not end:
|
||||
@@ -331,6 +970,16 @@ def backtest(
|
||||
|
||||
# Run data updater
|
||||
run_data_updater(project_root)
|
||||
auto_prepare_backtest_market_store(
|
||||
config_name,
|
||||
start_date=start,
|
||||
end_date=end,
|
||||
)
|
||||
auto_enrich_market_store(
|
||||
config_name,
|
||||
end_date=end,
|
||||
force=False,
|
||||
)
|
||||
|
||||
# Build command using backend.main
|
||||
cmd = [
|
||||
@@ -370,11 +1019,6 @@ def backtest(
|
||||
|
||||
@app.command()
|
||||
def live(
|
||||
mock: bool = typer.Option(
|
||||
False,
|
||||
"--mock",
|
||||
help="Use mock mode with simulated prices (for testing)",
|
||||
),
|
||||
config_name: str = typer.Option(
|
||||
"live",
|
||||
"--config-name",
|
||||
@@ -392,12 +1036,22 @@ def live(
|
||||
"-p",
|
||||
help="WebSocket server port",
|
||||
),
|
||||
schedule_mode: str = typer.Option(
|
||||
"daily",
|
||||
"--schedule-mode",
|
||||
help="Scheduler mode: 'daily' or 'intraday'",
|
||||
),
|
||||
trigger_time: str = typer.Option(
|
||||
"now",
|
||||
"--trigger-time",
|
||||
"-t",
|
||||
help="Trigger time in LOCAL timezone (HH:MM), or 'now' to run immediately",
|
||||
),
|
||||
interval_minutes: int = typer.Option(
|
||||
60,
|
||||
"--interval-minutes",
|
||||
help="When schedule-mode=intraday, run every N minutes",
|
||||
),
|
||||
poll_interval: int = typer.Option(
|
||||
10,
|
||||
"--poll-interval",
|
||||
@@ -419,21 +1073,21 @@ def live(
|
||||
|
||||
Example:
|
||||
evotraders live # Run immediately (default)
|
||||
evotraders live --mock # Mock mode
|
||||
evotraders live -t 22:30 # Run at 22:30 local time daily
|
||||
evotraders live --schedule-mode intraday --interval-minutes 60
|
||||
evotraders live --trigger-time now # Run immediately
|
||||
evotraders live --clean # Clear historical data before starting
|
||||
"""
|
||||
mode_name = "MOCK" if mock else "LIVE"
|
||||
schedule_mode = str(_normalize_typer_value(schedule_mode, "daily"))
|
||||
interval_minutes = int(_normalize_typer_value(interval_minutes, 60))
|
||||
console.print(
|
||||
Panel.fit(
|
||||
f"[bold cyan]EvoTraders {mode_name} Mode[/bold cyan]",
|
||||
"[bold cyan]大时代 LIVE Mode[/bold cyan]",
|
||||
border_style="cyan",
|
||||
),
|
||||
)
|
||||
|
||||
# Check for required API key in live mode
|
||||
if not mock:
|
||||
env_file = get_project_root() / ".env"
|
||||
if not env_file.exists():
|
||||
console.print("\n[yellow]Warning: .env file not found[/yellow]")
|
||||
@@ -455,6 +1109,16 @@ def live(
|
||||
# Handle historical data cleanup
|
||||
handle_history_cleanup(config_name, auto_clean=clean)
|
||||
|
||||
if schedule_mode not in {"daily", "intraday"}:
|
||||
console.print(
|
||||
f"[red]Error: unsupported schedule mode '{schedule_mode}'[/red]",
|
||||
)
|
||||
raise typer.Exit(1)
|
||||
|
||||
if interval_minutes <= 0:
|
||||
console.print("[red]Error: --interval-minutes must be > 0[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
# Convert local time to NYSE time
|
||||
nyse_tz = ZoneInfo("America/New_York")
|
||||
local_tz = datetime.now().astimezone().tzinfo
|
||||
@@ -462,7 +1126,9 @@ def live(
|
||||
nyse_now = datetime.now(nyse_tz)
|
||||
|
||||
# Convert trigger time from local to NYSE
|
||||
if trigger_time.lower() == "now":
|
||||
if schedule_mode == "intraday":
|
||||
nyse_trigger_time = "now"
|
||||
elif trigger_time.lower() == "now":
|
||||
nyse_trigger_time = "now"
|
||||
else:
|
||||
local_trigger = datetime.strptime(trigger_time, "%H:%M")
|
||||
@@ -482,7 +1148,10 @@ def live(
|
||||
console.print(
|
||||
f" NYSE Time: {nyse_now.strftime('%Y-%m-%d %H:%M:%S %Z')}",
|
||||
)
|
||||
if nyse_trigger_time == "now":
|
||||
console.print(f" Schedule: {schedule_mode}")
|
||||
if schedule_mode == "intraday":
|
||||
console.print(f" Interval: every {interval_minutes} minute(s)")
|
||||
elif nyse_trigger_time == "now":
|
||||
console.print(" Trigger: [green]NOW (immediate)[/green]")
|
||||
else:
|
||||
console.print(
|
||||
@@ -491,9 +1160,6 @@ def live(
|
||||
|
||||
# Display configuration
|
||||
console.print("\n[bold]Configuration:[/bold]")
|
||||
if mock:
|
||||
console.print(" Mode: [yellow]MOCK[/yellow] (Simulated prices)")
|
||||
else:
|
||||
console.print(
|
||||
" Mode: [green]LIVE[/green] (Real-time prices via Finnhub)",
|
||||
)
|
||||
@@ -511,12 +1177,16 @@ def live(
|
||||
project_root = get_project_root()
|
||||
os.chdir(project_root)
|
||||
|
||||
# Data update (if not mock mode)
|
||||
if not mock:
|
||||
# Data update
|
||||
run_data_updater(project_root)
|
||||
else:
|
||||
console.print(
|
||||
"\n[dim]Mock mode enabled - skipping data update[/dim]\n",
|
||||
auto_update_market_store(
|
||||
config_name,
|
||||
end_date=nyse_now.date().isoformat(),
|
||||
)
|
||||
auto_enrich_market_store(
|
||||
config_name,
|
||||
end_date=nyse_now.date().isoformat(),
|
||||
force=False,
|
||||
)
|
||||
|
||||
# Build command using backend.main
|
||||
@@ -533,14 +1203,16 @@ def live(
|
||||
host,
|
||||
"--port",
|
||||
str(port),
|
||||
"--schedule-mode",
|
||||
schedule_mode,
|
||||
"--poll-interval",
|
||||
str(poll_interval),
|
||||
"--trigger-time",
|
||||
nyse_trigger_time,
|
||||
"--interval-minutes",
|
||||
str(interval_minutes),
|
||||
]
|
||||
|
||||
if mock:
|
||||
cmd.append("--mock")
|
||||
if enable_memory:
|
||||
cmd.append("--enable-memory")
|
||||
|
||||
@@ -579,7 +1251,7 @@ def frontend(
|
||||
"""
|
||||
console.print(
|
||||
Panel.fit(
|
||||
"[bold cyan]EvoTraders Frontend[/bold cyan]",
|
||||
"[bold cyan]大时代 Frontend[/bold cyan]",
|
||||
border_style="cyan",
|
||||
),
|
||||
)
|
||||
@@ -647,16 +1319,16 @@ def frontend(
|
||||
|
||||
@app.command()
|
||||
def version():
|
||||
"""Show the version of EvoTraders."""
|
||||
"""Show the version of 大时代."""
|
||||
console.print(
|
||||
"\n[bold cyan]EvoTraders[/bold cyan] version [green]0.1.0[/green]\n",
|
||||
"\n[bold cyan]大时代[/bold cyan] version [green]0.1.0[/green]\n",
|
||||
)
|
||||
|
||||
|
||||
@app.callback()
|
||||
def main():
|
||||
"""
|
||||
EvoTraders: A self-evolving multi-agent trading system
|
||||
大时代:自进化多智能体交易系统
|
||||
|
||||
Use 'evotraders --help' to see available commands.
|
||||
"""
|
||||
|
||||
@@ -4,6 +4,22 @@
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
DEFAULT_TICKERS = [
|
||||
"AAPL",
|
||||
"MSFT",
|
||||
"GOOGL",
|
||||
"AMZN",
|
||||
"NVDA",
|
||||
"META",
|
||||
"TSLA",
|
||||
"AMD",
|
||||
"NFLX",
|
||||
"AVGO",
|
||||
"PLTR",
|
||||
"COIN",
|
||||
]
|
||||
import re
|
||||
|
||||
import yaml
|
||||
@@ -119,12 +135,15 @@ def resolve_runtime_config(
|
||||
project_root: Path,
|
||||
config_name: str,
|
||||
enable_memory: bool = False,
|
||||
schedule_mode: str = "daily",
|
||||
interval_minutes: int = 60,
|
||||
trigger_time: str = "09:30",
|
||||
) -> Dict[str, Any]:
|
||||
"""Merge env defaults with run-scoped bootstrap front matter."""
|
||||
bootstrap = get_bootstrap_config_for_run(project_root, config_name)
|
||||
return {
|
||||
"tickers": bootstrap.get("tickers")
|
||||
or get_env_list("TICKERS", ["AAPL", "MSFT"]),
|
||||
or get_env_list("TICKERS", DEFAULT_TICKERS),
|
||||
"initial_cash": float(
|
||||
bootstrap.get(
|
||||
"initial_cash",
|
||||
@@ -143,6 +162,18 @@ def resolve_runtime_config(
|
||||
get_env_int("MAX_COMM_CYCLES", 2),
|
||||
),
|
||||
),
|
||||
"schedule_mode": str(
|
||||
bootstrap.get("schedule_mode", schedule_mode),
|
||||
).strip().lower() or schedule_mode,
|
||||
"interval_minutes": int(
|
||||
bootstrap.get(
|
||||
"interval_minutes",
|
||||
interval_minutes or get_env_int("INTERVAL_MINUTES", 60),
|
||||
),
|
||||
),
|
||||
"trigger_time": str(
|
||||
bootstrap.get("trigger_time", trigger_time),
|
||||
).strip() or trigger_time,
|
||||
"enable_memory": bool(enable_memory)
|
||||
or _coerce_bool(bootstrap.get("enable_memory", False)),
|
||||
}
|
||||
|
||||
@@ -76,27 +76,19 @@ def _resolve_config() -> DataSourceConfig:
|
||||
"""
|
||||
Resolve data source configuration based on available API keys.
|
||||
|
||||
Priority:
|
||||
1. FINNHUB_API_KEY (if set)
|
||||
2. FINANCIAL_DATASETS_API_KEY (if set)
|
||||
3. Raises error if neither is available
|
||||
The effective source should always match the first item in the resolved
|
||||
ordered source list.
|
||||
"""
|
||||
sources = _ordered_sources()
|
||||
if "finnhub" in sources:
|
||||
return DataSourceConfig(
|
||||
source="finnhub",
|
||||
api_key=os.getenv("FINNHUB_API_KEY", "").strip(),
|
||||
sources=sources,
|
||||
)
|
||||
if "financial_datasets" in sources:
|
||||
return DataSourceConfig(
|
||||
source="financial_datasets",
|
||||
api_key=os.getenv("FINANCIAL_DATASETS_API_KEY", "").strip(),
|
||||
sources=sources,
|
||||
)
|
||||
if "yfinance" in sources:
|
||||
return DataSourceConfig(source="yfinance", api_key="", sources=sources)
|
||||
return DataSourceConfig(source="local_csv", api_key="", sources=sources)
|
||||
source = sources[0] if sources else "local_csv"
|
||||
|
||||
api_key = ""
|
||||
if source == "finnhub":
|
||||
api_key = os.getenv("FINNHUB_API_KEY", "").strip()
|
||||
elif source == "financial_datasets":
|
||||
api_key = os.getenv("FINANCIAL_DATASETS_API_KEY", "").strip()
|
||||
|
||||
return DataSourceConfig(source=source, api_key=api_key, sources=sources)
|
||||
|
||||
|
||||
def get_config() -> DataSourceConfig:
|
||||
|
||||
@@ -1,7 +1,35 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Core pipeline and orchestration logic"""
|
||||
"""Core pipeline and orchestration logic.
|
||||
|
||||
Keep ``pipeline_runner`` behind lazy wrappers so importing ``backend.core`` does
|
||||
not immediately pull in the gateway runtime graph.
|
||||
"""
|
||||
|
||||
from .pipeline import TradingPipeline
|
||||
from .state_sync import StateSync
|
||||
|
||||
__all__ = ["TradingPipeline", "StateSync"]
|
||||
|
||||
def create_agents(*args, **kwargs):
|
||||
from .pipeline_runner import create_agents as _create_agents
|
||||
|
||||
return _create_agents(*args, **kwargs)
|
||||
|
||||
|
||||
def create_long_term_memory(*args, **kwargs):
|
||||
from .pipeline_runner import create_long_term_memory as _create_long_term_memory
|
||||
|
||||
return _create_long_term_memory(*args, **kwargs)
|
||||
|
||||
|
||||
def stop_gateway(*args, **kwargs):
|
||||
from .pipeline_runner import stop_gateway as _stop_gateway
|
||||
|
||||
return _stop_gateway(*args, **kwargs)
|
||||
|
||||
__all__ = [
|
||||
"TradingPipeline",
|
||||
"StateSync",
|
||||
"create_agents",
|
||||
"create_long_term_memory",
|
||||
"stop_gateway",
|
||||
]
|
||||
|
||||
@@ -10,26 +10,45 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional
|
||||
|
||||
from agentscope.message import Msg
|
||||
from agentscope.pipeline import MsgHub
|
||||
|
||||
from backend.utils.settlement import SettlementCoordinator
|
||||
from backend.utils.terminal_dashboard import get_dashboard
|
||||
from backend.core.state_sync import StateSync
|
||||
from backend.utils.trade_executor import PortfolioTradeExecutor
|
||||
from backend.runtime.manager import TradingRuntimeManager
|
||||
from backend.runtime.session import TradingSessionKey
|
||||
from backend.agents.team_pipeline_config import (
|
||||
resolve_active_analysts,
|
||||
update_active_analysts,
|
||||
)
|
||||
from backend.agents import AnalystAgent
|
||||
from backend.agents.toolkit_factory import create_agent_toolkit
|
||||
from backend.agents.workspace_manager import WorkspaceManager
|
||||
from backend.agents.prompt_loader import get_prompt_loader
|
||||
from backend.llm.models import get_agent_formatter, get_agent_model
|
||||
from backend.config.constants import ANALYST_TYPES
|
||||
|
||||
# Team infrastructure imports (graceful import - may not exist yet)
|
||||
try:
|
||||
from backend.agents.team.team_coordinator import TeamCoordinator
|
||||
from backend.agents.team.msg_hub import MsgHub as TeamMsgHub
|
||||
TEAM_COORD_AVAILABLE = True
|
||||
except ImportError:
|
||||
TEAM_COORD_AVAILABLE = False
|
||||
TeamCoordinator = None
|
||||
TeamMsgHub = None
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _log(msg: str):
|
||||
"""Log to dashboard if available, otherwise to logger"""
|
||||
dashboard = get_dashboard()
|
||||
if dashboard.live:
|
||||
dashboard.log(msg)
|
||||
else:
|
||||
def _log(msg: str) -> None:
|
||||
"""Helper function for pipeline logging."""
|
||||
logger.info(msg)
|
||||
|
||||
|
||||
@@ -46,6 +65,8 @@ class TradingPipeline:
|
||||
6. Reflection phase: broadcast closing P&L, agents record to long-term memory
|
||||
|
||||
Real-time updates via StateSync after each agent completes.
|
||||
|
||||
Supports both legacy agent lists and run-scoped agent loading.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -56,6 +77,9 @@ class TradingPipeline:
|
||||
state_sync: Optional["StateSync"] = None,
|
||||
settlement_coordinator: Optional[SettlementCoordinator] = None,
|
||||
max_comm_cycles: Optional[int] = None,
|
||||
workspace_id: Optional[str] = None,
|
||||
agent_factory: Optional[Any] = None,
|
||||
runtime_manager: Optional[TradingRuntimeManager] = None,
|
||||
):
|
||||
self.analysts = analysts
|
||||
self.risk_manager = risk_manager
|
||||
@@ -66,6 +90,17 @@ class TradingPipeline:
|
||||
os.getenv("MAX_COMM_CYCLES", "2"),
|
||||
)
|
||||
self.conference_summary = None # Store latest conference summary
|
||||
self.workspace_id = workspace_id
|
||||
self.agent_factory = agent_factory
|
||||
self.runtime_manager = runtime_manager
|
||||
self._session_key: Optional[str] = None
|
||||
self._dynamic_analysts: Dict[str, Any] = {}
|
||||
|
||||
if hasattr(self.pm, "set_team_controller"):
|
||||
self.pm.set_team_controller(
|
||||
create_agent_callback=self._create_runtime_analyst,
|
||||
remove_agent_callback=self._remove_runtime_analyst,
|
||||
)
|
||||
|
||||
async def run_cycle(
|
||||
self,
|
||||
@@ -80,6 +115,7 @@ class TradingPipeline:
|
||||
get_close_prices_fn: Optional[
|
||||
Callable[[], Awaitable[Dict[str, float]]]
|
||||
] = None,
|
||||
execute_decisions: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Run one complete trading cycle
|
||||
@@ -101,12 +137,19 @@ class TradingPipeline:
|
||||
Each agent's result is broadcast immediately via StateSync.
|
||||
"""
|
||||
_log(f"Starting cycle {date} - {len(tickers)} tickers")
|
||||
session_key = TradingSessionKey(date=date).key()
|
||||
self._session_key = session_key
|
||||
active_analysts = self._get_active_analysts()
|
||||
if self.runtime_manager:
|
||||
self.runtime_manager.set_session_key(session_key)
|
||||
self._runtime_log_event("cycle:start", {"tickers": tickers, "date": date})
|
||||
self._runtime_batch_status(active_analysts, "analysis_in_progress")
|
||||
|
||||
# Phase 0: Clear short-term memory to avoid cross-day context pollution
|
||||
_log("Phase 0: Clearing memory")
|
||||
await self._clear_all_agent_memory()
|
||||
|
||||
participants = self.analysts + [self.risk_manager, self.pm]
|
||||
participants = self._all_analysts() + [self.risk_manager, self.pm]
|
||||
|
||||
# Single MsgHub for entire cycle - no nesting
|
||||
async with MsgHub(
|
||||
@@ -117,12 +160,17 @@ class TradingPipeline:
|
||||
"system",
|
||||
),
|
||||
):
|
||||
# Phase 1.1: Analysts
|
||||
_log("Phase 1.1: Analyst analysis")
|
||||
analyst_results = await self._run_analysts_with_sync(tickers, date)
|
||||
# Phase 1.1: Analysts (parallel execution with TeamCoordinator)
|
||||
_log("Phase 1.1: Analyst analysis (parallel)")
|
||||
analyst_results = await self._run_analysts_parallel(
|
||||
tickers,
|
||||
date,
|
||||
active_analysts=active_analysts,
|
||||
)
|
||||
|
||||
# Phase 1.2: Risk Manager
|
||||
_log("Phase 1.2: Risk assessment")
|
||||
self._runtime_update_status(self.risk_manager, "risk_assessment")
|
||||
risk_assessment = await self._run_risk_manager_with_sync(
|
||||
tickers,
|
||||
date,
|
||||
@@ -145,6 +193,7 @@ class TradingPipeline:
|
||||
final_predictions = await self._collect_final_predictions(
|
||||
tickers,
|
||||
date,
|
||||
active_analysts=active_analysts,
|
||||
)
|
||||
|
||||
# Record final predictions for leaderboard ranking
|
||||
@@ -161,6 +210,7 @@ class TradingPipeline:
|
||||
|
||||
# Phase 3: PM makes decisions
|
||||
_log("Phase 3.1: PM makes decisions")
|
||||
self._runtime_update_status(self.pm, "decision_phase")
|
||||
pm_result = await self._run_pm_with_sync(
|
||||
tickers,
|
||||
date,
|
||||
@@ -169,10 +219,17 @@ class TradingPipeline:
|
||||
risk_assessment,
|
||||
)
|
||||
|
||||
# Phase 4: Execute decisions
|
||||
_log("Phase 4: Executing trades")
|
||||
decisions = pm_result.get("decisions", {})
|
||||
execution_result = {
|
||||
"executed_trades": [],
|
||||
"portfolio": self.pm.get_portfolio_state(),
|
||||
}
|
||||
if execute_decisions:
|
||||
_log("Phase 4: Executing trades")
|
||||
self._runtime_update_status(self.pm, "executing")
|
||||
execution_result = self._execute_decisions(decisions, prices, date)
|
||||
else:
|
||||
_log("Phase 4: Skipping trade execution")
|
||||
|
||||
# Live mode: wait for market close before settlement
|
||||
if get_close_prices_fn:
|
||||
@@ -184,6 +241,10 @@ class TradingPipeline:
|
||||
settlement_result = None
|
||||
if close_prices and self.settlement_coordinator:
|
||||
_log("Phase 5: Daily review and generate memories")
|
||||
self._runtime_batch_status(
|
||||
[self.risk_manager] + self._all_analysts() + [self.pm],
|
||||
"settlement",
|
||||
)
|
||||
|
||||
agent_trajectories = await self._capture_agent_trajectories()
|
||||
|
||||
@@ -214,8 +275,17 @@ class TradingPipeline:
|
||||
settlement_result=settlement_result,
|
||||
conference_summary=self.conference_summary,
|
||||
)
|
||||
self._runtime_batch_status(
|
||||
[self.risk_manager] + self._all_analysts() + [self.pm],
|
||||
"reflection",
|
||||
)
|
||||
|
||||
_log(f"Cycle complete: {date}")
|
||||
self._runtime_batch_status(
|
||||
self._all_analysts() + [self.risk_manager, self.pm],
|
||||
"idle",
|
||||
)
|
||||
self._runtime_log_event("cycle:end", {"tickers": tickers, "date": date})
|
||||
|
||||
return {
|
||||
"analyst_results": analyst_results,
|
||||
@@ -248,7 +318,7 @@ class TradingPipeline:
|
||||
},
|
||||
)
|
||||
|
||||
for analyst in self.analysts:
|
||||
for analyst in self._all_analysts():
|
||||
analyst.reload_runtime_assets(
|
||||
active_skill_dirs=active_skill_map.get(analyst.name, []),
|
||||
)
|
||||
@@ -262,7 +332,7 @@ class TradingPipeline:
|
||||
|
||||
return {
|
||||
"config_name": config_name,
|
||||
"reloaded_agents": [agent.name for agent in self.analysts]
|
||||
"reloaded_agents": [agent.name for agent in self._all_analysts()]
|
||||
+ ["risk_manager", "portfolio_manager"],
|
||||
"active_skills": {
|
||||
agent_id: [path.name for path in paths]
|
||||
@@ -273,7 +343,7 @@ class TradingPipeline:
|
||||
|
||||
async def _clear_all_agent_memory(self):
|
||||
"""Clear short-term memory for all agents"""
|
||||
for analyst in self.analysts:
|
||||
for analyst in self._all_analysts():
|
||||
await analyst.memory.clear()
|
||||
|
||||
await self.risk_manager.memory.clear()
|
||||
@@ -355,7 +425,7 @@ class TradingPipeline:
|
||||
trajectories = {}
|
||||
|
||||
# Capture analyst trajectories
|
||||
for analyst in self.analysts:
|
||||
for analyst in self._all_analysts():
|
||||
try:
|
||||
msgs = await analyst.memory.get_memory()
|
||||
if msgs:
|
||||
@@ -565,7 +635,7 @@ class TradingPipeline:
|
||||
)
|
||||
|
||||
# Record for analysts
|
||||
for analyst in self.analysts:
|
||||
for analyst in self._all_analysts():
|
||||
if (
|
||||
hasattr(analyst, "long_term_memory")
|
||||
and analyst.long_term_memory is not None
|
||||
@@ -684,7 +754,22 @@ class TradingPipeline:
|
||||
date=date,
|
||||
)
|
||||
|
||||
# Run discussion cycles (no new MsgHub - use parent's)
|
||||
# Conference participants: analysts + PM
|
||||
conference_participants = self._get_active_analysts() + [self.pm]
|
||||
|
||||
# Use TeamMsgHub for conference if available
|
||||
if TEAM_COORD_AVAILABLE and TeamMsgHub is not None:
|
||||
_log(
|
||||
f"Phase 2.1: Conference using TeamMsgHub with "
|
||||
f"{len(conference_participants)} participants"
|
||||
)
|
||||
conference_hub = TeamMsgHub(participants=conference_participants)
|
||||
else:
|
||||
_log("Phase 2.1: Conference using standard MsgHub context")
|
||||
conference_hub = None
|
||||
|
||||
# Run discussion cycles
|
||||
async with conference_hub if conference_hub else nullcontext(None):
|
||||
for cycle in range(self.max_comm_cycles):
|
||||
_log(
|
||||
"Phase 2.1: Conference discussion - "
|
||||
@@ -717,8 +802,8 @@ class TradingPipeline:
|
||||
content=pm_content,
|
||||
)
|
||||
|
||||
# Analysts share perspectives
|
||||
for analyst in self.analysts:
|
||||
# Analysts share perspectives (supports per-round active team updates)
|
||||
for analyst in self._get_active_analysts():
|
||||
analyst_prompt = self._build_analyst_discussion_prompt(
|
||||
cycle=cycle,
|
||||
tickers=tickers,
|
||||
@@ -845,6 +930,7 @@ class TradingPipeline:
|
||||
self,
|
||||
tickers: List[str],
|
||||
date: str,
|
||||
active_analysts: Optional[List[Any]] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Collect final predictions from all analysts as simple text responses.
|
||||
@@ -852,14 +938,15 @@ class TradingPipeline:
|
||||
"""
|
||||
_log(
|
||||
"Phase 2.2: Analysts generate final structured predictions\n"
|
||||
f" Starting _collect_final_predictions for {len(self.analysts)} analysts",
|
||||
f" Starting _collect_final_predictions for {len(active_analysts or self.analysts)} analysts",
|
||||
)
|
||||
final_predictions = []
|
||||
|
||||
for i, analyst in enumerate(self.analysts):
|
||||
analysts = active_analysts or self.analysts
|
||||
for i, analyst in enumerate(analysts):
|
||||
_log(
|
||||
"Phase 2.2: Analysts generate final structured predictions\n"
|
||||
f" Collecting prediction from analyst {i+1}/{len(self.analysts)}: {analyst.name}",
|
||||
f" Collecting prediction from analyst {i+1}/{len(analysts)}: {analyst.name}",
|
||||
)
|
||||
|
||||
prompt = (
|
||||
@@ -955,11 +1042,13 @@ class TradingPipeline:
|
||||
self,
|
||||
tickers: List[str],
|
||||
date: str,
|
||||
active_analysts: Optional[List[Any]] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Run all analysts with real-time sync after each completion"""
|
||||
results = []
|
||||
analysts = active_analysts or self.analysts
|
||||
|
||||
for analyst in self.analysts:
|
||||
for analyst in analysts:
|
||||
content = (
|
||||
f"Analyze the following stocks for date {date}: {', '.join(tickers)}. "
|
||||
f"Provide investment signals with confidence scores and reasoning."
|
||||
@@ -989,15 +1078,107 @@ class TradingPipeline:
|
||||
|
||||
return results
|
||||
|
||||
async def _run_analysts_parallel(
|
||||
self,
|
||||
tickers: List[str],
|
||||
date: str,
|
||||
active_analysts: Optional[List[Any]] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Run all analysts in parallel using TeamCoordinator.
|
||||
|
||||
This method replaces the sequential analyst loop with parallel execution
|
||||
using the TeamCoordinator for orchestration.
|
||||
|
||||
Args:
|
||||
tickers: List of stock tickers to analyze
|
||||
date: Trading date
|
||||
active_analysts: Optional list of analysts to run
|
||||
|
||||
Returns:
|
||||
List of analyst result dictionaries
|
||||
"""
|
||||
analysts = active_analysts or self.analysts
|
||||
|
||||
if not analysts:
|
||||
return []
|
||||
|
||||
if not TEAM_COORD_AVAILABLE:
|
||||
_log("TeamCoordinator not available, falling back to sequential execution")
|
||||
return await self._run_analysts_with_sync(
|
||||
tickers=tickers,
|
||||
date=date,
|
||||
active_analysts=active_analysts,
|
||||
)
|
||||
|
||||
_log(
|
||||
f"Phase 1.1: Running {len(analysts)} analysts in parallel "
|
||||
f"[{', '.join(a.name for a in analysts)}]"
|
||||
)
|
||||
|
||||
# Build the analyst prompt
|
||||
content = (
|
||||
f"Analyze the following stocks for date {date}: {', '.join(tickers)}. "
|
||||
f"Provide investment signals with confidence scores and reasoning."
|
||||
)
|
||||
|
||||
# Create coordinator for parallel execution
|
||||
coordinator = TeamCoordinator(
|
||||
participants=analysts,
|
||||
task_content=content,
|
||||
)
|
||||
|
||||
# Run analysts in parallel via TeamCoordinator
|
||||
results = await coordinator.run_phase(
|
||||
"analyst_analysis",
|
||||
metadata={"tickers": tickers, "date": date},
|
||||
)
|
||||
|
||||
# Process results and sync
|
||||
processed_results = []
|
||||
for i, (analyst, result) in enumerate(zip(analysts, results)):
|
||||
if result is not None:
|
||||
extracted = self._extract_result_from_msg(result)
|
||||
processed_results.append(extracted)
|
||||
|
||||
# Sync retrieved memory
|
||||
await self._sync_memory_if_retrieved(analyst)
|
||||
|
||||
# Broadcast agent result via StateSync
|
||||
if self.state_sync:
|
||||
text_content = self._extract_text_content(result.content)
|
||||
await self.state_sync.on_agent_complete(
|
||||
agent_id=analyst.name,
|
||||
content=text_content,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Analyst %s returned no result",
|
||||
analyst.name,
|
||||
)
|
||||
processed_results.append({
|
||||
"agent": analyst.name,
|
||||
"content": "",
|
||||
"success": False,
|
||||
})
|
||||
|
||||
_log(
|
||||
f"Phase 1.1: Parallel analyst execution complete "
|
||||
f"({len(processed_results)}/{len(analysts)} successful)"
|
||||
)
|
||||
|
||||
return processed_results
|
||||
|
||||
async def _run_analysts(
|
||||
self,
|
||||
tickers: List[str],
|
||||
date: str,
|
||||
active_analysts: Optional[List[Any]] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Run all analysts (without sync, for backward compatibility)"""
|
||||
results = []
|
||||
analysts = active_analysts or self.analysts
|
||||
|
||||
for analyst in self.analysts:
|
||||
for analyst in analysts:
|
||||
content = (
|
||||
f"Analyze the following stocks for date {date}: {', '.join(tickers)}. "
|
||||
f"Provide investment signals with confidence scores and reasoning."
|
||||
@@ -1306,3 +1487,198 @@ class TradingPipeline:
|
||||
if decision_texts:
|
||||
return "Decisions: " + "; ".join(decision_texts)
|
||||
return "Portfolio analysis completed. No trades recommended."
|
||||
|
||||
def load_agents_from_workspace(
|
||||
self,
|
||||
workspace_id: str,
|
||||
agent_factory: Optional[Any] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Load agents from workspace using AgentFactory.
|
||||
|
||||
This method supports the new EvoAgent architecture by loading
|
||||
agents from a workspace instead of using hardcoded agents.
|
||||
|
||||
Args:
|
||||
workspace_id: Workspace identifier
|
||||
agent_factory: Optional AgentFactory instance (uses self.agent_factory if None)
|
||||
|
||||
Returns:
|
||||
Dictionary with loaded agents:
|
||||
{
|
||||
"analysts": List[EvoAgent],
|
||||
"risk_manager": EvoAgent,
|
||||
"portfolio_manager": EvoAgent,
|
||||
}
|
||||
|
||||
Raises:
|
||||
ValueError: If workspace doesn't exist or no agents found
|
||||
"""
|
||||
factory = agent_factory or self.agent_factory
|
||||
if factory is None:
|
||||
from backend.agents import AgentFactory
|
||||
factory = AgentFactory()
|
||||
|
||||
# Check workspace exists
|
||||
if not factory.workspaces_root.exists():
|
||||
raise ValueError(f"Workspaces root does not exist: {factory.workspaces_root}")
|
||||
|
||||
workspace_dir = factory.workspaces_root / workspace_id
|
||||
if not workspace_dir.exists():
|
||||
raise ValueError(f"Workspace '{workspace_id}' does not exist")
|
||||
|
||||
# Load agents from workspace
|
||||
agents_data = factory.list_agents(workspace_id=workspace_id)
|
||||
|
||||
if not agents_data:
|
||||
raise ValueError(f"No agents found in workspace '{workspace_id}'")
|
||||
|
||||
# Categorize agents by type
|
||||
analysts = []
|
||||
risk_manager = None
|
||||
portfolio_manager = None
|
||||
|
||||
for agent_data in agents_data:
|
||||
agent_type = agent_data.get("agent_type", "unknown")
|
||||
agent_id = agent_data.get("agent_id")
|
||||
|
||||
# Load full agent configuration
|
||||
config_path = Path(agent_data.get("config_path", ""))
|
||||
if config_path.exists():
|
||||
agent = factory.load_agent(agent_id, workspace_id)
|
||||
|
||||
if agent_type.endswith("_analyst"):
|
||||
analysts.append(agent)
|
||||
elif agent_type == "risk_manager":
|
||||
risk_manager = agent
|
||||
elif agent_type == "portfolio_manager":
|
||||
portfolio_manager = agent
|
||||
|
||||
if not analysts:
|
||||
raise ValueError(f"No analysts found in workspace '{workspace_id}'")
|
||||
if risk_manager is None:
|
||||
raise ValueError(f"No risk_manager found in workspace '{workspace_id}'")
|
||||
if portfolio_manager is None:
|
||||
raise ValueError(f"No portfolio_manager found in workspace '{workspace_id}'")
|
||||
|
||||
return {
|
||||
"analysts": analysts,
|
||||
"risk_manager": risk_manager,
|
||||
"portfolio_manager": portfolio_manager,
|
||||
}
|
||||
|
||||
def reload_agents_from_workspace(self, workspace_id: Optional[str] = None) -> None:
|
||||
"""
|
||||
Reload all agents from workspace.
|
||||
|
||||
This updates self.analysts, self.risk_manager, and self.pm
|
||||
with agents loaded from the specified workspace.
|
||||
|
||||
Args:
|
||||
workspace_id: Workspace ID (uses self.workspace_id if None)
|
||||
"""
|
||||
ws_id = workspace_id or self.workspace_id
|
||||
if not ws_id:
|
||||
raise ValueError("No workspace_id specified")
|
||||
|
||||
loaded = self.load_agents_from_workspace(ws_id)
|
||||
|
||||
self.analysts = loaded["analysts"]
|
||||
self.risk_manager = loaded["risk_manager"]
|
||||
self.pm = loaded["portfolio_manager"]
|
||||
self.workspace_id = ws_id
|
||||
|
||||
logger.info(f"Reloaded {len(self.analysts)} analysts from workspace '{ws_id}'")
|
||||
|
||||
def _runtime_update_status(self, agent: Any, status: str) -> None:
|
||||
if not self.runtime_manager:
|
||||
return
|
||||
agent_id = getattr(agent, "agent_id", None) or getattr(agent, "name", None)
|
||||
if not agent_id:
|
||||
return
|
||||
self.runtime_manager.update_agent_status(agent_id, status, self._session_key)
|
||||
|
||||
def _runtime_batch_status(self, agents: List[Any], status: str) -> None:
|
||||
for agent in agents:
|
||||
self._runtime_update_status(agent, status)
|
||||
|
||||
def _all_analysts(self) -> List[Any]:
|
||||
"""Return static analysts plus runtime-created analysts."""
|
||||
return list(self.analysts) + list(self._dynamic_analysts.values())
|
||||
|
||||
def _create_runtime_analyst(self, agent_id: str, analyst_type: str) -> str:
|
||||
"""Create one runtime analyst instance."""
|
||||
if analyst_type not in ANALYST_TYPES:
|
||||
return (
|
||||
f"Unknown analyst_type '{analyst_type}'. "
|
||||
f"Available: {', '.join(ANALYST_TYPES.keys())}"
|
||||
)
|
||||
if agent_id in {agent.name for agent in self._all_analysts()}:
|
||||
return f"Analyst '{agent_id}' already exists."
|
||||
|
||||
config_name = getattr(self.pm, "config", {}).get("config_name", "default")
|
||||
project_root = Path(__file__).resolve().parents[2]
|
||||
personas = get_prompt_loader().load_yaml_config("analyst", "personas")
|
||||
persona = personas.get(analyst_type, {})
|
||||
workspace_manager = WorkspaceManager(project_root=project_root)
|
||||
workspace_manager.ensure_agent_assets(
|
||||
config_name=config_name,
|
||||
agent_id=agent_id,
|
||||
file_contents=workspace_manager.build_default_agent_files(
|
||||
agent_id=agent_id,
|
||||
persona=persona,
|
||||
),
|
||||
)
|
||||
|
||||
agent = AnalystAgent(
|
||||
analyst_type=analyst_type,
|
||||
toolkit=create_agent_toolkit(
|
||||
agent_id=agent_id,
|
||||
config_name=config_name,
|
||||
active_skill_dirs=[],
|
||||
),
|
||||
model=get_agent_model(analyst_type),
|
||||
formatter=get_agent_formatter(analyst_type),
|
||||
agent_id=agent_id,
|
||||
config={"config_name": config_name},
|
||||
)
|
||||
self._dynamic_analysts[agent_id] = agent
|
||||
update_active_analysts(
|
||||
project_root=project_root,
|
||||
config_name=config_name,
|
||||
available_analysts=[item.name for item in self._all_analysts()],
|
||||
add=[agent_id],
|
||||
)
|
||||
return f"Created runtime analyst '{agent_id}' ({analyst_type})."
|
||||
|
||||
def _remove_runtime_analyst(self, agent_id: str) -> str:
|
||||
"""Remove one runtime-created analyst instance."""
|
||||
if agent_id not in self._dynamic_analysts:
|
||||
return f"Runtime analyst '{agent_id}' not found."
|
||||
self._dynamic_analysts.pop(agent_id, None)
|
||||
config_name = getattr(self.pm, "config", {}).get("config_name", "default")
|
||||
project_root = Path(__file__).resolve().parents[2]
|
||||
update_active_analysts(
|
||||
project_root=project_root,
|
||||
config_name=config_name,
|
||||
available_analysts=[item.name for item in self._all_analysts()],
|
||||
remove=[agent_id],
|
||||
)
|
||||
return f"Removed runtime analyst '{agent_id}'."
|
||||
|
||||
def _get_active_analysts(self) -> List[Any]:
|
||||
"""Resolve active analyst participants from run-scoped team pipeline config."""
|
||||
config_name = getattr(self.pm, "config", {}).get("config_name", "default")
|
||||
project_root = Path(__file__).resolve().parents[2]
|
||||
analyst_map = {agent.name: agent for agent in self._all_analysts()}
|
||||
active_ids = resolve_active_analysts(
|
||||
project_root=project_root,
|
||||
config_name=config_name,
|
||||
available_analysts=list(analyst_map.keys()),
|
||||
)
|
||||
return [analyst_map[agent_id] for agent_id in active_ids if agent_id in analyst_map]
|
||||
|
||||
def _runtime_log_event(self, event: str, details: Optional[Dict[str, Any]] = None) -> None:
|
||||
if not self.runtime_manager:
|
||||
return
|
||||
self.runtime_manager.log_event(event, details)
|
||||
|
||||
481
backend/core/pipeline_runner.py
Normal file
481
backend/core/pipeline_runner.py
Normal file
@@ -0,0 +1,481 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Pipeline Runner - Independent trading pipeline execution
|
||||
|
||||
This module provides functions to start/stop trading pipelines
|
||||
that can be called from the REST API.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from contextlib import AsyncExitStack
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Callable
|
||||
|
||||
from backend.agents import AnalystAgent, PMAgent, RiskAgent
|
||||
from backend.agents.skills_manager import SkillsManager
|
||||
from backend.agents.toolkit_factory import create_agent_toolkit, load_agent_profiles
|
||||
from backend.agents.prompt_loader import get_prompt_loader
|
||||
from backend.agents.workspace_manager import WorkspaceManager
|
||||
from backend.config.constants import ANALYST_TYPES
|
||||
from backend.core.pipeline import TradingPipeline
|
||||
from backend.core.scheduler import BacktestScheduler, Scheduler
|
||||
from backend.llm.models import get_agent_formatter, get_agent_model
|
||||
from backend.runtime.manager import (
|
||||
TradingRuntimeManager,
|
||||
set_global_runtime_manager,
|
||||
clear_global_runtime_manager,
|
||||
set_shutdown_event,
|
||||
clear_shutdown_event,
|
||||
is_shutdown_requested,
|
||||
)
|
||||
from backend.services.market import MarketService
|
||||
from backend.services.storage import StorageService
|
||||
from backend.services.gateway import Gateway
|
||||
from backend.utils.settlement import SettlementCoordinator
|
||||
|
||||
_prompt_loader = get_prompt_loader()
|
||||
|
||||
# Global gateway reference for cleanup
|
||||
_gateway_instance: Optional[Gateway] = None
|
||||
|
||||
|
||||
def _set_gateway(gateway: Optional[Gateway]) -> None:
|
||||
"""Set global gateway reference."""
|
||||
global _gateway_instance
|
||||
_gateway_instance = gateway
|
||||
|
||||
|
||||
def stop_gateway() -> None:
|
||||
"""Stop the running gateway if exists."""
|
||||
global _gateway_instance
|
||||
if _gateway_instance is not None:
|
||||
try:
|
||||
_gateway_instance.stop()
|
||||
except Exception as e:
|
||||
import logging
|
||||
logging.getLogger(__name__).error(f"Error stopping gateway: {e}")
|
||||
finally:
|
||||
_gateway_instance = None
|
||||
|
||||
|
||||
def create_long_term_memory(agent_name: str, run_id: str, run_dir: Path):
|
||||
"""Create ReMeTaskLongTermMemory for an agent."""
|
||||
try:
|
||||
from agentscope.memory import ReMeTaskLongTermMemory
|
||||
from agentscope.model import DashScopeChatModel
|
||||
from agentscope.embedding import DashScopeTextEmbedding
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
api_key = os.getenv("MEMORY_API_KEY")
|
||||
if not api_key:
|
||||
return None
|
||||
|
||||
memory_dir = str(run_dir / "memory")
|
||||
|
||||
return ReMeTaskLongTermMemory(
|
||||
agent_name=agent_name,
|
||||
user_name=agent_name,
|
||||
model=DashScopeChatModel(
|
||||
model_name=os.getenv("MEMORY_MODEL_NAME", "qwen3-max"),
|
||||
api_key=api_key,
|
||||
stream=False,
|
||||
),
|
||||
embedding_model=DashScopeTextEmbedding(
|
||||
model_name=os.getenv("MEMORY_EMBEDDING_MODEL", "text-embedding-v4"),
|
||||
api_key=api_key,
|
||||
dimensions=1024,
|
||||
),
|
||||
**{
|
||||
"vector_store.default.backend": "local",
|
||||
"vector_store.default.params.store_dir": memory_dir,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def create_agents(
|
||||
run_id: str,
|
||||
run_dir: Path,
|
||||
initial_cash: float,
|
||||
margin_requirement: float,
|
||||
enable_long_term_memory: bool = False,
|
||||
):
|
||||
"""Create all agents for the system."""
|
||||
analysts = []
|
||||
long_term_memories = []
|
||||
|
||||
# Initialize workspace manager and assets
|
||||
workspace_manager = WorkspaceManager()
|
||||
workspace_manager.initialize_default_assets(
|
||||
config_name=run_id,
|
||||
agent_ids=list(ANALYST_TYPES.keys()) + ["risk_manager", "portfolio_manager"],
|
||||
analyst_personas=_prompt_loader.load_yaml_config("analyst", "personas"),
|
||||
)
|
||||
|
||||
profiles = load_agent_profiles()
|
||||
skills_manager = SkillsManager()
|
||||
active_skill_map = skills_manager.prepare_active_skills(
|
||||
config_name=run_id,
|
||||
agent_defaults={
|
||||
agent_id: profile.get("skills", [])
|
||||
for agent_id, profile in profiles.items()
|
||||
},
|
||||
)
|
||||
|
||||
# Create analyst agents
|
||||
for analyst_type in ANALYST_TYPES:
|
||||
model = get_agent_model(analyst_type)
|
||||
formatter = get_agent_formatter(analyst_type)
|
||||
toolkit = create_agent_toolkit(
|
||||
analyst_type,
|
||||
run_id,
|
||||
active_skill_dirs=active_skill_map.get(analyst_type, []),
|
||||
)
|
||||
|
||||
long_term_memory = None
|
||||
if enable_long_term_memory:
|
||||
long_term_memory = create_long_term_memory(analyst_type, run_id, run_dir)
|
||||
if long_term_memory:
|
||||
long_term_memories.append(long_term_memory)
|
||||
|
||||
analyst = AnalystAgent(
|
||||
analyst_type=analyst_type,
|
||||
toolkit=toolkit,
|
||||
model=model,
|
||||
formatter=formatter,
|
||||
agent_id=analyst_type,
|
||||
config={"config_name": run_id},
|
||||
long_term_memory=long_term_memory,
|
||||
)
|
||||
analysts.append(analyst)
|
||||
|
||||
# Create risk manager
|
||||
risk_long_term_memory = None
|
||||
if enable_long_term_memory:
|
||||
risk_long_term_memory = create_long_term_memory("risk_manager", run_id, run_dir)
|
||||
if risk_long_term_memory:
|
||||
long_term_memories.append(risk_long_term_memory)
|
||||
|
||||
risk_manager = RiskAgent(
|
||||
model=get_agent_model("risk_manager"),
|
||||
formatter=get_agent_formatter("risk_manager"),
|
||||
name="risk_manager",
|
||||
config={"config_name": run_id},
|
||||
long_term_memory=risk_long_term_memory,
|
||||
toolkit=create_agent_toolkit(
|
||||
"risk_manager",
|
||||
run_id,
|
||||
active_skill_dirs=active_skill_map.get("risk_manager", []),
|
||||
),
|
||||
)
|
||||
|
||||
# Create portfolio manager
|
||||
pm_long_term_memory = None
|
||||
if enable_long_term_memory:
|
||||
pm_long_term_memory = create_long_term_memory("portfolio_manager", run_id, run_dir)
|
||||
if pm_long_term_memory:
|
||||
long_term_memories.append(pm_long_term_memory)
|
||||
|
||||
portfolio_manager = PMAgent(
|
||||
name="portfolio_manager",
|
||||
model=get_agent_model("portfolio_manager"),
|
||||
formatter=get_agent_formatter("portfolio_manager"),
|
||||
initial_cash=initial_cash,
|
||||
margin_requirement=margin_requirement,
|
||||
config={"config_name": run_id},
|
||||
long_term_memory=pm_long_term_memory,
|
||||
toolkit_factory=create_agent_toolkit,
|
||||
toolkit_factory_kwargs={
|
||||
"active_skill_dirs": active_skill_map.get("portfolio_manager", []),
|
||||
},
|
||||
)
|
||||
|
||||
return analysts, risk_manager, portfolio_manager, long_term_memories
|
||||
|
||||
|
||||
async def run_pipeline(
|
||||
run_id: str,
|
||||
run_dir: Path,
|
||||
bootstrap: Dict[str, Any],
|
||||
stop_event: asyncio.Event,
|
||||
message_callback: Optional[Callable[[str, Any], None]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Run the trading pipeline with the given configuration.
|
||||
|
||||
Service Startup Order:
|
||||
Phase 1: WebSocket Server - Frontend can connect
|
||||
Phase 2: Market Service - Price data starts flowing
|
||||
Phase 3: Agent Runtime - Create all agents
|
||||
Phase 4: Pipeline & Scheduler - Trading logic ready
|
||||
Phase 5: Gateway Fully Operational - All systems running
|
||||
|
||||
Args:
|
||||
run_id: Unique run identifier (timestamp)
|
||||
run_dir: Run directory path
|
||||
bootstrap: Bootstrap configuration
|
||||
stop_event: Event to signal pipeline stop
|
||||
message_callback: Optional callback for sending messages to clients
|
||||
"""
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Set global shutdown event
|
||||
set_shutdown_event(stop_event)
|
||||
|
||||
logger.info(f"[Pipeline {run_id}] ======================================")
|
||||
logger.info(f"[Pipeline {run_id}] Starting with 5-phase initialization...")
|
||||
logger.info(f"[Pipeline {run_id}] ======================================")
|
||||
|
||||
try:
|
||||
# Extract config values
|
||||
tickers = bootstrap.get("tickers", ["AAPL", "MSFT", "GOOGL", "AMZN", "NVDA", "META", "TSLA", "AMD", "NFLX", "AVGO", "PLTR", "COIN"])
|
||||
initial_cash = float(bootstrap.get("initial_cash", 100000.0))
|
||||
margin_requirement = float(bootstrap.get("margin_requirement", 0.0))
|
||||
max_comm_cycles = int(bootstrap.get("max_comm_cycles", 2))
|
||||
schedule_mode = bootstrap.get("schedule_mode", "daily")
|
||||
trigger_time = bootstrap.get("trigger_time", "09:30")
|
||||
interval_minutes = int(bootstrap.get("interval_minutes", 60))
|
||||
heartbeat_interval = int(bootstrap.get("heartbeat_interval", 0))
|
||||
mode = bootstrap.get("mode", "live")
|
||||
start_date = bootstrap.get("start_date")
|
||||
end_date = bootstrap.get("end_date")
|
||||
enable_memory = bootstrap.get("enable_memory", False)
|
||||
|
||||
is_backtest = mode == "backtest"
|
||||
|
||||
# ======================================================================
|
||||
# PHASE 0: Initialize runtime manager
|
||||
# ======================================================================
|
||||
logger.info("[Phase 0/5] Initializing runtime manager...")
|
||||
|
||||
from backend.api.runtime import runtime_manager
|
||||
|
||||
if runtime_manager is None:
|
||||
runtime_manager = TradingRuntimeManager(
|
||||
config_name=run_id,
|
||||
run_dir=run_dir,
|
||||
bootstrap=bootstrap,
|
||||
)
|
||||
runtime_manager.prepare_run()
|
||||
|
||||
set_global_runtime_manager(runtime_manager)
|
||||
|
||||
# ======================================================================
|
||||
# PHASE 1 & 2: Create infrastructure services (Market, Storage)
|
||||
# These will be started by Gateway in the correct order
|
||||
# ======================================================================
|
||||
logger.info("[Phase 1-2/5] Creating infrastructure services...")
|
||||
|
||||
# Create storage service
|
||||
storage_service = StorageService(
|
||||
dashboard_dir=run_dir / "team_dashboard",
|
||||
initial_cash=initial_cash,
|
||||
config_name=run_id,
|
||||
)
|
||||
|
||||
if not storage_service.files["summary"].exists():
|
||||
storage_service.initialize_empty_dashboard()
|
||||
else:
|
||||
storage_service.update_leaderboard_model_info()
|
||||
|
||||
# Create market service (data source)
|
||||
market_service = MarketService(
|
||||
tickers=tickers,
|
||||
poll_interval=10,
|
||||
backtest_mode=is_backtest,
|
||||
api_key=os.getenv("FINNHUB_API_KEY") if not is_backtest else None,
|
||||
backtest_start_date=start_date if is_backtest else None,
|
||||
backtest_end_date=end_date if is_backtest else None,
|
||||
)
|
||||
|
||||
# ======================================================================
|
||||
# PHASE 3: Create Agent Runtime
|
||||
# ======================================================================
|
||||
logger.info("[Phase 3/5] Creating agent runtime...")
|
||||
|
||||
analysts, risk_manager, pm, long_term_memories = create_agents(
|
||||
run_id=run_id,
|
||||
run_dir=run_dir,
|
||||
initial_cash=initial_cash,
|
||||
margin_requirement=margin_requirement,
|
||||
enable_long_term_memory=enable_memory,
|
||||
)
|
||||
|
||||
# Register agents with runtime manager
|
||||
for agent in analysts + [risk_manager, pm]:
|
||||
agent_id = getattr(agent, "agent_id", None) or getattr(agent, "name", None)
|
||||
if agent_id:
|
||||
runtime_manager.register_agent(agent_id)
|
||||
|
||||
# Load portfolio state
|
||||
portfolio_state = storage_service.load_portfolio_state()
|
||||
pm.load_portfolio_state(portfolio_state)
|
||||
|
||||
# Create settlement coordinator
|
||||
settlement_coordinator = SettlementCoordinator(
|
||||
storage=storage_service,
|
||||
initial_capital=initial_cash,
|
||||
)
|
||||
|
||||
# ======================================================================
|
||||
# PHASE 4: Create Pipeline & Scheduler
|
||||
# ======================================================================
|
||||
logger.info("[Phase 4/5] Creating pipeline and scheduler...")
|
||||
|
||||
# Create pipeline
|
||||
pipeline = TradingPipeline(
|
||||
analysts=analysts,
|
||||
risk_manager=risk_manager,
|
||||
portfolio_manager=pm,
|
||||
settlement_coordinator=settlement_coordinator,
|
||||
max_comm_cycles=max_comm_cycles,
|
||||
runtime_manager=runtime_manager,
|
||||
)
|
||||
|
||||
# Create scheduler
|
||||
scheduler_callback = None
|
||||
live_scheduler = None
|
||||
|
||||
if is_backtest:
|
||||
backtest_scheduler = BacktestScheduler(
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
trading_calendar="NYSE",
|
||||
delay_between_days=0.5,
|
||||
)
|
||||
trading_dates = backtest_scheduler.get_trading_dates()
|
||||
|
||||
async def scheduler_callback_fn(callback):
|
||||
await backtest_scheduler.start(callback)
|
||||
|
||||
scheduler_callback = scheduler_callback_fn
|
||||
else:
|
||||
# Live mode
|
||||
live_scheduler = Scheduler(
|
||||
mode=schedule_mode,
|
||||
trigger_time=trigger_time,
|
||||
interval_minutes=interval_minutes,
|
||||
heartbeat_interval=heartbeat_interval if heartbeat_interval > 0 else None,
|
||||
config={"config_name": run_id},
|
||||
)
|
||||
|
||||
async def scheduler_callback_fn(callback):
|
||||
await live_scheduler.start(callback)
|
||||
|
||||
scheduler_callback = scheduler_callback_fn
|
||||
|
||||
# ======================================================================
|
||||
# PHASE 5: Start Gateway (WebSocket → Market → Scheduler)
|
||||
# Gateway.start() will handle the final startup sequence:
|
||||
# - WebSocket Server first (frontend can connect)
|
||||
# - Market Service second (price data flows)
|
||||
# - Scheduler last (trading begins)
|
||||
# ======================================================================
|
||||
logger.info("[Phase 5/5] Starting Gateway (WebSocket → Market → Scheduler)...")
|
||||
|
||||
gateway = Gateway(
|
||||
market_service=market_service,
|
||||
storage_service=storage_service,
|
||||
pipeline=pipeline,
|
||||
scheduler_callback=scheduler_callback,
|
||||
config={
|
||||
"mode": mode,
|
||||
"backtest_mode": is_backtest,
|
||||
"tickers": tickers,
|
||||
"config_name": run_id,
|
||||
"schedule_mode": schedule_mode,
|
||||
"interval_minutes": interval_minutes,
|
||||
"trigger_time": trigger_time,
|
||||
"heartbeat_interval": heartbeat_interval,
|
||||
"initial_cash": initial_cash,
|
||||
"margin_requirement": margin_requirement,
|
||||
"max_comm_cycles": max_comm_cycles,
|
||||
"enable_memory": enable_memory,
|
||||
},
|
||||
scheduler=live_scheduler,
|
||||
)
|
||||
_set_gateway(gateway)
|
||||
|
||||
# Start pipeline execution
|
||||
async with AsyncExitStack() as stack:
|
||||
# Enter long-term memory contexts
|
||||
for memory in long_term_memories:
|
||||
await stack.enter_async_context(memory)
|
||||
|
||||
# Start Gateway - this will execute the 4-phase startup:
|
||||
# Phase 1: WebSocket Server (frontend can connect immediately)
|
||||
# Phase 2: Market Service (price updates start flowing)
|
||||
# Phase 3: Market Status Monitor
|
||||
# Phase 4: Scheduler (trading cycles begin)
|
||||
gateway_task = asyncio.create_task(
|
||||
gateway.start(host="0.0.0.0", port=8765)
|
||||
)
|
||||
logger.info("[Pipeline] Gateway startup initiated on ws://localhost:8765")
|
||||
|
||||
# Wait for Gateway to fully initialize all phases
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Define the trading cycle callback
|
||||
async def trading_cycle(session_key: str) -> None:
|
||||
"""Execute one trading cycle."""
|
||||
if is_shutdown_requested():
|
||||
return
|
||||
|
||||
runtime_manager.set_session_key(session_key)
|
||||
runtime_manager.log_event("cycle:start", {"session": session_key})
|
||||
|
||||
try:
|
||||
# Fetch market data
|
||||
market_data = await market_service.get_all_data()
|
||||
|
||||
# Run pipeline
|
||||
await pipeline.run_cycle(
|
||||
session_key=session_key,
|
||||
market_data=market_data,
|
||||
)
|
||||
|
||||
runtime_manager.log_event("cycle:complete", {"session": session_key})
|
||||
|
||||
except Exception as e:
|
||||
runtime_manager.log_event("cycle:error", {"error": str(e)})
|
||||
raise
|
||||
|
||||
# Start scheduler
|
||||
if scheduler_callback:
|
||||
await scheduler_callback(trading_cycle)
|
||||
|
||||
# Wait for stop signal
|
||||
while not stop_event.is_set():
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Cancel gateway task
|
||||
if not gateway_task.done():
|
||||
gateway_task.cancel()
|
||||
try:
|
||||
await gateway_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
except asyncio.CancelledError:
|
||||
# Handle cancellation gracefully
|
||||
raise
|
||||
finally:
|
||||
# Cleanup
|
||||
logger.info("[Pipeline] Cleaning up...")
|
||||
|
||||
# Stop Gateway
|
||||
try:
|
||||
stop_gateway()
|
||||
logger.info("[Pipeline] Gateway stopped")
|
||||
except Exception as e:
|
||||
logger.error(f"[Pipeline] Error stopping gateway: {e}")
|
||||
|
||||
clear_shutdown_event()
|
||||
clear_global_runtime_manager()
|
||||
from backend.api.runtime import unregister_runtime_manager
|
||||
unregister_runtime_manager()
|
||||
logger.info("[Pipeline] Cleanup complete")
|
||||
@@ -4,7 +4,7 @@ Scheduler - Market-aware trigger system for trading cycles
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime, time, timedelta
|
||||
from typing import Any, Callable, Optional
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
@@ -28,16 +28,21 @@ class Scheduler:
|
||||
mode: str = "daily",
|
||||
trigger_time: Optional[str] = None,
|
||||
interval_minutes: Optional[int] = None,
|
||||
heartbeat_interval: Optional[int] = None,
|
||||
config: Optional[dict] = None,
|
||||
):
|
||||
self.mode = mode
|
||||
self.trigger_time = trigger_time or "09:30" # NYSE timezone
|
||||
self.trigger_now = self.trigger_time == "now"
|
||||
self.interval_minutes = interval_minutes or 60
|
||||
self.heartbeat_interval = heartbeat_interval # e.g. 3600 = 1 hour
|
||||
self.config = config or {}
|
||||
|
||||
self.running = False
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
self._heartbeat_task: Optional[asyncio.Task] = None
|
||||
self._callback: Optional[Callable] = None
|
||||
self._heartbeat_callback: Optional[Callable] = None
|
||||
|
||||
def _now_nyse(self) -> datetime:
|
||||
"""Get current time in NYSE timezone"""
|
||||
@@ -52,6 +57,15 @@ class Scheduler:
|
||||
)
|
||||
return len(valid_days) > 0
|
||||
|
||||
def _is_trading_hours(self, now: datetime) -> bool:
|
||||
"""Check if current time is within NYSE trading hours (9:30-16:00 ET)."""
|
||||
market_time = now.time()
|
||||
return time(9, 30) <= market_time <= time(16, 0)
|
||||
|
||||
def set_heartbeat_callback(self, callback: Callable) -> None:
|
||||
"""Register callback for heartbeat triggers."""
|
||||
self._heartbeat_callback = callback
|
||||
|
||||
def _next_trading_day(self, from_date: datetime) -> datetime:
|
||||
"""Find the next trading day from given date"""
|
||||
check_date = from_date
|
||||
@@ -68,18 +82,100 @@ class Scheduler:
|
||||
return
|
||||
|
||||
self.running = True
|
||||
self._callback = callback
|
||||
self._schedule_task()
|
||||
|
||||
if self.mode == "daily":
|
||||
self._task = asyncio.create_task(self._run_daily(callback))
|
||||
elif self.mode == "intraday":
|
||||
self._task = asyncio.create_task(self._run_intraday(callback))
|
||||
else:
|
||||
raise ValueError(f"Unknown scheduler mode: {self.mode}")
|
||||
# Start heartbeat loop if configured
|
||||
if self.heartbeat_interval and self._heartbeat_callback:
|
||||
self._heartbeat_task = asyncio.create_task(self._run_heartbeat_loop())
|
||||
logger.info(
|
||||
f"Heartbeat loop started: interval={self.heartbeat_interval}s",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Scheduler started: mode={self.mode}, timezone=America/New_York",
|
||||
)
|
||||
|
||||
def _schedule_task(self):
|
||||
"""Create the active scheduler task for the current mode."""
|
||||
if not self._callback:
|
||||
raise ValueError("Scheduler callback is not set")
|
||||
|
||||
if self._task:
|
||||
self._task.cancel()
|
||||
self._task = None
|
||||
|
||||
if self.mode == "daily":
|
||||
self._task = asyncio.create_task(self._run_daily(self._callback))
|
||||
elif self.mode == "intraday":
|
||||
self._task = asyncio.create_task(
|
||||
self._run_intraday(self._callback),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown scheduler mode: {self.mode}")
|
||||
|
||||
def reconfigure(
|
||||
self,
|
||||
*,
|
||||
mode: Optional[str] = None,
|
||||
trigger_time: Optional[str] = None,
|
||||
interval_minutes: Optional[int] = None,
|
||||
) -> bool:
|
||||
"""Update scheduler parameters in-place and restart its timing loop."""
|
||||
changed = False
|
||||
|
||||
if mode and mode != self.mode:
|
||||
self.mode = mode
|
||||
changed = True
|
||||
|
||||
if trigger_time and trigger_time != self.trigger_time:
|
||||
self.trigger_time = trigger_time
|
||||
self.trigger_now = self.trigger_time == "now"
|
||||
changed = True
|
||||
|
||||
if (
|
||||
interval_minutes is not None
|
||||
and interval_minutes > 0
|
||||
and interval_minutes != self.interval_minutes
|
||||
):
|
||||
self.interval_minutes = interval_minutes
|
||||
changed = True
|
||||
|
||||
if changed and self.running and self._callback:
|
||||
self._schedule_task()
|
||||
logger.info(
|
||||
"Scheduler reconfigured: mode=%s, trigger_time=%s, interval_minutes=%s",
|
||||
self.mode,
|
||||
self.trigger_time,
|
||||
self.interval_minutes,
|
||||
)
|
||||
|
||||
return changed
|
||||
|
||||
async def _run_heartbeat_loop(self):
|
||||
"""Run heartbeat checks on a separate interval during trading hours."""
|
||||
while self.running:
|
||||
now = self._now_nyse()
|
||||
if self._is_trading_day(now) and self._is_trading_hours(now):
|
||||
if self._heartbeat_callback:
|
||||
try:
|
||||
current_date = now.strftime("%Y-%m-%d")
|
||||
logger.debug(
|
||||
f"[Heartbeat] Triggering heartbeat check for {current_date}",
|
||||
)
|
||||
await self._heartbeat_callback(date=current_date)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[Heartbeat] Callback failed: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"[Heartbeat] Callback not set, skipping heartbeat",
|
||||
)
|
||||
|
||||
await asyncio.sleep(self.heartbeat_interval)
|
||||
|
||||
async def _run_daily(self, callback: Callable):
|
||||
"""Run once per trading day at specified time (NYSE timezone)"""
|
||||
first_run = True
|
||||
@@ -154,6 +250,9 @@ class Scheduler:
|
||||
if self._task:
|
||||
self._task.cancel()
|
||||
self._task = None
|
||||
if self._heartbeat_task:
|
||||
self._heartbeat_task.cancel()
|
||||
self._heartbeat_task = None
|
||||
logger.info("Scheduler stopped")
|
||||
|
||||
|
||||
|
||||
@@ -47,6 +47,10 @@ class StateSync:
|
||||
"""Set current simulation date for backtest-compatible timestamps"""
|
||||
self._simulation_date = date
|
||||
|
||||
def clear_simulation_date(self):
|
||||
"""Disable backtest timestamp simulation and use wall-clock time."""
|
||||
self._simulation_date = None
|
||||
|
||||
def _get_timestamp_ms(self) -> int:
|
||||
"""
|
||||
Get timestamp in milliseconds.
|
||||
@@ -97,12 +101,24 @@ class StateSync:
|
||||
if not self._enabled:
|
||||
return
|
||||
|
||||
# Ensure timestamp exists (use simulation date if in backtest mode)
|
||||
# Ensure timestamp exists. Prefer explicit millisecond timestamps so
|
||||
# frontend displays local wall time correctly instead of date-only UTC.
|
||||
if "timestamp" not in event:
|
||||
ts_ms = event.get("ts")
|
||||
if ts_ms is not None:
|
||||
try:
|
||||
event["timestamp"] = datetime.fromtimestamp(
|
||||
float(ts_ms) / 1000.0,
|
||||
).isoformat()
|
||||
except (TypeError, ValueError, OSError):
|
||||
if self._simulation_date:
|
||||
event["timestamp"] = f"{self._simulation_date}"
|
||||
else:
|
||||
event["timestamp"] = datetime.now().isoformat()
|
||||
elif self._simulation_date:
|
||||
event["timestamp"] = f"{self._simulation_date}"
|
||||
else:
|
||||
event["timestamp"] = datetime.now().isoformat()
|
||||
|
||||
# Persist to feed_history
|
||||
if persist:
|
||||
@@ -238,9 +254,12 @@ class StateSync:
|
||||
"""Called at start of trading cycle"""
|
||||
self._state["current_date"] = date
|
||||
self._state["status"] = "running"
|
||||
if self._state.get("server_mode") == "backtest":
|
||||
self.set_simulation_date(
|
||||
date,
|
||||
) # Set for backtest-compatible timestamps
|
||||
else:
|
||||
self.clear_simulation_date()
|
||||
|
||||
await self.emit(
|
||||
{
|
||||
@@ -411,7 +430,9 @@ class StateSync:
|
||||
|
||||
Useful for: frontend reconnection or restoring from saved state
|
||||
"""
|
||||
feed_history = self._state.get("feed_history", [])
|
||||
feed_history = self.storage.runtime_db.get_recent_feed_events(
|
||||
limit=self.storage.max_feed_history,
|
||||
) or self._state.get("feed_history", [])
|
||||
|
||||
# feed_history is newest-first, need to reverse for chronological replay # noqa: E501
|
||||
for event in reversed(feed_history):
|
||||
@@ -434,13 +455,21 @@ class StateSync:
|
||||
Returns:
|
||||
Dictionary suitable for sending to frontend
|
||||
"""
|
||||
feed_history = self.storage.runtime_db.get_recent_feed_events(
|
||||
limit=self.storage.max_feed_history,
|
||||
) or self._state.get("feed_history", [])
|
||||
last_day_history = self.storage.runtime_db.get_last_day_feed_events(
|
||||
current_date=self._state.get("current_date"),
|
||||
limit=self.storage.max_feed_history,
|
||||
) or self._state.get("last_day_history", [])
|
||||
|
||||
payload = {
|
||||
"server_mode": self._state.get("server_mode", "live"),
|
||||
"is_mock_mode": self._state.get("is_mock_mode", False),
|
||||
"is_backtest": self._state.get("is_backtest", False),
|
||||
"tickers": self._state.get("tickers"),
|
||||
"runtime_config": self._state.get("runtime_config"),
|
||||
"feed_history": self._state.get("feed_history", []),
|
||||
"feed_history": feed_history,
|
||||
"last_day_history": last_day_history,
|
||||
"current_date": self._state.get("current_date"),
|
||||
"trading_days_total": self._state.get("trading_days_total", 0),
|
||||
"trading_days_completed": self._state.get(
|
||||
@@ -458,12 +487,13 @@ class StateSync:
|
||||
}
|
||||
|
||||
if include_dashboard:
|
||||
dashboard_snapshot = self.storage.build_dashboard_snapshot_from_state(self._state)
|
||||
payload["dashboard"] = {
|
||||
"summary": self.storage.load_file("summary"),
|
||||
"holdings": self.storage.load_file("holdings"),
|
||||
"stats": self.storage.load_file("stats"),
|
||||
"trades": self.storage.load_file("trades"),
|
||||
"leaderboard": self.storage.load_file("leaderboard"),
|
||||
"summary": dashboard_snapshot.get("summary"),
|
||||
"holdings": dashboard_snapshot.get("holdings"),
|
||||
"stats": dashboard_snapshot.get("stats"),
|
||||
"trades": dashboard_snapshot.get("trades"),
|
||||
"leaderboard": dashboard_snapshot.get("leaderboard"),
|
||||
}
|
||||
|
||||
return payload
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from backend.data.historical_price_manager import HistoricalPriceManager
|
||||
from backend.data.mock_price_manager import MockPriceManager
|
||||
from backend.data.polling_price_manager import PollingPriceManager
|
||||
|
||||
__all__ = ["MockPriceManager", "PollingPriceManager", "HistoricalPriceManager"]
|
||||
__all__ = ["PollingPriceManager", "HistoricalPriceManager"]
|
||||
|
||||
@@ -7,6 +7,7 @@ from datetime import datetime
|
||||
from typing import Callable, Dict, List, Optional
|
||||
|
||||
import pandas as pd
|
||||
from backend.data.market_store import MarketStore
|
||||
from backend.data.provider_utils import normalize_symbol
|
||||
from backend.data.provider_router import get_provider_router
|
||||
|
||||
@@ -26,6 +27,7 @@ class HistoricalPriceManager:
|
||||
self.close_prices = {}
|
||||
self.running = False
|
||||
self._router = get_provider_router()
|
||||
self._market_store = MarketStore()
|
||||
|
||||
def subscribe(
|
||||
self,
|
||||
@@ -58,21 +60,48 @@ class HistoricalPriceManager:
|
||||
logger.warning(f"Failed to load CSV for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def _load_from_market_db(
|
||||
self,
|
||||
symbol: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
) -> Optional[pd.DataFrame]:
|
||||
"""Load price data from the long-lived market research database."""
|
||||
try:
|
||||
rows = self._market_store.get_ohlc(symbol, start_date, end_date)
|
||||
if not rows:
|
||||
return None
|
||||
df = pd.DataFrame(rows)
|
||||
if df.empty or "date" not in df.columns:
|
||||
return None
|
||||
df["Date"] = pd.to_datetime(df["date"])
|
||||
df.set_index("Date", inplace=True)
|
||||
df.sort_index(inplace=True)
|
||||
return df
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load market DB data for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def preload_data(self, start_date: str, end_date: str):
|
||||
"""Preload historical data from local CSV files."""
|
||||
"""Preload historical data from market DB first, then local CSV."""
|
||||
logger.info(f"Preloading data: {start_date} to {end_date}")
|
||||
|
||||
for symbol in self.subscribed_symbols:
|
||||
if symbol in self._price_cache:
|
||||
continue
|
||||
|
||||
# Load from local CSV file directly
|
||||
df = self._load_from_market_db(symbol, start_date, end_date)
|
||||
if df is not None and not df.empty:
|
||||
self._price_cache[symbol] = df
|
||||
logger.info(f"Loaded {symbol} from market DB: {len(df)} records")
|
||||
continue
|
||||
|
||||
df = self._load_from_csv(symbol)
|
||||
if df is not None and not df.empty:
|
||||
self._price_cache[symbol] = df
|
||||
logger.info(f"Loaded {symbol} from CSV: {len(df)} records")
|
||||
else:
|
||||
logger.warning(f"No CSV data for {symbol}")
|
||||
logger.warning(f"No market DB or CSV data for {symbol}")
|
||||
|
||||
def set_date(self, date: str):
|
||||
"""Set current trading date and update prices"""
|
||||
|
||||
299
backend/data/market_ingest.py
Normal file
299
backend/data/market_ingest.py
Normal file
@@ -0,0 +1,299 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Ingest Polygon market data into the long-lived research warehouse."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Iterable
|
||||
|
||||
from backend.data.market_store import MarketStore
|
||||
from backend.data.news_alignment import align_news_for_symbol
|
||||
from backend.data.provider_router import DataProviderRouter
|
||||
from backend.data.polygon_client import (
|
||||
fetch_news,
|
||||
fetch_ohlc,
|
||||
fetch_ticker_details,
|
||||
)
|
||||
from backend.data.provider_utils import normalize_symbol
|
||||
|
||||
|
||||
def _today_utc() -> str:
|
||||
return datetime.now(timezone.utc).date().isoformat()
|
||||
|
||||
|
||||
def _default_start(years: int = 2) -> str:
|
||||
return (datetime.now(timezone.utc).date() - timedelta(days=years * 366)).isoformat()
|
||||
|
||||
|
||||
def _max_news_date(news_rows: Iterable[dict]) -> str | None:
|
||||
dates = [
|
||||
str(item.get("published_utc") or "").strip()[:10]
|
||||
for item in news_rows
|
||||
if str(item.get("published_utc") or "").strip()
|
||||
]
|
||||
dates = [value for value in dates if value]
|
||||
return max(dates) if dates else None
|
||||
|
||||
|
||||
def _effective_last_news_fetch(
|
||||
market_store: MarketStore,
|
||||
*,
|
||||
ticker: str,
|
||||
end_date: str,
|
||||
watermark_value: str | None,
|
||||
) -> str | None:
|
||||
"""Clamp stale/future watermarks to the latest actually stored news date."""
|
||||
raw = str(watermark_value or "").strip()[:10]
|
||||
if not raw:
|
||||
return None
|
||||
if raw <= end_date:
|
||||
return raw
|
||||
|
||||
latest_stored = market_store.get_latest_news_date(ticker)
|
||||
if latest_stored and latest_stored <= end_date:
|
||||
return latest_stored
|
||||
return end_date
|
||||
|
||||
|
||||
def _normalize_provider_news_rows(ticker: str, news_items: Iterable[Any]) -> list[dict]:
|
||||
rows: list[dict] = []
|
||||
for item in news_items:
|
||||
payload = item.model_dump() if hasattr(item, "model_dump") else dict(item or {})
|
||||
related = payload.get("related")
|
||||
if isinstance(related, str):
|
||||
related_list = [value.strip().upper() for value in related.split(",") if value.strip()]
|
||||
elif isinstance(related, list):
|
||||
related_list = [str(value).strip().upper() for value in related if str(value).strip()]
|
||||
else:
|
||||
related_list = []
|
||||
if ticker not in related_list:
|
||||
related_list.append(ticker)
|
||||
rows.append(
|
||||
{
|
||||
"title": payload.get("title"),
|
||||
"description": payload.get("summary"),
|
||||
"summary": payload.get("summary"),
|
||||
"article_url": payload.get("url"),
|
||||
"published_utc": payload.get("date"),
|
||||
"publisher": payload.get("source"),
|
||||
"tickers": related_list,
|
||||
"category": payload.get("category"),
|
||||
"raw_json": payload,
|
||||
}
|
||||
)
|
||||
return rows
|
||||
|
||||
|
||||
def ingest_ticker_history(
|
||||
symbol: str,
|
||||
*,
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None,
|
||||
store: MarketStore | None = None,
|
||||
) -> dict:
|
||||
"""Fetch and persist Polygon OHLC + news for a ticker."""
|
||||
ticker = normalize_symbol(symbol)
|
||||
start = start_date or _default_start()
|
||||
end = end_date or _today_utc()
|
||||
market_store = store or MarketStore()
|
||||
|
||||
details = fetch_ticker_details(ticker)
|
||||
market_store.upsert_ticker(
|
||||
symbol=ticker,
|
||||
name=details.get("name"),
|
||||
sector=details.get("sic_description"),
|
||||
is_active=bool(details.get("active", True)),
|
||||
)
|
||||
|
||||
ohlc_rows = fetch_ohlc(ticker, start, end)
|
||||
news_rows = fetch_news(ticker, start, end)
|
||||
price_count = market_store.upsert_ohlc(ticker, ohlc_rows, source="polygon")
|
||||
news_count = market_store.upsert_news(ticker, news_rows, source="polygon")
|
||||
aligned_count = align_news_for_symbol(market_store, ticker)
|
||||
market_store.update_fetch_watermark(
|
||||
symbol=ticker,
|
||||
price_date=end,
|
||||
news_date=_max_news_date(news_rows),
|
||||
)
|
||||
|
||||
return {
|
||||
"symbol": ticker,
|
||||
"start_date": start,
|
||||
"end_date": end,
|
||||
"prices": price_count,
|
||||
"news": news_count,
|
||||
"aligned": aligned_count,
|
||||
}
|
||||
|
||||
|
||||
def update_ticker_incremental(
|
||||
symbol: str,
|
||||
*,
|
||||
end_date: str | None = None,
|
||||
store: MarketStore | None = None,
|
||||
) -> dict:
|
||||
"""Incrementally fetch OHLC + news since the last watermark."""
|
||||
ticker = normalize_symbol(symbol)
|
||||
market_store = store or MarketStore()
|
||||
watermarks = market_store.get_ticker_watermarks(ticker)
|
||||
end = end_date or _today_utc()
|
||||
start_prices = (
|
||||
(datetime.fromisoformat(watermarks["last_price_fetch"]) + timedelta(days=1)).date().isoformat()
|
||||
if watermarks.get("last_price_fetch")
|
||||
else _default_start()
|
||||
)
|
||||
effective_last_news_fetch = _effective_last_news_fetch(
|
||||
market_store,
|
||||
ticker=ticker,
|
||||
end_date=end,
|
||||
watermark_value=watermarks.get("last_news_fetch"),
|
||||
)
|
||||
start_news = (
|
||||
(datetime.fromisoformat(effective_last_news_fetch) + timedelta(days=1)).date().isoformat()
|
||||
if effective_last_news_fetch
|
||||
else _default_start()
|
||||
)
|
||||
|
||||
details = fetch_ticker_details(ticker)
|
||||
market_store.upsert_ticker(
|
||||
symbol=ticker,
|
||||
name=details.get("name"),
|
||||
sector=details.get("sic_description"),
|
||||
is_active=bool(details.get("active", True)),
|
||||
)
|
||||
|
||||
ohlc_rows = [] if start_prices > end else fetch_ohlc(ticker, start_prices, end)
|
||||
news_rows = [] if start_news > end else fetch_news(ticker, start_news, end)
|
||||
price_count = market_store.upsert_ohlc(ticker, ohlc_rows, source="polygon") if ohlc_rows else 0
|
||||
news_count = market_store.upsert_news(ticker, news_rows, source="polygon") if news_rows else 0
|
||||
aligned_count = align_news_for_symbol(market_store, ticker)
|
||||
market_store.update_fetch_watermark(
|
||||
symbol=ticker,
|
||||
price_date=end if ohlc_rows or watermarks.get("last_price_fetch") else None,
|
||||
news_date=_max_news_date(news_rows),
|
||||
)
|
||||
|
||||
return {
|
||||
"symbol": ticker,
|
||||
"start_price_date": start_prices,
|
||||
"start_news_date": start_news,
|
||||
"end_date": end,
|
||||
"prices": price_count,
|
||||
"news": news_count,
|
||||
"aligned": aligned_count,
|
||||
}
|
||||
|
||||
|
||||
def refresh_news_incremental(
|
||||
symbol: str,
|
||||
*,
|
||||
end_date: str | None = None,
|
||||
store: MarketStore | None = None,
|
||||
) -> dict:
|
||||
"""Incrementally fetch company news using the configured provider router."""
|
||||
ticker = normalize_symbol(symbol)
|
||||
market_store = store or MarketStore()
|
||||
watermarks = market_store.get_ticker_watermarks(ticker)
|
||||
end = end_date or _today_utc()
|
||||
effective_last_news_fetch = _effective_last_news_fetch(
|
||||
market_store,
|
||||
ticker=ticker,
|
||||
end_date=end,
|
||||
watermark_value=watermarks.get("last_news_fetch"),
|
||||
)
|
||||
start_news = (
|
||||
(datetime.fromisoformat(effective_last_news_fetch) + timedelta(days=1)).date().isoformat()
|
||||
if effective_last_news_fetch
|
||||
else _default_start()
|
||||
)
|
||||
|
||||
if start_news > end:
|
||||
return {
|
||||
"symbol": ticker,
|
||||
"start_news_date": start_news,
|
||||
"end_date": end,
|
||||
"news": 0,
|
||||
"aligned": 0,
|
||||
}
|
||||
|
||||
router = DataProviderRouter()
|
||||
news_items, source = router.get_company_news(
|
||||
ticker=ticker,
|
||||
start_date=start_news,
|
||||
end_date=end,
|
||||
limit=1000,
|
||||
)
|
||||
news_rows = _normalize_provider_news_rows(ticker, news_items)
|
||||
news_count = market_store.upsert_news(ticker, news_rows, source=source) if news_rows else 0
|
||||
aligned_count = align_news_for_symbol(market_store, ticker)
|
||||
market_store.update_fetch_watermark(
|
||||
symbol=ticker,
|
||||
news_date=_max_news_date(news_rows),
|
||||
)
|
||||
|
||||
return {
|
||||
"symbol": ticker,
|
||||
"start_news_date": start_news,
|
||||
"end_date": end,
|
||||
"news": news_count,
|
||||
"aligned": aligned_count,
|
||||
"source": source,
|
||||
}
|
||||
|
||||
|
||||
def refresh_news_for_symbols(
|
||||
symbols: Iterable[str],
|
||||
*,
|
||||
end_date: str | None = None,
|
||||
store: MarketStore | None = None,
|
||||
) -> list[dict]:
|
||||
"""Incrementally refresh company news for a list of tickers."""
|
||||
market_store = store or MarketStore()
|
||||
results = []
|
||||
for symbol in symbols:
|
||||
ticker = normalize_symbol(symbol)
|
||||
if not ticker:
|
||||
continue
|
||||
results.append(
|
||||
refresh_news_incremental(
|
||||
ticker,
|
||||
end_date=end_date,
|
||||
store=market_store,
|
||||
)
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
def ingest_symbols(
|
||||
symbols: Iterable[str],
|
||||
*,
|
||||
mode: str = "incremental",
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None,
|
||||
store: MarketStore | None = None,
|
||||
) -> list[dict]:
|
||||
"""Fetch Polygon data for a list of tickers."""
|
||||
market_store = store or MarketStore()
|
||||
results = []
|
||||
for symbol in symbols:
|
||||
ticker = normalize_symbol(symbol)
|
||||
if not ticker:
|
||||
continue
|
||||
if mode == "full":
|
||||
results.append(
|
||||
ingest_ticker_history(
|
||||
ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
store=market_store,
|
||||
)
|
||||
)
|
||||
else:
|
||||
results.append(
|
||||
update_ticker_incremental(
|
||||
ticker,
|
||||
end_date=end_date,
|
||||
store=market_store,
|
||||
)
|
||||
)
|
||||
return results
|
||||
1106
backend/data/market_store.py
Normal file
1106
backend/data/market_store.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,244 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Mock Price Manager - For testing during non-trading hours
|
||||
Generates virtual real-time price data
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import threading
|
||||
import time
|
||||
from typing import Callable, Dict, List, Optional
|
||||
from backend.data.provider_utils import normalize_symbol
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MockPriceManager:
|
||||
"""Mock Price Manager - Generates virtual prices for testing"""
|
||||
|
||||
def __init__(self, poll_interval: int = 10, volatility: float = 0.5):
|
||||
"""
|
||||
Args:
|
||||
poll_interval: Price update interval in seconds
|
||||
volatility: Price volatility percentage
|
||||
"""
|
||||
if poll_interval is None:
|
||||
poll_interval = int(os.getenv("MOCK_POLL_INTERVAL", "5"))
|
||||
if volatility is None:
|
||||
volatility = float(os.getenv("MOCK_VOLATILITY", "0.5"))
|
||||
|
||||
self.poll_interval = poll_interval
|
||||
self.volatility = volatility
|
||||
|
||||
self.subscribed_symbols: List[str] = []
|
||||
self.base_prices: Dict[str, float] = {}
|
||||
self.open_prices: Dict[str, float] = {}
|
||||
self.latest_prices: Dict[str, float] = {}
|
||||
self.price_callbacks: List[Callable] = []
|
||||
|
||||
self.running = False
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
|
||||
self.default_base_prices = {
|
||||
"AAPL": 237.50,
|
||||
"MSFT": 425.30,
|
||||
"GOOGL": 161.50,
|
||||
"AMZN": 218.45,
|
||||
"NVDA": 950.00,
|
||||
"META": 573.22,
|
||||
"TSLA": 342.15,
|
||||
"AMD": 168.90,
|
||||
"NFLX": 688.25,
|
||||
"INTC": 42.18,
|
||||
"COIN": 285.50,
|
||||
"PLTR": 45.80,
|
||||
"BABA": 88.30,
|
||||
"DIS": 112.50,
|
||||
"BKNG": 4850.00,
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"MockPriceManager initialized (interval: {self.poll_interval}s, "
|
||||
f"volatility: {self.volatility}%)",
|
||||
)
|
||||
|
||||
def subscribe(
|
||||
self,
|
||||
symbols: List[str],
|
||||
base_prices: Dict[str, float] = None,
|
||||
):
|
||||
"""Subscribe to stock symbols"""
|
||||
for symbol in symbols:
|
||||
symbol = normalize_symbol(symbol)
|
||||
if symbol not in self.subscribed_symbols:
|
||||
self.subscribed_symbols.append(symbol)
|
||||
|
||||
if base_prices and symbol in base_prices:
|
||||
base_price = base_prices[symbol]
|
||||
elif symbol in self.default_base_prices:
|
||||
base_price = self.default_base_prices[symbol]
|
||||
else:
|
||||
base_price = random.uniform(50, 500)
|
||||
|
||||
self.base_prices[symbol] = base_price
|
||||
self.open_prices[symbol] = base_price
|
||||
self.latest_prices[symbol] = base_price
|
||||
|
||||
logger.info(
|
||||
f"Subscribed to mock price: {symbol} (base: ${base_price:.2f})", # noqa: E501
|
||||
)
|
||||
|
||||
def unsubscribe(self, symbols: List[str]):
|
||||
"""Unsubscribe from symbols"""
|
||||
for symbol in symbols:
|
||||
symbol = normalize_symbol(symbol)
|
||||
if symbol in self.subscribed_symbols:
|
||||
self.subscribed_symbols.remove(symbol)
|
||||
self.base_prices.pop(symbol, None)
|
||||
self.open_prices.pop(symbol, None)
|
||||
self.latest_prices.pop(symbol, None)
|
||||
logger.info(f"Unsubscribed: {symbol}")
|
||||
|
||||
def add_price_callback(self, callback: Callable):
|
||||
"""Add price update callback"""
|
||||
self.price_callbacks.append(callback)
|
||||
|
||||
def _generate_price_update(self, symbol: str) -> float:
|
||||
"""Generate price update based on random walk"""
|
||||
current_price = self.latest_prices.get(
|
||||
symbol,
|
||||
self.base_prices[symbol],
|
||||
)
|
||||
|
||||
change_percent = random.uniform(-self.volatility, self.volatility)
|
||||
new_price = current_price * (1 + change_percent / 100)
|
||||
|
||||
# 10% chance of larger movement
|
||||
if random.random() < 0.1:
|
||||
trend_factor = random.uniform(-2, 2)
|
||||
new_price = new_price * (1 + trend_factor / 100)
|
||||
|
||||
# Limit intraday movement to +/-10%
|
||||
open_price = self.open_prices[symbol]
|
||||
max_price = open_price * 1.10
|
||||
min_price = open_price * 0.90
|
||||
new_price = max(min_price, min(max_price, new_price))
|
||||
|
||||
return new_price
|
||||
|
||||
def _update_prices(self):
|
||||
"""Update prices for all subscribed stocks"""
|
||||
timestamp = int(time.time() * 1000)
|
||||
|
||||
for symbol in self.subscribed_symbols:
|
||||
try:
|
||||
new_price = self._generate_price_update(symbol)
|
||||
self.latest_prices[symbol] = new_price
|
||||
|
||||
open_price = self.open_prices[symbol]
|
||||
ret = ((new_price - open_price) / open_price) * 100
|
||||
|
||||
price_data = {
|
||||
"symbol": symbol,
|
||||
"price": new_price,
|
||||
"timestamp": timestamp,
|
||||
"volume": random.randint(1000000, 10000000),
|
||||
"open": open_price,
|
||||
"high": max(new_price, open_price),
|
||||
"low": min(new_price, open_price),
|
||||
"previous_close": open_price,
|
||||
"ret": ret,
|
||||
}
|
||||
|
||||
for callback in self.price_callbacks:
|
||||
try:
|
||||
callback(price_data)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Mock price callback error ({symbol}): {e}",
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Mock {symbol}: ${new_price:.2f} [ret: {ret:+.2f}%]",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate mock price ({symbol}): {e}")
|
||||
|
||||
def _polling_loop(self):
|
||||
"""Main polling loop"""
|
||||
logger.info(
|
||||
f"Mock price generation started (interval: {self.poll_interval}s)",
|
||||
)
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
start_time = time.time()
|
||||
self._update_prices()
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
sleep_time = max(0, self.poll_interval - elapsed)
|
||||
if sleep_time > 0:
|
||||
time.sleep(sleep_time)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Mock polling loop error: {e}")
|
||||
time.sleep(5)
|
||||
|
||||
def start(self):
|
||||
"""Start mock price generation"""
|
||||
if self.running:
|
||||
logger.warning("Mock price manager already running")
|
||||
return
|
||||
|
||||
if not self.subscribed_symbols:
|
||||
logger.warning("No stocks subscribed")
|
||||
return
|
||||
|
||||
self.running = True
|
||||
self._thread = threading.Thread(target=self._polling_loop, daemon=True)
|
||||
self._thread.start()
|
||||
|
||||
logger.info(
|
||||
f"Mock price manager started: {', '.join(self.subscribed_symbols)}", # noqa: E501
|
||||
)
|
||||
|
||||
def stop(self):
|
||||
"""Stop mock price generation"""
|
||||
self.running = False
|
||||
if self._thread:
|
||||
self._thread.join(timeout=5)
|
||||
logger.info("Mock price manager stopped")
|
||||
|
||||
def get_latest_price(self, symbol: str) -> Optional[float]:
|
||||
"""Get latest price for symbol"""
|
||||
return self.latest_prices.get(symbol)
|
||||
|
||||
def get_all_latest_prices(self) -> Dict[str, float]:
|
||||
"""Get all latest prices"""
|
||||
return self.latest_prices.copy()
|
||||
|
||||
def get_open_price(self, symbol: str) -> Optional[float]:
|
||||
"""Get open price for symbol"""
|
||||
return self.open_prices.get(symbol)
|
||||
|
||||
def reset_open_prices(self):
|
||||
"""Reset open prices for new trading day"""
|
||||
for symbol in self.subscribed_symbols:
|
||||
last_close = self.latest_prices[symbol]
|
||||
gap_percent = random.uniform(-1, 1)
|
||||
new_open = last_close * (1 + gap_percent / 100)
|
||||
self.open_prices[symbol] = new_open
|
||||
self.latest_prices[symbol] = new_open
|
||||
logger.info("Open prices reset")
|
||||
|
||||
def set_base_price(self, symbol: str, price: float):
|
||||
"""Manually set base price for testing"""
|
||||
if symbol in self.subscribed_symbols:
|
||||
self.base_prices[symbol] = price
|
||||
self.open_prices[symbol] = price
|
||||
self.latest_prices[symbol] = price
|
||||
logger.info(f"{symbol} base price set to: ${price:.2f}")
|
||||
else:
|
||||
logger.warning(f"{symbol} not subscribed")
|
||||
64
backend/data/news_alignment.py
Normal file
64
backend/data/news_alignment.py
Normal file
@@ -0,0 +1,64 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Align persisted news to the nearest NYSE trading date."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import time
|
||||
|
||||
import pandas as pd
|
||||
import pandas_market_calendars as mcal
|
||||
|
||||
from backend.data.market_store import MarketStore
|
||||
|
||||
|
||||
NYSE_CALENDAR = mcal.get_calendar("NYSE")
|
||||
|
||||
|
||||
def _next_trading_day(date_str: str) -> str:
|
||||
start = pd.Timestamp(date_str).tz_localize(None)
|
||||
sessions = NYSE_CALENDAR.valid_days(
|
||||
start_date=(start - pd.Timedelta(days=1)).strftime("%Y-%m-%d"),
|
||||
end_date=(start + pd.Timedelta(days=10)).strftime("%Y-%m-%d"),
|
||||
)
|
||||
future = [
|
||||
pd.Timestamp(day).tz_localize(None).strftime("%Y-%m-%d")
|
||||
for day in sessions
|
||||
if pd.Timestamp(day).tz_localize(None) >= start
|
||||
]
|
||||
return future[0] if future else date_str
|
||||
|
||||
|
||||
def resolve_trade_date(published_utc: str | None) -> str | None:
|
||||
"""Map a published timestamp to an NYSE trade date."""
|
||||
if not published_utc:
|
||||
return None
|
||||
timestamp = pd.to_datetime(published_utc, utc=True, errors="coerce")
|
||||
if pd.isna(timestamp):
|
||||
return None
|
||||
nyse_time = timestamp.tz_convert("America/New_York")
|
||||
candidate = nyse_time.date().isoformat()
|
||||
valid_days = NYSE_CALENDAR.valid_days(start_date=candidate, end_date=candidate)
|
||||
if len(valid_days) == 0:
|
||||
return _next_trading_day(candidate)
|
||||
if nyse_time.time() >= time(16, 0):
|
||||
return _next_trading_day((nyse_time + pd.Timedelta(days=1)).date().isoformat())
|
||||
return candidate
|
||||
|
||||
|
||||
def align_news_for_symbol(store: MarketStore, symbol: str, *, limit: int = 5000) -> int:
|
||||
"""Fill missing trade_date values for one ticker."""
|
||||
pending = store.get_news_without_trade_date(symbol, limit=limit)
|
||||
updates = []
|
||||
for row in pending:
|
||||
trade_date = resolve_trade_date(row.get("published_utc"))
|
||||
if trade_date:
|
||||
updates.append(
|
||||
{
|
||||
"news_id": row["news_id"],
|
||||
"symbol": row["symbol"],
|
||||
"trade_date": trade_date,
|
||||
}
|
||||
)
|
||||
if not updates:
|
||||
return 0
|
||||
return store.set_trade_dates(updates)
|
||||
@@ -15,6 +15,9 @@ from backend.data.provider_utils import normalize_symbol
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_SUPPRESSED_LOG_EVERY = 20
|
||||
|
||||
|
||||
class PollingPriceManager:
|
||||
"""Polling-based price manager using Finnhub or yfinance."""
|
||||
|
||||
@@ -43,6 +46,7 @@ class PollingPriceManager:
|
||||
self.latest_prices: Dict[str, float] = {}
|
||||
self.open_prices: Dict[str, float] = {}
|
||||
self.price_callbacks: List[Callable] = []
|
||||
self._failure_counts: Dict[str, int] = {}
|
||||
|
||||
self.running = False
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
@@ -77,6 +81,8 @@ class PollingPriceManager:
|
||||
for symbol in self.subscribed_symbols:
|
||||
try:
|
||||
quote_data = self._fetch_quote(symbol)
|
||||
if not isinstance(quote_data, dict):
|
||||
raise ValueError(f"{symbol}: Empty quote payload")
|
||||
|
||||
current_price = quote_data.get("c")
|
||||
open_price = quote_data.get("o")
|
||||
@@ -103,6 +109,13 @@ class PollingPriceManager:
|
||||
)
|
||||
|
||||
self.latest_prices[symbol] = current_price
|
||||
previous_failures = self._failure_counts.pop(symbol, 0)
|
||||
if previous_failures > 0:
|
||||
logger.info(
|
||||
"%s quote polling recovered after %d consecutive failures",
|
||||
symbol,
|
||||
previous_failures,
|
||||
)
|
||||
|
||||
price_data = {
|
||||
"symbol": symbol,
|
||||
@@ -128,7 +141,20 @@ class PollingPriceManager:
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to fetch {symbol} price: {e}")
|
||||
failure_count = self._failure_counts.get(symbol, 0) + 1
|
||||
self._failure_counts[symbol] = failure_count
|
||||
message = f"Failed to fetch {symbol} price: {e}"
|
||||
|
||||
if failure_count == 1:
|
||||
logger.warning(message)
|
||||
elif failure_count % _SUPPRESSED_LOG_EVERY == 0:
|
||||
logger.warning(
|
||||
"%s (repeated %d times; suppressing intermediate failures)",
|
||||
message,
|
||||
failure_count,
|
||||
)
|
||||
else:
|
||||
logger.debug(message)
|
||||
|
||||
def _fetch_quote(self, symbol: str) -> Dict[str, float]:
|
||||
"""Fetch a normalized quote payload from the configured provider."""
|
||||
@@ -136,7 +162,10 @@ class PollingPriceManager:
|
||||
return self._fetch_yfinance_quote(symbol)
|
||||
if not self.finnhub_client:
|
||||
raise ValueError("Finnhub API key required for finnhub polling")
|
||||
return self.finnhub_client.quote(symbol)
|
||||
quote = self.finnhub_client.quote(symbol)
|
||||
if not isinstance(quote, dict):
|
||||
raise ValueError(f"{symbol}: Invalid Finnhub quote payload")
|
||||
return quote
|
||||
|
||||
def _fetch_yfinance_quote(self, symbol: str) -> Dict[str, float]:
|
||||
"""Fetch quote data from yfinance and normalize to Finnhub-like keys."""
|
||||
@@ -162,6 +191,8 @@ class PollingPriceManager:
|
||||
|
||||
if current_price is None:
|
||||
history = ticker.history(period="1d", interval="1m", auto_adjust=False)
|
||||
if history is None:
|
||||
raise ValueError(f"{symbol}: yfinance returned no history frame")
|
||||
if history.empty:
|
||||
raise ValueError(f"{symbol}: No yfinance quote data")
|
||||
latest = history.iloc[-1]
|
||||
|
||||
161
backend/data/polygon_client.py
Normal file
161
backend/data/polygon_client.py
Normal file
@@ -0,0 +1,161 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Polygon client used for long-lived market research ingestion."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Optional
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
BASE = "https://api.polygon.io"
|
||||
|
||||
|
||||
def _headers() -> dict[str, str]:
|
||||
api_key = os.getenv("POLYGON_API_KEY", "").strip()
|
||||
if not api_key:
|
||||
raise ValueError("Missing required API key: POLYGON_API_KEY")
|
||||
return {"Authorization": f"Bearer {api_key}"}
|
||||
|
||||
|
||||
def http_get(
|
||||
url: str,
|
||||
params: Optional[dict[str, Any]] = None,
|
||||
*,
|
||||
max_retries: int = 8,
|
||||
backoff: float = 2.0,
|
||||
) -> requests.Response:
|
||||
"""HTTP GET with exponential backoff and 429 handling."""
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
response = requests.get(
|
||||
url,
|
||||
params=params or {},
|
||||
headers=_headers(),
|
||||
timeout=30,
|
||||
)
|
||||
except requests.RequestException:
|
||||
time.sleep((backoff**attempt) + 0.5)
|
||||
if attempt == max_retries - 1:
|
||||
raise
|
||||
continue
|
||||
|
||||
if response.status_code == 429:
|
||||
retry_after = response.headers.get("Retry-After")
|
||||
wait = (
|
||||
float(retry_after)
|
||||
if retry_after and retry_after.isdigit()
|
||||
else min((backoff**attempt) + 1.0, 60.0)
|
||||
)
|
||||
time.sleep(wait)
|
||||
if attempt == max_retries - 1:
|
||||
response.raise_for_status()
|
||||
continue
|
||||
|
||||
if 500 <= response.status_code < 600:
|
||||
time.sleep(min((backoff**attempt) + 1.0, 60.0))
|
||||
if attempt == max_retries - 1:
|
||||
response.raise_for_status()
|
||||
continue
|
||||
|
||||
response.raise_for_status()
|
||||
return response
|
||||
raise RuntimeError("Unreachable")
|
||||
|
||||
|
||||
def fetch_ticker_details(symbol: str) -> dict[str, Any]:
|
||||
"""Fetch company metadata from Polygon."""
|
||||
response = http_get(f"{BASE}/v3/reference/tickers/{symbol}")
|
||||
return response.json().get("results", {}) or {}
|
||||
|
||||
|
||||
def fetch_ohlc(symbol: str, start_date: str, end_date: str) -> list[dict[str, Any]]:
|
||||
"""Fetch daily OHLC data from Polygon."""
|
||||
response = http_get(
|
||||
f"{BASE}/v2/aggs/ticker/{symbol}/range/1/day/{start_date}/{end_date}",
|
||||
params={"adjusted": "true", "sort": "asc", "limit": 50000},
|
||||
)
|
||||
results = response.json().get("results") or []
|
||||
rows: list[dict[str, Any]] = []
|
||||
for item in results:
|
||||
rows.append(
|
||||
{
|
||||
"date": datetime.fromtimestamp(
|
||||
int(item["t"]) / 1000,
|
||||
tz=timezone.utc,
|
||||
).date().isoformat(),
|
||||
"open": item.get("o"),
|
||||
"high": item.get("h"),
|
||||
"low": item.get("l"),
|
||||
"close": item.get("c"),
|
||||
"volume": item.get("v"),
|
||||
"vwap": item.get("vw"),
|
||||
"transactions": item.get("n"),
|
||||
}
|
||||
)
|
||||
return rows
|
||||
|
||||
|
||||
def fetch_news(
|
||||
symbol: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
*,
|
||||
per_page: int = 50,
|
||||
page_sleep: float = 1.2,
|
||||
max_pages: Optional[int] = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Fetch all Polygon news for a ticker, with pagination."""
|
||||
url = f"{BASE}/v2/reference/news"
|
||||
params = {
|
||||
"ticker": symbol,
|
||||
"published_utc.gte": start_date,
|
||||
"published_utc.lte": end_date,
|
||||
"limit": per_page,
|
||||
"order": "asc",
|
||||
}
|
||||
next_url: Optional[str] = None
|
||||
pages = 0
|
||||
all_articles: list[dict[str, Any]] = []
|
||||
seen_ids: set[str] = set()
|
||||
|
||||
while True:
|
||||
response = http_get(next_url or url, params=None if next_url else params)
|
||||
data = response.json()
|
||||
results = data.get("results") or []
|
||||
if not results:
|
||||
break
|
||||
|
||||
for item in results:
|
||||
article_id = item.get("id")
|
||||
if article_id and article_id in seen_ids:
|
||||
continue
|
||||
all_articles.append(
|
||||
{
|
||||
"id": article_id,
|
||||
"publisher": (item.get("publisher") or {}).get("name"),
|
||||
"title": item.get("title"),
|
||||
"author": item.get("author"),
|
||||
"published_utc": item.get("published_utc"),
|
||||
"amp_url": item.get("amp_url"),
|
||||
"article_url": item.get("article_url"),
|
||||
"tickers": item.get("tickers"),
|
||||
"description": item.get("description"),
|
||||
"insights": item.get("insights"),
|
||||
}
|
||||
)
|
||||
if article_id:
|
||||
seen_ids.add(article_id)
|
||||
|
||||
next_url = data.get("next_url")
|
||||
pages += 1
|
||||
if max_pages is not None and pages >= max_pages:
|
||||
break
|
||||
if not next_url:
|
||||
break
|
||||
time.sleep(page_sleep)
|
||||
|
||||
return all_articles
|
||||
@@ -11,7 +11,7 @@ import pandas as pd
|
||||
import yfinance as yf
|
||||
|
||||
from backend.config.data_config import DataSource, get_data_sources
|
||||
from backend.data.schema import (
|
||||
from shared.schema import (
|
||||
CompanyFactsResponse,
|
||||
CompanyNews,
|
||||
CompanyNewsResponse,
|
||||
|
||||
@@ -1,184 +1,50 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from pydantic import BaseModel
|
||||
"""Compatibility schema bridge.
|
||||
|
||||
This module preserves the legacy ``backend.data.schema`` import path while
|
||||
delegating the actual schema definitions to ``shared.schema``. Keeping one
|
||||
canonical DTO set avoids drift as the monolith is split into service-specific
|
||||
packages.
|
||||
"""
|
||||
|
||||
class Price(BaseModel):
|
||||
open: float
|
||||
close: float
|
||||
high: float
|
||||
low: float
|
||||
volume: int
|
||||
time: str
|
||||
from shared.schema import (
|
||||
AgentStateData,
|
||||
AgentStateMetadata,
|
||||
AnalystSignal,
|
||||
CompanyFacts,
|
||||
CompanyFactsResponse,
|
||||
CompanyNews,
|
||||
CompanyNewsResponse,
|
||||
FinancialMetrics,
|
||||
FinancialMetricsResponse,
|
||||
InsiderTrade,
|
||||
InsiderTradeResponse,
|
||||
LineItem,
|
||||
LineItemResponse,
|
||||
Portfolio,
|
||||
Position,
|
||||
Price,
|
||||
PriceResponse,
|
||||
TickerAnalysis,
|
||||
)
|
||||
|
||||
|
||||
class PriceResponse(BaseModel):
|
||||
ticker: str
|
||||
prices: list[Price]
|
||||
|
||||
|
||||
class FinancialMetrics(BaseModel):
|
||||
ticker: str
|
||||
report_period: str
|
||||
period: str
|
||||
currency: str
|
||||
market_cap: float | None
|
||||
enterprise_value: float | None
|
||||
price_to_earnings_ratio: float | None
|
||||
price_to_book_ratio: float | None
|
||||
price_to_sales_ratio: float | None
|
||||
enterprise_value_to_ebitda_ratio: float | None
|
||||
enterprise_value_to_revenue_ratio: float | None
|
||||
free_cash_flow_yield: float | None
|
||||
peg_ratio: float | None
|
||||
gross_margin: float | None
|
||||
operating_margin: float | None
|
||||
net_margin: float | None
|
||||
return_on_equity: float | None
|
||||
return_on_assets: float | None
|
||||
return_on_invested_capital: float | None
|
||||
asset_turnover: float | None
|
||||
inventory_turnover: float | None
|
||||
receivables_turnover: float | None
|
||||
days_sales_outstanding: float | None
|
||||
operating_cycle: float | None
|
||||
working_capital_turnover: float | None
|
||||
current_ratio: float | None
|
||||
quick_ratio: float | None
|
||||
cash_ratio: float | None
|
||||
operating_cash_flow_ratio: float | None
|
||||
debt_to_equity: float | None
|
||||
debt_to_assets: float | None
|
||||
interest_coverage: float | None
|
||||
revenue_growth: float | None
|
||||
earnings_growth: float | None
|
||||
book_value_growth: float | None
|
||||
earnings_per_share_growth: float | None
|
||||
free_cash_flow_growth: float | None
|
||||
operating_income_growth: float | None
|
||||
ebitda_growth: float | None
|
||||
payout_ratio: float | None
|
||||
earnings_per_share: float | None
|
||||
book_value_per_share: float | None
|
||||
free_cash_flow_per_share: float | None
|
||||
|
||||
|
||||
class FinancialMetricsResponse(BaseModel):
|
||||
financial_metrics: list[FinancialMetrics]
|
||||
|
||||
|
||||
class LineItem(BaseModel):
|
||||
ticker: str
|
||||
report_period: str
|
||||
period: str
|
||||
currency: str
|
||||
|
||||
# Allow additional fields dynamically
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
|
||||
class LineItemResponse(BaseModel):
|
||||
search_results: list[LineItem]
|
||||
|
||||
|
||||
class InsiderTrade(BaseModel):
|
||||
ticker: str
|
||||
issuer: str | None
|
||||
name: str | None
|
||||
title: str | None
|
||||
is_board_director: bool | None
|
||||
transaction_date: str | None
|
||||
transaction_shares: float | None
|
||||
transaction_price_per_share: float | None
|
||||
transaction_value: float | None
|
||||
shares_owned_before_transaction: float | None
|
||||
shares_owned_after_transaction: float | None
|
||||
security_title: str | None
|
||||
filing_date: str
|
||||
|
||||
|
||||
class InsiderTradeResponse(BaseModel):
|
||||
insider_trades: list[InsiderTrade]
|
||||
|
||||
|
||||
class CompanyNews(BaseModel):
|
||||
category: str | None = None
|
||||
ticker: str
|
||||
title: str
|
||||
related: str | None = None
|
||||
source: str
|
||||
date: str | None = None
|
||||
url: str
|
||||
summary: str | None = None
|
||||
|
||||
|
||||
class CompanyNewsResponse(BaseModel):
|
||||
news: list[CompanyNews]
|
||||
|
||||
|
||||
class CompanyFacts(BaseModel):
|
||||
ticker: str
|
||||
name: str
|
||||
cik: str | None = None
|
||||
industry: str | None = None
|
||||
sector: str | None = None
|
||||
category: str | None = None
|
||||
exchange: str | None = None
|
||||
is_active: bool | None = None
|
||||
listing_date: str | None = None
|
||||
location: str | None = None
|
||||
market_cap: float | None = None
|
||||
number_of_employees: int | None = None
|
||||
sec_filings_url: str | None = None
|
||||
sic_code: str | None = None
|
||||
sic_industry: str | None = None
|
||||
sic_sector: str | None = None
|
||||
website_url: str | None = None
|
||||
weighted_average_shares: int | None = None
|
||||
|
||||
|
||||
class CompanyFactsResponse(BaseModel):
|
||||
company_facts: CompanyFacts
|
||||
|
||||
|
||||
class Position(BaseModel):
|
||||
"""Position information - for Portfolio mode"""
|
||||
|
||||
long: int = 0 # Long position quantity (shares)
|
||||
short: int = 0 # Short position quantity (shares)
|
||||
long_cost_basis: float = 0.0 # Long position average cost
|
||||
short_cost_basis: float = 0.0 # Short position average cost
|
||||
|
||||
|
||||
class Portfolio(BaseModel):
|
||||
"""Portfolio - for Portfolio mode"""
|
||||
|
||||
cash: float = 100000.0 # Available cash
|
||||
positions: dict[str, Position] = {} # ticker -> Position mapping
|
||||
# Margin requirement (0.0 means shorting disabled, 0.5 means 50% margin)
|
||||
margin_requirement: float = 0.0
|
||||
margin_used: float = 0.0 # Margin used
|
||||
|
||||
|
||||
class AnalystSignal(BaseModel):
|
||||
signal: str | None = None
|
||||
confidence: float | None = None
|
||||
reasoning: dict | str | None = None
|
||||
max_position_size: float | None = None # For risk management signals
|
||||
|
||||
|
||||
class TickerAnalysis(BaseModel):
|
||||
ticker: str
|
||||
analyst_signals: dict[str, AnalystSignal] # agent_name -> signal mapping
|
||||
|
||||
|
||||
class AgentStateData(BaseModel):
|
||||
tickers: list[str]
|
||||
portfolio: Portfolio
|
||||
start_date: str
|
||||
end_date: str
|
||||
ticker_analyses: dict[str, TickerAnalysis] # ticker -> analysis mapping
|
||||
|
||||
|
||||
class AgentStateMetadata(BaseModel):
|
||||
show_reasoning: bool = False
|
||||
model_config = {"extra": "allow"}
|
||||
__all__ = [
|
||||
"Price",
|
||||
"PriceResponse",
|
||||
"FinancialMetrics",
|
||||
"FinancialMetricsResponse",
|
||||
"LineItem",
|
||||
"LineItemResponse",
|
||||
"InsiderTrade",
|
||||
"InsiderTradeResponse",
|
||||
"CompanyNews",
|
||||
"CompanyNewsResponse",
|
||||
"CompanyFacts",
|
||||
"CompanyFactsResponse",
|
||||
"Position",
|
||||
"Portfolio",
|
||||
"AnalystSignal",
|
||||
"TickerAnalysis",
|
||||
"AgentStateData",
|
||||
"AgentStateMetadata",
|
||||
]
|
||||
|
||||
2
backend/domains/__init__.py
Normal file
2
backend/domains/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Domain modules for split service internals."""
|
||||
320
backend/domains/news.py
Normal file
320
backend/domains/news.py
Normal file
@@ -0,0 +1,320 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""News/explain domain helpers shared by app surfaces and gateway fallbacks."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from backend.data.market_store import MarketStore
|
||||
from backend.data.market_ingest import update_ticker_incremental
|
||||
from backend.enrich.news_enricher import enrich_news_for_symbol
|
||||
from backend.explain.range_explainer import build_range_explanation
|
||||
from backend.explain.similarity_service import find_similar_days
|
||||
from backend.explain.story_service import get_or_create_stock_story
|
||||
|
||||
|
||||
def news_rows_need_enrichment(rows: list[dict[str, Any]]) -> bool:
|
||||
"""Return whether news rows are missing explain-oriented analysis fields."""
|
||||
if not rows:
|
||||
return True
|
||||
return all(
|
||||
not row.get("sentiment")
|
||||
and not row.get("relevance")
|
||||
and not row.get("key_discussion")
|
||||
for row in rows
|
||||
)
|
||||
|
||||
|
||||
def ensure_news_fresh(
|
||||
store: MarketStore,
|
||||
*,
|
||||
ticker: str,
|
||||
target_date: str | None = None,
|
||||
refresh_if_stale: bool = True,
|
||||
) -> dict[str, Any]:
|
||||
"""Refresh raw news incrementally when stored watermarks are stale."""
|
||||
normalized_target = str(target_date or "").strip()[:10]
|
||||
if not normalized_target:
|
||||
return {
|
||||
"ticker": ticker,
|
||||
"target_date": None,
|
||||
"last_news_fetch": None,
|
||||
"refreshed": False,
|
||||
}
|
||||
|
||||
watermarks = store.get_ticker_watermarks(ticker)
|
||||
last_news_fetch = str(watermarks.get("last_news_fetch") or "").strip()[:10]
|
||||
refreshed = False
|
||||
if refresh_if_stale and (not last_news_fetch or last_news_fetch < normalized_target):
|
||||
update_ticker_incremental(
|
||||
ticker,
|
||||
end_date=normalized_target,
|
||||
store=store,
|
||||
)
|
||||
refreshed = True
|
||||
watermarks = store.get_ticker_watermarks(ticker)
|
||||
last_news_fetch = str(watermarks.get("last_news_fetch") or "").strip()[:10]
|
||||
|
||||
return {
|
||||
"ticker": ticker,
|
||||
"target_date": normalized_target,
|
||||
"last_news_fetch": last_news_fetch or None,
|
||||
"refreshed": refreshed,
|
||||
}
|
||||
|
||||
|
||||
def get_enriched_news(
|
||||
store: MarketStore,
|
||||
*,
|
||||
ticker: str,
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None,
|
||||
limit: int = 100,
|
||||
refresh_if_stale: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
freshness = ensure_news_fresh(
|
||||
store,
|
||||
ticker=ticker,
|
||||
target_date=end_date,
|
||||
refresh_if_stale=refresh_if_stale,
|
||||
)
|
||||
rows = store.get_news_items_enriched(
|
||||
ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=limit,
|
||||
)
|
||||
if news_rows_need_enrichment(rows):
|
||||
enrich_news_for_symbol(
|
||||
store,
|
||||
ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=limit,
|
||||
)
|
||||
rows = store.get_news_items_enriched(
|
||||
ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=limit,
|
||||
)
|
||||
return {"ticker": ticker, "news": rows, "freshness": freshness}
|
||||
|
||||
|
||||
def get_news_for_date(
|
||||
store: MarketStore,
|
||||
*,
|
||||
ticker: str,
|
||||
date: str,
|
||||
limit: int = 20,
|
||||
refresh_if_stale: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
freshness = ensure_news_fresh(
|
||||
store,
|
||||
ticker=ticker,
|
||||
target_date=date,
|
||||
refresh_if_stale=refresh_if_stale,
|
||||
)
|
||||
rows = store.get_news_items_enriched(
|
||||
ticker,
|
||||
trade_date=date,
|
||||
limit=limit,
|
||||
)
|
||||
if news_rows_need_enrichment(rows):
|
||||
enrich_news_for_symbol(
|
||||
store,
|
||||
ticker,
|
||||
start_date=date,
|
||||
end_date=date,
|
||||
limit=limit,
|
||||
)
|
||||
rows = store.get_news_items_enriched(
|
||||
ticker,
|
||||
trade_date=date,
|
||||
limit=limit,
|
||||
)
|
||||
return {"ticker": ticker, "date": date, "news": rows, "freshness": freshness}
|
||||
|
||||
|
||||
def get_news_timeline(
|
||||
store: MarketStore,
|
||||
*,
|
||||
ticker: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
refresh_if_stale: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
freshness = ensure_news_fresh(
|
||||
store,
|
||||
ticker=ticker,
|
||||
target_date=end_date,
|
||||
refresh_if_stale=refresh_if_stale,
|
||||
)
|
||||
timeline = store.get_news_timeline_enriched(
|
||||
ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
if not timeline:
|
||||
enrich_news_for_symbol(
|
||||
store,
|
||||
ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=200,
|
||||
)
|
||||
timeline = store.get_news_timeline_enriched(
|
||||
ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
return {
|
||||
"ticker": ticker,
|
||||
"timeline": timeline,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
"freshness": freshness,
|
||||
}
|
||||
|
||||
|
||||
def get_news_categories(
|
||||
store: MarketStore,
|
||||
*,
|
||||
ticker: str,
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None,
|
||||
limit: int = 200,
|
||||
refresh_if_stale: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
freshness = ensure_news_fresh(
|
||||
store,
|
||||
ticker=ticker,
|
||||
target_date=end_date,
|
||||
refresh_if_stale=refresh_if_stale,
|
||||
)
|
||||
rows = store.get_news_items_enriched(
|
||||
ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=limit,
|
||||
)
|
||||
if news_rows_need_enrichment(rows):
|
||||
enrich_news_for_symbol(
|
||||
store,
|
||||
ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=limit,
|
||||
)
|
||||
categories = store.get_news_categories_enriched(
|
||||
ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=limit,
|
||||
)
|
||||
return {"ticker": ticker, "categories": categories, "freshness": freshness}
|
||||
|
||||
|
||||
def get_similar_days_payload(
|
||||
store: MarketStore,
|
||||
*,
|
||||
ticker: str,
|
||||
date: str,
|
||||
n_similar: int = 5,
|
||||
refresh_if_stale: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
freshness = ensure_news_fresh(
|
||||
store,
|
||||
ticker=ticker,
|
||||
target_date=date,
|
||||
refresh_if_stale=refresh_if_stale,
|
||||
)
|
||||
result = find_similar_days(
|
||||
store,
|
||||
symbol=ticker,
|
||||
target_date=date,
|
||||
top_k=n_similar,
|
||||
)
|
||||
result["freshness"] = freshness
|
||||
return result
|
||||
|
||||
|
||||
def get_story_payload(
|
||||
store: MarketStore,
|
||||
*,
|
||||
ticker: str,
|
||||
as_of_date: str,
|
||||
refresh_if_stale: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
freshness = ensure_news_fresh(
|
||||
store,
|
||||
ticker=ticker,
|
||||
target_date=as_of_date,
|
||||
refresh_if_stale=refresh_if_stale,
|
||||
)
|
||||
enrich_news_for_symbol(
|
||||
store,
|
||||
ticker,
|
||||
end_date=as_of_date,
|
||||
limit=80,
|
||||
)
|
||||
result = get_or_create_stock_story(
|
||||
store,
|
||||
symbol=ticker,
|
||||
as_of_date=as_of_date,
|
||||
)
|
||||
result["freshness"] = freshness
|
||||
return result
|
||||
|
||||
|
||||
def get_range_explain_payload(
|
||||
store: MarketStore,
|
||||
*,
|
||||
ticker: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
article_ids: list[str] | None = None,
|
||||
limit: int = 100,
|
||||
refresh_if_stale: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
freshness = ensure_news_fresh(
|
||||
store,
|
||||
ticker=ticker,
|
||||
target_date=end_date,
|
||||
refresh_if_stale=refresh_if_stale,
|
||||
)
|
||||
news_rows = []
|
||||
if article_ids:
|
||||
news_rows = store.get_news_by_ids_enriched(ticker, article_ids)
|
||||
if not news_rows:
|
||||
news_rows = store.get_news_items_enriched(
|
||||
ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=limit,
|
||||
)
|
||||
if news_rows_need_enrichment(news_rows):
|
||||
enrich_news_for_symbol(
|
||||
store,
|
||||
ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=limit,
|
||||
)
|
||||
news_rows = (
|
||||
store.get_news_by_ids_enriched(ticker, article_ids)
|
||||
if article_ids
|
||||
else store.get_news_items_enriched(
|
||||
ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=limit,
|
||||
)
|
||||
)
|
||||
result = build_range_explanation(
|
||||
ticker=ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
news_rows=news_rows,
|
||||
)
|
||||
return {"ticker": ticker, "result": result, "freshness": freshness}
|
||||
106
backend/domains/trading.py
Normal file
106
backend/domains/trading.py
Normal file
@@ -0,0 +1,106 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Trading domain helpers shared by app surfaces and gateway fallbacks."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from backend.services.market import MarketService
|
||||
from backend.tools.data_tools import (
|
||||
get_company_news,
|
||||
get_financial_metrics,
|
||||
get_insider_trades,
|
||||
get_market_cap,
|
||||
get_prices,
|
||||
search_line_items,
|
||||
)
|
||||
|
||||
|
||||
def get_prices_payload(*, ticker: str, start_date: str, end_date: str) -> dict[str, Any]:
|
||||
return {
|
||||
"ticker": ticker,
|
||||
"prices": get_prices(ticker, start_date, end_date),
|
||||
}
|
||||
|
||||
|
||||
def get_financials_payload(
|
||||
*,
|
||||
ticker: str,
|
||||
end_date: str,
|
||||
period: str = "ttm",
|
||||
limit: int = 10,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"financial_metrics": get_financial_metrics(
|
||||
ticker=ticker,
|
||||
end_date=end_date,
|
||||
period=period,
|
||||
limit=limit,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
def get_news_payload(
|
||||
*,
|
||||
ticker: str,
|
||||
end_date: str,
|
||||
start_date: str | None = None,
|
||||
limit: int = 1000,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"news": get_company_news(
|
||||
ticker=ticker,
|
||||
end_date=end_date,
|
||||
start_date=start_date,
|
||||
limit=limit,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
def get_insider_trades_payload(
|
||||
*,
|
||||
ticker: str,
|
||||
end_date: str,
|
||||
start_date: str | None = None,
|
||||
limit: int = 1000,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"insider_trades": get_insider_trades(
|
||||
ticker=ticker,
|
||||
end_date=end_date,
|
||||
start_date=start_date,
|
||||
limit=limit,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
def get_market_status_payload() -> dict[str, Any]:
|
||||
market_service = MarketService(tickers=[])
|
||||
return market_service.get_market_status()
|
||||
|
||||
|
||||
def get_market_cap_payload(*, ticker: str, end_date: str) -> dict[str, Any]:
|
||||
return {
|
||||
"ticker": ticker,
|
||||
"end_date": end_date,
|
||||
"market_cap": get_market_cap(ticker, end_date),
|
||||
}
|
||||
|
||||
|
||||
def get_line_items_payload(
|
||||
*,
|
||||
ticker: str,
|
||||
line_items: list[str],
|
||||
end_date: str,
|
||||
period: str = "ttm",
|
||||
limit: int = 10,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"search_results": search_line_items(
|
||||
ticker=ticker,
|
||||
line_items=line_items,
|
||||
end_date=end_date,
|
||||
period=period,
|
||||
limit=limit,
|
||||
)
|
||||
}
|
||||
2
backend/enrich/__init__.py
Normal file
2
backend/enrich/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
"""News enrichment utilities for explain-oriented market research."""
|
||||
|
||||
301
backend/enrich/llm_enricher.py
Normal file
301
backend/enrich/llm_enricher.py
Normal file
@@ -0,0 +1,301 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Optional AgentScope-backed news enrichment with safe local fallback."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.config.env_config import canonicalize_model_provider, get_env_bool, get_env_str
|
||||
from backend.llm.models import create_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EnrichedNewsItem(BaseModel):
|
||||
"""Structured output schema for one enriched article."""
|
||||
|
||||
id: str = Field(description="The source article id")
|
||||
relevance: str = Field(description="One of high, medium, low")
|
||||
sentiment: str = Field(description="One of positive, negative, neutral")
|
||||
key_discussion: str = Field(description="Concise core discussion")
|
||||
summary: str = Field(description="Concise factual summary")
|
||||
reason_growth: str = Field(description="Growth-oriented reason if present")
|
||||
reason_decrease: str = Field(description="Downside-oriented reason if present")
|
||||
|
||||
|
||||
class EnrichedNewsBatch(BaseModel):
|
||||
"""Structured output schema for a batch of enriched articles."""
|
||||
|
||||
items: list[EnrichedNewsItem]
|
||||
|
||||
|
||||
class RangeAnalysisPayload(BaseModel):
|
||||
"""Structured output schema for range explanation text."""
|
||||
|
||||
summary: str = Field(description="Concise Chinese range summary for the selected window")
|
||||
trend_analysis: str = Field(description="Concise Chinese trend explanation for the selected window")
|
||||
bullish_factors: list[str] = Field(description="Top bullish factors in Chinese")
|
||||
bearish_factors: list[str] = Field(description="Top bearish factors in Chinese")
|
||||
|
||||
|
||||
def get_explain_model_info() -> dict[str, str]:
|
||||
"""Resolve provider/model used by explain enrichment."""
|
||||
provider = canonicalize_model_provider(
|
||||
get_env_str("EXPLAIN_ENRICH_MODEL_PROVIDER")
|
||||
or get_env_str("MODEL_PROVIDER", "OPENAI"),
|
||||
)
|
||||
model_name = get_env_str("EXPLAIN_ENRICH_MODEL_NAME") or get_env_str(
|
||||
"MODEL_NAME",
|
||||
"gpt-4o-mini",
|
||||
)
|
||||
return {
|
||||
"provider": provider,
|
||||
"model_name": model_name,
|
||||
"label": f"{provider}:{model_name}",
|
||||
}
|
||||
|
||||
|
||||
def _normalize_enrichment_payload(payload: Any) -> dict[str, Any] | None:
|
||||
if isinstance(payload, BaseModel):
|
||||
payload = payload.model_dump()
|
||||
if not isinstance(payload, dict):
|
||||
return None
|
||||
return {
|
||||
"relevance": str(payload.get("relevance") or "").strip().lower() or None,
|
||||
"sentiment": str(payload.get("sentiment") or "").strip().lower() or None,
|
||||
"key_discussion": str(payload.get("key_discussion") or "").strip() or None,
|
||||
"summary": str(payload.get("summary") or "").strip() or None,
|
||||
"reason_growth": str(payload.get("reason_growth") or "").strip() or None,
|
||||
"reason_decrease": str(payload.get("reason_decrease") or "").strip() or None,
|
||||
"raw_json": payload,
|
||||
}
|
||||
|
||||
|
||||
def _run_async(coro: Any) -> Any:
|
||||
"""Run an async AgentScope model call from sync code, even inside a running loop."""
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
return asyncio.run(coro)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||
future = executor.submit(asyncio.run, coro)
|
||||
return future.result()
|
||||
|
||||
|
||||
def _get_explain_model():
|
||||
"""Create an AgentScope model for explain enrichment."""
|
||||
model_info = get_explain_model_info()
|
||||
return create_model(
|
||||
model_name=model_info["model_name"],
|
||||
provider=model_info["provider"],
|
||||
stream=False,
|
||||
generate_kwargs={"temperature": 0.1},
|
||||
)
|
||||
|
||||
|
||||
def llm_enrichment_enabled() -> bool:
|
||||
"""Return whether AgentScope-backed LLM enrichment should be attempted."""
|
||||
if not get_env_bool("EXPLAIN_ENRICH_USE_LLM", False):
|
||||
return False
|
||||
provider = get_explain_model_info()["provider"]
|
||||
provider_key_map = {
|
||||
"OPENAI": "OPENAI_API_KEY",
|
||||
"ANTHROPIC": "ANTHROPIC_API_KEY",
|
||||
"DASHSCOPE": "DASHSCOPE_API_KEY",
|
||||
"ALIBABA": "DASHSCOPE_API_KEY",
|
||||
"GEMINI": "GOOGLE_API_KEY",
|
||||
"GOOGLE": "GOOGLE_API_KEY",
|
||||
"DEEPSEEK": "DEEPSEEK_API_KEY",
|
||||
"GROQ": "GROQ_API_KEY",
|
||||
"OPENROUTER": "OPENROUTER_API_KEY",
|
||||
}
|
||||
env_key = provider_key_map.get(provider)
|
||||
return bool(get_env_str(env_key)) if env_key else provider == "OLLAMA"
|
||||
|
||||
|
||||
def llm_range_analysis_enabled() -> bool:
|
||||
"""Return whether LLM range analysis should be attempted."""
|
||||
raw_value = get_env_str("EXPLAIN_RANGE_USE_LLM")
|
||||
if raw_value is not None and str(raw_value).strip() != "":
|
||||
return get_env_bool("EXPLAIN_RANGE_USE_LLM", False) and llm_enrichment_enabled()
|
||||
return llm_enrichment_enabled()
|
||||
|
||||
|
||||
def analyze_news_row_with_llm(row: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Generate explain-oriented structured analysis for one article."""
|
||||
if not llm_enrichment_enabled():
|
||||
return None
|
||||
|
||||
model = _get_explain_model()
|
||||
title = str(row.get("title") or "").strip()
|
||||
summary = str(row.get("summary") or "").strip()
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"You produce concise structured financial news analysis. "
|
||||
"Use only the requested fields and keep content factual."
|
||||
),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"Analyze this stock-news article for an explain UI.\n"
|
||||
"Rules:\n"
|
||||
"- relevance must be one of: high, medium, low\n"
|
||||
"- sentiment must be one of: positive, negative, neutral\n"
|
||||
"- keep each text field concise and factual\n"
|
||||
f"- article id: {str(row.get('id') or '').strip()}\n"
|
||||
f"Title: {title}\n"
|
||||
f"Summary: {summary}\n"
|
||||
),
|
||||
},
|
||||
]
|
||||
try:
|
||||
response = _run_async(model(messages=messages, structured_model=EnrichedNewsItem))
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM enrichment failed: {e}")
|
||||
return None
|
||||
|
||||
payload = _normalize_enrichment_payload(getattr(response, "metadata", None))
|
||||
if payload:
|
||||
payload.setdefault("raw_json", {})
|
||||
payload["raw_json"]["model_provider"] = get_explain_model_info()["provider"]
|
||||
payload["raw_json"]["model_name"] = get_explain_model_info()["model_name"]
|
||||
payload["raw_json"]["model_label"] = get_explain_model_info()["label"]
|
||||
return payload
|
||||
|
||||
|
||||
def analyze_news_rows_with_llm(rows: list[dict[str, Any]]) -> dict[str, dict[str, Any]]:
|
||||
"""Generate structured analysis for multiple articles in one request."""
|
||||
if not llm_enrichment_enabled() or not rows:
|
||||
return {}
|
||||
|
||||
payload_rows = [
|
||||
{
|
||||
"id": str(row.get("id") or "").strip(),
|
||||
"title": str(row.get("title") or "").strip(),
|
||||
"summary": str(row.get("summary") or "").strip(),
|
||||
}
|
||||
for row in rows
|
||||
if str(row.get("id") or "").strip()
|
||||
]
|
||||
if not payload_rows:
|
||||
return {}
|
||||
|
||||
model = _get_explain_model()
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"You produce concise structured financial news analysis in JSON. "
|
||||
"Preserve ids exactly and do not invent extra items."
|
||||
),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"Analyze these stock-news articles for an explain UI.\n"
|
||||
"For each item return: id, relevance, sentiment, key_discussion, summary, "
|
||||
"reason_growth, reason_decrease.\n"
|
||||
"Rules:\n"
|
||||
"- relevance must be one of: high, medium, low\n"
|
||||
"- sentiment must be one of: positive, negative, neutral\n"
|
||||
"- keep all text concise and factual\n"
|
||||
f"Articles: {payload_rows}"
|
||||
),
|
||||
},
|
||||
]
|
||||
try:
|
||||
response = _run_async(
|
||||
model(messages=messages, structured_model=EnrichedNewsBatch),
|
||||
)
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
metadata = getattr(response, "metadata", None)
|
||||
if isinstance(metadata, BaseModel):
|
||||
metadata = metadata.model_dump()
|
||||
items = metadata.get("items") if isinstance(metadata, dict) else None
|
||||
if not isinstance(items, list):
|
||||
return {}
|
||||
|
||||
results: dict[str, dict[str, Any]] = {}
|
||||
for item in items:
|
||||
normalized = _normalize_enrichment_payload(item)
|
||||
news_id = str((item.model_dump() if isinstance(item, BaseModel) else item).get("id") or "").strip() if isinstance(item, (dict, BaseModel)) else ""
|
||||
if normalized and news_id:
|
||||
normalized.setdefault("raw_json", {})
|
||||
normalized["raw_json"]["model_provider"] = get_explain_model_info()["provider"]
|
||||
normalized["raw_json"]["model_name"] = get_explain_model_info()["model_name"]
|
||||
normalized["raw_json"]["model_label"] = get_explain_model_info()["label"]
|
||||
results[news_id] = normalized
|
||||
return results
|
||||
|
||||
|
||||
def analyze_range_with_llm(payload: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Generate explain-oriented range summary and factor refinement."""
|
||||
if not llm_range_analysis_enabled():
|
||||
return None
|
||||
|
||||
model = _get_explain_model()
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"You write concise Chinese stock range analysis for an explain UI. "
|
||||
"Use only the supplied facts. Keep the tone factual and analyst-like."
|
||||
),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"请基于给定事实生成区间分析。\n"
|
||||
"输出字段:summary, trend_analysis, bullish_factors, bearish_factors。\n"
|
||||
"要求:\n"
|
||||
"- 全部使用简体中文\n"
|
||||
"- summary 1到2句,概括区间走势、新闻密度和主导主题\n"
|
||||
"- trend_analysis 1句,解释区间内部阶段变化\n"
|
||||
"- bullish_factors 和 bearish_factors 各返回最多3条短句\n"
|
||||
"- 不要编造未提供的信息\n"
|
||||
f"事实数据: {payload}"
|
||||
),
|
||||
},
|
||||
]
|
||||
try:
|
||||
response = _run_async(
|
||||
model(messages=messages, structured_model=RangeAnalysisPayload),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM enrichment failed: {e}")
|
||||
return None
|
||||
|
||||
metadata = getattr(response, "metadata", None)
|
||||
if isinstance(metadata, BaseModel):
|
||||
metadata = metadata.model_dump()
|
||||
if not isinstance(metadata, dict):
|
||||
return None
|
||||
|
||||
return {
|
||||
"summary": str(metadata.get("summary") or "").strip() or None,
|
||||
"trend_analysis": str(metadata.get("trend_analysis") or "").strip() or None,
|
||||
"bullish_factors": [
|
||||
str(item).strip()
|
||||
for item in list(metadata.get("bullish_factors") or [])
|
||||
if str(item).strip()
|
||||
][:3],
|
||||
"bearish_factors": [
|
||||
str(item).strip()
|
||||
for item in list(metadata.get("bearish_factors") or [])
|
||||
if str(item).strip()
|
||||
][:3],
|
||||
"model_provider": get_explain_model_info()["provider"],
|
||||
"model_name": get_explain_model_info()["model_name"],
|
||||
"model_label": get_explain_model_info()["label"],
|
||||
}
|
||||
362
backend/enrich/news_enricher.py
Normal file
362
backend/enrich/news_enricher.py
Normal file
@@ -0,0 +1,362 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Lightweight news enrichment for explain-oriented market analysis."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
from typing import Any
|
||||
|
||||
from backend.config.env_config import get_env_int
|
||||
from backend.enrich.llm_enricher import (
|
||||
analyze_news_row_with_llm,
|
||||
analyze_news_rows_with_llm,
|
||||
llm_enrichment_enabled,
|
||||
)
|
||||
from backend.data.market_store import MarketStore
|
||||
|
||||
|
||||
POSITIVE_KEYWORDS = (
|
||||
"beat", "surge", "gain", "growth", "record", "upgrade", "strong",
|
||||
"partnership", "approved", "launch", "expands", "profit",
|
||||
)
|
||||
NEGATIVE_KEYWORDS = (
|
||||
"miss", "drop", "fall", "cut", "downgrade", "weak", "warning",
|
||||
"delay", "lawsuit", "probe", "tariff", "decline", "layoff",
|
||||
)
|
||||
HIGH_RELEVANCE_KEYWORDS = (
|
||||
"earnings", "guidance", "profit", "revenue", "ceo", "fda", "tariff",
|
||||
"regulation", "acquisition", "buyback", "forecast", "launch",
|
||||
)
|
||||
|
||||
|
||||
def _dedupe_key(row: dict[str, Any]) -> str:
|
||||
trade_date = str(row.get("trade_date") or row.get("date") or "")[:10]
|
||||
title = str(row.get("title") or "").strip().lower()
|
||||
summary = str(row.get("summary") or "").strip().lower()[:160]
|
||||
raw = f"{trade_date}::{title}::{summary}"
|
||||
return hashlib.sha1(raw.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def _chunk_rows(rows: list[dict[str, Any]], size: int) -> list[list[dict[str, Any]]]:
|
||||
chunk_size = max(1, int(size))
|
||||
return [rows[index:index + chunk_size] for index in range(0, len(rows), chunk_size)]
|
||||
|
||||
|
||||
def classify_news_row(row: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Return a lightweight explain-oriented analysis for one article."""
|
||||
llm_result = analyze_news_row_with_llm(row)
|
||||
if isinstance(llm_result, dict):
|
||||
merged = dict(llm_result)
|
||||
merged.setdefault("summary", str(row.get("summary") or row.get("title") or "")[:280])
|
||||
merged.setdefault("raw_json", row)
|
||||
merged["analysis_source"] = "llm"
|
||||
return merged
|
||||
|
||||
title = str(row.get("title") or "").strip()
|
||||
summary = str(row.get("summary") or "").strip()
|
||||
text = f"{title} {summary}".lower()
|
||||
|
||||
positive_hits = [keyword for keyword in POSITIVE_KEYWORDS if keyword in text]
|
||||
negative_hits = [keyword for keyword in NEGATIVE_KEYWORDS if keyword in text]
|
||||
relevance_hits = [keyword for keyword in HIGH_RELEVANCE_KEYWORDS if keyword in text]
|
||||
|
||||
if len(positive_hits) > len(negative_hits):
|
||||
sentiment = "positive"
|
||||
elif len(negative_hits) > len(positive_hits):
|
||||
sentiment = "negative"
|
||||
else:
|
||||
sentiment = "neutral"
|
||||
|
||||
relevance = "high" if relevance_hits else "medium" if title else "low"
|
||||
summary_text = summary or title
|
||||
key_discussion = ""
|
||||
if relevance_hits:
|
||||
key_discussion = f"核心主题集中在 {', '.join(relevance_hits[:3])}"
|
||||
elif summary_text:
|
||||
key_discussion = summary_text[:160]
|
||||
|
||||
reason_growth = ""
|
||||
reason_decrease = ""
|
||||
if sentiment == "positive":
|
||||
reason_growth = summary_text[:200]
|
||||
elif sentiment == "negative":
|
||||
reason_decrease = summary_text[:200]
|
||||
|
||||
return {
|
||||
"relevance": relevance,
|
||||
"sentiment": sentiment,
|
||||
"key_discussion": key_discussion,
|
||||
"summary": summary_text[:280],
|
||||
"reason_growth": reason_growth,
|
||||
"reason_decrease": reason_decrease,
|
||||
"analysis_source": "local",
|
||||
"raw_json": row,
|
||||
}
|
||||
|
||||
|
||||
def attach_forward_returns(
|
||||
*,
|
||||
news_rows: list[dict[str, Any]],
|
||||
ohlc_rows: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Attach forward-return labels to each analyzed row."""
|
||||
if not ohlc_rows:
|
||||
return news_rows
|
||||
|
||||
closes_by_date = {
|
||||
str(row.get("date")): float(row.get("close"))
|
||||
for row in ohlc_rows
|
||||
if row.get("date") is not None and row.get("close") is not None
|
||||
}
|
||||
ordered_dates = [str(row.get("date")) for row in ohlc_rows if row.get("date") is not None]
|
||||
date_index = {date: idx for idx, date in enumerate(ordered_dates)}
|
||||
|
||||
horizons = {
|
||||
"ret_t0": 0,
|
||||
"ret_t1": 1,
|
||||
"ret_t3": 3,
|
||||
"ret_t5": 5,
|
||||
"ret_t10": 10,
|
||||
}
|
||||
|
||||
enriched: list[dict[str, Any]] = []
|
||||
for row in news_rows:
|
||||
trade_date = str(row.get("trade_date") or "")[:10]
|
||||
base_close = closes_by_date.get(trade_date)
|
||||
if not trade_date or base_close in (None, 0):
|
||||
enriched.append(row)
|
||||
continue
|
||||
|
||||
next_row = dict(row)
|
||||
base_index = date_index.get(trade_date)
|
||||
if base_index is None:
|
||||
enriched.append(next_row)
|
||||
continue
|
||||
|
||||
for field, offset in horizons.items():
|
||||
target_index = base_index + offset
|
||||
if target_index >= len(ordered_dates):
|
||||
next_row[field] = None
|
||||
continue
|
||||
target_close = closes_by_date.get(ordered_dates[target_index])
|
||||
next_row[field] = (
|
||||
(float(target_close) - float(base_close)) / float(base_close)
|
||||
if target_close not in (None, 0)
|
||||
else None
|
||||
)
|
||||
enriched.append(next_row)
|
||||
return enriched
|
||||
|
||||
|
||||
def build_analysis_rows(
|
||||
*,
|
||||
symbol: str,
|
||||
news_rows: list[dict[str, Any]],
|
||||
ohlc_rows: list[dict[str, Any]],
|
||||
) -> tuple[list[dict[str, Any]], dict[str, int]]:
|
||||
"""Transform raw news rows into market_store news_analysis payloads plus stats."""
|
||||
llm_results: dict[str, dict[str, Any]] = {}
|
||||
if llm_enrichment_enabled():
|
||||
batch_size = get_env_int("EXPLAIN_ENRICH_BATCH_SIZE", 8)
|
||||
for chunk in _chunk_rows(news_rows, batch_size):
|
||||
llm_results.update(analyze_news_rows_with_llm(chunk))
|
||||
|
||||
staged_rows: list[dict[str, Any]] = []
|
||||
seen_dedupe_keys: set[str] = set()
|
||||
deduped_count = 0
|
||||
llm_count = 0
|
||||
local_count = 0
|
||||
for row in news_rows:
|
||||
news_id = str(row.get("id") or "").strip()
|
||||
if not news_id:
|
||||
continue
|
||||
dedupe_key = _dedupe_key(row)
|
||||
if dedupe_key in seen_dedupe_keys:
|
||||
deduped_count += 1
|
||||
continue
|
||||
seen_dedupe_keys.add(dedupe_key)
|
||||
batch_result = llm_results.get(news_id)
|
||||
if isinstance(batch_result, dict):
|
||||
analysis = dict(batch_result)
|
||||
analysis.setdefault("summary", str(row.get("summary") or row.get("title") or "")[:280])
|
||||
analysis.setdefault("raw_json", row)
|
||||
analysis["analysis_source"] = "llm"
|
||||
llm_count += 1
|
||||
else:
|
||||
analysis = classify_news_row(row)
|
||||
if analysis.get("analysis_source") == "llm":
|
||||
llm_count += 1
|
||||
else:
|
||||
local_count += 1
|
||||
staged_rows.append(
|
||||
{
|
||||
"news_id": news_id,
|
||||
"trade_date": str(row.get("trade_date") or "")[:10] or None,
|
||||
**analysis,
|
||||
}
|
||||
)
|
||||
return (
|
||||
attach_forward_returns(news_rows=staged_rows, ohlc_rows=ohlc_rows),
|
||||
{
|
||||
"deduped_count": deduped_count,
|
||||
"llm_count": llm_count,
|
||||
"local_count": local_count,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def enrich_news_for_symbol(
|
||||
store: MarketStore,
|
||||
symbol: str,
|
||||
*,
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None,
|
||||
limit: int = 200,
|
||||
analysis_source: str = "local",
|
||||
skip_existing: bool = True,
|
||||
only_reanalyze_local: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""Read raw market news, compute explain fields, and persist them."""
|
||||
normalized_symbol = str(symbol or "").strip().upper()
|
||||
if not normalized_symbol:
|
||||
return {"symbol": "", "analyzed": 0}
|
||||
|
||||
news_rows = store.get_news_items(
|
||||
normalized_symbol,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=limit,
|
||||
)
|
||||
total_news_count = len(news_rows)
|
||||
skipped_existing_count = 0
|
||||
analyzed_sources: dict[str, str] = {}
|
||||
skipped_missing_analysis_count = 0
|
||||
skipped_non_local_count = 0
|
||||
if news_rows and only_reanalyze_local:
|
||||
analyzed_sources = store.get_analyzed_news_sources(
|
||||
normalized_symbol,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
skipped_missing_analysis_count = sum(
|
||||
1
|
||||
for row in news_rows
|
||||
if str(row.get("id") or "").strip() not in analyzed_sources
|
||||
)
|
||||
skipped_non_local_count = sum(
|
||||
1
|
||||
for row in news_rows
|
||||
if str(row.get("id") or "").strip() in analyzed_sources
|
||||
and analyzed_sources.get(str(row.get("id") or "").strip()) != "local"
|
||||
)
|
||||
skipped_existing_count = sum(
|
||||
1
|
||||
for row in news_rows
|
||||
if str(row.get("id") or "").strip() not in analyzed_sources
|
||||
or analyzed_sources.get(str(row.get("id") or "").strip()) != "local"
|
||||
)
|
||||
news_rows = [
|
||||
row for row in news_rows
|
||||
if analyzed_sources.get(str(row.get("id") or "").strip()) == "local"
|
||||
]
|
||||
elif skip_existing and news_rows:
|
||||
analyzed_ids = store.get_analyzed_news_ids(
|
||||
normalized_symbol,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
skipped_existing_count = sum(
|
||||
1
|
||||
for row in news_rows
|
||||
if str(row.get("id") or "").strip() in analyzed_ids
|
||||
)
|
||||
news_rows = [
|
||||
row for row in news_rows
|
||||
if str(row.get("id") or "").strip() not in analyzed_ids
|
||||
]
|
||||
ohlc_start = start_date or (news_rows[-1]["trade_date"] if news_rows and news_rows[-1].get("trade_date") else None)
|
||||
ohlc_end = end_date or (news_rows[0]["trade_date"] if news_rows and news_rows[0].get("trade_date") else None)
|
||||
ohlc_rows = (
|
||||
store.get_ohlc(normalized_symbol, ohlc_start, ohlc_end)
|
||||
if ohlc_start and ohlc_end
|
||||
else []
|
||||
)
|
||||
analysis_rows, stats = build_analysis_rows(
|
||||
symbol=normalized_symbol,
|
||||
news_rows=news_rows,
|
||||
ohlc_rows=ohlc_rows,
|
||||
)
|
||||
analyzed = store.upsert_news_analysis(
|
||||
normalized_symbol,
|
||||
analysis_rows,
|
||||
analysis_source=analysis_source,
|
||||
)
|
||||
upgraded_dates = sorted(
|
||||
{
|
||||
str(row.get("trade_date") or "")[:10]
|
||||
for row in analysis_rows
|
||||
if str(row.get("analysis_source") or "").strip().lower() == "llm"
|
||||
and str(row.get("trade_date") or "").strip()
|
||||
}
|
||||
)
|
||||
remaining_local_titles = [
|
||||
str(row.get("title") or row.get("news_id") or "").strip()
|
||||
for row in news_rows
|
||||
for analyzed_row in analysis_rows
|
||||
if str(analyzed_row.get("news_id") or "").strip() == str(row.get("id") or "").strip()
|
||||
and str(analyzed_row.get("analysis_source") or "").strip().lower() == "local"
|
||||
][:5]
|
||||
return {
|
||||
"symbol": normalized_symbol,
|
||||
"analyzed": analyzed,
|
||||
"news_count": total_news_count,
|
||||
"queued_count": len(news_rows),
|
||||
"skipped_existing_count": skipped_existing_count,
|
||||
"deduped_count": stats["deduped_count"],
|
||||
"llm_count": stats["llm_count"],
|
||||
"local_count": stats["local_count"],
|
||||
"only_reanalyze_local": only_reanalyze_local,
|
||||
"upgraded_local_to_llm_count": (
|
||||
stats["llm_count"]
|
||||
if only_reanalyze_local
|
||||
else 0
|
||||
),
|
||||
"execution_summary": {
|
||||
"upgraded_dates": upgraded_dates[:5],
|
||||
"remaining_local_titles": remaining_local_titles,
|
||||
"skipped_missing_analysis_count": skipped_missing_analysis_count,
|
||||
"skipped_non_local_count": skipped_non_local_count,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def enrich_symbols(
|
||||
store: MarketStore,
|
||||
symbols: list[str],
|
||||
*,
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None,
|
||||
limit: int = 200,
|
||||
analysis_source: str = "local",
|
||||
skip_existing: bool = True,
|
||||
only_reanalyze_local: bool = False,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Batch enrich multiple symbols for explain-oriented news analysis."""
|
||||
results = []
|
||||
for symbol in symbols:
|
||||
normalized_symbol = str(symbol or "").strip().upper()
|
||||
if not normalized_symbol:
|
||||
continue
|
||||
results.append(
|
||||
enrich_news_for_symbol(
|
||||
store,
|
||||
normalized_symbol,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=limit,
|
||||
analysis_source=analysis_source,
|
||||
skip_existing=skip_existing,
|
||||
only_reanalyze_local=only_reanalyze_local,
|
||||
)
|
||||
)
|
||||
return results
|
||||
2
backend/explain/__init__.py
Normal file
2
backend/explain/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Explain-oriented services for stock narratives and news research."""
|
||||
69
backend/explain/category_engine.py
Normal file
69
backend/explain/category_engine.py
Normal file
@@ -0,0 +1,69 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Rule-based news categorization for explain UI."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, Iterable
|
||||
|
||||
|
||||
CATEGORY_KEYWORDS = {
|
||||
"market": [
|
||||
"market", "stock", "rally", "sell-off", "selloff", "trading",
|
||||
"wall street", "s&p", "nasdaq", "dow", "index", "bull", "bear",
|
||||
"correction", "volatility",
|
||||
],
|
||||
"policy": [
|
||||
"regulation", "fed", "federal reserve", "tariff", "sanction",
|
||||
"interest rate", "policy", "government", "congress", "sec",
|
||||
"trade war", "ban", "legislation", "tax",
|
||||
],
|
||||
"earnings": [
|
||||
"earnings", "revenue", "profit", "quarter", "eps", "guidance",
|
||||
"forecast", "income", "sales", "beat", "miss", "outlook",
|
||||
"financial results",
|
||||
],
|
||||
"product_tech": [
|
||||
"product", "ai", "chip", "cloud", "launch", "patent",
|
||||
"technology", "innovation", "release", "platform", "model",
|
||||
"software", "hardware", "gpu", "autonomous",
|
||||
],
|
||||
"competition": [
|
||||
"competitor", "rival", "market share", "overtake", "compete",
|
||||
"competition", "vs", "versus", "battle", "challenge",
|
||||
],
|
||||
"management": [
|
||||
"ceo", "executive", "resign", "layoff", "restructure",
|
||||
"management", "leadership", "appoint", "hire", "board",
|
||||
"chairman",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def categorize_news_rows(rows: Iterable[Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
|
||||
"""Bucket news rows by keyword categories."""
|
||||
categories: Dict[str, Dict[str, Any]] = {
|
||||
key: {
|
||||
"label": key,
|
||||
"count": 0,
|
||||
"article_ids": [],
|
||||
}
|
||||
for key in CATEGORY_KEYWORDS
|
||||
}
|
||||
|
||||
for row in rows:
|
||||
text = " ".join(
|
||||
[
|
||||
str(row.get("title") or ""),
|
||||
str(row.get("summary") or ""),
|
||||
str(row.get("related") or ""),
|
||||
str(row.get("category") or ""),
|
||||
]
|
||||
).lower()
|
||||
article_id = row.get("id")
|
||||
for category, keywords in CATEGORY_KEYWORDS.items():
|
||||
if any(keyword in text for keyword in keywords):
|
||||
categories[category]["count"] += 1
|
||||
if article_id:
|
||||
categories[category]["article_ids"].append(article_id)
|
||||
|
||||
return categories
|
||||
214
backend/explain/range_explainer.py
Normal file
214
backend/explain/range_explainer.py
Normal file
@@ -0,0 +1,214 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Local range explanation built from price and persisted news."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from backend.enrich.llm_enricher import analyze_range_with_llm
|
||||
from backend.explain.category_engine import categorize_news_rows
|
||||
from backend.tools.data_tools import get_prices
|
||||
|
||||
|
||||
def _rank_event_score(row: Dict[str, Any]) -> float:
|
||||
relevance = str(row.get("relevance") or "").strip().lower()
|
||||
relevance_score = {"high": 3.0, "relevant": 3.0, "medium": 2.0, "low": 1.0}.get(
|
||||
relevance,
|
||||
0.5,
|
||||
)
|
||||
impact_score = abs(float(row.get("ret_t0") or 0.0)) * 100
|
||||
return relevance_score + impact_score
|
||||
|
||||
|
||||
def summarize_bullish_factors(
|
||||
news_rows: list[Dict[str, Any]],
|
||||
*,
|
||||
limit: int = 5,
|
||||
) -> list[str]:
|
||||
factors = []
|
||||
for row in news_rows:
|
||||
if str(row.get("sentiment") or "").strip().lower() != "positive":
|
||||
continue
|
||||
candidate = row.get("reason_growth") or row.get("key_discussion") or row.get("summary") or row.get("title")
|
||||
if candidate:
|
||||
factors.append(str(candidate).strip())
|
||||
seen = set()
|
||||
output = []
|
||||
for factor in factors:
|
||||
if factor in seen:
|
||||
continue
|
||||
seen.add(factor)
|
||||
output.append(factor[:200])
|
||||
if len(output) >= limit:
|
||||
break
|
||||
return output
|
||||
|
||||
|
||||
def summarize_bearish_factors(
|
||||
news_rows: list[Dict[str, Any]],
|
||||
*,
|
||||
limit: int = 5,
|
||||
) -> list[str]:
|
||||
factors = []
|
||||
for row in news_rows:
|
||||
if str(row.get("sentiment") or "").strip().lower() != "negative":
|
||||
continue
|
||||
candidate = row.get("reason_decrease") or row.get("key_discussion") or row.get("summary") or row.get("title")
|
||||
if candidate:
|
||||
factors.append(str(candidate).strip())
|
||||
seen = set()
|
||||
output = []
|
||||
for factor in factors:
|
||||
if factor in seen:
|
||||
continue
|
||||
seen.add(factor)
|
||||
output.append(factor[:200])
|
||||
if len(output) >= limit:
|
||||
break
|
||||
return output
|
||||
|
||||
|
||||
def build_trend_analysis(prices: list[Any]) -> str:
|
||||
if len(prices) < 2:
|
||||
return "区间样本较短,暂不具备足够趋势信息。"
|
||||
if len(prices) < 3:
|
||||
open_price = float(prices[0].open)
|
||||
close_price = float(prices[-1].close)
|
||||
change = ((close_price - open_price) / open_price) * 100 if open_price else 0.0
|
||||
return f"短区间内价格变动 {change:+.2f}%,趋势信息有限。"
|
||||
|
||||
mid = len(prices) // 2
|
||||
first_open = float(prices[0].open)
|
||||
first_close = float(prices[mid].close)
|
||||
second_open = float(prices[mid].open)
|
||||
second_close = float(prices[-1].close)
|
||||
first_half = ((first_close - first_open) / first_open) * 100 if first_open else 0.0
|
||||
second_half = ((second_close - second_open) / second_open) * 100 if second_open else 0.0
|
||||
return (
|
||||
f"前半段{'上涨' if first_half >= 0 else '下跌'} {abs(first_half):.2f}%,"
|
||||
f"后半段{'上涨' if second_half >= 0 else '下跌'} {abs(second_half):.2f}%,"
|
||||
"说明价格驱动在区间内部出现了阶段性切换。"
|
||||
)
|
||||
|
||||
|
||||
def build_range_explanation(
|
||||
*,
|
||||
ticker: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
news_rows: list[Dict[str, Any]],
|
||||
) -> Dict[str, Any]:
|
||||
"""Explain a price range with local price and news heuristics."""
|
||||
prices = get_prices(ticker, start_date, end_date)
|
||||
if not prices:
|
||||
return {
|
||||
"symbol": ticker,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
"error": "No OHLC data for this range",
|
||||
}
|
||||
|
||||
open_price = float(prices[0].open)
|
||||
close_price = float(prices[-1].close)
|
||||
high_price = max(float(price.high) for price in prices)
|
||||
low_price = min(float(price.low) for price in prices)
|
||||
total_volume = sum(int(price.volume) for price in prices)
|
||||
price_change_pct = (
|
||||
((close_price - open_price) / open_price) * 100 if open_price else 0.0
|
||||
)
|
||||
|
||||
categories = categorize_news_rows(news_rows)
|
||||
news_count = len(news_rows)
|
||||
dominant_categories = sorted(
|
||||
(
|
||||
{"category": key, "count": value["count"]}
|
||||
for key, value in categories.items()
|
||||
if value["count"] > 0
|
||||
),
|
||||
key=lambda item: item["count"],
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
direction = "上涨" if price_change_pct > 0 else "下跌" if price_change_pct < 0 else "横盘"
|
||||
category_text = (
|
||||
f"主要主题集中在 {', '.join(item['category'] for item in dominant_categories[:3])}。"
|
||||
if dominant_categories
|
||||
else "区间内未识别出明显的主题聚类。"
|
||||
)
|
||||
summary = (
|
||||
f"{ticker} 在 {start_date} 至 {end_date} 区间内{direction} {abs(price_change_pct):.2f}%,"
|
||||
f"区间覆盖 {len(prices)} 个交易日,关联新闻 {news_count} 条。{category_text}"
|
||||
)
|
||||
|
||||
bullish_factors = summarize_bullish_factors(news_rows)
|
||||
bearish_factors = summarize_bearish_factors(news_rows)
|
||||
trend_analysis = build_trend_analysis(prices)
|
||||
llm_source = "local"
|
||||
|
||||
range_payload = {
|
||||
"ticker": ticker,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
"price_change_pct": round(price_change_pct, 2),
|
||||
"trading_days": len(prices),
|
||||
"news_count": news_count,
|
||||
"dominant_categories": dominant_categories[:5],
|
||||
"bullish_factors": bullish_factors[:3],
|
||||
"bearish_factors": bearish_factors[:3],
|
||||
"trend_analysis": trend_analysis,
|
||||
"top_news": [
|
||||
{
|
||||
"date": row.get("trade_date") or str(row.get("date") or "")[:10],
|
||||
"title": row.get("title") or "",
|
||||
"summary": row.get("summary") or "",
|
||||
"sentiment": row.get("sentiment") or "",
|
||||
"relevance": row.get("relevance") or "",
|
||||
"ret_t0": row.get("ret_t0"),
|
||||
}
|
||||
for row in sorted(news_rows, key=_rank_event_score, reverse=True)[:5]
|
||||
],
|
||||
}
|
||||
llm_analysis = analyze_range_with_llm(range_payload)
|
||||
if isinstance(llm_analysis, dict):
|
||||
summary = llm_analysis.get("summary") or summary
|
||||
trend_analysis = llm_analysis.get("trend_analysis") or trend_analysis
|
||||
bullish_factors = llm_analysis.get("bullish_factors") or bullish_factors
|
||||
bearish_factors = llm_analysis.get("bearish_factors") or bearish_factors
|
||||
llm_source = "llm"
|
||||
|
||||
key_events = [
|
||||
{
|
||||
"date": row.get("trade_date") or str(row.get("date") or "")[:10],
|
||||
"title": row.get("title") or "Untitled news",
|
||||
"summary": row.get("summary") or "",
|
||||
"category": row.get("category") or "",
|
||||
"id": row.get("id"),
|
||||
"sentiment": row.get("sentiment"),
|
||||
"ret_t0": row.get("ret_t0"),
|
||||
}
|
||||
for row in sorted(news_rows, key=_rank_event_score, reverse=True)[:8]
|
||||
]
|
||||
|
||||
return {
|
||||
"symbol": ticker,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
"price_change_pct": round(price_change_pct, 2),
|
||||
"open_price": open_price,
|
||||
"close_price": close_price,
|
||||
"high_price": high_price,
|
||||
"low_price": low_price,
|
||||
"total_volume": total_volume,
|
||||
"trading_days": len(prices),
|
||||
"news_count": news_count,
|
||||
"dominant_categories": dominant_categories[:5],
|
||||
"analysis": {
|
||||
"summary": summary,
|
||||
"key_events": key_events,
|
||||
"bullish_factors": bullish_factors,
|
||||
"bearish_factors": bearish_factors,
|
||||
"trend_analysis": trend_analysis,
|
||||
"analysis_source": llm_source,
|
||||
"analysis_model_label": llm_analysis.get("model_label") if isinstance(llm_analysis, dict) else None,
|
||||
},
|
||||
}
|
||||
202
backend/explain/similarity_service.py
Normal file
202
backend/explain/similarity_service.py
Normal file
@@ -0,0 +1,202 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Same-ticker historical similar day search for explain view."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from math import sqrt
|
||||
from typing import Any
|
||||
|
||||
from backend.data.market_store import MarketStore
|
||||
|
||||
|
||||
def _safe_float(value: Any, default: float = 0.0) -> float:
|
||||
try:
|
||||
parsed = float(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
return parsed
|
||||
|
||||
|
||||
def build_daily_feature_rows(
|
||||
*,
|
||||
symbol: str,
|
||||
ohlc_rows: list[dict[str, Any]],
|
||||
news_rows: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Aggregate price/news context into daily feature rows."""
|
||||
price_by_date = {str(row.get("date")): row for row in ohlc_rows if row.get("date")}
|
||||
ordered_dates = [str(row.get("date")) for row in ohlc_rows if row.get("date")]
|
||||
|
||||
news_by_date: dict[str, list[dict[str, Any]]] = {}
|
||||
for row in news_rows:
|
||||
trade_date = str(row.get("trade_date") or "")[:10] or str(row.get("date") or "")[:10]
|
||||
if not trade_date:
|
||||
continue
|
||||
news_by_date.setdefault(trade_date, []).append(row)
|
||||
|
||||
features: list[dict[str, Any]] = []
|
||||
previous_close: float | None = None
|
||||
for idx, date in enumerate(ordered_dates):
|
||||
price_row = price_by_date[date]
|
||||
close_price = _safe_float(price_row.get("close"))
|
||||
open_price = _safe_float(price_row.get("open"), close_price)
|
||||
day_news = news_by_date.get(date, [])
|
||||
positive_count = sum(1 for item in day_news if str(item.get("sentiment") or "").lower() == "positive")
|
||||
negative_count = sum(1 for item in day_news if str(item.get("sentiment") or "").lower() == "negative")
|
||||
high_relevance_count = sum(
|
||||
1 for item in day_news if str(item.get("relevance") or "").lower() in {"high", "relevant"}
|
||||
)
|
||||
ret_1d = (
|
||||
((close_price - previous_close) / previous_close)
|
||||
if previous_close not in (None, 0)
|
||||
else 0.0
|
||||
)
|
||||
intraday_ret = ((close_price - open_price) / open_price) if open_price else 0.0
|
||||
sentiment_score = (
|
||||
(positive_count - negative_count) / max(len(day_news), 1)
|
||||
if day_news
|
||||
else 0.0
|
||||
)
|
||||
future_t1 = None
|
||||
future_t3 = None
|
||||
if idx + 1 < len(ordered_dates) and close_price:
|
||||
next_close = _safe_float(price_by_date[ordered_dates[idx + 1]].get("close"))
|
||||
future_t1 = ((next_close - close_price) / close_price) if next_close else None
|
||||
if idx + 3 < len(ordered_dates) and close_price:
|
||||
next_close = _safe_float(price_by_date[ordered_dates[idx + 3]].get("close"))
|
||||
future_t3 = ((next_close - close_price) / close_price) if next_close else None
|
||||
|
||||
features.append(
|
||||
{
|
||||
"date": date,
|
||||
"symbol": symbol,
|
||||
"n_articles": len(day_news),
|
||||
"positive_count": positive_count,
|
||||
"negative_count": negative_count,
|
||||
"high_relevance_count": high_relevance_count,
|
||||
"sentiment_score": sentiment_score,
|
||||
"ret_1d": ret_1d,
|
||||
"intraday_ret": intraday_ret,
|
||||
"close": close_price,
|
||||
"ret_t1_after": future_t1,
|
||||
"ret_t3_after": future_t3,
|
||||
"news": [
|
||||
{
|
||||
"title": row.get("title") or "",
|
||||
"sentiment": row.get("sentiment") or "neutral",
|
||||
}
|
||||
for row in day_news[:3]
|
||||
],
|
||||
}
|
||||
)
|
||||
previous_close = close_price
|
||||
return features
|
||||
|
||||
|
||||
def compute_similarity_scores(
|
||||
target_vector: list[float],
|
||||
candidate_vectors: list[tuple[str, list[float], dict[str, Any]]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Return sorted similarity matches based on normalized Euclidean distance."""
|
||||
if not candidate_vectors:
|
||||
return []
|
||||
dimensions = len(target_vector)
|
||||
ranges = []
|
||||
for dimension in range(dimensions):
|
||||
values = [vector[1][dimension] for vector in candidate_vectors] + [target_vector[dimension]]
|
||||
min_value = min(values)
|
||||
max_value = max(values)
|
||||
ranges.append(max(max_value - min_value, 1e-9))
|
||||
|
||||
scored = []
|
||||
for date, vector, payload in candidate_vectors:
|
||||
distance = sqrt(
|
||||
sum(
|
||||
((target_vector[i] - vector[i]) / ranges[i]) ** 2
|
||||
for i in range(dimensions)
|
||||
)
|
||||
)
|
||||
similarity = 1.0 / (1.0 + distance)
|
||||
scored.append(
|
||||
{
|
||||
"date": date,
|
||||
"score": round(similarity, 4),
|
||||
**payload,
|
||||
}
|
||||
)
|
||||
return sorted(scored, key=lambda item: item["score"], reverse=True)
|
||||
|
||||
|
||||
def find_similar_days(
|
||||
store: MarketStore,
|
||||
*,
|
||||
symbol: str,
|
||||
target_date: str,
|
||||
top_k: int = 10,
|
||||
) -> dict[str, Any]:
|
||||
"""Find same-ticker historical days most similar to a target day."""
|
||||
cached = store.get_similar_day_cache(symbol, target_date=target_date)
|
||||
if cached and cached.get("payload"):
|
||||
return cached["payload"]
|
||||
|
||||
ohlc_rows = store.get_ohlc(symbol, "1900-01-01", target_date)
|
||||
news_rows = store.get_news_items_enriched(symbol, end_date=target_date, limit=500)
|
||||
daily_rows = build_daily_feature_rows(symbol=symbol, ohlc_rows=ohlc_rows, news_rows=news_rows)
|
||||
feature_map = {row["date"]: row for row in daily_rows}
|
||||
target_row = feature_map.get(target_date)
|
||||
if not target_row:
|
||||
return {
|
||||
"symbol": symbol,
|
||||
"target_date": target_date,
|
||||
"items": [],
|
||||
"error": "No feature row for target date",
|
||||
}
|
||||
|
||||
vector_keys = [
|
||||
"sentiment_score",
|
||||
"n_articles",
|
||||
"positive_count",
|
||||
"negative_count",
|
||||
"high_relevance_count",
|
||||
"ret_1d",
|
||||
"intraday_ret",
|
||||
]
|
||||
target_vector = [_safe_float(target_row.get(key)) for key in vector_keys]
|
||||
candidates = []
|
||||
for row in daily_rows:
|
||||
date = row["date"]
|
||||
if date == target_date:
|
||||
continue
|
||||
payload = {
|
||||
"n_articles": row["n_articles"],
|
||||
"sentiment_score": round(row["sentiment_score"], 4),
|
||||
"ret_1d": round(row["ret_1d"] * 100, 2),
|
||||
"intraday_ret": round(row["intraday_ret"] * 100, 2),
|
||||
"ret_t1_after": round(row["ret_t1_after"] * 100, 2) if row["ret_t1_after"] is not None else None,
|
||||
"ret_t3_after": round(row["ret_t3_after"] * 100, 2) if row["ret_t3_after"] is not None else None,
|
||||
"top_reasons": [item["title"] for item in row["news"][:2] if item.get("title")],
|
||||
"news": row["news"],
|
||||
}
|
||||
candidates.append(
|
||||
(
|
||||
date,
|
||||
[_safe_float(row.get(key)) for key in vector_keys],
|
||||
payload,
|
||||
)
|
||||
)
|
||||
|
||||
items = compute_similarity_scores(target_vector, candidates)[: max(1, min(int(top_k), 20))]
|
||||
result = {
|
||||
"symbol": symbol,
|
||||
"target_date": target_date,
|
||||
"target_features": {
|
||||
"sentiment_score": round(target_row["sentiment_score"], 4),
|
||||
"n_articles": target_row["n_articles"],
|
||||
"ret_1d": round(target_row["ret_1d"] * 100, 2),
|
||||
"intraday_ret": round(target_row["intraday_ret"] * 100, 2),
|
||||
"high_relevance_count": target_row["high_relevance_count"],
|
||||
},
|
||||
"items": items,
|
||||
}
|
||||
store.upsert_similar_day_cache(symbol, target_date=target_date, payload=result, source="local")
|
||||
return result
|
||||
127
backend/explain/story_service.py
Normal file
127
backend/explain/story_service.py
Normal file
@@ -0,0 +1,127 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Stock story generation for explain view."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
from backend.data.market_store import MarketStore
|
||||
|
||||
|
||||
def build_stock_story(
|
||||
*,
|
||||
symbol: str,
|
||||
as_of_date: str,
|
||||
price_rows: list[dict[str, Any]],
|
||||
news_rows: list[dict[str, Any]],
|
||||
) -> str:
|
||||
"""Build a compact markdown story from enriched news and recent price action."""
|
||||
lines = [f"## {symbol} Story", f"As of `{as_of_date}`"]
|
||||
if not price_rows:
|
||||
lines.append("")
|
||||
lines.append("No OHLC data available for story generation.")
|
||||
return "\n".join(lines)
|
||||
|
||||
open_price = float(price_rows[0].get("open") or price_rows[0].get("close") or 0.0)
|
||||
close_price = float(price_rows[-1].get("close") or 0.0)
|
||||
price_change = ((close_price - open_price) / open_price) * 100 if open_price else 0.0
|
||||
high_price = max(float(row.get("high") or row.get("close") or 0.0) for row in price_rows)
|
||||
low_price = min(float(row.get("low") or row.get("close") or 0.0) for row in price_rows)
|
||||
|
||||
lines.append("")
|
||||
lines.append(
|
||||
f"The stock moved {'up' if price_change >= 0 else 'down'} "
|
||||
f"{abs(price_change):.2f}% over the recent window, trading between "
|
||||
f"${low_price:.2f} and ${high_price:.2f}."
|
||||
)
|
||||
|
||||
positive = [row for row in news_rows if str(row.get("sentiment") or "").lower() == "positive"]
|
||||
negative = [row for row in news_rows if str(row.get("sentiment") or "").lower() == "negative"]
|
||||
lines.append("")
|
||||
lines.append(
|
||||
f"Recent coverage included {len(news_rows)} relevant articles "
|
||||
f"({len(positive)} positive / {len(negative)} negative)."
|
||||
)
|
||||
|
||||
if news_rows:
|
||||
lines.append("")
|
||||
lines.append("### Key Moments")
|
||||
ranked_rows = sorted(
|
||||
news_rows,
|
||||
key=lambda row: (
|
||||
0 if str(row.get("relevance") or "").lower() in {"high", "relevant"} else 1,
|
||||
-abs(float(row.get("ret_t0") or 0.0)),
|
||||
),
|
||||
)
|
||||
for row in ranked_rows[:5]:
|
||||
trade_date = row.get("trade_date") or str(row.get("date") or "")[:10]
|
||||
title = row.get("title") or "Untitled"
|
||||
key_discussion = row.get("key_discussion") or row.get("summary") or ""
|
||||
sentiment = str(row.get("sentiment") or "neutral").lower()
|
||||
lines.append(
|
||||
f"- `{trade_date}` [{sentiment}] {title}: {str(key_discussion).strip()[:220]}"
|
||||
)
|
||||
|
||||
if positive:
|
||||
lines.append("")
|
||||
lines.append("### Bullish Threads")
|
||||
for row in positive[:3]:
|
||||
reason = row.get("reason_growth") or row.get("key_discussion") or row.get("summary") or row.get("title")
|
||||
lines.append(f"- {str(reason).strip()[:220]}")
|
||||
|
||||
if negative:
|
||||
lines.append("")
|
||||
lines.append("### Bearish Threads")
|
||||
for row in negative[:3]:
|
||||
reason = row.get("reason_decrease") or row.get("key_discussion") or row.get("summary") or row.get("title")
|
||||
lines.append(f"- {str(reason).strip()[:220]}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def get_or_create_stock_story(
|
||||
store: MarketStore,
|
||||
*,
|
||||
symbol: str,
|
||||
as_of_date: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Return cached story or build a new one from recent market context."""
|
||||
cached = store.get_story_cache(symbol, as_of_date=as_of_date)
|
||||
if cached:
|
||||
return {
|
||||
"symbol": symbol,
|
||||
"as_of_date": as_of_date,
|
||||
"story": cached.get("content") or "",
|
||||
"source": cached.get("source") or "cache",
|
||||
}
|
||||
|
||||
start_date = None
|
||||
if len(as_of_date) >= 10:
|
||||
target_date = datetime.strptime(as_of_date[:10], "%Y-%m-%d").date()
|
||||
start_date = (target_date - timedelta(days=29)).isoformat()
|
||||
|
||||
price_rows = (
|
||||
store.get_ohlc(symbol, start_date, as_of_date)
|
||||
if start_date
|
||||
else []
|
||||
)
|
||||
news_rows = store.get_news_items_enriched(
|
||||
symbol,
|
||||
start_date=start_date,
|
||||
end_date=as_of_date,
|
||||
limit=40,
|
||||
)
|
||||
story = build_stock_story(
|
||||
symbol=symbol,
|
||||
as_of_date=as_of_date,
|
||||
price_rows=price_rows,
|
||||
news_rows=news_rows,
|
||||
)
|
||||
store.upsert_story_cache(symbol, as_of_date=as_of_date, content=story, source="local")
|
||||
return {
|
||||
"symbol": symbol,
|
||||
"as_of_date": as_of_date,
|
||||
"story": story,
|
||||
"source": "local",
|
||||
}
|
||||
309
backend/gateway_server.py
Normal file
309
backend/gateway_server.py
Normal file
@@ -0,0 +1,309 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Gateway Server - Entry point for Gateway subprocess.
|
||||
|
||||
This module is launched as a subprocess by the Control Plane (FastAPI)
|
||||
to run the Data Plane (Gateway + Pipeline).
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from contextlib import AsyncExitStack
|
||||
from pathlib import Path
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
from backend.agents import AnalystAgent, PMAgent, RiskAgent
|
||||
from backend.agents.skills_manager import SkillsManager
|
||||
from backend.agents.toolkit_factory import create_agent_toolkit, load_agent_profiles
|
||||
from backend.agents.prompt_loader import get_prompt_loader
|
||||
from backend.agents.workspace_manager import WorkspaceManager
|
||||
from backend.config.constants import ANALYST_TYPES
|
||||
from backend.core.pipeline import TradingPipeline
|
||||
from backend.core.pipeline_runner import create_agents, create_long_term_memory
|
||||
from backend.core.scheduler import BacktestScheduler, Scheduler
|
||||
from backend.llm.models import get_agent_formatter, get_agent_model
|
||||
from backend.runtime.manager import (
|
||||
TradingRuntimeManager,
|
||||
set_global_runtime_manager,
|
||||
clear_global_runtime_manager,
|
||||
)
|
||||
from backend.services.gateway import Gateway
|
||||
from backend.services.market import MarketService
|
||||
from backend.services.storage import StorageService
|
||||
from backend.utils.settlement import SettlementCoordinator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_prompt_loader = get_prompt_loader()
|
||||
|
||||
|
||||
INFO_LOGGER_PREFIXES = (
|
||||
"backend.agents",
|
||||
"backend.core.pipeline",
|
||||
"backend.core.scheduler",
|
||||
"backend.services.gateway_cycle_support",
|
||||
)
|
||||
|
||||
NOISY_LOGGER_LEVELS = {
|
||||
"aiohttp": logging.WARNING,
|
||||
"asyncio": logging.WARNING,
|
||||
"dashscope": logging.WARNING,
|
||||
"finnhub": logging.WARNING,
|
||||
"httpcore": logging.WARNING,
|
||||
"httpx": logging.WARNING,
|
||||
"urllib3": logging.WARNING,
|
||||
"websockets": logging.WARNING,
|
||||
"yfinance": logging.WARNING,
|
||||
"backend.data.polling_price_manager": logging.WARNING,
|
||||
"backend.services.gateway": logging.WARNING,
|
||||
"backend.services.market": logging.WARNING,
|
||||
"backend.services.storage": logging.WARNING,
|
||||
}
|
||||
|
||||
|
||||
class SuppressNoisyInfoFilter(logging.Filter):
|
||||
"""Filter out low-signal library INFO logs while keeping warnings/errors."""
|
||||
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
message = record.getMessage()
|
||||
if record.name == "httpx" and message.startswith("HTTP Request:"):
|
||||
return False
|
||||
if record.name.startswith("websockets") and "connection open" in message:
|
||||
return False
|
||||
if record.name.startswith("websockets") and "opening handshake failed" in message:
|
||||
return False
|
||||
|
||||
if record.levelno >= logging.WARNING:
|
||||
return True
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def configure_gateway_logging(verbose: bool = False) -> None:
|
||||
"""Configure gateway logging with low-noise defaults for runtime logs."""
|
||||
root_level = logging.DEBUG if verbose else logging.WARNING
|
||||
logging.basicConfig(
|
||||
level=root_level,
|
||||
format="%(asctime)s | %(levelname)-7s | %(name)s:%(lineno)d - %(message)s",
|
||||
force=True,
|
||||
)
|
||||
|
||||
if not verbose:
|
||||
suppress_filter = SuppressNoisyInfoFilter()
|
||||
for handler in logging.getLogger().handlers:
|
||||
handler.addFilter(suppress_filter)
|
||||
|
||||
for logger_name, level in NOISY_LOGGER_LEVELS.items():
|
||||
logging.getLogger(logger_name).setLevel(logging.DEBUG if verbose else level)
|
||||
|
||||
if not verbose:
|
||||
for prefix in INFO_LOGGER_PREFIXES:
|
||||
logging.getLogger(prefix).setLevel(logging.INFO)
|
||||
|
||||
logging.getLogger(__name__).setLevel(logging.INFO if not verbose else logging.DEBUG)
|
||||
|
||||
|
||||
async def run_gateway(
|
||||
run_id: str,
|
||||
run_dir: Path,
|
||||
bootstrap: dict,
|
||||
port: int
|
||||
):
|
||||
"""Run Gateway with Pipeline."""
|
||||
|
||||
# Extract config
|
||||
tickers = bootstrap.get("tickers", ["AAPL", "MSFT", "GOOGL", "AMZN", "NVDA", "META", "TSLA", "AMD", "NFLX", "AVGO", "PLTR", "COIN"])
|
||||
initial_cash = float(bootstrap.get("initial_cash", 100000.0))
|
||||
margin_requirement = float(bootstrap.get("margin_requirement", 0.0))
|
||||
max_comm_cycles = int(bootstrap.get("max_comm_cycles", 2))
|
||||
schedule_mode = bootstrap.get("schedule_mode", "daily")
|
||||
trigger_time = bootstrap.get("trigger_time", "09:30")
|
||||
interval_minutes = int(bootstrap.get("interval_minutes", 60))
|
||||
heartbeat_interval = int(bootstrap.get("heartbeat_interval", 0)) # 0 = disabled
|
||||
mode = bootstrap.get("mode", "live")
|
||||
start_date = bootstrap.get("start_date")
|
||||
end_date = bootstrap.get("end_date")
|
||||
enable_memory = bootstrap.get("enable_memory", False)
|
||||
poll_interval = int(bootstrap.get("poll_interval", 10))
|
||||
|
||||
is_backtest = mode == "backtest"
|
||||
|
||||
logger.info(f"[Gateway Server] Starting run {run_id} on port {port}")
|
||||
|
||||
# Create runtime manager
|
||||
runtime_manager = TradingRuntimeManager(
|
||||
config_name=run_id,
|
||||
run_dir=run_dir,
|
||||
bootstrap=bootstrap,
|
||||
)
|
||||
runtime_manager.prepare_run()
|
||||
set_global_runtime_manager(runtime_manager)
|
||||
|
||||
try:
|
||||
async with AsyncExitStack() as stack:
|
||||
# Create services
|
||||
market_service = MarketService(
|
||||
tickers=tickers,
|
||||
poll_interval=poll_interval,
|
||||
backtest_mode=is_backtest,
|
||||
api_key=os.getenv("FINNHUB_API_KEY") if not is_backtest else None,
|
||||
backtest_start_date=start_date if is_backtest else None,
|
||||
backtest_end_date=end_date if is_backtest else None,
|
||||
)
|
||||
|
||||
storage_service = StorageService(
|
||||
dashboard_dir=run_dir / "team_dashboard",
|
||||
initial_cash=initial_cash,
|
||||
config_name=run_id,
|
||||
)
|
||||
|
||||
if not storage_service.files["summary"].exists():
|
||||
storage_service.initialize_empty_dashboard()
|
||||
else:
|
||||
storage_service.update_leaderboard_model_info()
|
||||
|
||||
# Create agents
|
||||
analysts, risk_manager, pm, long_term_memories = create_agents(
|
||||
run_id=run_id,
|
||||
run_dir=run_dir,
|
||||
initial_cash=initial_cash,
|
||||
margin_requirement=margin_requirement,
|
||||
enable_long_term_memory=enable_memory,
|
||||
)
|
||||
|
||||
# Register agents
|
||||
for agent in analysts + [risk_manager, pm]:
|
||||
agent_id = getattr(agent, "agent_id", None) or getattr(agent, "name", None)
|
||||
if agent_id:
|
||||
runtime_manager.register_agent(agent_id)
|
||||
|
||||
# Load portfolio state
|
||||
portfolio_state = storage_service.load_portfolio_state()
|
||||
pm.load_portfolio_state(portfolio_state)
|
||||
|
||||
# Create settlement coordinator
|
||||
settlement_coordinator = SettlementCoordinator(
|
||||
storage=storage_service,
|
||||
initial_capital=initial_cash,
|
||||
)
|
||||
|
||||
# Create pipeline
|
||||
pipeline = TradingPipeline(
|
||||
analysts=analysts,
|
||||
risk_manager=risk_manager,
|
||||
portfolio_manager=pm,
|
||||
settlement_coordinator=settlement_coordinator,
|
||||
max_comm_cycles=max_comm_cycles,
|
||||
runtime_manager=runtime_manager,
|
||||
)
|
||||
|
||||
# Create scheduler
|
||||
scheduler_callback = None
|
||||
live_scheduler = None
|
||||
|
||||
if is_backtest:
|
||||
backtest_scheduler = BacktestScheduler(
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
trading_calendar="NYSE",
|
||||
delay_between_days=0.5,
|
||||
)
|
||||
|
||||
async def scheduler_callback_fn(callback):
|
||||
await backtest_scheduler.start(callback)
|
||||
|
||||
scheduler_callback = scheduler_callback_fn
|
||||
else:
|
||||
live_scheduler = Scheduler(
|
||||
mode=schedule_mode,
|
||||
trigger_time=trigger_time,
|
||||
interval_minutes=interval_minutes,
|
||||
heartbeat_interval=heartbeat_interval if heartbeat_interval > 0 else None,
|
||||
config={"config_name": run_id},
|
||||
)
|
||||
|
||||
async def scheduler_callback_fn(callback):
|
||||
await live_scheduler.start(callback)
|
||||
|
||||
scheduler_callback = scheduler_callback_fn
|
||||
|
||||
# Enter long-term memory contexts
|
||||
for memory in long_term_memories:
|
||||
await stack.enter_async_context(memory)
|
||||
|
||||
# Create Gateway
|
||||
gateway = Gateway(
|
||||
market_service=market_service,
|
||||
storage_service=storage_service,
|
||||
pipeline=pipeline,
|
||||
scheduler_callback=scheduler_callback,
|
||||
config={
|
||||
"mode": mode,
|
||||
"backtest_mode": is_backtest,
|
||||
"tickers": tickers,
|
||||
"config_name": run_id,
|
||||
"schedule_mode": schedule_mode,
|
||||
"interval_minutes": interval_minutes,
|
||||
"trigger_time": trigger_time,
|
||||
"heartbeat_interval": heartbeat_interval,
|
||||
"initial_cash": initial_cash,
|
||||
"margin_requirement": margin_requirement,
|
||||
"max_comm_cycles": max_comm_cycles,
|
||||
"enable_memory": enable_memory,
|
||||
},
|
||||
scheduler=live_scheduler,
|
||||
)
|
||||
|
||||
# Start Gateway (blocks until shutdown)
|
||||
logger.info(f"[Gateway Server] Gateway starting on port {port}")
|
||||
await gateway.start(host="0.0.0.0", port=port)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("[Gateway Server] Cancelled")
|
||||
raise
|
||||
finally:
|
||||
logger.info("[Gateway Server] Cleaning up")
|
||||
clear_global_runtime_manager()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point."""
|
||||
parser = argparse.ArgumentParser(description="Gateway Server")
|
||||
parser.add_argument("--run-id", required=True, help="Run identifier")
|
||||
parser.add_argument("--run-dir", required=True, help="Run directory path")
|
||||
parser.add_argument("--port", type=int, default=8765, help="WebSocket port")
|
||||
parser.add_argument("--bootstrap", required=True, help="Bootstrap config as JSON")
|
||||
parser.add_argument("--verbose", action="store_true", help="Verbose logging")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Setup logging
|
||||
configure_gateway_logging(verbose=args.verbose)
|
||||
|
||||
# Parse bootstrap
|
||||
bootstrap = json.loads(args.bootstrap)
|
||||
run_dir = Path(args.run_dir)
|
||||
|
||||
# Run
|
||||
try:
|
||||
asyncio.run(run_gateway(
|
||||
run_id=args.run_id,
|
||||
run_dir=run_dir,
|
||||
bootstrap=bootstrap,
|
||||
port=args.port
|
||||
))
|
||||
except KeyboardInterrupt:
|
||||
logger.info("[Gateway Server] Interrupted by user")
|
||||
except Exception as e:
|
||||
logger.exception(f"[Gateway Server] Fatal error: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -3,9 +3,13 @@
|
||||
AgentScope Native Model Factory
|
||||
Uses native AgentScope model classes for LLM calls
|
||||
"""
|
||||
from enum import Enum
|
||||
from typing import Optional, Tuple
|
||||
import asyncio
|
||||
import inspect
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Optional, Tuple, TypeVar, Union
|
||||
from agentscope.formatter import (
|
||||
AnthropicChatFormatter,
|
||||
DashScopeChatFormatter,
|
||||
@@ -26,6 +30,331 @@ from backend.config.env_config import (
|
||||
get_env_str,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Retry wrapper types
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def _usage_value(usage: Any, key: str, default: Any = 0) -> Any:
|
||||
"""Read usage fields from both object-style and dict-style usage payloads."""
|
||||
if usage is None:
|
||||
return default
|
||||
if isinstance(usage, dict):
|
||||
return usage.get(key, default)
|
||||
try:
|
||||
return getattr(usage, key)
|
||||
except (AttributeError, KeyError):
|
||||
return default
|
||||
|
||||
|
||||
def _usage_total_tokens(usage: Any) -> int:
|
||||
total = _usage_value(usage, "total_tokens", None)
|
||||
if total is not None:
|
||||
return int(total or 0)
|
||||
input_tokens = _usage_value(usage, "input_tokens", 0)
|
||||
output_tokens = _usage_value(usage, "output_tokens", 0)
|
||||
return int((input_tokens or 0) + (output_tokens or 0))
|
||||
|
||||
|
||||
class RetryChatModel:
|
||||
"""Wraps an AgentScope model with automatic retry for transient errors.
|
||||
|
||||
Based on CoPaw's RetryChatModel design. Handles rate limits, timeouts,
|
||||
and other transient failures with exponential backoff.
|
||||
"""
|
||||
|
||||
DEFAULT_MAX_RETRIES = 3
|
||||
DEFAULT_INITIAL_DELAY = 1.0
|
||||
DEFAULT_MAX_DELAY = 60.0
|
||||
DEFAULT_BACKOFF_MULTIPLIER = 2.0
|
||||
|
||||
# Transient error codes/messages that should trigger retry
|
||||
TRANSIENT_ERROR_KEYWORDS = frozenset([
|
||||
"rate_limit",
|
||||
"429",
|
||||
"timeout",
|
||||
"503",
|
||||
"502",
|
||||
"504",
|
||||
"connection",
|
||||
"disconnected",
|
||||
"temporary",
|
||||
"overloaded",
|
||||
"too_many_requests",
|
||||
])
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Any,
|
||||
max_retries: int = DEFAULT_MAX_RETRIES,
|
||||
initial_delay: float = DEFAULT_INITIAL_DELAY,
|
||||
max_delay: float = DEFAULT_MAX_DELAY,
|
||||
backoff_multiplier: float = DEFAULT_BACKOFF_MULTIPLIER,
|
||||
on_retry: Optional[Callable[[int, Exception, float], None]] = None,
|
||||
):
|
||||
"""Initialize retry wrapper.
|
||||
|
||||
Args:
|
||||
model: The underlying AgentScope model to wrap
|
||||
max_retries: Maximum number of retry attempts
|
||||
initial_delay: Initial delay in seconds before first retry
|
||||
max_delay: Maximum delay between retries
|
||||
backoff_multiplier: Multiplier for exponential backoff
|
||||
on_retry: Optional callback(retry_count, exception, delay) for logging
|
||||
"""
|
||||
self._model = model
|
||||
self._max_retries = max_retries
|
||||
self._initial_delay = initial_delay
|
||||
self._max_delay = max_delay
|
||||
self._backoff_multiplier = backoff_multiplier
|
||||
self._on_retry = on_retry
|
||||
self._total_tokens_used = 0
|
||||
self._total_cost = 0.0
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
return getattr(self._model, "model_name", str(self._model))
|
||||
|
||||
@property
|
||||
def total_tokens_used(self) -> int:
|
||||
return self._total_tokens_used
|
||||
|
||||
@property
|
||||
def total_cost(self) -> float:
|
||||
return self._total_cost
|
||||
|
||||
def _is_transient_error(self, error: Exception) -> bool:
|
||||
"""Check if an error is transient and should be retried.
|
||||
|
||||
Args:
|
||||
error: The exception to check
|
||||
|
||||
Returns:
|
||||
True if the error is transient
|
||||
"""
|
||||
error_str = str(error).lower()
|
||||
for keyword in self.TRANSIENT_ERROR_KEYWORDS:
|
||||
if keyword in error_str:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _calculate_delay(self, retry_count: int) -> float:
|
||||
"""Calculate delay for given retry attempt with exponential backoff.
|
||||
|
||||
Args:
|
||||
retry_count: Current retry attempt number (1-based)
|
||||
|
||||
Returns:
|
||||
Delay in seconds
|
||||
"""
|
||||
delay = self._initial_delay * (self._backoff_multiplier ** (retry_count - 1))
|
||||
return min(delay, self._max_delay)
|
||||
|
||||
def _call_with_retry(self, func: Callable[..., T], *args, **kwargs) -> T:
|
||||
"""Call a function with retry logic for transient errors.
|
||||
|
||||
Args:
|
||||
func: Function to call
|
||||
*args: Positional arguments
|
||||
**kwargs: Keyword arguments
|
||||
|
||||
Returns:
|
||||
Result from func
|
||||
|
||||
Raises:
|
||||
Last exception if all retries exhausted
|
||||
"""
|
||||
last_error: Optional[Exception] = None
|
||||
|
||||
for attempt in range(1, self._max_retries + 1):
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
|
||||
# Track usage if available
|
||||
if hasattr(result, "usage") and result.usage:
|
||||
usage = result.usage
|
||||
self._total_tokens_used += _usage_total_tokens(usage)
|
||||
self._total_cost += float(_usage_value(usage, "cost", 0.0) or 0.0)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
|
||||
if attempt >= self._max_retries:
|
||||
logger.error(
|
||||
"RetryChatModel: Max retries (%d) exhausted for %s",
|
||||
self._max_retries,
|
||||
self.model_name,
|
||||
)
|
||||
break
|
||||
|
||||
if not self._is_transient_error(e):
|
||||
logger.warning(
|
||||
"RetryChatModel: Non-transient error, not retrying: %s",
|
||||
str(e),
|
||||
)
|
||||
break
|
||||
|
||||
delay = self._calculate_delay(attempt)
|
||||
logger.warning(
|
||||
"RetryChatModel: Transient error on attempt %d/%d, "
|
||||
"retrying in %.1fs: %s",
|
||||
attempt,
|
||||
self._max_retries,
|
||||
delay,
|
||||
str(e)[:200],
|
||||
)
|
||||
|
||||
if self._on_retry:
|
||||
self._on_retry(attempt, e, delay)
|
||||
|
||||
time.sleep(delay)
|
||||
|
||||
if last_error is not None:
|
||||
raise last_error
|
||||
raise RuntimeError("RetryChatModel: Unexpected state, no error but no result")
|
||||
|
||||
async def _call_with_retry_async(self, func: Callable[..., T], *args, **kwargs) -> T:
|
||||
"""Call an async function with retry logic for transient errors."""
|
||||
last_error: Optional[Exception] = None
|
||||
|
||||
for attempt in range(1, self._max_retries + 1):
|
||||
try:
|
||||
result = await func(*args, **kwargs)
|
||||
|
||||
if hasattr(result, "usage") and result.usage:
|
||||
usage = result.usage
|
||||
self._total_tokens_used += _usage_total_tokens(usage)
|
||||
self._total_cost += float(_usage_value(usage, "cost", 0.0) or 0.0)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
|
||||
if attempt >= self._max_retries:
|
||||
logger.error(
|
||||
"RetryChatModel: Max retries (%d) exhausted for %s",
|
||||
self._max_retries,
|
||||
self.model_name,
|
||||
)
|
||||
break
|
||||
|
||||
if not self._is_transient_error(e):
|
||||
logger.warning(
|
||||
"RetryChatModel: Non-transient error, not retrying: %s",
|
||||
str(e),
|
||||
)
|
||||
break
|
||||
|
||||
delay = self._calculate_delay(attempt)
|
||||
logger.warning(
|
||||
"RetryChatModel: Transient async error on attempt %d/%d, "
|
||||
"retrying in %.1fs: %s",
|
||||
attempt,
|
||||
self._max_retries,
|
||||
delay,
|
||||
str(e)[:200],
|
||||
)
|
||||
|
||||
if self._on_retry:
|
||||
self._on_retry(attempt, e, delay)
|
||||
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
if last_error is not None:
|
||||
raise last_error
|
||||
raise RuntimeError("RetryChatModel: Unexpected async state, no error but no result")
|
||||
|
||||
def __call__(self, *args, **kwargs) -> Any:
|
||||
"""Forward calls to the wrapped model with retry logic."""
|
||||
model_call = getattr(self._model, "__call__", None)
|
||||
if inspect.iscoroutinefunction(self._model) or inspect.iscoroutinefunction(model_call):
|
||||
return self._call_with_retry_async(self._model, *args, **kwargs)
|
||||
|
||||
result = self._model(*args, **kwargs)
|
||||
return result
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
"""Proxy attribute access to the wrapped model."""
|
||||
return getattr(self._model, name)
|
||||
|
||||
|
||||
class TokenRecordingModelWrapper:
|
||||
"""Wraps a model to track token usage per provider.
|
||||
|
||||
Based on CoPaw's TokenRecordingModelWrapper design.
|
||||
"""
|
||||
|
||||
def __init__(self, model: Any):
|
||||
"""Initialize token recorder.
|
||||
|
||||
Args:
|
||||
model: The underlying AgentScope model to wrap
|
||||
"""
|
||||
self._model = model
|
||||
self._total_tokens = 0
|
||||
self._prompt_tokens = 0
|
||||
self._completion_tokens = 0
|
||||
self._total_cost = 0.0
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
return getattr(self._model, "model_name", str(self._model))
|
||||
|
||||
@property
|
||||
def total_tokens(self) -> int:
|
||||
return self._total_tokens
|
||||
|
||||
@property
|
||||
def prompt_tokens(self) -> int:
|
||||
return self._prompt_tokens
|
||||
|
||||
@property
|
||||
def completion_tokens(self) -> int:
|
||||
return self._completion_tokens
|
||||
|
||||
@property
|
||||
def total_cost(self) -> float:
|
||||
return self._total_cost
|
||||
|
||||
def record_usage(self, usage: Any) -> None:
|
||||
"""Record token usage from a model response.
|
||||
|
||||
Args:
|
||||
usage: Usage object from model response
|
||||
"""
|
||||
if usage is None:
|
||||
return
|
||||
|
||||
prompt_tokens = _usage_value(usage, "prompt_tokens", None)
|
||||
completion_tokens = _usage_value(usage, "completion_tokens", None)
|
||||
|
||||
if prompt_tokens is None:
|
||||
prompt_tokens = _usage_value(usage, "input_tokens", 0)
|
||||
if completion_tokens is None:
|
||||
completion_tokens = _usage_value(usage, "output_tokens", 0)
|
||||
|
||||
self._prompt_tokens += int(prompt_tokens or 0)
|
||||
self._completion_tokens += int(completion_tokens or 0)
|
||||
self._total_tokens += _usage_total_tokens(usage)
|
||||
self._total_cost += float(_usage_value(usage, "cost", 0.0) or 0.0)
|
||||
|
||||
def __call__(self, *args, **kwargs) -> Any:
|
||||
"""Forward calls and record usage."""
|
||||
result = self._model(*args, **kwargs)
|
||||
|
||||
if hasattr(result, "usage") and result.usage:
|
||||
self.record_usage(result.usage)
|
||||
|
||||
return result
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
"""Proxy attribute access to the wrapped model."""
|
||||
return getattr(self._model, name)
|
||||
|
||||
|
||||
class ModelProvider(Enum):
|
||||
"""Supported model providers"""
|
||||
@@ -161,7 +490,8 @@ def create_model(
|
||||
if host:
|
||||
model_kwargs["host"] = host
|
||||
|
||||
return model_class(**model_kwargs)
|
||||
model = model_class(**model_kwargs)
|
||||
return RetryChatModel(model)
|
||||
|
||||
|
||||
def get_agent_model(agent_id: str, stream: bool = False):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Main Entry Point
|
||||
Supports: backtest, live, mock modes
|
||||
Supports: backtest, live modes
|
||||
"""
|
||||
import argparse
|
||||
import asyncio
|
||||
@@ -16,23 +16,31 @@ from dotenv import load_dotenv
|
||||
from backend.agents import AnalystAgent, PMAgent, RiskAgent
|
||||
from backend.agents.skills_manager import SkillsManager
|
||||
from backend.agents.toolkit_factory import create_agent_toolkit, load_agent_profiles
|
||||
from backend.agents.prompt_loader import PromptLoader
|
||||
from backend.agents.prompt_loader import get_prompt_loader
|
||||
from backend.agents.workspace_manager import WorkspaceManager
|
||||
from backend.config.bootstrap_config import resolve_runtime_config
|
||||
from backend.config.constants import ANALYST_TYPES
|
||||
from backend.core.pipeline import TradingPipeline
|
||||
from backend.core.scheduler import BacktestScheduler, Scheduler
|
||||
from backend.utils.settlement import SettlementCoordinator
|
||||
from backend.llm.models import get_agent_formatter, get_agent_model
|
||||
from backend.api.runtime import register_runtime_manager, unregister_runtime_manager
|
||||
from backend.runtime.manager import (
|
||||
TradingRuntimeManager,
|
||||
set_global_runtime_manager,
|
||||
clear_global_runtime_manager,
|
||||
)
|
||||
from backend.gateway_server import configure_gateway_logging
|
||||
from backend.services.gateway import Gateway
|
||||
from backend.services.market import MarketService
|
||||
from backend.services.storage import StorageService
|
||||
from backend.utils.settlement import SettlementCoordinator
|
||||
|
||||
load_dotenv()
|
||||
logger = logging.getLogger(__name__)
|
||||
loguru.logger.disable("flowllm")
|
||||
loguru.logger.disable("reme_ai")
|
||||
_prompt_loader = PromptLoader()
|
||||
configure_gateway_logging(verbose=os.getenv("LOG_LEVEL", "").upper() == "DEBUG")
|
||||
_prompt_loader = get_prompt_loader()
|
||||
|
||||
|
||||
def _get_run_dir(config_name: str) -> Path:
|
||||
@@ -48,6 +56,9 @@ def _resolve_runtime_config(args) -> dict:
|
||||
project_root=project_root,
|
||||
config_name=args.config_name,
|
||||
enable_memory=args.enable_memory,
|
||||
schedule_mode=args.schedule_mode,
|
||||
interval_minutes=args.interval_minutes,
|
||||
trigger_time=args.trigger_time,
|
||||
)
|
||||
|
||||
|
||||
@@ -210,15 +221,20 @@ async def run_with_gateway(args):
|
||||
initial_cash = runtime_config["initial_cash"]
|
||||
margin_requirement = runtime_config["margin_requirement"]
|
||||
|
||||
runtime_manager = TradingRuntimeManager(
|
||||
config_name=config_name,
|
||||
run_dir=_get_run_dir(config_name),
|
||||
bootstrap=runtime_config,
|
||||
)
|
||||
runtime_manager.prepare_run()
|
||||
set_global_runtime_manager(runtime_manager)
|
||||
|
||||
# Create market service
|
||||
market_service = MarketService(
|
||||
tickers=tickers,
|
||||
poll_interval=args.poll_interval,
|
||||
mock_mode=args.mock and not is_backtest,
|
||||
backtest_mode=is_backtest,
|
||||
api_key=os.getenv("FINNHUB_API_KEY")
|
||||
if not args.mock and not is_backtest
|
||||
else None,
|
||||
api_key=os.getenv("FINNHUB_API_KEY") if not is_backtest else None,
|
||||
backtest_start_date=args.start_date if is_backtest else None,
|
||||
backtest_end_date=args.end_date if is_backtest else None,
|
||||
)
|
||||
@@ -242,6 +258,10 @@ async def run_with_gateway(args):
|
||||
margin_requirement=margin_requirement,
|
||||
enable_long_term_memory=runtime_config["enable_memory"],
|
||||
)
|
||||
for agent in analysts + [risk_manager, pm]:
|
||||
agent_id = getattr(agent, "agent_id", None) or getattr(agent, "name", None)
|
||||
if agent_id:
|
||||
runtime_manager.register_agent(agent_id)
|
||||
portfolio_state = storage_service.load_portfolio_state()
|
||||
pm.load_portfolio_state(portfolio_state)
|
||||
|
||||
@@ -256,11 +276,13 @@ async def run_with_gateway(args):
|
||||
portfolio_manager=pm,
|
||||
settlement_coordinator=settlement_coordinator,
|
||||
max_comm_cycles=runtime_config["max_comm_cycles"],
|
||||
runtime_manager=runtime_manager,
|
||||
)
|
||||
|
||||
# Create scheduler callback
|
||||
scheduler_callback = None
|
||||
trading_dates = []
|
||||
live_scheduler = None
|
||||
|
||||
if is_backtest:
|
||||
backtest_scheduler = BacktestScheduler(
|
||||
@@ -276,10 +298,11 @@ async def run_with_gateway(args):
|
||||
|
||||
scheduler_callback = scheduler_callback_fn
|
||||
else:
|
||||
# Live mode: use daily scheduler with NYSE timezone
|
||||
# Live mode: use daily or intraday scheduler with NYSE timezone
|
||||
live_scheduler = Scheduler(
|
||||
mode="daily",
|
||||
trigger_time=args.trigger_time,
|
||||
mode=runtime_config["schedule_mode"],
|
||||
trigger_time=runtime_config["trigger_time"],
|
||||
interval_minutes=runtime_config["interval_minutes"],
|
||||
config={"config_name": config_name},
|
||||
)
|
||||
|
||||
@@ -296,15 +319,18 @@ async def run_with_gateway(args):
|
||||
scheduler_callback=scheduler_callback,
|
||||
config={
|
||||
"mode": args.mode,
|
||||
"mock_mode": args.mock,
|
||||
"backtest_mode": is_backtest,
|
||||
"tickers": tickers,
|
||||
"config_name": config_name,
|
||||
"schedule_mode": runtime_config["schedule_mode"],
|
||||
"interval_minutes": runtime_config["interval_minutes"],
|
||||
"trigger_time": runtime_config["trigger_time"],
|
||||
"initial_cash": initial_cash,
|
||||
"margin_requirement": margin_requirement,
|
||||
"max_comm_cycles": runtime_config["max_comm_cycles"],
|
||||
"enable_memory": runtime_config["enable_memory"],
|
||||
},
|
||||
scheduler=live_scheduler if not is_backtest else None,
|
||||
)
|
||||
|
||||
if is_backtest:
|
||||
@@ -312,20 +338,29 @@ async def run_with_gateway(args):
|
||||
|
||||
# Start long-term memory contexts and run gateway
|
||||
async with AsyncExitStack() as stack:
|
||||
try:
|
||||
for memory in long_term_memories:
|
||||
await stack.enter_async_context(memory)
|
||||
await gateway.start(host=args.host, port=args.port)
|
||||
finally:
|
||||
unregister_runtime_manager()
|
||||
clear_global_runtime_manager()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point"""
|
||||
parser = argparse.ArgumentParser(description="Trading System")
|
||||
parser.add_argument("--mode", choices=["live", "backtest"], default="live")
|
||||
parser.add_argument("--mock", action="store_true")
|
||||
parser.add_argument("--config-name", default="mock")
|
||||
parser.add_argument("--config-name", default="live")
|
||||
parser.add_argument("--host", default="0.0.0.0")
|
||||
parser.add_argument("--port", type=int, default=8765)
|
||||
parser.add_argument(
|
||||
"--schedule-mode",
|
||||
choices=["daily", "intraday"],
|
||||
default="daily",
|
||||
)
|
||||
parser.add_argument("--trigger-time", default="09:30") # NYSE market open
|
||||
parser.add_argument("--interval-minutes", type=int, default=60)
|
||||
parser.add_argument("--poll-interval", type=int, default=10)
|
||||
parser.add_argument("--start-date")
|
||||
parser.add_argument("--end-date")
|
||||
|
||||
41
backend/process/models.py
Normal file
41
backend/process/models.py
Normal file
@@ -0,0 +1,41 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Data models for lightweight process supervision."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
class ProcessRunState(str, Enum):
|
||||
"""Execution state for supervised runs."""
|
||||
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProcessRun:
|
||||
"""Represents a supervised process run."""
|
||||
|
||||
run_id: str
|
||||
command: str
|
||||
scope_key: str
|
||||
state: ProcessRunState = ProcessRunState.PENDING
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
created_at: datetime = field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = field(default_factory=datetime.utcnow)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"run_id": self.run_id,
|
||||
"command": self.command,
|
||||
"scope_key": self.scope_key,
|
||||
"state": self.state.value,
|
||||
"metadata": self.metadata,
|
||||
"created_at": self.created_at.isoformat(),
|
||||
"updated_at": self.updated_at.isoformat(),
|
||||
}
|
||||
35
backend/process/registry.py
Normal file
35
backend/process/registry.py
Normal file
@@ -0,0 +1,35 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Registry for managing supervised process metadata."""
|
||||
|
||||
from threading import Lock
|
||||
from typing import Dict, Iterable, Optional
|
||||
|
||||
from .models import ProcessRun
|
||||
|
||||
|
||||
class RunRegistry:
|
||||
"""In-memory registry for tracked process runs."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._runs: Dict[str, ProcessRun] = {}
|
||||
self._lock = Lock()
|
||||
|
||||
def add(self, run: ProcessRun) -> None:
|
||||
with self._lock:
|
||||
self._runs[run.run_id] = run
|
||||
|
||||
def get(self, run_id: str) -> Optional[ProcessRun]:
|
||||
with self._lock:
|
||||
return self._runs.get(run_id)
|
||||
|
||||
def list(self) -> Iterable[ProcessRun]:
|
||||
with self._lock:
|
||||
return list(self._runs.values())
|
||||
|
||||
def update(self, run: ProcessRun) -> None:
|
||||
with self._lock:
|
||||
self._runs[run.run_id] = run
|
||||
|
||||
def remove(self, run_id: str) -> None:
|
||||
with self._lock:
|
||||
self._runs.pop(run_id, None)
|
||||
61
backend/process/supervisor.py
Normal file
61
backend/process/supervisor.py
Normal file
@@ -0,0 +1,61 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Minimal supervisor for scripted tasks and long-running utilities."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Iterable, Optional
|
||||
|
||||
from .models import ProcessRun, ProcessRunState
|
||||
from .registry import RunRegistry
|
||||
|
||||
|
||||
class ProcessSupervisor:
|
||||
"""Tracks supervised runs without executing real processes yet."""
|
||||
|
||||
def __init__(self, registry: Optional[RunRegistry] = None) -> None:
|
||||
self.registry = registry or RunRegistry()
|
||||
|
||||
def spawn(
|
||||
self,
|
||||
run_id: str,
|
||||
command: str,
|
||||
scope_key: str,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> ProcessRun:
|
||||
run = ProcessRun(
|
||||
run_id=run_id,
|
||||
command=command,
|
||||
scope_key=scope_key,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
run.state = ProcessRunState.RUNNING
|
||||
run.updated_at = datetime.utcnow()
|
||||
self.registry.add(run)
|
||||
return run
|
||||
|
||||
def update_state(
|
||||
self,
|
||||
run_id: str,
|
||||
state: ProcessRunState,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> Optional[ProcessRun]:
|
||||
run = self.registry.get(run_id)
|
||||
if not run:
|
||||
return None
|
||||
run.state = state
|
||||
run.metadata.update(metadata or {})
|
||||
run.updated_at = datetime.utcnow()
|
||||
self.registry.update(run)
|
||||
return run
|
||||
|
||||
def cancel(self, run_id: str, reason: Optional[str] = None) -> Optional[ProcessRun]:
|
||||
run = self.registry.get(run_id)
|
||||
if not run:
|
||||
return None
|
||||
run.state = ProcessRunState.CANCELLED
|
||||
run.metadata.setdefault("cancel_reason", reason or "manual")
|
||||
run.updated_at = datetime.utcnow()
|
||||
self.registry.update(run)
|
||||
return run
|
||||
|
||||
def list_runs(self) -> Iterable[ProcessRun]:
|
||||
return self.registry.list()
|
||||
13
backend/runtime/__init__.py
Normal file
13
backend/runtime/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from .agent_runtime import AgentRuntimeState
|
||||
from .context import TradingRunContext
|
||||
from .manager import TradingRuntimeManager
|
||||
from .registry import RuntimeRegistry
|
||||
from .session import TradingSessionKey
|
||||
|
||||
__all__ = [
|
||||
"AgentRuntimeState",
|
||||
"TradingRunContext",
|
||||
"TradingRuntimeManager",
|
||||
"RuntimeRegistry",
|
||||
"TradingSessionKey",
|
||||
]
|
||||
26
backend/runtime/agent_runtime.py
Normal file
26
backend/runtime/agent_runtime.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, UTC
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentRuntimeState:
|
||||
agent_id: str
|
||||
status: str = "idle"
|
||||
last_session: str | None = None
|
||||
last_updated: datetime = field(default_factory=lambda: datetime.now(UTC))
|
||||
|
||||
def update(self, status: str, session_key: str | None = None) -> None:
|
||||
self.status = status
|
||||
self.last_session = session_key
|
||||
self.last_updated = datetime.now(UTC)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"agent_id": self.agent_id,
|
||||
"status": self.status,
|
||||
"last_session": self.last_session,
|
||||
"last_updated": self.last_updated.isoformat(),
|
||||
}
|
||||
15
backend/runtime/context.py
Normal file
15
backend/runtime/context.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TradingRunContext:
|
||||
config_name: str
|
||||
run_dir: Path
|
||||
bootstrap_values: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def describe(self) -> str:
|
||||
return f"Run {self.config_name} @ {self.run_dir}"
|
||||
173
backend/runtime/manager.py
Normal file
173
backend/runtime/manager.py
Normal file
@@ -0,0 +1,173 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime, UTC
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from .agent_runtime import AgentRuntimeState
|
||||
from .context import TradingRunContext
|
||||
from .registry import RuntimeRegistry
|
||||
|
||||
_global_runtime_manager: Optional["TradingRuntimeManager"] = None
|
||||
_shutdown_event: Optional[asyncio.Event] = None
|
||||
|
||||
# Lazy import to avoid circular dependency
|
||||
_api_runtime = None
|
||||
|
||||
|
||||
def _get_api_runtime():
|
||||
global _api_runtime
|
||||
if _api_runtime is None:
|
||||
from backend.api import runtime as api_runtime_module
|
||||
_api_runtime = api_runtime_module
|
||||
return _api_runtime
|
||||
|
||||
|
||||
def set_global_runtime_manager(manager: "TradingRuntimeManager") -> None:
|
||||
global _global_runtime_manager
|
||||
_global_runtime_manager = manager
|
||||
# Sync to RuntimeState for consistency
|
||||
_get_api_runtime().register_runtime_manager(manager)
|
||||
|
||||
|
||||
def clear_global_runtime_manager() -> None:
|
||||
global _global_runtime_manager
|
||||
_global_runtime_manager = None
|
||||
# Sync to RuntimeState for consistency
|
||||
_get_api_runtime().unregister_runtime_manager()
|
||||
|
||||
|
||||
def get_global_runtime_manager() -> Optional["TradingRuntimeManager"]:
|
||||
return _global_runtime_manager
|
||||
|
||||
|
||||
def set_shutdown_event(event: asyncio.Event) -> None:
|
||||
"""Set the global shutdown event for signaling runtime stop."""
|
||||
global _shutdown_event
|
||||
_shutdown_event = event
|
||||
|
||||
|
||||
def clear_shutdown_event() -> None:
|
||||
"""Clear the global shutdown event."""
|
||||
global _shutdown_event
|
||||
_shutdown_event = None
|
||||
|
||||
|
||||
def get_shutdown_event() -> Optional[asyncio.Event]:
|
||||
"""Get the global shutdown event if set."""
|
||||
return _shutdown_event
|
||||
|
||||
|
||||
def is_shutdown_requested() -> bool:
|
||||
"""Check if shutdown has been requested."""
|
||||
return _shutdown_event is not None and _shutdown_event.is_set()
|
||||
|
||||
|
||||
class TradingRuntimeManager:
|
||||
def __init__(self, config_name: str, run_dir: Path, bootstrap: Optional[Dict[str, Any]] = None) -> None:
|
||||
self.config_name = config_name
|
||||
self.run_dir = run_dir
|
||||
self.bootstrap = bootstrap or {}
|
||||
self.context: Optional[TradingRunContext] = None
|
||||
self.registry = RuntimeRegistry()
|
||||
self.current_session_key: Optional[str] = None
|
||||
self.events: List[Dict[str, Any]] = []
|
||||
self.pending_approvals: Dict[str, Dict[str, Any]] = {}
|
||||
self.snapshot_path = self.run_dir / "state" / "runtime_state.json"
|
||||
|
||||
def prepare_run(self) -> TradingRunContext:
|
||||
self.run_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.context = TradingRunContext(
|
||||
config_name=self.config_name,
|
||||
run_dir=self.run_dir,
|
||||
bootstrap_values=self.bootstrap,
|
||||
)
|
||||
self._persist_snapshot()
|
||||
return self.context
|
||||
|
||||
def set_session_key(self, session_key: str) -> None:
|
||||
self.current_session_key = session_key
|
||||
self._persist_snapshot()
|
||||
|
||||
def log_event(self, event: str, details: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||
entry = {
|
||||
"timestamp": datetime.now(UTC).isoformat(),
|
||||
"event": event,
|
||||
"details": details or {},
|
||||
"session": self.current_session_key,
|
||||
}
|
||||
self.events.append(entry)
|
||||
self._persist_snapshot()
|
||||
return entry
|
||||
|
||||
def register_agent(self, agent_id: str) -> AgentRuntimeState:
|
||||
state = AgentRuntimeState(agent_id=agent_id)
|
||||
self.registry.register(agent_id, state)
|
||||
self._persist_snapshot()
|
||||
return state
|
||||
|
||||
def register_pending_approval(self, approval_id: str, payload: Dict[str, Any]) -> None:
|
||||
payload.setdefault("status", "pending")
|
||||
payload.setdefault("created_at", datetime.now(UTC).isoformat())
|
||||
self.pending_approvals[approval_id] = payload
|
||||
self._persist_snapshot()
|
||||
|
||||
def update_agent_status(
|
||||
self,
|
||||
agent_id: str,
|
||||
status: str,
|
||||
session_key: Optional[str] = None,
|
||||
) -> AgentRuntimeState:
|
||||
state = self.registry.get(agent_id)
|
||||
if state is None:
|
||||
state = self.register_agent(agent_id)
|
||||
effective_session = session_key or self.current_session_key
|
||||
state.update(status, effective_session)
|
||||
self._persist_snapshot()
|
||||
return state
|
||||
|
||||
def get_agent_state(self, agent_id: str) -> Optional[AgentRuntimeState]:
|
||||
return self.registry.get(agent_id)
|
||||
|
||||
def list_agents(self) -> list[str]:
|
||||
return self.registry.list_agents()
|
||||
|
||||
def resolve_pending_approval(self, approval_id: str, resolved_by: str, status: str) -> None:
|
||||
entry = self.pending_approvals.get(approval_id)
|
||||
if not entry:
|
||||
return
|
||||
entry["status"] = status
|
||||
entry["resolved_at"] = datetime.now(UTC).isoformat()
|
||||
entry["resolved_by"] = resolved_by
|
||||
self._persist_snapshot()
|
||||
|
||||
def list_pending_approvals(self) -> List[Dict[str, Any]]:
|
||||
return list(self.pending_approvals.values())
|
||||
|
||||
def build_snapshot(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"context": {
|
||||
"config_name": self.context.config_name,
|
||||
"run_dir": str(self.context.run_dir),
|
||||
"bootstrap_values": self.context.bootstrap_values,
|
||||
}
|
||||
if self.context
|
||||
else None,
|
||||
"current_session_key": self.current_session_key,
|
||||
"agents": [
|
||||
state.to_dict()
|
||||
for agent_id in self.registry.list_agents()
|
||||
if (state := self.registry.get(agent_id)) is not None
|
||||
],
|
||||
"events": self.events,
|
||||
"pending_approvals": self.list_pending_approvals(),
|
||||
}
|
||||
|
||||
def _persist_snapshot(self) -> None:
|
||||
self.snapshot_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self.snapshot_path.write_text(
|
||||
json.dumps(self.build_snapshot(), ensure_ascii=False, indent=2),
|
||||
encoding="utf-8",
|
||||
)
|
||||
20
backend/runtime/registry.py
Normal file
20
backend/runtime/registry.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, Optional
|
||||
|
||||
|
||||
class RuntimeRegistry:
|
||||
def __init__(self) -> None:
|
||||
self._states: Dict[str, "AgentRuntimeState"] = {}
|
||||
|
||||
def register(self, agent_id: str, state: "AgentRuntimeState") -> None:
|
||||
self._states[agent_id] = state
|
||||
|
||||
def get(self, agent_id: str) -> Optional["AgentRuntimeState"]:
|
||||
return self._states.get(agent_id)
|
||||
|
||||
def list_agents(self) -> list[str]:
|
||||
return list(self._states.keys())
|
||||
|
||||
def clear(self) -> None:
|
||||
self._states.clear()
|
||||
14
backend/runtime/session.py
Normal file
14
backend/runtime/session.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TradingSessionKey:
|
||||
date: str
|
||||
ticker: str | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.date:
|
||||
raise ValueError("Session must have a date")
|
||||
|
||||
def key(self) -> str:
|
||||
return f"{self.date}:{self.ticker or 'all'}"
|
||||
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user