feat: initial commit - EvoTraders project

量化交易多智能体系统,包含:
- 分析师、投资组合经理、风险经理等智能体
- 股票分析、投资组合管理、风险控制工具
- React 前端界面
- FastAPI 后端服务

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
2026-03-13 04:34:06 +08:00
commit 12de93aa30
115 changed files with 29304 additions and 0 deletions

View File

@@ -0,0 +1,2 @@
# -*- coding: utf-8 -*-
"""Services layer for infrastructure components"""

569
backend/services/gateway.py Normal file
View File

@@ -0,0 +1,569 @@
# -*- coding: utf-8 -*-
"""
WebSocket Gateway for frontend communication
"""
import asyncio
import json
import logging
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Set
import websockets
from websockets.server import WebSocketServerProtocol
from backend.utils.msg_adapter import FrontendAdapter
from backend.utils.terminal_dashboard import get_dashboard
from backend.core.pipeline import TradingPipeline
from backend.core.state_sync import StateSync
from backend.services.market import MarketService
from backend.services.storage import StorageService
logger = logging.getLogger(__name__)
class Gateway:
"""WebSocket Gateway for frontend communication"""
def __init__(
self,
market_service: MarketService,
storage_service: StorageService,
pipeline: TradingPipeline,
state_sync: Optional[StateSync] = None,
scheduler_callback: Optional[Callable] = None,
config: Dict[str, Any] = None,
):
self.market_service = market_service
self.storage = storage_service
self.pipeline = pipeline
self.scheduler_callback = scheduler_callback
self.config = config or {}
self.mode = self.config.get("mode", "live")
self.is_backtest = self.mode == "backtest" or self.config.get(
"backtest_mode",
False,
)
self.state_sync = state_sync or StateSync(storage=storage_service)
# self.state_sync.set_mode(self.is_backtest)
self.state_sync.set_broadcast_fn(self.broadcast)
self.pipeline.state_sync = self.state_sync
self.connected_clients: Set[WebSocketServerProtocol] = set()
self.lock = asyncio.Lock()
self._backtest_task: Optional[asyncio.Task] = None
self._backtest_start_date: Optional[str] = None
self._backtest_end_date: Optional[str] = None
self._dashboard = get_dashboard()
self._market_status_task: Optional[asyncio.Task] = None
# Session tracking for live returns
self._session_start_portfolio_value: Optional[float] = None
async def start(self, host: str = "0.0.0.0", port: int = 8766):
"""Start gateway server"""
logger.info(f"Starting gateway on {host}:{port}")
# Initialize terminal dashboard
self._dashboard.set_config(
mode=self.mode,
config_name=self.config.get("config_name", "default"),
host=host,
port=port,
poll_interval=self.config.get("poll_interval", 10),
mock=self.config.get("mock_mode", False),
tickers=self.config.get("tickers", []),
initial_cash=self.storage.initial_cash,
start_date=self._backtest_start_date or "",
end_date=self._backtest_end_date or "",
)
self._dashboard.start()
self.state_sync.load_state()
self.state_sync.update_state("status", "running")
self.state_sync.update_state("server_mode", self.mode)
self.state_sync.update_state("is_backtest", self.is_backtest)
self.state_sync.update_state(
"is_mock_mode",
self.config.get("mock_mode", False),
)
# Load and display existing portfolio state if available
summary = self.storage.load_file("summary")
if summary:
holdings = self.storage.load_file("holdings") or []
trades = self.storage.load_file("trades") or []
current_date = self.state_sync.state.get("current_date")
self._dashboard.update(
date=current_date or "-",
status="running",
portfolio=summary,
holdings=holdings,
trades=trades,
)
logger.info(
"Loaded existing portfolio: $%s",
f"{summary.get('totalAssetValue', 0):,.2f}",
)
await self.market_service.start(broadcast_func=self.broadcast)
if self.scheduler_callback:
await self.scheduler_callback(callback=self.on_strategy_trigger)
# Start market status monitoring (only for live mode)
if not self.is_backtest:
self._market_status_task = asyncio.create_task(
self._market_status_monitor(),
)
async with websockets.serve(
self.handle_client,
host,
port,
ping_interval=30,
ping_timeout=60,
):
logger.info(
f"Gateway started: ws://{host}:{port}, mode={self.mode}",
)
await asyncio.Future()
@property
def state(self) -> Dict[str, Any]:
return self.state_sync.state
async def handle_client(self, websocket: WebSocketServerProtocol):
"""Handle WebSocket client connection"""
async with self.lock:
self.connected_clients.add(websocket)
await self._send_initial_state(websocket)
await self._handle_client_messages(websocket)
async with self.lock:
self.connected_clients.discard(websocket)
async def _send_initial_state(self, websocket: WebSocketServerProtocol):
state_payload = self.state_sync.get_initial_state_payload(
include_dashboard=True,
)
# Include market status in initial state
state_payload[
"market_status"
] = self.market_service.get_market_status()
# Include live returns if session is active
if self.storage.is_live_session_active:
live_returns = self.storage.get_live_returns()
if "portfolio" in state_payload:
state_payload["portfolio"].update(live_returns)
await websocket.send(
json.dumps(
{"type": "initial_state", "state": state_payload},
ensure_ascii=False,
default=str,
),
)
async def _handle_client_messages(
self,
websocket: WebSocketServerProtocol,
):
try:
async for message in websocket:
data = json.loads(message)
msg_type = data.get("type", "unknown")
if msg_type == "ping":
await websocket.send(
json.dumps(
{
"type": "pong",
"timestamp": datetime.now().isoformat(),
},
ensure_ascii=False,
),
)
elif msg_type == "get_state":
await self._send_initial_state(websocket)
elif msg_type == "start_backtest":
await self._handle_start_backtest(data)
except websockets.ConnectionClosed:
pass
except json.JSONDecodeError:
pass
async def _handle_start_backtest(self, data: Dict[str, Any]):
if not self.is_backtest:
return
dates = data.get("dates", [])
if dates and self._backtest_task is None:
task = asyncio.create_task(
self._run_backtest_dates(dates),
)
task.add_done_callback(self._handle_backtest_exception)
self._backtest_task = task
async def broadcast(self, message: Dict[str, Any]):
"""Broadcast message to all connected clients"""
if not self.connected_clients:
return
message_json = json.dumps(message, ensure_ascii=False, default=str)
async with self.lock:
tasks = [
self._send_to_client(client, message_json)
for client in self.connected_clients.copy()
]
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
async def _send_to_client(
self,
client: WebSocketServerProtocol,
message: str,
):
try:
await client.send(message)
except websockets.ConnectionClosed:
async with self.lock:
self.connected_clients.discard(client)
async def _market_status_monitor(self):
"""Periodically check and broadcast market status changes"""
while True:
try:
await self.market_service.check_and_broadcast_market_status()
# On market open, start live session tracking
status = self.market_service.get_market_status()
if (
status["status"] == "open"
and not self.storage.is_live_session_active
):
self.storage.start_live_session()
summary = self.storage.load_file("summary") or {}
self._session_start_portfolio_value = summary.get(
"totalAssetValue",
self.storage.initial_cash,
)
logger.info(
"Session start portfolio: "
f"${self._session_start_portfolio_value:,.2f}",
)
elif (
status["status"] != "open"
and self.storage.is_live_session_active
):
self.storage.end_live_session()
self._session_start_portfolio_value = None
# Update and broadcast live returns if session is active
if self.storage.is_live_session_active:
await self._update_and_broadcast_live_returns()
await asyncio.sleep(60) # Check every minute
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Market status monitor error: {e}")
await asyncio.sleep(60)
async def _update_and_broadcast_live_returns(self):
"""Calculate and broadcast live returns for current session"""
if not self.storage.is_live_session_active:
return
# Get current prices and calculate portfolio value
prices = self.market_service.get_all_prices()
if not prices or not any(p > 0 for p in prices.values()):
return
# Load current internal state to get baseline values
state = self.storage.load_internal_state()
# Get latest values from history (if available)
equity_history = state.get("equity_history", [])
baseline_history = state.get("baseline_history", [])
baseline_vw_history = state.get("baseline_vw_history", [])
momentum_history = state.get("momentum_history", [])
current_equity = equity_history[-1]["v"] if equity_history else None
current_baseline = (
baseline_history[-1]["v"] if baseline_history else None
)
current_baseline_vw = (
baseline_vw_history[-1]["v"] if baseline_vw_history else None
)
current_momentum = (
momentum_history[-1]["v"] if momentum_history else None
)
# Update live returns with current values
point = self.storage.update_live_returns(
current_equity=current_equity,
current_baseline=current_baseline,
current_baseline_vw=current_baseline_vw,
current_momentum=current_momentum,
)
# Broadcast if we have new data
if point:
live_returns = self.storage.get_live_returns()
await self.broadcast(
{
"type": "team_summary",
"equity_return": live_returns["equity_return"],
"baseline_return": live_returns["baseline_return"],
"baseline_vw_return": live_returns["baseline_vw_return"],
"momentum_return": live_returns["momentum_return"],
},
)
async def on_strategy_trigger(self, date: str):
"""Handle trading cycle trigger"""
logger.info(f"Strategy triggered for {date}")
tickers = self.config.get("tickers", [])
if self.is_backtest:
await self._run_backtest_cycle(date, tickers)
else:
await self._run_live_cycle(date, tickers)
async def _run_backtest_cycle(self, date: str, tickers: List[str]):
"""Run backtest cycle with pre-loaded prices"""
self.market_service.set_backtest_date(date)
await self.market_service.emit_market_open()
await self.state_sync.on_cycle_start(date)
self._dashboard.update(date=date, status="Analyzing...")
prices = self.market_service.get_open_prices()
close_prices = self.market_service.get_close_prices()
market_caps = self._get_market_caps(tickers, date)
result = await self.pipeline.run_cycle(
tickers=tickers,
date=date,
prices=prices,
close_prices=close_prices,
market_caps=market_caps,
)
await self.market_service.emit_market_close()
settlement_result = result.get("settlement_result")
self._save_cycle_results(result, date, close_prices, settlement_result)
await self._broadcast_portfolio_updates(result, close_prices)
await self._finalize_cycle(date)
async def _run_live_cycle(self, date: str, tickers: List[str]):
"""
Run live cycle with real market timing.
- Analysis runs immediately
- Execution waits for market open
(or uses current prices if already open)
- Settlement waits for market close
"""
# Get actual trading date (might be next trading day if weekend)
trading_date = self.market_service.get_live_trading_date()
logger.info(
f"Live cycle: triggered={date}, trading_date={trading_date}",
)
await self.state_sync.on_cycle_start(trading_date)
self._dashboard.update(date=trading_date, status="Analyzing...")
market_caps = self._get_market_caps(tickers, trading_date)
# Run pipeline with async price callbacks
result = await self.pipeline.run_cycle(
tickers=tickers,
date=trading_date,
market_caps=market_caps,
get_open_prices_fn=self.market_service.wait_for_open_prices,
get_close_prices_fn=self.market_service.wait_for_close_prices,
)
close_prices = self.market_service.get_all_prices()
settlement_result = result.get("settlement_result")
self._save_cycle_results(
result,
trading_date,
close_prices,
settlement_result,
)
await self._broadcast_portfolio_updates(result, close_prices)
await self._finalize_cycle(trading_date)
async def _finalize_cycle(self, date: str):
"""Finalize cycle: broadcast state and update dashboard"""
summary = self.storage.load_file("summary") or {}
# Include live returns if session is active
if self.storage.is_live_session_active:
live_returns = self.storage.get_live_returns()
summary.update(live_returns)
await self.state_sync.on_cycle_end(date, portfolio_summary=summary)
holdings = self.storage.load_file("holdings") or []
trades = self.storage.load_file("trades") or []
leaderboard = self.storage.load_file("leaderboard") or []
if leaderboard:
await self.state_sync.on_leaderboard_update(leaderboard)
self._dashboard.update(
date=date,
status="Running",
portfolio=summary,
holdings=holdings,
trades=trades,
)
def _get_market_caps(
self,
tickers: List[str],
date: str,
) -> Dict[str, float]:
"""
Get market caps for tickers (stub implementation)
Args:
tickers: List of tickers
date: Trading date
Returns:
Dict mapping ticker to market cap
"""
from ..tools.data_tools import get_market_cap
market_caps = {}
for ticker in tickers:
try:
market_cap = get_market_cap(ticker, date)
if market_cap:
market_caps[ticker] = market_cap
else:
market_caps[ticker] = 1e9
except Exception:
market_caps[ticker] = 1e9
return market_caps
async def _broadcast_portfolio_updates(
self,
result: Dict[str, Any],
prices: Dict[str, float],
):
portfolio = result.get("portfolio", {})
if portfolio:
holdings = FrontendAdapter.build_holdings(portfolio, prices)
if holdings:
await self.state_sync.on_holdings_update(holdings)
stats = FrontendAdapter.build_stats(portfolio, prices)
if stats:
await self.state_sync.on_stats_update(stats)
executed_trades = result.get("executed_trades", [])
if executed_trades:
await self.state_sync.on_trades_executed(executed_trades)
def _save_cycle_results(
self,
result: Dict[str, Any],
date: str,
prices: Dict[str, float],
settlement_result: Optional[Dict[str, Any]] = None,
):
portfolio = result.get("portfolio", {})
executed_trades = result.get("executed_trades", [])
# Extract baseline values from settlement result
baseline_values = None
if settlement_result:
baseline_values = settlement_result.get("baseline_values")
if portfolio:
self.storage.update_dashboard_after_cycle(
portfolio=portfolio,
prices=prices,
date=date,
executed_trades=executed_trades,
baseline_values=baseline_values,
)
async def _run_backtest_dates(self, dates: List[str]):
self.state_sync.set_backtest_dates(dates)
self._dashboard.update(days_total=len(dates), days_completed=0)
await self.state_sync.on_system_message(
f"Starting backtest - {len(dates)} trading days",
)
try:
for i, date in enumerate(dates):
self._dashboard.update(days_completed=i)
await self.on_strategy_trigger(date=date)
await asyncio.sleep(0.1)
await self.state_sync.on_system_message(
f"Backtest complete - {len(dates)} days",
)
# Update dashboard with final state
summary = self.storage.load_file("summary") or {}
self._dashboard.update(
status="Complete",
portfolio=summary,
days_completed=len(dates),
)
self._dashboard.stop()
self._dashboard.print_final_summary()
except Exception as e:
error_msg = f"Backtest failed: {type(e).__name__}: {str(e)}"
logger.error(error_msg, exc_info=True)
await self.state_sync.on_system_message(error_msg)
self._dashboard.update(status=f"Failed: {str(e)}")
self._dashboard.stop()
raise
finally:
self._backtest_task = None
def _handle_backtest_exception(self, task: asyncio.Task):
"""Handle exceptions from backtest task"""
try:
task.result()
except asyncio.CancelledError:
logger.info("Backtest task was cancelled")
except Exception as e:
logger.error(
f"Backtest task failed with exception:{type(e).__name__}:{e}",
exc_info=True,
)
def set_backtest_dates(self, dates: List[str]):
self.state_sync.set_backtest_dates(dates)
if dates:
self._backtest_start_date = dates[0]
self._backtest_end_date = dates[-1]
self._dashboard.days_total = len(dates)
def stop(self):
self.state_sync.save_state()
self.market_service.stop()
if self._backtest_task:
self._backtest_task.cancel()
if self._market_status_task:
self._market_status_task.cancel()
self._dashboard.stop()

