Compare commits
18 Commits
4b5ac86b83
...
codex/chec
| Author | SHA1 | Date | |
|---|---|---|---|
| 9bcc4221a4 | |||
| fecf8a9466 | |||
| 86eb8c37a9 | |||
| 1f9063edad | |||
| 7e7a58769a | |||
| 16bb3c4211 | |||
| da6d642aaa | |||
| 8d6c3c5647 | |||
| 6413edf8c9 | |||
| c5eaf2b5ad | |||
| 032c37538f | |||
| 456748b01e | |||
| 609b509446 | |||
| 38102d0805 | |||
| 3448667b79 | |||
| 0f1bc2bb39 | |||
| 06a23c32a4 | |||
| 5b925fbe02 |
41
.env.example
Normal file
41
.env.example
Normal file
@@ -0,0 +1,41 @@
|
||||
# Copy this file to `.env` for local development.
|
||||
# Keep `.env` untracked and never paste real secrets into tracked files.
|
||||
|
||||
# ================== General Configuration | 通用配置 ==================
|
||||
TICKERS=AAPL,MSFT,GOOGL,NVDA,TSLA,META,AMZN
|
||||
|
||||
# Financial Data API
|
||||
# At least `FINANCIAL_DATASETS_API_KEY` is required when using `FIN_DATA_SOURCE=financial_datasets`.
|
||||
# `FINNHUB_API_KEY` is recommended for `FIN_DATA_SOURCE=finnhub` and required for live mode.
|
||||
FIN_DATA_SOURCE=finnhub
|
||||
ENABLED_DATA_SOURCES=financial_datasets,finnhub,yfinance,local_csv
|
||||
FINANCIAL_DATASETS_API_KEY=
|
||||
FINNHUB_API_KEY=
|
||||
POLYGON_API_KEY=
|
||||
MARKET_DB_PATH=
|
||||
|
||||
# Model API
|
||||
OPENAI_API_KEY=
|
||||
OPENAI_BASE_URL=
|
||||
MODEL_NAME=qwen3-max-preview
|
||||
EXPLAIN_ENRICH_USE_LLM=false
|
||||
EXPLAIN_ENRICH_MODEL_PROVIDER=
|
||||
EXPLAIN_ENRICH_MODEL_NAME=
|
||||
EXPLAIN_RANGE_USE_LLM=
|
||||
|
||||
# Memory module
|
||||
MEMORY_API_KEY=
|
||||
|
||||
# ================== Agent-Specific Model Configuration | Agent特定模型配置 ==================
|
||||
AGENT_SENTIMENT_ANALYST_MODEL_NAME=deepseek-v3.2-exp
|
||||
AGENT_TECHNICAL_ANALYST_MODEL_NAME=glm-4.6
|
||||
AGENT_FUNDAMENTALS_ANALYST_MODEL_NAME=qwen3-max-preview
|
||||
AGENT_VALUATION_ANALYST_MODEL_NAME=Moonshot-Kimi-K2-Instruct
|
||||
AGENT_RISK_MANAGER_MODEL_NAME=qwen3-max-preview
|
||||
AGENT_PORTFOLIO_MANAGER_MODEL_NAME=qwen3-max-preview
|
||||
|
||||
# ================== Advanced Configuration | 高阶配置 ==================
|
||||
MAX_COMM_CYCLES=2
|
||||
MARGIN_REQUIREMENT=0.5
|
||||
DATA_START_DATE=2022-01-01
|
||||
AUTO_UPDATE_DATA=true
|
||||
5
.gitignore
vendored
5
.gitignore
vendored
@@ -54,10 +54,13 @@ outputs/
|
||||
/smoke_live_mock/
|
||||
|
||||
# Local tooling state
|
||||
/.omc/
|
||||
.omc/
|
||||
/.pydeps/
|
||||
/referance/
|
||||
|
||||
# Run outputs
|
||||
/runs/
|
||||
|
||||
# Data files
|
||||
backend/data/ret_data/
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
{
|
||||
"version": "1.0.0",
|
||||
"lastScanned": 1773304964541,
|
||||
"projectRoot": "/Users/cillin/workspeace/agentscope-samples/evotraders",
|
||||
"lastScanned": 1774313111650,
|
||||
"projectRoot": "/Users/cillin/workspeace/evotraders",
|
||||
"techStack": {
|
||||
"languages": [
|
||||
{
|
||||
@@ -40,7 +40,8 @@
|
||||
"isMonorepo": false,
|
||||
"workspaces": [],
|
||||
"mainDirectories": [
|
||||
"docs"
|
||||
"docs",
|
||||
"scripts"
|
||||
],
|
||||
"gitBranches": {
|
||||
"defaultBranch": "main",
|
||||
@@ -52,26 +53,54 @@
|
||||
"backend": {
|
||||
"path": "backend",
|
||||
"purpose": null,
|
||||
"fileCount": 3,
|
||||
"lastAccessed": 1773304964533,
|
||||
"fileCount": 4,
|
||||
"lastAccessed": 1774313111639,
|
||||
"keyFiles": [
|
||||
"__init__.py",
|
||||
"cli.py",
|
||||
"gateway_server.py",
|
||||
"main.py"
|
||||
]
|
||||
},
|
||||
"backtest": {
|
||||
"path": "backtest",
|
||||
"purpose": null,
|
||||
"fileCount": 0,
|
||||
"lastAccessed": 1774313111640,
|
||||
"keyFiles": []
|
||||
},
|
||||
"data": {
|
||||
"path": "data",
|
||||
"purpose": "Data files",
|
||||
"fileCount": 3,
|
||||
"lastAccessed": 1774313111640,
|
||||
"keyFiles": [
|
||||
"market_research.db",
|
||||
"market_research.db-shm",
|
||||
"market_research.db-wal"
|
||||
]
|
||||
},
|
||||
"deploy": {
|
||||
"path": "deploy",
|
||||
"purpose": null,
|
||||
"fileCount": 0,
|
||||
"lastAccessed": 1774313111640,
|
||||
"keyFiles": []
|
||||
},
|
||||
"docs": {
|
||||
"path": "docs",
|
||||
"purpose": "Documentation",
|
||||
"fileCount": 0,
|
||||
"lastAccessed": 1773304964533,
|
||||
"keyFiles": []
|
||||
"fileCount": 1,
|
||||
"lastAccessed": 1774313111641,
|
||||
"keyFiles": [
|
||||
"compat-removal-plan.md"
|
||||
]
|
||||
},
|
||||
"evotraders.egg-info": {
|
||||
"path": "evotraders.egg-info",
|
||||
"purpose": null,
|
||||
"fileCount": 6,
|
||||
"lastAccessed": 1773304964534,
|
||||
"lastAccessed": 1774313111641,
|
||||
"keyFiles": [
|
||||
"PKG-INFO",
|
||||
"SOURCES.txt",
|
||||
@@ -83,8 +112,8 @@
|
||||
"frontend": {
|
||||
"path": "frontend",
|
||||
"purpose": null,
|
||||
"fileCount": 12,
|
||||
"lastAccessed": 1773304964535,
|
||||
"fileCount": 13,
|
||||
"lastAccessed": 1774313111641,
|
||||
"keyFiles": [
|
||||
"README.md",
|
||||
"components.json",
|
||||
@@ -93,239 +122,386 @@
|
||||
"index.css"
|
||||
]
|
||||
},
|
||||
"live": {
|
||||
"path": "live",
|
||||
"purpose": null,
|
||||
"fileCount": 0,
|
||||
"lastAccessed": 1774313111642,
|
||||
"keyFiles": []
|
||||
},
|
||||
"logs": {
|
||||
"path": "logs",
|
||||
"purpose": null,
|
||||
"fileCount": 6,
|
||||
"lastAccessed": 1774313111642,
|
||||
"keyFiles": [
|
||||
"2026-03-16_00-48-03.log",
|
||||
"2026-03-18_23-17-29.log",
|
||||
"2026-03-18_23-17-30.log",
|
||||
"2026-03-19_00-18-04.log",
|
||||
"2026-03-19_00-34-21.log"
|
||||
]
|
||||
},
|
||||
"reference": {
|
||||
"path": "reference",
|
||||
"purpose": null,
|
||||
"fileCount": 0,
|
||||
"lastAccessed": 1774313111643,
|
||||
"keyFiles": []
|
||||
},
|
||||
"runs": {
|
||||
"path": "runs",
|
||||
"purpose": null,
|
||||
"fileCount": 0,
|
||||
"lastAccessed": 1774313111643,
|
||||
"keyFiles": []
|
||||
},
|
||||
"scripts": {
|
||||
"path": "scripts",
|
||||
"purpose": "Build/utility scripts",
|
||||
"fileCount": 1,
|
||||
"lastAccessed": 1774313111644,
|
||||
"keyFiles": [
|
||||
"run_prod.sh"
|
||||
]
|
||||
},
|
||||
"services": {
|
||||
"path": "services",
|
||||
"purpose": "Business logic services",
|
||||
"fileCount": 1,
|
||||
"lastAccessed": 1774313111644,
|
||||
"keyFiles": [
|
||||
"README.md"
|
||||
]
|
||||
},
|
||||
"shared": {
|
||||
"path": "shared",
|
||||
"purpose": null,
|
||||
"fileCount": 0,
|
||||
"lastAccessed": 1774313111644,
|
||||
"keyFiles": []
|
||||
},
|
||||
"workspaces": {
|
||||
"path": "workspaces",
|
||||
"purpose": null,
|
||||
"fileCount": 0,
|
||||
"lastAccessed": 1774313111645,
|
||||
"keyFiles": []
|
||||
},
|
||||
"backend/api": {
|
||||
"path": "backend/api",
|
||||
"purpose": "API routes",
|
||||
"fileCount": 5,
|
||||
"lastAccessed": 1774313111645,
|
||||
"keyFiles": [
|
||||
"__init__.py",
|
||||
"agents.py",
|
||||
"guard.py"
|
||||
]
|
||||
},
|
||||
"backend/config": {
|
||||
"path": "backend/config",
|
||||
"purpose": "Configuration files",
|
||||
"fileCount": 4,
|
||||
"lastAccessed": 1773304964535,
|
||||
"fileCount": 6,
|
||||
"lastAccessed": 1774313111646,
|
||||
"keyFiles": [
|
||||
"__init__.py",
|
||||
"constants.py",
|
||||
"data_config.py"
|
||||
"agent_profiles.yaml",
|
||||
"bootstrap_config.py"
|
||||
]
|
||||
},
|
||||
"backend/data": {
|
||||
"path": "backend/data",
|
||||
"purpose": "Data files",
|
||||
"fileCount": 7,
|
||||
"lastAccessed": 1773304964536,
|
||||
"fileCount": 13,
|
||||
"lastAccessed": 1774313111647,
|
||||
"keyFiles": [
|
||||
"__init__.py",
|
||||
"cache.py",
|
||||
"historical_price_manager.py"
|
||||
]
|
||||
},
|
||||
"backend/services": {
|
||||
"path": "backend/services",
|
||||
"purpose": "Business logic services",
|
||||
"fileCount": 4,
|
||||
"lastAccessed": 1773304964536,
|
||||
"keyFiles": [
|
||||
"__init__.py",
|
||||
"gateway.py",
|
||||
"market.py"
|
||||
]
|
||||
},
|
||||
"backend/tests": {
|
||||
"path": "backend/tests",
|
||||
"purpose": "Test files",
|
||||
"fileCount": 4,
|
||||
"lastAccessed": 1773304964536,
|
||||
"keyFiles": [
|
||||
"__init__.py",
|
||||
"test_agents.py",
|
||||
"test_market_service.py"
|
||||
]
|
||||
},
|
||||
"docs/assets": {
|
||||
"path": "docs/assets",
|
||||
"purpose": "Static assets",
|
||||
"fileCount": 5,
|
||||
"lastAccessed": 1773304964536,
|
||||
"lastAccessed": 1774313111647,
|
||||
"keyFiles": [
|
||||
"dashboard.jpg",
|
||||
"evotraders_demo.gif",
|
||||
"evotraders_logo.jpg"
|
||||
]
|
||||
},
|
||||
"frontend/public": {
|
||||
"path": "frontend/public",
|
||||
"purpose": "Public files",
|
||||
"fileCount": 1,
|
||||
"lastAccessed": 1773304964538,
|
||||
"frontend/dist": {
|
||||
"path": "frontend/dist",
|
||||
"purpose": "Distribution/build output",
|
||||
"fileCount": 2,
|
||||
"lastAccessed": 1774313111647,
|
||||
"keyFiles": [
|
||||
"index.html",
|
||||
"trading_logo.png"
|
||||
]
|
||||
},
|
||||
"frontend/node_modules": {
|
||||
"path": "frontend/node_modules",
|
||||
"purpose": "Dependencies",
|
||||
"fileCount": 1,
|
||||
"lastAccessed": 1774313111650,
|
||||
"keyFiles": []
|
||||
}
|
||||
},
|
||||
"hotPaths": [
|
||||
{
|
||||
"path": "frontend/src/components/StatisticsView.jsx",
|
||||
"accessCount": 22,
|
||||
"lastAccessed": 1773310044545,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "frontend/src/components/AgentCard.jsx",
|
||||
"accessCount": 17,
|
||||
"lastAccessed": 1773309995177,
|
||||
"type": "file"
|
||||
"path": "CLAUDE.md",
|
||||
"accessCount": 15,
|
||||
"lastAccessed": 1774342728155,
|
||||
"type": "directory"
|
||||
},
|
||||
{
|
||||
"path": "frontend/src/App.jsx",
|
||||
"accessCount": 12,
|
||||
"lastAccessed": 1773309849392,
|
||||
"accessCount": 10,
|
||||
"lastAccessed": 1774339397617,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "frontend/src/components/AgentFeed.jsx",
|
||||
"accessCount": 12,
|
||||
"lastAccessed": 1773309960022,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": ".env",
|
||||
"accessCount": 7,
|
||||
"lastAccessed": 1773308950505,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "frontend/src/components/RoomView.jsx",
|
||||
"accessCount": 7,
|
||||
"lastAccessed": 1773309864236,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "backend/tools/analysis_tools.py",
|
||||
"accessCount": 5,
|
||||
"lastAccessed": 1773312271446,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "frontend/src/components/Header.jsx",
|
||||
"path": "frontend/src/hooks/useWebsocketSessionSync.js",
|
||||
"accessCount": 4,
|
||||
"lastAccessed": 1773309827069,
|
||||
"lastAccessed": 1774313470024,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "frontend/src/components/AboutModal.jsx",
|
||||
"path": "",
|
||||
"accessCount": 4,
|
||||
"lastAccessed": 1773310093371,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "backend/agents/prompts/analyst/personas.yaml",
|
||||
"accessCount": 4,
|
||||
"lastAccessed": 1773312049213,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "backend/agents/prompts/analyst/system.md",
|
||||
"accessCount": 4,
|
||||
"lastAccessed": 1773312049696,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "backend/agents/prompts/portfolio_manager/system.md",
|
||||
"accessCount": 4,
|
||||
"lastAccessed": 1773312050326,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "backend/agents/prompts/risk_manager/system.md",
|
||||
"accessCount": 4,
|
||||
"lastAccessed": 1773312050782,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "frontend/src/config/constants.js",
|
||||
"accessCount": 3,
|
||||
"lastAccessed": 1773309824671,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "frontend/src/components/RulesView.jsx",
|
||||
"accessCount": 3,
|
||||
"lastAccessed": 1773310061939,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "backend",
|
||||
"accessCount": 3,
|
||||
"lastAccessed": 1773312200721,
|
||||
"lastAccessed": 1774339108220,
|
||||
"type": "directory"
|
||||
},
|
||||
{
|
||||
"path": "backend/services/gateway.py",
|
||||
"accessCount": 3,
|
||||
"lastAccessed": 1774339389171,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "backend/main.py",
|
||||
"accessCount": 3,
|
||||
"lastAccessed": 1774342613364,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "frontend/src/store/runtimeStore.js",
|
||||
"accessCount": 2,
|
||||
"lastAccessed": 1773312232905,
|
||||
"type": "directory"
|
||||
},
|
||||
{
|
||||
"path": "README.md",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1773305013217,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "README_zh.md",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1773305013274,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "env.template",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1773305019965,
|
||||
"lastAccessed": 1774317990919,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "frontend/src/services/websocket.js",
|
||||
"accessCount": 2,
|
||||
"lastAccessed": 1774318009819,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "backend/core/pipeline_runner.py",
|
||||
"accessCount": 2,
|
||||
"lastAccessed": 1774339367538,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "backend/runtime/manager.py",
|
||||
"accessCount": 2,
|
||||
"lastAccessed": 1774339367572,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "frontend/src/store/marketStore.js",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1773309324302,
|
||||
"lastAccessed": 1774313140483,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "frontend/src/hooks/useFeedProcessor.js",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1774313148279,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "frontend/src/components/Header.jsx",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1774313156696,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "frontend/src/components/TraderView.jsx",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1774313156753,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "frontend/src/store/uiStore.js",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1774313187460,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "frontend/src/store/portfolioStore.js",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1774313187511,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "frontend/src/store/agentStore.js",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1774313187573,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "frontend/src/hooks/useWebSocketConnection.js",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1774313279414,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "frontend/src/hooks/useStockDataRequests.js",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1774313319716,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "frontend/src/hooks/useAgentDataRequests.js",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1774313347455,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "frontend/src/components/AppShell.jsx",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1774313396331,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "start-dev.sh",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1774317979859,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "backend/apps/agent_service.py",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1774317984348,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "shared/client/trading_client.py",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1774317984365,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "backend/apps/trading_service.py",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1774317984408,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "pyproject.toml",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1774317990970,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "backend/agents/factory.py",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1774318009867,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "backend/config/constants.py",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1774318009922,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "backend/api/__init__.py",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1774318009973,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "README.md",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1774339107381,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "backend/runtime/registry.py",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1774339380024,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "backend/runtime/session.py",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1774339380084,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "backend/runtime/context.py",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1774339380120,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "backend/runtime/agent_runtime.py",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1774339380185,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "backend/process/supervisor.py",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1774339389110,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "backend/core/pipeline.py",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1774339389187,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "backend/process/models.py",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1774339397557,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "backend/process/registry.py",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1774339397577,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "backend/config/env_config.py",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1774342678236,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "backend/config/data_config.py",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1773309324414,
|
||||
"lastAccessed": 1774342678253,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "backend/cli.py",
|
||||
"path": "frontend/env.template",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1773309336899,
|
||||
"type": "directory"
|
||||
},
|
||||
{
|
||||
"path": "backend/agents/portfolio_manager.py",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1773311956562,
|
||||
"lastAccessed": 1774342678290,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "backend/agents/risk_manager.py",
|
||||
"path": "env.template",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1773311956760,
|
||||
"lastAccessed": 1774342678310,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "backend/agents/analyst.py",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1773311963222,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "backend/tools",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1773312289643,
|
||||
"type": "directory"
|
||||
},
|
||||
{
|
||||
"path": "backend/tools/data_tools.py",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1773312293851,
|
||||
"type": "directory"
|
||||
}
|
||||
],
|
||||
"userDirectives": []
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"timestamp": "2026-03-12T20:33:59.497Z",
|
||||
"timestamp": "2026-03-24T07:58:12.123Z",
|
||||
"backgroundTasks": [],
|
||||
"sessionStartTimestamp": "2026-03-12T14:19:33.615Z",
|
||||
"sessionId": "73b0d597-0141-4873-9d0e-2b60e4e0635e"
|
||||
"sessionStartTimestamp": "2026-03-24T07:58:09.417Z",
|
||||
"sessionId": "fda34772-7bd2-402e-86b2-d656296416f3"
|
||||
}
|
||||
@@ -1 +1 @@
|
||||
{"session_id":"73b0d597-0141-4873-9d0e-2b60e4e0635e","transcript_path":"/Users/cillin/.claude/projects/-Users-cillin-workspeace-agentscope-samples-evotraders/73b0d597-0141-4873-9d0e-2b60e4e0635e.jsonl","cwd":"/Users/cillin/workspeace/agentscope-samples/evotraders","model":{"id":"kimi-for-coding","display_name":"kimi-for-coding"},"workspace":{"current_dir":"/Users/cillin/workspeace/agentscope-samples/evotraders","project_dir":"/Users/cillin/workspeace/agentscope-samples/evotraders","added_dirs":["/Users/cillin/workspeace/agentscope-samples/EvoTraders","/Users/cillin/workspeace/agentscope-samples/evotraders"]},"version":"2.1.63","output_style":{"name":"default"},"cost":{"total_cost_usd":6.822239999999999,"total_duration_ms":42679588,"total_api_duration_ms":1223637,"total_lines_added":275,"total_lines_removed":240},"context_window":{"total_input_tokens":654274,"total_output_tokens":27014,"context_window_size":200000,"current_usage":{"input_tokens":48465,"output_tokens":0,"cache_creation_input_tokens":0,"cache_read_input_tokens":0},"used_percentage":24,"remaining_percentage":76},"exceeds_200k_tokens":false}
|
||||
{"session_id":"fda34772-7bd2-402e-86b2-d656296416f3","transcript_path":"/Users/cillin/.claude/projects/-Users-cillin-workspeace-evotraders/fda34772-7bd2-402e-86b2-d656296416f3.jsonl","cwd":"/Users/cillin/workspeace/evotraders","model":{"id":"MiniMax-M2.7-highspeed","display_name":"MiniMax-M2.7-highspeed"},"workspace":{"current_dir":"/Users/cillin/workspeace/evotraders","project_dir":"/Users/cillin/workspeace/evotraders","added_dirs":[]},"version":"2.1.78","output_style":{"name":"default"},"cost":{"total_cost_usd":36.63980749999998,"total_duration_ms":69778027,"total_api_duration_ms":2925118,"total_lines_added":3056,"total_lines_removed":4537},"context_window":{"total_input_tokens":910503,"total_output_tokens":145207,"context_window_size":200000,"current_usage":{"input_tokens":507,"output_tokens":247,"cache_creation_input_tokens":4132,"cache_read_input_tokens":96553},"used_percentage":51,"remaining_percentage":49},"exceeds_200k_tokens":false}
|
||||
@@ -1,3 +1,3 @@
|
||||
{
|
||||
"lastSentAt": "2026-03-12T20:31:37.362Z"
|
||||
"lastSentAt": "2026-03-24T08:58:57.965Z"
|
||||
}
|
||||
@@ -1,26 +1,26 @@
|
||||
{
|
||||
"agents": [
|
||||
{
|
||||
"agent_id": "a4090d26a45ac828d",
|
||||
"agent_type": "oh-my-claudecode:executor",
|
||||
"started_at": "2026-03-12T10:02:38.238Z",
|
||||
"agent_id": "abeaf609b74a2b7ee",
|
||||
"agent_type": "Explore",
|
||||
"started_at": "2026-03-24T08:01:40.015Z",
|
||||
"parent_mode": "none",
|
||||
"status": "completed",
|
||||
"completed_at": "2026-03-12T10:10:59.192Z",
|
||||
"duration_ms": 500954
|
||||
"completed_at": "2026-03-24T08:02:31.822Z",
|
||||
"duration_ms": 51807
|
||||
},
|
||||
{
|
||||
"agent_id": "af87583ef76a4df30",
|
||||
"agent_type": "oh-my-claudecode:executor",
|
||||
"started_at": "2026-03-12T10:40:04.409Z",
|
||||
"agent_id": "afb6750eaae72bc72",
|
||||
"agent_type": "Explore",
|
||||
"started_at": "2026-03-24T08:56:21.471Z",
|
||||
"parent_mode": "none",
|
||||
"status": "completed",
|
||||
"completed_at": "2026-03-12T10:41:17.387Z",
|
||||
"duration_ms": 72978
|
||||
"completed_at": "2026-03-24T08:57:27.856Z",
|
||||
"duration_ms": 66385
|
||||
}
|
||||
],
|
||||
"total_spawned": 2,
|
||||
"total_completed": 2,
|
||||
"total_failed": 0,
|
||||
"last_updated": "2026-03-12T10:41:17.490Z"
|
||||
"last_updated": "2026-03-24T08:59:06.380Z"
|
||||
}
|
||||
378
CLAUDE.md
Normal file
378
CLAUDE.md
Normal file
@@ -0,0 +1,378 @@
|
||||
# CLAUDE.md
|
||||
|
||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||
|
||||
本文件为 Claude Code (claude.ai/code) 在此代码库中工作时提供指导。
|
||||
|
||||
## 项目概述
|
||||
|
||||
EvoTraders 是一个自进化多智能体交易系统,由 6 个 AI Agent(4 名分析师 + 投资经理 + 风控经理)协作完成交易决策。Agent 基于 AgentScope 框架构建,配合 ReMe 记忆系统实现持续学习。
|
||||
|
||||
## 常用命令
|
||||
|
||||
### Backend (Python)
|
||||
|
||||
```bash
|
||||
# 安装依赖
|
||||
uv pip install -e .
|
||||
|
||||
# 运行命令
|
||||
evotraders backtest --start 2025-11-01 --end 2025-12-01 # 回测模式
|
||||
evotraders backtest --start 2025-11-01 --end 2025-12-01 --enable-memory # 带记忆回测
|
||||
evotraders live # 实盘交易
|
||||
evotraders live --mock # 模拟/测试模式
|
||||
evotraders live -t 22:30 # 定时每日交易
|
||||
evotraders frontend # 启动可视化界面
|
||||
|
||||
# 开发服务器
|
||||
./start-dev.sh # 启动全部 4 个微服务 (agent, runtime, trading, news)
|
||||
|
||||
# Gateway WebSocket 服务器
|
||||
python backend/main.py --mode live --config-name mock --mock
|
||||
|
||||
# 单独启动微服务
|
||||
python -m uvicorn backend.apps.runtime_service:app --host 0.0.0.0 --port 8003 --reload
|
||||
python -m uvicorn backend.apps.agent_service:app --host 0.0.0.0 --port 8000 --reload
|
||||
python -m uvicorn backend.apps.trading_service:app --host 0.0.0.0 --port 8001 --reload
|
||||
python -m uvicorn backend.apps.news_service:app --host 0.0.0.0 --port 8002 --reload
|
||||
|
||||
# 测试
|
||||
pytest backend/tests # 运行全部测试
|
||||
pytest backend/tests/test_news_service_app.py -v # 运行单个测试
|
||||
```
|
||||
|
||||
### Frontend (React)
|
||||
|
||||
```bash
|
||||
cd frontend
|
||||
npm run dev # Vite 开发服务器 (http://localhost:5173)
|
||||
npm run build # 生产构建
|
||||
npm run lint # ESLint 检查
|
||||
npm run lint:fix # ESLint 自动修复
|
||||
npm run test # Vitest 单元测试
|
||||
```
|
||||
|
||||
## 架构概览
|
||||
|
||||
### 系统分层
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ Frontend (React) │
|
||||
│ WebSocket ws://localhost:8765 连接 Gateway │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ Gateway (backend/services/gateway.py) │
|
||||
│ WebSocket 服务器,编排 Pipeline,4 阶段启动 │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
│ │ │ │
|
||||
▼ ▼ ▼ ▼
|
||||
┌────────────┐ ┌────────────┐ ┌────────────┐ ┌────────────┐
|
||||
│ Market │ │ Storage │ │ Pipeline │ │ Scheduler │
|
||||
│ Service │ │ Service │ │ │ │ │
|
||||
└────────────┘ └────────────┘ └────────────┘ └────────────┘
|
||||
│
|
||||
┌──────────────────────┼──────────────────────┐
|
||||
▼ ▼ ▼
|
||||
┌──────────┐ ┌──────────┐ ┌──────────┐
|
||||
│ Analysts │ │ PM │ │ Risk │
|
||||
│ (4 个) │ │ │ │ Manager │
|
||||
└──────────┘ └──────────┘ └──────────┘
|
||||
```
|
||||
|
||||
### 微服务架构 (`backend/apps/`)
|
||||
|
||||
| 服务 | 端口 | 职责 |
|
||||
|------|------|------|
|
||||
| runtime_service | 8003 | 运行时配置、任务启动、Pipeline Runner |
|
||||
| agent_service | 8000 | Agent 生命周期、工作区管理 |
|
||||
| trading_service | 8001 | 市场数据、交易操作 |
|
||||
| news_service | 8002 | 新闻、新闻富化、解释功能 |
|
||||
|
||||
### Gateway 4 阶段启动 (`backend/services/gateway.py`)
|
||||
|
||||
1. **WebSocket Server** - 前端立即可连接
|
||||
2. **Market Service** - 价格数据开始推送
|
||||
3. **Market Status Monitor** - 市场状态监控
|
||||
4. **Scheduler** - 交易周期开始
|
||||
|
||||
### 运行时管理层 (`backend/runtime/`)
|
||||
|
||||
| 文件 | 职责 |
|
||||
|------|------|
|
||||
| `manager.py` | TradingRuntimeManager - 全局运行时管理器,agent 注册、会话、事件快照 |
|
||||
| `agent_runtime.py` | AgentRuntimeState - 单 agent 状态(status、last_session) |
|
||||
| `context.py` | TradingRunContext - 运行上下文 |
|
||||
| `session.py` | TradingSessionKey - 交易日会话键 |
|
||||
| `registry.py` | RuntimeRegistry - agent 状态注册表 |
|
||||
|
||||
快照持久化到 `runs/<run_id>/state/runtime_state.json`。
|
||||
|
||||
### Pipeline 执行 (`backend/core/`)
|
||||
|
||||
| 文件 | 职责 |
|
||||
|------|------|
|
||||
| `pipeline.py` | TradingPipeline - 核心编排器(分析→沟通→决策→执行→评估) |
|
||||
| `pipeline_runner.py` | REST API 触发的独立执行,5 阶段启动 |
|
||||
| `scheduler.py` | BacktestScheduler、Scheduler - 回测/实盘调度 |
|
||||
| `state_sync.py` | StateSync - 状态同步和广播 |
|
||||
|
||||
## 后端结构
|
||||
|
||||
```
|
||||
backend/
|
||||
├── agents/ # 多智能体实现
|
||||
│ ├── analyst.py # AnalystAgent 基类
|
||||
│ ├── portfolio_manager.py # PMAgent 投资经理
|
||||
│ ├── risk_manager.py # RiskAgent 风控经理
|
||||
│ ├── factory.py # Agent 实例工厂
|
||||
│ ├── toolkit_factory.py # 工具集工厂
|
||||
│ ├── skills_manager.py # 技能加载管理
|
||||
│ ├── workspace_manager.py # 工作区管理
|
||||
│ ├── skill_loader.py # 技能加载器
|
||||
│ ├── agent_workspace.py # Agent 工作区
|
||||
│ ├── prompt_loader.py # Prompt 加载器
|
||||
│ ├── prompt_factory.py # Prompt 工厂
|
||||
│ ├── skill_metadata.py # 技能元数据
|
||||
│ ├── registry.py # Agent 注册表
|
||||
│ ├── team_pipeline_config.py # 团队 Pipeline 配置
|
||||
│ ├── compat.py # 兼容性层
|
||||
│ ├── templates.py # 模板
|
||||
│ ├── workspace.py # 工作区
|
||||
│ ├── base/ # 核心类、Hooks
|
||||
│ │ ├── evo_agent.py # 基于 AgentScope 的核心实现
|
||||
│ │ └── hooks.py # 生命周期 Hooks
|
||||
│ └── prompts/ # Agent 提示词
|
||||
│ └── analyst/personas.yaml
|
||||
│
|
||||
├── apps/ # 微服务入口
|
||||
│ ├── runtime_service.py # 运行时服务(端口 8003)
|
||||
│ ├── agent_service.py # Agent 服务(端口 8000)
|
||||
│ ├── trading_service.py # 交易服务(端口 8001)
|
||||
│ ├── news_service.py # 新闻服务(端口 8002)
|
||||
│ └── cors.py
|
||||
│
|
||||
├── runtime/ # 运行时管理层
|
||||
│ ├── manager.py # TradingRuntimeManager
|
||||
│ ├── agent_runtime.py # AgentRuntimeState
|
||||
│ ├── context.py # TradingRunContext
|
||||
│ ├── session.py # TradingSessionKey
|
||||
│ └── registry.py # RuntimeRegistry
|
||||
│
|
||||
├── process/ # 进程监管层
|
||||
│ ├── supervisor.py # ProcessSupervisor
|
||||
│ ├── registry.py # RunRegistry
|
||||
│ └── models.py # ProcessRun、ProcessRunState
|
||||
│
|
||||
├── core/ # Pipeline 执行
|
||||
│ ├── pipeline.py # TradingPipeline(核心编排器)
|
||||
│ ├── pipeline_runner.py # 独立 Pipeline 执行
|
||||
│ ├── scheduler.py # 调度器
|
||||
│ └── state_sync.py # 状态同步
|
||||
│
|
||||
├── services/ # Gateway 和服务
|
||||
│ ├── gateway.py # WebSocket 网关
|
||||
│ ├── gateway_*.py # Gateway 子模块
|
||||
│ ├── market.py # 市场数据服务
|
||||
│ ├── storage.py # 存储服务
|
||||
│ ├── runtime_db.py # 运行时数据库
|
||||
│ └── research_db.py # 研究数据库
|
||||
│
|
||||
├── data/ # 市场数据处理
|
||||
│ ├── provider_router.py # 数据源路由
|
||||
│ ├── provider_utils.py # 数据源工具
|
||||
│ ├── market_store.py # 市场数据存储
|
||||
│ ├── market_ingest.py # 数据采集
|
||||
│ ├── cache.py # 缓存
|
||||
│ ├── schema.py # 数据 schema
|
||||
│ ├── historical_price_manager.py # 历史价格管理
|
||||
│ ├── polling_price_manager.py # 轮询价格管理
|
||||
│ ├── mock_price_manager.py # Mock 价格管理
|
||||
│ ├── news_alignment.py # 新闻对齐
|
||||
│ ├── polygon_client.py # Polygon.io 客户端
|
||||
│ └── ret_data_updater.py # 离线数据更新
|
||||
│
|
||||
├── config/ # 配置
|
||||
│ ├── constants.py # Agent 配置、显示名称
|
||||
│ ├── bootstrap_config.py # 启动配置解析
|
||||
│ ├── env_config.py # 环境变量配置
|
||||
│ ├── data_config.py # 数据源配置
|
||||
│ └── agent_profiles.yaml # Agent Profile 配置
|
||||
│
|
||||
├── domains/ # 领域业务逻辑
|
||||
│ ├── news.py
|
||||
│ └── trading.py
|
||||
│
|
||||
├── llm/ # LLM 集成
|
||||
│ └── models.py # RetryChatModel、TokenRecordingModelWrapper
|
||||
│
|
||||
├── skills/ # 技能定义
|
||||
├── tools/ # 交易和分析工具
|
||||
├── enrich/ # LLM 响应富化
|
||||
├── explain/ # 交易决策解释
|
||||
├── utils/ # 工具函数
|
||||
│ ├── settlement.py # 结算协调器
|
||||
│ ├── trade_executor.py # 交易执行器
|
||||
│ ├── terminal_dashboard.py # 终端仪表板
|
||||
│ ├── analyst_tracker.py # 分析师追踪
|
||||
│ ├── baselines.py # 基准线
|
||||
│ ├── msg_adapter.py # 消息适配器
|
||||
│ └── progress.py # 进度追踪
|
||||
│
|
||||
├── api/ # FastAPI 端点
|
||||
│ └── runtime.py
|
||||
│
|
||||
└── main.py # 主入口点
|
||||
```
|
||||
|
||||
## 前端结构
|
||||
|
||||
```
|
||||
frontend/src/
|
||||
├── App.jsx # 主应用(LiveTradingApp)
|
||||
├── AppShell.jsx # App 外壳(布局、侧边栏)
|
||||
├── components/
|
||||
│ ├── RuntimeView.jsx # 交易运行时 UI
|
||||
│ ├── TraderView.jsx # 交易员界面
|
||||
│ ├── RoomView.jsx # 聊天室视图
|
||||
│ ├── StockExplainView.jsx # 股票解释视图
|
||||
│ ├── RuntimeSettingsPanel.jsx # 运行时设置面板
|
||||
│ ├── RuntimeLogsModal.jsx # 运行时日志弹窗
|
||||
│ ├── WatchlistPanel.jsx # 关注列表
|
||||
│ ├── PerformanceView.jsx # 绩效视图
|
||||
│ ├── StatisticsView.jsx # 统计视图
|
||||
│ ├── NetValueChart.jsx # 净值曲线图
|
||||
│ ├── AgentCard.jsx # Agent 卡片
|
||||
│ ├── AgentFeed.jsx # Agent 动态
|
||||
│ ├── Header.jsx # 头部
|
||||
│ ├── MarkdownModal.jsx # Markdown 弹窗
|
||||
│ ├── StockLogo.jsx # 股票 Logo
|
||||
│ └── explain/ # 解释组件
|
||||
│ ├── ExplainNewsSection.jsx
|
||||
│ ├── ExplainRangeSection.jsx
|
||||
│ ├── ExplainSimilarDaysSection.jsx
|
||||
│ ├── ExplainStorySection.jsx
|
||||
│ └── useExplainModel.js
|
||||
├── hooks/ # React Hooks
|
||||
│ ├── useWebSocketConnection.js # WebSocket 连接管理
|
||||
│ ├── useRuntimeControls.js # 运行时配置管理
|
||||
│ ├── useAgentDataRequests.js # Agent 数据请求
|
||||
│ ├── useStockDataRequests.js # 股票数据请求
|
||||
│ ├── useStockExplainData.js # 股票解释数据
|
||||
│ ├── useAgentWorkspacePanel.js # Agent 工作区面板
|
||||
│ ├── useWebsocketSessionSync.js # WebSocket 会话同步
|
||||
│ └── useFeedProcessor.js # Feed 事件处理
|
||||
├── store/ # Zustand 状态管理
|
||||
│ ├── runtimeStore.js # 连接状态、运行时配置
|
||||
│ ├── marketStore.js # 市场数据、股票价格
|
||||
│ ├── portfolioStore.js # 组合、持仓、交易
|
||||
│ ├── agentStore.js # Agent 技能、工作区
|
||||
│ └── uiStore.js # UI 状态、视图切换
|
||||
├── services/
|
||||
│ ├── websocket.js # WebSocket 客户端
|
||||
│ ├── runtimeApi.js # 运行时 API
|
||||
│ ├── runtimeControls.js # 运行时控制
|
||||
│ ├── newsApi.js # 新闻 API
|
||||
│ └── tradingApi.js # 交易 API
|
||||
├── utils/
|
||||
│ ├── formatters.js # 格式化工具
|
||||
│ └── modelIcons.js # 模型图标
|
||||
└── config/
|
||||
└── constants.js # Agent 定义、配置
|
||||
```
|
||||
|
||||
## Agent 系统
|
||||
|
||||
### 6 种 Agent 角色
|
||||
|
||||
| 角色 ID | 名称 | 职责 |
|
||||
|---------|------|------|
|
||||
| `fundamentals_analyst` | 基本面分析师 | 财务健康、盈利能力、成长质量 |
|
||||
| `technical_analyst` | 技术分析师 | 价格趋势、技术指标、动量分析 |
|
||||
| `sentiment_analyst` | 情绪分析师 | 市场情绪、新闻情绪、内幕交易 |
|
||||
| `valuation_analyst` | 估值分析师 | DCF、EV/EBITDA、intrinsic value |
|
||||
| `portfolio_manager` | 投资经理 | 决策执行、交易协调 |
|
||||
| `risk_manager` | 风控经理 | 实时价格/波动率监控、仓位限制 |
|
||||
|
||||
### 添加自定义分析师
|
||||
|
||||
1. `backend/agents/prompts/analyst/personas.yaml` 注册
|
||||
2. `backend/config/constants.py` 的 `ANALYST_TYPES` 字典添加
|
||||
3. `frontend/src/config/constants.js` 可选更新
|
||||
|
||||
### LLM 模型封装 (`backend/llm/models.py`)
|
||||
|
||||
- **RetryChatModel**: 自动重试瞬态 LLM 错误,指数退避
|
||||
- **TokenRecordingModelWrapper**: 追踪 token 消耗和成本
|
||||
|
||||
## 技能系统 (`backend/skills/`)
|
||||
|
||||
技能定义在 `SKILL.md`,包含 `instructions`、`triggers`、`parameters`、`available_tools`。
|
||||
|
||||
技能管理器支持 6 种作用域:builtin、customized、installed、active、disabled、local。
|
||||
|
||||
## 运行时数据布局
|
||||
|
||||
- `data/market_research.db` - 持久研究数据
|
||||
- `runs/<run_id>/` - 每次任务运行的状态
|
||||
- `runs/<run_id>/team_dashboard/*.json` - 仪表板导出层(非权威源)
|
||||
- `runs/<run_id>/state/runtime_state.json` - 运行时快照
|
||||
- 运行时 API 优先使用 `server_state.json` 和 `runtime.db`
|
||||
|
||||
```bash
|
||||
RUNS_RETENTION_COUNT=20 # 时间戳格式文件夹自动清理
|
||||
```
|
||||
|
||||
## 环境配置
|
||||
|
||||
### Backend (`env.template`)
|
||||
|
||||
```bash
|
||||
# 金融数据源(支持多源fallback)
|
||||
FIN_DATA_SOURCE=finnhub|financial_datasets|yfinance|local_csv
|
||||
ENABLED_DATA_SOURCES=financial_datasets,finnhub,yfinance,local_csv
|
||||
FINANCIAL_DATASETS_API_KEY= # 回测必需
|
||||
FINNHUB_API_KEY= # 实盘必需
|
||||
POLYGON_API_KEY= # Polygon市场库采集可选
|
||||
|
||||
# LLM 配置
|
||||
OPENAI_API_KEY=
|
||||
OPENAI_BASE_URL=
|
||||
MODEL_NAME=qwen3-max-preview
|
||||
|
||||
# Agent 特定模型
|
||||
AGENT_SENTIMENT_ANALYST_MODEL_NAME=deepseek-v3.2-exp
|
||||
AGENT_TECHNICAL_ANALYST_MODEL_NAME=glm-4.6
|
||||
AGENT_FUNDAMENTALS_ANALYST_MODEL_NAME=qwen3-max-preview
|
||||
AGENT_VALUATION_ANALYST_MODEL_NAME=Moonshot-Kimi-K2-Instruct
|
||||
AGENT_RISK_MANAGER_MODEL_NAME=qwen3-max-preview
|
||||
AGENT_PORTFOLIO_MANAGER_MODEL_NAME=qwen3-max-preview
|
||||
|
||||
# ReMe 记忆系统
|
||||
MEMORY_API_KEY=
|
||||
MEMORY_MODEL_NAME=qwen3-max
|
||||
MEMORY_EMBEDDING_MODEL=text-embedding-v4
|
||||
|
||||
# 交易参数
|
||||
MAX_COMM_CYCLES=2
|
||||
MARGIN_REQUIREMENT=0.5
|
||||
DATA_START_DATE=2022-01-01
|
||||
AUTO_UPDATE_DATA=true
|
||||
```
|
||||
|
||||
### Frontend (`frontend/env.template`)
|
||||
|
||||
```bash
|
||||
VITE_WS_URL=ws://localhost:8765
|
||||
```
|
||||
|
||||
## 关键依赖
|
||||
|
||||
- **AgentScope** - 多智能体框架
|
||||
- **ReMe** - 持续学习记忆系统
|
||||
- **FastAPI** + **uvicorn** - 后端 API
|
||||
- **websockets** - 实时通信
|
||||
- **React 19** + **Vite** + **TailwindCSS** - 前端
|
||||
- **Zustand** - 状态管理
|
||||
15
README.md
15
README.md
@@ -110,6 +110,21 @@ evotraders frontend # Default connects to port 8765, you can modi
|
||||
|
||||
Visit `http://localhost:5173/` to view the trading room, select a date and click Run/Replay to observe the decision-making process.
|
||||
|
||||
### Runtime Data Layout
|
||||
|
||||
- Long-lived research data is stored in `data/market_research.db`
|
||||
- Each task run writes run-scoped state under `runs/<run_id>/`
|
||||
- `runs/<run_id>/team_dashboard/*.json` is an export/compatibility layer for dashboard views, not the authoritative runtime source of truth
|
||||
- Runtime APIs prefer active runtime state, `server_state.json`, and `runtime.db`
|
||||
|
||||
Optional retention control:
|
||||
|
||||
```bash
|
||||
RUNS_RETENTION_COUNT=20
|
||||
```
|
||||
|
||||
Only timestamped run folders like `YYYYMMDD_HHMMSS` are pruned automatically when starting a new runtime. Named runs such as `smoke_fullstack` or `test_*` are preserved.
|
||||
|
||||
---
|
||||
|
||||
## System Architecture
|
||||
|
||||
48
README_zh.md
48
README_zh.md
@@ -117,6 +117,54 @@ evotraders frontend # 默认连接 8765 端口, 你可以修改 .
|
||||
|
||||
访问 `http://localhost:5173/` 查看交易大厅,选择日期并点击 Run/Replay 观察决策过程。
|
||||
|
||||
### 迁移期服务边界说明
|
||||
|
||||
当前仓库正处于从模块化单体向独立服务迁移的阶段,当前默认开发路径已经切到独立 app surface:
|
||||
|
||||
- `backend.apps.agent_service`
|
||||
- `backend.apps.runtime_service`
|
||||
- `backend.apps.trading_service`
|
||||
- `backend.apps.news_service`
|
||||
|
||||
当前本地开发默认推荐直接运行拆分后的服务:
|
||||
|
||||
```bash
|
||||
./start-dev.sh split
|
||||
|
||||
# 或分别手动启动
|
||||
python -m uvicorn backend.apps.agent_service:app --port 8000 --reload
|
||||
python -m uvicorn backend.apps.runtime_service:app --port 8003 --reload
|
||||
python -m uvicorn backend.apps.trading_service:app --port 8001 --reload
|
||||
python -m uvicorn backend.apps.news_service:app --port 8002 --reload
|
||||
```
|
||||
|
||||
迁移期关键环境变量:
|
||||
|
||||
```bash
|
||||
# 后端 Gateway 优先走独立服务读取
|
||||
NEWS_SERVICE_URL=http://localhost:8002
|
||||
TRADING_SERVICE_URL=http://localhost:8001
|
||||
|
||||
# 前端浏览器直连控制面 / 运行时面
|
||||
VITE_CONTROL_API_BASE_URL=http://localhost:8000/api
|
||||
VITE_RUNTIME_API_BASE_URL=http://localhost:8003/api/runtime
|
||||
|
||||
# 前端浏览器优先直连独立服务
|
||||
VITE_NEWS_SERVICE_URL=http://localhost:8002
|
||||
VITE_TRADING_SERVICE_URL=http://localhost:8001
|
||||
```
|
||||
|
||||
目前前端已支持直连 `news-service` 的 explain 只读路径包括:
|
||||
|
||||
- runtime panel / gateway port 查询已可独立指向 `runtime-service`
|
||||
- story
|
||||
- similar days
|
||||
- range explain
|
||||
- news for date
|
||||
- news categories
|
||||
|
||||
如果没有配置这些变量,系统会继续走当前保留的本地回退逻辑。
|
||||
|
||||
---
|
||||
|
||||
## 系统架构
|
||||
|
||||
452
backend/agents/base/evaluation_hook.py
Normal file
452
backend/agents/base/evaluation_hook.py
Normal file
@@ -0,0 +1,452 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Evaluation hooks system for skills.
|
||||
|
||||
Provides evaluation metric collection and storage for skill performance tracking.
|
||||
Based on the evaluation hooks design in SKILL_TEMPLATE.md.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MetricType(Enum):
|
||||
"""Types of evaluation metrics."""
|
||||
HIT_RATE = "hit_rate" # 信号命中率
|
||||
RISK_VIOLATION = "risk_violation" # 风控违例率
|
||||
POSITION_DEVIATION = "position_deviation" # 仓位偏离率
|
||||
PnL_ATTRIBUTION = "pnl_attribution" # P&L 归因一致性
|
||||
SIGNAL_CONSISTENCY = "signal_consistency" # 信号一致性
|
||||
DECISION_LATENCY = "decision_latency" # 决策延迟
|
||||
TOOL_USAGE = "tool_usage" # 工具使用率
|
||||
CUSTOM = "custom" # 自定义指标
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvaluationMetric:
|
||||
"""A single evaluation metric."""
|
||||
name: str
|
||||
metric_type: MetricType
|
||||
value: float
|
||||
timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"name": self.name,
|
||||
"metric_type": self.metric_type.value,
|
||||
"value": self.value,
|
||||
"timestamp": self.timestamp,
|
||||
"metadata": self.metadata,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvaluationResult:
|
||||
"""Evaluation result for a skill execution."""
|
||||
skill_name: str
|
||||
run_id: str
|
||||
agent_id: str
|
||||
metrics: List[EvaluationMetric] = field(default_factory=list)
|
||||
inputs: Dict[str, Any] = field(default_factory=dict)
|
||||
outputs: Dict[str, Any] = field(default_factory=dict)
|
||||
decision: Optional[str] = None
|
||||
success: bool = True
|
||||
error_message: Optional[str] = None
|
||||
started_at: Optional[str] = None
|
||||
completed_at: Optional[str] = field(default_factory=lambda: datetime.now().isoformat())
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"skill_name": self.skill_name,
|
||||
"run_id": self.run_id,
|
||||
"agent_id": self.agent_id,
|
||||
"metrics": [m.to_dict() for m in self.metrics],
|
||||
"inputs": self.inputs,
|
||||
"outputs": self.outputs,
|
||||
"decision": self.decision,
|
||||
"success": self.success,
|
||||
"error_message": self.error_message,
|
||||
"started_at": self.started_at,
|
||||
"completed_at": self.completed_at,
|
||||
}
|
||||
|
||||
|
||||
class EvaluationHook:
|
||||
"""Hook for collecting skill evaluation metrics.
|
||||
|
||||
This hook collects and stores evaluation metrics after skill execution
|
||||
for later analysis and memory/reflection stages.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
storage_dir: Path,
|
||||
run_id: str,
|
||||
agent_id: str,
|
||||
):
|
||||
"""Initialize evaluation hook.
|
||||
|
||||
Args:
|
||||
storage_dir: Directory to store evaluation results
|
||||
run_id: Current run identifier
|
||||
agent_id: Current agent identifier
|
||||
"""
|
||||
self.storage_dir = Path(storage_dir)
|
||||
self.run_id = run_id
|
||||
self.agent_id = agent_id
|
||||
self._current_evaluation: Optional[EvaluationResult] = None
|
||||
|
||||
def start_evaluation(
|
||||
self,
|
||||
skill_name: str,
|
||||
inputs: Dict[str, Any],
|
||||
) -> None:
|
||||
"""Start a new evaluation session.
|
||||
|
||||
Args:
|
||||
skill_name: Name of the skill being evaluated
|
||||
inputs: Input parameters for the skill
|
||||
"""
|
||||
self._current_evaluation = EvaluationResult(
|
||||
skill_name=skill_name,
|
||||
run_id=self.run_id,
|
||||
agent_id=self.agent_id,
|
||||
inputs=inputs,
|
||||
started_at=datetime.now().isoformat(),
|
||||
)
|
||||
logger.debug(f"Started evaluation for skill: {skill_name}")
|
||||
|
||||
def add_metric(
|
||||
self,
|
||||
name: str,
|
||||
metric_type: MetricType,
|
||||
value: float,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""Add an evaluation metric.
|
||||
|
||||
Args:
|
||||
name: Metric name
|
||||
metric_type: Type of metric
|
||||
value: Metric value
|
||||
metadata: Additional metadata
|
||||
"""
|
||||
if self._current_evaluation is None:
|
||||
logger.warning("No active evaluation session, ignoring metric")
|
||||
return
|
||||
|
||||
metric = EvaluationMetric(
|
||||
name=name,
|
||||
metric_type=metric_type,
|
||||
value=value,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
self._current_evaluation.metrics.append(metric)
|
||||
logger.debug(f"Added metric: {name} = {value}")
|
||||
|
||||
def add_metrics(self, metrics: List[EvaluationMetric]) -> None:
|
||||
"""Add multiple evaluation metrics at once.
|
||||
|
||||
Args:
|
||||
metrics: List of metrics to add
|
||||
"""
|
||||
if self._current_evaluation is None:
|
||||
logger.warning("No active evaluation session, ignoring metrics")
|
||||
return
|
||||
|
||||
self._current_evaluation.metrics.extend(metrics)
|
||||
|
||||
def record_outputs(self, outputs: Dict[str, Any]) -> None:
|
||||
"""Record skill outputs.
|
||||
|
||||
Args:
|
||||
outputs: Output from skill execution
|
||||
"""
|
||||
if self._current_evaluation is None:
|
||||
logger.warning("No active evaluation session, ignoring outputs")
|
||||
return
|
||||
|
||||
self._current_evaluation.outputs = outputs
|
||||
|
||||
def record_decision(self, decision: str) -> None:
|
||||
"""Record the final decision.
|
||||
|
||||
Args:
|
||||
decision: Final decision made by the skill
|
||||
"""
|
||||
if self._current_evaluation is None:
|
||||
logger.warning("No active evaluation session, ignoring decision")
|
||||
return
|
||||
|
||||
self._current_evaluation.decision = decision
|
||||
|
||||
def complete_evaluation(
|
||||
self,
|
||||
success: bool = True,
|
||||
error_message: Optional[str] = None,
|
||||
) -> Optional[EvaluationResult]:
|
||||
"""Complete the evaluation session and persist results.
|
||||
|
||||
Args:
|
||||
success: Whether the skill execution was successful
|
||||
error_message: Error message if failed
|
||||
|
||||
Returns:
|
||||
The completed evaluation result, or None if no active evaluation
|
||||
"""
|
||||
if self._current_evaluation is None:
|
||||
logger.warning("No active evaluation to complete")
|
||||
return None
|
||||
|
||||
self._current_evaluation.success = success
|
||||
self._current_evaluation.error_message = error_message
|
||||
self._current_evaluation.completed_at = datetime.now().isoformat()
|
||||
|
||||
# Persist to storage
|
||||
result = self._persist_evaluation(self._current_evaluation)
|
||||
|
||||
self._current_evaluation = None
|
||||
logger.debug(f"Completed evaluation for skill: {result.skill_name}")
|
||||
|
||||
return result
|
||||
|
||||
def _persist_evaluation(self, evaluation: EvaluationResult) -> EvaluationResult:
|
||||
"""Persist evaluation result to storage.
|
||||
|
||||
Args:
|
||||
evaluation: Evaluation result to persist
|
||||
|
||||
Returns:
|
||||
The persisted evaluation
|
||||
"""
|
||||
# Create run-specific directory
|
||||
run_dir = self.storage_dir / self.run_id
|
||||
run_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create agent-specific subdirectory
|
||||
agent_dir = run_dir / self.agent_id
|
||||
agent_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Generate filename with timestamp
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
||||
filename = f"{evaluation.skill_name}_{timestamp}.json"
|
||||
filepath = agent_dir / filename
|
||||
|
||||
# Write evaluation result
|
||||
try:
|
||||
with open(filepath, "w", encoding="utf-8") as f:
|
||||
json.dump(evaluation.to_dict(), f, ensure_ascii=False, indent=2)
|
||||
logger.info(f"Persisted evaluation to: {filepath}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to persist evaluation: {e}")
|
||||
|
||||
return evaluation
|
||||
|
||||
def cancel_evaluation(self) -> None:
|
||||
"""Cancel the current evaluation session without saving."""
|
||||
if self._current_evaluation is not None:
|
||||
logger.debug(f"Cancelled evaluation for: {self._current_evaluation.skill_name}")
|
||||
self._current_evaluation = None
|
||||
|
||||
|
||||
class EvaluationCollector:
|
||||
"""Collector for aggregating evaluation metrics across runs.
|
||||
|
||||
Provides methods to query and analyze evaluation results.
|
||||
"""
|
||||
|
||||
def __init__(self, storage_dir: Path):
|
||||
"""Initialize evaluation collector.
|
||||
|
||||
Args:
|
||||
storage_dir: Root directory containing evaluation results
|
||||
"""
|
||||
self.storage_dir = Path(storage_dir)
|
||||
|
||||
def get_run_evaluations(
|
||||
self,
|
||||
run_id: str,
|
||||
agent_id: Optional[str] = None,
|
||||
) -> List[EvaluationResult]:
|
||||
"""Get all evaluations for a run.
|
||||
|
||||
Args:
|
||||
run_id: Run identifier
|
||||
agent_id: Optional agent identifier to filter by
|
||||
|
||||
Returns:
|
||||
List of evaluation results
|
||||
"""
|
||||
run_dir = self.storage_dir / run_id
|
||||
if not run_dir.exists():
|
||||
return []
|
||||
|
||||
evaluations = []
|
||||
|
||||
agent_dirs = [run_dir / agent_id] if agent_id else run_dir.iterdir()
|
||||
|
||||
for agent_dir in agent_dirs:
|
||||
if not agent_dir.is_dir():
|
||||
continue
|
||||
|
||||
for eval_file in agent_dir.glob("*.json"):
|
||||
try:
|
||||
with open(eval_file, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
evaluations.append(self._parse_evaluation(data))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load evaluation {eval_file}: {e}")
|
||||
|
||||
return evaluations
|
||||
|
||||
def get_skill_metrics(
|
||||
self,
|
||||
skill_name: str,
|
||||
run_ids: Optional[List[str]] = None,
|
||||
) -> List[EvaluationMetric]:
|
||||
"""Get all metrics for a specific skill.
|
||||
|
||||
Args:
|
||||
skill_name: Name of the skill
|
||||
run_ids: Optional list of run IDs to filter by
|
||||
|
||||
Returns:
|
||||
List of metrics for the skill
|
||||
"""
|
||||
metrics = []
|
||||
|
||||
if run_ids is None:
|
||||
run_ids = [d.name for d in self.storage_dir.iterdir() if d.is_dir()]
|
||||
|
||||
for run_id in run_ids:
|
||||
evaluations = self.get_run_evaluations(run_id)
|
||||
for eval_result in evaluations:
|
||||
if eval_result.skill_name == skill_name:
|
||||
metrics.extend(eval_result.metrics)
|
||||
|
||||
return metrics
|
||||
|
||||
def calculate_skill_stats(
|
||||
self,
|
||||
skill_name: str,
|
||||
metric_type: MetricType,
|
||||
run_ids: Optional[List[str]] = None,
|
||||
) -> Dict[str, float]:
|
||||
"""Calculate statistics for a specific metric type.
|
||||
|
||||
Args:
|
||||
skill_name: Name of the skill
|
||||
metric_type: Type of metric to calculate
|
||||
run_ids: Optional list of run IDs to filter by
|
||||
|
||||
Returns:
|
||||
Dictionary with min, max, avg, count statistics
|
||||
"""
|
||||
metrics = self.get_skill_metrics(skill_name, run_ids)
|
||||
filtered = [m for m in metrics if m.metric_type == metric_type]
|
||||
|
||||
if not filtered:
|
||||
return {"count": 0}
|
||||
|
||||
values = [m.value for m in filtered]
|
||||
return {
|
||||
"count": len(values),
|
||||
"min": min(values),
|
||||
"max": max(values),
|
||||
"avg": sum(values) / len(values),
|
||||
}
|
||||
|
||||
def _parse_evaluation(self, data: Dict[str, Any]) -> EvaluationResult:
|
||||
"""Parse evaluation data into EvaluationResult.
|
||||
|
||||
Args:
|
||||
data: Raw evaluation data
|
||||
|
||||
Returns:
|
||||
Parsed EvaluationResult
|
||||
"""
|
||||
metrics = []
|
||||
for m in data.get("metrics", []):
|
||||
metrics.append(EvaluationMetric(
|
||||
name=m["name"],
|
||||
metric_type=MetricType(m["metric_type"]),
|
||||
value=m["value"],
|
||||
timestamp=m.get("timestamp", ""),
|
||||
metadata=m.get("metadata", {}),
|
||||
))
|
||||
|
||||
return EvaluationResult(
|
||||
skill_name=data["skill_name"],
|
||||
run_id=data["run_id"],
|
||||
agent_id=data["agent_id"],
|
||||
metrics=metrics,
|
||||
inputs=data.get("inputs", {}),
|
||||
outputs=data.get("outputs", {}),
|
||||
decision=data.get("decision"),
|
||||
success=data.get("success", True),
|
||||
error_message=data.get("error_message"),
|
||||
started_at=data.get("started_at"),
|
||||
completed_at=data.get("completed_at"),
|
||||
)
|
||||
|
||||
|
||||
def parse_evaluation_hooks(skill_dir: Path) -> Dict[str, Any]:
|
||||
"""Parse evaluation hooks from SKILL.md.
|
||||
|
||||
Extracts the Optional: Evaluation hooks section from skill documentation.
|
||||
|
||||
Args:
|
||||
skill_dir: Skill directory path
|
||||
|
||||
Returns:
|
||||
Dictionary containing evaluation hook definitions
|
||||
"""
|
||||
skill_md = skill_dir / "SKILL.md"
|
||||
if not skill_md.exists():
|
||||
return {}
|
||||
|
||||
try:
|
||||
content = skill_md.read_text(encoding="utf-8")
|
||||
|
||||
# Extract evaluation hooks section
|
||||
if "## Optional: Evaluation hooks" in content:
|
||||
start = content.find("## Optional: Evaluation hooks")
|
||||
# Find the next ## section or end of file
|
||||
next_section = content.find("\n## ", start + 1)
|
||||
if next_section == -1:
|
||||
eval_section = content[start:]
|
||||
else:
|
||||
eval_section = content[start:next_section]
|
||||
|
||||
# Parse metrics from the section
|
||||
metrics = []
|
||||
for metric_type in MetricType:
|
||||
if metric_type.value.replace("_", " ") in eval_section.lower():
|
||||
metrics.append(metric_type.value)
|
||||
|
||||
return {
|
||||
"supported_metrics": metrics,
|
||||
"section_content": eval_section.strip(),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse evaluation hooks: {e}")
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
__all__ = [
|
||||
"MetricType",
|
||||
"EvaluationMetric",
|
||||
"EvaluationResult",
|
||||
"EvaluationHook",
|
||||
"EvaluationCollector",
|
||||
"parse_evaluation_hooks",
|
||||
]
|
||||
@@ -470,7 +470,7 @@ class EvoAgent(ToolGuardMixin, ReActAgent):
|
||||
"""
|
||||
return self._messenger
|
||||
|
||||
def delegate_task(
|
||||
async def delegate_task(
|
||||
self,
|
||||
task_type: str,
|
||||
task_data: Dict[str, Any],
|
||||
@@ -493,7 +493,7 @@ class EvoAgent(ToolGuardMixin, ReActAgent):
|
||||
}
|
||||
|
||||
try:
|
||||
return self._task_delegator.delegate_task(
|
||||
return await self._task_delegator.delegate_task(
|
||||
task_type=task_type,
|
||||
task_data=task_data,
|
||||
target_agent=target_agent,
|
||||
|
||||
489
backend/agents/base/skill_adaptation_hook.py
Normal file
489
backend/agents/base/skill_adaptation_hook.py
Normal file
@@ -0,0 +1,489 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Skill adaptation hook for automatic evaluation-to-iteration闭环.
|
||||
|
||||
Monitors evaluation metrics against configurable thresholds and triggers
|
||||
automatic skill reload or logs warnings when thresholds are breached.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
from .evaluation_hook import (
|
||||
EvaluationCollector,
|
||||
EvaluationResult,
|
||||
MetricType,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AdaptationAction(Enum):
|
||||
"""Actions to take when threshold is breached."""
|
||||
RELOAD = "reload" # 自动重新加载技能
|
||||
WARN = "warn" # 记录警告供人工审核
|
||||
BOTH = "both" # 同时执行重载和警告
|
||||
NONE = "none" # 不做任何操作
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdaptationThreshold:
|
||||
"""Threshold configuration for a metric."""
|
||||
metric_type: MetricType
|
||||
operator: str = "lt" # lt (less than), gt (greater than), lte, gte, eq
|
||||
value: float = 0.0
|
||||
window_size: int = 10 # 移动窗口大小,用于计算滑动平均
|
||||
min_samples: int = 5 # 最少样本数才触发检查
|
||||
action: AdaptationAction = AdaptationAction.WARN
|
||||
cooldown_seconds: int = 300 # 触发后的冷却时间
|
||||
|
||||
def evaluate(self, current_value: float) -> bool:
|
||||
"""Evaluate if threshold is breached."""
|
||||
ops = {
|
||||
"lt": lambda x, y: x < y,
|
||||
"lte": lambda x, y: x <= y,
|
||||
"gt": lambda x, y: x > y,
|
||||
"gte": lambda x, y: x >= y,
|
||||
"eq": lambda x, y: x == y,
|
||||
}
|
||||
op_func = ops.get(self.operator)
|
||||
if op_func is None:
|
||||
logger.warning(f"Unknown operator: {self.operator}")
|
||||
return False
|
||||
return op_func(current_value, self.value)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"metric_type": self.metric_type.value,
|
||||
"operator": self.operator,
|
||||
"value": self.value,
|
||||
"window_size": self.window_size,
|
||||
"min_samples": self.min_samples,
|
||||
"action": self.action.value,
|
||||
"cooldown_seconds": self.cooldown_seconds,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdaptationEvent:
|
||||
"""Record of an adaptation trigger event."""
|
||||
timestamp: str
|
||||
skill_name: str
|
||||
metric_type: MetricType
|
||||
threshold: AdaptationThreshold
|
||||
current_value: float
|
||||
avg_value: float
|
||||
action_taken: AdaptationAction
|
||||
details: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"timestamp": self.timestamp,
|
||||
"skill_name": self.skill_name,
|
||||
"metric_type": self.metric_type.value,
|
||||
"threshold": self.threshold.to_dict(),
|
||||
"current_value": self.current_value,
|
||||
"avg_value": self.avg_value,
|
||||
"action_taken": self.action_taken.value,
|
||||
"details": self.details,
|
||||
}
|
||||
|
||||
|
||||
class SkillAdaptationHook:
|
||||
"""Hook for monitoring evaluation metrics and triggering skill adaptation.
|
||||
|
||||
This hook wraps EvaluationHook to add threshold-based adaptation logic.
|
||||
When metrics breach configured thresholds, it can:
|
||||
- Automatically reload skills via SkillsManager
|
||||
- Log warnings for human review
|
||||
- Both
|
||||
"""
|
||||
|
||||
# Default thresholds for common metrics
|
||||
DEFAULT_THRESHOLDS: List[AdaptationThreshold] = [
|
||||
AdaptationThreshold(
|
||||
metric_type=MetricType.HIT_RATE,
|
||||
operator="lt",
|
||||
value=0.5,
|
||||
action=AdaptationAction.WARN,
|
||||
cooldown_seconds=600,
|
||||
),
|
||||
AdaptationThreshold(
|
||||
metric_type=MetricType.RISK_VIOLATION,
|
||||
operator="gt",
|
||||
value=0.1,
|
||||
action=AdaptationAction.WARN,
|
||||
cooldown_seconds=300,
|
||||
),
|
||||
AdaptationThreshold(
|
||||
metric_type=MetricType.DECISION_LATENCY,
|
||||
operator="gt",
|
||||
value=5000, # 5 seconds
|
||||
action=AdaptationAction.WARN,
|
||||
cooldown_seconds=300,
|
||||
),
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
storage_dir: Path,
|
||||
run_id: str,
|
||||
agent_id: str,
|
||||
thresholds: Optional[List[AdaptationThreshold]] = None,
|
||||
collector: Optional[EvaluationCollector] = None,
|
||||
):
|
||||
"""Initialize skill adaptation hook.
|
||||
|
||||
Args:
|
||||
storage_dir: Directory to store adaptation events
|
||||
run_id: Current run identifier
|
||||
agent_id: Current agent identifier
|
||||
thresholds: Custom threshold configurations (uses defaults if None)
|
||||
collector: Optional EvaluationCollector for historical data
|
||||
"""
|
||||
self.storage_dir = Path(storage_dir)
|
||||
self.run_id = run_id
|
||||
self.agent_id = agent_id
|
||||
self.thresholds = thresholds or self.DEFAULT_THRESHOLDS
|
||||
self.collector = collector or EvaluationCollector(storage_dir)
|
||||
|
||||
# Track cooldowns to prevent rapid re-triggering
|
||||
self._cooldowns: Dict[str, datetime] = {}
|
||||
|
||||
# Store recent metrics in memory for quick access
|
||||
self._recent_metrics: Dict[str, List[float]] = {}
|
||||
|
||||
# Pending adaptation events
|
||||
self._pending_events: List[AdaptationEvent] = []
|
||||
|
||||
def check_threshold(
|
||||
self,
|
||||
skill_name: str,
|
||||
metric_type: MetricType,
|
||||
current_value: float,
|
||||
) -> Optional[AdaptationEvent]:
|
||||
"""Check if a metric breaches any threshold.
|
||||
|
||||
Args:
|
||||
skill_name: Name of the skill
|
||||
metric_type: Type of metric
|
||||
current_value: Current metric value
|
||||
|
||||
Returns:
|
||||
AdaptationEvent if threshold breached, None otherwise
|
||||
"""
|
||||
# Find applicable thresholds
|
||||
applicable_thresholds = [
|
||||
t for t in self.thresholds
|
||||
if t.metric_type == metric_type
|
||||
]
|
||||
|
||||
if not applicable_thresholds:
|
||||
return None
|
||||
|
||||
# Check cooldown
|
||||
cooldown_key = f"{skill_name}:{metric_type.value}"
|
||||
now = datetime.now()
|
||||
last_trigger = self._cooldowns.get(cooldown_key)
|
||||
|
||||
# Store current value first for avg calculation
|
||||
self._store_metric(cooldown_key, current_value)
|
||||
|
||||
for threshold in applicable_thresholds:
|
||||
if last_trigger:
|
||||
elapsed = (now - last_trigger).total_seconds()
|
||||
if elapsed < threshold.cooldown_seconds:
|
||||
continue
|
||||
|
||||
# Evaluate threshold
|
||||
if threshold.evaluate(current_value):
|
||||
# Calculate moving average
|
||||
avg_value = self._calculate_avg(skill_name, metric_type, current_value)
|
||||
|
||||
# Check minimum samples (allow immediate trigger if min_samples <= 1)
|
||||
sample_count = len(self._recent_metrics.get(cooldown_key, []))
|
||||
if threshold.min_samples > 1 and sample_count < threshold.min_samples:
|
||||
# Not enough samples yet
|
||||
continue
|
||||
|
||||
# Trigger adaptation
|
||||
event = AdaptationEvent(
|
||||
timestamp=now.isoformat(),
|
||||
skill_name=skill_name,
|
||||
metric_type=metric_type,
|
||||
threshold=threshold,
|
||||
current_value=current_value,
|
||||
avg_value=avg_value,
|
||||
action_taken=threshold.action,
|
||||
details={
|
||||
"run_id": self.run_id,
|
||||
"agent_id": self.agent_id,
|
||||
},
|
||||
)
|
||||
|
||||
# Update cooldown
|
||||
self._cooldowns[cooldown_key] = now
|
||||
|
||||
# Persist event
|
||||
self._persist_event(event)
|
||||
|
||||
logger.info(
|
||||
f"Threshold breached for {skill_name}.{metric_type.value}: "
|
||||
f"current={current_value}, avg={avg_value}, action={threshold.action.value}"
|
||||
)
|
||||
|
||||
return event
|
||||
|
||||
return None
|
||||
|
||||
def _calculate_avg(
|
||||
self,
|
||||
skill_name: str,
|
||||
metric_type: MetricType,
|
||||
current_value: float,
|
||||
) -> float:
|
||||
"""Calculate moving average for a metric."""
|
||||
key = f"{skill_name}:{metric_type.value}"
|
||||
values = self._recent_metrics.get(key, [])
|
||||
if not values:
|
||||
return current_value
|
||||
return sum(values) / len(values)
|
||||
|
||||
def _store_metric(self, key: str, value: float) -> None:
|
||||
"""Store metric value with sliding window."""
|
||||
if key not in self._recent_metrics:
|
||||
self._recent_metrics[key] = []
|
||||
self._recent_metrics[key].append(value)
|
||||
# Keep only last 100 values
|
||||
if len(self._recent_metrics[key]) > 100:
|
||||
self._recent_metrics[key] = self._recent_metrics[key][-100:]
|
||||
|
||||
def _persist_event(self, event: AdaptationEvent) -> None:
|
||||
"""Persist adaptation event to storage."""
|
||||
run_dir = self.storage_dir / self.run_id / "adaptations"
|
||||
run_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
||||
filename = f"{event.skill_name}_{event.metric_type.value}_{timestamp}.json"
|
||||
filepath = run_dir / filename
|
||||
|
||||
try:
|
||||
with open(filepath, "w", encoding="utf-8") as f:
|
||||
json.dump(event.to_dict(), f, ensure_ascii=False, indent=2)
|
||||
logger.debug(f"Persisted adaptation event to: {filepath}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to persist adaptation event: {e}")
|
||||
|
||||
# Also add to pending list
|
||||
self._pending_events.append(event)
|
||||
|
||||
def get_pending_warnings(self) -> List[AdaptationEvent]:
|
||||
"""Get all pending warning events that need human review."""
|
||||
return [
|
||||
e for e in self._pending_events
|
||||
if e.action_taken in (AdaptationAction.WARN, AdaptationAction.BOTH)
|
||||
]
|
||||
|
||||
def clear_pending_warnings(self) -> None:
|
||||
"""Clear pending warnings after they have been reviewed."""
|
||||
self._pending_events = [
|
||||
e for e in self._pending_events
|
||||
if e.action_taken == AdaptationAction.RELOAD
|
||||
]
|
||||
|
||||
def get_recent_events(
|
||||
self,
|
||||
skill_name: Optional[str] = None,
|
||||
metric_type: Optional[MetricType] = None,
|
||||
limit: int = 50,
|
||||
) -> List[AdaptationEvent]:
|
||||
"""Get recent adaptation events.
|
||||
|
||||
Args:
|
||||
skill_name: Optional filter by skill name
|
||||
metric_type: Optional filter by metric type
|
||||
limit: Maximum number of events to return
|
||||
|
||||
Returns:
|
||||
List of recent adaptation events
|
||||
"""
|
||||
events_dir = self.storage_dir / self.run_id / "adaptations"
|
||||
if not events_dir.exists():
|
||||
return []
|
||||
|
||||
events = []
|
||||
for eval_file in sorted(events_dir.glob("*.json"), reverse=True)[:limit]:
|
||||
try:
|
||||
with open(eval_file, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
event = self._parse_event(data)
|
||||
if skill_name and event.skill_name != skill_name:
|
||||
continue
|
||||
if metric_type and event.metric_type != metric_type:
|
||||
continue
|
||||
events.append(event)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load adaptation event {eval_file}: {e}")
|
||||
|
||||
return events
|
||||
|
||||
def _parse_event(self, data: Dict[str, Any]) -> AdaptationEvent:
|
||||
"""Parse adaptation event from JSON data."""
|
||||
threshold_data = data.get("threshold", {})
|
||||
metric_type = MetricType(threshold_data.get("metric_type", "custom"))
|
||||
|
||||
threshold = AdaptationThreshold(
|
||||
metric_type=metric_type,
|
||||
operator=threshold_data.get("operator", "lt"),
|
||||
value=threshold_data.get("value", 0.0),
|
||||
window_size=threshold_data.get("window_size", 10),
|
||||
min_samples=threshold_data.get("min_samples", 5),
|
||||
action=AdaptationAction(threshold_data.get("action", "warn")),
|
||||
cooldown_seconds=threshold_data.get("cooldown_seconds", 300),
|
||||
)
|
||||
|
||||
return AdaptationEvent(
|
||||
timestamp=data.get("timestamp", ""),
|
||||
skill_name=data.get("skill_name", ""),
|
||||
metric_type=metric_type,
|
||||
threshold=threshold,
|
||||
current_value=data.get("current_value", 0.0),
|
||||
avg_value=data.get("avg_value", 0.0),
|
||||
action_taken=AdaptationAction(data.get("action_taken", "warn")),
|
||||
details=data.get("details", {}),
|
||||
)
|
||||
|
||||
def add_threshold(self, threshold: AdaptationThreshold) -> None:
|
||||
"""Add a new threshold configuration."""
|
||||
self.thresholds.append(threshold)
|
||||
|
||||
def remove_threshold(self, metric_type: MetricType) -> None:
|
||||
"""Remove all thresholds for a specific metric type."""
|
||||
self.thresholds = [
|
||||
t for t in self.thresholds
|
||||
if t.metric_type != metric_type
|
||||
]
|
||||
|
||||
def update_threshold(
|
||||
self,
|
||||
metric_type: MetricType,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Update threshold configuration for a metric type."""
|
||||
for threshold in self.thresholds:
|
||||
if threshold.metric_type == metric_type:
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(threshold, key):
|
||||
setattr(threshold, key, value)
|
||||
|
||||
def get_thresholds(self) -> List[AdaptationThreshold]:
|
||||
"""Get current threshold configurations."""
|
||||
return list(self.thresholds)
|
||||
|
||||
def is_in_cooldown(self, skill_name: str, metric_type: MetricType) -> bool:
|
||||
"""Check if a skill/metric combination is in cooldown period."""
|
||||
key = f"{skill_name}:{metric_type.value}"
|
||||
last_trigger = self._cooldowns.get(key)
|
||||
if not last_trigger:
|
||||
return False
|
||||
|
||||
# Find the threshold for this metric type
|
||||
for threshold in self.thresholds:
|
||||
if threshold.metric_type == metric_type:
|
||||
elapsed = (datetime.now() - last_trigger).total_seconds()
|
||||
return elapsed < threshold.cooldown_seconds
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class AdaptationManager:
|
||||
"""Manager for coordinating skill adaptation across multiple agents.
|
||||
|
||||
Provides centralized tracking of adaptation events and skill reloads.
|
||||
"""
|
||||
|
||||
def __init__(self, storage_dir: Path):
|
||||
"""Initialize adaptation manager.
|
||||
|
||||
Args:
|
||||
storage_dir: Root directory for storing adaptation data
|
||||
"""
|
||||
self.storage_dir = Path(storage_dir)
|
||||
self._hooks: Dict[str, SkillAdaptationHook] = {}
|
||||
|
||||
def get_hook(
|
||||
self,
|
||||
run_id: str,
|
||||
agent_id: str,
|
||||
thresholds: Optional[List[AdaptationThreshold]] = None,
|
||||
) -> SkillAdaptationHook:
|
||||
"""Get or create an adaptation hook for an agent.
|
||||
|
||||
Args:
|
||||
run_id: Run identifier
|
||||
agent_id: Agent identifier
|
||||
thresholds: Optional custom thresholds
|
||||
|
||||
Returns:
|
||||
SkillAdaptationHook instance
|
||||
"""
|
||||
key = f"{run_id}:{agent_id}"
|
||||
if key not in self._hooks:
|
||||
self._hooks[key] = SkillAdaptationHook(
|
||||
storage_dir=self.storage_dir,
|
||||
run_id=run_id,
|
||||
agent_id=agent_id,
|
||||
thresholds=thresholds,
|
||||
)
|
||||
return self._hooks[key]
|
||||
|
||||
def get_all_pending_warnings(self) -> List[AdaptationEvent]:
|
||||
"""Get all pending warnings from all hooks."""
|
||||
warnings = []
|
||||
for hook in self._hooks.values():
|
||||
warnings.extend(hook.get_pending_warnings())
|
||||
return warnings
|
||||
|
||||
def get_run_adaptations(self, run_id: str) -> List[AdaptationEvent]:
|
||||
"""Get all adaptation events for a run."""
|
||||
events = []
|
||||
for hook in self._hooks.values():
|
||||
if hook.run_id == run_id:
|
||||
events.extend(hook.get_recent_events())
|
||||
return events
|
||||
|
||||
|
||||
# Global manager instance
|
||||
_adaptation_manager: Optional[AdaptationManager] = None
|
||||
|
||||
|
||||
def get_adaptation_manager(storage_dir: Optional[Path] = None) -> AdaptationManager:
|
||||
"""Get global adaptation manager instance.
|
||||
|
||||
Args:
|
||||
storage_dir: Optional storage directory (required on first call)
|
||||
|
||||
Returns:
|
||||
AdaptationManager instance
|
||||
"""
|
||||
global _adaptation_manager
|
||||
if _adaptation_manager is None:
|
||||
if storage_dir is None:
|
||||
raise ValueError("storage_dir required on first initialization")
|
||||
_adaptation_manager = AdaptationManager(storage_dir)
|
||||
return _adaptation_manager
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AdaptationAction",
|
||||
"AdaptationThreshold",
|
||||
"AdaptationEvent",
|
||||
"SkillAdaptationHook",
|
||||
"AdaptationManager",
|
||||
"get_adaptation_manager",
|
||||
]
|
||||
@@ -289,6 +289,7 @@ class ToolGuardMixin:
|
||||
self._approval_timeout = approval_timeout
|
||||
self._pending_approval: Optional[ToolApprovalRequest] = None
|
||||
self._approval_callback: Optional[Callable[[ToolApprovalRequest], None]] = None
|
||||
self._approval_lock = asyncio.Lock()
|
||||
|
||||
def set_approval_callback(
|
||||
self,
|
||||
@@ -383,73 +384,80 @@ class ToolGuardMixin:
|
||||
Returns:
|
||||
True if approved, False otherwise
|
||||
"""
|
||||
record = TOOL_GUARD_STORE.create_pending(
|
||||
tool_name=tool_name,
|
||||
tool_input=tool_input,
|
||||
agent_id=getattr(self, "agent_id", "unknown"),
|
||||
workspace_id=getattr(self, "workspace_id", "default"),
|
||||
session_id=getattr(self, "session_id", None),
|
||||
findings=default_findings_for_tool(tool_name),
|
||||
)
|
||||
|
||||
manager = get_global_runtime_manager()
|
||||
if manager:
|
||||
manager.register_pending_approval(
|
||||
record.approval_id,
|
||||
{
|
||||
"tool_name": record.tool_name,
|
||||
"agent_id": record.agent_id,
|
||||
"workspace_id": record.workspace_id,
|
||||
"session_id": record.session_id,
|
||||
"tool_input": record.tool_input,
|
||||
},
|
||||
async with self._approval_lock:
|
||||
record = TOOL_GUARD_STORE.create_pending(
|
||||
tool_name=tool_name,
|
||||
tool_input=tool_input,
|
||||
agent_id=getattr(self, "agent_id", "unknown"),
|
||||
workspace_id=getattr(self, "workspace_id", "default"),
|
||||
session_id=getattr(self, "session_id", None),
|
||||
findings=default_findings_for_tool(tool_name),
|
||||
)
|
||||
|
||||
self._pending_approval = ToolApprovalRequest(
|
||||
approval_id=record.approval_id,
|
||||
tool_name=tool_name,
|
||||
tool_input=tool_input,
|
||||
tool_call_id=tool_call_id,
|
||||
session_id=getattr(self, "session_id", None),
|
||||
)
|
||||
record.pending_request = self._pending_approval
|
||||
manager = get_global_runtime_manager()
|
||||
if manager:
|
||||
manager.register_pending_approval(
|
||||
record.approval_id,
|
||||
{
|
||||
"tool_name": record.tool_name,
|
||||
"agent_id": record.agent_id,
|
||||
"workspace_id": record.workspace_id,
|
||||
"session_id": record.session_id,
|
||||
"tool_input": record.tool_input,
|
||||
},
|
||||
)
|
||||
|
||||
# Notify via callback if set
|
||||
if self._approval_callback:
|
||||
self._approval_callback(self._pending_approval)
|
||||
self._pending_approval = ToolApprovalRequest(
|
||||
approval_id=record.approval_id,
|
||||
tool_name=tool_name,
|
||||
tool_input=tool_input,
|
||||
tool_call_id=tool_call_id,
|
||||
session_id=getattr(self, "session_id", None),
|
||||
)
|
||||
record.pending_request = self._pending_approval
|
||||
|
||||
# Wait for approval
|
||||
approval_request = self._pending_approval
|
||||
# Notify via callback if set
|
||||
if self._approval_callback:
|
||||
self._approval_callback(self._pending_approval)
|
||||
|
||||
# Wait for approval (lock is released during wait, re-acquired after)
|
||||
approval_request = self._pending_approval
|
||||
|
||||
# Wait for approval outside the lock to allow concurrent approval
|
||||
approved = await approval_request.wait_for_approval(
|
||||
timeout=self._approval_timeout
|
||||
)
|
||||
|
||||
if approval_request:
|
||||
status = (
|
||||
ApprovalStatus.APPROVED
|
||||
if approval_request.approved is True
|
||||
else ApprovalStatus.DENIED
|
||||
if approval_request.approved is False
|
||||
else ApprovalStatus.EXPIRED
|
||||
)
|
||||
TOOL_GUARD_STORE.set_status(
|
||||
approval_request.approval_id,
|
||||
status,
|
||||
resolved_by="agent",
|
||||
notify_request=False,
|
||||
)
|
||||
manager = get_global_runtime_manager()
|
||||
if manager:
|
||||
manager.resolve_pending_approval(
|
||||
approval_request.approval_id,
|
||||
resolved_by="agent",
|
||||
status=status.value,
|
||||
async with self._approval_lock:
|
||||
if approval_request:
|
||||
status = (
|
||||
ApprovalStatus.APPROVED
|
||||
if approval_request.approved is True
|
||||
else ApprovalStatus.DENIED
|
||||
if approval_request.approved is False
|
||||
else ApprovalStatus.EXPIRED
|
||||
)
|
||||
TOOL_GUARD_STORE.set_status(
|
||||
approval_request.approval_id,
|
||||
status,
|
||||
resolved_by="agent",
|
||||
notify_request=False,
|
||||
)
|
||||
manager = get_global_runtime_manager()
|
||||
if manager:
|
||||
manager.resolve_pending_approval(
|
||||
approval_request.approval_id,
|
||||
resolved_by="agent",
|
||||
status=status.value,
|
||||
)
|
||||
|
||||
# Only clear if this is still the same request
|
||||
if self._pending_approval is approval_request:
|
||||
self._pending_approval = None
|
||||
|
||||
self._pending_approval = None
|
||||
return approved
|
||||
|
||||
def approve_guard_call(self, request_id: Optional[str] = None) -> bool:
|
||||
async def approve_guard_call(self, request_id: Optional[str] = None) -> bool:
|
||||
"""Approve a pending guard request.
|
||||
|
||||
This method is called externally to approve a tool call
|
||||
@@ -461,28 +469,29 @@ class ToolGuardMixin:
|
||||
Returns:
|
||||
True if a request was approved, False if no pending request
|
||||
"""
|
||||
if self._pending_approval is None:
|
||||
logger.warning("No pending approval request to approve")
|
||||
return False
|
||||
async with self._approval_lock:
|
||||
if self._pending_approval is None:
|
||||
logger.warning("No pending approval request to approve")
|
||||
return False
|
||||
|
||||
TOOL_GUARD_STORE.set_status(
|
||||
self._pending_approval.approval_id,
|
||||
ApprovalStatus.APPROVED,
|
||||
resolved_by="agent",
|
||||
notify_request=False,
|
||||
)
|
||||
manager = get_global_runtime_manager()
|
||||
if manager:
|
||||
manager.resolve_pending_approval(
|
||||
TOOL_GUARD_STORE.set_status(
|
||||
self._pending_approval.approval_id,
|
||||
ApprovalStatus.APPROVED,
|
||||
resolved_by="agent",
|
||||
status=ApprovalStatus.APPROVED.value,
|
||||
notify_request=False,
|
||||
)
|
||||
self._pending_approval.approve()
|
||||
logger.info("Approved tool call: %s", self._pending_approval.tool_name)
|
||||
return True
|
||||
manager = get_global_runtime_manager()
|
||||
if manager:
|
||||
manager.resolve_pending_approval(
|
||||
self._pending_approval.approval_id,
|
||||
resolved_by="agent",
|
||||
status=ApprovalStatus.APPROVED.value,
|
||||
)
|
||||
self._pending_approval.approve()
|
||||
logger.info("Approved tool call: %s", self._pending_approval.tool_name)
|
||||
return True
|
||||
|
||||
def deny_guard_call(self, request_id: Optional[str] = None) -> bool:
|
||||
async def deny_guard_call(self, request_id: Optional[str] = None) -> bool:
|
||||
"""Deny a pending guard request.
|
||||
|
||||
This method is called externally to deny a tool call
|
||||
@@ -494,26 +503,27 @@ class ToolGuardMixin:
|
||||
Returns:
|
||||
True if a request was denied, False if no pending request
|
||||
"""
|
||||
if self._pending_approval is None:
|
||||
logger.warning("No pending approval request to deny")
|
||||
return False
|
||||
async with self._approval_lock:
|
||||
if self._pending_approval is None:
|
||||
logger.warning("No pending approval request to deny")
|
||||
return False
|
||||
|
||||
TOOL_GUARD_STORE.set_status(
|
||||
self._pending_approval.approval_id,
|
||||
ApprovalStatus.DENIED,
|
||||
resolved_by="agent",
|
||||
notify_request=False,
|
||||
)
|
||||
manager = get_global_runtime_manager()
|
||||
if manager:
|
||||
manager.resolve_pending_approval(
|
||||
TOOL_GUARD_STORE.set_status(
|
||||
self._pending_approval.approval_id,
|
||||
ApprovalStatus.DENIED,
|
||||
resolved_by="agent",
|
||||
status=ApprovalStatus.DENIED.value,
|
||||
notify_request=False,
|
||||
)
|
||||
self._pending_approval.deny()
|
||||
logger.info("Denied tool call: %s", self._pending_approval.tool_name)
|
||||
return True
|
||||
manager = get_global_runtime_manager()
|
||||
if manager:
|
||||
manager.resolve_pending_approval(
|
||||
self._pending_approval.approval_id,
|
||||
resolved_by="agent",
|
||||
status=ApprovalStatus.DENIED.value,
|
||||
)
|
||||
self._pending_approval.deny()
|
||||
logger.info("Denied tool call: %s", self._pending_approval.tool_name)
|
||||
return True
|
||||
|
||||
async def _acting(self, tool_call) -> dict | None:
|
||||
"""Intercept sensitive tool calls before execution.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Agent Factory - Dynamic creation and management of EvoAgents."""
|
||||
"""Agent Factory - Dynamic creation and management of AgentConfigs."""
|
||||
|
||||
import logging
|
||||
import shutil
|
||||
@@ -37,8 +37,8 @@ class RoleConfig:
|
||||
self.constraints = []
|
||||
|
||||
|
||||
class EvoAgent:
|
||||
"""Represents a configured agent instance."""
|
||||
class AgentConfig:
|
||||
"""Represents a configured agent instance (data class)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -185,7 +185,7 @@ class AgentFactory:
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
role_config: Optional[RoleConfig] = None,
|
||||
clone_from: Optional[str] = None,
|
||||
) -> EvoAgent:
|
||||
) -> AgentConfig:
|
||||
"""Create a new agent.
|
||||
|
||||
Args:
|
||||
@@ -197,7 +197,7 @@ class AgentFactory:
|
||||
clone_from: Path to existing agent to clone from (optional)
|
||||
|
||||
Returns:
|
||||
EvoAgent instance
|
||||
AgentConfig instance
|
||||
|
||||
Raises:
|
||||
ValueError: If agent already exists or workspace doesn't exist
|
||||
@@ -234,7 +234,7 @@ class AgentFactory:
|
||||
config_path = agent_dir / "agent.yaml"
|
||||
self._write_agent_yaml(config_path, agent_id, agent_type, model_config)
|
||||
|
||||
return EvoAgent(
|
||||
return AgentConfig(
|
||||
agent_id=agent_id,
|
||||
agent_type=agent_type,
|
||||
workspace_id=workspace_id,
|
||||
@@ -267,7 +267,7 @@ class AgentFactory:
|
||||
new_agent_id: str,
|
||||
target_workspace_id: Optional[str] = None,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
) -> EvoAgent:
|
||||
) -> AgentConfig:
|
||||
"""Clone an existing agent.
|
||||
|
||||
Args:
|
||||
@@ -278,7 +278,7 @@ class AgentFactory:
|
||||
model_config: Optional new model configuration
|
||||
|
||||
Returns:
|
||||
EvoAgent instance for the cloned agent
|
||||
AgentConfig instance for the cloned agent
|
||||
"""
|
||||
target_workspace_id = target_workspace_id or source_workspace_id
|
||||
source_dir = self.workspaces_root / source_workspace_id / "agents" / source_agent_id
|
||||
|
||||
@@ -6,10 +6,10 @@ from typing import Any, Optional
|
||||
|
||||
from .agent_workspace import load_agent_workspace_config
|
||||
from backend.config.bootstrap_config import get_bootstrap_config_for_run
|
||||
from .prompt_loader import PromptLoader
|
||||
from .prompt_loader import get_prompt_loader
|
||||
from .skills_manager import SkillsManager
|
||||
|
||||
_prompt_loader = PromptLoader()
|
||||
_prompt_loader = get_prompt_loader()
|
||||
|
||||
|
||||
def _read_file_if_exists(path: Path) -> str:
|
||||
|
||||
@@ -10,6 +10,17 @@ from typing import Any, Dict, Optional
|
||||
|
||||
import yaml
|
||||
|
||||
# Singleton instance
|
||||
_prompt_loader_instance: Optional["PromptLoader"] = None
|
||||
|
||||
|
||||
def get_prompt_loader() -> "PromptLoader":
|
||||
"""Get the singleton PromptLoader instance."""
|
||||
global _prompt_loader_instance
|
||||
if _prompt_loader_instance is None:
|
||||
_prompt_loader_instance = PromptLoader()
|
||||
return _prompt_loader_instance
|
||||
|
||||
|
||||
class PromptLoader:
|
||||
"""Unified Prompt loader"""
|
||||
|
||||
@@ -5,6 +5,7 @@ from pathlib import Path
|
||||
import shutil
|
||||
import tempfile
|
||||
import zipfile
|
||||
from threading import Lock
|
||||
from typing import Any, Dict, Iterable, Iterator, List, Optional, Set
|
||||
from urllib.parse import urlparse
|
||||
from urllib.request import urlretrieve
|
||||
@@ -39,6 +40,9 @@ class SkillsManager:
|
||||
self.project_root / "backend" / "skills" / "customized"
|
||||
)
|
||||
self.runs_root = self.project_root / "runs"
|
||||
self._lock = Lock()
|
||||
# Instance-level pending skill changes (thread-safe via self._lock)
|
||||
self._pending_skill_changes: Dict[str, Set[Path]] = {}
|
||||
|
||||
def get_active_root(self, config_name: str) -> Path:
|
||||
return self.runs_root / config_name / "skills" / "active"
|
||||
@@ -737,7 +741,7 @@ class SkillsManager:
|
||||
if local_root.exists():
|
||||
watched_paths.append(local_root)
|
||||
|
||||
handler = _SkillsChangeHandler(watched_paths, callback)
|
||||
handler = _SkillsChangeHandler(watched_paths, self._pending_skill_changes, callback, self._lock)
|
||||
observer = Observer()
|
||||
for path in watched_paths:
|
||||
observer.schedule(handler, str(path), recursive=True)
|
||||
@@ -759,16 +763,19 @@ class SkillsManager:
|
||||
Map of agent_id -> list of reloaded skill paths, or empty dict
|
||||
if no changes were detected.
|
||||
"""
|
||||
changed = self._pending_skill_changes.get(config_name)
|
||||
if not changed:
|
||||
return {}
|
||||
with self._lock:
|
||||
changed = self._pending_skill_changes.get(config_name)
|
||||
if not changed:
|
||||
return {}
|
||||
|
||||
self._pending_skill_changes[config_name] = set()
|
||||
|
||||
self._pending_skill_changes[config_name] = set()
|
||||
return self.prepare_active_skills(config_name, agent_defaults)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Internal change-tracking state (populated by _SkillsChangeHandler)
|
||||
# -------------------------------------------------------------------------
|
||||
# Legacy class-level reference kept for migration compatibility
|
||||
_pending_skill_changes: Dict[str, Set[Path]] = {}
|
||||
|
||||
def _resolve_disabled_skill_names(
|
||||
@@ -820,11 +827,15 @@ class _SkillsChangeHandler(FileSystemEventHandler):
|
||||
def __init__(
|
||||
self,
|
||||
watched_paths: List[Path],
|
||||
pending_changes: Dict[str, Set[Path]],
|
||||
callback: Optional[Any] = None,
|
||||
lock: Optional[Lock] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._watched_paths = watched_paths
|
||||
self._pending_changes = pending_changes
|
||||
self._callback = callback
|
||||
self._lock = lock
|
||||
|
||||
def on_any_event(self, event: FileSystemEvent) -> None:
|
||||
if event.is_directory:
|
||||
@@ -832,9 +843,12 @@ class _SkillsChangeHandler(FileSystemEventHandler):
|
||||
src_path = Path(event.src_path)
|
||||
for watched in self._watched_paths:
|
||||
if src_path.is_relative_to(watched):
|
||||
SkillsManager._pending_skill_changes.setdefault(
|
||||
self._run_id_from_path(src_path), set()
|
||||
).add(src_path)
|
||||
run_id = self._run_id_from_path(src_path)
|
||||
if self._lock:
|
||||
with self._lock:
|
||||
self._pending_changes.setdefault(run_id, set()).add(src_path)
|
||||
else:
|
||||
self._pending_changes.setdefault(run_id, set()).add(src_path)
|
||||
if self._callback:
|
||||
self._callback([src_path])
|
||||
break
|
||||
|
||||
18
backend/agents/team/__init__.py
Normal file
18
backend/agents/team/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Team module for multi-agent orchestration.
|
||||
|
||||
Provides inter-agent communication, task delegation, and coordination
|
||||
for subagent spawning and lifecycle management.
|
||||
"""
|
||||
|
||||
from .messenger import AgentMessenger
|
||||
from .task_delegator import TaskDelegator
|
||||
from .team_coordinator import TeamCoordinator
|
||||
from .registry import AgentRegistry
|
||||
|
||||
__all__ = [
|
||||
"AgentMessenger",
|
||||
"TaskDelegator",
|
||||
"TeamCoordinator",
|
||||
"AgentRegistry",
|
||||
]
|
||||
225
backend/agents/team/messenger.py
Normal file
225
backend/agents/team/messenger.py
Normal file
@@ -0,0 +1,225 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""AgentMessenger - Pub/sub inter-agent communication.
|
||||
|
||||
Provides broadcast(), send(), and subscribe() for message passing
|
||||
between agents using AgentScope's Msg format.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any, Callable, Dict, List, Optional, Set
|
||||
|
||||
from agentscope.message import Msg
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentMessenger:
|
||||
"""Pub/sub messenger for inter-agent communication.
|
||||
|
||||
Supports:
|
||||
- broadcast(): Send message to all subscribers
|
||||
- send(): Send message to specific agent
|
||||
- subscribe(): Register callback for agent messages
|
||||
- announce(): Send system-wide announcement
|
||||
- enable_auto_broadcast: Auto-broadcast agent replies to all participants
|
||||
|
||||
Messages use AgentScope's Msg format for compatibility.
|
||||
"""
|
||||
|
||||
def __init__(self, enable_auto_broadcast: bool = False):
|
||||
"""Initialize the messenger.
|
||||
|
||||
Args:
|
||||
enable_auto_broadcast: If True, agent replies are automatically
|
||||
broadcast to all subscribed agents.
|
||||
"""
|
||||
self._subscriptions: Dict[str, List[Callable[[Msg], None]]] = {}
|
||||
self._inbox: Dict[str, List[Msg]] = {}
|
||||
self._locks: Dict[str, asyncio.Lock] = {}
|
||||
self._enable_auto_broadcast = enable_auto_broadcast
|
||||
self._participants: Set[str] = set()
|
||||
|
||||
def subscribe(
|
||||
self,
|
||||
agent_id: str,
|
||||
callback: Callable[[Msg], None],
|
||||
) -> None:
|
||||
"""Subscribe an agent to receive messages.
|
||||
|
||||
Args:
|
||||
agent_id: Target agent identifier
|
||||
callback: Async function to call when message received
|
||||
"""
|
||||
if agent_id not in self._subscriptions:
|
||||
self._subscriptions[agent_id] = []
|
||||
self._subscriptions[agent_id].append(callback)
|
||||
logger.debug("Agent %s subscribed to messages", agent_id)
|
||||
|
||||
def unsubscribe(self, agent_id: str, callback: Callable[[Msg], None]) -> None:
|
||||
"""Unsubscribe an agent from messages.
|
||||
|
||||
Args:
|
||||
agent_id: Target agent identifier
|
||||
callback: Callback to remove
|
||||
"""
|
||||
if agent_id in self._subscriptions:
|
||||
try:
|
||||
self._subscriptions[agent_id].remove(callback)
|
||||
logger.debug("Agent %s unsubscribed from messages", agent_id)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
async def send(
|
||||
self,
|
||||
to_agent: str,
|
||||
message: Msg,
|
||||
) -> None:
|
||||
"""Send message to specific agent.
|
||||
|
||||
Args:
|
||||
to_agent: Target agent identifier
|
||||
message: Message to send (uses Msg format)
|
||||
"""
|
||||
async def _deliver():
|
||||
if to_agent in self._subscriptions:
|
||||
for callback in self._subscriptions[to_agent]:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
await callback(message)
|
||||
else:
|
||||
callback(message)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error delivering message to %s: %s",
|
||||
to_agent,
|
||||
e,
|
||||
)
|
||||
|
||||
await _deliver()
|
||||
|
||||
async def broadcast(self, message: Msg) -> None:
|
||||
"""Broadcast message to all subscribed agents.
|
||||
|
||||
Args:
|
||||
message: Message to broadcast (uses Msg format)
|
||||
"""
|
||||
delivery_tasks = []
|
||||
for agent_id, callbacks in self._subscriptions.items():
|
||||
for callback in callbacks:
|
||||
async def _deliver(cb=callback, aid=agent_id):
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(cb):
|
||||
await cb(message)
|
||||
else:
|
||||
cb(message)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error broadcasting to %s: %s",
|
||||
aid,
|
||||
e,
|
||||
)
|
||||
delivery_tasks.append(_deliver())
|
||||
|
||||
if delivery_tasks:
|
||||
await asyncio.gather(*delivery_tasks)
|
||||
|
||||
def inbox(self, agent_id: str) -> List[Msg]:
|
||||
"""Get and clear inbox for agent.
|
||||
|
||||
Args:
|
||||
agent_id: Agent identifier
|
||||
|
||||
Returns:
|
||||
List of messages in inbox
|
||||
"""
|
||||
messages = self._inbox.get(agent_id, [])
|
||||
self._inbox[agent_id] = []
|
||||
return messages
|
||||
|
||||
def inbox_count(self, agent_id: str) -> int:
|
||||
"""Count messages in agent's inbox without clearing.
|
||||
|
||||
Args:
|
||||
agent_id: Agent identifier
|
||||
|
||||
Returns:
|
||||
Number of messages waiting
|
||||
"""
|
||||
return len(self._inbox.get(agent_id, []))
|
||||
|
||||
def add_participant(self, agent_id: str) -> None:
|
||||
"""Add a participant to the messenger.
|
||||
|
||||
Participants are the agents that can receive auto-broadcast messages.
|
||||
|
||||
Args:
|
||||
agent_id: Agent identifier to add
|
||||
"""
|
||||
self._participants.add(agent_id)
|
||||
logger.debug("Agent %s added as participant", agent_id)
|
||||
|
||||
def remove_participant(self, agent_id: str) -> None:
|
||||
"""Remove a participant from the messenger.
|
||||
|
||||
Args:
|
||||
agent_id: Agent identifier to remove
|
||||
"""
|
||||
self._participants.discard(agent_id)
|
||||
logger.debug("Agent %s removed from participants", agent_id)
|
||||
|
||||
@property
|
||||
def enable_auto_broadcast(self) -> bool:
|
||||
"""Check if auto_broadcast is enabled."""
|
||||
return self._enable_auto_broadcast
|
||||
|
||||
@enable_auto_broadcast.setter
|
||||
def enable_auto_broadcast(self, value: bool) -> None:
|
||||
"""Enable or disable auto_broadcast."""
|
||||
self._enable_auto_broadcast = value
|
||||
logger.debug("Auto_broadcast set to %s", value)
|
||||
|
||||
async def announce(self, message: Msg) -> None:
|
||||
"""Send a system-wide announcement to all participants.
|
||||
|
||||
Unlike broadcast(), announce() sends a message from the system/host
|
||||
to all participants without requiring prior subscription.
|
||||
|
||||
Args:
|
||||
message: Announcement message (uses Msg format)
|
||||
"""
|
||||
logger.info("System announcement: %s", message.content)
|
||||
await self.broadcast(message)
|
||||
|
||||
async def auto_broadcast(self, message: Msg) -> None:
|
||||
"""Auto-broadcast message to all participants.
|
||||
|
||||
This is called internally when enable_auto_broadcast is True.
|
||||
Broadcasts to all registered participants.
|
||||
|
||||
Args:
|
||||
message: Message to auto-broadcast (uses Msg format)
|
||||
"""
|
||||
if not self._enable_auto_broadcast:
|
||||
return
|
||||
|
||||
# Broadcast to all participants
|
||||
for participant_id in self._participants:
|
||||
if participant_id in self._subscriptions:
|
||||
for callback in self._subscriptions[participant_id]:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
await callback(message)
|
||||
else:
|
||||
callback(message)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error auto-broadcasting to %s: %s",
|
||||
participant_id,
|
||||
e,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["AgentMessenger"]
|
||||
188
backend/agents/team/registry.py
Normal file
188
backend/agents/team/registry.py
Normal file
@@ -0,0 +1,188 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""AgentRegistry - Agent registration and lookup by role.
|
||||
|
||||
Provides register(), unregister(), and get_by_role() for agent
|
||||
discovery and management.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from agentscope.message import Msg
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentRegistry:
|
||||
"""Registry for agent instances with role-based lookup.
|
||||
|
||||
Supports:
|
||||
- register(): Add agent with roles
|
||||
- unregister(): Remove agent
|
||||
- get_by_role(): Find agents by role
|
||||
- get_by_id(): Get specific agent
|
||||
|
||||
Each agent can have multiple roles for flexible dispatch.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._agents: Dict[str, Any] = {}
|
||||
self._roles: Dict[str, List[str]] = {}
|
||||
self._agent_roles: Dict[str, List[str]] = {}
|
||||
|
||||
def register(
|
||||
self,
|
||||
agent_id: str,
|
||||
agent: Any,
|
||||
roles: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
"""Register an agent with optional roles.
|
||||
|
||||
Args:
|
||||
agent_id: Unique agent identifier
|
||||
agent: Agent instance
|
||||
roles: Optional list of role strings
|
||||
"""
|
||||
self._agents[agent_id] = agent
|
||||
self._agent_roles[agent_id] = roles or []
|
||||
|
||||
for role in self._agent_roles[agent_id]:
|
||||
if role not in self._roles:
|
||||
self._roles[role] = []
|
||||
if agent_id not in self._roles[role]:
|
||||
self._roles[role].append(agent_id)
|
||||
|
||||
logger.info(
|
||||
"Registered agent %s with roles %s",
|
||||
agent_id,
|
||||
self._agent_roles[agent_id],
|
||||
)
|
||||
|
||||
def unregister(self, agent_id: str) -> bool:
|
||||
"""Unregister an agent.
|
||||
|
||||
Args:
|
||||
agent_id: Agent identifier to remove
|
||||
|
||||
Returns:
|
||||
True if agent was removed
|
||||
"""
|
||||
if agent_id not in self._agents:
|
||||
return False
|
||||
|
||||
roles = self._agent_roles.pop(agent_id, [])
|
||||
for role in roles:
|
||||
if role in self._roles:
|
||||
try:
|
||||
self._roles[role].remove(agent_id)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
del self._agents[agent_id]
|
||||
logger.info("Unregistered agent: %s", agent_id)
|
||||
return True
|
||||
|
||||
def get_by_id(self, agent_id: str) -> Optional[Any]:
|
||||
"""Get agent by ID.
|
||||
|
||||
Args:
|
||||
agent_id: Agent identifier
|
||||
|
||||
Returns:
|
||||
Agent instance or None
|
||||
"""
|
||||
return self._agents.get(agent_id)
|
||||
|
||||
def get_by_role(self, role: str) -> List[Any]:
|
||||
"""Get all agents with a given role.
|
||||
|
||||
Args:
|
||||
role: Role string to search for
|
||||
|
||||
Returns:
|
||||
List of agent instances with the role
|
||||
"""
|
||||
agent_ids = self._roles.get(role, [])
|
||||
return [self._agents[aid] for aid in agent_ids if aid in self._agents]
|
||||
|
||||
def get_by_roles(self, roles: List[str]) -> List[Any]:
|
||||
"""Get agents matching ANY of the given roles.
|
||||
|
||||
Args:
|
||||
roles: List of role strings
|
||||
|
||||
Returns:
|
||||
List of unique agent instances matching any role
|
||||
"""
|
||||
seen = set()
|
||||
result = []
|
||||
for role in roles:
|
||||
for agent in self.get_by_role(role):
|
||||
if id(agent) not in seen:
|
||||
seen.add(id(agent))
|
||||
result.append(agent)
|
||||
return result
|
||||
|
||||
def list_agents(self) -> List[str]:
|
||||
"""List all registered agent IDs.
|
||||
|
||||
Returns:
|
||||
List of agent identifiers
|
||||
"""
|
||||
return list(self._agents.keys())
|
||||
|
||||
def list_roles(self) -> List[str]:
|
||||
"""List all registered roles.
|
||||
|
||||
Returns:
|
||||
List of role strings
|
||||
"""
|
||||
return list(self._roles.keys())
|
||||
|
||||
def list_roles_for_agent(self, agent_id: str) -> List[str]:
|
||||
"""List roles for specific agent.
|
||||
|
||||
Args:
|
||||
agent_id: Agent identifier
|
||||
|
||||
Returns:
|
||||
List of role strings
|
||||
"""
|
||||
return list(self._agent_roles.get(agent_id, []))
|
||||
|
||||
def update_roles(self, agent_id: str, roles: List[str]) -> None:
|
||||
"""Update roles for an existing agent.
|
||||
|
||||
Args:
|
||||
agent_id: Agent identifier
|
||||
roles: New list of roles
|
||||
"""
|
||||
if agent_id not in self._agents:
|
||||
raise KeyError(f"Agent not registered: {agent_id}")
|
||||
|
||||
old_roles = self._agent_roles.get(agent_id, [])
|
||||
for role in old_roles:
|
||||
if role in self._roles:
|
||||
try:
|
||||
self._roles[role].remove(agent_id)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
self._agent_roles[agent_id] = roles
|
||||
for role in roles:
|
||||
if role not in self._roles:
|
||||
self._roles[role] = []
|
||||
if agent_id not in self._roles[role]:
|
||||
self._roles[role].append(agent_id)
|
||||
|
||||
logger.info("Updated roles for agent %s: %s", agent_id, roles)
|
||||
|
||||
@property
|
||||
def agents(self) -> Dict[str, Any]:
|
||||
"""Get copy of registered agents dict."""
|
||||
return dict(self._agents)
|
||||
|
||||
|
||||
__all__ = ["AgentRegistry"]
|
||||
620
backend/agents/team/task_delegator.py
Normal file
620
backend/agents/team/task_delegator.py
Normal file
@@ -0,0 +1,620 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""TaskDelegator - Subagent spawning and task delegation.
|
||||
|
||||
Provides delegate() and delegate_parallel() for spawning subagents
|
||||
with separate context and memory. Supports runtime dynamic subagent
|
||||
definition via task_data with description, prompt, and tools.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Union
|
||||
|
||||
from agentscope.message import Msg
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default timeout for subagent execution (seconds)
|
||||
DEFAULT_EXECUTION_TIMEOUT = 120.0
|
||||
|
||||
|
||||
# Type alias for subagent specification
|
||||
SubagentSpec = Dict[str, Any]
|
||||
"""Subagent specification format:
|
||||
{
|
||||
"description": "Expert code reviewer...",
|
||||
"prompt": "Analyze code quality...",
|
||||
"tools": ["Read", "Glob", "Grep"], # Optional: list of tool names
|
||||
"model": "gpt-4o", # Optional: model name
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
class TaskDelegator:
|
||||
"""Delegates tasks to subagents with isolated context.
|
||||
|
||||
Supports:
|
||||
- delegate(): Spawn single subagent for task
|
||||
- delegate_parallel(): Spawn multiple subagents concurrently
|
||||
- delegate_task(): Delegate with dynamic subagent definition from task_data
|
||||
|
||||
Each subagent gets its own memory/context to prevent
|
||||
cross-contamination.
|
||||
|
||||
Dynamic Subagent Definition:
|
||||
task_data can include an "agents" dict to define subagents inline:
|
||||
|
||||
task_data = {
|
||||
"task": "Review the code changes",
|
||||
"agents": {
|
||||
"code-reviewer": {
|
||||
"description": "Expert code reviewer for quality and security.",
|
||||
"prompt": "Analyze code quality and suggest improvements.",
|
||||
"tools": ["Read", "Glob", "Grep"],
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(self, agent: Any):
|
||||
"""Initialize TaskDelegator.
|
||||
|
||||
Args:
|
||||
agent: Parent EvoAgent instance for accessing model, formatter, workspace
|
||||
"""
|
||||
self._agent = agent
|
||||
# Get messenger from parent agent if available
|
||||
self._messenger = getattr(agent, "messenger", None)
|
||||
self._registry = getattr(agent, "_registry", None)
|
||||
self._subagents: Dict[str, Any] = {}
|
||||
self._dynamic_subagents: Dict[str, SubagentSpec] = {}
|
||||
self._tasks: Dict[str, asyncio.Task] = {}
|
||||
|
||||
# Extract model and formatter from parent agent
|
||||
self._model = getattr(agent, "model", None)
|
||||
self._formatter = getattr(agent, "formatter", None)
|
||||
self._workspace_dir = getattr(agent, "workspace_dir", None)
|
||||
self._config_name = getattr(agent, "config_name", None)
|
||||
|
||||
async def delegate(
|
||||
self,
|
||||
agent_id: str,
|
||||
task: Callable[..., Awaitable[Msg]],
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
) -> asyncio.Task:
|
||||
"""Delegate task to a single subagent.
|
||||
|
||||
Args:
|
||||
agent_id: Unique identifier for this subagent instance
|
||||
task: Async function representing the task
|
||||
context: Optional context dict for the subagent
|
||||
|
||||
Returns:
|
||||
asyncio.Task for the delegated task
|
||||
"""
|
||||
async def _run_with_context():
|
||||
result = await task(context or {})
|
||||
return result
|
||||
|
||||
self._tasks[agent_id] = asyncio.create_task(_run_with_context())
|
||||
logger.info("Delegated task to subagent: %s", agent_id)
|
||||
return self._tasks[agent_id]
|
||||
|
||||
async def delegate_parallel(
|
||||
self,
|
||||
tasks: List[Dict[str, Any]],
|
||||
) -> List[asyncio.Task]:
|
||||
"""Delegate multiple tasks in parallel.
|
||||
|
||||
Args:
|
||||
tasks: List of task dicts with keys:
|
||||
- agent_id: Unique identifier
|
||||
- task: Async function to execute
|
||||
- context: Optional context dict
|
||||
|
||||
Returns:
|
||||
List of asyncio.Task for all delegated tasks
|
||||
"""
|
||||
async def _run_task(task_def: Dict[str, Any]):
|
||||
agent_id = task_def["agent_id"]
|
||||
task_func = task_def["task"]
|
||||
context = task_def.get("context", {})
|
||||
|
||||
async def _run_with_context():
|
||||
return await task_func(context)
|
||||
|
||||
self._tasks[agent_id] = asyncio.create_task(_run_with_context())
|
||||
return self._tasks[agent_id]
|
||||
|
||||
gathered_tasks = await asyncio.gather(
|
||||
*[_run_task(t) for t in tasks],
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
valid_tasks = [t for t in gathered_tasks if isinstance(t, asyncio.Task)]
|
||||
logger.info(
|
||||
"Delegated %d tasks in parallel (%d succeeded)",
|
||||
len(tasks),
|
||||
len(valid_tasks),
|
||||
)
|
||||
return valid_tasks
|
||||
|
||||
async def wait_for(self, agent_id: str, timeout: Optional[float] = None) -> Any:
|
||||
"""Wait for subagent task to complete.
|
||||
|
||||
Args:
|
||||
agent_id: Subagent identifier
|
||||
timeout: Optional timeout in seconds
|
||||
|
||||
Returns:
|
||||
Task result
|
||||
|
||||
Raises:
|
||||
asyncio.TimeoutError: If task doesn't complete in time
|
||||
KeyError: If agent_id not found
|
||||
"""
|
||||
if agent_id not in self._tasks:
|
||||
raise KeyError(f"Unknown subagent: {agent_id}")
|
||||
|
||||
try:
|
||||
return await asyncio.wait_for(
|
||||
self._tasks[agent_id],
|
||||
timeout=timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Task %s timed out after %s seconds", agent_id, timeout)
|
||||
raise
|
||||
|
||||
async def cancel(self, agent_id: str) -> bool:
|
||||
"""Cancel a subagent task.
|
||||
|
||||
Args:
|
||||
agent_id: Subagent identifier
|
||||
|
||||
Returns:
|
||||
True if task was cancelled
|
||||
"""
|
||||
if agent_id in self._tasks:
|
||||
self._tasks[agent_id].cancel()
|
||||
del self._tasks[agent_id]
|
||||
logger.info("Cancelled subagent task: %s", agent_id)
|
||||
return True
|
||||
return False
|
||||
|
||||
def list_tasks(self) -> List[str]:
|
||||
"""List active subagent task IDs.
|
||||
|
||||
Returns:
|
||||
List of agent_ids with pending tasks
|
||||
"""
|
||||
return list(self._tasks.keys())
|
||||
|
||||
@property
|
||||
def tasks(self) -> Dict[str, asyncio.Task]:
|
||||
"""Get copy of active tasks dict."""
|
||||
return dict(self._tasks)
|
||||
|
||||
async def delegate_task(
|
||||
self,
|
||||
task_type: str,
|
||||
task_data: Dict[str, Any],
|
||||
target_agent: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Delegate a task with optional dynamic subagent definition.
|
||||
|
||||
Supports runtime subagent definition via task_data["agents"]:
|
||||
|
||||
task_data = {
|
||||
"task": "Review code changes",
|
||||
"agents": {
|
||||
"code-reviewer": {
|
||||
"description": "Expert code reviewer...",
|
||||
"prompt": "Analyze code quality...",
|
||||
"tools": ["Read", "Glob", "Grep"],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Args:
|
||||
task_type: Type of task (e.g., "analysis", "review", "research")
|
||||
task_data: Task payload, may include "agents" for dynamic subagent def
|
||||
target_agent: Optional specific agent ID to delegate to
|
||||
|
||||
Returns:
|
||||
Dict with "success" and result/error
|
||||
"""
|
||||
try:
|
||||
# Extract dynamic subagent definitions from task_data
|
||||
agents_def = task_data.get("agents", {})
|
||||
|
||||
if agents_def:
|
||||
# Register dynamic subagents
|
||||
for agent_name, agent_spec in agents_def.items():
|
||||
self._dynamic_subagents[agent_name] = agent_spec
|
||||
logger.info(
|
||||
"Registered dynamic subagent: %s (description: %s)",
|
||||
agent_name,
|
||||
agent_spec.get("description", "")[:50],
|
||||
)
|
||||
|
||||
# Determine target agent
|
||||
effective_target = target_agent
|
||||
if not effective_target:
|
||||
# Use first available dynamic subagent or default
|
||||
if agents_def:
|
||||
effective_target = next(iter(agents_def.keys()))
|
||||
else:
|
||||
effective_target = "default"
|
||||
|
||||
# Execute the task (async)
|
||||
task_result = await self._execute_task(
|
||||
task_type=task_type,
|
||||
task_data=task_data,
|
||||
target_agent=effective_target,
|
||||
)
|
||||
|
||||
# Clean up dynamic subagents after execution
|
||||
for agent_name in agents_def.keys():
|
||||
self._dynamic_subagents.pop(agent_name, None)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"result": task_result,
|
||||
"subagents_used": list(agents_def.keys()) if agents_def else [],
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Task delegation failed: %s", e)
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
async def _execute_task(
|
||||
self,
|
||||
task_type: str,
|
||||
task_data: Dict[str, Any],
|
||||
target_agent: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute the delegated task with a real subagent.
|
||||
|
||||
Args:
|
||||
task_type: Type of task
|
||||
task_data: Task payload
|
||||
target_agent: Target agent identifier
|
||||
|
||||
Returns:
|
||||
Task execution result with success/failure info
|
||||
"""
|
||||
task_content = task_data.get("task", task_data.get("prompt", ""))
|
||||
timeout = task_data.get("timeout", DEFAULT_EXECUTION_TIMEOUT)
|
||||
|
||||
# Check if we have a dynamic subagent spec for this target
|
||||
agent_spec = self._dynamic_subagents.get(target_agent)
|
||||
|
||||
if agent_spec:
|
||||
logger.info(
|
||||
"Executing task '%s' with dynamic subagent '%s'",
|
||||
task_type,
|
||||
target_agent,
|
||||
)
|
||||
return await self._create_and_run_subagent(
|
||||
agent_name=target_agent,
|
||||
agent_spec=agent_spec,
|
||||
task_content=task_content,
|
||||
task_type=task_type,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
# Fallback: try to use parent agent's model to process the task directly
|
||||
logger.info(
|
||||
"Executing task '%s' with parent agent '%s' (no dynamic subagent)",
|
||||
task_type,
|
||||
target_agent,
|
||||
)
|
||||
return await self._run_with_parent_agent(
|
||||
task_content=task_content,
|
||||
task_type=task_type,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
async def _create_and_run_subagent(
|
||||
self,
|
||||
agent_name: str,
|
||||
agent_spec: SubagentSpec,
|
||||
task_content: str,
|
||||
task_type: str,
|
||||
timeout: float,
|
||||
) -> Dict[str, Any]:
|
||||
"""Create and run a dynamic subagent.
|
||||
|
||||
Args:
|
||||
agent_name: Name identifier for the subagent
|
||||
agent_spec: Subagent specification (description, prompt, tools, model)
|
||||
task_content: Task prompt to send to the subagent
|
||||
task_type: Type of task
|
||||
timeout: Execution timeout in seconds
|
||||
|
||||
Returns:
|
||||
Dict with execution results
|
||||
"""
|
||||
subagent_id = f"subagent_{agent_name}_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
try:
|
||||
# Create subagent instance
|
||||
subagent = await self._create_subagent(
|
||||
subagent_id=subagent_id,
|
||||
agent_spec=agent_spec,
|
||||
)
|
||||
|
||||
if subagent is None:
|
||||
return {
|
||||
"task_type": task_type,
|
||||
"task": task_content,
|
||||
"subagent": agent_name,
|
||||
"status": "failed",
|
||||
"error": "Failed to create subagent",
|
||||
"message": f"Could not instantiate subagent '{agent_name}'",
|
||||
}
|
||||
|
||||
# Store for potential cleanup
|
||||
self._subagents[subagent_id] = subagent
|
||||
|
||||
# Execute with timeout
|
||||
result = await asyncio.wait_for(
|
||||
self._run_subagent(subagent, task_content),
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
# Extract response content
|
||||
response_content = ""
|
||||
if isinstance(result, Msg):
|
||||
response_content = result.content
|
||||
elif hasattr(result, "content"):
|
||||
response_content = str(result.content)
|
||||
elif isinstance(result, dict):
|
||||
response_content = result.get("content", str(result))
|
||||
else:
|
||||
response_content = str(result)
|
||||
|
||||
logger.info(
|
||||
"Subagent '%s' completed task '%s' successfully",
|
||||
agent_name,
|
||||
task_type,
|
||||
)
|
||||
|
||||
return {
|
||||
"task_type": task_type,
|
||||
"task": task_content,
|
||||
"subagent": {
|
||||
"name": agent_name,
|
||||
"id": subagent_id,
|
||||
"description": agent_spec.get("description", ""),
|
||||
},
|
||||
"status": "completed",
|
||||
"response": response_content,
|
||||
"message": f"Task '{task_type}' executed with subagent '{agent_name}'",
|
||||
}
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
"Subagent '%s' timed out after %.1f seconds for task '%s'",
|
||||
agent_name,
|
||||
timeout,
|
||||
task_type,
|
||||
)
|
||||
# Cancel the task if still running
|
||||
if subagent_id in self._subagents:
|
||||
self._subagents.pop(subagent_id, None)
|
||||
return {
|
||||
"task_type": task_type,
|
||||
"task": task_content,
|
||||
"subagent": agent_name,
|
||||
"status": "timeout",
|
||||
"error": f"Execution timed out after {timeout} seconds",
|
||||
"message": f"Task '{task_type}' timed out for subagent '{agent_name}'",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Subagent '%s' failed for task '%s': %s",
|
||||
agent_name,
|
||||
task_type,
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
# Cleanup on failure
|
||||
if subagent_id in self._subagents:
|
||||
self._subagents.pop(subagent_id, None)
|
||||
return {
|
||||
"task_type": task_type,
|
||||
"task": task_content,
|
||||
"subagent": agent_name,
|
||||
"status": "error",
|
||||
"error": str(e),
|
||||
"message": f"Task '{task_type}' failed for subagent '{agent_name}': {e}",
|
||||
}
|
||||
|
||||
async def _create_subagent(
|
||||
self,
|
||||
subagent_id: str,
|
||||
agent_spec: SubagentSpec,
|
||||
) -> Optional[Any]:
|
||||
"""Create a subagent instance.
|
||||
|
||||
Uses the parent agent's model/formatter to create a lightweight
|
||||
subagent for task execution.
|
||||
|
||||
Args:
|
||||
subagent_id: Unique identifier for the subagent
|
||||
agent_spec: Subagent specification
|
||||
|
||||
Returns:
|
||||
Subagent instance or None if creation fails
|
||||
"""
|
||||
try:
|
||||
# Import here to avoid circular imports
|
||||
from agentscope.memory import InMemoryMemory
|
||||
|
||||
# Get model and formatter from parent
|
||||
model = self._model
|
||||
formatter = self._formatter
|
||||
|
||||
if model is None:
|
||||
logger.error("Cannot create subagent: parent agent has no model")
|
||||
return None
|
||||
|
||||
# Build system prompt from agent spec
|
||||
description = agent_spec.get("description", "")
|
||||
prompt_template = agent_spec.get("prompt", "")
|
||||
system_prompt = f"""You are {description}
|
||||
|
||||
{prompt_template}
|
||||
|
||||
Your task is to complete the user's request below.
|
||||
"""
|
||||
|
||||
# Create a minimal ReActAgent as the subagent
|
||||
from agentscope.agent import ReActAgent
|
||||
|
||||
subagent = ReActAgent(
|
||||
name=subagent_id,
|
||||
model=model,
|
||||
sys_prompt=system_prompt,
|
||||
toolkit=None, # Could load tools from agent_spec.get("tools", [])
|
||||
memory=InMemoryMemory(),
|
||||
formatter=formatter,
|
||||
max_iters=agent_spec.get("max_iters", 5),
|
||||
)
|
||||
|
||||
logger.debug("Created subagent: %s", subagent_id)
|
||||
return subagent
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to create subagent '%s': %s",
|
||||
subagent_id,
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
return None
|
||||
|
||||
async def _run_subagent(
|
||||
self,
|
||||
subagent: Any,
|
||||
task_content: str,
|
||||
) -> Any:
|
||||
"""Run a subagent with the given task.
|
||||
|
||||
Args:
|
||||
subagent: Subagent instance
|
||||
task_content: Task prompt
|
||||
|
||||
Returns:
|
||||
Agent response (Msg or similar)
|
||||
"""
|
||||
from agentscope.message import Msg
|
||||
|
||||
# Create message for the subagent
|
||||
task_msg = Msg(
|
||||
name="user",
|
||||
content=task_content,
|
||||
role="user",
|
||||
)
|
||||
|
||||
# Execute the agent
|
||||
response = await subagent.reply(task_msg)
|
||||
return response
|
||||
|
||||
async def _run_with_parent_agent(
|
||||
self,
|
||||
task_content: str,
|
||||
task_type: str,
|
||||
timeout: float,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run task using the parent agent directly.
|
||||
|
||||
Used when no dynamic subagent is defined.
|
||||
|
||||
Args:
|
||||
task_content: Task prompt
|
||||
task_type: Type of task
|
||||
timeout: Execution timeout
|
||||
|
||||
Returns:
|
||||
Dict with execution results
|
||||
"""
|
||||
try:
|
||||
result = await asyncio.wait_for(
|
||||
self._agent.reply(Msg(
|
||||
name="user",
|
||||
content=task_content,
|
||||
role="user",
|
||||
)),
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
response_content = ""
|
||||
if isinstance(result, Msg):
|
||||
response_content = result.content
|
||||
elif hasattr(result, "content"):
|
||||
response_content = str(result.content)
|
||||
else:
|
||||
response_content = str(result)
|
||||
|
||||
return {
|
||||
"task_type": task_type,
|
||||
"task": task_content,
|
||||
"status": "completed",
|
||||
"response": response_content,
|
||||
"message": f"Task '{task_type}' executed with parent agent",
|
||||
}
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
return {
|
||||
"task_type": task_type,
|
||||
"task": task_content,
|
||||
"status": "timeout",
|
||||
"error": f"Execution timed out after {timeout} seconds",
|
||||
"message": f"Task '{task_type}' timed out",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Parent agent failed for task '%s': %s",
|
||||
task_type,
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
return {
|
||||
"task_type": task_type,
|
||||
"task": task_content,
|
||||
"status": "error",
|
||||
"error": str(e),
|
||||
"message": f"Task '{task_type}' failed: {e}",
|
||||
}
|
||||
|
||||
def get_dynamic_subagent(self, name: str) -> Optional[SubagentSpec]:
|
||||
"""Get a dynamically defined subagent specification.
|
||||
|
||||
Args:
|
||||
name: Subagent name
|
||||
|
||||
Returns:
|
||||
Subagent spec dict or None if not found
|
||||
"""
|
||||
return self._dynamic_subagents.get(name)
|
||||
|
||||
def list_dynamic_subagents(self) -> List[str]:
|
||||
"""List all registered dynamic subagent names.
|
||||
|
||||
Returns:
|
||||
List of subagent names
|
||||
"""
|
||||
return list(self._dynamic_subagents.keys())
|
||||
|
||||
|
||||
__all__ = ["TaskDelegator", "SubagentSpec"]
|
||||
389
backend/agents/team/team_coordinator.py
Normal file
389
backend/agents/team/team_coordinator.py
Normal file
@@ -0,0 +1,389 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""TeamCoordinator - Agent lifecycle management and execution.
|
||||
|
||||
Provides run_parallel() using asyncio.gather() and run_sequential()
|
||||
for coordinating multiple agents.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Type
|
||||
|
||||
from agentscope.message import Msg
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TeamCoordinator:
|
||||
"""Coordinates agent lifecycle and execution.
|
||||
|
||||
Supports:
|
||||
- run_parallel(): Execute multiple agents concurrently with asyncio.gather()
|
||||
- run_sequential(): Execute agents one after another
|
||||
- run_phase(): Execute a named phase with registered agents
|
||||
- register_agent(): Add agent to coordinator
|
||||
- unregister_agent(): Remove agent from coordinator
|
||||
|
||||
Each agent maintains separate context/memory.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
participants: Optional[List[Any]] = None,
|
||||
task_content: Optional[str] = None,
|
||||
messenger: Optional[Any] = None,
|
||||
registry: Optional[Any] = None,
|
||||
):
|
||||
"""Initialize TeamCoordinator.
|
||||
|
||||
Args:
|
||||
participants: List of agent instances to coordinate
|
||||
task_content: Task description content for the agents
|
||||
messenger: AgentMessenger for communication (optional)
|
||||
registry: AgentRegistry for agent lookup (optional)
|
||||
"""
|
||||
self._participants = participants or []
|
||||
self._task_content = task_content or ""
|
||||
self._messenger = messenger
|
||||
self._registry = registry
|
||||
self._agents: Dict[str, Any] = {}
|
||||
self._running_tasks: Dict[str, asyncio.Task] = {}
|
||||
# Auto-register participants
|
||||
for agent in self._participants:
|
||||
if hasattr(agent, "name"):
|
||||
self._agents[agent.name] = agent
|
||||
elif hasattr(agent, "id"):
|
||||
self._agents[agent.id] = agent
|
||||
|
||||
def register_agent(self, agent_id: str, agent: Any) -> None:
|
||||
"""Register an agent with the coordinator.
|
||||
|
||||
Args:
|
||||
agent_id: Unique agent identifier
|
||||
agent: Agent instance
|
||||
"""
|
||||
self._agents[agent_id] = agent
|
||||
logger.info("Registered agent: %s", agent_id)
|
||||
|
||||
def unregister_agent(self, agent_id: str) -> None:
|
||||
"""Unregister an agent from the coordinator.
|
||||
|
||||
Args:
|
||||
agent_id: Agent identifier to remove
|
||||
"""
|
||||
if agent_id in self._agents:
|
||||
del self._agents[agent_id]
|
||||
logger.info("Unregistered agent: %s", agent_id)
|
||||
|
||||
def get_agent(self, agent_id: str) -> Any:
|
||||
"""Get registered agent by ID.
|
||||
|
||||
Args:
|
||||
agent_id: Agent identifier
|
||||
|
||||
Returns:
|
||||
Agent instance
|
||||
"""
|
||||
return self._agents.get(agent_id)
|
||||
|
||||
def list_agents(self) -> List[str]:
|
||||
"""List all registered agent IDs.
|
||||
|
||||
Returns:
|
||||
List of agent identifiers
|
||||
"""
|
||||
return list(self._agents.keys())
|
||||
|
||||
async def run_parallel(
|
||||
self,
|
||||
agent_ids: List[str],
|
||||
initial_message: Optional[Msg] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run multiple agents in parallel using asyncio.gather().
|
||||
|
||||
Args:
|
||||
agent_ids: List of agent IDs to run concurrently
|
||||
initial_message: Optional initial message to broadcast
|
||||
|
||||
Returns:
|
||||
Dict mapping agent_id to result
|
||||
"""
|
||||
async def _run_agent(aid: str) -> tuple[str, Any]:
|
||||
agent = self._agents.get(aid)
|
||||
if agent is None:
|
||||
logger.error("Agent %s not found", aid)
|
||||
return (aid, None)
|
||||
|
||||
try:
|
||||
if hasattr(agent, "reply") and asyncio.iscoroutinefunction(agent.reply):
|
||||
if initial_message:
|
||||
result = await agent.reply(initial_message)
|
||||
else:
|
||||
result = await agent.reply()
|
||||
elif hasattr(agent, "run") and asyncio.iscoroutinefunction(agent.run):
|
||||
result = await agent.run()
|
||||
else:
|
||||
result = await agent()
|
||||
logger.info("Agent %s completed successfully", aid)
|
||||
return (aid, result)
|
||||
except Exception as e:
|
||||
logger.error("Agent %s failed: %s", aid, e)
|
||||
return (aid, {"error": str(e)})
|
||||
|
||||
results = await asyncio.gather(
|
||||
*[_run_agent(aid) for aid in agent_ids],
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
output: Dict[str, Any] = {}
|
||||
for result in results:
|
||||
if isinstance(result, tuple):
|
||||
agent_id, agent_result = result
|
||||
output[agent_id] = agent_result
|
||||
else:
|
||||
logger.error("Unexpected result from asyncio.gather: %s", result)
|
||||
|
||||
logger.info("Parallel run completed for %d agents", len(agent_ids))
|
||||
return output
|
||||
|
||||
async def run_sequential(
|
||||
self,
|
||||
agent_ids: List[str],
|
||||
initial_message: Optional[Msg] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run agents one after another in order.
|
||||
|
||||
Args:
|
||||
agent_ids: List of agent IDs to run in sequence
|
||||
initial_message: Optional initial message for first agent
|
||||
|
||||
Returns:
|
||||
Dict mapping agent_id to result
|
||||
"""
|
||||
output: Dict[str, Any] = {}
|
||||
current_message = initial_message
|
||||
|
||||
for agent_id in agent_ids:
|
||||
agent = self._agents.get(agent_id)
|
||||
if agent is None:
|
||||
logger.error("Agent %s not found", agent_id)
|
||||
output[agent_id] = {"error": "Agent not found"}
|
||||
continue
|
||||
|
||||
try:
|
||||
if hasattr(agent, "reply") and asyncio.iscoroutinefunction(agent.reply):
|
||||
result = await agent.reply(current_message)
|
||||
elif hasattr(agent, "run") and asyncio.iscoroutinefunction(agent.run):
|
||||
result = await agent.run()
|
||||
else:
|
||||
result = await agent()
|
||||
|
||||
output[agent_id] = result
|
||||
current_message = result
|
||||
logger.info("Agent %s completed sequentially", agent_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Agent %s failed: %s", agent_id, e)
|
||||
output[agent_id] = {"error": str(e)}
|
||||
break
|
||||
|
||||
logger.info("Sequential run completed for %d agents", len(agent_ids))
|
||||
return output
|
||||
|
||||
async def run_phase(
|
||||
self,
|
||||
phase_name: str,
|
||||
agent_ids: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> List[Any]:
|
||||
"""Execute a named phase with registered agents.
|
||||
|
||||
Args:
|
||||
phase_name: Name of the phase (e.g., "analyst_analysis")
|
||||
agent_ids: Optional list of agent IDs; if None, uses all registered
|
||||
metadata: Optional metadata to include in the message (e.g., tickers, date)
|
||||
|
||||
Returns:
|
||||
List of results from each agent
|
||||
"""
|
||||
if agent_ids is None:
|
||||
agent_ids = list(self._agents.keys())
|
||||
|
||||
_agent_ids = [aid for aid in agent_ids if aid in self._agents]
|
||||
|
||||
logger.info(
|
||||
"Running phase '%s' with %d agents: %s",
|
||||
phase_name,
|
||||
len(_agent_ids),
|
||||
_agent_ids,
|
||||
)
|
||||
|
||||
# Create messages for each agent
|
||||
results: List[Any] = []
|
||||
for agent_id in _agent_ids:
|
||||
agent = self._agents[agent_id]
|
||||
try:
|
||||
if hasattr(agent, "reply") and asyncio.iscoroutinefunction(agent.reply):
|
||||
# Create a message for the agent with proper structure
|
||||
msg = Msg(
|
||||
name="system",
|
||||
content=self._task_content or f"Please execute phase: {phase_name}",
|
||||
role="user",
|
||||
metadata=metadata,
|
||||
)
|
||||
result = await agent.reply(msg)
|
||||
elif hasattr(agent, "run") and asyncio.iscoroutinefunction(agent.run):
|
||||
result = await agent.run()
|
||||
else:
|
||||
result = await agent()
|
||||
results.append(result)
|
||||
logger.info("Phase '%s': Agent %s completed", phase_name, agent_id)
|
||||
except Exception as e:
|
||||
logger.error("Phase '%s': Agent %s failed: %s", phase_name, agent_id, e)
|
||||
results.append(None)
|
||||
|
||||
logger.info("Phase '%s' completed with %d results", phase_name, len(results))
|
||||
return results
|
||||
|
||||
async def run_with_dependencies(
|
||||
self,
|
||||
agent_tasks: Dict[str, List[str]],
|
||||
initial_message: Optional[Msg] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run agents respecting dependency graph.
|
||||
|
||||
Args:
|
||||
agent_tasks: Dict mapping agent_id to list of prerequisite agent_ids
|
||||
initial_message: Optional initial message
|
||||
|
||||
Returns:
|
||||
Dict mapping agent_id to result
|
||||
"""
|
||||
completed: Dict[str, Any] = {}
|
||||
remaining = set(agent_tasks.keys())
|
||||
|
||||
while remaining:
|
||||
ready = [
|
||||
aid for aid in remaining
|
||||
if all(dep in completed for dep in agent_tasks.get(aid, []))
|
||||
]
|
||||
|
||||
if not ready:
|
||||
logger.error("Circular dependency detected in agent tasks")
|
||||
for aid in remaining:
|
||||
completed[aid] = {"error": "Circular dependency"}
|
||||
break
|
||||
|
||||
results = await self.run_parallel(ready, initial_message)
|
||||
completed.update(results)
|
||||
|
||||
for aid in ready:
|
||||
remaining.discard(aid)
|
||||
initial_message = results.get(aid)
|
||||
|
||||
return completed
|
||||
|
||||
async def fanout_pipeline(
|
||||
self,
|
||||
agents: List[Any],
|
||||
msg: Optional[Msg] = None,
|
||||
) -> List[Msg]:
|
||||
"""Fanout a message to multiple agents concurrently and collect all responses.
|
||||
|
||||
Similar to AgentScope's fanout_pipeline, this sends the same message
|
||||
to all specified agents and returns a list of all agent responses.
|
||||
|
||||
Args:
|
||||
agents: List of agent instances to fanout the message to
|
||||
msg: Message to send to all agents (optional)
|
||||
|
||||
Returns:
|
||||
List of Msg responses from each agent (in the same order as input agents)
|
||||
|
||||
Example:
|
||||
>>> responses = await fanout_pipeline(
|
||||
... agents=[alice, bob, charlie],
|
||||
... msg=question,
|
||||
... )
|
||||
>>> # responses is a list of Msg responses from each agent
|
||||
"""
|
||||
async def _fanout_to_agent(agent: Any) -> Optional[Msg]:
|
||||
"""Send message to a single agent and return its response."""
|
||||
try:
|
||||
if hasattr(agent, "reply") and asyncio.iscoroutinefunction(agent.reply):
|
||||
result = await agent.reply(msg) if msg is not None else await agent.reply()
|
||||
elif hasattr(agent, "run") and asyncio.iscoroutinefunction(agent.run):
|
||||
result = await agent.run()
|
||||
else:
|
||||
result = await agent()
|
||||
|
||||
# Convert result to Msg if needed
|
||||
if result is None:
|
||||
return None
|
||||
if isinstance(result, Msg):
|
||||
return result
|
||||
# If result is a dict with content, wrap it
|
||||
if isinstance(result, dict) and "content" in result:
|
||||
return Msg(
|
||||
name=getattr(agent, "name", "unknown"),
|
||||
content=result.get("content", ""),
|
||||
role="assistant",
|
||||
metadata=result.get("metadata"),
|
||||
)
|
||||
# Otherwise wrap the result
|
||||
return Msg(
|
||||
name=getattr(agent, "name", "unknown"),
|
||||
content=str(result),
|
||||
role="assistant",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Agent %s failed in fanout_pipeline: %s",
|
||||
getattr(agent, "name", "unknown"), e)
|
||||
return None
|
||||
|
||||
# Run all agents concurrently
|
||||
results = await asyncio.gather(
|
||||
*[_fanout_to_agent(agent) for agent in agents],
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
# Filter out exceptions and keep only valid responses
|
||||
responses: List[Msg] = []
|
||||
for i, result in enumerate(results):
|
||||
if isinstance(result, Exception):
|
||||
logger.error("Fanout to agent %d failed: %s", i, result)
|
||||
responses.append(None) # type: ignore[arg-type]
|
||||
else:
|
||||
responses.append(result) # type: ignore[arg-type]
|
||||
|
||||
logger.info("Fanout pipeline completed for %d agents", len(agents))
|
||||
return responses
|
||||
|
||||
async def shutdown(self, timeout: Optional[float] = 5.0) -> None:
|
||||
"""Shutdown all running agents gracefully.
|
||||
|
||||
Args:
|
||||
timeout: Timeout for graceful shutdown
|
||||
"""
|
||||
logger.info("Shutting down TeamCoordinator...")
|
||||
|
||||
cancel_tasks = [
|
||||
asyncio.create_task(asyncio.wait_for(task, timeout=timeout))
|
||||
for task in self._running_tasks.values()
|
||||
]
|
||||
|
||||
if cancel_tasks:
|
||||
await asyncio.gather(*cancel_tasks, return_exceptions=True)
|
||||
|
||||
self._running_tasks.clear()
|
||||
logger.info("TeamCoordinator shutdown complete")
|
||||
|
||||
@property
|
||||
def agents(self) -> Dict[str, Any]:
|
||||
"""Get copy of registered agents dict."""
|
||||
return dict(self._agents)
|
||||
|
||||
|
||||
__all__ = ["TeamCoordinator"]
|
||||
132
backend/agents/team_pipeline_config.py
Normal file
132
backend/agents/team_pipeline_config.py
Normal file
@@ -0,0 +1,132 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Run-scoped team pipeline configuration helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Iterable, List, Dict, Any
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
DEFAULT_FILENAME = "TEAM_PIPELINE.yaml"
|
||||
|
||||
|
||||
def team_pipeline_path(project_root: Path, config_name: str) -> Path:
|
||||
"""Return run-scoped team pipeline config path."""
|
||||
return project_root / "runs" / config_name / DEFAULT_FILENAME
|
||||
|
||||
|
||||
def ensure_team_pipeline_config(
|
||||
project_root: Path,
|
||||
config_name: str,
|
||||
default_analysts: Iterable[str],
|
||||
) -> Path:
|
||||
"""Ensure TEAM_PIPELINE.yaml exists for one run."""
|
||||
path = team_pipeline_path(project_root, config_name)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if path.exists():
|
||||
return path
|
||||
|
||||
payload = {
|
||||
"version": 1,
|
||||
"controller_agent": "portfolio_manager",
|
||||
"discussion": {
|
||||
"allow_dynamic_team_update": True,
|
||||
"active_analysts": list(default_analysts),
|
||||
},
|
||||
"decision": {
|
||||
"require_risk_manager": True,
|
||||
},
|
||||
}
|
||||
path.write_text(
|
||||
yaml.safe_dump(payload, allow_unicode=True, sort_keys=False),
|
||||
encoding="utf-8",
|
||||
)
|
||||
return path
|
||||
|
||||
|
||||
def load_team_pipeline_config(project_root: Path, config_name: str) -> Dict[str, Any]:
|
||||
"""Load TEAM_PIPELINE.yaml and return parsed dict."""
|
||||
path = team_pipeline_path(project_root, config_name)
|
||||
if not path.exists():
|
||||
return {}
|
||||
parsed = yaml.safe_load(path.read_text(encoding="utf-8")) or {}
|
||||
return parsed if isinstance(parsed, dict) else {}
|
||||
|
||||
|
||||
def save_team_pipeline_config(
|
||||
project_root: Path,
|
||||
config_name: str,
|
||||
config: Dict[str, Any],
|
||||
) -> Path:
|
||||
"""Persist TEAM_PIPELINE.yaml."""
|
||||
path = team_pipeline_path(project_root, config_name)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_text(
|
||||
yaml.safe_dump(config, allow_unicode=True, sort_keys=False),
|
||||
encoding="utf-8",
|
||||
)
|
||||
return path
|
||||
|
||||
|
||||
def resolve_active_analysts(
|
||||
project_root: Path,
|
||||
config_name: str,
|
||||
available_analysts: Iterable[str],
|
||||
) -> List[str]:
|
||||
"""Resolve active analysts from TEAM_PIPELINE.yaml."""
|
||||
available = [item for item in available_analysts]
|
||||
parsed = load_team_pipeline_config(project_root, config_name)
|
||||
discussion = parsed.get("discussion", {}) if isinstance(parsed, dict) else {}
|
||||
configured = discussion.get("active_analysts", [])
|
||||
if not isinstance(configured, list) or not configured:
|
||||
return available
|
||||
|
||||
active = [item for item in configured if item in available]
|
||||
return active or available
|
||||
|
||||
|
||||
def update_active_analysts(
|
||||
project_root: Path,
|
||||
config_name: str,
|
||||
available_analysts: Iterable[str],
|
||||
*,
|
||||
add: Iterable[str] | None = None,
|
||||
remove: Iterable[str] | None = None,
|
||||
set_to: Iterable[str] | None = None,
|
||||
) -> List[str]:
|
||||
"""Update active analysts and persist TEAM_PIPELINE.yaml."""
|
||||
available = [item for item in available_analysts]
|
||||
ensure_team_pipeline_config(project_root, config_name, available)
|
||||
parsed = load_team_pipeline_config(project_root, config_name)
|
||||
discussion = parsed.setdefault("discussion", {})
|
||||
if not isinstance(discussion, dict):
|
||||
discussion = {}
|
||||
parsed["discussion"] = discussion
|
||||
|
||||
current = discussion.get("active_analysts", [])
|
||||
if not isinstance(current, list):
|
||||
current = []
|
||||
current = [item for item in current if item in available]
|
||||
if not current:
|
||||
current = list(available)
|
||||
|
||||
if set_to is not None:
|
||||
target = [item for item in set_to if item in available]
|
||||
current = target or current
|
||||
|
||||
for item in add or []:
|
||||
if item in available and item not in current:
|
||||
current.append(item)
|
||||
|
||||
for item in remove or []:
|
||||
current = [existing for existing in current if existing != item]
|
||||
|
||||
if not current:
|
||||
current = [available[0]] if available else []
|
||||
|
||||
discussion["active_analysts"] = current
|
||||
save_team_pipeline_config(project_root, config_name, parsed)
|
||||
return current
|
||||
|
||||
@@ -129,6 +129,33 @@ class RunWorkspaceManager:
|
||||
)
|
||||
return asset_dir
|
||||
|
||||
def load_agent_file(
|
||||
self,
|
||||
*,
|
||||
config_name: str,
|
||||
agent_id: str,
|
||||
filename: str,
|
||||
) -> str:
|
||||
"""Load one run-scoped agent workspace file."""
|
||||
path = self.get_agent_asset_dir(config_name, agent_id) / filename
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"File not found: {filename}")
|
||||
return path.read_text(encoding="utf-8")
|
||||
|
||||
def update_agent_file(
|
||||
self,
|
||||
*,
|
||||
config_name: str,
|
||||
agent_id: str,
|
||||
filename: str,
|
||||
content: str,
|
||||
) -> None:
|
||||
"""Write one run-scoped agent workspace file."""
|
||||
asset_dir = self.get_agent_asset_dir(config_name, agent_id)
|
||||
asset_dir.mkdir(parents=True, exist_ok=True)
|
||||
path = asset_dir / filename
|
||||
path.write_text(content, encoding="utf-8")
|
||||
|
||||
def initialize_default_assets(
|
||||
self,
|
||||
config_name: str,
|
||||
|
||||
@@ -13,8 +13,13 @@ from typing import Any, Dict, List, Optional
|
||||
from fastapi import APIRouter, HTTPException, Depends, Body, UploadFile, File, Form
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.agents import AgentFactory, WorkspaceManager, get_registry
|
||||
from backend.agents import AgentFactory, get_registry
|
||||
from backend.agents.workspace_manager import RunWorkspaceManager
|
||||
from backend.agents.agent_workspace import load_agent_workspace_config
|
||||
from backend.agents.skills_manager import SkillsManager
|
||||
from backend.agents.toolkit_factory import load_agent_profiles
|
||||
from backend.config.bootstrap_config import get_bootstrap_config_for_run
|
||||
from backend.llm.models import get_agent_model_info
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -47,6 +52,14 @@ class InstallExternalSkillRequest(BaseModel):
|
||||
activate: bool = Field(True, description="Whether to enable skill immediately")
|
||||
|
||||
|
||||
class LocalSkillRequest(BaseModel):
|
||||
skill_name: str = Field(..., description="Local skill name")
|
||||
|
||||
|
||||
class LocalSkillContentRequest(BaseModel):
|
||||
content: str = Field(..., description="Updated SKILL.md content")
|
||||
|
||||
|
||||
class AgentResponse(BaseModel):
|
||||
"""Agent information response."""
|
||||
agent_id: str
|
||||
@@ -63,6 +76,24 @@ class AgentFileResponse(BaseModel):
|
||||
content: str
|
||||
|
||||
|
||||
class AgentProfileResponse(BaseModel):
|
||||
agent_id: str
|
||||
workspace_id: str
|
||||
profile: Dict[str, Any]
|
||||
|
||||
|
||||
class AgentSkillsResponse(BaseModel):
|
||||
agent_id: str
|
||||
workspace_id: str
|
||||
skills: List[Dict[str, Any]]
|
||||
|
||||
|
||||
class SkillDetailResponse(BaseModel):
|
||||
agent_id: str
|
||||
workspace_id: str
|
||||
skill: Dict[str, Any]
|
||||
|
||||
|
||||
# Dependencies
|
||||
def get_agent_factory():
|
||||
"""Get AgentFactory instance."""
|
||||
@@ -70,8 +101,8 @@ def get_agent_factory():
|
||||
|
||||
|
||||
def get_workspace_manager():
|
||||
"""Get WorkspaceManager instance."""
|
||||
return WorkspaceManager()
|
||||
"""Get run-scoped workspace manager instance."""
|
||||
return RunWorkspaceManager()
|
||||
|
||||
|
||||
def get_skills_manager():
|
||||
@@ -199,6 +230,108 @@ async def get_agent(
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{agent_id}/profile", response_model=AgentProfileResponse)
|
||||
async def get_agent_profile(
|
||||
workspace_id: str,
|
||||
agent_id: str,
|
||||
skills_manager: SkillsManager = Depends(get_skills_manager),
|
||||
):
|
||||
asset_dir = skills_manager.get_agent_asset_dir(workspace_id, agent_id)
|
||||
agent_config = load_agent_workspace_config(asset_dir / "agent.yaml")
|
||||
profiles = load_agent_profiles()
|
||||
profile = profiles.get(agent_id, {})
|
||||
bootstrap = get_bootstrap_config_for_run(skills_manager.project_root, workspace_id)
|
||||
override = bootstrap.agent_override(agent_id)
|
||||
active_tool_groups = override.get("active_tool_groups", agent_config.active_tool_groups or profile.get("active_tool_groups", []))
|
||||
if not isinstance(active_tool_groups, list):
|
||||
active_tool_groups = []
|
||||
disabled_tool_groups = agent_config.disabled_tool_groups
|
||||
if disabled_tool_groups:
|
||||
disabled_set = set(disabled_tool_groups)
|
||||
active_tool_groups = [group_name for group_name in active_tool_groups if group_name not in disabled_set]
|
||||
|
||||
default_skills = profile.get("skills", [])
|
||||
if not isinstance(default_skills, list):
|
||||
default_skills = []
|
||||
resolved_skills = skills_manager.resolve_agent_skill_names(
|
||||
config_name=workspace_id,
|
||||
agent_id=agent_id,
|
||||
default_skills=default_skills,
|
||||
)
|
||||
prompt_files = agent_config.prompt_files or ["SOUL.md", "PROFILE.md", "AGENTS.md", "POLICY.md", "MEMORY.md"]
|
||||
model_name, model_provider = get_agent_model_info(agent_id)
|
||||
|
||||
return AgentProfileResponse(
|
||||
agent_id=agent_id,
|
||||
workspace_id=workspace_id,
|
||||
profile={
|
||||
"model_name": model_name,
|
||||
"model_provider": model_provider,
|
||||
"prompt_files": prompt_files,
|
||||
"default_skills": default_skills,
|
||||
"resolved_skills": resolved_skills,
|
||||
"active_tool_groups": active_tool_groups,
|
||||
"disabled_tool_groups": disabled_tool_groups,
|
||||
"enabled_skills": agent_config.enabled_skills,
|
||||
"disabled_skills": agent_config.disabled_skills,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{agent_id}/skills", response_model=AgentSkillsResponse)
|
||||
async def get_agent_skills(
|
||||
workspace_id: str,
|
||||
agent_id: str,
|
||||
skills_manager: SkillsManager = Depends(get_skills_manager),
|
||||
):
|
||||
agent_asset_dir = skills_manager.get_agent_asset_dir(workspace_id, agent_id)
|
||||
agent_config = load_agent_workspace_config(agent_asset_dir / "agent.yaml")
|
||||
resolved_skills = set(skills_manager.resolve_agent_skill_names(config_name=workspace_id, agent_id=agent_id, default_skills=[]))
|
||||
enabled = set(agent_config.enabled_skills)
|
||||
disabled = set(agent_config.disabled_skills)
|
||||
|
||||
payload = []
|
||||
for item in skills_manager.list_agent_skill_catalog(workspace_id, agent_id):
|
||||
if item.skill_name in disabled:
|
||||
status = "disabled"
|
||||
elif item.skill_name in enabled:
|
||||
status = "enabled"
|
||||
elif item.skill_name in resolved_skills:
|
||||
status = "active"
|
||||
else:
|
||||
status = "available"
|
||||
payload.append({
|
||||
"skill_name": item.skill_name,
|
||||
"name": item.name,
|
||||
"description": item.description,
|
||||
"version": item.version,
|
||||
"source": item.source,
|
||||
"tools": item.tools,
|
||||
"status": status,
|
||||
})
|
||||
|
||||
return AgentSkillsResponse(agent_id=agent_id, workspace_id=workspace_id, skills=payload)
|
||||
|
||||
|
||||
@router.get("/{agent_id}/skills/{skill_name}", response_model=SkillDetailResponse)
|
||||
async def get_agent_skill_detail(
|
||||
workspace_id: str,
|
||||
agent_id: str,
|
||||
skill_name: str,
|
||||
skills_manager: SkillsManager = Depends(get_skills_manager),
|
||||
):
|
||||
try:
|
||||
detail = skills_manager.load_agent_skill_document(
|
||||
config_name=workspace_id,
|
||||
agent_id=agent_id,
|
||||
skill_name=skill_name,
|
||||
)
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(status_code=404, detail=f"Unknown skill: {skill_name}")
|
||||
|
||||
return SkillDetailResponse(agent_id=agent_id, workspace_id=workspace_id, skill=detail)
|
||||
|
||||
|
||||
@router.delete("/{agent_id}")
|
||||
async def delete_agent(
|
||||
workspace_id: str,
|
||||
@@ -386,6 +519,85 @@ async def install_external_skill(
|
||||
}
|
||||
|
||||
|
||||
@router.post("/{agent_id}/skills/local")
|
||||
async def create_local_skill(
|
||||
workspace_id: str,
|
||||
agent_id: str,
|
||||
request: LocalSkillRequest,
|
||||
registry=Depends(get_registry),
|
||||
):
|
||||
agent_info = registry.get(agent_id)
|
||||
if not agent_info or agent_info.workspace_id != workspace_id:
|
||||
raise HTTPException(status_code=404, detail=f"Agent '{agent_id}' not found")
|
||||
|
||||
skills_manager = SkillsManager()
|
||||
try:
|
||||
skills_manager.create_agent_local_skill(
|
||||
config_name=workspace_id,
|
||||
agent_id=agent_id,
|
||||
skill_name=request.skill_name,
|
||||
)
|
||||
except (ValueError, FileExistsError) as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc))
|
||||
|
||||
return {"message": f"Created local skill '{request.skill_name}' for '{agent_id}'"}
|
||||
|
||||
|
||||
@router.put("/{agent_id}/skills/local/{skill_name}")
|
||||
async def update_local_skill(
|
||||
workspace_id: str,
|
||||
agent_id: str,
|
||||
skill_name: str,
|
||||
request: LocalSkillContentRequest,
|
||||
registry=Depends(get_registry),
|
||||
):
|
||||
agent_info = registry.get(agent_id)
|
||||
if not agent_info or agent_info.workspace_id != workspace_id:
|
||||
raise HTTPException(status_code=404, detail=f"Agent '{agent_id}' not found")
|
||||
|
||||
skills_manager = SkillsManager()
|
||||
try:
|
||||
skills_manager.update_agent_local_skill(
|
||||
config_name=workspace_id,
|
||||
agent_id=agent_id,
|
||||
skill_name=skill_name,
|
||||
content=request.content,
|
||||
)
|
||||
except (ValueError, FileNotFoundError) as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc))
|
||||
|
||||
return {"message": f"Updated local skill '{skill_name}' for '{agent_id}'"}
|
||||
|
||||
|
||||
@router.delete("/{agent_id}/skills/local/{skill_name}")
|
||||
async def delete_local_skill(
|
||||
workspace_id: str,
|
||||
agent_id: str,
|
||||
skill_name: str,
|
||||
registry=Depends(get_registry),
|
||||
):
|
||||
agent_info = registry.get(agent_id)
|
||||
if not agent_info or agent_info.workspace_id != workspace_id:
|
||||
raise HTTPException(status_code=404, detail=f"Agent '{agent_id}' not found")
|
||||
|
||||
skills_manager = SkillsManager()
|
||||
try:
|
||||
skills_manager.delete_agent_local_skill(
|
||||
config_name=workspace_id,
|
||||
agent_id=agent_id,
|
||||
skill_name=skill_name,
|
||||
)
|
||||
skills_manager.forget_agent_skill_overrides(
|
||||
config_name=workspace_id,
|
||||
agent_id=agent_id,
|
||||
skill_names=[skill_name],
|
||||
)
|
||||
except (ValueError, FileNotFoundError) as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc))
|
||||
|
||||
return {"message": f"Deleted local skill '{skill_name}' for '{agent_id}'"}
|
||||
|
||||
|
||||
@router.post("/{agent_id}/skills/upload")
|
||||
async def upload_external_skill(
|
||||
workspace_id: str,
|
||||
@@ -441,7 +653,7 @@ async def get_agent_file(
|
||||
workspace_id: str,
|
||||
agent_id: str,
|
||||
filename: str,
|
||||
workspace_manager: WorkspaceManager = Depends(get_workspace_manager),
|
||||
workspace_manager: RunWorkspaceManager = Depends(get_workspace_manager),
|
||||
):
|
||||
"""
|
||||
Read an agent's workspace file.
|
||||
@@ -471,7 +683,7 @@ async def update_agent_file(
|
||||
agent_id: str,
|
||||
filename: str,
|
||||
content: str = Body(..., media_type="text/plain"),
|
||||
workspace_manager: WorkspaceManager = Depends(get_workspace_manager),
|
||||
workspace_manager: RunWorkspaceManager = Depends(get_workspace_manager),
|
||||
):
|
||||
"""
|
||||
Update an agent's workspace file.
|
||||
|
||||
@@ -8,6 +8,7 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
from datetime import datetime
|
||||
@@ -16,20 +17,124 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from fastapi import APIRouter, HTTPException, BackgroundTasks
|
||||
from fastapi import APIRouter, BackgroundTasks, HTTPException, Request
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from backend.runtime.agent_runtime import AgentRuntimeState
|
||||
from backend.runtime.manager import TradingRuntimeManager, get_global_runtime_manager
|
||||
from backend.config.bootstrap_config import (
|
||||
resolve_runtime_config,
|
||||
update_bootstrap_values_for_run,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/api/runtime", tags=["runtime"])
|
||||
|
||||
runtime_manager: Optional[TradingRuntimeManager] = None
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||
|
||||
# Gateway process management
|
||||
_gateway_process: Optional[subprocess.Popen] = None
|
||||
_gateway_port: int = 8765
|
||||
|
||||
class RuntimeState:
|
||||
"""Thread-safe singleton for managing runtime state.
|
||||
|
||||
Encapsulates runtime_manager, _gateway_process, and _gateway_port
|
||||
with asyncio.Lock protection for concurrent access.
|
||||
"""
|
||||
|
||||
_instance: Optional["RuntimeState"] = None
|
||||
_lock: "threading.Lock" = __import__("threading").Lock()
|
||||
|
||||
def __new__(cls) -> "RuntimeState":
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self) -> None:
|
||||
if self._initialized:
|
||||
return
|
||||
self._runtime_manager: Optional[Any] = None
|
||||
self._gateway_process: Optional[subprocess.Popen] = None
|
||||
self._gateway_port: int = 8765
|
||||
self._state_lock = asyncio.Lock()
|
||||
self._initialized = True
|
||||
|
||||
@property
|
||||
async def lock(self) -> asyncio.Lock:
|
||||
"""Get the asyncio lock for state synchronization."""
|
||||
return self._state_lock
|
||||
|
||||
@property
|
||||
def runtime_manager(self) -> Optional[Any]:
|
||||
"""Get the runtime manager (no lock - read only)."""
|
||||
return self._runtime_manager
|
||||
|
||||
@runtime_manager.setter
|
||||
def runtime_manager(self, value: Optional[Any]) -> None:
|
||||
"""Set the runtime manager."""
|
||||
self._runtime_manager = value
|
||||
|
||||
@property
|
||||
def gateway_process(self) -> Optional[subprocess.Popen]:
|
||||
"""Get the gateway process (no lock - read only)."""
|
||||
return self._gateway_process
|
||||
|
||||
@gateway_process.setter
|
||||
def gateway_process(self, value: Optional[subprocess.Popen]) -> None:
|
||||
"""Set the gateway process."""
|
||||
self._gateway_process = value
|
||||
|
||||
@property
|
||||
def gateway_port(self) -> int:
|
||||
"""Get the gateway port."""
|
||||
return self._gateway_port
|
||||
|
||||
@gateway_port.setter
|
||||
def gateway_port(self, value: int) -> None:
|
||||
"""Set the gateway port."""
|
||||
self._gateway_port = value
|
||||
|
||||
async def set_runtime_manager(self, manager: Any) -> None:
|
||||
"""Set runtime manager with lock protection."""
|
||||
async with self._state_lock:
|
||||
self._runtime_manager = manager
|
||||
|
||||
async def get_runtime_manager(self) -> Optional[Any]:
|
||||
"""Get runtime manager with lock protection."""
|
||||
async with self._state_lock:
|
||||
return self._runtime_manager
|
||||
|
||||
async def set_gateway_process(self, process: Optional[subprocess.Popen]) -> None:
|
||||
"""Set gateway process with lock protection."""
|
||||
async with self._state_lock:
|
||||
self._gateway_process = process
|
||||
|
||||
async def get_gateway_process(self) -> Optional[subprocess.Popen]:
|
||||
"""Get gateway process with lock protection."""
|
||||
async with self._state_lock:
|
||||
return self._gateway_process
|
||||
|
||||
async def set_gateway_port(self, port: int) -> None:
|
||||
"""Set gateway port with lock protection."""
|
||||
async with self._state_lock:
|
||||
self._gateway_port = port
|
||||
|
||||
async def get_gateway_port(self) -> int:
|
||||
"""Get gateway port with lock protection."""
|
||||
async with self._state_lock:
|
||||
return self._gateway_port
|
||||
|
||||
|
||||
# Singleton instance
|
||||
_runtime_state = RuntimeState()
|
||||
|
||||
|
||||
def get_runtime_state() -> RuntimeState:
|
||||
"""Get the RuntimeState singleton instance."""
|
||||
return _runtime_state
|
||||
|
||||
|
||||
# Backward compatibility: module-level runtime_manager for external imports
|
||||
# This is set by register_runtime_manager() for backward compatibility
|
||||
runtime_manager: Optional[Any] = None
|
||||
|
||||
|
||||
class RunContextResponse(BaseModel):
|
||||
@@ -62,6 +167,8 @@ class RuntimeEventsResponse(BaseModel):
|
||||
|
||||
class LaunchConfig(BaseModel):
|
||||
"""Configuration for launching a new trading task."""
|
||||
launch_mode: str = Field(default="fresh", description="启动形式: fresh, restore")
|
||||
restore_run_id: Optional[str] = Field(default=None, description="历史任务 run_id,用于恢复启动")
|
||||
tickers: List[str] = Field(default_factory=list, description="股票池")
|
||||
schedule_mode: str = Field(default="daily", description="调度模式: daily, interval")
|
||||
interval_minutes: int = Field(default=60, ge=1, description="间隔分钟数")
|
||||
@@ -74,7 +181,6 @@ class LaunchConfig(BaseModel):
|
||||
start_date: Optional[str] = Field(default=None, description="回测开始日期 YYYY-MM-DD")
|
||||
end_date: Optional[str] = Field(default=None, description="回测结束日期 YYYY-MM-DD")
|
||||
poll_interval: int = Field(default=10, ge=1, le=300, description="市场数据轮询间隔(秒)")
|
||||
enable_mock: bool = Field(default=False, description="是否启用模拟模式(使用模拟价格数据)")
|
||||
|
||||
|
||||
class LaunchResponse(BaseModel):
|
||||
@@ -85,17 +191,61 @@ class LaunchResponse(BaseModel):
|
||||
message: str
|
||||
|
||||
|
||||
class RuntimeHistoryItem(BaseModel):
|
||||
run_id: str
|
||||
run_dir: str
|
||||
updated_at: Optional[str] = None
|
||||
total_trades: int = 0
|
||||
total_asset_value: Optional[float] = None
|
||||
bootstrap: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class RuntimeHistoryResponse(BaseModel):
|
||||
runs: List[RuntimeHistoryItem]
|
||||
|
||||
|
||||
class StopResponse(BaseModel):
|
||||
status: str
|
||||
message: str
|
||||
|
||||
|
||||
class CleanupResponse(BaseModel):
|
||||
status: str
|
||||
kept: int
|
||||
pruned_run_ids: List[str]
|
||||
|
||||
|
||||
class GatewayStatusResponse(BaseModel):
|
||||
is_running: bool
|
||||
port: int
|
||||
run_id: Optional[str] = None
|
||||
|
||||
|
||||
class RuntimeConfigResponse(BaseModel):
|
||||
run_id: str
|
||||
is_running: bool
|
||||
gateway_port: int
|
||||
bootstrap: Dict[str, Any]
|
||||
resolved: Dict[str, Any]
|
||||
|
||||
|
||||
class RuntimeLogResponse(BaseModel):
|
||||
run_id: Optional[str] = None
|
||||
is_running: bool
|
||||
log_path: Optional[str] = None
|
||||
content: str = ""
|
||||
|
||||
|
||||
class UpdateRuntimeConfigRequest(BaseModel):
|
||||
schedule_mode: Optional[str] = None
|
||||
interval_minutes: Optional[int] = Field(default=None, ge=1)
|
||||
trigger_time: Optional[str] = None
|
||||
max_comm_cycles: Optional[int] = Field(default=None, ge=1)
|
||||
initial_cash: Optional[float] = Field(default=None, gt=0)
|
||||
margin_requirement: Optional[float] = Field(default=None, ge=0)
|
||||
enable_memory: Optional[bool] = None
|
||||
|
||||
|
||||
def _generate_run_id() -> str:
|
||||
"""Generate timestamp-based run ID: YYYYMMDD_HHMMSS"""
|
||||
return datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
@@ -106,6 +256,128 @@ def _get_run_dir(run_id: str) -> Path:
|
||||
return PROJECT_ROOT / "runs" / run_id
|
||||
|
||||
|
||||
def _load_run_snapshot(run_id: str) -> Dict[str, Any]:
|
||||
"""Load a specific run snapshot by run_id."""
|
||||
snapshot_path = _get_run_dir(run_id) / "state" / "runtime_state.json"
|
||||
if not snapshot_path.exists():
|
||||
raise HTTPException(status_code=404, detail=f"Run snapshot not found: {run_id}")
|
||||
return json.loads(snapshot_path.read_text(encoding="utf-8"))
|
||||
|
||||
|
||||
def _copy_path_if_exists(src: Path, dst: Path) -> None:
|
||||
if not src.exists():
|
||||
return
|
||||
if src.is_dir():
|
||||
shutil.copytree(src, dst, dirs_exist_ok=True)
|
||||
else:
|
||||
dst.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy2(src, dst)
|
||||
|
||||
|
||||
def _restore_run_assets(source_run_id: str, target_run_dir: Path) -> None:
|
||||
"""Seed a fresh run directory from a historical run snapshot."""
|
||||
source_run_dir = _get_run_dir(source_run_id)
|
||||
if not source_run_dir.exists():
|
||||
raise HTTPException(status_code=404, detail=f"Source run not found: {source_run_id}")
|
||||
|
||||
for relative in [
|
||||
"team_dashboard",
|
||||
"agents",
|
||||
"skills",
|
||||
"memory",
|
||||
"state/server_state.json",
|
||||
"state/runtime.db",
|
||||
"state/research.db",
|
||||
]:
|
||||
_copy_path_if_exists(source_run_dir / relative, target_run_dir / relative)
|
||||
|
||||
|
||||
def _list_runs(limit: int = 50) -> list[RuntimeHistoryItem]:
|
||||
runs_root = PROJECT_ROOT / "runs"
|
||||
if not runs_root.exists():
|
||||
return []
|
||||
|
||||
items: list[RuntimeHistoryItem] = []
|
||||
run_dirs = sorted(
|
||||
[path for path in runs_root.iterdir() if path.is_dir()],
|
||||
key=lambda path: path.stat().st_mtime,
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
for run_dir in run_dirs[: max(1, int(limit))]:
|
||||
run_id = run_dir.name
|
||||
runtime_state_path = run_dir / "state" / "runtime_state.json"
|
||||
summary_path = run_dir / "team_dashboard" / "summary.json"
|
||||
|
||||
bootstrap: Dict[str, Any] = {}
|
||||
updated_at: Optional[str] = None
|
||||
total_trades = 0
|
||||
total_asset_value: Optional[float] = None
|
||||
|
||||
if runtime_state_path.exists():
|
||||
try:
|
||||
snapshot = json.loads(runtime_state_path.read_text(encoding="utf-8"))
|
||||
context = snapshot.get("context") or {}
|
||||
bootstrap = dict(context.get("bootstrap_values") or {})
|
||||
updated_at = snapshot.get("events", [{}])[-1].get("timestamp") if snapshot.get("events") else None
|
||||
except Exception:
|
||||
bootstrap = {}
|
||||
|
||||
if summary_path.exists():
|
||||
try:
|
||||
summary = json.loads(summary_path.read_text(encoding="utf-8"))
|
||||
total_trades = int(summary.get("totalTrades") or 0)
|
||||
total_asset_value = float(summary.get("totalAssetValue")) if summary.get("totalAssetValue") is not None else None
|
||||
except Exception:
|
||||
total_trades = 0
|
||||
total_asset_value = None
|
||||
|
||||
items.append(
|
||||
RuntimeHistoryItem(
|
||||
run_id=run_id,
|
||||
run_dir=str(run_dir),
|
||||
updated_at=updated_at,
|
||||
total_trades=total_trades,
|
||||
total_asset_value=total_asset_value,
|
||||
bootstrap=bootstrap,
|
||||
)
|
||||
)
|
||||
|
||||
return items
|
||||
|
||||
|
||||
def _is_timestamped_run_dir(path: Path) -> bool:
|
||||
try:
|
||||
datetime.strptime(path.name, "%Y%m%d_%H%M%S")
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def _prune_old_timestamped_runs(*, keep: int = 20, exclude_run_ids: Optional[set[str]] = None) -> list[str]:
|
||||
"""Prune old timestamped run directories, preserving the newest N and excluded ids."""
|
||||
exclude = exclude_run_ids or set()
|
||||
runs_root = PROJECT_ROOT / "runs"
|
||||
if not runs_root.exists():
|
||||
return []
|
||||
|
||||
candidates = sorted(
|
||||
[
|
||||
path
|
||||
for path in runs_root.iterdir()
|
||||
if path.is_dir() and _is_timestamped_run_dir(path) and path.name not in exclude
|
||||
],
|
||||
key=lambda path: path.name,
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
pruned: list[str] = []
|
||||
for path in candidates[max(0, keep):]:
|
||||
shutil.rmtree(path, ignore_errors=True)
|
||||
pruned.append(path.name)
|
||||
return pruned
|
||||
|
||||
|
||||
def _find_available_port(start_port: int = 8765, max_port: int = 9000) -> int:
|
||||
"""Find an available port for Gateway."""
|
||||
import socket
|
||||
@@ -118,31 +390,31 @@ def _find_available_port(start_port: int = 8765, max_port: int = 9000) -> int:
|
||||
|
||||
def _is_gateway_running() -> bool:
|
||||
"""Check if Gateway process is running."""
|
||||
global _gateway_process
|
||||
if _gateway_process is None:
|
||||
process = _runtime_state.gateway_process
|
||||
if process is None:
|
||||
return False
|
||||
return _gateway_process.poll() is None
|
||||
return process.poll() is None
|
||||
|
||||
|
||||
def _stop_gateway() -> bool:
|
||||
"""Stop the Gateway process."""
|
||||
global _gateway_process
|
||||
if _gateway_process is None:
|
||||
process = _runtime_state.gateway_process
|
||||
if process is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
# Try graceful shutdown first
|
||||
_gateway_process.terminate()
|
||||
process.terminate()
|
||||
try:
|
||||
_gateway_process.wait(timeout=5)
|
||||
process.wait(timeout=5)
|
||||
except subprocess.TimeoutExpired:
|
||||
# Force kill if graceful shutdown fails
|
||||
_gateway_process.kill()
|
||||
_gateway_process.wait()
|
||||
process.kill()
|
||||
process.wait()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error during gateway shutdown: {e}")
|
||||
finally:
|
||||
_gateway_process = None
|
||||
_runtime_state.gateway_process = None
|
||||
|
||||
return True
|
||||
|
||||
@@ -167,29 +439,29 @@ def _start_gateway_process(
|
||||
"--bootstrap", json.dumps(bootstrap)
|
||||
]
|
||||
|
||||
# Start process
|
||||
process = subprocess.Popen(
|
||||
cmd,
|
||||
env=env,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
cwd=PROJECT_ROOT
|
||||
)
|
||||
log_path = run_dir / "logs" / "gateway.log"
|
||||
log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
log_file = log_path.open("ab")
|
||||
try:
|
||||
process = subprocess.Popen(
|
||||
cmd,
|
||||
env=env,
|
||||
stdout=log_file,
|
||||
stderr=subprocess.STDOUT,
|
||||
cwd=PROJECT_ROOT
|
||||
)
|
||||
finally:
|
||||
log_file.close()
|
||||
|
||||
return process
|
||||
|
||||
|
||||
@router.get("/context", response_model=RunContextResponse)
|
||||
async def get_run_context() -> RunContextResponse:
|
||||
"""Return the most recent run context."""
|
||||
snapshot_path = PROJECT_ROOT.glob("runs/*/state/runtime_state.json")
|
||||
snapshots = sorted(snapshot_path, key=lambda p: p.stat().st_mtime, reverse=True)
|
||||
|
||||
if not snapshots:
|
||||
raise HTTPException(status_code=404, detail="No run context available")
|
||||
|
||||
latest = json.loads(snapshots[0].read_text(encoding="utf-8"))
|
||||
context = latest.get("context")
|
||||
"""Return active runtime context, or latest persisted context when stopped."""
|
||||
snapshot = _get_active_runtime_snapshot() if _is_gateway_running() else _load_latest_runtime_snapshot()
|
||||
context = snapshot.get("context")
|
||||
if context is None:
|
||||
raise HTTPException(status_code=404, detail="Run context is not ready")
|
||||
|
||||
@@ -202,15 +474,9 @@ async def get_run_context() -> RunContextResponse:
|
||||
|
||||
@router.get("/agents", response_model=RuntimeAgentsResponse)
|
||||
async def get_runtime_agents() -> RuntimeAgentsResponse:
|
||||
"""Return agent states from the most recent run."""
|
||||
snapshot_path = PROJECT_ROOT.glob("runs/*/state/runtime_state.json")
|
||||
snapshots = sorted(snapshot_path, key=lambda p: p.stat().st_mtime, reverse=True)
|
||||
|
||||
if not snapshots:
|
||||
raise HTTPException(status_code=404, detail="No runtime state available")
|
||||
|
||||
latest = json.loads(snapshots[0].read_text(encoding="utf-8"))
|
||||
agents = latest.get("agents", [])
|
||||
"""Return agent states from the active runtime, or latest persisted run."""
|
||||
snapshot = _get_active_runtime_snapshot() if _is_gateway_running() else _load_latest_runtime_snapshot()
|
||||
agents = snapshot.get("agents", [])
|
||||
|
||||
return RuntimeAgentsResponse(
|
||||
agents=[RuntimeAgentState(**a) for a in agents]
|
||||
@@ -219,58 +485,219 @@ async def get_runtime_agents() -> RuntimeAgentsResponse:
|
||||
|
||||
@router.get("/events", response_model=RuntimeEventsResponse)
|
||||
async def get_runtime_events() -> RuntimeEventsResponse:
|
||||
"""Return events from the most recent run."""
|
||||
snapshot_path = PROJECT_ROOT.glob("runs/*/state/runtime_state.json")
|
||||
snapshots = sorted(snapshot_path, key=lambda p: p.stat().st_mtime, reverse=True)
|
||||
|
||||
if not snapshots:
|
||||
raise HTTPException(status_code=404, detail="No runtime state available")
|
||||
|
||||
latest = json.loads(snapshots[0].read_text(encoding="utf-8"))
|
||||
events = latest.get("events", [])
|
||||
"""Return events from the active runtime, or latest persisted run."""
|
||||
snapshot = _get_active_runtime_snapshot() if _is_gateway_running() else _load_latest_runtime_snapshot()
|
||||
events = snapshot.get("events", [])
|
||||
|
||||
return RuntimeEventsResponse(
|
||||
events=[RuntimeEvent(**e) for e in events]
|
||||
)
|
||||
|
||||
|
||||
@router.get("/history", response_model=RuntimeHistoryResponse)
|
||||
async def get_runtime_history(limit: int = 20) -> RuntimeHistoryResponse:
|
||||
"""List recent historical runs for restore/start selection."""
|
||||
return RuntimeHistoryResponse(runs=_list_runs(limit=limit))
|
||||
|
||||
|
||||
@router.get("/gateway/status", response_model=GatewayStatusResponse)
|
||||
async def get_gateway_status() -> GatewayStatusResponse:
|
||||
"""Get Gateway process status and port."""
|
||||
global _gateway_port
|
||||
|
||||
is_running = _is_gateway_running()
|
||||
run_id = None
|
||||
|
||||
if is_running:
|
||||
# Try to find run_id from runtime state
|
||||
snapshot_path = PROJECT_ROOT.glob("runs/*/state/runtime_state.json")
|
||||
snapshots = sorted(snapshot_path, key=lambda p: p.stat().st_mtime, reverse=True)
|
||||
if snapshots:
|
||||
try:
|
||||
latest = json.loads(snapshots[0].read_text(encoding="utf-8"))
|
||||
run_id = latest.get("context", {}).get("config_name")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse latest snapshot: {e}")
|
||||
try:
|
||||
run_id = _get_active_runtime_context().get("config_name")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to resolve active runtime context: {e}")
|
||||
|
||||
return GatewayStatusResponse(
|
||||
is_running=is_running,
|
||||
port=_gateway_port,
|
||||
port=_runtime_state.gateway_port,
|
||||
run_id=run_id
|
||||
)
|
||||
|
||||
|
||||
@router.get("/gateway/port")
|
||||
async def get_gateway_port() -> Dict[str, Any]:
|
||||
async def get_gateway_port(request: Request) -> Dict[str, Any]:
|
||||
"""Get WebSocket Gateway port for frontend connection."""
|
||||
global _gateway_port
|
||||
gateway_port = _runtime_state.gateway_port
|
||||
return {
|
||||
"port": _gateway_port,
|
||||
"port": gateway_port,
|
||||
"is_running": _is_gateway_running(),
|
||||
"ws_url": f"ws://localhost:{_gateway_port}"
|
||||
"ws_url": _build_gateway_ws_url(request, gateway_port),
|
||||
}
|
||||
|
||||
|
||||
@router.get("/logs", response_model=RuntimeLogResponse)
|
||||
async def get_runtime_logs() -> RuntimeLogResponse:
|
||||
"""Return current runtime log tail, or the latest run log if runtime is stopped."""
|
||||
try:
|
||||
context = _get_active_runtime_context() if _is_gateway_running() else _get_runtime_context_from_latest_snapshot()
|
||||
except HTTPException:
|
||||
return RuntimeLogResponse(is_running=False, content="")
|
||||
|
||||
run_id = str(context.get("config_name") or "").strip() or None
|
||||
log_path = _get_gateway_log_path_for_run(run_id) if run_id else None
|
||||
content = _read_log_tail(log_path) if log_path else ""
|
||||
|
||||
return RuntimeLogResponse(
|
||||
run_id=run_id,
|
||||
is_running=_is_gateway_running(),
|
||||
log_path=str(log_path) if log_path else None,
|
||||
content=content,
|
||||
)
|
||||
|
||||
|
||||
def _build_gateway_ws_url(request: Request, port: int) -> str:
|
||||
"""Build a proxy-safe Gateway WebSocket URL."""
|
||||
forwarded_proto = request.headers.get("x-forwarded-proto", "").split(",")[0].strip()
|
||||
scheme = forwarded_proto or request.url.scheme
|
||||
ws_scheme = "wss" if scheme == "https" else "ws"
|
||||
|
||||
forwarded_host = request.headers.get("x-forwarded-host", "").split(",")[0].strip()
|
||||
host = forwarded_host or request.url.hostname or "localhost"
|
||||
if ":" in host and not host.startswith("["):
|
||||
host = host.split(":", 1)[0]
|
||||
|
||||
return f"{ws_scheme}://{host}:{port}"
|
||||
|
||||
|
||||
def _load_latest_runtime_snapshot() -> Dict[str, Any]:
|
||||
"""Load the latest persisted runtime snapshot."""
|
||||
snapshots = sorted(
|
||||
PROJECT_ROOT.glob("runs/*/state/runtime_state.json"),
|
||||
key=lambda p: p.stat().st_mtime,
|
||||
reverse=True,
|
||||
)
|
||||
if not snapshots:
|
||||
raise HTTPException(status_code=404, detail="No runtime information available")
|
||||
return json.loads(snapshots[0].read_text(encoding="utf-8"))
|
||||
|
||||
|
||||
def _get_active_runtime_snapshot() -> Dict[str, Any]:
|
||||
"""Return the active runtime snapshot, preferring in-memory manager state."""
|
||||
if not _is_gateway_running():
|
||||
raise HTTPException(status_code=404, detail="No runtime is currently running")
|
||||
|
||||
manager = _runtime_state.runtime_manager
|
||||
if manager is not None and hasattr(manager, "build_snapshot"):
|
||||
snapshot = manager.build_snapshot()
|
||||
context = snapshot.get("context") or {}
|
||||
if context.get("config_name"):
|
||||
return snapshot
|
||||
|
||||
return _load_latest_runtime_snapshot()
|
||||
|
||||
|
||||
def _get_runtime_context_from_latest_snapshot() -> Dict[str, Any]:
|
||||
"""Return the latest persisted runtime context regardless of active process state."""
|
||||
latest = _load_latest_runtime_snapshot()
|
||||
context = latest.get("context") or {}
|
||||
if not context.get("config_name"):
|
||||
raise HTTPException(status_code=404, detail="No runtime context available")
|
||||
return context
|
||||
|
||||
|
||||
def _get_gateway_log_path_for_run(run_id: str) -> Path:
|
||||
return _get_run_dir(run_id) / "logs" / "gateway.log"
|
||||
|
||||
|
||||
def _read_log_tail(path: Path, max_chars: int = 120_000) -> str:
|
||||
if not path.exists() or not path.is_file():
|
||||
return ""
|
||||
text = path.read_text(encoding="utf-8", errors="replace")
|
||||
if len(text) <= max_chars:
|
||||
return text
|
||||
return text[-max_chars:]
|
||||
|
||||
|
||||
def _get_current_runtime_context() -> Dict[str, Any]:
|
||||
"""Return the active runtime context from the latest snapshot."""
|
||||
if not _is_gateway_running():
|
||||
raise HTTPException(status_code=404, detail="No runtime is currently running")
|
||||
snapshot = _get_active_runtime_snapshot()
|
||||
context = snapshot.get("context") or {}
|
||||
if not context.get("config_name"):
|
||||
raise HTTPException(status_code=404, detail="No runtime context available")
|
||||
return context
|
||||
|
||||
|
||||
def _get_active_runtime_context() -> Dict[str, Any]:
|
||||
"""Return the active runtime context, preferring in-memory runtime manager state."""
|
||||
return _get_current_runtime_context()
|
||||
|
||||
|
||||
def _resolve_runtime_response(run_id: str) -> RuntimeConfigResponse:
|
||||
"""Build a normalized runtime config response for the active run."""
|
||||
context = _get_current_runtime_context()
|
||||
bootstrap = dict(context.get("bootstrap_values") or {})
|
||||
resolved = resolve_runtime_config(
|
||||
project_root=PROJECT_ROOT,
|
||||
config_name=run_id,
|
||||
enable_memory=bool(bootstrap.get("enable_memory", False)),
|
||||
schedule_mode=str(bootstrap.get("schedule_mode", "daily")),
|
||||
interval_minutes=int(bootstrap.get("interval_minutes", 60) or 60),
|
||||
trigger_time=str(bootstrap.get("trigger_time", "09:30") or "09:30"),
|
||||
)
|
||||
return RuntimeConfigResponse(
|
||||
run_id=run_id,
|
||||
is_running=True,
|
||||
gateway_port=_runtime_state.gateway_port,
|
||||
bootstrap=bootstrap,
|
||||
resolved=resolved,
|
||||
)
|
||||
|
||||
|
||||
def _normalize_runtime_config_updates(
|
||||
request: UpdateRuntimeConfigRequest,
|
||||
) -> Dict[str, Any]:
|
||||
"""Validate and normalize runtime config updates."""
|
||||
updates: Dict[str, Any] = {}
|
||||
|
||||
if request.schedule_mode is not None:
|
||||
schedule_mode = str(request.schedule_mode).strip().lower()
|
||||
if schedule_mode not in {"daily", "intraday"}:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="schedule_mode must be 'daily' or 'intraday'",
|
||||
)
|
||||
updates["schedule_mode"] = schedule_mode
|
||||
|
||||
if request.interval_minutes is not None:
|
||||
updates["interval_minutes"] = int(request.interval_minutes)
|
||||
|
||||
if request.trigger_time is not None:
|
||||
trigger_time = str(request.trigger_time).strip()
|
||||
if trigger_time and trigger_time != "now":
|
||||
try:
|
||||
datetime.strptime(trigger_time, "%H:%M")
|
||||
except ValueError as exc:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="trigger_time must use HH:MM or 'now'",
|
||||
) from exc
|
||||
updates["trigger_time"] = trigger_time or "09:30"
|
||||
|
||||
if request.max_comm_cycles is not None:
|
||||
updates["max_comm_cycles"] = int(request.max_comm_cycles)
|
||||
|
||||
if request.initial_cash is not None:
|
||||
updates["initial_cash"] = float(request.initial_cash)
|
||||
|
||||
if request.margin_requirement is not None:
|
||||
updates["margin_requirement"] = float(request.margin_requirement)
|
||||
|
||||
if request.enable_memory is not None:
|
||||
updates["enable_memory"] = bool(request.enable_memory)
|
||||
|
||||
if not updates:
|
||||
raise HTTPException(status_code=400, detail="No runtime config updates provided")
|
||||
|
||||
return updates
|
||||
|
||||
|
||||
@router.post("/start", response_model=LaunchResponse)
|
||||
async def start_runtime(
|
||||
config: LaunchConfig,
|
||||
@@ -284,33 +711,59 @@ async def start_runtime(
|
||||
4. Start Gateway as subprocess (Data Plane)
|
||||
5. Return Gateway port for WebSocket connection
|
||||
"""
|
||||
global _gateway_process, _gateway_port
|
||||
# Lazy import to avoid circular dependency
|
||||
from backend.runtime.manager import TradingRuntimeManager
|
||||
|
||||
# 1. Stop existing Gateway
|
||||
if _is_gateway_running():
|
||||
_stop_gateway()
|
||||
await asyncio.sleep(1) # Wait for port release
|
||||
|
||||
# 2. Generate run ID and directory
|
||||
run_id = _generate_run_id()
|
||||
run_dir = _get_run_dir(run_id)
|
||||
launch_mode = str(config.launch_mode or "fresh").strip().lower()
|
||||
if launch_mode not in {"fresh", "restore"}:
|
||||
raise HTTPException(status_code=400, detail="launch_mode must be 'fresh' or 'restore'")
|
||||
|
||||
# 3. Prepare bootstrap config
|
||||
bootstrap = {
|
||||
"tickers": config.tickers,
|
||||
"schedule_mode": config.schedule_mode,
|
||||
"interval_minutes": config.interval_minutes,
|
||||
"trigger_time": config.trigger_time,
|
||||
"max_comm_cycles": config.max_comm_cycles,
|
||||
"initial_cash": config.initial_cash,
|
||||
"margin_requirement": config.margin_requirement,
|
||||
"enable_memory": config.enable_memory,
|
||||
"mode": config.mode,
|
||||
"start_date": config.start_date,
|
||||
"end_date": config.end_date,
|
||||
"poll_interval": config.poll_interval,
|
||||
"enable_mock": config.enable_mock,
|
||||
}
|
||||
# 2. Resolve run ID, directory, and bootstrap
|
||||
if launch_mode == "restore":
|
||||
restore_run_id = str(config.restore_run_id or "").strip()
|
||||
if not restore_run_id:
|
||||
raise HTTPException(status_code=400, detail="restore_run_id is required when launch_mode=restore")
|
||||
snapshot = _load_run_snapshot(restore_run_id)
|
||||
context = snapshot.get("context") or {}
|
||||
if not context.get("config_name"):
|
||||
raise HTTPException(status_code=404, detail=f"Run context not found: {restore_run_id}")
|
||||
run_id = restore_run_id
|
||||
run_dir = _get_run_dir(run_id)
|
||||
bootstrap = dict(context.get("bootstrap_values") or {})
|
||||
bootstrap["launch_mode"] = "restore"
|
||||
bootstrap["restore_run_id"] = restore_run_id
|
||||
else:
|
||||
run_id = _generate_run_id()
|
||||
run_dir = _get_run_dir(run_id)
|
||||
bootstrap = {
|
||||
"launch_mode": "fresh",
|
||||
"restore_run_id": None,
|
||||
"tickers": config.tickers,
|
||||
"schedule_mode": config.schedule_mode,
|
||||
"interval_minutes": config.interval_minutes,
|
||||
"trigger_time": config.trigger_time,
|
||||
"max_comm_cycles": config.max_comm_cycles,
|
||||
"initial_cash": config.initial_cash,
|
||||
"margin_requirement": config.margin_requirement,
|
||||
"enable_memory": config.enable_memory,
|
||||
"mode": config.mode,
|
||||
"start_date": config.start_date,
|
||||
"end_date": config.end_date,
|
||||
"poll_interval": config.poll_interval,
|
||||
}
|
||||
|
||||
retention_keep = max(1, int(os.getenv("RUNS_RETENTION_COUNT", "20") or "20"))
|
||||
pruned_run_ids = _prune_old_timestamped_runs(
|
||||
keep=retention_keep,
|
||||
exclude_run_ids={run_id},
|
||||
)
|
||||
if pruned_run_ids:
|
||||
logger.info("Pruned old run directories: %s", ", ".join(pruned_run_ids))
|
||||
|
||||
# 4. Create runtime manager
|
||||
manager = TradingRuntimeManager(
|
||||
@@ -325,25 +778,28 @@ async def start_runtime(
|
||||
_write_bootstrap_md(run_dir, bootstrap)
|
||||
|
||||
# 6. Find available port and start Gateway process
|
||||
_gateway_port = _find_available_port(start_port=8765)
|
||||
gateway_port = _find_available_port(start_port=8765)
|
||||
_runtime_state.gateway_port = gateway_port
|
||||
|
||||
try:
|
||||
_gateway_process = _start_gateway_process(
|
||||
process = _start_gateway_process(
|
||||
run_id=run_id,
|
||||
run_dir=run_dir,
|
||||
bootstrap=bootstrap,
|
||||
port=_gateway_port
|
||||
port=gateway_port
|
||||
)
|
||||
_runtime_state.gateway_process = process
|
||||
|
||||
# Wait briefly to check if process started successfully
|
||||
await asyncio.sleep(2)
|
||||
|
||||
if not _is_gateway_running():
|
||||
stdout, stderr = _gateway_process.communicate(timeout=1)
|
||||
_gateway_process = None
|
||||
_runtime_state.gateway_process = None
|
||||
log_path = _get_gateway_log_path_for_run(run_id)
|
||||
log_tail = _read_log_tail(log_path, max_chars=4000)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Gateway failed to start: {stderr.decode() if stderr else 'Unknown error'}"
|
||||
detail=f"Gateway failed to start: {log_tail or 'Unknown error'}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -354,16 +810,44 @@ async def start_runtime(
|
||||
run_id=run_id,
|
||||
status="started",
|
||||
run_dir=str(run_dir),
|
||||
gateway_port=_gateway_port,
|
||||
message=f"Runtime started with run_id: {run_id}, Gateway on port: {_gateway_port}",
|
||||
gateway_port=gateway_port,
|
||||
message=f"Runtime started with run_id: {run_id}, Gateway on port: {gateway_port}",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/config", response_model=RuntimeConfigResponse)
|
||||
async def get_runtime_config() -> RuntimeConfigResponse:
|
||||
"""Return the current runtime bootstrap and resolved settings."""
|
||||
context = _get_current_runtime_context()
|
||||
return _resolve_runtime_response(context["config_name"])
|
||||
|
||||
|
||||
@router.put("/config", response_model=RuntimeConfigResponse)
|
||||
async def update_runtime_config(
|
||||
request: UpdateRuntimeConfigRequest,
|
||||
) -> RuntimeConfigResponse:
|
||||
"""Persist selected runtime configuration updates for the active run."""
|
||||
context = _get_current_runtime_context()
|
||||
run_id = context["config_name"]
|
||||
updates = _normalize_runtime_config_updates(request)
|
||||
updated = update_bootstrap_values_for_run(PROJECT_ROOT, run_id, updates)
|
||||
|
||||
manager = _runtime_state.runtime_manager
|
||||
if manager is not None and getattr(manager, "config_name", None) == run_id:
|
||||
manager.bootstrap.update(updates)
|
||||
if getattr(manager, "context", None) is not None:
|
||||
manager.context.bootstrap_values.update(updates)
|
||||
if hasattr(manager, "_persist_snapshot"):
|
||||
manager._persist_snapshot()
|
||||
|
||||
response = _resolve_runtime_response(run_id)
|
||||
response.bootstrap = dict(updated.values)
|
||||
return response
|
||||
|
||||
|
||||
@router.post("/stop", response_model=StopResponse)
|
||||
async def stop_runtime(force: bool = True) -> StopResponse:
|
||||
"""Stop the current running runtime."""
|
||||
global _gateway_process
|
||||
|
||||
was_running = _is_gateway_running()
|
||||
|
||||
if not was_running:
|
||||
@@ -381,6 +865,25 @@ async def stop_runtime(force: bool = True) -> StopResponse:
|
||||
)
|
||||
|
||||
|
||||
@router.post("/cleanup", response_model=CleanupResponse)
|
||||
async def cleanup_old_runs(keep: int = 20) -> CleanupResponse:
|
||||
"""Prune old timestamped run directories while preserving named runs."""
|
||||
keep_count = max(1, int(keep))
|
||||
exclude: set[str] = set()
|
||||
|
||||
if _is_gateway_running():
|
||||
try:
|
||||
active_context = _get_active_runtime_context()
|
||||
active_run_id = str(active_context.get("config_name") or "").strip()
|
||||
if active_run_id:
|
||||
exclude.add(active_run_id)
|
||||
except HTTPException:
|
||||
pass
|
||||
|
||||
pruned = _prune_old_timestamped_runs(keep=keep_count, exclude_run_ids=exclude)
|
||||
return CleanupResponse(status="ok", kept=keep_count, pruned_run_ids=pruned)
|
||||
|
||||
|
||||
@router.post("/restart")
|
||||
async def restart_runtime(
|
||||
config: LaunchConfig,
|
||||
@@ -407,35 +910,31 @@ async def get_current_runtime():
|
||||
if not _is_gateway_running():
|
||||
raise HTTPException(status_code=404, detail="No runtime is currently running")
|
||||
|
||||
# Find latest runtime state
|
||||
snapshot_path = PROJECT_ROOT.glob("runs/*/state/runtime_state.json")
|
||||
snapshots = sorted(snapshot_path, key=lambda p: p.stat().st_mtime, reverse=True)
|
||||
|
||||
if not snapshots:
|
||||
raise HTTPException(status_code=404, detail="No runtime information available")
|
||||
|
||||
latest = json.loads(snapshots[0].read_text(encoding="utf-8"))
|
||||
context = latest.get("context", {})
|
||||
context = _get_active_runtime_context()
|
||||
|
||||
return {
|
||||
"run_id": context.get("config_name"),
|
||||
"run_dir": context.get("run_dir"),
|
||||
"is_running": True,
|
||||
"gateway_port": _gateway_port,
|
||||
"gateway_port": _runtime_state.gateway_port,
|
||||
"bootstrap": context.get("bootstrap_values", {}),
|
||||
}
|
||||
|
||||
|
||||
def register_runtime_manager(manager: TradingRuntimeManager) -> None:
|
||||
def register_runtime_manager(manager: Any) -> None:
|
||||
"""Allow other modules to expose the runtime manager to the API."""
|
||||
global runtime_manager
|
||||
runtime_manager = manager
|
||||
# Also update the RuntimeState singleton for internal consistency
|
||||
_runtime_state.runtime_manager = manager
|
||||
|
||||
|
||||
def unregister_runtime_manager() -> None:
|
||||
"""Drop the runtime manager reference."""
|
||||
global runtime_manager
|
||||
runtime_manager = None
|
||||
# Also update the RuntimeState singleton for internal consistency
|
||||
_runtime_state.runtime_manager = None
|
||||
|
||||
|
||||
def _write_bootstrap_md(run_dir: Path, bootstrap: Dict[str, Any]) -> None:
|
||||
|
||||
115
backend/app.py
115
backend/app.py
@@ -1,115 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
FastAPI Application - REST API for EvoTraders
|
||||
|
||||
Provides HTTP endpoints for:
|
||||
- Agent management
|
||||
- Workspace management
|
||||
- Tool guard operations
|
||||
- Health checks
|
||||
"""
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from backend.api import agents_router, workspaces_router, guard_router, runtime_router
|
||||
from backend.agents import AgentFactory, WorkspaceManager, get_registry
|
||||
|
||||
|
||||
# Global instances (initialized on startup)
|
||||
agent_factory: AgentFactory | None = None
|
||||
workspace_manager: WorkspaceManager | None = None
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator:
|
||||
"""
|
||||
Application lifespan manager.
|
||||
|
||||
Initializes global services on startup and cleans up on shutdown.
|
||||
"""
|
||||
global agent_factory, workspace_manager
|
||||
|
||||
# Startup: Initialize services
|
||||
project_root = Path(__file__).parent.parent
|
||||
|
||||
# Initialize workspace manager
|
||||
workspace_manager = WorkspaceManager(project_root=project_root)
|
||||
|
||||
# Initialize agent factory
|
||||
agent_factory = AgentFactory(project_root=project_root)
|
||||
|
||||
# Ensure workspaces root exists
|
||||
agent_factory.workspaces_root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Get or create global registry
|
||||
registry = get_registry()
|
||||
|
||||
print(f"✓ EvoTraders API started")
|
||||
print(f" - Workspaces root: {agent_factory.workspaces_root}")
|
||||
print(f" - Registered agents: {registry.get_agent_count()}")
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown: Cleanup
|
||||
print("✓ EvoTraders API shutting down")
|
||||
|
||||
|
||||
# Create FastAPI application
|
||||
app = FastAPI(
|
||||
title="EvoTraders API",
|
||||
description="REST API for the EvoTraders multi-agent trading system",
|
||||
version="0.1.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # Configure appropriately for production
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
# Health check endpoint
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint."""
|
||||
registry = get_registry()
|
||||
return {
|
||||
"status": "healthy",
|
||||
"version": "0.1.0",
|
||||
"agents_registered": registry.get_agent_count(),
|
||||
"workspaces_available": len(workspace_manager.list_workspaces()) if workspace_manager else 0,
|
||||
}
|
||||
|
||||
|
||||
# API status endpoint
|
||||
@app.get("/api/status")
|
||||
async def api_status():
|
||||
"""Get API status and system information."""
|
||||
registry = get_registry()
|
||||
stats = registry.get_stats()
|
||||
|
||||
return {
|
||||
"status": "operational",
|
||||
"registry": stats,
|
||||
}
|
||||
|
||||
|
||||
# Include routers
|
||||
app.include_router(workspaces_router)
|
||||
app.include_router(agents_router)
|
||||
app.include_router(guard_router)
|
||||
app.include_router(runtime_router)
|
||||
|
||||
|
||||
# Main entry point for running with uvicorn
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
30
backend/apps/__init__.py
Normal file
30
backend/apps/__init__.py
Normal file
@@ -0,0 +1,30 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Application surfaces for progressive service extraction."""
|
||||
|
||||
from .agent_service import app as agent_app
|
||||
from .agent_service import create_app as create_agent_app
|
||||
from .news_service import app as news_app
|
||||
from .news_service import create_app as create_news_app
|
||||
from .runtime_service import app as runtime_app
|
||||
from .runtime_service import create_app as create_runtime_app
|
||||
from .trading_service import app as trading_app
|
||||
from .trading_service import create_app as create_trading_app
|
||||
from .cors import add_cors_middleware, get_cors_origins
|
||||
|
||||
app = agent_app
|
||||
create_app = create_agent_app
|
||||
|
||||
__all__ = [
|
||||
"app",
|
||||
"create_app",
|
||||
"agent_app",
|
||||
"create_agent_app",
|
||||
"news_app",
|
||||
"create_news_app",
|
||||
"runtime_app",
|
||||
"create_runtime_app",
|
||||
"trading_app",
|
||||
"create_trading_app",
|
||||
"add_cors_middleware",
|
||||
"get_cors_origins",
|
||||
]
|
||||
89
backend/apps/agent_service.py
Normal file
89
backend/apps/agent_service.py
Normal file
@@ -0,0 +1,89 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Agent control-plane FastAPI surface."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
from backend.apps.cors import add_cors_middleware
|
||||
|
||||
from backend.api import agents_router, guard_router, workspaces_router
|
||||
from backend.agents import AgentFactory, WorkspaceManager, get_registry
|
||||
|
||||
# Global instances (initialized on startup)
|
||||
agent_factory: AgentFactory | None = None
|
||||
workspace_manager: WorkspaceManager | None = None
|
||||
|
||||
|
||||
def create_app(project_root: Path | None = None) -> FastAPI:
|
||||
"""Create the agent control-plane app."""
|
||||
resolved_project_root = project_root or Path(__file__).resolve().parents[2]
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(_app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
"""Initialize workspace and registry state for the control plane."""
|
||||
global agent_factory, workspace_manager
|
||||
|
||||
workspace_manager = WorkspaceManager(project_root=resolved_project_root)
|
||||
agent_factory = AgentFactory(project_root=resolved_project_root)
|
||||
agent_factory.workspaces_root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
registry = get_registry()
|
||||
print("✓ EvoTraders API started")
|
||||
print(f" - Workspaces root: {agent_factory.workspaces_root}")
|
||||
print(f" - Registered agents: {registry.get_agent_count()}")
|
||||
|
||||
yield
|
||||
|
||||
print("✓ EvoTraders API shutting down")
|
||||
|
||||
app = FastAPI(
|
||||
title="EvoTraders Agent Service",
|
||||
description="REST API for the EvoTraders multi-agent control plane",
|
||||
version="0.1.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
add_cors_middleware(app)
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check() -> dict[str, object]:
|
||||
"""Health check endpoint."""
|
||||
registry = get_registry()
|
||||
return {
|
||||
"status": "healthy",
|
||||
"version": "0.1.0",
|
||||
"agents_registered": registry.get_agent_count(),
|
||||
"workspaces_available": (
|
||||
len(workspace_manager.list_workspaces())
|
||||
if workspace_manager
|
||||
else 0
|
||||
),
|
||||
}
|
||||
|
||||
@app.get("/api/status")
|
||||
async def api_status() -> dict[str, object]:
|
||||
"""Get API status and registry information."""
|
||||
registry = get_registry()
|
||||
return {
|
||||
"status": "operational",
|
||||
"registry": registry.get_stats(),
|
||||
}
|
||||
|
||||
app.include_router(workspaces_router)
|
||||
app.include_router(agents_router)
|
||||
app.include_router(guard_router)
|
||||
return app
|
||||
|
||||
|
||||
app = create_app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
30
backend/apps/cors.py
Normal file
30
backend/apps/cors.py
Normal file
@@ -0,0 +1,30 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Shared CORS configuration for all microservice apps."""
|
||||
|
||||
import os
|
||||
from typing import Sequence
|
||||
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
|
||||
def get_cors_origins() -> Sequence[str]:
|
||||
"""Get allowed CORS origins from environment variable.
|
||||
|
||||
Defaults to ["*"] for backward compatibility.
|
||||
Set CORS_ALLOWED_ORIGINS env var (comma-separated) in production.
|
||||
"""
|
||||
origins = os.getenv("CORS_ALLOWED_ORIGINS", "").strip()
|
||||
if not origins:
|
||||
return ["*"]
|
||||
return [o.strip() for o in origins.split(",") if o.strip()]
|
||||
|
||||
|
||||
def add_cors_middleware(app: "FastAPI") -> None:
|
||||
"""Add CORS middleware to app with environment-configured origins."""
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=get_cors_origins(),
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
154
backend/apps/news_service.py
Normal file
154
backend/apps/news_service.py
Normal file
@@ -0,0 +1,154 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""News and explain FastAPI surface."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fastapi import Depends, FastAPI, Query
|
||||
from backend.apps.cors import add_cors_middleware
|
||||
|
||||
from backend.data.market_store import MarketStore
|
||||
from backend.domains import news as news_domain
|
||||
|
||||
|
||||
def get_market_store() -> MarketStore:
|
||||
"""Get the MarketStore singleton dependency."""
|
||||
return MarketStore.get_instance()
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
"""Create the news/explain service app."""
|
||||
app = FastAPI(
|
||||
title="EvoTraders News Service",
|
||||
description="Read-only news enrichment and explain service surface extracted from the monolith",
|
||||
version="0.1.0",
|
||||
)
|
||||
|
||||
add_cors_middleware(app)
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check() -> dict[str, str]:
|
||||
return {"status": "healthy", "service": "news-service"}
|
||||
|
||||
@app.get("/api/enriched-news")
|
||||
async def api_get_enriched_news(
|
||||
ticker: str = Query(..., min_length=1),
|
||||
start_date: str | None = Query(None),
|
||||
end_date: str | None = Query(None),
|
||||
limit: int = Query(100, ge=1, le=1000),
|
||||
store: MarketStore = Depends(get_market_store),
|
||||
) -> dict[str, Any]:
|
||||
return news_domain.get_enriched_news(
|
||||
store,
|
||||
ticker=ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=limit,
|
||||
refresh_if_stale=False,
|
||||
)
|
||||
|
||||
@app.get("/api/news-for-date")
|
||||
async def api_get_news_for_date(
|
||||
ticker: str = Query(..., min_length=1),
|
||||
date: str = Query(...),
|
||||
limit: int = Query(20, ge=1, le=100),
|
||||
store: MarketStore = Depends(get_market_store),
|
||||
) -> dict[str, Any]:
|
||||
return news_domain.get_news_for_date(
|
||||
store,
|
||||
ticker=ticker,
|
||||
date=date,
|
||||
limit=limit,
|
||||
refresh_if_stale=False,
|
||||
)
|
||||
|
||||
@app.get("/api/news-timeline")
|
||||
async def api_get_news_timeline(
|
||||
ticker: str = Query(..., min_length=1),
|
||||
start_date: str = Query(...),
|
||||
end_date: str = Query(...),
|
||||
store: MarketStore = Depends(get_market_store),
|
||||
) -> dict[str, Any]:
|
||||
return news_domain.get_news_timeline(
|
||||
store,
|
||||
ticker=ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
refresh_if_stale=False,
|
||||
)
|
||||
|
||||
@app.get("/api/categories")
|
||||
async def api_get_categories(
|
||||
ticker: str = Query(..., min_length=1),
|
||||
start_date: str | None = Query(None),
|
||||
end_date: str | None = Query(None),
|
||||
limit: int = Query(200, ge=1, le=1000),
|
||||
store: MarketStore = Depends(get_market_store),
|
||||
) -> dict[str, Any]:
|
||||
return news_domain.get_news_categories(
|
||||
store,
|
||||
ticker=ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=limit,
|
||||
refresh_if_stale=False,
|
||||
)
|
||||
|
||||
@app.get("/api/similar-days")
|
||||
async def api_get_similar_days(
|
||||
ticker: str = Query(..., min_length=1),
|
||||
date: str = Query(...),
|
||||
n_similar: int = Query(5, ge=1, le=20),
|
||||
store: MarketStore = Depends(get_market_store),
|
||||
) -> dict[str, Any]:
|
||||
return news_domain.get_similar_days_payload(
|
||||
store,
|
||||
ticker=ticker,
|
||||
date=date,
|
||||
n_similar=n_similar,
|
||||
refresh_if_stale=False,
|
||||
)
|
||||
|
||||
@app.get("/api/stories/{ticker}")
|
||||
async def api_get_story(
|
||||
ticker: str,
|
||||
as_of_date: str = Query(...),
|
||||
store: MarketStore = Depends(get_market_store),
|
||||
) -> dict[str, Any]:
|
||||
return news_domain.get_story_payload(
|
||||
store,
|
||||
ticker=ticker,
|
||||
as_of_date=as_of_date,
|
||||
refresh_if_stale=False,
|
||||
)
|
||||
|
||||
@app.get("/api/range-explain")
|
||||
async def api_get_range_explain(
|
||||
ticker: str = Query(..., min_length=1),
|
||||
start_date: str = Query(...),
|
||||
end_date: str = Query(...),
|
||||
article_ids: list[str] = Query(default=[]),
|
||||
limit: int = Query(100, ge=1, le=500),
|
||||
store: MarketStore = Depends(get_market_store),
|
||||
) -> dict[str, Any]:
|
||||
return news_domain.get_range_explain_payload(
|
||||
store,
|
||||
ticker=ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
article_ids=article_ids,
|
||||
limit=limit,
|
||||
refresh_if_stale=False,
|
||||
)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
app = create_app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=8002)
|
||||
62
backend/apps/runtime_service.py
Normal file
62
backend/apps/runtime_service.py
Normal file
@@ -0,0 +1,62 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Dedicated runtime service FastAPI surface."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
from backend.api import runtime_router
|
||||
from backend.api.runtime import get_runtime_state
|
||||
from backend.apps.cors import add_cors_middleware
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
"""Create the runtime service app."""
|
||||
app = FastAPI(
|
||||
title="EvoTraders Runtime Service",
|
||||
description="Runtime lifecycle and gateway service surface extracted from the monolith",
|
||||
version="0.1.0",
|
||||
)
|
||||
|
||||
add_cors_middleware(app)
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check() -> dict[str, object]:
|
||||
"""Health check for the runtime service."""
|
||||
runtime_state = get_runtime_state()
|
||||
process = runtime_state.gateway_process
|
||||
is_running = process is not None and process.poll() is None
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "runtime-service",
|
||||
"gateway_running": is_running,
|
||||
"gateway_port": runtime_state.gateway_port,
|
||||
}
|
||||
|
||||
@app.get("/api/status")
|
||||
async def api_status() -> dict[str, object]:
|
||||
"""Service-level status payload for runtime orchestration."""
|
||||
runtime_state = get_runtime_state()
|
||||
process = runtime_state.gateway_process
|
||||
is_running = process is not None and process.poll() is None
|
||||
return {
|
||||
"status": "operational",
|
||||
"service": "runtime-service",
|
||||
"runtime": {
|
||||
"gateway_running": is_running,
|
||||
"gateway_port": runtime_state.gateway_port,
|
||||
"has_runtime_manager": runtime_state.runtime_manager is not None,
|
||||
},
|
||||
}
|
||||
|
||||
app.include_router(runtime_router)
|
||||
return app
|
||||
|
||||
|
||||
app = create_app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=8003)
|
||||
136
backend/apps/trading_service.py
Normal file
136
backend/apps/trading_service.py
Normal file
@@ -0,0 +1,136 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Trading data FastAPI surface."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fastapi import FastAPI, Query
|
||||
from backend.apps.cors import add_cors_middleware
|
||||
|
||||
from backend.domains import trading as trading_domain
|
||||
from shared.schema import (
|
||||
CompanyNewsResponse,
|
||||
FinancialMetricsResponse,
|
||||
InsiderTradeResponse,
|
||||
LineItemResponse,
|
||||
PriceResponse,
|
||||
)
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
"""Create the trading data service app."""
|
||||
app = FastAPI(
|
||||
title="EvoTraders Trading Service",
|
||||
description="Read-only trading data service surface extracted from the monolith",
|
||||
version="0.1.0",
|
||||
)
|
||||
|
||||
add_cors_middleware(app)
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check() -> dict[str, str]:
|
||||
"""Health check endpoint."""
|
||||
return {"status": "healthy", "service": "trading-service"}
|
||||
|
||||
@app.get("/api/prices", response_model=PriceResponse)
|
||||
async def api_get_prices(
|
||||
ticker: str = Query(..., min_length=1),
|
||||
start_date: str = Query(...),
|
||||
end_date: str = Query(...),
|
||||
) -> PriceResponse:
|
||||
payload = trading_domain.get_prices_payload(
|
||||
ticker=ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
return PriceResponse(ticker=payload["ticker"], prices=payload["prices"])
|
||||
|
||||
@app.get("/api/financials", response_model=FinancialMetricsResponse)
|
||||
async def api_get_financials(
|
||||
ticker: str = Query(..., min_length=1),
|
||||
end_date: str = Query(...),
|
||||
period: str = Query("ttm"),
|
||||
limit: int = Query(10, ge=1, le=100),
|
||||
) -> FinancialMetricsResponse:
|
||||
payload = trading_domain.get_financials_payload(
|
||||
ticker=ticker,
|
||||
end_date=end_date,
|
||||
period=period,
|
||||
limit=limit,
|
||||
)
|
||||
return FinancialMetricsResponse(financial_metrics=payload["financial_metrics"])
|
||||
|
||||
@app.get("/api/news", response_model=CompanyNewsResponse)
|
||||
async def api_get_news(
|
||||
ticker: str = Query(..., min_length=1),
|
||||
end_date: str = Query(...),
|
||||
start_date: str | None = Query(None),
|
||||
limit: int = Query(1000, ge=1, le=5000),
|
||||
) -> CompanyNewsResponse:
|
||||
payload = trading_domain.get_news_payload(
|
||||
ticker=ticker,
|
||||
end_date=end_date,
|
||||
start_date=start_date,
|
||||
limit=limit,
|
||||
)
|
||||
return CompanyNewsResponse(news=payload["news"])
|
||||
|
||||
@app.get("/api/insider-trades", response_model=InsiderTradeResponse)
|
||||
async def api_get_insider_trades(
|
||||
ticker: str = Query(..., min_length=1),
|
||||
end_date: str = Query(...),
|
||||
start_date: str | None = Query(None),
|
||||
limit: int = Query(1000, ge=1, le=5000),
|
||||
) -> InsiderTradeResponse:
|
||||
payload = trading_domain.get_insider_trades_payload(
|
||||
ticker=ticker,
|
||||
end_date=end_date,
|
||||
start_date=start_date,
|
||||
limit=limit,
|
||||
)
|
||||
return InsiderTradeResponse(insider_trades=payload["insider_trades"])
|
||||
|
||||
@app.get("/api/market/status")
|
||||
async def api_get_market_status() -> dict[str, Any]:
|
||||
"""Return current market status using the existing market service logic."""
|
||||
return trading_domain.get_market_status_payload()
|
||||
|
||||
@app.get("/api/market-cap")
|
||||
async def api_get_market_cap(
|
||||
ticker: str = Query(..., min_length=1),
|
||||
end_date: str = Query(...),
|
||||
) -> dict[str, Any]:
|
||||
"""Return market cap for one ticker/date."""
|
||||
return trading_domain.get_market_cap_payload(
|
||||
ticker=ticker,
|
||||
end_date=end_date,
|
||||
)
|
||||
|
||||
@app.get("/api/line-items", response_model=LineItemResponse)
|
||||
async def api_get_line_items(
|
||||
ticker: str = Query(..., min_length=1),
|
||||
line_items: list[str] = Query(...),
|
||||
end_date: str = Query(...),
|
||||
period: str = Query("ttm"),
|
||||
limit: int = Query(10, ge=1, le=100),
|
||||
) -> LineItemResponse:
|
||||
payload = trading_domain.get_line_items_payload(
|
||||
ticker=ticker,
|
||||
line_items=line_items,
|
||||
end_date=end_date,
|
||||
period=period,
|
||||
limit=limit,
|
||||
)
|
||||
return LineItemResponse(search_results=payload["search_results"])
|
||||
|
||||
return app
|
||||
|
||||
|
||||
app = create_app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=8001)
|
||||
@@ -29,7 +29,7 @@ from rich.table import Table
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from backend.agents.agent_workspace import load_agent_workspace_config
|
||||
from backend.agents.prompt_loader import PromptLoader
|
||||
from backend.agents.prompt_loader import get_prompt_loader
|
||||
from backend.agents.skills_manager import SkillsManager
|
||||
from backend.agents.team_pipeline_config import (
|
||||
ensure_team_pipeline_config,
|
||||
@@ -55,7 +55,7 @@ team_app = typer.Typer(help="Inspect and manage run-scoped team pipeline config.
|
||||
app.add_typer(team_app, name="team")
|
||||
|
||||
console = Console()
|
||||
_prompt_loader = PromptLoader()
|
||||
_prompt_loader = get_prompt_loader()
|
||||
load_dotenv()
|
||||
|
||||
|
||||
@@ -1019,11 +1019,6 @@ def backtest(
|
||||
|
||||
@app.command()
|
||||
def live(
|
||||
mock: bool = typer.Option(
|
||||
False,
|
||||
"--mock",
|
||||
help="Use mock mode with simulated prices (for testing)",
|
||||
),
|
||||
config_name: str = typer.Option(
|
||||
"live",
|
||||
"--config-name",
|
||||
@@ -1078,7 +1073,6 @@ def live(
|
||||
|
||||
Example:
|
||||
evotraders live # Run immediately (default)
|
||||
evotraders live --mock # Mock mode
|
||||
evotraders live -t 22:30 # Run at 22:30 local time daily
|
||||
evotraders live --schedule-mode intraday --interval-minutes 60
|
||||
evotraders live --trigger-time now # Run immediately
|
||||
@@ -1086,33 +1080,31 @@ def live(
|
||||
"""
|
||||
schedule_mode = str(_normalize_typer_value(schedule_mode, "daily"))
|
||||
interval_minutes = int(_normalize_typer_value(interval_minutes, 60))
|
||||
mode_name = "MOCK" if mock else "LIVE"
|
||||
console.print(
|
||||
Panel.fit(
|
||||
f"[bold cyan]EvoTraders {mode_name} Mode[/bold cyan]",
|
||||
"[bold cyan]EvoTraders LIVE Mode[/bold cyan]",
|
||||
border_style="cyan",
|
||||
),
|
||||
)
|
||||
|
||||
# Check for required API key in live mode
|
||||
if not mock:
|
||||
env_file = get_project_root() / ".env"
|
||||
if not env_file.exists():
|
||||
console.print("\n[yellow]Warning: .env file not found[/yellow]")
|
||||
console.print("Creating from template...\n")
|
||||
template = get_project_root() / "env.template"
|
||||
if template.exists():
|
||||
shutil.copy(template, env_file)
|
||||
console.print("[green].env file created[/green]")
|
||||
console.print(
|
||||
"\n[red]Error: Please edit .env and set FINNHUB_API_KEY[/red]",
|
||||
)
|
||||
console.print(
|
||||
"Get your free API key at: https://finnhub.io/register\n",
|
||||
)
|
||||
else:
|
||||
console.print("[red]Error: env.template not found[/red]")
|
||||
raise typer.Exit(1)
|
||||
env_file = get_project_root() / ".env"
|
||||
if not env_file.exists():
|
||||
console.print("\n[yellow]Warning: .env file not found[/yellow]")
|
||||
console.print("Creating from template...\n")
|
||||
template = get_project_root() / "env.template"
|
||||
if template.exists():
|
||||
shutil.copy(template, env_file)
|
||||
console.print("[green].env file created[/green]")
|
||||
console.print(
|
||||
"\n[red]Error: Please edit .env and set FINNHUB_API_KEY[/red]",
|
||||
)
|
||||
console.print(
|
||||
"Get your free API key at: https://finnhub.io/register\n",
|
||||
)
|
||||
else:
|
||||
console.print("[red]Error: env.template not found[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
# Handle historical data cleanup
|
||||
handle_history_cleanup(config_name, auto_clean=clean)
|
||||
@@ -1168,12 +1160,9 @@ def live(
|
||||
|
||||
# Display configuration
|
||||
console.print("\n[bold]Configuration:[/bold]")
|
||||
if mock:
|
||||
console.print(" Mode: [yellow]MOCK[/yellow] (Simulated prices)")
|
||||
else:
|
||||
console.print(
|
||||
" Mode: [green]LIVE[/green] (Real-time prices via Finnhub)",
|
||||
)
|
||||
console.print(
|
||||
" Mode: [green]LIVE[/green] (Real-time prices via Finnhub)",
|
||||
)
|
||||
console.print(f" Config: {config_name}")
|
||||
console.print(f" Server: {host}:{port}")
|
||||
console.print(f" Poll Interval: {poll_interval}s")
|
||||
@@ -1188,22 +1177,17 @@ def live(
|
||||
project_root = get_project_root()
|
||||
os.chdir(project_root)
|
||||
|
||||
# Data update (if not mock mode)
|
||||
if not mock:
|
||||
run_data_updater(project_root)
|
||||
auto_update_market_store(
|
||||
config_name,
|
||||
end_date=nyse_now.date().isoformat(),
|
||||
)
|
||||
auto_enrich_market_store(
|
||||
config_name,
|
||||
end_date=nyse_now.date().isoformat(),
|
||||
force=False,
|
||||
)
|
||||
else:
|
||||
console.print(
|
||||
"\n[dim]Mock mode enabled - skipping data update[/dim]\n",
|
||||
)
|
||||
# Data update
|
||||
run_data_updater(project_root)
|
||||
auto_update_market_store(
|
||||
config_name,
|
||||
end_date=nyse_now.date().isoformat(),
|
||||
)
|
||||
auto_enrich_market_store(
|
||||
config_name,
|
||||
end_date=nyse_now.date().isoformat(),
|
||||
force=False,
|
||||
)
|
||||
|
||||
# Build command using backend.main
|
||||
cmd = [
|
||||
@@ -1229,8 +1213,6 @@ def live(
|
||||
str(interval_minutes),
|
||||
]
|
||||
|
||||
if mock:
|
||||
cmd.append("--mock")
|
||||
if enable_memory:
|
||||
cmd.append("--enable-memory")
|
||||
|
||||
|
||||
@@ -76,27 +76,19 @@ def _resolve_config() -> DataSourceConfig:
|
||||
"""
|
||||
Resolve data source configuration based on available API keys.
|
||||
|
||||
Priority:
|
||||
1. FINNHUB_API_KEY (if set)
|
||||
2. FINANCIAL_DATASETS_API_KEY (if set)
|
||||
3. Raises error if neither is available
|
||||
The effective source should always match the first item in the resolved
|
||||
ordered source list.
|
||||
"""
|
||||
sources = _ordered_sources()
|
||||
if "finnhub" in sources:
|
||||
return DataSourceConfig(
|
||||
source="finnhub",
|
||||
api_key=os.getenv("FINNHUB_API_KEY", "").strip(),
|
||||
sources=sources,
|
||||
)
|
||||
if "financial_datasets" in sources:
|
||||
return DataSourceConfig(
|
||||
source="financial_datasets",
|
||||
api_key=os.getenv("FINANCIAL_DATASETS_API_KEY", "").strip(),
|
||||
sources=sources,
|
||||
)
|
||||
if "yfinance" in sources:
|
||||
return DataSourceConfig(source="yfinance", api_key="", sources=sources)
|
||||
return DataSourceConfig(source="local_csv", api_key="", sources=sources)
|
||||
source = sources[0] if sources else "local_csv"
|
||||
|
||||
api_key = ""
|
||||
if source == "finnhub":
|
||||
api_key = os.getenv("FINNHUB_API_KEY", "").strip()
|
||||
elif source == "financial_datasets":
|
||||
api_key = os.getenv("FINANCIAL_DATASETS_API_KEY", "").strip()
|
||||
|
||||
return DataSourceConfig(source=source, api_key=api_key, sources=sources)
|
||||
|
||||
|
||||
def get_config() -> DataSourceConfig:
|
||||
|
||||
@@ -1,7 +1,35 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Core pipeline and orchestration logic"""
|
||||
"""Core pipeline and orchestration logic.
|
||||
|
||||
Keep ``pipeline_runner`` behind lazy wrappers so importing ``backend.core`` does
|
||||
not immediately pull in the gateway runtime graph.
|
||||
"""
|
||||
|
||||
from .pipeline import TradingPipeline
|
||||
from .state_sync import StateSync
|
||||
|
||||
__all__ = ["TradingPipeline", "StateSync"]
|
||||
|
||||
def create_agents(*args, **kwargs):
|
||||
from .pipeline_runner import create_agents as _create_agents
|
||||
|
||||
return _create_agents(*args, **kwargs)
|
||||
|
||||
|
||||
def create_long_term_memory(*args, **kwargs):
|
||||
from .pipeline_runner import create_long_term_memory as _create_long_term_memory
|
||||
|
||||
return _create_long_term_memory(*args, **kwargs)
|
||||
|
||||
|
||||
def stop_gateway(*args, **kwargs):
|
||||
from .pipeline_runner import stop_gateway as _stop_gateway
|
||||
|
||||
return _stop_gateway(*args, **kwargs)
|
||||
|
||||
__all__ = [
|
||||
"TradingPipeline",
|
||||
"StateSync",
|
||||
"create_agents",
|
||||
"create_long_term_memory",
|
||||
"stop_gateway",
|
||||
]
|
||||
|
||||
@@ -30,7 +30,7 @@ from backend.agents.team_pipeline_config import (
|
||||
from backend.agents import AnalystAgent
|
||||
from backend.agents.toolkit_factory import create_agent_toolkit
|
||||
from backend.agents.workspace_manager import WorkspaceManager
|
||||
from backend.agents.prompt_loader import PromptLoader
|
||||
from backend.agents.prompt_loader import get_prompt_loader
|
||||
from backend.llm.models import get_agent_formatter, get_agent_model
|
||||
from backend.config.constants import ANALYST_TYPES
|
||||
|
||||
@@ -1623,7 +1623,7 @@ class TradingPipeline:
|
||||
|
||||
config_name = getattr(self.pm, "config", {}).get("config_name", "default")
|
||||
project_root = Path(__file__).resolve().parents[2]
|
||||
personas = PromptLoader().load_yaml_config("analyst", "personas")
|
||||
personas = get_prompt_loader().load_yaml_config("analyst", "personas")
|
||||
persona = personas.get(analyst_type, {})
|
||||
WorkspaceManager(project_root=project_root).ensure_agent_assets(
|
||||
config_name=config_name,
|
||||
|
||||
@@ -17,7 +17,7 @@ from typing import Any, Dict, Optional, Callable
|
||||
from backend.agents import AnalystAgent, PMAgent, RiskAgent
|
||||
from backend.agents.skills_manager import SkillsManager
|
||||
from backend.agents.toolkit_factory import create_agent_toolkit, load_agent_profiles
|
||||
from backend.agents.prompt_loader import PromptLoader
|
||||
from backend.agents.prompt_loader import get_prompt_loader
|
||||
from backend.agents.workspace_manager import WorkspaceManager
|
||||
from backend.config.constants import ANALYST_TYPES
|
||||
from backend.core.pipeline import TradingPipeline
|
||||
@@ -36,7 +36,7 @@ from backend.services.storage import StorageService
|
||||
from backend.services.gateway import Gateway
|
||||
from backend.utils.settlement import SettlementCoordinator
|
||||
|
||||
_prompt_loader = PromptLoader()
|
||||
_prompt_loader = get_prompt_loader()
|
||||
|
||||
# Global gateway reference for cleanup
|
||||
_gateway_instance: Optional[Gateway] = None
|
||||
@@ -244,10 +244,8 @@ async def run_pipeline(
|
||||
start_date = bootstrap.get("start_date")
|
||||
end_date = bootstrap.get("end_date")
|
||||
enable_memory = bootstrap.get("enable_memory", False)
|
||||
enable_mock = bootstrap.get("enable_mock", False)
|
||||
|
||||
is_backtest = mode == "backtest"
|
||||
is_mock = enable_mock or mode == "mock" or (not is_backtest and os.getenv("MOCK_MODE", "false").lower() == "true")
|
||||
|
||||
# ======================================================================
|
||||
# PHASE 0: Initialize runtime manager
|
||||
@@ -266,10 +264,6 @@ async def run_pipeline(
|
||||
|
||||
set_global_runtime_manager(runtime_manager)
|
||||
|
||||
# Register runtime manager with API
|
||||
from backend.api.runtime import register_runtime_manager
|
||||
register_runtime_manager(runtime_manager)
|
||||
|
||||
# ======================================================================
|
||||
# PHASE 1 & 2: Create infrastructure services (Market, Storage)
|
||||
# These will be started by Gateway in the correct order
|
||||
@@ -292,9 +286,8 @@ async def run_pipeline(
|
||||
market_service = MarketService(
|
||||
tickers=tickers,
|
||||
poll_interval=10,
|
||||
mock_mode=is_mock and not is_backtest,
|
||||
backtest_mode=is_backtest,
|
||||
api_key=os.getenv("FINNHUB_API_KEY") if not is_mock and not is_backtest else None,
|
||||
api_key=os.getenv("FINNHUB_API_KEY") if not is_backtest else None,
|
||||
backtest_start_date=start_date if is_backtest else None,
|
||||
backtest_end_date=end_date if is_backtest else None,
|
||||
)
|
||||
@@ -391,7 +384,6 @@ async def run_pipeline(
|
||||
scheduler_callback=scheduler_callback,
|
||||
config={
|
||||
"mode": mode,
|
||||
"mock_mode": is_mock,
|
||||
"backtest_mode": is_backtest,
|
||||
"tickers": tickers,
|
||||
"config_name": run_id,
|
||||
|
||||
@@ -465,7 +465,6 @@ class StateSync:
|
||||
|
||||
payload = {
|
||||
"server_mode": self._state.get("server_mode", "live"),
|
||||
"is_mock_mode": self._state.get("is_mock_mode", False),
|
||||
"is_backtest": self._state.get("is_backtest", False),
|
||||
"tickers": self._state.get("tickers"),
|
||||
"runtime_config": self._state.get("runtime_config"),
|
||||
@@ -488,12 +487,13 @@ class StateSync:
|
||||
}
|
||||
|
||||
if include_dashboard:
|
||||
dashboard_snapshot = self.storage.build_dashboard_snapshot_from_state(self._state)
|
||||
payload["dashboard"] = {
|
||||
"summary": self.storage.load_file("summary"),
|
||||
"holdings": self.storage.load_file("holdings"),
|
||||
"stats": self.storage.load_file("stats"),
|
||||
"trades": self.storage.load_file("trades"),
|
||||
"leaderboard": self.storage.load_file("leaderboard"),
|
||||
"summary": dashboard_snapshot.get("summary"),
|
||||
"holdings": dashboard_snapshot.get("holdings"),
|
||||
"stats": dashboard_snapshot.get("stats"),
|
||||
"trades": dashboard_snapshot.get("trades"),
|
||||
"leaderboard": dashboard_snapshot.get("leaderboard"),
|
||||
}
|
||||
|
||||
return payload
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from backend.data.historical_price_manager import HistoricalPriceManager
|
||||
from backend.data.mock_price_manager import MockPriceManager
|
||||
from backend.data.polling_price_manager import PollingPriceManager
|
||||
|
||||
__all__ = ["MockPriceManager", "PollingPriceManager", "HistoricalPriceManager"]
|
||||
__all__ = ["PollingPriceManager", "HistoricalPriceManager"]
|
||||
|
||||
@@ -8,6 +8,7 @@ from typing import Iterable
|
||||
|
||||
from backend.data.market_store import MarketStore
|
||||
from backend.data.news_alignment import align_news_for_symbol
|
||||
from backend.data.provider_router import DataProviderRouter
|
||||
from backend.data.polygon_client import (
|
||||
fetch_news,
|
||||
fetch_ohlc,
|
||||
@@ -24,6 +25,35 @@ def _default_start(years: int = 2) -> str:
|
||||
return (datetime.now(timezone.utc).date() - timedelta(days=years * 366)).isoformat()
|
||||
|
||||
|
||||
def _normalize_provider_news_rows(ticker: str, news_items: Iterable[Any]) -> list[dict]:
|
||||
rows: list[dict] = []
|
||||
for item in news_items:
|
||||
payload = item.model_dump() if hasattr(item, "model_dump") else dict(item or {})
|
||||
related = payload.get("related")
|
||||
if isinstance(related, str):
|
||||
related_list = [value.strip().upper() for value in related.split(",") if value.strip()]
|
||||
elif isinstance(related, list):
|
||||
related_list = [str(value).strip().upper() for value in related if str(value).strip()]
|
||||
else:
|
||||
related_list = []
|
||||
if ticker not in related_list:
|
||||
related_list.append(ticker)
|
||||
rows.append(
|
||||
{
|
||||
"title": payload.get("title"),
|
||||
"description": payload.get("summary"),
|
||||
"summary": payload.get("summary"),
|
||||
"article_url": payload.get("url"),
|
||||
"published_utc": payload.get("date"),
|
||||
"publisher": payload.get("source"),
|
||||
"tickers": related_list,
|
||||
"category": payload.get("category"),
|
||||
"raw_json": payload,
|
||||
}
|
||||
)
|
||||
return rows
|
||||
|
||||
|
||||
def ingest_ticker_history(
|
||||
symbol: str,
|
||||
*,
|
||||
@@ -114,6 +144,80 @@ def update_ticker_incremental(
|
||||
}
|
||||
|
||||
|
||||
def refresh_news_incremental(
|
||||
symbol: str,
|
||||
*,
|
||||
end_date: str | None = None,
|
||||
store: MarketStore | None = None,
|
||||
) -> dict:
|
||||
"""Incrementally fetch company news using the configured provider router."""
|
||||
ticker = normalize_symbol(symbol)
|
||||
market_store = store or MarketStore()
|
||||
watermarks = market_store.get_ticker_watermarks(ticker)
|
||||
end = end_date or _today_utc()
|
||||
start_news = (
|
||||
(datetime.fromisoformat(watermarks["last_news_fetch"]) + timedelta(days=1)).date().isoformat()
|
||||
if watermarks.get("last_news_fetch")
|
||||
else _default_start()
|
||||
)
|
||||
|
||||
if start_news > end:
|
||||
return {
|
||||
"symbol": ticker,
|
||||
"start_news_date": start_news,
|
||||
"end_date": end,
|
||||
"news": 0,
|
||||
"aligned": 0,
|
||||
}
|
||||
|
||||
router = DataProviderRouter()
|
||||
news_items, source = router.get_company_news(
|
||||
ticker=ticker,
|
||||
start_date=start_news,
|
||||
end_date=end,
|
||||
limit=1000,
|
||||
)
|
||||
news_rows = _normalize_provider_news_rows(ticker, news_items)
|
||||
news_count = market_store.upsert_news(ticker, news_rows, source=source) if news_rows else 0
|
||||
aligned_count = align_news_for_symbol(market_store, ticker)
|
||||
market_store.update_fetch_watermark(
|
||||
symbol=ticker,
|
||||
news_date=end if news_rows or watermarks.get("last_news_fetch") else None,
|
||||
)
|
||||
|
||||
return {
|
||||
"symbol": ticker,
|
||||
"start_news_date": start_news,
|
||||
"end_date": end,
|
||||
"news": news_count,
|
||||
"aligned": aligned_count,
|
||||
"source": source,
|
||||
}
|
||||
|
||||
|
||||
def refresh_news_for_symbols(
|
||||
symbols: Iterable[str],
|
||||
*,
|
||||
end_date: str | None = None,
|
||||
store: MarketStore | None = None,
|
||||
) -> list[dict]:
|
||||
"""Incrementally refresh company news for a list of tickers."""
|
||||
market_store = store or MarketStore()
|
||||
results = []
|
||||
for symbol in symbols:
|
||||
ticker = normalize_symbol(symbol)
|
||||
if not ticker:
|
||||
continue
|
||||
results.append(
|
||||
refresh_news_incremental(
|
||||
ticker,
|
||||
end_date=end_date,
|
||||
store=market_store,
|
||||
)
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
def ingest_symbols(
|
||||
symbols: Iterable[str],
|
||||
*,
|
||||
|
||||
@@ -9,7 +9,7 @@ import os
|
||||
import sqlite3
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterable
|
||||
from typing import Any, Iterable, Optional
|
||||
|
||||
|
||||
SCHEMA = """
|
||||
@@ -147,12 +147,30 @@ def _utc_timestamp() -> str:
|
||||
|
||||
|
||||
class MarketStore:
|
||||
"""SQLite-backed market research warehouse."""
|
||||
"""SQLite-backed market research warehouse. Use get_instance() for the singleton."""
|
||||
|
||||
_instance: Optional["MarketStore"] = None
|
||||
|
||||
def __new__(cls, db_path: Path | None = None) -> "MarketStore":
|
||||
if cls._instance is not None:
|
||||
if db_path is None or cls._instance.db_path == Path(db_path or get_market_db_path()):
|
||||
return cls._instance
|
||||
instance = super().__new__(cls)
|
||||
cls._instance = instance
|
||||
return instance
|
||||
|
||||
def __init__(self, db_path: Path | None = None):
|
||||
if getattr(self, "_initialized", False):
|
||||
return
|
||||
self.db_path = Path(db_path or get_market_db_path())
|
||||
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._init_db()
|
||||
self._initialized = True
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls, db_path: Path | None = None) -> "MarketStore":
|
||||
"""Get the MarketStore singleton instance."""
|
||||
return cls(db_path)
|
||||
|
||||
def _connect(self) -> sqlite3.Connection:
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
|
||||
@@ -1,244 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Mock Price Manager - For testing during non-trading hours
|
||||
Generates virtual real-time price data
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import threading
|
||||
import time
|
||||
from typing import Callable, Dict, List, Optional
|
||||
from backend.data.provider_utils import normalize_symbol
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MockPriceManager:
|
||||
"""Mock Price Manager - Generates virtual prices for testing"""
|
||||
|
||||
def __init__(self, poll_interval: int = 10, volatility: float = 0.5):
|
||||
"""
|
||||
Args:
|
||||
poll_interval: Price update interval in seconds
|
||||
volatility: Price volatility percentage
|
||||
"""
|
||||
if poll_interval is None:
|
||||
poll_interval = int(os.getenv("MOCK_POLL_INTERVAL", "5"))
|
||||
if volatility is None:
|
||||
volatility = float(os.getenv("MOCK_VOLATILITY", "0.5"))
|
||||
|
||||
self.poll_interval = poll_interval
|
||||
self.volatility = volatility
|
||||
|
||||
self.subscribed_symbols: List[str] = []
|
||||
self.base_prices: Dict[str, float] = {}
|
||||
self.open_prices: Dict[str, float] = {}
|
||||
self.latest_prices: Dict[str, float] = {}
|
||||
self.price_callbacks: List[Callable] = []
|
||||
|
||||
self.running = False
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
|
||||
self.default_base_prices = {
|
||||
"AAPL": 237.50,
|
||||
"MSFT": 425.30,
|
||||
"GOOGL": 161.50,
|
||||
"AMZN": 218.45,
|
||||
"NVDA": 950.00,
|
||||
"META": 573.22,
|
||||
"TSLA": 342.15,
|
||||
"AMD": 168.90,
|
||||
"NFLX": 688.25,
|
||||
"INTC": 42.18,
|
||||
"COIN": 285.50,
|
||||
"PLTR": 45.80,
|
||||
"BABA": 88.30,
|
||||
"DIS": 112.50,
|
||||
"BKNG": 4850.00,
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"MockPriceManager initialized (interval: {self.poll_interval}s, "
|
||||
f"volatility: {self.volatility}%)",
|
||||
)
|
||||
|
||||
def subscribe(
|
||||
self,
|
||||
symbols: List[str],
|
||||
base_prices: Dict[str, float] = None,
|
||||
):
|
||||
"""Subscribe to stock symbols"""
|
||||
for symbol in symbols:
|
||||
symbol = normalize_symbol(symbol)
|
||||
if symbol not in self.subscribed_symbols:
|
||||
self.subscribed_symbols.append(symbol)
|
||||
|
||||
if base_prices and symbol in base_prices:
|
||||
base_price = base_prices[symbol]
|
||||
elif symbol in self.default_base_prices:
|
||||
base_price = self.default_base_prices[symbol]
|
||||
else:
|
||||
base_price = random.uniform(50, 500)
|
||||
|
||||
self.base_prices[symbol] = base_price
|
||||
self.open_prices[symbol] = base_price
|
||||
self.latest_prices[symbol] = base_price
|
||||
|
||||
logger.info(
|
||||
f"Subscribed to mock price: {symbol} (base: ${base_price:.2f})", # noqa: E501
|
||||
)
|
||||
|
||||
def unsubscribe(self, symbols: List[str]):
|
||||
"""Unsubscribe from symbols"""
|
||||
for symbol in symbols:
|
||||
symbol = normalize_symbol(symbol)
|
||||
if symbol in self.subscribed_symbols:
|
||||
self.subscribed_symbols.remove(symbol)
|
||||
self.base_prices.pop(symbol, None)
|
||||
self.open_prices.pop(symbol, None)
|
||||
self.latest_prices.pop(symbol, None)
|
||||
logger.info(f"Unsubscribed: {symbol}")
|
||||
|
||||
def add_price_callback(self, callback: Callable):
|
||||
"""Add price update callback"""
|
||||
self.price_callbacks.append(callback)
|
||||
|
||||
def _generate_price_update(self, symbol: str) -> float:
|
||||
"""Generate price update based on random walk"""
|
||||
current_price = self.latest_prices.get(
|
||||
symbol,
|
||||
self.base_prices[symbol],
|
||||
)
|
||||
|
||||
change_percent = random.uniform(-self.volatility, self.volatility)
|
||||
new_price = current_price * (1 + change_percent / 100)
|
||||
|
||||
# 10% chance of larger movement
|
||||
if random.random() < 0.1:
|
||||
trend_factor = random.uniform(-2, 2)
|
||||
new_price = new_price * (1 + trend_factor / 100)
|
||||
|
||||
# Limit intraday movement to +/-10%
|
||||
open_price = self.open_prices[symbol]
|
||||
max_price = open_price * 1.10
|
||||
min_price = open_price * 0.90
|
||||
new_price = max(min_price, min(max_price, new_price))
|
||||
|
||||
return new_price
|
||||
|
||||
def _update_prices(self):
|
||||
"""Update prices for all subscribed stocks"""
|
||||
timestamp = int(time.time() * 1000)
|
||||
|
||||
for symbol in self.subscribed_symbols:
|
||||
try:
|
||||
new_price = self._generate_price_update(symbol)
|
||||
self.latest_prices[symbol] = new_price
|
||||
|
||||
open_price = self.open_prices[symbol]
|
||||
ret = ((new_price - open_price) / open_price) * 100
|
||||
|
||||
price_data = {
|
||||
"symbol": symbol,
|
||||
"price": new_price,
|
||||
"timestamp": timestamp,
|
||||
"volume": random.randint(1000000, 10000000),
|
||||
"open": open_price,
|
||||
"high": max(new_price, open_price),
|
||||
"low": min(new_price, open_price),
|
||||
"previous_close": open_price,
|
||||
"ret": ret,
|
||||
}
|
||||
|
||||
for callback in self.price_callbacks:
|
||||
try:
|
||||
callback(price_data)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Mock price callback error ({symbol}): {e}",
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Mock {symbol}: ${new_price:.2f} [ret: {ret:+.2f}%]",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate mock price ({symbol}): {e}")
|
||||
|
||||
def _polling_loop(self):
|
||||
"""Main polling loop"""
|
||||
logger.info(
|
||||
f"Mock price generation started (interval: {self.poll_interval}s)",
|
||||
)
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
start_time = time.time()
|
||||
self._update_prices()
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
sleep_time = max(0, self.poll_interval - elapsed)
|
||||
if sleep_time > 0:
|
||||
time.sleep(sleep_time)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Mock polling loop error: {e}")
|
||||
time.sleep(5)
|
||||
|
||||
def start(self):
|
||||
"""Start mock price generation"""
|
||||
if self.running:
|
||||
logger.warning("Mock price manager already running")
|
||||
return
|
||||
|
||||
if not self.subscribed_symbols:
|
||||
logger.warning("No stocks subscribed")
|
||||
return
|
||||
|
||||
self.running = True
|
||||
self._thread = threading.Thread(target=self._polling_loop, daemon=True)
|
||||
self._thread.start()
|
||||
|
||||
logger.info(
|
||||
f"Mock price manager started: {', '.join(self.subscribed_symbols)}", # noqa: E501
|
||||
)
|
||||
|
||||
def stop(self):
|
||||
"""Stop mock price generation"""
|
||||
self.running = False
|
||||
if self._thread:
|
||||
self._thread.join(timeout=5)
|
||||
logger.info("Mock price manager stopped")
|
||||
|
||||
def get_latest_price(self, symbol: str) -> Optional[float]:
|
||||
"""Get latest price for symbol"""
|
||||
return self.latest_prices.get(symbol)
|
||||
|
||||
def get_all_latest_prices(self) -> Dict[str, float]:
|
||||
"""Get all latest prices"""
|
||||
return self.latest_prices.copy()
|
||||
|
||||
def get_open_price(self, symbol: str) -> Optional[float]:
|
||||
"""Get open price for symbol"""
|
||||
return self.open_prices.get(symbol)
|
||||
|
||||
def reset_open_prices(self):
|
||||
"""Reset open prices for new trading day"""
|
||||
for symbol in self.subscribed_symbols:
|
||||
last_close = self.latest_prices[symbol]
|
||||
gap_percent = random.uniform(-1, 1)
|
||||
new_open = last_close * (1 + gap_percent / 100)
|
||||
self.open_prices[symbol] = new_open
|
||||
self.latest_prices[symbol] = new_open
|
||||
logger.info("Open prices reset")
|
||||
|
||||
def set_base_price(self, symbol: str, price: float):
|
||||
"""Manually set base price for testing"""
|
||||
if symbol in self.subscribed_symbols:
|
||||
self.base_prices[symbol] = price
|
||||
self.open_prices[symbol] = price
|
||||
self.latest_prices[symbol] = price
|
||||
logger.info(f"{symbol} base price set to: ${price:.2f}")
|
||||
else:
|
||||
logger.warning(f"{symbol} not subscribed")
|
||||
@@ -15,6 +15,9 @@ from backend.data.provider_utils import normalize_symbol
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_SUPPRESSED_LOG_EVERY = 20
|
||||
|
||||
|
||||
class PollingPriceManager:
|
||||
"""Polling-based price manager using Finnhub or yfinance."""
|
||||
|
||||
@@ -43,6 +46,7 @@ class PollingPriceManager:
|
||||
self.latest_prices: Dict[str, float] = {}
|
||||
self.open_prices: Dict[str, float] = {}
|
||||
self.price_callbacks: List[Callable] = []
|
||||
self._failure_counts: Dict[str, int] = {}
|
||||
|
||||
self.running = False
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
@@ -77,6 +81,8 @@ class PollingPriceManager:
|
||||
for symbol in self.subscribed_symbols:
|
||||
try:
|
||||
quote_data = self._fetch_quote(symbol)
|
||||
if not isinstance(quote_data, dict):
|
||||
raise ValueError(f"{symbol}: Empty quote payload")
|
||||
|
||||
current_price = quote_data.get("c")
|
||||
open_price = quote_data.get("o")
|
||||
@@ -103,6 +109,13 @@ class PollingPriceManager:
|
||||
)
|
||||
|
||||
self.latest_prices[symbol] = current_price
|
||||
previous_failures = self._failure_counts.pop(symbol, 0)
|
||||
if previous_failures > 0:
|
||||
logger.info(
|
||||
"%s quote polling recovered after %d consecutive failures",
|
||||
symbol,
|
||||
previous_failures,
|
||||
)
|
||||
|
||||
price_data = {
|
||||
"symbol": symbol,
|
||||
@@ -128,7 +141,20 @@ class PollingPriceManager:
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to fetch {symbol} price: {e}")
|
||||
failure_count = self._failure_counts.get(symbol, 0) + 1
|
||||
self._failure_counts[symbol] = failure_count
|
||||
message = f"Failed to fetch {symbol} price: {e}"
|
||||
|
||||
if failure_count == 1:
|
||||
logger.warning(message)
|
||||
elif failure_count % _SUPPRESSED_LOG_EVERY == 0:
|
||||
logger.warning(
|
||||
"%s (repeated %d times; suppressing intermediate failures)",
|
||||
message,
|
||||
failure_count,
|
||||
)
|
||||
else:
|
||||
logger.debug(message)
|
||||
|
||||
def _fetch_quote(self, symbol: str) -> Dict[str, float]:
|
||||
"""Fetch a normalized quote payload from the configured provider."""
|
||||
@@ -136,7 +162,10 @@ class PollingPriceManager:
|
||||
return self._fetch_yfinance_quote(symbol)
|
||||
if not self.finnhub_client:
|
||||
raise ValueError("Finnhub API key required for finnhub polling")
|
||||
return self.finnhub_client.quote(symbol)
|
||||
quote = self.finnhub_client.quote(symbol)
|
||||
if not isinstance(quote, dict):
|
||||
raise ValueError(f"{symbol}: Invalid Finnhub quote payload")
|
||||
return quote
|
||||
|
||||
def _fetch_yfinance_quote(self, symbol: str) -> Dict[str, float]:
|
||||
"""Fetch quote data from yfinance and normalize to Finnhub-like keys."""
|
||||
@@ -162,6 +191,8 @@ class PollingPriceManager:
|
||||
|
||||
if current_price is None:
|
||||
history = ticker.history(period="1d", interval="1m", auto_adjust=False)
|
||||
if history is None:
|
||||
raise ValueError(f"{symbol}: yfinance returned no history frame")
|
||||
if history.empty:
|
||||
raise ValueError(f"{symbol}: No yfinance quote data")
|
||||
latest = history.iloc[-1]
|
||||
|
||||
@@ -11,7 +11,7 @@ import pandas as pd
|
||||
import yfinance as yf
|
||||
|
||||
from backend.config.data_config import DataSource, get_data_sources
|
||||
from backend.data.schema import (
|
||||
from shared.schema import (
|
||||
CompanyFactsResponse,
|
||||
CompanyNews,
|
||||
CompanyNewsResponse,
|
||||
|
||||
@@ -1,194 +1,50 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from pydantic import BaseModel
|
||||
"""Compatibility schema bridge.
|
||||
|
||||
This module preserves the legacy ``backend.data.schema`` import path while
|
||||
delegating the actual schema definitions to ``shared.schema``. Keeping one
|
||||
canonical DTO set avoids drift as the monolith is split into service-specific
|
||||
packages.
|
||||
"""
|
||||
|
||||
class Price(BaseModel):
|
||||
open: float
|
||||
close: float
|
||||
high: float
|
||||
low: float
|
||||
volume: int
|
||||
time: str
|
||||
from shared.schema import (
|
||||
AgentStateData,
|
||||
AgentStateMetadata,
|
||||
AnalystSignal,
|
||||
CompanyFacts,
|
||||
CompanyFactsResponse,
|
||||
CompanyNews,
|
||||
CompanyNewsResponse,
|
||||
FinancialMetrics,
|
||||
FinancialMetricsResponse,
|
||||
InsiderTrade,
|
||||
InsiderTradeResponse,
|
||||
LineItem,
|
||||
LineItemResponse,
|
||||
Portfolio,
|
||||
Position,
|
||||
Price,
|
||||
PriceResponse,
|
||||
TickerAnalysis,
|
||||
)
|
||||
|
||||
|
||||
class PriceResponse(BaseModel):
|
||||
ticker: str
|
||||
prices: list[Price]
|
||||
|
||||
|
||||
class FinancialMetrics(BaseModel):
|
||||
ticker: str
|
||||
report_period: str
|
||||
period: str
|
||||
currency: str
|
||||
market_cap: float | None
|
||||
enterprise_value: float | None
|
||||
price_to_earnings_ratio: float | None
|
||||
price_to_book_ratio: float | None
|
||||
price_to_sales_ratio: float | None
|
||||
enterprise_value_to_ebitda_ratio: float | None
|
||||
enterprise_value_to_revenue_ratio: float | None
|
||||
free_cash_flow_yield: float | None
|
||||
peg_ratio: float | None
|
||||
gross_margin: float | None
|
||||
operating_margin: float | None
|
||||
net_margin: float | None
|
||||
return_on_equity: float | None
|
||||
return_on_assets: float | None
|
||||
return_on_invested_capital: float | None
|
||||
asset_turnover: float | None
|
||||
inventory_turnover: float | None
|
||||
receivables_turnover: float | None
|
||||
days_sales_outstanding: float | None
|
||||
operating_cycle: float | None
|
||||
working_capital_turnover: float | None
|
||||
current_ratio: float | None
|
||||
quick_ratio: float | None
|
||||
cash_ratio: float | None
|
||||
operating_cash_flow_ratio: float | None
|
||||
debt_to_equity: float | None
|
||||
debt_to_assets: float | None
|
||||
interest_coverage: float | None
|
||||
revenue_growth: float | None
|
||||
earnings_growth: float | None
|
||||
book_value_growth: float | None
|
||||
earnings_per_share_growth: float | None
|
||||
free_cash_flow_growth: float | None
|
||||
operating_income_growth: float | None
|
||||
ebitda_growth: float | None
|
||||
payout_ratio: float | None
|
||||
earnings_per_share: float | None
|
||||
book_value_per_share: float | None
|
||||
free_cash_flow_per_share: float | None
|
||||
|
||||
|
||||
class FinancialMetricsResponse(BaseModel):
|
||||
financial_metrics: list[FinancialMetrics]
|
||||
|
||||
|
||||
class LineItem(BaseModel):
|
||||
ticker: str
|
||||
report_period: str
|
||||
period: str
|
||||
currency: str
|
||||
|
||||
# Allow additional fields dynamically
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
|
||||
class LineItemResponse(BaseModel):
|
||||
search_results: list[LineItem]
|
||||
|
||||
|
||||
class InsiderTrade(BaseModel):
|
||||
ticker: str
|
||||
issuer: str | None
|
||||
name: str | None
|
||||
title: str | None
|
||||
is_board_director: bool | None
|
||||
transaction_date: str | None
|
||||
transaction_shares: float | None
|
||||
transaction_price_per_share: float | None
|
||||
transaction_value: float | None
|
||||
shares_owned_before_transaction: float | None
|
||||
shares_owned_after_transaction: float | None
|
||||
security_title: str | None
|
||||
filing_date: str
|
||||
|
||||
|
||||
class InsiderTradeResponse(BaseModel):
|
||||
insider_trades: list[InsiderTrade]
|
||||
|
||||
|
||||
class CompanyNews(BaseModel):
|
||||
category: str | None = None
|
||||
ticker: str
|
||||
title: str
|
||||
related: str | None = None
|
||||
source: str
|
||||
date: str | None = None
|
||||
url: str
|
||||
summary: str | None = None
|
||||
|
||||
|
||||
class CompanyNewsResponse(BaseModel):
|
||||
news: list[CompanyNews]
|
||||
|
||||
|
||||
class CompanyFacts(BaseModel):
|
||||
ticker: str
|
||||
name: str
|
||||
cik: str | None = None
|
||||
industry: str | None = None
|
||||
sector: str | None = None
|
||||
category: str | None = None
|
||||
exchange: str | None = None
|
||||
is_active: bool | None = None
|
||||
listing_date: str | None = None
|
||||
location: str | None = None
|
||||
market_cap: float | None = None
|
||||
number_of_employees: int | None = None
|
||||
sec_filings_url: str | None = None
|
||||
sic_code: str | None = None
|
||||
sic_industry: str | None = None
|
||||
sic_sector: str | None = None
|
||||
website_url: str | None = None
|
||||
weighted_average_shares: int | None = None
|
||||
|
||||
|
||||
class CompanyFactsResponse(BaseModel):
|
||||
company_facts: CompanyFacts
|
||||
|
||||
|
||||
class Position(BaseModel):
|
||||
"""Position information - for Portfolio mode"""
|
||||
|
||||
long: int = 0 # Long position quantity (shares)
|
||||
short: int = 0 # Short position quantity (shares)
|
||||
long_cost_basis: float = 0.0 # Long position average cost
|
||||
short_cost_basis: float = 0.0 # Short position average cost
|
||||
|
||||
|
||||
class Portfolio(BaseModel):
|
||||
"""Portfolio - for Portfolio mode"""
|
||||
|
||||
cash: float = 100000.0 # Available cash
|
||||
positions: dict[str, Position] = {} # ticker -> Position mapping
|
||||
# Margin requirement (0.0 means shorting disabled, 0.5 means 50% margin)
|
||||
margin_requirement: float = 0.0
|
||||
margin_used: float = 0.0 # Margin used
|
||||
|
||||
|
||||
class AnalystSignal(BaseModel):
|
||||
signal: str | None = None
|
||||
confidence: float | None = None
|
||||
reasoning: dict | str | None = None
|
||||
# Extended fields for richer signal information
|
||||
reasons: list[str] | None = None # Core drivers/reasons for the signal
|
||||
risks: list[str] | None = None # Key risk factors
|
||||
invalidation: str | None = None # Conditions that would invalidate the thesis
|
||||
next_action: str | None = None # Suggested next action for PM
|
||||
# Valuation-related fields
|
||||
intrinsic_value: float | None = None # DCF intrinsic value
|
||||
fair_value_range: dict | None = None # {bear, base, bull} fair value range
|
||||
value_gap_pct: float | None = None # Value gap percentage
|
||||
valuation_methods: list[str] | None = None # List of valuation methods used
|
||||
max_position_size: float | None = None # For risk management signals
|
||||
|
||||
|
||||
class TickerAnalysis(BaseModel):
|
||||
ticker: str
|
||||
analyst_signals: dict[str, AnalystSignal] # agent_name -> signal mapping
|
||||
|
||||
|
||||
class AgentStateData(BaseModel):
|
||||
tickers: list[str]
|
||||
portfolio: Portfolio
|
||||
start_date: str
|
||||
end_date: str
|
||||
ticker_analyses: dict[str, TickerAnalysis] # ticker -> analysis mapping
|
||||
|
||||
|
||||
class AgentStateMetadata(BaseModel):
|
||||
show_reasoning: bool = False
|
||||
model_config = {"extra": "allow"}
|
||||
__all__ = [
|
||||
"Price",
|
||||
"PriceResponse",
|
||||
"FinancialMetrics",
|
||||
"FinancialMetricsResponse",
|
||||
"LineItem",
|
||||
"LineItemResponse",
|
||||
"InsiderTrade",
|
||||
"InsiderTradeResponse",
|
||||
"CompanyNews",
|
||||
"CompanyNewsResponse",
|
||||
"CompanyFacts",
|
||||
"CompanyFactsResponse",
|
||||
"Position",
|
||||
"Portfolio",
|
||||
"AnalystSignal",
|
||||
"TickerAnalysis",
|
||||
"AgentStateData",
|
||||
"AgentStateMetadata",
|
||||
]
|
||||
|
||||
2
backend/domains/__init__.py
Normal file
2
backend/domains/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Domain modules for split service internals."""
|
||||
320
backend/domains/news.py
Normal file
320
backend/domains/news.py
Normal file
@@ -0,0 +1,320 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""News/explain domain helpers shared by app surfaces and gateway fallbacks."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from backend.data.market_store import MarketStore
|
||||
from backend.data.market_ingest import update_ticker_incremental
|
||||
from backend.enrich.news_enricher import enrich_news_for_symbol
|
||||
from backend.explain.range_explainer import build_range_explanation
|
||||
from backend.explain.similarity_service import find_similar_days
|
||||
from backend.explain.story_service import get_or_create_stock_story
|
||||
|
||||
|
||||
def news_rows_need_enrichment(rows: list[dict[str, Any]]) -> bool:
|
||||
"""Return whether news rows are missing explain-oriented analysis fields."""
|
||||
if not rows:
|
||||
return True
|
||||
return all(
|
||||
not row.get("sentiment")
|
||||
and not row.get("relevance")
|
||||
and not row.get("key_discussion")
|
||||
for row in rows
|
||||
)
|
||||
|
||||
|
||||
def ensure_news_fresh(
|
||||
store: MarketStore,
|
||||
*,
|
||||
ticker: str,
|
||||
target_date: str | None = None,
|
||||
refresh_if_stale: bool = True,
|
||||
) -> dict[str, Any]:
|
||||
"""Refresh raw news incrementally when stored watermarks are stale."""
|
||||
normalized_target = str(target_date or "").strip()[:10]
|
||||
if not normalized_target:
|
||||
return {
|
||||
"ticker": ticker,
|
||||
"target_date": None,
|
||||
"last_news_fetch": None,
|
||||
"refreshed": False,
|
||||
}
|
||||
|
||||
watermarks = store.get_ticker_watermarks(ticker)
|
||||
last_news_fetch = str(watermarks.get("last_news_fetch") or "").strip()[:10]
|
||||
refreshed = False
|
||||
if refresh_if_stale and (not last_news_fetch or last_news_fetch < normalized_target):
|
||||
update_ticker_incremental(
|
||||
ticker,
|
||||
end_date=normalized_target,
|
||||
store=store,
|
||||
)
|
||||
refreshed = True
|
||||
watermarks = store.get_ticker_watermarks(ticker)
|
||||
last_news_fetch = str(watermarks.get("last_news_fetch") or "").strip()[:10]
|
||||
|
||||
return {
|
||||
"ticker": ticker,
|
||||
"target_date": normalized_target,
|
||||
"last_news_fetch": last_news_fetch or None,
|
||||
"refreshed": refreshed,
|
||||
}
|
||||
|
||||
|
||||
def get_enriched_news(
|
||||
store: MarketStore,
|
||||
*,
|
||||
ticker: str,
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None,
|
||||
limit: int = 100,
|
||||
refresh_if_stale: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
freshness = ensure_news_fresh(
|
||||
store,
|
||||
ticker=ticker,
|
||||
target_date=end_date,
|
||||
refresh_if_stale=refresh_if_stale,
|
||||
)
|
||||
rows = store.get_news_items_enriched(
|
||||
ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=limit,
|
||||
)
|
||||
if news_rows_need_enrichment(rows):
|
||||
enrich_news_for_symbol(
|
||||
store,
|
||||
ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=limit,
|
||||
)
|
||||
rows = store.get_news_items_enriched(
|
||||
ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=limit,
|
||||
)
|
||||
return {"ticker": ticker, "news": rows, "freshness": freshness}
|
||||
|
||||
|
||||
def get_news_for_date(
|
||||
store: MarketStore,
|
||||
*,
|
||||
ticker: str,
|
||||
date: str,
|
||||
limit: int = 20,
|
||||
refresh_if_stale: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
freshness = ensure_news_fresh(
|
||||
store,
|
||||
ticker=ticker,
|
||||
target_date=date,
|
||||
refresh_if_stale=refresh_if_stale,
|
||||
)
|
||||
rows = store.get_news_items_enriched(
|
||||
ticker,
|
||||
trade_date=date,
|
||||
limit=limit,
|
||||
)
|
||||
if news_rows_need_enrichment(rows):
|
||||
enrich_news_for_symbol(
|
||||
store,
|
||||
ticker,
|
||||
start_date=date,
|
||||
end_date=date,
|
||||
limit=limit,
|
||||
)
|
||||
rows = store.get_news_items_enriched(
|
||||
ticker,
|
||||
trade_date=date,
|
||||
limit=limit,
|
||||
)
|
||||
return {"ticker": ticker, "date": date, "news": rows, "freshness": freshness}
|
||||
|
||||
|
||||
def get_news_timeline(
|
||||
store: MarketStore,
|
||||
*,
|
||||
ticker: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
refresh_if_stale: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
freshness = ensure_news_fresh(
|
||||
store,
|
||||
ticker=ticker,
|
||||
target_date=end_date,
|
||||
refresh_if_stale=refresh_if_stale,
|
||||
)
|
||||
timeline = store.get_news_timeline_enriched(
|
||||
ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
if not timeline:
|
||||
enrich_news_for_symbol(
|
||||
store,
|
||||
ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=200,
|
||||
)
|
||||
timeline = store.get_news_timeline_enriched(
|
||||
ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
return {
|
||||
"ticker": ticker,
|
||||
"timeline": timeline,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
"freshness": freshness,
|
||||
}
|
||||
|
||||
|
||||
def get_news_categories(
|
||||
store: MarketStore,
|
||||
*,
|
||||
ticker: str,
|
||||
start_date: str | None = None,
|
||||
end_date: str | None = None,
|
||||
limit: int = 200,
|
||||
refresh_if_stale: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
freshness = ensure_news_fresh(
|
||||
store,
|
||||
ticker=ticker,
|
||||
target_date=end_date,
|
||||
refresh_if_stale=refresh_if_stale,
|
||||
)
|
||||
rows = store.get_news_items_enriched(
|
||||
ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=limit,
|
||||
)
|
||||
if news_rows_need_enrichment(rows):
|
||||
enrich_news_for_symbol(
|
||||
store,
|
||||
ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=limit,
|
||||
)
|
||||
categories = store.get_news_categories_enriched(
|
||||
ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=limit,
|
||||
)
|
||||
return {"ticker": ticker, "categories": categories, "freshness": freshness}
|
||||
|
||||
|
||||
def get_similar_days_payload(
|
||||
store: MarketStore,
|
||||
*,
|
||||
ticker: str,
|
||||
date: str,
|
||||
n_similar: int = 5,
|
||||
refresh_if_stale: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
freshness = ensure_news_fresh(
|
||||
store,
|
||||
ticker=ticker,
|
||||
target_date=date,
|
||||
refresh_if_stale=refresh_if_stale,
|
||||
)
|
||||
result = find_similar_days(
|
||||
store,
|
||||
symbol=ticker,
|
||||
target_date=date,
|
||||
top_k=n_similar,
|
||||
)
|
||||
result["freshness"] = freshness
|
||||
return result
|
||||
|
||||
|
||||
def get_story_payload(
|
||||
store: MarketStore,
|
||||
*,
|
||||
ticker: str,
|
||||
as_of_date: str,
|
||||
refresh_if_stale: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
freshness = ensure_news_fresh(
|
||||
store,
|
||||
ticker=ticker,
|
||||
target_date=as_of_date,
|
||||
refresh_if_stale=refresh_if_stale,
|
||||
)
|
||||
enrich_news_for_symbol(
|
||||
store,
|
||||
ticker,
|
||||
end_date=as_of_date,
|
||||
limit=80,
|
||||
)
|
||||
result = get_or_create_stock_story(
|
||||
store,
|
||||
symbol=ticker,
|
||||
as_of_date=as_of_date,
|
||||
)
|
||||
result["freshness"] = freshness
|
||||
return result
|
||||
|
||||
|
||||
def get_range_explain_payload(
|
||||
store: MarketStore,
|
||||
*,
|
||||
ticker: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
article_ids: list[str] | None = None,
|
||||
limit: int = 100,
|
||||
refresh_if_stale: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
freshness = ensure_news_fresh(
|
||||
store,
|
||||
ticker=ticker,
|
||||
target_date=end_date,
|
||||
refresh_if_stale=refresh_if_stale,
|
||||
)
|
||||
news_rows = []
|
||||
if article_ids:
|
||||
news_rows = store.get_news_by_ids_enriched(ticker, article_ids)
|
||||
if not news_rows:
|
||||
news_rows = store.get_news_items_enriched(
|
||||
ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=limit,
|
||||
)
|
||||
if news_rows_need_enrichment(news_rows):
|
||||
enrich_news_for_symbol(
|
||||
store,
|
||||
ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=limit,
|
||||
)
|
||||
news_rows = (
|
||||
store.get_news_by_ids_enriched(ticker, article_ids)
|
||||
if article_ids
|
||||
else store.get_news_items_enriched(
|
||||
ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=limit,
|
||||
)
|
||||
)
|
||||
result = build_range_explanation(
|
||||
ticker=ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
news_rows=news_rows,
|
||||
)
|
||||
return {"ticker": ticker, "result": result, "freshness": freshness}
|
||||
106
backend/domains/trading.py
Normal file
106
backend/domains/trading.py
Normal file
@@ -0,0 +1,106 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Trading domain helpers shared by app surfaces and gateway fallbacks."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from backend.services.market import MarketService
|
||||
from backend.tools.data_tools import (
|
||||
get_company_news,
|
||||
get_financial_metrics,
|
||||
get_insider_trades,
|
||||
get_market_cap,
|
||||
get_prices,
|
||||
search_line_items,
|
||||
)
|
||||
|
||||
|
||||
def get_prices_payload(*, ticker: str, start_date: str, end_date: str) -> dict[str, Any]:
|
||||
return {
|
||||
"ticker": ticker,
|
||||
"prices": get_prices(ticker, start_date, end_date),
|
||||
}
|
||||
|
||||
|
||||
def get_financials_payload(
|
||||
*,
|
||||
ticker: str,
|
||||
end_date: str,
|
||||
period: str = "ttm",
|
||||
limit: int = 10,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"financial_metrics": get_financial_metrics(
|
||||
ticker=ticker,
|
||||
end_date=end_date,
|
||||
period=period,
|
||||
limit=limit,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
def get_news_payload(
|
||||
*,
|
||||
ticker: str,
|
||||
end_date: str,
|
||||
start_date: str | None = None,
|
||||
limit: int = 1000,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"news": get_company_news(
|
||||
ticker=ticker,
|
||||
end_date=end_date,
|
||||
start_date=start_date,
|
||||
limit=limit,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
def get_insider_trades_payload(
|
||||
*,
|
||||
ticker: str,
|
||||
end_date: str,
|
||||
start_date: str | None = None,
|
||||
limit: int = 1000,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"insider_trades": get_insider_trades(
|
||||
ticker=ticker,
|
||||
end_date=end_date,
|
||||
start_date=start_date,
|
||||
limit=limit,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
def get_market_status_payload() -> dict[str, Any]:
|
||||
market_service = MarketService(tickers=[])
|
||||
return market_service.get_market_status()
|
||||
|
||||
|
||||
def get_market_cap_payload(*, ticker: str, end_date: str) -> dict[str, Any]:
|
||||
return {
|
||||
"ticker": ticker,
|
||||
"end_date": end_date,
|
||||
"market_cap": get_market_cap(ticker, end_date),
|
||||
}
|
||||
|
||||
|
||||
def get_line_items_payload(
|
||||
*,
|
||||
ticker: str,
|
||||
line_items: list[str],
|
||||
end_date: str,
|
||||
period: str = "ttm",
|
||||
limit: int = 10,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"search_results": search_line_items(
|
||||
ticker=ticker,
|
||||
line_items=line_items,
|
||||
end_date=end_date,
|
||||
period=period,
|
||||
limit=limit,
|
||||
)
|
||||
}
|
||||
308
backend/gateway_server.py
Normal file
308
backend/gateway_server.py
Normal file
@@ -0,0 +1,308 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Gateway Server - Entry point for Gateway subprocess.
|
||||
|
||||
This module is launched as a subprocess by the Control Plane (FastAPI)
|
||||
to run the Data Plane (Gateway + Pipeline).
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from contextlib import AsyncExitStack
|
||||
from pathlib import Path
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
from backend.agents import AnalystAgent, PMAgent, RiskAgent
|
||||
from backend.agents.skills_manager import SkillsManager
|
||||
from backend.agents.toolkit_factory import create_agent_toolkit, load_agent_profiles
|
||||
from backend.agents.prompt_loader import get_prompt_loader
|
||||
from backend.agents.workspace_manager import WorkspaceManager
|
||||
from backend.config.constants import ANALYST_TYPES
|
||||
from backend.core.pipeline import TradingPipeline
|
||||
from backend.core.pipeline_runner import create_agents, create_long_term_memory
|
||||
from backend.core.scheduler import BacktestScheduler, Scheduler
|
||||
from backend.llm.models import get_agent_formatter, get_agent_model
|
||||
from backend.runtime.manager import (
|
||||
TradingRuntimeManager,
|
||||
set_global_runtime_manager,
|
||||
clear_global_runtime_manager,
|
||||
)
|
||||
from backend.services.gateway import Gateway
|
||||
from backend.services.market import MarketService
|
||||
from backend.services.storage import StorageService
|
||||
from backend.utils.settlement import SettlementCoordinator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_prompt_loader = get_prompt_loader()
|
||||
|
||||
|
||||
INFO_LOGGER_PREFIXES = (
|
||||
"backend.agents",
|
||||
"backend.core.pipeline",
|
||||
"backend.core.scheduler",
|
||||
"backend.services.gateway_cycle_support",
|
||||
"backend.utils.terminal_dashboard",
|
||||
)
|
||||
|
||||
NOISY_LOGGER_LEVELS = {
|
||||
"aiohttp": logging.WARNING,
|
||||
"asyncio": logging.WARNING,
|
||||
"dashscope": logging.WARNING,
|
||||
"finnhub": logging.WARNING,
|
||||
"httpcore": logging.WARNING,
|
||||
"httpx": logging.WARNING,
|
||||
"urllib3": logging.WARNING,
|
||||
"websockets": logging.WARNING,
|
||||
"yfinance": logging.WARNING,
|
||||
"backend.data.polling_price_manager": logging.WARNING,
|
||||
"backend.services.gateway": logging.WARNING,
|
||||
"backend.services.market": logging.WARNING,
|
||||
"backend.services.storage": logging.WARNING,
|
||||
}
|
||||
|
||||
|
||||
class SuppressNoisyInfoFilter(logging.Filter):
|
||||
"""Filter out low-signal library INFO logs while keeping warnings/errors."""
|
||||
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
if record.levelno >= logging.WARNING:
|
||||
return True
|
||||
|
||||
message = record.getMessage()
|
||||
if record.name == "httpx" and message.startswith("HTTP Request:"):
|
||||
return False
|
||||
if record.name.startswith("websockets") and "connection open" in message:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def configure_gateway_logging(verbose: bool = False) -> None:
|
||||
"""Configure gateway logging with low-noise defaults for runtime logs."""
|
||||
root_level = logging.DEBUG if verbose else logging.WARNING
|
||||
logging.basicConfig(
|
||||
level=root_level,
|
||||
format="%(asctime)s | %(levelname)-7s | %(name)s:%(lineno)d - %(message)s",
|
||||
force=True,
|
||||
)
|
||||
|
||||
if not verbose:
|
||||
suppress_filter = SuppressNoisyInfoFilter()
|
||||
for handler in logging.getLogger().handlers:
|
||||
handler.addFilter(suppress_filter)
|
||||
|
||||
for logger_name, level in NOISY_LOGGER_LEVELS.items():
|
||||
logging.getLogger(logger_name).setLevel(logging.DEBUG if verbose else level)
|
||||
|
||||
if not verbose:
|
||||
for prefix in INFO_LOGGER_PREFIXES:
|
||||
logging.getLogger(prefix).setLevel(logging.INFO)
|
||||
|
||||
logging.getLogger(__name__).setLevel(logging.INFO if not verbose else logging.DEBUG)
|
||||
|
||||
|
||||
async def run_gateway(
|
||||
run_id: str,
|
||||
run_dir: Path,
|
||||
bootstrap: dict,
|
||||
port: int
|
||||
):
|
||||
"""Run Gateway with Pipeline."""
|
||||
|
||||
# Extract config
|
||||
tickers = bootstrap.get("tickers", ["AAPL", "MSFT"])
|
||||
initial_cash = float(bootstrap.get("initial_cash", 100000.0))
|
||||
margin_requirement = float(bootstrap.get("margin_requirement", 0.0))
|
||||
max_comm_cycles = int(bootstrap.get("max_comm_cycles", 2))
|
||||
schedule_mode = bootstrap.get("schedule_mode", "daily")
|
||||
trigger_time = bootstrap.get("trigger_time", "09:30")
|
||||
interval_minutes = int(bootstrap.get("interval_minutes", 60))
|
||||
heartbeat_interval = int(bootstrap.get("heartbeat_interval", 0)) # 0 = disabled
|
||||
mode = bootstrap.get("mode", "live")
|
||||
start_date = bootstrap.get("start_date")
|
||||
end_date = bootstrap.get("end_date")
|
||||
enable_memory = bootstrap.get("enable_memory", False)
|
||||
poll_interval = int(bootstrap.get("poll_interval", 10))
|
||||
|
||||
is_backtest = mode == "backtest"
|
||||
|
||||
logger.info(f"[Gateway Server] Starting run {run_id} on port {port}")
|
||||
|
||||
# Create runtime manager
|
||||
runtime_manager = TradingRuntimeManager(
|
||||
config_name=run_id,
|
||||
run_dir=run_dir,
|
||||
bootstrap=bootstrap,
|
||||
)
|
||||
runtime_manager.prepare_run()
|
||||
set_global_runtime_manager(runtime_manager)
|
||||
|
||||
try:
|
||||
async with AsyncExitStack() as stack:
|
||||
# Create services
|
||||
market_service = MarketService(
|
||||
tickers=tickers,
|
||||
poll_interval=poll_interval,
|
||||
backtest_mode=is_backtest,
|
||||
api_key=os.getenv("FINNHUB_API_KEY") if not is_backtest else None,
|
||||
backtest_start_date=start_date if is_backtest else None,
|
||||
backtest_end_date=end_date if is_backtest else None,
|
||||
)
|
||||
|
||||
storage_service = StorageService(
|
||||
dashboard_dir=run_dir / "team_dashboard",
|
||||
initial_cash=initial_cash,
|
||||
config_name=run_id,
|
||||
)
|
||||
|
||||
if not storage_service.files["summary"].exists():
|
||||
storage_service.initialize_empty_dashboard()
|
||||
else:
|
||||
storage_service.update_leaderboard_model_info()
|
||||
|
||||
# Create agents
|
||||
analysts, risk_manager, pm, long_term_memories = create_agents(
|
||||
run_id=run_id,
|
||||
run_dir=run_dir,
|
||||
initial_cash=initial_cash,
|
||||
margin_requirement=margin_requirement,
|
||||
enable_long_term_memory=enable_memory,
|
||||
)
|
||||
|
||||
# Register agents
|
||||
for agent in analysts + [risk_manager, pm]:
|
||||
agent_id = getattr(agent, "agent_id", None) or getattr(agent, "name", None)
|
||||
if agent_id:
|
||||
runtime_manager.register_agent(agent_id)
|
||||
|
||||
# Load portfolio state
|
||||
portfolio_state = storage_service.load_portfolio_state()
|
||||
pm.load_portfolio_state(portfolio_state)
|
||||
|
||||
# Create settlement coordinator
|
||||
settlement_coordinator = SettlementCoordinator(
|
||||
storage=storage_service,
|
||||
initial_capital=initial_cash,
|
||||
)
|
||||
|
||||
# Create pipeline
|
||||
pipeline = TradingPipeline(
|
||||
analysts=analysts,
|
||||
risk_manager=risk_manager,
|
||||
portfolio_manager=pm,
|
||||
settlement_coordinator=settlement_coordinator,
|
||||
max_comm_cycles=max_comm_cycles,
|
||||
runtime_manager=runtime_manager,
|
||||
)
|
||||
|
||||
# Create scheduler
|
||||
scheduler_callback = None
|
||||
live_scheduler = None
|
||||
|
||||
if is_backtest:
|
||||
backtest_scheduler = BacktestScheduler(
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
trading_calendar="NYSE",
|
||||
delay_between_days=0.5,
|
||||
)
|
||||
|
||||
async def scheduler_callback_fn(callback):
|
||||
await backtest_scheduler.start(callback)
|
||||
|
||||
scheduler_callback = scheduler_callback_fn
|
||||
else:
|
||||
live_scheduler = Scheduler(
|
||||
mode=schedule_mode,
|
||||
trigger_time=trigger_time,
|
||||
interval_minutes=interval_minutes,
|
||||
heartbeat_interval=heartbeat_interval if heartbeat_interval > 0 else None,
|
||||
config={"config_name": run_id},
|
||||
)
|
||||
|
||||
async def scheduler_callback_fn(callback):
|
||||
await live_scheduler.start(callback)
|
||||
|
||||
scheduler_callback = scheduler_callback_fn
|
||||
|
||||
# Enter long-term memory contexts
|
||||
for memory in long_term_memories:
|
||||
await stack.enter_async_context(memory)
|
||||
|
||||
# Create Gateway
|
||||
gateway = Gateway(
|
||||
market_service=market_service,
|
||||
storage_service=storage_service,
|
||||
pipeline=pipeline,
|
||||
scheduler_callback=scheduler_callback,
|
||||
config={
|
||||
"mode": mode,
|
||||
"backtest_mode": is_backtest,
|
||||
"tickers": tickers,
|
||||
"config_name": run_id,
|
||||
"schedule_mode": schedule_mode,
|
||||
"interval_minutes": interval_minutes,
|
||||
"trigger_time": trigger_time,
|
||||
"heartbeat_interval": heartbeat_interval,
|
||||
"initial_cash": initial_cash,
|
||||
"margin_requirement": margin_requirement,
|
||||
"max_comm_cycles": max_comm_cycles,
|
||||
"enable_memory": enable_memory,
|
||||
},
|
||||
scheduler=live_scheduler,
|
||||
)
|
||||
|
||||
# Start Gateway (blocks until shutdown)
|
||||
logger.info(f"[Gateway Server] Gateway starting on port {port}")
|
||||
await gateway.start(host="0.0.0.0", port=port)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("[Gateway Server] Cancelled")
|
||||
raise
|
||||
finally:
|
||||
logger.info("[Gateway Server] Cleaning up")
|
||||
clear_global_runtime_manager()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point."""
|
||||
parser = argparse.ArgumentParser(description="Gateway Server")
|
||||
parser.add_argument("--run-id", required=True, help="Run identifier")
|
||||
parser.add_argument("--run-dir", required=True, help="Run directory path")
|
||||
parser.add_argument("--port", type=int, default=8765, help="WebSocket port")
|
||||
parser.add_argument("--bootstrap", required=True, help="Bootstrap config as JSON")
|
||||
parser.add_argument("--verbose", action="store_true", help="Verbose logging")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Setup logging
|
||||
configure_gateway_logging(verbose=args.verbose)
|
||||
|
||||
# Parse bootstrap
|
||||
bootstrap = json.loads(args.bootstrap)
|
||||
run_dir = Path(args.run_dir)
|
||||
|
||||
# Run
|
||||
try:
|
||||
asyncio.run(run_gateway(
|
||||
run_id=args.run_id,
|
||||
run_dir=run_dir,
|
||||
bootstrap=bootstrap,
|
||||
port=args.port
|
||||
))
|
||||
except KeyboardInterrupt:
|
||||
logger.info("[Gateway Server] Interrupted by user")
|
||||
except Exception as e:
|
||||
logger.exception(f"[Gateway Server] Fatal error: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -3,6 +3,8 @@
|
||||
AgentScope Native Model Factory
|
||||
Uses native AgentScope model classes for LLM calls
|
||||
"""
|
||||
import asyncio
|
||||
import inspect
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
@@ -34,6 +36,27 @@ logger = logging.getLogger(__name__)
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def _usage_value(usage: Any, key: str, default: Any = 0) -> Any:
|
||||
"""Read usage fields from both object-style and dict-style usage payloads."""
|
||||
if usage is None:
|
||||
return default
|
||||
if isinstance(usage, dict):
|
||||
return usage.get(key, default)
|
||||
try:
|
||||
return getattr(usage, key)
|
||||
except (AttributeError, KeyError):
|
||||
return default
|
||||
|
||||
|
||||
def _usage_total_tokens(usage: Any) -> int:
|
||||
total = _usage_value(usage, "total_tokens", None)
|
||||
if total is not None:
|
||||
return int(total or 0)
|
||||
input_tokens = _usage_value(usage, "input_tokens", 0)
|
||||
output_tokens = _usage_value(usage, "output_tokens", 0)
|
||||
return int((input_tokens or 0) + (output_tokens or 0))
|
||||
|
||||
|
||||
class RetryChatModel:
|
||||
"""Wraps an AgentScope model with automatic retry for transient errors.
|
||||
|
||||
@@ -55,6 +78,7 @@ class RetryChatModel:
|
||||
"502",
|
||||
"504",
|
||||
"connection",
|
||||
"disconnected",
|
||||
"temporary",
|
||||
"overloaded",
|
||||
"too_many_requests",
|
||||
@@ -150,8 +174,8 @@ class RetryChatModel:
|
||||
# Track usage if available
|
||||
if hasattr(result, "usage") and result.usage:
|
||||
usage = result.usage
|
||||
self._total_tokens_used += getattr(usage, "total_tokens", 0)
|
||||
self._total_cost += getattr(usage, "cost", 0.0)
|
||||
self._total_tokens_used += _usage_total_tokens(usage)
|
||||
self._total_cost += float(_usage_value(usage, "cost", 0.0) or 0.0)
|
||||
|
||||
return result
|
||||
|
||||
@@ -192,9 +216,66 @@ class RetryChatModel:
|
||||
raise last_error
|
||||
raise RuntimeError("RetryChatModel: Unexpected state, no error but no result")
|
||||
|
||||
async def _call_with_retry_async(self, func: Callable[..., T], *args, **kwargs) -> T:
|
||||
"""Call an async function with retry logic for transient errors."""
|
||||
last_error: Optional[Exception] = None
|
||||
|
||||
for attempt in range(1, self._max_retries + 1):
|
||||
try:
|
||||
result = await func(*args, **kwargs)
|
||||
|
||||
if hasattr(result, "usage") and result.usage:
|
||||
usage = result.usage
|
||||
self._total_tokens_used += _usage_total_tokens(usage)
|
||||
self._total_cost += float(_usage_value(usage, "cost", 0.0) or 0.0)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
|
||||
if attempt >= self._max_retries:
|
||||
logger.error(
|
||||
"RetryChatModel: Max retries (%d) exhausted for %s",
|
||||
self._max_retries,
|
||||
self.model_name,
|
||||
)
|
||||
break
|
||||
|
||||
if not self._is_transient_error(e):
|
||||
logger.warning(
|
||||
"RetryChatModel: Non-transient error, not retrying: %s",
|
||||
str(e),
|
||||
)
|
||||
break
|
||||
|
||||
delay = self._calculate_delay(attempt)
|
||||
logger.warning(
|
||||
"RetryChatModel: Transient async error on attempt %d/%d, "
|
||||
"retrying in %.1fs: %s",
|
||||
attempt,
|
||||
self._max_retries,
|
||||
delay,
|
||||
str(e)[:200],
|
||||
)
|
||||
|
||||
if self._on_retry:
|
||||
self._on_retry(attempt, e, delay)
|
||||
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
if last_error is not None:
|
||||
raise last_error
|
||||
raise RuntimeError("RetryChatModel: Unexpected async state, no error but no result")
|
||||
|
||||
def __call__(self, *args, **kwargs) -> Any:
|
||||
"""Forward calls to the wrapped model with retry logic."""
|
||||
return self._call_with_retry(self._model, *args, **kwargs)
|
||||
model_call = getattr(self._model, "__call__", None)
|
||||
if inspect.iscoroutinefunction(self._model) or inspect.iscoroutinefunction(model_call):
|
||||
return self._call_with_retry_async(self._model, *args, **kwargs)
|
||||
|
||||
result = self._model(*args, **kwargs)
|
||||
return result
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
"""Proxy attribute access to the wrapped model."""
|
||||
@@ -248,10 +329,18 @@ class TokenRecordingModelWrapper:
|
||||
if usage is None:
|
||||
return
|
||||
|
||||
self._prompt_tokens += getattr(usage, "prompt_tokens", 0)
|
||||
self._completion_tokens += getattr(usage, "completion_tokens", 0)
|
||||
self._total_tokens += getattr(usage, "total_tokens", 0)
|
||||
self._total_cost += getattr(usage, "cost", 0.0)
|
||||
prompt_tokens = _usage_value(usage, "prompt_tokens", None)
|
||||
completion_tokens = _usage_value(usage, "completion_tokens", None)
|
||||
|
||||
if prompt_tokens is None:
|
||||
prompt_tokens = _usage_value(usage, "input_tokens", 0)
|
||||
if completion_tokens is None:
|
||||
completion_tokens = _usage_value(usage, "output_tokens", 0)
|
||||
|
||||
self._prompt_tokens += int(prompt_tokens or 0)
|
||||
self._completion_tokens += int(completion_tokens or 0)
|
||||
self._total_tokens += _usage_total_tokens(usage)
|
||||
self._total_cost += float(_usage_value(usage, "cost", 0.0) or 0.0)
|
||||
|
||||
def __call__(self, *args, **kwargs) -> Any:
|
||||
"""Forward calls and record usage."""
|
||||
@@ -401,7 +490,8 @@ def create_model(
|
||||
if host:
|
||||
model_kwargs["host"] = host
|
||||
|
||||
return model_class(**model_kwargs)
|
||||
model = model_class(**model_kwargs)
|
||||
return RetryChatModel(model)
|
||||
|
||||
|
||||
def get_agent_model(agent_id: str, stream: bool = False):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Main Entry Point
|
||||
Supports: backtest, live, mock modes
|
||||
Supports: backtest, live modes
|
||||
"""
|
||||
import argparse
|
||||
import asyncio
|
||||
@@ -16,7 +16,7 @@ from dotenv import load_dotenv
|
||||
from backend.agents import AnalystAgent, PMAgent, RiskAgent
|
||||
from backend.agents.skills_manager import SkillsManager
|
||||
from backend.agents.toolkit_factory import create_agent_toolkit, load_agent_profiles
|
||||
from backend.agents.prompt_loader import PromptLoader
|
||||
from backend.agents.prompt_loader import get_prompt_loader
|
||||
from backend.agents.workspace_manager import WorkspaceManager
|
||||
from backend.config.bootstrap_config import resolve_runtime_config
|
||||
from backend.config.constants import ANALYST_TYPES
|
||||
@@ -38,7 +38,7 @@ load_dotenv()
|
||||
logger = logging.getLogger(__name__)
|
||||
loguru.logger.disable("flowllm")
|
||||
loguru.logger.disable("reme_ai")
|
||||
_prompt_loader = PromptLoader()
|
||||
_prompt_loader = get_prompt_loader()
|
||||
|
||||
|
||||
def _get_run_dir(config_name: str) -> Path:
|
||||
@@ -226,17 +226,13 @@ async def run_with_gateway(args):
|
||||
)
|
||||
runtime_manager.prepare_run()
|
||||
set_global_runtime_manager(runtime_manager)
|
||||
register_runtime_manager(runtime_manager)
|
||||
|
||||
# Create market service
|
||||
market_service = MarketService(
|
||||
tickers=tickers,
|
||||
poll_interval=args.poll_interval,
|
||||
mock_mode=args.mock and not is_backtest,
|
||||
backtest_mode=is_backtest,
|
||||
api_key=os.getenv("FINNHUB_API_KEY")
|
||||
if not args.mock and not is_backtest
|
||||
else None,
|
||||
api_key=os.getenv("FINNHUB_API_KEY") if not is_backtest else None,
|
||||
backtest_start_date=args.start_date if is_backtest else None,
|
||||
backtest_end_date=args.end_date if is_backtest else None,
|
||||
)
|
||||
@@ -321,7 +317,6 @@ async def run_with_gateway(args):
|
||||
scheduler_callback=scheduler_callback,
|
||||
config={
|
||||
"mode": args.mode,
|
||||
"mock_mode": args.mock,
|
||||
"backtest_mode": is_backtest,
|
||||
"tickers": tickers,
|
||||
"config_name": config_name,
|
||||
@@ -354,8 +349,7 @@ def main():
|
||||
"""Main entry point"""
|
||||
parser = argparse.ArgumentParser(description="Trading System")
|
||||
parser.add_argument("--mode", choices=["live", "backtest"], default="live")
|
||||
parser.add_argument("--mock", action="store_true")
|
||||
parser.add_argument("--config-name", default="mock")
|
||||
parser.add_argument("--config-name", default="live")
|
||||
parser.add_argument("--host", default="0.0.0.0")
|
||||
parser.add_argument("--port", type=int, default=8765)
|
||||
parser.add_argument(
|
||||
|
||||
@@ -13,15 +13,30 @@ from .registry import RuntimeRegistry
|
||||
_global_runtime_manager: Optional["TradingRuntimeManager"] = None
|
||||
_shutdown_event: Optional[asyncio.Event] = None
|
||||
|
||||
# Lazy import to avoid circular dependency
|
||||
_api_runtime = None
|
||||
|
||||
|
||||
def _get_api_runtime():
|
||||
global _api_runtime
|
||||
if _api_runtime is None:
|
||||
from backend.api import runtime as api_runtime_module
|
||||
_api_runtime = api_runtime_module
|
||||
return _api_runtime
|
||||
|
||||
|
||||
def set_global_runtime_manager(manager: "TradingRuntimeManager") -> None:
|
||||
global _global_runtime_manager
|
||||
_global_runtime_manager = manager
|
||||
# Sync to RuntimeState for consistency
|
||||
_get_api_runtime().register_runtime_manager(manager)
|
||||
|
||||
|
||||
def clear_global_runtime_manager() -> None:
|
||||
global _global_runtime_manager
|
||||
_global_runtime_manager = None
|
||||
# Sync to RuntimeState for consistency
|
||||
_get_api_runtime().unregister_runtime_manager()
|
||||
|
||||
|
||||
def get_global_runtime_manager() -> Optional["TradingRuntimeManager"]:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
426
backend/services/gateway_admin_handlers.py
Normal file
426
backend/services/gateway_admin_handlers.py
Normal file
@@ -0,0 +1,426 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Runtime/workspace/skills handlers extracted from the main Gateway module.
|
||||
|
||||
Deprecated note:
|
||||
Agent/workspace/skill read-write operations are being migrated to
|
||||
agent_service REST endpoints. These websocket handlers remain as a
|
||||
compatibility fallback and should not be considered the primary control
|
||||
plane path for frontend reads/writes.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from backend.agents.agent_workspace import load_agent_workspace_config
|
||||
from backend.agents.skills_manager import SkillsManager
|
||||
from backend.agents.toolkit_factory import load_agent_profiles
|
||||
from backend.config.bootstrap_config import (
|
||||
get_bootstrap_config_for_run,
|
||||
resolve_runtime_config,
|
||||
update_bootstrap_values_for_run,
|
||||
)
|
||||
from backend.data.market_ingest import ingest_symbols
|
||||
from backend.llm.models import get_agent_model_info
|
||||
|
||||
|
||||
async def handle_reload_runtime_assets(gateway: Any) -> None:
|
||||
config_name = gateway.config.get("config_name", "default")
|
||||
runtime_config = resolve_runtime_config(
|
||||
project_root=gateway._project_root,
|
||||
config_name=config_name,
|
||||
enable_memory=gateway.config.get("enable_memory", False),
|
||||
schedule_mode=gateway.config.get("schedule_mode", "daily"),
|
||||
interval_minutes=gateway.config.get("interval_minutes", 60),
|
||||
trigger_time=gateway.config.get("trigger_time", "09:30"),
|
||||
)
|
||||
result = gateway.pipeline.reload_runtime_assets(runtime_config=runtime_config)
|
||||
runtime_updates = gateway._apply_runtime_config(runtime_config)
|
||||
await gateway.state_sync.on_system_message("Runtime assets reloaded.")
|
||||
await gateway.broadcast({"type": "runtime_assets_reloaded", **result, **runtime_updates})
|
||||
|
||||
|
||||
async def handle_update_runtime_config(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||
updates: dict[str, Any] = {}
|
||||
|
||||
schedule_mode = str(data.get("schedule_mode", "")).strip().lower()
|
||||
if schedule_mode:
|
||||
if schedule_mode not in {"daily", "intraday"}:
|
||||
await websocket.send(json.dumps({"type": "error", "message": "schedule_mode must be 'daily' or 'intraday'."}, ensure_ascii=False))
|
||||
return
|
||||
updates["schedule_mode"] = schedule_mode
|
||||
|
||||
interval_minutes = data.get("interval_minutes")
|
||||
if interval_minutes is not None:
|
||||
try:
|
||||
parsed_interval = int(interval_minutes)
|
||||
except (TypeError, ValueError):
|
||||
parsed_interval = 0
|
||||
if parsed_interval <= 0:
|
||||
await websocket.send(json.dumps({"type": "error", "message": "interval_minutes must be a positive integer."}, ensure_ascii=False))
|
||||
return
|
||||
updates["interval_minutes"] = parsed_interval
|
||||
|
||||
trigger_time = data.get("trigger_time")
|
||||
if trigger_time is not None:
|
||||
raw_trigger = str(trigger_time).strip()
|
||||
if raw_trigger and raw_trigger != "now":
|
||||
try:
|
||||
datetime.strptime(raw_trigger, "%H:%M")
|
||||
except ValueError:
|
||||
await websocket.send(json.dumps({"type": "error", "message": "trigger_time must use HH:MM or 'now'."}, ensure_ascii=False))
|
||||
return
|
||||
updates["trigger_time"] = raw_trigger or "09:30"
|
||||
|
||||
max_comm_cycles = data.get("max_comm_cycles")
|
||||
if max_comm_cycles is not None:
|
||||
try:
|
||||
parsed_cycles = int(max_comm_cycles)
|
||||
except (TypeError, ValueError):
|
||||
parsed_cycles = 0
|
||||
if parsed_cycles <= 0:
|
||||
await websocket.send(json.dumps({"type": "error", "message": "max_comm_cycles must be a positive integer."}, ensure_ascii=False))
|
||||
return
|
||||
updates["max_comm_cycles"] = parsed_cycles
|
||||
|
||||
initial_cash = data.get("initial_cash")
|
||||
if initial_cash is not None:
|
||||
try:
|
||||
parsed_initial_cash = float(initial_cash)
|
||||
except (TypeError, ValueError):
|
||||
parsed_initial_cash = 0.0
|
||||
if parsed_initial_cash <= 0:
|
||||
await websocket.send(json.dumps({"type": "error", "message": "initial_cash must be a positive number."}, ensure_ascii=False))
|
||||
return
|
||||
updates["initial_cash"] = parsed_initial_cash
|
||||
|
||||
margin_requirement = data.get("margin_requirement")
|
||||
if margin_requirement is not None:
|
||||
try:
|
||||
parsed_margin_requirement = float(margin_requirement)
|
||||
except (TypeError, ValueError):
|
||||
parsed_margin_requirement = -1.0
|
||||
if parsed_margin_requirement < 0:
|
||||
await websocket.send(json.dumps({"type": "error", "message": "margin_requirement must be a non-negative number."}, ensure_ascii=False))
|
||||
return
|
||||
updates["margin_requirement"] = parsed_margin_requirement
|
||||
|
||||
enable_memory = data.get("enable_memory")
|
||||
if enable_memory is not None:
|
||||
updates["enable_memory"] = bool(enable_memory)
|
||||
|
||||
if not updates:
|
||||
await websocket.send(json.dumps({"type": "error", "message": "No runtime settings were provided."}, ensure_ascii=False))
|
||||
return
|
||||
|
||||
config_name = gateway.config.get("config_name", "default")
|
||||
update_bootstrap_values_for_run(
|
||||
project_root=gateway._project_root,
|
||||
config_name=config_name,
|
||||
updates=updates,
|
||||
)
|
||||
await gateway.state_sync.on_system_message("运行时调度配置已保存,正在热更新")
|
||||
await handle_reload_runtime_assets(gateway)
|
||||
|
||||
|
||||
async def handle_update_watchlist(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||
tickers = gateway._normalize_watchlist(data.get("tickers"))
|
||||
if not tickers:
|
||||
await websocket.send(json.dumps({"type": "error", "message": "update_watchlist requires at least one valid ticker."}, ensure_ascii=False))
|
||||
return
|
||||
|
||||
config_name = gateway.config.get("config_name", "default")
|
||||
update_bootstrap_values_for_run(
|
||||
project_root=gateway._project_root,
|
||||
config_name=config_name,
|
||||
updates={"tickers": tickers},
|
||||
)
|
||||
await gateway.state_sync.on_system_message(f"Watchlist updated: {', '.join(tickers)}")
|
||||
await gateway.broadcast({"type": "watchlist_updated", "config_name": config_name, "tickers": tickers})
|
||||
await handle_reload_runtime_assets(gateway)
|
||||
gateway._schedule_watchlist_market_store_refresh(tickers)
|
||||
|
||||
|
||||
async def handle_get_agent_skills(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||
agent_id = str(data.get("agent_id", "")).strip()
|
||||
if not agent_id:
|
||||
await websocket.send(json.dumps({"type": "error", "message": "get_agent_skills requires agent_id."}, ensure_ascii=False))
|
||||
return
|
||||
|
||||
config_name = gateway.config.get("config_name", "default")
|
||||
skills_manager = SkillsManager(project_root=gateway._project_root)
|
||||
agent_asset_dir = skills_manager.get_agent_asset_dir(config_name, agent_id)
|
||||
agent_config = load_agent_workspace_config(agent_asset_dir / "agent.yaml")
|
||||
resolved_skills = set(skills_manager.resolve_agent_skill_names(config_name=config_name, agent_id=agent_id, default_skills=[]))
|
||||
enabled = set(agent_config.enabled_skills)
|
||||
disabled = set(agent_config.disabled_skills)
|
||||
|
||||
payload = []
|
||||
for item in skills_manager.list_agent_skill_catalog(config_name, agent_id):
|
||||
if item.skill_name in disabled:
|
||||
status = "disabled"
|
||||
elif item.skill_name in enabled:
|
||||
status = "enabled"
|
||||
elif item.skill_name in resolved_skills:
|
||||
status = "active"
|
||||
else:
|
||||
status = "available"
|
||||
payload.append({
|
||||
"skill_name": item.skill_name,
|
||||
"name": item.name,
|
||||
"description": item.description,
|
||||
"version": item.version,
|
||||
"source": item.source,
|
||||
"tools": item.tools,
|
||||
"status": status,
|
||||
})
|
||||
|
||||
await websocket.send(json.dumps({
|
||||
"type": "agent_skills_loaded",
|
||||
"config_name": config_name,
|
||||
"agent_id": agent_id,
|
||||
"skills": payload,
|
||||
}, ensure_ascii=False))
|
||||
|
||||
|
||||
async def handle_get_agent_profile(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||
agent_id = str(data.get("agent_id", "")).strip()
|
||||
if not agent_id:
|
||||
await websocket.send(json.dumps({"type": "error", "message": "get_agent_profile requires agent_id."}, ensure_ascii=False))
|
||||
return
|
||||
|
||||
config_name = gateway.config.get("config_name", "default")
|
||||
skills_manager = SkillsManager(project_root=gateway._project_root)
|
||||
asset_dir = skills_manager.get_agent_asset_dir(config_name, agent_id)
|
||||
agent_config = load_agent_workspace_config(asset_dir / "agent.yaml")
|
||||
profiles = load_agent_profiles()
|
||||
profile = profiles.get(agent_id, {})
|
||||
bootstrap = get_bootstrap_config_for_run(gateway._project_root, config_name)
|
||||
override = bootstrap.agent_override(agent_id)
|
||||
active_tool_groups = override.get("active_tool_groups", agent_config.active_tool_groups or profile.get("active_tool_groups", []))
|
||||
if not isinstance(active_tool_groups, list):
|
||||
active_tool_groups = []
|
||||
disabled_tool_groups = agent_config.disabled_tool_groups
|
||||
if disabled_tool_groups:
|
||||
disabled_set = set(disabled_tool_groups)
|
||||
active_tool_groups = [group_name for group_name in active_tool_groups if group_name not in disabled_set]
|
||||
|
||||
default_skills = profile.get("skills", [])
|
||||
if not isinstance(default_skills, list):
|
||||
default_skills = []
|
||||
resolved_skills = skills_manager.resolve_agent_skill_names(
|
||||
config_name=config_name,
|
||||
agent_id=agent_id,
|
||||
default_skills=default_skills,
|
||||
)
|
||||
prompt_files = agent_config.prompt_files or ["SOUL.md", "PROFILE.md", "AGENTS.md", "POLICY.md", "MEMORY.md"]
|
||||
model_name, model_provider = get_agent_model_info(agent_id)
|
||||
|
||||
await websocket.send(json.dumps({
|
||||
"type": "agent_profile_loaded",
|
||||
"config_name": config_name,
|
||||
"agent_id": agent_id,
|
||||
"profile": {
|
||||
"model_name": model_name,
|
||||
"model_provider": model_provider,
|
||||
"prompt_files": prompt_files,
|
||||
"default_skills": default_skills,
|
||||
"resolved_skills": resolved_skills,
|
||||
"active_tool_groups": active_tool_groups,
|
||||
"disabled_tool_groups": disabled_tool_groups,
|
||||
"enabled_skills": agent_config.enabled_skills,
|
||||
"disabled_skills": agent_config.disabled_skills,
|
||||
},
|
||||
}, ensure_ascii=False))
|
||||
|
||||
|
||||
async def handle_get_skill_detail(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||
agent_id = str(data.get("agent_id", "")).strip()
|
||||
skill_name = str(data.get("skill_name", "")).strip()
|
||||
if not skill_name:
|
||||
await websocket.send(json.dumps({"type": "error", "message": "get_skill_detail requires skill_name."}, ensure_ascii=False))
|
||||
return
|
||||
|
||||
skills_manager = SkillsManager(project_root=gateway._project_root)
|
||||
try:
|
||||
if agent_id:
|
||||
config_name = gateway.config.get("config_name", "default")
|
||||
detail = skills_manager.load_agent_skill_document(config_name=config_name, agent_id=agent_id, skill_name=skill_name)
|
||||
else:
|
||||
detail = skills_manager.load_skill_document(skill_name)
|
||||
except FileNotFoundError:
|
||||
await websocket.send(json.dumps({"type": "error", "message": f"Unknown skill: {skill_name}"}, ensure_ascii=False))
|
||||
return
|
||||
|
||||
await websocket.send(json.dumps({
|
||||
"type": "skill_detail_loaded",
|
||||
"agent_id": agent_id,
|
||||
"skill": detail,
|
||||
}, ensure_ascii=False))
|
||||
|
||||
|
||||
async def handle_create_agent_local_skill(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||
agent_id = str(data.get("agent_id", "")).strip()
|
||||
skill_name = str(data.get("skill_name", "")).strip()
|
||||
if not agent_id or not skill_name:
|
||||
await websocket.send(json.dumps({"type": "error", "message": "create_agent_local_skill requires agent_id and skill_name."}, ensure_ascii=False))
|
||||
return
|
||||
|
||||
config_name = gateway.config.get("config_name", "default")
|
||||
skills_manager = SkillsManager(project_root=gateway._project_root)
|
||||
try:
|
||||
skills_manager.create_agent_local_skill(config_name=config_name, agent_id=agent_id, skill_name=skill_name)
|
||||
except (ValueError, FileExistsError) as exc:
|
||||
await websocket.send(json.dumps({"type": "error", "message": str(exc)}, ensure_ascii=False))
|
||||
return
|
||||
|
||||
await gateway.state_sync.on_system_message(f"Created local skill {skill_name} for {agent_id}")
|
||||
await gateway._handle_reload_runtime_assets()
|
||||
await websocket.send(json.dumps({"type": "agent_local_skill_created", "agent_id": agent_id, "skill_name": skill_name}, ensure_ascii=False))
|
||||
await handle_get_agent_skills(gateway, websocket, {"agent_id": agent_id})
|
||||
await handle_get_skill_detail(gateway, websocket, {"agent_id": agent_id, "skill_name": skill_name})
|
||||
|
||||
|
||||
async def handle_update_agent_local_skill(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||
agent_id = str(data.get("agent_id", "")).strip()
|
||||
skill_name = str(data.get("skill_name", "")).strip()
|
||||
content = data.get("content")
|
||||
if not agent_id or not skill_name or not isinstance(content, str):
|
||||
await websocket.send(json.dumps({"type": "error", "message": "update_agent_local_skill requires agent_id, skill_name, and string content."}, ensure_ascii=False))
|
||||
return
|
||||
|
||||
config_name = gateway.config.get("config_name", "default")
|
||||
skills_manager = SkillsManager(project_root=gateway._project_root)
|
||||
try:
|
||||
skills_manager.update_agent_local_skill(config_name=config_name, agent_id=agent_id, skill_name=skill_name, content=content)
|
||||
except (ValueError, FileNotFoundError) as exc:
|
||||
await websocket.send(json.dumps({"type": "error", "message": str(exc)}, ensure_ascii=False))
|
||||
return
|
||||
|
||||
await gateway.state_sync.on_system_message(f"Updated local skill {skill_name} for {agent_id}")
|
||||
await gateway._handle_reload_runtime_assets()
|
||||
await websocket.send(json.dumps({"type": "agent_local_skill_updated", "agent_id": agent_id, "skill_name": skill_name}, ensure_ascii=False))
|
||||
await handle_get_skill_detail(gateway, websocket, {"agent_id": agent_id, "skill_name": skill_name})
|
||||
|
||||
|
||||
async def handle_delete_agent_local_skill(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||
agent_id = str(data.get("agent_id", "")).strip()
|
||||
skill_name = str(data.get("skill_name", "")).strip()
|
||||
if not agent_id or not skill_name:
|
||||
await websocket.send(json.dumps({"type": "error", "message": "delete_agent_local_skill requires agent_id and skill_name."}, ensure_ascii=False))
|
||||
return
|
||||
|
||||
config_name = gateway.config.get("config_name", "default")
|
||||
skills_manager = SkillsManager(project_root=gateway._project_root)
|
||||
try:
|
||||
skills_manager.delete_agent_local_skill(config_name=config_name, agent_id=agent_id, skill_name=skill_name)
|
||||
skills_manager.forget_agent_skill_overrides(config_name=config_name, agent_id=agent_id, skill_names=[skill_name])
|
||||
except (ValueError, FileNotFoundError) as exc:
|
||||
await websocket.send(json.dumps({"type": "error", "message": str(exc)}, ensure_ascii=False))
|
||||
return
|
||||
|
||||
await gateway.state_sync.on_system_message(f"Deleted local skill {skill_name} for {agent_id}")
|
||||
await gateway._handle_reload_runtime_assets()
|
||||
await websocket.send(json.dumps({"type": "agent_local_skill_deleted", "agent_id": agent_id, "skill_name": skill_name}, ensure_ascii=False))
|
||||
await handle_get_agent_skills(gateway, websocket, {"agent_id": agent_id})
|
||||
|
||||
|
||||
async def handle_remove_agent_skill(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||
agent_id = str(data.get("agent_id", "")).strip()
|
||||
skill_name = str(data.get("skill_name", "")).strip()
|
||||
if not agent_id or not skill_name:
|
||||
await websocket.send(json.dumps({"type": "error", "message": "remove_agent_skill requires agent_id and skill_name."}, ensure_ascii=False))
|
||||
return
|
||||
|
||||
config_name = gateway.config.get("config_name", "default")
|
||||
skills_manager = SkillsManager(project_root=gateway._project_root)
|
||||
skill_names = {
|
||||
item.skill_name
|
||||
for item in skills_manager.list_agent_skill_catalog(config_name, agent_id)
|
||||
if item.source != "local"
|
||||
}
|
||||
if skill_name not in skill_names:
|
||||
await websocket.send(json.dumps({"type": "error", "message": f"Unknown shared skill: {skill_name}"}, ensure_ascii=False))
|
||||
return
|
||||
|
||||
skills_manager.update_agent_skill_overrides(config_name=config_name, agent_id=agent_id, disable=[skill_name])
|
||||
await gateway.state_sync.on_system_message(f"Removed shared skill {skill_name} from {agent_id}")
|
||||
await gateway._handle_reload_runtime_assets()
|
||||
await websocket.send(json.dumps({"type": "agent_skill_removed", "agent_id": agent_id, "skill_name": skill_name}, ensure_ascii=False))
|
||||
await handle_get_agent_skills(gateway, websocket, {"agent_id": agent_id})
|
||||
|
||||
|
||||
async def handle_update_agent_skill(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||
agent_id = str(data.get("agent_id", "")).strip()
|
||||
skill_name = str(data.get("skill_name", "")).strip()
|
||||
enabled = data.get("enabled")
|
||||
if not agent_id or not skill_name or not isinstance(enabled, bool):
|
||||
await websocket.send(json.dumps({"type": "error", "message": "update_agent_skill requires agent_id, skill_name, and boolean enabled."}, ensure_ascii=False))
|
||||
return
|
||||
|
||||
config_name = gateway.config.get("config_name", "default")
|
||||
skills_manager = SkillsManager(project_root=gateway._project_root)
|
||||
skill_names = {item.skill_name for item in skills_manager.list_agent_skill_catalog(config_name, agent_id)}
|
||||
if skill_name not in skill_names:
|
||||
await websocket.send(json.dumps({"type": "error", "message": f"Unknown skill: {skill_name}"}, ensure_ascii=False))
|
||||
return
|
||||
|
||||
if enabled:
|
||||
skills_manager.update_agent_skill_overrides(config_name=config_name, agent_id=agent_id, enable=[skill_name])
|
||||
await gateway.state_sync.on_system_message(f"Enabled skill {skill_name} for {agent_id}")
|
||||
else:
|
||||
skills_manager.update_agent_skill_overrides(config_name=config_name, agent_id=agent_id, disable=[skill_name])
|
||||
await gateway.state_sync.on_system_message(f"Disabled skill {skill_name} for {agent_id}")
|
||||
|
||||
await websocket.send(json.dumps({
|
||||
"type": "agent_skill_updated",
|
||||
"agent_id": agent_id,
|
||||
"skill_name": skill_name,
|
||||
"enabled": enabled,
|
||||
}, ensure_ascii=False))
|
||||
await gateway._handle_reload_runtime_assets()
|
||||
await handle_get_agent_skills(gateway, websocket, {"agent_id": agent_id})
|
||||
|
||||
|
||||
async def handle_get_agent_workspace_file(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||
agent_id = str(data.get("agent_id", "")).strip()
|
||||
filename = gateway._normalize_agent_workspace_filename(data.get("filename"))
|
||||
if not agent_id or not filename:
|
||||
await websocket.send(json.dumps({"type": "error", "message": "get_agent_workspace_file requires agent_id and supported filename."}, ensure_ascii=False))
|
||||
return
|
||||
|
||||
config_name = gateway.config.get("config_name", "default")
|
||||
skills_manager = SkillsManager(project_root=gateway._project_root)
|
||||
asset_dir = skills_manager.get_agent_asset_dir(config_name, agent_id)
|
||||
asset_dir.mkdir(parents=True, exist_ok=True)
|
||||
path = asset_dir / filename
|
||||
content = path.read_text(encoding="utf-8") if path.exists() else ""
|
||||
await websocket.send(json.dumps({
|
||||
"type": "agent_workspace_file_loaded",
|
||||
"config_name": config_name,
|
||||
"agent_id": agent_id,
|
||||
"filename": filename,
|
||||
"content": content,
|
||||
}, ensure_ascii=False))
|
||||
|
||||
|
||||
async def handle_update_agent_workspace_file(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||
agent_id = str(data.get("agent_id", "")).strip()
|
||||
filename = gateway._normalize_agent_workspace_filename(data.get("filename"))
|
||||
content = data.get("content")
|
||||
if not agent_id or not filename or not isinstance(content, str):
|
||||
await websocket.send(json.dumps({"type": "error", "message": "update_agent_workspace_file requires agent_id, supported filename, and string content."}, ensure_ascii=False))
|
||||
return
|
||||
|
||||
config_name = gateway.config.get("config_name", "default")
|
||||
skills_manager = SkillsManager(project_root=gateway._project_root)
|
||||
asset_dir = skills_manager.get_agent_asset_dir(config_name, agent_id)
|
||||
asset_dir.mkdir(parents=True, exist_ok=True)
|
||||
path = asset_dir / filename
|
||||
path.write_text(content, encoding="utf-8")
|
||||
await gateway.state_sync.on_system_message(f"Updated {filename} for {agent_id}")
|
||||
await websocket.send(json.dumps({"type": "agent_workspace_file_updated", "agent_id": agent_id, "filename": filename}, ensure_ascii=False))
|
||||
await gateway._handle_reload_runtime_assets()
|
||||
await handle_get_agent_workspace_file(gateway, websocket, {"agent_id": agent_id, "filename": filename})
|
||||
391
backend/services/gateway_cycle_support.py
Normal file
391
backend/services/gateway_cycle_support.py
Normal file
@@ -0,0 +1,391 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Cycle and monitoring helpers extracted from the main Gateway module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.data.market_ingest import ingest_symbols, refresh_news_for_symbols
|
||||
from backend.domains import trading as trading_domain
|
||||
from backend.utils.msg_adapter import FrontendAdapter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def schedule_watchlist_market_store_refresh(gateway: Any, tickers: list[str]) -> None:
|
||||
"""Kick off a non-blocking market-store refresh for an updated watchlist."""
|
||||
if not tickers:
|
||||
return
|
||||
if gateway._watchlist_ingest_task and not gateway._watchlist_ingest_task.done():
|
||||
gateway._watchlist_ingest_task.cancel()
|
||||
gateway._watchlist_ingest_task = asyncio.create_task(
|
||||
refresh_market_store_for_watchlist(gateway, tickers),
|
||||
)
|
||||
|
||||
|
||||
async def refresh_market_store_for_watchlist(gateway: Any, tickers: list[str]) -> None:
|
||||
"""Refresh the long-lived market store after a watchlist update."""
|
||||
try:
|
||||
await gateway.state_sync.on_system_message(
|
||||
f"正在同步自选股市场数据: {', '.join(tickers)}",
|
||||
)
|
||||
results = await asyncio.to_thread(
|
||||
ingest_symbols,
|
||||
tickers,
|
||||
mode="incremental",
|
||||
)
|
||||
summary = ", ".join(
|
||||
f"{item['symbol']} prices={item['prices']} news={item['news']}"
|
||||
for item in results
|
||||
)
|
||||
await gateway.state_sync.on_system_message(
|
||||
f"自选股市场数据已同步: {summary}",
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.warning("Watchlist market store refresh failed: %s", exc)
|
||||
await gateway.state_sync.on_system_message(
|
||||
f"自选股市场数据同步失败: {exc}",
|
||||
)
|
||||
|
||||
|
||||
async def market_status_monitor(gateway: Any) -> None:
|
||||
"""Periodically check and broadcast market status changes."""
|
||||
while True:
|
||||
try:
|
||||
await gateway.market_service.check_and_broadcast_market_status()
|
||||
|
||||
status = gateway.market_service.get_market_status()
|
||||
if status["status"] == "open" and not gateway.storage.is_live_session_active:
|
||||
gateway.storage.start_live_session()
|
||||
summary = gateway.storage.build_dashboard_snapshot_from_state(gateway.state_sync.state).get("summary") or {}
|
||||
gateway._session_start_portfolio_value = summary.get(
|
||||
"totalAssetValue",
|
||||
gateway.storage.initial_cash,
|
||||
)
|
||||
logger.info(
|
||||
"Session start portfolio: $%s",
|
||||
f"{gateway._session_start_portfolio_value:,.2f}",
|
||||
)
|
||||
elif status["status"] != "open" and gateway.storage.is_live_session_active:
|
||||
gateway.storage.end_live_session()
|
||||
gateway._session_start_portfolio_value = None
|
||||
|
||||
if gateway.storage.is_live_session_active:
|
||||
await update_and_broadcast_live_returns(gateway)
|
||||
|
||||
await asyncio.sleep(60)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as exc:
|
||||
logger.error("Market status monitor error: %s", exc)
|
||||
await asyncio.sleep(60)
|
||||
|
||||
|
||||
async def update_and_broadcast_live_returns(gateway: Any) -> None:
|
||||
"""Calculate and broadcast live returns for current session."""
|
||||
if not gateway.storage.is_live_session_active:
|
||||
return
|
||||
|
||||
prices = gateway.market_service.get_all_prices()
|
||||
if not prices or not any(p > 0 for p in prices.values()):
|
||||
return
|
||||
|
||||
state = gateway.storage.load_internal_state()
|
||||
equity_history = state.get("equity_history", [])
|
||||
baseline_history = state.get("baseline_history", [])
|
||||
baseline_vw_history = state.get("baseline_vw_history", [])
|
||||
momentum_history = state.get("momentum_history", [])
|
||||
|
||||
current_equity = equity_history[-1]["v"] if equity_history else None
|
||||
current_baseline = baseline_history[-1]["v"] if baseline_history else None
|
||||
current_baseline_vw = baseline_vw_history[-1]["v"] if baseline_vw_history else None
|
||||
current_momentum = momentum_history[-1]["v"] if momentum_history else None
|
||||
|
||||
point = gateway.storage.update_live_returns(
|
||||
current_equity=current_equity,
|
||||
current_baseline=current_baseline,
|
||||
current_baseline_vw=current_baseline_vw,
|
||||
current_momentum=current_momentum,
|
||||
)
|
||||
if point:
|
||||
live_returns = gateway.storage.get_live_returns()
|
||||
await gateway.broadcast(
|
||||
{
|
||||
"type": "team_summary",
|
||||
"equity_return": live_returns["equity_return"],
|
||||
"baseline_return": live_returns["baseline_return"],
|
||||
"baseline_vw_return": live_returns["baseline_vw_return"],
|
||||
"momentum_return": live_returns["momentum_return"],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def on_strategy_trigger(gateway: Any, date: str) -> None:
|
||||
"""Handle trading cycle trigger."""
|
||||
if gateway._cycle_lock.locked():
|
||||
logger.warning("Trading cycle already running, skipping trigger for %s", date)
|
||||
await gateway.state_sync.on_system_message(f"已有交易周期在运行,跳过本次触发: {date}")
|
||||
return
|
||||
|
||||
async with gateway._cycle_lock:
|
||||
logger.info("Strategy triggered for %s", date)
|
||||
tickers = gateway.config.get("tickers", [])
|
||||
if gateway.is_backtest:
|
||||
await run_backtest_cycle(gateway, date, tickers)
|
||||
else:
|
||||
await run_live_cycle(gateway, date, tickers)
|
||||
|
||||
|
||||
async def on_heartbeat_trigger(gateway: Any, date: str) -> None:
|
||||
"""Run lightweight heartbeat check for all analysts."""
|
||||
logger.info("[Heartbeat] Running heartbeat check for %s", date)
|
||||
analysts = gateway.pipeline._all_analysts()
|
||||
|
||||
for analyst in analysts:
|
||||
try:
|
||||
ws_id = getattr(analyst, "workspace_id", None)
|
||||
if ws_id:
|
||||
from backend.agents.workspace_manager import get_workspace_dir
|
||||
from pathlib import Path
|
||||
from agentscope.message import Msg
|
||||
|
||||
ws_dir = get_workspace_dir(ws_id)
|
||||
if ws_dir:
|
||||
hb_path = Path(ws_dir) / "HEARTBEAT.md"
|
||||
if hb_path.exists():
|
||||
content = hb_path.read_text(encoding="utf-8").strip()
|
||||
if content:
|
||||
hb_task = f"# 定期主动检查\n\n{content}\n\n请执行上述检查并报告结果。"
|
||||
logger.info("[Heartbeat] Running heartbeat for %s", analyst.name)
|
||||
msg = Msg(role="user", content=hb_task, name="system")
|
||||
await analyst.reply([msg])
|
||||
logger.info("[Heartbeat] %s heartbeat complete", analyst.name)
|
||||
continue
|
||||
logger.debug("[Heartbeat] No HEARTBEAT.md for %s, skipping", analyst.name)
|
||||
except Exception as exc:
|
||||
logger.error("[Heartbeat] %s failed: %s", analyst.name, exc, exc_info=True)
|
||||
|
||||
|
||||
async def run_backtest_cycle(gateway: Any, date: str, tickers: list[str]) -> None:
|
||||
gateway.market_service.set_backtest_date(date)
|
||||
await gateway.market_service.emit_market_open()
|
||||
|
||||
await gateway.state_sync.on_cycle_start(date)
|
||||
gateway._dashboard.update(date=date, status="Analyzing...")
|
||||
|
||||
prices = gateway.market_service.get_open_prices()
|
||||
close_prices = gateway.market_service.get_close_prices()
|
||||
market_caps = await get_market_caps(gateway, tickers, date)
|
||||
|
||||
result = await gateway.pipeline.run_cycle(
|
||||
tickers=tickers,
|
||||
date=date,
|
||||
prices=prices,
|
||||
close_prices=close_prices,
|
||||
market_caps=market_caps,
|
||||
)
|
||||
|
||||
await gateway.market_service.emit_market_close()
|
||||
settlement_result = result.get("settlement_result")
|
||||
save_cycle_results(gateway, result, date, close_prices, settlement_result)
|
||||
await broadcast_portfolio_updates(gateway, result, close_prices)
|
||||
await finalize_cycle(gateway, date)
|
||||
|
||||
|
||||
async def run_live_cycle(gateway: Any, date: str, tickers: list[str]) -> None:
|
||||
trading_date = gateway.market_service.get_live_trading_date()
|
||||
logger.info("Live cycle: triggered=%s, trading_date=%s", date, trading_date)
|
||||
|
||||
try:
|
||||
news_refresh = await asyncio.to_thread(
|
||||
refresh_news_for_symbols,
|
||||
tickers,
|
||||
end_date=trading_date,
|
||||
store=gateway.storage.market_store,
|
||||
)
|
||||
logger.info(
|
||||
"News refresh complete: %s",
|
||||
", ".join(
|
||||
f"{item['symbol']} news={item['news']}"
|
||||
for item in news_refresh
|
||||
) or "no symbols",
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("Live cycle news refresh failed: %s", exc)
|
||||
|
||||
await gateway.state_sync.on_cycle_start(trading_date)
|
||||
gateway._dashboard.update(date=trading_date, status="Analyzing...")
|
||||
|
||||
market_caps = await get_market_caps(gateway, tickers, trading_date)
|
||||
schedule_mode = gateway.config.get("schedule_mode", "daily")
|
||||
market_status = gateway.market_service.get_market_status()
|
||||
current_prices = gateway.market_service.get_all_prices()
|
||||
|
||||
if schedule_mode == "intraday":
|
||||
execute_decisions = market_status.get("status") == "open"
|
||||
if execute_decisions:
|
||||
await gateway.state_sync.on_system_message("定时任务触发:当前处于交易时段,本轮将执行交易决策")
|
||||
else:
|
||||
await gateway.state_sync.on_system_message("定时任务触发:当前非交易时段,本轮仅更新数据与分析,不执行交易")
|
||||
|
||||
result = await gateway.pipeline.run_cycle(
|
||||
tickers=tickers,
|
||||
date=trading_date,
|
||||
prices=current_prices,
|
||||
market_caps=market_caps,
|
||||
execute_decisions=execute_decisions,
|
||||
)
|
||||
close_prices = current_prices
|
||||
else:
|
||||
result = await gateway.pipeline.run_cycle(
|
||||
tickers=tickers,
|
||||
date=trading_date,
|
||||
market_caps=market_caps,
|
||||
get_open_prices_fn=gateway.market_service.wait_for_open_prices,
|
||||
get_close_prices_fn=gateway.market_service.wait_for_close_prices,
|
||||
)
|
||||
close_prices = gateway.market_service.get_all_prices()
|
||||
|
||||
settlement_result = result.get("settlement_result")
|
||||
save_cycle_results(gateway, result, trading_date, close_prices, settlement_result)
|
||||
await broadcast_portfolio_updates(gateway, result, close_prices)
|
||||
await finalize_cycle(gateway, trading_date)
|
||||
|
||||
|
||||
async def finalize_cycle(gateway: Any, date: str) -> None:
|
||||
dashboard_snapshot = gateway.storage.build_dashboard_snapshot_from_state(gateway.state_sync.state)
|
||||
summary = dashboard_snapshot.get("summary") or {}
|
||||
if gateway.storage.is_live_session_active:
|
||||
summary.update(gateway.storage.get_live_returns())
|
||||
|
||||
await gateway.state_sync.on_cycle_end(date, portfolio_summary=summary)
|
||||
holdings = dashboard_snapshot.get("holdings") or []
|
||||
trades = dashboard_snapshot.get("trades") or []
|
||||
leaderboard = dashboard_snapshot.get("leaderboard") or []
|
||||
if leaderboard:
|
||||
await gateway.state_sync.on_leaderboard_update(leaderboard)
|
||||
gateway._dashboard.update(date=date, status="Running", portfolio=summary, holdings=holdings, trades=trades)
|
||||
|
||||
|
||||
async def get_market_caps(gateway: Any, tickers: list[str], date: str) -> dict[str, float]:
|
||||
market_caps: dict[str, float] = {}
|
||||
for ticker in tickers:
|
||||
try:
|
||||
market_cap = None
|
||||
response = await gateway._call_trading_service(
|
||||
f"get_market_cap for {ticker}",
|
||||
lambda client, symbol=ticker: client.get_market_cap(ticker=symbol, end_date=date),
|
||||
)
|
||||
if response is not None:
|
||||
market_cap = response.get("market_cap")
|
||||
if market_cap is None:
|
||||
payload = trading_domain.get_market_cap_payload(ticker=ticker, end_date=date)
|
||||
market_cap = payload.get("market_cap")
|
||||
market_caps[ticker] = market_cap if market_cap else 1e9
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to get market cap for %s, using default 1e9: %s", ticker, exc)
|
||||
market_caps[ticker] = 1e9
|
||||
return market_caps
|
||||
|
||||
|
||||
async def broadcast_portfolio_updates(gateway: Any, result: dict[str, Any], prices: dict[str, float]) -> None:
|
||||
portfolio = result.get("portfolio", {})
|
||||
if portfolio:
|
||||
holdings = FrontendAdapter.build_holdings(portfolio, prices)
|
||||
if holdings:
|
||||
await gateway.state_sync.on_holdings_update(holdings)
|
||||
stats = FrontendAdapter.build_stats(portfolio, prices)
|
||||
if stats:
|
||||
await gateway.state_sync.on_stats_update(stats)
|
||||
|
||||
executed_trades = result.get("executed_trades", [])
|
||||
if executed_trades:
|
||||
await gateway.state_sync.on_trades_executed(executed_trades)
|
||||
|
||||
|
||||
def save_cycle_results(
|
||||
gateway: Any,
|
||||
result: dict[str, Any],
|
||||
date: str,
|
||||
prices: dict[str, float],
|
||||
settlement_result: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
portfolio = result.get("portfolio", {})
|
||||
executed_trades = result.get("executed_trades", [])
|
||||
baseline_values = settlement_result.get("baseline_values") if settlement_result else None
|
||||
if portfolio:
|
||||
gateway.storage.update_dashboard_after_cycle(
|
||||
portfolio=portfolio,
|
||||
prices=prices,
|
||||
date=date,
|
||||
executed_trades=executed_trades,
|
||||
baseline_values=baseline_values,
|
||||
)
|
||||
|
||||
|
||||
async def run_backtest_dates(gateway: Any, dates: list[str]) -> None:
|
||||
gateway.state_sync.set_backtest_dates(dates)
|
||||
gateway._dashboard.update(days_total=len(dates), days_completed=0)
|
||||
await gateway.state_sync.on_system_message(f"Starting backtest - {len(dates)} trading days")
|
||||
try:
|
||||
for i, date in enumerate(dates):
|
||||
gateway._dashboard.update(days_completed=i)
|
||||
await gateway.on_strategy_trigger(date=date)
|
||||
await asyncio.sleep(0.1)
|
||||
await gateway.state_sync.on_system_message(f"Backtest complete - {len(dates)} days")
|
||||
summary = gateway.storage.build_dashboard_snapshot_from_state(gateway.state_sync.state).get("summary") or {}
|
||||
gateway._dashboard.update(status="Complete", portfolio=summary, days_completed=len(dates))
|
||||
gateway._dashboard.stop()
|
||||
gateway._dashboard.print_final_summary()
|
||||
except Exception as exc:
|
||||
error_msg = f"Backtest failed: {type(exc).__name__}: {str(exc)}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
asyncio.create_task(gateway.state_sync.on_system_message(error_msg))
|
||||
gateway._dashboard.update(status=f"Failed: {str(exc)}")
|
||||
gateway._dashboard.stop()
|
||||
raise
|
||||
finally:
|
||||
gateway._backtest_task = None
|
||||
|
||||
|
||||
def handle_backtest_exception(gateway: Any, task: asyncio.Task) -> None:
|
||||
try:
|
||||
task.result()
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Backtest task was cancelled")
|
||||
except Exception as exc:
|
||||
logger.error("Backtest task failed with exception:%s:%s", type(exc).__name__, exc, exc_info=True)
|
||||
|
||||
|
||||
def handle_manual_cycle_exception(gateway: Any, task: asyncio.Task) -> None:
|
||||
gateway._manual_cycle_task = None
|
||||
try:
|
||||
task.result()
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Manual cycle task was cancelled")
|
||||
except Exception as exc:
|
||||
logger.error("Manual cycle task failed with exception:%s:%s", type(exc).__name__, exc, exc_info=True)
|
||||
|
||||
|
||||
def set_backtest_dates(gateway: Any, dates: list[str]) -> None:
|
||||
gateway.state_sync.set_backtest_dates(dates)
|
||||
if dates:
|
||||
gateway._backtest_start_date = dates[0]
|
||||
gateway._backtest_end_date = dates[-1]
|
||||
gateway._dashboard.days_total = len(dates)
|
||||
|
||||
|
||||
def stop_gateway(gateway: Any) -> None:
|
||||
gateway.state_sync.save_state()
|
||||
gateway.market_service.stop()
|
||||
if gateway._backtest_task:
|
||||
gateway._backtest_task.cancel()
|
||||
if gateway._market_status_task:
|
||||
gateway._market_status_task.cancel()
|
||||
if gateway._watchlist_ingest_task:
|
||||
gateway._watchlist_ingest_task.cancel()
|
||||
gateway._dashboard.stop()
|
||||
175
backend/services/gateway_runtime_support.py
Normal file
175
backend/services/gateway_runtime_support.py
Normal file
@@ -0,0 +1,175 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Runtime/state support helpers extracted from the main Gateway module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from backend.data.provider_utils import normalize_symbol
|
||||
|
||||
|
||||
def normalize_watchlist(raw_tickers: Any) -> list[str]:
|
||||
"""Parse watchlist payloads from websocket messages."""
|
||||
if raw_tickers is None:
|
||||
return []
|
||||
|
||||
if isinstance(raw_tickers, str):
|
||||
candidates = raw_tickers.split(",")
|
||||
elif isinstance(raw_tickers, list):
|
||||
candidates = raw_tickers
|
||||
else:
|
||||
candidates = [raw_tickers]
|
||||
|
||||
tickers: list[str] = []
|
||||
for candidate in candidates:
|
||||
symbol = normalize_symbol(str(candidate).strip().strip("\"'"))
|
||||
if symbol and symbol not in tickers:
|
||||
tickers.append(symbol)
|
||||
return tickers
|
||||
|
||||
|
||||
def normalize_agent_workspace_filename(
|
||||
raw_name: Any,
|
||||
*,
|
||||
allowlist: set[str],
|
||||
) -> str | None:
|
||||
"""Restrict editable workspace files to a safe allowlist."""
|
||||
filename = str(raw_name or "").strip()
|
||||
if filename in allowlist:
|
||||
return filename
|
||||
return None
|
||||
|
||||
|
||||
def apply_runtime_config(gateway: Any, runtime_config: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Apply runtime config to gateway-owned services and state."""
|
||||
warnings: list[str] = []
|
||||
|
||||
ticker_changes = gateway.market_service.update_tickers(
|
||||
runtime_config.get("tickers", []),
|
||||
)
|
||||
gateway.config["tickers"] = ticker_changes["active"]
|
||||
|
||||
gateway.pipeline.max_comm_cycles = int(runtime_config["max_comm_cycles"])
|
||||
gateway.config["max_comm_cycles"] = gateway.pipeline.max_comm_cycles
|
||||
gateway.config["schedule_mode"] = runtime_config.get(
|
||||
"schedule_mode",
|
||||
gateway.config.get("schedule_mode", "daily"),
|
||||
)
|
||||
gateway.config["interval_minutes"] = int(
|
||||
runtime_config.get(
|
||||
"interval_minutes",
|
||||
gateway.config.get("interval_minutes", 60),
|
||||
),
|
||||
)
|
||||
gateway.config["trigger_time"] = runtime_config.get(
|
||||
"trigger_time",
|
||||
gateway.config.get("trigger_time", "09:30"),
|
||||
)
|
||||
|
||||
if gateway.scheduler:
|
||||
gateway.scheduler.reconfigure(
|
||||
mode=gateway.config["schedule_mode"],
|
||||
trigger_time=gateway.config["trigger_time"],
|
||||
interval_minutes=gateway.config["interval_minutes"],
|
||||
)
|
||||
|
||||
pm_apply_result = gateway.pipeline.pm.apply_runtime_portfolio_config(
|
||||
margin_requirement=runtime_config["margin_requirement"],
|
||||
)
|
||||
gateway.config["margin_requirement"] = gateway.pipeline.pm.portfolio.get(
|
||||
"margin_requirement",
|
||||
runtime_config["margin_requirement"],
|
||||
)
|
||||
|
||||
requested_initial_cash = float(runtime_config["initial_cash"])
|
||||
current_initial_cash = float(gateway.storage.initial_cash)
|
||||
initial_cash_applied = requested_initial_cash == current_initial_cash
|
||||
if not initial_cash_applied:
|
||||
if (
|
||||
gateway.storage.can_apply_initial_cash()
|
||||
and gateway.pipeline.pm.can_apply_initial_cash()
|
||||
):
|
||||
initial_cash_applied = gateway.storage.apply_initial_cash(
|
||||
requested_initial_cash,
|
||||
)
|
||||
if initial_cash_applied:
|
||||
gateway.pipeline.pm.apply_runtime_portfolio_config(
|
||||
initial_cash=requested_initial_cash,
|
||||
)
|
||||
gateway.config["initial_cash"] = gateway.storage.initial_cash
|
||||
else:
|
||||
warnings.append(
|
||||
"initial_cash changed in BOOTSTRAP.md but was not applied "
|
||||
"because the run already has positions, margin usage, or trades.",
|
||||
)
|
||||
|
||||
requested_enable_memory = bool(runtime_config["enable_memory"])
|
||||
current_enable_memory = bool(gateway.config.get("enable_memory", False))
|
||||
if requested_enable_memory != current_enable_memory:
|
||||
warnings.append(
|
||||
"enable_memory changed in BOOTSTRAP.md but still requires a restart "
|
||||
"because long-term memory contexts are created at startup.",
|
||||
)
|
||||
|
||||
sync_runtime_state(gateway)
|
||||
|
||||
return {
|
||||
"runtime_config_requested": runtime_config,
|
||||
"runtime_config_applied": {
|
||||
"tickers": list(gateway.config.get("tickers", [])),
|
||||
"schedule_mode": gateway.config.get("schedule_mode", "daily"),
|
||||
"interval_minutes": gateway.config.get("interval_minutes", 60),
|
||||
"trigger_time": gateway.config.get("trigger_time", "09:30"),
|
||||
"initial_cash": gateway.storage.initial_cash,
|
||||
"margin_requirement": gateway.config["margin_requirement"],
|
||||
"max_comm_cycles": gateway.config["max_comm_cycles"],
|
||||
"enable_memory": gateway.config.get("enable_memory", False),
|
||||
},
|
||||
"runtime_config_status": {
|
||||
"tickers": True,
|
||||
"schedule_mode": True,
|
||||
"interval_minutes": True,
|
||||
"trigger_time": True,
|
||||
"initial_cash": initial_cash_applied,
|
||||
"margin_requirement": pm_apply_result["margin_requirement"],
|
||||
"max_comm_cycles": True,
|
||||
"enable_memory": requested_enable_memory == current_enable_memory,
|
||||
},
|
||||
"ticker_changes": ticker_changes,
|
||||
"runtime_config_warnings": warnings,
|
||||
}
|
||||
|
||||
|
||||
def sync_runtime_state(gateway: Any) -> None:
|
||||
"""Refresh persisted state and dashboard after runtime config changes."""
|
||||
gateway.state_sync.update_state("tickers", gateway.config.get("tickers", []))
|
||||
gateway.state_sync.update_state(
|
||||
"runtime_config",
|
||||
{
|
||||
"tickers": gateway.config.get("tickers", []),
|
||||
"schedule_mode": gateway.config.get("schedule_mode", "daily"),
|
||||
"interval_minutes": gateway.config.get("interval_minutes", 60),
|
||||
"trigger_time": gateway.config.get("trigger_time", "09:30"),
|
||||
"initial_cash": gateway.storage.initial_cash,
|
||||
"margin_requirement": gateway.config.get("margin_requirement"),
|
||||
"max_comm_cycles": gateway.config.get("max_comm_cycles"),
|
||||
"enable_memory": gateway.config.get("enable_memory", False),
|
||||
},
|
||||
)
|
||||
|
||||
gateway.storage.update_server_state_from_dashboard(gateway.state_sync.state)
|
||||
gateway.state_sync.save_state()
|
||||
|
||||
gateway._dashboard.tickers = list(gateway.config.get("tickers", []))
|
||||
gateway._dashboard.initial_cash = gateway.storage.initial_cash
|
||||
gateway._dashboard.enable_memory = bool(gateway.config.get("enable_memory", False))
|
||||
|
||||
dashboard_snapshot = gateway.storage.build_dashboard_snapshot_from_state(gateway.state_sync.state)
|
||||
summary = dashboard_snapshot.get("summary") or {}
|
||||
holdings = dashboard_snapshot.get("holdings") or []
|
||||
trades = dashboard_snapshot.get("trades") or []
|
||||
gateway._dashboard.update(
|
||||
portfolio=summary,
|
||||
holdings=holdings,
|
||||
trades=trades,
|
||||
)
|
||||
716
backend/services/gateway_stock_handlers.py
Normal file
716
backend/services/gateway_stock_handlers.py
Normal file
@@ -0,0 +1,716 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Stock-related Gateway handlers extracted from the main Gateway module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
from backend.data.provider_utils import normalize_symbol
|
||||
from backend.domains import news as news_domain
|
||||
from backend.domains import trading as trading_domain
|
||||
from backend.enrich.news_enricher import enrich_news_for_symbol
|
||||
from backend.enrich.llm_enricher import llm_enrichment_enabled
|
||||
from backend.tools.data_tools import prices_to_df
|
||||
from shared.client import NewsServiceClient, TradingServiceClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def handle_get_stock_history(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||
ticker = normalize_symbol(data.get("ticker", ""))
|
||||
if not ticker:
|
||||
await websocket.send(json.dumps({
|
||||
"type": "stock_history_loaded",
|
||||
"ticker": "",
|
||||
"prices": [],
|
||||
"source": None,
|
||||
"error": "invalid ticker",
|
||||
}, ensure_ascii=False))
|
||||
return
|
||||
|
||||
lookback_days = data.get("lookback_days", 90)
|
||||
try:
|
||||
lookback_days = max(7, min(int(lookback_days), 365))
|
||||
except (TypeError, ValueError):
|
||||
lookback_days = 90
|
||||
|
||||
end_date = gateway.state_sync.state.get("current_date") or datetime.now().strftime("%Y-%m-%d")
|
||||
try:
|
||||
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
|
||||
except ValueError:
|
||||
end_dt = datetime.now()
|
||||
end_date = end_dt.strftime("%Y-%m-%d")
|
||||
start_date = (end_dt - timedelta(days=lookback_days)).strftime("%Y-%m-%d")
|
||||
|
||||
prices = []
|
||||
source = "polygon"
|
||||
response = await gateway._call_trading_service(
|
||||
"get_prices for history",
|
||||
lambda client: client.get_prices(ticker=ticker, start_date=start_date, end_date=end_date),
|
||||
)
|
||||
if response is not None:
|
||||
prices = response.prices
|
||||
source = "trading_service"
|
||||
|
||||
if not prices:
|
||||
prices = await asyncio.to_thread(gateway.storage.market_store.get_ohlc, ticker, start_date, end_date)
|
||||
if not prices:
|
||||
payload = await asyncio.to_thread(
|
||||
trading_domain.get_prices_payload,
|
||||
ticker=ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
prices = payload.get("prices") or []
|
||||
usage_snapshot = gateway._provider_router.get_usage_snapshot()
|
||||
source = usage_snapshot.get("last_success", {}).get("prices")
|
||||
if prices:
|
||||
await asyncio.to_thread(
|
||||
gateway.storage.market_store.upsert_ohlc,
|
||||
ticker,
|
||||
[price.model_dump() for price in prices],
|
||||
source=source or "provider",
|
||||
)
|
||||
|
||||
await websocket.send(json.dumps({
|
||||
"type": "stock_history_loaded",
|
||||
"ticker": ticker,
|
||||
"prices": [price if isinstance(price, dict) else price.model_dump() for price in prices][-120:],
|
||||
"source": source,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
}, ensure_ascii=False, default=str))
|
||||
|
||||
|
||||
async def handle_get_stock_explain_events(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||
ticker = normalize_symbol(data.get("ticker", ""))
|
||||
snapshot = gateway.storage.runtime_db.get_stock_explain_snapshot(ticker)
|
||||
await websocket.send(json.dumps({
|
||||
"type": "stock_explain_events_loaded",
|
||||
"ticker": ticker,
|
||||
"events": snapshot.get("events", []),
|
||||
"signals": snapshot.get("signals", []),
|
||||
"trades": snapshot.get("trades", []),
|
||||
}, ensure_ascii=False, default=str))
|
||||
|
||||
|
||||
async def handle_get_stock_news(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||
ticker = normalize_symbol(data.get("ticker", ""))
|
||||
if not ticker:
|
||||
await websocket.send(json.dumps({
|
||||
"type": "stock_news_loaded",
|
||||
"ticker": "",
|
||||
"news": [],
|
||||
"source": None,
|
||||
"error": "invalid ticker",
|
||||
}, ensure_ascii=False))
|
||||
return
|
||||
|
||||
lookback_days = data.get("lookback_days", 30)
|
||||
limit = data.get("limit", 12)
|
||||
try:
|
||||
lookback_days = max(7, min(int(lookback_days), 180))
|
||||
except (TypeError, ValueError):
|
||||
lookback_days = 30
|
||||
try:
|
||||
limit = max(1, min(int(limit), 30))
|
||||
except (TypeError, ValueError):
|
||||
limit = 12
|
||||
|
||||
end_date = gateway.state_sync.state.get("current_date") or datetime.now().strftime("%Y-%m-%d")
|
||||
try:
|
||||
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
|
||||
except ValueError:
|
||||
end_dt = datetime.now()
|
||||
end_date = end_dt.strftime("%Y-%m-%d")
|
||||
start_date = (end_dt - timedelta(days=lookback_days)).strftime("%Y-%m-%d")
|
||||
|
||||
news_rows = []
|
||||
source = "polygon"
|
||||
response = await gateway._call_news_service(
|
||||
"get_enriched_news",
|
||||
lambda client: client.get_enriched_news(
|
||||
ticker=ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=limit,
|
||||
),
|
||||
)
|
||||
if response is not None:
|
||||
news_rows = response.get("news") or []
|
||||
source = "news_service"
|
||||
|
||||
if not news_rows:
|
||||
payload = await asyncio.to_thread(
|
||||
news_domain.get_enriched_news,
|
||||
gateway.storage.market_store,
|
||||
ticker=ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=max(limit, 50),
|
||||
refresh_if_stale=False,
|
||||
)
|
||||
news_rows = (payload.get("news") or [])[-limit:]
|
||||
source = "market_store"
|
||||
|
||||
await websocket.send(json.dumps({
|
||||
"type": "stock_news_loaded",
|
||||
"ticker": ticker,
|
||||
"news": news_rows[-limit:],
|
||||
"source": source,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
}, ensure_ascii=False, default=str))
|
||||
|
||||
|
||||
async def handle_get_stock_news_for_date(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||
ticker = normalize_symbol(data.get("ticker", ""))
|
||||
trade_date = str(data.get("date") or "").strip()
|
||||
if not ticker or not trade_date:
|
||||
await websocket.send(json.dumps({
|
||||
"type": "stock_news_for_date_loaded",
|
||||
"ticker": ticker,
|
||||
"date": trade_date,
|
||||
"news": [],
|
||||
"error": "ticker and date are required",
|
||||
}, ensure_ascii=False))
|
||||
return
|
||||
|
||||
limit = data.get("limit", 20)
|
||||
try:
|
||||
limit = max(1, min(int(limit), 50))
|
||||
except (TypeError, ValueError):
|
||||
limit = 20
|
||||
|
||||
source = "market_store"
|
||||
news_rows = []
|
||||
response = await gateway._call_news_service(
|
||||
"get_news_for_date",
|
||||
lambda client: client.get_news_for_date(ticker=ticker, date=trade_date, limit=limit),
|
||||
)
|
||||
if response is not None:
|
||||
news_rows = response.get("news") or []
|
||||
source = "news_service"
|
||||
|
||||
if not news_rows:
|
||||
payload = await asyncio.to_thread(
|
||||
news_domain.get_news_for_date,
|
||||
gateway.storage.market_store,
|
||||
ticker=ticker,
|
||||
date=trade_date,
|
||||
limit=limit,
|
||||
refresh_if_stale=False,
|
||||
)
|
||||
news_rows = payload.get("news") or []
|
||||
source = "market_store"
|
||||
|
||||
await websocket.send(json.dumps({
|
||||
"type": "stock_news_for_date_loaded",
|
||||
"ticker": ticker,
|
||||
"date": trade_date,
|
||||
"news": news_rows,
|
||||
"source": source,
|
||||
}, ensure_ascii=False, default=str))
|
||||
|
||||
|
||||
async def handle_get_stock_news_timeline(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||
ticker = normalize_symbol(data.get("ticker", ""))
|
||||
if not ticker:
|
||||
await websocket.send(json.dumps({
|
||||
"type": "stock_news_timeline_loaded",
|
||||
"ticker": "",
|
||||
"timeline": [],
|
||||
"error": "invalid ticker",
|
||||
}, ensure_ascii=False))
|
||||
return
|
||||
|
||||
lookback_days = data.get("lookback_days", 90)
|
||||
try:
|
||||
lookback_days = max(7, min(int(lookback_days), 365))
|
||||
except (TypeError, ValueError):
|
||||
lookback_days = 90
|
||||
|
||||
end_date = gateway.state_sync.state.get("current_date") or datetime.now().strftime("%Y-%m-%d")
|
||||
try:
|
||||
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
|
||||
except ValueError:
|
||||
end_dt = datetime.now()
|
||||
end_date = end_dt.strftime("%Y-%m-%d")
|
||||
start_date = (end_dt - timedelta(days=lookback_days)).strftime("%Y-%m-%d")
|
||||
|
||||
timeline = []
|
||||
response = await gateway._call_news_service(
|
||||
"get_news_timeline",
|
||||
lambda client: client.get_news_timeline(ticker=ticker, start_date=start_date, end_date=end_date),
|
||||
)
|
||||
if response is not None:
|
||||
timeline = response.get("timeline") or []
|
||||
|
||||
if not timeline:
|
||||
payload = await asyncio.to_thread(
|
||||
news_domain.get_news_timeline,
|
||||
gateway.storage.market_store,
|
||||
ticker=ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
refresh_if_stale=False,
|
||||
)
|
||||
timeline = payload.get("timeline") or []
|
||||
|
||||
await websocket.send(json.dumps({
|
||||
"type": "stock_news_timeline_loaded",
|
||||
"ticker": ticker,
|
||||
"timeline": timeline,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
}, ensure_ascii=False, default=str))
|
||||
|
||||
|
||||
async def handle_get_stock_news_categories(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||
ticker = normalize_symbol(data.get("ticker", ""))
|
||||
if not ticker:
|
||||
await websocket.send(json.dumps({
|
||||
"type": "stock_news_categories_loaded",
|
||||
"ticker": "",
|
||||
"categories": {},
|
||||
"error": "invalid ticker",
|
||||
}, ensure_ascii=False))
|
||||
return
|
||||
|
||||
lookback_days = data.get("lookback_days", 90)
|
||||
try:
|
||||
lookback_days = max(7, min(int(lookback_days), 365))
|
||||
except (TypeError, ValueError):
|
||||
lookback_days = 90
|
||||
|
||||
end_date = gateway.state_sync.state.get("current_date") or datetime.now().strftime("%Y-%m-%d")
|
||||
try:
|
||||
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
|
||||
except ValueError:
|
||||
end_dt = datetime.now()
|
||||
end_date = end_dt.strftime("%Y-%m-%d")
|
||||
start_date = (end_dt - timedelta(days=lookback_days)).strftime("%Y-%m-%d")
|
||||
|
||||
categories = {}
|
||||
response = await gateway._call_news_service(
|
||||
"get_categories",
|
||||
lambda client: client.get_categories(
|
||||
ticker=ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=200,
|
||||
),
|
||||
)
|
||||
if response is not None:
|
||||
categories = response.get("categories") or {}
|
||||
|
||||
if not categories:
|
||||
payload = await asyncio.to_thread(
|
||||
news_domain.get_news_categories,
|
||||
gateway.storage.market_store,
|
||||
ticker=ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=200,
|
||||
refresh_if_stale=False,
|
||||
)
|
||||
categories = payload.get("categories") or {}
|
||||
|
||||
await websocket.send(json.dumps({
|
||||
"type": "stock_news_categories_loaded",
|
||||
"ticker": ticker,
|
||||
"categories": categories,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
}, ensure_ascii=False, default=str))
|
||||
|
||||
|
||||
async def handle_get_stock_range_explain(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||
ticker = normalize_symbol(data.get("ticker", ""))
|
||||
start_date = str(data.get("start_date") or "").strip()
|
||||
end_date = str(data.get("end_date") or "").strip()
|
||||
if not ticker or not start_date or not end_date:
|
||||
await websocket.send(json.dumps({
|
||||
"type": "stock_range_explain_loaded",
|
||||
"ticker": ticker,
|
||||
"result": {"error": "ticker, start_date, end_date are required"},
|
||||
}, ensure_ascii=False))
|
||||
return
|
||||
|
||||
article_ids = data.get("article_ids")
|
||||
result = None
|
||||
response = await gateway._call_news_service(
|
||||
"get_range_explain",
|
||||
lambda client: client.get_range_explain(
|
||||
ticker=ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
article_ids=article_ids if isinstance(article_ids, list) else None,
|
||||
limit=100,
|
||||
),
|
||||
)
|
||||
if response is not None:
|
||||
result = response.get("result")
|
||||
|
||||
if result is None:
|
||||
payload = await asyncio.to_thread(
|
||||
news_domain.get_range_explain_payload,
|
||||
gateway.storage.market_store,
|
||||
ticker=ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
article_ids=article_ids if isinstance(article_ids, list) else None,
|
||||
limit=100,
|
||||
refresh_if_stale=False,
|
||||
)
|
||||
result = payload.get("result")
|
||||
|
||||
await websocket.send(json.dumps({
|
||||
"type": "stock_range_explain_loaded",
|
||||
"ticker": ticker,
|
||||
"result": result,
|
||||
}, ensure_ascii=False, default=str))
|
||||
|
||||
|
||||
async def handle_get_stock_insider_trades(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||
ticker = normalize_symbol(data.get("ticker", ""))
|
||||
if not ticker:
|
||||
await websocket.send(json.dumps({
|
||||
"type": "stock_insider_trades_loaded",
|
||||
"ticker": "",
|
||||
"trades": [],
|
||||
"error": "invalid ticker",
|
||||
}, ensure_ascii=False))
|
||||
return
|
||||
|
||||
end_date = str(data.get("end_date") or gateway.state_sync.state.get("current_date") or datetime.now().strftime("%Y-%m-%d")).strip()[:10]
|
||||
start_date = str(data.get("start_date") or "").strip()[:10]
|
||||
limit = int(data.get("limit", 50))
|
||||
|
||||
trades = []
|
||||
response = await gateway._call_trading_service(
|
||||
"get_insider_trades",
|
||||
lambda client: client.get_insider_trades(
|
||||
ticker=ticker,
|
||||
end_date=end_date,
|
||||
start_date=start_date if start_date else None,
|
||||
limit=limit,
|
||||
),
|
||||
)
|
||||
if response is not None:
|
||||
trades = response.insider_trades
|
||||
|
||||
if not trades:
|
||||
payload = await asyncio.to_thread(
|
||||
trading_domain.get_insider_trades_payload,
|
||||
ticker=ticker,
|
||||
end_date=end_date,
|
||||
start_date=start_date if start_date else None,
|
||||
limit=limit,
|
||||
)
|
||||
trades = payload.get("insider_trades") or []
|
||||
|
||||
sorted_trades = sorted(trades, key=lambda t: t.transaction_date or "", reverse=True)
|
||||
formatted_trades = [{
|
||||
"ticker": t.ticker,
|
||||
"name": t.name,
|
||||
"title": t.title,
|
||||
"is_board_director": t.is_board_director,
|
||||
"transaction_date": t.transaction_date,
|
||||
"transaction_shares": t.transaction_shares,
|
||||
"transaction_price_per_share": t.transaction_price_per_share,
|
||||
"transaction_value": t.transaction_value,
|
||||
"shares_owned_before_transaction": t.shares_owned_before_transaction,
|
||||
"shares_owned_after_transaction": t.shares_owned_after_transaction,
|
||||
"security_title": t.security_title,
|
||||
"filing_date": t.filing_date,
|
||||
"holding_change": (
|
||||
(t.shares_owned_after_transaction or 0) - (t.shares_owned_before_transaction or 0)
|
||||
if t.shares_owned_after_transaction and t.shares_owned_before_transaction else None
|
||||
),
|
||||
"is_buy": ((t.transaction_shares or 0) > 0) if t.transaction_shares is not None else None,
|
||||
} for t in sorted_trades]
|
||||
|
||||
await websocket.send(json.dumps({
|
||||
"type": "stock_insider_trades_loaded",
|
||||
"ticker": ticker,
|
||||
"start_date": start_date or None,
|
||||
"end_date": end_date,
|
||||
"trades": formatted_trades,
|
||||
}, ensure_ascii=False, default=str))
|
||||
|
||||
|
||||
async def handle_get_stock_story(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||
ticker = normalize_symbol(data.get("ticker", ""))
|
||||
if not ticker:
|
||||
await websocket.send(json.dumps({
|
||||
"type": "stock_story_loaded",
|
||||
"ticker": "",
|
||||
"story": "",
|
||||
"error": "invalid ticker",
|
||||
}, ensure_ascii=False))
|
||||
return
|
||||
|
||||
as_of_date = str(data.get("as_of_date") or gateway.state_sync.state.get("current_date") or datetime.now().strftime("%Y-%m-%d")).strip()[:10]
|
||||
result = await gateway._call_news_service(
|
||||
"get_story",
|
||||
lambda client: client.get_story(ticker=ticker, as_of_date=as_of_date),
|
||||
)
|
||||
if result is None:
|
||||
result = await asyncio.to_thread(
|
||||
news_domain.get_story_payload,
|
||||
gateway.storage.market_store,
|
||||
ticker=ticker,
|
||||
as_of_date=as_of_date,
|
||||
)
|
||||
|
||||
await websocket.send(json.dumps({
|
||||
"type": "stock_story_loaded",
|
||||
"ticker": ticker,
|
||||
"as_of_date": as_of_date,
|
||||
"story": result.get("story") or "",
|
||||
"source": result.get("source") or "local",
|
||||
}, ensure_ascii=False, default=str))
|
||||
|
||||
|
||||
async def handle_get_stock_similar_days(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||
ticker = normalize_symbol(data.get("ticker", ""))
|
||||
target_date = str(data.get("date") or "").strip()[:10]
|
||||
if not ticker or not target_date:
|
||||
await websocket.send(json.dumps({
|
||||
"type": "stock_similar_days_loaded",
|
||||
"ticker": ticker,
|
||||
"date": target_date,
|
||||
"items": [],
|
||||
"error": "ticker and date are required",
|
||||
}, ensure_ascii=False))
|
||||
return
|
||||
|
||||
top_k = data.get("top_k", 8)
|
||||
try:
|
||||
top_k = max(1, min(int(top_k), 20))
|
||||
except (TypeError, ValueError):
|
||||
top_k = 8
|
||||
|
||||
result = await gateway._call_news_service(
|
||||
"get_similar_days",
|
||||
lambda client: client.get_similar_days(ticker=ticker, date=target_date, n_similar=top_k),
|
||||
)
|
||||
if result is None:
|
||||
result = await asyncio.to_thread(
|
||||
news_domain.get_similar_days_payload,
|
||||
gateway.storage.market_store,
|
||||
ticker=ticker,
|
||||
date=target_date,
|
||||
n_similar=top_k,
|
||||
)
|
||||
|
||||
await websocket.send(json.dumps({
|
||||
"type": "stock_similar_days_loaded",
|
||||
"ticker": ticker,
|
||||
"date": target_date,
|
||||
**result,
|
||||
}, ensure_ascii=False, default=str))
|
||||
|
||||
|
||||
async def handle_get_stock_technical_indicators(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||
ticker = normalize_symbol(data.get("ticker", ""))
|
||||
if not ticker:
|
||||
await websocket.send(json.dumps({
|
||||
"type": "stock_technical_indicators_loaded",
|
||||
"ticker": ticker,
|
||||
"indicators": None,
|
||||
"error": "ticker is required",
|
||||
}, ensure_ascii=False))
|
||||
return
|
||||
|
||||
try:
|
||||
end_date = datetime.now()
|
||||
start_date = end_date - timedelta(days=250)
|
||||
|
||||
prices = None
|
||||
response = await gateway._call_trading_service(
|
||||
"get_prices",
|
||||
lambda client: client.get_prices(
|
||||
ticker=ticker,
|
||||
start_date=start_date.strftime("%Y-%m-%d"),
|
||||
end_date=end_date.strftime("%Y-%m-%d"),
|
||||
),
|
||||
)
|
||||
if response is not None:
|
||||
prices = response.prices
|
||||
|
||||
if prices is None:
|
||||
payload = trading_domain.get_prices_payload(
|
||||
ticker=ticker,
|
||||
start_date=start_date.strftime("%Y-%m-%d"),
|
||||
end_date=end_date.strftime("%Y-%m-%d"),
|
||||
)
|
||||
prices = payload.get("prices") or []
|
||||
|
||||
if not prices or len(prices) < 20:
|
||||
await websocket.send(json.dumps({
|
||||
"type": "stock_technical_indicators_loaded",
|
||||
"ticker": ticker,
|
||||
"indicators": None,
|
||||
"error": "Insufficient price data",
|
||||
}, ensure_ascii=False))
|
||||
return
|
||||
|
||||
df = prices_to_df(prices)
|
||||
signal = gateway._technical_analyzer.analyze(ticker, df)
|
||||
|
||||
import pandas as pd
|
||||
df_sorted = df.sort_values("time").reset_index(drop=True)
|
||||
df_sorted["returns"] = df_sorted["close"].pct_change()
|
||||
vol_10 = float(df_sorted["returns"].tail(10).std() * (252**0.5) * 100) if len(df_sorted) >= 10 else None
|
||||
vol_20 = float(df_sorted["returns"].tail(20).std() * (252**0.5) * 100) if len(df_sorted) >= 20 else None
|
||||
vol_60 = float(df_sorted["returns"].tail(60).std() * (252**0.5) * 100) if len(df_sorted) >= 60 else None
|
||||
ma_distance = {}
|
||||
for ma_key in ["ma5", "ma10", "ma20", "ma50", "ma200"]:
|
||||
ma_value = getattr(signal, ma_key, None)
|
||||
ma_distance[ma_key] = ((signal.current_price - ma_value) / ma_value) * 100 if ma_value and ma_value > 0 else None
|
||||
|
||||
indicators = {
|
||||
"ticker": ticker,
|
||||
"current_price": signal.current_price,
|
||||
"ma": {
|
||||
"ma5": signal.ma5,
|
||||
"ma10": signal.ma10,
|
||||
"ma20": signal.ma20,
|
||||
"ma50": signal.ma50,
|
||||
"ma200": signal.ma200,
|
||||
"distance": ma_distance,
|
||||
},
|
||||
"rsi": {
|
||||
"rsi14": signal.rsi14,
|
||||
"status": "oversold" if signal.rsi14 < 30 else "overbought" if signal.rsi14 > 70 else "neutral",
|
||||
},
|
||||
"macd": {
|
||||
"macd": signal.macd,
|
||||
"signal": signal.macd_signal,
|
||||
"histogram": signal.macd - signal.macd_signal,
|
||||
},
|
||||
"bollinger": {
|
||||
"upper": signal.bollinger_upper,
|
||||
"mid": signal.bollinger_mid,
|
||||
"lower": signal.bollinger_lower,
|
||||
},
|
||||
"volatility": {
|
||||
"vol_10d": vol_10,
|
||||
"vol_20d": vol_20,
|
||||
"vol_60d": vol_60,
|
||||
"annualized": signal.annualized_volatility_pct,
|
||||
"risk_level": signal.risk_level,
|
||||
},
|
||||
"trend": signal.trend,
|
||||
"mean_reversion": signal.mean_reversion_signal,
|
||||
}
|
||||
|
||||
await websocket.send(json.dumps({
|
||||
"type": "stock_technical_indicators_loaded",
|
||||
"ticker": ticker,
|
||||
"indicators": indicators,
|
||||
}, ensure_ascii=False, default=str))
|
||||
except Exception as exc:
|
||||
logger.exception("Error getting technical indicators for %s", ticker)
|
||||
await websocket.send(json.dumps({
|
||||
"type": "stock_technical_indicators_loaded",
|
||||
"ticker": ticker,
|
||||
"indicators": None,
|
||||
"error": str(exc),
|
||||
}, ensure_ascii=False))
|
||||
|
||||
|
||||
async def handle_run_stock_enrich(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
|
||||
ticker = normalize_symbol(data.get("ticker", ""))
|
||||
start_date = str(data.get("start_date") or "").strip()[:10]
|
||||
end_date = str(data.get("end_date") or "").strip()[:10]
|
||||
story_date = str(data.get("story_date") or end_date or "").strip()[:10]
|
||||
target_date = str(data.get("target_date") or "").strip()[:10]
|
||||
force = bool(data.get("force", False))
|
||||
rebuild_story = bool(data.get("rebuild_story", True))
|
||||
rebuild_similar_days = bool(data.get("rebuild_similar_days", True))
|
||||
only_local_to_llm = bool(data.get("only_local_to_llm", False))
|
||||
limit = data.get("limit", 200)
|
||||
|
||||
try:
|
||||
limit = max(10, min(int(limit), 500))
|
||||
except (TypeError, ValueError):
|
||||
limit = 200
|
||||
|
||||
if not ticker or not start_date or not end_date:
|
||||
await websocket.send(json.dumps({
|
||||
"type": "stock_enrich_completed",
|
||||
"ticker": ticker,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
"error": "ticker, start_date, end_date are required",
|
||||
}, ensure_ascii=False))
|
||||
return
|
||||
|
||||
if only_local_to_llm and not llm_enrichment_enabled():
|
||||
await websocket.send(json.dumps({
|
||||
"type": "stock_enrich_completed",
|
||||
"ticker": ticker,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
"error": "only_local_to_llm requires EXPLAIN_ENRICH_USE_LLM=true and a configured LLM provider",
|
||||
}, ensure_ascii=False))
|
||||
return
|
||||
|
||||
result = await asyncio.to_thread(
|
||||
enrich_news_for_symbol,
|
||||
gateway.storage.market_store,
|
||||
ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=limit,
|
||||
skip_existing=not force,
|
||||
only_reanalyze_local=only_local_to_llm,
|
||||
)
|
||||
|
||||
story_status = None
|
||||
if rebuild_story and story_date:
|
||||
await asyncio.to_thread(gateway.storage.market_store.delete_story_cache, ticker, as_of_date=story_date)
|
||||
story_result = await asyncio.to_thread(
|
||||
news_domain.get_story_payload,
|
||||
gateway.storage.market_store,
|
||||
ticker=ticker,
|
||||
as_of_date=story_date,
|
||||
)
|
||||
story_status = {"as_of_date": story_date, "source": story_result.get("source") or "local"}
|
||||
|
||||
similar_status = None
|
||||
if rebuild_similar_days and target_date:
|
||||
await asyncio.to_thread(gateway.storage.market_store.delete_similar_day_cache, ticker, target_date=target_date)
|
||||
similar_result = await asyncio.to_thread(
|
||||
news_domain.get_similar_days_payload,
|
||||
gateway.storage.market_store,
|
||||
ticker=ticker,
|
||||
date=target_date,
|
||||
n_similar=8,
|
||||
)
|
||||
similar_status = {
|
||||
"target_date": target_date,
|
||||
"count": len(similar_result.get("items") or []),
|
||||
"error": similar_result.get("error"),
|
||||
}
|
||||
|
||||
await websocket.send(json.dumps({
|
||||
"type": "stock_enrich_completed",
|
||||
"ticker": ticker,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
"story_date": story_date or None,
|
||||
"target_date": target_date or None,
|
||||
"force": force,
|
||||
"only_local_to_llm": only_local_to_llm,
|
||||
"stats": result,
|
||||
"story_status": story_status,
|
||||
"similar_status": similar_status,
|
||||
}, ensure_ascii=False, default=str))
|
||||
@@ -1,7 +1,7 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Market Data Service
|
||||
Supports live, mock, and backtest modes
|
||||
Supports live and backtest modes
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
@@ -10,7 +10,7 @@ from typing import Any, Callable, Dict, List, Optional
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
import pandas_market_calendars as mcal
|
||||
from backend.config.data_config import get_data_source
|
||||
from backend.config.data_config import get_data_sources
|
||||
from backend.data.provider_utils import normalize_symbol
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -36,7 +36,6 @@ class MarketService:
|
||||
self,
|
||||
tickers: List[str],
|
||||
poll_interval: int = 10,
|
||||
mock_mode: bool = False,
|
||||
backtest_mode: bool = False,
|
||||
api_key: Optional[str] = None,
|
||||
backtest_start_date: Optional[str] = None,
|
||||
@@ -44,7 +43,6 @@ class MarketService:
|
||||
):
|
||||
self.tickers = [normalize_symbol(ticker) for ticker in tickers]
|
||||
self.poll_interval = poll_interval
|
||||
self.mock_mode = mock_mode
|
||||
self.backtest_mode = backtest_mode
|
||||
self.api_key = api_key
|
||||
self.backtest_start_date = backtest_start_date
|
||||
@@ -69,8 +67,6 @@ class MarketService:
|
||||
"""Return the active live quote provider for UI/debugging."""
|
||||
if self.backtest_mode:
|
||||
return "backtest"
|
||||
if self.mock_mode:
|
||||
return "mock"
|
||||
if self._price_manager and hasattr(self._price_manager, "provider"):
|
||||
provider = getattr(self._price_manager, "provider", None)
|
||||
if isinstance(provider, str) and provider.strip():
|
||||
@@ -81,8 +77,6 @@ class MarketService:
|
||||
def mode_name(self) -> str:
|
||||
if self.backtest_mode:
|
||||
return "BACKTEST"
|
||||
elif self.mock_mode:
|
||||
return "MOCK"
|
||||
return "LIVE"
|
||||
|
||||
async def start(self, broadcast_func: Callable):
|
||||
@@ -96,8 +90,6 @@ class MarketService:
|
||||
|
||||
if self.backtest_mode:
|
||||
self._start_backtest_mode()
|
||||
elif self.mock_mode:
|
||||
self._start_mock_mode()
|
||||
else:
|
||||
self._start_real_mode()
|
||||
|
||||
@@ -125,26 +117,10 @@ class MarketService:
|
||||
|
||||
return callback
|
||||
|
||||
def _start_mock_mode(self):
|
||||
from backend.data.mock_price_manager import MockPriceManager
|
||||
|
||||
self._price_manager = MockPriceManager(
|
||||
poll_interval=self.poll_interval,
|
||||
volatility=0.5,
|
||||
)
|
||||
self._price_manager.add_price_callback(self._make_price_callback())
|
||||
self._price_manager.subscribe(
|
||||
self.tickers,
|
||||
base_prices={t: 100.0 for t in self.tickers},
|
||||
)
|
||||
self._price_manager.start()
|
||||
|
||||
def _start_real_mode(self):
|
||||
from backend.data.polling_price_manager import PollingPriceManager
|
||||
|
||||
provider = get_data_source()
|
||||
if provider == "local_csv":
|
||||
provider = "yfinance"
|
||||
provider = self._resolve_live_quote_provider()
|
||||
|
||||
if provider == "finnhub" and not self.api_key:
|
||||
raise ValueError("API key required for live mode")
|
||||
@@ -157,6 +133,13 @@ class MarketService:
|
||||
self._price_manager.subscribe(self.tickers)
|
||||
self._price_manager.start()
|
||||
|
||||
def _resolve_live_quote_provider(self) -> str:
|
||||
"""Pick the first configured provider that supports live quote polling."""
|
||||
for provider in get_data_sources():
|
||||
if provider in {"finnhub", "yfinance"}:
|
||||
return provider
|
||||
return "yfinance"
|
||||
|
||||
def _start_backtest_mode(self):
|
||||
from backend.data.historical_price_manager import (
|
||||
HistoricalPriceManager,
|
||||
@@ -257,13 +240,7 @@ class MarketService:
|
||||
if removed:
|
||||
self._price_manager.unsubscribe(removed)
|
||||
if added:
|
||||
if self.mock_mode:
|
||||
self._price_manager.subscribe(
|
||||
added,
|
||||
base_prices={ticker: 100.0 for ticker in added},
|
||||
)
|
||||
else:
|
||||
self._price_manager.subscribe(added)
|
||||
self._price_manager.subscribe(added)
|
||||
|
||||
if self.backtest_mode and self._current_date:
|
||||
self._price_manager.set_date(self._current_date)
|
||||
|
||||
@@ -9,7 +9,7 @@ from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterable
|
||||
|
||||
from backend.data.schema import CompanyNews
|
||||
from shared.schema import CompanyNews
|
||||
|
||||
|
||||
SCHEMA = """
|
||||
|
||||
@@ -11,7 +11,6 @@ from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from backend.data.market_store import MarketStore
|
||||
from .research_db import ResearchDb
|
||||
from .runtime_db import RuntimeDb
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -22,12 +21,18 @@ class StorageService:
|
||||
Storage service for data persistence
|
||||
|
||||
Responsibilities:
|
||||
1. Load/save dashboard JSON files
|
||||
1. Export dashboard JSON files
|
||||
(summary, holdings, stats, trades, leaderboard)
|
||||
2. Load/save internal state (_internal_state.json)
|
||||
3. Load/save server state (server_state.json) with feed history
|
||||
4. Manage portfolio state persistence
|
||||
5. Support loading from saved state to resume execution
|
||||
|
||||
Notes:
|
||||
- team_dashboard/*.json is treated as an export/compatibility layer
|
||||
rather than the authoritative runtime source of truth.
|
||||
- authoritative runtime reads should prefer in-memory state, server_state,
|
||||
runtime.db, and market_research.db.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -49,7 +54,7 @@ class StorageService:
|
||||
self.initial_cash = initial_cash
|
||||
self.config_name = config_name
|
||||
|
||||
# Dashboard file paths
|
||||
# Dashboard export file paths
|
||||
self.files = {
|
||||
"summary": self.dashboard_dir / "summary.json",
|
||||
"holdings": self.dashboard_dir / "holdings.json",
|
||||
@@ -66,7 +71,6 @@ class StorageService:
|
||||
self.state_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.server_state_file = self.state_dir / "server_state.json"
|
||||
self.runtime_db = RuntimeDb(self.state_dir / "runtime.db")
|
||||
self.research_db = ResearchDb(self.state_dir / "research.db")
|
||||
self.market_store = MarketStore()
|
||||
|
||||
# Feed history (for agent messages)
|
||||
@@ -84,16 +88,8 @@ class StorageService:
|
||||
|
||||
logger.info(f"Storage service initialized: {self.dashboard_dir}")
|
||||
|
||||
def load_file(self, file_type: str) -> Optional[Any]:
|
||||
"""
|
||||
Load dashboard JSON file
|
||||
|
||||
Args:
|
||||
file_type: One of: summary, holdings, stats, trades, leaderboard
|
||||
|
||||
Returns:
|
||||
Loaded data or None if file doesn't exist
|
||||
"""
|
||||
def load_export_file(self, file_type: str) -> Optional[Any]:
|
||||
"""Load dashboard export JSON file."""
|
||||
file_path = self.files.get(file_type)
|
||||
if not file_path or not file_path.exists():
|
||||
return None
|
||||
@@ -105,14 +101,12 @@ class StorageService:
|
||||
logger.error(f"Failed to load {file_type}.json: {e}")
|
||||
return None
|
||||
|
||||
def save_file(self, file_type: str, data: Any):
|
||||
"""
|
||||
Save dashboard JSON file
|
||||
def load_file(self, file_type: str) -> Optional[Any]:
|
||||
"""Backward-compatible alias for export-layer JSON reads."""
|
||||
return self.load_export_file(file_type)
|
||||
|
||||
Args:
|
||||
file_type: One of: summary, holdings, stats, trades, leaderboard
|
||||
data: Data to save
|
||||
"""
|
||||
def save_export_file(self, file_type: str, data: Any):
|
||||
"""Save dashboard export JSON file."""
|
||||
file_path = self.files.get(file_type)
|
||||
if not file_path:
|
||||
logger.error(f"Unknown file type: {file_type}")
|
||||
@@ -129,6 +123,48 @@ class StorageService:
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save {file_type}.json: {e}")
|
||||
|
||||
def save_file(self, file_type: str, data: Any):
|
||||
"""Backward-compatible alias for export-layer JSON writes."""
|
||||
self.save_export_file(file_type, data)
|
||||
|
||||
def build_dashboard_snapshot_from_state(
|
||||
self,
|
||||
state: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Build dashboard view data from runtime state instead of JSON exports."""
|
||||
runtime_state = state or self.load_server_state()
|
||||
portfolio = dict(runtime_state.get("portfolio") or {})
|
||||
holdings = list(runtime_state.get("holdings") or [])
|
||||
stats = runtime_state.get("stats") or self._get_default_stats()
|
||||
trades = list(runtime_state.get("trades") or [])
|
||||
leaderboard = list(runtime_state.get("leaderboard") or [])
|
||||
|
||||
summary = {
|
||||
"totalAssetValue": portfolio.get("total_value", self.initial_cash),
|
||||
"totalReturn": portfolio.get("pnl_percent", 0.0),
|
||||
"cashPosition": portfolio.get("cash", self.initial_cash),
|
||||
"tickerWeights": stats.get("tickerWeights", {}),
|
||||
"totalTrades": len(trades),
|
||||
"pnlPct": portfolio.get("pnl_percent", 0.0),
|
||||
"balance": portfolio.get("total_value", self.initial_cash),
|
||||
"equity": portfolio.get("equity", []),
|
||||
"baseline": portfolio.get("baseline", []),
|
||||
"baseline_vw": portfolio.get("baseline_vw", []),
|
||||
"momentum": portfolio.get("momentum", []),
|
||||
"equity_return": portfolio.get("equity_return", []),
|
||||
"baseline_return": portfolio.get("baseline_return", []),
|
||||
"baseline_vw_return": portfolio.get("baseline_vw_return", []),
|
||||
"momentum_return": portfolio.get("momentum_return", []),
|
||||
}
|
||||
|
||||
return {
|
||||
"summary": summary,
|
||||
"holdings": holdings,
|
||||
"stats": stats,
|
||||
"trades": trades,
|
||||
"leaderboard": leaderboard,
|
||||
}
|
||||
|
||||
def check_file_updates(self) -> Dict[str, bool]:
|
||||
"""
|
||||
Check which dashboard files have been updated since last check
|
||||
@@ -297,7 +333,7 @@ class StorageService:
|
||||
def initialize_empty_dashboard(self):
|
||||
"""Initialize empty dashboard files with default values"""
|
||||
# Summary
|
||||
self.save_file(
|
||||
self.save_export_file(
|
||||
"summary",
|
||||
{
|
||||
"totalAssetValue": self.initial_cash,
|
||||
@@ -315,10 +351,10 @@ class StorageService:
|
||||
)
|
||||
|
||||
# Holdings
|
||||
self.save_file("holdings", [])
|
||||
self.save_export_file("holdings", [])
|
||||
|
||||
# Stats
|
||||
self.save_file(
|
||||
self.save_export_file(
|
||||
"stats",
|
||||
{
|
||||
"totalAssetValue": self.initial_cash,
|
||||
@@ -335,7 +371,7 @@ class StorageService:
|
||||
)
|
||||
|
||||
# Trades
|
||||
self.save_file("trades", [])
|
||||
self.save_export_file("trades", [])
|
||||
|
||||
# Leaderboard with model info
|
||||
self.generate_leaderboard()
|
||||
@@ -375,7 +411,7 @@ class StorageService:
|
||||
ranking_entries.append(entry)
|
||||
|
||||
leaderboard = team_entries + ranking_entries
|
||||
self.save_file("leaderboard", leaderboard)
|
||||
self.save_export_file("leaderboard", leaderboard)
|
||||
logger.info("Leaderboard generated with model info")
|
||||
|
||||
def update_leaderboard_model_info(self):
|
||||
@@ -398,7 +434,7 @@ class StorageService:
|
||||
entry["modelName"] = model_name
|
||||
entry["modelProvider"] = model_provider
|
||||
|
||||
self.save_file("leaderboard", existing)
|
||||
self.save_export_file("leaderboard", existing)
|
||||
logger.info("Leaderboard model info updated")
|
||||
|
||||
def get_current_timestamp_ms(self, date: str = None) -> int:
|
||||
@@ -653,7 +689,7 @@ class StorageService:
|
||||
"momentum": state.get("momentum_history", []),
|
||||
}
|
||||
|
||||
self.save_file("summary", summary)
|
||||
self.save_export_file("summary", summary)
|
||||
|
||||
def _generate_holdings(
|
||||
self,
|
||||
@@ -715,7 +751,7 @@ class StorageService:
|
||||
# Sort by weight
|
||||
holdings.sort(key=lambda x: abs(x["weight"]), reverse=True)
|
||||
|
||||
self.save_file("holdings", holdings)
|
||||
self.save_export_file("holdings", holdings)
|
||||
|
||||
def _generate_stats(self, state: Dict[str, Any], net_value: float):
|
||||
"""Generate stats.json"""
|
||||
@@ -738,7 +774,7 @@ class StorageService:
|
||||
},
|
||||
}
|
||||
|
||||
self.save_file("stats", stats)
|
||||
self.save_export_file("stats", stats)
|
||||
|
||||
def _generate_trades(self, state: Dict[str, Any]):
|
||||
"""Generate trades.json"""
|
||||
@@ -764,7 +800,7 @@ class StorageService:
|
||||
},
|
||||
)
|
||||
|
||||
self.save_file("trades", trades)
|
||||
self.save_export_file("trades", trades)
|
||||
|
||||
# Server State Management Methods
|
||||
|
||||
@@ -1001,12 +1037,12 @@ class StorageService:
|
||||
Args:
|
||||
state: Server state dictionary to update
|
||||
"""
|
||||
# Load dashboard data
|
||||
summary = self.load_file("summary") or {}
|
||||
holdings = self.load_file("holdings") or []
|
||||
stats = self.load_file("stats") or self._get_default_stats()
|
||||
trades = self.load_file("trades") or []
|
||||
leaderboard = self.load_file("leaderboard") or []
|
||||
dashboard_snapshot = self.build_dashboard_snapshot_from_state(state)
|
||||
summary = dashboard_snapshot.get("summary") or {}
|
||||
holdings = dashboard_snapshot.get("holdings") or []
|
||||
stats = dashboard_snapshot.get("stats") or self._get_default_stats()
|
||||
trades = dashboard_snapshot.get("trades") or []
|
||||
leaderboard = dashboard_snapshot.get("leaderboard") or []
|
||||
internal_state = self.load_internal_state()
|
||||
|
||||
# Update state
|
||||
@@ -1040,7 +1076,6 @@ class StorageService:
|
||||
Start tracking live returns for current trading session.
|
||||
Captures current values as session start baseline.
|
||||
"""
|
||||
summary = self.load_file("summary") or {}
|
||||
state = self.load_internal_state()
|
||||
|
||||
# Capture current values as session start
|
||||
@@ -1052,7 +1087,7 @@ class StorageService:
|
||||
self._session_start_equity = (
|
||||
equity_history[-1]["v"]
|
||||
if equity_history
|
||||
else summary.get("totalAssetValue", self.initial_cash)
|
||||
else self.initial_cash
|
||||
)
|
||||
self._session_start_baseline = (
|
||||
baseline_history[-1]["v"]
|
||||
|
||||
119
backend/skills/SKILL_TEMPLATE.md
Normal file
119
backend/skills/SKILL_TEMPLATE.md
Normal file
@@ -0,0 +1,119 @@
|
||||
# Skill Template (Anthropic + AgentScope Aligned)
|
||||
|
||||
> 用于定义可执行、可路由、可评估的技能规范。
|
||||
> 建议所有 `SKILL.md` 至少覆盖以下 6 个部分。
|
||||
|
||||
---
|
||||
|
||||
## Frontmatter Spec
|
||||
|
||||
All `SKILL.md` files should begin with a YAML frontmatter block:
|
||||
|
||||
```yaml
|
||||
---
|
||||
name: skill_name # Required. Unique identifier for the skill.
|
||||
description: ... # Required. One-line description of the skill.
|
||||
version: "1.0.0" # Optional. Semantic version string.
|
||||
tools: [...] # Optional. Tools provided or used by this skill.
|
||||
allowed_tools: [...] # Optional. List of tool names permitted when this skill is active.
|
||||
denied_tools: [...] # Optional. List of tool names denied when this skill is active.
|
||||
---
|
||||
```
|
||||
|
||||
### Frontmatter Fields
|
||||
|
||||
| Field | Type | Description |
|
||||
|-------|------|-------------|
|
||||
| `name` | string | Unique skill identifier (kebab-case recommended). |
|
||||
| `description` | string | Human-readable one-line description. |
|
||||
| `version` | string | Semantic version (e.g., `"1.0.0"`). |
|
||||
| `tools` | list[string] | Tools provided by or associated with this skill. |
|
||||
| `allowed_tools` | list[string] | Enumerates which tools are **permitted** when this skill is active. If set, only these tools may be used. |
|
||||
| `denied_tools` | list[string] | Enumerates which tools are **forbidden** when this skill is active. Denied tools take precedence over `allowed_tools`. |
|
||||
|
||||
### Tool Restriction Rules
|
||||
|
||||
- If **only** `allowed_tools` is set: only those tools are accessible.
|
||||
- If **only** `denied_tools` is set: all tools except those are accessible.
|
||||
- If **both** are set: `allowed_tools` defines the initial set, then `denied_tools` removes from it.
|
||||
- **Denial takes precedence**: a tool in `denied_tools` is always blocked even if also in `allowed_tools`.
|
||||
|
||||
---
|
||||
|
||||
## 1) When to use
|
||||
|
||||
- 明确触发条件(任务类型、关键词、场景)。
|
||||
- 明确不应使用该技能的边界(避免误触发)。
|
||||
|
||||
## 2) Required inputs
|
||||
|
||||
- 列出最小必要输入(如 `tickers`、价格、组合状态、风险约束)。
|
||||
- 声明输入缺失时的处理规则(终止 / 降级 / 请求补充)。
|
||||
|
||||
## 3) Decision procedure
|
||||
|
||||
- 采用固定步骤,确保可复现。
|
||||
- 每一步说明目标、判据和产物(例如中间结论)。
|
||||
- 标明冲突处理逻辑(信号冲突、数据冲突、置信度冲突)。
|
||||
|
||||
## 4) Tool call policy
|
||||
|
||||
- 说明优先使用哪些工具组与工具。
|
||||
- 规定何时可以“无工具直接结论”,何时必须工具先证据后结论。
|
||||
- 规定工具失败、超时、返回异常时的替代动作。
|
||||
|
||||
## 5) Output schema
|
||||
|
||||
- 定义标准输出字段,便于下游 Agent 消费与评估。
|
||||
- 推荐包含:`signal`、`confidence`、`reasons`、`risks`、`invalidation`、`next_action`。
|
||||
- 若是组合决策技能,必须包含每个 ticker 的 `action` 与 `quantity`。
|
||||
|
||||
## 6) Failure fallback
|
||||
|
||||
- 规定在数据不足、信号冲突、风险超限、工具不可用时的降级策略。
|
||||
- 默认优先“保守 + 可解释 + 可执行”的输出。
|
||||
|
||||
## Optional: Evaluation hooks
|
||||
|
||||
定义技能的可评估指标,用于后续记忆/反思阶段写入长期经验。
|
||||
|
||||
### 支持的指标类型
|
||||
|
||||
| 指标类型 | 描述 | 适用技能 |
|
||||
|---------|------|---------|
|
||||
| `hit_rate` | 信号命中率 - 决策信号与实际结果的符合程度 | sentiment_review, technical_review |
|
||||
| `risk_violation` | 风控违例率 - 触发风控规则的次数 | risk_review, portfolio_decisioning |
|
||||
| `position_deviation` | 仓位偏离率 - 建议仓位与实际执行仓位的偏差 | portfolio_decisioning |
|
||||
| `pnl_attribution` | P&L 归因一致性 - 收益归因与实际收益的匹配度 | fundamental_review, valuation_review |
|
||||
| `signal_consistency` | 信号一致性 - 多来源信号的一致程度 | sentiment_review |
|
||||
| `decision_latency` | 决策延迟 - 从输入到决策的耗时 | portfolio_decisioning |
|
||||
| `tool_usage` | 工具使用率 - 工具调用次数与成功率的比值 | 所有技能 |
|
||||
| `custom` | 自定义指标 | 特定业务场景 |
|
||||
|
||||
### 使用方式
|
||||
|
||||
```python
|
||||
from backend.agents.base.evaluation_hook import EvaluationHook, MetricType
|
||||
|
||||
# 在技能执行开始时
|
||||
evaluation_hook.start_evaluation(
|
||||
skill_name="technical_review",
|
||||
inputs={"tickers": ["AAPL"], "prices": {...}}
|
||||
)
|
||||
|
||||
# 在技能执行过程中添加指标
|
||||
evaluation_hook.add_metric(
|
||||
name="signal_confidence",
|
||||
metric_type=MetricType.HIT_RATE,
|
||||
value=0.85,
|
||||
metadata={"method": "rsi", "threshold": 30}
|
||||
)
|
||||
|
||||
# 在技能完成时记录结果
|
||||
evaluation_hook.record_outputs({"signal": "buy", "confidence": 0.8})
|
||||
evaluation_hook.complete_evaluation(success=True)
|
||||
```
|
||||
|
||||
### 评估结果存储
|
||||
|
||||
评估结果自动保存到 `runs/{run_id}/evaluations/{agent_id}/{skill_name}_{timestamp}.json`
|
||||
104
backend/tests/test_agent_service_app.py
Normal file
104
backend/tests/test_agent_service_app.py
Normal file
@@ -0,0 +1,104 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Tests for the extracted agent service surface."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from backend.apps.agent_service import create_app
|
||||
from backend.api import agents as agents_module
|
||||
|
||||
|
||||
def test_agent_service_routes_include_control_plane_endpoints(tmp_path):
|
||||
app = create_app(project_root=tmp_path)
|
||||
|
||||
paths = {route.path for route in app.routes}
|
||||
|
||||
assert "/health" in paths
|
||||
assert "/api/status" in paths
|
||||
assert "/api/workspaces" in paths
|
||||
assert "/api/guard/pending" in paths
|
||||
|
||||
|
||||
def test_agent_service_excludes_runtime_routes(tmp_path):
|
||||
app = create_app(project_root=tmp_path)
|
||||
paths = {route.path for route in app.routes}
|
||||
|
||||
assert "/api/runtime/start" not in paths
|
||||
assert "/api/runtime/gateway/port" not in paths
|
||||
|
||||
|
||||
def test_agent_service_read_routes(monkeypatch, tmp_path):
|
||||
class _FakeSkillsManager:
|
||||
project_root = tmp_path
|
||||
|
||||
def get_agent_asset_dir(self, config_name, agent_id):
|
||||
return tmp_path / "runs" / config_name / "agents" / agent_id
|
||||
|
||||
def resolve_agent_skill_names(self, config_name, agent_id, default_skills=None):
|
||||
return ["demo_skill"]
|
||||
|
||||
def list_agent_skill_catalog(self, config_name, agent_id):
|
||||
return [
|
||||
type(
|
||||
"Skill",
|
||||
(),
|
||||
{
|
||||
"skill_name": "demo_skill",
|
||||
"name": "Demo Skill",
|
||||
"description": "demo",
|
||||
"version": "1.0.0",
|
||||
"source": "builtin",
|
||||
"tools": [],
|
||||
},
|
||||
)()
|
||||
]
|
||||
|
||||
def load_agent_skill_document(self, config_name, agent_id, skill_name):
|
||||
return {"skill_name": skill_name, "content": "# demo"}
|
||||
|
||||
class _FakeWorkspaceManager:
|
||||
def load_agent_file(self, config_name, agent_id, filename):
|
||||
return f"{config_name}:{agent_id}:{filename}"
|
||||
|
||||
monkeypatch.setattr(agents_module, "load_agent_profiles", lambda: {"portfolio_manager": {"skills": ["demo_skill"]}})
|
||||
monkeypatch.setattr(agents_module, "get_agent_model_info", lambda agent_id: ("deepseek-v3.2", "DASHSCOPE"))
|
||||
monkeypatch.setattr(
|
||||
agents_module,
|
||||
"load_agent_workspace_config",
|
||||
lambda path: type(
|
||||
"Cfg",
|
||||
(),
|
||||
{
|
||||
"active_tool_groups": ["portfolio_ops"],
|
||||
"disabled_tool_groups": [],
|
||||
"enabled_skills": [],
|
||||
"disabled_skills": [],
|
||||
"prompt_files": ["SOUL.md", "MEMORY.md"],
|
||||
},
|
||||
)(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
agents_module,
|
||||
"get_bootstrap_config_for_run",
|
||||
lambda project_root, config_name: type("Bootstrap", (), {"agent_override": lambda self, agent_id: {}})(),
|
||||
)
|
||||
|
||||
app = create_app(project_root=tmp_path)
|
||||
app.dependency_overrides[agents_module.get_skills_manager] = lambda: _FakeSkillsManager()
|
||||
app.dependency_overrides[agents_module.get_workspace_manager] = lambda: _FakeWorkspaceManager()
|
||||
|
||||
with TestClient(app) as client:
|
||||
profile = client.get("/api/workspaces/demo/agents/portfolio_manager/profile")
|
||||
skills = client.get("/api/workspaces/demo/agents/portfolio_manager/skills")
|
||||
detail = client.get("/api/workspaces/demo/agents/portfolio_manager/skills/demo_skill")
|
||||
workspace_file = client.get("/api/workspaces/demo/agents/portfolio_manager/files/MEMORY.md")
|
||||
|
||||
assert profile.status_code == 200
|
||||
assert profile.json()["profile"]["model_name"] == "deepseek-v3.2"
|
||||
assert skills.status_code == 200
|
||||
assert skills.json()["skills"][0]["skill_name"] == "demo_skill"
|
||||
assert detail.status_code == 200
|
||||
assert detail.json()["skill"]["content"] == "# demo"
|
||||
assert workspace_file.status_code == 200
|
||||
assert workspace_file.json()["content"] == "demo:portfolio_manager:MEMORY.md"
|
||||
@@ -34,7 +34,6 @@ def test_live_runs_incremental_market_store_update_before_start(monkeypatch, tmp
|
||||
monkeypatch.setattr(cli.subprocess, "run", fake_run)
|
||||
|
||||
cli.live(
|
||||
mock=False,
|
||||
config_name="smoke_fullstack",
|
||||
host="0.0.0.0",
|
||||
port=8765,
|
||||
|
||||
139
backend/tests/test_data_tools_service_routing.py
Normal file
139
backend/tests/test_data_tools_service_routing.py
Normal file
@@ -0,0 +1,139 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Tests for data_tools preferring split services when configured."""
|
||||
|
||||
from backend.tools import data_tools
|
||||
from shared.schema import CompanyNews, FinancialMetrics, InsiderTrade, LineItem, Price
|
||||
|
||||
|
||||
def test_data_tools_prefers_trading_service(monkeypatch):
|
||||
monkeypatch.setenv("TRADING_SERVICE_URL", "http://localhost:8001")
|
||||
monkeypatch.setenv("SERVICE_NAME", "agent_service")
|
||||
monkeypatch.setattr(data_tools._cache, "get_prices", lambda key: None)
|
||||
monkeypatch.setattr(data_tools._cache, "get_financial_metrics", lambda key: None)
|
||||
monkeypatch.setattr(data_tools._cache, "get_insider_trades", lambda key: None)
|
||||
monkeypatch.setattr(data_tools._cache, "get_company_news", lambda key: None)
|
||||
|
||||
def fake_service_get_json(base_url, path, *, params):
|
||||
if path == "/api/prices":
|
||||
return {
|
||||
"ticker": "AAPL",
|
||||
"prices": [
|
||||
Price(
|
||||
open=1,
|
||||
close=2,
|
||||
high=3,
|
||||
low=1,
|
||||
volume=10,
|
||||
time="2026-03-16",
|
||||
).model_dump()
|
||||
],
|
||||
}
|
||||
if path == "/api/financials":
|
||||
return {
|
||||
"financial_metrics": [
|
||||
FinancialMetrics(
|
||||
ticker="AAPL",
|
||||
report_period="2026-03-16",
|
||||
period="ttm",
|
||||
currency="USD",
|
||||
market_cap=123.0,
|
||||
enterprise_value=None,
|
||||
price_to_earnings_ratio=None,
|
||||
price_to_book_ratio=None,
|
||||
price_to_sales_ratio=None,
|
||||
enterprise_value_to_ebitda_ratio=None,
|
||||
enterprise_value_to_revenue_ratio=None,
|
||||
free_cash_flow_yield=None,
|
||||
peg_ratio=None,
|
||||
gross_margin=None,
|
||||
operating_margin=None,
|
||||
net_margin=None,
|
||||
return_on_equity=None,
|
||||
return_on_assets=None,
|
||||
return_on_invested_capital=None,
|
||||
asset_turnover=None,
|
||||
inventory_turnover=None,
|
||||
receivables_turnover=None,
|
||||
days_sales_outstanding=None,
|
||||
operating_cycle=None,
|
||||
working_capital_turnover=None,
|
||||
current_ratio=None,
|
||||
quick_ratio=None,
|
||||
cash_ratio=None,
|
||||
operating_cash_flow_ratio=None,
|
||||
debt_to_equity=None,
|
||||
debt_to_assets=None,
|
||||
interest_coverage=None,
|
||||
revenue_growth=None,
|
||||
earnings_growth=None,
|
||||
book_value_growth=None,
|
||||
earnings_per_share_growth=None,
|
||||
free_cash_flow_growth=None,
|
||||
operating_income_growth=None,
|
||||
ebitda_growth=None,
|
||||
payout_ratio=None,
|
||||
earnings_per_share=None,
|
||||
book_value_per_share=None,
|
||||
free_cash_flow_per_share=None,
|
||||
).model_dump()
|
||||
]
|
||||
}
|
||||
if path == "/api/insider-trades":
|
||||
return {
|
||||
"insider_trades": [
|
||||
InsiderTrade(ticker="AAPL", filing_date="2026-03-16").model_dump()
|
||||
]
|
||||
}
|
||||
if path == "/api/news":
|
||||
return {
|
||||
"news": [
|
||||
CompanyNews(
|
||||
ticker="AAPL",
|
||||
title="Title",
|
||||
source="polygon",
|
||||
url="https://example.com",
|
||||
).model_dump()
|
||||
]
|
||||
}
|
||||
if path == "/api/market-cap":
|
||||
return {"ticker": "AAPL", "end_date": "2026-03-16", "market_cap": 2.5e12}
|
||||
if path == "/api/line-items":
|
||||
return {
|
||||
"search_results": [
|
||||
LineItem(
|
||||
ticker="AAPL",
|
||||
report_period="2026-03-16",
|
||||
period="ttm",
|
||||
currency="USD",
|
||||
free_cash_flow=321.0,
|
||||
).model_dump()
|
||||
]
|
||||
}
|
||||
raise AssertionError(path)
|
||||
|
||||
monkeypatch.setattr(data_tools, "_service_get_json", fake_service_get_json)
|
||||
|
||||
prices = data_tools.get_prices("AAPL", "2026-03-01", "2026-03-16")
|
||||
metrics = data_tools.get_financial_metrics("AAPL", "2026-03-16")
|
||||
trades = data_tools.get_insider_trades("AAPL", "2026-03-16")
|
||||
news = data_tools.get_company_news("AAPL", "2026-03-16")
|
||||
market_cap = data_tools.get_market_cap("AAPL", "2026-03-16")
|
||||
line_items = data_tools.search_line_items(
|
||||
"AAPL",
|
||||
["free_cash_flow"],
|
||||
"2026-03-16",
|
||||
)
|
||||
|
||||
assert prices[0].close == 2
|
||||
assert metrics[0].ticker == "AAPL"
|
||||
assert trades[0].ticker == "AAPL"
|
||||
assert news[0].ticker == "AAPL"
|
||||
assert market_cap == 2.5e12
|
||||
assert line_items[0].free_cash_flow == 321.0
|
||||
|
||||
|
||||
def test_data_tools_skips_self_recursion_for_trading_service(monkeypatch):
|
||||
monkeypatch.setenv("TRADING_SERVICE_URL", "http://localhost:8001")
|
||||
monkeypatch.setenv("SERVICE_NAME", "trading_service")
|
||||
|
||||
assert data_tools._trading_service_url() is None
|
||||
@@ -6,6 +6,7 @@ import pytest
|
||||
|
||||
from backend.services.gateway import Gateway
|
||||
import backend.services.gateway as gateway_module
|
||||
from shared.schema import InsiderTrade, InsiderTradeResponse, Price, PriceResponse
|
||||
|
||||
|
||||
class DummyWebSocket:
|
||||
@@ -35,6 +36,10 @@ class FakeMarketStore:
|
||||
def __init__(self):
|
||||
self.calls = []
|
||||
|
||||
def get_ticker_watermarks(self, symbol):
|
||||
self.calls.append(("get_ticker_watermarks", symbol))
|
||||
return {"symbol": symbol, "last_news_fetch": "2026-12-31"}
|
||||
|
||||
def get_news_timeline_enriched(self, symbol, *, start_date=None, end_date=None):
|
||||
self.calls.append(("get_news_timeline_enriched", symbol, start_date, end_date))
|
||||
return [{"date": end_date, "count": 2, "source_count": 1, "top_title": "Top", "positive_count": 1}]
|
||||
@@ -123,6 +128,75 @@ def make_gateway(market_store=None):
|
||||
)
|
||||
|
||||
|
||||
class FakeNewsClient:
|
||||
def __init__(self, base_url):
|
||||
self.base_url = base_url
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
return None
|
||||
|
||||
async def get_categories(self, ticker, start_date=None, end_date=None, limit=200):
|
||||
return {"ticker": ticker, "categories": {"remote": {"count": 2}}}
|
||||
|
||||
async def get_enriched_news(self, ticker, start_date=None, end_date=None, limit=None):
|
||||
return {
|
||||
"ticker": ticker,
|
||||
"news": [
|
||||
{
|
||||
"id": "remote-news-1",
|
||||
"ticker": ticker,
|
||||
"title": "Remote Title",
|
||||
"date": end_date,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
async def get_story(self, ticker, as_of_date):
|
||||
return {"symbol": ticker, "as_of_date": as_of_date, "story": "remote story", "source": "news_service"}
|
||||
|
||||
|
||||
class FakeTradingClient:
|
||||
def __init__(self, base_url):
|
||||
self.base_url = base_url
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
return None
|
||||
|
||||
async def get_insider_trades(self, ticker, end_date=None, start_date=None, limit=None):
|
||||
return InsiderTradeResponse(
|
||||
insider_trades=[
|
||||
InsiderTrade(
|
||||
ticker=ticker,
|
||||
name="Remote Insider",
|
||||
filing_date=end_date or "2026-03-16",
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
async def get_prices(self, ticker, start_date=None, end_date=None):
|
||||
prices = [
|
||||
Price(
|
||||
open=float(100 + idx),
|
||||
close=float(101 + idx),
|
||||
high=float(102 + idx),
|
||||
low=float(99 + idx),
|
||||
volume=1000 + idx,
|
||||
time=f"2026-01-{idx + 1:02d}",
|
||||
)
|
||||
for idx in range(30)
|
||||
]
|
||||
return PriceResponse(ticker=ticker, prices=prices)
|
||||
|
||||
async def get_market_cap(self, ticker, end_date):
|
||||
return {"ticker": ticker, "end_date": end_date, "market_cap": 2.5e12}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_get_stock_news_timeline_uses_market_store_symbol_argument():
|
||||
market_store = FakeMarketStore()
|
||||
@@ -135,6 +209,7 @@ async def test_handle_get_stock_news_timeline_uses_market_store_symbol_argument(
|
||||
)
|
||||
|
||||
assert market_store.calls == [
|
||||
("get_ticker_watermarks", "AAPL"),
|
||||
("get_news_timeline_enriched", "AAPL", "2026-02-14", "2026-03-16")
|
||||
]
|
||||
assert websocket.messages[-1]["type"] == "stock_news_timeline_loaded"
|
||||
@@ -153,6 +228,7 @@ async def test_handle_get_stock_news_categories_uses_market_store_symbol_argumen
|
||||
)
|
||||
|
||||
assert market_store.calls == [
|
||||
("get_ticker_watermarks", "AAPL"),
|
||||
("get_news_items_enriched", "AAPL", "2026-02-14", "2026-03-16", None, 200),
|
||||
("get_news_categories_enriched", "AAPL", "2026-02-14", "2026-03-16", 200)
|
||||
]
|
||||
@@ -175,7 +251,7 @@ async def test_handle_get_stock_range_explain_uses_market_store_rows(monkeypatch
|
||||
}
|
||||
|
||||
monkeypatch.setattr(
|
||||
gateway_module,
|
||||
gateway_module.news_domain,
|
||||
"build_range_explanation",
|
||||
fake_build_range_explanation,
|
||||
)
|
||||
@@ -186,6 +262,7 @@ async def test_handle_get_stock_range_explain_uses_market_store_rows(monkeypatch
|
||||
)
|
||||
|
||||
assert market_store.calls == [
|
||||
("get_ticker_watermarks", "AAPL"),
|
||||
("get_news_items_enriched", "AAPL", "2026-03-10", "2026-03-16", None, 100)
|
||||
]
|
||||
assert websocket.messages[-1] == {
|
||||
@@ -207,7 +284,7 @@ async def test_handle_get_stock_range_explain_uses_article_ids_path(monkeypatch)
|
||||
websocket = DummyWebSocket()
|
||||
|
||||
monkeypatch.setattr(
|
||||
gateway_module,
|
||||
gateway_module.news_domain,
|
||||
"build_range_explanation",
|
||||
lambda **kwargs: {"news_count": len(kwargs["news_rows"])},
|
||||
)
|
||||
@@ -222,7 +299,10 @@ async def test_handle_get_stock_range_explain_uses_article_ids_path(monkeypatch)
|
||||
},
|
||||
)
|
||||
|
||||
assert market_store.calls == [("get_news_by_ids_enriched", "AAPL", ["news-99"])]
|
||||
assert market_store.calls == [
|
||||
("get_ticker_watermarks", "AAPL"),
|
||||
("get_news_by_ids_enriched", "AAPL", ["news-99"])
|
||||
]
|
||||
assert websocket.messages[-1]["result"]["news_count"] == 1
|
||||
|
||||
|
||||
@@ -238,6 +318,7 @@ async def test_handle_get_stock_news_for_date_uses_trade_date_lookup():
|
||||
)
|
||||
|
||||
assert market_store.calls == [
|
||||
("get_ticker_watermarks", "AAPL"),
|
||||
("get_news_items_enriched", "AAPL", None, None, "2026-03-16", 10)
|
||||
]
|
||||
assert websocket.messages[-1]["type"] == "stock_news_for_date_loaded"
|
||||
@@ -251,7 +332,7 @@ async def test_handle_get_stock_story_returns_story_payload(monkeypatch):
|
||||
websocket = DummyWebSocket()
|
||||
|
||||
monkeypatch.setattr(
|
||||
gateway_module,
|
||||
gateway_module.news_domain,
|
||||
"enrich_news_for_symbol",
|
||||
lambda *args, **kwargs: {"symbol": "AAPL", "analyzed": 3},
|
||||
)
|
||||
@@ -266,6 +347,132 @@ async def test_handle_get_stock_story_returns_story_payload(monkeypatch):
|
||||
assert "AAPL Story" in websocket.messages[-1]["story"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_get_stock_news_categories_uses_news_service_client_when_configured(monkeypatch):
|
||||
market_store = FakeMarketStore()
|
||||
gateway = make_gateway(market_store)
|
||||
websocket = DummyWebSocket()
|
||||
|
||||
monkeypatch.setenv("NEWS_SERVICE_URL", "http://news-service.local")
|
||||
monkeypatch.setattr(gateway_module, "NewsServiceClient", FakeNewsClient)
|
||||
|
||||
await gateway._handle_get_stock_news_categories(
|
||||
websocket,
|
||||
{"ticker": "AAPL", "lookback_days": 30},
|
||||
)
|
||||
|
||||
assert market_store.calls == []
|
||||
assert websocket.messages[-1]["type"] == "stock_news_categories_loaded"
|
||||
assert websocket.messages[-1]["categories"]["remote"]["count"] == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_get_stock_story_uses_news_service_client_when_configured(monkeypatch):
|
||||
market_store = FakeMarketStore()
|
||||
gateway = make_gateway(market_store)
|
||||
websocket = DummyWebSocket()
|
||||
|
||||
monkeypatch.setenv("NEWS_SERVICE_URL", "http://news-service.local")
|
||||
monkeypatch.setattr(gateway_module, "NewsServiceClient", FakeNewsClient)
|
||||
|
||||
await gateway._handle_get_stock_story(
|
||||
websocket,
|
||||
{"ticker": "AAPL", "as_of_date": "2026-03-16"},
|
||||
)
|
||||
|
||||
assert market_store.calls == []
|
||||
assert websocket.messages[-1]["type"] == "stock_story_loaded"
|
||||
assert websocket.messages[-1]["story"] == "remote story"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_get_stock_news_uses_news_service_client_when_configured(monkeypatch):
|
||||
market_store = FakeMarketStore()
|
||||
gateway = make_gateway(market_store)
|
||||
websocket = DummyWebSocket()
|
||||
|
||||
monkeypatch.setenv("NEWS_SERVICE_URL", "http://news-service.local")
|
||||
monkeypatch.setattr(gateway_module, "NewsServiceClient", FakeNewsClient)
|
||||
|
||||
await gateway._handle_get_stock_news(
|
||||
websocket,
|
||||
{"ticker": "AAPL", "lookback_days": 30, "limit": 5},
|
||||
)
|
||||
|
||||
assert market_store.calls == []
|
||||
assert websocket.messages[-1]["type"] == "stock_news_loaded"
|
||||
assert websocket.messages[-1]["source"] == "news_service"
|
||||
assert websocket.messages[-1]["news"][0]["title"] == "Remote Title"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_get_stock_insider_trades_uses_trading_service_client_when_configured(monkeypatch):
|
||||
market_store = FakeMarketStore()
|
||||
gateway = make_gateway(market_store)
|
||||
websocket = DummyWebSocket()
|
||||
|
||||
monkeypatch.setenv("TRADING_SERVICE_URL", "http://trading-service.local")
|
||||
monkeypatch.setattr(gateway_module, "TradingServiceClient", FakeTradingClient)
|
||||
|
||||
await gateway._handle_get_stock_insider_trades(
|
||||
websocket,
|
||||
{"ticker": "AAPL", "end_date": "2026-03-16", "limit": 10},
|
||||
)
|
||||
|
||||
assert websocket.messages[-1]["type"] == "stock_insider_trades_loaded"
|
||||
assert websocket.messages[-1]["trades"][0]["name"] == "Remote Insider"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_get_stock_history_uses_trading_service_client_when_configured(monkeypatch):
|
||||
market_store = FakeMarketStore()
|
||||
gateway = make_gateway(market_store)
|
||||
websocket = DummyWebSocket()
|
||||
|
||||
monkeypatch.setenv("TRADING_SERVICE_URL", "http://trading-service.local")
|
||||
monkeypatch.setattr(gateway_module, "TradingServiceClient", FakeTradingClient)
|
||||
|
||||
await gateway._handle_get_stock_history(
|
||||
websocket,
|
||||
{"ticker": "AAPL", "lookback_days": 30},
|
||||
)
|
||||
|
||||
assert market_store.calls == []
|
||||
assert websocket.messages[-1]["type"] == "stock_history_loaded"
|
||||
assert websocket.messages[-1]["source"] == "trading_service"
|
||||
assert len(websocket.messages[-1]["prices"]) == 30
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_get_stock_technical_indicators_uses_trading_service_client_when_configured(monkeypatch):
|
||||
gateway = make_gateway(FakeMarketStore())
|
||||
websocket = DummyWebSocket()
|
||||
|
||||
monkeypatch.setenv("TRADING_SERVICE_URL", "http://trading-service.local")
|
||||
monkeypatch.setattr(gateway_module, "TradingServiceClient", FakeTradingClient)
|
||||
|
||||
await gateway._handle_get_stock_technical_indicators(
|
||||
websocket,
|
||||
{"ticker": "AAPL"},
|
||||
)
|
||||
|
||||
assert websocket.messages[-1]["type"] == "stock_technical_indicators_loaded"
|
||||
assert websocket.messages[-1]["ticker"] == "AAPL"
|
||||
assert websocket.messages[-1]["indicators"] is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_market_caps_uses_trading_service_client_when_configured(monkeypatch):
|
||||
gateway = make_gateway(FakeMarketStore())
|
||||
|
||||
monkeypatch.setenv("TRADING_SERVICE_URL", "http://trading-service.local")
|
||||
monkeypatch.setattr(gateway_module, "TradingServiceClient", FakeTradingClient)
|
||||
|
||||
market_caps = await gateway._get_market_caps(["AAPL", "MSFT"], "2026-03-16")
|
||||
|
||||
assert market_caps == {"AAPL": 2.5e12, "MSFT": 2.5e12}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_get_stock_similar_days_returns_items(monkeypatch):
|
||||
market_store = FakeMarketStore()
|
||||
@@ -273,7 +480,7 @@ async def test_handle_get_stock_similar_days_returns_items(monkeypatch):
|
||||
websocket = DummyWebSocket()
|
||||
|
||||
monkeypatch.setattr(
|
||||
gateway_module,
|
||||
gateway_module.news_domain,
|
||||
"enrich_news_for_symbol",
|
||||
lambda *args, **kwargs: {"symbol": "AAPL", "analyzed": 3},
|
||||
)
|
||||
@@ -295,7 +502,12 @@ async def test_handle_run_stock_enrich_rebuilds_caches(monkeypatch):
|
||||
websocket = DummyWebSocket()
|
||||
|
||||
monkeypatch.setattr(
|
||||
gateway_module,
|
||||
gateway_module.gateway_stock_handlers,
|
||||
"enrich_news_for_symbol",
|
||||
lambda *args, **kwargs: {"symbol": "AAPL", "analyzed": 2, "queued_count": 2},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
gateway_module.news_domain,
|
||||
"enrich_news_for_symbol",
|
||||
lambda *args, **kwargs: {"symbol": "AAPL", "analyzed": 2, "queued_count": 2},
|
||||
)
|
||||
@@ -325,7 +537,7 @@ async def test_handle_run_stock_enrich_rejects_local_to_llm_without_llm(monkeypa
|
||||
gateway = make_gateway(FakeMarketStore())
|
||||
websocket = DummyWebSocket()
|
||||
|
||||
monkeypatch.setattr(gateway_module, "llm_enrichment_enabled", lambda: False)
|
||||
monkeypatch.setattr(gateway_module.gateway_stock_handlers, "llm_enrichment_enabled", lambda: False)
|
||||
|
||||
await gateway._handle_run_stock_enrich(
|
||||
websocket,
|
||||
@@ -361,7 +573,7 @@ def test_schedule_watchlist_market_store_refresh_creates_task(monkeypatch):
|
||||
|
||||
gateway._schedule_watchlist_market_store_refresh(["AAPL", "MSFT"])
|
||||
|
||||
assert captured["coro_name"] == "_refresh_market_store_for_watchlist"
|
||||
assert captured["coro_name"] == "refresh_market_store_for_watchlist"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -369,7 +581,7 @@ async def test_refresh_market_store_for_watchlist_emits_system_messages(monkeypa
|
||||
gateway = make_gateway()
|
||||
|
||||
monkeypatch.setattr(
|
||||
gateway_module,
|
||||
gateway_module.gateway_cycle_support,
|
||||
"ingest_symbols",
|
||||
lambda symbols, mode="incremental": [
|
||||
{"symbol": symbol, "prices": 3, "news": 4, "aligned": 4}
|
||||
@@ -445,12 +657,12 @@ async def test_handle_get_agent_profile_returns_model_and_tool_groups(monkeypatc
|
||||
websocket = DummyWebSocket()
|
||||
|
||||
monkeypatch.setattr(
|
||||
gateway_module,
|
||||
gateway_module.gateway_admin_handlers,
|
||||
"load_agent_profiles",
|
||||
lambda: {"risk_manager": {"skills": ["risk_review"], "active_tool_groups": ["risk_ops", "legacy_group"]}},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
gateway_module,
|
||||
gateway_module.gateway_admin_handlers,
|
||||
"get_agent_model_info",
|
||||
lambda agent_id: ("gpt-4o-mini", "OPENAI"),
|
||||
)
|
||||
@@ -461,7 +673,7 @@ async def test_handle_get_agent_profile_returns_model_and_tool_groups(monkeypatc
|
||||
return {}
|
||||
|
||||
monkeypatch.setattr(
|
||||
gateway_module,
|
||||
gateway_module.gateway_admin_handlers,
|
||||
"get_bootstrap_config_for_run",
|
||||
lambda project_root, config_name: _Bootstrap(),
|
||||
)
|
||||
|
||||
220
backend/tests/test_gateway_support_modules.py
Normal file
220
backend/tests/test_gateway_support_modules.py
Normal file
@@ -0,0 +1,220 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Direct tests for Gateway support modules."""
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.services import gateway_cycle_support, gateway_runtime_support
|
||||
|
||||
|
||||
class _DummyDashboard:
|
||||
def __init__(self):
|
||||
self.updated = []
|
||||
self.tickers = []
|
||||
self.initial_cash = None
|
||||
self.enable_memory = False
|
||||
self.days_total = 0
|
||||
|
||||
def update(self, **kwargs):
|
||||
self.updated.append(kwargs)
|
||||
|
||||
def stop(self):
|
||||
return None
|
||||
|
||||
def print_final_summary(self):
|
||||
return None
|
||||
|
||||
|
||||
class _DummyScheduler:
|
||||
def __init__(self):
|
||||
self.calls = []
|
||||
|
||||
def reconfigure(self, **kwargs):
|
||||
self.calls.append(kwargs)
|
||||
|
||||
|
||||
class _DummyStateSync:
|
||||
def __init__(self):
|
||||
self.updated = []
|
||||
self.saved = False
|
||||
self.system_messages = []
|
||||
self.backtest_dates = []
|
||||
self.state = {}
|
||||
|
||||
def update_state(self, key, value):
|
||||
self.updated.append((key, value))
|
||||
self.state[key] = value
|
||||
|
||||
def save_state(self):
|
||||
self.saved = True
|
||||
|
||||
async def on_system_message(self, message):
|
||||
self.system_messages.append(message)
|
||||
|
||||
def set_backtest_dates(self, dates):
|
||||
self.backtest_dates = list(dates)
|
||||
|
||||
|
||||
class _DummyStorage:
|
||||
def __init__(self):
|
||||
self.initial_cash = 100000.0
|
||||
self.is_live_session_active = False
|
||||
self.server_state_updates = []
|
||||
|
||||
def can_apply_initial_cash(self):
|
||||
return True
|
||||
|
||||
def apply_initial_cash(self, value):
|
||||
self.initial_cash = value
|
||||
return True
|
||||
|
||||
def update_server_state_from_dashboard(self, state):
|
||||
self.server_state_updates.append(state)
|
||||
|
||||
def load_file(self, name):
|
||||
if name == "summary":
|
||||
return {"totalAssetValue": self.initial_cash}
|
||||
return []
|
||||
|
||||
def build_dashboard_snapshot_from_state(self, state):
|
||||
return {
|
||||
"summary": {"totalAssetValue": self.initial_cash},
|
||||
"holdings": [],
|
||||
"stats": {},
|
||||
"trades": [],
|
||||
"leaderboard": [],
|
||||
}
|
||||
|
||||
|
||||
class _DummyPM:
|
||||
def __init__(self):
|
||||
self.portfolio = {"margin_requirement": 0.0}
|
||||
|
||||
def apply_runtime_portfolio_config(self, margin_requirement=None, initial_cash=None):
|
||||
if margin_requirement is not None:
|
||||
self.portfolio["margin_requirement"] = margin_requirement
|
||||
return {"margin_requirement": True}
|
||||
|
||||
def can_apply_initial_cash(self):
|
||||
return True
|
||||
|
||||
|
||||
class _DummyMarketService:
|
||||
def __init__(self):
|
||||
self.updated = None
|
||||
self.stopped = False
|
||||
|
||||
def update_tickers(self, tickers):
|
||||
self.updated = list(tickers)
|
||||
return {"active": list(tickers), "added": list(tickers), "removed": []}
|
||||
|
||||
def stop(self):
|
||||
self.stopped = True
|
||||
|
||||
|
||||
def make_gateway_stub():
|
||||
pipeline = SimpleNamespace(max_comm_cycles=0, pm=_DummyPM())
|
||||
gateway = SimpleNamespace(
|
||||
market_service=_DummyMarketService(),
|
||||
pipeline=pipeline,
|
||||
scheduler=_DummyScheduler(),
|
||||
config={
|
||||
"tickers": ["AAPL"],
|
||||
"schedule_mode": "daily",
|
||||
"interval_minutes": 60,
|
||||
"trigger_time": "09:30",
|
||||
"enable_memory": False,
|
||||
},
|
||||
storage=_DummyStorage(),
|
||||
state_sync=_DummyStateSync(),
|
||||
_dashboard=_DummyDashboard(),
|
||||
_watchlist_ingest_task=None,
|
||||
_market_status_task=None,
|
||||
_backtest_task=None,
|
||||
_backtest_start_date=None,
|
||||
_backtest_end_date=None,
|
||||
_manual_cycle_task=None,
|
||||
)
|
||||
return gateway
|
||||
|
||||
|
||||
def test_normalize_watchlist_filters_invalid_and_dedupes():
|
||||
assert gateway_runtime_support.normalize_watchlist(["aapl", " AAPL ", "", "msft"]) == ["AAPL", "MSFT"]
|
||||
assert gateway_runtime_support.normalize_watchlist("aapl,msft") == ["AAPL", "MSFT"]
|
||||
|
||||
|
||||
def test_normalize_agent_workspace_filename_obeys_allowlist():
|
||||
allowlist = {"SOUL.md", "PROFILE.md"}
|
||||
assert gateway_runtime_support.normalize_agent_workspace_filename("SOUL.md", allowlist=allowlist) == "SOUL.md"
|
||||
assert gateway_runtime_support.normalize_agent_workspace_filename("README.md", allowlist=allowlist) is None
|
||||
|
||||
|
||||
def test_apply_runtime_config_updates_gateway_state():
|
||||
gateway = make_gateway_stub()
|
||||
|
||||
result = gateway_runtime_support.apply_runtime_config(
|
||||
gateway,
|
||||
{
|
||||
"tickers": ["MSFT", "NVDA"],
|
||||
"schedule_mode": "intraday",
|
||||
"interval_minutes": 30,
|
||||
"trigger_time": "10:30",
|
||||
"initial_cash": 150000.0,
|
||||
"margin_requirement": 0.5,
|
||||
"max_comm_cycles": 4,
|
||||
"enable_memory": False,
|
||||
},
|
||||
)
|
||||
|
||||
assert gateway.config["tickers"] == ["MSFT", "NVDA"]
|
||||
assert gateway.config["schedule_mode"] == "intraday"
|
||||
assert gateway.storage.initial_cash == 150000.0
|
||||
assert result["runtime_config_applied"]["max_comm_cycles"] == 4
|
||||
assert gateway.scheduler.calls[-1] == {
|
||||
"mode": "intraday",
|
||||
"trigger_time": "10:30",
|
||||
"interval_minutes": 30,
|
||||
}
|
||||
|
||||
|
||||
def test_schedule_watchlist_market_store_refresh_creates_task(monkeypatch):
|
||||
gateway = make_gateway_stub()
|
||||
captured = {}
|
||||
|
||||
class DummyTask:
|
||||
def done(self):
|
||||
return False
|
||||
|
||||
def cancel(self):
|
||||
captured["cancelled"] = True
|
||||
|
||||
def fake_create_task(coro):
|
||||
captured["name"] = coro.cr_code.co_name
|
||||
coro.close()
|
||||
return DummyTask()
|
||||
|
||||
monkeypatch.setattr(gateway_cycle_support.asyncio, "create_task", fake_create_task)
|
||||
|
||||
gateway_cycle_support.schedule_watchlist_market_store_refresh(gateway, ["AAPL", "MSFT"])
|
||||
|
||||
assert captured["name"] == "refresh_market_store_for_watchlist"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_market_store_for_watchlist_emits_system_messages(monkeypatch):
|
||||
gateway = make_gateway_stub()
|
||||
|
||||
monkeypatch.setattr(
|
||||
gateway_cycle_support,
|
||||
"ingest_symbols",
|
||||
lambda symbols, mode="incremental": [
|
||||
{"symbol": symbol, "prices": 3, "news": 4}
|
||||
for symbol in symbols
|
||||
],
|
||||
)
|
||||
|
||||
await gateway_cycle_support.refresh_market_store_for_watchlist(gateway, ["AAPL", "MSFT"])
|
||||
|
||||
assert gateway.state_sync.system_messages[0] == "正在同步自选股市场数据: AAPL, MSFT"
|
||||
assert "自选股市场数据已同步:" in gateway.state_sync.system_messages[1]
|
||||
69
backend/tests/test_heartbeat_hook.py
Normal file
69
backend/tests/test_heartbeat_hook.py
Normal file
@@ -0,0 +1,69 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Tests for HeartbeatHook."""
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.agents.base.hooks import HeartbeatHook
|
||||
|
||||
|
||||
class TestHeartbeatHook:
|
||||
"""Tests for HeartbeatHook._read_heartbeat_content."""
|
||||
|
||||
def test_read_heartbeat_content_with_content(self, tmp_path):
|
||||
"""Test reading HEARTBEAT.md when it exists and has content."""
|
||||
ws_dir = tmp_path / "analyst_workspace"
|
||||
ws_dir.mkdir()
|
||||
hb_file = ws_dir / "HEARTBEAT.md"
|
||||
hb_file.write_text("# 定期主动检查\n\n- [ ] 持仓是否健康\n", encoding="utf-8")
|
||||
|
||||
hook = HeartbeatHook(workspace_dir=ws_dir)
|
||||
content = hook._read_heartbeat_content()
|
||||
|
||||
assert content is not None
|
||||
assert "# 定期主动检查" in content
|
||||
assert "持仓是否健康" in content
|
||||
|
||||
def test_read_heartbeat_content_absent(self, tmp_path):
|
||||
"""Test reading when HEARTBEAT.md does not exist."""
|
||||
ws_dir = tmp_path / "analyst_workspace"
|
||||
ws_dir.mkdir()
|
||||
|
||||
hook = HeartbeatHook(workspace_dir=ws_dir)
|
||||
content = hook._read_heartbeat_content()
|
||||
|
||||
assert content is None
|
||||
|
||||
def test_read_heartbeat_content_empty(self, tmp_path):
|
||||
"""Test reading when HEARTBEAT.md is empty."""
|
||||
ws_dir = tmp_path / "analyst_workspace"
|
||||
ws_dir.mkdir()
|
||||
hb_file = ws_dir / "HEARTBEAT.md"
|
||||
hb_file.write_text("", encoding="utf-8")
|
||||
|
||||
hook = HeartbeatHook(workspace_dir=ws_dir)
|
||||
content = hook._read_heartbeat_content()
|
||||
|
||||
assert content is None
|
||||
|
||||
def test_read_heartbeat_content_whitespace_only(self, tmp_path):
|
||||
"""Test reading when HEARTBEAT.md contains only whitespace."""
|
||||
ws_dir = tmp_path / "analyst_workspace"
|
||||
ws_dir.mkdir()
|
||||
hb_file = ws_dir / "HEARTBEAT.md"
|
||||
hb_file.write_text(" \n\n ", encoding="utf-8")
|
||||
|
||||
hook = HeartbeatHook(workspace_dir=ws_dir)
|
||||
content = hook._read_heartbeat_content()
|
||||
|
||||
assert content is None
|
||||
|
||||
def test_completed_flag_path(self, tmp_path):
|
||||
"""Test that completion flag is placed in workspace directory."""
|
||||
ws_dir = tmp_path / "analyst_workspace"
|
||||
ws_dir.mkdir()
|
||||
|
||||
hook = HeartbeatHook(workspace_dir=ws_dir)
|
||||
|
||||
assert hook._completed_flag == ws_dir / ".heartbeat_completed"
|
||||
@@ -2,153 +2,12 @@
|
||||
# pylint: disable=W0212
|
||||
import asyncio
|
||||
import time
|
||||
import logging
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
import pytest
|
||||
from backend.services.market import MarketService
|
||||
from backend.data.mock_price_manager import MockPriceManager
|
||||
from backend.data.polling_price_manager import PollingPriceManager
|
||||
|
||||
|
||||
class TestMockPriceManager:
|
||||
def test_init_default(self):
|
||||
manager = MockPriceManager()
|
||||
|
||||
assert manager.poll_interval == 10
|
||||
assert manager.volatility == 0.5
|
||||
assert manager.running is False
|
||||
assert len(manager.subscribed_symbols) == 0
|
||||
|
||||
def test_init_custom(self):
|
||||
manager = MockPriceManager(poll_interval=5, volatility=1.0)
|
||||
|
||||
assert manager.poll_interval == 5
|
||||
assert manager.volatility == 1.0
|
||||
|
||||
def test_subscribe(self):
|
||||
manager = MockPriceManager()
|
||||
manager.subscribe(["AAPL", "MSFT"])
|
||||
|
||||
assert "AAPL" in manager.subscribed_symbols
|
||||
assert "MSFT" in manager.subscribed_symbols
|
||||
assert manager.base_prices["AAPL"] == 237.50 # default price
|
||||
assert manager.base_prices["MSFT"] == 425.30 # default price
|
||||
|
||||
def test_subscribe_with_base_prices(self):
|
||||
manager = MockPriceManager()
|
||||
manager.subscribe(["AAPL"], base_prices={"AAPL": 100.0})
|
||||
|
||||
assert manager.base_prices["AAPL"] == 100.0
|
||||
assert manager.open_prices["AAPL"] == 100.0
|
||||
assert manager.latest_prices["AAPL"] == 100.0
|
||||
|
||||
def test_subscribe_unknown_symbol(self):
|
||||
manager = MockPriceManager()
|
||||
manager.subscribe(["UNKNOWN"])
|
||||
|
||||
assert "UNKNOWN" in manager.subscribed_symbols
|
||||
assert manager.base_prices["UNKNOWN"] > 0 # random price generated
|
||||
|
||||
def test_unsubscribe(self):
|
||||
manager = MockPriceManager()
|
||||
manager.subscribe(["AAPL", "MSFT"])
|
||||
manager.unsubscribe(["AAPL"])
|
||||
|
||||
assert "AAPL" not in manager.subscribed_symbols
|
||||
assert "MSFT" in manager.subscribed_symbols
|
||||
|
||||
def test_add_price_callback(self):
|
||||
manager = MockPriceManager()
|
||||
callback = MagicMock()
|
||||
manager.add_price_callback(callback)
|
||||
|
||||
assert callback in manager.price_callbacks
|
||||
|
||||
def test_generate_price_update_within_bounds(self):
|
||||
manager = MockPriceManager(volatility=0.5)
|
||||
manager.subscribe(["AAPL"], base_prices={"AAPL": 100.0})
|
||||
|
||||
for _ in range(100):
|
||||
new_price = manager._generate_price_update("AAPL")
|
||||
# Should be within +/-10% of open
|
||||
assert 90.0 <= new_price <= 110.0
|
||||
|
||||
def test_update_prices_triggers_callback(self):
|
||||
manager = MockPriceManager()
|
||||
manager.subscribe(["AAPL"], base_prices={"AAPL": 100.0})
|
||||
|
||||
callback = MagicMock()
|
||||
manager.add_price_callback(callback)
|
||||
|
||||
manager._update_prices()
|
||||
|
||||
callback.assert_called_once()
|
||||
call_args = callback.call_args[0][0]
|
||||
assert call_args["symbol"] == "AAPL"
|
||||
assert "price" in call_args
|
||||
assert "timestamp" in call_args
|
||||
|
||||
def test_start_stop(self):
|
||||
manager = MockPriceManager(poll_interval=1)
|
||||
manager.subscribe(["AAPL"], base_prices={"AAPL": 100.0})
|
||||
|
||||
manager.start()
|
||||
assert manager.running is True
|
||||
|
||||
time.sleep(0.1) # let thread start
|
||||
|
||||
manager.stop()
|
||||
assert manager.running is False
|
||||
|
||||
def test_start_without_subscription(self):
|
||||
manager = MockPriceManager()
|
||||
manager.start()
|
||||
|
||||
assert (
|
||||
manager.running is False
|
||||
) # should not start without subscriptions
|
||||
|
||||
def test_get_latest_price(self):
|
||||
manager = MockPriceManager()
|
||||
manager.subscribe(["AAPL"], base_prices={"AAPL": 100.0})
|
||||
|
||||
price = manager.get_latest_price("AAPL")
|
||||
assert price == 100.0
|
||||
|
||||
def test_get_latest_price_unknown(self):
|
||||
manager = MockPriceManager()
|
||||
price = manager.get_latest_price("UNKNOWN")
|
||||
assert price is None
|
||||
|
||||
def test_get_all_latest_prices(self):
|
||||
manager = MockPriceManager()
|
||||
manager.subscribe(
|
||||
["AAPL", "MSFT"],
|
||||
base_prices={"AAPL": 100.0, "MSFT": 200.0},
|
||||
)
|
||||
|
||||
prices = manager.get_all_latest_prices()
|
||||
assert prices["AAPL"] == 100.0
|
||||
assert prices["MSFT"] == 200.0
|
||||
|
||||
def test_reset_open_prices(self):
|
||||
manager = MockPriceManager()
|
||||
manager.subscribe(["AAPL"], base_prices={"AAPL": 100.0})
|
||||
manager.latest_prices["AAPL"] = 105.0
|
||||
|
||||
manager.reset_open_prices()
|
||||
|
||||
# Open price should change (based on latest with small gap)
|
||||
assert manager.open_prices["AAPL"] != 100.0
|
||||
|
||||
def test_set_base_price(self):
|
||||
manager = MockPriceManager()
|
||||
manager.subscribe(["AAPL"], base_prices={"AAPL": 100.0})
|
||||
|
||||
manager.set_base_price("AAPL", 150.0)
|
||||
|
||||
assert manager.base_prices["AAPL"] == 150.0
|
||||
assert manager.open_prices["AAPL"] == 150.0
|
||||
assert manager.latest_prices["AAPL"] == 150.0
|
||||
from backend.llm.models import RetryChatModel
|
||||
|
||||
|
||||
class TestPollingPriceManager:
|
||||
@@ -231,37 +90,67 @@ class TestPollingPriceManager:
|
||||
|
||||
assert len(manager.open_prices) == 0
|
||||
|
||||
def test_fetch_prices_suppresses_repeated_failures(self, caplog):
|
||||
manager = PollingPriceManager(provider="yfinance", poll_interval=10)
|
||||
manager.subscribe(["AAPL"])
|
||||
|
||||
with patch.object(manager, "_fetch_quote", side_effect=ValueError("empty quote")):
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
for _ in range(3):
|
||||
manager._fetch_prices()
|
||||
|
||||
assert manager._failure_counts["AAPL"] == 3
|
||||
warning_messages = [record.message for record in caplog.records if record.levelno >= logging.WARNING]
|
||||
assert any("Failed to fetch AAPL price: empty quote" in message for message in warning_messages)
|
||||
|
||||
def test_fetch_prices_logs_recovery_after_failure(self, caplog):
|
||||
manager = PollingPriceManager(provider="yfinance", poll_interval=10)
|
||||
manager.subscribe(["AAPL"])
|
||||
|
||||
with patch.object(
|
||||
manager,
|
||||
"_fetch_quote",
|
||||
side_effect=[
|
||||
ValueError("temporary outage"),
|
||||
{"c": 100.0, "o": 99.0, "h": 101.0, "l": 98.0, "pc": 99.5, "d": 0.5, "dp": 0.5, "t": 1},
|
||||
],
|
||||
):
|
||||
with caplog.at_level(logging.INFO):
|
||||
manager._fetch_prices()
|
||||
manager._fetch_prices()
|
||||
|
||||
assert "AAPL" not in manager._failure_counts
|
||||
assert any("recovered after 1 consecutive failures" in record.message for record in caplog.records)
|
||||
|
||||
|
||||
class TestRetryChatModel:
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_retry_recovers_from_disconnect(self):
|
||||
attempts = {"count": 0}
|
||||
|
||||
class FakeAsyncModel:
|
||||
model_name = "fake-async-model"
|
||||
|
||||
async def __call__(self, *args, **kwargs):
|
||||
attempts["count"] += 1
|
||||
if attempts["count"] < 2:
|
||||
raise RuntimeError("Server disconnected")
|
||||
return {"ok": True}
|
||||
|
||||
wrapped = RetryChatModel(FakeAsyncModel(), max_retries=2, initial_delay=0.01)
|
||||
result = await wrapped("hello")
|
||||
|
||||
assert result == {"ok": True}
|
||||
assert attempts["count"] == 2
|
||||
|
||||
|
||||
class TestMarketService:
|
||||
def test_init_mock_mode(self):
|
||||
service = MarketService(
|
||||
tickers=["AAPL", "MSFT"],
|
||||
poll_interval=10,
|
||||
mock_mode=True,
|
||||
)
|
||||
|
||||
assert service.tickers == ["AAPL", "MSFT"]
|
||||
assert service.poll_interval == 10
|
||||
assert service.mock_mode is True
|
||||
assert service.running is False
|
||||
|
||||
def test_init_real_mode(self):
|
||||
service = MarketService(
|
||||
tickers=["AAPL"],
|
||||
mock_mode=False,
|
||||
api_key="test_key",
|
||||
)
|
||||
|
||||
assert service.mock_mode is False
|
||||
assert service.api_key == "test_key"
|
||||
|
||||
@patch("backend.services.market.get_data_source", return_value="yfinance")
|
||||
@patch("backend.services.market.get_data_sources", return_value=["yfinance", "local_csv"])
|
||||
@patch.object(PollingPriceManager, "start")
|
||||
def test_start_real_mode_with_yfinance(self, _mock_start, _mock_source):
|
||||
def test_start_real_mode_with_yfinance(self, _mock_start, _mock_sources):
|
||||
service = MarketService(
|
||||
tickers=["AAPL"],
|
||||
poll_interval=10,
|
||||
mock_mode=False,
|
||||
)
|
||||
|
||||
service._start_real_mode()
|
||||
@@ -269,30 +158,24 @@ class TestMarketService:
|
||||
assert isinstance(service._price_manager, PollingPriceManager)
|
||||
assert service._price_manager.provider == "yfinance"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_mock_mode(self):
|
||||
@patch("backend.services.market.get_data_sources", return_value=["financial_datasets", "yfinance", "local_csv"])
|
||||
@patch.object(PollingPriceManager, "start")
|
||||
def test_start_real_mode_uses_first_supported_live_provider(self, _mock_start, _mock_sources):
|
||||
service = MarketService(
|
||||
tickers=["AAPL"],
|
||||
poll_interval=10,
|
||||
mock_mode=True,
|
||||
)
|
||||
|
||||
broadcast_func = AsyncMock()
|
||||
service._start_real_mode()
|
||||
|
||||
await service.start(broadcast_func)
|
||||
assert isinstance(service._price_manager, PollingPriceManager)
|
||||
assert service._price_manager.provider == "yfinance"
|
||||
|
||||
assert service.running is True
|
||||
assert service._price_manager is not None
|
||||
assert isinstance(service._price_manager, MockPriceManager)
|
||||
|
||||
service.stop()
|
||||
|
||||
@patch("backend.services.market.get_data_source", return_value="finnhub")
|
||||
@patch("backend.services.market.get_data_sources", return_value=["finnhub", "yfinance"])
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_real_mode_without_api_key(self, _mock_source):
|
||||
async def test_start_real_mode_without_api_key(self, _mock_sources):
|
||||
service = MarketService(
|
||||
tickers=["AAPL"],
|
||||
mock_mode=False,
|
||||
api_key=None,
|
||||
)
|
||||
|
||||
@@ -307,11 +190,12 @@ class TestMarketService:
|
||||
async def test_start_already_running(self):
|
||||
service = MarketService(
|
||||
tickers=["AAPL"],
|
||||
mock_mode=True,
|
||||
backtest_mode=True,
|
||||
)
|
||||
|
||||
broadcast_func = AsyncMock()
|
||||
|
||||
# First start with backtest mode
|
||||
await service.start(broadcast_func)
|
||||
assert service.running is True
|
||||
|
||||
@@ -323,7 +207,7 @@ class TestMarketService:
|
||||
def test_stop(self):
|
||||
service = MarketService(
|
||||
tickers=["AAPL"],
|
||||
mock_mode=True,
|
||||
backtest_mode=True,
|
||||
)
|
||||
service.running = True
|
||||
service._price_manager = MagicMock()
|
||||
@@ -336,7 +220,7 @@ class TestMarketService:
|
||||
def test_stop_when_not_running(self):
|
||||
service = MarketService(
|
||||
tickers=["AAPL"],
|
||||
mock_mode=True,
|
||||
backtest_mode=True,
|
||||
)
|
||||
|
||||
# Should not raise
|
||||
@@ -344,20 +228,20 @@ class TestMarketService:
|
||||
assert service.running is False
|
||||
|
||||
def test_get_price_sync(self):
|
||||
service = MarketService(tickers=["AAPL"], mock_mode=True)
|
||||
service = MarketService(tickers=["AAPL"], backtest_mode=True)
|
||||
service.cache["AAPL"] = {"price": 150.0, "open": 148.0}
|
||||
|
||||
price = service.get_price_sync("AAPL")
|
||||
assert price == 150.0
|
||||
|
||||
def test_get_price_sync_not_found(self):
|
||||
service = MarketService(tickers=["AAPL"], mock_mode=True)
|
||||
service = MarketService(tickers=["AAPL"], backtest_mode=True)
|
||||
|
||||
price = service.get_price_sync("MSFT")
|
||||
assert price is None
|
||||
|
||||
def test_get_all_prices(self):
|
||||
service = MarketService(tickers=["AAPL", "MSFT"], mock_mode=True)
|
||||
service = MarketService(tickers=["AAPL", "MSFT"], backtest_mode=True)
|
||||
service.cache["AAPL"] = {"price": 150.0}
|
||||
service.cache["MSFT"] = {"price": 400.0}
|
||||
|
||||
@@ -368,7 +252,7 @@ class TestMarketService:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_price_update(self):
|
||||
service = MarketService(tickers=["AAPL"], mock_mode=True)
|
||||
service = MarketService(tickers=["AAPL"], backtest_mode=True)
|
||||
service._broadcast_func = AsyncMock()
|
||||
|
||||
price_data = {
|
||||
@@ -388,7 +272,7 @@ class TestMarketService:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_price_update_no_func(self):
|
||||
service = MarketService(tickers=["AAPL"], mock_mode=True)
|
||||
service = MarketService(tickers=["AAPL"], backtest_mode=True)
|
||||
service._broadcast_func = None
|
||||
|
||||
price_data = {"symbol": "AAPL", "price": 150.0, "open": 148.0}
|
||||
@@ -396,67 +280,6 @@ class TestMarketService:
|
||||
# Should not raise
|
||||
await service._broadcast_price_update(price_data)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_price_callback_thread_safety(self):
|
||||
service = MarketService(
|
||||
tickers=["AAPL"],
|
||||
poll_interval=1,
|
||||
mock_mode=True,
|
||||
)
|
||||
|
||||
received_prices = []
|
||||
|
||||
async def capture_broadcast(msg):
|
||||
received_prices.append(msg)
|
||||
|
||||
await service.start(capture_broadcast)
|
||||
|
||||
# Wait for at least one price update
|
||||
await asyncio.sleep(1.5)
|
||||
|
||||
service.stop()
|
||||
|
||||
# Should have received at least one price update
|
||||
assert len(received_prices) >= 1
|
||||
assert received_prices[0]["type"] == "price_update"
|
||||
|
||||
|
||||
class TestMarketServiceIntegration:
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_mock_cycle(self):
|
||||
service = MarketService(
|
||||
tickers=["AAPL", "MSFT"],
|
||||
poll_interval=1,
|
||||
mock_mode=True,
|
||||
)
|
||||
|
||||
messages = []
|
||||
|
||||
async def collect_messages(msg):
|
||||
messages.append(msg)
|
||||
|
||||
await service.start(collect_messages)
|
||||
|
||||
# Wait for price updates
|
||||
await asyncio.sleep(2.5)
|
||||
|
||||
service.stop()
|
||||
|
||||
# Should have received multiple price updates
|
||||
assert len(messages) >= 2
|
||||
|
||||
# Check message structure
|
||||
symbols_seen = set()
|
||||
for msg in messages:
|
||||
assert msg["type"] == "price_update"
|
||||
assert "symbol" in msg
|
||||
assert "price" in msg
|
||||
assert "ret" in msg
|
||||
symbols_seen.add(msg["symbol"])
|
||||
|
||||
# Should have prices for both tickers
|
||||
assert "AAPL" in symbols_seen or "MSFT" in symbols_seen
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
|
||||
197
backend/tests/test_news_domain.py
Normal file
197
backend/tests/test_news_domain.py
Normal file
@@ -0,0 +1,197 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Unit tests for the news domain helpers."""
|
||||
|
||||
from backend.domains import news as news_domain
|
||||
|
||||
|
||||
class _FakeStore:
|
||||
def __init__(self):
|
||||
self.calls = []
|
||||
|
||||
def get_ticker_watermarks(self, symbol):
|
||||
self.calls.append(("get_ticker_watermarks", symbol))
|
||||
return {"symbol": symbol, "last_news_fetch": "2026-03-10"}
|
||||
|
||||
def get_news_items_enriched(self, ticker, start_date=None, end_date=None, trade_date=None, limit=100):
|
||||
self.calls.append(("get_news_items_enriched", ticker, start_date, end_date, trade_date, limit))
|
||||
target = trade_date or end_date
|
||||
return [{"id": "n1", "ticker": ticker, "date": target, "trade_date": target}]
|
||||
|
||||
def get_news_timeline_enriched(self, ticker, start_date=None, end_date=None):
|
||||
self.calls.append(("get_news_timeline_enriched", ticker, start_date, end_date))
|
||||
return [{"date": end_date, "count": 1}]
|
||||
|
||||
def get_news_categories_enriched(self, ticker, start_date=None, end_date=None, limit=200):
|
||||
self.calls.append(("get_news_categories_enriched", ticker, start_date, end_date, limit))
|
||||
return {"macro": {"count": 1}}
|
||||
|
||||
def get_news_by_ids_enriched(self, ticker, article_ids):
|
||||
self.calls.append(("get_news_by_ids_enriched", ticker, list(article_ids)))
|
||||
return [{"id": article_ids[0], "ticker": ticker, "date": "2026-03-16"}]
|
||||
|
||||
|
||||
def test_news_rows_need_enrichment_detects_missing_fields():
|
||||
assert news_domain.news_rows_need_enrichment([]) is True
|
||||
assert news_domain.news_rows_need_enrichment([{"sentiment": "", "relevance": "", "key_discussion": ""}]) is True
|
||||
assert news_domain.news_rows_need_enrichment([{"sentiment": "positive"}]) is False
|
||||
|
||||
|
||||
def test_ensure_news_fresh_triggers_incremental_refresh_when_watermark_is_stale(monkeypatch):
|
||||
store = _FakeStore()
|
||||
calls = []
|
||||
|
||||
monkeypatch.setattr(
|
||||
news_domain,
|
||||
"update_ticker_incremental",
|
||||
lambda symbol, end_date=None, store=None: calls.append((symbol, end_date)),
|
||||
)
|
||||
|
||||
payload = news_domain.ensure_news_fresh(store, ticker="AAPL", target_date="2026-03-16")
|
||||
|
||||
assert calls == [("AAPL", "2026-03-16")]
|
||||
assert payload["target_date"] == "2026-03-16"
|
||||
assert payload["refreshed"] is True
|
||||
|
||||
|
||||
def test_ensure_news_fresh_skips_refresh_when_watermark_is_current(monkeypatch):
|
||||
store = _FakeStore()
|
||||
calls = []
|
||||
|
||||
monkeypatch.setattr(
|
||||
store,
|
||||
"get_ticker_watermarks",
|
||||
lambda symbol: {"symbol": symbol, "last_news_fetch": "2026-03-16"},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
news_domain,
|
||||
"update_ticker_incremental",
|
||||
lambda symbol, end_date=None, store=None: calls.append((symbol, end_date)),
|
||||
)
|
||||
|
||||
payload = news_domain.ensure_news_fresh(store, ticker="AAPL", target_date="2026-03-16")
|
||||
|
||||
assert calls == []
|
||||
assert payload["refreshed"] is False
|
||||
|
||||
|
||||
def test_get_enriched_news_returns_rows_without_enrichment_when_present(monkeypatch):
|
||||
store = _FakeStore()
|
||||
monkeypatch.setattr(news_domain, "news_rows_need_enrichment", lambda rows: False)
|
||||
monkeypatch.setattr(
|
||||
news_domain,
|
||||
"ensure_news_fresh",
|
||||
lambda store, ticker, target_date=None, refresh_if_stale=False: {
|
||||
"ticker": ticker,
|
||||
"target_date": target_date,
|
||||
"last_news_fetch": target_date,
|
||||
"refreshed": False,
|
||||
},
|
||||
)
|
||||
|
||||
payload = news_domain.get_enriched_news(
|
||||
store,
|
||||
ticker="AAPL",
|
||||
start_date="2026-03-01",
|
||||
end_date="2026-03-16",
|
||||
limit=20,
|
||||
)
|
||||
|
||||
assert payload["ticker"] == "AAPL"
|
||||
assert payload["news"][0]["ticker"] == "AAPL"
|
||||
assert payload["freshness"]["target_date"] is None or payload["freshness"]["target_date"] == "2026-03-16"
|
||||
assert store.calls == [
|
||||
("get_news_items_enriched", "AAPL", "2026-03-01", "2026-03-16", None, 20)
|
||||
]
|
||||
|
||||
|
||||
def test_get_story_and_similar_days_delegate(monkeypatch):
|
||||
store = _FakeStore()
|
||||
monkeypatch.setattr(
|
||||
news_domain,
|
||||
"ensure_news_fresh",
|
||||
lambda store, ticker, target_date=None, refresh_if_stale=False: {
|
||||
"ticker": ticker,
|
||||
"target_date": target_date,
|
||||
"last_news_fetch": target_date,
|
||||
"refreshed": False,
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(news_domain, "enrich_news_for_symbol", lambda *args, **kwargs: {"analyzed": 1})
|
||||
monkeypatch.setattr(
|
||||
news_domain,
|
||||
"get_or_create_stock_story",
|
||||
lambda store, symbol, as_of_date: {"symbol": symbol, "as_of_date": as_of_date, "story": "story"},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
news_domain,
|
||||
"find_similar_days",
|
||||
lambda store, symbol, target_date, top_k: {"symbol": symbol, "target_date": target_date, "items": [{"score": 0.9}]},
|
||||
)
|
||||
|
||||
story = news_domain.get_story_payload(store, ticker="AAPL", as_of_date="2026-03-16")
|
||||
similar = news_domain.get_similar_days_payload(store, ticker="AAPL", date="2026-03-16", n_similar=8)
|
||||
|
||||
assert story["story"] == "story"
|
||||
assert "freshness" in story
|
||||
assert similar["items"][0]["score"] == 0.9
|
||||
assert "freshness" in similar
|
||||
|
||||
|
||||
def test_get_enriched_news_defaults_to_read_only_freshness(monkeypatch):
|
||||
store = _FakeStore()
|
||||
ensure_calls = []
|
||||
|
||||
def fake_ensure(store, ticker, target_date=None, refresh_if_stale=False):
|
||||
ensure_calls.append(refresh_if_stale)
|
||||
return {
|
||||
"ticker": ticker,
|
||||
"target_date": target_date,
|
||||
"last_news_fetch": target_date,
|
||||
"refreshed": False,
|
||||
}
|
||||
|
||||
monkeypatch.setattr(news_domain, "ensure_news_fresh", fake_ensure)
|
||||
monkeypatch.setattr(news_domain, "news_rows_need_enrichment", lambda rows: False)
|
||||
|
||||
payload = news_domain.get_enriched_news(
|
||||
store,
|
||||
ticker="AAPL",
|
||||
end_date="2026-03-16",
|
||||
)
|
||||
|
||||
assert payload["ticker"] == "AAPL"
|
||||
assert ensure_calls == [False]
|
||||
|
||||
|
||||
def test_get_range_explain_payload_uses_article_ids(monkeypatch):
|
||||
store = _FakeStore()
|
||||
monkeypatch.setattr(
|
||||
news_domain,
|
||||
"ensure_news_fresh",
|
||||
lambda store, ticker, target_date=None, refresh_if_stale=False: {
|
||||
"ticker": ticker,
|
||||
"target_date": target_date,
|
||||
"last_news_fetch": target_date,
|
||||
"refreshed": False,
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(news_domain, "news_rows_need_enrichment", lambda rows: False)
|
||||
monkeypatch.setattr(
|
||||
news_domain,
|
||||
"build_range_explanation",
|
||||
lambda ticker, start_date, end_date, news_rows: {"ticker": ticker, "count": len(news_rows)},
|
||||
)
|
||||
|
||||
payload = news_domain.get_range_explain_payload(
|
||||
store,
|
||||
ticker="AAPL",
|
||||
start_date="2026-03-10",
|
||||
end_date="2026-03-16",
|
||||
article_ids=["news-9"],
|
||||
limit=50,
|
||||
)
|
||||
|
||||
assert payload["ticker"] == "AAPL"
|
||||
assert payload["result"] == {"ticker": "AAPL", "count": 1}
|
||||
assert "freshness" in payload
|
||||
assert store.calls == [("get_news_by_ids_enriched", "AAPL", ["news-9"])]
|
||||
180
backend/tests/test_news_service_app.py
Normal file
180
backend/tests/test_news_service_app.py
Normal file
@@ -0,0 +1,180 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Tests for the extracted news service app surface."""
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from backend.apps.news_service import create_app
|
||||
|
||||
|
||||
class _FakeStore:
|
||||
def get_ticker_watermarks(self, symbol):
|
||||
return {"symbol": symbol, "last_news_fetch": "2026-12-31"}
|
||||
|
||||
def get_news_timeline_enriched(self, symbol, start_date=None, end_date=None):
|
||||
return [{"date": end_date, "count": 1}]
|
||||
|
||||
def get_news_items(self, symbol, start_date=None, end_date=None, limit=100):
|
||||
return [{"id": "news-raw-1", "ticker": symbol, "title": "Raw Title", "date": end_date}]
|
||||
|
||||
def get_news_items_enriched(self, symbol, start_date=None, end_date=None, trade_date=None, limit=100):
|
||||
return [{"id": "news-1", "ticker": symbol, "title": "Title", "date": trade_date or end_date}]
|
||||
|
||||
def upsert_news_analysis(self, symbol, rows):
|
||||
return len(rows)
|
||||
|
||||
def get_analyzed_news_ids(self, symbol, start_date=None, end_date=None):
|
||||
return set()
|
||||
|
||||
def get_news_categories_enriched(self, symbol, start_date=None, end_date=None, limit=200):
|
||||
return {"market": {"label": "market", "count": 1, "article_ids": ["news-1"]}}
|
||||
|
||||
def get_news_by_ids_enriched(self, symbol, article_ids):
|
||||
return [{"id": article_ids[0], "ticker": symbol, "title": "Picked"}]
|
||||
|
||||
|
||||
def test_news_service_routes_are_exposed():
|
||||
app = create_app()
|
||||
paths = {route.path for route in app.routes}
|
||||
|
||||
assert "/health" in paths
|
||||
assert "/api/enriched-news" in paths
|
||||
assert "/api/news-for-date" in paths
|
||||
assert "/api/news-timeline" in paths
|
||||
assert "/api/categories" in paths
|
||||
assert "/api/similar-days" in paths
|
||||
assert "/api/stories/{ticker}" in paths
|
||||
assert "/api/range-explain" in paths
|
||||
|
||||
|
||||
def test_news_service_enriched_news_and_categories(monkeypatch):
|
||||
app = create_app()
|
||||
app.dependency_overrides.clear()
|
||||
from backend.apps import news_service as news_service_module
|
||||
|
||||
app.dependency_overrides[news_service_module.get_market_store] = lambda: _FakeStore()
|
||||
monkeypatch.setattr(
|
||||
"backend.domains.news.enrich_news_for_symbol",
|
||||
lambda *args, **kwargs: {"symbol": "AAPL", "analyzed": 1},
|
||||
)
|
||||
|
||||
with TestClient(app) as client:
|
||||
news_response = client.get(
|
||||
"/api/enriched-news",
|
||||
params={"ticker": "AAPL", "end_date": "2026-03-23"},
|
||||
)
|
||||
categories_response = client.get(
|
||||
"/api/categories",
|
||||
params={"ticker": "AAPL", "end_date": "2026-03-23"},
|
||||
)
|
||||
|
||||
assert news_response.status_code == 200
|
||||
assert news_response.json()["news"][0]["ticker"] == "AAPL"
|
||||
assert categories_response.status_code == 200
|
||||
assert categories_response.json()["categories"]["market"]["count"] == 1
|
||||
|
||||
|
||||
def test_news_service_news_for_date_and_timeline(monkeypatch):
|
||||
app = create_app()
|
||||
from backend.apps import news_service as news_service_module
|
||||
|
||||
app.dependency_overrides[news_service_module.get_market_store] = lambda: _FakeStore()
|
||||
monkeypatch.setattr(
|
||||
"backend.domains.news.enrich_news_for_symbol",
|
||||
lambda *args, **kwargs: {"symbol": "AAPL", "analyzed": 1},
|
||||
)
|
||||
|
||||
with TestClient(app) as client:
|
||||
date_response = client.get(
|
||||
"/api/news-for-date",
|
||||
params={"ticker": "AAPL", "date": "2026-03-23"},
|
||||
)
|
||||
timeline_response = client.get(
|
||||
"/api/news-timeline",
|
||||
params={
|
||||
"ticker": "AAPL",
|
||||
"start_date": "2026-03-01",
|
||||
"end_date": "2026-03-23",
|
||||
},
|
||||
)
|
||||
|
||||
assert date_response.status_code == 200
|
||||
assert date_response.json()["date"] == "2026-03-23"
|
||||
assert timeline_response.status_code == 200
|
||||
assert timeline_response.json()["timeline"][0]["count"] == 1
|
||||
|
||||
|
||||
def test_news_service_similar_days_and_story(monkeypatch):
|
||||
app = create_app()
|
||||
from backend.apps import news_service as news_service_module
|
||||
|
||||
app.dependency_overrides[news_service_module.get_market_store] = lambda: _FakeStore()
|
||||
monkeypatch.setattr(
|
||||
"backend.domains.news.enrich_news_for_symbol",
|
||||
lambda *args, **kwargs: {"symbol": "AAPL", "analyzed": 1},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"backend.domains.news.find_similar_days",
|
||||
lambda store, symbol, target_date, top_k: {
|
||||
"symbol": symbol,
|
||||
"target_date": target_date,
|
||||
"items": [{"date": "2026-03-20", "score": 0.9}],
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"backend.domains.news.get_or_create_stock_story",
|
||||
lambda store, symbol, as_of_date: {
|
||||
"symbol": symbol,
|
||||
"as_of_date": as_of_date,
|
||||
"story": "story body",
|
||||
"source": "local",
|
||||
},
|
||||
)
|
||||
|
||||
with TestClient(app) as client:
|
||||
similar_response = client.get(
|
||||
"/api/similar-days",
|
||||
params={"ticker": "AAPL", "date": "2026-03-23", "n_similar": 3},
|
||||
)
|
||||
story_response = client.get(
|
||||
"/api/stories/AAPL",
|
||||
params={"as_of_date": "2026-03-23"},
|
||||
)
|
||||
|
||||
assert similar_response.status_code == 200
|
||||
assert similar_response.json()["items"][0]["score"] == 0.9
|
||||
assert story_response.status_code == 200
|
||||
assert story_response.json()["story"] == "story body"
|
||||
|
||||
|
||||
def test_news_service_range_explain(monkeypatch):
|
||||
app = create_app()
|
||||
from backend.apps import news_service as news_service_module
|
||||
|
||||
app.dependency_overrides[news_service_module.get_market_store] = lambda: _FakeStore()
|
||||
monkeypatch.setattr(
|
||||
"backend.domains.news.enrich_news_for_symbol",
|
||||
lambda *args, **kwargs: {"symbol": "AAPL", "analyzed": 1},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"backend.domains.news.build_range_explanation",
|
||||
lambda ticker, start_date, end_date, news_rows: {
|
||||
"symbol": ticker,
|
||||
"news_count": len(news_rows),
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
},
|
||||
)
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.get(
|
||||
"/api/range-explain",
|
||||
params={
|
||||
"ticker": "AAPL",
|
||||
"start_date": "2026-03-01",
|
||||
"end_date": "2026-03-23",
|
||||
"article_ids": ["news-7"],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["result"]["news_count"] == 1
|
||||
@@ -9,6 +9,7 @@ def test_router_includes_local_csv_fallback(monkeypatch):
|
||||
monkeypatch.delenv("FINNHUB_API_KEY", raising=False)
|
||||
monkeypatch.delenv("FINANCIAL_DATASETS_API_KEY", raising=False)
|
||||
monkeypatch.delenv("FIN_DATA_SOURCE", raising=False)
|
||||
monkeypatch.delenv("ENABLED_DATA_SOURCES", raising=False)
|
||||
reset_config()
|
||||
|
||||
router = DataProviderRouter()
|
||||
|
||||
364
backend/tests/test_runtime_service_app.py
Normal file
364
backend/tests/test_runtime_service_app.py
Normal file
@@ -0,0 +1,364 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Tests for the extracted runtime service app surface."""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from backend.api import runtime as runtime_module
|
||||
from backend.apps.runtime_service import create_app
|
||||
|
||||
|
||||
def test_runtime_service_routes_are_exposed():
|
||||
app = create_app()
|
||||
paths = {route.path for route in app.routes}
|
||||
|
||||
assert "/health" in paths
|
||||
assert "/api/status" in paths
|
||||
assert "/api/runtime/start" in paths
|
||||
assert "/api/runtime/stop" in paths
|
||||
assert "/api/runtime/cleanup" in paths
|
||||
assert "/api/runtime/history" in paths
|
||||
assert "/api/runtime/current" in paths
|
||||
assert "/api/runtime/gateway/port" in paths
|
||||
|
||||
|
||||
def test_runtime_service_health_and_status(monkeypatch):
|
||||
runtime_state = runtime_module.get_runtime_state()
|
||||
runtime_state.gateway_process = None
|
||||
runtime_state.gateway_port = 9876
|
||||
runtime_state.runtime_manager = object()
|
||||
|
||||
with TestClient(create_app()) as client:
|
||||
health_response = client.get("/health")
|
||||
status_response = client.get("/api/status")
|
||||
|
||||
assert health_response.status_code == 200
|
||||
assert health_response.json() == {
|
||||
"status": "healthy",
|
||||
"service": "runtime-service",
|
||||
"gateway_running": False,
|
||||
"gateway_port": 9876,
|
||||
}
|
||||
assert status_response.status_code == 200
|
||||
assert status_response.json() == {
|
||||
"status": "operational",
|
||||
"service": "runtime-service",
|
||||
"runtime": {
|
||||
"gateway_running": False,
|
||||
"gateway_port": 9876,
|
||||
"has_runtime_manager": True,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def test_runtime_service_gateway_port_endpoint_uses_runtime_router(monkeypatch):
|
||||
runtime_module.get_runtime_state().gateway_port = 9345
|
||||
monkeypatch.setattr(runtime_module, "_is_gateway_running", lambda: True)
|
||||
|
||||
with TestClient(create_app()) as client:
|
||||
response = client.get(
|
||||
"/api/runtime/gateway/port",
|
||||
headers={"host": "runtime.example:8003", "x-forwarded-proto": "https"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"port": 9345,
|
||||
"is_running": True,
|
||||
"ws_url": "wss://runtime.example:9345",
|
||||
}
|
||||
|
||||
|
||||
def test_runtime_service_get_runtime_config(monkeypatch, tmp_path):
|
||||
run_dir = tmp_path / "runs" / "demo"
|
||||
state_dir = run_dir / "state"
|
||||
state_dir.mkdir(parents=True)
|
||||
(run_dir / "BOOTSTRAP.md").write_text(
|
||||
"---\n"
|
||||
"tickers:\n"
|
||||
" - AAPL\n"
|
||||
"schedule_mode: intraday\n"
|
||||
"interval_minutes: 30\n"
|
||||
"trigger_time: '10:00'\n"
|
||||
"max_comm_cycles: 3\n"
|
||||
"enable_memory: true\n"
|
||||
"---\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
(state_dir / "runtime_state.json").write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"context": {
|
||||
"config_name": "demo",
|
||||
"run_dir": str(run_dir),
|
||||
"bootstrap_values": {
|
||||
"tickers": ["AAPL"],
|
||||
"schedule_mode": "intraday",
|
||||
"interval_minutes": 30,
|
||||
"trigger_time": "10:00",
|
||||
"max_comm_cycles": 3,
|
||||
"enable_memory": True,
|
||||
},
|
||||
}
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
monkeypatch.setattr(runtime_module, "PROJECT_ROOT", tmp_path)
|
||||
monkeypatch.setattr(runtime_module, "_is_gateway_running", lambda: True)
|
||||
runtime_module.get_runtime_state().gateway_port = 8765
|
||||
|
||||
with TestClient(create_app()) as client:
|
||||
response = client.get("/api/runtime/config")
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload["run_id"] == "demo"
|
||||
assert payload["bootstrap"]["schedule_mode"] == "intraday"
|
||||
assert payload["resolved"]["interval_minutes"] == 30
|
||||
assert payload["resolved"]["enable_memory"] is True
|
||||
|
||||
|
||||
def test_runtime_service_update_runtime_config_persists_bootstrap(monkeypatch, tmp_path):
|
||||
run_dir = tmp_path / "runs" / "demo"
|
||||
state_dir = run_dir / "state"
|
||||
state_dir.mkdir(parents=True)
|
||||
(run_dir / "BOOTSTRAP.md").write_text(
|
||||
"---\n"
|
||||
"tickers:\n"
|
||||
" - AAPL\n"
|
||||
"schedule_mode: daily\n"
|
||||
"interval_minutes: 60\n"
|
||||
"trigger_time: '09:30'\n"
|
||||
"max_comm_cycles: 2\n"
|
||||
"---\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
(state_dir / "runtime_state.json").write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"context": {
|
||||
"config_name": "demo",
|
||||
"run_dir": str(run_dir),
|
||||
"bootstrap_values": {
|
||||
"tickers": ["AAPL"],
|
||||
"schedule_mode": "daily",
|
||||
"interval_minutes": 60,
|
||||
"trigger_time": "09:30",
|
||||
"max_comm_cycles": 2,
|
||||
},
|
||||
}
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
class _DummyContext:
|
||||
def __init__(self):
|
||||
self.bootstrap_values = {
|
||||
"tickers": ["AAPL"],
|
||||
"schedule_mode": "daily",
|
||||
"interval_minutes": 60,
|
||||
"trigger_time": "09:30",
|
||||
"max_comm_cycles": 2,
|
||||
}
|
||||
|
||||
class _DummyManager:
|
||||
def __init__(self):
|
||||
self.config_name = "demo"
|
||||
self.bootstrap = dict(_DummyContext().bootstrap_values)
|
||||
self.context = _DummyContext()
|
||||
|
||||
def _persist_snapshot(self):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(runtime_module, "PROJECT_ROOT", tmp_path)
|
||||
monkeypatch.setattr(runtime_module, "_is_gateway_running", lambda: True)
|
||||
runtime_module.get_runtime_state().runtime_manager = _DummyManager()
|
||||
runtime_module.get_runtime_state().gateway_port = 8765
|
||||
|
||||
with TestClient(create_app()) as client:
|
||||
response = client.put(
|
||||
"/api/runtime/config",
|
||||
json={
|
||||
"schedule_mode": "intraday",
|
||||
"interval_minutes": 15,
|
||||
"trigger_time": "10:15",
|
||||
"max_comm_cycles": 4,
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload["bootstrap"]["schedule_mode"] == "intraday"
|
||||
assert payload["resolved"]["interval_minutes"] == 15
|
||||
assert "interval_minutes: 15" in (run_dir / "BOOTSTRAP.md").read_text(encoding="utf-8")
|
||||
|
||||
|
||||
def test_prune_old_timestamped_runs_keeps_named_runs(monkeypatch, tmp_path):
|
||||
runs_dir = tmp_path / "runs"
|
||||
runs_dir.mkdir()
|
||||
|
||||
keep_dirs = ["20260324_110000", "20260324_120000"]
|
||||
prune_dir = "20260324_100000"
|
||||
named_dir = "smoke_fullstack"
|
||||
|
||||
for name in [*keep_dirs, prune_dir, named_dir]:
|
||||
(runs_dir / name).mkdir(parents=True)
|
||||
|
||||
monkeypatch.setattr(runtime_module, "PROJECT_ROOT", tmp_path)
|
||||
|
||||
pruned = runtime_module._prune_old_timestamped_runs(keep=1, exclude_run_ids={"20260324_120000"})
|
||||
|
||||
assert prune_dir in pruned
|
||||
assert (runs_dir / named_dir).exists()
|
||||
assert (runs_dir / "20260324_120000").exists()
|
||||
assert (runs_dir / "20260324_110000").exists()
|
||||
|
||||
|
||||
def test_runtime_cleanup_endpoint_prunes_old_runs(monkeypatch, tmp_path):
|
||||
runs_dir = tmp_path / "runs"
|
||||
runs_dir.mkdir()
|
||||
|
||||
for name in ["20260324_090000", "20260324_100000", "20260324_110000", "smoke_fullstack"]:
|
||||
(runs_dir / name).mkdir(parents=True)
|
||||
|
||||
monkeypatch.setattr(runtime_module, "PROJECT_ROOT", tmp_path)
|
||||
monkeypatch.setattr(runtime_module, "_is_gateway_running", lambda: False)
|
||||
|
||||
with TestClient(create_app()) as client:
|
||||
response = client.post("/api/runtime/cleanup?keep=1")
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload["status"] == "ok"
|
||||
assert sorted(payload["pruned_run_ids"]) == ["20260324_090000", "20260324_100000"]
|
||||
assert (runs_dir / "20260324_110000").exists()
|
||||
assert (runs_dir / "smoke_fullstack").exists()
|
||||
|
||||
|
||||
def test_runtime_history_lists_recent_runs(monkeypatch, tmp_path):
|
||||
run_dir = tmp_path / "runs" / "20260324_120000"
|
||||
(run_dir / "state").mkdir(parents=True)
|
||||
(run_dir / "team_dashboard").mkdir(parents=True)
|
||||
(run_dir / "state" / "runtime_state.json").write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"context": {
|
||||
"config_name": "20260324_120000",
|
||||
"run_dir": str(run_dir),
|
||||
"bootstrap_values": {"tickers": ["AAPL"]},
|
||||
},
|
||||
"events": [],
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
(run_dir / "team_dashboard" / "summary.json").write_text(
|
||||
json.dumps({"totalTrades": 3, "totalAssetValue": 123456.0}),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
monkeypatch.setattr(runtime_module, "PROJECT_ROOT", tmp_path)
|
||||
|
||||
with TestClient(create_app()) as client:
|
||||
response = client.get("/api/runtime/history?limit=5")
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload["runs"][0]["run_id"] == "20260324_120000"
|
||||
assert payload["runs"][0]["total_trades"] == 3
|
||||
|
||||
|
||||
def test_restore_run_assets_copies_state(monkeypatch, tmp_path):
|
||||
source_run = tmp_path / "runs" / "20260324_100000"
|
||||
(source_run / "team_dashboard").mkdir(parents=True)
|
||||
(source_run / "state").mkdir(parents=True)
|
||||
(source_run / "agents").mkdir(parents=True)
|
||||
(source_run / "team_dashboard" / "_internal_state.json").write_text("{}", encoding="utf-8")
|
||||
(source_run / "state" / "server_state.json").write_text("{}", encoding="utf-8")
|
||||
|
||||
target_run = tmp_path / "runs" / "20260324_130000"
|
||||
|
||||
monkeypatch.setattr(runtime_module, "PROJECT_ROOT", tmp_path)
|
||||
|
||||
runtime_module._restore_run_assets("20260324_100000", target_run)
|
||||
|
||||
assert (target_run / "team_dashboard" / "_internal_state.json").exists()
|
||||
assert (target_run / "state" / "server_state.json").exists()
|
||||
|
||||
|
||||
def test_start_runtime_restore_reuses_historical_run_id(monkeypatch, tmp_path):
|
||||
run_dir = tmp_path / "runs" / "20260324_100000"
|
||||
(run_dir / "state").mkdir(parents=True)
|
||||
(run_dir / "state" / "runtime_state.json").write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"context": {
|
||||
"config_name": "20260324_100000",
|
||||
"run_dir": str(run_dir),
|
||||
"bootstrap_values": {
|
||||
"tickers": ["AAPL"],
|
||||
"schedule_mode": "intraday",
|
||||
"interval_minutes": 30,
|
||||
"trigger_time": "now",
|
||||
"max_comm_cycles": 2,
|
||||
"initial_cash": 100000.0,
|
||||
"margin_requirement": 0.0,
|
||||
"enable_memory": False,
|
||||
"mode": "live",
|
||||
"poll_interval": 10,
|
||||
},
|
||||
}
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
class _DummyManager:
|
||||
def __init__(self, config_name, run_dir, bootstrap):
|
||||
self.config_name = config_name
|
||||
self.run_dir = Path(run_dir)
|
||||
self.bootstrap = bootstrap
|
||||
self.context = None
|
||||
|
||||
def prepare_run(self):
|
||||
self.context = type(
|
||||
"Ctx",
|
||||
(),
|
||||
{
|
||||
"config_name": self.config_name,
|
||||
"run_dir": self.run_dir,
|
||||
"bootstrap_values": self.bootstrap,
|
||||
},
|
||||
)()
|
||||
return self.context
|
||||
|
||||
class _DummyProcess:
|
||||
def poll(self):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(runtime_module, "PROJECT_ROOT", tmp_path)
|
||||
monkeypatch.setattr(runtime_module, "_find_available_port", lambda start_port=8765, max_port=9000: 8765)
|
||||
monkeypatch.setattr(runtime_module, "_start_gateway_process", lambda **kwargs: _DummyProcess())
|
||||
monkeypatch.setattr(runtime_module, "_stop_gateway", lambda: True)
|
||||
monkeypatch.setattr("backend.runtime.manager.TradingRuntimeManager", _DummyManager)
|
||||
runtime_state = runtime_module.get_runtime_state()
|
||||
runtime_state.gateway_process = None
|
||||
|
||||
with TestClient(create_app()) as client:
|
||||
response = client.post(
|
||||
"/api/runtime/start",
|
||||
json={
|
||||
"launch_mode": "restore",
|
||||
"restore_run_id": "20260324_100000",
|
||||
"tickers": [],
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
payload = response.json()
|
||||
assert payload["run_id"] == "20260324_100000"
|
||||
assert payload["run_dir"] == str(run_dir)
|
||||
107
backend/tests/test_service_clients.py
Normal file
107
backend/tests/test_service_clients.py
Normal file
@@ -0,0 +1,107 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Tests for split-aware shared service clients."""
|
||||
|
||||
import pytest
|
||||
|
||||
from shared.client.control_client import ControlPlaneClient
|
||||
from shared.client.runtime_client import RuntimeServiceClient
|
||||
|
||||
|
||||
class _DummyResponse:
|
||||
def __init__(self, payload):
|
||||
self._payload = payload
|
||||
|
||||
def raise_for_status(self):
|
||||
return None
|
||||
|
||||
def json(self):
|
||||
return self._payload
|
||||
|
||||
|
||||
class _DummyAsyncClient:
|
||||
def __init__(self):
|
||||
self.calls = []
|
||||
|
||||
async def get(self, path, params=None):
|
||||
self.calls.append(("get", path, params))
|
||||
return _DummyResponse({"path": path, "params": params})
|
||||
|
||||
async def post(self, path, json=None):
|
||||
self.calls.append(("post", path, json))
|
||||
return _DummyResponse({"path": path, "json": json})
|
||||
|
||||
async def put(self, path, json=None):
|
||||
self.calls.append(("put", path, json))
|
||||
return _DummyResponse({"path": path, "json": json})
|
||||
|
||||
async def aclose(self):
|
||||
return None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_control_plane_client_hits_current_workspace_and_guard_routes():
|
||||
client = ControlPlaneClient()
|
||||
client._client = _DummyAsyncClient()
|
||||
|
||||
await client.list_workspaces()
|
||||
await client.get_workspace("demo")
|
||||
await client.list_agents("demo")
|
||||
await client.get_agent("demo", "risk_manager")
|
||||
await client.fetch_pending_approvals()
|
||||
await client.approve_pending_approval("ap-1")
|
||||
await client.deny_pending_approval("ap-2", reason="nope")
|
||||
|
||||
assert client._client.calls == [
|
||||
("get", "/workspaces", None),
|
||||
("get", "/workspaces/demo", None),
|
||||
("get", "/workspaces/demo/agents", None),
|
||||
("get", "/workspaces/demo/agents/risk_manager", None),
|
||||
("get", "/guard/pending", None),
|
||||
(
|
||||
"post",
|
||||
"/guard/approve",
|
||||
{
|
||||
"approval_id": "ap-1",
|
||||
"one_time": True,
|
||||
"expires_in_minutes": 30,
|
||||
},
|
||||
),
|
||||
(
|
||||
"post",
|
||||
"/guard/deny",
|
||||
{
|
||||
"approval_id": "ap-2",
|
||||
"reason": "nope",
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_service_client_hits_current_runtime_routes():
|
||||
client = RuntimeServiceClient()
|
||||
client._client = _DummyAsyncClient()
|
||||
|
||||
await client.fetch_context()
|
||||
await client.fetch_agents()
|
||||
await client.fetch_events()
|
||||
await client.fetch_gateway_port()
|
||||
await client.start_runtime({"tickers": ["AAPL"]})
|
||||
await client.stop_runtime(force=True)
|
||||
await client.restart_runtime({"tickers": ["MSFT"]})
|
||||
await client.fetch_current_runtime()
|
||||
await client.get_runtime_config()
|
||||
await client.update_runtime_config({"schedule_mode": "intraday"})
|
||||
|
||||
assert client._client.calls == [
|
||||
("get", "/context", None),
|
||||
("get", "/agents", None),
|
||||
("get", "/events", None),
|
||||
("get", "/gateway/port", None),
|
||||
("post", "/start", {"tickers": ["AAPL"]}),
|
||||
("post", "/stop?force=true", None),
|
||||
("post", "/restart", {"tickers": ["MSFT"]}),
|
||||
("get", "/current", None),
|
||||
("get", "/config", None),
|
||||
("put", "/config", {"schedule_mode": "intraday"}),
|
||||
]
|
||||
32
backend/tests/test_shared_schema_bridge.py
Normal file
32
backend/tests/test_shared_schema_bridge.py
Normal file
@@ -0,0 +1,32 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Regression coverage for the shared schema bridge."""
|
||||
|
||||
from backend.data import schema as legacy_schema
|
||||
from shared import schema as shared_schema
|
||||
|
||||
|
||||
def test_backend_data_schema_reexports_shared_contracts():
|
||||
assert legacy_schema.Price is shared_schema.Price
|
||||
assert legacy_schema.PriceResponse is shared_schema.PriceResponse
|
||||
assert legacy_schema.FinancialMetrics is shared_schema.FinancialMetrics
|
||||
assert legacy_schema.FinancialMetricsResponse is (
|
||||
shared_schema.FinancialMetricsResponse
|
||||
)
|
||||
assert legacy_schema.LineItem is shared_schema.LineItem
|
||||
assert legacy_schema.LineItemResponse is shared_schema.LineItemResponse
|
||||
assert legacy_schema.InsiderTrade is shared_schema.InsiderTrade
|
||||
assert legacy_schema.InsiderTradeResponse is (
|
||||
shared_schema.InsiderTradeResponse
|
||||
)
|
||||
assert legacy_schema.CompanyNews is shared_schema.CompanyNews
|
||||
assert legacy_schema.CompanyNewsResponse is shared_schema.CompanyNewsResponse
|
||||
assert legacy_schema.CompanyFacts is shared_schema.CompanyFacts
|
||||
assert legacy_schema.CompanyFactsResponse is (
|
||||
shared_schema.CompanyFactsResponse
|
||||
)
|
||||
assert legacy_schema.Position is shared_schema.Position
|
||||
assert legacy_schema.Portfolio is shared_schema.Portfolio
|
||||
assert legacy_schema.AnalystSignal is shared_schema.AnalystSignal
|
||||
assert legacy_schema.TickerAnalysis is shared_schema.TickerAnalysis
|
||||
assert legacy_schema.AgentStateData is shared_schema.AgentStateData
|
||||
assert legacy_schema.AgentStateMetadata is shared_schema.AgentStateMetadata
|
||||
47
backend/tests/test_trading_domain.py
Normal file
47
backend/tests/test_trading_domain.py
Normal file
@@ -0,0 +1,47 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Unit tests for the trading domain helpers."""
|
||||
|
||||
from backend.domains import trading as trading_domain
|
||||
|
||||
|
||||
def test_trading_domain_payload_wrappers(monkeypatch):
|
||||
monkeypatch.setattr(trading_domain, "get_prices", lambda ticker, start_date, end_date: [{"close": 1}])
|
||||
monkeypatch.setattr(trading_domain, "get_financial_metrics", lambda ticker, end_date, period, limit: [{"ticker": ticker}])
|
||||
monkeypatch.setattr(trading_domain, "get_company_news", lambda ticker, end_date, start_date=None, limit=1000: [{"ticker": ticker}])
|
||||
monkeypatch.setattr(trading_domain, "get_insider_trades", lambda ticker, end_date, start_date=None, limit=1000: [{"ticker": ticker}])
|
||||
monkeypatch.setattr(trading_domain, "get_market_cap", lambda ticker, end_date: 2.5e12)
|
||||
|
||||
assert trading_domain.get_prices_payload(ticker="AAPL", start_date="2026-03-01", end_date="2026-03-16") == {
|
||||
"ticker": "AAPL",
|
||||
"prices": [{"close": 1}],
|
||||
}
|
||||
assert trading_domain.get_financials_payload(ticker="AAPL", end_date="2026-03-16") == {
|
||||
"financial_metrics": [{"ticker": "AAPL"}],
|
||||
}
|
||||
assert trading_domain.get_news_payload(ticker="AAPL", end_date="2026-03-16") == {
|
||||
"news": [{"ticker": "AAPL"}],
|
||||
}
|
||||
assert trading_domain.get_insider_trades_payload(ticker="AAPL", end_date="2026-03-16") == {
|
||||
"insider_trades": [{"ticker": "AAPL"}],
|
||||
}
|
||||
assert trading_domain.get_market_cap_payload(ticker="AAPL", end_date="2026-03-16") == {
|
||||
"ticker": "AAPL",
|
||||
"end_date": "2026-03-16",
|
||||
"market_cap": 2.5e12,
|
||||
}
|
||||
|
||||
|
||||
def test_get_market_status_payload_uses_market_service(monkeypatch):
|
||||
class _FakeMarketService:
|
||||
def __init__(self, tickers):
|
||||
self.tickers = tickers
|
||||
|
||||
def get_market_status(self):
|
||||
return {"status": "open", "status_text": "Open"}
|
||||
|
||||
monkeypatch.setattr(trading_domain, "MarketService", _FakeMarketService)
|
||||
|
||||
assert trading_domain.get_market_status_payload() == {
|
||||
"status": "open",
|
||||
"status_text": "Open",
|
||||
}
|
||||
231
backend/tests/test_trading_service_app.py
Normal file
231
backend/tests/test_trading_service_app.py
Normal file
@@ -0,0 +1,231 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Tests for the extracted trading service app surface."""
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from backend.apps.trading_service import create_app
|
||||
from shared.schema import CompanyNews, FinancialMetrics, InsiderTrade, LineItem, Price
|
||||
|
||||
|
||||
def test_trading_service_routes_are_exposed():
|
||||
app = create_app()
|
||||
|
||||
paths = {route.path for route in app.routes}
|
||||
|
||||
assert "/health" in paths
|
||||
assert "/api/prices" in paths
|
||||
assert "/api/financials" in paths
|
||||
assert "/api/news" in paths
|
||||
assert "/api/insider-trades" in paths
|
||||
assert "/api/market/status" in paths
|
||||
assert "/api/market-cap" in paths
|
||||
assert "/api/line-items" in paths
|
||||
|
||||
|
||||
def test_trading_service_prices_endpoint(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"backend.domains.trading.get_prices_payload",
|
||||
lambda ticker, start_date, end_date: {
|
||||
"ticker": ticker,
|
||||
"prices": [
|
||||
Price(
|
||||
open=1.0,
|
||||
close=2.0,
|
||||
high=2.5,
|
||||
low=0.5,
|
||||
volume=100,
|
||||
time="2026-03-20",
|
||||
)
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
with TestClient(create_app()) as client:
|
||||
response = client.get(
|
||||
"/api/prices",
|
||||
params={
|
||||
"ticker": "AAPL",
|
||||
"start_date": "2026-03-01",
|
||||
"end_date": "2026-03-20",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["ticker"] == "AAPL"
|
||||
assert response.json()["prices"][0]["close"] == 2.0
|
||||
|
||||
|
||||
def test_trading_service_financials_endpoint(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"backend.domains.trading.get_financials_payload",
|
||||
lambda ticker, end_date, period, limit: {
|
||||
"financial_metrics": [
|
||||
FinancialMetrics(
|
||||
ticker=ticker,
|
||||
report_period=end_date,
|
||||
period=period,
|
||||
currency="USD",
|
||||
market_cap=123.0,
|
||||
enterprise_value=None,
|
||||
price_to_earnings_ratio=None,
|
||||
price_to_book_ratio=None,
|
||||
price_to_sales_ratio=None,
|
||||
enterprise_value_to_ebitda_ratio=None,
|
||||
enterprise_value_to_revenue_ratio=None,
|
||||
free_cash_flow_yield=None,
|
||||
peg_ratio=None,
|
||||
gross_margin=None,
|
||||
operating_margin=None,
|
||||
net_margin=None,
|
||||
return_on_equity=None,
|
||||
return_on_assets=None,
|
||||
return_on_invested_capital=None,
|
||||
asset_turnover=None,
|
||||
inventory_turnover=None,
|
||||
receivables_turnover=None,
|
||||
days_sales_outstanding=None,
|
||||
operating_cycle=None,
|
||||
working_capital_turnover=None,
|
||||
current_ratio=None,
|
||||
quick_ratio=None,
|
||||
cash_ratio=None,
|
||||
operating_cash_flow_ratio=None,
|
||||
debt_to_equity=None,
|
||||
debt_to_assets=None,
|
||||
interest_coverage=None,
|
||||
revenue_growth=None,
|
||||
earnings_growth=None,
|
||||
book_value_growth=None,
|
||||
earnings_per_share_growth=None,
|
||||
free_cash_flow_growth=None,
|
||||
operating_income_growth=None,
|
||||
ebitda_growth=None,
|
||||
payout_ratio=None,
|
||||
earnings_per_share=None,
|
||||
book_value_per_share=None,
|
||||
free_cash_flow_per_share=None,
|
||||
)
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
with TestClient(create_app()) as client:
|
||||
response = client.get(
|
||||
"/api/financials",
|
||||
params={"ticker": "AAPL", "end_date": "2026-03-20"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["financial_metrics"][0]["ticker"] == "AAPL"
|
||||
|
||||
|
||||
def test_trading_service_news_and_insider_endpoints(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"backend.domains.trading.get_news_payload",
|
||||
lambda ticker, end_date, start_date=None, limit=1000: {
|
||||
"news": [
|
||||
CompanyNews(
|
||||
ticker=ticker,
|
||||
title="News title",
|
||||
source="polygon",
|
||||
url="https://example.com/news",
|
||||
date=end_date,
|
||||
)
|
||||
]
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"backend.domains.trading.get_insider_trades_payload",
|
||||
lambda ticker, end_date, start_date=None, limit=1000: {
|
||||
"insider_trades": [
|
||||
InsiderTrade(ticker=ticker, filing_date=end_date)
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
with TestClient(create_app()) as client:
|
||||
news_response = client.get(
|
||||
"/api/news",
|
||||
params={"ticker": "AAPL", "end_date": "2026-03-20"},
|
||||
)
|
||||
insider_response = client.get(
|
||||
"/api/insider-trades",
|
||||
params={"ticker": "AAPL", "end_date": "2026-03-20"},
|
||||
)
|
||||
|
||||
assert news_response.status_code == 200
|
||||
assert news_response.json()["news"][0]["title"] == "News title"
|
||||
assert insider_response.status_code == 200
|
||||
assert insider_response.json()["insider_trades"][0]["ticker"] == "AAPL"
|
||||
|
||||
|
||||
def test_trading_service_market_status_endpoint(monkeypatch):
|
||||
class _FakeMarketService:
|
||||
def get_market_status(self):
|
||||
return {"status": "open", "status_text": "Open"}
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.domains.trading.get_market_status_payload",
|
||||
lambda: _FakeMarketService().get_market_status(),
|
||||
)
|
||||
|
||||
with TestClient(create_app()) as client:
|
||||
response = client.get("/api/market/status")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"status": "open", "status_text": "Open"}
|
||||
|
||||
|
||||
def test_trading_service_market_cap_endpoint(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"backend.domains.trading.get_market_cap_payload",
|
||||
lambda ticker, end_date: {
|
||||
"ticker": ticker,
|
||||
"end_date": end_date,
|
||||
"market_cap": 3.5e12,
|
||||
},
|
||||
)
|
||||
|
||||
with TestClient(create_app()) as client:
|
||||
response = client.get(
|
||||
"/api/market-cap",
|
||||
params={"ticker": "AAPL", "end_date": "2026-03-20"},
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"ticker": "AAPL",
|
||||
"end_date": "2026-03-20",
|
||||
"market_cap": 3.5e12,
|
||||
}
|
||||
|
||||
|
||||
def test_trading_service_line_items_endpoint(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"backend.domains.trading.get_line_items_payload",
|
||||
lambda ticker, line_items, end_date, period, limit: {
|
||||
"search_results": [
|
||||
LineItem(
|
||||
ticker=ticker,
|
||||
report_period=end_date,
|
||||
period=period,
|
||||
currency="USD",
|
||||
free_cash_flow=123.0,
|
||||
)
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
with TestClient(create_app()) as client:
|
||||
response = client.get(
|
||||
"/api/line-items",
|
||||
params=[
|
||||
("ticker", "AAPL"),
|
||||
("line_items", "free_cash_flow"),
|
||||
("end_date", "2026-03-20"),
|
||||
],
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["search_results"][0]["ticker"] == "AAPL"
|
||||
assert response.json()["search_results"][0]["free_cash_flow"] == 123.0
|
||||
@@ -3,13 +3,16 @@
|
||||
# pylint: disable=C0301
|
||||
"""Data fetching tools backed by the unified provider router."""
|
||||
import datetime
|
||||
import os
|
||||
|
||||
import httpx
|
||||
import pandas as pd
|
||||
import pandas_market_calendars as mcal
|
||||
from backend.data.provider_utils import normalize_symbol
|
||||
|
||||
from backend.data.cache import get_cache
|
||||
from backend.data.provider_router import get_provider_router
|
||||
from backend.data.schema import (
|
||||
from shared.schema import (
|
||||
CompanyNews,
|
||||
FinancialMetrics,
|
||||
InsiderTrade,
|
||||
@@ -23,6 +26,31 @@ _cache = get_cache()
|
||||
_router = get_provider_router()
|
||||
|
||||
|
||||
def _service_name() -> str:
|
||||
return str(os.getenv("SERVICE_NAME", "")).strip().lower()
|
||||
|
||||
|
||||
def _trading_service_url() -> str | None:
|
||||
value = str(os.getenv("TRADING_SERVICE_URL", "")).strip().rstrip("/")
|
||||
if not value or _service_name() == "trading_service":
|
||||
return None
|
||||
return value
|
||||
|
||||
|
||||
def _news_service_url() -> str | None:
|
||||
value = str(os.getenv("NEWS_SERVICE_URL", "")).strip().rstrip("/")
|
||||
if not value or _service_name() == "news_service":
|
||||
return None
|
||||
return value
|
||||
|
||||
|
||||
def _service_get_json(base_url: str, path: str, *, params: dict[str, object]) -> dict:
|
||||
with httpx.Client(base_url=base_url, timeout=30.0) as client:
|
||||
response = client.get(path, params=params)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
def get_last_tradeday(date: str) -> str:
|
||||
"""
|
||||
Get the previous trading day for the specified date
|
||||
@@ -104,6 +132,24 @@ def get_prices(
|
||||
if cached_data := _cache.get_prices(cache_key):
|
||||
return [Price(**price) for price in cached_data]
|
||||
|
||||
service_url = _trading_service_url()
|
||||
if service_url:
|
||||
try:
|
||||
payload = _service_get_json(
|
||||
service_url,
|
||||
"/api/prices",
|
||||
params={
|
||||
"ticker": ticker,
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
},
|
||||
)
|
||||
prices = [Price(**price) for price in payload.get("prices", [])]
|
||||
if prices:
|
||||
return prices
|
||||
except Exception as exc:
|
||||
logger.info("Trading service price lookup failed for %s: %s", ticker, exc)
|
||||
|
||||
try:
|
||||
prices, data_source = _router.get_prices(ticker, start_date, end_date)
|
||||
except Exception as exc:
|
||||
@@ -146,6 +192,28 @@ def get_financial_metrics(
|
||||
if cached_data := _cache.get_financial_metrics(cache_key):
|
||||
return [FinancialMetrics(**metric) for metric in cached_data]
|
||||
|
||||
service_url = _trading_service_url()
|
||||
if service_url:
|
||||
try:
|
||||
payload = _service_get_json(
|
||||
service_url,
|
||||
"/api/financials",
|
||||
params={
|
||||
"ticker": ticker,
|
||||
"end_date": end_date,
|
||||
"period": period,
|
||||
"limit": limit,
|
||||
},
|
||||
)
|
||||
metrics = [
|
||||
FinancialMetrics(**metric)
|
||||
for metric in payload.get("financial_metrics", [])
|
||||
]
|
||||
if metrics:
|
||||
return metrics
|
||||
except Exception as exc:
|
||||
logger.info("Trading service financial lookup failed for %s: %s", ticker, exc)
|
||||
|
||||
try:
|
||||
financial_metrics, data_source = _router.get_financial_metrics(
|
||||
ticker=ticker,
|
||||
@@ -183,6 +251,22 @@ def search_line_items(
|
||||
ticker = normalize_symbol(ticker)
|
||||
if not ticker:
|
||||
return []
|
||||
|
||||
service_url = _trading_service_url()
|
||||
if service_url:
|
||||
payload = _service_get_json(
|
||||
service_url,
|
||||
"/api/line-items",
|
||||
params={
|
||||
"ticker": ticker,
|
||||
"line_items": line_items,
|
||||
"end_date": end_date,
|
||||
"period": period,
|
||||
"limit": limit,
|
||||
},
|
||||
)
|
||||
return [LineItem(**item) for item in payload.get("search_results", [])]
|
||||
|
||||
return _router.search_line_items(
|
||||
ticker=ticker,
|
||||
line_items=line_items,
|
||||
@@ -213,6 +297,26 @@ def get_insider_trades(
|
||||
if cached_data := _cache.get_insider_trades(cache_key):
|
||||
return [InsiderTrade(**trade) for trade in cached_data]
|
||||
|
||||
service_url = _trading_service_url()
|
||||
if service_url:
|
||||
try:
|
||||
params = {"ticker": ticker, "end_date": end_date, "limit": limit}
|
||||
if start_date:
|
||||
params["start_date"] = start_date
|
||||
payload = _service_get_json(
|
||||
service_url,
|
||||
"/api/insider-trades",
|
||||
params=params,
|
||||
)
|
||||
trades = [
|
||||
InsiderTrade(**trade)
|
||||
for trade in payload.get("insider_trades", [])
|
||||
]
|
||||
if trades:
|
||||
return trades
|
||||
except Exception as exc:
|
||||
logger.info("Trading service insider lookup failed for %s: %s", ticker, exc)
|
||||
|
||||
try:
|
||||
all_trades, data_source = _router.get_insider_trades(
|
||||
ticker=ticker,
|
||||
@@ -248,6 +352,40 @@ def get_company_news(
|
||||
if cached_data := _cache.get_company_news(cache_key):
|
||||
return [CompanyNews(**news) for news in cached_data]
|
||||
|
||||
trading_service_url = _trading_service_url()
|
||||
if trading_service_url:
|
||||
try:
|
||||
params = {"ticker": ticker, "end_date": end_date, "limit": limit}
|
||||
if start_date:
|
||||
params["start_date"] = start_date
|
||||
payload = _service_get_json(
|
||||
trading_service_url,
|
||||
"/api/news",
|
||||
params=params,
|
||||
)
|
||||
news = [CompanyNews(**item) for item in payload.get("news", [])]
|
||||
if news:
|
||||
return news
|
||||
except Exception as exc:
|
||||
logger.info("Trading service news lookup failed for %s: %s", ticker, exc)
|
||||
|
||||
news_service_url = _news_service_url()
|
||||
if news_service_url:
|
||||
try:
|
||||
params = {"ticker": ticker, "end_date": end_date, "limit": limit}
|
||||
if start_date:
|
||||
params["start_date"] = start_date
|
||||
payload = _service_get_json(
|
||||
news_service_url,
|
||||
"/api/enriched-news",
|
||||
params=params,
|
||||
)
|
||||
news = [CompanyNews(**item) for item in payload.get("news", [])]
|
||||
if news:
|
||||
return news
|
||||
except Exception as exc:
|
||||
logger.info("News service lookup failed for %s: %s", ticker, exc)
|
||||
|
||||
try:
|
||||
all_news, data_source = _router.get_company_news(
|
||||
ticker=ticker,
|
||||
@@ -272,6 +410,19 @@ def get_market_cap(ticker: str, end_date: str) -> float | None:
|
||||
if not ticker:
|
||||
return None
|
||||
|
||||
service_url = _trading_service_url()
|
||||
if service_url:
|
||||
try:
|
||||
payload = _service_get_json(
|
||||
service_url,
|
||||
"/api/market-cap",
|
||||
params={"ticker": ticker, "end_date": end_date},
|
||||
)
|
||||
value = payload.get("market_cap")
|
||||
return float(value) if value is not None else None
|
||||
except Exception as exc:
|
||||
logger.info("Trading service market-cap lookup failed for %s: %s", ticker, exc)
|
||||
|
||||
def _metrics_lookup(symbol: str, date: str):
|
||||
for source in _router.api_sources():
|
||||
cache_key = f"{symbol}_ttm_{date}_10_{source}"
|
||||
|
||||
@@ -228,12 +228,12 @@ class SettlementCoordinator:
|
||||
|
||||
all_evaluations = {**analyst_evaluations, **pm_evaluations}
|
||||
|
||||
leaderboard = self.storage.load_file("leaderboard") or []
|
||||
leaderboard = self.storage.load_export_file("leaderboard") or []
|
||||
updated_leaderboard = update_leaderboard_with_evaluations(
|
||||
leaderboard,
|
||||
all_evaluations,
|
||||
)
|
||||
self.storage.save_file("leaderboard", updated_leaderboard)
|
||||
self.storage.save_export_file("leaderboard", updated_leaderboard)
|
||||
|
||||
self._update_summary_with_baselines(
|
||||
date,
|
||||
|
||||
@@ -30,7 +30,6 @@ class TerminalDashboard:
|
||||
self.port = 8765
|
||||
self.poll_interval = 10
|
||||
self.trigger_time = "now"
|
||||
self.mock = False
|
||||
self.enable_memory = False
|
||||
self.local_time = ""
|
||||
self.nyse_time = ""
|
||||
@@ -65,7 +64,6 @@ class TerminalDashboard:
|
||||
port: int,
|
||||
poll_interval: int,
|
||||
trigger_time: str = "now",
|
||||
mock: bool = False,
|
||||
enable_memory: bool = False,
|
||||
local_time: str = "",
|
||||
nyse_time: str = "",
|
||||
@@ -82,7 +80,6 @@ class TerminalDashboard:
|
||||
self.port = port
|
||||
self.poll_interval = poll_interval
|
||||
self.trigger_time = trigger_time
|
||||
self.mock = mock
|
||||
self.enable_memory = enable_memory
|
||||
self.local_time = local_time
|
||||
self.nyse_time = nyse_time
|
||||
@@ -109,8 +106,6 @@ class TerminalDashboard:
|
||||
# Mode line
|
||||
if self.mode == "backtest":
|
||||
mode_str = "[cyan]Backtest[/cyan]"
|
||||
elif self.mock:
|
||||
mode_str = "[yellow]MOCK[/yellow]"
|
||||
else:
|
||||
mode_str = "[green]LIVE[/green]"
|
||||
|
||||
@@ -216,8 +211,6 @@ class TerminalDashboard:
|
||||
title = "[bold cyan]EvoTraders[/bold cyan]"
|
||||
if self.mode == "backtest":
|
||||
title += " [dim]Backtest[/dim]"
|
||||
elif self.mock:
|
||||
title += " [dim]Mock[/dim]"
|
||||
else:
|
||||
title += " [dim]Live[/dim]"
|
||||
|
||||
|
||||
28
docs/compat-removal-plan.md
Normal file
28
docs/compat-removal-plan.md
Normal file
@@ -0,0 +1,28 @@
|
||||
# Compatibility Removal Plan
|
||||
|
||||
This document tracks the remaining migration-only surfaces that still exist
|
||||
after the move to split-first development.
|
||||
|
||||
## Migration-only Surfaces
|
||||
|
||||
None currently remain as dedicated compatibility wrappers.
|
||||
|
||||
## Completed Removals
|
||||
|
||||
### `backend.app`
|
||||
|
||||
- Removed after compatibility startup switched to
|
||||
`backend.apps.combined_service:app` directly.
|
||||
|
||||
### `shared.client.AgentServiceClient`
|
||||
|
||||
- Removed after split-aware clients became the default import surface.
|
||||
- Replacement:
|
||||
- `ControlPlaneClient`
|
||||
- `RuntimeServiceClient`
|
||||
- `TradingServiceClient`
|
||||
- `NewsServiceClient`
|
||||
|
||||
### `backend.apps.combined_service`
|
||||
|
||||
- Removed after split-service mode became the only supported dev startup path.
|
||||
@@ -1,7 +1,31 @@
|
||||
|
||||
## QuickStart
|
||||
```bash
|
||||
cd frontend
|
||||
npm install
|
||||
npm run dev
|
||||
```
|
||||
|
||||
## Optional Direct Service Calls
|
||||
|
||||
The frontend still works with the compatibility backend entrypoint by default.
|
||||
In the current test-stage setup, split services are the recommended default.
|
||||
Point the frontend directly at those standalone services:
|
||||
|
||||
```bash
|
||||
VITE_CONTROL_API_BASE_URL=http://localhost:8000/api
|
||||
VITE_RUNTIME_API_BASE_URL=http://localhost:8003/api/runtime
|
||||
VITE_NEWS_SERVICE_URL=http://localhost:8002
|
||||
VITE_TRADING_SERVICE_URL=http://localhost:8001
|
||||
```
|
||||
|
||||
Current direct-call coverage:
|
||||
|
||||
- runtime panel + gateway port discovery
|
||||
- `story`
|
||||
- `similar days`
|
||||
- `range explain`
|
||||
- `news for date`
|
||||
- `news categories`
|
||||
|
||||
If these variables are not set, the frontend falls back to the existing
|
||||
WebSocket-driven compatibility flow.
|
||||
|
||||
3084
frontend/src/App.jsx
3084
frontend/src/App.jsx
File diff suppressed because it is too large
Load Diff
@@ -57,7 +57,7 @@ export default function AgentCard({ agent, onClose, isClosing }) {
|
||||
background: '#ffffff',
|
||||
borderBottom: '2px solid #000000',
|
||||
boxShadow: '0 4px 12px rgba(0, 0, 0, 0.1)',
|
||||
zIndex: 1000,
|
||||
zIndex: 800,
|
||||
animation: isClosing ? 'slideUp 0.2s ease-out forwards' : 'slideDown 0.25s ease-out'
|
||||
}}>
|
||||
{/* Horizontal scrollable content */}
|
||||
|
||||
@@ -35,14 +35,22 @@ const stripMarkdown = (text) => {
|
||||
.replace(/^[-=]+$/gm, '');
|
||||
};
|
||||
|
||||
const AgentFeed = forwardRef(({ feed, leaderboard }, ref) => {
|
||||
const AgentFeed = forwardRef(({ feed, leaderboard, agentProfilesByAgent }, ref) => {
|
||||
const feedContentRef = useRef(null);
|
||||
const [highlightedId, setHighlightedId] = useState(null);
|
||||
const [selectedAgent, setSelectedAgent] = useState('all');
|
||||
const [dropdownOpen, setDropdownOpen] = useState(false);
|
||||
|
||||
const getAgentModelInfo = (agentId) => {
|
||||
if (!leaderboard || !agentId) return { modelName: null, modelProvider: null };
|
||||
if (!agentId) return { modelName: null, modelProvider: null };
|
||||
const profile = agentProfilesByAgent?.[agentId];
|
||||
if (profile?.model_name) {
|
||||
return {
|
||||
modelName: profile.model_name,
|
||||
modelProvider: profile.model_provider
|
||||
};
|
||||
}
|
||||
if (!leaderboard) return { modelName: null, modelProvider: null };
|
||||
const agentData = leaderboard.find(lb => lb.id === agentId || lb.agentId === agentId);
|
||||
return {
|
||||
modelName: agentData?.modelName,
|
||||
@@ -52,7 +60,17 @@ const AgentFeed = forwardRef(({ feed, leaderboard }, ref) => {
|
||||
|
||||
// Get agent info by name
|
||||
const getAgentInfoByName = (agentName) => {
|
||||
if (!leaderboard || !agentName) return null;
|
||||
if (!agentName) return null;
|
||||
const agentConfig = AGENTS.find((agent) => agent.name === agentName);
|
||||
const profile = agentConfig ? agentProfilesByAgent?.[agentConfig.id] : null;
|
||||
if (agentConfig && profile?.model_name) {
|
||||
return {
|
||||
agentId: agentConfig.id,
|
||||
modelName: profile.model_name,
|
||||
modelProvider: profile.model_provider
|
||||
};
|
||||
}
|
||||
if (!leaderboard) return null;
|
||||
const agentData = leaderboard.find(lb => lb.name === agentName || lb.agentName === agentName);
|
||||
if (!agentData) return null;
|
||||
return {
|
||||
|
||||
506
frontend/src/components/AppShell.jsx
Normal file
506
frontend/src/components/AppShell.jsx
Normal file
@@ -0,0 +1,506 @@
|
||||
import React, { Suspense, lazy, useRef, useEffect, useMemo } from 'react';
|
||||
import GlobalStyles from '../styles/GlobalStyles';
|
||||
import Header from './Header.jsx';
|
||||
import RuntimeSettingsPanel from './RuntimeSettingsPanel.jsx';
|
||||
import StockLogo from './StockLogo.jsx';
|
||||
import NetValueChart from './NetValueChart.jsx';
|
||||
import { AGENTS } from '../config/constants';
|
||||
import { useRuntimeStore } from '../store/runtimeStore';
|
||||
import { useUIStore } from '../store/uiStore';
|
||||
import { formatNumber, formatTickerPrice } from '../utils/formatters';
|
||||
|
||||
const RoomView = lazy(() => import('./RoomView'));
|
||||
const AgentFeed = lazy(() => import('./AgentFeed'));
|
||||
const StatisticsView = lazy(() => import('./StatisticsView'));
|
||||
const StockExplainView = lazy(() => import('./StockExplainView.jsx'));
|
||||
const TraderView = lazy(() => import('./TraderView.jsx'));
|
||||
|
||||
function ViewLoadingFallback({ label = '加载中...' }) {
|
||||
return (
|
||||
<div style={{
|
||||
minHeight: 240,
|
||||
height: '100%',
|
||||
display: 'flex',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
border: '1px solid #000000',
|
||||
background: '#ffffff',
|
||||
fontSize: 12,
|
||||
fontWeight: 700,
|
||||
letterSpacing: 0.4
|
||||
}}>
|
||||
{label}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* AppShell - Layout shell containing Header, TickerBar, ViewNavBar, View container, and AgentFeed
|
||||
*/
|
||||
export default function AppShell({
|
||||
// Connection & status
|
||||
isConnected,
|
||||
virtualTime,
|
||||
now,
|
||||
marketStatus,
|
||||
serverMode,
|
||||
marketStatusLabel,
|
||||
dataSourceLabel,
|
||||
runtimeSummaryLabel,
|
||||
isUpdating,
|
||||
// Handlers
|
||||
onManualTrigger,
|
||||
onOpenRuntimeLogs,
|
||||
onRuntimeSettingsToggle,
|
||||
// Runtime settings panel props
|
||||
isRuntimeSettingsOpen,
|
||||
isRuntimeConfigSaving,
|
||||
isWatchlistSaving,
|
||||
runtimeConfigFeedback,
|
||||
watchlistFeedback,
|
||||
launchModeDraft,
|
||||
restoreRunIdDraft,
|
||||
runtimeHistoryRuns,
|
||||
scheduleModeDraft,
|
||||
intervalMinutesDraft,
|
||||
triggerTimeDraft,
|
||||
maxCommCyclesDraft,
|
||||
initialCashDraft,
|
||||
marginRequirementDraft,
|
||||
enableMemoryDraft,
|
||||
modeDraft,
|
||||
pollIntervalDraft,
|
||||
startDateDraft,
|
||||
endDateDraft,
|
||||
watchlistDraftSymbols,
|
||||
watchlistInputValue,
|
||||
watchlistSuggestions,
|
||||
onLaunchModeChange,
|
||||
onRestoreRunIdChange,
|
||||
onScheduleModeChange,
|
||||
onIntervalMinutesChange,
|
||||
onTriggerTimeChange,
|
||||
onMaxCommCyclesChange,
|
||||
onInitialCashChange,
|
||||
onMarginRequirementChange,
|
||||
onEnableMemoryChange,
|
||||
onModeChange,
|
||||
onPollIntervalChange,
|
||||
onStartDateChange,
|
||||
onEndDateChange,
|
||||
onWatchlistInputChange,
|
||||
onWatchlistInputKeyDown,
|
||||
onWatchlistAdd,
|
||||
onWatchlistRemove,
|
||||
onWatchlistRestoreCurrent,
|
||||
onWatchlistRestoreDefault,
|
||||
onWatchlistSuggestionClick,
|
||||
onLaunchConfigSave,
|
||||
onRestoreDefaults,
|
||||
// Ticker and portfolio data
|
||||
displayTickers,
|
||||
portfolioData,
|
||||
rollingTickers,
|
||||
// Feed data
|
||||
feed,
|
||||
bubbles,
|
||||
bubbleFor,
|
||||
leaderboard,
|
||||
// Views data
|
||||
currentView,
|
||||
chartTab,
|
||||
holdings,
|
||||
trades,
|
||||
stats,
|
||||
priceHistoryByTicker,
|
||||
ohlcHistoryByTicker,
|
||||
selectedExplainSymbol,
|
||||
onSelectedExplainSymbolChange,
|
||||
historySourceByTicker,
|
||||
explainEventsByTicker,
|
||||
newsByTicker,
|
||||
insiderTradesByTicker,
|
||||
technicalIndicatorsByTicker,
|
||||
currentDate,
|
||||
// Stock request handlers
|
||||
stockRequests,
|
||||
// Agent request handlers
|
||||
agentRequests,
|
||||
agentProfilesByAgent,
|
||||
// Layout
|
||||
leftWidth,
|
||||
isResizing,
|
||||
onMouseDown,
|
||||
agentFeedRef
|
||||
}) {
|
||||
const containerRef = useRef(null);
|
||||
const { setIsRuntimeSettingsOpen, setIsWatchlistPanelOpen } = useRuntimeStore();
|
||||
const { setChartTab, setCurrentView, setIsResizing, setLeftWidth } = useUIStore();
|
||||
|
||||
// Resize handler
|
||||
useEffect(() => {
|
||||
if (!isResizing) return;
|
||||
|
||||
const handleMouseMove = (e) => {
|
||||
if (!containerRef.current) return;
|
||||
const containerRect = containerRef.current.getBoundingClientRect();
|
||||
const newLeftWidth = ((e.clientX - containerRect.left) / containerRect.width) * 100;
|
||||
if (newLeftWidth >= 30 && newLeftWidth <= 85) {
|
||||
setLeftWidth(newLeftWidth);
|
||||
}
|
||||
};
|
||||
|
||||
const handleMouseUp = () => setIsResizing(false);
|
||||
|
||||
document.addEventListener('mousemove', handleMouseMove);
|
||||
document.addEventListener('mouseup', handleMouseUp);
|
||||
|
||||
return () => {
|
||||
document.removeEventListener('mousemove', handleMouseMove);
|
||||
document.removeEventListener('mouseup', handleMouseUp);
|
||||
};
|
||||
}, [isResizing, setIsResizing, setLeftWidth]);
|
||||
|
||||
const handleJumpToMessage = (bubble) => {
|
||||
if (agentFeedRef.current && agentFeedRef.current.scrollToMessage) {
|
||||
agentFeedRef.current.scrollToMessage(bubble);
|
||||
}
|
||||
};
|
||||
|
||||
const viewClassName = useMemo(() => {
|
||||
const base = `view-slider-five ${currentView === 'traders' ? 'show-traders' :
|
||||
currentView === 'room' ? 'show-room' :
|
||||
currentView === 'explain' ? 'show-explain' :
|
||||
currentView === 'statistics' ? 'show-statistics' : 'show-chart'}`;
|
||||
return base;
|
||||
}, [currentView]);
|
||||
|
||||
return (
|
||||
<div className="app">
|
||||
<GlobalStyles />
|
||||
|
||||
{/* Header */}
|
||||
<div className="header">
|
||||
<Header />
|
||||
|
||||
<div className="header-right" style={{ display: 'flex', alignItems: 'center', gap: 24, marginLeft: 'auto', flexWrap: 'wrap', minWidth: 0 }}>
|
||||
{/* Unified Status Indicator */}
|
||||
<div className="header-status-inline">
|
||||
<span className={`status-dot ${isConnected ? (isUpdating ? 'updating' : 'live') : 'offline'}`} />
|
||||
<span className={`status-text ${isConnected ? 'live' : 'offline'}`}>
|
||||
{isConnected ? (isUpdating ? '同步中' : '在线') : '离线'}
|
||||
</span>
|
||||
{marketStatus && (
|
||||
<>
|
||||
<span className="status-sep">·</span>
|
||||
<span className={`market-text ${serverMode === 'backtest' ? 'backtest' : (marketStatus.status === 'open' ? 'open' : 'closed')}`}>
|
||||
{marketStatusLabel}
|
||||
</span>
|
||||
</>
|
||||
)}
|
||||
{dataSourceLabel && (
|
||||
<>
|
||||
<span className="status-sep">·</span>
|
||||
<span className="market-text backtest">{dataSourceLabel}</span>
|
||||
</>
|
||||
)}
|
||||
{runtimeSummaryLabel && (
|
||||
<>
|
||||
<span className="status-sep">·</span>
|
||||
<span className="market-text backtest" title="当前运行配置">{runtimeSummaryLabel}</span>
|
||||
</>
|
||||
)}
|
||||
<span className="status-sep">·</span>
|
||||
<span className="time-text">{now.toLocaleTimeString('en-US', { hour: '2-digit', minute: '2-digit', second: '2-digit', hour12: false })}</span>
|
||||
</div>
|
||||
|
||||
{serverMode !== 'backtest' && (
|
||||
<div style={{ display: 'flex', gap: 8, alignItems: 'center' }}>
|
||||
{onOpenRuntimeLogs && (
|
||||
<button
|
||||
onClick={onOpenRuntimeLogs}
|
||||
style={{
|
||||
padding: '6px 12px',
|
||||
borderRadius: 4,
|
||||
background: '#FFFFFF',
|
||||
border: '1px solid #111111',
|
||||
color: '#111111',
|
||||
fontSize: '11px',
|
||||
fontFamily: '"Courier New", monospace',
|
||||
fontWeight: 700,
|
||||
cursor: 'pointer',
|
||||
letterSpacing: '0.4px',
|
||||
textTransform: 'uppercase'
|
||||
}}
|
||||
title="查看当前运行日志"
|
||||
>
|
||||
运行日志
|
||||
</button>
|
||||
)}
|
||||
<button
|
||||
onClick={onManualTrigger}
|
||||
disabled={!isConnected}
|
||||
style={{
|
||||
padding: '6px 12px',
|
||||
borderRadius: 4,
|
||||
background: isConnected ? '#111111' : '#8a8a8a',
|
||||
border: '1px solid #111111',
|
||||
color: '#FFFFFF',
|
||||
fontSize: '11px',
|
||||
fontFamily: '"Courier New", monospace',
|
||||
fontWeight: 700,
|
||||
cursor: isConnected ? 'pointer' : 'not-allowed',
|
||||
letterSpacing: '0.4px',
|
||||
textTransform: 'uppercase'
|
||||
}}
|
||||
title="手动触发一轮分析与交易决策"
|
||||
>
|
||||
手动运行
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<RuntimeSettingsPanel
|
||||
showTrigger={false}
|
||||
isOpen={isRuntimeSettingsOpen}
|
||||
isConnected={isConnected}
|
||||
isSaving={isRuntimeConfigSaving || isWatchlistSaving}
|
||||
feedback={runtimeConfigFeedback || watchlistFeedback}
|
||||
launchMode={launchModeDraft}
|
||||
restoreRunId={restoreRunIdDraft}
|
||||
runtimeHistoryRuns={runtimeHistoryRuns}
|
||||
scheduleMode={scheduleModeDraft}
|
||||
intervalMinutes={intervalMinutesDraft}
|
||||
triggerTime={triggerTimeDraft}
|
||||
maxCommCycles={maxCommCyclesDraft}
|
||||
initialCash={initialCashDraft}
|
||||
marginRequirement={marginRequirementDraft}
|
||||
enableMemory={enableMemoryDraft}
|
||||
mode={modeDraft}
|
||||
pollInterval={pollIntervalDraft}
|
||||
startDate={startDateDraft}
|
||||
endDate={endDateDraft}
|
||||
watchlistSymbols={watchlistDraftSymbols}
|
||||
watchlistInputValue={watchlistInputValue}
|
||||
watchlistSuggestions={watchlistSuggestions}
|
||||
onToggle={onRuntimeSettingsToggle}
|
||||
onClose={() => setIsRuntimeSettingsOpen(false)}
|
||||
onLaunchModeChange={onLaunchModeChange}
|
||||
onRestoreRunIdChange={onRestoreRunIdChange}
|
||||
onScheduleModeChange={onScheduleModeChange}
|
||||
onIntervalMinutesChange={onIntervalMinutesChange}
|
||||
onTriggerTimeChange={onTriggerTimeChange}
|
||||
onMaxCommCyclesChange={onMaxCommCyclesChange}
|
||||
onInitialCashChange={onInitialCashChange}
|
||||
onMarginRequirementChange={onMarginRequirementChange}
|
||||
onEnableMemoryChange={onEnableMemoryChange}
|
||||
onModeChange={onModeChange}
|
||||
onPollIntervalChange={onPollIntervalChange}
|
||||
onStartDateChange={onStartDateChange}
|
||||
onEndDateChange={onEndDateChange}
|
||||
onWatchlistInputChange={onWatchlistInputChange}
|
||||
onWatchlistInputKeyDown={onWatchlistInputKeyDown}
|
||||
onWatchlistAdd={onWatchlistAdd}
|
||||
onWatchlistRemove={onWatchlistRemove}
|
||||
onWatchlistRestoreCurrent={onWatchlistRestoreCurrent}
|
||||
onWatchlistRestoreDefault={onWatchlistRestoreDefault}
|
||||
onWatchlistSuggestionClick={onWatchlistSuggestionClick}
|
||||
onSave={onLaunchConfigSave}
|
||||
onRestoreDefaults={onRestoreDefaults}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Main Content */}
|
||||
<>
|
||||
{/* Ticker Bar */}
|
||||
<div className="ticker-bar">
|
||||
<div className="ticker-track">
|
||||
{[0, 1].map((groupIdx) => (
|
||||
<div key={groupIdx} className="ticker-group">
|
||||
{displayTickers.map(ticker => (
|
||||
<div key={`${ticker.symbol}-${groupIdx}`} className="ticker-item">
|
||||
<StockLogo ticker={ticker.symbol} size={16} />
|
||||
<span className="ticker-symbol">{ticker.symbol}</span>
|
||||
<span className="ticker-price">
|
||||
<span className={`ticker-price-value ${rollingTickers[ticker.symbol] ? 'rolling' : ''}`}>
|
||||
{ticker.price !== null && ticker.price !== undefined
|
||||
? `$${formatTickerPrice(ticker.price)}` : '-'}
|
||||
</span>
|
||||
</span>
|
||||
<span className={`ticker-change ${
|
||||
ticker.change === null || ticker.change === undefined
|
||||
? '' : ticker.change >= 0 ? 'positive' : 'negative'
|
||||
}`}>
|
||||
{ticker.change !== null && ticker.change !== undefined
|
||||
? `${ticker.change >= 0 ? '+' : ''}${ticker.change.toFixed(2)}%` : '-'}
|
||||
</span>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
<div className="portfolio-value">
|
||||
<span className="portfolio-label">投资组合</span>
|
||||
<span className="portfolio-amount">${formatNumber(portfolioData.netValue)}</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="main-container" ref={containerRef}>
|
||||
{/* Left Panel */}
|
||||
<div className="left-panel" style={{ width: `${leftWidth}%` }}>
|
||||
<div className="chart-section">
|
||||
<div className="view-container">
|
||||
<div className="view-nav-bar">
|
||||
<button
|
||||
className={`view-nav-btn ${currentView === 'traders' ? 'active' : ''}`}
|
||||
onClick={() => setCurrentView('traders')}
|
||||
>
|
||||
交易员
|
||||
</button>
|
||||
<button
|
||||
className={`view-nav-btn ${currentView === 'room' ? 'active' : ''}`}
|
||||
onClick={() => setCurrentView('room')}
|
||||
>
|
||||
交易室
|
||||
</button>
|
||||
<button
|
||||
className={`view-nav-btn ${currentView === 'explain' ? 'active' : ''}`}
|
||||
onClick={() => setCurrentView('explain')}
|
||||
>
|
||||
个股分析
|
||||
</button>
|
||||
<button
|
||||
className={`view-nav-btn ${currentView === 'chart' ? 'active' : ''}`}
|
||||
onClick={() => setCurrentView('chart')}
|
||||
>
|
||||
业绩图表
|
||||
</button>
|
||||
<button
|
||||
className={`view-nav-btn ${currentView === 'statistics' ? 'active' : ''}`}
|
||||
onClick={() => setCurrentView('statistics')}
|
||||
>
|
||||
统计
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<div className={viewClassName}>
|
||||
{/* Traders View */}
|
||||
<div className="view-panel">
|
||||
<Suspense fallback={<ViewLoadingFallback label="加载交易员视图..." />}>
|
||||
<TraderView {...agentRequests} />
|
||||
</Suspense>
|
||||
</div>
|
||||
|
||||
{/* Room View Panel */}
|
||||
<div className="view-panel">
|
||||
<Suspense fallback={<ViewLoadingFallback label="加载交易室..." />}>
|
||||
<RoomView
|
||||
bubbles={bubbles}
|
||||
bubbleFor={bubbleFor}
|
||||
leaderboard={leaderboard}
|
||||
agentProfilesByAgent={agentProfilesByAgent}
|
||||
feed={feed}
|
||||
onJumpToMessage={handleJumpToMessage}
|
||||
onOpenLaunchConfig={() => setIsRuntimeSettingsOpen(true)}
|
||||
/>
|
||||
</Suspense>
|
||||
</div>
|
||||
|
||||
{/* Stock Explain View Panel */}
|
||||
<div className="view-panel">
|
||||
<Suspense fallback={<ViewLoadingFallback label="加载个股分析..." />}>
|
||||
<StockExplainView
|
||||
tickers={displayTickers}
|
||||
holdings={holdings}
|
||||
trades={trades}
|
||||
leaderboard={leaderboard}
|
||||
feed={feed}
|
||||
priceHistoryByTicker={priceHistoryByTicker}
|
||||
ohlcHistoryByTicker={ohlcHistoryByTicker}
|
||||
selectedSymbol={selectedExplainSymbol}
|
||||
onSelectedSymbolChange={onSelectedExplainSymbolChange}
|
||||
selectedHistorySource={historySourceByTicker[selectedExplainSymbol] || null}
|
||||
explainEventsSnapshot={explainEventsByTicker[selectedExplainSymbol] || null}
|
||||
newsSnapshot={newsByTicker[selectedExplainSymbol] || null}
|
||||
insiderTradesSnapshot={insiderTradesByTicker[selectedExplainSymbol] || null}
|
||||
technicalIndicatorsSnapshot={technicalIndicatorsByTicker[selectedExplainSymbol] || null}
|
||||
onRequestHistory={stockRequests?.requestStockHistory}
|
||||
onRequestExplainEvents={stockRequests?.requestStockExplainEvents}
|
||||
onRequestNews={stockRequests?.requestStockNews}
|
||||
onRequestRangeExplain={stockRequests?.requestStockRangeExplain}
|
||||
onRequestNewsForDate={stockRequests?.requestStockNewsForDate}
|
||||
onRequestStory={stockRequests?.requestStockStory}
|
||||
onRequestInsiderTrades={stockRequests?.requestStockInsiderTrades}
|
||||
onRequestTechnicalIndicators={stockRequests?.requestStockTechnicalIndicators}
|
||||
currentDate={currentDate}
|
||||
onRequestSimilarDays={stockRequests?.requestStockSimilarDays}
|
||||
onRequestStockEnrich={stockRequests?.requestStockEnrich}
|
||||
/>
|
||||
</Suspense>
|
||||
</div>
|
||||
|
||||
{/* Chart View Panel */}
|
||||
<div className="view-panel">
|
||||
<div className="chart-container">
|
||||
<div className="chart-tabs-floating">
|
||||
<button
|
||||
className={`chart-tab ${chartTab === 'all' ? 'active' : ''}`}
|
||||
onClick={() => setChartTab('all')}
|
||||
>
|
||||
日线
|
||||
</button>
|
||||
</div>
|
||||
{currentView === 'chart' ? (
|
||||
<NetValueChart
|
||||
equity={portfolioData.equity}
|
||||
baseline={portfolioData.baseline}
|
||||
baseline_vw={portfolioData.baseline_vw}
|
||||
momentum={portfolioData.momentum}
|
||||
strategies={portfolioData.strategies}
|
||||
equity_return={portfolioData.equity_return}
|
||||
baseline_return={portfolioData.baseline_return}
|
||||
baseline_vw_return={portfolioData.baseline_vw_return}
|
||||
momentum_return={portfolioData.momentum_return}
|
||||
chartTab={chartTab}
|
||||
virtualTime={virtualTime}
|
||||
/>
|
||||
) : (
|
||||
<div style={{ height: '100%', minHeight: 320 }} />
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Statistics View Panel */}
|
||||
<div className="view-panel">
|
||||
<Suspense fallback={<ViewLoadingFallback label="加载统计视图..." />}>
|
||||
<StatisticsView
|
||||
trades={trades}
|
||||
holdings={holdings}
|
||||
stats={stats}
|
||||
portfolioData={portfolioData}
|
||||
baseline_vw={portfolioData.baseline_vw}
|
||||
equity={portfolioData.equity}
|
||||
leaderboard={leaderboard}
|
||||
/>
|
||||
</Suspense>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Resizer */}
|
||||
<div className={`resizer ${isResizing ? 'resizing' : ''}`} onMouseDown={onMouseDown} />
|
||||
|
||||
{/* Right Panel: Agent Feed */}
|
||||
<div className="right-panel" style={{ width: `${100 - leftWidth}%` }}>
|
||||
<Suspense fallback={<ViewLoadingFallback label="加载消息流..." />}>
|
||||
<AgentFeed ref={agentFeedRef} feed={feed} leaderboard={leaderboard} agentProfilesByAgent={agentProfilesByAgent} />
|
||||
</Suspense>
|
||||
</div>
|
||||
</div>
|
||||
</>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -5,10 +5,10 @@ import { formatNumber, formatFullNumber } from '../utils/formatters';
|
||||
/**
|
||||
* Helper function to get the start time of the most recent trading session
|
||||
* Trading session: 22:30 - next day 05:00
|
||||
* @param {Date|null} virtualTime - Virtual time from server (for mock mode), or null to use real time
|
||||
* @param {Date|null} virtualTime - Virtual time from server, or null to use real time
|
||||
*/
|
||||
function getRecentTradingSessionStart(virtualTime = null) {
|
||||
// Use virtual time if provided (for mock mode), otherwise use real time
|
||||
// Use virtual time if provided, otherwise use real time
|
||||
let now;
|
||||
if (virtualTime) {
|
||||
// Ensure virtualTime is a valid Date object
|
||||
|
||||
@@ -47,7 +47,7 @@ function getRankMedal(rank) {
|
||||
* Supports click and hover (1.5s) to show agent performance cards
|
||||
* Supports replay mode - completely independent from live mode
|
||||
*/
|
||||
export default function RoomView({ bubbles, bubbleFor, leaderboard, feed, onJumpToMessage, onOpenLaunchConfig }) {
|
||||
export default function RoomView({ bubbles, bubbleFor, leaderboard, agentProfilesByAgent, feed, onJumpToMessage, onOpenLaunchConfig }) {
|
||||
const canvasRef = useRef(null);
|
||||
const containerRef = useRef(null);
|
||||
|
||||
@@ -162,11 +162,14 @@ export default function RoomView({ bubbles, bubbleFor, leaderboard, feed, onJump
|
||||
const getAgentData = (agentId) => {
|
||||
const agent = AGENTS.find(a => a.id === agentId);
|
||||
if (!agent) return null;
|
||||
const profile = agentProfilesByAgent?.[agentId] || null;
|
||||
|
||||
// If no leaderboard data, return agent with default stats
|
||||
if (!leaderboard || !Array.isArray(leaderboard)) {
|
||||
return {
|
||||
...agent,
|
||||
modelName: profile?.model_name || null,
|
||||
modelProvider: profile?.model_provider || null,
|
||||
bull: { n: 0, win: 0, unknown: 0 },
|
||||
bear: { n: 0, win: 0, unknown: 0 },
|
||||
winRate: null,
|
||||
@@ -181,6 +184,8 @@ export default function RoomView({ bubbles, bubbleFor, leaderboard, feed, onJump
|
||||
if (!leaderboardData) {
|
||||
return {
|
||||
...agent,
|
||||
modelName: profile?.model_name || null,
|
||||
modelProvider: profile?.model_provider || null,
|
||||
bull: { n: 0, win: 0, unknown: 0 },
|
||||
bear: { n: 0, win: 0, unknown: 0 },
|
||||
winRate: null,
|
||||
@@ -193,6 +198,8 @@ export default function RoomView({ bubbles, bubbleFor, leaderboard, feed, onJump
|
||||
return {
|
||||
...agent,
|
||||
...leaderboardData,
|
||||
modelName: profile?.model_name || leaderboardData.modelName || null,
|
||||
modelProvider: profile?.model_provider || leaderboardData.modelProvider || null,
|
||||
avatar: agent.avatar // Always use the frontend's avatar URL
|
||||
};
|
||||
};
|
||||
|
||||
190
frontend/src/components/RuntimeLogsModal.jsx
Normal file
190
frontend/src/components/RuntimeLogsModal.jsx
Normal file
@@ -0,0 +1,190 @@
|
||||
import React, { useEffect, useMemo, useRef, useState } from 'react';
|
||||
import { createPortal } from 'react-dom';
|
||||
|
||||
export default function RuntimeLogsModal({
|
||||
isOpen,
|
||||
isLoading,
|
||||
logPayload,
|
||||
error,
|
||||
onClose,
|
||||
onRefresh
|
||||
}) {
|
||||
const logRef = useRef(null);
|
||||
const [autoRefresh, setAutoRefresh] = useState(true);
|
||||
const [followTail, setFollowTail] = useState(true);
|
||||
|
||||
const refreshIntervalMs = useMemo(() => 2000, []);
|
||||
|
||||
useEffect(() => {
|
||||
if (!isOpen || !autoRefresh) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
const timerId = window.setInterval(() => {
|
||||
onRefresh();
|
||||
}, refreshIntervalMs);
|
||||
|
||||
return () => window.clearInterval(timerId);
|
||||
}, [autoRefresh, isOpen, onRefresh, refreshIntervalMs]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!isOpen || !followTail || !logRef.current) {
|
||||
return;
|
||||
}
|
||||
|
||||
logRef.current.scrollTop = logRef.current.scrollHeight;
|
||||
}, [followTail, isOpen, logPayload?.content]);
|
||||
|
||||
if (!isOpen) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return createPortal(
|
||||
<div
|
||||
onClick={onClose}
|
||||
style={{
|
||||
position: 'fixed',
|
||||
inset: 0,
|
||||
background: 'rgba(15, 23, 42, 0.32)',
|
||||
display: 'flex',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
padding: 24,
|
||||
zIndex: 10000
|
||||
}}
|
||||
>
|
||||
<div
|
||||
onClick={(event) => event.stopPropagation()}
|
||||
style={{
|
||||
width: 'min(980px, 94vw)',
|
||||
maxHeight: '82vh',
|
||||
overflow: 'hidden',
|
||||
borderRadius: 16,
|
||||
border: '1px solid #D9E0E7',
|
||||
background: '#FFFFFF',
|
||||
boxShadow: '0 24px 60px rgba(15, 23, 42, 0.18)',
|
||||
display: 'grid',
|
||||
gridTemplateRows: 'auto auto minmax(0, 1fr)'
|
||||
}}
|
||||
>
|
||||
<div style={{
|
||||
padding: '18px 20px 10px',
|
||||
display: 'flex',
|
||||
justifyContent: 'space-between',
|
||||
alignItems: 'center',
|
||||
gap: 12
|
||||
}}>
|
||||
<div style={{ display: 'grid', gap: 4 }}>
|
||||
<div style={{ fontSize: 14, fontWeight: 800, color: '#111111' }}>运行日志</div>
|
||||
<div style={{ fontSize: 11, color: '#6B7280' }}>
|
||||
{logPayload?.run_id ? `任务 ${logPayload.run_id}` : '当前运行任务'}
|
||||
</div>
|
||||
</div>
|
||||
<div style={{ display: 'flex', gap: 8 }}>
|
||||
<button
|
||||
type="button"
|
||||
onClick={onRefresh}
|
||||
style={{
|
||||
padding: '7px 10px',
|
||||
borderRadius: 8,
|
||||
border: '1px solid #D0D7DE',
|
||||
background: '#FFFFFF',
|
||||
color: '#111111',
|
||||
fontSize: 11,
|
||||
fontWeight: 700,
|
||||
cursor: 'pointer'
|
||||
}}
|
||||
>
|
||||
刷新
|
||||
</button>
|
||||
<button
|
||||
type="button"
|
||||
onClick={onClose}
|
||||
style={{
|
||||
padding: '7px 10px',
|
||||
borderRadius: 8,
|
||||
border: '1px solid #111111',
|
||||
background: '#111111',
|
||||
color: '#FFFFFF',
|
||||
fontSize: 11,
|
||||
fontWeight: 700,
|
||||
cursor: 'pointer'
|
||||
}}
|
||||
>
|
||||
关闭
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div style={{
|
||||
padding: '0 20px 12px',
|
||||
display: 'flex',
|
||||
justifyContent: 'space-between',
|
||||
gap: 12,
|
||||
alignItems: 'center',
|
||||
flexWrap: 'wrap'
|
||||
}}>
|
||||
<div style={{ fontSize: 11, color: '#6B7280', fontFamily: '"Courier New", monospace' }}>
|
||||
{logPayload?.log_path || '未找到日志文件'}
|
||||
</div>
|
||||
{isLoading ? (
|
||||
<div style={{ fontSize: 11, color: '#2563EB', fontWeight: 700 }}>加载中...</div>
|
||||
) : error ? (
|
||||
<div style={{ fontSize: 11, color: '#B91C1C', fontWeight: 700 }}>{error}</div>
|
||||
) : null}
|
||||
</div>
|
||||
|
||||
<div style={{
|
||||
padding: '0 20px 12px',
|
||||
display: 'flex',
|
||||
gap: 16,
|
||||
alignItems: 'center',
|
||||
flexWrap: 'wrap'
|
||||
}}>
|
||||
<label style={{ display: 'inline-flex', alignItems: 'center', gap: 6, fontSize: 11, color: '#374151', cursor: 'pointer' }}>
|
||||
<input
|
||||
type="checkbox"
|
||||
checked={autoRefresh}
|
||||
onChange={(event) => setAutoRefresh(event.target.checked)}
|
||||
/>
|
||||
实时刷新
|
||||
</label>
|
||||
<label style={{ display: 'inline-flex', alignItems: 'center', gap: 6, fontSize: 11, color: '#374151', cursor: 'pointer' }}>
|
||||
<input
|
||||
type="checkbox"
|
||||
checked={followTail}
|
||||
onChange={(event) => setFollowTail(event.target.checked)}
|
||||
/>
|
||||
自动滚底
|
||||
</label>
|
||||
</div>
|
||||
|
||||
<div style={{ padding: '0 20px 20px', minHeight: 0 }}>
|
||||
<pre
|
||||
ref={logRef}
|
||||
style={{
|
||||
margin: 0,
|
||||
height: '100%',
|
||||
minHeight: 320,
|
||||
maxHeight: 'calc(82vh - 140px)',
|
||||
overflow: 'auto',
|
||||
borderRadius: 12,
|
||||
border: '1px solid #D9E0E7',
|
||||
background: '#0F172A',
|
||||
color: '#E2E8F0',
|
||||
padding: 16,
|
||||
fontSize: 11,
|
||||
lineHeight: 1.6,
|
||||
fontFamily: '"SFMono-Regular", Menlo, Consolas, "Liberation Mono", monospace',
|
||||
whiteSpace: 'pre-wrap',
|
||||
wordBreak: 'break-word'
|
||||
}}
|
||||
>
|
||||
{logPayload?.content || '暂无日志输出'}
|
||||
</pre>
|
||||
</div>
|
||||
</div>
|
||||
</div>,
|
||||
document.body
|
||||
);
|
||||
}
|
||||
@@ -1,12 +1,24 @@
|
||||
import React from 'react';
|
||||
import { createPortal } from 'react-dom';
|
||||
|
||||
const formatHistorySummary = (run) => {
|
||||
const updatedAt = run?.updated_at ? String(run.updated_at).replace("T", " ").slice(0, 16) : "未知时间";
|
||||
const mode = run?.bootstrap?.mode ? String(run.bootstrap.mode).toUpperCase() : "LIVE";
|
||||
const tickers = Array.isArray(run?.bootstrap?.tickers) ? run.bootstrap.tickers.length : 0;
|
||||
const assetValue = Number(run?.total_asset_value ?? 0).toFixed(2);
|
||||
const trades = Number(run?.total_trades ?? 0);
|
||||
return `${run.run_id} · ${updatedAt} · ${mode} · ${tickers}标的 · ${trades}笔交易 · $${assetValue}`;
|
||||
};
|
||||
|
||||
export default function RuntimeSettingsPanel({
|
||||
showTrigger = true,
|
||||
isOpen,
|
||||
isConnected,
|
||||
isSaving,
|
||||
feedback,
|
||||
launchMode,
|
||||
restoreRunId,
|
||||
runtimeHistoryRuns,
|
||||
scheduleMode,
|
||||
intervalMinutes,
|
||||
triggerTime,
|
||||
@@ -18,13 +30,14 @@ export default function RuntimeSettingsPanel({
|
||||
pollInterval,
|
||||
startDate,
|
||||
endDate,
|
||||
enableMock,
|
||||
watchlistSymbols,
|
||||
watchlistInputValue,
|
||||
watchlistSuggestions,
|
||||
onToggle,
|
||||
onClose,
|
||||
onScheduleModeChange,
|
||||
onLaunchModeChange,
|
||||
onRestoreRunIdChange,
|
||||
onIntervalMinutesChange,
|
||||
onTriggerTimeChange,
|
||||
onMaxCommCyclesChange,
|
||||
@@ -35,7 +48,6 @@ export default function RuntimeSettingsPanel({
|
||||
onPollIntervalChange,
|
||||
onStartDateChange,
|
||||
onEndDateChange,
|
||||
onEnableMockChange,
|
||||
onWatchlistInputChange,
|
||||
onWatchlistInputKeyDown,
|
||||
onWatchlistAdd,
|
||||
@@ -142,6 +154,75 @@ export default function RuntimeSettingsPanel({
|
||||
display: 'grid',
|
||||
gap: 12
|
||||
}}>
|
||||
<div style={{ fontSize: 12, fontWeight: 800, color: '#111111' }}>启动形式</div>
|
||||
<label style={{ display: 'grid', gap: 4 }}>
|
||||
<span style={{ fontSize: '10px', color: '#4B5563', fontWeight: 700 }}>任务模式</span>
|
||||
<select
|
||||
value={launchMode}
|
||||
onChange={(e) => onLaunchModeChange(e.target.value)}
|
||||
style={{
|
||||
padding: '9px 10px',
|
||||
borderRadius: 8,
|
||||
border: '1px solid #D0D7DE',
|
||||
background: '#FFFFFF',
|
||||
color: '#111111',
|
||||
fontSize: '12px'
|
||||
}}
|
||||
>
|
||||
<option value="fresh">重新启动</option>
|
||||
<option value="restore">从历史任务恢复</option>
|
||||
</select>
|
||||
</label>
|
||||
|
||||
{launchMode === 'restore' && (
|
||||
<>
|
||||
<label style={{ display: 'grid', gap: 4 }}>
|
||||
<span style={{ fontSize: '10px', color: '#4B5563', fontWeight: 700 }}>历史任务</span>
|
||||
<select
|
||||
value={restoreRunId}
|
||||
onChange={(e) => onRestoreRunIdChange(e.target.value)}
|
||||
style={{
|
||||
padding: '9px 10px',
|
||||
borderRadius: 8,
|
||||
border: '1px solid #D0D7DE',
|
||||
background: '#FFFFFF',
|
||||
color: '#111111',
|
||||
fontSize: '12px'
|
||||
}}
|
||||
>
|
||||
<option value="">请选择历史任务</option>
|
||||
{runtimeHistoryRuns.map((run) => (
|
||||
<option key={run.run_id} value={run.run_id}>
|
||||
{formatHistorySummary(run)}
|
||||
</option>
|
||||
))}
|
||||
</select>
|
||||
</label>
|
||||
|
||||
<div style={{
|
||||
fontSize: '11px',
|
||||
color: '#6B7280',
|
||||
lineHeight: 1.6,
|
||||
padding: '10px 12px',
|
||||
borderRadius: 8,
|
||||
background: '#FFFFFF',
|
||||
border: '1px dashed #D0D7DE'
|
||||
}}>
|
||||
恢复启动会从所选历史任务复制运行状态、组合、交易记录和 Agent 工作区资产,并以新的任务 ID 继续运行。
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{launchMode === 'fresh' && (
|
||||
<div style={{
|
||||
border: '1px solid #E5EAF1',
|
||||
borderRadius: 12,
|
||||
background: '#FCFDFE',
|
||||
padding: 14,
|
||||
display: 'grid',
|
||||
gap: 12
|
||||
}}>
|
||||
<div style={{ fontSize: 12, fontWeight: 800, color: '#111111' }}>自选股</div>
|
||||
|
||||
<div style={{
|
||||
@@ -272,16 +353,18 @@ export default function RuntimeSettingsPanel({
|
||||
恢复默认
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div style={{
|
||||
border: '1px solid #E5EAF1',
|
||||
borderRadius: 12,
|
||||
background: '#FCFDFE',
|
||||
padding: 14,
|
||||
display: 'grid',
|
||||
gap: 12
|
||||
}}>
|
||||
{launchMode === 'fresh' && (
|
||||
<div style={{
|
||||
border: '1px solid #E5EAF1',
|
||||
borderRadius: 12,
|
||||
background: '#FCFDFE',
|
||||
padding: 14,
|
||||
display: 'grid',
|
||||
gap: 12
|
||||
}}>
|
||||
<div style={{ fontSize: 12, fontWeight: 800, color: '#111111' }}>调度参数</div>
|
||||
<div style={{ display: 'grid', gridTemplateColumns: '1fr 1fr', gap: 8 }}>
|
||||
<label style={{ display: 'grid', gap: 4 }}>
|
||||
@@ -495,22 +578,8 @@ export default function RuntimeSettingsPanel({
|
||||
}}
|
||||
/>
|
||||
</label>
|
||||
|
||||
<label style={{ display: 'flex', alignItems: 'center', gap: 10, marginTop: 2 }}>
|
||||
<input
|
||||
type="checkbox"
|
||||
checked={enableMock}
|
||||
onChange={(e) => onEnableMockChange(e.target.checked)}
|
||||
style={{
|
||||
width: 16,
|
||||
height: 16,
|
||||
accentColor: '#0D47A1',
|
||||
cursor: 'pointer'
|
||||
}}
|
||||
/>
|
||||
<span style={{ fontSize: '11px', color: '#111111', fontWeight: 700 }}>启用模拟数据 (Mock)</span>
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div style={{
|
||||
border: '1px solid #E5EAF1',
|
||||
|
||||
@@ -34,6 +34,18 @@ const EVENT_FILTER_OPTIONS = [
|
||||
{ value: 'approval', label: '审批事件' }
|
||||
];
|
||||
|
||||
const SR_ONLY_STYLE = {
|
||||
position: 'absolute',
|
||||
width: 1,
|
||||
height: 1,
|
||||
padding: 0,
|
||||
margin: -1,
|
||||
overflow: 'hidden',
|
||||
clip: 'rect(0, 0, 0, 0)',
|
||||
whiteSpace: 'nowrap',
|
||||
border: 0
|
||||
};
|
||||
|
||||
function metricCard(label, value, accent, helper = null) {
|
||||
return (
|
||||
<div className="stat-card">
|
||||
@@ -722,6 +734,9 @@ export default function RuntimeView() {
|
||||
{sectionTitle(
|
||||
'近期事件',
|
||||
<select
|
||||
id="runtime-event-filter"
|
||||
name="runtime_event_filter"
|
||||
aria-label="筛选近期事件"
|
||||
value={eventFilter}
|
||||
onChange={(event) => setEventFilter(event.target.value)}
|
||||
style={{
|
||||
@@ -739,6 +754,9 @@ export default function RuntimeView() {
|
||||
))}
|
||||
</select>
|
||||
)}
|
||||
<label htmlFor="runtime-event-filter" style={SR_ONLY_STYLE}>
|
||||
筛选近期事件
|
||||
</label>
|
||||
<div style={{
|
||||
display: 'grid',
|
||||
gap: 8,
|
||||
|
||||
@@ -8,12 +8,36 @@ import { formatNumber, formatDateTime } from '../utils/formatters';
|
||||
* Left: Performance Overview (35%) | Right: Holdings + Trades (65%)
|
||||
* No scrolling - content fits within viewport with pagination
|
||||
*/
|
||||
export default function StatisticsView({ trades, holdings, stats, baseline_vw, equity, leaderboard }) {
|
||||
export default function StatisticsView({ trades, holdings, stats, baseline_vw, equity, leaderboard, portfolioData }) {
|
||||
const [holdingsPage, setHoldingsPage] = useState(1);
|
||||
const [tradesPage, setTradesPage] = useState(1);
|
||||
const holdingsPerPage = 5;
|
||||
const tradesPerPage = 8;
|
||||
|
||||
const effectiveStats = React.useMemo(() => {
|
||||
const base = stats && typeof stats === 'object' ? stats : {};
|
||||
const netValue = Number(portfolioData?.netValue ?? 0);
|
||||
const pnl = Number(portfolioData?.pnl ?? 0);
|
||||
const hasPortfolioValue = Number.isFinite(netValue) && netValue > 0;
|
||||
const hasMeaningfulStats = Number(base?.totalAssetValue ?? 0) > 0;
|
||||
|
||||
if (hasMeaningfulStats || !hasPortfolioValue) {
|
||||
return base;
|
||||
}
|
||||
|
||||
const cashHolding = Array.isArray(holdings)
|
||||
? holdings.find((item) => String(item?.ticker || '').toUpperCase() === 'CASH')
|
||||
: null;
|
||||
|
||||
return {
|
||||
...base,
|
||||
totalAssetValue: netValue,
|
||||
totalReturn: pnl,
|
||||
cashPosition: Number(cashHolding?.marketValue ?? cashHolding?.currentPrice ?? 0),
|
||||
totalTrades: Array.isArray(trades) ? trades.length : 0,
|
||||
};
|
||||
}, [holdings, portfolioData, stats, trades]);
|
||||
|
||||
// Calculate pagination for holdings
|
||||
const totalHoldingsPages = Math.ceil(holdings.length / holdingsPerPage);
|
||||
const holdingsStartIndex = (holdingsPage - 1) * holdingsPerPage;
|
||||
@@ -28,12 +52,12 @@ export default function StatisticsView({ trades, holdings, stats, baseline_vw, e
|
||||
|
||||
// Calculate excess return (Evatraders return - benchmark value-weighted return)
|
||||
const calculateExcessReturn = () => {
|
||||
if (!stats || !baseline_vw || baseline_vw.length === 0) {
|
||||
if (!effectiveStats || !baseline_vw || baseline_vw.length === 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// Get Evatraders return from stats
|
||||
const evatradersReturn = stats.totalReturn || 0; // Already in percentage
|
||||
const evatradersReturn = effectiveStats.totalReturn || 0; // Already in percentage
|
||||
|
||||
// Calculate benchmark return from baseline_vw
|
||||
// baseline_vw format: [{t: timestamp, v: value}, ...] or [value, ...]
|
||||
@@ -130,7 +154,7 @@ export default function StatisticsView({ trades, holdings, stats, baseline_vw, e
|
||||
borderRight: '2px solid #e0e0e0',
|
||||
overflow: 'hidden'
|
||||
}}>
|
||||
{stats ? (
|
||||
{effectiveStats ? (
|
||||
<div style={{
|
||||
padding: '24px',
|
||||
display: 'flex',
|
||||
@@ -179,7 +203,7 @@ export default function StatisticsView({ trades, holdings, stats, baseline_vw, e
|
||||
fontFamily: '"Courier New", monospace',
|
||||
lineHeight: 1
|
||||
}}>
|
||||
${formatNumber(stats.totalAssetValue || 0)}
|
||||
${formatNumber(effectiveStats.totalAssetValue || 0)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -272,10 +296,10 @@ export default function StatisticsView({ trades, holdings, stats, baseline_vw, e
|
||||
<div style={{
|
||||
fontSize: 28,
|
||||
fontWeight: 700,
|
||||
color: (stats.totalReturn || 0) >= 0 ? '#00C853' : '#FF1744',
|
||||
color: (effectiveStats.totalReturn || 0) >= 0 ? '#00C853' : '#FF1744',
|
||||
fontFamily: '"Courier New", monospace'
|
||||
}}>
|
||||
{(stats.totalReturn || 0) >= 0 ? '+' : ''}{(stats.totalReturn || 0).toFixed(2)}%
|
||||
{(effectiveStats.totalReturn || 0) >= 0 ? '+' : ''}{(effectiveStats.totalReturn || 0).toFixed(2)}%
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
@@ -304,7 +328,7 @@ export default function StatisticsView({ trades, holdings, stats, baseline_vw, e
|
||||
color: '#000000',
|
||||
fontFamily: '"Courier New", monospace'
|
||||
}}>
|
||||
${formatNumber(stats.cashPosition || 0)}
|
||||
${formatNumber(effectiveStats.cashPosition || 0)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -330,13 +354,13 @@ export default function StatisticsView({ trades, holdings, stats, baseline_vw, e
|
||||
color: '#000000',
|
||||
fontFamily: '"Courier New", monospace'
|
||||
}}>
|
||||
{stats.totalTrades || 0}
|
||||
{effectiveStats.totalTrades || 0}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Ticker Weights - Compact */}
|
||||
{stats.tickerWeights && Object.keys(stats.tickerWeights).length > 0 && (
|
||||
{effectiveStats?.tickerWeights && Object.keys(effectiveStats.tickerWeights).length > 0 && (
|
||||
<div style={{
|
||||
marginTop: 'auto',
|
||||
paddingTop: 20,
|
||||
@@ -358,7 +382,7 @@ export default function StatisticsView({ trades, holdings, stats, baseline_vw, e
|
||||
gap: 8,
|
||||
maxHeight: 120
|
||||
}}>
|
||||
{Object.entries(stats.tickerWeights).map(([ticker, weight]) => {
|
||||
{Object.entries(effectiveStats.tickerWeights).map(([ticker, weight]) => {
|
||||
const weightValue = Number(weight);
|
||||
const isNegative = weightValue < 0;
|
||||
const displayWeight = (weightValue * 100).toFixed(1);
|
||||
|
||||
@@ -33,6 +33,9 @@ export default function StockExplainView({
|
||||
insiderTradesSnapshot,
|
||||
technicalIndicatorsSnapshot,
|
||||
onRequestRangeExplain,
|
||||
onRequestHistory,
|
||||
onRequestExplainEvents,
|
||||
onRequestNews,
|
||||
onRequestNewsForDate,
|
||||
onRequestStory,
|
||||
onRequestInsiderTrades,
|
||||
@@ -77,6 +80,7 @@ export default function StockExplainView({
|
||||
visibleNews,
|
||||
newsCategories,
|
||||
visibleNewsByCategory,
|
||||
selectedNewsFreshness,
|
||||
selectedRangeWindow,
|
||||
selectedRangeExplain,
|
||||
latestSignal,
|
||||
@@ -141,11 +145,37 @@ export default function StockExplainView({
|
||||
setActiveNewsSentiment('all');
|
||||
}, [selectedSymbol, selectedEventDate]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!selectedSymbol) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (onRequestHistory && (!Array.isArray(ohlcHistoryByTicker?.[selectedSymbol]) || ohlcHistoryByTicker[selectedSymbol].length === 0)) {
|
||||
onRequestHistory(selectedSymbol);
|
||||
}
|
||||
|
||||
if (onRequestExplainEvents && !explainEventsSnapshot) {
|
||||
onRequestExplainEvents(selectedSymbol);
|
||||
}
|
||||
|
||||
if (onRequestNews && (!Array.isArray(newsSnapshot?.items) || newsSnapshot.items.length === 0)) {
|
||||
onRequestNews(selectedSymbol);
|
||||
}
|
||||
}, [
|
||||
explainEventsSnapshot,
|
||||
newsSnapshot,
|
||||
ohlcHistoryByTicker,
|
||||
onRequestExplainEvents,
|
||||
onRequestHistory,
|
||||
onRequestNews,
|
||||
selectedSymbol,
|
||||
]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!selectedSymbol || !selectedEventDate || !onRequestNewsForDate) {
|
||||
return;
|
||||
}
|
||||
if (Array.isArray(newsSnapshot?.byDate?.[selectedEventDate]) && newsSnapshot.byDate[selectedEventDate].length > 0) {
|
||||
if (Object.prototype.hasOwnProperty.call(newsSnapshot?.byDate || {}, selectedEventDate)) {
|
||||
return;
|
||||
}
|
||||
onRequestNewsForDate(selectedSymbol, selectedEventDate);
|
||||
@@ -155,21 +185,21 @@ export default function StockExplainView({
|
||||
if (!selectedSymbol || !onRequestStory || !currentDate) {
|
||||
return;
|
||||
}
|
||||
if (selectedStory?.story) {
|
||||
if (Object.prototype.hasOwnProperty.call(newsSnapshot?.storyCache || {}, currentDate)) {
|
||||
return;
|
||||
}
|
||||
onRequestStory(selectedSymbol, currentDate);
|
||||
}, [currentDate, onRequestStory, selectedStory, selectedSymbol]);
|
||||
}, [currentDate, newsSnapshot, onRequestStory, selectedStory, selectedSymbol]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!selectedSymbol || !selectedEventDate || !onRequestSimilarDays) {
|
||||
return;
|
||||
}
|
||||
if (selectedSimilarDays?.items?.length) {
|
||||
if (Object.prototype.hasOwnProperty.call(newsSnapshot?.similarDaysCache || {}, selectedEventDate)) {
|
||||
return;
|
||||
}
|
||||
onRequestSimilarDays(selectedSymbol, selectedEventDate);
|
||||
}, [onRequestSimilarDays, selectedEventDate, selectedSimilarDays, selectedSymbol]);
|
||||
}, [newsSnapshot, onRequestSimilarDays, selectedEventDate, selectedSimilarDays, selectedSymbol]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!selectedSymbol || !onRequestTechnicalIndicators) {
|
||||
@@ -337,6 +367,7 @@ export default function StockExplainView({
|
||||
newsSnapshot={newsSnapshot}
|
||||
visibleNewsByCategory={visibleNewsByCategory}
|
||||
visibleNews={visibleNews}
|
||||
selectedNewsFreshness={selectedNewsFreshness}
|
||||
activeNewsCategory={activeNewsCategory}
|
||||
onSelectNewsCategory={setActiveNewsCategory}
|
||||
activeNewsSentiment={activeNewsSentiment}
|
||||
|
||||
@@ -38,6 +38,18 @@ export default function TraderView({
|
||||
onWorkspaceFileSave,
|
||||
onUploadExternalSkill
|
||||
}) {
|
||||
const srOnlyStyle = {
|
||||
position: 'absolute',
|
||||
width: 1,
|
||||
height: 1,
|
||||
padding: 0,
|
||||
margin: -1,
|
||||
overflow: 'hidden',
|
||||
clip: 'rect(0, 0, 0, 0)',
|
||||
whiteSpace: 'nowrap',
|
||||
border: 0
|
||||
};
|
||||
|
||||
const [expandedSkillKey, setExpandedSkillKey] = useState(null);
|
||||
const [newLocalSkillName, setNewLocalSkillName] = useState('');
|
||||
const [externalSkillFile, setExternalSkillFile] = useState(null);
|
||||
@@ -460,6 +472,9 @@ export default function TraderView({
|
||||
本地技能 SKILL.md
|
||||
</div>
|
||||
<textarea
|
||||
id={`local-skill-${selectedAgentId}-${skill.skill_name}`}
|
||||
name={`local_skill_${selectedAgentId}_${skill.skill_name}`}
|
||||
aria-label={`${skill.skill_name} 本地技能内容`}
|
||||
value={skillDraft}
|
||||
onChange={(e) => onLocalSkillDraftChange(skill.skill_name, e.target.value)}
|
||||
style={{
|
||||
@@ -557,6 +572,9 @@ export default function TraderView({
|
||||
</div>
|
||||
|
||||
<textarea
|
||||
id={`workspace-editor-${selectedAgentId}-${selectedWorkspaceFile || 'file'}`}
|
||||
name={`workspace_editor_${selectedAgentId}_${selectedWorkspaceFile || 'file'}`}
|
||||
aria-label={`编辑 ${selectedWorkspaceFile || '工作区文件'} 内容`}
|
||||
value={workspaceDraftContent}
|
||||
onChange={(e) => onWorkspaceDraftChange(e.target.value)}
|
||||
placeholder={isWorkspaceFileLoading ? '加载中...' : '输入 markdown 内容'}
|
||||
@@ -687,7 +705,13 @@ export default function TraderView({
|
||||
}}>
|
||||
<div style={{ fontSize: 12, fontWeight: 800, color: '#111111' }}>创建本地技能</div>
|
||||
<div style={{ display: 'flex', gap: 8, alignItems: 'center' }}>
|
||||
<label htmlFor="new-local-skill-name" style={srOnlyStyle}>
|
||||
输入本地技能名称
|
||||
</label>
|
||||
<input
|
||||
id="new-local-skill-name"
|
||||
name="new_local_skill_name"
|
||||
aria-label="输入本地技能名称"
|
||||
value={newLocalSkillName}
|
||||
onChange={(e) => setNewLocalSkillName(e.target.value)}
|
||||
placeholder="输入技能名,例如 event_playbook"
|
||||
@@ -741,7 +765,13 @@ export default function TraderView({
|
||||
支持上传 .zip(包内需包含一个技能目录及 SKILL.md)
|
||||
</div>
|
||||
<div style={{ display: 'flex', gap: 8, alignItems: 'center', flexWrap: 'wrap' }}>
|
||||
<label htmlFor="external-skill-zip" style={srOnlyStyle}>
|
||||
上传外部技能 zip 包
|
||||
</label>
|
||||
<input
|
||||
id="external-skill-zip"
|
||||
name="external_skill_zip"
|
||||
aria-label="上传外部技能 zip 包"
|
||||
type="file"
|
||||
accept=".zip,application/zip"
|
||||
onChange={async (e) => {
|
||||
|
||||
@@ -19,6 +19,18 @@ export default function WatchlistPanel({
|
||||
onSuggestionClick,
|
||||
onSave
|
||||
}) {
|
||||
const srOnlyStyle = {
|
||||
position: 'absolute',
|
||||
width: 1,
|
||||
height: 1,
|
||||
padding: 0,
|
||||
margin: -1,
|
||||
overflow: 'hidden',
|
||||
clip: 'rect(0, 0, 0, 0)',
|
||||
whiteSpace: 'nowrap',
|
||||
border: 0
|
||||
};
|
||||
|
||||
return (
|
||||
<div style={{ display: 'flex', alignItems: 'center', gap: 8, minWidth: 0, position: 'relative', marginLeft: -6 }}>
|
||||
<button
|
||||
@@ -117,7 +129,13 @@ export default function WatchlistPanel({
|
||||
</div>
|
||||
|
||||
<div style={{ display: 'flex', gap: 8 }}>
|
||||
<label htmlFor="watchlist-symbol-input" style={srOnlyStyle}>
|
||||
输入股票代码
|
||||
</label>
|
||||
<input
|
||||
id="watchlist-symbol-input"
|
||||
name="watchlist_symbol"
|
||||
aria-label="输入股票代码"
|
||||
value={inputValue}
|
||||
onChange={(e) => onInputChange(e.target.value)}
|
||||
onKeyDown={onInputKeyDown}
|
||||
|
||||
107
frontend/src/components/explain/ExplainInsiderSection.jsx
Normal file
107
frontend/src/components/explain/ExplainInsiderSection.jsx
Normal file
@@ -0,0 +1,107 @@
|
||||
import React from 'react';
|
||||
import { formatDateTime, formatNumber } from '../../utils/formatters';
|
||||
|
||||
export default function ExplainInsiderSection({
|
||||
insiderTrades,
|
||||
selectedSymbol,
|
||||
isOpen,
|
||||
onToggle,
|
||||
onRequest,
|
||||
}) {
|
||||
const handleRefresh = () => {
|
||||
if (onRequest) {
|
||||
onRequest(selectedSymbol);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="section">
|
||||
<div className="section-header">
|
||||
<h2 className="section-title">内部人交易</h2>
|
||||
<div style={{ display: 'flex', alignItems: 'center', gap: 12, flexWrap: 'wrap' }}>
|
||||
<div style={{ fontSize: 11, color: '#666666' }}>
|
||||
{insiderTrades.length} 笔内部人交易记录
|
||||
</div>
|
||||
<button
|
||||
onClick={handleRefresh}
|
||||
style={{
|
||||
border: '1px solid #111111',
|
||||
background: '#ffffff',
|
||||
color: '#111111',
|
||||
padding: '5px 8px',
|
||||
fontFamily: 'inherit',
|
||||
fontSize: 10,
|
||||
cursor: 'pointer'
|
||||
}}
|
||||
>
|
||||
刷新
|
||||
</button>
|
||||
<button
|
||||
onClick={onToggle}
|
||||
style={{
|
||||
border: '1px solid #111111',
|
||||
background: isOpen ? '#111111' : '#ffffff',
|
||||
color: isOpen ? '#ffffff' : '#111111',
|
||||
padding: '7px 10px',
|
||||
fontFamily: 'inherit',
|
||||
fontSize: 11,
|
||||
fontWeight: 700,
|
||||
cursor: 'pointer'
|
||||
}}
|
||||
>
|
||||
{isOpen ? '收起' : `展开 ${insiderTrades.length}`}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{!isOpen ? (
|
||||
<div className="empty-state">点击展开查看内部人交易详情</div>
|
||||
) : insiderTrades.length === 0 ? (
|
||||
<div className="empty-state">暂无内部人交易数据</div>
|
||||
) : (
|
||||
<div className="table-wrapper">
|
||||
<table className="data-table">
|
||||
<thead>
|
||||
<tr>
|
||||
<th>交易日期</th>
|
||||
<th>内部人</th>
|
||||
<th>职位</th>
|
||||
<th>方向</th>
|
||||
<th>股份数</th>
|
||||
<th>价格</th>
|
||||
<th>持仓变化</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{insiderTrades.slice(0, 20).map((trade, index) => {
|
||||
const isBuy = trade.is_buy;
|
||||
const holdingChange = trade.holding_change;
|
||||
return (
|
||||
<tr key={trade.transaction_date + '-' + trade.name + '-' + index}>
|
||||
<td>{trade.transaction_date || '-'}</td>
|
||||
<td>{trade.name || '-'}</td>
|
||||
<td>{trade.title || '-'}</td>
|
||||
<td style={{
|
||||
fontWeight: 700,
|
||||
color: isBuy === true ? '#00C853' : isBuy === false ? '#FF1744' : '#666666'
|
||||
}}>
|
||||
{isBuy === true ? '买入' : isBuy === false ? '卖出' : '-'}
|
||||
</td>
|
||||
<td>{trade.transaction_shares != null ? formatNumber(trade.transaction_shares) : '-'}</td>
|
||||
<td>${trade.transaction_price_per_share != null ? Number(trade.transaction_price_per_share).toFixed(2) : '-'}</td>
|
||||
<td style={{
|
||||
color: holdingChange != null ? (holdingChange > 0 ? '#00C853' : '#FF1744') : '#666666',
|
||||
fontWeight: holdingChange != null ? 700 : 400
|
||||
}}>
|
||||
{holdingChange != null ? (holdingChange > 0 ? '+' : '') + formatNumber(holdingChange) : '-'}
|
||||
</td>
|
||||
</tr>
|
||||
);
|
||||
})}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,6 +1,12 @@
|
||||
import React from 'react';
|
||||
import { formatDateTime } from '../../utils/formatters';
|
||||
|
||||
function renderFreshness(freshness) {
|
||||
if (!freshness || typeof freshness !== 'object') return null;
|
||||
const lastFetch = freshness.last_news_fetch || '-';
|
||||
return `新闻更新到 ${lastFetch}${freshness.refreshed ? ' · 本次已刷新' : ''}`;
|
||||
}
|
||||
|
||||
function categoryLabel(value) {
|
||||
const normalized = String(value || '').trim().toLowerCase();
|
||||
const labels = {
|
||||
@@ -47,6 +53,7 @@ export default function ExplainNewsSection({
|
||||
newsSnapshot,
|
||||
visibleNewsByCategory,
|
||||
visibleNews,
|
||||
selectedNewsFreshness,
|
||||
activeNewsCategory,
|
||||
onSelectNewsCategory,
|
||||
activeNewsSentiment,
|
||||
@@ -64,6 +71,11 @@ export default function ExplainNewsSection({
|
||||
<div style={{ fontSize: 11, color: '#666666' }}>
|
||||
{newsSnapshot?.source ? `最近 ${visibleNewsByCategory.length} 条 · ${newsSnapshot.source}` : `最近 ${visibleNewsByCategory.length} 条真实新闻`}
|
||||
</div>
|
||||
{renderFreshness(selectedNewsFreshness) ? (
|
||||
<div style={{ fontSize: 11, color: '#666666' }}>
|
||||
{renderFreshness(selectedNewsFreshness)}
|
||||
</div>
|
||||
) : null}
|
||||
<button
|
||||
onClick={onToggle}
|
||||
style={{
|
||||
|
||||
@@ -11,6 +11,37 @@ export default function ExplainPriceSection({
|
||||
isOpen,
|
||||
onToggle,
|
||||
}) {
|
||||
const timeTicks = (() => {
|
||||
const candles = Array.isArray(chartModel?.candles) ? chartModel.candles : [];
|
||||
if (!candles.length) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const targetCount = Math.min(4, candles.length);
|
||||
const step = Math.max(1, Math.floor((candles.length - 1) / Math.max(targetCount - 1, 1)));
|
||||
const ticks = [];
|
||||
|
||||
for (let index = 0; index < candles.length; index += step) {
|
||||
const candle = candles[index];
|
||||
const rawLabel = candle.startLabel || candle.time || candle.date || '';
|
||||
ticks.push({
|
||||
x: candle.centerX,
|
||||
label: String(rawLabel).slice(5, 16).replace('T', ' '),
|
||||
});
|
||||
}
|
||||
|
||||
const lastCandle = candles[candles.length - 1];
|
||||
const lastLabel = String(lastCandle.endLabel || lastCandle.time || lastCandle.date || '').slice(5, 16).replace('T', ' ');
|
||||
if (ticks.length === 0 || ticks[ticks.length - 1]?.x !== lastCandle.centerX) {
|
||||
ticks.push({
|
||||
x: lastCandle.centerX,
|
||||
label: lastLabel,
|
||||
});
|
||||
}
|
||||
|
||||
return ticks;
|
||||
})();
|
||||
|
||||
return (
|
||||
<div className="section">
|
||||
<div className="section-header">
|
||||
@@ -66,12 +97,35 @@ export default function ExplainPriceSection({
|
||||
strokeWidth="1"
|
||||
/>
|
||||
|
||||
{timeTicks.map((tick) => (
|
||||
<g key={`${tick.x}-${tick.label}`}>
|
||||
<line
|
||||
x1={tick.x}
|
||||
y1={chartModel.height - chartModel.padding}
|
||||
x2={tick.x}
|
||||
y2={chartModel.height - chartModel.padding + 4}
|
||||
stroke="#666666"
|
||||
strokeWidth="1"
|
||||
/>
|
||||
<text
|
||||
x={tick.x}
|
||||
y={chartModel.height - chartModel.padding + 16}
|
||||
fontSize="10"
|
||||
fill="#666666"
|
||||
textAnchor="middle"
|
||||
>
|
||||
{tick.label}
|
||||
</text>
|
||||
</g>
|
||||
))}
|
||||
|
||||
{chartModel.candles.length > 1 ? chartModel.candles.map((candle) => {
|
||||
const rising = candle.close >= candle.open;
|
||||
const stroke = rising ? '#00C853' : '#FF1744';
|
||||
const fill = rising ? 'rgba(0, 200, 83, 0.16)' : 'rgba(255, 23, 68, 0.16)';
|
||||
return (
|
||||
<g key={candle.id}>
|
||||
<title>{`${candle.startLabel || candle.time || candle.date || ''} → ${candle.endLabel || candle.time || candle.date || ''}`}</title>
|
||||
<line
|
||||
x1={candle.centerX}
|
||||
y1={candle.highY}
|
||||
@@ -123,7 +177,7 @@ export default function ExplainPriceSection({
|
||||
stroke={marker.isSelected ? '#111111' : '#ffffff'}
|
||||
strokeWidth={marker.isSelected ? '2.5' : '2'}
|
||||
/>
|
||||
<title>{`${marker.title} · ${marker.dateKey || ''}${marker.count ? ` · ${marker.count} 条新闻` : ''}`}</title>
|
||||
<title>{`${marker.title} · ${marker.timestamp || marker.dateKey || ''}${marker.count ? ` · ${marker.count} 条新闻` : ''}`}</title>
|
||||
</g>
|
||||
);
|
||||
})}
|
||||
|
||||
@@ -1,6 +1,12 @@
|
||||
import React from 'react';
|
||||
import { formatTickerPrice } from '../../utils/formatters';
|
||||
|
||||
function renderFreshness(freshness) {
|
||||
if (!freshness || typeof freshness !== 'object') return null;
|
||||
const lastFetch = freshness.last_news_fetch || '-';
|
||||
return `新闻更新到 ${lastFetch}${freshness.refreshed ? ' · 本次已刷新' : ''}`;
|
||||
}
|
||||
|
||||
function renderSentimentLabel(value) {
|
||||
const normalized = String(value || '').trim().toLowerCase();
|
||||
if (normalized === 'positive') return '利多';
|
||||
@@ -94,6 +100,11 @@ export default function ExplainRangeSection({
|
||||
: `分析来源 · ${renderAnalysisSourceLabel(selectedRangeExplain.analysis.analysis_source)}`}
|
||||
</div>
|
||||
) : null}
|
||||
{renderFreshness(selectedRangeExplain?.freshness) ? (
|
||||
<div style={{ fontSize: 11, color: '#666666' }}>
|
||||
{renderFreshness(selectedRangeExplain?.freshness)}
|
||||
</div>
|
||||
) : null}
|
||||
<button
|
||||
onClick={onToggle}
|
||||
style={{
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user