Compare commits
19 Commits
dev
...
456748b01e
| Author | SHA1 | Date | |
|---|---|---|---|
| 456748b01e | |||
| 609b509446 | |||
| 38102d0805 | |||
| 3448667b79 | |||
| 0f1bc2bb39 | |||
| 06a23c32a4 | |||
| 5b925fbe02 | |||
| 4b5ac86b83 | |||
| f4a2b7f3af | |||
| 2dcda63394 | |||
| a3f767126f | |||
| 9ec4a8702d | |||
| 3174734f26 | |||
| 59b44545d0 | |||
| 2daf5717ba | |||
| 1f5ee3698e | |||
| 3a5558b576 | |||
| a41cd705b4 | |||
| 564c92c0c8 |
@@ -1,7 +1,7 @@
|
|||||||
{
|
{
|
||||||
"version": "1.0.0",
|
"version": "1.0.0",
|
||||||
"lastScanned": 1773304964541,
|
"lastScanned": 1773938154948,
|
||||||
"projectRoot": "/Users/cillin/workspeace/agentscope-samples/evotraders",
|
"projectRoot": "/Users/cillin/workspeace/evotraders",
|
||||||
"techStack": {
|
"techStack": {
|
||||||
"languages": [
|
"languages": [
|
||||||
{
|
{
|
||||||
@@ -11,6 +11,14 @@
|
|||||||
"markers": [
|
"markers": [
|
||||||
"pyproject.toml"
|
"pyproject.toml"
|
||||||
]
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "C/C++",
|
||||||
|
"version": null,
|
||||||
|
"confidence": "high",
|
||||||
|
"markers": [
|
||||||
|
"Makefile"
|
||||||
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"frameworks": [
|
"frameworks": [
|
||||||
@@ -24,8 +32,8 @@
|
|||||||
"runtime": null
|
"runtime": null
|
||||||
},
|
},
|
||||||
"build": {
|
"build": {
|
||||||
"buildCommand": null,
|
"buildCommand": "make build",
|
||||||
"testCommand": "pytest",
|
"testCommand": "make test",
|
||||||
"lintCommand": "ruff check",
|
"lintCommand": "ruff check",
|
||||||
"devCommand": null,
|
"devCommand": null,
|
||||||
"scripts": {}
|
"scripts": {}
|
||||||
@@ -40,7 +48,8 @@
|
|||||||
"isMonorepo": false,
|
"isMonorepo": false,
|
||||||
"workspaces": [],
|
"workspaces": [],
|
||||||
"mainDirectories": [
|
"mainDirectories": [
|
||||||
"docs"
|
"docs",
|
||||||
|
"scripts"
|
||||||
],
|
],
|
||||||
"gitBranches": {
|
"gitBranches": {
|
||||||
"defaultBranch": "main",
|
"defaultBranch": "main",
|
||||||
@@ -49,29 +58,64 @@
|
|||||||
},
|
},
|
||||||
"customNotes": [],
|
"customNotes": [],
|
||||||
"directoryMap": {
|
"directoryMap": {
|
||||||
|
"agent-service": {
|
||||||
|
"path": "agent-service",
|
||||||
|
"purpose": null,
|
||||||
|
"fileCount": 2,
|
||||||
|
"lastAccessed": 1773938154941,
|
||||||
|
"keyFiles": [
|
||||||
|
"Dockerfile",
|
||||||
|
"requirements.txt"
|
||||||
|
]
|
||||||
|
},
|
||||||
"backend": {
|
"backend": {
|
||||||
"path": "backend",
|
"path": "backend",
|
||||||
"purpose": null,
|
"purpose": null,
|
||||||
"fileCount": 3,
|
"fileCount": 5,
|
||||||
"lastAccessed": 1773304964533,
|
"lastAccessed": 1773938154941,
|
||||||
"keyFiles": [
|
"keyFiles": [
|
||||||
"__init__.py",
|
"__init__.py",
|
||||||
|
"app.py",
|
||||||
"cli.py",
|
"cli.py",
|
||||||
|
"gateway_server.py",
|
||||||
"main.py"
|
"main.py"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
"backtest": {
|
||||||
|
"path": "backtest",
|
||||||
|
"purpose": null,
|
||||||
|
"fileCount": 0,
|
||||||
|
"lastAccessed": 1773938154941,
|
||||||
|
"keyFiles": []
|
||||||
|
},
|
||||||
|
"data": {
|
||||||
|
"path": "data",
|
||||||
|
"purpose": "Data files",
|
||||||
|
"fileCount": 1,
|
||||||
|
"lastAccessed": 1773938154941,
|
||||||
|
"keyFiles": [
|
||||||
|
"market_research.db"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"deploy": {
|
||||||
|
"path": "deploy",
|
||||||
|
"purpose": null,
|
||||||
|
"fileCount": 0,
|
||||||
|
"lastAccessed": 1773938154942,
|
||||||
|
"keyFiles": []
|
||||||
|
},
|
||||||
"docs": {
|
"docs": {
|
||||||
"path": "docs",
|
"path": "docs",
|
||||||
"purpose": "Documentation",
|
"purpose": "Documentation",
|
||||||
"fileCount": 0,
|
"fileCount": 0,
|
||||||
"lastAccessed": 1773304964533,
|
"lastAccessed": 1773938154942,
|
||||||
"keyFiles": []
|
"keyFiles": []
|
||||||
},
|
},
|
||||||
"evotraders.egg-info": {
|
"evotraders.egg-info": {
|
||||||
"path": "evotraders.egg-info",
|
"path": "evotraders.egg-info",
|
||||||
"purpose": null,
|
"purpose": null,
|
||||||
"fileCount": 6,
|
"fileCount": 6,
|
||||||
"lastAccessed": 1773304964534,
|
"lastAccessed": 1773938154942,
|
||||||
"keyFiles": [
|
"keyFiles": [
|
||||||
"PKG-INFO",
|
"PKG-INFO",
|
||||||
"SOURCES.txt",
|
"SOURCES.txt",
|
||||||
@@ -83,8 +127,8 @@
|
|||||||
"frontend": {
|
"frontend": {
|
||||||
"path": "frontend",
|
"path": "frontend",
|
||||||
"purpose": null,
|
"purpose": null,
|
||||||
"fileCount": 12,
|
"fileCount": 13,
|
||||||
"lastAccessed": 1773304964535,
|
"lastAccessed": 1773938154942,
|
||||||
"keyFiles": [
|
"keyFiles": [
|
||||||
"README.md",
|
"README.md",
|
||||||
"components.json",
|
"components.json",
|
||||||
@@ -93,239 +137,488 @@
|
|||||||
"index.css"
|
"index.css"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
"live": {
|
||||||
|
"path": "live",
|
||||||
|
"purpose": null,
|
||||||
|
"fileCount": 0,
|
||||||
|
"lastAccessed": 1773938154943,
|
||||||
|
"keyFiles": []
|
||||||
|
},
|
||||||
|
"logs": {
|
||||||
|
"path": "logs",
|
||||||
|
"purpose": null,
|
||||||
|
"fileCount": 7,
|
||||||
|
"lastAccessed": 1773938154943,
|
||||||
|
"keyFiles": [
|
||||||
|
"2026-03-16_00-48-03.log",
|
||||||
|
"2026-03-18_23-17-29.log",
|
||||||
|
"2026-03-18_23-17-30.2026-03-18_23-17-30_000801.log.zip",
|
||||||
|
"2026-03-18_23-17-30.log",
|
||||||
|
"2026-03-19_00-18-04.log"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"news-service": {
|
||||||
|
"path": "news-service",
|
||||||
|
"purpose": null,
|
||||||
|
"fileCount": 3,
|
||||||
|
"lastAccessed": 1773938154943,
|
||||||
|
"keyFiles": [
|
||||||
|
"Dockerfile",
|
||||||
|
"requirements.txt"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"reference": {
|
||||||
|
"path": "reference",
|
||||||
|
"purpose": null,
|
||||||
|
"fileCount": 0,
|
||||||
|
"lastAccessed": 1773938154943,
|
||||||
|
"keyFiles": []
|
||||||
|
},
|
||||||
|
"runs": {
|
||||||
|
"path": "runs",
|
||||||
|
"purpose": null,
|
||||||
|
"fileCount": 0,
|
||||||
|
"lastAccessed": 1773938154944,
|
||||||
|
"keyFiles": []
|
||||||
|
},
|
||||||
|
"scripts": {
|
||||||
|
"path": "scripts",
|
||||||
|
"purpose": "Build/utility scripts",
|
||||||
|
"fileCount": 1,
|
||||||
|
"lastAccessed": 1773938154944,
|
||||||
|
"keyFiles": [
|
||||||
|
"run_prod.sh"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"services": {
|
||||||
|
"path": "services",
|
||||||
|
"purpose": "Business logic services",
|
||||||
|
"fileCount": 1,
|
||||||
|
"lastAccessed": 1773938154944,
|
||||||
|
"keyFiles": [
|
||||||
|
"README.md"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"shared": {
|
||||||
|
"path": "shared",
|
||||||
|
"purpose": null,
|
||||||
|
"fileCount": 0,
|
||||||
|
"lastAccessed": 1773938154944,
|
||||||
|
"keyFiles": []
|
||||||
|
},
|
||||||
|
"trading-service": {
|
||||||
|
"path": "trading-service",
|
||||||
|
"purpose": null,
|
||||||
|
"fileCount": 4,
|
||||||
|
"lastAccessed": 1773938154944,
|
||||||
|
"keyFiles": [
|
||||||
|
"Dockerfile",
|
||||||
|
"README.md",
|
||||||
|
"requirements.txt"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"workspaces": {
|
||||||
|
"path": "workspaces",
|
||||||
|
"purpose": null,
|
||||||
|
"fileCount": 0,
|
||||||
|
"lastAccessed": 1773938154944,
|
||||||
|
"keyFiles": []
|
||||||
|
},
|
||||||
|
"agent-service/src": {
|
||||||
|
"path": "agent-service/src",
|
||||||
|
"purpose": "Source code",
|
||||||
|
"fileCount": 5,
|
||||||
|
"lastAccessed": 1773938154944,
|
||||||
|
"keyFiles": [
|
||||||
|
"__init__.py",
|
||||||
|
"config.py",
|
||||||
|
"main.py"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"backend/api": {
|
||||||
|
"path": "backend/api",
|
||||||
|
"purpose": "API routes",
|
||||||
|
"fileCount": 5,
|
||||||
|
"lastAccessed": 1773938154944,
|
||||||
|
"keyFiles": [
|
||||||
|
"__init__.py",
|
||||||
|
"agents.py",
|
||||||
|
"guard.py"
|
||||||
|
]
|
||||||
|
},
|
||||||
"backend/config": {
|
"backend/config": {
|
||||||
"path": "backend/config",
|
"path": "backend/config",
|
||||||
"purpose": "Configuration files",
|
"purpose": "Configuration files",
|
||||||
"fileCount": 4,
|
"fileCount": 6,
|
||||||
"lastAccessed": 1773304964535,
|
"lastAccessed": 1773938154944,
|
||||||
"keyFiles": [
|
"keyFiles": [
|
||||||
"__init__.py",
|
"__init__.py",
|
||||||
"constants.py",
|
"agent_profiles.yaml",
|
||||||
"data_config.py"
|
"bootstrap_config.py"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"backend/data": {
|
"backend/data": {
|
||||||
"path": "backend/data",
|
"path": "backend/data",
|
||||||
"purpose": "Data files",
|
"purpose": "Data files",
|
||||||
"fileCount": 7,
|
"fileCount": 13,
|
||||||
"lastAccessed": 1773304964536,
|
"lastAccessed": 1773938154944,
|
||||||
"keyFiles": [
|
"keyFiles": [
|
||||||
"__init__.py",
|
"__init__.py",
|
||||||
"cache.py",
|
"cache.py",
|
||||||
"historical_price_manager.py"
|
"historical_price_manager.py"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"backend/services": {
|
|
||||||
"path": "backend/services",
|
|
||||||
"purpose": "Business logic services",
|
|
||||||
"fileCount": 4,
|
|
||||||
"lastAccessed": 1773304964536,
|
|
||||||
"keyFiles": [
|
|
||||||
"__init__.py",
|
|
||||||
"gateway.py",
|
|
||||||
"market.py"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"backend/tests": {
|
|
||||||
"path": "backend/tests",
|
|
||||||
"purpose": "Test files",
|
|
||||||
"fileCount": 4,
|
|
||||||
"lastAccessed": 1773304964536,
|
|
||||||
"keyFiles": [
|
|
||||||
"__init__.py",
|
|
||||||
"test_agents.py",
|
|
||||||
"test_market_service.py"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"docs/assets": {
|
"docs/assets": {
|
||||||
"path": "docs/assets",
|
"path": "docs/assets",
|
||||||
"purpose": "Static assets",
|
"purpose": "Static assets",
|
||||||
"fileCount": 5,
|
"fileCount": 5,
|
||||||
"lastAccessed": 1773304964536,
|
"lastAccessed": 1773938154944,
|
||||||
"keyFiles": [
|
"keyFiles": [
|
||||||
"dashboard.jpg",
|
"dashboard.jpg",
|
||||||
"evotraders_demo.gif",
|
"evotraders_demo.gif",
|
||||||
"evotraders_logo.jpg"
|
"evotraders_logo.jpg"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"frontend/public": {
|
"frontend/dist": {
|
||||||
"path": "frontend/public",
|
"path": "frontend/dist",
|
||||||
"purpose": "Public files",
|
"purpose": "Distribution/build output",
|
||||||
"fileCount": 1,
|
"fileCount": 2,
|
||||||
"lastAccessed": 1773304964538,
|
"lastAccessed": 1773938154945,
|
||||||
"keyFiles": [
|
"keyFiles": [
|
||||||
|
"index.html",
|
||||||
"trading_logo.png"
|
"trading_logo.png"
|
||||||
]
|
]
|
||||||
|
},
|
||||||
|
"frontend/node_modules": {
|
||||||
|
"path": "frontend/node_modules",
|
||||||
|
"purpose": "Dependencies",
|
||||||
|
"fileCount": 1,
|
||||||
|
"lastAccessed": 1773938154947,
|
||||||
|
"keyFiles": []
|
||||||
|
},
|
||||||
|
"news-service/src": {
|
||||||
|
"path": "news-service/src",
|
||||||
|
"purpose": "Source code",
|
||||||
|
"fileCount": 3,
|
||||||
|
"lastAccessed": 1773938154948,
|
||||||
|
"keyFiles": [
|
||||||
|
"__init__.py",
|
||||||
|
"config.py",
|
||||||
|
"main.py"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"trading-service/src": {
|
||||||
|
"path": "trading-service/src",
|
||||||
|
"purpose": "Source code",
|
||||||
|
"fileCount": 8,
|
||||||
|
"lastAccessed": 1773938154948,
|
||||||
|
"keyFiles": [
|
||||||
|
"__init__.py",
|
||||||
|
"config.py",
|
||||||
|
"main.py"
|
||||||
|
]
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"hotPaths": [
|
"hotPaths": [
|
||||||
{
|
{
|
||||||
"path": "frontend/src/components/StatisticsView.jsx",
|
"path": "backend/agents/factory.py",
|
||||||
"accessCount": 22,
|
|
||||||
"lastAccessed": 1773310044545,
|
|
||||||
"type": "file"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"path": "frontend/src/components/AgentCard.jsx",
|
|
||||||
"accessCount": 17,
|
"accessCount": 17,
|
||||||
"lastAccessed": 1773309995177,
|
"lastAccessed": 1773939950376,
|
||||||
"type": "file"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"path": "frontend/src/App.jsx",
|
|
||||||
"accessCount": 12,
|
|
||||||
"lastAccessed": 1773309849392,
|
|
||||||
"type": "file"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"path": "frontend/src/components/AgentFeed.jsx",
|
|
||||||
"accessCount": 12,
|
|
||||||
"lastAccessed": 1773309960022,
|
|
||||||
"type": "file"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"path": ".env",
|
|
||||||
"accessCount": 7,
|
|
||||||
"lastAccessed": 1773308950505,
|
|
||||||
"type": "file"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"path": "frontend/src/components/RoomView.jsx",
|
|
||||||
"accessCount": 7,
|
|
||||||
"lastAccessed": 1773309864236,
|
|
||||||
"type": "file"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"path": "backend/tools/analysis_tools.py",
|
|
||||||
"accessCount": 5,
|
|
||||||
"lastAccessed": 1773312271446,
|
|
||||||
"type": "file"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"path": "frontend/src/components/Header.jsx",
|
|
||||||
"accessCount": 4,
|
|
||||||
"lastAccessed": 1773309827069,
|
|
||||||
"type": "file"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"path": "frontend/src/components/AboutModal.jsx",
|
|
||||||
"accessCount": 4,
|
|
||||||
"lastAccessed": 1773310093371,
|
|
||||||
"type": "file"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"path": "backend/agents/prompts/analyst/personas.yaml",
|
|
||||||
"accessCount": 4,
|
|
||||||
"lastAccessed": 1773312049213,
|
|
||||||
"type": "file"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"path": "backend/agents/prompts/analyst/system.md",
|
|
||||||
"accessCount": 4,
|
|
||||||
"lastAccessed": 1773312049696,
|
|
||||||
"type": "file"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"path": "backend/agents/prompts/portfolio_manager/system.md",
|
|
||||||
"accessCount": 4,
|
|
||||||
"lastAccessed": 1773312050326,
|
|
||||||
"type": "file"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"path": "backend/agents/prompts/risk_manager/system.md",
|
|
||||||
"accessCount": 4,
|
|
||||||
"lastAccessed": 1773312050782,
|
|
||||||
"type": "file"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"path": "frontend/src/config/constants.js",
|
|
||||||
"accessCount": 3,
|
|
||||||
"lastAccessed": 1773309824671,
|
|
||||||
"type": "file"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"path": "frontend/src/components/RulesView.jsx",
|
|
||||||
"accessCount": 3,
|
|
||||||
"lastAccessed": 1773310061939,
|
|
||||||
"type": "file"
|
"type": "file"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"path": "backend",
|
"path": "backend",
|
||||||
"accessCount": 3,
|
"accessCount": 16,
|
||||||
"lastAccessed": 1773312200721,
|
"lastAccessed": 1773940042371,
|
||||||
"type": "directory"
|
"type": "directory"
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"path": "",
|
||||||
|
"accessCount": 13,
|
||||||
|
"lastAccessed": 1773939899611,
|
||||||
|
"type": "directory"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"path": "backend/main.py",
|
||||||
|
"accessCount": 7,
|
||||||
|
"lastAccessed": 1773939993951,
|
||||||
|
"type": "file"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"path": "backend/gateway_server.py",
|
||||||
|
"accessCount": 7,
|
||||||
|
"lastAccessed": 1773940004402,
|
||||||
|
"type": "file"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"path": "backend/services/news/main.py",
|
||||||
|
"accessCount": 5,
|
||||||
|
"lastAccessed": 1773938385662,
|
||||||
|
"type": "file"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"path": "backend/core/pipeline.py",
|
||||||
|
"accessCount": 5,
|
||||||
|
"lastAccessed": 1773940024933,
|
||||||
|
"type": "file"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"path": "backend/services/news/enrich/news_enricher.py",
|
||||||
|
"accessCount": 4,
|
||||||
|
"lastAccessed": 1773938508417,
|
||||||
|
"type": "file"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"path": "start-dev.sh",
|
||||||
|
"accessCount": 4,
|
||||||
|
"lastAccessed": 1773939259381,
|
||||||
|
"type": "file"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"path": "services/README.md",
|
||||||
|
"accessCount": 4,
|
||||||
|
"lastAccessed": 1773939281935,
|
||||||
|
"type": "file"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"path": "backend/app.py",
|
||||||
|
"accessCount": 4,
|
||||||
|
"lastAccessed": 1773939648215,
|
||||||
|
"type": "file"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"path": "backend/services/news/routes/news.py",
|
||||||
|
"accessCount": 3,
|
||||||
|
"lastAccessed": 1773938438928,
|
||||||
|
"type": "file"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"path": "backend/services/news",
|
||||||
|
"accessCount": 3,
|
||||||
|
"lastAccessed": 1773938468730,
|
||||||
|
"type": "directory"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"path": "frontend/src/config/constants.js",
|
||||||
|
"accessCount": 3,
|
||||||
|
"lastAccessed": 1773939204395,
|
||||||
|
"type": "file"
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"path": "backend/services/gateway.py",
|
"path": "backend/services/gateway.py",
|
||||||
|
"accessCount": 3,
|
||||||
|
"lastAccessed": 1773939672930,
|
||||||
|
"type": "file"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"path": "backend/core/__init__.py",
|
||||||
|
"accessCount": 3,
|
||||||
|
"lastAccessed": 1773939963627,
|
||||||
|
"type": "file"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"path": "backend/services/trading/main.py",
|
||||||
"accessCount": 2,
|
"accessCount": 2,
|
||||||
"lastAccessed": 1773312232905,
|
"lastAccessed": 1773938360736,
|
||||||
|
"type": "file"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"path": "backend/services/agents/main.py",
|
||||||
|
"accessCount": 2,
|
||||||
|
"lastAccessed": 1773938361040,
|
||||||
|
"type": "file"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"path": "backend/services/trading/data/__init__.py",
|
||||||
|
"accessCount": 2,
|
||||||
|
"lastAccessed": 1773938402496,
|
||||||
|
"type": "file"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"path": "backend/services/news/explain/__init__.py",
|
||||||
|
"accessCount": 2,
|
||||||
|
"lastAccessed": 1773938460019,
|
||||||
|
"type": "file"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"path": "backend/services/news/enrich/__init__.py",
|
||||||
|
"accessCount": 2,
|
||||||
|
"lastAccessed": 1773938465216,
|
||||||
|
"type": "file"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"path": "backend/services/news/explain/range_explainer.py",
|
||||||
|
"accessCount": 2,
|
||||||
|
"lastAccessed": 1773938481152,
|
||||||
|
"type": "file"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"path": "backend/services/news/enrich/llm_enricher.py",
|
||||||
|
"accessCount": 2,
|
||||||
|
"lastAccessed": 1773938499885,
|
||||||
|
"type": "file"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"path": "CLAUDE.md",
|
||||||
|
"accessCount": 2,
|
||||||
|
"lastAccessed": 1773939273598,
|
||||||
|
"type": "file"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"path": "backend/agents/__init__.py",
|
||||||
|
"accessCount": 2,
|
||||||
|
"lastAccessed": 1773939883015,
|
||||||
|
"type": "file"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"path": "backend/agents/agent_core.py",
|
||||||
|
"accessCount": 2,
|
||||||
|
"lastAccessed": 1773939886997,
|
||||||
|
"type": "file"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"path": "Makefile",
|
||||||
|
"accessCount": 1,
|
||||||
|
"lastAccessed": 1773938226307,
|
||||||
|
"type": "file"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"path": "docker-compose.yml",
|
||||||
|
"accessCount": 1,
|
||||||
|
"lastAccessed": 1773938226360,
|
||||||
|
"type": "file"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"path": "backend/services/news/shared/trading_client.py",
|
||||||
|
"accessCount": 1,
|
||||||
|
"lastAccessed": 1773938370618,
|
||||||
|
"type": "file"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"path": "backend/services/agents",
|
||||||
|
"accessCount": 1,
|
||||||
|
"lastAccessed": 1773938397772,
|
||||||
"type": "directory"
|
"type": "directory"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"path": "README.md",
|
"path": "backend/services/trading",
|
||||||
"accessCount": 1,
|
"accessCount": 1,
|
||||||
"lastAccessed": 1773305013217,
|
"lastAccessed": 1773938397823,
|
||||||
|
"type": "directory"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"path": "backend/services",
|
||||||
|
"accessCount": 1,
|
||||||
|
"lastAccessed": 1773938405541,
|
||||||
|
"type": "directory"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"path": "backend/services/news/config.py",
|
||||||
|
"accessCount": 1,
|
||||||
|
"lastAccessed": 1773938638664,
|
||||||
"type": "file"
|
"type": "file"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"path": "README_zh.md",
|
"path": "shared/client/news_client.py",
|
||||||
"accessCount": 1,
|
"accessCount": 1,
|
||||||
"lastAccessed": 1773305013274,
|
"lastAccessed": 1773938638715,
|
||||||
"type": "file"
|
"type": "file"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"path": "env.template",
|
"path": "shared/client/trading_client.py",
|
||||||
"accessCount": 1,
|
"accessCount": 1,
|
||||||
"lastAccessed": 1773305019965,
|
"lastAccessed": 1773938638770,
|
||||||
|
"type": "file"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"path": "backend/api",
|
||||||
|
"accessCount": 1,
|
||||||
|
"lastAccessed": 1773938669143,
|
||||||
|
"type": "directory"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"path": "frontend",
|
||||||
|
"accessCount": 1,
|
||||||
|
"lastAccessed": 1773938669195,
|
||||||
|
"type": "directory"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"path": ".env.example",
|
||||||
|
"accessCount": 1,
|
||||||
|
"lastAccessed": 1773938849397,
|
||||||
"type": "file"
|
"type": "file"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"path": "frontend/src/services/websocket.js",
|
"path": "frontend/src/services/websocket.js",
|
||||||
"accessCount": 1,
|
"accessCount": 1,
|
||||||
"lastAccessed": 1773309324302,
|
"lastAccessed": 1773938849448,
|
||||||
"type": "file"
|
"type": "file"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"path": "backend/config/data_config.py",
|
"path": "frontend/src/services/runtimeApi.js",
|
||||||
"accessCount": 1,
|
"accessCount": 1,
|
||||||
"lastAccessed": 1773309324414,
|
"lastAccessed": 1773938849500,
|
||||||
"type": "file"
|
"type": "file"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"path": "backend/cli.py",
|
"path": "backend/services/agents/routes/websocket.py",
|
||||||
"accessCount": 1,
|
"accessCount": 1,
|
||||||
"lastAccessed": 1773309336899,
|
"lastAccessed": 1773939001692,
|
||||||
"type": "directory"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"path": "backend/agents/portfolio_manager.py",
|
|
||||||
"accessCount": 1,
|
|
||||||
"lastAccessed": 1773311956562,
|
|
||||||
"type": "file"
|
"type": "file"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"path": "backend/agents/risk_manager.py",
|
"path": "backend/services/agents/routes/agents.py",
|
||||||
"accessCount": 1,
|
"accessCount": 1,
|
||||||
"lastAccessed": 1773311956760,
|
"lastAccessed": 1773939016291,
|
||||||
|
"type": "file"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"path": "backend/services/agents/routes/run.py",
|
||||||
|
"accessCount": 1,
|
||||||
|
"lastAccessed": 1773939016343,
|
||||||
|
"type": "file"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"path": "backend/__init__.py",
|
||||||
|
"accessCount": 1,
|
||||||
|
"lastAccessed": 1773939648323,
|
||||||
|
"type": "file"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"path": "backend/api/__init__.py",
|
||||||
|
"accessCount": 1,
|
||||||
|
"lastAccessed": 1773939658650,
|
||||||
|
"type": "file"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"path": "backend/runtime/__init__.py",
|
||||||
|
"accessCount": 1,
|
||||||
|
"lastAccessed": 1773939658687,
|
||||||
|
"type": "file"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"path": "backend/agents/base/evo_agent.py",
|
||||||
|
"accessCount": 1,
|
||||||
|
"lastAccessed": 1773939664916,
|
||||||
"type": "file"
|
"type": "file"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"path": "backend/agents/analyst.py",
|
"path": "backend/agents/analyst.py",
|
||||||
"accessCount": 1,
|
"accessCount": 1,
|
||||||
"lastAccessed": 1773311963222,
|
"lastAccessed": 1773939664967,
|
||||||
"type": "file"
|
"type": "file"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"path": "backend/tools",
|
"path": "backend/agents/base/hooks.py",
|
||||||
"accessCount": 1,
|
"accessCount": 1,
|
||||||
"lastAccessed": 1773312289643,
|
"lastAccessed": 1773939672727,
|
||||||
"type": "directory"
|
"type": "file"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"path": "backend/tools/data_tools.py",
|
"path": "pyproject.toml",
|
||||||
"accessCount": 1,
|
"accessCount": 1,
|
||||||
"lastAccessed": 1773312293851,
|
"lastAccessed": 1773939672778,
|
||||||
"type": "directory"
|
"type": "file"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"userDirectives": []
|
"userDirectives": []
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"timestamp": "2026-03-12T20:33:59.497Z",
|
"timestamp": "2026-03-19T16:36:52.471Z",
|
||||||
"backgroundTasks": [],
|
"backgroundTasks": [],
|
||||||
"sessionStartTimestamp": "2026-03-12T14:19:33.615Z",
|
"sessionStartTimestamp": "2026-03-19T16:36:42.224Z",
|
||||||
"sessionId": "73b0d597-0141-4873-9d0e-2b60e4e0635e"
|
"sessionId": "ef02339a-1eec-4c7a-95ac-c8cfa0b5067d"
|
||||||
}
|
}
|
||||||
@@ -1 +1 @@
|
|||||||
{"session_id":"73b0d597-0141-4873-9d0e-2b60e4e0635e","transcript_path":"/Users/cillin/.claude/projects/-Users-cillin-workspeace-agentscope-samples-evotraders/73b0d597-0141-4873-9d0e-2b60e4e0635e.jsonl","cwd":"/Users/cillin/workspeace/agentscope-samples/evotraders","model":{"id":"kimi-for-coding","display_name":"kimi-for-coding"},"workspace":{"current_dir":"/Users/cillin/workspeace/agentscope-samples/evotraders","project_dir":"/Users/cillin/workspeace/agentscope-samples/evotraders","added_dirs":["/Users/cillin/workspeace/agentscope-samples/EvoTraders","/Users/cillin/workspeace/agentscope-samples/evotraders"]},"version":"2.1.63","output_style":{"name":"default"},"cost":{"total_cost_usd":6.822239999999999,"total_duration_ms":42679588,"total_api_duration_ms":1223637,"total_lines_added":275,"total_lines_removed":240},"context_window":{"total_input_tokens":654274,"total_output_tokens":27014,"context_window_size":200000,"current_usage":{"input_tokens":48465,"output_tokens":0,"cache_creation_input_tokens":0,"cache_read_input_tokens":0},"used_percentage":24,"remaining_percentage":76},"exceeds_200k_tokens":false}
|
{"session_id":"ef02339a-1eec-4c7a-95ac-c8cfa0b5067d","transcript_path":"/Users/cillin/.claude/projects/-Users-cillin-workspeace-evotraders/ef02339a-1eec-4c7a-95ac-c8cfa0b5067d.jsonl","cwd":"/Users/cillin/workspeace/evotraders","model":{"id":"MiniMax-M2.7-highspeed","display_name":"MiniMax-M2.7-highspeed"},"workspace":{"current_dir":"/Users/cillin/workspeace/evotraders","project_dir":"/Users/cillin/workspeace/evotraders","added_dirs":[]},"version":"2.1.78","output_style":{"name":"default"},"cost":{"total_cost_usd":17.458779250000003,"total_duration_ms":1866224,"total_api_duration_ms":1188013,"total_lines_added":257,"total_lines_removed":290},"context_window":{"total_input_tokens":195204,"total_output_tokens":48917,"context_window_size":200000,"current_usage":{"input_tokens":481,"output_tokens":0,"cache_creation_input_tokens":149,"cache_read_input_tokens":163286},"used_percentage":82,"remaining_percentage":18},"exceeds_200k_tokens":false}
|
||||||
@@ -1,3 +1,3 @@
|
|||||||
{
|
{
|
||||||
"lastSentAt": "2026-03-12T20:31:37.362Z"
|
"lastSentAt": "2026-03-19T17:02:32.170Z"
|
||||||
}
|
}
|
||||||
@@ -1,26 +1,17 @@
|
|||||||
{
|
{
|
||||||
"agents": [
|
"agents": [
|
||||||
{
|
{
|
||||||
"agent_id": "a4090d26a45ac828d",
|
"agent_id": "a8305a91e192b2196",
|
||||||
"agent_type": "oh-my-claudecode:executor",
|
"agent_type": "Explore",
|
||||||
"started_at": "2026-03-12T10:02:38.238Z",
|
"started_at": "2026-03-19T17:00:33.284Z",
|
||||||
"parent_mode": "none",
|
"parent_mode": "none",
|
||||||
"status": "completed",
|
"status": "completed",
|
||||||
"completed_at": "2026-03-12T10:10:59.192Z",
|
"completed_at": "2026-03-19T17:02:19.439Z",
|
||||||
"duration_ms": 500954
|
"duration_ms": 106155
|
||||||
},
|
|
||||||
{
|
|
||||||
"agent_id": "af87583ef76a4df30",
|
|
||||||
"agent_type": "oh-my-claudecode:executor",
|
|
||||||
"started_at": "2026-03-12T10:40:04.409Z",
|
|
||||||
"parent_mode": "none",
|
|
||||||
"status": "completed",
|
|
||||||
"completed_at": "2026-03-12T10:41:17.387Z",
|
|
||||||
"duration_ms": 72978
|
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"total_spawned": 2,
|
"total_spawned": 1,
|
||||||
"total_completed": 2,
|
"total_completed": 1,
|
||||||
"total_failed": 0,
|
"total_failed": 0,
|
||||||
"last_updated": "2026-03-12T10:41:17.490Z"
|
"last_updated": "2026-03-19T17:02:39.175Z"
|
||||||
}
|
}
|
||||||
302
CLAUDE.md
Normal file
302
CLAUDE.md
Normal file
@@ -0,0 +1,302 @@
|
|||||||
|
# CLAUDE.md
|
||||||
|
|
||||||
|
本文件为 Claude Code (claude.ai/code) 在此代码库中工作时提供指导。
|
||||||
|
|
||||||
|
## 项目概述
|
||||||
|
|
||||||
|
EvoTraders 是一个自进化多智能体交易系统,由 6 个 AI Agent(4 名分析师 + 投资经理 + 风控经理)协作完成交易决策。Agent 基于 AgentScope 框架构建,配合 ReMe 记忆系统实现持续学习。
|
||||||
|
|
||||||
|
## 常用命令
|
||||||
|
|
||||||
|
### Backend (Python)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 安装依赖
|
||||||
|
uv pip install -e .
|
||||||
|
|
||||||
|
# 运行命令
|
||||||
|
evotraders backtest --start 2025-11-01 --end 2025-12-01 # 回测模式
|
||||||
|
evotraders backtest --start 2025-11-01 --end 2025-12-01 --enable-memory # 带记忆回测
|
||||||
|
evotraders live # 实盘交易
|
||||||
|
evotraders live --mock # 模拟/测试模式
|
||||||
|
evotraders live -t 22:30 # 定时每日交易
|
||||||
|
evotraders frontend # 启动可视化界面
|
||||||
|
|
||||||
|
# 开发服务器
|
||||||
|
./start-dev.sh # 启动全部 4 个微服务
|
||||||
|
|
||||||
|
# 单独启动某个服务
|
||||||
|
python -m uvicorn backend.apps.agent_service:app --host 0.0.0.0 --port 8000 --reload
|
||||||
|
python -m uvicorn backend.apps.runtime_service:app --host 0.0.0.0 --port 8003 --reload
|
||||||
|
python -m uvicorn backend.apps.trading_service:app --host 0.0.0.0 --port 8001 --reload
|
||||||
|
python -m uvicorn backend.apps.news_service:app --host 0.0.0.0 --port 8002 --reload
|
||||||
|
|
||||||
|
# 测试
|
||||||
|
pytest backend/tests # 运行全部测试
|
||||||
|
pytest backend/tests/test_news_service_app.py -v # 运行单个测试文件
|
||||||
|
pytest backend/tests/test_news_service_app.py::test_news_service_routes_are_exposed -v # 运行单个测试
|
||||||
|
```
|
||||||
|
|
||||||
|
### Frontend (React)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd frontend
|
||||||
|
npm run dev # Vite 开发服务器 (http://localhost:5173)
|
||||||
|
npm run build # 生产构建
|
||||||
|
npm run lint # ESLint 检查
|
||||||
|
npm run lint:fix # ESLint 自动修复
|
||||||
|
npm run test # Vitest 单元测试
|
||||||
|
npm run test:watch # 监听模式
|
||||||
|
```
|
||||||
|
|
||||||
|
## 架构概览
|
||||||
|
|
||||||
|
### 微服务架构 (`backend/apps/`)
|
||||||
|
|
||||||
|
项目采用 split-first 微服务架构,4 个独立的 FastAPI 服务:
|
||||||
|
|
||||||
|
| 服务 | 入口 | 端口 | 职责 |
|
||||||
|
|------|------|------|------|
|
||||||
|
| agent_service | `backend.apps.agent_service:app` | 8000 | Agent 生命周期、工作区管理 |
|
||||||
|
| runtime_service | `backend.apps.runtime_service:app` | 8003 | 运行时配置、任务启动 |
|
||||||
|
| trading_service | `backend.apps.trading_service:app` | 8001 | 市场数据、交易操作 |
|
||||||
|
| news_service | `backend.apps.news_service:app` | 8002 | 新闻、新闻富化、解释功能 |
|
||||||
|
|
||||||
|
服务间通过环境变量通信(详见 `start-dev.sh`):
|
||||||
|
```bash
|
||||||
|
export TRADING_SERVICE_URL=http://localhost:8001
|
||||||
|
export NEWS_SERVICE_URL=http://localhost:8002
|
||||||
|
export RUNTIME_SERVICE_URL=http://localhost:8003
|
||||||
|
```
|
||||||
|
|
||||||
|
### Gateway 网关 (`backend/services/gateway.py`)
|
||||||
|
|
||||||
|
Gateway 是统一的请求路由器,根据路径前缀将请求转发到对应的微服务:
|
||||||
|
- `/control/*` → agent_service
|
||||||
|
- `/runtime/*` → runtime_service
|
||||||
|
- `/trading/*` → trading_service
|
||||||
|
- `/news/*` → news_service
|
||||||
|
|
||||||
|
新增接口时应注册到对应的 service app,而非直接添加到 gateway。
|
||||||
|
|
||||||
|
### 共享客户端 (`shared/client/`)
|
||||||
|
|
||||||
|
统一的服务客户端库,所有前端和后端服务间通信都使用此处定义的客户端:
|
||||||
|
|
||||||
|
| 客户端 | 用途 |
|
||||||
|
|--------|------|
|
||||||
|
| `ControlPlaneClient` | Agent 服务通信 |
|
||||||
|
| `RuntimeServiceClient` | 运行时服务通信 |
|
||||||
|
| `TradingServiceClient` | 交易服务通信 |
|
||||||
|
| `NewsServiceClient` | 新闻服务通信 |
|
||||||
|
|
||||||
|
### 领域层 (`backend/domains/`)
|
||||||
|
|
||||||
|
业务逻辑按领域分离:
|
||||||
|
|
||||||
|
- `news.py` - 新闻领域操作
|
||||||
|
- `trading.py` - 交易领域操作
|
||||||
|
|
||||||
|
## 后端结构
|
||||||
|
|
||||||
|
```
|
||||||
|
backend/
|
||||||
|
├── agents/ # 多智能体实现
|
||||||
|
│ ├── base/ # 核心类、Hooks、评估
|
||||||
|
│ │ ├── evo_agent.py # 基于 AgentScope 的核心实现
|
||||||
|
│ │ ├── hooks.py # 生命周期 Hooks
|
||||||
|
│ │ │ ├── BootstrapHook # 启动初始化
|
||||||
|
│ │ │ ├── MemoryCompactionHook # 内存压缩(基于 CoPaw)
|
||||||
|
│ │ │ ├── HeartbeatHook # 心跳检测
|
||||||
|
│ │ │ └── WorkspaceWatchHook # 工作区监控
|
||||||
|
│ │ ├── evaluation_hook.py # 执行后评估
|
||||||
|
│ │ ├── skill_adaptation_hook.py # 动态技能适配
|
||||||
|
│ │ └── tool_guard.py # 工具调用守卫
|
||||||
|
│ ├── prompts/ # Agent 提示词和角色定义
|
||||||
|
│ │ ├── analyst/personas.yaml # 分析师角色配置
|
||||||
|
│ │ └── portfolio_manager/
|
||||||
|
│ ├── team/ # 团队协作逻辑
|
||||||
|
│ │ ├── registry.py # Agent 注册表
|
||||||
|
│ │ ├── coordinator.py # 协作协调器
|
||||||
|
│ │ ├── messenger.py # 消息传递
|
||||||
|
│ │ └── task_delegator.py # 任务分发
|
||||||
|
│ ├── factory.py # Agent 实例工厂
|
||||||
|
│ ├── skills_manager.py # 技能加载管理(6 种作用域)
|
||||||
|
│ └── toolkit_factory.py # 工具集工厂
|
||||||
|
├── apps/ # 微服务入口(split-first)
|
||||||
|
│ ├── agent_service.py
|
||||||
|
│ ├── runtime_service.py
|
||||||
|
│ ├── trading_service.py
|
||||||
|
│ └── news_service.py
|
||||||
|
├── domains/ # 领域业务逻辑
|
||||||
|
│ ├── news.py
|
||||||
|
│ └── trading.py
|
||||||
|
├── services/ # Gateway 和辅助服务
|
||||||
|
│ ├── gateway.py # 统一路由网关
|
||||||
|
│ ├── gateway_*.py # Gateway 子模块
|
||||||
|
│ └── market.py # 市场数据服务
|
||||||
|
├── api/ # FastAPI 端点
|
||||||
|
├── config/ # 常量和配置
|
||||||
|
│ └── constants.py # Agent 配置、显示名称等
|
||||||
|
├── core/ # Pipeline 执行逻辑
|
||||||
|
├── data/ # 市场数据处理
|
||||||
|
│ ├── provider_router.py # 数据源路由
|
||||||
|
│ └── schema.py # 数据 schema
|
||||||
|
├── enrich/ # LLM 响应富化
|
||||||
|
├── explain/ # 交易决策解释
|
||||||
|
├── llm/ # LLM 集成
|
||||||
|
│ └── models.py # RetryChatModel、TokenRecordingModelWrapper
|
||||||
|
├── skills/ # 技能定义(内置 + 自定义)
|
||||||
|
├── tools/ # 交易和分析工具
|
||||||
|
└── utils/ # 工具函数
|
||||||
|
```
|
||||||
|
|
||||||
|
## 前端结构
|
||||||
|
|
||||||
|
```
|
||||||
|
frontend/src/
|
||||||
|
├── App.jsx # React 主应用
|
||||||
|
├── components/ # React 组件
|
||||||
|
│ ├── RuntimeView.jsx # 交易运行时 UI
|
||||||
|
│ ├── TraderView.jsx # 交易员界面
|
||||||
|
│ ├── RoomView.jsx # 聊天室视图
|
||||||
|
│ ├── StockExplainView.jsx # 股票解释视图
|
||||||
|
│ ├── RuntimeSettingsPanel.jsx # 运行时设置面板
|
||||||
|
│ ├── WatchlistPanel.jsx # 关注列表
|
||||||
|
│ ├── PerformanceView.jsx # 绩效视图
|
||||||
|
│ ├── StatisticsView.jsx # 统计视图
|
||||||
|
│ ├── NetValueChart.jsx # 净值曲线图
|
||||||
|
│ ├── AgentCard.jsx # Agent 卡片
|
||||||
|
│ ├── AgentFeed.jsx # Agent 动态
|
||||||
|
│ └── explain/ # 解释相关组件
|
||||||
|
│ ├── ExplainNewsSection.jsx
|
||||||
|
│ ├── ExplainRangeSection.jsx
|
||||||
|
│ ├── ExplainSimilarDaysSection.jsx
|
||||||
|
│ ├── ExplainStorySection.jsx
|
||||||
|
│ └── useExplainModel.js
|
||||||
|
├── services/ # API 服务
|
||||||
|
│ ├── runtimeApi.js # 运行时 API 调用
|
||||||
|
│ ├── websocket.js # WebSocket 实时通信
|
||||||
|
│ ├── newsApi.js # 新闻服务客户端
|
||||||
|
│ └── tradingApi.js # 交易服务客户端
|
||||||
|
├── config/
|
||||||
|
│ └── constants.js # Agent 定义、配置
|
||||||
|
└── hooks/ # React Hooks
|
||||||
|
```
|
||||||
|
|
||||||
|
## Agent 系统
|
||||||
|
|
||||||
|
### 6 种 Agent 角色
|
||||||
|
|
||||||
|
| 角色 ID | 名称 | 职责 |
|
||||||
|
|---------|------|------|
|
||||||
|
| `fundamentals_analyst` | 基本面分析师 | 财务健康、盈利能力、成长质量 |
|
||||||
|
| `technical_analyst` | 技术分析师 | 价格趋势、技术指标、动量分析 |
|
||||||
|
| `sentiment_analyst` | 情绪分析师 | 市场情绪、新闻情绪、内幕交易 |
|
||||||
|
| `valuation_analyst` | 估值分析师 | DCF、EV/EBITDA、 intrinsic value |
|
||||||
|
| `portfolio_manager` | 投资经理 | 决策执行、交易协调 |
|
||||||
|
| `risk_manager` | 风控经理 | 实时价格/波动率监控、仓位限制、多层风险预警 |
|
||||||
|
|
||||||
|
### Hook 系统 (`base/hooks.py`)
|
||||||
|
|
||||||
|
- **MemoryCompactionHook**: 基于 CoPaw 的内存压缩
|
||||||
|
- `memory_compact_ratio`: 压缩目标比例(默认 0.75)
|
||||||
|
- `memory_reserve_ratio`: 保留比例(默认 0.1)
|
||||||
|
- `enable_tool_result_compact`: 工具结果压缩
|
||||||
|
- `tool_result_compact_keep_n`: 保留最近 N 条工具结果
|
||||||
|
|
||||||
|
### 添加自定义分析师
|
||||||
|
|
||||||
|
1. 在 `backend/agents/prompts/analyst/personas.yaml` 注册
|
||||||
|
2. 在 `backend/config/constants.py` 的 `ANALYST_TYPES` 字典中添加
|
||||||
|
3. 可选:在 `frontend/src/config/constants.js` 中更新前端配置
|
||||||
|
|
||||||
|
### LLM 模型封装 (`backend/llm/models.py`)
|
||||||
|
|
||||||
|
基于 CoPaw 的模型封装设计:
|
||||||
|
|
||||||
|
- **RetryChatModel**: 自动重试瞬态 LLM 错误(rate limit、timeout、502/503 等),指数退避
|
||||||
|
- `max_retries`: 最大重试次数(默认 3)
|
||||||
|
- `initial_delay`: 初始延迟秒数(默认 1.0)
|
||||||
|
- `backoff_multiplier`: 退避倍数(默认 2.0)
|
||||||
|
|
||||||
|
- **TokenRecordingModelWrapper**: 追踪每个 provider 的 token 消耗和成本
|
||||||
|
|
||||||
|
```python
|
||||||
|
from backend.llm.models import create_model, RetryChatModel
|
||||||
|
|
||||||
|
model = RetryChatModel(create_model("gpt-4o", "OPENAI"), max_retries=3)
|
||||||
|
```
|
||||||
|
|
||||||
|
## 技能系统 (`backend/skills/`)
|
||||||
|
|
||||||
|
技能定义在 `SKILL.md` 文件中,包含:
|
||||||
|
- `instructions` - 技能说明
|
||||||
|
- `triggers` - 触发条件
|
||||||
|
- `parameters` - 输入/输出 schema
|
||||||
|
- `available_tools` - 技能可使用的工具
|
||||||
|
|
||||||
|
技能由 `skills_manager.py` 加载,通过 `skill_adaptation_hook.py` 绑定到 Agent。
|
||||||
|
|
||||||
|
技能管理器支持 6 种作用域:builtin、customized、installed、active、disabled、local。
|
||||||
|
|
||||||
|
## Pipeline 执行 (`backend/core/`)
|
||||||
|
|
||||||
|
每日交易流程:
|
||||||
|
|
||||||
|
1. **分析阶段** - 各 Agent 基于工具和历史经验独立分析
|
||||||
|
2. **沟通阶段** - 通过私聊、通知、会议等方式交换观点(1v1/1vN/NvN)
|
||||||
|
3. **决策阶段** - 投资经理综合判断,给出最终交易
|
||||||
|
4. **评估阶段** - 绩效跟踪
|
||||||
|
5. **复盘阶段** - Agent 根据当日实际收益反思总结,通过 ReMe 记忆框架更新经验
|
||||||
|
|
||||||
|
## 前端状态管理
|
||||||
|
|
||||||
|
项目正在向 Zustand 状态管理过渡,已创建的 store:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
frontend/src/store/
|
||||||
|
├── index.js # 导出所有 store
|
||||||
|
├── runtimeStore.js # 连接状态、运行时配置
|
||||||
|
├── marketStore.js # 市场数据、股票价格
|
||||||
|
├── portfolioStore.js # 组合、持仓、交易
|
||||||
|
├── agentStore.js # Agent 技能、工作区
|
||||||
|
└── uiStore.js # UI 状态、视图切换
|
||||||
|
```
|
||||||
|
|
||||||
|
**迁移状态**:
|
||||||
|
- Stores 已创建但尚未在 App.jsx 中使用
|
||||||
|
- 计划:逐步迁移 60+ 个 useState 到对应 store
|
||||||
|
|
||||||
|
## 环境配置
|
||||||
|
|
||||||
|
`.env` 必需配置:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 金融数据源
|
||||||
|
FIN_DATA_SOURCE=finnhub|financial_datasets
|
||||||
|
FINANCIAL_DATASETS_API_KEY= # 回测必需
|
||||||
|
FINNHUB_API_KEY= # 实盘必需
|
||||||
|
|
||||||
|
# Agent LLM
|
||||||
|
OPENAI_API_KEY=
|
||||||
|
OPENAI_BASE_URL=
|
||||||
|
MODEL_NAME=qwen3-max-preview
|
||||||
|
|
||||||
|
# 可为不同 Agent 指定不同模型
|
||||||
|
AGENT_SENTIMENT_ANALYST_MODEL_NAME=qwen3-max-preview
|
||||||
|
AGENT_FUNDAMENTALS_ANALYST_MODEL_NAME=deepseek-chat
|
||||||
|
|
||||||
|
# ReMe 记忆系统
|
||||||
|
MEMORY_API_KEY=
|
||||||
|
```
|
||||||
|
|
||||||
|
## 关键依赖
|
||||||
|
|
||||||
|
- **AgentScope** - 多智能体框架
|
||||||
|
- **ReMe** - 持续学习记忆系统
|
||||||
|
- **FastAPI** + **uvicorn** - 后端 API 服务器
|
||||||
|
- **websockets** - 实时通信
|
||||||
|
- **React 19** + **Vite** + **TailwindCSS** - 前端
|
||||||
|
- **React Context** - 前端状态管理(App.jsx 中使用 useState + useCallback)
|
||||||
|
- **Three.js** / **React-Three-Fiber** - 3D 可视化
|
||||||
51
README_zh.md
51
README_zh.md
@@ -96,8 +96,11 @@ evotraders live # 立即运行(默认)
|
|||||||
evotraders live --enable-memory # 使用记忆
|
evotraders live --enable-memory # 使用记忆
|
||||||
evotraders live --mock # Mock 模式(测试)
|
evotraders live --mock # Mock 模式(测试)
|
||||||
evotraders live -t 22:30 # 每天本地时间 22:30 运行(自动转换为 NYSE 时区)
|
evotraders live -t 22:30 # 每天本地时间 22:30 运行(自动转换为 NYSE 时区)
|
||||||
|
evotraders live --schedule-mode intraday --interval-minutes 60 # 每隔 1 小时触发一次;仅交易时段执行交易,其他时段只分析
|
||||||
```
|
```
|
||||||
|
|
||||||
|
前端的“运行设置”面板也支持热更新 `schedule_mode`、`interval_minutes`、`max_comm_cycles`;其中 daily 模式时间当前按 NYSE/ET 配置。
|
||||||
|
|
||||||
**获取帮助:**
|
**获取帮助:**
|
||||||
```bash
|
```bash
|
||||||
evotraders --help # 查看整体命令行帮助
|
evotraders --help # 查看整体命令行帮助
|
||||||
@@ -114,6 +117,54 @@ evotraders frontend # 默认连接 8765 端口, 你可以修改 .
|
|||||||
|
|
||||||
访问 `http://localhost:5173/` 查看交易大厅,选择日期并点击 Run/Replay 观察决策过程。
|
访问 `http://localhost:5173/` 查看交易大厅,选择日期并点击 Run/Replay 观察决策过程。
|
||||||
|
|
||||||
|
### 迁移期服务边界说明
|
||||||
|
|
||||||
|
当前仓库正处于从模块化单体向独立服务迁移的阶段,当前默认开发路径已经切到独立 app surface:
|
||||||
|
|
||||||
|
- `backend.apps.agent_service`
|
||||||
|
- `backend.apps.runtime_service`
|
||||||
|
- `backend.apps.trading_service`
|
||||||
|
- `backend.apps.news_service`
|
||||||
|
|
||||||
|
当前本地开发默认推荐直接运行拆分后的服务:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./start-dev.sh split
|
||||||
|
|
||||||
|
# 或分别手动启动
|
||||||
|
python -m uvicorn backend.apps.agent_service:app --port 8000 --reload
|
||||||
|
python -m uvicorn backend.apps.runtime_service:app --port 8003 --reload
|
||||||
|
python -m uvicorn backend.apps.trading_service:app --port 8001 --reload
|
||||||
|
python -m uvicorn backend.apps.news_service:app --port 8002 --reload
|
||||||
|
```
|
||||||
|
|
||||||
|
迁移期关键环境变量:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 后端 Gateway 优先走独立服务读取
|
||||||
|
NEWS_SERVICE_URL=http://localhost:8002
|
||||||
|
TRADING_SERVICE_URL=http://localhost:8001
|
||||||
|
|
||||||
|
# 前端浏览器直连控制面 / 运行时面
|
||||||
|
VITE_CONTROL_API_BASE_URL=http://localhost:8000/api
|
||||||
|
VITE_RUNTIME_API_BASE_URL=http://localhost:8003/api/runtime
|
||||||
|
|
||||||
|
# 前端浏览器优先直连独立服务
|
||||||
|
VITE_NEWS_SERVICE_URL=http://localhost:8002
|
||||||
|
VITE_TRADING_SERVICE_URL=http://localhost:8001
|
||||||
|
```
|
||||||
|
|
||||||
|
目前前端已支持直连 `news-service` 的 explain 只读路径包括:
|
||||||
|
|
||||||
|
- runtime panel / gateway port 查询已可独立指向 `runtime-service`
|
||||||
|
- story
|
||||||
|
- similar days
|
||||||
|
- range explain
|
||||||
|
- news for date
|
||||||
|
- news categories
|
||||||
|
|
||||||
|
如果没有配置这些变量,系统会继续走当前保留的本地回退逻辑。
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## 系统架构
|
## 系统架构
|
||||||
|
|||||||
@@ -1,6 +1,57 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
Agents package - EvoAgent architecture for trading system.
|
||||||
|
|
||||||
|
Exports:
|
||||||
|
- EvoAgent: Next-generation agent with workspace support
|
||||||
|
- ToolGuardMixin: Tool call approval/denial flow
|
||||||
|
- CommandHandler: System command handling
|
||||||
|
- AgentFactory: Dynamic agent creation and management
|
||||||
|
- WorkspaceManager: Legacy name for the persistent workspace registry
|
||||||
|
- WorkspaceRegistry: Explicit run-time-agnostic workspace registry
|
||||||
|
- RunWorkspaceManager: Run-scoped workspace asset manager
|
||||||
|
- AgentRegistry: Central agent registry
|
||||||
|
- Legacy compatibility: AnalystAgent, PMAgent, RiskAgent
|
||||||
|
"""
|
||||||
|
|
||||||
|
# New EvoAgent architecture (from agent_core.py)
|
||||||
|
from .agent_core import EvoAgent, ToolGuardMixin, CommandHandler
|
||||||
|
from .factory import AgentFactory, ModelConfig, RoleConfig
|
||||||
|
from .workspace import WorkspaceManager, WorkspaceRegistry, WorkspaceConfig
|
||||||
|
from .workspace_manager import RunWorkspaceManager
|
||||||
|
from .registry import AgentRegistry, AgentInfo, get_registry, reset_registry
|
||||||
|
|
||||||
|
# Legacy agents (backward compatibility)
|
||||||
from .analyst import AnalystAgent
|
from .analyst import AnalystAgent
|
||||||
from .portfolio_manager import PMAgent
|
from .portfolio_manager import PMAgent
|
||||||
from .risk_manager import RiskAgent
|
from .risk_manager import RiskAgent
|
||||||
|
|
||||||
__all__ = ["AnalystAgent", "PMAgent", "RiskAgent"]
|
# Compatibility layer
|
||||||
|
from .compat import LegacyAgentAdapter, adapt_agent, adapt_agents, is_legacy_agent
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# New architecture
|
||||||
|
"EvoAgent",
|
||||||
|
"ToolGuardMixin",
|
||||||
|
"CommandHandler",
|
||||||
|
"AgentFactory",
|
||||||
|
"ModelConfig",
|
||||||
|
"RoleConfig",
|
||||||
|
"WorkspaceManager",
|
||||||
|
"WorkspaceRegistry",
|
||||||
|
"WorkspaceConfig",
|
||||||
|
"RunWorkspaceManager",
|
||||||
|
"AgentRegistry",
|
||||||
|
"AgentInfo",
|
||||||
|
"get_registry",
|
||||||
|
"reset_registry",
|
||||||
|
# Legacy compatibility
|
||||||
|
"AnalystAgent",
|
||||||
|
"PMAgent",
|
||||||
|
"RiskAgent",
|
||||||
|
# Compatibility layer
|
||||||
|
"LegacyAgentAdapter",
|
||||||
|
"adapt_agent",
|
||||||
|
"adapt_agents",
|
||||||
|
"is_legacy_agent",
|
||||||
|
]
|
||||||
|
|||||||
18
backend/agents/agent_core.py
Normal file
18
backend/agents/agent_core.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
Compatibility layer for legacy imports.
|
||||||
|
|
||||||
|
This module re-exports the newer base implementations so existing import
|
||||||
|
paths (`from backend.agents.agent_core import EvoAgent`) continue to work while
|
||||||
|
centralizing the actual logic in `backend.agents.base.evo_agent`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .base.command_handler import CommandHandler
|
||||||
|
from .base.evo_agent import EvoAgent
|
||||||
|
from .base.tool_guard import ToolGuardMixin
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"EvoAgent",
|
||||||
|
"ToolGuardMixin",
|
||||||
|
"CommandHandler",
|
||||||
|
]
|
||||||
75
backend/agents/agent_workspace.py
Normal file
75
backend/agents/agent_workspace.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""Per-agent run-scoped workspace configuration helpers."""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class AgentWorkspaceConfig:
|
||||||
|
"""Structured agent config loaded from runs/<config>/agents/<agent>/agent.yaml."""
|
||||||
|
|
||||||
|
values: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def prompt_files(self) -> Optional[List[str]]:
|
||||||
|
raw = self.values.get("prompt_files")
|
||||||
|
if not isinstance(raw, list):
|
||||||
|
return None
|
||||||
|
files = [
|
||||||
|
str(item).strip()
|
||||||
|
for item in raw
|
||||||
|
if isinstance(item, str) and str(item).strip()
|
||||||
|
]
|
||||||
|
return files or None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def enabled_skills(self) -> List[str]:
|
||||||
|
return _normalized_string_list(self.values.get("enabled_skills"))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def disabled_skills(self) -> List[str]:
|
||||||
|
return _normalized_string_list(self.values.get("disabled_skills"))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def active_tool_groups(self) -> Optional[List[str]]:
|
||||||
|
groups = _normalized_string_list(self.values.get("active_tool_groups"))
|
||||||
|
return groups or None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def disabled_tool_groups(self) -> List[str]:
|
||||||
|
return _normalized_string_list(self.values.get("disabled_tool_groups"))
|
||||||
|
|
||||||
|
def get(self, key: str, default: Any = None) -> Any:
|
||||||
|
return self.values.get(key, default)
|
||||||
|
|
||||||
|
|
||||||
|
def _normalized_string_list(raw: Any) -> List[str]:
|
||||||
|
if not isinstance(raw, list):
|
||||||
|
return []
|
||||||
|
seen: List[str] = []
|
||||||
|
for item in raw:
|
||||||
|
if not isinstance(item, str):
|
||||||
|
continue
|
||||||
|
value = item.strip()
|
||||||
|
if value and value not in seen:
|
||||||
|
seen.append(value)
|
||||||
|
return seen
|
||||||
|
|
||||||
|
|
||||||
|
def load_agent_workspace_config(path: Path) -> AgentWorkspaceConfig:
|
||||||
|
"""Load agent.yaml if present."""
|
||||||
|
if not path.exists() or not path.is_file():
|
||||||
|
return AgentWorkspaceConfig()
|
||||||
|
|
||||||
|
raw = path.read_text(encoding="utf-8").strip()
|
||||||
|
if not raw:
|
||||||
|
return AgentWorkspaceConfig()
|
||||||
|
|
||||||
|
parsed = yaml.safe_load(raw) or {}
|
||||||
|
if not isinstance(parsed, dict):
|
||||||
|
parsed = {}
|
||||||
|
return AgentWorkspaceConfig(values=parsed)
|
||||||
@@ -48,15 +48,19 @@ class AnalystAgent(ReActAgent):
|
|||||||
f"Must be one of: {list(ANALYST_TYPES.keys())}",
|
f"Must be one of: {list(ANALYST_TYPES.keys())}",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.analyst_type_key = analyst_type
|
object.__setattr__(self, "analyst_type_key", analyst_type)
|
||||||
self.analyst_persona = ANALYST_TYPES[analyst_type]["display_name"]
|
object.__setattr__(
|
||||||
|
self,
|
||||||
|
"analyst_persona",
|
||||||
|
ANALYST_TYPES[analyst_type]["display_name"],
|
||||||
|
)
|
||||||
|
|
||||||
if agent_id is None:
|
if agent_id is None:
|
||||||
agent_id = analyst_type
|
agent_id = analyst_type
|
||||||
self.agent_id = agent_id
|
object.__setattr__(self, "agent_id", agent_id)
|
||||||
|
|
||||||
self.config = config or {}
|
object.__setattr__(self, "config", config or {})
|
||||||
self.toolkit = toolkit
|
object.__setattr__(self, "toolkit", toolkit)
|
||||||
sys_prompt = self._load_system_prompt()
|
sys_prompt = self._load_system_prompt()
|
||||||
|
|
||||||
kwargs = {
|
kwargs = {
|
||||||
@@ -125,4 +129,12 @@ class AnalystAgent(ReActAgent):
|
|||||||
self.config.get("config_name", "default"),
|
self.config.get("config_name", "default"),
|
||||||
active_skill_dirs=active_skill_dirs,
|
active_skill_dirs=active_skill_dirs,
|
||||||
)
|
)
|
||||||
self.sys_prompt = self._load_system_prompt()
|
self._apply_runtime_sys_prompt(self._load_system_prompt())
|
||||||
|
|
||||||
|
def _apply_runtime_sys_prompt(self, sys_prompt: str) -> None:
|
||||||
|
"""Update the prompt used by future turns and the cached system msg."""
|
||||||
|
self._sys_prompt = sys_prompt
|
||||||
|
for msg, _marks in self.memory.content:
|
||||||
|
if getattr(msg, "role", None) == "system":
|
||||||
|
msg.content = sys_prompt
|
||||||
|
break
|
||||||
|
|||||||
57
backend/agents/base/__init__.py
Normal file
57
backend/agents/base/__init__.py
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""Base agent module for EvoTraders.
|
||||||
|
|
||||||
|
提供Agent基础类、命令处理、工具守卫和钩子管理等功能。
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 命令处理器 (从command_handler.py导入)
|
||||||
|
from .command_handler import (
|
||||||
|
AgentCommandDispatcher,
|
||||||
|
CommandContext,
|
||||||
|
CommandHandler,
|
||||||
|
CommandResult,
|
||||||
|
create_command_dispatcher,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 评估钩子 (从evaluation_hook.py导入)
|
||||||
|
from .evaluation_hook import (
|
||||||
|
EvaluationHook,
|
||||||
|
EvaluationCollector,
|
||||||
|
MetricType,
|
||||||
|
EvaluationMetric,
|
||||||
|
EvaluationResult,
|
||||||
|
parse_evaluation_hooks,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 技能适配钩子 (从skill_adaptation_hook.py导入)
|
||||||
|
from .skill_adaptation_hook import (
|
||||||
|
AdaptationAction,
|
||||||
|
AdaptationThreshold,
|
||||||
|
AdaptationEvent,
|
||||||
|
SkillAdaptationHook,
|
||||||
|
AdaptationManager,
|
||||||
|
get_adaptation_manager,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# 命令处理
|
||||||
|
"AgentCommandDispatcher",
|
||||||
|
"CommandContext",
|
||||||
|
"CommandHandler",
|
||||||
|
"CommandResult",
|
||||||
|
"create_command_dispatcher",
|
||||||
|
# 评估钩子
|
||||||
|
"EvaluationHook",
|
||||||
|
"EvaluationCollector",
|
||||||
|
"MetricType",
|
||||||
|
"EvaluationMetric",
|
||||||
|
"EvaluationResult",
|
||||||
|
"parse_evaluation_hooks",
|
||||||
|
# 技能适配钩子
|
||||||
|
"AdaptationAction",
|
||||||
|
"AdaptationThreshold",
|
||||||
|
"AdaptationEvent",
|
||||||
|
"SkillAdaptationHook",
|
||||||
|
"AdaptationManager",
|
||||||
|
"get_adaptation_manager",
|
||||||
|
]
|
||||||
543
backend/agents/base/command_handler.py
Normal file
543
backend/agents/base/command_handler.py
Normal file
@@ -0,0 +1,543 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""Agent command handler for system commands.
|
||||||
|
|
||||||
|
This module handles system commands like /save, /compact, /skills, /reload, etc.
|
||||||
|
参考CoPaw设计,为EvoAgent提供命令处理能力。
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Protocol
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .agent import EvoAgent
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CommandResult:
|
||||||
|
"""命令执行结果"""
|
||||||
|
success: bool
|
||||||
|
message: str
|
||||||
|
data: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class CommandContext:
|
||||||
|
"""命令执行上下文"""
|
||||||
|
|
||||||
|
def __init__(self, agent: "EvoAgent", raw_query: str, args: str = ""):
|
||||||
|
self.agent = agent
|
||||||
|
self.raw_query = raw_query
|
||||||
|
self.args = args
|
||||||
|
self.config_name = getattr(agent, "config_name", "default")
|
||||||
|
self.agent_id = getattr(agent, "agent_id", "unknown")
|
||||||
|
|
||||||
|
|
||||||
|
class CommandHandler(ABC):
|
||||||
|
"""命令处理器抽象基类"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def handle(self, ctx: CommandContext) -> CommandResult:
|
||||||
|
"""处理命令"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SaveCommandHandler(CommandHandler):
|
||||||
|
"""处理 /save <message> 命令 - 保存内容到MEMORY.md"""
|
||||||
|
|
||||||
|
async def handle(self, ctx: CommandContext) -> CommandResult:
|
||||||
|
message = ctx.args.strip()
|
||||||
|
if not message:
|
||||||
|
return CommandResult(
|
||||||
|
success=False,
|
||||||
|
message="Usage: /save <message>\n请提供要保存的内容。"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
memory_path = self._get_memory_path(ctx)
|
||||||
|
memory_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
timestamp = self._get_timestamp()
|
||||||
|
entry = f"\n## {timestamp}\n\n{message}\n"
|
||||||
|
|
||||||
|
with open(memory_path, "a", encoding="utf-8") as f:
|
||||||
|
f.write(entry)
|
||||||
|
|
||||||
|
return CommandResult(
|
||||||
|
success=True,
|
||||||
|
message=f"✅ 内容已保存到 MEMORY.md\n- 路径: {memory_path}\n- 长度: {len(message)} 字符",
|
||||||
|
data={"path": str(memory_path), "length": len(message)}
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to save to MEMORY.md: {e}")
|
||||||
|
return CommandResult(
|
||||||
|
success=False,
|
||||||
|
message=f"❌ 保存失败: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_memory_path(self, ctx: CommandContext) -> Path:
|
||||||
|
"""获取MEMORY.md路径"""
|
||||||
|
from backend.agents.skills_manager import SkillsManager
|
||||||
|
sm = SkillsManager()
|
||||||
|
asset_dir = sm.get_agent_asset_dir(ctx.config_name, ctx.agent_id)
|
||||||
|
return asset_dir / "MEMORY.md"
|
||||||
|
|
||||||
|
def _get_timestamp(self) -> str:
|
||||||
|
"""获取当前时间戳"""
|
||||||
|
from datetime import datetime
|
||||||
|
return datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
|
||||||
|
|
||||||
|
class CompactCommandHandler(CommandHandler):
|
||||||
|
"""处理 /compact 命令 - 压缩记忆"""
|
||||||
|
|
||||||
|
async def handle(self, ctx: CommandContext) -> CommandResult:
|
||||||
|
try:
|
||||||
|
agent = ctx.agent
|
||||||
|
memory_manager = getattr(agent, "memory_manager", None)
|
||||||
|
|
||||||
|
if memory_manager is None:
|
||||||
|
return CommandResult(
|
||||||
|
success=False,
|
||||||
|
message="❌ Memory Manager 未启用\n\n- 记忆压缩功能不可用\n- 请在配置中启用 memory_manager"
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = await self._get_messages(agent)
|
||||||
|
if not messages:
|
||||||
|
return CommandResult(
|
||||||
|
success=False,
|
||||||
|
message="⚠️ 没有可压缩的消息\n\n- 当前记忆为空\n- 无需执行压缩"
|
||||||
|
)
|
||||||
|
|
||||||
|
compact_content = await memory_manager.compact_memory(messages)
|
||||||
|
await self._update_compressed_summary(agent, compact_content)
|
||||||
|
|
||||||
|
return CommandResult(
|
||||||
|
success=True,
|
||||||
|
message=f"✅ 记忆压缩完成\n\n- 压缩了 {len(messages)} 条消息\n- 摘要长度: {len(compact_content)} 字符",
|
||||||
|
data={"message_count": len(messages), "summary_length": len(compact_content)}
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to compact memory: {e}")
|
||||||
|
return CommandResult(
|
||||||
|
success=False,
|
||||||
|
message=f"❌ 压缩失败: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _get_messages(self, agent: "EvoAgent") -> List[Any]:
|
||||||
|
"""获取Agent的记忆消息"""
|
||||||
|
memory = getattr(agent, "memory", None)
|
||||||
|
if memory is None:
|
||||||
|
return []
|
||||||
|
return await memory.get_memory() if hasattr(memory, "get_memory") else []
|
||||||
|
|
||||||
|
async def _update_compressed_summary(self, agent: "EvoAgent", content: str) -> None:
|
||||||
|
"""更新压缩摘要"""
|
||||||
|
memory = getattr(agent, "memory", None)
|
||||||
|
if memory and hasattr(memory, "update_compressed_summary"):
|
||||||
|
await memory.update_compressed_summary(content)
|
||||||
|
|
||||||
|
|
||||||
|
class SkillsListCommandHandler(CommandHandler):
|
||||||
|
"""处理 /skills list 命令 - 列出已激活技能"""
|
||||||
|
|
||||||
|
async def handle(self, ctx: CommandContext) -> CommandResult:
|
||||||
|
try:
|
||||||
|
from backend.agents.skills_manager import SkillsManager
|
||||||
|
sm = SkillsManager()
|
||||||
|
|
||||||
|
active_skills = sm.list_active_skill_metadata(ctx.config_name, ctx.agent_id)
|
||||||
|
catalog = sm.list_agent_skill_catalog(ctx.config_name, ctx.agent_id)
|
||||||
|
|
||||||
|
lines = ["📋 技能列表", ""]
|
||||||
|
|
||||||
|
if active_skills:
|
||||||
|
lines.append("✅ 已激活技能:")
|
||||||
|
for skill in active_skills:
|
||||||
|
lines.append(f" • {skill.name} - {skill.description[:50]}...")
|
||||||
|
else:
|
||||||
|
lines.append("⚠️ 当前没有激活的技能")
|
||||||
|
|
||||||
|
lines.append("")
|
||||||
|
lines.append(f"📚 可用技能总数: {len(catalog)}")
|
||||||
|
lines.append("💡 使用 /skills enable <name> 启用技能")
|
||||||
|
|
||||||
|
return CommandResult(
|
||||||
|
success=True,
|
||||||
|
message="\n".join(lines),
|
||||||
|
data={
|
||||||
|
"active_count": len(active_skills),
|
||||||
|
"catalog_count": len(catalog),
|
||||||
|
"active": [s.skill_name for s in active_skills]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to list skills: {e}")
|
||||||
|
return CommandResult(
|
||||||
|
success=False,
|
||||||
|
message=f"❌ 获取技能列表失败: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SkillsEnableCommandHandler(CommandHandler):
|
||||||
|
"""处理 /skills enable <name> 命令 - 启用技能"""
|
||||||
|
|
||||||
|
async def handle(self, ctx: CommandContext) -> CommandResult:
|
||||||
|
skill_name = ctx.args.strip()
|
||||||
|
if not skill_name:
|
||||||
|
return CommandResult(
|
||||||
|
success=False,
|
||||||
|
message="Usage: /skills enable <skill_name>\n请提供技能名称。"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from backend.agents.skills_manager import SkillsManager
|
||||||
|
sm = SkillsManager()
|
||||||
|
|
||||||
|
result = sm.update_agent_skill_overrides(
|
||||||
|
ctx.config_name,
|
||||||
|
ctx.agent_id,
|
||||||
|
enable=[skill_name]
|
||||||
|
)
|
||||||
|
|
||||||
|
return CommandResult(
|
||||||
|
success=True,
|
||||||
|
message=f"✅ 技能已启用: {skill_name}\n\n已启用技能: {', '.join(result['enabled_skills'])}",
|
||||||
|
data=result
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to enable skill: {e}")
|
||||||
|
return CommandResult(
|
||||||
|
success=False,
|
||||||
|
message=f"❌ 启用技能失败: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SkillsDisableCommandHandler(CommandHandler):
|
||||||
|
"""处理 /skills disable <name> 命令 - 禁用技能"""
|
||||||
|
|
||||||
|
async def handle(self, ctx: CommandContext) -> CommandResult:
|
||||||
|
skill_name = ctx.args.strip()
|
||||||
|
if not skill_name:
|
||||||
|
return CommandResult(
|
||||||
|
success=False,
|
||||||
|
message="Usage: /skills disable <skill_name>\n请提供技能名称。"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from backend.agents.skills_manager import SkillsManager
|
||||||
|
sm = SkillsManager()
|
||||||
|
|
||||||
|
result = sm.update_agent_skill_overrides(
|
||||||
|
ctx.config_name,
|
||||||
|
ctx.agent_id,
|
||||||
|
disable=[skill_name]
|
||||||
|
)
|
||||||
|
|
||||||
|
return CommandResult(
|
||||||
|
success=True,
|
||||||
|
message=f"✅ 技能已禁用: {skill_name}\n\n已禁用技能: {', '.join(result['disabled_skills'])}",
|
||||||
|
data=result
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to disable skill: {e}")
|
||||||
|
return CommandResult(
|
||||||
|
success=False,
|
||||||
|
message=f"❌ 禁用技能失败: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SkillsInstallCommandHandler(CommandHandler):
|
||||||
|
"""处理 /skills install <name> 命令 - 安装技能"""
|
||||||
|
|
||||||
|
async def handle(self, ctx: CommandContext) -> CommandResult:
|
||||||
|
skill_name = ctx.args.strip()
|
||||||
|
if not skill_name:
|
||||||
|
return CommandResult(
|
||||||
|
success=False,
|
||||||
|
message="Usage: /skills install <skill_name>\n请提供技能名称。"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from backend.agents.skills_manager import SkillsManager
|
||||||
|
from backend.agents.skill_loader import load_skill_from_dir
|
||||||
|
sm = SkillsManager()
|
||||||
|
|
||||||
|
# 查找技能源目录
|
||||||
|
source_dir = self._resolve_skill_source(sm, skill_name)
|
||||||
|
if not source_dir:
|
||||||
|
return CommandResult(
|
||||||
|
success=False,
|
||||||
|
message=f"❌ 技能未找到: {skill_name}\n\n请检查技能名称是否正确,或技能是否存在于 builtin/customized 目录。"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 加载并验证技能
|
||||||
|
skill_info = load_skill_from_dir(source_dir)
|
||||||
|
if not skill_info:
|
||||||
|
return CommandResult(
|
||||||
|
success=False,
|
||||||
|
message=f"❌ 技能加载失败: {skill_name}\n\n技能格式可能不正确。"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 安装到agent的installed目录
|
||||||
|
installed_root = sm.get_agent_installed_root(ctx.config_name, ctx.agent_id)
|
||||||
|
target_dir = installed_root / skill_name
|
||||||
|
|
||||||
|
import shutil
|
||||||
|
if target_dir.exists():
|
||||||
|
shutil.rmtree(target_dir)
|
||||||
|
shutil.copytree(source_dir, target_dir)
|
||||||
|
|
||||||
|
return CommandResult(
|
||||||
|
success=True,
|
||||||
|
message=f"✅ 技能已安装: {skill_name}\n\n- 名称: {skill_info.get('name', skill_name)}\n- 版本: {skill_info.get('version', 'unknown')}\n- 路径: {target_dir}",
|
||||||
|
data={"skill_name": skill_name, "target_dir": str(target_dir)}
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to install skill: {e}")
|
||||||
|
return CommandResult(
|
||||||
|
success=False,
|
||||||
|
message=f"❌ 安装技能失败: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _resolve_skill_source(self, sm: "SkillsManager", skill_name: str) -> Optional[Path]:
|
||||||
|
"""解析技能源目录"""
|
||||||
|
for root in [sm.customized_root, sm.builtin_root]:
|
||||||
|
candidate = root / skill_name
|
||||||
|
if candidate.exists() and (candidate / "SKILL.md").exists():
|
||||||
|
return candidate
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class ReloadCommandHandler(CommandHandler):
|
||||||
|
"""处理 /reload 命令 - 重新加载配置"""
|
||||||
|
|
||||||
|
async def handle(self, ctx: CommandContext) -> CommandResult:
|
||||||
|
try:
|
||||||
|
agent = ctx.agent
|
||||||
|
|
||||||
|
# 重新加载配置
|
||||||
|
if hasattr(agent, "reload_config"):
|
||||||
|
await agent.reload_config()
|
||||||
|
|
||||||
|
# 重新加载技能
|
||||||
|
from backend.agents.skills_manager import SkillsManager
|
||||||
|
sm = SkillsManager()
|
||||||
|
|
||||||
|
# 刷新技能同步
|
||||||
|
active_root = sm.get_agent_active_root(ctx.config_name, ctx.agent_id)
|
||||||
|
if active_root.exists():
|
||||||
|
# 清除缓存,强制重新加载
|
||||||
|
import shutil
|
||||||
|
for item in active_root.iterdir():
|
||||||
|
if item.is_dir():
|
||||||
|
shutil.rmtree(item)
|
||||||
|
|
||||||
|
return CommandResult(
|
||||||
|
success=True,
|
||||||
|
message="✅ 配置已重新加载\n\n- Agent配置已刷新\n- 技能缓存已清除\n- 请重启对话以应用所有更改",
|
||||||
|
data={"config_name": ctx.config_name, "agent_id": ctx.agent_id}
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to reload config: {e}")
|
||||||
|
return CommandResult(
|
||||||
|
success=False,
|
||||||
|
message=f"❌ 重新加载失败: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class StatusCommandHandler(CommandHandler):
|
||||||
|
"""处理 /status 命令 - 显示Agent状态"""
|
||||||
|
|
||||||
|
async def handle(self, ctx: CommandContext) -> CommandResult:
|
||||||
|
try:
|
||||||
|
agent = ctx.agent
|
||||||
|
|
||||||
|
lines = ["📊 Agent 状态", ""]
|
||||||
|
lines.append(f"🆔 Agent ID: {ctx.agent_id}")
|
||||||
|
lines.append(f"⚙️ Config: {ctx.config_name}")
|
||||||
|
|
||||||
|
# 模型信息
|
||||||
|
model = getattr(agent, "model", None)
|
||||||
|
if model:
|
||||||
|
lines.append(f"🤖 Model: {model}")
|
||||||
|
|
||||||
|
# 记忆状态
|
||||||
|
memory = getattr(agent, "memory", None)
|
||||||
|
if memory:
|
||||||
|
msg_count = len(getattr(memory, "content", []))
|
||||||
|
lines.append(f"💾 Memory: {msg_count} messages")
|
||||||
|
|
||||||
|
# 技能状态
|
||||||
|
from backend.agents.skills_manager import SkillsManager
|
||||||
|
sm = SkillsManager()
|
||||||
|
active_skills = sm.list_active_skill_metadata(ctx.config_name, ctx.agent_id)
|
||||||
|
lines.append(f"🔧 Active Skills: {len(active_skills)}")
|
||||||
|
|
||||||
|
# 工具组状态
|
||||||
|
toolkit = getattr(agent, "toolkit", None)
|
||||||
|
if toolkit:
|
||||||
|
groups = getattr(toolkit, "tool_groups", {})
|
||||||
|
active_groups = [name for name, g in groups.items() if getattr(g, "active", False)]
|
||||||
|
lines.append(f"🛠️ Active Tool Groups: {', '.join(active_groups) if active_groups else 'None'}")
|
||||||
|
|
||||||
|
return CommandResult(
|
||||||
|
success=True,
|
||||||
|
message="\n".join(lines),
|
||||||
|
data={
|
||||||
|
"agent_id": ctx.agent_id,
|
||||||
|
"config_name": ctx.config_name,
|
||||||
|
"active_skills_count": len(active_skills)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to get status: {e}")
|
||||||
|
return CommandResult(
|
||||||
|
success=False,
|
||||||
|
message=f"❌ 获取状态失败: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class HelpCommandHandler(CommandHandler):
|
||||||
|
"""处理 /help 命令 - 显示帮助"""
|
||||||
|
|
||||||
|
async def handle(self, ctx: CommandContext) -> CommandResult:
|
||||||
|
help_text = """📖 EvoAgent 命令帮助
|
||||||
|
|
||||||
|
可用命令:
|
||||||
|
/save <message> - 保存内容到 MEMORY.md
|
||||||
|
/compact - 压缩记忆
|
||||||
|
/skills list - 列出已激活技能
|
||||||
|
/skills enable <name> - 启用技能
|
||||||
|
/skills disable <name>- 禁用技能
|
||||||
|
/skills install <name>- 安装技能
|
||||||
|
/reload - 重新加载配置
|
||||||
|
/status - 显示Agent状态
|
||||||
|
/help - 显示此帮助信息
|
||||||
|
|
||||||
|
提示:
|
||||||
|
• 所有命令以 / 开头
|
||||||
|
• 命令不区分大小写
|
||||||
|
• 使用 Tab 键可自动补全命令
|
||||||
|
"""
|
||||||
|
return CommandResult(success=True, message=help_text)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentCommandDispatcher:
|
||||||
|
"""Agent命令分发器
|
||||||
|
|
||||||
|
参考CoPaw的CommandHandler设计,为EvoAgent提供统一的命令处理入口。
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 支持的系统命令
|
||||||
|
SYSTEM_COMMANDS = frozenset({
|
||||||
|
"save", "compact",
|
||||||
|
"skills", "reload",
|
||||||
|
"status", "help"
|
||||||
|
})
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._handlers: Dict[str, CommandHandler] = {}
|
||||||
|
self._subcommands: Dict[str, Dict[str, CommandHandler]] = {}
|
||||||
|
self._register_default_handlers()
|
||||||
|
|
||||||
|
def _register_default_handlers(self) -> None:
|
||||||
|
"""注册默认命令处理器"""
|
||||||
|
self._handlers["save"] = SaveCommandHandler()
|
||||||
|
self._handlers["compact"] = CompactCommandHandler()
|
||||||
|
self._handlers["reload"] = ReloadCommandHandler()
|
||||||
|
self._handlers["status"] = StatusCommandHandler()
|
||||||
|
self._handlers["help"] = HelpCommandHandler()
|
||||||
|
|
||||||
|
# 子命令: /skills list/enable/disable/install
|
||||||
|
self._subcommands["skills"] = {
|
||||||
|
"list": SkillsListCommandHandler(),
|
||||||
|
"enable": SkillsEnableCommandHandler(),
|
||||||
|
"disable": SkillsDisableCommandHandler(),
|
||||||
|
"install": SkillsInstallCommandHandler(),
|
||||||
|
}
|
||||||
|
|
||||||
|
def is_command(self, query: str | None) -> bool:
|
||||||
|
"""检查是否为命令
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: 用户输入字符串
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True 如果是系统命令
|
||||||
|
"""
|
||||||
|
if not isinstance(query, str) or not query.startswith("/"):
|
||||||
|
return False
|
||||||
|
|
||||||
|
parts = query.strip().lstrip("/").split()
|
||||||
|
if not parts:
|
||||||
|
return False
|
||||||
|
|
||||||
|
cmd = parts[0].lower()
|
||||||
|
|
||||||
|
# 检查主命令
|
||||||
|
if cmd in self.SYSTEM_COMMANDS:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def handle(self, agent: "EvoAgent", query: str) -> CommandResult:
|
||||||
|
"""处理命令
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent: EvoAgent实例
|
||||||
|
query: 命令字符串
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
命令执行结果
|
||||||
|
"""
|
||||||
|
if not self.is_command(query):
|
||||||
|
return CommandResult(
|
||||||
|
success=False,
|
||||||
|
message=f"未知命令: {query}\n使用 /help 查看可用命令。"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 解析命令和参数
|
||||||
|
parts = query.strip().lstrip("/").split(maxsplit=1)
|
||||||
|
cmd = parts[0].lower()
|
||||||
|
args = parts[1] if len(parts) > 1 else ""
|
||||||
|
|
||||||
|
logger.info(f"Processing command: {cmd}, args: {args}")
|
||||||
|
|
||||||
|
# 处理子命令 (e.g., /skills list)
|
||||||
|
if cmd in self._subcommands:
|
||||||
|
sub_parts = args.split(maxsplit=1)
|
||||||
|
sub_cmd = sub_parts[0].lower() if sub_parts else ""
|
||||||
|
sub_args = sub_parts[1] if len(sub_parts) > 1 else ""
|
||||||
|
|
||||||
|
handlers = self._subcommands[cmd]
|
||||||
|
handler = handlers.get(sub_cmd)
|
||||||
|
|
||||||
|
if handler is None:
|
||||||
|
available = ", ".join(handlers.keys())
|
||||||
|
return CommandResult(
|
||||||
|
success=False,
|
||||||
|
message=f"未知子命令: {sub_cmd}\n可用子命令: {available}"
|
||||||
|
)
|
||||||
|
|
||||||
|
ctx = CommandContext(agent, query, sub_args)
|
||||||
|
return await handler.handle(ctx)
|
||||||
|
|
||||||
|
# 处理主命令
|
||||||
|
handler = self._handlers.get(cmd)
|
||||||
|
if handler is None:
|
||||||
|
return CommandResult(
|
||||||
|
success=False,
|
||||||
|
message=f"命令未实现: {cmd}"
|
||||||
|
)
|
||||||
|
|
||||||
|
ctx = CommandContext(agent, query, args)
|
||||||
|
return await handler.handle(ctx)
|
||||||
|
|
||||||
|
|
||||||
|
# 便捷函数
|
||||||
|
def create_command_dispatcher() -> AgentCommandDispatcher:
|
||||||
|
"""创建命令分发器实例"""
|
||||||
|
return AgentCommandDispatcher()
|
||||||
452
backend/agents/base/evaluation_hook.py
Normal file
452
backend/agents/base/evaluation_hook.py
Normal file
@@ -0,0 +1,452 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""Evaluation hooks system for skills.
|
||||||
|
|
||||||
|
Provides evaluation metric collection and storage for skill performance tracking.
|
||||||
|
Based on the evaluation hooks design in SKILL_TEMPLATE.md.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field, asdict
|
||||||
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional, Set
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MetricType(Enum):
|
||||||
|
"""Types of evaluation metrics."""
|
||||||
|
HIT_RATE = "hit_rate" # 信号命中率
|
||||||
|
RISK_VIOLATION = "risk_violation" # 风控违例率
|
||||||
|
POSITION_DEVIATION = "position_deviation" # 仓位偏离率
|
||||||
|
PnL_ATTRIBUTION = "pnl_attribution" # P&L 归因一致性
|
||||||
|
SIGNAL_CONSISTENCY = "signal_consistency" # 信号一致性
|
||||||
|
DECISION_LATENCY = "decision_latency" # 决策延迟
|
||||||
|
TOOL_USAGE = "tool_usage" # 工具使用率
|
||||||
|
CUSTOM = "custom" # 自定义指标
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EvaluationMetric:
|
||||||
|
"""A single evaluation metric."""
|
||||||
|
name: str
|
||||||
|
metric_type: MetricType
|
||||||
|
value: float
|
||||||
|
timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||||
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"name": self.name,
|
||||||
|
"metric_type": self.metric_type.value,
|
||||||
|
"value": self.value,
|
||||||
|
"timestamp": self.timestamp,
|
||||||
|
"metadata": self.metadata,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EvaluationResult:
|
||||||
|
"""Evaluation result for a skill execution."""
|
||||||
|
skill_name: str
|
||||||
|
run_id: str
|
||||||
|
agent_id: str
|
||||||
|
metrics: List[EvaluationMetric] = field(default_factory=list)
|
||||||
|
inputs: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
outputs: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
decision: Optional[str] = None
|
||||||
|
success: bool = True
|
||||||
|
error_message: Optional[str] = None
|
||||||
|
started_at: Optional[str] = None
|
||||||
|
completed_at: Optional[str] = field(default_factory=lambda: datetime.now().isoformat())
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"skill_name": self.skill_name,
|
||||||
|
"run_id": self.run_id,
|
||||||
|
"agent_id": self.agent_id,
|
||||||
|
"metrics": [m.to_dict() for m in self.metrics],
|
||||||
|
"inputs": self.inputs,
|
||||||
|
"outputs": self.outputs,
|
||||||
|
"decision": self.decision,
|
||||||
|
"success": self.success,
|
||||||
|
"error_message": self.error_message,
|
||||||
|
"started_at": self.started_at,
|
||||||
|
"completed_at": self.completed_at,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class EvaluationHook:
|
||||||
|
"""Hook for collecting skill evaluation metrics.
|
||||||
|
|
||||||
|
This hook collects and stores evaluation metrics after skill execution
|
||||||
|
for later analysis and memory/reflection stages.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
storage_dir: Path,
|
||||||
|
run_id: str,
|
||||||
|
agent_id: str,
|
||||||
|
):
|
||||||
|
"""Initialize evaluation hook.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
storage_dir: Directory to store evaluation results
|
||||||
|
run_id: Current run identifier
|
||||||
|
agent_id: Current agent identifier
|
||||||
|
"""
|
||||||
|
self.storage_dir = Path(storage_dir)
|
||||||
|
self.run_id = run_id
|
||||||
|
self.agent_id = agent_id
|
||||||
|
self._current_evaluation: Optional[EvaluationResult] = None
|
||||||
|
|
||||||
|
def start_evaluation(
|
||||||
|
self,
|
||||||
|
skill_name: str,
|
||||||
|
inputs: Dict[str, Any],
|
||||||
|
) -> None:
|
||||||
|
"""Start a new evaluation session.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
skill_name: Name of the skill being evaluated
|
||||||
|
inputs: Input parameters for the skill
|
||||||
|
"""
|
||||||
|
self._current_evaluation = EvaluationResult(
|
||||||
|
skill_name=skill_name,
|
||||||
|
run_id=self.run_id,
|
||||||
|
agent_id=self.agent_id,
|
||||||
|
inputs=inputs,
|
||||||
|
started_at=datetime.now().isoformat(),
|
||||||
|
)
|
||||||
|
logger.debug(f"Started evaluation for skill: {skill_name}")
|
||||||
|
|
||||||
|
def add_metric(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
metric_type: MetricType,
|
||||||
|
value: float,
|
||||||
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Add an evaluation metric.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Metric name
|
||||||
|
metric_type: Type of metric
|
||||||
|
value: Metric value
|
||||||
|
metadata: Additional metadata
|
||||||
|
"""
|
||||||
|
if self._current_evaluation is None:
|
||||||
|
logger.warning("No active evaluation session, ignoring metric")
|
||||||
|
return
|
||||||
|
|
||||||
|
metric = EvaluationMetric(
|
||||||
|
name=name,
|
||||||
|
metric_type=metric_type,
|
||||||
|
value=value,
|
||||||
|
metadata=metadata or {},
|
||||||
|
)
|
||||||
|
self._current_evaluation.metrics.append(metric)
|
||||||
|
logger.debug(f"Added metric: {name} = {value}")
|
||||||
|
|
||||||
|
def add_metrics(self, metrics: List[EvaluationMetric]) -> None:
|
||||||
|
"""Add multiple evaluation metrics at once.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metrics: List of metrics to add
|
||||||
|
"""
|
||||||
|
if self._current_evaluation is None:
|
||||||
|
logger.warning("No active evaluation session, ignoring metrics")
|
||||||
|
return
|
||||||
|
|
||||||
|
self._current_evaluation.metrics.extend(metrics)
|
||||||
|
|
||||||
|
def record_outputs(self, outputs: Dict[str, Any]) -> None:
|
||||||
|
"""Record skill outputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
outputs: Output from skill execution
|
||||||
|
"""
|
||||||
|
if self._current_evaluation is None:
|
||||||
|
logger.warning("No active evaluation session, ignoring outputs")
|
||||||
|
return
|
||||||
|
|
||||||
|
self._current_evaluation.outputs = outputs
|
||||||
|
|
||||||
|
def record_decision(self, decision: str) -> None:
|
||||||
|
"""Record the final decision.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
decision: Final decision made by the skill
|
||||||
|
"""
|
||||||
|
if self._current_evaluation is None:
|
||||||
|
logger.warning("No active evaluation session, ignoring decision")
|
||||||
|
return
|
||||||
|
|
||||||
|
self._current_evaluation.decision = decision
|
||||||
|
|
||||||
|
def complete_evaluation(
|
||||||
|
self,
|
||||||
|
success: bool = True,
|
||||||
|
error_message: Optional[str] = None,
|
||||||
|
) -> Optional[EvaluationResult]:
|
||||||
|
"""Complete the evaluation session and persist results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
success: Whether the skill execution was successful
|
||||||
|
error_message: Error message if failed
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The completed evaluation result, or None if no active evaluation
|
||||||
|
"""
|
||||||
|
if self._current_evaluation is None:
|
||||||
|
logger.warning("No active evaluation to complete")
|
||||||
|
return None
|
||||||
|
|
||||||
|
self._current_evaluation.success = success
|
||||||
|
self._current_evaluation.error_message = error_message
|
||||||
|
self._current_evaluation.completed_at = datetime.now().isoformat()
|
||||||
|
|
||||||
|
# Persist to storage
|
||||||
|
result = self._persist_evaluation(self._current_evaluation)
|
||||||
|
|
||||||
|
self._current_evaluation = None
|
||||||
|
logger.debug(f"Completed evaluation for skill: {result.skill_name}")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _persist_evaluation(self, evaluation: EvaluationResult) -> EvaluationResult:
|
||||||
|
"""Persist evaluation result to storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
evaluation: Evaluation result to persist
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The persisted evaluation
|
||||||
|
"""
|
||||||
|
# Create run-specific directory
|
||||||
|
run_dir = self.storage_dir / self.run_id
|
||||||
|
run_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Create agent-specific subdirectory
|
||||||
|
agent_dir = run_dir / self.agent_id
|
||||||
|
agent_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Generate filename with timestamp
|
||||||
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
||||||
|
filename = f"{evaluation.skill_name}_{timestamp}.json"
|
||||||
|
filepath = agent_dir / filename
|
||||||
|
|
||||||
|
# Write evaluation result
|
||||||
|
try:
|
||||||
|
with open(filepath, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(evaluation.to_dict(), f, ensure_ascii=False, indent=2)
|
||||||
|
logger.info(f"Persisted evaluation to: {filepath}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to persist evaluation: {e}")
|
||||||
|
|
||||||
|
return evaluation
|
||||||
|
|
||||||
|
def cancel_evaluation(self) -> None:
|
||||||
|
"""Cancel the current evaluation session without saving."""
|
||||||
|
if self._current_evaluation is not None:
|
||||||
|
logger.debug(f"Cancelled evaluation for: {self._current_evaluation.skill_name}")
|
||||||
|
self._current_evaluation = None
|
||||||
|
|
||||||
|
|
||||||
|
class EvaluationCollector:
|
||||||
|
"""Collector for aggregating evaluation metrics across runs.
|
||||||
|
|
||||||
|
Provides methods to query and analyze evaluation results.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, storage_dir: Path):
|
||||||
|
"""Initialize evaluation collector.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
storage_dir: Root directory containing evaluation results
|
||||||
|
"""
|
||||||
|
self.storage_dir = Path(storage_dir)
|
||||||
|
|
||||||
|
def get_run_evaluations(
|
||||||
|
self,
|
||||||
|
run_id: str,
|
||||||
|
agent_id: Optional[str] = None,
|
||||||
|
) -> List[EvaluationResult]:
|
||||||
|
"""Get all evaluations for a run.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run_id: Run identifier
|
||||||
|
agent_id: Optional agent identifier to filter by
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of evaluation results
|
||||||
|
"""
|
||||||
|
run_dir = self.storage_dir / run_id
|
||||||
|
if not run_dir.exists():
|
||||||
|
return []
|
||||||
|
|
||||||
|
evaluations = []
|
||||||
|
|
||||||
|
agent_dirs = [run_dir / agent_id] if agent_id else run_dir.iterdir()
|
||||||
|
|
||||||
|
for agent_dir in agent_dirs:
|
||||||
|
if not agent_dir.is_dir():
|
||||||
|
continue
|
||||||
|
|
||||||
|
for eval_file in agent_dir.glob("*.json"):
|
||||||
|
try:
|
||||||
|
with open(eval_file, "r", encoding="utf-8") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
evaluations.append(self._parse_evaluation(data))
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to load evaluation {eval_file}: {e}")
|
||||||
|
|
||||||
|
return evaluations
|
||||||
|
|
||||||
|
def get_skill_metrics(
|
||||||
|
self,
|
||||||
|
skill_name: str,
|
||||||
|
run_ids: Optional[List[str]] = None,
|
||||||
|
) -> List[EvaluationMetric]:
|
||||||
|
"""Get all metrics for a specific skill.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
skill_name: Name of the skill
|
||||||
|
run_ids: Optional list of run IDs to filter by
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of metrics for the skill
|
||||||
|
"""
|
||||||
|
metrics = []
|
||||||
|
|
||||||
|
if run_ids is None:
|
||||||
|
run_ids = [d.name for d in self.storage_dir.iterdir() if d.is_dir()]
|
||||||
|
|
||||||
|
for run_id in run_ids:
|
||||||
|
evaluations = self.get_run_evaluations(run_id)
|
||||||
|
for eval_result in evaluations:
|
||||||
|
if eval_result.skill_name == skill_name:
|
||||||
|
metrics.extend(eval_result.metrics)
|
||||||
|
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
def calculate_skill_stats(
|
||||||
|
self,
|
||||||
|
skill_name: str,
|
||||||
|
metric_type: MetricType,
|
||||||
|
run_ids: Optional[List[str]] = None,
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
"""Calculate statistics for a specific metric type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
skill_name: Name of the skill
|
||||||
|
metric_type: Type of metric to calculate
|
||||||
|
run_ids: Optional list of run IDs to filter by
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with min, max, avg, count statistics
|
||||||
|
"""
|
||||||
|
metrics = self.get_skill_metrics(skill_name, run_ids)
|
||||||
|
filtered = [m for m in metrics if m.metric_type == metric_type]
|
||||||
|
|
||||||
|
if not filtered:
|
||||||
|
return {"count": 0}
|
||||||
|
|
||||||
|
values = [m.value for m in filtered]
|
||||||
|
return {
|
||||||
|
"count": len(values),
|
||||||
|
"min": min(values),
|
||||||
|
"max": max(values),
|
||||||
|
"avg": sum(values) / len(values),
|
||||||
|
}
|
||||||
|
|
||||||
|
def _parse_evaluation(self, data: Dict[str, Any]) -> EvaluationResult:
|
||||||
|
"""Parse evaluation data into EvaluationResult.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Raw evaluation data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Parsed EvaluationResult
|
||||||
|
"""
|
||||||
|
metrics = []
|
||||||
|
for m in data.get("metrics", []):
|
||||||
|
metrics.append(EvaluationMetric(
|
||||||
|
name=m["name"],
|
||||||
|
metric_type=MetricType(m["metric_type"]),
|
||||||
|
value=m["value"],
|
||||||
|
timestamp=m.get("timestamp", ""),
|
||||||
|
metadata=m.get("metadata", {}),
|
||||||
|
))
|
||||||
|
|
||||||
|
return EvaluationResult(
|
||||||
|
skill_name=data["skill_name"],
|
||||||
|
run_id=data["run_id"],
|
||||||
|
agent_id=data["agent_id"],
|
||||||
|
metrics=metrics,
|
||||||
|
inputs=data.get("inputs", {}),
|
||||||
|
outputs=data.get("outputs", {}),
|
||||||
|
decision=data.get("decision"),
|
||||||
|
success=data.get("success", True),
|
||||||
|
error_message=data.get("error_message"),
|
||||||
|
started_at=data.get("started_at"),
|
||||||
|
completed_at=data.get("completed_at"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_evaluation_hooks(skill_dir: Path) -> Dict[str, Any]:
|
||||||
|
"""Parse evaluation hooks from SKILL.md.
|
||||||
|
|
||||||
|
Extracts the Optional: Evaluation hooks section from skill documentation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
skill_dir: Skill directory path
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing evaluation hook definitions
|
||||||
|
"""
|
||||||
|
skill_md = skill_dir / "SKILL.md"
|
||||||
|
if not skill_md.exists():
|
||||||
|
return {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
content = skill_md.read_text(encoding="utf-8")
|
||||||
|
|
||||||
|
# Extract evaluation hooks section
|
||||||
|
if "## Optional: Evaluation hooks" in content:
|
||||||
|
start = content.find("## Optional: Evaluation hooks")
|
||||||
|
# Find the next ## section or end of file
|
||||||
|
next_section = content.find("\n## ", start + 1)
|
||||||
|
if next_section == -1:
|
||||||
|
eval_section = content[start:]
|
||||||
|
else:
|
||||||
|
eval_section = content[start:next_section]
|
||||||
|
|
||||||
|
# Parse metrics from the section
|
||||||
|
metrics = []
|
||||||
|
for metric_type in MetricType:
|
||||||
|
if metric_type.value.replace("_", " ") in eval_section.lower():
|
||||||
|
metrics.append(metric_type.value)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"supported_metrics": metrics,
|
||||||
|
"section_content": eval_section.strip(),
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to parse evaluation hooks: {e}")
|
||||||
|
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"MetricType",
|
||||||
|
"EvaluationMetric",
|
||||||
|
"EvaluationResult",
|
||||||
|
"EvaluationHook",
|
||||||
|
"EvaluationCollector",
|
||||||
|
"parse_evaluation_hooks",
|
||||||
|
]
|
||||||
510
backend/agents/base/evo_agent.py
Normal file
510
backend/agents/base/evo_agent.py
Normal file
@@ -0,0 +1,510 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""EvoAgent - Core agent implementation for EvoTraders.
|
||||||
|
|
||||||
|
This module provides the main EvoAgent class built on AgentScope's ReActAgent,
|
||||||
|
with integrated tools, skills, and memory management based on CoPaw design.
|
||||||
|
|
||||||
|
Key features:
|
||||||
|
- Workspace-driven configuration from Markdown files
|
||||||
|
- Dynamic skill loading from skills/active directories
|
||||||
|
- Tool-guard security interception
|
||||||
|
- Hook system for extensibility
|
||||||
|
- Runtime skill and prompt reloading
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional, Type, TYPE_CHECKING
|
||||||
|
|
||||||
|
from agentscope.agent import ReActAgent
|
||||||
|
from agentscope.memory import InMemoryMemory
|
||||||
|
from agentscope.message import Msg
|
||||||
|
from agentscope.tool import Toolkit
|
||||||
|
|
||||||
|
from .tool_guard import ToolGuardMixin
|
||||||
|
from .hooks import (
|
||||||
|
HookManager,
|
||||||
|
BootstrapHook,
|
||||||
|
MemoryCompactionHook,
|
||||||
|
WorkspaceWatchHook,
|
||||||
|
HOOK_PRE_REASONING,
|
||||||
|
)
|
||||||
|
from ..prompts.builder import (
|
||||||
|
PromptBuilder,
|
||||||
|
build_system_prompt_from_workspace,
|
||||||
|
)
|
||||||
|
from ..agent_workspace import load_agent_workspace_config
|
||||||
|
from ..skills_manager import SkillsManager
|
||||||
|
|
||||||
|
# Team infrastructure imports (graceful import - may not exist yet)
|
||||||
|
try:
|
||||||
|
from backend.agents.team.messenger import AgentMessenger
|
||||||
|
from backend.agents.team.task_delegator import TaskDelegator
|
||||||
|
TEAM_INFRA_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
TEAM_INFRA_AVAILABLE = False
|
||||||
|
AgentMessenger = None
|
||||||
|
TaskDelegator = None
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from agentscope.formatter import FormatterBase
|
||||||
|
from agentscope.model import ModelWrapperBase
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class EvoAgent(ToolGuardMixin, ReActAgent):
|
||||||
|
"""EvoAgent with integrated tools, skills, and memory management.
|
||||||
|
|
||||||
|
This agent extends ReActAgent with:
|
||||||
|
- Workspace-driven configuration from AGENTS.md/SOUL.md/PROFILE.md/etc.
|
||||||
|
- Dynamic skill loading from skills/active directories
|
||||||
|
- Tool-guard security interception (via ToolGuardMixin)
|
||||||
|
- Hook system for extensibility (bootstrap, memory compaction)
|
||||||
|
- Runtime skill and prompt reloading
|
||||||
|
|
||||||
|
MRO note
|
||||||
|
~~~~~~~~
|
||||||
|
``ToolGuardMixin`` overrides ``_acting`` and ``_reasoning`` via
|
||||||
|
Python's MRO: EvoAgent → ToolGuardMixin → ReActAgent.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
agent = EvoAgent(
|
||||||
|
agent_id="fundamentals_analyst",
|
||||||
|
config_name="smoke_fullstack",
|
||||||
|
workspace_dir=Path("runs/smoke_fullstack/agents/fundamentals_analyst"),
|
||||||
|
model=model_instance,
|
||||||
|
formatter=formatter_instance,
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
agent_id: str,
|
||||||
|
config_name: str,
|
||||||
|
workspace_dir: Path,
|
||||||
|
model: "ModelWrapperBase",
|
||||||
|
formatter: "FormatterBase",
|
||||||
|
skills_manager: Optional[SkillsManager] = None,
|
||||||
|
sys_prompt: Optional[str] = None,
|
||||||
|
max_iters: int = 10,
|
||||||
|
memory: Optional[Any] = None,
|
||||||
|
enable_tool_guard: bool = True,
|
||||||
|
enable_bootstrap_hook: bool = True,
|
||||||
|
enable_memory_compaction: bool = False,
|
||||||
|
memory_manager: Optional[Any] = None,
|
||||||
|
memory_compact_threshold: Optional[int] = None,
|
||||||
|
env_context: Optional[str] = None,
|
||||||
|
prompt_files: Optional[List[str]] = None,
|
||||||
|
):
|
||||||
|
"""Initialize EvoAgent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: Unique identifier for this agent
|
||||||
|
config_name: Run configuration name (e.g., "smoke_fullstack")
|
||||||
|
workspace_dir: Agent workspace directory containing markdown files
|
||||||
|
model: LLM model instance
|
||||||
|
formatter: Message formatter instance
|
||||||
|
skills_manager: Optional SkillsManager instance
|
||||||
|
sys_prompt: Optional override for system prompt
|
||||||
|
max_iters: Maximum reasoning-acting iterations
|
||||||
|
memory: Optional memory instance (defaults to InMemoryMemory)
|
||||||
|
enable_tool_guard: Enable tool-guard security interception
|
||||||
|
enable_bootstrap_hook: Enable bootstrap guidance on first interaction
|
||||||
|
enable_memory_compaction: Enable automatic memory compaction
|
||||||
|
memory_manager: Optional memory manager for compaction
|
||||||
|
memory_compact_threshold: Token threshold for memory compaction
|
||||||
|
env_context: Optional environment context to prepend to system prompt
|
||||||
|
prompt_files: List of markdown files to load (defaults to standard set)
|
||||||
|
"""
|
||||||
|
self.agent_id = agent_id
|
||||||
|
self.config_name = config_name
|
||||||
|
self.workspace_dir = Path(workspace_dir)
|
||||||
|
self._skills_manager = skills_manager or SkillsManager()
|
||||||
|
self._env_context = env_context
|
||||||
|
self._prompt_files = prompt_files
|
||||||
|
|
||||||
|
# Initialize tool guard
|
||||||
|
if enable_tool_guard:
|
||||||
|
self._init_tool_guard()
|
||||||
|
|
||||||
|
# Load agent configuration from workspace
|
||||||
|
self._agent_config = self._load_agent_config()
|
||||||
|
|
||||||
|
# Build or use provided system prompt
|
||||||
|
if sys_prompt is not None:
|
||||||
|
self._sys_prompt = sys_prompt
|
||||||
|
else:
|
||||||
|
self._sys_prompt = self._build_system_prompt()
|
||||||
|
|
||||||
|
# Create toolkit with skills
|
||||||
|
toolkit = self._create_toolkit()
|
||||||
|
|
||||||
|
# Initialize hook manager
|
||||||
|
self._hook_manager = HookManager()
|
||||||
|
|
||||||
|
# Initialize parent ReActAgent
|
||||||
|
super().__init__(
|
||||||
|
name=agent_id,
|
||||||
|
model=model,
|
||||||
|
sys_prompt=self._sys_prompt,
|
||||||
|
toolkit=toolkit,
|
||||||
|
memory=memory or InMemoryMemory(),
|
||||||
|
formatter=formatter,
|
||||||
|
max_iters=max_iters,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Register hooks
|
||||||
|
self._register_hooks(
|
||||||
|
enable_bootstrap=enable_bootstrap_hook,
|
||||||
|
enable_memory_compaction=enable_memory_compaction,
|
||||||
|
memory_manager=memory_manager,
|
||||||
|
memory_compact_threshold=memory_compact_threshold,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize team infrastructure if available
|
||||||
|
self._messenger: Optional["AgentMessenger"] = None
|
||||||
|
self._task_delegator: Optional["TaskDelegator"] = None
|
||||||
|
if TEAM_INFRA_AVAILABLE:
|
||||||
|
self._init_team_infrastructure()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"EvoAgent initialized: %s (workspace: %s)",
|
||||||
|
agent_id,
|
||||||
|
workspace_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _load_agent_config(self) -> Dict[str, Any]:
|
||||||
|
"""Load agent configuration from workspace.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Agent configuration dictionary
|
||||||
|
"""
|
||||||
|
config_path = self.workspace_dir / "agent.yaml"
|
||||||
|
if config_path.exists():
|
||||||
|
loaded = load_agent_workspace_config(config_path)
|
||||||
|
return dict(loaded.values)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def _build_system_prompt(self) -> str:
|
||||||
|
"""Build system prompt from workspace markdown files.
|
||||||
|
|
||||||
|
Uses PromptBuilder to load and combine AGENTS.md, SOUL.md,
|
||||||
|
PROFILE.md, and other configured files.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Complete system prompt string
|
||||||
|
"""
|
||||||
|
prompt = build_system_prompt_from_workspace(
|
||||||
|
workspace_dir=self.workspace_dir,
|
||||||
|
enabled_files=self._prompt_files,
|
||||||
|
agent_id=self.agent_id,
|
||||||
|
extra_context=self._env_context,
|
||||||
|
)
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
def _create_toolkit(self) -> Toolkit:
|
||||||
|
"""Create and populate toolkit with agent skills.
|
||||||
|
|
||||||
|
Loads skills from the agent's active skills directory and
|
||||||
|
registers them with the toolkit.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured Toolkit instance
|
||||||
|
"""
|
||||||
|
toolkit = Toolkit(
|
||||||
|
agent_skill_instruction=(
|
||||||
|
"<system-info>You have access to specialized skills. "
|
||||||
|
"Each skill lives in a directory and is described by SKILL.md. "
|
||||||
|
"Follow the skill instructions when they are relevant to the current task."
|
||||||
|
"</system-info>"
|
||||||
|
),
|
||||||
|
agent_skill_template="- {name} (dir: {dir}): {description}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Register skills from active directory
|
||||||
|
active_skills_dir = self._skills_manager.get_agent_active_root(
|
||||||
|
self.config_name,
|
||||||
|
self.agent_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if active_skills_dir.exists():
|
||||||
|
for skill_dir in sorted(active_skills_dir.iterdir()):
|
||||||
|
if skill_dir.is_dir() and (skill_dir / "SKILL.md").exists():
|
||||||
|
try:
|
||||||
|
toolkit.register_agent_skill(str(skill_dir))
|
||||||
|
logger.debug("Registered skill: %s", skill_dir.name)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
"Failed to register skill '%s': %s",
|
||||||
|
skill_dir.name,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
|
||||||
|
return toolkit
|
||||||
|
|
||||||
|
def _register_hooks(
|
||||||
|
self,
|
||||||
|
enable_bootstrap: bool,
|
||||||
|
enable_memory_compaction: bool,
|
||||||
|
memory_manager: Optional[Any],
|
||||||
|
memory_compact_threshold: Optional[int],
|
||||||
|
) -> None:
|
||||||
|
"""Register agent hooks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
enable_bootstrap: Enable bootstrap hook
|
||||||
|
enable_memory_compaction: Enable memory compaction hook
|
||||||
|
memory_manager: Memory manager instance
|
||||||
|
memory_compact_threshold: Token threshold for compaction
|
||||||
|
"""
|
||||||
|
# Bootstrap hook - checks BOOTSTRAP.md on first interaction
|
||||||
|
if enable_bootstrap:
|
||||||
|
bootstrap_hook = BootstrapHook(
|
||||||
|
workspace_dir=self.workspace_dir,
|
||||||
|
language="zh",
|
||||||
|
)
|
||||||
|
self._hook_manager.register(
|
||||||
|
hook_type=HOOK_PRE_REASONING,
|
||||||
|
hook_name="bootstrap",
|
||||||
|
hook=bootstrap_hook,
|
||||||
|
)
|
||||||
|
logger.debug("Registered bootstrap hook")
|
||||||
|
|
||||||
|
# Memory compaction hook
|
||||||
|
if enable_memory_compaction and memory_manager is not None:
|
||||||
|
compaction_hook = MemoryCompactionHook(
|
||||||
|
memory_manager=memory_manager,
|
||||||
|
memory_compact_threshold=memory_compact_threshold,
|
||||||
|
)
|
||||||
|
self._hook_manager.register(
|
||||||
|
hook_type=HOOK_PRE_REASONING,
|
||||||
|
hook_name="memory_compaction",
|
||||||
|
hook=compaction_hook,
|
||||||
|
)
|
||||||
|
logger.debug("Registered memory compaction hook")
|
||||||
|
|
||||||
|
# Workspace watch hook - auto-reload markdown files on change
|
||||||
|
workspace_watch_hook = WorkspaceWatchHook(
|
||||||
|
workspace_dir=self.workspace_dir,
|
||||||
|
)
|
||||||
|
self._hook_manager.register(
|
||||||
|
hook_type=HOOK_PRE_REASONING,
|
||||||
|
hook_name="workspace_watch",
|
||||||
|
hook=workspace_watch_hook,
|
||||||
|
)
|
||||||
|
logger.debug("Registered workspace watch hook")
|
||||||
|
|
||||||
|
async def _reasoning(self, **kwargs) -> Msg:
|
||||||
|
"""Override reasoning to execute pre-reasoning hooks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**kwargs: Arguments for reasoning
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Response message
|
||||||
|
"""
|
||||||
|
# Execute pre-reasoning hooks
|
||||||
|
kwargs = await self._hook_manager.execute(
|
||||||
|
hook_type=HOOK_PRE_REASONING,
|
||||||
|
agent=self,
|
||||||
|
kwargs=kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Call parent (which may be ToolGuardMixin's _reasoning)
|
||||||
|
return await super()._reasoning(**kwargs)
|
||||||
|
|
||||||
|
def reload_skills(self, active_skill_dirs: Optional[List[Path]] = None) -> None:
|
||||||
|
"""Reload skills at runtime.
|
||||||
|
|
||||||
|
Rebuilds the toolkit with current skills from the active directory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
active_skill_dirs: Optional list of specific skill directories to load
|
||||||
|
"""
|
||||||
|
logger.info("Reloading skills for agent: %s", self.agent_id)
|
||||||
|
|
||||||
|
# Create new toolkit
|
||||||
|
new_toolkit = Toolkit(
|
||||||
|
agent_skill_instruction=(
|
||||||
|
"<system-info>You have access to specialized skills. "
|
||||||
|
"Each skill lives in a directory and is described by SKILL.md. "
|
||||||
|
"Follow the skill instructions when they are relevant to the current task."
|
||||||
|
"</system-info>"
|
||||||
|
),
|
||||||
|
agent_skill_template="- {name} (dir: {dir}): {description}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Register skills
|
||||||
|
if active_skill_dirs is None:
|
||||||
|
active_skills_dir = self._skills_manager.get_agent_active_root(
|
||||||
|
self.config_name,
|
||||||
|
self.agent_id,
|
||||||
|
)
|
||||||
|
if active_skills_dir.exists():
|
||||||
|
active_skill_dirs = [
|
||||||
|
d for d in active_skills_dir.iterdir()
|
||||||
|
if d.is_dir() and (d / "SKILL.md").exists()
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
active_skill_dirs = []
|
||||||
|
|
||||||
|
for skill_dir in active_skill_dirs:
|
||||||
|
if skill_dir.exists() and (skill_dir / "SKILL.md").exists():
|
||||||
|
try:
|
||||||
|
new_toolkit.register_agent_skill(str(skill_dir))
|
||||||
|
logger.debug("Reloaded skill: %s", skill_dir.name)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
"Failed to reload skill '%s': %s",
|
||||||
|
skill_dir.name,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Replace toolkit
|
||||||
|
self.toolkit = new_toolkit
|
||||||
|
logger.info("Skills reloaded for agent: %s", self.agent_id)
|
||||||
|
|
||||||
|
def rebuild_sys_prompt(self) -> None:
|
||||||
|
"""Rebuild and replace the system prompt at runtime.
|
||||||
|
|
||||||
|
Useful after updating AGENTS.md, SOUL.md, PROFILE.md, etc.
|
||||||
|
to ensure the prompt reflects the latest configuration.
|
||||||
|
|
||||||
|
Updates both self._sys_prompt and the first system-role
|
||||||
|
message stored in self.memory.content.
|
||||||
|
"""
|
||||||
|
logger.info("Rebuilding system prompt for agent: %s", self.agent_id)
|
||||||
|
|
||||||
|
# Reload agent config in case it changed
|
||||||
|
self._agent_config = self._load_agent_config()
|
||||||
|
|
||||||
|
# Rebuild prompt
|
||||||
|
self._sys_prompt = self._build_system_prompt()
|
||||||
|
|
||||||
|
# Update memory if system message exists
|
||||||
|
if hasattr(self, "memory") and self.memory.content:
|
||||||
|
for msg, _marks in self.memory.content:
|
||||||
|
if getattr(msg, "role", None) == "system":
|
||||||
|
msg.content = self._sys_prompt
|
||||||
|
logger.debug("Updated system message in memory")
|
||||||
|
break
|
||||||
|
|
||||||
|
logger.info("System prompt rebuilt for agent: %s", self.agent_id)
|
||||||
|
|
||||||
|
async def reply(
|
||||||
|
self,
|
||||||
|
msg: Msg | List[Msg] | None = None,
|
||||||
|
structured_model: Optional[Type[Any]] = None,
|
||||||
|
) -> Msg:
|
||||||
|
"""Process a message and return a response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
msg: Input message(s) from user
|
||||||
|
structured_model: Optional pydantic model for structured output
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Response message
|
||||||
|
"""
|
||||||
|
# Handle list of messages
|
||||||
|
if isinstance(msg, list):
|
||||||
|
# Process each message in sequence
|
||||||
|
for m in msg[:-1]:
|
||||||
|
await self.memory.add(m)
|
||||||
|
msg = msg[-1] if msg else None
|
||||||
|
|
||||||
|
return await super().reply(msg=msg, structured_model=structured_model)
|
||||||
|
|
||||||
|
def get_agent_info(self) -> Dict[str, Any]:
|
||||||
|
"""Get agent information.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with agent metadata
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"agent_id": self.agent_id,
|
||||||
|
"config_name": self.config_name,
|
||||||
|
"workspace_dir": str(self.workspace_dir),
|
||||||
|
"skills_count": len([
|
||||||
|
s for s in self._skills_manager.list_active_skill_metadata(
|
||||||
|
self.config_name,
|
||||||
|
self.agent_id,
|
||||||
|
)
|
||||||
|
]),
|
||||||
|
"registered_hooks": self._hook_manager.list_hooks(),
|
||||||
|
"team_infra_available": TEAM_INFRA_AVAILABLE,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _init_team_infrastructure(self) -> None:
|
||||||
|
"""Initialize team infrastructure components (messenger and task delegator).
|
||||||
|
|
||||||
|
This method initializes the AgentMessenger for inter-agent communication
|
||||||
|
and the TaskDelegator for subagent delegation.
|
||||||
|
"""
|
||||||
|
if not TEAM_INFRA_AVAILABLE:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._messenger = AgentMessenger(agent_id=self.agent_id)
|
||||||
|
self._task_delegator = TaskDelegator(agent=self)
|
||||||
|
logger.debug(
|
||||||
|
"Team infrastructure initialized for agent: %s",
|
||||||
|
self.agent_id,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to initialize team infrastructure for %s: %s",
|
||||||
|
self.agent_id,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
self._messenger = None
|
||||||
|
self._task_delegator = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def messenger(self) -> Optional["AgentMessenger"]:
|
||||||
|
"""Get the agent's messenger for inter-agent communication.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AgentMessenger instance if available, None otherwise
|
||||||
|
"""
|
||||||
|
return self._messenger
|
||||||
|
|
||||||
|
async def delegate_task(
|
||||||
|
self,
|
||||||
|
task_type: str,
|
||||||
|
task_data: Dict[str, Any],
|
||||||
|
target_agent: Optional[str] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Delegate a task to a subagent using the TaskDelegator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_type: Type of task to delegate
|
||||||
|
task_data: Data/payload for the task
|
||||||
|
target_agent: Optional specific agent ID to delegate to
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict containing the delegation result
|
||||||
|
"""
|
||||||
|
if not TEAM_INFRA_AVAILABLE or self._task_delegator is None:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": "Team infrastructure not available",
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
return await self._task_delegator.delegate_task(
|
||||||
|
task_type=task_type,
|
||||||
|
task_data=task_data,
|
||||||
|
target_agent=target_agent,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
"Task delegation failed for %s: %s",
|
||||||
|
self.agent_id,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
return {"success": False, "error": str(e)}
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["EvoAgent"]
|
||||||
702
backend/agents/base/hooks.py
Normal file
702
backend/agents/base/hooks.py
Normal file
@@ -0,0 +1,702 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""Hook system for EvoAgent.
|
||||||
|
|
||||||
|
Provides pre_reasoning and post_acting hooks with built-in implementations:
|
||||||
|
- BootstrapHook: First-time setup guidance
|
||||||
|
- MemoryCompactionHook: Automatic memory compression
|
||||||
|
|
||||||
|
Based on CoPaw's hooks design.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from agentscope.agent import ReActAgent
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Hook types
|
||||||
|
HookType = str
|
||||||
|
HOOK_PRE_REASONING: HookType = "pre_reasoning"
|
||||||
|
HOOK_POST_ACTING: HookType = "post_acting"
|
||||||
|
|
||||||
|
|
||||||
|
class Hook(ABC):
|
||||||
|
"""Abstract base class for agent hooks."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def __call__(
|
||||||
|
self,
|
||||||
|
agent: "ReActAgent",
|
||||||
|
kwargs: Dict[str, Any],
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Execute the hook.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent: The agent instance
|
||||||
|
kwargs: Input arguments to the method being hooked
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Modified kwargs or None to use original
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class HookManager:
|
||||||
|
"""Manages agent hooks.
|
||||||
|
|
||||||
|
Provides registration and execution of hooks for different
|
||||||
|
lifecycle events in the agent's operation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._hooks: Dict[HookType, List[tuple[str, Hook]]] = {
|
||||||
|
HOOK_PRE_REASONING: [],
|
||||||
|
HOOK_POST_ACTING: [],
|
||||||
|
}
|
||||||
|
|
||||||
|
def register(
|
||||||
|
self,
|
||||||
|
hook_type: HookType,
|
||||||
|
hook_name: str,
|
||||||
|
hook: Hook | Callable,
|
||||||
|
) -> None:
|
||||||
|
"""Register a hook.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hook_type: Type of hook (pre_reasoning, post_acting)
|
||||||
|
hook_name: Unique name for this hook
|
||||||
|
hook: Hook instance or callable
|
||||||
|
"""
|
||||||
|
# Remove existing hook with same name
|
||||||
|
self._hooks[hook_type] = [
|
||||||
|
(name, h) for name, h in self._hooks[hook_type] if name != hook_name
|
||||||
|
]
|
||||||
|
self._hooks[hook_type].append((hook_name, hook))
|
||||||
|
logger.debug("Registered hook '%s' for type '%s'", hook_name, hook_type)
|
||||||
|
|
||||||
|
def unregister(self, hook_type: HookType, hook_name: str) -> bool:
|
||||||
|
"""Unregister a hook.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hook_type: Type of hook
|
||||||
|
hook_name: Name of the hook to remove
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if hook was found and removed
|
||||||
|
"""
|
||||||
|
original_len = len(self._hooks[hook_type])
|
||||||
|
self._hooks[hook_type] = [
|
||||||
|
(name, h) for name, h in self._hooks[hook_type] if name != hook_name
|
||||||
|
]
|
||||||
|
removed = len(self._hooks[hook_type]) < original_len
|
||||||
|
if removed:
|
||||||
|
logger.debug("Unregistered hook '%s' from type '%s'", hook_name, hook_type)
|
||||||
|
return removed
|
||||||
|
|
||||||
|
async def execute(
|
||||||
|
self,
|
||||||
|
hook_type: HookType,
|
||||||
|
agent: "ReActAgent",
|
||||||
|
kwargs: Dict[str, Any],
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Execute all hooks of a given type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hook_type: Type of hooks to execute
|
||||||
|
agent: The agent instance
|
||||||
|
kwargs: Input arguments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Potentially modified kwargs
|
||||||
|
"""
|
||||||
|
for name, hook in self._hooks[hook_type]:
|
||||||
|
try:
|
||||||
|
result = await hook(agent, kwargs)
|
||||||
|
if result is not None:
|
||||||
|
kwargs = result
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Hook '%s' failed: %s", name, e, exc_info=True)
|
||||||
|
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
def list_hooks(self, hook_type: Optional[HookType] = None) -> List[str]:
|
||||||
|
"""List registered hook names.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hook_type: Optional type to filter by
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of hook names
|
||||||
|
"""
|
||||||
|
if hook_type:
|
||||||
|
return [name for name, _ in self._hooks.get(hook_type, [])]
|
||||||
|
|
||||||
|
names = []
|
||||||
|
for hooks in self._hooks.values():
|
||||||
|
names.extend([name for name, _ in hooks])
|
||||||
|
return names
|
||||||
|
|
||||||
|
|
||||||
|
class BootstrapHook(Hook):
|
||||||
|
"""Hook for bootstrap guidance on first user interaction.
|
||||||
|
|
||||||
|
This hook looks for a BOOTSTRAP.md file in the working directory
|
||||||
|
and if found, prepends guidance to the first user message to help
|
||||||
|
establish the agent's identity and user preferences.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
workspace_dir: Path,
|
||||||
|
language: str = "zh",
|
||||||
|
):
|
||||||
|
"""Initialize bootstrap hook.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_dir: Working directory containing BOOTSTRAP.md
|
||||||
|
language: Language code for bootstrap guidance (en/zh)
|
||||||
|
"""
|
||||||
|
self.workspace_dir = Path(workspace_dir)
|
||||||
|
self.language = language
|
||||||
|
self._completed_flag = self.workspace_dir / ".bootstrap_completed"
|
||||||
|
|
||||||
|
def _is_first_user_interaction(self, agent: "ReActAgent") -> bool:
|
||||||
|
"""Check if this is the first user interaction.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent: The agent instance
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if first user interaction
|
||||||
|
"""
|
||||||
|
if not hasattr(agent, "memory") or not agent.memory.content:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Count user messages (excluding system)
|
||||||
|
user_count = sum(
|
||||||
|
1 for msg, _ in agent.memory.content if msg.role == "user"
|
||||||
|
)
|
||||||
|
return user_count <= 1
|
||||||
|
|
||||||
|
def _build_bootstrap_guidance(self) -> str:
|
||||||
|
"""Build bootstrap guidance message.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted bootstrap guidance
|
||||||
|
"""
|
||||||
|
if self.language == "zh":
|
||||||
|
return (
|
||||||
|
"# 引导模式\n"
|
||||||
|
"\n"
|
||||||
|
"工作目录中存在 `BOOTSTRAP.md` — 首次设置。\n"
|
||||||
|
"\n"
|
||||||
|
"1. 阅读 BOOTSTRAP.md,友好地表示初次见面,"
|
||||||
|
"引导用户完成设置。\n"
|
||||||
|
"2. 按照 BOOTSTRAP.md 的指示,"
|
||||||
|
"帮助用户定义你的身份和偏好。\n"
|
||||||
|
"3. 按指南创建/更新必要文件"
|
||||||
|
"(PROFILE.md、MEMORY.md 等)。\n"
|
||||||
|
"4. 完成后删除 BOOTSTRAP.md。\n"
|
||||||
|
"\n"
|
||||||
|
"如果用户希望跳过,直接回答下面的问题即可。\n"
|
||||||
|
"\n"
|
||||||
|
"---\n"
|
||||||
|
"\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
"# BOOTSTRAP MODE\n"
|
||||||
|
"\n"
|
||||||
|
"`BOOTSTRAP.md` exists — first-time setup.\n"
|
||||||
|
"\n"
|
||||||
|
"1. Read BOOTSTRAP.md, greet the user, "
|
||||||
|
"and guide them through setup.\n"
|
||||||
|
"2. Follow BOOTSTRAP.md instructions "
|
||||||
|
"to define identity and preferences.\n"
|
||||||
|
"3. Create/update files "
|
||||||
|
"(PROFILE.md, MEMORY.md, etc.) as described.\n"
|
||||||
|
"4. Delete BOOTSTRAP.md when done.\n"
|
||||||
|
"\n"
|
||||||
|
"If the user wants to skip, answer their "
|
||||||
|
"question directly instead.\n"
|
||||||
|
"\n"
|
||||||
|
"---\n"
|
||||||
|
"\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def __call__(
|
||||||
|
self,
|
||||||
|
agent: "ReActAgent",
|
||||||
|
kwargs: Dict[str, Any],
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Check and load BOOTSTRAP.md on first user interaction.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent: The agent instance
|
||||||
|
kwargs: Input arguments to the _reasoning method
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None (hook doesn't modify kwargs)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
bootstrap_path = self.workspace_dir / "BOOTSTRAP.md"
|
||||||
|
|
||||||
|
# Check if bootstrap has already been triggered
|
||||||
|
if self._completed_flag.exists():
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not bootstrap_path.exists():
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not self._is_first_user_interaction(agent):
|
||||||
|
return None
|
||||||
|
|
||||||
|
bootstrap_guidance = self._build_bootstrap_guidance()
|
||||||
|
|
||||||
|
logger.debug("Found BOOTSTRAP.md [%s], prepending guidance", self.language)
|
||||||
|
|
||||||
|
# Prepend to first user message in memory
|
||||||
|
if hasattr(agent, "memory") and agent.memory.content:
|
||||||
|
system_count = sum(
|
||||||
|
1 for msg, _ in agent.memory.content if msg.role == "system"
|
||||||
|
)
|
||||||
|
for msg, _ in agent.memory.content[system_count:]:
|
||||||
|
if msg.role == "user":
|
||||||
|
# Prepend guidance to message content
|
||||||
|
original_content = msg.content
|
||||||
|
msg.content = bootstrap_guidance + original_content
|
||||||
|
break
|
||||||
|
|
||||||
|
logger.debug("Bootstrap guidance prepended to first user message")
|
||||||
|
|
||||||
|
# Create completion flag to prevent repeated triggering
|
||||||
|
self._completed_flag.touch()
|
||||||
|
logger.debug("Created bootstrap completion flag")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to process bootstrap: %s", e, exc_info=True)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class WorkspaceWatchHook(Hook):
|
||||||
|
"""Hook for auto-reloading workspace markdown files on change.
|
||||||
|
|
||||||
|
Monitors SOUL.md, AGENTS.md, PROFILE.md, etc. and triggers
|
||||||
|
a prompt rebuild when any of them change. Based on CoPaw's
|
||||||
|
AgentConfigWatcher approach but for markdown files.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Files to monitor (same as PromptBuilder.DEFAULT_FILES)
|
||||||
|
WATCHED_FILES = frozenset([
|
||||||
|
"SOUL.md", "AGENTS.md", "PROFILE.md", "ROLE.md",
|
||||||
|
"POLICY.md", "MEMORY.md", "HEARTBEAT.md", "STYLE.md",
|
||||||
|
"BOOTSTRAP.md",
|
||||||
|
])
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
workspace_dir: Path,
|
||||||
|
poll_interval: float = 2.0,
|
||||||
|
):
|
||||||
|
"""Initialize workspace watch hook.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_dir: Workspace directory to monitor
|
||||||
|
poll_interval: How often to check for changes (seconds)
|
||||||
|
"""
|
||||||
|
self.workspace_dir = Path(workspace_dir)
|
||||||
|
self.poll_interval = poll_interval
|
||||||
|
self._last_mtimes: dict[str, float] = {}
|
||||||
|
self._initialized = False
|
||||||
|
|
||||||
|
def _scan_mtimes(self) -> dict[str, float]:
|
||||||
|
"""Scan watched files and return their current mtimes."""
|
||||||
|
mtimes = {}
|
||||||
|
for name in self.WATCHED_FILES:
|
||||||
|
path = self.workspace_dir / name
|
||||||
|
if path.exists():
|
||||||
|
mtimes[name] = path.stat().st_mtime
|
||||||
|
return mtimes
|
||||||
|
|
||||||
|
def _has_changes(self) -> bool:
|
||||||
|
"""Check if any watched file has changed since last check."""
|
||||||
|
current = self._scan_mtimes()
|
||||||
|
|
||||||
|
if not self._initialized:
|
||||||
|
self._last_mtimes = current
|
||||||
|
self._initialized = True
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check for new, modified, or deleted files
|
||||||
|
if set(current.keys()) != set(self._last_mtimes.keys()):
|
||||||
|
self._last_mtimes = current
|
||||||
|
return True
|
||||||
|
|
||||||
|
for name, mtime in current.items():
|
||||||
|
if mtime != self._last_mtimes.get(name):
|
||||||
|
self._last_mtimes = current
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def __call__(
|
||||||
|
self,
|
||||||
|
agent: "ReActAgent",
|
||||||
|
kwargs: Dict[str, Any],
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Check for file changes and rebuild prompt if needed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent: The agent instance
|
||||||
|
kwargs: Input arguments (unused)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if self._has_changes():
|
||||||
|
logger.info(
|
||||||
|
"Workspace files changed, triggering prompt rebuild for: %s",
|
||||||
|
getattr(agent, "agent_id", "unknown"),
|
||||||
|
)
|
||||||
|
if hasattr(agent, "rebuild_sys_prompt"):
|
||||||
|
agent.rebuild_sys_prompt()
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"Agent %s has no rebuild_sys_prompt method",
|
||||||
|
getattr(agent, "agent_id", "unknown"),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Workspace watch hook failed: %s", e, exc_info=True)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryCompactionHook(Hook):
|
||||||
|
"""Hook for automatic memory compaction when context is full.
|
||||||
|
|
||||||
|
This hook monitors the token count of messages and triggers compaction
|
||||||
|
when it exceeds the threshold. It preserves the system prompt and recent
|
||||||
|
messages while summarizing older conversation history.
|
||||||
|
|
||||||
|
Based on CoPaw's memory compaction design with additional improvements:
|
||||||
|
- memory_compact_ratio: Ratio to compact when threshold reached
|
||||||
|
- memory_reserve_ratio: Always keep a reserve of tokens for recent messages
|
||||||
|
- enable_tool_result_compact: Compact tool results separately
|
||||||
|
- tool_result_compact_keep_n: Number of tool results to keep
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
memory_manager: Any,
|
||||||
|
memory_compact_threshold: Optional[int] = None,
|
||||||
|
memory_compact_ratio: float = 0.75,
|
||||||
|
memory_reserve_ratio: float = 0.1,
|
||||||
|
enable_tool_result_compact: bool = False,
|
||||||
|
tool_result_compact_keep_n: int = 5,
|
||||||
|
):
|
||||||
|
"""Initialize memory compaction hook.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
memory_manager: Memory manager instance for compaction
|
||||||
|
memory_compact_threshold: Token threshold for compaction
|
||||||
|
memory_compact_ratio: Target ratio to compact to (e.g., 0.75 = compact to 75%)
|
||||||
|
memory_reserve_ratio: Reserve ratio to always keep free (e.g., 0.1 = 10%)
|
||||||
|
enable_tool_result_compact: Enable tool result compaction
|
||||||
|
tool_result_compact_keep_n: Number of tool results to keep
|
||||||
|
"""
|
||||||
|
self.memory_manager = memory_manager
|
||||||
|
self.memory_compact_threshold = memory_compact_threshold
|
||||||
|
self.memory_compact_ratio = memory_compact_ratio
|
||||||
|
self.memory_reserve_ratio = memory_reserve_ratio
|
||||||
|
self.enable_tool_result_compact = enable_tool_result_compact
|
||||||
|
self.tool_result_compact_keep_n = tool_result_compact_keep_n
|
||||||
|
|
||||||
|
async def __call__(
|
||||||
|
self,
|
||||||
|
agent: "ReActAgent",
|
||||||
|
kwargs: Dict[str, Any],
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Pre-reasoning hook to check and compact memory if needed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent: The agent instance
|
||||||
|
kwargs: Input arguments to the _reasoning method
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None (hook doesn't modify kwargs)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if not hasattr(agent, "memory") or not self.memory_manager:
|
||||||
|
return None
|
||||||
|
|
||||||
|
memory = agent.memory
|
||||||
|
|
||||||
|
# Get current token count estimate
|
||||||
|
messages = await memory.get_memory()
|
||||||
|
total_tokens = self._estimate_tokens(messages)
|
||||||
|
|
||||||
|
if self.memory_compact_threshold is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if total_tokens < self.memory_compact_threshold:
|
||||||
|
return None
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Memory compaction triggered: %d tokens (threshold: %d)",
|
||||||
|
total_tokens,
|
||||||
|
self.memory_compact_threshold,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compact memory
|
||||||
|
await self._compact_memory(agent, messages)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to compact memory: %s", e, exc_info=True)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _estimate_tokens(self, messages: List[Any]) -> int:
|
||||||
|
"""Estimate token count for messages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of messages
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Estimated token count
|
||||||
|
"""
|
||||||
|
# Simple estimation: ~4 chars per token
|
||||||
|
total_chars = sum(
|
||||||
|
len(str(getattr(msg, "content", "")))
|
||||||
|
for msg in messages
|
||||||
|
)
|
||||||
|
return total_chars // 4
|
||||||
|
|
||||||
|
async def _compact_memory(
|
||||||
|
self,
|
||||||
|
agent: "ReActAgent",
|
||||||
|
messages: List[Any],
|
||||||
|
) -> None:
|
||||||
|
"""Compact memory by summarizing older messages.
|
||||||
|
|
||||||
|
Uses CoPaw-style memory management:
|
||||||
|
- memory_compact_ratio: Target ratio to compact to (e.g., 0.75 means compact to 75%)
|
||||||
|
- memory_reserve_ratio: Always keep this ratio free (e.g., 0.1 means keep 10% for recent)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent: The agent instance
|
||||||
|
messages: Current messages in memory
|
||||||
|
"""
|
||||||
|
if self.memory_compact_threshold is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Estimate total tokens
|
||||||
|
total_tokens = self._estimate_tokens(messages)
|
||||||
|
|
||||||
|
# Calculate reserve based on ratio (CoPaw-style)
|
||||||
|
reserve_tokens = int(total_tokens * self.memory_reserve_ratio)
|
||||||
|
|
||||||
|
# Calculate target tokens after compaction
|
||||||
|
target_tokens = int(total_tokens * self.memory_compact_ratio)
|
||||||
|
target_tokens = max(target_tokens, total_tokens - reserve_tokens)
|
||||||
|
|
||||||
|
# Find messages to compact (older ones)
|
||||||
|
# Keep recent messages that fit within target
|
||||||
|
messages_to_compact = []
|
||||||
|
kept_tokens = 0
|
||||||
|
|
||||||
|
# Start from oldest, stop when we've kept enough
|
||||||
|
for msg in messages:
|
||||||
|
msg_tokens = self._estimate_tokens([msg])
|
||||||
|
if kept_tokens + msg_tokens > target_tokens:
|
||||||
|
messages_to_compact.append(msg)
|
||||||
|
else:
|
||||||
|
kept_tokens += msg_tokens
|
||||||
|
|
||||||
|
if not messages_to_compact:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Compacting %d messages (%d tokens) to target %d tokens",
|
||||||
|
len(messages_to_compact),
|
||||||
|
self._estimate_tokens(messages_to_compact),
|
||||||
|
target_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use memory manager to compact if available
|
||||||
|
if hasattr(self.memory_manager, "compact_memory"):
|
||||||
|
try:
|
||||||
|
summary = await self.memory_manager.compact_memory(
|
||||||
|
messages=messages_to_compact,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"Memory compacted: %d messages summarized, summary: %s",
|
||||||
|
len(messages_to_compact),
|
||||||
|
summary[:200] if summary else "N/A",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mark messages as compressed if supported
|
||||||
|
if hasattr(agent.memory, "update_messages_mark"):
|
||||||
|
from agentscope.agent._react_agent import _MemoryMark
|
||||||
|
await agent.memory.update_messages_mark(
|
||||||
|
new_mark=_MemoryMark.COMPRESSED,
|
||||||
|
msg_ids=[msg.id for msg in messages_to_compact],
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Memory manager compaction failed: %s", e)
|
||||||
|
|
||||||
|
# Tool result compaction (CoPaw-style)
|
||||||
|
if self.enable_tool_result_compact:
|
||||||
|
await self._compact_tool_results(agent, messages)
|
||||||
|
|
||||||
|
async def _compact_tool_results(
|
||||||
|
self,
|
||||||
|
agent: "ReActAgent",
|
||||||
|
messages: List[Any],
|
||||||
|
) -> None:
|
||||||
|
"""Compact tool results by keeping only recent ones.
|
||||||
|
|
||||||
|
Based on CoPaw's tool_result_compact_keep_n pattern.
|
||||||
|
Tool results can be very verbose, so we keep only the N most recent ones.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent: The agent instance
|
||||||
|
messages: Current messages in memory
|
||||||
|
"""
|
||||||
|
if not hasattr(agent.memory, "content"):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Find tool result messages (usually have "tool" role or tool_related content)
|
||||||
|
tool_results = []
|
||||||
|
for msg, _ in agent.memory.content:
|
||||||
|
if hasattr(msg, "role") and msg.role == "tool":
|
||||||
|
tool_results.append(msg)
|
||||||
|
|
||||||
|
if len(tool_results) <= self.tool_result_compact_keep_n:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Keep only the most recent N tool results
|
||||||
|
excess_results = tool_results[:-self.tool_result_compact_keep_n]
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Tool result compaction: %d tool results found, keeping %d, compacting %d",
|
||||||
|
len(tool_results),
|
||||||
|
self.tool_result_compact_keep_n,
|
||||||
|
len(excess_results),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mark excess tool results as compressed if supported
|
||||||
|
if hasattr(agent.memory, "update_messages_mark"):
|
||||||
|
from agentscope.agent._react_agent import _MemoryMark
|
||||||
|
await agent.memory.update_messages_mark(
|
||||||
|
new_mark=_MemoryMark.COMPRESSED,
|
||||||
|
msg_ids=[msg.id for msg in excess_results],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class HeartbeatHook(Hook):
|
||||||
|
"""Pre-reasoning hook that injects HEARTBEAT.md content.
|
||||||
|
|
||||||
|
Reads the agent's HEARTBEAT.md file and prepends it to the
|
||||||
|
reasoning input, causing the agent to perform self-checks.
|
||||||
|
|
||||||
|
This enables "主动检查" (proactive monitoring) - periodic
|
||||||
|
market condition and position checks during trading hours.
|
||||||
|
"""
|
||||||
|
|
||||||
|
HEARTBEAT_FILE = "HEARTBEAT.md"
|
||||||
|
|
||||||
|
def __init__(self, workspace_dir: Path):
|
||||||
|
"""Initialize heartbeat hook.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_dir: Working directory containing HEARTBEAT.md
|
||||||
|
"""
|
||||||
|
self.workspace_dir = Path(workspace_dir)
|
||||||
|
self._completed_flag = self.workspace_dir / ".heartbeat_completed"
|
||||||
|
|
||||||
|
def _read_heartbeat_content(self) -> Optional[str]:
|
||||||
|
"""Read HEARTBEAT.md if it exists and is non-empty.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The HEARTBEAT.md content stripped of whitespace, or None
|
||||||
|
if the file is absent or empty.
|
||||||
|
"""
|
||||||
|
hb_path = self.workspace_dir / self.HEARTBEAT_FILE
|
||||||
|
if not hb_path.exists():
|
||||||
|
return None
|
||||||
|
content = hb_path.read_text(encoding="utf-8").strip()
|
||||||
|
return content if content else None
|
||||||
|
|
||||||
|
async def __call__(
|
||||||
|
self,
|
||||||
|
agent: "ReActAgent",
|
||||||
|
kwargs: Dict[str, Any],
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Prepend heartbeat task to user message.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent: The agent instance
|
||||||
|
kwargs: Input arguments to the _reasoning method
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Modified kwargs with heartbeat content prepended, or None
|
||||||
|
if no HEARTBEAT.md content is available.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
content = self._read_heartbeat_content()
|
||||||
|
if not content:
|
||||||
|
return None
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"Heartbeat: found HEARTBEAT.md for agent %s",
|
||||||
|
getattr(agent, "agent_id", "unknown"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build heartbeat task instruction (Chinese)
|
||||||
|
hb_task = (
|
||||||
|
"# 定期主动检查\n\n"
|
||||||
|
f"{content}\n\n"
|
||||||
|
"请执行上述检查并报告结果。"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Inject into the first user message in memory
|
||||||
|
if hasattr(agent, "memory") and agent.memory.content:
|
||||||
|
system_count = sum(
|
||||||
|
1 for msg, _ in agent.memory.content if msg.role == "system"
|
||||||
|
)
|
||||||
|
for msg, _ in agent.memory.content[system_count:]:
|
||||||
|
if msg.role == "user":
|
||||||
|
original_content = msg.content
|
||||||
|
msg.content = hb_task + "\n\n" + original_content
|
||||||
|
break
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"Heartbeat task prepended for agent %s",
|
||||||
|
getattr(agent, "agent_id", "unknown"),
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Heartbeat hook failed: %s", e, exc_info=True)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Hook",
|
||||||
|
"HookManager",
|
||||||
|
"HookType",
|
||||||
|
"HOOK_PRE_REASONING",
|
||||||
|
"HOOK_POST_ACTING",
|
||||||
|
"BootstrapHook",
|
||||||
|
"HeartbeatHook",
|
||||||
|
"MemoryCompactionHook",
|
||||||
|
"WorkspaceWatchHook",
|
||||||
|
]
|
||||||
489
backend/agents/base/skill_adaptation_hook.py
Normal file
489
backend/agents/base/skill_adaptation_hook.py
Normal file
@@ -0,0 +1,489 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""Skill adaptation hook for automatic evaluation-to-iteration闭环.
|
||||||
|
|
||||||
|
Monitors evaluation metrics against configurable thresholds and triggers
|
||||||
|
automatic skill reload or logs warnings when thresholds are breached.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional, Set
|
||||||
|
|
||||||
|
from .evaluation_hook import (
|
||||||
|
EvaluationCollector,
|
||||||
|
EvaluationResult,
|
||||||
|
MetricType,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AdaptationAction(Enum):
|
||||||
|
"""Actions to take when threshold is breached."""
|
||||||
|
RELOAD = "reload" # 自动重新加载技能
|
||||||
|
WARN = "warn" # 记录警告供人工审核
|
||||||
|
BOTH = "both" # 同时执行重载和警告
|
||||||
|
NONE = "none" # 不做任何操作
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AdaptationThreshold:
|
||||||
|
"""Threshold configuration for a metric."""
|
||||||
|
metric_type: MetricType
|
||||||
|
operator: str = "lt" # lt (less than), gt (greater than), lte, gte, eq
|
||||||
|
value: float = 0.0
|
||||||
|
window_size: int = 10 # 移动窗口大小,用于计算滑动平均
|
||||||
|
min_samples: int = 5 # 最少样本数才触发检查
|
||||||
|
action: AdaptationAction = AdaptationAction.WARN
|
||||||
|
cooldown_seconds: int = 300 # 触发后的冷却时间
|
||||||
|
|
||||||
|
def evaluate(self, current_value: float) -> bool:
|
||||||
|
"""Evaluate if threshold is breached."""
|
||||||
|
ops = {
|
||||||
|
"lt": lambda x, y: x < y,
|
||||||
|
"lte": lambda x, y: x <= y,
|
||||||
|
"gt": lambda x, y: x > y,
|
||||||
|
"gte": lambda x, y: x >= y,
|
||||||
|
"eq": lambda x, y: x == y,
|
||||||
|
}
|
||||||
|
op_func = ops.get(self.operator)
|
||||||
|
if op_func is None:
|
||||||
|
logger.warning(f"Unknown operator: {self.operator}")
|
||||||
|
return False
|
||||||
|
return op_func(current_value, self.value)
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"metric_type": self.metric_type.value,
|
||||||
|
"operator": self.operator,
|
||||||
|
"value": self.value,
|
||||||
|
"window_size": self.window_size,
|
||||||
|
"min_samples": self.min_samples,
|
||||||
|
"action": self.action.value,
|
||||||
|
"cooldown_seconds": self.cooldown_seconds,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AdaptationEvent:
|
||||||
|
"""Record of an adaptation trigger event."""
|
||||||
|
timestamp: str
|
||||||
|
skill_name: str
|
||||||
|
metric_type: MetricType
|
||||||
|
threshold: AdaptationThreshold
|
||||||
|
current_value: float
|
||||||
|
avg_value: float
|
||||||
|
action_taken: AdaptationAction
|
||||||
|
details: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"timestamp": self.timestamp,
|
||||||
|
"skill_name": self.skill_name,
|
||||||
|
"metric_type": self.metric_type.value,
|
||||||
|
"threshold": self.threshold.to_dict(),
|
||||||
|
"current_value": self.current_value,
|
||||||
|
"avg_value": self.avg_value,
|
||||||
|
"action_taken": self.action_taken.value,
|
||||||
|
"details": self.details,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class SkillAdaptationHook:
|
||||||
|
"""Hook for monitoring evaluation metrics and triggering skill adaptation.
|
||||||
|
|
||||||
|
This hook wraps EvaluationHook to add threshold-based adaptation logic.
|
||||||
|
When metrics breach configured thresholds, it can:
|
||||||
|
- Automatically reload skills via SkillsManager
|
||||||
|
- Log warnings for human review
|
||||||
|
- Both
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Default thresholds for common metrics
|
||||||
|
DEFAULT_THRESHOLDS: List[AdaptationThreshold] = [
|
||||||
|
AdaptationThreshold(
|
||||||
|
metric_type=MetricType.HIT_RATE,
|
||||||
|
operator="lt",
|
||||||
|
value=0.5,
|
||||||
|
action=AdaptationAction.WARN,
|
||||||
|
cooldown_seconds=600,
|
||||||
|
),
|
||||||
|
AdaptationThreshold(
|
||||||
|
metric_type=MetricType.RISK_VIOLATION,
|
||||||
|
operator="gt",
|
||||||
|
value=0.1,
|
||||||
|
action=AdaptationAction.WARN,
|
||||||
|
cooldown_seconds=300,
|
||||||
|
),
|
||||||
|
AdaptationThreshold(
|
||||||
|
metric_type=MetricType.DECISION_LATENCY,
|
||||||
|
operator="gt",
|
||||||
|
value=5000, # 5 seconds
|
||||||
|
action=AdaptationAction.WARN,
|
||||||
|
cooldown_seconds=300,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
storage_dir: Path,
|
||||||
|
run_id: str,
|
||||||
|
agent_id: str,
|
||||||
|
thresholds: Optional[List[AdaptationThreshold]] = None,
|
||||||
|
collector: Optional[EvaluationCollector] = None,
|
||||||
|
):
|
||||||
|
"""Initialize skill adaptation hook.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
storage_dir: Directory to store adaptation events
|
||||||
|
run_id: Current run identifier
|
||||||
|
agent_id: Current agent identifier
|
||||||
|
thresholds: Custom threshold configurations (uses defaults if None)
|
||||||
|
collector: Optional EvaluationCollector for historical data
|
||||||
|
"""
|
||||||
|
self.storage_dir = Path(storage_dir)
|
||||||
|
self.run_id = run_id
|
||||||
|
self.agent_id = agent_id
|
||||||
|
self.thresholds = thresholds or self.DEFAULT_THRESHOLDS
|
||||||
|
self.collector = collector or EvaluationCollector(storage_dir)
|
||||||
|
|
||||||
|
# Track cooldowns to prevent rapid re-triggering
|
||||||
|
self._cooldowns: Dict[str, datetime] = {}
|
||||||
|
|
||||||
|
# Store recent metrics in memory for quick access
|
||||||
|
self._recent_metrics: Dict[str, List[float]] = {}
|
||||||
|
|
||||||
|
# Pending adaptation events
|
||||||
|
self._pending_events: List[AdaptationEvent] = []
|
||||||
|
|
||||||
|
def check_threshold(
|
||||||
|
self,
|
||||||
|
skill_name: str,
|
||||||
|
metric_type: MetricType,
|
||||||
|
current_value: float,
|
||||||
|
) -> Optional[AdaptationEvent]:
|
||||||
|
"""Check if a metric breaches any threshold.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
skill_name: Name of the skill
|
||||||
|
metric_type: Type of metric
|
||||||
|
current_value: Current metric value
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AdaptationEvent if threshold breached, None otherwise
|
||||||
|
"""
|
||||||
|
# Find applicable thresholds
|
||||||
|
applicable_thresholds = [
|
||||||
|
t for t in self.thresholds
|
||||||
|
if t.metric_type == metric_type
|
||||||
|
]
|
||||||
|
|
||||||
|
if not applicable_thresholds:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Check cooldown
|
||||||
|
cooldown_key = f"{skill_name}:{metric_type.value}"
|
||||||
|
now = datetime.now()
|
||||||
|
last_trigger = self._cooldowns.get(cooldown_key)
|
||||||
|
|
||||||
|
# Store current value first for avg calculation
|
||||||
|
self._store_metric(cooldown_key, current_value)
|
||||||
|
|
||||||
|
for threshold in applicable_thresholds:
|
||||||
|
if last_trigger:
|
||||||
|
elapsed = (now - last_trigger).total_seconds()
|
||||||
|
if elapsed < threshold.cooldown_seconds:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Evaluate threshold
|
||||||
|
if threshold.evaluate(current_value):
|
||||||
|
# Calculate moving average
|
||||||
|
avg_value = self._calculate_avg(skill_name, metric_type, current_value)
|
||||||
|
|
||||||
|
# Check minimum samples (allow immediate trigger if min_samples <= 1)
|
||||||
|
sample_count = len(self._recent_metrics.get(cooldown_key, []))
|
||||||
|
if threshold.min_samples > 1 and sample_count < threshold.min_samples:
|
||||||
|
# Not enough samples yet
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Trigger adaptation
|
||||||
|
event = AdaptationEvent(
|
||||||
|
timestamp=now.isoformat(),
|
||||||
|
skill_name=skill_name,
|
||||||
|
metric_type=metric_type,
|
||||||
|
threshold=threshold,
|
||||||
|
current_value=current_value,
|
||||||
|
avg_value=avg_value,
|
||||||
|
action_taken=threshold.action,
|
||||||
|
details={
|
||||||
|
"run_id": self.run_id,
|
||||||
|
"agent_id": self.agent_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update cooldown
|
||||||
|
self._cooldowns[cooldown_key] = now
|
||||||
|
|
||||||
|
# Persist event
|
||||||
|
self._persist_event(event)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Threshold breached for {skill_name}.{metric_type.value}: "
|
||||||
|
f"current={current_value}, avg={avg_value}, action={threshold.action.value}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return event
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _calculate_avg(
|
||||||
|
self,
|
||||||
|
skill_name: str,
|
||||||
|
metric_type: MetricType,
|
||||||
|
current_value: float,
|
||||||
|
) -> float:
|
||||||
|
"""Calculate moving average for a metric."""
|
||||||
|
key = f"{skill_name}:{metric_type.value}"
|
||||||
|
values = self._recent_metrics.get(key, [])
|
||||||
|
if not values:
|
||||||
|
return current_value
|
||||||
|
return sum(values) / len(values)
|
||||||
|
|
||||||
|
def _store_metric(self, key: str, value: float) -> None:
|
||||||
|
"""Store metric value with sliding window."""
|
||||||
|
if key not in self._recent_metrics:
|
||||||
|
self._recent_metrics[key] = []
|
||||||
|
self._recent_metrics[key].append(value)
|
||||||
|
# Keep only last 100 values
|
||||||
|
if len(self._recent_metrics[key]) > 100:
|
||||||
|
self._recent_metrics[key] = self._recent_metrics[key][-100:]
|
||||||
|
|
||||||
|
def _persist_event(self, event: AdaptationEvent) -> None:
|
||||||
|
"""Persist adaptation event to storage."""
|
||||||
|
run_dir = self.storage_dir / self.run_id / "adaptations"
|
||||||
|
run_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
||||||
|
filename = f"{event.skill_name}_{event.metric_type.value}_{timestamp}.json"
|
||||||
|
filepath = run_dir / filename
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(filepath, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(event.to_dict(), f, ensure_ascii=False, indent=2)
|
||||||
|
logger.debug(f"Persisted adaptation event to: {filepath}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to persist adaptation event: {e}")
|
||||||
|
|
||||||
|
# Also add to pending list
|
||||||
|
self._pending_events.append(event)
|
||||||
|
|
||||||
|
def get_pending_warnings(self) -> List[AdaptationEvent]:
|
||||||
|
"""Get all pending warning events that need human review."""
|
||||||
|
return [
|
||||||
|
e for e in self._pending_events
|
||||||
|
if e.action_taken in (AdaptationAction.WARN, AdaptationAction.BOTH)
|
||||||
|
]
|
||||||
|
|
||||||
|
def clear_pending_warnings(self) -> None:
|
||||||
|
"""Clear pending warnings after they have been reviewed."""
|
||||||
|
self._pending_events = [
|
||||||
|
e for e in self._pending_events
|
||||||
|
if e.action_taken == AdaptationAction.RELOAD
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_recent_events(
|
||||||
|
self,
|
||||||
|
skill_name: Optional[str] = None,
|
||||||
|
metric_type: Optional[MetricType] = None,
|
||||||
|
limit: int = 50,
|
||||||
|
) -> List[AdaptationEvent]:
|
||||||
|
"""Get recent adaptation events.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
skill_name: Optional filter by skill name
|
||||||
|
metric_type: Optional filter by metric type
|
||||||
|
limit: Maximum number of events to return
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of recent adaptation events
|
||||||
|
"""
|
||||||
|
events_dir = self.storage_dir / self.run_id / "adaptations"
|
||||||
|
if not events_dir.exists():
|
||||||
|
return []
|
||||||
|
|
||||||
|
events = []
|
||||||
|
for eval_file in sorted(events_dir.glob("*.json"), reverse=True)[:limit]:
|
||||||
|
try:
|
||||||
|
with open(eval_file, "r", encoding="utf-8") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
event = self._parse_event(data)
|
||||||
|
if skill_name and event.skill_name != skill_name:
|
||||||
|
continue
|
||||||
|
if metric_type and event.metric_type != metric_type:
|
||||||
|
continue
|
||||||
|
events.append(event)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to load adaptation event {eval_file}: {e}")
|
||||||
|
|
||||||
|
return events
|
||||||
|
|
||||||
|
def _parse_event(self, data: Dict[str, Any]) -> AdaptationEvent:
|
||||||
|
"""Parse adaptation event from JSON data."""
|
||||||
|
threshold_data = data.get("threshold", {})
|
||||||
|
metric_type = MetricType(threshold_data.get("metric_type", "custom"))
|
||||||
|
|
||||||
|
threshold = AdaptationThreshold(
|
||||||
|
metric_type=metric_type,
|
||||||
|
operator=threshold_data.get("operator", "lt"),
|
||||||
|
value=threshold_data.get("value", 0.0),
|
||||||
|
window_size=threshold_data.get("window_size", 10),
|
||||||
|
min_samples=threshold_data.get("min_samples", 5),
|
||||||
|
action=AdaptationAction(threshold_data.get("action", "warn")),
|
||||||
|
cooldown_seconds=threshold_data.get("cooldown_seconds", 300),
|
||||||
|
)
|
||||||
|
|
||||||
|
return AdaptationEvent(
|
||||||
|
timestamp=data.get("timestamp", ""),
|
||||||
|
skill_name=data.get("skill_name", ""),
|
||||||
|
metric_type=metric_type,
|
||||||
|
threshold=threshold,
|
||||||
|
current_value=data.get("current_value", 0.0),
|
||||||
|
avg_value=data.get("avg_value", 0.0),
|
||||||
|
action_taken=AdaptationAction(data.get("action_taken", "warn")),
|
||||||
|
details=data.get("details", {}),
|
||||||
|
)
|
||||||
|
|
||||||
|
def add_threshold(self, threshold: AdaptationThreshold) -> None:
|
||||||
|
"""Add a new threshold configuration."""
|
||||||
|
self.thresholds.append(threshold)
|
||||||
|
|
||||||
|
def remove_threshold(self, metric_type: MetricType) -> None:
|
||||||
|
"""Remove all thresholds for a specific metric type."""
|
||||||
|
self.thresholds = [
|
||||||
|
t for t in self.thresholds
|
||||||
|
if t.metric_type != metric_type
|
||||||
|
]
|
||||||
|
|
||||||
|
def update_threshold(
|
||||||
|
self,
|
||||||
|
metric_type: MetricType,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
"""Update threshold configuration for a metric type."""
|
||||||
|
for threshold in self.thresholds:
|
||||||
|
if threshold.metric_type == metric_type:
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
if hasattr(threshold, key):
|
||||||
|
setattr(threshold, key, value)
|
||||||
|
|
||||||
|
def get_thresholds(self) -> List[AdaptationThreshold]:
|
||||||
|
"""Get current threshold configurations."""
|
||||||
|
return list(self.thresholds)
|
||||||
|
|
||||||
|
def is_in_cooldown(self, skill_name: str, metric_type: MetricType) -> bool:
|
||||||
|
"""Check if a skill/metric combination is in cooldown period."""
|
||||||
|
key = f"{skill_name}:{metric_type.value}"
|
||||||
|
last_trigger = self._cooldowns.get(key)
|
||||||
|
if not last_trigger:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Find the threshold for this metric type
|
||||||
|
for threshold in self.thresholds:
|
||||||
|
if threshold.metric_type == metric_type:
|
||||||
|
elapsed = (datetime.now() - last_trigger).total_seconds()
|
||||||
|
return elapsed < threshold.cooldown_seconds
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class AdaptationManager:
|
||||||
|
"""Manager for coordinating skill adaptation across multiple agents.
|
||||||
|
|
||||||
|
Provides centralized tracking of adaptation events and skill reloads.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, storage_dir: Path):
|
||||||
|
"""Initialize adaptation manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
storage_dir: Root directory for storing adaptation data
|
||||||
|
"""
|
||||||
|
self.storage_dir = Path(storage_dir)
|
||||||
|
self._hooks: Dict[str, SkillAdaptationHook] = {}
|
||||||
|
|
||||||
|
def get_hook(
|
||||||
|
self,
|
||||||
|
run_id: str,
|
||||||
|
agent_id: str,
|
||||||
|
thresholds: Optional[List[AdaptationThreshold]] = None,
|
||||||
|
) -> SkillAdaptationHook:
|
||||||
|
"""Get or create an adaptation hook for an agent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run_id: Run identifier
|
||||||
|
agent_id: Agent identifier
|
||||||
|
thresholds: Optional custom thresholds
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SkillAdaptationHook instance
|
||||||
|
"""
|
||||||
|
key = f"{run_id}:{agent_id}"
|
||||||
|
if key not in self._hooks:
|
||||||
|
self._hooks[key] = SkillAdaptationHook(
|
||||||
|
storage_dir=self.storage_dir,
|
||||||
|
run_id=run_id,
|
||||||
|
agent_id=agent_id,
|
||||||
|
thresholds=thresholds,
|
||||||
|
)
|
||||||
|
return self._hooks[key]
|
||||||
|
|
||||||
|
def get_all_pending_warnings(self) -> List[AdaptationEvent]:
|
||||||
|
"""Get all pending warnings from all hooks."""
|
||||||
|
warnings = []
|
||||||
|
for hook in self._hooks.values():
|
||||||
|
warnings.extend(hook.get_pending_warnings())
|
||||||
|
return warnings
|
||||||
|
|
||||||
|
def get_run_adaptations(self, run_id: str) -> List[AdaptationEvent]:
|
||||||
|
"""Get all adaptation events for a run."""
|
||||||
|
events = []
|
||||||
|
for hook in self._hooks.values():
|
||||||
|
if hook.run_id == run_id:
|
||||||
|
events.extend(hook.get_recent_events())
|
||||||
|
return events
|
||||||
|
|
||||||
|
|
||||||
|
# Global manager instance
|
||||||
|
_adaptation_manager: Optional[AdaptationManager] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_adaptation_manager(storage_dir: Optional[Path] = None) -> AdaptationManager:
|
||||||
|
"""Get global adaptation manager instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
storage_dir: Optional storage directory (required on first call)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AdaptationManager instance
|
||||||
|
"""
|
||||||
|
global _adaptation_manager
|
||||||
|
if _adaptation_manager is None:
|
||||||
|
if storage_dir is None:
|
||||||
|
raise ValueError("storage_dir required on first initialization")
|
||||||
|
_adaptation_manager = AdaptationManager(storage_dir)
|
||||||
|
return _adaptation_manager
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AdaptationAction",
|
||||||
|
"AdaptationThreshold",
|
||||||
|
"AdaptationEvent",
|
||||||
|
"SkillAdaptationHook",
|
||||||
|
"AdaptationManager",
|
||||||
|
"get_adaptation_manager",
|
||||||
|
]
|
||||||
684
backend/agents/base/tool_guard.py
Normal file
684
backend/agents/base/tool_guard.py
Normal file
@@ -0,0 +1,684 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""ToolGuardMixin - Security interception for dangerous tool calls.
|
||||||
|
|
||||||
|
Provides ``_acting`` and ``_reasoning`` overrides that intercept
|
||||||
|
sensitive tool calls before execution, implementing the deny /
|
||||||
|
guard / approve flow.
|
||||||
|
|
||||||
|
Based on CoPaw's tool_guard_mixin.py design.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
from typing import Any, Callable, Dict, Iterable, List, Optional, Set
|
||||||
|
|
||||||
|
from agentscope.message import Msg
|
||||||
|
from backend.runtime.manager import get_global_runtime_manager
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class SeverityLevel(str, Enum):
|
||||||
|
"""Risk severity level."""
|
||||||
|
|
||||||
|
LOW = "low"
|
||||||
|
MEDIUM = "medium"
|
||||||
|
HIGH = "high"
|
||||||
|
CRITICAL = "critical"
|
||||||
|
|
||||||
|
|
||||||
|
class ApprovalStatus(str, Enum):
|
||||||
|
"""Approval lifecycle state."""
|
||||||
|
|
||||||
|
PENDING = "pending"
|
||||||
|
APPROVED = "approved"
|
||||||
|
DENIED = "denied"
|
||||||
|
EXPIRED = "expired"
|
||||||
|
|
||||||
|
|
||||||
|
class ToolFindingRecord:
|
||||||
|
"""Internal representation of a guard finding."""
|
||||||
|
|
||||||
|
def __init__(self, severity: SeverityLevel, message: str, field: Optional[str] = None) -> None:
|
||||||
|
self.severity = severity
|
||||||
|
self.message = message
|
||||||
|
self.field = field
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"severity": self.severity.value,
|
||||||
|
"message": self.message,
|
||||||
|
"field": self.field,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ApprovalRecord:
|
||||||
|
"""Stores the state of an approval request."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
approval_id: str,
|
||||||
|
tool_name: str,
|
||||||
|
tool_input: Dict[str, Any],
|
||||||
|
agent_id: str,
|
||||||
|
workspace_id: str,
|
||||||
|
session_id: Optional[str] = None,
|
||||||
|
findings: Optional[List[ToolFindingRecord]] = None,
|
||||||
|
) -> None:
|
||||||
|
self.approval_id = approval_id
|
||||||
|
self.tool_name = tool_name
|
||||||
|
self.tool_input = tool_input
|
||||||
|
self.agent_id = agent_id
|
||||||
|
self.workspace_id = workspace_id
|
||||||
|
self.session_id = session_id
|
||||||
|
self.status = ApprovalStatus.PENDING
|
||||||
|
self.findings = findings or []
|
||||||
|
self.created_at = datetime.utcnow()
|
||||||
|
self.resolved_at: Optional[datetime] = None
|
||||||
|
self.resolved_by: Optional[str] = None
|
||||||
|
self.metadata: Dict[str, Any] = {}
|
||||||
|
self.pending_request: "ToolApprovalRequest" | None = None
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"approval_id": self.approval_id,
|
||||||
|
"status": self.status.value,
|
||||||
|
"tool_name": self.tool_name,
|
||||||
|
"tool_input": self.tool_input,
|
||||||
|
"agent_id": self.agent_id,
|
||||||
|
"workspace_id": self.workspace_id,
|
||||||
|
"session_id": self.session_id,
|
||||||
|
"findings": [f.to_dict() for f in self.findings],
|
||||||
|
"created_at": self.created_at.isoformat(),
|
||||||
|
"resolved_at": self.resolved_at.isoformat() if self.resolved_at else None,
|
||||||
|
"resolved_by": self.resolved_by,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ToolGuardStore:
|
||||||
|
"""Simple in-memory approval store for development/testing."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._records: Dict[str, ApprovalRecord] = {}
|
||||||
|
self._counter = 0
|
||||||
|
|
||||||
|
def next_id(self) -> str:
|
||||||
|
self._counter += 1
|
||||||
|
return f"approval_{self._counter:06d}"
|
||||||
|
|
||||||
|
def list(
|
||||||
|
self,
|
||||||
|
status: ApprovalStatus | None = None,
|
||||||
|
workspace_id: Optional[str] = None,
|
||||||
|
agent_id: Optional[str] = None,
|
||||||
|
) -> Iterable[ApprovalRecord]:
|
||||||
|
for record in self._records.values():
|
||||||
|
if status and record.status != status:
|
||||||
|
continue
|
||||||
|
if workspace_id and record.workspace_id != workspace_id:
|
||||||
|
continue
|
||||||
|
if agent_id and record.agent_id != agent_id:
|
||||||
|
continue
|
||||||
|
yield record
|
||||||
|
|
||||||
|
def get(self, approval_id: str) -> Optional[ApprovalRecord]:
|
||||||
|
return self._records.get(approval_id)
|
||||||
|
|
||||||
|
def create_pending(
|
||||||
|
self,
|
||||||
|
tool_name: str,
|
||||||
|
tool_input: Dict[str, Any],
|
||||||
|
agent_id: str,
|
||||||
|
workspace_id: str,
|
||||||
|
session_id: Optional[str] = None,
|
||||||
|
findings: Optional[List[ToolFindingRecord]] = None,
|
||||||
|
) -> ApprovalRecord:
|
||||||
|
record = ApprovalRecord(
|
||||||
|
approval_id=self.next_id(),
|
||||||
|
tool_name=tool_name,
|
||||||
|
tool_input=tool_input,
|
||||||
|
agent_id=agent_id,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
session_id=session_id,
|
||||||
|
findings=findings,
|
||||||
|
)
|
||||||
|
self._records[record.approval_id] = record
|
||||||
|
return record
|
||||||
|
|
||||||
|
def set_status(
|
||||||
|
self,
|
||||||
|
approval_id: str,
|
||||||
|
status: ApprovalStatus,
|
||||||
|
resolved_by: Optional[str] = None,
|
||||||
|
notify_request: bool = True,
|
||||||
|
) -> ApprovalRecord:
|
||||||
|
record = self._records[approval_id]
|
||||||
|
if record.status == status:
|
||||||
|
return record
|
||||||
|
|
||||||
|
record.status = status
|
||||||
|
record.resolved_at = datetime.utcnow()
|
||||||
|
record.resolved_by = resolved_by
|
||||||
|
if notify_request and record.pending_request:
|
||||||
|
if status == ApprovalStatus.APPROVED:
|
||||||
|
record.pending_request.approve()
|
||||||
|
elif status == ApprovalStatus.DENIED:
|
||||||
|
record.pending_request.deny()
|
||||||
|
return record
|
||||||
|
|
||||||
|
def cancel(self, approval_id: str) -> None:
|
||||||
|
self._records.pop(approval_id, None)
|
||||||
|
|
||||||
|
|
||||||
|
TOOL_GUARD_STORE = ToolGuardStore()
|
||||||
|
|
||||||
|
|
||||||
|
def get_tool_guard_store() -> ToolGuardStore:
|
||||||
|
return TOOL_GUARD_STORE
|
||||||
|
|
||||||
|
|
||||||
|
# Default tools that require approval
|
||||||
|
DEFAULT_GUARDED_TOOLS: Set[str] = {
|
||||||
|
"execute_shell_command",
|
||||||
|
"write_file",
|
||||||
|
"edit_file",
|
||||||
|
"place_order",
|
||||||
|
"modify_position",
|
||||||
|
"delete_file",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Default denied tools (cannot be approved)
|
||||||
|
DEFAULT_DENIED_TOOLS: Set[str] = {
|
||||||
|
"execute_shell_command", # Shell execution is dangerous
|
||||||
|
}
|
||||||
|
|
||||||
|
# Mark for tool guard denied messages
|
||||||
|
TOOL_GUARD_DENIED_MARK = "tool_guard_denied"
|
||||||
|
|
||||||
|
|
||||||
|
def default_findings_for_tool(tool_name: str) -> List[ToolFindingRecord]:
|
||||||
|
findings: List[ToolFindingRecord] = []
|
||||||
|
if tool_name in {"execute_trade", "modify_portfolio"}:
|
||||||
|
findings.append(
|
||||||
|
ToolFindingRecord(
|
||||||
|
severity=SeverityLevel.HIGH,
|
||||||
|
message=f"Tool '{tool_name}' touches portfolio state",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return findings
|
||||||
|
|
||||||
|
|
||||||
|
class ToolApprovalRequest:
|
||||||
|
"""Represents a pending tool approval request."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
approval_id: str,
|
||||||
|
tool_name: str,
|
||||||
|
tool_input: Dict[str, Any],
|
||||||
|
tool_call_id: str,
|
||||||
|
session_id: Optional[str] = None,
|
||||||
|
):
|
||||||
|
self.approval_id = approval_id
|
||||||
|
self.tool_name = tool_name
|
||||||
|
self.tool_input = tool_input
|
||||||
|
self.tool_call_id = tool_call_id
|
||||||
|
self.session_id = session_id
|
||||||
|
self.approved: Optional[bool] = None
|
||||||
|
self._event = asyncio.Event()
|
||||||
|
|
||||||
|
async def wait_for_approval(self, timeout: Optional[float] = None) -> bool:
|
||||||
|
"""Wait for approval decision.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timeout: Maximum time to wait in seconds
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if approved, False otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(self._event.wait(), timeout=timeout)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
return False
|
||||||
|
return self.approved is True
|
||||||
|
|
||||||
|
def approve(self) -> None:
|
||||||
|
"""Approve this request."""
|
||||||
|
self.approved = True
|
||||||
|
self._event.set()
|
||||||
|
|
||||||
|
def deny(self) -> None:
|
||||||
|
"""Deny this request."""
|
||||||
|
self.approved = False
|
||||||
|
self._event.set()
|
||||||
|
|
||||||
|
|
||||||
|
class ToolGuardMixin:
|
||||||
|
"""Mixin that adds tool-guard interception to a ReActAgent.
|
||||||
|
|
||||||
|
At runtime this class is combined with ReActAgent via MRO,
|
||||||
|
so ``super()._acting`` and ``super()._reasoning`` resolve to
|
||||||
|
the concrete agent methods.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
class MyAgent(ToolGuardMixin, ReActAgent):
|
||||||
|
def __init__(self, ...):
|
||||||
|
super().__init__(...)
|
||||||
|
self._init_tool_guard()
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _init_tool_guard(
|
||||||
|
self,
|
||||||
|
guarded_tools: Optional[Set[str]] = None,
|
||||||
|
denied_tools: Optional[Set[str]] = None,
|
||||||
|
approval_timeout: float = 300.0,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize tool guard.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
guarded_tools: Set of tool names requiring approval
|
||||||
|
denied_tools: Set of tool names that are always denied
|
||||||
|
approval_timeout: Timeout for approval requests in seconds
|
||||||
|
"""
|
||||||
|
self._guarded_tools = guarded_tools or DEFAULT_GUARDED_TOOLS.copy()
|
||||||
|
self._denied_tools = denied_tools or DEFAULT_DENIED_TOOLS.copy()
|
||||||
|
self._approval_timeout = approval_timeout
|
||||||
|
self._pending_approval: Optional[ToolApprovalRequest] = None
|
||||||
|
self._approval_callback: Optional[Callable[[ToolApprovalRequest], None]] = None
|
||||||
|
self._approval_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
def set_approval_callback(
|
||||||
|
self,
|
||||||
|
callback: Callable[[ToolApprovalRequest], None],
|
||||||
|
) -> None:
|
||||||
|
"""Set callback for approval requests.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
callback: Function called when approval is needed
|
||||||
|
"""
|
||||||
|
self._approval_callback = callback
|
||||||
|
|
||||||
|
def _is_tool_guarded(self, tool_name: str) -> bool:
|
||||||
|
"""Check if a tool requires approval.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_name: Name of the tool
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if tool requires approval
|
||||||
|
"""
|
||||||
|
return tool_name in self._guarded_tools
|
||||||
|
|
||||||
|
def _is_tool_denied(self, tool_name: str) -> bool:
|
||||||
|
"""Check if a tool is always denied.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_name: Name of the tool
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if tool is denied
|
||||||
|
"""
|
||||||
|
return tool_name in self._denied_tools
|
||||||
|
|
||||||
|
def _last_tool_response_is_denied(self) -> bool:
|
||||||
|
"""Check if the last message is a guard-denied tool result."""
|
||||||
|
if not hasattr(self, "memory") or not self.memory.content:
|
||||||
|
return False
|
||||||
|
|
||||||
|
msg, marks = self.memory.content[-1]
|
||||||
|
return TOOL_GUARD_DENIED_MARK in marks and msg.role == "system"
|
||||||
|
|
||||||
|
async def _cleanup_tool_guard_denied_messages(
|
||||||
|
self,
|
||||||
|
include_denial_response: bool = True,
|
||||||
|
) -> None:
|
||||||
|
"""Remove tool-guard denied messages from memory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
include_denial_response: Also remove the assistant's denial explanation
|
||||||
|
"""
|
||||||
|
if not hasattr(self, "memory"):
|
||||||
|
return
|
||||||
|
|
||||||
|
ids_to_delete: list[str] = []
|
||||||
|
last_marked_idx = -1
|
||||||
|
|
||||||
|
for i, (msg, marks) in enumerate(self.memory.content):
|
||||||
|
if TOOL_GUARD_DENIED_MARK in marks:
|
||||||
|
ids_to_delete.append(msg.id)
|
||||||
|
last_marked_idx = i
|
||||||
|
|
||||||
|
if (
|
||||||
|
include_denial_response
|
||||||
|
and last_marked_idx >= 0
|
||||||
|
and last_marked_idx + 1 < len(self.memory.content)
|
||||||
|
):
|
||||||
|
next_msg, _ = self.memory.content[last_marked_idx + 1]
|
||||||
|
if next_msg.role == "assistant":
|
||||||
|
ids_to_delete.append(next_msg.id)
|
||||||
|
|
||||||
|
if ids_to_delete:
|
||||||
|
removed = await self.memory.delete(ids_to_delete)
|
||||||
|
logger.info("Tool guard: cleaned up %d denied message(s)", removed)
|
||||||
|
|
||||||
|
async def _request_guard_approval(
|
||||||
|
self,
|
||||||
|
tool_name: str,
|
||||||
|
tool_input: Dict[str, Any],
|
||||||
|
tool_call_id: str,
|
||||||
|
) -> bool:
|
||||||
|
"""Request approval for a guarded tool call.
|
||||||
|
|
||||||
|
This method creates a ToolApprovalRequest and waits for
|
||||||
|
external approval via approve_guard_call() or deny_guard_call().
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_name: Name of the tool
|
||||||
|
tool_input: Tool input parameters
|
||||||
|
tool_call_id: ID of the tool call
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if approved, False otherwise
|
||||||
|
"""
|
||||||
|
async with self._approval_lock:
|
||||||
|
record = TOOL_GUARD_STORE.create_pending(
|
||||||
|
tool_name=tool_name,
|
||||||
|
tool_input=tool_input,
|
||||||
|
agent_id=getattr(self, "agent_id", "unknown"),
|
||||||
|
workspace_id=getattr(self, "workspace_id", "default"),
|
||||||
|
session_id=getattr(self, "session_id", None),
|
||||||
|
findings=default_findings_for_tool(tool_name),
|
||||||
|
)
|
||||||
|
|
||||||
|
manager = get_global_runtime_manager()
|
||||||
|
if manager:
|
||||||
|
manager.register_pending_approval(
|
||||||
|
record.approval_id,
|
||||||
|
{
|
||||||
|
"tool_name": record.tool_name,
|
||||||
|
"agent_id": record.agent_id,
|
||||||
|
"workspace_id": record.workspace_id,
|
||||||
|
"session_id": record.session_id,
|
||||||
|
"tool_input": record.tool_input,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
self._pending_approval = ToolApprovalRequest(
|
||||||
|
approval_id=record.approval_id,
|
||||||
|
tool_name=tool_name,
|
||||||
|
tool_input=tool_input,
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
session_id=getattr(self, "session_id", None),
|
||||||
|
)
|
||||||
|
record.pending_request = self._pending_approval
|
||||||
|
|
||||||
|
# Notify via callback if set
|
||||||
|
if self._approval_callback:
|
||||||
|
self._approval_callback(self._pending_approval)
|
||||||
|
|
||||||
|
# Wait for approval (lock is released during wait, re-acquired after)
|
||||||
|
approval_request = self._pending_approval
|
||||||
|
|
||||||
|
# Wait for approval outside the lock to allow concurrent approval
|
||||||
|
approved = await approval_request.wait_for_approval(
|
||||||
|
timeout=self._approval_timeout
|
||||||
|
)
|
||||||
|
|
||||||
|
async with self._approval_lock:
|
||||||
|
if approval_request:
|
||||||
|
status = (
|
||||||
|
ApprovalStatus.APPROVED
|
||||||
|
if approval_request.approved is True
|
||||||
|
else ApprovalStatus.DENIED
|
||||||
|
if approval_request.approved is False
|
||||||
|
else ApprovalStatus.EXPIRED
|
||||||
|
)
|
||||||
|
TOOL_GUARD_STORE.set_status(
|
||||||
|
approval_request.approval_id,
|
||||||
|
status,
|
||||||
|
resolved_by="agent",
|
||||||
|
notify_request=False,
|
||||||
|
)
|
||||||
|
manager = get_global_runtime_manager()
|
||||||
|
if manager:
|
||||||
|
manager.resolve_pending_approval(
|
||||||
|
approval_request.approval_id,
|
||||||
|
resolved_by="agent",
|
||||||
|
status=status.value,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Only clear if this is still the same request
|
||||||
|
if self._pending_approval is approval_request:
|
||||||
|
self._pending_approval = None
|
||||||
|
|
||||||
|
return approved
|
||||||
|
|
||||||
|
async def approve_guard_call(self, request_id: Optional[str] = None) -> bool:
|
||||||
|
"""Approve a pending guard request.
|
||||||
|
|
||||||
|
This method is called externally to approve a tool call
|
||||||
|
that is waiting for approval.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request_id: Optional request ID to verify (not yet implemented)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if a request was approved, False if no pending request
|
||||||
|
"""
|
||||||
|
async with self._approval_lock:
|
||||||
|
if self._pending_approval is None:
|
||||||
|
logger.warning("No pending approval request to approve")
|
||||||
|
return False
|
||||||
|
|
||||||
|
TOOL_GUARD_STORE.set_status(
|
||||||
|
self._pending_approval.approval_id,
|
||||||
|
ApprovalStatus.APPROVED,
|
||||||
|
resolved_by="agent",
|
||||||
|
notify_request=False,
|
||||||
|
)
|
||||||
|
manager = get_global_runtime_manager()
|
||||||
|
if manager:
|
||||||
|
manager.resolve_pending_approval(
|
||||||
|
self._pending_approval.approval_id,
|
||||||
|
resolved_by="agent",
|
||||||
|
status=ApprovalStatus.APPROVED.value,
|
||||||
|
)
|
||||||
|
self._pending_approval.approve()
|
||||||
|
logger.info("Approved tool call: %s", self._pending_approval.tool_name)
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def deny_guard_call(self, request_id: Optional[str] = None) -> bool:
|
||||||
|
"""Deny a pending guard request.
|
||||||
|
|
||||||
|
This method is called externally to deny a tool call
|
||||||
|
that is waiting for approval.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request_id: Optional request ID to verify (not yet implemented)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if a request was denied, False if no pending request
|
||||||
|
"""
|
||||||
|
async with self._approval_lock:
|
||||||
|
if self._pending_approval is None:
|
||||||
|
logger.warning("No pending approval request to deny")
|
||||||
|
return False
|
||||||
|
|
||||||
|
TOOL_GUARD_STORE.set_status(
|
||||||
|
self._pending_approval.approval_id,
|
||||||
|
ApprovalStatus.DENIED,
|
||||||
|
resolved_by="agent",
|
||||||
|
notify_request=False,
|
||||||
|
)
|
||||||
|
manager = get_global_runtime_manager()
|
||||||
|
if manager:
|
||||||
|
manager.resolve_pending_approval(
|
||||||
|
self._pending_approval.approval_id,
|
||||||
|
resolved_by="agent",
|
||||||
|
status=ApprovalStatus.DENIED.value,
|
||||||
|
)
|
||||||
|
self._pending_approval.deny()
|
||||||
|
logger.info("Denied tool call: %s", self._pending_approval.tool_name)
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def _acting(self, tool_call) -> dict | None:
|
||||||
|
"""Intercept sensitive tool calls before execution.
|
||||||
|
|
||||||
|
1. If tool is in denied_tools, auto-deny unconditionally.
|
||||||
|
2. Check for a one-shot pre-approval.
|
||||||
|
3. If tool is in the guarded scope, request approval.
|
||||||
|
4. Otherwise, delegate to parent _acting.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_call: Tool call from the model
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tool result dict or None
|
||||||
|
"""
|
||||||
|
tool_name: str = tool_call.get("name", "")
|
||||||
|
tool_input: dict = tool_call.get("input", {})
|
||||||
|
tool_call_id: str = tool_call.get("id", "")
|
||||||
|
|
||||||
|
# Check if tool is denied
|
||||||
|
if tool_name and self._is_tool_denied(tool_name):
|
||||||
|
logger.warning("Tool '%s' is in the denied set, auto-denying", tool_name)
|
||||||
|
return await self._acting_auto_denied(tool_call, tool_name)
|
||||||
|
|
||||||
|
# Check if tool is guarded
|
||||||
|
if tool_name and self._is_tool_guarded(tool_name):
|
||||||
|
approved = await self._request_guard_approval(
|
||||||
|
tool_name=tool_name,
|
||||||
|
tool_input=tool_input,
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not approved:
|
||||||
|
return await self._acting_with_denial(tool_call, tool_name)
|
||||||
|
|
||||||
|
# Call parent _acting
|
||||||
|
return await super()._acting(tool_call) # type: ignore[misc]
|
||||||
|
|
||||||
|
async def _acting_auto_denied(
|
||||||
|
self,
|
||||||
|
tool_call: Dict[str, Any],
|
||||||
|
tool_name: str,
|
||||||
|
) -> dict | None:
|
||||||
|
"""Auto-deny a tool call without offering approval.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_call: Tool call from the model
|
||||||
|
tool_name: Name of the denied tool
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Denial result
|
||||||
|
"""
|
||||||
|
from agentscope.message import ToolResultBlock
|
||||||
|
|
||||||
|
denied_text = (
|
||||||
|
f"⛔ **Tool Blocked / 工具已拦截**\n\n"
|
||||||
|
f"- Tool / 工具: `{tool_name}`\n"
|
||||||
|
f"- Reason / 原因: This tool is blocked for security reasons\n\n"
|
||||||
|
f"This tool is blocked and cannot be approved.\n"
|
||||||
|
f"该工具已被禁止,无法批准执行。"
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_res_msg = Msg(
|
||||||
|
"system",
|
||||||
|
[
|
||||||
|
ToolResultBlock(
|
||||||
|
type="tool_result",
|
||||||
|
id=tool_call.get("id", ""),
|
||||||
|
name=tool_name,
|
||||||
|
output=[{"type": "text", "text": denied_text}],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
"system",
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.print(tool_res_msg, True)
|
||||||
|
await self.memory.add(tool_res_msg)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _acting_with_denial(
|
||||||
|
self,
|
||||||
|
tool_call: Dict[str, Any],
|
||||||
|
tool_name: str,
|
||||||
|
) -> dict | None:
|
||||||
|
"""Deny the tool call after approval was rejected.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_call: Tool call from the model
|
||||||
|
tool_name: Name of the tool
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Denial result
|
||||||
|
"""
|
||||||
|
from agentscope.message import ToolResultBlock
|
||||||
|
|
||||||
|
params_text = json.dumps(
|
||||||
|
tool_call.get("input", {}),
|
||||||
|
ensure_ascii=False,
|
||||||
|
indent=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
denied_text = (
|
||||||
|
f"⚠️ **Tool Call Denied / 工具调用被拒绝**\n\n"
|
||||||
|
f"- Tool / 工具: `{tool_name}`\n"
|
||||||
|
f"- Parameters / 参数:\n"
|
||||||
|
f"```json\n{params_text}\n```\n\n"
|
||||||
|
f"The tool call was denied by the user or timed out.\n"
|
||||||
|
f"工具调用被用户拒绝或已超时。"
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_res_msg = Msg(
|
||||||
|
"system",
|
||||||
|
[
|
||||||
|
ToolResultBlock(
|
||||||
|
type="tool_result",
|
||||||
|
id=tool_call.get("id", ""),
|
||||||
|
name=tool_name,
|
||||||
|
output=[{"type": "text", "text": denied_text}],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
"system",
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.print(tool_res_msg, True)
|
||||||
|
await self.memory.add(tool_res_msg, marks=TOOL_GUARD_DENIED_MARK)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _reasoning(self, **kwargs) -> Msg:
|
||||||
|
"""Short-circuit reasoning when awaiting guard approval.
|
||||||
|
|
||||||
|
If the last message was a guard denial, return a waiting message
|
||||||
|
instead of continuing reasoning.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Response message
|
||||||
|
"""
|
||||||
|
if self._last_tool_response_is_denied():
|
||||||
|
msg = Msg(
|
||||||
|
self.name,
|
||||||
|
"⏳ Waiting for approval / 等待审批...\n\n"
|
||||||
|
"Type `/approve` to approve, or send any message to deny.\n"
|
||||||
|
"输入 `/approve` 批准执行,或发送任意消息拒绝。",
|
||||||
|
"assistant",
|
||||||
|
)
|
||||||
|
await self.print(msg, True)
|
||||||
|
await self.memory.add(msg)
|
||||||
|
return msg
|
||||||
|
|
||||||
|
return await super()._reasoning(**kwargs) # type: ignore[misc]
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ToolGuardMixin",
|
||||||
|
"ToolApprovalRequest",
|
||||||
|
"DEFAULT_GUARDED_TOOLS",
|
||||||
|
"DEFAULT_DENIED_TOOLS",
|
||||||
|
"TOOL_GUARD_DENIED_MARK",
|
||||||
|
]
|
||||||
146
backend/agents/compat.py
Normal file
146
backend/agents/compat.py
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
Compatibility Layer - Adapters for legacy to EvoAgent migration.
|
||||||
|
|
||||||
|
Provides:
|
||||||
|
- LegacyAgentAdapter: Wraps old AnalystAgent to work with new interfaces
|
||||||
|
- Migration utilities for gradual adoption
|
||||||
|
"""
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from agentscope.message import Msg
|
||||||
|
|
||||||
|
from .agent_core import EvoAgent
|
||||||
|
|
||||||
|
|
||||||
|
class LegacyAgentAdapter:
|
||||||
|
"""
|
||||||
|
Adapter to make legacy AnalystAgent compatible with EvoAgent interfaces.
|
||||||
|
|
||||||
|
This allows gradual migration by wrapping existing agents.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, legacy_agent: Any):
|
||||||
|
"""
|
||||||
|
Initialize adapter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
legacy_agent: Legacy AnalystAgent instance
|
||||||
|
"""
|
||||||
|
self._agent = legacy_agent
|
||||||
|
self.agent_id = getattr(legacy_agent, 'agent_id', getattr(legacy_agent, 'name', 'unknown'))
|
||||||
|
self.analyst_type = getattr(legacy_agent, 'analyst_type_key', None)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
"""Get agent name."""
|
||||||
|
return getattr(self._agent, 'name', self.agent_id)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def toolkit(self) -> Any:
|
||||||
|
"""Get agent toolkit."""
|
||||||
|
return getattr(self._agent, 'toolkit', None)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model(self) -> Any:
|
||||||
|
"""Get agent model."""
|
||||||
|
return getattr(self._agent, 'model', None)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def memory(self) -> Any:
|
||||||
|
"""Get agent memory."""
|
||||||
|
return getattr(self._agent, 'memory', None)
|
||||||
|
|
||||||
|
async def reply(self, x: Msg = None) -> Msg:
|
||||||
|
"""
|
||||||
|
Delegate to legacy agent's reply method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Input message
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Response message
|
||||||
|
"""
|
||||||
|
return await self._agent.reply(x)
|
||||||
|
|
||||||
|
def reload_runtime_assets(self, active_skill_dirs: Optional[list] = None) -> None:
|
||||||
|
"""
|
||||||
|
Reload runtime assets if supported.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
active_skill_dirs: Optional list of active skill directories
|
||||||
|
"""
|
||||||
|
if hasattr(self._agent, 'reload_runtime_assets'):
|
||||||
|
self._agent.reload_runtime_assets(active_skill_dirs)
|
||||||
|
|
||||||
|
def to_evo_agent(
|
||||||
|
self,
|
||||||
|
workspace_manager: Optional[Any] = None,
|
||||||
|
enable_tool_guard: bool = False,
|
||||||
|
) -> EvoAgent:
|
||||||
|
"""
|
||||||
|
Convert legacy agent to EvoAgent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_manager: Optional workspace manager
|
||||||
|
enable_tool_guard: Whether to enable tool guard
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
New EvoAgent instance with same configuration
|
||||||
|
"""
|
||||||
|
return EvoAgent(
|
||||||
|
agent_id=self.agent_id,
|
||||||
|
model=self.model,
|
||||||
|
formatter=getattr(self._agent, 'formatter', None),
|
||||||
|
toolkit=self.toolkit,
|
||||||
|
workspace_manager=workspace_manager,
|
||||||
|
config=getattr(self._agent, 'config', {}),
|
||||||
|
long_term_memory=getattr(self._agent, 'long_term_memory', None),
|
||||||
|
enable_tool_guard=enable_tool_guard,
|
||||||
|
sys_prompt=getattr(self._agent, '_sys_prompt', None),
|
||||||
|
)
|
||||||
|
|
||||||
|
def __getattr__(self, name: str) -> Any:
|
||||||
|
"""Delegate unknown attributes to wrapped agent."""
|
||||||
|
return getattr(self._agent, name)
|
||||||
|
|
||||||
|
|
||||||
|
def is_legacy_agent(agent: Any) -> bool:
|
||||||
|
"""
|
||||||
|
Check if an agent is a legacy agent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent: Agent instance to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if legacy agent
|
||||||
|
"""
|
||||||
|
return hasattr(agent, 'analyst_type_key') and not isinstance(agent, EvoAgent)
|
||||||
|
|
||||||
|
|
||||||
|
def adapt_agent(agent: Any) -> Any:
|
||||||
|
"""
|
||||||
|
Wrap agent in adapter if it's a legacy agent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent: Agent instance
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Adapted agent or original if already EvoAgent
|
||||||
|
"""
|
||||||
|
if is_legacy_agent(agent):
|
||||||
|
return LegacyAgentAdapter(agent)
|
||||||
|
return agent
|
||||||
|
|
||||||
|
|
||||||
|
def adapt_agents(agents: list) -> list:
|
||||||
|
"""
|
||||||
|
Wrap multiple agents in adapters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agents: List of agent instances
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of adapted agents
|
||||||
|
"""
|
||||||
|
return [adapt_agent(agent) for agent in agents]
|
||||||
497
backend/agents/factory.py
Normal file
497
backend/agents/factory.py
Normal file
@@ -0,0 +1,497 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""Agent Factory - Dynamic creation and management of AgentConfigs."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import shutil
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelConfig:
|
||||||
|
"""Model configuration for an agent."""
|
||||||
|
|
||||||
|
model_name: str = "gpt-4o"
|
||||||
|
temperature: float = 0.7
|
||||||
|
max_tokens: int = 4096
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RoleConfig:
|
||||||
|
"""Role configuration for an agent."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
description: str = ""
|
||||||
|
focus_areas: List[str] = None
|
||||||
|
constraints: List[str] = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.focus_areas is None:
|
||||||
|
self.focus_areas = []
|
||||||
|
if self.constraints is None:
|
||||||
|
self.constraints = []
|
||||||
|
|
||||||
|
|
||||||
|
class AgentConfig:
|
||||||
|
"""Represents a configured agent instance (data class)."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
agent_id: str,
|
||||||
|
agent_type: str,
|
||||||
|
workspace_id: str,
|
||||||
|
config_path: Path,
|
||||||
|
model_config: Optional[ModelConfig] = None,
|
||||||
|
role_config: Optional[RoleConfig] = None,
|
||||||
|
):
|
||||||
|
self.agent_id = agent_id
|
||||||
|
self.agent_type = agent_type
|
||||||
|
self.workspace_id = workspace_id
|
||||||
|
self.config_path = config_path
|
||||||
|
self.model_config = model_config or ModelConfig()
|
||||||
|
self.role_config = role_config
|
||||||
|
self.agent_dir = config_path.parent
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Serialize agent to dictionary."""
|
||||||
|
return {
|
||||||
|
"agent_id": self.agent_id,
|
||||||
|
"agent_type": self.agent_type,
|
||||||
|
"workspace_id": self.workspace_id,
|
||||||
|
"config_path": str(self.config_path),
|
||||||
|
"agent_dir": str(self.agent_dir),
|
||||||
|
"model_config": {
|
||||||
|
"model_name": self.model_config.model_name,
|
||||||
|
"temperature": self.model_config.temperature,
|
||||||
|
"max_tokens": self.model_config.max_tokens,
|
||||||
|
},
|
||||||
|
"role_config": self.role_config.__dict__ if self.role_config else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class AgentFactory:
|
||||||
|
"""Factory for creating, cloning, and managing agents."""
|
||||||
|
|
||||||
|
# Default role templates by agent type
|
||||||
|
ROLE_TEMPLATES = {
|
||||||
|
"technical_analyst": {
|
||||||
|
"name": "Technical Analyst",
|
||||||
|
"description": "Analyze price patterns, trends, and technical indicators.",
|
||||||
|
"focus_areas": [
|
||||||
|
"Price action and chart patterns",
|
||||||
|
"Support and resistance levels",
|
||||||
|
"Technical indicators (RSI, MACD, Moving Averages)",
|
||||||
|
"Volume analysis",
|
||||||
|
],
|
||||||
|
"constraints": [
|
||||||
|
"State clear signal, confidence, and invalidation conditions",
|
||||||
|
"Use available technical analysis tools",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"fundamentals_analyst": {
|
||||||
|
"name": "Fundamentals Analyst",
|
||||||
|
"description": "Analyze company financials, earnings, and business metrics.",
|
||||||
|
"focus_areas": [
|
||||||
|
"Financial statements analysis",
|
||||||
|
"Earnings reports and guidance",
|
||||||
|
"Valuation metrics",
|
||||||
|
"Business model and competitive position",
|
||||||
|
],
|
||||||
|
"constraints": [
|
||||||
|
"State clear signal, confidence, and invalidation conditions",
|
||||||
|
"Use available fundamental analysis tools",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"sentiment_analyst": {
|
||||||
|
"name": "Sentiment Analyst",
|
||||||
|
"description": "Analyze market sentiment, news, and social signals.",
|
||||||
|
"focus_areas": [
|
||||||
|
"News sentiment analysis",
|
||||||
|
"Social media sentiment",
|
||||||
|
"Analyst ratings and price targets",
|
||||||
|
"Insider activity",
|
||||||
|
],
|
||||||
|
"constraints": [
|
||||||
|
"State clear signal, confidence, and invalidation conditions",
|
||||||
|
"Use available sentiment analysis tools",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"valuation_analyst": {
|
||||||
|
"name": "Valuation Analyst",
|
||||||
|
"description": "Perform valuation analysis and price target calculations.",
|
||||||
|
"focus_areas": [
|
||||||
|
"DCF and comparable valuation",
|
||||||
|
"Price target derivation",
|
||||||
|
"Margin of safety assessment",
|
||||||
|
"Risk-adjusted return expectations",
|
||||||
|
],
|
||||||
|
"constraints": [
|
||||||
|
"State clear signal, confidence, and invalidation conditions",
|
||||||
|
"Use available valuation tools",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"risk_manager": {
|
||||||
|
"name": "Risk Manager",
|
||||||
|
"description": "Quantify concentration, leverage, liquidity, and volatility risk.",
|
||||||
|
"focus_areas": [
|
||||||
|
"Portfolio concentration risk",
|
||||||
|
"Leverage and margin analysis",
|
||||||
|
"Liquidity assessment",
|
||||||
|
"Volatility and drawdown risk",
|
||||||
|
],
|
||||||
|
"constraints": [
|
||||||
|
"Prioritize highest-severity risk first",
|
||||||
|
"State concrete limits and recommendations",
|
||||||
|
"Use available risk tools before issuing final memo",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"portfolio_manager": {
|
||||||
|
"name": "Portfolio Manager",
|
||||||
|
"description": "Synthesize analyst and risk inputs into portfolio decisions.",
|
||||||
|
"focus_areas": [
|
||||||
|
"Position sizing and allocation",
|
||||||
|
"Risk-adjusted portfolio construction",
|
||||||
|
"Trade execution timing",
|
||||||
|
"Portfolio rebalancing",
|
||||||
|
],
|
||||||
|
"constraints": [
|
||||||
|
"Be concise, capital-aware, and explicit about sizing rationale",
|
||||||
|
"Respect cash, margin, and concentration constraints",
|
||||||
|
"Consider all analyst inputs before decisions",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, project_root: Optional[Path] = None):
|
||||||
|
"""Initialize the agent factory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_root: Root directory of the project
|
||||||
|
"""
|
||||||
|
self.project_root = project_root or Path(__file__).parent.parent.parent
|
||||||
|
self.workspaces_root = self.project_root / "workspaces"
|
||||||
|
self.template_dir = self.project_root / "backend" / "workspaces" / ".template"
|
||||||
|
|
||||||
|
def create_agent(
|
||||||
|
self,
|
||||||
|
agent_id: str,
|
||||||
|
agent_type: str,
|
||||||
|
workspace_id: str,
|
||||||
|
model_config: Optional[ModelConfig] = None,
|
||||||
|
role_config: Optional[RoleConfig] = None,
|
||||||
|
clone_from: Optional[str] = None,
|
||||||
|
) -> AgentConfig:
|
||||||
|
"""Create a new agent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: Unique identifier for the agent
|
||||||
|
agent_type: Type of agent (e.g., "technical_analyst")
|
||||||
|
workspace_id: ID of the workspace to create agent in
|
||||||
|
model_config: Model configuration
|
||||||
|
role_config: Role configuration (auto-generated if None)
|
||||||
|
clone_from: Path to existing agent to clone from (optional)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AgentConfig instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If agent already exists or workspace doesn't exist
|
||||||
|
"""
|
||||||
|
workspace_dir = self.workspaces_root / workspace_id
|
||||||
|
if not workspace_dir.exists():
|
||||||
|
raise ValueError(f"Workspace '{workspace_id}' does not exist")
|
||||||
|
|
||||||
|
agent_dir = workspace_dir / "agents" / agent_id
|
||||||
|
if agent_dir.exists():
|
||||||
|
raise ValueError(f"Agent '{agent_id}' already exists in workspace '{workspace_id}'")
|
||||||
|
|
||||||
|
# Create directory structure
|
||||||
|
agent_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
(agent_dir / "skills" / "active").mkdir(parents=True, exist_ok=True)
|
||||||
|
(agent_dir / "skills" / "local").mkdir(parents=True, exist_ok=True)
|
||||||
|
(agent_dir / "skills" / "installed").mkdir(parents=True, exist_ok=True)
|
||||||
|
(agent_dir / "skills" / "disabled").mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Copy template or clone existing agent
|
||||||
|
if clone_from:
|
||||||
|
self._clone_agent_files(clone_from, agent_dir, agent_id)
|
||||||
|
else:
|
||||||
|
self._copy_template(agent_dir, agent_id, agent_type)
|
||||||
|
|
||||||
|
# Generate role config if not provided
|
||||||
|
if role_config is None:
|
||||||
|
role_config = self._generate_role_config(agent_type)
|
||||||
|
|
||||||
|
# Generate ROLE.md
|
||||||
|
self._generate_role_md(agent_dir, role_config)
|
||||||
|
|
||||||
|
# Write agent.yaml
|
||||||
|
config_path = agent_dir / "agent.yaml"
|
||||||
|
self._write_agent_yaml(config_path, agent_id, agent_type, model_config)
|
||||||
|
|
||||||
|
return AgentConfig(
|
||||||
|
agent_id=agent_id,
|
||||||
|
agent_type=agent_type,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
config_path=config_path,
|
||||||
|
model_config=model_config,
|
||||||
|
role_config=role_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
def delete_agent(self, agent_id: str, workspace_id: str) -> bool:
|
||||||
|
"""Delete an agent and its workspace.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: ID of the agent to delete
|
||||||
|
workspace_id: ID of the workspace containing the agent
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if deleted, False if agent didn't exist
|
||||||
|
"""
|
||||||
|
agent_dir = self.workspaces_root / workspace_id / "agents" / agent_id
|
||||||
|
if not agent_dir.exists():
|
||||||
|
return False
|
||||||
|
|
||||||
|
shutil.rmtree(agent_dir)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def clone_agent(
|
||||||
|
self,
|
||||||
|
source_agent_id: str,
|
||||||
|
source_workspace_id: str,
|
||||||
|
new_agent_id: str,
|
||||||
|
target_workspace_id: Optional[str] = None,
|
||||||
|
model_config: Optional[ModelConfig] = None,
|
||||||
|
) -> AgentConfig:
|
||||||
|
"""Clone an existing agent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source_agent_id: ID of the agent to clone
|
||||||
|
source_workspace_id: Workspace containing the source agent
|
||||||
|
new_agent_id: ID for the new agent
|
||||||
|
target_workspace_id: Target workspace (defaults to source workspace)
|
||||||
|
model_config: Optional new model configuration
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AgentConfig instance for the cloned agent
|
||||||
|
"""
|
||||||
|
target_workspace_id = target_workspace_id or source_workspace_id
|
||||||
|
source_dir = self.workspaces_root / source_workspace_id / "agents" / source_agent_id
|
||||||
|
|
||||||
|
if not source_dir.exists():
|
||||||
|
raise ValueError(f"Source agent '{source_agent_id}' not found")
|
||||||
|
|
||||||
|
# Load source agent config
|
||||||
|
source_config_path = source_dir / "agent.yaml"
|
||||||
|
source_config = {}
|
||||||
|
if source_config_path.exists():
|
||||||
|
with open(source_config_path, "r", encoding="utf-8") as f:
|
||||||
|
source_config = yaml.safe_load(f) or {}
|
||||||
|
|
||||||
|
agent_type = source_config.get("agent_type", "generic")
|
||||||
|
|
||||||
|
# Determine source path for cloning
|
||||||
|
clone_from = str(source_dir)
|
||||||
|
|
||||||
|
return self.create_agent(
|
||||||
|
agent_id=new_agent_id,
|
||||||
|
agent_type=agent_type,
|
||||||
|
workspace_id=target_workspace_id,
|
||||||
|
model_config=model_config,
|
||||||
|
clone_from=clone_from,
|
||||||
|
)
|
||||||
|
|
||||||
|
def list_agents(self, workspace_id: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||||
|
"""List all agents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_id: Optional workspace to filter by
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of agent information dictionaries
|
||||||
|
"""
|
||||||
|
agents = []
|
||||||
|
|
||||||
|
if workspace_id:
|
||||||
|
workspaces = [self.workspaces_root / workspace_id]
|
||||||
|
else:
|
||||||
|
if not self.workspaces_root.exists():
|
||||||
|
return agents
|
||||||
|
workspaces = [d for d in self.workspaces_root.iterdir() if d.is_dir()]
|
||||||
|
|
||||||
|
for workspace in workspaces:
|
||||||
|
agents_dir = workspace / "agents"
|
||||||
|
if not agents_dir.exists():
|
||||||
|
continue
|
||||||
|
|
||||||
|
for agent_dir in agents_dir.iterdir():
|
||||||
|
if not agent_dir.is_dir():
|
||||||
|
continue
|
||||||
|
|
||||||
|
config_path = agent_dir / "agent.yaml"
|
||||||
|
if config_path.exists():
|
||||||
|
try:
|
||||||
|
with open(config_path, "r", encoding="utf-8") as f:
|
||||||
|
config = yaml.safe_load(f) or {}
|
||||||
|
|
||||||
|
agents.append({
|
||||||
|
"agent_id": agent_dir.name,
|
||||||
|
"workspace_id": workspace.name,
|
||||||
|
"agent_type": config.get("agent_type", "unknown"),
|
||||||
|
"config_path": str(config_path),
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to load agent config {config_path}: {e}")
|
||||||
|
|
||||||
|
return agents
|
||||||
|
|
||||||
|
def _copy_template(
|
||||||
|
self,
|
||||||
|
agent_dir: Path,
|
||||||
|
agent_id: str,
|
||||||
|
agent_type: str,
|
||||||
|
) -> None:
|
||||||
|
"""Copy template files to agent directory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_dir: Target agent directory
|
||||||
|
agent_id: ID of the agent
|
||||||
|
agent_type: Type of the agent
|
||||||
|
"""
|
||||||
|
# Create default markdown files
|
||||||
|
default_files = {
|
||||||
|
"AGENTS.md": f"# Agent Guide\n\nDocument how {agent_id} should work, collaborate, and choose tools or skills.\n\n",
|
||||||
|
"SOUL.md": f"# Soul\n\nDescribe {agent_id}'s temperament, reasoning posture, and voice.\n\n",
|
||||||
|
"PROFILE.md": f"# Profile\n\nTrack {agent_id}'s long-lived investment style, preferences, and strengths.\n\n",
|
||||||
|
"MEMORY.md": f"# Memory\n\nStore durable lessons, heuristics, and reminders for {agent_id}.\n\n",
|
||||||
|
"HEARTBEAT.md": f"# Heartbeat\n\nOptional checklist for periodic review or self-reflection.\n\n",
|
||||||
|
"POLICY.md": f"# Policy\n\nOptional run-scoped constraints, limits, or strategy policy.\n\n",
|
||||||
|
"STYLE.md": f"# Style\n\nOptional run-scoped communication or reasoning style.\n\n",
|
||||||
|
}
|
||||||
|
|
||||||
|
for filename, content in default_files.items():
|
||||||
|
filepath = agent_dir / filename
|
||||||
|
if not filepath.exists():
|
||||||
|
filepath.write_text(content, encoding="utf-8")
|
||||||
|
|
||||||
|
def _clone_agent_files(self, source_path: str, target_dir: Path, new_agent_id: str) -> None:
|
||||||
|
"""Clone files from an existing agent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source_path: Path to source agent directory
|
||||||
|
target_dir: Target agent directory
|
||||||
|
new_agent_id: ID for the new agent
|
||||||
|
"""
|
||||||
|
source_dir = Path(source_path)
|
||||||
|
if not source_dir.exists():
|
||||||
|
raise ValueError(f"Source path '{source_path}' does not exist")
|
||||||
|
|
||||||
|
# Copy markdown files
|
||||||
|
for md_file in source_dir.glob("*.md"):
|
||||||
|
target_file = target_dir / md_file.name
|
||||||
|
content = md_file.read_text(encoding="utf-8")
|
||||||
|
# Update agent references in content
|
||||||
|
source_name = source_dir.name
|
||||||
|
content = content.replace(source_name, new_agent_id)
|
||||||
|
target_file.write_text(content, encoding="utf-8")
|
||||||
|
|
||||||
|
# Copy skills directory structure (but not contents)
|
||||||
|
for skill_subdir in ["active", "local", "installed", "disabled"]:
|
||||||
|
source_skills = source_dir / "skills" / skill_subdir
|
||||||
|
if source_skills.exists():
|
||||||
|
target_skills = target_dir / "skills" / skill_subdir
|
||||||
|
target_skills.mkdir(parents=True, exist_ok=True)
|
||||||
|
# Copy skill files
|
||||||
|
for skill_file in source_skills.iterdir():
|
||||||
|
if skill_file.is_file():
|
||||||
|
shutil.copy2(skill_file, target_skills / skill_file.name)
|
||||||
|
|
||||||
|
def _generate_role_config(self, agent_type: str) -> RoleConfig:
|
||||||
|
"""Generate role configuration for an agent type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_type: Type of agent
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RoleConfig instance
|
||||||
|
"""
|
||||||
|
template = self.ROLE_TEMPLATES.get(agent_type, {})
|
||||||
|
return RoleConfig(
|
||||||
|
name=template.get("name", agent_type.replace("_", " ").title()),
|
||||||
|
description=template.get("description", ""),
|
||||||
|
focus_areas=template.get("focus_areas", []),
|
||||||
|
constraints=template.get("constraints", []),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _generate_role_md(self, agent_dir: Path, role_config: RoleConfig) -> None:
|
||||||
|
"""Generate ROLE.md file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_dir: Agent directory
|
||||||
|
role_config: Role configuration
|
||||||
|
"""
|
||||||
|
lines = [f"# {role_config.name}", ""]
|
||||||
|
|
||||||
|
if role_config.description:
|
||||||
|
lines.extend([role_config.description, ""])
|
||||||
|
|
||||||
|
if role_config.focus_areas:
|
||||||
|
lines.extend(["## Focus Areas", ""])
|
||||||
|
for area in role_config.focus_areas:
|
||||||
|
lines.append(f"- {area}")
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
if role_config.constraints:
|
||||||
|
lines.extend(["## Constraints", ""])
|
||||||
|
for constraint in role_config.constraints:
|
||||||
|
lines.append(f"- {constraint}")
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
content = "\n".join(lines)
|
||||||
|
(agent_dir / "ROLE.md").write_text(content, encoding="utf-8")
|
||||||
|
|
||||||
|
def _write_agent_yaml(
|
||||||
|
self,
|
||||||
|
config_path: Path,
|
||||||
|
agent_id: str,
|
||||||
|
agent_type: str,
|
||||||
|
model_config: Optional[ModelConfig] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Write agent.yaml configuration file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_path: Path to write configuration
|
||||||
|
agent_id: Agent ID
|
||||||
|
agent_type: Agent type
|
||||||
|
model_config: Optional model configuration
|
||||||
|
"""
|
||||||
|
config = {
|
||||||
|
"agent_id": agent_id,
|
||||||
|
"agent_type": agent_type,
|
||||||
|
"prompt_files": [
|
||||||
|
"SOUL.md",
|
||||||
|
"PROFILE.md",
|
||||||
|
"AGENTS.md",
|
||||||
|
"POLICY.md",
|
||||||
|
"MEMORY.md",
|
||||||
|
],
|
||||||
|
"enabled_skills": [],
|
||||||
|
"disabled_skills": [],
|
||||||
|
"active_tool_groups": [],
|
||||||
|
"disabled_tool_groups": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
if model_config:
|
||||||
|
config["model"] = {
|
||||||
|
"name": model_config.model_name,
|
||||||
|
"temperature": model_config.temperature,
|
||||||
|
"max_tokens": model_config.max_tokens,
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(config_path, "w", encoding="utf-8") as f:
|
||||||
|
yaml.safe_dump(config, f, allow_unicode=True, sort_keys=False)
|
||||||
@@ -4,7 +4,8 @@ Portfolio Manager Agent - Based on AgentScope ReActAgent
|
|||||||
Responsible for decision-making (NOT trade execution)
|
Responsible for decision-making (NOT trade execution)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Dict, Optional
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, Optional, Callable
|
||||||
|
|
||||||
from agentscope.agent import ReActAgent
|
from agentscope.agent import ReActAgent
|
||||||
from agentscope.memory import InMemoryMemory, LongTermMemoryBase
|
from agentscope.memory import InMemoryMemory, LongTermMemoryBase
|
||||||
@@ -13,6 +14,8 @@ from agentscope.tool import Toolkit, ToolResponse
|
|||||||
|
|
||||||
from ..utils.progress import progress
|
from ..utils.progress import progress
|
||||||
from .prompt_factory import build_agent_system_prompt, clear_prompt_factory_cache
|
from .prompt_factory import build_agent_system_prompt, clear_prompt_factory_cache
|
||||||
|
from .team_pipeline_config import update_active_analysts
|
||||||
|
from ..config.constants import ANALYST_TYPES
|
||||||
|
|
||||||
|
|
||||||
class PMAgent(ReActAgent):
|
class PMAgent(ReActAgent):
|
||||||
@@ -38,21 +41,31 @@ class PMAgent(ReActAgent):
|
|||||||
toolkit_factory_kwargs: Optional[Dict[str, Any]] = None,
|
toolkit_factory_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
toolkit: Optional[Toolkit] = None,
|
toolkit: Optional[Toolkit] = None,
|
||||||
):
|
):
|
||||||
self.config = config or {}
|
object.__setattr__(self, "config", config or {})
|
||||||
|
|
||||||
# Portfolio state
|
# Portfolio state
|
||||||
self.portfolio = {
|
object.__setattr__(
|
||||||
|
self,
|
||||||
|
"portfolio",
|
||||||
|
{
|
||||||
"cash": initial_cash,
|
"cash": initial_cash,
|
||||||
"positions": {},
|
"positions": {},
|
||||||
"margin_used": 0.0,
|
"margin_used": 0.0,
|
||||||
"margin_requirement": margin_requirement,
|
"margin_requirement": margin_requirement,
|
||||||
}
|
},
|
||||||
|
)
|
||||||
|
|
||||||
# Decisions made in current cycle
|
# Decisions made in current cycle
|
||||||
self._decisions: Dict[str, Dict] = {}
|
object.__setattr__(self, "_decisions", {})
|
||||||
toolkit_factory_kwargs = toolkit_factory_kwargs or {}
|
toolkit_factory_kwargs = toolkit_factory_kwargs or {}
|
||||||
self._toolkit_factory = toolkit_factory
|
object.__setattr__(self, "_toolkit_factory", toolkit_factory)
|
||||||
self._toolkit_factory_kwargs = toolkit_factory_kwargs
|
object.__setattr__(
|
||||||
|
self,
|
||||||
|
"_toolkit_factory_kwargs",
|
||||||
|
toolkit_factory_kwargs,
|
||||||
|
)
|
||||||
|
object.__setattr__(self, "_create_team_agent_cb", None)
|
||||||
|
object.__setattr__(self, "_remove_team_agent_cb", None)
|
||||||
|
|
||||||
# Create toolkit after local state is ready so bound tool methods can be registered.
|
# Create toolkit after local state is ready so bound tool methods can be registered.
|
||||||
if toolkit is None:
|
if toolkit is None:
|
||||||
@@ -65,7 +78,7 @@ class PMAgent(ReActAgent):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
toolkit = self._create_toolkit()
|
toolkit = self._create_toolkit()
|
||||||
self.toolkit = toolkit
|
object.__setattr__(self, "toolkit", toolkit)
|
||||||
|
|
||||||
sys_prompt = build_agent_system_prompt(
|
sys_prompt = build_agent_system_prompt(
|
||||||
agent_id=name,
|
agent_id=name,
|
||||||
@@ -144,6 +157,107 @@ class PMAgent(ReActAgent):
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _add_team_analyst(self, agent_id: str) -> ToolResponse:
|
||||||
|
"""Add one analyst to active discussion team."""
|
||||||
|
config_name = self.config.get("config_name", "default")
|
||||||
|
project_root = Path(__file__).resolve().parents[2]
|
||||||
|
active = update_active_analysts(
|
||||||
|
project_root=project_root,
|
||||||
|
config_name=config_name,
|
||||||
|
available_analysts=list(ANALYST_TYPES.keys()),
|
||||||
|
add=[agent_id],
|
||||||
|
)
|
||||||
|
return ToolResponse(
|
||||||
|
content=[
|
||||||
|
TextBlock(
|
||||||
|
type="text",
|
||||||
|
text=(
|
||||||
|
f"Active analyst team updated. Added: {agent_id}. "
|
||||||
|
f"Current active analysts: {', '.join(active)}"
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
def _remove_team_analyst(self, agent_id: str) -> ToolResponse:
|
||||||
|
"""Remove one analyst from active discussion team."""
|
||||||
|
callback_msg = ""
|
||||||
|
callback = self._remove_team_agent_cb
|
||||||
|
if callback is not None:
|
||||||
|
callback_msg = callback(agent_id=agent_id)
|
||||||
|
|
||||||
|
config_name = self.config.get("config_name", "default")
|
||||||
|
project_root = Path(__file__).resolve().parents[2]
|
||||||
|
active = update_active_analysts(
|
||||||
|
project_root=project_root,
|
||||||
|
config_name=config_name,
|
||||||
|
available_analysts=list(ANALYST_TYPES.keys()),
|
||||||
|
remove=[agent_id],
|
||||||
|
)
|
||||||
|
return ToolResponse(
|
||||||
|
content=[
|
||||||
|
TextBlock(
|
||||||
|
type="text",
|
||||||
|
text=(
|
||||||
|
f"Active analyst team updated. Removed: {agent_id}. "
|
||||||
|
f"Current active analysts: {', '.join(active)}"
|
||||||
|
+ (f" | {callback_msg}" if callback_msg else "")
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
def _set_active_analysts(self, agent_ids: str) -> ToolResponse:
|
||||||
|
"""Set active analysts from comma-separated agent ids."""
|
||||||
|
requested = [
|
||||||
|
item.strip() for item in str(agent_ids or "").split(",") if item.strip()
|
||||||
|
]
|
||||||
|
config_name = self.config.get("config_name", "default")
|
||||||
|
project_root = Path(__file__).resolve().parents[2]
|
||||||
|
active = update_active_analysts(
|
||||||
|
project_root=project_root,
|
||||||
|
config_name=config_name,
|
||||||
|
available_analysts=list(ANALYST_TYPES.keys()),
|
||||||
|
set_to=requested,
|
||||||
|
)
|
||||||
|
return ToolResponse(
|
||||||
|
content=[
|
||||||
|
TextBlock(
|
||||||
|
type="text",
|
||||||
|
text=f"Active analyst team set to: {', '.join(active)}",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
def _create_team_analyst(self, agent_id: str, analyst_type: str) -> ToolResponse:
|
||||||
|
"""Create a runtime analyst instance and activate it."""
|
||||||
|
callback = self._create_team_agent_cb
|
||||||
|
if callback is None:
|
||||||
|
return ToolResponse(
|
||||||
|
content=[
|
||||||
|
TextBlock(
|
||||||
|
type="text",
|
||||||
|
text="Runtime agent creation is not available in current pipeline.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
result = callback(agent_id=agent_id, analyst_type=analyst_type)
|
||||||
|
return ToolResponse(
|
||||||
|
content=[
|
||||||
|
TextBlock(type="text", text=result),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
def set_team_controller(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
create_agent_callback: Optional[Callable[..., str]] = None,
|
||||||
|
remove_agent_callback: Optional[Callable[..., str]] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Inject runtime team lifecycle callbacks from pipeline."""
|
||||||
|
object.__setattr__(self, "_create_team_agent_cb", create_agent_callback)
|
||||||
|
object.__setattr__(self, "_remove_team_agent_cb", remove_agent_callback)
|
||||||
|
|
||||||
async def reply(self, x: Msg = None) -> Msg:
|
async def reply(self, x: Msg = None) -> Msg:
|
||||||
"""
|
"""
|
||||||
Make investment decisions
|
Make investment decisions
|
||||||
@@ -205,6 +319,42 @@ class PMAgent(ReActAgent):
|
|||||||
"""Update portfolio after external execution"""
|
"""Update portfolio after external execution"""
|
||||||
self.portfolio.update(portfolio)
|
self.portfolio.update(portfolio)
|
||||||
|
|
||||||
|
def _has_open_positions(self) -> bool:
|
||||||
|
"""Return whether the current portfolio still has non-zero positions."""
|
||||||
|
for position in self.portfolio.get("positions", {}).values():
|
||||||
|
if position.get("long", 0) or position.get("short", 0):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def can_apply_initial_cash(self) -> bool:
|
||||||
|
"""Only allow cash rebasing before any positions or margin exist."""
|
||||||
|
return (
|
||||||
|
not self._has_open_positions()
|
||||||
|
and float(self.portfolio.get("margin_used", 0.0) or 0.0) == 0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
def apply_runtime_portfolio_config(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
margin_requirement: Optional[float] = None,
|
||||||
|
initial_cash: Optional[float] = None,
|
||||||
|
) -> Dict[str, bool]:
|
||||||
|
"""Apply safe run-time portfolio config updates."""
|
||||||
|
result = {
|
||||||
|
"margin_requirement": False,
|
||||||
|
"initial_cash": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
if margin_requirement is not None:
|
||||||
|
self.portfolio["margin_requirement"] = float(margin_requirement)
|
||||||
|
result["margin_requirement"] = True
|
||||||
|
|
||||||
|
if initial_cash is not None and self.can_apply_initial_cash():
|
||||||
|
self.portfolio["cash"] = float(initial_cash)
|
||||||
|
result["initial_cash"] = True
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
def reload_runtime_assets(self, active_skill_dirs: Optional[list] = None) -> None:
|
def reload_runtime_assets(self, active_skill_dirs: Optional[list] = None) -> None:
|
||||||
"""Reload toolkit and system prompt from current run assets."""
|
"""Reload toolkit and system prompt from current run assets."""
|
||||||
from .toolkit_factory import create_agent_toolkit
|
from .toolkit_factory import create_agent_toolkit
|
||||||
@@ -221,8 +371,18 @@ class PMAgent(ReActAgent):
|
|||||||
owner=self,
|
owner=self,
|
||||||
**toolkit_kwargs,
|
**toolkit_kwargs,
|
||||||
)
|
)
|
||||||
self.sys_prompt = build_agent_system_prompt(
|
self._apply_runtime_sys_prompt(
|
||||||
|
build_agent_system_prompt(
|
||||||
agent_id=self.name,
|
agent_id=self.name,
|
||||||
config_name=self.config.get("config_name", "default"),
|
config_name=self.config.get("config_name", "default"),
|
||||||
toolkit=self.toolkit,
|
toolkit=self.toolkit,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _apply_runtime_sys_prompt(self, sys_prompt: str) -> None:
|
||||||
|
"""Update the prompt used by future turns and the cached system msg."""
|
||||||
|
self._sys_prompt = sys_prompt
|
||||||
|
for msg, _marks in self.memory.content:
|
||||||
|
if getattr(msg, "role", None) == "system":
|
||||||
|
msg.content = sys_prompt
|
||||||
|
break
|
||||||
|
|||||||
@@ -4,11 +4,12 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from .agent_workspace import load_agent_workspace_config
|
||||||
from backend.config.bootstrap_config import get_bootstrap_config_for_run
|
from backend.config.bootstrap_config import get_bootstrap_config_for_run
|
||||||
from .prompt_loader import PromptLoader
|
from .prompt_loader import get_prompt_loader
|
||||||
from .skills_manager import SkillsManager
|
from .skills_manager import SkillsManager
|
||||||
|
|
||||||
_prompt_loader = PromptLoader()
|
_prompt_loader = get_prompt_loader()
|
||||||
|
|
||||||
|
|
||||||
def _read_file_if_exists(path: Path) -> str:
|
def _read_file_if_exists(path: Path) -> str:
|
||||||
@@ -23,14 +24,47 @@ def _append_section(parts: list[str], title: str, content: str) -> None:
|
|||||||
parts.append(f"## {title}\n{content}")
|
parts.append(f"## {title}\n{content}")
|
||||||
|
|
||||||
|
|
||||||
|
def _build_skill_metadata_summary(skills_manager: SkillsManager, config_name: str, agent_id: str) -> str:
|
||||||
|
"""Create a compact summary of active skills for prompt routing."""
|
||||||
|
metadata_items = skills_manager.list_active_skill_metadata(config_name, agent_id)
|
||||||
|
if not metadata_items:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
lines: list[str] = [
|
||||||
|
"You can use the following active skills. Prefer the most relevant one, then read its SKILL.md if needed for detailed workflow:",
|
||||||
|
]
|
||||||
|
for item in metadata_items:
|
||||||
|
parts = [f"- `{item.skill_name}`"]
|
||||||
|
if item.description:
|
||||||
|
parts.append(item.description)
|
||||||
|
if item.version:
|
||||||
|
parts.append(f"version: {item.version}")
|
||||||
|
parts.append(f"path: {item.path}")
|
||||||
|
lines.append(" | ".join(parts))
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
def build_agent_system_prompt(
|
def build_agent_system_prompt(
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
config_name: str,
|
config_name: str,
|
||||||
toolkit: Any,
|
toolkit: Any,
|
||||||
analyst_type: Optional[str] = None,
|
analyst_type: Optional[str] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Build the final system prompt for an agent."""
|
"""Build the final system prompt for an agent.
|
||||||
|
|
||||||
|
Always reads fresh from disk — no caching.
|
||||||
|
"""
|
||||||
|
# Clear any cached templates before building (CoPaw-style, no caching)
|
||||||
|
_prompt_loader.clear_cache()
|
||||||
|
|
||||||
sections: list[str] = []
|
sections: list[str] = []
|
||||||
|
canonical_agent_id = (
|
||||||
|
"portfolio_manager"
|
||||||
|
if "portfolio" in agent_id
|
||||||
|
else "risk_manager"
|
||||||
|
if "risk" in agent_id and not analyst_type
|
||||||
|
else agent_id
|
||||||
|
)
|
||||||
|
|
||||||
if analyst_type:
|
if analyst_type:
|
||||||
personas_config = _prompt_loader.load_yaml_config(
|
personas_config = _prompt_loader.load_yaml_config(
|
||||||
@@ -56,11 +90,21 @@ def build_agent_system_prompt(
|
|||||||
"portfolio_manager",
|
"portfolio_manager",
|
||||||
"system",
|
"system",
|
||||||
)
|
)
|
||||||
|
elif canonical_agent_id == "portfolio_manager":
|
||||||
|
base_prompt = _prompt_loader.load_prompt(
|
||||||
|
"portfolio_manager",
|
||||||
|
"system",
|
||||||
|
)
|
||||||
elif agent_id == "risk_manager":
|
elif agent_id == "risk_manager":
|
||||||
base_prompt = _prompt_loader.load_prompt(
|
base_prompt = _prompt_loader.load_prompt(
|
||||||
"risk_manager",
|
"risk_manager",
|
||||||
"system",
|
"system",
|
||||||
)
|
)
|
||||||
|
elif canonical_agent_id == "risk_manager":
|
||||||
|
base_prompt = _prompt_loader.load_prompt(
|
||||||
|
"risk_manager",
|
||||||
|
"system",
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported agent prompt build for: {agent_id}")
|
raise ValueError(f"Unsupported agent prompt build for: {agent_id}")
|
||||||
|
|
||||||
@@ -69,6 +113,7 @@ def build_agent_system_prompt(
|
|||||||
skills_manager = SkillsManager()
|
skills_manager = SkillsManager()
|
||||||
asset_dir = skills_manager.get_agent_asset_dir(config_name, agent_id)
|
asset_dir = skills_manager.get_agent_asset_dir(config_name, agent_id)
|
||||||
asset_dir.mkdir(parents=True, exist_ok=True)
|
asset_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
agent_config = load_agent_workspace_config(asset_dir / "agent.yaml")
|
||||||
bootstrap_config = get_bootstrap_config_for_run(
|
bootstrap_config = get_bootstrap_config_for_run(
|
||||||
skills_manager.project_root,
|
skills_manager.project_root,
|
||||||
config_name,
|
config_name,
|
||||||
@@ -80,16 +125,44 @@ def build_agent_system_prompt(
|
|||||||
bootstrap_config.prompt_body,
|
bootstrap_config.prompt_body,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
prompt_files = agent_config.prompt_files or [
|
||||||
|
"SOUL.md",
|
||||||
|
"PROFILE.md",
|
||||||
|
"AGENTS.md",
|
||||||
|
"POLICY.md",
|
||||||
|
"MEMORY.md",
|
||||||
|
]
|
||||||
|
included_files = set(prompt_files)
|
||||||
|
title_map = {
|
||||||
|
"SOUL.md": "Soul",
|
||||||
|
"PROFILE.md": "Profile",
|
||||||
|
"AGENTS.md": "Agent Guide",
|
||||||
|
"POLICY.md": "Policy",
|
||||||
|
"MEMORY.md": "Memory",
|
||||||
|
"HEARTBEAT.md": "Heartbeat",
|
||||||
|
"ROLE.md": "Role",
|
||||||
|
"STYLE.md": "Style",
|
||||||
|
}
|
||||||
|
for filename in prompt_files:
|
||||||
|
_append_section(
|
||||||
|
sections,
|
||||||
|
title_map.get(filename, filename),
|
||||||
|
_read_file_if_exists(asset_dir / filename),
|
||||||
|
)
|
||||||
|
|
||||||
|
if "ROLE.md" not in included_files:
|
||||||
_append_section(
|
_append_section(
|
||||||
sections,
|
sections,
|
||||||
"Role",
|
"Role",
|
||||||
_read_file_if_exists(asset_dir / "ROLE.md"),
|
_read_file_if_exists(asset_dir / "ROLE.md"),
|
||||||
)
|
)
|
||||||
|
if "STYLE.md" not in included_files:
|
||||||
_append_section(
|
_append_section(
|
||||||
sections,
|
sections,
|
||||||
"Style",
|
"Style",
|
||||||
_read_file_if_exists(asset_dir / "STYLE.md"),
|
_read_file_if_exists(asset_dir / "STYLE.md"),
|
||||||
)
|
)
|
||||||
|
if "POLICY.md" not in included_files:
|
||||||
_append_section(
|
_append_section(
|
||||||
sections,
|
sections,
|
||||||
"Policy",
|
"Policy",
|
||||||
@@ -100,6 +173,14 @@ def build_agent_system_prompt(
|
|||||||
if skill_prompt:
|
if skill_prompt:
|
||||||
_append_section(sections, "Skills", str(skill_prompt))
|
_append_section(sections, "Skills", str(skill_prompt))
|
||||||
|
|
||||||
|
metadata_summary = _build_skill_metadata_summary(
|
||||||
|
skills_manager=skills_manager,
|
||||||
|
config_name=config_name,
|
||||||
|
agent_id=agent_id,
|
||||||
|
)
|
||||||
|
if metadata_summary:
|
||||||
|
_append_section(sections, "Active Skill Catalog", metadata_summary)
|
||||||
|
|
||||||
activated_notes = toolkit.get_activated_notes()
|
activated_notes = toolkit.get_activated_notes()
|
||||||
if activated_notes:
|
if activated_notes:
|
||||||
_append_section(sections, "Tool Usage Notes", str(activated_notes))
|
_append_section(sections, "Tool Usage Notes", str(activated_notes))
|
||||||
|
|||||||
@@ -10,6 +10,17 @@ from typing import Any, Dict, Optional
|
|||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
# Singleton instance
|
||||||
|
_prompt_loader_instance: Optional["PromptLoader"] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_prompt_loader() -> "PromptLoader":
|
||||||
|
"""Get the singleton PromptLoader instance."""
|
||||||
|
global _prompt_loader_instance
|
||||||
|
if _prompt_loader_instance is None:
|
||||||
|
_prompt_loader_instance = PromptLoader()
|
||||||
|
return _prompt_loader_instance
|
||||||
|
|
||||||
|
|
||||||
class PromptLoader:
|
class PromptLoader:
|
||||||
"""Unified Prompt loader"""
|
"""Unified Prompt loader"""
|
||||||
@@ -27,10 +38,6 @@ class PromptLoader:
|
|||||||
else:
|
else:
|
||||||
self.prompts_dir = Path(prompts_dir)
|
self.prompts_dir = Path(prompts_dir)
|
||||||
|
|
||||||
# Cache loaded prompts
|
|
||||||
self._prompt_cache: Dict[str, str] = {}
|
|
||||||
self._yaml_cache: Dict[str, Dict] = {}
|
|
||||||
|
|
||||||
def load_prompt(
|
def load_prompt(
|
||||||
self,
|
self,
|
||||||
agent_type: str,
|
agent_type: str,
|
||||||
@@ -38,25 +45,10 @@ class PromptLoader:
|
|||||||
variables: Optional[Dict[str, Any]] = None,
|
variables: Optional[Dict[str, Any]] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Load and render Prompt
|
Load and render Prompt.
|
||||||
|
|
||||||
Args:
|
No caching — always reads fresh from disk (CoPaw-style).
|
||||||
agent_type: Agent type (analyst, portfolio_manager, risk_manager)
|
|
||||||
prompt_name: Prompt file name (without extension)
|
|
||||||
variables: Variable dictionary for rendering Prompt
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Rendered prompt string
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
loader = PromptLoader()
|
|
||||||
prompt = loader.load_prompt("analyst", "tool_selection",
|
|
||||||
{"analyst_persona": "Technical Analyst"})
|
|
||||||
"""
|
"""
|
||||||
cache_key = f"{agent_type}/{prompt_name}"
|
|
||||||
|
|
||||||
# Try to load from cache
|
|
||||||
if cache_key not in self._prompt_cache:
|
|
||||||
prompt_path = self.prompts_dir / agent_type / f"{prompt_name}.md"
|
prompt_path = self.prompts_dir / agent_type / f"{prompt_name}.md"
|
||||||
|
|
||||||
if not prompt_path.exists():
|
if not prompt_path.exists():
|
||||||
@@ -66,9 +58,7 @@ class PromptLoader:
|
|||||||
)
|
)
|
||||||
|
|
||||||
with open(prompt_path, "r", encoding="utf-8") as f:
|
with open(prompt_path, "r", encoding="utf-8") as f:
|
||||||
self._prompt_cache[cache_key] = f.read()
|
prompt_template = f.read()
|
||||||
|
|
||||||
prompt_template = self._prompt_cache[cache_key]
|
|
||||||
|
|
||||||
# If variables provided, use simple string replacement
|
# If variables provided, use simple string replacement
|
||||||
if variables:
|
if variables:
|
||||||
@@ -76,8 +66,6 @@ class PromptLoader:
|
|||||||
else:
|
else:
|
||||||
rendered = prompt_template
|
rendered = prompt_template
|
||||||
|
|
||||||
# Smart escaping: escape braces in JSON code blocks
|
|
||||||
# rendered = self._escape_json_braces(rendered)
|
|
||||||
return rendered
|
return rendered
|
||||||
|
|
||||||
def _render_template(
|
def _render_template(
|
||||||
@@ -140,45 +128,26 @@ class PromptLoader:
|
|||||||
config_name: str,
|
config_name: str,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Load YAML configuration file
|
Load YAML configuration file.
|
||||||
|
|
||||||
Args:
|
No caching — always reads fresh from disk (CoPaw-style).
|
||||||
agent_type: Agent type
|
|
||||||
config_name: Configuration file name (without extension)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Configuration dictionary
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> loader = PromptLoader()
|
|
||||||
>>> config = loader.load_yaml_config("analyst", "personas")
|
|
||||||
"""
|
"""
|
||||||
cache_key = f"{agent_type}/{config_name}"
|
|
||||||
|
|
||||||
if cache_key not in self._yaml_cache:
|
|
||||||
yaml_path = self.prompts_dir / agent_type / f"{config_name}.yaml"
|
yaml_path = self.prompts_dir / agent_type / f"{config_name}.yaml"
|
||||||
|
|
||||||
if not yaml_path.exists():
|
if not yaml_path.exists():
|
||||||
raise FileNotFoundError(f"YAML config not found: {yaml_path}")
|
raise FileNotFoundError(f"YAML config not found: {yaml_path}")
|
||||||
|
|
||||||
with open(yaml_path, "r", encoding="utf-8") as f:
|
with open(yaml_path, "r", encoding="utf-8") as f:
|
||||||
self._yaml_cache[cache_key] = yaml.safe_load(f)
|
return yaml.safe_load(f) or {}
|
||||||
|
|
||||||
return self._yaml_cache[cache_key]
|
|
||||||
|
|
||||||
def clear_cache(self):
|
def clear_cache(self):
|
||||||
"""Clear cache (for hot reload)"""
|
"""No-op — caching removed (CoPaw-style, always fresh reads)."""
|
||||||
self._prompt_cache.clear()
|
pass
|
||||||
self._yaml_cache.clear()
|
|
||||||
|
|
||||||
def reload_prompt(self, agent_type: str, prompt_name: str):
|
def reload_prompt(self, agent_type: str, prompt_name: str):
|
||||||
"""Reload specified prompt (force cache refresh)"""
|
"""No-op — caching removed."""
|
||||||
cache_key = f"{agent_type}/{prompt_name}"
|
pass
|
||||||
if cache_key in self._prompt_cache:
|
|
||||||
del self._prompt_cache[cache_key]
|
|
||||||
|
|
||||||
def reload_config(self, agent_type: str, config_name: str):
|
def reload_config(self, agent_type: str, config_name: str):
|
||||||
"""Reload specified configuration (force cache refresh)"""
|
"""No-op — caching removed."""
|
||||||
cache_key = f"{agent_type}/{config_name}"
|
pass
|
||||||
if cache_key in self._yaml_cache:
|
|
||||||
del self._yaml_cache[cache_key]
|
|
||||||
|
|||||||
19
backend/agents/prompts/__init__.py
Normal file
19
backend/agents/prompts/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""Prompt building utilities for EvoAgent.
|
||||||
|
|
||||||
|
This module provides prompt construction from workspace markdown files
|
||||||
|
with YAML frontmatter support.
|
||||||
|
"""
|
||||||
|
from .builder import (
|
||||||
|
PromptBuilder,
|
||||||
|
build_system_prompt_from_workspace,
|
||||||
|
build_bootstrap_guidance,
|
||||||
|
DEFAULT_SYS_PROMPT,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"PromptBuilder",
|
||||||
|
"build_system_prompt_from_workspace",
|
||||||
|
"build_bootstrap_guidance",
|
||||||
|
"DEFAULT_SYS_PROMPT",
|
||||||
|
]
|
||||||
305
backend/agents/prompts/builder.py
Normal file
305
backend/agents/prompts/builder.py
Normal file
@@ -0,0 +1,305 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""PromptBuilder for constructing system prompts from workspace markdown files.
|
||||||
|
|
||||||
|
Based on CoPaw design - loads AGENTS.md, SOUL.md, PROFILE.md, etc. from
|
||||||
|
agent workspace directories with YAML frontmatter support.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DEFAULT_SYS_PROMPT = """You are a helpful trading analysis assistant."""
|
||||||
|
|
||||||
|
|
||||||
|
class PromptBuilder:
|
||||||
|
"""Builder for constructing system prompts from markdown files.
|
||||||
|
|
||||||
|
Loads markdown configuration files from agent workspace directories,
|
||||||
|
supporting YAML frontmatter for metadata extraction.
|
||||||
|
"""
|
||||||
|
|
||||||
|
DEFAULT_FILES = [
|
||||||
|
"AGENTS.md",
|
||||||
|
"SOUL.md",
|
||||||
|
"PROFILE.md",
|
||||||
|
"ROLE.md",
|
||||||
|
"POLICY.md",
|
||||||
|
"MEMORY.md",
|
||||||
|
"HEARTBEAT.md",
|
||||||
|
"STYLE.md",
|
||||||
|
]
|
||||||
|
|
||||||
|
TITLE_MAP: Dict[str, str] = {
|
||||||
|
"AGENTS.md": "Agent Guide",
|
||||||
|
"SOUL.md": "Soul",
|
||||||
|
"PROFILE.md": "Profile",
|
||||||
|
"ROLE.md": "Role",
|
||||||
|
"POLICY.md": "Policy",
|
||||||
|
"MEMORY.md": "Memory",
|
||||||
|
"HEARTBEAT.md": "Heartbeat",
|
||||||
|
"STYLE.md": "Style",
|
||||||
|
"BOOTSTRAP.md": "Bootstrap",
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
workspace_dir: Path,
|
||||||
|
enabled_files: Optional[List[str]] = None,
|
||||||
|
):
|
||||||
|
"""Initialize prompt builder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_dir: Directory containing markdown configuration files
|
||||||
|
enabled_files: List of filenames to load (if None, uses defaults)
|
||||||
|
"""
|
||||||
|
self.workspace_dir = Path(workspace_dir)
|
||||||
|
self.enabled_files = enabled_files or self.DEFAULT_FILES.copy()
|
||||||
|
self._prompt_parts: List[str] = []
|
||||||
|
self._metadata: Dict[str, Any] = {}
|
||||||
|
self.loaded_count = 0
|
||||||
|
|
||||||
|
def _load_file(self, filename: str) -> tuple[str, Optional[Dict[str, Any]]]:
|
||||||
|
"""Load a single markdown file with YAML frontmatter support.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filename: Name of the file to load
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (content, metadata dict or None)
|
||||||
|
"""
|
||||||
|
file_path = self.workspace_dir / filename
|
||||||
|
|
||||||
|
if not file_path.exists():
|
||||||
|
logger.debug("File %s not found in %s, skipping", filename, self.workspace_dir)
|
||||||
|
return "", None
|
||||||
|
|
||||||
|
try:
|
||||||
|
raw_content = file_path.read_text(encoding="utf-8").strip()
|
||||||
|
|
||||||
|
if not raw_content:
|
||||||
|
logger.debug("Skipped empty file: %s", filename)
|
||||||
|
return "", None
|
||||||
|
|
||||||
|
content, metadata = self._parse_frontmatter(raw_content)
|
||||||
|
|
||||||
|
if content:
|
||||||
|
self.loaded_count += 1
|
||||||
|
logger.debug("Loaded %s (metadata: %s)", filename, bool(metadata))
|
||||||
|
|
||||||
|
return content, metadata
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Failed to read file %s: %s, skipping", filename, e)
|
||||||
|
return "", None
|
||||||
|
|
||||||
|
def _parse_frontmatter(self, raw_content: str) -> tuple[str, Optional[Dict[str, Any]]]:
|
||||||
|
"""Parse YAML frontmatter from markdown content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
raw_content: Raw file content
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (content without frontmatter, metadata dict or None)
|
||||||
|
"""
|
||||||
|
if not raw_content.startswith("---"):
|
||||||
|
return raw_content, None
|
||||||
|
|
||||||
|
parts = raw_content.split("---", 2)
|
||||||
|
if len(parts) < 3:
|
||||||
|
return raw_content, None
|
||||||
|
|
||||||
|
frontmatter = parts[1].strip()
|
||||||
|
content = parts[2].strip()
|
||||||
|
|
||||||
|
try:
|
||||||
|
metadata = yaml.safe_load(frontmatter) or {}
|
||||||
|
if not isinstance(metadata, dict):
|
||||||
|
metadata = {}
|
||||||
|
return content, metadata
|
||||||
|
except yaml.YAMLError as e:
|
||||||
|
logger.warning("Failed to parse YAML frontmatter: %s", e)
|
||||||
|
return content, None
|
||||||
|
|
||||||
|
def _append_section(self, title: str, content: str) -> None:
|
||||||
|
"""Append a section to the prompt parts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
title: Section title
|
||||||
|
content: Section content
|
||||||
|
"""
|
||||||
|
content = content.strip()
|
||||||
|
if not content:
|
||||||
|
return
|
||||||
|
|
||||||
|
if self._prompt_parts:
|
||||||
|
self._prompt_parts.append("")
|
||||||
|
|
||||||
|
self._prompt_parts.append(f"## {title}")
|
||||||
|
self._prompt_parts.append("")
|
||||||
|
self._prompt_parts.append(content)
|
||||||
|
|
||||||
|
def build(self) -> str:
|
||||||
|
"""Build the system prompt from markdown files.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Constructed system prompt string
|
||||||
|
"""
|
||||||
|
self._prompt_parts = []
|
||||||
|
self._metadata = {}
|
||||||
|
self.loaded_count = 0
|
||||||
|
|
||||||
|
for filename in self.enabled_files:
|
||||||
|
content, metadata = self._load_file(filename)
|
||||||
|
|
||||||
|
if metadata:
|
||||||
|
self._metadata[filename] = metadata
|
||||||
|
|
||||||
|
if content:
|
||||||
|
title = self.TITLE_MAP.get(filename, filename.replace(".md", ""))
|
||||||
|
self._append_section(title, content)
|
||||||
|
|
||||||
|
if not self._prompt_parts:
|
||||||
|
logger.warning("No content loaded from workspace: %s", self.workspace_dir)
|
||||||
|
return DEFAULT_SYS_PROMPT
|
||||||
|
|
||||||
|
final_prompt = "\n".join(self._prompt_parts)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"System prompt built from %d file(s), total length: %d chars",
|
||||||
|
self.loaded_count,
|
||||||
|
len(final_prompt),
|
||||||
|
)
|
||||||
|
|
||||||
|
return final_prompt
|
||||||
|
|
||||||
|
def get_metadata(self) -> Dict[str, Any]:
|
||||||
|
"""Get metadata collected from YAML frontmatter.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping filenames to their metadata
|
||||||
|
"""
|
||||||
|
return self._metadata.copy()
|
||||||
|
|
||||||
|
def get_agent_identity(self) -> Optional[Dict[str, Any]]:
|
||||||
|
"""Extract agent identity from PROFILE.md metadata.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Identity dict with name, role, etc. or None
|
||||||
|
"""
|
||||||
|
profile_meta = self._metadata.get("PROFILE.md", {})
|
||||||
|
if not profile_meta:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return {
|
||||||
|
"name": profile_meta.get("name", "Unknown"),
|
||||||
|
"role": profile_meta.get("role", ""),
|
||||||
|
"expertise": profile_meta.get("expertise", []),
|
||||||
|
"style": profile_meta.get("style", ""),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def build_system_prompt_from_workspace(
|
||||||
|
workspace_dir: Path,
|
||||||
|
enabled_files: Optional[List[str]] = None,
|
||||||
|
agent_id: Optional[str] = None,
|
||||||
|
extra_context: Optional[str] = None,
|
||||||
|
) -> str:
|
||||||
|
"""Build system prompt from workspace markdown files.
|
||||||
|
|
||||||
|
This is the main entry point for building system prompts from
|
||||||
|
agent workspace directories.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_dir: Directory containing markdown configuration files
|
||||||
|
enabled_files: List of filenames to load (if None, uses defaults)
|
||||||
|
agent_id: Agent identifier to include in system prompt
|
||||||
|
extra_context: Additional context to append to the prompt
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Constructed system prompt string
|
||||||
|
"""
|
||||||
|
builder = PromptBuilder(
|
||||||
|
workspace_dir=workspace_dir,
|
||||||
|
enabled_files=enabled_files,
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = builder.build()
|
||||||
|
|
||||||
|
# Add agent identity header if agent_id provided
|
||||||
|
if agent_id and agent_id != "default":
|
||||||
|
identity_header = (
|
||||||
|
f"# Agent Identity\n\n"
|
||||||
|
f"Your agent ID is `{agent_id}`. "
|
||||||
|
f"This is your unique identifier in the multi-agent system.\n\n"
|
||||||
|
)
|
||||||
|
prompt = identity_header + prompt
|
||||||
|
|
||||||
|
# Append extra context if provided
|
||||||
|
if extra_context:
|
||||||
|
prompt = prompt + "\n\n" + extra_context
|
||||||
|
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
def build_bootstrap_guidance(language: str = "zh") -> str:
|
||||||
|
"""Build bootstrap guidance message for first-time setup.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
language: Language code (zh/en)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted bootstrap guidance message
|
||||||
|
"""
|
||||||
|
if language == "zh":
|
||||||
|
return (
|
||||||
|
"# 引导模式\n"
|
||||||
|
"\n"
|
||||||
|
"工作目录中存在 `BOOTSTRAP.md` — 首次设置。\n"
|
||||||
|
"\n"
|
||||||
|
"1. 阅读 BOOTSTRAP.md,友好地表示初次见面,"
|
||||||
|
"引导用户完成设置。\n"
|
||||||
|
"2. 按照 BOOTSTRAP.md 的指示,"
|
||||||
|
"帮助用户定义你的身份和偏好。\n"
|
||||||
|
"3. 按指南创建/更新必要文件"
|
||||||
|
"(PROFILE.md、MEMORY.md 等)。\n"
|
||||||
|
"4. 完成后删除 BOOTSTRAP.md。\n"
|
||||||
|
"\n"
|
||||||
|
"如果用户希望跳过,直接回答下面的问题即可。\n"
|
||||||
|
"\n"
|
||||||
|
"---\n"
|
||||||
|
"\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
"# BOOTSTRAP MODE\n"
|
||||||
|
"\n"
|
||||||
|
"`BOOTSTRAP.md` exists — first-time setup.\n"
|
||||||
|
"\n"
|
||||||
|
"1. Read BOOTSTRAP.md, greet the user, "
|
||||||
|
"and guide them through setup.\n"
|
||||||
|
"2. Follow BOOTSTRAP.md instructions "
|
||||||
|
"to define identity and preferences.\n"
|
||||||
|
"3. Create/update files "
|
||||||
|
"(PROFILE.md, MEMORY.md, etc.) as described.\n"
|
||||||
|
"4. Delete BOOTSTRAP.md when done.\n"
|
||||||
|
"\n"
|
||||||
|
"If the user wants to skip, answer their "
|
||||||
|
"question directly instead.\n"
|
||||||
|
"\n"
|
||||||
|
"---\n"
|
||||||
|
"\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"PromptBuilder",
|
||||||
|
"build_system_prompt_from_workspace",
|
||||||
|
"build_bootstrap_guidance",
|
||||||
|
"DEFAULT_SYS_PROMPT",
|
||||||
|
]
|
||||||
284
backend/agents/registry.py
Normal file
284
backend/agents/registry.py
Normal file
@@ -0,0 +1,284 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""Agent Registry - In-memory registry for agent management."""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AgentInfo:
|
||||||
|
"""Information about a registered agent."""
|
||||||
|
|
||||||
|
agent_id: str
|
||||||
|
agent_type: str
|
||||||
|
workspace_id: str
|
||||||
|
config_path: str
|
||||||
|
agent_dir: str
|
||||||
|
status: str = "inactive" # inactive, active, error
|
||||||
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Serialize to dictionary."""
|
||||||
|
return {
|
||||||
|
"agent_id": self.agent_id,
|
||||||
|
"agent_type": self.agent_type,
|
||||||
|
"workspace_id": self.workspace_id,
|
||||||
|
"config_path": self.config_path,
|
||||||
|
"agent_dir": self.agent_dir,
|
||||||
|
"status": self.status,
|
||||||
|
"metadata": self.metadata,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class AgentRegistry:
|
||||||
|
"""In-memory registry for agent instances."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize the agent registry."""
|
||||||
|
# Dictionary mapping agent_id -> AgentInfo
|
||||||
|
self._agents: Dict[str, AgentInfo] = {}
|
||||||
|
# Index mapping workspace_id -> set of agent_ids
|
||||||
|
self._workspace_index: Dict[str, set] = {}
|
||||||
|
|
||||||
|
def register(
|
||||||
|
self,
|
||||||
|
agent_id: str,
|
||||||
|
agent_type: str,
|
||||||
|
workspace_id: str,
|
||||||
|
config_path: str,
|
||||||
|
agent_dir: str,
|
||||||
|
status: str = "inactive",
|
||||||
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> AgentInfo:
|
||||||
|
"""Register an agent in the registry.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: Unique identifier for the agent
|
||||||
|
agent_type: Type of agent
|
||||||
|
workspace_id: ID of the workspace containing the agent
|
||||||
|
config_path: Path to agent configuration file
|
||||||
|
agent_dir: Path to agent directory
|
||||||
|
status: Initial status (default: inactive)
|
||||||
|
metadata: Optional metadata dictionary
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AgentInfo instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If agent_id is already registered
|
||||||
|
"""
|
||||||
|
if agent_id in self._agents:
|
||||||
|
raise ValueError(f"Agent '{agent_id}' is already registered")
|
||||||
|
|
||||||
|
agent_info = AgentInfo(
|
||||||
|
agent_id=agent_id,
|
||||||
|
agent_type=agent_type,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
config_path=config_path,
|
||||||
|
agent_dir=agent_dir,
|
||||||
|
status=status,
|
||||||
|
metadata=metadata or {},
|
||||||
|
)
|
||||||
|
|
||||||
|
self._agents[agent_id] = agent_info
|
||||||
|
|
||||||
|
# Update workspace index
|
||||||
|
if workspace_id not in self._workspace_index:
|
||||||
|
self._workspace_index[workspace_id] = set()
|
||||||
|
self._workspace_index[workspace_id].add(agent_id)
|
||||||
|
|
||||||
|
return agent_info
|
||||||
|
|
||||||
|
def unregister(self, agent_id: str) -> bool:
|
||||||
|
"""Unregister an agent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: ID of the agent to unregister
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if unregistered, False if agent wasn't registered
|
||||||
|
"""
|
||||||
|
if agent_id not in self._agents:
|
||||||
|
return False
|
||||||
|
|
||||||
|
agent_info = self._agents[agent_id]
|
||||||
|
|
||||||
|
# Remove from workspace index
|
||||||
|
workspace_id = agent_info.workspace_id
|
||||||
|
if workspace_id in self._workspace_index:
|
||||||
|
self._workspace_index[workspace_id].discard(agent_id)
|
||||||
|
if not self._workspace_index[workspace_id]:
|
||||||
|
del self._workspace_index[workspace_id]
|
||||||
|
|
||||||
|
# Remove from agents dict
|
||||||
|
del self._agents[agent_id]
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get(self, agent_id: str) -> Optional[AgentInfo]:
|
||||||
|
"""Get agent information by ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: ID of the agent
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AgentInfo if found, None otherwise
|
||||||
|
"""
|
||||||
|
return self._agents.get(agent_id)
|
||||||
|
|
||||||
|
def list_all(
|
||||||
|
self,
|
||||||
|
workspace_id: Optional[str] = None,
|
||||||
|
agent_type: Optional[str] = None,
|
||||||
|
status: Optional[str] = None,
|
||||||
|
) -> List[AgentInfo]:
|
||||||
|
"""List all registered agents with optional filtering.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_id: Filter by workspace ID
|
||||||
|
agent_type: Filter by agent type
|
||||||
|
status: Filter by status
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of AgentInfo instances
|
||||||
|
"""
|
||||||
|
agents = list(self._agents.values())
|
||||||
|
|
||||||
|
if workspace_id:
|
||||||
|
agent_ids = self._workspace_index.get(workspace_id, set())
|
||||||
|
agents = [a for a in agents if a.agent_id in agent_ids]
|
||||||
|
|
||||||
|
if agent_type:
|
||||||
|
agents = [a for a in agents if a.agent_type == agent_type]
|
||||||
|
|
||||||
|
if status:
|
||||||
|
agents = [a for a in agents if a.status == status]
|
||||||
|
|
||||||
|
return agents
|
||||||
|
|
||||||
|
def update_status(self, agent_id: str, status: str) -> bool:
|
||||||
|
"""Update the status of an agent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: ID of the agent
|
||||||
|
status: New status value
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if updated, False if agent not found
|
||||||
|
"""
|
||||||
|
if agent_id not in self._agents:
|
||||||
|
return False
|
||||||
|
|
||||||
|
self._agents[agent_id].status = status
|
||||||
|
return True
|
||||||
|
|
||||||
|
def update_metadata(self, agent_id: str, metadata: Dict[str, Any]) -> bool:
|
||||||
|
"""Update the metadata of an agent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: ID of the agent
|
||||||
|
metadata: Metadata dictionary to merge
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if updated, False if agent not found
|
||||||
|
"""
|
||||||
|
if agent_id not in self._agents:
|
||||||
|
return False
|
||||||
|
|
||||||
|
self._agents[agent_id].metadata.update(metadata)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def is_registered(self, agent_id: str) -> bool:
|
||||||
|
"""Check if an agent is registered.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: ID of the agent
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if registered, False otherwise
|
||||||
|
"""
|
||||||
|
return agent_id in self._agents
|
||||||
|
|
||||||
|
def get_workspace_agents(self, workspace_id: str) -> List[AgentInfo]:
|
||||||
|
"""Get all agents in a workspace.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_id: ID of the workspace
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of AgentInfo instances
|
||||||
|
"""
|
||||||
|
agent_ids = self._workspace_index.get(workspace_id, set())
|
||||||
|
return [self._agents[agent_id] for agent_id in agent_ids if agent_id in self._agents]
|
||||||
|
|
||||||
|
def get_agent_count(self, workspace_id: Optional[str] = None) -> int:
|
||||||
|
"""Get the count of registered agents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_id: Optional workspace ID to filter by
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of agents
|
||||||
|
"""
|
||||||
|
if workspace_id:
|
||||||
|
return len(self._workspace_index.get(workspace_id, set()))
|
||||||
|
return len(self._agents)
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""Clear all registered agents."""
|
||||||
|
self._agents.clear()
|
||||||
|
self._workspace_index.clear()
|
||||||
|
|
||||||
|
def get_stats(self) -> Dict[str, Any]:
|
||||||
|
"""Get registry statistics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with registry statistics
|
||||||
|
"""
|
||||||
|
stats = {
|
||||||
|
"total_agents": len(self._agents),
|
||||||
|
"workspaces": len(self._workspace_index),
|
||||||
|
"agents_by_workspace": {
|
||||||
|
ws_id: len(agent_ids)
|
||||||
|
for ws_id, agent_ids in self._workspace_index.items()
|
||||||
|
},
|
||||||
|
"agents_by_type": {},
|
||||||
|
"agents_by_status": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
for agent in self._agents.values():
|
||||||
|
# Count by type
|
||||||
|
agent_type = agent.agent_type
|
||||||
|
stats["agents_by_type"][agent_type] = (
|
||||||
|
stats["agents_by_type"].get(agent_type, 0) + 1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Count by status
|
||||||
|
status = agent.status
|
||||||
|
stats["agents_by_status"][status] = (
|
||||||
|
stats["agents_by_status"].get(status, 0) + 1
|
||||||
|
)
|
||||||
|
|
||||||
|
return stats
|
||||||
|
|
||||||
|
|
||||||
|
# Global registry instance
|
||||||
|
_global_registry: Optional[AgentRegistry] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_registry() -> AgentRegistry:
|
||||||
|
"""Get the global agent registry instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AgentRegistry instance
|
||||||
|
"""
|
||||||
|
global _global_registry
|
||||||
|
if _global_registry is None:
|
||||||
|
_global_registry = AgentRegistry()
|
||||||
|
return _global_registry
|
||||||
|
|
||||||
|
|
||||||
|
def reset_registry() -> None:
|
||||||
|
"""Reset the global registry (useful for testing)."""
|
||||||
|
global _global_registry
|
||||||
|
_global_registry = None
|
||||||
@@ -39,12 +39,12 @@ class RiskAgent(ReActAgent):
|
|||||||
config: Configuration dictionary
|
config: Configuration dictionary
|
||||||
long_term_memory: Optional ReMeTaskLongTermMemory instance
|
long_term_memory: Optional ReMeTaskLongTermMemory instance
|
||||||
"""
|
"""
|
||||||
self.config = config or {}
|
object.__setattr__(self, "config", config or {})
|
||||||
self.agent_id = name
|
object.__setattr__(self, "agent_id", name)
|
||||||
|
|
||||||
if toolkit is None:
|
if toolkit is None:
|
||||||
toolkit = Toolkit()
|
toolkit = Toolkit()
|
||||||
self.toolkit = toolkit
|
object.__setattr__(self, "toolkit", toolkit)
|
||||||
|
|
||||||
sys_prompt = self._load_system_prompt()
|
sys_prompt = self._load_system_prompt()
|
||||||
|
|
||||||
@@ -99,4 +99,12 @@ class RiskAgent(ReActAgent):
|
|||||||
self.config.get("config_name", "default"),
|
self.config.get("config_name", "default"),
|
||||||
active_skill_dirs=active_skill_dirs,
|
active_skill_dirs=active_skill_dirs,
|
||||||
)
|
)
|
||||||
self.sys_prompt = self._load_system_prompt()
|
self._apply_runtime_sys_prompt(self._load_system_prompt())
|
||||||
|
|
||||||
|
def _apply_runtime_sys_prompt(self, sys_prompt: str) -> None:
|
||||||
|
"""Update the prompt used by future turns and the cached system msg."""
|
||||||
|
self._sys_prompt = sys_prompt
|
||||||
|
for msg, _marks in self.memory.content:
|
||||||
|
if getattr(msg, "role", None) == "system":
|
||||||
|
msg.content = sys_prompt
|
||||||
|
break
|
||||||
|
|||||||
388
backend/agents/skill_loader.py
Normal file
388
backend/agents/skill_loader.py
Normal file
@@ -0,0 +1,388 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""Skill loader for loading and validating skills from directories.
|
||||||
|
|
||||||
|
提供从目录加载技能、解析SKILL.md frontmatter、获取工具列表等功能。
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional, Set
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from backend.agents.skill_metadata import SkillMetadata, parse_skill_metadata
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SkillInfo:
|
||||||
|
"""完整的技能信息"""
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
version: str
|
||||||
|
source: str
|
||||||
|
path: Path
|
||||||
|
metadata: SkillMetadata
|
||||||
|
tools: List[str] = field(default_factory=list)
|
||||||
|
scripts: List[str] = field(default_factory=list)
|
||||||
|
references: List[str] = field(default_factory=list)
|
||||||
|
content: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
def load_skill_from_dir(skill_dir: Path, source: str = "unknown") -> Optional[Dict[str, Any]]:
|
||||||
|
"""从目录加载技能
|
||||||
|
|
||||||
|
Args:
|
||||||
|
skill_dir: 技能目录路径
|
||||||
|
source: 技能来源 (builtin/customized/local/installed/active)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
技能信息字典,加载失败返回None
|
||||||
|
"""
|
||||||
|
if not skill_dir.exists() or not skill_dir.is_dir():
|
||||||
|
logger.warning(f"Skill directory does not exist: {skill_dir}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
skill_md = skill_dir / "SKILL.md"
|
||||||
|
if not skill_md.exists():
|
||||||
|
logger.warning(f"SKILL.md not found in: {skill_dir}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 解析元数据
|
||||||
|
metadata = parse_skill_metadata(skill_dir, source=source)
|
||||||
|
|
||||||
|
# 读取完整内容
|
||||||
|
content = skill_md.read_text(encoding="utf-8")
|
||||||
|
|
||||||
|
# 提取body (去掉frontmatter)
|
||||||
|
body = content
|
||||||
|
if content.startswith("---"):
|
||||||
|
parts = content.split("---", 2)
|
||||||
|
if len(parts) >= 3:
|
||||||
|
body = parts[2].strip()
|
||||||
|
|
||||||
|
# 获取工具列表
|
||||||
|
tools = get_skill_tools(skill_dir)
|
||||||
|
|
||||||
|
# 获取脚本列表
|
||||||
|
scripts = _get_skill_scripts(skill_dir)
|
||||||
|
|
||||||
|
# 获取参考资料列表
|
||||||
|
references = _get_skill_references(skill_dir)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"name": metadata.name,
|
||||||
|
"skill_name": metadata.skill_name,
|
||||||
|
"description": metadata.description,
|
||||||
|
"version": metadata.version,
|
||||||
|
"source": source,
|
||||||
|
"path": str(skill_dir),
|
||||||
|
"content": body,
|
||||||
|
"tools": tools,
|
||||||
|
"scripts": scripts,
|
||||||
|
"references": references,
|
||||||
|
"metadata": metadata,
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load skill from {skill_dir}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def parse_skill_metadata(skill_dir: Path, source: str = "unknown") -> SkillMetadata:
|
||||||
|
"""解析技能元数据 (兼容已有函数)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
skill_dir: 技能目录路径
|
||||||
|
source: 技能来源
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SkillMetadata对象
|
||||||
|
"""
|
||||||
|
from backend.agents.skill_metadata import parse_skill_metadata as _parse
|
||||||
|
return _parse(skill_dir, source=source)
|
||||||
|
|
||||||
|
|
||||||
|
def get_skill_tools(skill_dir: Path) -> List[str]:
|
||||||
|
"""获取技能提供的工具列表
|
||||||
|
|
||||||
|
从SKILL.md frontmatter的tools字段和scripts目录解析工具。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
skill_dir: 技能目录路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
工具名称列表
|
||||||
|
"""
|
||||||
|
tools: Set[str] = set()
|
||||||
|
|
||||||
|
# 1. 从SKILL.md frontmatter读取tools字段
|
||||||
|
skill_md = skill_dir / "SKILL.md"
|
||||||
|
if skill_md.exists():
|
||||||
|
try:
|
||||||
|
raw = skill_md.read_text(encoding="utf-8").strip()
|
||||||
|
if raw.startswith("---"):
|
||||||
|
parts = raw.split("---", 2)
|
||||||
|
if len(parts) >= 3:
|
||||||
|
try:
|
||||||
|
frontmatter = yaml.safe_load(parts[1].strip()) or {}
|
||||||
|
if isinstance(frontmatter, dict):
|
||||||
|
tools_list = frontmatter.get("tools", [])
|
||||||
|
if isinstance(tools_list, str):
|
||||||
|
tools.add(tools_list.strip())
|
||||||
|
elif isinstance(tools_list, list):
|
||||||
|
for tool in tools_list:
|
||||||
|
if isinstance(tool, str):
|
||||||
|
tools.add(tool.strip())
|
||||||
|
except yaml.YAMLError:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to parse tools from SKILL.md: {e}")
|
||||||
|
|
||||||
|
# 2. 从scripts目录推断工具
|
||||||
|
scripts_dir = skill_dir / "scripts"
|
||||||
|
if scripts_dir.exists() and scripts_dir.is_dir():
|
||||||
|
for script in scripts_dir.iterdir():
|
||||||
|
if script.is_file() and not script.name.startswith("_"):
|
||||||
|
# 去掉扩展名作为工具名
|
||||||
|
tool_name = script.stem
|
||||||
|
tools.add(tool_name)
|
||||||
|
|
||||||
|
return sorted(list(tools))
|
||||||
|
|
||||||
|
|
||||||
|
def _get_skill_scripts(skill_dir: Path) -> List[str]:
|
||||||
|
"""获取技能脚本列表
|
||||||
|
|
||||||
|
Args:
|
||||||
|
skill_dir: 技能目录路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
脚本相对路径列表 (相对于scripts目录)
|
||||||
|
"""
|
||||||
|
scripts: List[str] = []
|
||||||
|
scripts_dir = skill_dir / "scripts"
|
||||||
|
|
||||||
|
if not scripts_dir.exists():
|
||||||
|
return scripts
|
||||||
|
|
||||||
|
try:
|
||||||
|
for item in scripts_dir.rglob("*"):
|
||||||
|
if item.is_file() and not item.name.startswith("_"):
|
||||||
|
rel_path = item.relative_to(scripts_dir)
|
||||||
|
scripts.append(str(rel_path))
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to list scripts in {skill_dir}: {e}")
|
||||||
|
|
||||||
|
return sorted(scripts)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_skill_references(skill_dir: Path) -> List[str]:
|
||||||
|
"""获取技能参考资料列表
|
||||||
|
|
||||||
|
Args:
|
||||||
|
skill_dir: 技能目录路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
参考资料相对路径列表 (相对于references目录)
|
||||||
|
"""
|
||||||
|
refs: List[str] = []
|
||||||
|
refs_dir = skill_dir / "references"
|
||||||
|
|
||||||
|
if not refs_dir.exists():
|
||||||
|
return refs
|
||||||
|
|
||||||
|
try:
|
||||||
|
for item in refs_dir.rglob("*"):
|
||||||
|
if item.is_file():
|
||||||
|
rel_path = item.relative_to(refs_dir)
|
||||||
|
refs.append(str(rel_path))
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to list references in {skill_dir}: {e}")
|
||||||
|
|
||||||
|
return sorted(refs)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_skill(skill_dir: Path) -> Dict[str, Any]:
|
||||||
|
"""验证技能格式
|
||||||
|
|
||||||
|
检查技能目录结构是否符合规范。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
skill_dir: 技能目录路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
验证结果字典,包含:
|
||||||
|
- valid: 是否有效
|
||||||
|
- errors: 错误列表
|
||||||
|
- warnings: 警告列表
|
||||||
|
"""
|
||||||
|
errors: List[str] = []
|
||||||
|
warnings: List[str] = []
|
||||||
|
|
||||||
|
# 检查目录存在
|
||||||
|
if not skill_dir.exists():
|
||||||
|
errors.append(f"Skill directory does not exist: {skill_dir}")
|
||||||
|
return {"valid": False, "errors": errors, "warnings": warnings}
|
||||||
|
|
||||||
|
if not skill_dir.is_dir():
|
||||||
|
errors.append(f"Path is not a directory: {skill_dir}")
|
||||||
|
return {"valid": False, "errors": errors, "warnings": warnings}
|
||||||
|
|
||||||
|
# 检查SKILL.md
|
||||||
|
skill_md = skill_dir / "SKILL.md"
|
||||||
|
if not skill_md.exists():
|
||||||
|
errors.append("SKILL.md is required but not found")
|
||||||
|
return {"valid": False, "errors": errors, "warnings": warnings}
|
||||||
|
|
||||||
|
# 解析frontmatter
|
||||||
|
try:
|
||||||
|
content = skill_md.read_text(encoding="utf-8").strip()
|
||||||
|
if not content.startswith("---"):
|
||||||
|
warnings.append("SKILL.md should have YAML frontmatter (starts with ---)")
|
||||||
|
else:
|
||||||
|
parts = content.split("---", 2)
|
||||||
|
if len(parts) < 3:
|
||||||
|
errors.append("Invalid YAML frontmatter format")
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
frontmatter = yaml.safe_load(parts[1].strip()) or {}
|
||||||
|
if not isinstance(frontmatter, dict):
|
||||||
|
errors.append("YAML frontmatter must be a dictionary")
|
||||||
|
else:
|
||||||
|
# 检查必需字段
|
||||||
|
if "name" not in frontmatter:
|
||||||
|
warnings.append("Frontmatter should have 'name' field")
|
||||||
|
if "description" not in frontmatter:
|
||||||
|
warnings.append("Frontmatter should have 'description' field")
|
||||||
|
|
||||||
|
# 检查version字段
|
||||||
|
version = frontmatter.get("version")
|
||||||
|
if version and not isinstance(version, str):
|
||||||
|
warnings.append("'version' should be a string")
|
||||||
|
|
||||||
|
# 检查tools字段
|
||||||
|
tools = frontmatter.get("tools")
|
||||||
|
if tools and not isinstance(tools, (str, list)):
|
||||||
|
warnings.append("'tools' should be a string or list")
|
||||||
|
|
||||||
|
except yaml.YAMLError as e:
|
||||||
|
errors.append(f"Invalid YAML in frontmatter: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
errors.append(f"Failed to read SKILL.md: {e}")
|
||||||
|
|
||||||
|
# 检查body内容
|
||||||
|
try:
|
||||||
|
content = skill_md.read_text(encoding="utf-8")
|
||||||
|
body = content
|
||||||
|
if content.startswith("---"):
|
||||||
|
parts = content.split("---", 2)
|
||||||
|
if len(parts) >= 3:
|
||||||
|
body = parts[2].strip()
|
||||||
|
|
||||||
|
if not body:
|
||||||
|
warnings.append("SKILL.md body is empty")
|
||||||
|
elif len(body) < 50:
|
||||||
|
warnings.append("SKILL.md body is very short, consider adding more details")
|
||||||
|
except Exception as e:
|
||||||
|
errors.append(f"Failed to validate body: {e}")
|
||||||
|
|
||||||
|
# 检查scripts目录
|
||||||
|
scripts_dir = skill_dir / "scripts"
|
||||||
|
if scripts_dir.exists():
|
||||||
|
if not scripts_dir.is_dir():
|
||||||
|
errors.append("'scripts' exists but is not a directory")
|
||||||
|
else:
|
||||||
|
# 检查是否有可执行脚本
|
||||||
|
has_scripts = any(
|
||||||
|
f.is_file() and not f.name.startswith("_")
|
||||||
|
for f in scripts_dir.iterdir()
|
||||||
|
)
|
||||||
|
if not has_scripts:
|
||||||
|
warnings.append("scripts directory exists but contains no valid scripts")
|
||||||
|
|
||||||
|
# 检查references目录
|
||||||
|
refs_dir = skill_dir / "references"
|
||||||
|
if refs_dir.exists() and not refs_dir.is_dir():
|
||||||
|
errors.append("'references' exists but is not a directory")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"valid": len(errors) == 0,
|
||||||
|
"errors": errors,
|
||||||
|
"warnings": warnings,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def load_skills_from_directory(
|
||||||
|
directory: Path,
|
||||||
|
source: str = "unknown",
|
||||||
|
recursive: bool = False,
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""从目录加载所有技能
|
||||||
|
|
||||||
|
Args:
|
||||||
|
directory: 包含技能目录的父目录
|
||||||
|
source: 技能来源标识
|
||||||
|
recursive: 是否递归搜索子目录
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
技能信息列表
|
||||||
|
"""
|
||||||
|
skills: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
|
if not directory.exists() or not directory.is_dir():
|
||||||
|
logger.warning(f"Directory does not exist: {directory}")
|
||||||
|
return skills
|
||||||
|
|
||||||
|
try:
|
||||||
|
for item in directory.iterdir():
|
||||||
|
if not item.is_dir():
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 检查是否是技能目录 (包含SKILL.md)
|
||||||
|
if (item / "SKILL.md").exists():
|
||||||
|
skill_info = load_skill_from_dir(item, source=source)
|
||||||
|
if skill_info:
|
||||||
|
skills.append(skill_info)
|
||||||
|
elif recursive:
|
||||||
|
# 递归搜索子目录
|
||||||
|
sub_skills = load_skills_from_directory(item, source, recursive)
|
||||||
|
skills.extend(sub_skills)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load skills from {directory}: {e}")
|
||||||
|
|
||||||
|
return skills
|
||||||
|
|
||||||
|
|
||||||
|
def get_skill_manifest(skill_dir: Path) -> Dict[str, Any]:
|
||||||
|
"""获取技能清单
|
||||||
|
|
||||||
|
生成技能的详细清单,用于调试和展示。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
skill_dir: 技能目录路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
技能清单字典
|
||||||
|
"""
|
||||||
|
info = load_skill_from_dir(skill_dir)
|
||||||
|
if not info:
|
||||||
|
return {"error": "Failed to load skill"}
|
||||||
|
|
||||||
|
validation = validate_skill(skill_dir)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"name": info["name"],
|
||||||
|
"skill_name": info["skill_name"],
|
||||||
|
"version": info["version"],
|
||||||
|
"description": info["description"],
|
||||||
|
"source": info["source"],
|
||||||
|
"path": info["path"],
|
||||||
|
"tools": info["tools"],
|
||||||
|
"scripts": info["scripts"],
|
||||||
|
"references": info["references"],
|
||||||
|
"validation": validation,
|
||||||
|
"content_preview": info["content"][:500] + "..." if len(info["content"]) > 500 else info["content"],
|
||||||
|
}
|
||||||
83
backend/agents/skill_metadata.py
Normal file
83
backend/agents/skill_metadata.py
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""Skill metadata parsing helpers for SKILL.md files."""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class SkillMetadata:
|
||||||
|
"""Parsed metadata for a skill package."""
|
||||||
|
|
||||||
|
skill_name: str
|
||||||
|
path: Path
|
||||||
|
source: str
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
version: str = ""
|
||||||
|
tools: List[str] = field(default_factory=list)
|
||||||
|
allowed_tools: List[str] = field(default_factory=list)
|
||||||
|
denied_tools: List[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_skill_metadata(skill_dir: Path, source: str) -> SkillMetadata:
|
||||||
|
"""Parse SKILL.md frontmatter with a forgiving schema."""
|
||||||
|
skill_name = skill_dir.name
|
||||||
|
skill_file = skill_dir / "SKILL.md"
|
||||||
|
if not skill_file.exists():
|
||||||
|
return SkillMetadata(
|
||||||
|
skill_name=skill_name,
|
||||||
|
path=skill_dir,
|
||||||
|
source=source,
|
||||||
|
name=skill_name,
|
||||||
|
description="",
|
||||||
|
)
|
||||||
|
|
||||||
|
raw = skill_file.read_text(encoding="utf-8").strip()
|
||||||
|
frontmatter = {}
|
||||||
|
body = raw
|
||||||
|
if raw.startswith("---"):
|
||||||
|
parts = raw.split("---", 2)
|
||||||
|
if len(parts) >= 3:
|
||||||
|
try:
|
||||||
|
frontmatter = yaml.safe_load(parts[1].strip()) or {}
|
||||||
|
except yaml.YAMLError:
|
||||||
|
frontmatter = {}
|
||||||
|
body = parts[2].strip()
|
||||||
|
if not isinstance(frontmatter, dict):
|
||||||
|
frontmatter = {}
|
||||||
|
|
||||||
|
description = str(frontmatter.get("description") or "").strip()
|
||||||
|
if not description and body:
|
||||||
|
description = body.splitlines()[0].strip().lstrip("#").strip()
|
||||||
|
|
||||||
|
return SkillMetadata(
|
||||||
|
skill_name=skill_name,
|
||||||
|
path=skill_dir,
|
||||||
|
source=source,
|
||||||
|
name=str(frontmatter.get("name") or skill_name).strip() or skill_name,
|
||||||
|
description=description,
|
||||||
|
version=str(frontmatter.get("version") or "").strip(),
|
||||||
|
tools=_string_list(frontmatter.get("tools")),
|
||||||
|
allowed_tools=_string_list(frontmatter.get("allowed_tools")),
|
||||||
|
denied_tools=_string_list(frontmatter.get("denied_tools")),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _string_list(value) -> List[str]:
|
||||||
|
if isinstance(value, str):
|
||||||
|
item = value.strip()
|
||||||
|
return [item] if item else []
|
||||||
|
if not isinstance(value, list):
|
||||||
|
return []
|
||||||
|
seen: List[str] = []
|
||||||
|
for item in value:
|
||||||
|
if not isinstance(item, str):
|
||||||
|
continue
|
||||||
|
normalized = item.strip()
|
||||||
|
if normalized and normalized not in seen:
|
||||||
|
seen.append(normalized)
|
||||||
|
return seen
|
||||||
@@ -1,14 +1,32 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
"""Manage builtin/customized/active skill directories for each run."""
|
"""Manage agent-installed and run-active skill directories for each run."""
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import shutil
|
import shutil
|
||||||
from typing import Dict, Iterable, List
|
import tempfile
|
||||||
|
import zipfile
|
||||||
|
from threading import Lock
|
||||||
|
from typing import Any, Dict, Iterable, Iterator, List, Optional, Set
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
from urllib.request import urlretrieve
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
from backend.agents.agent_workspace import load_agent_workspace_config
|
||||||
|
from backend.agents.skill_metadata import SkillMetadata, parse_skill_metadata
|
||||||
|
from backend.agents.skill_loader import validate_skill
|
||||||
from backend.config.bootstrap_config import get_bootstrap_config_for_run
|
from backend.config.bootstrap_config import get_bootstrap_config_for_run
|
||||||
|
|
||||||
|
try:
|
||||||
|
from watchdog.observers import Observer
|
||||||
|
from watchdog.events import FileSystemEventHandler, FileSystemEvent
|
||||||
|
WATCHDOG_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
WATCHDOG_AVAILABLE = False
|
||||||
|
Observer = None
|
||||||
|
FileSystemEventHandler = object
|
||||||
|
FileSystemEvent = object # type: ignore[misc,assignment]
|
||||||
|
|
||||||
|
|
||||||
class SkillsManager:
|
class SkillsManager:
|
||||||
"""Sync named skills into a run-scoped active skills workspace."""
|
"""Sync named skills into a run-scoped active skills workspace."""
|
||||||
@@ -22,16 +40,391 @@ class SkillsManager:
|
|||||||
self.project_root / "backend" / "skills" / "customized"
|
self.project_root / "backend" / "skills" / "customized"
|
||||||
)
|
)
|
||||||
self.runs_root = self.project_root / "runs"
|
self.runs_root = self.project_root / "runs"
|
||||||
|
self._lock = Lock()
|
||||||
|
|
||||||
def get_active_root(self, config_name: str) -> Path:
|
def get_active_root(self, config_name: str) -> Path:
|
||||||
return self.runs_root / config_name / "skills" / "active"
|
return self.runs_root / config_name / "skills" / "active"
|
||||||
|
|
||||||
|
def get_agent_skills_root(self, config_name: str, agent_id: str) -> Path:
|
||||||
|
return self.get_agent_asset_dir(config_name, agent_id) / "skills"
|
||||||
|
|
||||||
|
def get_agent_active_root(self, config_name: str, agent_id: str) -> Path:
|
||||||
|
return self.get_agent_skills_root(config_name, agent_id) / "active"
|
||||||
|
|
||||||
|
def get_agent_installed_root(self, config_name: str, agent_id: str) -> Path:
|
||||||
|
return self.get_agent_skills_root(config_name, agent_id) / "installed"
|
||||||
|
|
||||||
|
def get_agent_disabled_root(self, config_name: str, agent_id: str) -> Path:
|
||||||
|
return self.get_agent_skills_root(config_name, agent_id) / "disabled"
|
||||||
|
|
||||||
|
def get_agent_local_root(self, config_name: str, agent_id: str) -> Path:
|
||||||
|
return self.get_agent_skills_root(config_name, agent_id) / "local"
|
||||||
|
|
||||||
def get_activation_manifest_path(self, config_name: str) -> Path:
|
def get_activation_manifest_path(self, config_name: str) -> Path:
|
||||||
return self.runs_root / config_name / "skills" / "activation.yaml"
|
return self.runs_root / config_name / "skills" / "activation.yaml"
|
||||||
|
|
||||||
def get_agent_asset_dir(self, config_name: str, agent_id: str) -> Path:
|
def get_agent_asset_dir(self, config_name: str, agent_id: str) -> Path:
|
||||||
return self.runs_root / config_name / "agents" / agent_id
|
return self.runs_root / config_name / "agents" / agent_id
|
||||||
|
|
||||||
|
def list_skill_catalog(self) -> List[SkillMetadata]:
|
||||||
|
"""Return builtin/customized skills with parsed metadata."""
|
||||||
|
catalog: Dict[str, SkillMetadata] = {}
|
||||||
|
|
||||||
|
for source, root in (
|
||||||
|
("builtin", self.builtin_root),
|
||||||
|
("customized", self.customized_root),
|
||||||
|
):
|
||||||
|
if not root.exists():
|
||||||
|
continue
|
||||||
|
for skill_dir in sorted(root.iterdir(), key=lambda item: item.name):
|
||||||
|
if not skill_dir.is_dir():
|
||||||
|
continue
|
||||||
|
if not (skill_dir / "SKILL.md").exists():
|
||||||
|
continue
|
||||||
|
metadata = parse_skill_metadata(skill_dir, source=source)
|
||||||
|
catalog[metadata.skill_name] = metadata
|
||||||
|
|
||||||
|
return sorted(catalog.values(), key=lambda item: item.skill_name)
|
||||||
|
|
||||||
|
def list_agent_skill_catalog(
|
||||||
|
self,
|
||||||
|
config_name: str,
|
||||||
|
agent_id: str,
|
||||||
|
) -> List[SkillMetadata]:
|
||||||
|
"""Return shared plus agent-local skills for one agent."""
|
||||||
|
catalog = {
|
||||||
|
item.skill_name: item
|
||||||
|
for item in self.list_skill_catalog()
|
||||||
|
}
|
||||||
|
for item in self.list_agent_local_skills(config_name, agent_id):
|
||||||
|
catalog[item.skill_name] = item
|
||||||
|
return sorted(catalog.values(), key=lambda item: item.skill_name)
|
||||||
|
|
||||||
|
def list_active_skill_metadata(
|
||||||
|
self,
|
||||||
|
config_name: str,
|
||||||
|
agent_id: str,
|
||||||
|
) -> List[SkillMetadata]:
|
||||||
|
"""Return metadata for active skills synced for one agent."""
|
||||||
|
active_root = self.get_agent_active_root(config_name, agent_id)
|
||||||
|
if not active_root.exists():
|
||||||
|
return []
|
||||||
|
|
||||||
|
items: List[SkillMetadata] = []
|
||||||
|
for skill_dir in sorted(active_root.iterdir(), key=lambda item: item.name):
|
||||||
|
if not skill_dir.is_dir():
|
||||||
|
continue
|
||||||
|
if not (skill_dir / "SKILL.md").exists():
|
||||||
|
continue
|
||||||
|
items.append(parse_skill_metadata(skill_dir, source="active"))
|
||||||
|
return items
|
||||||
|
|
||||||
|
def list_agent_local_skills(
|
||||||
|
self,
|
||||||
|
config_name: str,
|
||||||
|
agent_id: str,
|
||||||
|
) -> List[SkillMetadata]:
|
||||||
|
"""Return metadata for agent-private local skills."""
|
||||||
|
local_root = self.get_agent_local_root(config_name, agent_id)
|
||||||
|
if not local_root.exists():
|
||||||
|
return []
|
||||||
|
|
||||||
|
items: List[SkillMetadata] = []
|
||||||
|
for skill_dir in sorted(local_root.iterdir(), key=lambda item: item.name):
|
||||||
|
if not skill_dir.is_dir():
|
||||||
|
continue
|
||||||
|
if not (skill_dir / "SKILL.md").exists():
|
||||||
|
continue
|
||||||
|
items.append(parse_skill_metadata(skill_dir, source="local"))
|
||||||
|
return items
|
||||||
|
|
||||||
|
def load_skill_document(self, skill_name: str) -> Dict[str, object]:
|
||||||
|
"""Return skill metadata plus markdown body for one skill."""
|
||||||
|
source_dir = self._resolve_source_dir(skill_name)
|
||||||
|
return self._load_skill_document_from_dir(
|
||||||
|
source_dir,
|
||||||
|
source="customized" if source_dir.parent == self.customized_root else "builtin",
|
||||||
|
)
|
||||||
|
|
||||||
|
def load_agent_skill_document(
|
||||||
|
self,
|
||||||
|
config_name: str,
|
||||||
|
agent_id: str,
|
||||||
|
skill_name: str,
|
||||||
|
) -> Dict[str, object]:
|
||||||
|
"""Return skill metadata plus markdown body for one agent-visible skill."""
|
||||||
|
source_dir = self._resolve_agent_skill_source_dir(
|
||||||
|
config_name=config_name,
|
||||||
|
agent_id=agent_id,
|
||||||
|
skill_name=skill_name,
|
||||||
|
)
|
||||||
|
source = "local"
|
||||||
|
if source_dir.parent == self.customized_root:
|
||||||
|
source = "customized"
|
||||||
|
elif source_dir.parent == self.builtin_root:
|
||||||
|
source = "builtin"
|
||||||
|
elif source_dir.parent == self.get_agent_installed_root(config_name, agent_id):
|
||||||
|
source = "installed"
|
||||||
|
return self._load_skill_document_from_dir(source_dir, source=source)
|
||||||
|
|
||||||
|
def create_agent_local_skill(
|
||||||
|
self,
|
||||||
|
config_name: str,
|
||||||
|
agent_id: str,
|
||||||
|
skill_name: str,
|
||||||
|
) -> Path:
|
||||||
|
"""Create a new local skill directory with a default SKILL.md."""
|
||||||
|
normalized = _normalize_skill_name(skill_name)
|
||||||
|
if not normalized:
|
||||||
|
raise ValueError("Skill name is required.")
|
||||||
|
local_root = self.get_agent_local_root(config_name, agent_id)
|
||||||
|
local_root.mkdir(parents=True, exist_ok=True)
|
||||||
|
skill_dir = local_root / normalized
|
||||||
|
if skill_dir.exists():
|
||||||
|
raise FileExistsError(f"Local skill already exists: {normalized}")
|
||||||
|
skill_dir.mkdir(parents=True, exist_ok=False)
|
||||||
|
(skill_dir / "SKILL.md").write_text(
|
||||||
|
"---\n"
|
||||||
|
f"name: {normalized}\n"
|
||||||
|
"description: 当用户提出与该本地技能相关的专门任务时,应使用此技能。\n"
|
||||||
|
"version: 1.0.0\n"
|
||||||
|
"---\n\n"
|
||||||
|
f"# {normalized}\n\n"
|
||||||
|
"在这里描述该交易员的专有分析流程、判断框架和可复用步骤。\n",
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
return skill_dir
|
||||||
|
|
||||||
|
def install_external_skill_for_agent(
|
||||||
|
self,
|
||||||
|
config_name: str,
|
||||||
|
agent_id: str,
|
||||||
|
source: str,
|
||||||
|
*,
|
||||||
|
skill_name: str | None = None,
|
||||||
|
activate: bool = True,
|
||||||
|
) -> Dict[str, object]:
|
||||||
|
"""
|
||||||
|
Install an external skill into one agent's local skill space.
|
||||||
|
|
||||||
|
Supports:
|
||||||
|
- local skill directory containing SKILL.md
|
||||||
|
- local zip archive containing one skill directory
|
||||||
|
- http(s) URL to zip archive
|
||||||
|
"""
|
||||||
|
source_path = self._resolve_external_source_path(source)
|
||||||
|
skill_dir = self._resolve_external_skill_dir(source_path)
|
||||||
|
metadata = parse_skill_metadata(skill_dir, source="external")
|
||||||
|
final_name = _normalize_skill_name(skill_name or metadata.skill_name or skill_dir.name)
|
||||||
|
if not final_name:
|
||||||
|
raise ValueError("Could not determine skill name from external source.")
|
||||||
|
|
||||||
|
target_dir = self.get_agent_local_root(config_name, agent_id) / final_name
|
||||||
|
target_dir.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
if target_dir.exists():
|
||||||
|
shutil.rmtree(target_dir)
|
||||||
|
shutil.copytree(skill_dir, target_dir)
|
||||||
|
|
||||||
|
validation = validate_skill(target_dir)
|
||||||
|
if not validation.get("valid", False):
|
||||||
|
shutil.rmtree(target_dir, ignore_errors=True)
|
||||||
|
raise ValueError(
|
||||||
|
"Installed skill is invalid: "
|
||||||
|
+ "; ".join(validation.get("errors", []))
|
||||||
|
)
|
||||||
|
|
||||||
|
if activate:
|
||||||
|
self.update_agent_skill_overrides(
|
||||||
|
config_name=config_name,
|
||||||
|
agent_id=agent_id,
|
||||||
|
enable=[final_name],
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"skill_name": final_name,
|
||||||
|
"target_dir": str(target_dir),
|
||||||
|
"activated": activate,
|
||||||
|
"warnings": validation.get("warnings", []),
|
||||||
|
}
|
||||||
|
|
||||||
|
def update_agent_local_skill(
|
||||||
|
self,
|
||||||
|
config_name: str,
|
||||||
|
agent_id: str,
|
||||||
|
skill_name: str,
|
||||||
|
content: str,
|
||||||
|
) -> Path:
|
||||||
|
"""Overwrite one agent-local SKILL.md."""
|
||||||
|
normalized = _normalize_skill_name(skill_name)
|
||||||
|
if not normalized:
|
||||||
|
raise ValueError("Skill name is required.")
|
||||||
|
skill_dir = self.get_agent_local_root(config_name, agent_id) / normalized
|
||||||
|
if not skill_dir.exists():
|
||||||
|
raise FileNotFoundError(f"Unknown local skill: {normalized}")
|
||||||
|
(skill_dir / "SKILL.md").write_text(content, encoding="utf-8")
|
||||||
|
return skill_dir
|
||||||
|
|
||||||
|
def delete_agent_local_skill(
|
||||||
|
self,
|
||||||
|
config_name: str,
|
||||||
|
agent_id: str,
|
||||||
|
skill_name: str,
|
||||||
|
) -> None:
|
||||||
|
"""Delete one agent-local skill directory."""
|
||||||
|
normalized = _normalize_skill_name(skill_name)
|
||||||
|
if not normalized:
|
||||||
|
raise ValueError("Skill name is required.")
|
||||||
|
skill_dir = self.get_agent_local_root(config_name, agent_id) / normalized
|
||||||
|
if not skill_dir.exists():
|
||||||
|
raise FileNotFoundError(f"Unknown local skill: {normalized}")
|
||||||
|
shutil.rmtree(skill_dir)
|
||||||
|
|
||||||
|
def _load_skill_document_from_dir(
|
||||||
|
self,
|
||||||
|
source_dir: Path,
|
||||||
|
*,
|
||||||
|
source: str,
|
||||||
|
) -> Dict[str, object]:
|
||||||
|
"""Return metadata plus markdown body for one resolved skill directory."""
|
||||||
|
metadata = parse_skill_metadata(
|
||||||
|
source_dir,
|
||||||
|
source=source,
|
||||||
|
)
|
||||||
|
skill_file = source_dir / "SKILL.md"
|
||||||
|
raw = skill_file.read_text(encoding="utf-8").strip() if skill_file.exists() else ""
|
||||||
|
body = raw
|
||||||
|
if raw.startswith("---"):
|
||||||
|
parts = raw.split("---", 2)
|
||||||
|
if len(parts) >= 3:
|
||||||
|
body = parts[2].strip()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"skill_name": metadata.skill_name,
|
||||||
|
"name": metadata.name,
|
||||||
|
"description": metadata.description,
|
||||||
|
"version": metadata.version,
|
||||||
|
"tools": metadata.tools,
|
||||||
|
"source": metadata.source,
|
||||||
|
"content": body,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _resolve_external_source_path(self, source: str) -> Path:
|
||||||
|
"""Resolve source into a local path; download URL when needed."""
|
||||||
|
parsed = urlparse(source)
|
||||||
|
if parsed.scheme in {"http", "https"}:
|
||||||
|
suffix = Path(parsed.path).suffix or ".zip"
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
|
||||||
|
temp_path = Path(tmp.name)
|
||||||
|
urlretrieve(source, temp_path)
|
||||||
|
return temp_path
|
||||||
|
return Path(source).expanduser().resolve()
|
||||||
|
|
||||||
|
def _resolve_external_skill_dir(self, source_path: Path) -> Path:
|
||||||
|
"""Resolve external source path to a skill directory containing SKILL.md."""
|
||||||
|
if not source_path.exists():
|
||||||
|
raise FileNotFoundError(f"Source does not exist: {source_path}")
|
||||||
|
|
||||||
|
if source_path.is_dir():
|
||||||
|
if (source_path / "SKILL.md").exists():
|
||||||
|
return source_path
|
||||||
|
children = [
|
||||||
|
item for item in source_path.iterdir()
|
||||||
|
if item.is_dir() and (item / "SKILL.md").exists()
|
||||||
|
]
|
||||||
|
if len(children) == 1:
|
||||||
|
return children[0]
|
||||||
|
raise ValueError(
|
||||||
|
"Source directory must contain SKILL.md "
|
||||||
|
"or exactly one child directory containing SKILL.md."
|
||||||
|
)
|
||||||
|
|
||||||
|
if source_path.suffix.lower() != ".zip":
|
||||||
|
raise ValueError("External source file must be a .zip archive.")
|
||||||
|
|
||||||
|
temp_root = Path(tempfile.mkdtemp(prefix="external_skill_"))
|
||||||
|
with zipfile.ZipFile(source_path, "r") as archive:
|
||||||
|
archive.extractall(temp_root)
|
||||||
|
|
||||||
|
candidates = [
|
||||||
|
item.parent
|
||||||
|
for item in temp_root.rglob("SKILL.md")
|
||||||
|
if item.is_file()
|
||||||
|
]
|
||||||
|
unique = []
|
||||||
|
for item in candidates:
|
||||||
|
if item not in unique:
|
||||||
|
unique.append(item)
|
||||||
|
if len(unique) != 1:
|
||||||
|
raise ValueError(
|
||||||
|
"Zip archive must contain exactly one skill directory with SKILL.md."
|
||||||
|
)
|
||||||
|
return unique[0]
|
||||||
|
|
||||||
|
def update_agent_skill_overrides(
|
||||||
|
self,
|
||||||
|
config_name: str,
|
||||||
|
agent_id: str,
|
||||||
|
*,
|
||||||
|
enable: Iterable[str] | None = None,
|
||||||
|
disable: Iterable[str] | None = None,
|
||||||
|
) -> Dict[str, List[str]]:
|
||||||
|
"""Persist per-agent enabled/disabled skill overrides in agent.yaml."""
|
||||||
|
asset_dir = self.get_agent_asset_dir(config_name, agent_id)
|
||||||
|
asset_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
config_path = asset_dir / "agent.yaml"
|
||||||
|
current = load_agent_workspace_config(config_path)
|
||||||
|
values = dict(current.values)
|
||||||
|
|
||||||
|
enabled = _dedupe_preserve_order(current.enabled_skills)
|
||||||
|
disabled_set = set(current.disabled_skills)
|
||||||
|
|
||||||
|
for skill_name in enable or []:
|
||||||
|
if skill_name not in enabled:
|
||||||
|
enabled.append(skill_name)
|
||||||
|
disabled_set.discard(skill_name)
|
||||||
|
|
||||||
|
for skill_name in disable or []:
|
||||||
|
disabled_set.add(skill_name)
|
||||||
|
enabled = [item for item in enabled if item != skill_name]
|
||||||
|
|
||||||
|
values["enabled_skills"] = enabled
|
||||||
|
values["disabled_skills"] = sorted(disabled_set)
|
||||||
|
config_path.write_text(
|
||||||
|
yaml.safe_dump(values, allow_unicode=True, sort_keys=False),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"enabled_skills": enabled,
|
||||||
|
"disabled_skills": sorted(disabled_set),
|
||||||
|
}
|
||||||
|
|
||||||
|
def forget_agent_skill_overrides(
|
||||||
|
self,
|
||||||
|
config_name: str,
|
||||||
|
agent_id: str,
|
||||||
|
skill_names: Iterable[str],
|
||||||
|
) -> Dict[str, List[str]]:
|
||||||
|
"""Remove skills from both enabled/disabled overrides in agent.yaml."""
|
||||||
|
asset_dir = self.get_agent_asset_dir(config_name, agent_id)
|
||||||
|
asset_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
config_path = asset_dir / "agent.yaml"
|
||||||
|
current = load_agent_workspace_config(config_path)
|
||||||
|
values = dict(current.values)
|
||||||
|
removed = set(skill_names)
|
||||||
|
|
||||||
|
enabled = [item for item in current.enabled_skills if item not in removed]
|
||||||
|
disabled = [item for item in current.disabled_skills if item not in removed]
|
||||||
|
|
||||||
|
values["enabled_skills"] = enabled
|
||||||
|
values["disabled_skills"] = disabled
|
||||||
|
config_path.write_text(
|
||||||
|
yaml.safe_dump(values, allow_unicode=True, sort_keys=False),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"enabled_skills": enabled,
|
||||||
|
"disabled_skills": disabled,
|
||||||
|
}
|
||||||
|
|
||||||
def ensure_activation_manifest(self, config_name: str) -> Path:
|
def ensure_activation_manifest(self, config_name: str) -> Path:
|
||||||
manifest_path = self.get_activation_manifest_path(config_name)
|
manifest_path = self.get_activation_manifest_path(config_name)
|
||||||
manifest_path.parent.mkdir(parents=True, exist_ok=True)
|
manifest_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
@@ -62,6 +455,87 @@ class SkillsManager:
|
|||||||
|
|
||||||
raise FileNotFoundError(f"Unknown skill: {skill_name}")
|
raise FileNotFoundError(f"Unknown skill: {skill_name}")
|
||||||
|
|
||||||
|
def _resolve_agent_skill_source_dir(
|
||||||
|
self,
|
||||||
|
config_name: str,
|
||||||
|
agent_id: str,
|
||||||
|
skill_name: str,
|
||||||
|
) -> Path:
|
||||||
|
"""Resolve one skill from the agent-local workspace or shared registry."""
|
||||||
|
for root in (
|
||||||
|
self.get_agent_local_root(config_name, agent_id),
|
||||||
|
self.get_agent_installed_root(config_name, agent_id),
|
||||||
|
):
|
||||||
|
candidate = root / skill_name
|
||||||
|
if candidate.exists() and (candidate / "SKILL.md").exists():
|
||||||
|
return candidate
|
||||||
|
return self._resolve_source_dir(skill_name)
|
||||||
|
|
||||||
|
def _skill_exists_for_agent(
|
||||||
|
self,
|
||||||
|
config_name: str,
|
||||||
|
agent_id: str,
|
||||||
|
skill_name: str,
|
||||||
|
) -> bool:
|
||||||
|
try:
|
||||||
|
self._resolve_agent_skill_source_dir(config_name, agent_id, skill_name)
|
||||||
|
except FileNotFoundError:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _persist_runtime_edits(
|
||||||
|
self,
|
||||||
|
config_name: str,
|
||||||
|
skill_name: str,
|
||||||
|
active_dir: Path,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Persist run-time edits from active skills into customized skills.
|
||||||
|
|
||||||
|
This keeps active skill experiments from being lost on the next reload
|
||||||
|
while still allowing the active directory to be re-synced cleanly.
|
||||||
|
"""
|
||||||
|
if not active_dir.exists():
|
||||||
|
return
|
||||||
|
|
||||||
|
source_dir = self._resolve_source_dir(skill_name)
|
||||||
|
if active_dir.resolve() == source_dir.resolve():
|
||||||
|
return
|
||||||
|
|
||||||
|
if not self._directories_match(active_dir, source_dir):
|
||||||
|
customized_dir = self.customized_root / skill_name
|
||||||
|
customized_dir.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
if customized_dir.exists():
|
||||||
|
shutil.rmtree(customized_dir)
|
||||||
|
shutil.copytree(active_dir, customized_dir)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _directories_match(left: Path, right: Path) -> bool:
|
||||||
|
"""Compare two directory trees by file contents."""
|
||||||
|
if not left.exists() or not right.exists():
|
||||||
|
return False
|
||||||
|
|
||||||
|
left_items = sorted(
|
||||||
|
path.relative_to(left)
|
||||||
|
for path in left.rglob("*")
|
||||||
|
)
|
||||||
|
right_items = sorted(
|
||||||
|
path.relative_to(right)
|
||||||
|
for path in right.rglob("*")
|
||||||
|
)
|
||||||
|
if left_items != right_items:
|
||||||
|
return False
|
||||||
|
|
||||||
|
for relative_path in left_items:
|
||||||
|
left_path = left / relative_path
|
||||||
|
right_path = right / relative_path
|
||||||
|
if left_path.is_dir() != right_path.is_dir():
|
||||||
|
return False
|
||||||
|
if left_path.is_file():
|
||||||
|
if left_path.read_bytes() != right_path.read_bytes():
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
def resolve_agent_skill_names(
|
def resolve_agent_skill_names(
|
||||||
self,
|
self,
|
||||||
config_name: str,
|
config_name: str,
|
||||||
@@ -72,6 +546,13 @@ class SkillsManager:
|
|||||||
bootstrap = get_bootstrap_config_for_run(self.project_root, config_name)
|
bootstrap = get_bootstrap_config_for_run(self.project_root, config_name)
|
||||||
override = bootstrap.agent_override(agent_id)
|
override = bootstrap.agent_override(agent_id)
|
||||||
skills = list(override.get("skills", list(default_skills)))
|
skills = list(override.get("skills", list(default_skills)))
|
||||||
|
agent_config = load_agent_workspace_config(
|
||||||
|
self.get_agent_asset_dir(config_name, agent_id) / "agent.yaml",
|
||||||
|
)
|
||||||
|
|
||||||
|
for skill_name in agent_config.enabled_skills:
|
||||||
|
if skill_name not in skills:
|
||||||
|
skills.append(skill_name)
|
||||||
|
|
||||||
manifest = self.load_activation_manifest(config_name)
|
manifest = self.load_activation_manifest(config_name)
|
||||||
for skill_name in manifest.get("global_enabled_skills", []):
|
for skill_name in manifest.get("global_enabled_skills", []):
|
||||||
@@ -86,28 +567,36 @@ class SkillsManager:
|
|||||||
disabled.update(
|
disabled.update(
|
||||||
manifest.get("agent_disabled_skills", {}).get(agent_id, []),
|
manifest.get("agent_disabled_skills", {}).get(agent_id, []),
|
||||||
)
|
)
|
||||||
|
disabled.update(agent_config.disabled_skills)
|
||||||
|
|
||||||
return [skill for skill in skills if skill not in disabled]
|
for item in self.list_agent_local_skills(config_name, agent_id):
|
||||||
|
if item.skill_name not in skills:
|
||||||
|
skills.append(item.skill_name)
|
||||||
|
|
||||||
def sync_active_skills(
|
return [
|
||||||
|
skill
|
||||||
|
for skill in skills
|
||||||
|
if skill not in disabled
|
||||||
|
and self._skill_exists_for_agent(config_name, agent_id, skill)
|
||||||
|
]
|
||||||
|
|
||||||
|
def sync_skill_dirs(
|
||||||
self,
|
self,
|
||||||
config_name: str,
|
target_root: Path,
|
||||||
skill_names: Iterable[str],
|
skill_sources: Dict[str, Path],
|
||||||
) -> List[Path]:
|
) -> List[Path]:
|
||||||
"""Sync selected skills into the run workspace and return their paths."""
|
"""Sync selected skill directories into one target root."""
|
||||||
active_root = self.get_active_root(config_name)
|
target_root.mkdir(parents=True, exist_ok=True)
|
||||||
active_root.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
synced_paths: List[Path] = []
|
synced_paths: List[Path] = []
|
||||||
wanted = set(skill_names)
|
wanted = set(skill_sources)
|
||||||
|
|
||||||
for existing in active_root.iterdir():
|
for existing in target_root.iterdir():
|
||||||
if existing.is_dir() and existing.name not in wanted:
|
if existing.is_dir() and existing.name not in wanted:
|
||||||
shutil.rmtree(existing)
|
shutil.rmtree(existing)
|
||||||
|
|
||||||
for skill_name in skill_names:
|
for skill_name, source_dir in skill_sources.items():
|
||||||
source_dir = self._resolve_source_dir(skill_name)
|
target_dir = target_root / skill_name
|
||||||
target_dir = active_root / skill_name
|
|
||||||
if target_dir.exists():
|
if target_dir.exists():
|
||||||
shutil.rmtree(target_dir)
|
shutil.rmtree(target_dir)
|
||||||
shutil.copytree(source_dir, target_dir)
|
shutil.copytree(source_dir, target_dir)
|
||||||
@@ -115,12 +604,25 @@ class SkillsManager:
|
|||||||
|
|
||||||
return synced_paths
|
return synced_paths
|
||||||
|
|
||||||
|
def sync_active_skills(
|
||||||
|
self,
|
||||||
|
target_root: Path,
|
||||||
|
skill_names: Iterable[str],
|
||||||
|
) -> List[Path]:
|
||||||
|
"""Sync selected shared skills into one active directory."""
|
||||||
|
skill_sources = {
|
||||||
|
skill_name: self._resolve_source_dir(skill_name)
|
||||||
|
for skill_name in skill_names
|
||||||
|
}
|
||||||
|
return self.sync_skill_dirs(target_root, skill_sources)
|
||||||
|
|
||||||
def prepare_active_skills(
|
def prepare_active_skills(
|
||||||
self,
|
self,
|
||||||
config_name: str,
|
config_name: str,
|
||||||
agent_defaults: Dict[str, Iterable[str]],
|
agent_defaults: Dict[str, Iterable[str]],
|
||||||
|
auto_reload: bool = False,
|
||||||
) -> Dict[str, List[Path]]:
|
) -> Dict[str, List[Path]]:
|
||||||
"""Resolve all agent skills, sync the union once, and map paths per agent."""
|
"""Resolve all agent skills into per-agent installed/active workspaces."""
|
||||||
resolved: Dict[str, List[str]] = {}
|
resolved: Dict[str, List[str]] = {}
|
||||||
union: List[str] = []
|
union: List[str] = []
|
||||||
|
|
||||||
@@ -135,10 +637,239 @@ class SkillsManager:
|
|||||||
if skill_name not in union:
|
if skill_name not in union:
|
||||||
union.append(skill_name)
|
union.append(skill_name)
|
||||||
|
|
||||||
self.sync_active_skills(config_name=config_name, skill_names=union)
|
# Maintain the legacy union directory for compatibility/debugging.
|
||||||
active_root = self.get_active_root(config_name)
|
# Agent-local skills remain private to the agent workspace.
|
||||||
|
self.sync_active_skills(
|
||||||
|
target_root=self.get_active_root(config_name),
|
||||||
|
skill_names=[
|
||||||
|
skill_name
|
||||||
|
for skill_name in union
|
||||||
|
if self._is_shared_skill(skill_name)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
active_map: Dict[str, List[Path]] = {}
|
||||||
agent_id: [active_root / skill_name for skill_name in skill_names]
|
for agent_id, skill_names in resolved.items():
|
||||||
for agent_id, skill_names in resolved.items()
|
installed_sources = {
|
||||||
|
skill_name: self._resolve_source_dir(skill_name)
|
||||||
|
for skill_name in skill_names
|
||||||
|
if (self.get_agent_local_root(config_name, agent_id) / skill_name).exists() is False
|
||||||
}
|
}
|
||||||
|
installed_paths = self.sync_skill_dirs(
|
||||||
|
target_root=self.get_agent_installed_root(config_name, agent_id),
|
||||||
|
skill_sources=installed_sources,
|
||||||
|
)
|
||||||
|
|
||||||
|
local_root = self.get_agent_local_root(config_name, agent_id)
|
||||||
|
local_sources = {
|
||||||
|
skill_name: local_root / skill_name
|
||||||
|
for skill_name in skill_names
|
||||||
|
if (local_root / skill_name).exists()
|
||||||
|
}
|
||||||
|
active_sources = {
|
||||||
|
path.name: path for path in installed_paths
|
||||||
|
}
|
||||||
|
active_sources.update(local_sources)
|
||||||
|
active_map[agent_id] = self.sync_skill_dirs(
|
||||||
|
target_root=self.get_agent_active_root(config_name, agent_id),
|
||||||
|
skill_sources=active_sources,
|
||||||
|
)
|
||||||
|
|
||||||
|
disabled_names = _dedupe_preserve_order(
|
||||||
|
self._resolve_disabled_skill_names(
|
||||||
|
config_name=config_name,
|
||||||
|
agent_id=agent_id,
|
||||||
|
default_skills=agent_defaults.get(agent_id, []),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
disabled_sources = {
|
||||||
|
skill_name: self._resolve_agent_skill_source_dir(
|
||||||
|
config_name=config_name,
|
||||||
|
agent_id=agent_id,
|
||||||
|
skill_name=skill_name,
|
||||||
|
)
|
||||||
|
for skill_name in disabled_names
|
||||||
|
}
|
||||||
|
self.sync_skill_dirs(
|
||||||
|
target_root=self.get_agent_disabled_root(config_name, agent_id),
|
||||||
|
skill_sources=disabled_sources,
|
||||||
|
)
|
||||||
|
|
||||||
|
if auto_reload:
|
||||||
|
self.watch_active_skills(config_name, agent_defaults)
|
||||||
|
|
||||||
|
return active_map
|
||||||
|
|
||||||
|
def _is_shared_skill(self, skill_name: str) -> bool:
|
||||||
|
try:
|
||||||
|
self._resolve_source_dir(skill_name)
|
||||||
|
except FileNotFoundError:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def watch_active_skills(
|
||||||
|
self,
|
||||||
|
config_name: str,
|
||||||
|
agent_defaults: Dict[str, Iterable[str]],
|
||||||
|
callback: Optional[Any] = None,
|
||||||
|
) -> "_SkillsWatcher":
|
||||||
|
"""Start file system monitoring on active skill directories.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_name: Run configuration name.
|
||||||
|
agent_defaults: Map of agent_id -> default skill names.
|
||||||
|
callback: Optional callable invoked on file changes with
|
||||||
|
(changed_paths: List[Path]).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A _SkillsWatcher instance. Call .stop() to halt monitoring.
|
||||||
|
"""
|
||||||
|
if not WATCHDOG_AVAILABLE:
|
||||||
|
raise ImportError(
|
||||||
|
"watchdog is required for watch_active_skills. "
|
||||||
|
"Install it with: pip install watchdog"
|
||||||
|
)
|
||||||
|
|
||||||
|
watched_paths: List[Path] = []
|
||||||
|
for agent_id in agent_defaults:
|
||||||
|
active_root = self.get_agent_active_root(config_name, agent_id)
|
||||||
|
if active_root.exists():
|
||||||
|
watched_paths.append(active_root)
|
||||||
|
local_root = self.get_agent_local_root(config_name, agent_id)
|
||||||
|
if local_root.exists():
|
||||||
|
watched_paths.append(local_root)
|
||||||
|
|
||||||
|
handler = _SkillsChangeHandler(watched_paths, callback, self._lock)
|
||||||
|
observer = Observer()
|
||||||
|
for path in watched_paths:
|
||||||
|
observer.schedule(handler, str(path), recursive=True)
|
||||||
|
observer.start()
|
||||||
|
return _SkillsWatcher(observer, handler)
|
||||||
|
|
||||||
|
def reload_skills_if_changed(
|
||||||
|
self,
|
||||||
|
config_name: str,
|
||||||
|
agent_defaults: Dict[str, Iterable[str]],
|
||||||
|
) -> Dict[str, List[Path]]:
|
||||||
|
"""Check for file changes and reload active skills if needed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_name: Run configuration name.
|
||||||
|
agent_defaults: Map of agent_id -> default skill names.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Map of agent_id -> list of reloaded skill paths, or empty dict
|
||||||
|
if no changes were detected.
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
changed = self._pending_skill_changes.get(config_name)
|
||||||
|
if not changed:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
self._pending_skill_changes[config_name] = set()
|
||||||
|
|
||||||
|
return self.prepare_active_skills(config_name, agent_defaults)
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
# Internal change-tracking state (populated by _SkillsChangeHandler)
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
_pending_skill_changes: Dict[str, Set[Path]] = {}
|
||||||
|
|
||||||
|
def _resolve_disabled_skill_names(
|
||||||
|
self,
|
||||||
|
config_name: str,
|
||||||
|
agent_id: str,
|
||||||
|
default_skills: Iterable[str],
|
||||||
|
) -> List[str]:
|
||||||
|
"""Resolve explicit disabled skills for one agent."""
|
||||||
|
bootstrap = get_bootstrap_config_for_run(self.project_root, config_name)
|
||||||
|
override = bootstrap.agent_override(agent_id)
|
||||||
|
baseline = list(override.get("skills", list(default_skills)))
|
||||||
|
agent_config = load_agent_workspace_config(
|
||||||
|
self.get_agent_asset_dir(config_name, agent_id) / "agent.yaml",
|
||||||
|
)
|
||||||
|
manifest = self.load_activation_manifest(config_name)
|
||||||
|
disabled = list(manifest.get("global_disabled_skills", []))
|
||||||
|
disabled.extend(manifest.get("agent_disabled_skills", {}).get(agent_id, []))
|
||||||
|
disabled.extend(agent_config.disabled_skills)
|
||||||
|
for skill_name in baseline:
|
||||||
|
if skill_name in agent_config.disabled_skills and skill_name not in disabled:
|
||||||
|
disabled.append(skill_name)
|
||||||
|
for item in self.list_agent_local_skills(config_name, agent_id):
|
||||||
|
if item.skill_name in agent_config.disabled_skills and item.skill_name not in disabled:
|
||||||
|
disabled.append(item.skill_name)
|
||||||
|
return [
|
||||||
|
skill
|
||||||
|
for skill in disabled
|
||||||
|
if self._skill_exists_for_agent(config_name, agent_id, skill)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class _SkillsWatcher:
|
||||||
|
"""Handle returned by watch_active_skills; call .stop() to halt monitoring."""
|
||||||
|
|
||||||
|
def __init__(self, observer: Observer, handler: "_SkillsChangeHandler") -> None:
|
||||||
|
self._observer = observer
|
||||||
|
self._handler = handler
|
||||||
|
|
||||||
|
def stop(self) -> None:
|
||||||
|
"""Stop the file system observer."""
|
||||||
|
self._observer.stop()
|
||||||
|
self._observer.join()
|
||||||
|
|
||||||
|
|
||||||
|
class _SkillsChangeHandler(FileSystemEventHandler):
|
||||||
|
"""Collects file-change events on skill directories."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
watched_paths: List[Path],
|
||||||
|
callback: Optional[Any] = None,
|
||||||
|
lock: Optional[Lock] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._watched_paths = watched_paths
|
||||||
|
self._callback = callback
|
||||||
|
self._lock = lock
|
||||||
|
|
||||||
|
def on_any_event(self, event: FileSystemEvent) -> None:
|
||||||
|
if event.is_directory:
|
||||||
|
return
|
||||||
|
src_path = Path(event.src_path)
|
||||||
|
for watched in self._watched_paths:
|
||||||
|
if src_path.is_relative_to(watched):
|
||||||
|
run_id = self._run_id_from_path(src_path)
|
||||||
|
if self._lock:
|
||||||
|
with self._lock:
|
||||||
|
SkillsManager._pending_skill_changes.setdefault(
|
||||||
|
run_id, set()
|
||||||
|
).add(src_path)
|
||||||
|
else:
|
||||||
|
SkillsManager._pending_skill_changes.setdefault(
|
||||||
|
run_id, set()
|
||||||
|
).add(src_path)
|
||||||
|
if self._callback:
|
||||||
|
self._callback([src_path])
|
||||||
|
break
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _run_id_from_path(path: Path) -> str:
|
||||||
|
"""Infer config_name from a path like runs/{config_name}/skills/active/..."""
|
||||||
|
parts = path.parts
|
||||||
|
for i, part in enumerate(parts):
|
||||||
|
if part == "runs" and i + 1 < len(parts):
|
||||||
|
return parts[i + 1]
|
||||||
|
return "default"
|
||||||
|
|
||||||
|
def _dedupe_preserve_order(items: Iterable[str]) -> List[str]:
|
||||||
|
result: List[str] = []
|
||||||
|
for item in items:
|
||||||
|
if item not in result:
|
||||||
|
result.append(item)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_skill_name(raw_name: str) -> str:
|
||||||
|
normalized = str(raw_name or "").strip().lower().replace(" ", "_").replace("-", "_")
|
||||||
|
allowed = [ch for ch in normalized if ch.isalnum() or ch == "_"]
|
||||||
|
return "".join(allowed).strip("_")
|
||||||
|
|||||||
18
backend/agents/team/__init__.py
Normal file
18
backend/agents/team/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""Team module for multi-agent orchestration.
|
||||||
|
|
||||||
|
Provides inter-agent communication, task delegation, and coordination
|
||||||
|
for subagent spawning and lifecycle management.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .messenger import AgentMessenger
|
||||||
|
from .task_delegator import TaskDelegator
|
||||||
|
from .team_coordinator import TeamCoordinator
|
||||||
|
from .registry import AgentRegistry
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AgentMessenger",
|
||||||
|
"TaskDelegator",
|
||||||
|
"TeamCoordinator",
|
||||||
|
"AgentRegistry",
|
||||||
|
]
|
||||||
225
backend/agents/team/messenger.py
Normal file
225
backend/agents/team/messenger.py
Normal file
@@ -0,0 +1,225 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""AgentMessenger - Pub/sub inter-agent communication.
|
||||||
|
|
||||||
|
Provides broadcast(), send(), and subscribe() for message passing
|
||||||
|
between agents using AgentScope's Msg format.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from typing import Any, Callable, Dict, List, Optional, Set
|
||||||
|
|
||||||
|
from agentscope.message import Msg
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentMessenger:
|
||||||
|
"""Pub/sub messenger for inter-agent communication.
|
||||||
|
|
||||||
|
Supports:
|
||||||
|
- broadcast(): Send message to all subscribers
|
||||||
|
- send(): Send message to specific agent
|
||||||
|
- subscribe(): Register callback for agent messages
|
||||||
|
- announce(): Send system-wide announcement
|
||||||
|
- enable_auto_broadcast: Auto-broadcast agent replies to all participants
|
||||||
|
|
||||||
|
Messages use AgentScope's Msg format for compatibility.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, enable_auto_broadcast: bool = False):
|
||||||
|
"""Initialize the messenger.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
enable_auto_broadcast: If True, agent replies are automatically
|
||||||
|
broadcast to all subscribed agents.
|
||||||
|
"""
|
||||||
|
self._subscriptions: Dict[str, List[Callable[[Msg], None]]] = {}
|
||||||
|
self._inbox: Dict[str, List[Msg]] = {}
|
||||||
|
self._locks: Dict[str, asyncio.Lock] = {}
|
||||||
|
self._enable_auto_broadcast = enable_auto_broadcast
|
||||||
|
self._participants: Set[str] = set()
|
||||||
|
|
||||||
|
def subscribe(
|
||||||
|
self,
|
||||||
|
agent_id: str,
|
||||||
|
callback: Callable[[Msg], None],
|
||||||
|
) -> None:
|
||||||
|
"""Subscribe an agent to receive messages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: Target agent identifier
|
||||||
|
callback: Async function to call when message received
|
||||||
|
"""
|
||||||
|
if agent_id not in self._subscriptions:
|
||||||
|
self._subscriptions[agent_id] = []
|
||||||
|
self._subscriptions[agent_id].append(callback)
|
||||||
|
logger.debug("Agent %s subscribed to messages", agent_id)
|
||||||
|
|
||||||
|
def unsubscribe(self, agent_id: str, callback: Callable[[Msg], None]) -> None:
|
||||||
|
"""Unsubscribe an agent from messages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: Target agent identifier
|
||||||
|
callback: Callback to remove
|
||||||
|
"""
|
||||||
|
if agent_id in self._subscriptions:
|
||||||
|
try:
|
||||||
|
self._subscriptions[agent_id].remove(callback)
|
||||||
|
logger.debug("Agent %s unsubscribed from messages", agent_id)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def send(
|
||||||
|
self,
|
||||||
|
to_agent: str,
|
||||||
|
message: Msg,
|
||||||
|
) -> None:
|
||||||
|
"""Send message to specific agent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
to_agent: Target agent identifier
|
||||||
|
message: Message to send (uses Msg format)
|
||||||
|
"""
|
||||||
|
async def _deliver():
|
||||||
|
if to_agent in self._subscriptions:
|
||||||
|
for callback in self._subscriptions[to_agent]:
|
||||||
|
try:
|
||||||
|
if asyncio.iscoroutinefunction(callback):
|
||||||
|
await callback(message)
|
||||||
|
else:
|
||||||
|
callback(message)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
"Error delivering message to %s: %s",
|
||||||
|
to_agent,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
|
||||||
|
await _deliver()
|
||||||
|
|
||||||
|
async def broadcast(self, message: Msg) -> None:
|
||||||
|
"""Broadcast message to all subscribed agents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: Message to broadcast (uses Msg format)
|
||||||
|
"""
|
||||||
|
delivery_tasks = []
|
||||||
|
for agent_id, callbacks in self._subscriptions.items():
|
||||||
|
for callback in callbacks:
|
||||||
|
async def _deliver(cb=callback, aid=agent_id):
|
||||||
|
try:
|
||||||
|
if asyncio.iscoroutinefunction(cb):
|
||||||
|
await cb(message)
|
||||||
|
else:
|
||||||
|
cb(message)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
"Error broadcasting to %s: %s",
|
||||||
|
aid,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
delivery_tasks.append(_deliver())
|
||||||
|
|
||||||
|
if delivery_tasks:
|
||||||
|
await asyncio.gather(*delivery_tasks)
|
||||||
|
|
||||||
|
def inbox(self, agent_id: str) -> List[Msg]:
|
||||||
|
"""Get and clear inbox for agent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: Agent identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of messages in inbox
|
||||||
|
"""
|
||||||
|
messages = self._inbox.get(agent_id, [])
|
||||||
|
self._inbox[agent_id] = []
|
||||||
|
return messages
|
||||||
|
|
||||||
|
def inbox_count(self, agent_id: str) -> int:
|
||||||
|
"""Count messages in agent's inbox without clearing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: Agent identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of messages waiting
|
||||||
|
"""
|
||||||
|
return len(self._inbox.get(agent_id, []))
|
||||||
|
|
||||||
|
def add_participant(self, agent_id: str) -> None:
|
||||||
|
"""Add a participant to the messenger.
|
||||||
|
|
||||||
|
Participants are the agents that can receive auto-broadcast messages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: Agent identifier to add
|
||||||
|
"""
|
||||||
|
self._participants.add(agent_id)
|
||||||
|
logger.debug("Agent %s added as participant", agent_id)
|
||||||
|
|
||||||
|
def remove_participant(self, agent_id: str) -> None:
|
||||||
|
"""Remove a participant from the messenger.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: Agent identifier to remove
|
||||||
|
"""
|
||||||
|
self._participants.discard(agent_id)
|
||||||
|
logger.debug("Agent %s removed from participants", agent_id)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def enable_auto_broadcast(self) -> bool:
|
||||||
|
"""Check if auto_broadcast is enabled."""
|
||||||
|
return self._enable_auto_broadcast
|
||||||
|
|
||||||
|
@enable_auto_broadcast.setter
|
||||||
|
def enable_auto_broadcast(self, value: bool) -> None:
|
||||||
|
"""Enable or disable auto_broadcast."""
|
||||||
|
self._enable_auto_broadcast = value
|
||||||
|
logger.debug("Auto_broadcast set to %s", value)
|
||||||
|
|
||||||
|
async def announce(self, message: Msg) -> None:
|
||||||
|
"""Send a system-wide announcement to all participants.
|
||||||
|
|
||||||
|
Unlike broadcast(), announce() sends a message from the system/host
|
||||||
|
to all participants without requiring prior subscription.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: Announcement message (uses Msg format)
|
||||||
|
"""
|
||||||
|
logger.info("System announcement: %s", message.content)
|
||||||
|
await self.broadcast(message)
|
||||||
|
|
||||||
|
async def auto_broadcast(self, message: Msg) -> None:
|
||||||
|
"""Auto-broadcast message to all participants.
|
||||||
|
|
||||||
|
This is called internally when enable_auto_broadcast is True.
|
||||||
|
Broadcasts to all registered participants.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: Message to auto-broadcast (uses Msg format)
|
||||||
|
"""
|
||||||
|
if not self._enable_auto_broadcast:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Broadcast to all participants
|
||||||
|
for participant_id in self._participants:
|
||||||
|
if participant_id in self._subscriptions:
|
||||||
|
for callback in self._subscriptions[participant_id]:
|
||||||
|
try:
|
||||||
|
if asyncio.iscoroutinefunction(callback):
|
||||||
|
await callback(message)
|
||||||
|
else:
|
||||||
|
callback(message)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
"Error auto-broadcasting to %s: %s",
|
||||||
|
participant_id,
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["AgentMessenger"]
|
||||||
188
backend/agents/team/registry.py
Normal file
188
backend/agents/team/registry.py
Normal file
@@ -0,0 +1,188 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""AgentRegistry - Agent registration and lookup by role.
|
||||||
|
|
||||||
|
Provides register(), unregister(), and get_by_role() for agent
|
||||||
|
discovery and management.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from agentscope.message import Msg
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentRegistry:
|
||||||
|
"""Registry for agent instances with role-based lookup.
|
||||||
|
|
||||||
|
Supports:
|
||||||
|
- register(): Add agent with roles
|
||||||
|
- unregister(): Remove agent
|
||||||
|
- get_by_role(): Find agents by role
|
||||||
|
- get_by_id(): Get specific agent
|
||||||
|
|
||||||
|
Each agent can have multiple roles for flexible dispatch.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._agents: Dict[str, Any] = {}
|
||||||
|
self._roles: Dict[str, List[str]] = {}
|
||||||
|
self._agent_roles: Dict[str, List[str]] = {}
|
||||||
|
|
||||||
|
def register(
|
||||||
|
self,
|
||||||
|
agent_id: str,
|
||||||
|
agent: Any,
|
||||||
|
roles: Optional[List[str]] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Register an agent with optional roles.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: Unique agent identifier
|
||||||
|
agent: Agent instance
|
||||||
|
roles: Optional list of role strings
|
||||||
|
"""
|
||||||
|
self._agents[agent_id] = agent
|
||||||
|
self._agent_roles[agent_id] = roles or []
|
||||||
|
|
||||||
|
for role in self._agent_roles[agent_id]:
|
||||||
|
if role not in self._roles:
|
||||||
|
self._roles[role] = []
|
||||||
|
if agent_id not in self._roles[role]:
|
||||||
|
self._roles[role].append(agent_id)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Registered agent %s with roles %s",
|
||||||
|
agent_id,
|
||||||
|
self._agent_roles[agent_id],
|
||||||
|
)
|
||||||
|
|
||||||
|
def unregister(self, agent_id: str) -> bool:
|
||||||
|
"""Unregister an agent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: Agent identifier to remove
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if agent was removed
|
||||||
|
"""
|
||||||
|
if agent_id not in self._agents:
|
||||||
|
return False
|
||||||
|
|
||||||
|
roles = self._agent_roles.pop(agent_id, [])
|
||||||
|
for role in roles:
|
||||||
|
if role in self._roles:
|
||||||
|
try:
|
||||||
|
self._roles[role].remove(agent_id)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
del self._agents[agent_id]
|
||||||
|
logger.info("Unregistered agent: %s", agent_id)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_by_id(self, agent_id: str) -> Optional[Any]:
|
||||||
|
"""Get agent by ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: Agent identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Agent instance or None
|
||||||
|
"""
|
||||||
|
return self._agents.get(agent_id)
|
||||||
|
|
||||||
|
def get_by_role(self, role: str) -> List[Any]:
|
||||||
|
"""Get all agents with a given role.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
role: Role string to search for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of agent instances with the role
|
||||||
|
"""
|
||||||
|
agent_ids = self._roles.get(role, [])
|
||||||
|
return [self._agents[aid] for aid in agent_ids if aid in self._agents]
|
||||||
|
|
||||||
|
def get_by_roles(self, roles: List[str]) -> List[Any]:
|
||||||
|
"""Get agents matching ANY of the given roles.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
roles: List of role strings
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of unique agent instances matching any role
|
||||||
|
"""
|
||||||
|
seen = set()
|
||||||
|
result = []
|
||||||
|
for role in roles:
|
||||||
|
for agent in self.get_by_role(role):
|
||||||
|
if id(agent) not in seen:
|
||||||
|
seen.add(id(agent))
|
||||||
|
result.append(agent)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def list_agents(self) -> List[str]:
|
||||||
|
"""List all registered agent IDs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of agent identifiers
|
||||||
|
"""
|
||||||
|
return list(self._agents.keys())
|
||||||
|
|
||||||
|
def list_roles(self) -> List[str]:
|
||||||
|
"""List all registered roles.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of role strings
|
||||||
|
"""
|
||||||
|
return list(self._roles.keys())
|
||||||
|
|
||||||
|
def list_roles_for_agent(self, agent_id: str) -> List[str]:
|
||||||
|
"""List roles for specific agent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: Agent identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of role strings
|
||||||
|
"""
|
||||||
|
return list(self._agent_roles.get(agent_id, []))
|
||||||
|
|
||||||
|
def update_roles(self, agent_id: str, roles: List[str]) -> None:
|
||||||
|
"""Update roles for an existing agent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: Agent identifier
|
||||||
|
roles: New list of roles
|
||||||
|
"""
|
||||||
|
if agent_id not in self._agents:
|
||||||
|
raise KeyError(f"Agent not registered: {agent_id}")
|
||||||
|
|
||||||
|
old_roles = self._agent_roles.get(agent_id, [])
|
||||||
|
for role in old_roles:
|
||||||
|
if role in self._roles:
|
||||||
|
try:
|
||||||
|
self._roles[role].remove(agent_id)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
self._agent_roles[agent_id] = roles
|
||||||
|
for role in roles:
|
||||||
|
if role not in self._roles:
|
||||||
|
self._roles[role] = []
|
||||||
|
if agent_id not in self._roles[role]:
|
||||||
|
self._roles[role].append(agent_id)
|
||||||
|
|
||||||
|
logger.info("Updated roles for agent %s: %s", agent_id, roles)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def agents(self) -> Dict[str, Any]:
|
||||||
|
"""Get copy of registered agents dict."""
|
||||||
|
return dict(self._agents)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["AgentRegistry"]
|
||||||
620
backend/agents/team/task_delegator.py
Normal file
620
backend/agents/team/task_delegator.py
Normal file
@@ -0,0 +1,620 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""TaskDelegator - Subagent spawning and task delegation.
|
||||||
|
|
||||||
|
Provides delegate() and delegate_parallel() for spawning subagents
|
||||||
|
with separate context and memory. Supports runtime dynamic subagent
|
||||||
|
definition via task_data with description, prompt, and tools.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from typing import Any, Awaitable, Callable, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
from agentscope.message import Msg
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Default timeout for subagent execution (seconds)
|
||||||
|
DEFAULT_EXECUTION_TIMEOUT = 120.0
|
||||||
|
|
||||||
|
|
||||||
|
# Type alias for subagent specification
|
||||||
|
SubagentSpec = Dict[str, Any]
|
||||||
|
"""Subagent specification format:
|
||||||
|
{
|
||||||
|
"description": "Expert code reviewer...",
|
||||||
|
"prompt": "Analyze code quality...",
|
||||||
|
"tools": ["Read", "Glob", "Grep"], # Optional: list of tool names
|
||||||
|
"model": "gpt-4o", # Optional: model name
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class TaskDelegator:
|
||||||
|
"""Delegates tasks to subagents with isolated context.
|
||||||
|
|
||||||
|
Supports:
|
||||||
|
- delegate(): Spawn single subagent for task
|
||||||
|
- delegate_parallel(): Spawn multiple subagents concurrently
|
||||||
|
- delegate_task(): Delegate with dynamic subagent definition from task_data
|
||||||
|
|
||||||
|
Each subagent gets its own memory/context to prevent
|
||||||
|
cross-contamination.
|
||||||
|
|
||||||
|
Dynamic Subagent Definition:
|
||||||
|
task_data can include an "agents" dict to define subagents inline:
|
||||||
|
|
||||||
|
task_data = {
|
||||||
|
"task": "Review the code changes",
|
||||||
|
"agents": {
|
||||||
|
"code-reviewer": {
|
||||||
|
"description": "Expert code reviewer for quality and security.",
|
||||||
|
"prompt": "Analyze code quality and suggest improvements.",
|
||||||
|
"tools": ["Read", "Glob", "Grep"],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, agent: Any):
|
||||||
|
"""Initialize TaskDelegator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent: Parent EvoAgent instance for accessing model, formatter, workspace
|
||||||
|
"""
|
||||||
|
self._agent = agent
|
||||||
|
# Get messenger from parent agent if available
|
||||||
|
self._messenger = getattr(agent, "messenger", None)
|
||||||
|
self._registry = getattr(agent, "_registry", None)
|
||||||
|
self._subagents: Dict[str, Any] = {}
|
||||||
|
self._dynamic_subagents: Dict[str, SubagentSpec] = {}
|
||||||
|
self._tasks: Dict[str, asyncio.Task] = {}
|
||||||
|
|
||||||
|
# Extract model and formatter from parent agent
|
||||||
|
self._model = getattr(agent, "model", None)
|
||||||
|
self._formatter = getattr(agent, "formatter", None)
|
||||||
|
self._workspace_dir = getattr(agent, "workspace_dir", None)
|
||||||
|
self._config_name = getattr(agent, "config_name", None)
|
||||||
|
|
||||||
|
async def delegate(
|
||||||
|
self,
|
||||||
|
agent_id: str,
|
||||||
|
task: Callable[..., Awaitable[Msg]],
|
||||||
|
context: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> asyncio.Task:
|
||||||
|
"""Delegate task to a single subagent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: Unique identifier for this subagent instance
|
||||||
|
task: Async function representing the task
|
||||||
|
context: Optional context dict for the subagent
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
asyncio.Task for the delegated task
|
||||||
|
"""
|
||||||
|
async def _run_with_context():
|
||||||
|
result = await task(context or {})
|
||||||
|
return result
|
||||||
|
|
||||||
|
self._tasks[agent_id] = asyncio.create_task(_run_with_context())
|
||||||
|
logger.info("Delegated task to subagent: %s", agent_id)
|
||||||
|
return self._tasks[agent_id]
|
||||||
|
|
||||||
|
async def delegate_parallel(
|
||||||
|
self,
|
||||||
|
tasks: List[Dict[str, Any]],
|
||||||
|
) -> List[asyncio.Task]:
|
||||||
|
"""Delegate multiple tasks in parallel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tasks: List of task dicts with keys:
|
||||||
|
- agent_id: Unique identifier
|
||||||
|
- task: Async function to execute
|
||||||
|
- context: Optional context dict
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of asyncio.Task for all delegated tasks
|
||||||
|
"""
|
||||||
|
async def _run_task(task_def: Dict[str, Any]):
|
||||||
|
agent_id = task_def["agent_id"]
|
||||||
|
task_func = task_def["task"]
|
||||||
|
context = task_def.get("context", {})
|
||||||
|
|
||||||
|
async def _run_with_context():
|
||||||
|
return await task_func(context)
|
||||||
|
|
||||||
|
self._tasks[agent_id] = asyncio.create_task(_run_with_context())
|
||||||
|
return self._tasks[agent_id]
|
||||||
|
|
||||||
|
gathered_tasks = await asyncio.gather(
|
||||||
|
*[_run_task(t) for t in tasks],
|
||||||
|
return_exceptions=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
valid_tasks = [t for t in gathered_tasks if isinstance(t, asyncio.Task)]
|
||||||
|
logger.info(
|
||||||
|
"Delegated %d tasks in parallel (%d succeeded)",
|
||||||
|
len(tasks),
|
||||||
|
len(valid_tasks),
|
||||||
|
)
|
||||||
|
return valid_tasks
|
||||||
|
|
||||||
|
async def wait_for(self, agent_id: str, timeout: Optional[float] = None) -> Any:
|
||||||
|
"""Wait for subagent task to complete.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: Subagent identifier
|
||||||
|
timeout: Optional timeout in seconds
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Task result
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
asyncio.TimeoutError: If task doesn't complete in time
|
||||||
|
KeyError: If agent_id not found
|
||||||
|
"""
|
||||||
|
if agent_id not in self._tasks:
|
||||||
|
raise KeyError(f"Unknown subagent: {agent_id}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
return await asyncio.wait_for(
|
||||||
|
self._tasks[agent_id],
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning("Task %s timed out after %s seconds", agent_id, timeout)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def cancel(self, agent_id: str) -> bool:
|
||||||
|
"""Cancel a subagent task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: Subagent identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if task was cancelled
|
||||||
|
"""
|
||||||
|
if agent_id in self._tasks:
|
||||||
|
self._tasks[agent_id].cancel()
|
||||||
|
del self._tasks[agent_id]
|
||||||
|
logger.info("Cancelled subagent task: %s", agent_id)
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def list_tasks(self) -> List[str]:
|
||||||
|
"""List active subagent task IDs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of agent_ids with pending tasks
|
||||||
|
"""
|
||||||
|
return list(self._tasks.keys())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tasks(self) -> Dict[str, asyncio.Task]:
|
||||||
|
"""Get copy of active tasks dict."""
|
||||||
|
return dict(self._tasks)
|
||||||
|
|
||||||
|
async def delegate_task(
|
||||||
|
self,
|
||||||
|
task_type: str,
|
||||||
|
task_data: Dict[str, Any],
|
||||||
|
target_agent: Optional[str] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Delegate a task with optional dynamic subagent definition.
|
||||||
|
|
||||||
|
Supports runtime subagent definition via task_data["agents"]:
|
||||||
|
|
||||||
|
task_data = {
|
||||||
|
"task": "Review code changes",
|
||||||
|
"agents": {
|
||||||
|
"code-reviewer": {
|
||||||
|
"description": "Expert code reviewer...",
|
||||||
|
"prompt": "Analyze code quality...",
|
||||||
|
"tools": ["Read", "Glob", "Grep"],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_type: Type of task (e.g., "analysis", "review", "research")
|
||||||
|
task_data: Task payload, may include "agents" for dynamic subagent def
|
||||||
|
target_agent: Optional specific agent ID to delegate to
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with "success" and result/error
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Extract dynamic subagent definitions from task_data
|
||||||
|
agents_def = task_data.get("agents", {})
|
||||||
|
|
||||||
|
if agents_def:
|
||||||
|
# Register dynamic subagents
|
||||||
|
for agent_name, agent_spec in agents_def.items():
|
||||||
|
self._dynamic_subagents[agent_name] = agent_spec
|
||||||
|
logger.info(
|
||||||
|
"Registered dynamic subagent: %s (description: %s)",
|
||||||
|
agent_name,
|
||||||
|
agent_spec.get("description", "")[:50],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Determine target agent
|
||||||
|
effective_target = target_agent
|
||||||
|
if not effective_target:
|
||||||
|
# Use first available dynamic subagent or default
|
||||||
|
if agents_def:
|
||||||
|
effective_target = next(iter(agents_def.keys()))
|
||||||
|
else:
|
||||||
|
effective_target = "default"
|
||||||
|
|
||||||
|
# Execute the task (async)
|
||||||
|
task_result = await self._execute_task(
|
||||||
|
task_type=task_type,
|
||||||
|
task_data=task_data,
|
||||||
|
target_agent=effective_target,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Clean up dynamic subagents after execution
|
||||||
|
for agent_name in agents_def.keys():
|
||||||
|
self._dynamic_subagents.pop(agent_name, None)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"result": task_result,
|
||||||
|
"subagents_used": list(agents_def.keys()) if agents_def else [],
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Task delegation failed: %s", e)
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": str(e),
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _execute_task(
|
||||||
|
self,
|
||||||
|
task_type: str,
|
||||||
|
task_data: Dict[str, Any],
|
||||||
|
target_agent: str,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Execute the delegated task with a real subagent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_type: Type of task
|
||||||
|
task_data: Task payload
|
||||||
|
target_agent: Target agent identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Task execution result with success/failure info
|
||||||
|
"""
|
||||||
|
task_content = task_data.get("task", task_data.get("prompt", ""))
|
||||||
|
timeout = task_data.get("timeout", DEFAULT_EXECUTION_TIMEOUT)
|
||||||
|
|
||||||
|
# Check if we have a dynamic subagent spec for this target
|
||||||
|
agent_spec = self._dynamic_subagents.get(target_agent)
|
||||||
|
|
||||||
|
if agent_spec:
|
||||||
|
logger.info(
|
||||||
|
"Executing task '%s' with dynamic subagent '%s'",
|
||||||
|
task_type,
|
||||||
|
target_agent,
|
||||||
|
)
|
||||||
|
return await self._create_and_run_subagent(
|
||||||
|
agent_name=target_agent,
|
||||||
|
agent_spec=agent_spec,
|
||||||
|
task_content=task_content,
|
||||||
|
task_type=task_type,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fallback: try to use parent agent's model to process the task directly
|
||||||
|
logger.info(
|
||||||
|
"Executing task '%s' with parent agent '%s' (no dynamic subagent)",
|
||||||
|
task_type,
|
||||||
|
target_agent,
|
||||||
|
)
|
||||||
|
return await self._run_with_parent_agent(
|
||||||
|
task_content=task_content,
|
||||||
|
task_type=task_type,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _create_and_run_subagent(
|
||||||
|
self,
|
||||||
|
agent_name: str,
|
||||||
|
agent_spec: SubagentSpec,
|
||||||
|
task_content: str,
|
||||||
|
task_type: str,
|
||||||
|
timeout: float,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Create and run a dynamic subagent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_name: Name identifier for the subagent
|
||||||
|
agent_spec: Subagent specification (description, prompt, tools, model)
|
||||||
|
task_content: Task prompt to send to the subagent
|
||||||
|
task_type: Type of task
|
||||||
|
timeout: Execution timeout in seconds
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with execution results
|
||||||
|
"""
|
||||||
|
subagent_id = f"subagent_{agent_name}_{uuid.uuid4().hex[:8]}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create subagent instance
|
||||||
|
subagent = await self._create_subagent(
|
||||||
|
subagent_id=subagent_id,
|
||||||
|
agent_spec=agent_spec,
|
||||||
|
)
|
||||||
|
|
||||||
|
if subagent is None:
|
||||||
|
return {
|
||||||
|
"task_type": task_type,
|
||||||
|
"task": task_content,
|
||||||
|
"subagent": agent_name,
|
||||||
|
"status": "failed",
|
||||||
|
"error": "Failed to create subagent",
|
||||||
|
"message": f"Could not instantiate subagent '{agent_name}'",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Store for potential cleanup
|
||||||
|
self._subagents[subagent_id] = subagent
|
||||||
|
|
||||||
|
# Execute with timeout
|
||||||
|
result = await asyncio.wait_for(
|
||||||
|
self._run_subagent(subagent, task_content),
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract response content
|
||||||
|
response_content = ""
|
||||||
|
if isinstance(result, Msg):
|
||||||
|
response_content = result.content
|
||||||
|
elif hasattr(result, "content"):
|
||||||
|
response_content = str(result.content)
|
||||||
|
elif isinstance(result, dict):
|
||||||
|
response_content = result.get("content", str(result))
|
||||||
|
else:
|
||||||
|
response_content = str(result)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Subagent '%s' completed task '%s' successfully",
|
||||||
|
agent_name,
|
||||||
|
task_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"task_type": task_type,
|
||||||
|
"task": task_content,
|
||||||
|
"subagent": {
|
||||||
|
"name": agent_name,
|
||||||
|
"id": subagent_id,
|
||||||
|
"description": agent_spec.get("description", ""),
|
||||||
|
},
|
||||||
|
"status": "completed",
|
||||||
|
"response": response_content,
|
||||||
|
"message": f"Task '{task_type}' executed with subagent '{agent_name}'",
|
||||||
|
}
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning(
|
||||||
|
"Subagent '%s' timed out after %.1f seconds for task '%s'",
|
||||||
|
agent_name,
|
||||||
|
timeout,
|
||||||
|
task_type,
|
||||||
|
)
|
||||||
|
# Cancel the task if still running
|
||||||
|
if subagent_id in self._subagents:
|
||||||
|
self._subagents.pop(subagent_id, None)
|
||||||
|
return {
|
||||||
|
"task_type": task_type,
|
||||||
|
"task": task_content,
|
||||||
|
"subagent": agent_name,
|
||||||
|
"status": "timeout",
|
||||||
|
"error": f"Execution timed out after {timeout} seconds",
|
||||||
|
"message": f"Task '{task_type}' timed out for subagent '{agent_name}'",
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
"Subagent '%s' failed for task '%s': %s",
|
||||||
|
agent_name,
|
||||||
|
task_type,
|
||||||
|
e,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
# Cleanup on failure
|
||||||
|
if subagent_id in self._subagents:
|
||||||
|
self._subagents.pop(subagent_id, None)
|
||||||
|
return {
|
||||||
|
"task_type": task_type,
|
||||||
|
"task": task_content,
|
||||||
|
"subagent": agent_name,
|
||||||
|
"status": "error",
|
||||||
|
"error": str(e),
|
||||||
|
"message": f"Task '{task_type}' failed for subagent '{agent_name}': {e}",
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _create_subagent(
|
||||||
|
self,
|
||||||
|
subagent_id: str,
|
||||||
|
agent_spec: SubagentSpec,
|
||||||
|
) -> Optional[Any]:
|
||||||
|
"""Create a subagent instance.
|
||||||
|
|
||||||
|
Uses the parent agent's model/formatter to create a lightweight
|
||||||
|
subagent for task execution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
subagent_id: Unique identifier for the subagent
|
||||||
|
agent_spec: Subagent specification
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Subagent instance or None if creation fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Import here to avoid circular imports
|
||||||
|
from agentscope.memory import InMemoryMemory
|
||||||
|
|
||||||
|
# Get model and formatter from parent
|
||||||
|
model = self._model
|
||||||
|
formatter = self._formatter
|
||||||
|
|
||||||
|
if model is None:
|
||||||
|
logger.error("Cannot create subagent: parent agent has no model")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Build system prompt from agent spec
|
||||||
|
description = agent_spec.get("description", "")
|
||||||
|
prompt_template = agent_spec.get("prompt", "")
|
||||||
|
system_prompt = f"""You are {description}
|
||||||
|
|
||||||
|
{prompt_template}
|
||||||
|
|
||||||
|
Your task is to complete the user's request below.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Create a minimal ReActAgent as the subagent
|
||||||
|
from agentscope.agent import ReActAgent
|
||||||
|
|
||||||
|
subagent = ReActAgent(
|
||||||
|
name=subagent_id,
|
||||||
|
model=model,
|
||||||
|
sys_prompt=system_prompt,
|
||||||
|
toolkit=None, # Could load tools from agent_spec.get("tools", [])
|
||||||
|
memory=InMemoryMemory(),
|
||||||
|
formatter=formatter,
|
||||||
|
max_iters=agent_spec.get("max_iters", 5),
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug("Created subagent: %s", subagent_id)
|
||||||
|
return subagent
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
"Failed to create subagent '%s': %s",
|
||||||
|
subagent_id,
|
||||||
|
e,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _run_subagent(
|
||||||
|
self,
|
||||||
|
subagent: Any,
|
||||||
|
task_content: str,
|
||||||
|
) -> Any:
|
||||||
|
"""Run a subagent with the given task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
subagent: Subagent instance
|
||||||
|
task_content: Task prompt
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Agent response (Msg or similar)
|
||||||
|
"""
|
||||||
|
from agentscope.message import Msg
|
||||||
|
|
||||||
|
# Create message for the subagent
|
||||||
|
task_msg = Msg(
|
||||||
|
name="user",
|
||||||
|
content=task_content,
|
||||||
|
role="user",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execute the agent
|
||||||
|
response = await subagent.reply(task_msg)
|
||||||
|
return response
|
||||||
|
|
||||||
|
async def _run_with_parent_agent(
|
||||||
|
self,
|
||||||
|
task_content: str,
|
||||||
|
task_type: str,
|
||||||
|
timeout: float,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Run task using the parent agent directly.
|
||||||
|
|
||||||
|
Used when no dynamic subagent is defined.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_content: Task prompt
|
||||||
|
task_type: Type of task
|
||||||
|
timeout: Execution timeout
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with execution results
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
result = await asyncio.wait_for(
|
||||||
|
self._agent.reply(Msg(
|
||||||
|
name="user",
|
||||||
|
content=task_content,
|
||||||
|
role="user",
|
||||||
|
)),
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
response_content = ""
|
||||||
|
if isinstance(result, Msg):
|
||||||
|
response_content = result.content
|
||||||
|
elif hasattr(result, "content"):
|
||||||
|
response_content = str(result.content)
|
||||||
|
else:
|
||||||
|
response_content = str(result)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"task_type": task_type,
|
||||||
|
"task": task_content,
|
||||||
|
"status": "completed",
|
||||||
|
"response": response_content,
|
||||||
|
"message": f"Task '{task_type}' executed with parent agent",
|
||||||
|
}
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
return {
|
||||||
|
"task_type": task_type,
|
||||||
|
"task": task_content,
|
||||||
|
"status": "timeout",
|
||||||
|
"error": f"Execution timed out after {timeout} seconds",
|
||||||
|
"message": f"Task '{task_type}' timed out",
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
"Parent agent failed for task '%s': %s",
|
||||||
|
task_type,
|
||||||
|
e,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"task_type": task_type,
|
||||||
|
"task": task_content,
|
||||||
|
"status": "error",
|
||||||
|
"error": str(e),
|
||||||
|
"message": f"Task '{task_type}' failed: {e}",
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_dynamic_subagent(self, name: str) -> Optional[SubagentSpec]:
|
||||||
|
"""Get a dynamically defined subagent specification.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Subagent name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Subagent spec dict or None if not found
|
||||||
|
"""
|
||||||
|
return self._dynamic_subagents.get(name)
|
||||||
|
|
||||||
|
def list_dynamic_subagents(self) -> List[str]:
|
||||||
|
"""List all registered dynamic subagent names.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of subagent names
|
||||||
|
"""
|
||||||
|
return list(self._dynamic_subagents.keys())
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["TaskDelegator", "SubagentSpec"]
|
||||||
389
backend/agents/team/team_coordinator.py
Normal file
389
backend/agents/team/team_coordinator.py
Normal file
@@ -0,0 +1,389 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""TeamCoordinator - Agent lifecycle management and execution.
|
||||||
|
|
||||||
|
Provides run_parallel() using asyncio.gather() and run_sequential()
|
||||||
|
for coordinating multiple agents.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from typing import Any, Awaitable, Callable, Dict, List, Optional, Type
|
||||||
|
|
||||||
|
from agentscope.message import Msg
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TeamCoordinator:
|
||||||
|
"""Coordinates agent lifecycle and execution.
|
||||||
|
|
||||||
|
Supports:
|
||||||
|
- run_parallel(): Execute multiple agents concurrently with asyncio.gather()
|
||||||
|
- run_sequential(): Execute agents one after another
|
||||||
|
- run_phase(): Execute a named phase with registered agents
|
||||||
|
- register_agent(): Add agent to coordinator
|
||||||
|
- unregister_agent(): Remove agent from coordinator
|
||||||
|
|
||||||
|
Each agent maintains separate context/memory.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
participants: Optional[List[Any]] = None,
|
||||||
|
task_content: Optional[str] = None,
|
||||||
|
messenger: Optional[Any] = None,
|
||||||
|
registry: Optional[Any] = None,
|
||||||
|
):
|
||||||
|
"""Initialize TeamCoordinator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
participants: List of agent instances to coordinate
|
||||||
|
task_content: Task description content for the agents
|
||||||
|
messenger: AgentMessenger for communication (optional)
|
||||||
|
registry: AgentRegistry for agent lookup (optional)
|
||||||
|
"""
|
||||||
|
self._participants = participants or []
|
||||||
|
self._task_content = task_content or ""
|
||||||
|
self._messenger = messenger
|
||||||
|
self._registry = registry
|
||||||
|
self._agents: Dict[str, Any] = {}
|
||||||
|
self._running_tasks: Dict[str, asyncio.Task] = {}
|
||||||
|
# Auto-register participants
|
||||||
|
for agent in self._participants:
|
||||||
|
if hasattr(agent, "name"):
|
||||||
|
self._agents[agent.name] = agent
|
||||||
|
elif hasattr(agent, "id"):
|
||||||
|
self._agents[agent.id] = agent
|
||||||
|
|
||||||
|
def register_agent(self, agent_id: str, agent: Any) -> None:
|
||||||
|
"""Register an agent with the coordinator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: Unique agent identifier
|
||||||
|
agent: Agent instance
|
||||||
|
"""
|
||||||
|
self._agents[agent_id] = agent
|
||||||
|
logger.info("Registered agent: %s", agent_id)
|
||||||
|
|
||||||
|
def unregister_agent(self, agent_id: str) -> None:
|
||||||
|
"""Unregister an agent from the coordinator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: Agent identifier to remove
|
||||||
|
"""
|
||||||
|
if agent_id in self._agents:
|
||||||
|
del self._agents[agent_id]
|
||||||
|
logger.info("Unregistered agent: %s", agent_id)
|
||||||
|
|
||||||
|
def get_agent(self, agent_id: str) -> Any:
|
||||||
|
"""Get registered agent by ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: Agent identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Agent instance
|
||||||
|
"""
|
||||||
|
return self._agents.get(agent_id)
|
||||||
|
|
||||||
|
def list_agents(self) -> List[str]:
|
||||||
|
"""List all registered agent IDs.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of agent identifiers
|
||||||
|
"""
|
||||||
|
return list(self._agents.keys())
|
||||||
|
|
||||||
|
async def run_parallel(
|
||||||
|
self,
|
||||||
|
agent_ids: List[str],
|
||||||
|
initial_message: Optional[Msg] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Run multiple agents in parallel using asyncio.gather().
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_ids: List of agent IDs to run concurrently
|
||||||
|
initial_message: Optional initial message to broadcast
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict mapping agent_id to result
|
||||||
|
"""
|
||||||
|
async def _run_agent(aid: str) -> tuple[str, Any]:
|
||||||
|
agent = self._agents.get(aid)
|
||||||
|
if agent is None:
|
||||||
|
logger.error("Agent %s not found", aid)
|
||||||
|
return (aid, None)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if hasattr(agent, "reply") and asyncio.iscoroutinefunction(agent.reply):
|
||||||
|
if initial_message:
|
||||||
|
result = await agent.reply(initial_message)
|
||||||
|
else:
|
||||||
|
result = await agent.reply()
|
||||||
|
elif hasattr(agent, "run") and asyncio.iscoroutinefunction(agent.run):
|
||||||
|
result = await agent.run()
|
||||||
|
else:
|
||||||
|
result = await agent()
|
||||||
|
logger.info("Agent %s completed successfully", aid)
|
||||||
|
return (aid, result)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Agent %s failed: %s", aid, e)
|
||||||
|
return (aid, {"error": str(e)})
|
||||||
|
|
||||||
|
results = await asyncio.gather(
|
||||||
|
*[_run_agent(aid) for aid in agent_ids],
|
||||||
|
return_exceptions=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
output: Dict[str, Any] = {}
|
||||||
|
for result in results:
|
||||||
|
if isinstance(result, tuple):
|
||||||
|
agent_id, agent_result = result
|
||||||
|
output[agent_id] = agent_result
|
||||||
|
else:
|
||||||
|
logger.error("Unexpected result from asyncio.gather: %s", result)
|
||||||
|
|
||||||
|
logger.info("Parallel run completed for %d agents", len(agent_ids))
|
||||||
|
return output
|
||||||
|
|
||||||
|
async def run_sequential(
|
||||||
|
self,
|
||||||
|
agent_ids: List[str],
|
||||||
|
initial_message: Optional[Msg] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Run agents one after another in order.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_ids: List of agent IDs to run in sequence
|
||||||
|
initial_message: Optional initial message for first agent
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict mapping agent_id to result
|
||||||
|
"""
|
||||||
|
output: Dict[str, Any] = {}
|
||||||
|
current_message = initial_message
|
||||||
|
|
||||||
|
for agent_id in agent_ids:
|
||||||
|
agent = self._agents.get(agent_id)
|
||||||
|
if agent is None:
|
||||||
|
logger.error("Agent %s not found", agent_id)
|
||||||
|
output[agent_id] = {"error": "Agent not found"}
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
if hasattr(agent, "reply") and asyncio.iscoroutinefunction(agent.reply):
|
||||||
|
result = await agent.reply(current_message)
|
||||||
|
elif hasattr(agent, "run") and asyncio.iscoroutinefunction(agent.run):
|
||||||
|
result = await agent.run()
|
||||||
|
else:
|
||||||
|
result = await agent()
|
||||||
|
|
||||||
|
output[agent_id] = result
|
||||||
|
current_message = result
|
||||||
|
logger.info("Agent %s completed sequentially", agent_id)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Agent %s failed: %s", agent_id, e)
|
||||||
|
output[agent_id] = {"error": str(e)}
|
||||||
|
break
|
||||||
|
|
||||||
|
logger.info("Sequential run completed for %d agents", len(agent_ids))
|
||||||
|
return output
|
||||||
|
|
||||||
|
async def run_phase(
|
||||||
|
self,
|
||||||
|
phase_name: str,
|
||||||
|
agent_ids: Optional[List[str]] = None,
|
||||||
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> List[Any]:
|
||||||
|
"""Execute a named phase with registered agents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
phase_name: Name of the phase (e.g., "analyst_analysis")
|
||||||
|
agent_ids: Optional list of agent IDs; if None, uses all registered
|
||||||
|
metadata: Optional metadata to include in the message (e.g., tickers, date)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of results from each agent
|
||||||
|
"""
|
||||||
|
if agent_ids is None:
|
||||||
|
agent_ids = list(self._agents.keys())
|
||||||
|
|
||||||
|
_agent_ids = [aid for aid in agent_ids if aid in self._agents]
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Running phase '%s' with %d agents: %s",
|
||||||
|
phase_name,
|
||||||
|
len(_agent_ids),
|
||||||
|
_agent_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create messages for each agent
|
||||||
|
results: List[Any] = []
|
||||||
|
for agent_id in _agent_ids:
|
||||||
|
agent = self._agents[agent_id]
|
||||||
|
try:
|
||||||
|
if hasattr(agent, "reply") and asyncio.iscoroutinefunction(agent.reply):
|
||||||
|
# Create a message for the agent with proper structure
|
||||||
|
msg = Msg(
|
||||||
|
name="system",
|
||||||
|
content=self._task_content or f"Please execute phase: {phase_name}",
|
||||||
|
role="user",
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
result = await agent.reply(msg)
|
||||||
|
elif hasattr(agent, "run") and asyncio.iscoroutinefunction(agent.run):
|
||||||
|
result = await agent.run()
|
||||||
|
else:
|
||||||
|
result = await agent()
|
||||||
|
results.append(result)
|
||||||
|
logger.info("Phase '%s': Agent %s completed", phase_name, agent_id)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Phase '%s': Agent %s failed: %s", phase_name, agent_id, e)
|
||||||
|
results.append(None)
|
||||||
|
|
||||||
|
logger.info("Phase '%s' completed with %d results", phase_name, len(results))
|
||||||
|
return results
|
||||||
|
|
||||||
|
async def run_with_dependencies(
|
||||||
|
self,
|
||||||
|
agent_tasks: Dict[str, List[str]],
|
||||||
|
initial_message: Optional[Msg] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Run agents respecting dependency graph.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_tasks: Dict mapping agent_id to list of prerequisite agent_ids
|
||||||
|
initial_message: Optional initial message
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict mapping agent_id to result
|
||||||
|
"""
|
||||||
|
completed: Dict[str, Any] = {}
|
||||||
|
remaining = set(agent_tasks.keys())
|
||||||
|
|
||||||
|
while remaining:
|
||||||
|
ready = [
|
||||||
|
aid for aid in remaining
|
||||||
|
if all(dep in completed for dep in agent_tasks.get(aid, []))
|
||||||
|
]
|
||||||
|
|
||||||
|
if not ready:
|
||||||
|
logger.error("Circular dependency detected in agent tasks")
|
||||||
|
for aid in remaining:
|
||||||
|
completed[aid] = {"error": "Circular dependency"}
|
||||||
|
break
|
||||||
|
|
||||||
|
results = await self.run_parallel(ready, initial_message)
|
||||||
|
completed.update(results)
|
||||||
|
|
||||||
|
for aid in ready:
|
||||||
|
remaining.discard(aid)
|
||||||
|
initial_message = results.get(aid)
|
||||||
|
|
||||||
|
return completed
|
||||||
|
|
||||||
|
async def fanout_pipeline(
|
||||||
|
self,
|
||||||
|
agents: List[Any],
|
||||||
|
msg: Optional[Msg] = None,
|
||||||
|
) -> List[Msg]:
|
||||||
|
"""Fanout a message to multiple agents concurrently and collect all responses.
|
||||||
|
|
||||||
|
Similar to AgentScope's fanout_pipeline, this sends the same message
|
||||||
|
to all specified agents and returns a list of all agent responses.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agents: List of agent instances to fanout the message to
|
||||||
|
msg: Message to send to all agents (optional)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of Msg responses from each agent (in the same order as input agents)
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> responses = await fanout_pipeline(
|
||||||
|
... agents=[alice, bob, charlie],
|
||||||
|
... msg=question,
|
||||||
|
... )
|
||||||
|
>>> # responses is a list of Msg responses from each agent
|
||||||
|
"""
|
||||||
|
async def _fanout_to_agent(agent: Any) -> Optional[Msg]:
|
||||||
|
"""Send message to a single agent and return its response."""
|
||||||
|
try:
|
||||||
|
if hasattr(agent, "reply") and asyncio.iscoroutinefunction(agent.reply):
|
||||||
|
result = await agent.reply(msg) if msg is not None else await agent.reply()
|
||||||
|
elif hasattr(agent, "run") and asyncio.iscoroutinefunction(agent.run):
|
||||||
|
result = await agent.run()
|
||||||
|
else:
|
||||||
|
result = await agent()
|
||||||
|
|
||||||
|
# Convert result to Msg if needed
|
||||||
|
if result is None:
|
||||||
|
return None
|
||||||
|
if isinstance(result, Msg):
|
||||||
|
return result
|
||||||
|
# If result is a dict with content, wrap it
|
||||||
|
if isinstance(result, dict) and "content" in result:
|
||||||
|
return Msg(
|
||||||
|
name=getattr(agent, "name", "unknown"),
|
||||||
|
content=result.get("content", ""),
|
||||||
|
role="assistant",
|
||||||
|
metadata=result.get("metadata"),
|
||||||
|
)
|
||||||
|
# Otherwise wrap the result
|
||||||
|
return Msg(
|
||||||
|
name=getattr(agent, "name", "unknown"),
|
||||||
|
content=str(result),
|
||||||
|
role="assistant",
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Agent %s failed in fanout_pipeline: %s",
|
||||||
|
getattr(agent, "name", "unknown"), e)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Run all agents concurrently
|
||||||
|
results = await asyncio.gather(
|
||||||
|
*[_fanout_to_agent(agent) for agent in agents],
|
||||||
|
return_exceptions=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Filter out exceptions and keep only valid responses
|
||||||
|
responses: List[Msg] = []
|
||||||
|
for i, result in enumerate(results):
|
||||||
|
if isinstance(result, Exception):
|
||||||
|
logger.error("Fanout to agent %d failed: %s", i, result)
|
||||||
|
responses.append(None) # type: ignore[arg-type]
|
||||||
|
else:
|
||||||
|
responses.append(result) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
logger.info("Fanout pipeline completed for %d agents", len(agents))
|
||||||
|
return responses
|
||||||
|
|
||||||
|
async def shutdown(self, timeout: Optional[float] = 5.0) -> None:
|
||||||
|
"""Shutdown all running agents gracefully.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timeout: Timeout for graceful shutdown
|
||||||
|
"""
|
||||||
|
logger.info("Shutting down TeamCoordinator...")
|
||||||
|
|
||||||
|
cancel_tasks = [
|
||||||
|
asyncio.create_task(asyncio.wait_for(task, timeout=timeout))
|
||||||
|
for task in self._running_tasks.values()
|
||||||
|
]
|
||||||
|
|
||||||
|
if cancel_tasks:
|
||||||
|
await asyncio.gather(*cancel_tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
self._running_tasks.clear()
|
||||||
|
logger.info("TeamCoordinator shutdown complete")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def agents(self) -> Dict[str, Any]:
|
||||||
|
"""Get copy of registered agents dict."""
|
||||||
|
return dict(self._agents)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["TeamCoordinator"]
|
||||||
132
backend/agents/team_pipeline_config.py
Normal file
132
backend/agents/team_pipeline_config.py
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""Run-scoped team pipeline configuration helpers."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Iterable, List, Dict, Any
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_FILENAME = "TEAM_PIPELINE.yaml"
|
||||||
|
|
||||||
|
|
||||||
|
def team_pipeline_path(project_root: Path, config_name: str) -> Path:
|
||||||
|
"""Return run-scoped team pipeline config path."""
|
||||||
|
return project_root / "runs" / config_name / DEFAULT_FILENAME
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_team_pipeline_config(
|
||||||
|
project_root: Path,
|
||||||
|
config_name: str,
|
||||||
|
default_analysts: Iterable[str],
|
||||||
|
) -> Path:
|
||||||
|
"""Ensure TEAM_PIPELINE.yaml exists for one run."""
|
||||||
|
path = team_pipeline_path(project_root, config_name)
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
if path.exists():
|
||||||
|
return path
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"version": 1,
|
||||||
|
"controller_agent": "portfolio_manager",
|
||||||
|
"discussion": {
|
||||||
|
"allow_dynamic_team_update": True,
|
||||||
|
"active_analysts": list(default_analysts),
|
||||||
|
},
|
||||||
|
"decision": {
|
||||||
|
"require_risk_manager": True,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
path.write_text(
|
||||||
|
yaml.safe_dump(payload, allow_unicode=True, sort_keys=False),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
def load_team_pipeline_config(project_root: Path, config_name: str) -> Dict[str, Any]:
|
||||||
|
"""Load TEAM_PIPELINE.yaml and return parsed dict."""
|
||||||
|
path = team_pipeline_path(project_root, config_name)
|
||||||
|
if not path.exists():
|
||||||
|
return {}
|
||||||
|
parsed = yaml.safe_load(path.read_text(encoding="utf-8")) or {}
|
||||||
|
return parsed if isinstance(parsed, dict) else {}
|
||||||
|
|
||||||
|
|
||||||
|
def save_team_pipeline_config(
|
||||||
|
project_root: Path,
|
||||||
|
config_name: str,
|
||||||
|
config: Dict[str, Any],
|
||||||
|
) -> Path:
|
||||||
|
"""Persist TEAM_PIPELINE.yaml."""
|
||||||
|
path = team_pipeline_path(project_root, config_name)
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
path.write_text(
|
||||||
|
yaml.safe_dump(config, allow_unicode=True, sort_keys=False),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_active_analysts(
|
||||||
|
project_root: Path,
|
||||||
|
config_name: str,
|
||||||
|
available_analysts: Iterable[str],
|
||||||
|
) -> List[str]:
|
||||||
|
"""Resolve active analysts from TEAM_PIPELINE.yaml."""
|
||||||
|
available = [item for item in available_analysts]
|
||||||
|
parsed = load_team_pipeline_config(project_root, config_name)
|
||||||
|
discussion = parsed.get("discussion", {}) if isinstance(parsed, dict) else {}
|
||||||
|
configured = discussion.get("active_analysts", [])
|
||||||
|
if not isinstance(configured, list) or not configured:
|
||||||
|
return available
|
||||||
|
|
||||||
|
active = [item for item in configured if item in available]
|
||||||
|
return active or available
|
||||||
|
|
||||||
|
|
||||||
|
def update_active_analysts(
|
||||||
|
project_root: Path,
|
||||||
|
config_name: str,
|
||||||
|
available_analysts: Iterable[str],
|
||||||
|
*,
|
||||||
|
add: Iterable[str] | None = None,
|
||||||
|
remove: Iterable[str] | None = None,
|
||||||
|
set_to: Iterable[str] | None = None,
|
||||||
|
) -> List[str]:
|
||||||
|
"""Update active analysts and persist TEAM_PIPELINE.yaml."""
|
||||||
|
available = [item for item in available_analysts]
|
||||||
|
ensure_team_pipeline_config(project_root, config_name, available)
|
||||||
|
parsed = load_team_pipeline_config(project_root, config_name)
|
||||||
|
discussion = parsed.setdefault("discussion", {})
|
||||||
|
if not isinstance(discussion, dict):
|
||||||
|
discussion = {}
|
||||||
|
parsed["discussion"] = discussion
|
||||||
|
|
||||||
|
current = discussion.get("active_analysts", [])
|
||||||
|
if not isinstance(current, list):
|
||||||
|
current = []
|
||||||
|
current = [item for item in current if item in available]
|
||||||
|
if not current:
|
||||||
|
current = list(available)
|
||||||
|
|
||||||
|
if set_to is not None:
|
||||||
|
target = [item for item in set_to if item in available]
|
||||||
|
current = target or current
|
||||||
|
|
||||||
|
for item in add or []:
|
||||||
|
if item in available and item not in current:
|
||||||
|
current.append(item)
|
||||||
|
|
||||||
|
for item in remove or []:
|
||||||
|
current = [existing for existing in current if existing != item]
|
||||||
|
|
||||||
|
if not current:
|
||||||
|
current = [available[0]] if available else []
|
||||||
|
|
||||||
|
discussion["active_analysts"] = current
|
||||||
|
save_team_pipeline_config(project_root, config_name, parsed)
|
||||||
|
return current
|
||||||
|
|
||||||
286
backend/agents/templates.py
Normal file
286
backend/agents/templates.py
Normal file
@@ -0,0 +1,286 @@
|
|||||||
|
"""
|
||||||
|
Agent模板定义
|
||||||
|
|
||||||
|
包含各角色的ROLE.md内容字典,供程序生成Agent工作空间时使用。
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 基础模板文件内容
|
||||||
|
BASE_TEMPLATES = {
|
||||||
|
"AGENTS.md": """# Agent Guide
|
||||||
|
|
||||||
|
## 工作流程
|
||||||
|
1. 接收分析任务
|
||||||
|
2. 调用相关工具/技能
|
||||||
|
3. 生成分析报告
|
||||||
|
4. 参与团队决策
|
||||||
|
|
||||||
|
## 工具使用规范
|
||||||
|
- 优先使用已激活的技能
|
||||||
|
- 不确定时询问Portfolio Manager
|
||||||
|
- 重要发现用 `/save` 记录
|
||||||
|
|
||||||
|
## 记忆管理
|
||||||
|
- 使用 `/compact` 定期压缩记忆
|
||||||
|
- 投资经验记录在MEMORY.md
|
||||||
|
""",
|
||||||
|
|
||||||
|
"SOUL.md": """# Soul
|
||||||
|
|
||||||
|
你是专业的金融分析师,语气冷静、客观、专业。
|
||||||
|
你的分析应该数据驱动,避免情绪化表达。
|
||||||
|
""",
|
||||||
|
|
||||||
|
"PROFILE.md": """# Profile
|
||||||
|
|
||||||
|
## 投资风格
|
||||||
|
- 风险承受能力:中等
|
||||||
|
- 投资期限:中期(3-12个月)
|
||||||
|
- 偏好行业:科技、医疗、消费
|
||||||
|
|
||||||
|
## 优势
|
||||||
|
- 财务分析
|
||||||
|
- 趋势识别
|
||||||
|
|
||||||
|
## 改进方向
|
||||||
|
- 市场情绪把握
|
||||||
|
""",
|
||||||
|
|
||||||
|
"MEMORY.md": """# Memory
|
||||||
|
|
||||||
|
<!-- 此文件用于记录Agent的学习经验和重要发现 -->
|
||||||
|
|
||||||
|
## 经验总结
|
||||||
|
|
||||||
|
## 重要事件
|
||||||
|
|
||||||
|
## 改进记录
|
||||||
|
""",
|
||||||
|
|
||||||
|
"HEARTBEAT.md": """# Heartbeat
|
||||||
|
|
||||||
|
## 定时任务
|
||||||
|
- 每日开盘前检查持仓
|
||||||
|
- 收盘后记录当日表现
|
||||||
|
""",
|
||||||
|
|
||||||
|
"POLICY.md": """# Policy
|
||||||
|
|
||||||
|
## 风控规则
|
||||||
|
- 单一持仓不超过20%
|
||||||
|
- 止损线:-15%
|
||||||
|
""",
|
||||||
|
|
||||||
|
"STYLE.md": """# Style
|
||||||
|
|
||||||
|
- 使用结构化输出(JSON/Markdown表格)
|
||||||
|
- 包含置信度评分
|
||||||
|
- 列出关键假设
|
||||||
|
""",
|
||||||
|
|
||||||
|
"agent.yaml": """agent_id: {agent_id}
|
||||||
|
agent_type: {agent_type}
|
||||||
|
name: {name}
|
||||||
|
model:
|
||||||
|
provider: openai
|
||||||
|
model_name: gpt-4o
|
||||||
|
temperature: 0.3
|
||||||
|
enabled_skills: []
|
||||||
|
disabled_skills: []
|
||||||
|
settings: {{}}
|
||||||
|
""",
|
||||||
|
}
|
||||||
|
|
||||||
|
# 角色专用模板
|
||||||
|
ROLE_TEMPLATES = {
|
||||||
|
"fundamental": {
|
||||||
|
"ROLE.md": """# Role: Fundamental Analyst
|
||||||
|
|
||||||
|
## 职责
|
||||||
|
分析公司财务报表、盈利能力、成长性、竞争优势等基本面因素。
|
||||||
|
|
||||||
|
## 分析维度
|
||||||
|
- 财务报表分析(资产负债表、利润表、现金流量表)
|
||||||
|
- 盈利能力指标(ROE、ROA、毛利率、净利率)
|
||||||
|
- 成长性指标(营收增长率、利润增长率)
|
||||||
|
- 估值指标(P/E、P/B、P/S)
|
||||||
|
- 行业地位和竞争优势
|
||||||
|
|
||||||
|
## 输出格式
|
||||||
|
- 财务健康度评分(1-10)
|
||||||
|
- 成长性评分(1-10)
|
||||||
|
- 关键财务亮点和风险
|
||||||
|
- 同业对比分析
|
||||||
|
""",
|
||||||
|
"SOUL.md": """# Soul
|
||||||
|
|
||||||
|
你是严谨的基本面分析师,像沃伦·巴菲特一样注重企业内在价值。
|
||||||
|
你的分析深入细致,关注长期价值而非短期波动。
|
||||||
|
语气沉稳、逻辑严密,善于发现财务数据背后的商业本质。
|
||||||
|
""",
|
||||||
|
},
|
||||||
|
|
||||||
|
"technical": {
|
||||||
|
"ROLE.md": """# Role: Technical Analyst
|
||||||
|
|
||||||
|
## 职责
|
||||||
|
分析价格走势、交易量、技术指标,识别买卖时机。
|
||||||
|
|
||||||
|
## 分析维度
|
||||||
|
- 趋势分析(长期/中期/短期趋势)
|
||||||
|
- 支撑阻力位识别
|
||||||
|
- 技术指标(MACD、RSI、KDJ、布林带等)
|
||||||
|
- 形态识别(头肩顶/底、双底、三角形等)
|
||||||
|
- 量价关系分析
|
||||||
|
|
||||||
|
## 输出格式
|
||||||
|
- 趋势方向(上涨/下跌/震荡)
|
||||||
|
- 关键价位(支撑/阻力)
|
||||||
|
- 技术信号(买入/卖出/观望)
|
||||||
|
- 置信度评分
|
||||||
|
""",
|
||||||
|
"SOUL.md": """# Soul
|
||||||
|
|
||||||
|
你是敏锐的技术分析师,相信价格包含一切信息。
|
||||||
|
你善于从图表中发现规律,像侦探一样寻找市场留下的痕迹。
|
||||||
|
语气果断、快速反应,善于捕捉稍纵即逝的交易机会。
|
||||||
|
""",
|
||||||
|
},
|
||||||
|
|
||||||
|
"sentiment": {
|
||||||
|
"ROLE.md": """# Role: Sentiment Analyst
|
||||||
|
|
||||||
|
## 职责
|
||||||
|
分析市场情绪、资金流向、新闻舆情,判断市场心理状态。
|
||||||
|
|
||||||
|
## 分析维度
|
||||||
|
- 市场情绪指标(恐慌/贪婪指数)
|
||||||
|
- 资金流向分析(主力/散户资金)
|
||||||
|
- 新闻舆情分析(正面/负面/中性)
|
||||||
|
- 社交媒体情绪
|
||||||
|
- 机构持仓变化
|
||||||
|
|
||||||
|
## 输出格式
|
||||||
|
- 情绪评分(-10到+10,极度恐慌到极度贪婪)
|
||||||
|
- 资金流向判断
|
||||||
|
- 舆情摘要
|
||||||
|
- 情绪拐点预警
|
||||||
|
""",
|
||||||
|
"SOUL.md": """# Soul
|
||||||
|
|
||||||
|
你是敏感的市场情绪捕手,善于感知市场的恐惧与贪婪。
|
||||||
|
你关注人性在金融市场中的表现,理解情绪如何驱动价格。
|
||||||
|
语气富有洞察力、善于捕捉微妙变化,像心理学家一样理解市场参与者。
|
||||||
|
""",
|
||||||
|
},
|
||||||
|
|
||||||
|
"valuation": {
|
||||||
|
"ROLE.md": """# Role: Valuation Analyst
|
||||||
|
|
||||||
|
## 职责
|
||||||
|
评估公司内在价值,计算合理价格区间,识别高估/低估机会。
|
||||||
|
|
||||||
|
## 分析维度
|
||||||
|
- DCF现金流折现模型
|
||||||
|
- 相对估值法(P/E、EV/EBITDA等)
|
||||||
|
- 资产重估法
|
||||||
|
- 分部估值(SOTP)
|
||||||
|
- 安全边际计算
|
||||||
|
|
||||||
|
## 输出格式
|
||||||
|
- 内在价值估算
|
||||||
|
- 合理价格区间
|
||||||
|
- 当前价格vs内在价值(高估/低估百分比)
|
||||||
|
- 估值假设和敏感性分析
|
||||||
|
""",
|
||||||
|
"SOUL.md": """# Soul
|
||||||
|
|
||||||
|
你是精确的估值分析师,追求计算内在价值的准确区间。
|
||||||
|
你像精算师一样严谨,注重假设的合理性和安全边际。
|
||||||
|
语气精确、注重数字,善于发现市场定价错误带来的机会。
|
||||||
|
""",
|
||||||
|
},
|
||||||
|
|
||||||
|
"portfolio": {
|
||||||
|
"ROLE.md": """# Role: Portfolio Manager
|
||||||
|
|
||||||
|
## 职责
|
||||||
|
统筹各分析师意见,制定投资决策,管理投资组合配置。
|
||||||
|
|
||||||
|
## 分析维度
|
||||||
|
- 资产配置策略(股债比例、行业分布)
|
||||||
|
- 风险收益平衡
|
||||||
|
- 仓位管理(建仓/加仓/减仓/清仓)
|
||||||
|
- 再平衡时机
|
||||||
|
- 组合相关性分析
|
||||||
|
|
||||||
|
## 输出格式
|
||||||
|
- 投资决策(买入/卖出/持有)
|
||||||
|
- 建议仓位比例
|
||||||
|
- 目标价位
|
||||||
|
- 止损止盈设置
|
||||||
|
- 组合调整建议
|
||||||
|
""",
|
||||||
|
"SOUL.md": """# Soul
|
||||||
|
|
||||||
|
你是睿智的投资组合经理,像将军一样统筹全局。
|
||||||
|
你善于权衡各方意见,做出果断而理性的投资决策。
|
||||||
|
语气权威、决策果断,对组合整体表现负有最终责任。
|
||||||
|
""",
|
||||||
|
},
|
||||||
|
|
||||||
|
"risk": {
|
||||||
|
"ROLE.md": """# Role: Risk Manager
|
||||||
|
|
||||||
|
## 职责
|
||||||
|
识别、评估和监控投资风险,确保组合风险在可控范围内。
|
||||||
|
|
||||||
|
## 分析维度
|
||||||
|
- 市场风险(Beta、波动率)
|
||||||
|
- 信用风险
|
||||||
|
- 流动性风险
|
||||||
|
- 集中度风险
|
||||||
|
- 尾部风险(VaR、CVaR)
|
||||||
|
- 压力测试
|
||||||
|
|
||||||
|
## 输出格式
|
||||||
|
- 风险等级(低/中/高/极高)
|
||||||
|
- 风险敞口分析
|
||||||
|
- 风险调整建议
|
||||||
|
- 预警阈值设置
|
||||||
|
- 应急预案
|
||||||
|
""",
|
||||||
|
"SOUL.md": """# Soul
|
||||||
|
|
||||||
|
你是谨慎的风险管理者,时刻警惕潜在的损失。
|
||||||
|
你像守门员一样守护组合安全,宁可错过机会也不冒无法承受的风险。
|
||||||
|
语气保守、风险意识强,善于发现隐藏的威胁和脆弱性。
|
||||||
|
""",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_base_template(filename: str) -> str | None:
|
||||||
|
"""获取基础模板内容"""
|
||||||
|
return BASE_TEMPLATES.get(filename)
|
||||||
|
|
||||||
|
|
||||||
|
def get_role_template(role_type: str, filename: str) -> str | None:
|
||||||
|
"""获取角色专用模板内容"""
|
||||||
|
role = ROLE_TEMPLATES.get(role_type)
|
||||||
|
if role:
|
||||||
|
return role.get(filename)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_all_role_types() -> list[str]:
|
||||||
|
"""获取所有角色类型列表"""
|
||||||
|
return list(ROLE_TEMPLATES.keys())
|
||||||
|
|
||||||
|
|
||||||
|
def render_agent_yaml(agent_id: str, agent_type: str, name: str) -> str:
|
||||||
|
"""渲染agent.yaml模板"""
|
||||||
|
return BASE_TEMPLATES["agent.yaml"].format(
|
||||||
|
agent_id=agent_id,
|
||||||
|
agent_type=agent_type,
|
||||||
|
name=name
|
||||||
|
)
|
||||||
@@ -1,21 +1,31 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
"""Toolkit factory following AgentScope's skill + tool group practices."""
|
"""Toolkit factory following AgentScope's skill + tool group practices.
|
||||||
|
|
||||||
from typing import Any, Dict, Iterable
|
支持从Agent工作空间动态创建工具集,加载builtin/customized技能,
|
||||||
|
以及合并Agent特定工具。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any, Dict, Iterable, List, Optional, Set
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from backend.config.bootstrap_config import get_bootstrap_config_for_run
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from .skills_manager import SkillsManager
|
from backend.agents.agent_workspace import load_agent_workspace_config
|
||||||
|
from backend.agents.skills_manager import SkillsManager
|
||||||
|
from backend.agents.skill_loader import load_skill_from_dir, get_skill_tools
|
||||||
|
from backend.agents.skill_metadata import parse_skill_metadata
|
||||||
|
from backend.config.bootstrap_config import get_bootstrap_config_for_run
|
||||||
|
|
||||||
|
|
||||||
def load_agent_profiles() -> Dict[str, Dict[str, Any]]:
|
def load_agent_profiles() -> Dict[str, Dict[str, Any]]:
|
||||||
|
"""加载Agent配置文件"""
|
||||||
config_path = SkillsManager().project_root / "backend" / "config" / "agent_profiles.yaml"
|
config_path = SkillsManager().project_root / "backend" / "config" / "agent_profiles.yaml"
|
||||||
with open(config_path, "r", encoding="utf-8") as file:
|
with open(config_path, "r", encoding="utf-8") as file:
|
||||||
return yaml.safe_load(file) or {}
|
return yaml.safe_load(file) or {}
|
||||||
|
|
||||||
|
|
||||||
def _register_analysis_tool_groups(toolkit: Any) -> None:
|
def _register_analysis_tool_groups(toolkit: Any) -> None:
|
||||||
|
"""注册分析工具组"""
|
||||||
from backend.tools.analysis_tools import TOOL_REGISTRY
|
from backend.tools.analysis_tools import TOOL_REGISTRY
|
||||||
|
|
||||||
tool_groups = {
|
tool_groups = {
|
||||||
@@ -94,6 +104,7 @@ def _register_analysis_tool_groups(toolkit: Any) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def _register_portfolio_tool_groups(toolkit: Any, pm_agent: Any) -> None:
|
def _register_portfolio_tool_groups(toolkit: Any, pm_agent: Any) -> None:
|
||||||
|
"""注册投资组合工具组"""
|
||||||
toolkit.create_tool_group(
|
toolkit.create_tool_group(
|
||||||
group_name="portfolio_ops",
|
group_name="portfolio_ops",
|
||||||
description="Portfolio decision recording tools.",
|
description="Portfolio decision recording tools.",
|
||||||
@@ -107,9 +118,30 @@ def _register_portfolio_tool_groups(toolkit: Any, pm_agent: Any) -> None:
|
|||||||
pm_agent._make_decision,
|
pm_agent._make_decision,
|
||||||
group_name="portfolio_ops",
|
group_name="portfolio_ops",
|
||||||
)
|
)
|
||||||
|
if hasattr(pm_agent, "_add_team_analyst"):
|
||||||
|
toolkit.register_tool_function(
|
||||||
|
pm_agent._add_team_analyst,
|
||||||
|
group_name="portfolio_ops",
|
||||||
|
)
|
||||||
|
if hasattr(pm_agent, "_remove_team_analyst"):
|
||||||
|
toolkit.register_tool_function(
|
||||||
|
pm_agent._remove_team_analyst,
|
||||||
|
group_name="portfolio_ops",
|
||||||
|
)
|
||||||
|
if hasattr(pm_agent, "_set_active_analysts"):
|
||||||
|
toolkit.register_tool_function(
|
||||||
|
pm_agent._set_active_analysts,
|
||||||
|
group_name="portfolio_ops",
|
||||||
|
)
|
||||||
|
if hasattr(pm_agent, "_create_team_analyst"):
|
||||||
|
toolkit.register_tool_function(
|
||||||
|
pm_agent._create_team_analyst,
|
||||||
|
group_name="portfolio_ops",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _register_risk_tool_groups(toolkit: Any) -> None:
|
def _register_risk_tool_groups(toolkit: Any) -> None:
|
||||||
|
"""注册风险工具组"""
|
||||||
from backend.tools.risk_tools import (
|
from backend.tools.risk_tools import (
|
||||||
assess_margin_and_liquidity,
|
assess_margin_and_liquidity,
|
||||||
assess_position_concentration,
|
assess_position_concentration,
|
||||||
@@ -145,12 +177,25 @@ def create_agent_toolkit(
|
|||||||
owner: Any = None,
|
owner: Any = None,
|
||||||
active_skill_dirs: Iterable[str] | None = None,
|
active_skill_dirs: Iterable[str] | None = None,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Create a Toolkit with agent skills and grouped tools."""
|
"""Create a Toolkit with agent skills and grouped tools.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: Agent标识符
|
||||||
|
config_name: 运行配置名称
|
||||||
|
owner: Agent实例(用于注册特定方法)
|
||||||
|
active_skill_dirs: 显式指定的活动技能目录列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
配置好的Toolkit实例
|
||||||
|
"""
|
||||||
from agentscope.tool import Toolkit
|
from agentscope.tool import Toolkit
|
||||||
|
|
||||||
profiles = load_agent_profiles()
|
profiles = load_agent_profiles()
|
||||||
profile = profiles.get(agent_id, {})
|
profile = profiles.get(agent_id, {})
|
||||||
skills_manager = SkillsManager()
|
skills_manager = SkillsManager()
|
||||||
|
agent_config = load_agent_workspace_config(
|
||||||
|
skills_manager.get_agent_asset_dir(config_name, agent_id) / "agent.yaml",
|
||||||
|
)
|
||||||
bootstrap_config = get_bootstrap_config_for_run(
|
bootstrap_config = get_bootstrap_config_for_run(
|
||||||
skills_manager.project_root,
|
skills_manager.project_root,
|
||||||
config_name,
|
config_name,
|
||||||
@@ -158,8 +203,16 @@ def create_agent_toolkit(
|
|||||||
override = bootstrap_config.agent_override(agent_id)
|
override = bootstrap_config.agent_override(agent_id)
|
||||||
active_groups = override.get(
|
active_groups = override.get(
|
||||||
"active_tool_groups",
|
"active_tool_groups",
|
||||||
profile.get("active_tool_groups", []),
|
agent_config.active_tool_groups
|
||||||
|
or profile.get("active_tool_groups", []),
|
||||||
)
|
)
|
||||||
|
disabled_groups = set(agent_config.disabled_tool_groups)
|
||||||
|
if disabled_groups:
|
||||||
|
active_groups = [
|
||||||
|
group_name
|
||||||
|
for group_name in active_groups
|
||||||
|
if group_name not in disabled_groups
|
||||||
|
]
|
||||||
|
|
||||||
toolkit = Toolkit(
|
toolkit = Toolkit(
|
||||||
agent_skill_instruction=(
|
agent_skill_instruction=(
|
||||||
@@ -184,14 +237,281 @@ def create_agent_toolkit(
|
|||||||
default_skills=profile.get("skills", []),
|
default_skills=profile.get("skills", []),
|
||||||
)
|
)
|
||||||
active_skill_dirs = [
|
active_skill_dirs = [
|
||||||
skills_manager.get_active_root(config_name) / skill_name
|
skills_manager.get_agent_active_root(config_name, agent_id) / skill_name
|
||||||
for skill_name in skill_names
|
for skill_name in skill_names
|
||||||
]
|
]
|
||||||
|
|
||||||
for skill_dir in active_skill_dirs:
|
for skill_dir in active_skill_dirs:
|
||||||
toolkit.register_agent_skill(str(skill_dir))
|
toolkit.register_agent_skill(str(skill_dir))
|
||||||
|
|
||||||
|
apply_skill_tool_restrictions(toolkit, active_skill_dirs)
|
||||||
|
|
||||||
if active_groups:
|
if active_groups:
|
||||||
toolkit.update_tool_groups(group_names=active_groups, active=True)
|
toolkit.update_tool_groups(group_names=active_groups, active=True)
|
||||||
|
|
||||||
return toolkit
|
return toolkit
|
||||||
|
|
||||||
|
|
||||||
|
def create_toolkit_from_workspace(
|
||||||
|
agent_id: str,
|
||||||
|
config_name: str,
|
||||||
|
owner: Any = None,
|
||||||
|
include_builtin: bool = True,
|
||||||
|
include_customized: bool = True,
|
||||||
|
include_local: bool = True,
|
||||||
|
active_groups: Optional[List[str]] = None,
|
||||||
|
) -> Any:
|
||||||
|
"""从Agent工作空间创建工具集
|
||||||
|
|
||||||
|
这是create_agent_toolkit的增强版本,支持更灵活的技能加载策略。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: Agent标识符
|
||||||
|
config_name: 运行配置名称
|
||||||
|
owner: Agent实例
|
||||||
|
include_builtin: 是否包含builtin技能
|
||||||
|
include_customized: 是否包含customized技能
|
||||||
|
include_local: 是否包含agent-local技能
|
||||||
|
active_groups: 显式指定的活动工具组
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
配置好的Toolkit实例
|
||||||
|
"""
|
||||||
|
from agentscope.tool import Toolkit
|
||||||
|
|
||||||
|
skills_manager = SkillsManager()
|
||||||
|
agent_config = load_agent_workspace_config(
|
||||||
|
skills_manager.get_agent_asset_dir(config_name, agent_id) / "agent.yaml",
|
||||||
|
)
|
||||||
|
|
||||||
|
toolkit = Toolkit(
|
||||||
|
agent_skill_instruction=(
|
||||||
|
"<system-info>You have access to project skills. Each skill lives in a "
|
||||||
|
"directory and is described by SKILL.md. Follow the skill instructions "
|
||||||
|
"when they are relevant to the current task.</system-info>"
|
||||||
|
),
|
||||||
|
agent_skill_template="- {name} (dir: {dir}): {description}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 注册Agent类型的默认工具组
|
||||||
|
if agent_id.endswith("_analyst"):
|
||||||
|
_register_analysis_tool_groups(toolkit)
|
||||||
|
elif agent_id == "portfolio_manager" and owner is not None:
|
||||||
|
_register_portfolio_tool_groups(toolkit, owner)
|
||||||
|
elif agent_id == "risk_manager":
|
||||||
|
_register_risk_tool_groups(toolkit)
|
||||||
|
|
||||||
|
# 收集所有要加载的技能目录
|
||||||
|
skill_dirs: List[Path] = []
|
||||||
|
|
||||||
|
# 1. 从active目录加载已同步的技能
|
||||||
|
active_root = skills_manager.get_agent_active_root(config_name, agent_id)
|
||||||
|
if active_root.exists():
|
||||||
|
for skill_dir in sorted(active_root.iterdir()):
|
||||||
|
if skill_dir.is_dir() and (skill_dir / "SKILL.md").exists():
|
||||||
|
skill_dirs.append(skill_dir)
|
||||||
|
|
||||||
|
# 2. 从installed目录加载
|
||||||
|
installed_root = skills_manager.get_agent_installed_root(config_name, agent_id)
|
||||||
|
if installed_root.exists():
|
||||||
|
for skill_dir in sorted(installed_root.iterdir()):
|
||||||
|
if skill_dir.is_dir() and (skill_dir / "SKILL.md").exists():
|
||||||
|
if skill_dir not in skill_dirs:
|
||||||
|
skill_dirs.append(skill_dir)
|
||||||
|
|
||||||
|
# 3. 从local目录加载agent-local技能
|
||||||
|
if include_local:
|
||||||
|
local_root = skills_manager.get_agent_local_root(config_name, agent_id)
|
||||||
|
if local_root.exists():
|
||||||
|
for skill_dir in sorted(local_root.iterdir()):
|
||||||
|
if skill_dir.is_dir() and (skill_dir / "SKILL.md").exists():
|
||||||
|
if skill_dir not in skill_dirs:
|
||||||
|
skill_dirs.append(skill_dir)
|
||||||
|
|
||||||
|
# 注册技能到toolkit
|
||||||
|
for skill_dir in skill_dirs:
|
||||||
|
toolkit.register_agent_skill(str(skill_dir))
|
||||||
|
|
||||||
|
apply_skill_tool_restrictions(toolkit, skill_dirs)
|
||||||
|
|
||||||
|
# 激活指定的工具组
|
||||||
|
if active_groups is None:
|
||||||
|
# 从配置中读取
|
||||||
|
profiles = load_agent_profiles()
|
||||||
|
profile = profiles.get(agent_id, {})
|
||||||
|
active_groups = agent_config.active_tool_groups or profile.get("active_tool_groups", [])
|
||||||
|
|
||||||
|
# 应用禁用列表
|
||||||
|
disabled_groups = set(agent_config.disabled_tool_groups)
|
||||||
|
if disabled_groups:
|
||||||
|
active_groups = [g for g in active_groups if g not in disabled_groups]
|
||||||
|
|
||||||
|
if active_groups:
|
||||||
|
toolkit.update_tool_groups(group_names=active_groups, active=True)
|
||||||
|
|
||||||
|
return toolkit
|
||||||
|
|
||||||
|
|
||||||
|
def get_toolkit_info(toolkit: Any) -> Dict[str, Any]:
|
||||||
|
"""获取工具集信息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
toolkit: Toolkit实例
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
工具集信息字典
|
||||||
|
"""
|
||||||
|
info = {
|
||||||
|
"tool_groups": {},
|
||||||
|
"skills": [],
|
||||||
|
"tools_count": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 获取工具组信息
|
||||||
|
groups = getattr(toolkit, "tool_groups", {})
|
||||||
|
for name, group in groups.items():
|
||||||
|
info["tool_groups"][name] = {
|
||||||
|
"description": getattr(group, "description", ""),
|
||||||
|
"active": getattr(group, "active", False),
|
||||||
|
"tools": [t.name for t in getattr(group, "tools", [])],
|
||||||
|
}
|
||||||
|
info["tools_count"] += len(getattr(group, "tools", []))
|
||||||
|
|
||||||
|
# 获取技能信息
|
||||||
|
skills = getattr(toolkit, "agent_skills", [])
|
||||||
|
for skill in skills:
|
||||||
|
info["skills"].append({
|
||||||
|
"name": getattr(skill, "name", "unknown"),
|
||||||
|
"path": getattr(skill, "path", ""),
|
||||||
|
"description": getattr(skill, "description", ""),
|
||||||
|
})
|
||||||
|
|
||||||
|
return info
|
||||||
|
|
||||||
|
|
||||||
|
def refresh_toolkit_skills(
|
||||||
|
toolkit: Any,
|
||||||
|
agent_id: str,
|
||||||
|
config_name: str,
|
||||||
|
) -> None:
|
||||||
|
"""刷新工具集中的技能
|
||||||
|
|
||||||
|
重新从工作空间加载技能,用于运行时技能变更。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
toolkit: Toolkit实例
|
||||||
|
agent_id: Agent标识符
|
||||||
|
config_name: 运行配置名称
|
||||||
|
"""
|
||||||
|
skills_manager = SkillsManager()
|
||||||
|
|
||||||
|
# 清除现有技能
|
||||||
|
if hasattr(toolkit, "agent_skills"):
|
||||||
|
toolkit.agent_skills.clear()
|
||||||
|
|
||||||
|
# 重新加载active技能
|
||||||
|
active_root = skills_manager.get_agent_active_root(config_name, agent_id)
|
||||||
|
if active_root.exists():
|
||||||
|
for skill_dir in sorted(active_root.iterdir()):
|
||||||
|
if skill_dir.is_dir() and (skill_dir / "SKILL.md").exists():
|
||||||
|
toolkit.register_agent_skill(str(skill_dir))
|
||||||
|
|
||||||
|
# 重新加载local技能
|
||||||
|
local_root = skills_manager.get_agent_local_root(config_name, agent_id)
|
||||||
|
if local_root.exists():
|
||||||
|
for skill_dir in sorted(local_root.iterdir()):
|
||||||
|
if skill_dir.is_dir() and (skill_dir / "SKILL.md").exists():
|
||||||
|
toolkit.register_agent_skill(str(skill_dir))
|
||||||
|
|
||||||
|
|
||||||
|
def apply_skill_tool_restrictions(toolkit: Any, skill_dirs: List[Path]) -> None:
|
||||||
|
"""Apply per-skill allowed_tools / denied_tools restrictions to a toolkit.
|
||||||
|
|
||||||
|
If a skill specifies allowed_tools, only those tools are accessible when
|
||||||
|
that skill is active. If a skill specifies denied_tools, those tools are
|
||||||
|
removed regardless of allowed_tools. Denied tools take precedence.
|
||||||
|
|
||||||
|
This function annotates the toolkit with a _skill_tool_restrictions map
|
||||||
|
that downstream code can consult when resolving available tools.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
toolkit: The agentscope Toolkit instance.
|
||||||
|
skill_dirs: List of skill directory paths to inspect.
|
||||||
|
"""
|
||||||
|
restrictions: Dict[str, Dict[str, Set[str]]] = {}
|
||||||
|
for skill_dir in skill_dirs:
|
||||||
|
metadata = parse_skill_metadata(skill_dir, source="active")
|
||||||
|
if not metadata.allowed_tools and not metadata.denied_tools:
|
||||||
|
continue
|
||||||
|
restrictions[skill_dir.name] = {
|
||||||
|
"allowed": set(metadata.allowed_tools),
|
||||||
|
"denied": set(metadata.denied_tools),
|
||||||
|
}
|
||||||
|
if hasattr(toolkit, "agent_skills"):
|
||||||
|
for skill in toolkit.agent_skills:
|
||||||
|
skill_name = getattr(skill, "name", "") or ""
|
||||||
|
if skill_name in restrictions:
|
||||||
|
setattr(
|
||||||
|
skill,
|
||||||
|
"_tool_allowed",
|
||||||
|
restrictions[skill_name]["allowed"],
|
||||||
|
)
|
||||||
|
setattr(
|
||||||
|
skill,
|
||||||
|
"_tool_denied",
|
||||||
|
restrictions[skill_name]["denied"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_skill_effective_tools(skill: Any) -> Optional[Set[str]]:
|
||||||
|
"""Return the effective tool set for a skill after applying restrictions.
|
||||||
|
|
||||||
|
If the skill has no restrictions (no allowed_tools / denied_tools),
|
||||||
|
returns None to indicate "all tools allowed".
|
||||||
|
|
||||||
|
If allowed_tools is set, returns only those tools minus denied_tools.
|
||||||
|
If only denied_tools is set, returns all tools minus denied_tools.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
skill: A skill object previously registered via register_agent_skill.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A set of allowed tool names, or None if unrestricted.
|
||||||
|
"""
|
||||||
|
allowed = getattr(skill, "_tool_allowed", None)
|
||||||
|
denied = getattr(skill, "_tool_denied", set())
|
||||||
|
|
||||||
|
if allowed is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
effective = allowed - denied
|
||||||
|
return effective
|
||||||
|
|
||||||
|
|
||||||
|
def filter_toolkit_by_skill(
|
||||||
|
toolkit: Any,
|
||||||
|
skill_name: str,
|
||||||
|
) -> Set[str]:
|
||||||
|
"""Return the set of tool names that are accessible for a given skill.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
toolkit: The agentscope Toolkit instance.
|
||||||
|
skill_name: Name of the skill to query.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Set of allowed tool names, or all registered tool names if unrestricted.
|
||||||
|
"""
|
||||||
|
if not hasattr(toolkit, "agent_skills"):
|
||||||
|
return set()
|
||||||
|
|
||||||
|
for skill in toolkit.agent_skills:
|
||||||
|
name = getattr(skill, "name", "") or ""
|
||||||
|
if name != skill_name:
|
||||||
|
continue
|
||||||
|
effective = get_skill_effective_tools(skill)
|
||||||
|
if effective is None:
|
||||||
|
return set()
|
||||||
|
return effective
|
||||||
|
|
||||||
|
return set()
|
||||||
|
|
||||||
|
|||||||
327
backend/agents/workspace.py
Normal file
327
backend/agents/workspace.py
Normal file
@@ -0,0 +1,327 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""Workspace Manager - Create and manage agent workspaces."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class WorkspaceConfig:
|
||||||
|
"""Configuration for a workspace."""
|
||||||
|
|
||||||
|
workspace_id: str
|
||||||
|
name: str = ""
|
||||||
|
description: str = ""
|
||||||
|
created_at: str = ""
|
||||||
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Serialize to dictionary."""
|
||||||
|
return {
|
||||||
|
"workspace_id": self.workspace_id,
|
||||||
|
"name": self.name,
|
||||||
|
"description": self.description,
|
||||||
|
"created_at": self.created_at,
|
||||||
|
"metadata": self.metadata,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Dict[str, Any]) -> "WorkspaceConfig":
|
||||||
|
"""Create from dictionary."""
|
||||||
|
return cls(
|
||||||
|
workspace_id=data.get("workspace_id", ""),
|
||||||
|
name=data.get("name", ""),
|
||||||
|
description=data.get("description", ""),
|
||||||
|
created_at=data.get("created_at", ""),
|
||||||
|
metadata=data.get("metadata", {}),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class WorkspaceRegistry:
|
||||||
|
"""Registry for persistent workspace definitions (design-time)."""
|
||||||
|
|
||||||
|
def __init__(self, project_root: Optional[Path] = None):
|
||||||
|
"""Initialize the workspace manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
project_root: Root directory of the project
|
||||||
|
"""
|
||||||
|
self.project_root = project_root or Path(__file__).parent.parent.parent
|
||||||
|
self.workspaces_root = self.project_root / "workspaces"
|
||||||
|
self.workspaces_root.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
def create_workspace(
|
||||||
|
self,
|
||||||
|
workspace_id: str,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
description: Optional[str] = None,
|
||||||
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> WorkspaceConfig:
|
||||||
|
"""Create a new workspace with directory structure.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_id: Unique identifier for the workspace
|
||||||
|
name: Display name for the workspace
|
||||||
|
description: Optional description
|
||||||
|
metadata: Optional metadata dictionary
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
WorkspaceConfig instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If workspace already exists
|
||||||
|
"""
|
||||||
|
workspace_dir = self.workspaces_root / workspace_id
|
||||||
|
|
||||||
|
if workspace_dir.exists():
|
||||||
|
raise ValueError(f"Workspace '{workspace_id}' already exists")
|
||||||
|
|
||||||
|
# Create directory structure
|
||||||
|
workspace_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Create subdirectories
|
||||||
|
(workspace_dir / "agents").mkdir(exist_ok=True)
|
||||||
|
(workspace_dir / "shared" / "market_data").mkdir(parents=True, exist_ok=True)
|
||||||
|
(workspace_dir / "shared" / "memories").mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Create workspace.yaml
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
config = WorkspaceConfig(
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
name=name or workspace_id,
|
||||||
|
description=description or "",
|
||||||
|
created_at=datetime.now().isoformat(),
|
||||||
|
metadata=metadata or {},
|
||||||
|
)
|
||||||
|
|
||||||
|
self._write_workspace_config(workspace_dir, config)
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
def list_workspaces(self) -> List[WorkspaceConfig]:
|
||||||
|
"""List all workspaces.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of WorkspaceConfig instances
|
||||||
|
"""
|
||||||
|
workspaces = []
|
||||||
|
|
||||||
|
if not self.workspaces_root.exists():
|
||||||
|
return workspaces
|
||||||
|
|
||||||
|
for workspace_dir in self.workspaces_root.iterdir():
|
||||||
|
if not workspace_dir.is_dir():
|
||||||
|
continue
|
||||||
|
|
||||||
|
config_path = workspace_dir / "workspace.yaml"
|
||||||
|
if config_path.exists():
|
||||||
|
try:
|
||||||
|
with open(config_path, "r", encoding="utf-8") as f:
|
||||||
|
data = yaml.safe_load(f) or {}
|
||||||
|
workspaces.append(WorkspaceConfig.from_dict(data))
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to load workspace config {config_path}: {e}")
|
||||||
|
|
||||||
|
return workspaces
|
||||||
|
|
||||||
|
def get_workspace_agents(self, workspace_id: str) -> List[Dict[str, Any]]:
|
||||||
|
"""Get all agents in a workspace.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_id: ID of the workspace
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of agent information dictionaries
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If workspace doesn't exist
|
||||||
|
"""
|
||||||
|
workspace_dir = self.workspaces_root / workspace_id
|
||||||
|
|
||||||
|
if not workspace_dir.exists():
|
||||||
|
raise ValueError(f"Workspace '{workspace_id}' does not exist")
|
||||||
|
|
||||||
|
agents = []
|
||||||
|
agents_dir = workspace_dir / "agents"
|
||||||
|
|
||||||
|
if not agents_dir.exists():
|
||||||
|
return agents
|
||||||
|
|
||||||
|
for agent_dir in agents_dir.iterdir():
|
||||||
|
if not agent_dir.is_dir():
|
||||||
|
continue
|
||||||
|
|
||||||
|
config_path = agent_dir / "agent.yaml"
|
||||||
|
if config_path.exists():
|
||||||
|
try:
|
||||||
|
with open(config_path, "r", encoding="utf-8") as f:
|
||||||
|
config = yaml.safe_load(f) or {}
|
||||||
|
|
||||||
|
agents.append({
|
||||||
|
"agent_id": agent_dir.name,
|
||||||
|
"agent_type": config.get("agent_type", "unknown"),
|
||||||
|
"config_path": str(config_path),
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to load agent config {config_path}: {e}")
|
||||||
|
|
||||||
|
return agents
|
||||||
|
|
||||||
|
def get_agent_workspace(self, agent_id: str, workspace_id: str) -> Optional[Path]:
|
||||||
|
"""Get the workspace path for an agent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent_id: ID of the agent
|
||||||
|
workspace_id: ID of the workspace
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to agent directory, or None if not found
|
||||||
|
"""
|
||||||
|
agent_dir = self.workspaces_root / workspace_id / "agents" / agent_id
|
||||||
|
|
||||||
|
if agent_dir.exists():
|
||||||
|
return agent_dir
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def workspace_exists(self, workspace_id: str) -> bool:
|
||||||
|
"""Check if a workspace exists.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_id: ID of the workspace
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if workspace exists, False otherwise
|
||||||
|
"""
|
||||||
|
workspace_dir = self.workspaces_root / workspace_id
|
||||||
|
return workspace_dir.exists() and (workspace_dir / "workspace.yaml").exists()
|
||||||
|
|
||||||
|
def delete_workspace(self, workspace_id: str, force: bool = False) -> bool:
|
||||||
|
"""Delete a workspace and all its agents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_id: ID of the workspace to delete
|
||||||
|
force: If True, delete even if workspace has agents
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if deleted, False if workspace didn't exist
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If workspace has agents and force is False
|
||||||
|
"""
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
workspace_dir = self.workspaces_root / workspace_id
|
||||||
|
|
||||||
|
if not workspace_dir.exists():
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check for agents
|
||||||
|
agents_dir = workspace_dir / "agents"
|
||||||
|
if agents_dir.exists() and any(agents_dir.iterdir()):
|
||||||
|
if not force:
|
||||||
|
raise ValueError(
|
||||||
|
f"Workspace '{workspace_id}' contains agents. "
|
||||||
|
"Use force=True to delete anyway."
|
||||||
|
)
|
||||||
|
|
||||||
|
shutil.rmtree(workspace_dir)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_workspace_path(self, workspace_id: str) -> Path:
|
||||||
|
"""Get the path to a workspace directory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_id: ID of the workspace
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to workspace directory
|
||||||
|
"""
|
||||||
|
return self.workspaces_root / workspace_id
|
||||||
|
|
||||||
|
def get_shared_data_path(self, workspace_id: str) -> Optional[Path]:
|
||||||
|
"""Get the shared data directory for a workspace.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_id: ID of the workspace
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to shared data directory, or None if workspace doesn't exist
|
||||||
|
"""
|
||||||
|
workspace_dir = self.workspaces_root / workspace_id
|
||||||
|
|
||||||
|
if not workspace_dir.exists():
|
||||||
|
return None
|
||||||
|
|
||||||
|
return workspace_dir / "shared"
|
||||||
|
|
||||||
|
def update_workspace_config(
|
||||||
|
self,
|
||||||
|
workspace_id: str,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
description: Optional[str] = None,
|
||||||
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> WorkspaceConfig:
|
||||||
|
"""Update workspace configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_id: ID of the workspace
|
||||||
|
name: New display name (optional)
|
||||||
|
description: New description (optional)
|
||||||
|
metadata: Metadata to merge (optional)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated WorkspaceConfig
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If workspace doesn't exist
|
||||||
|
"""
|
||||||
|
workspace_dir = self.workspaces_root / workspace_id
|
||||||
|
|
||||||
|
if not workspace_dir.exists():
|
||||||
|
raise ValueError(f"Workspace '{workspace_id}' does not exist")
|
||||||
|
|
||||||
|
config_path = workspace_dir / "workspace.yaml"
|
||||||
|
current_config = {}
|
||||||
|
|
||||||
|
if config_path.exists():
|
||||||
|
try:
|
||||||
|
with open(config_path, "r", encoding="utf-8") as f:
|
||||||
|
current_config = yaml.safe_load(f) or {}
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to load existing config {config_path}: {e}")
|
||||||
|
|
||||||
|
# Update fields
|
||||||
|
if name is not None:
|
||||||
|
current_config["name"] = name
|
||||||
|
if description is not None:
|
||||||
|
current_config["description"] = description
|
||||||
|
if metadata is not None:
|
||||||
|
current_config["metadata"] = {**current_config.get("metadata", {}), **metadata}
|
||||||
|
|
||||||
|
config = WorkspaceConfig.from_dict(current_config)
|
||||||
|
self._write_workspace_config(workspace_dir, config)
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
def _write_workspace_config(self, workspace_dir: Path, config: WorkspaceConfig) -> None:
|
||||||
|
"""Write workspace configuration to file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_dir: Workspace directory
|
||||||
|
config: Workspace configuration
|
||||||
|
"""
|
||||||
|
config_path = workspace_dir / "workspace.yaml"
|
||||||
|
with open(config_path, "w", encoding="utf-8") as f:
|
||||||
|
yaml.safe_dump(config.to_dict(), f, allow_unicode=True, sort_keys=False)
|
||||||
|
|
||||||
|
|
||||||
|
# Backward-compatible alias: legacy imports expect WorkspaceManager.
|
||||||
|
WorkspaceManager = WorkspaceRegistry
|
||||||
@@ -4,10 +4,13 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Iterable, Optional
|
from typing import Dict, Iterable, Optional
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
from .skills_manager import SkillsManager
|
from .skills_manager import SkillsManager
|
||||||
|
from .team_pipeline_config import ensure_team_pipeline_config
|
||||||
|
|
||||||
|
|
||||||
class WorkspaceManager:
|
class RunWorkspaceManager:
|
||||||
"""Create and maintain run-level prompt asset files for each agent."""
|
"""Create and maintain run-level prompt asset files for each agent."""
|
||||||
|
|
||||||
def __init__(self, project_root: Optional[Path] = None):
|
def __init__(self, project_root: Optional[Path] = None):
|
||||||
@@ -21,6 +24,16 @@ class WorkspaceManager:
|
|||||||
run_dir = self.get_run_dir(config_name)
|
run_dir = self.get_run_dir(config_name)
|
||||||
run_dir.mkdir(parents=True, exist_ok=True)
|
run_dir.mkdir(parents=True, exist_ok=True)
|
||||||
self.skills_manager.ensure_activation_manifest(config_name)
|
self.skills_manager.ensure_activation_manifest(config_name)
|
||||||
|
ensure_team_pipeline_config(
|
||||||
|
project_root=self.project_root,
|
||||||
|
config_name=config_name,
|
||||||
|
default_analysts=[
|
||||||
|
"fundamentals_analyst",
|
||||||
|
"technical_analyst",
|
||||||
|
"sentiment_analyst",
|
||||||
|
"valuation_analyst",
|
||||||
|
],
|
||||||
|
)
|
||||||
bootstrap_path = run_dir / "BOOTSTRAP.md"
|
bootstrap_path = run_dir / "BOOTSTRAP.md"
|
||||||
if not bootstrap_path.exists():
|
if not bootstrap_path.exists():
|
||||||
bootstrap_path.write_text(
|
bootstrap_path.write_text(
|
||||||
@@ -59,6 +72,10 @@ class WorkspaceManager:
|
|||||||
agent_id,
|
agent_id,
|
||||||
)
|
)
|
||||||
asset_dir.mkdir(parents=True, exist_ok=True)
|
asset_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
(asset_dir / "skills" / "installed").mkdir(parents=True, exist_ok=True)
|
||||||
|
(asset_dir / "skills" / "active").mkdir(parents=True, exist_ok=True)
|
||||||
|
(asset_dir / "skills" / "disabled").mkdir(parents=True, exist_ok=True)
|
||||||
|
(asset_dir / "skills" / "local").mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
self._ensure_file(
|
self._ensure_file(
|
||||||
asset_dir / "ROLE.md",
|
asset_dir / "ROLE.md",
|
||||||
@@ -81,6 +98,35 @@ class WorkspaceManager:
|
|||||||
f"{policy_seed}".strip()
|
f"{policy_seed}".strip()
|
||||||
+ "\n",
|
+ "\n",
|
||||||
)
|
)
|
||||||
|
self._ensure_file(
|
||||||
|
asset_dir / "SOUL.md",
|
||||||
|
"# Soul\n\n"
|
||||||
|
"Describe the agent's temperament, reasoning posture, and voice.\n\n",
|
||||||
|
)
|
||||||
|
self._ensure_file(
|
||||||
|
asset_dir / "PROFILE.md",
|
||||||
|
"# Profile\n\n"
|
||||||
|
"Track this agent's long-lived investment style, preferences, and strengths.\n\n",
|
||||||
|
)
|
||||||
|
self._ensure_file(
|
||||||
|
asset_dir / "AGENTS.md",
|
||||||
|
"# Agent Guide\n\n"
|
||||||
|
"Document how this agent should work, collaborate, and choose tools or skills.\n\n",
|
||||||
|
)
|
||||||
|
self._ensure_file(
|
||||||
|
asset_dir / "MEMORY.md",
|
||||||
|
"# Memory\n\n"
|
||||||
|
"Store durable lessons, heuristics, and reminders for this agent.\n\n",
|
||||||
|
)
|
||||||
|
self._ensure_file(
|
||||||
|
asset_dir / "HEARTBEAT.md",
|
||||||
|
"# Heartbeat\n\n"
|
||||||
|
"Optional checklist for periodic review or self-reflection.\n\n",
|
||||||
|
)
|
||||||
|
self._ensure_agent_yaml(
|
||||||
|
asset_dir / "agent.yaml",
|
||||||
|
agent_id=agent_id,
|
||||||
|
)
|
||||||
return asset_dir
|
return asset_dir
|
||||||
|
|
||||||
def initialize_default_assets(
|
def initialize_default_assets(
|
||||||
@@ -138,3 +184,31 @@ class WorkspaceManager:
|
|||||||
def _ensure_file(path: Path, content: str) -> None:
|
def _ensure_file(path: Path, content: str) -> None:
|
||||||
if not path.exists():
|
if not path.exists():
|
||||||
path.write_text(content, encoding="utf-8")
|
path.write_text(content, encoding="utf-8")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _ensure_agent_yaml(path: Path, agent_id: str) -> None:
|
||||||
|
if path.exists():
|
||||||
|
return
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"agent_id": agent_id,
|
||||||
|
"prompt_files": [
|
||||||
|
"SOUL.md",
|
||||||
|
"PROFILE.md",
|
||||||
|
"AGENTS.md",
|
||||||
|
"POLICY.md",
|
||||||
|
"MEMORY.md",
|
||||||
|
],
|
||||||
|
"enabled_skills": [],
|
||||||
|
"disabled_skills": [],
|
||||||
|
"active_tool_groups": [],
|
||||||
|
"disabled_tool_groups": [],
|
||||||
|
}
|
||||||
|
path.write_text(
|
||||||
|
yaml.safe_dump(payload, allow_unicode=True, sort_keys=False),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Backward-compatible alias: code importing WorkspaceManager from this module should continue to work.
|
||||||
|
WorkspaceManager = RunWorkspaceManager
|
||||||
|
|||||||
21
backend/api/__init__.py
Normal file
21
backend/api/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
API Routes Package
|
||||||
|
|
||||||
|
Provides REST API endpoints for:
|
||||||
|
- Agent management
|
||||||
|
- Workspace management
|
||||||
|
- Tool guard operations
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .agents import router as agents_router
|
||||||
|
from .workspaces import router as workspaces_router
|
||||||
|
from .guard import router as guard_router
|
||||||
|
from .runtime import router as runtime_router
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"agents_router",
|
||||||
|
"workspaces_router",
|
||||||
|
"guard_router",
|
||||||
|
"runtime_router",
|
||||||
|
]
|
||||||
497
backend/api/agents.py
Normal file
497
backend/api/agents.py
Normal file
@@ -0,0 +1,497 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
Agent API Routes
|
||||||
|
|
||||||
|
Provides REST API endpoints for agent management within workspaces.
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException, Depends, Body, UploadFile, File, Form
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from backend.agents import AgentFactory, WorkspaceManager, get_registry
|
||||||
|
from backend.agents.skills_manager import SkillsManager
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/workspaces/{workspace_id}/agents", tags=["agents"])
|
||||||
|
|
||||||
|
|
||||||
|
# Request/Response Models
|
||||||
|
class CreateAgentRequest(BaseModel):
|
||||||
|
"""Request to create a new agent."""
|
||||||
|
agent_id: str = Field(..., description="Unique agent identifier")
|
||||||
|
agent_type: str = Field(..., description="Type of agent (e.g., technical_analyst)")
|
||||||
|
name: Optional[str] = Field(None, description="Display name")
|
||||||
|
description: Optional[str] = Field(None, description="Agent description")
|
||||||
|
clone_from: Optional[str] = Field(None, description="Agent ID to clone from")
|
||||||
|
llm_model_config: Optional[Dict[str, Any]] = Field(None, description="LLM model configuration")
|
||||||
|
|
||||||
|
|
||||||
|
class UpdateAgentRequest(BaseModel):
|
||||||
|
"""Request to update an agent."""
|
||||||
|
name: Optional[str] = None
|
||||||
|
description: Optional[str] = None
|
||||||
|
enabled_skills: Optional[List[str]] = None
|
||||||
|
disabled_skills: Optional[List[str]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class InstallExternalSkillRequest(BaseModel):
|
||||||
|
"""Request to install an external skill for one agent."""
|
||||||
|
source: str = Field(..., description="Directory path, zip path, or http(s) zip URL")
|
||||||
|
name: Optional[str] = Field(None, description="Optional override skill name")
|
||||||
|
activate: bool = Field(True, description="Whether to enable skill immediately")
|
||||||
|
|
||||||
|
|
||||||
|
class AgentResponse(BaseModel):
|
||||||
|
"""Agent information response."""
|
||||||
|
agent_id: str
|
||||||
|
agent_type: str
|
||||||
|
workspace_id: str
|
||||||
|
config_path: str
|
||||||
|
agent_dir: str
|
||||||
|
status: str = "inactive"
|
||||||
|
|
||||||
|
|
||||||
|
class AgentFileResponse(BaseModel):
|
||||||
|
"""Agent file content response."""
|
||||||
|
filename: str
|
||||||
|
content: str
|
||||||
|
|
||||||
|
|
||||||
|
# Dependencies
|
||||||
|
def get_agent_factory():
|
||||||
|
"""Get AgentFactory instance."""
|
||||||
|
return AgentFactory()
|
||||||
|
|
||||||
|
|
||||||
|
def get_workspace_manager():
|
||||||
|
"""Get WorkspaceManager instance."""
|
||||||
|
return WorkspaceManager()
|
||||||
|
|
||||||
|
|
||||||
|
def get_skills_manager():
|
||||||
|
"""Get SkillsManager instance."""
|
||||||
|
return SkillsManager()
|
||||||
|
|
||||||
|
|
||||||
|
# Routes
|
||||||
|
@router.post("", response_model=AgentResponse)
|
||||||
|
async def create_agent(
|
||||||
|
workspace_id: str,
|
||||||
|
request: CreateAgentRequest,
|
||||||
|
factory: AgentFactory = Depends(get_agent_factory),
|
||||||
|
registry = Depends(get_registry),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Create a new agent in a workspace.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_id: Workspace identifier
|
||||||
|
request: Agent creation parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Created agent information
|
||||||
|
"""
|
||||||
|
# Check workspace exists
|
||||||
|
if not factory.workspaces_root.exists():
|
||||||
|
raise HTTPException(status_code=404, detail="Workspaces root not found")
|
||||||
|
|
||||||
|
workspace_dir = factory.workspaces_root / workspace_id
|
||||||
|
if not workspace_dir.exists():
|
||||||
|
raise HTTPException(status_code=404, detail=f"Workspace '{workspace_id}' not found")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create agent
|
||||||
|
agent = factory.create_agent(
|
||||||
|
agent_id=request.agent_id,
|
||||||
|
agent_type=request.agent_type,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
clone_from=request.clone_from,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Register in registry
|
||||||
|
registry.register(
|
||||||
|
agent_id=request.agent_id,
|
||||||
|
agent_type=request.agent_type,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
config_path=str(agent.config_path),
|
||||||
|
agent_dir=str(agent.agent_dir),
|
||||||
|
status="inactive",
|
||||||
|
)
|
||||||
|
|
||||||
|
return AgentResponse(
|
||||||
|
agent_id=agent.agent_id,
|
||||||
|
agent_type=agent.agent_type,
|
||||||
|
workspace_id=agent.workspace_id,
|
||||||
|
config_path=str(agent.config_path),
|
||||||
|
agent_dir=str(agent.agent_dir),
|
||||||
|
status="inactive",
|
||||||
|
)
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("", response_model=List[AgentResponse])
|
||||||
|
async def list_agents(
|
||||||
|
workspace_id: str,
|
||||||
|
factory: AgentFactory = Depends(get_agent_factory),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
List all agents in a workspace.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_id: Workspace identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of agents
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
agents_data = factory.list_agents(workspace_id=workspace_id)
|
||||||
|
return [
|
||||||
|
AgentResponse(
|
||||||
|
agent_id=agent["agent_id"],
|
||||||
|
agent_type=agent["agent_type"],
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
config_path=agent["config_path"],
|
||||||
|
agent_dir=str(Path(agent["config_path"]).parent),
|
||||||
|
status="inactive",
|
||||||
|
)
|
||||||
|
for agent in agents_data
|
||||||
|
]
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{agent_id}", response_model=AgentResponse)
|
||||||
|
async def get_agent(
|
||||||
|
workspace_id: str,
|
||||||
|
agent_id: str,
|
||||||
|
registry = Depends(get_registry),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get agent details.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_id: Workspace identifier
|
||||||
|
agent_id: Agent identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Agent information
|
||||||
|
"""
|
||||||
|
agent_info = registry.get(agent_id)
|
||||||
|
|
||||||
|
if not agent_info or agent_info.workspace_id != workspace_id:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Agent '{agent_id}' not found")
|
||||||
|
|
||||||
|
return AgentResponse(
|
||||||
|
agent_id=agent_info.agent_id,
|
||||||
|
agent_type=agent_info.agent_type,
|
||||||
|
workspace_id=agent_info.workspace_id,
|
||||||
|
config_path=agent_info.config_path,
|
||||||
|
agent_dir=agent_info.agent_dir,
|
||||||
|
status=agent_info.status,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{agent_id}")
|
||||||
|
async def delete_agent(
|
||||||
|
workspace_id: str,
|
||||||
|
agent_id: str,
|
||||||
|
factory: AgentFactory = Depends(get_agent_factory),
|
||||||
|
registry = Depends(get_registry),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Delete an agent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_id: Workspace identifier
|
||||||
|
agent_id: Agent identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Success message
|
||||||
|
"""
|
||||||
|
# Check agent exists in registry
|
||||||
|
agent_info = registry.get(agent_id)
|
||||||
|
if not agent_info or agent_info.workspace_id != workspace_id:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Agent '{agent_id}' not found")
|
||||||
|
|
||||||
|
# Delete from factory
|
||||||
|
success = factory.delete_agent(agent_id, workspace_id)
|
||||||
|
if not success:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Agent '{agent_id}' not found")
|
||||||
|
|
||||||
|
# Unregister
|
||||||
|
registry.unregister(agent_id)
|
||||||
|
|
||||||
|
return {"message": f"Agent '{agent_id}' deleted successfully"}
|
||||||
|
|
||||||
|
|
||||||
|
@router.patch("/{agent_id}", response_model=AgentResponse)
|
||||||
|
async def update_agent(
|
||||||
|
workspace_id: str,
|
||||||
|
agent_id: str,
|
||||||
|
request: UpdateAgentRequest,
|
||||||
|
registry = Depends(get_registry),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Update agent configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_id: Workspace identifier
|
||||||
|
agent_id: Agent identifier
|
||||||
|
request: Update parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated agent information
|
||||||
|
"""
|
||||||
|
agent_info = registry.get(agent_id)
|
||||||
|
if not agent_info or agent_info.workspace_id != workspace_id:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Agent '{agent_id}' not found")
|
||||||
|
|
||||||
|
# Update metadata in registry
|
||||||
|
metadata_updates = {}
|
||||||
|
if request.name:
|
||||||
|
metadata_updates["name"] = request.name
|
||||||
|
if request.description:
|
||||||
|
metadata_updates["description"] = request.description
|
||||||
|
|
||||||
|
if metadata_updates:
|
||||||
|
registry.update_metadata(agent_id, metadata_updates)
|
||||||
|
|
||||||
|
# Update skills if provided
|
||||||
|
if request.enabled_skills or request.disabled_skills:
|
||||||
|
skills_manager = SkillsManager()
|
||||||
|
skills_manager.update_agent_skill_overrides(
|
||||||
|
config_name=workspace_id,
|
||||||
|
agent_id=agent_id,
|
||||||
|
enable=request.enabled_skills or [],
|
||||||
|
disable=request.disabled_skills or [],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get updated info
|
||||||
|
agent_info = registry.get(agent_id)
|
||||||
|
return AgentResponse(
|
||||||
|
agent_id=agent_info.agent_id,
|
||||||
|
agent_type=agent_info.agent_type,
|
||||||
|
workspace_id=agent_info.workspace_id,
|
||||||
|
config_path=agent_info.config_path,
|
||||||
|
agent_dir=agent_info.agent_dir,
|
||||||
|
status=agent_info.status,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{agent_id}/skills/{skill_name}/enable")
|
||||||
|
async def enable_skill(
|
||||||
|
workspace_id: str,
|
||||||
|
agent_id: str,
|
||||||
|
skill_name: str,
|
||||||
|
registry = Depends(get_registry),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Enable a skill for an agent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_id: Workspace identifier
|
||||||
|
agent_id: Agent identifier
|
||||||
|
skill_name: Skill name to enable
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Success message
|
||||||
|
"""
|
||||||
|
agent_info = registry.get(agent_id)
|
||||||
|
if not agent_info or agent_info.workspace_id != workspace_id:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Agent '{agent_id}' not found")
|
||||||
|
|
||||||
|
skills_manager = SkillsManager()
|
||||||
|
result = skills_manager.update_agent_skill_overrides(
|
||||||
|
config_name=workspace_id,
|
||||||
|
agent_id=agent_id,
|
||||||
|
enable=[skill_name],
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"message": f"Skill '{skill_name}' enabled for agent '{agent_id}'",
|
||||||
|
"enabled_skills": result["enabled_skills"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{agent_id}/skills/{skill_name}/disable")
|
||||||
|
async def disable_skill(
|
||||||
|
workspace_id: str,
|
||||||
|
agent_id: str,
|
||||||
|
skill_name: str,
|
||||||
|
registry = Depends(get_registry),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Disable a skill for an agent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_id: Workspace identifier
|
||||||
|
agent_id: Agent identifier
|
||||||
|
skill_name: Skill name to disable
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Success message
|
||||||
|
"""
|
||||||
|
agent_info = registry.get(agent_id)
|
||||||
|
if not agent_info or agent_info.workspace_id != workspace_id:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Agent '{agent_id}' not found")
|
||||||
|
|
||||||
|
skills_manager = SkillsManager()
|
||||||
|
result = skills_manager.update_agent_skill_overrides(
|
||||||
|
config_name=workspace_id,
|
||||||
|
agent_id=agent_id,
|
||||||
|
disable=[skill_name],
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"message": f"Skill '{skill_name}' disabled for agent '{agent_id}'",
|
||||||
|
"disabled_skills": result["disabled_skills"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{agent_id}/skills/install")
|
||||||
|
async def install_external_skill(
|
||||||
|
workspace_id: str,
|
||||||
|
agent_id: str,
|
||||||
|
request: InstallExternalSkillRequest,
|
||||||
|
registry=Depends(get_registry),
|
||||||
|
):
|
||||||
|
"""Install an external skill into one agent's local skills."""
|
||||||
|
agent_info = registry.get(agent_id)
|
||||||
|
if not agent_info or agent_info.workspace_id != workspace_id:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Agent '{agent_id}' not found")
|
||||||
|
|
||||||
|
skills_manager = SkillsManager()
|
||||||
|
try:
|
||||||
|
result = skills_manager.install_external_skill_for_agent(
|
||||||
|
config_name=workspace_id,
|
||||||
|
agent_id=agent_id,
|
||||||
|
source=request.source,
|
||||||
|
skill_name=request.name,
|
||||||
|
activate=request.activate,
|
||||||
|
)
|
||||||
|
except (FileNotFoundError, ValueError) as exc:
|
||||||
|
raise HTTPException(status_code=400, detail=str(exc))
|
||||||
|
|
||||||
|
return {
|
||||||
|
"message": f"Installed external skill '{result['skill_name']}' for '{agent_id}'",
|
||||||
|
**result,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{agent_id}/skills/upload")
|
||||||
|
async def upload_external_skill(
|
||||||
|
workspace_id: str,
|
||||||
|
agent_id: str,
|
||||||
|
file: UploadFile = File(...),
|
||||||
|
name: Optional[str] = Form(None),
|
||||||
|
activate: bool = Form(True),
|
||||||
|
registry=Depends(get_registry),
|
||||||
|
):
|
||||||
|
"""Upload a zip skill package from frontend and install for one agent."""
|
||||||
|
agent_info = registry.get(agent_id)
|
||||||
|
if not agent_info or agent_info.workspace_id != workspace_id:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Agent '{agent_id}' not found")
|
||||||
|
|
||||||
|
original_name = (file.filename or "").strip()
|
||||||
|
if not original_name.lower().endswith(".zip"):
|
||||||
|
raise HTTPException(status_code=400, detail="Uploaded file must be a .zip archive")
|
||||||
|
|
||||||
|
suffix = Path(original_name).suffix or ".zip"
|
||||||
|
temp_path: Optional[str] = None
|
||||||
|
try:
|
||||||
|
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
|
||||||
|
temp_path = tmp.name
|
||||||
|
content = await file.read()
|
||||||
|
tmp.write(content)
|
||||||
|
|
||||||
|
skills_manager = SkillsManager()
|
||||||
|
result = skills_manager.install_external_skill_for_agent(
|
||||||
|
config_name=workspace_id,
|
||||||
|
agent_id=agent_id,
|
||||||
|
source=temp_path,
|
||||||
|
skill_name=name,
|
||||||
|
activate=activate,
|
||||||
|
)
|
||||||
|
except (FileNotFoundError, ValueError) as exc:
|
||||||
|
raise HTTPException(status_code=400, detail=str(exc))
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
await file.close()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to close uploaded file: {e}")
|
||||||
|
if temp_path and os.path.exists(temp_path):
|
||||||
|
os.remove(temp_path)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"message": f"Uploaded and installed external skill '{result['skill_name']}' for '{agent_id}'",
|
||||||
|
**result,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{agent_id}/files/{filename}", response_model=AgentFileResponse)
|
||||||
|
async def get_agent_file(
|
||||||
|
workspace_id: str,
|
||||||
|
agent_id: str,
|
||||||
|
filename: str,
|
||||||
|
workspace_manager: WorkspaceManager = Depends(get_workspace_manager),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Read an agent's workspace file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_id: Workspace identifier
|
||||||
|
agent_id: Agent identifier
|
||||||
|
filename: File to read (e.g., SOUL.md, ROLE.md)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
File content
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
content = workspace_manager.load_agent_file(
|
||||||
|
config_name=workspace_id,
|
||||||
|
agent_id=agent_id,
|
||||||
|
filename=filename,
|
||||||
|
)
|
||||||
|
return AgentFileResponse(filename=filename, content=content)
|
||||||
|
except FileNotFoundError:
|
||||||
|
raise HTTPException(status_code=404, detail=f"File '{filename}' not found")
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/{agent_id}/files/{filename}", response_model=AgentFileResponse)
|
||||||
|
async def update_agent_file(
|
||||||
|
workspace_id: str,
|
||||||
|
agent_id: str,
|
||||||
|
filename: str,
|
||||||
|
content: str = Body(..., media_type="text/plain"),
|
||||||
|
workspace_manager: WorkspaceManager = Depends(get_workspace_manager),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Update an agent's workspace file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_id: Workspace identifier
|
||||||
|
agent_id: Agent identifier
|
||||||
|
filename: File to update
|
||||||
|
content: New file content
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated file information
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
workspace_manager.update_agent_file(
|
||||||
|
config_name=workspace_id,
|
||||||
|
agent_id=agent_id,
|
||||||
|
filename=filename,
|
||||||
|
content=content,
|
||||||
|
)
|
||||||
|
return AgentFileResponse(filename=filename, content=content)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
257
backend/api/guard.py
Normal file
257
backend/api/guard.py
Normal file
@@ -0,0 +1,257 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
Tool Guard API Routes
|
||||||
|
|
||||||
|
Provides REST API endpoints for tool guard operations.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from backend.agents.base.tool_guard import (
|
||||||
|
ApprovalRecord,
|
||||||
|
ApprovalStatus,
|
||||||
|
SeverityLevel,
|
||||||
|
TOOL_GUARD_STORE,
|
||||||
|
default_findings_for_tool,
|
||||||
|
)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/guard", tags=["guard"])
|
||||||
|
|
||||||
|
|
||||||
|
# Request/Response Models
|
||||||
|
class ToolCallRequest(BaseModel):
|
||||||
|
"""Tool call request."""
|
||||||
|
tool_name: str = Field(..., description="Name of the tool")
|
||||||
|
tool_input: Dict[str, Any] = Field(default_factory=dict, description="Tool parameters")
|
||||||
|
agent_id: str = Field(..., description="Agent making the request")
|
||||||
|
workspace_id: str = Field(..., description="Workspace context")
|
||||||
|
session_id: Optional[str] = Field(None, description="Session identifier")
|
||||||
|
|
||||||
|
|
||||||
|
class ApprovalRequest(BaseModel):
|
||||||
|
"""Request to approve a tool call."""
|
||||||
|
approval_id: str = Field(..., description="Approval request ID")
|
||||||
|
one_time: bool = Field(True, description="Whether this is a one-time approval")
|
||||||
|
expires_in_minutes: Optional[int] = Field(30, description="Approval expiration time")
|
||||||
|
|
||||||
|
|
||||||
|
class DenyRequest(BaseModel):
|
||||||
|
"""Request to deny a tool call."""
|
||||||
|
approval_id: str = Field(..., description="Approval request ID")
|
||||||
|
reason: Optional[str] = Field(None, description="Reason for denial")
|
||||||
|
|
||||||
|
|
||||||
|
class ToolFinding(BaseModel):
|
||||||
|
"""Tool guard finding."""
|
||||||
|
severity: SeverityLevel
|
||||||
|
message: str
|
||||||
|
field: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ApprovalResponse(BaseModel):
|
||||||
|
"""Tool approval response."""
|
||||||
|
approval_id: str
|
||||||
|
status: ApprovalStatus
|
||||||
|
tool_name: str
|
||||||
|
tool_input: Dict[str, Any]
|
||||||
|
agent_id: str
|
||||||
|
workspace_id: str
|
||||||
|
session_id: Optional[str] = None
|
||||||
|
findings: List[ToolFinding] = Field(default_factory=list)
|
||||||
|
created_at: str
|
||||||
|
resolved_at: Optional[str] = None
|
||||||
|
resolved_by: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class PendingApprovalsResponse(BaseModel):
|
||||||
|
"""List of pending approvals."""
|
||||||
|
approvals: List[ApprovalResponse]
|
||||||
|
total: int
|
||||||
|
|
||||||
|
|
||||||
|
STORE = TOOL_GUARD_STORE
|
||||||
|
SAFE_TOOLS = {
|
||||||
|
"get_price",
|
||||||
|
"get_fundamentals",
|
||||||
|
"get_news",
|
||||||
|
"analyze_technical",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _to_response(record: ApprovalRecord) -> ApprovalResponse:
|
||||||
|
return ApprovalResponse(
|
||||||
|
approval_id=record.approval_id,
|
||||||
|
status=record.status,
|
||||||
|
tool_name=record.tool_name,
|
||||||
|
tool_input=record.tool_input,
|
||||||
|
agent_id=record.agent_id,
|
||||||
|
workspace_id=record.workspace_id,
|
||||||
|
session_id=record.session_id,
|
||||||
|
findings=[ToolFinding(**f.to_dict()) for f in record.findings],
|
||||||
|
created_at=record.created_at.isoformat(),
|
||||||
|
resolved_at=record.resolved_at.isoformat() if record.resolved_at else None,
|
||||||
|
resolved_by=record.resolved_by,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Routes
|
||||||
|
@router.post("/check", response_model=ApprovalResponse)
|
||||||
|
async def check_tool_call(
|
||||||
|
request: ToolCallRequest,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Check if a tool call requires approval.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: Tool call details
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Approval status - may be auto-approved, auto-denied, or pending
|
||||||
|
"""
|
||||||
|
record = STORE.create_pending(
|
||||||
|
tool_name=request.tool_name,
|
||||||
|
tool_input=request.tool_input,
|
||||||
|
agent_id=request.agent_id,
|
||||||
|
workspace_id=request.workspace_id,
|
||||||
|
session_id=request.session_id,
|
||||||
|
findings=default_findings_for_tool(request.tool_name),
|
||||||
|
)
|
||||||
|
|
||||||
|
if request.tool_name in SAFE_TOOLS:
|
||||||
|
record.status = ApprovalStatus.APPROVED
|
||||||
|
record.resolved_at = datetime.utcnow()
|
||||||
|
record.resolved_by = "system"
|
||||||
|
STORE.set_status(
|
||||||
|
record.approval_id,
|
||||||
|
ApprovalStatus.APPROVED,
|
||||||
|
resolved_by="system",
|
||||||
|
notify_request=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return _to_response(record)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/approve", response_model=ApprovalResponse)
|
||||||
|
async def approve_tool_call(
|
||||||
|
request: ApprovalRequest,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Approve a pending tool call.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: Approval parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated approval status
|
||||||
|
"""
|
||||||
|
record = STORE.get(request.approval_id)
|
||||||
|
if not record:
|
||||||
|
raise HTTPException(status_code=404, detail="Approval request not found")
|
||||||
|
|
||||||
|
if record.status != ApprovalStatus.PENDING:
|
||||||
|
raise HTTPException(status_code=400, detail=f"Approval already {record.status}")
|
||||||
|
|
||||||
|
record.status = ApprovalStatus.APPROVED
|
||||||
|
record.resolved_at = datetime.utcnow()
|
||||||
|
record.resolved_by = "user"
|
||||||
|
|
||||||
|
return _to_response(record)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/deny", response_model=ApprovalResponse)
|
||||||
|
async def deny_tool_call(
|
||||||
|
request: DenyRequest,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Deny a pending tool call.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: Denial parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated approval status
|
||||||
|
"""
|
||||||
|
record = STORE.get(request.approval_id)
|
||||||
|
if not record:
|
||||||
|
raise HTTPException(status_code=404, detail="Approval request not found")
|
||||||
|
|
||||||
|
if record.status != ApprovalStatus.PENDING:
|
||||||
|
raise HTTPException(status_code=400, detail=f"Approval already {record.status}")
|
||||||
|
|
||||||
|
record.status = ApprovalStatus.DENIED
|
||||||
|
record.resolved_at = datetime.utcnow()
|
||||||
|
record.resolved_by = "user"
|
||||||
|
record.metadata["denial_reason"] = request.reason
|
||||||
|
|
||||||
|
return _to_response(record)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/pending", response_model=PendingApprovalsResponse)
|
||||||
|
async def list_pending_approvals(
|
||||||
|
workspace_id: Optional[str] = None,
|
||||||
|
agent_id: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
List pending tool approval requests.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_id: Filter by workspace
|
||||||
|
agent_id: Filter by agent
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of pending approvals
|
||||||
|
"""
|
||||||
|
pending = [
|
||||||
|
_to_response(record)
|
||||||
|
for record in STORE.list(
|
||||||
|
status=ApprovalStatus.PENDING,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
agent_id=agent_id,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
return PendingApprovalsResponse(approvals=pending, total=len(pending))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/approvals/{approval_id}", response_model=ApprovalResponse)
|
||||||
|
async def get_approval_status(
|
||||||
|
approval_id: str,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get the status of a specific approval request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
approval_id: Approval request ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Approval status
|
||||||
|
"""
|
||||||
|
record = STORE.get(approval_id)
|
||||||
|
if not record:
|
||||||
|
raise HTTPException(status_code=404, detail="Approval request not found")
|
||||||
|
return _to_response(record)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/approvals/{approval_id}")
|
||||||
|
async def cancel_approval(
|
||||||
|
approval_id: str,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Cancel/delete a pending approval request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
approval_id: Approval request ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Success message
|
||||||
|
"""
|
||||||
|
record = STORE.get(approval_id)
|
||||||
|
if not record:
|
||||||
|
raise HTTPException(status_code=404, detail="Approval request not found")
|
||||||
|
|
||||||
|
STORE.cancel(approval_id)
|
||||||
|
return _to_response(record)
|
||||||
720
backend/api/runtime.py
Normal file
720
backend/api/runtime.py
Normal file
@@ -0,0 +1,720 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""Runtime API routes - Control Plane for managing Gateway processes."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import signal
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
from fastapi import APIRouter, BackgroundTasks, HTTPException, Request
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from backend.runtime.agent_runtime import AgentRuntimeState
|
||||||
|
from backend.config.bootstrap_config import (
|
||||||
|
resolve_runtime_config,
|
||||||
|
update_bootstrap_values_for_run,
|
||||||
|
)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/runtime", tags=["runtime"])
|
||||||
|
|
||||||
|
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||||
|
|
||||||
|
|
||||||
|
class RuntimeState:
|
||||||
|
"""Thread-safe singleton for managing runtime state.
|
||||||
|
|
||||||
|
Encapsulates runtime_manager, _gateway_process, and _gateway_port
|
||||||
|
with asyncio.Lock protection for concurrent access.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_instance: Optional["RuntimeState"] = None
|
||||||
|
_lock: asyncio.Lock = asyncio.Lock()
|
||||||
|
|
||||||
|
def __new__(cls) -> "RuntimeState":
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super().__new__(cls)
|
||||||
|
cls._instance._initialized = False
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
if self._initialized:
|
||||||
|
return
|
||||||
|
self._runtime_manager: Optional[Any] = None
|
||||||
|
self._gateway_process: Optional[subprocess.Popen] = None
|
||||||
|
self._gateway_port: int = 8765
|
||||||
|
self._state_lock = asyncio.Lock()
|
||||||
|
self._initialized = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
async def lock(self) -> asyncio.Lock:
|
||||||
|
"""Get the asyncio lock for state synchronization."""
|
||||||
|
return self._state_lock
|
||||||
|
|
||||||
|
@property
|
||||||
|
def runtime_manager(self) -> Optional[Any]:
|
||||||
|
"""Get the runtime manager (no lock - read only)."""
|
||||||
|
return self._runtime_manager
|
||||||
|
|
||||||
|
@runtime_manager.setter
|
||||||
|
def runtime_manager(self, value: Optional[Any]) -> None:
|
||||||
|
"""Set the runtime manager."""
|
||||||
|
self._runtime_manager = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def gateway_process(self) -> Optional[subprocess.Popen]:
|
||||||
|
"""Get the gateway process (no lock - read only)."""
|
||||||
|
return self._gateway_process
|
||||||
|
|
||||||
|
@gateway_process.setter
|
||||||
|
def gateway_process(self, value: Optional[subprocess.Popen]) -> None:
|
||||||
|
"""Set the gateway process."""
|
||||||
|
self._gateway_process = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def gateway_port(self) -> int:
|
||||||
|
"""Get the gateway port."""
|
||||||
|
return self._gateway_port
|
||||||
|
|
||||||
|
@gateway_port.setter
|
||||||
|
def gateway_port(self, value: int) -> None:
|
||||||
|
"""Set the gateway port."""
|
||||||
|
self._gateway_port = value
|
||||||
|
|
||||||
|
async def set_runtime_manager(self, manager: Any) -> None:
|
||||||
|
"""Set runtime manager with lock protection."""
|
||||||
|
async with self._state_lock:
|
||||||
|
self._runtime_manager = manager
|
||||||
|
|
||||||
|
async def get_runtime_manager(self) -> Optional[Any]:
|
||||||
|
"""Get runtime manager with lock protection."""
|
||||||
|
async with self._state_lock:
|
||||||
|
return self._runtime_manager
|
||||||
|
|
||||||
|
async def set_gateway_process(self, process: Optional[subprocess.Popen]) -> None:
|
||||||
|
"""Set gateway process with lock protection."""
|
||||||
|
async with self._state_lock:
|
||||||
|
self._gateway_process = process
|
||||||
|
|
||||||
|
async def get_gateway_process(self) -> Optional[subprocess.Popen]:
|
||||||
|
"""Get gateway process with lock protection."""
|
||||||
|
async with self._state_lock:
|
||||||
|
return self._gateway_process
|
||||||
|
|
||||||
|
async def set_gateway_port(self, port: int) -> None:
|
||||||
|
"""Set gateway port with lock protection."""
|
||||||
|
async with self._state_lock:
|
||||||
|
self._gateway_port = port
|
||||||
|
|
||||||
|
async def get_gateway_port(self) -> int:
|
||||||
|
"""Get gateway port with lock protection."""
|
||||||
|
async with self._state_lock:
|
||||||
|
return self._gateway_port
|
||||||
|
|
||||||
|
|
||||||
|
# Singleton instance
|
||||||
|
_runtime_state = RuntimeState()
|
||||||
|
|
||||||
|
|
||||||
|
def get_runtime_state() -> RuntimeState:
|
||||||
|
"""Get the RuntimeState singleton instance."""
|
||||||
|
return _runtime_state
|
||||||
|
|
||||||
|
|
||||||
|
# Backward compatibility: module-level runtime_manager for external imports
|
||||||
|
# This is set by register_runtime_manager() for backward compatibility
|
||||||
|
runtime_manager: Optional[Any] = None
|
||||||
|
|
||||||
|
|
||||||
|
class RunContextResponse(BaseModel):
|
||||||
|
config_name: str
|
||||||
|
run_dir: str
|
||||||
|
bootstrap_values: Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class RuntimeAgentState(BaseModel):
|
||||||
|
agent_id: str
|
||||||
|
status: str
|
||||||
|
last_session: Optional[str] = None
|
||||||
|
last_updated: str
|
||||||
|
|
||||||
|
|
||||||
|
class RuntimeAgentsResponse(BaseModel):
|
||||||
|
agents: List[RuntimeAgentState]
|
||||||
|
|
||||||
|
|
||||||
|
class RuntimeEvent(BaseModel):
|
||||||
|
timestamp: str
|
||||||
|
event: str
|
||||||
|
details: Dict[str, Any]
|
||||||
|
session: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
|
class RuntimeEventsResponse(BaseModel):
|
||||||
|
events: List[RuntimeEvent]
|
||||||
|
|
||||||
|
|
||||||
|
class LaunchConfig(BaseModel):
|
||||||
|
"""Configuration for launching a new trading task."""
|
||||||
|
tickers: List[str] = Field(default_factory=list, description="股票池")
|
||||||
|
schedule_mode: str = Field(default="daily", description="调度模式: daily, interval")
|
||||||
|
interval_minutes: int = Field(default=60, ge=1, description="间隔分钟数")
|
||||||
|
trigger_time: str = Field(default="09:30", description="触发时间 HH:MM")
|
||||||
|
max_comm_cycles: int = Field(default=2, ge=1, description="最大会商轮数")
|
||||||
|
initial_cash: float = Field(default=100000.0, gt=0, description="初始资金")
|
||||||
|
margin_requirement: float = Field(default=0.0, ge=0, description="保证金要求")
|
||||||
|
enable_memory: bool = Field(default=False, description="是否启用长期记忆")
|
||||||
|
mode: str = Field(default="live", description="运行模式: live, backtest")
|
||||||
|
start_date: Optional[str] = Field(default=None, description="回测开始日期 YYYY-MM-DD")
|
||||||
|
end_date: Optional[str] = Field(default=None, description="回测结束日期 YYYY-MM-DD")
|
||||||
|
poll_interval: int = Field(default=10, ge=1, le=300, description="市场数据轮询间隔(秒)")
|
||||||
|
enable_mock: bool = Field(default=False, description="是否启用模拟模式(使用模拟价格数据)")
|
||||||
|
|
||||||
|
|
||||||
|
class LaunchResponse(BaseModel):
|
||||||
|
run_id: str
|
||||||
|
status: str
|
||||||
|
run_dir: str
|
||||||
|
gateway_port: int
|
||||||
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
class StopResponse(BaseModel):
|
||||||
|
status: str
|
||||||
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
class GatewayStatusResponse(BaseModel):
|
||||||
|
is_running: bool
|
||||||
|
port: int
|
||||||
|
run_id: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class RuntimeConfigResponse(BaseModel):
|
||||||
|
run_id: str
|
||||||
|
is_running: bool
|
||||||
|
gateway_port: int
|
||||||
|
bootstrap: Dict[str, Any]
|
||||||
|
resolved: Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class UpdateRuntimeConfigRequest(BaseModel):
|
||||||
|
schedule_mode: Optional[str] = None
|
||||||
|
interval_minutes: Optional[int] = Field(default=None, ge=1)
|
||||||
|
trigger_time: Optional[str] = None
|
||||||
|
max_comm_cycles: Optional[int] = Field(default=None, ge=1)
|
||||||
|
initial_cash: Optional[float] = Field(default=None, gt=0)
|
||||||
|
margin_requirement: Optional[float] = Field(default=None, ge=0)
|
||||||
|
enable_memory: Optional[bool] = None
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_run_id() -> str:
|
||||||
|
"""Generate timestamp-based run ID: YYYYMMDD_HHMMSS"""
|
||||||
|
return datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
|
||||||
|
|
||||||
|
def _get_run_dir(run_id: str) -> Path:
|
||||||
|
"""Return the run directory for a given run ID."""
|
||||||
|
return PROJECT_ROOT / "runs" / run_id
|
||||||
|
|
||||||
|
|
||||||
|
def _find_available_port(start_port: int = 8765, max_port: int = 9000) -> int:
|
||||||
|
"""Find an available port for Gateway."""
|
||||||
|
import socket
|
||||||
|
for port in range(start_port, max_port):
|
||||||
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||||
|
if s.connect_ex(('localhost', port)) != 0:
|
||||||
|
return port
|
||||||
|
raise RuntimeError("No available port found")
|
||||||
|
|
||||||
|
|
||||||
|
def _is_gateway_running() -> bool:
|
||||||
|
"""Check if Gateway process is running."""
|
||||||
|
process = _runtime_state.gateway_process
|
||||||
|
if process is None:
|
||||||
|
return False
|
||||||
|
return process.poll() is None
|
||||||
|
|
||||||
|
|
||||||
|
def _stop_gateway() -> bool:
|
||||||
|
"""Stop the Gateway process."""
|
||||||
|
process = _runtime_state.gateway_process
|
||||||
|
if process is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Try graceful shutdown first
|
||||||
|
process.terminate()
|
||||||
|
try:
|
||||||
|
process.wait(timeout=5)
|
||||||
|
except subprocess.TimeoutExpired:
|
||||||
|
# Force kill if graceful shutdown fails
|
||||||
|
process.kill()
|
||||||
|
process.wait()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error during gateway shutdown: {e}")
|
||||||
|
finally:
|
||||||
|
_runtime_state.gateway_process = None
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _start_gateway_process(
|
||||||
|
run_id: str,
|
||||||
|
run_dir: Path,
|
||||||
|
bootstrap: Dict[str, Any],
|
||||||
|
port: int
|
||||||
|
) -> subprocess.Popen:
|
||||||
|
"""Start Gateway as a separate process."""
|
||||||
|
# Prepare environment
|
||||||
|
env = os.environ.copy()
|
||||||
|
|
||||||
|
# Create command arguments
|
||||||
|
cmd = [
|
||||||
|
sys.executable,
|
||||||
|
"-m", "backend.gateway_server",
|
||||||
|
"--run-id", run_id,
|
||||||
|
"--run-dir", str(run_dir),
|
||||||
|
"--port", str(port),
|
||||||
|
"--bootstrap", json.dumps(bootstrap)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Start process
|
||||||
|
process = subprocess.Popen(
|
||||||
|
cmd,
|
||||||
|
env=env,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.PIPE,
|
||||||
|
cwd=PROJECT_ROOT
|
||||||
|
)
|
||||||
|
|
||||||
|
return process
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/context", response_model=RunContextResponse)
|
||||||
|
async def get_run_context() -> RunContextResponse:
|
||||||
|
"""Return the most recent run context."""
|
||||||
|
snapshot_path = PROJECT_ROOT.glob("runs/*/state/runtime_state.json")
|
||||||
|
snapshots = sorted(snapshot_path, key=lambda p: p.stat().st_mtime, reverse=True)
|
||||||
|
|
||||||
|
if not snapshots:
|
||||||
|
raise HTTPException(status_code=404, detail="No run context available")
|
||||||
|
|
||||||
|
latest = json.loads(snapshots[0].read_text(encoding="utf-8"))
|
||||||
|
context = latest.get("context")
|
||||||
|
if context is None:
|
||||||
|
raise HTTPException(status_code=404, detail="Run context is not ready")
|
||||||
|
|
||||||
|
return RunContextResponse(
|
||||||
|
config_name=context["config_name"],
|
||||||
|
run_dir=context["run_dir"],
|
||||||
|
bootstrap_values=context["bootstrap_values"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/agents", response_model=RuntimeAgentsResponse)
|
||||||
|
async def get_runtime_agents() -> RuntimeAgentsResponse:
|
||||||
|
"""Return agent states from the most recent run."""
|
||||||
|
snapshot_path = PROJECT_ROOT.glob("runs/*/state/runtime_state.json")
|
||||||
|
snapshots = sorted(snapshot_path, key=lambda p: p.stat().st_mtime, reverse=True)
|
||||||
|
|
||||||
|
if not snapshots:
|
||||||
|
raise HTTPException(status_code=404, detail="No runtime state available")
|
||||||
|
|
||||||
|
latest = json.loads(snapshots[0].read_text(encoding="utf-8"))
|
||||||
|
agents = latest.get("agents", [])
|
||||||
|
|
||||||
|
return RuntimeAgentsResponse(
|
||||||
|
agents=[RuntimeAgentState(**a) for a in agents]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/events", response_model=RuntimeEventsResponse)
|
||||||
|
async def get_runtime_events() -> RuntimeEventsResponse:
|
||||||
|
"""Return events from the most recent run."""
|
||||||
|
snapshot_path = PROJECT_ROOT.glob("runs/*/state/runtime_state.json")
|
||||||
|
snapshots = sorted(snapshot_path, key=lambda p: p.stat().st_mtime, reverse=True)
|
||||||
|
|
||||||
|
if not snapshots:
|
||||||
|
raise HTTPException(status_code=404, detail="No runtime state available")
|
||||||
|
|
||||||
|
latest = json.loads(snapshots[0].read_text(encoding="utf-8"))
|
||||||
|
events = latest.get("events", [])
|
||||||
|
|
||||||
|
return RuntimeEventsResponse(
|
||||||
|
events=[RuntimeEvent(**e) for e in events]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/gateway/status", response_model=GatewayStatusResponse)
|
||||||
|
async def get_gateway_status() -> GatewayStatusResponse:
|
||||||
|
"""Get Gateway process status and port."""
|
||||||
|
is_running = _is_gateway_running()
|
||||||
|
run_id = None
|
||||||
|
|
||||||
|
if is_running:
|
||||||
|
# Try to find run_id from runtime state
|
||||||
|
snapshot_path = PROJECT_ROOT.glob("runs/*/state/runtime_state.json")
|
||||||
|
snapshots = sorted(snapshot_path, key=lambda p: p.stat().st_mtime, reverse=True)
|
||||||
|
if snapshots:
|
||||||
|
try:
|
||||||
|
latest = json.loads(snapshots[0].read_text(encoding="utf-8"))
|
||||||
|
run_id = latest.get("context", {}).get("config_name")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to parse latest snapshot: {e}")
|
||||||
|
|
||||||
|
return GatewayStatusResponse(
|
||||||
|
is_running=is_running,
|
||||||
|
port=_runtime_state.gateway_port,
|
||||||
|
run_id=run_id
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/gateway/port")
|
||||||
|
async def get_gateway_port(request: Request) -> Dict[str, Any]:
|
||||||
|
"""Get WebSocket Gateway port for frontend connection."""
|
||||||
|
gateway_port = _runtime_state.gateway_port
|
||||||
|
return {
|
||||||
|
"port": gateway_port,
|
||||||
|
"is_running": _is_gateway_running(),
|
||||||
|
"ws_url": _build_gateway_ws_url(request, gateway_port),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _build_gateway_ws_url(request: Request, port: int) -> str:
|
||||||
|
"""Build a proxy-safe Gateway WebSocket URL."""
|
||||||
|
forwarded_proto = request.headers.get("x-forwarded-proto", "").split(",")[0].strip()
|
||||||
|
scheme = forwarded_proto or request.url.scheme
|
||||||
|
ws_scheme = "wss" if scheme == "https" else "ws"
|
||||||
|
|
||||||
|
forwarded_host = request.headers.get("x-forwarded-host", "").split(",")[0].strip()
|
||||||
|
host = forwarded_host or request.url.hostname or "localhost"
|
||||||
|
if ":" in host and not host.startswith("["):
|
||||||
|
host = host.split(":", 1)[0]
|
||||||
|
|
||||||
|
return f"{ws_scheme}://{host}:{port}"
|
||||||
|
|
||||||
|
|
||||||
|
def _load_latest_runtime_snapshot() -> Dict[str, Any]:
|
||||||
|
"""Load the latest persisted runtime snapshot."""
|
||||||
|
snapshots = sorted(
|
||||||
|
PROJECT_ROOT.glob("runs/*/state/runtime_state.json"),
|
||||||
|
key=lambda p: p.stat().st_mtime,
|
||||||
|
reverse=True,
|
||||||
|
)
|
||||||
|
if not snapshots:
|
||||||
|
raise HTTPException(status_code=404, detail="No runtime information available")
|
||||||
|
return json.loads(snapshots[0].read_text(encoding="utf-8"))
|
||||||
|
|
||||||
|
|
||||||
|
def _get_current_runtime_context() -> Dict[str, Any]:
|
||||||
|
"""Return the active runtime context from the latest snapshot."""
|
||||||
|
if not _is_gateway_running():
|
||||||
|
raise HTTPException(status_code=404, detail="No runtime is currently running")
|
||||||
|
latest = _load_latest_runtime_snapshot()
|
||||||
|
context = latest.get("context") or {}
|
||||||
|
if not context.get("config_name"):
|
||||||
|
raise HTTPException(status_code=404, detail="No runtime context available")
|
||||||
|
return context
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_runtime_response(run_id: str) -> RuntimeConfigResponse:
|
||||||
|
"""Build a normalized runtime config response for the active run."""
|
||||||
|
context = _get_current_runtime_context()
|
||||||
|
bootstrap = dict(context.get("bootstrap_values") or {})
|
||||||
|
resolved = resolve_runtime_config(
|
||||||
|
project_root=PROJECT_ROOT,
|
||||||
|
config_name=run_id,
|
||||||
|
enable_memory=bool(bootstrap.get("enable_memory", False)),
|
||||||
|
schedule_mode=str(bootstrap.get("schedule_mode", "daily")),
|
||||||
|
interval_minutes=int(bootstrap.get("interval_minutes", 60) or 60),
|
||||||
|
trigger_time=str(bootstrap.get("trigger_time", "09:30") or "09:30"),
|
||||||
|
)
|
||||||
|
return RuntimeConfigResponse(
|
||||||
|
run_id=run_id,
|
||||||
|
is_running=True,
|
||||||
|
gateway_port=_runtime_state.gateway_port,
|
||||||
|
bootstrap=bootstrap,
|
||||||
|
resolved=resolved,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_runtime_config_updates(
|
||||||
|
request: UpdateRuntimeConfigRequest,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Validate and normalize runtime config updates."""
|
||||||
|
updates: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
if request.schedule_mode is not None:
|
||||||
|
schedule_mode = str(request.schedule_mode).strip().lower()
|
||||||
|
if schedule_mode not in {"daily", "intraday"}:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="schedule_mode must be 'daily' or 'intraday'",
|
||||||
|
)
|
||||||
|
updates["schedule_mode"] = schedule_mode
|
||||||
|
|
||||||
|
if request.interval_minutes is not None:
|
||||||
|
updates["interval_minutes"] = int(request.interval_minutes)
|
||||||
|
|
||||||
|
if request.trigger_time is not None:
|
||||||
|
trigger_time = str(request.trigger_time).strip()
|
||||||
|
if trigger_time and trigger_time != "now":
|
||||||
|
try:
|
||||||
|
datetime.strptime(trigger_time, "%H:%M")
|
||||||
|
except ValueError as exc:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="trigger_time must use HH:MM or 'now'",
|
||||||
|
) from exc
|
||||||
|
updates["trigger_time"] = trigger_time or "09:30"
|
||||||
|
|
||||||
|
if request.max_comm_cycles is not None:
|
||||||
|
updates["max_comm_cycles"] = int(request.max_comm_cycles)
|
||||||
|
|
||||||
|
if request.initial_cash is not None:
|
||||||
|
updates["initial_cash"] = float(request.initial_cash)
|
||||||
|
|
||||||
|
if request.margin_requirement is not None:
|
||||||
|
updates["margin_requirement"] = float(request.margin_requirement)
|
||||||
|
|
||||||
|
if request.enable_memory is not None:
|
||||||
|
updates["enable_memory"] = bool(request.enable_memory)
|
||||||
|
|
||||||
|
if not updates:
|
||||||
|
raise HTTPException(status_code=400, detail="No runtime config updates provided")
|
||||||
|
|
||||||
|
return updates
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/start", response_model=LaunchResponse)
|
||||||
|
async def start_runtime(
|
||||||
|
config: LaunchConfig,
|
||||||
|
background_tasks: BackgroundTasks
|
||||||
|
) -> LaunchResponse:
|
||||||
|
"""Start a new trading runtime with the given configuration.
|
||||||
|
|
||||||
|
1. Stop existing Gateway if running
|
||||||
|
2. Generate run ID and directory
|
||||||
|
3. Create runtime manager
|
||||||
|
4. Start Gateway as subprocess (Data Plane)
|
||||||
|
5. Return Gateway port for WebSocket connection
|
||||||
|
"""
|
||||||
|
# Lazy import to avoid circular dependency
|
||||||
|
from backend.runtime.manager import TradingRuntimeManager
|
||||||
|
|
||||||
|
# 1. Stop existing Gateway
|
||||||
|
if _is_gateway_running():
|
||||||
|
_stop_gateway()
|
||||||
|
await asyncio.sleep(1) # Wait for port release
|
||||||
|
|
||||||
|
# 2. Generate run ID and directory
|
||||||
|
run_id = _generate_run_id()
|
||||||
|
run_dir = _get_run_dir(run_id)
|
||||||
|
|
||||||
|
# 3. Prepare bootstrap config
|
||||||
|
bootstrap = {
|
||||||
|
"tickers": config.tickers,
|
||||||
|
"schedule_mode": config.schedule_mode,
|
||||||
|
"interval_minutes": config.interval_minutes,
|
||||||
|
"trigger_time": config.trigger_time,
|
||||||
|
"max_comm_cycles": config.max_comm_cycles,
|
||||||
|
"initial_cash": config.initial_cash,
|
||||||
|
"margin_requirement": config.margin_requirement,
|
||||||
|
"enable_memory": config.enable_memory,
|
||||||
|
"mode": config.mode,
|
||||||
|
"start_date": config.start_date,
|
||||||
|
"end_date": config.end_date,
|
||||||
|
"poll_interval": config.poll_interval,
|
||||||
|
"enable_mock": config.enable_mock,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 4. Create runtime manager
|
||||||
|
manager = TradingRuntimeManager(
|
||||||
|
config_name=run_id,
|
||||||
|
run_dir=run_dir,
|
||||||
|
bootstrap=bootstrap,
|
||||||
|
)
|
||||||
|
manager.prepare_run()
|
||||||
|
register_runtime_manager(manager)
|
||||||
|
|
||||||
|
# 5. Write BOOTSTRAP.md
|
||||||
|
_write_bootstrap_md(run_dir, bootstrap)
|
||||||
|
|
||||||
|
# 6. Find available port and start Gateway process
|
||||||
|
gateway_port = _find_available_port(start_port=8765)
|
||||||
|
_runtime_state.gateway_port = gateway_port
|
||||||
|
|
||||||
|
try:
|
||||||
|
process = _start_gateway_process(
|
||||||
|
run_id=run_id,
|
||||||
|
run_dir=run_dir,
|
||||||
|
bootstrap=bootstrap,
|
||||||
|
port=gateway_port
|
||||||
|
)
|
||||||
|
_runtime_state.gateway_process = process
|
||||||
|
|
||||||
|
# Wait briefly to check if process started successfully
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
|
||||||
|
if not _is_gateway_running():
|
||||||
|
stdout, stderr = process.communicate(timeout=1)
|
||||||
|
_runtime_state.gateway_process = None
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail=f"Gateway failed to start: {stderr.decode() if stderr else 'Unknown error'}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
_stop_gateway()
|
||||||
|
raise HTTPException(status_code=500, detail=f"Failed to start Gateway: {str(e)}")
|
||||||
|
|
||||||
|
return LaunchResponse(
|
||||||
|
run_id=run_id,
|
||||||
|
status="started",
|
||||||
|
run_dir=str(run_dir),
|
||||||
|
gateway_port=gateway_port,
|
||||||
|
message=f"Runtime started with run_id: {run_id}, Gateway on port: {gateway_port}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/config", response_model=RuntimeConfigResponse)
|
||||||
|
async def get_runtime_config() -> RuntimeConfigResponse:
|
||||||
|
"""Return the current runtime bootstrap and resolved settings."""
|
||||||
|
context = _get_current_runtime_context()
|
||||||
|
return _resolve_runtime_response(context["config_name"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/config", response_model=RuntimeConfigResponse)
|
||||||
|
async def update_runtime_config(
|
||||||
|
request: UpdateRuntimeConfigRequest,
|
||||||
|
) -> RuntimeConfigResponse:
|
||||||
|
"""Persist selected runtime configuration updates for the active run."""
|
||||||
|
context = _get_current_runtime_context()
|
||||||
|
run_id = context["config_name"]
|
||||||
|
updates = _normalize_runtime_config_updates(request)
|
||||||
|
updated = update_bootstrap_values_for_run(PROJECT_ROOT, run_id, updates)
|
||||||
|
|
||||||
|
manager = _runtime_state.runtime_manager
|
||||||
|
if manager is not None and getattr(manager, "config_name", None) == run_id:
|
||||||
|
manager.bootstrap.update(updates)
|
||||||
|
if getattr(manager, "context", None) is not None:
|
||||||
|
manager.context.bootstrap_values.update(updates)
|
||||||
|
if hasattr(manager, "_persist_snapshot"):
|
||||||
|
manager._persist_snapshot()
|
||||||
|
|
||||||
|
response = _resolve_runtime_response(run_id)
|
||||||
|
response.bootstrap = dict(updated.values)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/stop", response_model=StopResponse)
|
||||||
|
async def stop_runtime(force: bool = True) -> StopResponse:
|
||||||
|
"""Stop the current running runtime."""
|
||||||
|
was_running = _is_gateway_running()
|
||||||
|
|
||||||
|
if not was_running:
|
||||||
|
raise HTTPException(status_code=404, detail="No runtime is currently running")
|
||||||
|
|
||||||
|
# Stop Gateway process
|
||||||
|
_stop_gateway()
|
||||||
|
|
||||||
|
# Unregister runtime manager
|
||||||
|
unregister_runtime_manager()
|
||||||
|
|
||||||
|
return StopResponse(
|
||||||
|
status="stopped",
|
||||||
|
message="Runtime stopped successfully",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/restart")
|
||||||
|
async def restart_runtime(
|
||||||
|
config: LaunchConfig,
|
||||||
|
background_tasks: BackgroundTasks
|
||||||
|
):
|
||||||
|
"""Restart the runtime with a new configuration."""
|
||||||
|
# Stop current runtime
|
||||||
|
await stop_runtime(force=True)
|
||||||
|
|
||||||
|
# Start new runtime
|
||||||
|
response = await start_runtime(config, background_tasks)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"run_id": response.run_id,
|
||||||
|
"status": "restarted",
|
||||||
|
"gateway_port": response.gateway_port,
|
||||||
|
"message": f"Runtime restarted with run_id: {response.run_id}",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/current")
|
||||||
|
async def get_current_runtime():
|
||||||
|
"""Get information about the currently running runtime."""
|
||||||
|
if not _is_gateway_running():
|
||||||
|
raise HTTPException(status_code=404, detail="No runtime is currently running")
|
||||||
|
|
||||||
|
# Find latest runtime state
|
||||||
|
snapshot_path = PROJECT_ROOT.glob("runs/*/state/runtime_state.json")
|
||||||
|
snapshots = sorted(snapshot_path, key=lambda p: p.stat().st_mtime, reverse=True)
|
||||||
|
|
||||||
|
if not snapshots:
|
||||||
|
raise HTTPException(status_code=404, detail="No runtime information available")
|
||||||
|
|
||||||
|
latest = json.loads(snapshots[0].read_text(encoding="utf-8"))
|
||||||
|
context = latest.get("context", {})
|
||||||
|
|
||||||
|
return {
|
||||||
|
"run_id": context.get("config_name"),
|
||||||
|
"run_dir": context.get("run_dir"),
|
||||||
|
"is_running": True,
|
||||||
|
"gateway_port": _runtime_state.gateway_port,
|
||||||
|
"bootstrap": context.get("bootstrap_values", {}),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def register_runtime_manager(manager: Any) -> None:
|
||||||
|
"""Allow other modules to expose the runtime manager to the API."""
|
||||||
|
global runtime_manager
|
||||||
|
runtime_manager = manager
|
||||||
|
# Also update the RuntimeState singleton for internal consistency
|
||||||
|
_runtime_state.runtime_manager = manager
|
||||||
|
|
||||||
|
|
||||||
|
def unregister_runtime_manager() -> None:
|
||||||
|
"""Drop the runtime manager reference."""
|
||||||
|
global runtime_manager
|
||||||
|
runtime_manager = None
|
||||||
|
# Also update the RuntimeState singleton for internal consistency
|
||||||
|
_runtime_state.runtime_manager = None
|
||||||
|
|
||||||
|
|
||||||
|
def _write_bootstrap_md(run_dir: Path, bootstrap: Dict[str, Any]) -> None:
|
||||||
|
"""Write bootstrap configuration to BOOTSTRAP.md."""
|
||||||
|
try:
|
||||||
|
import yaml
|
||||||
|
except ImportError:
|
||||||
|
yaml = None
|
||||||
|
|
||||||
|
bootstrap_path = run_dir / "BOOTSTRAP.md"
|
||||||
|
bootstrap_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Filter out None values
|
||||||
|
values = {k: v for k, v in bootstrap.items() if v is not None}
|
||||||
|
|
||||||
|
if yaml:
|
||||||
|
front_matter = yaml.safe_dump(values, allow_unicode=True, sort_keys=False)
|
||||||
|
else:
|
||||||
|
front_matter = json.dumps(values, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
content = f"---\n{front_matter}---\n"
|
||||||
|
bootstrap_path.write_text(content, encoding="utf-8")
|
||||||
196
backend/api/workspaces.py
Normal file
196
backend/api/workspaces.py
Normal file
@@ -0,0 +1,196 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
Workspace API Routes
|
||||||
|
|
||||||
|
Provides REST API endpoints for workspace management.
|
||||||
|
"""
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException, Depends
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from backend.agents import WorkspaceManager
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/workspaces", tags=["workspaces"])
|
||||||
|
|
||||||
|
|
||||||
|
# Request/Response Models
|
||||||
|
class CreateWorkspaceRequest(BaseModel):
|
||||||
|
"""Request to create a new workspace."""
|
||||||
|
workspace_id: str = Field(..., description="Unique workspace identifier")
|
||||||
|
name: Optional[str] = Field(None, description="Display name")
|
||||||
|
description: Optional[str] = Field(None, description="Workspace description")
|
||||||
|
metadata: Optional[Dict[str, Any]] = Field(None, description="Additional metadata")
|
||||||
|
|
||||||
|
|
||||||
|
class UpdateWorkspaceRequest(BaseModel):
|
||||||
|
"""Request to update a workspace."""
|
||||||
|
name: Optional[str] = None
|
||||||
|
description: Optional[str] = None
|
||||||
|
metadata: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class WorkspaceResponse(BaseModel):
|
||||||
|
"""Workspace information response."""
|
||||||
|
workspace_id: str
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
created_at: Optional[str] = None
|
||||||
|
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class WorkspaceListResponse(BaseModel):
|
||||||
|
"""List of workspaces response."""
|
||||||
|
workspaces: List[WorkspaceResponse]
|
||||||
|
total: int
|
||||||
|
|
||||||
|
|
||||||
|
# Dependencies
|
||||||
|
def get_workspace_manager():
|
||||||
|
"""Get WorkspaceManager instance."""
|
||||||
|
return WorkspaceManager()
|
||||||
|
|
||||||
|
|
||||||
|
# Routes
|
||||||
|
@router.post("", response_model=WorkspaceResponse)
|
||||||
|
async def create_workspace(
|
||||||
|
request: CreateWorkspaceRequest,
|
||||||
|
manager: WorkspaceManager = Depends(get_workspace_manager),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Create a new workspace.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: Workspace creation parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Created workspace information
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
config = manager.create_workspace(
|
||||||
|
workspace_id=request.workspace_id,
|
||||||
|
name=request.name,
|
||||||
|
description=request.description,
|
||||||
|
metadata=request.metadata or {},
|
||||||
|
)
|
||||||
|
return WorkspaceResponse(
|
||||||
|
workspace_id=config.workspace_id,
|
||||||
|
name=config.name,
|
||||||
|
description=config.description,
|
||||||
|
created_at=config.created_at,
|
||||||
|
metadata=config.metadata,
|
||||||
|
)
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("", response_model=WorkspaceListResponse)
|
||||||
|
async def list_workspaces(
|
||||||
|
manager: WorkspaceManager = Depends(get_workspace_manager),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
List all workspaces.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of workspaces
|
||||||
|
"""
|
||||||
|
workspaces = manager.list_workspaces()
|
||||||
|
return WorkspaceListResponse(
|
||||||
|
workspaces=[
|
||||||
|
WorkspaceResponse(
|
||||||
|
workspace_id=ws.workspace_id,
|
||||||
|
name=ws.name,
|
||||||
|
description=ws.description,
|
||||||
|
created_at=ws.created_at,
|
||||||
|
metadata=ws.metadata,
|
||||||
|
)
|
||||||
|
for ws in workspaces
|
||||||
|
],
|
||||||
|
total=len(workspaces),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{workspace_id}", response_model=WorkspaceResponse)
|
||||||
|
async def get_workspace(
|
||||||
|
workspace_id: str,
|
||||||
|
manager: WorkspaceManager = Depends(get_workspace_manager),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get workspace details.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_id: Workspace identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Workspace information
|
||||||
|
"""
|
||||||
|
workspace = manager.get_workspace(workspace_id)
|
||||||
|
if not workspace:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Workspace '{workspace_id}' not found")
|
||||||
|
|
||||||
|
return WorkspaceResponse(
|
||||||
|
workspace_id=workspace["workspace_id"],
|
||||||
|
name=workspace.get("name", workspace_id),
|
||||||
|
description=workspace.get("description", ""),
|
||||||
|
created_at=workspace.get("created_at"),
|
||||||
|
metadata=workspace.get("metadata", {}),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.patch("/{workspace_id}", response_model=WorkspaceResponse)
|
||||||
|
async def update_workspace(
|
||||||
|
workspace_id: str,
|
||||||
|
request: UpdateWorkspaceRequest,
|
||||||
|
manager: WorkspaceManager = Depends(get_workspace_manager),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Update workspace configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_id: Workspace identifier
|
||||||
|
request: Update parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated workspace information
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
config = manager.update_workspace_config(
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
name=request.name,
|
||||||
|
description=request.description,
|
||||||
|
metadata=request.metadata,
|
||||||
|
)
|
||||||
|
return WorkspaceResponse(
|
||||||
|
workspace_id=config.workspace_id,
|
||||||
|
name=config.name,
|
||||||
|
description=config.description,
|
||||||
|
created_at=config.created_at,
|
||||||
|
metadata=config.metadata,
|
||||||
|
)
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{workspace_id}")
|
||||||
|
async def delete_workspace(
|
||||||
|
workspace_id: str,
|
||||||
|
force: bool = False,
|
||||||
|
manager: WorkspaceManager = Depends(get_workspace_manager),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Delete a workspace.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_id: Workspace identifier
|
||||||
|
force: If True, delete even if workspace has agents
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Success message
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
success = manager.delete_workspace(workspace_id, force=force)
|
||||||
|
if not success:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Workspace '{workspace_id}' not found")
|
||||||
|
return {"message": f"Workspace '{workspace_id}' deleted successfully"}
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
27
backend/apps/__init__.py
Normal file
27
backend/apps/__init__.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""Application surfaces for progressive service extraction."""
|
||||||
|
|
||||||
|
from .agent_service import app as agent_app
|
||||||
|
from .agent_service import create_app as create_agent_app
|
||||||
|
from .news_service import app as news_app
|
||||||
|
from .news_service import create_app as create_news_app
|
||||||
|
from .runtime_service import app as runtime_app
|
||||||
|
from .runtime_service import create_app as create_runtime_app
|
||||||
|
from .trading_service import app as trading_app
|
||||||
|
from .trading_service import create_app as create_trading_app
|
||||||
|
|
||||||
|
app = agent_app
|
||||||
|
create_app = create_agent_app
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"app",
|
||||||
|
"create_app",
|
||||||
|
"agent_app",
|
||||||
|
"create_agent_app",
|
||||||
|
"news_app",
|
||||||
|
"create_news_app",
|
||||||
|
"runtime_app",
|
||||||
|
"create_runtime_app",
|
||||||
|
"trading_app",
|
||||||
|
"create_trading_app",
|
||||||
|
]
|
||||||
94
backend/apps/agent_service.py
Normal file
94
backend/apps/agent_service.py
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""Agent control-plane FastAPI surface."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
|
from backend.api import agents_router, guard_router, workspaces_router
|
||||||
|
from backend.agents import AgentFactory, WorkspaceManager, get_registry
|
||||||
|
|
||||||
|
# Global instances (initialized on startup)
|
||||||
|
agent_factory: AgentFactory | None = None
|
||||||
|
workspace_manager: WorkspaceManager | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def create_app(project_root: Path | None = None) -> FastAPI:
|
||||||
|
"""Create the agent control-plane app."""
|
||||||
|
resolved_project_root = project_root or Path(__file__).resolve().parents[2]
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(_app: FastAPI) -> AsyncGenerator[None, None]:
|
||||||
|
"""Initialize workspace and registry state for the control plane."""
|
||||||
|
global agent_factory, workspace_manager
|
||||||
|
|
||||||
|
workspace_manager = WorkspaceManager(project_root=resolved_project_root)
|
||||||
|
agent_factory = AgentFactory(project_root=resolved_project_root)
|
||||||
|
agent_factory.workspaces_root.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
registry = get_registry()
|
||||||
|
print("✓ EvoTraders API started")
|
||||||
|
print(f" - Workspaces root: {agent_factory.workspaces_root}")
|
||||||
|
print(f" - Registered agents: {registry.get_agent_count()}")
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
print("✓ EvoTraders API shutting down")
|
||||||
|
|
||||||
|
app = FastAPI(
|
||||||
|
title="EvoTraders Agent Service",
|
||||||
|
description="REST API for the EvoTraders multi-agent control plane",
|
||||||
|
version="0.1.0",
|
||||||
|
lifespan=lifespan,
|
||||||
|
)
|
||||||
|
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
async def health_check() -> dict[str, object]:
|
||||||
|
"""Health check endpoint."""
|
||||||
|
registry = get_registry()
|
||||||
|
return {
|
||||||
|
"status": "healthy",
|
||||||
|
"version": "0.1.0",
|
||||||
|
"agents_registered": registry.get_agent_count(),
|
||||||
|
"workspaces_available": (
|
||||||
|
len(workspace_manager.list_workspaces())
|
||||||
|
if workspace_manager
|
||||||
|
else 0
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
@app.get("/api/status")
|
||||||
|
async def api_status() -> dict[str, object]:
|
||||||
|
"""Get API status and registry information."""
|
||||||
|
registry = get_registry()
|
||||||
|
return {
|
||||||
|
"status": "operational",
|
||||||
|
"registry": registry.get_stats(),
|
||||||
|
}
|
||||||
|
|
||||||
|
app.include_router(workspaces_router)
|
||||||
|
app.include_router(agents_router)
|
||||||
|
app.include_router(guard_router)
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
app = create_app()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||||
153
backend/apps/news_service.py
Normal file
153
backend/apps/news_service.py
Normal file
@@ -0,0 +1,153 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""News and explain FastAPI surface."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from fastapi import Depends, FastAPI, Query
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
|
from backend.data.market_store import MarketStore
|
||||||
|
from backend.domains import news as news_domain
|
||||||
|
|
||||||
|
|
||||||
|
def get_market_store() -> MarketStore:
|
||||||
|
"""Create a market store dependency."""
|
||||||
|
return MarketStore()
|
||||||
|
|
||||||
|
|
||||||
|
def create_app() -> FastAPI:
|
||||||
|
"""Create the news/explain service app."""
|
||||||
|
app = FastAPI(
|
||||||
|
title="EvoTraders News Service",
|
||||||
|
description="Read-only news enrichment and explain service surface extracted from the monolith",
|
||||||
|
version="0.1.0",
|
||||||
|
)
|
||||||
|
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
async def health_check() -> dict[str, str]:
|
||||||
|
return {"status": "healthy", "service": "news-service"}
|
||||||
|
|
||||||
|
@app.get("/api/enriched-news")
|
||||||
|
async def api_get_enriched_news(
|
||||||
|
ticker: str = Query(..., min_length=1),
|
||||||
|
start_date: str | None = Query(None),
|
||||||
|
end_date: str | None = Query(None),
|
||||||
|
limit: int = Query(100, ge=1, le=1000),
|
||||||
|
store: MarketStore = Depends(get_market_store),
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
return news_domain.get_enriched_news(
|
||||||
|
store,
|
||||||
|
ticker=ticker,
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
@app.get("/api/news-for-date")
|
||||||
|
async def api_get_news_for_date(
|
||||||
|
ticker: str = Query(..., min_length=1),
|
||||||
|
date: str = Query(...),
|
||||||
|
limit: int = Query(20, ge=1, le=100),
|
||||||
|
store: MarketStore = Depends(get_market_store),
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
return news_domain.get_news_for_date(
|
||||||
|
store,
|
||||||
|
ticker=ticker,
|
||||||
|
date=date,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
@app.get("/api/news-timeline")
|
||||||
|
async def api_get_news_timeline(
|
||||||
|
ticker: str = Query(..., min_length=1),
|
||||||
|
start_date: str = Query(...),
|
||||||
|
end_date: str = Query(...),
|
||||||
|
store: MarketStore = Depends(get_market_store),
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
return news_domain.get_news_timeline(
|
||||||
|
store,
|
||||||
|
ticker=ticker,
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
)
|
||||||
|
|
||||||
|
@app.get("/api/categories")
|
||||||
|
async def api_get_categories(
|
||||||
|
ticker: str = Query(..., min_length=1),
|
||||||
|
start_date: str | None = Query(None),
|
||||||
|
end_date: str | None = Query(None),
|
||||||
|
limit: int = Query(200, ge=1, le=1000),
|
||||||
|
store: MarketStore = Depends(get_market_store),
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
return news_domain.get_news_categories(
|
||||||
|
store,
|
||||||
|
ticker=ticker,
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
@app.get("/api/similar-days")
|
||||||
|
async def api_get_similar_days(
|
||||||
|
ticker: str = Query(..., min_length=1),
|
||||||
|
date: str = Query(...),
|
||||||
|
n_similar: int = Query(5, ge=1, le=20),
|
||||||
|
store: MarketStore = Depends(get_market_store),
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
return news_domain.get_similar_days_payload(
|
||||||
|
store,
|
||||||
|
ticker=ticker,
|
||||||
|
date=date,
|
||||||
|
n_similar=n_similar,
|
||||||
|
)
|
||||||
|
|
||||||
|
@app.get("/api/stories/{ticker}")
|
||||||
|
async def api_get_story(
|
||||||
|
ticker: str,
|
||||||
|
as_of_date: str = Query(...),
|
||||||
|
store: MarketStore = Depends(get_market_store),
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
return news_domain.get_story_payload(
|
||||||
|
store,
|
||||||
|
ticker=ticker,
|
||||||
|
as_of_date=as_of_date,
|
||||||
|
)
|
||||||
|
|
||||||
|
@app.get("/api/range-explain")
|
||||||
|
async def api_get_range_explain(
|
||||||
|
ticker: str = Query(..., min_length=1),
|
||||||
|
start_date: str = Query(...),
|
||||||
|
end_date: str = Query(...),
|
||||||
|
article_ids: list[str] = Query(default=[]),
|
||||||
|
limit: int = Query(100, ge=1, le=500),
|
||||||
|
store: MarketStore = Depends(get_market_store),
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
return news_domain.get_range_explain_payload(
|
||||||
|
store,
|
||||||
|
ticker=ticker,
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
article_ids=article_ids,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
app = create_app()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
uvicorn.run(app, host="0.0.0.0", port=8002)
|
||||||
68
backend/apps/runtime_service.py
Normal file
68
backend/apps/runtime_service.py
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""Dedicated runtime service FastAPI surface."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
|
from backend.api import runtime_router
|
||||||
|
from backend.api.runtime import get_runtime_state
|
||||||
|
|
||||||
|
|
||||||
|
def create_app() -> FastAPI:
|
||||||
|
"""Create the runtime service app."""
|
||||||
|
app = FastAPI(
|
||||||
|
title="EvoTraders Runtime Service",
|
||||||
|
description="Runtime lifecycle and gateway service surface extracted from the monolith",
|
||||||
|
version="0.1.0",
|
||||||
|
)
|
||||||
|
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
async def health_check() -> dict[str, object]:
|
||||||
|
"""Health check for the runtime service."""
|
||||||
|
runtime_state = get_runtime_state()
|
||||||
|
process = runtime_state.gateway_process
|
||||||
|
is_running = process is not None and process.poll() is None
|
||||||
|
return {
|
||||||
|
"status": "healthy",
|
||||||
|
"service": "runtime-service",
|
||||||
|
"gateway_running": is_running,
|
||||||
|
"gateway_port": runtime_state.gateway_port,
|
||||||
|
}
|
||||||
|
|
||||||
|
@app.get("/api/status")
|
||||||
|
async def api_status() -> dict[str, object]:
|
||||||
|
"""Service-level status payload for runtime orchestration."""
|
||||||
|
runtime_state = get_runtime_state()
|
||||||
|
process = runtime_state.gateway_process
|
||||||
|
is_running = process is not None and process.poll() is None
|
||||||
|
return {
|
||||||
|
"status": "operational",
|
||||||
|
"service": "runtime-service",
|
||||||
|
"runtime": {
|
||||||
|
"gateway_running": is_running,
|
||||||
|
"gateway_port": runtime_state.gateway_port,
|
||||||
|
"has_runtime_manager": runtime_state.runtime_manager is not None,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
app.include_router(runtime_router)
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
app = create_app()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
uvicorn.run(app, host="0.0.0.0", port=8003)
|
||||||
142
backend/apps/trading_service.py
Normal file
142
backend/apps/trading_service.py
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""Trading data FastAPI surface."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from fastapi import FastAPI, Query
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
|
from backend.domains import trading as trading_domain
|
||||||
|
from shared.schema import (
|
||||||
|
CompanyNewsResponse,
|
||||||
|
FinancialMetricsResponse,
|
||||||
|
InsiderTradeResponse,
|
||||||
|
LineItemResponse,
|
||||||
|
PriceResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_app() -> FastAPI:
|
||||||
|
"""Create the trading data service app."""
|
||||||
|
app = FastAPI(
|
||||||
|
title="EvoTraders Trading Service",
|
||||||
|
description="Read-only trading data service surface extracted from the monolith",
|
||||||
|
version="0.1.0",
|
||||||
|
)
|
||||||
|
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
async def health_check() -> dict[str, str]:
|
||||||
|
"""Health check endpoint."""
|
||||||
|
return {"status": "healthy", "service": "trading-service"}
|
||||||
|
|
||||||
|
@app.get("/api/prices", response_model=PriceResponse)
|
||||||
|
async def api_get_prices(
|
||||||
|
ticker: str = Query(..., min_length=1),
|
||||||
|
start_date: str = Query(...),
|
||||||
|
end_date: str = Query(...),
|
||||||
|
) -> PriceResponse:
|
||||||
|
payload = trading_domain.get_prices_payload(
|
||||||
|
ticker=ticker,
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
)
|
||||||
|
return PriceResponse(ticker=payload["ticker"], prices=payload["prices"])
|
||||||
|
|
||||||
|
@app.get("/api/financials", response_model=FinancialMetricsResponse)
|
||||||
|
async def api_get_financials(
|
||||||
|
ticker: str = Query(..., min_length=1),
|
||||||
|
end_date: str = Query(...),
|
||||||
|
period: str = Query("ttm"),
|
||||||
|
limit: int = Query(10, ge=1, le=100),
|
||||||
|
) -> FinancialMetricsResponse:
|
||||||
|
payload = trading_domain.get_financials_payload(
|
||||||
|
ticker=ticker,
|
||||||
|
end_date=end_date,
|
||||||
|
period=period,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
return FinancialMetricsResponse(financial_metrics=payload["financial_metrics"])
|
||||||
|
|
||||||
|
@app.get("/api/news", response_model=CompanyNewsResponse)
|
||||||
|
async def api_get_news(
|
||||||
|
ticker: str = Query(..., min_length=1),
|
||||||
|
end_date: str = Query(...),
|
||||||
|
start_date: str | None = Query(None),
|
||||||
|
limit: int = Query(1000, ge=1, le=5000),
|
||||||
|
) -> CompanyNewsResponse:
|
||||||
|
payload = trading_domain.get_news_payload(
|
||||||
|
ticker=ticker,
|
||||||
|
end_date=end_date,
|
||||||
|
start_date=start_date,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
return CompanyNewsResponse(news=payload["news"])
|
||||||
|
|
||||||
|
@app.get("/api/insider-trades", response_model=InsiderTradeResponse)
|
||||||
|
async def api_get_insider_trades(
|
||||||
|
ticker: str = Query(..., min_length=1),
|
||||||
|
end_date: str = Query(...),
|
||||||
|
start_date: str | None = Query(None),
|
||||||
|
limit: int = Query(1000, ge=1, le=5000),
|
||||||
|
) -> InsiderTradeResponse:
|
||||||
|
payload = trading_domain.get_insider_trades_payload(
|
||||||
|
ticker=ticker,
|
||||||
|
end_date=end_date,
|
||||||
|
start_date=start_date,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
return InsiderTradeResponse(insider_trades=payload["insider_trades"])
|
||||||
|
|
||||||
|
@app.get("/api/market/status")
|
||||||
|
async def api_get_market_status() -> dict[str, Any]:
|
||||||
|
"""Return current market status using the existing market service logic."""
|
||||||
|
return trading_domain.get_market_status_payload()
|
||||||
|
|
||||||
|
@app.get("/api/market-cap")
|
||||||
|
async def api_get_market_cap(
|
||||||
|
ticker: str = Query(..., min_length=1),
|
||||||
|
end_date: str = Query(...),
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Return market cap for one ticker/date."""
|
||||||
|
return trading_domain.get_market_cap_payload(
|
||||||
|
ticker=ticker,
|
||||||
|
end_date=end_date,
|
||||||
|
)
|
||||||
|
|
||||||
|
@app.get("/api/line-items", response_model=LineItemResponse)
|
||||||
|
async def api_get_line_items(
|
||||||
|
ticker: str = Query(..., min_length=1),
|
||||||
|
line_items: list[str] = Query(...),
|
||||||
|
end_date: str = Query(...),
|
||||||
|
period: str = Query("ttm"),
|
||||||
|
limit: int = Query(10, ge=1, le=100),
|
||||||
|
) -> LineItemResponse:
|
||||||
|
payload = trading_domain.get_line_items_payload(
|
||||||
|
ticker=ticker,
|
||||||
|
line_items=line_items,
|
||||||
|
end_date=end_date,
|
||||||
|
period=period,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
return LineItemResponse(search_results=payload["search_results"])
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
app = create_app()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
uvicorn.run(app, host="0.0.0.0", port=8001)
|
||||||
711
backend/cli.py
711
backend/cli.py
@@ -8,31 +8,62 @@ and frontend development server.
|
|||||||
"""
|
"""
|
||||||
# flake8: noqa: E501
|
# flake8: noqa: E501
|
||||||
# pylint: disable=R0912, R0915
|
# pylint: disable=R0912, R0915
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
from datetime import datetime
|
from datetime import datetime, timedelta
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from zoneinfo import ZoneInfo
|
from zoneinfo import ZoneInfo
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
import typer
|
import typer
|
||||||
|
import yaml
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.panel import Panel
|
from rich.panel import Panel
|
||||||
from rich.prompt import Confirm
|
from rich.prompt import Confirm
|
||||||
|
from rich.table import Table
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
from backend.agents.prompt_loader import PromptLoader
|
from backend.agents.agent_workspace import load_agent_workspace_config
|
||||||
|
from backend.agents.prompt_loader import get_prompt_loader
|
||||||
|
from backend.agents.skills_manager import SkillsManager
|
||||||
|
from backend.agents.team_pipeline_config import (
|
||||||
|
ensure_team_pipeline_config,
|
||||||
|
load_team_pipeline_config,
|
||||||
|
)
|
||||||
from backend.agents.workspace_manager import WorkspaceManager
|
from backend.agents.workspace_manager import WorkspaceManager
|
||||||
|
from backend.config.constants import ANALYST_TYPES
|
||||||
|
from backend.data.market_ingest import ingest_symbols
|
||||||
|
from backend.data.market_store import MarketStore
|
||||||
|
from backend.enrich.llm_enricher import get_explain_model_info, llm_enrichment_enabled
|
||||||
|
from backend.enrich.news_enricher import enrich_symbols
|
||||||
|
|
||||||
app = typer.Typer(
|
app = typer.Typer(
|
||||||
name="evotraders",
|
name="evotraders",
|
||||||
help="EvoTraders: A self-evolving multi-agent trading system",
|
help="EvoTraders: A self-evolving multi-agent trading system",
|
||||||
add_completion=False,
|
add_completion=False,
|
||||||
)
|
)
|
||||||
|
ingest_app = typer.Typer(help="Ingest Polygon market data into the research warehouse.")
|
||||||
|
app.add_typer(ingest_app, name="ingest")
|
||||||
|
skills_app = typer.Typer(help="Inspect and manage per-agent skills.")
|
||||||
|
app.add_typer(skills_app, name="skills")
|
||||||
|
team_app = typer.Typer(help="Inspect and manage run-scoped team pipeline config.")
|
||||||
|
app.add_typer(team_app, name="team")
|
||||||
|
|
||||||
console = Console()
|
console = Console()
|
||||||
_prompt_loader = PromptLoader()
|
_prompt_loader = get_prompt_loader()
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_typer_value(value, default):
|
||||||
|
"""Allow CLI command functions to be called directly in tests/internal code."""
|
||||||
|
if hasattr(value, "default"):
|
||||||
|
return value.default
|
||||||
|
return default if value is None else value
|
||||||
|
|
||||||
|
|
||||||
def get_project_root() -> Path:
|
def get_project_root() -> Path:
|
||||||
@@ -49,9 +80,8 @@ def handle_history_cleanup(config_name: str, auto_clean: bool = False) -> None:
|
|||||||
config_name: Configuration name for the run
|
config_name: Configuration name for the run
|
||||||
auto_clean: If True, skip confirmation and clean automatically
|
auto_clean: If True, skip confirmation and clean automatically
|
||||||
"""
|
"""
|
||||||
# logs_dir = get_project_root() / "logs"
|
workspace_manager = WorkspaceManager(project_root=get_project_root())
|
||||||
logs_dir = get_project_root()
|
base_data_dir = workspace_manager.get_run_dir(config_name)
|
||||||
base_data_dir = logs_dir / config_name
|
|
||||||
|
|
||||||
# Check if historical data exists
|
# Check if historical data exists
|
||||||
if not base_data_dir.exists() or not any(base_data_dir.iterdir()):
|
if not base_data_dir.exists() or not any(base_data_dir.iterdir()):
|
||||||
@@ -76,8 +106,8 @@ def handle_history_cleanup(config_name: str, auto_clean: bool = False) -> None:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
console.print(f" Directory size: [cyan]{size_mb:.1f} MB[/cyan]")
|
console.print(f" Directory size: [cyan]{size_mb:.1f} MB[/cyan]")
|
||||||
except Exception:
|
except Exception as e:
|
||||||
pass
|
logger.debug(f"Could not calculate directory size: {e}")
|
||||||
|
|
||||||
# Show last modified time
|
# Show last modified time
|
||||||
state_dir = base_data_dir / "state"
|
state_dir = base_data_dir / "state"
|
||||||
@@ -178,7 +208,8 @@ def run_data_updater(project_root: Path) -> None:
|
|||||||
console.print(
|
console.print(
|
||||||
"[yellow] Data updater module not available, skipping update[/yellow]\n",
|
"[yellow] Data updater module not available, skipping update[/yellow]\n",
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception as e:
|
||||||
|
logger.debug(f"Data updater check failed: {e}")
|
||||||
console.print(
|
console.print(
|
||||||
"[yellow] Data updater check failed, skipping update[/yellow]\n",
|
"[yellow] Data updater check failed, skipping update[/yellow]\n",
|
||||||
)
|
)
|
||||||
@@ -205,6 +236,202 @@ def initialize_workspace(config_name: str) -> Path:
|
|||||||
return workspace_manager.get_run_dir(config_name)
|
return workspace_manager.get_run_dir(config_name)
|
||||||
|
|
||||||
|
|
||||||
|
def _require_agent_asset_dir(config_name: str, agent_id: str) -> Path:
|
||||||
|
manager = WorkspaceManager(project_root=get_project_root())
|
||||||
|
manager.initialize_default_assets(
|
||||||
|
config_name=config_name,
|
||||||
|
agent_ids=[agent_id],
|
||||||
|
analyst_personas=_prompt_loader.load_yaml_config(
|
||||||
|
"analyst",
|
||||||
|
"personas",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return manager.skills_manager.get_agent_asset_dir(config_name, agent_id)
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_symbols(raw_tickers: Optional[str], config_name: Optional[str] = None) -> list[str]:
|
||||||
|
"""Resolve symbols from explicit input or runtime bootstrap config."""
|
||||||
|
if raw_tickers and raw_tickers.strip():
|
||||||
|
return [
|
||||||
|
item.strip().upper()
|
||||||
|
for item in raw_tickers.split(",")
|
||||||
|
if item.strip()
|
||||||
|
]
|
||||||
|
|
||||||
|
workspace_manager = WorkspaceManager(project_root=get_project_root())
|
||||||
|
bootstrap_path = workspace_manager.get_run_dir(config_name or "default") / "BOOTSTRAP.md"
|
||||||
|
if bootstrap_path.exists():
|
||||||
|
content = bootstrap_path.read_text(encoding="utf-8")
|
||||||
|
for line in content.splitlines():
|
||||||
|
if line.strip().startswith("tickers:"):
|
||||||
|
raw = line.split(":", 1)[1]
|
||||||
|
return [
|
||||||
|
item.strip().upper()
|
||||||
|
for item in raw.split(",")
|
||||||
|
if item.strip()
|
||||||
|
]
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def _filter_problematic_report_rows(rows: list[dict]) -> list[dict]:
|
||||||
|
"""Keep tickers with incomplete coverage or without any LLM-enriched rows."""
|
||||||
|
return [
|
||||||
|
row
|
||||||
|
for row in rows
|
||||||
|
if float(row.get("coverage_pct") or 0.0) < 100.0
|
||||||
|
or int(row.get("llm_count") or 0) == 0
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def auto_update_market_store(
|
||||||
|
config_name: str,
|
||||||
|
*,
|
||||||
|
end_date: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Refresh the long-lived Polygon market store for the active watchlist."""
|
||||||
|
api_key = os.getenv("POLYGON_API_KEY", "").strip()
|
||||||
|
if not api_key:
|
||||||
|
console.print(
|
||||||
|
"[dim]Skipping Polygon market store update: POLYGON_API_KEY not set[/dim]",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
symbols = _resolve_symbols(None, config_name)
|
||||||
|
if not symbols:
|
||||||
|
console.print(
|
||||||
|
f"[dim]Skipping Polygon market store update: no tickers found for config '{config_name}'[/dim]",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
target_end = end_date or datetime.now().date().isoformat()
|
||||||
|
console.print(
|
||||||
|
f"[cyan]Updating Polygon market store for {', '.join(symbols)} -> {target_end}[/cyan]",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
results = ingest_symbols(
|
||||||
|
symbols,
|
||||||
|
mode="incremental",
|
||||||
|
end_date=target_end,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
console.print(
|
||||||
|
f"[yellow]Polygon market store update failed, continuing startup: {exc}[/yellow]",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
for result in results:
|
||||||
|
console.print(
|
||||||
|
"[green]"
|
||||||
|
f"{result['symbol']}"
|
||||||
|
"[/green] "
|
||||||
|
f"prices={result['prices']} news={result['news']} aligned={result['aligned']}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def auto_prepare_backtest_market_store(
|
||||||
|
config_name: str,
|
||||||
|
*,
|
||||||
|
start_date: str,
|
||||||
|
end_date: str,
|
||||||
|
) -> None:
|
||||||
|
"""Ensure the market store has the requested backtest window for the active watchlist."""
|
||||||
|
api_key = os.getenv("POLYGON_API_KEY", "").strip()
|
||||||
|
if not api_key:
|
||||||
|
console.print(
|
||||||
|
"[dim]Skipping Polygon backtest preload: POLYGON_API_KEY not set[/dim]",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
symbols = _resolve_symbols(None, config_name)
|
||||||
|
if not symbols:
|
||||||
|
console.print(
|
||||||
|
f"[dim]Skipping Polygon backtest preload: no tickers found for config '{config_name}'[/dim]",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
console.print(
|
||||||
|
f"[cyan]Preparing Polygon market store for backtest {start_date} -> {end_date} "
|
||||||
|
f"({', '.join(symbols)})[/cyan]",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
results = ingest_symbols(
|
||||||
|
symbols,
|
||||||
|
mode="full",
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
console.print(
|
||||||
|
f"[yellow]Polygon backtest preload failed, continuing startup: {exc}[/yellow]",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
for result in results:
|
||||||
|
console.print(
|
||||||
|
"[green]"
|
||||||
|
f"{result['symbol']}"
|
||||||
|
"[/green] "
|
||||||
|
f"prices={result['prices']} news={result['news']} aligned={result['aligned']}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def auto_enrich_market_store(
|
||||||
|
config_name: str,
|
||||||
|
*,
|
||||||
|
end_date: Optional[str] = None,
|
||||||
|
lookback_days: int = 120,
|
||||||
|
force: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""Refresh explain-oriented enriched news for the active watchlist."""
|
||||||
|
symbols = _resolve_symbols(None, config_name)
|
||||||
|
if not symbols:
|
||||||
|
console.print(
|
||||||
|
f"[dim]Skipping explain enrich: no tickers found for config '{config_name}'[/dim]",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
target_end = end_date or datetime.now().date().isoformat()
|
||||||
|
try:
|
||||||
|
end_dt = datetime.strptime(target_end, "%Y-%m-%d")
|
||||||
|
except ValueError:
|
||||||
|
console.print(
|
||||||
|
f"[yellow]Skipping explain enrich: invalid end date {target_end}[/yellow]",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
start_date = (end_dt - timedelta(days=max(1, lookback_days))).date().isoformat()
|
||||||
|
console.print(
|
||||||
|
f"[cyan]Refreshing explain enrich for {', '.join(symbols)} -> {target_end}[/cyan]",
|
||||||
|
)
|
||||||
|
store = MarketStore()
|
||||||
|
try:
|
||||||
|
results = enrich_symbols(
|
||||||
|
store,
|
||||||
|
symbols,
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=target_end,
|
||||||
|
limit=300,
|
||||||
|
skip_existing=not force,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
console.print(
|
||||||
|
f"[yellow]Explain enrich failed, continuing startup: {exc}[/yellow]",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
for result in results:
|
||||||
|
console.print(
|
||||||
|
"[green]"
|
||||||
|
f"{result['symbol']}"
|
||||||
|
"[/green] "
|
||||||
|
f"news={result['news_count']} queued={result['queued_count']} analyzed={result['analyzed']} "
|
||||||
|
f"skipped={result['skipped_existing_count']} deduped={result['deduped_count']} "
|
||||||
|
f"llm={result['llm_count']} local={result['local_count']}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.command("init-workspace")
|
@app.command("init-workspace")
|
||||||
def init_workspace(
|
def init_workspace(
|
||||||
config_name: str = typer.Option(
|
config_name: str = typer.Option(
|
||||||
@@ -224,6 +451,416 @@ def init_workspace(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ingest_app.command("full")
|
||||||
|
def ingest_full(
|
||||||
|
tickers: Optional[str] = typer.Option(
|
||||||
|
None,
|
||||||
|
"--tickers",
|
||||||
|
"-t",
|
||||||
|
help="Comma-separated tickers to ingest",
|
||||||
|
),
|
||||||
|
start: Optional[str] = typer.Option(
|
||||||
|
None,
|
||||||
|
"--start",
|
||||||
|
help="Start date for full ingestion (YYYY-MM-DD)",
|
||||||
|
),
|
||||||
|
end: Optional[str] = typer.Option(
|
||||||
|
None,
|
||||||
|
"--end",
|
||||||
|
help="End date for ingestion (YYYY-MM-DD)",
|
||||||
|
),
|
||||||
|
config_name: str = typer.Option(
|
||||||
|
"default",
|
||||||
|
"--config-name",
|
||||||
|
"-c",
|
||||||
|
help="Fallback config to read tickers from BOOTSTRAP.md",
|
||||||
|
),
|
||||||
|
):
|
||||||
|
"""Run full Polygon ingestion for the specified symbols."""
|
||||||
|
symbols = _resolve_symbols(tickers, config_name)
|
||||||
|
if not symbols:
|
||||||
|
console.print("[red]No tickers provided and none found in BOOTSTRAP.md[/red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
console.print(f"[cyan]Starting full Polygon ingest for {', '.join(symbols)}[/cyan]")
|
||||||
|
results = ingest_symbols(symbols, mode="full", start_date=start, end_date=end)
|
||||||
|
for result in results:
|
||||||
|
console.print(
|
||||||
|
f"[green]{result['symbol']}[/green] prices={result['prices']} news={result['news']} aligned={result['aligned']}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ingest_app.command("update")
|
||||||
|
def ingest_update(
|
||||||
|
tickers: Optional[str] = typer.Option(
|
||||||
|
None,
|
||||||
|
"--tickers",
|
||||||
|
"-t",
|
||||||
|
help="Comma-separated tickers to update",
|
||||||
|
),
|
||||||
|
end: Optional[str] = typer.Option(
|
||||||
|
None,
|
||||||
|
"--end",
|
||||||
|
help="Optional end date override (YYYY-MM-DD)",
|
||||||
|
),
|
||||||
|
config_name: str = typer.Option(
|
||||||
|
"default",
|
||||||
|
"--config-name",
|
||||||
|
"-c",
|
||||||
|
help="Fallback config to read tickers from BOOTSTRAP.md",
|
||||||
|
),
|
||||||
|
):
|
||||||
|
"""Run incremental Polygon ingestion using stored watermarks."""
|
||||||
|
symbols = _resolve_symbols(tickers, config_name)
|
||||||
|
if not symbols:
|
||||||
|
console.print("[red]No tickers provided and none found in BOOTSTRAP.md[/red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
console.print(f"[cyan]Starting incremental Polygon ingest for {', '.join(symbols)}[/cyan]")
|
||||||
|
results = ingest_symbols(symbols, mode="incremental", end_date=end)
|
||||||
|
for result in results:
|
||||||
|
console.print(
|
||||||
|
f"[green]{result['symbol']}[/green] prices={result['prices']} news={result['news']} aligned={result['aligned']}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ingest_app.command("enrich")
|
||||||
|
def ingest_enrich(
|
||||||
|
tickers: Optional[str] = typer.Option(
|
||||||
|
None,
|
||||||
|
"--tickers",
|
||||||
|
"-t",
|
||||||
|
help="Comma-separated tickers to enrich",
|
||||||
|
),
|
||||||
|
start: Optional[str] = typer.Option(
|
||||||
|
None,
|
||||||
|
"--start",
|
||||||
|
help="Optional start date for enrichment window (YYYY-MM-DD)",
|
||||||
|
),
|
||||||
|
end: Optional[str] = typer.Option(
|
||||||
|
None,
|
||||||
|
"--end",
|
||||||
|
help="Optional end date for enrichment window (YYYY-MM-DD)",
|
||||||
|
),
|
||||||
|
limit: int = typer.Option(
|
||||||
|
300,
|
||||||
|
"--limit",
|
||||||
|
help="Maximum raw news rows per ticker to analyze",
|
||||||
|
),
|
||||||
|
force: bool = typer.Option(
|
||||||
|
False,
|
||||||
|
"--force",
|
||||||
|
help="Re-analyze already enriched news instead of only missing rows",
|
||||||
|
),
|
||||||
|
config_name: str = typer.Option(
|
||||||
|
"default",
|
||||||
|
"--config-name",
|
||||||
|
"-c",
|
||||||
|
help="Fallback config to read tickers from BOOTSTRAP.md",
|
||||||
|
),
|
||||||
|
):
|
||||||
|
"""Run explain-oriented news enrichment for symbols already in the market store."""
|
||||||
|
symbols = _resolve_symbols(tickers, config_name)
|
||||||
|
if not symbols:
|
||||||
|
console.print("[red]No tickers provided and none found in BOOTSTRAP.md[/red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
console.print(f"[cyan]Starting explain enrich for {', '.join(symbols)}[/cyan]")
|
||||||
|
store = MarketStore()
|
||||||
|
results = enrich_symbols(
|
||||||
|
store,
|
||||||
|
symbols,
|
||||||
|
start_date=start,
|
||||||
|
end_date=end,
|
||||||
|
limit=max(10, limit),
|
||||||
|
skip_existing=not force,
|
||||||
|
)
|
||||||
|
for result in results:
|
||||||
|
console.print(
|
||||||
|
f"[green]{result['symbol']}[/green] "
|
||||||
|
f"news={result['news_count']} queued={result['queued_count']} analyzed={result['analyzed']} "
|
||||||
|
f"skipped={result['skipped_existing_count']} deduped={result['deduped_count']} "
|
||||||
|
f"llm={result['llm_count']} local={result['local_count']}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ingest_app.command("report")
|
||||||
|
def ingest_report(
|
||||||
|
tickers: Optional[str] = typer.Option(
|
||||||
|
None,
|
||||||
|
"--tickers",
|
||||||
|
"-t",
|
||||||
|
help="Optional comma-separated tickers to report",
|
||||||
|
),
|
||||||
|
start: Optional[str] = typer.Option(
|
||||||
|
None,
|
||||||
|
"--start",
|
||||||
|
help="Optional start date for report window (YYYY-MM-DD)",
|
||||||
|
),
|
||||||
|
end: Optional[str] = typer.Option(
|
||||||
|
None,
|
||||||
|
"--end",
|
||||||
|
help="Optional end date for report window (YYYY-MM-DD)",
|
||||||
|
),
|
||||||
|
config_name: str = typer.Option(
|
||||||
|
"default",
|
||||||
|
"--config-name",
|
||||||
|
"-c",
|
||||||
|
help="Fallback config to read tickers from BOOTSTRAP.md",
|
||||||
|
),
|
||||||
|
only_problematic: bool = typer.Option(
|
||||||
|
False,
|
||||||
|
"--only-problematic",
|
||||||
|
help="Only show tickers with incomplete coverage or no LLM-enriched news",
|
||||||
|
),
|
||||||
|
):
|
||||||
|
"""Show explain enrichment coverage and freshness per ticker."""
|
||||||
|
symbols = _resolve_symbols(tickers, config_name)
|
||||||
|
store = MarketStore()
|
||||||
|
report_rows = store.get_enrich_report(
|
||||||
|
symbols=symbols or None,
|
||||||
|
start_date=start,
|
||||||
|
end_date=end,
|
||||||
|
)
|
||||||
|
if only_problematic:
|
||||||
|
report_rows = _filter_problematic_report_rows(report_rows)
|
||||||
|
if not report_rows:
|
||||||
|
if only_problematic:
|
||||||
|
console.print("[green]No problematic enrich report rows found for the requested scope[/green]")
|
||||||
|
else:
|
||||||
|
console.print("[yellow]No enrich report rows found for the requested scope[/yellow]")
|
||||||
|
raise typer.Exit(0)
|
||||||
|
|
||||||
|
model_info = get_explain_model_info()
|
||||||
|
model_label = model_info["label"] if llm_enrichment_enabled() else "disabled"
|
||||||
|
table = Table(title="Explain Enrichment Report")
|
||||||
|
table.add_column("Ticker", style="cyan")
|
||||||
|
table.add_column("Raw News", justify="right")
|
||||||
|
table.add_column("Analyzed", justify="right")
|
||||||
|
table.add_column("Coverage", justify="right")
|
||||||
|
table.add_column("LLM", justify="right")
|
||||||
|
table.add_column("Local", justify="right")
|
||||||
|
table.add_column("Latest Trade Date")
|
||||||
|
table.add_column("Latest Analysis")
|
||||||
|
table.caption = f"Explain LLM: {model_label}"
|
||||||
|
|
||||||
|
for row in report_rows:
|
||||||
|
table.add_row(
|
||||||
|
row["symbol"],
|
||||||
|
str(row["raw_news_count"]),
|
||||||
|
str(row["analyzed_news_count"]),
|
||||||
|
f'{row["coverage_pct"]:.1f}%',
|
||||||
|
str(row["llm_count"]),
|
||||||
|
str(row["local_count"]),
|
||||||
|
str(row["latest_trade_date"] or "-"),
|
||||||
|
str(row["latest_analysis_at"] or "-"),
|
||||||
|
)
|
||||||
|
console.print(table)
|
||||||
|
|
||||||
|
|
||||||
|
@skills_app.command("list")
|
||||||
|
def skills_list(
|
||||||
|
config_name: str = typer.Option(
|
||||||
|
"default",
|
||||||
|
"--config-name",
|
||||||
|
"-c",
|
||||||
|
help="Run config name.",
|
||||||
|
),
|
||||||
|
agent_id: Optional[str] = typer.Option(
|
||||||
|
None,
|
||||||
|
"--agent-id",
|
||||||
|
"-a",
|
||||||
|
help="Optional agent id to show resolved status for.",
|
||||||
|
),
|
||||||
|
):
|
||||||
|
"""List available skills and optional agent-level enablement state."""
|
||||||
|
project_root = get_project_root()
|
||||||
|
skills_manager = SkillsManager(project_root=project_root)
|
||||||
|
catalog = (
|
||||||
|
skills_manager.list_agent_skill_catalog(config_name, agent_id)
|
||||||
|
if agent_id
|
||||||
|
else skills_manager.list_skill_catalog()
|
||||||
|
)
|
||||||
|
if not catalog:
|
||||||
|
console.print("[yellow]No skills found[/yellow]")
|
||||||
|
raise typer.Exit(0)
|
||||||
|
|
||||||
|
agent_config = None
|
||||||
|
resolved_skills = set()
|
||||||
|
if agent_id:
|
||||||
|
asset_dir = _require_agent_asset_dir(config_name, agent_id)
|
||||||
|
agent_config = load_agent_workspace_config(asset_dir / "agent.yaml")
|
||||||
|
resolved_skills = set(
|
||||||
|
skills_manager.resolve_agent_skill_names(
|
||||||
|
config_name=config_name,
|
||||||
|
agent_id=agent_id,
|
||||||
|
default_skills=[],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
table = Table(title="Skill Catalog")
|
||||||
|
table.add_column("Skill", style="cyan")
|
||||||
|
table.add_column("Source")
|
||||||
|
table.add_column("Description")
|
||||||
|
if agent_id:
|
||||||
|
table.add_column("Status")
|
||||||
|
|
||||||
|
enabled = set(agent_config.enabled_skills) if agent_config else set()
|
||||||
|
disabled = set(agent_config.disabled_skills) if agent_config else set()
|
||||||
|
for skill in catalog:
|
||||||
|
row = [
|
||||||
|
skill.skill_name,
|
||||||
|
skill.source,
|
||||||
|
skill.description or "-",
|
||||||
|
]
|
||||||
|
if agent_id:
|
||||||
|
if skill.skill_name in disabled:
|
||||||
|
status = "disabled"
|
||||||
|
elif skill.skill_name in enabled:
|
||||||
|
status = "enabled"
|
||||||
|
elif skill.skill_name in resolved_skills:
|
||||||
|
status = "active"
|
||||||
|
else:
|
||||||
|
status = "-"
|
||||||
|
row.append(status)
|
||||||
|
table.add_row(*row)
|
||||||
|
console.print(table)
|
||||||
|
|
||||||
|
|
||||||
|
@skills_app.command("enable")
|
||||||
|
def skills_enable(
|
||||||
|
agent_id: str = typer.Option(..., "--agent-id", "-a", help="Agent id."),
|
||||||
|
skill: str = typer.Option(..., "--skill", "-s", help="Skill name."),
|
||||||
|
config_name: str = typer.Option(
|
||||||
|
"default",
|
||||||
|
"--config-name",
|
||||||
|
"-c",
|
||||||
|
help="Run config name.",
|
||||||
|
),
|
||||||
|
):
|
||||||
|
"""Enable a skill for one agent in agent.yaml."""
|
||||||
|
asset_dir = _require_agent_asset_dir(config_name, agent_id)
|
||||||
|
skills_manager = SkillsManager(project_root=get_project_root())
|
||||||
|
catalog = {
|
||||||
|
item.skill_name
|
||||||
|
for item in skills_manager.list_agent_skill_catalog(config_name, agent_id)
|
||||||
|
}
|
||||||
|
if skill not in catalog:
|
||||||
|
console.print(f"[red]Unknown skill: {skill}[/red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
result = skills_manager.update_agent_skill_overrides(
|
||||||
|
config_name=config_name,
|
||||||
|
agent_id=agent_id,
|
||||||
|
enable=[skill],
|
||||||
|
)
|
||||||
|
console.print(
|
||||||
|
f"[green]Enabled[/green] `{skill}` for `{agent_id}` "
|
||||||
|
f"([{asset_dir / 'agent.yaml'}])",
|
||||||
|
)
|
||||||
|
console.print(f"Enabled skills: {', '.join(result['enabled_skills']) or '-'}")
|
||||||
|
console.print(f"Disabled skills: {', '.join(result['disabled_skills']) or '-'}")
|
||||||
|
|
||||||
|
|
||||||
|
@skills_app.command("disable")
|
||||||
|
def skills_disable(
|
||||||
|
agent_id: str = typer.Option(..., "--agent-id", "-a", help="Agent id."),
|
||||||
|
skill: str = typer.Option(..., "--skill", "-s", help="Skill name."),
|
||||||
|
config_name: str = typer.Option(
|
||||||
|
"default",
|
||||||
|
"--config-name",
|
||||||
|
"-c",
|
||||||
|
help="Run config name.",
|
||||||
|
),
|
||||||
|
):
|
||||||
|
"""Disable a skill for one agent in agent.yaml."""
|
||||||
|
asset_dir = _require_agent_asset_dir(config_name, agent_id)
|
||||||
|
skills_manager = SkillsManager(project_root=get_project_root())
|
||||||
|
result = skills_manager.update_agent_skill_overrides(
|
||||||
|
config_name=config_name,
|
||||||
|
agent_id=agent_id,
|
||||||
|
disable=[skill],
|
||||||
|
)
|
||||||
|
console.print(
|
||||||
|
f"[yellow]Disabled[/yellow] `{skill}` for `{agent_id}` "
|
||||||
|
f"([{asset_dir / 'agent.yaml'}])",
|
||||||
|
)
|
||||||
|
console.print(f"Enabled skills: {', '.join(result['enabled_skills']) or '-'}")
|
||||||
|
console.print(f"Disabled skills: {', '.join(result['disabled_skills']) or '-'}")
|
||||||
|
|
||||||
|
|
||||||
|
@skills_app.command("install")
|
||||||
|
def skills_install(
|
||||||
|
agent_id: str = typer.Option(..., "--agent-id", "-a", help="Target agent id."),
|
||||||
|
source: str = typer.Option(
|
||||||
|
...,
|
||||||
|
"--source",
|
||||||
|
"-s",
|
||||||
|
help="External skill source: directory path, zip path, or http(s) zip URL.",
|
||||||
|
),
|
||||||
|
config_name: str = typer.Option(
|
||||||
|
"default",
|
||||||
|
"--config-name",
|
||||||
|
"-c",
|
||||||
|
help="Run config name.",
|
||||||
|
),
|
||||||
|
name: Optional[str] = typer.Option(
|
||||||
|
None,
|
||||||
|
"--name",
|
||||||
|
help="Optional override skill name.",
|
||||||
|
),
|
||||||
|
activate: bool = typer.Option(
|
||||||
|
True,
|
||||||
|
"--activate/--no-activate",
|
||||||
|
help="Enable the skill for this agent immediately.",
|
||||||
|
),
|
||||||
|
):
|
||||||
|
"""Install an external skill into one agent's local skill directory."""
|
||||||
|
_require_agent_asset_dir(config_name, agent_id)
|
||||||
|
skills_manager = SkillsManager(project_root=get_project_root())
|
||||||
|
result = skills_manager.install_external_skill_for_agent(
|
||||||
|
config_name=config_name,
|
||||||
|
agent_id=agent_id,
|
||||||
|
source=source,
|
||||||
|
skill_name=name,
|
||||||
|
activate=activate,
|
||||||
|
)
|
||||||
|
console.print(
|
||||||
|
f"[green]Installed[/green] `{result['skill_name']}` to `{agent_id}`",
|
||||||
|
)
|
||||||
|
console.print(f"Path: {result['target_dir']}")
|
||||||
|
console.print(f"Activated: {result['activated']}")
|
||||||
|
warnings = result.get("warnings") or []
|
||||||
|
if warnings:
|
||||||
|
console.print(f"Warnings: {'; '.join(warnings)}")
|
||||||
|
|
||||||
|
|
||||||
|
@team_app.command("show")
|
||||||
|
def team_show(
|
||||||
|
config_name: str = typer.Option(
|
||||||
|
"default",
|
||||||
|
"--config-name",
|
||||||
|
"-c",
|
||||||
|
help="Run config name.",
|
||||||
|
),
|
||||||
|
):
|
||||||
|
"""Show TEAM_PIPELINE.yaml for one run."""
|
||||||
|
project_root = get_project_root()
|
||||||
|
ensure_team_pipeline_config(
|
||||||
|
project_root=project_root,
|
||||||
|
config_name=config_name,
|
||||||
|
default_analysts=list(ANALYST_TYPES.keys()),
|
||||||
|
)
|
||||||
|
config = load_team_pipeline_config(project_root, config_name)
|
||||||
|
console.print(
|
||||||
|
Panel.fit(
|
||||||
|
yaml.safe_dump(config, allow_unicode=True, sort_keys=False),
|
||||||
|
title=f"TEAM_PIPELINE ({config_name})",
|
||||||
|
border_style="cyan",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def backtest(
|
def backtest(
|
||||||
start: Optional[str] = typer.Option(
|
start: Optional[str] = typer.Option(
|
||||||
@@ -286,6 +923,7 @@ def backtest(
|
|||||||
border_style="cyan",
|
border_style="cyan",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
poll_interval = int(_normalize_typer_value(poll_interval, 10))
|
||||||
|
|
||||||
# Validate dates - required for backtest
|
# Validate dates - required for backtest
|
||||||
if not start or not end:
|
if not start or not end:
|
||||||
@@ -332,6 +970,16 @@ def backtest(
|
|||||||
|
|
||||||
# Run data updater
|
# Run data updater
|
||||||
run_data_updater(project_root)
|
run_data_updater(project_root)
|
||||||
|
auto_prepare_backtest_market_store(
|
||||||
|
config_name,
|
||||||
|
start_date=start,
|
||||||
|
end_date=end,
|
||||||
|
)
|
||||||
|
auto_enrich_market_store(
|
||||||
|
config_name,
|
||||||
|
end_date=end,
|
||||||
|
force=False,
|
||||||
|
)
|
||||||
|
|
||||||
# Build command using backend.main
|
# Build command using backend.main
|
||||||
cmd = [
|
cmd = [
|
||||||
@@ -393,12 +1041,22 @@ def live(
|
|||||||
"-p",
|
"-p",
|
||||||
help="WebSocket server port",
|
help="WebSocket server port",
|
||||||
),
|
),
|
||||||
|
schedule_mode: str = typer.Option(
|
||||||
|
"daily",
|
||||||
|
"--schedule-mode",
|
||||||
|
help="Scheduler mode: 'daily' or 'intraday'",
|
||||||
|
),
|
||||||
trigger_time: str = typer.Option(
|
trigger_time: str = typer.Option(
|
||||||
"now",
|
"now",
|
||||||
"--trigger-time",
|
"--trigger-time",
|
||||||
"-t",
|
"-t",
|
||||||
help="Trigger time in LOCAL timezone (HH:MM), or 'now' to run immediately",
|
help="Trigger time in LOCAL timezone (HH:MM), or 'now' to run immediately",
|
||||||
),
|
),
|
||||||
|
interval_minutes: int = typer.Option(
|
||||||
|
60,
|
||||||
|
"--interval-minutes",
|
||||||
|
help="When schedule-mode=intraday, run every N minutes",
|
||||||
|
),
|
||||||
poll_interval: int = typer.Option(
|
poll_interval: int = typer.Option(
|
||||||
10,
|
10,
|
||||||
"--poll-interval",
|
"--poll-interval",
|
||||||
@@ -422,9 +1080,12 @@ def live(
|
|||||||
evotraders live # Run immediately (default)
|
evotraders live # Run immediately (default)
|
||||||
evotraders live --mock # Mock mode
|
evotraders live --mock # Mock mode
|
||||||
evotraders live -t 22:30 # Run at 22:30 local time daily
|
evotraders live -t 22:30 # Run at 22:30 local time daily
|
||||||
|
evotraders live --schedule-mode intraday --interval-minutes 60
|
||||||
evotraders live --trigger-time now # Run immediately
|
evotraders live --trigger-time now # Run immediately
|
||||||
evotraders live --clean # Clear historical data before starting
|
evotraders live --clean # Clear historical data before starting
|
||||||
"""
|
"""
|
||||||
|
schedule_mode = str(_normalize_typer_value(schedule_mode, "daily"))
|
||||||
|
interval_minutes = int(_normalize_typer_value(interval_minutes, 60))
|
||||||
mode_name = "MOCK" if mock else "LIVE"
|
mode_name = "MOCK" if mock else "LIVE"
|
||||||
console.print(
|
console.print(
|
||||||
Panel.fit(
|
Panel.fit(
|
||||||
@@ -456,6 +1117,16 @@ def live(
|
|||||||
# Handle historical data cleanup
|
# Handle historical data cleanup
|
||||||
handle_history_cleanup(config_name, auto_clean=clean)
|
handle_history_cleanup(config_name, auto_clean=clean)
|
||||||
|
|
||||||
|
if schedule_mode not in {"daily", "intraday"}:
|
||||||
|
console.print(
|
||||||
|
f"[red]Error: unsupported schedule mode '{schedule_mode}'[/red]",
|
||||||
|
)
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
if interval_minutes <= 0:
|
||||||
|
console.print("[red]Error: --interval-minutes must be > 0[/red]")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
# Convert local time to NYSE time
|
# Convert local time to NYSE time
|
||||||
nyse_tz = ZoneInfo("America/New_York")
|
nyse_tz = ZoneInfo("America/New_York")
|
||||||
local_tz = datetime.now().astimezone().tzinfo
|
local_tz = datetime.now().astimezone().tzinfo
|
||||||
@@ -463,7 +1134,9 @@ def live(
|
|||||||
nyse_now = datetime.now(nyse_tz)
|
nyse_now = datetime.now(nyse_tz)
|
||||||
|
|
||||||
# Convert trigger time from local to NYSE
|
# Convert trigger time from local to NYSE
|
||||||
if trigger_time.lower() == "now":
|
if schedule_mode == "intraday":
|
||||||
|
nyse_trigger_time = "now"
|
||||||
|
elif trigger_time.lower() == "now":
|
||||||
nyse_trigger_time = "now"
|
nyse_trigger_time = "now"
|
||||||
else:
|
else:
|
||||||
local_trigger = datetime.strptime(trigger_time, "%H:%M")
|
local_trigger = datetime.strptime(trigger_time, "%H:%M")
|
||||||
@@ -483,7 +1156,10 @@ def live(
|
|||||||
console.print(
|
console.print(
|
||||||
f" NYSE Time: {nyse_now.strftime('%Y-%m-%d %H:%M:%S %Z')}",
|
f" NYSE Time: {nyse_now.strftime('%Y-%m-%d %H:%M:%S %Z')}",
|
||||||
)
|
)
|
||||||
if nyse_trigger_time == "now":
|
console.print(f" Schedule: {schedule_mode}")
|
||||||
|
if schedule_mode == "intraday":
|
||||||
|
console.print(f" Interval: every {interval_minutes} minute(s)")
|
||||||
|
elif nyse_trigger_time == "now":
|
||||||
console.print(" Trigger: [green]NOW (immediate)[/green]")
|
console.print(" Trigger: [green]NOW (immediate)[/green]")
|
||||||
else:
|
else:
|
||||||
console.print(
|
console.print(
|
||||||
@@ -515,6 +1191,15 @@ def live(
|
|||||||
# Data update (if not mock mode)
|
# Data update (if not mock mode)
|
||||||
if not mock:
|
if not mock:
|
||||||
run_data_updater(project_root)
|
run_data_updater(project_root)
|
||||||
|
auto_update_market_store(
|
||||||
|
config_name,
|
||||||
|
end_date=nyse_now.date().isoformat(),
|
||||||
|
)
|
||||||
|
auto_enrich_market_store(
|
||||||
|
config_name,
|
||||||
|
end_date=nyse_now.date().isoformat(),
|
||||||
|
force=False,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
console.print(
|
console.print(
|
||||||
"\n[dim]Mock mode enabled - skipping data update[/dim]\n",
|
"\n[dim]Mock mode enabled - skipping data update[/dim]\n",
|
||||||
@@ -534,10 +1219,14 @@ def live(
|
|||||||
host,
|
host,
|
||||||
"--port",
|
"--port",
|
||||||
str(port),
|
str(port),
|
||||||
|
"--schedule-mode",
|
||||||
|
schedule_mode,
|
||||||
"--poll-interval",
|
"--poll-interval",
|
||||||
str(poll_interval),
|
str(poll_interval),
|
||||||
"--trigger-time",
|
"--trigger-time",
|
||||||
nyse_trigger_time,
|
nyse_trigger_time,
|
||||||
|
"--interval-minutes",
|
||||||
|
str(interval_minutes),
|
||||||
]
|
]
|
||||||
|
|
||||||
if mock:
|
if mock:
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
"""Parse run-scoped BOOTSTRAP.md into structured configuration."""
|
"""Parse run-scoped BOOTSTRAP.md into structured and runtime config."""
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -8,6 +8,8 @@ import re
|
|||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
from backend.config.env_config import get_env_float, get_env_int, get_env_list
|
||||||
|
|
||||||
|
|
||||||
BOOTSTRAP_FRONT_MATTER_RE = re.compile(
|
BOOTSTRAP_FRONT_MATTER_RE = re.compile(
|
||||||
r"^---\s*\n(.*?)\n---\s*\n?(.*)$",
|
r"^---\s*\n(.*?)\n---\s*\n?(.*)$",
|
||||||
@@ -63,3 +65,99 @@ def get_bootstrap_config_for_run(
|
|||||||
return load_bootstrap_config(
|
return load_bootstrap_config(
|
||||||
project_root / "runs" / config_name / "BOOTSTRAP.md",
|
project_root / "runs" / config_name / "BOOTSTRAP.md",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def save_bootstrap_config(bootstrap_path: Path, config: BootstrapConfig) -> None:
|
||||||
|
"""Persist structured bootstrap config back to BOOTSTRAP.md."""
|
||||||
|
bootstrap_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
values = config.values if isinstance(config.values, dict) else {}
|
||||||
|
front_matter = yaml.safe_dump(
|
||||||
|
values,
|
||||||
|
allow_unicode=True,
|
||||||
|
sort_keys=False,
|
||||||
|
).strip()
|
||||||
|
body = (config.prompt_body or "").strip()
|
||||||
|
|
||||||
|
content = f"---\n{front_matter}\n---"
|
||||||
|
if body:
|
||||||
|
content += f"\n\n{body}\n"
|
||||||
|
else:
|
||||||
|
content += "\n"
|
||||||
|
|
||||||
|
bootstrap_path.write_text(content, encoding="utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def update_bootstrap_values_for_run(
|
||||||
|
project_root: Path,
|
||||||
|
config_name: str,
|
||||||
|
updates: Dict[str, Any],
|
||||||
|
) -> BootstrapConfig:
|
||||||
|
"""Patch selected front matter keys for a run and persist them."""
|
||||||
|
bootstrap_path = project_root / "runs" / config_name / "BOOTSTRAP.md"
|
||||||
|
existing = load_bootstrap_config(bootstrap_path)
|
||||||
|
values = dict(existing.values)
|
||||||
|
values.update(updates)
|
||||||
|
updated = BootstrapConfig(values=values, prompt_body=existing.prompt_body)
|
||||||
|
save_bootstrap_config(bootstrap_path, updated)
|
||||||
|
return updated
|
||||||
|
|
||||||
|
|
||||||
|
def _coerce_bool(value: Any) -> bool:
|
||||||
|
"""Parse booleans from bootstrap-friendly string values."""
|
||||||
|
if isinstance(value, bool):
|
||||||
|
return value
|
||||||
|
if isinstance(value, str):
|
||||||
|
normalized = value.strip().lower()
|
||||||
|
if normalized in {"1", "true", "yes", "on"}:
|
||||||
|
return True
|
||||||
|
if normalized in {"0", "false", "no", "off"}:
|
||||||
|
return False
|
||||||
|
return bool(value)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_runtime_config(
|
||||||
|
project_root: Path,
|
||||||
|
config_name: str,
|
||||||
|
enable_memory: bool = False,
|
||||||
|
schedule_mode: str = "daily",
|
||||||
|
interval_minutes: int = 60,
|
||||||
|
trigger_time: str = "09:30",
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Merge env defaults with run-scoped bootstrap front matter."""
|
||||||
|
bootstrap = get_bootstrap_config_for_run(project_root, config_name)
|
||||||
|
return {
|
||||||
|
"tickers": bootstrap.get("tickers")
|
||||||
|
or get_env_list("TICKERS", ["AAPL", "MSFT"]),
|
||||||
|
"initial_cash": float(
|
||||||
|
bootstrap.get(
|
||||||
|
"initial_cash",
|
||||||
|
get_env_float("INITIAL_CASH", 100000.0),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
"margin_requirement": float(
|
||||||
|
bootstrap.get(
|
||||||
|
"margin_requirement",
|
||||||
|
get_env_float("MARGIN_REQUIREMENT", 0.0),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
"max_comm_cycles": int(
|
||||||
|
bootstrap.get(
|
||||||
|
"max_comm_cycles",
|
||||||
|
get_env_int("MAX_COMM_CYCLES", 2),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
"schedule_mode": str(
|
||||||
|
bootstrap.get("schedule_mode", schedule_mode),
|
||||||
|
).strip().lower() or schedule_mode,
|
||||||
|
"interval_minutes": int(
|
||||||
|
bootstrap.get(
|
||||||
|
"interval_minutes",
|
||||||
|
interval_minutes or get_env_int("INTERVAL_MINUTES", 60),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
"trigger_time": str(
|
||||||
|
bootstrap.get("trigger_time", trigger_time),
|
||||||
|
).strip() or trigger_time,
|
||||||
|
"enable_memory": bool(enable_memory)
|
||||||
|
or _coerce_bool(bootstrap.get("enable_memory", False)),
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,7 +1,35 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
"""Core pipeline and orchestration logic"""
|
"""Core pipeline and orchestration logic.
|
||||||
|
|
||||||
|
Keep ``pipeline_runner`` behind lazy wrappers so importing ``backend.core`` does
|
||||||
|
not immediately pull in the gateway runtime graph.
|
||||||
|
"""
|
||||||
|
|
||||||
from .pipeline import TradingPipeline
|
from .pipeline import TradingPipeline
|
||||||
from .state_sync import StateSync
|
from .state_sync import StateSync
|
||||||
|
|
||||||
__all__ = ["TradingPipeline", "StateSync"]
|
|
||||||
|
def create_agents(*args, **kwargs):
|
||||||
|
from .pipeline_runner import create_agents as _create_agents
|
||||||
|
|
||||||
|
return _create_agents(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def create_long_term_memory(*args, **kwargs):
|
||||||
|
from .pipeline_runner import create_long_term_memory as _create_long_term_memory
|
||||||
|
|
||||||
|
return _create_long_term_memory(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def stop_gateway(*args, **kwargs):
|
||||||
|
from .pipeline_runner import stop_gateway as _stop_gateway
|
||||||
|
|
||||||
|
return _stop_gateway(*args, **kwargs)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"TradingPipeline",
|
||||||
|
"StateSync",
|
||||||
|
"create_agents",
|
||||||
|
"create_long_term_memory",
|
||||||
|
"stop_gateway",
|
||||||
|
]
|
||||||
|
|||||||
@@ -10,6 +10,8 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
from contextlib import nullcontext
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any, Awaitable, Callable, Dict, List, Optional
|
from typing import Any, Awaitable, Callable, Dict, List, Optional
|
||||||
|
|
||||||
from agentscope.message import Msg
|
from agentscope.message import Msg
|
||||||
@@ -19,6 +21,28 @@ from backend.utils.settlement import SettlementCoordinator
|
|||||||
from backend.utils.terminal_dashboard import get_dashboard
|
from backend.utils.terminal_dashboard import get_dashboard
|
||||||
from backend.core.state_sync import StateSync
|
from backend.core.state_sync import StateSync
|
||||||
from backend.utils.trade_executor import PortfolioTradeExecutor
|
from backend.utils.trade_executor import PortfolioTradeExecutor
|
||||||
|
from backend.runtime.manager import TradingRuntimeManager
|
||||||
|
from backend.runtime.session import TradingSessionKey
|
||||||
|
from backend.agents.team_pipeline_config import (
|
||||||
|
resolve_active_analysts,
|
||||||
|
update_active_analysts,
|
||||||
|
)
|
||||||
|
from backend.agents import AnalystAgent
|
||||||
|
from backend.agents.toolkit_factory import create_agent_toolkit
|
||||||
|
from backend.agents.workspace_manager import WorkspaceManager
|
||||||
|
from backend.agents.prompt_loader import get_prompt_loader
|
||||||
|
from backend.llm.models import get_agent_formatter, get_agent_model
|
||||||
|
from backend.config.constants import ANALYST_TYPES
|
||||||
|
|
||||||
|
# Team infrastructure imports (graceful import - may not exist yet)
|
||||||
|
try:
|
||||||
|
from backend.agents.team.team_coordinator import TeamCoordinator
|
||||||
|
from backend.agents.team.msg_hub import MsgHub as TeamMsgHub
|
||||||
|
TEAM_COORD_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
TEAM_COORD_AVAILABLE = False
|
||||||
|
TeamCoordinator = None
|
||||||
|
TeamMsgHub = None
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -46,6 +70,8 @@ class TradingPipeline:
|
|||||||
6. Reflection phase: broadcast closing P&L, agents record to long-term memory
|
6. Reflection phase: broadcast closing P&L, agents record to long-term memory
|
||||||
|
|
||||||
Real-time updates via StateSync after each agent completes.
|
Real-time updates via StateSync after each agent completes.
|
||||||
|
|
||||||
|
Supports both legacy agent lists and new workspace-based agent loading.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -56,6 +82,9 @@ class TradingPipeline:
|
|||||||
state_sync: Optional["StateSync"] = None,
|
state_sync: Optional["StateSync"] = None,
|
||||||
settlement_coordinator: Optional[SettlementCoordinator] = None,
|
settlement_coordinator: Optional[SettlementCoordinator] = None,
|
||||||
max_comm_cycles: Optional[int] = None,
|
max_comm_cycles: Optional[int] = None,
|
||||||
|
workspace_id: Optional[str] = None,
|
||||||
|
agent_factory: Optional[Any] = None,
|
||||||
|
runtime_manager: Optional[TradingRuntimeManager] = None,
|
||||||
):
|
):
|
||||||
self.analysts = analysts
|
self.analysts = analysts
|
||||||
self.risk_manager = risk_manager
|
self.risk_manager = risk_manager
|
||||||
@@ -66,6 +95,17 @@ class TradingPipeline:
|
|||||||
os.getenv("MAX_COMM_CYCLES", "2"),
|
os.getenv("MAX_COMM_CYCLES", "2"),
|
||||||
)
|
)
|
||||||
self.conference_summary = None # Store latest conference summary
|
self.conference_summary = None # Store latest conference summary
|
||||||
|
self.workspace_id = workspace_id
|
||||||
|
self.agent_factory = agent_factory
|
||||||
|
self.runtime_manager = runtime_manager
|
||||||
|
self._session_key: Optional[str] = None
|
||||||
|
self._dynamic_analysts: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
if hasattr(self.pm, "set_team_controller"):
|
||||||
|
self.pm.set_team_controller(
|
||||||
|
create_agent_callback=self._create_runtime_analyst,
|
||||||
|
remove_agent_callback=self._remove_runtime_analyst,
|
||||||
|
)
|
||||||
|
|
||||||
async def run_cycle(
|
async def run_cycle(
|
||||||
self,
|
self,
|
||||||
@@ -80,6 +120,7 @@ class TradingPipeline:
|
|||||||
get_close_prices_fn: Optional[
|
get_close_prices_fn: Optional[
|
||||||
Callable[[], Awaitable[Dict[str, float]]]
|
Callable[[], Awaitable[Dict[str, float]]]
|
||||||
] = None,
|
] = None,
|
||||||
|
execute_decisions: bool = True,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Run one complete trading cycle
|
Run one complete trading cycle
|
||||||
@@ -101,12 +142,19 @@ class TradingPipeline:
|
|||||||
Each agent's result is broadcast immediately via StateSync.
|
Each agent's result is broadcast immediately via StateSync.
|
||||||
"""
|
"""
|
||||||
_log(f"Starting cycle {date} - {len(tickers)} tickers")
|
_log(f"Starting cycle {date} - {len(tickers)} tickers")
|
||||||
|
session_key = TradingSessionKey(date=date).key()
|
||||||
|
self._session_key = session_key
|
||||||
|
active_analysts = self._get_active_analysts()
|
||||||
|
if self.runtime_manager:
|
||||||
|
self.runtime_manager.set_session_key(session_key)
|
||||||
|
self._runtime_log_event("cycle:start", {"tickers": tickers, "date": date})
|
||||||
|
self._runtime_batch_status(active_analysts, "analysis_in_progress")
|
||||||
|
|
||||||
# Phase 0: Clear short-term memory to avoid cross-day context pollution
|
# Phase 0: Clear short-term memory to avoid cross-day context pollution
|
||||||
_log("Phase 0: Clearing memory")
|
_log("Phase 0: Clearing memory")
|
||||||
await self._clear_all_agent_memory()
|
await self._clear_all_agent_memory()
|
||||||
|
|
||||||
participants = self.analysts + [self.risk_manager, self.pm]
|
participants = self._all_analysts() + [self.risk_manager, self.pm]
|
||||||
|
|
||||||
# Single MsgHub for entire cycle - no nesting
|
# Single MsgHub for entire cycle - no nesting
|
||||||
async with MsgHub(
|
async with MsgHub(
|
||||||
@@ -117,12 +165,17 @@ class TradingPipeline:
|
|||||||
"system",
|
"system",
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
# Phase 1.1: Analysts
|
# Phase 1.1: Analysts (parallel execution with TeamCoordinator)
|
||||||
_log("Phase 1.1: Analyst analysis")
|
_log("Phase 1.1: Analyst analysis (parallel)")
|
||||||
analyst_results = await self._run_analysts_with_sync(tickers, date)
|
analyst_results = await self._run_analysts_parallel(
|
||||||
|
tickers,
|
||||||
|
date,
|
||||||
|
active_analysts=active_analysts,
|
||||||
|
)
|
||||||
|
|
||||||
# Phase 1.2: Risk Manager
|
# Phase 1.2: Risk Manager
|
||||||
_log("Phase 1.2: Risk assessment")
|
_log("Phase 1.2: Risk assessment")
|
||||||
|
self._runtime_update_status(self.risk_manager, "risk_assessment")
|
||||||
risk_assessment = await self._run_risk_manager_with_sync(
|
risk_assessment = await self._run_risk_manager_with_sync(
|
||||||
tickers,
|
tickers,
|
||||||
date,
|
date,
|
||||||
@@ -145,6 +198,7 @@ class TradingPipeline:
|
|||||||
final_predictions = await self._collect_final_predictions(
|
final_predictions = await self._collect_final_predictions(
|
||||||
tickers,
|
tickers,
|
||||||
date,
|
date,
|
||||||
|
active_analysts=active_analysts,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Record final predictions for leaderboard ranking
|
# Record final predictions for leaderboard ranking
|
||||||
@@ -161,6 +215,7 @@ class TradingPipeline:
|
|||||||
|
|
||||||
# Phase 3: PM makes decisions
|
# Phase 3: PM makes decisions
|
||||||
_log("Phase 3.1: PM makes decisions")
|
_log("Phase 3.1: PM makes decisions")
|
||||||
|
self._runtime_update_status(self.pm, "decision_phase")
|
||||||
pm_result = await self._run_pm_with_sync(
|
pm_result = await self._run_pm_with_sync(
|
||||||
tickers,
|
tickers,
|
||||||
date,
|
date,
|
||||||
@@ -169,10 +224,17 @@ class TradingPipeline:
|
|||||||
risk_assessment,
|
risk_assessment,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Phase 4: Execute decisions
|
|
||||||
_log("Phase 4: Executing trades")
|
|
||||||
decisions = pm_result.get("decisions", {})
|
decisions = pm_result.get("decisions", {})
|
||||||
|
execution_result = {
|
||||||
|
"executed_trades": [],
|
||||||
|
"portfolio": self.pm.get_portfolio_state(),
|
||||||
|
}
|
||||||
|
if execute_decisions:
|
||||||
|
_log("Phase 4: Executing trades")
|
||||||
|
self._runtime_update_status(self.pm, "executing")
|
||||||
execution_result = self._execute_decisions(decisions, prices, date)
|
execution_result = self._execute_decisions(decisions, prices, date)
|
||||||
|
else:
|
||||||
|
_log("Phase 4: Skipping trade execution")
|
||||||
|
|
||||||
# Live mode: wait for market close before settlement
|
# Live mode: wait for market close before settlement
|
||||||
if get_close_prices_fn:
|
if get_close_prices_fn:
|
||||||
@@ -184,6 +246,10 @@ class TradingPipeline:
|
|||||||
settlement_result = None
|
settlement_result = None
|
||||||
if close_prices and self.settlement_coordinator:
|
if close_prices and self.settlement_coordinator:
|
||||||
_log("Phase 5: Daily review and generate memories")
|
_log("Phase 5: Daily review and generate memories")
|
||||||
|
self._runtime_batch_status(
|
||||||
|
[self.risk_manager] + self._all_analysts() + [self.pm],
|
||||||
|
"settlement",
|
||||||
|
)
|
||||||
|
|
||||||
agent_trajectories = await self._capture_agent_trajectories()
|
agent_trajectories = await self._capture_agent_trajectories()
|
||||||
|
|
||||||
@@ -214,8 +280,17 @@ class TradingPipeline:
|
|||||||
settlement_result=settlement_result,
|
settlement_result=settlement_result,
|
||||||
conference_summary=self.conference_summary,
|
conference_summary=self.conference_summary,
|
||||||
)
|
)
|
||||||
|
self._runtime_batch_status(
|
||||||
|
[self.risk_manager] + self._all_analysts() + [self.pm],
|
||||||
|
"reflection",
|
||||||
|
)
|
||||||
|
|
||||||
_log(f"Cycle complete: {date}")
|
_log(f"Cycle complete: {date}")
|
||||||
|
self._runtime_batch_status(
|
||||||
|
self._all_analysts() + [self.risk_manager, self.pm],
|
||||||
|
"idle",
|
||||||
|
)
|
||||||
|
self._runtime_log_event("cycle:end", {"tickers": tickers, "date": date})
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"analyst_results": analyst_results,
|
"analyst_results": analyst_results,
|
||||||
@@ -226,12 +301,18 @@ class TradingPipeline:
|
|||||||
"settlement_result": settlement_result,
|
"settlement_result": settlement_result,
|
||||||
}
|
}
|
||||||
|
|
||||||
def reload_runtime_assets(self) -> Dict[str, Any]:
|
def reload_runtime_assets(
|
||||||
"""Reload prompt assets, bootstrap config, and active skills for all agents."""
|
self,
|
||||||
|
runtime_config: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Reload prompt assets and safe in-process runtime settings."""
|
||||||
from backend.agents.skills_manager import SkillsManager
|
from backend.agents.skills_manager import SkillsManager
|
||||||
from backend.agents.toolkit_factory import load_agent_profiles
|
from backend.agents.toolkit_factory import load_agent_profiles
|
||||||
|
|
||||||
config_name = getattr(self.pm, "config", {}).get("config_name", "default")
|
config_name = getattr(self.pm, "config", {}).get("config_name", "default")
|
||||||
|
if runtime_config and "max_comm_cycles" in runtime_config:
|
||||||
|
self.max_comm_cycles = int(runtime_config["max_comm_cycles"])
|
||||||
|
|
||||||
skills_manager = SkillsManager()
|
skills_manager = SkillsManager()
|
||||||
profiles = load_agent_profiles()
|
profiles = load_agent_profiles()
|
||||||
active_skill_map = skills_manager.prepare_active_skills(
|
active_skill_map = skills_manager.prepare_active_skills(
|
||||||
@@ -242,7 +323,7 @@ class TradingPipeline:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
for analyst in self.analysts:
|
for analyst in self._all_analysts():
|
||||||
analyst.reload_runtime_assets(
|
analyst.reload_runtime_assets(
|
||||||
active_skill_dirs=active_skill_map.get(analyst.name, []),
|
active_skill_dirs=active_skill_map.get(analyst.name, []),
|
||||||
)
|
)
|
||||||
@@ -256,17 +337,18 @@ class TradingPipeline:
|
|||||||
|
|
||||||
return {
|
return {
|
||||||
"config_name": config_name,
|
"config_name": config_name,
|
||||||
"reloaded_agents": [agent.name for agent in self.analysts]
|
"reloaded_agents": [agent.name for agent in self._all_analysts()]
|
||||||
+ ["risk_manager", "portfolio_manager"],
|
+ ["risk_manager", "portfolio_manager"],
|
||||||
"active_skills": {
|
"active_skills": {
|
||||||
agent_id: [path.name for path in paths]
|
agent_id: [path.name for path in paths]
|
||||||
for agent_id, paths in active_skill_map.items()
|
for agent_id, paths in active_skill_map.items()
|
||||||
},
|
},
|
||||||
|
"max_comm_cycles": self.max_comm_cycles,
|
||||||
}
|
}
|
||||||
|
|
||||||
async def _clear_all_agent_memory(self):
|
async def _clear_all_agent_memory(self):
|
||||||
"""Clear short-term memory for all agents"""
|
"""Clear short-term memory for all agents"""
|
||||||
for analyst in self.analysts:
|
for analyst in self._all_analysts():
|
||||||
await analyst.memory.clear()
|
await analyst.memory.clear()
|
||||||
|
|
||||||
await self.risk_manager.memory.clear()
|
await self.risk_manager.memory.clear()
|
||||||
@@ -348,7 +430,7 @@ class TradingPipeline:
|
|||||||
trajectories = {}
|
trajectories = {}
|
||||||
|
|
||||||
# Capture analyst trajectories
|
# Capture analyst trajectories
|
||||||
for analyst in self.analysts:
|
for analyst in self._all_analysts():
|
||||||
try:
|
try:
|
||||||
msgs = await analyst.memory.get_memory()
|
msgs = await analyst.memory.get_memory()
|
||||||
if msgs:
|
if msgs:
|
||||||
@@ -558,7 +640,7 @@ class TradingPipeline:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Record for analysts
|
# Record for analysts
|
||||||
for analyst in self.analysts:
|
for analyst in self._all_analysts():
|
||||||
if (
|
if (
|
||||||
hasattr(analyst, "long_term_memory")
|
hasattr(analyst, "long_term_memory")
|
||||||
and analyst.long_term_memory is not None
|
and analyst.long_term_memory is not None
|
||||||
@@ -677,7 +759,22 @@ class TradingPipeline:
|
|||||||
date=date,
|
date=date,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Run discussion cycles (no new MsgHub - use parent's)
|
# Conference participants: analysts + PM
|
||||||
|
conference_participants = self._get_active_analysts() + [self.pm]
|
||||||
|
|
||||||
|
# Use TeamMsgHub for conference if available
|
||||||
|
if TEAM_COORD_AVAILABLE and TeamMsgHub is not None:
|
||||||
|
_log(
|
||||||
|
f"Phase 2.1: Conference using TeamMsgHub with "
|
||||||
|
f"{len(conference_participants)} participants"
|
||||||
|
)
|
||||||
|
conference_hub = TeamMsgHub(participants=conference_participants)
|
||||||
|
else:
|
||||||
|
_log("Phase 2.1: Conference using standard MsgHub context")
|
||||||
|
conference_hub = None
|
||||||
|
|
||||||
|
# Run discussion cycles
|
||||||
|
async with conference_hub if conference_hub else nullcontext(None):
|
||||||
for cycle in range(self.max_comm_cycles):
|
for cycle in range(self.max_comm_cycles):
|
||||||
_log(
|
_log(
|
||||||
"Phase 2.1: Conference discussion - "
|
"Phase 2.1: Conference discussion - "
|
||||||
@@ -710,8 +807,8 @@ class TradingPipeline:
|
|||||||
content=pm_content,
|
content=pm_content,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Analysts share perspectives
|
# Analysts share perspectives (supports per-round active team updates)
|
||||||
for analyst in self.analysts:
|
for analyst in self._get_active_analysts():
|
||||||
analyst_prompt = self._build_analyst_discussion_prompt(
|
analyst_prompt = self._build_analyst_discussion_prompt(
|
||||||
cycle=cycle,
|
cycle=cycle,
|
||||||
tickers=tickers,
|
tickers=tickers,
|
||||||
@@ -838,6 +935,7 @@ class TradingPipeline:
|
|||||||
self,
|
self,
|
||||||
tickers: List[str],
|
tickers: List[str],
|
||||||
date: str,
|
date: str,
|
||||||
|
active_analysts: Optional[List[Any]] = None,
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Collect final predictions from all analysts as simple text responses.
|
Collect final predictions from all analysts as simple text responses.
|
||||||
@@ -845,14 +943,15 @@ class TradingPipeline:
|
|||||||
"""
|
"""
|
||||||
_log(
|
_log(
|
||||||
"Phase 2.2: Analysts generate final structured predictions\n"
|
"Phase 2.2: Analysts generate final structured predictions\n"
|
||||||
f" Starting _collect_final_predictions for {len(self.analysts)} analysts",
|
f" Starting _collect_final_predictions for {len(active_analysts or self.analysts)} analysts",
|
||||||
)
|
)
|
||||||
final_predictions = []
|
final_predictions = []
|
||||||
|
|
||||||
for i, analyst in enumerate(self.analysts):
|
analysts = active_analysts or self.analysts
|
||||||
|
for i, analyst in enumerate(analysts):
|
||||||
_log(
|
_log(
|
||||||
"Phase 2.2: Analysts generate final structured predictions\n"
|
"Phase 2.2: Analysts generate final structured predictions\n"
|
||||||
f" Collecting prediction from analyst {i+1}/{len(self.analysts)}: {analyst.name}",
|
f" Collecting prediction from analyst {i+1}/{len(analysts)}: {analyst.name}",
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt = (
|
prompt = (
|
||||||
@@ -948,11 +1047,13 @@ class TradingPipeline:
|
|||||||
self,
|
self,
|
||||||
tickers: List[str],
|
tickers: List[str],
|
||||||
date: str,
|
date: str,
|
||||||
|
active_analysts: Optional[List[Any]] = None,
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""Run all analysts with real-time sync after each completion"""
|
"""Run all analysts with real-time sync after each completion"""
|
||||||
results = []
|
results = []
|
||||||
|
analysts = active_analysts or self.analysts
|
||||||
|
|
||||||
for analyst in self.analysts:
|
for analyst in analysts:
|
||||||
content = (
|
content = (
|
||||||
f"Analyze the following stocks for date {date}: {', '.join(tickers)}. "
|
f"Analyze the following stocks for date {date}: {', '.join(tickers)}. "
|
||||||
f"Provide investment signals with confidence scores and reasoning."
|
f"Provide investment signals with confidence scores and reasoning."
|
||||||
@@ -982,15 +1083,107 @@ class TradingPipeline:
|
|||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
async def _run_analysts_parallel(
|
||||||
|
self,
|
||||||
|
tickers: List[str],
|
||||||
|
date: str,
|
||||||
|
active_analysts: Optional[List[Any]] = None,
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""Run all analysts in parallel using TeamCoordinator.
|
||||||
|
|
||||||
|
This method replaces the sequential analyst loop with parallel execution
|
||||||
|
using the TeamCoordinator for orchestration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tickers: List of stock tickers to analyze
|
||||||
|
date: Trading date
|
||||||
|
active_analysts: Optional list of analysts to run
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of analyst result dictionaries
|
||||||
|
"""
|
||||||
|
analysts = active_analysts or self.analysts
|
||||||
|
|
||||||
|
if not analysts:
|
||||||
|
return []
|
||||||
|
|
||||||
|
if not TEAM_COORD_AVAILABLE:
|
||||||
|
_log("TeamCoordinator not available, falling back to sequential execution")
|
||||||
|
return await self._run_analysts_with_sync(
|
||||||
|
tickers=tickers,
|
||||||
|
date=date,
|
||||||
|
active_analysts=active_analysts,
|
||||||
|
)
|
||||||
|
|
||||||
|
_log(
|
||||||
|
f"Phase 1.1: Running {len(analysts)} analysts in parallel "
|
||||||
|
f"[{', '.join(a.name for a in analysts)}]"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build the analyst prompt
|
||||||
|
content = (
|
||||||
|
f"Analyze the following stocks for date {date}: {', '.join(tickers)}. "
|
||||||
|
f"Provide investment signals with confidence scores and reasoning."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create coordinator for parallel execution
|
||||||
|
coordinator = TeamCoordinator(
|
||||||
|
participants=analysts,
|
||||||
|
task_content=content,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run analysts in parallel via TeamCoordinator
|
||||||
|
results = await coordinator.run_phase(
|
||||||
|
"analyst_analysis",
|
||||||
|
metadata={"tickers": tickers, "date": date},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process results and sync
|
||||||
|
processed_results = []
|
||||||
|
for i, (analyst, result) in enumerate(zip(analysts, results)):
|
||||||
|
if result is not None:
|
||||||
|
extracted = self._extract_result_from_msg(result)
|
||||||
|
processed_results.append(extracted)
|
||||||
|
|
||||||
|
# Sync retrieved memory
|
||||||
|
await self._sync_memory_if_retrieved(analyst)
|
||||||
|
|
||||||
|
# Broadcast agent result via StateSync
|
||||||
|
if self.state_sync:
|
||||||
|
text_content = self._extract_text_content(result.content)
|
||||||
|
await self.state_sync.on_agent_complete(
|
||||||
|
agent_id=analyst.name,
|
||||||
|
content=text_content,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"Analyst %s returned no result",
|
||||||
|
analyst.name,
|
||||||
|
)
|
||||||
|
processed_results.append({
|
||||||
|
"agent": analyst.name,
|
||||||
|
"content": "",
|
||||||
|
"success": False,
|
||||||
|
})
|
||||||
|
|
||||||
|
_log(
|
||||||
|
f"Phase 1.1: Parallel analyst execution complete "
|
||||||
|
f"({len(processed_results)}/{len(analysts)} successful)"
|
||||||
|
)
|
||||||
|
|
||||||
|
return processed_results
|
||||||
|
|
||||||
async def _run_analysts(
|
async def _run_analysts(
|
||||||
self,
|
self,
|
||||||
tickers: List[str],
|
tickers: List[str],
|
||||||
date: str,
|
date: str,
|
||||||
|
active_analysts: Optional[List[Any]] = None,
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""Run all analysts (without sync, for backward compatibility)"""
|
"""Run all analysts (without sync, for backward compatibility)"""
|
||||||
results = []
|
results = []
|
||||||
|
analysts = active_analysts or self.analysts
|
||||||
|
|
||||||
for analyst in self.analysts:
|
for analyst in analysts:
|
||||||
content = (
|
content = (
|
||||||
f"Analyze the following stocks for date {date}: {', '.join(tickers)}. "
|
f"Analyze the following stocks for date {date}: {', '.join(tickers)}. "
|
||||||
f"Provide investment signals with confidence scores and reasoning."
|
f"Provide investment signals with confidence scores and reasoning."
|
||||||
@@ -1299,3 +1492,199 @@ class TradingPipeline:
|
|||||||
if decision_texts:
|
if decision_texts:
|
||||||
return "Decisions: " + "; ".join(decision_texts)
|
return "Decisions: " + "; ".join(decision_texts)
|
||||||
return "Portfolio analysis completed. No trades recommended."
|
return "Portfolio analysis completed. No trades recommended."
|
||||||
|
|
||||||
|
def load_agents_from_workspace(
|
||||||
|
self,
|
||||||
|
workspace_id: str,
|
||||||
|
agent_factory: Optional[Any] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Load agents from workspace using AgentFactory.
|
||||||
|
|
||||||
|
This method supports the new EvoAgent architecture by loading
|
||||||
|
agents from a workspace instead of using hardcoded agents.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_id: Workspace identifier
|
||||||
|
agent_factory: Optional AgentFactory instance (uses self.agent_factory if None)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with loaded agents:
|
||||||
|
{
|
||||||
|
"analysts": List[EvoAgent],
|
||||||
|
"risk_manager": EvoAgent,
|
||||||
|
"portfolio_manager": EvoAgent,
|
||||||
|
}
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If workspace doesn't exist or no agents found
|
||||||
|
"""
|
||||||
|
factory = agent_factory or self.agent_factory
|
||||||
|
if factory is None:
|
||||||
|
from backend.agents import AgentFactory
|
||||||
|
factory = AgentFactory()
|
||||||
|
|
||||||
|
# Check workspace exists
|
||||||
|
if not factory.workspaces_root.exists():
|
||||||
|
raise ValueError(f"Workspaces root does not exist: {factory.workspaces_root}")
|
||||||
|
|
||||||
|
workspace_dir = factory.workspaces_root / workspace_id
|
||||||
|
if not workspace_dir.exists():
|
||||||
|
raise ValueError(f"Workspace '{workspace_id}' does not exist")
|
||||||
|
|
||||||
|
# Load agents from workspace
|
||||||
|
agents_data = factory.list_agents(workspace_id=workspace_id)
|
||||||
|
|
||||||
|
if not agents_data:
|
||||||
|
raise ValueError(f"No agents found in workspace '{workspace_id}'")
|
||||||
|
|
||||||
|
# Categorize agents by type
|
||||||
|
analysts = []
|
||||||
|
risk_manager = None
|
||||||
|
portfolio_manager = None
|
||||||
|
|
||||||
|
for agent_data in agents_data:
|
||||||
|
agent_type = agent_data.get("agent_type", "unknown")
|
||||||
|
agent_id = agent_data.get("agent_id")
|
||||||
|
|
||||||
|
# Load full agent configuration
|
||||||
|
config_path = Path(agent_data.get("config_path", ""))
|
||||||
|
if config_path.exists():
|
||||||
|
agent = factory.load_agent(agent_id, workspace_id)
|
||||||
|
|
||||||
|
if agent_type.endswith("_analyst"):
|
||||||
|
analysts.append(agent)
|
||||||
|
elif agent_type == "risk_manager":
|
||||||
|
risk_manager = agent
|
||||||
|
elif agent_type == "portfolio_manager":
|
||||||
|
portfolio_manager = agent
|
||||||
|
|
||||||
|
if not analysts:
|
||||||
|
raise ValueError(f"No analysts found in workspace '{workspace_id}'")
|
||||||
|
if risk_manager is None:
|
||||||
|
raise ValueError(f"No risk_manager found in workspace '{workspace_id}'")
|
||||||
|
if portfolio_manager is None:
|
||||||
|
raise ValueError(f"No portfolio_manager found in workspace '{workspace_id}'")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"analysts": analysts,
|
||||||
|
"risk_manager": risk_manager,
|
||||||
|
"portfolio_manager": portfolio_manager,
|
||||||
|
}
|
||||||
|
|
||||||
|
def reload_agents_from_workspace(self, workspace_id: Optional[str] = None) -> None:
|
||||||
|
"""
|
||||||
|
Reload all agents from workspace.
|
||||||
|
|
||||||
|
This updates self.analysts, self.risk_manager, and self.pm
|
||||||
|
with agents loaded from the specified workspace.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_id: Workspace ID (uses self.workspace_id if None)
|
||||||
|
"""
|
||||||
|
ws_id = workspace_id or self.workspace_id
|
||||||
|
if not ws_id:
|
||||||
|
raise ValueError("No workspace_id specified")
|
||||||
|
|
||||||
|
loaded = self.load_agents_from_workspace(ws_id)
|
||||||
|
|
||||||
|
self.analysts = loaded["analysts"]
|
||||||
|
self.risk_manager = loaded["risk_manager"]
|
||||||
|
self.pm = loaded["portfolio_manager"]
|
||||||
|
self.workspace_id = ws_id
|
||||||
|
|
||||||
|
logger.info(f"Reloaded {len(self.analysts)} analysts from workspace '{ws_id}'")
|
||||||
|
|
||||||
|
def _runtime_update_status(self, agent: Any, status: str) -> None:
|
||||||
|
if not self.runtime_manager:
|
||||||
|
return
|
||||||
|
agent_id = getattr(agent, "agent_id", None) or getattr(agent, "name", None)
|
||||||
|
if not agent_id:
|
||||||
|
return
|
||||||
|
self.runtime_manager.update_agent_status(agent_id, status, self._session_key)
|
||||||
|
|
||||||
|
def _runtime_batch_status(self, agents: List[Any], status: str) -> None:
|
||||||
|
for agent in agents:
|
||||||
|
self._runtime_update_status(agent, status)
|
||||||
|
|
||||||
|
def _all_analysts(self) -> List[Any]:
|
||||||
|
"""Return static analysts plus runtime-created analysts."""
|
||||||
|
return list(self.analysts) + list(self._dynamic_analysts.values())
|
||||||
|
|
||||||
|
def _create_runtime_analyst(self, agent_id: str, analyst_type: str) -> str:
|
||||||
|
"""Create one runtime analyst instance."""
|
||||||
|
if analyst_type not in ANALYST_TYPES:
|
||||||
|
return (
|
||||||
|
f"Unknown analyst_type '{analyst_type}'. "
|
||||||
|
f"Available: {', '.join(ANALYST_TYPES.keys())}"
|
||||||
|
)
|
||||||
|
if agent_id in {agent.name for agent in self._all_analysts()}:
|
||||||
|
return f"Analyst '{agent_id}' already exists."
|
||||||
|
|
||||||
|
config_name = getattr(self.pm, "config", {}).get("config_name", "default")
|
||||||
|
project_root = Path(__file__).resolve().parents[2]
|
||||||
|
personas = get_prompt_loader().load_yaml_config("analyst", "personas")
|
||||||
|
persona = personas.get(analyst_type, {})
|
||||||
|
WorkspaceManager(project_root=project_root).ensure_agent_assets(
|
||||||
|
config_name=config_name,
|
||||||
|
agent_id=agent_id,
|
||||||
|
role_seed=persona.get("description", "").strip(),
|
||||||
|
style_seed="\n".join(f"- {item}" for item in persona.get("focus", [])),
|
||||||
|
policy_seed=(
|
||||||
|
"State a clear signal, confidence, and the conditions "
|
||||||
|
"that would invalidate the thesis."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
agent = AnalystAgent(
|
||||||
|
analyst_type=analyst_type,
|
||||||
|
toolkit=create_agent_toolkit(
|
||||||
|
agent_id=agent_id,
|
||||||
|
config_name=config_name,
|
||||||
|
active_skill_dirs=[],
|
||||||
|
),
|
||||||
|
model=get_agent_model(analyst_type),
|
||||||
|
formatter=get_agent_formatter(analyst_type),
|
||||||
|
agent_id=agent_id,
|
||||||
|
config={"config_name": config_name},
|
||||||
|
)
|
||||||
|
self._dynamic_analysts[agent_id] = agent
|
||||||
|
update_active_analysts(
|
||||||
|
project_root=project_root,
|
||||||
|
config_name=config_name,
|
||||||
|
available_analysts=[item.name for item in self._all_analysts()],
|
||||||
|
add=[agent_id],
|
||||||
|
)
|
||||||
|
return f"Created runtime analyst '{agent_id}' ({analyst_type})."
|
||||||
|
|
||||||
|
def _remove_runtime_analyst(self, agent_id: str) -> str:
|
||||||
|
"""Remove one runtime-created analyst instance."""
|
||||||
|
if agent_id not in self._dynamic_analysts:
|
||||||
|
return f"Runtime analyst '{agent_id}' not found."
|
||||||
|
self._dynamic_analysts.pop(agent_id, None)
|
||||||
|
config_name = getattr(self.pm, "config", {}).get("config_name", "default")
|
||||||
|
project_root = Path(__file__).resolve().parents[2]
|
||||||
|
update_active_analysts(
|
||||||
|
project_root=project_root,
|
||||||
|
config_name=config_name,
|
||||||
|
available_analysts=[item.name for item in self._all_analysts()],
|
||||||
|
remove=[agent_id],
|
||||||
|
)
|
||||||
|
return f"Removed runtime analyst '{agent_id}'."
|
||||||
|
|
||||||
|
def _get_active_analysts(self) -> List[Any]:
|
||||||
|
"""Resolve active analyst participants from run-scoped team pipeline config."""
|
||||||
|
config_name = getattr(self.pm, "config", {}).get("config_name", "default")
|
||||||
|
project_root = Path(__file__).resolve().parents[2]
|
||||||
|
analyst_map = {agent.name: agent for agent in self._all_analysts()}
|
||||||
|
active_ids = resolve_active_analysts(
|
||||||
|
project_root=project_root,
|
||||||
|
config_name=config_name,
|
||||||
|
available_analysts=list(analyst_map.keys()),
|
||||||
|
)
|
||||||
|
return [analyst_map[agent_id] for agent_id in active_ids if agent_id in analyst_map]
|
||||||
|
|
||||||
|
def _runtime_log_event(self, event: str, details: Optional[Dict[str, Any]] = None) -> None:
|
||||||
|
if not self.runtime_manager:
|
||||||
|
return
|
||||||
|
self.runtime_manager.log_event(event, details)
|
||||||
|
|||||||
489
backend/core/pipeline_runner.py
Normal file
489
backend/core/pipeline_runner.py
Normal file
@@ -0,0 +1,489 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
Pipeline Runner - Independent trading pipeline execution
|
||||||
|
|
||||||
|
This module provides functions to start/stop trading pipelines
|
||||||
|
that can be called from the REST API.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
from contextlib import AsyncExitStack
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, Optional, Callable
|
||||||
|
|
||||||
|
from backend.agents import AnalystAgent, PMAgent, RiskAgent
|
||||||
|
from backend.agents.skills_manager import SkillsManager
|
||||||
|
from backend.agents.toolkit_factory import create_agent_toolkit, load_agent_profiles
|
||||||
|
from backend.agents.prompt_loader import get_prompt_loader
|
||||||
|
from backend.agents.workspace_manager import WorkspaceManager
|
||||||
|
from backend.config.constants import ANALYST_TYPES
|
||||||
|
from backend.core.pipeline import TradingPipeline
|
||||||
|
from backend.core.scheduler import BacktestScheduler, Scheduler
|
||||||
|
from backend.llm.models import get_agent_formatter, get_agent_model
|
||||||
|
from backend.runtime.manager import (
|
||||||
|
TradingRuntimeManager,
|
||||||
|
set_global_runtime_manager,
|
||||||
|
clear_global_runtime_manager,
|
||||||
|
set_shutdown_event,
|
||||||
|
clear_shutdown_event,
|
||||||
|
is_shutdown_requested,
|
||||||
|
)
|
||||||
|
from backend.services.market import MarketService
|
||||||
|
from backend.services.storage import StorageService
|
||||||
|
from backend.services.gateway import Gateway
|
||||||
|
from backend.utils.settlement import SettlementCoordinator
|
||||||
|
|
||||||
|
_prompt_loader = get_prompt_loader()
|
||||||
|
|
||||||
|
# Global gateway reference for cleanup
|
||||||
|
_gateway_instance: Optional[Gateway] = None
|
||||||
|
|
||||||
|
|
||||||
|
def _set_gateway(gateway: Optional[Gateway]) -> None:
|
||||||
|
"""Set global gateway reference."""
|
||||||
|
global _gateway_instance
|
||||||
|
_gateway_instance = gateway
|
||||||
|
|
||||||
|
|
||||||
|
def stop_gateway() -> None:
|
||||||
|
"""Stop the running gateway if exists."""
|
||||||
|
global _gateway_instance
|
||||||
|
if _gateway_instance is not None:
|
||||||
|
try:
|
||||||
|
_gateway_instance.stop()
|
||||||
|
except Exception as e:
|
||||||
|
import logging
|
||||||
|
logging.getLogger(__name__).error(f"Error stopping gateway: {e}")
|
||||||
|
finally:
|
||||||
|
_gateway_instance = None
|
||||||
|
|
||||||
|
|
||||||
|
def create_long_term_memory(agent_name: str, run_id: str, run_dir: Path):
|
||||||
|
"""Create ReMeTaskLongTermMemory for an agent."""
|
||||||
|
try:
|
||||||
|
from agentscope.memory import ReMeTaskLongTermMemory
|
||||||
|
from agentscope.model import DashScopeChatModel
|
||||||
|
from agentscope.embedding import DashScopeTextEmbedding
|
||||||
|
except ImportError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
api_key = os.getenv("MEMORY_API_KEY")
|
||||||
|
if not api_key:
|
||||||
|
return None
|
||||||
|
|
||||||
|
memory_dir = str(run_dir / "memory")
|
||||||
|
|
||||||
|
return ReMeTaskLongTermMemory(
|
||||||
|
agent_name=agent_name,
|
||||||
|
user_name=agent_name,
|
||||||
|
model=DashScopeChatModel(
|
||||||
|
model_name=os.getenv("MEMORY_MODEL_NAME", "qwen3-max"),
|
||||||
|
api_key=api_key,
|
||||||
|
stream=False,
|
||||||
|
),
|
||||||
|
embedding_model=DashScopeTextEmbedding(
|
||||||
|
model_name=os.getenv("MEMORY_EMBEDDING_MODEL", "text-embedding-v4"),
|
||||||
|
api_key=api_key,
|
||||||
|
dimensions=1024,
|
||||||
|
),
|
||||||
|
**{
|
||||||
|
"vector_store.default.backend": "local",
|
||||||
|
"vector_store.default.params.store_dir": memory_dir,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_agents(
|
||||||
|
run_id: str,
|
||||||
|
run_dir: Path,
|
||||||
|
initial_cash: float,
|
||||||
|
margin_requirement: float,
|
||||||
|
enable_long_term_memory: bool = False,
|
||||||
|
):
|
||||||
|
"""Create all agents for the system."""
|
||||||
|
analysts = []
|
||||||
|
long_term_memories = []
|
||||||
|
|
||||||
|
# Initialize workspace manager and assets
|
||||||
|
workspace_manager = WorkspaceManager()
|
||||||
|
workspace_manager.initialize_default_assets(
|
||||||
|
config_name=run_id,
|
||||||
|
agent_ids=list(ANALYST_TYPES.keys()) + ["risk_manager", "portfolio_manager"],
|
||||||
|
analyst_personas=_prompt_loader.load_yaml_config("analyst", "personas"),
|
||||||
|
)
|
||||||
|
|
||||||
|
profiles = load_agent_profiles()
|
||||||
|
skills_manager = SkillsManager()
|
||||||
|
active_skill_map = skills_manager.prepare_active_skills(
|
||||||
|
config_name=run_id,
|
||||||
|
agent_defaults={
|
||||||
|
agent_id: profile.get("skills", [])
|
||||||
|
for agent_id, profile in profiles.items()
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create analyst agents
|
||||||
|
for analyst_type in ANALYST_TYPES:
|
||||||
|
model = get_agent_model(analyst_type)
|
||||||
|
formatter = get_agent_formatter(analyst_type)
|
||||||
|
toolkit = create_agent_toolkit(
|
||||||
|
analyst_type,
|
||||||
|
run_id,
|
||||||
|
active_skill_dirs=active_skill_map.get(analyst_type, []),
|
||||||
|
)
|
||||||
|
|
||||||
|
long_term_memory = None
|
||||||
|
if enable_long_term_memory:
|
||||||
|
long_term_memory = create_long_term_memory(analyst_type, run_id, run_dir)
|
||||||
|
if long_term_memory:
|
||||||
|
long_term_memories.append(long_term_memory)
|
||||||
|
|
||||||
|
analyst = AnalystAgent(
|
||||||
|
analyst_type=analyst_type,
|
||||||
|
toolkit=toolkit,
|
||||||
|
model=model,
|
||||||
|
formatter=formatter,
|
||||||
|
agent_id=analyst_type,
|
||||||
|
config={"config_name": run_id},
|
||||||
|
long_term_memory=long_term_memory,
|
||||||
|
)
|
||||||
|
analysts.append(analyst)
|
||||||
|
|
||||||
|
# Create risk manager
|
||||||
|
risk_long_term_memory = None
|
||||||
|
if enable_long_term_memory:
|
||||||
|
risk_long_term_memory = create_long_term_memory("risk_manager", run_id, run_dir)
|
||||||
|
if risk_long_term_memory:
|
||||||
|
long_term_memories.append(risk_long_term_memory)
|
||||||
|
|
||||||
|
risk_manager = RiskAgent(
|
||||||
|
model=get_agent_model("risk_manager"),
|
||||||
|
formatter=get_agent_formatter("risk_manager"),
|
||||||
|
name="risk_manager",
|
||||||
|
config={"config_name": run_id},
|
||||||
|
long_term_memory=risk_long_term_memory,
|
||||||
|
toolkit=create_agent_toolkit(
|
||||||
|
"risk_manager",
|
||||||
|
run_id,
|
||||||
|
active_skill_dirs=active_skill_map.get("risk_manager", []),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create portfolio manager
|
||||||
|
pm_long_term_memory = None
|
||||||
|
if enable_long_term_memory:
|
||||||
|
pm_long_term_memory = create_long_term_memory("portfolio_manager", run_id, run_dir)
|
||||||
|
if pm_long_term_memory:
|
||||||
|
long_term_memories.append(pm_long_term_memory)
|
||||||
|
|
||||||
|
portfolio_manager = PMAgent(
|
||||||
|
name="portfolio_manager",
|
||||||
|
model=get_agent_model("portfolio_manager"),
|
||||||
|
formatter=get_agent_formatter("portfolio_manager"),
|
||||||
|
initial_cash=initial_cash,
|
||||||
|
margin_requirement=margin_requirement,
|
||||||
|
config={"config_name": run_id},
|
||||||
|
long_term_memory=pm_long_term_memory,
|
||||||
|
toolkit_factory=create_agent_toolkit,
|
||||||
|
toolkit_factory_kwargs={
|
||||||
|
"active_skill_dirs": active_skill_map.get("portfolio_manager", []),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return analysts, risk_manager, portfolio_manager, long_term_memories
|
||||||
|
|
||||||
|
|
||||||
|
async def run_pipeline(
|
||||||
|
run_id: str,
|
||||||
|
run_dir: Path,
|
||||||
|
bootstrap: Dict[str, Any],
|
||||||
|
stop_event: asyncio.Event,
|
||||||
|
message_callback: Optional[Callable[[str, Any], None]] = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Run the trading pipeline with the given configuration.
|
||||||
|
|
||||||
|
Service Startup Order:
|
||||||
|
Phase 1: WebSocket Server - Frontend can connect
|
||||||
|
Phase 2: Market Service - Price data starts flowing
|
||||||
|
Phase 3: Agent Runtime - Create all agents
|
||||||
|
Phase 4: Pipeline & Scheduler - Trading logic ready
|
||||||
|
Phase 5: Gateway Fully Operational - All systems running
|
||||||
|
|
||||||
|
Args:
|
||||||
|
run_id: Unique run identifier (timestamp)
|
||||||
|
run_dir: Run directory path
|
||||||
|
bootstrap: Bootstrap configuration
|
||||||
|
stop_event: Event to signal pipeline stop
|
||||||
|
message_callback: Optional callback for sending messages to clients
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Set global shutdown event
|
||||||
|
set_shutdown_event(stop_event)
|
||||||
|
|
||||||
|
logger.info(f"[Pipeline {run_id}] ======================================")
|
||||||
|
logger.info(f"[Pipeline {run_id}] Starting with 5-phase initialization...")
|
||||||
|
logger.info(f"[Pipeline {run_id}] ======================================")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Extract config values
|
||||||
|
tickers = bootstrap.get("tickers", ["AAPL", "MSFT"])
|
||||||
|
initial_cash = float(bootstrap.get("initial_cash", 100000.0))
|
||||||
|
margin_requirement = float(bootstrap.get("margin_requirement", 0.0))
|
||||||
|
max_comm_cycles = int(bootstrap.get("max_comm_cycles", 2))
|
||||||
|
schedule_mode = bootstrap.get("schedule_mode", "daily")
|
||||||
|
trigger_time = bootstrap.get("trigger_time", "09:30")
|
||||||
|
interval_minutes = int(bootstrap.get("interval_minutes", 60))
|
||||||
|
heartbeat_interval = int(bootstrap.get("heartbeat_interval", 0))
|
||||||
|
mode = bootstrap.get("mode", "live")
|
||||||
|
start_date = bootstrap.get("start_date")
|
||||||
|
end_date = bootstrap.get("end_date")
|
||||||
|
enable_memory = bootstrap.get("enable_memory", False)
|
||||||
|
enable_mock = bootstrap.get("enable_mock", False)
|
||||||
|
|
||||||
|
is_backtest = mode == "backtest"
|
||||||
|
is_mock = enable_mock or mode == "mock" or (not is_backtest and os.getenv("MOCK_MODE", "false").lower() == "true")
|
||||||
|
|
||||||
|
# ======================================================================
|
||||||
|
# PHASE 0: Initialize runtime manager
|
||||||
|
# ======================================================================
|
||||||
|
logger.info("[Phase 0/5] Initializing runtime manager...")
|
||||||
|
|
||||||
|
from backend.api.runtime import runtime_manager
|
||||||
|
|
||||||
|
if runtime_manager is None:
|
||||||
|
runtime_manager = TradingRuntimeManager(
|
||||||
|
config_name=run_id,
|
||||||
|
run_dir=run_dir,
|
||||||
|
bootstrap=bootstrap,
|
||||||
|
)
|
||||||
|
runtime_manager.prepare_run()
|
||||||
|
|
||||||
|
set_global_runtime_manager(runtime_manager)
|
||||||
|
|
||||||
|
# Register runtime manager with API
|
||||||
|
from backend.api.runtime import register_runtime_manager
|
||||||
|
register_runtime_manager(runtime_manager)
|
||||||
|
|
||||||
|
# ======================================================================
|
||||||
|
# PHASE 1 & 2: Create infrastructure services (Market, Storage)
|
||||||
|
# These will be started by Gateway in the correct order
|
||||||
|
# ======================================================================
|
||||||
|
logger.info("[Phase 1-2/5] Creating infrastructure services...")
|
||||||
|
|
||||||
|
# Create storage service
|
||||||
|
storage_service = StorageService(
|
||||||
|
dashboard_dir=run_dir / "team_dashboard",
|
||||||
|
initial_cash=initial_cash,
|
||||||
|
config_name=run_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not storage_service.files["summary"].exists():
|
||||||
|
storage_service.initialize_empty_dashboard()
|
||||||
|
else:
|
||||||
|
storage_service.update_leaderboard_model_info()
|
||||||
|
|
||||||
|
# Create market service (data source)
|
||||||
|
market_service = MarketService(
|
||||||
|
tickers=tickers,
|
||||||
|
poll_interval=10,
|
||||||
|
mock_mode=is_mock and not is_backtest,
|
||||||
|
backtest_mode=is_backtest,
|
||||||
|
api_key=os.getenv("FINNHUB_API_KEY") if not is_mock and not is_backtest else None,
|
||||||
|
backtest_start_date=start_date if is_backtest else None,
|
||||||
|
backtest_end_date=end_date if is_backtest else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ======================================================================
|
||||||
|
# PHASE 3: Create Agent Runtime
|
||||||
|
# ======================================================================
|
||||||
|
logger.info("[Phase 3/5] Creating agent runtime...")
|
||||||
|
|
||||||
|
analysts, risk_manager, pm, long_term_memories = create_agents(
|
||||||
|
run_id=run_id,
|
||||||
|
run_dir=run_dir,
|
||||||
|
initial_cash=initial_cash,
|
||||||
|
margin_requirement=margin_requirement,
|
||||||
|
enable_long_term_memory=enable_memory,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Register agents with runtime manager
|
||||||
|
for agent in analysts + [risk_manager, pm]:
|
||||||
|
agent_id = getattr(agent, "agent_id", None) or getattr(agent, "name", None)
|
||||||
|
if agent_id:
|
||||||
|
runtime_manager.register_agent(agent_id)
|
||||||
|
|
||||||
|
# Load portfolio state
|
||||||
|
portfolio_state = storage_service.load_portfolio_state()
|
||||||
|
pm.load_portfolio_state(portfolio_state)
|
||||||
|
|
||||||
|
# Create settlement coordinator
|
||||||
|
settlement_coordinator = SettlementCoordinator(
|
||||||
|
storage=storage_service,
|
||||||
|
initial_capital=initial_cash,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ======================================================================
|
||||||
|
# PHASE 4: Create Pipeline & Scheduler
|
||||||
|
# ======================================================================
|
||||||
|
logger.info("[Phase 4/5] Creating pipeline and scheduler...")
|
||||||
|
|
||||||
|
# Create pipeline
|
||||||
|
pipeline = TradingPipeline(
|
||||||
|
analysts=analysts,
|
||||||
|
risk_manager=risk_manager,
|
||||||
|
portfolio_manager=pm,
|
||||||
|
settlement_coordinator=settlement_coordinator,
|
||||||
|
max_comm_cycles=max_comm_cycles,
|
||||||
|
runtime_manager=runtime_manager,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create scheduler
|
||||||
|
scheduler_callback = None
|
||||||
|
live_scheduler = None
|
||||||
|
|
||||||
|
if is_backtest:
|
||||||
|
backtest_scheduler = BacktestScheduler(
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
trading_calendar="NYSE",
|
||||||
|
delay_between_days=0.5,
|
||||||
|
)
|
||||||
|
trading_dates = backtest_scheduler.get_trading_dates()
|
||||||
|
|
||||||
|
async def scheduler_callback_fn(callback):
|
||||||
|
await backtest_scheduler.start(callback)
|
||||||
|
|
||||||
|
scheduler_callback = scheduler_callback_fn
|
||||||
|
else:
|
||||||
|
# Live mode
|
||||||
|
live_scheduler = Scheduler(
|
||||||
|
mode=schedule_mode,
|
||||||
|
trigger_time=trigger_time,
|
||||||
|
interval_minutes=interval_minutes,
|
||||||
|
heartbeat_interval=heartbeat_interval if heartbeat_interval > 0 else None,
|
||||||
|
config={"config_name": run_id},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def scheduler_callback_fn(callback):
|
||||||
|
await live_scheduler.start(callback)
|
||||||
|
|
||||||
|
scheduler_callback = scheduler_callback_fn
|
||||||
|
|
||||||
|
# ======================================================================
|
||||||
|
# PHASE 5: Start Gateway (WebSocket → Market → Scheduler)
|
||||||
|
# Gateway.start() will handle the final startup sequence:
|
||||||
|
# - WebSocket Server first (frontend can connect)
|
||||||
|
# - Market Service second (price data flows)
|
||||||
|
# - Scheduler last (trading begins)
|
||||||
|
# ======================================================================
|
||||||
|
logger.info("[Phase 5/5] Starting Gateway (WebSocket → Market → Scheduler)...")
|
||||||
|
|
||||||
|
gateway = Gateway(
|
||||||
|
market_service=market_service,
|
||||||
|
storage_service=storage_service,
|
||||||
|
pipeline=pipeline,
|
||||||
|
scheduler_callback=scheduler_callback,
|
||||||
|
config={
|
||||||
|
"mode": mode,
|
||||||
|
"mock_mode": is_mock,
|
||||||
|
"backtest_mode": is_backtest,
|
||||||
|
"tickers": tickers,
|
||||||
|
"config_name": run_id,
|
||||||
|
"schedule_mode": schedule_mode,
|
||||||
|
"interval_minutes": interval_minutes,
|
||||||
|
"trigger_time": trigger_time,
|
||||||
|
"heartbeat_interval": heartbeat_interval,
|
||||||
|
"initial_cash": initial_cash,
|
||||||
|
"margin_requirement": margin_requirement,
|
||||||
|
"max_comm_cycles": max_comm_cycles,
|
||||||
|
"enable_memory": enable_memory,
|
||||||
|
},
|
||||||
|
scheduler=live_scheduler,
|
||||||
|
)
|
||||||
|
_set_gateway(gateway)
|
||||||
|
|
||||||
|
# Start pipeline execution
|
||||||
|
async with AsyncExitStack() as stack:
|
||||||
|
# Enter long-term memory contexts
|
||||||
|
for memory in long_term_memories:
|
||||||
|
await stack.enter_async_context(memory)
|
||||||
|
|
||||||
|
# Start Gateway - this will execute the 4-phase startup:
|
||||||
|
# Phase 1: WebSocket Server (frontend can connect immediately)
|
||||||
|
# Phase 2: Market Service (price updates start flowing)
|
||||||
|
# Phase 3: Market Status Monitor
|
||||||
|
# Phase 4: Scheduler (trading cycles begin)
|
||||||
|
gateway_task = asyncio.create_task(
|
||||||
|
gateway.start(host="0.0.0.0", port=8765)
|
||||||
|
)
|
||||||
|
logger.info("[Pipeline] Gateway startup initiated on ws://localhost:8765")
|
||||||
|
|
||||||
|
# Wait for Gateway to fully initialize all phases
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
|
# Define the trading cycle callback
|
||||||
|
async def trading_cycle(session_key: str) -> None:
|
||||||
|
"""Execute one trading cycle."""
|
||||||
|
if is_shutdown_requested():
|
||||||
|
return
|
||||||
|
|
||||||
|
runtime_manager.set_session_key(session_key)
|
||||||
|
runtime_manager.log_event("cycle:start", {"session": session_key})
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Fetch market data
|
||||||
|
market_data = await market_service.get_all_data()
|
||||||
|
|
||||||
|
# Run pipeline
|
||||||
|
await pipeline.run_cycle(
|
||||||
|
session_key=session_key,
|
||||||
|
market_data=market_data,
|
||||||
|
)
|
||||||
|
|
||||||
|
runtime_manager.log_event("cycle:complete", {"session": session_key})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
runtime_manager.log_event("cycle:error", {"error": str(e)})
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Start scheduler
|
||||||
|
if scheduler_callback:
|
||||||
|
await scheduler_callback(trading_cycle)
|
||||||
|
|
||||||
|
# Wait for stop signal
|
||||||
|
while not stop_event.is_set():
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
# Cancel gateway task
|
||||||
|
if not gateway_task.done():
|
||||||
|
gateway_task.cancel()
|
||||||
|
try:
|
||||||
|
await gateway_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
# Handle cancellation gracefully
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
# Cleanup
|
||||||
|
logger.info("[Pipeline] Cleaning up...")
|
||||||
|
|
||||||
|
# Stop Gateway
|
||||||
|
try:
|
||||||
|
stop_gateway()
|
||||||
|
logger.info("[Pipeline] Gateway stopped")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[Pipeline] Error stopping gateway: {e}")
|
||||||
|
|
||||||
|
clear_shutdown_event()
|
||||||
|
clear_global_runtime_manager()
|
||||||
|
from backend.api.runtime import unregister_runtime_manager
|
||||||
|
unregister_runtime_manager()
|
||||||
|
logger.info("[Pipeline] Cleanup complete")
|
||||||
@@ -4,7 +4,7 @@ Scheduler - Market-aware trigger system for trading cycles
|
|||||||
"""
|
"""
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, time, timedelta
|
||||||
from typing import Any, Callable, Optional
|
from typing import Any, Callable, Optional
|
||||||
from zoneinfo import ZoneInfo
|
from zoneinfo import ZoneInfo
|
||||||
|
|
||||||
@@ -28,16 +28,21 @@ class Scheduler:
|
|||||||
mode: str = "daily",
|
mode: str = "daily",
|
||||||
trigger_time: Optional[str] = None,
|
trigger_time: Optional[str] = None,
|
||||||
interval_minutes: Optional[int] = None,
|
interval_minutes: Optional[int] = None,
|
||||||
|
heartbeat_interval: Optional[int] = None,
|
||||||
config: Optional[dict] = None,
|
config: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
self.trigger_time = trigger_time or "09:30" # NYSE timezone
|
self.trigger_time = trigger_time or "09:30" # NYSE timezone
|
||||||
self.trigger_now = self.trigger_time == "now"
|
self.trigger_now = self.trigger_time == "now"
|
||||||
self.interval_minutes = interval_minutes or 60
|
self.interval_minutes = interval_minutes or 60
|
||||||
|
self.heartbeat_interval = heartbeat_interval # e.g. 3600 = 1 hour
|
||||||
self.config = config or {}
|
self.config = config or {}
|
||||||
|
|
||||||
self.running = False
|
self.running = False
|
||||||
self._task: Optional[asyncio.Task] = None
|
self._task: Optional[asyncio.Task] = None
|
||||||
|
self._heartbeat_task: Optional[asyncio.Task] = None
|
||||||
|
self._callback: Optional[Callable] = None
|
||||||
|
self._heartbeat_callback: Optional[Callable] = None
|
||||||
|
|
||||||
def _now_nyse(self) -> datetime:
|
def _now_nyse(self) -> datetime:
|
||||||
"""Get current time in NYSE timezone"""
|
"""Get current time in NYSE timezone"""
|
||||||
@@ -52,6 +57,15 @@ class Scheduler:
|
|||||||
)
|
)
|
||||||
return len(valid_days) > 0
|
return len(valid_days) > 0
|
||||||
|
|
||||||
|
def _is_trading_hours(self, now: datetime) -> bool:
|
||||||
|
"""Check if current time is within NYSE trading hours (9:30-16:00 ET)."""
|
||||||
|
market_time = now.time()
|
||||||
|
return time(9, 30) <= market_time <= time(16, 0)
|
||||||
|
|
||||||
|
def set_heartbeat_callback(self, callback: Callable) -> None:
|
||||||
|
"""Register callback for heartbeat triggers."""
|
||||||
|
self._heartbeat_callback = callback
|
||||||
|
|
||||||
def _next_trading_day(self, from_date: datetime) -> datetime:
|
def _next_trading_day(self, from_date: datetime) -> datetime:
|
||||||
"""Find the next trading day from given date"""
|
"""Find the next trading day from given date"""
|
||||||
check_date = from_date
|
check_date = from_date
|
||||||
@@ -68,18 +82,100 @@ class Scheduler:
|
|||||||
return
|
return
|
||||||
|
|
||||||
self.running = True
|
self.running = True
|
||||||
|
self._callback = callback
|
||||||
|
self._schedule_task()
|
||||||
|
|
||||||
if self.mode == "daily":
|
# Start heartbeat loop if configured
|
||||||
self._task = asyncio.create_task(self._run_daily(callback))
|
if self.heartbeat_interval and self._heartbeat_callback:
|
||||||
elif self.mode == "intraday":
|
self._heartbeat_task = asyncio.create_task(self._run_heartbeat_loop())
|
||||||
self._task = asyncio.create_task(self._run_intraday(callback))
|
logger.info(
|
||||||
else:
|
f"Heartbeat loop started: interval={self.heartbeat_interval}s",
|
||||||
raise ValueError(f"Unknown scheduler mode: {self.mode}")
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Scheduler started: mode={self.mode}, timezone=America/New_York",
|
f"Scheduler started: mode={self.mode}, timezone=America/New_York",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _schedule_task(self):
|
||||||
|
"""Create the active scheduler task for the current mode."""
|
||||||
|
if not self._callback:
|
||||||
|
raise ValueError("Scheduler callback is not set")
|
||||||
|
|
||||||
|
if self._task:
|
||||||
|
self._task.cancel()
|
||||||
|
self._task = None
|
||||||
|
|
||||||
|
if self.mode == "daily":
|
||||||
|
self._task = asyncio.create_task(self._run_daily(self._callback))
|
||||||
|
elif self.mode == "intraday":
|
||||||
|
self._task = asyncio.create_task(
|
||||||
|
self._run_intraday(self._callback),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown scheduler mode: {self.mode}")
|
||||||
|
|
||||||
|
def reconfigure(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
mode: Optional[str] = None,
|
||||||
|
trigger_time: Optional[str] = None,
|
||||||
|
interval_minutes: Optional[int] = None,
|
||||||
|
) -> bool:
|
||||||
|
"""Update scheduler parameters in-place and restart its timing loop."""
|
||||||
|
changed = False
|
||||||
|
|
||||||
|
if mode and mode != self.mode:
|
||||||
|
self.mode = mode
|
||||||
|
changed = True
|
||||||
|
|
||||||
|
if trigger_time and trigger_time != self.trigger_time:
|
||||||
|
self.trigger_time = trigger_time
|
||||||
|
self.trigger_now = self.trigger_time == "now"
|
||||||
|
changed = True
|
||||||
|
|
||||||
|
if (
|
||||||
|
interval_minutes is not None
|
||||||
|
and interval_minutes > 0
|
||||||
|
and interval_minutes != self.interval_minutes
|
||||||
|
):
|
||||||
|
self.interval_minutes = interval_minutes
|
||||||
|
changed = True
|
||||||
|
|
||||||
|
if changed and self.running and self._callback:
|
||||||
|
self._schedule_task()
|
||||||
|
logger.info(
|
||||||
|
"Scheduler reconfigured: mode=%s, trigger_time=%s, interval_minutes=%s",
|
||||||
|
self.mode,
|
||||||
|
self.trigger_time,
|
||||||
|
self.interval_minutes,
|
||||||
|
)
|
||||||
|
|
||||||
|
return changed
|
||||||
|
|
||||||
|
async def _run_heartbeat_loop(self):
|
||||||
|
"""Run heartbeat checks on a separate interval during trading hours."""
|
||||||
|
while self.running:
|
||||||
|
now = self._now_nyse()
|
||||||
|
if self._is_trading_day(now) and self._is_trading_hours(now):
|
||||||
|
if self._heartbeat_callback:
|
||||||
|
try:
|
||||||
|
current_date = now.strftime("%Y-%m-%d")
|
||||||
|
logger.debug(
|
||||||
|
f"[Heartbeat] Triggering heartbeat check for {current_date}",
|
||||||
|
)
|
||||||
|
await self._heartbeat_callback(date=current_date)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"[Heartbeat] Callback failed: {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"[Heartbeat] Callback not set, skipping heartbeat",
|
||||||
|
)
|
||||||
|
|
||||||
|
await asyncio.sleep(self.heartbeat_interval)
|
||||||
|
|
||||||
async def _run_daily(self, callback: Callable):
|
async def _run_daily(self, callback: Callable):
|
||||||
"""Run once per trading day at specified time (NYSE timezone)"""
|
"""Run once per trading day at specified time (NYSE timezone)"""
|
||||||
first_run = True
|
first_run = True
|
||||||
@@ -154,6 +250,9 @@ class Scheduler:
|
|||||||
if self._task:
|
if self._task:
|
||||||
self._task.cancel()
|
self._task.cancel()
|
||||||
self._task = None
|
self._task = None
|
||||||
|
if self._heartbeat_task:
|
||||||
|
self._heartbeat_task.cancel()
|
||||||
|
self._heartbeat_task = None
|
||||||
logger.info("Scheduler stopped")
|
logger.info("Scheduler stopped")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -47,6 +47,10 @@ class StateSync:
|
|||||||
"""Set current simulation date for backtest-compatible timestamps"""
|
"""Set current simulation date for backtest-compatible timestamps"""
|
||||||
self._simulation_date = date
|
self._simulation_date = date
|
||||||
|
|
||||||
|
def clear_simulation_date(self):
|
||||||
|
"""Disable backtest timestamp simulation and use wall-clock time."""
|
||||||
|
self._simulation_date = None
|
||||||
|
|
||||||
def _get_timestamp_ms(self) -> int:
|
def _get_timestamp_ms(self) -> int:
|
||||||
"""
|
"""
|
||||||
Get timestamp in milliseconds.
|
Get timestamp in milliseconds.
|
||||||
@@ -97,12 +101,24 @@ class StateSync:
|
|||||||
if not self._enabled:
|
if not self._enabled:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Ensure timestamp exists (use simulation date if in backtest mode)
|
# Ensure timestamp exists. Prefer explicit millisecond timestamps so
|
||||||
|
# frontend displays local wall time correctly instead of date-only UTC.
|
||||||
if "timestamp" not in event:
|
if "timestamp" not in event:
|
||||||
|
ts_ms = event.get("ts")
|
||||||
|
if ts_ms is not None:
|
||||||
|
try:
|
||||||
|
event["timestamp"] = datetime.fromtimestamp(
|
||||||
|
float(ts_ms) / 1000.0,
|
||||||
|
).isoformat()
|
||||||
|
except (TypeError, ValueError, OSError):
|
||||||
if self._simulation_date:
|
if self._simulation_date:
|
||||||
event["timestamp"] = f"{self._simulation_date}"
|
event["timestamp"] = f"{self._simulation_date}"
|
||||||
else:
|
else:
|
||||||
event["timestamp"] = datetime.now().isoformat()
|
event["timestamp"] = datetime.now().isoformat()
|
||||||
|
elif self._simulation_date:
|
||||||
|
event["timestamp"] = f"{self._simulation_date}"
|
||||||
|
else:
|
||||||
|
event["timestamp"] = datetime.now().isoformat()
|
||||||
|
|
||||||
# Persist to feed_history
|
# Persist to feed_history
|
||||||
if persist:
|
if persist:
|
||||||
@@ -238,9 +254,12 @@ class StateSync:
|
|||||||
"""Called at start of trading cycle"""
|
"""Called at start of trading cycle"""
|
||||||
self._state["current_date"] = date
|
self._state["current_date"] = date
|
||||||
self._state["status"] = "running"
|
self._state["status"] = "running"
|
||||||
|
if self._state.get("server_mode") == "backtest":
|
||||||
self.set_simulation_date(
|
self.set_simulation_date(
|
||||||
date,
|
date,
|
||||||
) # Set for backtest-compatible timestamps
|
) # Set for backtest-compatible timestamps
|
||||||
|
else:
|
||||||
|
self.clear_simulation_date()
|
||||||
|
|
||||||
await self.emit(
|
await self.emit(
|
||||||
{
|
{
|
||||||
@@ -411,7 +430,9 @@ class StateSync:
|
|||||||
|
|
||||||
Useful for: frontend reconnection or restoring from saved state
|
Useful for: frontend reconnection or restoring from saved state
|
||||||
"""
|
"""
|
||||||
feed_history = self._state.get("feed_history", [])
|
feed_history = self.storage.runtime_db.get_recent_feed_events(
|
||||||
|
limit=self.storage.max_feed_history,
|
||||||
|
) or self._state.get("feed_history", [])
|
||||||
|
|
||||||
# feed_history is newest-first, need to reverse for chronological replay # noqa: E501
|
# feed_history is newest-first, need to reverse for chronological replay # noqa: E501
|
||||||
for event in reversed(feed_history):
|
for event in reversed(feed_history):
|
||||||
@@ -434,11 +455,22 @@ class StateSync:
|
|||||||
Returns:
|
Returns:
|
||||||
Dictionary suitable for sending to frontend
|
Dictionary suitable for sending to frontend
|
||||||
"""
|
"""
|
||||||
|
feed_history = self.storage.runtime_db.get_recent_feed_events(
|
||||||
|
limit=self.storage.max_feed_history,
|
||||||
|
) or self._state.get("feed_history", [])
|
||||||
|
last_day_history = self.storage.runtime_db.get_last_day_feed_events(
|
||||||
|
current_date=self._state.get("current_date"),
|
||||||
|
limit=self.storage.max_feed_history,
|
||||||
|
) or self._state.get("last_day_history", [])
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"server_mode": self._state.get("server_mode", "live"),
|
"server_mode": self._state.get("server_mode", "live"),
|
||||||
"is_mock_mode": self._state.get("is_mock_mode", False),
|
"is_mock_mode": self._state.get("is_mock_mode", False),
|
||||||
"is_backtest": self._state.get("is_backtest", False),
|
"is_backtest": self._state.get("is_backtest", False),
|
||||||
"feed_history": self._state.get("feed_history", []),
|
"tickers": self._state.get("tickers"),
|
||||||
|
"runtime_config": self._state.get("runtime_config"),
|
||||||
|
"feed_history": feed_history,
|
||||||
|
"last_day_history": last_day_history,
|
||||||
"current_date": self._state.get("current_date"),
|
"current_date": self._state.get("current_date"),
|
||||||
"trading_days_total": self._state.get("trading_days_total", 0),
|
"trading_days_total": self._state.get("trading_days_total", 0),
|
||||||
"trading_days_completed": self._state.get(
|
"trading_days_completed": self._state.get(
|
||||||
@@ -452,6 +484,7 @@ class StateSync:
|
|||||||
"portfolio": self._state.get("portfolio", {}),
|
"portfolio": self._state.get("portfolio", {}),
|
||||||
"realtime_prices": self._state.get("realtime_prices", {}),
|
"realtime_prices": self._state.get("realtime_prices", {}),
|
||||||
"data_sources": self._state.get("data_sources", {}),
|
"data_sources": self._state.get("data_sources", {}),
|
||||||
|
"price_history": self._state.get("price_history", {}),
|
||||||
}
|
}
|
||||||
|
|
||||||
if include_dashboard:
|
if include_dashboard:
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from datetime import datetime
|
|||||||
from typing import Callable, Dict, List, Optional
|
from typing import Callable, Dict, List, Optional
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
from backend.data.market_store import MarketStore
|
||||||
from backend.data.provider_utils import normalize_symbol
|
from backend.data.provider_utils import normalize_symbol
|
||||||
from backend.data.provider_router import get_provider_router
|
from backend.data.provider_router import get_provider_router
|
||||||
|
|
||||||
@@ -26,6 +27,7 @@ class HistoricalPriceManager:
|
|||||||
self.close_prices = {}
|
self.close_prices = {}
|
||||||
self.running = False
|
self.running = False
|
||||||
self._router = get_provider_router()
|
self._router = get_provider_router()
|
||||||
|
self._market_store = MarketStore()
|
||||||
|
|
||||||
def subscribe(
|
def subscribe(
|
||||||
self,
|
self,
|
||||||
@@ -58,21 +60,48 @@ class HistoricalPriceManager:
|
|||||||
logger.warning(f"Failed to load CSV for {symbol}: {e}")
|
logger.warning(f"Failed to load CSV for {symbol}: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def _load_from_market_db(
|
||||||
|
self,
|
||||||
|
symbol: str,
|
||||||
|
start_date: str,
|
||||||
|
end_date: str,
|
||||||
|
) -> Optional[pd.DataFrame]:
|
||||||
|
"""Load price data from the long-lived market research database."""
|
||||||
|
try:
|
||||||
|
rows = self._market_store.get_ohlc(symbol, start_date, end_date)
|
||||||
|
if not rows:
|
||||||
|
return None
|
||||||
|
df = pd.DataFrame(rows)
|
||||||
|
if df.empty or "date" not in df.columns:
|
||||||
|
return None
|
||||||
|
df["Date"] = pd.to_datetime(df["date"])
|
||||||
|
df.set_index("Date", inplace=True)
|
||||||
|
df.sort_index(inplace=True)
|
||||||
|
return df
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to load market DB data for {symbol}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
def preload_data(self, start_date: str, end_date: str):
|
def preload_data(self, start_date: str, end_date: str):
|
||||||
"""Preload historical data from local CSV files."""
|
"""Preload historical data from market DB first, then local CSV."""
|
||||||
logger.info(f"Preloading data: {start_date} to {end_date}")
|
logger.info(f"Preloading data: {start_date} to {end_date}")
|
||||||
|
|
||||||
for symbol in self.subscribed_symbols:
|
for symbol in self.subscribed_symbols:
|
||||||
if symbol in self._price_cache:
|
if symbol in self._price_cache:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Load from local CSV file directly
|
df = self._load_from_market_db(symbol, start_date, end_date)
|
||||||
|
if df is not None and not df.empty:
|
||||||
|
self._price_cache[symbol] = df
|
||||||
|
logger.info(f"Loaded {symbol} from market DB: {len(df)} records")
|
||||||
|
continue
|
||||||
|
|
||||||
df = self._load_from_csv(symbol)
|
df = self._load_from_csv(symbol)
|
||||||
if df is not None and not df.empty:
|
if df is not None and not df.empty:
|
||||||
self._price_cache[symbol] = df
|
self._price_cache[symbol] = df
|
||||||
logger.info(f"Loaded {symbol} from CSV: {len(df)} records")
|
logger.info(f"Loaded {symbol} from CSV: {len(df)} records")
|
||||||
else:
|
else:
|
||||||
logger.warning(f"No CSV data for {symbol}")
|
logger.warning(f"No market DB or CSV data for {symbol}")
|
||||||
|
|
||||||
def set_date(self, date: str):
|
def set_date(self, date: str):
|
||||||
"""Set current trading date and update prices"""
|
"""Set current trading date and update prices"""
|
||||||
|
|||||||
149
backend/data/market_ingest.py
Normal file
149
backend/data/market_ingest.py
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""Ingest Polygon market data into the long-lived research warehouse."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from typing import Iterable
|
||||||
|
|
||||||
|
from backend.data.market_store import MarketStore
|
||||||
|
from backend.data.news_alignment import align_news_for_symbol
|
||||||
|
from backend.data.polygon_client import (
|
||||||
|
fetch_news,
|
||||||
|
fetch_ohlc,
|
||||||
|
fetch_ticker_details,
|
||||||
|
)
|
||||||
|
from backend.data.provider_utils import normalize_symbol
|
||||||
|
|
||||||
|
|
||||||
|
def _today_utc() -> str:
|
||||||
|
return datetime.now(timezone.utc).date().isoformat()
|
||||||
|
|
||||||
|
|
||||||
|
def _default_start(years: int = 2) -> str:
|
||||||
|
return (datetime.now(timezone.utc).date() - timedelta(days=years * 366)).isoformat()
|
||||||
|
|
||||||
|
|
||||||
|
def ingest_ticker_history(
|
||||||
|
symbol: str,
|
||||||
|
*,
|
||||||
|
start_date: str | None = None,
|
||||||
|
end_date: str | None = None,
|
||||||
|
store: MarketStore | None = None,
|
||||||
|
) -> dict:
|
||||||
|
"""Fetch and persist Polygon OHLC + news for a ticker."""
|
||||||
|
ticker = normalize_symbol(symbol)
|
||||||
|
start = start_date or _default_start()
|
||||||
|
end = end_date or _today_utc()
|
||||||
|
market_store = store or MarketStore()
|
||||||
|
|
||||||
|
details = fetch_ticker_details(ticker)
|
||||||
|
market_store.upsert_ticker(
|
||||||
|
symbol=ticker,
|
||||||
|
name=details.get("name"),
|
||||||
|
sector=details.get("sic_description"),
|
||||||
|
is_active=bool(details.get("active", True)),
|
||||||
|
)
|
||||||
|
|
||||||
|
ohlc_rows = fetch_ohlc(ticker, start, end)
|
||||||
|
news_rows = fetch_news(ticker, start, end)
|
||||||
|
price_count = market_store.upsert_ohlc(ticker, ohlc_rows, source="polygon")
|
||||||
|
news_count = market_store.upsert_news(ticker, news_rows, source="polygon")
|
||||||
|
aligned_count = align_news_for_symbol(market_store, ticker)
|
||||||
|
market_store.update_fetch_watermark(symbol=ticker, price_date=end, news_date=end)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"symbol": ticker,
|
||||||
|
"start_date": start,
|
||||||
|
"end_date": end,
|
||||||
|
"prices": price_count,
|
||||||
|
"news": news_count,
|
||||||
|
"aligned": aligned_count,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def update_ticker_incremental(
|
||||||
|
symbol: str,
|
||||||
|
*,
|
||||||
|
end_date: str | None = None,
|
||||||
|
store: MarketStore | None = None,
|
||||||
|
) -> dict:
|
||||||
|
"""Incrementally fetch OHLC + news since the last watermark."""
|
||||||
|
ticker = normalize_symbol(symbol)
|
||||||
|
market_store = store or MarketStore()
|
||||||
|
watermarks = market_store.get_ticker_watermarks(ticker)
|
||||||
|
end = end_date or _today_utc()
|
||||||
|
start_prices = (
|
||||||
|
(datetime.fromisoformat(watermarks["last_price_fetch"]) + timedelta(days=1)).date().isoformat()
|
||||||
|
if watermarks.get("last_price_fetch")
|
||||||
|
else _default_start()
|
||||||
|
)
|
||||||
|
start_news = (
|
||||||
|
(datetime.fromisoformat(watermarks["last_news_fetch"]) + timedelta(days=1)).date().isoformat()
|
||||||
|
if watermarks.get("last_news_fetch")
|
||||||
|
else _default_start()
|
||||||
|
)
|
||||||
|
|
||||||
|
details = fetch_ticker_details(ticker)
|
||||||
|
market_store.upsert_ticker(
|
||||||
|
symbol=ticker,
|
||||||
|
name=details.get("name"),
|
||||||
|
sector=details.get("sic_description"),
|
||||||
|
is_active=bool(details.get("active", True)),
|
||||||
|
)
|
||||||
|
|
||||||
|
ohlc_rows = [] if start_prices > end else fetch_ohlc(ticker, start_prices, end)
|
||||||
|
news_rows = [] if start_news > end else fetch_news(ticker, start_news, end)
|
||||||
|
price_count = market_store.upsert_ohlc(ticker, ohlc_rows, source="polygon") if ohlc_rows else 0
|
||||||
|
news_count = market_store.upsert_news(ticker, news_rows, source="polygon") if news_rows else 0
|
||||||
|
aligned_count = align_news_for_symbol(market_store, ticker)
|
||||||
|
market_store.update_fetch_watermark(
|
||||||
|
symbol=ticker,
|
||||||
|
price_date=end if ohlc_rows or watermarks.get("last_price_fetch") else None,
|
||||||
|
news_date=end if news_rows or watermarks.get("last_news_fetch") else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"symbol": ticker,
|
||||||
|
"start_price_date": start_prices,
|
||||||
|
"start_news_date": start_news,
|
||||||
|
"end_date": end,
|
||||||
|
"prices": price_count,
|
||||||
|
"news": news_count,
|
||||||
|
"aligned": aligned_count,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def ingest_symbols(
|
||||||
|
symbols: Iterable[str],
|
||||||
|
*,
|
||||||
|
mode: str = "incremental",
|
||||||
|
start_date: str | None = None,
|
||||||
|
end_date: str | None = None,
|
||||||
|
store: MarketStore | None = None,
|
||||||
|
) -> list[dict]:
|
||||||
|
"""Fetch Polygon data for a list of tickers."""
|
||||||
|
market_store = store or MarketStore()
|
||||||
|
results = []
|
||||||
|
for symbol in symbols:
|
||||||
|
ticker = normalize_symbol(symbol)
|
||||||
|
if not ticker:
|
||||||
|
continue
|
||||||
|
if mode == "full":
|
||||||
|
results.append(
|
||||||
|
ingest_ticker_history(
|
||||||
|
ticker,
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
store=market_store,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
results.append(
|
||||||
|
update_ticker_incremental(
|
||||||
|
ticker,
|
||||||
|
end_date=end_date,
|
||||||
|
store=market_store,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return results
|
||||||
1074
backend/data/market_store.py
Normal file
1074
backend/data/market_store.py
Normal file
File diff suppressed because it is too large
Load Diff
64
backend/data/news_alignment.py
Normal file
64
backend/data/news_alignment.py
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""Align persisted news to the nearest NYSE trading date."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import time
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import pandas_market_calendars as mcal
|
||||||
|
|
||||||
|
from backend.data.market_store import MarketStore
|
||||||
|
|
||||||
|
|
||||||
|
NYSE_CALENDAR = mcal.get_calendar("NYSE")
|
||||||
|
|
||||||
|
|
||||||
|
def _next_trading_day(date_str: str) -> str:
|
||||||
|
start = pd.Timestamp(date_str).tz_localize(None)
|
||||||
|
sessions = NYSE_CALENDAR.valid_days(
|
||||||
|
start_date=(start - pd.Timedelta(days=1)).strftime("%Y-%m-%d"),
|
||||||
|
end_date=(start + pd.Timedelta(days=10)).strftime("%Y-%m-%d"),
|
||||||
|
)
|
||||||
|
future = [
|
||||||
|
pd.Timestamp(day).tz_localize(None).strftime("%Y-%m-%d")
|
||||||
|
for day in sessions
|
||||||
|
if pd.Timestamp(day).tz_localize(None) >= start
|
||||||
|
]
|
||||||
|
return future[0] if future else date_str
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_trade_date(published_utc: str | None) -> str | None:
|
||||||
|
"""Map a published timestamp to an NYSE trade date."""
|
||||||
|
if not published_utc:
|
||||||
|
return None
|
||||||
|
timestamp = pd.to_datetime(published_utc, utc=True, errors="coerce")
|
||||||
|
if pd.isna(timestamp):
|
||||||
|
return None
|
||||||
|
nyse_time = timestamp.tz_convert("America/New_York")
|
||||||
|
candidate = nyse_time.date().isoformat()
|
||||||
|
valid_days = NYSE_CALENDAR.valid_days(start_date=candidate, end_date=candidate)
|
||||||
|
if len(valid_days) == 0:
|
||||||
|
return _next_trading_day(candidate)
|
||||||
|
if nyse_time.time() >= time(16, 0):
|
||||||
|
return _next_trading_day((nyse_time + pd.Timedelta(days=1)).date().isoformat())
|
||||||
|
return candidate
|
||||||
|
|
||||||
|
|
||||||
|
def align_news_for_symbol(store: MarketStore, symbol: str, *, limit: int = 5000) -> int:
|
||||||
|
"""Fill missing trade_date values for one ticker."""
|
||||||
|
pending = store.get_news_without_trade_date(symbol, limit=limit)
|
||||||
|
updates = []
|
||||||
|
for row in pending:
|
||||||
|
trade_date = resolve_trade_date(row.get("published_utc"))
|
||||||
|
if trade_date:
|
||||||
|
updates.append(
|
||||||
|
{
|
||||||
|
"news_id": row["news_id"],
|
||||||
|
"symbol": row["symbol"],
|
||||||
|
"trade_date": trade_date,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if not updates:
|
||||||
|
return 0
|
||||||
|
return store.set_trade_dates(updates)
|
||||||
161
backend/data/polygon_client.py
Normal file
161
backend/data/polygon_client.py
Normal file
@@ -0,0 +1,161 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""Polygon client used for long-lived market research ingestion."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
BASE = "https://api.polygon.io"
|
||||||
|
|
||||||
|
|
||||||
|
def _headers() -> dict[str, str]:
|
||||||
|
api_key = os.getenv("POLYGON_API_KEY", "").strip()
|
||||||
|
if not api_key:
|
||||||
|
raise ValueError("Missing required API key: POLYGON_API_KEY")
|
||||||
|
return {"Authorization": f"Bearer {api_key}"}
|
||||||
|
|
||||||
|
|
||||||
|
def http_get(
|
||||||
|
url: str,
|
||||||
|
params: Optional[dict[str, Any]] = None,
|
||||||
|
*,
|
||||||
|
max_retries: int = 8,
|
||||||
|
backoff: float = 2.0,
|
||||||
|
) -> requests.Response:
|
||||||
|
"""HTTP GET with exponential backoff and 429 handling."""
|
||||||
|
for attempt in range(max_retries):
|
||||||
|
try:
|
||||||
|
response = requests.get(
|
||||||
|
url,
|
||||||
|
params=params or {},
|
||||||
|
headers=_headers(),
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
|
except requests.RequestException:
|
||||||
|
time.sleep((backoff**attempt) + 0.5)
|
||||||
|
if attempt == max_retries - 1:
|
||||||
|
raise
|
||||||
|
continue
|
||||||
|
|
||||||
|
if response.status_code == 429:
|
||||||
|
retry_after = response.headers.get("Retry-After")
|
||||||
|
wait = (
|
||||||
|
float(retry_after)
|
||||||
|
if retry_after and retry_after.isdigit()
|
||||||
|
else min((backoff**attempt) + 1.0, 60.0)
|
||||||
|
)
|
||||||
|
time.sleep(wait)
|
||||||
|
if attempt == max_retries - 1:
|
||||||
|
response.raise_for_status()
|
||||||
|
continue
|
||||||
|
|
||||||
|
if 500 <= response.status_code < 600:
|
||||||
|
time.sleep(min((backoff**attempt) + 1.0, 60.0))
|
||||||
|
if attempt == max_retries - 1:
|
||||||
|
response.raise_for_status()
|
||||||
|
continue
|
||||||
|
|
||||||
|
response.raise_for_status()
|
||||||
|
return response
|
||||||
|
raise RuntimeError("Unreachable")
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_ticker_details(symbol: str) -> dict[str, Any]:
|
||||||
|
"""Fetch company metadata from Polygon."""
|
||||||
|
response = http_get(f"{BASE}/v3/reference/tickers/{symbol}")
|
||||||
|
return response.json().get("results", {}) or {}
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_ohlc(symbol: str, start_date: str, end_date: str) -> list[dict[str, Any]]:
|
||||||
|
"""Fetch daily OHLC data from Polygon."""
|
||||||
|
response = http_get(
|
||||||
|
f"{BASE}/v2/aggs/ticker/{symbol}/range/1/day/{start_date}/{end_date}",
|
||||||
|
params={"adjusted": "true", "sort": "asc", "limit": 50000},
|
||||||
|
)
|
||||||
|
results = response.json().get("results") or []
|
||||||
|
rows: list[dict[str, Any]] = []
|
||||||
|
for item in results:
|
||||||
|
rows.append(
|
||||||
|
{
|
||||||
|
"date": datetime.fromtimestamp(
|
||||||
|
int(item["t"]) / 1000,
|
||||||
|
tz=timezone.utc,
|
||||||
|
).date().isoformat(),
|
||||||
|
"open": item.get("o"),
|
||||||
|
"high": item.get("h"),
|
||||||
|
"low": item.get("l"),
|
||||||
|
"close": item.get("c"),
|
||||||
|
"volume": item.get("v"),
|
||||||
|
"vwap": item.get("vw"),
|
||||||
|
"transactions": item.get("n"),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return rows
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_news(
|
||||||
|
symbol: str,
|
||||||
|
start_date: str,
|
||||||
|
end_date: str,
|
||||||
|
*,
|
||||||
|
per_page: int = 50,
|
||||||
|
page_sleep: float = 1.2,
|
||||||
|
max_pages: Optional[int] = None,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Fetch all Polygon news for a ticker, with pagination."""
|
||||||
|
url = f"{BASE}/v2/reference/news"
|
||||||
|
params = {
|
||||||
|
"ticker": symbol,
|
||||||
|
"published_utc.gte": start_date,
|
||||||
|
"published_utc.lte": end_date,
|
||||||
|
"limit": per_page,
|
||||||
|
"order": "asc",
|
||||||
|
}
|
||||||
|
next_url: Optional[str] = None
|
||||||
|
pages = 0
|
||||||
|
all_articles: list[dict[str, Any]] = []
|
||||||
|
seen_ids: set[str] = set()
|
||||||
|
|
||||||
|
while True:
|
||||||
|
response = http_get(next_url or url, params=None if next_url else params)
|
||||||
|
data = response.json()
|
||||||
|
results = data.get("results") or []
|
||||||
|
if not results:
|
||||||
|
break
|
||||||
|
|
||||||
|
for item in results:
|
||||||
|
article_id = item.get("id")
|
||||||
|
if article_id and article_id in seen_ids:
|
||||||
|
continue
|
||||||
|
all_articles.append(
|
||||||
|
{
|
||||||
|
"id": article_id,
|
||||||
|
"publisher": (item.get("publisher") or {}).get("name"),
|
||||||
|
"title": item.get("title"),
|
||||||
|
"author": item.get("author"),
|
||||||
|
"published_utc": item.get("published_utc"),
|
||||||
|
"amp_url": item.get("amp_url"),
|
||||||
|
"article_url": item.get("article_url"),
|
||||||
|
"tickers": item.get("tickers"),
|
||||||
|
"description": item.get("description"),
|
||||||
|
"insights": item.get("insights"),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if article_id:
|
||||||
|
seen_ids.add(article_id)
|
||||||
|
|
||||||
|
next_url = data.get("next_url")
|
||||||
|
pages += 1
|
||||||
|
if max_pages is not None and pages >= max_pages:
|
||||||
|
break
|
||||||
|
if not next_url:
|
||||||
|
break
|
||||||
|
time.sleep(page_sleep)
|
||||||
|
|
||||||
|
return all_articles
|
||||||
@@ -11,7 +11,7 @@ import pandas as pd
|
|||||||
import yfinance as yf
|
import yfinance as yf
|
||||||
|
|
||||||
from backend.config.data_config import DataSource, get_data_sources
|
from backend.config.data_config import DataSource, get_data_sources
|
||||||
from backend.data.schema import (
|
from shared.schema import (
|
||||||
CompanyFactsResponse,
|
CompanyFactsResponse,
|
||||||
CompanyNews,
|
CompanyNews,
|
||||||
CompanyNewsResponse,
|
CompanyNewsResponse,
|
||||||
@@ -30,6 +30,25 @@ logger = logging.getLogger(__name__)
|
|||||||
_DATA_DIR = Path(__file__).parent / "ret_data"
|
_DATA_DIR = Path(__file__).parent / "ret_data"
|
||||||
|
|
||||||
|
|
||||||
|
def _format_provider_error(exc: Exception) -> str:
|
||||||
|
"""Condense common provider failures into short, readable messages."""
|
||||||
|
message = str(exc).strip().replace("\n", " ")
|
||||||
|
if "429" in message:
|
||||||
|
return "rate limit reached"
|
||||||
|
if "402" in message:
|
||||||
|
return "insufficient credits"
|
||||||
|
if "422" in message or "Missing parameters" in message:
|
||||||
|
return "invalid request parameters"
|
||||||
|
if "Quote not found" in message:
|
||||||
|
return "quote not found"
|
||||||
|
return message
|
||||||
|
|
||||||
|
|
||||||
|
def _has_valid_ticker(ticker: str) -> bool:
|
||||||
|
"""Return whether the normalized ticker is non-empty."""
|
||||||
|
return bool((ticker or "").strip())
|
||||||
|
|
||||||
|
|
||||||
class DataProviderRouter:
|
class DataProviderRouter:
|
||||||
"""Route data requests across configured providers with fallbacks."""
|
"""Route data requests across configured providers with fallbacks."""
|
||||||
|
|
||||||
@@ -56,6 +75,8 @@ class DataProviderRouter:
|
|||||||
end_date: str,
|
end_date: str,
|
||||||
) -> tuple[list[Price], DataSource]:
|
) -> tuple[list[Price], DataSource]:
|
||||||
"""Fetch prices using preferred providers with fallback."""
|
"""Fetch prices using preferred providers with fallback."""
|
||||||
|
if not _has_valid_ticker(ticker):
|
||||||
|
return [], "local_csv"
|
||||||
last_error: Optional[Exception] = None
|
last_error: Optional[Exception] = None
|
||||||
|
|
||||||
for source in self.price_sources():
|
for source in self.price_sources():
|
||||||
@@ -78,7 +99,12 @@ class DataProviderRouter:
|
|||||||
return prices, source
|
return prices, source
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
last_error = exc
|
last_error = exc
|
||||||
logger.warning("Price source %s failed for %s: %s", source, ticker, exc)
|
logger.warning(
|
||||||
|
"Price source %s failed for %s: %s",
|
||||||
|
source,
|
||||||
|
ticker,
|
||||||
|
_format_provider_error(exc),
|
||||||
|
)
|
||||||
|
|
||||||
if last_error:
|
if last_error:
|
||||||
raise last_error
|
raise last_error
|
||||||
@@ -92,6 +118,8 @@ class DataProviderRouter:
|
|||||||
limit: int = 10,
|
limit: int = 10,
|
||||||
) -> tuple[list[FinancialMetrics], DataSource]:
|
) -> tuple[list[FinancialMetrics], DataSource]:
|
||||||
"""Fetch financial metrics with API provider fallback."""
|
"""Fetch financial metrics with API provider fallback."""
|
||||||
|
if not _has_valid_ticker(ticker):
|
||||||
|
return [], "local_csv"
|
||||||
last_error: Optional[Exception] = None
|
last_error: Optional[Exception] = None
|
||||||
|
|
||||||
for source in self.api_sources():
|
for source in self.api_sources():
|
||||||
@@ -126,7 +154,7 @@ class DataProviderRouter:
|
|||||||
"Financial metrics source %s failed for %s: %s",
|
"Financial metrics source %s failed for %s: %s",
|
||||||
source,
|
source,
|
||||||
ticker,
|
ticker,
|
||||||
exc,
|
_format_provider_error(exc),
|
||||||
)
|
)
|
||||||
|
|
||||||
if last_error:
|
if last_error:
|
||||||
@@ -142,6 +170,8 @@ class DataProviderRouter:
|
|||||||
limit: int = 10,
|
limit: int = 10,
|
||||||
) -> list[LineItem]:
|
) -> list[LineItem]:
|
||||||
"""Line items are only supported via Financial Datasets."""
|
"""Line items are only supported via Financial Datasets."""
|
||||||
|
if not _has_valid_ticker(ticker):
|
||||||
|
return []
|
||||||
if "financial_datasets" not in self.api_sources():
|
if "financial_datasets" not in self.api_sources():
|
||||||
return []
|
return []
|
||||||
try:
|
try:
|
||||||
@@ -155,7 +185,11 @@ class DataProviderRouter:
|
|||||||
self._record_success("line_items", "financial_datasets")
|
self._record_success("line_items", "financial_datasets")
|
||||||
return results
|
return results
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning("Line items source failed for %s: %s", ticker, exc)
|
logger.warning(
|
||||||
|
"Line items source failed for %s: %s",
|
||||||
|
ticker,
|
||||||
|
_format_provider_error(exc),
|
||||||
|
)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def get_insider_trades(
|
def get_insider_trades(
|
||||||
@@ -166,6 +200,8 @@ class DataProviderRouter:
|
|||||||
limit: int = 1000,
|
limit: int = 1000,
|
||||||
) -> tuple[list[InsiderTrade], DataSource]:
|
) -> tuple[list[InsiderTrade], DataSource]:
|
||||||
"""Fetch insider trades with provider fallback."""
|
"""Fetch insider trades with provider fallback."""
|
||||||
|
if not _has_valid_ticker(ticker):
|
||||||
|
return [], "local_csv"
|
||||||
last_error: Optional[Exception] = None
|
last_error: Optional[Exception] = None
|
||||||
|
|
||||||
for source in self.api_sources():
|
for source in self.api_sources():
|
||||||
@@ -193,7 +229,7 @@ class DataProviderRouter:
|
|||||||
"Insider trades source %s failed for %s: %s",
|
"Insider trades source %s failed for %s: %s",
|
||||||
source,
|
source,
|
||||||
ticker,
|
ticker,
|
||||||
exc,
|
_format_provider_error(exc),
|
||||||
)
|
)
|
||||||
|
|
||||||
if last_error:
|
if last_error:
|
||||||
@@ -208,6 +244,8 @@ class DataProviderRouter:
|
|||||||
limit: int = 1000,
|
limit: int = 1000,
|
||||||
) -> tuple[list[CompanyNews], DataSource]:
|
) -> tuple[list[CompanyNews], DataSource]:
|
||||||
"""Fetch company news with provider fallback."""
|
"""Fetch company news with provider fallback."""
|
||||||
|
if not _has_valid_ticker(ticker):
|
||||||
|
return [], "local_csv"
|
||||||
last_error: Optional[Exception] = None
|
last_error: Optional[Exception] = None
|
||||||
|
|
||||||
for source in self.api_sources():
|
for source in self.api_sources():
|
||||||
@@ -244,7 +282,7 @@ class DataProviderRouter:
|
|||||||
"Company news source %s failed for %s: %s",
|
"Company news source %s failed for %s: %s",
|
||||||
source,
|
source,
|
||||||
ticker,
|
ticker,
|
||||||
exc,
|
_format_provider_error(exc),
|
||||||
)
|
)
|
||||||
|
|
||||||
if last_error:
|
if last_error:
|
||||||
@@ -258,6 +296,8 @@ class DataProviderRouter:
|
|||||||
metrics_lookup,
|
metrics_lookup,
|
||||||
) -> tuple[Optional[float], DataSource]:
|
) -> tuple[Optional[float], DataSource]:
|
||||||
"""Fetch market cap using facts API or financial metrics fallback."""
|
"""Fetch market cap using facts API or financial metrics fallback."""
|
||||||
|
if not _has_valid_ticker(ticker):
|
||||||
|
return None, "local_csv"
|
||||||
today = datetime.datetime.now().strftime("%Y-%m-%d")
|
today = datetime.datetime.now().strftime("%Y-%m-%d")
|
||||||
if end_date == today and "financial_datasets" in self.api_sources():
|
if end_date == today and "financial_datasets" in self.api_sources():
|
||||||
try:
|
try:
|
||||||
@@ -267,7 +307,7 @@ class DataProviderRouter:
|
|||||||
logger.warning(
|
logger.warning(
|
||||||
"Market cap facts source failed for %s: %s",
|
"Market cap facts source failed for %s: %s",
|
||||||
ticker,
|
ticker,
|
||||||
exc,
|
_format_provider_error(exc),
|
||||||
)
|
)
|
||||||
|
|
||||||
metrics, source = metrics_lookup(ticker, end_date)
|
metrics, source = metrics_lookup(ticker, end_date)
|
||||||
|
|||||||
@@ -1,184 +1,50 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
from pydantic import BaseModel
|
"""Compatibility schema bridge.
|
||||||
|
|
||||||
|
This module preserves the legacy ``backend.data.schema`` import path while
|
||||||
|
delegating the actual schema definitions to ``shared.schema``. Keeping one
|
||||||
|
canonical DTO set avoids drift as the monolith is split into service-specific
|
||||||
|
packages.
|
||||||
|
"""
|
||||||
|
|
||||||
class Price(BaseModel):
|
from shared.schema import (
|
||||||
open: float
|
AgentStateData,
|
||||||
close: float
|
AgentStateMetadata,
|
||||||
high: float
|
AnalystSignal,
|
||||||
low: float
|
CompanyFacts,
|
||||||
volume: int
|
CompanyFactsResponse,
|
||||||
time: str
|
CompanyNews,
|
||||||
|
CompanyNewsResponse,
|
||||||
|
FinancialMetrics,
|
||||||
|
FinancialMetricsResponse,
|
||||||
|
InsiderTrade,
|
||||||
|
InsiderTradeResponse,
|
||||||
|
LineItem,
|
||||||
|
LineItemResponse,
|
||||||
|
Portfolio,
|
||||||
|
Position,
|
||||||
|
Price,
|
||||||
|
PriceResponse,
|
||||||
|
TickerAnalysis,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
class PriceResponse(BaseModel):
|
"Price",
|
||||||
ticker: str
|
"PriceResponse",
|
||||||
prices: list[Price]
|
"FinancialMetrics",
|
||||||
|
"FinancialMetricsResponse",
|
||||||
|
"LineItem",
|
||||||
class FinancialMetrics(BaseModel):
|
"LineItemResponse",
|
||||||
ticker: str
|
"InsiderTrade",
|
||||||
report_period: str
|
"InsiderTradeResponse",
|
||||||
period: str
|
"CompanyNews",
|
||||||
currency: str
|
"CompanyNewsResponse",
|
||||||
market_cap: float | None
|
"CompanyFacts",
|
||||||
enterprise_value: float | None
|
"CompanyFactsResponse",
|
||||||
price_to_earnings_ratio: float | None
|
"Position",
|
||||||
price_to_book_ratio: float | None
|
"Portfolio",
|
||||||
price_to_sales_ratio: float | None
|
"AnalystSignal",
|
||||||
enterprise_value_to_ebitda_ratio: float | None
|
"TickerAnalysis",
|
||||||
enterprise_value_to_revenue_ratio: float | None
|
"AgentStateData",
|
||||||
free_cash_flow_yield: float | None
|
"AgentStateMetadata",
|
||||||
peg_ratio: float | None
|
]
|
||||||
gross_margin: float | None
|
|
||||||
operating_margin: float | None
|
|
||||||
net_margin: float | None
|
|
||||||
return_on_equity: float | None
|
|
||||||
return_on_assets: float | None
|
|
||||||
return_on_invested_capital: float | None
|
|
||||||
asset_turnover: float | None
|
|
||||||
inventory_turnover: float | None
|
|
||||||
receivables_turnover: float | None
|
|
||||||
days_sales_outstanding: float | None
|
|
||||||
operating_cycle: float | None
|
|
||||||
working_capital_turnover: float | None
|
|
||||||
current_ratio: float | None
|
|
||||||
quick_ratio: float | None
|
|
||||||
cash_ratio: float | None
|
|
||||||
operating_cash_flow_ratio: float | None
|
|
||||||
debt_to_equity: float | None
|
|
||||||
debt_to_assets: float | None
|
|
||||||
interest_coverage: float | None
|
|
||||||
revenue_growth: float | None
|
|
||||||
earnings_growth: float | None
|
|
||||||
book_value_growth: float | None
|
|
||||||
earnings_per_share_growth: float | None
|
|
||||||
free_cash_flow_growth: float | None
|
|
||||||
operating_income_growth: float | None
|
|
||||||
ebitda_growth: float | None
|
|
||||||
payout_ratio: float | None
|
|
||||||
earnings_per_share: float | None
|
|
||||||
book_value_per_share: float | None
|
|
||||||
free_cash_flow_per_share: float | None
|
|
||||||
|
|
||||||
|
|
||||||
class FinancialMetricsResponse(BaseModel):
|
|
||||||
financial_metrics: list[FinancialMetrics]
|
|
||||||
|
|
||||||
|
|
||||||
class LineItem(BaseModel):
|
|
||||||
ticker: str
|
|
||||||
report_period: str
|
|
||||||
period: str
|
|
||||||
currency: str
|
|
||||||
|
|
||||||
# Allow additional fields dynamically
|
|
||||||
model_config = {"extra": "allow"}
|
|
||||||
|
|
||||||
|
|
||||||
class LineItemResponse(BaseModel):
|
|
||||||
search_results: list[LineItem]
|
|
||||||
|
|
||||||
|
|
||||||
class InsiderTrade(BaseModel):
|
|
||||||
ticker: str
|
|
||||||
issuer: str | None
|
|
||||||
name: str | None
|
|
||||||
title: str | None
|
|
||||||
is_board_director: bool | None
|
|
||||||
transaction_date: str | None
|
|
||||||
transaction_shares: float | None
|
|
||||||
transaction_price_per_share: float | None
|
|
||||||
transaction_value: float | None
|
|
||||||
shares_owned_before_transaction: float | None
|
|
||||||
shares_owned_after_transaction: float | None
|
|
||||||
security_title: str | None
|
|
||||||
filing_date: str
|
|
||||||
|
|
||||||
|
|
||||||
class InsiderTradeResponse(BaseModel):
|
|
||||||
insider_trades: list[InsiderTrade]
|
|
||||||
|
|
||||||
|
|
||||||
class CompanyNews(BaseModel):
|
|
||||||
category: str | None = None
|
|
||||||
ticker: str
|
|
||||||
title: str
|
|
||||||
related: str | None = None
|
|
||||||
source: str
|
|
||||||
date: str | None = None
|
|
||||||
url: str
|
|
||||||
summary: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class CompanyNewsResponse(BaseModel):
|
|
||||||
news: list[CompanyNews]
|
|
||||||
|
|
||||||
|
|
||||||
class CompanyFacts(BaseModel):
|
|
||||||
ticker: str
|
|
||||||
name: str
|
|
||||||
cik: str | None = None
|
|
||||||
industry: str | None = None
|
|
||||||
sector: str | None = None
|
|
||||||
category: str | None = None
|
|
||||||
exchange: str | None = None
|
|
||||||
is_active: bool | None = None
|
|
||||||
listing_date: str | None = None
|
|
||||||
location: str | None = None
|
|
||||||
market_cap: float | None = None
|
|
||||||
number_of_employees: int | None = None
|
|
||||||
sec_filings_url: str | None = None
|
|
||||||
sic_code: str | None = None
|
|
||||||
sic_industry: str | None = None
|
|
||||||
sic_sector: str | None = None
|
|
||||||
website_url: str | None = None
|
|
||||||
weighted_average_shares: int | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class CompanyFactsResponse(BaseModel):
|
|
||||||
company_facts: CompanyFacts
|
|
||||||
|
|
||||||
|
|
||||||
class Position(BaseModel):
|
|
||||||
"""Position information - for Portfolio mode"""
|
|
||||||
|
|
||||||
long: int = 0 # Long position quantity (shares)
|
|
||||||
short: int = 0 # Short position quantity (shares)
|
|
||||||
long_cost_basis: float = 0.0 # Long position average cost
|
|
||||||
short_cost_basis: float = 0.0 # Short position average cost
|
|
||||||
|
|
||||||
|
|
||||||
class Portfolio(BaseModel):
|
|
||||||
"""Portfolio - for Portfolio mode"""
|
|
||||||
|
|
||||||
cash: float = 100000.0 # Available cash
|
|
||||||
positions: dict[str, Position] = {} # ticker -> Position mapping
|
|
||||||
# Margin requirement (0.0 means shorting disabled, 0.5 means 50% margin)
|
|
||||||
margin_requirement: float = 0.0
|
|
||||||
margin_used: float = 0.0 # Margin used
|
|
||||||
|
|
||||||
|
|
||||||
class AnalystSignal(BaseModel):
|
|
||||||
signal: str | None = None
|
|
||||||
confidence: float | None = None
|
|
||||||
reasoning: dict | str | None = None
|
|
||||||
max_position_size: float | None = None # For risk management signals
|
|
||||||
|
|
||||||
|
|
||||||
class TickerAnalysis(BaseModel):
|
|
||||||
ticker: str
|
|
||||||
analyst_signals: dict[str, AnalystSignal] # agent_name -> signal mapping
|
|
||||||
|
|
||||||
|
|
||||||
class AgentStateData(BaseModel):
|
|
||||||
tickers: list[str]
|
|
||||||
portfolio: Portfolio
|
|
||||||
start_date: str
|
|
||||||
end_date: str
|
|
||||||
ticker_analyses: dict[str, TickerAnalysis] # ticker -> analysis mapping
|
|
||||||
|
|
||||||
|
|
||||||
class AgentStateMetadata(BaseModel):
|
|
||||||
show_reasoning: bool = False
|
|
||||||
model_config = {"extra": "allow"}
|
|
||||||
|
|||||||
2
backend/domains/__init__.py
Normal file
2
backend/domains/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""Domain modules for split service internals."""
|
||||||
277
backend/domains/news.py
Normal file
277
backend/domains/news.py
Normal file
@@ -0,0 +1,277 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""News/explain domain helpers shared by app surfaces and gateway fallbacks."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from backend.data.market_store import MarketStore
|
||||||
|
from backend.data.market_ingest import update_ticker_incremental
|
||||||
|
from backend.enrich.news_enricher import enrich_news_for_symbol
|
||||||
|
from backend.explain.range_explainer import build_range_explanation
|
||||||
|
from backend.explain.similarity_service import find_similar_days
|
||||||
|
from backend.explain.story_service import get_or_create_stock_story
|
||||||
|
|
||||||
|
|
||||||
|
def news_rows_need_enrichment(rows: list[dict[str, Any]]) -> bool:
|
||||||
|
"""Return whether news rows are missing explain-oriented analysis fields."""
|
||||||
|
if not rows:
|
||||||
|
return True
|
||||||
|
return all(
|
||||||
|
not row.get("sentiment")
|
||||||
|
and not row.get("relevance")
|
||||||
|
and not row.get("key_discussion")
|
||||||
|
for row in rows
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_news_fresh(
|
||||||
|
store: MarketStore,
|
||||||
|
*,
|
||||||
|
ticker: str,
|
||||||
|
target_date: str | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Refresh raw news incrementally when stored watermarks are stale."""
|
||||||
|
normalized_target = str(target_date or "").strip()[:10]
|
||||||
|
if not normalized_target:
|
||||||
|
return {
|
||||||
|
"ticker": ticker,
|
||||||
|
"target_date": None,
|
||||||
|
"last_news_fetch": None,
|
||||||
|
"refreshed": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
watermarks = store.get_ticker_watermarks(ticker)
|
||||||
|
last_news_fetch = str(watermarks.get("last_news_fetch") or "").strip()[:10]
|
||||||
|
refreshed = False
|
||||||
|
if not last_news_fetch or last_news_fetch < normalized_target:
|
||||||
|
update_ticker_incremental(
|
||||||
|
ticker,
|
||||||
|
end_date=normalized_target,
|
||||||
|
store=store,
|
||||||
|
)
|
||||||
|
refreshed = True
|
||||||
|
watermarks = store.get_ticker_watermarks(ticker)
|
||||||
|
last_news_fetch = str(watermarks.get("last_news_fetch") or "").strip()[:10]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"ticker": ticker,
|
||||||
|
"target_date": normalized_target,
|
||||||
|
"last_news_fetch": last_news_fetch or None,
|
||||||
|
"refreshed": refreshed,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_enriched_news(
|
||||||
|
store: MarketStore,
|
||||||
|
*,
|
||||||
|
ticker: str,
|
||||||
|
start_date: str | None = None,
|
||||||
|
end_date: str | None = None,
|
||||||
|
limit: int = 100,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
freshness = ensure_news_fresh(store, ticker=ticker, target_date=end_date)
|
||||||
|
rows = store.get_news_items_enriched(
|
||||||
|
ticker,
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
if news_rows_need_enrichment(rows):
|
||||||
|
enrich_news_for_symbol(
|
||||||
|
store,
|
||||||
|
ticker,
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
rows = store.get_news_items_enriched(
|
||||||
|
ticker,
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
return {"ticker": ticker, "news": rows, "freshness": freshness}
|
||||||
|
|
||||||
|
|
||||||
|
def get_news_for_date(
|
||||||
|
store: MarketStore,
|
||||||
|
*,
|
||||||
|
ticker: str,
|
||||||
|
date: str,
|
||||||
|
limit: int = 20,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
freshness = ensure_news_fresh(store, ticker=ticker, target_date=date)
|
||||||
|
rows = store.get_news_items_enriched(
|
||||||
|
ticker,
|
||||||
|
trade_date=date,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
if news_rows_need_enrichment(rows):
|
||||||
|
enrich_news_for_symbol(
|
||||||
|
store,
|
||||||
|
ticker,
|
||||||
|
start_date=date,
|
||||||
|
end_date=date,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
rows = store.get_news_items_enriched(
|
||||||
|
ticker,
|
||||||
|
trade_date=date,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
return {"ticker": ticker, "date": date, "news": rows, "freshness": freshness}
|
||||||
|
|
||||||
|
|
||||||
|
def get_news_timeline(
|
||||||
|
store: MarketStore,
|
||||||
|
*,
|
||||||
|
ticker: str,
|
||||||
|
start_date: str,
|
||||||
|
end_date: str,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
freshness = ensure_news_fresh(store, ticker=ticker, target_date=end_date)
|
||||||
|
timeline = store.get_news_timeline_enriched(
|
||||||
|
ticker,
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
)
|
||||||
|
if not timeline:
|
||||||
|
enrich_news_for_symbol(
|
||||||
|
store,
|
||||||
|
ticker,
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
limit=200,
|
||||||
|
)
|
||||||
|
timeline = store.get_news_timeline_enriched(
|
||||||
|
ticker,
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"ticker": ticker,
|
||||||
|
"timeline": timeline,
|
||||||
|
"start_date": start_date,
|
||||||
|
"end_date": end_date,
|
||||||
|
"freshness": freshness,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_news_categories(
|
||||||
|
store: MarketStore,
|
||||||
|
*,
|
||||||
|
ticker: str,
|
||||||
|
start_date: str | None = None,
|
||||||
|
end_date: str | None = None,
|
||||||
|
limit: int = 200,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
freshness = ensure_news_fresh(store, ticker=ticker, target_date=end_date)
|
||||||
|
rows = store.get_news_items_enriched(
|
||||||
|
ticker,
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
if news_rows_need_enrichment(rows):
|
||||||
|
enrich_news_for_symbol(
|
||||||
|
store,
|
||||||
|
ticker,
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
categories = store.get_news_categories_enriched(
|
||||||
|
ticker,
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
return {"ticker": ticker, "categories": categories, "freshness": freshness}
|
||||||
|
|
||||||
|
|
||||||
|
def get_similar_days_payload(
|
||||||
|
store: MarketStore,
|
||||||
|
*,
|
||||||
|
ticker: str,
|
||||||
|
date: str,
|
||||||
|
n_similar: int = 5,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
freshness = ensure_news_fresh(store, ticker=ticker, target_date=date)
|
||||||
|
result = find_similar_days(
|
||||||
|
store,
|
||||||
|
symbol=ticker,
|
||||||
|
target_date=date,
|
||||||
|
top_k=n_similar,
|
||||||
|
)
|
||||||
|
result["freshness"] = freshness
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def get_story_payload(
|
||||||
|
store: MarketStore,
|
||||||
|
*,
|
||||||
|
ticker: str,
|
||||||
|
as_of_date: str,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
freshness = ensure_news_fresh(store, ticker=ticker, target_date=as_of_date)
|
||||||
|
enrich_news_for_symbol(
|
||||||
|
store,
|
||||||
|
ticker,
|
||||||
|
end_date=as_of_date,
|
||||||
|
limit=80,
|
||||||
|
)
|
||||||
|
result = get_or_create_stock_story(
|
||||||
|
store,
|
||||||
|
symbol=ticker,
|
||||||
|
as_of_date=as_of_date,
|
||||||
|
)
|
||||||
|
result["freshness"] = freshness
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def get_range_explain_payload(
|
||||||
|
store: MarketStore,
|
||||||
|
*,
|
||||||
|
ticker: str,
|
||||||
|
start_date: str,
|
||||||
|
end_date: str,
|
||||||
|
article_ids: list[str] | None = None,
|
||||||
|
limit: int = 100,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
freshness = ensure_news_fresh(store, ticker=ticker, target_date=end_date)
|
||||||
|
news_rows = []
|
||||||
|
if article_ids:
|
||||||
|
news_rows = store.get_news_by_ids_enriched(ticker, article_ids)
|
||||||
|
if not news_rows:
|
||||||
|
news_rows = store.get_news_items_enriched(
|
||||||
|
ticker,
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
if news_rows_need_enrichment(news_rows):
|
||||||
|
enrich_news_for_symbol(
|
||||||
|
store,
|
||||||
|
ticker,
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
news_rows = (
|
||||||
|
store.get_news_by_ids_enriched(ticker, article_ids)
|
||||||
|
if article_ids
|
||||||
|
else store.get_news_items_enriched(
|
||||||
|
ticker,
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
result = build_range_explanation(
|
||||||
|
ticker=ticker,
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
news_rows=news_rows,
|
||||||
|
)
|
||||||
|
return {"ticker": ticker, "result": result, "freshness": freshness}
|
||||||
106
backend/domains/trading.py
Normal file
106
backend/domains/trading.py
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""Trading domain helpers shared by app surfaces and gateway fallbacks."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from backend.services.market import MarketService
|
||||||
|
from backend.tools.data_tools import (
|
||||||
|
get_company_news,
|
||||||
|
get_financial_metrics,
|
||||||
|
get_insider_trades,
|
||||||
|
get_market_cap,
|
||||||
|
get_prices,
|
||||||
|
search_line_items,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_prices_payload(*, ticker: str, start_date: str, end_date: str) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"ticker": ticker,
|
||||||
|
"prices": get_prices(ticker, start_date, end_date),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_financials_payload(
|
||||||
|
*,
|
||||||
|
ticker: str,
|
||||||
|
end_date: str,
|
||||||
|
period: str = "ttm",
|
||||||
|
limit: int = 10,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"financial_metrics": get_financial_metrics(
|
||||||
|
ticker=ticker,
|
||||||
|
end_date=end_date,
|
||||||
|
period=period,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_news_payload(
|
||||||
|
*,
|
||||||
|
ticker: str,
|
||||||
|
end_date: str,
|
||||||
|
start_date: str | None = None,
|
||||||
|
limit: int = 1000,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"news": get_company_news(
|
||||||
|
ticker=ticker,
|
||||||
|
end_date=end_date,
|
||||||
|
start_date=start_date,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_insider_trades_payload(
|
||||||
|
*,
|
||||||
|
ticker: str,
|
||||||
|
end_date: str,
|
||||||
|
start_date: str | None = None,
|
||||||
|
limit: int = 1000,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"insider_trades": get_insider_trades(
|
||||||
|
ticker=ticker,
|
||||||
|
end_date=end_date,
|
||||||
|
start_date=start_date,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_market_status_payload() -> dict[str, Any]:
|
||||||
|
market_service = MarketService(tickers=[])
|
||||||
|
return market_service.get_market_status()
|
||||||
|
|
||||||
|
|
||||||
|
def get_market_cap_payload(*, ticker: str, end_date: str) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"ticker": ticker,
|
||||||
|
"end_date": end_date,
|
||||||
|
"market_cap": get_market_cap(ticker, end_date),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_line_items_payload(
|
||||||
|
*,
|
||||||
|
ticker: str,
|
||||||
|
line_items: list[str],
|
||||||
|
end_date: str,
|
||||||
|
period: str = "ttm",
|
||||||
|
limit: int = 10,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"search_results": search_line_items(
|
||||||
|
ticker=ticker,
|
||||||
|
line_items=line_items,
|
||||||
|
end_date=end_date,
|
||||||
|
period=period,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
}
|
||||||
2
backend/enrich/__init__.py
Normal file
2
backend/enrich/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
"""News enrichment utilities for explain-oriented market research."""
|
||||||
|
|
||||||
301
backend/enrich/llm_enricher.py
Normal file
301
backend/enrich/llm_enricher.py
Normal file
@@ -0,0 +1,301 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""Optional AgentScope-backed news enrichment with safe local fallback."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from backend.config.env_config import canonicalize_model_provider, get_env_bool, get_env_str
|
||||||
|
from backend.llm.models import create_model
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class EnrichedNewsItem(BaseModel):
|
||||||
|
"""Structured output schema for one enriched article."""
|
||||||
|
|
||||||
|
id: str = Field(description="The source article id")
|
||||||
|
relevance: str = Field(description="One of high, medium, low")
|
||||||
|
sentiment: str = Field(description="One of positive, negative, neutral")
|
||||||
|
key_discussion: str = Field(description="Concise core discussion")
|
||||||
|
summary: str = Field(description="Concise factual summary")
|
||||||
|
reason_growth: str = Field(description="Growth-oriented reason if present")
|
||||||
|
reason_decrease: str = Field(description="Downside-oriented reason if present")
|
||||||
|
|
||||||
|
|
||||||
|
class EnrichedNewsBatch(BaseModel):
|
||||||
|
"""Structured output schema for a batch of enriched articles."""
|
||||||
|
|
||||||
|
items: list[EnrichedNewsItem]
|
||||||
|
|
||||||
|
|
||||||
|
class RangeAnalysisPayload(BaseModel):
|
||||||
|
"""Structured output schema for range explanation text."""
|
||||||
|
|
||||||
|
summary: str = Field(description="Concise Chinese range summary for the selected window")
|
||||||
|
trend_analysis: str = Field(description="Concise Chinese trend explanation for the selected window")
|
||||||
|
bullish_factors: list[str] = Field(description="Top bullish factors in Chinese")
|
||||||
|
bearish_factors: list[str] = Field(description="Top bearish factors in Chinese")
|
||||||
|
|
||||||
|
|
||||||
|
def get_explain_model_info() -> dict[str, str]:
|
||||||
|
"""Resolve provider/model used by explain enrichment."""
|
||||||
|
provider = canonicalize_model_provider(
|
||||||
|
get_env_str("EXPLAIN_ENRICH_MODEL_PROVIDER")
|
||||||
|
or get_env_str("MODEL_PROVIDER", "OPENAI"),
|
||||||
|
)
|
||||||
|
model_name = get_env_str("EXPLAIN_ENRICH_MODEL_NAME") or get_env_str(
|
||||||
|
"MODEL_NAME",
|
||||||
|
"gpt-4o-mini",
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"provider": provider,
|
||||||
|
"model_name": model_name,
|
||||||
|
"label": f"{provider}:{model_name}",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_enrichment_payload(payload: Any) -> dict[str, Any] | None:
|
||||||
|
if isinstance(payload, BaseModel):
|
||||||
|
payload = payload.model_dump()
|
||||||
|
if not isinstance(payload, dict):
|
||||||
|
return None
|
||||||
|
return {
|
||||||
|
"relevance": str(payload.get("relevance") or "").strip().lower() or None,
|
||||||
|
"sentiment": str(payload.get("sentiment") or "").strip().lower() or None,
|
||||||
|
"key_discussion": str(payload.get("key_discussion") or "").strip() or None,
|
||||||
|
"summary": str(payload.get("summary") or "").strip() or None,
|
||||||
|
"reason_growth": str(payload.get("reason_growth") or "").strip() or None,
|
||||||
|
"reason_decrease": str(payload.get("reason_decrease") or "").strip() or None,
|
||||||
|
"raw_json": payload,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _run_async(coro: Any) -> Any:
|
||||||
|
"""Run an async AgentScope model call from sync code, even inside a running loop."""
|
||||||
|
try:
|
||||||
|
asyncio.get_running_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
return asyncio.run(coro)
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||||
|
future = executor.submit(asyncio.run, coro)
|
||||||
|
return future.result()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_explain_model():
|
||||||
|
"""Create an AgentScope model for explain enrichment."""
|
||||||
|
model_info = get_explain_model_info()
|
||||||
|
return create_model(
|
||||||
|
model_name=model_info["model_name"],
|
||||||
|
provider=model_info["provider"],
|
||||||
|
stream=False,
|
||||||
|
generate_kwargs={"temperature": 0.1},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def llm_enrichment_enabled() -> bool:
|
||||||
|
"""Return whether AgentScope-backed LLM enrichment should be attempted."""
|
||||||
|
if not get_env_bool("EXPLAIN_ENRICH_USE_LLM", False):
|
||||||
|
return False
|
||||||
|
provider = get_explain_model_info()["provider"]
|
||||||
|
provider_key_map = {
|
||||||
|
"OPENAI": "OPENAI_API_KEY",
|
||||||
|
"ANTHROPIC": "ANTHROPIC_API_KEY",
|
||||||
|
"DASHSCOPE": "DASHSCOPE_API_KEY",
|
||||||
|
"ALIBABA": "DASHSCOPE_API_KEY",
|
||||||
|
"GEMINI": "GOOGLE_API_KEY",
|
||||||
|
"GOOGLE": "GOOGLE_API_KEY",
|
||||||
|
"DEEPSEEK": "DEEPSEEK_API_KEY",
|
||||||
|
"GROQ": "GROQ_API_KEY",
|
||||||
|
"OPENROUTER": "OPENROUTER_API_KEY",
|
||||||
|
}
|
||||||
|
env_key = provider_key_map.get(provider)
|
||||||
|
return bool(get_env_str(env_key)) if env_key else provider == "OLLAMA"
|
||||||
|
|
||||||
|
|
||||||
|
def llm_range_analysis_enabled() -> bool:
|
||||||
|
"""Return whether LLM range analysis should be attempted."""
|
||||||
|
raw_value = get_env_str("EXPLAIN_RANGE_USE_LLM")
|
||||||
|
if raw_value is not None and str(raw_value).strip() != "":
|
||||||
|
return get_env_bool("EXPLAIN_RANGE_USE_LLM", False) and llm_enrichment_enabled()
|
||||||
|
return llm_enrichment_enabled()
|
||||||
|
|
||||||
|
|
||||||
|
def analyze_news_row_with_llm(row: dict[str, Any]) -> dict[str, Any] | None:
|
||||||
|
"""Generate explain-oriented structured analysis for one article."""
|
||||||
|
if not llm_enrichment_enabled():
|
||||||
|
return None
|
||||||
|
|
||||||
|
model = _get_explain_model()
|
||||||
|
title = str(row.get("title") or "").strip()
|
||||||
|
summary = str(row.get("summary") or "").strip()
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": (
|
||||||
|
"You produce concise structured financial news analysis. "
|
||||||
|
"Use only the requested fields and keep content factual."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": (
|
||||||
|
"Analyze this stock-news article for an explain UI.\n"
|
||||||
|
"Rules:\n"
|
||||||
|
"- relevance must be one of: high, medium, low\n"
|
||||||
|
"- sentiment must be one of: positive, negative, neutral\n"
|
||||||
|
"- keep each text field concise and factual\n"
|
||||||
|
f"- article id: {str(row.get('id') or '').strip()}\n"
|
||||||
|
f"Title: {title}\n"
|
||||||
|
f"Summary: {summary}\n"
|
||||||
|
),
|
||||||
|
},
|
||||||
|
]
|
||||||
|
try:
|
||||||
|
response = _run_async(model(messages=messages, structured_model=EnrichedNewsItem))
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"LLM enrichment failed: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
payload = _normalize_enrichment_payload(getattr(response, "metadata", None))
|
||||||
|
if payload:
|
||||||
|
payload.setdefault("raw_json", {})
|
||||||
|
payload["raw_json"]["model_provider"] = get_explain_model_info()["provider"]
|
||||||
|
payload["raw_json"]["model_name"] = get_explain_model_info()["model_name"]
|
||||||
|
payload["raw_json"]["model_label"] = get_explain_model_info()["label"]
|
||||||
|
return payload
|
||||||
|
|
||||||
|
|
||||||
|
def analyze_news_rows_with_llm(rows: list[dict[str, Any]]) -> dict[str, dict[str, Any]]:
|
||||||
|
"""Generate structured analysis for multiple articles in one request."""
|
||||||
|
if not llm_enrichment_enabled() or not rows:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
payload_rows = [
|
||||||
|
{
|
||||||
|
"id": str(row.get("id") or "").strip(),
|
||||||
|
"title": str(row.get("title") or "").strip(),
|
||||||
|
"summary": str(row.get("summary") or "").strip(),
|
||||||
|
}
|
||||||
|
for row in rows
|
||||||
|
if str(row.get("id") or "").strip()
|
||||||
|
]
|
||||||
|
if not payload_rows:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
model = _get_explain_model()
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": (
|
||||||
|
"You produce concise structured financial news analysis in JSON. "
|
||||||
|
"Preserve ids exactly and do not invent extra items."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": (
|
||||||
|
"Analyze these stock-news articles for an explain UI.\n"
|
||||||
|
"For each item return: id, relevance, sentiment, key_discussion, summary, "
|
||||||
|
"reason_growth, reason_decrease.\n"
|
||||||
|
"Rules:\n"
|
||||||
|
"- relevance must be one of: high, medium, low\n"
|
||||||
|
"- sentiment must be one of: positive, negative, neutral\n"
|
||||||
|
"- keep all text concise and factual\n"
|
||||||
|
f"Articles: {payload_rows}"
|
||||||
|
),
|
||||||
|
},
|
||||||
|
]
|
||||||
|
try:
|
||||||
|
response = _run_async(
|
||||||
|
model(messages=messages, structured_model=EnrichedNewsBatch),
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
metadata = getattr(response, "metadata", None)
|
||||||
|
if isinstance(metadata, BaseModel):
|
||||||
|
metadata = metadata.model_dump()
|
||||||
|
items = metadata.get("items") if isinstance(metadata, dict) else None
|
||||||
|
if not isinstance(items, list):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
results: dict[str, dict[str, Any]] = {}
|
||||||
|
for item in items:
|
||||||
|
normalized = _normalize_enrichment_payload(item)
|
||||||
|
news_id = str((item.model_dump() if isinstance(item, BaseModel) else item).get("id") or "").strip() if isinstance(item, (dict, BaseModel)) else ""
|
||||||
|
if normalized and news_id:
|
||||||
|
normalized.setdefault("raw_json", {})
|
||||||
|
normalized["raw_json"]["model_provider"] = get_explain_model_info()["provider"]
|
||||||
|
normalized["raw_json"]["model_name"] = get_explain_model_info()["model_name"]
|
||||||
|
normalized["raw_json"]["model_label"] = get_explain_model_info()["label"]
|
||||||
|
results[news_id] = normalized
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def analyze_range_with_llm(payload: dict[str, Any]) -> dict[str, Any] | None:
|
||||||
|
"""Generate explain-oriented range summary and factor refinement."""
|
||||||
|
if not llm_range_analysis_enabled():
|
||||||
|
return None
|
||||||
|
|
||||||
|
model = _get_explain_model()
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": (
|
||||||
|
"You write concise Chinese stock range analysis for an explain UI. "
|
||||||
|
"Use only the supplied facts. Keep the tone factual and analyst-like."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": (
|
||||||
|
"请基于给定事实生成区间分析。\n"
|
||||||
|
"输出字段:summary, trend_analysis, bullish_factors, bearish_factors。\n"
|
||||||
|
"要求:\n"
|
||||||
|
"- 全部使用简体中文\n"
|
||||||
|
"- summary 1到2句,概括区间走势、新闻密度和主导主题\n"
|
||||||
|
"- trend_analysis 1句,解释区间内部阶段变化\n"
|
||||||
|
"- bullish_factors 和 bearish_factors 各返回最多3条短句\n"
|
||||||
|
"- 不要编造未提供的信息\n"
|
||||||
|
f"事实数据: {payload}"
|
||||||
|
),
|
||||||
|
},
|
||||||
|
]
|
||||||
|
try:
|
||||||
|
response = _run_async(
|
||||||
|
model(messages=messages, structured_model=RangeAnalysisPayload),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"LLM enrichment failed: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
metadata = getattr(response, "metadata", None)
|
||||||
|
if isinstance(metadata, BaseModel):
|
||||||
|
metadata = metadata.model_dump()
|
||||||
|
if not isinstance(metadata, dict):
|
||||||
|
return None
|
||||||
|
|
||||||
|
return {
|
||||||
|
"summary": str(metadata.get("summary") or "").strip() or None,
|
||||||
|
"trend_analysis": str(metadata.get("trend_analysis") or "").strip() or None,
|
||||||
|
"bullish_factors": [
|
||||||
|
str(item).strip()
|
||||||
|
for item in list(metadata.get("bullish_factors") or [])
|
||||||
|
if str(item).strip()
|
||||||
|
][:3],
|
||||||
|
"bearish_factors": [
|
||||||
|
str(item).strip()
|
||||||
|
for item in list(metadata.get("bearish_factors") or [])
|
||||||
|
if str(item).strip()
|
||||||
|
][:3],
|
||||||
|
"model_provider": get_explain_model_info()["provider"],
|
||||||
|
"model_name": get_explain_model_info()["model_name"],
|
||||||
|
"model_label": get_explain_model_info()["label"],
|
||||||
|
}
|
||||||
362
backend/enrich/news_enricher.py
Normal file
362
backend/enrich/news_enricher.py
Normal file
@@ -0,0 +1,362 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""Lightweight news enrichment for explain-oriented market analysis."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from backend.config.env_config import get_env_int
|
||||||
|
from backend.enrich.llm_enricher import (
|
||||||
|
analyze_news_row_with_llm,
|
||||||
|
analyze_news_rows_with_llm,
|
||||||
|
llm_enrichment_enabled,
|
||||||
|
)
|
||||||
|
from backend.data.market_store import MarketStore
|
||||||
|
|
||||||
|
|
||||||
|
POSITIVE_KEYWORDS = (
|
||||||
|
"beat", "surge", "gain", "growth", "record", "upgrade", "strong",
|
||||||
|
"partnership", "approved", "launch", "expands", "profit",
|
||||||
|
)
|
||||||
|
NEGATIVE_KEYWORDS = (
|
||||||
|
"miss", "drop", "fall", "cut", "downgrade", "weak", "warning",
|
||||||
|
"delay", "lawsuit", "probe", "tariff", "decline", "layoff",
|
||||||
|
)
|
||||||
|
HIGH_RELEVANCE_KEYWORDS = (
|
||||||
|
"earnings", "guidance", "profit", "revenue", "ceo", "fda", "tariff",
|
||||||
|
"regulation", "acquisition", "buyback", "forecast", "launch",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _dedupe_key(row: dict[str, Any]) -> str:
|
||||||
|
trade_date = str(row.get("trade_date") or row.get("date") or "")[:10]
|
||||||
|
title = str(row.get("title") or "").strip().lower()
|
||||||
|
summary = str(row.get("summary") or "").strip().lower()[:160]
|
||||||
|
raw = f"{trade_date}::{title}::{summary}"
|
||||||
|
return hashlib.sha1(raw.encode("utf-8")).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def _chunk_rows(rows: list[dict[str, Any]], size: int) -> list[list[dict[str, Any]]]:
|
||||||
|
chunk_size = max(1, int(size))
|
||||||
|
return [rows[index:index + chunk_size] for index in range(0, len(rows), chunk_size)]
|
||||||
|
|
||||||
|
|
||||||
|
def classify_news_row(row: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Return a lightweight explain-oriented analysis for one article."""
|
||||||
|
llm_result = analyze_news_row_with_llm(row)
|
||||||
|
if isinstance(llm_result, dict):
|
||||||
|
merged = dict(llm_result)
|
||||||
|
merged.setdefault("summary", str(row.get("summary") or row.get("title") or "")[:280])
|
||||||
|
merged.setdefault("raw_json", row)
|
||||||
|
merged["analysis_source"] = "llm"
|
||||||
|
return merged
|
||||||
|
|
||||||
|
title = str(row.get("title") or "").strip()
|
||||||
|
summary = str(row.get("summary") or "").strip()
|
||||||
|
text = f"{title} {summary}".lower()
|
||||||
|
|
||||||
|
positive_hits = [keyword for keyword in POSITIVE_KEYWORDS if keyword in text]
|
||||||
|
negative_hits = [keyword for keyword in NEGATIVE_KEYWORDS if keyword in text]
|
||||||
|
relevance_hits = [keyword for keyword in HIGH_RELEVANCE_KEYWORDS if keyword in text]
|
||||||
|
|
||||||
|
if len(positive_hits) > len(negative_hits):
|
||||||
|
sentiment = "positive"
|
||||||
|
elif len(negative_hits) > len(positive_hits):
|
||||||
|
sentiment = "negative"
|
||||||
|
else:
|
||||||
|
sentiment = "neutral"
|
||||||
|
|
||||||
|
relevance = "high" if relevance_hits else "medium" if title else "low"
|
||||||
|
summary_text = summary or title
|
||||||
|
key_discussion = ""
|
||||||
|
if relevance_hits:
|
||||||
|
key_discussion = f"核心主题集中在 {', '.join(relevance_hits[:3])}"
|
||||||
|
elif summary_text:
|
||||||
|
key_discussion = summary_text[:160]
|
||||||
|
|
||||||
|
reason_growth = ""
|
||||||
|
reason_decrease = ""
|
||||||
|
if sentiment == "positive":
|
||||||
|
reason_growth = summary_text[:200]
|
||||||
|
elif sentiment == "negative":
|
||||||
|
reason_decrease = summary_text[:200]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"relevance": relevance,
|
||||||
|
"sentiment": sentiment,
|
||||||
|
"key_discussion": key_discussion,
|
||||||
|
"summary": summary_text[:280],
|
||||||
|
"reason_growth": reason_growth,
|
||||||
|
"reason_decrease": reason_decrease,
|
||||||
|
"analysis_source": "local",
|
||||||
|
"raw_json": row,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def attach_forward_returns(
|
||||||
|
*,
|
||||||
|
news_rows: list[dict[str, Any]],
|
||||||
|
ohlc_rows: list[dict[str, Any]],
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Attach forward-return labels to each analyzed row."""
|
||||||
|
if not ohlc_rows:
|
||||||
|
return news_rows
|
||||||
|
|
||||||
|
closes_by_date = {
|
||||||
|
str(row.get("date")): float(row.get("close"))
|
||||||
|
for row in ohlc_rows
|
||||||
|
if row.get("date") is not None and row.get("close") is not None
|
||||||
|
}
|
||||||
|
ordered_dates = [str(row.get("date")) for row in ohlc_rows if row.get("date") is not None]
|
||||||
|
date_index = {date: idx for idx, date in enumerate(ordered_dates)}
|
||||||
|
|
||||||
|
horizons = {
|
||||||
|
"ret_t0": 0,
|
||||||
|
"ret_t1": 1,
|
||||||
|
"ret_t3": 3,
|
||||||
|
"ret_t5": 5,
|
||||||
|
"ret_t10": 10,
|
||||||
|
}
|
||||||
|
|
||||||
|
enriched: list[dict[str, Any]] = []
|
||||||
|
for row in news_rows:
|
||||||
|
trade_date = str(row.get("trade_date") or "")[:10]
|
||||||
|
base_close = closes_by_date.get(trade_date)
|
||||||
|
if not trade_date or base_close in (None, 0):
|
||||||
|
enriched.append(row)
|
||||||
|
continue
|
||||||
|
|
||||||
|
next_row = dict(row)
|
||||||
|
base_index = date_index.get(trade_date)
|
||||||
|
if base_index is None:
|
||||||
|
enriched.append(next_row)
|
||||||
|
continue
|
||||||
|
|
||||||
|
for field, offset in horizons.items():
|
||||||
|
target_index = base_index + offset
|
||||||
|
if target_index >= len(ordered_dates):
|
||||||
|
next_row[field] = None
|
||||||
|
continue
|
||||||
|
target_close = closes_by_date.get(ordered_dates[target_index])
|
||||||
|
next_row[field] = (
|
||||||
|
(float(target_close) - float(base_close)) / float(base_close)
|
||||||
|
if target_close not in (None, 0)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
enriched.append(next_row)
|
||||||
|
return enriched
|
||||||
|
|
||||||
|
|
||||||
|
def build_analysis_rows(
|
||||||
|
*,
|
||||||
|
symbol: str,
|
||||||
|
news_rows: list[dict[str, Any]],
|
||||||
|
ohlc_rows: list[dict[str, Any]],
|
||||||
|
) -> tuple[list[dict[str, Any]], dict[str, int]]:
|
||||||
|
"""Transform raw news rows into market_store news_analysis payloads plus stats."""
|
||||||
|
llm_results: dict[str, dict[str, Any]] = {}
|
||||||
|
if llm_enrichment_enabled():
|
||||||
|
batch_size = get_env_int("EXPLAIN_ENRICH_BATCH_SIZE", 8)
|
||||||
|
for chunk in _chunk_rows(news_rows, batch_size):
|
||||||
|
llm_results.update(analyze_news_rows_with_llm(chunk))
|
||||||
|
|
||||||
|
staged_rows: list[dict[str, Any]] = []
|
||||||
|
seen_dedupe_keys: set[str] = set()
|
||||||
|
deduped_count = 0
|
||||||
|
llm_count = 0
|
||||||
|
local_count = 0
|
||||||
|
for row in news_rows:
|
||||||
|
news_id = str(row.get("id") or "").strip()
|
||||||
|
if not news_id:
|
||||||
|
continue
|
||||||
|
dedupe_key = _dedupe_key(row)
|
||||||
|
if dedupe_key in seen_dedupe_keys:
|
||||||
|
deduped_count += 1
|
||||||
|
continue
|
||||||
|
seen_dedupe_keys.add(dedupe_key)
|
||||||
|
batch_result = llm_results.get(news_id)
|
||||||
|
if isinstance(batch_result, dict):
|
||||||
|
analysis = dict(batch_result)
|
||||||
|
analysis.setdefault("summary", str(row.get("summary") or row.get("title") or "")[:280])
|
||||||
|
analysis.setdefault("raw_json", row)
|
||||||
|
analysis["analysis_source"] = "llm"
|
||||||
|
llm_count += 1
|
||||||
|
else:
|
||||||
|
analysis = classify_news_row(row)
|
||||||
|
if analysis.get("analysis_source") == "llm":
|
||||||
|
llm_count += 1
|
||||||
|
else:
|
||||||
|
local_count += 1
|
||||||
|
staged_rows.append(
|
||||||
|
{
|
||||||
|
"news_id": news_id,
|
||||||
|
"trade_date": str(row.get("trade_date") or "")[:10] or None,
|
||||||
|
**analysis,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
attach_forward_returns(news_rows=staged_rows, ohlc_rows=ohlc_rows),
|
||||||
|
{
|
||||||
|
"deduped_count": deduped_count,
|
||||||
|
"llm_count": llm_count,
|
||||||
|
"local_count": local_count,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def enrich_news_for_symbol(
|
||||||
|
store: MarketStore,
|
||||||
|
symbol: str,
|
||||||
|
*,
|
||||||
|
start_date: str | None = None,
|
||||||
|
end_date: str | None = None,
|
||||||
|
limit: int = 200,
|
||||||
|
analysis_source: str = "local",
|
||||||
|
skip_existing: bool = True,
|
||||||
|
only_reanalyze_local: bool = False,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Read raw market news, compute explain fields, and persist them."""
|
||||||
|
normalized_symbol = str(symbol or "").strip().upper()
|
||||||
|
if not normalized_symbol:
|
||||||
|
return {"symbol": "", "analyzed": 0}
|
||||||
|
|
||||||
|
news_rows = store.get_news_items(
|
||||||
|
normalized_symbol,
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
total_news_count = len(news_rows)
|
||||||
|
skipped_existing_count = 0
|
||||||
|
analyzed_sources: dict[str, str] = {}
|
||||||
|
skipped_missing_analysis_count = 0
|
||||||
|
skipped_non_local_count = 0
|
||||||
|
if news_rows and only_reanalyze_local:
|
||||||
|
analyzed_sources = store.get_analyzed_news_sources(
|
||||||
|
normalized_symbol,
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
)
|
||||||
|
skipped_missing_analysis_count = sum(
|
||||||
|
1
|
||||||
|
for row in news_rows
|
||||||
|
if str(row.get("id") or "").strip() not in analyzed_sources
|
||||||
|
)
|
||||||
|
skipped_non_local_count = sum(
|
||||||
|
1
|
||||||
|
for row in news_rows
|
||||||
|
if str(row.get("id") or "").strip() in analyzed_sources
|
||||||
|
and analyzed_sources.get(str(row.get("id") or "").strip()) != "local"
|
||||||
|
)
|
||||||
|
skipped_existing_count = sum(
|
||||||
|
1
|
||||||
|
for row in news_rows
|
||||||
|
if str(row.get("id") or "").strip() not in analyzed_sources
|
||||||
|
or analyzed_sources.get(str(row.get("id") or "").strip()) != "local"
|
||||||
|
)
|
||||||
|
news_rows = [
|
||||||
|
row for row in news_rows
|
||||||
|
if analyzed_sources.get(str(row.get("id") or "").strip()) == "local"
|
||||||
|
]
|
||||||
|
elif skip_existing and news_rows:
|
||||||
|
analyzed_ids = store.get_analyzed_news_ids(
|
||||||
|
normalized_symbol,
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
)
|
||||||
|
skipped_existing_count = sum(
|
||||||
|
1
|
||||||
|
for row in news_rows
|
||||||
|
if str(row.get("id") or "").strip() in analyzed_ids
|
||||||
|
)
|
||||||
|
news_rows = [
|
||||||
|
row for row in news_rows
|
||||||
|
if str(row.get("id") or "").strip() not in analyzed_ids
|
||||||
|
]
|
||||||
|
ohlc_start = start_date or (news_rows[-1]["trade_date"] if news_rows and news_rows[-1].get("trade_date") else None)
|
||||||
|
ohlc_end = end_date or (news_rows[0]["trade_date"] if news_rows and news_rows[0].get("trade_date") else None)
|
||||||
|
ohlc_rows = (
|
||||||
|
store.get_ohlc(normalized_symbol, ohlc_start, ohlc_end)
|
||||||
|
if ohlc_start and ohlc_end
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
analysis_rows, stats = build_analysis_rows(
|
||||||
|
symbol=normalized_symbol,
|
||||||
|
news_rows=news_rows,
|
||||||
|
ohlc_rows=ohlc_rows,
|
||||||
|
)
|
||||||
|
analyzed = store.upsert_news_analysis(
|
||||||
|
normalized_symbol,
|
||||||
|
analysis_rows,
|
||||||
|
analysis_source=analysis_source,
|
||||||
|
)
|
||||||
|
upgraded_dates = sorted(
|
||||||
|
{
|
||||||
|
str(row.get("trade_date") or "")[:10]
|
||||||
|
for row in analysis_rows
|
||||||
|
if str(row.get("analysis_source") or "").strip().lower() == "llm"
|
||||||
|
and str(row.get("trade_date") or "").strip()
|
||||||
|
}
|
||||||
|
)
|
||||||
|
remaining_local_titles = [
|
||||||
|
str(row.get("title") or row.get("news_id") or "").strip()
|
||||||
|
for row in news_rows
|
||||||
|
for analyzed_row in analysis_rows
|
||||||
|
if str(analyzed_row.get("news_id") or "").strip() == str(row.get("id") or "").strip()
|
||||||
|
and str(analyzed_row.get("analysis_source") or "").strip().lower() == "local"
|
||||||
|
][:5]
|
||||||
|
return {
|
||||||
|
"symbol": normalized_symbol,
|
||||||
|
"analyzed": analyzed,
|
||||||
|
"news_count": total_news_count,
|
||||||
|
"queued_count": len(news_rows),
|
||||||
|
"skipped_existing_count": skipped_existing_count,
|
||||||
|
"deduped_count": stats["deduped_count"],
|
||||||
|
"llm_count": stats["llm_count"],
|
||||||
|
"local_count": stats["local_count"],
|
||||||
|
"only_reanalyze_local": only_reanalyze_local,
|
||||||
|
"upgraded_local_to_llm_count": (
|
||||||
|
stats["llm_count"]
|
||||||
|
if only_reanalyze_local
|
||||||
|
else 0
|
||||||
|
),
|
||||||
|
"execution_summary": {
|
||||||
|
"upgraded_dates": upgraded_dates[:5],
|
||||||
|
"remaining_local_titles": remaining_local_titles,
|
||||||
|
"skipped_missing_analysis_count": skipped_missing_analysis_count,
|
||||||
|
"skipped_non_local_count": skipped_non_local_count,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def enrich_symbols(
|
||||||
|
store: MarketStore,
|
||||||
|
symbols: list[str],
|
||||||
|
*,
|
||||||
|
start_date: str | None = None,
|
||||||
|
end_date: str | None = None,
|
||||||
|
limit: int = 200,
|
||||||
|
analysis_source: str = "local",
|
||||||
|
skip_existing: bool = True,
|
||||||
|
only_reanalyze_local: bool = False,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Batch enrich multiple symbols for explain-oriented news analysis."""
|
||||||
|
results = []
|
||||||
|
for symbol in symbols:
|
||||||
|
normalized_symbol = str(symbol or "").strip().upper()
|
||||||
|
if not normalized_symbol:
|
||||||
|
continue
|
||||||
|
results.append(
|
||||||
|
enrich_news_for_symbol(
|
||||||
|
store,
|
||||||
|
normalized_symbol,
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
limit=limit,
|
||||||
|
analysis_source=analysis_source,
|
||||||
|
skip_existing=skip_existing,
|
||||||
|
only_reanalyze_local=only_reanalyze_local,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return results
|
||||||
2
backend/explain/__init__.py
Normal file
2
backend/explain/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""Explain-oriented services for stock narratives and news research."""
|
||||||
69
backend/explain/category_engine.py
Normal file
69
backend/explain/category_engine.py
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""Rule-based news categorization for explain UI."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Dict, Iterable
|
||||||
|
|
||||||
|
|
||||||
|
CATEGORY_KEYWORDS = {
|
||||||
|
"market": [
|
||||||
|
"market", "stock", "rally", "sell-off", "selloff", "trading",
|
||||||
|
"wall street", "s&p", "nasdaq", "dow", "index", "bull", "bear",
|
||||||
|
"correction", "volatility",
|
||||||
|
],
|
||||||
|
"policy": [
|
||||||
|
"regulation", "fed", "federal reserve", "tariff", "sanction",
|
||||||
|
"interest rate", "policy", "government", "congress", "sec",
|
||||||
|
"trade war", "ban", "legislation", "tax",
|
||||||
|
],
|
||||||
|
"earnings": [
|
||||||
|
"earnings", "revenue", "profit", "quarter", "eps", "guidance",
|
||||||
|
"forecast", "income", "sales", "beat", "miss", "outlook",
|
||||||
|
"financial results",
|
||||||
|
],
|
||||||
|
"product_tech": [
|
||||||
|
"product", "ai", "chip", "cloud", "launch", "patent",
|
||||||
|
"technology", "innovation", "release", "platform", "model",
|
||||||
|
"software", "hardware", "gpu", "autonomous",
|
||||||
|
],
|
||||||
|
"competition": [
|
||||||
|
"competitor", "rival", "market share", "overtake", "compete",
|
||||||
|
"competition", "vs", "versus", "battle", "challenge",
|
||||||
|
],
|
||||||
|
"management": [
|
||||||
|
"ceo", "executive", "resign", "layoff", "restructure",
|
||||||
|
"management", "leadership", "appoint", "hire", "board",
|
||||||
|
"chairman",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def categorize_news_rows(rows: Iterable[Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
|
||||||
|
"""Bucket news rows by keyword categories."""
|
||||||
|
categories: Dict[str, Dict[str, Any]] = {
|
||||||
|
key: {
|
||||||
|
"label": key,
|
||||||
|
"count": 0,
|
||||||
|
"article_ids": [],
|
||||||
|
}
|
||||||
|
for key in CATEGORY_KEYWORDS
|
||||||
|
}
|
||||||
|
|
||||||
|
for row in rows:
|
||||||
|
text = " ".join(
|
||||||
|
[
|
||||||
|
str(row.get("title") or ""),
|
||||||
|
str(row.get("summary") or ""),
|
||||||
|
str(row.get("related") or ""),
|
||||||
|
str(row.get("category") or ""),
|
||||||
|
]
|
||||||
|
).lower()
|
||||||
|
article_id = row.get("id")
|
||||||
|
for category, keywords in CATEGORY_KEYWORDS.items():
|
||||||
|
if any(keyword in text for keyword in keywords):
|
||||||
|
categories[category]["count"] += 1
|
||||||
|
if article_id:
|
||||||
|
categories[category]["article_ids"].append(article_id)
|
||||||
|
|
||||||
|
return categories
|
||||||
214
backend/explain/range_explainer.py
Normal file
214
backend/explain/range_explainer.py
Normal file
@@ -0,0 +1,214 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""Local range explanation built from price and persisted news."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
from backend.enrich.llm_enricher import analyze_range_with_llm
|
||||||
|
from backend.explain.category_engine import categorize_news_rows
|
||||||
|
from backend.tools.data_tools import get_prices
|
||||||
|
|
||||||
|
|
||||||
|
def _rank_event_score(row: Dict[str, Any]) -> float:
|
||||||
|
relevance = str(row.get("relevance") or "").strip().lower()
|
||||||
|
relevance_score = {"high": 3.0, "relevant": 3.0, "medium": 2.0, "low": 1.0}.get(
|
||||||
|
relevance,
|
||||||
|
0.5,
|
||||||
|
)
|
||||||
|
impact_score = abs(float(row.get("ret_t0") or 0.0)) * 100
|
||||||
|
return relevance_score + impact_score
|
||||||
|
|
||||||
|
|
||||||
|
def summarize_bullish_factors(
|
||||||
|
news_rows: list[Dict[str, Any]],
|
||||||
|
*,
|
||||||
|
limit: int = 5,
|
||||||
|
) -> list[str]:
|
||||||
|
factors = []
|
||||||
|
for row in news_rows:
|
||||||
|
if str(row.get("sentiment") or "").strip().lower() != "positive":
|
||||||
|
continue
|
||||||
|
candidate = row.get("reason_growth") or row.get("key_discussion") or row.get("summary") or row.get("title")
|
||||||
|
if candidate:
|
||||||
|
factors.append(str(candidate).strip())
|
||||||
|
seen = set()
|
||||||
|
output = []
|
||||||
|
for factor in factors:
|
||||||
|
if factor in seen:
|
||||||
|
continue
|
||||||
|
seen.add(factor)
|
||||||
|
output.append(factor[:200])
|
||||||
|
if len(output) >= limit:
|
||||||
|
break
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def summarize_bearish_factors(
|
||||||
|
news_rows: list[Dict[str, Any]],
|
||||||
|
*,
|
||||||
|
limit: int = 5,
|
||||||
|
) -> list[str]:
|
||||||
|
factors = []
|
||||||
|
for row in news_rows:
|
||||||
|
if str(row.get("sentiment") or "").strip().lower() != "negative":
|
||||||
|
continue
|
||||||
|
candidate = row.get("reason_decrease") or row.get("key_discussion") or row.get("summary") or row.get("title")
|
||||||
|
if candidate:
|
||||||
|
factors.append(str(candidate).strip())
|
||||||
|
seen = set()
|
||||||
|
output = []
|
||||||
|
for factor in factors:
|
||||||
|
if factor in seen:
|
||||||
|
continue
|
||||||
|
seen.add(factor)
|
||||||
|
output.append(factor[:200])
|
||||||
|
if len(output) >= limit:
|
||||||
|
break
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def build_trend_analysis(prices: list[Any]) -> str:
|
||||||
|
if len(prices) < 2:
|
||||||
|
return "区间样本较短,暂不具备足够趋势信息。"
|
||||||
|
if len(prices) < 3:
|
||||||
|
open_price = float(prices[0].open)
|
||||||
|
close_price = float(prices[-1].close)
|
||||||
|
change = ((close_price - open_price) / open_price) * 100 if open_price else 0.0
|
||||||
|
return f"短区间内价格变动 {change:+.2f}%,趋势信息有限。"
|
||||||
|
|
||||||
|
mid = len(prices) // 2
|
||||||
|
first_open = float(prices[0].open)
|
||||||
|
first_close = float(prices[mid].close)
|
||||||
|
second_open = float(prices[mid].open)
|
||||||
|
second_close = float(prices[-1].close)
|
||||||
|
first_half = ((first_close - first_open) / first_open) * 100 if first_open else 0.0
|
||||||
|
second_half = ((second_close - second_open) / second_open) * 100 if second_open else 0.0
|
||||||
|
return (
|
||||||
|
f"前半段{'上涨' if first_half >= 0 else '下跌'} {abs(first_half):.2f}%,"
|
||||||
|
f"后半段{'上涨' if second_half >= 0 else '下跌'} {abs(second_half):.2f}%,"
|
||||||
|
"说明价格驱动在区间内部出现了阶段性切换。"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_range_explanation(
|
||||||
|
*,
|
||||||
|
ticker: str,
|
||||||
|
start_date: str,
|
||||||
|
end_date: str,
|
||||||
|
news_rows: list[Dict[str, Any]],
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Explain a price range with local price and news heuristics."""
|
||||||
|
prices = get_prices(ticker, start_date, end_date)
|
||||||
|
if not prices:
|
||||||
|
return {
|
||||||
|
"symbol": ticker,
|
||||||
|
"start_date": start_date,
|
||||||
|
"end_date": end_date,
|
||||||
|
"error": "No OHLC data for this range",
|
||||||
|
}
|
||||||
|
|
||||||
|
open_price = float(prices[0].open)
|
||||||
|
close_price = float(prices[-1].close)
|
||||||
|
high_price = max(float(price.high) for price in prices)
|
||||||
|
low_price = min(float(price.low) for price in prices)
|
||||||
|
total_volume = sum(int(price.volume) for price in prices)
|
||||||
|
price_change_pct = (
|
||||||
|
((close_price - open_price) / open_price) * 100 if open_price else 0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
categories = categorize_news_rows(news_rows)
|
||||||
|
news_count = len(news_rows)
|
||||||
|
dominant_categories = sorted(
|
||||||
|
(
|
||||||
|
{"category": key, "count": value["count"]}
|
||||||
|
for key, value in categories.items()
|
||||||
|
if value["count"] > 0
|
||||||
|
),
|
||||||
|
key=lambda item: item["count"],
|
||||||
|
reverse=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
direction = "上涨" if price_change_pct > 0 else "下跌" if price_change_pct < 0 else "横盘"
|
||||||
|
category_text = (
|
||||||
|
f"主要主题集中在 {', '.join(item['category'] for item in dominant_categories[:3])}。"
|
||||||
|
if dominant_categories
|
||||||
|
else "区间内未识别出明显的主题聚类。"
|
||||||
|
)
|
||||||
|
summary = (
|
||||||
|
f"{ticker} 在 {start_date} 至 {end_date} 区间内{direction} {abs(price_change_pct):.2f}%,"
|
||||||
|
f"区间覆盖 {len(prices)} 个交易日,关联新闻 {news_count} 条。{category_text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
bullish_factors = summarize_bullish_factors(news_rows)
|
||||||
|
bearish_factors = summarize_bearish_factors(news_rows)
|
||||||
|
trend_analysis = build_trend_analysis(prices)
|
||||||
|
llm_source = "local"
|
||||||
|
|
||||||
|
range_payload = {
|
||||||
|
"ticker": ticker,
|
||||||
|
"start_date": start_date,
|
||||||
|
"end_date": end_date,
|
||||||
|
"price_change_pct": round(price_change_pct, 2),
|
||||||
|
"trading_days": len(prices),
|
||||||
|
"news_count": news_count,
|
||||||
|
"dominant_categories": dominant_categories[:5],
|
||||||
|
"bullish_factors": bullish_factors[:3],
|
||||||
|
"bearish_factors": bearish_factors[:3],
|
||||||
|
"trend_analysis": trend_analysis,
|
||||||
|
"top_news": [
|
||||||
|
{
|
||||||
|
"date": row.get("trade_date") or str(row.get("date") or "")[:10],
|
||||||
|
"title": row.get("title") or "",
|
||||||
|
"summary": row.get("summary") or "",
|
||||||
|
"sentiment": row.get("sentiment") or "",
|
||||||
|
"relevance": row.get("relevance") or "",
|
||||||
|
"ret_t0": row.get("ret_t0"),
|
||||||
|
}
|
||||||
|
for row in sorted(news_rows, key=_rank_event_score, reverse=True)[:5]
|
||||||
|
],
|
||||||
|
}
|
||||||
|
llm_analysis = analyze_range_with_llm(range_payload)
|
||||||
|
if isinstance(llm_analysis, dict):
|
||||||
|
summary = llm_analysis.get("summary") or summary
|
||||||
|
trend_analysis = llm_analysis.get("trend_analysis") or trend_analysis
|
||||||
|
bullish_factors = llm_analysis.get("bullish_factors") or bullish_factors
|
||||||
|
bearish_factors = llm_analysis.get("bearish_factors") or bearish_factors
|
||||||
|
llm_source = "llm"
|
||||||
|
|
||||||
|
key_events = [
|
||||||
|
{
|
||||||
|
"date": row.get("trade_date") or str(row.get("date") or "")[:10],
|
||||||
|
"title": row.get("title") or "Untitled news",
|
||||||
|
"summary": row.get("summary") or "",
|
||||||
|
"category": row.get("category") or "",
|
||||||
|
"id": row.get("id"),
|
||||||
|
"sentiment": row.get("sentiment"),
|
||||||
|
"ret_t0": row.get("ret_t0"),
|
||||||
|
}
|
||||||
|
for row in sorted(news_rows, key=_rank_event_score, reverse=True)[:8]
|
||||||
|
]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"symbol": ticker,
|
||||||
|
"start_date": start_date,
|
||||||
|
"end_date": end_date,
|
||||||
|
"price_change_pct": round(price_change_pct, 2),
|
||||||
|
"open_price": open_price,
|
||||||
|
"close_price": close_price,
|
||||||
|
"high_price": high_price,
|
||||||
|
"low_price": low_price,
|
||||||
|
"total_volume": total_volume,
|
||||||
|
"trading_days": len(prices),
|
||||||
|
"news_count": news_count,
|
||||||
|
"dominant_categories": dominant_categories[:5],
|
||||||
|
"analysis": {
|
||||||
|
"summary": summary,
|
||||||
|
"key_events": key_events,
|
||||||
|
"bullish_factors": bullish_factors,
|
||||||
|
"bearish_factors": bearish_factors,
|
||||||
|
"trend_analysis": trend_analysis,
|
||||||
|
"analysis_source": llm_source,
|
||||||
|
"analysis_model_label": llm_analysis.get("model_label") if isinstance(llm_analysis, dict) else None,
|
||||||
|
},
|
||||||
|
}
|
||||||
202
backend/explain/similarity_service.py
Normal file
202
backend/explain/similarity_service.py
Normal file
@@ -0,0 +1,202 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""Same-ticker historical similar day search for explain view."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from math import sqrt
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from backend.data.market_store import MarketStore
|
||||||
|
|
||||||
|
|
||||||
|
def _safe_float(value: Any, default: float = 0.0) -> float:
|
||||||
|
try:
|
||||||
|
parsed = float(value)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return default
|
||||||
|
return parsed
|
||||||
|
|
||||||
|
|
||||||
|
def build_daily_feature_rows(
|
||||||
|
*,
|
||||||
|
symbol: str,
|
||||||
|
ohlc_rows: list[dict[str, Any]],
|
||||||
|
news_rows: list[dict[str, Any]],
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Aggregate price/news context into daily feature rows."""
|
||||||
|
price_by_date = {str(row.get("date")): row for row in ohlc_rows if row.get("date")}
|
||||||
|
ordered_dates = [str(row.get("date")) for row in ohlc_rows if row.get("date")]
|
||||||
|
|
||||||
|
news_by_date: dict[str, list[dict[str, Any]]] = {}
|
||||||
|
for row in news_rows:
|
||||||
|
trade_date = str(row.get("trade_date") or "")[:10] or str(row.get("date") or "")[:10]
|
||||||
|
if not trade_date:
|
||||||
|
continue
|
||||||
|
news_by_date.setdefault(trade_date, []).append(row)
|
||||||
|
|
||||||
|
features: list[dict[str, Any]] = []
|
||||||
|
previous_close: float | None = None
|
||||||
|
for idx, date in enumerate(ordered_dates):
|
||||||
|
price_row = price_by_date[date]
|
||||||
|
close_price = _safe_float(price_row.get("close"))
|
||||||
|
open_price = _safe_float(price_row.get("open"), close_price)
|
||||||
|
day_news = news_by_date.get(date, [])
|
||||||
|
positive_count = sum(1 for item in day_news if str(item.get("sentiment") or "").lower() == "positive")
|
||||||
|
negative_count = sum(1 for item in day_news if str(item.get("sentiment") or "").lower() == "negative")
|
||||||
|
high_relevance_count = sum(
|
||||||
|
1 for item in day_news if str(item.get("relevance") or "").lower() in {"high", "relevant"}
|
||||||
|
)
|
||||||
|
ret_1d = (
|
||||||
|
((close_price - previous_close) / previous_close)
|
||||||
|
if previous_close not in (None, 0)
|
||||||
|
else 0.0
|
||||||
|
)
|
||||||
|
intraday_ret = ((close_price - open_price) / open_price) if open_price else 0.0
|
||||||
|
sentiment_score = (
|
||||||
|
(positive_count - negative_count) / max(len(day_news), 1)
|
||||||
|
if day_news
|
||||||
|
else 0.0
|
||||||
|
)
|
||||||
|
future_t1 = None
|
||||||
|
future_t3 = None
|
||||||
|
if idx + 1 < len(ordered_dates) and close_price:
|
||||||
|
next_close = _safe_float(price_by_date[ordered_dates[idx + 1]].get("close"))
|
||||||
|
future_t1 = ((next_close - close_price) / close_price) if next_close else None
|
||||||
|
if idx + 3 < len(ordered_dates) and close_price:
|
||||||
|
next_close = _safe_float(price_by_date[ordered_dates[idx + 3]].get("close"))
|
||||||
|
future_t3 = ((next_close - close_price) / close_price) if next_close else None
|
||||||
|
|
||||||
|
features.append(
|
||||||
|
{
|
||||||
|
"date": date,
|
||||||
|
"symbol": symbol,
|
||||||
|
"n_articles": len(day_news),
|
||||||
|
"positive_count": positive_count,
|
||||||
|
"negative_count": negative_count,
|
||||||
|
"high_relevance_count": high_relevance_count,
|
||||||
|
"sentiment_score": sentiment_score,
|
||||||
|
"ret_1d": ret_1d,
|
||||||
|
"intraday_ret": intraday_ret,
|
||||||
|
"close": close_price,
|
||||||
|
"ret_t1_after": future_t1,
|
||||||
|
"ret_t3_after": future_t3,
|
||||||
|
"news": [
|
||||||
|
{
|
||||||
|
"title": row.get("title") or "",
|
||||||
|
"sentiment": row.get("sentiment") or "neutral",
|
||||||
|
}
|
||||||
|
for row in day_news[:3]
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
previous_close = close_price
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
|
def compute_similarity_scores(
|
||||||
|
target_vector: list[float],
|
||||||
|
candidate_vectors: list[tuple[str, list[float], dict[str, Any]]],
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Return sorted similarity matches based on normalized Euclidean distance."""
|
||||||
|
if not candidate_vectors:
|
||||||
|
return []
|
||||||
|
dimensions = len(target_vector)
|
||||||
|
ranges = []
|
||||||
|
for dimension in range(dimensions):
|
||||||
|
values = [vector[1][dimension] for vector in candidate_vectors] + [target_vector[dimension]]
|
||||||
|
min_value = min(values)
|
||||||
|
max_value = max(values)
|
||||||
|
ranges.append(max(max_value - min_value, 1e-9))
|
||||||
|
|
||||||
|
scored = []
|
||||||
|
for date, vector, payload in candidate_vectors:
|
||||||
|
distance = sqrt(
|
||||||
|
sum(
|
||||||
|
((target_vector[i] - vector[i]) / ranges[i]) ** 2
|
||||||
|
for i in range(dimensions)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
similarity = 1.0 / (1.0 + distance)
|
||||||
|
scored.append(
|
||||||
|
{
|
||||||
|
"date": date,
|
||||||
|
"score": round(similarity, 4),
|
||||||
|
**payload,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return sorted(scored, key=lambda item: item["score"], reverse=True)
|
||||||
|
|
||||||
|
|
||||||
|
def find_similar_days(
|
||||||
|
store: MarketStore,
|
||||||
|
*,
|
||||||
|
symbol: str,
|
||||||
|
target_date: str,
|
||||||
|
top_k: int = 10,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Find same-ticker historical days most similar to a target day."""
|
||||||
|
cached = store.get_similar_day_cache(symbol, target_date=target_date)
|
||||||
|
if cached and cached.get("payload"):
|
||||||
|
return cached["payload"]
|
||||||
|
|
||||||
|
ohlc_rows = store.get_ohlc(symbol, "1900-01-01", target_date)
|
||||||
|
news_rows = store.get_news_items_enriched(symbol, end_date=target_date, limit=500)
|
||||||
|
daily_rows = build_daily_feature_rows(symbol=symbol, ohlc_rows=ohlc_rows, news_rows=news_rows)
|
||||||
|
feature_map = {row["date"]: row for row in daily_rows}
|
||||||
|
target_row = feature_map.get(target_date)
|
||||||
|
if not target_row:
|
||||||
|
return {
|
||||||
|
"symbol": symbol,
|
||||||
|
"target_date": target_date,
|
||||||
|
"items": [],
|
||||||
|
"error": "No feature row for target date",
|
||||||
|
}
|
||||||
|
|
||||||
|
vector_keys = [
|
||||||
|
"sentiment_score",
|
||||||
|
"n_articles",
|
||||||
|
"positive_count",
|
||||||
|
"negative_count",
|
||||||
|
"high_relevance_count",
|
||||||
|
"ret_1d",
|
||||||
|
"intraday_ret",
|
||||||
|
]
|
||||||
|
target_vector = [_safe_float(target_row.get(key)) for key in vector_keys]
|
||||||
|
candidates = []
|
||||||
|
for row in daily_rows:
|
||||||
|
date = row["date"]
|
||||||
|
if date == target_date:
|
||||||
|
continue
|
||||||
|
payload = {
|
||||||
|
"n_articles": row["n_articles"],
|
||||||
|
"sentiment_score": round(row["sentiment_score"], 4),
|
||||||
|
"ret_1d": round(row["ret_1d"] * 100, 2),
|
||||||
|
"intraday_ret": round(row["intraday_ret"] * 100, 2),
|
||||||
|
"ret_t1_after": round(row["ret_t1_after"] * 100, 2) if row["ret_t1_after"] is not None else None,
|
||||||
|
"ret_t3_after": round(row["ret_t3_after"] * 100, 2) if row["ret_t3_after"] is not None else None,
|
||||||
|
"top_reasons": [item["title"] for item in row["news"][:2] if item.get("title")],
|
||||||
|
"news": row["news"],
|
||||||
|
}
|
||||||
|
candidates.append(
|
||||||
|
(
|
||||||
|
date,
|
||||||
|
[_safe_float(row.get(key)) for key in vector_keys],
|
||||||
|
payload,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
items = compute_similarity_scores(target_vector, candidates)[: max(1, min(int(top_k), 20))]
|
||||||
|
result = {
|
||||||
|
"symbol": symbol,
|
||||||
|
"target_date": target_date,
|
||||||
|
"target_features": {
|
||||||
|
"sentiment_score": round(target_row["sentiment_score"], 4),
|
||||||
|
"n_articles": target_row["n_articles"],
|
||||||
|
"ret_1d": round(target_row["ret_1d"] * 100, 2),
|
||||||
|
"intraday_ret": round(target_row["intraday_ret"] * 100, 2),
|
||||||
|
"high_relevance_count": target_row["high_relevance_count"],
|
||||||
|
},
|
||||||
|
"items": items,
|
||||||
|
}
|
||||||
|
store.upsert_similar_day_cache(symbol, target_date=target_date, payload=result, source="local")
|
||||||
|
return result
|
||||||
127
backend/explain/story_service.py
Normal file
127
backend/explain/story_service.py
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""Stock story generation for explain view."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from backend.data.market_store import MarketStore
|
||||||
|
|
||||||
|
|
||||||
|
def build_stock_story(
|
||||||
|
*,
|
||||||
|
symbol: str,
|
||||||
|
as_of_date: str,
|
||||||
|
price_rows: list[dict[str, Any]],
|
||||||
|
news_rows: list[dict[str, Any]],
|
||||||
|
) -> str:
|
||||||
|
"""Build a compact markdown story from enriched news and recent price action."""
|
||||||
|
lines = [f"## {symbol} Story", f"As of `{as_of_date}`"]
|
||||||
|
if not price_rows:
|
||||||
|
lines.append("")
|
||||||
|
lines.append("No OHLC data available for story generation.")
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
open_price = float(price_rows[0].get("open") or price_rows[0].get("close") or 0.0)
|
||||||
|
close_price = float(price_rows[-1].get("close") or 0.0)
|
||||||
|
price_change = ((close_price - open_price) / open_price) * 100 if open_price else 0.0
|
||||||
|
high_price = max(float(row.get("high") or row.get("close") or 0.0) for row in price_rows)
|
||||||
|
low_price = min(float(row.get("low") or row.get("close") or 0.0) for row in price_rows)
|
||||||
|
|
||||||
|
lines.append("")
|
||||||
|
lines.append(
|
||||||
|
f"The stock moved {'up' if price_change >= 0 else 'down'} "
|
||||||
|
f"{abs(price_change):.2f}% over the recent window, trading between "
|
||||||
|
f"${low_price:.2f} and ${high_price:.2f}."
|
||||||
|
)
|
||||||
|
|
||||||
|
positive = [row for row in news_rows if str(row.get("sentiment") or "").lower() == "positive"]
|
||||||
|
negative = [row for row in news_rows if str(row.get("sentiment") or "").lower() == "negative"]
|
||||||
|
lines.append("")
|
||||||
|
lines.append(
|
||||||
|
f"Recent coverage included {len(news_rows)} relevant articles "
|
||||||
|
f"({len(positive)} positive / {len(negative)} negative)."
|
||||||
|
)
|
||||||
|
|
||||||
|
if news_rows:
|
||||||
|
lines.append("")
|
||||||
|
lines.append("### Key Moments")
|
||||||
|
ranked_rows = sorted(
|
||||||
|
news_rows,
|
||||||
|
key=lambda row: (
|
||||||
|
0 if str(row.get("relevance") or "").lower() in {"high", "relevant"} else 1,
|
||||||
|
-abs(float(row.get("ret_t0") or 0.0)),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for row in ranked_rows[:5]:
|
||||||
|
trade_date = row.get("trade_date") or str(row.get("date") or "")[:10]
|
||||||
|
title = row.get("title") or "Untitled"
|
||||||
|
key_discussion = row.get("key_discussion") or row.get("summary") or ""
|
||||||
|
sentiment = str(row.get("sentiment") or "neutral").lower()
|
||||||
|
lines.append(
|
||||||
|
f"- `{trade_date}` [{sentiment}] {title}: {str(key_discussion).strip()[:220]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if positive:
|
||||||
|
lines.append("")
|
||||||
|
lines.append("### Bullish Threads")
|
||||||
|
for row in positive[:3]:
|
||||||
|
reason = row.get("reason_growth") or row.get("key_discussion") or row.get("summary") or row.get("title")
|
||||||
|
lines.append(f"- {str(reason).strip()[:220]}")
|
||||||
|
|
||||||
|
if negative:
|
||||||
|
lines.append("")
|
||||||
|
lines.append("### Bearish Threads")
|
||||||
|
for row in negative[:3]:
|
||||||
|
reason = row.get("reason_decrease") or row.get("key_discussion") or row.get("summary") or row.get("title")
|
||||||
|
lines.append(f"- {str(reason).strip()[:220]}")
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
def get_or_create_stock_story(
|
||||||
|
store: MarketStore,
|
||||||
|
*,
|
||||||
|
symbol: str,
|
||||||
|
as_of_date: str,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Return cached story or build a new one from recent market context."""
|
||||||
|
cached = store.get_story_cache(symbol, as_of_date=as_of_date)
|
||||||
|
if cached:
|
||||||
|
return {
|
||||||
|
"symbol": symbol,
|
||||||
|
"as_of_date": as_of_date,
|
||||||
|
"story": cached.get("content") or "",
|
||||||
|
"source": cached.get("source") or "cache",
|
||||||
|
}
|
||||||
|
|
||||||
|
start_date = None
|
||||||
|
if len(as_of_date) >= 10:
|
||||||
|
target_date = datetime.strptime(as_of_date[:10], "%Y-%m-%d").date()
|
||||||
|
start_date = (target_date - timedelta(days=29)).isoformat()
|
||||||
|
|
||||||
|
price_rows = (
|
||||||
|
store.get_ohlc(symbol, start_date, as_of_date)
|
||||||
|
if start_date
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
news_rows = store.get_news_items_enriched(
|
||||||
|
symbol,
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=as_of_date,
|
||||||
|
limit=40,
|
||||||
|
)
|
||||||
|
story = build_stock_story(
|
||||||
|
symbol=symbol,
|
||||||
|
as_of_date=as_of_date,
|
||||||
|
price_rows=price_rows,
|
||||||
|
news_rows=news_rows,
|
||||||
|
)
|
||||||
|
store.upsert_story_cache(symbol, as_of_date=as_of_date, content=story, source="local")
|
||||||
|
return {
|
||||||
|
"symbol": symbol,
|
||||||
|
"as_of_date": as_of_date,
|
||||||
|
"story": story,
|
||||||
|
"source": "local",
|
||||||
|
}
|
||||||
251
backend/gateway_server.py
Normal file
251
backend/gateway_server.py
Normal file
@@ -0,0 +1,251 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""Gateway Server - Entry point for Gateway subprocess.
|
||||||
|
|
||||||
|
This module is launched as a subprocess by the Control Plane (FastAPI)
|
||||||
|
to run the Data Plane (Gateway + Pipeline).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from contextlib import AsyncExitStack
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
# Load environment variables
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
from backend.agents import AnalystAgent, PMAgent, RiskAgent
|
||||||
|
from backend.agents.skills_manager import SkillsManager
|
||||||
|
from backend.agents.toolkit_factory import create_agent_toolkit, load_agent_profiles
|
||||||
|
from backend.agents.prompt_loader import get_prompt_loader
|
||||||
|
from backend.agents.workspace_manager import WorkspaceManager
|
||||||
|
from backend.config.constants import ANALYST_TYPES
|
||||||
|
from backend.core.pipeline import TradingPipeline
|
||||||
|
from backend.core.pipeline_runner import create_agents, create_long_term_memory
|
||||||
|
from backend.core.scheduler import BacktestScheduler, Scheduler
|
||||||
|
from backend.llm.models import get_agent_formatter, get_agent_model
|
||||||
|
from backend.runtime.manager import (
|
||||||
|
TradingRuntimeManager,
|
||||||
|
set_global_runtime_manager,
|
||||||
|
clear_global_runtime_manager,
|
||||||
|
)
|
||||||
|
from backend.services.gateway import Gateway
|
||||||
|
from backend.services.market import MarketService
|
||||||
|
from backend.services.storage import StorageService
|
||||||
|
from backend.utils.settlement import SettlementCoordinator
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
_prompt_loader = get_prompt_loader()
|
||||||
|
|
||||||
|
|
||||||
|
async def run_gateway(
|
||||||
|
run_id: str,
|
||||||
|
run_dir: Path,
|
||||||
|
bootstrap: dict,
|
||||||
|
port: int
|
||||||
|
):
|
||||||
|
"""Run Gateway with Pipeline."""
|
||||||
|
|
||||||
|
# Extract config
|
||||||
|
tickers = bootstrap.get("tickers", ["AAPL", "MSFT"])
|
||||||
|
initial_cash = float(bootstrap.get("initial_cash", 100000.0))
|
||||||
|
margin_requirement = float(bootstrap.get("margin_requirement", 0.0))
|
||||||
|
max_comm_cycles = int(bootstrap.get("max_comm_cycles", 2))
|
||||||
|
schedule_mode = bootstrap.get("schedule_mode", "daily")
|
||||||
|
trigger_time = bootstrap.get("trigger_time", "09:30")
|
||||||
|
interval_minutes = int(bootstrap.get("interval_minutes", 60))
|
||||||
|
heartbeat_interval = int(bootstrap.get("heartbeat_interval", 0)) # 0 = disabled
|
||||||
|
mode = bootstrap.get("mode", "live")
|
||||||
|
start_date = bootstrap.get("start_date")
|
||||||
|
end_date = bootstrap.get("end_date")
|
||||||
|
enable_memory = bootstrap.get("enable_memory", False)
|
||||||
|
poll_interval = int(bootstrap.get("poll_interval", 10))
|
||||||
|
enable_mock = bootstrap.get("enable_mock", False)
|
||||||
|
|
||||||
|
is_backtest = mode == "backtest"
|
||||||
|
is_mock = enable_mock or mode == "mock" or (not is_backtest and os.getenv("MOCK_MODE", "false").lower() == "true")
|
||||||
|
|
||||||
|
logger.info(f"[Gateway Server] Starting run {run_id} on port {port}")
|
||||||
|
|
||||||
|
# Create runtime manager
|
||||||
|
runtime_manager = TradingRuntimeManager(
|
||||||
|
config_name=run_id,
|
||||||
|
run_dir=run_dir,
|
||||||
|
bootstrap=bootstrap,
|
||||||
|
)
|
||||||
|
runtime_manager.prepare_run()
|
||||||
|
set_global_runtime_manager(runtime_manager)
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with AsyncExitStack() as stack:
|
||||||
|
# Create services
|
||||||
|
market_service = MarketService(
|
||||||
|
tickers=tickers,
|
||||||
|
poll_interval=poll_interval,
|
||||||
|
mock_mode=is_mock and not is_backtest,
|
||||||
|
backtest_mode=is_backtest,
|
||||||
|
api_key=os.getenv("FINNHUB_API_KEY") if not is_mock and not is_backtest else None,
|
||||||
|
backtest_start_date=start_date if is_backtest else None,
|
||||||
|
backtest_end_date=end_date if is_backtest else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
storage_service = StorageService(
|
||||||
|
dashboard_dir=run_dir / "team_dashboard",
|
||||||
|
initial_cash=initial_cash,
|
||||||
|
config_name=run_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not storage_service.files["summary"].exists():
|
||||||
|
storage_service.initialize_empty_dashboard()
|
||||||
|
else:
|
||||||
|
storage_service.update_leaderboard_model_info()
|
||||||
|
|
||||||
|
# Create agents
|
||||||
|
analysts, risk_manager, pm, long_term_memories = create_agents(
|
||||||
|
run_id=run_id,
|
||||||
|
run_dir=run_dir,
|
||||||
|
initial_cash=initial_cash,
|
||||||
|
margin_requirement=margin_requirement,
|
||||||
|
enable_long_term_memory=enable_memory,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Register agents
|
||||||
|
for agent in analysts + [risk_manager, pm]:
|
||||||
|
agent_id = getattr(agent, "agent_id", None) or getattr(agent, "name", None)
|
||||||
|
if agent_id:
|
||||||
|
runtime_manager.register_agent(agent_id)
|
||||||
|
|
||||||
|
# Load portfolio state
|
||||||
|
portfolio_state = storage_service.load_portfolio_state()
|
||||||
|
pm.load_portfolio_state(portfolio_state)
|
||||||
|
|
||||||
|
# Create settlement coordinator
|
||||||
|
settlement_coordinator = SettlementCoordinator(
|
||||||
|
storage=storage_service,
|
||||||
|
initial_capital=initial_cash,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create pipeline
|
||||||
|
pipeline = TradingPipeline(
|
||||||
|
analysts=analysts,
|
||||||
|
risk_manager=risk_manager,
|
||||||
|
portfolio_manager=pm,
|
||||||
|
settlement_coordinator=settlement_coordinator,
|
||||||
|
max_comm_cycles=max_comm_cycles,
|
||||||
|
runtime_manager=runtime_manager,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create scheduler
|
||||||
|
scheduler_callback = None
|
||||||
|
live_scheduler = None
|
||||||
|
|
||||||
|
if is_backtest:
|
||||||
|
backtest_scheduler = BacktestScheduler(
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
trading_calendar="NYSE",
|
||||||
|
delay_between_days=0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def scheduler_callback_fn(callback):
|
||||||
|
await backtest_scheduler.start(callback)
|
||||||
|
|
||||||
|
scheduler_callback = scheduler_callback_fn
|
||||||
|
else:
|
||||||
|
live_scheduler = Scheduler(
|
||||||
|
mode=schedule_mode,
|
||||||
|
trigger_time=trigger_time,
|
||||||
|
interval_minutes=interval_minutes,
|
||||||
|
heartbeat_interval=heartbeat_interval if heartbeat_interval > 0 else None,
|
||||||
|
config={"config_name": run_id},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def scheduler_callback_fn(callback):
|
||||||
|
await live_scheduler.start(callback)
|
||||||
|
|
||||||
|
scheduler_callback = scheduler_callback_fn
|
||||||
|
|
||||||
|
# Enter long-term memory contexts
|
||||||
|
for memory in long_term_memories:
|
||||||
|
await stack.enter_async_context(memory)
|
||||||
|
|
||||||
|
# Create Gateway
|
||||||
|
gateway = Gateway(
|
||||||
|
market_service=market_service,
|
||||||
|
storage_service=storage_service,
|
||||||
|
pipeline=pipeline,
|
||||||
|
scheduler_callback=scheduler_callback,
|
||||||
|
config={
|
||||||
|
"mode": mode,
|
||||||
|
"mock_mode": is_mock,
|
||||||
|
"backtest_mode": is_backtest,
|
||||||
|
"tickers": tickers,
|
||||||
|
"config_name": run_id,
|
||||||
|
"schedule_mode": schedule_mode,
|
||||||
|
"interval_minutes": interval_minutes,
|
||||||
|
"trigger_time": trigger_time,
|
||||||
|
"heartbeat_interval": heartbeat_interval,
|
||||||
|
"initial_cash": initial_cash,
|
||||||
|
"margin_requirement": margin_requirement,
|
||||||
|
"max_comm_cycles": max_comm_cycles,
|
||||||
|
"enable_memory": enable_memory,
|
||||||
|
},
|
||||||
|
scheduler=live_scheduler,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Start Gateway (blocks until shutdown)
|
||||||
|
logger.info(f"[Gateway Server] Gateway starting on port {port}")
|
||||||
|
await gateway.start(host="0.0.0.0", port=port)
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.info("[Gateway Server] Cancelled")
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
logger.info("[Gateway Server] Cleaning up")
|
||||||
|
clear_global_runtime_manager()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main entry point."""
|
||||||
|
parser = argparse.ArgumentParser(description="Gateway Server")
|
||||||
|
parser.add_argument("--run-id", required=True, help="Run identifier")
|
||||||
|
parser.add_argument("--run-dir", required=True, help="Run directory path")
|
||||||
|
parser.add_argument("--port", type=int, default=8765, help="WebSocket port")
|
||||||
|
parser.add_argument("--bootstrap", required=True, help="Bootstrap config as JSON")
|
||||||
|
parser.add_argument("--verbose", action="store_true", help="Verbose logging")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Setup logging
|
||||||
|
level = logging.DEBUG if args.verbose else logging.INFO
|
||||||
|
logging.basicConfig(
|
||||||
|
level=level,
|
||||||
|
format="%(asctime)s | %(levelname)-7s | %(name)s:%(lineno)d - %(message)s",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse bootstrap
|
||||||
|
bootstrap = json.loads(args.bootstrap)
|
||||||
|
run_dir = Path(args.run_dir)
|
||||||
|
|
||||||
|
# Run
|
||||||
|
try:
|
||||||
|
asyncio.run(run_gateway(
|
||||||
|
run_id=args.run_id,
|
||||||
|
run_dir=run_dir,
|
||||||
|
bootstrap=bootstrap,
|
||||||
|
port=args.port
|
||||||
|
))
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.info("[Gateway Server] Interrupted by user")
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"[Gateway Server] Fatal error: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -3,9 +3,11 @@
|
|||||||
AgentScope Native Model Factory
|
AgentScope Native Model Factory
|
||||||
Uses native AgentScope model classes for LLM calls
|
Uses native AgentScope model classes for LLM calls
|
||||||
"""
|
"""
|
||||||
from enum import Enum
|
|
||||||
from typing import Optional, Tuple
|
|
||||||
import os
|
import os
|
||||||
|
import time
|
||||||
|
import logging
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Callable, Optional, Tuple, TypeVar, Union
|
||||||
from agentscope.formatter import (
|
from agentscope.formatter import (
|
||||||
AnthropicChatFormatter,
|
AnthropicChatFormatter,
|
||||||
DashScopeChatFormatter,
|
DashScopeChatFormatter,
|
||||||
@@ -26,6 +28,244 @@ from backend.config.env_config import (
|
|||||||
get_env_str,
|
get_env_str,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Retry wrapper types
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
class RetryChatModel:
|
||||||
|
"""Wraps an AgentScope model with automatic retry for transient errors.
|
||||||
|
|
||||||
|
Based on CoPaw's RetryChatModel design. Handles rate limits, timeouts,
|
||||||
|
and other transient failures with exponential backoff.
|
||||||
|
"""
|
||||||
|
|
||||||
|
DEFAULT_MAX_RETRIES = 3
|
||||||
|
DEFAULT_INITIAL_DELAY = 1.0
|
||||||
|
DEFAULT_MAX_DELAY = 60.0
|
||||||
|
DEFAULT_BACKOFF_MULTIPLIER = 2.0
|
||||||
|
|
||||||
|
# Transient error codes/messages that should trigger retry
|
||||||
|
TRANSIENT_ERROR_KEYWORDS = frozenset([
|
||||||
|
"rate_limit",
|
||||||
|
"429",
|
||||||
|
"timeout",
|
||||||
|
"503",
|
||||||
|
"502",
|
||||||
|
"504",
|
||||||
|
"connection",
|
||||||
|
"temporary",
|
||||||
|
"overloaded",
|
||||||
|
"too_many_requests",
|
||||||
|
])
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: Any,
|
||||||
|
max_retries: int = DEFAULT_MAX_RETRIES,
|
||||||
|
initial_delay: float = DEFAULT_INITIAL_DELAY,
|
||||||
|
max_delay: float = DEFAULT_MAX_DELAY,
|
||||||
|
backoff_multiplier: float = DEFAULT_BACKOFF_MULTIPLIER,
|
||||||
|
on_retry: Optional[Callable[[int, Exception, float], None]] = None,
|
||||||
|
):
|
||||||
|
"""Initialize retry wrapper.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The underlying AgentScope model to wrap
|
||||||
|
max_retries: Maximum number of retry attempts
|
||||||
|
initial_delay: Initial delay in seconds before first retry
|
||||||
|
max_delay: Maximum delay between retries
|
||||||
|
backoff_multiplier: Multiplier for exponential backoff
|
||||||
|
on_retry: Optional callback(retry_count, exception, delay) for logging
|
||||||
|
"""
|
||||||
|
self._model = model
|
||||||
|
self._max_retries = max_retries
|
||||||
|
self._initial_delay = initial_delay
|
||||||
|
self._max_delay = max_delay
|
||||||
|
self._backoff_multiplier = backoff_multiplier
|
||||||
|
self._on_retry = on_retry
|
||||||
|
self._total_tokens_used = 0
|
||||||
|
self._total_cost = 0.0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_name(self) -> str:
|
||||||
|
return getattr(self._model, "model_name", str(self._model))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total_tokens_used(self) -> int:
|
||||||
|
return self._total_tokens_used
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total_cost(self) -> float:
|
||||||
|
return self._total_cost
|
||||||
|
|
||||||
|
def _is_transient_error(self, error: Exception) -> bool:
|
||||||
|
"""Check if an error is transient and should be retried.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
error: The exception to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the error is transient
|
||||||
|
"""
|
||||||
|
error_str = str(error).lower()
|
||||||
|
for keyword in self.TRANSIENT_ERROR_KEYWORDS:
|
||||||
|
if keyword in error_str:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _calculate_delay(self, retry_count: int) -> float:
|
||||||
|
"""Calculate delay for given retry attempt with exponential backoff.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
retry_count: Current retry attempt number (1-based)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Delay in seconds
|
||||||
|
"""
|
||||||
|
delay = self._initial_delay * (self._backoff_multiplier ** (retry_count - 1))
|
||||||
|
return min(delay, self._max_delay)
|
||||||
|
|
||||||
|
def _call_with_retry(self, func: Callable[..., T], *args, **kwargs) -> T:
|
||||||
|
"""Call a function with retry logic for transient errors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: Function to call
|
||||||
|
*args: Positional arguments
|
||||||
|
**kwargs: Keyword arguments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Result from func
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Last exception if all retries exhausted
|
||||||
|
"""
|
||||||
|
last_error: Optional[Exception] = None
|
||||||
|
|
||||||
|
for attempt in range(1, self._max_retries + 1):
|
||||||
|
try:
|
||||||
|
result = func(*args, **kwargs)
|
||||||
|
|
||||||
|
# Track usage if available
|
||||||
|
if hasattr(result, "usage") and result.usage:
|
||||||
|
usage = result.usage
|
||||||
|
self._total_tokens_used += getattr(usage, "total_tokens", 0)
|
||||||
|
self._total_cost += getattr(usage, "cost", 0.0)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
last_error = e
|
||||||
|
|
||||||
|
if attempt >= self._max_retries:
|
||||||
|
logger.error(
|
||||||
|
"RetryChatModel: Max retries (%d) exhausted for %s",
|
||||||
|
self._max_retries,
|
||||||
|
self.model_name,
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
if not self._is_transient_error(e):
|
||||||
|
logger.warning(
|
||||||
|
"RetryChatModel: Non-transient error, not retrying: %s",
|
||||||
|
str(e),
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
delay = self._calculate_delay(attempt)
|
||||||
|
logger.warning(
|
||||||
|
"RetryChatModel: Transient error on attempt %d/%d, "
|
||||||
|
"retrying in %.1fs: %s",
|
||||||
|
attempt,
|
||||||
|
self._max_retries,
|
||||||
|
delay,
|
||||||
|
str(e)[:200],
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._on_retry:
|
||||||
|
self._on_retry(attempt, e, delay)
|
||||||
|
|
||||||
|
time.sleep(delay)
|
||||||
|
|
||||||
|
if last_error is not None:
|
||||||
|
raise last_error
|
||||||
|
raise RuntimeError("RetryChatModel: Unexpected state, no error but no result")
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs) -> Any:
|
||||||
|
"""Forward calls to the wrapped model with retry logic."""
|
||||||
|
return self._call_with_retry(self._model, *args, **kwargs)
|
||||||
|
|
||||||
|
def __getattr__(self, name: str) -> Any:
|
||||||
|
"""Proxy attribute access to the wrapped model."""
|
||||||
|
return getattr(self._model, name)
|
||||||
|
|
||||||
|
|
||||||
|
class TokenRecordingModelWrapper:
|
||||||
|
"""Wraps a model to track token usage per provider.
|
||||||
|
|
||||||
|
Based on CoPaw's TokenRecordingModelWrapper design.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model: Any):
|
||||||
|
"""Initialize token recorder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The underlying AgentScope model to wrap
|
||||||
|
"""
|
||||||
|
self._model = model
|
||||||
|
self._total_tokens = 0
|
||||||
|
self._prompt_tokens = 0
|
||||||
|
self._completion_tokens = 0
|
||||||
|
self._total_cost = 0.0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_name(self) -> str:
|
||||||
|
return getattr(self._model, "model_name", str(self._model))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total_tokens(self) -> int:
|
||||||
|
return self._total_tokens
|
||||||
|
|
||||||
|
@property
|
||||||
|
def prompt_tokens(self) -> int:
|
||||||
|
return self._prompt_tokens
|
||||||
|
|
||||||
|
@property
|
||||||
|
def completion_tokens(self) -> int:
|
||||||
|
return self._completion_tokens
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total_cost(self) -> float:
|
||||||
|
return self._total_cost
|
||||||
|
|
||||||
|
def record_usage(self, usage: Any) -> None:
|
||||||
|
"""Record token usage from a model response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
usage: Usage object from model response
|
||||||
|
"""
|
||||||
|
if usage is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._prompt_tokens += getattr(usage, "prompt_tokens", 0)
|
||||||
|
self._completion_tokens += getattr(usage, "completion_tokens", 0)
|
||||||
|
self._total_tokens += getattr(usage, "total_tokens", 0)
|
||||||
|
self._total_cost += getattr(usage, "cost", 0.0)
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs) -> Any:
|
||||||
|
"""Forward calls and record usage."""
|
||||||
|
result = self._model(*args, **kwargs)
|
||||||
|
|
||||||
|
if hasattr(result, "usage") and result.usage:
|
||||||
|
self.record_usage(result.usage)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def __getattr__(self, name: str) -> Any:
|
||||||
|
"""Proxy attribute access to the wrapped model."""
|
||||||
|
return getattr(self._model, name)
|
||||||
|
|
||||||
|
|
||||||
class ModelProvider(Enum):
|
class ModelProvider(Enum):
|
||||||
"""Supported model providers"""
|
"""Supported model providers"""
|
||||||
|
|||||||
@@ -16,55 +16,48 @@ from dotenv import load_dotenv
|
|||||||
from backend.agents import AnalystAgent, PMAgent, RiskAgent
|
from backend.agents import AnalystAgent, PMAgent, RiskAgent
|
||||||
from backend.agents.skills_manager import SkillsManager
|
from backend.agents.skills_manager import SkillsManager
|
||||||
from backend.agents.toolkit_factory import create_agent_toolkit, load_agent_profiles
|
from backend.agents.toolkit_factory import create_agent_toolkit, load_agent_profiles
|
||||||
from backend.agents.prompt_loader import PromptLoader
|
from backend.agents.prompt_loader import get_prompt_loader
|
||||||
from backend.agents.workspace_manager import WorkspaceManager
|
from backend.agents.workspace_manager import WorkspaceManager
|
||||||
from backend.config.bootstrap_config import get_bootstrap_config_for_run
|
from backend.config.bootstrap_config import resolve_runtime_config
|
||||||
from backend.config.constants import ANALYST_TYPES
|
from backend.config.constants import ANALYST_TYPES
|
||||||
from backend.config.env_config import get_env_float, get_env_int, get_env_list
|
|
||||||
from backend.core.pipeline import TradingPipeline
|
from backend.core.pipeline import TradingPipeline
|
||||||
from backend.core.scheduler import BacktestScheduler, Scheduler
|
from backend.core.scheduler import BacktestScheduler, Scheduler
|
||||||
from backend.utils.settlement import SettlementCoordinator
|
|
||||||
from backend.llm.models import get_agent_formatter, get_agent_model
|
from backend.llm.models import get_agent_formatter, get_agent_model
|
||||||
|
from backend.api.runtime import register_runtime_manager, unregister_runtime_manager
|
||||||
|
from backend.runtime.manager import (
|
||||||
|
TradingRuntimeManager,
|
||||||
|
set_global_runtime_manager,
|
||||||
|
clear_global_runtime_manager,
|
||||||
|
)
|
||||||
from backend.services.gateway import Gateway
|
from backend.services.gateway import Gateway
|
||||||
from backend.services.market import MarketService
|
from backend.services.market import MarketService
|
||||||
from backend.services.storage import StorageService
|
from backend.services.storage import StorageService
|
||||||
|
from backend.utils.settlement import SettlementCoordinator
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
loguru.logger.disable("flowllm")
|
loguru.logger.disable("flowllm")
|
||||||
loguru.logger.disable("reme_ai")
|
loguru.logger.disable("reme_ai")
|
||||||
_prompt_loader = PromptLoader()
|
_prompt_loader = get_prompt_loader()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_run_dir(config_name: str) -> Path:
|
||||||
|
"""Return the canonical run-scoped directory for a config."""
|
||||||
|
project_root = Path(__file__).resolve().parents[1]
|
||||||
|
return WorkspaceManager(project_root=project_root).get_run_dir(config_name)
|
||||||
|
|
||||||
|
|
||||||
def _resolve_runtime_config(args) -> dict:
|
def _resolve_runtime_config(args) -> dict:
|
||||||
"""Merge env defaults with run-scoped bootstrap config."""
|
"""Merge env defaults with run-scoped bootstrap config."""
|
||||||
project_root = Path(__file__).resolve().parents[1]
|
project_root = Path(__file__).resolve().parents[1]
|
||||||
bootstrap = get_bootstrap_config_for_run(project_root, args.config_name)
|
return resolve_runtime_config(
|
||||||
|
project_root=project_root,
|
||||||
return {
|
config_name=args.config_name,
|
||||||
"tickers": bootstrap.get("tickers")
|
enable_memory=args.enable_memory,
|
||||||
or get_env_list("TICKERS", ["AAPL", "MSFT"]),
|
schedule_mode=args.schedule_mode,
|
||||||
"initial_cash": float(
|
interval_minutes=args.interval_minutes,
|
||||||
bootstrap.get(
|
trigger_time=args.trigger_time,
|
||||||
"initial_cash",
|
)
|
||||||
get_env_float("INITIAL_CASH", 100000.0),
|
|
||||||
),
|
|
||||||
),
|
|
||||||
"margin_requirement": float(
|
|
||||||
bootstrap.get(
|
|
||||||
"margin_requirement",
|
|
||||||
get_env_float("MARGIN_REQUIREMENT", 0.0),
|
|
||||||
),
|
|
||||||
),
|
|
||||||
"max_comm_cycles": int(
|
|
||||||
bootstrap.get(
|
|
||||||
"max_comm_cycles",
|
|
||||||
get_env_int("MAX_COMM_CYCLES", 2),
|
|
||||||
),
|
|
||||||
),
|
|
||||||
"enable_memory": args.enable_memory
|
|
||||||
or bool(bootstrap.get("enable_memory", False)),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def create_long_term_memory(agent_name: str, config_name: str):
|
def create_long_term_memory(agent_name: str, config_name: str):
|
||||||
@@ -82,7 +75,7 @@ def create_long_term_memory(agent_name: str, config_name: str):
|
|||||||
logger.warning("MEMORY_API_KEY not set, long-term memory disabled")
|
logger.warning("MEMORY_API_KEY not set, long-term memory disabled")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
memory_dir = str(Path(config_name) / "memory")
|
memory_dir = str(_get_run_dir(config_name) / "memory")
|
||||||
|
|
||||||
return ReMeTaskLongTermMemory(
|
return ReMeTaskLongTermMemory(
|
||||||
agent_name=agent_name,
|
agent_name=agent_name,
|
||||||
@@ -226,6 +219,15 @@ async def run_with_gateway(args):
|
|||||||
initial_cash = runtime_config["initial_cash"]
|
initial_cash = runtime_config["initial_cash"]
|
||||||
margin_requirement = runtime_config["margin_requirement"]
|
margin_requirement = runtime_config["margin_requirement"]
|
||||||
|
|
||||||
|
runtime_manager = TradingRuntimeManager(
|
||||||
|
config_name=config_name,
|
||||||
|
run_dir=_get_run_dir(config_name),
|
||||||
|
bootstrap=runtime_config,
|
||||||
|
)
|
||||||
|
runtime_manager.prepare_run()
|
||||||
|
set_global_runtime_manager(runtime_manager)
|
||||||
|
register_runtime_manager(runtime_manager)
|
||||||
|
|
||||||
# Create market service
|
# Create market service
|
||||||
market_service = MarketService(
|
market_service = MarketService(
|
||||||
tickers=tickers,
|
tickers=tickers,
|
||||||
@@ -241,7 +243,7 @@ async def run_with_gateway(args):
|
|||||||
|
|
||||||
# Create storage service
|
# Create storage service
|
||||||
storage_service = StorageService(
|
storage_service = StorageService(
|
||||||
dashboard_dir=Path(config_name) / "team_dashboard",
|
dashboard_dir=_get_run_dir(config_name) / "team_dashboard",
|
||||||
initial_cash=initial_cash,
|
initial_cash=initial_cash,
|
||||||
config_name=config_name,
|
config_name=config_name,
|
||||||
)
|
)
|
||||||
@@ -258,6 +260,10 @@ async def run_with_gateway(args):
|
|||||||
margin_requirement=margin_requirement,
|
margin_requirement=margin_requirement,
|
||||||
enable_long_term_memory=runtime_config["enable_memory"],
|
enable_long_term_memory=runtime_config["enable_memory"],
|
||||||
)
|
)
|
||||||
|
for agent in analysts + [risk_manager, pm]:
|
||||||
|
agent_id = getattr(agent, "agent_id", None) or getattr(agent, "name", None)
|
||||||
|
if agent_id:
|
||||||
|
runtime_manager.register_agent(agent_id)
|
||||||
portfolio_state = storage_service.load_portfolio_state()
|
portfolio_state = storage_service.load_portfolio_state()
|
||||||
pm.load_portfolio_state(portfolio_state)
|
pm.load_portfolio_state(portfolio_state)
|
||||||
|
|
||||||
@@ -272,11 +278,13 @@ async def run_with_gateway(args):
|
|||||||
portfolio_manager=pm,
|
portfolio_manager=pm,
|
||||||
settlement_coordinator=settlement_coordinator,
|
settlement_coordinator=settlement_coordinator,
|
||||||
max_comm_cycles=runtime_config["max_comm_cycles"],
|
max_comm_cycles=runtime_config["max_comm_cycles"],
|
||||||
|
runtime_manager=runtime_manager,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create scheduler callback
|
# Create scheduler callback
|
||||||
scheduler_callback = None
|
scheduler_callback = None
|
||||||
trading_dates = []
|
trading_dates = []
|
||||||
|
live_scheduler = None
|
||||||
|
|
||||||
if is_backtest:
|
if is_backtest:
|
||||||
backtest_scheduler = BacktestScheduler(
|
backtest_scheduler = BacktestScheduler(
|
||||||
@@ -292,10 +300,11 @@ async def run_with_gateway(args):
|
|||||||
|
|
||||||
scheduler_callback = scheduler_callback_fn
|
scheduler_callback = scheduler_callback_fn
|
||||||
else:
|
else:
|
||||||
# Live mode: use daily scheduler with NYSE timezone
|
# Live mode: use daily or intraday scheduler with NYSE timezone
|
||||||
live_scheduler = Scheduler(
|
live_scheduler = Scheduler(
|
||||||
mode="daily",
|
mode=runtime_config["schedule_mode"],
|
||||||
trigger_time=args.trigger_time,
|
trigger_time=runtime_config["trigger_time"],
|
||||||
|
interval_minutes=runtime_config["interval_minutes"],
|
||||||
config={"config_name": config_name},
|
config={"config_name": config_name},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -316,7 +325,15 @@ async def run_with_gateway(args):
|
|||||||
"backtest_mode": is_backtest,
|
"backtest_mode": is_backtest,
|
||||||
"tickers": tickers,
|
"tickers": tickers,
|
||||||
"config_name": config_name,
|
"config_name": config_name,
|
||||||
|
"schedule_mode": runtime_config["schedule_mode"],
|
||||||
|
"interval_minutes": runtime_config["interval_minutes"],
|
||||||
|
"trigger_time": runtime_config["trigger_time"],
|
||||||
|
"initial_cash": initial_cash,
|
||||||
|
"margin_requirement": margin_requirement,
|
||||||
|
"max_comm_cycles": runtime_config["max_comm_cycles"],
|
||||||
|
"enable_memory": runtime_config["enable_memory"],
|
||||||
},
|
},
|
||||||
|
scheduler=live_scheduler if not is_backtest else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_backtest:
|
if is_backtest:
|
||||||
@@ -324,9 +341,13 @@ async def run_with_gateway(args):
|
|||||||
|
|
||||||
# Start long-term memory contexts and run gateway
|
# Start long-term memory contexts and run gateway
|
||||||
async with AsyncExitStack() as stack:
|
async with AsyncExitStack() as stack:
|
||||||
|
try:
|
||||||
for memory in long_term_memories:
|
for memory in long_term_memories:
|
||||||
await stack.enter_async_context(memory)
|
await stack.enter_async_context(memory)
|
||||||
await gateway.start(host=args.host, port=args.port)
|
await gateway.start(host=args.host, port=args.port)
|
||||||
|
finally:
|
||||||
|
unregister_runtime_manager()
|
||||||
|
clear_global_runtime_manager()
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@@ -337,7 +358,13 @@ def main():
|
|||||||
parser.add_argument("--config-name", default="mock")
|
parser.add_argument("--config-name", default="mock")
|
||||||
parser.add_argument("--host", default="0.0.0.0")
|
parser.add_argument("--host", default="0.0.0.0")
|
||||||
parser.add_argument("--port", type=int, default=8765)
|
parser.add_argument("--port", type=int, default=8765)
|
||||||
|
parser.add_argument(
|
||||||
|
"--schedule-mode",
|
||||||
|
choices=["daily", "intraday"],
|
||||||
|
default="daily",
|
||||||
|
)
|
||||||
parser.add_argument("--trigger-time", default="09:30") # NYSE market open
|
parser.add_argument("--trigger-time", default="09:30") # NYSE market open
|
||||||
|
parser.add_argument("--interval-minutes", type=int, default=60)
|
||||||
parser.add_argument("--poll-interval", type=int, default=10)
|
parser.add_argument("--poll-interval", type=int, default=10)
|
||||||
parser.add_argument("--start-date")
|
parser.add_argument("--start-date")
|
||||||
parser.add_argument("--end-date")
|
parser.add_argument("--end-date")
|
||||||
|
|||||||
41
backend/process/models.py
Normal file
41
backend/process/models.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""Data models for lightweight process supervision."""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
|
||||||
|
class ProcessRunState(str, Enum):
|
||||||
|
"""Execution state for supervised runs."""
|
||||||
|
|
||||||
|
PENDING = "pending"
|
||||||
|
RUNNING = "running"
|
||||||
|
COMPLETED = "completed"
|
||||||
|
FAILED = "failed"
|
||||||
|
CANCELLED = "cancelled"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ProcessRun:
|
||||||
|
"""Represents a supervised process run."""
|
||||||
|
|
||||||
|
run_id: str
|
||||||
|
command: str
|
||||||
|
scope_key: str
|
||||||
|
state: ProcessRunState = ProcessRunState.PENDING
|
||||||
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
created_at: datetime = field(default_factory=datetime.utcnow)
|
||||||
|
updated_at: datetime = field(default_factory=datetime.utcnow)
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"run_id": self.run_id,
|
||||||
|
"command": self.command,
|
||||||
|
"scope_key": self.scope_key,
|
||||||
|
"state": self.state.value,
|
||||||
|
"metadata": self.metadata,
|
||||||
|
"created_at": self.created_at.isoformat(),
|
||||||
|
"updated_at": self.updated_at.isoformat(),
|
||||||
|
}
|
||||||
35
backend/process/registry.py
Normal file
35
backend/process/registry.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""Registry for managing supervised process metadata."""
|
||||||
|
|
||||||
|
from threading import Lock
|
||||||
|
from typing import Dict, Iterable, Optional
|
||||||
|
|
||||||
|
from .models import ProcessRun
|
||||||
|
|
||||||
|
|
||||||
|
class RunRegistry:
|
||||||
|
"""In-memory registry for tracked process runs."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._runs: Dict[str, ProcessRun] = {}
|
||||||
|
self._lock = Lock()
|
||||||
|
|
||||||
|
def add(self, run: ProcessRun) -> None:
|
||||||
|
with self._lock:
|
||||||
|
self._runs[run.run_id] = run
|
||||||
|
|
||||||
|
def get(self, run_id: str) -> Optional[ProcessRun]:
|
||||||
|
with self._lock:
|
||||||
|
return self._runs.get(run_id)
|
||||||
|
|
||||||
|
def list(self) -> Iterable[ProcessRun]:
|
||||||
|
with self._lock:
|
||||||
|
return list(self._runs.values())
|
||||||
|
|
||||||
|
def update(self, run: ProcessRun) -> None:
|
||||||
|
with self._lock:
|
||||||
|
self._runs[run.run_id] = run
|
||||||
|
|
||||||
|
def remove(self, run_id: str) -> None:
|
||||||
|
with self._lock:
|
||||||
|
self._runs.pop(run_id, None)
|
||||||
61
backend/process/supervisor.py
Normal file
61
backend/process/supervisor.py
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""Minimal supervisor for scripted tasks and long-running utilities."""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Dict, Iterable, Optional
|
||||||
|
|
||||||
|
from .models import ProcessRun, ProcessRunState
|
||||||
|
from .registry import RunRegistry
|
||||||
|
|
||||||
|
|
||||||
|
class ProcessSupervisor:
|
||||||
|
"""Tracks supervised runs without executing real processes yet."""
|
||||||
|
|
||||||
|
def __init__(self, registry: Optional[RunRegistry] = None) -> None:
|
||||||
|
self.registry = registry or RunRegistry()
|
||||||
|
|
||||||
|
def spawn(
|
||||||
|
self,
|
||||||
|
run_id: str,
|
||||||
|
command: str,
|
||||||
|
scope_key: str,
|
||||||
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> ProcessRun:
|
||||||
|
run = ProcessRun(
|
||||||
|
run_id=run_id,
|
||||||
|
command=command,
|
||||||
|
scope_key=scope_key,
|
||||||
|
metadata=metadata or {},
|
||||||
|
)
|
||||||
|
run.state = ProcessRunState.RUNNING
|
||||||
|
run.updated_at = datetime.utcnow()
|
||||||
|
self.registry.add(run)
|
||||||
|
return run
|
||||||
|
|
||||||
|
def update_state(
|
||||||
|
self,
|
||||||
|
run_id: str,
|
||||||
|
state: ProcessRunState,
|
||||||
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> Optional[ProcessRun]:
|
||||||
|
run = self.registry.get(run_id)
|
||||||
|
if not run:
|
||||||
|
return None
|
||||||
|
run.state = state
|
||||||
|
run.metadata.update(metadata or {})
|
||||||
|
run.updated_at = datetime.utcnow()
|
||||||
|
self.registry.update(run)
|
||||||
|
return run
|
||||||
|
|
||||||
|
def cancel(self, run_id: str, reason: Optional[str] = None) -> Optional[ProcessRun]:
|
||||||
|
run = self.registry.get(run_id)
|
||||||
|
if not run:
|
||||||
|
return None
|
||||||
|
run.state = ProcessRunState.CANCELLED
|
||||||
|
run.metadata.setdefault("cancel_reason", reason or "manual")
|
||||||
|
run.updated_at = datetime.utcnow()
|
||||||
|
self.registry.update(run)
|
||||||
|
return run
|
||||||
|
|
||||||
|
def list_runs(self) -> Iterable[ProcessRun]:
|
||||||
|
return self.registry.list()
|
||||||
13
backend/runtime/__init__.py
Normal file
13
backend/runtime/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
from .agent_runtime import AgentRuntimeState
|
||||||
|
from .context import TradingRunContext
|
||||||
|
from .manager import TradingRuntimeManager
|
||||||
|
from .registry import RuntimeRegistry
|
||||||
|
from .session import TradingSessionKey
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AgentRuntimeState",
|
||||||
|
"TradingRunContext",
|
||||||
|
"TradingRuntimeManager",
|
||||||
|
"RuntimeRegistry",
|
||||||
|
"TradingSessionKey",
|
||||||
|
]
|
||||||
26
backend/runtime/agent_runtime.py
Normal file
26
backend/runtime/agent_runtime.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime, UTC
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AgentRuntimeState:
|
||||||
|
agent_id: str
|
||||||
|
status: str = "idle"
|
||||||
|
last_session: str | None = None
|
||||||
|
last_updated: datetime = field(default_factory=lambda: datetime.now(UTC))
|
||||||
|
|
||||||
|
def update(self, status: str, session_key: str | None = None) -> None:
|
||||||
|
self.status = status
|
||||||
|
self.last_session = session_key
|
||||||
|
self.last_updated = datetime.now(UTC)
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"agent_id": self.agent_id,
|
||||||
|
"status": self.status,
|
||||||
|
"last_session": self.last_session,
|
||||||
|
"last_updated": self.last_updated.isoformat(),
|
||||||
|
}
|
||||||
15
backend/runtime/context.py
Normal file
15
backend/runtime/context.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class TradingRunContext:
|
||||||
|
config_name: str
|
||||||
|
run_dir: Path
|
||||||
|
bootstrap_values: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
def describe(self) -> str:
|
||||||
|
return f"Run {self.config_name} @ {self.run_dir}"
|
||||||
158
backend/runtime/manager.py
Normal file
158
backend/runtime/manager.py
Normal file
@@ -0,0 +1,158 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
from datetime import datetime, UTC
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from .agent_runtime import AgentRuntimeState
|
||||||
|
from .context import TradingRunContext
|
||||||
|
from .registry import RuntimeRegistry
|
||||||
|
|
||||||
|
_global_runtime_manager: Optional["TradingRuntimeManager"] = None
|
||||||
|
_shutdown_event: Optional[asyncio.Event] = None
|
||||||
|
|
||||||
|
|
||||||
|
def set_global_runtime_manager(manager: "TradingRuntimeManager") -> None:
|
||||||
|
global _global_runtime_manager
|
||||||
|
_global_runtime_manager = manager
|
||||||
|
|
||||||
|
|
||||||
|
def clear_global_runtime_manager() -> None:
|
||||||
|
global _global_runtime_manager
|
||||||
|
_global_runtime_manager = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_global_runtime_manager() -> Optional["TradingRuntimeManager"]:
|
||||||
|
return _global_runtime_manager
|
||||||
|
|
||||||
|
|
||||||
|
def set_shutdown_event(event: asyncio.Event) -> None:
|
||||||
|
"""Set the global shutdown event for signaling runtime stop."""
|
||||||
|
global _shutdown_event
|
||||||
|
_shutdown_event = event
|
||||||
|
|
||||||
|
|
||||||
|
def clear_shutdown_event() -> None:
|
||||||
|
"""Clear the global shutdown event."""
|
||||||
|
global _shutdown_event
|
||||||
|
_shutdown_event = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_shutdown_event() -> Optional[asyncio.Event]:
|
||||||
|
"""Get the global shutdown event if set."""
|
||||||
|
return _shutdown_event
|
||||||
|
|
||||||
|
|
||||||
|
def is_shutdown_requested() -> bool:
|
||||||
|
"""Check if shutdown has been requested."""
|
||||||
|
return _shutdown_event is not None and _shutdown_event.is_set()
|
||||||
|
|
||||||
|
|
||||||
|
class TradingRuntimeManager:
|
||||||
|
def __init__(self, config_name: str, run_dir: Path, bootstrap: Optional[Dict[str, Any]] = None) -> None:
|
||||||
|
self.config_name = config_name
|
||||||
|
self.run_dir = run_dir
|
||||||
|
self.bootstrap = bootstrap or {}
|
||||||
|
self.context: Optional[TradingRunContext] = None
|
||||||
|
self.registry = RuntimeRegistry()
|
||||||
|
self.current_session_key: Optional[str] = None
|
||||||
|
self.events: List[Dict[str, Any]] = []
|
||||||
|
self.pending_approvals: Dict[str, Dict[str, Any]] = {}
|
||||||
|
self.snapshot_path = self.run_dir / "state" / "runtime_state.json"
|
||||||
|
|
||||||
|
def prepare_run(self) -> TradingRunContext:
|
||||||
|
self.run_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
self.context = TradingRunContext(
|
||||||
|
config_name=self.config_name,
|
||||||
|
run_dir=self.run_dir,
|
||||||
|
bootstrap_values=self.bootstrap,
|
||||||
|
)
|
||||||
|
self._persist_snapshot()
|
||||||
|
return self.context
|
||||||
|
|
||||||
|
def set_session_key(self, session_key: str) -> None:
|
||||||
|
self.current_session_key = session_key
|
||||||
|
self._persist_snapshot()
|
||||||
|
|
||||||
|
def log_event(self, event: str, details: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||||
|
entry = {
|
||||||
|
"timestamp": datetime.now(UTC).isoformat(),
|
||||||
|
"event": event,
|
||||||
|
"details": details or {},
|
||||||
|
"session": self.current_session_key,
|
||||||
|
}
|
||||||
|
self.events.append(entry)
|
||||||
|
self._persist_snapshot()
|
||||||
|
return entry
|
||||||
|
|
||||||
|
def register_agent(self, agent_id: str) -> AgentRuntimeState:
|
||||||
|
state = AgentRuntimeState(agent_id=agent_id)
|
||||||
|
self.registry.register(agent_id, state)
|
||||||
|
self._persist_snapshot()
|
||||||
|
return state
|
||||||
|
|
||||||
|
def register_pending_approval(self, approval_id: str, payload: Dict[str, Any]) -> None:
|
||||||
|
payload.setdefault("status", "pending")
|
||||||
|
payload.setdefault("created_at", datetime.now(UTC).isoformat())
|
||||||
|
self.pending_approvals[approval_id] = payload
|
||||||
|
self._persist_snapshot()
|
||||||
|
|
||||||
|
def update_agent_status(
|
||||||
|
self,
|
||||||
|
agent_id: str,
|
||||||
|
status: str,
|
||||||
|
session_key: Optional[str] = None,
|
||||||
|
) -> AgentRuntimeState:
|
||||||
|
state = self.registry.get(agent_id)
|
||||||
|
if state is None:
|
||||||
|
state = self.register_agent(agent_id)
|
||||||
|
effective_session = session_key or self.current_session_key
|
||||||
|
state.update(status, effective_session)
|
||||||
|
self._persist_snapshot()
|
||||||
|
return state
|
||||||
|
|
||||||
|
def get_agent_state(self, agent_id: str) -> Optional[AgentRuntimeState]:
|
||||||
|
return self.registry.get(agent_id)
|
||||||
|
|
||||||
|
def list_agents(self) -> list[str]:
|
||||||
|
return self.registry.list_agents()
|
||||||
|
|
||||||
|
def resolve_pending_approval(self, approval_id: str, resolved_by: str, status: str) -> None:
|
||||||
|
entry = self.pending_approvals.get(approval_id)
|
||||||
|
if not entry:
|
||||||
|
return
|
||||||
|
entry["status"] = status
|
||||||
|
entry["resolved_at"] = datetime.now(UTC).isoformat()
|
||||||
|
entry["resolved_by"] = resolved_by
|
||||||
|
self._persist_snapshot()
|
||||||
|
|
||||||
|
def list_pending_approvals(self) -> List[Dict[str, Any]]:
|
||||||
|
return list(self.pending_approvals.values())
|
||||||
|
|
||||||
|
def build_snapshot(self) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"context": {
|
||||||
|
"config_name": self.context.config_name,
|
||||||
|
"run_dir": str(self.context.run_dir),
|
||||||
|
"bootstrap_values": self.context.bootstrap_values,
|
||||||
|
}
|
||||||
|
if self.context
|
||||||
|
else None,
|
||||||
|
"current_session_key": self.current_session_key,
|
||||||
|
"agents": [
|
||||||
|
state.to_dict()
|
||||||
|
for agent_id in self.registry.list_agents()
|
||||||
|
if (state := self.registry.get(agent_id)) is not None
|
||||||
|
],
|
||||||
|
"events": self.events,
|
||||||
|
"pending_approvals": self.list_pending_approvals(),
|
||||||
|
}
|
||||||
|
|
||||||
|
def _persist_snapshot(self) -> None:
|
||||||
|
self.snapshot_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
self.snapshot_path.write_text(
|
||||||
|
json.dumps(self.build_snapshot(), ensure_ascii=False, indent=2),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
20
backend/runtime/registry.py
Normal file
20
backend/runtime/registry.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
|
||||||
|
class RuntimeRegistry:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._states: Dict[str, "AgentRuntimeState"] = {}
|
||||||
|
|
||||||
|
def register(self, agent_id: str, state: "AgentRuntimeState") -> None:
|
||||||
|
self._states[agent_id] = state
|
||||||
|
|
||||||
|
def get(self, agent_id: str) -> Optional["AgentRuntimeState"]:
|
||||||
|
return self._states.get(agent_id)
|
||||||
|
|
||||||
|
def list_agents(self) -> list[str]:
|
||||||
|
return list(self._states.keys())
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
self._states.clear()
|
||||||
14
backend/runtime/session.py
Normal file
14
backend/runtime/session.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class TradingSessionKey:
|
||||||
|
date: str
|
||||||
|
ticker: str | None = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if not self.date:
|
||||||
|
raise ValueError("Session must have a date")
|
||||||
|
|
||||||
|
def key(self) -> str:
|
||||||
|
return f"{self.date}:{self.ticker or 'all'}"
|
||||||
@@ -5,21 +5,43 @@ WebSocket Gateway for frontend communication
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Dict, List, Optional, Set
|
from typing import Any, Callable, Dict, List, Optional, Set
|
||||||
|
|
||||||
import websockets
|
import websockets
|
||||||
from websockets.server import WebSocketServerProtocol
|
from websockets.asyncio.server import ServerConnection
|
||||||
|
|
||||||
from backend.utils.msg_adapter import FrontendAdapter
|
from backend.data.provider_utils import normalize_symbol
|
||||||
|
from backend.domains import news as news_domain
|
||||||
|
from backend.llm.models import get_agent_model_info
|
||||||
from backend.utils.terminal_dashboard import get_dashboard
|
from backend.utils.terminal_dashboard import get_dashboard
|
||||||
from backend.core.pipeline import TradingPipeline
|
from backend.core.pipeline import TradingPipeline
|
||||||
from backend.core.state_sync import StateSync
|
from backend.core.state_sync import StateSync
|
||||||
from backend.services.market import MarketService
|
from backend.services.market import MarketService
|
||||||
from backend.services.storage import StorageService
|
from backend.services.storage import StorageService
|
||||||
from backend.data.provider_router import get_provider_router
|
from backend.data.provider_router import get_provider_router
|
||||||
|
from backend.tools.technical_signals import StockTechnicalAnalyzer
|
||||||
|
from backend.core.scheduler import Scheduler
|
||||||
|
from backend.services import gateway_admin_handlers
|
||||||
|
from backend.services import gateway_cycle_support
|
||||||
|
from backend.services import gateway_runtime_support
|
||||||
|
from backend.services import gateway_stock_handlers
|
||||||
|
from shared.client import NewsServiceClient
|
||||||
|
from shared.client import TradingServiceClient
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
EDITABLE_AGENT_WORKSPACE_FILES = {
|
||||||
|
"SOUL.md",
|
||||||
|
"PROFILE.md",
|
||||||
|
"AGENTS.md",
|
||||||
|
"MEMORY.md",
|
||||||
|
"POLICY.md",
|
||||||
|
"HEARTBEAT.md",
|
||||||
|
"ROLE.md",
|
||||||
|
"STYLE.md",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class Gateway:
|
class Gateway:
|
||||||
@@ -32,12 +54,14 @@ class Gateway:
|
|||||||
pipeline: TradingPipeline,
|
pipeline: TradingPipeline,
|
||||||
state_sync: Optional[StateSync] = None,
|
state_sync: Optional[StateSync] = None,
|
||||||
scheduler_callback: Optional[Callable] = None,
|
scheduler_callback: Optional[Callable] = None,
|
||||||
|
scheduler: Optional[Scheduler] = None,
|
||||||
config: Dict[str, Any] = None,
|
config: Dict[str, Any] = None,
|
||||||
):
|
):
|
||||||
self.market_service = market_service
|
self.market_service = market_service
|
||||||
self.storage = storage_service
|
self.storage = storage_service
|
||||||
self.pipeline = pipeline
|
self.pipeline = pipeline
|
||||||
self.scheduler_callback = scheduler_callback
|
self.scheduler_callback = scheduler_callback
|
||||||
|
self.scheduler = scheduler
|
||||||
self.config = config or {}
|
self.config = config or {}
|
||||||
|
|
||||||
self.mode = self.config.get("mode", "live")
|
self.mode = self.config.get("mode", "live")
|
||||||
@@ -51,21 +75,31 @@ class Gateway:
|
|||||||
self.state_sync.set_broadcast_fn(self.broadcast)
|
self.state_sync.set_broadcast_fn(self.broadcast)
|
||||||
self.pipeline.state_sync = self.state_sync
|
self.pipeline.state_sync = self.state_sync
|
||||||
|
|
||||||
self.connected_clients: Set[WebSocketServerProtocol] = set()
|
self.connected_clients: Set[ServerConnection] = set()
|
||||||
self.lock = asyncio.Lock()
|
self.lock = asyncio.Lock()
|
||||||
|
self._cycle_lock = asyncio.Lock()
|
||||||
self._backtest_task: Optional[asyncio.Task] = None
|
self._backtest_task: Optional[asyncio.Task] = None
|
||||||
|
self._manual_cycle_task: Optional[asyncio.Task] = None
|
||||||
self._backtest_start_date: Optional[str] = None
|
self._backtest_start_date: Optional[str] = None
|
||||||
self._backtest_end_date: Optional[str] = None
|
self._backtest_end_date: Optional[str] = None
|
||||||
self._dashboard = get_dashboard()
|
self._dashboard = get_dashboard()
|
||||||
self._market_status_task: Optional[asyncio.Task] = None
|
self._market_status_task: Optional[asyncio.Task] = None
|
||||||
|
self._watchlist_ingest_task: Optional[asyncio.Task] = None
|
||||||
|
|
||||||
# Session tracking for live returns
|
# Session tracking for live returns
|
||||||
self._session_start_portfolio_value: Optional[float] = None
|
self._session_start_portfolio_value: Optional[float] = None
|
||||||
self._provider_router = get_provider_router()
|
self._provider_router = get_provider_router()
|
||||||
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
||||||
|
self._project_root = Path(__file__).resolve().parents[2]
|
||||||
|
self._technical_analyzer = StockTechnicalAnalyzer()
|
||||||
|
|
||||||
async def start(self, host: str = "0.0.0.0", port: int = 8766):
|
async def start(self, host: str = "0.0.0.0", port: int = 8766):
|
||||||
"""Start gateway server"""
|
"""Start gateway server with proper initialization order.
|
||||||
|
|
||||||
|
Phase 1: Start WebSocket server first so frontend can connect immediately
|
||||||
|
Phase 2: Start market data service (pushes data to connected clients)
|
||||||
|
Phase 3: Start scheduler last (triggers trading cycles)
|
||||||
|
"""
|
||||||
logger.info(f"Starting gateway on {host}:{port}")
|
logger.info(f"Starting gateway on {host}:{port}")
|
||||||
self._loop = asyncio.get_running_loop()
|
self._loop = asyncio.get_running_loop()
|
||||||
self._provider_router.add_listener(self._on_provider_usage_changed)
|
self._provider_router.add_listener(self._on_provider_usage_changed)
|
||||||
@@ -87,13 +121,31 @@ class Gateway:
|
|||||||
self._dashboard.start()
|
self._dashboard.start()
|
||||||
|
|
||||||
self.state_sync.load_state()
|
self.state_sync.load_state()
|
||||||
self.state_sync.update_state("status", "running")
|
self.market_service.set_price_recorder(self.storage.record_price_point)
|
||||||
|
self.state_sync.update_state("status", "initializing")
|
||||||
self.state_sync.update_state("server_mode", self.mode)
|
self.state_sync.update_state("server_mode", self.mode)
|
||||||
self.state_sync.update_state("is_backtest", self.is_backtest)
|
self.state_sync.update_state("is_backtest", self.is_backtest)
|
||||||
self.state_sync.update_state(
|
self.state_sync.update_state(
|
||||||
"is_mock_mode",
|
"is_mock_mode",
|
||||||
self.config.get("mock_mode", False),
|
self.config.get("mock_mode", False),
|
||||||
)
|
)
|
||||||
|
self.state_sync.update_state("tickers", self.config.get("tickers", []))
|
||||||
|
self.state_sync.update_state(
|
||||||
|
"runtime_config",
|
||||||
|
{
|
||||||
|
"tickers": self.config.get("tickers", []),
|
||||||
|
"schedule_mode": self.config.get("schedule_mode", "daily"),
|
||||||
|
"interval_minutes": self.config.get("interval_minutes", 60),
|
||||||
|
"trigger_time": self.config.get("trigger_time", "09:30"),
|
||||||
|
"initial_cash": self.config.get(
|
||||||
|
"initial_cash",
|
||||||
|
self.storage.initial_cash,
|
||||||
|
),
|
||||||
|
"margin_requirement": self.config.get("margin_requirement"),
|
||||||
|
"max_comm_cycles": self.config.get("max_comm_cycles"),
|
||||||
|
"enable_memory": self.config.get("enable_memory", False),
|
||||||
|
},
|
||||||
|
)
|
||||||
self.state_sync.update_state(
|
self.state_sync.update_state(
|
||||||
"data_sources",
|
"data_sources",
|
||||||
self._provider_router.get_usage_snapshot(),
|
self._provider_router.get_usage_snapshot(),
|
||||||
@@ -117,27 +169,71 @@ class Gateway:
|
|||||||
f"{summary.get('totalAssetValue', 0):,.2f}",
|
f"{summary.get('totalAssetValue', 0):,.2f}",
|
||||||
)
|
)
|
||||||
|
|
||||||
await self.market_service.start(broadcast_func=self.broadcast)
|
# ======================================================================
|
||||||
|
# PHASE 1: Start WebSocket server first
|
||||||
|
# This allows frontend to connect immediately and receive status updates
|
||||||
|
# ======================================================================
|
||||||
|
logger.info("[Phase 1/4] Starting WebSocket server...")
|
||||||
|
self.state_sync.update_state("status", "websocket_ready")
|
||||||
|
|
||||||
if self.scheduler_callback:
|
# Create server but don't block yet - we'll serve inside the context manager
|
||||||
await self.scheduler_callback(callback=self.on_strategy_trigger)
|
server = await websockets.serve(
|
||||||
|
|
||||||
# Start market status monitoring (only for live mode)
|
|
||||||
if not self.is_backtest:
|
|
||||||
self._market_status_task = asyncio.create_task(
|
|
||||||
self._market_status_monitor(),
|
|
||||||
)
|
|
||||||
|
|
||||||
async with websockets.serve(
|
|
||||||
self.handle_client,
|
self.handle_client,
|
||||||
host,
|
host,
|
||||||
port,
|
port,
|
||||||
ping_interval=30,
|
ping_interval=30,
|
||||||
ping_timeout=60,
|
ping_timeout=60,
|
||||||
):
|
|
||||||
logger.info(
|
|
||||||
f"Gateway started: ws://{host}:{port}, mode={self.mode}",
|
|
||||||
)
|
)
|
||||||
|
logger.info(f"WebSocket server ready: ws://{host}:{port}")
|
||||||
|
|
||||||
|
# Give a brief moment for any existing clients to reconnect
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
# ======================================================================
|
||||||
|
# PHASE 2: Start market data service
|
||||||
|
# Now frontend is connected, start pushing price updates
|
||||||
|
# ======================================================================
|
||||||
|
logger.info("[Phase 2/4] Starting market data service...")
|
||||||
|
self.state_sync.update_state("status", "market_service_starting")
|
||||||
|
await self.market_service.start(broadcast_func=self.broadcast)
|
||||||
|
self.state_sync.update_state("status", "market_service_ready")
|
||||||
|
logger.info("Market data service ready - price updates active")
|
||||||
|
|
||||||
|
# ======================================================================
|
||||||
|
# PHASE 3: Start market status monitoring
|
||||||
|
# Monitors market open/close and broadcasts status
|
||||||
|
# ======================================================================
|
||||||
|
logger.info("[Phase 3/4] Starting market status monitoring...")
|
||||||
|
if not self.is_backtest:
|
||||||
|
self._market_status_task = asyncio.create_task(
|
||||||
|
self._market_status_monitor(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# ======================================================================
|
||||||
|
# PHASE 4: Start scheduler last
|
||||||
|
# Only start trading after everything else is ready
|
||||||
|
# ======================================================================
|
||||||
|
logger.info("[Phase 4/4] Starting scheduler...")
|
||||||
|
self.state_sync.update_state("status", "scheduler_starting")
|
||||||
|
|
||||||
|
if self.scheduler:
|
||||||
|
# Wire up heartbeat callback if heartbeat is configured
|
||||||
|
heartbeat_interval = self.config.get("heartbeat_interval", 0)
|
||||||
|
if heartbeat_interval and heartbeat_interval > 0:
|
||||||
|
self.scheduler.set_heartbeat_callback(self.on_heartbeat_trigger)
|
||||||
|
logger.info(
|
||||||
|
f"[Heartbeat] Registered heartbeat callback (interval={heartbeat_interval}s)",
|
||||||
|
)
|
||||||
|
await self.scheduler.start(self.on_strategy_trigger)
|
||||||
|
elif self.scheduler_callback:
|
||||||
|
await self.scheduler_callback(callback=self.on_strategy_trigger)
|
||||||
|
|
||||||
|
self.state_sync.update_state("status", "running")
|
||||||
|
logger.info(
|
||||||
|
f"Gateway fully operational: ws://{host}:{port}, mode={self.mode}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Keep server running
|
||||||
await asyncio.Future()
|
await asyncio.Future()
|
||||||
|
|
||||||
def _on_provider_usage_changed(self, snapshot: Dict[str, Any]):
|
def _on_provider_usage_changed(self, snapshot: Dict[str, Any]):
|
||||||
@@ -159,7 +255,63 @@ class Gateway:
|
|||||||
def state(self) -> Dict[str, Any]:
|
def state(self) -> Dict[str, Any]:
|
||||||
return self.state_sync.state
|
return self.state_sync.state
|
||||||
|
|
||||||
async def handle_client(self, websocket: WebSocketServerProtocol):
|
@staticmethod
|
||||||
|
def _news_rows_need_enrichment(rows: List[Dict[str, Any]]) -> bool:
|
||||||
|
return news_domain.news_rows_need_enrichment(rows)
|
||||||
|
|
||||||
|
def _news_service_url(self) -> str | None:
|
||||||
|
"""Return configured news-service base URL, if any."""
|
||||||
|
candidate = self.config.get("news_service_url") or os.getenv(
|
||||||
|
"NEWS_SERVICE_URL",
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
value = str(candidate or "").strip()
|
||||||
|
return value or None
|
||||||
|
|
||||||
|
def _trading_service_url(self) -> str | None:
|
||||||
|
"""Return configured trading-service base URL, if any."""
|
||||||
|
candidate = self.config.get("trading_service_url") or os.getenv(
|
||||||
|
"TRADING_SERVICE_URL",
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
value = str(candidate or "").strip()
|
||||||
|
return value or None
|
||||||
|
|
||||||
|
async def _call_news_service(
|
||||||
|
self,
|
||||||
|
action: str,
|
||||||
|
callback: Callable[[NewsServiceClient], Any],
|
||||||
|
) -> Any | None:
|
||||||
|
"""Call news-service when configured, otherwise return None."""
|
||||||
|
service_url = self._news_service_url()
|
||||||
|
if not service_url:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with NewsServiceClient(service_url) as client:
|
||||||
|
return await callback(client)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("news-service %s failed: %s", action, exc)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _call_trading_service(
|
||||||
|
self,
|
||||||
|
action: str,
|
||||||
|
callback: Callable[[TradingServiceClient], Any],
|
||||||
|
) -> Any | None:
|
||||||
|
"""Call trading-service when configured, otherwise return None."""
|
||||||
|
service_url = self._trading_service_url()
|
||||||
|
if not service_url:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with TradingServiceClient(service_url) as client:
|
||||||
|
return await callback(client)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("trading-service %s failed: %s", action, exc)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def handle_client(self, websocket: ServerConnection):
|
||||||
"""Handle WebSocket client connection"""
|
"""Handle WebSocket client connection"""
|
||||||
async with self.lock:
|
async with self.lock:
|
||||||
self.connected_clients.add(websocket)
|
self.connected_clients.add(websocket)
|
||||||
@@ -170,7 +322,9 @@ class Gateway:
|
|||||||
async with self.lock:
|
async with self.lock:
|
||||||
self.connected_clients.discard(websocket)
|
self.connected_clients.discard(websocket)
|
||||||
|
|
||||||
async def _send_initial_state(self, websocket: WebSocketServerProtocol):
|
async def _send_initial_state(self, websocket: ServerConnection):
|
||||||
|
try:
|
||||||
|
logger.info("[Gateway] Sending initial state to client...")
|
||||||
state_payload = self.state_sync.get_initial_state_payload(
|
state_payload = self.state_sync.get_initial_state_payload(
|
||||||
include_dashboard=True,
|
include_dashboard=True,
|
||||||
)
|
)
|
||||||
@@ -195,10 +349,23 @@ class Gateway:
|
|||||||
default=str,
|
default=str,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
logger.info("[Gateway] Initial state sent successfully")
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"[Gateway] Failed to send initial state: {e}")
|
||||||
|
# Send error response so client knows something went wrong
|
||||||
|
try:
|
||||||
|
await websocket.send(
|
||||||
|
json.dumps(
|
||||||
|
{"type": "error", "message": "Failed to load initial state"},
|
||||||
|
ensure_ascii=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to send error response to client: {e}")
|
||||||
|
|
||||||
async def _handle_client_messages(
|
async def _handle_client_messages(
|
||||||
self,
|
self,
|
||||||
websocket: WebSocketServerProtocol,
|
websocket: ServerConnection,
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
async for message in websocket:
|
async for message in websocket:
|
||||||
@@ -219,14 +386,148 @@ class Gateway:
|
|||||||
await self._send_initial_state(websocket)
|
await self._send_initial_state(websocket)
|
||||||
elif msg_type == "start_backtest":
|
elif msg_type == "start_backtest":
|
||||||
await self._handle_start_backtest(data)
|
await self._handle_start_backtest(data)
|
||||||
|
elif msg_type == "trigger_strategy":
|
||||||
|
await self._handle_manual_trigger(websocket, data)
|
||||||
|
elif msg_type == "update_runtime_config":
|
||||||
|
await self._handle_update_runtime_config(websocket, data)
|
||||||
elif msg_type == "reload_runtime_assets":
|
elif msg_type == "reload_runtime_assets":
|
||||||
await self._handle_reload_runtime_assets()
|
await self._handle_reload_runtime_assets()
|
||||||
|
elif msg_type == "update_watchlist":
|
||||||
|
await self._handle_update_watchlist(websocket, data)
|
||||||
|
elif msg_type == "get_agent_skills":
|
||||||
|
await self._handle_get_agent_skills(websocket, data)
|
||||||
|
elif msg_type == "get_agent_profile":
|
||||||
|
await self._handle_get_agent_profile(websocket, data)
|
||||||
|
elif msg_type == "get_skill_detail":
|
||||||
|
await self._handle_get_skill_detail(websocket, data)
|
||||||
|
elif msg_type == "create_agent_local_skill":
|
||||||
|
await self._handle_create_agent_local_skill(websocket, data)
|
||||||
|
elif msg_type == "update_agent_local_skill":
|
||||||
|
await self._handle_update_agent_local_skill(websocket, data)
|
||||||
|
elif msg_type == "delete_agent_local_skill":
|
||||||
|
await self._handle_delete_agent_local_skill(websocket, data)
|
||||||
|
elif msg_type == "remove_agent_skill":
|
||||||
|
await self._handle_remove_agent_skill(websocket, data)
|
||||||
|
elif msg_type == "update_agent_skill":
|
||||||
|
await self._handle_update_agent_skill(websocket, data)
|
||||||
|
elif msg_type == "get_agent_workspace_file":
|
||||||
|
await self._handle_get_agent_workspace_file(websocket, data)
|
||||||
|
elif msg_type == "update_agent_workspace_file":
|
||||||
|
await self._handle_update_agent_workspace_file(websocket, data)
|
||||||
|
elif msg_type == "get_stock_history":
|
||||||
|
await self._handle_get_stock_history(websocket, data)
|
||||||
|
elif msg_type == "get_stock_explain_events":
|
||||||
|
await self._handle_get_stock_explain_events(websocket, data)
|
||||||
|
elif msg_type == "get_stock_news":
|
||||||
|
await self._handle_get_stock_news(websocket, data)
|
||||||
|
elif msg_type == "get_stock_news_for_date":
|
||||||
|
await self._handle_get_stock_news_for_date(websocket, data)
|
||||||
|
elif msg_type == "get_stock_news_timeline":
|
||||||
|
await self._handle_get_stock_news_timeline(websocket, data)
|
||||||
|
elif msg_type == "get_stock_news_categories":
|
||||||
|
await self._handle_get_stock_news_categories(websocket, data)
|
||||||
|
elif msg_type == "get_stock_range_explain":
|
||||||
|
await self._handle_get_stock_range_explain(websocket, data)
|
||||||
|
elif msg_type == "get_stock_insider_trades":
|
||||||
|
await self._handle_get_stock_insider_trades(websocket, data)
|
||||||
|
elif msg_type == "get_stock_story":
|
||||||
|
await self._handle_get_stock_story(websocket, data)
|
||||||
|
elif msg_type == "get_stock_similar_days":
|
||||||
|
await self._handle_get_stock_similar_days(websocket, data)
|
||||||
|
elif msg_type == "get_stock_technical_indicators":
|
||||||
|
await self._handle_get_stock_technical_indicators(websocket, data)
|
||||||
|
elif msg_type == "run_stock_enrich":
|
||||||
|
await self._handle_run_stock_enrich(websocket, data)
|
||||||
|
|
||||||
except websockets.ConnectionClosed:
|
except websockets.ConnectionClosed:
|
||||||
pass
|
pass
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def _handle_get_stock_history(
|
||||||
|
self,
|
||||||
|
websocket: ServerConnection,
|
||||||
|
data: Dict[str, Any],
|
||||||
|
):
|
||||||
|
await gateway_stock_handlers.handle_get_stock_history(self, websocket, data)
|
||||||
|
|
||||||
|
async def _handle_get_stock_explain_events(
|
||||||
|
self,
|
||||||
|
websocket: ServerConnection,
|
||||||
|
data: Dict[str, Any],
|
||||||
|
):
|
||||||
|
await gateway_stock_handlers.handle_get_stock_explain_events(self, websocket, data)
|
||||||
|
|
||||||
|
async def _handle_get_stock_news(
|
||||||
|
self,
|
||||||
|
websocket: ServerConnection,
|
||||||
|
data: Dict[str, Any],
|
||||||
|
):
|
||||||
|
await gateway_stock_handlers.handle_get_stock_news(self, websocket, data)
|
||||||
|
|
||||||
|
async def _handle_get_stock_news_for_date(
|
||||||
|
self,
|
||||||
|
websocket: ServerConnection,
|
||||||
|
data: Dict[str, Any],
|
||||||
|
):
|
||||||
|
await gateway_stock_handlers.handle_get_stock_news_for_date(self, websocket, data)
|
||||||
|
|
||||||
|
async def _handle_get_stock_news_timeline(
|
||||||
|
self,
|
||||||
|
websocket: ServerConnection,
|
||||||
|
data: Dict[str, Any],
|
||||||
|
):
|
||||||
|
await gateway_stock_handlers.handle_get_stock_news_timeline(self, websocket, data)
|
||||||
|
|
||||||
|
async def _handle_get_stock_news_categories(
|
||||||
|
self,
|
||||||
|
websocket: ServerConnection,
|
||||||
|
data: Dict[str, Any],
|
||||||
|
):
|
||||||
|
await gateway_stock_handlers.handle_get_stock_news_categories(self, websocket, data)
|
||||||
|
|
||||||
|
async def _handle_get_stock_range_explain(
|
||||||
|
self,
|
||||||
|
websocket: ServerConnection,
|
||||||
|
data: Dict[str, Any],
|
||||||
|
):
|
||||||
|
await gateway_stock_handlers.handle_get_stock_range_explain(self, websocket, data)
|
||||||
|
|
||||||
|
async def _handle_get_stock_insider_trades(
|
||||||
|
self,
|
||||||
|
websocket: ServerConnection,
|
||||||
|
data: Dict[str, Any],
|
||||||
|
):
|
||||||
|
await gateway_stock_handlers.handle_get_stock_insider_trades(self, websocket, data)
|
||||||
|
|
||||||
|
async def _handle_get_stock_story(
|
||||||
|
self,
|
||||||
|
websocket: ServerConnection,
|
||||||
|
data: Dict[str, Any],
|
||||||
|
):
|
||||||
|
await gateway_stock_handlers.handle_get_stock_story(self, websocket, data)
|
||||||
|
|
||||||
|
async def _handle_get_stock_similar_days(
|
||||||
|
self,
|
||||||
|
websocket: ServerConnection,
|
||||||
|
data: Dict[str, Any],
|
||||||
|
):
|
||||||
|
await gateway_stock_handlers.handle_get_stock_similar_days(self, websocket, data)
|
||||||
|
|
||||||
|
async def _handle_get_stock_technical_indicators(
|
||||||
|
self,
|
||||||
|
websocket: ServerConnection,
|
||||||
|
data: Dict[str, Any],
|
||||||
|
):
|
||||||
|
await gateway_stock_handlers.handle_get_stock_technical_indicators(self, websocket, data)
|
||||||
|
|
||||||
|
async def _handle_run_stock_enrich(
|
||||||
|
self,
|
||||||
|
websocket: ServerConnection,
|
||||||
|
data: Dict[str, Any],
|
||||||
|
):
|
||||||
|
await gateway_stock_handlers.handle_run_stock_enrich(self, websocket, data)
|
||||||
|
|
||||||
async def _handle_start_backtest(self, data: Dict[str, Any]):
|
async def _handle_start_backtest(self, data: Dict[str, Any]):
|
||||||
if not self.is_backtest:
|
if not self.is_backtest:
|
||||||
return
|
return
|
||||||
@@ -238,18 +539,171 @@ class Gateway:
|
|||||||
task.add_done_callback(self._handle_backtest_exception)
|
task.add_done_callback(self._handle_backtest_exception)
|
||||||
self._backtest_task = task
|
self._backtest_task = task
|
||||||
|
|
||||||
async def _handle_reload_runtime_assets(self):
|
async def _handle_manual_trigger(
|
||||||
"""Reload prompt assets and active skills without restarting the server."""
|
self,
|
||||||
result = self.pipeline.reload_runtime_assets()
|
websocket: ServerConnection,
|
||||||
await self.state_sync.on_system_message(
|
data: Dict[str, Any],
|
||||||
"Runtime assets reloaded.",
|
) -> None:
|
||||||
)
|
"""Run one live/mock trading cycle on demand."""
|
||||||
await self.broadcast(
|
if self.is_backtest:
|
||||||
|
await websocket.send(
|
||||||
|
json.dumps(
|
||||||
{
|
{
|
||||||
"type": "runtime_assets_reloaded",
|
"type": "error",
|
||||||
**result,
|
"message": "Manual trigger is only available in live/mock mode.",
|
||||||
},
|
},
|
||||||
|
ensure_ascii=False,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if (
|
||||||
|
self._cycle_lock.locked()
|
||||||
|
or (
|
||||||
|
self._manual_cycle_task is not None
|
||||||
|
and not self._manual_cycle_task.done()
|
||||||
|
)
|
||||||
|
):
|
||||||
|
await websocket.send(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"type": "error",
|
||||||
|
"message": "A trading cycle is already running.",
|
||||||
|
},
|
||||||
|
ensure_ascii=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
await self.state_sync.on_system_message("已有任务在运行,已忽略手动触发")
|
||||||
|
return
|
||||||
|
|
||||||
|
requested_date = data.get("date")
|
||||||
|
await self.state_sync.on_system_message("收到手动触发请求,准备开始新一轮分析与决策")
|
||||||
|
task = asyncio.create_task(
|
||||||
|
self.on_strategy_trigger(
|
||||||
|
date=requested_date or datetime.now().strftime("%Y-%m-%d"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
task.add_done_callback(self._handle_manual_cycle_exception)
|
||||||
|
self._manual_cycle_task = task
|
||||||
|
|
||||||
|
async def _handle_reload_runtime_assets(self):
|
||||||
|
await gateway_admin_handlers.handle_reload_runtime_assets(self)
|
||||||
|
|
||||||
|
async def _handle_update_runtime_config(
|
||||||
|
self,
|
||||||
|
websocket: ServerConnection,
|
||||||
|
data: Dict[str, Any],
|
||||||
|
) -> None:
|
||||||
|
await gateway_admin_handlers.handle_update_runtime_config(self, websocket, data)
|
||||||
|
|
||||||
|
async def _handle_update_watchlist(
|
||||||
|
self,
|
||||||
|
websocket: ServerConnection,
|
||||||
|
data: Dict[str, Any],
|
||||||
|
) -> None:
|
||||||
|
await gateway_admin_handlers.handle_update_watchlist(self, websocket, data)
|
||||||
|
|
||||||
|
async def _handle_get_agent_skills(
|
||||||
|
self,
|
||||||
|
websocket: ServerConnection,
|
||||||
|
data: Dict[str, Any],
|
||||||
|
) -> None:
|
||||||
|
await gateway_admin_handlers.handle_get_agent_skills(self, websocket, data)
|
||||||
|
|
||||||
|
async def _handle_get_agent_profile(
|
||||||
|
self,
|
||||||
|
websocket: ServerConnection,
|
||||||
|
data: Dict[str, Any],
|
||||||
|
) -> None:
|
||||||
|
await gateway_admin_handlers.handle_get_agent_profile(self, websocket, data)
|
||||||
|
|
||||||
|
async def _handle_get_skill_detail(
|
||||||
|
self,
|
||||||
|
websocket: ServerConnection,
|
||||||
|
data: Dict[str, Any],
|
||||||
|
) -> None:
|
||||||
|
await gateway_admin_handlers.handle_get_skill_detail(self, websocket, data)
|
||||||
|
|
||||||
|
async def _handle_create_agent_local_skill(
|
||||||
|
self,
|
||||||
|
websocket: ServerConnection,
|
||||||
|
data: Dict[str, Any],
|
||||||
|
) -> None:
|
||||||
|
await gateway_admin_handlers.handle_create_agent_local_skill(self, websocket, data)
|
||||||
|
|
||||||
|
async def _handle_update_agent_local_skill(
|
||||||
|
self,
|
||||||
|
websocket: ServerConnection,
|
||||||
|
data: Dict[str, Any],
|
||||||
|
) -> None:
|
||||||
|
await gateway_admin_handlers.handle_update_agent_local_skill(self, websocket, data)
|
||||||
|
|
||||||
|
async def _handle_delete_agent_local_skill(
|
||||||
|
self,
|
||||||
|
websocket: ServerConnection,
|
||||||
|
data: Dict[str, Any],
|
||||||
|
) -> None:
|
||||||
|
await gateway_admin_handlers.handle_delete_agent_local_skill(self, websocket, data)
|
||||||
|
|
||||||
|
async def _handle_remove_agent_skill(
|
||||||
|
self,
|
||||||
|
websocket: ServerConnection,
|
||||||
|
data: Dict[str, Any],
|
||||||
|
) -> None:
|
||||||
|
await gateway_admin_handlers.handle_remove_agent_skill(self, websocket, data)
|
||||||
|
|
||||||
|
async def _handle_update_agent_skill(
|
||||||
|
self,
|
||||||
|
websocket: ServerConnection,
|
||||||
|
data: Dict[str, Any],
|
||||||
|
) -> None:
|
||||||
|
await gateway_admin_handlers.handle_update_agent_skill(self, websocket, data)
|
||||||
|
|
||||||
|
async def _handle_get_agent_workspace_file(
|
||||||
|
self,
|
||||||
|
websocket: ServerConnection,
|
||||||
|
data: Dict[str, Any],
|
||||||
|
) -> None:
|
||||||
|
await gateway_admin_handlers.handle_get_agent_workspace_file(self, websocket, data)
|
||||||
|
|
||||||
|
async def _handle_update_agent_workspace_file(
|
||||||
|
self,
|
||||||
|
websocket: ServerConnection,
|
||||||
|
data: Dict[str, Any],
|
||||||
|
) -> None:
|
||||||
|
await gateway_admin_handlers.handle_update_agent_workspace_file(self, websocket, data)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _normalize_watchlist(raw_tickers: Any) -> List[str]:
|
||||||
|
return gateway_runtime_support.normalize_watchlist(raw_tickers)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _normalize_agent_workspace_filename(raw_name: Any) -> Optional[str]:
|
||||||
|
return gateway_runtime_support.normalize_agent_workspace_filename(
|
||||||
|
raw_name,
|
||||||
|
allowlist=EDITABLE_AGENT_WORKSPACE_FILES,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _apply_runtime_config(
|
||||||
|
self,
|
||||||
|
runtime_config: Dict[str, Any],
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
return gateway_runtime_support.apply_runtime_config(self, runtime_config)
|
||||||
|
|
||||||
|
def _sync_runtime_state(self) -> None:
|
||||||
|
gateway_runtime_support.sync_runtime_state(self)
|
||||||
|
|
||||||
|
def _schedule_watchlist_market_store_refresh(
|
||||||
|
self,
|
||||||
|
tickers: List[str],
|
||||||
|
) -> None:
|
||||||
|
gateway_cycle_support.schedule_watchlist_market_store_refresh(self, tickers)
|
||||||
|
|
||||||
|
async def _refresh_market_store_for_watchlist(
|
||||||
|
self,
|
||||||
|
tickers: List[str],
|
||||||
|
) -> None:
|
||||||
|
await gateway_cycle_support.refresh_market_store_for_watchlist(self, tickers)
|
||||||
|
|
||||||
async def broadcast(self, message: Dict[str, Any]):
|
async def broadcast(self, message: Dict[str, Any]):
|
||||||
"""Broadcast message to all connected clients"""
|
"""Broadcast message to all connected clients"""
|
||||||
@@ -269,7 +723,7 @@ class Gateway:
|
|||||||
|
|
||||||
async def _send_to_client(
|
async def _send_to_client(
|
||||||
self,
|
self,
|
||||||
client: WebSocketServerProtocol,
|
client: ServerConnection,
|
||||||
message: str,
|
message: str,
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
@@ -279,248 +733,39 @@ class Gateway:
|
|||||||
self.connected_clients.discard(client)
|
self.connected_clients.discard(client)
|
||||||
|
|
||||||
async def _market_status_monitor(self):
|
async def _market_status_monitor(self):
|
||||||
"""Periodically check and broadcast market status changes"""
|
await gateway_cycle_support.market_status_monitor(self)
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
await self.market_service.check_and_broadcast_market_status()
|
|
||||||
|
|
||||||
# On market open, start live session tracking
|
|
||||||
status = self.market_service.get_market_status()
|
|
||||||
if (
|
|
||||||
status["status"] == "open"
|
|
||||||
and not self.storage.is_live_session_active
|
|
||||||
):
|
|
||||||
self.storage.start_live_session()
|
|
||||||
summary = self.storage.load_file("summary") or {}
|
|
||||||
self._session_start_portfolio_value = summary.get(
|
|
||||||
"totalAssetValue",
|
|
||||||
self.storage.initial_cash,
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
"Session start portfolio: "
|
|
||||||
f"${self._session_start_portfolio_value:,.2f}",
|
|
||||||
)
|
|
||||||
elif (
|
|
||||||
status["status"] != "open"
|
|
||||||
and self.storage.is_live_session_active
|
|
||||||
):
|
|
||||||
self.storage.end_live_session()
|
|
||||||
self._session_start_portfolio_value = None
|
|
||||||
|
|
||||||
# Update and broadcast live returns if session is active
|
|
||||||
if self.storage.is_live_session_active:
|
|
||||||
await self._update_and_broadcast_live_returns()
|
|
||||||
|
|
||||||
await asyncio.sleep(60) # Check every minute
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
break
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Market status monitor error: {e}")
|
|
||||||
await asyncio.sleep(60)
|
|
||||||
|
|
||||||
async def _update_and_broadcast_live_returns(self):
|
async def _update_and_broadcast_live_returns(self):
|
||||||
"""Calculate and broadcast live returns for current session"""
|
await gateway_cycle_support.update_and_broadcast_live_returns(self)
|
||||||
if not self.storage.is_live_session_active:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Get current prices and calculate portfolio value
|
|
||||||
prices = self.market_service.get_all_prices()
|
|
||||||
if not prices or not any(p > 0 for p in prices.values()):
|
|
||||||
return
|
|
||||||
|
|
||||||
# Load current internal state to get baseline values
|
|
||||||
state = self.storage.load_internal_state()
|
|
||||||
|
|
||||||
# Get latest values from history (if available)
|
|
||||||
equity_history = state.get("equity_history", [])
|
|
||||||
baseline_history = state.get("baseline_history", [])
|
|
||||||
baseline_vw_history = state.get("baseline_vw_history", [])
|
|
||||||
momentum_history = state.get("momentum_history", [])
|
|
||||||
|
|
||||||
current_equity = equity_history[-1]["v"] if equity_history else None
|
|
||||||
current_baseline = (
|
|
||||||
baseline_history[-1]["v"] if baseline_history else None
|
|
||||||
)
|
|
||||||
current_baseline_vw = (
|
|
||||||
baseline_vw_history[-1]["v"] if baseline_vw_history else None
|
|
||||||
)
|
|
||||||
current_momentum = (
|
|
||||||
momentum_history[-1]["v"] if momentum_history else None
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update live returns with current values
|
|
||||||
point = self.storage.update_live_returns(
|
|
||||||
current_equity=current_equity,
|
|
||||||
current_baseline=current_baseline,
|
|
||||||
current_baseline_vw=current_baseline_vw,
|
|
||||||
current_momentum=current_momentum,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Broadcast if we have new data
|
|
||||||
if point:
|
|
||||||
live_returns = self.storage.get_live_returns()
|
|
||||||
await self.broadcast(
|
|
||||||
{
|
|
||||||
"type": "team_summary",
|
|
||||||
"equity_return": live_returns["equity_return"],
|
|
||||||
"baseline_return": live_returns["baseline_return"],
|
|
||||||
"baseline_vw_return": live_returns["baseline_vw_return"],
|
|
||||||
"momentum_return": live_returns["momentum_return"],
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
async def on_strategy_trigger(self, date: str):
|
async def on_strategy_trigger(self, date: str):
|
||||||
"""Handle trading cycle trigger"""
|
await gateway_cycle_support.on_strategy_trigger(self, date)
|
||||||
logger.info(f"Strategy triggered for {date}")
|
|
||||||
|
|
||||||
tickers = self.config.get("tickers", [])
|
async def on_heartbeat_trigger(self, date: str):
|
||||||
|
await gateway_cycle_support.on_heartbeat_trigger(self, date)
|
||||||
if self.is_backtest:
|
|
||||||
await self._run_backtest_cycle(date, tickers)
|
|
||||||
else:
|
|
||||||
await self._run_live_cycle(date, tickers)
|
|
||||||
|
|
||||||
async def _run_backtest_cycle(self, date: str, tickers: List[str]):
|
async def _run_backtest_cycle(self, date: str, tickers: List[str]):
|
||||||
"""Run backtest cycle with pre-loaded prices"""
|
await gateway_cycle_support.run_backtest_cycle(self, date, tickers)
|
||||||
self.market_service.set_backtest_date(date)
|
|
||||||
await self.market_service.emit_market_open()
|
|
||||||
|
|
||||||
await self.state_sync.on_cycle_start(date)
|
|
||||||
self._dashboard.update(date=date, status="Analyzing...")
|
|
||||||
|
|
||||||
prices = self.market_service.get_open_prices()
|
|
||||||
close_prices = self.market_service.get_close_prices()
|
|
||||||
market_caps = self._get_market_caps(tickers, date)
|
|
||||||
|
|
||||||
result = await self.pipeline.run_cycle(
|
|
||||||
tickers=tickers,
|
|
||||||
date=date,
|
|
||||||
prices=prices,
|
|
||||||
close_prices=close_prices,
|
|
||||||
market_caps=market_caps,
|
|
||||||
)
|
|
||||||
|
|
||||||
await self.market_service.emit_market_close()
|
|
||||||
settlement_result = result.get("settlement_result")
|
|
||||||
self._save_cycle_results(result, date, close_prices, settlement_result)
|
|
||||||
await self._broadcast_portfolio_updates(result, close_prices)
|
|
||||||
await self._finalize_cycle(date)
|
|
||||||
|
|
||||||
async def _run_live_cycle(self, date: str, tickers: List[str]):
|
async def _run_live_cycle(self, date: str, tickers: List[str]):
|
||||||
"""
|
await gateway_cycle_support.run_live_cycle(self, date, tickers)
|
||||||
Run live cycle with real market timing.
|
|
||||||
|
|
||||||
- Analysis runs immediately
|
|
||||||
- Execution waits for market open
|
|
||||||
(or uses current prices if already open)
|
|
||||||
- Settlement waits for market close
|
|
||||||
"""
|
|
||||||
# Get actual trading date (might be next trading day if weekend)
|
|
||||||
trading_date = self.market_service.get_live_trading_date()
|
|
||||||
logger.info(
|
|
||||||
f"Live cycle: triggered={date}, trading_date={trading_date}",
|
|
||||||
)
|
|
||||||
|
|
||||||
await self.state_sync.on_cycle_start(trading_date)
|
|
||||||
self._dashboard.update(date=trading_date, status="Analyzing...")
|
|
||||||
|
|
||||||
market_caps = self._get_market_caps(tickers, trading_date)
|
|
||||||
|
|
||||||
# Run pipeline with async price callbacks
|
|
||||||
result = await self.pipeline.run_cycle(
|
|
||||||
tickers=tickers,
|
|
||||||
date=trading_date,
|
|
||||||
market_caps=market_caps,
|
|
||||||
get_open_prices_fn=self.market_service.wait_for_open_prices,
|
|
||||||
get_close_prices_fn=self.market_service.wait_for_close_prices,
|
|
||||||
)
|
|
||||||
|
|
||||||
close_prices = self.market_service.get_all_prices()
|
|
||||||
settlement_result = result.get("settlement_result")
|
|
||||||
self._save_cycle_results(
|
|
||||||
result,
|
|
||||||
trading_date,
|
|
||||||
close_prices,
|
|
||||||
settlement_result,
|
|
||||||
)
|
|
||||||
await self._broadcast_portfolio_updates(result, close_prices)
|
|
||||||
await self._finalize_cycle(trading_date)
|
|
||||||
|
|
||||||
async def _finalize_cycle(self, date: str):
|
async def _finalize_cycle(self, date: str):
|
||||||
"""Finalize cycle: broadcast state and update dashboard"""
|
await gateway_cycle_support.finalize_cycle(self, date)
|
||||||
summary = self.storage.load_file("summary") or {}
|
|
||||||
|
|
||||||
# Include live returns if session is active
|
async def _get_market_caps(
|
||||||
if self.storage.is_live_session_active:
|
|
||||||
live_returns = self.storage.get_live_returns()
|
|
||||||
summary.update(live_returns)
|
|
||||||
|
|
||||||
await self.state_sync.on_cycle_end(date, portfolio_summary=summary)
|
|
||||||
|
|
||||||
holdings = self.storage.load_file("holdings") or []
|
|
||||||
trades = self.storage.load_file("trades") or []
|
|
||||||
leaderboard = self.storage.load_file("leaderboard") or []
|
|
||||||
|
|
||||||
if leaderboard:
|
|
||||||
await self.state_sync.on_leaderboard_update(leaderboard)
|
|
||||||
|
|
||||||
self._dashboard.update(
|
|
||||||
date=date,
|
|
||||||
status="Running",
|
|
||||||
portfolio=summary,
|
|
||||||
holdings=holdings,
|
|
||||||
trades=trades,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_market_caps(
|
|
||||||
self,
|
self,
|
||||||
tickers: List[str],
|
tickers: List[str],
|
||||||
date: str,
|
date: str,
|
||||||
) -> Dict[str, float]:
|
) -> Dict[str, float]:
|
||||||
"""
|
return await gateway_cycle_support.get_market_caps(self, tickers, date)
|
||||||
Get market caps for tickers (stub implementation)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tickers: List of tickers
|
|
||||||
date: Trading date
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict mapping ticker to market cap
|
|
||||||
"""
|
|
||||||
from ..tools.data_tools import get_market_cap
|
|
||||||
|
|
||||||
market_caps = {}
|
|
||||||
for ticker in tickers:
|
|
||||||
try:
|
|
||||||
market_cap = get_market_cap(ticker, date)
|
|
||||||
if market_cap:
|
|
||||||
market_caps[ticker] = market_cap
|
|
||||||
else:
|
|
||||||
market_caps[ticker] = 1e9
|
|
||||||
except Exception:
|
|
||||||
market_caps[ticker] = 1e9
|
|
||||||
|
|
||||||
return market_caps
|
|
||||||
|
|
||||||
async def _broadcast_portfolio_updates(
|
async def _broadcast_portfolio_updates(
|
||||||
self,
|
self,
|
||||||
result: Dict[str, Any],
|
result: Dict[str, Any],
|
||||||
prices: Dict[str, float],
|
prices: Dict[str, float],
|
||||||
):
|
):
|
||||||
portfolio = result.get("portfolio", {})
|
await gateway_cycle_support.broadcast_portfolio_updates(self, result, prices)
|
||||||
|
|
||||||
if portfolio:
|
|
||||||
holdings = FrontendAdapter.build_holdings(portfolio, prices)
|
|
||||||
if holdings:
|
|
||||||
await self.state_sync.on_holdings_update(holdings)
|
|
||||||
|
|
||||||
stats = FrontendAdapter.build_stats(portfolio, prices)
|
|
||||||
if stats:
|
|
||||||
await self.state_sync.on_stats_update(stats)
|
|
||||||
|
|
||||||
executed_trades = result.get("executed_trades", [])
|
|
||||||
if executed_trades:
|
|
||||||
await self.state_sync.on_trades_executed(executed_trades)
|
|
||||||
|
|
||||||
def _save_cycle_results(
|
def _save_cycle_results(
|
||||||
self,
|
self,
|
||||||
@@ -529,84 +774,25 @@ class Gateway:
|
|||||||
prices: Dict[str, float],
|
prices: Dict[str, float],
|
||||||
settlement_result: Optional[Dict[str, Any]] = None,
|
settlement_result: Optional[Dict[str, Any]] = None,
|
||||||
):
|
):
|
||||||
portfolio = result.get("portfolio", {})
|
gateway_cycle_support.save_cycle_results(
|
||||||
executed_trades = result.get("executed_trades", [])
|
self,
|
||||||
|
result,
|
||||||
# Extract baseline values from settlement result
|
date,
|
||||||
baseline_values = None
|
prices,
|
||||||
if settlement_result:
|
settlement_result,
|
||||||
baseline_values = settlement_result.get("baseline_values")
|
|
||||||
|
|
||||||
if portfolio:
|
|
||||||
self.storage.update_dashboard_after_cycle(
|
|
||||||
portfolio=portfolio,
|
|
||||||
prices=prices,
|
|
||||||
date=date,
|
|
||||||
executed_trades=executed_trades,
|
|
||||||
baseline_values=baseline_values,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _run_backtest_dates(self, dates: List[str]):
|
async def _run_backtest_dates(self, dates: List[str]):
|
||||||
self.state_sync.set_backtest_dates(dates)
|
await gateway_cycle_support.run_backtest_dates(self, dates)
|
||||||
self._dashboard.update(days_total=len(dates), days_completed=0)
|
|
||||||
|
|
||||||
await self.state_sync.on_system_message(
|
|
||||||
f"Starting backtest - {len(dates)} trading days",
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
for i, date in enumerate(dates):
|
|
||||||
self._dashboard.update(days_completed=i)
|
|
||||||
await self.on_strategy_trigger(date=date)
|
|
||||||
await asyncio.sleep(0.1)
|
|
||||||
|
|
||||||
await self.state_sync.on_system_message(
|
|
||||||
f"Backtest complete - {len(dates)} days",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Update dashboard with final state
|
|
||||||
summary = self.storage.load_file("summary") or {}
|
|
||||||
self._dashboard.update(
|
|
||||||
status="Complete",
|
|
||||||
portfolio=summary,
|
|
||||||
days_completed=len(dates),
|
|
||||||
)
|
|
||||||
self._dashboard.stop()
|
|
||||||
self._dashboard.print_final_summary()
|
|
||||||
except Exception as e:
|
|
||||||
error_msg = f"Backtest failed: {type(e).__name__}: {str(e)}"
|
|
||||||
logger.error(error_msg, exc_info=True)
|
|
||||||
await self.state_sync.on_system_message(error_msg)
|
|
||||||
self._dashboard.update(status=f"Failed: {str(e)}")
|
|
||||||
self._dashboard.stop()
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
self._backtest_task = None
|
|
||||||
|
|
||||||
def _handle_backtest_exception(self, task: asyncio.Task):
|
def _handle_backtest_exception(self, task: asyncio.Task):
|
||||||
"""Handle exceptions from backtest task"""
|
gateway_cycle_support.handle_backtest_exception(self, task)
|
||||||
try:
|
|
||||||
task.result()
|
def _handle_manual_cycle_exception(self, task: asyncio.Task):
|
||||||
except asyncio.CancelledError:
|
gateway_cycle_support.handle_manual_cycle_exception(self, task)
|
||||||
logger.info("Backtest task was cancelled")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Backtest task failed with exception:{type(e).__name__}:{e}",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
def set_backtest_dates(self, dates: List[str]):
|
def set_backtest_dates(self, dates: List[str]):
|
||||||
self.state_sync.set_backtest_dates(dates)
|
gateway_cycle_support.set_backtest_dates(self, dates)
|
||||||
if dates:
|
|
||||||
self._backtest_start_date = dates[0]
|
|
||||||
self._backtest_end_date = dates[-1]
|
|
||||||
self._dashboard.days_total = len(dates)
|
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
self.state_sync.save_state()
|
gateway_cycle_support.stop_gateway(self)
|
||||||
self.market_service.stop()
|
|
||||||
if self._backtest_task:
|
|
||||||
self._backtest_task.cancel()
|
|
||||||
if self._market_status_task:
|
|
||||||
self._market_status_task.cancel()
|
|
||||||
self._dashboard.stop()
|
|
||||||
|
|||||||
419
backend/services/gateway_admin_handlers.py
Normal file
419
backend/services/gateway_admin_handlers.py
Normal file
@@ -0,0 +1,419 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""Runtime/workspace/skills handlers extracted from the main Gateway module."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from backend.agents.agent_workspace import load_agent_workspace_config
|
||||||
|
from backend.agents.skills_manager import SkillsManager
|
||||||
|
from backend.agents.toolkit_factory import load_agent_profiles
|
||||||
|
from backend.config.bootstrap_config import (
|
||||||
|
get_bootstrap_config_for_run,
|
||||||
|
resolve_runtime_config,
|
||||||
|
update_bootstrap_values_for_run,
|
||||||
|
)
|
||||||
|
from backend.data.market_ingest import ingest_symbols
|
||||||
|
from backend.llm.models import get_agent_model_info
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_reload_runtime_assets(gateway: Any) -> None:
|
||||||
|
config_name = gateway.config.get("config_name", "default")
|
||||||
|
runtime_config = resolve_runtime_config(
|
||||||
|
project_root=gateway._project_root,
|
||||||
|
config_name=config_name,
|
||||||
|
enable_memory=gateway.config.get("enable_memory", False),
|
||||||
|
schedule_mode=gateway.config.get("schedule_mode", "daily"),
|
||||||
|
interval_minutes=gateway.config.get("interval_minutes", 60),
|
||||||
|
trigger_time=gateway.config.get("trigger_time", "09:30"),
|
||||||
|
)
|
||||||
|
result = gateway.pipeline.reload_runtime_assets(runtime_config=runtime_config)
|
||||||
|
runtime_updates = gateway._apply_runtime_config(runtime_config)
|
||||||
|
await gateway.state_sync.on_system_message("Runtime assets reloaded.")
|
||||||
|
await gateway.broadcast({"type": "runtime_assets_reloaded", **result, **runtime_updates})
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_update_runtime_config(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||||
|
updates: dict[str, Any] = {}
|
||||||
|
|
||||||
|
schedule_mode = str(data.get("schedule_mode", "")).strip().lower()
|
||||||
|
if schedule_mode:
|
||||||
|
if schedule_mode not in {"daily", "intraday"}:
|
||||||
|
await websocket.send(json.dumps({"type": "error", "message": "schedule_mode must be 'daily' or 'intraday'."}, ensure_ascii=False))
|
||||||
|
return
|
||||||
|
updates["schedule_mode"] = schedule_mode
|
||||||
|
|
||||||
|
interval_minutes = data.get("interval_minutes")
|
||||||
|
if interval_minutes is not None:
|
||||||
|
try:
|
||||||
|
parsed_interval = int(interval_minutes)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
parsed_interval = 0
|
||||||
|
if parsed_interval <= 0:
|
||||||
|
await websocket.send(json.dumps({"type": "error", "message": "interval_minutes must be a positive integer."}, ensure_ascii=False))
|
||||||
|
return
|
||||||
|
updates["interval_minutes"] = parsed_interval
|
||||||
|
|
||||||
|
trigger_time = data.get("trigger_time")
|
||||||
|
if trigger_time is not None:
|
||||||
|
raw_trigger = str(trigger_time).strip()
|
||||||
|
if raw_trigger and raw_trigger != "now":
|
||||||
|
try:
|
||||||
|
datetime.strptime(raw_trigger, "%H:%M")
|
||||||
|
except ValueError:
|
||||||
|
await websocket.send(json.dumps({"type": "error", "message": "trigger_time must use HH:MM or 'now'."}, ensure_ascii=False))
|
||||||
|
return
|
||||||
|
updates["trigger_time"] = raw_trigger or "09:30"
|
||||||
|
|
||||||
|
max_comm_cycles = data.get("max_comm_cycles")
|
||||||
|
if max_comm_cycles is not None:
|
||||||
|
try:
|
||||||
|
parsed_cycles = int(max_comm_cycles)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
parsed_cycles = 0
|
||||||
|
if parsed_cycles <= 0:
|
||||||
|
await websocket.send(json.dumps({"type": "error", "message": "max_comm_cycles must be a positive integer."}, ensure_ascii=False))
|
||||||
|
return
|
||||||
|
updates["max_comm_cycles"] = parsed_cycles
|
||||||
|
|
||||||
|
initial_cash = data.get("initial_cash")
|
||||||
|
if initial_cash is not None:
|
||||||
|
try:
|
||||||
|
parsed_initial_cash = float(initial_cash)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
parsed_initial_cash = 0.0
|
||||||
|
if parsed_initial_cash <= 0:
|
||||||
|
await websocket.send(json.dumps({"type": "error", "message": "initial_cash must be a positive number."}, ensure_ascii=False))
|
||||||
|
return
|
||||||
|
updates["initial_cash"] = parsed_initial_cash
|
||||||
|
|
||||||
|
margin_requirement = data.get("margin_requirement")
|
||||||
|
if margin_requirement is not None:
|
||||||
|
try:
|
||||||
|
parsed_margin_requirement = float(margin_requirement)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
parsed_margin_requirement = -1.0
|
||||||
|
if parsed_margin_requirement < 0:
|
||||||
|
await websocket.send(json.dumps({"type": "error", "message": "margin_requirement must be a non-negative number."}, ensure_ascii=False))
|
||||||
|
return
|
||||||
|
updates["margin_requirement"] = parsed_margin_requirement
|
||||||
|
|
||||||
|
enable_memory = data.get("enable_memory")
|
||||||
|
if enable_memory is not None:
|
||||||
|
updates["enable_memory"] = bool(enable_memory)
|
||||||
|
|
||||||
|
if not updates:
|
||||||
|
await websocket.send(json.dumps({"type": "error", "message": "No runtime settings were provided."}, ensure_ascii=False))
|
||||||
|
return
|
||||||
|
|
||||||
|
config_name = gateway.config.get("config_name", "default")
|
||||||
|
update_bootstrap_values_for_run(
|
||||||
|
project_root=gateway._project_root,
|
||||||
|
config_name=config_name,
|
||||||
|
updates=updates,
|
||||||
|
)
|
||||||
|
await gateway.state_sync.on_system_message("运行时调度配置已保存,正在热更新")
|
||||||
|
await handle_reload_runtime_assets(gateway)
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_update_watchlist(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||||
|
tickers = gateway._normalize_watchlist(data.get("tickers"))
|
||||||
|
if not tickers:
|
||||||
|
await websocket.send(json.dumps({"type": "error", "message": "update_watchlist requires at least one valid ticker."}, ensure_ascii=False))
|
||||||
|
return
|
||||||
|
|
||||||
|
config_name = gateway.config.get("config_name", "default")
|
||||||
|
update_bootstrap_values_for_run(
|
||||||
|
project_root=gateway._project_root,
|
||||||
|
config_name=config_name,
|
||||||
|
updates={"tickers": tickers},
|
||||||
|
)
|
||||||
|
await gateway.state_sync.on_system_message(f"Watchlist updated: {', '.join(tickers)}")
|
||||||
|
await gateway.broadcast({"type": "watchlist_updated", "config_name": config_name, "tickers": tickers})
|
||||||
|
await handle_reload_runtime_assets(gateway)
|
||||||
|
gateway._schedule_watchlist_market_store_refresh(tickers)
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_get_agent_skills(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||||
|
agent_id = str(data.get("agent_id", "")).strip()
|
||||||
|
if not agent_id:
|
||||||
|
await websocket.send(json.dumps({"type": "error", "message": "get_agent_skills requires agent_id."}, ensure_ascii=False))
|
||||||
|
return
|
||||||
|
|
||||||
|
config_name = gateway.config.get("config_name", "default")
|
||||||
|
skills_manager = SkillsManager(project_root=gateway._project_root)
|
||||||
|
agent_asset_dir = skills_manager.get_agent_asset_dir(config_name, agent_id)
|
||||||
|
agent_config = load_agent_workspace_config(agent_asset_dir / "agent.yaml")
|
||||||
|
resolved_skills = set(skills_manager.resolve_agent_skill_names(config_name=config_name, agent_id=agent_id, default_skills=[]))
|
||||||
|
enabled = set(agent_config.enabled_skills)
|
||||||
|
disabled = set(agent_config.disabled_skills)
|
||||||
|
|
||||||
|
payload = []
|
||||||
|
for item in skills_manager.list_agent_skill_catalog(config_name, agent_id):
|
||||||
|
if item.skill_name in disabled:
|
||||||
|
status = "disabled"
|
||||||
|
elif item.skill_name in enabled:
|
||||||
|
status = "enabled"
|
||||||
|
elif item.skill_name in resolved_skills:
|
||||||
|
status = "active"
|
||||||
|
else:
|
||||||
|
status = "available"
|
||||||
|
payload.append({
|
||||||
|
"skill_name": item.skill_name,
|
||||||
|
"name": item.name,
|
||||||
|
"description": item.description,
|
||||||
|
"version": item.version,
|
||||||
|
"source": item.source,
|
||||||
|
"tools": item.tools,
|
||||||
|
"status": status,
|
||||||
|
})
|
||||||
|
|
||||||
|
await websocket.send(json.dumps({
|
||||||
|
"type": "agent_skills_loaded",
|
||||||
|
"config_name": config_name,
|
||||||
|
"agent_id": agent_id,
|
||||||
|
"skills": payload,
|
||||||
|
}, ensure_ascii=False))
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_get_agent_profile(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||||
|
agent_id = str(data.get("agent_id", "")).strip()
|
||||||
|
if not agent_id:
|
||||||
|
await websocket.send(json.dumps({"type": "error", "message": "get_agent_profile requires agent_id."}, ensure_ascii=False))
|
||||||
|
return
|
||||||
|
|
||||||
|
config_name = gateway.config.get("config_name", "default")
|
||||||
|
skills_manager = SkillsManager(project_root=gateway._project_root)
|
||||||
|
asset_dir = skills_manager.get_agent_asset_dir(config_name, agent_id)
|
||||||
|
agent_config = load_agent_workspace_config(asset_dir / "agent.yaml")
|
||||||
|
profiles = load_agent_profiles()
|
||||||
|
profile = profiles.get(agent_id, {})
|
||||||
|
bootstrap = get_bootstrap_config_for_run(gateway._project_root, config_name)
|
||||||
|
override = bootstrap.agent_override(agent_id)
|
||||||
|
active_tool_groups = override.get("active_tool_groups", agent_config.active_tool_groups or profile.get("active_tool_groups", []))
|
||||||
|
if not isinstance(active_tool_groups, list):
|
||||||
|
active_tool_groups = []
|
||||||
|
disabled_tool_groups = agent_config.disabled_tool_groups
|
||||||
|
if disabled_tool_groups:
|
||||||
|
disabled_set = set(disabled_tool_groups)
|
||||||
|
active_tool_groups = [group_name for group_name in active_tool_groups if group_name not in disabled_set]
|
||||||
|
|
||||||
|
default_skills = profile.get("skills", [])
|
||||||
|
if not isinstance(default_skills, list):
|
||||||
|
default_skills = []
|
||||||
|
resolved_skills = skills_manager.resolve_agent_skill_names(
|
||||||
|
config_name=config_name,
|
||||||
|
agent_id=agent_id,
|
||||||
|
default_skills=default_skills,
|
||||||
|
)
|
||||||
|
prompt_files = agent_config.prompt_files or ["SOUL.md", "PROFILE.md", "AGENTS.md", "POLICY.md", "MEMORY.md"]
|
||||||
|
model_name, model_provider = get_agent_model_info(agent_id)
|
||||||
|
|
||||||
|
await websocket.send(json.dumps({
|
||||||
|
"type": "agent_profile_loaded",
|
||||||
|
"config_name": config_name,
|
||||||
|
"agent_id": agent_id,
|
||||||
|
"profile": {
|
||||||
|
"model_name": model_name,
|
||||||
|
"model_provider": model_provider,
|
||||||
|
"prompt_files": prompt_files,
|
||||||
|
"default_skills": default_skills,
|
||||||
|
"resolved_skills": resolved_skills,
|
||||||
|
"active_tool_groups": active_tool_groups,
|
||||||
|
"disabled_tool_groups": disabled_tool_groups,
|
||||||
|
"enabled_skills": agent_config.enabled_skills,
|
||||||
|
"disabled_skills": agent_config.disabled_skills,
|
||||||
|
},
|
||||||
|
}, ensure_ascii=False))
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_get_skill_detail(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||||
|
agent_id = str(data.get("agent_id", "")).strip()
|
||||||
|
skill_name = str(data.get("skill_name", "")).strip()
|
||||||
|
if not skill_name:
|
||||||
|
await websocket.send(json.dumps({"type": "error", "message": "get_skill_detail requires skill_name."}, ensure_ascii=False))
|
||||||
|
return
|
||||||
|
|
||||||
|
skills_manager = SkillsManager(project_root=gateway._project_root)
|
||||||
|
try:
|
||||||
|
if agent_id:
|
||||||
|
config_name = gateway.config.get("config_name", "default")
|
||||||
|
detail = skills_manager.load_agent_skill_document(config_name=config_name, agent_id=agent_id, skill_name=skill_name)
|
||||||
|
else:
|
||||||
|
detail = skills_manager.load_skill_document(skill_name)
|
||||||
|
except FileNotFoundError:
|
||||||
|
await websocket.send(json.dumps({"type": "error", "message": f"Unknown skill: {skill_name}"}, ensure_ascii=False))
|
||||||
|
return
|
||||||
|
|
||||||
|
await websocket.send(json.dumps({
|
||||||
|
"type": "skill_detail_loaded",
|
||||||
|
"agent_id": agent_id,
|
||||||
|
"skill": detail,
|
||||||
|
}, ensure_ascii=False))
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_create_agent_local_skill(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||||
|
agent_id = str(data.get("agent_id", "")).strip()
|
||||||
|
skill_name = str(data.get("skill_name", "")).strip()
|
||||||
|
if not agent_id or not skill_name:
|
||||||
|
await websocket.send(json.dumps({"type": "error", "message": "create_agent_local_skill requires agent_id and skill_name."}, ensure_ascii=False))
|
||||||
|
return
|
||||||
|
|
||||||
|
config_name = gateway.config.get("config_name", "default")
|
||||||
|
skills_manager = SkillsManager(project_root=gateway._project_root)
|
||||||
|
try:
|
||||||
|
skills_manager.create_agent_local_skill(config_name=config_name, agent_id=agent_id, skill_name=skill_name)
|
||||||
|
except (ValueError, FileExistsError) as exc:
|
||||||
|
await websocket.send(json.dumps({"type": "error", "message": str(exc)}, ensure_ascii=False))
|
||||||
|
return
|
||||||
|
|
||||||
|
await gateway.state_sync.on_system_message(f"Created local skill {skill_name} for {agent_id}")
|
||||||
|
await gateway._handle_reload_runtime_assets()
|
||||||
|
await websocket.send(json.dumps({"type": "agent_local_skill_created", "agent_id": agent_id, "skill_name": skill_name}, ensure_ascii=False))
|
||||||
|
await handle_get_agent_skills(gateway, websocket, {"agent_id": agent_id})
|
||||||
|
await handle_get_skill_detail(gateway, websocket, {"agent_id": agent_id, "skill_name": skill_name})
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_update_agent_local_skill(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||||
|
agent_id = str(data.get("agent_id", "")).strip()
|
||||||
|
skill_name = str(data.get("skill_name", "")).strip()
|
||||||
|
content = data.get("content")
|
||||||
|
if not agent_id or not skill_name or not isinstance(content, str):
|
||||||
|
await websocket.send(json.dumps({"type": "error", "message": "update_agent_local_skill requires agent_id, skill_name, and string content."}, ensure_ascii=False))
|
||||||
|
return
|
||||||
|
|
||||||
|
config_name = gateway.config.get("config_name", "default")
|
||||||
|
skills_manager = SkillsManager(project_root=gateway._project_root)
|
||||||
|
try:
|
||||||
|
skills_manager.update_agent_local_skill(config_name=config_name, agent_id=agent_id, skill_name=skill_name, content=content)
|
||||||
|
except (ValueError, FileNotFoundError) as exc:
|
||||||
|
await websocket.send(json.dumps({"type": "error", "message": str(exc)}, ensure_ascii=False))
|
||||||
|
return
|
||||||
|
|
||||||
|
await gateway.state_sync.on_system_message(f"Updated local skill {skill_name} for {agent_id}")
|
||||||
|
await gateway._handle_reload_runtime_assets()
|
||||||
|
await websocket.send(json.dumps({"type": "agent_local_skill_updated", "agent_id": agent_id, "skill_name": skill_name}, ensure_ascii=False))
|
||||||
|
await handle_get_skill_detail(gateway, websocket, {"agent_id": agent_id, "skill_name": skill_name})
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_delete_agent_local_skill(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||||
|
agent_id = str(data.get("agent_id", "")).strip()
|
||||||
|
skill_name = str(data.get("skill_name", "")).strip()
|
||||||
|
if not agent_id or not skill_name:
|
||||||
|
await websocket.send(json.dumps({"type": "error", "message": "delete_agent_local_skill requires agent_id and skill_name."}, ensure_ascii=False))
|
||||||
|
return
|
||||||
|
|
||||||
|
config_name = gateway.config.get("config_name", "default")
|
||||||
|
skills_manager = SkillsManager(project_root=gateway._project_root)
|
||||||
|
try:
|
||||||
|
skills_manager.delete_agent_local_skill(config_name=config_name, agent_id=agent_id, skill_name=skill_name)
|
||||||
|
skills_manager.forget_agent_skill_overrides(config_name=config_name, agent_id=agent_id, skill_names=[skill_name])
|
||||||
|
except (ValueError, FileNotFoundError) as exc:
|
||||||
|
await websocket.send(json.dumps({"type": "error", "message": str(exc)}, ensure_ascii=False))
|
||||||
|
return
|
||||||
|
|
||||||
|
await gateway.state_sync.on_system_message(f"Deleted local skill {skill_name} for {agent_id}")
|
||||||
|
await gateway._handle_reload_runtime_assets()
|
||||||
|
await websocket.send(json.dumps({"type": "agent_local_skill_deleted", "agent_id": agent_id, "skill_name": skill_name}, ensure_ascii=False))
|
||||||
|
await handle_get_agent_skills(gateway, websocket, {"agent_id": agent_id})
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_remove_agent_skill(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||||
|
agent_id = str(data.get("agent_id", "")).strip()
|
||||||
|
skill_name = str(data.get("skill_name", "")).strip()
|
||||||
|
if not agent_id or not skill_name:
|
||||||
|
await websocket.send(json.dumps({"type": "error", "message": "remove_agent_skill requires agent_id and skill_name."}, ensure_ascii=False))
|
||||||
|
return
|
||||||
|
|
||||||
|
config_name = gateway.config.get("config_name", "default")
|
||||||
|
skills_manager = SkillsManager(project_root=gateway._project_root)
|
||||||
|
skill_names = {
|
||||||
|
item.skill_name
|
||||||
|
for item in skills_manager.list_agent_skill_catalog(config_name, agent_id)
|
||||||
|
if item.source != "local"
|
||||||
|
}
|
||||||
|
if skill_name not in skill_names:
|
||||||
|
await websocket.send(json.dumps({"type": "error", "message": f"Unknown shared skill: {skill_name}"}, ensure_ascii=False))
|
||||||
|
return
|
||||||
|
|
||||||
|
skills_manager.update_agent_skill_overrides(config_name=config_name, agent_id=agent_id, disable=[skill_name])
|
||||||
|
await gateway.state_sync.on_system_message(f"Removed shared skill {skill_name} from {agent_id}")
|
||||||
|
await gateway._handle_reload_runtime_assets()
|
||||||
|
await websocket.send(json.dumps({"type": "agent_skill_removed", "agent_id": agent_id, "skill_name": skill_name}, ensure_ascii=False))
|
||||||
|
await handle_get_agent_skills(gateway, websocket, {"agent_id": agent_id})
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_update_agent_skill(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||||
|
agent_id = str(data.get("agent_id", "")).strip()
|
||||||
|
skill_name = str(data.get("skill_name", "")).strip()
|
||||||
|
enabled = data.get("enabled")
|
||||||
|
if not agent_id or not skill_name or not isinstance(enabled, bool):
|
||||||
|
await websocket.send(json.dumps({"type": "error", "message": "update_agent_skill requires agent_id, skill_name, and boolean enabled."}, ensure_ascii=False))
|
||||||
|
return
|
||||||
|
|
||||||
|
config_name = gateway.config.get("config_name", "default")
|
||||||
|
skills_manager = SkillsManager(project_root=gateway._project_root)
|
||||||
|
skill_names = {item.skill_name for item in skills_manager.list_agent_skill_catalog(config_name, agent_id)}
|
||||||
|
if skill_name not in skill_names:
|
||||||
|
await websocket.send(json.dumps({"type": "error", "message": f"Unknown skill: {skill_name}"}, ensure_ascii=False))
|
||||||
|
return
|
||||||
|
|
||||||
|
if enabled:
|
||||||
|
skills_manager.update_agent_skill_overrides(config_name=config_name, agent_id=agent_id, enable=[skill_name])
|
||||||
|
await gateway.state_sync.on_system_message(f"Enabled skill {skill_name} for {agent_id}")
|
||||||
|
else:
|
||||||
|
skills_manager.update_agent_skill_overrides(config_name=config_name, agent_id=agent_id, disable=[skill_name])
|
||||||
|
await gateway.state_sync.on_system_message(f"Disabled skill {skill_name} for {agent_id}")
|
||||||
|
|
||||||
|
await websocket.send(json.dumps({
|
||||||
|
"type": "agent_skill_updated",
|
||||||
|
"agent_id": agent_id,
|
||||||
|
"skill_name": skill_name,
|
||||||
|
"enabled": enabled,
|
||||||
|
}, ensure_ascii=False))
|
||||||
|
await gateway._handle_reload_runtime_assets()
|
||||||
|
await handle_get_agent_skills(gateway, websocket, {"agent_id": agent_id})
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_get_agent_workspace_file(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||||
|
agent_id = str(data.get("agent_id", "")).strip()
|
||||||
|
filename = gateway._normalize_agent_workspace_filename(data.get("filename"))
|
||||||
|
if not agent_id or not filename:
|
||||||
|
await websocket.send(json.dumps({"type": "error", "message": "get_agent_workspace_file requires agent_id and supported filename."}, ensure_ascii=False))
|
||||||
|
return
|
||||||
|
|
||||||
|
config_name = gateway.config.get("config_name", "default")
|
||||||
|
skills_manager = SkillsManager(project_root=gateway._project_root)
|
||||||
|
asset_dir = skills_manager.get_agent_asset_dir(config_name, agent_id)
|
||||||
|
asset_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
path = asset_dir / filename
|
||||||
|
content = path.read_text(encoding="utf-8") if path.exists() else ""
|
||||||
|
await websocket.send(json.dumps({
|
||||||
|
"type": "agent_workspace_file_loaded",
|
||||||
|
"config_name": config_name,
|
||||||
|
"agent_id": agent_id,
|
||||||
|
"filename": filename,
|
||||||
|
"content": content,
|
||||||
|
}, ensure_ascii=False))
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_update_agent_workspace_file(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||||
|
agent_id = str(data.get("agent_id", "")).strip()
|
||||||
|
filename = gateway._normalize_agent_workspace_filename(data.get("filename"))
|
||||||
|
content = data.get("content")
|
||||||
|
if not agent_id or not filename or not isinstance(content, str):
|
||||||
|
await websocket.send(json.dumps({"type": "error", "message": "update_agent_workspace_file requires agent_id, supported filename, and string content."}, ensure_ascii=False))
|
||||||
|
return
|
||||||
|
|
||||||
|
config_name = gateway.config.get("config_name", "default")
|
||||||
|
skills_manager = SkillsManager(project_root=gateway._project_root)
|
||||||
|
asset_dir = skills_manager.get_agent_asset_dir(config_name, agent_id)
|
||||||
|
asset_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
path = asset_dir / filename
|
||||||
|
path.write_text(content, encoding="utf-8")
|
||||||
|
await gateway.state_sync.on_system_message(f"Updated {filename} for {agent_id}")
|
||||||
|
await websocket.send(json.dumps({"type": "agent_workspace_file_updated", "agent_id": agent_id, "filename": filename}, ensure_ascii=False))
|
||||||
|
await gateway._handle_reload_runtime_assets()
|
||||||
|
await handle_get_agent_workspace_file(gateway, websocket, {"agent_id": agent_id, "filename": filename})
|
||||||
373
backend/services/gateway_cycle_support.py
Normal file
373
backend/services/gateway_cycle_support.py
Normal file
@@ -0,0 +1,373 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""Cycle and monitoring helpers extracted from the main Gateway module."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from backend.data.market_ingest import ingest_symbols
|
||||||
|
from backend.domains import trading as trading_domain
|
||||||
|
from backend.utils.msg_adapter import FrontendAdapter
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def schedule_watchlist_market_store_refresh(gateway: Any, tickers: list[str]) -> None:
|
||||||
|
"""Kick off a non-blocking market-store refresh for an updated watchlist."""
|
||||||
|
if not tickers:
|
||||||
|
return
|
||||||
|
if gateway._watchlist_ingest_task and not gateway._watchlist_ingest_task.done():
|
||||||
|
gateway._watchlist_ingest_task.cancel()
|
||||||
|
gateway._watchlist_ingest_task = asyncio.create_task(
|
||||||
|
refresh_market_store_for_watchlist(gateway, tickers),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def refresh_market_store_for_watchlist(gateway: Any, tickers: list[str]) -> None:
|
||||||
|
"""Refresh the long-lived market store after a watchlist update."""
|
||||||
|
try:
|
||||||
|
await gateway.state_sync.on_system_message(
|
||||||
|
f"正在同步自选股市场数据: {', '.join(tickers)}",
|
||||||
|
)
|
||||||
|
results = await asyncio.to_thread(
|
||||||
|
ingest_symbols,
|
||||||
|
tickers,
|
||||||
|
mode="incremental",
|
||||||
|
)
|
||||||
|
summary = ", ".join(
|
||||||
|
f"{item['symbol']} prices={item['prices']} news={item['news']}"
|
||||||
|
for item in results
|
||||||
|
)
|
||||||
|
await gateway.state_sync.on_system_message(
|
||||||
|
f"自选股市场数据已同步: {summary}",
|
||||||
|
)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Watchlist market store refresh failed: %s", exc)
|
||||||
|
await gateway.state_sync.on_system_message(
|
||||||
|
f"自选股市场数据同步失败: {exc}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def market_status_monitor(gateway: Any) -> None:
|
||||||
|
"""Periodically check and broadcast market status changes."""
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
await gateway.market_service.check_and_broadcast_market_status()
|
||||||
|
|
||||||
|
status = gateway.market_service.get_market_status()
|
||||||
|
if status["status"] == "open" and not gateway.storage.is_live_session_active:
|
||||||
|
gateway.storage.start_live_session()
|
||||||
|
summary = gateway.storage.load_file("summary") or {}
|
||||||
|
gateway._session_start_portfolio_value = summary.get(
|
||||||
|
"totalAssetValue",
|
||||||
|
gateway.storage.initial_cash,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"Session start portfolio: $%s",
|
||||||
|
f"{gateway._session_start_portfolio_value:,.2f}",
|
||||||
|
)
|
||||||
|
elif status["status"] != "open" and gateway.storage.is_live_session_active:
|
||||||
|
gateway.storage.end_live_session()
|
||||||
|
gateway._session_start_portfolio_value = None
|
||||||
|
|
||||||
|
if gateway.storage.is_live_session_active:
|
||||||
|
await update_and_broadcast_live_returns(gateway)
|
||||||
|
|
||||||
|
await asyncio.sleep(60)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
break
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Market status monitor error: %s", exc)
|
||||||
|
await asyncio.sleep(60)
|
||||||
|
|
||||||
|
|
||||||
|
async def update_and_broadcast_live_returns(gateway: Any) -> None:
|
||||||
|
"""Calculate and broadcast live returns for current session."""
|
||||||
|
if not gateway.storage.is_live_session_active:
|
||||||
|
return
|
||||||
|
|
||||||
|
prices = gateway.market_service.get_all_prices()
|
||||||
|
if not prices or not any(p > 0 for p in prices.values()):
|
||||||
|
return
|
||||||
|
|
||||||
|
state = gateway.storage.load_internal_state()
|
||||||
|
equity_history = state.get("equity_history", [])
|
||||||
|
baseline_history = state.get("baseline_history", [])
|
||||||
|
baseline_vw_history = state.get("baseline_vw_history", [])
|
||||||
|
momentum_history = state.get("momentum_history", [])
|
||||||
|
|
||||||
|
current_equity = equity_history[-1]["v"] if equity_history else None
|
||||||
|
current_baseline = baseline_history[-1]["v"] if baseline_history else None
|
||||||
|
current_baseline_vw = baseline_vw_history[-1]["v"] if baseline_vw_history else None
|
||||||
|
current_momentum = momentum_history[-1]["v"] if momentum_history else None
|
||||||
|
|
||||||
|
point = gateway.storage.update_live_returns(
|
||||||
|
current_equity=current_equity,
|
||||||
|
current_baseline=current_baseline,
|
||||||
|
current_baseline_vw=current_baseline_vw,
|
||||||
|
current_momentum=current_momentum,
|
||||||
|
)
|
||||||
|
if point:
|
||||||
|
live_returns = gateway.storage.get_live_returns()
|
||||||
|
await gateway.broadcast(
|
||||||
|
{
|
||||||
|
"type": "team_summary",
|
||||||
|
"equity_return": live_returns["equity_return"],
|
||||||
|
"baseline_return": live_returns["baseline_return"],
|
||||||
|
"baseline_vw_return": live_returns["baseline_vw_return"],
|
||||||
|
"momentum_return": live_returns["momentum_return"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def on_strategy_trigger(gateway: Any, date: str) -> None:
|
||||||
|
"""Handle trading cycle trigger."""
|
||||||
|
if gateway._cycle_lock.locked():
|
||||||
|
logger.warning("Trading cycle already running, skipping trigger for %s", date)
|
||||||
|
await gateway.state_sync.on_system_message(f"已有交易周期在运行,跳过本次触发: {date}")
|
||||||
|
return
|
||||||
|
|
||||||
|
async with gateway._cycle_lock:
|
||||||
|
logger.info("Strategy triggered for %s", date)
|
||||||
|
tickers = gateway.config.get("tickers", [])
|
||||||
|
if gateway.is_backtest:
|
||||||
|
await run_backtest_cycle(gateway, date, tickers)
|
||||||
|
else:
|
||||||
|
await run_live_cycle(gateway, date, tickers)
|
||||||
|
|
||||||
|
|
||||||
|
async def on_heartbeat_trigger(gateway: Any, date: str) -> None:
|
||||||
|
"""Run lightweight heartbeat check for all analysts."""
|
||||||
|
logger.info("[Heartbeat] Running heartbeat check for %s", date)
|
||||||
|
analysts = gateway.pipeline._all_analysts()
|
||||||
|
|
||||||
|
for analyst in analysts:
|
||||||
|
try:
|
||||||
|
ws_id = getattr(analyst, "workspace_id", None)
|
||||||
|
if ws_id:
|
||||||
|
from backend.agents.workspace_manager import get_workspace_dir
|
||||||
|
from pathlib import Path
|
||||||
|
from agentscope.message import Msg
|
||||||
|
|
||||||
|
ws_dir = get_workspace_dir(ws_id)
|
||||||
|
if ws_dir:
|
||||||
|
hb_path = Path(ws_dir) / "HEARTBEAT.md"
|
||||||
|
if hb_path.exists():
|
||||||
|
content = hb_path.read_text(encoding="utf-8").strip()
|
||||||
|
if content:
|
||||||
|
hb_task = f"# 定期主动检查\n\n{content}\n\n请执行上述检查并报告结果。"
|
||||||
|
logger.info("[Heartbeat] Running heartbeat for %s", analyst.name)
|
||||||
|
msg = Msg(role="user", content=hb_task, name="system")
|
||||||
|
await analyst.reply([msg])
|
||||||
|
logger.info("[Heartbeat] %s heartbeat complete", analyst.name)
|
||||||
|
continue
|
||||||
|
logger.debug("[Heartbeat] No HEARTBEAT.md for %s, skipping", analyst.name)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("[Heartbeat] %s failed: %s", analyst.name, exc, exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
|
async def run_backtest_cycle(gateway: Any, date: str, tickers: list[str]) -> None:
|
||||||
|
gateway.market_service.set_backtest_date(date)
|
||||||
|
await gateway.market_service.emit_market_open()
|
||||||
|
|
||||||
|
await gateway.state_sync.on_cycle_start(date)
|
||||||
|
gateway._dashboard.update(date=date, status="Analyzing...")
|
||||||
|
|
||||||
|
prices = gateway.market_service.get_open_prices()
|
||||||
|
close_prices = gateway.market_service.get_close_prices()
|
||||||
|
market_caps = await get_market_caps(gateway, tickers, date)
|
||||||
|
|
||||||
|
result = await gateway.pipeline.run_cycle(
|
||||||
|
tickers=tickers,
|
||||||
|
date=date,
|
||||||
|
prices=prices,
|
||||||
|
close_prices=close_prices,
|
||||||
|
market_caps=market_caps,
|
||||||
|
)
|
||||||
|
|
||||||
|
await gateway.market_service.emit_market_close()
|
||||||
|
settlement_result = result.get("settlement_result")
|
||||||
|
save_cycle_results(gateway, result, date, close_prices, settlement_result)
|
||||||
|
await broadcast_portfolio_updates(gateway, result, close_prices)
|
||||||
|
await finalize_cycle(gateway, date)
|
||||||
|
|
||||||
|
|
||||||
|
async def run_live_cycle(gateway: Any, date: str, tickers: list[str]) -> None:
|
||||||
|
trading_date = gateway.market_service.get_live_trading_date()
|
||||||
|
logger.info("Live cycle: triggered=%s, trading_date=%s", date, trading_date)
|
||||||
|
|
||||||
|
await gateway.state_sync.on_cycle_start(trading_date)
|
||||||
|
gateway._dashboard.update(date=trading_date, status="Analyzing...")
|
||||||
|
|
||||||
|
market_caps = await get_market_caps(gateway, tickers, trading_date)
|
||||||
|
schedule_mode = gateway.config.get("schedule_mode", "daily")
|
||||||
|
market_status = gateway.market_service.get_market_status()
|
||||||
|
current_prices = gateway.market_service.get_all_prices()
|
||||||
|
|
||||||
|
if schedule_mode == "intraday":
|
||||||
|
execute_decisions = market_status.get("status") == "open"
|
||||||
|
if execute_decisions:
|
||||||
|
await gateway.state_sync.on_system_message("定时任务触发:当前处于交易时段,本轮将执行交易决策")
|
||||||
|
else:
|
||||||
|
await gateway.state_sync.on_system_message("定时任务触发:当前非交易时段,本轮仅更新数据与分析,不执行交易")
|
||||||
|
|
||||||
|
result = await gateway.pipeline.run_cycle(
|
||||||
|
tickers=tickers,
|
||||||
|
date=trading_date,
|
||||||
|
prices=current_prices,
|
||||||
|
market_caps=market_caps,
|
||||||
|
execute_decisions=execute_decisions,
|
||||||
|
)
|
||||||
|
close_prices = current_prices
|
||||||
|
else:
|
||||||
|
result = await gateway.pipeline.run_cycle(
|
||||||
|
tickers=tickers,
|
||||||
|
date=trading_date,
|
||||||
|
market_caps=market_caps,
|
||||||
|
get_open_prices_fn=gateway.market_service.wait_for_open_prices,
|
||||||
|
get_close_prices_fn=gateway.market_service.wait_for_close_prices,
|
||||||
|
)
|
||||||
|
close_prices = gateway.market_service.get_all_prices()
|
||||||
|
|
||||||
|
settlement_result = result.get("settlement_result")
|
||||||
|
save_cycle_results(gateway, result, trading_date, close_prices, settlement_result)
|
||||||
|
await broadcast_portfolio_updates(gateway, result, close_prices)
|
||||||
|
await finalize_cycle(gateway, trading_date)
|
||||||
|
|
||||||
|
|
||||||
|
async def finalize_cycle(gateway: Any, date: str) -> None:
|
||||||
|
summary = gateway.storage.load_file("summary") or {}
|
||||||
|
if gateway.storage.is_live_session_active:
|
||||||
|
summary.update(gateway.storage.get_live_returns())
|
||||||
|
|
||||||
|
await gateway.state_sync.on_cycle_end(date, portfolio_summary=summary)
|
||||||
|
holdings = gateway.storage.load_file("holdings") or []
|
||||||
|
trades = gateway.storage.load_file("trades") or []
|
||||||
|
leaderboard = gateway.storage.load_file("leaderboard") or []
|
||||||
|
if leaderboard:
|
||||||
|
await gateway.state_sync.on_leaderboard_update(leaderboard)
|
||||||
|
gateway._dashboard.update(date=date, status="Running", portfolio=summary, holdings=holdings, trades=trades)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_market_caps(gateway: Any, tickers: list[str], date: str) -> dict[str, float]:
|
||||||
|
market_caps: dict[str, float] = {}
|
||||||
|
for ticker in tickers:
|
||||||
|
try:
|
||||||
|
market_cap = None
|
||||||
|
response = await gateway._call_trading_service(
|
||||||
|
f"get_market_cap for {ticker}",
|
||||||
|
lambda client, symbol=ticker: client.get_market_cap(ticker=symbol, end_date=date),
|
||||||
|
)
|
||||||
|
if response is not None:
|
||||||
|
market_cap = response.get("market_cap")
|
||||||
|
if market_cap is None:
|
||||||
|
payload = trading_domain.get_market_cap_payload(ticker=ticker, end_date=date)
|
||||||
|
market_cap = payload.get("market_cap")
|
||||||
|
market_caps[ticker] = market_cap if market_cap else 1e9
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Failed to get market cap for %s, using default 1e9: %s", ticker, exc)
|
||||||
|
market_caps[ticker] = 1e9
|
||||||
|
return market_caps
|
||||||
|
|
||||||
|
|
||||||
|
async def broadcast_portfolio_updates(gateway: Any, result: dict[str, Any], prices: dict[str, float]) -> None:
|
||||||
|
portfolio = result.get("portfolio", {})
|
||||||
|
if portfolio:
|
||||||
|
holdings = FrontendAdapter.build_holdings(portfolio, prices)
|
||||||
|
if holdings:
|
||||||
|
await gateway.state_sync.on_holdings_update(holdings)
|
||||||
|
stats = FrontendAdapter.build_stats(portfolio, prices)
|
||||||
|
if stats:
|
||||||
|
await gateway.state_sync.on_stats_update(stats)
|
||||||
|
|
||||||
|
executed_trades = result.get("executed_trades", [])
|
||||||
|
if executed_trades:
|
||||||
|
await gateway.state_sync.on_trades_executed(executed_trades)
|
||||||
|
|
||||||
|
|
||||||
|
def save_cycle_results(
|
||||||
|
gateway: Any,
|
||||||
|
result: dict[str, Any],
|
||||||
|
date: str,
|
||||||
|
prices: dict[str, float],
|
||||||
|
settlement_result: dict[str, Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
portfolio = result.get("portfolio", {})
|
||||||
|
executed_trades = result.get("executed_trades", [])
|
||||||
|
baseline_values = settlement_result.get("baseline_values") if settlement_result else None
|
||||||
|
if portfolio:
|
||||||
|
gateway.storage.update_dashboard_after_cycle(
|
||||||
|
portfolio=portfolio,
|
||||||
|
prices=prices,
|
||||||
|
date=date,
|
||||||
|
executed_trades=executed_trades,
|
||||||
|
baseline_values=baseline_values,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def run_backtest_dates(gateway: Any, dates: list[str]) -> None:
|
||||||
|
gateway.state_sync.set_backtest_dates(dates)
|
||||||
|
gateway._dashboard.update(days_total=len(dates), days_completed=0)
|
||||||
|
await gateway.state_sync.on_system_message(f"Starting backtest - {len(dates)} trading days")
|
||||||
|
try:
|
||||||
|
for i, date in enumerate(dates):
|
||||||
|
gateway._dashboard.update(days_completed=i)
|
||||||
|
await gateway.on_strategy_trigger(date=date)
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
await gateway.state_sync.on_system_message(f"Backtest complete - {len(dates)} days")
|
||||||
|
summary = gateway.storage.load_file("summary") or {}
|
||||||
|
gateway._dashboard.update(status="Complete", portfolio=summary, days_completed=len(dates))
|
||||||
|
gateway._dashboard.stop()
|
||||||
|
gateway._dashboard.print_final_summary()
|
||||||
|
except Exception as exc:
|
||||||
|
error_msg = f"Backtest failed: {type(exc).__name__}: {str(exc)}"
|
||||||
|
logger.error(error_msg, exc_info=True)
|
||||||
|
asyncio.create_task(gateway.state_sync.on_system_message(error_msg))
|
||||||
|
gateway._dashboard.update(status=f"Failed: {str(exc)}")
|
||||||
|
gateway._dashboard.stop()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
gateway._backtest_task = None
|
||||||
|
|
||||||
|
|
||||||
|
def handle_backtest_exception(gateway: Any, task: asyncio.Task) -> None:
|
||||||
|
try:
|
||||||
|
task.result()
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.info("Backtest task was cancelled")
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Backtest task failed with exception:%s:%s", type(exc).__name__, exc, exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
|
def handle_manual_cycle_exception(gateway: Any, task: asyncio.Task) -> None:
|
||||||
|
gateway._manual_cycle_task = None
|
||||||
|
try:
|
||||||
|
task.result()
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.info("Manual cycle task was cancelled")
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Manual cycle task failed with exception:%s:%s", type(exc).__name__, exc, exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
|
def set_backtest_dates(gateway: Any, dates: list[str]) -> None:
|
||||||
|
gateway.state_sync.set_backtest_dates(dates)
|
||||||
|
if dates:
|
||||||
|
gateway._backtest_start_date = dates[0]
|
||||||
|
gateway._backtest_end_date = dates[-1]
|
||||||
|
gateway._dashboard.days_total = len(dates)
|
||||||
|
|
||||||
|
|
||||||
|
def stop_gateway(gateway: Any) -> None:
|
||||||
|
gateway.state_sync.save_state()
|
||||||
|
gateway.market_service.stop()
|
||||||
|
if gateway._backtest_task:
|
||||||
|
gateway._backtest_task.cancel()
|
||||||
|
if gateway._market_status_task:
|
||||||
|
gateway._market_status_task.cancel()
|
||||||
|
if gateway._watchlist_ingest_task:
|
||||||
|
gateway._watchlist_ingest_task.cancel()
|
||||||
|
gateway._dashboard.stop()
|
||||||
174
backend/services/gateway_runtime_support.py
Normal file
174
backend/services/gateway_runtime_support.py
Normal file
@@ -0,0 +1,174 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""Runtime/state support helpers extracted from the main Gateway module."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from backend.data.provider_utils import normalize_symbol
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_watchlist(raw_tickers: Any) -> list[str]:
|
||||||
|
"""Parse watchlist payloads from websocket messages."""
|
||||||
|
if raw_tickers is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
if isinstance(raw_tickers, str):
|
||||||
|
candidates = raw_tickers.split(",")
|
||||||
|
elif isinstance(raw_tickers, list):
|
||||||
|
candidates = raw_tickers
|
||||||
|
else:
|
||||||
|
candidates = [raw_tickers]
|
||||||
|
|
||||||
|
tickers: list[str] = []
|
||||||
|
for candidate in candidates:
|
||||||
|
symbol = normalize_symbol(str(candidate).strip().strip("\"'"))
|
||||||
|
if symbol and symbol not in tickers:
|
||||||
|
tickers.append(symbol)
|
||||||
|
return tickers
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_agent_workspace_filename(
|
||||||
|
raw_name: Any,
|
||||||
|
*,
|
||||||
|
allowlist: set[str],
|
||||||
|
) -> str | None:
|
||||||
|
"""Restrict editable workspace files to a safe allowlist."""
|
||||||
|
filename = str(raw_name or "").strip()
|
||||||
|
if filename in allowlist:
|
||||||
|
return filename
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def apply_runtime_config(gateway: Any, runtime_config: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Apply runtime config to gateway-owned services and state."""
|
||||||
|
warnings: list[str] = []
|
||||||
|
|
||||||
|
ticker_changes = gateway.market_service.update_tickers(
|
||||||
|
runtime_config.get("tickers", []),
|
||||||
|
)
|
||||||
|
gateway.config["tickers"] = ticker_changes["active"]
|
||||||
|
|
||||||
|
gateway.pipeline.max_comm_cycles = int(runtime_config["max_comm_cycles"])
|
||||||
|
gateway.config["max_comm_cycles"] = gateway.pipeline.max_comm_cycles
|
||||||
|
gateway.config["schedule_mode"] = runtime_config.get(
|
||||||
|
"schedule_mode",
|
||||||
|
gateway.config.get("schedule_mode", "daily"),
|
||||||
|
)
|
||||||
|
gateway.config["interval_minutes"] = int(
|
||||||
|
runtime_config.get(
|
||||||
|
"interval_minutes",
|
||||||
|
gateway.config.get("interval_minutes", 60),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
gateway.config["trigger_time"] = runtime_config.get(
|
||||||
|
"trigger_time",
|
||||||
|
gateway.config.get("trigger_time", "09:30"),
|
||||||
|
)
|
||||||
|
|
||||||
|
if gateway.scheduler:
|
||||||
|
gateway.scheduler.reconfigure(
|
||||||
|
mode=gateway.config["schedule_mode"],
|
||||||
|
trigger_time=gateway.config["trigger_time"],
|
||||||
|
interval_minutes=gateway.config["interval_minutes"],
|
||||||
|
)
|
||||||
|
|
||||||
|
pm_apply_result = gateway.pipeline.pm.apply_runtime_portfolio_config(
|
||||||
|
margin_requirement=runtime_config["margin_requirement"],
|
||||||
|
)
|
||||||
|
gateway.config["margin_requirement"] = gateway.pipeline.pm.portfolio.get(
|
||||||
|
"margin_requirement",
|
||||||
|
runtime_config["margin_requirement"],
|
||||||
|
)
|
||||||
|
|
||||||
|
requested_initial_cash = float(runtime_config["initial_cash"])
|
||||||
|
current_initial_cash = float(gateway.storage.initial_cash)
|
||||||
|
initial_cash_applied = requested_initial_cash == current_initial_cash
|
||||||
|
if not initial_cash_applied:
|
||||||
|
if (
|
||||||
|
gateway.storage.can_apply_initial_cash()
|
||||||
|
and gateway.pipeline.pm.can_apply_initial_cash()
|
||||||
|
):
|
||||||
|
initial_cash_applied = gateway.storage.apply_initial_cash(
|
||||||
|
requested_initial_cash,
|
||||||
|
)
|
||||||
|
if initial_cash_applied:
|
||||||
|
gateway.pipeline.pm.apply_runtime_portfolio_config(
|
||||||
|
initial_cash=requested_initial_cash,
|
||||||
|
)
|
||||||
|
gateway.config["initial_cash"] = gateway.storage.initial_cash
|
||||||
|
else:
|
||||||
|
warnings.append(
|
||||||
|
"initial_cash changed in BOOTSTRAP.md but was not applied "
|
||||||
|
"because the run already has positions, margin usage, or trades.",
|
||||||
|
)
|
||||||
|
|
||||||
|
requested_enable_memory = bool(runtime_config["enable_memory"])
|
||||||
|
current_enable_memory = bool(gateway.config.get("enable_memory", False))
|
||||||
|
if requested_enable_memory != current_enable_memory:
|
||||||
|
warnings.append(
|
||||||
|
"enable_memory changed in BOOTSTRAP.md but still requires a restart "
|
||||||
|
"because long-term memory contexts are created at startup.",
|
||||||
|
)
|
||||||
|
|
||||||
|
sync_runtime_state(gateway)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"runtime_config_requested": runtime_config,
|
||||||
|
"runtime_config_applied": {
|
||||||
|
"tickers": list(gateway.config.get("tickers", [])),
|
||||||
|
"schedule_mode": gateway.config.get("schedule_mode", "daily"),
|
||||||
|
"interval_minutes": gateway.config.get("interval_minutes", 60),
|
||||||
|
"trigger_time": gateway.config.get("trigger_time", "09:30"),
|
||||||
|
"initial_cash": gateway.storage.initial_cash,
|
||||||
|
"margin_requirement": gateway.config["margin_requirement"],
|
||||||
|
"max_comm_cycles": gateway.config["max_comm_cycles"],
|
||||||
|
"enable_memory": gateway.config.get("enable_memory", False),
|
||||||
|
},
|
||||||
|
"runtime_config_status": {
|
||||||
|
"tickers": True,
|
||||||
|
"schedule_mode": True,
|
||||||
|
"interval_minutes": True,
|
||||||
|
"trigger_time": True,
|
||||||
|
"initial_cash": initial_cash_applied,
|
||||||
|
"margin_requirement": pm_apply_result["margin_requirement"],
|
||||||
|
"max_comm_cycles": True,
|
||||||
|
"enable_memory": requested_enable_memory == current_enable_memory,
|
||||||
|
},
|
||||||
|
"ticker_changes": ticker_changes,
|
||||||
|
"runtime_config_warnings": warnings,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def sync_runtime_state(gateway: Any) -> None:
|
||||||
|
"""Refresh persisted state and dashboard after runtime config changes."""
|
||||||
|
gateway.state_sync.update_state("tickers", gateway.config.get("tickers", []))
|
||||||
|
gateway.state_sync.update_state(
|
||||||
|
"runtime_config",
|
||||||
|
{
|
||||||
|
"tickers": gateway.config.get("tickers", []),
|
||||||
|
"schedule_mode": gateway.config.get("schedule_mode", "daily"),
|
||||||
|
"interval_minutes": gateway.config.get("interval_minutes", 60),
|
||||||
|
"trigger_time": gateway.config.get("trigger_time", "09:30"),
|
||||||
|
"initial_cash": gateway.storage.initial_cash,
|
||||||
|
"margin_requirement": gateway.config.get("margin_requirement"),
|
||||||
|
"max_comm_cycles": gateway.config.get("max_comm_cycles"),
|
||||||
|
"enable_memory": gateway.config.get("enable_memory", False),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
gateway.storage.update_server_state_from_dashboard(gateway.state_sync.state)
|
||||||
|
gateway.state_sync.save_state()
|
||||||
|
|
||||||
|
gateway._dashboard.tickers = list(gateway.config.get("tickers", []))
|
||||||
|
gateway._dashboard.initial_cash = gateway.storage.initial_cash
|
||||||
|
gateway._dashboard.enable_memory = bool(gateway.config.get("enable_memory", False))
|
||||||
|
|
||||||
|
summary = gateway.storage.load_file("summary") or {}
|
||||||
|
holdings = gateway.storage.load_file("holdings") or []
|
||||||
|
trades = gateway.storage.load_file("trades") or []
|
||||||
|
gateway._dashboard.update(
|
||||||
|
portfolio=summary,
|
||||||
|
holdings=holdings,
|
||||||
|
trades=trades,
|
||||||
|
)
|
||||||
711
backend/services/gateway_stock_handlers.py
Normal file
711
backend/services/gateway_stock_handlers.py
Normal file
@@ -0,0 +1,711 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""Stock-related Gateway handlers extracted from the main Gateway module."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from backend.data.provider_utils import normalize_symbol
|
||||||
|
from backend.domains import news as news_domain
|
||||||
|
from backend.domains import trading as trading_domain
|
||||||
|
from backend.enrich.news_enricher import enrich_news_for_symbol
|
||||||
|
from backend.enrich.llm_enricher import llm_enrichment_enabled
|
||||||
|
from backend.tools.data_tools import prices_to_df
|
||||||
|
from shared.client import NewsServiceClient, TradingServiceClient
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_get_stock_history(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||||
|
ticker = normalize_symbol(data.get("ticker", ""))
|
||||||
|
if not ticker:
|
||||||
|
await websocket.send(json.dumps({
|
||||||
|
"type": "stock_history_loaded",
|
||||||
|
"ticker": "",
|
||||||
|
"prices": [],
|
||||||
|
"source": None,
|
||||||
|
"error": "invalid ticker",
|
||||||
|
}, ensure_ascii=False))
|
||||||
|
return
|
||||||
|
|
||||||
|
lookback_days = data.get("lookback_days", 90)
|
||||||
|
try:
|
||||||
|
lookback_days = max(7, min(int(lookback_days), 365))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
lookback_days = 90
|
||||||
|
|
||||||
|
end_date = gateway.state_sync.state.get("current_date") or datetime.now().strftime("%Y-%m-%d")
|
||||||
|
try:
|
||||||
|
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
|
||||||
|
except ValueError:
|
||||||
|
end_dt = datetime.now()
|
||||||
|
end_date = end_dt.strftime("%Y-%m-%d")
|
||||||
|
start_date = (end_dt - timedelta(days=lookback_days)).strftime("%Y-%m-%d")
|
||||||
|
|
||||||
|
prices = []
|
||||||
|
source = "polygon"
|
||||||
|
response = await gateway._call_trading_service(
|
||||||
|
"get_prices for history",
|
||||||
|
lambda client: client.get_prices(ticker=ticker, start_date=start_date, end_date=end_date),
|
||||||
|
)
|
||||||
|
if response is not None:
|
||||||
|
prices = response.prices
|
||||||
|
source = "trading_service"
|
||||||
|
|
||||||
|
if not prices:
|
||||||
|
prices = await asyncio.to_thread(gateway.storage.market_store.get_ohlc, ticker, start_date, end_date)
|
||||||
|
if not prices:
|
||||||
|
payload = await asyncio.to_thread(
|
||||||
|
trading_domain.get_prices_payload,
|
||||||
|
ticker=ticker,
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
)
|
||||||
|
prices = payload.get("prices") or []
|
||||||
|
usage_snapshot = gateway._provider_router.get_usage_snapshot()
|
||||||
|
source = usage_snapshot.get("last_success", {}).get("prices")
|
||||||
|
if prices:
|
||||||
|
await asyncio.to_thread(
|
||||||
|
gateway.storage.market_store.upsert_ohlc,
|
||||||
|
ticker,
|
||||||
|
[price.model_dump() for price in prices],
|
||||||
|
source=source or "provider",
|
||||||
|
)
|
||||||
|
|
||||||
|
await websocket.send(json.dumps({
|
||||||
|
"type": "stock_history_loaded",
|
||||||
|
"ticker": ticker,
|
||||||
|
"prices": [price if isinstance(price, dict) else price.model_dump() for price in prices][-120:],
|
||||||
|
"source": source,
|
||||||
|
"start_date": start_date,
|
||||||
|
"end_date": end_date,
|
||||||
|
}, ensure_ascii=False, default=str))
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_get_stock_explain_events(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||||
|
ticker = normalize_symbol(data.get("ticker", ""))
|
||||||
|
snapshot = gateway.storage.runtime_db.get_stock_explain_snapshot(ticker)
|
||||||
|
await websocket.send(json.dumps({
|
||||||
|
"type": "stock_explain_events_loaded",
|
||||||
|
"ticker": ticker,
|
||||||
|
"events": snapshot.get("events", []),
|
||||||
|
"signals": snapshot.get("signals", []),
|
||||||
|
"trades": snapshot.get("trades", []),
|
||||||
|
}, ensure_ascii=False, default=str))
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_get_stock_news(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||||
|
ticker = normalize_symbol(data.get("ticker", ""))
|
||||||
|
if not ticker:
|
||||||
|
await websocket.send(json.dumps({
|
||||||
|
"type": "stock_news_loaded",
|
||||||
|
"ticker": "",
|
||||||
|
"news": [],
|
||||||
|
"source": None,
|
||||||
|
"error": "invalid ticker",
|
||||||
|
}, ensure_ascii=False))
|
||||||
|
return
|
||||||
|
|
||||||
|
lookback_days = data.get("lookback_days", 30)
|
||||||
|
limit = data.get("limit", 12)
|
||||||
|
try:
|
||||||
|
lookback_days = max(7, min(int(lookback_days), 180))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
lookback_days = 30
|
||||||
|
try:
|
||||||
|
limit = max(1, min(int(limit), 30))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
limit = 12
|
||||||
|
|
||||||
|
end_date = gateway.state_sync.state.get("current_date") or datetime.now().strftime("%Y-%m-%d")
|
||||||
|
try:
|
||||||
|
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
|
||||||
|
except ValueError:
|
||||||
|
end_dt = datetime.now()
|
||||||
|
end_date = end_dt.strftime("%Y-%m-%d")
|
||||||
|
start_date = (end_dt - timedelta(days=lookback_days)).strftime("%Y-%m-%d")
|
||||||
|
|
||||||
|
news_rows = []
|
||||||
|
source = "polygon"
|
||||||
|
response = await gateway._call_news_service(
|
||||||
|
"get_enriched_news",
|
||||||
|
lambda client: client.get_enriched_news(
|
||||||
|
ticker=ticker,
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
limit=limit,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if response is not None:
|
||||||
|
news_rows = response.get("news") or []
|
||||||
|
source = "news_service"
|
||||||
|
|
||||||
|
if not news_rows:
|
||||||
|
payload = await asyncio.to_thread(
|
||||||
|
news_domain.get_enriched_news,
|
||||||
|
gateway.storage.market_store,
|
||||||
|
ticker=ticker,
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
limit=max(limit, 50),
|
||||||
|
)
|
||||||
|
news_rows = (payload.get("news") or [])[-limit:]
|
||||||
|
source = "market_store"
|
||||||
|
|
||||||
|
await websocket.send(json.dumps({
|
||||||
|
"type": "stock_news_loaded",
|
||||||
|
"ticker": ticker,
|
||||||
|
"news": news_rows[-limit:],
|
||||||
|
"source": source,
|
||||||
|
"start_date": start_date,
|
||||||
|
"end_date": end_date,
|
||||||
|
}, ensure_ascii=False, default=str))
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_get_stock_news_for_date(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||||
|
ticker = normalize_symbol(data.get("ticker", ""))
|
||||||
|
trade_date = str(data.get("date") or "").strip()
|
||||||
|
if not ticker or not trade_date:
|
||||||
|
await websocket.send(json.dumps({
|
||||||
|
"type": "stock_news_for_date_loaded",
|
||||||
|
"ticker": ticker,
|
||||||
|
"date": trade_date,
|
||||||
|
"news": [],
|
||||||
|
"error": "ticker and date are required",
|
||||||
|
}, ensure_ascii=False))
|
||||||
|
return
|
||||||
|
|
||||||
|
limit = data.get("limit", 20)
|
||||||
|
try:
|
||||||
|
limit = max(1, min(int(limit), 50))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
limit = 20
|
||||||
|
|
||||||
|
source = "market_store"
|
||||||
|
news_rows = []
|
||||||
|
response = await gateway._call_news_service(
|
||||||
|
"get_news_for_date",
|
||||||
|
lambda client: client.get_news_for_date(ticker=ticker, date=trade_date, limit=limit),
|
||||||
|
)
|
||||||
|
if response is not None:
|
||||||
|
news_rows = response.get("news") or []
|
||||||
|
source = "news_service"
|
||||||
|
|
||||||
|
if not news_rows:
|
||||||
|
payload = await asyncio.to_thread(
|
||||||
|
news_domain.get_news_for_date,
|
||||||
|
gateway.storage.market_store,
|
||||||
|
ticker=ticker,
|
||||||
|
date=trade_date,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
news_rows = payload.get("news") or []
|
||||||
|
source = "market_store"
|
||||||
|
|
||||||
|
await websocket.send(json.dumps({
|
||||||
|
"type": "stock_news_for_date_loaded",
|
||||||
|
"ticker": ticker,
|
||||||
|
"date": trade_date,
|
||||||
|
"news": news_rows,
|
||||||
|
"source": source,
|
||||||
|
}, ensure_ascii=False, default=str))
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_get_stock_news_timeline(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||||
|
ticker = normalize_symbol(data.get("ticker", ""))
|
||||||
|
if not ticker:
|
||||||
|
await websocket.send(json.dumps({
|
||||||
|
"type": "stock_news_timeline_loaded",
|
||||||
|
"ticker": "",
|
||||||
|
"timeline": [],
|
||||||
|
"error": "invalid ticker",
|
||||||
|
}, ensure_ascii=False))
|
||||||
|
return
|
||||||
|
|
||||||
|
lookback_days = data.get("lookback_days", 90)
|
||||||
|
try:
|
||||||
|
lookback_days = max(7, min(int(lookback_days), 365))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
lookback_days = 90
|
||||||
|
|
||||||
|
end_date = gateway.state_sync.state.get("current_date") or datetime.now().strftime("%Y-%m-%d")
|
||||||
|
try:
|
||||||
|
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
|
||||||
|
except ValueError:
|
||||||
|
end_dt = datetime.now()
|
||||||
|
end_date = end_dt.strftime("%Y-%m-%d")
|
||||||
|
start_date = (end_dt - timedelta(days=lookback_days)).strftime("%Y-%m-%d")
|
||||||
|
|
||||||
|
timeline = []
|
||||||
|
response = await gateway._call_news_service(
|
||||||
|
"get_news_timeline",
|
||||||
|
lambda client: client.get_news_timeline(ticker=ticker, start_date=start_date, end_date=end_date),
|
||||||
|
)
|
||||||
|
if response is not None:
|
||||||
|
timeline = response.get("timeline") or []
|
||||||
|
|
||||||
|
if not timeline:
|
||||||
|
payload = await asyncio.to_thread(
|
||||||
|
news_domain.get_news_timeline,
|
||||||
|
gateway.storage.market_store,
|
||||||
|
ticker=ticker,
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
)
|
||||||
|
timeline = payload.get("timeline") or []
|
||||||
|
|
||||||
|
await websocket.send(json.dumps({
|
||||||
|
"type": "stock_news_timeline_loaded",
|
||||||
|
"ticker": ticker,
|
||||||
|
"timeline": timeline,
|
||||||
|
"start_date": start_date,
|
||||||
|
"end_date": end_date,
|
||||||
|
}, ensure_ascii=False, default=str))
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_get_stock_news_categories(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||||
|
ticker = normalize_symbol(data.get("ticker", ""))
|
||||||
|
if not ticker:
|
||||||
|
await websocket.send(json.dumps({
|
||||||
|
"type": "stock_news_categories_loaded",
|
||||||
|
"ticker": "",
|
||||||
|
"categories": {},
|
||||||
|
"error": "invalid ticker",
|
||||||
|
}, ensure_ascii=False))
|
||||||
|
return
|
||||||
|
|
||||||
|
lookback_days = data.get("lookback_days", 90)
|
||||||
|
try:
|
||||||
|
lookback_days = max(7, min(int(lookback_days), 365))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
lookback_days = 90
|
||||||
|
|
||||||
|
end_date = gateway.state_sync.state.get("current_date") or datetime.now().strftime("%Y-%m-%d")
|
||||||
|
try:
|
||||||
|
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
|
||||||
|
except ValueError:
|
||||||
|
end_dt = datetime.now()
|
||||||
|
end_date = end_dt.strftime("%Y-%m-%d")
|
||||||
|
start_date = (end_dt - timedelta(days=lookback_days)).strftime("%Y-%m-%d")
|
||||||
|
|
||||||
|
categories = {}
|
||||||
|
response = await gateway._call_news_service(
|
||||||
|
"get_categories",
|
||||||
|
lambda client: client.get_categories(
|
||||||
|
ticker=ticker,
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
limit=200,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if response is not None:
|
||||||
|
categories = response.get("categories") or {}
|
||||||
|
|
||||||
|
if not categories:
|
||||||
|
payload = await asyncio.to_thread(
|
||||||
|
news_domain.get_news_categories,
|
||||||
|
gateway.storage.market_store,
|
||||||
|
ticker=ticker,
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
limit=200,
|
||||||
|
)
|
||||||
|
categories = payload.get("categories") or {}
|
||||||
|
|
||||||
|
await websocket.send(json.dumps({
|
||||||
|
"type": "stock_news_categories_loaded",
|
||||||
|
"ticker": ticker,
|
||||||
|
"categories": categories,
|
||||||
|
"start_date": start_date,
|
||||||
|
"end_date": end_date,
|
||||||
|
}, ensure_ascii=False, default=str))
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_get_stock_range_explain(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||||
|
ticker = normalize_symbol(data.get("ticker", ""))
|
||||||
|
start_date = str(data.get("start_date") or "").strip()
|
||||||
|
end_date = str(data.get("end_date") or "").strip()
|
||||||
|
if not ticker or not start_date or not end_date:
|
||||||
|
await websocket.send(json.dumps({
|
||||||
|
"type": "stock_range_explain_loaded",
|
||||||
|
"ticker": ticker,
|
||||||
|
"result": {"error": "ticker, start_date, end_date are required"},
|
||||||
|
}, ensure_ascii=False))
|
||||||
|
return
|
||||||
|
|
||||||
|
article_ids = data.get("article_ids")
|
||||||
|
result = None
|
||||||
|
response = await gateway._call_news_service(
|
||||||
|
"get_range_explain",
|
||||||
|
lambda client: client.get_range_explain(
|
||||||
|
ticker=ticker,
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
article_ids=article_ids if isinstance(article_ids, list) else None,
|
||||||
|
limit=100,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if response is not None:
|
||||||
|
result = response.get("result")
|
||||||
|
|
||||||
|
if result is None:
|
||||||
|
payload = await asyncio.to_thread(
|
||||||
|
news_domain.get_range_explain_payload,
|
||||||
|
gateway.storage.market_store,
|
||||||
|
ticker=ticker,
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
article_ids=article_ids if isinstance(article_ids, list) else None,
|
||||||
|
limit=100,
|
||||||
|
)
|
||||||
|
result = payload.get("result")
|
||||||
|
|
||||||
|
await websocket.send(json.dumps({
|
||||||
|
"type": "stock_range_explain_loaded",
|
||||||
|
"ticker": ticker,
|
||||||
|
"result": result,
|
||||||
|
}, ensure_ascii=False, default=str))
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_get_stock_insider_trades(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||||
|
ticker = normalize_symbol(data.get("ticker", ""))
|
||||||
|
if not ticker:
|
||||||
|
await websocket.send(json.dumps({
|
||||||
|
"type": "stock_insider_trades_loaded",
|
||||||
|
"ticker": "",
|
||||||
|
"trades": [],
|
||||||
|
"error": "invalid ticker",
|
||||||
|
}, ensure_ascii=False))
|
||||||
|
return
|
||||||
|
|
||||||
|
end_date = str(data.get("end_date") or gateway.state_sync.state.get("current_date") or datetime.now().strftime("%Y-%m-%d")).strip()[:10]
|
||||||
|
start_date = str(data.get("start_date") or "").strip()[:10]
|
||||||
|
limit = int(data.get("limit", 50))
|
||||||
|
|
||||||
|
trades = []
|
||||||
|
response = await gateway._call_trading_service(
|
||||||
|
"get_insider_trades",
|
||||||
|
lambda client: client.get_insider_trades(
|
||||||
|
ticker=ticker,
|
||||||
|
end_date=end_date,
|
||||||
|
start_date=start_date if start_date else None,
|
||||||
|
limit=limit,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if response is not None:
|
||||||
|
trades = response.insider_trades
|
||||||
|
|
||||||
|
if not trades:
|
||||||
|
payload = await asyncio.to_thread(
|
||||||
|
trading_domain.get_insider_trades_payload,
|
||||||
|
ticker=ticker,
|
||||||
|
end_date=end_date,
|
||||||
|
start_date=start_date if start_date else None,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
trades = payload.get("insider_trades") or []
|
||||||
|
|
||||||
|
sorted_trades = sorted(trades, key=lambda t: t.transaction_date or "", reverse=True)
|
||||||
|
formatted_trades = [{
|
||||||
|
"ticker": t.ticker,
|
||||||
|
"name": t.name,
|
||||||
|
"title": t.title,
|
||||||
|
"is_board_director": t.is_board_director,
|
||||||
|
"transaction_date": t.transaction_date,
|
||||||
|
"transaction_shares": t.transaction_shares,
|
||||||
|
"transaction_price_per_share": t.transaction_price_per_share,
|
||||||
|
"transaction_value": t.transaction_value,
|
||||||
|
"shares_owned_before_transaction": t.shares_owned_before_transaction,
|
||||||
|
"shares_owned_after_transaction": t.shares_owned_after_transaction,
|
||||||
|
"security_title": t.security_title,
|
||||||
|
"filing_date": t.filing_date,
|
||||||
|
"holding_change": (
|
||||||
|
(t.shares_owned_after_transaction or 0) - (t.shares_owned_before_transaction or 0)
|
||||||
|
if t.shares_owned_after_transaction and t.shares_owned_before_transaction else None
|
||||||
|
),
|
||||||
|
"is_buy": ((t.transaction_shares or 0) > 0) if t.transaction_shares is not None else None,
|
||||||
|
} for t in sorted_trades]
|
||||||
|
|
||||||
|
await websocket.send(json.dumps({
|
||||||
|
"type": "stock_insider_trades_loaded",
|
||||||
|
"ticker": ticker,
|
||||||
|
"start_date": start_date or None,
|
||||||
|
"end_date": end_date,
|
||||||
|
"trades": formatted_trades,
|
||||||
|
}, ensure_ascii=False, default=str))
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_get_stock_story(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||||
|
ticker = normalize_symbol(data.get("ticker", ""))
|
||||||
|
if not ticker:
|
||||||
|
await websocket.send(json.dumps({
|
||||||
|
"type": "stock_story_loaded",
|
||||||
|
"ticker": "",
|
||||||
|
"story": "",
|
||||||
|
"error": "invalid ticker",
|
||||||
|
}, ensure_ascii=False))
|
||||||
|
return
|
||||||
|
|
||||||
|
as_of_date = str(data.get("as_of_date") or gateway.state_sync.state.get("current_date") or datetime.now().strftime("%Y-%m-%d")).strip()[:10]
|
||||||
|
result = await gateway._call_news_service(
|
||||||
|
"get_story",
|
||||||
|
lambda client: client.get_story(ticker=ticker, as_of_date=as_of_date),
|
||||||
|
)
|
||||||
|
if result is None:
|
||||||
|
result = await asyncio.to_thread(
|
||||||
|
news_domain.get_story_payload,
|
||||||
|
gateway.storage.market_store,
|
||||||
|
ticker=ticker,
|
||||||
|
as_of_date=as_of_date,
|
||||||
|
)
|
||||||
|
|
||||||
|
await websocket.send(json.dumps({
|
||||||
|
"type": "stock_story_loaded",
|
||||||
|
"ticker": ticker,
|
||||||
|
"as_of_date": as_of_date,
|
||||||
|
"story": result.get("story") or "",
|
||||||
|
"source": result.get("source") or "local",
|
||||||
|
}, ensure_ascii=False, default=str))
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_get_stock_similar_days(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||||
|
ticker = normalize_symbol(data.get("ticker", ""))
|
||||||
|
target_date = str(data.get("date") or "").strip()[:10]
|
||||||
|
if not ticker or not target_date:
|
||||||
|
await websocket.send(json.dumps({
|
||||||
|
"type": "stock_similar_days_loaded",
|
||||||
|
"ticker": ticker,
|
||||||
|
"date": target_date,
|
||||||
|
"items": [],
|
||||||
|
"error": "ticker and date are required",
|
||||||
|
}, ensure_ascii=False))
|
||||||
|
return
|
||||||
|
|
||||||
|
top_k = data.get("top_k", 8)
|
||||||
|
try:
|
||||||
|
top_k = max(1, min(int(top_k), 20))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
top_k = 8
|
||||||
|
|
||||||
|
result = await gateway._call_news_service(
|
||||||
|
"get_similar_days",
|
||||||
|
lambda client: client.get_similar_days(ticker=ticker, date=target_date, n_similar=top_k),
|
||||||
|
)
|
||||||
|
if result is None:
|
||||||
|
result = await asyncio.to_thread(
|
||||||
|
news_domain.get_similar_days_payload,
|
||||||
|
gateway.storage.market_store,
|
||||||
|
ticker=ticker,
|
||||||
|
date=target_date,
|
||||||
|
n_similar=top_k,
|
||||||
|
)
|
||||||
|
|
||||||
|
await websocket.send(json.dumps({
|
||||||
|
"type": "stock_similar_days_loaded",
|
||||||
|
"ticker": ticker,
|
||||||
|
"date": target_date,
|
||||||
|
**result,
|
||||||
|
}, ensure_ascii=False, default=str))
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_get_stock_technical_indicators(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||||
|
ticker = normalize_symbol(data.get("ticker", ""))
|
||||||
|
if not ticker:
|
||||||
|
await websocket.send(json.dumps({
|
||||||
|
"type": "stock_technical_indicators_loaded",
|
||||||
|
"ticker": ticker,
|
||||||
|
"indicators": None,
|
||||||
|
"error": "ticker is required",
|
||||||
|
}, ensure_ascii=False))
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
end_date = datetime.now()
|
||||||
|
start_date = end_date - timedelta(days=250)
|
||||||
|
|
||||||
|
prices = None
|
||||||
|
response = await gateway._call_trading_service(
|
||||||
|
"get_prices",
|
||||||
|
lambda client: client.get_prices(
|
||||||
|
ticker=ticker,
|
||||||
|
start_date=start_date.strftime("%Y-%m-%d"),
|
||||||
|
end_date=end_date.strftime("%Y-%m-%d"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if response is not None:
|
||||||
|
prices = response.prices
|
||||||
|
|
||||||
|
if prices is None:
|
||||||
|
payload = trading_domain.get_prices_payload(
|
||||||
|
ticker=ticker,
|
||||||
|
start_date=start_date.strftime("%Y-%m-%d"),
|
||||||
|
end_date=end_date.strftime("%Y-%m-%d"),
|
||||||
|
)
|
||||||
|
prices = payload.get("prices") or []
|
||||||
|
|
||||||
|
if not prices or len(prices) < 20:
|
||||||
|
await websocket.send(json.dumps({
|
||||||
|
"type": "stock_technical_indicators_loaded",
|
||||||
|
"ticker": ticker,
|
||||||
|
"indicators": None,
|
||||||
|
"error": "Insufficient price data",
|
||||||
|
}, ensure_ascii=False))
|
||||||
|
return
|
||||||
|
|
||||||
|
df = prices_to_df(prices)
|
||||||
|
signal = gateway._technical_analyzer.analyze(ticker, df)
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
df_sorted = df.sort_values("time").reset_index(drop=True)
|
||||||
|
df_sorted["returns"] = df_sorted["close"].pct_change()
|
||||||
|
vol_10 = float(df_sorted["returns"].tail(10).std() * (252**0.5) * 100) if len(df_sorted) >= 10 else None
|
||||||
|
vol_20 = float(df_sorted["returns"].tail(20).std() * (252**0.5) * 100) if len(df_sorted) >= 20 else None
|
||||||
|
vol_60 = float(df_sorted["returns"].tail(60).std() * (252**0.5) * 100) if len(df_sorted) >= 60 else None
|
||||||
|
ma_distance = {}
|
||||||
|
for ma_key in ["ma5", "ma10", "ma20", "ma50", "ma200"]:
|
||||||
|
ma_value = getattr(signal, ma_key, None)
|
||||||
|
ma_distance[ma_key] = ((signal.current_price - ma_value) / ma_value) * 100 if ma_value and ma_value > 0 else None
|
||||||
|
|
||||||
|
indicators = {
|
||||||
|
"ticker": ticker,
|
||||||
|
"current_price": signal.current_price,
|
||||||
|
"ma": {
|
||||||
|
"ma5": signal.ma5,
|
||||||
|
"ma10": signal.ma10,
|
||||||
|
"ma20": signal.ma20,
|
||||||
|
"ma50": signal.ma50,
|
||||||
|
"ma200": signal.ma200,
|
||||||
|
"distance": ma_distance,
|
||||||
|
},
|
||||||
|
"rsi": {
|
||||||
|
"rsi14": signal.rsi14,
|
||||||
|
"status": "oversold" if signal.rsi14 < 30 else "overbought" if signal.rsi14 > 70 else "neutral",
|
||||||
|
},
|
||||||
|
"macd": {
|
||||||
|
"macd": signal.macd,
|
||||||
|
"signal": signal.macd_signal,
|
||||||
|
"histogram": signal.macd - signal.macd_signal,
|
||||||
|
},
|
||||||
|
"bollinger": {
|
||||||
|
"upper": signal.bollinger_upper,
|
||||||
|
"mid": signal.bollinger_mid,
|
||||||
|
"lower": signal.bollinger_lower,
|
||||||
|
},
|
||||||
|
"volatility": {
|
||||||
|
"vol_10d": vol_10,
|
||||||
|
"vol_20d": vol_20,
|
||||||
|
"vol_60d": vol_60,
|
||||||
|
"annualized": signal.annualized_volatility_pct,
|
||||||
|
"risk_level": signal.risk_level,
|
||||||
|
},
|
||||||
|
"trend": signal.trend,
|
||||||
|
"mean_reversion": signal.mean_reversion_signal,
|
||||||
|
}
|
||||||
|
|
||||||
|
await websocket.send(json.dumps({
|
||||||
|
"type": "stock_technical_indicators_loaded",
|
||||||
|
"ticker": ticker,
|
||||||
|
"indicators": indicators,
|
||||||
|
}, ensure_ascii=False, default=str))
|
||||||
|
except Exception as exc:
|
||||||
|
logger.exception("Error getting technical indicators for %s", ticker)
|
||||||
|
await websocket.send(json.dumps({
|
||||||
|
"type": "stock_technical_indicators_loaded",
|
||||||
|
"ticker": ticker,
|
||||||
|
"indicators": None,
|
||||||
|
"error": str(exc),
|
||||||
|
}, ensure_ascii=False))
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_run_stock_enrich(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||||
|
ticker = normalize_symbol(data.get("ticker", ""))
|
||||||
|
start_date = str(data.get("start_date") or "").strip()[:10]
|
||||||
|
end_date = str(data.get("end_date") or "").strip()[:10]
|
||||||
|
story_date = str(data.get("story_date") or end_date or "").strip()[:10]
|
||||||
|
target_date = str(data.get("target_date") or "").strip()[:10]
|
||||||
|
force = bool(data.get("force", False))
|
||||||
|
rebuild_story = bool(data.get("rebuild_story", True))
|
||||||
|
rebuild_similar_days = bool(data.get("rebuild_similar_days", True))
|
||||||
|
only_local_to_llm = bool(data.get("only_local_to_llm", False))
|
||||||
|
limit = data.get("limit", 200)
|
||||||
|
|
||||||
|
try:
|
||||||
|
limit = max(10, min(int(limit), 500))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
limit = 200
|
||||||
|
|
||||||
|
if not ticker or not start_date or not end_date:
|
||||||
|
await websocket.send(json.dumps({
|
||||||
|
"type": "stock_enrich_completed",
|
||||||
|
"ticker": ticker,
|
||||||
|
"start_date": start_date,
|
||||||
|
"end_date": end_date,
|
||||||
|
"error": "ticker, start_date, end_date are required",
|
||||||
|
}, ensure_ascii=False))
|
||||||
|
return
|
||||||
|
|
||||||
|
if only_local_to_llm and not llm_enrichment_enabled():
|
||||||
|
await websocket.send(json.dumps({
|
||||||
|
"type": "stock_enrich_completed",
|
||||||
|
"ticker": ticker,
|
||||||
|
"start_date": start_date,
|
||||||
|
"end_date": end_date,
|
||||||
|
"error": "only_local_to_llm requires EXPLAIN_ENRICH_USE_LLM=true and a configured LLM provider",
|
||||||
|
}, ensure_ascii=False))
|
||||||
|
return
|
||||||
|
|
||||||
|
result = await asyncio.to_thread(
|
||||||
|
enrich_news_for_symbol,
|
||||||
|
gateway.storage.market_store,
|
||||||
|
ticker,
|
||||||
|
start_date=start_date,
|
||||||
|
end_date=end_date,
|
||||||
|
limit=limit,
|
||||||
|
skip_existing=not force,
|
||||||
|
only_reanalyze_local=only_local_to_llm,
|
||||||
|
)
|
||||||
|
|
||||||
|
story_status = None
|
||||||
|
if rebuild_story and story_date:
|
||||||
|
await asyncio.to_thread(gateway.storage.market_store.delete_story_cache, ticker, as_of_date=story_date)
|
||||||
|
story_result = await asyncio.to_thread(
|
||||||
|
news_domain.get_story_payload,
|
||||||
|
gateway.storage.market_store,
|
||||||
|
ticker=ticker,
|
||||||
|
as_of_date=story_date,
|
||||||
|
)
|
||||||
|
story_status = {"as_of_date": story_date, "source": story_result.get("source") or "local"}
|
||||||
|
|
||||||
|
similar_status = None
|
||||||
|
if rebuild_similar_days and target_date:
|
||||||
|
await asyncio.to_thread(gateway.storage.market_store.delete_similar_day_cache, ticker, target_date=target_date)
|
||||||
|
similar_result = await asyncio.to_thread(
|
||||||
|
news_domain.get_similar_days_payload,
|
||||||
|
gateway.storage.market_store,
|
||||||
|
ticker=ticker,
|
||||||
|
date=target_date,
|
||||||
|
n_similar=8,
|
||||||
|
)
|
||||||
|
similar_status = {
|
||||||
|
"target_date": target_date,
|
||||||
|
"count": len(similar_result.get("items") or []),
|
||||||
|
"error": similar_result.get("error"),
|
||||||
|
}
|
||||||
|
|
||||||
|
await websocket.send(json.dumps({
|
||||||
|
"type": "stock_enrich_completed",
|
||||||
|
"ticker": ticker,
|
||||||
|
"start_date": start_date,
|
||||||
|
"end_date": end_date,
|
||||||
|
"story_date": story_date or None,
|
||||||
|
"target_date": target_date or None,
|
||||||
|
"force": force,
|
||||||
|
"only_local_to_llm": only_local_to_llm,
|
||||||
|
"stats": result,
|
||||||
|
"story_status": story_status,
|
||||||
|
"similar_status": similar_status,
|
||||||
|
}, ensure_ascii=False, default=str))
|
||||||
@@ -54,6 +54,7 @@ class MarketService:
|
|||||||
self.running = False
|
self.running = False
|
||||||
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
||||||
self._broadcast_func: Optional[Callable] = None
|
self._broadcast_func: Optional[Callable] = None
|
||||||
|
self._price_record_func: Optional[Callable[..., None]] = None
|
||||||
self._price_manager: Optional[Any] = None
|
self._price_manager: Optional[Any] = None
|
||||||
self._current_date: Optional[str] = None
|
self._current_date: Optional[str] = None
|
||||||
|
|
||||||
@@ -64,6 +65,18 @@ class MarketService:
|
|||||||
self._session_start_values: Optional[Dict[str, float]] = None
|
self._session_start_values: Optional[Dict[str, float]] = None
|
||||||
self._session_start_timestamp: Optional[int] = None
|
self._session_start_timestamp: Optional[int] = None
|
||||||
|
|
||||||
|
def get_live_quote_provider(self) -> Optional[str]:
|
||||||
|
"""Return the active live quote provider for UI/debugging."""
|
||||||
|
if self.backtest_mode:
|
||||||
|
return "backtest"
|
||||||
|
if self.mock_mode:
|
||||||
|
return "mock"
|
||||||
|
if self._price_manager and hasattr(self._price_manager, "provider"):
|
||||||
|
provider = getattr(self._price_manager, "provider", None)
|
||||||
|
if isinstance(provider, str) and provider.strip():
|
||||||
|
return provider.strip().lower()
|
||||||
|
return None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def mode_name(self) -> str:
|
def mode_name(self) -> str:
|
||||||
if self.backtest_mode:
|
if self.backtest_mode:
|
||||||
@@ -92,6 +105,10 @@ class MarketService:
|
|||||||
f"Market service started: {self.mode_name}, tickers={self.tickers}", # noqa: E501
|
f"Market service started: {self.mode_name}, tickers={self.tickers}", # noqa: E501
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def set_price_recorder(self, recorder: Optional[Callable[..., None]]):
|
||||||
|
"""Register an optional callback for persisting runtime price points."""
|
||||||
|
self._price_record_func = recorder
|
||||||
|
|
||||||
def _make_price_callback(self) -> Callable:
|
def _make_price_callback(self) -> Callable:
|
||||||
"""Create thread-safe price callback"""
|
"""Create thread-safe price callback"""
|
||||||
|
|
||||||
@@ -169,6 +186,24 @@ class MarketService:
|
|||||||
((price - open_price) / open_price) * 100 if open_price > 0 else 0
|
((price - open_price) / open_price) * 100 if open_price > 0 else 0
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self._price_record_func:
|
||||||
|
try:
|
||||||
|
self._price_record_func(
|
||||||
|
ticker=symbol,
|
||||||
|
timestamp=str(price_data.get("timestamp") or datetime.now().isoformat()),
|
||||||
|
price=float(price),
|
||||||
|
open_price=float(open_price) if open_price is not None else None,
|
||||||
|
ret=float(ret),
|
||||||
|
source=self.mode_name.lower(),
|
||||||
|
meta=price_data,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"Failed to record price point for %s: %s",
|
||||||
|
symbol,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
|
||||||
await self._broadcast_func(
|
await self._broadcast_func(
|
||||||
{
|
{
|
||||||
"type": "price_update",
|
"type": "price_update",
|
||||||
@@ -205,6 +240,43 @@ class MarketService:
|
|||||||
self._loop = None
|
self._loop = None
|
||||||
self._broadcast_func = None
|
self._broadcast_func = None
|
||||||
|
|
||||||
|
def update_tickers(self, tickers: List[str]) -> Dict[str, List[str]]:
|
||||||
|
"""Hot-update subscribed tickers without restarting the service."""
|
||||||
|
normalized: List[str] = []
|
||||||
|
for ticker in tickers:
|
||||||
|
symbol = normalize_symbol(ticker)
|
||||||
|
if symbol and symbol not in normalized:
|
||||||
|
normalized.append(symbol)
|
||||||
|
|
||||||
|
previous = list(self.tickers)
|
||||||
|
removed = [ticker for ticker in previous if ticker not in normalized]
|
||||||
|
added = [ticker for ticker in normalized if ticker not in previous]
|
||||||
|
self.tickers = normalized
|
||||||
|
|
||||||
|
if self._price_manager:
|
||||||
|
if removed:
|
||||||
|
self._price_manager.unsubscribe(removed)
|
||||||
|
if added:
|
||||||
|
if self.mock_mode:
|
||||||
|
self._price_manager.subscribe(
|
||||||
|
added,
|
||||||
|
base_prices={ticker: 100.0 for ticker in added},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._price_manager.subscribe(added)
|
||||||
|
|
||||||
|
if self.backtest_mode and self._current_date:
|
||||||
|
self._price_manager.set_date(self._current_date)
|
||||||
|
|
||||||
|
for ticker in removed:
|
||||||
|
self.cache.pop(ticker, None)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"added": added,
|
||||||
|
"removed": removed,
|
||||||
|
"active": list(self.tickers),
|
||||||
|
}
|
||||||
|
|
||||||
# Backtest methods
|
# Backtest methods
|
||||||
def set_backtest_date(self, date: str):
|
def set_backtest_date(self, date: str):
|
||||||
"""Set current backtest date"""
|
"""Set current backtest date"""
|
||||||
@@ -472,6 +544,7 @@ class MarketService:
|
|||||||
"status": MarketStatus.OPEN,
|
"status": MarketStatus.OPEN,
|
||||||
"status_text": "Backtest Mode",
|
"status_text": "Backtest Mode",
|
||||||
"is_trading_day": True,
|
"is_trading_day": True,
|
||||||
|
"live_quote_provider": self.get_live_quote_provider(),
|
||||||
}
|
}
|
||||||
|
|
||||||
now = self._now_nyse()
|
now = self._now_nyse()
|
||||||
@@ -484,6 +557,7 @@ class MarketService:
|
|||||||
"status": MarketStatus.CLOSED,
|
"status": MarketStatus.CLOSED,
|
||||||
"status_text": "Market Closed (Non-trading Day)",
|
"status_text": "Market Closed (Non-trading Day)",
|
||||||
"is_trading_day": False,
|
"is_trading_day": False,
|
||||||
|
"live_quote_provider": self.get_live_quote_provider(),
|
||||||
}
|
}
|
||||||
|
|
||||||
market_open, market_close = self._get_market_hours(today)
|
market_open, market_close = self._get_market_hours(today)
|
||||||
@@ -493,6 +567,7 @@ class MarketService:
|
|||||||
"status": MarketStatus.CLOSED,
|
"status": MarketStatus.CLOSED,
|
||||||
"status_text": "Market Closed",
|
"status_text": "Market Closed",
|
||||||
"is_trading_day": is_trading,
|
"is_trading_day": is_trading,
|
||||||
|
"live_quote_provider": self.get_live_quote_provider(),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Determine status based on current time
|
# Determine status based on current time
|
||||||
@@ -503,6 +578,7 @@ class MarketService:
|
|||||||
"is_trading_day": True,
|
"is_trading_day": True,
|
||||||
"market_open": market_open.isoformat(),
|
"market_open": market_open.isoformat(),
|
||||||
"market_close": market_close.isoformat(),
|
"market_close": market_close.isoformat(),
|
||||||
|
"live_quote_provider": self.get_live_quote_provider(),
|
||||||
}
|
}
|
||||||
elif now > market_close:
|
elif now > market_close:
|
||||||
return {
|
return {
|
||||||
@@ -511,6 +587,7 @@ class MarketService:
|
|||||||
"is_trading_day": True,
|
"is_trading_day": True,
|
||||||
"market_open": market_open.isoformat(),
|
"market_open": market_open.isoformat(),
|
||||||
"market_close": market_close.isoformat(),
|
"market_close": market_close.isoformat(),
|
||||||
|
"live_quote_provider": self.get_live_quote_provider(),
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
return {
|
return {
|
||||||
@@ -519,6 +596,7 @@ class MarketService:
|
|||||||
"is_trading_day": True,
|
"is_trading_day": True,
|
||||||
"market_open": market_open.isoformat(),
|
"market_open": market_open.isoformat(),
|
||||||
"market_close": market_close.isoformat(),
|
"market_close": market_close.isoformat(),
|
||||||
|
"live_quote_provider": self.get_live_quote_provider(),
|
||||||
}
|
}
|
||||||
|
|
||||||
async def check_and_broadcast_market_status(self):
|
async def check_and_broadcast_market_status(self):
|
||||||
|
|||||||
280
backend/services/research_db.py
Normal file
280
backend/services/research_db.py
Normal file
@@ -0,0 +1,280 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""Query-oriented storage for explain/research data."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import sqlite3
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, Iterable
|
||||||
|
|
||||||
|
from shared.schema import CompanyNews
|
||||||
|
|
||||||
|
|
||||||
|
SCHEMA = """
|
||||||
|
CREATE TABLE IF NOT EXISTS news_items (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
ticker TEXT NOT NULL,
|
||||||
|
published_at TEXT,
|
||||||
|
trade_date TEXT,
|
||||||
|
source TEXT,
|
||||||
|
title TEXT NOT NULL,
|
||||||
|
summary TEXT,
|
||||||
|
url TEXT,
|
||||||
|
related TEXT,
|
||||||
|
category TEXT,
|
||||||
|
raw_json TEXT NOT NULL,
|
||||||
|
ingest_run_date TEXT,
|
||||||
|
created_at TEXT NOT NULL
|
||||||
|
);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_news_items_ticker_date
|
||||||
|
ON news_items (ticker, trade_date DESC, published_at DESC);
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def _json_dumps(value: Any) -> str:
|
||||||
|
return json.dumps(value, ensure_ascii=False, sort_keys=True, default=str)
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_news_id(ticker: str, item: CompanyNews, fallback_index: int) -> str:
|
||||||
|
base = item.url or item.title or f"{ticker}-{fallback_index}"
|
||||||
|
return f"{ticker}:{base}"
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_trade_date(date_value: str | None) -> str | None:
|
||||||
|
if not date_value:
|
||||||
|
return None
|
||||||
|
normalized = str(date_value).strip()
|
||||||
|
if not normalized:
|
||||||
|
return None
|
||||||
|
if "T" in normalized:
|
||||||
|
return normalized.split("T", 1)[0]
|
||||||
|
if " " in normalized:
|
||||||
|
return normalized.split(" ", 1)[0]
|
||||||
|
return normalized[:10]
|
||||||
|
|
||||||
|
|
||||||
|
class ResearchDb:
|
||||||
|
"""Small SQLite helper for explain-oriented news storage."""
|
||||||
|
|
||||||
|
def __init__(self, db_path: Path):
|
||||||
|
self.db_path = Path(db_path)
|
||||||
|
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
self._init_db()
|
||||||
|
|
||||||
|
def _connect(self) -> sqlite3.Connection:
|
||||||
|
conn = sqlite3.connect(self.db_path)
|
||||||
|
conn.row_factory = sqlite3.Row
|
||||||
|
conn.execute("PRAGMA journal_mode=WAL")
|
||||||
|
conn.execute("PRAGMA foreign_keys=ON")
|
||||||
|
return conn
|
||||||
|
|
||||||
|
def _init_db(self):
|
||||||
|
with self._connect() as conn:
|
||||||
|
conn.executescript(SCHEMA)
|
||||||
|
|
||||||
|
def upsert_news_items(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
ticker: str,
|
||||||
|
items: Iterable[CompanyNews],
|
||||||
|
ingest_run_date: str | None = None,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Persist provider news and return normalized rows."""
|
||||||
|
normalized_rows: list[dict[str, Any]] = []
|
||||||
|
timestamp = datetime.utcnow().isoformat(timespec="seconds")
|
||||||
|
symbol = str(ticker or "").strip().upper()
|
||||||
|
if not symbol:
|
||||||
|
return normalized_rows
|
||||||
|
|
||||||
|
with self._connect() as conn:
|
||||||
|
for index, item in enumerate(items):
|
||||||
|
news_id = _resolve_news_id(symbol, item, index)
|
||||||
|
trade_date = _resolve_trade_date(item.date)
|
||||||
|
payload = item.model_dump()
|
||||||
|
row = {
|
||||||
|
"id": news_id,
|
||||||
|
"ticker": symbol,
|
||||||
|
"published_at": item.date,
|
||||||
|
"trade_date": trade_date,
|
||||||
|
"source": item.source,
|
||||||
|
"title": item.title,
|
||||||
|
"summary": item.summary,
|
||||||
|
"url": item.url,
|
||||||
|
"related": item.related,
|
||||||
|
"category": item.category,
|
||||||
|
"raw_json": _json_dumps(payload),
|
||||||
|
"ingest_run_date": ingest_run_date,
|
||||||
|
"created_at": timestamp,
|
||||||
|
}
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO news_items
|
||||||
|
(id, ticker, published_at, trade_date, source, title, summary, url,
|
||||||
|
related, category, raw_json, ingest_run_date, created_at)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
|
ON CONFLICT(id) DO UPDATE SET
|
||||||
|
ticker = excluded.ticker,
|
||||||
|
published_at = excluded.published_at,
|
||||||
|
trade_date = excluded.trade_date,
|
||||||
|
source = excluded.source,
|
||||||
|
title = excluded.title,
|
||||||
|
summary = excluded.summary,
|
||||||
|
url = excluded.url,
|
||||||
|
related = excluded.related,
|
||||||
|
category = excluded.category,
|
||||||
|
raw_json = excluded.raw_json,
|
||||||
|
ingest_run_date = excluded.ingest_run_date
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
row["id"],
|
||||||
|
row["ticker"],
|
||||||
|
row["published_at"],
|
||||||
|
row["trade_date"],
|
||||||
|
row["source"],
|
||||||
|
row["title"],
|
||||||
|
row["summary"],
|
||||||
|
row["url"],
|
||||||
|
row["related"],
|
||||||
|
row["category"],
|
||||||
|
row["raw_json"],
|
||||||
|
row["ingest_run_date"],
|
||||||
|
row["created_at"],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
normalized_rows.append(row)
|
||||||
|
return normalized_rows
|
||||||
|
|
||||||
|
def get_news_items(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
ticker: str,
|
||||||
|
start_date: str | None = None,
|
||||||
|
end_date: str | None = None,
|
||||||
|
limit: int = 20,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Return normalized news rows for explain UI."""
|
||||||
|
symbol = str(ticker or "").strip().upper()
|
||||||
|
if not symbol:
|
||||||
|
return []
|
||||||
|
|
||||||
|
sql = """
|
||||||
|
SELECT id, ticker, published_at, trade_date, source, title, summary,
|
||||||
|
url, related, category
|
||||||
|
FROM news_items
|
||||||
|
WHERE ticker = ?
|
||||||
|
"""
|
||||||
|
params: list[Any] = [symbol]
|
||||||
|
if start_date:
|
||||||
|
sql += " AND COALESCE(trade_date, substr(published_at, 1, 10)) >= ?"
|
||||||
|
params.append(start_date)
|
||||||
|
if end_date:
|
||||||
|
sql += " AND COALESCE(trade_date, substr(published_at, 1, 10)) <= ?"
|
||||||
|
params.append(end_date)
|
||||||
|
sql += " ORDER BY COALESCE(published_at, trade_date) DESC LIMIT ?"
|
||||||
|
params.append(max(1, int(limit)))
|
||||||
|
|
||||||
|
with self._connect() as conn:
|
||||||
|
rows = conn.execute(sql, params).fetchall()
|
||||||
|
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"id": row["id"],
|
||||||
|
"ticker": row["ticker"],
|
||||||
|
"date": row["published_at"] or row["trade_date"],
|
||||||
|
"trade_date": row["trade_date"],
|
||||||
|
"source": row["source"],
|
||||||
|
"title": row["title"],
|
||||||
|
"summary": row["summary"],
|
||||||
|
"url": row["url"],
|
||||||
|
"related": row["related"],
|
||||||
|
"category": row["category"],
|
||||||
|
}
|
||||||
|
for row in rows
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_news_timeline(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
ticker: str,
|
||||||
|
start_date: str | None = None,
|
||||||
|
end_date: str | None = None,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Aggregate news counts per trade date for chart markers."""
|
||||||
|
symbol = str(ticker or "").strip().upper()
|
||||||
|
if not symbol:
|
||||||
|
return []
|
||||||
|
|
||||||
|
sql = """
|
||||||
|
SELECT COALESCE(trade_date, substr(published_at, 1, 10)) AS date,
|
||||||
|
COUNT(*) AS count,
|
||||||
|
COUNT(DISTINCT source) AS source_count,
|
||||||
|
MAX(title) AS top_title
|
||||||
|
FROM news_items
|
||||||
|
WHERE ticker = ?
|
||||||
|
"""
|
||||||
|
params: list[Any] = [symbol]
|
||||||
|
if start_date:
|
||||||
|
sql += " AND COALESCE(trade_date, substr(published_at, 1, 10)) >= ?"
|
||||||
|
params.append(start_date)
|
||||||
|
if end_date:
|
||||||
|
sql += " AND COALESCE(trade_date, substr(published_at, 1, 10)) <= ?"
|
||||||
|
params.append(end_date)
|
||||||
|
sql += """
|
||||||
|
GROUP BY COALESCE(trade_date, substr(published_at, 1, 10))
|
||||||
|
ORDER BY date ASC
|
||||||
|
"""
|
||||||
|
|
||||||
|
with self._connect() as conn:
|
||||||
|
rows = conn.execute(sql, params).fetchall()
|
||||||
|
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"date": row["date"],
|
||||||
|
"count": int(row["count"] or 0),
|
||||||
|
"source_count": int(row["source_count"] or 0),
|
||||||
|
"top_title": row["top_title"] or "",
|
||||||
|
}
|
||||||
|
for row in rows
|
||||||
|
if row["date"]
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_news_by_ids(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
ticker: str,
|
||||||
|
article_ids: Iterable[str],
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Return selected persisted news items."""
|
||||||
|
symbol = str(ticker or "").strip().upper()
|
||||||
|
ids = [str(article_id).strip() for article_id in article_ids if str(article_id).strip()]
|
||||||
|
if not symbol or not ids:
|
||||||
|
return []
|
||||||
|
|
||||||
|
placeholders = ",".join("?" for _ in ids)
|
||||||
|
sql = f"""
|
||||||
|
SELECT id, ticker, published_at, trade_date, source, title, summary,
|
||||||
|
url, related, category
|
||||||
|
FROM news_items
|
||||||
|
WHERE ticker = ? AND id IN ({placeholders})
|
||||||
|
ORDER BY COALESCE(published_at, trade_date) DESC
|
||||||
|
"""
|
||||||
|
with self._connect() as conn:
|
||||||
|
rows = conn.execute(sql, [symbol, *ids]).fetchall()
|
||||||
|
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"id": row["id"],
|
||||||
|
"ticker": row["ticker"],
|
||||||
|
"date": row["published_at"] or row["trade_date"],
|
||||||
|
"trade_date": row["trade_date"],
|
||||||
|
"source": row["source"],
|
||||||
|
"title": row["title"],
|
||||||
|
"summary": row["summary"],
|
||||||
|
"url": row["url"],
|
||||||
|
"related": row["related"],
|
||||||
|
"category": row["category"],
|
||||||
|
}
|
||||||
|
for row in rows
|
||||||
|
]
|
||||||
512
backend/services/runtime_db.py
Normal file
512
backend/services/runtime_db.py
Normal file
@@ -0,0 +1,512 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""Run-scoped SQLite storage for query-oriented runtime history."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import sqlite3
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, Iterable, Optional
|
||||||
|
|
||||||
|
|
||||||
|
SCHEMA = """
|
||||||
|
CREATE TABLE IF NOT EXISTS events (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
event_type TEXT NOT NULL,
|
||||||
|
timestamp TEXT,
|
||||||
|
agent_id TEXT,
|
||||||
|
agent_name TEXT,
|
||||||
|
ticker TEXT,
|
||||||
|
title TEXT,
|
||||||
|
content TEXT,
|
||||||
|
payload_json TEXT NOT NULL,
|
||||||
|
run_date TEXT
|
||||||
|
);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_events_type_time ON events(event_type, timestamp DESC);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_events_ticker_time ON events(ticker, timestamp DESC);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS trades (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
ticker TEXT NOT NULL,
|
||||||
|
side TEXT,
|
||||||
|
qty REAL,
|
||||||
|
price REAL,
|
||||||
|
timestamp TEXT,
|
||||||
|
trading_date TEXT,
|
||||||
|
agent_id TEXT,
|
||||||
|
meta_json TEXT
|
||||||
|
);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_trades_ticker_time ON trades(ticker, timestamp DESC);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS signals (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
ticker TEXT NOT NULL,
|
||||||
|
agent_id TEXT,
|
||||||
|
agent_name TEXT,
|
||||||
|
role TEXT,
|
||||||
|
signal TEXT,
|
||||||
|
confidence REAL,
|
||||||
|
reasoning_json TEXT,
|
||||||
|
reasons_json TEXT,
|
||||||
|
risks_json TEXT,
|
||||||
|
invalidation TEXT,
|
||||||
|
next_action TEXT,
|
||||||
|
intrinsic_value REAL,
|
||||||
|
fair_value_range_json TEXT,
|
||||||
|
value_gap_pct REAL,
|
||||||
|
valuation_methods_json TEXT,
|
||||||
|
real_return REAL,
|
||||||
|
is_correct TEXT,
|
||||||
|
trade_date TEXT,
|
||||||
|
created_at TEXT,
|
||||||
|
meta_json TEXT
|
||||||
|
);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_signals_ticker_date ON signals(ticker, trade_date DESC);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_signals_agent_date ON signals(agent_id, trade_date DESC);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS price_points (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
ticker TEXT NOT NULL,
|
||||||
|
timestamp TEXT NOT NULL,
|
||||||
|
price REAL NOT NULL,
|
||||||
|
open_price REAL,
|
||||||
|
ret REAL,
|
||||||
|
source TEXT,
|
||||||
|
meta_json TEXT
|
||||||
|
);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_price_points_ticker_time ON price_points(ticker, timestamp DESC);
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def _json_dumps(value: Any) -> str:
|
||||||
|
return json.dumps(value, ensure_ascii=False, sort_keys=True, default=str)
|
||||||
|
|
||||||
|
|
||||||
|
def _hash_key(*parts: Any) -> str:
|
||||||
|
raw = "::".join("" if part is None else str(part) for part in parts)
|
||||||
|
return hashlib.sha1(raw.encode("utf-8")).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
class RuntimeDb:
|
||||||
|
"""Small SQLite helper for append-mostly runtime data."""
|
||||||
|
|
||||||
|
def __init__(self, db_path: Path):
|
||||||
|
self.db_path = Path(db_path)
|
||||||
|
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
self._init_db()
|
||||||
|
|
||||||
|
def _connect(self) -> sqlite3.Connection:
|
||||||
|
conn = sqlite3.connect(self.db_path)
|
||||||
|
conn.row_factory = sqlite3.Row
|
||||||
|
conn.execute("PRAGMA journal_mode=WAL")
|
||||||
|
conn.execute("PRAGMA foreign_keys=ON")
|
||||||
|
return conn
|
||||||
|
|
||||||
|
def _init_db(self):
|
||||||
|
with self._connect() as conn:
|
||||||
|
conn.executescript(SCHEMA)
|
||||||
|
|
||||||
|
def insert_event(self, event: Dict[str, Any]):
|
||||||
|
payload = dict(event or {})
|
||||||
|
if not payload:
|
||||||
|
return
|
||||||
|
|
||||||
|
event_id = payload.get("id") or _hash_key(
|
||||||
|
payload.get("type"),
|
||||||
|
payload.get("timestamp"),
|
||||||
|
payload.get("agentId") or payload.get("agent_id"),
|
||||||
|
payload.get("content"),
|
||||||
|
payload.get("title"),
|
||||||
|
)
|
||||||
|
ticker = payload.get("ticker")
|
||||||
|
if not ticker and isinstance(payload.get("tickers"), list) and len(payload["tickers"]) == 1:
|
||||||
|
ticker = payload["tickers"][0]
|
||||||
|
|
||||||
|
with self._connect() as conn:
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT OR IGNORE INTO events
|
||||||
|
(id, event_type, timestamp, agent_id, agent_name, ticker, title, content, payload_json, run_date)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
event_id,
|
||||||
|
payload.get("type"),
|
||||||
|
payload.get("timestamp"),
|
||||||
|
payload.get("agentId") or payload.get("agent_id"),
|
||||||
|
payload.get("agentName") or payload.get("agent_name"),
|
||||||
|
ticker,
|
||||||
|
payload.get("title"),
|
||||||
|
payload.get("content"),
|
||||||
|
_json_dumps(payload),
|
||||||
|
payload.get("date") or payload.get("trading_date") or payload.get("run_date"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_recent_feed_events(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
limit: int = 200,
|
||||||
|
event_types: Optional[Iterable[str]] = None,
|
||||||
|
) -> list[Dict[str, Any]]:
|
||||||
|
"""Return recent persisted feed events in newest-first order."""
|
||||||
|
event_types = tuple(event_types or ())
|
||||||
|
sql = """
|
||||||
|
SELECT payload_json
|
||||||
|
FROM events
|
||||||
|
"""
|
||||||
|
params: list[Any] = []
|
||||||
|
if event_types:
|
||||||
|
placeholders = ",".join("?" for _ in event_types)
|
||||||
|
sql += f" WHERE event_type IN ({placeholders})"
|
||||||
|
params.extend(event_types)
|
||||||
|
sql += " ORDER BY timestamp DESC LIMIT ?"
|
||||||
|
params.append(max(1, int(limit)))
|
||||||
|
|
||||||
|
with self._connect() as conn:
|
||||||
|
rows = conn.execute(sql, params).fetchall()
|
||||||
|
|
||||||
|
items: list[Dict[str, Any]] = []
|
||||||
|
for row in rows:
|
||||||
|
try:
|
||||||
|
payload = json.loads(row["payload_json"]) if row["payload_json"] else {}
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
payload = {}
|
||||||
|
if payload:
|
||||||
|
items.append(payload)
|
||||||
|
return items
|
||||||
|
|
||||||
|
def get_last_day_feed_events(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
current_date: Optional[str] = None,
|
||||||
|
limit: int = 200,
|
||||||
|
event_types: Optional[Iterable[str]] = None,
|
||||||
|
) -> list[Dict[str, Any]]:
|
||||||
|
"""Return latest trading day events in newest-first order for replay."""
|
||||||
|
event_types = tuple(event_types or ())
|
||||||
|
target_date = str(current_date or "").strip() or None
|
||||||
|
|
||||||
|
with self._connect() as conn:
|
||||||
|
if not target_date:
|
||||||
|
row = conn.execute(
|
||||||
|
"""
|
||||||
|
SELECT run_date
|
||||||
|
FROM events
|
||||||
|
WHERE run_date IS NOT NULL AND TRIM(run_date) != ''
|
||||||
|
ORDER BY run_date DESC
|
||||||
|
LIMIT 1
|
||||||
|
"""
|
||||||
|
).fetchone()
|
||||||
|
target_date = row["run_date"] if row else None
|
||||||
|
|
||||||
|
if not target_date:
|
||||||
|
return []
|
||||||
|
|
||||||
|
sql = """
|
||||||
|
SELECT payload_json
|
||||||
|
FROM events
|
||||||
|
WHERE run_date = ?
|
||||||
|
"""
|
||||||
|
params: list[Any] = [target_date]
|
||||||
|
if event_types:
|
||||||
|
placeholders = ",".join("?" for _ in event_types)
|
||||||
|
sql += f" AND event_type IN ({placeholders})"
|
||||||
|
params.extend(event_types)
|
||||||
|
sql += " ORDER BY timestamp DESC LIMIT ?"
|
||||||
|
params.append(max(1, int(limit)))
|
||||||
|
rows = conn.execute(sql, params).fetchall()
|
||||||
|
|
||||||
|
items: list[Dict[str, Any]] = []
|
||||||
|
for row in rows:
|
||||||
|
try:
|
||||||
|
payload = json.loads(row["payload_json"]) if row["payload_json"] else {}
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
payload = {}
|
||||||
|
if payload:
|
||||||
|
items.append(payload)
|
||||||
|
return items
|
||||||
|
|
||||||
|
def upsert_trade(self, trade: Dict[str, Any]):
|
||||||
|
payload = dict(trade or {})
|
||||||
|
if not payload:
|
||||||
|
return
|
||||||
|
|
||||||
|
trade_id = payload.get("id") or _hash_key(
|
||||||
|
payload.get("ticker"),
|
||||||
|
payload.get("timestamp") or payload.get("ts"),
|
||||||
|
payload.get("side"),
|
||||||
|
payload.get("qty"),
|
||||||
|
payload.get("price"),
|
||||||
|
)
|
||||||
|
with self._connect() as conn:
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT OR REPLACE INTO trades
|
||||||
|
(id, ticker, side, qty, price, timestamp, trading_date, agent_id, meta_json)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
trade_id,
|
||||||
|
payload.get("ticker"),
|
||||||
|
payload.get("side"),
|
||||||
|
payload.get("qty"),
|
||||||
|
payload.get("price"),
|
||||||
|
payload.get("timestamp") or payload.get("ts"),
|
||||||
|
payload.get("trading_date"),
|
||||||
|
payload.get("agentId") or payload.get("agent_id"),
|
||||||
|
_json_dumps(payload),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def upsert_signal(self, signal: Dict[str, Any], *, agent_id: str, agent_name: str, role: str):
|
||||||
|
payload = dict(signal or {})
|
||||||
|
ticker = payload.get("ticker")
|
||||||
|
if not ticker:
|
||||||
|
return
|
||||||
|
|
||||||
|
signal_id = _hash_key(
|
||||||
|
agent_id,
|
||||||
|
ticker,
|
||||||
|
payload.get("date"),
|
||||||
|
payload.get("signal"),
|
||||||
|
payload.get("confidence"),
|
||||||
|
)
|
||||||
|
with self._connect() as conn:
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT OR REPLACE INTO signals
|
||||||
|
(id, ticker, agent_id, agent_name, role, signal, confidence, reasoning_json,
|
||||||
|
reasons_json, risks_json, invalidation, next_action, intrinsic_value,
|
||||||
|
fair_value_range_json, value_gap_pct, valuation_methods_json,
|
||||||
|
real_return, is_correct, trade_date, created_at, meta_json)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
signal_id,
|
||||||
|
ticker,
|
||||||
|
agent_id,
|
||||||
|
agent_name,
|
||||||
|
role,
|
||||||
|
payload.get("signal"),
|
||||||
|
payload.get("confidence"),
|
||||||
|
_json_dumps(payload.get("reasoning")),
|
||||||
|
_json_dumps(payload.get("reasons")),
|
||||||
|
_json_dumps(payload.get("risks")),
|
||||||
|
payload.get("invalidation"),
|
||||||
|
payload.get("next_action"),
|
||||||
|
payload.get("intrinsic_value"),
|
||||||
|
_json_dumps(payload.get("fair_value_range")),
|
||||||
|
payload.get("value_gap_pct"),
|
||||||
|
_json_dumps(payload.get("valuation_methods")),
|
||||||
|
payload.get("real_return"),
|
||||||
|
None if payload.get("is_correct") is None else str(payload.get("is_correct")),
|
||||||
|
payload.get("date"),
|
||||||
|
payload.get("created_at") or payload.get("date"),
|
||||||
|
_json_dumps(payload),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def replace_signals_for_leaderboard(self, leaderboard: Iterable[Dict[str, Any]]):
|
||||||
|
with self._connect() as conn:
|
||||||
|
conn.execute("DELETE FROM signals")
|
||||||
|
for agent in leaderboard:
|
||||||
|
agent_id = agent.get("agentId")
|
||||||
|
agent_name = agent.get("name")
|
||||||
|
role = agent.get("role")
|
||||||
|
for signal in agent.get("signals", []) or []:
|
||||||
|
payload = dict(signal or {})
|
||||||
|
ticker = payload.get("ticker")
|
||||||
|
if not ticker:
|
||||||
|
continue
|
||||||
|
signal_id = _hash_key(
|
||||||
|
agent_id,
|
||||||
|
ticker,
|
||||||
|
payload.get("date"),
|
||||||
|
payload.get("signal"),
|
||||||
|
payload.get("confidence"),
|
||||||
|
)
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO signals
|
||||||
|
(id, ticker, agent_id, agent_name, role, signal, confidence, reasoning_json,
|
||||||
|
reasons_json, risks_json, invalidation, next_action, intrinsic_value,
|
||||||
|
fair_value_range_json, value_gap_pct, valuation_methods_json,
|
||||||
|
real_return, is_correct, trade_date, created_at, meta_json)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
signal_id,
|
||||||
|
ticker,
|
||||||
|
agent_id,
|
||||||
|
agent_name,
|
||||||
|
role,
|
||||||
|
payload.get("signal"),
|
||||||
|
payload.get("confidence"),
|
||||||
|
_json_dumps(payload.get("reasoning")),
|
||||||
|
_json_dumps(payload.get("reasons")),
|
||||||
|
_json_dumps(payload.get("risks")),
|
||||||
|
payload.get("invalidation"),
|
||||||
|
payload.get("next_action"),
|
||||||
|
payload.get("intrinsic_value"),
|
||||||
|
_json_dumps(payload.get("fair_value_range")),
|
||||||
|
payload.get("value_gap_pct"),
|
||||||
|
_json_dumps(payload.get("valuation_methods")),
|
||||||
|
payload.get("real_return"),
|
||||||
|
None if payload.get("is_correct") is None else str(payload.get("is_correct")),
|
||||||
|
payload.get("date"),
|
||||||
|
payload.get("created_at") or payload.get("date"),
|
||||||
|
_json_dumps(payload),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def insert_price_point(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
ticker: str,
|
||||||
|
timestamp: str,
|
||||||
|
price: float,
|
||||||
|
open_price: Optional[float] = None,
|
||||||
|
ret: Optional[float] = None,
|
||||||
|
source: Optional[str] = None,
|
||||||
|
meta: Optional[Dict[str, Any]] = None,
|
||||||
|
):
|
||||||
|
price_id = _hash_key(ticker, timestamp, price, open_price, ret)
|
||||||
|
with self._connect() as conn:
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT OR IGNORE INTO price_points
|
||||||
|
(id, ticker, timestamp, price, open_price, ret, source, meta_json)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
price_id,
|
||||||
|
ticker,
|
||||||
|
timestamp,
|
||||||
|
price,
|
||||||
|
open_price,
|
||||||
|
ret,
|
||||||
|
source,
|
||||||
|
_json_dumps(meta or {}),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_stock_explain_snapshot(
|
||||||
|
self,
|
||||||
|
ticker: str,
|
||||||
|
*,
|
||||||
|
limit_events: int = 24,
|
||||||
|
limit_trades: int = 12,
|
||||||
|
limit_signals: int = 12,
|
||||||
|
) -> Dict[str, list[Dict[str, Any]]]:
|
||||||
|
"""Fetch query-oriented history for a single ticker."""
|
||||||
|
symbol = str(ticker or "").strip().upper()
|
||||||
|
if not symbol:
|
||||||
|
return {"events": [], "trades": [], "signals": []}
|
||||||
|
|
||||||
|
with self._connect() as conn:
|
||||||
|
trade_rows = conn.execute(
|
||||||
|
"""
|
||||||
|
SELECT * FROM trades
|
||||||
|
WHERE ticker = ?
|
||||||
|
ORDER BY timestamp DESC
|
||||||
|
LIMIT ?
|
||||||
|
""",
|
||||||
|
(symbol, limit_trades),
|
||||||
|
).fetchall()
|
||||||
|
signal_rows = conn.execute(
|
||||||
|
"""
|
||||||
|
SELECT * FROM signals
|
||||||
|
WHERE ticker = ?
|
||||||
|
ORDER BY trade_date DESC, created_at DESC
|
||||||
|
LIMIT ?
|
||||||
|
""",
|
||||||
|
(symbol, limit_signals),
|
||||||
|
).fetchall()
|
||||||
|
event_rows = conn.execute(
|
||||||
|
"""
|
||||||
|
SELECT * FROM events
|
||||||
|
WHERE payload_json LIKE ? OR content LIKE ? OR title LIKE ? OR ticker = ?
|
||||||
|
ORDER BY timestamp DESC
|
||||||
|
LIMIT ?
|
||||||
|
""",
|
||||||
|
(f"%{symbol}%", f"%{symbol}%", f"%{symbol}%", symbol, limit_events * 3),
|
||||||
|
).fetchall()
|
||||||
|
|
||||||
|
normalized_events = []
|
||||||
|
seen_event_ids: set[str] = set()
|
||||||
|
for row in event_rows:
|
||||||
|
payload = json.loads(row["payload_json"]) if row["payload_json"] else {}
|
||||||
|
content = str(row["content"] or payload.get("content") or "")
|
||||||
|
title = str(row["title"] or payload.get("title") or "")
|
||||||
|
if symbol not in f"{title} {content}".upper() and str(row["ticker"] or "").upper() != symbol:
|
||||||
|
continue
|
||||||
|
event_id = row["id"]
|
||||||
|
if event_id in seen_event_ids:
|
||||||
|
continue
|
||||||
|
seen_event_ids.add(event_id)
|
||||||
|
normalized_events.append(
|
||||||
|
{
|
||||||
|
"id": event_id,
|
||||||
|
"type": "mention",
|
||||||
|
"timestamp": row["timestamp"],
|
||||||
|
"title": title or f"{row['agent_name'] or '未知角色'}提及 {symbol}",
|
||||||
|
"meta": payload.get("conferenceTitle")
|
||||||
|
or payload.get("feedType")
|
||||||
|
or row["event_type"],
|
||||||
|
"body": content,
|
||||||
|
"tone": "neutral",
|
||||||
|
"agent": row["agent_name"] or payload.get("agentName") or payload.get("agent"),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if len(normalized_events) >= limit_events:
|
||||||
|
break
|
||||||
|
|
||||||
|
normalized_trades = [
|
||||||
|
{
|
||||||
|
"id": row["id"],
|
||||||
|
"type": "trade",
|
||||||
|
"timestamp": row["timestamp"],
|
||||||
|
"title": f"{row['side']} {int(row['qty'] or 0)} 股",
|
||||||
|
"meta": "交易执行",
|
||||||
|
"body": f"成交价 ${float(row['price'] or 0):.2f}",
|
||||||
|
"tone": "positive" if row["side"] == "LONG" else "negative" if row["side"] == "SHORT" else "neutral",
|
||||||
|
}
|
||||||
|
for row in trade_rows
|
||||||
|
]
|
||||||
|
|
||||||
|
normalized_signals = [
|
||||||
|
{
|
||||||
|
"id": row["id"],
|
||||||
|
"type": "signal",
|
||||||
|
"timestamp": f"{row['trade_date']}T08:00:00" if row["trade_date"] else row["created_at"],
|
||||||
|
"title": f"{row['agent_name']} 给出{row['signal'] or '中性'}信号",
|
||||||
|
"meta": row["role"],
|
||||||
|
"body": (
|
||||||
|
f"后验收益 {float(row['real_return']) * 100:+.2f}%"
|
||||||
|
if row["real_return"] is not None
|
||||||
|
else "该信号暂未完成后验评估"
|
||||||
|
),
|
||||||
|
"tone": "positive" if str(row["signal"] or "").lower() in {"bullish", "buy", "long"} else "negative" if str(row["signal"] or "").lower() in {"bearish", "sell", "short"} else "neutral",
|
||||||
|
# Extended signal fields
|
||||||
|
"signal": row["signal"],
|
||||||
|
"confidence": row["confidence"],
|
||||||
|
"reasoning": json.loads(row["reasoning_json"]) if row["reasoning_json"] else None,
|
||||||
|
"reasons": json.loads(row["reasons_json"]) if row["reasons_json"] else None,
|
||||||
|
"risks": json.loads(row["risks_json"]) if row["risks_json"] else None,
|
||||||
|
"invalidation": row["invalidation"],
|
||||||
|
"next_action": row["next_action"],
|
||||||
|
"intrinsic_value": row["intrinsic_value"],
|
||||||
|
"fair_value_range": json.loads(row["fair_value_range_json"]) if row["fair_value_range_json"] else None,
|
||||||
|
"value_gap_pct": row["value_gap_pct"],
|
||||||
|
"valuation_methods": json.loads(row["valuation_methods_json"]) if row["valuation_methods_json"] else None,
|
||||||
|
}
|
||||||
|
for row in signal_rows
|
||||||
|
]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"events": normalized_events,
|
||||||
|
"trades": normalized_trades,
|
||||||
|
"signals": normalized_signals,
|
||||||
|
}
|
||||||
@@ -10,6 +10,10 @@ from datetime import datetime
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from backend.data.market_store import MarketStore
|
||||||
|
from .research_db import ResearchDb
|
||||||
|
from .runtime_db import RuntimeDb
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -61,6 +65,9 @@ class StorageService:
|
|||||||
self.state_dir = self.dashboard_dir.parent / "state"
|
self.state_dir = self.dashboard_dir.parent / "state"
|
||||||
self.state_dir.mkdir(parents=True, exist_ok=True)
|
self.state_dir.mkdir(parents=True, exist_ok=True)
|
||||||
self.server_state_file = self.state_dir / "server_state.json"
|
self.server_state_file = self.state_dir / "server_state.json"
|
||||||
|
self.runtime_db = RuntimeDb(self.state_dir / "runtime.db")
|
||||||
|
self.research_db = ResearchDb(self.state_dir / "research.db")
|
||||||
|
self.market_store = MarketStore()
|
||||||
|
|
||||||
# Feed history (for agent messages)
|
# Feed history (for agent messages)
|
||||||
self.max_feed_history = 200
|
self.max_feed_history = 200
|
||||||
@@ -114,6 +121,11 @@ class StorageService:
|
|||||||
try:
|
try:
|
||||||
with open(file_path, "w", encoding="utf-8") as f:
|
with open(file_path, "w", encoding="utf-8") as f:
|
||||||
json.dump(data, f, indent=2, ensure_ascii=False)
|
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||||
|
if file_type == "leaderboard" and isinstance(data, list):
|
||||||
|
self.runtime_db.replace_signals_for_leaderboard(data)
|
||||||
|
elif file_type == "trades" and isinstance(data, list):
|
||||||
|
for trade in data:
|
||||||
|
self.runtime_db.upsert_trade(trade)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to save {file_type}.json: {e}")
|
logger.error(f"Failed to save {file_type}.json: {e}")
|
||||||
|
|
||||||
@@ -211,6 +223,7 @@ class StorageService:
|
|||||||
try:
|
try:
|
||||||
with open(self.internal_state_file, "w", encoding="utf-8") as f:
|
with open(self.internal_state_file, "w", encoding="utf-8") as f:
|
||||||
json.dump(state, f, indent=2, ensure_ascii=False)
|
json.dump(state, f, indent=2, ensure_ascii=False)
|
||||||
|
self._sync_price_history_to_db(state.get("price_history", {}))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to save internal state: {e}")
|
logger.error(f"Failed to save internal state: {e}")
|
||||||
|
|
||||||
@@ -231,6 +244,41 @@ class StorageService:
|
|||||||
"margin_requirement": 0.25, # Default 25% margin requirement
|
"margin_requirement": 0.25, # Default 25% margin requirement
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _portfolio_is_pristine(portfolio_state: Dict[str, Any]) -> bool:
|
||||||
|
"""Return whether the persisted portfolio can be safely rebased."""
|
||||||
|
positions = portfolio_state.get("positions", {})
|
||||||
|
has_positions = any(
|
||||||
|
position.get("long", 0) or position.get("short", 0)
|
||||||
|
for position in positions.values()
|
||||||
|
)
|
||||||
|
margin_used = float(portfolio_state.get("margin_used", 0.0) or 0.0)
|
||||||
|
return not has_positions and margin_used == 0.0
|
||||||
|
|
||||||
|
def can_apply_initial_cash(self) -> bool:
|
||||||
|
"""Only allow initial cash changes before the run has traded."""
|
||||||
|
state = self.load_internal_state()
|
||||||
|
if not self._portfolio_is_pristine(state.get("portfolio_state", {})):
|
||||||
|
return False
|
||||||
|
if state.get("all_trades"):
|
||||||
|
return False
|
||||||
|
return len(state.get("equity_history", [])) <= 1
|
||||||
|
|
||||||
|
def apply_initial_cash(self, initial_cash: float) -> bool:
|
||||||
|
"""Rebase storage state to a new initial cash when the run is pristine."""
|
||||||
|
if not self.can_apply_initial_cash():
|
||||||
|
return False
|
||||||
|
|
||||||
|
self.initial_cash = float(initial_cash)
|
||||||
|
if self.internal_state_file.exists():
|
||||||
|
self.internal_state_file.unlink()
|
||||||
|
|
||||||
|
self.initialize_empty_dashboard()
|
||||||
|
state = self.load_server_state()
|
||||||
|
self.update_server_state_from_dashboard(state)
|
||||||
|
self.save_server_state(state)
|
||||||
|
return True
|
||||||
|
|
||||||
def save_portfolio_state(self, portfolio: Dict[str, Any]):
|
def save_portfolio_state(self, portfolio: Dict[str, Any]):
|
||||||
"""
|
"""
|
||||||
Save portfolio state to internal state
|
Save portfolio state to internal state
|
||||||
@@ -750,6 +798,7 @@ class StorageService:
|
|||||||
"last_day_history": [],
|
"last_day_history": [],
|
||||||
"trading_days_total": 0,
|
"trading_days_total": 0,
|
||||||
"trading_days_completed": 0,
|
"trading_days_completed": 0,
|
||||||
|
"price_history": {},
|
||||||
}
|
}
|
||||||
|
|
||||||
if not self.server_state_file.exists():
|
if not self.server_state_file.exists():
|
||||||
@@ -771,6 +820,11 @@ class StorageService:
|
|||||||
)
|
)
|
||||||
logger.info(f"Trades: {len(saved_state.get('trades', []))} records")
|
logger.info(f"Trades: {len(saved_state.get('trades', []))} records")
|
||||||
|
|
||||||
|
for event in saved_state.get("feed_history", []):
|
||||||
|
self.runtime_db.insert_event(event)
|
||||||
|
for trade in saved_state.get("trades", []):
|
||||||
|
self.runtime_db.upsert_trade(trade)
|
||||||
|
|
||||||
return saved_state
|
return saved_state
|
||||||
|
|
||||||
def save_server_state(self, state: Dict[str, Any]):
|
def save_server_state(self, state: Dict[str, Any]):
|
||||||
@@ -852,6 +906,7 @@ class StorageService:
|
|||||||
state["feed_history"] = []
|
state["feed_history"] = []
|
||||||
|
|
||||||
state["feed_history"].insert(0, feed_msg)
|
state["feed_history"].insert(0, feed_msg)
|
||||||
|
self.runtime_db.insert_event(feed_msg)
|
||||||
|
|
||||||
# Trim to max size
|
# Trim to max size
|
||||||
if len(state["feed_history"]) > self.max_feed_history:
|
if len(state["feed_history"]) > self.max_feed_history:
|
||||||
@@ -861,6 +916,69 @@ class StorageService:
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def record_price_point(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
ticker: str,
|
||||||
|
timestamp: str,
|
||||||
|
price: float,
|
||||||
|
open_price: Optional[float] = None,
|
||||||
|
ret: Optional[float] = None,
|
||||||
|
source: Optional[str] = None,
|
||||||
|
meta: Optional[Dict[str, Any]] = None,
|
||||||
|
):
|
||||||
|
"""Persist a runtime price point for later query-oriented reads."""
|
||||||
|
if not ticker or not timestamp:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
self.runtime_db.insert_price_point(
|
||||||
|
ticker=ticker,
|
||||||
|
timestamp=timestamp,
|
||||||
|
price=price,
|
||||||
|
open_price=open_price,
|
||||||
|
ret=ret,
|
||||||
|
source=source,
|
||||||
|
meta=meta,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Failed to record price point for %s: %s", ticker, exc)
|
||||||
|
|
||||||
|
def _sync_price_history_to_db(self, price_history: Dict[str, Any]):
|
||||||
|
"""Backfill structured price points from serialized internal state."""
|
||||||
|
if not isinstance(price_history, dict):
|
||||||
|
return
|
||||||
|
for ticker, points in price_history.items():
|
||||||
|
if not ticker or not isinstance(points, list):
|
||||||
|
continue
|
||||||
|
for point in points:
|
||||||
|
if isinstance(point, (list, tuple)) and len(point) >= 2:
|
||||||
|
timestamp, price = point[0], point[1]
|
||||||
|
try:
|
||||||
|
self.record_price_point(
|
||||||
|
ticker=str(ticker),
|
||||||
|
timestamp=str(timestamp),
|
||||||
|
price=float(price),
|
||||||
|
)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
continue
|
||||||
|
elif isinstance(point, dict):
|
||||||
|
timestamp = point.get("timestamp") or point.get("label") or point.get("date")
|
||||||
|
price = point.get("price") or point.get("close") or point.get("value")
|
||||||
|
if not timestamp or price is None:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
self.record_price_point(
|
||||||
|
ticker=str(ticker),
|
||||||
|
timestamp=str(timestamp),
|
||||||
|
price=float(price),
|
||||||
|
open_price=point.get("open"),
|
||||||
|
ret=point.get("ret"),
|
||||||
|
source=point.get("source"),
|
||||||
|
meta=point,
|
||||||
|
)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
continue
|
||||||
|
|
||||||
def _get_default_stats(self) -> Dict[str, Any]:
|
def _get_default_stats(self) -> Dict[str, Any]:
|
||||||
"""Get default stats structure"""
|
"""Get default stats structure"""
|
||||||
return {
|
return {
|
||||||
@@ -889,6 +1007,7 @@ class StorageService:
|
|||||||
stats = self.load_file("stats") or self._get_default_stats()
|
stats = self.load_file("stats") or self._get_default_stats()
|
||||||
trades = self.load_file("trades") or []
|
trades = self.load_file("trades") or []
|
||||||
leaderboard = self.load_file("leaderboard") or []
|
leaderboard = self.load_file("leaderboard") or []
|
||||||
|
internal_state = self.load_internal_state()
|
||||||
|
|
||||||
# Update state
|
# Update state
|
||||||
state["portfolio"] = {
|
state["portfolio"] = {
|
||||||
@@ -910,6 +1029,9 @@ class StorageService:
|
|||||||
state["stats"] = stats
|
state["stats"] = stats
|
||||||
state["trades"] = trades
|
state["trades"] = trades
|
||||||
state["leaderboard"] = leaderboard
|
state["leaderboard"] = leaderboard
|
||||||
|
state["price_history"] = internal_state.get("price_history", {})
|
||||||
|
self.runtime_db.replace_signals_for_leaderboard(leaderboard)
|
||||||
|
self._sync_price_history_to_db(state["price_history"])
|
||||||
|
|
||||||
# ========== Live Returns Tracking ==========
|
# ========== Live Returns Tracking ==========
|
||||||
|
|
||||||
|
|||||||
119
backend/skills/SKILL_TEMPLATE.md
Normal file
119
backend/skills/SKILL_TEMPLATE.md
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
# Skill Template (Anthropic + AgentScope Aligned)
|
||||||
|
|
||||||
|
> 用于定义可执行、可路由、可评估的技能规范。
|
||||||
|
> 建议所有 `SKILL.md` 至少覆盖以下 6 个部分。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Frontmatter Spec
|
||||||
|
|
||||||
|
All `SKILL.md` files should begin with a YAML frontmatter block:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
---
|
||||||
|
name: skill_name # Required. Unique identifier for the skill.
|
||||||
|
description: ... # Required. One-line description of the skill.
|
||||||
|
version: "1.0.0" # Optional. Semantic version string.
|
||||||
|
tools: [...] # Optional. Tools provided or used by this skill.
|
||||||
|
allowed_tools: [...] # Optional. List of tool names permitted when this skill is active.
|
||||||
|
denied_tools: [...] # Optional. List of tool names denied when this skill is active.
|
||||||
|
---
|
||||||
|
```
|
||||||
|
|
||||||
|
### Frontmatter Fields
|
||||||
|
|
||||||
|
| Field | Type | Description |
|
||||||
|
|-------|------|-------------|
|
||||||
|
| `name` | string | Unique skill identifier (kebab-case recommended). |
|
||||||
|
| `description` | string | Human-readable one-line description. |
|
||||||
|
| `version` | string | Semantic version (e.g., `"1.0.0"`). |
|
||||||
|
| `tools` | list[string] | Tools provided by or associated with this skill. |
|
||||||
|
| `allowed_tools` | list[string] | Enumerates which tools are **permitted** when this skill is active. If set, only these tools may be used. |
|
||||||
|
| `denied_tools` | list[string] | Enumerates which tools are **forbidden** when this skill is active. Denied tools take precedence over `allowed_tools`. |
|
||||||
|
|
||||||
|
### Tool Restriction Rules
|
||||||
|
|
||||||
|
- If **only** `allowed_tools` is set: only those tools are accessible.
|
||||||
|
- If **only** `denied_tools` is set: all tools except those are accessible.
|
||||||
|
- If **both** are set: `allowed_tools` defines the initial set, then `denied_tools` removes from it.
|
||||||
|
- **Denial takes precedence**: a tool in `denied_tools` is always blocked even if also in `allowed_tools`.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 1) When to use
|
||||||
|
|
||||||
|
- 明确触发条件(任务类型、关键词、场景)。
|
||||||
|
- 明确不应使用该技能的边界(避免误触发)。
|
||||||
|
|
||||||
|
## 2) Required inputs
|
||||||
|
|
||||||
|
- 列出最小必要输入(如 `tickers`、价格、组合状态、风险约束)。
|
||||||
|
- 声明输入缺失时的处理规则(终止 / 降级 / 请求补充)。
|
||||||
|
|
||||||
|
## 3) Decision procedure
|
||||||
|
|
||||||
|
- 采用固定步骤,确保可复现。
|
||||||
|
- 每一步说明目标、判据和产物(例如中间结论)。
|
||||||
|
- 标明冲突处理逻辑(信号冲突、数据冲突、置信度冲突)。
|
||||||
|
|
||||||
|
## 4) Tool call policy
|
||||||
|
|
||||||
|
- 说明优先使用哪些工具组与工具。
|
||||||
|
- 规定何时可以“无工具直接结论”,何时必须工具先证据后结论。
|
||||||
|
- 规定工具失败、超时、返回异常时的替代动作。
|
||||||
|
|
||||||
|
## 5) Output schema
|
||||||
|
|
||||||
|
- 定义标准输出字段,便于下游 Agent 消费与评估。
|
||||||
|
- 推荐包含:`signal`、`confidence`、`reasons`、`risks`、`invalidation`、`next_action`。
|
||||||
|
- 若是组合决策技能,必须包含每个 ticker 的 `action` 与 `quantity`。
|
||||||
|
|
||||||
|
## 6) Failure fallback
|
||||||
|
|
||||||
|
- 规定在数据不足、信号冲突、风险超限、工具不可用时的降级策略。
|
||||||
|
- 默认优先“保守 + 可解释 + 可执行”的输出。
|
||||||
|
|
||||||
|
## Optional: Evaluation hooks
|
||||||
|
|
||||||
|
定义技能的可评估指标,用于后续记忆/反思阶段写入长期经验。
|
||||||
|
|
||||||
|
### 支持的指标类型
|
||||||
|
|
||||||
|
| 指标类型 | 描述 | 适用技能 |
|
||||||
|
|---------|------|---------|
|
||||||
|
| `hit_rate` | 信号命中率 - 决策信号与实际结果的符合程度 | sentiment_review, technical_review |
|
||||||
|
| `risk_violation` | 风控违例率 - 触发风控规则的次数 | risk_review, portfolio_decisioning |
|
||||||
|
| `position_deviation` | 仓位偏离率 - 建议仓位与实际执行仓位的偏差 | portfolio_decisioning |
|
||||||
|
| `pnl_attribution` | P&L 归因一致性 - 收益归因与实际收益的匹配度 | fundamental_review, valuation_review |
|
||||||
|
| `signal_consistency` | 信号一致性 - 多来源信号的一致程度 | sentiment_review |
|
||||||
|
| `decision_latency` | 决策延迟 - 从输入到决策的耗时 | portfolio_decisioning |
|
||||||
|
| `tool_usage` | 工具使用率 - 工具调用次数与成功率的比值 | 所有技能 |
|
||||||
|
| `custom` | 自定义指标 | 特定业务场景 |
|
||||||
|
|
||||||
|
### 使用方式
|
||||||
|
|
||||||
|
```python
|
||||||
|
from backend.agents.base.evaluation_hook import EvaluationHook, MetricType
|
||||||
|
|
||||||
|
# 在技能执行开始时
|
||||||
|
evaluation_hook.start_evaluation(
|
||||||
|
skill_name="technical_review",
|
||||||
|
inputs={"tickers": ["AAPL"], "prices": {...}}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 在技能执行过程中添加指标
|
||||||
|
evaluation_hook.add_metric(
|
||||||
|
name="signal_confidence",
|
||||||
|
metric_type=MetricType.HIT_RATE,
|
||||||
|
value=0.85,
|
||||||
|
metadata={"method": "rsi", "threshold": 30}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 在技能完成时记录结果
|
||||||
|
evaluation_hook.record_outputs({"signal": "buy", "confidence": 0.8})
|
||||||
|
evaluation_hook.complete_evaluation(success=True)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 评估结果存储
|
||||||
|
|
||||||
|
评估结果自动保存到 `runs/{run_id}/evaluations/{agent_id}/{skill_name}_{timestamp}.json`
|
||||||
1
backend/skills/__init__.py
Normal file
1
backend/skills/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
1
backend/skills/builtin/__init__.py
Normal file
1
backend/skills/builtin/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
@@ -1,21 +1,49 @@
|
|||||||
---
|
---
|
||||||
name: fundamental_review
|
name: 基本面分析
|
||||||
description: Review a company from a fundamentals-first perspective before issuing a trading signal.
|
description: 当用户要求“基本面分析”“看财务质量”“分析盈利能力”“判断公司质量”或“评估长期盈利韧性”时,应使用此技能。
|
||||||
|
version: 1.0.0
|
||||||
---
|
---
|
||||||
|
|
||||||
# Fundamental Review
|
# 基本面分析
|
||||||
|
|
||||||
Use this skill when the task requires judging business quality, balance-sheet strength, profitability, or long-term earnings durability.
|
当用户希望从公司质量、资产负债表强度、盈利能力或长期盈利韧性出发判断标的时,使用这个技能。
|
||||||
|
|
||||||
## Workflow
|
## 1) When to use
|
||||||
|
|
||||||
1. Check profitability, growth, financial health, and efficiency before forming a conclusion.
|
- 适用于需要判断“公司基本面质量是否支撑当前估值/交易观点”的任务。
|
||||||
2. Separate durable business quality from short-term noise.
|
- 优先在中长期视角下使用(财务稳健性、盈利韧性、成长持续性)。
|
||||||
3. State what would invalidate the thesis.
|
- 当任务明确以短线事件驱动为主时,不应单独依赖本技能,应与情绪/技术信号联合。
|
||||||
4. End with a clear signal, confidence, and the main drivers behind that signal.
|
|
||||||
|
|
||||||
## Guardrails
|
## 2) Required inputs
|
||||||
|
|
||||||
- Do not rely on one metric in isolation.
|
- 最少输入:`tickers`、关键财务指标(盈利、成长、偿债、效率)。
|
||||||
- Call out missing data explicitly.
|
- 推荐输入:行业背景、公司阶段、近期重大事件。
|
||||||
- Prefer conservative conclusions when financial quality is mixed.
|
- 若关键数据缺失(例如利润质量或现金流质量无法判断),必须在结论中显式标注“不足信息风险”,并降低置信度。
|
||||||
|
|
||||||
|
## 3) Decision procedure
|
||||||
|
|
||||||
|
1. 先做四维诊断:盈利能力、成长质量、财务健康度、经营效率。
|
||||||
|
2. 区分“结构性优势”与“周期性改善/短期噪音”。
|
||||||
|
3. 识别关键风险与失效条件(invalidation),明确什么情况会推翻当前判断。
|
||||||
|
4. 合成最终观点:`signal + confidence + drivers + risks`。
|
||||||
|
|
||||||
|
## 4) Tool call policy
|
||||||
|
|
||||||
|
- 优先使用基本面与财务相关工具组获取证据,再形成结论。
|
||||||
|
- 在数据完备且任务允许时,可补充估值相关工具进行交叉验证。
|
||||||
|
- 若工具失败或返回异常:保留已验证证据,明确未验证部分,不允许伪造数据。
|
||||||
|
|
||||||
|
## 5) Output schema
|
||||||
|
|
||||||
|
- `signal`: `bullish | bearish | neutral`
|
||||||
|
- `confidence`: `0-100`
|
||||||
|
- `reasons`: 2-4 条核心驱动
|
||||||
|
- `risks`: 1-3 条关键风险
|
||||||
|
- `invalidation`: 触发观点失效的条件
|
||||||
|
- `next_action`: 对 PM 的可执行建议(如“仅小仓位试错/等待下一季报确认”)
|
||||||
|
|
||||||
|
## 6) Failure fallback
|
||||||
|
|
||||||
|
- 数据稀疏或矛盾时:默认 `neutral` 或低置信度方向结论。
|
||||||
|
- 不允许因单一亮点指标给出高置信度信号。
|
||||||
|
- 当财务质量优劣混杂时,优先保守结论并附加“需补充验证”的下一步建议。
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user