625
backend/services/market.py Normal file
View File

@@ -0,0 +1,625 @@
# -*- coding: utf-8 -*-
"""
Market Data Service
Supports live, mock, and backtest modes
"""
import asyncio
import logging
from datetime import datetime, timedelta
from typing import Any, Callable, Dict, List, Optional
from zoneinfo import ZoneInfo
import pandas_market_calendars as mcal
logger = logging.getLogger(__name__)
# NYSE timezone and calendar
NYSE_TZ = ZoneInfo("America/New_York")
NYSE_CALENDAR = mcal.get_calendar("NYSE")
class MarketStatus:
"""Market status enum-like class"""
OPEN = "open"
CLOSED = "closed"
PREMARKET = "premarket"
AFTERHOURS = "afterhours"
class MarketService:
"""Market data service for price management"""
def __init__(
self,
tickers: List[str],
poll_interval: int = 10,
mock_mode: bool = False,
backtest_mode: bool = False,
api_key: Optional[str] = None,
backtest_start_date: Optional[str] = None,
backtest_end_date: Optional[str] = None,
):
self.tickers = tickers
self.poll_interval = poll_interval
self.mock_mode = mock_mode
self.backtest_mode = backtest_mode
self.api_key = api_key
self.backtest_start_date = backtest_start_date
self.backtest_end_date = backtest_end_date
self.cache: Dict[str, Dict[str, Any]] = {}
self.running = False
self._loop: Optional[asyncio.AbstractEventLoop] = None
self._broadcast_func: Optional[Callable] = None
self._price_manager: Optional[Any] = None
self._current_date: Optional[str] = None
# Market status tracking
self._last_market_status: Optional[str] = None
# Session tracking for live returns
self._session_start_values: Optional[Dict[str, float]] = None
self._session_start_timestamp: Optional[int] = None
@property
def mode_name(self) -> str:
if self.backtest_mode:
return "BACKTEST"
elif self.mock_mode:
return "MOCK"
return "LIVE"
async def start(self, broadcast_func: Callable):
"""Start market data service"""
if self.running:
return
self.running = True
self._loop = asyncio.get_running_loop()
self._broadcast_func = broadcast_func
if self.backtest_mode:
self._start_backtest_mode()
elif self.mock_mode:
self._start_mock_mode()
else:
self._start_real_mode()
logger.info(
f"Market service started: {self.mode_name}, tickers={self.tickers}", # noqa: E501
)
def _make_price_callback(self) -> Callable:
"""Create thread-safe price callback"""
def callback(price_data: Dict[str, Any]):
symbol = price_data["symbol"]
self.cache[symbol] = price_data
loop = self._loop
if loop and loop.is_running() and self._broadcast_func:
asyncio.run_coroutine_threadsafe(
self._broadcast_price_update(price_data),
loop,
)
return callback
def _start_mock_mode(self):
from backend.data.mock_price_manager import MockPriceManager
self._price_manager = MockPriceManager(
poll_interval=self.poll_interval,
volatility=0.5,
)
self._price_manager.add_price_callback(self._make_price_callback())
self._price_manager.subscribe(
self.tickers,
base_prices={t: 100.0 for t in self.tickers},
)
self._price_manager.start()
def _start_real_mode(self):
from backend.data.polling_price_manager import PollingPriceManager
if not self.api_key:
raise ValueError("API key required for live mode")
self._price_manager = PollingPriceManager(
api_key=self.api_key,
poll_interval=self.poll_interval,
)
self._price_manager.add_price_callback(self._make_price_callback())
self._price_manager.subscribe(self.tickers)
self._price_manager.start()
def _start_backtest_mode(self):
from backend.data.historical_price_manager import (
HistoricalPriceManager,
)
self._price_manager = HistoricalPriceManager()
self._price_manager.add_price_callback(self._make_price_callback())
self._price_manager.subscribe(self.tickers)
if self.backtest_start_date and self.backtest_end_date:
self._price_manager.preload_data(
self.backtest_start_date,
self.backtest_end_date,
)
self._price_manager.start()
async def _broadcast_price_update(self, price_data: Dict[str, Any]):
"""Broadcast price update to frontend"""
if not self._broadcast_func:
return
symbol = price_data["symbol"]
price = price_data["price"]
open_price = price_data.get("open", price)
ret = (
((price - open_price) / open_price) * 100 if open_price > 0 else 0
)
await self._broadcast_func(
{
"type": "price_update",
"symbol": symbol,
"price": price,
"open": open_price,
"ret": ret,
"timestamp": price_data.get("timestamp"),
"realtime_prices": {
t: self._get_cached_price(t) for t in self.tickers
},
},
)
def _get_cached_price(self, ticker: str) -> Dict[str, Any]:
"""Get cached price data for a ticker"""
if ticker in self.cache:
return self.cache[ticker]
# Return from price manager if not in cache
if self._price_manager:
price = self._price_manager.get_latest_price(ticker)
if price:
return {"price": price, "symbol": ticker}
return {"price": 0, "symbol": ticker}
def stop(self):
"""Stop market service"""
if not self.running:
return
self.running = False
if self._price_manager:
self._price_manager.stop()
self._price_manager = None
self._loop = None
self._broadcast_func = None
# Backtest methods
def set_backtest_date(self, date: str):
"""Set current backtest date"""
if not self.backtest_mode or not self._price_manager:
return
self._current_date = date
self._price_manager.set_date(date)
logger.info(f"Backtest date: {date}")
async def emit_market_open(self):
"""Emit market open prices"""
if self.backtest_mode and self._price_manager:
self._price_manager.emit_open_prices()
# Log prices for debugging
prices = self.get_open_prices()
logger.info(f"Open prices: {prices}")
async def emit_market_close(self):
"""Emit market close prices"""
if self.backtest_mode and self._price_manager:
self._price_manager.emit_close_prices()
# Log prices for debugging
prices = self.get_close_prices()
logger.info(f"Close prices: {prices}")
def get_open_prices(self) -> Dict[str, float]:
"""Get open prices for all tickers"""
prices = {}
for ticker in self.tickers:
price = None
# Try price manager first
if self.backtest_mode and self._price_manager:
price = self._price_manager.get_open_price(ticker)
# Fallback to cache
if price is None or price <= 0:
cached = self.cache.get(ticker, {})
price = cached.get("open") or cached.get("price")
prices[ticker] = price if price and price > 0 else 0.0
return prices
def get_close_prices(self) -> Dict[str, float]:
"""Get close prices for all tickers"""
prices = {}
for ticker in self.tickers:
price = None
# Try price manager first
if self.backtest_mode and self._price_manager:
price = self._price_manager.get_close_price(ticker)
# Fallback to cache
if price is None or price <= 0:
cached = self.cache.get(ticker, {})
price = cached.get("close") or cached.get("price")
prices[ticker] = price if price and price > 0 else 0.0
return prices
def get_price_for_date(
self,
ticker: str,
date: str,
price_type: str = "close",
) -> Optional[float]:
"""Get price for a specific date"""
if self.backtest_mode and self._price_manager:
return self._price_manager.get_price_for_date(
ticker,
date,
price_type,
)
return self.get_price_sync(ticker)
# Common methods
def get_price_sync(self, ticker: str) -> Optional[float]:
"""Get latest price synchronously"""
# Try cache first
data = self.cache.get(ticker)
if data and data.get("price"):
return data["price"]
# Try price manager
if self._price_manager:
return self._price_manager.get_latest_price(ticker)
return None
def get_all_prices(self) -> Dict[str, float]:
"""Get all latest prices"""
prices = {}
for ticker in self.tickers:
price = self.get_price_sync(ticker)
prices[ticker] = price if price and price > 0 else 0.0
return prices
# Live mode async waiting methods
def _now_nyse(self) -> datetime:
"""Get current time in NYSE timezone"""
return datetime.now(NYSE_TZ)
def _is_trading_day(self, date: datetime) -> bool:
"""Check if date is a NYSE trading day"""
date_str = date.strftime("%Y-%m-%d")
valid_days = NYSE_CALENDAR.valid_days(
start_date=date_str,
end_date=date_str,
)
return len(valid_days) > 0
def _get_market_hours(self, date: datetime) -> tuple:
"""Get market open and close times for a given date"""
date_str = date.strftime("%Y-%m-%d")
schedule = NYSE_CALENDAR.schedule(
start_date=date_str,
end_date=date_str,
)
if schedule.empty:
return None, None
market_open = schedule.iloc[0]["market_open"].to_pydatetime()
market_close = schedule.iloc[0]["market_close"].to_pydatetime()
return market_open, market_close
def _next_trading_day(self, from_date: datetime) -> datetime:
"""Find the next trading day from given date"""
check_date = from_date + timedelta(days=1)
for _ in range(10): # Max 10 days ahead (handles holidays)
if self._is_trading_day(check_date):
return check_date
check_date += timedelta(days=1)
return check_date
def _get_trading_date_for_execution(self) -> tuple:
"""
Determine the trading date for execution.
Returns:
(trading_date, market_open_time, market_close_time)
Logic:
- If today is a trading day and market has opened: use today
- If today is a trading day but market hasn't opened: wait for open
- If today is not a trading day: use next trading day
"""
now = self._now_nyse()
today = now.replace(hour=0, minute=0, second=0, microsecond=0)
if self._is_trading_day(today):
market_open, market_close = self._get_market_hours(today)
return today, market_open, market_close
else:
# Weekend or holiday - find next trading day
next_day = self._next_trading_day(today)
market_open, market_close = self._get_market_hours(next_day)
return next_day, market_open, market_close
async def wait_for_open_prices(self) -> Dict[str, float]:
"""
Wait for market open and return open prices.
Behavior:
- If market is already open today: return current prices immediately
- If market hasn't opened yet today: wait until open
- If not a trading day: wait until next trading day opens
"""
now = self._now_nyse()
trading_date, market_open, _ = self._get_trading_date_for_execution()
if market_open is None:
logger.warning("Could not determine market hours")
return self.get_all_prices()
trading_date_str = trading_date.strftime("%Y-%m-%d")
# Check if we need to wait
if now < market_open:
wait_seconds = (market_open - now).total_seconds()
logger.info(
f"Waiting {wait_seconds/60:.1f} min for market open "
f"({trading_date_str} {market_open.strftime('%H:%M')} ET)",
)
await asyncio.sleep(wait_seconds)
# Small delay to ensure prices are available
await asyncio.sleep(5)
else:
logger.info(
f"Market already open for {trading_date_str}, "
f"getting current prices",
)
# Poll until we have valid prices
prices = await self._poll_for_prices()
logger.info(f"Got open prices for {trading_date_str}: {prices}")
return prices
async def wait_for_close_prices(self) -> Dict[str, float]:
"""
Wait for market close and return close prices.
Behavior:
- If market is already closed today: return current prices immediately
- If market hasn't closed yet: wait until close
"""
now = self._now_nyse()
trading_date, _, market_close = self._get_trading_date_for_execution()
if market_close is None:
logger.warning("Could not determine market hours")
return self.get_all_prices()
trading_date_str = trading_date.strftime("%Y-%m-%d")
# Check if we need to wait
if now < market_close:
wait_seconds = (market_close - now).total_seconds()
logger.info(
f"Waiting {wait_seconds/60:.1f} min for market close "
f"({trading_date_str} {market_close.strftime('%H:%M')} ET)",
)
await asyncio.sleep(wait_seconds)
# Small delay to ensure final prices settle
await asyncio.sleep(10)
else:
logger.info(
f"Market already closed for {trading_date_str}, "
f"getting close prices",
)
# Get final prices
prices = await self._poll_for_prices()
logger.info(f"Got close prices for {trading_date_str}: {prices}")
return prices
def get_live_trading_date(self) -> str:
"""Get the trading date that will be used for live execution"""
trading_date, _, _ = self._get_trading_date_for_execution()
return trading_date.strftime("%Y-%m-%d")
async def _poll_for_prices(
self,
max_retries: int = 12,
) -> Dict[str, float]:
"""Poll until all prices are available"""
for _ in range(max_retries):
prices = self.get_all_prices()
if all(p > 0 for p in prices.values()):
return prices
logger.debug("Waiting for prices to be available...")
await asyncio.sleep(5)
# Return whatever we have
return self.get_all_prices()
# ========== Market Status Methods ==========
def get_market_status(self) -> Dict[str, Any]:
"""
Get current market status
Returns:
Dict with status info:
- status: 'open' | 'closed' | 'premarket' | 'afterhours'
- status_text: Human readable status
- is_trading_day: Whether today is a trading day
- market_open: Market open time (if trading day)
- market_close: Market close time (if trading day)
"""
if self.backtest_mode:
# In backtest mode, always return open
return {
"status": MarketStatus.OPEN,
"status_text": "Backtest Mode",
"is_trading_day": True,
}
now = self._now_nyse()
today = now.replace(hour=0, minute=0, second=0, microsecond=0)
is_trading = self._is_trading_day(today)
if not is_trading:
return {
"status": MarketStatus.CLOSED,
"status_text": "Market Closed (Non-trading Day)",
"is_trading_day": False,
}
market_open, market_close = self._get_market_hours(today)
if market_open is None or market_close is None:
return {
"status": MarketStatus.CLOSED,
"status_text": "Market Closed",
"is_trading_day": is_trading,
}
# Determine status based on current time
if now < market_open:
return {
"status": MarketStatus.PREMARKET,
"status_text": "Pre-Market",
"is_trading_day": True,
"market_open": market_open.isoformat(),
"market_close": market_close.isoformat(),
}
elif now > market_close:
return {
"status": MarketStatus.CLOSED,
"status_text": "Market Closed",
"is_trading_day": True,
"market_open": market_open.isoformat(),
"market_close": market_close.isoformat(),
}
else:
return {
"status": MarketStatus.OPEN,
"status_text": "Market Open",
"is_trading_day": True,
"market_open": market_open.isoformat(),
"market_close": market_close.isoformat(),
}
async def check_and_broadcast_market_status(self):
"""Check market status and broadcast if changed"""
status = self.get_market_status()
current_status = status["status"]
if current_status != self._last_market_status:
self._last_market_status = current_status
await self._broadcast_market_status(status)
# Handle session transitions
if current_status == MarketStatus.OPEN:
await self._on_session_start()
elif (
current_status == MarketStatus.CLOSED
and self._session_start_values is not None
):
self._on_session_end()
async def _broadcast_market_status(self, status: Dict[str, Any]):
"""Broadcast market status update to frontend"""
if not self._broadcast_func:
return
await self._broadcast_func(
{
"type": "market_status_update",
"market_status": status,
"timestamp": datetime.now(NYSE_TZ).isoformat(),
},
)
logger.info(f"Market status: {status['status_text']}")
async def _on_session_start(self):
"""Called when market session starts - capture baseline values"""
# Wait briefly for prices to be available
await asyncio.sleep(2)
prices = self.get_all_prices()
if prices and any(p > 0 for p in prices.values()):
self._session_start_values = prices.copy()
self._session_start_timestamp = int(
datetime.now().timestamp() * 1000,
)
logger.info(f"Session started with prices: {prices}")
def _on_session_end(self):
"""Called when market session ends - clear session data"""
self._session_start_values = None
self._session_start_timestamp = None
logger.info("Session ended, cleared session data")
def get_session_returns(
self,
current_prices: Dict[str, float],
portfolio_value: Optional[float] = None,
session_start_portfolio_value: Optional[float] = None,
) -> Optional[Dict[str, Any]]:
"""
Calculate session returns (from session start to now)
Args:
current_prices: Current prices for tickers
portfolio_value: Current portfolio value (optional)
session_start_portfolio_value:
Returns:
Dict with return data or None if session not started
"""
if self._session_start_values is None:
return None
timestamp = int(datetime.now().timestamp() * 1000)
returns = {}
# Calculate individual ticker returns
for ticker, start_price in self._session_start_values.items():
current = current_prices.get(ticker)
if current and start_price and start_price > 0:
ret = ((current - start_price) / start_price) * 100
returns[ticker] = round(ret, 4)
result = {
"timestamp": timestamp,
"ticker_returns": returns,
}
# Calculate portfolio return if values provided
if (
portfolio_value is not None
and session_start_portfolio_value is not None
):
if session_start_portfolio_value > 0:
portfolio_ret = (
(portfolio_value - session_start_portfolio_value)
/ session_start_portfolio_value
) * 100
result["portfolio_return"] = round(portfolio_ret, 4)
return result
@property
def session_start_values(self) -> Optional[Dict[str, float]]:
"""Get session start values for external use"""
return self._session_start_values
@property
def session_start_timestamp(self) -> Optional[int]:
"""Get session start timestamp"""
return self._session_start_timestamp

1099
backend/services/storage.py Normal file

File diff suppressed because it is too large Load Diff