254 lines
8.5 KiB
Python
254 lines
8.5 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""
|
|
Historical Price Manager for backtest mode
|
|
"""
|
|
import logging
|
|
from datetime import datetime
|
|
from typing import Callable, Dict, List, Optional
|
|
|
|
import pandas as pd
|
|
from backend.data.market_store import MarketStore
|
|
from backend.data.provider_utils import normalize_symbol
|
|
from backend.data.provider_router import get_provider_router
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class HistoricalPriceManager:
|
|
"""Provides historical prices for backtest mode"""
|
|
|
|
def __init__(self):
|
|
self.subscribed_symbols = []
|
|
self.price_callbacks = []
|
|
self._price_cache = {}
|
|
self._current_date = None
|
|
self.latest_prices = {}
|
|
self.open_prices = {}
|
|
self.close_prices = {}
|
|
self.running = False
|
|
self._router = get_provider_router()
|
|
self._market_store = MarketStore()
|
|
|
|
def subscribe(
|
|
self,
|
|
symbols: List[str],
|
|
):
|
|
"""Subscribe to symbols"""
|
|
for symbol in symbols:
|
|
symbol = normalize_symbol(symbol)
|
|
if symbol not in self.subscribed_symbols:
|
|
self.subscribed_symbols.append(symbol)
|
|
|
|
def unsubscribe(self, symbols: List[str]):
|
|
"""Unsubscribe from symbols"""
|
|
for symbol in symbols:
|
|
symbol = normalize_symbol(symbol)
|
|
if symbol in self.subscribed_symbols:
|
|
self.subscribed_symbols.remove(symbol)
|
|
self._price_cache.pop(symbol, None)
|
|
|
|
def add_price_callback(self, callback: Callable):
|
|
"""Add price update callback"""
|
|
self.price_callbacks.append(callback)
|
|
|
|
def _load_from_csv(self, symbol: str) -> Optional[pd.DataFrame]:
|
|
"""Load price data from local CSV file."""
|
|
try:
|
|
df = self._router.load_local_price_frame(symbol)
|
|
return df if not df.empty else None
|
|
except Exception as e:
|
|
logger.warning(f"Failed to load CSV for {symbol}: {e}")
|
|
return None
|
|
|
|
def _load_from_market_db(
|
|
self,
|
|
symbol: str,
|
|
start_date: str,
|
|
end_date: str,
|
|
) -> Optional[pd.DataFrame]:
|
|
"""Load price data from the long-lived market research database."""
|
|
try:
|
|
rows = self._market_store.get_ohlc(symbol, start_date, end_date)
|
|
if not rows:
|
|
return None
|
|
df = pd.DataFrame(rows)
|
|
if df.empty or "date" not in df.columns:
|
|
return None
|
|
df["Date"] = pd.to_datetime(df["date"])
|
|
df.set_index("Date", inplace=True)
|
|
df.sort_index(inplace=True)
|
|
return df
|
|
except Exception as e:
|
|
logger.warning(f"Failed to load market DB data for {symbol}: {e}")
|
|
return None
|
|
|
|
def preload_data(self, start_date: str, end_date: str):
|
|
"""Preload historical data from market DB first, then local CSV."""
|
|
logger.info(f"Preloading data: {start_date} to {end_date}")
|
|
|
|
for symbol in self.subscribed_symbols:
|
|
if symbol in self._price_cache:
|
|
continue
|
|
|
|
df = self._load_from_market_db(symbol, start_date, end_date)
|
|
if df is not None and not df.empty:
|
|
self._price_cache[symbol] = df
|
|
logger.info(f"Loaded {symbol} from market DB: {len(df)} records")
|
|
continue
|
|
|
|
df = self._load_from_csv(symbol)
|
|
if df is not None and not df.empty:
|
|
self._price_cache[symbol] = df
|
|
logger.info(f"Loaded {symbol} from CSV: {len(df)} records")
|
|
else:
|
|
logger.warning(f"No market DB or CSV data for {symbol}")
|
|
|
|
def set_date(self, date: str):
|
|
"""Set current trading date and update prices"""
|
|
self._current_date = date
|
|
date_dt = pd.Timestamp(date)
|
|
|
|
for symbol in self.subscribed_symbols:
|
|
df = self._price_cache.get(symbol)
|
|
if df is None or df.empty:
|
|
# Keep previous prices if no data available
|
|
logger.warning(f"No cached data for {symbol} on {date}")
|
|
continue
|
|
|
|
# Find exact date or closest earlier date
|
|
if date_dt in df.index:
|
|
row = df.loc[date_dt]
|
|
else:
|
|
valid_dates = df.index[df.index <= date_dt]
|
|
if len(valid_dates) == 0:
|
|
logger.warning(f"No data for {symbol} on or before {date}")
|
|
continue
|
|
row = df.loc[valid_dates[-1]]
|
|
|
|
open_price = float(row["open"])
|
|
close_price = float(row["close"])
|
|
|
|
self.open_prices[symbol] = open_price
|
|
self.close_prices[symbol] = close_price
|
|
self.latest_prices[symbol] = open_price
|
|
|
|
logger.debug(
|
|
f"{symbol} @ {date}: open={open_price:.2f}, close={close_price:.2f}", # noqa: E501
|
|
)
|
|
|
|
def emit_open_prices(self):
|
|
"""Emit open prices to callbacks"""
|
|
if not self._current_date:
|
|
return
|
|
|
|
timestamp = int(
|
|
datetime.strptime(self._current_date, "%Y-%m-%d").timestamp()
|
|
* 1000,
|
|
)
|
|
|
|
for symbol in self.subscribed_symbols:
|
|
price = self.open_prices.get(symbol)
|
|
if price is None or price <= 0:
|
|
logger.warning(f"Invalid open price for {symbol}: {price}")
|
|
continue
|
|
|
|
self.latest_prices[symbol] = price
|
|
self._emit_price(symbol, price, timestamp)
|
|
|
|
def emit_close_prices(self):
|
|
"""Emit close prices to callbacks"""
|
|
if not self._current_date:
|
|
return
|
|
|
|
timestamp = int(
|
|
datetime.strptime(self._current_date, "%Y-%m-%d").timestamp()
|
|
* 1000,
|
|
)
|
|
timestamp += 23400000 # Add 6.5 hours
|
|
|
|
for symbol in self.subscribed_symbols:
|
|
price = self.close_prices.get(symbol)
|
|
if price is None or price <= 0:
|
|
logger.warning(f"Invalid close price for {symbol}: {price}")
|
|
continue
|
|
|
|
self.latest_prices[symbol] = price
|
|
self._emit_price(symbol, price, timestamp)
|
|
|
|
def _emit_price(self, symbol: str, price: float, timestamp: int):
|
|
"""Emit single price to callbacks"""
|
|
open_price = self.open_prices.get(symbol, price)
|
|
close_price = self.close_prices.get(symbol, price)
|
|
ret = (
|
|
((price - open_price) / open_price) * 100 if open_price > 0 else 0
|
|
)
|
|
|
|
price_data = {
|
|
"symbol": symbol,
|
|
"price": price,
|
|
"timestamp": timestamp,
|
|
"open": open_price,
|
|
"close": close_price,
|
|
"high": max(open_price, close_price),
|
|
"low": min(open_price, close_price),
|
|
"ret": ret,
|
|
}
|
|
|
|
for callback in self.price_callbacks:
|
|
try:
|
|
callback(price_data)
|
|
except Exception as e:
|
|
logger.error(f"Callback error for {symbol}: {e}")
|
|
|
|
def get_price_for_date(
|
|
self,
|
|
symbol: str,
|
|
date: str,
|
|
price_type: str = "close",
|
|
) -> Optional[float]:
|
|
"""Get price for a specific date"""
|
|
df = self._price_cache.get(symbol)
|
|
if df is None or df.empty:
|
|
return self.latest_prices.get(symbol)
|
|
|
|
date_dt = pd.Timestamp(date)
|
|
if date_dt in df.index:
|
|
return float(df.loc[date_dt, price_type])
|
|
|
|
valid_dates = df.index[df.index <= date_dt]
|
|
if len(valid_dates) == 0:
|
|
return self.latest_prices.get(symbol)
|
|
return float(df.loc[valid_dates[-1], price_type])
|
|
|
|
def start(self):
|
|
"""Start manager"""
|
|
self.running = True
|
|
|
|
def stop(self):
|
|
"""Stop manager"""
|
|
self.running = False
|
|
|
|
def get_latest_price(self, symbol: str) -> Optional[float]:
|
|
return self.latest_prices.get(symbol)
|
|
|
|
def get_all_latest_prices(self) -> Dict[str, float]:
|
|
return self.latest_prices.copy()
|
|
|
|
def get_open_price(self, symbol: str) -> Optional[float]:
|
|
# Return open price, fallback to latest if not set
|
|
price = self.open_prices.get(symbol)
|
|
if price is None or price <= 0:
|
|
return self.latest_prices.get(symbol)
|
|
return price
|
|
|
|
def get_close_price(self, symbol: str) -> Optional[float]:
|
|
# Return close price, fallback to latest if not set
|
|
price = self.close_prices.get(symbol)
|
|
if price is None or price <= 0:
|
|
return self.latest_prices.get(symbol)
|
|
return price
|
|
|
|
def reset_open_prices(self):
|
|
# Don't clear prices - keep them for continuity
|
|
pass
|