# -*- coding: utf-8 -*- """Agent service client for agent orchestration and runtime operations.""" import json from typing import Any, AsyncIterator import httpx import websockets from shared.schema.signals import AgentStateData class AgentServiceClient: """Async client for the Agent Service API.""" def __init__(self, base_url: str = "http://localhost:8000"): """Initialize the client with a base URL. Args: base_url: Base URL for the agent service API. """ self.base_url = base_url.rstrip("/") self._client: httpx.AsyncClient | None = None async def __aenter__(self) -> "AgentServiceClient": self._client = httpx.AsyncClient(base_url=self.base_url, timeout=30.0) return self async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: if self._client: await self._client.aclose() async def get_agents(self) -> dict: """Get list of all registered agents. Returns: Dictionary with agent list. """ response = await self._client.get("/api/agents") response.raise_for_status() return response.json() async def get_agent_status(self, agent_id: str) -> dict: """Get status of a specific agent. Args: agent_id: The agent identifier. Returns: Dictionary with agent status. """ response = await self._client.get(f"/api/agents/{agent_id}/status") response.raise_for_status() return response.json() async def post_run_daily( self, tickers: list[str], start_date: str, end_date: str, runtime_config: dict[str, Any] | None = None, ) -> dict: """Trigger a daily analysis run. Args: tickers: List of stock tickers to analyze. start_date: Start date (YYYY-MM-DD). end_date: End date (YYYY-MM-DD). runtime_config: Optional runtime configuration. Returns: Dictionary with run initiation response. """ payload = { "tickers": tickers, "start_date": start_date, "end_date": end_date, } if runtime_config: payload["runtime_config"] = runtime_config response = await self._client.post("/api/run/daily", json=payload) response.raise_for_status() return response.json() async def get_run_status(self, run_id: str) -> dict: """Get status of a run. Args: run_id: The run identifier. Returns: Dictionary with run status. """ response = await self._client.get(f"/api/runs/{run_id}/status") response.raise_for_status() return response.json() async def get_run_result(self, run_id: str) -> AgentStateData: """Get the result of a completed run. Args: run_id: The run identifier. Returns: AgentStateData with run results. """ response = await self._client.get(f"/api/runs/{run_id}/result") response.raise_for_status() return AgentStateData.model_validate(response.json()) async def get_run_logs(self, run_id: str) -> dict: """Get logs for a run. Args: run_id: The run identifier. Returns: Dictionary with run logs. """ response = await self._client.get(f"/api/runs/{run_id}/logs") response.raise_for_status() return response.json() async def cancel_run(self, run_id: str) -> dict: """Cancel a running task. Args: run_id: The run identifier. Returns: Dictionary with cancellation confirmation. """ response = await self._client.post(f"/api/runs/{run_id}/cancel") response.raise_for_status() return response.json() async def get_runtime_config(self) -> dict: """Get current runtime configuration. Returns: Dictionary with runtime config. """ response = await self._client.get("/api/runtime/config") response.raise_for_status() return response.json() async def update_runtime_config(self, config: dict[str, Any]) -> dict: """Update runtime configuration. Args: config: New runtime configuration. Returns: Dictionary with updated config. """ response = await self._client.put("/api/runtime/config", json=config) response.raise_for_status() return response.json() async def websocket_connect( self, run_id: str | None = None, ) -> AsyncIterator[dict]: """Connect to WebSocket for real-time updates. Args: run_id: Optional run ID to subscribe to. Yields: Dictionary with WebSocket messages. """ ws_url = self.base_url.replace("http", "ws") + "/ws" if run_id: ws_url += f"?run_id={run_id}" async with websockets.connect(ws_url) as ws: async for message in ws: yield json.loads(message) async def get_pipeline_status(self) -> dict: """Get current pipeline execution status. Returns: Dictionary with pipeline status. """ response = await self._client.get("/api/pipeline/status") response.raise_for_status() return response.json() async def trigger_pipeline( self, pipeline_type: str, tickers: list[str], config: dict[str, Any] | None = None, ) -> dict: """Trigger a pipeline execution. Args: pipeline_type: Type of pipeline to run. tickers: List of tickers to process. config: Optional pipeline configuration. Returns: Dictionary with pipeline trigger response. """ payload = {"pipeline_type": pipeline_type, "tickers": tickers} if config: payload["config"] = config response = await self._client.post("/api/pipeline/trigger", json=payload) response.raise_for_status() return response.json()