18 Commits

Author SHA1 Message Date
9bcc4221a4 refactor: remove mock trading functionality from backend and frontend
Removes all mock price simulation features:
- Delete MockPriceManager from backend/data/
- Remove mock_mode, enable_mock, is_mock_mode flags from services
- Remove mock CLI options and config
- Remove mock mode UI components and state from frontend
- Update tests to remove mock references

Now system supports only live and backtest modes.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-26 13:38:51 +08:00
fecf8a9466 Fix agent workspace file handlers 2026-03-26 13:07:47 +08:00
86eb8c37a9 Fix empty dashboard fallback and websocket gating 2026-03-26 12:39:13 +08:00
1f9063edad Fix provider router import for news ingest 2026-03-26 11:29:14 +08:00
7e7a58769a feat: 添加新闻增量刷新和前端组件修复
- 新增 refresh_news_incremental/refresh_news_for_symbols 函数支持增量新闻获取
- 在 live cycle 中集成新闻刷新逻辑
- AgentFeed 支持 agentProfilesByAgent 显示模型信息
- StatisticsView 修复 stats 计算逻辑,使用 portfolioData 作为 fallback
- StockExplainView 修复 useEffect 依赖项问题
- AppShell/RoomView 传递 agentProfilesByAgent 属性
- start-dev.sh 调整日志级别为 warning 减少噪音

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-26 10:50:45 +08:00
16bb3c4211 docs(README): 更新项目说明文档,完善架构与启动指南
- 调整 CLAUDE.md 内容,增加使用指导和架构概览
- 补充微服务启动说明及单服务启动命令
- 明确 Gateway 服务器四阶段启动流程和职责划分
- 细化后端目录结构说明,补充主要文件职责描述
- 新增系统分层结构图,优化整体架构理解
- 更新 .gitignore,添加 runs 目录忽略规则
- 同步 .omc 相关状态文件,更新项目状态跟踪信息
2026-03-24 17:19:31 +08:00
da6d642aaa Migrate agent control reads and writes to REST 2026-03-24 16:30:36 +08:00
8d6c3c5647 Add restore-mode task launch flow 2026-03-24 15:27:35 +08:00
6413edf8c9 Refine runtime data flow and UI layering 2026-03-24 15:00:35 +08:00
c5eaf2b5ad Fix runtime logging and frontend app regressions 2026-03-24 10:58:41 +08:00
032c37538f fix(frontend): 添加缺失的 lastDayHistory 字段到 runtimeStore
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-23 20:15:27 +08:00
456748b01e fix(frontend): 修复 now 变量重复声明导致的空白页面
- 从 runtimeStore 移除 now(应从 uiStore 获取)
- 修复重复声明错误

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-23 18:12:32 +08:00
609b509446 feat(frontend): 完成 Zustand 状态管理迁移
- 将 App.jsx 中的 useState 迁移到 5 个 Zustand stores
- useRuntimeStore: 连接状态、运行时配置
- useMarketStore: 市场数据、股票价格
- usePortfolioStore: 组合、持仓、交易
- useAgentStore: Agent 技能,工作区
- useUIStore: UI 状态、视图切换
- 保留 tickers useState(需与 INITIAL_TICKERS 同步)
- 恢复 newsApi.js 和 tradingApi.js

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-23 18:03:23 +08:00
38102d0805 feat(backend): Agent 系统和 API 增强
- task_delegator: 完善团队任务分发逻辑
- runtime API: 增强运行时管理功能
- skills_manager: 技能管理改进
- tool_guard: 工具调用守卫优化
- evo_agent: 核心 Agent 改进

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-23 17:46:06 +08:00
3448667b79 feat: 微服务架构拆分和前后端优化
后端:
- 拆分出 agent_service, runtime_service, trading_service, news_service
- Gateway 模块化拆分 (gateway_*.py)
- 添加 domains/ 领域层
- 新增 control_client, runtime_client
- 更新 start-dev.sh 支持 split 服务模式

前端:
- 完善 API 服务层 (newsApi, tradingApi)
- 更新 vite.config.js
- Explain 组件优化

测试:
- 添加多个服务 app 测试

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-23 17:45:39 +08:00
0f1bc2bb39 feat(frontend): 添加 Zustand store 架构
- 创建 5 个领域 store:runtimeStore, marketStore, portfolioStore, agentStore, uiStore
- 更新 CLAUDE.md 记录架构改进
- Zustand 已安装但 stores 尚未在 App.jsx 中使用(渐进迁移)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-23 17:44:17 +08:00
06a23c32a4 refactor: Fix code quality issues identified in analysis
1. Rename factory.py's EvoAgent data class to AgentConfig
   - Avoids naming conflict with base/evo_agent.py's EvoAgent

2. Export pipeline_runner functions in backend/core/__init__.py
   - Add create_agents, create_long_term_memory, stop_gateway

3. Consolidate PromptLoader to singleton pattern
   - Add get_prompt_loader() singleton function
   - Update all usages to use singleton

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-20 01:07:53 +08:00
5b925fbe02 feat: Refactor services architecture and update project structure
- Remove Docker-based microservices (docker-compose.yml, Makefile, Dockerfiles)
- Update start-dev.sh to use backend.app:app entry point
- Add shared schema and client modules for service communication
- Add team coordination modules (messenger, registry, task_delegator, coordinator)
- Add evaluation hooks and skill adaptation hooks
- Add skill template and gateway server
- Update frontend WebSocket URL configuration
- Add explain components for insider and technical analysis

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-20 00:57:09 +08:00
144 changed files with 17040 additions and 6419 deletions

41
.env.example Normal file
View File

@@ -0,0 +1,41 @@
# Copy this file to `.env` for local development.
# Keep `.env` untracked and never paste real secrets into tracked files.
# ================== General Configuration | 通用配置 ==================
TICKERS=AAPL,MSFT,GOOGL,NVDA,TSLA,META,AMZN
# Financial Data API
# At least `FINANCIAL_DATASETS_API_KEY` is required when using `FIN_DATA_SOURCE=financial_datasets`.
# `FINNHUB_API_KEY` is recommended for `FIN_DATA_SOURCE=finnhub` and required for live mode.
FIN_DATA_SOURCE=finnhub
ENABLED_DATA_SOURCES=financial_datasets,finnhub,yfinance,local_csv
FINANCIAL_DATASETS_API_KEY=
FINNHUB_API_KEY=
POLYGON_API_KEY=
MARKET_DB_PATH=
# Model API
OPENAI_API_KEY=
OPENAI_BASE_URL=
MODEL_NAME=qwen3-max-preview
EXPLAIN_ENRICH_USE_LLM=false
EXPLAIN_ENRICH_MODEL_PROVIDER=
EXPLAIN_ENRICH_MODEL_NAME=
EXPLAIN_RANGE_USE_LLM=
# Memory module
MEMORY_API_KEY=
# ================== Agent-Specific Model Configuration | Agent特定模型配置 ==================
AGENT_SENTIMENT_ANALYST_MODEL_NAME=deepseek-v3.2-exp
AGENT_TECHNICAL_ANALYST_MODEL_NAME=glm-4.6
AGENT_FUNDAMENTALS_ANALYST_MODEL_NAME=qwen3-max-preview
AGENT_VALUATION_ANALYST_MODEL_NAME=Moonshot-Kimi-K2-Instruct
AGENT_RISK_MANAGER_MODEL_NAME=qwen3-max-preview
AGENT_PORTFOLIO_MANAGER_MODEL_NAME=qwen3-max-preview
# ================== Advanced Configuration | 高阶配置 ==================
MAX_COMM_CYCLES=2
MARGIN_REQUIREMENT=0.5
DATA_START_DATE=2022-01-01
AUTO_UPDATE_DATA=true

5
.gitignore vendored
View File

@@ -54,10 +54,13 @@ outputs/
/smoke_live_mock/
# Local tooling state
/.omc/
.omc/
/.pydeps/
/referance/
# Run outputs
/runs/
# Data files
backend/data/ret_data/

View File

@@ -1,7 +1,7 @@
{
"version": "1.0.0",
"lastScanned": 1773304964541,
"projectRoot": "/Users/cillin/workspeace/agentscope-samples/evotraders",
"lastScanned": 1774313111650,
"projectRoot": "/Users/cillin/workspeace/evotraders",
"techStack": {
"languages": [
{
@@ -40,7 +40,8 @@
"isMonorepo": false,
"workspaces": [],
"mainDirectories": [
"docs"
"docs",
"scripts"
],
"gitBranches": {
"defaultBranch": "main",
@@ -52,26 +53,54 @@
"backend": {
"path": "backend",
"purpose": null,
"fileCount": 3,
"lastAccessed": 1773304964533,
"fileCount": 4,
"lastAccessed": 1774313111639,
"keyFiles": [
"__init__.py",
"cli.py",
"gateway_server.py",
"main.py"
]
},
"backtest": {
"path": "backtest",
"purpose": null,
"fileCount": 0,
"lastAccessed": 1774313111640,
"keyFiles": []
},
"data": {
"path": "data",
"purpose": "Data files",
"fileCount": 3,
"lastAccessed": 1774313111640,
"keyFiles": [
"market_research.db",
"market_research.db-shm",
"market_research.db-wal"
]
},
"deploy": {
"path": "deploy",
"purpose": null,
"fileCount": 0,
"lastAccessed": 1774313111640,
"keyFiles": []
},
"docs": {
"path": "docs",
"purpose": "Documentation",
"fileCount": 0,
"lastAccessed": 1773304964533,
"keyFiles": []
"fileCount": 1,
"lastAccessed": 1774313111641,
"keyFiles": [
"compat-removal-plan.md"
]
},
"evotraders.egg-info": {
"path": "evotraders.egg-info",
"purpose": null,
"fileCount": 6,
"lastAccessed": 1773304964534,
"lastAccessed": 1774313111641,
"keyFiles": [
"PKG-INFO",
"SOURCES.txt",
@@ -83,8 +112,8 @@
"frontend": {
"path": "frontend",
"purpose": null,
"fileCount": 12,
"lastAccessed": 1773304964535,
"fileCount": 13,
"lastAccessed": 1774313111641,
"keyFiles": [
"README.md",
"components.json",
@@ -93,239 +122,386 @@
"index.css"
]
},
"live": {
"path": "live",
"purpose": null,
"fileCount": 0,
"lastAccessed": 1774313111642,
"keyFiles": []
},
"logs": {
"path": "logs",
"purpose": null,
"fileCount": 6,
"lastAccessed": 1774313111642,
"keyFiles": [
"2026-03-16_00-48-03.log",
"2026-03-18_23-17-29.log",
"2026-03-18_23-17-30.log",
"2026-03-19_00-18-04.log",
"2026-03-19_00-34-21.log"
]
},
"reference": {
"path": "reference",
"purpose": null,
"fileCount": 0,
"lastAccessed": 1774313111643,
"keyFiles": []
},
"runs": {
"path": "runs",
"purpose": null,
"fileCount": 0,
"lastAccessed": 1774313111643,
"keyFiles": []
},
"scripts": {
"path": "scripts",
"purpose": "Build/utility scripts",
"fileCount": 1,
"lastAccessed": 1774313111644,
"keyFiles": [
"run_prod.sh"
]
},
"services": {
"path": "services",
"purpose": "Business logic services",
"fileCount": 1,
"lastAccessed": 1774313111644,
"keyFiles": [
"README.md"
]
},
"shared": {
"path": "shared",
"purpose": null,
"fileCount": 0,
"lastAccessed": 1774313111644,
"keyFiles": []
},
"workspaces": {
"path": "workspaces",
"purpose": null,
"fileCount": 0,
"lastAccessed": 1774313111645,
"keyFiles": []
},
"backend/api": {
"path": "backend/api",
"purpose": "API routes",
"fileCount": 5,
"lastAccessed": 1774313111645,
"keyFiles": [
"__init__.py",
"agents.py",
"guard.py"
]
},
"backend/config": {
"path": "backend/config",
"purpose": "Configuration files",
"fileCount": 4,
"lastAccessed": 1773304964535,
"fileCount": 6,
"lastAccessed": 1774313111646,
"keyFiles": [
"__init__.py",
"constants.py",
"data_config.py"
"agent_profiles.yaml",
"bootstrap_config.py"
]
},
"backend/data": {
"path": "backend/data",
"purpose": "Data files",
"fileCount": 7,
"lastAccessed": 1773304964536,
"fileCount": 13,
"lastAccessed": 1774313111647,
"keyFiles": [
"__init__.py",
"cache.py",
"historical_price_manager.py"
]
},
"backend/services": {
"path": "backend/services",
"purpose": "Business logic services",
"fileCount": 4,
"lastAccessed": 1773304964536,
"keyFiles": [
"__init__.py",
"gateway.py",
"market.py"
]
},
"backend/tests": {
"path": "backend/tests",
"purpose": "Test files",
"fileCount": 4,
"lastAccessed": 1773304964536,
"keyFiles": [
"__init__.py",
"test_agents.py",
"test_market_service.py"
]
},
"docs/assets": {
"path": "docs/assets",
"purpose": "Static assets",
"fileCount": 5,
"lastAccessed": 1773304964536,
"lastAccessed": 1774313111647,
"keyFiles": [
"dashboard.jpg",
"evotraders_demo.gif",
"evotraders_logo.jpg"
]
},
"frontend/public": {
"path": "frontend/public",
"purpose": "Public files",
"fileCount": 1,
"lastAccessed": 1773304964538,
"frontend/dist": {
"path": "frontend/dist",
"purpose": "Distribution/build output",
"fileCount": 2,
"lastAccessed": 1774313111647,
"keyFiles": [
"index.html",
"trading_logo.png"
]
},
"frontend/node_modules": {
"path": "frontend/node_modules",
"purpose": "Dependencies",
"fileCount": 1,
"lastAccessed": 1774313111650,
"keyFiles": []
}
},
"hotPaths": [
{
"path": "frontend/src/components/StatisticsView.jsx",
"accessCount": 22,
"lastAccessed": 1773310044545,
"type": "file"
},
{
"path": "frontend/src/components/AgentCard.jsx",
"accessCount": 17,
"lastAccessed": 1773309995177,
"type": "file"
"path": "CLAUDE.md",
"accessCount": 15,
"lastAccessed": 1774342728155,
"type": "directory"
},
{
"path": "frontend/src/App.jsx",
"accessCount": 12,
"lastAccessed": 1773309849392,
"accessCount": 10,
"lastAccessed": 1774339397617,
"type": "file"
},
{
"path": "frontend/src/components/AgentFeed.jsx",
"accessCount": 12,
"lastAccessed": 1773309960022,
"type": "file"
},
{
"path": ".env",
"accessCount": 7,
"lastAccessed": 1773308950505,
"type": "file"
},
{
"path": "frontend/src/components/RoomView.jsx",
"accessCount": 7,
"lastAccessed": 1773309864236,
"type": "file"
},
{
"path": "backend/tools/analysis_tools.py",
"accessCount": 5,
"lastAccessed": 1773312271446,
"type": "file"
},
{
"path": "frontend/src/components/Header.jsx",
"path": "frontend/src/hooks/useWebsocketSessionSync.js",
"accessCount": 4,
"lastAccessed": 1773309827069,
"lastAccessed": 1774313470024,
"type": "file"
},
{
"path": "frontend/src/components/AboutModal.jsx",
"path": "",
"accessCount": 4,
"lastAccessed": 1773310093371,
"type": "file"
},
{
"path": "backend/agents/prompts/analyst/personas.yaml",
"accessCount": 4,
"lastAccessed": 1773312049213,
"type": "file"
},
{
"path": "backend/agents/prompts/analyst/system.md",
"accessCount": 4,
"lastAccessed": 1773312049696,
"type": "file"
},
{
"path": "backend/agents/prompts/portfolio_manager/system.md",
"accessCount": 4,
"lastAccessed": 1773312050326,
"type": "file"
},
{
"path": "backend/agents/prompts/risk_manager/system.md",
"accessCount": 4,
"lastAccessed": 1773312050782,
"type": "file"
},
{
"path": "frontend/src/config/constants.js",
"accessCount": 3,
"lastAccessed": 1773309824671,
"type": "file"
},
{
"path": "frontend/src/components/RulesView.jsx",
"accessCount": 3,
"lastAccessed": 1773310061939,
"type": "file"
},
{
"path": "backend",
"accessCount": 3,
"lastAccessed": 1773312200721,
"lastAccessed": 1774339108220,
"type": "directory"
},
{
"path": "backend/services/gateway.py",
"accessCount": 3,
"lastAccessed": 1774339389171,
"type": "file"
},
{
"path": "backend/main.py",
"accessCount": 3,
"lastAccessed": 1774342613364,
"type": "file"
},
{
"path": "frontend/src/store/runtimeStore.js",
"accessCount": 2,
"lastAccessed": 1773312232905,
"type": "directory"
},
{
"path": "README.md",
"accessCount": 1,
"lastAccessed": 1773305013217,
"type": "file"
},
{
"path": "README_zh.md",
"accessCount": 1,
"lastAccessed": 1773305013274,
"type": "file"
},
{
"path": "env.template",
"accessCount": 1,
"lastAccessed": 1773305019965,
"lastAccessed": 1774317990919,
"type": "file"
},
{
"path": "frontend/src/services/websocket.js",
"accessCount": 2,
"lastAccessed": 1774318009819,
"type": "file"
},
{
"path": "backend/core/pipeline_runner.py",
"accessCount": 2,
"lastAccessed": 1774339367538,
"type": "file"
},
{
"path": "backend/runtime/manager.py",
"accessCount": 2,
"lastAccessed": 1774339367572,
"type": "file"
},
{
"path": "frontend/src/store/marketStore.js",
"accessCount": 1,
"lastAccessed": 1773309324302,
"lastAccessed": 1774313140483,
"type": "file"
},
{
"path": "frontend/src/hooks/useFeedProcessor.js",
"accessCount": 1,
"lastAccessed": 1774313148279,
"type": "file"
},
{
"path": "frontend/src/components/Header.jsx",
"accessCount": 1,
"lastAccessed": 1774313156696,
"type": "file"
},
{
"path": "frontend/src/components/TraderView.jsx",
"accessCount": 1,
"lastAccessed": 1774313156753,
"type": "file"
},
{
"path": "frontend/src/store/uiStore.js",
"accessCount": 1,
"lastAccessed": 1774313187460,
"type": "file"
},
{
"path": "frontend/src/store/portfolioStore.js",
"accessCount": 1,
"lastAccessed": 1774313187511,
"type": "file"
},
{
"path": "frontend/src/store/agentStore.js",
"accessCount": 1,
"lastAccessed": 1774313187573,
"type": "file"
},
{
"path": "frontend/src/hooks/useWebSocketConnection.js",
"accessCount": 1,
"lastAccessed": 1774313279414,
"type": "file"
},
{
"path": "frontend/src/hooks/useStockDataRequests.js",
"accessCount": 1,
"lastAccessed": 1774313319716,
"type": "file"
},
{
"path": "frontend/src/hooks/useAgentDataRequests.js",
"accessCount": 1,
"lastAccessed": 1774313347455,
"type": "file"
},
{
"path": "frontend/src/components/AppShell.jsx",
"accessCount": 1,
"lastAccessed": 1774313396331,
"type": "file"
},
{
"path": "start-dev.sh",
"accessCount": 1,
"lastAccessed": 1774317979859,
"type": "file"
},
{
"path": "backend/apps/agent_service.py",
"accessCount": 1,
"lastAccessed": 1774317984348,
"type": "file"
},
{
"path": "shared/client/trading_client.py",
"accessCount": 1,
"lastAccessed": 1774317984365,
"type": "file"
},
{
"path": "backend/apps/trading_service.py",
"accessCount": 1,
"lastAccessed": 1774317984408,
"type": "file"
},
{
"path": "pyproject.toml",
"accessCount": 1,
"lastAccessed": 1774317990970,
"type": "file"
},
{
"path": "backend/agents/factory.py",
"accessCount": 1,
"lastAccessed": 1774318009867,
"type": "file"
},
{
"path": "backend/config/constants.py",
"accessCount": 1,
"lastAccessed": 1774318009922,
"type": "file"
},
{
"path": "backend/api/__init__.py",
"accessCount": 1,
"lastAccessed": 1774318009973,
"type": "file"
},
{
"path": "README.md",
"accessCount": 1,
"lastAccessed": 1774339107381,
"type": "file"
},
{
"path": "backend/runtime/registry.py",
"accessCount": 1,
"lastAccessed": 1774339380024,
"type": "file"
},
{
"path": "backend/runtime/session.py",
"accessCount": 1,
"lastAccessed": 1774339380084,
"type": "file"
},
{
"path": "backend/runtime/context.py",
"accessCount": 1,
"lastAccessed": 1774339380120,
"type": "file"
},
{
"path": "backend/runtime/agent_runtime.py",
"accessCount": 1,
"lastAccessed": 1774339380185,
"type": "file"
},
{
"path": "backend/process/supervisor.py",
"accessCount": 1,
"lastAccessed": 1774339389110,
"type": "file"
},
{
"path": "backend/core/pipeline.py",
"accessCount": 1,
"lastAccessed": 1774339389187,
"type": "file"
},
{
"path": "backend/process/models.py",
"accessCount": 1,
"lastAccessed": 1774339397557,
"type": "file"
},
{
"path": "backend/process/registry.py",
"accessCount": 1,
"lastAccessed": 1774339397577,
"type": "file"
},
{
"path": "backend/config/env_config.py",
"accessCount": 1,
"lastAccessed": 1774342678236,
"type": "file"
},
{
"path": "backend/config/data_config.py",
"accessCount": 1,
"lastAccessed": 1773309324414,
"lastAccessed": 1774342678253,
"type": "file"
},
{
"path": "backend/cli.py",
"path": "frontend/env.template",
"accessCount": 1,
"lastAccessed": 1773309336899,
"type": "directory"
},
{
"path": "backend/agents/portfolio_manager.py",
"accessCount": 1,
"lastAccessed": 1773311956562,
"lastAccessed": 1774342678290,
"type": "file"
},
{
"path": "backend/agents/risk_manager.py",
"path": "env.template",
"accessCount": 1,
"lastAccessed": 1773311956760,
"lastAccessed": 1774342678310,
"type": "file"
},
{
"path": "backend/agents/analyst.py",
"accessCount": 1,
"lastAccessed": 1773311963222,
"type": "file"
},
{
"path": "backend/tools",
"accessCount": 1,
"lastAccessed": 1773312289643,
"type": "directory"
},
{
"path": "backend/tools/data_tools.py",
"accessCount": 1,
"lastAccessed": 1773312293851,
"type": "directory"
}
],
"userDirectives": []

View File

@@ -1,6 +1,6 @@
{
"timestamp": "2026-03-12T20:33:59.497Z",
"timestamp": "2026-03-24T07:58:12.123Z",
"backgroundTasks": [],
"sessionStartTimestamp": "2026-03-12T14:19:33.615Z",
"sessionId": "73b0d597-0141-4873-9d0e-2b60e4e0635e"
"sessionStartTimestamp": "2026-03-24T07:58:09.417Z",
"sessionId": "fda34772-7bd2-402e-86b2-d656296416f3"
}

View File

@@ -1 +1 @@
{"session_id":"73b0d597-0141-4873-9d0e-2b60e4e0635e","transcript_path":"/Users/cillin/.claude/projects/-Users-cillin-workspeace-agentscope-samples-evotraders/73b0d597-0141-4873-9d0e-2b60e4e0635e.jsonl","cwd":"/Users/cillin/workspeace/agentscope-samples/evotraders","model":{"id":"kimi-for-coding","display_name":"kimi-for-coding"},"workspace":{"current_dir":"/Users/cillin/workspeace/agentscope-samples/evotraders","project_dir":"/Users/cillin/workspeace/agentscope-samples/evotraders","added_dirs":["/Users/cillin/workspeace/agentscope-samples/EvoTraders","/Users/cillin/workspeace/agentscope-samples/evotraders"]},"version":"2.1.63","output_style":{"name":"default"},"cost":{"total_cost_usd":6.822239999999999,"total_duration_ms":42679588,"total_api_duration_ms":1223637,"total_lines_added":275,"total_lines_removed":240},"context_window":{"total_input_tokens":654274,"total_output_tokens":27014,"context_window_size":200000,"current_usage":{"input_tokens":48465,"output_tokens":0,"cache_creation_input_tokens":0,"cache_read_input_tokens":0},"used_percentage":24,"remaining_percentage":76},"exceeds_200k_tokens":false}
{"session_id":"fda34772-7bd2-402e-86b2-d656296416f3","transcript_path":"/Users/cillin/.claude/projects/-Users-cillin-workspeace-evotraders/fda34772-7bd2-402e-86b2-d656296416f3.jsonl","cwd":"/Users/cillin/workspeace/evotraders","model":{"id":"MiniMax-M2.7-highspeed","display_name":"MiniMax-M2.7-highspeed"},"workspace":{"current_dir":"/Users/cillin/workspeace/evotraders","project_dir":"/Users/cillin/workspeace/evotraders","added_dirs":[]},"version":"2.1.78","output_style":{"name":"default"},"cost":{"total_cost_usd":36.63980749999998,"total_duration_ms":69778027,"total_api_duration_ms":2925118,"total_lines_added":3056,"total_lines_removed":4537},"context_window":{"total_input_tokens":910503,"total_output_tokens":145207,"context_window_size":200000,"current_usage":{"input_tokens":507,"output_tokens":247,"cache_creation_input_tokens":4132,"cache_read_input_tokens":96553},"used_percentage":51,"remaining_percentage":49},"exceeds_200k_tokens":false}

View File

@@ -1,3 +1,3 @@
{
"lastSentAt": "2026-03-12T20:31:37.362Z"
"lastSentAt": "2026-03-24T08:58:57.965Z"
}

View File

@@ -1,26 +1,26 @@
{
"agents": [
{
"agent_id": "a4090d26a45ac828d",
"agent_type": "oh-my-claudecode:executor",
"started_at": "2026-03-12T10:02:38.238Z",
"agent_id": "abeaf609b74a2b7ee",
"agent_type": "Explore",
"started_at": "2026-03-24T08:01:40.015Z",
"parent_mode": "none",
"status": "completed",
"completed_at": "2026-03-12T10:10:59.192Z",
"duration_ms": 500954
"completed_at": "2026-03-24T08:02:31.822Z",
"duration_ms": 51807
},
{
"agent_id": "af87583ef76a4df30",
"agent_type": "oh-my-claudecode:executor",
"started_at": "2026-03-12T10:40:04.409Z",
"agent_id": "afb6750eaae72bc72",
"agent_type": "Explore",
"started_at": "2026-03-24T08:56:21.471Z",
"parent_mode": "none",
"status": "completed",
"completed_at": "2026-03-12T10:41:17.387Z",
"duration_ms": 72978
"completed_at": "2026-03-24T08:57:27.856Z",
"duration_ms": 66385
}
],
"total_spawned": 2,
"total_completed": 2,
"total_failed": 0,
"last_updated": "2026-03-12T10:41:17.490Z"
"last_updated": "2026-03-24T08:59:06.380Z"
}

378
CLAUDE.md Normal file
View File

@@ -0,0 +1,378 @@
# CLAUDE.md
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
本文件为 Claude Code (claude.ai/code) 在此代码库中工作时提供指导。
## 项目概述
EvoTraders 是一个自进化多智能体交易系统,由 6 个 AI Agent4 名分析师 + 投资经理 + 风控经理协作完成交易决策。Agent 基于 AgentScope 框架构建,配合 ReMe 记忆系统实现持续学习。
## 常用命令
### Backend (Python)
```bash
# 安装依赖
uv pip install -e .
# 运行命令
evotraders backtest --start 2025-11-01 --end 2025-12-01 # 回测模式
evotraders backtest --start 2025-11-01 --end 2025-12-01 --enable-memory # 带记忆回测
evotraders live # 实盘交易
evotraders live --mock # 模拟/测试模式
evotraders live -t 22:30 # 定时每日交易
evotraders frontend # 启动可视化界面
# 开发服务器
./start-dev.sh # 启动全部 4 个微服务 (agent, runtime, trading, news)
# Gateway WebSocket 服务器
python backend/main.py --mode live --config-name mock --mock
# 单独启动微服务
python -m uvicorn backend.apps.runtime_service:app --host 0.0.0.0 --port 8003 --reload
python -m uvicorn backend.apps.agent_service:app --host 0.0.0.0 --port 8000 --reload
python -m uvicorn backend.apps.trading_service:app --host 0.0.0.0 --port 8001 --reload
python -m uvicorn backend.apps.news_service:app --host 0.0.0.0 --port 8002 --reload
# 测试
pytest backend/tests # 运行全部测试
pytest backend/tests/test_news_service_app.py -v # 运行单个测试
```
### Frontend (React)
```bash
cd frontend
npm run dev # Vite 开发服务器 (http://localhost:5173)
npm run build # 生产构建
npm run lint # ESLint 检查
npm run lint:fix # ESLint 自动修复
npm run test # Vitest 单元测试
```
## 架构概览
### 系统分层
```
┌─────────────────────────────────────────────────────────────┐
│ Frontend (React) │
│ WebSocket ws://localhost:8765 连接 Gateway │
└─────────────────────────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────┐
│ Gateway (backend/services/gateway.py) │
│ WebSocket 服务器,编排 Pipeline4 阶段启动 │
└─────────────────────────────────────────────────────────────┘
│ │ │ │
▼ ▼ ▼ ▼
┌────────────┐ ┌────────────┐ ┌────────────┐ ┌────────────┐
│ Market │ │ Storage │ │ Pipeline │ │ Scheduler │
│ Service │ │ Service │ │ │ │ │
└────────────┘ └────────────┘ └────────────┘ └────────────┘
┌──────────────────────┼──────────────────────┐
▼ ▼ ▼
┌──────────┐ ┌──────────┐ ┌──────────┐
│ Analysts │ │ PM │ │ Risk │
│ (4 个) │ │ │ │ Manager │
└──────────┘ └──────────┘ └──────────┘
```
### 微服务架构 (`backend/apps/`)
| 服务 | 端口 | 职责 |
|------|------|------|
| runtime_service | 8003 | 运行时配置、任务启动、Pipeline Runner |
| agent_service | 8000 | Agent 生命周期、工作区管理 |
| trading_service | 8001 | 市场数据、交易操作 |
| news_service | 8002 | 新闻、新闻富化、解释功能 |
### Gateway 4 阶段启动 (`backend/services/gateway.py`)
1. **WebSocket Server** - 前端立即可连接
2. **Market Service** - 价格数据开始推送
3. **Market Status Monitor** - 市场状态监控
4. **Scheduler** - 交易周期开始
### 运行时管理层 (`backend/runtime/`)
| 文件 | 职责 |
|------|------|
| `manager.py` | TradingRuntimeManager - 全局运行时管理器agent 注册、会话、事件快照 |
| `agent_runtime.py` | AgentRuntimeState - 单 agent 状态status、last_session |
| `context.py` | TradingRunContext - 运行上下文 |
| `session.py` | TradingSessionKey - 交易日会话键 |
| `registry.py` | RuntimeRegistry - agent 状态注册表 |
快照持久化到 `runs/<run_id>/state/runtime_state.json`
### Pipeline 执行 (`backend/core/`)
| 文件 | 职责 |
|------|------|
| `pipeline.py` | TradingPipeline - 核心编排器(分析→沟通→决策→执行→评估) |
| `pipeline_runner.py` | REST API 触发的独立执行5 阶段启动 |
| `scheduler.py` | BacktestScheduler、Scheduler - 回测/实盘调度 |
| `state_sync.py` | StateSync - 状态同步和广播 |
## 后端结构
```
backend/
├── agents/ # 多智能体实现
│ ├── analyst.py # AnalystAgent 基类
│ ├── portfolio_manager.py # PMAgent 投资经理
│ ├── risk_manager.py # RiskAgent 风控经理
│ ├── factory.py # Agent 实例工厂
│ ├── toolkit_factory.py # 工具集工厂
│ ├── skills_manager.py # 技能加载管理
│ ├── workspace_manager.py # 工作区管理
│ ├── skill_loader.py # 技能加载器
│ ├── agent_workspace.py # Agent 工作区
│ ├── prompt_loader.py # Prompt 加载器
│ ├── prompt_factory.py # Prompt 工厂
│ ├── skill_metadata.py # 技能元数据
│ ├── registry.py # Agent 注册表
│ ├── team_pipeline_config.py # 团队 Pipeline 配置
│ ├── compat.py # 兼容性层
│ ├── templates.py # 模板
│ ├── workspace.py # 工作区
│ ├── base/ # 核心类、Hooks
│ │ ├── evo_agent.py # 基于 AgentScope 的核心实现
│ │ └── hooks.py # 生命周期 Hooks
│ └── prompts/ # Agent 提示词
│ └── analyst/personas.yaml
├── apps/ # 微服务入口
│ ├── runtime_service.py # 运行时服务(端口 8003
│ ├── agent_service.py # Agent 服务(端口 8000
│ ├── trading_service.py # 交易服务(端口 8001
│ ├── news_service.py # 新闻服务(端口 8002
│ └── cors.py
├── runtime/ # 运行时管理层
│ ├── manager.py # TradingRuntimeManager
│ ├── agent_runtime.py # AgentRuntimeState
│ ├── context.py # TradingRunContext
│ ├── session.py # TradingSessionKey
│ └── registry.py # RuntimeRegistry
├── process/ # 进程监管层
│ ├── supervisor.py # ProcessSupervisor
│ ├── registry.py # RunRegistry
│ └── models.py # ProcessRun、ProcessRunState
├── core/ # Pipeline 执行
│ ├── pipeline.py # TradingPipeline核心编排器
│ ├── pipeline_runner.py # 独立 Pipeline 执行
│ ├── scheduler.py # 调度器
│ └── state_sync.py # 状态同步
├── services/ # Gateway 和服务
│ ├── gateway.py # WebSocket 网关
│ ├── gateway_*.py # Gateway 子模块
│ ├── market.py # 市场数据服务
│ ├── storage.py # 存储服务
│ ├── runtime_db.py # 运行时数据库
│ └── research_db.py # 研究数据库
├── data/ # 市场数据处理
│ ├── provider_router.py # 数据源路由
│ ├── provider_utils.py # 数据源工具
│ ├── market_store.py # 市场数据存储
│ ├── market_ingest.py # 数据采集
│ ├── cache.py # 缓存
│ ├── schema.py # 数据 schema
│ ├── historical_price_manager.py # 历史价格管理
│ ├── polling_price_manager.py # 轮询价格管理
│ ├── mock_price_manager.py # Mock 价格管理
│ ├── news_alignment.py # 新闻对齐
│ ├── polygon_client.py # Polygon.io 客户端
│ └── ret_data_updater.py # 离线数据更新
├── config/ # 配置
│ ├── constants.py # Agent 配置、显示名称
│ ├── bootstrap_config.py # 启动配置解析
│ ├── env_config.py # 环境变量配置
│ ├── data_config.py # 数据源配置
│ └── agent_profiles.yaml # Agent Profile 配置
├── domains/ # 领域业务逻辑
│ ├── news.py
│ └── trading.py
├── llm/ # LLM 集成
│ └── models.py # RetryChatModel、TokenRecordingModelWrapper
├── skills/ # 技能定义
├── tools/ # 交易和分析工具
├── enrich/ # LLM 响应富化
├── explain/ # 交易决策解释
├── utils/ # 工具函数
│ ├── settlement.py # 结算协调器
│ ├── trade_executor.py # 交易执行器
│ ├── terminal_dashboard.py # 终端仪表板
│ ├── analyst_tracker.py # 分析师追踪
│ ├── baselines.py # 基准线
│ ├── msg_adapter.py # 消息适配器
│ └── progress.py # 进度追踪
├── api/ # FastAPI 端点
│ └── runtime.py
└── main.py # 主入口点
```
## 前端结构
```
frontend/src/
├── App.jsx # 主应用LiveTradingApp
├── AppShell.jsx # App 外壳(布局、侧边栏)
├── components/
│ ├── RuntimeView.jsx # 交易运行时 UI
│ ├── TraderView.jsx # 交易员界面
│ ├── RoomView.jsx # 聊天室视图
│ ├── StockExplainView.jsx # 股票解释视图
│ ├── RuntimeSettingsPanel.jsx # 运行时设置面板
│ ├── RuntimeLogsModal.jsx # 运行时日志弹窗
│ ├── WatchlistPanel.jsx # 关注列表
│ ├── PerformanceView.jsx # 绩效视图
│ ├── StatisticsView.jsx # 统计视图
│ ├── NetValueChart.jsx # 净值曲线图
│ ├── AgentCard.jsx # Agent 卡片
│ ├── AgentFeed.jsx # Agent 动态
│ ├── Header.jsx # 头部
│ ├── MarkdownModal.jsx # Markdown 弹窗
│ ├── StockLogo.jsx # 股票 Logo
│ └── explain/ # 解释组件
│ ├── ExplainNewsSection.jsx
│ ├── ExplainRangeSection.jsx
│ ├── ExplainSimilarDaysSection.jsx
│ ├── ExplainStorySection.jsx
│ └── useExplainModel.js
├── hooks/ # React Hooks
│ ├── useWebSocketConnection.js # WebSocket 连接管理
│ ├── useRuntimeControls.js # 运行时配置管理
│ ├── useAgentDataRequests.js # Agent 数据请求
│ ├── useStockDataRequests.js # 股票数据请求
│ ├── useStockExplainData.js # 股票解释数据
│ ├── useAgentWorkspacePanel.js # Agent 工作区面板
│ ├── useWebsocketSessionSync.js # WebSocket 会话同步
│ └── useFeedProcessor.js # Feed 事件处理
├── store/ # Zustand 状态管理
│ ├── runtimeStore.js # 连接状态、运行时配置
│ ├── marketStore.js # 市场数据、股票价格
│ ├── portfolioStore.js # 组合、持仓、交易
│ ├── agentStore.js # Agent 技能、工作区
│ └── uiStore.js # UI 状态、视图切换
├── services/
│ ├── websocket.js # WebSocket 客户端
│ ├── runtimeApi.js # 运行时 API
│ ├── runtimeControls.js # 运行时控制
│ ├── newsApi.js # 新闻 API
│ └── tradingApi.js # 交易 API
├── utils/
│ ├── formatters.js # 格式化工具
│ └── modelIcons.js # 模型图标
└── config/
└── constants.js # Agent 定义、配置
```
## Agent 系统
### 6 种 Agent 角色
| 角色 ID | 名称 | 职责 |
|---------|------|------|
| `fundamentals_analyst` | 基本面分析师 | 财务健康、盈利能力、成长质量 |
| `technical_analyst` | 技术分析师 | 价格趋势、技术指标、动量分析 |
| `sentiment_analyst` | 情绪分析师 | 市场情绪、新闻情绪、内幕交易 |
| `valuation_analyst` | 估值分析师 | DCF、EV/EBITDA、intrinsic value |
| `portfolio_manager` | 投资经理 | 决策执行、交易协调 |
| `risk_manager` | 风控经理 | 实时价格/波动率监控、仓位限制 |
### 添加自定义分析师
1. `backend/agents/prompts/analyst/personas.yaml` 注册
2. `backend/config/constants.py``ANALYST_TYPES` 字典添加
3. `frontend/src/config/constants.js` 可选更新
### LLM 模型封装 (`backend/llm/models.py`)
- **RetryChatModel**: 自动重试瞬态 LLM 错误,指数退避
- **TokenRecordingModelWrapper**: 追踪 token 消耗和成本
## 技能系统 (`backend/skills/`)
技能定义在 `SKILL.md`,包含 `instructions``triggers``parameters``available_tools`
技能管理器支持 6 种作用域builtin、customized、installed、active、disabled、local。
## 运行时数据布局
- `data/market_research.db` - 持久研究数据
- `runs/<run_id>/` - 每次任务运行的状态
- `runs/<run_id>/team_dashboard/*.json` - 仪表板导出层(非权威源)
- `runs/<run_id>/state/runtime_state.json` - 运行时快照
- 运行时 API 优先使用 `server_state.json``runtime.db`
```bash
RUNS_RETENTION_COUNT=20 # 时间戳格式文件夹自动清理
```
## 环境配置
### Backend (`env.template`)
```bash
# 金融数据源支持多源fallback
FIN_DATA_SOURCE=finnhub|financial_datasets|yfinance|local_csv
ENABLED_DATA_SOURCES=financial_datasets,finnhub,yfinance,local_csv
FINANCIAL_DATASETS_API_KEY= # 回测必需
FINNHUB_API_KEY= # 实盘必需
POLYGON_API_KEY= # Polygon市场库采集可选
# LLM 配置
OPENAI_API_KEY=
OPENAI_BASE_URL=
MODEL_NAME=qwen3-max-preview
# Agent 特定模型
AGENT_SENTIMENT_ANALYST_MODEL_NAME=deepseek-v3.2-exp
AGENT_TECHNICAL_ANALYST_MODEL_NAME=glm-4.6
AGENT_FUNDAMENTALS_ANALYST_MODEL_NAME=qwen3-max-preview
AGENT_VALUATION_ANALYST_MODEL_NAME=Moonshot-Kimi-K2-Instruct
AGENT_RISK_MANAGER_MODEL_NAME=qwen3-max-preview
AGENT_PORTFOLIO_MANAGER_MODEL_NAME=qwen3-max-preview
# ReMe 记忆系统
MEMORY_API_KEY=
MEMORY_MODEL_NAME=qwen3-max
MEMORY_EMBEDDING_MODEL=text-embedding-v4
# 交易参数
MAX_COMM_CYCLES=2
MARGIN_REQUIREMENT=0.5
DATA_START_DATE=2022-01-01
AUTO_UPDATE_DATA=true
```
### Frontend (`frontend/env.template`)
```bash
VITE_WS_URL=ws://localhost:8765
```
## 关键依赖
- **AgentScope** - 多智能体框架
- **ReMe** - 持续学习记忆系统
- **FastAPI** + **uvicorn** - 后端 API
- **websockets** - 实时通信
- **React 19** + **Vite** + **TailwindCSS** - 前端
- **Zustand** - 状态管理

View File

@@ -110,6 +110,21 @@ evotraders frontend # Default connects to port 8765, you can modi
Visit `http://localhost:5173/` to view the trading room, select a date and click Run/Replay to observe the decision-making process.
### Runtime Data Layout
- Long-lived research data is stored in `data/market_research.db`
- Each task run writes run-scoped state under `runs/<run_id>/`
- `runs/<run_id>/team_dashboard/*.json` is an export/compatibility layer for dashboard views, not the authoritative runtime source of truth
- Runtime APIs prefer active runtime state, `server_state.json`, and `runtime.db`
Optional retention control:
```bash
RUNS_RETENTION_COUNT=20
```
Only timestamped run folders like `YYYYMMDD_HHMMSS` are pruned automatically when starting a new runtime. Named runs such as `smoke_fullstack` or `test_*` are preserved.
---
## System Architecture

View File

@@ -117,6 +117,54 @@ evotraders frontend # 默认连接 8765 端口, 你可以修改 .
访问 `http://localhost:5173/` 查看交易大厅,选择日期并点击 Run/Replay 观察决策过程。
### 迁移期服务边界说明
当前仓库正处于从模块化单体向独立服务迁移的阶段,当前默认开发路径已经切到独立 app surface
- `backend.apps.agent_service`
- `backend.apps.runtime_service`
- `backend.apps.trading_service`
- `backend.apps.news_service`
当前本地开发默认推荐直接运行拆分后的服务:
```bash
./start-dev.sh split
# 或分别手动启动
python -m uvicorn backend.apps.agent_service:app --port 8000 --reload
python -m uvicorn backend.apps.runtime_service:app --port 8003 --reload
python -m uvicorn backend.apps.trading_service:app --port 8001 --reload
python -m uvicorn backend.apps.news_service:app --port 8002 --reload
```
迁移期关键环境变量:
```bash
# 后端 Gateway 优先走独立服务读取
NEWS_SERVICE_URL=http://localhost:8002
TRADING_SERVICE_URL=http://localhost:8001
# 前端浏览器直连控制面 / 运行时面
VITE_CONTROL_API_BASE_URL=http://localhost:8000/api
VITE_RUNTIME_API_BASE_URL=http://localhost:8003/api/runtime
# 前端浏览器优先直连独立服务
VITE_NEWS_SERVICE_URL=http://localhost:8002
VITE_TRADING_SERVICE_URL=http://localhost:8001
```
目前前端已支持直连 `news-service` 的 explain 只读路径包括:
- runtime panel / gateway port 查询已可独立指向 `runtime-service`
- story
- similar days
- range explain
- news for date
- news categories
如果没有配置这些变量,系统会继续走当前保留的本地回退逻辑。
---
## 系统架构

View File

@@ -0,0 +1,452 @@
# -*- coding: utf-8 -*-
"""Evaluation hooks system for skills.
Provides evaluation metric collection and storage for skill performance tracking.
Based on the evaluation hooks design in SKILL_TEMPLATE.md.
"""
from __future__ import annotations
import json
import logging
from dataclasses import dataclass, field, asdict
from datetime import datetime
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional, Set
logger = logging.getLogger(__name__)
class MetricType(Enum):
"""Types of evaluation metrics."""
HIT_RATE = "hit_rate" # 信号命中率
RISK_VIOLATION = "risk_violation" # 风控违例率
POSITION_DEVIATION = "position_deviation" # 仓位偏离率
PnL_ATTRIBUTION = "pnl_attribution" # P&L 归因一致性
SIGNAL_CONSISTENCY = "signal_consistency" # 信号一致性
DECISION_LATENCY = "decision_latency" # 决策延迟
TOOL_USAGE = "tool_usage" # 工具使用率
CUSTOM = "custom" # 自定义指标
@dataclass
class EvaluationMetric:
"""A single evaluation metric."""
name: str
metric_type: MetricType
value: float
timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
metadata: Dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> Dict[str, Any]:
return {
"name": self.name,
"metric_type": self.metric_type.value,
"value": self.value,
"timestamp": self.timestamp,
"metadata": self.metadata,
}
@dataclass
class EvaluationResult:
"""Evaluation result for a skill execution."""
skill_name: str
run_id: str
agent_id: str
metrics: List[EvaluationMetric] = field(default_factory=list)
inputs: Dict[str, Any] = field(default_factory=dict)
outputs: Dict[str, Any] = field(default_factory=dict)
decision: Optional[str] = None
success: bool = True
error_message: Optional[str] = None
started_at: Optional[str] = None
completed_at: Optional[str] = field(default_factory=lambda: datetime.now().isoformat())
def to_dict(self) -> Dict[str, Any]:
return {
"skill_name": self.skill_name,
"run_id": self.run_id,
"agent_id": self.agent_id,
"metrics": [m.to_dict() for m in self.metrics],
"inputs": self.inputs,
"outputs": self.outputs,
"decision": self.decision,
"success": self.success,
"error_message": self.error_message,
"started_at": self.started_at,
"completed_at": self.completed_at,
}
class EvaluationHook:
"""Hook for collecting skill evaluation metrics.
This hook collects and stores evaluation metrics after skill execution
for later analysis and memory/reflection stages.
"""
def __init__(
self,
storage_dir: Path,
run_id: str,
agent_id: str,
):
"""Initialize evaluation hook.
Args:
storage_dir: Directory to store evaluation results
run_id: Current run identifier
agent_id: Current agent identifier
"""
self.storage_dir = Path(storage_dir)
self.run_id = run_id
self.agent_id = agent_id
self._current_evaluation: Optional[EvaluationResult] = None
def start_evaluation(
self,
skill_name: str,
inputs: Dict[str, Any],
) -> None:
"""Start a new evaluation session.
Args:
skill_name: Name of the skill being evaluated
inputs: Input parameters for the skill
"""
self._current_evaluation = EvaluationResult(
skill_name=skill_name,
run_id=self.run_id,
agent_id=self.agent_id,
inputs=inputs,
started_at=datetime.now().isoformat(),
)
logger.debug(f"Started evaluation for skill: {skill_name}")
def add_metric(
self,
name: str,
metric_type: MetricType,
value: float,
metadata: Optional[Dict[str, Any]] = None,
) -> None:
"""Add an evaluation metric.
Args:
name: Metric name
metric_type: Type of metric
value: Metric value
metadata: Additional metadata
"""
if self._current_evaluation is None:
logger.warning("No active evaluation session, ignoring metric")
return
metric = EvaluationMetric(
name=name,
metric_type=metric_type,
value=value,
metadata=metadata or {},
)
self._current_evaluation.metrics.append(metric)
logger.debug(f"Added metric: {name} = {value}")
def add_metrics(self, metrics: List[EvaluationMetric]) -> None:
"""Add multiple evaluation metrics at once.
Args:
metrics: List of metrics to add
"""
if self._current_evaluation is None:
logger.warning("No active evaluation session, ignoring metrics")
return
self._current_evaluation.metrics.extend(metrics)
def record_outputs(self, outputs: Dict[str, Any]) -> None:
"""Record skill outputs.
Args:
outputs: Output from skill execution
"""
if self._current_evaluation is None:
logger.warning("No active evaluation session, ignoring outputs")
return
self._current_evaluation.outputs = outputs
def record_decision(self, decision: str) -> None:
"""Record the final decision.
Args:
decision: Final decision made by the skill
"""
if self._current_evaluation is None:
logger.warning("No active evaluation session, ignoring decision")
return
self._current_evaluation.decision = decision
def complete_evaluation(
self,
success: bool = True,
error_message: Optional[str] = None,
) -> Optional[EvaluationResult]:
"""Complete the evaluation session and persist results.
Args:
success: Whether the skill execution was successful
error_message: Error message if failed
Returns:
The completed evaluation result, or None if no active evaluation
"""
if self._current_evaluation is None:
logger.warning("No active evaluation to complete")
return None
self._current_evaluation.success = success
self._current_evaluation.error_message = error_message
self._current_evaluation.completed_at = datetime.now().isoformat()
# Persist to storage
result = self._persist_evaluation(self._current_evaluation)
self._current_evaluation = None
logger.debug(f"Completed evaluation for skill: {result.skill_name}")
return result
def _persist_evaluation(self, evaluation: EvaluationResult) -> EvaluationResult:
"""Persist evaluation result to storage.
Args:
evaluation: Evaluation result to persist
Returns:
The persisted evaluation
"""
# Create run-specific directory
run_dir = self.storage_dir / self.run_id
run_dir.mkdir(parents=True, exist_ok=True)
# Create agent-specific subdirectory
agent_dir = run_dir / self.agent_id
agent_dir.mkdir(parents=True, exist_ok=True)
# Generate filename with timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
filename = f"{evaluation.skill_name}_{timestamp}.json"
filepath = agent_dir / filename
# Write evaluation result
try:
with open(filepath, "w", encoding="utf-8") as f:
json.dump(evaluation.to_dict(), f, ensure_ascii=False, indent=2)
logger.info(f"Persisted evaluation to: {filepath}")
except Exception as e:
logger.error(f"Failed to persist evaluation: {e}")
return evaluation
def cancel_evaluation(self) -> None:
"""Cancel the current evaluation session without saving."""
if self._current_evaluation is not None:
logger.debug(f"Cancelled evaluation for: {self._current_evaluation.skill_name}")
self._current_evaluation = None
class EvaluationCollector:
"""Collector for aggregating evaluation metrics across runs.
Provides methods to query and analyze evaluation results.
"""
def __init__(self, storage_dir: Path):
"""Initialize evaluation collector.
Args:
storage_dir: Root directory containing evaluation results
"""
self.storage_dir = Path(storage_dir)
def get_run_evaluations(
self,
run_id: str,
agent_id: Optional[str] = None,
) -> List[EvaluationResult]:
"""Get all evaluations for a run.
Args:
run_id: Run identifier
agent_id: Optional agent identifier to filter by
Returns:
List of evaluation results
"""
run_dir = self.storage_dir / run_id
if not run_dir.exists():
return []
evaluations = []
agent_dirs = [run_dir / agent_id] if agent_id else run_dir.iterdir()
for agent_dir in agent_dirs:
if not agent_dir.is_dir():
continue
for eval_file in agent_dir.glob("*.json"):
try:
with open(eval_file, "r", encoding="utf-8") as f:
data = json.load(f)
evaluations.append(self._parse_evaluation(data))
except Exception as e:
logger.warning(f"Failed to load evaluation {eval_file}: {e}")
return evaluations
def get_skill_metrics(
self,
skill_name: str,
run_ids: Optional[List[str]] = None,
) -> List[EvaluationMetric]:
"""Get all metrics for a specific skill.
Args:
skill_name: Name of the skill
run_ids: Optional list of run IDs to filter by
Returns:
List of metrics for the skill
"""
metrics = []
if run_ids is None:
run_ids = [d.name for d in self.storage_dir.iterdir() if d.is_dir()]
for run_id in run_ids:
evaluations = self.get_run_evaluations(run_id)
for eval_result in evaluations:
if eval_result.skill_name == skill_name:
metrics.extend(eval_result.metrics)
return metrics
def calculate_skill_stats(
self,
skill_name: str,
metric_type: MetricType,
run_ids: Optional[List[str]] = None,
) -> Dict[str, float]:
"""Calculate statistics for a specific metric type.
Args:
skill_name: Name of the skill
metric_type: Type of metric to calculate
run_ids: Optional list of run IDs to filter by
Returns:
Dictionary with min, max, avg, count statistics
"""
metrics = self.get_skill_metrics(skill_name, run_ids)
filtered = [m for m in metrics if m.metric_type == metric_type]
if not filtered:
return {"count": 0}
values = [m.value for m in filtered]
return {
"count": len(values),
"min": min(values),
"max": max(values),
"avg": sum(values) / len(values),
}
def _parse_evaluation(self, data: Dict[str, Any]) -> EvaluationResult:
"""Parse evaluation data into EvaluationResult.
Args:
data: Raw evaluation data
Returns:
Parsed EvaluationResult
"""
metrics = []
for m in data.get("metrics", []):
metrics.append(EvaluationMetric(
name=m["name"],
metric_type=MetricType(m["metric_type"]),
value=m["value"],
timestamp=m.get("timestamp", ""),
metadata=m.get("metadata", {}),
))
return EvaluationResult(
skill_name=data["skill_name"],
run_id=data["run_id"],
agent_id=data["agent_id"],
metrics=metrics,
inputs=data.get("inputs", {}),
outputs=data.get("outputs", {}),
decision=data.get("decision"),
success=data.get("success", True),
error_message=data.get("error_message"),
started_at=data.get("started_at"),
completed_at=data.get("completed_at"),
)
def parse_evaluation_hooks(skill_dir: Path) -> Dict[str, Any]:
"""Parse evaluation hooks from SKILL.md.
Extracts the Optional: Evaluation hooks section from skill documentation.
Args:
skill_dir: Skill directory path
Returns:
Dictionary containing evaluation hook definitions
"""
skill_md = skill_dir / "SKILL.md"
if not skill_md.exists():
return {}
try:
content = skill_md.read_text(encoding="utf-8")
# Extract evaluation hooks section
if "## Optional: Evaluation hooks" in content:
start = content.find("## Optional: Evaluation hooks")
# Find the next ## section or end of file
next_section = content.find("\n## ", start + 1)
if next_section == -1:
eval_section = content[start:]
else:
eval_section = content[start:next_section]
# Parse metrics from the section
metrics = []
for metric_type in MetricType:
if metric_type.value.replace("_", " ") in eval_section.lower():
metrics.append(metric_type.value)
return {
"supported_metrics": metrics,
"section_content": eval_section.strip(),
}
except Exception as e:
logger.warning(f"Failed to parse evaluation hooks: {e}")
return {}
__all__ = [
"MetricType",
"EvaluationMetric",
"EvaluationResult",
"EvaluationHook",
"EvaluationCollector",
"parse_evaluation_hooks",
]

View File

@@ -470,7 +470,7 @@ class EvoAgent(ToolGuardMixin, ReActAgent):
"""
return self._messenger
def delegate_task(
async def delegate_task(
self,
task_type: str,
task_data: Dict[str, Any],
@@ -493,7 +493,7 @@ class EvoAgent(ToolGuardMixin, ReActAgent):
}
try:
return self._task_delegator.delegate_task(
return await self._task_delegator.delegate_task(
task_type=task_type,
task_data=task_data,
target_agent=target_agent,

View File

@@ -0,0 +1,489 @@
# -*- coding: utf-8 -*-
"""Skill adaptation hook for automatic evaluation-to-iteration闭环.
Monitors evaluation metrics against configurable thresholds and triggers
automatic skill reload or logs warnings when thresholds are breached.
"""
from __future__ import annotations
import json
import logging
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional, Set
from .evaluation_hook import (
EvaluationCollector,
EvaluationResult,
MetricType,
)
logger = logging.getLogger(__name__)
class AdaptationAction(Enum):
"""Actions to take when threshold is breached."""
RELOAD = "reload" # 自动重新加载技能
WARN = "warn" # 记录警告供人工审核
BOTH = "both" # 同时执行重载和警告
NONE = "none" # 不做任何操作
@dataclass
class AdaptationThreshold:
"""Threshold configuration for a metric."""
metric_type: MetricType
operator: str = "lt" # lt (less than), gt (greater than), lte, gte, eq
value: float = 0.0
window_size: int = 10 # 移动窗口大小,用于计算滑动平均
min_samples: int = 5 # 最少样本数才触发检查
action: AdaptationAction = AdaptationAction.WARN
cooldown_seconds: int = 300 # 触发后的冷却时间
def evaluate(self, current_value: float) -> bool:
"""Evaluate if threshold is breached."""
ops = {
"lt": lambda x, y: x < y,
"lte": lambda x, y: x <= y,
"gt": lambda x, y: x > y,
"gte": lambda x, y: x >= y,
"eq": lambda x, y: x == y,
}
op_func = ops.get(self.operator)
if op_func is None:
logger.warning(f"Unknown operator: {self.operator}")
return False
return op_func(current_value, self.value)
def to_dict(self) -> Dict[str, Any]:
return {
"metric_type": self.metric_type.value,
"operator": self.operator,
"value": self.value,
"window_size": self.window_size,
"min_samples": self.min_samples,
"action": self.action.value,
"cooldown_seconds": self.cooldown_seconds,
}
@dataclass
class AdaptationEvent:
"""Record of an adaptation trigger event."""
timestamp: str
skill_name: str
metric_type: MetricType
threshold: AdaptationThreshold
current_value: float
avg_value: float
action_taken: AdaptationAction
details: Dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> Dict[str, Any]:
return {
"timestamp": self.timestamp,
"skill_name": self.skill_name,
"metric_type": self.metric_type.value,
"threshold": self.threshold.to_dict(),
"current_value": self.current_value,
"avg_value": self.avg_value,
"action_taken": self.action_taken.value,
"details": self.details,
}
class SkillAdaptationHook:
"""Hook for monitoring evaluation metrics and triggering skill adaptation.
This hook wraps EvaluationHook to add threshold-based adaptation logic.
When metrics breach configured thresholds, it can:
- Automatically reload skills via SkillsManager
- Log warnings for human review
- Both
"""
# Default thresholds for common metrics
DEFAULT_THRESHOLDS: List[AdaptationThreshold] = [
AdaptationThreshold(
metric_type=MetricType.HIT_RATE,
operator="lt",
value=0.5,
action=AdaptationAction.WARN,
cooldown_seconds=600,
),
AdaptationThreshold(
metric_type=MetricType.RISK_VIOLATION,
operator="gt",
value=0.1,
action=AdaptationAction.WARN,
cooldown_seconds=300,
),
AdaptationThreshold(
metric_type=MetricType.DECISION_LATENCY,
operator="gt",
value=5000, # 5 seconds
action=AdaptationAction.WARN,
cooldown_seconds=300,
),
]
def __init__(
self,
storage_dir: Path,
run_id: str,
agent_id: str,
thresholds: Optional[List[AdaptationThreshold]] = None,
collector: Optional[EvaluationCollector] = None,
):
"""Initialize skill adaptation hook.
Args:
storage_dir: Directory to store adaptation events
run_id: Current run identifier
agent_id: Current agent identifier
thresholds: Custom threshold configurations (uses defaults if None)
collector: Optional EvaluationCollector for historical data
"""
self.storage_dir = Path(storage_dir)
self.run_id = run_id
self.agent_id = agent_id
self.thresholds = thresholds or self.DEFAULT_THRESHOLDS
self.collector = collector or EvaluationCollector(storage_dir)
# Track cooldowns to prevent rapid re-triggering
self._cooldowns: Dict[str, datetime] = {}
# Store recent metrics in memory for quick access
self._recent_metrics: Dict[str, List[float]] = {}
# Pending adaptation events
self._pending_events: List[AdaptationEvent] = []
def check_threshold(
self,
skill_name: str,
metric_type: MetricType,
current_value: float,
) -> Optional[AdaptationEvent]:
"""Check if a metric breaches any threshold.
Args:
skill_name: Name of the skill
metric_type: Type of metric
current_value: Current metric value
Returns:
AdaptationEvent if threshold breached, None otherwise
"""
# Find applicable thresholds
applicable_thresholds = [
t for t in self.thresholds
if t.metric_type == metric_type
]
if not applicable_thresholds:
return None
# Check cooldown
cooldown_key = f"{skill_name}:{metric_type.value}"
now = datetime.now()
last_trigger = self._cooldowns.get(cooldown_key)
# Store current value first for avg calculation
self._store_metric(cooldown_key, current_value)
for threshold in applicable_thresholds:
if last_trigger:
elapsed = (now - last_trigger).total_seconds()
if elapsed < threshold.cooldown_seconds:
continue
# Evaluate threshold
if threshold.evaluate(current_value):
# Calculate moving average
avg_value = self._calculate_avg(skill_name, metric_type, current_value)
# Check minimum samples (allow immediate trigger if min_samples <= 1)
sample_count = len(self._recent_metrics.get(cooldown_key, []))
if threshold.min_samples > 1 and sample_count < threshold.min_samples:
# Not enough samples yet
continue
# Trigger adaptation
event = AdaptationEvent(
timestamp=now.isoformat(),
skill_name=skill_name,
metric_type=metric_type,
threshold=threshold,
current_value=current_value,
avg_value=avg_value,
action_taken=threshold.action,
details={
"run_id": self.run_id,
"agent_id": self.agent_id,
},
)
# Update cooldown
self._cooldowns[cooldown_key] = now
# Persist event
self._persist_event(event)
logger.info(
f"Threshold breached for {skill_name}.{metric_type.value}: "
f"current={current_value}, avg={avg_value}, action={threshold.action.value}"
)
return event
return None
def _calculate_avg(
self,
skill_name: str,
metric_type: MetricType,
current_value: float,
) -> float:
"""Calculate moving average for a metric."""
key = f"{skill_name}:{metric_type.value}"
values = self._recent_metrics.get(key, [])
if not values:
return current_value
return sum(values) / len(values)
def _store_metric(self, key: str, value: float) -> None:
"""Store metric value with sliding window."""
if key not in self._recent_metrics:
self._recent_metrics[key] = []
self._recent_metrics[key].append(value)
# Keep only last 100 values
if len(self._recent_metrics[key]) > 100:
self._recent_metrics[key] = self._recent_metrics[key][-100:]
def _persist_event(self, event: AdaptationEvent) -> None:
"""Persist adaptation event to storage."""
run_dir = self.storage_dir / self.run_id / "adaptations"
run_dir.mkdir(parents=True, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
filename = f"{event.skill_name}_{event.metric_type.value}_{timestamp}.json"
filepath = run_dir / filename
try:
with open(filepath, "w", encoding="utf-8") as f:
json.dump(event.to_dict(), f, ensure_ascii=False, indent=2)
logger.debug(f"Persisted adaptation event to: {filepath}")
except Exception as e:
logger.error(f"Failed to persist adaptation event: {e}")
# Also add to pending list
self._pending_events.append(event)
def get_pending_warnings(self) -> List[AdaptationEvent]:
"""Get all pending warning events that need human review."""
return [
e for e in self._pending_events
if e.action_taken in (AdaptationAction.WARN, AdaptationAction.BOTH)
]
def clear_pending_warnings(self) -> None:
"""Clear pending warnings after they have been reviewed."""
self._pending_events = [
e for e in self._pending_events
if e.action_taken == AdaptationAction.RELOAD
]
def get_recent_events(
self,
skill_name: Optional[str] = None,
metric_type: Optional[MetricType] = None,
limit: int = 50,
) -> List[AdaptationEvent]:
"""Get recent adaptation events.
Args:
skill_name: Optional filter by skill name
metric_type: Optional filter by metric type
limit: Maximum number of events to return
Returns:
List of recent adaptation events
"""
events_dir = self.storage_dir / self.run_id / "adaptations"
if not events_dir.exists():
return []
events = []
for eval_file in sorted(events_dir.glob("*.json"), reverse=True)[:limit]:
try:
with open(eval_file, "r", encoding="utf-8") as f:
data = json.load(f)
event = self._parse_event(data)
if skill_name and event.skill_name != skill_name:
continue
if metric_type and event.metric_type != metric_type:
continue
events.append(event)
except Exception as e:
logger.warning(f"Failed to load adaptation event {eval_file}: {e}")
return events
def _parse_event(self, data: Dict[str, Any]) -> AdaptationEvent:
"""Parse adaptation event from JSON data."""
threshold_data = data.get("threshold", {})
metric_type = MetricType(threshold_data.get("metric_type", "custom"))
threshold = AdaptationThreshold(
metric_type=metric_type,
operator=threshold_data.get("operator", "lt"),
value=threshold_data.get("value", 0.0),
window_size=threshold_data.get("window_size", 10),
min_samples=threshold_data.get("min_samples", 5),
action=AdaptationAction(threshold_data.get("action", "warn")),
cooldown_seconds=threshold_data.get("cooldown_seconds", 300),
)
return AdaptationEvent(
timestamp=data.get("timestamp", ""),
skill_name=data.get("skill_name", ""),
metric_type=metric_type,
threshold=threshold,
current_value=data.get("current_value", 0.0),
avg_value=data.get("avg_value", 0.0),
action_taken=AdaptationAction(data.get("action_taken", "warn")),
details=data.get("details", {}),
)
def add_threshold(self, threshold: AdaptationThreshold) -> None:
"""Add a new threshold configuration."""
self.thresholds.append(threshold)
def remove_threshold(self, metric_type: MetricType) -> None:
"""Remove all thresholds for a specific metric type."""
self.thresholds = [
t for t in self.thresholds
if t.metric_type != metric_type
]
def update_threshold(
self,
metric_type: MetricType,
**kwargs,
) -> None:
"""Update threshold configuration for a metric type."""
for threshold in self.thresholds:
if threshold.metric_type == metric_type:
for key, value in kwargs.items():
if hasattr(threshold, key):
setattr(threshold, key, value)
def get_thresholds(self) -> List[AdaptationThreshold]:
"""Get current threshold configurations."""
return list(self.thresholds)
def is_in_cooldown(self, skill_name: str, metric_type: MetricType) -> bool:
"""Check if a skill/metric combination is in cooldown period."""
key = f"{skill_name}:{metric_type.value}"
last_trigger = self._cooldowns.get(key)
if not last_trigger:
return False
# Find the threshold for this metric type
for threshold in self.thresholds:
if threshold.metric_type == metric_type:
elapsed = (datetime.now() - last_trigger).total_seconds()
return elapsed < threshold.cooldown_seconds
return False
class AdaptationManager:
"""Manager for coordinating skill adaptation across multiple agents.
Provides centralized tracking of adaptation events and skill reloads.
"""
def __init__(self, storage_dir: Path):
"""Initialize adaptation manager.
Args:
storage_dir: Root directory for storing adaptation data
"""
self.storage_dir = Path(storage_dir)
self._hooks: Dict[str, SkillAdaptationHook] = {}
def get_hook(
self,
run_id: str,
agent_id: str,
thresholds: Optional[List[AdaptationThreshold]] = None,
) -> SkillAdaptationHook:
"""Get or create an adaptation hook for an agent.
Args:
run_id: Run identifier
agent_id: Agent identifier
thresholds: Optional custom thresholds
Returns:
SkillAdaptationHook instance
"""
key = f"{run_id}:{agent_id}"
if key not in self._hooks:
self._hooks[key] = SkillAdaptationHook(
storage_dir=self.storage_dir,
run_id=run_id,
agent_id=agent_id,
thresholds=thresholds,
)
return self._hooks[key]
def get_all_pending_warnings(self) -> List[AdaptationEvent]:
"""Get all pending warnings from all hooks."""
warnings = []
for hook in self._hooks.values():
warnings.extend(hook.get_pending_warnings())
return warnings
def get_run_adaptations(self, run_id: str) -> List[AdaptationEvent]:
"""Get all adaptation events for a run."""
events = []
for hook in self._hooks.values():
if hook.run_id == run_id:
events.extend(hook.get_recent_events())
return events
# Global manager instance
_adaptation_manager: Optional[AdaptationManager] = None
def get_adaptation_manager(storage_dir: Optional[Path] = None) -> AdaptationManager:
"""Get global adaptation manager instance.
Args:
storage_dir: Optional storage directory (required on first call)
Returns:
AdaptationManager instance
"""
global _adaptation_manager
if _adaptation_manager is None:
if storage_dir is None:
raise ValueError("storage_dir required on first initialization")
_adaptation_manager = AdaptationManager(storage_dir)
return _adaptation_manager
__all__ = [
"AdaptationAction",
"AdaptationThreshold",
"AdaptationEvent",
"SkillAdaptationHook",
"AdaptationManager",
"get_adaptation_manager",
]

View File

@@ -289,6 +289,7 @@ class ToolGuardMixin:
self._approval_timeout = approval_timeout
self._pending_approval: Optional[ToolApprovalRequest] = None
self._approval_callback: Optional[Callable[[ToolApprovalRequest], None]] = None
self._approval_lock = asyncio.Lock()
def set_approval_callback(
self,
@@ -383,6 +384,7 @@ class ToolGuardMixin:
Returns:
True if approved, False otherwise
"""
async with self._approval_lock:
record = TOOL_GUARD_STORE.create_pending(
tool_name=tool_name,
tool_input=tool_input,
@@ -418,12 +420,15 @@ class ToolGuardMixin:
if self._approval_callback:
self._approval_callback(self._pending_approval)
# Wait for approval
# Wait for approval (lock is released during wait, re-acquired after)
approval_request = self._pending_approval
# Wait for approval outside the lock to allow concurrent approval
approved = await approval_request.wait_for_approval(
timeout=self._approval_timeout
)
async with self._approval_lock:
if approval_request:
status = (
ApprovalStatus.APPROVED
@@ -446,10 +451,13 @@ class ToolGuardMixin:
status=status.value,
)
# Only clear if this is still the same request
if self._pending_approval is approval_request:
self._pending_approval = None
return approved
def approve_guard_call(self, request_id: Optional[str] = None) -> bool:
async def approve_guard_call(self, request_id: Optional[str] = None) -> bool:
"""Approve a pending guard request.
This method is called externally to approve a tool call
@@ -461,6 +469,7 @@ class ToolGuardMixin:
Returns:
True if a request was approved, False if no pending request
"""
async with self._approval_lock:
if self._pending_approval is None:
logger.warning("No pending approval request to approve")
return False
@@ -482,7 +491,7 @@ class ToolGuardMixin:
logger.info("Approved tool call: %s", self._pending_approval.tool_name)
return True
def deny_guard_call(self, request_id: Optional[str] = None) -> bool:
async def deny_guard_call(self, request_id: Optional[str] = None) -> bool:
"""Deny a pending guard request.
This method is called externally to deny a tool call
@@ -494,6 +503,7 @@ class ToolGuardMixin:
Returns:
True if a request was denied, False if no pending request
"""
async with self._approval_lock:
if self._pending_approval is None:
logger.warning("No pending approval request to deny")
return False

View File

@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
"""Agent Factory - Dynamic creation and management of EvoAgents."""
"""Agent Factory - Dynamic creation and management of AgentConfigs."""
import logging
import shutil
@@ -37,8 +37,8 @@ class RoleConfig:
self.constraints = []
class EvoAgent:
"""Represents a configured agent instance."""
class AgentConfig:
"""Represents a configured agent instance (data class)."""
def __init__(
self,
@@ -185,7 +185,7 @@ class AgentFactory:
model_config: Optional[ModelConfig] = None,
role_config: Optional[RoleConfig] = None,
clone_from: Optional[str] = None,
) -> EvoAgent:
) -> AgentConfig:
"""Create a new agent.
Args:
@@ -197,7 +197,7 @@ class AgentFactory:
clone_from: Path to existing agent to clone from (optional)
Returns:
EvoAgent instance
AgentConfig instance
Raises:
ValueError: If agent already exists or workspace doesn't exist
@@ -234,7 +234,7 @@ class AgentFactory:
config_path = agent_dir / "agent.yaml"
self._write_agent_yaml(config_path, agent_id, agent_type, model_config)
return EvoAgent(
return AgentConfig(
agent_id=agent_id,
agent_type=agent_type,
workspace_id=workspace_id,
@@ -267,7 +267,7 @@ class AgentFactory:
new_agent_id: str,
target_workspace_id: Optional[str] = None,
model_config: Optional[ModelConfig] = None,
) -> EvoAgent:
) -> AgentConfig:
"""Clone an existing agent.
Args:
@@ -278,7 +278,7 @@ class AgentFactory:
model_config: Optional new model configuration
Returns:
EvoAgent instance for the cloned agent
AgentConfig instance for the cloned agent
"""
target_workspace_id = target_workspace_id or source_workspace_id
source_dir = self.workspaces_root / source_workspace_id / "agents" / source_agent_id

View File

@@ -6,10 +6,10 @@ from typing import Any, Optional
from .agent_workspace import load_agent_workspace_config
from backend.config.bootstrap_config import get_bootstrap_config_for_run
from .prompt_loader import PromptLoader
from .prompt_loader import get_prompt_loader
from .skills_manager import SkillsManager
_prompt_loader = PromptLoader()
_prompt_loader = get_prompt_loader()
def _read_file_if_exists(path: Path) -> str:

View File

@@ -10,6 +10,17 @@ from typing import Any, Dict, Optional
import yaml
# Singleton instance
_prompt_loader_instance: Optional["PromptLoader"] = None
def get_prompt_loader() -> "PromptLoader":
"""Get the singleton PromptLoader instance."""
global _prompt_loader_instance
if _prompt_loader_instance is None:
_prompt_loader_instance = PromptLoader()
return _prompt_loader_instance
class PromptLoader:
"""Unified Prompt loader"""

View File

@@ -5,6 +5,7 @@ from pathlib import Path
import shutil
import tempfile
import zipfile
from threading import Lock
from typing import Any, Dict, Iterable, Iterator, List, Optional, Set
from urllib.parse import urlparse
from urllib.request import urlretrieve
@@ -39,6 +40,9 @@ class SkillsManager:
self.project_root / "backend" / "skills" / "customized"
)
self.runs_root = self.project_root / "runs"
self._lock = Lock()
# Instance-level pending skill changes (thread-safe via self._lock)
self._pending_skill_changes: Dict[str, Set[Path]] = {}
def get_active_root(self, config_name: str) -> Path:
return self.runs_root / config_name / "skills" / "active"
@@ -737,7 +741,7 @@ class SkillsManager:
if local_root.exists():
watched_paths.append(local_root)
handler = _SkillsChangeHandler(watched_paths, callback)
handler = _SkillsChangeHandler(watched_paths, self._pending_skill_changes, callback, self._lock)
observer = Observer()
for path in watched_paths:
observer.schedule(handler, str(path), recursive=True)
@@ -759,16 +763,19 @@ class SkillsManager:
Map of agent_id -> list of reloaded skill paths, or empty dict
if no changes were detected.
"""
with self._lock:
changed = self._pending_skill_changes.get(config_name)
if not changed:
return {}
self._pending_skill_changes[config_name] = set()
return self.prepare_active_skills(config_name, agent_defaults)
# -------------------------------------------------------------------------
# Internal change-tracking state (populated by _SkillsChangeHandler)
# -------------------------------------------------------------------------
# Legacy class-level reference kept for migration compatibility
_pending_skill_changes: Dict[str, Set[Path]] = {}
def _resolve_disabled_skill_names(
@@ -820,11 +827,15 @@ class _SkillsChangeHandler(FileSystemEventHandler):
def __init__(
self,
watched_paths: List[Path],
pending_changes: Dict[str, Set[Path]],
callback: Optional[Any] = None,
lock: Optional[Lock] = None,
) -> None:
super().__init__()
self._watched_paths = watched_paths
self._pending_changes = pending_changes
self._callback = callback
self._lock = lock
def on_any_event(self, event: FileSystemEvent) -> None:
if event.is_directory:
@@ -832,9 +843,12 @@ class _SkillsChangeHandler(FileSystemEventHandler):
src_path = Path(event.src_path)
for watched in self._watched_paths:
if src_path.is_relative_to(watched):
SkillsManager._pending_skill_changes.setdefault(
self._run_id_from_path(src_path), set()
).add(src_path)
run_id = self._run_id_from_path(src_path)
if self._lock:
with self._lock:
self._pending_changes.setdefault(run_id, set()).add(src_path)
else:
self._pending_changes.setdefault(run_id, set()).add(src_path)
if self._callback:
self._callback([src_path])
break

View File

@@ -0,0 +1,18 @@
# -*- coding: utf-8 -*-
"""Team module for multi-agent orchestration.
Provides inter-agent communication, task delegation, and coordination
for subagent spawning and lifecycle management.
"""
from .messenger import AgentMessenger
from .task_delegator import TaskDelegator
from .team_coordinator import TeamCoordinator
from .registry import AgentRegistry
__all__ = [
"AgentMessenger",
"TaskDelegator",
"TeamCoordinator",
"AgentRegistry",
]

View File

@@ -0,0 +1,225 @@
# -*- coding: utf-8 -*-
"""AgentMessenger - Pub/sub inter-agent communication.
Provides broadcast(), send(), and subscribe() for message passing
between agents using AgentScope's Msg format.
"""
from __future__ import annotations
import asyncio
import logging
from typing import Any, Callable, Dict, List, Optional, Set
from agentscope.message import Msg
logger = logging.getLogger(__name__)
class AgentMessenger:
"""Pub/sub messenger for inter-agent communication.
Supports:
- broadcast(): Send message to all subscribers
- send(): Send message to specific agent
- subscribe(): Register callback for agent messages
- announce(): Send system-wide announcement
- enable_auto_broadcast: Auto-broadcast agent replies to all participants
Messages use AgentScope's Msg format for compatibility.
"""
def __init__(self, enable_auto_broadcast: bool = False):
"""Initialize the messenger.
Args:
enable_auto_broadcast: If True, agent replies are automatically
broadcast to all subscribed agents.
"""
self._subscriptions: Dict[str, List[Callable[[Msg], None]]] = {}
self._inbox: Dict[str, List[Msg]] = {}
self._locks: Dict[str, asyncio.Lock] = {}
self._enable_auto_broadcast = enable_auto_broadcast
self._participants: Set[str] = set()
def subscribe(
self,
agent_id: str,
callback: Callable[[Msg], None],
) -> None:
"""Subscribe an agent to receive messages.
Args:
agent_id: Target agent identifier
callback: Async function to call when message received
"""
if agent_id not in self._subscriptions:
self._subscriptions[agent_id] = []
self._subscriptions[agent_id].append(callback)
logger.debug("Agent %s subscribed to messages", agent_id)
def unsubscribe(self, agent_id: str, callback: Callable[[Msg], None]) -> None:
"""Unsubscribe an agent from messages.
Args:
agent_id: Target agent identifier
callback: Callback to remove
"""
if agent_id in self._subscriptions:
try:
self._subscriptions[agent_id].remove(callback)
logger.debug("Agent %s unsubscribed from messages", agent_id)
except ValueError:
pass
async def send(
self,
to_agent: str,
message: Msg,
) -> None:
"""Send message to specific agent.
Args:
to_agent: Target agent identifier
message: Message to send (uses Msg format)
"""
async def _deliver():
if to_agent in self._subscriptions:
for callback in self._subscriptions[to_agent]:
try:
if asyncio.iscoroutinefunction(callback):
await callback(message)
else:
callback(message)
except Exception as e:
logger.error(
"Error delivering message to %s: %s",
to_agent,
e,
)
await _deliver()
async def broadcast(self, message: Msg) -> None:
"""Broadcast message to all subscribed agents.
Args:
message: Message to broadcast (uses Msg format)
"""
delivery_tasks = []
for agent_id, callbacks in self._subscriptions.items():
for callback in callbacks:
async def _deliver(cb=callback, aid=agent_id):
try:
if asyncio.iscoroutinefunction(cb):
await cb(message)
else:
cb(message)
except Exception as e:
logger.error(
"Error broadcasting to %s: %s",
aid,
e,
)
delivery_tasks.append(_deliver())
if delivery_tasks:
await asyncio.gather(*delivery_tasks)
def inbox(self, agent_id: str) -> List[Msg]:
"""Get and clear inbox for agent.
Args:
agent_id: Agent identifier
Returns:
List of messages in inbox
"""
messages = self._inbox.get(agent_id, [])
self._inbox[agent_id] = []
return messages
def inbox_count(self, agent_id: str) -> int:
"""Count messages in agent's inbox without clearing.
Args:
agent_id: Agent identifier
Returns:
Number of messages waiting
"""
return len(self._inbox.get(agent_id, []))
def add_participant(self, agent_id: str) -> None:
"""Add a participant to the messenger.
Participants are the agents that can receive auto-broadcast messages.
Args:
agent_id: Agent identifier to add
"""
self._participants.add(agent_id)
logger.debug("Agent %s added as participant", agent_id)
def remove_participant(self, agent_id: str) -> None:
"""Remove a participant from the messenger.
Args:
agent_id: Agent identifier to remove
"""
self._participants.discard(agent_id)
logger.debug("Agent %s removed from participants", agent_id)
@property
def enable_auto_broadcast(self) -> bool:
"""Check if auto_broadcast is enabled."""
return self._enable_auto_broadcast
@enable_auto_broadcast.setter
def enable_auto_broadcast(self, value: bool) -> None:
"""Enable or disable auto_broadcast."""
self._enable_auto_broadcast = value
logger.debug("Auto_broadcast set to %s", value)
async def announce(self, message: Msg) -> None:
"""Send a system-wide announcement to all participants.
Unlike broadcast(), announce() sends a message from the system/host
to all participants without requiring prior subscription.
Args:
message: Announcement message (uses Msg format)
"""
logger.info("System announcement: %s", message.content)
await self.broadcast(message)
async def auto_broadcast(self, message: Msg) -> None:
"""Auto-broadcast message to all participants.
This is called internally when enable_auto_broadcast is True.
Broadcasts to all registered participants.
Args:
message: Message to auto-broadcast (uses Msg format)
"""
if not self._enable_auto_broadcast:
return
# Broadcast to all participants
for participant_id in self._participants:
if participant_id in self._subscriptions:
for callback in self._subscriptions[participant_id]:
try:
if asyncio.iscoroutinefunction(callback):
await callback(message)
else:
callback(message)
except Exception as e:
logger.error(
"Error auto-broadcasting to %s: %s",
participant_id,
e,
)
__all__ = ["AgentMessenger"]

View File

@@ -0,0 +1,188 @@
# -*- coding: utf-8 -*-
"""AgentRegistry - Agent registration and lookup by role.
Provides register(), unregister(), and get_by_role() for agent
discovery and management.
"""
from __future__ import annotations
import logging
from typing import Any, Dict, List, Optional
from agentscope.message import Msg
logger = logging.getLogger(__name__)
class AgentRegistry:
"""Registry for agent instances with role-based lookup.
Supports:
- register(): Add agent with roles
- unregister(): Remove agent
- get_by_role(): Find agents by role
- get_by_id(): Get specific agent
Each agent can have multiple roles for flexible dispatch.
"""
def __init__(self):
self._agents: Dict[str, Any] = {}
self._roles: Dict[str, List[str]] = {}
self._agent_roles: Dict[str, List[str]] = {}
def register(
self,
agent_id: str,
agent: Any,
roles: Optional[List[str]] = None,
) -> None:
"""Register an agent with optional roles.
Args:
agent_id: Unique agent identifier
agent: Agent instance
roles: Optional list of role strings
"""
self._agents[agent_id] = agent
self._agent_roles[agent_id] = roles or []
for role in self._agent_roles[agent_id]:
if role not in self._roles:
self._roles[role] = []
if agent_id not in self._roles[role]:
self._roles[role].append(agent_id)
logger.info(
"Registered agent %s with roles %s",
agent_id,
self._agent_roles[agent_id],
)
def unregister(self, agent_id: str) -> bool:
"""Unregister an agent.
Args:
agent_id: Agent identifier to remove
Returns:
True if agent was removed
"""
if agent_id not in self._agents:
return False
roles = self._agent_roles.pop(agent_id, [])
for role in roles:
if role in self._roles:
try:
self._roles[role].remove(agent_id)
except ValueError:
pass
del self._agents[agent_id]
logger.info("Unregistered agent: %s", agent_id)
return True
def get_by_id(self, agent_id: str) -> Optional[Any]:
"""Get agent by ID.
Args:
agent_id: Agent identifier
Returns:
Agent instance or None
"""
return self._agents.get(agent_id)
def get_by_role(self, role: str) -> List[Any]:
"""Get all agents with a given role.
Args:
role: Role string to search for
Returns:
List of agent instances with the role
"""
agent_ids = self._roles.get(role, [])
return [self._agents[aid] for aid in agent_ids if aid in self._agents]
def get_by_roles(self, roles: List[str]) -> List[Any]:
"""Get agents matching ANY of the given roles.
Args:
roles: List of role strings
Returns:
List of unique agent instances matching any role
"""
seen = set()
result = []
for role in roles:
for agent in self.get_by_role(role):
if id(agent) not in seen:
seen.add(id(agent))
result.append(agent)
return result
def list_agents(self) -> List[str]:
"""List all registered agent IDs.
Returns:
List of agent identifiers
"""
return list(self._agents.keys())
def list_roles(self) -> List[str]:
"""List all registered roles.
Returns:
List of role strings
"""
return list(self._roles.keys())
def list_roles_for_agent(self, agent_id: str) -> List[str]:
"""List roles for specific agent.
Args:
agent_id: Agent identifier
Returns:
List of role strings
"""
return list(self._agent_roles.get(agent_id, []))
def update_roles(self, agent_id: str, roles: List[str]) -> None:
"""Update roles for an existing agent.
Args:
agent_id: Agent identifier
roles: New list of roles
"""
if agent_id not in self._agents:
raise KeyError(f"Agent not registered: {agent_id}")
old_roles = self._agent_roles.get(agent_id, [])
for role in old_roles:
if role in self._roles:
try:
self._roles[role].remove(agent_id)
except ValueError:
pass
self._agent_roles[agent_id] = roles
for role in roles:
if role not in self._roles:
self._roles[role] = []
if agent_id not in self._roles[role]:
self._roles[role].append(agent_id)
logger.info("Updated roles for agent %s: %s", agent_id, roles)
@property
def agents(self) -> Dict[str, Any]:
"""Get copy of registered agents dict."""
return dict(self._agents)
__all__ = ["AgentRegistry"]

View File

@@ -0,0 +1,620 @@
# -*- coding: utf-8 -*-
"""TaskDelegator - Subagent spawning and task delegation.
Provides delegate() and delegate_parallel() for spawning subagents
with separate context and memory. Supports runtime dynamic subagent
definition via task_data with description, prompt, and tools.
"""
from __future__ import annotations
import asyncio
import logging
import uuid
from typing import Any, Awaitable, Callable, Dict, List, Optional, Union
from agentscope.message import Msg
logger = logging.getLogger(__name__)
# Default timeout for subagent execution (seconds)
DEFAULT_EXECUTION_TIMEOUT = 120.0
# Type alias for subagent specification
SubagentSpec = Dict[str, Any]
"""Subagent specification format:
{
"description": "Expert code reviewer...",
"prompt": "Analyze code quality...",
"tools": ["Read", "Glob", "Grep"], # Optional: list of tool names
"model": "gpt-4o", # Optional: model name
}
"""
class TaskDelegator:
"""Delegates tasks to subagents with isolated context.
Supports:
- delegate(): Spawn single subagent for task
- delegate_parallel(): Spawn multiple subagents concurrently
- delegate_task(): Delegate with dynamic subagent definition from task_data
Each subagent gets its own memory/context to prevent
cross-contamination.
Dynamic Subagent Definition:
task_data can include an "agents" dict to define subagents inline:
task_data = {
"task": "Review the code changes",
"agents": {
"code-reviewer": {
"description": "Expert code reviewer for quality and security.",
"prompt": "Analyze code quality and suggest improvements.",
"tools": ["Read", "Glob", "Grep"],
}
}
}
"""
def __init__(self, agent: Any):
"""Initialize TaskDelegator.
Args:
agent: Parent EvoAgent instance for accessing model, formatter, workspace
"""
self._agent = agent
# Get messenger from parent agent if available
self._messenger = getattr(agent, "messenger", None)
self._registry = getattr(agent, "_registry", None)
self._subagents: Dict[str, Any] = {}
self._dynamic_subagents: Dict[str, SubagentSpec] = {}
self._tasks: Dict[str, asyncio.Task] = {}
# Extract model and formatter from parent agent
self._model = getattr(agent, "model", None)
self._formatter = getattr(agent, "formatter", None)
self._workspace_dir = getattr(agent, "workspace_dir", None)
self._config_name = getattr(agent, "config_name", None)
async def delegate(
self,
agent_id: str,
task: Callable[..., Awaitable[Msg]],
context: Optional[Dict[str, Any]] = None,
) -> asyncio.Task:
"""Delegate task to a single subagent.
Args:
agent_id: Unique identifier for this subagent instance
task: Async function representing the task
context: Optional context dict for the subagent
Returns:
asyncio.Task for the delegated task
"""
async def _run_with_context():
result = await task(context or {})
return result
self._tasks[agent_id] = asyncio.create_task(_run_with_context())
logger.info("Delegated task to subagent: %s", agent_id)
return self._tasks[agent_id]
async def delegate_parallel(
self,
tasks: List[Dict[str, Any]],
) -> List[asyncio.Task]:
"""Delegate multiple tasks in parallel.
Args:
tasks: List of task dicts with keys:
- agent_id: Unique identifier
- task: Async function to execute
- context: Optional context dict
Returns:
List of asyncio.Task for all delegated tasks
"""
async def _run_task(task_def: Dict[str, Any]):
agent_id = task_def["agent_id"]
task_func = task_def["task"]
context = task_def.get("context", {})
async def _run_with_context():
return await task_func(context)
self._tasks[agent_id] = asyncio.create_task(_run_with_context())
return self._tasks[agent_id]
gathered_tasks = await asyncio.gather(
*[_run_task(t) for t in tasks],
return_exceptions=True,
)
valid_tasks = [t for t in gathered_tasks if isinstance(t, asyncio.Task)]
logger.info(
"Delegated %d tasks in parallel (%d succeeded)",
len(tasks),
len(valid_tasks),
)
return valid_tasks
async def wait_for(self, agent_id: str, timeout: Optional[float] = None) -> Any:
"""Wait for subagent task to complete.
Args:
agent_id: Subagent identifier
timeout: Optional timeout in seconds
Returns:
Task result
Raises:
asyncio.TimeoutError: If task doesn't complete in time
KeyError: If agent_id not found
"""
if agent_id not in self._tasks:
raise KeyError(f"Unknown subagent: {agent_id}")
try:
return await asyncio.wait_for(
self._tasks[agent_id],
timeout=timeout,
)
except asyncio.TimeoutError:
logger.warning("Task %s timed out after %s seconds", agent_id, timeout)
raise
async def cancel(self, agent_id: str) -> bool:
"""Cancel a subagent task.
Args:
agent_id: Subagent identifier
Returns:
True if task was cancelled
"""
if agent_id in self._tasks:
self._tasks[agent_id].cancel()
del self._tasks[agent_id]
logger.info("Cancelled subagent task: %s", agent_id)
return True
return False
def list_tasks(self) -> List[str]:
"""List active subagent task IDs.
Returns:
List of agent_ids with pending tasks
"""
return list(self._tasks.keys())
@property
def tasks(self) -> Dict[str, asyncio.Task]:
"""Get copy of active tasks dict."""
return dict(self._tasks)
async def delegate_task(
self,
task_type: str,
task_data: Dict[str, Any],
target_agent: Optional[str] = None,
) -> Dict[str, Any]:
"""Delegate a task with optional dynamic subagent definition.
Supports runtime subagent definition via task_data["agents"]:
task_data = {
"task": "Review code changes",
"agents": {
"code-reviewer": {
"description": "Expert code reviewer...",
"prompt": "Analyze code quality...",
"tools": ["Read", "Glob", "Grep"],
}
}
}
Args:
task_type: Type of task (e.g., "analysis", "review", "research")
task_data: Task payload, may include "agents" for dynamic subagent def
target_agent: Optional specific agent ID to delegate to
Returns:
Dict with "success" and result/error
"""
try:
# Extract dynamic subagent definitions from task_data
agents_def = task_data.get("agents", {})
if agents_def:
# Register dynamic subagents
for agent_name, agent_spec in agents_def.items():
self._dynamic_subagents[agent_name] = agent_spec
logger.info(
"Registered dynamic subagent: %s (description: %s)",
agent_name,
agent_spec.get("description", "")[:50],
)
# Determine target agent
effective_target = target_agent
if not effective_target:
# Use first available dynamic subagent or default
if agents_def:
effective_target = next(iter(agents_def.keys()))
else:
effective_target = "default"
# Execute the task (async)
task_result = await self._execute_task(
task_type=task_type,
task_data=task_data,
target_agent=effective_target,
)
# Clean up dynamic subagents after execution
for agent_name in agents_def.keys():
self._dynamic_subagents.pop(agent_name, None)
return {
"success": True,
"result": task_result,
"subagents_used": list(agents_def.keys()) if agents_def else [],
}
except Exception as e:
logger.error("Task delegation failed: %s", e)
return {
"success": False,
"error": str(e),
}
async def _execute_task(
self,
task_type: str,
task_data: Dict[str, Any],
target_agent: str,
) -> Dict[str, Any]:
"""Execute the delegated task with a real subagent.
Args:
task_type: Type of task
task_data: Task payload
target_agent: Target agent identifier
Returns:
Task execution result with success/failure info
"""
task_content = task_data.get("task", task_data.get("prompt", ""))
timeout = task_data.get("timeout", DEFAULT_EXECUTION_TIMEOUT)
# Check if we have a dynamic subagent spec for this target
agent_spec = self._dynamic_subagents.get(target_agent)
if agent_spec:
logger.info(
"Executing task '%s' with dynamic subagent '%s'",
task_type,
target_agent,
)
return await self._create_and_run_subagent(
agent_name=target_agent,
agent_spec=agent_spec,
task_content=task_content,
task_type=task_type,
timeout=timeout,
)
# Fallback: try to use parent agent's model to process the task directly
logger.info(
"Executing task '%s' with parent agent '%s' (no dynamic subagent)",
task_type,
target_agent,
)
return await self._run_with_parent_agent(
task_content=task_content,
task_type=task_type,
timeout=timeout,
)
async def _create_and_run_subagent(
self,
agent_name: str,
agent_spec: SubagentSpec,
task_content: str,
task_type: str,
timeout: float,
) -> Dict[str, Any]:
"""Create and run a dynamic subagent.
Args:
agent_name: Name identifier for the subagent
agent_spec: Subagent specification (description, prompt, tools, model)
task_content: Task prompt to send to the subagent
task_type: Type of task
timeout: Execution timeout in seconds
Returns:
Dict with execution results
"""
subagent_id = f"subagent_{agent_name}_{uuid.uuid4().hex[:8]}"
try:
# Create subagent instance
subagent = await self._create_subagent(
subagent_id=subagent_id,
agent_spec=agent_spec,
)
if subagent is None:
return {
"task_type": task_type,
"task": task_content,
"subagent": agent_name,
"status": "failed",
"error": "Failed to create subagent",
"message": f"Could not instantiate subagent '{agent_name}'",
}
# Store for potential cleanup
self._subagents[subagent_id] = subagent
# Execute with timeout
result = await asyncio.wait_for(
self._run_subagent(subagent, task_content),
timeout=timeout,
)
# Extract response content
response_content = ""
if isinstance(result, Msg):
response_content = result.content
elif hasattr(result, "content"):
response_content = str(result.content)
elif isinstance(result, dict):
response_content = result.get("content", str(result))
else:
response_content = str(result)
logger.info(
"Subagent '%s' completed task '%s' successfully",
agent_name,
task_type,
)
return {
"task_type": task_type,
"task": task_content,
"subagent": {
"name": agent_name,
"id": subagent_id,
"description": agent_spec.get("description", ""),
},
"status": "completed",
"response": response_content,
"message": f"Task '{task_type}' executed with subagent '{agent_name}'",
}
except asyncio.TimeoutError:
logger.warning(
"Subagent '%s' timed out after %.1f seconds for task '%s'",
agent_name,
timeout,
task_type,
)
# Cancel the task if still running
if subagent_id in self._subagents:
self._subagents.pop(subagent_id, None)
return {
"task_type": task_type,
"task": task_content,
"subagent": agent_name,
"status": "timeout",
"error": f"Execution timed out after {timeout} seconds",
"message": f"Task '{task_type}' timed out for subagent '{agent_name}'",
}
except Exception as e:
logger.error(
"Subagent '%s' failed for task '%s': %s",
agent_name,
task_type,
e,
exc_info=True,
)
# Cleanup on failure
if subagent_id in self._subagents:
self._subagents.pop(subagent_id, None)
return {
"task_type": task_type,
"task": task_content,
"subagent": agent_name,
"status": "error",
"error": str(e),
"message": f"Task '{task_type}' failed for subagent '{agent_name}': {e}",
}
async def _create_subagent(
self,
subagent_id: str,
agent_spec: SubagentSpec,
) -> Optional[Any]:
"""Create a subagent instance.
Uses the parent agent's model/formatter to create a lightweight
subagent for task execution.
Args:
subagent_id: Unique identifier for the subagent
agent_spec: Subagent specification
Returns:
Subagent instance or None if creation fails
"""
try:
# Import here to avoid circular imports
from agentscope.memory import InMemoryMemory
# Get model and formatter from parent
model = self._model
formatter = self._formatter
if model is None:
logger.error("Cannot create subagent: parent agent has no model")
return None
# Build system prompt from agent spec
description = agent_spec.get("description", "")
prompt_template = agent_spec.get("prompt", "")
system_prompt = f"""You are {description}
{prompt_template}
Your task is to complete the user's request below.
"""
# Create a minimal ReActAgent as the subagent
from agentscope.agent import ReActAgent
subagent = ReActAgent(
name=subagent_id,
model=model,
sys_prompt=system_prompt,
toolkit=None, # Could load tools from agent_spec.get("tools", [])
memory=InMemoryMemory(),
formatter=formatter,
max_iters=agent_spec.get("max_iters", 5),
)
logger.debug("Created subagent: %s", subagent_id)
return subagent
except Exception as e:
logger.error(
"Failed to create subagent '%s': %s",
subagent_id,
e,
exc_info=True,
)
return None
async def _run_subagent(
self,
subagent: Any,
task_content: str,
) -> Any:
"""Run a subagent with the given task.
Args:
subagent: Subagent instance
task_content: Task prompt
Returns:
Agent response (Msg or similar)
"""
from agentscope.message import Msg
# Create message for the subagent
task_msg = Msg(
name="user",
content=task_content,
role="user",
)
# Execute the agent
response = await subagent.reply(task_msg)
return response
async def _run_with_parent_agent(
self,
task_content: str,
task_type: str,
timeout: float,
) -> Dict[str, Any]:
"""Run task using the parent agent directly.
Used when no dynamic subagent is defined.
Args:
task_content: Task prompt
task_type: Type of task
timeout: Execution timeout
Returns:
Dict with execution results
"""
try:
result = await asyncio.wait_for(
self._agent.reply(Msg(
name="user",
content=task_content,
role="user",
)),
timeout=timeout,
)
response_content = ""
if isinstance(result, Msg):
response_content = result.content
elif hasattr(result, "content"):
response_content = str(result.content)
else:
response_content = str(result)
return {
"task_type": task_type,
"task": task_content,
"status": "completed",
"response": response_content,
"message": f"Task '{task_type}' executed with parent agent",
}
except asyncio.TimeoutError:
return {
"task_type": task_type,
"task": task_content,
"status": "timeout",
"error": f"Execution timed out after {timeout} seconds",
"message": f"Task '{task_type}' timed out",
}
except Exception as e:
logger.error(
"Parent agent failed for task '%s': %s",
task_type,
e,
exc_info=True,
)
return {
"task_type": task_type,
"task": task_content,
"status": "error",
"error": str(e),
"message": f"Task '{task_type}' failed: {e}",
}
def get_dynamic_subagent(self, name: str) -> Optional[SubagentSpec]:
"""Get a dynamically defined subagent specification.
Args:
name: Subagent name
Returns:
Subagent spec dict or None if not found
"""
return self._dynamic_subagents.get(name)
def list_dynamic_subagents(self) -> List[str]:
"""List all registered dynamic subagent names.
Returns:
List of subagent names
"""
return list(self._dynamic_subagents.keys())
__all__ = ["TaskDelegator", "SubagentSpec"]

View File

@@ -0,0 +1,389 @@
# -*- coding: utf-8 -*-
"""TeamCoordinator - Agent lifecycle management and execution.
Provides run_parallel() using asyncio.gather() and run_sequential()
for coordinating multiple agents.
"""
from __future__ import annotations
import asyncio
import logging
from typing import Any, Awaitable, Callable, Dict, List, Optional, Type
from agentscope.message import Msg
logger = logging.getLogger(__name__)
class TeamCoordinator:
"""Coordinates agent lifecycle and execution.
Supports:
- run_parallel(): Execute multiple agents concurrently with asyncio.gather()
- run_sequential(): Execute agents one after another
- run_phase(): Execute a named phase with registered agents
- register_agent(): Add agent to coordinator
- unregister_agent(): Remove agent from coordinator
Each agent maintains separate context/memory.
"""
def __init__(
self,
participants: Optional[List[Any]] = None,
task_content: Optional[str] = None,
messenger: Optional[Any] = None,
registry: Optional[Any] = None,
):
"""Initialize TeamCoordinator.
Args:
participants: List of agent instances to coordinate
task_content: Task description content for the agents
messenger: AgentMessenger for communication (optional)
registry: AgentRegistry for agent lookup (optional)
"""
self._participants = participants or []
self._task_content = task_content or ""
self._messenger = messenger
self._registry = registry
self._agents: Dict[str, Any] = {}
self._running_tasks: Dict[str, asyncio.Task] = {}
# Auto-register participants
for agent in self._participants:
if hasattr(agent, "name"):
self._agents[agent.name] = agent
elif hasattr(agent, "id"):
self._agents[agent.id] = agent
def register_agent(self, agent_id: str, agent: Any) -> None:
"""Register an agent with the coordinator.
Args:
agent_id: Unique agent identifier
agent: Agent instance
"""
self._agents[agent_id] = agent
logger.info("Registered agent: %s", agent_id)
def unregister_agent(self, agent_id: str) -> None:
"""Unregister an agent from the coordinator.
Args:
agent_id: Agent identifier to remove
"""
if agent_id in self._agents:
del self._agents[agent_id]
logger.info("Unregistered agent: %s", agent_id)
def get_agent(self, agent_id: str) -> Any:
"""Get registered agent by ID.
Args:
agent_id: Agent identifier
Returns:
Agent instance
"""
return self._agents.get(agent_id)
def list_agents(self) -> List[str]:
"""List all registered agent IDs.
Returns:
List of agent identifiers
"""
return list(self._agents.keys())
async def run_parallel(
self,
agent_ids: List[str],
initial_message: Optional[Msg] = None,
) -> Dict[str, Any]:
"""Run multiple agents in parallel using asyncio.gather().
Args:
agent_ids: List of agent IDs to run concurrently
initial_message: Optional initial message to broadcast
Returns:
Dict mapping agent_id to result
"""
async def _run_agent(aid: str) -> tuple[str, Any]:
agent = self._agents.get(aid)
if agent is None:
logger.error("Agent %s not found", aid)
return (aid, None)
try:
if hasattr(agent, "reply") and asyncio.iscoroutinefunction(agent.reply):
if initial_message:
result = await agent.reply(initial_message)
else:
result = await agent.reply()
elif hasattr(agent, "run") and asyncio.iscoroutinefunction(agent.run):
result = await agent.run()
else:
result = await agent()
logger.info("Agent %s completed successfully", aid)
return (aid, result)
except Exception as e:
logger.error("Agent %s failed: %s", aid, e)
return (aid, {"error": str(e)})
results = await asyncio.gather(
*[_run_agent(aid) for aid in agent_ids],
return_exceptions=True,
)
output: Dict[str, Any] = {}
for result in results:
if isinstance(result, tuple):
agent_id, agent_result = result
output[agent_id] = agent_result
else:
logger.error("Unexpected result from asyncio.gather: %s", result)
logger.info("Parallel run completed for %d agents", len(agent_ids))
return output
async def run_sequential(
self,
agent_ids: List[str],
initial_message: Optional[Msg] = None,
) -> Dict[str, Any]:
"""Run agents one after another in order.
Args:
agent_ids: List of agent IDs to run in sequence
initial_message: Optional initial message for first agent
Returns:
Dict mapping agent_id to result
"""
output: Dict[str, Any] = {}
current_message = initial_message
for agent_id in agent_ids:
agent = self._agents.get(agent_id)
if agent is None:
logger.error("Agent %s not found", agent_id)
output[agent_id] = {"error": "Agent not found"}
continue
try:
if hasattr(agent, "reply") and asyncio.iscoroutinefunction(agent.reply):
result = await agent.reply(current_message)
elif hasattr(agent, "run") and asyncio.iscoroutinefunction(agent.run):
result = await agent.run()
else:
result = await agent()
output[agent_id] = result
current_message = result
logger.info("Agent %s completed sequentially", agent_id)
except Exception as e:
logger.error("Agent %s failed: %s", agent_id, e)
output[agent_id] = {"error": str(e)}
break
logger.info("Sequential run completed for %d agents", len(agent_ids))
return output
async def run_phase(
self,
phase_name: str,
agent_ids: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> List[Any]:
"""Execute a named phase with registered agents.
Args:
phase_name: Name of the phase (e.g., "analyst_analysis")
agent_ids: Optional list of agent IDs; if None, uses all registered
metadata: Optional metadata to include in the message (e.g., tickers, date)
Returns:
List of results from each agent
"""
if agent_ids is None:
agent_ids = list(self._agents.keys())
_agent_ids = [aid for aid in agent_ids if aid in self._agents]
logger.info(
"Running phase '%s' with %d agents: %s",
phase_name,
len(_agent_ids),
_agent_ids,
)
# Create messages for each agent
results: List[Any] = []
for agent_id in _agent_ids:
agent = self._agents[agent_id]
try:
if hasattr(agent, "reply") and asyncio.iscoroutinefunction(agent.reply):
# Create a message for the agent with proper structure
msg = Msg(
name="system",
content=self._task_content or f"Please execute phase: {phase_name}",
role="user",
metadata=metadata,
)
result = await agent.reply(msg)
elif hasattr(agent, "run") and asyncio.iscoroutinefunction(agent.run):
result = await agent.run()
else:
result = await agent()
results.append(result)
logger.info("Phase '%s': Agent %s completed", phase_name, agent_id)
except Exception as e:
logger.error("Phase '%s': Agent %s failed: %s", phase_name, agent_id, e)
results.append(None)
logger.info("Phase '%s' completed with %d results", phase_name, len(results))
return results
async def run_with_dependencies(
self,
agent_tasks: Dict[str, List[str]],
initial_message: Optional[Msg] = None,
) -> Dict[str, Any]:
"""Run agents respecting dependency graph.
Args:
agent_tasks: Dict mapping agent_id to list of prerequisite agent_ids
initial_message: Optional initial message
Returns:
Dict mapping agent_id to result
"""
completed: Dict[str, Any] = {}
remaining = set(agent_tasks.keys())
while remaining:
ready = [
aid for aid in remaining
if all(dep in completed for dep in agent_tasks.get(aid, []))
]
if not ready:
logger.error("Circular dependency detected in agent tasks")
for aid in remaining:
completed[aid] = {"error": "Circular dependency"}
break
results = await self.run_parallel(ready, initial_message)
completed.update(results)
for aid in ready:
remaining.discard(aid)
initial_message = results.get(aid)
return completed
async def fanout_pipeline(
self,
agents: List[Any],
msg: Optional[Msg] = None,
) -> List[Msg]:
"""Fanout a message to multiple agents concurrently and collect all responses.
Similar to AgentScope's fanout_pipeline, this sends the same message
to all specified agents and returns a list of all agent responses.
Args:
agents: List of agent instances to fanout the message to
msg: Message to send to all agents (optional)
Returns:
List of Msg responses from each agent (in the same order as input agents)
Example:
>>> responses = await fanout_pipeline(
... agents=[alice, bob, charlie],
... msg=question,
... )
>>> # responses is a list of Msg responses from each agent
"""
async def _fanout_to_agent(agent: Any) -> Optional[Msg]:
"""Send message to a single agent and return its response."""
try:
if hasattr(agent, "reply") and asyncio.iscoroutinefunction(agent.reply):
result = await agent.reply(msg) if msg is not None else await agent.reply()
elif hasattr(agent, "run") and asyncio.iscoroutinefunction(agent.run):
result = await agent.run()
else:
result = await agent()
# Convert result to Msg if needed
if result is None:
return None
if isinstance(result, Msg):
return result
# If result is a dict with content, wrap it
if isinstance(result, dict) and "content" in result:
return Msg(
name=getattr(agent, "name", "unknown"),
content=result.get("content", ""),
role="assistant",
metadata=result.get("metadata"),
)
# Otherwise wrap the result
return Msg(
name=getattr(agent, "name", "unknown"),
content=str(result),
role="assistant",
)
except Exception as e:
logger.error("Agent %s failed in fanout_pipeline: %s",
getattr(agent, "name", "unknown"), e)
return None
# Run all agents concurrently
results = await asyncio.gather(
*[_fanout_to_agent(agent) for agent in agents],
return_exceptions=True,
)
# Filter out exceptions and keep only valid responses
responses: List[Msg] = []
for i, result in enumerate(results):
if isinstance(result, Exception):
logger.error("Fanout to agent %d failed: %s", i, result)
responses.append(None) # type: ignore[arg-type]
else:
responses.append(result) # type: ignore[arg-type]
logger.info("Fanout pipeline completed for %d agents", len(agents))
return responses
async def shutdown(self, timeout: Optional[float] = 5.0) -> None:
"""Shutdown all running agents gracefully.
Args:
timeout: Timeout for graceful shutdown
"""
logger.info("Shutting down TeamCoordinator...")
cancel_tasks = [
asyncio.create_task(asyncio.wait_for(task, timeout=timeout))
for task in self._running_tasks.values()
]
if cancel_tasks:
await asyncio.gather(*cancel_tasks, return_exceptions=True)
self._running_tasks.clear()
logger.info("TeamCoordinator shutdown complete")
@property
def agents(self) -> Dict[str, Any]:
"""Get copy of registered agents dict."""
return dict(self._agents)
__all__ = ["TeamCoordinator"]

View File

@@ -0,0 +1,132 @@
# -*- coding: utf-8 -*-
"""Run-scoped team pipeline configuration helpers."""
from __future__ import annotations
from pathlib import Path
from typing import Iterable, List, Dict, Any
import yaml
DEFAULT_FILENAME = "TEAM_PIPELINE.yaml"
def team_pipeline_path(project_root: Path, config_name: str) -> Path:
"""Return run-scoped team pipeline config path."""
return project_root / "runs" / config_name / DEFAULT_FILENAME
def ensure_team_pipeline_config(
project_root: Path,
config_name: str,
default_analysts: Iterable[str],
) -> Path:
"""Ensure TEAM_PIPELINE.yaml exists for one run."""
path = team_pipeline_path(project_root, config_name)
path.parent.mkdir(parents=True, exist_ok=True)
if path.exists():
return path
payload = {
"version": 1,
"controller_agent": "portfolio_manager",
"discussion": {
"allow_dynamic_team_update": True,
"active_analysts": list(default_analysts),
},
"decision": {
"require_risk_manager": True,
},
}
path.write_text(
yaml.safe_dump(payload, allow_unicode=True, sort_keys=False),
encoding="utf-8",
)
return path
def load_team_pipeline_config(project_root: Path, config_name: str) -> Dict[str, Any]:
"""Load TEAM_PIPELINE.yaml and return parsed dict."""
path = team_pipeline_path(project_root, config_name)
if not path.exists():
return {}
parsed = yaml.safe_load(path.read_text(encoding="utf-8")) or {}
return parsed if isinstance(parsed, dict) else {}
def save_team_pipeline_config(
project_root: Path,
config_name: str,
config: Dict[str, Any],
) -> Path:
"""Persist TEAM_PIPELINE.yaml."""
path = team_pipeline_path(project_root, config_name)
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(
yaml.safe_dump(config, allow_unicode=True, sort_keys=False),
encoding="utf-8",
)
return path
def resolve_active_analysts(
project_root: Path,
config_name: str,
available_analysts: Iterable[str],
) -> List[str]:
"""Resolve active analysts from TEAM_PIPELINE.yaml."""
available = [item for item in available_analysts]
parsed = load_team_pipeline_config(project_root, config_name)
discussion = parsed.get("discussion", {}) if isinstance(parsed, dict) else {}
configured = discussion.get("active_analysts", [])
if not isinstance(configured, list) or not configured:
return available
active = [item for item in configured if item in available]
return active or available
def update_active_analysts(
project_root: Path,
config_name: str,
available_analysts: Iterable[str],
*,
add: Iterable[str] | None = None,
remove: Iterable[str] | None = None,
set_to: Iterable[str] | None = None,
) -> List[str]:
"""Update active analysts and persist TEAM_PIPELINE.yaml."""
available = [item for item in available_analysts]
ensure_team_pipeline_config(project_root, config_name, available)
parsed = load_team_pipeline_config(project_root, config_name)
discussion = parsed.setdefault("discussion", {})
if not isinstance(discussion, dict):
discussion = {}
parsed["discussion"] = discussion
current = discussion.get("active_analysts", [])
if not isinstance(current, list):
current = []
current = [item for item in current if item in available]
if not current:
current = list(available)
if set_to is not None:
target = [item for item in set_to if item in available]
current = target or current
for item in add or []:
if item in available and item not in current:
current.append(item)
for item in remove or []:
current = [existing for existing in current if existing != item]
if not current:
current = [available[0]] if available else []
discussion["active_analysts"] = current
save_team_pipeline_config(project_root, config_name, parsed)
return current

View File

@@ -129,6 +129,33 @@ class RunWorkspaceManager:
)
return asset_dir
def load_agent_file(
self,
*,
config_name: str,
agent_id: str,
filename: str,
) -> str:
"""Load one run-scoped agent workspace file."""
path = self.get_agent_asset_dir(config_name, agent_id) / filename
if not path.exists():
raise FileNotFoundError(f"File not found: {filename}")
return path.read_text(encoding="utf-8")
def update_agent_file(
self,
*,
config_name: str,
agent_id: str,
filename: str,
content: str,
) -> None:
"""Write one run-scoped agent workspace file."""
asset_dir = self.get_agent_asset_dir(config_name, agent_id)
asset_dir.mkdir(parents=True, exist_ok=True)
path = asset_dir / filename
path.write_text(content, encoding="utf-8")
def initialize_default_assets(
self,
config_name: str,

View File

@@ -13,8 +13,13 @@ from typing import Any, Dict, List, Optional
from fastapi import APIRouter, HTTPException, Depends, Body, UploadFile, File, Form
from pydantic import BaseModel, Field
from backend.agents import AgentFactory, WorkspaceManager, get_registry
from backend.agents import AgentFactory, get_registry
from backend.agents.workspace_manager import RunWorkspaceManager
from backend.agents.agent_workspace import load_agent_workspace_config
from backend.agents.skills_manager import SkillsManager
from backend.agents.toolkit_factory import load_agent_profiles
from backend.config.bootstrap_config import get_bootstrap_config_for_run
from backend.llm.models import get_agent_model_info
logger = logging.getLogger(__name__)
@@ -47,6 +52,14 @@ class InstallExternalSkillRequest(BaseModel):
activate: bool = Field(True, description="Whether to enable skill immediately")
class LocalSkillRequest(BaseModel):
skill_name: str = Field(..., description="Local skill name")
class LocalSkillContentRequest(BaseModel):
content: str = Field(..., description="Updated SKILL.md content")
class AgentResponse(BaseModel):
"""Agent information response."""
agent_id: str
@@ -63,6 +76,24 @@ class AgentFileResponse(BaseModel):
content: str
class AgentProfileResponse(BaseModel):
agent_id: str
workspace_id: str
profile: Dict[str, Any]
class AgentSkillsResponse(BaseModel):
agent_id: str
workspace_id: str
skills: List[Dict[str, Any]]
class SkillDetailResponse(BaseModel):
agent_id: str
workspace_id: str
skill: Dict[str, Any]
# Dependencies
def get_agent_factory():
"""Get AgentFactory instance."""
@@ -70,8 +101,8 @@ def get_agent_factory():
def get_workspace_manager():
"""Get WorkspaceManager instance."""
return WorkspaceManager()
"""Get run-scoped workspace manager instance."""
return RunWorkspaceManager()
def get_skills_manager():
@@ -199,6 +230,108 @@ async def get_agent(
)
@router.get("/{agent_id}/profile", response_model=AgentProfileResponse)
async def get_agent_profile(
workspace_id: str,
agent_id: str,
skills_manager: SkillsManager = Depends(get_skills_manager),
):
asset_dir = skills_manager.get_agent_asset_dir(workspace_id, agent_id)
agent_config = load_agent_workspace_config(asset_dir / "agent.yaml")
profiles = load_agent_profiles()
profile = profiles.get(agent_id, {})
bootstrap = get_bootstrap_config_for_run(skills_manager.project_root, workspace_id)
override = bootstrap.agent_override(agent_id)
active_tool_groups = override.get("active_tool_groups", agent_config.active_tool_groups or profile.get("active_tool_groups", []))
if not isinstance(active_tool_groups, list):
active_tool_groups = []
disabled_tool_groups = agent_config.disabled_tool_groups
if disabled_tool_groups:
disabled_set = set(disabled_tool_groups)
active_tool_groups = [group_name for group_name in active_tool_groups if group_name not in disabled_set]
default_skills = profile.get("skills", [])
if not isinstance(default_skills, list):
default_skills = []
resolved_skills = skills_manager.resolve_agent_skill_names(
config_name=workspace_id,
agent_id=agent_id,
default_skills=default_skills,
)
prompt_files = agent_config.prompt_files or ["SOUL.md", "PROFILE.md", "AGENTS.md", "POLICY.md", "MEMORY.md"]
model_name, model_provider = get_agent_model_info(agent_id)
return AgentProfileResponse(
agent_id=agent_id,
workspace_id=workspace_id,
profile={
"model_name": model_name,
"model_provider": model_provider,
"prompt_files": prompt_files,
"default_skills": default_skills,
"resolved_skills": resolved_skills,
"active_tool_groups": active_tool_groups,
"disabled_tool_groups": disabled_tool_groups,
"enabled_skills": agent_config.enabled_skills,
"disabled_skills": agent_config.disabled_skills,
},
)
@router.get("/{agent_id}/skills", response_model=AgentSkillsResponse)
async def get_agent_skills(
workspace_id: str,
agent_id: str,
skills_manager: SkillsManager = Depends(get_skills_manager),
):
agent_asset_dir = skills_manager.get_agent_asset_dir(workspace_id, agent_id)
agent_config = load_agent_workspace_config(agent_asset_dir / "agent.yaml")
resolved_skills = set(skills_manager.resolve_agent_skill_names(config_name=workspace_id, agent_id=agent_id, default_skills=[]))
enabled = set(agent_config.enabled_skills)
disabled = set(agent_config.disabled_skills)
payload = []
for item in skills_manager.list_agent_skill_catalog(workspace_id, agent_id):
if item.skill_name in disabled:
status = "disabled"
elif item.skill_name in enabled:
status = "enabled"
elif item.skill_name in resolved_skills:
status = "active"
else:
status = "available"
payload.append({
"skill_name": item.skill_name,
"name": item.name,
"description": item.description,
"version": item.version,
"source": item.source,
"tools": item.tools,
"status": status,
})
return AgentSkillsResponse(agent_id=agent_id, workspace_id=workspace_id, skills=payload)
@router.get("/{agent_id}/skills/{skill_name}", response_model=SkillDetailResponse)
async def get_agent_skill_detail(
workspace_id: str,
agent_id: str,
skill_name: str,
skills_manager: SkillsManager = Depends(get_skills_manager),
):
try:
detail = skills_manager.load_agent_skill_document(
config_name=workspace_id,
agent_id=agent_id,
skill_name=skill_name,
)
except FileNotFoundError:
raise HTTPException(status_code=404, detail=f"Unknown skill: {skill_name}")
return SkillDetailResponse(agent_id=agent_id, workspace_id=workspace_id, skill=detail)
@router.delete("/{agent_id}")
async def delete_agent(
workspace_id: str,
@@ -386,6 +519,85 @@ async def install_external_skill(
}
@router.post("/{agent_id}/skills/local")
async def create_local_skill(
workspace_id: str,
agent_id: str,
request: LocalSkillRequest,
registry=Depends(get_registry),
):
agent_info = registry.get(agent_id)
if not agent_info or agent_info.workspace_id != workspace_id:
raise HTTPException(status_code=404, detail=f"Agent '{agent_id}' not found")
skills_manager = SkillsManager()
try:
skills_manager.create_agent_local_skill(
config_name=workspace_id,
agent_id=agent_id,
skill_name=request.skill_name,
)
except (ValueError, FileExistsError) as exc:
raise HTTPException(status_code=400, detail=str(exc))
return {"message": f"Created local skill '{request.skill_name}' for '{agent_id}'"}
@router.put("/{agent_id}/skills/local/{skill_name}")
async def update_local_skill(
workspace_id: str,
agent_id: str,
skill_name: str,
request: LocalSkillContentRequest,
registry=Depends(get_registry),
):
agent_info = registry.get(agent_id)
if not agent_info or agent_info.workspace_id != workspace_id:
raise HTTPException(status_code=404, detail=f"Agent '{agent_id}' not found")
skills_manager = SkillsManager()
try:
skills_manager.update_agent_local_skill(
config_name=workspace_id,
agent_id=agent_id,
skill_name=skill_name,
content=request.content,
)
except (ValueError, FileNotFoundError) as exc:
raise HTTPException(status_code=400, detail=str(exc))
return {"message": f"Updated local skill '{skill_name}' for '{agent_id}'"}
@router.delete("/{agent_id}/skills/local/{skill_name}")
async def delete_local_skill(
workspace_id: str,
agent_id: str,
skill_name: str,
registry=Depends(get_registry),
):
agent_info = registry.get(agent_id)
if not agent_info or agent_info.workspace_id != workspace_id:
raise HTTPException(status_code=404, detail=f"Agent '{agent_id}' not found")
skills_manager = SkillsManager()
try:
skills_manager.delete_agent_local_skill(
config_name=workspace_id,
agent_id=agent_id,
skill_name=skill_name,
)
skills_manager.forget_agent_skill_overrides(
config_name=workspace_id,
agent_id=agent_id,
skill_names=[skill_name],
)
except (ValueError, FileNotFoundError) as exc:
raise HTTPException(status_code=400, detail=str(exc))
return {"message": f"Deleted local skill '{skill_name}' for '{agent_id}'"}
@router.post("/{agent_id}/skills/upload")
async def upload_external_skill(
workspace_id: str,
@@ -441,7 +653,7 @@ async def get_agent_file(
workspace_id: str,
agent_id: str,
filename: str,
workspace_manager: WorkspaceManager = Depends(get_workspace_manager),
workspace_manager: RunWorkspaceManager = Depends(get_workspace_manager),
):
"""
Read an agent's workspace file.
@@ -471,7 +683,7 @@ async def update_agent_file(
agent_id: str,
filename: str,
content: str = Body(..., media_type="text/plain"),
workspace_manager: WorkspaceManager = Depends(get_workspace_manager),
workspace_manager: RunWorkspaceManager = Depends(get_workspace_manager),
):
"""
Update an agent's workspace file.

View File

@@ -8,6 +8,7 @@ import json
import logging
import os
import signal
import shutil
import subprocess
import sys
from datetime import datetime
@@ -16,20 +17,124 @@ from typing import Any, Dict, List, Optional
logger = logging.getLogger(__name__)
from fastapi import APIRouter, HTTPException, BackgroundTasks
from fastapi import APIRouter, BackgroundTasks, HTTPException, Request
from pydantic import BaseModel, Field
from backend.runtime.agent_runtime import AgentRuntimeState
from backend.runtime.manager import TradingRuntimeManager, get_global_runtime_manager
from backend.config.bootstrap_config import (
resolve_runtime_config,
update_bootstrap_values_for_run,
)
router = APIRouter(prefix="/api/runtime", tags=["runtime"])
runtime_manager: Optional[TradingRuntimeManager] = None
PROJECT_ROOT = Path(__file__).resolve().parents[2]
# Gateway process management
_gateway_process: Optional[subprocess.Popen] = None
_gateway_port: int = 8765
class RuntimeState:
"""Thread-safe singleton for managing runtime state.
Encapsulates runtime_manager, _gateway_process, and _gateway_port
with asyncio.Lock protection for concurrent access.
"""
_instance: Optional["RuntimeState"] = None
_lock: "threading.Lock" = __import__("threading").Lock()
def __new__(cls) -> "RuntimeState":
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self) -> None:
if self._initialized:
return
self._runtime_manager: Optional[Any] = None
self._gateway_process: Optional[subprocess.Popen] = None
self._gateway_port: int = 8765
self._state_lock = asyncio.Lock()
self._initialized = True
@property
async def lock(self) -> asyncio.Lock:
"""Get the asyncio lock for state synchronization."""
return self._state_lock
@property
def runtime_manager(self) -> Optional[Any]:
"""Get the runtime manager (no lock - read only)."""
return self._runtime_manager
@runtime_manager.setter
def runtime_manager(self, value: Optional[Any]) -> None:
"""Set the runtime manager."""
self._runtime_manager = value
@property
def gateway_process(self) -> Optional[subprocess.Popen]:
"""Get the gateway process (no lock - read only)."""
return self._gateway_process
@gateway_process.setter
def gateway_process(self, value: Optional[subprocess.Popen]) -> None:
"""Set the gateway process."""
self._gateway_process = value
@property
def gateway_port(self) -> int:
"""Get the gateway port."""
return self._gateway_port
@gateway_port.setter
def gateway_port(self, value: int) -> None:
"""Set the gateway port."""
self._gateway_port = value
async def set_runtime_manager(self, manager: Any) -> None:
"""Set runtime manager with lock protection."""
async with self._state_lock:
self._runtime_manager = manager
async def get_runtime_manager(self) -> Optional[Any]:
"""Get runtime manager with lock protection."""
async with self._state_lock:
return self._runtime_manager
async def set_gateway_process(self, process: Optional[subprocess.Popen]) -> None:
"""Set gateway process with lock protection."""
async with self._state_lock:
self._gateway_process = process
async def get_gateway_process(self) -> Optional[subprocess.Popen]:
"""Get gateway process with lock protection."""
async with self._state_lock:
return self._gateway_process
async def set_gateway_port(self, port: int) -> None:
"""Set gateway port with lock protection."""
async with self._state_lock:
self._gateway_port = port
async def get_gateway_port(self) -> int:
"""Get gateway port with lock protection."""
async with self._state_lock:
return self._gateway_port
# Singleton instance
_runtime_state = RuntimeState()
def get_runtime_state() -> RuntimeState:
"""Get the RuntimeState singleton instance."""
return _runtime_state
# Backward compatibility: module-level runtime_manager for external imports
# This is set by register_runtime_manager() for backward compatibility
runtime_manager: Optional[Any] = None
class RunContextResponse(BaseModel):
@@ -62,6 +167,8 @@ class RuntimeEventsResponse(BaseModel):
class LaunchConfig(BaseModel):
"""Configuration for launching a new trading task."""
launch_mode: str = Field(default="fresh", description="启动形式: fresh, restore")
restore_run_id: Optional[str] = Field(default=None, description="历史任务 run_id用于恢复启动")
tickers: List[str] = Field(default_factory=list, description="股票池")
schedule_mode: str = Field(default="daily", description="调度模式: daily, interval")
interval_minutes: int = Field(default=60, ge=1, description="间隔分钟数")
@@ -74,7 +181,6 @@ class LaunchConfig(BaseModel):
start_date: Optional[str] = Field(default=None, description="回测开始日期 YYYY-MM-DD")
end_date: Optional[str] = Field(default=None, description="回测结束日期 YYYY-MM-DD")
poll_interval: int = Field(default=10, ge=1, le=300, description="市场数据轮询间隔(秒)")
enable_mock: bool = Field(default=False, description="是否启用模拟模式(使用模拟价格数据)")
class LaunchResponse(BaseModel):
@@ -85,17 +191,61 @@ class LaunchResponse(BaseModel):
message: str
class RuntimeHistoryItem(BaseModel):
run_id: str
run_dir: str
updated_at: Optional[str] = None
total_trades: int = 0
total_asset_value: Optional[float] = None
bootstrap: Dict[str, Any] = Field(default_factory=dict)
class RuntimeHistoryResponse(BaseModel):
runs: List[RuntimeHistoryItem]
class StopResponse(BaseModel):
status: str
message: str
class CleanupResponse(BaseModel):
status: str
kept: int
pruned_run_ids: List[str]
class GatewayStatusResponse(BaseModel):
is_running: bool
port: int
run_id: Optional[str] = None
class RuntimeConfigResponse(BaseModel):
run_id: str
is_running: bool
gateway_port: int
bootstrap: Dict[str, Any]
resolved: Dict[str, Any]
class RuntimeLogResponse(BaseModel):
run_id: Optional[str] = None
is_running: bool
log_path: Optional[str] = None
content: str = ""
class UpdateRuntimeConfigRequest(BaseModel):
schedule_mode: Optional[str] = None
interval_minutes: Optional[int] = Field(default=None, ge=1)
trigger_time: Optional[str] = None
max_comm_cycles: Optional[int] = Field(default=None, ge=1)
initial_cash: Optional[float] = Field(default=None, gt=0)
margin_requirement: Optional[float] = Field(default=None, ge=0)
enable_memory: Optional[bool] = None
def _generate_run_id() -> str:
"""Generate timestamp-based run ID: YYYYMMDD_HHMMSS"""
return datetime.now().strftime("%Y%m%d_%H%M%S")
@@ -106,6 +256,128 @@ def _get_run_dir(run_id: str) -> Path:
return PROJECT_ROOT / "runs" / run_id
def _load_run_snapshot(run_id: str) -> Dict[str, Any]:
"""Load a specific run snapshot by run_id."""
snapshot_path = _get_run_dir(run_id) / "state" / "runtime_state.json"
if not snapshot_path.exists():
raise HTTPException(status_code=404, detail=f"Run snapshot not found: {run_id}")
return json.loads(snapshot_path.read_text(encoding="utf-8"))
def _copy_path_if_exists(src: Path, dst: Path) -> None:
if not src.exists():
return
if src.is_dir():
shutil.copytree(src, dst, dirs_exist_ok=True)
else:
dst.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(src, dst)
def _restore_run_assets(source_run_id: str, target_run_dir: Path) -> None:
"""Seed a fresh run directory from a historical run snapshot."""
source_run_dir = _get_run_dir(source_run_id)
if not source_run_dir.exists():
raise HTTPException(status_code=404, detail=f"Source run not found: {source_run_id}")
for relative in [
"team_dashboard",
"agents",
"skills",
"memory",
"state/server_state.json",
"state/runtime.db",
"state/research.db",
]:
_copy_path_if_exists(source_run_dir / relative, target_run_dir / relative)
def _list_runs(limit: int = 50) -> list[RuntimeHistoryItem]:
runs_root = PROJECT_ROOT / "runs"
if not runs_root.exists():
return []
items: list[RuntimeHistoryItem] = []
run_dirs = sorted(
[path for path in runs_root.iterdir() if path.is_dir()],
key=lambda path: path.stat().st_mtime,
reverse=True,
)
for run_dir in run_dirs[: max(1, int(limit))]:
run_id = run_dir.name
runtime_state_path = run_dir / "state" / "runtime_state.json"
summary_path = run_dir / "team_dashboard" / "summary.json"
bootstrap: Dict[str, Any] = {}
updated_at: Optional[str] = None
total_trades = 0
total_asset_value: Optional[float] = None
if runtime_state_path.exists():
try:
snapshot = json.loads(runtime_state_path.read_text(encoding="utf-8"))
context = snapshot.get("context") or {}
bootstrap = dict(context.get("bootstrap_values") or {})
updated_at = snapshot.get("events", [{}])[-1].get("timestamp") if snapshot.get("events") else None
except Exception:
bootstrap = {}
if summary_path.exists():
try:
summary = json.loads(summary_path.read_text(encoding="utf-8"))
total_trades = int(summary.get("totalTrades") or 0)
total_asset_value = float(summary.get("totalAssetValue")) if summary.get("totalAssetValue") is not None else None
except Exception:
total_trades = 0
total_asset_value = None
items.append(
RuntimeHistoryItem(
run_id=run_id,
run_dir=str(run_dir),
updated_at=updated_at,
total_trades=total_trades,
total_asset_value=total_asset_value,
bootstrap=bootstrap,
)
)
return items
def _is_timestamped_run_dir(path: Path) -> bool:
try:
datetime.strptime(path.name, "%Y%m%d_%H%M%S")
return True
except ValueError:
return False
def _prune_old_timestamped_runs(*, keep: int = 20, exclude_run_ids: Optional[set[str]] = None) -> list[str]:
"""Prune old timestamped run directories, preserving the newest N and excluded ids."""
exclude = exclude_run_ids or set()
runs_root = PROJECT_ROOT / "runs"
if not runs_root.exists():
return []
candidates = sorted(
[
path
for path in runs_root.iterdir()
if path.is_dir() and _is_timestamped_run_dir(path) and path.name not in exclude
],
key=lambda path: path.name,
reverse=True,
)
pruned: list[str] = []
for path in candidates[max(0, keep):]:
shutil.rmtree(path, ignore_errors=True)
pruned.append(path.name)
return pruned
def _find_available_port(start_port: int = 8765, max_port: int = 9000) -> int:
"""Find an available port for Gateway."""
import socket
@@ -118,31 +390,31 @@ def _find_available_port(start_port: int = 8765, max_port: int = 9000) -> int:
def _is_gateway_running() -> bool:
"""Check if Gateway process is running."""
global _gateway_process
if _gateway_process is None:
process = _runtime_state.gateway_process
if process is None:
return False
return _gateway_process.poll() is None
return process.poll() is None
def _stop_gateway() -> bool:
"""Stop the Gateway process."""
global _gateway_process
if _gateway_process is None:
process = _runtime_state.gateway_process
if process is None:
return False
try:
# Try graceful shutdown first
_gateway_process.terminate()
process.terminate()
try:
_gateway_process.wait(timeout=5)
process.wait(timeout=5)
except subprocess.TimeoutExpired:
# Force kill if graceful shutdown fails
_gateway_process.kill()
_gateway_process.wait()
process.kill()
process.wait()
except Exception as e:
logger.warning(f"Error during gateway shutdown: {e}")
finally:
_gateway_process = None
_runtime_state.gateway_process = None
return True
@@ -167,29 +439,29 @@ def _start_gateway_process(
"--bootstrap", json.dumps(bootstrap)
]
# Start process
log_path = run_dir / "logs" / "gateway.log"
log_path.parent.mkdir(parents=True, exist_ok=True)
log_file = log_path.open("ab")
try:
process = subprocess.Popen(
cmd,
env=env,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
stdout=log_file,
stderr=subprocess.STDOUT,
cwd=PROJECT_ROOT
)
finally:
log_file.close()
return process
@router.get("/context", response_model=RunContextResponse)
async def get_run_context() -> RunContextResponse:
"""Return the most recent run context."""
snapshot_path = PROJECT_ROOT.glob("runs/*/state/runtime_state.json")
snapshots = sorted(snapshot_path, key=lambda p: p.stat().st_mtime, reverse=True)
if not snapshots:
raise HTTPException(status_code=404, detail="No run context available")
latest = json.loads(snapshots[0].read_text(encoding="utf-8"))
context = latest.get("context")
"""Return active runtime context, or latest persisted context when stopped."""
snapshot = _get_active_runtime_snapshot() if _is_gateway_running() else _load_latest_runtime_snapshot()
context = snapshot.get("context")
if context is None:
raise HTTPException(status_code=404, detail="Run context is not ready")
@@ -202,15 +474,9 @@ async def get_run_context() -> RunContextResponse:
@router.get("/agents", response_model=RuntimeAgentsResponse)
async def get_runtime_agents() -> RuntimeAgentsResponse:
"""Return agent states from the most recent run."""
snapshot_path = PROJECT_ROOT.glob("runs/*/state/runtime_state.json")
snapshots = sorted(snapshot_path, key=lambda p: p.stat().st_mtime, reverse=True)
if not snapshots:
raise HTTPException(status_code=404, detail="No runtime state available")
latest = json.loads(snapshots[0].read_text(encoding="utf-8"))
agents = latest.get("agents", [])
"""Return agent states from the active runtime, or latest persisted run."""
snapshot = _get_active_runtime_snapshot() if _is_gateway_running() else _load_latest_runtime_snapshot()
agents = snapshot.get("agents", [])
return RuntimeAgentsResponse(
agents=[RuntimeAgentState(**a) for a in agents]
@@ -219,58 +485,219 @@ async def get_runtime_agents() -> RuntimeAgentsResponse:
@router.get("/events", response_model=RuntimeEventsResponse)
async def get_runtime_events() -> RuntimeEventsResponse:
"""Return events from the most recent run."""
snapshot_path = PROJECT_ROOT.glob("runs/*/state/runtime_state.json")
snapshots = sorted(snapshot_path, key=lambda p: p.stat().st_mtime, reverse=True)
if not snapshots:
raise HTTPException(status_code=404, detail="No runtime state available")
latest = json.loads(snapshots[0].read_text(encoding="utf-8"))
events = latest.get("events", [])
"""Return events from the active runtime, or latest persisted run."""
snapshot = _get_active_runtime_snapshot() if _is_gateway_running() else _load_latest_runtime_snapshot()
events = snapshot.get("events", [])
return RuntimeEventsResponse(
events=[RuntimeEvent(**e) for e in events]
)
@router.get("/history", response_model=RuntimeHistoryResponse)
async def get_runtime_history(limit: int = 20) -> RuntimeHistoryResponse:
"""List recent historical runs for restore/start selection."""
return RuntimeHistoryResponse(runs=_list_runs(limit=limit))
@router.get("/gateway/status", response_model=GatewayStatusResponse)
async def get_gateway_status() -> GatewayStatusResponse:
"""Get Gateway process status and port."""
global _gateway_port
is_running = _is_gateway_running()
run_id = None
if is_running:
# Try to find run_id from runtime state
snapshot_path = PROJECT_ROOT.glob("runs/*/state/runtime_state.json")
snapshots = sorted(snapshot_path, key=lambda p: p.stat().st_mtime, reverse=True)
if snapshots:
try:
latest = json.loads(snapshots[0].read_text(encoding="utf-8"))
run_id = latest.get("context", {}).get("config_name")
run_id = _get_active_runtime_context().get("config_name")
except Exception as e:
logger.warning(f"Failed to parse latest snapshot: {e}")
logger.warning(f"Failed to resolve active runtime context: {e}")
return GatewayStatusResponse(
is_running=is_running,
port=_gateway_port,
port=_runtime_state.gateway_port,
run_id=run_id
)
@router.get("/gateway/port")
async def get_gateway_port() -> Dict[str, Any]:
async def get_gateway_port(request: Request) -> Dict[str, Any]:
"""Get WebSocket Gateway port for frontend connection."""
global _gateway_port
gateway_port = _runtime_state.gateway_port
return {
"port": _gateway_port,
"port": gateway_port,
"is_running": _is_gateway_running(),
"ws_url": f"ws://localhost:{_gateway_port}"
"ws_url": _build_gateway_ws_url(request, gateway_port),
}
@router.get("/logs", response_model=RuntimeLogResponse)
async def get_runtime_logs() -> RuntimeLogResponse:
"""Return current runtime log tail, or the latest run log if runtime is stopped."""
try:
context = _get_active_runtime_context() if _is_gateway_running() else _get_runtime_context_from_latest_snapshot()
except HTTPException:
return RuntimeLogResponse(is_running=False, content="")
run_id = str(context.get("config_name") or "").strip() or None
log_path = _get_gateway_log_path_for_run(run_id) if run_id else None
content = _read_log_tail(log_path) if log_path else ""
return RuntimeLogResponse(
run_id=run_id,
is_running=_is_gateway_running(),
log_path=str(log_path) if log_path else None,
content=content,
)
def _build_gateway_ws_url(request: Request, port: int) -> str:
"""Build a proxy-safe Gateway WebSocket URL."""
forwarded_proto = request.headers.get("x-forwarded-proto", "").split(",")[0].strip()
scheme = forwarded_proto or request.url.scheme
ws_scheme = "wss" if scheme == "https" else "ws"
forwarded_host = request.headers.get("x-forwarded-host", "").split(",")[0].strip()
host = forwarded_host or request.url.hostname or "localhost"
if ":" in host and not host.startswith("["):
host = host.split(":", 1)[0]
return f"{ws_scheme}://{host}:{port}"
def _load_latest_runtime_snapshot() -> Dict[str, Any]:
"""Load the latest persisted runtime snapshot."""
snapshots = sorted(
PROJECT_ROOT.glob("runs/*/state/runtime_state.json"),
key=lambda p: p.stat().st_mtime,
reverse=True,
)
if not snapshots:
raise HTTPException(status_code=404, detail="No runtime information available")
return json.loads(snapshots[0].read_text(encoding="utf-8"))
def _get_active_runtime_snapshot() -> Dict[str, Any]:
"""Return the active runtime snapshot, preferring in-memory manager state."""
if not _is_gateway_running():
raise HTTPException(status_code=404, detail="No runtime is currently running")
manager = _runtime_state.runtime_manager
if manager is not None and hasattr(manager, "build_snapshot"):
snapshot = manager.build_snapshot()
context = snapshot.get("context") or {}
if context.get("config_name"):
return snapshot
return _load_latest_runtime_snapshot()
def _get_runtime_context_from_latest_snapshot() -> Dict[str, Any]:
"""Return the latest persisted runtime context regardless of active process state."""
latest = _load_latest_runtime_snapshot()
context = latest.get("context") or {}
if not context.get("config_name"):
raise HTTPException(status_code=404, detail="No runtime context available")
return context
def _get_gateway_log_path_for_run(run_id: str) -> Path:
return _get_run_dir(run_id) / "logs" / "gateway.log"
def _read_log_tail(path: Path, max_chars: int = 120_000) -> str:
if not path.exists() or not path.is_file():
return ""
text = path.read_text(encoding="utf-8", errors="replace")
if len(text) <= max_chars:
return text
return text[-max_chars:]
def _get_current_runtime_context() -> Dict[str, Any]:
"""Return the active runtime context from the latest snapshot."""
if not _is_gateway_running():
raise HTTPException(status_code=404, detail="No runtime is currently running")
snapshot = _get_active_runtime_snapshot()
context = snapshot.get("context") or {}
if not context.get("config_name"):
raise HTTPException(status_code=404, detail="No runtime context available")
return context
def _get_active_runtime_context() -> Dict[str, Any]:
"""Return the active runtime context, preferring in-memory runtime manager state."""
return _get_current_runtime_context()
def _resolve_runtime_response(run_id: str) -> RuntimeConfigResponse:
"""Build a normalized runtime config response for the active run."""
context = _get_current_runtime_context()
bootstrap = dict(context.get("bootstrap_values") or {})
resolved = resolve_runtime_config(
project_root=PROJECT_ROOT,
config_name=run_id,
enable_memory=bool(bootstrap.get("enable_memory", False)),
schedule_mode=str(bootstrap.get("schedule_mode", "daily")),
interval_minutes=int(bootstrap.get("interval_minutes", 60) or 60),
trigger_time=str(bootstrap.get("trigger_time", "09:30") or "09:30"),
)
return RuntimeConfigResponse(
run_id=run_id,
is_running=True,
gateway_port=_runtime_state.gateway_port,
bootstrap=bootstrap,
resolved=resolved,
)
def _normalize_runtime_config_updates(
request: UpdateRuntimeConfigRequest,
) -> Dict[str, Any]:
"""Validate and normalize runtime config updates."""
updates: Dict[str, Any] = {}
if request.schedule_mode is not None:
schedule_mode = str(request.schedule_mode).strip().lower()
if schedule_mode not in {"daily", "intraday"}:
raise HTTPException(
status_code=400,
detail="schedule_mode must be 'daily' or 'intraday'",
)
updates["schedule_mode"] = schedule_mode
if request.interval_minutes is not None:
updates["interval_minutes"] = int(request.interval_minutes)
if request.trigger_time is not None:
trigger_time = str(request.trigger_time).strip()
if trigger_time and trigger_time != "now":
try:
datetime.strptime(trigger_time, "%H:%M")
except ValueError as exc:
raise HTTPException(
status_code=400,
detail="trigger_time must use HH:MM or 'now'",
) from exc
updates["trigger_time"] = trigger_time or "09:30"
if request.max_comm_cycles is not None:
updates["max_comm_cycles"] = int(request.max_comm_cycles)
if request.initial_cash is not None:
updates["initial_cash"] = float(request.initial_cash)
if request.margin_requirement is not None:
updates["margin_requirement"] = float(request.margin_requirement)
if request.enable_memory is not None:
updates["enable_memory"] = bool(request.enable_memory)
if not updates:
raise HTTPException(status_code=400, detail="No runtime config updates provided")
return updates
@router.post("/start", response_model=LaunchResponse)
async def start_runtime(
config: LaunchConfig,
@@ -284,19 +711,38 @@ async def start_runtime(
4. Start Gateway as subprocess (Data Plane)
5. Return Gateway port for WebSocket connection
"""
global _gateway_process, _gateway_port
# Lazy import to avoid circular dependency
from backend.runtime.manager import TradingRuntimeManager
# 1. Stop existing Gateway
if _is_gateway_running():
_stop_gateway()
await asyncio.sleep(1) # Wait for port release
# 2. Generate run ID and directory
launch_mode = str(config.launch_mode or "fresh").strip().lower()
if launch_mode not in {"fresh", "restore"}:
raise HTTPException(status_code=400, detail="launch_mode must be 'fresh' or 'restore'")
# 2. Resolve run ID, directory, and bootstrap
if launch_mode == "restore":
restore_run_id = str(config.restore_run_id or "").strip()
if not restore_run_id:
raise HTTPException(status_code=400, detail="restore_run_id is required when launch_mode=restore")
snapshot = _load_run_snapshot(restore_run_id)
context = snapshot.get("context") or {}
if not context.get("config_name"):
raise HTTPException(status_code=404, detail=f"Run context not found: {restore_run_id}")
run_id = restore_run_id
run_dir = _get_run_dir(run_id)
bootstrap = dict(context.get("bootstrap_values") or {})
bootstrap["launch_mode"] = "restore"
bootstrap["restore_run_id"] = restore_run_id
else:
run_id = _generate_run_id()
run_dir = _get_run_dir(run_id)
# 3. Prepare bootstrap config
bootstrap = {
"launch_mode": "fresh",
"restore_run_id": None,
"tickers": config.tickers,
"schedule_mode": config.schedule_mode,
"interval_minutes": config.interval_minutes,
@@ -309,9 +755,16 @@ async def start_runtime(
"start_date": config.start_date,
"end_date": config.end_date,
"poll_interval": config.poll_interval,
"enable_mock": config.enable_mock,
}
retention_keep = max(1, int(os.getenv("RUNS_RETENTION_COUNT", "20") or "20"))
pruned_run_ids = _prune_old_timestamped_runs(
keep=retention_keep,
exclude_run_ids={run_id},
)
if pruned_run_ids:
logger.info("Pruned old run directories: %s", ", ".join(pruned_run_ids))
# 4. Create runtime manager
manager = TradingRuntimeManager(
config_name=run_id,
@@ -325,25 +778,28 @@ async def start_runtime(
_write_bootstrap_md(run_dir, bootstrap)
# 6. Find available port and start Gateway process
_gateway_port = _find_available_port(start_port=8765)
gateway_port = _find_available_port(start_port=8765)
_runtime_state.gateway_port = gateway_port
try:
_gateway_process = _start_gateway_process(
process = _start_gateway_process(
run_id=run_id,
run_dir=run_dir,
bootstrap=bootstrap,
port=_gateway_port
port=gateway_port
)
_runtime_state.gateway_process = process
# Wait briefly to check if process started successfully
await asyncio.sleep(2)
if not _is_gateway_running():
stdout, stderr = _gateway_process.communicate(timeout=1)
_gateway_process = None
_runtime_state.gateway_process = None
log_path = _get_gateway_log_path_for_run(run_id)
log_tail = _read_log_tail(log_path, max_chars=4000)
raise HTTPException(
status_code=500,
detail=f"Gateway failed to start: {stderr.decode() if stderr else 'Unknown error'}"
detail=f"Gateway failed to start: {log_tail or 'Unknown error'}"
)
except Exception as e:
@@ -354,16 +810,44 @@ async def start_runtime(
run_id=run_id,
status="started",
run_dir=str(run_dir),
gateway_port=_gateway_port,
message=f"Runtime started with run_id: {run_id}, Gateway on port: {_gateway_port}",
gateway_port=gateway_port,
message=f"Runtime started with run_id: {run_id}, Gateway on port: {gateway_port}",
)
@router.get("/config", response_model=RuntimeConfigResponse)
async def get_runtime_config() -> RuntimeConfigResponse:
"""Return the current runtime bootstrap and resolved settings."""
context = _get_current_runtime_context()
return _resolve_runtime_response(context["config_name"])
@router.put("/config", response_model=RuntimeConfigResponse)
async def update_runtime_config(
request: UpdateRuntimeConfigRequest,
) -> RuntimeConfigResponse:
"""Persist selected runtime configuration updates for the active run."""
context = _get_current_runtime_context()
run_id = context["config_name"]
updates = _normalize_runtime_config_updates(request)
updated = update_bootstrap_values_for_run(PROJECT_ROOT, run_id, updates)
manager = _runtime_state.runtime_manager
if manager is not None and getattr(manager, "config_name", None) == run_id:
manager.bootstrap.update(updates)
if getattr(manager, "context", None) is not None:
manager.context.bootstrap_values.update(updates)
if hasattr(manager, "_persist_snapshot"):
manager._persist_snapshot()
response = _resolve_runtime_response(run_id)
response.bootstrap = dict(updated.values)
return response
@router.post("/stop", response_model=StopResponse)
async def stop_runtime(force: bool = True) -> StopResponse:
"""Stop the current running runtime."""
global _gateway_process
was_running = _is_gateway_running()
if not was_running:
@@ -381,6 +865,25 @@ async def stop_runtime(force: bool = True) -> StopResponse:
)
@router.post("/cleanup", response_model=CleanupResponse)
async def cleanup_old_runs(keep: int = 20) -> CleanupResponse:
"""Prune old timestamped run directories while preserving named runs."""
keep_count = max(1, int(keep))
exclude: set[str] = set()
if _is_gateway_running():
try:
active_context = _get_active_runtime_context()
active_run_id = str(active_context.get("config_name") or "").strip()
if active_run_id:
exclude.add(active_run_id)
except HTTPException:
pass
pruned = _prune_old_timestamped_runs(keep=keep_count, exclude_run_ids=exclude)
return CleanupResponse(status="ok", kept=keep_count, pruned_run_ids=pruned)
@router.post("/restart")
async def restart_runtime(
config: LaunchConfig,
@@ -407,35 +910,31 @@ async def get_current_runtime():
if not _is_gateway_running():
raise HTTPException(status_code=404, detail="No runtime is currently running")
# Find latest runtime state
snapshot_path = PROJECT_ROOT.glob("runs/*/state/runtime_state.json")
snapshots = sorted(snapshot_path, key=lambda p: p.stat().st_mtime, reverse=True)
if not snapshots:
raise HTTPException(status_code=404, detail="No runtime information available")
latest = json.loads(snapshots[0].read_text(encoding="utf-8"))
context = latest.get("context", {})
context = _get_active_runtime_context()
return {
"run_id": context.get("config_name"),
"run_dir": context.get("run_dir"),
"is_running": True,
"gateway_port": _gateway_port,
"gateway_port": _runtime_state.gateway_port,
"bootstrap": context.get("bootstrap_values", {}),
}
def register_runtime_manager(manager: TradingRuntimeManager) -> None:
def register_runtime_manager(manager: Any) -> None:
"""Allow other modules to expose the runtime manager to the API."""
global runtime_manager
runtime_manager = manager
# Also update the RuntimeState singleton for internal consistency
_runtime_state.runtime_manager = manager
def unregister_runtime_manager() -> None:
"""Drop the runtime manager reference."""
global runtime_manager
runtime_manager = None
# Also update the RuntimeState singleton for internal consistency
_runtime_state.runtime_manager = None
def _write_bootstrap_md(run_dir: Path, bootstrap: Dict[str, Any]) -> None:

View File

@@ -1,115 +0,0 @@
# -*- coding: utf-8 -*-
"""
FastAPI Application - REST API for EvoTraders
Provides HTTP endpoints for:
- Agent management
- Workspace management
- Tool guard operations
- Health checks
"""
from contextlib import asynccontextmanager
from pathlib import Path
from typing import AsyncGenerator
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from backend.api import agents_router, workspaces_router, guard_router, runtime_router
from backend.agents import AgentFactory, WorkspaceManager, get_registry
# Global instances (initialized on startup)
agent_factory: AgentFactory | None = None
workspace_manager: WorkspaceManager | None = None
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator:
"""
Application lifespan manager.
Initializes global services on startup and cleans up on shutdown.
"""
global agent_factory, workspace_manager
# Startup: Initialize services
project_root = Path(__file__).parent.parent
# Initialize workspace manager
workspace_manager = WorkspaceManager(project_root=project_root)
# Initialize agent factory
agent_factory = AgentFactory(project_root=project_root)
# Ensure workspaces root exists
agent_factory.workspaces_root.mkdir(parents=True, exist_ok=True)
# Get or create global registry
registry = get_registry()
print(f"✓ EvoTraders API started")
print(f" - Workspaces root: {agent_factory.workspaces_root}")
print(f" - Registered agents: {registry.get_agent_count()}")
yield
# Shutdown: Cleanup
print("✓ EvoTraders API shutting down")
# Create FastAPI application
app = FastAPI(
title="EvoTraders API",
description="REST API for the EvoTraders multi-agent trading system",
version="0.1.0",
lifespan=lifespan,
)
# CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Configure appropriately for production
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Health check endpoint
@app.get("/health")
async def health_check():
"""Health check endpoint."""
registry = get_registry()
return {
"status": "healthy",
"version": "0.1.0",
"agents_registered": registry.get_agent_count(),
"workspaces_available": len(workspace_manager.list_workspaces()) if workspace_manager else 0,
}
# API status endpoint
@app.get("/api/status")
async def api_status():
"""Get API status and system information."""
registry = get_registry()
stats = registry.get_stats()
return {
"status": "operational",
"registry": stats,
}
# Include routers
app.include_router(workspaces_router)
app.include_router(agents_router)
app.include_router(guard_router)
app.include_router(runtime_router)
# Main entry point for running with uvicorn
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)

30
backend/apps/__init__.py Normal file
View File

@@ -0,0 +1,30 @@
# -*- coding: utf-8 -*-
"""Application surfaces for progressive service extraction."""
from .agent_service import app as agent_app
from .agent_service import create_app as create_agent_app
from .news_service import app as news_app
from .news_service import create_app as create_news_app
from .runtime_service import app as runtime_app
from .runtime_service import create_app as create_runtime_app
from .trading_service import app as trading_app
from .trading_service import create_app as create_trading_app
from .cors import add_cors_middleware, get_cors_origins
app = agent_app
create_app = create_agent_app
__all__ = [
"app",
"create_app",
"agent_app",
"create_agent_app",
"news_app",
"create_news_app",
"runtime_app",
"create_runtime_app",
"trading_app",
"create_trading_app",
"add_cors_middleware",
"get_cors_origins",
]

View File

@@ -0,0 +1,89 @@
# -*- coding: utf-8 -*-
"""Agent control-plane FastAPI surface."""
from __future__ import annotations
from contextlib import asynccontextmanager
from pathlib import Path
from typing import AsyncGenerator
from fastapi import FastAPI
from backend.apps.cors import add_cors_middleware
from backend.api import agents_router, guard_router, workspaces_router
from backend.agents import AgentFactory, WorkspaceManager, get_registry
# Global instances (initialized on startup)
agent_factory: AgentFactory | None = None
workspace_manager: WorkspaceManager | None = None
def create_app(project_root: Path | None = None) -> FastAPI:
"""Create the agent control-plane app."""
resolved_project_root = project_root or Path(__file__).resolve().parents[2]
@asynccontextmanager
async def lifespan(_app: FastAPI) -> AsyncGenerator[None, None]:
"""Initialize workspace and registry state for the control plane."""
global agent_factory, workspace_manager
workspace_manager = WorkspaceManager(project_root=resolved_project_root)
agent_factory = AgentFactory(project_root=resolved_project_root)
agent_factory.workspaces_root.mkdir(parents=True, exist_ok=True)
registry = get_registry()
print("✓ EvoTraders API started")
print(f" - Workspaces root: {agent_factory.workspaces_root}")
print(f" - Registered agents: {registry.get_agent_count()}")
yield
print("✓ EvoTraders API shutting down")
app = FastAPI(
title="EvoTraders Agent Service",
description="REST API for the EvoTraders multi-agent control plane",
version="0.1.0",
lifespan=lifespan,
)
add_cors_middleware(app)
@app.get("/health")
async def health_check() -> dict[str, object]:
"""Health check endpoint."""
registry = get_registry()
return {
"status": "healthy",
"version": "0.1.0",
"agents_registered": registry.get_agent_count(),
"workspaces_available": (
len(workspace_manager.list_workspaces())
if workspace_manager
else 0
),
}
@app.get("/api/status")
async def api_status() -> dict[str, object]:
"""Get API status and registry information."""
registry = get_registry()
return {
"status": "operational",
"registry": registry.get_stats(),
}
app.include_router(workspaces_router)
app.include_router(agents_router)
app.include_router(guard_router)
return app
app = create_app()
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)

30
backend/apps/cors.py Normal file
View File

@@ -0,0 +1,30 @@
# -*- coding: utf-8 -*-
"""Shared CORS configuration for all microservice apps."""
import os
from typing import Sequence
from fastapi.middleware.cors import CORSMiddleware
def get_cors_origins() -> Sequence[str]:
"""Get allowed CORS origins from environment variable.
Defaults to ["*"] for backward compatibility.
Set CORS_ALLOWED_ORIGINS env var (comma-separated) in production.
"""
origins = os.getenv("CORS_ALLOWED_ORIGINS", "").strip()
if not origins:
return ["*"]
return [o.strip() for o in origins.split(",") if o.strip()]
def add_cors_middleware(app: "FastAPI") -> None:
"""Add CORS middleware to app with environment-configured origins."""
app.add_middleware(
CORSMiddleware,
allow_origins=get_cors_origins(),
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)

View File

@@ -0,0 +1,154 @@
# -*- coding: utf-8 -*-
"""News and explain FastAPI surface."""
from __future__ import annotations
from typing import Any
from fastapi import Depends, FastAPI, Query
from backend.apps.cors import add_cors_middleware
from backend.data.market_store import MarketStore
from backend.domains import news as news_domain
def get_market_store() -> MarketStore:
"""Get the MarketStore singleton dependency."""
return MarketStore.get_instance()
def create_app() -> FastAPI:
"""Create the news/explain service app."""
app = FastAPI(
title="EvoTraders News Service",
description="Read-only news enrichment and explain service surface extracted from the monolith",
version="0.1.0",
)
add_cors_middleware(app)
@app.get("/health")
async def health_check() -> dict[str, str]:
return {"status": "healthy", "service": "news-service"}
@app.get("/api/enriched-news")
async def api_get_enriched_news(
ticker: str = Query(..., min_length=1),
start_date: str | None = Query(None),
end_date: str | None = Query(None),
limit: int = Query(100, ge=1, le=1000),
store: MarketStore = Depends(get_market_store),
) -> dict[str, Any]:
return news_domain.get_enriched_news(
store,
ticker=ticker,
start_date=start_date,
end_date=end_date,
limit=limit,
refresh_if_stale=False,
)
@app.get("/api/news-for-date")
async def api_get_news_for_date(
ticker: str = Query(..., min_length=1),
date: str = Query(...),
limit: int = Query(20, ge=1, le=100),
store: MarketStore = Depends(get_market_store),
) -> dict[str, Any]:
return news_domain.get_news_for_date(
store,
ticker=ticker,
date=date,
limit=limit,
refresh_if_stale=False,
)
@app.get("/api/news-timeline")
async def api_get_news_timeline(
ticker: str = Query(..., min_length=1),
start_date: str = Query(...),
end_date: str = Query(...),
store: MarketStore = Depends(get_market_store),
) -> dict[str, Any]:
return news_domain.get_news_timeline(
store,
ticker=ticker,
start_date=start_date,
end_date=end_date,
refresh_if_stale=False,
)
@app.get("/api/categories")
async def api_get_categories(
ticker: str = Query(..., min_length=1),
start_date: str | None = Query(None),
end_date: str | None = Query(None),
limit: int = Query(200, ge=1, le=1000),
store: MarketStore = Depends(get_market_store),
) -> dict[str, Any]:
return news_domain.get_news_categories(
store,
ticker=ticker,
start_date=start_date,
end_date=end_date,
limit=limit,
refresh_if_stale=False,
)
@app.get("/api/similar-days")
async def api_get_similar_days(
ticker: str = Query(..., min_length=1),
date: str = Query(...),
n_similar: int = Query(5, ge=1, le=20),
store: MarketStore = Depends(get_market_store),
) -> dict[str, Any]:
return news_domain.get_similar_days_payload(
store,
ticker=ticker,
date=date,
n_similar=n_similar,
refresh_if_stale=False,
)
@app.get("/api/stories/{ticker}")
async def api_get_story(
ticker: str,
as_of_date: str = Query(...),
store: MarketStore = Depends(get_market_store),
) -> dict[str, Any]:
return news_domain.get_story_payload(
store,
ticker=ticker,
as_of_date=as_of_date,
refresh_if_stale=False,
)
@app.get("/api/range-explain")
async def api_get_range_explain(
ticker: str = Query(..., min_length=1),
start_date: str = Query(...),
end_date: str = Query(...),
article_ids: list[str] = Query(default=[]),
limit: int = Query(100, ge=1, le=500),
store: MarketStore = Depends(get_market_store),
) -> dict[str, Any]:
return news_domain.get_range_explain_payload(
store,
ticker=ticker,
start_date=start_date,
end_date=end_date,
article_ids=article_ids,
limit=limit,
refresh_if_stale=False,
)
return app
app = create_app()
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8002)

View File

@@ -0,0 +1,62 @@
# -*- coding: utf-8 -*-
"""Dedicated runtime service FastAPI surface."""
from __future__ import annotations
from fastapi import FastAPI
from backend.api import runtime_router
from backend.api.runtime import get_runtime_state
from backend.apps.cors import add_cors_middleware
def create_app() -> FastAPI:
"""Create the runtime service app."""
app = FastAPI(
title="EvoTraders Runtime Service",
description="Runtime lifecycle and gateway service surface extracted from the monolith",
version="0.1.0",
)
add_cors_middleware(app)
@app.get("/health")
async def health_check() -> dict[str, object]:
"""Health check for the runtime service."""
runtime_state = get_runtime_state()
process = runtime_state.gateway_process
is_running = process is not None and process.poll() is None
return {
"status": "healthy",
"service": "runtime-service",
"gateway_running": is_running,
"gateway_port": runtime_state.gateway_port,
}
@app.get("/api/status")
async def api_status() -> dict[str, object]:
"""Service-level status payload for runtime orchestration."""
runtime_state = get_runtime_state()
process = runtime_state.gateway_process
is_running = process is not None and process.poll() is None
return {
"status": "operational",
"service": "runtime-service",
"runtime": {
"gateway_running": is_running,
"gateway_port": runtime_state.gateway_port,
"has_runtime_manager": runtime_state.runtime_manager is not None,
},
}
app.include_router(runtime_router)
return app
app = create_app()
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8003)

View File

@@ -0,0 +1,136 @@
# -*- coding: utf-8 -*-
"""Trading data FastAPI surface."""
from __future__ import annotations
from typing import Any
from fastapi import FastAPI, Query
from backend.apps.cors import add_cors_middleware
from backend.domains import trading as trading_domain
from shared.schema import (
CompanyNewsResponse,
FinancialMetricsResponse,
InsiderTradeResponse,
LineItemResponse,
PriceResponse,
)
def create_app() -> FastAPI:
"""Create the trading data service app."""
app = FastAPI(
title="EvoTraders Trading Service",
description="Read-only trading data service surface extracted from the monolith",
version="0.1.0",
)
add_cors_middleware(app)
@app.get("/health")
async def health_check() -> dict[str, str]:
"""Health check endpoint."""
return {"status": "healthy", "service": "trading-service"}
@app.get("/api/prices", response_model=PriceResponse)
async def api_get_prices(
ticker: str = Query(..., min_length=1),
start_date: str = Query(...),
end_date: str = Query(...),
) -> PriceResponse:
payload = trading_domain.get_prices_payload(
ticker=ticker,
start_date=start_date,
end_date=end_date,
)
return PriceResponse(ticker=payload["ticker"], prices=payload["prices"])
@app.get("/api/financials", response_model=FinancialMetricsResponse)
async def api_get_financials(
ticker: str = Query(..., min_length=1),
end_date: str = Query(...),
period: str = Query("ttm"),
limit: int = Query(10, ge=1, le=100),
) -> FinancialMetricsResponse:
payload = trading_domain.get_financials_payload(
ticker=ticker,
end_date=end_date,
period=period,
limit=limit,
)
return FinancialMetricsResponse(financial_metrics=payload["financial_metrics"])
@app.get("/api/news", response_model=CompanyNewsResponse)
async def api_get_news(
ticker: str = Query(..., min_length=1),
end_date: str = Query(...),
start_date: str | None = Query(None),
limit: int = Query(1000, ge=1, le=5000),
) -> CompanyNewsResponse:
payload = trading_domain.get_news_payload(
ticker=ticker,
end_date=end_date,
start_date=start_date,
limit=limit,
)
return CompanyNewsResponse(news=payload["news"])
@app.get("/api/insider-trades", response_model=InsiderTradeResponse)
async def api_get_insider_trades(
ticker: str = Query(..., min_length=1),
end_date: str = Query(...),
start_date: str | None = Query(None),
limit: int = Query(1000, ge=1, le=5000),
) -> InsiderTradeResponse:
payload = trading_domain.get_insider_trades_payload(
ticker=ticker,
end_date=end_date,
start_date=start_date,
limit=limit,
)
return InsiderTradeResponse(insider_trades=payload["insider_trades"])
@app.get("/api/market/status")
async def api_get_market_status() -> dict[str, Any]:
"""Return current market status using the existing market service logic."""
return trading_domain.get_market_status_payload()
@app.get("/api/market-cap")
async def api_get_market_cap(
ticker: str = Query(..., min_length=1),
end_date: str = Query(...),
) -> dict[str, Any]:
"""Return market cap for one ticker/date."""
return trading_domain.get_market_cap_payload(
ticker=ticker,
end_date=end_date,
)
@app.get("/api/line-items", response_model=LineItemResponse)
async def api_get_line_items(
ticker: str = Query(..., min_length=1),
line_items: list[str] = Query(...),
end_date: str = Query(...),
period: str = Query("ttm"),
limit: int = Query(10, ge=1, le=100),
) -> LineItemResponse:
payload = trading_domain.get_line_items_payload(
ticker=ticker,
line_items=line_items,
end_date=end_date,
period=period,
limit=limit,
)
return LineItemResponse(search_results=payload["search_results"])
return app
app = create_app()
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8001)

View File

@@ -29,7 +29,7 @@ from rich.table import Table
from dotenv import load_dotenv
from backend.agents.agent_workspace import load_agent_workspace_config
from backend.agents.prompt_loader import PromptLoader
from backend.agents.prompt_loader import get_prompt_loader
from backend.agents.skills_manager import SkillsManager
from backend.agents.team_pipeline_config import (
ensure_team_pipeline_config,
@@ -55,7 +55,7 @@ team_app = typer.Typer(help="Inspect and manage run-scoped team pipeline config.
app.add_typer(team_app, name="team")
console = Console()
_prompt_loader = PromptLoader()
_prompt_loader = get_prompt_loader()
load_dotenv()
@@ -1019,11 +1019,6 @@ def backtest(
@app.command()
def live(
mock: bool = typer.Option(
False,
"--mock",
help="Use mock mode with simulated prices (for testing)",
),
config_name: str = typer.Option(
"live",
"--config-name",
@@ -1078,7 +1073,6 @@ def live(
Example:
evotraders live # Run immediately (default)
evotraders live --mock # Mock mode
evotraders live -t 22:30 # Run at 22:30 local time daily
evotraders live --schedule-mode intraday --interval-minutes 60
evotraders live --trigger-time now # Run immediately
@@ -1086,16 +1080,14 @@ def live(
"""
schedule_mode = str(_normalize_typer_value(schedule_mode, "daily"))
interval_minutes = int(_normalize_typer_value(interval_minutes, 60))
mode_name = "MOCK" if mock else "LIVE"
console.print(
Panel.fit(
f"[bold cyan]EvoTraders {mode_name} Mode[/bold cyan]",
"[bold cyan]EvoTraders LIVE Mode[/bold cyan]",
border_style="cyan",
),
)
# Check for required API key in live mode
if not mock:
env_file = get_project_root() / ".env"
if not env_file.exists():
console.print("\n[yellow]Warning: .env file not found[/yellow]")
@@ -1168,9 +1160,6 @@ def live(
# Display configuration
console.print("\n[bold]Configuration:[/bold]")
if mock:
console.print(" Mode: [yellow]MOCK[/yellow] (Simulated prices)")
else:
console.print(
" Mode: [green]LIVE[/green] (Real-time prices via Finnhub)",
)
@@ -1188,8 +1177,7 @@ def live(
project_root = get_project_root()
os.chdir(project_root)
# Data update (if not mock mode)
if not mock:
# Data update
run_data_updater(project_root)
auto_update_market_store(
config_name,
@@ -1200,10 +1188,6 @@ def live(
end_date=nyse_now.date().isoformat(),
force=False,
)
else:
console.print(
"\n[dim]Mock mode enabled - skipping data update[/dim]\n",
)
# Build command using backend.main
cmd = [
@@ -1229,8 +1213,6 @@ def live(
str(interval_minutes),
]
if mock:
cmd.append("--mock")
if enable_memory:
cmd.append("--enable-memory")

View File

@@ -76,27 +76,19 @@ def _resolve_config() -> DataSourceConfig:
"""
Resolve data source configuration based on available API keys.
Priority:
1. FINNHUB_API_KEY (if set)
2. FINANCIAL_DATASETS_API_KEY (if set)
3. Raises error if neither is available
The effective source should always match the first item in the resolved
ordered source list.
"""
sources = _ordered_sources()
if "finnhub" in sources:
return DataSourceConfig(
source="finnhub",
api_key=os.getenv("FINNHUB_API_KEY", "").strip(),
sources=sources,
)
if "financial_datasets" in sources:
return DataSourceConfig(
source="financial_datasets",
api_key=os.getenv("FINANCIAL_DATASETS_API_KEY", "").strip(),
sources=sources,
)
if "yfinance" in sources:
return DataSourceConfig(source="yfinance", api_key="", sources=sources)
return DataSourceConfig(source="local_csv", api_key="", sources=sources)
source = sources[0] if sources else "local_csv"
api_key = ""
if source == "finnhub":
api_key = os.getenv("FINNHUB_API_KEY", "").strip()
elif source == "financial_datasets":
api_key = os.getenv("FINANCIAL_DATASETS_API_KEY", "").strip()
return DataSourceConfig(source=source, api_key=api_key, sources=sources)
def get_config() -> DataSourceConfig:

View File

@@ -1,7 +1,35 @@
# -*- coding: utf-8 -*-
"""Core pipeline and orchestration logic"""
"""Core pipeline and orchestration logic.
Keep ``pipeline_runner`` behind lazy wrappers so importing ``backend.core`` does
not immediately pull in the gateway runtime graph.
"""
from .pipeline import TradingPipeline
from .state_sync import StateSync
__all__ = ["TradingPipeline", "StateSync"]
def create_agents(*args, **kwargs):
from .pipeline_runner import create_agents as _create_agents
return _create_agents(*args, **kwargs)
def create_long_term_memory(*args, **kwargs):
from .pipeline_runner import create_long_term_memory as _create_long_term_memory
return _create_long_term_memory(*args, **kwargs)
def stop_gateway(*args, **kwargs):
from .pipeline_runner import stop_gateway as _stop_gateway
return _stop_gateway(*args, **kwargs)
__all__ = [
"TradingPipeline",
"StateSync",
"create_agents",
"create_long_term_memory",
"stop_gateway",
]

View File

@@ -30,7 +30,7 @@ from backend.agents.team_pipeline_config import (
from backend.agents import AnalystAgent
from backend.agents.toolkit_factory import create_agent_toolkit
from backend.agents.workspace_manager import WorkspaceManager
from backend.agents.prompt_loader import PromptLoader
from backend.agents.prompt_loader import get_prompt_loader
from backend.llm.models import get_agent_formatter, get_agent_model
from backend.config.constants import ANALYST_TYPES
@@ -1623,7 +1623,7 @@ class TradingPipeline:
config_name = getattr(self.pm, "config", {}).get("config_name", "default")
project_root = Path(__file__).resolve().parents[2]
personas = PromptLoader().load_yaml_config("analyst", "personas")
personas = get_prompt_loader().load_yaml_config("analyst", "personas")
persona = personas.get(analyst_type, {})
WorkspaceManager(project_root=project_root).ensure_agent_assets(
config_name=config_name,

View File

@@ -17,7 +17,7 @@ from typing import Any, Dict, Optional, Callable
from backend.agents import AnalystAgent, PMAgent, RiskAgent
from backend.agents.skills_manager import SkillsManager
from backend.agents.toolkit_factory import create_agent_toolkit, load_agent_profiles
from backend.agents.prompt_loader import PromptLoader
from backend.agents.prompt_loader import get_prompt_loader
from backend.agents.workspace_manager import WorkspaceManager
from backend.config.constants import ANALYST_TYPES
from backend.core.pipeline import TradingPipeline
@@ -36,7 +36,7 @@ from backend.services.storage import StorageService
from backend.services.gateway import Gateway
from backend.utils.settlement import SettlementCoordinator
_prompt_loader = PromptLoader()
_prompt_loader = get_prompt_loader()
# Global gateway reference for cleanup
_gateway_instance: Optional[Gateway] = None
@@ -244,10 +244,8 @@ async def run_pipeline(
start_date = bootstrap.get("start_date")
end_date = bootstrap.get("end_date")
enable_memory = bootstrap.get("enable_memory", False)
enable_mock = bootstrap.get("enable_mock", False)
is_backtest = mode == "backtest"
is_mock = enable_mock or mode == "mock" or (not is_backtest and os.getenv("MOCK_MODE", "false").lower() == "true")
# ======================================================================
# PHASE 0: Initialize runtime manager
@@ -266,10 +264,6 @@ async def run_pipeline(
set_global_runtime_manager(runtime_manager)
# Register runtime manager with API
from backend.api.runtime import register_runtime_manager
register_runtime_manager(runtime_manager)
# ======================================================================
# PHASE 1 & 2: Create infrastructure services (Market, Storage)
# These will be started by Gateway in the correct order
@@ -292,9 +286,8 @@ async def run_pipeline(
market_service = MarketService(
tickers=tickers,
poll_interval=10,
mock_mode=is_mock and not is_backtest,
backtest_mode=is_backtest,
api_key=os.getenv("FINNHUB_API_KEY") if not is_mock and not is_backtest else None,
api_key=os.getenv("FINNHUB_API_KEY") if not is_backtest else None,
backtest_start_date=start_date if is_backtest else None,
backtest_end_date=end_date if is_backtest else None,
)
@@ -391,7 +384,6 @@ async def run_pipeline(
scheduler_callback=scheduler_callback,
config={
"mode": mode,
"mock_mode": is_mock,
"backtest_mode": is_backtest,
"tickers": tickers,
"config_name": run_id,

View File

@@ -465,7 +465,6 @@ class StateSync:
payload = {
"server_mode": self._state.get("server_mode", "live"),
"is_mock_mode": self._state.get("is_mock_mode", False),
"is_backtest": self._state.get("is_backtest", False),
"tickers": self._state.get("tickers"),
"runtime_config": self._state.get("runtime_config"),
@@ -488,12 +487,13 @@ class StateSync:
}
if include_dashboard:
dashboard_snapshot = self.storage.build_dashboard_snapshot_from_state(self._state)
payload["dashboard"] = {
"summary": self.storage.load_file("summary"),
"holdings": self.storage.load_file("holdings"),
"stats": self.storage.load_file("stats"),
"trades": self.storage.load_file("trades"),
"leaderboard": self.storage.load_file("leaderboard"),
"summary": dashboard_snapshot.get("summary"),
"holdings": dashboard_snapshot.get("holdings"),
"stats": dashboard_snapshot.get("stats"),
"trades": dashboard_snapshot.get("trades"),
"leaderboard": dashboard_snapshot.get("leaderboard"),
}
return payload

View File

@@ -1,6 +1,5 @@
# -*- coding: utf-8 -*-
from backend.data.historical_price_manager import HistoricalPriceManager
from backend.data.mock_price_manager import MockPriceManager
from backend.data.polling_price_manager import PollingPriceManager
__all__ = ["MockPriceManager", "PollingPriceManager", "HistoricalPriceManager"]
__all__ = ["PollingPriceManager", "HistoricalPriceManager"]

View File

@@ -8,6 +8,7 @@ from typing import Iterable
from backend.data.market_store import MarketStore
from backend.data.news_alignment import align_news_for_symbol
from backend.data.provider_router import DataProviderRouter
from backend.data.polygon_client import (
fetch_news,
fetch_ohlc,
@@ -24,6 +25,35 @@ def _default_start(years: int = 2) -> str:
return (datetime.now(timezone.utc).date() - timedelta(days=years * 366)).isoformat()
def _normalize_provider_news_rows(ticker: str, news_items: Iterable[Any]) -> list[dict]:
rows: list[dict] = []
for item in news_items:
payload = item.model_dump() if hasattr(item, "model_dump") else dict(item or {})
related = payload.get("related")
if isinstance(related, str):
related_list = [value.strip().upper() for value in related.split(",") if value.strip()]
elif isinstance(related, list):
related_list = [str(value).strip().upper() for value in related if str(value).strip()]
else:
related_list = []
if ticker not in related_list:
related_list.append(ticker)
rows.append(
{
"title": payload.get("title"),
"description": payload.get("summary"),
"summary": payload.get("summary"),
"article_url": payload.get("url"),
"published_utc": payload.get("date"),
"publisher": payload.get("source"),
"tickers": related_list,
"category": payload.get("category"),
"raw_json": payload,
}
)
return rows
def ingest_ticker_history(
symbol: str,
*,
@@ -114,6 +144,80 @@ def update_ticker_incremental(
}
def refresh_news_incremental(
symbol: str,
*,
end_date: str | None = None,
store: MarketStore | None = None,
) -> dict:
"""Incrementally fetch company news using the configured provider router."""
ticker = normalize_symbol(symbol)
market_store = store or MarketStore()
watermarks = market_store.get_ticker_watermarks(ticker)
end = end_date or _today_utc()
start_news = (
(datetime.fromisoformat(watermarks["last_news_fetch"]) + timedelta(days=1)).date().isoformat()
if watermarks.get("last_news_fetch")
else _default_start()
)
if start_news > end:
return {
"symbol": ticker,
"start_news_date": start_news,
"end_date": end,
"news": 0,
"aligned": 0,
}
router = DataProviderRouter()
news_items, source = router.get_company_news(
ticker=ticker,
start_date=start_news,
end_date=end,
limit=1000,
)
news_rows = _normalize_provider_news_rows(ticker, news_items)
news_count = market_store.upsert_news(ticker, news_rows, source=source) if news_rows else 0
aligned_count = align_news_for_symbol(market_store, ticker)
market_store.update_fetch_watermark(
symbol=ticker,
news_date=end if news_rows or watermarks.get("last_news_fetch") else None,
)
return {
"symbol": ticker,
"start_news_date": start_news,
"end_date": end,
"news": news_count,
"aligned": aligned_count,
"source": source,
}
def refresh_news_for_symbols(
symbols: Iterable[str],
*,
end_date: str | None = None,
store: MarketStore | None = None,
) -> list[dict]:
"""Incrementally refresh company news for a list of tickers."""
market_store = store or MarketStore()
results = []
for symbol in symbols:
ticker = normalize_symbol(symbol)
if not ticker:
continue
results.append(
refresh_news_incremental(
ticker,
end_date=end_date,
store=market_store,
)
)
return results
def ingest_symbols(
symbols: Iterable[str],
*,

View File

@@ -9,7 +9,7 @@ import os
import sqlite3
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Iterable
from typing import Any, Iterable, Optional
SCHEMA = """
@@ -147,12 +147,30 @@ def _utc_timestamp() -> str:
class MarketStore:
"""SQLite-backed market research warehouse."""
"""SQLite-backed market research warehouse. Use get_instance() for the singleton."""
_instance: Optional["MarketStore"] = None
def __new__(cls, db_path: Path | None = None) -> "MarketStore":
if cls._instance is not None:
if db_path is None or cls._instance.db_path == Path(db_path or get_market_db_path()):
return cls._instance
instance = super().__new__(cls)
cls._instance = instance
return instance
def __init__(self, db_path: Path | None = None):
if getattr(self, "_initialized", False):
return
self.db_path = Path(db_path or get_market_db_path())
self.db_path.parent.mkdir(parents=True, exist_ok=True)
self._init_db()
self._initialized = True
@classmethod
def get_instance(cls, db_path: Path | None = None) -> "MarketStore":
"""Get the MarketStore singleton instance."""
return cls(db_path)
def _connect(self) -> sqlite3.Connection:
conn = sqlite3.connect(self.db_path)

View File

@@ -1,244 +0,0 @@
# -*- coding: utf-8 -*-
"""
Mock Price Manager - For testing during non-trading hours
Generates virtual real-time price data
"""
import logging
import os
import random
import threading
import time
from typing import Callable, Dict, List, Optional
from backend.data.provider_utils import normalize_symbol
logger = logging.getLogger(__name__)
class MockPriceManager:
"""Mock Price Manager - Generates virtual prices for testing"""
def __init__(self, poll_interval: int = 10, volatility: float = 0.5):
"""
Args:
poll_interval: Price update interval in seconds
volatility: Price volatility percentage
"""
if poll_interval is None:
poll_interval = int(os.getenv("MOCK_POLL_INTERVAL", "5"))
if volatility is None:
volatility = float(os.getenv("MOCK_VOLATILITY", "0.5"))
self.poll_interval = poll_interval
self.volatility = volatility
self.subscribed_symbols: List[str] = []
self.base_prices: Dict[str, float] = {}
self.open_prices: Dict[str, float] = {}
self.latest_prices: Dict[str, float] = {}
self.price_callbacks: List[Callable] = []
self.running = False
self._thread: Optional[threading.Thread] = None
self.default_base_prices = {
"AAPL": 237.50,
"MSFT": 425.30,
"GOOGL": 161.50,
"AMZN": 218.45,
"NVDA": 950.00,
"META": 573.22,
"TSLA": 342.15,
"AMD": 168.90,
"NFLX": 688.25,
"INTC": 42.18,
"COIN": 285.50,
"PLTR": 45.80,
"BABA": 88.30,
"DIS": 112.50,
"BKNG": 4850.00,
}
logger.info(
f"MockPriceManager initialized (interval: {self.poll_interval}s, "
f"volatility: {self.volatility}%)",
)
def subscribe(
self,
symbols: List[str],
base_prices: Dict[str, float] = None,
):
"""Subscribe to stock symbols"""
for symbol in symbols:
symbol = normalize_symbol(symbol)
if symbol not in self.subscribed_symbols:
self.subscribed_symbols.append(symbol)
if base_prices and symbol in base_prices:
base_price = base_prices[symbol]
elif symbol in self.default_base_prices:
base_price = self.default_base_prices[symbol]
else:
base_price = random.uniform(50, 500)
self.base_prices[symbol] = base_price
self.open_prices[symbol] = base_price
self.latest_prices[symbol] = base_price
logger.info(
f"Subscribed to mock price: {symbol} (base: ${base_price:.2f})", # noqa: E501
)
def unsubscribe(self, symbols: List[str]):
"""Unsubscribe from symbols"""
for symbol in symbols:
symbol = normalize_symbol(symbol)
if symbol in self.subscribed_symbols:
self.subscribed_symbols.remove(symbol)
self.base_prices.pop(symbol, None)
self.open_prices.pop(symbol, None)
self.latest_prices.pop(symbol, None)
logger.info(f"Unsubscribed: {symbol}")
def add_price_callback(self, callback: Callable):
"""Add price update callback"""
self.price_callbacks.append(callback)
def _generate_price_update(self, symbol: str) -> float:
"""Generate price update based on random walk"""
current_price = self.latest_prices.get(
symbol,
self.base_prices[symbol],
)
change_percent = random.uniform(-self.volatility, self.volatility)
new_price = current_price * (1 + change_percent / 100)
# 10% chance of larger movement
if random.random() < 0.1:
trend_factor = random.uniform(-2, 2)
new_price = new_price * (1 + trend_factor / 100)
# Limit intraday movement to +/-10%
open_price = self.open_prices[symbol]
max_price = open_price * 1.10
min_price = open_price * 0.90
new_price = max(min_price, min(max_price, new_price))
return new_price
def _update_prices(self):
"""Update prices for all subscribed stocks"""
timestamp = int(time.time() * 1000)
for symbol in self.subscribed_symbols:
try:
new_price = self._generate_price_update(symbol)
self.latest_prices[symbol] = new_price
open_price = self.open_prices[symbol]
ret = ((new_price - open_price) / open_price) * 100
price_data = {
"symbol": symbol,
"price": new_price,
"timestamp": timestamp,
"volume": random.randint(1000000, 10000000),
"open": open_price,
"high": max(new_price, open_price),
"low": min(new_price, open_price),
"previous_close": open_price,
"ret": ret,
}
for callback in self.price_callbacks:
try:
callback(price_data)
except Exception as e:
logger.error(
f"Mock price callback error ({symbol}): {e}",
)
logger.debug(
f"Mock {symbol}: ${new_price:.2f} [ret: {ret:+.2f}%]",
)
except Exception as e:
logger.error(f"Failed to generate mock price ({symbol}): {e}")
def _polling_loop(self):
"""Main polling loop"""
logger.info(
f"Mock price generation started (interval: {self.poll_interval}s)",
)
while self.running:
try:
start_time = time.time()
self._update_prices()
elapsed = time.time() - start_time
sleep_time = max(0, self.poll_interval - elapsed)
if sleep_time > 0:
time.sleep(sleep_time)
except Exception as e:
logger.error(f"Mock polling loop error: {e}")
time.sleep(5)
def start(self):
"""Start mock price generation"""
if self.running:
logger.warning("Mock price manager already running")
return
if not self.subscribed_symbols:
logger.warning("No stocks subscribed")
return
self.running = True
self._thread = threading.Thread(target=self._polling_loop, daemon=True)
self._thread.start()
logger.info(
f"Mock price manager started: {', '.join(self.subscribed_symbols)}", # noqa: E501
)
def stop(self):
"""Stop mock price generation"""
self.running = False
if self._thread:
self._thread.join(timeout=5)
logger.info("Mock price manager stopped")
def get_latest_price(self, symbol: str) -> Optional[float]:
"""Get latest price for symbol"""
return self.latest_prices.get(symbol)
def get_all_latest_prices(self) -> Dict[str, float]:
"""Get all latest prices"""
return self.latest_prices.copy()
def get_open_price(self, symbol: str) -> Optional[float]:
"""Get open price for symbol"""
return self.open_prices.get(symbol)
def reset_open_prices(self):
"""Reset open prices for new trading day"""
for symbol in self.subscribed_symbols:
last_close = self.latest_prices[symbol]
gap_percent = random.uniform(-1, 1)
new_open = last_close * (1 + gap_percent / 100)
self.open_prices[symbol] = new_open
self.latest_prices[symbol] = new_open
logger.info("Open prices reset")
def set_base_price(self, symbol: str, price: float):
"""Manually set base price for testing"""
if symbol in self.subscribed_symbols:
self.base_prices[symbol] = price
self.open_prices[symbol] = price
self.latest_prices[symbol] = price
logger.info(f"{symbol} base price set to: ${price:.2f}")
else:
logger.warning(f"{symbol} not subscribed")

View File

@@ -15,6 +15,9 @@ from backend.data.provider_utils import normalize_symbol
logger = logging.getLogger(__name__)
_SUPPRESSED_LOG_EVERY = 20
class PollingPriceManager:
"""Polling-based price manager using Finnhub or yfinance."""
@@ -43,6 +46,7 @@ class PollingPriceManager:
self.latest_prices: Dict[str, float] = {}
self.open_prices: Dict[str, float] = {}
self.price_callbacks: List[Callable] = []
self._failure_counts: Dict[str, int] = {}
self.running = False
self._thread: Optional[threading.Thread] = None
@@ -77,6 +81,8 @@ class PollingPriceManager:
for symbol in self.subscribed_symbols:
try:
quote_data = self._fetch_quote(symbol)
if not isinstance(quote_data, dict):
raise ValueError(f"{symbol}: Empty quote payload")
current_price = quote_data.get("c")
open_price = quote_data.get("o")
@@ -103,6 +109,13 @@ class PollingPriceManager:
)
self.latest_prices[symbol] = current_price
previous_failures = self._failure_counts.pop(symbol, 0)
if previous_failures > 0:
logger.info(
"%s quote polling recovered after %d consecutive failures",
symbol,
previous_failures,
)
price_data = {
"symbol": symbol,
@@ -128,7 +141,20 @@ class PollingPriceManager:
)
except Exception as e:
logger.error(f"Failed to fetch {symbol} price: {e}")
failure_count = self._failure_counts.get(symbol, 0) + 1
self._failure_counts[symbol] = failure_count
message = f"Failed to fetch {symbol} price: {e}"
if failure_count == 1:
logger.warning(message)
elif failure_count % _SUPPRESSED_LOG_EVERY == 0:
logger.warning(
"%s (repeated %d times; suppressing intermediate failures)",
message,
failure_count,
)
else:
logger.debug(message)
def _fetch_quote(self, symbol: str) -> Dict[str, float]:
"""Fetch a normalized quote payload from the configured provider."""
@@ -136,7 +162,10 @@ class PollingPriceManager:
return self._fetch_yfinance_quote(symbol)
if not self.finnhub_client:
raise ValueError("Finnhub API key required for finnhub polling")
return self.finnhub_client.quote(symbol)
quote = self.finnhub_client.quote(symbol)
if not isinstance(quote, dict):
raise ValueError(f"{symbol}: Invalid Finnhub quote payload")
return quote
def _fetch_yfinance_quote(self, symbol: str) -> Dict[str, float]:
"""Fetch quote data from yfinance and normalize to Finnhub-like keys."""
@@ -162,6 +191,8 @@ class PollingPriceManager:
if current_price is None:
history = ticker.history(period="1d", interval="1m", auto_adjust=False)
if history is None:
raise ValueError(f"{symbol}: yfinance returned no history frame")
if history.empty:
raise ValueError(f"{symbol}: No yfinance quote data")
latest = history.iloc[-1]

View File

@@ -11,7 +11,7 @@ import pandas as pd
import yfinance as yf
from backend.config.data_config import DataSource, get_data_sources
from backend.data.schema import (
from shared.schema import (
CompanyFactsResponse,
CompanyNews,
CompanyNewsResponse,

View File

@@ -1,194 +1,50 @@
# -*- coding: utf-8 -*-
from pydantic import BaseModel
"""Compatibility schema bridge.
This module preserves the legacy ``backend.data.schema`` import path while
delegating the actual schema definitions to ``shared.schema``. Keeping one
canonical DTO set avoids drift as the monolith is split into service-specific
packages.
"""
class Price(BaseModel):
open: float
close: float
high: float
low: float
volume: int
time: str
from shared.schema import (
AgentStateData,
AgentStateMetadata,
AnalystSignal,
CompanyFacts,
CompanyFactsResponse,
CompanyNews,
CompanyNewsResponse,
FinancialMetrics,
FinancialMetricsResponse,
InsiderTrade,
InsiderTradeResponse,
LineItem,
LineItemResponse,
Portfolio,
Position,
Price,
PriceResponse,
TickerAnalysis,
)
class PriceResponse(BaseModel):
ticker: str
prices: list[Price]
class FinancialMetrics(BaseModel):
ticker: str
report_period: str
period: str
currency: str
market_cap: float | None
enterprise_value: float | None
price_to_earnings_ratio: float | None
price_to_book_ratio: float | None
price_to_sales_ratio: float | None
enterprise_value_to_ebitda_ratio: float | None
enterprise_value_to_revenue_ratio: float | None
free_cash_flow_yield: float | None
peg_ratio: float | None
gross_margin: float | None
operating_margin: float | None
net_margin: float | None
return_on_equity: float | None
return_on_assets: float | None
return_on_invested_capital: float | None
asset_turnover: float | None
inventory_turnover: float | None
receivables_turnover: float | None
days_sales_outstanding: float | None
operating_cycle: float | None
working_capital_turnover: float | None
current_ratio: float | None
quick_ratio: float | None
cash_ratio: float | None
operating_cash_flow_ratio: float | None
debt_to_equity: float | None
debt_to_assets: float | None
interest_coverage: float | None
revenue_growth: float | None
earnings_growth: float | None
book_value_growth: float | None
earnings_per_share_growth: float | None
free_cash_flow_growth: float | None
operating_income_growth: float | None
ebitda_growth: float | None
payout_ratio: float | None
earnings_per_share: float | None
book_value_per_share: float | None
free_cash_flow_per_share: float | None
class FinancialMetricsResponse(BaseModel):
financial_metrics: list[FinancialMetrics]
class LineItem(BaseModel):
ticker: str
report_period: str
period: str
currency: str
# Allow additional fields dynamically
model_config = {"extra": "allow"}
class LineItemResponse(BaseModel):
search_results: list[LineItem]
class InsiderTrade(BaseModel):
ticker: str
issuer: str | None
name: str | None
title: str | None
is_board_director: bool | None
transaction_date: str | None
transaction_shares: float | None
transaction_price_per_share: float | None
transaction_value: float | None
shares_owned_before_transaction: float | None
shares_owned_after_transaction: float | None
security_title: str | None
filing_date: str
class InsiderTradeResponse(BaseModel):
insider_trades: list[InsiderTrade]
class CompanyNews(BaseModel):
category: str | None = None
ticker: str
title: str
related: str | None = None
source: str
date: str | None = None
url: str
summary: str | None = None
class CompanyNewsResponse(BaseModel):
news: list[CompanyNews]
class CompanyFacts(BaseModel):
ticker: str
name: str
cik: str | None = None
industry: str | None = None
sector: str | None = None
category: str | None = None
exchange: str | None = None
is_active: bool | None = None
listing_date: str | None = None
location: str | None = None
market_cap: float | None = None
number_of_employees: int | None = None
sec_filings_url: str | None = None
sic_code: str | None = None
sic_industry: str | None = None
sic_sector: str | None = None
website_url: str | None = None
weighted_average_shares: int | None = None
class CompanyFactsResponse(BaseModel):
company_facts: CompanyFacts
class Position(BaseModel):
"""Position information - for Portfolio mode"""
long: int = 0 # Long position quantity (shares)
short: int = 0 # Short position quantity (shares)
long_cost_basis: float = 0.0 # Long position average cost
short_cost_basis: float = 0.0 # Short position average cost
class Portfolio(BaseModel):
"""Portfolio - for Portfolio mode"""
cash: float = 100000.0 # Available cash
positions: dict[str, Position] = {} # ticker -> Position mapping
# Margin requirement (0.0 means shorting disabled, 0.5 means 50% margin)
margin_requirement: float = 0.0
margin_used: float = 0.0 # Margin used
class AnalystSignal(BaseModel):
signal: str | None = None
confidence: float | None = None
reasoning: dict | str | None = None
# Extended fields for richer signal information
reasons: list[str] | None = None # Core drivers/reasons for the signal
risks: list[str] | None = None # Key risk factors
invalidation: str | None = None # Conditions that would invalidate the thesis
next_action: str | None = None # Suggested next action for PM
# Valuation-related fields
intrinsic_value: float | None = None # DCF intrinsic value
fair_value_range: dict | None = None # {bear, base, bull} fair value range
value_gap_pct: float | None = None # Value gap percentage
valuation_methods: list[str] | None = None # List of valuation methods used
max_position_size: float | None = None # For risk management signals
class TickerAnalysis(BaseModel):
ticker: str
analyst_signals: dict[str, AnalystSignal] # agent_name -> signal mapping
class AgentStateData(BaseModel):
tickers: list[str]
portfolio: Portfolio
start_date: str
end_date: str
ticker_analyses: dict[str, TickerAnalysis] # ticker -> analysis mapping
class AgentStateMetadata(BaseModel):
show_reasoning: bool = False
model_config = {"extra": "allow"}
__all__ = [
"Price",
"PriceResponse",
"FinancialMetrics",
"FinancialMetricsResponse",
"LineItem",
"LineItemResponse",
"InsiderTrade",
"InsiderTradeResponse",
"CompanyNews",
"CompanyNewsResponse",
"CompanyFacts",
"CompanyFactsResponse",
"Position",
"Portfolio",
"AnalystSignal",
"TickerAnalysis",
"AgentStateData",
"AgentStateMetadata",
]

View File

@@ -0,0 +1,2 @@
# -*- coding: utf-8 -*-
"""Domain modules for split service internals."""

320
backend/domains/news.py Normal file
View File

@@ -0,0 +1,320 @@
# -*- coding: utf-8 -*-
"""News/explain domain helpers shared by app surfaces and gateway fallbacks."""
from __future__ import annotations
from typing import Any
from backend.data.market_store import MarketStore
from backend.data.market_ingest import update_ticker_incremental
from backend.enrich.news_enricher import enrich_news_for_symbol
from backend.explain.range_explainer import build_range_explanation
from backend.explain.similarity_service import find_similar_days
from backend.explain.story_service import get_or_create_stock_story
def news_rows_need_enrichment(rows: list[dict[str, Any]]) -> bool:
"""Return whether news rows are missing explain-oriented analysis fields."""
if not rows:
return True
return all(
not row.get("sentiment")
and not row.get("relevance")
and not row.get("key_discussion")
for row in rows
)
def ensure_news_fresh(
store: MarketStore,
*,
ticker: str,
target_date: str | None = None,
refresh_if_stale: bool = True,
) -> dict[str, Any]:
"""Refresh raw news incrementally when stored watermarks are stale."""
normalized_target = str(target_date or "").strip()[:10]
if not normalized_target:
return {
"ticker": ticker,
"target_date": None,
"last_news_fetch": None,
"refreshed": False,
}
watermarks = store.get_ticker_watermarks(ticker)
last_news_fetch = str(watermarks.get("last_news_fetch") or "").strip()[:10]
refreshed = False
if refresh_if_stale and (not last_news_fetch or last_news_fetch < normalized_target):
update_ticker_incremental(
ticker,
end_date=normalized_target,
store=store,
)
refreshed = True
watermarks = store.get_ticker_watermarks(ticker)
last_news_fetch = str(watermarks.get("last_news_fetch") or "").strip()[:10]
return {
"ticker": ticker,
"target_date": normalized_target,
"last_news_fetch": last_news_fetch or None,
"refreshed": refreshed,
}
def get_enriched_news(
store: MarketStore,
*,
ticker: str,
start_date: str | None = None,
end_date: str | None = None,
limit: int = 100,
refresh_if_stale: bool = False,
) -> dict[str, Any]:
freshness = ensure_news_fresh(
store,
ticker=ticker,
target_date=end_date,
refresh_if_stale=refresh_if_stale,
)
rows = store.get_news_items_enriched(
ticker,
start_date=start_date,
end_date=end_date,
limit=limit,
)
if news_rows_need_enrichment(rows):
enrich_news_for_symbol(
store,
ticker,
start_date=start_date,
end_date=end_date,
limit=limit,
)
rows = store.get_news_items_enriched(
ticker,
start_date=start_date,
end_date=end_date,
limit=limit,
)
return {"ticker": ticker, "news": rows, "freshness": freshness}
def get_news_for_date(
store: MarketStore,
*,
ticker: str,
date: str,
limit: int = 20,
refresh_if_stale: bool = False,
) -> dict[str, Any]:
freshness = ensure_news_fresh(
store,
ticker=ticker,
target_date=date,
refresh_if_stale=refresh_if_stale,
)
rows = store.get_news_items_enriched(
ticker,
trade_date=date,
limit=limit,
)
if news_rows_need_enrichment(rows):
enrich_news_for_symbol(
store,
ticker,
start_date=date,
end_date=date,
limit=limit,
)
rows = store.get_news_items_enriched(
ticker,
trade_date=date,
limit=limit,
)
return {"ticker": ticker, "date": date, "news": rows, "freshness": freshness}
def get_news_timeline(
store: MarketStore,
*,
ticker: str,
start_date: str,
end_date: str,
refresh_if_stale: bool = False,
) -> dict[str, Any]:
freshness = ensure_news_fresh(
store,
ticker=ticker,
target_date=end_date,
refresh_if_stale=refresh_if_stale,
)
timeline = store.get_news_timeline_enriched(
ticker,
start_date=start_date,
end_date=end_date,
)
if not timeline:
enrich_news_for_symbol(
store,
ticker,
start_date=start_date,
end_date=end_date,
limit=200,
)
timeline = store.get_news_timeline_enriched(
ticker,
start_date=start_date,
end_date=end_date,
)
return {
"ticker": ticker,
"timeline": timeline,
"start_date": start_date,
"end_date": end_date,
"freshness": freshness,
}
def get_news_categories(
store: MarketStore,
*,
ticker: str,
start_date: str | None = None,
end_date: str | None = None,
limit: int = 200,
refresh_if_stale: bool = False,
) -> dict[str, Any]:
freshness = ensure_news_fresh(
store,
ticker=ticker,
target_date=end_date,
refresh_if_stale=refresh_if_stale,
)
rows = store.get_news_items_enriched(
ticker,
start_date=start_date,
end_date=end_date,
limit=limit,
)
if news_rows_need_enrichment(rows):
enrich_news_for_symbol(
store,
ticker,
start_date=start_date,
end_date=end_date,
limit=limit,
)
categories = store.get_news_categories_enriched(
ticker,
start_date=start_date,
end_date=end_date,
limit=limit,
)
return {"ticker": ticker, "categories": categories, "freshness": freshness}
def get_similar_days_payload(
store: MarketStore,
*,
ticker: str,
date: str,
n_similar: int = 5,
refresh_if_stale: bool = False,
) -> dict[str, Any]:
freshness = ensure_news_fresh(
store,
ticker=ticker,
target_date=date,
refresh_if_stale=refresh_if_stale,
)
result = find_similar_days(
store,
symbol=ticker,
target_date=date,
top_k=n_similar,
)
result["freshness"] = freshness
return result
def get_story_payload(
store: MarketStore,
*,
ticker: str,
as_of_date: str,
refresh_if_stale: bool = False,
) -> dict[str, Any]:
freshness = ensure_news_fresh(
store,
ticker=ticker,
target_date=as_of_date,
refresh_if_stale=refresh_if_stale,
)
enrich_news_for_symbol(
store,
ticker,
end_date=as_of_date,
limit=80,
)
result = get_or_create_stock_story(
store,
symbol=ticker,
as_of_date=as_of_date,
)
result["freshness"] = freshness
return result
def get_range_explain_payload(
store: MarketStore,
*,
ticker: str,
start_date: str,
end_date: str,
article_ids: list[str] | None = None,
limit: int = 100,
refresh_if_stale: bool = False,
) -> dict[str, Any]:
freshness = ensure_news_fresh(
store,
ticker=ticker,
target_date=end_date,
refresh_if_stale=refresh_if_stale,
)
news_rows = []
if article_ids:
news_rows = store.get_news_by_ids_enriched(ticker, article_ids)
if not news_rows:
news_rows = store.get_news_items_enriched(
ticker,
start_date=start_date,
end_date=end_date,
limit=limit,
)
if news_rows_need_enrichment(news_rows):
enrich_news_for_symbol(
store,
ticker,
start_date=start_date,
end_date=end_date,
limit=limit,
)
news_rows = (
store.get_news_by_ids_enriched(ticker, article_ids)
if article_ids
else store.get_news_items_enriched(
ticker,
start_date=start_date,
end_date=end_date,
limit=limit,
)
)
result = build_range_explanation(
ticker=ticker,
start_date=start_date,
end_date=end_date,
news_rows=news_rows,
)
return {"ticker": ticker, "result": result, "freshness": freshness}

106
backend/domains/trading.py Normal file
View File

@@ -0,0 +1,106 @@
# -*- coding: utf-8 -*-
"""Trading domain helpers shared by app surfaces and gateway fallbacks."""
from __future__ import annotations
from typing import Any
from backend.services.market import MarketService
from backend.tools.data_tools import (
get_company_news,
get_financial_metrics,
get_insider_trades,
get_market_cap,
get_prices,
search_line_items,
)
def get_prices_payload(*, ticker: str, start_date: str, end_date: str) -> dict[str, Any]:
return {
"ticker": ticker,
"prices": get_prices(ticker, start_date, end_date),
}
def get_financials_payload(
*,
ticker: str,
end_date: str,
period: str = "ttm",
limit: int = 10,
) -> dict[str, Any]:
return {
"financial_metrics": get_financial_metrics(
ticker=ticker,
end_date=end_date,
period=period,
limit=limit,
)
}
def get_news_payload(
*,
ticker: str,
end_date: str,
start_date: str | None = None,
limit: int = 1000,
) -> dict[str, Any]:
return {
"news": get_company_news(
ticker=ticker,
end_date=end_date,
start_date=start_date,
limit=limit,
)
}
def get_insider_trades_payload(
*,
ticker: str,
end_date: str,
start_date: str | None = None,
limit: int = 1000,
) -> dict[str, Any]:
return {
"insider_trades": get_insider_trades(
ticker=ticker,
end_date=end_date,
start_date=start_date,
limit=limit,
)
}
def get_market_status_payload() -> dict[str, Any]:
market_service = MarketService(tickers=[])
return market_service.get_market_status()
def get_market_cap_payload(*, ticker: str, end_date: str) -> dict[str, Any]:
return {
"ticker": ticker,
"end_date": end_date,
"market_cap": get_market_cap(ticker, end_date),
}
def get_line_items_payload(
*,
ticker: str,
line_items: list[str],
end_date: str,
period: str = "ttm",
limit: int = 10,
) -> dict[str, Any]:
return {
"search_results": search_line_items(
ticker=ticker,
line_items=line_items,
end_date=end_date,
period=period,
limit=limit,
)
}

308
backend/gateway_server.py Normal file
View File

@@ -0,0 +1,308 @@
# -*- coding: utf-8 -*-
"""Gateway Server - Entry point for Gateway subprocess.
This module is launched as a subprocess by the Control Plane (FastAPI)
to run the Data Plane (Gateway + Pipeline).
"""
import argparse
import asyncio
import json
import logging
import os
import sys
from contextlib import AsyncExitStack
from pathlib import Path
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
from backend.agents import AnalystAgent, PMAgent, RiskAgent
from backend.agents.skills_manager import SkillsManager
from backend.agents.toolkit_factory import create_agent_toolkit, load_agent_profiles
from backend.agents.prompt_loader import get_prompt_loader
from backend.agents.workspace_manager import WorkspaceManager
from backend.config.constants import ANALYST_TYPES
from backend.core.pipeline import TradingPipeline
from backend.core.pipeline_runner import create_agents, create_long_term_memory
from backend.core.scheduler import BacktestScheduler, Scheduler
from backend.llm.models import get_agent_formatter, get_agent_model
from backend.runtime.manager import (
TradingRuntimeManager,
set_global_runtime_manager,
clear_global_runtime_manager,
)
from backend.services.gateway import Gateway
from backend.services.market import MarketService
from backend.services.storage import StorageService
from backend.utils.settlement import SettlementCoordinator
logger = logging.getLogger(__name__)
_prompt_loader = get_prompt_loader()
INFO_LOGGER_PREFIXES = (
"backend.agents",
"backend.core.pipeline",
"backend.core.scheduler",
"backend.services.gateway_cycle_support",
"backend.utils.terminal_dashboard",
)
NOISY_LOGGER_LEVELS = {
"aiohttp": logging.WARNING,
"asyncio": logging.WARNING,
"dashscope": logging.WARNING,
"finnhub": logging.WARNING,
"httpcore": logging.WARNING,
"httpx": logging.WARNING,
"urllib3": logging.WARNING,
"websockets": logging.WARNING,
"yfinance": logging.WARNING,
"backend.data.polling_price_manager": logging.WARNING,
"backend.services.gateway": logging.WARNING,
"backend.services.market": logging.WARNING,
"backend.services.storage": logging.WARNING,
}
class SuppressNoisyInfoFilter(logging.Filter):
"""Filter out low-signal library INFO logs while keeping warnings/errors."""
def filter(self, record: logging.LogRecord) -> bool:
if record.levelno >= logging.WARNING:
return True
message = record.getMessage()
if record.name == "httpx" and message.startswith("HTTP Request:"):
return False
if record.name.startswith("websockets") and "connection open" in message:
return False
return True
def configure_gateway_logging(verbose: bool = False) -> None:
"""Configure gateway logging with low-noise defaults for runtime logs."""
root_level = logging.DEBUG if verbose else logging.WARNING
logging.basicConfig(
level=root_level,
format="%(asctime)s | %(levelname)-7s | %(name)s:%(lineno)d - %(message)s",
force=True,
)
if not verbose:
suppress_filter = SuppressNoisyInfoFilter()
for handler in logging.getLogger().handlers:
handler.addFilter(suppress_filter)
for logger_name, level in NOISY_LOGGER_LEVELS.items():
logging.getLogger(logger_name).setLevel(logging.DEBUG if verbose else level)
if not verbose:
for prefix in INFO_LOGGER_PREFIXES:
logging.getLogger(prefix).setLevel(logging.INFO)
logging.getLogger(__name__).setLevel(logging.INFO if not verbose else logging.DEBUG)
async def run_gateway(
run_id: str,
run_dir: Path,
bootstrap: dict,
port: int
):
"""Run Gateway with Pipeline."""
# Extract config
tickers = bootstrap.get("tickers", ["AAPL", "MSFT"])
initial_cash = float(bootstrap.get("initial_cash", 100000.0))
margin_requirement = float(bootstrap.get("margin_requirement", 0.0))
max_comm_cycles = int(bootstrap.get("max_comm_cycles", 2))
schedule_mode = bootstrap.get("schedule_mode", "daily")
trigger_time = bootstrap.get("trigger_time", "09:30")
interval_minutes = int(bootstrap.get("interval_minutes", 60))
heartbeat_interval = int(bootstrap.get("heartbeat_interval", 0)) # 0 = disabled
mode = bootstrap.get("mode", "live")
start_date = bootstrap.get("start_date")
end_date = bootstrap.get("end_date")
enable_memory = bootstrap.get("enable_memory", False)
poll_interval = int(bootstrap.get("poll_interval", 10))
is_backtest = mode == "backtest"
logger.info(f"[Gateway Server] Starting run {run_id} on port {port}")
# Create runtime manager
runtime_manager = TradingRuntimeManager(
config_name=run_id,
run_dir=run_dir,
bootstrap=bootstrap,
)
runtime_manager.prepare_run()
set_global_runtime_manager(runtime_manager)
try:
async with AsyncExitStack() as stack:
# Create services
market_service = MarketService(
tickers=tickers,
poll_interval=poll_interval,
backtest_mode=is_backtest,
api_key=os.getenv("FINNHUB_API_KEY") if not is_backtest else None,
backtest_start_date=start_date if is_backtest else None,
backtest_end_date=end_date if is_backtest else None,
)
storage_service = StorageService(
dashboard_dir=run_dir / "team_dashboard",
initial_cash=initial_cash,
config_name=run_id,
)
if not storage_service.files["summary"].exists():
storage_service.initialize_empty_dashboard()
else:
storage_service.update_leaderboard_model_info()
# Create agents
analysts, risk_manager, pm, long_term_memories = create_agents(
run_id=run_id,
run_dir=run_dir,
initial_cash=initial_cash,
margin_requirement=margin_requirement,
enable_long_term_memory=enable_memory,
)
# Register agents
for agent in analysts + [risk_manager, pm]:
agent_id = getattr(agent, "agent_id", None) or getattr(agent, "name", None)
if agent_id:
runtime_manager.register_agent(agent_id)
# Load portfolio state
portfolio_state = storage_service.load_portfolio_state()
pm.load_portfolio_state(portfolio_state)
# Create settlement coordinator
settlement_coordinator = SettlementCoordinator(
storage=storage_service,
initial_capital=initial_cash,
)
# Create pipeline
pipeline = TradingPipeline(
analysts=analysts,
risk_manager=risk_manager,
portfolio_manager=pm,
settlement_coordinator=settlement_coordinator,
max_comm_cycles=max_comm_cycles,
runtime_manager=runtime_manager,
)
# Create scheduler
scheduler_callback = None
live_scheduler = None
if is_backtest:
backtest_scheduler = BacktestScheduler(
start_date=start_date,
end_date=end_date,
trading_calendar="NYSE",
delay_between_days=0.5,
)
async def scheduler_callback_fn(callback):
await backtest_scheduler.start(callback)
scheduler_callback = scheduler_callback_fn
else:
live_scheduler = Scheduler(
mode=schedule_mode,
trigger_time=trigger_time,
interval_minutes=interval_minutes,
heartbeat_interval=heartbeat_interval if heartbeat_interval > 0 else None,
config={"config_name": run_id},
)
async def scheduler_callback_fn(callback):
await live_scheduler.start(callback)
scheduler_callback = scheduler_callback_fn
# Enter long-term memory contexts
for memory in long_term_memories:
await stack.enter_async_context(memory)
# Create Gateway
gateway = Gateway(
market_service=market_service,
storage_service=storage_service,
pipeline=pipeline,
scheduler_callback=scheduler_callback,
config={
"mode": mode,
"backtest_mode": is_backtest,
"tickers": tickers,
"config_name": run_id,
"schedule_mode": schedule_mode,
"interval_minutes": interval_minutes,
"trigger_time": trigger_time,
"heartbeat_interval": heartbeat_interval,
"initial_cash": initial_cash,
"margin_requirement": margin_requirement,
"max_comm_cycles": max_comm_cycles,
"enable_memory": enable_memory,
},
scheduler=live_scheduler,
)
# Start Gateway (blocks until shutdown)
logger.info(f"[Gateway Server] Gateway starting on port {port}")
await gateway.start(host="0.0.0.0", port=port)
except asyncio.CancelledError:
logger.info("[Gateway Server] Cancelled")
raise
finally:
logger.info("[Gateway Server] Cleaning up")
clear_global_runtime_manager()
def main():
"""Main entry point."""
parser = argparse.ArgumentParser(description="Gateway Server")
parser.add_argument("--run-id", required=True, help="Run identifier")
parser.add_argument("--run-dir", required=True, help="Run directory path")
parser.add_argument("--port", type=int, default=8765, help="WebSocket port")
parser.add_argument("--bootstrap", required=True, help="Bootstrap config as JSON")
parser.add_argument("--verbose", action="store_true", help="Verbose logging")
args = parser.parse_args()
# Setup logging
configure_gateway_logging(verbose=args.verbose)
# Parse bootstrap
bootstrap = json.loads(args.bootstrap)
run_dir = Path(args.run_dir)
# Run
try:
asyncio.run(run_gateway(
run_id=args.run_id,
run_dir=run_dir,
bootstrap=bootstrap,
port=args.port
))
except KeyboardInterrupt:
logger.info("[Gateway Server] Interrupted by user")
except Exception as e:
logger.exception(f"[Gateway Server] Fatal error: {e}")
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -3,6 +3,8 @@
AgentScope Native Model Factory
Uses native AgentScope model classes for LLM calls
"""
import asyncio
import inspect
import os
import time
import logging
@@ -34,6 +36,27 @@ logger = logging.getLogger(__name__)
T = TypeVar("T")
def _usage_value(usage: Any, key: str, default: Any = 0) -> Any:
"""Read usage fields from both object-style and dict-style usage payloads."""
if usage is None:
return default
if isinstance(usage, dict):
return usage.get(key, default)
try:
return getattr(usage, key)
except (AttributeError, KeyError):
return default
def _usage_total_tokens(usage: Any) -> int:
total = _usage_value(usage, "total_tokens", None)
if total is not None:
return int(total or 0)
input_tokens = _usage_value(usage, "input_tokens", 0)
output_tokens = _usage_value(usage, "output_tokens", 0)
return int((input_tokens or 0) + (output_tokens or 0))
class RetryChatModel:
"""Wraps an AgentScope model with automatic retry for transient errors.
@@ -55,6 +78,7 @@ class RetryChatModel:
"502",
"504",
"connection",
"disconnected",
"temporary",
"overloaded",
"too_many_requests",
@@ -150,8 +174,8 @@ class RetryChatModel:
# Track usage if available
if hasattr(result, "usage") and result.usage:
usage = result.usage
self._total_tokens_used += getattr(usage, "total_tokens", 0)
self._total_cost += getattr(usage, "cost", 0.0)
self._total_tokens_used += _usage_total_tokens(usage)
self._total_cost += float(_usage_value(usage, "cost", 0.0) or 0.0)
return result
@@ -192,9 +216,66 @@ class RetryChatModel:
raise last_error
raise RuntimeError("RetryChatModel: Unexpected state, no error but no result")
async def _call_with_retry_async(self, func: Callable[..., T], *args, **kwargs) -> T:
"""Call an async function with retry logic for transient errors."""
last_error: Optional[Exception] = None
for attempt in range(1, self._max_retries + 1):
try:
result = await func(*args, **kwargs)
if hasattr(result, "usage") and result.usage:
usage = result.usage
self._total_tokens_used += _usage_total_tokens(usage)
self._total_cost += float(_usage_value(usage, "cost", 0.0) or 0.0)
return result
except Exception as e:
last_error = e
if attempt >= self._max_retries:
logger.error(
"RetryChatModel: Max retries (%d) exhausted for %s",
self._max_retries,
self.model_name,
)
break
if not self._is_transient_error(e):
logger.warning(
"RetryChatModel: Non-transient error, not retrying: %s",
str(e),
)
break
delay = self._calculate_delay(attempt)
logger.warning(
"RetryChatModel: Transient async error on attempt %d/%d, "
"retrying in %.1fs: %s",
attempt,
self._max_retries,
delay,
str(e)[:200],
)
if self._on_retry:
self._on_retry(attempt, e, delay)
await asyncio.sleep(delay)
if last_error is not None:
raise last_error
raise RuntimeError("RetryChatModel: Unexpected async state, no error but no result")
def __call__(self, *args, **kwargs) -> Any:
"""Forward calls to the wrapped model with retry logic."""
return self._call_with_retry(self._model, *args, **kwargs)
model_call = getattr(self._model, "__call__", None)
if inspect.iscoroutinefunction(self._model) or inspect.iscoroutinefunction(model_call):
return self._call_with_retry_async(self._model, *args, **kwargs)
result = self._model(*args, **kwargs)
return result
def __getattr__(self, name: str) -> Any:
"""Proxy attribute access to the wrapped model."""
@@ -248,10 +329,18 @@ class TokenRecordingModelWrapper:
if usage is None:
return
self._prompt_tokens += getattr(usage, "prompt_tokens", 0)
self._completion_tokens += getattr(usage, "completion_tokens", 0)
self._total_tokens += getattr(usage, "total_tokens", 0)
self._total_cost += getattr(usage, "cost", 0.0)
prompt_tokens = _usage_value(usage, "prompt_tokens", None)
completion_tokens = _usage_value(usage, "completion_tokens", None)
if prompt_tokens is None:
prompt_tokens = _usage_value(usage, "input_tokens", 0)
if completion_tokens is None:
completion_tokens = _usage_value(usage, "output_tokens", 0)
self._prompt_tokens += int(prompt_tokens or 0)
self._completion_tokens += int(completion_tokens or 0)
self._total_tokens += _usage_total_tokens(usage)
self._total_cost += float(_usage_value(usage, "cost", 0.0) or 0.0)
def __call__(self, *args, **kwargs) -> Any:
"""Forward calls and record usage."""
@@ -401,7 +490,8 @@ def create_model(
if host:
model_kwargs["host"] = host
return model_class(**model_kwargs)
model = model_class(**model_kwargs)
return RetryChatModel(model)
def get_agent_model(agent_id: str, stream: bool = False):

View File

@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
"""
Main Entry Point
Supports: backtest, live, mock modes
Supports: backtest, live modes
"""
import argparse
import asyncio
@@ -16,7 +16,7 @@ from dotenv import load_dotenv
from backend.agents import AnalystAgent, PMAgent, RiskAgent
from backend.agents.skills_manager import SkillsManager
from backend.agents.toolkit_factory import create_agent_toolkit, load_agent_profiles
from backend.agents.prompt_loader import PromptLoader
from backend.agents.prompt_loader import get_prompt_loader
from backend.agents.workspace_manager import WorkspaceManager
from backend.config.bootstrap_config import resolve_runtime_config
from backend.config.constants import ANALYST_TYPES
@@ -38,7 +38,7 @@ load_dotenv()
logger = logging.getLogger(__name__)
loguru.logger.disable("flowllm")
loguru.logger.disable("reme_ai")
_prompt_loader = PromptLoader()
_prompt_loader = get_prompt_loader()
def _get_run_dir(config_name: str) -> Path:
@@ -226,17 +226,13 @@ async def run_with_gateway(args):
)
runtime_manager.prepare_run()
set_global_runtime_manager(runtime_manager)
register_runtime_manager(runtime_manager)
# Create market service
market_service = MarketService(
tickers=tickers,
poll_interval=args.poll_interval,
mock_mode=args.mock and not is_backtest,
backtest_mode=is_backtest,
api_key=os.getenv("FINNHUB_API_KEY")
if not args.mock and not is_backtest
else None,
api_key=os.getenv("FINNHUB_API_KEY") if not is_backtest else None,
backtest_start_date=args.start_date if is_backtest else None,
backtest_end_date=args.end_date if is_backtest else None,
)
@@ -321,7 +317,6 @@ async def run_with_gateway(args):
scheduler_callback=scheduler_callback,
config={
"mode": args.mode,
"mock_mode": args.mock,
"backtest_mode": is_backtest,
"tickers": tickers,
"config_name": config_name,
@@ -354,8 +349,7 @@ def main():
"""Main entry point"""
parser = argparse.ArgumentParser(description="Trading System")
parser.add_argument("--mode", choices=["live", "backtest"], default="live")
parser.add_argument("--mock", action="store_true")
parser.add_argument("--config-name", default="mock")
parser.add_argument("--config-name", default="live")
parser.add_argument("--host", default="0.0.0.0")
parser.add_argument("--port", type=int, default=8765)
parser.add_argument(

View File

@@ -13,15 +13,30 @@ from .registry import RuntimeRegistry
_global_runtime_manager: Optional["TradingRuntimeManager"] = None
_shutdown_event: Optional[asyncio.Event] = None
# Lazy import to avoid circular dependency
_api_runtime = None
def _get_api_runtime():
global _api_runtime
if _api_runtime is None:
from backend.api import runtime as api_runtime_module
_api_runtime = api_runtime_module
return _api_runtime
def set_global_runtime_manager(manager: "TradingRuntimeManager") -> None:
global _global_runtime_manager
_global_runtime_manager = manager
# Sync to RuntimeState for consistency
_get_api_runtime().register_runtime_manager(manager)
def clear_global_runtime_manager() -> None:
global _global_runtime_manager
_global_runtime_manager = None
# Sync to RuntimeState for consistency
_get_api_runtime().unregister_runtime_manager()
def get_global_runtime_manager() -> Optional["TradingRuntimeManager"]:

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,426 @@
# -*- coding: utf-8 -*-
"""Runtime/workspace/skills handlers extracted from the main Gateway module.
Deprecated note:
Agent/workspace/skill read-write operations are being migrated to
agent_service REST endpoints. These websocket handlers remain as a
compatibility fallback and should not be considered the primary control
plane path for frontend reads/writes.
"""
from __future__ import annotations
import json
from datetime import datetime
from typing import Any
from backend.agents.agent_workspace import load_agent_workspace_config
from backend.agents.skills_manager import SkillsManager
from backend.agents.toolkit_factory import load_agent_profiles
from backend.config.bootstrap_config import (
get_bootstrap_config_for_run,
resolve_runtime_config,
update_bootstrap_values_for_run,
)
from backend.data.market_ingest import ingest_symbols
from backend.llm.models import get_agent_model_info
async def handle_reload_runtime_assets(gateway: Any) -> None:
config_name = gateway.config.get("config_name", "default")
runtime_config = resolve_runtime_config(
project_root=gateway._project_root,
config_name=config_name,
enable_memory=gateway.config.get("enable_memory", False),
schedule_mode=gateway.config.get("schedule_mode", "daily"),
interval_minutes=gateway.config.get("interval_minutes", 60),
trigger_time=gateway.config.get("trigger_time", "09:30"),
)
result = gateway.pipeline.reload_runtime_assets(runtime_config=runtime_config)
runtime_updates = gateway._apply_runtime_config(runtime_config)
await gateway.state_sync.on_system_message("Runtime assets reloaded.")
await gateway.broadcast({"type": "runtime_assets_reloaded", **result, **runtime_updates})
async def handle_update_runtime_config(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
updates: dict[str, Any] = {}
schedule_mode = str(data.get("schedule_mode", "")).strip().lower()
if schedule_mode:
if schedule_mode not in {"daily", "intraday"}:
await websocket.send(json.dumps({"type": "error", "message": "schedule_mode must be 'daily' or 'intraday'."}, ensure_ascii=False))
return
updates["schedule_mode"] = schedule_mode
interval_minutes = data.get("interval_minutes")
if interval_minutes is not None:
try:
parsed_interval = int(interval_minutes)
except (TypeError, ValueError):
parsed_interval = 0
if parsed_interval <= 0:
await websocket.send(json.dumps({"type": "error", "message": "interval_minutes must be a positive integer."}, ensure_ascii=False))
return
updates["interval_minutes"] = parsed_interval
trigger_time = data.get("trigger_time")
if trigger_time is not None:
raw_trigger = str(trigger_time).strip()
if raw_trigger and raw_trigger != "now":
try:
datetime.strptime(raw_trigger, "%H:%M")
except ValueError:
await websocket.send(json.dumps({"type": "error", "message": "trigger_time must use HH:MM or 'now'."}, ensure_ascii=False))
return
updates["trigger_time"] = raw_trigger or "09:30"
max_comm_cycles = data.get("max_comm_cycles")
if max_comm_cycles is not None:
try:
parsed_cycles = int(max_comm_cycles)
except (TypeError, ValueError):
parsed_cycles = 0
if parsed_cycles <= 0:
await websocket.send(json.dumps({"type": "error", "message": "max_comm_cycles must be a positive integer."}, ensure_ascii=False))
return
updates["max_comm_cycles"] = parsed_cycles
initial_cash = data.get("initial_cash")
if initial_cash is not None:
try:
parsed_initial_cash = float(initial_cash)
except (TypeError, ValueError):
parsed_initial_cash = 0.0
if parsed_initial_cash <= 0:
await websocket.send(json.dumps({"type": "error", "message": "initial_cash must be a positive number."}, ensure_ascii=False))
return
updates["initial_cash"] = parsed_initial_cash
margin_requirement = data.get("margin_requirement")
if margin_requirement is not None:
try:
parsed_margin_requirement = float(margin_requirement)
except (TypeError, ValueError):
parsed_margin_requirement = -1.0
if parsed_margin_requirement < 0:
await websocket.send(json.dumps({"type": "error", "message": "margin_requirement must be a non-negative number."}, ensure_ascii=False))
return
updates["margin_requirement"] = parsed_margin_requirement
enable_memory = data.get("enable_memory")
if enable_memory is not None:
updates["enable_memory"] = bool(enable_memory)
if not updates:
await websocket.send(json.dumps({"type": "error", "message": "No runtime settings were provided."}, ensure_ascii=False))
return
config_name = gateway.config.get("config_name", "default")
update_bootstrap_values_for_run(
project_root=gateway._project_root,
config_name=config_name,
updates=updates,
)
await gateway.state_sync.on_system_message("运行时调度配置已保存,正在热更新")
await handle_reload_runtime_assets(gateway)
async def handle_update_watchlist(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
tickers = gateway._normalize_watchlist(data.get("tickers"))
if not tickers:
await websocket.send(json.dumps({"type": "error", "message": "update_watchlist requires at least one valid ticker."}, ensure_ascii=False))
return
config_name = gateway.config.get("config_name", "default")
update_bootstrap_values_for_run(
project_root=gateway._project_root,
config_name=config_name,
updates={"tickers": tickers},
)
await gateway.state_sync.on_system_message(f"Watchlist updated: {', '.join(tickers)}")
await gateway.broadcast({"type": "watchlist_updated", "config_name": config_name, "tickers": tickers})
await handle_reload_runtime_assets(gateway)
gateway._schedule_watchlist_market_store_refresh(tickers)
async def handle_get_agent_skills(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
agent_id = str(data.get("agent_id", "")).strip()
if not agent_id:
await websocket.send(json.dumps({"type": "error", "message": "get_agent_skills requires agent_id."}, ensure_ascii=False))
return
config_name = gateway.config.get("config_name", "default")
skills_manager = SkillsManager(project_root=gateway._project_root)
agent_asset_dir = skills_manager.get_agent_asset_dir(config_name, agent_id)
agent_config = load_agent_workspace_config(agent_asset_dir / "agent.yaml")
resolved_skills = set(skills_manager.resolve_agent_skill_names(config_name=config_name, agent_id=agent_id, default_skills=[]))
enabled = set(agent_config.enabled_skills)
disabled = set(agent_config.disabled_skills)
payload = []
for item in skills_manager.list_agent_skill_catalog(config_name, agent_id):
if item.skill_name in disabled:
status = "disabled"
elif item.skill_name in enabled:
status = "enabled"
elif item.skill_name in resolved_skills:
status = "active"
else:
status = "available"
payload.append({
"skill_name": item.skill_name,
"name": item.name,
"description": item.description,
"version": item.version,
"source": item.source,
"tools": item.tools,
"status": status,
})
await websocket.send(json.dumps({
"type": "agent_skills_loaded",
"config_name": config_name,
"agent_id": agent_id,
"skills": payload,
}, ensure_ascii=False))
async def handle_get_agent_profile(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
agent_id = str(data.get("agent_id", "")).strip()
if not agent_id:
await websocket.send(json.dumps({"type": "error", "message": "get_agent_profile requires agent_id."}, ensure_ascii=False))
return
config_name = gateway.config.get("config_name", "default")
skills_manager = SkillsManager(project_root=gateway._project_root)
asset_dir = skills_manager.get_agent_asset_dir(config_name, agent_id)
agent_config = load_agent_workspace_config(asset_dir / "agent.yaml")
profiles = load_agent_profiles()
profile = profiles.get(agent_id, {})
bootstrap = get_bootstrap_config_for_run(gateway._project_root, config_name)
override = bootstrap.agent_override(agent_id)
active_tool_groups = override.get("active_tool_groups", agent_config.active_tool_groups or profile.get("active_tool_groups", []))
if not isinstance(active_tool_groups, list):
active_tool_groups = []
disabled_tool_groups = agent_config.disabled_tool_groups
if disabled_tool_groups:
disabled_set = set(disabled_tool_groups)
active_tool_groups = [group_name for group_name in active_tool_groups if group_name not in disabled_set]
default_skills = profile.get("skills", [])
if not isinstance(default_skills, list):
default_skills = []
resolved_skills = skills_manager.resolve_agent_skill_names(
config_name=config_name,
agent_id=agent_id,
default_skills=default_skills,
)
prompt_files = agent_config.prompt_files or ["SOUL.md", "PROFILE.md", "AGENTS.md", "POLICY.md", "MEMORY.md"]
model_name, model_provider = get_agent_model_info(agent_id)
await websocket.send(json.dumps({
"type": "agent_profile_loaded",
"config_name": config_name,
"agent_id": agent_id,
"profile": {
"model_name": model_name,
"model_provider": model_provider,
"prompt_files": prompt_files,
"default_skills": default_skills,
"resolved_skills": resolved_skills,
"active_tool_groups": active_tool_groups,
"disabled_tool_groups": disabled_tool_groups,
"enabled_skills": agent_config.enabled_skills,
"disabled_skills": agent_config.disabled_skills,
},
}, ensure_ascii=False))
async def handle_get_skill_detail(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
agent_id = str(data.get("agent_id", "")).strip()
skill_name = str(data.get("skill_name", "")).strip()
if not skill_name:
await websocket.send(json.dumps({"type": "error", "message": "get_skill_detail requires skill_name."}, ensure_ascii=False))
return
skills_manager = SkillsManager(project_root=gateway._project_root)
try:
if agent_id:
config_name = gateway.config.get("config_name", "default")
detail = skills_manager.load_agent_skill_document(config_name=config_name, agent_id=agent_id, skill_name=skill_name)
else:
detail = skills_manager.load_skill_document(skill_name)
except FileNotFoundError:
await websocket.send(json.dumps({"type": "error", "message": f"Unknown skill: {skill_name}"}, ensure_ascii=False))
return
await websocket.send(json.dumps({
"type": "skill_detail_loaded",
"agent_id": agent_id,
"skill": detail,
}, ensure_ascii=False))
async def handle_create_agent_local_skill(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
agent_id = str(data.get("agent_id", "")).strip()
skill_name = str(data.get("skill_name", "")).strip()
if not agent_id or not skill_name:
await websocket.send(json.dumps({"type": "error", "message": "create_agent_local_skill requires agent_id and skill_name."}, ensure_ascii=False))
return
config_name = gateway.config.get("config_name", "default")
skills_manager = SkillsManager(project_root=gateway._project_root)
try:
skills_manager.create_agent_local_skill(config_name=config_name, agent_id=agent_id, skill_name=skill_name)
except (ValueError, FileExistsError) as exc:
await websocket.send(json.dumps({"type": "error", "message": str(exc)}, ensure_ascii=False))
return
await gateway.state_sync.on_system_message(f"Created local skill {skill_name} for {agent_id}")
await gateway._handle_reload_runtime_assets()
await websocket.send(json.dumps({"type": "agent_local_skill_created", "agent_id": agent_id, "skill_name": skill_name}, ensure_ascii=False))
await handle_get_agent_skills(gateway, websocket, {"agent_id": agent_id})
await handle_get_skill_detail(gateway, websocket, {"agent_id": agent_id, "skill_name": skill_name})
async def handle_update_agent_local_skill(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
agent_id = str(data.get("agent_id", "")).strip()
skill_name = str(data.get("skill_name", "")).strip()
content = data.get("content")
if not agent_id or not skill_name or not isinstance(content, str):
await websocket.send(json.dumps({"type": "error", "message": "update_agent_local_skill requires agent_id, skill_name, and string content."}, ensure_ascii=False))
return
config_name = gateway.config.get("config_name", "default")
skills_manager = SkillsManager(project_root=gateway._project_root)
try:
skills_manager.update_agent_local_skill(config_name=config_name, agent_id=agent_id, skill_name=skill_name, content=content)
except (ValueError, FileNotFoundError) as exc:
await websocket.send(json.dumps({"type": "error", "message": str(exc)}, ensure_ascii=False))
return
await gateway.state_sync.on_system_message(f"Updated local skill {skill_name} for {agent_id}")
await gateway._handle_reload_runtime_assets()
await websocket.send(json.dumps({"type": "agent_local_skill_updated", "agent_id": agent_id, "skill_name": skill_name}, ensure_ascii=False))
await handle_get_skill_detail(gateway, websocket, {"agent_id": agent_id, "skill_name": skill_name})
async def handle_delete_agent_local_skill(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
agent_id = str(data.get("agent_id", "")).strip()
skill_name = str(data.get("skill_name", "")).strip()
if not agent_id or not skill_name:
await websocket.send(json.dumps({"type": "error", "message": "delete_agent_local_skill requires agent_id and skill_name."}, ensure_ascii=False))
return
config_name = gateway.config.get("config_name", "default")
skills_manager = SkillsManager(project_root=gateway._project_root)
try:
skills_manager.delete_agent_local_skill(config_name=config_name, agent_id=agent_id, skill_name=skill_name)
skills_manager.forget_agent_skill_overrides(config_name=config_name, agent_id=agent_id, skill_names=[skill_name])
except (ValueError, FileNotFoundError) as exc:
await websocket.send(json.dumps({"type": "error", "message": str(exc)}, ensure_ascii=False))
return
await gateway.state_sync.on_system_message(f"Deleted local skill {skill_name} for {agent_id}")
await gateway._handle_reload_runtime_assets()
await websocket.send(json.dumps({"type": "agent_local_skill_deleted", "agent_id": agent_id, "skill_name": skill_name}, ensure_ascii=False))
await handle_get_agent_skills(gateway, websocket, {"agent_id": agent_id})
async def handle_remove_agent_skill(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
agent_id = str(data.get("agent_id", "")).strip()
skill_name = str(data.get("skill_name", "")).strip()
if not agent_id or not skill_name:
await websocket.send(json.dumps({"type": "error", "message": "remove_agent_skill requires agent_id and skill_name."}, ensure_ascii=False))
return
config_name = gateway.config.get("config_name", "default")
skills_manager = SkillsManager(project_root=gateway._project_root)
skill_names = {
item.skill_name
for item in skills_manager.list_agent_skill_catalog(config_name, agent_id)
if item.source != "local"
}
if skill_name not in skill_names:
await websocket.send(json.dumps({"type": "error", "message": f"Unknown shared skill: {skill_name}"}, ensure_ascii=False))
return
skills_manager.update_agent_skill_overrides(config_name=config_name, agent_id=agent_id, disable=[skill_name])
await gateway.state_sync.on_system_message(f"Removed shared skill {skill_name} from {agent_id}")
await gateway._handle_reload_runtime_assets()
await websocket.send(json.dumps({"type": "agent_skill_removed", "agent_id": agent_id, "skill_name": skill_name}, ensure_ascii=False))
await handle_get_agent_skills(gateway, websocket, {"agent_id": agent_id})
async def handle_update_agent_skill(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
agent_id = str(data.get("agent_id", "")).strip()
skill_name = str(data.get("skill_name", "")).strip()
enabled = data.get("enabled")
if not agent_id or not skill_name or not isinstance(enabled, bool):
await websocket.send(json.dumps({"type": "error", "message": "update_agent_skill requires agent_id, skill_name, and boolean enabled."}, ensure_ascii=False))
return
config_name = gateway.config.get("config_name", "default")
skills_manager = SkillsManager(project_root=gateway._project_root)
skill_names = {item.skill_name for item in skills_manager.list_agent_skill_catalog(config_name, agent_id)}
if skill_name not in skill_names:
await websocket.send(json.dumps({"type": "error", "message": f"Unknown skill: {skill_name}"}, ensure_ascii=False))
return
if enabled:
skills_manager.update_agent_skill_overrides(config_name=config_name, agent_id=agent_id, enable=[skill_name])
await gateway.state_sync.on_system_message(f"Enabled skill {skill_name} for {agent_id}")
else:
skills_manager.update_agent_skill_overrides(config_name=config_name, agent_id=agent_id, disable=[skill_name])
await gateway.state_sync.on_system_message(f"Disabled skill {skill_name} for {agent_id}")
await websocket.send(json.dumps({
"type": "agent_skill_updated",
"agent_id": agent_id,
"skill_name": skill_name,
"enabled": enabled,
}, ensure_ascii=False))
await gateway._handle_reload_runtime_assets()
await handle_get_agent_skills(gateway, websocket, {"agent_id": agent_id})
async def handle_get_agent_workspace_file(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
agent_id = str(data.get("agent_id", "")).strip()
filename = gateway._normalize_agent_workspace_filename(data.get("filename"))
if not agent_id or not filename:
await websocket.send(json.dumps({"type": "error", "message": "get_agent_workspace_file requires agent_id and supported filename."}, ensure_ascii=False))
return
config_name = gateway.config.get("config_name", "default")
skills_manager = SkillsManager(project_root=gateway._project_root)
asset_dir = skills_manager.get_agent_asset_dir(config_name, agent_id)
asset_dir.mkdir(parents=True, exist_ok=True)
path = asset_dir / filename
content = path.read_text(encoding="utf-8") if path.exists() else ""
await websocket.send(json.dumps({
"type": "agent_workspace_file_loaded",
"config_name": config_name,
"agent_id": agent_id,
"filename": filename,
"content": content,
}, ensure_ascii=False))
async def handle_update_agent_workspace_file(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
agent_id = str(data.get("agent_id", "")).strip()
filename = gateway._normalize_agent_workspace_filename(data.get("filename"))
content = data.get("content")
if not agent_id or not filename or not isinstance(content, str):
await websocket.send(json.dumps({"type": "error", "message": "update_agent_workspace_file requires agent_id, supported filename, and string content."}, ensure_ascii=False))
return
config_name = gateway.config.get("config_name", "default")
skills_manager = SkillsManager(project_root=gateway._project_root)
asset_dir = skills_manager.get_agent_asset_dir(config_name, agent_id)
asset_dir.mkdir(parents=True, exist_ok=True)
path = asset_dir / filename
path.write_text(content, encoding="utf-8")
await gateway.state_sync.on_system_message(f"Updated {filename} for {agent_id}")
await websocket.send(json.dumps({"type": "agent_workspace_file_updated", "agent_id": agent_id, "filename": filename}, ensure_ascii=False))
await gateway._handle_reload_runtime_assets()
await handle_get_agent_workspace_file(gateway, websocket, {"agent_id": agent_id, "filename": filename})

View File

@@ -0,0 +1,391 @@
# -*- coding: utf-8 -*-
"""Cycle and monitoring helpers extracted from the main Gateway module."""
from __future__ import annotations
import asyncio
import logging
from typing import Any
from backend.data.market_ingest import ingest_symbols, refresh_news_for_symbols
from backend.domains import trading as trading_domain
from backend.utils.msg_adapter import FrontendAdapter
logger = logging.getLogger(__name__)
def schedule_watchlist_market_store_refresh(gateway: Any, tickers: list[str]) -> None:
"""Kick off a non-blocking market-store refresh for an updated watchlist."""
if not tickers:
return
if gateway._watchlist_ingest_task and not gateway._watchlist_ingest_task.done():
gateway._watchlist_ingest_task.cancel()
gateway._watchlist_ingest_task = asyncio.create_task(
refresh_market_store_for_watchlist(gateway, tickers),
)
async def refresh_market_store_for_watchlist(gateway: Any, tickers: list[str]) -> None:
"""Refresh the long-lived market store after a watchlist update."""
try:
await gateway.state_sync.on_system_message(
f"正在同步自选股市场数据: {', '.join(tickers)}",
)
results = await asyncio.to_thread(
ingest_symbols,
tickers,
mode="incremental",
)
summary = ", ".join(
f"{item['symbol']} prices={item['prices']} news={item['news']}"
for item in results
)
await gateway.state_sync.on_system_message(
f"自选股市场数据已同步: {summary}",
)
except asyncio.CancelledError:
raise
except Exception as exc:
logger.warning("Watchlist market store refresh failed: %s", exc)
await gateway.state_sync.on_system_message(
f"自选股市场数据同步失败: {exc}",
)
async def market_status_monitor(gateway: Any) -> None:
"""Periodically check and broadcast market status changes."""
while True:
try:
await gateway.market_service.check_and_broadcast_market_status()
status = gateway.market_service.get_market_status()
if status["status"] == "open" and not gateway.storage.is_live_session_active:
gateway.storage.start_live_session()
summary = gateway.storage.build_dashboard_snapshot_from_state(gateway.state_sync.state).get("summary") or {}
gateway._session_start_portfolio_value = summary.get(
"totalAssetValue",
gateway.storage.initial_cash,
)
logger.info(
"Session start portfolio: $%s",
f"{gateway._session_start_portfolio_value:,.2f}",
)
elif status["status"] != "open" and gateway.storage.is_live_session_active:
gateway.storage.end_live_session()
gateway._session_start_portfolio_value = None
if gateway.storage.is_live_session_active:
await update_and_broadcast_live_returns(gateway)
await asyncio.sleep(60)
except asyncio.CancelledError:
break
except Exception as exc:
logger.error("Market status monitor error: %s", exc)
await asyncio.sleep(60)
async def update_and_broadcast_live_returns(gateway: Any) -> None:
"""Calculate and broadcast live returns for current session."""
if not gateway.storage.is_live_session_active:
return
prices = gateway.market_service.get_all_prices()
if not prices or not any(p > 0 for p in prices.values()):
return
state = gateway.storage.load_internal_state()
equity_history = state.get("equity_history", [])
baseline_history = state.get("baseline_history", [])
baseline_vw_history = state.get("baseline_vw_history", [])
momentum_history = state.get("momentum_history", [])
current_equity = equity_history[-1]["v"] if equity_history else None
current_baseline = baseline_history[-1]["v"] if baseline_history else None
current_baseline_vw = baseline_vw_history[-1]["v"] if baseline_vw_history else None
current_momentum = momentum_history[-1]["v"] if momentum_history else None
point = gateway.storage.update_live_returns(
current_equity=current_equity,
current_baseline=current_baseline,
current_baseline_vw=current_baseline_vw,
current_momentum=current_momentum,
)
if point:
live_returns = gateway.storage.get_live_returns()
await gateway.broadcast(
{
"type": "team_summary",
"equity_return": live_returns["equity_return"],
"baseline_return": live_returns["baseline_return"],
"baseline_vw_return": live_returns["baseline_vw_return"],
"momentum_return": live_returns["momentum_return"],
},
)
async def on_strategy_trigger(gateway: Any, date: str) -> None:
"""Handle trading cycle trigger."""
if gateway._cycle_lock.locked():
logger.warning("Trading cycle already running, skipping trigger for %s", date)
await gateway.state_sync.on_system_message(f"已有交易周期在运行,跳过本次触发: {date}")
return
async with gateway._cycle_lock:
logger.info("Strategy triggered for %s", date)
tickers = gateway.config.get("tickers", [])
if gateway.is_backtest:
await run_backtest_cycle(gateway, date, tickers)
else:
await run_live_cycle(gateway, date, tickers)
async def on_heartbeat_trigger(gateway: Any, date: str) -> None:
"""Run lightweight heartbeat check for all analysts."""
logger.info("[Heartbeat] Running heartbeat check for %s", date)
analysts = gateway.pipeline._all_analysts()
for analyst in analysts:
try:
ws_id = getattr(analyst, "workspace_id", None)
if ws_id:
from backend.agents.workspace_manager import get_workspace_dir
from pathlib import Path
from agentscope.message import Msg
ws_dir = get_workspace_dir(ws_id)
if ws_dir:
hb_path = Path(ws_dir) / "HEARTBEAT.md"
if hb_path.exists():
content = hb_path.read_text(encoding="utf-8").strip()
if content:
hb_task = f"# 定期主动检查\n\n{content}\n\n请执行上述检查并报告结果。"
logger.info("[Heartbeat] Running heartbeat for %s", analyst.name)
msg = Msg(role="user", content=hb_task, name="system")
await analyst.reply([msg])
logger.info("[Heartbeat] %s heartbeat complete", analyst.name)
continue
logger.debug("[Heartbeat] No HEARTBEAT.md for %s, skipping", analyst.name)
except Exception as exc:
logger.error("[Heartbeat] %s failed: %s", analyst.name, exc, exc_info=True)
async def run_backtest_cycle(gateway: Any, date: str, tickers: list[str]) -> None:
gateway.market_service.set_backtest_date(date)
await gateway.market_service.emit_market_open()
await gateway.state_sync.on_cycle_start(date)
gateway._dashboard.update(date=date, status="Analyzing...")
prices = gateway.market_service.get_open_prices()
close_prices = gateway.market_service.get_close_prices()
market_caps = await get_market_caps(gateway, tickers, date)
result = await gateway.pipeline.run_cycle(
tickers=tickers,
date=date,
prices=prices,
close_prices=close_prices,
market_caps=market_caps,
)
await gateway.market_service.emit_market_close()
settlement_result = result.get("settlement_result")
save_cycle_results(gateway, result, date, close_prices, settlement_result)
await broadcast_portfolio_updates(gateway, result, close_prices)
await finalize_cycle(gateway, date)
async def run_live_cycle(gateway: Any, date: str, tickers: list[str]) -> None:
trading_date = gateway.market_service.get_live_trading_date()
logger.info("Live cycle: triggered=%s, trading_date=%s", date, trading_date)
try:
news_refresh = await asyncio.to_thread(
refresh_news_for_symbols,
tickers,
end_date=trading_date,
store=gateway.storage.market_store,
)
logger.info(
"News refresh complete: %s",
", ".join(
f"{item['symbol']} news={item['news']}"
for item in news_refresh
) or "no symbols",
)
except Exception as exc:
logger.warning("Live cycle news refresh failed: %s", exc)
await gateway.state_sync.on_cycle_start(trading_date)
gateway._dashboard.update(date=trading_date, status="Analyzing...")
market_caps = await get_market_caps(gateway, tickers, trading_date)
schedule_mode = gateway.config.get("schedule_mode", "daily")
market_status = gateway.market_service.get_market_status()
current_prices = gateway.market_service.get_all_prices()
if schedule_mode == "intraday":
execute_decisions = market_status.get("status") == "open"
if execute_decisions:
await gateway.state_sync.on_system_message("定时任务触发:当前处于交易时段,本轮将执行交易决策")
else:
await gateway.state_sync.on_system_message("定时任务触发:当前非交易时段,本轮仅更新数据与分析,不执行交易")
result = await gateway.pipeline.run_cycle(
tickers=tickers,
date=trading_date,
prices=current_prices,
market_caps=market_caps,
execute_decisions=execute_decisions,
)
close_prices = current_prices
else:
result = await gateway.pipeline.run_cycle(
tickers=tickers,
date=trading_date,
market_caps=market_caps,
get_open_prices_fn=gateway.market_service.wait_for_open_prices,
get_close_prices_fn=gateway.market_service.wait_for_close_prices,
)
close_prices = gateway.market_service.get_all_prices()
settlement_result = result.get("settlement_result")
save_cycle_results(gateway, result, trading_date, close_prices, settlement_result)
await broadcast_portfolio_updates(gateway, result, close_prices)
await finalize_cycle(gateway, trading_date)
async def finalize_cycle(gateway: Any, date: str) -> None:
dashboard_snapshot = gateway.storage.build_dashboard_snapshot_from_state(gateway.state_sync.state)
summary = dashboard_snapshot.get("summary") or {}
if gateway.storage.is_live_session_active:
summary.update(gateway.storage.get_live_returns())
await gateway.state_sync.on_cycle_end(date, portfolio_summary=summary)
holdings = dashboard_snapshot.get("holdings") or []
trades = dashboard_snapshot.get("trades") or []
leaderboard = dashboard_snapshot.get("leaderboard") or []
if leaderboard:
await gateway.state_sync.on_leaderboard_update(leaderboard)
gateway._dashboard.update(date=date, status="Running", portfolio=summary, holdings=holdings, trades=trades)
async def get_market_caps(gateway: Any, tickers: list[str], date: str) -> dict[str, float]:
market_caps: dict[str, float] = {}
for ticker in tickers:
try:
market_cap = None
response = await gateway._call_trading_service(
f"get_market_cap for {ticker}",
lambda client, symbol=ticker: client.get_market_cap(ticker=symbol, end_date=date),
)
if response is not None:
market_cap = response.get("market_cap")
if market_cap is None:
payload = trading_domain.get_market_cap_payload(ticker=ticker, end_date=date)
market_cap = payload.get("market_cap")
market_caps[ticker] = market_cap if market_cap else 1e9
except Exception as exc:
logger.warning("Failed to get market cap for %s, using default 1e9: %s", ticker, exc)
market_caps[ticker] = 1e9
return market_caps
async def broadcast_portfolio_updates(gateway: Any, result: dict[str, Any], prices: dict[str, float]) -> None:
portfolio = result.get("portfolio", {})
if portfolio:
holdings = FrontendAdapter.build_holdings(portfolio, prices)
if holdings:
await gateway.state_sync.on_holdings_update(holdings)
stats = FrontendAdapter.build_stats(portfolio, prices)
if stats:
await gateway.state_sync.on_stats_update(stats)
executed_trades = result.get("executed_trades", [])
if executed_trades:
await gateway.state_sync.on_trades_executed(executed_trades)
def save_cycle_results(
gateway: Any,
result: dict[str, Any],
date: str,
prices: dict[str, float],
settlement_result: dict[str, Any] | None = None,
) -> None:
portfolio = result.get("portfolio", {})
executed_trades = result.get("executed_trades", [])
baseline_values = settlement_result.get("baseline_values") if settlement_result else None
if portfolio:
gateway.storage.update_dashboard_after_cycle(
portfolio=portfolio,
prices=prices,
date=date,
executed_trades=executed_trades,
baseline_values=baseline_values,
)
async def run_backtest_dates(gateway: Any, dates: list[str]) -> None:
gateway.state_sync.set_backtest_dates(dates)
gateway._dashboard.update(days_total=len(dates), days_completed=0)
await gateway.state_sync.on_system_message(f"Starting backtest - {len(dates)} trading days")
try:
for i, date in enumerate(dates):
gateway._dashboard.update(days_completed=i)
await gateway.on_strategy_trigger(date=date)
await asyncio.sleep(0.1)
await gateway.state_sync.on_system_message(f"Backtest complete - {len(dates)} days")
summary = gateway.storage.build_dashboard_snapshot_from_state(gateway.state_sync.state).get("summary") or {}
gateway._dashboard.update(status="Complete", portfolio=summary, days_completed=len(dates))
gateway._dashboard.stop()
gateway._dashboard.print_final_summary()
except Exception as exc:
error_msg = f"Backtest failed: {type(exc).__name__}: {str(exc)}"
logger.error(error_msg, exc_info=True)
asyncio.create_task(gateway.state_sync.on_system_message(error_msg))
gateway._dashboard.update(status=f"Failed: {str(exc)}")
gateway._dashboard.stop()
raise
finally:
gateway._backtest_task = None
def handle_backtest_exception(gateway: Any, task: asyncio.Task) -> None:
try:
task.result()
except asyncio.CancelledError:
logger.info("Backtest task was cancelled")
except Exception as exc:
logger.error("Backtest task failed with exception:%s:%s", type(exc).__name__, exc, exc_info=True)
def handle_manual_cycle_exception(gateway: Any, task: asyncio.Task) -> None:
gateway._manual_cycle_task = None
try:
task.result()
except asyncio.CancelledError:
logger.info("Manual cycle task was cancelled")
except Exception as exc:
logger.error("Manual cycle task failed with exception:%s:%s", type(exc).__name__, exc, exc_info=True)
def set_backtest_dates(gateway: Any, dates: list[str]) -> None:
gateway.state_sync.set_backtest_dates(dates)
if dates:
gateway._backtest_start_date = dates[0]
gateway._backtest_end_date = dates[-1]
gateway._dashboard.days_total = len(dates)
def stop_gateway(gateway: Any) -> None:
gateway.state_sync.save_state()
gateway.market_service.stop()
if gateway._backtest_task:
gateway._backtest_task.cancel()
if gateway._market_status_task:
gateway._market_status_task.cancel()
if gateway._watchlist_ingest_task:
gateway._watchlist_ingest_task.cancel()
gateway._dashboard.stop()

View File

@@ -0,0 +1,175 @@
# -*- coding: utf-8 -*-
"""Runtime/state support helpers extracted from the main Gateway module."""
from __future__ import annotations
from typing import Any
from backend.data.provider_utils import normalize_symbol
def normalize_watchlist(raw_tickers: Any) -> list[str]:
"""Parse watchlist payloads from websocket messages."""
if raw_tickers is None:
return []
if isinstance(raw_tickers, str):
candidates = raw_tickers.split(",")
elif isinstance(raw_tickers, list):
candidates = raw_tickers
else:
candidates = [raw_tickers]
tickers: list[str] = []
for candidate in candidates:
symbol = normalize_symbol(str(candidate).strip().strip("\"'"))
if symbol and symbol not in tickers:
tickers.append(symbol)
return tickers
def normalize_agent_workspace_filename(
raw_name: Any,
*,
allowlist: set[str],
) -> str | None:
"""Restrict editable workspace files to a safe allowlist."""
filename = str(raw_name or "").strip()
if filename in allowlist:
return filename
return None
def apply_runtime_config(gateway: Any, runtime_config: dict[str, Any]) -> dict[str, Any]:
"""Apply runtime config to gateway-owned services and state."""
warnings: list[str] = []
ticker_changes = gateway.market_service.update_tickers(
runtime_config.get("tickers", []),
)
gateway.config["tickers"] = ticker_changes["active"]
gateway.pipeline.max_comm_cycles = int(runtime_config["max_comm_cycles"])
gateway.config["max_comm_cycles"] = gateway.pipeline.max_comm_cycles
gateway.config["schedule_mode"] = runtime_config.get(
"schedule_mode",
gateway.config.get("schedule_mode", "daily"),
)
gateway.config["interval_minutes"] = int(
runtime_config.get(
"interval_minutes",
gateway.config.get("interval_minutes", 60),
),
)
gateway.config["trigger_time"] = runtime_config.get(
"trigger_time",
gateway.config.get("trigger_time", "09:30"),
)
if gateway.scheduler:
gateway.scheduler.reconfigure(
mode=gateway.config["schedule_mode"],
trigger_time=gateway.config["trigger_time"],
interval_minutes=gateway.config["interval_minutes"],
)
pm_apply_result = gateway.pipeline.pm.apply_runtime_portfolio_config(
margin_requirement=runtime_config["margin_requirement"],
)
gateway.config["margin_requirement"] = gateway.pipeline.pm.portfolio.get(
"margin_requirement",
runtime_config["margin_requirement"],
)
requested_initial_cash = float(runtime_config["initial_cash"])
current_initial_cash = float(gateway.storage.initial_cash)
initial_cash_applied = requested_initial_cash == current_initial_cash
if not initial_cash_applied:
if (
gateway.storage.can_apply_initial_cash()
and gateway.pipeline.pm.can_apply_initial_cash()
):
initial_cash_applied = gateway.storage.apply_initial_cash(
requested_initial_cash,
)
if initial_cash_applied:
gateway.pipeline.pm.apply_runtime_portfolio_config(
initial_cash=requested_initial_cash,
)
gateway.config["initial_cash"] = gateway.storage.initial_cash
else:
warnings.append(
"initial_cash changed in BOOTSTRAP.md but was not applied "
"because the run already has positions, margin usage, or trades.",
)
requested_enable_memory = bool(runtime_config["enable_memory"])
current_enable_memory = bool(gateway.config.get("enable_memory", False))
if requested_enable_memory != current_enable_memory:
warnings.append(
"enable_memory changed in BOOTSTRAP.md but still requires a restart "
"because long-term memory contexts are created at startup.",
)
sync_runtime_state(gateway)
return {
"runtime_config_requested": runtime_config,
"runtime_config_applied": {
"tickers": list(gateway.config.get("tickers", [])),
"schedule_mode": gateway.config.get("schedule_mode", "daily"),
"interval_minutes": gateway.config.get("interval_minutes", 60),
"trigger_time": gateway.config.get("trigger_time", "09:30"),
"initial_cash": gateway.storage.initial_cash,
"margin_requirement": gateway.config["margin_requirement"],
"max_comm_cycles": gateway.config["max_comm_cycles"],
"enable_memory": gateway.config.get("enable_memory", False),
},
"runtime_config_status": {
"tickers": True,
"schedule_mode": True,
"interval_minutes": True,
"trigger_time": True,
"initial_cash": initial_cash_applied,
"margin_requirement": pm_apply_result["margin_requirement"],
"max_comm_cycles": True,
"enable_memory": requested_enable_memory == current_enable_memory,
},
"ticker_changes": ticker_changes,
"runtime_config_warnings": warnings,
}
def sync_runtime_state(gateway: Any) -> None:
"""Refresh persisted state and dashboard after runtime config changes."""
gateway.state_sync.update_state("tickers", gateway.config.get("tickers", []))
gateway.state_sync.update_state(
"runtime_config",
{
"tickers": gateway.config.get("tickers", []),
"schedule_mode": gateway.config.get("schedule_mode", "daily"),
"interval_minutes": gateway.config.get("interval_minutes", 60),
"trigger_time": gateway.config.get("trigger_time", "09:30"),
"initial_cash": gateway.storage.initial_cash,
"margin_requirement": gateway.config.get("margin_requirement"),
"max_comm_cycles": gateway.config.get("max_comm_cycles"),
"enable_memory": gateway.config.get("enable_memory", False),
},
)
gateway.storage.update_server_state_from_dashboard(gateway.state_sync.state)
gateway.state_sync.save_state()
gateway._dashboard.tickers = list(gateway.config.get("tickers", []))
gateway._dashboard.initial_cash = gateway.storage.initial_cash
gateway._dashboard.enable_memory = bool(gateway.config.get("enable_memory", False))
dashboard_snapshot = gateway.storage.build_dashboard_snapshot_from_state(gateway.state_sync.state)
summary = dashboard_snapshot.get("summary") or {}
holdings = dashboard_snapshot.get("holdings") or []
trades = dashboard_snapshot.get("trades") or []
gateway._dashboard.update(
portfolio=summary,
holdings=holdings,
trades=trades,
)

View File

@@ -0,0 +1,716 @@
# -*- coding: utf-8 -*-
"""Stock-related Gateway handlers extracted from the main Gateway module."""
from __future__ import annotations
import asyncio
import json
import logging
from datetime import datetime, timedelta
from typing import Any
from backend.data.provider_utils import normalize_symbol
from backend.domains import news as news_domain
from backend.domains import trading as trading_domain
from backend.enrich.news_enricher import enrich_news_for_symbol
from backend.enrich.llm_enricher import llm_enrichment_enabled
from backend.tools.data_tools import prices_to_df
from shared.client import NewsServiceClient, TradingServiceClient
logger = logging.getLogger(__name__)
async def handle_get_stock_history(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
ticker = normalize_symbol(data.get("ticker", ""))
if not ticker:
await websocket.send(json.dumps({
"type": "stock_history_loaded",
"ticker": "",
"prices": [],
"source": None,
"error": "invalid ticker",
}, ensure_ascii=False))
return
lookback_days = data.get("lookback_days", 90)
try:
lookback_days = max(7, min(int(lookback_days), 365))
except (TypeError, ValueError):
lookback_days = 90
end_date = gateway.state_sync.state.get("current_date") or datetime.now().strftime("%Y-%m-%d")
try:
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
except ValueError:
end_dt = datetime.now()
end_date = end_dt.strftime("%Y-%m-%d")
start_date = (end_dt - timedelta(days=lookback_days)).strftime("%Y-%m-%d")
prices = []
source = "polygon"
response = await gateway._call_trading_service(
"get_prices for history",
lambda client: client.get_prices(ticker=ticker, start_date=start_date, end_date=end_date),
)
if response is not None:
prices = response.prices
source = "trading_service"
if not prices:
prices = await asyncio.to_thread(gateway.storage.market_store.get_ohlc, ticker, start_date, end_date)
if not prices:
payload = await asyncio.to_thread(
trading_domain.get_prices_payload,
ticker=ticker,
start_date=start_date,
end_date=end_date,
)
prices = payload.get("prices") or []
usage_snapshot = gateway._provider_router.get_usage_snapshot()
source = usage_snapshot.get("last_success", {}).get("prices")
if prices:
await asyncio.to_thread(
gateway.storage.market_store.upsert_ohlc,
ticker,
[price.model_dump() for price in prices],
source=source or "provider",
)
await websocket.send(json.dumps({
"type": "stock_history_loaded",
"ticker": ticker,
"prices": [price if isinstance(price, dict) else price.model_dump() for price in prices][-120:],
"source": source,
"start_date": start_date,
"end_date": end_date,
}, ensure_ascii=False, default=str))
async def handle_get_stock_explain_events(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
ticker = normalize_symbol(data.get("ticker", ""))
snapshot = gateway.storage.runtime_db.get_stock_explain_snapshot(ticker)
await websocket.send(json.dumps({
"type": "stock_explain_events_loaded",
"ticker": ticker,
"events": snapshot.get("events", []),
"signals": snapshot.get("signals", []),
"trades": snapshot.get("trades", []),
}, ensure_ascii=False, default=str))
async def handle_get_stock_news(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
ticker = normalize_symbol(data.get("ticker", ""))
if not ticker:
await websocket.send(json.dumps({
"type": "stock_news_loaded",
"ticker": "",
"news": [],
"source": None,
"error": "invalid ticker",
}, ensure_ascii=False))
return
lookback_days = data.get("lookback_days", 30)
limit = data.get("limit", 12)
try:
lookback_days = max(7, min(int(lookback_days), 180))
except (TypeError, ValueError):
lookback_days = 30
try:
limit = max(1, min(int(limit), 30))
except (TypeError, ValueError):
limit = 12
end_date = gateway.state_sync.state.get("current_date") or datetime.now().strftime("%Y-%m-%d")
try:
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
except ValueError:
end_dt = datetime.now()
end_date = end_dt.strftime("%Y-%m-%d")
start_date = (end_dt - timedelta(days=lookback_days)).strftime("%Y-%m-%d")
news_rows = []
source = "polygon"
response = await gateway._call_news_service(
"get_enriched_news",
lambda client: client.get_enriched_news(
ticker=ticker,
start_date=start_date,
end_date=end_date,
limit=limit,
),
)
if response is not None:
news_rows = response.get("news") or []
source = "news_service"
if not news_rows:
payload = await asyncio.to_thread(
news_domain.get_enriched_news,
gateway.storage.market_store,
ticker=ticker,
start_date=start_date,
end_date=end_date,
limit=max(limit, 50),
refresh_if_stale=False,
)
news_rows = (payload.get("news") or [])[-limit:]
source = "market_store"
await websocket.send(json.dumps({
"type": "stock_news_loaded",
"ticker": ticker,
"news": news_rows[-limit:],
"source": source,
"start_date": start_date,
"end_date": end_date,
}, ensure_ascii=False, default=str))
async def handle_get_stock_news_for_date(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
ticker = normalize_symbol(data.get("ticker", ""))
trade_date = str(data.get("date") or "").strip()
if not ticker or not trade_date:
await websocket.send(json.dumps({
"type": "stock_news_for_date_loaded",
"ticker": ticker,
"date": trade_date,
"news": [],
"error": "ticker and date are required",
}, ensure_ascii=False))
return
limit = data.get("limit", 20)
try:
limit = max(1, min(int(limit), 50))
except (TypeError, ValueError):
limit = 20
source = "market_store"
news_rows = []
response = await gateway._call_news_service(
"get_news_for_date",
lambda client: client.get_news_for_date(ticker=ticker, date=trade_date, limit=limit),
)
if response is not None:
news_rows = response.get("news") or []
source = "news_service"
if not news_rows:
payload = await asyncio.to_thread(
news_domain.get_news_for_date,
gateway.storage.market_store,
ticker=ticker,
date=trade_date,
limit=limit,
refresh_if_stale=False,
)
news_rows = payload.get("news") or []
source = "market_store"
await websocket.send(json.dumps({
"type": "stock_news_for_date_loaded",
"ticker": ticker,
"date": trade_date,
"news": news_rows,
"source": source,
}, ensure_ascii=False, default=str))
async def handle_get_stock_news_timeline(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
ticker = normalize_symbol(data.get("ticker", ""))
if not ticker:
await websocket.send(json.dumps({
"type": "stock_news_timeline_loaded",
"ticker": "",
"timeline": [],
"error": "invalid ticker",
}, ensure_ascii=False))
return
lookback_days = data.get("lookback_days", 90)
try:
lookback_days = max(7, min(int(lookback_days), 365))
except (TypeError, ValueError):
lookback_days = 90
end_date = gateway.state_sync.state.get("current_date") or datetime.now().strftime("%Y-%m-%d")
try:
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
except ValueError:
end_dt = datetime.now()
end_date = end_dt.strftime("%Y-%m-%d")
start_date = (end_dt - timedelta(days=lookback_days)).strftime("%Y-%m-%d")
timeline = []
response = await gateway._call_news_service(
"get_news_timeline",
lambda client: client.get_news_timeline(ticker=ticker, start_date=start_date, end_date=end_date),
)
if response is not None:
timeline = response.get("timeline") or []
if not timeline:
payload = await asyncio.to_thread(
news_domain.get_news_timeline,
gateway.storage.market_store,
ticker=ticker,
start_date=start_date,
end_date=end_date,
refresh_if_stale=False,
)
timeline = payload.get("timeline") or []
await websocket.send(json.dumps({
"type": "stock_news_timeline_loaded",
"ticker": ticker,
"timeline": timeline,
"start_date": start_date,
"end_date": end_date,
}, ensure_ascii=False, default=str))
async def handle_get_stock_news_categories(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
ticker = normalize_symbol(data.get("ticker", ""))
if not ticker:
await websocket.send(json.dumps({
"type": "stock_news_categories_loaded",
"ticker": "",
"categories": {},
"error": "invalid ticker",
}, ensure_ascii=False))
return
lookback_days = data.get("lookback_days", 90)
try:
lookback_days = max(7, min(int(lookback_days), 365))
except (TypeError, ValueError):
lookback_days = 90
end_date = gateway.state_sync.state.get("current_date") or datetime.now().strftime("%Y-%m-%d")
try:
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
except ValueError:
end_dt = datetime.now()
end_date = end_dt.strftime("%Y-%m-%d")
start_date = (end_dt - timedelta(days=lookback_days)).strftime("%Y-%m-%d")
categories = {}
response = await gateway._call_news_service(
"get_categories",
lambda client: client.get_categories(
ticker=ticker,
start_date=start_date,
end_date=end_date,
limit=200,
),
)
if response is not None:
categories = response.get("categories") or {}
if not categories:
payload = await asyncio.to_thread(
news_domain.get_news_categories,
gateway.storage.market_store,
ticker=ticker,
start_date=start_date,
end_date=end_date,
limit=200,
refresh_if_stale=False,
)
categories = payload.get("categories") or {}
await websocket.send(json.dumps({
"type": "stock_news_categories_loaded",
"ticker": ticker,
"categories": categories,
"start_date": start_date,
"end_date": end_date,
}, ensure_ascii=False, default=str))
async def handle_get_stock_range_explain(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
ticker = normalize_symbol(data.get("ticker", ""))
start_date = str(data.get("start_date") or "").strip()
end_date = str(data.get("end_date") or "").strip()
if not ticker or not start_date or not end_date:
await websocket.send(json.dumps({
"type": "stock_range_explain_loaded",
"ticker": ticker,
"result": {"error": "ticker, start_date, end_date are required"},
}, ensure_ascii=False))
return
article_ids = data.get("article_ids")
result = None
response = await gateway._call_news_service(
"get_range_explain",
lambda client: client.get_range_explain(
ticker=ticker,
start_date=start_date,
end_date=end_date,
article_ids=article_ids if isinstance(article_ids, list) else None,
limit=100,
),
)
if response is not None:
result = response.get("result")
if result is None:
payload = await asyncio.to_thread(
news_domain.get_range_explain_payload,
gateway.storage.market_store,
ticker=ticker,
start_date=start_date,
end_date=end_date,
article_ids=article_ids if isinstance(article_ids, list) else None,
limit=100,
refresh_if_stale=False,
)
result = payload.get("result")
await websocket.send(json.dumps({
"type": "stock_range_explain_loaded",
"ticker": ticker,
"result": result,
}, ensure_ascii=False, default=str))
async def handle_get_stock_insider_trades(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
ticker = normalize_symbol(data.get("ticker", ""))
if not ticker:
await websocket.send(json.dumps({
"type": "stock_insider_trades_loaded",
"ticker": "",
"trades": [],
"error": "invalid ticker",
}, ensure_ascii=False))
return
end_date = str(data.get("end_date") or gateway.state_sync.state.get("current_date") or datetime.now().strftime("%Y-%m-%d")).strip()[:10]
start_date = str(data.get("start_date") or "").strip()[:10]
limit = int(data.get("limit", 50))
trades = []
response = await gateway._call_trading_service(
"get_insider_trades",
lambda client: client.get_insider_trades(
ticker=ticker,
end_date=end_date,
start_date=start_date if start_date else None,
limit=limit,
),
)
if response is not None:
trades = response.insider_trades
if not trades:
payload = await asyncio.to_thread(
trading_domain.get_insider_trades_payload,
ticker=ticker,
end_date=end_date,
start_date=start_date if start_date else None,
limit=limit,
)
trades = payload.get("insider_trades") or []
sorted_trades = sorted(trades, key=lambda t: t.transaction_date or "", reverse=True)
formatted_trades = [{
"ticker": t.ticker,
"name": t.name,
"title": t.title,
"is_board_director": t.is_board_director,
"transaction_date": t.transaction_date,
"transaction_shares": t.transaction_shares,
"transaction_price_per_share": t.transaction_price_per_share,
"transaction_value": t.transaction_value,
"shares_owned_before_transaction": t.shares_owned_before_transaction,
"shares_owned_after_transaction": t.shares_owned_after_transaction,
"security_title": t.security_title,
"filing_date": t.filing_date,
"holding_change": (
(t.shares_owned_after_transaction or 0) - (t.shares_owned_before_transaction or 0)
if t.shares_owned_after_transaction and t.shares_owned_before_transaction else None
),
"is_buy": ((t.transaction_shares or 0) > 0) if t.transaction_shares is not None else None,
} for t in sorted_trades]
await websocket.send(json.dumps({
"type": "stock_insider_trades_loaded",
"ticker": ticker,
"start_date": start_date or None,
"end_date": end_date,
"trades": formatted_trades,
}, ensure_ascii=False, default=str))
async def handle_get_stock_story(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
ticker = normalize_symbol(data.get("ticker", ""))
if not ticker:
await websocket.send(json.dumps({
"type": "stock_story_loaded",
"ticker": "",
"story": "",
"error": "invalid ticker",
}, ensure_ascii=False))
return
as_of_date = str(data.get("as_of_date") or gateway.state_sync.state.get("current_date") or datetime.now().strftime("%Y-%m-%d")).strip()[:10]
result = await gateway._call_news_service(
"get_story",
lambda client: client.get_story(ticker=ticker, as_of_date=as_of_date),
)
if result is None:
result = await asyncio.to_thread(
news_domain.get_story_payload,
gateway.storage.market_store,
ticker=ticker,
as_of_date=as_of_date,
)
await websocket.send(json.dumps({
"type": "stock_story_loaded",
"ticker": ticker,
"as_of_date": as_of_date,
"story": result.get("story") or "",
"source": result.get("source") or "local",
}, ensure_ascii=False, default=str))
async def handle_get_stock_similar_days(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
ticker = normalize_symbol(data.get("ticker", ""))
target_date = str(data.get("date") or "").strip()[:10]
if not ticker or not target_date:
await websocket.send(json.dumps({
"type": "stock_similar_days_loaded",
"ticker": ticker,
"date": target_date,
"items": [],
"error": "ticker and date are required",
}, ensure_ascii=False))
return
top_k = data.get("top_k", 8)
try:
top_k = max(1, min(int(top_k), 20))
except (TypeError, ValueError):
top_k = 8
result = await gateway._call_news_service(
"get_similar_days",
lambda client: client.get_similar_days(ticker=ticker, date=target_date, n_similar=top_k),
)
if result is None:
result = await asyncio.to_thread(
news_domain.get_similar_days_payload,
gateway.storage.market_store,
ticker=ticker,
date=target_date,
n_similar=top_k,
)
await websocket.send(json.dumps({
"type": "stock_similar_days_loaded",
"ticker": ticker,
"date": target_date,
**result,
}, ensure_ascii=False, default=str))
async def handle_get_stock_technical_indicators(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
ticker = normalize_symbol(data.get("ticker", ""))
if not ticker:
await websocket.send(json.dumps({
"type": "stock_technical_indicators_loaded",
"ticker": ticker,
"indicators": None,
"error": "ticker is required",
}, ensure_ascii=False))
return
try:
end_date = datetime.now()
start_date = end_date - timedelta(days=250)
prices = None
response = await gateway._call_trading_service(
"get_prices",
lambda client: client.get_prices(
ticker=ticker,
start_date=start_date.strftime("%Y-%m-%d"),
end_date=end_date.strftime("%Y-%m-%d"),
),
)
if response is not None:
prices = response.prices
if prices is None:
payload = trading_domain.get_prices_payload(
ticker=ticker,
start_date=start_date.strftime("%Y-%m-%d"),
end_date=end_date.strftime("%Y-%m-%d"),
)
prices = payload.get("prices") or []
if not prices or len(prices) < 20:
await websocket.send(json.dumps({
"type": "stock_technical_indicators_loaded",
"ticker": ticker,
"indicators": None,
"error": "Insufficient price data",
}, ensure_ascii=False))
return
df = prices_to_df(prices)
signal = gateway._technical_analyzer.analyze(ticker, df)
import pandas as pd
df_sorted = df.sort_values("time").reset_index(drop=True)
df_sorted["returns"] = df_sorted["close"].pct_change()
vol_10 = float(df_sorted["returns"].tail(10).std() * (252**0.5) * 100) if len(df_sorted) >= 10 else None
vol_20 = float(df_sorted["returns"].tail(20).std() * (252**0.5) * 100) if len(df_sorted) >= 20 else None
vol_60 = float(df_sorted["returns"].tail(60).std() * (252**0.5) * 100) if len(df_sorted) >= 60 else None
ma_distance = {}
for ma_key in ["ma5", "ma10", "ma20", "ma50", "ma200"]:
ma_value = getattr(signal, ma_key, None)
ma_distance[ma_key] = ((signal.current_price - ma_value) / ma_value) * 100 if ma_value and ma_value > 0 else None
indicators = {
"ticker": ticker,
"current_price": signal.current_price,
"ma": {
"ma5": signal.ma5,
"ma10": signal.ma10,
"ma20": signal.ma20,
"ma50": signal.ma50,
"ma200": signal.ma200,
"distance": ma_distance,
},
"rsi": {
"rsi14": signal.rsi14,
"status": "oversold" if signal.rsi14 < 30 else "overbought" if signal.rsi14 > 70 else "neutral",
},
"macd": {
"macd": signal.macd,
"signal": signal.macd_signal,
"histogram": signal.macd - signal.macd_signal,
},
"bollinger": {
"upper": signal.bollinger_upper,
"mid": signal.bollinger_mid,
"lower": signal.bollinger_lower,
},
"volatility": {
"vol_10d": vol_10,
"vol_20d": vol_20,
"vol_60d": vol_60,
"annualized": signal.annualized_volatility_pct,
"risk_level": signal.risk_level,
},
"trend": signal.trend,
"mean_reversion": signal.mean_reversion_signal,
}
await websocket.send(json.dumps({
"type": "stock_technical_indicators_loaded",
"ticker": ticker,
"indicators": indicators,
}, ensure_ascii=False, default=str))
except Exception as exc:
logger.exception("Error getting technical indicators for %s", ticker)
await websocket.send(json.dumps({
"type": "stock_technical_indicators_loaded",
"ticker": ticker,
"indicators": None,
"error": str(exc),
}, ensure_ascii=False))
async def handle_run_stock_enrich(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
ticker = normalize_symbol(data.get("ticker", ""))
start_date = str(data.get("start_date") or "").strip()[:10]
end_date = str(data.get("end_date") or "").strip()[:10]
story_date = str(data.get("story_date") or end_date or "").strip()[:10]
target_date = str(data.get("target_date") or "").strip()[:10]
force = bool(data.get("force", False))
rebuild_story = bool(data.get("rebuild_story", True))
rebuild_similar_days = bool(data.get("rebuild_similar_days", True))
only_local_to_llm = bool(data.get("only_local_to_llm", False))
limit = data.get("limit", 200)
try:
limit = max(10, min(int(limit), 500))
except (TypeError, ValueError):
limit = 200
if not ticker or not start_date or not end_date:
await websocket.send(json.dumps({
"type": "stock_enrich_completed",
"ticker": ticker,
"start_date": start_date,
"end_date": end_date,
"error": "ticker, start_date, end_date are required",
}, ensure_ascii=False))
return
if only_local_to_llm and not llm_enrichment_enabled():
await websocket.send(json.dumps({
"type": "stock_enrich_completed",
"ticker": ticker,
"start_date": start_date,
"end_date": end_date,
"error": "only_local_to_llm requires EXPLAIN_ENRICH_USE_LLM=true and a configured LLM provider",
}, ensure_ascii=False))
return
result = await asyncio.to_thread(
enrich_news_for_symbol,
gateway.storage.market_store,
ticker,
start_date=start_date,
end_date=end_date,
limit=limit,
skip_existing=not force,
only_reanalyze_local=only_local_to_llm,
)
story_status = None
if rebuild_story and story_date:
await asyncio.to_thread(gateway.storage.market_store.delete_story_cache, ticker, as_of_date=story_date)
story_result = await asyncio.to_thread(
news_domain.get_story_payload,
gateway.storage.market_store,
ticker=ticker,
as_of_date=story_date,
)
story_status = {"as_of_date": story_date, "source": story_result.get("source") or "local"}
similar_status = None
if rebuild_similar_days and target_date:
await asyncio.to_thread(gateway.storage.market_store.delete_similar_day_cache, ticker, target_date=target_date)
similar_result = await asyncio.to_thread(
news_domain.get_similar_days_payload,
gateway.storage.market_store,
ticker=ticker,
date=target_date,
n_similar=8,
)
similar_status = {
"target_date": target_date,
"count": len(similar_result.get("items") or []),
"error": similar_result.get("error"),
}
await websocket.send(json.dumps({
"type": "stock_enrich_completed",
"ticker": ticker,
"start_date": start_date,
"end_date": end_date,
"story_date": story_date or None,
"target_date": target_date or None,
"force": force,
"only_local_to_llm": only_local_to_llm,
"stats": result,
"story_status": story_status,
"similar_status": similar_status,
}, ensure_ascii=False, default=str))

View File

@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
"""
Market Data Service
Supports live, mock, and backtest modes
Supports live and backtest modes
"""
import asyncio
import logging
@@ -10,7 +10,7 @@ from typing import Any, Callable, Dict, List, Optional
from zoneinfo import ZoneInfo
import pandas_market_calendars as mcal
from backend.config.data_config import get_data_source
from backend.config.data_config import get_data_sources
from backend.data.provider_utils import normalize_symbol
logger = logging.getLogger(__name__)
@@ -36,7 +36,6 @@ class MarketService:
self,
tickers: List[str],
poll_interval: int = 10,
mock_mode: bool = False,
backtest_mode: bool = False,
api_key: Optional[str] = None,
backtest_start_date: Optional[str] = None,
@@ -44,7 +43,6 @@ class MarketService:
):
self.tickers = [normalize_symbol(ticker) for ticker in tickers]
self.poll_interval = poll_interval
self.mock_mode = mock_mode
self.backtest_mode = backtest_mode
self.api_key = api_key
self.backtest_start_date = backtest_start_date
@@ -69,8 +67,6 @@ class MarketService:
"""Return the active live quote provider for UI/debugging."""
if self.backtest_mode:
return "backtest"
if self.mock_mode:
return "mock"
if self._price_manager and hasattr(self._price_manager, "provider"):
provider = getattr(self._price_manager, "provider", None)
if isinstance(provider, str) and provider.strip():
@@ -81,8 +77,6 @@ class MarketService:
def mode_name(self) -> str:
if self.backtest_mode:
return "BACKTEST"
elif self.mock_mode:
return "MOCK"
return "LIVE"
async def start(self, broadcast_func: Callable):
@@ -96,8 +90,6 @@ class MarketService:
if self.backtest_mode:
self._start_backtest_mode()
elif self.mock_mode:
self._start_mock_mode()
else:
self._start_real_mode()
@@ -125,26 +117,10 @@ class MarketService:
return callback
def _start_mock_mode(self):
from backend.data.mock_price_manager import MockPriceManager
self._price_manager = MockPriceManager(
poll_interval=self.poll_interval,
volatility=0.5,
)
self._price_manager.add_price_callback(self._make_price_callback())
self._price_manager.subscribe(
self.tickers,
base_prices={t: 100.0 for t in self.tickers},
)
self._price_manager.start()
def _start_real_mode(self):
from backend.data.polling_price_manager import PollingPriceManager
provider = get_data_source()
if provider == "local_csv":
provider = "yfinance"
provider = self._resolve_live_quote_provider()
if provider == "finnhub" and not self.api_key:
raise ValueError("API key required for live mode")
@@ -157,6 +133,13 @@ class MarketService:
self._price_manager.subscribe(self.tickers)
self._price_manager.start()
def _resolve_live_quote_provider(self) -> str:
"""Pick the first configured provider that supports live quote polling."""
for provider in get_data_sources():
if provider in {"finnhub", "yfinance"}:
return provider
return "yfinance"
def _start_backtest_mode(self):
from backend.data.historical_price_manager import (
HistoricalPriceManager,
@@ -257,12 +240,6 @@ class MarketService:
if removed:
self._price_manager.unsubscribe(removed)
if added:
if self.mock_mode:
self._price_manager.subscribe(
added,
base_prices={ticker: 100.0 for ticker in added},
)
else:
self._price_manager.subscribe(added)
if self.backtest_mode and self._current_date:

View File

@@ -9,7 +9,7 @@ from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Iterable
from backend.data.schema import CompanyNews
from shared.schema import CompanyNews
SCHEMA = """

View File

@@ -11,7 +11,6 @@ from pathlib import Path
from typing import Any, Dict, List, Optional
from backend.data.market_store import MarketStore
from .research_db import ResearchDb
from .runtime_db import RuntimeDb
logger = logging.getLogger(__name__)
@@ -22,12 +21,18 @@ class StorageService:
Storage service for data persistence
Responsibilities:
1. Load/save dashboard JSON files
1. Export dashboard JSON files
(summary, holdings, stats, trades, leaderboard)
2. Load/save internal state (_internal_state.json)
3. Load/save server state (server_state.json) with feed history
4. Manage portfolio state persistence
5. Support loading from saved state to resume execution
Notes:
- team_dashboard/*.json is treated as an export/compatibility layer
rather than the authoritative runtime source of truth.
- authoritative runtime reads should prefer in-memory state, server_state,
runtime.db, and market_research.db.
"""
def __init__(
@@ -49,7 +54,7 @@ class StorageService:
self.initial_cash = initial_cash
self.config_name = config_name
# Dashboard file paths
# Dashboard export file paths
self.files = {
"summary": self.dashboard_dir / "summary.json",
"holdings": self.dashboard_dir / "holdings.json",
@@ -66,7 +71,6 @@ class StorageService:
self.state_dir.mkdir(parents=True, exist_ok=True)
self.server_state_file = self.state_dir / "server_state.json"
self.runtime_db = RuntimeDb(self.state_dir / "runtime.db")
self.research_db = ResearchDb(self.state_dir / "research.db")
self.market_store = MarketStore()
# Feed history (for agent messages)
@@ -84,16 +88,8 @@ class StorageService:
logger.info(f"Storage service initialized: {self.dashboard_dir}")
def load_file(self, file_type: str) -> Optional[Any]:
"""
Load dashboard JSON file
Args:
file_type: One of: summary, holdings, stats, trades, leaderboard
Returns:
Loaded data or None if file doesn't exist
"""
def load_export_file(self, file_type: str) -> Optional[Any]:
"""Load dashboard export JSON file."""
file_path = self.files.get(file_type)
if not file_path or not file_path.exists():
return None
@@ -105,14 +101,12 @@ class StorageService:
logger.error(f"Failed to load {file_type}.json: {e}")
return None
def save_file(self, file_type: str, data: Any):
"""
Save dashboard JSON file
def load_file(self, file_type: str) -> Optional[Any]:
"""Backward-compatible alias for export-layer JSON reads."""
return self.load_export_file(file_type)
Args:
file_type: One of: summary, holdings, stats, trades, leaderboard
data: Data to save
"""
def save_export_file(self, file_type: str, data: Any):
"""Save dashboard export JSON file."""
file_path = self.files.get(file_type)
if not file_path:
logger.error(f"Unknown file type: {file_type}")
@@ -129,6 +123,48 @@ class StorageService:
except Exception as e:
logger.error(f"Failed to save {file_type}.json: {e}")
def save_file(self, file_type: str, data: Any):
"""Backward-compatible alias for export-layer JSON writes."""
self.save_export_file(file_type, data)
def build_dashboard_snapshot_from_state(
self,
state: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""Build dashboard view data from runtime state instead of JSON exports."""
runtime_state = state or self.load_server_state()
portfolio = dict(runtime_state.get("portfolio") or {})
holdings = list(runtime_state.get("holdings") or [])
stats = runtime_state.get("stats") or self._get_default_stats()
trades = list(runtime_state.get("trades") or [])
leaderboard = list(runtime_state.get("leaderboard") or [])
summary = {
"totalAssetValue": portfolio.get("total_value", self.initial_cash),
"totalReturn": portfolio.get("pnl_percent", 0.0),
"cashPosition": portfolio.get("cash", self.initial_cash),
"tickerWeights": stats.get("tickerWeights", {}),
"totalTrades": len(trades),
"pnlPct": portfolio.get("pnl_percent", 0.0),
"balance": portfolio.get("total_value", self.initial_cash),
"equity": portfolio.get("equity", []),
"baseline": portfolio.get("baseline", []),
"baseline_vw": portfolio.get("baseline_vw", []),
"momentum": portfolio.get("momentum", []),
"equity_return": portfolio.get("equity_return", []),
"baseline_return": portfolio.get("baseline_return", []),
"baseline_vw_return": portfolio.get("baseline_vw_return", []),
"momentum_return": portfolio.get("momentum_return", []),
}
return {
"summary": summary,
"holdings": holdings,
"stats": stats,
"trades": trades,
"leaderboard": leaderboard,
}
def check_file_updates(self) -> Dict[str, bool]:
"""
Check which dashboard files have been updated since last check
@@ -297,7 +333,7 @@ class StorageService:
def initialize_empty_dashboard(self):
"""Initialize empty dashboard files with default values"""
# Summary
self.save_file(
self.save_export_file(
"summary",
{
"totalAssetValue": self.initial_cash,
@@ -315,10 +351,10 @@ class StorageService:
)
# Holdings
self.save_file("holdings", [])
self.save_export_file("holdings", [])
# Stats
self.save_file(
self.save_export_file(
"stats",
{
"totalAssetValue": self.initial_cash,
@@ -335,7 +371,7 @@ class StorageService:
)
# Trades
self.save_file("trades", [])
self.save_export_file("trades", [])
# Leaderboard with model info
self.generate_leaderboard()
@@ -375,7 +411,7 @@ class StorageService:
ranking_entries.append(entry)
leaderboard = team_entries + ranking_entries
self.save_file("leaderboard", leaderboard)
self.save_export_file("leaderboard", leaderboard)
logger.info("Leaderboard generated with model info")
def update_leaderboard_model_info(self):
@@ -398,7 +434,7 @@ class StorageService:
entry["modelName"] = model_name
entry["modelProvider"] = model_provider
self.save_file("leaderboard", existing)
self.save_export_file("leaderboard", existing)
logger.info("Leaderboard model info updated")
def get_current_timestamp_ms(self, date: str = None) -> int:
@@ -653,7 +689,7 @@ class StorageService:
"momentum": state.get("momentum_history", []),
}
self.save_file("summary", summary)
self.save_export_file("summary", summary)
def _generate_holdings(
self,
@@ -715,7 +751,7 @@ class StorageService:
# Sort by weight
holdings.sort(key=lambda x: abs(x["weight"]), reverse=True)
self.save_file("holdings", holdings)
self.save_export_file("holdings", holdings)
def _generate_stats(self, state: Dict[str, Any], net_value: float):
"""Generate stats.json"""
@@ -738,7 +774,7 @@ class StorageService:
},
}
self.save_file("stats", stats)
self.save_export_file("stats", stats)
def _generate_trades(self, state: Dict[str, Any]):
"""Generate trades.json"""
@@ -764,7 +800,7 @@ class StorageService:
},
)
self.save_file("trades", trades)
self.save_export_file("trades", trades)
# Server State Management Methods
@@ -1001,12 +1037,12 @@ class StorageService:
Args:
state: Server state dictionary to update
"""
# Load dashboard data
summary = self.load_file("summary") or {}
holdings = self.load_file("holdings") or []
stats = self.load_file("stats") or self._get_default_stats()
trades = self.load_file("trades") or []
leaderboard = self.load_file("leaderboard") or []
dashboard_snapshot = self.build_dashboard_snapshot_from_state(state)
summary = dashboard_snapshot.get("summary") or {}
holdings = dashboard_snapshot.get("holdings") or []
stats = dashboard_snapshot.get("stats") or self._get_default_stats()
trades = dashboard_snapshot.get("trades") or []
leaderboard = dashboard_snapshot.get("leaderboard") or []
internal_state = self.load_internal_state()
# Update state
@@ -1040,7 +1076,6 @@ class StorageService:
Start tracking live returns for current trading session.
Captures current values as session start baseline.
"""
summary = self.load_file("summary") or {}
state = self.load_internal_state()
# Capture current values as session start
@@ -1052,7 +1087,7 @@ class StorageService:
self._session_start_equity = (
equity_history[-1]["v"]
if equity_history
else summary.get("totalAssetValue", self.initial_cash)
else self.initial_cash
)
self._session_start_baseline = (
baseline_history[-1]["v"]

View File

@@ -0,0 +1,119 @@
# Skill Template (Anthropic + AgentScope Aligned)
> 用于定义可执行、可路由、可评估的技能规范。
> 建议所有 `SKILL.md` 至少覆盖以下 6 个部分。
---
## Frontmatter Spec
All `SKILL.md` files should begin with a YAML frontmatter block:
```yaml
---
name: skill_name # Required. Unique identifier for the skill.
description: ... # Required. One-line description of the skill.
version: "1.0.0" # Optional. Semantic version string.
tools: [...] # Optional. Tools provided or used by this skill.
allowed_tools: [...] # Optional. List of tool names permitted when this skill is active.
denied_tools: [...] # Optional. List of tool names denied when this skill is active.
---
```
### Frontmatter Fields
| Field | Type | Description |
|-------|------|-------------|
| `name` | string | Unique skill identifier (kebab-case recommended). |
| `description` | string | Human-readable one-line description. |
| `version` | string | Semantic version (e.g., `"1.0.0"`). |
| `tools` | list[string] | Tools provided by or associated with this skill. |
| `allowed_tools` | list[string] | Enumerates which tools are **permitted** when this skill is active. If set, only these tools may be used. |
| `denied_tools` | list[string] | Enumerates which tools are **forbidden** when this skill is active. Denied tools take precedence over `allowed_tools`. |
### Tool Restriction Rules
- If **only** `allowed_tools` is set: only those tools are accessible.
- If **only** `denied_tools` is set: all tools except those are accessible.
- If **both** are set: `allowed_tools` defines the initial set, then `denied_tools` removes from it.
- **Denial takes precedence**: a tool in `denied_tools` is always blocked even if also in `allowed_tools`.
---
## 1) When to use
- 明确触发条件(任务类型、关键词、场景)。
- 明确不应使用该技能的边界(避免误触发)。
## 2) Required inputs
- 列出最小必要输入(如 `tickers`、价格、组合状态、风险约束)。
- 声明输入缺失时的处理规则(终止 / 降级 / 请求补充)。
## 3) Decision procedure
- 采用固定步骤,确保可复现。
- 每一步说明目标、判据和产物(例如中间结论)。
- 标明冲突处理逻辑(信号冲突、数据冲突、置信度冲突)。
## 4) Tool call policy
- 说明优先使用哪些工具组与工具。
- 规定何时可以“无工具直接结论”,何时必须工具先证据后结论。
- 规定工具失败、超时、返回异常时的替代动作。
## 5) Output schema
- 定义标准输出字段,便于下游 Agent 消费与评估。
- 推荐包含:`signal``confidence``reasons``risks``invalidation``next_action`
- 若是组合决策技能,必须包含每个 ticker 的 `action``quantity`
## 6) Failure fallback
- 规定在数据不足、信号冲突、风险超限、工具不可用时的降级策略。
- 默认优先“保守 + 可解释 + 可执行”的输出。
## Optional: Evaluation hooks
定义技能的可评估指标,用于后续记忆/反思阶段写入长期经验。
### 支持的指标类型
| 指标类型 | 描述 | 适用技能 |
|---------|------|---------|
| `hit_rate` | 信号命中率 - 决策信号与实际结果的符合程度 | sentiment_review, technical_review |
| `risk_violation` | 风控违例率 - 触发风控规则的次数 | risk_review, portfolio_decisioning |
| `position_deviation` | 仓位偏离率 - 建议仓位与实际执行仓位的偏差 | portfolio_decisioning |
| `pnl_attribution` | P&L 归因一致性 - 收益归因与实际收益的匹配度 | fundamental_review, valuation_review |
| `signal_consistency` | 信号一致性 - 多来源信号的一致程度 | sentiment_review |
| `decision_latency` | 决策延迟 - 从输入到决策的耗时 | portfolio_decisioning |
| `tool_usage` | 工具使用率 - 工具调用次数与成功率的比值 | 所有技能 |
| `custom` | 自定义指标 | 特定业务场景 |
### 使用方式
```python
from backend.agents.base.evaluation_hook import EvaluationHook, MetricType
# 在技能执行开始时
evaluation_hook.start_evaluation(
skill_name="technical_review",
inputs={"tickers": ["AAPL"], "prices": {...}}
)
# 在技能执行过程中添加指标
evaluation_hook.add_metric(
name="signal_confidence",
metric_type=MetricType.HIT_RATE,
value=0.85,
metadata={"method": "rsi", "threshold": 30}
)
# 在技能完成时记录结果
evaluation_hook.record_outputs({"signal": "buy", "confidence": 0.8})
evaluation_hook.complete_evaluation(success=True)
```
### 评估结果存储
评估结果自动保存到 `runs/{run_id}/evaluations/{agent_id}/{skill_name}_{timestamp}.json`

View File

@@ -0,0 +1,104 @@
# -*- coding: utf-8 -*-
"""Tests for the extracted agent service surface."""
from pathlib import Path
from fastapi.testclient import TestClient
from backend.apps.agent_service import create_app
from backend.api import agents as agents_module
def test_agent_service_routes_include_control_plane_endpoints(tmp_path):
app = create_app(project_root=tmp_path)
paths = {route.path for route in app.routes}
assert "/health" in paths
assert "/api/status" in paths
assert "/api/workspaces" in paths
assert "/api/guard/pending" in paths
def test_agent_service_excludes_runtime_routes(tmp_path):
app = create_app(project_root=tmp_path)
paths = {route.path for route in app.routes}
assert "/api/runtime/start" not in paths
assert "/api/runtime/gateway/port" not in paths
def test_agent_service_read_routes(monkeypatch, tmp_path):
class _FakeSkillsManager:
project_root = tmp_path
def get_agent_asset_dir(self, config_name, agent_id):
return tmp_path / "runs" / config_name / "agents" / agent_id
def resolve_agent_skill_names(self, config_name, agent_id, default_skills=None):
return ["demo_skill"]
def list_agent_skill_catalog(self, config_name, agent_id):
return [
type(
"Skill",
(),
{
"skill_name": "demo_skill",
"name": "Demo Skill",
"description": "demo",
"version": "1.0.0",
"source": "builtin",
"tools": [],
},
)()
]
def load_agent_skill_document(self, config_name, agent_id, skill_name):
return {"skill_name": skill_name, "content": "# demo"}
class _FakeWorkspaceManager:
def load_agent_file(self, config_name, agent_id, filename):
return f"{config_name}:{agent_id}:{filename}"
monkeypatch.setattr(agents_module, "load_agent_profiles", lambda: {"portfolio_manager": {"skills": ["demo_skill"]}})
monkeypatch.setattr(agents_module, "get_agent_model_info", lambda agent_id: ("deepseek-v3.2", "DASHSCOPE"))
monkeypatch.setattr(
agents_module,
"load_agent_workspace_config",
lambda path: type(
"Cfg",
(),
{
"active_tool_groups": ["portfolio_ops"],
"disabled_tool_groups": [],
"enabled_skills": [],
"disabled_skills": [],
"prompt_files": ["SOUL.md", "MEMORY.md"],
},
)(),
)
monkeypatch.setattr(
agents_module,
"get_bootstrap_config_for_run",
lambda project_root, config_name: type("Bootstrap", (), {"agent_override": lambda self, agent_id: {}})(),
)
app = create_app(project_root=tmp_path)
app.dependency_overrides[agents_module.get_skills_manager] = lambda: _FakeSkillsManager()
app.dependency_overrides[agents_module.get_workspace_manager] = lambda: _FakeWorkspaceManager()
with TestClient(app) as client:
profile = client.get("/api/workspaces/demo/agents/portfolio_manager/profile")
skills = client.get("/api/workspaces/demo/agents/portfolio_manager/skills")
detail = client.get("/api/workspaces/demo/agents/portfolio_manager/skills/demo_skill")
workspace_file = client.get("/api/workspaces/demo/agents/portfolio_manager/files/MEMORY.md")
assert profile.status_code == 200
assert profile.json()["profile"]["model_name"] == "deepseek-v3.2"
assert skills.status_code == 200
assert skills.json()["skills"][0]["skill_name"] == "demo_skill"
assert detail.status_code == 200
assert detail.json()["skill"]["content"] == "# demo"
assert workspace_file.status_code == 200
assert workspace_file.json()["content"] == "demo:portfolio_manager:MEMORY.md"

View File

@@ -34,7 +34,6 @@ def test_live_runs_incremental_market_store_update_before_start(monkeypatch, tmp
monkeypatch.setattr(cli.subprocess, "run", fake_run)
cli.live(
mock=False,
config_name="smoke_fullstack",
host="0.0.0.0",
port=8765,

View File

@@ -0,0 +1,139 @@
# -*- coding: utf-8 -*-
"""Tests for data_tools preferring split services when configured."""
from backend.tools import data_tools
from shared.schema import CompanyNews, FinancialMetrics, InsiderTrade, LineItem, Price
def test_data_tools_prefers_trading_service(monkeypatch):
monkeypatch.setenv("TRADING_SERVICE_URL", "http://localhost:8001")
monkeypatch.setenv("SERVICE_NAME", "agent_service")
monkeypatch.setattr(data_tools._cache, "get_prices", lambda key: None)
monkeypatch.setattr(data_tools._cache, "get_financial_metrics", lambda key: None)
monkeypatch.setattr(data_tools._cache, "get_insider_trades", lambda key: None)
monkeypatch.setattr(data_tools._cache, "get_company_news", lambda key: None)
def fake_service_get_json(base_url, path, *, params):
if path == "/api/prices":
return {
"ticker": "AAPL",
"prices": [
Price(
open=1,
close=2,
high=3,
low=1,
volume=10,
time="2026-03-16",
).model_dump()
],
}
if path == "/api/financials":
return {
"financial_metrics": [
FinancialMetrics(
ticker="AAPL",
report_period="2026-03-16",
period="ttm",
currency="USD",
market_cap=123.0,
enterprise_value=None,
price_to_earnings_ratio=None,
price_to_book_ratio=None,
price_to_sales_ratio=None,
enterprise_value_to_ebitda_ratio=None,
enterprise_value_to_revenue_ratio=None,
free_cash_flow_yield=None,
peg_ratio=None,
gross_margin=None,
operating_margin=None,
net_margin=None,
return_on_equity=None,
return_on_assets=None,
return_on_invested_capital=None,
asset_turnover=None,
inventory_turnover=None,
receivables_turnover=None,
days_sales_outstanding=None,
operating_cycle=None,
working_capital_turnover=None,
current_ratio=None,
quick_ratio=None,
cash_ratio=None,
operating_cash_flow_ratio=None,
debt_to_equity=None,
debt_to_assets=None,
interest_coverage=None,
revenue_growth=None,
earnings_growth=None,
book_value_growth=None,
earnings_per_share_growth=None,
free_cash_flow_growth=None,
operating_income_growth=None,
ebitda_growth=None,
payout_ratio=None,
earnings_per_share=None,
book_value_per_share=None,
free_cash_flow_per_share=None,
).model_dump()
]
}
if path == "/api/insider-trades":
return {
"insider_trades": [
InsiderTrade(ticker="AAPL", filing_date="2026-03-16").model_dump()
]
}
if path == "/api/news":
return {
"news": [
CompanyNews(
ticker="AAPL",
title="Title",
source="polygon",
url="https://example.com",
).model_dump()
]
}
if path == "/api/market-cap":
return {"ticker": "AAPL", "end_date": "2026-03-16", "market_cap": 2.5e12}
if path == "/api/line-items":
return {
"search_results": [
LineItem(
ticker="AAPL",
report_period="2026-03-16",
period="ttm",
currency="USD",
free_cash_flow=321.0,
).model_dump()
]
}
raise AssertionError(path)
monkeypatch.setattr(data_tools, "_service_get_json", fake_service_get_json)
prices = data_tools.get_prices("AAPL", "2026-03-01", "2026-03-16")
metrics = data_tools.get_financial_metrics("AAPL", "2026-03-16")
trades = data_tools.get_insider_trades("AAPL", "2026-03-16")
news = data_tools.get_company_news("AAPL", "2026-03-16")
market_cap = data_tools.get_market_cap("AAPL", "2026-03-16")
line_items = data_tools.search_line_items(
"AAPL",
["free_cash_flow"],
"2026-03-16",
)
assert prices[0].close == 2
assert metrics[0].ticker == "AAPL"
assert trades[0].ticker == "AAPL"
assert news[0].ticker == "AAPL"
assert market_cap == 2.5e12
assert line_items[0].free_cash_flow == 321.0
def test_data_tools_skips_self_recursion_for_trading_service(monkeypatch):
monkeypatch.setenv("TRADING_SERVICE_URL", "http://localhost:8001")
monkeypatch.setenv("SERVICE_NAME", "trading_service")
assert data_tools._trading_service_url() is None

View File

@@ -6,6 +6,7 @@ import pytest
from backend.services.gateway import Gateway
import backend.services.gateway as gateway_module
from shared.schema import InsiderTrade, InsiderTradeResponse, Price, PriceResponse
class DummyWebSocket:
@@ -35,6 +36,10 @@ class FakeMarketStore:
def __init__(self):
self.calls = []
def get_ticker_watermarks(self, symbol):
self.calls.append(("get_ticker_watermarks", symbol))
return {"symbol": symbol, "last_news_fetch": "2026-12-31"}
def get_news_timeline_enriched(self, symbol, *, start_date=None, end_date=None):
self.calls.append(("get_news_timeline_enriched", symbol, start_date, end_date))
return [{"date": end_date, "count": 2, "source_count": 1, "top_title": "Top", "positive_count": 1}]
@@ -123,6 +128,75 @@ def make_gateway(market_store=None):
)
class FakeNewsClient:
def __init__(self, base_url):
self.base_url = base_url
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
return None
async def get_categories(self, ticker, start_date=None, end_date=None, limit=200):
return {"ticker": ticker, "categories": {"remote": {"count": 2}}}
async def get_enriched_news(self, ticker, start_date=None, end_date=None, limit=None):
return {
"ticker": ticker,
"news": [
{
"id": "remote-news-1",
"ticker": ticker,
"title": "Remote Title",
"date": end_date,
}
],
}
async def get_story(self, ticker, as_of_date):
return {"symbol": ticker, "as_of_date": as_of_date, "story": "remote story", "source": "news_service"}
class FakeTradingClient:
def __init__(self, base_url):
self.base_url = base_url
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
return None
async def get_insider_trades(self, ticker, end_date=None, start_date=None, limit=None):
return InsiderTradeResponse(
insider_trades=[
InsiderTrade(
ticker=ticker,
name="Remote Insider",
filing_date=end_date or "2026-03-16",
)
]
)
async def get_prices(self, ticker, start_date=None, end_date=None):
prices = [
Price(
open=float(100 + idx),
close=float(101 + idx),
high=float(102 + idx),
low=float(99 + idx),
volume=1000 + idx,
time=f"2026-01-{idx + 1:02d}",
)
for idx in range(30)
]
return PriceResponse(ticker=ticker, prices=prices)
async def get_market_cap(self, ticker, end_date):
return {"ticker": ticker, "end_date": end_date, "market_cap": 2.5e12}
@pytest.mark.asyncio
async def test_handle_get_stock_news_timeline_uses_market_store_symbol_argument():
market_store = FakeMarketStore()
@@ -135,6 +209,7 @@ async def test_handle_get_stock_news_timeline_uses_market_store_symbol_argument(
)
assert market_store.calls == [
("get_ticker_watermarks", "AAPL"),
("get_news_timeline_enriched", "AAPL", "2026-02-14", "2026-03-16")
]
assert websocket.messages[-1]["type"] == "stock_news_timeline_loaded"
@@ -153,6 +228,7 @@ async def test_handle_get_stock_news_categories_uses_market_store_symbol_argumen
)
assert market_store.calls == [
("get_ticker_watermarks", "AAPL"),
("get_news_items_enriched", "AAPL", "2026-02-14", "2026-03-16", None, 200),
("get_news_categories_enriched", "AAPL", "2026-02-14", "2026-03-16", 200)
]
@@ -175,7 +251,7 @@ async def test_handle_get_stock_range_explain_uses_market_store_rows(monkeypatch
}
monkeypatch.setattr(
gateway_module,
gateway_module.news_domain,
"build_range_explanation",
fake_build_range_explanation,
)
@@ -186,6 +262,7 @@ async def test_handle_get_stock_range_explain_uses_market_store_rows(monkeypatch
)
assert market_store.calls == [
("get_ticker_watermarks", "AAPL"),
("get_news_items_enriched", "AAPL", "2026-03-10", "2026-03-16", None, 100)
]
assert websocket.messages[-1] == {
@@ -207,7 +284,7 @@ async def test_handle_get_stock_range_explain_uses_article_ids_path(monkeypatch)
websocket = DummyWebSocket()
monkeypatch.setattr(
gateway_module,
gateway_module.news_domain,
"build_range_explanation",
lambda **kwargs: {"news_count": len(kwargs["news_rows"])},
)
@@ -222,7 +299,10 @@ async def test_handle_get_stock_range_explain_uses_article_ids_path(monkeypatch)
},
)
assert market_store.calls == [("get_news_by_ids_enriched", "AAPL", ["news-99"])]
assert market_store.calls == [
("get_ticker_watermarks", "AAPL"),
("get_news_by_ids_enriched", "AAPL", ["news-99"])
]
assert websocket.messages[-1]["result"]["news_count"] == 1
@@ -238,6 +318,7 @@ async def test_handle_get_stock_news_for_date_uses_trade_date_lookup():
)
assert market_store.calls == [
("get_ticker_watermarks", "AAPL"),
("get_news_items_enriched", "AAPL", None, None, "2026-03-16", 10)
]
assert websocket.messages[-1]["type"] == "stock_news_for_date_loaded"
@@ -251,7 +332,7 @@ async def test_handle_get_stock_story_returns_story_payload(monkeypatch):
websocket = DummyWebSocket()
monkeypatch.setattr(
gateway_module,
gateway_module.news_domain,
"enrich_news_for_symbol",
lambda *args, **kwargs: {"symbol": "AAPL", "analyzed": 3},
)
@@ -266,6 +347,132 @@ async def test_handle_get_stock_story_returns_story_payload(monkeypatch):
assert "AAPL Story" in websocket.messages[-1]["story"]
@pytest.mark.asyncio
async def test_handle_get_stock_news_categories_uses_news_service_client_when_configured(monkeypatch):
market_store = FakeMarketStore()
gateway = make_gateway(market_store)
websocket = DummyWebSocket()
monkeypatch.setenv("NEWS_SERVICE_URL", "http://news-service.local")
monkeypatch.setattr(gateway_module, "NewsServiceClient", FakeNewsClient)
await gateway._handle_get_stock_news_categories(
websocket,
{"ticker": "AAPL", "lookback_days": 30},
)
assert market_store.calls == []
assert websocket.messages[-1]["type"] == "stock_news_categories_loaded"
assert websocket.messages[-1]["categories"]["remote"]["count"] == 2
@pytest.mark.asyncio
async def test_handle_get_stock_story_uses_news_service_client_when_configured(monkeypatch):
market_store = FakeMarketStore()
gateway = make_gateway(market_store)
websocket = DummyWebSocket()
monkeypatch.setenv("NEWS_SERVICE_URL", "http://news-service.local")
monkeypatch.setattr(gateway_module, "NewsServiceClient", FakeNewsClient)
await gateway._handle_get_stock_story(
websocket,
{"ticker": "AAPL", "as_of_date": "2026-03-16"},
)
assert market_store.calls == []
assert websocket.messages[-1]["type"] == "stock_story_loaded"
assert websocket.messages[-1]["story"] == "remote story"
@pytest.mark.asyncio
async def test_handle_get_stock_news_uses_news_service_client_when_configured(monkeypatch):
market_store = FakeMarketStore()
gateway = make_gateway(market_store)
websocket = DummyWebSocket()
monkeypatch.setenv("NEWS_SERVICE_URL", "http://news-service.local")
monkeypatch.setattr(gateway_module, "NewsServiceClient", FakeNewsClient)
await gateway._handle_get_stock_news(
websocket,
{"ticker": "AAPL", "lookback_days": 30, "limit": 5},
)
assert market_store.calls == []
assert websocket.messages[-1]["type"] == "stock_news_loaded"
assert websocket.messages[-1]["source"] == "news_service"
assert websocket.messages[-1]["news"][0]["title"] == "Remote Title"
@pytest.mark.asyncio
async def test_handle_get_stock_insider_trades_uses_trading_service_client_when_configured(monkeypatch):
market_store = FakeMarketStore()
gateway = make_gateway(market_store)
websocket = DummyWebSocket()
monkeypatch.setenv("TRADING_SERVICE_URL", "http://trading-service.local")
monkeypatch.setattr(gateway_module, "TradingServiceClient", FakeTradingClient)
await gateway._handle_get_stock_insider_trades(
websocket,
{"ticker": "AAPL", "end_date": "2026-03-16", "limit": 10},
)
assert websocket.messages[-1]["type"] == "stock_insider_trades_loaded"
assert websocket.messages[-1]["trades"][0]["name"] == "Remote Insider"
@pytest.mark.asyncio
async def test_handle_get_stock_history_uses_trading_service_client_when_configured(monkeypatch):
market_store = FakeMarketStore()
gateway = make_gateway(market_store)
websocket = DummyWebSocket()
monkeypatch.setenv("TRADING_SERVICE_URL", "http://trading-service.local")
monkeypatch.setattr(gateway_module, "TradingServiceClient", FakeTradingClient)
await gateway._handle_get_stock_history(
websocket,
{"ticker": "AAPL", "lookback_days": 30},
)
assert market_store.calls == []
assert websocket.messages[-1]["type"] == "stock_history_loaded"
assert websocket.messages[-1]["source"] == "trading_service"
assert len(websocket.messages[-1]["prices"]) == 30
@pytest.mark.asyncio
async def test_handle_get_stock_technical_indicators_uses_trading_service_client_when_configured(monkeypatch):
gateway = make_gateway(FakeMarketStore())
websocket = DummyWebSocket()
monkeypatch.setenv("TRADING_SERVICE_URL", "http://trading-service.local")
monkeypatch.setattr(gateway_module, "TradingServiceClient", FakeTradingClient)
await gateway._handle_get_stock_technical_indicators(
websocket,
{"ticker": "AAPL"},
)
assert websocket.messages[-1]["type"] == "stock_technical_indicators_loaded"
assert websocket.messages[-1]["ticker"] == "AAPL"
assert websocket.messages[-1]["indicators"] is not None
@pytest.mark.asyncio
async def test_get_market_caps_uses_trading_service_client_when_configured(monkeypatch):
gateway = make_gateway(FakeMarketStore())
monkeypatch.setenv("TRADING_SERVICE_URL", "http://trading-service.local")
monkeypatch.setattr(gateway_module, "TradingServiceClient", FakeTradingClient)
market_caps = await gateway._get_market_caps(["AAPL", "MSFT"], "2026-03-16")
assert market_caps == {"AAPL": 2.5e12, "MSFT": 2.5e12}
@pytest.mark.asyncio
async def test_handle_get_stock_similar_days_returns_items(monkeypatch):
market_store = FakeMarketStore()
@@ -273,7 +480,7 @@ async def test_handle_get_stock_similar_days_returns_items(monkeypatch):
websocket = DummyWebSocket()
monkeypatch.setattr(
gateway_module,
gateway_module.news_domain,
"enrich_news_for_symbol",
lambda *args, **kwargs: {"symbol": "AAPL", "analyzed": 3},
)
@@ -295,7 +502,12 @@ async def test_handle_run_stock_enrich_rebuilds_caches(monkeypatch):
websocket = DummyWebSocket()
monkeypatch.setattr(
gateway_module,
gateway_module.gateway_stock_handlers,
"enrich_news_for_symbol",
lambda *args, **kwargs: {"symbol": "AAPL", "analyzed": 2, "queued_count": 2},
)
monkeypatch.setattr(
gateway_module.news_domain,
"enrich_news_for_symbol",
lambda *args, **kwargs: {"symbol": "AAPL", "analyzed": 2, "queued_count": 2},
)
@@ -325,7 +537,7 @@ async def test_handle_run_stock_enrich_rejects_local_to_llm_without_llm(monkeypa
gateway = make_gateway(FakeMarketStore())
websocket = DummyWebSocket()
monkeypatch.setattr(gateway_module, "llm_enrichment_enabled", lambda: False)
monkeypatch.setattr(gateway_module.gateway_stock_handlers, "llm_enrichment_enabled", lambda: False)
await gateway._handle_run_stock_enrich(
websocket,
@@ -361,7 +573,7 @@ def test_schedule_watchlist_market_store_refresh_creates_task(monkeypatch):
gateway._schedule_watchlist_market_store_refresh(["AAPL", "MSFT"])
assert captured["coro_name"] == "_refresh_market_store_for_watchlist"
assert captured["coro_name"] == "refresh_market_store_for_watchlist"
@pytest.mark.asyncio
@@ -369,7 +581,7 @@ async def test_refresh_market_store_for_watchlist_emits_system_messages(monkeypa
gateway = make_gateway()
monkeypatch.setattr(
gateway_module,
gateway_module.gateway_cycle_support,
"ingest_symbols",
lambda symbols, mode="incremental": [
{"symbol": symbol, "prices": 3, "news": 4, "aligned": 4}
@@ -445,12 +657,12 @@ async def test_handle_get_agent_profile_returns_model_and_tool_groups(monkeypatc
websocket = DummyWebSocket()
monkeypatch.setattr(
gateway_module,
gateway_module.gateway_admin_handlers,
"load_agent_profiles",
lambda: {"risk_manager": {"skills": ["risk_review"], "active_tool_groups": ["risk_ops", "legacy_group"]}},
)
monkeypatch.setattr(
gateway_module,
gateway_module.gateway_admin_handlers,
"get_agent_model_info",
lambda agent_id: ("gpt-4o-mini", "OPENAI"),
)
@@ -461,7 +673,7 @@ async def test_handle_get_agent_profile_returns_model_and_tool_groups(monkeypatc
return {}
monkeypatch.setattr(
gateway_module,
gateway_module.gateway_admin_handlers,
"get_bootstrap_config_for_run",
lambda project_root, config_name: _Bootstrap(),
)

View File

@@ -0,0 +1,220 @@
# -*- coding: utf-8 -*-
"""Direct tests for Gateway support modules."""
from types import SimpleNamespace
import pytest
from backend.services import gateway_cycle_support, gateway_runtime_support
class _DummyDashboard:
def __init__(self):
self.updated = []
self.tickers = []
self.initial_cash = None
self.enable_memory = False
self.days_total = 0
def update(self, **kwargs):
self.updated.append(kwargs)
def stop(self):
return None
def print_final_summary(self):
return None
class _DummyScheduler:
def __init__(self):
self.calls = []
def reconfigure(self, **kwargs):
self.calls.append(kwargs)
class _DummyStateSync:
def __init__(self):
self.updated = []
self.saved = False
self.system_messages = []
self.backtest_dates = []
self.state = {}
def update_state(self, key, value):
self.updated.append((key, value))
self.state[key] = value
def save_state(self):
self.saved = True
async def on_system_message(self, message):
self.system_messages.append(message)
def set_backtest_dates(self, dates):
self.backtest_dates = list(dates)
class _DummyStorage:
def __init__(self):
self.initial_cash = 100000.0
self.is_live_session_active = False
self.server_state_updates = []
def can_apply_initial_cash(self):
return True
def apply_initial_cash(self, value):
self.initial_cash = value
return True
def update_server_state_from_dashboard(self, state):
self.server_state_updates.append(state)
def load_file(self, name):
if name == "summary":
return {"totalAssetValue": self.initial_cash}
return []
def build_dashboard_snapshot_from_state(self, state):
return {
"summary": {"totalAssetValue": self.initial_cash},
"holdings": [],
"stats": {},
"trades": [],
"leaderboard": [],
}
class _DummyPM:
def __init__(self):
self.portfolio = {"margin_requirement": 0.0}
def apply_runtime_portfolio_config(self, margin_requirement=None, initial_cash=None):
if margin_requirement is not None:
self.portfolio["margin_requirement"] = margin_requirement
return {"margin_requirement": True}
def can_apply_initial_cash(self):
return True
class _DummyMarketService:
def __init__(self):
self.updated = None
self.stopped = False
def update_tickers(self, tickers):
self.updated = list(tickers)
return {"active": list(tickers), "added": list(tickers), "removed": []}
def stop(self):
self.stopped = True
def make_gateway_stub():
pipeline = SimpleNamespace(max_comm_cycles=0, pm=_DummyPM())
gateway = SimpleNamespace(
market_service=_DummyMarketService(),
pipeline=pipeline,
scheduler=_DummyScheduler(),
config={
"tickers": ["AAPL"],
"schedule_mode": "daily",
"interval_minutes": 60,
"trigger_time": "09:30",
"enable_memory": False,
},
storage=_DummyStorage(),
state_sync=_DummyStateSync(),
_dashboard=_DummyDashboard(),
_watchlist_ingest_task=None,
_market_status_task=None,
_backtest_task=None,
_backtest_start_date=None,
_backtest_end_date=None,
_manual_cycle_task=None,
)
return gateway
def test_normalize_watchlist_filters_invalid_and_dedupes():
assert gateway_runtime_support.normalize_watchlist(["aapl", " AAPL ", "", "msft"]) == ["AAPL", "MSFT"]
assert gateway_runtime_support.normalize_watchlist("aapl,msft") == ["AAPL", "MSFT"]
def test_normalize_agent_workspace_filename_obeys_allowlist():
allowlist = {"SOUL.md", "PROFILE.md"}
assert gateway_runtime_support.normalize_agent_workspace_filename("SOUL.md", allowlist=allowlist) == "SOUL.md"
assert gateway_runtime_support.normalize_agent_workspace_filename("README.md", allowlist=allowlist) is None
def test_apply_runtime_config_updates_gateway_state():
gateway = make_gateway_stub()
result = gateway_runtime_support.apply_runtime_config(
gateway,
{
"tickers": ["MSFT", "NVDA"],
"schedule_mode": "intraday",
"interval_minutes": 30,
"trigger_time": "10:30",
"initial_cash": 150000.0,
"margin_requirement": 0.5,
"max_comm_cycles": 4,
"enable_memory": False,
},
)
assert gateway.config["tickers"] == ["MSFT", "NVDA"]
assert gateway.config["schedule_mode"] == "intraday"
assert gateway.storage.initial_cash == 150000.0
assert result["runtime_config_applied"]["max_comm_cycles"] == 4
assert gateway.scheduler.calls[-1] == {
"mode": "intraday",
"trigger_time": "10:30",
"interval_minutes": 30,
}
def test_schedule_watchlist_market_store_refresh_creates_task(monkeypatch):
gateway = make_gateway_stub()
captured = {}
class DummyTask:
def done(self):
return False
def cancel(self):
captured["cancelled"] = True
def fake_create_task(coro):
captured["name"] = coro.cr_code.co_name
coro.close()
return DummyTask()
monkeypatch.setattr(gateway_cycle_support.asyncio, "create_task", fake_create_task)
gateway_cycle_support.schedule_watchlist_market_store_refresh(gateway, ["AAPL", "MSFT"])
assert captured["name"] == "refresh_market_store_for_watchlist"
@pytest.mark.asyncio
async def test_refresh_market_store_for_watchlist_emits_system_messages(monkeypatch):
gateway = make_gateway_stub()
monkeypatch.setattr(
gateway_cycle_support,
"ingest_symbols",
lambda symbols, mode="incremental": [
{"symbol": symbol, "prices": 3, "news": 4}
for symbol in symbols
],
)
await gateway_cycle_support.refresh_market_store_for_watchlist(gateway, ["AAPL", "MSFT"])
assert gateway.state_sync.system_messages[0] == "正在同步自选股市场数据: AAPL, MSFT"
assert "自选股市场数据已同步:" in gateway.state_sync.system_messages[1]

View File

@@ -0,0 +1,69 @@
# -*- coding: utf-8 -*-
"""Tests for HeartbeatHook."""
import tempfile
from pathlib import Path
import pytest
from backend.agents.base.hooks import HeartbeatHook
class TestHeartbeatHook:
"""Tests for HeartbeatHook._read_heartbeat_content."""
def test_read_heartbeat_content_with_content(self, tmp_path):
"""Test reading HEARTBEAT.md when it exists and has content."""
ws_dir = tmp_path / "analyst_workspace"
ws_dir.mkdir()
hb_file = ws_dir / "HEARTBEAT.md"
hb_file.write_text("# 定期主动检查\n\n- [ ] 持仓是否健康\n", encoding="utf-8")
hook = HeartbeatHook(workspace_dir=ws_dir)
content = hook._read_heartbeat_content()
assert content is not None
assert "# 定期主动检查" in content
assert "持仓是否健康" in content
def test_read_heartbeat_content_absent(self, tmp_path):
"""Test reading when HEARTBEAT.md does not exist."""
ws_dir = tmp_path / "analyst_workspace"
ws_dir.mkdir()
hook = HeartbeatHook(workspace_dir=ws_dir)
content = hook._read_heartbeat_content()
assert content is None
def test_read_heartbeat_content_empty(self, tmp_path):
"""Test reading when HEARTBEAT.md is empty."""
ws_dir = tmp_path / "analyst_workspace"
ws_dir.mkdir()
hb_file = ws_dir / "HEARTBEAT.md"
hb_file.write_text("", encoding="utf-8")
hook = HeartbeatHook(workspace_dir=ws_dir)
content = hook._read_heartbeat_content()
assert content is None
def test_read_heartbeat_content_whitespace_only(self, tmp_path):
"""Test reading when HEARTBEAT.md contains only whitespace."""
ws_dir = tmp_path / "analyst_workspace"
ws_dir.mkdir()
hb_file = ws_dir / "HEARTBEAT.md"
hb_file.write_text(" \n\n ", encoding="utf-8")
hook = HeartbeatHook(workspace_dir=ws_dir)
content = hook._read_heartbeat_content()
assert content is None
def test_completed_flag_path(self, tmp_path):
"""Test that completion flag is placed in workspace directory."""
ws_dir = tmp_path / "analyst_workspace"
ws_dir.mkdir()
hook = HeartbeatHook(workspace_dir=ws_dir)
assert hook._completed_flag == ws_dir / ".heartbeat_completed"

View File

@@ -2,153 +2,12 @@
# pylint: disable=W0212
import asyncio
import time
import logging
from unittest.mock import MagicMock, AsyncMock, patch
import pytest
from backend.services.market import MarketService
from backend.data.mock_price_manager import MockPriceManager
from backend.data.polling_price_manager import PollingPriceManager
class TestMockPriceManager:
def test_init_default(self):
manager = MockPriceManager()
assert manager.poll_interval == 10
assert manager.volatility == 0.5
assert manager.running is False
assert len(manager.subscribed_symbols) == 0
def test_init_custom(self):
manager = MockPriceManager(poll_interval=5, volatility=1.0)
assert manager.poll_interval == 5
assert manager.volatility == 1.0
def test_subscribe(self):
manager = MockPriceManager()
manager.subscribe(["AAPL", "MSFT"])
assert "AAPL" in manager.subscribed_symbols
assert "MSFT" in manager.subscribed_symbols
assert manager.base_prices["AAPL"] == 237.50 # default price
assert manager.base_prices["MSFT"] == 425.30 # default price
def test_subscribe_with_base_prices(self):
manager = MockPriceManager()
manager.subscribe(["AAPL"], base_prices={"AAPL": 100.0})
assert manager.base_prices["AAPL"] == 100.0
assert manager.open_prices["AAPL"] == 100.0
assert manager.latest_prices["AAPL"] == 100.0
def test_subscribe_unknown_symbol(self):
manager = MockPriceManager()
manager.subscribe(["UNKNOWN"])
assert "UNKNOWN" in manager.subscribed_symbols
assert manager.base_prices["UNKNOWN"] > 0 # random price generated
def test_unsubscribe(self):
manager = MockPriceManager()
manager.subscribe(["AAPL", "MSFT"])
manager.unsubscribe(["AAPL"])
assert "AAPL" not in manager.subscribed_symbols
assert "MSFT" in manager.subscribed_symbols
def test_add_price_callback(self):
manager = MockPriceManager()
callback = MagicMock()
manager.add_price_callback(callback)
assert callback in manager.price_callbacks
def test_generate_price_update_within_bounds(self):
manager = MockPriceManager(volatility=0.5)
manager.subscribe(["AAPL"], base_prices={"AAPL": 100.0})
for _ in range(100):
new_price = manager._generate_price_update("AAPL")
# Should be within +/-10% of open
assert 90.0 <= new_price <= 110.0
def test_update_prices_triggers_callback(self):
manager = MockPriceManager()
manager.subscribe(["AAPL"], base_prices={"AAPL": 100.0})
callback = MagicMock()
manager.add_price_callback(callback)
manager._update_prices()
callback.assert_called_once()
call_args = callback.call_args[0][0]
assert call_args["symbol"] == "AAPL"
assert "price" in call_args
assert "timestamp" in call_args
def test_start_stop(self):
manager = MockPriceManager(poll_interval=1)
manager.subscribe(["AAPL"], base_prices={"AAPL": 100.0})
manager.start()
assert manager.running is True
time.sleep(0.1) # let thread start
manager.stop()
assert manager.running is False
def test_start_without_subscription(self):
manager = MockPriceManager()
manager.start()
assert (
manager.running is False
) # should not start without subscriptions
def test_get_latest_price(self):
manager = MockPriceManager()
manager.subscribe(["AAPL"], base_prices={"AAPL": 100.0})
price = manager.get_latest_price("AAPL")
assert price == 100.0
def test_get_latest_price_unknown(self):
manager = MockPriceManager()
price = manager.get_latest_price("UNKNOWN")
assert price is None
def test_get_all_latest_prices(self):
manager = MockPriceManager()
manager.subscribe(
["AAPL", "MSFT"],
base_prices={"AAPL": 100.0, "MSFT": 200.0},
)
prices = manager.get_all_latest_prices()
assert prices["AAPL"] == 100.0
assert prices["MSFT"] == 200.0
def test_reset_open_prices(self):
manager = MockPriceManager()
manager.subscribe(["AAPL"], base_prices={"AAPL": 100.0})
manager.latest_prices["AAPL"] = 105.0
manager.reset_open_prices()
# Open price should change (based on latest with small gap)
assert manager.open_prices["AAPL"] != 100.0
def test_set_base_price(self):
manager = MockPriceManager()
manager.subscribe(["AAPL"], base_prices={"AAPL": 100.0})
manager.set_base_price("AAPL", 150.0)
assert manager.base_prices["AAPL"] == 150.0
assert manager.open_prices["AAPL"] == 150.0
assert manager.latest_prices["AAPL"] == 150.0
from backend.llm.models import RetryChatModel
class TestPollingPriceManager:
@@ -231,37 +90,67 @@ class TestPollingPriceManager:
assert len(manager.open_prices) == 0
def test_fetch_prices_suppresses_repeated_failures(self, caplog):
manager = PollingPriceManager(provider="yfinance", poll_interval=10)
manager.subscribe(["AAPL"])
with patch.object(manager, "_fetch_quote", side_effect=ValueError("empty quote")):
with caplog.at_level(logging.DEBUG):
for _ in range(3):
manager._fetch_prices()
assert manager._failure_counts["AAPL"] == 3
warning_messages = [record.message for record in caplog.records if record.levelno >= logging.WARNING]
assert any("Failed to fetch AAPL price: empty quote" in message for message in warning_messages)
def test_fetch_prices_logs_recovery_after_failure(self, caplog):
manager = PollingPriceManager(provider="yfinance", poll_interval=10)
manager.subscribe(["AAPL"])
with patch.object(
manager,
"_fetch_quote",
side_effect=[
ValueError("temporary outage"),
{"c": 100.0, "o": 99.0, "h": 101.0, "l": 98.0, "pc": 99.5, "d": 0.5, "dp": 0.5, "t": 1},
],
):
with caplog.at_level(logging.INFO):
manager._fetch_prices()
manager._fetch_prices()
assert "AAPL" not in manager._failure_counts
assert any("recovered after 1 consecutive failures" in record.message for record in caplog.records)
class TestRetryChatModel:
@pytest.mark.asyncio
async def test_async_retry_recovers_from_disconnect(self):
attempts = {"count": 0}
class FakeAsyncModel:
model_name = "fake-async-model"
async def __call__(self, *args, **kwargs):
attempts["count"] += 1
if attempts["count"] < 2:
raise RuntimeError("Server disconnected")
return {"ok": True}
wrapped = RetryChatModel(FakeAsyncModel(), max_retries=2, initial_delay=0.01)
result = await wrapped("hello")
assert result == {"ok": True}
assert attempts["count"] == 2
class TestMarketService:
def test_init_mock_mode(self):
service = MarketService(
tickers=["AAPL", "MSFT"],
poll_interval=10,
mock_mode=True,
)
assert service.tickers == ["AAPL", "MSFT"]
assert service.poll_interval == 10
assert service.mock_mode is True
assert service.running is False
def test_init_real_mode(self):
service = MarketService(
tickers=["AAPL"],
mock_mode=False,
api_key="test_key",
)
assert service.mock_mode is False
assert service.api_key == "test_key"
@patch("backend.services.market.get_data_source", return_value="yfinance")
@patch("backend.services.market.get_data_sources", return_value=["yfinance", "local_csv"])
@patch.object(PollingPriceManager, "start")
def test_start_real_mode_with_yfinance(self, _mock_start, _mock_source):
def test_start_real_mode_with_yfinance(self, _mock_start, _mock_sources):
service = MarketService(
tickers=["AAPL"],
poll_interval=10,
mock_mode=False,
)
service._start_real_mode()
@@ -269,30 +158,24 @@ class TestMarketService:
assert isinstance(service._price_manager, PollingPriceManager)
assert service._price_manager.provider == "yfinance"
@pytest.mark.asyncio
async def test_start_mock_mode(self):
@patch("backend.services.market.get_data_sources", return_value=["financial_datasets", "yfinance", "local_csv"])
@patch.object(PollingPriceManager, "start")
def test_start_real_mode_uses_first_supported_live_provider(self, _mock_start, _mock_sources):
service = MarketService(
tickers=["AAPL"],
poll_interval=10,
mock_mode=True,
)
broadcast_func = AsyncMock()
service._start_real_mode()
await service.start(broadcast_func)
assert isinstance(service._price_manager, PollingPriceManager)
assert service._price_manager.provider == "yfinance"
assert service.running is True
assert service._price_manager is not None
assert isinstance(service._price_manager, MockPriceManager)
service.stop()
@patch("backend.services.market.get_data_source", return_value="finnhub")
@patch("backend.services.market.get_data_sources", return_value=["finnhub", "yfinance"])
@pytest.mark.asyncio
async def test_start_real_mode_without_api_key(self, _mock_source):
async def test_start_real_mode_without_api_key(self, _mock_sources):
service = MarketService(
tickers=["AAPL"],
mock_mode=False,
api_key=None,
)
@@ -307,11 +190,12 @@ class TestMarketService:
async def test_start_already_running(self):
service = MarketService(
tickers=["AAPL"],
mock_mode=True,
backtest_mode=True,
)
broadcast_func = AsyncMock()
# First start with backtest mode
await service.start(broadcast_func)
assert service.running is True
@@ -323,7 +207,7 @@ class TestMarketService:
def test_stop(self):
service = MarketService(
tickers=["AAPL"],
mock_mode=True,
backtest_mode=True,
)
service.running = True
service._price_manager = MagicMock()
@@ -336,7 +220,7 @@ class TestMarketService:
def test_stop_when_not_running(self):
service = MarketService(
tickers=["AAPL"],
mock_mode=True,
backtest_mode=True,
)
# Should not raise
@@ -344,20 +228,20 @@ class TestMarketService:
assert service.running is False
def test_get_price_sync(self):
service = MarketService(tickers=["AAPL"], mock_mode=True)
service = MarketService(tickers=["AAPL"], backtest_mode=True)
service.cache["AAPL"] = {"price": 150.0, "open": 148.0}
price = service.get_price_sync("AAPL")
assert price == 150.0
def test_get_price_sync_not_found(self):
service = MarketService(tickers=["AAPL"], mock_mode=True)
service = MarketService(tickers=["AAPL"], backtest_mode=True)
price = service.get_price_sync("MSFT")
assert price is None
def test_get_all_prices(self):
service = MarketService(tickers=["AAPL", "MSFT"], mock_mode=True)
service = MarketService(tickers=["AAPL", "MSFT"], backtest_mode=True)
service.cache["AAPL"] = {"price": 150.0}
service.cache["MSFT"] = {"price": 400.0}
@@ -368,7 +252,7 @@ class TestMarketService:
@pytest.mark.asyncio
async def test_broadcast_price_update(self):
service = MarketService(tickers=["AAPL"], mock_mode=True)
service = MarketService(tickers=["AAPL"], backtest_mode=True)
service._broadcast_func = AsyncMock()
price_data = {
@@ -388,7 +272,7 @@ class TestMarketService:
@pytest.mark.asyncio
async def test_broadcast_price_update_no_func(self):
service = MarketService(tickers=["AAPL"], mock_mode=True)
service = MarketService(tickers=["AAPL"], backtest_mode=True)
service._broadcast_func = None
price_data = {"symbol": "AAPL", "price": 150.0, "open": 148.0}
@@ -396,67 +280,6 @@ class TestMarketService:
# Should not raise
await service._broadcast_price_update(price_data)
@pytest.mark.asyncio
async def test_price_callback_thread_safety(self):
service = MarketService(
tickers=["AAPL"],
poll_interval=1,
mock_mode=True,
)
received_prices = []
async def capture_broadcast(msg):
received_prices.append(msg)
await service.start(capture_broadcast)
# Wait for at least one price update
await asyncio.sleep(1.5)
service.stop()
# Should have received at least one price update
assert len(received_prices) >= 1
assert received_prices[0]["type"] == "price_update"
class TestMarketServiceIntegration:
@pytest.mark.asyncio
async def test_full_mock_cycle(self):
service = MarketService(
tickers=["AAPL", "MSFT"],
poll_interval=1,
mock_mode=True,
)
messages = []
async def collect_messages(msg):
messages.append(msg)
await service.start(collect_messages)
# Wait for price updates
await asyncio.sleep(2.5)
service.stop()
# Should have received multiple price updates
assert len(messages) >= 2
# Check message structure
symbols_seen = set()
for msg in messages:
assert msg["type"] == "price_update"
assert "symbol" in msg
assert "price" in msg
assert "ret" in msg
symbols_seen.add(msg["symbol"])
# Should have prices for both tickers
assert "AAPL" in symbols_seen or "MSFT" in symbols_seen
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,197 @@
# -*- coding: utf-8 -*-
"""Unit tests for the news domain helpers."""
from backend.domains import news as news_domain
class _FakeStore:
def __init__(self):
self.calls = []
def get_ticker_watermarks(self, symbol):
self.calls.append(("get_ticker_watermarks", symbol))
return {"symbol": symbol, "last_news_fetch": "2026-03-10"}
def get_news_items_enriched(self, ticker, start_date=None, end_date=None, trade_date=None, limit=100):
self.calls.append(("get_news_items_enriched", ticker, start_date, end_date, trade_date, limit))
target = trade_date or end_date
return [{"id": "n1", "ticker": ticker, "date": target, "trade_date": target}]
def get_news_timeline_enriched(self, ticker, start_date=None, end_date=None):
self.calls.append(("get_news_timeline_enriched", ticker, start_date, end_date))
return [{"date": end_date, "count": 1}]
def get_news_categories_enriched(self, ticker, start_date=None, end_date=None, limit=200):
self.calls.append(("get_news_categories_enriched", ticker, start_date, end_date, limit))
return {"macro": {"count": 1}}
def get_news_by_ids_enriched(self, ticker, article_ids):
self.calls.append(("get_news_by_ids_enriched", ticker, list(article_ids)))
return [{"id": article_ids[0], "ticker": ticker, "date": "2026-03-16"}]
def test_news_rows_need_enrichment_detects_missing_fields():
assert news_domain.news_rows_need_enrichment([]) is True
assert news_domain.news_rows_need_enrichment([{"sentiment": "", "relevance": "", "key_discussion": ""}]) is True
assert news_domain.news_rows_need_enrichment([{"sentiment": "positive"}]) is False
def test_ensure_news_fresh_triggers_incremental_refresh_when_watermark_is_stale(monkeypatch):
store = _FakeStore()
calls = []
monkeypatch.setattr(
news_domain,
"update_ticker_incremental",
lambda symbol, end_date=None, store=None: calls.append((symbol, end_date)),
)
payload = news_domain.ensure_news_fresh(store, ticker="AAPL", target_date="2026-03-16")
assert calls == [("AAPL", "2026-03-16")]
assert payload["target_date"] == "2026-03-16"
assert payload["refreshed"] is True
def test_ensure_news_fresh_skips_refresh_when_watermark_is_current(monkeypatch):
store = _FakeStore()
calls = []
monkeypatch.setattr(
store,
"get_ticker_watermarks",
lambda symbol: {"symbol": symbol, "last_news_fetch": "2026-03-16"},
)
monkeypatch.setattr(
news_domain,
"update_ticker_incremental",
lambda symbol, end_date=None, store=None: calls.append((symbol, end_date)),
)
payload = news_domain.ensure_news_fresh(store, ticker="AAPL", target_date="2026-03-16")
assert calls == []
assert payload["refreshed"] is False
def test_get_enriched_news_returns_rows_without_enrichment_when_present(monkeypatch):
store = _FakeStore()
monkeypatch.setattr(news_domain, "news_rows_need_enrichment", lambda rows: False)
monkeypatch.setattr(
news_domain,
"ensure_news_fresh",
lambda store, ticker, target_date=None, refresh_if_stale=False: {
"ticker": ticker,
"target_date": target_date,
"last_news_fetch": target_date,
"refreshed": False,
},
)
payload = news_domain.get_enriched_news(
store,
ticker="AAPL",
start_date="2026-03-01",
end_date="2026-03-16",
limit=20,
)
assert payload["ticker"] == "AAPL"
assert payload["news"][0]["ticker"] == "AAPL"
assert payload["freshness"]["target_date"] is None or payload["freshness"]["target_date"] == "2026-03-16"
assert store.calls == [
("get_news_items_enriched", "AAPL", "2026-03-01", "2026-03-16", None, 20)
]
def test_get_story_and_similar_days_delegate(monkeypatch):
store = _FakeStore()
monkeypatch.setattr(
news_domain,
"ensure_news_fresh",
lambda store, ticker, target_date=None, refresh_if_stale=False: {
"ticker": ticker,
"target_date": target_date,
"last_news_fetch": target_date,
"refreshed": False,
},
)
monkeypatch.setattr(news_domain, "enrich_news_for_symbol", lambda *args, **kwargs: {"analyzed": 1})
monkeypatch.setattr(
news_domain,
"get_or_create_stock_story",
lambda store, symbol, as_of_date: {"symbol": symbol, "as_of_date": as_of_date, "story": "story"},
)
monkeypatch.setattr(
news_domain,
"find_similar_days",
lambda store, symbol, target_date, top_k: {"symbol": symbol, "target_date": target_date, "items": [{"score": 0.9}]},
)
story = news_domain.get_story_payload(store, ticker="AAPL", as_of_date="2026-03-16")
similar = news_domain.get_similar_days_payload(store, ticker="AAPL", date="2026-03-16", n_similar=8)
assert story["story"] == "story"
assert "freshness" in story
assert similar["items"][0]["score"] == 0.9
assert "freshness" in similar
def test_get_enriched_news_defaults_to_read_only_freshness(monkeypatch):
store = _FakeStore()
ensure_calls = []
def fake_ensure(store, ticker, target_date=None, refresh_if_stale=False):
ensure_calls.append(refresh_if_stale)
return {
"ticker": ticker,
"target_date": target_date,
"last_news_fetch": target_date,
"refreshed": False,
}
monkeypatch.setattr(news_domain, "ensure_news_fresh", fake_ensure)
monkeypatch.setattr(news_domain, "news_rows_need_enrichment", lambda rows: False)
payload = news_domain.get_enriched_news(
store,
ticker="AAPL",
end_date="2026-03-16",
)
assert payload["ticker"] == "AAPL"
assert ensure_calls == [False]
def test_get_range_explain_payload_uses_article_ids(monkeypatch):
store = _FakeStore()
monkeypatch.setattr(
news_domain,
"ensure_news_fresh",
lambda store, ticker, target_date=None, refresh_if_stale=False: {
"ticker": ticker,
"target_date": target_date,
"last_news_fetch": target_date,
"refreshed": False,
},
)
monkeypatch.setattr(news_domain, "news_rows_need_enrichment", lambda rows: False)
monkeypatch.setattr(
news_domain,
"build_range_explanation",
lambda ticker, start_date, end_date, news_rows: {"ticker": ticker, "count": len(news_rows)},
)
payload = news_domain.get_range_explain_payload(
store,
ticker="AAPL",
start_date="2026-03-10",
end_date="2026-03-16",
article_ids=["news-9"],
limit=50,
)
assert payload["ticker"] == "AAPL"
assert payload["result"] == {"ticker": "AAPL", "count": 1}
assert "freshness" in payload
assert store.calls == [("get_news_by_ids_enriched", "AAPL", ["news-9"])]

View File

@@ -0,0 +1,180 @@
# -*- coding: utf-8 -*-
"""Tests for the extracted news service app surface."""
from fastapi.testclient import TestClient
from backend.apps.news_service import create_app
class _FakeStore:
def get_ticker_watermarks(self, symbol):
return {"symbol": symbol, "last_news_fetch": "2026-12-31"}
def get_news_timeline_enriched(self, symbol, start_date=None, end_date=None):
return [{"date": end_date, "count": 1}]
def get_news_items(self, symbol, start_date=None, end_date=None, limit=100):
return [{"id": "news-raw-1", "ticker": symbol, "title": "Raw Title", "date": end_date}]
def get_news_items_enriched(self, symbol, start_date=None, end_date=None, trade_date=None, limit=100):
return [{"id": "news-1", "ticker": symbol, "title": "Title", "date": trade_date or end_date}]
def upsert_news_analysis(self, symbol, rows):
return len(rows)
def get_analyzed_news_ids(self, symbol, start_date=None, end_date=None):
return set()
def get_news_categories_enriched(self, symbol, start_date=None, end_date=None, limit=200):
return {"market": {"label": "market", "count": 1, "article_ids": ["news-1"]}}
def get_news_by_ids_enriched(self, symbol, article_ids):
return [{"id": article_ids[0], "ticker": symbol, "title": "Picked"}]
def test_news_service_routes_are_exposed():
app = create_app()
paths = {route.path for route in app.routes}
assert "/health" in paths
assert "/api/enriched-news" in paths
assert "/api/news-for-date" in paths
assert "/api/news-timeline" in paths
assert "/api/categories" in paths
assert "/api/similar-days" in paths
assert "/api/stories/{ticker}" in paths
assert "/api/range-explain" in paths
def test_news_service_enriched_news_and_categories(monkeypatch):
app = create_app()
app.dependency_overrides.clear()
from backend.apps import news_service as news_service_module
app.dependency_overrides[news_service_module.get_market_store] = lambda: _FakeStore()
monkeypatch.setattr(
"backend.domains.news.enrich_news_for_symbol",
lambda *args, **kwargs: {"symbol": "AAPL", "analyzed": 1},
)
with TestClient(app) as client:
news_response = client.get(
"/api/enriched-news",
params={"ticker": "AAPL", "end_date": "2026-03-23"},
)
categories_response = client.get(
"/api/categories",
params={"ticker": "AAPL", "end_date": "2026-03-23"},
)
assert news_response.status_code == 200
assert news_response.json()["news"][0]["ticker"] == "AAPL"
assert categories_response.status_code == 200
assert categories_response.json()["categories"]["market"]["count"] == 1
def test_news_service_news_for_date_and_timeline(monkeypatch):
app = create_app()
from backend.apps import news_service as news_service_module
app.dependency_overrides[news_service_module.get_market_store] = lambda: _FakeStore()
monkeypatch.setattr(
"backend.domains.news.enrich_news_for_symbol",
lambda *args, **kwargs: {"symbol": "AAPL", "analyzed": 1},
)
with TestClient(app) as client:
date_response = client.get(
"/api/news-for-date",
params={"ticker": "AAPL", "date": "2026-03-23"},
)
timeline_response = client.get(
"/api/news-timeline",
params={
"ticker": "AAPL",
"start_date": "2026-03-01",
"end_date": "2026-03-23",
},
)
assert date_response.status_code == 200
assert date_response.json()["date"] == "2026-03-23"
assert timeline_response.status_code == 200
assert timeline_response.json()["timeline"][0]["count"] == 1
def test_news_service_similar_days_and_story(monkeypatch):
app = create_app()
from backend.apps import news_service as news_service_module
app.dependency_overrides[news_service_module.get_market_store] = lambda: _FakeStore()
monkeypatch.setattr(
"backend.domains.news.enrich_news_for_symbol",
lambda *args, **kwargs: {"symbol": "AAPL", "analyzed": 1},
)
monkeypatch.setattr(
"backend.domains.news.find_similar_days",
lambda store, symbol, target_date, top_k: {
"symbol": symbol,
"target_date": target_date,
"items": [{"date": "2026-03-20", "score": 0.9}],
},
)
monkeypatch.setattr(
"backend.domains.news.get_or_create_stock_story",
lambda store, symbol, as_of_date: {
"symbol": symbol,
"as_of_date": as_of_date,
"story": "story body",
"source": "local",
},
)
with TestClient(app) as client:
similar_response = client.get(
"/api/similar-days",
params={"ticker": "AAPL", "date": "2026-03-23", "n_similar": 3},
)
story_response = client.get(
"/api/stories/AAPL",
params={"as_of_date": "2026-03-23"},
)
assert similar_response.status_code == 200
assert similar_response.json()["items"][0]["score"] == 0.9
assert story_response.status_code == 200
assert story_response.json()["story"] == "story body"
def test_news_service_range_explain(monkeypatch):
app = create_app()
from backend.apps import news_service as news_service_module
app.dependency_overrides[news_service_module.get_market_store] = lambda: _FakeStore()
monkeypatch.setattr(
"backend.domains.news.enrich_news_for_symbol",
lambda *args, **kwargs: {"symbol": "AAPL", "analyzed": 1},
)
monkeypatch.setattr(
"backend.domains.news.build_range_explanation",
lambda ticker, start_date, end_date, news_rows: {
"symbol": ticker,
"news_count": len(news_rows),
"start_date": start_date,
"end_date": end_date,
},
)
with TestClient(app) as client:
response = client.get(
"/api/range-explain",
params={
"ticker": "AAPL",
"start_date": "2026-03-01",
"end_date": "2026-03-23",
"article_ids": ["news-7"],
},
)
assert response.status_code == 200
assert response.json()["result"]["news_count"] == 1

View File

@@ -9,6 +9,7 @@ def test_router_includes_local_csv_fallback(monkeypatch):
monkeypatch.delenv("FINNHUB_API_KEY", raising=False)
monkeypatch.delenv("FINANCIAL_DATASETS_API_KEY", raising=False)
monkeypatch.delenv("FIN_DATA_SOURCE", raising=False)
monkeypatch.delenv("ENABLED_DATA_SOURCES", raising=False)
reset_config()
router = DataProviderRouter()

View File

@@ -0,0 +1,364 @@
# -*- coding: utf-8 -*-
"""Tests for the extracted runtime service app surface."""
import json
from pathlib import Path
from fastapi.testclient import TestClient
from backend.api import runtime as runtime_module
from backend.apps.runtime_service import create_app
def test_runtime_service_routes_are_exposed():
app = create_app()
paths = {route.path for route in app.routes}
assert "/health" in paths
assert "/api/status" in paths
assert "/api/runtime/start" in paths
assert "/api/runtime/stop" in paths
assert "/api/runtime/cleanup" in paths
assert "/api/runtime/history" in paths
assert "/api/runtime/current" in paths
assert "/api/runtime/gateway/port" in paths
def test_runtime_service_health_and_status(monkeypatch):
runtime_state = runtime_module.get_runtime_state()
runtime_state.gateway_process = None
runtime_state.gateway_port = 9876
runtime_state.runtime_manager = object()
with TestClient(create_app()) as client:
health_response = client.get("/health")
status_response = client.get("/api/status")
assert health_response.status_code == 200
assert health_response.json() == {
"status": "healthy",
"service": "runtime-service",
"gateway_running": False,
"gateway_port": 9876,
}
assert status_response.status_code == 200
assert status_response.json() == {
"status": "operational",
"service": "runtime-service",
"runtime": {
"gateway_running": False,
"gateway_port": 9876,
"has_runtime_manager": True,
},
}
def test_runtime_service_gateway_port_endpoint_uses_runtime_router(monkeypatch):
runtime_module.get_runtime_state().gateway_port = 9345
monkeypatch.setattr(runtime_module, "_is_gateway_running", lambda: True)
with TestClient(create_app()) as client:
response = client.get(
"/api/runtime/gateway/port",
headers={"host": "runtime.example:8003", "x-forwarded-proto": "https"},
)
assert response.status_code == 200
assert response.json() == {
"port": 9345,
"is_running": True,
"ws_url": "wss://runtime.example:9345",
}
def test_runtime_service_get_runtime_config(monkeypatch, tmp_path):
run_dir = tmp_path / "runs" / "demo"
state_dir = run_dir / "state"
state_dir.mkdir(parents=True)
(run_dir / "BOOTSTRAP.md").write_text(
"---\n"
"tickers:\n"
" - AAPL\n"
"schedule_mode: intraday\n"
"interval_minutes: 30\n"
"trigger_time: '10:00'\n"
"max_comm_cycles: 3\n"
"enable_memory: true\n"
"---\n",
encoding="utf-8",
)
(state_dir / "runtime_state.json").write_text(
json.dumps(
{
"context": {
"config_name": "demo",
"run_dir": str(run_dir),
"bootstrap_values": {
"tickers": ["AAPL"],
"schedule_mode": "intraday",
"interval_minutes": 30,
"trigger_time": "10:00",
"max_comm_cycles": 3,
"enable_memory": True,
},
}
}
),
encoding="utf-8",
)
monkeypatch.setattr(runtime_module, "PROJECT_ROOT", tmp_path)
monkeypatch.setattr(runtime_module, "_is_gateway_running", lambda: True)
runtime_module.get_runtime_state().gateway_port = 8765
with TestClient(create_app()) as client:
response = client.get("/api/runtime/config")
assert response.status_code == 200
payload = response.json()
assert payload["run_id"] == "demo"
assert payload["bootstrap"]["schedule_mode"] == "intraday"
assert payload["resolved"]["interval_minutes"] == 30
assert payload["resolved"]["enable_memory"] is True
def test_runtime_service_update_runtime_config_persists_bootstrap(monkeypatch, tmp_path):
run_dir = tmp_path / "runs" / "demo"
state_dir = run_dir / "state"
state_dir.mkdir(parents=True)
(run_dir / "BOOTSTRAP.md").write_text(
"---\n"
"tickers:\n"
" - AAPL\n"
"schedule_mode: daily\n"
"interval_minutes: 60\n"
"trigger_time: '09:30'\n"
"max_comm_cycles: 2\n"
"---\n",
encoding="utf-8",
)
(state_dir / "runtime_state.json").write_text(
json.dumps(
{
"context": {
"config_name": "demo",
"run_dir": str(run_dir),
"bootstrap_values": {
"tickers": ["AAPL"],
"schedule_mode": "daily",
"interval_minutes": 60,
"trigger_time": "09:30",
"max_comm_cycles": 2,
},
}
}
),
encoding="utf-8",
)
class _DummyContext:
def __init__(self):
self.bootstrap_values = {
"tickers": ["AAPL"],
"schedule_mode": "daily",
"interval_minutes": 60,
"trigger_time": "09:30",
"max_comm_cycles": 2,
}
class _DummyManager:
def __init__(self):
self.config_name = "demo"
self.bootstrap = dict(_DummyContext().bootstrap_values)
self.context = _DummyContext()
def _persist_snapshot(self):
return None
monkeypatch.setattr(runtime_module, "PROJECT_ROOT", tmp_path)
monkeypatch.setattr(runtime_module, "_is_gateway_running", lambda: True)
runtime_module.get_runtime_state().runtime_manager = _DummyManager()
runtime_module.get_runtime_state().gateway_port = 8765
with TestClient(create_app()) as client:
response = client.put(
"/api/runtime/config",
json={
"schedule_mode": "intraday",
"interval_minutes": 15,
"trigger_time": "10:15",
"max_comm_cycles": 4,
},
)
assert response.status_code == 200
payload = response.json()
assert payload["bootstrap"]["schedule_mode"] == "intraday"
assert payload["resolved"]["interval_minutes"] == 15
assert "interval_minutes: 15" in (run_dir / "BOOTSTRAP.md").read_text(encoding="utf-8")
def test_prune_old_timestamped_runs_keeps_named_runs(monkeypatch, tmp_path):
runs_dir = tmp_path / "runs"
runs_dir.mkdir()
keep_dirs = ["20260324_110000", "20260324_120000"]
prune_dir = "20260324_100000"
named_dir = "smoke_fullstack"
for name in [*keep_dirs, prune_dir, named_dir]:
(runs_dir / name).mkdir(parents=True)
monkeypatch.setattr(runtime_module, "PROJECT_ROOT", tmp_path)
pruned = runtime_module._prune_old_timestamped_runs(keep=1, exclude_run_ids={"20260324_120000"})
assert prune_dir in pruned
assert (runs_dir / named_dir).exists()
assert (runs_dir / "20260324_120000").exists()
assert (runs_dir / "20260324_110000").exists()
def test_runtime_cleanup_endpoint_prunes_old_runs(monkeypatch, tmp_path):
runs_dir = tmp_path / "runs"
runs_dir.mkdir()
for name in ["20260324_090000", "20260324_100000", "20260324_110000", "smoke_fullstack"]:
(runs_dir / name).mkdir(parents=True)
monkeypatch.setattr(runtime_module, "PROJECT_ROOT", tmp_path)
monkeypatch.setattr(runtime_module, "_is_gateway_running", lambda: False)
with TestClient(create_app()) as client:
response = client.post("/api/runtime/cleanup?keep=1")
assert response.status_code == 200
payload = response.json()
assert payload["status"] == "ok"
assert sorted(payload["pruned_run_ids"]) == ["20260324_090000", "20260324_100000"]
assert (runs_dir / "20260324_110000").exists()
assert (runs_dir / "smoke_fullstack").exists()
def test_runtime_history_lists_recent_runs(monkeypatch, tmp_path):
run_dir = tmp_path / "runs" / "20260324_120000"
(run_dir / "state").mkdir(parents=True)
(run_dir / "team_dashboard").mkdir(parents=True)
(run_dir / "state" / "runtime_state.json").write_text(
json.dumps(
{
"context": {
"config_name": "20260324_120000",
"run_dir": str(run_dir),
"bootstrap_values": {"tickers": ["AAPL"]},
},
"events": [],
}
),
encoding="utf-8",
)
(run_dir / "team_dashboard" / "summary.json").write_text(
json.dumps({"totalTrades": 3, "totalAssetValue": 123456.0}),
encoding="utf-8",
)
monkeypatch.setattr(runtime_module, "PROJECT_ROOT", tmp_path)
with TestClient(create_app()) as client:
response = client.get("/api/runtime/history?limit=5")
assert response.status_code == 200
payload = response.json()
assert payload["runs"][0]["run_id"] == "20260324_120000"
assert payload["runs"][0]["total_trades"] == 3
def test_restore_run_assets_copies_state(monkeypatch, tmp_path):
source_run = tmp_path / "runs" / "20260324_100000"
(source_run / "team_dashboard").mkdir(parents=True)
(source_run / "state").mkdir(parents=True)
(source_run / "agents").mkdir(parents=True)
(source_run / "team_dashboard" / "_internal_state.json").write_text("{}", encoding="utf-8")
(source_run / "state" / "server_state.json").write_text("{}", encoding="utf-8")
target_run = tmp_path / "runs" / "20260324_130000"
monkeypatch.setattr(runtime_module, "PROJECT_ROOT", tmp_path)
runtime_module._restore_run_assets("20260324_100000", target_run)
assert (target_run / "team_dashboard" / "_internal_state.json").exists()
assert (target_run / "state" / "server_state.json").exists()
def test_start_runtime_restore_reuses_historical_run_id(monkeypatch, tmp_path):
run_dir = tmp_path / "runs" / "20260324_100000"
(run_dir / "state").mkdir(parents=True)
(run_dir / "state" / "runtime_state.json").write_text(
json.dumps(
{
"context": {
"config_name": "20260324_100000",
"run_dir": str(run_dir),
"bootstrap_values": {
"tickers": ["AAPL"],
"schedule_mode": "intraday",
"interval_minutes": 30,
"trigger_time": "now",
"max_comm_cycles": 2,
"initial_cash": 100000.0,
"margin_requirement": 0.0,
"enable_memory": False,
"mode": "live",
"poll_interval": 10,
},
}
}
),
encoding="utf-8",
)
class _DummyManager:
def __init__(self, config_name, run_dir, bootstrap):
self.config_name = config_name
self.run_dir = Path(run_dir)
self.bootstrap = bootstrap
self.context = None
def prepare_run(self):
self.context = type(
"Ctx",
(),
{
"config_name": self.config_name,
"run_dir": self.run_dir,
"bootstrap_values": self.bootstrap,
},
)()
return self.context
class _DummyProcess:
def poll(self):
return None
monkeypatch.setattr(runtime_module, "PROJECT_ROOT", tmp_path)
monkeypatch.setattr(runtime_module, "_find_available_port", lambda start_port=8765, max_port=9000: 8765)
monkeypatch.setattr(runtime_module, "_start_gateway_process", lambda **kwargs: _DummyProcess())
monkeypatch.setattr(runtime_module, "_stop_gateway", lambda: True)
monkeypatch.setattr("backend.runtime.manager.TradingRuntimeManager", _DummyManager)
runtime_state = runtime_module.get_runtime_state()
runtime_state.gateway_process = None
with TestClient(create_app()) as client:
response = client.post(
"/api/runtime/start",
json={
"launch_mode": "restore",
"restore_run_id": "20260324_100000",
"tickers": [],
},
)
assert response.status_code == 200
payload = response.json()
assert payload["run_id"] == "20260324_100000"
assert payload["run_dir"] == str(run_dir)

View File

@@ -0,0 +1,107 @@
# -*- coding: utf-8 -*-
"""Tests for split-aware shared service clients."""
import pytest
from shared.client.control_client import ControlPlaneClient
from shared.client.runtime_client import RuntimeServiceClient
class _DummyResponse:
def __init__(self, payload):
self._payload = payload
def raise_for_status(self):
return None
def json(self):
return self._payload
class _DummyAsyncClient:
def __init__(self):
self.calls = []
async def get(self, path, params=None):
self.calls.append(("get", path, params))
return _DummyResponse({"path": path, "params": params})
async def post(self, path, json=None):
self.calls.append(("post", path, json))
return _DummyResponse({"path": path, "json": json})
async def put(self, path, json=None):
self.calls.append(("put", path, json))
return _DummyResponse({"path": path, "json": json})
async def aclose(self):
return None
@pytest.mark.asyncio
async def test_control_plane_client_hits_current_workspace_and_guard_routes():
client = ControlPlaneClient()
client._client = _DummyAsyncClient()
await client.list_workspaces()
await client.get_workspace("demo")
await client.list_agents("demo")
await client.get_agent("demo", "risk_manager")
await client.fetch_pending_approvals()
await client.approve_pending_approval("ap-1")
await client.deny_pending_approval("ap-2", reason="nope")
assert client._client.calls == [
("get", "/workspaces", None),
("get", "/workspaces/demo", None),
("get", "/workspaces/demo/agents", None),
("get", "/workspaces/demo/agents/risk_manager", None),
("get", "/guard/pending", None),
(
"post",
"/guard/approve",
{
"approval_id": "ap-1",
"one_time": True,
"expires_in_minutes": 30,
},
),
(
"post",
"/guard/deny",
{
"approval_id": "ap-2",
"reason": "nope",
},
),
]
@pytest.mark.asyncio
async def test_runtime_service_client_hits_current_runtime_routes():
client = RuntimeServiceClient()
client._client = _DummyAsyncClient()
await client.fetch_context()
await client.fetch_agents()
await client.fetch_events()
await client.fetch_gateway_port()
await client.start_runtime({"tickers": ["AAPL"]})
await client.stop_runtime(force=True)
await client.restart_runtime({"tickers": ["MSFT"]})
await client.fetch_current_runtime()
await client.get_runtime_config()
await client.update_runtime_config({"schedule_mode": "intraday"})
assert client._client.calls == [
("get", "/context", None),
("get", "/agents", None),
("get", "/events", None),
("get", "/gateway/port", None),
("post", "/start", {"tickers": ["AAPL"]}),
("post", "/stop?force=true", None),
("post", "/restart", {"tickers": ["MSFT"]}),
("get", "/current", None),
("get", "/config", None),
("put", "/config", {"schedule_mode": "intraday"}),
]

View File

@@ -0,0 +1,32 @@
# -*- coding: utf-8 -*-
"""Regression coverage for the shared schema bridge."""
from backend.data import schema as legacy_schema
from shared import schema as shared_schema
def test_backend_data_schema_reexports_shared_contracts():
assert legacy_schema.Price is shared_schema.Price
assert legacy_schema.PriceResponse is shared_schema.PriceResponse
assert legacy_schema.FinancialMetrics is shared_schema.FinancialMetrics
assert legacy_schema.FinancialMetricsResponse is (
shared_schema.FinancialMetricsResponse
)
assert legacy_schema.LineItem is shared_schema.LineItem
assert legacy_schema.LineItemResponse is shared_schema.LineItemResponse
assert legacy_schema.InsiderTrade is shared_schema.InsiderTrade
assert legacy_schema.InsiderTradeResponse is (
shared_schema.InsiderTradeResponse
)
assert legacy_schema.CompanyNews is shared_schema.CompanyNews
assert legacy_schema.CompanyNewsResponse is shared_schema.CompanyNewsResponse
assert legacy_schema.CompanyFacts is shared_schema.CompanyFacts
assert legacy_schema.CompanyFactsResponse is (
shared_schema.CompanyFactsResponse
)
assert legacy_schema.Position is shared_schema.Position
assert legacy_schema.Portfolio is shared_schema.Portfolio
assert legacy_schema.AnalystSignal is shared_schema.AnalystSignal
assert legacy_schema.TickerAnalysis is shared_schema.TickerAnalysis
assert legacy_schema.AgentStateData is shared_schema.AgentStateData
assert legacy_schema.AgentStateMetadata is shared_schema.AgentStateMetadata

View File

@@ -0,0 +1,47 @@
# -*- coding: utf-8 -*-
"""Unit tests for the trading domain helpers."""
from backend.domains import trading as trading_domain
def test_trading_domain_payload_wrappers(monkeypatch):
monkeypatch.setattr(trading_domain, "get_prices", lambda ticker, start_date, end_date: [{"close": 1}])
monkeypatch.setattr(trading_domain, "get_financial_metrics", lambda ticker, end_date, period, limit: [{"ticker": ticker}])
monkeypatch.setattr(trading_domain, "get_company_news", lambda ticker, end_date, start_date=None, limit=1000: [{"ticker": ticker}])
monkeypatch.setattr(trading_domain, "get_insider_trades", lambda ticker, end_date, start_date=None, limit=1000: [{"ticker": ticker}])
monkeypatch.setattr(trading_domain, "get_market_cap", lambda ticker, end_date: 2.5e12)
assert trading_domain.get_prices_payload(ticker="AAPL", start_date="2026-03-01", end_date="2026-03-16") == {
"ticker": "AAPL",
"prices": [{"close": 1}],
}
assert trading_domain.get_financials_payload(ticker="AAPL", end_date="2026-03-16") == {
"financial_metrics": [{"ticker": "AAPL"}],
}
assert trading_domain.get_news_payload(ticker="AAPL", end_date="2026-03-16") == {
"news": [{"ticker": "AAPL"}],
}
assert trading_domain.get_insider_trades_payload(ticker="AAPL", end_date="2026-03-16") == {
"insider_trades": [{"ticker": "AAPL"}],
}
assert trading_domain.get_market_cap_payload(ticker="AAPL", end_date="2026-03-16") == {
"ticker": "AAPL",
"end_date": "2026-03-16",
"market_cap": 2.5e12,
}
def test_get_market_status_payload_uses_market_service(monkeypatch):
class _FakeMarketService:
def __init__(self, tickers):
self.tickers = tickers
def get_market_status(self):
return {"status": "open", "status_text": "Open"}
monkeypatch.setattr(trading_domain, "MarketService", _FakeMarketService)
assert trading_domain.get_market_status_payload() == {
"status": "open",
"status_text": "Open",
}

View File

@@ -0,0 +1,231 @@
# -*- coding: utf-8 -*-
"""Tests for the extracted trading service app surface."""
from fastapi.testclient import TestClient
from backend.apps.trading_service import create_app
from shared.schema import CompanyNews, FinancialMetrics, InsiderTrade, LineItem, Price
def test_trading_service_routes_are_exposed():
app = create_app()
paths = {route.path for route in app.routes}
assert "/health" in paths
assert "/api/prices" in paths
assert "/api/financials" in paths
assert "/api/news" in paths
assert "/api/insider-trades" in paths
assert "/api/market/status" in paths
assert "/api/market-cap" in paths
assert "/api/line-items" in paths
def test_trading_service_prices_endpoint(monkeypatch):
monkeypatch.setattr(
"backend.domains.trading.get_prices_payload",
lambda ticker, start_date, end_date: {
"ticker": ticker,
"prices": [
Price(
open=1.0,
close=2.0,
high=2.5,
low=0.5,
volume=100,
time="2026-03-20",
)
],
},
)
with TestClient(create_app()) as client:
response = client.get(
"/api/prices",
params={
"ticker": "AAPL",
"start_date": "2026-03-01",
"end_date": "2026-03-20",
},
)
assert response.status_code == 200
assert response.json()["ticker"] == "AAPL"
assert response.json()["prices"][0]["close"] == 2.0
def test_trading_service_financials_endpoint(monkeypatch):
monkeypatch.setattr(
"backend.domains.trading.get_financials_payload",
lambda ticker, end_date, period, limit: {
"financial_metrics": [
FinancialMetrics(
ticker=ticker,
report_period=end_date,
period=period,
currency="USD",
market_cap=123.0,
enterprise_value=None,
price_to_earnings_ratio=None,
price_to_book_ratio=None,
price_to_sales_ratio=None,
enterprise_value_to_ebitda_ratio=None,
enterprise_value_to_revenue_ratio=None,
free_cash_flow_yield=None,
peg_ratio=None,
gross_margin=None,
operating_margin=None,
net_margin=None,
return_on_equity=None,
return_on_assets=None,
return_on_invested_capital=None,
asset_turnover=None,
inventory_turnover=None,
receivables_turnover=None,
days_sales_outstanding=None,
operating_cycle=None,
working_capital_turnover=None,
current_ratio=None,
quick_ratio=None,
cash_ratio=None,
operating_cash_flow_ratio=None,
debt_to_equity=None,
debt_to_assets=None,
interest_coverage=None,
revenue_growth=None,
earnings_growth=None,
book_value_growth=None,
earnings_per_share_growth=None,
free_cash_flow_growth=None,
operating_income_growth=None,
ebitda_growth=None,
payout_ratio=None,
earnings_per_share=None,
book_value_per_share=None,
free_cash_flow_per_share=None,
)
]
},
)
with TestClient(create_app()) as client:
response = client.get(
"/api/financials",
params={"ticker": "AAPL", "end_date": "2026-03-20"},
)
assert response.status_code == 200
assert response.json()["financial_metrics"][0]["ticker"] == "AAPL"
def test_trading_service_news_and_insider_endpoints(monkeypatch):
monkeypatch.setattr(
"backend.domains.trading.get_news_payload",
lambda ticker, end_date, start_date=None, limit=1000: {
"news": [
CompanyNews(
ticker=ticker,
title="News title",
source="polygon",
url="https://example.com/news",
date=end_date,
)
]
},
)
monkeypatch.setattr(
"backend.domains.trading.get_insider_trades_payload",
lambda ticker, end_date, start_date=None, limit=1000: {
"insider_trades": [
InsiderTrade(ticker=ticker, filing_date=end_date)
]
},
)
with TestClient(create_app()) as client:
news_response = client.get(
"/api/news",
params={"ticker": "AAPL", "end_date": "2026-03-20"},
)
insider_response = client.get(
"/api/insider-trades",
params={"ticker": "AAPL", "end_date": "2026-03-20"},
)
assert news_response.status_code == 200
assert news_response.json()["news"][0]["title"] == "News title"
assert insider_response.status_code == 200
assert insider_response.json()["insider_trades"][0]["ticker"] == "AAPL"
def test_trading_service_market_status_endpoint(monkeypatch):
class _FakeMarketService:
def get_market_status(self):
return {"status": "open", "status_text": "Open"}
monkeypatch.setattr(
"backend.domains.trading.get_market_status_payload",
lambda: _FakeMarketService().get_market_status(),
)
with TestClient(create_app()) as client:
response = client.get("/api/market/status")
assert response.status_code == 200
assert response.json() == {"status": "open", "status_text": "Open"}
def test_trading_service_market_cap_endpoint(monkeypatch):
monkeypatch.setattr(
"backend.domains.trading.get_market_cap_payload",
lambda ticker, end_date: {
"ticker": ticker,
"end_date": end_date,
"market_cap": 3.5e12,
},
)
with TestClient(create_app()) as client:
response = client.get(
"/api/market-cap",
params={"ticker": "AAPL", "end_date": "2026-03-20"},
)
assert response.status_code == 200
assert response.json() == {
"ticker": "AAPL",
"end_date": "2026-03-20",
"market_cap": 3.5e12,
}
def test_trading_service_line_items_endpoint(monkeypatch):
monkeypatch.setattr(
"backend.domains.trading.get_line_items_payload",
lambda ticker, line_items, end_date, period, limit: {
"search_results": [
LineItem(
ticker=ticker,
report_period=end_date,
period=period,
currency="USD",
free_cash_flow=123.0,
)
]
},
)
with TestClient(create_app()) as client:
response = client.get(
"/api/line-items",
params=[
("ticker", "AAPL"),
("line_items", "free_cash_flow"),
("end_date", "2026-03-20"),
],
)
assert response.status_code == 200
assert response.json()["search_results"][0]["ticker"] == "AAPL"
assert response.json()["search_results"][0]["free_cash_flow"] == 123.0

View File

@@ -3,13 +3,16 @@
# pylint: disable=C0301
"""Data fetching tools backed by the unified provider router."""
import datetime
import os
import httpx
import pandas as pd
import pandas_market_calendars as mcal
from backend.data.provider_utils import normalize_symbol
from backend.data.cache import get_cache
from backend.data.provider_router import get_provider_router
from backend.data.schema import (
from shared.schema import (
CompanyNews,
FinancialMetrics,
InsiderTrade,
@@ -23,6 +26,31 @@ _cache = get_cache()
_router = get_provider_router()
def _service_name() -> str:
return str(os.getenv("SERVICE_NAME", "")).strip().lower()
def _trading_service_url() -> str | None:
value = str(os.getenv("TRADING_SERVICE_URL", "")).strip().rstrip("/")
if not value or _service_name() == "trading_service":
return None
return value
def _news_service_url() -> str | None:
value = str(os.getenv("NEWS_SERVICE_URL", "")).strip().rstrip("/")
if not value or _service_name() == "news_service":
return None
return value
def _service_get_json(base_url: str, path: str, *, params: dict[str, object]) -> dict:
with httpx.Client(base_url=base_url, timeout=30.0) as client:
response = client.get(path, params=params)
response.raise_for_status()
return response.json()
def get_last_tradeday(date: str) -> str:
"""
Get the previous trading day for the specified date
@@ -104,6 +132,24 @@ def get_prices(
if cached_data := _cache.get_prices(cache_key):
return [Price(**price) for price in cached_data]
service_url = _trading_service_url()
if service_url:
try:
payload = _service_get_json(
service_url,
"/api/prices",
params={
"ticker": ticker,
"start_date": start_date,
"end_date": end_date,
},
)
prices = [Price(**price) for price in payload.get("prices", [])]
if prices:
return prices
except Exception as exc:
logger.info("Trading service price lookup failed for %s: %s", ticker, exc)
try:
prices, data_source = _router.get_prices(ticker, start_date, end_date)
except Exception as exc:
@@ -146,6 +192,28 @@ def get_financial_metrics(
if cached_data := _cache.get_financial_metrics(cache_key):
return [FinancialMetrics(**metric) for metric in cached_data]
service_url = _trading_service_url()
if service_url:
try:
payload = _service_get_json(
service_url,
"/api/financials",
params={
"ticker": ticker,
"end_date": end_date,
"period": period,
"limit": limit,
},
)
metrics = [
FinancialMetrics(**metric)
for metric in payload.get("financial_metrics", [])
]
if metrics:
return metrics
except Exception as exc:
logger.info("Trading service financial lookup failed for %s: %s", ticker, exc)
try:
financial_metrics, data_source = _router.get_financial_metrics(
ticker=ticker,
@@ -183,6 +251,22 @@ def search_line_items(
ticker = normalize_symbol(ticker)
if not ticker:
return []
service_url = _trading_service_url()
if service_url:
payload = _service_get_json(
service_url,
"/api/line-items",
params={
"ticker": ticker,
"line_items": line_items,
"end_date": end_date,
"period": period,
"limit": limit,
},
)
return [LineItem(**item) for item in payload.get("search_results", [])]
return _router.search_line_items(
ticker=ticker,
line_items=line_items,
@@ -213,6 +297,26 @@ def get_insider_trades(
if cached_data := _cache.get_insider_trades(cache_key):
return [InsiderTrade(**trade) for trade in cached_data]
service_url = _trading_service_url()
if service_url:
try:
params = {"ticker": ticker, "end_date": end_date, "limit": limit}
if start_date:
params["start_date"] = start_date
payload = _service_get_json(
service_url,
"/api/insider-trades",
params=params,
)
trades = [
InsiderTrade(**trade)
for trade in payload.get("insider_trades", [])
]
if trades:
return trades
except Exception as exc:
logger.info("Trading service insider lookup failed for %s: %s", ticker, exc)
try:
all_trades, data_source = _router.get_insider_trades(
ticker=ticker,
@@ -248,6 +352,40 @@ def get_company_news(
if cached_data := _cache.get_company_news(cache_key):
return [CompanyNews(**news) for news in cached_data]
trading_service_url = _trading_service_url()
if trading_service_url:
try:
params = {"ticker": ticker, "end_date": end_date, "limit": limit}
if start_date:
params["start_date"] = start_date
payload = _service_get_json(
trading_service_url,
"/api/news",
params=params,
)
news = [CompanyNews(**item) for item in payload.get("news", [])]
if news:
return news
except Exception as exc:
logger.info("Trading service news lookup failed for %s: %s", ticker, exc)
news_service_url = _news_service_url()
if news_service_url:
try:
params = {"ticker": ticker, "end_date": end_date, "limit": limit}
if start_date:
params["start_date"] = start_date
payload = _service_get_json(
news_service_url,
"/api/enriched-news",
params=params,
)
news = [CompanyNews(**item) for item in payload.get("news", [])]
if news:
return news
except Exception as exc:
logger.info("News service lookup failed for %s: %s", ticker, exc)
try:
all_news, data_source = _router.get_company_news(
ticker=ticker,
@@ -272,6 +410,19 @@ def get_market_cap(ticker: str, end_date: str) -> float | None:
if not ticker:
return None
service_url = _trading_service_url()
if service_url:
try:
payload = _service_get_json(
service_url,
"/api/market-cap",
params={"ticker": ticker, "end_date": end_date},
)
value = payload.get("market_cap")
return float(value) if value is not None else None
except Exception as exc:
logger.info("Trading service market-cap lookup failed for %s: %s", ticker, exc)
def _metrics_lookup(symbol: str, date: str):
for source in _router.api_sources():
cache_key = f"{symbol}_ttm_{date}_10_{source}"

View File

@@ -228,12 +228,12 @@ class SettlementCoordinator:
all_evaluations = {**analyst_evaluations, **pm_evaluations}
leaderboard = self.storage.load_file("leaderboard") or []
leaderboard = self.storage.load_export_file("leaderboard") or []
updated_leaderboard = update_leaderboard_with_evaluations(
leaderboard,
all_evaluations,
)
self.storage.save_file("leaderboard", updated_leaderboard)
self.storage.save_export_file("leaderboard", updated_leaderboard)
self._update_summary_with_baselines(
date,

View File

@@ -30,7 +30,6 @@ class TerminalDashboard:
self.port = 8765
self.poll_interval = 10
self.trigger_time = "now"
self.mock = False
self.enable_memory = False
self.local_time = ""
self.nyse_time = ""
@@ -65,7 +64,6 @@ class TerminalDashboard:
port: int,
poll_interval: int,
trigger_time: str = "now",
mock: bool = False,
enable_memory: bool = False,
local_time: str = "",
nyse_time: str = "",
@@ -82,7 +80,6 @@ class TerminalDashboard:
self.port = port
self.poll_interval = poll_interval
self.trigger_time = trigger_time
self.mock = mock
self.enable_memory = enable_memory
self.local_time = local_time
self.nyse_time = nyse_time
@@ -109,8 +106,6 @@ class TerminalDashboard:
# Mode line
if self.mode == "backtest":
mode_str = "[cyan]Backtest[/cyan]"
elif self.mock:
mode_str = "[yellow]MOCK[/yellow]"
else:
mode_str = "[green]LIVE[/green]"
@@ -216,8 +211,6 @@ class TerminalDashboard:
title = "[bold cyan]EvoTraders[/bold cyan]"
if self.mode == "backtest":
title += " [dim]Backtest[/dim]"
elif self.mock:
title += " [dim]Mock[/dim]"
else:
title += " [dim]Live[/dim]"

View File

@@ -0,0 +1,28 @@
# Compatibility Removal Plan
This document tracks the remaining migration-only surfaces that still exist
after the move to split-first development.
## Migration-only Surfaces
None currently remain as dedicated compatibility wrappers.
## Completed Removals
### `backend.app`
- Removed after compatibility startup switched to
`backend.apps.combined_service:app` directly.
### `shared.client.AgentServiceClient`
- Removed after split-aware clients became the default import surface.
- Replacement:
- `ControlPlaneClient`
- `RuntimeServiceClient`
- `TradingServiceClient`
- `NewsServiceClient`
### `backend.apps.combined_service`
- Removed after split-service mode became the only supported dev startup path.

View File

@@ -1,7 +1,31 @@
## QuickStart
```bash
cd frontend
npm install
npm run dev
```
## Optional Direct Service Calls
The frontend still works with the compatibility backend entrypoint by default.
In the current test-stage setup, split services are the recommended default.
Point the frontend directly at those standalone services:
```bash
VITE_CONTROL_API_BASE_URL=http://localhost:8000/api
VITE_RUNTIME_API_BASE_URL=http://localhost:8003/api/runtime
VITE_NEWS_SERVICE_URL=http://localhost:8002
VITE_TRADING_SERVICE_URL=http://localhost:8001
```
Current direct-call coverage:
- runtime panel + gateway port discovery
- `story`
- `similar days`
- `range explain`
- `news for date`
- `news categories`
If these variables are not set, the frontend falls back to the existing
WebSocket-driven compatibility flow.

File diff suppressed because it is too large Load Diff

View File

@@ -57,7 +57,7 @@ export default function AgentCard({ agent, onClose, isClosing }) {
background: '#ffffff',
borderBottom: '2px solid #000000',
boxShadow: '0 4px 12px rgba(0, 0, 0, 0.1)',
zIndex: 1000,
zIndex: 800,
animation: isClosing ? 'slideUp 0.2s ease-out forwards' : 'slideDown 0.25s ease-out'
}}>
{/* Horizontal scrollable content */}

View File

@@ -35,14 +35,22 @@ const stripMarkdown = (text) => {
.replace(/^[-=]+$/gm, '');
};
const AgentFeed = forwardRef(({ feed, leaderboard }, ref) => {
const AgentFeed = forwardRef(({ feed, leaderboard, agentProfilesByAgent }, ref) => {
const feedContentRef = useRef(null);
const [highlightedId, setHighlightedId] = useState(null);
const [selectedAgent, setSelectedAgent] = useState('all');
const [dropdownOpen, setDropdownOpen] = useState(false);
const getAgentModelInfo = (agentId) => {
if (!leaderboard || !agentId) return { modelName: null, modelProvider: null };
if (!agentId) return { modelName: null, modelProvider: null };
const profile = agentProfilesByAgent?.[agentId];
if (profile?.model_name) {
return {
modelName: profile.model_name,
modelProvider: profile.model_provider
};
}
if (!leaderboard) return { modelName: null, modelProvider: null };
const agentData = leaderboard.find(lb => lb.id === agentId || lb.agentId === agentId);
return {
modelName: agentData?.modelName,
@@ -52,7 +60,17 @@ const AgentFeed = forwardRef(({ feed, leaderboard }, ref) => {
// Get agent info by name
const getAgentInfoByName = (agentName) => {
if (!leaderboard || !agentName) return null;
if (!agentName) return null;
const agentConfig = AGENTS.find((agent) => agent.name === agentName);
const profile = agentConfig ? agentProfilesByAgent?.[agentConfig.id] : null;
if (agentConfig && profile?.model_name) {
return {
agentId: agentConfig.id,
modelName: profile.model_name,
modelProvider: profile.model_provider
};
}
if (!leaderboard) return null;
const agentData = leaderboard.find(lb => lb.name === agentName || lb.agentName === agentName);
if (!agentData) return null;
return {

View File

@@ -0,0 +1,506 @@
import React, { Suspense, lazy, useRef, useEffect, useMemo } from 'react';
import GlobalStyles from '../styles/GlobalStyles';
import Header from './Header.jsx';
import RuntimeSettingsPanel from './RuntimeSettingsPanel.jsx';
import StockLogo from './StockLogo.jsx';
import NetValueChart from './NetValueChart.jsx';
import { AGENTS } from '../config/constants';
import { useRuntimeStore } from '../store/runtimeStore';
import { useUIStore } from '../store/uiStore';
import { formatNumber, formatTickerPrice } from '../utils/formatters';
const RoomView = lazy(() => import('./RoomView'));
const AgentFeed = lazy(() => import('./AgentFeed'));
const StatisticsView = lazy(() => import('./StatisticsView'));
const StockExplainView = lazy(() => import('./StockExplainView.jsx'));
const TraderView = lazy(() => import('./TraderView.jsx'));
function ViewLoadingFallback({ label = '加载中...' }) {
return (
<div style={{
minHeight: 240,
height: '100%',
display: 'flex',
alignItems: 'center',
justifyContent: 'center',
border: '1px solid #000000',
background: '#ffffff',
fontSize: 12,
fontWeight: 700,
letterSpacing: 0.4
}}>
{label}
</div>
);
}
/**
* AppShell - Layout shell containing Header, TickerBar, ViewNavBar, View container, and AgentFeed
*/
export default function AppShell({
// Connection & status
isConnected,
virtualTime,
now,
marketStatus,
serverMode,
marketStatusLabel,
dataSourceLabel,
runtimeSummaryLabel,
isUpdating,
// Handlers
onManualTrigger,
onOpenRuntimeLogs,
onRuntimeSettingsToggle,
// Runtime settings panel props
isRuntimeSettingsOpen,
isRuntimeConfigSaving,
isWatchlistSaving,
runtimeConfigFeedback,
watchlistFeedback,
launchModeDraft,
restoreRunIdDraft,
runtimeHistoryRuns,
scheduleModeDraft,
intervalMinutesDraft,
triggerTimeDraft,
maxCommCyclesDraft,
initialCashDraft,
marginRequirementDraft,
enableMemoryDraft,
modeDraft,
pollIntervalDraft,
startDateDraft,
endDateDraft,
watchlistDraftSymbols,
watchlistInputValue,
watchlistSuggestions,
onLaunchModeChange,
onRestoreRunIdChange,
onScheduleModeChange,
onIntervalMinutesChange,
onTriggerTimeChange,
onMaxCommCyclesChange,
onInitialCashChange,
onMarginRequirementChange,
onEnableMemoryChange,
onModeChange,
onPollIntervalChange,
onStartDateChange,
onEndDateChange,
onWatchlistInputChange,
onWatchlistInputKeyDown,
onWatchlistAdd,
onWatchlistRemove,
onWatchlistRestoreCurrent,
onWatchlistRestoreDefault,
onWatchlistSuggestionClick,
onLaunchConfigSave,
onRestoreDefaults,
// Ticker and portfolio data
displayTickers,
portfolioData,
rollingTickers,
// Feed data
feed,
bubbles,
bubbleFor,
leaderboard,
// Views data
currentView,
chartTab,
holdings,
trades,
stats,
priceHistoryByTicker,
ohlcHistoryByTicker,
selectedExplainSymbol,
onSelectedExplainSymbolChange,
historySourceByTicker,
explainEventsByTicker,
newsByTicker,
insiderTradesByTicker,
technicalIndicatorsByTicker,
currentDate,
// Stock request handlers
stockRequests,
// Agent request handlers
agentRequests,
agentProfilesByAgent,
// Layout
leftWidth,
isResizing,
onMouseDown,
agentFeedRef
}) {
const containerRef = useRef(null);
const { setIsRuntimeSettingsOpen, setIsWatchlistPanelOpen } = useRuntimeStore();
const { setChartTab, setCurrentView, setIsResizing, setLeftWidth } = useUIStore();
// Resize handler
useEffect(() => {
if (!isResizing) return;
const handleMouseMove = (e) => {
if (!containerRef.current) return;
const containerRect = containerRef.current.getBoundingClientRect();
const newLeftWidth = ((e.clientX - containerRect.left) / containerRect.width) * 100;
if (newLeftWidth >= 30 && newLeftWidth <= 85) {
setLeftWidth(newLeftWidth);
}
};
const handleMouseUp = () => setIsResizing(false);
document.addEventListener('mousemove', handleMouseMove);
document.addEventListener('mouseup', handleMouseUp);
return () => {
document.removeEventListener('mousemove', handleMouseMove);
document.removeEventListener('mouseup', handleMouseUp);
};
}, [isResizing, setIsResizing, setLeftWidth]);
const handleJumpToMessage = (bubble) => {
if (agentFeedRef.current && agentFeedRef.current.scrollToMessage) {
agentFeedRef.current.scrollToMessage(bubble);
}
};
const viewClassName = useMemo(() => {
const base = `view-slider-five ${currentView === 'traders' ? 'show-traders' :
currentView === 'room' ? 'show-room' :
currentView === 'explain' ? 'show-explain' :
currentView === 'statistics' ? 'show-statistics' : 'show-chart'}`;
return base;
}, [currentView]);
return (
<div className="app">
<GlobalStyles />
{/* Header */}
<div className="header">
<Header />
<div className="header-right" style={{ display: 'flex', alignItems: 'center', gap: 24, marginLeft: 'auto', flexWrap: 'wrap', minWidth: 0 }}>
{/* Unified Status Indicator */}
<div className="header-status-inline">
<span className={`status-dot ${isConnected ? (isUpdating ? 'updating' : 'live') : 'offline'}`} />
<span className={`status-text ${isConnected ? 'live' : 'offline'}`}>
{isConnected ? (isUpdating ? '同步中' : '在线') : '离线'}
</span>
{marketStatus && (
<>
<span className="status-sep">·</span>
<span className={`market-text ${serverMode === 'backtest' ? 'backtest' : (marketStatus.status === 'open' ? 'open' : 'closed')}`}>
{marketStatusLabel}
</span>
</>
)}
{dataSourceLabel && (
<>
<span className="status-sep">·</span>
<span className="market-text backtest">{dataSourceLabel}</span>
</>
)}
{runtimeSummaryLabel && (
<>
<span className="status-sep">·</span>
<span className="market-text backtest" title="当前运行配置">{runtimeSummaryLabel}</span>
</>
)}
<span className="status-sep">·</span>
<span className="time-text">{now.toLocaleTimeString('en-US', { hour: '2-digit', minute: '2-digit', second: '2-digit', hour12: false })}</span>
</div>
{serverMode !== 'backtest' && (
<div style={{ display: 'flex', gap: 8, alignItems: 'center' }}>
{onOpenRuntimeLogs && (
<button
onClick={onOpenRuntimeLogs}
style={{
padding: '6px 12px',
borderRadius: 4,
background: '#FFFFFF',
border: '1px solid #111111',
color: '#111111',
fontSize: '11px',
fontFamily: '"Courier New", monospace',
fontWeight: 700,
cursor: 'pointer',
letterSpacing: '0.4px',
textTransform: 'uppercase'
}}
title="查看当前运行日志"
>
运行日志
</button>
)}
<button
onClick={onManualTrigger}
disabled={!isConnected}
style={{
padding: '6px 12px',
borderRadius: 4,
background: isConnected ? '#111111' : '#8a8a8a',
border: '1px solid #111111',
color: '#FFFFFF',
fontSize: '11px',
fontFamily: '"Courier New", monospace',
fontWeight: 700,
cursor: isConnected ? 'pointer' : 'not-allowed',
letterSpacing: '0.4px',
textTransform: 'uppercase'
}}
title="手动触发一轮分析与交易决策"
>
手动运行
</button>
</div>
)}
<RuntimeSettingsPanel
showTrigger={false}
isOpen={isRuntimeSettingsOpen}
isConnected={isConnected}
isSaving={isRuntimeConfigSaving || isWatchlistSaving}
feedback={runtimeConfigFeedback || watchlistFeedback}
launchMode={launchModeDraft}
restoreRunId={restoreRunIdDraft}
runtimeHistoryRuns={runtimeHistoryRuns}
scheduleMode={scheduleModeDraft}
intervalMinutes={intervalMinutesDraft}
triggerTime={triggerTimeDraft}
maxCommCycles={maxCommCyclesDraft}
initialCash={initialCashDraft}
marginRequirement={marginRequirementDraft}
enableMemory={enableMemoryDraft}
mode={modeDraft}
pollInterval={pollIntervalDraft}
startDate={startDateDraft}
endDate={endDateDraft}
watchlistSymbols={watchlistDraftSymbols}
watchlistInputValue={watchlistInputValue}
watchlistSuggestions={watchlistSuggestions}
onToggle={onRuntimeSettingsToggle}
onClose={() => setIsRuntimeSettingsOpen(false)}
onLaunchModeChange={onLaunchModeChange}
onRestoreRunIdChange={onRestoreRunIdChange}
onScheduleModeChange={onScheduleModeChange}
onIntervalMinutesChange={onIntervalMinutesChange}
onTriggerTimeChange={onTriggerTimeChange}
onMaxCommCyclesChange={onMaxCommCyclesChange}
onInitialCashChange={onInitialCashChange}
onMarginRequirementChange={onMarginRequirementChange}
onEnableMemoryChange={onEnableMemoryChange}
onModeChange={onModeChange}
onPollIntervalChange={onPollIntervalChange}
onStartDateChange={onStartDateChange}
onEndDateChange={onEndDateChange}
onWatchlistInputChange={onWatchlistInputChange}
onWatchlistInputKeyDown={onWatchlistInputKeyDown}
onWatchlistAdd={onWatchlistAdd}
onWatchlistRemove={onWatchlistRemove}
onWatchlistRestoreCurrent={onWatchlistRestoreCurrent}
onWatchlistRestoreDefault={onWatchlistRestoreDefault}
onWatchlistSuggestionClick={onWatchlistSuggestionClick}
onSave={onLaunchConfigSave}
onRestoreDefaults={onRestoreDefaults}
/>
</div>
</div>
{/* Main Content */}
<>
{/* Ticker Bar */}
<div className="ticker-bar">
<div className="ticker-track">
{[0, 1].map((groupIdx) => (
<div key={groupIdx} className="ticker-group">
{displayTickers.map(ticker => (
<div key={`${ticker.symbol}-${groupIdx}`} className="ticker-item">
<StockLogo ticker={ticker.symbol} size={16} />
<span className="ticker-symbol">{ticker.symbol}</span>
<span className="ticker-price">
<span className={`ticker-price-value ${rollingTickers[ticker.symbol] ? 'rolling' : ''}`}>
{ticker.price !== null && ticker.price !== undefined
? `$${formatTickerPrice(ticker.price)}` : '-'}
</span>
</span>
<span className={`ticker-change ${
ticker.change === null || ticker.change === undefined
? '' : ticker.change >= 0 ? 'positive' : 'negative'
}`}>
{ticker.change !== null && ticker.change !== undefined
? `${ticker.change >= 0 ? '+' : ''}${ticker.change.toFixed(2)}%` : '-'}
</span>
</div>
))}
</div>
))}
</div>
<div className="portfolio-value">
<span className="portfolio-label">投资组合</span>
<span className="portfolio-amount">${formatNumber(portfolioData.netValue)}</span>
</div>
</div>
<div className="main-container" ref={containerRef}>
{/* Left Panel */}
<div className="left-panel" style={{ width: `${leftWidth}%` }}>
<div className="chart-section">
<div className="view-container">
<div className="view-nav-bar">
<button
className={`view-nav-btn ${currentView === 'traders' ? 'active' : ''}`}
onClick={() => setCurrentView('traders')}
>
交易员
</button>
<button
className={`view-nav-btn ${currentView === 'room' ? 'active' : ''}`}
onClick={() => setCurrentView('room')}
>
交易室
</button>
<button
className={`view-nav-btn ${currentView === 'explain' ? 'active' : ''}`}
onClick={() => setCurrentView('explain')}
>
个股分析
</button>
<button
className={`view-nav-btn ${currentView === 'chart' ? 'active' : ''}`}
onClick={() => setCurrentView('chart')}
>
业绩图表
</button>
<button
className={`view-nav-btn ${currentView === 'statistics' ? 'active' : ''}`}
onClick={() => setCurrentView('statistics')}
>
统计
</button>
</div>
<div className={viewClassName}>
{/* Traders View */}
<div className="view-panel">
<Suspense fallback={<ViewLoadingFallback label="加载交易员视图..." />}>
<TraderView {...agentRequests} />
</Suspense>
</div>
{/* Room View Panel */}
<div className="view-panel">
<Suspense fallback={<ViewLoadingFallback label="加载交易室..." />}>
<RoomView
bubbles={bubbles}
bubbleFor={bubbleFor}
leaderboard={leaderboard}
agentProfilesByAgent={agentProfilesByAgent}
feed={feed}
onJumpToMessage={handleJumpToMessage}
onOpenLaunchConfig={() => setIsRuntimeSettingsOpen(true)}
/>
</Suspense>
</div>
{/* Stock Explain View Panel */}
<div className="view-panel">
<Suspense fallback={<ViewLoadingFallback label="加载个股分析..." />}>
<StockExplainView
tickers={displayTickers}
holdings={holdings}
trades={trades}
leaderboard={leaderboard}
feed={feed}
priceHistoryByTicker={priceHistoryByTicker}
ohlcHistoryByTicker={ohlcHistoryByTicker}
selectedSymbol={selectedExplainSymbol}
onSelectedSymbolChange={onSelectedExplainSymbolChange}
selectedHistorySource={historySourceByTicker[selectedExplainSymbol] || null}
explainEventsSnapshot={explainEventsByTicker[selectedExplainSymbol] || null}
newsSnapshot={newsByTicker[selectedExplainSymbol] || null}
insiderTradesSnapshot={insiderTradesByTicker[selectedExplainSymbol] || null}
technicalIndicatorsSnapshot={technicalIndicatorsByTicker[selectedExplainSymbol] || null}
onRequestHistory={stockRequests?.requestStockHistory}
onRequestExplainEvents={stockRequests?.requestStockExplainEvents}
onRequestNews={stockRequests?.requestStockNews}
onRequestRangeExplain={stockRequests?.requestStockRangeExplain}
onRequestNewsForDate={stockRequests?.requestStockNewsForDate}
onRequestStory={stockRequests?.requestStockStory}
onRequestInsiderTrades={stockRequests?.requestStockInsiderTrades}
onRequestTechnicalIndicators={stockRequests?.requestStockTechnicalIndicators}
currentDate={currentDate}
onRequestSimilarDays={stockRequests?.requestStockSimilarDays}
onRequestStockEnrich={stockRequests?.requestStockEnrich}
/>
</Suspense>
</div>
{/* Chart View Panel */}
<div className="view-panel">
<div className="chart-container">
<div className="chart-tabs-floating">
<button
className={`chart-tab ${chartTab === 'all' ? 'active' : ''}`}
onClick={() => setChartTab('all')}
>
日线
</button>
</div>
{currentView === 'chart' ? (
<NetValueChart
equity={portfolioData.equity}
baseline={portfolioData.baseline}
baseline_vw={portfolioData.baseline_vw}
momentum={portfolioData.momentum}
strategies={portfolioData.strategies}
equity_return={portfolioData.equity_return}
baseline_return={portfolioData.baseline_return}
baseline_vw_return={portfolioData.baseline_vw_return}
momentum_return={portfolioData.momentum_return}
chartTab={chartTab}
virtualTime={virtualTime}
/>
) : (
<div style={{ height: '100%', minHeight: 320 }} />
)}
</div>
</div>
{/* Statistics View Panel */}
<div className="view-panel">
<Suspense fallback={<ViewLoadingFallback label="加载统计视图..." />}>
<StatisticsView
trades={trades}
holdings={holdings}
stats={stats}
portfolioData={portfolioData}
baseline_vw={portfolioData.baseline_vw}
equity={portfolioData.equity}
leaderboard={leaderboard}
/>
</Suspense>
</div>
</div>
</div>
</div>
</div>
{/* Resizer */}
<div className={`resizer ${isResizing ? 'resizing' : ''}`} onMouseDown={onMouseDown} />
{/* Right Panel: Agent Feed */}
<div className="right-panel" style={{ width: `${100 - leftWidth}%` }}>
<Suspense fallback={<ViewLoadingFallback label="加载消息流..." />}>
<AgentFeed ref={agentFeedRef} feed={feed} leaderboard={leaderboard} agentProfilesByAgent={agentProfilesByAgent} />
</Suspense>
</div>
</div>
</>
</div>
);
}

View File

@@ -5,10 +5,10 @@ import { formatNumber, formatFullNumber } from '../utils/formatters';
/**
* Helper function to get the start time of the most recent trading session
* Trading session: 22:30 - next day 05:00
* @param {Date|null} virtualTime - Virtual time from server (for mock mode), or null to use real time
* @param {Date|null} virtualTime - Virtual time from server, or null to use real time
*/
function getRecentTradingSessionStart(virtualTime = null) {
// Use virtual time if provided (for mock mode), otherwise use real time
// Use virtual time if provided, otherwise use real time
let now;
if (virtualTime) {
// Ensure virtualTime is a valid Date object

View File

@@ -47,7 +47,7 @@ function getRankMedal(rank) {
* Supports click and hover (1.5s) to show agent performance cards
* Supports replay mode - completely independent from live mode
*/
export default function RoomView({ bubbles, bubbleFor, leaderboard, feed, onJumpToMessage, onOpenLaunchConfig }) {
export default function RoomView({ bubbles, bubbleFor, leaderboard, agentProfilesByAgent, feed, onJumpToMessage, onOpenLaunchConfig }) {
const canvasRef = useRef(null);
const containerRef = useRef(null);
@@ -162,11 +162,14 @@ export default function RoomView({ bubbles, bubbleFor, leaderboard, feed, onJump
const getAgentData = (agentId) => {
const agent = AGENTS.find(a => a.id === agentId);
if (!agent) return null;
const profile = agentProfilesByAgent?.[agentId] || null;
// If no leaderboard data, return agent with default stats
if (!leaderboard || !Array.isArray(leaderboard)) {
return {
...agent,
modelName: profile?.model_name || null,
modelProvider: profile?.model_provider || null,
bull: { n: 0, win: 0, unknown: 0 },
bear: { n: 0, win: 0, unknown: 0 },
winRate: null,
@@ -181,6 +184,8 @@ export default function RoomView({ bubbles, bubbleFor, leaderboard, feed, onJump
if (!leaderboardData) {
return {
...agent,
modelName: profile?.model_name || null,
modelProvider: profile?.model_provider || null,
bull: { n: 0, win: 0, unknown: 0 },
bear: { n: 0, win: 0, unknown: 0 },
winRate: null,
@@ -193,6 +198,8 @@ export default function RoomView({ bubbles, bubbleFor, leaderboard, feed, onJump
return {
...agent,
...leaderboardData,
modelName: profile?.model_name || leaderboardData.modelName || null,
modelProvider: profile?.model_provider || leaderboardData.modelProvider || null,
avatar: agent.avatar // Always use the frontend's avatar URL
};
};

View File

@@ -0,0 +1,190 @@
import React, { useEffect, useMemo, useRef, useState } from 'react';
import { createPortal } from 'react-dom';
export default function RuntimeLogsModal({
isOpen,
isLoading,
logPayload,
error,
onClose,
onRefresh
}) {
const logRef = useRef(null);
const [autoRefresh, setAutoRefresh] = useState(true);
const [followTail, setFollowTail] = useState(true);
const refreshIntervalMs = useMemo(() => 2000, []);
useEffect(() => {
if (!isOpen || !autoRefresh) {
return undefined;
}
const timerId = window.setInterval(() => {
onRefresh();
}, refreshIntervalMs);
return () => window.clearInterval(timerId);
}, [autoRefresh, isOpen, onRefresh, refreshIntervalMs]);
useEffect(() => {
if (!isOpen || !followTail || !logRef.current) {
return;
}
logRef.current.scrollTop = logRef.current.scrollHeight;
}, [followTail, isOpen, logPayload?.content]);
if (!isOpen) {
return null;
}
return createPortal(
<div
onClick={onClose}
style={{
position: 'fixed',
inset: 0,
background: 'rgba(15, 23, 42, 0.32)',
display: 'flex',
alignItems: 'center',
justifyContent: 'center',
padding: 24,
zIndex: 10000
}}
>
<div
onClick={(event) => event.stopPropagation()}
style={{
width: 'min(980px, 94vw)',
maxHeight: '82vh',
overflow: 'hidden',
borderRadius: 16,
border: '1px solid #D9E0E7',
background: '#FFFFFF',
boxShadow: '0 24px 60px rgba(15, 23, 42, 0.18)',
display: 'grid',
gridTemplateRows: 'auto auto minmax(0, 1fr)'
}}
>
<div style={{
padding: '18px 20px 10px',
display: 'flex',
justifyContent: 'space-between',
alignItems: 'center',
gap: 12
}}>
<div style={{ display: 'grid', gap: 4 }}>
<div style={{ fontSize: 14, fontWeight: 800, color: '#111111' }}>运行日志</div>
<div style={{ fontSize: 11, color: '#6B7280' }}>
{logPayload?.run_id ? `任务 ${logPayload.run_id}` : '当前运行任务'}
</div>
</div>
<div style={{ display: 'flex', gap: 8 }}>
<button
type="button"
onClick={onRefresh}
style={{
padding: '7px 10px',
borderRadius: 8,
border: '1px solid #D0D7DE',
background: '#FFFFFF',
color: '#111111',
fontSize: 11,
fontWeight: 700,
cursor: 'pointer'
}}
>
刷新
</button>
<button
type="button"
onClick={onClose}
style={{
padding: '7px 10px',
borderRadius: 8,
border: '1px solid #111111',
background: '#111111',
color: '#FFFFFF',
fontSize: 11,
fontWeight: 700,
cursor: 'pointer'
}}
>
关闭
</button>
</div>
</div>
<div style={{
padding: '0 20px 12px',
display: 'flex',
justifyContent: 'space-between',
gap: 12,
alignItems: 'center',
flexWrap: 'wrap'
}}>
<div style={{ fontSize: 11, color: '#6B7280', fontFamily: '"Courier New", monospace' }}>
{logPayload?.log_path || '未找到日志文件'}
</div>
{isLoading ? (
<div style={{ fontSize: 11, color: '#2563EB', fontWeight: 700 }}>加载中...</div>
) : error ? (
<div style={{ fontSize: 11, color: '#B91C1C', fontWeight: 700 }}>{error}</div>
) : null}
</div>
<div style={{
padding: '0 20px 12px',
display: 'flex',
gap: 16,
alignItems: 'center',
flexWrap: 'wrap'
}}>
<label style={{ display: 'inline-flex', alignItems: 'center', gap: 6, fontSize: 11, color: '#374151', cursor: 'pointer' }}>
<input
type="checkbox"
checked={autoRefresh}
onChange={(event) => setAutoRefresh(event.target.checked)}
/>
实时刷新
</label>
<label style={{ display: 'inline-flex', alignItems: 'center', gap: 6, fontSize: 11, color: '#374151', cursor: 'pointer' }}>
<input
type="checkbox"
checked={followTail}
onChange={(event) => setFollowTail(event.target.checked)}
/>
自动滚底
</label>
</div>
<div style={{ padding: '0 20px 20px', minHeight: 0 }}>
<pre
ref={logRef}
style={{
margin: 0,
height: '100%',
minHeight: 320,
maxHeight: 'calc(82vh - 140px)',
overflow: 'auto',
borderRadius: 12,
border: '1px solid #D9E0E7',
background: '#0F172A',
color: '#E2E8F0',
padding: 16,
fontSize: 11,
lineHeight: 1.6,
fontFamily: '"SFMono-Regular", Menlo, Consolas, "Liberation Mono", monospace',
whiteSpace: 'pre-wrap',
wordBreak: 'break-word'
}}
>
{logPayload?.content || '暂无日志输出'}
</pre>
</div>
</div>
</div>,
document.body
);
}

View File

@@ -1,12 +1,24 @@
import React from 'react';
import { createPortal } from 'react-dom';
const formatHistorySummary = (run) => {
const updatedAt = run?.updated_at ? String(run.updated_at).replace("T", " ").slice(0, 16) : "未知时间";
const mode = run?.bootstrap?.mode ? String(run.bootstrap.mode).toUpperCase() : "LIVE";
const tickers = Array.isArray(run?.bootstrap?.tickers) ? run.bootstrap.tickers.length : 0;
const assetValue = Number(run?.total_asset_value ?? 0).toFixed(2);
const trades = Number(run?.total_trades ?? 0);
return `${run.run_id} · ${updatedAt} · ${mode} · ${tickers}标的 · ${trades}笔交易 · $${assetValue}`;
};
export default function RuntimeSettingsPanel({
showTrigger = true,
isOpen,
isConnected,
isSaving,
feedback,
launchMode,
restoreRunId,
runtimeHistoryRuns,
scheduleMode,
intervalMinutes,
triggerTime,
@@ -18,13 +30,14 @@ export default function RuntimeSettingsPanel({
pollInterval,
startDate,
endDate,
enableMock,
watchlistSymbols,
watchlistInputValue,
watchlistSuggestions,
onToggle,
onClose,
onScheduleModeChange,
onLaunchModeChange,
onRestoreRunIdChange,
onIntervalMinutesChange,
onTriggerTimeChange,
onMaxCommCyclesChange,
@@ -35,7 +48,6 @@ export default function RuntimeSettingsPanel({
onPollIntervalChange,
onStartDateChange,
onEndDateChange,
onEnableMockChange,
onWatchlistInputChange,
onWatchlistInputKeyDown,
onWatchlistAdd,
@@ -134,6 +146,75 @@ export default function RuntimeSettingsPanel({
</div>
</div>
<div style={{
border: '1px solid #E5EAF1',
borderRadius: 12,
background: '#FCFDFE',
padding: 14,
display: 'grid',
gap: 12
}}>
<div style={{ fontSize: 12, fontWeight: 800, color: '#111111' }}>启动形式</div>
<label style={{ display: 'grid', gap: 4 }}>
<span style={{ fontSize: '10px', color: '#4B5563', fontWeight: 700 }}>任务模式</span>
<select
value={launchMode}
onChange={(e) => onLaunchModeChange(e.target.value)}
style={{
padding: '9px 10px',
borderRadius: 8,
border: '1px solid #D0D7DE',
background: '#FFFFFF',
color: '#111111',
fontSize: '12px'
}}
>
<option value="fresh">重新启动</option>
<option value="restore">从历史任务恢复</option>
</select>
</label>
{launchMode === 'restore' && (
<>
<label style={{ display: 'grid', gap: 4 }}>
<span style={{ fontSize: '10px', color: '#4B5563', fontWeight: 700 }}>历史任务</span>
<select
value={restoreRunId}
onChange={(e) => onRestoreRunIdChange(e.target.value)}
style={{
padding: '9px 10px',
borderRadius: 8,
border: '1px solid #D0D7DE',
background: '#FFFFFF',
color: '#111111',
fontSize: '12px'
}}
>
<option value="">请选择历史任务</option>
{runtimeHistoryRuns.map((run) => (
<option key={run.run_id} value={run.run_id}>
{formatHistorySummary(run)}
</option>
))}
</select>
</label>
<div style={{
fontSize: '11px',
color: '#6B7280',
lineHeight: 1.6,
padding: '10px 12px',
borderRadius: 8,
background: '#FFFFFF',
border: '1px dashed #D0D7DE'
}}>
恢复启动会从所选历史任务复制运行状态组合交易记录和 Agent 工作区资产并以新的任务 ID 继续运行
</div>
</>
)}
</div>
{launchMode === 'fresh' && (
<div style={{
border: '1px solid #E5EAF1',
borderRadius: 12,
@@ -273,7 +354,9 @@ export default function RuntimeSettingsPanel({
</button>
</div>
</div>
)}
{launchMode === 'fresh' && (
<div style={{
border: '1px solid #E5EAF1',
borderRadius: 12,
@@ -495,22 +578,8 @@ export default function RuntimeSettingsPanel({
}}
/>
</label>
<label style={{ display: 'flex', alignItems: 'center', gap: 10, marginTop: 2 }}>
<input
type="checkbox"
checked={enableMock}
onChange={(e) => onEnableMockChange(e.target.checked)}
style={{
width: 16,
height: 16,
accentColor: '#0D47A1',
cursor: 'pointer'
}}
/>
<span style={{ fontSize: '11px', color: '#111111', fontWeight: 700 }}>启用模拟数据 (Mock)</span>
</label>
</div>
)}
<div style={{
border: '1px solid #E5EAF1',

View File

@@ -34,6 +34,18 @@ const EVENT_FILTER_OPTIONS = [
{ value: 'approval', label: '审批事件' }
];
const SR_ONLY_STYLE = {
position: 'absolute',
width: 1,
height: 1,
padding: 0,
margin: -1,
overflow: 'hidden',
clip: 'rect(0, 0, 0, 0)',
whiteSpace: 'nowrap',
border: 0
};
function metricCard(label, value, accent, helper = null) {
return (
<div className="stat-card">
@@ -722,6 +734,9 @@ export default function RuntimeView() {
{sectionTitle(
'近期事件',
<select
id="runtime-event-filter"
name="runtime_event_filter"
aria-label="筛选近期事件"
value={eventFilter}
onChange={(event) => setEventFilter(event.target.value)}
style={{
@@ -739,6 +754,9 @@ export default function RuntimeView() {
))}
</select>
)}
<label htmlFor="runtime-event-filter" style={SR_ONLY_STYLE}>
筛选近期事件
</label>
<div style={{
display: 'grid',
gap: 8,

View File

@@ -8,12 +8,36 @@ import { formatNumber, formatDateTime } from '../utils/formatters';
* Left: Performance Overview (35%) | Right: Holdings + Trades (65%)
* No scrolling - content fits within viewport with pagination
*/
export default function StatisticsView({ trades, holdings, stats, baseline_vw, equity, leaderboard }) {
export default function StatisticsView({ trades, holdings, stats, baseline_vw, equity, leaderboard, portfolioData }) {
const [holdingsPage, setHoldingsPage] = useState(1);
const [tradesPage, setTradesPage] = useState(1);
const holdingsPerPage = 5;
const tradesPerPage = 8;
const effectiveStats = React.useMemo(() => {
const base = stats && typeof stats === 'object' ? stats : {};
const netValue = Number(portfolioData?.netValue ?? 0);
const pnl = Number(portfolioData?.pnl ?? 0);
const hasPortfolioValue = Number.isFinite(netValue) && netValue > 0;
const hasMeaningfulStats = Number(base?.totalAssetValue ?? 0) > 0;
if (hasMeaningfulStats || !hasPortfolioValue) {
return base;
}
const cashHolding = Array.isArray(holdings)
? holdings.find((item) => String(item?.ticker || '').toUpperCase() === 'CASH')
: null;
return {
...base,
totalAssetValue: netValue,
totalReturn: pnl,
cashPosition: Number(cashHolding?.marketValue ?? cashHolding?.currentPrice ?? 0),
totalTrades: Array.isArray(trades) ? trades.length : 0,
};
}, [holdings, portfolioData, stats, trades]);
// Calculate pagination for holdings
const totalHoldingsPages = Math.ceil(holdings.length / holdingsPerPage);
const holdingsStartIndex = (holdingsPage - 1) * holdingsPerPage;
@@ -28,12 +52,12 @@ export default function StatisticsView({ trades, holdings, stats, baseline_vw, e
// Calculate excess return (Evatraders return - benchmark value-weighted return)
const calculateExcessReturn = () => {
if (!stats || !baseline_vw || baseline_vw.length === 0) {
if (!effectiveStats || !baseline_vw || baseline_vw.length === 0) {
return null;
}
// Get Evatraders return from stats
const evatradersReturn = stats.totalReturn || 0; // Already in percentage
const evatradersReturn = effectiveStats.totalReturn || 0; // Already in percentage
// Calculate benchmark return from baseline_vw
// baseline_vw format: [{t: timestamp, v: value}, ...] or [value, ...]
@@ -130,7 +154,7 @@ export default function StatisticsView({ trades, holdings, stats, baseline_vw, e
borderRight: '2px solid #e0e0e0',
overflow: 'hidden'
}}>
{stats ? (
{effectiveStats ? (
<div style={{
padding: '24px',
display: 'flex',
@@ -179,7 +203,7 @@ export default function StatisticsView({ trades, holdings, stats, baseline_vw, e
fontFamily: '"Courier New", monospace',
lineHeight: 1
}}>
${formatNumber(stats.totalAssetValue || 0)}
${formatNumber(effectiveStats.totalAssetValue || 0)}
</div>
</div>
@@ -272,10 +296,10 @@ export default function StatisticsView({ trades, holdings, stats, baseline_vw, e
<div style={{
fontSize: 28,
fontWeight: 700,
color: (stats.totalReturn || 0) >= 0 ? '#00C853' : '#FF1744',
color: (effectiveStats.totalReturn || 0) >= 0 ? '#00C853' : '#FF1744',
fontFamily: '"Courier New", monospace'
}}>
{(stats.totalReturn || 0) >= 0 ? '+' : ''}{(stats.totalReturn || 0).toFixed(2)}%
{(effectiveStats.totalReturn || 0) >= 0 ? '+' : ''}{(effectiveStats.totalReturn || 0).toFixed(2)}%
</div>
</div>
</div>
@@ -304,7 +328,7 @@ export default function StatisticsView({ trades, holdings, stats, baseline_vw, e
color: '#000000',
fontFamily: '"Courier New", monospace'
}}>
${formatNumber(stats.cashPosition || 0)}
${formatNumber(effectiveStats.cashPosition || 0)}
</div>
</div>
@@ -330,13 +354,13 @@ export default function StatisticsView({ trades, holdings, stats, baseline_vw, e
color: '#000000',
fontFamily: '"Courier New", monospace'
}}>
{stats.totalTrades || 0}
{effectiveStats.totalTrades || 0}
</div>
</div>
</div>
{/* Ticker Weights - Compact */}
{stats.tickerWeights && Object.keys(stats.tickerWeights).length > 0 && (
{effectiveStats?.tickerWeights && Object.keys(effectiveStats.tickerWeights).length > 0 && (
<div style={{
marginTop: 'auto',
paddingTop: 20,
@@ -358,7 +382,7 @@ export default function StatisticsView({ trades, holdings, stats, baseline_vw, e
gap: 8,
maxHeight: 120
}}>
{Object.entries(stats.tickerWeights).map(([ticker, weight]) => {
{Object.entries(effectiveStats.tickerWeights).map(([ticker, weight]) => {
const weightValue = Number(weight);
const isNegative = weightValue < 0;
const displayWeight = (weightValue * 100).toFixed(1);

View File

@@ -33,6 +33,9 @@ export default function StockExplainView({
insiderTradesSnapshot,
technicalIndicatorsSnapshot,
onRequestRangeExplain,
onRequestHistory,
onRequestExplainEvents,
onRequestNews,
onRequestNewsForDate,
onRequestStory,
onRequestInsiderTrades,
@@ -77,6 +80,7 @@ export default function StockExplainView({
visibleNews,
newsCategories,
visibleNewsByCategory,
selectedNewsFreshness,
selectedRangeWindow,
selectedRangeExplain,
latestSignal,
@@ -141,11 +145,37 @@ export default function StockExplainView({
setActiveNewsSentiment('all');
}, [selectedSymbol, selectedEventDate]);
useEffect(() => {
if (!selectedSymbol) {
return;
}
if (onRequestHistory && (!Array.isArray(ohlcHistoryByTicker?.[selectedSymbol]) || ohlcHistoryByTicker[selectedSymbol].length === 0)) {
onRequestHistory(selectedSymbol);
}
if (onRequestExplainEvents && !explainEventsSnapshot) {
onRequestExplainEvents(selectedSymbol);
}
if (onRequestNews && (!Array.isArray(newsSnapshot?.items) || newsSnapshot.items.length === 0)) {
onRequestNews(selectedSymbol);
}
}, [
explainEventsSnapshot,
newsSnapshot,
ohlcHistoryByTicker,
onRequestExplainEvents,
onRequestHistory,
onRequestNews,
selectedSymbol,
]);
useEffect(() => {
if (!selectedSymbol || !selectedEventDate || !onRequestNewsForDate) {
return;
}
if (Array.isArray(newsSnapshot?.byDate?.[selectedEventDate]) && newsSnapshot.byDate[selectedEventDate].length > 0) {
if (Object.prototype.hasOwnProperty.call(newsSnapshot?.byDate || {}, selectedEventDate)) {
return;
}
onRequestNewsForDate(selectedSymbol, selectedEventDate);
@@ -155,21 +185,21 @@ export default function StockExplainView({
if (!selectedSymbol || !onRequestStory || !currentDate) {
return;
}
if (selectedStory?.story) {
if (Object.prototype.hasOwnProperty.call(newsSnapshot?.storyCache || {}, currentDate)) {
return;
}
onRequestStory(selectedSymbol, currentDate);
}, [currentDate, onRequestStory, selectedStory, selectedSymbol]);
}, [currentDate, newsSnapshot, onRequestStory, selectedStory, selectedSymbol]);
useEffect(() => {
if (!selectedSymbol || !selectedEventDate || !onRequestSimilarDays) {
return;
}
if (selectedSimilarDays?.items?.length) {
if (Object.prototype.hasOwnProperty.call(newsSnapshot?.similarDaysCache || {}, selectedEventDate)) {
return;
}
onRequestSimilarDays(selectedSymbol, selectedEventDate);
}, [onRequestSimilarDays, selectedEventDate, selectedSimilarDays, selectedSymbol]);
}, [newsSnapshot, onRequestSimilarDays, selectedEventDate, selectedSimilarDays, selectedSymbol]);
useEffect(() => {
if (!selectedSymbol || !onRequestTechnicalIndicators) {
@@ -337,6 +367,7 @@ export default function StockExplainView({
newsSnapshot={newsSnapshot}
visibleNewsByCategory={visibleNewsByCategory}
visibleNews={visibleNews}
selectedNewsFreshness={selectedNewsFreshness}
activeNewsCategory={activeNewsCategory}
onSelectNewsCategory={setActiveNewsCategory}
activeNewsSentiment={activeNewsSentiment}

View File

@@ -38,6 +38,18 @@ export default function TraderView({
onWorkspaceFileSave,
onUploadExternalSkill
}) {
const srOnlyStyle = {
position: 'absolute',
width: 1,
height: 1,
padding: 0,
margin: -1,
overflow: 'hidden',
clip: 'rect(0, 0, 0, 0)',
whiteSpace: 'nowrap',
border: 0
};
const [expandedSkillKey, setExpandedSkillKey] = useState(null);
const [newLocalSkillName, setNewLocalSkillName] = useState('');
const [externalSkillFile, setExternalSkillFile] = useState(null);
@@ -460,6 +472,9 @@ export default function TraderView({
本地技能 SKILL.md
</div>
<textarea
id={`local-skill-${selectedAgentId}-${skill.skill_name}`}
name={`local_skill_${selectedAgentId}_${skill.skill_name}`}
aria-label={`${skill.skill_name} 本地技能内容`}
value={skillDraft}
onChange={(e) => onLocalSkillDraftChange(skill.skill_name, e.target.value)}
style={{
@@ -557,6 +572,9 @@ export default function TraderView({
</div>
<textarea
id={`workspace-editor-${selectedAgentId}-${selectedWorkspaceFile || 'file'}`}
name={`workspace_editor_${selectedAgentId}_${selectedWorkspaceFile || 'file'}`}
aria-label={`编辑 ${selectedWorkspaceFile || '工作区文件'} 内容`}
value={workspaceDraftContent}
onChange={(e) => onWorkspaceDraftChange(e.target.value)}
placeholder={isWorkspaceFileLoading ? '加载中...' : '输入 markdown 内容'}
@@ -687,7 +705,13 @@ export default function TraderView({
}}>
<div style={{ fontSize: 12, fontWeight: 800, color: '#111111' }}>创建本地技能</div>
<div style={{ display: 'flex', gap: 8, alignItems: 'center' }}>
<label htmlFor="new-local-skill-name" style={srOnlyStyle}>
输入本地技能名称
</label>
<input
id="new-local-skill-name"
name="new_local_skill_name"
aria-label="输入本地技能名称"
value={newLocalSkillName}
onChange={(e) => setNewLocalSkillName(e.target.value)}
placeholder="输入技能名,例如 event_playbook"
@@ -741,7 +765,13 @@ export default function TraderView({
支持上传 .zip包内需包含一个技能目录及 SKILL.md
</div>
<div style={{ display: 'flex', gap: 8, alignItems: 'center', flexWrap: 'wrap' }}>
<label htmlFor="external-skill-zip" style={srOnlyStyle}>
上传外部技能 zip
</label>
<input
id="external-skill-zip"
name="external_skill_zip"
aria-label="上传外部技能 zip 包"
type="file"
accept=".zip,application/zip"
onChange={async (e) => {

View File

@@ -19,6 +19,18 @@ export default function WatchlistPanel({
onSuggestionClick,
onSave
}) {
const srOnlyStyle = {
position: 'absolute',
width: 1,
height: 1,
padding: 0,
margin: -1,
overflow: 'hidden',
clip: 'rect(0, 0, 0, 0)',
whiteSpace: 'nowrap',
border: 0
};
return (
<div style={{ display: 'flex', alignItems: 'center', gap: 8, minWidth: 0, position: 'relative', marginLeft: -6 }}>
<button
@@ -117,7 +129,13 @@ export default function WatchlistPanel({
</div>
<div style={{ display: 'flex', gap: 8 }}>
<label htmlFor="watchlist-symbol-input" style={srOnlyStyle}>
输入股票代码
</label>
<input
id="watchlist-symbol-input"
name="watchlist_symbol"
aria-label="输入股票代码"
value={inputValue}
onChange={(e) => onInputChange(e.target.value)}
onKeyDown={onInputKeyDown}

View File

@@ -0,0 +1,107 @@
import React from 'react';
import { formatDateTime, formatNumber } from '../../utils/formatters';
export default function ExplainInsiderSection({
insiderTrades,
selectedSymbol,
isOpen,
onToggle,
onRequest,
}) {
const handleRefresh = () => {
if (onRequest) {
onRequest(selectedSymbol);
}
};
return (
<div className="section">
<div className="section-header">
<h2 className="section-title">内部人交易</h2>
<div style={{ display: 'flex', alignItems: 'center', gap: 12, flexWrap: 'wrap' }}>
<div style={{ fontSize: 11, color: '#666666' }}>
{insiderTrades.length} 笔内部人交易记录
</div>
<button
onClick={handleRefresh}
style={{
border: '1px solid #111111',
background: '#ffffff',
color: '#111111',
padding: '5px 8px',
fontFamily: 'inherit',
fontSize: 10,
cursor: 'pointer'
}}
>
刷新
</button>
<button
onClick={onToggle}
style={{
border: '1px solid #111111',
background: isOpen ? '#111111' : '#ffffff',
color: isOpen ? '#ffffff' : '#111111',
padding: '7px 10px',
fontFamily: 'inherit',
fontSize: 11,
fontWeight: 700,
cursor: 'pointer'
}}
>
{isOpen ? '收起' : `展开 ${insiderTrades.length}`}
</button>
</div>
</div>
{!isOpen ? (
<div className="empty-state">点击展开查看内部人交易详情</div>
) : insiderTrades.length === 0 ? (
<div className="empty-state">暂无内部人交易数据</div>
) : (
<div className="table-wrapper">
<table className="data-table">
<thead>
<tr>
<th>交易日期</th>
<th>内部人</th>
<th>职位</th>
<th>方向</th>
<th>股份数</th>
<th>价格</th>
<th>持仓变化</th>
</tr>
</thead>
<tbody>
{insiderTrades.slice(0, 20).map((trade, index) => {
const isBuy = trade.is_buy;
const holdingChange = trade.holding_change;
return (
<tr key={trade.transaction_date + '-' + trade.name + '-' + index}>
<td>{trade.transaction_date || '-'}</td>
<td>{trade.name || '-'}</td>
<td>{trade.title || '-'}</td>
<td style={{
fontWeight: 700,
color: isBuy === true ? '#00C853' : isBuy === false ? '#FF1744' : '#666666'
}}>
{isBuy === true ? '买入' : isBuy === false ? '卖出' : '-'}
</td>
<td>{trade.transaction_shares != null ? formatNumber(trade.transaction_shares) : '-'}</td>
<td>${trade.transaction_price_per_share != null ? Number(trade.transaction_price_per_share).toFixed(2) : '-'}</td>
<td style={{
color: holdingChange != null ? (holdingChange > 0 ? '#00C853' : '#FF1744') : '#666666',
fontWeight: holdingChange != null ? 700 : 400
}}>
{holdingChange != null ? (holdingChange > 0 ? '+' : '') + formatNumber(holdingChange) : '-'}
</td>
</tr>
);
})}
</tbody>
</table>
</div>
)}
</div>
);
}

View File

@@ -1,6 +1,12 @@
import React from 'react';
import { formatDateTime } from '../../utils/formatters';
function renderFreshness(freshness) {
if (!freshness || typeof freshness !== 'object') return null;
const lastFetch = freshness.last_news_fetch || '-';
return `新闻更新到 ${lastFetch}${freshness.refreshed ? ' · 本次已刷新' : ''}`;
}
function categoryLabel(value) {
const normalized = String(value || '').trim().toLowerCase();
const labels = {
@@ -47,6 +53,7 @@ export default function ExplainNewsSection({
newsSnapshot,
visibleNewsByCategory,
visibleNews,
selectedNewsFreshness,
activeNewsCategory,
onSelectNewsCategory,
activeNewsSentiment,
@@ -64,6 +71,11 @@ export default function ExplainNewsSection({
<div style={{ fontSize: 11, color: '#666666' }}>
{newsSnapshot?.source ? `最近 ${visibleNewsByCategory.length} 条 · ${newsSnapshot.source}` : `最近 ${visibleNewsByCategory.length} 条真实新闻`}
</div>
{renderFreshness(selectedNewsFreshness) ? (
<div style={{ fontSize: 11, color: '#666666' }}>
{renderFreshness(selectedNewsFreshness)}
</div>
) : null}
<button
onClick={onToggle}
style={{

View File

@@ -11,6 +11,37 @@ export default function ExplainPriceSection({
isOpen,
onToggle,
}) {
const timeTicks = (() => {
const candles = Array.isArray(chartModel?.candles) ? chartModel.candles : [];
if (!candles.length) {
return [];
}
const targetCount = Math.min(4, candles.length);
const step = Math.max(1, Math.floor((candles.length - 1) / Math.max(targetCount - 1, 1)));
const ticks = [];
for (let index = 0; index < candles.length; index += step) {
const candle = candles[index];
const rawLabel = candle.startLabel || candle.time || candle.date || '';
ticks.push({
x: candle.centerX,
label: String(rawLabel).slice(5, 16).replace('T', ' '),
});
}
const lastCandle = candles[candles.length - 1];
const lastLabel = String(lastCandle.endLabel || lastCandle.time || lastCandle.date || '').slice(5, 16).replace('T', ' ');
if (ticks.length === 0 || ticks[ticks.length - 1]?.x !== lastCandle.centerX) {
ticks.push({
x: lastCandle.centerX,
label: lastLabel,
});
}
return ticks;
})();
return (
<div className="section">
<div className="section-header">
@@ -66,12 +97,35 @@ export default function ExplainPriceSection({
strokeWidth="1"
/>
{timeTicks.map((tick) => (
<g key={`${tick.x}-${tick.label}`}>
<line
x1={tick.x}
y1={chartModel.height - chartModel.padding}
x2={tick.x}
y2={chartModel.height - chartModel.padding + 4}
stroke="#666666"
strokeWidth="1"
/>
<text
x={tick.x}
y={chartModel.height - chartModel.padding + 16}
fontSize="10"
fill="#666666"
textAnchor="middle"
>
{tick.label}
</text>
</g>
))}
{chartModel.candles.length > 1 ? chartModel.candles.map((candle) => {
const rising = candle.close >= candle.open;
const stroke = rising ? '#00C853' : '#FF1744';
const fill = rising ? 'rgba(0, 200, 83, 0.16)' : 'rgba(255, 23, 68, 0.16)';
return (
<g key={candle.id}>
<title>{`${candle.startLabel || candle.time || candle.date || ''}${candle.endLabel || candle.time || candle.date || ''}`}</title>
<line
x1={candle.centerX}
y1={candle.highY}
@@ -123,7 +177,7 @@ export default function ExplainPriceSection({
stroke={marker.isSelected ? '#111111' : '#ffffff'}
strokeWidth={marker.isSelected ? '2.5' : '2'}
/>
<title>{`${marker.title} · ${marker.dateKey || ''}${marker.count ? ` · ${marker.count} 条新闻` : ''}`}</title>
<title>{`${marker.title} · ${marker.timestamp || marker.dateKey || ''}${marker.count ? ` · ${marker.count} 条新闻` : ''}`}</title>
</g>
);
})}

View File

@@ -1,6 +1,12 @@
import React from 'react';
import { formatTickerPrice } from '../../utils/formatters';
function renderFreshness(freshness) {
if (!freshness || typeof freshness !== 'object') return null;
const lastFetch = freshness.last_news_fetch || '-';
return `新闻更新到 ${lastFetch}${freshness.refreshed ? ' · 本次已刷新' : ''}`;
}
function renderSentimentLabel(value) {
const normalized = String(value || '').trim().toLowerCase();
if (normalized === 'positive') return '利多';
@@ -94,6 +100,11 @@ export default function ExplainRangeSection({
: `分析来源 · ${renderAnalysisSourceLabel(selectedRangeExplain.analysis.analysis_source)}`}
</div>
) : null}
{renderFreshness(selectedRangeExplain?.freshness) ? (
<div style={{ fontSize: 11, color: '#666666' }}>
{renderFreshness(selectedRangeExplain?.freshness)}
</div>
) : null}
<button
onClick={onToggle}
style={{

Some files were not shown because too many files have changed in this diff Show More