feat: Integrate WebSocket Gateway into API-based task launch
- Gateway now starts automatically when calling POST /runtime/start - No need to run python backend/main.py separately - Single service architecture: only FastAPI (port 8000) needed - Gateway runs in background task and stops with pipeline - Add error handling and logging for Gateway lifecycle Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -33,10 +33,33 @@ from backend.runtime.manager import (
|
||||
)
|
||||
from backend.services.market import MarketService
|
||||
from backend.services.storage import StorageService
|
||||
from backend.services.gateway import Gateway
|
||||
from backend.utils.settlement import SettlementCoordinator
|
||||
|
||||
_prompt_loader = PromptLoader()
|
||||
|
||||
# Global gateway reference for cleanup
|
||||
_gateway_instance: Optional[Gateway] = None
|
||||
|
||||
|
||||
def _set_gateway(gateway: Optional[Gateway]) -> None:
|
||||
"""Set global gateway reference."""
|
||||
global _gateway_instance
|
||||
_gateway_instance = gateway
|
||||
|
||||
|
||||
def stop_gateway() -> None:
|
||||
"""Stop the running gateway if exists."""
|
||||
global _gateway_instance
|
||||
if _gateway_instance is not None:
|
||||
try:
|
||||
_gateway_instance.stop()
|
||||
except Exception as e:
|
||||
import logging
|
||||
logging.getLogger(__name__).error(f"Error stopping gateway: {e}")
|
||||
finally:
|
||||
_gateway_instance = None
|
||||
|
||||
|
||||
async def create_long_term_memory(agent_name: str, run_id: str, run_dir: Path):
|
||||
"""Create ReMeTaskLongTermMemory for an agent."""
|
||||
@@ -321,12 +344,45 @@ async def run_pipeline(
|
||||
|
||||
scheduler_callback = scheduler_callback_fn
|
||||
|
||||
# Create Gateway for WebSocket connections (after pipeline and scheduler are ready)
|
||||
gateway = Gateway(
|
||||
market_service=market_service,
|
||||
storage_service=storage_service,
|
||||
pipeline=pipeline,
|
||||
scheduler_callback=scheduler_callback,
|
||||
config={
|
||||
"mode": mode,
|
||||
"mock_mode": is_mock,
|
||||
"backtest_mode": is_backtest,
|
||||
"tickers": tickers,
|
||||
"config_name": run_id,
|
||||
"schedule_mode": schedule_mode,
|
||||
"interval_minutes": interval_minutes,
|
||||
"trigger_time": trigger_time,
|
||||
"initial_cash": initial_cash,
|
||||
"margin_requirement": margin_requirement,
|
||||
"max_comm_cycles": max_comm_cycles,
|
||||
"enable_memory": enable_memory,
|
||||
},
|
||||
scheduler=live_scheduler,
|
||||
)
|
||||
_set_gateway(gateway)
|
||||
|
||||
# Start pipeline execution
|
||||
async with AsyncExitStack() as stack:
|
||||
# Enter long-term memory contexts
|
||||
for memory in long_term_memories:
|
||||
await stack.enter_async_context(memory)
|
||||
|
||||
# Start Gateway in background task
|
||||
gateway_task = asyncio.create_task(
|
||||
gateway.start(host="0.0.0.0", port=8765)
|
||||
)
|
||||
logger.info("[Pipeline] Gateway started on ws://localhost:8765")
|
||||
|
||||
# Give Gateway a moment to start
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Define the trading cycle callback
|
||||
async def trading_cycle(session_key: str) -> None:
|
||||
"""Execute one trading cycle."""
|
||||
@@ -360,12 +416,30 @@ async def run_pipeline(
|
||||
while not stop_event.is_set():
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Cancel gateway task
|
||||
if not gateway_task.done():
|
||||
gateway_task.cancel()
|
||||
try:
|
||||
await gateway_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
except asyncio.CancelledError:
|
||||
# Handle cancellation gracefully
|
||||
raise
|
||||
finally:
|
||||
# Cleanup
|
||||
logger.info("[Pipeline] Cleaning up...")
|
||||
|
||||
# Stop Gateway
|
||||
try:
|
||||
stop_gateway()
|
||||
logger.info("[Pipeline] Gateway stopped")
|
||||
except Exception as e:
|
||||
logger.error(f"[Pipeline] Error stopping gateway: {e}")
|
||||
|
||||
clear_shutdown_event()
|
||||
clear_global_runtime_manager()
|
||||
from backend.api.runtime import unregister_runtime_manager
|
||||
unregister_runtime_manager()
|
||||
logger.info("[Pipeline] Cleanup complete")
|
||||
|
||||
@@ -238,30 +238,45 @@ class Gateway:
|
||||
self.connected_clients.discard(websocket)
|
||||
|
||||
async def _send_initial_state(self, websocket: ServerConnection):
|
||||
state_payload = self.state_sync.get_initial_state_payload(
|
||||
include_dashboard=True,
|
||||
)
|
||||
state_payload["data_sources"] = (
|
||||
self._provider_router.get_usage_snapshot()
|
||||
)
|
||||
# Include market status in initial state
|
||||
state_payload[
|
||||
"market_status"
|
||||
] = self.market_service.get_market_status()
|
||||
try:
|
||||
logger.info("[Gateway] Sending initial state to client...")
|
||||
state_payload = self.state_sync.get_initial_state_payload(
|
||||
include_dashboard=True,
|
||||
)
|
||||
state_payload["data_sources"] = (
|
||||
self._provider_router.get_usage_snapshot()
|
||||
)
|
||||
# Include market status in initial state
|
||||
state_payload[
|
||||
"market_status"
|
||||
] = self.market_service.get_market_status()
|
||||
|
||||
# Include live returns if session is active
|
||||
if self.storage.is_live_session_active:
|
||||
live_returns = self.storage.get_live_returns()
|
||||
if "portfolio" in state_payload:
|
||||
state_payload["portfolio"].update(live_returns)
|
||||
# Include live returns if session is active
|
||||
if self.storage.is_live_session_active:
|
||||
live_returns = self.storage.get_live_returns()
|
||||
if "portfolio" in state_payload:
|
||||
state_payload["portfolio"].update(live_returns)
|
||||
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{"type": "initial_state", "state": state_payload},
|
||||
ensure_ascii=False,
|
||||
default=str,
|
||||
),
|
||||
)
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{"type": "initial_state", "state": state_payload},
|
||||
ensure_ascii=False,
|
||||
default=str,
|
||||
),
|
||||
)
|
||||
logger.info("[Gateway] Initial state sent successfully")
|
||||
except Exception as e:
|
||||
logger.exception(f"[Gateway] Failed to send initial state: {e}")
|
||||
# Send error response so client knows something went wrong
|
||||
try:
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{"type": "error", "message": "Failed to load initial state"},
|
||||
ensure_ascii=False,
|
||||
),
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def _handle_client_messages(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user