feat: 微服务架构拆分和前后端优化

后端:
- 拆分出 agent_service, runtime_service, trading_service, news_service
- Gateway 模块化拆分 (gateway_*.py)
- 添加 domains/ 领域层
- 新增 control_client, runtime_client
- 更新 start-dev.sh 支持 split 服务模式

前端:
- 完善 API 服务层 (newsApi, tradingApi)
- 更新 vite.config.js
- Explain 组件优化

测试:
- 添加多个服务 app 测试

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-03-23 17:45:39 +08:00
parent 0f1bc2bb39
commit 3448667b79
54 changed files with 5440 additions and 2947 deletions

View File

@@ -1,115 +0,0 @@
# -*- coding: utf-8 -*-
"""
FastAPI Application - REST API for EvoTraders
Provides HTTP endpoints for:
- Agent management
- Workspace management
- Tool guard operations
- Health checks
"""
from contextlib import asynccontextmanager
from pathlib import Path
from typing import AsyncGenerator
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from backend.api import agents_router, workspaces_router, guard_router, runtime_router
from backend.agents import AgentFactory, WorkspaceManager, get_registry
# Global instances (initialized on startup)
agent_factory: AgentFactory | None = None
workspace_manager: WorkspaceManager | None = None
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator:
"""
Application lifespan manager.
Initializes global services on startup and cleans up on shutdown.
"""
global agent_factory, workspace_manager
# Startup: Initialize services
project_root = Path(__file__).parent.parent
# Initialize workspace manager
workspace_manager = WorkspaceManager(project_root=project_root)
# Initialize agent factory
agent_factory = AgentFactory(project_root=project_root)
# Ensure workspaces root exists
agent_factory.workspaces_root.mkdir(parents=True, exist_ok=True)
# Get or create global registry
registry = get_registry()
print(f"✓ EvoTraders API started")
print(f" - Workspaces root: {agent_factory.workspaces_root}")
print(f" - Registered agents: {registry.get_agent_count()}")
yield
# Shutdown: Cleanup
print("✓ EvoTraders API shutting down")
# Create FastAPI application
app = FastAPI(
title="EvoTraders API",
description="REST API for the EvoTraders multi-agent trading system",
version="0.1.0",
lifespan=lifespan,
)
# CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Configure appropriately for production
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Health check endpoint
@app.get("/health")
async def health_check():
"""Health check endpoint."""
registry = get_registry()
return {
"status": "healthy",
"version": "0.1.0",
"agents_registered": registry.get_agent_count(),
"workspaces_available": len(workspace_manager.list_workspaces()) if workspace_manager else 0,
}
# API status endpoint
@app.get("/api/status")
async def api_status():
"""Get API status and system information."""
registry = get_registry()
stats = registry.get_stats()
return {
"status": "operational",
"registry": stats,
}
# Include routers
app.include_router(workspaces_router)
app.include_router(agents_router)
app.include_router(guard_router)
app.include_router(runtime_router)
# Main entry point for running with uvicorn
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)

27
backend/apps/__init__.py Normal file
View File

@@ -0,0 +1,27 @@
# -*- coding: utf-8 -*-
"""Application surfaces for progressive service extraction."""
from .agent_service import app as agent_app
from .agent_service import create_app as create_agent_app
from .news_service import app as news_app
from .news_service import create_app as create_news_app
from .runtime_service import app as runtime_app
from .runtime_service import create_app as create_runtime_app
from .trading_service import app as trading_app
from .trading_service import create_app as create_trading_app
app = agent_app
create_app = create_agent_app
__all__ = [
"app",
"create_app",
"agent_app",
"create_agent_app",
"news_app",
"create_news_app",
"runtime_app",
"create_runtime_app",
"trading_app",
"create_trading_app",
]

View File

@@ -0,0 +1,94 @@
# -*- coding: utf-8 -*-
"""Agent control-plane FastAPI surface."""
from __future__ import annotations
from contextlib import asynccontextmanager
from pathlib import Path
from typing import AsyncGenerator
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from backend.api import agents_router, guard_router, workspaces_router
from backend.agents import AgentFactory, WorkspaceManager, get_registry
# Global instances (initialized on startup)
agent_factory: AgentFactory | None = None
workspace_manager: WorkspaceManager | None = None
def create_app(project_root: Path | None = None) -> FastAPI:
"""Create the agent control-plane app."""
resolved_project_root = project_root or Path(__file__).resolve().parents[2]
@asynccontextmanager
async def lifespan(_app: FastAPI) -> AsyncGenerator[None, None]:
"""Initialize workspace and registry state for the control plane."""
global agent_factory, workspace_manager
workspace_manager = WorkspaceManager(project_root=resolved_project_root)
agent_factory = AgentFactory(project_root=resolved_project_root)
agent_factory.workspaces_root.mkdir(parents=True, exist_ok=True)
registry = get_registry()
print("✓ EvoTraders API started")
print(f" - Workspaces root: {agent_factory.workspaces_root}")
print(f" - Registered agents: {registry.get_agent_count()}")
yield
print("✓ EvoTraders API shutting down")
app = FastAPI(
title="EvoTraders Agent Service",
description="REST API for the EvoTraders multi-agent control plane",
version="0.1.0",
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/health")
async def health_check() -> dict[str, object]:
"""Health check endpoint."""
registry = get_registry()
return {
"status": "healthy",
"version": "0.1.0",
"agents_registered": registry.get_agent_count(),
"workspaces_available": (
len(workspace_manager.list_workspaces())
if workspace_manager
else 0
),
}
@app.get("/api/status")
async def api_status() -> dict[str, object]:
"""Get API status and registry information."""
registry = get_registry()
return {
"status": "operational",
"registry": registry.get_stats(),
}
app.include_router(workspaces_router)
app.include_router(agents_router)
app.include_router(guard_router)
return app
app = create_app()
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)

View File

@@ -0,0 +1,153 @@
# -*- coding: utf-8 -*-
"""News and explain FastAPI surface."""
from __future__ import annotations
from typing import Any
from fastapi import Depends, FastAPI, Query
from fastapi.middleware.cors import CORSMiddleware
from backend.data.market_store import MarketStore
from backend.domains import news as news_domain
def get_market_store() -> MarketStore:
"""Create a market store dependency."""
return MarketStore()
def create_app() -> FastAPI:
"""Create the news/explain service app."""
app = FastAPI(
title="EvoTraders News Service",
description="Read-only news enrichment and explain service surface extracted from the monolith",
version="0.1.0",
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/health")
async def health_check() -> dict[str, str]:
return {"status": "healthy", "service": "news-service"}
@app.get("/api/enriched-news")
async def api_get_enriched_news(
ticker: str = Query(..., min_length=1),
start_date: str | None = Query(None),
end_date: str | None = Query(None),
limit: int = Query(100, ge=1, le=1000),
store: MarketStore = Depends(get_market_store),
) -> dict[str, Any]:
return news_domain.get_enriched_news(
store,
ticker=ticker,
start_date=start_date,
end_date=end_date,
limit=limit,
)
@app.get("/api/news-for-date")
async def api_get_news_for_date(
ticker: str = Query(..., min_length=1),
date: str = Query(...),
limit: int = Query(20, ge=1, le=100),
store: MarketStore = Depends(get_market_store),
) -> dict[str, Any]:
return news_domain.get_news_for_date(
store,
ticker=ticker,
date=date,
limit=limit,
)
@app.get("/api/news-timeline")
async def api_get_news_timeline(
ticker: str = Query(..., min_length=1),
start_date: str = Query(...),
end_date: str = Query(...),
store: MarketStore = Depends(get_market_store),
) -> dict[str, Any]:
return news_domain.get_news_timeline(
store,
ticker=ticker,
start_date=start_date,
end_date=end_date,
)
@app.get("/api/categories")
async def api_get_categories(
ticker: str = Query(..., min_length=1),
start_date: str | None = Query(None),
end_date: str | None = Query(None),
limit: int = Query(200, ge=1, le=1000),
store: MarketStore = Depends(get_market_store),
) -> dict[str, Any]:
return news_domain.get_news_categories(
store,
ticker=ticker,
start_date=start_date,
end_date=end_date,
limit=limit,
)
@app.get("/api/similar-days")
async def api_get_similar_days(
ticker: str = Query(..., min_length=1),
date: str = Query(...),
n_similar: int = Query(5, ge=1, le=20),
store: MarketStore = Depends(get_market_store),
) -> dict[str, Any]:
return news_domain.get_similar_days_payload(
store,
ticker=ticker,
date=date,
n_similar=n_similar,
)
@app.get("/api/stories/{ticker}")
async def api_get_story(
ticker: str,
as_of_date: str = Query(...),
store: MarketStore = Depends(get_market_store),
) -> dict[str, Any]:
return news_domain.get_story_payload(
store,
ticker=ticker,
as_of_date=as_of_date,
)
@app.get("/api/range-explain")
async def api_get_range_explain(
ticker: str = Query(..., min_length=1),
start_date: str = Query(...),
end_date: str = Query(...),
article_ids: list[str] = Query(default=[]),
limit: int = Query(100, ge=1, le=500),
store: MarketStore = Depends(get_market_store),
) -> dict[str, Any]:
return news_domain.get_range_explain_payload(
store,
ticker=ticker,
start_date=start_date,
end_date=end_date,
article_ids=article_ids,
limit=limit,
)
return app
app = create_app()
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8002)

View File

@@ -0,0 +1,68 @@
# -*- coding: utf-8 -*-
"""Dedicated runtime service FastAPI surface."""
from __future__ import annotations
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from backend.api import runtime_router
from backend.api.runtime import get_runtime_state
def create_app() -> FastAPI:
"""Create the runtime service app."""
app = FastAPI(
title="EvoTraders Runtime Service",
description="Runtime lifecycle and gateway service surface extracted from the monolith",
version="0.1.0",
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/health")
async def health_check() -> dict[str, object]:
"""Health check for the runtime service."""
runtime_state = get_runtime_state()
process = runtime_state.gateway_process
is_running = process is not None and process.poll() is None
return {
"status": "healthy",
"service": "runtime-service",
"gateway_running": is_running,
"gateway_port": runtime_state.gateway_port,
}
@app.get("/api/status")
async def api_status() -> dict[str, object]:
"""Service-level status payload for runtime orchestration."""
runtime_state = get_runtime_state()
process = runtime_state.gateway_process
is_running = process is not None and process.poll() is None
return {
"status": "operational",
"service": "runtime-service",
"runtime": {
"gateway_running": is_running,
"gateway_port": runtime_state.gateway_port,
"has_runtime_manager": runtime_state.runtime_manager is not None,
},
}
app.include_router(runtime_router)
return app
app = create_app()
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8003)

View File

@@ -0,0 +1,142 @@
# -*- coding: utf-8 -*-
"""Trading data FastAPI surface."""
from __future__ import annotations
from typing import Any
from fastapi import FastAPI, Query
from fastapi.middleware.cors import CORSMiddleware
from backend.domains import trading as trading_domain
from shared.schema import (
CompanyNewsResponse,
FinancialMetricsResponse,
InsiderTradeResponse,
LineItemResponse,
PriceResponse,
)
def create_app() -> FastAPI:
"""Create the trading data service app."""
app = FastAPI(
title="EvoTraders Trading Service",
description="Read-only trading data service surface extracted from the monolith",
version="0.1.0",
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/health")
async def health_check() -> dict[str, str]:
"""Health check endpoint."""
return {"status": "healthy", "service": "trading-service"}
@app.get("/api/prices", response_model=PriceResponse)
async def api_get_prices(
ticker: str = Query(..., min_length=1),
start_date: str = Query(...),
end_date: str = Query(...),
) -> PriceResponse:
payload = trading_domain.get_prices_payload(
ticker=ticker,
start_date=start_date,
end_date=end_date,
)
return PriceResponse(ticker=payload["ticker"], prices=payload["prices"])
@app.get("/api/financials", response_model=FinancialMetricsResponse)
async def api_get_financials(
ticker: str = Query(..., min_length=1),
end_date: str = Query(...),
period: str = Query("ttm"),
limit: int = Query(10, ge=1, le=100),
) -> FinancialMetricsResponse:
payload = trading_domain.get_financials_payload(
ticker=ticker,
end_date=end_date,
period=period,
limit=limit,
)
return FinancialMetricsResponse(financial_metrics=payload["financial_metrics"])
@app.get("/api/news", response_model=CompanyNewsResponse)
async def api_get_news(
ticker: str = Query(..., min_length=1),
end_date: str = Query(...),
start_date: str | None = Query(None),
limit: int = Query(1000, ge=1, le=5000),
) -> CompanyNewsResponse:
payload = trading_domain.get_news_payload(
ticker=ticker,
end_date=end_date,
start_date=start_date,
limit=limit,
)
return CompanyNewsResponse(news=payload["news"])
@app.get("/api/insider-trades", response_model=InsiderTradeResponse)
async def api_get_insider_trades(
ticker: str = Query(..., min_length=1),
end_date: str = Query(...),
start_date: str | None = Query(None),
limit: int = Query(1000, ge=1, le=5000),
) -> InsiderTradeResponse:
payload = trading_domain.get_insider_trades_payload(
ticker=ticker,
end_date=end_date,
start_date=start_date,
limit=limit,
)
return InsiderTradeResponse(insider_trades=payload["insider_trades"])
@app.get("/api/market/status")
async def api_get_market_status() -> dict[str, Any]:
"""Return current market status using the existing market service logic."""
return trading_domain.get_market_status_payload()
@app.get("/api/market-cap")
async def api_get_market_cap(
ticker: str = Query(..., min_length=1),
end_date: str = Query(...),
) -> dict[str, Any]:
"""Return market cap for one ticker/date."""
return trading_domain.get_market_cap_payload(
ticker=ticker,
end_date=end_date,
)
@app.get("/api/line-items", response_model=LineItemResponse)
async def api_get_line_items(
ticker: str = Query(..., min_length=1),
line_items: list[str] = Query(...),
end_date: str = Query(...),
period: str = Query("ttm"),
limit: int = Query(10, ge=1, le=100),
) -> LineItemResponse:
payload = trading_domain.get_line_items_payload(
ticker=ticker,
line_items=line_items,
end_date=end_date,
period=period,
limit=limit,
)
return LineItemResponse(search_results=payload["search_results"])
return app
app = create_app()
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8001)

View File

