feat: Integrate WebSocket Gateway into API-based task launch

- Gateway now starts automatically when calling POST /runtime/start
- No need to run python backend/main.py separately
- Single service architecture: only FastAPI (port 8000) needed
- Gateway runs in background task and stops with pipeline
- Add error handling and logging for Gateway lifecycle

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-03-17 18:36:33 +08:00
parent a3f767126f
commit 2dcda63394
2 changed files with 111 additions and 22 deletions

View File

@@ -33,10 +33,33 @@ from backend.runtime.manager import (
) )
from backend.services.market import MarketService from backend.services.market import MarketService
from backend.services.storage import StorageService from backend.services.storage import StorageService
from backend.services.gateway import Gateway
from backend.utils.settlement import SettlementCoordinator from backend.utils.settlement import SettlementCoordinator
_prompt_loader = PromptLoader() _prompt_loader = PromptLoader()
# Global gateway reference for cleanup
_gateway_instance: Optional[Gateway] = None
def _set_gateway(gateway: Optional[Gateway]) -> None:
"""Set global gateway reference."""
global _gateway_instance
_gateway_instance = gateway
def stop_gateway() -> None:
"""Stop the running gateway if exists."""
global _gateway_instance
if _gateway_instance is not None:
try:
_gateway_instance.stop()
except Exception as e:
import logging
logging.getLogger(__name__).error(f"Error stopping gateway: {e}")
finally:
_gateway_instance = None
async def create_long_term_memory(agent_name: str, run_id: str, run_dir: Path): async def create_long_term_memory(agent_name: str, run_id: str, run_dir: Path):
"""Create ReMeTaskLongTermMemory for an agent.""" """Create ReMeTaskLongTermMemory for an agent."""
@@ -321,12 +344,45 @@ async def run_pipeline(
scheduler_callback = scheduler_callback_fn scheduler_callback = scheduler_callback_fn
# Create Gateway for WebSocket connections (after pipeline and scheduler are ready)
gateway = Gateway(
market_service=market_service,
storage_service=storage_service,
pipeline=pipeline,
scheduler_callback=scheduler_callback,
config={
"mode": mode,
"mock_mode": is_mock,
"backtest_mode": is_backtest,
"tickers": tickers,
"config_name": run_id,
"schedule_mode": schedule_mode,
"interval_minutes": interval_minutes,
"trigger_time": trigger_time,
"initial_cash": initial_cash,
"margin_requirement": margin_requirement,
"max_comm_cycles": max_comm_cycles,
"enable_memory": enable_memory,
},
scheduler=live_scheduler,
)
_set_gateway(gateway)
# Start pipeline execution # Start pipeline execution
async with AsyncExitStack() as stack: async with AsyncExitStack() as stack:
# Enter long-term memory contexts # Enter long-term memory contexts
for memory in long_term_memories: for memory in long_term_memories:
await stack.enter_async_context(memory) await stack.enter_async_context(memory)
# Start Gateway in background task
gateway_task = asyncio.create_task(
gateway.start(host="0.0.0.0", port=8765)
)
logger.info("[Pipeline] Gateway started on ws://localhost:8765")
# Give Gateway a moment to start
await asyncio.sleep(0.5)
# Define the trading cycle callback # Define the trading cycle callback
async def trading_cycle(session_key: str) -> None: async def trading_cycle(session_key: str) -> None:
"""Execute one trading cycle.""" """Execute one trading cycle."""
@@ -360,12 +416,30 @@ async def run_pipeline(
while not stop_event.is_set(): while not stop_event.is_set():
await asyncio.sleep(1) await asyncio.sleep(1)
# Cancel gateway task
if not gateway_task.done():
gateway_task.cancel()
try:
await gateway_task
except asyncio.CancelledError:
pass
except asyncio.CancelledError: except asyncio.CancelledError:
# Handle cancellation gracefully # Handle cancellation gracefully
raise raise
finally: finally:
# Cleanup # Cleanup
logger.info("[Pipeline] Cleaning up...")
# Stop Gateway
try:
stop_gateway()
logger.info("[Pipeline] Gateway stopped")
except Exception as e:
logger.error(f"[Pipeline] Error stopping gateway: {e}")
clear_shutdown_event() clear_shutdown_event()
clear_global_runtime_manager() clear_global_runtime_manager()
from backend.api.runtime import unregister_runtime_manager from backend.api.runtime import unregister_runtime_manager
unregister_runtime_manager() unregister_runtime_manager()
logger.info("[Pipeline] Cleanup complete")

View File

@@ -238,30 +238,45 @@ class Gateway:
self.connected_clients.discard(websocket) self.connected_clients.discard(websocket)
async def _send_initial_state(self, websocket: ServerConnection): async def _send_initial_state(self, websocket: ServerConnection):
state_payload = self.state_sync.get_initial_state_payload( try:
include_dashboard=True, logger.info("[Gateway] Sending initial state to client...")
) state_payload = self.state_sync.get_initial_state_payload(
state_payload["data_sources"] = ( include_dashboard=True,
self._provider_router.get_usage_snapshot() )
) state_payload["data_sources"] = (
# Include market status in initial state self._provider_router.get_usage_snapshot()
state_payload[ )
"market_status" # Include market status in initial state
] = self.market_service.get_market_status() state_payload[
"market_status"
] = self.market_service.get_market_status()
# Include live returns if session is active # Include live returns if session is active
if self.storage.is_live_session_active: if self.storage.is_live_session_active:
live_returns = self.storage.get_live_returns() live_returns = self.storage.get_live_returns()
if "portfolio" in state_payload: if "portfolio" in state_payload:
state_payload["portfolio"].update(live_returns) state_payload["portfolio"].update(live_returns)
await websocket.send( await websocket.send(
json.dumps( json.dumps(
{"type": "initial_state", "state": state_payload}, {"type": "initial_state", "state": state_payload},
ensure_ascii=False, ensure_ascii=False,
default=str, default=str,
), ),
) )
logger.info("[Gateway] Initial state sent successfully")
except Exception as e:
logger.exception(f"[Gateway] Failed to send initial state: {e}")
# Send error response so client knows something went wrong
try:
await websocket.send(
json.dumps(
{"type": "error", "message": "Failed to load initial state"},
ensure_ascii=False,
),
)
except Exception:
pass
async def _handle_client_messages( async def _handle_client_messages(
self, self,