perf: optimize system concurrency, I/O stability and fix WebSocket disconnects

This commit is contained in:
2026-04-07 13:58:49 +08:00
parent 62c7341cf6
commit 11849208ed
21 changed files with 357 additions and 215 deletions

View File

@@ -12,7 +12,7 @@ from __future__ import annotations
import asyncio
import json
import logging
from datetime import UTC, datetime
from datetime import datetime, timezone
from enum import Enum
from typing import Any, Callable, Dict, Iterable, List, Optional, Set
@@ -78,7 +78,7 @@ class ApprovalRecord:
self.session_id = session_id
self.status = ApprovalStatus.PENDING
self.findings = findings or []
self.created_at = datetime.now(UTC)
self.created_at = datetime.now(timezone.utc)
self.resolved_at: Optional[datetime] = None
self.resolved_by: Optional[str] = None
self.metadata: Dict[str, Any] = {}
@@ -163,7 +163,7 @@ class ToolGuardStore:
return record
record.status = status
record.resolved_at = datetime.now(UTC)
record.resolved_at = datetime.now(timezone.utc)
record.resolved_by = resolved_by
if notify_request and record.pending_request:
if status == ApprovalStatus.APPROVED:

View File

@@ -7,7 +7,7 @@ Provides REST API endpoints for tool guard operations.
from __future__ import annotations
from typing import Any, Dict, List, Optional
from datetime import UTC, datetime
from datetime import datetime, timezone
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, Field
@@ -146,7 +146,7 @@ async def check_tool_call(
if request.tool_name in SAFE_TOOLS:
record.status = ApprovalStatus.APPROVED
record.resolved_at = datetime.now(UTC)
record.resolved_at = datetime.now(timezone.utc)
record.resolved_by = "system"
STORE.set_status(
record.approval_id,

View File

@@ -81,7 +81,12 @@ async def proxy_ws(ws: WebSocket):
await ws.accept()
upstream = None
try:
upstream = await websockets.asyncio.client.connect(gateway_url)
upstream = await websockets.asyncio.client.connect(
gateway_url,
ping_interval=20,
ping_timeout=120,
max_size=10 * 1024 * 1024, # 10MB
)
async def client_to_upstream():
try:

View File

@@ -28,11 +28,11 @@ def create_app() -> FastAPI:
add_cors_middleware(app)
@app.get("/health")
async def health_check() -> dict[str, str]:
def health_check() -> dict[str, str]:
return {"status": "healthy", "service": "news-service"}
@app.get("/api/enriched-news")
async def api_get_enriched_news(
def api_get_enriched_news(
ticker: str = Query(..., min_length=1),
start_date: str | None = Query(None),
end_date: str | None = Query(None),
@@ -49,7 +49,7 @@ def create_app() -> FastAPI:
)
@app.get("/api/news-for-date")
async def api_get_news_for_date(
def api_get_news_for_date(
ticker: str = Query(..., min_length=1),
date: str = Query(...),
limit: int = Query(20, ge=1, le=100),
@@ -64,7 +64,7 @@ def create_app() -> FastAPI:
)
@app.get("/api/news-timeline")
async def api_get_news_timeline(
def api_get_news_timeline(
ticker: str = Query(..., min_length=1),
start_date: str = Query(...),
end_date: str = Query(...),
@@ -79,7 +79,7 @@ def create_app() -> FastAPI:
)
@app.get("/api/categories")
async def api_get_categories(
def api_get_categories(
ticker: str = Query(..., min_length=1),
start_date: str | None = Query(None),
end_date: str | None = Query(None),
@@ -96,7 +96,7 @@ def create_app() -> FastAPI:
)
@app.get("/api/similar-days")
async def api_get_similar_days(
def api_get_similar_days(
ticker: str = Query(..., min_length=1),
date: str = Query(...),
n_similar: int = Query(5, ge=1, le=20),
@@ -111,7 +111,7 @@ def create_app() -> FastAPI:
)
@app.get("/api/stories/{ticker}")
async def api_get_story(
def api_get_story(
ticker: str,
as_of_date: str = Query(...),
store: MarketStore = Depends(get_market_store),
@@ -124,7 +124,7 @@ def create_app() -> FastAPI:
)
@app.get("/api/range-explain")
async def api_get_range_explain(
def api_get_range_explain(
ticker: str = Query(..., min_length=1),
start_date: str = Query(...),
end_date: str = Query(...),

View File

@@ -29,12 +29,12 @@ def create_app() -> FastAPI:
add_cors_middleware(app)
@app.get("/health")
async def health_check() -> dict[str, str]:
def health_check() -> dict[str, str]:
"""Health check endpoint."""
return {"status": "healthy", "service": "trading-service"}
@app.get("/api/prices", response_model=PriceResponse)
async def api_get_prices(
def api_get_prices(
ticker: str = Query(..., min_length=1),
start_date: str = Query(...),
end_date: str = Query(...),
@@ -47,7 +47,7 @@ def create_app() -> FastAPI:
return PriceResponse(ticker=payload["ticker"], prices=payload["prices"])
@app.get("/api/financials", response_model=FinancialMetricsResponse)
async def api_get_financials(
def api_get_financials(
ticker: str = Query(..., min_length=1),
end_date: str = Query(...),
period: str = Query("ttm"),
@@ -62,7 +62,7 @@ def create_app() -> FastAPI:
return FinancialMetricsResponse(financial_metrics=payload["financial_metrics"])
@app.get("/api/news", response_model=CompanyNewsResponse)
async def api_get_news(
def api_get_news(
ticker: str = Query(..., min_length=1),
end_date: str = Query(...),
start_date: str | None = Query(None),
@@ -77,7 +77,7 @@ def create_app() -> FastAPI:
return CompanyNewsResponse(news=payload["news"])
@app.get("/api/insider-trades", response_model=InsiderTradeResponse)
async def api_get_insider_trades(
def api_get_insider_trades(
ticker: str = Query(..., min_length=1),
end_date: str = Query(...),
start_date: str | None = Query(None),
@@ -92,12 +92,12 @@ def create_app() -> FastAPI:
return InsiderTradeResponse(insider_trades=payload["insider_trades"])
@app.get("/api/market/status")
async def api_get_market_status() -> dict[str, Any]:
def api_get_market_status() -> dict[str, Any]:
"""Return current market status using the existing market service logic."""
return trading_domain.get_market_status_payload()
@app.get("/api/market-cap")
async def api_get_market_cap(
def api_get_market_cap(
ticker: str = Query(..., min_length=1),
end_date: str = Query(...),
) -> dict[str, Any]:
@@ -108,7 +108,7 @@ def create_app() -> FastAPI:
)
@app.get("/api/line-items", response_model=LineItemResponse)
async def api_get_line_items(
def api_get_line_items(
ticker: str = Query(..., min_length=1),
line_items: list[str] = Query(...),
end_date: str = Query(...),

View File

@@ -144,7 +144,7 @@ class TradingPipeline:
self._team_controller = DynamicTeamController(
create_callback=self._create_runtime_analyst,
remove_callback=self._remove_runtime_analyst,
get_analysts_callback=self._all_analysts,
get_analysts_callback=lambda: self._all_analysts() + [self.risk_manager, self.pm],
)
set_controller(self._team_controller)

View File

@@ -123,7 +123,11 @@ class StateSync:
# Persist to feed_history
if persist:
self.storage.add_feed_message(self._state, event)
self.save_state()
# Make persistence non-blocking to keep event loop snappy
if asyncio.get_event_loop().is_running():
asyncio.create_task(asyncio.to_thread(self.save_state))
else:
self.save_state()
# Broadcast to frontend
if self._broadcast_fn:

View File

@@ -190,8 +190,9 @@ class MarketStore:
name: str | None = None,
sector: str | None = None,
is_active: bool = True,
) -> None:
) -> int:
timestamp = _utc_timestamp()
count = 0
with self._connect() as conn:
conn.execute(
"""
@@ -206,6 +207,8 @@ class MarketStore:
""",
(symbol, name, sector, 1 if is_active else 0, timestamp, timestamp),
)
count += 1
return count
def update_fetch_watermark(
self,
@@ -213,8 +216,9 @@ class MarketStore:
symbol: str,
price_date: str | None = None,
news_date: str | None = None,
) -> None:
) -> int:
timestamp = _utc_timestamp()
count = 0
with self._connect() as conn:
conn.execute(
"""
@@ -227,6 +231,8 @@ class MarketStore:
""",
(symbol, timestamp, timestamp, price_date, news_date),
)
count += 1
return count
def get_ticker_watermarks(self, symbol: str) -> dict[str, Any]:
with self._connect() as conn:
@@ -263,6 +269,8 @@ class MarketStore:
count = 0
with self._connect() as conn:
for row in rows:
if not row.get("date"):
continue
conn.execute(
"""
INSERT INTO ohlc
@@ -341,6 +349,7 @@ class MarketStore:
timestamp,
),
)
count += 1
for ticker in tickers:
conn.execute(
"""
@@ -349,7 +358,6 @@ class MarketStore:
""",
(news_id, str(ticker).strip().upper()),
)
count += 1
return count
def get_news_without_trade_date(self, symbol: str | None = None, *, limit: int = 5000) -> list[dict[str, Any]]:
@@ -928,8 +936,9 @@ class MarketStore:
as_of_date: str,
content: str,
source: str = "local",
) -> None:
) -> int:
timestamp = _utc_timestamp()
count = 0
with self._connect() as conn:
conn.execute(
"""
@@ -943,6 +952,8 @@ class MarketStore:
""",
(symbol, as_of_date, content, source, timestamp, timestamp),
)
count += 1
return count
def delete_story_cache(
self,
@@ -1002,8 +1013,9 @@ class MarketStore:
target_date: str,
payload: dict[str, Any],
source: str = "local",
) -> None:
) -> int:
timestamp = _utc_timestamp()
count = 0
with self._connect() as conn:
conn.execute(
"""
@@ -1017,6 +1029,8 @@ class MarketStore:
""",
(symbol, target_date, _json_dumps(payload), source, timestamp, timestamp),
)
count += 1
return count
def delete_similar_day_cache(
self,

View File

@@ -444,6 +444,16 @@ def create_model(
"""
provider = canonicalize_model_provider(provider)
# If provider is default OPENAI but model name looks like deepseek,
# check if we should switch to DASHSCOPE.
if provider == "OPENAI" and "deepseek" in model_name.lower() and os.getenv("DASHSCOPE_API_KEY"):
provider = "DASHSCOPE"
# Intelligent routing: if it's a DeepSeek model and we have DashScope credentials,
# prefer using DashScopeChatModel over OpenAIChatModel.
if provider == "DEEPSEEK" and os.getenv("DASHSCOPE_API_KEY"):
provider = "DASHSCOPE"
model_class = PROVIDER_MODEL_MAP.get(provider)
if model_class is None:
raise ValueError(f"Unsupported provider: {provider}")

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
from dataclasses import dataclass, field
from datetime import datetime, UTC
from datetime import datetime, timezone
from typing import Any, Dict
@@ -11,12 +11,12 @@ class AgentRuntimeState:
display_name: str | None = None
status: str = "idle"
last_session: str | None = None
last_updated: datetime = field(default_factory=lambda: datetime.now(UTC))
last_updated: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
def update(self, status: str, session_key: str | None = None) -> None:
self.status = status
self.last_session = session_key
self.last_updated = datetime.now(UTC)
self.last_updated = datetime.now(timezone.utc)
def to_dict(self) -> Dict[str, Any]:
return {

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
import asyncio
import json
from datetime import datetime, UTC
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Dict, List, Optional
@@ -93,7 +93,7 @@ class TradingRuntimeManager:
def log_event(self, event: str, details: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
entry = {
"timestamp": datetime.now(UTC).isoformat(),
"timestamp": datetime.now(timezone.utc).isoformat(),
"event": event,
"details": details or {},
"session": self.current_session_key,
@@ -120,7 +120,7 @@ class TradingRuntimeManager:
def register_pending_approval(self, approval_id: str, payload: Dict[str, Any]) -> None:
payload.setdefault("status", "pending")
payload.setdefault("created_at", datetime.now(UTC).isoformat())
payload.setdefault("created_at", datetime.now(timezone.utc).isoformat())
self.pending_approvals[approval_id] = payload
self._persist_snapshot()
@@ -149,7 +149,7 @@ class TradingRuntimeManager:
if not entry:
return
entry["status"] = status
entry["resolved_at"] = datetime.now(UTC).isoformat()
entry["resolved_at"] = datetime.now(timezone.utc).isoformat()
entry["resolved_by"] = resolved_by
self._persist_snapshot()

View File

@@ -148,8 +148,9 @@ class Gateway:
self.handle_client,
host,
port,
ping_interval=30,
ping_timeout=60,
ping_interval=20,
ping_timeout=120,
max_size=10 * 1024 * 1024, # 10MB
)
logger.info(f"WebSocket server ready: ws://{host}:{port}")
@@ -833,12 +834,18 @@ class Gateway:
if not self.connected_clients:
return
message_json = json.dumps(message, ensure_ascii=False, default=str)
# Offload potentially heavy JSON serialization to thread
message_json = await asyncio.to_thread(
json.dumps, message, ensure_ascii=False, default=str
)
async with self.lock:
# Filter only active clients to minimize unnecessary send attempts
# In websockets v13+, we must check state.name == 'OPEN'
active_clients = [c for c in self.connected_clients if c.state.name == 'OPEN']
tasks = [
self._send_to_client(client, message_json)
for client in self.connected_clients.copy()
for client in active_clients
]
if tasks:
@@ -849,9 +856,14 @@ class Gateway:
client: ServerConnection,
message: str,
):
if client.state.name != 'OPEN':
async with self.lock:
self.connected_clients.discard(client)
return
try:
await client.send(message)
except websockets.ConnectionClosed:
except (websockets.ConnectionClosed, Exception):
async with self.lock:
self.connected_clients.discard(client)

View File

@@ -253,7 +253,8 @@ async def finalize_cycle(gateway: Any, date: str) -> None:
async def get_market_caps(gateway: Any, tickers: list[str], date: str) -> dict[str, float]:
market_caps: dict[str, float] = {}
for ticker in tickers:
async def _get_one(ticker: str):
try:
market_cap = None
response = await gateway._call_trading_service(
@@ -263,12 +264,21 @@ async def get_market_caps(gateway: Any, tickers: list[str], date: str) -> dict[s
if response is not None:
market_cap = response.get("market_cap")
if market_cap is None:
payload = trading_domain.get_market_cap_payload(ticker=ticker, end_date=date)
payload = await asyncio.to_thread(
trading_domain.get_market_cap_payload,
ticker=ticker,
end_date=date,
)
market_cap = payload.get("market_cap")
market_caps[ticker] = market_cap if market_cap else 1e9
return ticker, (market_cap if market_cap else 1e9)
except Exception as exc:
logger.warning("Failed to get market cap for %s, using default 1e9: %s", ticker, exc)
market_caps[ticker] = 1e9
return ticker, 1e9
tasks = [_get_one(ticker) for ticker in tickers]
results = await asyncio.gather(*tasks)
for ticker, mc in results:
market_caps[ticker] = mc
return market_caps

View File

@@ -517,111 +517,129 @@ async def handle_get_stock_similar_days(gateway: Any, websocket: Any, data: dict
async def handle_get_stock_technical_indicators(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:
ticker = normalize_symbol(data.get("ticker", ""))
if not ticker:
await websocket.send(json.dumps({
"type": "stock_technical_indicators_loaded",
"ticker": ticker,
"indicators": None,
"error": "ticker is required",
}, ensure_ascii=False))
return
ticker = normalize_symbol(data.get("ticker", ""))
if not ticker:
await websocket.send(json.dumps({
"type": "stock_technical_indicators_loaded",
"ticker": ticker,
"indicators": None,
"error": "ticker is required",
}, ensure_ascii=False))
return
try:
end_date = datetime.now()
start_date = end_date - timedelta(days=250)
try:
end_date = datetime.now()
# Reduced from 250 to 150 days to lower CPU/memory pressure while still supporting MA200 (approx 140 trading days)
start_date = end_date - timedelta(days=150)
prices = None
response = await gateway._call_trading_service(
"get_prices",
lambda client: client.get_prices(
ticker=ticker,
start_date=start_date.strftime("%Y-%m-%d"),
end_date=end_date.strftime("%Y-%m-%d"),
),
)
if response is not None:
prices = response.prices
prices = None
response = await gateway._call_trading_service(
"get_prices",
lambda client: client.get_prices(
ticker=ticker,
start_date=start_date.strftime("%Y-%m-%d"),
end_date=end_date.strftime("%Y-%m-%d"),
),
)
if response is not None:
prices = response.prices
if prices is None:
payload = trading_domain.get_prices_payload(
ticker=ticker,
start_date=start_date.strftime("%Y-%m-%d"),
end_date=end_date.strftime("%Y-%m-%d"),
)
prices = payload.get("prices") or []
if prices is None:
# Offload domain logic to thread
payload = await asyncio.to_thread(
trading_domain.get_prices_payload,
ticker=ticker,
start_date=start_date.strftime("%Y-%m-%d"),
end_date=end_date.strftime("%Y-%m-%d"),
)
prices = payload.get("prices") or []
if not prices or len(prices) < 20:
await websocket.send(json.dumps({
"type": "stock_technical_indicators_loaded",
"ticker": ticker,
"indicators": None,
"error": "Insufficient price data",
}, ensure_ascii=False))
return
if not prices or len(prices) < 20:
await websocket.send(json.dumps({
"type": "stock_technical_indicators_loaded",
"ticker": ticker,
"indicators": None,
"error": "Insufficient price data",
}, ensure_ascii=False))
return
df = prices_to_df(prices)
signal = gateway._technical_analyzer.analyze(ticker, df)
def _calc():
df = prices_to_df(prices)
signal = gateway._technical_analyzer.analyze(ticker, df)
df_sorted = df.sort_values("time").reset_index(drop=True)
df_sorted["returns"] = df_sorted["close"].pct_change()
v10 = float(df_sorted["returns"].tail(10).std() * (252**0.5) * 100) if len(df_sorted) >= 10 else None
v20 = float(df_sorted["returns"].tail(20).std() * (252**0.5) * 100) if len(df_sorted) >= 20 else None
v60 = float(df_sorted["returns"].tail(60).std() * (252**0.5) * 100) if len(df_sorted) >= 60 else None
df_sorted = df.sort_values("time").reset_index(drop=True)
df_sorted["returns"] = df_sorted["close"].pct_change()
vol_10 = float(df_sorted["returns"].tail(10).std() * (252**0.5) * 100) if len(df_sorted) >= 10 else None
vol_20 = float(df_sorted["returns"].tail(20).std() * (252**0.5) * 100) if len(df_sorted) >= 20 else None
vol_60 = float(df_sorted["returns"].tail(60).std() * (252**0.5) * 100) if len(df_sorted) >= 60 else None
ma_distance = {}
for ma_key in ["ma5", "ma10", "ma20", "ma50", "ma200"]:
ma_value = getattr(signal, ma_key, None)
ma_distance[ma_key] = ((signal.current_price - ma_value) / ma_value) * 100 if ma_value and ma_value > 0 else None
ma_dist = {}
for ma_key in ["ma5", "ma10", "ma20", "ma50", "ma200"]:
ma_val = getattr(signal, ma_key, None)
ma_dist[ma_key] = ((signal.current_price - ma_val) / ma_val) * 100 if ma_val and ma_val > 0 else None
indicators = {
"ticker": ticker,
"current_price": signal.current_price,
"ma": {
"ma5": signal.ma5,
"ma10": signal.ma10,
"ma20": signal.ma20,
"ma50": signal.ma50,
"ma200": signal.ma200,
"distance": ma_distance,
},
"rsi": {
"rsi14": signal.rsi14,
"status": "oversold" if signal.rsi14 < 30 else "overbought" if signal.rsi14 > 70 else "neutral",
},
"macd": {
"macd": signal.macd,
"signal": signal.macd_signal,
"histogram": signal.macd - signal.macd_signal,
},
"bollinger": {
"upper": signal.bollinger_upper,
"mid": signal.bollinger_mid,
"lower": signal.bollinger_lower,
},
"volatility": {
"vol_10d": vol_10,
"vol_20d": vol_20,
"vol_60d": vol_60,
"annualized": signal.annualized_volatility_pct,
"risk_level": signal.risk_level,
},
"trend": signal.trend,
"mean_reversion": signal.mean_reversion_signal,
}
return {
"ticker": ticker,
"current_price": signal.current_price,
"ma": {
"ma5": signal.ma5,
"ma10": signal.ma10,
"ma20": signal.ma20,
"ma50": signal.ma50,
"ma200": signal.ma200,
"distance": ma_dist,
},
"rsi": {
"rsi14": signal.rsi14,
"status": "oversold" if signal.rsi14 < 30 else "overbought" if signal.rsi14 > 70 else "neutral",
},
"macd": {
"macd": signal.macd,
"signal": signal.macd_signal,
"histogram": signal.macd - signal.macd_signal,
},
"bollinger": {
"upper": signal.bollinger_upper,
"mid": signal.bollinger_mid,
"lower": signal.bollinger_lower,
},
"volatility": {
"vol_10d": v10,
"vol_20d": v20,
"vol_60d": v60,
"annualized": signal.annualized_volatility_pct,
"risk_level": signal.risk_level,
},
"trend": signal.trend,
"mean_reversion": signal.mean_reversion_signal,
}
await websocket.send(json.dumps({
"type": "stock_technical_indicators_loaded",
"ticker": ticker,
"indicators": indicators,
}, ensure_ascii=False, default=str))
except Exception as exc:
logger.exception("Error getting technical indicators for %s", ticker)
await websocket.send(json.dumps({
"type": "stock_technical_indicators_loaded",
"ticker": ticker,
"indicators": None,
"error": str(exc),
}, ensure_ascii=False))
# Use a semaphore to prevent too many concurrent CPU-intensive calculations
# which can block the event loop heartbeats.
if not hasattr(gateway, "_calc_sem"):
gateway._calc_sem = asyncio.Semaphore(3)
async with gateway._calc_sem:
indicators = await asyncio.to_thread(_calc)
# Also offload JSON serialization to thread to avoid blocking main loop
msg = await asyncio.to_thread(json.dumps, {
"type": "stock_technical_indicators_loaded",
"ticker": ticker,
"indicators": indicators,
}, ensure_ascii=False, default=str)
if websocket.state.name == 'OPEN':
await websocket.send(msg)
else:
logger.warning("Websocket closed for %s, skipping indicator send", ticker)
except Exception as exc:
logger.exception("Error getting technical indicators for %s", ticker)
await websocket.send(json.dumps({
"type": "stock_technical_indicators_loaded",
"ticker": ticker,
"indicators": None,
"error": str(exc),
}, ensure_ascii=False))
async def handle_run_stock_enrich(gateway: Any, websocket: Any, data: dict[str, Any]) -> None:

View File

@@ -7,6 +7,7 @@ Handles reading/writing dashboard JSON files and portfolio state
import json
import logging
import os
import time
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional
@@ -950,11 +951,14 @@ class StorageService:
def save_server_state(self, state: Dict[str, Any]):
"""
Save server state to file
Args:
state: Server state dictionary
Save server state to file with rate-limiting to avoid I/O storms.
"""
now = time.time()
# Ensure at least 2 seconds between physical disk writes
if hasattr(self, "_last_save_time") and (now - self._last_save_time) < 2.0:
return
self._last_save_time = now
state_to_save = {
**state,
"last_saved": datetime.now().isoformat(),
@@ -970,14 +974,17 @@ class StorageService:
if "trades" in state_to_save:
state_to_save["trades"] = state_to_save["trades"][:100]
with open(self.server_state_file, "w", encoding="utf-8") as f:
json.dump(
state_to_save,
f,
ensure_ascii=False,
indent=2,
default=str,
)
try:
with open(self.server_state_file, "w", encoding="utf-8") as f:
# Removed indent=2 to minimize file size and serialization overhead
json.dump(
state_to_save,
f,
ensure_ascii=False,
default=str,
)
except Exception as e:
logger.error(f"Failed to save server state: {e}")
logger.debug(f"Server state saved to: {self.server_state_file}")