diff --git a/README_zh.md b/README_zh.md index 9fecf15..479b062 100644 --- a/README_zh.md +++ b/README_zh.md @@ -117,6 +117,54 @@ evotraders frontend # 默认连接 8765 端口, 你可以修改 . 访问 `http://localhost:5173/` 查看交易大厅,选择日期并点击 Run/Replay 观察决策过程。 +### 迁移期服务边界说明 + +当前仓库正处于从模块化单体向独立服务迁移的阶段,当前默认开发路径已经切到独立 app surface: + +- `backend.apps.agent_service` +- `backend.apps.runtime_service` +- `backend.apps.trading_service` +- `backend.apps.news_service` + +当前本地开发默认推荐直接运行拆分后的服务: + +```bash +./start-dev.sh split + +# 或分别手动启动 +python -m uvicorn backend.apps.agent_service:app --port 8000 --reload +python -m uvicorn backend.apps.runtime_service:app --port 8003 --reload +python -m uvicorn backend.apps.trading_service:app --port 8001 --reload +python -m uvicorn backend.apps.news_service:app --port 8002 --reload +``` + +迁移期关键环境变量: + +```bash +# 后端 Gateway 优先走独立服务读取 +NEWS_SERVICE_URL=http://localhost:8002 +TRADING_SERVICE_URL=http://localhost:8001 + +# 前端浏览器直连控制面 / 运行时面 +VITE_CONTROL_API_BASE_URL=http://localhost:8000/api +VITE_RUNTIME_API_BASE_URL=http://localhost:8003/api/runtime + +# 前端浏览器优先直连独立服务 +VITE_NEWS_SERVICE_URL=http://localhost:8002 +VITE_TRADING_SERVICE_URL=http://localhost:8001 +``` + +目前前端已支持直连 `news-service` 的 explain 只读路径包括: + +- runtime panel / gateway port 查询已可独立指向 `runtime-service` +- story +- similar days +- range explain +- news for date +- news categories + +如果没有配置这些变量,系统会继续走当前保留的本地回退逻辑。 + --- ## 系统架构 diff --git a/backend/app.py b/backend/app.py deleted file mode 100644 index 940941d..0000000 --- a/backend/app.py +++ /dev/null @@ -1,115 +0,0 @@ -# -*- coding: utf-8 -*- -""" -FastAPI Application - REST API for EvoTraders - -Provides HTTP endpoints for: -- Agent management -- Workspace management -- Tool guard operations -- Health checks -""" -from contextlib import asynccontextmanager -from pathlib import Path -from typing import AsyncGenerator - -from fastapi import FastAPI -from fastapi.middleware.cors import CORSMiddleware - -from backend.api import agents_router, workspaces_router, guard_router, runtime_router -from backend.agents import AgentFactory, WorkspaceManager, get_registry - - -# Global instances (initialized on startup) -agent_factory: AgentFactory | None = None -workspace_manager: WorkspaceManager | None = None - - -@asynccontextmanager -async def lifespan(app: FastAPI) -> AsyncGenerator: - """ - Application lifespan manager. - - Initializes global services on startup and cleans up on shutdown. - """ - global agent_factory, workspace_manager - - # Startup: Initialize services - project_root = Path(__file__).parent.parent - - # Initialize workspace manager - workspace_manager = WorkspaceManager(project_root=project_root) - - # Initialize agent factory - agent_factory = AgentFactory(project_root=project_root) - - # Ensure workspaces root exists - agent_factory.workspaces_root.mkdir(parents=True, exist_ok=True) - - # Get or create global registry - registry = get_registry() - - print(f"✓ EvoTraders API started") - print(f" - Workspaces root: {agent_factory.workspaces_root}") - print(f" - Registered agents: {registry.get_agent_count()}") - - yield - - # Shutdown: Cleanup - print("✓ EvoTraders API shutting down") - - -# Create FastAPI application -app = FastAPI( - title="EvoTraders API", - description="REST API for the EvoTraders multi-agent trading system", - version="0.1.0", - lifespan=lifespan, -) - -# CORS middleware -app.add_middleware( - CORSMiddleware, - allow_origins=["*"], # Configure appropriately for production - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - - -# Health check endpoint -@app.get("/health") -async def health_check(): - """Health check endpoint.""" - registry = get_registry() - return { - "status": "healthy", - "version": "0.1.0", - "agents_registered": registry.get_agent_count(), - "workspaces_available": len(workspace_manager.list_workspaces()) if workspace_manager else 0, - } - - -# API status endpoint -@app.get("/api/status") -async def api_status(): - """Get API status and system information.""" - registry = get_registry() - stats = registry.get_stats() - - return { - "status": "operational", - "registry": stats, - } - - -# Include routers -app.include_router(workspaces_router) -app.include_router(agents_router) -app.include_router(guard_router) -app.include_router(runtime_router) - - -# Main entry point for running with uvicorn -if __name__ == "__main__": - import uvicorn - uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/backend/apps/__init__.py b/backend/apps/__init__.py new file mode 100644 index 0000000..7e10106 --- /dev/null +++ b/backend/apps/__init__.py @@ -0,0 +1,27 @@ +# -*- coding: utf-8 -*- +"""Application surfaces for progressive service extraction.""" + +from .agent_service import app as agent_app +from .agent_service import create_app as create_agent_app +from .news_service import app as news_app +from .news_service import create_app as create_news_app +from .runtime_service import app as runtime_app +from .runtime_service import create_app as create_runtime_app +from .trading_service import app as trading_app +from .trading_service import create_app as create_trading_app + +app = agent_app +create_app = create_agent_app + +__all__ = [ + "app", + "create_app", + "agent_app", + "create_agent_app", + "news_app", + "create_news_app", + "runtime_app", + "create_runtime_app", + "trading_app", + "create_trading_app", +] diff --git a/backend/apps/agent_service.py b/backend/apps/agent_service.py new file mode 100644 index 0000000..e9812f5 --- /dev/null +++ b/backend/apps/agent_service.py @@ -0,0 +1,94 @@ +# -*- coding: utf-8 -*- +"""Agent control-plane FastAPI surface.""" + +from __future__ import annotations + +from contextlib import asynccontextmanager +from pathlib import Path +from typing import AsyncGenerator + +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware + +from backend.api import agents_router, guard_router, workspaces_router +from backend.agents import AgentFactory, WorkspaceManager, get_registry + +# Global instances (initialized on startup) +agent_factory: AgentFactory | None = None +workspace_manager: WorkspaceManager | None = None + + +def create_app(project_root: Path | None = None) -> FastAPI: + """Create the agent control-plane app.""" + resolved_project_root = project_root or Path(__file__).resolve().parents[2] + + @asynccontextmanager + async def lifespan(_app: FastAPI) -> AsyncGenerator[None, None]: + """Initialize workspace and registry state for the control plane.""" + global agent_factory, workspace_manager + + workspace_manager = WorkspaceManager(project_root=resolved_project_root) + agent_factory = AgentFactory(project_root=resolved_project_root) + agent_factory.workspaces_root.mkdir(parents=True, exist_ok=True) + + registry = get_registry() + print("✓ EvoTraders API started") + print(f" - Workspaces root: {agent_factory.workspaces_root}") + print(f" - Registered agents: {registry.get_agent_count()}") + + yield + + print("✓ EvoTraders API shutting down") + + app = FastAPI( + title="EvoTraders Agent Service", + description="REST API for the EvoTraders multi-agent control plane", + version="0.1.0", + lifespan=lifespan, + ) + + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + @app.get("/health") + async def health_check() -> dict[str, object]: + """Health check endpoint.""" + registry = get_registry() + return { + "status": "healthy", + "version": "0.1.0", + "agents_registered": registry.get_agent_count(), + "workspaces_available": ( + len(workspace_manager.list_workspaces()) + if workspace_manager + else 0 + ), + } + + @app.get("/api/status") + async def api_status() -> dict[str, object]: + """Get API status and registry information.""" + registry = get_registry() + return { + "status": "operational", + "registry": registry.get_stats(), + } + + app.include_router(workspaces_router) + app.include_router(agents_router) + app.include_router(guard_router) + return app + + +app = create_app() + + +if __name__ == "__main__": + import uvicorn + + uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/backend/apps/news_service.py b/backend/apps/news_service.py new file mode 100644 index 0000000..88b6204 --- /dev/null +++ b/backend/apps/news_service.py @@ -0,0 +1,153 @@ +# -*- coding: utf-8 -*- +"""News and explain FastAPI surface.""" + +from __future__ import annotations + +from typing import Any + +from fastapi import Depends, FastAPI, Query +from fastapi.middleware.cors import CORSMiddleware + +from backend.data.market_store import MarketStore +from backend.domains import news as news_domain + + +def get_market_store() -> MarketStore: + """Create a market store dependency.""" + return MarketStore() + + +def create_app() -> FastAPI: + """Create the news/explain service app.""" + app = FastAPI( + title="EvoTraders News Service", + description="Read-only news enrichment and explain service surface extracted from the monolith", + version="0.1.0", + ) + + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + @app.get("/health") + async def health_check() -> dict[str, str]: + return {"status": "healthy", "service": "news-service"} + + @app.get("/api/enriched-news") + async def api_get_enriched_news( + ticker: str = Query(..., min_length=1), + start_date: str | None = Query(None), + end_date: str | None = Query(None), + limit: int = Query(100, ge=1, le=1000), + store: MarketStore = Depends(get_market_store), + ) -> dict[str, Any]: + return news_domain.get_enriched_news( + store, + ticker=ticker, + start_date=start_date, + end_date=end_date, + limit=limit, + ) + + @app.get("/api/news-for-date") + async def api_get_news_for_date( + ticker: str = Query(..., min_length=1), + date: str = Query(...), + limit: int = Query(20, ge=1, le=100), + store: MarketStore = Depends(get_market_store), + ) -> dict[str, Any]: + return news_domain.get_news_for_date( + store, + ticker=ticker, + date=date, + limit=limit, + ) + + @app.get("/api/news-timeline") + async def api_get_news_timeline( + ticker: str = Query(..., min_length=1), + start_date: str = Query(...), + end_date: str = Query(...), + store: MarketStore = Depends(get_market_store), + ) -> dict[str, Any]: + return news_domain.get_news_timeline( + store, + ticker=ticker, + start_date=start_date, + end_date=end_date, + ) + + @app.get("/api/categories") + async def api_get_categories( + ticker: str = Query(..., min_length=1), + start_date: str | None = Query(None), + end_date: str | None = Query(None), + limit: int = Query(200, ge=1, le=1000), + store: MarketStore = Depends(get_market_store), + ) -> dict[str, Any]: + return news_domain.get_news_categories( + store, + ticker=ticker, + start_date=start_date, + end_date=end_date, + limit=limit, + ) + + @app.get("/api/similar-days") + async def api_get_similar_days( + ticker: str = Query(..., min_length=1), + date: str = Query(...), + n_similar: int = Query(5, ge=1, le=20), + store: MarketStore = Depends(get_market_store), + ) -> dict[str, Any]: + return news_domain.get_similar_days_payload( + store, + ticker=ticker, + date=date, + n_similar=n_similar, + ) + + @app.get("/api/stories/{ticker}") + async def api_get_story( + ticker: str, + as_of_date: str = Query(...), + store: MarketStore = Depends(get_market_store), + ) -> dict[str, Any]: + return news_domain.get_story_payload( + store, + ticker=ticker, + as_of_date=as_of_date, + ) + + @app.get("/api/range-explain") + async def api_get_range_explain( + ticker: str = Query(..., min_length=1), + start_date: str = Query(...), + end_date: str = Query(...), + article_ids: list[str] = Query(default=[]), + limit: int = Query(100, ge=1, le=500), + store: MarketStore = Depends(get_market_store), + ) -> dict[str, Any]: + return news_domain.get_range_explain_payload( + store, + ticker=ticker, + start_date=start_date, + end_date=end_date, + article_ids=article_ids, + limit=limit, + ) + + return app + + +app = create_app() + + +if __name__ == "__main__": + import uvicorn + + uvicorn.run(app, host="0.0.0.0", port=8002) diff --git a/backend/apps/runtime_service.py b/backend/apps/runtime_service.py new file mode 100644 index 0000000..c6014a6 --- /dev/null +++ b/backend/apps/runtime_service.py @@ -0,0 +1,68 @@ +# -*- coding: utf-8 -*- +"""Dedicated runtime service FastAPI surface.""" + +from __future__ import annotations + +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware + +from backend.api import runtime_router +from backend.api.runtime import get_runtime_state + + +def create_app() -> FastAPI: + """Create the runtime service app.""" + app = FastAPI( + title="EvoTraders Runtime Service", + description="Runtime lifecycle and gateway service surface extracted from the monolith", + version="0.1.0", + ) + + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + @app.get("/health") + async def health_check() -> dict[str, object]: + """Health check for the runtime service.""" + runtime_state = get_runtime_state() + process = runtime_state.gateway_process + is_running = process is not None and process.poll() is None + return { + "status": "healthy", + "service": "runtime-service", + "gateway_running": is_running, + "gateway_port": runtime_state.gateway_port, + } + + @app.get("/api/status") + async def api_status() -> dict[str, object]: + """Service-level status payload for runtime orchestration.""" + runtime_state = get_runtime_state() + process = runtime_state.gateway_process + is_running = process is not None and process.poll() is None + return { + "status": "operational", + "service": "runtime-service", + "runtime": { + "gateway_running": is_running, + "gateway_port": runtime_state.gateway_port, + "has_runtime_manager": runtime_state.runtime_manager is not None, + }, + } + + app.include_router(runtime_router) + return app + + +app = create_app() + + +if __name__ == "__main__": + import uvicorn + + uvicorn.run(app, host="0.0.0.0", port=8003) diff --git a/backend/apps/trading_service.py b/backend/apps/trading_service.py new file mode 100644 index 0000000..ccd8f56 --- /dev/null +++ b/backend/apps/trading_service.py @@ -0,0 +1,142 @@ +# -*- coding: utf-8 -*- +"""Trading data FastAPI surface.""" + +from __future__ import annotations + +from typing import Any + +from fastapi import FastAPI, Query +from fastapi.middleware.cors import CORSMiddleware + +from backend.domains import trading as trading_domain +from shared.schema import ( + CompanyNewsResponse, + FinancialMetricsResponse, + InsiderTradeResponse, + LineItemResponse, + PriceResponse, +) + + +def create_app() -> FastAPI: + """Create the trading data service app.""" + app = FastAPI( + title="EvoTraders Trading Service", + description="Read-only trading data service surface extracted from the monolith", + version="0.1.0", + ) + + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + @app.get("/health") + async def health_check() -> dict[str, str]: + """Health check endpoint.""" + return {"status": "healthy", "service": "trading-service"} + + @app.get("/api/prices", response_model=PriceResponse) + async def api_get_prices( + ticker: str = Query(..., min_length=1), + start_date: str = Query(...), + end_date: str = Query(...), + ) -> PriceResponse: + payload = trading_domain.get_prices_payload( + ticker=ticker, + start_date=start_date, + end_date=end_date, + ) + return PriceResponse(ticker=payload["ticker"], prices=payload["prices"]) + + @app.get("/api/financials", response_model=FinancialMetricsResponse) + async def api_get_financials( + ticker: str = Query(..., min_length=1), + end_date: str = Query(...), + period: str = Query("ttm"), + limit: int = Query(10, ge=1, le=100), + ) -> FinancialMetricsResponse: + payload = trading_domain.get_financials_payload( + ticker=ticker, + end_date=end_date, + period=period, + limit=limit, + ) + return FinancialMetricsResponse(financial_metrics=payload["financial_metrics"]) + + @app.get("/api/news", response_model=CompanyNewsResponse) + async def api_get_news( + ticker: str = Query(..., min_length=1), + end_date: str = Query(...), + start_date: str | None = Query(None), + limit: int = Query(1000, ge=1, le=5000), + ) -> CompanyNewsResponse: + payload = trading_domain.get_news_payload( + ticker=ticker, + end_date=end_date, + start_date=start_date, + limit=limit, + ) + return CompanyNewsResponse(news=payload["news"]) + + @app.get("/api/insider-trades", response_model=InsiderTradeResponse) + async def api_get_insider_trades( + ticker: str = Query(..., min_length=1), + end_date: str = Query(...), + start_date: str | None = Query(None), + limit: int = Query(1000, ge=1, le=5000), + ) -> InsiderTradeResponse: + payload = trading_domain.get_insider_trades_payload( + ticker=ticker, + end_date=end_date, + start_date=start_date, + limit=limit, + ) + return InsiderTradeResponse(insider_trades=payload["insider_trades"]) + + @app.get("/api/market/status") + async def api_get_market_status() -> dict[str, Any]: + """Return current market status using the existing market service logic.""" + return trading_domain.get_market_status_payload() + + @app.get("/api/market-cap") + async def api_get_market_cap( + ticker: str = Query(..., min_length=1), + end_date: str = Query(...), + ) -> dict[str, Any]: + """Return market cap for one ticker/date.""" + return trading_domain.get_market_cap_payload( + ticker=ticker, + end_date=end_date, + ) + + @app.get("/api/line-items", response_model=LineItemResponse) + async def api_get_line_items( + ticker: str = Query(..., min_length=1), + line_items: list[str] = Query(...), + end_date: str = Query(...), + period: str = Query("ttm"), + limit: int = Query(10, ge=1, le=100), + ) -> LineItemResponse: + payload = trading_domain.get_line_items_payload( + ticker=ticker, + line_items=line_items, + end_date=end_date, + period=period, + limit=limit, + ) + return LineItemResponse(search_results=payload["search_results"]) + + return app + + +app = create_app() + + +if __name__ == "__main__": + import uvicorn + + uvicorn.run(app, host="0.0.0.0", port=8001) diff --git a/backend/core/__init__.py b/backend/core/__init__.py index 1e767f7..6fab6e7 100644 --- a/backend/core/__init__.py +++ b/backend/core/__init__.py @@ -1,9 +1,30 @@ # -*- coding: utf-8 -*- -"""Core pipeline and orchestration logic""" +"""Core pipeline and orchestration logic. + +Keep ``pipeline_runner`` behind lazy wrappers so importing ``backend.core`` does +not immediately pull in the gateway runtime graph. +""" from .pipeline import TradingPipeline from .state_sync import StateSync -from .pipeline_runner import create_agents, create_long_term_memory, stop_gateway + + +def create_agents(*args, **kwargs): + from .pipeline_runner import create_agents as _create_agents + + return _create_agents(*args, **kwargs) + + +def create_long_term_memory(*args, **kwargs): + from .pipeline_runner import create_long_term_memory as _create_long_term_memory + + return _create_long_term_memory(*args, **kwargs) + + +def stop_gateway(*args, **kwargs): + from .pipeline_runner import stop_gateway as _stop_gateway + + return _stop_gateway(*args, **kwargs) __all__ = [ "TradingPipeline", diff --git a/backend/data/provider_router.py b/backend/data/provider_router.py index 1e016b1..1bf4740 100644 --- a/backend/data/provider_router.py +++ b/backend/data/provider_router.py @@ -11,7 +11,7 @@ import pandas as pd import yfinance as yf from backend.config.data_config import DataSource, get_data_sources -from backend.data.schema import ( +from shared.schema import ( CompanyFactsResponse, CompanyNews, CompanyNewsResponse, diff --git a/backend/data/schema.py b/backend/data/schema.py index c9e4bde..24fc19f 100644 --- a/backend/data/schema.py +++ b/backend/data/schema.py @@ -1,194 +1,50 @@ # -*- coding: utf-8 -*- -from pydantic import BaseModel +"""Compatibility schema bridge. +This module preserves the legacy ``backend.data.schema`` import path while +delegating the actual schema definitions to ``shared.schema``. Keeping one +canonical DTO set avoids drift as the monolith is split into service-specific +packages. +""" -class Price(BaseModel): - open: float - close: float - high: float - low: float - volume: int - time: str +from shared.schema import ( + AgentStateData, + AgentStateMetadata, + AnalystSignal, + CompanyFacts, + CompanyFactsResponse, + CompanyNews, + CompanyNewsResponse, + FinancialMetrics, + FinancialMetricsResponse, + InsiderTrade, + InsiderTradeResponse, + LineItem, + LineItemResponse, + Portfolio, + Position, + Price, + PriceResponse, + TickerAnalysis, +) - -class PriceResponse(BaseModel): - ticker: str - prices: list[Price] - - -class FinancialMetrics(BaseModel): - ticker: str - report_period: str - period: str - currency: str - market_cap: float | None - enterprise_value: float | None - price_to_earnings_ratio: float | None - price_to_book_ratio: float | None - price_to_sales_ratio: float | None - enterprise_value_to_ebitda_ratio: float | None - enterprise_value_to_revenue_ratio: float | None - free_cash_flow_yield: float | None - peg_ratio: float | None - gross_margin: float | None - operating_margin: float | None - net_margin: float | None - return_on_equity: float | None - return_on_assets: float | None - return_on_invested_capital: float | None - asset_turnover: float | None - inventory_turnover: float | None - receivables_turnover: float | None - days_sales_outstanding: float | None - operating_cycle: float | None - working_capital_turnover: float | None - current_ratio: float | None - quick_ratio: float | None - cash_ratio: float | None - operating_cash_flow_ratio: float | None - debt_to_equity: float | None - debt_to_assets: float | None - interest_coverage: float | None - revenue_growth: float | None - earnings_growth: float | None - book_value_growth: float | None - earnings_per_share_growth: float | None - free_cash_flow_growth: float | None - operating_income_growth: float | None - ebitda_growth: float | None - payout_ratio: float | None - earnings_per_share: float | None - book_value_per_share: float | None - free_cash_flow_per_share: float | None - - -class FinancialMetricsResponse(BaseModel): - financial_metrics: list[FinancialMetrics] - - -class LineItem(BaseModel): - ticker: str - report_period: str - period: str - currency: str - - # Allow additional fields dynamically - model_config = {"extra": "allow"} - - -class LineItemResponse(BaseModel): - search_results: list[LineItem] - - -class InsiderTrade(BaseModel): - ticker: str - issuer: str | None - name: str | None - title: str | None - is_board_director: bool | None - transaction_date: str | None - transaction_shares: float | None - transaction_price_per_share: float | None - transaction_value: float | None - shares_owned_before_transaction: float | None - shares_owned_after_transaction: float | None - security_title: str | None - filing_date: str - - -class InsiderTradeResponse(BaseModel): - insider_trades: list[InsiderTrade] - - -class CompanyNews(BaseModel): - category: str | None = None - ticker: str - title: str - related: str | None = None - source: str - date: str | None = None - url: str - summary: str | None = None - - -class CompanyNewsResponse(BaseModel): - news: list[CompanyNews] - - -class CompanyFacts(BaseModel): - ticker: str - name: str - cik: str | None = None - industry: str | None = None - sector: str | None = None - category: str | None = None - exchange: str | None = None - is_active: bool | None = None - listing_date: str | None = None - location: str | None = None - market_cap: float | None = None - number_of_employees: int | None = None - sec_filings_url: str | None = None - sic_code: str | None = None - sic_industry: str | None = None - sic_sector: str | None = None - website_url: str | None = None - weighted_average_shares: int | None = None - - -class CompanyFactsResponse(BaseModel): - company_facts: CompanyFacts - - -class Position(BaseModel): - """Position information - for Portfolio mode""" - - long: int = 0 # Long position quantity (shares) - short: int = 0 # Short position quantity (shares) - long_cost_basis: float = 0.0 # Long position average cost - short_cost_basis: float = 0.0 # Short position average cost - - -class Portfolio(BaseModel): - """Portfolio - for Portfolio mode""" - - cash: float = 100000.0 # Available cash - positions: dict[str, Position] = {} # ticker -> Position mapping - # Margin requirement (0.0 means shorting disabled, 0.5 means 50% margin) - margin_requirement: float = 0.0 - margin_used: float = 0.0 # Margin used - - -class AnalystSignal(BaseModel): - signal: str | None = None - confidence: float | None = None - reasoning: dict | str | None = None - # Extended fields for richer signal information - reasons: list[str] | None = None # Core drivers/reasons for the signal - risks: list[str] | None = None # Key risk factors - invalidation: str | None = None # Conditions that would invalidate the thesis - next_action: str | None = None # Suggested next action for PM - # Valuation-related fields - intrinsic_value: float | None = None # DCF intrinsic value - fair_value_range: dict | None = None # {bear, base, bull} fair value range - value_gap_pct: float | None = None # Value gap percentage - valuation_methods: list[str] | None = None # List of valuation methods used - max_position_size: float | None = None # For risk management signals - - -class TickerAnalysis(BaseModel): - ticker: str - analyst_signals: dict[str, AnalystSignal] # agent_name -> signal mapping - - -class AgentStateData(BaseModel): - tickers: list[str] - portfolio: Portfolio - start_date: str - end_date: str - ticker_analyses: dict[str, TickerAnalysis] # ticker -> analysis mapping - - -class AgentStateMetadata(BaseModel): - show_reasoning: bool = False - model_config = {"extra": "allow"} +__all__ = [ + "Price", + "PriceResponse", + "FinancialMetrics", + "FinancialMetricsResponse", + "LineItem", + "LineItemResponse", + "InsiderTrade", + "InsiderTradeResponse", + "CompanyNews", + "CompanyNewsResponse", + "CompanyFacts", + "CompanyFactsResponse", + "Position", + "Portfolio", + "AnalystSignal", + "TickerAnalysis", + "AgentStateData", + "AgentStateMetadata", +] diff --git a/backend/domains/__init__.py b/backend/domains/__init__.py new file mode 100644 index 0000000..3dc0930 --- /dev/null +++ b/backend/domains/__init__.py @@ -0,0 +1,2 @@ +# -*- coding: utf-8 -*- +"""Domain modules for split service internals.""" diff --git a/backend/domains/news.py b/backend/domains/news.py new file mode 100644 index 0000000..c3dc2ed --- /dev/null +++ b/backend/domains/news.py @@ -0,0 +1,277 @@ +# -*- coding: utf-8 -*- +"""News/explain domain helpers shared by app surfaces and gateway fallbacks.""" + +from __future__ import annotations + +from typing import Any + +from backend.data.market_store import MarketStore +from backend.data.market_ingest import update_ticker_incremental +from backend.enrich.news_enricher import enrich_news_for_symbol +from backend.explain.range_explainer import build_range_explanation +from backend.explain.similarity_service import find_similar_days +from backend.explain.story_service import get_or_create_stock_story + + +def news_rows_need_enrichment(rows: list[dict[str, Any]]) -> bool: + """Return whether news rows are missing explain-oriented analysis fields.""" + if not rows: + return True + return all( + not row.get("sentiment") + and not row.get("relevance") + and not row.get("key_discussion") + for row in rows + ) + + +def ensure_news_fresh( + store: MarketStore, + *, + ticker: str, + target_date: str | None = None, +) -> dict[str, Any]: + """Refresh raw news incrementally when stored watermarks are stale.""" + normalized_target = str(target_date or "").strip()[:10] + if not normalized_target: + return { + "ticker": ticker, + "target_date": None, + "last_news_fetch": None, + "refreshed": False, + } + + watermarks = store.get_ticker_watermarks(ticker) + last_news_fetch = str(watermarks.get("last_news_fetch") or "").strip()[:10] + refreshed = False + if not last_news_fetch or last_news_fetch < normalized_target: + update_ticker_incremental( + ticker, + end_date=normalized_target, + store=store, + ) + refreshed = True + watermarks = store.get_ticker_watermarks(ticker) + last_news_fetch = str(watermarks.get("last_news_fetch") or "").strip()[:10] + + return { + "ticker": ticker, + "target_date": normalized_target, + "last_news_fetch": last_news_fetch or None, + "refreshed": refreshed, + } + + +def get_enriched_news( + store: MarketStore, + *, + ticker: str, + start_date: str | None = None, + end_date: str | None = None, + limit: int = 100, +) -> dict[str, Any]: + freshness = ensure_news_fresh(store, ticker=ticker, target_date=end_date) + rows = store.get_news_items_enriched( + ticker, + start_date=start_date, + end_date=end_date, + limit=limit, + ) + if news_rows_need_enrichment(rows): + enrich_news_for_symbol( + store, + ticker, + start_date=start_date, + end_date=end_date, + limit=limit, + ) + rows = store.get_news_items_enriched( + ticker, + start_date=start_date, + end_date=end_date, + limit=limit, + ) + return {"ticker": ticker, "news": rows, "freshness": freshness} + + +def get_news_for_date( + store: MarketStore, + *, + ticker: str, + date: str, + limit: int = 20, +) -> dict[str, Any]: + freshness = ensure_news_fresh(store, ticker=ticker, target_date=date) + rows = store.get_news_items_enriched( + ticker, + trade_date=date, + limit=limit, + ) + if news_rows_need_enrichment(rows): + enrich_news_for_symbol( + store, + ticker, + start_date=date, + end_date=date, + limit=limit, + ) + rows = store.get_news_items_enriched( + ticker, + trade_date=date, + limit=limit, + ) + return {"ticker": ticker, "date": date, "news": rows, "freshness": freshness} + + +def get_news_timeline( + store: MarketStore, + *, + ticker: str, + start_date: str, + end_date: str, +) -> dict[str, Any]: + freshness = ensure_news_fresh(store, ticker=ticker, target_date=end_date) + timeline = store.get_news_timeline_enriched( + ticker, + start_date=start_date, + end_date=end_date, + ) + if not timeline: + enrich_news_for_symbol( + store, + ticker, + start_date=start_date, + end_date=end_date, + limit=200, + ) + timeline = store.get_news_timeline_enriched( + ticker, + start_date=start_date, + end_date=end_date, + ) + return { + "ticker": ticker, + "timeline": timeline, + "start_date": start_date, + "end_date": end_date, + "freshness": freshness, + } + + +def get_news_categories( + store: MarketStore, + *, + ticker: str, + start_date: str | None = None, + end_date: str | None = None, + limit: int = 200, +) -> dict[str, Any]: + freshness = ensure_news_fresh(store, ticker=ticker, target_date=end_date) + rows = store.get_news_items_enriched( + ticker, + start_date=start_date, + end_date=end_date, + limit=limit, + ) + if news_rows_need_enrichment(rows): + enrich_news_for_symbol( + store, + ticker, + start_date=start_date, + end_date=end_date, + limit=limit, + ) + categories = store.get_news_categories_enriched( + ticker, + start_date=start_date, + end_date=end_date, + limit=limit, + ) + return {"ticker": ticker, "categories": categories, "freshness": freshness} + + +def get_similar_days_payload( + store: MarketStore, + *, + ticker: str, + date: str, + n_similar: int = 5, +) -> dict[str, Any]: + freshness = ensure_news_fresh(store, ticker=ticker, target_date=date) + result = find_similar_days( + store, + symbol=ticker, + target_date=date, + top_k=n_similar, + ) + result["freshness"] = freshness + return result + + +def get_story_payload( + store: MarketStore, + *, + ticker: str, + as_of_date: str, +) -> dict[str, Any]: + freshness = ensure_news_fresh(store, ticker=ticker, target_date=as_of_date) + enrich_news_for_symbol( + store, + ticker, + end_date=as_of_date, + limit=80, + ) + result = get_or_create_stock_story( + store, + symbol=ticker, + as_of_date=as_of_date, + ) + result["freshness"] = freshness + return result + + +def get_range_explain_payload( + store: MarketStore, + *, + ticker: str, + start_date: str, + end_date: str, + article_ids: list[str] | None = None, + limit: int = 100, +) -> dict[str, Any]: + freshness = ensure_news_fresh(store, ticker=ticker, target_date=end_date) + news_rows = [] + if article_ids: + news_rows = store.get_news_by_ids_enriched(ticker, article_ids) + if not news_rows: + news_rows = store.get_news_items_enriched( + ticker, + start_date=start_date, + end_date=end_date, + limit=limit, + ) + if news_rows_need_enrichment(news_rows): + enrich_news_for_symbol( + store, + ticker, + start_date=start_date, + end_date=end_date, + limit=limit, + ) + news_rows = ( + store.get_news_by_ids_enriched(ticker, article_ids) + if article_ids + else store.get_news_items_enriched( + ticker, + start_date=start_date, + end_date=end_date, + limit=limit, + ) + ) + result = build_range_explanation( + ticker=ticker, + start_date=start_date, + end_date=end_date, + news_rows=news_rows, + ) + return {"ticker": ticker, "result": result, "freshness": freshness} diff --git a/backend/domains/trading.py b/backend/domains/trading.py new file mode 100644 index 0000000..febeffa --- /dev/null +++ b/backend/domains/trading.py @@ -0,0 +1,106 @@ +# -*- coding: utf-8 -*- +"""Trading domain helpers shared by app surfaces and gateway fallbacks.""" + +from __future__ import annotations + +from typing import Any + +from backend.services.market import MarketService +from backend.tools.data_tools import ( + get_company_news, + get_financial_metrics, + get_insider_trades, + get_market_cap, + get_prices, + search_line_items, +) + + +def get_prices_payload(*, ticker: str, start_date: str, end_date: str) -> dict[str, Any]: + return { + "ticker": ticker, + "prices": get_prices(ticker, start_date, end_date), + } + + +def get_financials_payload( + *, + ticker: str, + end_date: str, + period: str = "ttm", + limit: int = 10, +) -> dict[str, Any]: + return { + "financial_metrics": get_financial_metrics( + ticker=ticker, + end_date=end_date, + period=period, + limit=limit, + ) + } + + +def get_news_payload( + *, + ticker: str, + end_date: str, + start_date: str | None = None, + limit: int = 1000, +) -> dict[str, Any]: + return { + "news": get_company_news( + ticker=ticker, + end_date=end_date, + start_date=start_date, + limit=limit, + ) + } + + +def get_insider_trades_payload( + *, + ticker: str, + end_date: str, + start_date: str | None = None, + limit: int = 1000, +) -> dict[str, Any]: + return { + "insider_trades": get_insider_trades( + ticker=ticker, + end_date=end_date, + start_date=start_date, + limit=limit, + ) + } + + +def get_market_status_payload() -> dict[str, Any]: + market_service = MarketService(tickers=[]) + return market_service.get_market_status() + + +def get_market_cap_payload(*, ticker: str, end_date: str) -> dict[str, Any]: + return { + "ticker": ticker, + "end_date": end_date, + "market_cap": get_market_cap(ticker, end_date), + } + + +def get_line_items_payload( + *, + ticker: str, + line_items: list[str], + end_date: str, + period: str = "ttm", + limit: int = 10, +) -> dict[str, Any]: + return { + "search_results": search_line_items( + ticker=ticker, + line_items=line_items, + end_date=end_date, + period=period, + limit=limit, + ) + } diff --git a/backend/services/gateway.py b/backend/services/gateway.py index 2378580..aa30822 100644 --- a/backend/services/gateway.py +++ b/backend/services/gateway.py @@ -5,42 +5,31 @@ WebSocket Gateway for frontend communication import asyncio import json import logging -from datetime import datetime, timedelta +import os +from datetime import datetime from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Set import websockets from websockets.asyncio.server import ServerConnection -from backend.config.bootstrap_config import ( - get_bootstrap_config_for_run, - resolve_runtime_config, - update_bootstrap_values_for_run, -) -from backend.agents.agent_workspace import load_agent_workspace_config -from backend.agents.skills_manager import SkillsManager -from backend.agents.toolkit_factory import load_agent_profiles from backend.data.provider_utils import normalize_symbol -from backend.data.market_ingest import ingest_symbols -from backend.enrich.llm_enricher import llm_enrichment_enabled -from backend.enrich.news_enricher import enrich_news_for_symbol -from backend.explain.range_explainer import build_range_explanation -from backend.explain.similarity_service import find_similar_days -from backend.explain.story_service import get_or_create_stock_story +from backend.domains import news as news_domain from backend.llm.models import get_agent_model_info -from backend.utils.msg_adapter import FrontendAdapter from backend.utils.terminal_dashboard import get_dashboard from backend.core.pipeline import TradingPipeline from backend.core.state_sync import StateSync from backend.services.market import MarketService from backend.services.storage import StorageService from backend.data.provider_router import get_provider_router -from backend.tools.data_tools import get_prices -from backend.tools.data_tools import get_company_news -from backend.tools.data_tools import get_insider_trades -from backend.tools.data_tools import prices_to_df from backend.tools.technical_signals import StockTechnicalAnalyzer from backend.core.scheduler import Scheduler +from backend.services import gateway_admin_handlers +from backend.services import gateway_cycle_support +from backend.services import gateway_runtime_support +from backend.services import gateway_stock_handlers +from shared.client import NewsServiceClient +from shared.client import TradingServiceClient logger = logging.getLogger(__name__) EDITABLE_AGENT_WORKSPACE_FILES = { @@ -268,14 +257,59 @@ class Gateway: @staticmethod def _news_rows_need_enrichment(rows: List[Dict[str, Any]]) -> bool: - if not rows: - return True - return all( - not row.get("sentiment") - and not row.get("relevance") - and not row.get("key_discussion") - for row in rows + return news_domain.news_rows_need_enrichment(rows) + + def _news_service_url(self) -> str | None: + """Return configured news-service base URL, if any.""" + candidate = self.config.get("news_service_url") or os.getenv( + "NEWS_SERVICE_URL", + "", ) + value = str(candidate or "").strip() + return value or None + + def _trading_service_url(self) -> str | None: + """Return configured trading-service base URL, if any.""" + candidate = self.config.get("trading_service_url") or os.getenv( + "TRADING_SERVICE_URL", + "", + ) + value = str(candidate or "").strip() + return value or None + + async def _call_news_service( + self, + action: str, + callback: Callable[[NewsServiceClient], Any], + ) -> Any | None: + """Call news-service when configured, otherwise return None.""" + service_url = self._news_service_url() + if not service_url: + return None + + try: + async with NewsServiceClient(service_url) as client: + return await callback(client) + except Exception as exc: + logger.warning("news-service %s failed: %s", action, exc) + return None + + async def _call_trading_service( + self, + action: str, + callback: Callable[[TradingServiceClient], Any], + ) -> Any | None: + """Call trading-service when configured, otherwise return None.""" + service_url = self._trading_service_url() + if not service_url: + return None + + try: + async with TradingServiceClient(service_url) as client: + return await callback(client) + except Exception as exc: + logger.warning("trading-service %s failed: %s", action, exc) + return None async def handle_client(self, websocket: ServerConnection): """Handle WebSocket client connection""" @@ -415,952 +449,84 @@ class Gateway: websocket: ServerConnection, data: Dict[str, Any], ): - ticker = normalize_symbol(data.get("ticker", "")) - if not ticker: - await websocket.send( - json.dumps( - { - "type": "stock_history_loaded", - "ticker": "", - "prices": [], - "source": None, - "error": "invalid ticker", - }, - ensure_ascii=False, - ), - ) - return - - lookback_days = data.get("lookback_days", 90) - try: - lookback_days = max(7, min(int(lookback_days), 365)) - except (TypeError, ValueError): - lookback_days = 90 - - end_date = self.state_sync.state.get("current_date") - if not end_date: - end_date = datetime.now().strftime("%Y-%m-%d") - - try: - end_dt = datetime.strptime(end_date, "%Y-%m-%d") - except ValueError: - end_dt = datetime.now() - end_date = end_dt.strftime("%Y-%m-%d") - - start_date = (end_dt - timedelta(days=lookback_days)).strftime( - "%Y-%m-%d", - ) - - prices = await asyncio.to_thread( - self.storage.market_store.get_ohlc, - ticker, - start_date, - end_date, - ) - source = "polygon" - if not prices: - prices = await asyncio.to_thread( - get_prices, - ticker, - start_date, - end_date, - ) - usage_snapshot = self._provider_router.get_usage_snapshot() - source = usage_snapshot.get("last_success", {}).get("prices") - if prices: - await asyncio.to_thread( - self.storage.market_store.upsert_ohlc, - ticker, - [price.model_dump() for price in prices], - source=source or "provider", - ) - - await websocket.send( - json.dumps( - { - "type": "stock_history_loaded", - "ticker": ticker, - "prices": [ - price if isinstance(price, dict) else price.model_dump() - for price in prices - ][-120:], - "source": source, - "start_date": start_date, - "end_date": end_date, - }, - ensure_ascii=False, - default=str, - ), - ) + await gateway_stock_handlers.handle_get_stock_history(self, websocket, data) async def _handle_get_stock_explain_events( self, websocket: ServerConnection, data: Dict[str, Any], ): - ticker = normalize_symbol(data.get("ticker", "")) - snapshot = self.storage.runtime_db.get_stock_explain_snapshot(ticker) - await websocket.send( - json.dumps( - { - "type": "stock_explain_events_loaded", - "ticker": ticker, - "events": snapshot.get("events", []), - "signals": snapshot.get("signals", []), - "trades": snapshot.get("trades", []), - }, - ensure_ascii=False, - default=str, - ), - ) + await gateway_stock_handlers.handle_get_stock_explain_events(self, websocket, data) async def _handle_get_stock_news( self, websocket: ServerConnection, data: Dict[str, Any], ): - ticker = normalize_symbol(data.get("ticker", "")) - if not ticker: - await websocket.send( - json.dumps( - { - "type": "stock_news_loaded", - "ticker": "", - "news": [], - "source": None, - "error": "invalid ticker", - }, - ensure_ascii=False, - ), - ) - return - - lookback_days = data.get("lookback_days", 30) - limit = data.get("limit", 12) - try: - lookback_days = max(7, min(int(lookback_days), 180)) - except (TypeError, ValueError): - lookback_days = 30 - try: - limit = max(1, min(int(limit), 30)) - except (TypeError, ValueError): - limit = 12 - - end_date = self.state_sync.state.get("current_date") - if not end_date: - end_date = datetime.now().strftime("%Y-%m-%d") - - try: - end_dt = datetime.strptime(end_date, "%Y-%m-%d") - except ValueError: - end_dt = datetime.now() - end_date = end_dt.strftime("%Y-%m-%d") - - start_date = (end_dt - timedelta(days=lookback_days)).strftime( - "%Y-%m-%d", - ) - - news_rows = await asyncio.to_thread( - self.storage.market_store.get_news_items_enriched, - ticker, - start_date=start_date, - end_date=end_date, - limit=limit, - ) - source = "polygon" - if self._news_rows_need_enrichment(news_rows): - news = await asyncio.to_thread( - get_company_news, - ticker, - end_date, - start_date, - limit, - ) - if news: - usage_snapshot = self._provider_router.get_usage_snapshot() - source = usage_snapshot.get("last_success", {}).get("company_news") - await asyncio.to_thread( - self.storage.market_store.upsert_news, - ticker, - [item.model_dump() for item in news], - source=source or "provider", - ) - await asyncio.to_thread( - enrich_news_for_symbol, - self.storage.market_store, - ticker, - start_date=start_date, - end_date=end_date, - limit=max(limit, 50), - ) - news_rows = await asyncio.to_thread( - self.storage.market_store.get_news_items_enriched, - ticker, - start_date=start_date, - end_date=end_date, - limit=limit, - ) - source = source or "market_store" - - await websocket.send( - json.dumps( - { - "type": "stock_news_loaded", - "ticker": ticker, - "news": news_rows[-limit:], - "source": source, - "start_date": start_date, - "end_date": end_date, - }, - ensure_ascii=False, - default=str, - ), - ) + await gateway_stock_handlers.handle_get_stock_news(self, websocket, data) async def _handle_get_stock_news_for_date( self, websocket: ServerConnection, data: Dict[str, Any], ): - ticker = normalize_symbol(data.get("ticker", "")) - trade_date = str(data.get("date") or "").strip() - if not ticker or not trade_date: - await websocket.send( - json.dumps( - { - "type": "stock_news_for_date_loaded", - "ticker": ticker, - "date": trade_date, - "news": [], - "error": "ticker and date are required", - }, - ensure_ascii=False, - ), - ) - return - - limit = data.get("limit", 20) - try: - limit = max(1, min(int(limit), 50)) - except (TypeError, ValueError): - limit = 20 - - news_rows = await asyncio.to_thread( - self.storage.market_store.get_news_items_enriched, - ticker, - trade_date=trade_date, - limit=limit, - ) - if self._news_rows_need_enrichment(news_rows): - await asyncio.to_thread( - enrich_news_for_symbol, - self.storage.market_store, - ticker, - start_date=trade_date, - end_date=trade_date, - limit=limit, - ) - news_rows = await asyncio.to_thread( - self.storage.market_store.get_news_items_enriched, - ticker, - trade_date=trade_date, - limit=limit, - ) - - await websocket.send( - json.dumps( - { - "type": "stock_news_for_date_loaded", - "ticker": ticker, - "date": trade_date, - "news": news_rows, - "source": "market_store", - }, - ensure_ascii=False, - default=str, - ), - ) + await gateway_stock_handlers.handle_get_stock_news_for_date(self, websocket, data) async def _handle_get_stock_news_timeline( self, websocket: ServerConnection, data: Dict[str, Any], ): - ticker = normalize_symbol(data.get("ticker", "")) - if not ticker: - await websocket.send( - json.dumps( - { - "type": "stock_news_timeline_loaded", - "ticker": "", - "timeline": [], - "error": "invalid ticker", - }, - ensure_ascii=False, - ), - ) - return - - lookback_days = data.get("lookback_days", 90) - try: - lookback_days = max(7, min(int(lookback_days), 365)) - except (TypeError, ValueError): - lookback_days = 90 - - end_date = self.state_sync.state.get("current_date") - if not end_date: - end_date = datetime.now().strftime("%Y-%m-%d") - - try: - end_dt = datetime.strptime(end_date, "%Y-%m-%d") - except ValueError: - end_dt = datetime.now() - end_date = end_dt.strftime("%Y-%m-%d") - - start_date = (end_dt - timedelta(days=lookback_days)).strftime( - "%Y-%m-%d", - ) - timeline = await asyncio.to_thread( - self.storage.market_store.get_news_timeline_enriched, - ticker, - start_date=start_date, - end_date=end_date, - ) - if not timeline: - await asyncio.to_thread( - enrich_news_for_symbol, - self.storage.market_store, - ticker, - start_date=start_date, - end_date=end_date, - limit=200, - ) - timeline = await asyncio.to_thread( - self.storage.market_store.get_news_timeline_enriched, - ticker, - start_date=start_date, - end_date=end_date, - ) - await websocket.send( - json.dumps( - { - "type": "stock_news_timeline_loaded", - "ticker": ticker, - "timeline": timeline, - "start_date": start_date, - "end_date": end_date, - }, - ensure_ascii=False, - default=str, - ), - ) + await gateway_stock_handlers.handle_get_stock_news_timeline(self, websocket, data) async def _handle_get_stock_news_categories( self, websocket: ServerConnection, data: Dict[str, Any], ): - ticker = normalize_symbol(data.get("ticker", "")) - if not ticker: - await websocket.send( - json.dumps( - { - "type": "stock_news_categories_loaded", - "ticker": "", - "categories": {}, - "error": "invalid ticker", - }, - ensure_ascii=False, - ), - ) - return - - lookback_days = data.get("lookback_days", 90) - try: - lookback_days = max(7, min(int(lookback_days), 365)) - except (TypeError, ValueError): - lookback_days = 90 - - end_date = self.state_sync.state.get("current_date") - if not end_date: - end_date = datetime.now().strftime("%Y-%m-%d") - - try: - end_dt = datetime.strptime(end_date, "%Y-%m-%d") - except ValueError: - end_dt = datetime.now() - end_date = end_dt.strftime("%Y-%m-%d") - - start_date = (end_dt - timedelta(days=lookback_days)).strftime( - "%Y-%m-%d", - ) - news_rows = await asyncio.to_thread( - self.storage.market_store.get_news_items_enriched, - ticker, - start_date=start_date, - end_date=end_date, - limit=200, - ) - if self._news_rows_need_enrichment(news_rows): - await asyncio.to_thread( - enrich_news_for_symbol, - self.storage.market_store, - ticker, - start_date=start_date, - end_date=end_date, - limit=200, - ) - categories = await asyncio.to_thread( - self.storage.market_store.get_news_categories_enriched, - ticker, - start_date=start_date, - end_date=end_date, - limit=200, - ) - await websocket.send( - json.dumps( - { - "type": "stock_news_categories_loaded", - "ticker": ticker, - "categories": categories, - "start_date": start_date, - "end_date": end_date, - }, - ensure_ascii=False, - default=str, - ), - ) + await gateway_stock_handlers.handle_get_stock_news_categories(self, websocket, data) async def _handle_get_stock_range_explain( self, websocket: ServerConnection, data: Dict[str, Any], ): - ticker = normalize_symbol(data.get("ticker", "")) - start_date = str(data.get("start_date") or "").strip() - end_date = str(data.get("end_date") or "").strip() - if not ticker or not start_date or not end_date: - await websocket.send( - json.dumps( - { - "type": "stock_range_explain_loaded", - "ticker": ticker, - "result": {"error": "ticker, start_date, end_date are required"}, - }, - ensure_ascii=False, - ), - ) - return - - article_ids = data.get("article_ids") - if isinstance(article_ids, list) and article_ids: - news_rows = await asyncio.to_thread( - self.storage.market_store.get_news_by_ids_enriched, - ticker, - article_ids, - ) - if self._news_rows_need_enrichment(news_rows): - await asyncio.to_thread( - enrich_news_for_symbol, - self.storage.market_store, - ticker, - start_date=start_date, - end_date=end_date, - limit=100, - ) - news_rows = await asyncio.to_thread( - self.storage.market_store.get_news_by_ids_enriched, - ticker, - article_ids, - ) - else: - news_rows = await asyncio.to_thread( - self.storage.market_store.get_news_items_enriched, - ticker, - start_date=start_date, - end_date=end_date, - limit=100, - ) - if not news_rows: - await asyncio.to_thread( - enrich_news_for_symbol, - self.storage.market_store, - ticker, - start_date=start_date, - end_date=end_date, - limit=100, - ) - news_rows = await asyncio.to_thread( - self.storage.market_store.get_news_items_enriched, - ticker, - start_date=start_date, - end_date=end_date, - limit=100, - ) - - result = await asyncio.to_thread( - build_range_explanation, - ticker=ticker, - start_date=start_date, - end_date=end_date, - news_rows=news_rows, - ) - await websocket.send( - json.dumps( - { - "type": "stock_range_explain_loaded", - "ticker": ticker, - "result": result, - }, - ensure_ascii=False, - default=str, - ), - ) + await gateway_stock_handlers.handle_get_stock_range_explain(self, websocket, data) async def _handle_get_stock_insider_trades( self, websocket: ServerConnection, data: Dict[str, Any], ): - ticker = normalize_symbol(data.get("ticker", "")) - if not ticker: - await websocket.send( - json.dumps( - { - "type": "stock_insider_trades_loaded", - "ticker": "", - "trades": [], - "error": "invalid ticker", - }, - ensure_ascii=False, - ), - ) - return - - end_date = str( - data.get("end_date") - or self.state_sync.state.get("current_date") - or datetime.now().strftime("%Y-%m-%d") - ).strip()[:10] - start_date = str(data.get("start_date") or "").strip()[:10] - limit = int(data.get("limit", 50)) - - trades = await asyncio.to_thread( - get_insider_trades, - ticker=ticker, - end_date=end_date, - start_date=start_date if start_date else None, - limit=limit, - ) - - # Sort by transaction date descending - sorted_trades = sorted( - trades, - key=lambda t: t.transaction_date or "", - reverse=True, - ) - - # Format for frontend - formatted_trades = [ - { - "ticker": t.ticker, - "name": t.name, - "title": t.title, - "is_board_director": t.is_board_director, - "transaction_date": t.transaction_date, - "transaction_shares": t.transaction_shares, - "transaction_price_per_share": t.transaction_price_per_share, - "transaction_value": t.transaction_value, - "shares_owned_before_transaction": t.shares_owned_before_transaction, - "shares_owned_after_transaction": t.shares_owned_after_transaction, - "security_title": t.security_title, - "filing_date": t.filing_date, - # Calculated fields - "holding_change": ( - (t.shares_owned_after_transaction or 0) - - (t.shares_owned_before_transaction or 0) - if t.shares_owned_after_transaction and t.shares_owned_before_transaction - else None - ), - "is_buy": ( - (t.transaction_shares or 0) > 0 - if t.transaction_shares is not None - else None - ), - } - for t in sorted_trades - ] - - await websocket.send( - json.dumps( - { - "type": "stock_insider_trades_loaded", - "ticker": ticker, - "start_date": start_date or None, - "end_date": end_date, - "trades": formatted_trades, - }, - ensure_ascii=False, - default=str, - ), - ) + await gateway_stock_handlers.handle_get_stock_insider_trades(self, websocket, data) async def _handle_get_stock_story( self, websocket: ServerConnection, data: Dict[str, Any], ): - ticker = normalize_symbol(data.get("ticker", "")) - if not ticker: - await websocket.send( - json.dumps( - { - "type": "stock_story_loaded", - "ticker": "", - "story": "", - "error": "invalid ticker", - }, - ensure_ascii=False, - ), - ) - return - - as_of_date = str( - data.get("as_of_date") - or self.state_sync.state.get("current_date") - or datetime.now().strftime("%Y-%m-%d") - ).strip()[:10] - await asyncio.to_thread( - enrich_news_for_symbol, - self.storage.market_store, - ticker, - end_date=as_of_date, - limit=80, - ) - result = await asyncio.to_thread( - get_or_create_stock_story, - self.storage.market_store, - symbol=ticker, - as_of_date=as_of_date, - ) - await websocket.send( - json.dumps( - { - "type": "stock_story_loaded", - "ticker": ticker, - "as_of_date": as_of_date, - "story": result.get("story") or "", - "source": result.get("source") or "local", - }, - ensure_ascii=False, - default=str, - ), - ) + await gateway_stock_handlers.handle_get_stock_story(self, websocket, data) async def _handle_get_stock_similar_days( self, websocket: ServerConnection, data: Dict[str, Any], ): - ticker = normalize_symbol(data.get("ticker", "")) - target_date = str(data.get("date") or "").strip()[:10] - if not ticker or not target_date: - await websocket.send( - json.dumps( - { - "type": "stock_similar_days_loaded", - "ticker": ticker, - "date": target_date, - "items": [], - "error": "ticker and date are required", - }, - ensure_ascii=False, - ), - ) - return - - top_k = data.get("top_k", 8) - try: - top_k = max(1, min(int(top_k), 20)) - except (TypeError, ValueError): - top_k = 8 - - await asyncio.to_thread( - enrich_news_for_symbol, - self.storage.market_store, - ticker, - end_date=target_date, - limit=200, - ) - result = await asyncio.to_thread( - find_similar_days, - self.storage.market_store, - symbol=ticker, - target_date=target_date, - top_k=top_k, - ) - await websocket.send( - json.dumps( - { - "type": "stock_similar_days_loaded", - "ticker": ticker, - "date": target_date, - **result, - }, - ensure_ascii=False, - default=str, - ), - ) + await gateway_stock_handlers.handle_get_stock_similar_days(self, websocket, data) async def _handle_get_stock_technical_indicators( self, websocket: ServerConnection, data: Dict[str, Any], ): - ticker = normalize_symbol(data.get("ticker", "")) - if not ticker: - await websocket.send( - json.dumps( - { - "type": "stock_technical_indicators_loaded", - "ticker": ticker, - "indicators": None, - "error": "ticker is required", - }, - ensure_ascii=False, - ), - ) - return - - try: - # Get price data for the ticker - from datetime import datetime, timedelta - end_date = datetime.now() - start_date = end_date - timedelta(days=250) # ~1 year for MA200 - - prices = get_prices( - ticker=ticker, - start_date=start_date.strftime("%Y-%m-%d"), - end_date=end_date.strftime("%Y-%m-%d"), - ) - - if not prices or len(prices) < 20: - await websocket.send( - json.dumps( - { - "type": "stock_technical_indicators_loaded", - "ticker": ticker, - "indicators": None, - "error": "Insufficient price data", - }, - ensure_ascii=False, - ), - ) - return - - # Analyze technical indicators - df = prices_to_df(prices) - signal = self._technical_analyzer.analyze(ticker, df) - - # Calculate additional volatility metrics - import pandas as pd - df_sorted = df.sort_values("time").reset_index(drop=True) - df_sorted["returns"] = df_sorted["close"].pct_change() - - vol_10 = float(df_sorted["returns"].tail(10).std() * (252**0.5) * 100) if len(df_sorted) >= 10 else None - vol_20 = float(df_sorted["returns"].tail(20).std() * (252**0.5) * 100) if len(df_sorted) >= 20 else None - vol_60 = float(df_sorted["returns"].tail(60).std() * (252**0.5) * 100) if len(df_sorted) >= 60 else None - - # Calculate MA distance from current price - ma_distance = {} - for ma_key in ["ma5", "ma10", "ma20", "ma50", "ma200"]: - ma_value = getattr(signal, ma_key, None) - if ma_value and ma_value > 0: - ma_distance[ma_key] = ((signal.current_price - ma_value) / ma_value) * 100 - else: - ma_distance[ma_key] = None - - indicators = { - "ticker": ticker, - "current_price": signal.current_price, - "ma": { - "ma5": signal.ma5, - "ma10": signal.ma10, - "ma20": signal.ma20, - "ma50": signal.ma50, - "ma200": signal.ma200, - "distance": ma_distance, - }, - "rsi": { - "rsi14": signal.rsi14, - "status": "oversold" if signal.rsi14 < 30 else "overbought" if signal.rsi14 > 70 else "neutral", - }, - "macd": { - "macd": signal.macd, - "signal": signal.macd_signal, - "histogram": signal.macd - signal.macd_signal, - }, - "bollinger": { - "upper": signal.bollinger_upper, - "mid": signal.bollinger_mid, - "lower": signal.bollinger_lower, - }, - "volatility": { - "vol_10d": vol_10, - "vol_20d": vol_20, - "vol_60d": vol_60, - "annualized": signal.annualized_volatility_pct, - "risk_level": signal.risk_level, - }, - "trend": signal.trend, - "mean_reversion": signal.mean_reversion_signal, - } - - await websocket.send( - json.dumps( - { - "type": "stock_technical_indicators_loaded", - "ticker": ticker, - "indicators": indicators, - }, - ensure_ascii=False, - default=str, - ), - ) - - except Exception as e: - logger.exception(f"Error getting technical indicators for {ticker}") - await websocket.send( - json.dumps( - { - "type": "stock_technical_indicators_loaded", - "ticker": ticker, - "indicators": None, - "error": str(e), - }, - ensure_ascii=False, - ), - ) + await gateway_stock_handlers.handle_get_stock_technical_indicators(self, websocket, data) async def _handle_run_stock_enrich( self, websocket: ServerConnection, data: Dict[str, Any], ): - ticker = normalize_symbol(data.get("ticker", "")) - start_date = str(data.get("start_date") or "").strip()[:10] - end_date = str(data.get("end_date") or "").strip()[:10] - story_date = str(data.get("story_date") or end_date or "").strip()[:10] - target_date = str(data.get("target_date") or "").strip()[:10] - force = bool(data.get("force", False)) - rebuild_story = bool(data.get("rebuild_story", True)) - rebuild_similar_days = bool(data.get("rebuild_similar_days", True)) - only_local_to_llm = bool(data.get("only_local_to_llm", False)) - limit = data.get("limit", 200) - - try: - limit = max(10, min(int(limit), 500)) - except (TypeError, ValueError): - limit = 200 - - if not ticker or not start_date or not end_date: - await websocket.send( - json.dumps( - { - "type": "stock_enrich_completed", - "ticker": ticker, - "start_date": start_date, - "end_date": end_date, - "error": "ticker, start_date, end_date are required", - }, - ensure_ascii=False, - ), - ) - return - - if only_local_to_llm and not llm_enrichment_enabled(): - await websocket.send( - json.dumps( - { - "type": "stock_enrich_completed", - "ticker": ticker, - "start_date": start_date, - "end_date": end_date, - "error": "only_local_to_llm requires EXPLAIN_ENRICH_USE_LLM=true and a configured LLM provider", - }, - ensure_ascii=False, - ), - ) - return - - result = await asyncio.to_thread( - enrich_news_for_symbol, - self.storage.market_store, - ticker, - start_date=start_date, - end_date=end_date, - limit=limit, - skip_existing=not force, - only_reanalyze_local=only_local_to_llm, - ) - - story_status = None - if rebuild_story and story_date: - await asyncio.to_thread( - self.storage.market_store.delete_story_cache, - ticker, - as_of_date=story_date, - ) - story_result = await asyncio.to_thread( - get_or_create_stock_story, - self.storage.market_store, - symbol=ticker, - as_of_date=story_date, - ) - story_status = { - "as_of_date": story_date, - "source": story_result.get("source") or "local", - } - - similar_status = None - if rebuild_similar_days and target_date: - await asyncio.to_thread( - self.storage.market_store.delete_similar_day_cache, - ticker, - target_date=target_date, - ) - similar_result = await asyncio.to_thread( - find_similar_days, - self.storage.market_store, - symbol=ticker, - target_date=target_date, - top_k=8, - ) - similar_status = { - "target_date": target_date, - "count": len(similar_result.get("items") or []), - "error": similar_result.get("error"), - } - - await websocket.send( - json.dumps( - { - "type": "stock_enrich_completed", - "ticker": ticker, - "start_date": start_date, - "end_date": end_date, - "story_date": story_date or None, - "target_date": target_date or None, - "force": force, - "only_local_to_llm": only_local_to_llm, - "stats": result, - "story_status": story_status, - "similar_status": similar_status, - }, - ensure_ascii=False, - default=str, - ), - ) + await gateway_stock_handlers.handle_run_stock_enrich(self, websocket, data) async def _handle_start_backtest(self, data: Dict[str, Any]): if not self.is_backtest: @@ -1421,1008 +587,123 @@ class Gateway: self._manual_cycle_task = task async def _handle_reload_runtime_assets(self): - """Reload prompt, skills, and safe runtime config without restart.""" - config_name = self.config.get("config_name", "default") - runtime_config = resolve_runtime_config( - project_root=self._project_root, - config_name=config_name, - enable_memory=self.config.get("enable_memory", False), - schedule_mode=self.config.get("schedule_mode", "daily"), - interval_minutes=self.config.get("interval_minutes", 60), - trigger_time=self.config.get("trigger_time", "09:30"), - ) - result = self.pipeline.reload_runtime_assets(runtime_config=runtime_config) - runtime_updates = self._apply_runtime_config(runtime_config) - await self.state_sync.on_system_message( - "Runtime assets reloaded.", - ) - await self.broadcast( - { - "type": "runtime_assets_reloaded", - **result, - **runtime_updates, - }, - ) + await gateway_admin_handlers.handle_reload_runtime_assets(self) async def _handle_update_runtime_config( self, websocket: ServerConnection, data: Dict[str, Any], ) -> None: - """Persist selected runtime settings and hot-reload them.""" - updates: Dict[str, Any] = {} - - schedule_mode = str(data.get("schedule_mode", "")).strip().lower() - if schedule_mode: - if schedule_mode not in {"daily", "intraday"}: - await websocket.send( - json.dumps( - { - "type": "error", - "message": "schedule_mode must be 'daily' or 'intraday'.", - }, - ensure_ascii=False, - ), - ) - return - updates["schedule_mode"] = schedule_mode - - interval_minutes = data.get("interval_minutes") - if interval_minutes is not None: - try: - parsed_interval = int(interval_minutes) - except (TypeError, ValueError): - parsed_interval = 0 - if parsed_interval <= 0: - await websocket.send( - json.dumps( - { - "type": "error", - "message": "interval_minutes must be a positive integer.", - }, - ensure_ascii=False, - ), - ) - return - updates["interval_minutes"] = parsed_interval - - trigger_time = data.get("trigger_time") - if trigger_time is not None: - raw_trigger = str(trigger_time).strip() - if raw_trigger and raw_trigger != "now": - try: - datetime.strptime(raw_trigger, "%H:%M") - except ValueError: - await websocket.send( - json.dumps( - { - "type": "error", - "message": "trigger_time must use HH:MM or 'now'.", - }, - ensure_ascii=False, - ), - ) - return - updates["trigger_time"] = raw_trigger or "09:30" - - max_comm_cycles = data.get("max_comm_cycles") - if max_comm_cycles is not None: - try: - parsed_cycles = int(max_comm_cycles) - except (TypeError, ValueError): - parsed_cycles = 0 - if parsed_cycles <= 0: - await websocket.send( - json.dumps( - { - "type": "error", - "message": "max_comm_cycles must be a positive integer.", - }, - ensure_ascii=False, - ), - ) - return - updates["max_comm_cycles"] = parsed_cycles - - initial_cash = data.get("initial_cash") - if initial_cash is not None: - try: - parsed_initial_cash = float(initial_cash) - except (TypeError, ValueError): - parsed_initial_cash = 0.0 - if parsed_initial_cash <= 0: - await websocket.send( - json.dumps( - { - "type": "error", - "message": "initial_cash must be a positive number.", - }, - ensure_ascii=False, - ), - ) - return - updates["initial_cash"] = parsed_initial_cash - - margin_requirement = data.get("margin_requirement") - if margin_requirement is not None: - try: - parsed_margin_requirement = float(margin_requirement) - except (TypeError, ValueError): - parsed_margin_requirement = -1.0 - if parsed_margin_requirement < 0: - await websocket.send( - json.dumps( - { - "type": "error", - "message": "margin_requirement must be a non-negative number.", - }, - ensure_ascii=False, - ), - ) - return - updates["margin_requirement"] = parsed_margin_requirement - - enable_memory = data.get("enable_memory") - if enable_memory is not None: - updates["enable_memory"] = bool(enable_memory) - - if not updates: - await websocket.send( - json.dumps( - { - "type": "error", - "message": "No runtime settings were provided.", - }, - ensure_ascii=False, - ), - ) - return - - config_name = self.config.get("config_name", "default") - update_bootstrap_values_for_run( - project_root=self._project_root, - config_name=config_name, - updates=updates, - ) - await self.state_sync.on_system_message("运行时调度配置已保存,正在热更新") - await self._handle_reload_runtime_assets() + await gateway_admin_handlers.handle_update_runtime_config(self, websocket, data) async def _handle_update_watchlist( self, websocket: ServerConnection, data: Dict[str, Any], ) -> None: - """Persist a new watchlist to BOOTSTRAP.md and hot-reload it.""" - tickers = self._normalize_watchlist(data.get("tickers")) - if not tickers: - await websocket.send( - json.dumps( - { - "type": "error", - "message": "update_watchlist requires at least one valid ticker.", - }, - ensure_ascii=False, - ), - ) - return - - config_name = self.config.get("config_name", "default") - update_bootstrap_values_for_run( - project_root=self._project_root, - config_name=config_name, - updates={"tickers": tickers}, - ) - await self.state_sync.on_system_message( - f"Watchlist updated: {', '.join(tickers)}", - ) - await self.broadcast( - { - "type": "watchlist_updated", - "config_name": config_name, - "tickers": tickers, - }, - ) - await self._handle_reload_runtime_assets() - self._schedule_watchlist_market_store_refresh(tickers) + await gateway_admin_handlers.handle_update_watchlist(self, websocket, data) async def _handle_get_agent_skills( self, websocket: ServerConnection, data: Dict[str, Any], ) -> None: - """Return skill catalog and status for one agent.""" - agent_id = str(data.get("agent_id", "")).strip() - if not agent_id: - await websocket.send( - json.dumps( - { - "type": "error", - "message": "get_agent_skills requires agent_id.", - }, - ensure_ascii=False, - ), - ) - return - - config_name = self.config.get("config_name", "default") - skills_manager = SkillsManager(project_root=self._project_root) - agent_asset_dir = skills_manager.get_agent_asset_dir(config_name, agent_id) - agent_config = load_agent_workspace_config(agent_asset_dir / "agent.yaml") - resolved_skills = set( - skills_manager.resolve_agent_skill_names( - config_name=config_name, - agent_id=agent_id, - default_skills=[], - ), - ) - enabled = set(agent_config.enabled_skills) - disabled = set(agent_config.disabled_skills) - - payload = [] - for item in skills_manager.list_agent_skill_catalog(config_name, agent_id): - if item.skill_name in disabled: - status = "disabled" - elif item.skill_name in enabled: - status = "enabled" - elif item.skill_name in resolved_skills: - status = "active" - else: - status = "available" - payload.append( - { - "skill_name": item.skill_name, - "name": item.name, - "description": item.description, - "version": item.version, - "source": item.source, - "tools": item.tools, - "status": status, - }, - ) - - await websocket.send( - json.dumps( - { - "type": "agent_skills_loaded", - "config_name": config_name, - "agent_id": agent_id, - "skills": payload, - }, - ensure_ascii=False, - ), - ) + await gateway_admin_handlers.handle_get_agent_skills(self, websocket, data) async def _handle_get_agent_profile( self, websocket: ServerConnection, data: Dict[str, Any], ) -> None: - """Return structured profile/config summary for one agent.""" - agent_id = str(data.get("agent_id", "")).strip() - if not agent_id: - await websocket.send( - json.dumps( - { - "type": "error", - "message": "get_agent_profile requires agent_id.", - }, - ensure_ascii=False, - ), - ) - return - - config_name = self.config.get("config_name", "default") - skills_manager = SkillsManager(project_root=self._project_root) - asset_dir = skills_manager.get_agent_asset_dir(config_name, agent_id) - agent_config = load_agent_workspace_config(asset_dir / "agent.yaml") - profiles = load_agent_profiles() - profile = profiles.get(agent_id, {}) - bootstrap = get_bootstrap_config_for_run(self._project_root, config_name) - override = bootstrap.agent_override(agent_id) - active_tool_groups = override.get( - "active_tool_groups", - agent_config.active_tool_groups or profile.get("active_tool_groups", []), - ) - if not isinstance(active_tool_groups, list): - active_tool_groups = [] - disabled_tool_groups = agent_config.disabled_tool_groups - if disabled_tool_groups: - disabled_set = set(disabled_tool_groups) - active_tool_groups = [ - group_name - for group_name in active_tool_groups - if group_name not in disabled_set - ] - - default_skills = profile.get("skills", []) - if not isinstance(default_skills, list): - default_skills = [] - resolved_skills = skills_manager.resolve_agent_skill_names( - config_name=config_name, - agent_id=agent_id, - default_skills=default_skills, - ) - prompt_files = agent_config.prompt_files or [ - "SOUL.md", - "PROFILE.md", - "AGENTS.md", - "POLICY.md", - "MEMORY.md", - ] - model_name, model_provider = get_agent_model_info(agent_id) - - await websocket.send( - json.dumps( - { - "type": "agent_profile_loaded", - "config_name": config_name, - "agent_id": agent_id, - "profile": { - "model_name": model_name, - "model_provider": model_provider, - "prompt_files": prompt_files, - "default_skills": default_skills, - "resolved_skills": resolved_skills, - "active_tool_groups": active_tool_groups, - "disabled_tool_groups": disabled_tool_groups, - "enabled_skills": agent_config.enabled_skills, - "disabled_skills": agent_config.disabled_skills, - }, - }, - ensure_ascii=False, - ), - ) + await gateway_admin_handlers.handle_get_agent_profile(self, websocket, data) async def _handle_get_skill_detail( self, websocket: ServerConnection, data: Dict[str, Any], ) -> None: - """Return full SKILL.md body for one skill.""" - agent_id = str(data.get("agent_id", "")).strip() - skill_name = str(data.get("skill_name", "")).strip() - if not skill_name: - await websocket.send( - json.dumps( - { - "type": "error", - "message": "get_skill_detail requires skill_name.", - }, - ensure_ascii=False, - ), - ) - return - - skills_manager = SkillsManager(project_root=self._project_root) - try: - if agent_id: - config_name = self.config.get("config_name", "default") - detail = skills_manager.load_agent_skill_document( - config_name=config_name, - agent_id=agent_id, - skill_name=skill_name, - ) - else: - detail = skills_manager.load_skill_document(skill_name) - except FileNotFoundError: - await websocket.send( - json.dumps( - { - "type": "error", - "message": f"Unknown skill: {skill_name}", - }, - ensure_ascii=False, - ), - ) - return - - await websocket.send( - json.dumps( - { - "type": "skill_detail_loaded", - "agent_id": agent_id, - "skill": detail, - }, - ensure_ascii=False, - ), - ) + await gateway_admin_handlers.handle_get_skill_detail(self, websocket, data) async def _handle_create_agent_local_skill( self, websocket: ServerConnection, data: Dict[str, Any], ) -> None: - """Create a new local skill for one agent and hot-reload runtime assets.""" - agent_id = str(data.get("agent_id", "")).strip() - skill_name = str(data.get("skill_name", "")).strip() - if not agent_id or not skill_name: - await websocket.send( - json.dumps( - { - "type": "error", - "message": "create_agent_local_skill requires agent_id and skill_name.", - }, - ensure_ascii=False, - ), - ) - return - - config_name = self.config.get("config_name", "default") - skills_manager = SkillsManager(project_root=self._project_root) - try: - skills_manager.create_agent_local_skill( - config_name=config_name, - agent_id=agent_id, - skill_name=skill_name, - ) - except (ValueError, FileExistsError) as exc: - await websocket.send( - json.dumps( - {"type": "error", "message": str(exc)}, - ensure_ascii=False, - ), - ) - return - - await self.state_sync.on_system_message( - f"Created local skill {skill_name} for {agent_id}", - ) - await self._handle_reload_runtime_assets() - await websocket.send( - json.dumps( - { - "type": "agent_local_skill_created", - "agent_id": agent_id, - "skill_name": skill_name, - }, - ensure_ascii=False, - ), - ) - await self._handle_get_agent_skills(websocket, {"agent_id": agent_id}) - await self._handle_get_skill_detail( - websocket, - {"agent_id": agent_id, "skill_name": skill_name}, - ) + await gateway_admin_handlers.handle_create_agent_local_skill(self, websocket, data) async def _handle_update_agent_local_skill( self, websocket: ServerConnection, data: Dict[str, Any], ) -> None: - """Update one agent-local SKILL.md and hot-reload runtime assets.""" - agent_id = str(data.get("agent_id", "")).strip() - skill_name = str(data.get("skill_name", "")).strip() - content = data.get("content") - if not agent_id or not skill_name or not isinstance(content, str): - await websocket.send( - json.dumps( - { - "type": "error", - "message": "update_agent_local_skill requires agent_id, skill_name, and string content.", - }, - ensure_ascii=False, - ), - ) - return - - config_name = self.config.get("config_name", "default") - skills_manager = SkillsManager(project_root=self._project_root) - try: - skills_manager.update_agent_local_skill( - config_name=config_name, - agent_id=agent_id, - skill_name=skill_name, - content=content, - ) - except (ValueError, FileNotFoundError) as exc: - await websocket.send( - json.dumps( - {"type": "error", "message": str(exc)}, - ensure_ascii=False, - ), - ) - return - - await self.state_sync.on_system_message( - f"Updated local skill {skill_name} for {agent_id}", - ) - await self._handle_reload_runtime_assets() - await websocket.send( - json.dumps( - { - "type": "agent_local_skill_updated", - "agent_id": agent_id, - "skill_name": skill_name, - }, - ensure_ascii=False, - ), - ) - await self._handle_get_skill_detail( - websocket, - {"agent_id": agent_id, "skill_name": skill_name}, - ) + await gateway_admin_handlers.handle_update_agent_local_skill(self, websocket, data) async def _handle_delete_agent_local_skill( self, websocket: ServerConnection, data: Dict[str, Any], ) -> None: - """Delete one agent-local skill and hot-reload runtime assets.""" - agent_id = str(data.get("agent_id", "")).strip() - skill_name = str(data.get("skill_name", "")).strip() - if not agent_id or not skill_name: - await websocket.send( - json.dumps( - { - "type": "error", - "message": "delete_agent_local_skill requires agent_id and skill_name.", - }, - ensure_ascii=False, - ), - ) - return - - config_name = self.config.get("config_name", "default") - skills_manager = SkillsManager(project_root=self._project_root) - try: - skills_manager.delete_agent_local_skill( - config_name=config_name, - agent_id=agent_id, - skill_name=skill_name, - ) - skills_manager.forget_agent_skill_overrides( - config_name=config_name, - agent_id=agent_id, - skill_names=[skill_name], - ) - except (ValueError, FileNotFoundError) as exc: - await websocket.send( - json.dumps( - {"type": "error", "message": str(exc)}, - ensure_ascii=False, - ), - ) - return - - await self.state_sync.on_system_message( - f"Deleted local skill {skill_name} for {agent_id}", - ) - await self._handle_reload_runtime_assets() - await websocket.send( - json.dumps( - { - "type": "agent_local_skill_deleted", - "agent_id": agent_id, - "skill_name": skill_name, - }, - ensure_ascii=False, - ), - ) - await self._handle_get_agent_skills(websocket, {"agent_id": agent_id}) + await gateway_admin_handlers.handle_delete_agent_local_skill(self, websocket, data) async def _handle_remove_agent_skill( self, websocket: ServerConnection, data: Dict[str, Any], ) -> None: - """Remove one shared skill from the agent's installed set.""" - agent_id = str(data.get("agent_id", "")).strip() - skill_name = str(data.get("skill_name", "")).strip() - if not agent_id or not skill_name: - await websocket.send( - json.dumps( - { - "type": "error", - "message": "remove_agent_skill requires agent_id and skill_name.", - }, - ensure_ascii=False, - ), - ) - return - - config_name = self.config.get("config_name", "default") - skills_manager = SkillsManager(project_root=self._project_root) - skill_names = { - item.skill_name - for item in skills_manager.list_agent_skill_catalog(config_name, agent_id) - if item.source != "local" - } - if skill_name not in skill_names: - await websocket.send( - json.dumps( - {"type": "error", "message": f"Unknown shared skill: {skill_name}"}, - ensure_ascii=False, - ), - ) - return - - skills_manager.update_agent_skill_overrides( - config_name=config_name, - agent_id=agent_id, - disable=[skill_name], - ) - await self.state_sync.on_system_message( - f"Removed shared skill {skill_name} from {agent_id}", - ) - await self._handle_reload_runtime_assets() - await websocket.send( - json.dumps( - { - "type": "agent_skill_removed", - "agent_id": agent_id, - "skill_name": skill_name, - }, - ensure_ascii=False, - ), - ) - await self._handle_get_agent_skills(websocket, {"agent_id": agent_id}) + await gateway_admin_handlers.handle_remove_agent_skill(self, websocket, data) async def _handle_update_agent_skill( self, websocket: ServerConnection, data: Dict[str, Any], ) -> None: - """Enable or disable one skill for one agent and hot-reload assets.""" - agent_id = str(data.get("agent_id", "")).strip() - skill_name = str(data.get("skill_name", "")).strip() - enabled = data.get("enabled") - if not agent_id or not skill_name or not isinstance(enabled, bool): - await websocket.send( - json.dumps( - { - "type": "error", - "message": "update_agent_skill requires agent_id, skill_name, and boolean enabled.", - }, - ensure_ascii=False, - ), - ) - return - - config_name = self.config.get("config_name", "default") - skills_manager = SkillsManager(project_root=self._project_root) - skill_names = { - item.skill_name - for item in skills_manager.list_agent_skill_catalog(config_name, agent_id) - } - if skill_name not in skill_names: - await websocket.send( - json.dumps( - { - "type": "error", - "message": f"Unknown skill: {skill_name}", - }, - ensure_ascii=False, - ), - ) - return - - if enabled: - skills_manager.update_agent_skill_overrides( - config_name=config_name, - agent_id=agent_id, - enable=[skill_name], - ) - await self.state_sync.on_system_message( - f"Enabled skill {skill_name} for {agent_id}", - ) - else: - skills_manager.update_agent_skill_overrides( - config_name=config_name, - agent_id=agent_id, - disable=[skill_name], - ) - await self.state_sync.on_system_message( - f"Disabled skill {skill_name} for {agent_id}", - ) - - await websocket.send( - json.dumps( - { - "type": "agent_skill_updated", - "agent_id": agent_id, - "skill_name": skill_name, - "enabled": enabled, - }, - ensure_ascii=False, - ), - ) - await self._handle_reload_runtime_assets() - await self._handle_get_agent_skills( - websocket, - {"agent_id": agent_id}, - ) + await gateway_admin_handlers.handle_update_agent_skill(self, websocket, data) async def _handle_get_agent_workspace_file( self, websocket: ServerConnection, data: Dict[str, Any], ) -> None: - """Load one editable agent workspace markdown file.""" - agent_id = str(data.get("agent_id", "")).strip() - filename = self._normalize_agent_workspace_filename(data.get("filename")) - if not agent_id or not filename: - await websocket.send( - json.dumps( - { - "type": "error", - "message": "get_agent_workspace_file requires agent_id and supported filename.", - }, - ensure_ascii=False, - ), - ) - return - - config_name = self.config.get("config_name", "default") - skills_manager = SkillsManager(project_root=self._project_root) - asset_dir = skills_manager.get_agent_asset_dir(config_name, agent_id) - asset_dir.mkdir(parents=True, exist_ok=True) - path = asset_dir / filename - content = path.read_text(encoding="utf-8") if path.exists() else "" - await websocket.send( - json.dumps( - { - "type": "agent_workspace_file_loaded", - "config_name": config_name, - "agent_id": agent_id, - "filename": filename, - "content": content, - }, - ensure_ascii=False, - ), - ) + await gateway_admin_handlers.handle_get_agent_workspace_file(self, websocket, data) async def _handle_update_agent_workspace_file( self, websocket: ServerConnection, data: Dict[str, Any], ) -> None: - """Persist one editable agent workspace markdown file and hot-reload.""" - agent_id = str(data.get("agent_id", "")).strip() - filename = self._normalize_agent_workspace_filename(data.get("filename")) - content = data.get("content") - if not agent_id or not filename or not isinstance(content, str): - await websocket.send( - json.dumps( - { - "type": "error", - "message": "update_agent_workspace_file requires agent_id, supported filename, and string content.", - }, - ensure_ascii=False, - ), - ) - return - - config_name = self.config.get("config_name", "default") - skills_manager = SkillsManager(project_root=self._project_root) - asset_dir = skills_manager.get_agent_asset_dir(config_name, agent_id) - asset_dir.mkdir(parents=True, exist_ok=True) - path = asset_dir / filename - path.write_text(content, encoding="utf-8") - await self.state_sync.on_system_message( - f"Updated {filename} for {agent_id}", - ) - await websocket.send( - json.dumps( - { - "type": "agent_workspace_file_updated", - "agent_id": agent_id, - "filename": filename, - }, - ensure_ascii=False, - ), - ) - await self._handle_reload_runtime_assets() - await self._handle_get_agent_workspace_file( - websocket, - {"agent_id": agent_id, "filename": filename}, - ) + await gateway_admin_handlers.handle_update_agent_workspace_file(self, websocket, data) @staticmethod def _normalize_watchlist(raw_tickers: Any) -> List[str]: - """Parse watchlist payloads from websocket messages.""" - if raw_tickers is None: - return [] - - if isinstance(raw_tickers, str): - candidates = raw_tickers.split(",") - elif isinstance(raw_tickers, list): - candidates = raw_tickers - else: - candidates = [raw_tickers] - - tickers: List[str] = [] - for candidate in candidates: - symbol = normalize_symbol(str(candidate).strip().strip("\"'")) - if symbol and symbol not in tickers: - tickers.append(symbol) - return tickers + return gateway_runtime_support.normalize_watchlist(raw_tickers) @staticmethod def _normalize_agent_workspace_filename(raw_name: Any) -> Optional[str]: - """Restrict editable workspace files to a safe allowlist.""" - filename = str(raw_name or "").strip() - if filename in EDITABLE_AGENT_WORKSPACE_FILES: - return filename - return None + return gateway_runtime_support.normalize_agent_workspace_filename( + raw_name, + allowlist=EDITABLE_AGENT_WORKSPACE_FILES, + ) def _apply_runtime_config( self, runtime_config: Dict[str, Any], ) -> Dict[str, Any]: - """Apply runtime config to gateway-owned services and state.""" - warnings: List[str] = [] - - ticker_changes = self.market_service.update_tickers( - runtime_config.get("tickers", []), - ) - self.config["tickers"] = ticker_changes["active"] - - self.pipeline.max_comm_cycles = int(runtime_config["max_comm_cycles"]) - self.config["max_comm_cycles"] = self.pipeline.max_comm_cycles - self.config["schedule_mode"] = runtime_config.get( - "schedule_mode", - self.config.get("schedule_mode", "daily"), - ) - self.config["interval_minutes"] = int( - runtime_config.get( - "interval_minutes", - self.config.get("interval_minutes", 60), - ), - ) - self.config["trigger_time"] = runtime_config.get( - "trigger_time", - self.config.get("trigger_time", "09:30"), - ) - - if self.scheduler: - self.scheduler.reconfigure( - mode=self.config["schedule_mode"], - trigger_time=self.config["trigger_time"], - interval_minutes=self.config["interval_minutes"], - ) - - pm_apply_result = self.pipeline.pm.apply_runtime_portfolio_config( - margin_requirement=runtime_config["margin_requirement"], - ) - self.config["margin_requirement"] = self.pipeline.pm.portfolio.get( - "margin_requirement", - runtime_config["margin_requirement"], - ) - - requested_initial_cash = float(runtime_config["initial_cash"]) - current_initial_cash = float(self.storage.initial_cash) - initial_cash_applied = requested_initial_cash == current_initial_cash - if not initial_cash_applied: - if ( - self.storage.can_apply_initial_cash() - and self.pipeline.pm.can_apply_initial_cash() - ): - initial_cash_applied = self.storage.apply_initial_cash( - requested_initial_cash, - ) - if initial_cash_applied: - self.pipeline.pm.apply_runtime_portfolio_config( - initial_cash=requested_initial_cash, - ) - self.config["initial_cash"] = self.storage.initial_cash - else: - warnings.append( - "initial_cash changed in BOOTSTRAP.md but was not applied " - "because the run already has positions, margin usage, or trades.", - ) - - requested_enable_memory = bool(runtime_config["enable_memory"]) - current_enable_memory = bool(self.config.get("enable_memory", False)) - if requested_enable_memory != current_enable_memory: - warnings.append( - "enable_memory changed in BOOTSTRAP.md but still requires a restart " - "because long-term memory contexts are created at startup.", - ) - - self._sync_runtime_state() - - return { - "runtime_config_requested": runtime_config, - "runtime_config_applied": { - "tickers": list(self.config.get("tickers", [])), - "schedule_mode": self.config.get("schedule_mode", "daily"), - "interval_minutes": self.config.get("interval_minutes", 60), - "trigger_time": self.config.get("trigger_time", "09:30"), - "initial_cash": self.storage.initial_cash, - "margin_requirement": self.config["margin_requirement"], - "max_comm_cycles": self.config["max_comm_cycles"], - "enable_memory": self.config.get("enable_memory", False), - }, - "runtime_config_status": { - "tickers": True, - "schedule_mode": True, - "interval_minutes": True, - "trigger_time": True, - "initial_cash": initial_cash_applied, - "margin_requirement": pm_apply_result["margin_requirement"], - "max_comm_cycles": True, - "enable_memory": requested_enable_memory == current_enable_memory, - }, - "ticker_changes": ticker_changes, - "runtime_config_warnings": warnings, - } + return gateway_runtime_support.apply_runtime_config(self, runtime_config) def _sync_runtime_state(self) -> None: - """Refresh persisted state and dashboard after runtime config changes.""" - self.state_sync.update_state("tickers", self.config.get("tickers", [])) - self.state_sync.update_state( - "runtime_config", - { - "tickers": self.config.get("tickers", []), - "schedule_mode": self.config.get("schedule_mode", "daily"), - "interval_minutes": self.config.get("interval_minutes", 60), - "trigger_time": self.config.get("trigger_time", "09:30"), - "initial_cash": self.storage.initial_cash, - "margin_requirement": self.config.get("margin_requirement"), - "max_comm_cycles": self.config.get("max_comm_cycles"), - "enable_memory": self.config.get("enable_memory", False), - }, - ) - - self.storage.update_server_state_from_dashboard(self.state_sync.state) - self.state_sync.save_state() - - self._dashboard.tickers = list(self.config.get("tickers", [])) - self._dashboard.initial_cash = self.storage.initial_cash - self._dashboard.enable_memory = bool( - self.config.get("enable_memory", False), - ) - - summary = self.storage.load_file("summary") or {} - holdings = self.storage.load_file("holdings") or [] - trades = self.storage.load_file("trades") or [] - self._dashboard.update( - portfolio=summary, - holdings=holdings, - trades=trades, - ) + gateway_runtime_support.sync_runtime_state(self) def _schedule_watchlist_market_store_refresh( self, tickers: List[str], ) -> None: - """Kick off a non-blocking Polygon refresh for the updated watchlist.""" - if not tickers: - return - if self._watchlist_ingest_task and not self._watchlist_ingest_task.done(): - self._watchlist_ingest_task.cancel() - self._watchlist_ingest_task = asyncio.create_task( - self._refresh_market_store_for_watchlist(tickers), - ) + gateway_cycle_support.schedule_watchlist_market_store_refresh(self, tickers) async def _refresh_market_store_for_watchlist( self, tickers: List[str], ) -> None: - """Refresh the long-lived market store after a watchlist update.""" - try: - await self.state_sync.on_system_message( - f"正在同步自选股市场数据: {', '.join(tickers)}", - ) - results = await asyncio.to_thread( - ingest_symbols, - tickers, - mode="incremental", - ) - summary = ", ".join( - f"{item['symbol']} prices={item['prices']} news={item['news']}" - for item in results - ) - await self.state_sync.on_system_message( - f"自选股市场数据已同步: {summary}", - ) - except asyncio.CancelledError: - raise - except Exception as exc: - logger.warning("Watchlist market store refresh failed: %s", exc) - await self.state_sync.on_system_message( - f"自选股市场数据同步失败: {exc}", - ) + await gateway_cycle_support.refresh_market_store_for_watchlist(self, tickers) async def broadcast(self, message: Dict[str, Any]): """Broadcast message to all connected clients""" @@ -2452,332 +733,39 @@ class Gateway: self.connected_clients.discard(client) async def _market_status_monitor(self): - """Periodically check and broadcast market status changes""" - while True: - try: - await self.market_service.check_and_broadcast_market_status() - - # On market open, start live session tracking - status = self.market_service.get_market_status() - if ( - status["status"] == "open" - and not self.storage.is_live_session_active - ): - self.storage.start_live_session() - summary = self.storage.load_file("summary") or {} - self._session_start_portfolio_value = summary.get( - "totalAssetValue", - self.storage.initial_cash, - ) - logger.info( - "Session start portfolio: " - f"${self._session_start_portfolio_value:,.2f}", - ) - elif ( - status["status"] != "open" - and self.storage.is_live_session_active - ): - self.storage.end_live_session() - self._session_start_portfolio_value = None - - # Update and broadcast live returns if session is active - if self.storage.is_live_session_active: - await self._update_and_broadcast_live_returns() - - await asyncio.sleep(60) # Check every minute - except asyncio.CancelledError: - break - except Exception as e: - logger.error(f"Market status monitor error: {e}") - await asyncio.sleep(60) + await gateway_cycle_support.market_status_monitor(self) async def _update_and_broadcast_live_returns(self): - """Calculate and broadcast live returns for current session""" - if not self.storage.is_live_session_active: - return - - # Get current prices and calculate portfolio value - prices = self.market_service.get_all_prices() - if not prices or not any(p > 0 for p in prices.values()): - return - - # Load current internal state to get baseline values - state = self.storage.load_internal_state() - - # Get latest values from history (if available) - equity_history = state.get("equity_history", []) - baseline_history = state.get("baseline_history", []) - baseline_vw_history = state.get("baseline_vw_history", []) - momentum_history = state.get("momentum_history", []) - - current_equity = equity_history[-1]["v"] if equity_history else None - current_baseline = ( - baseline_history[-1]["v"] if baseline_history else None - ) - current_baseline_vw = ( - baseline_vw_history[-1]["v"] if baseline_vw_history else None - ) - current_momentum = ( - momentum_history[-1]["v"] if momentum_history else None - ) - - # Update live returns with current values - point = self.storage.update_live_returns( - current_equity=current_equity, - current_baseline=current_baseline, - current_baseline_vw=current_baseline_vw, - current_momentum=current_momentum, - ) - - # Broadcast if we have new data - if point: - live_returns = self.storage.get_live_returns() - await self.broadcast( - { - "type": "team_summary", - "equity_return": live_returns["equity_return"], - "baseline_return": live_returns["baseline_return"], - "baseline_vw_return": live_returns["baseline_vw_return"], - "momentum_return": live_returns["momentum_return"], - }, - ) + await gateway_cycle_support.update_and_broadcast_live_returns(self) async def on_strategy_trigger(self, date: str): - """Handle trading cycle trigger""" - if self._cycle_lock.locked(): - logger.warning("Trading cycle already running, skipping trigger for %s", date) - await self.state_sync.on_system_message( - f"已有交易周期在运行,跳过本次触发: {date}", - ) - return - - async with self._cycle_lock: - logger.info(f"Strategy triggered for {date}") - - tickers = self.config.get("tickers", []) - - if self.is_backtest: - await self._run_backtest_cycle(date, tickers) - else: - await self._run_live_cycle(date, tickers) + await gateway_cycle_support.on_strategy_trigger(self, date) async def on_heartbeat_trigger(self, date: str): - """Run lightweight heartbeat check for all analysts. - - Each analyst reads its HEARTBEAT.md and performs a self-check - without running the full trading pipeline. - """ - logger.info(f"[Heartbeat] Running heartbeat check for {date}") - - tickers = self.config.get("tickers", []) - analysts = self.pipeline._all_analysts() - - for analyst in analysts: - try: - ws_id = getattr(analyst, "workspace_id", None) - if ws_id: - from backend.agents.workspace_manager import get_workspace_dir - ws_dir = get_workspace_dir(ws_id) - if ws_dir: - from pathlib import Path - hb_path = Path(ws_dir) / "HEARTBEAT.md" - if hb_path.exists(): - content = hb_path.read_text(encoding="utf-8").strip() - if content: - hb_task = ( - f"# 定期主动检查\n\n{content}\n\n" - "请执行上述检查并报告结果。" - ) - logger.info( - f"[Heartbeat] Running heartbeat for {analyst.name}", - ) - # Build a minimal user message and let the analyst reply - from agentscope.message import Msg - msg = Msg( - role="user", - content=hb_task, - name="system", - ) - result = await analyst.reply([msg]) - logger.info( - f"[Heartbeat] {analyst.name} heartbeat complete", - ) - continue - - logger.debug( - f"[Heartbeat] No HEARTBEAT.md for {analyst.name}, skipping", - ) - except Exception as e: - logger.error( - f"[Heartbeat] {analyst.name} failed: {e}", - exc_info=True, - ) + await gateway_cycle_support.on_heartbeat_trigger(self, date) async def _run_backtest_cycle(self, date: str, tickers: List[str]): - """Run backtest cycle with pre-loaded prices""" - self.market_service.set_backtest_date(date) - await self.market_service.emit_market_open() - - await self.state_sync.on_cycle_start(date) - self._dashboard.update(date=date, status="Analyzing...") - - prices = self.market_service.get_open_prices() - close_prices = self.market_service.get_close_prices() - market_caps = self._get_market_caps(tickers, date) - - result = await self.pipeline.run_cycle( - tickers=tickers, - date=date, - prices=prices, - close_prices=close_prices, - market_caps=market_caps, - ) - - await self.market_service.emit_market_close() - settlement_result = result.get("settlement_result") - self._save_cycle_results(result, date, close_prices, settlement_result) - await self._broadcast_portfolio_updates(result, close_prices) - await self._finalize_cycle(date) + await gateway_cycle_support.run_backtest_cycle(self, date, tickers) async def _run_live_cycle(self, date: str, tickers: List[str]): - """ - Run live cycle with real market timing. - - - Analysis runs immediately - - Daily mode waits for open/close as before - - Intraday mode executes only during market open - and skips trading outside market hours - """ - # Get actual trading date (might be next trading day if weekend) - trading_date = self.market_service.get_live_trading_date() - logger.info( - f"Live cycle: triggered={date}, trading_date={trading_date}", - ) - - await self.state_sync.on_cycle_start(trading_date) - self._dashboard.update(date=trading_date, status="Analyzing...") - - market_caps = self._get_market_caps(tickers, trading_date) - schedule_mode = self.config.get("schedule_mode", "daily") - market_status = self.market_service.get_market_status() - current_prices = self.market_service.get_all_prices() - - if schedule_mode == "intraday": - execute_decisions = market_status.get("status") == "open" - if execute_decisions: - await self.state_sync.on_system_message( - "定时任务触发:当前处于交易时段,本轮将执行交易决策", - ) - else: - await self.state_sync.on_system_message( - "定时任务触发:当前非交易时段,本轮仅更新数据与分析,不执行交易", - ) - - result = await self.pipeline.run_cycle( - tickers=tickers, - date=trading_date, - prices=current_prices, - market_caps=market_caps, - execute_decisions=execute_decisions, - ) - close_prices = current_prices - else: - # Daily mode keeps the original full-session behavior - result = await self.pipeline.run_cycle( - tickers=tickers, - date=trading_date, - market_caps=market_caps, - get_open_prices_fn=self.market_service.wait_for_open_prices, - get_close_prices_fn=self.market_service.wait_for_close_prices, - ) - close_prices = self.market_service.get_all_prices() - - settlement_result = result.get("settlement_result") - self._save_cycle_results( - result, - trading_date, - close_prices, - settlement_result, - ) - await self._broadcast_portfolio_updates(result, close_prices) - await self._finalize_cycle(trading_date) + await gateway_cycle_support.run_live_cycle(self, date, tickers) async def _finalize_cycle(self, date: str): - """Finalize cycle: broadcast state and update dashboard""" - summary = self.storage.load_file("summary") or {} + await gateway_cycle_support.finalize_cycle(self, date) - # Include live returns if session is active - if self.storage.is_live_session_active: - live_returns = self.storage.get_live_returns() - summary.update(live_returns) - - await self.state_sync.on_cycle_end(date, portfolio_summary=summary) - - holdings = self.storage.load_file("holdings") or [] - trades = self.storage.load_file("trades") or [] - leaderboard = self.storage.load_file("leaderboard") or [] - - if leaderboard: - await self.state_sync.on_leaderboard_update(leaderboard) - - self._dashboard.update( - date=date, - status="Running", - portfolio=summary, - holdings=holdings, - trades=trades, - ) - - def _get_market_caps( + async def _get_market_caps( self, tickers: List[str], date: str, ) -> Dict[str, float]: - """ - Get market caps for tickers (stub implementation) - - Args: - tickers: List of tickers - date: Trading date - - Returns: - Dict mapping ticker to market cap - """ - from ..tools.data_tools import get_market_cap - - market_caps = {} - for ticker in tickers: - try: - market_cap = get_market_cap(ticker, date) - if market_cap: - market_caps[ticker] = market_cap - else: - market_caps[ticker] = 1e9 - except Exception as e: - logger.warning(f"Failed to get market cap for {ticker}, using default 1e9: {e}") - market_caps[ticker] = 1e9 - - return market_caps + return await gateway_cycle_support.get_market_caps(self, tickers, date) async def _broadcast_portfolio_updates( self, result: Dict[str, Any], prices: Dict[str, float], ): - portfolio = result.get("portfolio", {}) - - if portfolio: - holdings = FrontendAdapter.build_holdings(portfolio, prices) - if holdings: - await self.state_sync.on_holdings_update(holdings) - - stats = FrontendAdapter.build_stats(portfolio, prices) - if stats: - await self.state_sync.on_stats_update(stats) - - executed_trades = result.get("executed_trades", []) - if executed_trades: - await self.state_sync.on_trades_executed(executed_trades) + await gateway_cycle_support.broadcast_portfolio_updates(self, result, prices) def _save_cycle_results( self, @@ -2786,101 +774,25 @@ class Gateway: prices: Dict[str, float], settlement_result: Optional[Dict[str, Any]] = None, ): - portfolio = result.get("portfolio", {}) - executed_trades = result.get("executed_trades", []) - - # Extract baseline values from settlement result - baseline_values = None - if settlement_result: - baseline_values = settlement_result.get("baseline_values") - - if portfolio: - self.storage.update_dashboard_after_cycle( - portfolio=portfolio, - prices=prices, - date=date, - executed_trades=executed_trades, - baseline_values=baseline_values, - ) - - async def _run_backtest_dates(self, dates: List[str]): - self.state_sync.set_backtest_dates(dates) - self._dashboard.update(days_total=len(dates), days_completed=0) - - await self.state_sync.on_system_message( - f"Starting backtest - {len(dates)} trading days", + gateway_cycle_support.save_cycle_results( + self, + result, + date, + prices, + settlement_result, ) - try: - for i, date in enumerate(dates): - self._dashboard.update(days_completed=i) - await self.on_strategy_trigger(date=date) - await asyncio.sleep(0.1) - - await self.state_sync.on_system_message( - f"Backtest complete - {len(dates)} days", - ) - - # Update dashboard with final state - summary = self.storage.load_file("summary") or {} - self._dashboard.update( - status="Complete", - portfolio=summary, - days_completed=len(dates), - ) - self._dashboard.stop() - self._dashboard.print_final_summary() - except Exception as e: - error_msg = f"Backtest failed: {type(e).__name__}: {str(e)}" - logger.error(error_msg, exc_info=True) - await self.state_sync.on_system_message(error_msg) - self._dashboard.update(status=f"Failed: {str(e)}") - self._dashboard.stop() - raise - finally: - self._backtest_task = None + async def _run_backtest_dates(self, dates: List[str]): + await gateway_cycle_support.run_backtest_dates(self, dates) def _handle_backtest_exception(self, task: asyncio.Task): - """Handle exceptions from backtest task""" - try: - task.result() - except asyncio.CancelledError: - logger.info("Backtest task was cancelled") - except Exception as e: - logger.error( - f"Backtest task failed with exception:{type(e).__name__}:{e}", - exc_info=True, - ) + gateway_cycle_support.handle_backtest_exception(self, task) def _handle_manual_cycle_exception(self, task: asyncio.Task): - """Handle exceptions from manually-triggered live cycles.""" - self._manual_cycle_task = None - try: - task.result() - except asyncio.CancelledError: - logger.info("Manual cycle task was cancelled") - except Exception as exc: - logger.error( - "Manual cycle task failed with exception:%s:%s", - type(exc).__name__, - exc, - exc_info=True, - ) + gateway_cycle_support.handle_manual_cycle_exception(self, task) def set_backtest_dates(self, dates: List[str]): - self.state_sync.set_backtest_dates(dates) - if dates: - self._backtest_start_date = dates[0] - self._backtest_end_date = dates[-1] - self._dashboard.days_total = len(dates) + gateway_cycle_support.set_backtest_dates(self, dates) def stop(self): - self.state_sync.save_state() - self.market_service.stop() - if self._backtest_task: - self._backtest_task.cancel() - if self._market_status_task: - self._market_status_task.cancel() - if self._watchlist_ingest_task: - self._watchlist_ingest_task.cancel() - self._dashboard.stop() + gateway_cycle_support.stop_gateway(self) diff --git a/backend/services/gateway_admin_handlers.py b/backend/services/gateway_admin_handlers.py new file mode 100644 index 0000000..424fb11 --- /dev/null +++ b/backend/services/gateway_admin_handlers.py @@ -0,0 +1,419 @@ +# -*- coding: utf-8 -*- +"""Runtime/workspace/skills handlers extracted from the main Gateway module.""" + +from __future__ import annotations + +import json +from datetime import datetime +from typing import Any + +from backend.agents.agent_workspace import load_agent_workspace_config +from backend.agents.skills_manager import SkillsManager +from backend.agents.toolkit_factory import load_agent_profiles +from backend.config.bootstrap_config import ( + get_bootstrap_config_for_run, + resolve_runtime_config, + update_bootstrap_values_for_run, +) +from backend.data.market_ingest import ingest_symbols +from backend.llm.models import get_agent_model_info + + +async def handle_reload_runtime_assets(gateway: Any) -> None: + config_name = gateway.config.get("config_name", "default") + runtime_config = resolve_runtime_config( + project_root=gateway._project_root, + config_name=config_name, + enable_memory=gateway.config.get("enable_memory", False), + schedule_mode=gateway.config.get("schedule_mode", "daily"), + interval_minutes=gateway.config.get("interval_minutes", 60), + trigger_time=gateway.config.get("trigger_time", "09:30"), + ) + result = gateway.pipeline.reload_runtime_assets(runtime_config=runtime_config) + runtime_updates = gateway._apply_runtime_config(runtime_config) + await gateway.state_sync.on_system_message("Runtime assets reloaded.") + await gateway.broadcast({"type": "runtime_assets_reloaded", **result, **runtime_updates}) + + +async def handle_update_runtime_config(gateway: Any, websocket: Any, data: dict[str, Any]) -> None: + updates: dict[str, Any] = {} + + schedule_mode = str(data.get("schedule_mode", "")).strip().lower() + if schedule_mode: + if schedule_mode not in {"daily", "intraday"}: + await websocket.send(json.dumps({"type": "error", "message": "schedule_mode must be 'daily' or 'intraday'."}, ensure_ascii=False)) + return + updates["schedule_mode"] = schedule_mode + + interval_minutes = data.get("interval_minutes") + if interval_minutes is not None: + try: + parsed_interval = int(interval_minutes) + except (TypeError, ValueError): + parsed_interval = 0 + if parsed_interval <= 0: + await websocket.send(json.dumps({"type": "error", "message": "interval_minutes must be a positive integer."}, ensure_ascii=False)) + return + updates["interval_minutes"] = parsed_interval + + trigger_time = data.get("trigger_time") + if trigger_time is not None: + raw_trigger = str(trigger_time).strip() + if raw_trigger and raw_trigger != "now": + try: + datetime.strptime(raw_trigger, "%H:%M") + except ValueError: + await websocket.send(json.dumps({"type": "error", "message": "trigger_time must use HH:MM or 'now'."}, ensure_ascii=False)) + return + updates["trigger_time"] = raw_trigger or "09:30" + + max_comm_cycles = data.get("max_comm_cycles") + if max_comm_cycles is not None: + try: + parsed_cycles = int(max_comm_cycles) + except (TypeError, ValueError): + parsed_cycles = 0 + if parsed_cycles <= 0: + await websocket.send(json.dumps({"type": "error", "message": "max_comm_cycles must be a positive integer."}, ensure_ascii=False)) + return + updates["max_comm_cycles"] = parsed_cycles + + initial_cash = data.get("initial_cash") + if initial_cash is not None: + try: + parsed_initial_cash = float(initial_cash) + except (TypeError, ValueError): + parsed_initial_cash = 0.0 + if parsed_initial_cash <= 0: + await websocket.send(json.dumps({"type": "error", "message": "initial_cash must be a positive number."}, ensure_ascii=False)) + return + updates["initial_cash"] = parsed_initial_cash + + margin_requirement = data.get("margin_requirement") + if margin_requirement is not None: + try: + parsed_margin_requirement = float(margin_requirement) + except (TypeError, ValueError): + parsed_margin_requirement = -1.0 + if parsed_margin_requirement < 0: + await websocket.send(json.dumps({"type": "error", "message": "margin_requirement must be a non-negative number."}, ensure_ascii=False)) + return + updates["margin_requirement"] = parsed_margin_requirement + + enable_memory = data.get("enable_memory") + if enable_memory is not None: + updates["enable_memory"] = bool(enable_memory) + + if not updates: + await websocket.send(json.dumps({"type": "error", "message": "No runtime settings were provided."}, ensure_ascii=False)) + return + + config_name = gateway.config.get("config_name", "default") + update_bootstrap_values_for_run( + project_root=gateway._project_root, + config_name=config_name, + updates=updates, + ) + await gateway.state_sync.on_system_message("运行时调度配置已保存,正在热更新") + await handle_reload_runtime_assets(gateway) + + +async def handle_update_watchlist(gateway: Any, websocket: Any, data: dict[str, Any]) -> None: + tickers = gateway._normalize_watchlist(data.get("tickers")) + if not tickers: + await websocket.send(json.dumps({"type": "error", "message": "update_watchlist requires at least one valid ticker."}, ensure_ascii=False)) + return + + config_name = gateway.config.get("config_name", "default") + update_bootstrap_values_for_run( + project_root=gateway._project_root, + config_name=config_name, + updates={"tickers": tickers}, + ) + await gateway.state_sync.on_system_message(f"Watchlist updated: {', '.join(tickers)}") + await gateway.broadcast({"type": "watchlist_updated", "config_name": config_name, "tickers": tickers}) + await handle_reload_runtime_assets(gateway) + gateway._schedule_watchlist_market_store_refresh(tickers) + + +async def handle_get_agent_skills(gateway: Any, websocket: Any, data: dict[str, Any]) -> None: + agent_id = str(data.get("agent_id", "")).strip() + if not agent_id: + await websocket.send(json.dumps({"type": "error", "message": "get_agent_skills requires agent_id."}, ensure_ascii=False)) + return + + config_name = gateway.config.get("config_name", "default") + skills_manager = SkillsManager(project_root=gateway._project_root) + agent_asset_dir = skills_manager.get_agent_asset_dir(config_name, agent_id) + agent_config = load_agent_workspace_config(agent_asset_dir / "agent.yaml") + resolved_skills = set(skills_manager.resolve_agent_skill_names(config_name=config_name, agent_id=agent_id, default_skills=[])) + enabled = set(agent_config.enabled_skills) + disabled = set(agent_config.disabled_skills) + + payload = [] + for item in skills_manager.list_agent_skill_catalog(config_name, agent_id): + if item.skill_name in disabled: + status = "disabled" + elif item.skill_name in enabled: + status = "enabled" + elif item.skill_name in resolved_skills: + status = "active" + else: + status = "available" + payload.append({ + "skill_name": item.skill_name, + "name": item.name, + "description": item.description, + "version": item.version, + "source": item.source, + "tools": item.tools, + "status": status, + }) + + await websocket.send(json.dumps({ + "type": "agent_skills_loaded", + "config_name": config_name, + "agent_id": agent_id, + "skills": payload, + }, ensure_ascii=False)) + + +async def handle_get_agent_profile(gateway: Any, websocket: Any, data: dict[str, Any]) -> None: + agent_id = str(data.get("agent_id", "")).strip() + if not agent_id: + await websocket.send(json.dumps({"type": "error", "message": "get_agent_profile requires agent_id."}, ensure_ascii=False)) + return + + config_name = gateway.config.get("config_name", "default") + skills_manager = SkillsManager(project_root=gateway._project_root) + asset_dir = skills_manager.get_agent_asset_dir(config_name, agent_id) + agent_config = load_agent_workspace_config(asset_dir / "agent.yaml") + profiles = load_agent_profiles() + profile = profiles.get(agent_id, {}) + bootstrap = get_bootstrap_config_for_run(gateway._project_root, config_name) + override = bootstrap.agent_override(agent_id) + active_tool_groups = override.get("active_tool_groups", agent_config.active_tool_groups or profile.get("active_tool_groups", [])) + if not isinstance(active_tool_groups, list): + active_tool_groups = [] + disabled_tool_groups = agent_config.disabled_tool_groups + if disabled_tool_groups: + disabled_set = set(disabled_tool_groups) + active_tool_groups = [group_name for group_name in active_tool_groups if group_name not in disabled_set] + + default_skills = profile.get("skills", []) + if not isinstance(default_skills, list): + default_skills = [] + resolved_skills = skills_manager.resolve_agent_skill_names( + config_name=config_name, + agent_id=agent_id, + default_skills=default_skills, + ) + prompt_files = agent_config.prompt_files or ["SOUL.md", "PROFILE.md", "AGENTS.md", "POLICY.md", "MEMORY.md"] + model_name, model_provider = get_agent_model_info(agent_id) + + await websocket.send(json.dumps({ + "type": "agent_profile_loaded", + "config_name": config_name, + "agent_id": agent_id, + "profile": { + "model_name": model_name, + "model_provider": model_provider, + "prompt_files": prompt_files, + "default_skills": default_skills, + "resolved_skills": resolved_skills, + "active_tool_groups": active_tool_groups, + "disabled_tool_groups": disabled_tool_groups, + "enabled_skills": agent_config.enabled_skills, + "disabled_skills": agent_config.disabled_skills, + }, + }, ensure_ascii=False)) + + +async def handle_get_skill_detail(gateway: Any, websocket: Any, data: dict[str, Any]) -> None: + agent_id = str(data.get("agent_id", "")).strip() + skill_name = str(data.get("skill_name", "")).strip() + if not skill_name: + await websocket.send(json.dumps({"type": "error", "message": "get_skill_detail requires skill_name."}, ensure_ascii=False)) + return + + skills_manager = SkillsManager(project_root=gateway._project_root) + try: + if agent_id: + config_name = gateway.config.get("config_name", "default") + detail = skills_manager.load_agent_skill_document(config_name=config_name, agent_id=agent_id, skill_name=skill_name) + else: + detail = skills_manager.load_skill_document(skill_name) + except FileNotFoundError: + await websocket.send(json.dumps({"type": "error", "message": f"Unknown skill: {skill_name}"}, ensure_ascii=False)) + return + + await websocket.send(json.dumps({ + "type": "skill_detail_loaded", + "agent_id": agent_id, + "skill": detail, + }, ensure_ascii=False)) + + +async def handle_create_agent_local_skill(gateway: Any, websocket: Any, data: dict[str, Any]) -> None: + agent_id = str(data.get("agent_id", "")).strip() + skill_name = str(data.get("skill_name", "")).strip() + if not agent_id or not skill_name: + await websocket.send(json.dumps({"type": "error", "message": "create_agent_local_skill requires agent_id and skill_name."}, ensure_ascii=False)) + return + + config_name = gateway.config.get("config_name", "default") + skills_manager = SkillsManager(project_root=gateway._project_root) + try: + skills_manager.create_agent_local_skill(config_name=config_name, agent_id=agent_id, skill_name=skill_name) + except (ValueError, FileExistsError) as exc: + await websocket.send(json.dumps({"type": "error", "message": str(exc)}, ensure_ascii=False)) + return + + await gateway.state_sync.on_system_message(f"Created local skill {skill_name} for {agent_id}") + await gateway._handle_reload_runtime_assets() + await websocket.send(json.dumps({"type": "agent_local_skill_created", "agent_id": agent_id, "skill_name": skill_name}, ensure_ascii=False)) + await handle_get_agent_skills(gateway, websocket, {"agent_id": agent_id}) + await handle_get_skill_detail(gateway, websocket, {"agent_id": agent_id, "skill_name": skill_name}) + + +async def handle_update_agent_local_skill(gateway: Any, websocket: Any, data: dict[str, Any]) -> None: + agent_id = str(data.get("agent_id", "")).strip() + skill_name = str(data.get("skill_name", "")).strip() + content = data.get("content") + if not agent_id or not skill_name or not isinstance(content, str): + await websocket.send(json.dumps({"type": "error", "message": "update_agent_local_skill requires agent_id, skill_name, and string content."}, ensure_ascii=False)) + return + + config_name = gateway.config.get("config_name", "default") + skills_manager = SkillsManager(project_root=gateway._project_root) + try: + skills_manager.update_agent_local_skill(config_name=config_name, agent_id=agent_id, skill_name=skill_name, content=content) + except (ValueError, FileNotFoundError) as exc: + await websocket.send(json.dumps({"type": "error", "message": str(exc)}, ensure_ascii=False)) + return + + await gateway.state_sync.on_system_message(f"Updated local skill {skill_name} for {agent_id}") + await gateway._handle_reload_runtime_assets() + await websocket.send(json.dumps({"type": "agent_local_skill_updated", "agent_id": agent_id, "skill_name": skill_name}, ensure_ascii=False)) + await handle_get_skill_detail(gateway, websocket, {"agent_id": agent_id, "skill_name": skill_name}) + + +async def handle_delete_agent_local_skill(gateway: Any, websocket: Any, data: dict[str, Any]) -> None: + agent_id = str(data.get("agent_id", "")).strip() + skill_name = str(data.get("skill_name", "")).strip() + if not agent_id or not skill_name: + await websocket.send(json.dumps({"type": "error", "message": "delete_agent_local_skill requires agent_id and skill_name."}, ensure_ascii=False)) + return + + config_name = gateway.config.get("config_name", "default") + skills_manager = SkillsManager(project_root=gateway._project_root) + try: + skills_manager.delete_agent_local_skill(config_name=config_name, agent_id=agent_id, skill_name=skill_name) + skills_manager.forget_agent_skill_overrides(config_name=config_name, agent_id=agent_id, skill_names=[skill_name]) + except (ValueError, FileNotFoundError) as exc: + await websocket.send(json.dumps({"type": "error", "message": str(exc)}, ensure_ascii=False)) + return + + await gateway.state_sync.on_system_message(f"Deleted local skill {skill_name} for {agent_id}") + await gateway._handle_reload_runtime_assets() + await websocket.send(json.dumps({"type": "agent_local_skill_deleted", "agent_id": agent_id, "skill_name": skill_name}, ensure_ascii=False)) + await handle_get_agent_skills(gateway, websocket, {"agent_id": agent_id}) + + +async def handle_remove_agent_skill(gateway: Any, websocket: Any, data: dict[str, Any]) -> None: + agent_id = str(data.get("agent_id", "")).strip() + skill_name = str(data.get("skill_name", "")).strip() + if not agent_id or not skill_name: + await websocket.send(json.dumps({"type": "error", "message": "remove_agent_skill requires agent_id and skill_name."}, ensure_ascii=False)) + return + + config_name = gateway.config.get("config_name", "default") + skills_manager = SkillsManager(project_root=gateway._project_root) + skill_names = { + item.skill_name + for item in skills_manager.list_agent_skill_catalog(config_name, agent_id) + if item.source != "local" + } + if skill_name not in skill_names: + await websocket.send(json.dumps({"type": "error", "message": f"Unknown shared skill: {skill_name}"}, ensure_ascii=False)) + return + + skills_manager.update_agent_skill_overrides(config_name=config_name, agent_id=agent_id, disable=[skill_name]) + await gateway.state_sync.on_system_message(f"Removed shared skill {skill_name} from {agent_id}") + await gateway._handle_reload_runtime_assets() + await websocket.send(json.dumps({"type": "agent_skill_removed", "agent_id": agent_id, "skill_name": skill_name}, ensure_ascii=False)) + await handle_get_agent_skills(gateway, websocket, {"agent_id": agent_id}) + + +async def handle_update_agent_skill(gateway: Any, websocket: Any, data: dict[str, Any]) -> None: + agent_id = str(data.get("agent_id", "")).strip() + skill_name = str(data.get("skill_name", "")).strip() + enabled = data.get("enabled") + if not agent_id or not skill_name or not isinstance(enabled, bool): + await websocket.send(json.dumps({"type": "error", "message": "update_agent_skill requires agent_id, skill_name, and boolean enabled."}, ensure_ascii=False)) + return + + config_name = gateway.config.get("config_name", "default") + skills_manager = SkillsManager(project_root=gateway._project_root) + skill_names = {item.skill_name for item in skills_manager.list_agent_skill_catalog(config_name, agent_id)} + if skill_name not in skill_names: + await websocket.send(json.dumps({"type": "error", "message": f"Unknown skill: {skill_name}"}, ensure_ascii=False)) + return + + if enabled: + skills_manager.update_agent_skill_overrides(config_name=config_name, agent_id=agent_id, enable=[skill_name]) + await gateway.state_sync.on_system_message(f"Enabled skill {skill_name} for {agent_id}") + else: + skills_manager.update_agent_skill_overrides(config_name=config_name, agent_id=agent_id, disable=[skill_name]) + await gateway.state_sync.on_system_message(f"Disabled skill {skill_name} for {agent_id}") + + await websocket.send(json.dumps({ + "type": "agent_skill_updated", + "agent_id": agent_id, + "skill_name": skill_name, + "enabled": enabled, + }, ensure_ascii=False)) + await gateway._handle_reload_runtime_assets() + await handle_get_agent_skills(gateway, websocket, {"agent_id": agent_id}) + + +async def handle_get_agent_workspace_file(gateway: Any, websocket: Any, data: dict[str, Any]) -> None: + agent_id = str(data.get("agent_id", "")).strip() + filename = gateway._normalize_agent_workspace_filename(data.get("filename")) + if not agent_id or not filename: + await websocket.send(json.dumps({"type": "error", "message": "get_agent_workspace_file requires agent_id and supported filename."}, ensure_ascii=False)) + return + + config_name = gateway.config.get("config_name", "default") + skills_manager = SkillsManager(project_root=gateway._project_root) + asset_dir = skills_manager.get_agent_asset_dir(config_name, agent_id) + asset_dir.mkdir(parents=True, exist_ok=True) + path = asset_dir / filename + content = path.read_text(encoding="utf-8") if path.exists() else "" + await websocket.send(json.dumps({ + "type": "agent_workspace_file_loaded", + "config_name": config_name, + "agent_id": agent_id, + "filename": filename, + "content": content, + }, ensure_ascii=False)) + + +async def handle_update_agent_workspace_file(gateway: Any, websocket: Any, data: dict[str, Any]) -> None: + agent_id = str(data.get("agent_id", "")).strip() + filename = gateway._normalize_agent_workspace_filename(data.get("filename")) + content = data.get("content") + if not agent_id or not filename or not isinstance(content, str): + await websocket.send(json.dumps({"type": "error", "message": "update_agent_workspace_file requires agent_id, supported filename, and string content."}, ensure_ascii=False)) + return + + config_name = gateway.config.get("config_name", "default") + skills_manager = SkillsManager(project_root=gateway._project_root) + asset_dir = skills_manager.get_agent_asset_dir(config_name, agent_id) + asset_dir.mkdir(parents=True, exist_ok=True) + path = asset_dir / filename + path.write_text(content, encoding="utf-8") + await gateway.state_sync.on_system_message(f"Updated {filename} for {agent_id}") + await websocket.send(json.dumps({"type": "agent_workspace_file_updated", "agent_id": agent_id, "filename": filename}, ensure_ascii=False)) + await gateway._handle_reload_runtime_assets() + await handle_get_agent_workspace_file(gateway, websocket, {"agent_id": agent_id, "filename": filename}) diff --git a/backend/services/gateway_cycle_support.py b/backend/services/gateway_cycle_support.py new file mode 100644 index 0000000..d5a1319 --- /dev/null +++ b/backend/services/gateway_cycle_support.py @@ -0,0 +1,373 @@ +# -*- coding: utf-8 -*- +"""Cycle and monitoring helpers extracted from the main Gateway module.""" + +from __future__ import annotations + +import asyncio +import logging +from typing import Any + +from backend.data.market_ingest import ingest_symbols +from backend.domains import trading as trading_domain +from backend.utils.msg_adapter import FrontendAdapter + +logger = logging.getLogger(__name__) + + +def schedule_watchlist_market_store_refresh(gateway: Any, tickers: list[str]) -> None: + """Kick off a non-blocking market-store refresh for an updated watchlist.""" + if not tickers: + return + if gateway._watchlist_ingest_task and not gateway._watchlist_ingest_task.done(): + gateway._watchlist_ingest_task.cancel() + gateway._watchlist_ingest_task = asyncio.create_task( + refresh_market_store_for_watchlist(gateway, tickers), + ) + + +async def refresh_market_store_for_watchlist(gateway: Any, tickers: list[str]) -> None: + """Refresh the long-lived market store after a watchlist update.""" + try: + await gateway.state_sync.on_system_message( + f"正在同步自选股市场数据: {', '.join(tickers)}", + ) + results = await asyncio.to_thread( + ingest_symbols, + tickers, + mode="incremental", + ) + summary = ", ".join( + f"{item['symbol']} prices={item['prices']} news={item['news']}" + for item in results + ) + await gateway.state_sync.on_system_message( + f"自选股市场数据已同步: {summary}", + ) + except asyncio.CancelledError: + raise + except Exception as exc: + logger.warning("Watchlist market store refresh failed: %s", exc) + await gateway.state_sync.on_system_message( + f"自选股市场数据同步失败: {exc}", + ) + + +async def market_status_monitor(gateway: Any) -> None: + """Periodically check and broadcast market status changes.""" + while True: + try: + await gateway.market_service.check_and_broadcast_market_status() + + status = gateway.market_service.get_market_status() + if status["status"] == "open" and not gateway.storage.is_live_session_active: + gateway.storage.start_live_session() + summary = gateway.storage.load_file("summary") or {} + gateway._session_start_portfolio_value = summary.get( + "totalAssetValue", + gateway.storage.initial_cash, + ) + logger.info( + "Session start portfolio: $%s", + f"{gateway._session_start_portfolio_value:,.2f}", + ) + elif status["status"] != "open" and gateway.storage.is_live_session_active: + gateway.storage.end_live_session() + gateway._session_start_portfolio_value = None + + if gateway.storage.is_live_session_active: + await update_and_broadcast_live_returns(gateway) + + await asyncio.sleep(60) + except asyncio.CancelledError: + break + except Exception as exc: + logger.error("Market status monitor error: %s", exc) + await asyncio.sleep(60) + + +async def update_and_broadcast_live_returns(gateway: Any) -> None: + """Calculate and broadcast live returns for current session.""" + if not gateway.storage.is_live_session_active: + return + + prices = gateway.market_service.get_all_prices() + if not prices or not any(p > 0 for p in prices.values()): + return + + state = gateway.storage.load_internal_state() + equity_history = state.get("equity_history", []) + baseline_history = state.get("baseline_history", []) + baseline_vw_history = state.get("baseline_vw_history", []) + momentum_history = state.get("momentum_history", []) + + current_equity = equity_history[-1]["v"] if equity_history else None + current_baseline = baseline_history[-1]["v"] if baseline_history else None + current_baseline_vw = baseline_vw_history[-1]["v"] if baseline_vw_history else None + current_momentum = momentum_history[-1]["v"] if momentum_history else None + + point = gateway.storage.update_live_returns( + current_equity=current_equity, + current_baseline=current_baseline, + current_baseline_vw=current_baseline_vw, + current_momentum=current_momentum, + ) + if point: + live_returns = gateway.storage.get_live_returns() + await gateway.broadcast( + { + "type": "team_summary", + "equity_return": live_returns["equity_return"], + "baseline_return": live_returns["baseline_return"], + "baseline_vw_return": live_returns["baseline_vw_return"], + "momentum_return": live_returns["momentum_return"], + }, + ) + + +async def on_strategy_trigger(gateway: Any, date: str) -> None: + """Handle trading cycle trigger.""" + if gateway._cycle_lock.locked(): + logger.warning("Trading cycle already running, skipping trigger for %s", date) + await gateway.state_sync.on_system_message(f"已有交易周期在运行,跳过本次触发: {date}") + return + + async with gateway._cycle_lock: + logger.info("Strategy triggered for %s", date) + tickers = gateway.config.get("tickers", []) + if gateway.is_backtest: + await run_backtest_cycle(gateway, date, tickers) + else: + await run_live_cycle(gateway, date, tickers) + + +async def on_heartbeat_trigger(gateway: Any, date: str) -> None: + """Run lightweight heartbeat check for all analysts.""" + logger.info("[Heartbeat] Running heartbeat check for %s", date) + analysts = gateway.pipeline._all_analysts() + + for analyst in analysts: + try: + ws_id = getattr(analyst, "workspace_id", None) + if ws_id: + from backend.agents.workspace_manager import get_workspace_dir + from pathlib import Path + from agentscope.message import Msg + + ws_dir = get_workspace_dir(ws_id) + if ws_dir: + hb_path = Path(ws_dir) / "HEARTBEAT.md" + if hb_path.exists(): + content = hb_path.read_text(encoding="utf-8").strip() + if content: + hb_task = f"# 定期主动检查\n\n{content}\n\n请执行上述检查并报告结果。" + logger.info("[Heartbeat] Running heartbeat for %s", analyst.name) + msg = Msg(role="user", content=hb_task, name="system") + await analyst.reply([msg]) + logger.info("[Heartbeat] %s heartbeat complete", analyst.name) + continue + logger.debug("[Heartbeat] No HEARTBEAT.md for %s, skipping", analyst.name) + except Exception as exc: + logger.error("[Heartbeat] %s failed: %s", analyst.name, exc, exc_info=True) + + +async def run_backtest_cycle(gateway: Any, date: str, tickers: list[str]) -> None: + gateway.market_service.set_backtest_date(date) + await gateway.market_service.emit_market_open() + + await gateway.state_sync.on_cycle_start(date) + gateway._dashboard.update(date=date, status="Analyzing...") + + prices = gateway.market_service.get_open_prices() + close_prices = gateway.market_service.get_close_prices() + market_caps = await get_market_caps(gateway, tickers, date) + + result = await gateway.pipeline.run_cycle( + tickers=tickers, + date=date, + prices=prices, + close_prices=close_prices, + market_caps=market_caps, + ) + + await gateway.market_service.emit_market_close() + settlement_result = result.get("settlement_result") + save_cycle_results(gateway, result, date, close_prices, settlement_result) + await broadcast_portfolio_updates(gateway, result, close_prices) + await finalize_cycle(gateway, date) + + +async def run_live_cycle(gateway: Any, date: str, tickers: list[str]) -> None: + trading_date = gateway.market_service.get_live_trading_date() + logger.info("Live cycle: triggered=%s, trading_date=%s", date, trading_date) + + await gateway.state_sync.on_cycle_start(trading_date) + gateway._dashboard.update(date=trading_date, status="Analyzing...") + + market_caps = await get_market_caps(gateway, tickers, trading_date) + schedule_mode = gateway.config.get("schedule_mode", "daily") + market_status = gateway.market_service.get_market_status() + current_prices = gateway.market_service.get_all_prices() + + if schedule_mode == "intraday": + execute_decisions = market_status.get("status") == "open" + if execute_decisions: + await gateway.state_sync.on_system_message("定时任务触发:当前处于交易时段,本轮将执行交易决策") + else: + await gateway.state_sync.on_system_message("定时任务触发:当前非交易时段,本轮仅更新数据与分析,不执行交易") + + result = await gateway.pipeline.run_cycle( + tickers=tickers, + date=trading_date, + prices=current_prices, + market_caps=market_caps, + execute_decisions=execute_decisions, + ) + close_prices = current_prices + else: + result = await gateway.pipeline.run_cycle( + tickers=tickers, + date=trading_date, + market_caps=market_caps, + get_open_prices_fn=gateway.market_service.wait_for_open_prices, + get_close_prices_fn=gateway.market_service.wait_for_close_prices, + ) + close_prices = gateway.market_service.get_all_prices() + + settlement_result = result.get("settlement_result") + save_cycle_results(gateway, result, trading_date, close_prices, settlement_result) + await broadcast_portfolio_updates(gateway, result, close_prices) + await finalize_cycle(gateway, trading_date) + + +async def finalize_cycle(gateway: Any, date: str) -> None: + summary = gateway.storage.load_file("summary") or {} + if gateway.storage.is_live_session_active: + summary.update(gateway.storage.get_live_returns()) + + await gateway.state_sync.on_cycle_end(date, portfolio_summary=summary) + holdings = gateway.storage.load_file("holdings") or [] + trades = gateway.storage.load_file("trades") or [] + leaderboard = gateway.storage.load_file("leaderboard") or [] + if leaderboard: + await gateway.state_sync.on_leaderboard_update(leaderboard) + gateway._dashboard.update(date=date, status="Running", portfolio=summary, holdings=holdings, trades=trades) + + +async def get_market_caps(gateway: Any, tickers: list[str], date: str) -> dict[str, float]: + market_caps: dict[str, float] = {} + for ticker in tickers: + try: + market_cap = None + response = await gateway._call_trading_service( + f"get_market_cap for {ticker}", + lambda client, symbol=ticker: client.get_market_cap(ticker=symbol, end_date=date), + ) + if response is not None: + market_cap = response.get("market_cap") + if market_cap is None: + payload = trading_domain.get_market_cap_payload(ticker=ticker, end_date=date) + market_cap = payload.get("market_cap") + market_caps[ticker] = market_cap if market_cap else 1e9 + except Exception as exc: + logger.warning("Failed to get market cap for %s, using default 1e9: %s", ticker, exc) + market_caps[ticker] = 1e9 + return market_caps + + +async def broadcast_portfolio_updates(gateway: Any, result: dict[str, Any], prices: dict[str, float]) -> None: + portfolio = result.get("portfolio", {}) + if portfolio: + holdings = FrontendAdapter.build_holdings(portfolio, prices) + if holdings: + await gateway.state_sync.on_holdings_update(holdings) + stats = FrontendAdapter.build_stats(portfolio, prices) + if stats: + await gateway.state_sync.on_stats_update(stats) + + executed_trades = result.get("executed_trades", []) + if executed_trades: + await gateway.state_sync.on_trades_executed(executed_trades) + + +def save_cycle_results( + gateway: Any, + result: dict[str, Any], + date: str, + prices: dict[str, float], + settlement_result: dict[str, Any] | None = None, +) -> None: + portfolio = result.get("portfolio", {}) + executed_trades = result.get("executed_trades", []) + baseline_values = settlement_result.get("baseline_values") if settlement_result else None + if portfolio: + gateway.storage.update_dashboard_after_cycle( + portfolio=portfolio, + prices=prices, + date=date, + executed_trades=executed_trades, + baseline_values=baseline_values, + ) + + +async def run_backtest_dates(gateway: Any, dates: list[str]) -> None: + gateway.state_sync.set_backtest_dates(dates) + gateway._dashboard.update(days_total=len(dates), days_completed=0) + await gateway.state_sync.on_system_message(f"Starting backtest - {len(dates)} trading days") + try: + for i, date in enumerate(dates): + gateway._dashboard.update(days_completed=i) + await gateway.on_strategy_trigger(date=date) + await asyncio.sleep(0.1) + await gateway.state_sync.on_system_message(f"Backtest complete - {len(dates)} days") + summary = gateway.storage.load_file("summary") or {} + gateway._dashboard.update(status="Complete", portfolio=summary, days_completed=len(dates)) + gateway._dashboard.stop() + gateway._dashboard.print_final_summary() + except Exception as exc: + error_msg = f"Backtest failed: {type(exc).__name__}: {str(exc)}" + logger.error(error_msg, exc_info=True) + asyncio.create_task(gateway.state_sync.on_system_message(error_msg)) + gateway._dashboard.update(status=f"Failed: {str(exc)}") + gateway._dashboard.stop() + raise + finally: + gateway._backtest_task = None + + +def handle_backtest_exception(gateway: Any, task: asyncio.Task) -> None: + try: + task.result() + except asyncio.CancelledError: + logger.info("Backtest task was cancelled") + except Exception as exc: + logger.error("Backtest task failed with exception:%s:%s", type(exc).__name__, exc, exc_info=True) + + +def handle_manual_cycle_exception(gateway: Any, task: asyncio.Task) -> None: + gateway._manual_cycle_task = None + try: + task.result() + except asyncio.CancelledError: + logger.info("Manual cycle task was cancelled") + except Exception as exc: + logger.error("Manual cycle task failed with exception:%s:%s", type(exc).__name__, exc, exc_info=True) + + +def set_backtest_dates(gateway: Any, dates: list[str]) -> None: + gateway.state_sync.set_backtest_dates(dates) + if dates: + gateway._backtest_start_date = dates[0] + gateway._backtest_end_date = dates[-1] + gateway._dashboard.days_total = len(dates) + + +def stop_gateway(gateway: Any) -> None: + gateway.state_sync.save_state() + gateway.market_service.stop() + if gateway._backtest_task: + gateway._backtest_task.cancel() + if gateway._market_status_task: + gateway._market_status_task.cancel() + if gateway._watchlist_ingest_task: + gateway._watchlist_ingest_task.cancel() + gateway._dashboard.stop() diff --git a/backend/services/gateway_runtime_support.py b/backend/services/gateway_runtime_support.py new file mode 100644 index 0000000..acfdff7 --- /dev/null +++ b/backend/services/gateway_runtime_support.py @@ -0,0 +1,174 @@ +# -*- coding: utf-8 -*- +"""Runtime/state support helpers extracted from the main Gateway module.""" + +from __future__ import annotations + +from typing import Any + +from backend.data.provider_utils import normalize_symbol + + +def normalize_watchlist(raw_tickers: Any) -> list[str]: + """Parse watchlist payloads from websocket messages.""" + if raw_tickers is None: + return [] + + if isinstance(raw_tickers, str): + candidates = raw_tickers.split(",") + elif isinstance(raw_tickers, list): + candidates = raw_tickers + else: + candidates = [raw_tickers] + + tickers: list[str] = [] + for candidate in candidates: + symbol = normalize_symbol(str(candidate).strip().strip("\"'")) + if symbol and symbol not in tickers: + tickers.append(symbol) + return tickers + + +def normalize_agent_workspace_filename( + raw_name: Any, + *, + allowlist: set[str], +) -> str | None: + """Restrict editable workspace files to a safe allowlist.""" + filename = str(raw_name or "").strip() + if filename in allowlist: + return filename + return None + + +def apply_runtime_config(gateway: Any, runtime_config: dict[str, Any]) -> dict[str, Any]: + """Apply runtime config to gateway-owned services and state.""" + warnings: list[str] = [] + + ticker_changes = gateway.market_service.update_tickers( + runtime_config.get("tickers", []), + ) + gateway.config["tickers"] = ticker_changes["active"] + + gateway.pipeline.max_comm_cycles = int(runtime_config["max_comm_cycles"]) + gateway.config["max_comm_cycles"] = gateway.pipeline.max_comm_cycles + gateway.config["schedule_mode"] = runtime_config.get( + "schedule_mode", + gateway.config.get("schedule_mode", "daily"), + ) + gateway.config["interval_minutes"] = int( + runtime_config.get( + "interval_minutes", + gateway.config.get("interval_minutes", 60), + ), + ) + gateway.config["trigger_time"] = runtime_config.get( + "trigger_time", + gateway.config.get("trigger_time", "09:30"), + ) + + if gateway.scheduler: + gateway.scheduler.reconfigure( + mode=gateway.config["schedule_mode"], + trigger_time=gateway.config["trigger_time"], + interval_minutes=gateway.config["interval_minutes"], + ) + + pm_apply_result = gateway.pipeline.pm.apply_runtime_portfolio_config( + margin_requirement=runtime_config["margin_requirement"], + ) + gateway.config["margin_requirement"] = gateway.pipeline.pm.portfolio.get( + "margin_requirement", + runtime_config["margin_requirement"], + ) + + requested_initial_cash = float(runtime_config["initial_cash"]) + current_initial_cash = float(gateway.storage.initial_cash) + initial_cash_applied = requested_initial_cash == current_initial_cash + if not initial_cash_applied: + if ( + gateway.storage.can_apply_initial_cash() + and gateway.pipeline.pm.can_apply_initial_cash() + ): + initial_cash_applied = gateway.storage.apply_initial_cash( + requested_initial_cash, + ) + if initial_cash_applied: + gateway.pipeline.pm.apply_runtime_portfolio_config( + initial_cash=requested_initial_cash, + ) + gateway.config["initial_cash"] = gateway.storage.initial_cash + else: + warnings.append( + "initial_cash changed in BOOTSTRAP.md but was not applied " + "because the run already has positions, margin usage, or trades.", + ) + + requested_enable_memory = bool(runtime_config["enable_memory"]) + current_enable_memory = bool(gateway.config.get("enable_memory", False)) + if requested_enable_memory != current_enable_memory: + warnings.append( + "enable_memory changed in BOOTSTRAP.md but still requires a restart " + "because long-term memory contexts are created at startup.", + ) + + sync_runtime_state(gateway) + + return { + "runtime_config_requested": runtime_config, + "runtime_config_applied": { + "tickers": list(gateway.config.get("tickers", [])), + "schedule_mode": gateway.config.get("schedule_mode", "daily"), + "interval_minutes": gateway.config.get("interval_minutes", 60), + "trigger_time": gateway.config.get("trigger_time", "09:30"), + "initial_cash": gateway.storage.initial_cash, + "margin_requirement": gateway.config["margin_requirement"], + "max_comm_cycles": gateway.config["max_comm_cycles"], + "enable_memory": gateway.config.get("enable_memory", False), + }, + "runtime_config_status": { + "tickers": True, + "schedule_mode": True, + "interval_minutes": True, + "trigger_time": True, + "initial_cash": initial_cash_applied, + "margin_requirement": pm_apply_result["margin_requirement"], + "max_comm_cycles": True, + "enable_memory": requested_enable_memory == current_enable_memory, + }, + "ticker_changes": ticker_changes, + "runtime_config_warnings": warnings, + } + + +def sync_runtime_state(gateway: Any) -> None: + """Refresh persisted state and dashboard after runtime config changes.""" + gateway.state_sync.update_state("tickers", gateway.config.get("tickers", [])) + gateway.state_sync.update_state( + "runtime_config", + { + "tickers": gateway.config.get("tickers", []), + "schedule_mode": gateway.config.get("schedule_mode", "daily"), + "interval_minutes": gateway.config.get("interval_minutes", 60), + "trigger_time": gateway.config.get("trigger_time", "09:30"), + "initial_cash": gateway.storage.initial_cash, + "margin_requirement": gateway.config.get("margin_requirement"), + "max_comm_cycles": gateway.config.get("max_comm_cycles"), + "enable_memory": gateway.config.get("enable_memory", False), + }, + ) + + gateway.storage.update_server_state_from_dashboard(gateway.state_sync.state) + gateway.state_sync.save_state() + + gateway._dashboard.tickers = list(gateway.config.get("tickers", [])) + gateway._dashboard.initial_cash = gateway.storage.initial_cash + gateway._dashboard.enable_memory = bool(gateway.config.get("enable_memory", False)) + + summary = gateway.storage.load_file("summary") or {} + holdings = gateway.storage.load_file("holdings") or [] + trades = gateway.storage.load_file("trades") or [] + gateway._dashboard.update( + portfolio=summary, + holdings=holdings, + trades=trades, + ) diff --git a/backend/services/gateway_stock_handlers.py b/backend/services/gateway_stock_handlers.py new file mode 100644 index 0000000..5f25b17 --- /dev/null +++ b/backend/services/gateway_stock_handlers.py @@ -0,0 +1,711 @@ +# -*- coding: utf-8 -*- +"""Stock-related Gateway handlers extracted from the main Gateway module.""" + +from __future__ import annotations + +import asyncio +import json +import logging +from datetime import datetime, timedelta +from typing import Any + +from backend.data.provider_utils import normalize_symbol +from backend.domains import news as news_domain +from backend.domains import trading as trading_domain +from backend.enrich.news_enricher import enrich_news_for_symbol +from backend.enrich.llm_enricher import llm_enrichment_enabled +from backend.tools.data_tools import prices_to_df +from shared.client import NewsServiceClient, TradingServiceClient + +logger = logging.getLogger(__name__) + + +async def handle_get_stock_history(gateway: Any, websocket: Any, data: dict[str, Any]) -> None: + ticker = normalize_symbol(data.get("ticker", "")) + if not ticker: + await websocket.send(json.dumps({ + "type": "stock_history_loaded", + "ticker": "", + "prices": [], + "source": None, + "error": "invalid ticker", + }, ensure_ascii=False)) + return + + lookback_days = data.get("lookback_days", 90) + try: + lookback_days = max(7, min(int(lookback_days), 365)) + except (TypeError, ValueError): + lookback_days = 90 + + end_date = gateway.state_sync.state.get("current_date") or datetime.now().strftime("%Y-%m-%d") + try: + end_dt = datetime.strptime(end_date, "%Y-%m-%d") + except ValueError: + end_dt = datetime.now() + end_date = end_dt.strftime("%Y-%m-%d") + start_date = (end_dt - timedelta(days=lookback_days)).strftime("%Y-%m-%d") + + prices = [] + source = "polygon" + response = await gateway._call_trading_service( + "get_prices for history", + lambda client: client.get_prices(ticker=ticker, start_date=start_date, end_date=end_date), + ) + if response is not None: + prices = response.prices + source = "trading_service" + + if not prices: + prices = await asyncio.to_thread(gateway.storage.market_store.get_ohlc, ticker, start_date, end_date) + if not prices: + payload = await asyncio.to_thread( + trading_domain.get_prices_payload, + ticker=ticker, + start_date=start_date, + end_date=end_date, + ) + prices = payload.get("prices") or [] + usage_snapshot = gateway._provider_router.get_usage_snapshot() + source = usage_snapshot.get("last_success", {}).get("prices") + if prices: + await asyncio.to_thread( + gateway.storage.market_store.upsert_ohlc, + ticker, + [price.model_dump() for price in prices], + source=source or "provider", + ) + + await websocket.send(json.dumps({ + "type": "stock_history_loaded", + "ticker": ticker, + "prices": [price if isinstance(price, dict) else price.model_dump() for price in prices][-120:], + "source": source, + "start_date": start_date, + "end_date": end_date, + }, ensure_ascii=False, default=str)) + + +async def handle_get_stock_explain_events(gateway: Any, websocket: Any, data: dict[str, Any]) -> None: + ticker = normalize_symbol(data.get("ticker", "")) + snapshot = gateway.storage.runtime_db.get_stock_explain_snapshot(ticker) + await websocket.send(json.dumps({ + "type": "stock_explain_events_loaded", + "ticker": ticker, + "events": snapshot.get("events", []), + "signals": snapshot.get("signals", []), + "trades": snapshot.get("trades", []), + }, ensure_ascii=False, default=str)) + + +async def handle_get_stock_news(gateway: Any, websocket: Any, data: dict[str, Any]) -> None: + ticker = normalize_symbol(data.get("ticker", "")) + if not ticker: + await websocket.send(json.dumps({ + "type": "stock_news_loaded", + "ticker": "", + "news": [], + "source": None, + "error": "invalid ticker", + }, ensure_ascii=False)) + return + + lookback_days = data.get("lookback_days", 30) + limit = data.get("limit", 12) + try: + lookback_days = max(7, min(int(lookback_days), 180)) + except (TypeError, ValueError): + lookback_days = 30 + try: + limit = max(1, min(int(limit), 30)) + except (TypeError, ValueError): + limit = 12 + + end_date = gateway.state_sync.state.get("current_date") or datetime.now().strftime("%Y-%m-%d") + try: + end_dt = datetime.strptime(end_date, "%Y-%m-%d") + except ValueError: + end_dt = datetime.now() + end_date = end_dt.strftime("%Y-%m-%d") + start_date = (end_dt - timedelta(days=lookback_days)).strftime("%Y-%m-%d") + + news_rows = [] + source = "polygon" + response = await gateway._call_news_service( + "get_enriched_news", + lambda client: client.get_enriched_news( + ticker=ticker, + start_date=start_date, + end_date=end_date, + limit=limit, + ), + ) + if response is not None: + news_rows = response.get("news") or [] + source = "news_service" + + if not news_rows: + payload = await asyncio.to_thread( + news_domain.get_enriched_news, + gateway.storage.market_store, + ticker=ticker, + start_date=start_date, + end_date=end_date, + limit=max(limit, 50), + ) + news_rows = (payload.get("news") or [])[-limit:] + source = "market_store" + + await websocket.send(json.dumps({ + "type": "stock_news_loaded", + "ticker": ticker, + "news": news_rows[-limit:], + "source": source, + "start_date": start_date, + "end_date": end_date, + }, ensure_ascii=False, default=str)) + + +async def handle_get_stock_news_for_date(gateway: Any, websocket: Any, data: dict[str, Any]) -> None: + ticker = normalize_symbol(data.get("ticker", "")) + trade_date = str(data.get("date") or "").strip() + if not ticker or not trade_date: + await websocket.send(json.dumps({ + "type": "stock_news_for_date_loaded", + "ticker": ticker, + "date": trade_date, + "news": [], + "error": "ticker and date are required", + }, ensure_ascii=False)) + return + + limit = data.get("limit", 20) + try: + limit = max(1, min(int(limit), 50)) + except (TypeError, ValueError): + limit = 20 + + source = "market_store" + news_rows = [] + response = await gateway._call_news_service( + "get_news_for_date", + lambda client: client.get_news_for_date(ticker=ticker, date=trade_date, limit=limit), + ) + if response is not None: + news_rows = response.get("news") or [] + source = "news_service" + + if not news_rows: + payload = await asyncio.to_thread( + news_domain.get_news_for_date, + gateway.storage.market_store, + ticker=ticker, + date=trade_date, + limit=limit, + ) + news_rows = payload.get("news") or [] + source = "market_store" + + await websocket.send(json.dumps({ + "type": "stock_news_for_date_loaded", + "ticker": ticker, + "date": trade_date, + "news": news_rows, + "source": source, + }, ensure_ascii=False, default=str)) + + +async def handle_get_stock_news_timeline(gateway: Any, websocket: Any, data: dict[str, Any]) -> None: + ticker = normalize_symbol(data.get("ticker", "")) + if not ticker: + await websocket.send(json.dumps({ + "type": "stock_news_timeline_loaded", + "ticker": "", + "timeline": [], + "error": "invalid ticker", + }, ensure_ascii=False)) + return + + lookback_days = data.get("lookback_days", 90) + try: + lookback_days = max(7, min(int(lookback_days), 365)) + except (TypeError, ValueError): + lookback_days = 90 + + end_date = gateway.state_sync.state.get("current_date") or datetime.now().strftime("%Y-%m-%d") + try: + end_dt = datetime.strptime(end_date, "%Y-%m-%d") + except ValueError: + end_dt = datetime.now() + end_date = end_dt.strftime("%Y-%m-%d") + start_date = (end_dt - timedelta(days=lookback_days)).strftime("%Y-%m-%d") + + timeline = [] + response = await gateway._call_news_service( + "get_news_timeline", + lambda client: client.get_news_timeline(ticker=ticker, start_date=start_date, end_date=end_date), + ) + if response is not None: + timeline = response.get("timeline") or [] + + if not timeline: + payload = await asyncio.to_thread( + news_domain.get_news_timeline, + gateway.storage.market_store, + ticker=ticker, + start_date=start_date, + end_date=end_date, + ) + timeline = payload.get("timeline") or [] + + await websocket.send(json.dumps({ + "type": "stock_news_timeline_loaded", + "ticker": ticker, + "timeline": timeline, + "start_date": start_date, + "end_date": end_date, + }, ensure_ascii=False, default=str)) + + +async def handle_get_stock_news_categories(gateway: Any, websocket: Any, data: dict[str, Any]) -> None: + ticker = normalize_symbol(data.get("ticker", "")) + if not ticker: + await websocket.send(json.dumps({ + "type": "stock_news_categories_loaded", + "ticker": "", + "categories": {}, + "error": "invalid ticker", + }, ensure_ascii=False)) + return + + lookback_days = data.get("lookback_days", 90) + try: + lookback_days = max(7, min(int(lookback_days), 365)) + except (TypeError, ValueError): + lookback_days = 90 + + end_date = gateway.state_sync.state.get("current_date") or datetime.now().strftime("%Y-%m-%d") + try: + end_dt = datetime.strptime(end_date, "%Y-%m-%d") + except ValueError: + end_dt = datetime.now() + end_date = end_dt.strftime("%Y-%m-%d") + start_date = (end_dt - timedelta(days=lookback_days)).strftime("%Y-%m-%d") + + categories = {} + response = await gateway._call_news_service( + "get_categories", + lambda client: client.get_categories( + ticker=ticker, + start_date=start_date, + end_date=end_date, + limit=200, + ), + ) + if response is not None: + categories = response.get("categories") or {} + + if not categories: + payload = await asyncio.to_thread( + news_domain.get_news_categories, + gateway.storage.market_store, + ticker=ticker, + start_date=start_date, + end_date=end_date, + limit=200, + ) + categories = payload.get("categories") or {} + + await websocket.send(json.dumps({ + "type": "stock_news_categories_loaded", + "ticker": ticker, + "categories": categories, + "start_date": start_date, + "end_date": end_date, + }, ensure_ascii=False, default=str)) + + +async def handle_get_stock_range_explain(gateway: Any, websocket: Any, data: dict[str, Any]) -> None: + ticker = normalize_symbol(data.get("ticker", "")) + start_date = str(data.get("start_date") or "").strip() + end_date = str(data.get("end_date") or "").strip() + if not ticker or not start_date or not end_date: + await websocket.send(json.dumps({ + "type": "stock_range_explain_loaded", + "ticker": ticker, + "result": {"error": "ticker, start_date, end_date are required"}, + }, ensure_ascii=False)) + return + + article_ids = data.get("article_ids") + result = None + response = await gateway._call_news_service( + "get_range_explain", + lambda client: client.get_range_explain( + ticker=ticker, + start_date=start_date, + end_date=end_date, + article_ids=article_ids if isinstance(article_ids, list) else None, + limit=100, + ), + ) + if response is not None: + result = response.get("result") + + if result is None: + payload = await asyncio.to_thread( + news_domain.get_range_explain_payload, + gateway.storage.market_store, + ticker=ticker, + start_date=start_date, + end_date=end_date, + article_ids=article_ids if isinstance(article_ids, list) else None, + limit=100, + ) + result = payload.get("result") + + await websocket.send(json.dumps({ + "type": "stock_range_explain_loaded", + "ticker": ticker, + "result": result, + }, ensure_ascii=False, default=str)) + + +async def handle_get_stock_insider_trades(gateway: Any, websocket: Any, data: dict[str, Any]) -> None: + ticker = normalize_symbol(data.get("ticker", "")) + if not ticker: + await websocket.send(json.dumps({ + "type": "stock_insider_trades_loaded", + "ticker": "", + "trades": [], + "error": "invalid ticker", + }, ensure_ascii=False)) + return + + end_date = str(data.get("end_date") or gateway.state_sync.state.get("current_date") or datetime.now().strftime("%Y-%m-%d")).strip()[:10] + start_date = str(data.get("start_date") or "").strip()[:10] + limit = int(data.get("limit", 50)) + + trades = [] + response = await gateway._call_trading_service( + "get_insider_trades", + lambda client: client.get_insider_trades( + ticker=ticker, + end_date=end_date, + start_date=start_date if start_date else None, + limit=limit, + ), + ) + if response is not None: + trades = response.insider_trades + + if not trades: + payload = await asyncio.to_thread( + trading_domain.get_insider_trades_payload, + ticker=ticker, + end_date=end_date, + start_date=start_date if start_date else None, + limit=limit, + ) + trades = payload.get("insider_trades") or [] + + sorted_trades = sorted(trades, key=lambda t: t.transaction_date or "", reverse=True) + formatted_trades = [{ + "ticker": t.ticker, + "name": t.name, + "title": t.title, + "is_board_director": t.is_board_director, + "transaction_date": t.transaction_date, + "transaction_shares": t.transaction_shares, + "transaction_price_per_share": t.transaction_price_per_share, + "transaction_value": t.transaction_value, + "shares_owned_before_transaction": t.shares_owned_before_transaction, + "shares_owned_after_transaction": t.shares_owned_after_transaction, + "security_title": t.security_title, + "filing_date": t.filing_date, + "holding_change": ( + (t.shares_owned_after_transaction or 0) - (t.shares_owned_before_transaction or 0) + if t.shares_owned_after_transaction and t.shares_owned_before_transaction else None + ), + "is_buy": ((t.transaction_shares or 0) > 0) if t.transaction_shares is not None else None, + } for t in sorted_trades] + + await websocket.send(json.dumps({ + "type": "stock_insider_trades_loaded", + "ticker": ticker, + "start_date": start_date or None, + "end_date": end_date, + "trades": formatted_trades, + }, ensure_ascii=False, default=str)) + + +async def handle_get_stock_story(gateway: Any, websocket: Any, data: dict[str, Any]) -> None: + ticker = normalize_symbol(data.get("ticker", "")) + if not ticker: + await websocket.send(json.dumps({ + "type": "stock_story_loaded", + "ticker": "", + "story": "", + "error": "invalid ticker", + }, ensure_ascii=False)) + return + + as_of_date = str(data.get("as_of_date") or gateway.state_sync.state.get("current_date") or datetime.now().strftime("%Y-%m-%d")).strip()[:10] + result = await gateway._call_news_service( + "get_story", + lambda client: client.get_story(ticker=ticker, as_of_date=as_of_date), + ) + if result is None: + result = await asyncio.to_thread( + news_domain.get_story_payload, + gateway.storage.market_store, + ticker=ticker, + as_of_date=as_of_date, + ) + + await websocket.send(json.dumps({ + "type": "stock_story_loaded", + "ticker": ticker, + "as_of_date": as_of_date, + "story": result.get("story") or "", + "source": result.get("source") or "local", + }, ensure_ascii=False, default=str)) + + +async def handle_get_stock_similar_days(gateway: Any, websocket: Any, data: dict[str, Any]) -> None: + ticker = normalize_symbol(data.get("ticker", "")) + target_date = str(data.get("date") or "").strip()[:10] + if not ticker or not target_date: + await websocket.send(json.dumps({ + "type": "stock_similar_days_loaded", + "ticker": ticker, + "date": target_date, + "items": [], + "error": "ticker and date are required", + }, ensure_ascii=False)) + return + + top_k = data.get("top_k", 8) + try: + top_k = max(1, min(int(top_k), 20)) + except (TypeError, ValueError): + top_k = 8 + + result = await gateway._call_news_service( + "get_similar_days", + lambda client: client.get_similar_days(ticker=ticker, date=target_date, n_similar=top_k), + ) + if result is None: + result = await asyncio.to_thread( + news_domain.get_similar_days_payload, + gateway.storage.market_store, + ticker=ticker, + date=target_date, + n_similar=top_k, + ) + + await websocket.send(json.dumps({ + "type": "stock_similar_days_loaded", + "ticker": ticker, + "date": target_date, + **result, + }, ensure_ascii=False, default=str)) + + +async def handle_get_stock_technical_indicators(gateway: Any, websocket: Any, data: dict[str, Any]) -> None: + ticker = normalize_symbol(data.get("ticker", "")) + if not ticker: + await websocket.send(json.dumps({ + "type": "stock_technical_indicators_loaded", + "ticker": ticker, + "indicators": None, + "error": "ticker is required", + }, ensure_ascii=False)) + return + + try: + end_date = datetime.now() + start_date = end_date - timedelta(days=250) + + prices = None + response = await gateway._call_trading_service( + "get_prices", + lambda client: client.get_prices( + ticker=ticker, + start_date=start_date.strftime("%Y-%m-%d"), + end_date=end_date.strftime("%Y-%m-%d"), + ), + ) + if response is not None: + prices = response.prices + + if prices is None: + payload = trading_domain.get_prices_payload( + ticker=ticker, + start_date=start_date.strftime("%Y-%m-%d"), + end_date=end_date.strftime("%Y-%m-%d"), + ) + prices = payload.get("prices") or [] + + if not prices or len(prices) < 20: + await websocket.send(json.dumps({ + "type": "stock_technical_indicators_loaded", + "ticker": ticker, + "indicators": None, + "error": "Insufficient price data", + }, ensure_ascii=False)) + return + + df = prices_to_df(prices) + signal = gateway._technical_analyzer.analyze(ticker, df) + + import pandas as pd + df_sorted = df.sort_values("time").reset_index(drop=True) + df_sorted["returns"] = df_sorted["close"].pct_change() + vol_10 = float(df_sorted["returns"].tail(10).std() * (252**0.5) * 100) if len(df_sorted) >= 10 else None + vol_20 = float(df_sorted["returns"].tail(20).std() * (252**0.5) * 100) if len(df_sorted) >= 20 else None + vol_60 = float(df_sorted["returns"].tail(60).std() * (252**0.5) * 100) if len(df_sorted) >= 60 else None + ma_distance = {} + for ma_key in ["ma5", "ma10", "ma20", "ma50", "ma200"]: + ma_value = getattr(signal, ma_key, None) + ma_distance[ma_key] = ((signal.current_price - ma_value) / ma_value) * 100 if ma_value and ma_value > 0 else None + + indicators = { + "ticker": ticker, + "current_price": signal.current_price, + "ma": { + "ma5": signal.ma5, + "ma10": signal.ma10, + "ma20": signal.ma20, + "ma50": signal.ma50, + "ma200": signal.ma200, + "distance": ma_distance, + }, + "rsi": { + "rsi14": signal.rsi14, + "status": "oversold" if signal.rsi14 < 30 else "overbought" if signal.rsi14 > 70 else "neutral", + }, + "macd": { + "macd": signal.macd, + "signal": signal.macd_signal, + "histogram": signal.macd - signal.macd_signal, + }, + "bollinger": { + "upper": signal.bollinger_upper, + "mid": signal.bollinger_mid, + "lower": signal.bollinger_lower, + }, + "volatility": { + "vol_10d": vol_10, + "vol_20d": vol_20, + "vol_60d": vol_60, + "annualized": signal.annualized_volatility_pct, + "risk_level": signal.risk_level, + }, + "trend": signal.trend, + "mean_reversion": signal.mean_reversion_signal, + } + + await websocket.send(json.dumps({ + "type": "stock_technical_indicators_loaded", + "ticker": ticker, + "indicators": indicators, + }, ensure_ascii=False, default=str)) + except Exception as exc: + logger.exception("Error getting technical indicators for %s", ticker) + await websocket.send(json.dumps({ + "type": "stock_technical_indicators_loaded", + "ticker": ticker, + "indicators": None, + "error": str(exc), + }, ensure_ascii=False)) + + +async def handle_run_stock_enrich(gateway: Any, websocket: Any, data: dict[str, Any]) -> None: + ticker = normalize_symbol(data.get("ticker", "")) + start_date = str(data.get("start_date") or "").strip()[:10] + end_date = str(data.get("end_date") or "").strip()[:10] + story_date = str(data.get("story_date") or end_date or "").strip()[:10] + target_date = str(data.get("target_date") or "").strip()[:10] + force = bool(data.get("force", False)) + rebuild_story = bool(data.get("rebuild_story", True)) + rebuild_similar_days = bool(data.get("rebuild_similar_days", True)) + only_local_to_llm = bool(data.get("only_local_to_llm", False)) + limit = data.get("limit", 200) + + try: + limit = max(10, min(int(limit), 500)) + except (TypeError, ValueError): + limit = 200 + + if not ticker or not start_date or not end_date: + await websocket.send(json.dumps({ + "type": "stock_enrich_completed", + "ticker": ticker, + "start_date": start_date, + "end_date": end_date, + "error": "ticker, start_date, end_date are required", + }, ensure_ascii=False)) + return + + if only_local_to_llm and not llm_enrichment_enabled(): + await websocket.send(json.dumps({ + "type": "stock_enrich_completed", + "ticker": ticker, + "start_date": start_date, + "end_date": end_date, + "error": "only_local_to_llm requires EXPLAIN_ENRICH_USE_LLM=true and a configured LLM provider", + }, ensure_ascii=False)) + return + + result = await asyncio.to_thread( + enrich_news_for_symbol, + gateway.storage.market_store, + ticker, + start_date=start_date, + end_date=end_date, + limit=limit, + skip_existing=not force, + only_reanalyze_local=only_local_to_llm, + ) + + story_status = None + if rebuild_story and story_date: + await asyncio.to_thread(gateway.storage.market_store.delete_story_cache, ticker, as_of_date=story_date) + story_result = await asyncio.to_thread( + news_domain.get_story_payload, + gateway.storage.market_store, + ticker=ticker, + as_of_date=story_date, + ) + story_status = {"as_of_date": story_date, "source": story_result.get("source") or "local"} + + similar_status = None + if rebuild_similar_days and target_date: + await asyncio.to_thread(gateway.storage.market_store.delete_similar_day_cache, ticker, target_date=target_date) + similar_result = await asyncio.to_thread( + news_domain.get_similar_days_payload, + gateway.storage.market_store, + ticker=ticker, + date=target_date, + n_similar=8, + ) + similar_status = { + "target_date": target_date, + "count": len(similar_result.get("items") or []), + "error": similar_result.get("error"), + } + + await websocket.send(json.dumps({ + "type": "stock_enrich_completed", + "ticker": ticker, + "start_date": start_date, + "end_date": end_date, + "story_date": story_date or None, + "target_date": target_date or None, + "force": force, + "only_local_to_llm": only_local_to_llm, + "stats": result, + "story_status": story_status, + "similar_status": similar_status, + }, ensure_ascii=False, default=str)) diff --git a/backend/services/research_db.py b/backend/services/research_db.py index 387e34f..e009f83 100644 --- a/backend/services/research_db.py +++ b/backend/services/research_db.py @@ -9,7 +9,7 @@ from datetime import datetime from pathlib import Path from typing import Any, Dict, Iterable -from backend.data.schema import CompanyNews +from shared.schema import CompanyNews SCHEMA = """ diff --git a/backend/tests/test_agent_service_app.py b/backend/tests/test_agent_service_app.py new file mode 100644 index 0000000..02d64d4 --- /dev/null +++ b/backend/tests/test_agent_service_app.py @@ -0,0 +1,27 @@ +# -*- coding: utf-8 -*- +"""Tests for the extracted agent service surface.""" + +from pathlib import Path + +from fastapi.testclient import TestClient + +from backend.apps.agent_service import create_app + + +def test_agent_service_routes_include_control_plane_endpoints(tmp_path): + app = create_app(project_root=tmp_path) + + paths = {route.path for route in app.routes} + + assert "/health" in paths + assert "/api/status" in paths + assert "/api/workspaces" in paths + assert "/api/guard/pending" in paths + + +def test_agent_service_excludes_runtime_routes(tmp_path): + app = create_app(project_root=tmp_path) + paths = {route.path for route in app.routes} + + assert "/api/runtime/start" not in paths + assert "/api/runtime/gateway/port" not in paths diff --git a/backend/tests/test_data_tools_service_routing.py b/backend/tests/test_data_tools_service_routing.py new file mode 100644 index 0000000..570e634 --- /dev/null +++ b/backend/tests/test_data_tools_service_routing.py @@ -0,0 +1,139 @@ +# -*- coding: utf-8 -*- +"""Tests for data_tools preferring split services when configured.""" + +from backend.tools import data_tools +from shared.schema import CompanyNews, FinancialMetrics, InsiderTrade, LineItem, Price + + +def test_data_tools_prefers_trading_service(monkeypatch): + monkeypatch.setenv("TRADING_SERVICE_URL", "http://localhost:8001") + monkeypatch.setenv("SERVICE_NAME", "agent_service") + monkeypatch.setattr(data_tools._cache, "get_prices", lambda key: None) + monkeypatch.setattr(data_tools._cache, "get_financial_metrics", lambda key: None) + monkeypatch.setattr(data_tools._cache, "get_insider_trades", lambda key: None) + monkeypatch.setattr(data_tools._cache, "get_company_news", lambda key: None) + + def fake_service_get_json(base_url, path, *, params): + if path == "/api/prices": + return { + "ticker": "AAPL", + "prices": [ + Price( + open=1, + close=2, + high=3, + low=1, + volume=10, + time="2026-03-16", + ).model_dump() + ], + } + if path == "/api/financials": + return { + "financial_metrics": [ + FinancialMetrics( + ticker="AAPL", + report_period="2026-03-16", + period="ttm", + currency="USD", + market_cap=123.0, + enterprise_value=None, + price_to_earnings_ratio=None, + price_to_book_ratio=None, + price_to_sales_ratio=None, + enterprise_value_to_ebitda_ratio=None, + enterprise_value_to_revenue_ratio=None, + free_cash_flow_yield=None, + peg_ratio=None, + gross_margin=None, + operating_margin=None, + net_margin=None, + return_on_equity=None, + return_on_assets=None, + return_on_invested_capital=None, + asset_turnover=None, + inventory_turnover=None, + receivables_turnover=None, + days_sales_outstanding=None, + operating_cycle=None, + working_capital_turnover=None, + current_ratio=None, + quick_ratio=None, + cash_ratio=None, + operating_cash_flow_ratio=None, + debt_to_equity=None, + debt_to_assets=None, + interest_coverage=None, + revenue_growth=None, + earnings_growth=None, + book_value_growth=None, + earnings_per_share_growth=None, + free_cash_flow_growth=None, + operating_income_growth=None, + ebitda_growth=None, + payout_ratio=None, + earnings_per_share=None, + book_value_per_share=None, + free_cash_flow_per_share=None, + ).model_dump() + ] + } + if path == "/api/insider-trades": + return { + "insider_trades": [ + InsiderTrade(ticker="AAPL", filing_date="2026-03-16").model_dump() + ] + } + if path == "/api/news": + return { + "news": [ + CompanyNews( + ticker="AAPL", + title="Title", + source="polygon", + url="https://example.com", + ).model_dump() + ] + } + if path == "/api/market-cap": + return {"ticker": "AAPL", "end_date": "2026-03-16", "market_cap": 2.5e12} + if path == "/api/line-items": + return { + "search_results": [ + LineItem( + ticker="AAPL", + report_period="2026-03-16", + period="ttm", + currency="USD", + free_cash_flow=321.0, + ).model_dump() + ] + } + raise AssertionError(path) + + monkeypatch.setattr(data_tools, "_service_get_json", fake_service_get_json) + + prices = data_tools.get_prices("AAPL", "2026-03-01", "2026-03-16") + metrics = data_tools.get_financial_metrics("AAPL", "2026-03-16") + trades = data_tools.get_insider_trades("AAPL", "2026-03-16") + news = data_tools.get_company_news("AAPL", "2026-03-16") + market_cap = data_tools.get_market_cap("AAPL", "2026-03-16") + line_items = data_tools.search_line_items( + "AAPL", + ["free_cash_flow"], + "2026-03-16", + ) + + assert prices[0].close == 2 + assert metrics[0].ticker == "AAPL" + assert trades[0].ticker == "AAPL" + assert news[0].ticker == "AAPL" + assert market_cap == 2.5e12 + assert line_items[0].free_cash_flow == 321.0 + + +def test_data_tools_skips_self_recursion_for_trading_service(monkeypatch): + monkeypatch.setenv("TRADING_SERVICE_URL", "http://localhost:8001") + monkeypatch.setenv("SERVICE_NAME", "trading_service") + + assert data_tools._trading_service_url() is None diff --git a/backend/tests/test_gateway_explain_handlers.py b/backend/tests/test_gateway_explain_handlers.py index 0f870c0..7378e13 100644 --- a/backend/tests/test_gateway_explain_handlers.py +++ b/backend/tests/test_gateway_explain_handlers.py @@ -6,6 +6,7 @@ import pytest from backend.services.gateway import Gateway import backend.services.gateway as gateway_module +from shared.schema import InsiderTrade, InsiderTradeResponse, Price, PriceResponse class DummyWebSocket: @@ -35,6 +36,10 @@ class FakeMarketStore: def __init__(self): self.calls = [] + def get_ticker_watermarks(self, symbol): + self.calls.append(("get_ticker_watermarks", symbol)) + return {"symbol": symbol, "last_news_fetch": "2026-12-31"} + def get_news_timeline_enriched(self, symbol, *, start_date=None, end_date=None): self.calls.append(("get_news_timeline_enriched", symbol, start_date, end_date)) return [{"date": end_date, "count": 2, "source_count": 1, "top_title": "Top", "positive_count": 1}] @@ -123,6 +128,75 @@ def make_gateway(market_store=None): ) +class FakeNewsClient: + def __init__(self, base_url): + self.base_url = base_url + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + return None + + async def get_categories(self, ticker, start_date=None, end_date=None, limit=200): + return {"ticker": ticker, "categories": {"remote": {"count": 2}}} + + async def get_enriched_news(self, ticker, start_date=None, end_date=None, limit=None): + return { + "ticker": ticker, + "news": [ + { + "id": "remote-news-1", + "ticker": ticker, + "title": "Remote Title", + "date": end_date, + } + ], + } + + async def get_story(self, ticker, as_of_date): + return {"symbol": ticker, "as_of_date": as_of_date, "story": "remote story", "source": "news_service"} + + +class FakeTradingClient: + def __init__(self, base_url): + self.base_url = base_url + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + return None + + async def get_insider_trades(self, ticker, end_date=None, start_date=None, limit=None): + return InsiderTradeResponse( + insider_trades=[ + InsiderTrade( + ticker=ticker, + name="Remote Insider", + filing_date=end_date or "2026-03-16", + ) + ] + ) + + async def get_prices(self, ticker, start_date=None, end_date=None): + prices = [ + Price( + open=float(100 + idx), + close=float(101 + idx), + high=float(102 + idx), + low=float(99 + idx), + volume=1000 + idx, + time=f"2026-01-{idx + 1:02d}", + ) + for idx in range(30) + ] + return PriceResponse(ticker=ticker, prices=prices) + + async def get_market_cap(self, ticker, end_date): + return {"ticker": ticker, "end_date": end_date, "market_cap": 2.5e12} + + @pytest.mark.asyncio async def test_handle_get_stock_news_timeline_uses_market_store_symbol_argument(): market_store = FakeMarketStore() @@ -135,6 +209,7 @@ async def test_handle_get_stock_news_timeline_uses_market_store_symbol_argument( ) assert market_store.calls == [ + ("get_ticker_watermarks", "AAPL"), ("get_news_timeline_enriched", "AAPL", "2026-02-14", "2026-03-16") ] assert websocket.messages[-1]["type"] == "stock_news_timeline_loaded" @@ -153,6 +228,7 @@ async def test_handle_get_stock_news_categories_uses_market_store_symbol_argumen ) assert market_store.calls == [ + ("get_ticker_watermarks", "AAPL"), ("get_news_items_enriched", "AAPL", "2026-02-14", "2026-03-16", None, 200), ("get_news_categories_enriched", "AAPL", "2026-02-14", "2026-03-16", 200) ] @@ -175,7 +251,7 @@ async def test_handle_get_stock_range_explain_uses_market_store_rows(monkeypatch } monkeypatch.setattr( - gateway_module, + gateway_module.news_domain, "build_range_explanation", fake_build_range_explanation, ) @@ -186,6 +262,7 @@ async def test_handle_get_stock_range_explain_uses_market_store_rows(monkeypatch ) assert market_store.calls == [ + ("get_ticker_watermarks", "AAPL"), ("get_news_items_enriched", "AAPL", "2026-03-10", "2026-03-16", None, 100) ] assert websocket.messages[-1] == { @@ -207,7 +284,7 @@ async def test_handle_get_stock_range_explain_uses_article_ids_path(monkeypatch) websocket = DummyWebSocket() monkeypatch.setattr( - gateway_module, + gateway_module.news_domain, "build_range_explanation", lambda **kwargs: {"news_count": len(kwargs["news_rows"])}, ) @@ -222,7 +299,10 @@ async def test_handle_get_stock_range_explain_uses_article_ids_path(monkeypatch) }, ) - assert market_store.calls == [("get_news_by_ids_enriched", "AAPL", ["news-99"])] + assert market_store.calls == [ + ("get_ticker_watermarks", "AAPL"), + ("get_news_by_ids_enriched", "AAPL", ["news-99"]) + ] assert websocket.messages[-1]["result"]["news_count"] == 1 @@ -238,6 +318,7 @@ async def test_handle_get_stock_news_for_date_uses_trade_date_lookup(): ) assert market_store.calls == [ + ("get_ticker_watermarks", "AAPL"), ("get_news_items_enriched", "AAPL", None, None, "2026-03-16", 10) ] assert websocket.messages[-1]["type"] == "stock_news_for_date_loaded" @@ -251,7 +332,7 @@ async def test_handle_get_stock_story_returns_story_payload(monkeypatch): websocket = DummyWebSocket() monkeypatch.setattr( - gateway_module, + gateway_module.news_domain, "enrich_news_for_symbol", lambda *args, **kwargs: {"symbol": "AAPL", "analyzed": 3}, ) @@ -266,6 +347,132 @@ async def test_handle_get_stock_story_returns_story_payload(monkeypatch): assert "AAPL Story" in websocket.messages[-1]["story"] +@pytest.mark.asyncio +async def test_handle_get_stock_news_categories_uses_news_service_client_when_configured(monkeypatch): + market_store = FakeMarketStore() + gateway = make_gateway(market_store) + websocket = DummyWebSocket() + + monkeypatch.setenv("NEWS_SERVICE_URL", "http://news-service.local") + monkeypatch.setattr(gateway_module, "NewsServiceClient", FakeNewsClient) + + await gateway._handle_get_stock_news_categories( + websocket, + {"ticker": "AAPL", "lookback_days": 30}, + ) + + assert market_store.calls == [] + assert websocket.messages[-1]["type"] == "stock_news_categories_loaded" + assert websocket.messages[-1]["categories"]["remote"]["count"] == 2 + + +@pytest.mark.asyncio +async def test_handle_get_stock_story_uses_news_service_client_when_configured(monkeypatch): + market_store = FakeMarketStore() + gateway = make_gateway(market_store) + websocket = DummyWebSocket() + + monkeypatch.setenv("NEWS_SERVICE_URL", "http://news-service.local") + monkeypatch.setattr(gateway_module, "NewsServiceClient", FakeNewsClient) + + await gateway._handle_get_stock_story( + websocket, + {"ticker": "AAPL", "as_of_date": "2026-03-16"}, + ) + + assert market_store.calls == [] + assert websocket.messages[-1]["type"] == "stock_story_loaded" + assert websocket.messages[-1]["story"] == "remote story" + + +@pytest.mark.asyncio +async def test_handle_get_stock_news_uses_news_service_client_when_configured(monkeypatch): + market_store = FakeMarketStore() + gateway = make_gateway(market_store) + websocket = DummyWebSocket() + + monkeypatch.setenv("NEWS_SERVICE_URL", "http://news-service.local") + monkeypatch.setattr(gateway_module, "NewsServiceClient", FakeNewsClient) + + await gateway._handle_get_stock_news( + websocket, + {"ticker": "AAPL", "lookback_days": 30, "limit": 5}, + ) + + assert market_store.calls == [] + assert websocket.messages[-1]["type"] == "stock_news_loaded" + assert websocket.messages[-1]["source"] == "news_service" + assert websocket.messages[-1]["news"][0]["title"] == "Remote Title" + + +@pytest.mark.asyncio +async def test_handle_get_stock_insider_trades_uses_trading_service_client_when_configured(monkeypatch): + market_store = FakeMarketStore() + gateway = make_gateway(market_store) + websocket = DummyWebSocket() + + monkeypatch.setenv("TRADING_SERVICE_URL", "http://trading-service.local") + monkeypatch.setattr(gateway_module, "TradingServiceClient", FakeTradingClient) + + await gateway._handle_get_stock_insider_trades( + websocket, + {"ticker": "AAPL", "end_date": "2026-03-16", "limit": 10}, + ) + + assert websocket.messages[-1]["type"] == "stock_insider_trades_loaded" + assert websocket.messages[-1]["trades"][0]["name"] == "Remote Insider" + + +@pytest.mark.asyncio +async def test_handle_get_stock_history_uses_trading_service_client_when_configured(monkeypatch): + market_store = FakeMarketStore() + gateway = make_gateway(market_store) + websocket = DummyWebSocket() + + monkeypatch.setenv("TRADING_SERVICE_URL", "http://trading-service.local") + monkeypatch.setattr(gateway_module, "TradingServiceClient", FakeTradingClient) + + await gateway._handle_get_stock_history( + websocket, + {"ticker": "AAPL", "lookback_days": 30}, + ) + + assert market_store.calls == [] + assert websocket.messages[-1]["type"] == "stock_history_loaded" + assert websocket.messages[-1]["source"] == "trading_service" + assert len(websocket.messages[-1]["prices"]) == 30 + + +@pytest.mark.asyncio +async def test_handle_get_stock_technical_indicators_uses_trading_service_client_when_configured(monkeypatch): + gateway = make_gateway(FakeMarketStore()) + websocket = DummyWebSocket() + + monkeypatch.setenv("TRADING_SERVICE_URL", "http://trading-service.local") + monkeypatch.setattr(gateway_module, "TradingServiceClient", FakeTradingClient) + + await gateway._handle_get_stock_technical_indicators( + websocket, + {"ticker": "AAPL"}, + ) + + assert websocket.messages[-1]["type"] == "stock_technical_indicators_loaded" + assert websocket.messages[-1]["ticker"] == "AAPL" + assert websocket.messages[-1]["indicators"] is not None + + +@pytest.mark.asyncio +async def test_get_market_caps_uses_trading_service_client_when_configured(monkeypatch): + gateway = make_gateway(FakeMarketStore()) + + monkeypatch.setenv("TRADING_SERVICE_URL", "http://trading-service.local") + monkeypatch.setattr(gateway_module, "TradingServiceClient", FakeTradingClient) + + market_caps = await gateway._get_market_caps(["AAPL", "MSFT"], "2026-03-16") + + assert market_caps == {"AAPL": 2.5e12, "MSFT": 2.5e12} + + @pytest.mark.asyncio async def test_handle_get_stock_similar_days_returns_items(monkeypatch): market_store = FakeMarketStore() @@ -273,7 +480,7 @@ async def test_handle_get_stock_similar_days_returns_items(monkeypatch): websocket = DummyWebSocket() monkeypatch.setattr( - gateway_module, + gateway_module.news_domain, "enrich_news_for_symbol", lambda *args, **kwargs: {"symbol": "AAPL", "analyzed": 3}, ) @@ -295,7 +502,12 @@ async def test_handle_run_stock_enrich_rebuilds_caches(monkeypatch): websocket = DummyWebSocket() monkeypatch.setattr( - gateway_module, + gateway_module.gateway_stock_handlers, + "enrich_news_for_symbol", + lambda *args, **kwargs: {"symbol": "AAPL", "analyzed": 2, "queued_count": 2}, + ) + monkeypatch.setattr( + gateway_module.news_domain, "enrich_news_for_symbol", lambda *args, **kwargs: {"symbol": "AAPL", "analyzed": 2, "queued_count": 2}, ) @@ -325,7 +537,7 @@ async def test_handle_run_stock_enrich_rejects_local_to_llm_without_llm(monkeypa gateway = make_gateway(FakeMarketStore()) websocket = DummyWebSocket() - monkeypatch.setattr(gateway_module, "llm_enrichment_enabled", lambda: False) + monkeypatch.setattr(gateway_module.gateway_stock_handlers, "llm_enrichment_enabled", lambda: False) await gateway._handle_run_stock_enrich( websocket, @@ -361,7 +573,7 @@ def test_schedule_watchlist_market_store_refresh_creates_task(monkeypatch): gateway._schedule_watchlist_market_store_refresh(["AAPL", "MSFT"]) - assert captured["coro_name"] == "_refresh_market_store_for_watchlist" + assert captured["coro_name"] == "refresh_market_store_for_watchlist" @pytest.mark.asyncio @@ -369,7 +581,7 @@ async def test_refresh_market_store_for_watchlist_emits_system_messages(monkeypa gateway = make_gateway() monkeypatch.setattr( - gateway_module, + gateway_module.gateway_cycle_support, "ingest_symbols", lambda symbols, mode="incremental": [ {"symbol": symbol, "prices": 3, "news": 4, "aligned": 4} @@ -445,12 +657,12 @@ async def test_handle_get_agent_profile_returns_model_and_tool_groups(monkeypatc websocket = DummyWebSocket() monkeypatch.setattr( - gateway_module, + gateway_module.gateway_admin_handlers, "load_agent_profiles", lambda: {"risk_manager": {"skills": ["risk_review"], "active_tool_groups": ["risk_ops", "legacy_group"]}}, ) monkeypatch.setattr( - gateway_module, + gateway_module.gateway_admin_handlers, "get_agent_model_info", lambda agent_id: ("gpt-4o-mini", "OPENAI"), ) @@ -461,7 +673,7 @@ async def test_handle_get_agent_profile_returns_model_and_tool_groups(monkeypatc return {} monkeypatch.setattr( - gateway_module, + gateway_module.gateway_admin_handlers, "get_bootstrap_config_for_run", lambda project_root, config_name: _Bootstrap(), ) diff --git a/backend/tests/test_gateway_support_modules.py b/backend/tests/test_gateway_support_modules.py new file mode 100644 index 0000000..1812828 --- /dev/null +++ b/backend/tests/test_gateway_support_modules.py @@ -0,0 +1,211 @@ +# -*- coding: utf-8 -*- +"""Direct tests for Gateway support modules.""" + +from types import SimpleNamespace + +import pytest + +from backend.services import gateway_cycle_support, gateway_runtime_support + + +class _DummyDashboard: + def __init__(self): + self.updated = [] + self.tickers = [] + self.initial_cash = None + self.enable_memory = False + self.days_total = 0 + + def update(self, **kwargs): + self.updated.append(kwargs) + + def stop(self): + return None + + def print_final_summary(self): + return None + + +class _DummyScheduler: + def __init__(self): + self.calls = [] + + def reconfigure(self, **kwargs): + self.calls.append(kwargs) + + +class _DummyStateSync: + def __init__(self): + self.updated = [] + self.saved = False + self.system_messages = [] + self.backtest_dates = [] + self.state = {} + + def update_state(self, key, value): + self.updated.append((key, value)) + self.state[key] = value + + def save_state(self): + self.saved = True + + async def on_system_message(self, message): + self.system_messages.append(message) + + def set_backtest_dates(self, dates): + self.backtest_dates = list(dates) + + +class _DummyStorage: + def __init__(self): + self.initial_cash = 100000.0 + self.is_live_session_active = False + self.server_state_updates = [] + + def can_apply_initial_cash(self): + return True + + def apply_initial_cash(self, value): + self.initial_cash = value + return True + + def update_server_state_from_dashboard(self, state): + self.server_state_updates.append(state) + + def load_file(self, name): + if name == "summary": + return {"totalAssetValue": self.initial_cash} + return [] + + +class _DummyPM: + def __init__(self): + self.portfolio = {"margin_requirement": 0.0} + + def apply_runtime_portfolio_config(self, margin_requirement=None, initial_cash=None): + if margin_requirement is not None: + self.portfolio["margin_requirement"] = margin_requirement + return {"margin_requirement": True} + + def can_apply_initial_cash(self): + return True + + +class _DummyMarketService: + def __init__(self): + self.updated = None + self.stopped = False + + def update_tickers(self, tickers): + self.updated = list(tickers) + return {"active": list(tickers), "added": list(tickers), "removed": []} + + def stop(self): + self.stopped = True + + +def make_gateway_stub(): + pipeline = SimpleNamespace(max_comm_cycles=0, pm=_DummyPM()) + gateway = SimpleNamespace( + market_service=_DummyMarketService(), + pipeline=pipeline, + scheduler=_DummyScheduler(), + config={ + "tickers": ["AAPL"], + "schedule_mode": "daily", + "interval_minutes": 60, + "trigger_time": "09:30", + "enable_memory": False, + }, + storage=_DummyStorage(), + state_sync=_DummyStateSync(), + _dashboard=_DummyDashboard(), + _watchlist_ingest_task=None, + _market_status_task=None, + _backtest_task=None, + _backtest_start_date=None, + _backtest_end_date=None, + _manual_cycle_task=None, + ) + return gateway + + +def test_normalize_watchlist_filters_invalid_and_dedupes(): + assert gateway_runtime_support.normalize_watchlist(["aapl", " AAPL ", "", "msft"]) == ["AAPL", "MSFT"] + assert gateway_runtime_support.normalize_watchlist("aapl,msft") == ["AAPL", "MSFT"] + + +def test_normalize_agent_workspace_filename_obeys_allowlist(): + allowlist = {"SOUL.md", "PROFILE.md"} + assert gateway_runtime_support.normalize_agent_workspace_filename("SOUL.md", allowlist=allowlist) == "SOUL.md" + assert gateway_runtime_support.normalize_agent_workspace_filename("README.md", allowlist=allowlist) is None + + +def test_apply_runtime_config_updates_gateway_state(): + gateway = make_gateway_stub() + + result = gateway_runtime_support.apply_runtime_config( + gateway, + { + "tickers": ["MSFT", "NVDA"], + "schedule_mode": "intraday", + "interval_minutes": 30, + "trigger_time": "10:30", + "initial_cash": 150000.0, + "margin_requirement": 0.5, + "max_comm_cycles": 4, + "enable_memory": False, + }, + ) + + assert gateway.config["tickers"] == ["MSFT", "NVDA"] + assert gateway.config["schedule_mode"] == "intraday" + assert gateway.storage.initial_cash == 150000.0 + assert result["runtime_config_applied"]["max_comm_cycles"] == 4 + assert gateway.scheduler.calls[-1] == { + "mode": "intraday", + "trigger_time": "10:30", + "interval_minutes": 30, + } + + +def test_schedule_watchlist_market_store_refresh_creates_task(monkeypatch): + gateway = make_gateway_stub() + captured = {} + + class DummyTask: + def done(self): + return False + + def cancel(self): + captured["cancelled"] = True + + def fake_create_task(coro): + captured["name"] = coro.cr_code.co_name + coro.close() + return DummyTask() + + monkeypatch.setattr(gateway_cycle_support.asyncio, "create_task", fake_create_task) + + gateway_cycle_support.schedule_watchlist_market_store_refresh(gateway, ["AAPL", "MSFT"]) + + assert captured["name"] == "refresh_market_store_for_watchlist" + + +@pytest.mark.asyncio +async def test_refresh_market_store_for_watchlist_emits_system_messages(monkeypatch): + gateway = make_gateway_stub() + + monkeypatch.setattr( + gateway_cycle_support, + "ingest_symbols", + lambda symbols, mode="incremental": [ + {"symbol": symbol, "prices": 3, "news": 4} + for symbol in symbols + ], + ) + + await gateway_cycle_support.refresh_market_store_for_watchlist(gateway, ["AAPL", "MSFT"]) + + assert gateway.state_sync.system_messages[0] == "正在同步自选股市场数据: AAPL, MSFT" + assert "自选股市场数据已同步:" in gateway.state_sync.system_messages[1] diff --git a/backend/tests/test_news_domain.py b/backend/tests/test_news_domain.py new file mode 100644 index 0000000..2b8cc7f --- /dev/null +++ b/backend/tests/test_news_domain.py @@ -0,0 +1,171 @@ +# -*- coding: utf-8 -*- +"""Unit tests for the news domain helpers.""" + +from backend.domains import news as news_domain + + +class _FakeStore: + def __init__(self): + self.calls = [] + + def get_ticker_watermarks(self, symbol): + self.calls.append(("get_ticker_watermarks", symbol)) + return {"symbol": symbol, "last_news_fetch": "2026-03-10"} + + def get_news_items_enriched(self, ticker, start_date=None, end_date=None, trade_date=None, limit=100): + self.calls.append(("get_news_items_enriched", ticker, start_date, end_date, trade_date, limit)) + target = trade_date or end_date + return [{"id": "n1", "ticker": ticker, "date": target, "trade_date": target}] + + def get_news_timeline_enriched(self, ticker, start_date=None, end_date=None): + self.calls.append(("get_news_timeline_enriched", ticker, start_date, end_date)) + return [{"date": end_date, "count": 1}] + + def get_news_categories_enriched(self, ticker, start_date=None, end_date=None, limit=200): + self.calls.append(("get_news_categories_enriched", ticker, start_date, end_date, limit)) + return {"macro": {"count": 1}} + + def get_news_by_ids_enriched(self, ticker, article_ids): + self.calls.append(("get_news_by_ids_enriched", ticker, list(article_ids))) + return [{"id": article_ids[0], "ticker": ticker, "date": "2026-03-16"}] + + +def test_news_rows_need_enrichment_detects_missing_fields(): + assert news_domain.news_rows_need_enrichment([]) is True + assert news_domain.news_rows_need_enrichment([{"sentiment": "", "relevance": "", "key_discussion": ""}]) is True + assert news_domain.news_rows_need_enrichment([{"sentiment": "positive"}]) is False + + +def test_ensure_news_fresh_triggers_incremental_refresh_when_watermark_is_stale(monkeypatch): + store = _FakeStore() + calls = [] + + monkeypatch.setattr( + news_domain, + "update_ticker_incremental", + lambda symbol, end_date=None, store=None: calls.append((symbol, end_date)), + ) + + payload = news_domain.ensure_news_fresh(store, ticker="AAPL", target_date="2026-03-16") + + assert calls == [("AAPL", "2026-03-16")] + assert payload["target_date"] == "2026-03-16" + assert payload["refreshed"] is True + + +def test_ensure_news_fresh_skips_refresh_when_watermark_is_current(monkeypatch): + store = _FakeStore() + calls = [] + + monkeypatch.setattr( + store, + "get_ticker_watermarks", + lambda symbol: {"symbol": symbol, "last_news_fetch": "2026-03-16"}, + ) + monkeypatch.setattr( + news_domain, + "update_ticker_incremental", + lambda symbol, end_date=None, store=None: calls.append((symbol, end_date)), + ) + + payload = news_domain.ensure_news_fresh(store, ticker="AAPL", target_date="2026-03-16") + + assert calls == [] + assert payload["refreshed"] is False + + +def test_get_enriched_news_returns_rows_without_enrichment_when_present(monkeypatch): + store = _FakeStore() + monkeypatch.setattr(news_domain, "news_rows_need_enrichment", lambda rows: False) + monkeypatch.setattr( + news_domain, + "ensure_news_fresh", + lambda store, ticker, target_date=None: { + "ticker": ticker, + "target_date": target_date, + "last_news_fetch": target_date, + "refreshed": False, + }, + ) + + payload = news_domain.get_enriched_news( + store, + ticker="AAPL", + start_date="2026-03-01", + end_date="2026-03-16", + limit=20, + ) + + assert payload["ticker"] == "AAPL" + assert payload["news"][0]["ticker"] == "AAPL" + assert payload["freshness"]["target_date"] is None or payload["freshness"]["target_date"] == "2026-03-16" + assert store.calls == [ + ("get_news_items_enriched", "AAPL", "2026-03-01", "2026-03-16", None, 20) + ] + + +def test_get_story_and_similar_days_delegate(monkeypatch): + store = _FakeStore() + monkeypatch.setattr( + news_domain, + "ensure_news_fresh", + lambda store, ticker, target_date=None: { + "ticker": ticker, + "target_date": target_date, + "last_news_fetch": target_date, + "refreshed": False, + }, + ) + monkeypatch.setattr(news_domain, "enrich_news_for_symbol", lambda *args, **kwargs: {"analyzed": 1}) + monkeypatch.setattr( + news_domain, + "get_or_create_stock_story", + lambda store, symbol, as_of_date: {"symbol": symbol, "as_of_date": as_of_date, "story": "story"}, + ) + monkeypatch.setattr( + news_domain, + "find_similar_days", + lambda store, symbol, target_date, top_k: {"symbol": symbol, "target_date": target_date, "items": [{"score": 0.9}]}, + ) + + story = news_domain.get_story_payload(store, ticker="AAPL", as_of_date="2026-03-16") + similar = news_domain.get_similar_days_payload(store, ticker="AAPL", date="2026-03-16", n_similar=8) + + assert story["story"] == "story" + assert "freshness" in story + assert similar["items"][0]["score"] == 0.9 + assert "freshness" in similar + + +def test_get_range_explain_payload_uses_article_ids(monkeypatch): + store = _FakeStore() + monkeypatch.setattr( + news_domain, + "ensure_news_fresh", + lambda store, ticker, target_date=None: { + "ticker": ticker, + "target_date": target_date, + "last_news_fetch": target_date, + "refreshed": False, + }, + ) + monkeypatch.setattr(news_domain, "news_rows_need_enrichment", lambda rows: False) + monkeypatch.setattr( + news_domain, + "build_range_explanation", + lambda ticker, start_date, end_date, news_rows: {"ticker": ticker, "count": len(news_rows)}, + ) + + payload = news_domain.get_range_explain_payload( + store, + ticker="AAPL", + start_date="2026-03-10", + end_date="2026-03-16", + article_ids=["news-9"], + limit=50, + ) + + assert payload["ticker"] == "AAPL" + assert payload["result"] == {"ticker": "AAPL", "count": 1} + assert "freshness" in payload + assert store.calls == [("get_news_by_ids_enriched", "AAPL", ["news-9"])] diff --git a/backend/tests/test_news_service_app.py b/backend/tests/test_news_service_app.py new file mode 100644 index 0000000..3f15ef2 --- /dev/null +++ b/backend/tests/test_news_service_app.py @@ -0,0 +1,180 @@ +# -*- coding: utf-8 -*- +"""Tests for the extracted news service app surface.""" + +from fastapi.testclient import TestClient + +from backend.apps.news_service import create_app + + +class _FakeStore: + def get_ticker_watermarks(self, symbol): + return {"symbol": symbol, "last_news_fetch": "2026-12-31"} + + def get_news_timeline_enriched(self, symbol, start_date=None, end_date=None): + return [{"date": end_date, "count": 1}] + + def get_news_items(self, symbol, start_date=None, end_date=None, limit=100): + return [{"id": "news-raw-1", "ticker": symbol, "title": "Raw Title", "date": end_date}] + + def get_news_items_enriched(self, symbol, start_date=None, end_date=None, trade_date=None, limit=100): + return [{"id": "news-1", "ticker": symbol, "title": "Title", "date": trade_date or end_date}] + + def upsert_news_analysis(self, symbol, rows): + return len(rows) + + def get_analyzed_news_ids(self, symbol, start_date=None, end_date=None): + return set() + + def get_news_categories_enriched(self, symbol, start_date=None, end_date=None, limit=200): + return {"market": {"label": "market", "count": 1, "article_ids": ["news-1"]}} + + def get_news_by_ids_enriched(self, symbol, article_ids): + return [{"id": article_ids[0], "ticker": symbol, "title": "Picked"}] + + +def test_news_service_routes_are_exposed(): + app = create_app() + paths = {route.path for route in app.routes} + + assert "/health" in paths + assert "/api/enriched-news" in paths + assert "/api/news-for-date" in paths + assert "/api/news-timeline" in paths + assert "/api/categories" in paths + assert "/api/similar-days" in paths + assert "/api/stories/{ticker}" in paths + assert "/api/range-explain" in paths + + +def test_news_service_enriched_news_and_categories(monkeypatch): + app = create_app() + app.dependency_overrides.clear() + from backend.apps import news_service as news_service_module + + app.dependency_overrides[news_service_module.get_market_store] = lambda: _FakeStore() + monkeypatch.setattr( + "backend.domains.news.enrich_news_for_symbol", + lambda *args, **kwargs: {"symbol": "AAPL", "analyzed": 1}, + ) + + with TestClient(app) as client: + news_response = client.get( + "/api/enriched-news", + params={"ticker": "AAPL", "end_date": "2026-03-23"}, + ) + categories_response = client.get( + "/api/categories", + params={"ticker": "AAPL", "end_date": "2026-03-23"}, + ) + + assert news_response.status_code == 200 + assert news_response.json()["news"][0]["ticker"] == "AAPL" + assert categories_response.status_code == 200 + assert categories_response.json()["categories"]["market"]["count"] == 1 + + +def test_news_service_news_for_date_and_timeline(monkeypatch): + app = create_app() + from backend.apps import news_service as news_service_module + + app.dependency_overrides[news_service_module.get_market_store] = lambda: _FakeStore() + monkeypatch.setattr( + "backend.domains.news.enrich_news_for_symbol", + lambda *args, **kwargs: {"symbol": "AAPL", "analyzed": 1}, + ) + + with TestClient(app) as client: + date_response = client.get( + "/api/news-for-date", + params={"ticker": "AAPL", "date": "2026-03-23"}, + ) + timeline_response = client.get( + "/api/news-timeline", + params={ + "ticker": "AAPL", + "start_date": "2026-03-01", + "end_date": "2026-03-23", + }, + ) + + assert date_response.status_code == 200 + assert date_response.json()["date"] == "2026-03-23" + assert timeline_response.status_code == 200 + assert timeline_response.json()["timeline"][0]["count"] == 1 + + +def test_news_service_similar_days_and_story(monkeypatch): + app = create_app() + from backend.apps import news_service as news_service_module + + app.dependency_overrides[news_service_module.get_market_store] = lambda: _FakeStore() + monkeypatch.setattr( + "backend.domains.news.enrich_news_for_symbol", + lambda *args, **kwargs: {"symbol": "AAPL", "analyzed": 1}, + ) + monkeypatch.setattr( + "backend.domains.news.find_similar_days", + lambda store, symbol, target_date, top_k: { + "symbol": symbol, + "target_date": target_date, + "items": [{"date": "2026-03-20", "score": 0.9}], + }, + ) + monkeypatch.setattr( + "backend.domains.news.get_or_create_stock_story", + lambda store, symbol, as_of_date: { + "symbol": symbol, + "as_of_date": as_of_date, + "story": "story body", + "source": "local", + }, + ) + + with TestClient(app) as client: + similar_response = client.get( + "/api/similar-days", + params={"ticker": "AAPL", "date": "2026-03-23", "n_similar": 3}, + ) + story_response = client.get( + "/api/stories/AAPL", + params={"as_of_date": "2026-03-23"}, + ) + + assert similar_response.status_code == 200 + assert similar_response.json()["items"][0]["score"] == 0.9 + assert story_response.status_code == 200 + assert story_response.json()["story"] == "story body" + + +def test_news_service_range_explain(monkeypatch): + app = create_app() + from backend.apps import news_service as news_service_module + + app.dependency_overrides[news_service_module.get_market_store] = lambda: _FakeStore() + monkeypatch.setattr( + "backend.domains.news.enrich_news_for_symbol", + lambda *args, **kwargs: {"symbol": "AAPL", "analyzed": 1}, + ) + monkeypatch.setattr( + "backend.domains.news.build_range_explanation", + lambda ticker, start_date, end_date, news_rows: { + "symbol": ticker, + "news_count": len(news_rows), + "start_date": start_date, + "end_date": end_date, + }, + ) + + with TestClient(app) as client: + response = client.get( + "/api/range-explain", + params={ + "ticker": "AAPL", + "start_date": "2026-03-01", + "end_date": "2026-03-23", + "article_ids": ["news-7"], + }, + ) + + assert response.status_code == 200 + assert response.json()["result"]["news_count"] == 1 diff --git a/backend/tests/test_provider_router.py b/backend/tests/test_provider_router.py index f7ef926..cdb5487 100644 --- a/backend/tests/test_provider_router.py +++ b/backend/tests/test_provider_router.py @@ -9,6 +9,7 @@ def test_router_includes_local_csv_fallback(monkeypatch): monkeypatch.delenv("FINNHUB_API_KEY", raising=False) monkeypatch.delenv("FINANCIAL_DATASETS_API_KEY", raising=False) monkeypatch.delenv("FIN_DATA_SOURCE", raising=False) + monkeypatch.delenv("ENABLED_DATA_SOURCES", raising=False) reset_config() router = DataProviderRouter() diff --git a/backend/tests/test_runtime_service_app.py b/backend/tests/test_runtime_service_app.py new file mode 100644 index 0000000..2c406fa --- /dev/null +++ b/backend/tests/test_runtime_service_app.py @@ -0,0 +1,194 @@ +# -*- coding: utf-8 -*- +"""Tests for the extracted runtime service app surface.""" + +import json + +from fastapi.testclient import TestClient + +from backend.api import runtime as runtime_module +from backend.apps.runtime_service import create_app + + +def test_runtime_service_routes_are_exposed(): + app = create_app() + paths = {route.path for route in app.routes} + + assert "/health" in paths + assert "/api/status" in paths + assert "/api/runtime/start" in paths + assert "/api/runtime/stop" in paths + assert "/api/runtime/current" in paths + assert "/api/runtime/gateway/port" in paths + + +def test_runtime_service_health_and_status(monkeypatch): + runtime_state = runtime_module.get_runtime_state() + runtime_state.gateway_process = None + runtime_state.gateway_port = 9876 + runtime_state.runtime_manager = object() + + with TestClient(create_app()) as client: + health_response = client.get("/health") + status_response = client.get("/api/status") + + assert health_response.status_code == 200 + assert health_response.json() == { + "status": "healthy", + "service": "runtime-service", + "gateway_running": False, + "gateway_port": 9876, + } + assert status_response.status_code == 200 + assert status_response.json() == { + "status": "operational", + "service": "runtime-service", + "runtime": { + "gateway_running": False, + "gateway_port": 9876, + "has_runtime_manager": True, + }, + } + + +def test_runtime_service_gateway_port_endpoint_uses_runtime_router(monkeypatch): + runtime_module.get_runtime_state().gateway_port = 9345 + monkeypatch.setattr(runtime_module, "_is_gateway_running", lambda: True) + + with TestClient(create_app()) as client: + response = client.get( + "/api/runtime/gateway/port", + headers={"host": "runtime.example:8003", "x-forwarded-proto": "https"}, + ) + + assert response.status_code == 200 + assert response.json() == { + "port": 9345, + "is_running": True, + "ws_url": "wss://runtime.example:9345", + } + + +def test_runtime_service_get_runtime_config(monkeypatch, tmp_path): + run_dir = tmp_path / "runs" / "demo" + state_dir = run_dir / "state" + state_dir.mkdir(parents=True) + (run_dir / "BOOTSTRAP.md").write_text( + "---\n" + "tickers:\n" + " - AAPL\n" + "schedule_mode: intraday\n" + "interval_minutes: 30\n" + "trigger_time: '10:00'\n" + "max_comm_cycles: 3\n" + "enable_memory: true\n" + "---\n", + encoding="utf-8", + ) + (state_dir / "runtime_state.json").write_text( + json.dumps( + { + "context": { + "config_name": "demo", + "run_dir": str(run_dir), + "bootstrap_values": { + "tickers": ["AAPL"], + "schedule_mode": "intraday", + "interval_minutes": 30, + "trigger_time": "10:00", + "max_comm_cycles": 3, + "enable_memory": True, + }, + } + } + ), + encoding="utf-8", + ) + monkeypatch.setattr(runtime_module, "PROJECT_ROOT", tmp_path) + monkeypatch.setattr(runtime_module, "_is_gateway_running", lambda: True) + runtime_module.get_runtime_state().gateway_port = 8765 + + with TestClient(create_app()) as client: + response = client.get("/api/runtime/config") + + assert response.status_code == 200 + payload = response.json() + assert payload["run_id"] == "demo" + assert payload["bootstrap"]["schedule_mode"] == "intraday" + assert payload["resolved"]["interval_minutes"] == 30 + assert payload["resolved"]["enable_memory"] is True + + +def test_runtime_service_update_runtime_config_persists_bootstrap(monkeypatch, tmp_path): + run_dir = tmp_path / "runs" / "demo" + state_dir = run_dir / "state" + state_dir.mkdir(parents=True) + (run_dir / "BOOTSTRAP.md").write_text( + "---\n" + "tickers:\n" + " - AAPL\n" + "schedule_mode: daily\n" + "interval_minutes: 60\n" + "trigger_time: '09:30'\n" + "max_comm_cycles: 2\n" + "---\n", + encoding="utf-8", + ) + (state_dir / "runtime_state.json").write_text( + json.dumps( + { + "context": { + "config_name": "demo", + "run_dir": str(run_dir), + "bootstrap_values": { + "tickers": ["AAPL"], + "schedule_mode": "daily", + "interval_minutes": 60, + "trigger_time": "09:30", + "max_comm_cycles": 2, + }, + } + } + ), + encoding="utf-8", + ) + + class _DummyContext: + def __init__(self): + self.bootstrap_values = { + "tickers": ["AAPL"], + "schedule_mode": "daily", + "interval_minutes": 60, + "trigger_time": "09:30", + "max_comm_cycles": 2, + } + + class _DummyManager: + def __init__(self): + self.config_name = "demo" + self.bootstrap = dict(_DummyContext().bootstrap_values) + self.context = _DummyContext() + + def _persist_snapshot(self): + return None + + monkeypatch.setattr(runtime_module, "PROJECT_ROOT", tmp_path) + monkeypatch.setattr(runtime_module, "_is_gateway_running", lambda: True) + runtime_module.get_runtime_state().runtime_manager = _DummyManager() + runtime_module.get_runtime_state().gateway_port = 8765 + + with TestClient(create_app()) as client: + response = client.put( + "/api/runtime/config", + json={ + "schedule_mode": "intraday", + "interval_minutes": 15, + "trigger_time": "10:15", + "max_comm_cycles": 4, + }, + ) + + assert response.status_code == 200 + payload = response.json() + assert payload["bootstrap"]["schedule_mode"] == "intraday" + assert payload["resolved"]["interval_minutes"] == 15 + assert "interval_minutes: 15" in (run_dir / "BOOTSTRAP.md").read_text(encoding="utf-8") diff --git a/backend/tests/test_service_clients.py b/backend/tests/test_service_clients.py new file mode 100644 index 0000000..19cc677 --- /dev/null +++ b/backend/tests/test_service_clients.py @@ -0,0 +1,107 @@ +# -*- coding: utf-8 -*- +"""Tests for split-aware shared service clients.""" + +import pytest + +from shared.client.control_client import ControlPlaneClient +from shared.client.runtime_client import RuntimeServiceClient + + +class _DummyResponse: + def __init__(self, payload): + self._payload = payload + + def raise_for_status(self): + return None + + def json(self): + return self._payload + + +class _DummyAsyncClient: + def __init__(self): + self.calls = [] + + async def get(self, path, params=None): + self.calls.append(("get", path, params)) + return _DummyResponse({"path": path, "params": params}) + + async def post(self, path, json=None): + self.calls.append(("post", path, json)) + return _DummyResponse({"path": path, "json": json}) + + async def put(self, path, json=None): + self.calls.append(("put", path, json)) + return _DummyResponse({"path": path, "json": json}) + + async def aclose(self): + return None + + +@pytest.mark.asyncio +async def test_control_plane_client_hits_current_workspace_and_guard_routes(): + client = ControlPlaneClient() + client._client = _DummyAsyncClient() + + await client.list_workspaces() + await client.get_workspace("demo") + await client.list_agents("demo") + await client.get_agent("demo", "risk_manager") + await client.fetch_pending_approvals() + await client.approve_pending_approval("ap-1") + await client.deny_pending_approval("ap-2", reason="nope") + + assert client._client.calls == [ + ("get", "/workspaces", None), + ("get", "/workspaces/demo", None), + ("get", "/workspaces/demo/agents", None), + ("get", "/workspaces/demo/agents/risk_manager", None), + ("get", "/guard/pending", None), + ( + "post", + "/guard/approve", + { + "approval_id": "ap-1", + "one_time": True, + "expires_in_minutes": 30, + }, + ), + ( + "post", + "/guard/deny", + { + "approval_id": "ap-2", + "reason": "nope", + }, + ), + ] + + +@pytest.mark.asyncio +async def test_runtime_service_client_hits_current_runtime_routes(): + client = RuntimeServiceClient() + client._client = _DummyAsyncClient() + + await client.fetch_context() + await client.fetch_agents() + await client.fetch_events() + await client.fetch_gateway_port() + await client.start_runtime({"tickers": ["AAPL"]}) + await client.stop_runtime(force=True) + await client.restart_runtime({"tickers": ["MSFT"]}) + await client.fetch_current_runtime() + await client.get_runtime_config() + await client.update_runtime_config({"schedule_mode": "intraday"}) + + assert client._client.calls == [ + ("get", "/context", None), + ("get", "/agents", None), + ("get", "/events", None), + ("get", "/gateway/port", None), + ("post", "/start", {"tickers": ["AAPL"]}), + ("post", "/stop?force=true", None), + ("post", "/restart", {"tickers": ["MSFT"]}), + ("get", "/current", None), + ("get", "/config", None), + ("put", "/config", {"schedule_mode": "intraday"}), + ] diff --git a/backend/tests/test_shared_schema_bridge.py b/backend/tests/test_shared_schema_bridge.py new file mode 100644 index 0000000..b7e91bb --- /dev/null +++ b/backend/tests/test_shared_schema_bridge.py @@ -0,0 +1,32 @@ +# -*- coding: utf-8 -*- +"""Regression coverage for the shared schema bridge.""" + +from backend.data import schema as legacy_schema +from shared import schema as shared_schema + + +def test_backend_data_schema_reexports_shared_contracts(): + assert legacy_schema.Price is shared_schema.Price + assert legacy_schema.PriceResponse is shared_schema.PriceResponse + assert legacy_schema.FinancialMetrics is shared_schema.FinancialMetrics + assert legacy_schema.FinancialMetricsResponse is ( + shared_schema.FinancialMetricsResponse + ) + assert legacy_schema.LineItem is shared_schema.LineItem + assert legacy_schema.LineItemResponse is shared_schema.LineItemResponse + assert legacy_schema.InsiderTrade is shared_schema.InsiderTrade + assert legacy_schema.InsiderTradeResponse is ( + shared_schema.InsiderTradeResponse + ) + assert legacy_schema.CompanyNews is shared_schema.CompanyNews + assert legacy_schema.CompanyNewsResponse is shared_schema.CompanyNewsResponse + assert legacy_schema.CompanyFacts is shared_schema.CompanyFacts + assert legacy_schema.CompanyFactsResponse is ( + shared_schema.CompanyFactsResponse + ) + assert legacy_schema.Position is shared_schema.Position + assert legacy_schema.Portfolio is shared_schema.Portfolio + assert legacy_schema.AnalystSignal is shared_schema.AnalystSignal + assert legacy_schema.TickerAnalysis is shared_schema.TickerAnalysis + assert legacy_schema.AgentStateData is shared_schema.AgentStateData + assert legacy_schema.AgentStateMetadata is shared_schema.AgentStateMetadata diff --git a/backend/tests/test_trading_domain.py b/backend/tests/test_trading_domain.py new file mode 100644 index 0000000..d248d57 --- /dev/null +++ b/backend/tests/test_trading_domain.py @@ -0,0 +1,47 @@ +# -*- coding: utf-8 -*- +"""Unit tests for the trading domain helpers.""" + +from backend.domains import trading as trading_domain + + +def test_trading_domain_payload_wrappers(monkeypatch): + monkeypatch.setattr(trading_domain, "get_prices", lambda ticker, start_date, end_date: [{"close": 1}]) + monkeypatch.setattr(trading_domain, "get_financial_metrics", lambda ticker, end_date, period, limit: [{"ticker": ticker}]) + monkeypatch.setattr(trading_domain, "get_company_news", lambda ticker, end_date, start_date=None, limit=1000: [{"ticker": ticker}]) + monkeypatch.setattr(trading_domain, "get_insider_trades", lambda ticker, end_date, start_date=None, limit=1000: [{"ticker": ticker}]) + monkeypatch.setattr(trading_domain, "get_market_cap", lambda ticker, end_date: 2.5e12) + + assert trading_domain.get_prices_payload(ticker="AAPL", start_date="2026-03-01", end_date="2026-03-16") == { + "ticker": "AAPL", + "prices": [{"close": 1}], + } + assert trading_domain.get_financials_payload(ticker="AAPL", end_date="2026-03-16") == { + "financial_metrics": [{"ticker": "AAPL"}], + } + assert trading_domain.get_news_payload(ticker="AAPL", end_date="2026-03-16") == { + "news": [{"ticker": "AAPL"}], + } + assert trading_domain.get_insider_trades_payload(ticker="AAPL", end_date="2026-03-16") == { + "insider_trades": [{"ticker": "AAPL"}], + } + assert trading_domain.get_market_cap_payload(ticker="AAPL", end_date="2026-03-16") == { + "ticker": "AAPL", + "end_date": "2026-03-16", + "market_cap": 2.5e12, + } + + +def test_get_market_status_payload_uses_market_service(monkeypatch): + class _FakeMarketService: + def __init__(self, tickers): + self.tickers = tickers + + def get_market_status(self): + return {"status": "open", "status_text": "Open"} + + monkeypatch.setattr(trading_domain, "MarketService", _FakeMarketService) + + assert trading_domain.get_market_status_payload() == { + "status": "open", + "status_text": "Open", + } diff --git a/backend/tests/test_trading_service_app.py b/backend/tests/test_trading_service_app.py new file mode 100644 index 0000000..1a7e9ea --- /dev/null +++ b/backend/tests/test_trading_service_app.py @@ -0,0 +1,231 @@ +# -*- coding: utf-8 -*- +"""Tests for the extracted trading service app surface.""" + +from fastapi.testclient import TestClient + +from backend.apps.trading_service import create_app +from shared.schema import CompanyNews, FinancialMetrics, InsiderTrade, LineItem, Price + + +def test_trading_service_routes_are_exposed(): + app = create_app() + + paths = {route.path for route in app.routes} + + assert "/health" in paths + assert "/api/prices" in paths + assert "/api/financials" in paths + assert "/api/news" in paths + assert "/api/insider-trades" in paths + assert "/api/market/status" in paths + assert "/api/market-cap" in paths + assert "/api/line-items" in paths + + +def test_trading_service_prices_endpoint(monkeypatch): + monkeypatch.setattr( + "backend.domains.trading.get_prices_payload", + lambda ticker, start_date, end_date: { + "ticker": ticker, + "prices": [ + Price( + open=1.0, + close=2.0, + high=2.5, + low=0.5, + volume=100, + time="2026-03-20", + ) + ], + }, + ) + + with TestClient(create_app()) as client: + response = client.get( + "/api/prices", + params={ + "ticker": "AAPL", + "start_date": "2026-03-01", + "end_date": "2026-03-20", + }, + ) + + assert response.status_code == 200 + assert response.json()["ticker"] == "AAPL" + assert response.json()["prices"][0]["close"] == 2.0 + + +def test_trading_service_financials_endpoint(monkeypatch): + monkeypatch.setattr( + "backend.domains.trading.get_financials_payload", + lambda ticker, end_date, period, limit: { + "financial_metrics": [ + FinancialMetrics( + ticker=ticker, + report_period=end_date, + period=period, + currency="USD", + market_cap=123.0, + enterprise_value=None, + price_to_earnings_ratio=None, + price_to_book_ratio=None, + price_to_sales_ratio=None, + enterprise_value_to_ebitda_ratio=None, + enterprise_value_to_revenue_ratio=None, + free_cash_flow_yield=None, + peg_ratio=None, + gross_margin=None, + operating_margin=None, + net_margin=None, + return_on_equity=None, + return_on_assets=None, + return_on_invested_capital=None, + asset_turnover=None, + inventory_turnover=None, + receivables_turnover=None, + days_sales_outstanding=None, + operating_cycle=None, + working_capital_turnover=None, + current_ratio=None, + quick_ratio=None, + cash_ratio=None, + operating_cash_flow_ratio=None, + debt_to_equity=None, + debt_to_assets=None, + interest_coverage=None, + revenue_growth=None, + earnings_growth=None, + book_value_growth=None, + earnings_per_share_growth=None, + free_cash_flow_growth=None, + operating_income_growth=None, + ebitda_growth=None, + payout_ratio=None, + earnings_per_share=None, + book_value_per_share=None, + free_cash_flow_per_share=None, + ) + ] + }, + ) + + with TestClient(create_app()) as client: + response = client.get( + "/api/financials", + params={"ticker": "AAPL", "end_date": "2026-03-20"}, + ) + + assert response.status_code == 200 + assert response.json()["financial_metrics"][0]["ticker"] == "AAPL" + + +def test_trading_service_news_and_insider_endpoints(monkeypatch): + monkeypatch.setattr( + "backend.domains.trading.get_news_payload", + lambda ticker, end_date, start_date=None, limit=1000: { + "news": [ + CompanyNews( + ticker=ticker, + title="News title", + source="polygon", + url="https://example.com/news", + date=end_date, + ) + ] + }, + ) + monkeypatch.setattr( + "backend.domains.trading.get_insider_trades_payload", + lambda ticker, end_date, start_date=None, limit=1000: { + "insider_trades": [ + InsiderTrade(ticker=ticker, filing_date=end_date) + ] + }, + ) + + with TestClient(create_app()) as client: + news_response = client.get( + "/api/news", + params={"ticker": "AAPL", "end_date": "2026-03-20"}, + ) + insider_response = client.get( + "/api/insider-trades", + params={"ticker": "AAPL", "end_date": "2026-03-20"}, + ) + + assert news_response.status_code == 200 + assert news_response.json()["news"][0]["title"] == "News title" + assert insider_response.status_code == 200 + assert insider_response.json()["insider_trades"][0]["ticker"] == "AAPL" + + +def test_trading_service_market_status_endpoint(monkeypatch): + class _FakeMarketService: + def get_market_status(self): + return {"status": "open", "status_text": "Open"} + + monkeypatch.setattr( + "backend.domains.trading.get_market_status_payload", + lambda: _FakeMarketService().get_market_status(), + ) + + with TestClient(create_app()) as client: + response = client.get("/api/market/status") + + assert response.status_code == 200 + assert response.json() == {"status": "open", "status_text": "Open"} + + +def test_trading_service_market_cap_endpoint(monkeypatch): + monkeypatch.setattr( + "backend.domains.trading.get_market_cap_payload", + lambda ticker, end_date: { + "ticker": ticker, + "end_date": end_date, + "market_cap": 3.5e12, + }, + ) + + with TestClient(create_app()) as client: + response = client.get( + "/api/market-cap", + params={"ticker": "AAPL", "end_date": "2026-03-20"}, + ) + + assert response.status_code == 200 + assert response.json() == { + "ticker": "AAPL", + "end_date": "2026-03-20", + "market_cap": 3.5e12, + } + + +def test_trading_service_line_items_endpoint(monkeypatch): + monkeypatch.setattr( + "backend.domains.trading.get_line_items_payload", + lambda ticker, line_items, end_date, period, limit: { + "search_results": [ + LineItem( + ticker=ticker, + report_period=end_date, + period=period, + currency="USD", + free_cash_flow=123.0, + ) + ] + }, + ) + + with TestClient(create_app()) as client: + response = client.get( + "/api/line-items", + params=[ + ("ticker", "AAPL"), + ("line_items", "free_cash_flow"), + ("end_date", "2026-03-20"), + ], + ) + + assert response.status_code == 200 + assert response.json()["search_results"][0]["ticker"] == "AAPL" + assert response.json()["search_results"][0]["free_cash_flow"] == 123.0 diff --git a/backend/tools/data_tools.py b/backend/tools/data_tools.py index cb43af8..0ed359e 100644 --- a/backend/tools/data_tools.py +++ b/backend/tools/data_tools.py @@ -3,13 +3,16 @@ # pylint: disable=C0301 """Data fetching tools backed by the unified provider router.""" import datetime +import os + +import httpx import pandas as pd import pandas_market_calendars as mcal from backend.data.provider_utils import normalize_symbol from backend.data.cache import get_cache from backend.data.provider_router import get_provider_router -from backend.data.schema import ( +from shared.schema import ( CompanyNews, FinancialMetrics, InsiderTrade, @@ -23,6 +26,31 @@ _cache = get_cache() _router = get_provider_router() +def _service_name() -> str: + return str(os.getenv("SERVICE_NAME", "")).strip().lower() + + +def _trading_service_url() -> str | None: + value = str(os.getenv("TRADING_SERVICE_URL", "")).strip().rstrip("/") + if not value or _service_name() == "trading_service": + return None + return value + + +def _news_service_url() -> str | None: + value = str(os.getenv("NEWS_SERVICE_URL", "")).strip().rstrip("/") + if not value or _service_name() == "news_service": + return None + return value + + +def _service_get_json(base_url: str, path: str, *, params: dict[str, object]) -> dict: + with httpx.Client(base_url=base_url, timeout=30.0) as client: + response = client.get(path, params=params) + response.raise_for_status() + return response.json() + + def get_last_tradeday(date: str) -> str: """ Get the previous trading day for the specified date @@ -104,6 +132,24 @@ def get_prices( if cached_data := _cache.get_prices(cache_key): return [Price(**price) for price in cached_data] + service_url = _trading_service_url() + if service_url: + try: + payload = _service_get_json( + service_url, + "/api/prices", + params={ + "ticker": ticker, + "start_date": start_date, + "end_date": end_date, + }, + ) + prices = [Price(**price) for price in payload.get("prices", [])] + if prices: + return prices + except Exception as exc: + logger.info("Trading service price lookup failed for %s: %s", ticker, exc) + try: prices, data_source = _router.get_prices(ticker, start_date, end_date) except Exception as exc: @@ -146,6 +192,28 @@ def get_financial_metrics( if cached_data := _cache.get_financial_metrics(cache_key): return [FinancialMetrics(**metric) for metric in cached_data] + service_url = _trading_service_url() + if service_url: + try: + payload = _service_get_json( + service_url, + "/api/financials", + params={ + "ticker": ticker, + "end_date": end_date, + "period": period, + "limit": limit, + }, + ) + metrics = [ + FinancialMetrics(**metric) + for metric in payload.get("financial_metrics", []) + ] + if metrics: + return metrics + except Exception as exc: + logger.info("Trading service financial lookup failed for %s: %s", ticker, exc) + try: financial_metrics, data_source = _router.get_financial_metrics( ticker=ticker, @@ -183,6 +251,22 @@ def search_line_items( ticker = normalize_symbol(ticker) if not ticker: return [] + + service_url = _trading_service_url() + if service_url: + payload = _service_get_json( + service_url, + "/api/line-items", + params={ + "ticker": ticker, + "line_items": line_items, + "end_date": end_date, + "period": period, + "limit": limit, + }, + ) + return [LineItem(**item) for item in payload.get("search_results", [])] + return _router.search_line_items( ticker=ticker, line_items=line_items, @@ -213,6 +297,26 @@ def get_insider_trades( if cached_data := _cache.get_insider_trades(cache_key): return [InsiderTrade(**trade) for trade in cached_data] + service_url = _trading_service_url() + if service_url: + try: + params = {"ticker": ticker, "end_date": end_date, "limit": limit} + if start_date: + params["start_date"] = start_date + payload = _service_get_json( + service_url, + "/api/insider-trades", + params=params, + ) + trades = [ + InsiderTrade(**trade) + for trade in payload.get("insider_trades", []) + ] + if trades: + return trades + except Exception as exc: + logger.info("Trading service insider lookup failed for %s: %s", ticker, exc) + try: all_trades, data_source = _router.get_insider_trades( ticker=ticker, @@ -248,6 +352,40 @@ def get_company_news( if cached_data := _cache.get_company_news(cache_key): return [CompanyNews(**news) for news in cached_data] + trading_service_url = _trading_service_url() + if trading_service_url: + try: + params = {"ticker": ticker, "end_date": end_date, "limit": limit} + if start_date: + params["start_date"] = start_date + payload = _service_get_json( + trading_service_url, + "/api/news", + params=params, + ) + news = [CompanyNews(**item) for item in payload.get("news", [])] + if news: + return news + except Exception as exc: + logger.info("Trading service news lookup failed for %s: %s", ticker, exc) + + news_service_url = _news_service_url() + if news_service_url: + try: + params = {"ticker": ticker, "end_date": end_date, "limit": limit} + if start_date: + params["start_date"] = start_date + payload = _service_get_json( + news_service_url, + "/api/enriched-news", + params=params, + ) + news = [CompanyNews(**item) for item in payload.get("news", [])] + if news: + return news + except Exception as exc: + logger.info("News service lookup failed for %s: %s", ticker, exc) + try: all_news, data_source = _router.get_company_news( ticker=ticker, @@ -272,6 +410,19 @@ def get_market_cap(ticker: str, end_date: str) -> float | None: if not ticker: return None + service_url = _trading_service_url() + if service_url: + try: + payload = _service_get_json( + service_url, + "/api/market-cap", + params={"ticker": ticker, "end_date": end_date}, + ) + value = payload.get("market_cap") + return float(value) if value is not None else None + except Exception as exc: + logger.info("Trading service market-cap lookup failed for %s: %s", ticker, exc) + def _metrics_lookup(symbol: str, date: str): for source in _router.api_sources(): cache_key = f"{symbol}_ttm_{date}_10_{source}" diff --git a/docs/compat-removal-plan.md b/docs/compat-removal-plan.md new file mode 100644 index 0000000..f4f960f --- /dev/null +++ b/docs/compat-removal-plan.md @@ -0,0 +1,28 @@ +# Compatibility Removal Plan + +This document tracks the remaining migration-only surfaces that still exist +after the move to split-first development. + +## Migration-only Surfaces + +None currently remain as dedicated compatibility wrappers. + +## Completed Removals + +### `backend.app` + +- Removed after compatibility startup switched to + `backend.apps.combined_service:app` directly. + +### `shared.client.AgentServiceClient` + +- Removed after split-aware clients became the default import surface. +- Replacement: + - `ControlPlaneClient` + - `RuntimeServiceClient` + - `TradingServiceClient` + - `NewsServiceClient` + +### `backend.apps.combined_service` + +- Removed after split-service mode became the only supported dev startup path. diff --git a/frontend/README.md b/frontend/README.md index b5b523f..fa7fa51 100644 --- a/frontend/README.md +++ b/frontend/README.md @@ -1,7 +1,31 @@ - ## QuickStart ```bash cd frontend npm install npm run dev -``` \ No newline at end of file +``` + +## Optional Direct Service Calls + +The frontend still works with the compatibility backend entrypoint by default. +In the current test-stage setup, split services are the recommended default. +Point the frontend directly at those standalone services: + +```bash +VITE_CONTROL_API_BASE_URL=http://localhost:8000/api +VITE_RUNTIME_API_BASE_URL=http://localhost:8003/api/runtime +VITE_NEWS_SERVICE_URL=http://localhost:8002 +VITE_TRADING_SERVICE_URL=http://localhost:8001 +``` + +Current direct-call coverage: + +- runtime panel + gateway port discovery +- `story` +- `similar days` +- `range explain` +- `news for date` +- `news categories` + +If these variables are not set, the frontend falls back to the existing +WebSocket-driven compatibility flow. diff --git a/frontend/src/App.jsx b/frontend/src/App.jsx index a243630..e7c7415 100644 --- a/frontend/src/App.jsx +++ b/frontend/src/App.jsx @@ -6,6 +6,19 @@ import { AGENTS, INITIAL_TICKERS } from './config/constants'; // Services import { ReadOnlyClient } from './services/websocket'; import { startRuntime, uploadAgentSkillZip } from './services/runtimeApi'; +import { + fetchNewsCategoriesDirect, + fetchNewsForDateDirect, + fetchRangeExplainDirect, + fetchSimilarDaysDirect, + fetchStockStoryDirect, + hasDirectNewsService +} from './services/newsApi'; +import { + fetchInsiderTradesDirect, + fetchStockHistoryDirect, + hasDirectTradingService +} from './services/tradingApi'; // Hooks import { useFeedProcessor } from './hooks/useFeedProcessor'; @@ -937,7 +950,7 @@ export default function LiveTradingApp() { const requestStockHistory = useCallback((symbol, { force = false } = {}) => { const normalized = typeof symbol === 'string' ? symbol.trim().toUpperCase() : ''; - if (!normalized || !clientRef.current) { + if (!normalized) { return false; } @@ -945,6 +958,65 @@ export default function LiveTradingApp() { return false; } + const endDate = currentDate + ? String(currentDate).slice(0, 10) + : new Date().toISOString().slice(0, 10); + const end = new Date(`${endDate}T00:00:00`); + const start = new Date(end); + start.setDate(start.getDate() - 120); + const startDate = start.toISOString().slice(0, 10); + + if (hasDirectTradingService()) { + void fetchStockHistoryDirect(normalized, startDate, endDate) + .then((payload) => { + const prices = Array.isArray(payload?.prices) ? payload.prices : []; + setOhlcHistoryByTicker((prev) => ({ + ...prev, + [normalized]: prices + })); + setPriceHistoryByTicker((prev) => ({ + ...prev, + [normalized]: prices + .map((point) => { + const price = Number(point?.close); + const timestamp = point?.time; + if (!timestamp || !Number.isFinite(price)) { + return null; + } + return { + timestamp: String(timestamp), + label: String(timestamp), + price + }; + }) + .filter(Boolean) + })); + setHistorySourceByTicker((prev) => ({ + ...prev, + [normalized]: 'trading_service' + })); + }) + .catch((error) => { + console.error('Direct stock-history fetch failed, falling back to websocket:', error); + if (clientRef.current) { + const success = clientRef.current.send({ + type: 'get_stock_history', + ticker: normalized, + lookback_days: 120 + }); + if (success) { + requestedStockHistoryRef.current.add(normalized); + } + } + }); + requestedStockHistoryRef.current.add(normalized); + return true; + } + + if (!clientRef.current) { + return false; + } + const success = clientRef.current.send({ type: 'get_stock_history', ticker: normalized, @@ -956,7 +1028,7 @@ export default function LiveTradingApp() { } return success; - }, []); + }, [currentDate]); const requestStockExplainEvents = useCallback((symbol) => { const normalized = typeof symbol === 'string' ? symbol.trim().toUpperCase() : ''; @@ -984,9 +1056,49 @@ export default function LiveTradingApp() { const requestStockNewsForDate = useCallback((symbol, date) => { const normalized = typeof symbol === 'string' ? symbol.trim().toUpperCase() : ''; - if (!normalized || !date || !clientRef.current) { + if (!normalized || !date) { return false; } + + if (hasDirectNewsService()) { + void fetchNewsForDateDirect(normalized, date, 20) + .then((payload) => { + const targetDate = typeof payload?.date === 'string' ? payload.date.trim() : date; + const news = Array.isArray(payload?.news) ? payload.news : []; + const freshness = payload?.freshness || null; + setNewsByTicker((prev) => ({ + ...prev, + [normalized]: { + ...(prev[normalized] || {}), + byDate: { + ...((prev[normalized] && prev[normalized].byDate) || {}), + [targetDate]: news + }, + byDateFreshness: { + ...((prev[normalized] && prev[normalized].byDateFreshness) || {}), + [targetDate]: freshness + } + } + })); + }) + .catch((error) => { + console.error('Direct news-for-date fetch failed, falling back to websocket:', error); + if (clientRef.current) { + clientRef.current.send({ + type: 'get_stock_news_for_date', + ticker: normalized, + date, + limit: 20 + }); + } + }); + return true; + } + + if (!clientRef.current) { + return false; + } + return clientRef.current.send({ type: 'get_stock_news_for_date', ticker: normalized, @@ -1009,21 +1121,96 @@ export default function LiveTradingApp() { const requestStockNewsCategories = useCallback((symbol) => { const normalized = typeof symbol === 'string' ? symbol.trim().toUpperCase() : ''; - if (!normalized || !clientRef.current) { + if (!normalized) { return false; } + + const endDate = currentDate + ? String(currentDate).slice(0, 10) + : new Date().toISOString().slice(0, 10); + const end = new Date(`${endDate}T00:00:00`); + const start = new Date(end); + start.setDate(start.getDate() - 90); + const startDate = start.toISOString().slice(0, 10); + + if (hasDirectNewsService()) { + void fetchNewsCategoriesDirect(normalized, startDate, endDate, 200) + .then((payload) => { + const freshness = payload?.freshness || null; + setNewsByTicker((prev) => ({ + ...prev, + [normalized]: { + ...(prev[normalized] || {}), + categories: payload?.categories || {}, + categoriesStartDate: startDate, + categoriesEndDate: endDate, + categoriesFreshness: freshness + } + })); + }) + .catch((error) => { + console.error('Direct news-categories fetch failed, falling back to websocket:', error); + if (clientRef.current) { + clientRef.current.send({ + type: 'get_stock_news_categories', + ticker: normalized, + lookback_days: 90 + }); + } + }); + return true; + } + + if (!clientRef.current) { + return false; + } + return clientRef.current.send({ type: 'get_stock_news_categories', ticker: normalized, lookback_days: 90 }); - }, []); + }, [currentDate]); const requestStockInsiderTrades = useCallback((symbol, startDate = null, endDate = null) => { const normalized = typeof symbol === 'string' ? symbol.trim().toUpperCase() : ''; - if (!normalized || !clientRef.current) { + if (!normalized) { return false; } + + if (hasDirectTradingService()) { + void fetchInsiderTradesDirect(normalized, startDate, endDate, 50) + .then((payload) => { + const rows = Array.isArray(payload?.insider_trades) ? payload.insider_trades : []; + setInsiderTradesByTicker((prev) => ({ + ...prev, + [normalized]: { + ticker: normalized, + startDate: startDate || null, + endDate: endDate || null, + trades: rows + } + })); + }) + .catch((error) => { + console.error('Direct insider-trades fetch failed, falling back to websocket:', error); + if (clientRef.current) { + clientRef.current.send({ + type: 'get_stock_insider_trades', + ticker: normalized, + start_date: startDate, + end_date: endDate, + limit: 50 + }); + } + }); + return true; + } + + if (!clientRef.current) { + return false; + } + return clientRef.current.send({ type: 'get_stock_insider_trades', ticker: normalized, @@ -1046,9 +1233,52 @@ export default function LiveTradingApp() { const requestStockRangeExplain = useCallback((symbol, startDate, endDate, articleIds = []) => { const normalized = typeof symbol === 'string' ? symbol.trim().toUpperCase() : ''; - if (!normalized || !startDate || !endDate || !clientRef.current) { + if (!normalized || !startDate || !endDate) { return false; } + + if (hasDirectNewsService()) { + void fetchRangeExplainDirect(normalized, startDate, endDate, articleIds) + .then((payload) => { + const result = payload?.result && typeof payload.result === 'object' ? payload.result : null; + const freshness = payload?.freshness || null; + if (!result?.start_date || !result?.end_date) { + return; + } + const cacheKey = `${result.start_date}:${result.end_date}`; + setNewsByTicker((prev) => ({ + ...prev, + [normalized]: { + ...(prev[normalized] || {}), + rangeExplainCache: { + ...((prev[normalized] && prev[normalized].rangeExplainCache) || {}), + [cacheKey]: { + ...result, + freshness + } + } + } + })); + }) + .catch((error) => { + console.error('Direct range explain fetch failed, falling back to websocket:', error); + if (clientRef.current) { + clientRef.current.send({ + type: 'get_stock_range_explain', + ticker: normalized, + start_date: startDate, + end_date: endDate, + article_ids: Array.isArray(articleIds) ? articleIds : [] + }); + } + }); + return true; + } + + if (!clientRef.current) { + return false; + } + return clientRef.current.send({ type: 'get_stock_range_explain', ticker: normalized, @@ -1060,9 +1290,51 @@ export default function LiveTradingApp() { const requestStockStory = useCallback((symbol, asOfDate = null) => { const normalized = typeof symbol === 'string' ? symbol.trim().toUpperCase() : ''; - if (!normalized || !clientRef.current) { + if (!normalized) { return false; } + + if (hasDirectNewsService()) { + void fetchStockStoryDirect(normalized, asOfDate) + .then((payload) => { + const storyDate = typeof payload?.as_of_date === 'string' ? payload.as_of_date.trim() : ''; + const freshness = payload?.freshness || null; + if (!storyDate) { + return; + } + setNewsByTicker((prev) => ({ + ...prev, + [normalized]: { + ...(prev[normalized] || {}), + storyCache: { + ...((prev[normalized] && prev[normalized].storyCache) || {}), + [storyDate]: { + story: payload.story || '', + source: payload.source || 'news_service', + asOfDate: storyDate, + freshness + } + } + } + })); + }) + .catch((error) => { + console.error('Direct story fetch failed, falling back to websocket:', error); + if (clientRef.current) { + clientRef.current.send({ + type: 'get_stock_story', + ticker: normalized, + as_of_date: asOfDate + }); + } + }); + return true; + } + + if (!clientRef.current) { + return false; + } + return clientRef.current.send({ type: 'get_stock_story', ticker: normalized, @@ -1072,9 +1344,46 @@ export default function LiveTradingApp() { const requestStockSimilarDays = useCallback((symbol, date, topK = 8) => { const normalized = typeof symbol === 'string' ? symbol.trim().toUpperCase() : ''; - if (!normalized || !date || !clientRef.current) { + if (!normalized || !date) { return false; } + + if (hasDirectNewsService()) { + void fetchSimilarDaysDirect(normalized, date, topK) + .then((payload) => { + const targetDate = typeof payload?.target_date === 'string' ? payload.target_date.trim() : date; + if (!targetDate) { + return; + } + setNewsByTicker((prev) => ({ + ...prev, + [normalized]: { + ...(prev[normalized] || {}), + similarDaysCache: { + ...((prev[normalized] && prev[normalized].similarDaysCache) || {}), + [targetDate]: payload + } + } + })); + }) + .catch((error) => { + console.error('Direct similar-days fetch failed, falling back to websocket:', error); + if (clientRef.current) { + clientRef.current.send({ + type: 'get_stock_similar_days', + ticker: normalized, + date, + top_k: topK + }); + } + }); + return true; + } + + if (!clientRef.current) { + return false; + } + return clientRef.current.send({ type: 'get_stock_similar_days', ticker: normalized, @@ -1707,7 +2016,8 @@ export default function LiveTradingApp() { items: Array.isArray(e.news) ? e.news : [], source: e.source || null, startDate: e.start_date || null, - endDate: e.end_date || null + endDate: e.end_date || null, + freshness: e.freshness || null } })); requestStockNewsTimeline(symbol); @@ -1726,6 +2036,10 @@ export default function LiveTradingApp() { byDate: { ...((prev[symbol] && prev[symbol].byDate) || {}), [date]: Array.isArray(e.news) ? e.news : [] + }, + byDateFreshness: { + ...((prev[symbol] && prev[symbol].byDateFreshness) || {}), + [date]: e.freshness || null } } })); @@ -1742,7 +2056,8 @@ export default function LiveTradingApp() { ...(prev[symbol] || {}), timeline: Array.isArray(e.timeline) ? e.timeline : [], timelineStartDate: e.start_date || null, - timelineEndDate: e.end_date || null + timelineEndDate: e.end_date || null, + timelineFreshness: e.freshness || null } })); }, @@ -1758,7 +2073,8 @@ export default function LiveTradingApp() { ...(prev[symbol] || {}), categories: e.categories || {}, categoriesStartDate: e.start_date || null, - categoriesEndDate: e.end_date || null + categoriesEndDate: e.end_date || null, + categoriesFreshness: e.freshness || null } })); }, @@ -1805,7 +2121,10 @@ export default function LiveTradingApp() { ...(prev[symbol] || {}), rangeExplainCache: { ...((prev[symbol] && prev[symbol].rangeExplainCache) || {}), - [cacheKey]: result + [cacheKey]: { + ...result, + freshness: e.freshness || null + } } } })); @@ -1826,7 +2145,8 @@ export default function LiveTradingApp() { [asOfDate]: { story: e.story || '', source: e.source || null, - asOfDate + asOfDate, + freshness: e.freshness || null } } } @@ -1852,7 +2172,8 @@ export default function LiveTradingApp() { [date]: { target_features: e.target_features || {}, items: Array.isArray(e.items) ? e.items : [], - error: e.error || null + error: e.error || null, + freshness: e.freshness || null } } } diff --git a/frontend/src/components/StockExplainView.jsx b/frontend/src/components/StockExplainView.jsx index a7a976c..644d66a 100644 --- a/frontend/src/components/StockExplainView.jsx +++ b/frontend/src/components/StockExplainView.jsx @@ -77,6 +77,7 @@ export default function StockExplainView({ visibleNews, newsCategories, visibleNewsByCategory, + selectedNewsFreshness, selectedRangeWindow, selectedRangeExplain, latestSignal, @@ -337,6 +338,7 @@ export default function StockExplainView({ newsSnapshot={newsSnapshot} visibleNewsByCategory={visibleNewsByCategory} visibleNews={visibleNews} + selectedNewsFreshness={selectedNewsFreshness} activeNewsCategory={activeNewsCategory} onSelectNewsCategory={setActiveNewsCategory} activeNewsSentiment={activeNewsSentiment} diff --git a/frontend/src/components/explain/ExplainNewsSection.jsx b/frontend/src/components/explain/ExplainNewsSection.jsx index 75f0e7f..92016d0 100644 --- a/frontend/src/components/explain/ExplainNewsSection.jsx +++ b/frontend/src/components/explain/ExplainNewsSection.jsx @@ -1,6 +1,12 @@ import React from 'react'; import { formatDateTime } from '../../utils/formatters'; +function renderFreshness(freshness) { + if (!freshness || typeof freshness !== 'object') return null; + const lastFetch = freshness.last_news_fetch || '-'; + return `新闻更新到 ${lastFetch}${freshness.refreshed ? ' · 本次已刷新' : ''}`; +} + function categoryLabel(value) { const normalized = String(value || '').trim().toLowerCase(); const labels = { @@ -47,6 +53,7 @@ export default function ExplainNewsSection({ newsSnapshot, visibleNewsByCategory, visibleNews, + selectedNewsFreshness, activeNewsCategory, onSelectNewsCategory, activeNewsSentiment, @@ -64,6 +71,11 @@ export default function ExplainNewsSection({