feat: 架构修复 - P0/P1 问题全面修复
P0 修复: - runtimeStore: 添加缺失的 lastDayHistory 字段 - Gateway/RuntimeService: 状态同步改为内存优先,消除 glob 竞态 - App.jsx: 从 3075 行重构到 ~500 行,提取 8 个独立文件 P1 修复: - CORS: 4 个服务改为从环境变量读取允许 origins - MarketStore: 改为模块级单例模式 - Domain 层: 删除 trading thin wrapper,保留 news 真实逻辑 - 测试: 补齐 77 个 gateway/runtime 测试 新增文件: - backend/tests/test_gateway.py (43 tests) - frontend/src/hooks/useWebSocketHandler.js - frontend/src/hooks/useStockRequestCallbacks.js - frontend/src/hooks/useAgentCallbacks.js - frontend/src/hooks/useRuntimeCallbacks.js - frontend/src/hooks/useWatchlistCallbacks.js - frontend/src/components/TickerBar.jsx - frontend/src/components/HeaderRight.jsx - frontend/src/components/ChartTabs.jsx Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -302,36 +302,28 @@ def _start_gateway_process(
|
||||
|
||||
@router.get("/context", response_model=RunContextResponse)
|
||||
async def get_run_context() -> RunContextResponse:
|
||||
"""Return the most recent run context."""
|
||||
snapshot_path = PROJECT_ROOT.glob("runs/*/state/runtime_state.json")
|
||||
snapshots = sorted(snapshot_path, key=lambda p: p.stat().st_mtime, reverse=True)
|
||||
|
||||
if not snapshots:
|
||||
"""Return the current run context from in-memory state (avoids glob race condition)."""
|
||||
manager = _runtime_state.runtime_manager
|
||||
if manager is None or manager.context is None:
|
||||
raise HTTPException(status_code=404, detail="No run context available")
|
||||
|
||||
latest = json.loads(snapshots[0].read_text(encoding="utf-8"))
|
||||
context = latest.get("context")
|
||||
if context is None:
|
||||
raise HTTPException(status_code=404, detail="Run context is not ready")
|
||||
|
||||
context = manager.context
|
||||
return RunContextResponse(
|
||||
config_name=context["config_name"],
|
||||
run_dir=context["run_dir"],
|
||||
bootstrap_values=context["bootstrap_values"],
|
||||
config_name=context.config_name,
|
||||
run_dir=str(context.run_dir),
|
||||
bootstrap_values=context.bootstrap_values,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/agents", response_model=RuntimeAgentsResponse)
|
||||
async def get_runtime_agents() -> RuntimeAgentsResponse:
|
||||
"""Return agent states from the most recent run."""
|
||||
snapshot_path = PROJECT_ROOT.glob("runs/*/state/runtime_state.json")
|
||||
snapshots = sorted(snapshot_path, key=lambda p: p.stat().st_mtime, reverse=True)
|
||||
|
||||
if not snapshots:
|
||||
"""Return agent states from the in-memory runtime manager (avoids glob race condition)."""
|
||||
manager = _runtime_state.runtime_manager
|
||||
if manager is None:
|
||||
raise HTTPException(status_code=404, detail="No runtime state available")
|
||||
|
||||
latest = json.loads(snapshots[0].read_text(encoding="utf-8"))
|
||||
agents = latest.get("agents", [])
|
||||
snapshot = manager.build_snapshot()
|
||||
agents = snapshot.get("agents", [])
|
||||
|
||||
return RuntimeAgentsResponse(
|
||||
agents=[RuntimeAgentState(**a) for a in agents]
|
||||
@@ -340,15 +332,13 @@ async def get_runtime_agents() -> RuntimeAgentsResponse:
|
||||
|
||||
@router.get("/events", response_model=RuntimeEventsResponse)
|
||||
async def get_runtime_events() -> RuntimeEventsResponse:
|
||||
"""Return events from the most recent run."""
|
||||
snapshot_path = PROJECT_ROOT.glob("runs/*/state/runtime_state.json")
|
||||
snapshots = sorted(snapshot_path, key=lambda p: p.stat().st_mtime, reverse=True)
|
||||
|
||||
if not snapshots:
|
||||
"""Return events from the in-memory runtime manager (avoids glob race condition)."""
|
||||
manager = _runtime_state.runtime_manager
|
||||
if manager is None:
|
||||
raise HTTPException(status_code=404, detail="No runtime state available")
|
||||
|
||||
latest = json.loads(snapshots[0].read_text(encoding="utf-8"))
|
||||
events = latest.get("events", [])
|
||||
snapshot = manager.build_snapshot()
|
||||
events = snapshot.get("events", [])
|
||||
|
||||
return RuntimeEventsResponse(
|
||||
events=[RuntimeEvent(**e) for e in events]
|
||||
@@ -362,15 +352,10 @@ async def get_gateway_status() -> GatewayStatusResponse:
|
||||
run_id = None
|
||||
|
||||
if is_running:
|
||||
# Try to find run_id from runtime state
|
||||
snapshot_path = PROJECT_ROOT.glob("runs/*/state/runtime_state.json")
|
||||
snapshots = sorted(snapshot_path, key=lambda p: p.stat().st_mtime, reverse=True)
|
||||
if snapshots:
|
||||
try:
|
||||
latest = json.loads(snapshots[0].read_text(encoding="utf-8"))
|
||||
run_id = latest.get("context", {}).get("config_name")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse latest snapshot: {e}")
|
||||
# Get run_id from in-memory runtime manager (avoids glob race condition)
|
||||
manager = _runtime_state.runtime_manager
|
||||
if manager is not None and manager.context is not None:
|
||||
run_id = manager.context.config_name
|
||||
|
||||
return GatewayStatusResponse(
|
||||
is_running=is_running,
|
||||
@@ -404,8 +389,28 @@ def _build_gateway_ws_url(request: Request, port: int) -> str:
|
||||
return f"{ws_scheme}://{host}:{port}"
|
||||
|
||||
|
||||
def _load_latest_runtime_snapshot() -> Dict[str, Any]:
|
||||
"""Load the latest persisted runtime snapshot."""
|
||||
def _get_current_runtime_context() -> Dict[str, Any]:
|
||||
"""Return the active runtime context from the in-memory manager (avoids glob race condition).
|
||||
|
||||
Falls back to file-based lookup only when the in-memory manager is not available
|
||||
(e.g., after a service restart). File-based lookup is deprecated and exists
|
||||
only for backward compatibility.
|
||||
"""
|
||||
if not _is_gateway_running():
|
||||
raise HTTPException(status_code=404, detail="No runtime is currently running")
|
||||
|
||||
# Primary: use in-memory manager (always correct for current process)
|
||||
manager = _runtime_state.runtime_manager
|
||||
if manager is not None and manager.context is not None:
|
||||
ctx = manager.context
|
||||
return {
|
||||
"config_name": ctx.config_name,
|
||||
"run_dir": str(ctx.run_dir),
|
||||
"bootstrap_values": ctx.bootstrap_values,
|
||||
}
|
||||
|
||||
# Deprecated fallback: scan filesystem (only for backward compatibility
|
||||
# after service restart without a restart of the runtime itself)
|
||||
snapshots = sorted(
|
||||
PROJECT_ROOT.glob("runs/*/state/runtime_state.json"),
|
||||
key=lambda p: p.stat().st_mtime,
|
||||
@@ -413,14 +418,7 @@ def _load_latest_runtime_snapshot() -> Dict[str, Any]:
|
||||
)
|
||||
if not snapshots:
|
||||
raise HTTPException(status_code=404, detail="No runtime information available")
|
||||
return json.loads(snapshots[0].read_text(encoding="utf-8"))
|
||||
|
||||
|
||||
def _get_current_runtime_context() -> Dict[str, Any]:
|
||||
"""Return the active runtime context from the latest snapshot."""
|
||||
if not _is_gateway_running():
|
||||
raise HTTPException(status_code=404, detail="No runtime is currently running")
|
||||
latest = _load_latest_runtime_snapshot()
|
||||
latest = json.loads(snapshots[0].read_text(encoding="utf-8"))
|
||||
context = latest.get("context") or {}
|
||||
if not context.get("config_name"):
|
||||
raise HTTPException(status_code=404, detail="No runtime context available")
|
||||
@@ -663,15 +661,8 @@ async def get_current_runtime():
|
||||
if not _is_gateway_running():
|
||||
raise HTTPException(status_code=404, detail="No runtime is currently running")
|
||||
|
||||
# Find latest runtime state
|
||||
snapshot_path = PROJECT_ROOT.glob("runs/*/state/runtime_state.json")
|
||||
snapshots = sorted(snapshot_path, key=lambda p: p.stat().st_mtime, reverse=True)
|
||||
|
||||
if not snapshots:
|
||||
raise HTTPException(status_code=404, detail="No runtime information available")
|
||||
|
||||
latest = json.loads(snapshots[0].read_text(encoding="utf-8"))
|
||||
context = latest.get("context", {})
|
||||
# Get context from in-memory manager (avoids glob race condition)
|
||||
context = _get_current_runtime_context()
|
||||
|
||||
return {
|
||||
"run_id": context.get("config_name"),
|
||||
|
||||
@@ -12,6 +12,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from backend.api import agents_router, guard_router, workspaces_router
|
||||
from backend.agents import AgentFactory, WorkspaceManager, get_registry
|
||||
from backend.config.env_config import get_cors_origins
|
||||
|
||||
# Global instances (initialized on startup)
|
||||
agent_factory: AgentFactory | None = None
|
||||
@@ -49,7 +50,7 @@ def create_app(project_root: Path | None = None) -> FastAPI:
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_origins=get_cors_origins(),
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
|
||||
@@ -10,6 +10,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from backend.data.market_store import MarketStore
|
||||
from backend.domains import news as news_domain
|
||||
from backend.config.env_config import get_cors_origins
|
||||
|
||||
|
||||
def get_market_store() -> MarketStore:
|
||||
@@ -27,7 +28,7 @@ def create_app() -> FastAPI:
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_origins=get_cors_origins(),
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
|
||||
@@ -8,6 +8,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from backend.api import runtime_router
|
||||
from backend.api.runtime import get_runtime_state
|
||||
from backend.config.env_config import get_cors_origins
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
@@ -20,7 +21,7 @@ def create_app() -> FastAPI:
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_origins=get_cors_origins(),
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
|
||||
@@ -8,7 +8,16 @@ from typing import Any
|
||||
from fastapi import FastAPI, Query
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from backend.domains import trading as trading_domain
|
||||
from backend.config.env_config import get_cors_origins
|
||||
from backend.services.market import MarketService
|
||||
from backend.tools.data_tools import (
|
||||
get_company_news,
|
||||
get_financial_metrics,
|
||||
get_insider_trades,
|
||||
get_market_cap,
|
||||
get_prices,
|
||||
search_line_items,
|
||||
)
|
||||
from shared.schema import (
|
||||
CompanyNewsResponse,
|
||||
FinancialMetricsResponse,
|
||||
@@ -28,7 +37,7 @@ def create_app() -> FastAPI:
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_origins=get_cors_origins(),
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
@@ -45,12 +54,8 @@ def create_app() -> FastAPI:
|
||||
start_date: str = Query(...),
|
||||
end_date: str = Query(...),
|
||||
) -> PriceResponse:
|
||||
payload = trading_domain.get_prices_payload(
|
||||
ticker=ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
return PriceResponse(ticker=payload["ticker"], prices=payload["prices"])
|
||||
prices = get_prices(ticker=ticker, start_date=start_date, end_date=end_date)
|
||||
return PriceResponse(ticker=ticker, prices=prices)
|
||||
|
||||
@app.get("/api/financials", response_model=FinancialMetricsResponse)
|
||||
async def api_get_financials(
|
||||
@@ -59,13 +64,13 @@ def create_app() -> FastAPI:
|
||||
period: str = Query("ttm"),
|
||||
limit: int = Query(10, ge=1, le=100),
|
||||
) -> FinancialMetricsResponse:
|
||||
payload = trading_domain.get_financials_payload(
|
||||
metrics = get_financial_metrics(
|
||||
ticker=ticker,
|
||||
end_date=end_date,
|
||||
period=period,
|
||||
limit=limit,
|
||||
)
|
||||
return FinancialMetricsResponse(financial_metrics=payload["financial_metrics"])
|
||||
return FinancialMetricsResponse(financial_metrics=metrics)
|
||||
|
||||
@app.get("/api/news", response_model=CompanyNewsResponse)
|
||||
async def api_get_news(
|
||||
@@ -74,13 +79,13 @@ def create_app() -> FastAPI:
|
||||
start_date: str | None = Query(None),
|
||||
limit: int = Query(1000, ge=1, le=5000),
|
||||
) -> CompanyNewsResponse:
|
||||
payload = trading_domain.get_news_payload(
|
||||
news = get_company_news(
|
||||
ticker=ticker,
|
||||
end_date=end_date,
|
||||
start_date=start_date,
|
||||
limit=limit,
|
||||
)
|
||||
return CompanyNewsResponse(news=payload["news"])
|
||||
return CompanyNewsResponse(news=news)
|
||||
|
||||
@app.get("/api/insider-trades", response_model=InsiderTradeResponse)
|
||||
async def api_get_insider_trades(
|
||||
@@ -89,18 +94,19 @@ def create_app() -> FastAPI:
|
||||
start_date: str | None = Query(None),
|
||||
limit: int = Query(1000, ge=1, le=5000),
|
||||
) -> InsiderTradeResponse:
|
||||
payload = trading_domain.get_insider_trades_payload(
|
||||
trades = get_insider_trades(
|
||||
ticker=ticker,
|
||||
end_date=end_date,
|
||||
start_date=start_date,
|
||||
limit=limit,
|
||||
)
|
||||
return InsiderTradeResponse(insider_trades=payload["insider_trades"])
|
||||
return InsiderTradeResponse(insider_trades=trades)
|
||||
|
||||
@app.get("/api/market/status")
|
||||
async def api_get_market_status() -> dict[str, Any]:
|
||||
"""Return current market status using the existing market service logic."""
|
||||
return trading_domain.get_market_status_payload()
|
||||
service = MarketService(tickers=[])
|
||||
return service.get_market_status()
|
||||
|
||||
@app.get("/api/market-cap")
|
||||
async def api_get_market_cap(
|
||||
@@ -108,10 +114,12 @@ def create_app() -> FastAPI:
|
||||
end_date: str = Query(...),
|
||||
) -> dict[str, Any]:
|
||||
"""Return market cap for one ticker/date."""
|
||||
return trading_domain.get_market_cap_payload(
|
||||
ticker=ticker,
|
||||
end_date=end_date,
|
||||
)
|
||||
market_cap = get_market_cap(ticker=ticker, end_date=end_date)
|
||||
return {
|
||||
"ticker": ticker,
|
||||
"end_date": end_date,
|
||||
"market_cap": market_cap,
|
||||
}
|
||||
|
||||
@app.get("/api/line-items", response_model=LineItemResponse)
|
||||
async def api_get_line_items(
|
||||
@@ -121,14 +129,14 @@ def create_app() -> FastAPI:
|
||||
period: str = Query("ttm"),
|
||||
limit: int = Query(10, ge=1, le=100),
|
||||
) -> LineItemResponse:
|
||||
payload = trading_domain.get_line_items_payload(
|
||||
items = search_line_items(
|
||||
ticker=ticker,
|
||||
line_items=line_items,
|
||||
end_date=end_date,
|
||||
period=period,
|
||||
limit=limit,
|
||||
)
|
||||
return LineItemResponse(search_results=payload["search_results"])
|
||||
return LineItemResponse(search_results=items)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
"""Environment config helpers with light validation and normalization."""
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
@@ -16,6 +17,36 @@ PROVIDER_ALIASES = {
|
||||
"vertexai": "GEMINI",
|
||||
}
|
||||
|
||||
# Default dev CORS origins (localhost variants used by common dev servers)
|
||||
_LOCALHOST_ORIGINS = [
|
||||
"http://localhost:5173",
|
||||
"http://localhost:3000",
|
||||
"http://localhost:8000",
|
||||
"http://127.0.0.1:5173",
|
||||
"http://127.0.0.1:3000",
|
||||
"http://127.0.0.1:8000",
|
||||
]
|
||||
|
||||
|
||||
def get_cors_origins() -> list[str]:
|
||||
"""Get CORS allowed origins from environment.
|
||||
|
||||
Reads CORS_ALLOWED_ORIGINS env var (comma-separated).
|
||||
Falls back to localhost dev origins if not set.
|
||||
Warns if "*" is configured (only acceptable for local dev).
|
||||
"""
|
||||
origins = get_env_list("CORS_ALLOWED_ORIGINS", default=[])
|
||||
if origins:
|
||||
if "*" in origins:
|
||||
warnings.warn(
|
||||
"CORS_ALLOWED_ORIGINS contains '*' — this allows any origin. "
|
||||
"Only use in local development, never in production.",
|
||||
UserWarning,
|
||||
)
|
||||
return origins
|
||||
# Fallback: local dev only
|
||||
return _LOCALHOST_ORIGINS
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AgentModelConfig:
|
||||
|
||||
@@ -8,7 +8,7 @@ import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.data.market_ingest import ingest_symbols
|
||||
from backend.domains import trading as trading_domain
|
||||
from backend.tools.data_tools import get_market_cap
|
||||
from backend.utils.msg_adapter import FrontendAdapter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -265,8 +265,7 @@ async def get_market_caps(gateway: Any, tickers: list[str], date: str) -> dict[s
|
||||
if response is not None:
|
||||
market_cap = response.get("market_cap")
|
||||
if market_cap is None:
|
||||
payload = trading_domain.get_market_cap_payload(ticker=ticker, end_date=date)
|
||||
market_cap = payload.get("market_cap")
|
||||
market_cap = get_market_cap(ticker=ticker, end_date=date)
|
||||
market_caps[ticker] = market_cap if market_cap else 1e9
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to get market cap for %s, using default 1e9: %s", ticker, exc)
|
||||
|
||||
@@ -11,10 +11,9 @@ from typing import Any
|
||||
|
||||
from backend.data.provider_utils import normalize_symbol
|
||||
from backend.domains import news as news_domain
|
||||
from backend.domains import trading as trading_domain
|
||||
from backend.enrich.news_enricher import enrich_news_for_symbol
|
||||
from backend.enrich.llm_enricher import llm_enrichment_enabled
|
||||
from backend.tools.data_tools import prices_to_df
|
||||
from backend.tools.data_tools import get_insider_trades, get_prices, prices_to_df
|
||||
from shared.client import NewsServiceClient, TradingServiceClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -59,13 +58,12 @@ async def handle_get_stock_history(gateway: Any, websocket: Any, data: dict[str,
|
||||
if not prices:
|
||||
prices = await asyncio.to_thread(gateway.storage.market_store.get_ohlc, ticker, start_date, end_date)
|
||||
if not prices:
|
||||
payload = await asyncio.to_thread(
|
||||
trading_domain.get_prices_payload,
|
||||
prices = await asyncio.to_thread(
|
||||
get_prices,
|
||||
ticker=ticker,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
prices = payload.get("prices") or []
|
||||
usage_snapshot = gateway._provider_router.get_usage_snapshot()
|
||||
source = usage_snapshot.get("last_success", {}).get("prices")
|
||||
if prices:
|
||||
@@ -400,14 +398,13 @@ async def handle_get_stock_insider_trades(gateway: Any, websocket: Any, data: di
|
||||
trades = response.insider_trades
|
||||
|
||||
if not trades:
|
||||
payload = await asyncio.to_thread(
|
||||
trading_domain.get_insider_trades_payload,
|
||||
trades = await asyncio.to_thread(
|
||||
get_insider_trades,
|
||||
ticker=ticker,
|
||||
end_date=end_date,
|
||||
start_date=start_date if start_date else None,
|
||||
limit=limit,
|
||||
)
|
||||
trades = payload.get("insider_trades") or []
|
||||
|
||||
sorted_trades = sorted(trades, key=lambda t: t.transaction_date or "", reverse=True)
|
||||
formatted_trades = [{
|
||||
@@ -540,12 +537,11 @@ async def handle_get_stock_technical_indicators(gateway: Any, websocket: Any, da
|
||||
prices = response.prices
|
||||
|
||||
if prices is None:
|
||||
payload = trading_domain.get_prices_payload(
|
||||
prices = get_prices(
|
||||
ticker=ticker,
|
||||
start_date=start_date.strftime("%Y-%m-%d"),
|
||||
end_date=end_date.strftime("%Y-%m-%d"),
|
||||
)
|
||||
prices = payload.get("prices") or []
|
||||
|
||||
if not prices or len(prices) < 20:
|
||||
await websocket.send(json.dumps({
|
||||
|
||||
549
backend/tests/test_gateway.py
Normal file
549
backend/tests/test_gateway.py
Normal file
@@ -0,0 +1,549 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Tests for the Gateway main class - core behavior and fallback paths."""
|
||||
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.services.gateway import Gateway
|
||||
import backend.services.gateway as gateway_module
|
||||
|
||||
|
||||
class DummyWebSocket:
|
||||
def __init__(self):
|
||||
self.messages = []
|
||||
self.closed = False
|
||||
self._queue = []
|
||||
|
||||
def queue(self, data: str):
|
||||
"""Queue a raw message string to be yielded by the async iterator."""
|
||||
self._queue.append(data)
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
if not self._queue:
|
||||
raise StopAsyncIteration
|
||||
return self._queue.pop(0)
|
||||
|
||||
async def send(self, payload: str):
|
||||
self.messages.append(json.loads(payload))
|
||||
|
||||
async def close(self):
|
||||
self.closed = True
|
||||
|
||||
|
||||
class DummyStateSync:
|
||||
def __init__(self, current_date="2026-03-16"):
|
||||
self.state = {"current_date": current_date}
|
||||
self.system_messages = []
|
||||
self.saved = False
|
||||
self.initial_state_payload = {}
|
||||
|
||||
def set_broadcast_fn(self, _fn):
|
||||
return None
|
||||
|
||||
def update_state(self, key, value):
|
||||
self.state[key] = value
|
||||
|
||||
def save_state(self):
|
||||
self.saved = True
|
||||
|
||||
async def on_system_message(self, message):
|
||||
self.system_messages.append(message)
|
||||
|
||||
def get_initial_state_payload(self, include_dashboard=True):
|
||||
return {
|
||||
"status": "running",
|
||||
"current_date": self.state.get("current_date", ""),
|
||||
"portfolio": {},
|
||||
"holdings": [],
|
||||
"trades": [],
|
||||
}
|
||||
|
||||
|
||||
class DummyMarketService:
|
||||
def __init__(self):
|
||||
self.broadcast_func = None
|
||||
self.market_status = {"is_open": True, "session": "regular"}
|
||||
|
||||
def set_price_recorder(self, _fn):
|
||||
return None
|
||||
|
||||
async def start(self, broadcast_func=None):
|
||||
self.broadcast_func = broadcast_func
|
||||
|
||||
def get_market_status(self):
|
||||
return self.market_status
|
||||
|
||||
|
||||
class DummyStorage:
|
||||
def __init__(self, initial_cash=100000.0, live_session=False):
|
||||
self.initial_cash = initial_cash
|
||||
self.is_live_session_active = live_session
|
||||
self._market_store = SimpleNamespace()
|
||||
|
||||
@property
|
||||
def market_store(self):
|
||||
return self._market_store
|
||||
|
||||
def load_file(self, name):
|
||||
if name == "summary":
|
||||
return {"totalAssetValue": self.initial_cash}
|
||||
if name in ("holdings", "trades"):
|
||||
return []
|
||||
return None
|
||||
|
||||
def get_live_returns(self):
|
||||
return {"session_pnl": 0.0, "session_return": 0.0}
|
||||
|
||||
|
||||
def make_gateway(market_service=None, storage=None, state_sync=None, config=None):
|
||||
storage = storage or DummyStorage()
|
||||
state_sync = state_sync or DummyStateSync()
|
||||
market_service = market_service or DummyMarketService()
|
||||
pipeline = SimpleNamespace(state_sync=state_sync, max_comm_cycles=0, pm=SimpleNamespace(portfolio={"margin_requirement": 0.0}))
|
||||
return Gateway(
|
||||
market_service=market_service,
|
||||
storage_service=storage,
|
||||
pipeline=pipeline,
|
||||
state_sync=state_sync,
|
||||
config=config or {"mode": "live"},
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Gateway initialization and core properties
|
||||
# =============================================================================
|
||||
|
||||
def test_gateway_init_sets_live_mode():
|
||||
gateway = make_gateway(config={"mode": "live"})
|
||||
assert gateway.mode == "live"
|
||||
assert gateway.is_backtest is False
|
||||
|
||||
|
||||
def test_gateway_init_sets_backtest_mode_from_config():
|
||||
gateway = make_gateway(config={"mode": "backtest"})
|
||||
assert gateway.mode == "backtest"
|
||||
assert gateway.is_backtest is True
|
||||
|
||||
|
||||
def test_gateway_init_sets_backtest_mode_from_flag():
|
||||
gateway = make_gateway(config={"backtest_mode": True, "mode": "live"})
|
||||
assert gateway.is_backtest is True
|
||||
|
||||
|
||||
def test_gateway_init_defaults_to_live_mode():
|
||||
gateway = make_gateway(config={})
|
||||
assert gateway.mode == "live"
|
||||
assert gateway.is_backtest is False
|
||||
|
||||
|
||||
def test_gateway_state_property_returns_state_sync_state():
|
||||
state_sync = DummyStateSync()
|
||||
state_sync.state["foo"] = "bar"
|
||||
gateway = make_gateway(state_sync=state_sync)
|
||||
assert gateway.state["foo"] == "bar"
|
||||
|
||||
|
||||
def test_gateway_news_rows_need_enrichment_delegates_to_news_domain():
|
||||
rows = [{"id": "1"}, {"id": "2"}]
|
||||
with patch.object(gateway_module.news_domain, "news_rows_need_enrichment", return_value=True) as mock:
|
||||
result = Gateway._news_rows_need_enrichment(rows)
|
||||
mock.assert_called_once_with(rows)
|
||||
assert result is True
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Service URL helpers and fallback paths
|
||||
# =============================================================================
|
||||
|
||||
def test_news_service_url_returns_config_value(monkeypatch):
|
||||
gateway = make_gateway(config={"news_service_url": "http://custom-news:9000"})
|
||||
assert gateway._news_service_url() == "http://custom-news:9000"
|
||||
|
||||
|
||||
def test_news_service_url_falls_back_to_env(monkeypatch):
|
||||
monkeypatch.setenv("NEWS_SERVICE_URL", "http://env-news:9001")
|
||||
gateway = make_gateway(config={})
|
||||
assert gateway._news_service_url() == "http://env-news:9001"
|
||||
|
||||
|
||||
def test_news_service_url_returns_none_when_unset(monkeypatch):
|
||||
monkeypatch.delenv("NEWS_SERVICE_URL", raising=False)
|
||||
gateway = make_gateway(config={})
|
||||
assert gateway._news_service_url() is None
|
||||
|
||||
|
||||
def test_news_service_url_strips_whitespace(monkeypatch):
|
||||
gateway = make_gateway(config={"news_service_url": " http://whitespace-news:9000 "})
|
||||
assert gateway._news_service_url() == "http://whitespace-news:9000"
|
||||
|
||||
|
||||
def test_trading_service_url_returns_config_value(monkeypatch):
|
||||
gateway = make_gateway(config={"trading_service_url": "http://custom-trading:9000"})
|
||||
assert gateway._trading_service_url() == "http://custom-trading:9000"
|
||||
|
||||
|
||||
def test_trading_service_url_falls_back_to_env(monkeypatch):
|
||||
monkeypatch.setenv("TRADING_SERVICE_URL", "http://env-trading:9001")
|
||||
gateway = make_gateway(config={})
|
||||
assert gateway._trading_service_url() == "http://env-trading:9001"
|
||||
|
||||
|
||||
def test_trading_service_url_returns_none_when_unset(monkeypatch):
|
||||
monkeypatch.delenv("TRADING_SERVICE_URL", raising=False)
|
||||
gateway = make_gateway(config={})
|
||||
assert gateway._trading_service_url() is None
|
||||
|
||||
|
||||
def test_trading_service_url_strips_whitespace(monkeypatch):
|
||||
gateway = make_gateway(config={"trading_service_url": " http://whitespace-trading:9000 "})
|
||||
assert gateway._trading_service_url() == "http://whitespace-trading:9000"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_news_service_returns_none_when_url_not_set(monkeypatch):
|
||||
monkeypatch.delenv("NEWS_SERVICE_URL", raising=False)
|
||||
gateway = make_gateway(config={})
|
||||
|
||||
async def dummy_callback(client):
|
||||
return "should not be called"
|
||||
|
||||
result = await gateway._call_news_service("test_action", dummy_callback)
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_news_service_calls_callback_and_returns():
|
||||
gateway = make_gateway(config={"news_service_url": "http://news:9000"})
|
||||
|
||||
async def callback(client):
|
||||
return {"result": "ok"}
|
||||
|
||||
result = await gateway._call_news_service("test_action", callback)
|
||||
assert result == {"result": "ok"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_news_service_returns_none_on_exception():
|
||||
gateway = make_gateway(config={"news_service_url": "http://news:9000"})
|
||||
|
||||
async def failing_callback(client):
|
||||
raise RuntimeError("connection failed")
|
||||
|
||||
result = await gateway._call_news_service("test_action", failing_callback)
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_trading_service_returns_none_when_url_not_set(monkeypatch):
|
||||
monkeypatch.delenv("TRADING_SERVICE_URL", raising=False)
|
||||
gateway = make_gateway(config={})
|
||||
|
||||
result = await gateway._call_trading_service("test_action", lambda c: None)
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_trading_service_calls_callback_and_returns():
|
||||
gateway = make_gateway(config={"trading_service_url": "http://trading:9000"})
|
||||
|
||||
async def callback(client):
|
||||
return {"result": "ok"}
|
||||
result = await gateway._call_trading_service("test_action", callback)
|
||||
assert result == {"result": "ok"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_trading_service_returns_none_on_exception():
|
||||
gateway = make_gateway(config={"trading_service_url": "http://trading:9000"})
|
||||
|
||||
async def failing_callback(client):
|
||||
raise RuntimeError("connection failed")
|
||||
|
||||
result = await gateway._call_trading_service("test_action", failing_callback)
|
||||
assert result is None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# WebSocket message handlers
|
||||
# =============================================================================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_client_messages_ping_returns_pong():
|
||||
"""Ping message type results in a pong response."""
|
||||
gateway = make_gateway()
|
||||
ws = DummyWebSocket()
|
||||
ws.queue(json.dumps({"type": "ping"}))
|
||||
|
||||
await gateway._handle_client_messages(ws)
|
||||
|
||||
assert ws.messages[-1]["type"] == "pong"
|
||||
assert "timestamp" in ws.messages[-1]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_client_messages_get_state_sends_initial_state():
|
||||
"""get_state message type triggers _send_initial_state."""
|
||||
gateway = make_gateway()
|
||||
ws = DummyWebSocket()
|
||||
ws.queue(json.dumps({"type": "get_state"}))
|
||||
|
||||
with patch.object(gateway, "_send_initial_state", AsyncMock()) as mock_send:
|
||||
await gateway._handle_client_messages(ws)
|
||||
mock_send.assert_called_once_with(ws)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_client_messages_unknown_type_is_silently_ignored():
|
||||
"""Unknown message types are silently ignored without error."""
|
||||
gateway = make_gateway()
|
||||
ws = DummyWebSocket()
|
||||
ws.queue(json.dumps({"type": "unknown_type"}))
|
||||
|
||||
# Should not raise
|
||||
await gateway._handle_client_messages(ws)
|
||||
assert len(ws.messages) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_client_messages_json_decode_error_is_silently_ignored():
|
||||
"""Invalid JSON messages are caught by the handler's except block."""
|
||||
gateway = make_gateway()
|
||||
ws = DummyWebSocket()
|
||||
ws.queue("not valid json")
|
||||
|
||||
# Should not raise
|
||||
await gateway._handle_client_messages(ws)
|
||||
assert len(ws.messages) == 0
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Backtest handling
|
||||
# =============================================================================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_start_backtest_ignored_when_not_backtest_mode():
|
||||
gateway = make_gateway(config={"mode": "live"})
|
||||
# Should not raise - backtest is ignored in live mode
|
||||
await gateway._handle_start_backtest({"dates": ["2026-03-01", "2026-03-02"]})
|
||||
# Gateway should not have started a backtest task
|
||||
assert gateway._backtest_task is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_start_backtest_ignored_when_task_already_running():
|
||||
gateway = make_gateway(config={"mode": "backtest"})
|
||||
|
||||
# Pre-set a backtest task
|
||||
dummy_task = MagicMock()
|
||||
dummy_task.done.return_value = False
|
||||
gateway._backtest_task = dummy_task
|
||||
|
||||
# Should not start a new task
|
||||
await gateway._handle_start_backtest({"dates": ["2026-03-01"]})
|
||||
|
||||
assert gateway._backtest_task is dummy_task # unchanged
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Manual trigger (live/mock mode)
|
||||
# =============================================================================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_manual_trigger_rejected_in_backtest_mode():
|
||||
gateway = make_gateway(config={"mode": "backtest"})
|
||||
ws = DummyWebSocket()
|
||||
|
||||
await gateway._handle_manual_trigger(ws, {"date": "2026-03-16"})
|
||||
|
||||
assert any(m["type"] == "error" and "manual trigger" in m["message"].lower() for m in ws.messages)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_manual_trigger_rejected_when_cycle_already_running():
|
||||
gateway = make_gateway(config={"mode": "live"})
|
||||
ws = DummyWebSocket()
|
||||
|
||||
# Simulate a running cycle task
|
||||
dummy_task = MagicMock()
|
||||
dummy_task.done.return_value = False
|
||||
gateway._manual_cycle_task = dummy_task
|
||||
|
||||
await gateway._handle_manual_trigger(ws, {"date": "2026-03-16"})
|
||||
|
||||
assert any(m["type"] == "error" and "already running" in m["message"].lower() for m in ws.messages)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Normalization helpers
|
||||
# =============================================================================
|
||||
|
||||
def test_normalize_watchlist_filters_empty_and_dedupes():
|
||||
result = Gateway._normalize_watchlist(["aapl", " AAPL ", "", "msft", "MSFT", " "])
|
||||
assert result == ["AAPL", "MSFT"]
|
||||
|
||||
|
||||
def test_normalize_watchlist_handles_string_input():
|
||||
result = Gateway._normalize_watchlist("aapl, msft, aapl")
|
||||
assert result == ["AAPL", "MSFT"]
|
||||
|
||||
|
||||
def test_normalize_agent_workspace_filename_allows_editable_files():
|
||||
for filename in ["SOUL.md", "PROFILE.md", "AGENTS.md", "MEMORY.md", "POLICY.md"]:
|
||||
result = Gateway._normalize_agent_workspace_filename(filename)
|
||||
assert result == filename
|
||||
|
||||
|
||||
def test_normalize_agent_workspace_filename_rejects_non_editable_files():
|
||||
result = Gateway._normalize_agent_workspace_filename("README.md")
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_normalize_agent_workspace_filename_rejects_arbitrary_paths():
|
||||
result = Gateway._normalize_agent_workspace_filename("../etc/passwd")
|
||||
assert result is None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Broadcast
|
||||
# =============================================================================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_skips_when_no_clients():
|
||||
gateway = make_gateway()
|
||||
gateway.connected_clients = set()
|
||||
|
||||
# Should not raise
|
||||
await gateway.broadcast({"type": "test"})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_sends_to_all_connected_clients():
|
||||
gateway = make_gateway()
|
||||
ws1 = DummyWebSocket()
|
||||
ws2 = DummyWebSocket()
|
||||
gateway.connected_clients = {ws1, ws2}
|
||||
|
||||
await gateway.broadcast({"type": "market_update", "data": "test"})
|
||||
|
||||
assert all(m["type"] == "market_update" for m in ws1.messages + ws2.messages)
|
||||
assert ws1.messages[0]["data"] == "test"
|
||||
assert ws2.messages[0]["data"] == "test"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_removes_closed_connections():
|
||||
"""Verify closed connections are removed from connected_clients set.
|
||||
|
||||
The broadcast method's _send_to_client helper removes a client
|
||||
when it raises websockets.ConnectionClosed.
|
||||
"""
|
||||
gateway = make_gateway()
|
||||
closed_ws = DummyWebSocket()
|
||||
open_ws = DummyWebSocket()
|
||||
gateway.connected_clients = {closed_ws, open_ws}
|
||||
|
||||
# Make closed_ws.send raise ConnectionClosed so the original
|
||||
# _send_to_client's except block triggers and removes it
|
||||
original_send = closed_ws.send
|
||||
async def raising_send(payload):
|
||||
raise gateway_module.websockets.ConnectionClosed(None, None)
|
||||
closed_ws.send = raising_send
|
||||
|
||||
try:
|
||||
await gateway.broadcast({"type": "test"})
|
||||
except gateway_module.websockets.ConnectionClosed:
|
||||
pass
|
||||
|
||||
# The closed client should have been removed, open client should remain
|
||||
assert closed_ws not in gateway.connected_clients
|
||||
assert open_ws in gateway.connected_clients
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_sends_to_all_connected_clients():
|
||||
"""Verify broadcast sends to all connected clients and collects results."""
|
||||
gateway = make_gateway()
|
||||
ws1 = DummyWebSocket()
|
||||
ws2 = DummyWebSocket()
|
||||
gateway.connected_clients = {ws1, ws2}
|
||||
|
||||
await gateway.broadcast({"type": "market_update", "data": "test"})
|
||||
|
||||
assert all(m["type"] == "market_update" for m in ws1.messages + ws2.messages)
|
||||
assert ws1.messages[0]["data"] == "test"
|
||||
assert ws2.messages[0]["data"] == "test"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Stop
|
||||
# =============================================================================
|
||||
|
||||
def test_stop_gateway_calls_cycle_support():
|
||||
gateway = make_gateway()
|
||||
with patch.object(gateway_module.gateway_cycle_support, "stop_gateway") as mock:
|
||||
gateway.stop()
|
||||
mock.assert_called_once_with(gateway)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# set_backtest_dates
|
||||
# =============================================================================
|
||||
|
||||
def test_set_backtest_dates_delegates_to_cycle_support():
|
||||
gateway = make_gateway()
|
||||
with patch.object(gateway_module.gateway_cycle_support, "set_backtest_dates") as mock:
|
||||
gateway.set_backtest_dates(["2026-03-01", "2026-03-02"])
|
||||
mock.assert_called_once_with(gateway, ["2026-03-01", "2026-03-02"])
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Provider usage change callback
|
||||
# =============================================================================
|
||||
|
||||
def test_on_provider_usage_changed_updates_state_sync():
|
||||
"""_on_provider_usage_changed updates state_sync with the provider snapshot."""
|
||||
gateway = make_gateway()
|
||||
gateway._loop = None # no loop set
|
||||
|
||||
snapshot = {"provider": "finnhub", "calls": 10}
|
||||
gateway._on_provider_usage_changed(snapshot)
|
||||
|
||||
# State sync should be updated
|
||||
assert gateway.state_sync.state.get("data_sources") == snapshot
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# handle_client lifecycle
|
||||
# =============================================================================
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_client_adds_and_removes_client_from_connected_set():
|
||||
gateway = make_gateway()
|
||||
ws = DummyWebSocket()
|
||||
|
||||
with patch.object(gateway, "_send_initial_state", AsyncMock()):
|
||||
with patch.object(gateway, "_handle_client_messages", AsyncMock()):
|
||||
await gateway.handle_client(ws)
|
||||
|
||||
# Client should be removed from connected set after handler returns
|
||||
assert ws not in gateway.connected_clients
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_client_adds_client_before_handler():
|
||||
gateway = make_gateway()
|
||||
ws = DummyWebSocket()
|
||||
|
||||
with patch.object(gateway, "_send_initial_state", AsyncMock()):
|
||||
with patch.object(gateway, "_handle_client_messages", AsyncMock()):
|
||||
await gateway.handle_client(ws)
|
||||
|
||||
# Client was added at start
|
||||
# But removed at end (via lock)
|
||||
assert ws not in gateway.connected_clients
|
||||
@@ -1,14 +1,31 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Tests for the extracted runtime service app surface."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from backend.api import runtime as runtime_module
|
||||
from backend.apps.runtime_service import create_app
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_runtime_module_state():
|
||||
"""Reset module-level runtime_manager before each test."""
|
||||
runtime_module.runtime_manager = None
|
||||
# Also reset RuntimeState singleton's _runtime_manager
|
||||
rs = runtime_module.get_runtime_state()
|
||||
rs._runtime_manager = None
|
||||
yield
|
||||
runtime_module.runtime_manager = None
|
||||
rs = runtime_module.get_runtime_state()
|
||||
rs._runtime_manager = None
|
||||
|
||||
|
||||
def test_runtime_service_routes_are_exposed():
|
||||
app = create_app()
|
||||
paths = {route.path for route in app.routes}
|
||||
@@ -153,7 +170,9 @@ def test_runtime_service_update_runtime_config_persists_bootstrap(monkeypatch, t
|
||||
)
|
||||
|
||||
class _DummyContext:
|
||||
def __init__(self):
|
||||
def __init__(self, run_dir):
|
||||
self.config_name = "demo"
|
||||
self.run_dir = run_dir
|
||||
self.bootstrap_values = {
|
||||
"tickers": ["AAPL"],
|
||||
"schedule_mode": "daily",
|
||||
@@ -165,8 +184,17 @@ def test_runtime_service_update_runtime_config_persists_bootstrap(monkeypatch, t
|
||||
class _DummyManager:
|
||||
def __init__(self):
|
||||
self.config_name = "demo"
|
||||
self.bootstrap = dict(_DummyContext().bootstrap_values)
|
||||
self.context = _DummyContext()
|
||||
self.bootstrap = dict(_DummyContext(run_dir).bootstrap_values)
|
||||
self.context = _DummyContext(run_dir)
|
||||
|
||||
def build_snapshot(self):
|
||||
return {
|
||||
"context": {
|
||||
"config_name": self.context.config_name,
|
||||
"run_dir": str(self.context.run_dir),
|
||||
"bootstrap_values": self.context.bootstrap_values,
|
||||
}
|
||||
}
|
||||
|
||||
def _persist_snapshot(self):
|
||||
return None
|
||||
@@ -192,3 +220,385 @@ def test_runtime_service_update_runtime_config_persists_bootstrap(monkeypatch, t
|
||||
assert payload["bootstrap"]["schedule_mode"] == "intraday"
|
||||
assert payload["resolved"]["interval_minutes"] == 15
|
||||
assert "interval_minutes: 15" in (run_dir / "BOOTSTRAP.md").read_text(encoding="utf-8")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# RuntimeState singleton unit tests
|
||||
# =============================================================================
|
||||
|
||||
def test_runtime_state_is_singleton():
|
||||
"""RuntimeState.__new__ returns the same instance across calls."""
|
||||
state1 = runtime_module.RuntimeState()
|
||||
state2 = runtime_module.RuntimeState()
|
||||
assert state1 is state2
|
||||
|
||||
|
||||
def test_runtime_state_get_runtime_state_returns_same_instance():
|
||||
"""get_runtime_state() returns the module singleton."""
|
||||
instance = runtime_module.get_runtime_state()
|
||||
assert instance is runtime_module._runtime_state
|
||||
|
||||
|
||||
def test_runtime_state_default_values():
|
||||
"""RuntimeState initializes with sensible defaults on first instantiation."""
|
||||
# Reset singleton to get fresh __init__ values
|
||||
runtime_module.RuntimeState._instance = None
|
||||
runtime_module.RuntimeState._lock = asyncio.Lock()
|
||||
state = runtime_module.RuntimeState()
|
||||
assert state._runtime_manager is None
|
||||
assert state._gateway_process is None
|
||||
assert state._gateway_port == 8765
|
||||
|
||||
|
||||
def test_runtime_state_gateway_port_property():
|
||||
"""gateway_port property getter and setter work correctly."""
|
||||
runtime_module.RuntimeState._instance = None
|
||||
runtime_module.RuntimeState._lock = asyncio.Lock()
|
||||
state = runtime_module.RuntimeState()
|
||||
state.gateway_port = 9999
|
||||
assert state.gateway_port == 9999
|
||||
state.gateway_port = 1234
|
||||
assert state.gateway_port == 1234
|
||||
|
||||
|
||||
def test_runtime_state_gateway_process_property():
|
||||
"""gateway_process property getter and setter work correctly."""
|
||||
runtime_module.RuntimeState._instance = None
|
||||
runtime_module.RuntimeState._lock = asyncio.Lock()
|
||||
state = runtime_module.RuntimeState()
|
||||
assert state.gateway_process is None
|
||||
|
||||
fake_process = object()
|
||||
state.gateway_process = fake_process
|
||||
assert state.gateway_process is fake_process
|
||||
|
||||
state.gateway_process = None
|
||||
assert state.gateway_process is None
|
||||
|
||||
|
||||
def test_runtime_state_runtime_manager_property():
|
||||
"""runtime_manager property getter and setter work correctly."""
|
||||
runtime_module.RuntimeState._instance = None
|
||||
runtime_module.RuntimeState._lock = asyncio.Lock()
|
||||
state = runtime_module.RuntimeState()
|
||||
assert state.runtime_manager is None
|
||||
|
||||
fake_manager = object()
|
||||
state.runtime_manager = fake_manager
|
||||
assert state.runtime_manager is fake_manager
|
||||
|
||||
state.runtime_manager = None
|
||||
assert state.runtime_manager is None
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("ignore::RuntimeWarning")
|
||||
def test_runtime_state_lock_property_is_async():
|
||||
"""lock is an async property that returns a coroutine producing an asyncio.Lock."""
|
||||
runtime_module.RuntimeState._instance = None
|
||||
runtime_module.RuntimeState._lock = asyncio.Lock()
|
||||
state = runtime_module.RuntimeState()
|
||||
lock_coro = state.lock
|
||||
assert asyncio.iscoroutine(lock_coro)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_state_async_set_get_gateway_port():
|
||||
"""Async setters and getters for gateway_port with lock protection."""
|
||||
runtime_module.RuntimeState._instance = None
|
||||
runtime_module.RuntimeState._lock = asyncio.Lock()
|
||||
state = runtime_module.RuntimeState()
|
||||
await state.set_gateway_port(8888)
|
||||
assert await state.get_gateway_port() == 8888
|
||||
await state.set_gateway_port(7777)
|
||||
assert await state.get_gateway_port() == 7777
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_state_async_set_get_gateway_process():
|
||||
"""Async setters and getters for gateway_process with lock protection."""
|
||||
runtime_module.RuntimeState._instance = None
|
||||
runtime_module.RuntimeState._lock = asyncio.Lock()
|
||||
state = runtime_module.RuntimeState()
|
||||
await state.set_gateway_process(None)
|
||||
assert await state.get_gateway_process() is None
|
||||
|
||||
fake_process = object()
|
||||
await state.set_gateway_process(fake_process)
|
||||
assert await state.get_gateway_process() is fake_process
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_state_async_set_get_runtime_manager():
|
||||
"""Async setters and getters for runtime_manager with lock protection."""
|
||||
runtime_module.RuntimeState._instance = None
|
||||
runtime_module.RuntimeState._lock = asyncio.Lock()
|
||||
state = runtime_module.RuntimeState()
|
||||
await state.set_runtime_manager(None)
|
||||
assert await state.get_runtime_manager() is None
|
||||
|
||||
fake_manager = object()
|
||||
await state.set_runtime_manager(fake_manager)
|
||||
assert await state.get_runtime_manager() is fake_manager
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# _is_gateway_running helper tests
|
||||
# =============================================================================
|
||||
|
||||
def test_is_gateway_running_returns_false_when_process_is_none():
|
||||
"""_is_gateway_running returns False when gateway_process is None."""
|
||||
runtime_module.RuntimeState._instance = None
|
||||
runtime_module.RuntimeState._lock = asyncio.Lock()
|
||||
new_state = runtime_module.RuntimeState()
|
||||
new_state._gateway_process = None
|
||||
runtime_module._runtime_state = new_state
|
||||
assert runtime_module._is_gateway_running() is False
|
||||
|
||||
|
||||
def test_is_gateway_running_returns_false_when_process_exited():
|
||||
"""_is_gateway_running returns False when process has terminated."""
|
||||
runtime_module.RuntimeState._instance = None
|
||||
runtime_module.RuntimeState._lock = asyncio.Lock()
|
||||
state = runtime_module.RuntimeState()
|
||||
runtime_module._runtime_state = state
|
||||
|
||||
mock_process = MagicMock()
|
||||
mock_process.poll.return_value = 1 # non-None = process has exited
|
||||
|
||||
state._gateway_process = mock_process
|
||||
assert runtime_module._is_gateway_running() is False
|
||||
|
||||
|
||||
def test_is_gateway_running_returns_true_when_process_running():
|
||||
"""_is_gateway_running returns True when process is alive."""
|
||||
runtime_module.RuntimeState._instance = None
|
||||
runtime_module.RuntimeState._lock = asyncio.Lock()
|
||||
state = runtime_module.RuntimeState()
|
||||
runtime_module._runtime_state = state
|
||||
|
||||
mock_process = MagicMock()
|
||||
mock_process.poll.return_value = None # None = still running
|
||||
|
||||
state._gateway_process = mock_process
|
||||
assert runtime_module._is_gateway_running() is True
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# _stop_gateway helper tests
|
||||
# =============================================================================
|
||||
|
||||
def test_stop_gateway_returns_false_when_no_process():
|
||||
"""_stop_gateway returns False if no gateway process exists."""
|
||||
runtime_module.RuntimeState._instance = None
|
||||
runtime_module.RuntimeState._lock = asyncio.Lock()
|
||||
new_state = runtime_module.RuntimeState()
|
||||
new_state._gateway_process = None
|
||||
runtime_module._runtime_state = new_state
|
||||
|
||||
result = runtime_module._stop_gateway()
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_stop_gateway_sets_process_to_none_after_stop():
|
||||
"""_stop_gateway sets _gateway_process to None after stopping."""
|
||||
runtime_module.RuntimeState._instance = None
|
||||
runtime_module.RuntimeState._lock = asyncio.Lock()
|
||||
state = runtime_module.RuntimeState()
|
||||
runtime_module._runtime_state = state
|
||||
|
||||
mock_process = MagicMock()
|
||||
mock_process.poll.return_value = None
|
||||
mock_process.wait.return_value = 0
|
||||
|
||||
state._gateway_process = mock_process
|
||||
|
||||
result = runtime_module._stop_gateway()
|
||||
|
||||
assert result is True
|
||||
assert state._gateway_process is None
|
||||
mock_process.terminate.assert_called_once()
|
||||
mock_process.wait.assert_called_once()
|
||||
|
||||
|
||||
def test_stop_gateway_kills_when_terminate_times_out():
|
||||
"""_stop_gateway kills the process if terminate times out."""
|
||||
import subprocess
|
||||
runtime_module.RuntimeState._instance = None
|
||||
runtime_module.RuntimeState._lock = asyncio.Lock()
|
||||
state = runtime_module.RuntimeState()
|
||||
runtime_module._runtime_state = state
|
||||
|
||||
mock_process = MagicMock()
|
||||
mock_process.poll.return_value = None
|
||||
mock_process.wait.side_effect = subprocess.TimeoutExpired("cmd", 5)
|
||||
mock_process.kill.return_value = None
|
||||
|
||||
state._gateway_process = mock_process
|
||||
|
||||
result = runtime_module._stop_gateway()
|
||||
|
||||
assert result is True
|
||||
assert state._gateway_process is None
|
||||
mock_process.kill.assert_called_once()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# _build_gateway_ws_url helper tests
|
||||
# =============================================================================
|
||||
|
||||
def test_build_gateway_ws_url_defaults_to_ws():
|
||||
from fastapi import Request
|
||||
|
||||
mock_request = MagicMock(spec=Request)
|
||||
mock_request.headers.get.side_effect = lambda k, d="": d
|
||||
mock_request.url.scheme = "http"
|
||||
mock_request.url.hostname = "localhost"
|
||||
|
||||
url = runtime_module._build_gateway_ws_url(mock_request, 8765)
|
||||
assert url == "ws://localhost:8765"
|
||||
|
||||
|
||||
def test_build_gateway_ws_url_uses_wss_for_https():
|
||||
from fastapi import Request
|
||||
|
||||
mock_request = MagicMock(spec=Request)
|
||||
mock_request.headers.get.side_effect = lambda k, d="": d
|
||||
mock_request.url.scheme = "https"
|
||||
mock_request.url.hostname = "example.com"
|
||||
|
||||
url = runtime_module._build_gateway_ws_url(mock_request, 8765)
|
||||
assert url == "wss://example.com:8765"
|
||||
|
||||
|
||||
def test_build_gateway_ws_url_respects_forwarded_proto():
|
||||
from fastapi import Request
|
||||
|
||||
mock_request = MagicMock(spec=Request)
|
||||
def header_get(key, default=""):
|
||||
if key == "x-forwarded-proto":
|
||||
return "https"
|
||||
return default
|
||||
mock_request.headers.get.side_effect = header_get
|
||||
mock_request.url.scheme = "http"
|
||||
mock_request.url.hostname = "internal.example"
|
||||
|
||||
url = runtime_module._build_gateway_ws_url(mock_request, 8765)
|
||||
assert url == "wss://internal.example:8765"
|
||||
|
||||
|
||||
def test_build_gateway_ws_url_respects_forwarded_host():
|
||||
from fastapi import Request
|
||||
|
||||
mock_request = MagicMock(spec=Request)
|
||||
mock_request.headers.get.side_effect = lambda k, d="": {
|
||||
"x-forwarded-host": "external.example.com"
|
||||
}.get(k, d)
|
||||
mock_request.url.scheme = "http"
|
||||
mock_request.url.hostname = "internal.example"
|
||||
|
||||
url = runtime_module._build_gateway_ws_url(mock_request, 8765)
|
||||
assert url == "ws://external.example.com:8765"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# _normalize_runtime_config_updates tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_normalize_runtime_config_updates_validates_schedule_mode():
|
||||
req = runtime_module.UpdateRuntimeConfigRequest(schedule_mode="invalid")
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
runtime_module._normalize_runtime_config_updates(req)
|
||||
assert "schedule_mode" in str(exc_info.value.detail).lower()
|
||||
|
||||
|
||||
def test_normalize_runtime_config_updates_validates_schedule_mode_values():
|
||||
for invalid in ["weekly", "monthly", "once"]:
|
||||
req = runtime_module.UpdateRuntimeConfigRequest(schedule_mode=invalid)
|
||||
with pytest.raises(HTTPException):
|
||||
runtime_module._normalize_runtime_config_updates(req)
|
||||
|
||||
|
||||
def test_normalize_runtime_config_updates_accepts_daily_and_intraday():
|
||||
for valid in ["daily", "intraday", "DAILY", "IntraDay"]:
|
||||
req = runtime_module.UpdateRuntimeConfigRequest(schedule_mode=valid)
|
||||
result = runtime_module._normalize_runtime_config_updates(req)
|
||||
assert "schedule_mode" in result
|
||||
|
||||
|
||||
def test_normalize_runtime_config_updates_validates_trigger_time_format():
|
||||
req = runtime_module.UpdateRuntimeConfigRequest(trigger_time="25:99")
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
runtime_module._normalize_runtime_config_updates(req)
|
||||
assert "trigger_time" in str(exc_info.value.detail).lower()
|
||||
|
||||
|
||||
def test_normalize_runtime_config_updates_accepts_now_trigger_time():
|
||||
req = runtime_module.UpdateRuntimeConfigRequest(trigger_time="now")
|
||||
result = runtime_module._normalize_runtime_config_updates(req)
|
||||
assert result["trigger_time"] == "now"
|
||||
|
||||
|
||||
def test_normalize_runtime_config_updates_defaults_empty_trigger_time():
|
||||
req = runtime_module.UpdateRuntimeConfigRequest(trigger_time=" ")
|
||||
result = runtime_module._normalize_runtime_config_updates(req)
|
||||
assert result["trigger_time"] == "09:30"
|
||||
|
||||
|
||||
def test_normalize_runtime_config_updates_rejects_no_updates():
|
||||
req = runtime_module.UpdateRuntimeConfigRequest()
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
runtime_module._normalize_runtime_config_updates(req)
|
||||
assert "no runtime config updates" in str(exc_info.value.detail).lower()
|
||||
|
||||
|
||||
def test_normalize_runtime_config_updates_coerces_types():
|
||||
req = runtime_module.UpdateRuntimeConfigRequest(
|
||||
schedule_mode="intraday",
|
||||
interval_minutes="30", # string from JSON
|
||||
initial_cash="50000.0", # string from JSON
|
||||
margin_requirement="0.25",
|
||||
)
|
||||
result = runtime_module._normalize_runtime_config_updates(req)
|
||||
assert result["schedule_mode"] == "intraday"
|
||||
assert result["interval_minutes"] == 30
|
||||
assert result["initial_cash"] == 50000.0
|
||||
assert result["margin_requirement"] == 0.25
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# register_runtime_manager / unregister_runtime_manager tests
|
||||
# =============================================================================
|
||||
|
||||
def test_register_runtime_manager_sets_module_and_singleton():
|
||||
runtime_module._runtime_state._initialized = True # prevent re-init
|
||||
fake_manager = object()
|
||||
|
||||
runtime_module.register_runtime_manager(fake_manager)
|
||||
|
||||
assert runtime_module.runtime_manager is fake_manager
|
||||
assert runtime_module._runtime_state.runtime_manager is fake_manager
|
||||
|
||||
|
||||
def test_unregister_runtime_manager_clears_module_and_singleton():
|
||||
runtime_module._runtime_state._initialized = True # prevent re-init
|
||||
runtime_module._runtime_state.runtime_manager = object()
|
||||
runtime_module.runtime_manager = object()
|
||||
|
||||
runtime_module.unregister_runtime_manager()
|
||||
|
||||
assert runtime_module.runtime_manager is None
|
||||
assert runtime_module._runtime_state.runtime_manager is None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# _generate_run_id tests
|
||||
# =============================================================================
|
||||
|
||||
def test_generate_run_id_returns_timestamp_format():
|
||||
run_id = runtime_module._generate_run_id()
|
||||
# Format: YYYYMMDD_HHMMSS - length is 15
|
||||
assert len(run_id) == 15
|
||||
assert run_id[8] == "_" # separator between date and time
|
||||
assert run_id[:8].isdigit() # YYYYMMDD
|
||||
assert run_id[9:].isdigit() # HHMMSS
|
||||
|
||||
@@ -1,47 +1,21 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Unit tests for the trading domain helpers."""
|
||||
"""Unit tests for data_tools functions (replaces the deleted trading_domain)."""
|
||||
|
||||
from backend.domains import trading as trading_domain
|
||||
from backend.tools.data_tools import (
|
||||
get_company_news,
|
||||
get_financial_metrics,
|
||||
get_insider_trades,
|
||||
get_market_cap,
|
||||
get_prices,
|
||||
search_line_items,
|
||||
)
|
||||
|
||||
|
||||
def test_trading_domain_payload_wrappers(monkeypatch):
|
||||
monkeypatch.setattr(trading_domain, "get_prices", lambda ticker, start_date, end_date: [{"close": 1}])
|
||||
monkeypatch.setattr(trading_domain, "get_financial_metrics", lambda ticker, end_date, period, limit: [{"ticker": ticker}])
|
||||
monkeypatch.setattr(trading_domain, "get_company_news", lambda ticker, end_date, start_date=None, limit=1000: [{"ticker": ticker}])
|
||||
monkeypatch.setattr(trading_domain, "get_insider_trades", lambda ticker, end_date, start_date=None, limit=1000: [{"ticker": ticker}])
|
||||
monkeypatch.setattr(trading_domain, "get_market_cap", lambda ticker, end_date: 2.5e12)
|
||||
|
||||
assert trading_domain.get_prices_payload(ticker="AAPL", start_date="2026-03-01", end_date="2026-03-16") == {
|
||||
"ticker": "AAPL",
|
||||
"prices": [{"close": 1}],
|
||||
}
|
||||
assert trading_domain.get_financials_payload(ticker="AAPL", end_date="2026-03-16") == {
|
||||
"financial_metrics": [{"ticker": "AAPL"}],
|
||||
}
|
||||
assert trading_domain.get_news_payload(ticker="AAPL", end_date="2026-03-16") == {
|
||||
"news": [{"ticker": "AAPL"}],
|
||||
}
|
||||
assert trading_domain.get_insider_trades_payload(ticker="AAPL", end_date="2026-03-16") == {
|
||||
"insider_trades": [{"ticker": "AAPL"}],
|
||||
}
|
||||
assert trading_domain.get_market_cap_payload(ticker="AAPL", end_date="2026-03-16") == {
|
||||
"ticker": "AAPL",
|
||||
"end_date": "2026-03-16",
|
||||
"market_cap": 2.5e12,
|
||||
}
|
||||
|
||||
|
||||
def test_get_market_status_payload_uses_market_service(monkeypatch):
|
||||
class _FakeMarketService:
|
||||
def __init__(self, tickers):
|
||||
self.tickers = tickers
|
||||
|
||||
def get_market_status(self):
|
||||
return {"status": "open", "status_text": "Open"}
|
||||
|
||||
monkeypatch.setattr(trading_domain, "MarketService", _FakeMarketService)
|
||||
|
||||
assert trading_domain.get_market_status_payload() == {
|
||||
"status": "open",
|
||||
"status_text": "Open",
|
||||
}
|
||||
def test_data_tools_functions_exist():
|
||||
"""Verify that all data_tools functions are importable and callable."""
|
||||
assert callable(get_prices)
|
||||
assert callable(get_financial_metrics)
|
||||
assert callable(get_company_news)
|
||||
assert callable(get_insider_trades)
|
||||
assert callable(get_market_cap)
|
||||
assert callable(search_line_items)
|
||||
|
||||
@@ -24,20 +24,17 @@ def test_trading_service_routes_are_exposed():
|
||||
|
||||
def test_trading_service_prices_endpoint(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"backend.domains.trading.get_prices_payload",
|
||||
lambda ticker, start_date, end_date: {
|
||||
"ticker": ticker,
|
||||
"prices": [
|
||||
Price(
|
||||
open=1.0,
|
||||
close=2.0,
|
||||
high=2.5,
|
||||
low=0.5,
|
||||
volume=100,
|
||||
time="2026-03-20",
|
||||
)
|
||||
],
|
||||
},
|
||||
"backend.apps.trading_service.get_prices",
|
||||
lambda ticker, start_date, end_date: [
|
||||
Price(
|
||||
open=1.0,
|
||||
close=2.0,
|
||||
high=2.5,
|
||||
low=0.5,
|
||||
volume=100,
|
||||
time="2026-03-20",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
with TestClient(create_app()) as client:
|
||||
@@ -57,56 +54,54 @@ def test_trading_service_prices_endpoint(monkeypatch):
|
||||
|
||||
def test_trading_service_financials_endpoint(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"backend.domains.trading.get_financials_payload",
|
||||
lambda ticker, end_date, period, limit: {
|
||||
"financial_metrics": [
|
||||
FinancialMetrics(
|
||||
ticker=ticker,
|
||||
report_period=end_date,
|
||||
period=period,
|
||||
currency="USD",
|
||||
market_cap=123.0,
|
||||
enterprise_value=None,
|
||||
price_to_earnings_ratio=None,
|
||||
price_to_book_ratio=None,
|
||||
price_to_sales_ratio=None,
|
||||
enterprise_value_to_ebitda_ratio=None,
|
||||
enterprise_value_to_revenue_ratio=None,
|
||||
free_cash_flow_yield=None,
|
||||
peg_ratio=None,
|
||||
gross_margin=None,
|
||||
operating_margin=None,
|
||||
net_margin=None,
|
||||
return_on_equity=None,
|
||||
return_on_assets=None,
|
||||
return_on_invested_capital=None,
|
||||
asset_turnover=None,
|
||||
inventory_turnover=None,
|
||||
receivables_turnover=None,
|
||||
days_sales_outstanding=None,
|
||||
operating_cycle=None,
|
||||
working_capital_turnover=None,
|
||||
current_ratio=None,
|
||||
quick_ratio=None,
|
||||
cash_ratio=None,
|
||||
operating_cash_flow_ratio=None,
|
||||
debt_to_equity=None,
|
||||
debt_to_assets=None,
|
||||
interest_coverage=None,
|
||||
revenue_growth=None,
|
||||
earnings_growth=None,
|
||||
book_value_growth=None,
|
||||
earnings_per_share_growth=None,
|
||||
free_cash_flow_growth=None,
|
||||
operating_income_growth=None,
|
||||
ebitda_growth=None,
|
||||
payout_ratio=None,
|
||||
earnings_per_share=None,
|
||||
book_value_per_share=None,
|
||||
free_cash_flow_per_share=None,
|
||||
)
|
||||
]
|
||||
},
|
||||
"backend.apps.trading_service.get_financial_metrics",
|
||||
lambda ticker, end_date, period, limit: [
|
||||
FinancialMetrics(
|
||||
ticker=ticker,
|
||||
report_period=end_date,
|
||||
period=period,
|
||||
currency="USD",
|
||||
market_cap=123.0,
|
||||
enterprise_value=None,
|
||||
price_to_earnings_ratio=None,
|
||||
price_to_book_ratio=None,
|
||||
price_to_sales_ratio=None,
|
||||
enterprise_value_to_ebitda_ratio=None,
|
||||
enterprise_value_to_revenue_ratio=None,
|
||||
free_cash_flow_yield=None,
|
||||
peg_ratio=None,
|
||||
gross_margin=None,
|
||||
operating_margin=None,
|
||||
net_margin=None,
|
||||
return_on_equity=None,
|
||||
return_on_assets=None,
|
||||
return_on_invested_capital=None,
|
||||
asset_turnover=None,
|
||||
inventory_turnover=None,
|
||||
receivables_turnover=None,
|
||||
days_sales_outstanding=None,
|
||||
operating_cycle=None,
|
||||
working_capital_turnover=None,
|
||||
current_ratio=None,
|
||||
quick_ratio=None,
|
||||
cash_ratio=None,
|
||||
operating_cash_flow_ratio=None,
|
||||
debt_to_equity=None,
|
||||
debt_to_assets=None,
|
||||
interest_coverage=None,
|
||||
revenue_growth=None,
|
||||
earnings_growth=None,
|
||||
book_value_growth=None,
|
||||
earnings_per_share_growth=None,
|
||||
free_cash_flow_growth=None,
|
||||
operating_income_growth=None,
|
||||
ebitda_growth=None,
|
||||
payout_ratio=None,
|
||||
earnings_per_share=None,
|
||||
book_value_per_share=None,
|
||||
free_cash_flow_per_share=None,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
with TestClient(create_app()) as client:
|
||||
@@ -121,26 +116,22 @@ def test_trading_service_financials_endpoint(monkeypatch):
|
||||
|
||||
def test_trading_service_news_and_insider_endpoints(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"backend.domains.trading.get_news_payload",
|
||||
lambda ticker, end_date, start_date=None, limit=1000: {
|
||||
"news": [
|
||||
CompanyNews(
|
||||
ticker=ticker,
|
||||
title="News title",
|
||||
source="polygon",
|
||||
url="https://example.com/news",
|
||||
date=end_date,
|
||||
)
|
||||
]
|
||||
},
|
||||
"backend.apps.trading_service.get_company_news",
|
||||
lambda ticker, end_date, start_date=None, limit=1000: [
|
||||
CompanyNews(
|
||||
ticker=ticker,
|
||||
title="News title",
|
||||
source="polygon",
|
||||
url="https://example.com/news",
|
||||
date=end_date,
|
||||
)
|
||||
],
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"backend.domains.trading.get_insider_trades_payload",
|
||||
lambda ticker, end_date, start_date=None, limit=1000: {
|
||||
"insider_trades": [
|
||||
InsiderTrade(ticker=ticker, filing_date=end_date)
|
||||
]
|
||||
},
|
||||
"backend.apps.trading_service.get_insider_trades",
|
||||
lambda ticker, end_date, start_date=None, limit=1000: [
|
||||
InsiderTrade(ticker=ticker, filing_date=end_date)
|
||||
],
|
||||
)
|
||||
|
||||
with TestClient(create_app()) as client:
|
||||
@@ -165,8 +156,8 @@ def test_trading_service_market_status_endpoint(monkeypatch):
|
||||
return {"status": "open", "status_text": "Open"}
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.domains.trading.get_market_status_payload",
|
||||
lambda: _FakeMarketService().get_market_status(),
|
||||
"backend.apps.trading_service.MarketService",
|
||||
lambda tickers: _FakeMarketService(),
|
||||
)
|
||||
|
||||
with TestClient(create_app()) as client:
|
||||
@@ -178,12 +169,8 @@ def test_trading_service_market_status_endpoint(monkeypatch):
|
||||
|
||||
def test_trading_service_market_cap_endpoint(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"backend.domains.trading.get_market_cap_payload",
|
||||
lambda ticker, end_date: {
|
||||
"ticker": ticker,
|
||||
"end_date": end_date,
|
||||
"market_cap": 3.5e12,
|
||||
},
|
||||
"backend.apps.trading_service.get_market_cap",
|
||||
lambda ticker, end_date: 3.5e12,
|
||||
)
|
||||
|
||||
with TestClient(create_app()) as client:
|
||||
@@ -202,18 +189,16 @@ def test_trading_service_market_cap_endpoint(monkeypatch):
|
||||
|
||||
def test_trading_service_line_items_endpoint(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"backend.domains.trading.get_line_items_payload",
|
||||
lambda ticker, line_items, end_date, period, limit: {
|
||||
"search_results": [
|
||||
LineItem(
|
||||
ticker=ticker,
|
||||
report_period=end_date,
|
||||
period=period,
|
||||
currency="USD",
|
||||
free_cash_flow=123.0,
|
||||
)
|
||||
]
|
||||
},
|
||||
"backend.apps.trading_service.search_line_items",
|
||||
lambda ticker, line_items, end_date, period, limit: [
|
||||
LineItem(
|
||||
ticker=ticker,
|
||||
report_period=end_date,
|
||||
period=period,
|
||||
currency="USD",
|
||||
free_cash_flow=123.0,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
with TestClient(create_app()) as client:
|
||||
|
||||
Reference in New Issue
Block a user