Fix runtime logging and frontend app regressions
This commit is contained in:
@@ -38,12 +38,13 @@ class RuntimeState:
|
||||
"""
|
||||
|
||||
_instance: Optional["RuntimeState"] = None
|
||||
_lock: asyncio.Lock = asyncio.Lock()
|
||||
_lock: "threading.Lock" = __import__("threading").Lock()
|
||||
|
||||
def __new__(cls) -> "RuntimeState":
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self) -> None:
|
||||
@@ -207,6 +208,13 @@ class RuntimeConfigResponse(BaseModel):
|
||||
resolved: Dict[str, Any]
|
||||
|
||||
|
||||
class RuntimeLogResponse(BaseModel):
|
||||
run_id: Optional[str] = None
|
||||
is_running: bool
|
||||
log_path: Optional[str] = None
|
||||
content: str = ""
|
||||
|
||||
|
||||
class UpdateRuntimeConfigRequest(BaseModel):
|
||||
schedule_mode: Optional[str] = None
|
||||
interval_minutes: Optional[int] = Field(default=None, ge=1)
|
||||
@@ -288,14 +296,20 @@ def _start_gateway_process(
|
||||
"--bootstrap", json.dumps(bootstrap)
|
||||
]
|
||||
|
||||
# Start process
|
||||
process = subprocess.Popen(
|
||||
cmd,
|
||||
env=env,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
cwd=PROJECT_ROOT
|
||||
)
|
||||
log_path = run_dir / "logs" / "gateway.log"
|
||||
log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
log_file = log_path.open("ab")
|
||||
try:
|
||||
process = subprocess.Popen(
|
||||
cmd,
|
||||
env=env,
|
||||
stdout=log_file,
|
||||
stderr=subprocess.STDOUT,
|
||||
cwd=PROJECT_ROOT
|
||||
)
|
||||
finally:
|
||||
log_file.close()
|
||||
|
||||
return process
|
||||
|
||||
@@ -390,6 +404,26 @@ async def get_gateway_port(request: Request) -> Dict[str, Any]:
|
||||
}
|
||||
|
||||
|
||||
@router.get("/logs", response_model=RuntimeLogResponse)
|
||||
async def get_runtime_logs() -> RuntimeLogResponse:
|
||||
"""Return current runtime log tail, or the latest run log if runtime is stopped."""
|
||||
try:
|
||||
context = _get_runtime_context_from_latest_snapshot()
|
||||
except HTTPException:
|
||||
return RuntimeLogResponse(is_running=False, content="")
|
||||
|
||||
run_id = str(context.get("config_name") or "").strip() or None
|
||||
log_path = _get_gateway_log_path_for_run(run_id) if run_id else None
|
||||
content = _read_log_tail(log_path) if log_path else ""
|
||||
|
||||
return RuntimeLogResponse(
|
||||
run_id=run_id,
|
||||
is_running=_is_gateway_running(),
|
||||
log_path=str(log_path) if log_path else None,
|
||||
content=content,
|
||||
)
|
||||
|
||||
|
||||
def _build_gateway_ws_url(request: Request, port: int) -> str:
|
||||
"""Build a proxy-safe Gateway WebSocket URL."""
|
||||
forwarded_proto = request.headers.get("x-forwarded-proto", "").split(",")[0].strip()
|
||||
@@ -416,10 +450,8 @@ def _load_latest_runtime_snapshot() -> Dict[str, Any]:
|
||||
return json.loads(snapshots[0].read_text(encoding="utf-8"))
|
||||
|
||||
|
||||
def _get_current_runtime_context() -> Dict[str, Any]:
|
||||
"""Return the active runtime context from the latest snapshot."""
|
||||
if not _is_gateway_running():
|
||||
raise HTTPException(status_code=404, detail="No runtime is currently running")
|
||||
def _get_runtime_context_from_latest_snapshot() -> Dict[str, Any]:
|
||||
"""Return the latest persisted runtime context regardless of active process state."""
|
||||
latest = _load_latest_runtime_snapshot()
|
||||
context = latest.get("context") or {}
|
||||
if not context.get("config_name"):
|
||||
@@ -427,6 +459,26 @@ def _get_current_runtime_context() -> Dict[str, Any]:
|
||||
return context
|
||||
|
||||
|
||||
def _get_gateway_log_path_for_run(run_id: str) -> Path:
|
||||
return _get_run_dir(run_id) / "logs" / "gateway.log"
|
||||
|
||||
|
||||
def _read_log_tail(path: Path, max_chars: int = 120_000) -> str:
|
||||
if not path.exists() or not path.is_file():
|
||||
return ""
|
||||
text = path.read_text(encoding="utf-8", errors="replace")
|
||||
if len(text) <= max_chars:
|
||||
return text
|
||||
return text[-max_chars:]
|
||||
|
||||
|
||||
def _get_current_runtime_context() -> Dict[str, Any]:
|
||||
"""Return the active runtime context from the latest snapshot."""
|
||||
if not _is_gateway_running():
|
||||
raise HTTPException(status_code=404, detail="No runtime is currently running")
|
||||
return _get_runtime_context_from_latest_snapshot()
|
||||
|
||||
|
||||
def _resolve_runtime_response(run_id: str) -> RuntimeConfigResponse:
|
||||
"""Build a normalized runtime config response for the active run."""
|
||||
context = _get_current_runtime_context()
|
||||
@@ -567,11 +619,12 @@ async def start_runtime(
|
||||
await asyncio.sleep(2)
|
||||
|
||||
if not _is_gateway_running():
|
||||
stdout, stderr = process.communicate(timeout=1)
|
||||
_runtime_state.gateway_process = None
|
||||
log_path = _get_gateway_log_path_for_run(run_id)
|
||||
log_tail = _read_log_tail(log_path, max_chars=4000)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Gateway failed to start: {stderr.decode() if stderr else 'Unknown error'}"
|
||||
detail=f"Gateway failed to start: {log_tail or 'Unknown error'}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
30
backend/apps/cors.py
Normal file
30
backend/apps/cors.py
Normal file
@@ -0,0 +1,30 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Shared CORS configuration for all microservice apps."""
|
||||
|
||||
import os
|
||||
from typing import Sequence
|
||||
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
|
||||
def get_cors_origins() -> Sequence[str]:
|
||||
"""Get allowed CORS origins from environment variable.
|
||||
|
||||
Defaults to ["*"] for backward compatibility.
|
||||
Set CORS_ALLOWED_ORIGINS env var (comma-separated) in production.
|
||||
"""
|
||||
origins = os.getenv("CORS_ALLOWED_ORIGINS", "").strip()
|
||||
if not origins:
|
||||
return ["*"]
|
||||
return [o.strip() for o in origins.split(",") if o.strip()]
|
||||
|
||||
|
||||
def add_cors_middleware(app: "FastAPI") -> None:
|
||||
"""Add CORS middleware to app with environment-configured origins."""
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=get_cors_origins(),
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
@@ -76,27 +76,19 @@ def _resolve_config() -> DataSourceConfig:
|
||||
"""
|
||||
Resolve data source configuration based on available API keys.
|
||||
|
||||
Priority:
|
||||
1. FINNHUB_API_KEY (if set)
|
||||
2. FINANCIAL_DATASETS_API_KEY (if set)
|
||||
3. Raises error if neither is available
|
||||
The effective source should always match the first item in the resolved
|
||||
ordered source list.
|
||||
"""
|
||||
sources = _ordered_sources()
|
||||
if "finnhub" in sources:
|
||||
return DataSourceConfig(
|
||||
source="finnhub",
|
||||
api_key=os.getenv("FINNHUB_API_KEY", "").strip(),
|
||||
sources=sources,
|
||||
)
|
||||
if "financial_datasets" in sources:
|
||||
return DataSourceConfig(
|
||||
source="financial_datasets",
|
||||
api_key=os.getenv("FINANCIAL_DATASETS_API_KEY", "").strip(),
|
||||
sources=sources,
|
||||
)
|
||||
if "yfinance" in sources:
|
||||
return DataSourceConfig(source="yfinance", api_key="", sources=sources)
|
||||
return DataSourceConfig(source="local_csv", api_key="", sources=sources)
|
||||
source = sources[0] if sources else "local_csv"
|
||||
|
||||
api_key = ""
|
||||
if source == "finnhub":
|
||||
api_key = os.getenv("FINNHUB_API_KEY", "").strip()
|
||||
elif source == "financial_datasets":
|
||||
api_key = os.getenv("FINANCIAL_DATASETS_API_KEY", "").strip()
|
||||
|
||||
return DataSourceConfig(source=source, api_key=api_key, sources=sources)
|
||||
|
||||
|
||||
def get_config() -> DataSourceConfig:
|
||||
|
||||
@@ -15,6 +15,9 @@ from backend.data.provider_utils import normalize_symbol
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_SUPPRESSED_LOG_EVERY = 20
|
||||
|
||||
|
||||
class PollingPriceManager:
|
||||
"""Polling-based price manager using Finnhub or yfinance."""
|
||||
|
||||
@@ -43,6 +46,7 @@ class PollingPriceManager:
|
||||
self.latest_prices: Dict[str, float] = {}
|
||||
self.open_prices: Dict[str, float] = {}
|
||||
self.price_callbacks: List[Callable] = []
|
||||
self._failure_counts: Dict[str, int] = {}
|
||||
|
||||
self.running = False
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
@@ -77,6 +81,8 @@ class PollingPriceManager:
|
||||
for symbol in self.subscribed_symbols:
|
||||
try:
|
||||
quote_data = self._fetch_quote(symbol)
|
||||
if not isinstance(quote_data, dict):
|
||||
raise ValueError(f"{symbol}: Empty quote payload")
|
||||
|
||||
current_price = quote_data.get("c")
|
||||
open_price = quote_data.get("o")
|
||||
@@ -103,6 +109,13 @@ class PollingPriceManager:
|
||||
)
|
||||
|
||||
self.latest_prices[symbol] = current_price
|
||||
previous_failures = self._failure_counts.pop(symbol, 0)
|
||||
if previous_failures > 0:
|
||||
logger.info(
|
||||
"%s quote polling recovered after %d consecutive failures",
|
||||
symbol,
|
||||
previous_failures,
|
||||
)
|
||||
|
||||
price_data = {
|
||||
"symbol": symbol,
|
||||
@@ -128,7 +141,20 @@ class PollingPriceManager:
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to fetch {symbol} price: {e}")
|
||||
failure_count = self._failure_counts.get(symbol, 0) + 1
|
||||
self._failure_counts[symbol] = failure_count
|
||||
message = f"Failed to fetch {symbol} price: {e}"
|
||||
|
||||
if failure_count == 1:
|
||||
logger.warning(message)
|
||||
elif failure_count % _SUPPRESSED_LOG_EVERY == 0:
|
||||
logger.warning(
|
||||
"%s (repeated %d times; suppressing intermediate failures)",
|
||||
message,
|
||||
failure_count,
|
||||
)
|
||||
else:
|
||||
logger.debug(message)
|
||||
|
||||
def _fetch_quote(self, symbol: str) -> Dict[str, float]:
|
||||
"""Fetch a normalized quote payload from the configured provider."""
|
||||
@@ -136,7 +162,10 @@ class PollingPriceManager:
|
||||
return self._fetch_yfinance_quote(symbol)
|
||||
if not self.finnhub_client:
|
||||
raise ValueError("Finnhub API key required for finnhub polling")
|
||||
return self.finnhub_client.quote(symbol)
|
||||
quote = self.finnhub_client.quote(symbol)
|
||||
if not isinstance(quote, dict):
|
||||
raise ValueError(f"{symbol}: Invalid Finnhub quote payload")
|
||||
return quote
|
||||
|
||||
def _fetch_yfinance_quote(self, symbol: str) -> Dict[str, float]:
|
||||
"""Fetch quote data from yfinance and normalize to Finnhub-like keys."""
|
||||
@@ -162,6 +191,8 @@ class PollingPriceManager:
|
||||
|
||||
if current_price is None:
|
||||
history = ticker.history(period="1d", interval="1m", auto_adjust=False)
|
||||
if history is None:
|
||||
raise ValueError(f"{symbol}: yfinance returned no history frame")
|
||||
if history.empty:
|
||||
raise ValueError(f"{symbol}: No yfinance quote data")
|
||||
latest = history.iloc[-1]
|
||||
|
||||
@@ -43,6 +43,71 @@ logger = logging.getLogger(__name__)
|
||||
_prompt_loader = get_prompt_loader()
|
||||
|
||||
|
||||
INFO_LOGGER_PREFIXES = (
|
||||
"backend.agents",
|
||||
"backend.core.pipeline",
|
||||
"backend.core.scheduler",
|
||||
"backend.services.gateway_cycle_support",
|
||||
"backend.utils.terminal_dashboard",
|
||||
)
|
||||
|
||||
NOISY_LOGGER_LEVELS = {
|
||||
"aiohttp": logging.WARNING,
|
||||
"asyncio": logging.WARNING,
|
||||
"dashscope": logging.WARNING,
|
||||
"finnhub": logging.WARNING,
|
||||
"httpcore": logging.WARNING,
|
||||
"httpx": logging.WARNING,
|
||||
"urllib3": logging.WARNING,
|
||||
"websockets": logging.WARNING,
|
||||
"yfinance": logging.WARNING,
|
||||
"backend.data.polling_price_manager": logging.WARNING,
|
||||
"backend.services.gateway": logging.WARNING,
|
||||
"backend.services.market": logging.WARNING,
|
||||
"backend.services.storage": logging.WARNING,
|
||||
}
|
||||
|
||||
|
||||
class SuppressNoisyInfoFilter(logging.Filter):
|
||||
"""Filter out low-signal library INFO logs while keeping warnings/errors."""
|
||||
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
if record.levelno >= logging.WARNING:
|
||||
return True
|
||||
|
||||
message = record.getMessage()
|
||||
if record.name == "httpx" and message.startswith("HTTP Request:"):
|
||||
return False
|
||||
if record.name.startswith("websockets") and "connection open" in message:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def configure_gateway_logging(verbose: bool = False) -> None:
|
||||
"""Configure gateway logging with low-noise defaults for runtime logs."""
|
||||
root_level = logging.DEBUG if verbose else logging.WARNING
|
||||
logging.basicConfig(
|
||||
level=root_level,
|
||||
format="%(asctime)s | %(levelname)-7s | %(name)s:%(lineno)d - %(message)s",
|
||||
force=True,
|
||||
)
|
||||
|
||||
if not verbose:
|
||||
suppress_filter = SuppressNoisyInfoFilter()
|
||||
for handler in logging.getLogger().handlers:
|
||||
handler.addFilter(suppress_filter)
|
||||
|
||||
for logger_name, level in NOISY_LOGGER_LEVELS.items():
|
||||
logging.getLogger(logger_name).setLevel(logging.DEBUG if verbose else level)
|
||||
|
||||
if not verbose:
|
||||
for prefix in INFO_LOGGER_PREFIXES:
|
||||
logging.getLogger(prefix).setLevel(logging.INFO)
|
||||
|
||||
logging.getLogger(__name__).setLevel(logging.INFO if not verbose else logging.DEBUG)
|
||||
|
||||
|
||||
async def run_gateway(
|
||||
run_id: str,
|
||||
run_dir: Path,
|
||||
@@ -222,11 +287,7 @@ def main():
|
||||
args = parser.parse_args()
|
||||
|
||||
# Setup logging
|
||||
level = logging.DEBUG if args.verbose else logging.INFO
|
||||
logging.basicConfig(
|
||||
level=level,
|
||||
format="%(asctime)s | %(levelname)-7s | %(name)s:%(lineno)d - %(message)s",
|
||||
)
|
||||
configure_gateway_logging(verbose=args.verbose)
|
||||
|
||||
# Parse bootstrap
|
||||
bootstrap = json.loads(args.bootstrap)
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
AgentScope Native Model Factory
|
||||
Uses native AgentScope model classes for LLM calls
|
||||
"""
|
||||
import asyncio
|
||||
import inspect
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
@@ -34,6 +36,27 @@ logger = logging.getLogger(__name__)
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def _usage_value(usage: Any, key: str, default: Any = 0) -> Any:
|
||||
"""Read usage fields from both object-style and dict-style usage payloads."""
|
||||
if usage is None:
|
||||
return default
|
||||
if isinstance(usage, dict):
|
||||
return usage.get(key, default)
|
||||
try:
|
||||
return getattr(usage, key)
|
||||
except (AttributeError, KeyError):
|
||||
return default
|
||||
|
||||
|
||||
def _usage_total_tokens(usage: Any) -> int:
|
||||
total = _usage_value(usage, "total_tokens", None)
|
||||
if total is not None:
|
||||
return int(total or 0)
|
||||
input_tokens = _usage_value(usage, "input_tokens", 0)
|
||||
output_tokens = _usage_value(usage, "output_tokens", 0)
|
||||
return int((input_tokens or 0) + (output_tokens or 0))
|
||||
|
||||
|
||||
class RetryChatModel:
|
||||
"""Wraps an AgentScope model with automatic retry for transient errors.
|
||||
|
||||
@@ -55,6 +78,7 @@ class RetryChatModel:
|
||||
"502",
|
||||
"504",
|
||||
"connection",
|
||||
"disconnected",
|
||||
"temporary",
|
||||
"overloaded",
|
||||
"too_many_requests",
|
||||
@@ -150,8 +174,8 @@ class RetryChatModel:
|
||||
# Track usage if available
|
||||
if hasattr(result, "usage") and result.usage:
|
||||
usage = result.usage
|
||||
self._total_tokens_used += getattr(usage, "total_tokens", 0)
|
||||
self._total_cost += getattr(usage, "cost", 0.0)
|
||||
self._total_tokens_used += _usage_total_tokens(usage)
|
||||
self._total_cost += float(_usage_value(usage, "cost", 0.0) or 0.0)
|
||||
|
||||
return result
|
||||
|
||||
@@ -192,9 +216,66 @@ class RetryChatModel:
|
||||
raise last_error
|
||||
raise RuntimeError("RetryChatModel: Unexpected state, no error but no result")
|
||||
|
||||
async def _call_with_retry_async(self, func: Callable[..., T], *args, **kwargs) -> T:
|
||||
"""Call an async function with retry logic for transient errors."""
|
||||
last_error: Optional[Exception] = None
|
||||
|
||||
for attempt in range(1, self._max_retries + 1):
|
||||
try:
|
||||
result = await func(*args, **kwargs)
|
||||
|
||||
if hasattr(result, "usage") and result.usage:
|
||||
usage = result.usage
|
||||
self._total_tokens_used += _usage_total_tokens(usage)
|
||||
self._total_cost += float(_usage_value(usage, "cost", 0.0) or 0.0)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
|
||||
if attempt >= self._max_retries:
|
||||
logger.error(
|
||||
"RetryChatModel: Max retries (%d) exhausted for %s",
|
||||
self._max_retries,
|
||||
self.model_name,
|
||||
)
|
||||
break
|
||||
|
||||
if not self._is_transient_error(e):
|
||||
logger.warning(
|
||||
"RetryChatModel: Non-transient error, not retrying: %s",
|
||||
str(e),
|
||||
)
|
||||
break
|
||||
|
||||
delay = self._calculate_delay(attempt)
|
||||
logger.warning(
|
||||
"RetryChatModel: Transient async error on attempt %d/%d, "
|
||||
"retrying in %.1fs: %s",
|
||||
attempt,
|
||||
self._max_retries,
|
||||
delay,
|
||||
str(e)[:200],
|
||||
)
|
||||
|
||||
if self._on_retry:
|
||||
self._on_retry(attempt, e, delay)
|
||||
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
if last_error is not None:
|
||||
raise last_error
|
||||
raise RuntimeError("RetryChatModel: Unexpected async state, no error but no result")
|
||||
|
||||
def __call__(self, *args, **kwargs) -> Any:
|
||||
"""Forward calls to the wrapped model with retry logic."""
|
||||
return self._call_with_retry(self._model, *args, **kwargs)
|
||||
model_call = getattr(self._model, "__call__", None)
|
||||
if inspect.iscoroutinefunction(self._model) or inspect.iscoroutinefunction(model_call):
|
||||
return self._call_with_retry_async(self._model, *args, **kwargs)
|
||||
|
||||
result = self._model(*args, **kwargs)
|
||||
return result
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
"""Proxy attribute access to the wrapped model."""
|
||||
@@ -248,10 +329,18 @@ class TokenRecordingModelWrapper:
|
||||
if usage is None:
|
||||
return
|
||||
|
||||
self._prompt_tokens += getattr(usage, "prompt_tokens", 0)
|
||||
self._completion_tokens += getattr(usage, "completion_tokens", 0)
|
||||
self._total_tokens += getattr(usage, "total_tokens", 0)
|
||||
self._total_cost += getattr(usage, "cost", 0.0)
|
||||
prompt_tokens = _usage_value(usage, "prompt_tokens", None)
|
||||
completion_tokens = _usage_value(usage, "completion_tokens", None)
|
||||
|
||||
if prompt_tokens is None:
|
||||
prompt_tokens = _usage_value(usage, "input_tokens", 0)
|
||||
if completion_tokens is None:
|
||||
completion_tokens = _usage_value(usage, "output_tokens", 0)
|
||||
|
||||
self._prompt_tokens += int(prompt_tokens or 0)
|
||||
self._completion_tokens += int(completion_tokens or 0)
|
||||
self._total_tokens += _usage_total_tokens(usage)
|
||||
self._total_cost += float(_usage_value(usage, "cost", 0.0) or 0.0)
|
||||
|
||||
def __call__(self, *args, **kwargs) -> Any:
|
||||
"""Forward calls and record usage."""
|
||||
@@ -401,7 +490,8 @@ def create_model(
|
||||
if host:
|
||||
model_kwargs["host"] = host
|
||||
|
||||
return model_class(**model_kwargs)
|
||||
model = model_class(**model_kwargs)
|
||||
return RetryChatModel(model)
|
||||
|
||||
|
||||
def get_agent_model(agent_id: str, stream: bool = False):
|
||||
|
||||
@@ -10,7 +10,7 @@ from typing import Any, Callable, Dict, List, Optional
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
import pandas_market_calendars as mcal
|
||||
from backend.config.data_config import get_data_source
|
||||
from backend.config.data_config import get_data_sources
|
||||
from backend.data.provider_utils import normalize_symbol
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -142,9 +142,7 @@ class MarketService:
|
||||
def _start_real_mode(self):
|
||||
from backend.data.polling_price_manager import PollingPriceManager
|
||||
|
||||
provider = get_data_source()
|
||||
if provider == "local_csv":
|
||||
provider = "yfinance"
|
||||
provider = self._resolve_live_quote_provider()
|
||||
|
||||
if provider == "finnhub" and not self.api_key:
|
||||
raise ValueError("API key required for live mode")
|
||||
@@ -157,6 +155,13 @@ class MarketService:
|
||||
self._price_manager.subscribe(self.tickers)
|
||||
self._price_manager.start()
|
||||
|
||||
def _resolve_live_quote_provider(self) -> str:
|
||||
"""Pick the first configured provider that supports live quote polling."""
|
||||
for provider in get_data_sources():
|
||||
if provider in {"finnhub", "yfinance"}:
|
||||
return provider
|
||||
return "yfinance"
|
||||
|
||||
def _start_backtest_mode(self):
|
||||
from backend.data.historical_price_manager import (
|
||||
HistoricalPriceManager,
|
||||
|
||||
@@ -2,11 +2,13 @@
|
||||
# pylint: disable=W0212
|
||||
import asyncio
|
||||
import time
|
||||
import logging
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
import pytest
|
||||
from backend.services.market import MarketService
|
||||
from backend.data.mock_price_manager import MockPriceManager
|
||||
from backend.data.polling_price_manager import PollingPriceManager
|
||||
from backend.llm.models import RetryChatModel
|
||||
|
||||
|
||||
class TestMockPriceManager:
|
||||
@@ -231,6 +233,59 @@ class TestPollingPriceManager:
|
||||
|
||||
assert len(manager.open_prices) == 0
|
||||
|
||||
def test_fetch_prices_suppresses_repeated_failures(self, caplog):
|
||||
manager = PollingPriceManager(provider="yfinance", poll_interval=10)
|
||||
manager.subscribe(["AAPL"])
|
||||
|
||||
with patch.object(manager, "_fetch_quote", side_effect=ValueError("empty quote")):
|
||||
with caplog.at_level(logging.DEBUG):
|
||||
for _ in range(3):
|
||||
manager._fetch_prices()
|
||||
|
||||
assert manager._failure_counts["AAPL"] == 3
|
||||
warning_messages = [record.message for record in caplog.records if record.levelno >= logging.WARNING]
|
||||
assert any("Failed to fetch AAPL price: empty quote" in message for message in warning_messages)
|
||||
|
||||
def test_fetch_prices_logs_recovery_after_failure(self, caplog):
|
||||
manager = PollingPriceManager(provider="yfinance", poll_interval=10)
|
||||
manager.subscribe(["AAPL"])
|
||||
|
||||
with patch.object(
|
||||
manager,
|
||||
"_fetch_quote",
|
||||
side_effect=[
|
||||
ValueError("temporary outage"),
|
||||
{"c": 100.0, "o": 99.0, "h": 101.0, "l": 98.0, "pc": 99.5, "d": 0.5, "dp": 0.5, "t": 1},
|
||||
],
|
||||
):
|
||||
with caplog.at_level(logging.INFO):
|
||||
manager._fetch_prices()
|
||||
manager._fetch_prices()
|
||||
|
||||
assert "AAPL" not in manager._failure_counts
|
||||
assert any("recovered after 1 consecutive failures" in record.message for record in caplog.records)
|
||||
|
||||
|
||||
class TestRetryChatModel:
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_retry_recovers_from_disconnect(self):
|
||||
attempts = {"count": 0}
|
||||
|
||||
class FakeAsyncModel:
|
||||
model_name = "fake-async-model"
|
||||
|
||||
async def __call__(self, *args, **kwargs):
|
||||
attempts["count"] += 1
|
||||
if attempts["count"] < 2:
|
||||
raise RuntimeError("Server disconnected")
|
||||
return {"ok": True}
|
||||
|
||||
wrapped = RetryChatModel(FakeAsyncModel(), max_retries=2, initial_delay=0.01)
|
||||
result = await wrapped("hello")
|
||||
|
||||
assert result == {"ok": True}
|
||||
assert attempts["count"] == 2
|
||||
|
||||
|
||||
class TestMarketService:
|
||||
def test_init_mock_mode(self):
|
||||
@@ -255,9 +310,23 @@ class TestMarketService:
|
||||
assert service.mock_mode is False
|
||||
assert service.api_key == "test_key"
|
||||
|
||||
@patch("backend.services.market.get_data_source", return_value="yfinance")
|
||||
@patch("backend.services.market.get_data_sources", return_value=["yfinance", "local_csv"])
|
||||
@patch.object(PollingPriceManager, "start")
|
||||
def test_start_real_mode_with_yfinance(self, _mock_start, _mock_source):
|
||||
def test_start_real_mode_with_yfinance(self, _mock_start, _mock_sources):
|
||||
service = MarketService(
|
||||
tickers=["AAPL"],
|
||||
poll_interval=10,
|
||||
mock_mode=False,
|
||||
)
|
||||
|
||||
service._start_real_mode()
|
||||
|
||||
assert isinstance(service._price_manager, PollingPriceManager)
|
||||
assert service._price_manager.provider == "yfinance"
|
||||
|
||||
@patch("backend.services.market.get_data_sources", return_value=["financial_datasets", "yfinance", "local_csv"])
|
||||
@patch.object(PollingPriceManager, "start")
|
||||
def test_start_real_mode_uses_first_supported_live_provider(self, _mock_start, _mock_sources):
|
||||
service = MarketService(
|
||||
tickers=["AAPL"],
|
||||
poll_interval=10,
|
||||
@@ -287,9 +356,9 @@ class TestMarketService:
|
||||
|
||||
service.stop()
|
||||
|
||||
@patch("backend.services.market.get_data_source", return_value="finnhub")
|
||||
@patch("backend.services.market.get_data_sources", return_value=["finnhub", "yfinance"])
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_real_mode_without_api_key(self, _mock_source):
|
||||
async def test_start_real_mode_without_api_key(self, _mock_sources):
|
||||
service = MarketService(
|
||||
tickers=["AAPL"],
|
||||
mock_mode=False,
|
||||
|
||||
Reference in New Issue
Block a user