feat: initial commit - EvoTraders project
量化交易多智能体系统,包含: - 分析师、投资组合经理、风险经理等智能体 - 股票分析、投资组合管理、风险控制工具 - React 前端界面 - FastAPI 后端服务 Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
2
backend/services/__init__.py
Normal file
2
backend/services/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Services layer for infrastructure components"""
|
||||
569
backend/services/gateway.py
Normal file
569
backend/services/gateway.py
Normal file
@@ -0,0 +1,569 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
WebSocket Gateway for frontend communication
|
||||
"""
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Callable, Dict, List, Optional, Set
|
||||
|
||||
import websockets
|
||||
from websockets.server import WebSocketServerProtocol
|
||||
|
||||
from backend.utils.msg_adapter import FrontendAdapter
|
||||
from backend.utils.terminal_dashboard import get_dashboard
|
||||
from backend.core.pipeline import TradingPipeline
|
||||
from backend.core.state_sync import StateSync
|
||||
from backend.services.market import MarketService
|
||||
from backend.services.storage import StorageService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Gateway:
|
||||
"""WebSocket Gateway for frontend communication"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
market_service: MarketService,
|
||||
storage_service: StorageService,
|
||||
pipeline: TradingPipeline,
|
||||
state_sync: Optional[StateSync] = None,
|
||||
scheduler_callback: Optional[Callable] = None,
|
||||
config: Dict[str, Any] = None,
|
||||
):
|
||||
self.market_service = market_service
|
||||
self.storage = storage_service
|
||||
self.pipeline = pipeline
|
||||
self.scheduler_callback = scheduler_callback
|
||||
self.config = config or {}
|
||||
|
||||
self.mode = self.config.get("mode", "live")
|
||||
self.is_backtest = self.mode == "backtest" or self.config.get(
|
||||
"backtest_mode",
|
||||
False,
|
||||
)
|
||||
|
||||
self.state_sync = state_sync or StateSync(storage=storage_service)
|
||||
# self.state_sync.set_mode(self.is_backtest)
|
||||
self.state_sync.set_broadcast_fn(self.broadcast)
|
||||
self.pipeline.state_sync = self.state_sync
|
||||
|
||||
self.connected_clients: Set[WebSocketServerProtocol] = set()
|
||||
self.lock = asyncio.Lock()
|
||||
self._backtest_task: Optional[asyncio.Task] = None
|
||||
self._backtest_start_date: Optional[str] = None
|
||||
self._backtest_end_date: Optional[str] = None
|
||||
self._dashboard = get_dashboard()
|
||||
self._market_status_task: Optional[asyncio.Task] = None
|
||||
|
||||
# Session tracking for live returns
|
||||
self._session_start_portfolio_value: Optional[float] = None
|
||||
|
||||
async def start(self, host: str = "0.0.0.0", port: int = 8766):
|
||||
"""Start gateway server"""
|
||||
logger.info(f"Starting gateway on {host}:{port}")
|
||||
|
||||
# Initialize terminal dashboard
|
||||
self._dashboard.set_config(
|
||||
mode=self.mode,
|
||||
config_name=self.config.get("config_name", "default"),
|
||||
host=host,
|
||||
port=port,
|
||||
poll_interval=self.config.get("poll_interval", 10),
|
||||
mock=self.config.get("mock_mode", False),
|
||||
tickers=self.config.get("tickers", []),
|
||||
initial_cash=self.storage.initial_cash,
|
||||
start_date=self._backtest_start_date or "",
|
||||
end_date=self._backtest_end_date or "",
|
||||
)
|
||||
self._dashboard.start()
|
||||
|
||||
self.state_sync.load_state()
|
||||
self.state_sync.update_state("status", "running")
|
||||
self.state_sync.update_state("server_mode", self.mode)
|
||||
self.state_sync.update_state("is_backtest", self.is_backtest)
|
||||
self.state_sync.update_state(
|
||||
"is_mock_mode",
|
||||
self.config.get("mock_mode", False),
|
||||
)
|
||||
|
||||
# Load and display existing portfolio state if available
|
||||
summary = self.storage.load_file("summary")
|
||||
if summary:
|
||||
holdings = self.storage.load_file("holdings") or []
|
||||
trades = self.storage.load_file("trades") or []
|
||||
current_date = self.state_sync.state.get("current_date")
|
||||
self._dashboard.update(
|
||||
date=current_date or "-",
|
||||
status="running",
|
||||
portfolio=summary,
|
||||
holdings=holdings,
|
||||
trades=trades,
|
||||
)
|
||||
logger.info(
|
||||
"Loaded existing portfolio: $%s",
|
||||
f"{summary.get('totalAssetValue', 0):,.2f}",
|
||||
)
|
||||
|
||||
await self.market_service.start(broadcast_func=self.broadcast)
|
||||
|
||||
if self.scheduler_callback:
|
||||
await self.scheduler_callback(callback=self.on_strategy_trigger)
|
||||
|
||||
# Start market status monitoring (only for live mode)
|
||||
if not self.is_backtest:
|
||||
self._market_status_task = asyncio.create_task(
|
||||
self._market_status_monitor(),
|
||||
)
|
||||
|
||||
async with websockets.serve(
|
||||
self.handle_client,
|
||||
host,
|
||||
port,
|
||||
ping_interval=30,
|
||||
ping_timeout=60,
|
||||
):
|
||||
logger.info(
|
||||
f"Gateway started: ws://{host}:{port}, mode={self.mode}",
|
||||
)
|
||||
await asyncio.Future()
|
||||
|
||||
@property
|
||||
def state(self) -> Dict[str, Any]:
|
||||
return self.state_sync.state
|
||||
|
||||
async def handle_client(self, websocket: WebSocketServerProtocol):
|
||||
"""Handle WebSocket client connection"""
|
||||
async with self.lock:
|
||||
self.connected_clients.add(websocket)
|
||||
|
||||
await self._send_initial_state(websocket)
|
||||
await self._handle_client_messages(websocket)
|
||||
|
||||
async with self.lock:
|
||||
self.connected_clients.discard(websocket)
|
||||
|
||||
async def _send_initial_state(self, websocket: WebSocketServerProtocol):
|
||||
state_payload = self.state_sync.get_initial_state_payload(
|
||||
include_dashboard=True,
|
||||
)
|
||||
# Include market status in initial state
|
||||
state_payload[
|
||||
"market_status"
|
||||
] = self.market_service.get_market_status()
|
||||
|
||||
# Include live returns if session is active
|
||||
if self.storage.is_live_session_active:
|
||||
live_returns = self.storage.get_live_returns()
|
||||
if "portfolio" in state_payload:
|
||||
state_payload["portfolio"].update(live_returns)
|
||||
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{"type": "initial_state", "state": state_payload},
|
||||
ensure_ascii=False,
|
||||
default=str,
|
||||
),
|
||||
)
|
||||
|
||||
async def _handle_client_messages(
|
||||
self,
|
||||
websocket: WebSocketServerProtocol,
|
||||
):
|
||||
try:
|
||||
async for message in websocket:
|
||||
data = json.loads(message)
|
||||
msg_type = data.get("type", "unknown")
|
||||
|
||||
if msg_type == "ping":
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "pong",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
},
|
||||
ensure_ascii=False,
|
||||
),
|
||||
)
|
||||
elif msg_type == "get_state":
|
||||
await self._send_initial_state(websocket)
|
||||
elif msg_type == "start_backtest":
|
||||
await self._handle_start_backtest(data)
|
||||
|
||||
except websockets.ConnectionClosed:
|
||||
pass
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
async def _handle_start_backtest(self, data: Dict[str, Any]):
|
||||
if not self.is_backtest:
|
||||
return
|
||||
dates = data.get("dates", [])
|
||||
if dates and self._backtest_task is None:
|
||||
task = asyncio.create_task(
|
||||
self._run_backtest_dates(dates),
|
||||
)
|
||||
task.add_done_callback(self._handle_backtest_exception)
|
||||
self._backtest_task = task
|
||||
|
||||
async def broadcast(self, message: Dict[str, Any]):
|
||||
"""Broadcast message to all connected clients"""
|
||||
if not self.connected_clients:
|
||||
return
|
||||
|
||||
message_json = json.dumps(message, ensure_ascii=False, default=str)
|
||||
|
||||
async with self.lock:
|
||||
tasks = [
|
||||
self._send_to_client(client, message_json)
|
||||
for client in self.connected_clients.copy()
|
||||
]
|
||||
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
async def _send_to_client(
|
||||
self,
|
||||
client: WebSocketServerProtocol,
|
||||
message: str,
|
||||
):
|
||||
try:
|
||||
await client.send(message)
|
||||
except websockets.ConnectionClosed:
|
||||
async with self.lock:
|
||||
self.connected_clients.discard(client)
|
||||
|
||||
async def _market_status_monitor(self):
|
||||
"""Periodically check and broadcast market status changes"""
|
||||
while True:
|
||||
try:
|
||||
await self.market_service.check_and_broadcast_market_status()
|
||||
|
||||
# On market open, start live session tracking
|
||||
status = self.market_service.get_market_status()
|
||||
if (
|
||||
status["status"] == "open"
|
||||
and not self.storage.is_live_session_active
|
||||
):
|
||||
self.storage.start_live_session()
|
||||
summary = self.storage.load_file("summary") or {}
|
||||
self._session_start_portfolio_value = summary.get(
|
||||
"totalAssetValue",
|
||||
self.storage.initial_cash,
|
||||
)
|
||||
logger.info(
|
||||
"Session start portfolio: "
|
||||
f"${self._session_start_portfolio_value:,.2f}",
|
||||
)
|
||||
elif (
|
||||
status["status"] != "open"
|
||||
and self.storage.is_live_session_active
|
||||
):
|
||||
self.storage.end_live_session()
|
||||
self._session_start_portfolio_value = None
|
||||
|
||||
# Update and broadcast live returns if session is active
|
||||
if self.storage.is_live_session_active:
|
||||
await self._update_and_broadcast_live_returns()
|
||||
|
||||
await asyncio.sleep(60) # Check every minute
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Market status monitor error: {e}")
|
||||
await asyncio.sleep(60)
|
||||
|
||||
async def _update_and_broadcast_live_returns(self):
|
||||
"""Calculate and broadcast live returns for current session"""
|
||||
if not self.storage.is_live_session_active:
|
||||
return
|
||||
|
||||
# Get current prices and calculate portfolio value
|
||||
prices = self.market_service.get_all_prices()
|
||||
if not prices or not any(p > 0 for p in prices.values()):
|
||||
return
|
||||
|
||||
# Load current internal state to get baseline values
|
||||
state = self.storage.load_internal_state()
|
||||
|
||||
# Get latest values from history (if available)
|
||||
equity_history = state.get("equity_history", [])
|
||||
baseline_history = state.get("baseline_history", [])
|
||||
baseline_vw_history = state.get("baseline_vw_history", [])
|
||||
momentum_history = state.get("momentum_history", [])
|
||||
|
||||
current_equity = equity_history[-1]["v"] if equity_history else None
|
||||
current_baseline = (
|
||||
baseline_history[-1]["v"] if baseline_history else None
|
||||
)
|
||||
current_baseline_vw = (
|
||||
baseline_vw_history[-1]["v"] if baseline_vw_history else None
|
||||
)
|
||||
current_momentum = (
|
||||
momentum_history[-1]["v"] if momentum_history else None
|
||||
)
|
||||
|
||||
# Update live returns with current values
|
||||
point = self.storage.update_live_returns(
|
||||
current_equity=current_equity,
|
||||
current_baseline=current_baseline,
|
||||
current_baseline_vw=current_baseline_vw,
|
||||
current_momentum=current_momentum,
|
||||
)
|
||||
|
||||
# Broadcast if we have new data
|
||||
if point:
|
||||
live_returns = self.storage.get_live_returns()
|
||||
await self.broadcast(
|
||||
{
|
||||
"type": "team_summary",
|
||||
"equity_return": live_returns["equity_return"],
|
||||
"baseline_return": live_returns["baseline_return"],
|
||||
"baseline_vw_return": live_returns["baseline_vw_return"],
|
||||
"momentum_return": live_returns["momentum_return"],
|
||||
},
|
||||
)
|
||||
|
||||
async def on_strategy_trigger(self, date: str):
|
||||
"""Handle trading cycle trigger"""
|
||||
logger.info(f"Strategy triggered for {date}")
|
||||
|
||||
tickers = self.config.get("tickers", [])
|
||||
|
||||
if self.is_backtest:
|
||||
await self._run_backtest_cycle(date, tickers)
|
||||
else:
|
||||
await self._run_live_cycle(date, tickers)
|
||||
|
||||
async def _run_backtest_cycle(self, date: str, tickers: List[str]):
|
||||
"""Run backtest cycle with pre-loaded prices"""
|
||||
self.market_service.set_backtest_date(date)
|
||||
await self.market_service.emit_market_open()
|
||||
|
||||
await self.state_sync.on_cycle_start(date)
|
||||
self._dashboard.update(date=date, status="Analyzing...")
|
||||
|
||||
prices = self.market_service.get_open_prices()
|
||||
close_prices = self.market_service.get_close_prices()
|
||||
market_caps = self._get_market_caps(tickers, date)
|
||||
|
||||
result = await self.pipeline.run_cycle(
|
||||
tickers=tickers,
|
||||
date=date,
|
||||
prices=prices,
|
||||
close_prices=close_prices,
|
||||
market_caps=market_caps,
|
||||
)
|
||||
|
||||
await self.market_service.emit_market_close()
|
||||
settlement_result = result.get("settlement_result")
|
||||
self._save_cycle_results(result, date, close_prices, settlement_result)
|
||||
await self._broadcast_portfolio_updates(result, close_prices)
|
||||
await self._finalize_cycle(date)
|
||||
|
||||
async def _run_live_cycle(self, date: str, tickers: List[str]):
|
||||
"""
|
||||
Run live cycle with real market timing.
|
||||
|
||||
- Analysis runs immediately
|
||||
- Execution waits for market open
|
||||
(or uses current prices if already open)
|
||||
- Settlement waits for market close
|
||||
"""
|
||||
# Get actual trading date (might be next trading day if weekend)
|
||||
trading_date = self.market_service.get_live_trading_date()
|
||||
logger.info(
|
||||
f"Live cycle: triggered={date}, trading_date={trading_date}",
|
||||
)
|
||||
|
||||
await self.state_sync.on_cycle_start(trading_date)
|
||||
self._dashboard.update(date=trading_date, status="Analyzing...")
|
||||
|
||||
market_caps = self._get_market_caps(tickers, trading_date)
|
||||
|
||||
# Run pipeline with async price callbacks
|
||||
result = await self.pipeline.run_cycle(
|
||||
tickers=tickers,
|
||||
date=trading_date,
|
||||
market_caps=market_caps,
|
||||
get_open_prices_fn=self.market_service.wait_for_open_prices,
|
||||
get_close_prices_fn=self.market_service.wait_for_close_prices,
|
||||
)
|
||||
|
||||
close_prices = self.market_service.get_all_prices()
|
||||
settlement_result = result.get("settlement_result")
|
||||
self._save_cycle_results(
|
||||
result,
|
||||
trading_date,
|
||||
close_prices,
|
||||
settlement_result,
|
||||
)
|
||||
await self._broadcast_portfolio_updates(result, close_prices)
|
||||
await self._finalize_cycle(trading_date)
|
||||
|
||||
async def _finalize_cycle(self, date: str):
|
||||
"""Finalize cycle: broadcast state and update dashboard"""
|
||||
summary = self.storage.load_file("summary") or {}
|
||||
|
||||
# Include live returns if session is active
|
||||
if self.storage.is_live_session_active:
|
||||
live_returns = self.storage.get_live_returns()
|
||||
summary.update(live_returns)
|
||||
|
||||
await self.state_sync.on_cycle_end(date, portfolio_summary=summary)
|
||||
|
||||
holdings = self.storage.load_file("holdings") or []
|
||||
trades = self.storage.load_file("trades") or []
|
||||
leaderboard = self.storage.load_file("leaderboard") or []
|
||||
|
||||
if leaderboard:
|
||||
await self.state_sync.on_leaderboard_update(leaderboard)
|
||||
|
||||
self._dashboard.update(
|
||||
date=date,
|
||||
status="Running",
|
||||
portfolio=summary,
|
||||
holdings=holdings,
|
||||
trades=trades,
|
||||
)
|
||||
|
||||
def _get_market_caps(
|
||||
self,
|
||||
tickers: List[str],
|
||||
date: str,
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
Get market caps for tickers (stub implementation)
|
||||
|
||||
Args:
|
||||
tickers: List of tickers
|
||||
date: Trading date
|
||||
|
||||
Returns:
|
||||
Dict mapping ticker to market cap
|
||||
"""
|
||||
from ..tools.data_tools import get_market_cap
|
||||
|
||||
market_caps = {}
|
||||
for ticker in tickers:
|
||||
try:
|
||||
market_cap = get_market_cap(ticker, date)
|
||||
if market_cap:
|
||||
market_caps[ticker] = market_cap
|
||||
else:
|
||||
market_caps[ticker] = 1e9
|
||||
except Exception:
|
||||
market_caps[ticker] = 1e9
|
||||
|
||||
return market_caps
|
||||
|
||||
async def _broadcast_portfolio_updates(
|
||||
self,
|
||||
result: Dict[str, Any],
|
||||
prices: Dict[str, float],
|
||||
):
|
||||
portfolio = result.get("portfolio", {})
|
||||
|
||||
if portfolio:
|
||||
holdings = FrontendAdapter.build_holdings(portfolio, prices)
|
||||
if holdings:
|
||||
await self.state_sync.on_holdings_update(holdings)
|
||||
|
||||
stats = FrontendAdapter.build_stats(portfolio, prices)
|
||||
if stats:
|
||||
await self.state_sync.on_stats_update(stats)
|
||||
|
||||
executed_trades = result.get("executed_trades", [])
|
||||
if executed_trades:
|
||||
await self.state_sync.on_trades_executed(executed_trades)
|
||||
|
||||
def _save_cycle_results(
|
||||
self,
|
||||
result: Dict[str, Any],
|
||||
date: str,
|
||||
prices: Dict[str, float],
|
||||
settlement_result: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
portfolio = result.get("portfolio", {})
|
||||
executed_trades = result.get("executed_trades", [])
|
||||
|
||||
# Extract baseline values from settlement result
|
||||
baseline_values = None
|
||||
if settlement_result:
|
||||
baseline_values = settlement_result.get("baseline_values")
|
||||
|
||||
if portfolio:
|
||||
self.storage.update_dashboard_after_cycle(
|
||||
portfolio=portfolio,
|
||||
prices=prices,
|
||||
date=date,
|
||||
executed_trades=executed_trades,
|
||||
baseline_values=baseline_values,
|
||||
)
|
||||
|
||||
async def _run_backtest_dates(self, dates: List[str]):
|
||||
self.state_sync.set_backtest_dates(dates)
|
||||
self._dashboard.update(days_total=len(dates), days_completed=0)
|
||||
|
||||
await self.state_sync.on_system_message(
|
||||
f"Starting backtest - {len(dates)} trading days",
|
||||
)
|
||||
|
||||
try:
|
||||
for i, date in enumerate(dates):
|
||||
self._dashboard.update(days_completed=i)
|
||||
await self.on_strategy_trigger(date=date)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
await self.state_sync.on_system_message(
|
||||
f"Backtest complete - {len(dates)} days",
|
||||
)
|
||||
|
||||
# Update dashboard with final state
|
||||
summary = self.storage.load_file("summary") or {}
|
||||
self._dashboard.update(
|
||||
status="Complete",
|
||||
portfolio=summary,
|
||||
days_completed=len(dates),
|
||||
)
|
||||
self._dashboard.stop()
|
||||
self._dashboard.print_final_summary()
|
||||
except Exception as e:
|
||||
error_msg = f"Backtest failed: {type(e).__name__}: {str(e)}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
await self.state_sync.on_system_message(error_msg)
|
||||
self._dashboard.update(status=f"Failed: {str(e)}")
|
||||
self._dashboard.stop()
|
||||
raise
|
||||
finally:
|
||||
self._backtest_task = None
|
||||
|
||||
def _handle_backtest_exception(self, task: asyncio.Task):
|
||||
"""Handle exceptions from backtest task"""
|
||||
try:
|
||||
task.result()
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Backtest task was cancelled")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Backtest task failed with exception:{type(e).__name__}:{e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
def set_backtest_dates(self, dates: List[str]):
|
||||
self.state_sync.set_backtest_dates(dates)
|
||||
if dates:
|
||||
self._backtest_start_date = dates[0]
|
||||
self._backtest_end_date = dates[-1]
|
||||
self._dashboard.days_total = len(dates)
|
||||
|
||||
def stop(self):
|
||||
self.state_sync.save_state()
|
||||
self.market_service.stop()
|
||||
if self._backtest_task:
|
||||
self._backtest_task.cancel()
|
||||
if self._market_status_task:
|
||||
self._market_status_task.cancel()
|
||||
self._dashboard.stop()
|
||||
625
backend/services/market.py
Normal file
625
backend/services/market.py
Normal file
@@ -0,0 +1,625 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Market Data Service
|
||||
Supports live, mock, and backtest modes
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
import pandas_market_calendars as mcal
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# NYSE timezone and calendar
|
||||
NYSE_TZ = ZoneInfo("America/New_York")
|
||||
NYSE_CALENDAR = mcal.get_calendar("NYSE")
|
||||
|
||||
|
||||
class MarketStatus:
|
||||
"""Market status enum-like class"""
|
||||
|
||||
OPEN = "open"
|
||||
CLOSED = "closed"
|
||||
PREMARKET = "premarket"
|
||||
AFTERHOURS = "afterhours"
|
||||
|
||||
|
||||
class MarketService:
|
||||
"""Market data service for price management"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tickers: List[str],
|
||||
poll_interval: int = 10,
|
||||
mock_mode: bool = False,
|
||||
backtest_mode: bool = False,
|
||||
api_key: Optional[str] = None,
|
||||
backtest_start_date: Optional[str] = None,
|
||||
backtest_end_date: Optional[str] = None,
|
||||
):
|
||||
self.tickers = tickers
|
||||
self.poll_interval = poll_interval
|
||||
self.mock_mode = mock_mode
|
||||
self.backtest_mode = backtest_mode
|
||||
self.api_key = api_key
|
||||
self.backtest_start_date = backtest_start_date
|
||||
self.backtest_end_date = backtest_end_date
|
||||
|
||||
self.cache: Dict[str, Dict[str, Any]] = {}
|
||||
self.running = False
|
||||
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
self._broadcast_func: Optional[Callable] = None
|
||||
self._price_manager: Optional[Any] = None
|
||||
self._current_date: Optional[str] = None
|
||||
|
||||
# Market status tracking
|
||||
self._last_market_status: Optional[str] = None
|
||||
|
||||
# Session tracking for live returns
|
||||
self._session_start_values: Optional[Dict[str, float]] = None
|
||||
self._session_start_timestamp: Optional[int] = None
|
||||
|
||||
@property
|
||||
def mode_name(self) -> str:
|
||||
if self.backtest_mode:
|
||||
return "BACKTEST"
|
||||
elif self.mock_mode:
|
||||
return "MOCK"
|
||||
return "LIVE"
|
||||
|
||||
async def start(self, broadcast_func: Callable):
|
||||
"""Start market data service"""
|
||||
if self.running:
|
||||
return
|
||||
|
||||
self.running = True
|
||||
self._loop = asyncio.get_running_loop()
|
||||
self._broadcast_func = broadcast_func
|
||||
|
||||
if self.backtest_mode:
|
||||
self._start_backtest_mode()
|
||||
elif self.mock_mode:
|
||||
self._start_mock_mode()
|
||||
else:
|
||||
self._start_real_mode()
|
||||
|
||||
logger.info(
|
||||
f"Market service started: {self.mode_name}, tickers={self.tickers}", # noqa: E501
|
||||
)
|
||||
|
||||
def _make_price_callback(self) -> Callable:
|
||||
"""Create thread-safe price callback"""
|
||||
|
||||
def callback(price_data: Dict[str, Any]):
|
||||
symbol = price_data["symbol"]
|
||||
self.cache[symbol] = price_data
|
||||
|
||||
loop = self._loop
|
||||
if loop and loop.is_running() and self._broadcast_func:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self._broadcast_price_update(price_data),
|
||||
loop,
|
||||
)
|
||||
|
||||
return callback
|
||||
|
||||
def _start_mock_mode(self):
|
||||
from backend.data.mock_price_manager import MockPriceManager
|
||||
|
||||
self._price_manager = MockPriceManager(
|
||||
poll_interval=self.poll_interval,
|
||||
volatility=0.5,
|
||||
)
|
||||
self._price_manager.add_price_callback(self._make_price_callback())
|
||||
self._price_manager.subscribe(
|
||||
self.tickers,
|
||||
base_prices={t: 100.0 for t in self.tickers},
|
||||
)
|
||||
self._price_manager.start()
|
||||
|
||||
def _start_real_mode(self):
|
||||
from backend.data.polling_price_manager import PollingPriceManager
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError("API key required for live mode")
|
||||
self._price_manager = PollingPriceManager(
|
||||
api_key=self.api_key,
|
||||
poll_interval=self.poll_interval,
|
||||
)
|
||||
self._price_manager.add_price_callback(self._make_price_callback())
|
||||
self._price_manager.subscribe(self.tickers)
|
||||
self._price_manager.start()
|
||||
|
||||
def _start_backtest_mode(self):
|
||||
from backend.data.historical_price_manager import (
|
||||
HistoricalPriceManager,
|
||||
)
|
||||
|
||||
self._price_manager = HistoricalPriceManager()
|
||||
self._price_manager.add_price_callback(self._make_price_callback())
|
||||
self._price_manager.subscribe(self.tickers)
|
||||
|
||||
if self.backtest_start_date and self.backtest_end_date:
|
||||
self._price_manager.preload_data(
|
||||
self.backtest_start_date,
|
||||
self.backtest_end_date,
|
||||
)
|
||||
|
||||
self._price_manager.start()
|
||||
|
||||
async def _broadcast_price_update(self, price_data: Dict[str, Any]):
|
||||
"""Broadcast price update to frontend"""
|
||||
if not self._broadcast_func:
|
||||
return
|
||||
|
||||
symbol = price_data["symbol"]
|
||||
price = price_data["price"]
|
||||
open_price = price_data.get("open", price)
|
||||
ret = (
|
||||
((price - open_price) / open_price) * 100 if open_price > 0 else 0
|
||||
)
|
||||
|
||||
await self._broadcast_func(
|
||||
{
|
||||
"type": "price_update",
|
||||
"symbol": symbol,
|
||||
"price": price,
|
||||
"open": open_price,
|
||||
"ret": ret,
|
||||
"timestamp": price_data.get("timestamp"),
|
||||
"realtime_prices": {
|
||||
t: self._get_cached_price(t) for t in self.tickers
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
def _get_cached_price(self, ticker: str) -> Dict[str, Any]:
|
||||
"""Get cached price data for a ticker"""
|
||||
if ticker in self.cache:
|
||||
return self.cache[ticker]
|
||||
# Return from price manager if not in cache
|
||||
if self._price_manager:
|
||||
price = self._price_manager.get_latest_price(ticker)
|
||||
if price:
|
||||
return {"price": price, "symbol": ticker}
|
||||
return {"price": 0, "symbol": ticker}
|
||||
|
||||
def stop(self):
|
||||
"""Stop market service"""
|
||||
if not self.running:
|
||||
return
|
||||
self.running = False
|
||||
if self._price_manager:
|
||||
self._price_manager.stop()
|
||||
self._price_manager = None
|
||||
self._loop = None
|
||||
self._broadcast_func = None
|
||||
|
||||
# Backtest methods
|
||||
def set_backtest_date(self, date: str):
|
||||
"""Set current backtest date"""
|
||||
if not self.backtest_mode or not self._price_manager:
|
||||
return
|
||||
self._current_date = date
|
||||
self._price_manager.set_date(date)
|
||||
logger.info(f"Backtest date: {date}")
|
||||
|
||||
async def emit_market_open(self):
|
||||
"""Emit market open prices"""
|
||||
if self.backtest_mode and self._price_manager:
|
||||
self._price_manager.emit_open_prices()
|
||||
# Log prices for debugging
|
||||
prices = self.get_open_prices()
|
||||
logger.info(f"Open prices: {prices}")
|
||||
|
||||
async def emit_market_close(self):
|
||||
"""Emit market close prices"""
|
||||
if self.backtest_mode and self._price_manager:
|
||||
self._price_manager.emit_close_prices()
|
||||
# Log prices for debugging
|
||||
prices = self.get_close_prices()
|
||||
logger.info(f"Close prices: {prices}")
|
||||
|
||||
def get_open_prices(self) -> Dict[str, float]:
|
||||
"""Get open prices for all tickers"""
|
||||
prices = {}
|
||||
for ticker in self.tickers:
|
||||
price = None
|
||||
# Try price manager first
|
||||
if self.backtest_mode and self._price_manager:
|
||||
price = self._price_manager.get_open_price(ticker)
|
||||
# Fallback to cache
|
||||
if price is None or price <= 0:
|
||||
cached = self.cache.get(ticker, {})
|
||||
price = cached.get("open") or cached.get("price")
|
||||
prices[ticker] = price if price and price > 0 else 0.0
|
||||
return prices
|
||||
|
||||
def get_close_prices(self) -> Dict[str, float]:
|
||||
"""Get close prices for all tickers"""
|
||||
prices = {}
|
||||
for ticker in self.tickers:
|
||||
price = None
|
||||
# Try price manager first
|
||||
if self.backtest_mode and self._price_manager:
|
||||
price = self._price_manager.get_close_price(ticker)
|
||||
# Fallback to cache
|
||||
if price is None or price <= 0:
|
||||
cached = self.cache.get(ticker, {})
|
||||
price = cached.get("close") or cached.get("price")
|
||||
prices[ticker] = price if price and price > 0 else 0.0
|
||||
return prices
|
||||
|
||||
def get_price_for_date(
|
||||
self,
|
||||
ticker: str,
|
||||
date: str,
|
||||
price_type: str = "close",
|
||||
) -> Optional[float]:
|
||||
"""Get price for a specific date"""
|
||||
if self.backtest_mode and self._price_manager:
|
||||
return self._price_manager.get_price_for_date(
|
||||
ticker,
|
||||
date,
|
||||
price_type,
|
||||
)
|
||||
return self.get_price_sync(ticker)
|
||||
|
||||
# Common methods
|
||||
def get_price_sync(self, ticker: str) -> Optional[float]:
|
||||
"""Get latest price synchronously"""
|
||||
# Try cache first
|
||||
data = self.cache.get(ticker)
|
||||
if data and data.get("price"):
|
||||
return data["price"]
|
||||
# Try price manager
|
||||
if self._price_manager:
|
||||
return self._price_manager.get_latest_price(ticker)
|
||||
return None
|
||||
|
||||
def get_all_prices(self) -> Dict[str, float]:
|
||||
"""Get all latest prices"""
|
||||
prices = {}
|
||||
for ticker in self.tickers:
|
||||
price = self.get_price_sync(ticker)
|
||||
prices[ticker] = price if price and price > 0 else 0.0
|
||||
return prices
|
||||
|
||||
# Live mode async waiting methods
|
||||
|
||||
def _now_nyse(self) -> datetime:
|
||||
"""Get current time in NYSE timezone"""
|
||||
return datetime.now(NYSE_TZ)
|
||||
|
||||
def _is_trading_day(self, date: datetime) -> bool:
|
||||
"""Check if date is a NYSE trading day"""
|
||||
date_str = date.strftime("%Y-%m-%d")
|
||||
valid_days = NYSE_CALENDAR.valid_days(
|
||||
start_date=date_str,
|
||||
end_date=date_str,
|
||||
)
|
||||
return len(valid_days) > 0
|
||||
|
||||
def _get_market_hours(self, date: datetime) -> tuple:
|
||||
"""Get market open and close times for a given date"""
|
||||
date_str = date.strftime("%Y-%m-%d")
|
||||
schedule = NYSE_CALENDAR.schedule(
|
||||
start_date=date_str,
|
||||
end_date=date_str,
|
||||
)
|
||||
if schedule.empty:
|
||||
return None, None
|
||||
market_open = schedule.iloc[0]["market_open"].to_pydatetime()
|
||||
market_close = schedule.iloc[0]["market_close"].to_pydatetime()
|
||||
return market_open, market_close
|
||||
|
||||
def _next_trading_day(self, from_date: datetime) -> datetime:
|
||||
"""Find the next trading day from given date"""
|
||||
check_date = from_date + timedelta(days=1)
|
||||
for _ in range(10): # Max 10 days ahead (handles holidays)
|
||||
if self._is_trading_day(check_date):
|
||||
return check_date
|
||||
check_date += timedelta(days=1)
|
||||
return check_date
|
||||
|
||||
def _get_trading_date_for_execution(self) -> tuple:
|
||||
"""
|
||||
Determine the trading date for execution.
|
||||
|
||||
Returns:
|
||||
(trading_date, market_open_time, market_close_time)
|
||||
|
||||
Logic:
|
||||
- If today is a trading day and market has opened: use today
|
||||
- If today is a trading day but market hasn't opened: wait for open
|
||||
- If today is not a trading day: use next trading day
|
||||
"""
|
||||
now = self._now_nyse()
|
||||
today = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
|
||||
if self._is_trading_day(today):
|
||||
market_open, market_close = self._get_market_hours(today)
|
||||
return today, market_open, market_close
|
||||
else:
|
||||
# Weekend or holiday - find next trading day
|
||||
next_day = self._next_trading_day(today)
|
||||
market_open, market_close = self._get_market_hours(next_day)
|
||||
return next_day, market_open, market_close
|
||||
|
||||
async def wait_for_open_prices(self) -> Dict[str, float]:
|
||||
"""
|
||||
Wait for market open and return open prices.
|
||||
|
||||
Behavior:
|
||||
- If market is already open today: return current prices immediately
|
||||
- If market hasn't opened yet today: wait until open
|
||||
- If not a trading day: wait until next trading day opens
|
||||
"""
|
||||
now = self._now_nyse()
|
||||
trading_date, market_open, _ = self._get_trading_date_for_execution()
|
||||
|
||||
if market_open is None:
|
||||
logger.warning("Could not determine market hours")
|
||||
return self.get_all_prices()
|
||||
|
||||
trading_date_str = trading_date.strftime("%Y-%m-%d")
|
||||
|
||||
# Check if we need to wait
|
||||
if now < market_open:
|
||||
wait_seconds = (market_open - now).total_seconds()
|
||||
logger.info(
|
||||
f"Waiting {wait_seconds/60:.1f} min for market open "
|
||||
f"({trading_date_str} {market_open.strftime('%H:%M')} ET)",
|
||||
)
|
||||
await asyncio.sleep(wait_seconds)
|
||||
# Small delay to ensure prices are available
|
||||
await asyncio.sleep(5)
|
||||
else:
|
||||
logger.info(
|
||||
f"Market already open for {trading_date_str}, "
|
||||
f"getting current prices",
|
||||
)
|
||||
|
||||
# Poll until we have valid prices
|
||||
prices = await self._poll_for_prices()
|
||||
logger.info(f"Got open prices for {trading_date_str}: {prices}")
|
||||
return prices
|
||||
|
||||
async def wait_for_close_prices(self) -> Dict[str, float]:
|
||||
"""
|
||||
Wait for market close and return close prices.
|
||||
|
||||
Behavior:
|
||||
- If market is already closed today: return current prices immediately
|
||||
- If market hasn't closed yet: wait until close
|
||||
"""
|
||||
now = self._now_nyse()
|
||||
trading_date, _, market_close = self._get_trading_date_for_execution()
|
||||
|
||||
if market_close is None:
|
||||
logger.warning("Could not determine market hours")
|
||||
return self.get_all_prices()
|
||||
|
||||
trading_date_str = trading_date.strftime("%Y-%m-%d")
|
||||
|
||||
# Check if we need to wait
|
||||
if now < market_close:
|
||||
wait_seconds = (market_close - now).total_seconds()
|
||||
logger.info(
|
||||
f"Waiting {wait_seconds/60:.1f} min for market close "
|
||||
f"({trading_date_str} {market_close.strftime('%H:%M')} ET)",
|
||||
)
|
||||
await asyncio.sleep(wait_seconds)
|
||||
# Small delay to ensure final prices settle
|
||||
await asyncio.sleep(10)
|
||||
else:
|
||||
logger.info(
|
||||
f"Market already closed for {trading_date_str}, "
|
||||
f"getting close prices",
|
||||
)
|
||||
|
||||
# Get final prices
|
||||
prices = await self._poll_for_prices()
|
||||
logger.info(f"Got close prices for {trading_date_str}: {prices}")
|
||||
return prices
|
||||
|
||||
def get_live_trading_date(self) -> str:
|
||||
"""Get the trading date that will be used for live execution"""
|
||||
trading_date, _, _ = self._get_trading_date_for_execution()
|
||||
return trading_date.strftime("%Y-%m-%d")
|
||||
|
||||
async def _poll_for_prices(
|
||||
self,
|
||||
max_retries: int = 12,
|
||||
) -> Dict[str, float]:
|
||||
"""Poll until all prices are available"""
|
||||
for _ in range(max_retries):
|
||||
prices = self.get_all_prices()
|
||||
if all(p > 0 for p in prices.values()):
|
||||
return prices
|
||||
logger.debug("Waiting for prices to be available...")
|
||||
await asyncio.sleep(5)
|
||||
# Return whatever we have
|
||||
return self.get_all_prices()
|
||||
|
||||
# ========== Market Status Methods ==========
|
||||
|
||||
def get_market_status(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get current market status
|
||||
|
||||
Returns:
|
||||
Dict with status info:
|
||||
- status: 'open' | 'closed' | 'premarket' | 'afterhours'
|
||||
- status_text: Human readable status
|
||||
- is_trading_day: Whether today is a trading day
|
||||
- market_open: Market open time (if trading day)
|
||||
- market_close: Market close time (if trading day)
|
||||
"""
|
||||
if self.backtest_mode:
|
||||
# In backtest mode, always return open
|
||||
return {
|
||||
"status": MarketStatus.OPEN,
|
||||
"status_text": "Backtest Mode",
|
||||
"is_trading_day": True,
|
||||
}
|
||||
|
||||
now = self._now_nyse()
|
||||
today = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
|
||||
is_trading = self._is_trading_day(today)
|
||||
|
||||
if not is_trading:
|
||||
return {
|
||||
"status": MarketStatus.CLOSED,
|
||||
"status_text": "Market Closed (Non-trading Day)",
|
||||
"is_trading_day": False,
|
||||
}
|
||||
|
||||
market_open, market_close = self._get_market_hours(today)
|
||||
|
||||
if market_open is None or market_close is None:
|
||||
return {
|
||||
"status": MarketStatus.CLOSED,
|
||||
"status_text": "Market Closed",
|
||||
"is_trading_day": is_trading,
|
||||
}
|
||||
|
||||
# Determine status based on current time
|
||||
if now < market_open:
|
||||
return {
|
||||
"status": MarketStatus.PREMARKET,
|
||||
"status_text": "Pre-Market",
|
||||
"is_trading_day": True,
|
||||
"market_open": market_open.isoformat(),
|
||||
"market_close": market_close.isoformat(),
|
||||
}
|
||||
elif now > market_close:
|
||||
return {
|
||||
"status": MarketStatus.CLOSED,
|
||||
"status_text": "Market Closed",
|
||||
"is_trading_day": True,
|
||||
"market_open": market_open.isoformat(),
|
||||
"market_close": market_close.isoformat(),
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"status": MarketStatus.OPEN,
|
||||
"status_text": "Market Open",
|
||||
"is_trading_day": True,
|
||||
"market_open": market_open.isoformat(),
|
||||
"market_close": market_close.isoformat(),
|
||||
}
|
||||
|
||||
async def check_and_broadcast_market_status(self):
|
||||
"""Check market status and broadcast if changed"""
|
||||
status = self.get_market_status()
|
||||
current_status = status["status"]
|
||||
|
||||
if current_status != self._last_market_status:
|
||||
self._last_market_status = current_status
|
||||
await self._broadcast_market_status(status)
|
||||
|
||||
# Handle session transitions
|
||||
if current_status == MarketStatus.OPEN:
|
||||
await self._on_session_start()
|
||||
elif (
|
||||
current_status == MarketStatus.CLOSED
|
||||
and self._session_start_values is not None
|
||||
):
|
||||
self._on_session_end()
|
||||
|
||||
async def _broadcast_market_status(self, status: Dict[str, Any]):
|
||||
"""Broadcast market status update to frontend"""
|
||||
if not self._broadcast_func:
|
||||
return
|
||||
|
||||
await self._broadcast_func(
|
||||
{
|
||||
"type": "market_status_update",
|
||||
"market_status": status,
|
||||
"timestamp": datetime.now(NYSE_TZ).isoformat(),
|
||||
},
|
||||
)
|
||||
logger.info(f"Market status: {status['status_text']}")
|
||||
|
||||
async def _on_session_start(self):
|
||||
"""Called when market session starts - capture baseline values"""
|
||||
# Wait briefly for prices to be available
|
||||
await asyncio.sleep(2)
|
||||
|
||||
prices = self.get_all_prices()
|
||||
if prices and any(p > 0 for p in prices.values()):
|
||||
self._session_start_values = prices.copy()
|
||||
self._session_start_timestamp = int(
|
||||
datetime.now().timestamp() * 1000,
|
||||
)
|
||||
logger.info(f"Session started with prices: {prices}")
|
||||
|
||||
def _on_session_end(self):
|
||||
"""Called when market session ends - clear session data"""
|
||||
self._session_start_values = None
|
||||
self._session_start_timestamp = None
|
||||
logger.info("Session ended, cleared session data")
|
||||
|
||||
def get_session_returns(
|
||||
self,
|
||||
current_prices: Dict[str, float],
|
||||
portfolio_value: Optional[float] = None,
|
||||
session_start_portfolio_value: Optional[float] = None,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Calculate session returns (from session start to now)
|
||||
|
||||
Args:
|
||||
current_prices: Current prices for tickers
|
||||
portfolio_value: Current portfolio value (optional)
|
||||
session_start_portfolio_value:
|
||||
|
||||
Returns:
|
||||
Dict with return data or None if session not started
|
||||
"""
|
||||
if self._session_start_values is None:
|
||||
return None
|
||||
|
||||
timestamp = int(datetime.now().timestamp() * 1000)
|
||||
returns = {}
|
||||
|
||||
# Calculate individual ticker returns
|
||||
for ticker, start_price in self._session_start_values.items():
|
||||
current = current_prices.get(ticker)
|
||||
if current and start_price and start_price > 0:
|
||||
ret = ((current - start_price) / start_price) * 100
|
||||
returns[ticker] = round(ret, 4)
|
||||
|
||||
result = {
|
||||
"timestamp": timestamp,
|
||||
"ticker_returns": returns,
|
||||
}
|
||||
|
||||
# Calculate portfolio return if values provided
|
||||
if (
|
||||
portfolio_value is not None
|
||||
and session_start_portfolio_value is not None
|
||||
):
|
||||
if session_start_portfolio_value > 0:
|
||||
portfolio_ret = (
|
||||
(portfolio_value - session_start_portfolio_value)
|
||||
/ session_start_portfolio_value
|
||||
) * 100
|
||||
result["portfolio_return"] = round(portfolio_ret, 4)
|
||||
|
||||
return result
|
||||
|
||||
@property
|
||||
def session_start_values(self) -> Optional[Dict[str, float]]:
|
||||
"""Get session start values for external use"""
|
||||
return self._session_start_values
|
||||
|
||||
@property
|
||||
def session_start_timestamp(self) -> Optional[int]:
|
||||
"""Get session start timestamp"""
|
||||
return self._session_start_timestamp
|
||||
1099
backend/services/storage.py
Normal file
1099
backend/services/storage.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user