313 lines
11 KiB
Python
313 lines
11 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""Gateway Server - Entry point for Gateway subprocess.
|
|
|
|
This module is launched as a subprocess by the Control Plane (FastAPI)
|
|
to run the Data Plane (Gateway + Pipeline).
|
|
"""
|
|
|
|
import argparse
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
import sys
|
|
from contextlib import AsyncExitStack
|
|
from pathlib import Path
|
|
|
|
from dotenv import load_dotenv
|
|
|
|
# Load environment variables
|
|
load_dotenv()
|
|
|
|
from backend.agents import AnalystAgent, PMAgent, RiskAgent
|
|
from backend.agents.skills_manager import SkillsManager
|
|
from backend.agents.toolkit_factory import create_agent_toolkit, load_agent_profiles
|
|
from backend.agents.prompt_loader import get_prompt_loader
|
|
from backend.agents.workspace_manager import WorkspaceManager
|
|
from backend.config.constants import ANALYST_TYPES
|
|
from backend.core.pipeline import TradingPipeline
|
|
from backend.core.pipeline_runner import create_agents, create_long_term_memory
|
|
from backend.core.scheduler import BacktestScheduler, Scheduler
|
|
from backend.llm.models import get_agent_formatter, get_agent_model
|
|
from backend.runtime.manager import (
|
|
TradingRuntimeManager,
|
|
set_global_runtime_manager,
|
|
clear_global_runtime_manager,
|
|
)
|
|
from backend.services.gateway import Gateway
|
|
from backend.services.market import MarketService
|
|
from backend.services.storage import StorageService
|
|
from backend.utils.settlement import SettlementCoordinator
|
|
|
|
logger = logging.getLogger(__name__)
|
|
_prompt_loader = get_prompt_loader()
|
|
|
|
|
|
INFO_LOGGER_PREFIXES = (
|
|
"backend.agents",
|
|
"backend.core.pipeline",
|
|
"backend.core.scheduler",
|
|
"backend.services.gateway_cycle_support",
|
|
"backend.utils.terminal_dashboard",
|
|
)
|
|
|
|
NOISY_LOGGER_LEVELS = {
|
|
"aiohttp": logging.WARNING,
|
|
"asyncio": logging.WARNING,
|
|
"dashscope": logging.WARNING,
|
|
"finnhub": logging.WARNING,
|
|
"httpcore": logging.WARNING,
|
|
"httpx": logging.WARNING,
|
|
"urllib3": logging.WARNING,
|
|
"websockets": logging.WARNING,
|
|
"yfinance": logging.WARNING,
|
|
"backend.data.polling_price_manager": logging.WARNING,
|
|
"backend.services.gateway": logging.WARNING,
|
|
"backend.services.market": logging.WARNING,
|
|
"backend.services.storage": logging.WARNING,
|
|
}
|
|
|
|
|
|
class SuppressNoisyInfoFilter(logging.Filter):
|
|
"""Filter out low-signal library INFO logs while keeping warnings/errors."""
|
|
|
|
def filter(self, record: logging.LogRecord) -> bool:
|
|
if record.levelno >= logging.WARNING:
|
|
return True
|
|
|
|
message = record.getMessage()
|
|
if record.name == "httpx" and message.startswith("HTTP Request:"):
|
|
return False
|
|
if record.name.startswith("websockets") and "connection open" in message:
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
def configure_gateway_logging(verbose: bool = False) -> None:
|
|
"""Configure gateway logging with low-noise defaults for runtime logs."""
|
|
root_level = logging.DEBUG if verbose else logging.WARNING
|
|
logging.basicConfig(
|
|
level=root_level,
|
|
format="%(asctime)s | %(levelname)-7s | %(name)s:%(lineno)d - %(message)s",
|
|
force=True,
|
|
)
|
|
|
|
if not verbose:
|
|
suppress_filter = SuppressNoisyInfoFilter()
|
|
for handler in logging.getLogger().handlers:
|
|
handler.addFilter(suppress_filter)
|
|
|
|
for logger_name, level in NOISY_LOGGER_LEVELS.items():
|
|
logging.getLogger(logger_name).setLevel(logging.DEBUG if verbose else level)
|
|
|
|
if not verbose:
|
|
for prefix in INFO_LOGGER_PREFIXES:
|
|
logging.getLogger(prefix).setLevel(logging.INFO)
|
|
|
|
logging.getLogger(__name__).setLevel(logging.INFO if not verbose else logging.DEBUG)
|
|
|
|
|
|
async def run_gateway(
|
|
run_id: str,
|
|
run_dir: Path,
|
|
bootstrap: dict,
|
|
port: int
|
|
):
|
|
"""Run Gateway with Pipeline."""
|
|
|
|
# Extract config
|
|
tickers = bootstrap.get("tickers", ["AAPL", "MSFT"])
|
|
initial_cash = float(bootstrap.get("initial_cash", 100000.0))
|
|
margin_requirement = float(bootstrap.get("margin_requirement", 0.0))
|
|
max_comm_cycles = int(bootstrap.get("max_comm_cycles", 2))
|
|
schedule_mode = bootstrap.get("schedule_mode", "daily")
|
|
trigger_time = bootstrap.get("trigger_time", "09:30")
|
|
interval_minutes = int(bootstrap.get("interval_minutes", 60))
|
|
heartbeat_interval = int(bootstrap.get("heartbeat_interval", 0)) # 0 = disabled
|
|
mode = bootstrap.get("mode", "live")
|
|
start_date = bootstrap.get("start_date")
|
|
end_date = bootstrap.get("end_date")
|
|
enable_memory = bootstrap.get("enable_memory", False)
|
|
poll_interval = int(bootstrap.get("poll_interval", 10))
|
|
enable_mock = bootstrap.get("enable_mock", False)
|
|
|
|
is_backtest = mode == "backtest"
|
|
is_mock = enable_mock or mode == "mock" or (not is_backtest and os.getenv("MOCK_MODE", "false").lower() == "true")
|
|
|
|
logger.info(f"[Gateway Server] Starting run {run_id} on port {port}")
|
|
|
|
# Create runtime manager
|
|
runtime_manager = TradingRuntimeManager(
|
|
config_name=run_id,
|
|
run_dir=run_dir,
|
|
bootstrap=bootstrap,
|
|
)
|
|
runtime_manager.prepare_run()
|
|
set_global_runtime_manager(runtime_manager)
|
|
|
|
try:
|
|
async with AsyncExitStack() as stack:
|
|
# Create services
|
|
market_service = MarketService(
|
|
tickers=tickers,
|
|
poll_interval=poll_interval,
|
|
mock_mode=is_mock and not is_backtest,
|
|
backtest_mode=is_backtest,
|
|
api_key=os.getenv("FINNHUB_API_KEY") if not is_mock and not is_backtest else None,
|
|
backtest_start_date=start_date if is_backtest else None,
|
|
backtest_end_date=end_date if is_backtest else None,
|
|
)
|
|
|
|
storage_service = StorageService(
|
|
dashboard_dir=run_dir / "team_dashboard",
|
|
initial_cash=initial_cash,
|
|
config_name=run_id,
|
|
)
|
|
|
|
if not storage_service.files["summary"].exists():
|
|
storage_service.initialize_empty_dashboard()
|
|
else:
|
|
storage_service.update_leaderboard_model_info()
|
|
|
|
# Create agents
|
|
analysts, risk_manager, pm, long_term_memories = create_agents(
|
|
run_id=run_id,
|
|
run_dir=run_dir,
|
|
initial_cash=initial_cash,
|
|
margin_requirement=margin_requirement,
|
|
enable_long_term_memory=enable_memory,
|
|
)
|
|
|
|
# Register agents
|
|
for agent in analysts + [risk_manager, pm]:
|
|
agent_id = getattr(agent, "agent_id", None) or getattr(agent, "name", None)
|
|
if agent_id:
|
|
runtime_manager.register_agent(agent_id)
|
|
|
|
# Load portfolio state
|
|
portfolio_state = storage_service.load_portfolio_state()
|
|
pm.load_portfolio_state(portfolio_state)
|
|
|
|
# Create settlement coordinator
|
|
settlement_coordinator = SettlementCoordinator(
|
|
storage=storage_service,
|
|
initial_capital=initial_cash,
|
|
)
|
|
|
|
# Create pipeline
|
|
pipeline = TradingPipeline(
|
|
analysts=analysts,
|
|
risk_manager=risk_manager,
|
|
portfolio_manager=pm,
|
|
settlement_coordinator=settlement_coordinator,
|
|
max_comm_cycles=max_comm_cycles,
|
|
runtime_manager=runtime_manager,
|
|
)
|
|
|
|
# Create scheduler
|
|
scheduler_callback = None
|
|
live_scheduler = None
|
|
|
|
if is_backtest:
|
|
backtest_scheduler = BacktestScheduler(
|
|
start_date=start_date,
|
|
end_date=end_date,
|
|
trading_calendar="NYSE",
|
|
delay_between_days=0.5,
|
|
)
|
|
|
|
async def scheduler_callback_fn(callback):
|
|
await backtest_scheduler.start(callback)
|
|
|
|
scheduler_callback = scheduler_callback_fn
|
|
else:
|
|
live_scheduler = Scheduler(
|
|
mode=schedule_mode,
|
|
trigger_time=trigger_time,
|
|
interval_minutes=interval_minutes,
|
|
heartbeat_interval=heartbeat_interval if heartbeat_interval > 0 else None,
|
|
config={"config_name": run_id},
|
|
)
|
|
|
|
async def scheduler_callback_fn(callback):
|
|
await live_scheduler.start(callback)
|
|
|
|
scheduler_callback = scheduler_callback_fn
|
|
|
|
# Enter long-term memory contexts
|
|
for memory in long_term_memories:
|
|
await stack.enter_async_context(memory)
|
|
|
|
# Create Gateway
|
|
gateway = Gateway(
|
|
market_service=market_service,
|
|
storage_service=storage_service,
|
|
pipeline=pipeline,
|
|
scheduler_callback=scheduler_callback,
|
|
config={
|
|
"mode": mode,
|
|
"mock_mode": is_mock,
|
|
"backtest_mode": is_backtest,
|
|
"tickers": tickers,
|
|
"config_name": run_id,
|
|
"schedule_mode": schedule_mode,
|
|
"interval_minutes": interval_minutes,
|
|
"trigger_time": trigger_time,
|
|
"heartbeat_interval": heartbeat_interval,
|
|
"initial_cash": initial_cash,
|
|
"margin_requirement": margin_requirement,
|
|
"max_comm_cycles": max_comm_cycles,
|
|
"enable_memory": enable_memory,
|
|
},
|
|
scheduler=live_scheduler,
|
|
)
|
|
|
|
# Start Gateway (blocks until shutdown)
|
|
logger.info(f"[Gateway Server] Gateway starting on port {port}")
|
|
await gateway.start(host="0.0.0.0", port=port)
|
|
|
|
except asyncio.CancelledError:
|
|
logger.info("[Gateway Server] Cancelled")
|
|
raise
|
|
finally:
|
|
logger.info("[Gateway Server] Cleaning up")
|
|
clear_global_runtime_manager()
|
|
|
|
|
|
def main():
|
|
"""Main entry point."""
|
|
parser = argparse.ArgumentParser(description="Gateway Server")
|
|
parser.add_argument("--run-id", required=True, help="Run identifier")
|
|
parser.add_argument("--run-dir", required=True, help="Run directory path")
|
|
parser.add_argument("--port", type=int, default=8765, help="WebSocket port")
|
|
parser.add_argument("--bootstrap", required=True, help="Bootstrap config as JSON")
|
|
parser.add_argument("--verbose", action="store_true", help="Verbose logging")
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Setup logging
|
|
configure_gateway_logging(verbose=args.verbose)
|
|
|
|
# Parse bootstrap
|
|
bootstrap = json.loads(args.bootstrap)
|
|
run_dir = Path(args.run_dir)
|
|
|
|
# Run
|
|
try:
|
|
asyncio.run(run_gateway(
|
|
run_id=args.run_id,
|
|
run_dir=run_dir,
|
|
bootstrap=bootstrap,
|
|
port=args.port
|
|
))
|
|
except KeyboardInterrupt:
|
|
logger.info("[Gateway Server] Interrupted by user")
|
|
except Exception as e:
|
|
logger.exception(f"[Gateway Server] Fatal error: {e}")
|
|
sys.exit(1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|