- Migrate OpenClaw from HTTP (port 8004) to WebSocket (port 18789) - Add workspace file list and content preview handlers - Add OpenClawStatus component with agent/skills view - Add OpenClawView panel in trader interface - Add Zustand store for OpenClaw state management - Fix gateway logging noise (yfinance, websockets) - Fix RunWorkspaceManager.get_agent_asset_dir attribute error - Handle missing workspace files gracefully in preview Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
740 lines
26 KiB
Python
740 lines
26 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""OpenClaw Gateway WebSocket client for bidirectional agent communication."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import sys
|
|
import time
|
|
import uuid
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
from typing import Any, Callable
|
|
|
|
import httpx
|
|
import websockets
|
|
from cryptography.hazmat.primitives import hashes, serialization
|
|
from cryptography.hazmat.primitives.asymmetric import ec
|
|
from cryptography.hazmat.primitives.asymmetric.utils import decode_dss_signature
|
|
from cryptography.hazmat.backends import default_backend
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Default Gateway port
|
|
DEFAULT_GATEWAY_PORT = 18789
|
|
DEFAULT_GATEWAY_URL = f"ws://127.0.0.1:{DEFAULT_GATEWAY_PORT}"
|
|
|
|
# Protocol version (from protocol/schema/protocol-schemas.ts)
|
|
PROTOCOL_VERSION = 3
|
|
|
|
|
|
@dataclass
|
|
class DeviceIdentity:
|
|
"""Device identity for Gateway authentication."""
|
|
device_id: str
|
|
public_key_pem: bytes
|
|
private_key_pem: bytes
|
|
|
|
@classmethod
|
|
def load_or_create(cls, identity_dir: Path | None = None) -> "DeviceIdentity":
|
|
"""Load existing device identity from OpenClaw's identity directory, or create a new one."""
|
|
if identity_dir is None:
|
|
identity_dir = Path.home() / ".openclaw" / "identity"
|
|
|
|
device_json = identity_dir / "device.json"
|
|
|
|
# Check if identity exists in OpenClaw's format
|
|
if device_json.exists():
|
|
import json
|
|
data = json.loads(device_json.read_text())
|
|
return cls(
|
|
device_id=data["deviceId"],
|
|
public_key_pem=data["publicKeyPem"].encode(),
|
|
private_key_pem=data["privateKeyPem"].encode(),
|
|
)
|
|
|
|
# Fall back to old devices directory format
|
|
device_dir = Path.home() / ".openclaw" / "devices"
|
|
id_file = device_dir / "device_id"
|
|
pubkey_file = device_dir / "device_pubkey.pem"
|
|
privkey_file = device_dir / "device_privkey.pem"
|
|
|
|
if id_file.exists() and pubkey_file.exists() and privkey_file.exists():
|
|
device_id = id_file.read_text().strip()
|
|
public_key_pem = pubkey_file.read_bytes()
|
|
private_key_pem = privkey_file.read_bytes()
|
|
return cls(
|
|
device_id=device_id,
|
|
public_key_pem=public_key_pem,
|
|
private_key_pem=private_key_pem,
|
|
)
|
|
|
|
# Generate new identity (Ed25519, matching OpenClaw's approach)
|
|
from cryptography.hazmat.primitives.asymmetric import ed25519
|
|
private_key = ed25519.Ed25519PrivateKey.generate()
|
|
public_key = private_key.public_key()
|
|
|
|
# Derive device ID from public key (SHA256 hash)
|
|
import hashlib
|
|
public_key_raw = public_key.public_bytes(
|
|
encoding=serialization.Encoding.Raw,
|
|
format=serialization.PublicFormat.Raw,
|
|
)
|
|
device_id = hashlib.sha256(public_key_raw).hexdigest()
|
|
public_key_pem = public_key.public_bytes(
|
|
encoding=serialization.Encoding.PEM,
|
|
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
|
)
|
|
private_key_pem = private_key.private_bytes(
|
|
encoding=serialization.Encoding.PEM,
|
|
format=serialization.PrivateFormat.PKCS8,
|
|
encryption_algorithm=serialization.NoEncryption(),
|
|
)
|
|
|
|
# Save to disk
|
|
device_dir.mkdir(parents=True, exist_ok=True)
|
|
id_file.write_text(device_id)
|
|
pubkey_file.write_bytes(public_key_pem)
|
|
privkey_file.write_bytes(private_key_pem)
|
|
|
|
logger.info(f"Created new device identity: {device_id}")
|
|
return cls(
|
|
device_id=device_id,
|
|
public_key_pem=public_key_pem,
|
|
private_key_pem=private_key_pem,
|
|
)
|
|
|
|
def sign(self, payload: str) -> tuple[int, int]:
|
|
"""Sign a payload (for ECDSA keys)."""
|
|
private_key = serialization.load_pem_private_key(
|
|
self.private_key_pem, password=None, backend=default_backend()
|
|
)
|
|
signature = private_key.sign(payload.encode(), ec.ECDSA(hashes.SHA256()))
|
|
r, s = decode_dss_signature(signature)
|
|
return r, s
|
|
|
|
def sign_base64url(self, payload: str) -> str:
|
|
"""Sign payload and return base64url encoded signature (matches TypeScript crypto.sign)."""
|
|
import base64
|
|
private_key = serialization.load_pem_private_key(
|
|
self.private_key_pem, password=None, backend=default_backend()
|
|
)
|
|
# Ed25519 signing (used by OpenClaw)
|
|
sig = private_key.sign(payload.encode())
|
|
return base64.urlsafe_b64encode(sig).rstrip(b"=").decode()
|
|
|
|
|
|
@dataclass
|
|
class GatewayHello:
|
|
"""Gateway hello response after connection."""
|
|
protocol: int
|
|
server_version: str
|
|
conn_id: str
|
|
methods: list[str]
|
|
events: list[str]
|
|
device_token: str | None = None
|
|
role: str | None = None
|
|
scopes: list[str] | None = None
|
|
|
|
|
|
@dataclass
|
|
class MessageEvent:
|
|
"""Incoming message event from agent."""
|
|
event: str
|
|
payload: dict[str, Any]
|
|
seq: int | None = None
|
|
|
|
|
|
@dataclass
|
|
class SendResult:
|
|
"""Result of sending a message."""
|
|
message_id: str
|
|
session_key: str
|
|
ok: bool
|
|
|
|
|
|
class OpenClawWebSocketClient:
|
|
"""WebSocket client for OpenClaw Gateway.
|
|
|
|
Supports:
|
|
- Device authentication
|
|
- Send messages to agents via sessions.send
|
|
- Receive real-time responses via event subscription
|
|
- Session management
|
|
|
|
Example usage:
|
|
async with OpenClawWebSocketClient() as client:
|
|
await client.connect()
|
|
result = await client.send_message(session_key, "Hello agent!")
|
|
async for event in client.subscribe(session_key):
|
|
print(event)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
url: str = DEFAULT_GATEWAY_URL,
|
|
gateway_token: str | None = None,
|
|
device_identity: DeviceIdentity | None = None,
|
|
client_name: str = "cli", # Must be a valid GatewayClientId (cli, gateway-client, etc)
|
|
client_version: str = "1.0.0",
|
|
timeout_ms: int = 30000,
|
|
):
|
|
self.url = url
|
|
self.gateway_token = gateway_token or self._load_gateway_token()
|
|
self.device_identity = device_identity
|
|
self.client_name = client_name
|
|
self.client_version = client_version
|
|
self.timeout_ms = timeout_ms
|
|
|
|
self._ws: websockets.WebSocketClientProtocol | None = None
|
|
self._hello: GatewayHello | None = None
|
|
self._pending: dict[str, asyncio.Future] = {}
|
|
self._event_handlers: list[Callable[[MessageEvent], None]] = []
|
|
self._recv_task: asyncio.Task | None = None
|
|
self._nonce: str | None = None
|
|
self._connected = False
|
|
|
|
@staticmethod
|
|
def _load_gateway_token() -> str | None:
|
|
"""Load gateway token from ~/.openclaw/openclaw.json."""
|
|
try:
|
|
from pathlib import Path
|
|
token_file = Path.home() / ".openclaw" / "openclaw.json"
|
|
if token_file.exists():
|
|
data = json.loads(token_file.read_text())
|
|
return data.get("gateway", {}).get("auth", {}).get("token")
|
|
except Exception:
|
|
pass
|
|
return None
|
|
|
|
@property
|
|
def is_connected(self) -> bool:
|
|
return self._connected and self._ws is not None
|
|
|
|
async def __aenter__(self) -> "OpenClawWebSocketClient":
|
|
return self
|
|
|
|
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
|
|
await self.disconnect()
|
|
|
|
async def connect(self) -> GatewayHello:
|
|
"""Connect to the Gateway and complete authentication handshake."""
|
|
if self._connected:
|
|
return self._hello
|
|
|
|
# Load or create device identity
|
|
if self.device_identity is None:
|
|
self.device_identity = DeviceIdentity.load_or_create()
|
|
|
|
logger.info(f"Connecting to OpenClaw Gateway at {self.url}")
|
|
|
|
self._ws = await websockets.connect(
|
|
self.url,
|
|
max_size=25 * 1024 * 1024, # 25MB max payload
|
|
)
|
|
|
|
# Start receive loop
|
|
self._recv_task = asyncio.create_task(self._recv_loop())
|
|
|
|
# Wait for connect.challenge
|
|
challenge = await self._wait_for_event("connect.challenge")
|
|
self._nonce = challenge.payload.get("nonce")
|
|
|
|
# Build connect params
|
|
connect_params = self._build_connect_params()
|
|
|
|
# Debug: log connect params
|
|
import logging
|
|
logging.basicConfig(level=logging.DEBUG)
|
|
logger.debug(f"Connect params: {connect_params}")
|
|
|
|
# Send connect request and wait for hello-ok
|
|
hello_event = await self._send_request("connect", connect_params, _allow_handshake=True)
|
|
self._hello = GatewayHello(
|
|
protocol=hello_event["protocol"],
|
|
server_version=hello_event["server"]["version"],
|
|
conn_id=hello_event["server"]["connId"],
|
|
methods=hello_event["features"]["methods"],
|
|
events=hello_event["features"]["events"],
|
|
device_token=hello_event.get("auth", {}).get("deviceToken"),
|
|
role=hello_event.get("auth", {}).get("role"),
|
|
scopes=hello_event.get("auth", {}).get("scopes"),
|
|
)
|
|
|
|
self._connected = True
|
|
logger.info(f"Connected to OpenClaw Gateway v{self._hello.server_version}")
|
|
logger.info(f"Supported methods: {self._hello.methods}")
|
|
|
|
return self._hello
|
|
|
|
def _build_connect_params(self) -> dict[str, Any]:
|
|
"""Build connect parameters with device authentication.
|
|
|
|
Implements V3 device auth payload format:
|
|
v3|deviceId|clientId|clientMode|role|scopes|signedAtMs|token|nonce|platform|deviceFamily
|
|
"""
|
|
import base64
|
|
|
|
# Load public key - use Raw format for Ed25519 (32 bytes)
|
|
from cryptography.hazmat.primitives.asymmetric import ed25519
|
|
private_key = serialization.load_pem_private_key(
|
|
self.device_identity.private_key_pem, password=None, backend=default_backend()
|
|
)
|
|
if isinstance(private_key, ed25519.Ed25519PrivateKey):
|
|
public_key = private_key.public_key()
|
|
public_key_raw = public_key.public_bytes(
|
|
encoding=serialization.Encoding.Raw,
|
|
format=serialization.PublicFormat.Raw,
|
|
)
|
|
else:
|
|
# ECDSA: use SPKI format
|
|
public_key = serialization.load_pem_public_key(
|
|
self.device_identity.public_key_pem, backend=default_backend()
|
|
)
|
|
public_key_raw = public_key.public_bytes(
|
|
encoding=serialization.Encoding.DER,
|
|
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
|
)
|
|
public_key_b64 = base64.urlsafe_b64encode(public_key_raw).rstrip(b"=").decode()
|
|
|
|
# Build auth payload for signing (V3 format)
|
|
signed_at_ms = int(time.time() * 1000)
|
|
scopes = "operator.admin,operator.approvals,operator.pairing,operator.read,operator.write"
|
|
token = self.gateway_token or ""
|
|
platform = "darwin"
|
|
device_family = ""
|
|
|
|
# V3 payload: v3|deviceId|clientId|clientMode|role|scopes|signedAtMs|token|nonce|platform|deviceFamily
|
|
auth_payload = "|".join([
|
|
"v3",
|
|
self.device_identity.device_id,
|
|
self.client_name, # clientId
|
|
"backend", # clientMode
|
|
"operator", # role
|
|
scopes,
|
|
str(signed_at_ms),
|
|
token,
|
|
self._nonce or "",
|
|
platform,
|
|
device_family,
|
|
])
|
|
|
|
signature_b64 = self.device_identity.sign_base64url(auth_payload)
|
|
|
|
params = {
|
|
"minProtocol": PROTOCOL_VERSION,
|
|
"maxProtocol": PROTOCOL_VERSION,
|
|
"client": {
|
|
"id": self.client_name,
|
|
"version": self.client_version,
|
|
"platform": platform,
|
|
"mode": "backend",
|
|
},
|
|
"device": {
|
|
"id": self.device_identity.device_id,
|
|
"publicKey": public_key_b64,
|
|
"signature": signature_b64,
|
|
"signedAt": signed_at_ms,
|
|
"nonce": self._nonce,
|
|
},
|
|
"auth": {
|
|
"token": token or None,
|
|
},
|
|
"role": "operator",
|
|
"scopes": scopes.split(","),
|
|
}
|
|
|
|
# Debug output
|
|
print(f"DEBUG: nonce={self._nonce}", file=sys.stderr)
|
|
print(f"DEBUG: auth_payload={auth_payload}", file=sys.stderr)
|
|
print(f"DEBUG: connect params = {json.dumps(params, indent=2)}", file=sys.stderr)
|
|
|
|
return params
|
|
|
|
async def _recv_loop(self) -> None:
|
|
"""Receive and dispatch incoming messages."""
|
|
try:
|
|
async for raw in self._ws:
|
|
if raw is None:
|
|
break
|
|
await self._handle_frame(json.loads(raw))
|
|
except websockets.exceptions.ConnectionClosed:
|
|
pass
|
|
except Exception as e:
|
|
logger.error(f"Receive loop error: {e}")
|
|
finally:
|
|
# Clean up pending futures
|
|
for future in self._pending.values():
|
|
if not future.done():
|
|
future.set_exception(Exception("Connection closed"))
|
|
self._pending.clear()
|
|
self._connected = False
|
|
|
|
async def _handle_frame(self, frame: dict[str, Any]) -> None:
|
|
"""Handle incoming frame."""
|
|
frame_type = frame.get("type")
|
|
|
|
if frame_type == "event":
|
|
event_name = frame.get("event", "")
|
|
payload = frame.get("payload", {})
|
|
seq = frame.get("seq")
|
|
|
|
event = MessageEvent(event=event_name, payload=payload, seq=seq)
|
|
|
|
# Handle connect challenge
|
|
if event_name == "connect.challenge":
|
|
nonce = payload.get("nonce")
|
|
if nonce:
|
|
self._nonce = nonce
|
|
challenge_event = MessageEvent(event=event_name, payload={"nonce": nonce}, seq=seq)
|
|
for handler in self._event_handlers:
|
|
try:
|
|
handler(challenge_event)
|
|
except Exception as e:
|
|
logger.error(f"Event handler error: {e}")
|
|
|
|
# Notify event handlers
|
|
for handler in self._event_handlers:
|
|
try:
|
|
handler(event)
|
|
except Exception as e:
|
|
logger.error(f"Event handler error: {e}")
|
|
|
|
elif frame_type == "res":
|
|
req_id = frame.get("id")
|
|
if req_id in self._pending:
|
|
future = self._pending.pop(req_id)
|
|
if frame.get("ok"):
|
|
future.set_result(frame.get("payload", {}))
|
|
else:
|
|
error = frame.get("error", {})
|
|
future.set_exception(Exception(f"{error.get('code', 'ERROR')}: {error.get('message', 'Unknown error')}"))
|
|
|
|
async def _wait_for_event(self, event_name: str, timeout_ms: int | None = None) -> MessageEvent:
|
|
"""Wait for a specific event."""
|
|
future: asyncio.Future = asyncio.Future()
|
|
timeout = timeout_ms or self.timeout_ms
|
|
|
|
def handler(event: MessageEvent) -> None:
|
|
if event.event == event_name:
|
|
if not future.done():
|
|
future.set_result(event)
|
|
|
|
self._event_handlers.append(handler)
|
|
try:
|
|
return await asyncio.wait_for(future, timeout / 1000)
|
|
finally:
|
|
self._event_handlers.remove(handler)
|
|
|
|
async def _send_request(self, method: str, params: dict[str, Any] | None = None, _allow_handshake: bool = False) -> dict[str, Any]:
|
|
"""Send a request and wait for response.
|
|
|
|
Args:
|
|
method: The RPC method name
|
|
params: Optional parameters for the method
|
|
_allow_handshake: If True, allow sending even during handshake (for connect method)
|
|
"""
|
|
if not self._ws:
|
|
raise Exception("Not connected to Gateway")
|
|
if not self._connected and not _allow_handshake:
|
|
raise Exception("Not connected to Gateway")
|
|
|
|
req_id = str(uuid.uuid4())
|
|
frame = {"type": "req", "id": req_id, "method": method}
|
|
if params:
|
|
frame["params"] = params
|
|
|
|
future: asyncio.Future = asyncio.Future()
|
|
self._pending[req_id] = future
|
|
|
|
await self._ws.send(json.dumps(frame))
|
|
|
|
try:
|
|
return await asyncio.wait_for(future, self.timeout_ms / 1000)
|
|
except asyncio.TimeoutError:
|
|
self._pending.pop(req_id, None)
|
|
raise TimeoutError(f"Request {method} timed out after {self.timeout_ms}ms")
|
|
|
|
async def call_method(self, method: str, params: dict[str, Any] | None = None) -> dict[str, Any]:
|
|
"""Call any RPC method on the Gateway.
|
|
|
|
Args:
|
|
method: The RPC method name (e.g., "sessions.list", "agents.list")
|
|
params: Optional parameters for the method
|
|
|
|
Returns:
|
|
The response payload from the Gateway
|
|
"""
|
|
return await self._send_request(method, params)
|
|
|
|
async def disconnect(self) -> None:
|
|
"""Disconnect from the Gateway."""
|
|
self._connected = False
|
|
|
|
if self._recv_task:
|
|
self._recv_task.cancel()
|
|
try:
|
|
await self._recv_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
if self._ws:
|
|
await self._ws.close()
|
|
self._ws = None
|
|
|
|
self._hello = None
|
|
|
|
# -------------------------------------------------------------------------
|
|
# Session operations
|
|
# -------------------------------------------------------------------------
|
|
|
|
async def list_sessions(
|
|
self,
|
|
limit: int = 50,
|
|
agent_id: str | None = None,
|
|
include_last_message: bool = True,
|
|
) -> list[dict[str, Any]]:
|
|
"""List active sessions."""
|
|
params: dict[str, Any] = {"limit": limit, "includeLastMessage": include_last_message}
|
|
if agent_id:
|
|
params["agentId"] = agent_id
|
|
|
|
result = await self._send_request("sessions.list", params)
|
|
return result.get("sessions", [])
|
|
|
|
async def resolve_session(
|
|
self,
|
|
agent_id: str | None = None,
|
|
label: str | None = None,
|
|
channel: str | None = None,
|
|
include_global: bool = True,
|
|
) -> str | None:
|
|
"""Resolve a session key by agent and optional channel."""
|
|
params: dict[str, Any] = {"includeGlobal": include_global}
|
|
if agent_id:
|
|
params["agentId"] = agent_id
|
|
if label:
|
|
params["label"] = label
|
|
|
|
result = await self._send_request("sessions.resolve", params)
|
|
sessions = result.get("sessions", [])
|
|
if sessions:
|
|
return sessions[0].get("key")
|
|
return None
|
|
|
|
async def send_message(
|
|
self,
|
|
session_key: str,
|
|
message: str,
|
|
thinking: str | None = None,
|
|
timeout_ms: int | None = None,
|
|
) -> dict[str, Any]:
|
|
"""Send a message to an agent session.
|
|
|
|
Args:
|
|
session_key: The session key (format: agentId:channelId:accountId:conversationId)
|
|
message: The message text to send
|
|
thinking: Optional thinking/reasoning to include
|
|
timeout_ms: Timeout for the request
|
|
|
|
Returns:
|
|
The response payload containing message ID and result
|
|
"""
|
|
params: dict[str, Any] = {
|
|
"key": session_key,
|
|
"message": message,
|
|
}
|
|
if thinking:
|
|
params["thinking"] = thinking
|
|
|
|
# Use shorter timeout for send since it waits for agent response
|
|
result = await self._send_request(
|
|
"sessions.send",
|
|
params,
|
|
)
|
|
return result
|
|
|
|
async def subscribe(self, session_key: str) -> AsyncMessageIterator:
|
|
"""Subscribe to messages from a session.
|
|
|
|
Usage:
|
|
async for event in client.subscribe(session_key):
|
|
print(f"Event: {event.event}", event.payload)
|
|
|
|
Args:
|
|
session_key: The session key to subscribe to
|
|
|
|
Returns:
|
|
AsyncIterator of MessageEvents
|
|
"""
|
|
# First subscribe to the session
|
|
await self._send_request("sessions.messages.subscribe", {"key": session_key})
|
|
|
|
return AsyncMessageIterator(self, session_key)
|
|
|
|
# -------------------------------------------------------------------------
|
|
# Agent operations
|
|
# -------------------------------------------------------------------------
|
|
|
|
async def list_agents(self) -> list[dict[str, Any]]:
|
|
"""List configured agents."""
|
|
result = await self._send_request("agents.list", {})
|
|
return result.get("agents", [])
|
|
|
|
async def get_agent(self, agent_id: str) -> dict[str, Any] | None:
|
|
"""Get agent details."""
|
|
agents = await self.list_agents()
|
|
for agent in agents:
|
|
if agent.get("id") == agent_id:
|
|
return agent
|
|
return None
|
|
|
|
# -------------------------------------------------------------------------
|
|
# Channel operations
|
|
# -------------------------------------------------------------------------
|
|
|
|
async def channels_status(self, probe: bool = False) -> dict[str, Any]:
|
|
"""Get channel status."""
|
|
params = {"probe": probe} if probe else {}
|
|
return await self._send_request("channels.status", params)
|
|
|
|
async def channels_list(self) -> list[dict[str, Any]]:
|
|
"""List configured channels."""
|
|
result = await self._send_request("channels.list", {})
|
|
return result.get("channels", [])
|
|
|
|
# -------------------------------------------------------------------------
|
|
# Convenience methods
|
|
# -------------------------------------------------------------------------
|
|
|
|
async def send_to_agent(
|
|
self,
|
|
agent_id: str,
|
|
message: str,
|
|
channel: str | None = None,
|
|
label: str | None = None,
|
|
) -> dict[str, Any]:
|
|
"""Convenience method to send a message to an agent.
|
|
|
|
Resolves the session automatically.
|
|
|
|
Args:
|
|
agent_id: The agent ID
|
|
message: Message to send
|
|
channel: Optional channel to route through
|
|
label: Optional session label
|
|
|
|
Returns:
|
|
The agent's response
|
|
"""
|
|
session_key = await self.resolve_session(agent_id=agent_id, label=label, channel=channel)
|
|
if not session_key:
|
|
raise ValueError(f"No session found for agent {agent_id}")
|
|
|
|
return await self.send_message(session_key, message)
|
|
|
|
def add_event_handler(self, handler: Callable[[MessageEvent], None]) -> None:
|
|
"""Add an event handler for incoming events."""
|
|
self._event_handlers.append(handler)
|
|
|
|
def remove_event_handler(self, handler: Callable[[MessageEvent], None]) -> None:
|
|
"""Remove an event handler."""
|
|
self._event_handlers.remove(handler)
|
|
|
|
|
|
class AsyncMessageIterator:
|
|
"""Async iterator for session messages."""
|
|
|
|
def __init__(self, client: OpenClawWebSocketClient, session_key: str):
|
|
self._client = client
|
|
self._session_key = session_key
|
|
self._queue: asyncio.Queue[MessageEvent] = asyncio.Queue()
|
|
self._handler_added = False
|
|
|
|
def _on_event(self, event: MessageEvent) -> None:
|
|
"""Handle incoming event and check if it's for our session."""
|
|
# Filter to session-specific events
|
|
payload = event.payload or {}
|
|
event_session_key = payload.get("sessionKey") or payload.get("key")
|
|
if event_session_key == self._session_key or event.event.startswith("sessions."):
|
|
self._queue.put_nowait(event)
|
|
|
|
async def __aiter__(self) -> "AsyncMessageIterator":
|
|
if not self._handler_added:
|
|
self._client.add_event_handler(self._on_event)
|
|
self._handler_added = True
|
|
return self
|
|
|
|
async def __anext__(self) -> MessageEvent:
|
|
return await self._queue.get()
|
|
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# Synchronous convenience functions
|
|
# -----------------------------------------------------------------------------
|
|
|
|
async def send_to_agent(
|
|
message: str,
|
|
agent_id: str,
|
|
gateway_url: str = DEFAULT_GATEWAY_URL,
|
|
gateway_token: str | None = None,
|
|
channel: str | None = None,
|
|
label: str | None = None,
|
|
timeout_ms: int = 60000,
|
|
) -> dict[str, Any]:
|
|
"""Send a message to an agent and wait for response.
|
|
|
|
This is a convenience function for one-shot message sending.
|
|
|
|
Args:
|
|
message: The message to send
|
|
agent_id: The agent ID to target
|
|
gateway_url: Gateway WebSocket URL
|
|
gateway_token: Optional gateway auth token
|
|
channel: Optional channel to route through
|
|
label: Optional session label
|
|
timeout_ms: Timeout in milliseconds
|
|
|
|
Returns:
|
|
The agent's response
|
|
|
|
Example:
|
|
response = await send_to_agent("Hello!", agent_id="my-agent")
|
|
"""
|
|
async with OpenClawWebSocketClient(
|
|
url=gateway_url,
|
|
gateway_token=gateway_token,
|
|
) as client:
|
|
await client.connect()
|
|
return await client.send_to_agent(
|
|
agent_id=agent_id,
|
|
message=message,
|
|
channel=channel,
|
|
label=label,
|
|
)
|
|
|
|
|
|
async def list_active_sessions(
|
|
gateway_url: str = DEFAULT_GATEWAY_URL,
|
|
gateway_token: str | None = None,
|
|
agent_id: str | None = None,
|
|
) -> list[dict[str, Any]]:
|
|
"""List active sessions.
|
|
|
|
Args:
|
|
gateway_url: Gateway WebSocket URL
|
|
gateway_token: Optional gateway auth token
|
|
agent_id: Optional agent ID to filter by
|
|
|
|
Returns:
|
|
List of active sessions
|
|
"""
|
|
async with OpenClawWebSocketClient(
|
|
url=gateway_url,
|
|
gateway_token=gateway_token,
|
|
) as client:
|
|
await client.connect()
|
|
return await client.list_sessions(agent_id=agent_id)
|