@@ -1,9 +1,30 @@
# -*- coding: utf-8 -*-
"""Core pipeline and orchestration logic"""
"""Core pipeline and orchestration logic.
Keep ``pipeline_runner`` behind lazy wrappers so importing ``backend.core`` does
not immediately pull in the gateway runtime graph.
"""
from .pipeline import TradingPipeline
from .state_sync import StateSync
from .pipeline_runner import create_agents, create_long_term_memory, stop_gateway
def create_agents(*args, **kwargs):
from .pipeline_runner import create_agents as _create_agents
return _create_agents(*args, **kwargs)
def create_long_term_memory(*args, **kwargs):
from .pipeline_runner import create_long_term_memory as _create_long_term_memory
return _create_long_term_memory(*args, **kwargs)
def stop_gateway(*args, **kwargs):
from .pipeline_runner import stop_gateway as _stop_gateway
return _stop_gateway(*args, **kwargs)
__all__ = [
"TradingPipeline",

View File

@@ -11,7 +11,7 @@ import pandas as pd
import yfinance as yf
from backend.config.data_config import DataSource, get_data_sources
from backend.data.schema import (
from shared.schema import (
CompanyFactsResponse,
CompanyNews,
CompanyNewsResponse,

View File

@@ -1,194 +1,50 @@
# -*- coding: utf-8 -*-
from pydantic import BaseModel
"""Compatibility schema bridge.
This module preserves the legacy ``backend.data.schema`` import path while
delegating the actual schema definitions to ``shared.schema``. Keeping one
canonical DTO set avoids drift as the monolith is split into service-specific
packages.
"""
class Price(BaseModel):
open: float
close: float
high: float
low: float
volume: int
time: str
from shared.schema import (
AgentStateData,
AgentStateMetadata,
AnalystSignal,
CompanyFacts,
CompanyFactsResponse,
CompanyNews,
CompanyNewsResponse,
FinancialMetrics,
FinancialMetricsResponse,
InsiderTrade,
InsiderTradeResponse,
LineItem,
LineItemResponse,
Portfolio,
Position,
Price,
PriceResponse,
TickerAnalysis,
)
class PriceResponse(BaseModel):
ticker: str
prices: list[Price]
class FinancialMetrics(BaseModel):
ticker: str
report_period: str
period: str
currency: str
market_cap: float | None
enterprise_value: float | None
price_to_earnings_ratio: float | None
price_to_book_ratio: float | None
price_to_sales_ratio: float | None
enterprise_value_to_ebitda_ratio: float | None
enterprise_value_to_revenue_ratio: float | None
free_cash_flow_yield: float | None
peg_ratio: float | None
gross_margin: float | None
operating_margin: float | None
net_margin: float | None
return_on_equity: float | None
return_on_assets: float | None
return_on_invested_capital: float | None
asset_turnover: float | None
inventory_turnover: float | None
receivables_turnover: float | None
days_sales_outstanding: float | None
operating_cycle: float | None
working_capital_turnover: float | None
current_ratio: float | None
quick_ratio: float | None
cash_ratio: float | None
operating_cash_flow_ratio: float | None
debt_to_equity: float | None
debt_to_assets: float | None
interest_coverage: float | None
revenue_growth: float | None
earnings_growth: float | None
book_value_growth: float | None
earnings_per_share_growth: float | None
free_cash_flow_growth: float | None
operating_income_growth: float | None
ebitda_growth: float | None
payout_ratio: float | None
earnings_per_share: float | None
book_value_per_share: float | None
free_cash_flow_per_share: float | None
class FinancialMetricsResponse(BaseModel):
financial_metrics: list[FinancialMetrics]
class LineItem(BaseModel):
ticker: str
report_period: str
period: str
currency: str
# Allow additional fields dynamically
model_config = {"extra": "allow"}
class LineItemResponse(BaseModel):
search_results: list[LineItem]
class InsiderTrade(BaseModel):
ticker: str
issuer: str | None
name: str | None
title: str | None
is_board_director: bool | None
transaction_date: str | None
transaction_shares: float | None
transaction_price_per_share: float | None
transaction_value: float | None
shares_owned_before_transaction: float | None
shares_owned_after_transaction: float | None
security_title: str | None
filing_date: str
class InsiderTradeResponse(BaseModel):
insider_trades: list[InsiderTrade]
class CompanyNews(BaseModel):
category: str | None = None
ticker: str
title: str
related: str | None = None
source: str
date: str | None = None
url: str
summary: str | None = None
class CompanyNewsResponse(BaseModel):
news: list[CompanyNews]
class CompanyFacts(BaseModel):
ticker: str
name: str
cik: str | None = None
industry: str | None = None
sector: str | None = None
category: str | None = None
exchange: str | None = None
is_active: bool | None = None
listing_date: str | None = None
location: str | None = None
market_cap: float | None = None
number_of_employees: int | None = None
sec_filings_url: str | None = None
sic_code: str | None = None
sic_industry: str | None = None
sic_sector: str | None = None
website_url: str | None = None
weighted_average_shares: int | None = None
class CompanyFactsResponse(BaseModel):
company_facts: CompanyFacts
class Position(BaseModel):
"""Position information - for Portfolio mode"""
long: int = 0 # Long position quantity (shares)
short: int = 0 # Short position quantity (shares)
long_cost_basis: float = 0.0 # Long position average cost
short_cost_basis: float = 0.0 # Short position average cost
class Portfolio(BaseModel):
"""Portfolio - for Portfolio mode"""
cash: float = 100000.0 # Available cash
positions: dict[str, Position] = {} # ticker -> Position mapping
# Margin requirement (0.0 means shorting disabled, 0.5 means 50% margin)
margin_requirement: float = 0.0
margin_used: float = 0.0 # Margin used
class AnalystSignal(BaseModel):
signal: str | None = None
confidence: float | None = None
reasoning: dict | str | None = None
# Extended fields for richer signal information
reasons: list[str] | None = None # Core drivers/reasons for the signal
risks: list[str] | None = None # Key risk factors
invalidation: str | None = None # Conditions that would invalidate the thesis
next_action: str | None = None # Suggested next action for PM
# Valuation-related fields
intrinsic_value: float | None = None # DCF intrinsic value
fair_value_range: dict | None = None # {bear, base, bull} fair value range
value_gap_pct: float | None = None # Value gap percentage
valuation_methods: list[str] | None = None # List of valuation methods used
max_position_size: float | None = None # For risk management signals
class TickerAnalysis(BaseModel):
ticker: str
analyst_signals: dict[str, AnalystSignal] # agent_name -> signal mapping
class AgentStateData(BaseModel):
tickers: list[str]
portfolio: Portfolio
start_date: str
end_date: str
ticker_analyses: dict[str, TickerAnalysis] # ticker -> analysis mapping
class AgentStateMetadata(BaseModel):
show_reasoning: bool = False
model_config = {"extra": "allow"}
__all__ = [
"Price",
"PriceResponse",
"FinancialMetrics",
"FinancialMetricsResponse",
"LineItem",
"LineItemResponse",
"InsiderTrade",
"InsiderTradeResponse",
"CompanyNews",
"CompanyNewsResponse",
"CompanyFacts",
"CompanyFactsResponse",
"Position",
"Portfolio",
"AnalystSignal",
"TickerAnalysis",
"AgentStateData",
"AgentStateMetadata",
]

View File

@@ -0,0 +1,2 @@
# -*- coding: utf-8 -*-
"""Domain modules for split service internals."""

277
backend/domains/news.py Normal file
View File

@@ -0,0 +1,277 @@
# -*- coding: utf-8 -*-
"""News/explain domain helpers shared by app surfaces and gateway fallbacks."""
from __future__ import annotations
from typing import Any
from backend.data.market_store import MarketStore
from backend.data.market_ingest import update_ticker_incremental
from backend.enrich.news_enricher import enrich_news_for_symbol
from backend.explain.range_explainer import build_range_explanation
from backend.explain.similarity_service import find_similar_days
from backend.explain.story_service import get_or_create_stock_story
def news_rows_need_enrichment(rows: list[dict[str, Any]]) -> bool:
"""Return whether news rows are missing explain-oriented analysis fields."""
if not rows:
return True
return all(
not row.get("sentiment")
and not row.get("relevance")
and not row.get("key_discussion")
for row in rows
)
def ensure_news_fresh(
store: MarketStore,
*,
ticker: str,
target_date: str | None = None,
) -> dict[str, Any]:
"""Refresh raw news incrementally when stored watermarks are stale."""
normalized_target = str(target_date or "").strip()[:10]
if not normalized_target:
return {
"ticker": ticker,
"target_date": None,
"last_news_fetch": None,
"refreshed": False,
}
watermarks = store.get_ticker_watermarks(ticker)
last_news_fetch = str(watermarks.get("last_news_fetch") or "").strip()[:10]
refreshed = False
if not last_news_fetch or last_news_fetch < normalized_target:
update_ticker_incremental(
ticker,
end_date=normalized_target,
store=store,
)
refreshed = True
watermarks = store.get_ticker_watermarks(ticker)
last_news_fetch = str(watermarks.get("last_news_fetch") or "").strip()[:10]
return {
"ticker": ticker,
"target_date": normalized_target,
"last_news_fetch": last_news_fetch or None,
"refreshed": refreshed,
}
def get_enriched_news(
store: MarketStore,
*,
ticker: str,
start_date: str | None = None,
end_date: str | None = None,
limit: int = 100,
) -> dict[str, Any]:
freshness = ensure_news_fresh(store, ticker=ticker, target_date=end_date)
rows = store.get_news_items_enriched(
ticker,
start_date=start_date,
end_date=end_date,
limit=limit,
)
if news_rows_need_enrichment(rows):
enrich_news_for_symbol(
store,
ticker,
start_date=start_date,
end_date=end_date,
limit=limit,
)
rows = store.get_news_items_enriched(
ticker,
start_date=start_date,
end_date=end_date,
limit=limit,
)
return {"ticker": ticker, "news": rows, "freshness": freshness}
def get_news_for_date(
store: MarketStore,
*,
ticker: str,
date: str,
limit: int = 20,
) -> dict[str, Any]:
freshness = ensure_news_fresh(store, ticker=ticker, target_date=date)
rows = store.get_news_items_enriched(
ticker,
trade_date=date,
limit=limit,
)
if news_rows_need_enrichment(rows):
enrich_news_for_symbol(
store,
ticker,
start_date=date,
end_date=date,
limit=limit,
)
rows = store.get_news_items_enriched(
ticker,
trade_date=date,
limit=limit,
)
return {"ticker": ticker, "date": date, "news": rows, "freshness": freshness}
def get_news_timeline(
store: MarketStore,
*,
ticker: str,
start_date: str,
end_date: str,
) -> dict[str, Any]:
freshness = ensure_news_fresh(store, ticker=ticker, target_date=end_date)
timeline = store.get_news_timeline_enriched(
ticker,
start_date=start_date,
end_date=end_date,
)
if not timeline:
enrich_news_for_symbol(
store,
ticker,
start_date=start_date,
end_date=end_date,
limit=200,
)
timeline = store.get_news_timeline_enriched(
ticker,
start_date=start_date,
end_date=end_date,
)
return {
"ticker": ticker,
"timeline": timeline,
"start_date": start_date,
"end_date": end_date,
"freshness": freshness,
}
def get_news_categories(
store: MarketStore,
*,
ticker: str,
start_date: str | None = None,
end_date: str | None = None,
limit: int = 200,
) -> dict[str, Any]:
freshness = ensure_news_fresh(store, ticker=ticker, target_date=end_date)
rows = store.get_news_items_enriched(
ticker,
start_date=start_date,
end_date=end_date,
limit=limit,
)
if news_rows_need_enrichment(rows):
enrich_news_for_symbol(
store,
ticker,
start_date=start_date,
end_date=end_date,
limit=limit,
)
categories = store.get_news_categories_enriched(
ticker,
start_date=start_date,
end_date=end_date,
limit=limit,
)
return {"ticker": ticker, "categories": categories, "freshness": freshness}
def get_similar_days_payload(
store: MarketStore,
*,
ticker: str,
date: str,
n_similar: int = 5,
) -> dict[str, Any]:
freshness = ensure_news_fresh(store, ticker=ticker, target_date=date)
result = find_similar_days(
store,
symbol=ticker,
target_date=date,
top_k=n_similar,
)
result["freshness"] = freshness
return result
def get_story_payload(
store: MarketStore,
*,
ticker: str,
as_of_date: str,
) -> dict[str, Any]:
freshness = ensure_news_fresh(store, ticker=ticker, target_date=as_of_date)
enrich_news_for_symbol(
store,
ticker,
end_date=as_of_date,
limit=80,
)
result = get_or_create_stock_story(
store,
symbol=ticker,
as_of_date=as_of_date,
)
result["freshness"] = freshness
return result
def get_range_explain_payload(
store: MarketStore,
*,
ticker: str,
start_date: str,
end_date: str,
article_ids: list[str] | None = None,
limit: int = 100,
) -> dict[str, Any]:
freshness = ensure_news_fresh(store, ticker=ticker, target_date=end_date)
news_rows = []
if article_ids:
news_rows = store.get_news_by_ids_enriched(ticker, article_ids)
if not news_rows:
news_rows = store.get_news_items_enriched(
ticker,
start_date=start_date,
end_date=end_date,
limit=limit,
)
if news_rows_need_enrichment(news_rows):
enrich_news_for_symbol(
store,
ticker,
start_date=start_date,
end_date=end_date,
limit=limit,
)
news_rows = (
store.get_news_by_ids_enriched(ticker, article_ids)
if article_ids
else store.get_news_items_enriched(
ticker,
start_date=start_date,
end_date=end_date,
limit=limit,
)
)
result = build_range_explanation(
ticker=ticker,
start_date=start_date,
end_date=end_date,
news_rows=news_rows,
)
return {"ticker": ticker, "result": result, "freshness": freshness}

106
backend/domains/trading.py Normal file
View File

@@ -0,0 +1,106 @@
# -*- coding: utf-8 -*-
"""Trading domain helpers shared by app surfaces and gateway fallbacks."""
from __future__ import annotations
from typing import Any
from backend.services.market import MarketService
from backend.tools.data_tools import (
get_company_news,
get_financial_metrics,
get_insider_trades,
get_market_cap,
get_prices,
search_line_items,
)
def get_prices_payload(*, ticker: str, start_date: str, end_date: str) -> dict[str, Any]:
return {
"ticker": ticker,
"prices": get_prices(ticker, start_date, end_date),
}
def get_financials_payload(
*,
ticker: str,
end_date: str,
period: str = "ttm",
limit: int = 10,
) -> dict[str, Any]:
return {
"financial_metrics": get_financial_metrics(
ticker=ticker,
end_date=end_date,
period=period,
limit=limit,
)
}
def get_news_payload(
*,
ticker: str,
end_date: str,
start_date: str | None = None,
limit: int = 1000,
) -> dict[str, Any]:
return {
"news": get_company_news(
ticker=ticker,
end_date=end_date,
start_date=start_date,
limit=limit,
)
}
def get_insider_trades_payload(
*,
ticker: str,
end_date: str,
start_date: str | None = None,
limit: int = 1000,
) -> dict[str, Any]:
return {
"insider_trades": get_insider_trades(
ticker=ticker,
end_date=end_date,
start_date=start_date,
limit=limit,
)
}
def get_market_status_payload() -> dict[str, Any]:
market_service = MarketService(tickers=[])
return market_service.get_market_status()
def get_market_cap_payload(*, ticker: str, end_date: str) -> dict[str, Any]:
return {
"ticker": ticker,
"end_date": end_date,
"market_cap": get_market_cap(ticker, end_date),
}
def get_line_items_payload(
*,
ticker: str,
line_items: list[str],
end_date: str,
period: str = "ttm",
limit: int = 10,
) -> dict[str, Any]:
return {
"search_results": search_line_items(
ticker=ticker,
line_items=line_items,
end_date=end_date,
period=period,
limit=limit,
)
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,419 @@
# -*- coding: utf-8 -*-
"""Runtime/workspace/skills handlers extracted from the main Gateway module."""
from __future__ import annotations
import json
from datetime import datetime
from typing import Any
from backend.agents.agent_workspace import load_agent_workspace_config
from backend.agents.skills_manager import SkillsManager
from backend.agents.toolkit_factory import load_agent_profiles
from backend.config.bootstrap_config import (
get_bootstrap_config_for_run,
resolve_runtime_config,
update_bootstrap_values_for_run,
)
from backend.data.market_ingest import ingest_symbols
from backend.llm.models import get_agent_model_info
async def handle_reload_runtime_assets(gateway: Any) -> None:
config_name = gateway.config.get("config_name", "default")
runtime_config = resolve_runtime_config(
project_root=gateway._project_root,
config_name=config_name,
enable_memory=gateway.config.get("enable_memory", False),
schedule_mode=gateway.config.get("schedule_mode", "daily"),
interval_minutes=gateway.config.get("interval_minutes", 60),
trigger_time=gateway.config.get("trigger_time", "09:30"),
)
result = gateway.pipeline.reload_runtime_assets(runtime_config=runtime_config)
runtime_updates = gateway._apply_runtime_config(runtime_config)
await gateway.state_sync.on_system_message("Runtime assets reloaded.")
await gateway.broadcast({"type": "runtime_assets_reloaded", **result, **runtime_updates})
async def handle_update_runtime_config(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
updates: dict[str, Any] = {}
schedule_mode = str(data.get("schedule_mode", "")).strip().lower()
if schedule_mode:
if schedule_mode not in {"daily", "intraday"}:
await websocket.send(json.dumps({"type": "error", "message": "schedule_mode must be 'daily' or 'intraday'."}, ensure_ascii=False))
return
updates["schedule_mode"] = schedule_mode
interval_minutes = data.get("interval_minutes")
if interval_minutes is not None:
try:
parsed_interval = int(interval_minutes)
except (TypeError, ValueError):
parsed_interval = 0
if parsed_interval <= 0:
await websocket.send(json.dumps({"type": "error", "message": "interval_minutes must be a positive integer."}, ensure_ascii=False))
return
updates["interval_minutes"] = parsed_interval
trigger_time = data.get("trigger_time")
if trigger_time is not None:
raw_trigger = str(trigger_time).strip()
if raw_trigger and raw_trigger != "now":
try:
datetime.strptime(raw_trigger, "%H:%M")
except ValueError:
await websocket.send(json.dumps({"type": "error", "message": "trigger_time must use HH:MM or 'now'."}, ensure_ascii=False))
return
updates["trigger_time"] = raw_trigger or "09:30"
max_comm_cycles = data.get("max_comm_cycles")
if max_comm_cycles is not None:
try:
parsed_cycles = int(max_comm_cycles)
except (TypeError, ValueError):
parsed_cycles = 0
if parsed_cycles <= 0:
await websocket.send(json.dumps({"type": "error", "message": "max_comm_cycles must be a positive integer."}, ensure_ascii=False))
return
updates["max_comm_cycles"] = parsed_cycles
initial_cash = data.get("initial_cash")
if initial_cash is not None:
try:
parsed_initial_cash = float(initial_cash)
except (TypeError, ValueError):
parsed_initial_cash = 0.0
if parsed_initial_cash <= 0:
await websocket.send(json.dumps({"type": "error", "message": "initial_cash must be a positive number."}, ensure_ascii=False))
return
updates["initial_cash"] = parsed_initial_cash
margin_requirement = data.get("margin_requirement")
if margin_requirement is not None:
try:
parsed_margin_requirement = float(margin_requirement)
except (TypeError, ValueError):
parsed_margin_requirement = -1.0
if parsed_margin_requirement < 0:
await websocket.send(json.dumps({"type": "error", "message": "margin_requirement must be a non-negative number."}, ensure_ascii=False))
return
updates["margin_requirement"] = parsed_margin_requirement
enable_memory = data.get("enable_memory")
if enable_memory is not None:
updates["enable_memory"] = bool(enable_memory)
if not updates:
await websocket.send(json.dumps({"type": "error", "message": "No runtime settings were provided."}, ensure_ascii=False))
return
config_name = gateway.config.get("config_name", "default")
update_bootstrap_values_for_run(
project_root=gateway._project_root,
config_name=config_name,
updates=updates,
)
await gateway.state_sync.on_system_message("运行时调度配置已保存,正在热更新")
await handle_reload_runtime_assets(gateway)
async def handle_update_watchlist(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
tickers = gateway._normalize_watchlist(data.get("tickers"))
if not tickers:
await websocket.send(json.dumps({"type": "error", "message": "update_watchlist requires at least one valid ticker."}, ensure_ascii=False))
return
config_name = gateway.config.get("config_name", "default")
update_bootstrap_values_for_run(
project_root=gateway._project_root,
config_name=config_name,
updates={"tickers": tickers},
)
await gateway.state_sync.on_system_message(f"Watchlist updated: {', '.join(tickers)}")
await gateway.broadcast({"type": "watchlist_updated", "config_name": config_name, "tickers": tickers})
await handle_reload_runtime_assets(gateway)
gateway._schedule_watchlist_market_store_refresh(tickers)
async def handle_get_agent_skills(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
agent_id = str(data.get("agent_id", "")).strip()
if not agent_id:
await websocket.send(json.dumps({"type": "error", "message": "get_agent_skills requires agent_id."}, ensure_ascii=False))
return
config_name = gateway.config.get("config_name", "default")
skills_manager = SkillsManager(project_root=gateway._project_root)
agent_asset_dir = skills_manager.get_agent_asset_dir(config_name, agent_id)
agent_config = load_agent_workspace_config(agent_asset_dir / "agent.yaml")
resolved_skills = set(skills_manager.resolve_agent_skill_names(config_name=config_name, agent_id=agent_id, default_skills=[]))
enabled = set(agent_config.enabled_skills)
disabled = set(agent_config.disabled_skills)
payload = []
for item in skills_manager.list_agent_skill_catalog(config_name, agent_id):
if item.skill_name in disabled:
status = "disabled"
elif item.skill_name in enabled:
status = "enabled"
elif item.skill_name in resolved_skills:
status = "active"
else:
status = "available"
payload.append({
"skill_name": item.skill_name,
"name": item.name,
"description": item.description,
"version": item.version,
"source": item.source,
"tools": item.tools,
"status": status,
})
await websocket.send(json.dumps({
"type": "agent_skills_loaded",
"config_name": config_name,
"agent_id": agent_id,
"skills": payload,
}, ensure_ascii=False))
async def handle_get_agent_profile(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
agent_id = str(data.get("agent_id", "")).strip()
if not agent_id:
await websocket.send(json.dumps({"type": "error", "message": "get_agent_profile requires agent_id."}, ensure_ascii=False))
return
config_name = gateway.config.get("config_name", "default")
skills_manager = SkillsManager(project_root=gateway._project_root)
asset_dir = skills_manager.get_agent_asset_dir(config_name, agent_id)
agent_config = load_agent_workspace_config(asset_dir / "agent.yaml")
profiles = load_agent_profiles()
profile = profiles.get(agent_id, {})
bootstrap = get_bootstrap_config_for_run(gateway._project_root, config_name)
override = bootstrap.agent_override(agent_id)
active_tool_groups = override.get("active_tool_groups", agent_config.active_tool_groups or profile.get("active_tool_groups", []))
if not isinstance(active_tool_groups, list):
active_tool_groups = []
disabled_tool_groups = agent_config.disabled_tool_groups
if disabled_tool_groups:
disabled_set = set(disabled_tool_groups)
active_tool_groups = [group_name for group_name in active_tool_groups if group_name not in disabled_set]
default_skills = profile.get("skills", [])
if not isinstance(default_skills, list):
default_skills = []
resolved_skills = skills_manager.resolve_agent_skill_names(
config_name=config_name,
agent_id=agent_id,
default_skills=default_skills,
)
prompt_files = agent_config.prompt_files or ["SOUL.md", "PROFILE.md", "AGENTS.md", "POLICY.md", "MEMORY.md"]
model_name, model_provider = get_agent_model_info(agent_id)
await websocket.send(json.dumps({
"type": "agent_profile_loaded",
"config_name": config_name,
"agent_id": agent_id,
"profile": {
"model_name": model_name,
"model_provider": model_provider,
"prompt_files": prompt_files,
"default_skills": default_skills,
"resolved_skills": resolved_skills,
"active_tool_groups": active_tool_groups,
"disabled_tool_groups": disabled_tool_groups,
"enabled_skills": agent_config.enabled_skills,
"disabled_skills": agent_config.disabled_skills,
},
}, ensure_ascii=False))
async def handle_get_skill_detail(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
agent_id = str(data.get("agent_id", "")).strip()
skill_name = str(data.get("skill_name", "")).strip()
if not skill_name:
await websocket.send(json.dumps({"type": "error", "message": "get_skill_detail requires skill_name."}, ensure_ascii=False))
return
skills_manager = SkillsManager(project_root=gateway._project_root)
try:
if agent_id:
config_name = gateway.config.get("config_name", "default")
detail = skills_manager.load_agent_skill_document(config_name=config_name, agent_id=agent_id, skill_name=skill_name)
else:
detail = skills_manager.load_skill_document(skill_name)
except FileNotFoundError:
await websocket.send(json.dumps({"type": "error", "message": f"Unknown skill: {skill_name}"}, ensure_ascii=False))
return
await websocket.send(json.dumps({
"type": "skill_detail_loaded",
"agent_id": agent_id,
"skill": detail,
}, ensure_ascii=False))
async def handle_create_agent_local_skill(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
agent_id = str(data.get("agent_id", "")).strip()
skill_name = str(data.get("skill_name", "")).strip()
if not agent_id or not skill_name:
await websocket.send(json.dumps({"type": "error", "message": "create_agent_local_skill requires agent_id and skill_name."}, ensure_ascii=False))
return
config_name = gateway.config.get("config_name", "default")
skills_manager = SkillsManager(project_root=gateway._project_root)
try:
skills_manager.create_agent_local_skill(config_name=config_name, agent_id=agent_id, skill_name=skill_name)
except (ValueError, FileExistsError) as exc:
await websocket.send(json.dumps({"type": "error", "message": str(exc)}, ensure_ascii=False))
return
await gateway.state_sync.on_system_message(f"Created local skill {skill_name} for {agent_id}")
await gateway._handle_reload_runtime_assets()
await websocket.send(json.dumps({"type": "agent_local_skill_created", "agent_id": agent_id, "skill_name": skill_name}, ensure_ascii=False))
await handle_get_agent_skills(gateway, websocket, {"agent_id": agent_id})
await handle_get_skill_detail(gateway, websocket, {"agent_id": agent_id, "skill_name": skill_name})
async def handle_update_agent_local_skill(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
agent_id = str(data.get("agent_id", "")).strip()
skill_name = str(data.get("skill_name", "")).strip()
content = data.get("content")
if not agent_id or not skill_name or not isinstance(content, str):
await websocket.send(json.dumps({"type": "error", "message": "update_agent_local_skill requires agent_id, skill_name, and string content."}, ensure_ascii=False))
return
config_name = gateway.config.get("config_name", "default")
skills_manager = SkillsManager(project_root=gateway._project_root)
try:
skills_manager.update_agent_local_skill(config_name=config_name, agent_id=agent_id, skill_name=skill_name, content=content)
except (ValueError, FileNotFoundError) as exc:
await websocket.send(json.dumps({"type": "error", "message": str(exc)}, ensure_ascii=False))
return
await gateway.state_sync.on_system_message(f"Updated local skill {skill_name} for {agent_id}")
await gateway._handle_reload_runtime_assets()
await websocket.send(json.dumps({"type": "agent_local_skill_updated", "agent_id": agent_id, "skill_name": skill_name}, ensure_ascii=False))
await handle_get_skill_detail(gateway, websocket, {"agent_id": agent_id, "skill_name": skill_name})
async def handle_delete_agent_local_skill(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
agent_id = str(data.get("agent_id", "")).strip()
skill_name = str(data.get("skill_name", "")).strip()
if not agent_id or not skill_name:
await websocket.send(json.dumps({"type": "error", "message": "delete_agent_local_skill requires agent_id and skill_name."}, ensure_ascii=False))
return
config_name = gateway.config.get("config_name", "default")
skills_manager = SkillsManager(project_root=gateway._project_root)
try:
skills_manager.delete_agent_local_skill(config_name=config_name, agent_id=agent_id, skill_name=skill_name)
skills_manager.forget_agent_skill_overrides(config_name=config_name, agent_id=agent_id, skill_names=[skill_name])
except (ValueError, FileNotFoundError) as exc:
await websocket.send(json.dumps({"type": "error", "message": str(exc)}, ensure_ascii=False))
return
await gateway.state_sync.on_system_message(f"Deleted local skill {skill_name} for {agent_id}")
await gateway._handle_reload_runtime_assets()
await websocket.send(json.dumps({"type": "agent_local_skill_deleted", "agent_id": agent_id, "skill_name": skill_name}, ensure_ascii=False))
await handle_get_agent_skills(gateway, websocket, {"agent_id": agent_id})
async def handle_remove_agent_skill(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
agent_id = str(data.get("agent_id", "")).strip()
skill_name = str(data.get("skill_name", "")).strip()
if not agent_id or not skill_name:
await websocket.send(json.dumps({"type": "error", "message": "remove_agent_skill requires agent_id and skill_name."}, ensure_ascii=False))
return
config_name = gateway.config.get("config_name", "default")
skills_manager = SkillsManager(project_root=gateway._project_root)
skill_names = {
item.skill_name
for item in skills_manager.list_agent_skill_catalog(config_name, agent_id)
if item.source != "local"
}
if skill_name not in skill_names:
await websocket.send(json.dumps({"type": "error", "message": f"Unknown shared skill: {skill_name}"}, ensure_ascii=False))
return
skills_manager.update_agent_skill_overrides(config_name=config_name, agent_id=agent_id, disable=[skill_name])
await gateway.state_sync.on_system_message(f"Removed shared skill {skill_name} from {agent_id}")
await gateway._handle_reload_runtime_assets()
await websocket.send(json.dumps({"type": "agent_skill_removed", "agent_id": agent_id, "skill_name": skill_name}, ensure_ascii=False))
await handle_get_agent_skills(gateway, websocket, {"agent_id": agent_id})
async def handle_update_agent_skill(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
agent_id = str(data.get("agent_id", "")).strip()
skill_name = str(data.get("skill_name", "")).strip()
enabled = data.get("enabled")
if not agent_id or not skill_name or not isinstance(enabled, bool):
await websocket.send(json.dumps({"type": "error", "message": "update_agent_skill requires agent_id, skill_name, and boolean enabled."}, ensure_ascii=False))
return
config_name = gateway.config.get("config_name", "default")
skills_manager = SkillsManager(project_root=gateway._project_root)
skill_names = {item.skill_name for item in skills_manager.list_agent_skill_catalog(config_name, agent_id)}
if skill_name not in skill_names:
await websocket.send(json.dumps({"type": "error", "message": f"Unknown skill: {skill_name}"}, ensure_ascii=False))
return
if enabled:
skills_manager.update_agent_skill_overrides(config_name=config_name, agent_id=agent_id, enable=[skill_name])
await gateway.state_sync.on_system_message(f"Enabled skill {skill_name} for {agent_id}")
else:
skills_manager.update_agent_skill_overrides(config_name=config_name, agent_id=agent_id, disable=[skill_name])
await gateway.state_sync.on_system_message(f"Disabled skill {skill_name} for {agent_id}")
await websocket.send(json.dumps({
"type": "agent_skill_updated",
"agent_id": agent_id,
"skill_name": skill_name,
"enabled": enabled,
}, ensure_ascii=False))
await gateway._handle_reload_runtime_assets()
await handle_get_agent_skills(gateway, websocket, {"agent_id": agent_id})
async def handle_get_agent_workspace_file(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
agent_id = str(data.get("agent_id", "")).strip()
filename = gateway._normalize_agent_workspace_filename(data.get("filename"))
if not agent_id or not filename:
await websocket.send(json.dumps({"type": "error", "message": "get_agent_workspace_file requires agent_id and supported filename."}, ensure_ascii=False))
return
config_name = gateway.config.get("config_name", "default")
skills_manager = SkillsManager(project_root=gateway._project_root)
asset_dir = skills_manager.get_agent_asset_dir(config_name, agent_id)
asset_dir.mkdir(parents=True, exist_ok=True)
path = asset_dir / filename
content = path.read_text(encoding="utf-8") if path.exists() else ""
await websocket.send(json.dumps({
"type": "agent_workspace_file_loaded",
"config_name": config_name,
"agent_id": agent_id,
"filename": filename,
"content": content,
}, ensure_ascii=False))
async def handle_update_agent_workspace_file(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
agent_id = str(data.get("agent_id", "")).strip()
filename = gateway._normalize_agent_workspace_filename(data.get("filename"))
content = data.get("content")
if not agent_id or not filename or not isinstance(content, str):
await websocket.send(json.dumps({"type": "error", "message": "update_agent_workspace_file requires agent_id, supported filename, and string content."}, ensure_ascii=False))
return
config_name = gateway.config.get("config_name", "default")
skills_manager = SkillsManager(project_root=gateway._project_root)
asset_dir = skills_manager.get_agent_asset_dir(config_name, agent_id)
asset_dir.mkdir(parents=True, exist_ok=True)
path = asset_dir / filename
path.write_text(content, encoding="utf-8")
await gateway.state_sync.on_system_message(f"Updated {filename} for {agent_id}")
await websocket.send(json.dumps({"type": "agent_workspace_file_updated", "agent_id": agent_id, "filename": filename}, ensure_ascii=False))
await gateway._handle_reload_runtime_assets()
await handle_get_agent_workspace_file(gateway, websocket, {"agent_id": agent_id, "filename": filename})

View File

@@ -0,0 +1,373 @@
# -*- coding: utf-8 -*-
"""Cycle and monitoring helpers extracted from the main Gateway module."""
from __future__ import annotations
import asyncio
import logging
from typing import Any
from backend.data.market_ingest import ingest_symbols
from backend.domains import trading as trading_domain
from backend.utils.msg_adapter import FrontendAdapter
logger = logging.getLogger(__name__)
def schedule_watchlist_market_store_refresh(gateway: Any, tickers: list[str]) -> None:
"""Kick off a non-blocking market-store refresh for an updated watchlist."""
if not tickers:
return
if gateway._watchlist_ingest_task and not gateway._watchlist_ingest_task.done():
gateway._watchlist_ingest_task.cancel()
gateway._watchlist_ingest_task = asyncio.create_task(
refresh_market_store_for_watchlist(gateway, tickers),
)
async def refresh_market_store_for_watchlist(gateway: Any, tickers: list[str]) -> None:
"""Refresh the long-lived market store after a watchlist update."""
try:
await gateway.state_sync.on_system_message(
f"正在同步自选股市场数据: {', '.join(tickers)}",
)
results = await asyncio.to_thread(
ingest_symbols,
tickers,
mode="incremental",
)
summary = ", ".join(
f"{item['symbol']} prices={item['prices']} news={item['news']}"
for item in results
)
await gateway.state_sync.on_system_message(
f"自选股市场数据已同步: {summary}",
)
except asyncio.CancelledError:
raise
except Exception as exc:
logger.warning("Watchlist market store refresh failed: %s", exc)
await gateway.state_sync.on_system_message(
f"自选股市场数据同步失败: {exc}",
)
async def market_status_monitor(gateway: Any) -> None:
"""Periodically check and broadcast market status changes."""
while True:
try:
await gateway.market_service.check_and_broadcast_market_status()
status = gateway.market_service.get_market_status()
if status["status"] == "open" and not gateway.storage.is_live_session_active:
gateway.storage.start_live_session()
summary = gateway.storage.load_file("summary") or {}
gateway._session_start_portfolio_value = summary.get(
"totalAssetValue",
gateway.storage.initial_cash,
)
logger.info(
"Session start portfolio: $%s",
f"{gateway._session_start_portfolio_value:,.2f}",
)
elif status["status"] != "open" and gateway.storage.is_live_session_active:
gateway.storage.end_live_session()
gateway._session_start_portfolio_value = None
if gateway.storage.is_live_session_active:
await update_and_broadcast_live_returns(gateway)
await asyncio.sleep(60)
except asyncio.CancelledError:
break
except Exception as exc:
logger.error("Market status monitor error: %s", exc)
await asyncio.sleep(60)
async def update_and_broadcast_live_returns(gateway: Any) -> None:
"""Calculate and broadcast live returns for current session."""
if not gateway.storage.is_live_session_active:
return
prices = gateway.market_service.get_all_prices()
if not prices or not any(p > 0 for p in prices.values()):
return
state = gateway.storage.load_internal_state()
equity_history = state.get("equity_history", [])
baseline_history = state.get("baseline_history", [])
baseline_vw_history = state.get("baseline_vw_history", [])
momentum_history = state.get("momentum_history", [])
current_equity = equity_history[-1]["v"] if equity_history else None
current_baseline = baseline_history[-1]["v"] if baseline_history else None
current_baseline_vw = baseline_vw_history[-1]["v"] if baseline_vw_history else None
current_momentum = momentum_history[-1]["v"] if momentum_history else None
point = gateway.storage.update_live_returns(
current_equity=current_equity,
current_baseline=current_baseline,
current_baseline_vw=current_baseline_vw,
current_momentum=current_momentum,
)
if point:
live_returns = gateway.storage.get_live_returns()
await gateway.broadcast(
{
"type": "team_summary",
"equity_return": live_returns["equity_return"],
"baseline_return": live_returns["baseline_return"],
"baseline_vw_return": live_returns["baseline_vw_return"],
"momentum_return": live_returns["momentum_return"],
},
)
async def on_strategy_trigger(gateway: Any, date: str) -> None:
"""Handle trading cycle trigger."""
if gateway._cycle_lock.locked():
logger.warning("Trading cycle already running, skipping trigger for %s", date)
await gateway.state_sync.on_system_message(f"已有交易周期在运行,跳过本次触发: {date}")
return
async with gateway._cycle_lock:
logger.info("Strategy triggered for %s", date)
tickers = gateway.config.get("tickers", [])
if gateway.is_backtest:
await run_backtest_cycle(gateway, date, tickers)
else:
await run_live_cycle(gateway, date, tickers)
async def on_heartbeat_trigger(gateway: Any, date: str) -> None:
"""Run lightweight heartbeat check for all analysts."""
logger.info("[Heartbeat] Running heartbeat check for %s", date)
analysts = gateway.pipeline._all_analysts()
for analyst in analysts:
try:
ws_id = getattr(analyst, "workspace_id", None)
if ws_id:
from backend.agents.workspace_manager import get_workspace_dir
from pathlib import Path
from agentscope.message import Msg
ws_dir = get_workspace_dir(ws_id)
if ws_dir:
hb_path = Path(ws_dir) / "HEARTBEAT.md"
if hb_path.exists():
content = hb_path.read_text(encoding="utf-8").strip()
if content:
hb_task = f"# 定期主动检查\n\n{content}\n\n请执行上述检查并报告结果。"
logger.info("[Heartbeat] Running heartbeat for %s", analyst.name)
msg = Msg(role="user", content=hb_task, name="system")
await analyst.reply([msg])
logger.info("[Heartbeat] %s heartbeat complete", analyst.name)
continue
logger.debug("[Heartbeat] No HEARTBEAT.md for %s, skipping", analyst.name)
except Exception as exc:
logger.error("[Heartbeat] %s failed: %s", analyst.name, exc, exc_info=True)
async def run_backtest_cycle(gateway: Any, date: str, tickers: list[str]) -> None:
gateway.market_service.set_backtest_date(date)
await gateway.market_service.emit_market_open()
await gateway.state_sync.on_cycle_start(date)
gateway._dashboard.update(date=date, status="Analyzing...")
prices = gateway.market_service.get_open_prices()
close_prices = gateway.market_service.get_close_prices()
market_caps = await get_market_caps(gateway, tickers, date)
result = await gateway.pipeline.run_cycle(
tickers=tickers,
date=date,
prices=prices,
close_prices=close_prices,
market_caps=market_caps,
)
await gateway.market_service.emit_market_close()
settlement_result = result.get("settlement_result")
save_cycle_results(gateway, result, date, close_prices, settlement_result)
await broadcast_portfolio_updates(gateway, result, close_prices)
await finalize_cycle(gateway, date)
async def run_live_cycle(gateway: Any, date: str, tickers: list[str]) -> None:
trading_date = gateway.market_service.get_live_trading_date()
logger.info("Live cycle: triggered=%s, trading_date=%s", date, trading_date)
await gateway.state_sync.on_cycle_start(trading_date)
gateway._dashboard.update(date=trading_date, status="Analyzing...")
market_caps = await get_market_caps(gateway, tickers, trading_date)
schedule_mode = gateway.config.get("schedule_mode", "daily")
market_status = gateway.market_service.get_market_status()
current_prices = gateway.market_service.get_all_prices()
if schedule_mode == "intraday":
execute_decisions = market_status.get("status") == "open"
if execute_decisions:
await gateway.state_sync.on_system_message("定时任务触发:当前处于交易时段,本轮将执行交易决策")
else:
await gateway.state_sync.on_system_message("定时任务触发:当前非交易时段,本轮仅更新数据与分析,不执行交易")
result = await gateway.pipeline.run_cycle(
tickers=tickers,
date=trading_date,
prices=current_prices,
market_caps=market_caps,
execute_decisions=execute_decisions,
)
close_prices = current_prices
else:
result = await gateway.pipeline.run_cycle(
tickers=tickers,
date=trading_date,
market_caps=market_caps,
get_open_prices_fn=gateway.market_service.wait_for_open_prices,
get_close_prices_fn=gateway.market_service.wait_for_close_prices,
)
close_prices = gateway.market_service.get_all_prices()
settlement_result = result.get("settlement_result")
save_cycle_results(gateway, result, trading_date, close_prices, settlement_result)
await broadcast_portfolio_updates(gateway, result, close_prices)
await finalize_cycle(gateway, trading_date)
async def finalize_cycle(gateway: Any, date: str) -> None:
summary = gateway.storage.load_file("summary") or {}
if gateway.storage.is_live_session_active:
summary.update(gateway.storage.get_live_returns())
await gateway.state_sync.on_cycle_end(date, portfolio_summary=summary)
holdings = gateway.storage.load_file("holdings") or []
trades = gateway.storage.load_file("trades") or []
leaderboard = gateway.storage.load_file("leaderboard") or []
if leaderboard:
await gateway.state_sync.on_leaderboard_update(leaderboard)
gateway._dashboard.update(date=date, status="Running", portfolio=summary, holdings=holdings, trades=trades)
async def get_market_caps(gateway: Any, tickers: list[str], date: str) -> dict[str, float]:
market_caps: dict[str, float] = {}
for ticker in tickers:
try:
market_cap = None
response = await gateway._call_trading_service(
f"get_market_cap for {ticker}",
lambda client, symbol=ticker: client.get_market_cap(ticker=symbol, end_date=date),
)
if response is not None:
market_cap = response.get("market_cap")
if market_cap is None:
payload = trading_domain.get_market_cap_payload(ticker=ticker, end_date=date)
market_cap = payload.get("market_cap")
market_caps[ticker] = market_cap if market_cap else 1e9
except Exception as exc:
logger.warning("Failed to get market cap for %s, using default 1e9: %s", ticker, exc)
market_caps[ticker] = 1e9
return market_caps
async def broadcast_portfolio_updates(gateway: Any, result: dict[str, Any], prices: dict[str, float]) -> None:
portfolio = result.get("portfolio", {})
if portfolio:
holdings = FrontendAdapter.build_holdings(portfolio, prices)
if holdings:
await gateway.state_sync.on_holdings_update(holdings)
stats = FrontendAdapter.build_stats(portfolio, prices)
if stats:
await gateway.state_sync.on_stats_update(stats)
executed_trades = result.get("executed_trades", [])
if executed_trades:
await gateway.state_sync.on_trades_executed(executed_trades)
def save_cycle_results(
gateway: Any,
result: dict[str, Any],
date: str,
prices: dict[str, float],
settlement_result: dict[str, Any] | None = None,
) -> None:
portfolio = result.get("portfolio", {})
executed_trades = result.get("executed_trades", [])
baseline_values = settlement_result.get("baseline_values") if settlement_result else None
if portfolio:
gateway.storage.update_dashboard_after_cycle(
portfolio=portfolio,
prices=prices,
date=date,
executed_trades=executed_trades,
baseline_values=baseline_values,
)
async def run_backtest_dates(gateway: Any, dates: list[str]) -> None:
gateway.state_sync.set_backtest_dates(dates)
gateway._dashboard.update(days_total=len(dates), days_completed=0)
await gateway.state_sync.on_system_message(f"Starting backtest - {len(dates)} trading days")
try:
for i, date in enumerate(dates):
gateway._dashboard.update(days_completed=i)
await gateway.on_strategy_trigger(date=date)
await asyncio.sleep(0.1)
await gateway.state_sync.on_system_message(f"Backtest complete - {len(dates)} days")
summary = gateway.storage.load_file("summary") or {}
gateway._dashboard.update(status="Complete", portfolio=summary, days_completed=len(dates))
gateway._dashboard.stop()
gateway._dashboard.print_final_summary()
except Exception as exc:
error_msg = f"Backtest failed: {type(exc).__name__}: {str(exc)}"
logger.error(error_msg, exc_info=True)
asyncio.create_task(gateway.state_sync.on_system_message(error_msg))
gateway._dashboard.update(status=f"Failed: {str(exc)}")
gateway._dashboard.stop()
raise
finally:
gateway._backtest_task = None
def handle_backtest_exception(gateway: Any, task: asyncio.Task) -> None:
try:
task.result()
except asyncio.CancelledError:
logger.info("Backtest task was cancelled")
except Exception as exc:
logger.error("Backtest task failed with exception:%s:%s", type(exc).__name__, exc, exc_info=True)
def handle_manual_cycle_exception(gateway: Any, task: asyncio.Task) -> None:
gateway._manual_cycle_task = None
try:
task.result()
except asyncio.CancelledError:
logger.info("Manual cycle task was cancelled")
except Exception as exc:
logger.error("Manual cycle task failed with exception:%s:%s", type(exc).__name__, exc, exc_info=True)
def set_backtest_dates(gateway: Any, dates: list[str]) -> None:
gateway.state_sync.set_backtest_dates(dates)
if dates:
gateway._backtest_start_date = dates[0]
gateway._backtest_end_date = dates[-1]
gateway._dashboard.days_total = len(dates)
def stop_gateway(gateway: Any) -> None:
gateway.state_sync.save_state()
gateway.market_service.stop()
if gateway._backtest_task:
gateway._backtest_task.cancel()
if gateway._market_status_task:
gateway._market_status_task.cancel()
if gateway._watchlist_ingest_task:
gateway._watchlist_ingest_task.cancel()
gateway._dashboard.stop()

View File

@@ -0,0 +1,174 @@
# -*- coding: utf-8 -*-
"""Runtime/state support helpers extracted from the main Gateway module."""
from __future__ import annotations
from typing import Any
from backend.data.provider_utils import normalize_symbol
def normalize_watchlist(raw_tickers: Any) -> list[str]:
"""Parse watchlist payloads from websocket messages."""
if raw_tickers is None:
return []
if isinstance(raw_tickers, str):
candidates = raw_tickers.split(",")
elif isinstance(raw_tickers, list):
candidates = raw_tickers
else:
candidates = [raw_tickers]
tickers: list[str] = []
for candidate in candidates:
symbol = normalize_symbol(str(candidate).strip().strip("\"'"))
if symbol and symbol not in tickers:
tickers.append(symbol)
return tickers
def normalize_agent_workspace_filename(
raw_name: Any,
*,
allowlist: set[str],
) -> str | None:
"""Restrict editable workspace files to a safe allowlist."""
filename = str(raw_name or "").strip()
if filename in allowlist:
return filename
return None
def apply_runtime_config(gateway: Any, runtime_config: dict[str, Any]) -> dict[str, Any]:
"""Apply runtime config to gateway-owned services and state."""
warnings: list[str] = []
ticker_changes = gateway.market_service.update_tickers(
runtime_config.get("tickers", []),
)
gateway.config["tickers"] = ticker_changes["active"]
gateway.pipeline.max_comm_cycles = int(runtime_config["max_comm_cycles"])
gateway.config["max_comm_cycles"] = gateway.pipeline.max_comm_cycles
gateway.config["schedule_mode"] = runtime_config.get(
"schedule_mode",
gateway.config.get("schedule_mode", "daily"),
)
gateway.config["interval_minutes"] = int(
runtime_config.get(
"interval_minutes",
gateway.config.get("interval_minutes", 60),
),
)
gateway.config["trigger_time"] = runtime_config.get(
"trigger_time",
gateway.config.get("trigger_time", "09:30"),
)
if gateway.scheduler:
gateway.scheduler.reconfigure(
mode=gateway.config["schedule_mode"],
trigger_time=gateway.config["trigger_time"],
interval_minutes=gateway.config["interval_minutes"],
)
pm_apply_result = gateway.pipeline.pm.apply_runtime_portfolio_config(
margin_requirement=runtime_config["margin_requirement"],
)
gateway.config["margin_requirement"] = gateway.pipeline.pm.portfolio.get(
"margin_requirement",
runtime_config["margin_requirement"],
)
requested_initial_cash = float(runtime_config["initial_cash"])
current_initial_cash = float(gateway.storage.initial_cash)
initial_cash_applied = requested_initial_cash == current_initial_cash
if not initial_cash_applied:
if (
gateway.storage.can_apply_initial_cash()
and gateway.pipeline.pm.can_apply_initial_cash()
):
initial_cash_applied = gateway.storage.apply_initial_cash(
requested_initial_cash,
)
if initial_cash_applied:
gateway.pipeline.pm.apply_runtime_portfolio_config(
initial_cash=requested_initial_cash,
)
gateway.config["initial_cash"] = gateway.storage.initial_cash
else:
warnings.append(
"initial_cash changed in BOOTSTRAP.md but was not applied "
"because the run already has positions, margin usage, or trades.",
)
requested_enable_memory = bool(runtime_config["enable_memory"])
current_enable_memory = bool(gateway.config.get("enable_memory", False))
if requested_enable_memory != current_enable_memory:
warnings.append(
"enable_memory changed in BOOTSTRAP.md but still requires a restart "
"because long-term memory contexts are created at startup.",
)
sync_runtime_state(gateway)
return {
"runtime_config_requested": runtime_config,
"runtime_config_applied": {
"tickers": list(gateway.config.get("tickers", [])),
"schedule_mode": gateway.config.get("schedule_mode", "daily"),
"interval_minutes": gateway.config.get("interval_minutes", 60),
"trigger_time": gateway.config.get("trigger_time", "09:30"),
"initial_cash": gateway.storage.initial_cash,
"margin_requirement": gateway.config["margin_requirement"],
"max_comm_cycles": gateway.config["max_comm_cycles"],
"enable_memory": gateway.config.get("enable_memory", False),
},
"runtime_config_status": {
"tickers": True,
"schedule_mode": True,
"interval_minutes": True,
"trigger_time": True,
"initial_cash": initial_cash_applied,
"margin_requirement": pm_apply_result["margin_requirement"],
"max_comm_cycles": True,
"enable_memory": requested_enable_memory == current_enable_memory,
},
"ticker_changes": ticker_changes,
"runtime_config_warnings": warnings,
}
def sync_runtime_state(gateway: Any) -> None:
"""Refresh persisted state and dashboard after runtime config changes."""
gateway.state_sync.update_state("tickers", gateway.config.get("tickers", []))
gateway.state_sync.update_state(
"runtime_config",
{
"tickers": gateway.config.get("tickers", []),
"schedule_mode": gateway.config.get("schedule_mode", "daily"),
"interval_minutes": gateway.config.get("interval_minutes", 60),
"trigger_time": gateway.config.get("trigger_time", "09:30"),
"initial_cash": gateway.storage.initial_cash,
"margin_requirement": gateway.config.get("margin_requirement"),
"max_comm_cycles": gateway.config.get("max_comm_cycles"),
"enable_memory": gateway.config.get("enable_memory", False),
},
)
gateway.storage.update_server_state_from_dashboard(gateway.state_sync.state)
gateway.state_sync.save_state()
gateway._dashboard.tickers = list(gateway.config.get("tickers", []))
gateway._dashboard.initial_cash = gateway.storage.initial_cash
gateway._dashboard.enable_memory = bool(gateway.config.get("enable_memory", False))
summary = gateway.storage.load_file("summary") or {}
holdings = gateway.storage.load_file("holdings") or []
trades = gateway.storage.load_file("trades") or []
gateway._dashboard.update(
portfolio=summary,
holdings=holdings,
trades=trades,
)

View File

@@ -0,0 +1,711 @@
# -*- coding: utf-8 -*-
"""Stock-related Gateway handlers extracted from the main Gateway module."""
from __future__ import annotations
import asyncio
import json
import logging
from datetime import datetime, timedelta
from typing import Any
from backend.data.provider_utils import normalize_symbol
from backend.domains import news as news_domain
from backend.domains import trading as trading_domain
from backend.enrich.news_enricher import enrich_news_for_symbol
from backend.enrich.llm_enricher import llm_enrichment_enabled
from backend.tools.data_tools import prices_to_df
from shared.client import NewsServiceClient, TradingServiceClient
logger = logging.getLogger(__name__)
async def handle_get_stock_history(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
ticker = normalize_symbol(data.get("ticker", ""))
if not ticker:
await websocket.send(json.dumps({
"type": "stock_history_loaded",
"ticker": "",
"prices": [],
"source": None,
"error": "invalid ticker",
}, ensure_ascii=False))
return
lookback_days = data.get("lookback_days", 90)
try:
lookback_days = max(7, min(int(lookback_days), 365))
except (TypeError, ValueError):
lookback_days = 90
end_date = gateway.state_sync.state.get("current_date") or datetime.now().strftime("%Y-%m-%d")
try:
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
except ValueError:
end_dt = datetime.now()
end_date = end_dt.strftime("%Y-%m-%d")
start_date = (end_dt - timedelta(days=lookback_days)).strftime("%Y-%m-%d")
prices = []
source = "polygon"
response = await gateway._call_trading_service(
"get_prices for history",
lambda client: client.get_prices(ticker=ticker, start_date=start_date, end_date=end_date),
)
if response is not None:
prices = response.prices
source = "trading_service"
if not prices:
prices = await asyncio.to_thread(gateway.storage.market_store.get_ohlc, ticker, start_date, end_date)
if not prices:
payload = await asyncio.to_thread(
trading_domain.get_prices_payload,
ticker=ticker,
start_date=start_date,
end_date=end_date,
)
prices = payload.get("prices") or []
usage_snapshot = gateway._provider_router.get_usage_snapshot()
source = usage_snapshot.get("last_success", {}).get("prices")
if prices:
await asyncio.to_thread(
gateway.storage.market_store.upsert_ohlc,
ticker,
[price.model_dump() for price in prices],
source=source or "provider",
)
await websocket.send(json.dumps({
"type": "stock_history_loaded",
"ticker": ticker,
"prices": [price if isinstance(price, dict) else price.model_dump() for price in prices][-120:],
"source": source,
"start_date": start_date,
"end_date": end_date,
}, ensure_ascii=False, default=str))
async def handle_get_stock_explain_events(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
ticker = normalize_symbol(data.get("ticker", ""))
snapshot = gateway.storage.runtime_db.get_stock_explain_snapshot(ticker)
await websocket.send(json.dumps({
"type": "stock_explain_events_loaded",
"ticker": ticker,
"events": snapshot.get("events", []),
"signals": snapshot.get("signals", []),
"trades": snapshot.get("trades", []),
}, ensure_ascii=False, default=str))
async def handle_get_stock_news(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
ticker = normalize_symbol(data.get("ticker", ""))
if not ticker:
await websocket.send(json.dumps({
"type": "stock_news_loaded",
"ticker": "",
"news": [],
"source": None,
"error": "invalid ticker",
}, ensure_ascii=False))
return
lookback_days = data.get("lookback_days", 30)
limit = data.get("limit", 12)
try:
lookback_days = max(7, min(int(lookback_days), 180))
except (TypeError, ValueError):
lookback_days = 30
try:
limit = max(1, min(int(limit), 30))
except (TypeError, ValueError):
limit = 12
end_date = gateway.state_sync.state.get("current_date") or datetime.now().strftime("%Y-%m-%d")
try:
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
except ValueError:
end_dt = datetime.now()
end_date = end_dt.strftime("%Y-%m-%d")
start_date = (end_dt - timedelta(days=lookback_days)).strftime("%Y-%m-%d")
news_rows = []
source = "polygon"
response = await gateway._call_news_service(
"get_enriched_news",
lambda client: client.get_enriched_news(
ticker=ticker,
start_date=start_date,
end_date=end_date,
limit=limit,
),
)
if response is not None:
news_rows = response.get("news") or []
source = "news_service"
if not news_rows:
payload = await asyncio.to_thread(
news_domain.get_enriched_news,
gateway.storage.market_store,
ticker=ticker,
start_date=start_date,
end_date=end_date,
limit=max(limit, 50),
)
news_rows = (payload.get("news") or [])[-limit:]
source = "market_store"
await websocket.send(json.dumps({
"type": "stock_news_loaded",
"ticker": ticker,
"news": news_rows[-limit:],
"source": source,
"start_date": start_date,
"end_date": end_date,
}, ensure_ascii=False, default=str))
async def handle_get_stock_news_for_date(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
ticker = normalize_symbol(data.get("ticker", ""))
trade_date = str(data.get("date") or "").strip()
if not ticker or not trade_date:
await websocket.send(json.dumps({
"type": "stock_news_for_date_loaded",
"ticker": ticker,
"date": trade_date,
"news": [],
"error": "ticker and date are required",
}, ensure_ascii=False))
return
limit = data.get("limit", 20)
try:
limit = max(1, min(int(limit), 50))
except (TypeError, ValueError):
limit = 20
source = "market_store"
news_rows = []
response = await gateway._call_news_service(
"get_news_for_date",
lambda client: client.get_news_for_date(ticker=ticker, date=trade_date, limit=limit),
)
if response is not None:
news_rows = response.get("news") or []
source = "news_service"
if not news_rows:
payload = await asyncio.to_thread(
news_domain.get_news_for_date,
gateway.storage.market_store,
ticker=ticker,
date=trade_date,
limit=limit,
)
news_rows = payload.get("news") or []
source = "market_store"
await websocket.send(json.dumps({
"type": "stock_news_for_date_loaded",
"ticker": ticker,
"date": trade_date,
"news": news_rows,
"source": source,
}, ensure_ascii=False, default=str))
async def handle_get_stock_news_timeline(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
ticker = normalize_symbol(data.get("ticker", ""))
if not ticker:
await websocket.send(json.dumps({
"type": "stock_news_timeline_loaded",
"ticker": "",
"timeline": [],
"error": "invalid ticker",
}, ensure_ascii=False))
return
lookback_days = data.get("lookback_days", 90)
try:
lookback_days = max(7, min(int(lookback_days), 365))
except (TypeError, ValueError):
lookback_days = 90
end_date = gateway.state_sync.state.get("current_date") or datetime.now().strftime("%Y-%m-%d")
try:
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
except ValueError:
end_dt = datetime.now()
end_date = end_dt.strftime("%Y-%m-%d")
start_date = (end_dt - timedelta(days=lookback_days)).strftime("%Y-%m-%d")
timeline = []
response = await gateway._call_news_service(
"get_news_timeline",
lambda client: client.get_news_timeline(ticker=ticker, start_date=start_date, end_date=end_date),
)
if response is not None:
timeline = response.get("timeline") or []
if not timeline:
payload = await asyncio.to_thread(
news_domain.get_news_timeline,
gateway.storage.market_store,
ticker=ticker,
start_date=start_date,
end_date=end_date,
)
timeline = payload.get("timeline") or []
await websocket.send(json.dumps({
"type": "stock_news_timeline_loaded",
"ticker": ticker,
"timeline": timeline,
"start_date": start_date,
"end_date": end_date,
}, ensure_ascii=False, default=str))
async def handle_get_stock_news_categories(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
ticker = normalize_symbol(data.get("ticker", ""))
if not ticker:
await websocket.send(json.dumps({
"type": "stock_news_categories_loaded",
"ticker": "",
"categories": {},
"error": "invalid ticker",
}, ensure_ascii=False))
return
lookback_days = data.get("lookback_days", 90)
try:
lookback_days = max(7, min(int(lookback_days), 365))
except (TypeError, ValueError):
lookback_days = 90
end_date = gateway.state_sync.state.get("current_date") or datetime.now().strftime("%Y-%m-%d")
try:
end_dt = datetime.strptime(end_date, "%Y-%m-%d")
except ValueError:
end_dt = datetime.now()
end_date = end_dt.strftime("%Y-%m-%d")
start_date = (end_dt - timedelta(days=lookback_days)).strftime("%Y-%m-%d")
categories = {}
response = await gateway._call_news_service(
"get_categories",
lambda client: client.get_categories(
ticker=ticker,
start_date=start_date,
end_date=end_date,
limit=200,
),
)
if response is not None:
categories = response.get("categories") or {}
if not categories:
payload = await asyncio.to_thread(
news_domain.get_news_categories,
gateway.storage.market_store,
ticker=ticker,
start_date=start_date,
end_date=end_date,
limit=200,
)
categories = payload.get("categories") or {}
await websocket.send(json.dumps({
"type": "stock_news_categories_loaded",
"ticker": ticker,
"categories": categories,
"start_date": start_date,
"end_date": end_date,
}, ensure_ascii=False, default=str))
async def handle_get_stock_range_explain(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
ticker = normalize_symbol(data.get("ticker", ""))
start_date = str(data.get("start_date") or "").strip()
end_date = str(data.get("end_date") or "").strip()
if not ticker or not start_date or not end_date:
await websocket.send(json.dumps({
"type": "stock_range_explain_loaded",
"ticker": ticker,
"result": {"error": "ticker, start_date, end_date are required"},
}, ensure_ascii=False))
return
article_ids = data.get("article_ids")
result = None
response = await gateway._call_news_service(
"get_range_explain",
lambda client: client.get_range_explain(
ticker=ticker,
start_date=start_date,
end_date=end_date,
article_ids=article_ids if isinstance(article_ids, list) else None,
limit=100,
),
)
if response is not None:
result = response.get("result")
if result is None:
payload = await asyncio.to_thread(
news_domain.get_range_explain_payload,
gateway.storage.market_store,
ticker=ticker,
start_date=start_date,
end_date=end_date,
article_ids=article_ids if isinstance(article_ids, list) else None,
limit=100,
)
result = payload.get("result")
await websocket.send(json.dumps({
"type": "stock_range_explain_loaded",
"ticker": ticker,
"result": result,
}, ensure_ascii=False, default=str))
async def handle_get_stock_insider_trades(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
ticker = normalize_symbol(data.get("ticker", ""))
if not ticker:
await websocket.send(json.dumps({
"type": "stock_insider_trades_loaded",
"ticker": "",
"trades": [],
"error": "invalid ticker",
}, ensure_ascii=False))
return
end_date = str(data.get("end_date") or gateway.state_sync.state.get("current_date") or datetime.now().strftime("%Y-%m-%d")).strip()[:10]
start_date = str(data.get("start_date") or "").strip()[:10]
limit = int(data.get("limit", 50))
trades = []
response = await gateway._call_trading_service(
"get_insider_trades",
lambda client: client.get_insider_trades(
ticker=ticker,
end_date=end_date,
start_date=start_date if start_date else None,
limit=limit,
),
)
if response is not None:
trades = response.insider_trades
if not trades:
payload = await asyncio.to_thread(
trading_domain.get_insider_trades_payload,
ticker=ticker,
end_date=end_date,
start_date=start_date if start_date else None,
limit=limit,
)
trades = payload.get("insider_trades") or []
sorted_trades = sorted(trades, key=lambda t: t.transaction_date or "", reverse=True)
formatted_trades = [{
"ticker": t.ticker,
"name": t.name,
"title": t.title,
"is_board_director": t.is_board_director,
"transaction_date": t.transaction_date,
"transaction_shares": t.transaction_shares,
"transaction_price_per_share": t.transaction_price_per_share,
"transaction_value": t.transaction_value,
"shares_owned_before_transaction": t.shares_owned_before_transaction,
"shares_owned_after_transaction": t.shares_owned_after_transaction,
"security_title": t.security_title,
"filing_date": t.filing_date,
"holding_change": (
(t.shares_owned_after_transaction or 0) - (t.shares_owned_before_transaction or 0)
if t.shares_owned_after_transaction and t.shares_owned_before_transaction else None
),
"is_buy": ((t.transaction_shares or 0) > 0) if t.transaction_shares is not None else None,
} for t in sorted_trades]
await websocket.send(json.dumps({
"type": "stock_insider_trades_loaded",
"ticker": ticker,
"start_date": start_date or None,
"end_date": end_date,
"trades": formatted_trades,
}, ensure_ascii=False, default=str))
async def handle_get_stock_story(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
ticker = normalize_symbol(data.get("ticker", ""))
if not ticker:
await websocket.send(json.dumps({
"type": "stock_story_loaded",
"ticker": "",
"story": "",
"error": "invalid ticker",
}, ensure_ascii=False))
return
as_of_date = str(data.get("as_of_date") or gateway.state_sync.state.get("current_date") or datetime.now().strftime("%Y-%m-%d")).strip()[:10]
result = await gateway._call_news_service(
"get_story",
lambda client: client.get_story(ticker=ticker, as_of_date=as_of_date),
)
if result is None:
result = await asyncio.to_thread(
news_domain.get_story_payload,
gateway.storage.market_store,
ticker=ticker,
as_of_date=as_of_date,
)
await websocket.send(json.dumps({
"type": "stock_story_loaded",
"ticker": ticker,
"as_of_date": as_of_date,
"story": result.get("story") or "",
"source": result.get("source") or "local",
}, ensure_ascii=False, default=str))
async def handle_get_stock_similar_days(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
ticker = normalize_symbol(data.get("ticker", ""))
target_date = str(data.get("date") or "").strip()[:10]
if not ticker or not target_date:
await websocket.send(json.dumps({
"type": "stock_similar_days_loaded",
"ticker": ticker,
"date": target_date,
"items": [],
"error": "ticker and date are required",
}, ensure_ascii=False))
return
top_k = data.get("top_k", 8)
try:
top_k = max(1, min(int(top_k), 20))
except (TypeError, ValueError):
top_k = 8
result = await gateway._call_news_service(
"get_similar_days",
lambda client: client.get_similar_days(ticker=ticker, date=target_date, n_similar=top_k),
)
if result is None:
result = await asyncio.to_thread(
news_domain.get_similar_days_payload,
gateway.storage.market_store,
ticker=ticker,
date=target_date,
n_similar=top_k,
)
await websocket.send(json.dumps({
"type": "stock_similar_days_loaded",
"ticker": ticker,
"date": target_date,
**result,
}, ensure_ascii=False, default=str))
async def handle_get_stock_technical_indicators(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
ticker = normalize_symbol(data.get("ticker", ""))
if not ticker:
await websocket.send(json.dumps({
"type": "stock_technical_indicators_loaded",
"ticker": ticker,
"indicators": None,
"error": "ticker is required",
}, ensure_ascii=False))
return
try:
end_date = datetime.now()
start_date = end_date - timedelta(days=250)
prices = None
response = await gateway._call_trading_service(
"get_prices",
lambda client: client.get_prices(
ticker=ticker,
start_date=start_date.strftime("%Y-%m-%d"),
end_date=end_date.strftime("%Y-%m-%d"),
),
)
if response is not None:
prices = response.prices
if prices is None:
payload = trading_domain.get_prices_payload(
ticker=ticker,
start_date=start_date.strftime("%Y-%m-%d"),
end_date=end_date.strftime("%Y-%m-%d"),
)
prices = payload.get("prices") or []
if not prices or len(prices) < 20:
await websocket.send(json.dumps({
"type": "stock_technical_indicators_loaded",
"ticker": ticker,
"indicators": None,
"error": "Insufficient price data",
}, ensure_ascii=False))
return
df = prices_to_df(prices)
signal = gateway._technical_analyzer.analyze(ticker, df)
import pandas as pd
df_sorted = df.sort_values("time").reset_index(drop=True)
df_sorted["returns"] = df_sorted["close"].pct_change()
vol_10 = float(df_sorted["returns"].tail(10).std() * (252**0.5) * 100) if len(df_sorted) >= 10 else None
vol_20 = float(df_sorted["returns"].tail(20).std() * (252**0.5) * 100) if len(df_sorted) >= 20 else None
vol_60 = float(df_sorted["returns"].tail(60).std() * (252**0.5) * 100) if len(df_sorted) >= 60 else None
ma_distance = {}
for ma_key in ["ma5", "ma10", "ma20", "ma50", "ma200"]:
ma_value = getattr(signal, ma_key, None)
ma_distance[ma_key] = ((signal.current_price - ma_value) / ma_value) * 100 if ma_value and ma_value > 0 else None
indicators = {
"ticker": ticker,
"current_price": signal.current_price,
"ma": {
"ma5": signal.ma5,
"ma10": signal.ma10,
"ma20": signal.ma20,
"ma50": signal.ma50,
"ma200": signal.ma200,
"distance": ma_distance,
},
"rsi": {
"rsi14": signal.rsi14,
"status": "oversold" if signal.rsi14 < 30 else "overbought" if signal.rsi14 > 70 else "neutral",
},
"macd": {
"macd": signal.macd,
"signal": signal.macd_signal,
"histogram": signal.macd - signal.macd_signal,
},
"bollinger": {
"upper": signal.bollinger_upper,
"mid": signal.bollinger_mid,
"lower": signal.bollinger_lower,
},
"volatility": {
"vol_10d": vol_10,
"vol_20d": vol_20,
"vol_60d": vol_60,
"annualized": signal.annualized_volatility_pct,
"risk_level": signal.risk_level,
},
"trend": signal.trend,
"mean_reversion": signal.mean_reversion_signal,
}
await websocket.send(json.dumps({
"type": "stock_technical_indicators_loaded",
"ticker": ticker,
"indicators": indicators,
}, ensure_ascii=False, default=str))
except Exception as exc:
logger.exception("Error getting technical indicators for %s", ticker)
await websocket.send(json.dumps({
"type": "stock_technical_indicators_loaded",
"ticker": ticker,
"indicators": None,
"error": str(exc),
}, ensure_ascii=False))
async def handle_run_stock_enrich(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
ticker = normalize_symbol(data.get("ticker", ""))
start_date = str(data.get("start_date") or "").strip()[:10]
end_date = str(data.get("end_date") or "").strip()[:10]
story_date = str(data.get("story_date") or end_date or "").strip()[:10]
target_date = str(data.get("target_date") or "").strip()[:10]
force = bool(data.get("force", False))
rebuild_story = bool(data.get("rebuild_story", True))
rebuild_similar_days = bool(data.get("rebuild_similar_days", True))
only_local_to_llm = bool(data.get("only_local_to_llm", False))
limit = data.get("limit", 200)
try:
limit = max(10, min(int(limit), 500))
except (TypeError, ValueError):
limit = 200
if not ticker or not start_date or not end_date:
await websocket.send(json.dumps({
"type": "stock_enrich_completed",
"ticker": ticker,
"start_date": start_date,
"end_date": end_date,
"error": "ticker, start_date, end_date are required",
}, ensure_ascii=False))
return
if only_local_to_llm and not llm_enrichment_enabled():
await websocket.send(json.dumps({
"type": "stock_enrich_completed",
"ticker": ticker,
"start_date": start_date,
"end_date": end_date,
"error": "only_local_to_llm requires EXPLAIN_ENRICH_USE_LLM=true and a configured LLM provider",
}, ensure_ascii=False))
return
result = await asyncio.to_thread(
enrich_news_for_symbol,
gateway.storage.market_store,
ticker,
start_date=start_date,
end_date=end_date,
limit=limit,
skip_existing=not force,
only_reanalyze_local=only_local_to_llm,
)
story_status = None
if rebuild_story and story_date:
await asyncio.to_thread(gateway.storage.market_store.delete_story_cache, ticker, as_of_date=story_date)
story_result = await asyncio.to_thread(
news_domain.get_story_payload,
gateway.storage.market_store,
ticker=ticker,
as_of_date=story_date,
)
story_status = {"as_of_date": story_date, "source": story_result.get("source") or "local"}
similar_status = None
if rebuild_similar_days and target_date:
await asyncio.to_thread(gateway.storage.market_store.delete_similar_day_cache, ticker, target_date=target_date)
similar_result = await asyncio.to_thread(
news_domain.get_similar_days_payload,
gateway.storage.market_store,
ticker=ticker,
date=target_date,
n_similar=8,
)
similar_status = {
"target_date": target_date,
"count": len(similar_result.get("items") or []),
"error": similar_result.get("error"),
}
await websocket.send(json.dumps({
"type": "stock_enrich_completed",
"ticker": ticker,
"start_date": start_date,
"end_date": end_date,
"story_date": story_date or None,
"target_date": target_date or None,
"force": force,
"only_local_to_llm": only_local_to_llm,
"stats": result,
"story_status": story_status,
"similar_status": similar_status,
}, ensure_ascii=False, default=str))

View File

@@ -9,7 +9,7 @@ from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Iterable
from backend.data.schema import CompanyNews
from shared.schema import CompanyNews
SCHEMA = """

View File

@@ -0,0 +1,27 @@
# -*- coding: utf-8 -*-
"""Tests for the extracted agent service surface."""
from pathlib import Path
from fastapi.testclient import TestClient
from backend.apps.agent_service import create_app
def test_agent_service_routes_include_control_plane_endpoints(tmp_path):
app = create_app(project_root=tmp_path)
paths = {route.path for route in app.routes}
assert "/health" in paths
assert "/api/status" in paths
assert "/api/workspaces" in paths
assert "/api/guard/pending" in paths
def test_agent_service_excludes_runtime_routes(tmp_path):
app = create_app(project_root=tmp_path)
paths = {route.path for route in app.routes}
assert "/api/runtime/start" not in paths
assert "/api/runtime/gateway/port" not in paths

View File

@@ -0,0 +1,139 @@
# -*- coding: utf-8 -*-
"""Tests for data_tools preferring split services when configured."""
from backend.tools import data_tools
from shared.schema import CompanyNews, FinancialMetrics, InsiderTrade, LineItem, Price
def test_data_tools_prefers_trading_service(monkeypatch):
monkeypatch.setenv("TRADING_SERVICE_URL", "http://localhost:8001")
monkeypatch.setenv("SERVICE_NAME", "agent_service")
monkeypatch.setattr(data_tools._cache, "get_prices", lambda key: None)
monkeypatch.setattr(data_tools._cache, "get_financial_metrics", lambda key: None)
monkeypatch.setattr(data_tools._cache, "get_insider_trades", lambda key: None)
monkeypatch.setattr(data_tools._cache, "get_company_news", lambda key: None)
def fake_service_get_json(base_url, path, *, params):
if path == "/api/prices":
return {
"ticker": "AAPL",
"prices": [
Price(
open=1,
close=2,
high=3,
low=1,
volume=10,
time="2026-03-16",
).model_dump()
],
}
if path == "/api/financials":
return {
"financial_metrics": [
FinancialMetrics(
ticker="AAPL",
report_period="2026-03-16",
period="ttm",
currency="USD",
market_cap=123.0,
enterprise_value=None,
price_to_earnings_ratio=None,
price_to_book_ratio=None,
price_to_sales_ratio=None,
enterprise_value_to_ebitda_ratio=None,
enterprise_value_to_revenue_ratio=None,
free_cash_flow_yield=None,
peg_ratio=None,
gross_margin=None,
operating_margin=None,
net_margin=None,
return_on_equity=None,
return_on_assets=None,
return_on_invested_capital=None,
asset_turnover=None,
inventory_turnover=None,
receivables_turnover=None,
days_sales_outstanding=None,
operating_cycle=None,
working_capital_turnover=None,
current_ratio=None,
quick_ratio=None,
cash_ratio=None,
operating_cash_flow_ratio=None,
debt_to_equity=None,
debt_to_assets=None,
interest_coverage=None,
revenue_growth=None,
earnings_growth=None,
book_value_growth=None,
earnings_per_share_growth=None,
free_cash_flow_growth=None,
operating_income_growth=None,
ebitda_growth=None,
payout_ratio=None,
earnings_per_share=None,
book_value_per_share=None,
free_cash_flow_per_share=None,
).model_dump()
]
}
if path == "/api/insider-trades":
return {
"insider_trades": [
InsiderTrade(ticker="AAPL", filing_date="2026-03-16").model_dump()
]
}
if path == "/api/news":
return {
"news": [
CompanyNews(
ticker="AAPL",
title="Title",
source="polygon",
url="https://example.com",
).model_dump()
]
}
if path == "/api/market-cap":
return {"ticker": "AAPL", "end_date": "2026-03-16", "market_cap": 2.5e12}
if path == "/api/line-items":
return {
"search_results": [
LineItem(
ticker="AAPL",
report_period="2026-03-16",
period="ttm",
currency="USD",
free_cash_flow=321.0,
).model_dump()
]
}
raise AssertionError(path)
monkeypatch.setattr(data_tools, "_service_get_json", fake_service_get_json)
prices = data_tools.get_prices("AAPL", "2026-03-01", "2026-03-16")
metrics = data_tools.get_financial_metrics("AAPL", "2026-03-16")
trades = data_tools.get_insider_trades("AAPL", "2026-03-16")
news = data_tools.get_company_news("AAPL", "2026-03-16")
market_cap = data_tools.get_market_cap("AAPL", "2026-03-16")
line_items = data_tools.search_line_items(
"AAPL",
["free_cash_flow"],
"2026-03-16",
)
assert prices[0].close == 2
assert metrics[0].ticker == "AAPL"
assert trades[0].ticker == "AAPL"
assert news[0].ticker == "AAPL"
assert market_cap == 2.5e12
assert line_items[0].free_cash_flow == 321.0
def test_data_tools_skips_self_recursion_for_trading_service(monkeypatch):
monkeypatch.setenv("TRADING_SERVICE_URL", "http://localhost:8001")
monkeypatch.setenv("SERVICE_NAME", "trading_service")
assert data_tools._trading_service_url() is None

View File

@@ -6,6 +6,7 @@ import pytest
from backend.services.gateway import Gateway
import backend.services.gateway as gateway_module
from shared.schema import InsiderTrade, InsiderTradeResponse, Price, PriceResponse
class DummyWebSocket:
@@ -35,6 +36,10 @@ class FakeMarketStore:
def __init__(self):
self.calls = []
def get_ticker_watermarks(self, symbol):
self.calls.append(("get_ticker_watermarks", symbol))
return {"symbol": symbol, "last_news_fetch": "2026-12-31"}
def get_news_timeline_enriched(self, symbol, *, start_date=None, end_date=None):
self.calls.append(("get_news_timeline_enriched", symbol, start_date, end_date))
return [{"date": end_date, "count": 2, "source_count": 1, "top_title": "Top", "positive_count": 1}]
@@ -123,6 +128,75 @@ def make_gateway(market_store=None):
)
class FakeNewsClient:
def __init__(self, base_url):
self.base_url = base_url
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
return None
async def get_categories(self, ticker, start_date=None, end_date=None, limit=200):
return {"ticker": ticker, "categories": {"remote": {"count": 2}}}
async def get_enriched_news(self, ticker, start_date=None, end_date=None, limit=None):
return {
"ticker": ticker,
"news": [
{
"id": "remote-news-1",
"ticker": ticker,
"title": "Remote Title",
"date": end_date,
}
],
}
async def get_story(self, ticker, as_of_date):
return {"symbol": ticker, "as_of_date": as_of_date, "story": "remote story", "source": "news_service"}
class FakeTradingClient:
def __init__(self, base_url):
self.base_url = base_url
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
return None
async def get_insider_trades(self, ticker, end_date=None, start_date=None, limit=None):
return InsiderTradeResponse(
insider_trades=[
InsiderTrade(
ticker=ticker,
name="Remote Insider",
filing_date=end_date or "2026-03-16",
)
]
)
async def get_prices(self, ticker, start_date=None, end_date=None):
prices = [
Price(
open=float(100 + idx),
close=float(101 + idx),
high=float(102 + idx),
low=float(99 + idx),
volume=1000 + idx,
time=f"2026-01-{idx + 1:02d}",
)
for idx in range(30)
]
return PriceResponse(ticker=ticker, prices=prices)
async def get_market_cap(self, ticker, end_date):
return {"ticker": ticker, "end_date": end_date, "market_cap": 2.5e12}
@pytest.mark.asyncio
async def test_handle_get_stock_news_timeline_uses_market_store_symbol_argument():
market_store = FakeMarketStore()
@@ -135,6 +209,7 @@ async def test_handle_get_stock_news_timeline_uses_market_store_symbol_argument(
)
assert market_store.calls == [
("get_ticker_watermarks", "AAPL"),
("get_news_timeline_enriched", "AAPL", "2026-02-14", "2026-03-16")
]
assert websocket.messages[-1]["type"] == "stock_news_timeline_loaded"
@@ -153,6 +228,7 @@ async def test_handle_get_stock_news_categories_uses_market_store_symbol_argumen
)
assert market_store.calls == [
("get_ticker_watermarks", "AAPL"),
("get_news_items_enriched", "AAPL", "2026-02-14", "2026-03-16", None, 200),
("get_news_categories_enriched", "AAPL", "2026-02-14", "2026-03-16", 200)
]
@@ -175,7 +251,7 @@ async def test_handle_get_stock_range_explain_uses_market_store_rows(monkeypatch
}
monkeypatch.setattr(
gateway_module,
gateway_module.news_domain,
"build_range_explanation",
fake_build_range_explanation,
)
@@ -186,6 +262,7 @@ async def test_handle_get_stock_range_explain_uses_market_store_rows(monkeypatch
)
assert market_store.calls == [
("get_ticker_watermarks", "AAPL"),
("get_news_items_enriched", "AAPL", "2026-03-10", "2026-03-16", None, 100)
]
assert websocket.messages[-1] == {
@@ -207,7 +284,7 @@ async def test_handle_get_stock_range_explain_uses_article_ids_path(monkeypatch)
websocket = DummyWebSocket()
monkeypatch.setattr(
gateway_module,
gateway_module.news_domain,
"build_range_explanation",
lambda **kwargs: {"news_count": len(kwargs["news_rows"])},
)
@@ -222,7 +299,10 @@ async def test_handle_get_stock_range_explain_uses_article_ids_path(monkeypatch)
},
)
assert market_store.calls == [("get_news_by_ids_enriched", "AAPL", ["news-99"])]
assert market_store.calls == [
("get_ticker_watermarks", "AAPL"),
("get_news_by_ids_enriched", "AAPL", ["news-99"])
]
assert websocket.messages[-1]["result"]["news_count"] == 1
@@ -238,6 +318,7 @@ async def test_handle_get_stock_news_for_date_uses_trade_date_lookup():
)
assert market_store.calls == [
("get_ticker_watermarks", "AAPL"),
("get_news_items_enriched", "AAPL", None, None, "2026-03-16", 10)
]
assert websocket.messages[-1]["type"] == "stock_news_for_date_loaded"
@@ -251,7 +332,7 @@ async def test_handle_get_stock_story_returns_story_payload(monkeypatch):
websocket = DummyWebSocket()
monkeypatch.setattr(
gateway_module,
gateway_module.news_domain,
"enrich_news_for_symbol",
lambda *args, **kwargs: {"symbol": "AAPL", "analyzed": 3},
)
@@ -266,6 +347,132 @@ async def test_handle_get_stock_story_returns_story_payload(monkeypatch):
assert "AAPL Story" in websocket.messages[-1]["story"]
@pytest.mark.asyncio
async def test_handle_get_stock_news_categories_uses_news_service_client_when_configured(monkeypatch):
market_store = FakeMarketStore()
gateway = make_gateway(market_store)
websocket = DummyWebSocket()
monkeypatch.setenv("NEWS_SERVICE_URL", "http://news-service.local")
monkeypatch.setattr(gateway_module, "NewsServiceClient", FakeNewsClient)
await gateway._handle_get_stock_news_categories(
websocket,
{"ticker": "AAPL", "lookback_days": 30},
)
assert market_store.calls == []
assert websocket.messages[-1]["type"] == "stock_news_categories_loaded"
assert websocket.messages[-1]["categories"]["remote"]["count"] == 2
@pytest.mark.asyncio
async def test_handle_get_stock_story_uses_news_service_client_when_configured(monkeypatch):
market_store = FakeMarketStore()
gateway = make_gateway(market_store)
websocket = DummyWebSocket()
monkeypatch.setenv("NEWS_SERVICE_URL", "http://news-service.local")
monkeypatch.setattr(gateway_module, "NewsServiceClient", FakeNewsClient)
await gateway._handle_get_stock_story(
websocket,
{"ticker": "AAPL", "as_of_date": "2026-03-16"},
)
assert market_store.calls == []
assert websocket.messages[-1]["type"] == "stock_story_loaded"
assert websocket.messages[-1]["story"] == "remote story"
@pytest.mark.asyncio
async def test_handle_get_stock_news_uses_news_service_client_when_configured(monkeypatch):
market_store = FakeMarketStore()
gateway = make_gateway(market_store)
websocket = DummyWebSocket()
monkeypatch.setenv("NEWS_SERVICE_URL", "http://news-service.local")
monkeypatch.setattr(gateway_module, "NewsServiceClient", FakeNewsClient)
await gateway._handle_get_stock_news(
websocket,
{"ticker": "AAPL", "lookback_days": 30, "limit": 5},
)
assert market_store.calls == []
assert websocket.messages[-1]["type"] == "stock_news_loaded"
assert websocket.messages[-1]["source"] == "news_service"
assert websocket.messages[-1]["news"][0]["title"] == "Remote Title"
@pytest.mark.asyncio
async def test_handle_get_stock_insider_trades_uses_trading_service_client_when_configured(monkeypatch):
market_store = FakeMarketStore()
gateway = make_gateway(market_store)
websocket = DummyWebSocket()
monkeypatch.setenv("TRADING_SERVICE_URL", "http://trading-service.local")
monkeypatch.setattr(gateway_module, "TradingServiceClient", FakeTradingClient)
await gateway._handle_get_stock_insider_trades(
websocket,
{"ticker": "AAPL", "end_date": "2026-03-16", "limit": 10},
)
assert websocket.messages[-1]["type"] == "stock_insider_trades_loaded"
assert websocket.messages[-1]["trades"][0]["name"] == "Remote Insider"
@pytest.mark.asyncio
async def test_handle_get_stock_history_uses_trading_service_client_when_configured(monkeypatch):
market_store = FakeMarketStore()
gateway = make_gateway(market_store)
websocket = DummyWebSocket()
monkeypatch.setenv("TRADING_SERVICE_URL", "http://trading-service.local")
monkeypatch.setattr(gateway_module, "TradingServiceClient", FakeTradingClient)
await gateway._handle_get_stock_history(
websocket,
{"ticker": "AAPL", "lookback_days": 30},
)
assert market_store.calls == []
assert websocket.messages[-1]["type"] == "stock_history_loaded"
assert websocket.messages[-1]["source"] == "trading_service"
assert len(websocket.messages[-1]["prices"]) == 30
@pytest.mark.asyncio
async def test_handle_get_stock_technical_indicators_uses_trading_service_client_when_configured(monkeypatch):
gateway = make_gateway(FakeMarketStore())
websocket = DummyWebSocket()
monkeypatch.setenv("TRADING_SERVICE_URL", "http://trading-service.local")
monkeypatch.setattr(gateway_module, "TradingServiceClient", FakeTradingClient)
await gateway._handle_get_stock_technical_indicators(
websocket,
{"ticker": "AAPL"},
)
assert websocket.messages[-1]["type"] == "stock_technical_indicators_loaded"
assert websocket.messages[-1]["ticker"] == "AAPL"
assert websocket.messages[-1]["indicators"] is not None
@pytest.mark.asyncio
async def test_get_market_caps_uses_trading_service_client_when_configured(monkeypatch):
gateway = make_gateway(FakeMarketStore())
monkeypatch.setenv("TRADING_SERVICE_URL", "http://trading-service.local")
monkeypatch.setattr(gateway_module, "TradingServiceClient", FakeTradingClient)
market_caps = await gateway._get_market_caps(["AAPL", "MSFT"], "2026-03-16")
assert market_caps == {"AAPL": 2.5e12, "MSFT": 2.5e12}
@pytest.mark.asyncio
async def test_handle_get_stock_similar_days_returns_items(monkeypatch):
market_store = FakeMarketStore()
@@ -273,7 +480,7 @@ async def test_handle_get_stock_similar_days_returns_items(monkeypatch):
websocket = DummyWebSocket()
monkeypatch.setattr(
gateway_module,
gateway_module.news_domain,
"enrich_news_for_symbol",
lambda *args, **kwargs: {"symbol": "AAPL", "analyzed": 3},
)
@@ -295,7 +502,12 @@ async def test_handle_run_stock_enrich_rebuilds_caches(monkeypatch):
websocket = DummyWebSocket()
monkeypatch.setattr(
gateway_module,
gateway_module.gateway_stock_handlers,
"enrich_news_for_symbol",
lambda *args, **kwargs: {"symbol": "AAPL", "analyzed": 2, "queued_count": 2},
)
monkeypatch.setattr(
gateway_module.news_domain,
"enrich_news_for_symbol",
lambda *args, **kwargs: {"symbol": "AAPL", "analyzed": 2, "queued_count": 2},
)
@@ -325,7 +537,7 @@ async def test_handle_run_stock_enrich_rejects_local_to_llm_without_llm(monkeypa
gateway = make_gateway(FakeMarketStore())
websocket = DummyWebSocket()
monkeypatch.setattr(gateway_module, "llm_enrichment_enabled", lambda: False)
monkeypatch.setattr(gateway_module.gateway_stock_handlers, "llm_enrichment_enabled", lambda: False)
await gateway._handle_run_stock_enrich(
websocket,
@@ -361,7 +573,7 @@ def test_schedule_watchlist_market_store_refresh_creates_task(monkeypatch):
gateway._schedule_watchlist_market_store_refresh(["AAPL", "MSFT"])
assert captured["coro_name"] == "_refresh_market_store_for_watchlist"
assert captured["coro_name"] == "refresh_market_store_for_watchlist"
@pytest.mark.asyncio
@@ -369,7 +581,7 @@ async def test_refresh_market_store_for_watchlist_emits_system_messages(monkeypa
gateway = make_gateway()
monkeypatch.setattr(
gateway_module,
gateway_module.gateway_cycle_support,
"ingest_symbols",
lambda symbols, mode="incremental": [
{"symbol": symbol, "prices": 3, "news": 4, "aligned": 4}
@@ -445,12 +657,12 @@ async def test_handle_get_agent_profile_returns_model_and_tool_groups(monkeypatc
websocket = DummyWebSocket()
monkeypatch.setattr(
gateway_module,
gateway_module.gateway_admin_handlers,
"load_agent_profiles",
lambda: {"risk_manager": {"skills": ["risk_review"], "active_tool_groups": ["risk_ops", "legacy_group"]}},
)
monkeypatch.setattr(
gateway_module,
gateway_module.gateway_admin_handlers,
"get_agent_model_info",
lambda agent_id: ("gpt-4o-mini", "OPENAI"),
)
@@ -461,7 +673,7 @@ async def test_handle_get_agent_profile_returns_model_and_tool_groups(monkeypatc
return {}
monkeypatch.setattr(
gateway_module,
gateway_module.gateway_admin_handlers,
"get_bootstrap_config_for_run",
lambda project_root, config_name: _Bootstrap(),
)

View File

@@ -0,0 +1,211 @@
# -*- coding: utf-8 -*-
"""Direct tests for Gateway support modules."""
from types import SimpleNamespace
import pytest
from backend.services import gateway_cycle_support, gateway_runtime_support
class _DummyDashboard:
def __init__(self):
self.updated = []
self.tickers = []
self.initial_cash = None
self.enable_memory = False
self.days_total = 0
def update(self, **kwargs):
self.updated.append(kwargs)
def stop(self):
return None
def print_final_summary(self):
return None
class _DummyScheduler:
def __init__(self):
self.calls = []
def reconfigure(self, **kwargs):
self.calls.append(kwargs)
class _DummyStateSync:
def __init__(self):
self.updated = []
self.saved = False
self.system_messages = []
self.backtest_dates = []
self.state = {}
def update_state(self, key, value):
self.updated.append((key, value))
self.state[key] = value
def save_state(self):
self.saved = True
async def on_system_message(self, message):
self.system_messages.append(message)
def set_backtest_dates(self, dates):
self.backtest_dates = list(dates)
class _DummyStorage:
def __init__(self):
self.initial_cash = 100000.0
self.is_live_session_active = False
self.server_state_updates = []
def can_apply_initial_cash(self):
return True
def apply_initial_cash(self, value):
self.initial_cash = value
return True
def update_server_state_from_dashboard(self, state):
self.server_state_updates.append(state)
def load_file(self, name):
if name == "summary":
return {"totalAssetValue": self.initial_cash}
return []
class _DummyPM:
def __init__(self):
self.portfolio = {"margin_requirement": 0.0}
def apply_runtime_portfolio_config(self, margin_requirement=None, initial_cash=None):
if margin_requirement is not None:
self.portfolio["margin_requirement"] = margin_requirement
return {"margin_requirement": True}
def can_apply_initial_cash(self):
return True
class _DummyMarketService:
def __init__(self):
self.updated = None
self.stopped = False
def update_tickers(self, tickers):
self.updated = list(tickers)
return {"active": list(tickers), "added": list(tickers), "removed": []}
def stop(self):
self.stopped = True
def make_gateway_stub():
pipeline = SimpleNamespace(max_comm_cycles=0, pm=_DummyPM())
gateway = SimpleNamespace(
market_service=_DummyMarketService(),
pipeline=pipeline,
scheduler=_DummyScheduler(),
config={
"tickers": ["AAPL"],
"schedule_mode": "daily",
"interval_minutes": 60,
"trigger_time": "09:30",
"enable_memory": False,
},
storage=_DummyStorage(),
state_sync=_DummyStateSync(),
_dashboard=_DummyDashboard(),
_watchlist_ingest_task=None,
_market_status_task=None,
_backtest_task=None,
_backtest_start_date=None,
_backtest_end_date=None,
_manual_cycle_task=None,
)
return gateway
def test_normalize_watchlist_filters_invalid_and_dedupes():
assert gateway_runtime_support.normalize_watchlist(["aapl", " AAPL ", "", "msft"]) == ["AAPL", "MSFT"]
assert gateway_runtime_support.normalize_watchlist("aapl,msft") == ["AAPL", "MSFT"]
def test_normalize_agent_workspace_filename_obeys_allowlist():
allowlist = {"SOUL.md", "PROFILE.md"}
assert gateway_runtime_support.normalize_agent_workspace_filename("SOUL.md", allowlist=allowlist) == "SOUL.md"
assert gateway_runtime_support.normalize_agent_workspace_filename("README.md", allowlist=allowlist) is None
def test_apply_runtime_config_updates_gateway_state():
gateway = make_gateway_stub()
result = gateway_runtime_support.apply_runtime_config(
gateway,
{
"tickers": ["MSFT", "NVDA"],
"schedule_mode": "intraday",
"interval_minutes": 30,
"trigger_time": "10:30",
"initial_cash": 150000.0,
"margin_requirement": 0.5,
"max_comm_cycles": 4,
"enable_memory": False,
},
)
assert gateway.config["tickers"] == ["MSFT", "NVDA"]
assert gateway.config["schedule_mode"] == "intraday"
assert gateway.storage.initial_cash == 150000.0
assert result["runtime_config_applied"]["max_comm_cycles"] == 4
assert gateway.scheduler.calls[-1] == {
"mode": "intraday",
"trigger_time": "10:30",
"interval_minutes": 30,
}
def test_schedule_watchlist_market_store_refresh_creates_task(monkeypatch):
gateway = make_gateway_stub()
captured = {}
class DummyTask:
def done(self):
return False
def cancel(self):
captured["cancelled"] = True
def fake_create_task(coro):
captured["name"] = coro.cr_code.co_name
coro.close()
return DummyTask()
monkeypatch.setattr(gateway_cycle_support.asyncio, "create_task", fake_create_task)
gateway_cycle_support.schedule_watchlist_market_store_refresh(gateway, ["AAPL", "MSFT"])
assert captured["name"] == "refresh_market_store_for_watchlist"
@pytest.mark.asyncio
async def test_refresh_market_store_for_watchlist_emits_system_messages(monkeypatch):
gateway = make_gateway_stub()
monkeypatch.setattr(
gateway_cycle_support,
"ingest_symbols",
lambda symbols, mode="incremental": [
{"symbol": symbol, "prices": 3, "news": 4}
for symbol in symbols
],
)
await gateway_cycle_support.refresh_market_store_for_watchlist(gateway, ["AAPL", "MSFT"])
assert gateway.state_sync.system_messages[0] == "正在同步自选股市场数据: AAPL, MSFT"
assert "自选股市场数据已同步:" in gateway.state_sync.system_messages[1]

View File

@@ -0,0 +1,171 @@
# -*- coding: utf-8 -*-
"""Unit tests for the news domain helpers."""
from backend.domains import news as news_domain
class _FakeStore:
def __init__(self):
self.calls = []
def get_ticker_watermarks(self, symbol):
self.calls.append(("get_ticker_watermarks", symbol))
return {"symbol": symbol, "last_news_fetch": "2026-03-10"}
def get_news_items_enriched(self, ticker, start_date=None, end_date=None, trade_date=None, limit=100):
self.calls.append(("get_news_items_enriched", ticker, start_date, end_date, trade_date, limit))
target = trade_date or end_date
return [{"id": "n1", "ticker": ticker, "date": target, "trade_date": target}]
def get_news_timeline_enriched(self, ticker, start_date=None, end_date=None):
self.calls.append(("get_news_timeline_enriched", ticker, start_date, end_date))
return [{"date": end_date, "count": 1}]
def get_news_categories_enriched(self, ticker, start_date=None, end_date=None, limit=200):
self.calls.append(("get_news_categories_enriched", ticker, start_date, end_date, limit))
return {"macro": {"count": 1}}
def get_news_by_ids_enriched(self, ticker, article_ids):
self.calls.append(("get_news_by_ids_enriched", ticker, list(article_ids)))
return [{"id": article_ids[0], "ticker": ticker, "date": "2026-03-16"}]
def test_news_rows_need_enrichment_detects_missing_fields():
assert news_domain.news_rows_need_enrichment([]) is True
assert news_domain.news_rows_need_enrichment([{"sentiment": "", "relevance": "", "key_discussion": ""}]) is True
assert news_domain.news_rows_need_enrichment([{"sentiment": "positive"}]) is False
def test_ensure_news_fresh_triggers_incremental_refresh_when_watermark_is_stale(monkeypatch):
store = _FakeStore()
calls = []
monkeypatch.setattr(
news_domain,
"update_ticker_incremental",
lambda symbol, end_date=None, store=None: calls.append((symbol, end_date)),
)
payload = news_domain.ensure_news_fresh(store, ticker="AAPL", target_date="2026-03-16")
assert calls == [("AAPL", "2026-03-16")]
assert payload["target_date"] == "2026-03-16"
assert payload["refreshed"] is True
def test_ensure_news_fresh_skips_refresh_when_watermark_is_current(monkeypatch):
store = _FakeStore()
calls = []
monkeypatch.setattr(
store,
"get_ticker_watermarks",
lambda symbol: {"symbol": symbol, "last_news_fetch": "2026-03-16"},
)
monkeypatch.setattr(
news_domain,
"update_ticker_incremental",
lambda symbol, end_date=None, store=None: calls.append((symbol, end_date)),
)
payload = news_domain.ensure_news_fresh(store, ticker="AAPL", target_date="2026-03-16")
assert calls == []
assert payload["refreshed"] is False
def test_get_enriched_news_returns_rows_without_enrichment_when_present(monkeypatch):
store = _FakeStore()
monkeypatch.setattr(news_domain, "news_rows_need_enrichment", lambda rows: False)
monkeypatch.setattr(
news_domain,
"ensure_news_fresh",
lambda store, ticker, target_date=None: {
"ticker": ticker,
"target_date": target_date,
"last_news_fetch": target_date,
"refreshed": False,
},
)
payload = news_domain.get_enriched_news(
store,
ticker="AAPL",
start_date="2026-03-01",
end_date="2026-03-16",
limit=20,
)
assert payload["ticker"] == "AAPL"
assert payload["news"][0]["ticker"] == "AAPL"
assert payload["freshness"]["target_date"] is None or payload["freshness"]["target_date"] == "2026-03-16"
assert store.calls == [
("get_news_items_enriched", "AAPL", "2026-03-01", "2026-03-16", None, 20)
]
def test_get_story_and_similar_days_delegate(monkeypatch):
store = _FakeStore()
monkeypatch.setattr(
news_domain,
"ensure_news_fresh",
lambda store, ticker, target_date=None: {
"ticker": ticker,
"target_date": target_date,
"last_news_fetch": target_date,
"refreshed": False,
},
)
monkeypatch.setattr(news_domain, "enrich_news_for_symbol", lambda *args, **kwargs: {"analyzed": 1})
monkeypatch.setattr(
news_domain,
"get_or_create_stock_story",
lambda store, symbol, as_of_date: {"symbol": symbol, "as_of_date": as_of_date, "story": "story"},
)
monkeypatch.setattr(
news_domain,
"find_similar_days",
lambda store, symbol, target_date, top_k: {"symbol": symbol, "target_date": target_date, "items": [{"score": 0.9}]},
)
story = news_domain.get_story_payload(store, ticker="AAPL", as_of_date="2026-03-16")
similar = news_domain.get_similar_days_payload(store, ticker="AAPL", date="2026-03-16", n_similar=8)
assert story["story"] == "story"
assert "freshness" in story
assert similar["items"][0]["score"] == 0.9
assert "freshness" in similar
def test_get_range_explain_payload_uses_article_ids(monkeypatch):
store = _FakeStore()
monkeypatch.setattr(
news_domain,
"ensure_news_fresh",
lambda store, ticker, target_date=None: {
"ticker": ticker,
"target_date": target_date,
"last_news_fetch": target_date,
"refreshed": False,
},
)
monkeypatch.setattr(news_domain, "news_rows_need_enrichment", lambda rows: False)
monkeypatch.setattr(
news_domain,
"build_range_explanation",
lambda ticker, start_date, end_date, news_rows: {"ticker": ticker, "count": len(news_rows)},
)
payload = news_domain.get_range_explain_payload(
store,
ticker="AAPL",
start_date="2026-03-10",
end_date="2026-03-16",
article_ids=["news-9"],
limit=50,
)
assert payload["ticker"] == "AAPL"
assert payload["result"] == {"ticker": "AAPL", "count": 1}
assert "freshness" in payload
assert store.calls == [("get_news_by_ids_enriched", "AAPL", ["news-9"])]

View File

@@ -0,0 +1,180 @@
# -*- coding: utf-8 -*-
"""Tests for the extracted news service app surface."""
from fastapi.testclient import TestClient
from backend.apps.news_service import create_app
class _FakeStore:
def get_ticker_watermarks(self, symbol):
return {"symbol": symbol, "last_news_fetch": "2026-12-31"}
def get_news_timeline_enriched(self, symbol, start_date=None, end_date=None):
return [{"date": end_date, "count": 1}]
def get_news_items(self, symbol, start_date=None, end_date=None, limit=100):
return [{"id": "news-raw-1", "ticker": symbol, "title": "Raw Title", "date": end_date}]
def get_news_items_enriched(self, symbol, start_date=None, end_date=None, trade_date=None, limit=100):
return [{"id": "news-1", "ticker": symbol, "title": "Title", "date": trade_date or end_date}]
def upsert_news_analysis(self, symbol, rows):
return len(rows)
def get_analyzed_news_ids(self, symbol, start_date=None, end_date=None):
return set()
def get_news_categories_enriched(self, symbol, start_date=None, end_date=None, limit=200):
return {"market": {"label": "market", "count": 1, "article_ids": ["news-1"]}}
def get_news_by_ids_enriched(self, symbol, article_ids):
return [{"id": article_ids[0], "ticker": symbol, "title": "Picked"}]
def test_news_service_routes_are_exposed():
app = create_app()
paths = {route.path for route in app.routes}
assert "/health" in paths
assert "/api/enriched-news" in paths
assert "/api/news-for-date" in paths
assert "/api/news-timeline" in paths
assert "/api/categories" in paths
assert "/api/similar-days" in paths
assert "/api/stories/{ticker}" in paths
assert "/api/range-explain" in paths
def test_news_service_enriched_news_and_categories(monkeypatch):
app = create_app()
app.dependency_overrides.clear()
from backend.apps import news_service as news_service_module
app.dependency_overrides[news_service_module.get_market_store] = lambda: _FakeStore()
monkeypatch.setattr(
"backend.domains.news.enrich_news_for_symbol",
lambda *args, **kwargs: {"symbol": "AAPL", "analyzed": 1},
)
with TestClient(app) as client:
news_response = client.get(
"/api/enriched-news",
params={"ticker": "AAPL", "end_date": "2026-03-23"},
)
categories_response = client.get(
"/api/categories",
params={"ticker": "AAPL", "end_date": "2026-03-23"},
)
assert news_response.status_code == 200
assert news_response.json()["news"][0]["ticker"] == "AAPL"
assert categories_response.status_code == 200
assert categories_response.json()["categories"]["market"]["count"] == 1
def test_news_service_news_for_date_and_timeline(monkeypatch):
app = create_app()
from backend.apps import news_service as news_service_module
app.dependency_overrides[news_service_module.get_market_store] = lambda: _FakeStore()
monkeypatch.setattr(
"backend.domains.news.enrich_news_for_symbol",
lambda *args, **kwargs: {"symbol": "AAPL", "analyzed": 1},
)
with TestClient(app) as client:
date_response = client.get(
"/api/news-for-date",
params={"ticker": "AAPL", "date": "2026-03-23"},
)
timeline_response = client.get(
"/api/news-timeline",
params={
"ticker": "AAPL",
"start_date": "2026-03-01",
"end_date": "2026-03-23",
},
)
assert date_response.status_code == 200
assert date_response.json()["date"] == "2026-03-23"
assert timeline_response.status_code == 200
assert timeline_response.json()["timeline"][0]["count"] == 1
def test_news_service_similar_days_and_story(monkeypatch):
app = create_app()
from backend.apps import news_service as news_service_module
app.dependency_overrides[news_service_module.get_market_store] = lambda: _FakeStore()
monkeypatch.setattr(
"backend.domains.news.enrich_news_for_symbol",
lambda *args, **kwargs: {"symbol": "AAPL", "analyzed": 1},
)
monkeypatch.setattr(
"backend.domains.news.find_similar_days",
lambda store, symbol, target_date, top_k: {
"symbol": symbol,
"target_date": target_date,
"items": [{"date": "2026-03-20", "score": 0.9}],
},
)
monkeypatch.setattr(
"backend.domains.news.get_or_create_stock_story",
lambda store, symbol, as_of_date: {
"symbol": symbol,
"as_of_date": as_of_date,
"story": "story body",
"source": "local",
},
)
with TestClient(app) as client:
similar_response = client.get(
"/api/similar-days",
params={"ticker": "AAPL", "date": "2026-03-23", "n_similar": 3},
)
story_response = client.get(
"/api/stories/AAPL",
params={"as_of_date": "2026-03-23"},
)
assert similar_response.status_code == 200
assert similar_response.json()["items"][0]["score"] == 0.9
assert story_response.status_code == 200
assert story_response.json()["story"] == "story body"
def test_news_service_range_explain(monkeypatch):
app = create_app()
from backend.apps import news_service as news_service_module
app.dependency_overrides[news_service_module.get_market_store] = lambda: _FakeStore()
monkeypatch.setattr(
"backend.domains.news.enrich_news_for_symbol",
lambda *args, **kwargs: {"symbol": "AAPL", "analyzed": 1},
)
monkeypatch.setattr(
"backend.domains.news.build_range_explanation",
lambda ticker, start_date, end_date, news_rows: {
"symbol": ticker,
"news_count": len(news_rows),
"start_date": start_date,
"end_date": end_date,
},
)
with TestClient(app) as client:
response = client.get(
"/api/range-explain",
params={
"ticker": "AAPL",
"start_date": "2026-03-01",
"end_date": "2026-03-23",
"article_ids": ["news-7"],
},
)
assert response.status_code == 200
assert response.json()["result"]["news_count"] == 1

View File

@@ -9,6 +9,7 @@ def test_router_includes_local_csv_fallback(monkeypatch):
monkeypatch.delenv("FINNHUB_API_KEY", raising=False)
monkeypatch.delenv("FINANCIAL_DATASETS_API_KEY", raising=False)
monkeypatch.delenv("FIN_DATA_SOURCE", raising=False)
monkeypatch.delenv("ENABLED_DATA_SOURCES", raising=False)
reset_config()
router = DataProviderRouter()

View File

@@ -0,0 +1,194 @@
# -*- coding: utf-8 -*-
"""Tests for the extracted runtime service app surface."""
import json
from fastapi.testclient import TestClient
from backend.api import runtime as runtime_module
from backend.apps.runtime_service import create_app
def test_runtime_service_routes_are_exposed():
app = create_app()
paths = {route.path for route in app.routes}
assert "/health" in paths
assert "/api/status" in paths
assert "/api/runtime/start" in paths
assert "/api/runtime/stop" in paths
assert "/api/runtime/current" in paths
assert "/api/runtime/gateway/port" in paths
def test_runtime_service_health_and_status(monkeypatch):
runtime_state = runtime_module.get_runtime_state()
runtime_state.gateway_process = None
runtime_state.gateway_port = 9876
runtime_state.runtime_manager = object()
with TestClient(create_app()) as client:
health_response = client.get("/health")
status_response = client.get("/api/status")
assert health_response.status_code == 200
assert health_response.json() == {
"status": "healthy",
"service": "runtime-service",
"gateway_running": False,
"gateway_port": 9876,
}
assert status_response.status_code == 200
assert status_response.json() == {
"status": "operational",
"service": "runtime-service",
"runtime": {
"gateway_running": False,
"gateway_port": 9876,
"has_runtime_manager": True,
},
}
def test_runtime_service_gateway_port_endpoint_uses_runtime_router(monkeypatch):
runtime_module.get_runtime_state().gateway_port = 9345
monkeypatch.setattr(runtime_module, "_is_gateway_running", lambda: True)
with TestClient(create_app()) as client:
response = client.get(
"/api/runtime/gateway/port",
headers={"host": "runtime.example:8003", "x-forwarded-proto": "https"},
)
assert response.status_code == 200
assert response.json() == {
"port": 9345,
"is_running": True,
"ws_url": "wss://runtime.example:9345",
}
def test_runtime_service_get_runtime_config(monkeypatch, tmp_path):
run_dir = tmp_path / "runs" / "demo"
state_dir = run_dir / "state"
state_dir.mkdir(parents=True)
(run_dir / "BOOTSTRAP.md").write_text(
"---\n"
"tickers:\n"
" - AAPL\n"
"schedule_mode: intraday\n"
"interval_minutes: 30\n"
"trigger_time: '10:00'\n"
"max_comm_cycles: 3\n"
"enable_memory: true\n"
"---\n",
encoding="utf-8",
)
(state_dir / "runtime_state.json").write_text(
json.dumps(
{
"context": {
"config_name": "demo",
"run_dir": str(run_dir),
"bootstrap_values": {
"tickers": ["AAPL"],
"schedule_mode": "intraday",
"interval_minutes": 30,
"trigger_time": "10:00",
"max_comm_cycles": 3,
"enable_memory": True,
},
}
}
),
encoding="utf-8",
)
monkeypatch.setattr(runtime_module, "PROJECT_ROOT", tmp_path)
monkeypatch.setattr(runtime_module, "_is_gateway_running", lambda: True)
runtime_module.get_runtime_state().gateway_port = 8765
with TestClient(create_app()) as client:
response = client.get("/api/runtime/config")
assert response.status_code == 200
payload = response.json()
assert payload["run_id"] == "demo"
assert payload["bootstrap"]["schedule_mode"] == "intraday"
assert payload["resolved"]["interval_minutes"] == 30
assert payload["resolved"]["enable_memory"] is True
def test_runtime_service_update_runtime_config_persists_bootstrap(monkeypatch, tmp_path):
run_dir = tmp_path / "runs" / "demo"
state_dir = run_dir / "state"
state_dir.mkdir(parents=True)
(run_dir / "BOOTSTRAP.md").write_text(
"---\n"
"tickers:\n"
" - AAPL\n"
"schedule_mode: daily\n"
"interval_minutes: 60\n"
"trigger_time: '09:30'\n"
"max_comm_cycles: 2\n"
"---\n",
encoding="utf-8",
)
(state_dir / "runtime_state.json").write_text(
json.dumps(
{
"context": {
"config_name": "demo",
"run_dir": str(run_dir),
"bootstrap_values": {
"tickers": ["AAPL"],
"schedule_mode": "daily",
"interval_minutes": 60,
"trigger_time": "09:30",
"max_comm_cycles": 2,
},
}
}
),
encoding="utf-8",
)
class _DummyContext:
def __init__(self):
self.bootstrap_values = {
"tickers": ["AAPL"],
"schedule_mode": "daily",
"interval_minutes": 60,
"trigger_time": "09:30",
"max_comm_cycles": 2,
}
class _DummyManager:
def __init__(self):
self.config_name = "demo"
self.bootstrap = dict(_DummyContext().bootstrap_values)
self.context = _DummyContext()
def _persist_snapshot(self):
return None
monkeypatch.setattr(runtime_module, "PROJECT_ROOT", tmp_path)
monkeypatch.setattr(runtime_module, "_is_gateway_running", lambda: True)
runtime_module.get_runtime_state().runtime_manager = _DummyManager()
runtime_module.get_runtime_state().gateway_port = 8765
with TestClient(create_app()) as client:
response = client.put(
"/api/runtime/config",
json={
"schedule_mode": "intraday",
"interval_minutes": 15,
"trigger_time": "10:15",
"max_comm_cycles": 4,
},
)
assert response.status_code == 200
payload = response.json()
assert payload["bootstrap"]["schedule_mode"] == "intraday"
assert payload["resolved"]["interval_minutes"] == 15
assert "interval_minutes: 15" in (run_dir / "BOOTSTRAP.md").read_text(encoding="utf-8")

View File

@@ -0,0 +1,107 @@
# -*- coding: utf-8 -*-
"""Tests for split-aware shared service clients."""
import pytest
from shared.client.control_client import ControlPlaneClient
from shared.client.runtime_client import RuntimeServiceClient
class _DummyResponse:
def __init__(self, payload):
self._payload = payload
def raise_for_status(self):
return None
def json(self):
return self._payload
class _DummyAsyncClient:
def __init__(self):
self.calls = []
async def get(self, path, params=None):
self.calls.append(("get", path, params))
return _DummyResponse({"path": path, "params": params})
async def post(self, path, json=None):
self.calls.append(("post", path, json))
return _DummyResponse({"path": path, "json": json})
async def put(self, path, json=None):
self.calls.append(("put", path, json))
return _DummyResponse({"path": path, "json": json})
async def aclose(self):
return None
@pytest.mark.asyncio
async def test_control_plane_client_hits_current_workspace_and_guard_routes():
client = ControlPlaneClient()
client._client = _DummyAsyncClient()
await client.list_workspaces()
await client.get_workspace("demo")
await client.list_agents("demo")
await client.get_agent("demo", "risk_manager")
await client.fetch_pending_approvals()
await client.approve_pending_approval("ap-1")
await client.deny_pending_approval("ap-2", reason="nope")
assert client._client.calls == [
("get", "/workspaces", None),
("get", "/workspaces/demo", None),
("get", "/workspaces/demo/agents", None),
("get", "/workspaces/demo/agents/risk_manager", None),
("get", "/guard/pending", None),
(
"post",
"/guard/approve",
{
"approval_id": "ap-1",
"one_time": True,
"expires_in_minutes": 30,
},
),
(
"post",
"/guard/deny",
{
"approval_id": "ap-2",
"reason": "nope",
},
),
]
@pytest.mark.asyncio
async def test_runtime_service_client_hits_current_runtime_routes():
client = RuntimeServiceClient()
client._client = _DummyAsyncClient()
await client.fetch_context()
await client.fetch_agents()
await client.fetch_events()
await client.fetch_gateway_port()
await client.start_runtime({"tickers": ["AAPL"]})
await client.stop_runtime(force=True)
await client.restart_runtime({"tickers": ["MSFT"]})
await client.fetch_current_runtime()
await client.get_runtime_config()
await client.update_runtime_config({"schedule_mode": "intraday"})
assert client._client.calls == [
("get", "/context", None),
("get", "/agents", None),
("get", "/events", None),
("get", "/gateway/port", None),
("post", "/start", {"tickers": ["AAPL"]}),
("post", "/stop?force=true", None),
("post", "/restart", {"tickers": ["MSFT"]}),
("get", "/current", None),
("get", "/config", None),
("put", "/config", {"schedule_mode": "intraday"}),
]

View File

@@ -0,0 +1,32 @@
# -*- coding: utf-8 -*-
"""Regression coverage for the shared schema bridge."""
from backend.data import schema as legacy_schema
from shared import schema as shared_schema
def test_backend_data_schema_reexports_shared_contracts():
assert legacy_schema.Price is shared_schema.Price
assert legacy_schema.PriceResponse is shared_schema.PriceResponse
assert legacy_schema.FinancialMetrics is shared_schema.FinancialMetrics
assert legacy_schema.FinancialMetricsResponse is (
shared_schema.FinancialMetricsResponse
)
assert legacy_schema.LineItem is shared_schema.LineItem
assert legacy_schema.LineItemResponse is shared_schema.LineItemResponse
assert legacy_schema.InsiderTrade is shared_schema.InsiderTrade
assert legacy_schema.InsiderTradeResponse is (
shared_schema.InsiderTradeResponse
)
assert legacy_schema.CompanyNews is shared_schema.CompanyNews
assert legacy_schema.CompanyNewsResponse is shared_schema.CompanyNewsResponse
assert legacy_schema.CompanyFacts is shared_schema.CompanyFacts
assert legacy_schema.CompanyFactsResponse is (
shared_schema.CompanyFactsResponse
)
assert legacy_schema.Position is shared_schema.Position
assert legacy_schema.Portfolio is shared_schema.Portfolio
assert legacy_schema.AnalystSignal is shared_schema.AnalystSignal
assert legacy_schema.TickerAnalysis is shared_schema.TickerAnalysis
assert legacy_schema.AgentStateData is shared_schema.AgentStateData
assert legacy_schema.AgentStateMetadata is shared_schema.AgentStateMetadata

View File

@@ -0,0 +1,47 @@
# -*- coding: utf-8 -*-
"""Unit tests for the trading domain helpers."""
from backend.domains import trading as trading_domain
def test_trading_domain_payload_wrappers(monkeypatch):
monkeypatch.setattr(trading_domain, "get_prices", lambda ticker, start_date, end_date: [{"close": 1}])
monkeypatch.setattr(trading_domain, "get_financial_metrics", lambda ticker, end_date, period, limit: [{"ticker": ticker}])
monkeypatch.setattr(trading_domain, "get_company_news", lambda ticker, end_date, start_date=None, limit=1000: [{"ticker": ticker}])
monkeypatch.setattr(trading_domain, "get_insider_trades", lambda ticker, end_date, start_date=None, limit=1000: [{"ticker": ticker}])
monkeypatch.setattr(trading_domain, "get_market_cap", lambda ticker, end_date: 2.5e12)
assert trading_domain.get_prices_payload(ticker="AAPL", start_date="2026-03-01", end_date="2026-03-16") == {
"ticker": "AAPL",
"prices": [{"close": 1}],
}
assert trading_domain.get_financials_payload(ticker="AAPL", end_date="2026-03-16") == {
"financial_metrics": [{"ticker": "AAPL"}],
}
assert trading_domain.get_news_payload(ticker="AAPL", end_date="2026-03-16") == {
"news": [{"ticker": "AAPL"}],
}
assert trading_domain.get_insider_trades_payload(ticker="AAPL", end_date="2026-03-16") == {
"insider_trades": [{"ticker": "AAPL"}],
}
assert trading_domain.get_market_cap_payload(ticker="AAPL", end_date="2026-03-16") == {
"ticker": "AAPL",
"end_date": "2026-03-16",
"market_cap": 2.5e12,
}
def test_get_market_status_payload_uses_market_service(monkeypatch):
class _FakeMarketService:
def __init__(self, tickers):
self.tickers = tickers
def get_market_status(self):
return {"status": "open", "status_text": "Open"}
monkeypatch.setattr(trading_domain, "MarketService", _FakeMarketService)
assert trading_domain.get_market_status_payload() == {
"status": "open",
"status_text": "Open",
}

View File

@@ -0,0 +1,231 @@
# -*- coding: utf-8 -*-
"""Tests for the extracted trading service app surface."""
from fastapi.testclient import TestClient
from backend.apps.trading_service import create_app
from shared.schema import CompanyNews, FinancialMetrics, InsiderTrade, LineItem, Price
def test_trading_service_routes_are_exposed():
app = create_app()
paths = {route.path for route in app.routes}
assert "/health" in paths
assert "/api/prices" in paths
assert "/api/financials" in paths
assert "/api/news" in paths
assert "/api/insider-trades" in paths
assert "/api/market/status" in paths
assert "/api/market-cap" in paths
assert "/api/line-items" in paths
def test_trading_service_prices_endpoint(monkeypatch):
monkeypatch.setattr(
"backend.domains.trading.get_prices_payload",
lambda ticker, start_date, end_date: {
"ticker": ticker,
"prices": [
Price(
open=1.0,
close=2.0,
high=2.5,
low=0.5,
volume=100,
time="2026-03-20",
)
],
},
)
with TestClient(create_app()) as client:
response = client.get(
"/api/prices",
params={
"ticker": "AAPL",
"start_date": "2026-03-01",
"end_date": "2026-03-20",
},
)
assert response.status_code == 200
assert response.json()["ticker"] == "AAPL"
assert response.json()["prices"][0]["close"] == 2.0
def test_trading_service_financials_endpoint(monkeypatch):
monkeypatch.setattr(
"backend.domains.trading.get_financials_payload",
lambda ticker, end_date, period, limit: {
"financial_metrics": [
FinancialMetrics(
ticker=ticker,
report_period=end_date,
period=period,
currency="USD",
market_cap=123.0,
enterprise_value=None,
price_to_earnings_ratio=None,
price_to_book_ratio=None,
price_to_sales_ratio=None,
enterprise_value_to_ebitda_ratio=None,
enterprise_value_to_revenue_ratio=None,
free_cash_flow_yield=None,
peg_ratio=None,
gross_margin=None,
operating_margin=None,
net_margin=None,
return_on_equity=None,
return_on_assets=None,
return_on_invested_capital=None,
asset_turnover=None,
inventory_turnover=None,
receivables_turnover=None,
days_sales_outstanding=None,
operating_cycle=None,
working_capital_turnover=None,
current_ratio=None,
quick_ratio=None,
cash_ratio=None,
operating_cash_flow_ratio=None,
debt_to_equity=None,
debt_to_assets=None,
interest_coverage=None,
revenue_growth=None,
earnings_growth=None,
book_value_growth=None,
earnings_per_share_growth=None,
free_cash_flow_growth=None,
operating_income_growth=None,
ebitda_growth=None,
payout_ratio=None,
earnings_per_share=None,
book_value_per_share=None,
free_cash_flow_per_share=None,
)
]
},
)
with TestClient(create_app()) as client:
response = client.get(
"/api/financials",
params={"ticker": "AAPL", "end_date": "2026-03-20"},
)
assert response.status_code == 200
assert response.json()["financial_metrics"][0]["ticker"] == "AAPL"
def test_trading_service_news_and_insider_endpoints(monkeypatch):
monkeypatch.setattr(
"backend.domains.trading.get_news_payload",
lambda ticker, end_date, start_date=None, limit=1000: {
"news": [
CompanyNews(
ticker=ticker,
title="News title",
source="polygon",
url="https://example.com/news",
date=end_date,
)
]
},
)
monkeypatch.setattr(
"backend.domains.trading.get_insider_trades_payload",
lambda ticker, end_date, start_date=None, limit=1000: {
"insider_trades": [
InsiderTrade(ticker=ticker, filing_date=end_date)
]
},
)
with TestClient(create_app()) as client:
news_response = client.get(
"/api/news",
params={"ticker": "AAPL", "end_date": "2026-03-20"},
)
insider_response = client.get(
"/api/insider-trades",
params={"ticker": "AAPL", "end_date": "2026-03-20"},
)
assert news_response.status_code == 200
assert news_response.json()["news"][0]["title"] == "News title"
assert insider_response.status_code == 200
assert insider_response.json()["insider_trades"][0]["ticker"] == "AAPL"
def test_trading_service_market_status_endpoint(monkeypatch):
class _FakeMarketService:
def get_market_status(self):
return {"status": "open", "status_text": "Open"}
monkeypatch.setattr(
"backend.domains.trading.get_market_status_payload",
lambda: _FakeMarketService().get_market_status(),
)
with TestClient(create_app()) as client:
response = client.get("/api/market/status")
assert response.status_code == 200
assert response.json() == {"status": "open", "status_text": "Open"}
def test_trading_service_market_cap_endpoint(monkeypatch):
monkeypatch.setattr(
"backend.domains.trading.get_market_cap_payload",
lambda ticker, end_date: {
"ticker": ticker,
"end_date": end_date,
"market_cap": 3.5e12,
},
)
with TestClient(create_app()) as client:
response = client.get(
"/api/market-cap",
params={"ticker": "AAPL", "end_date": "2026-03-20"},
)
assert response.status_code == 200
assert response.json() == {
"ticker": "AAPL",
"end_date": "2026-03-20",
"market_cap": 3.5e12,
}
def test_trading_service_line_items_endpoint(monkeypatch):
monkeypatch.setattr(
"backend.domains.trading.get_line_items_payload",
lambda ticker, line_items, end_date, period, limit: {
"search_results": [
LineItem(
ticker=ticker,
report_period=end_date,
period=period,
currency="USD",
free_cash_flow=123.0,
)
]
},
)
with TestClient(create_app()) as client:
response = client.get(
"/api/line-items",
params=[
("ticker", "AAPL"),
("line_items", "free_cash_flow"),
("end_date", "2026-03-20"),
],
)
assert response.status_code == 200
assert response.json()["search_results"][0]["ticker"] == "AAPL"
assert response.json()["search_results"][0]["free_cash_flow"] == 123.0

View File

@@ -3,13 +3,16 @@
# pylint: disable=C0301
"""Data fetching tools backed by the unified provider router."""
import datetime
import os
import httpx
import pandas as pd
import pandas_market_calendars as mcal
from backend.data.provider_utils import normalize_symbol
from backend.data.cache import get_cache
from backend.data.provider_router import get_provider_router
from backend.data.schema import (
from shared.schema import (
CompanyNews,
FinancialMetrics,
InsiderTrade,
@@ -23,6 +26,31 @@ _cache = get_cache()
_router = get_provider_router()
def _service_name() -> str:
return str(os.getenv("SERVICE_NAME", "")).strip().lower()
def _trading_service_url() -> str | None:
value = str(os.getenv("TRADING_SERVICE_URL", "")).strip().rstrip("/")
if not value or _service_name() == "trading_service":
return None
return value
def _news_service_url() -> str | None:
value = str(os.getenv("NEWS_SERVICE_URL", "")).strip().rstrip("/")
if not value or _service_name() == "news_service":
return None
return value
def _service_get_json(base_url: str, path: str, *, params: dict[str, object]) -> dict:
with httpx.Client(base_url=base_url, timeout=30.0) as client:
response = client.get(path, params=params)
response.raise_for_status()
return response.json()
def get_last_tradeday(date: str) -> str:
"""
Get the previous trading day for the specified date
@@ -104,6 +132,24 @@ def get_prices(
if cached_data := _cache.get_prices(cache_key):
return [Price(**price) for price in cached_data]
service_url = _trading_service_url()
if service_url:
try:
payload = _service_get_json(
service_url,
"/api/prices",
params={
"ticker": ticker,
"start_date": start_date,
"end_date": end_date,
},
)
prices = [Price(**price) for price in payload.get("prices", [])]
if prices:
return prices
except Exception as exc:
logger.info("Trading service price lookup failed for %s: %s", ticker, exc)
try:
prices, data_source = _router.get_prices(ticker, start_date, end_date)
except Exception as exc:
@@ -146,6 +192,28 @@ def get_financial_metrics(
if cached_data := _cache.get_financial_metrics(cache_key):
return [FinancialMetrics(**metric) for metric in cached_data]
service_url = _trading_service_url()
if service_url:
try:
payload = _service_get_json(
service_url,
"/api/financials",
params={
"ticker": ticker,
"end_date": end_date,
"period": period,
"limit": limit,
},
)
metrics = [
FinancialMetrics(**metric)
for metric in payload.get("financial_metrics", [])
]
if metrics:
return metrics
except Exception as exc:
logger.info("Trading service financial lookup failed for %s: %s", ticker, exc)
try:
financial_metrics, data_source = _router.get_financial_metrics(
ticker=ticker,
@@ -183,6 +251,22 @@ def search_line_items(
ticker = normalize_symbol(ticker)
if not ticker:
return []
service_url = _trading_service_url()
if service_url:
payload = _service_get_json(
service_url,
"/api/line-items",
params={
"ticker": ticker,
"line_items": line_items,
"end_date": end_date,
"period": period,
"limit": limit,
},
)
return [LineItem(**item) for item in payload.get("search_results", [])]
return _router.search_line_items(
ticker=ticker,
line_items=line_items,
@@ -213,6 +297,26 @@ def get_insider_trades(
if cached_data := _cache.get_insider_trades(cache_key):
return [InsiderTrade(**trade) for trade in cached_data]
service_url = _trading_service_url()
if service_url:
try:
params = {"ticker": ticker, "end_date": end_date, "limit": limit}
if start_date:
params["start_date"] = start_date
payload = _service_get_json(
service_url,
"/api/insider-trades",
params=params,
)
trades = [
InsiderTrade(**trade)
for trade in payload.get("insider_trades", [])
]
if trades:
return trades
except Exception as exc:
logger.info("Trading service insider lookup failed for %s: %s", ticker, exc)
try:
all_trades, data_source = _router.get_insider_trades(
ticker=ticker,
@@ -248,6 +352,40 @@ def get_company_news(
if cached_data := _cache.get_company_news(cache_key):
return [CompanyNews(**news) for news in cached_data]
trading_service_url = _trading_service_url()
if trading_service_url:
try:
params = {"ticker": ticker, "end_date": end_date, "limit": limit}
if start_date:
params["start_date"] = start_date
payload = _service_get_json(
trading_service_url,
"/api/news",
params=params,
)
news = [CompanyNews(**item) for item in payload.get("news", [])]
if news:
return news
except Exception as exc:
logger.info("Trading service news lookup failed for %s: %s", ticker, exc)
news_service_url = _news_service_url()
if news_service_url:
try:
params = {"ticker": ticker, "end_date": end_date, "limit": limit}
if start_date:
params["start_date"] = start_date
payload = _service_get_json(
news_service_url,
"/api/enriched-news",
params=params,
)
news = [CompanyNews(**item) for item in payload.get("news", [])]
if news:
return news
except Exception as exc:
logger.info("News service lookup failed for %s: %s", ticker, exc)
try:
all_news, data_source = _router.get_company_news(
ticker=ticker,
@@ -272,6 +410,19 @@ def get_market_cap(ticker: str, end_date: str) -> float | None:
if not ticker:
return None
service_url = _trading_service_url()
if service_url:
try:
payload = _service_get_json(
service_url,
"/api/market-cap",
params={"ticker": ticker, "end_date": end_date},
)
value = payload.get("market_cap")
return float(value) if value is not None else None
except Exception as exc:
logger.info("Trading service market-cap lookup failed for %s: %s", ticker, exc)
def _metrics_lookup(symbol: str, date: str):
for source in _router.api_sources():
cache_key = f"{symbol}_ttm_{date}_10_{source}"