Files
evotraders/shared/client/openclaw_websocket_client.py

754 lines
26 KiB
Python

# -*- coding: utf-8 -*-
"""OpenClaw Gateway WebSocket client for bidirectional agent communication."""
from __future__ import annotations
import asyncio
import json
import logging
import time
import uuid
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Callable
import httpx
import websockets
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives.asymmetric.utils import decode_dss_signature
from cryptography.hazmat.backends import default_backend
logger = logging.getLogger(__name__)
# Default Gateway port
DEFAULT_GATEWAY_PORT = 18789
DEFAULT_GATEWAY_URL = f"ws://127.0.0.1:{DEFAULT_GATEWAY_PORT}"
# Protocol version (from protocol/schema/protocol-schemas.ts)
PROTOCOL_VERSION = 3
@dataclass
class DeviceIdentity:
"""Device identity for Gateway authentication."""
device_id: str
public_key_pem: bytes
private_key_pem: bytes
@classmethod
def load_or_create(cls, identity_dir: Path | None = None) -> "DeviceIdentity":
"""Load existing device identity from OpenClaw's identity directory, or create a new one."""
if identity_dir is None:
identity_dir = Path.home() / ".openclaw" / "identity"
device_json = identity_dir / "device.json"
# Check if identity exists in OpenClaw's format
if device_json.exists():
import json
data = json.loads(device_json.read_text())
return cls(
device_id=data["deviceId"],
public_key_pem=data["publicKeyPem"].encode(),
private_key_pem=data["privateKeyPem"].encode(),
)
# Fall back to old devices directory format
device_dir = Path.home() / ".openclaw" / "devices"
id_file = device_dir / "device_id"
pubkey_file = device_dir / "device_pubkey.pem"
privkey_file = device_dir / "device_privkey.pem"
if id_file.exists() and pubkey_file.exists() and privkey_file.exists():
device_id = id_file.read_text().strip()
public_key_pem = pubkey_file.read_bytes()
private_key_pem = privkey_file.read_bytes()
return cls(
device_id=device_id,
public_key_pem=public_key_pem,
private_key_pem=private_key_pem,
)
# Generate new identity (Ed25519, matching OpenClaw's approach)
from cryptography.hazmat.primitives.asymmetric import ed25519
private_key = ed25519.Ed25519PrivateKey.generate()
public_key = private_key.public_key()
# Derive device ID from public key (SHA256 hash)
import hashlib
public_key_raw = public_key.public_bytes(
encoding=serialization.Encoding.Raw,
format=serialization.PublicFormat.Raw,
)
device_id = hashlib.sha256(public_key_raw).hexdigest()
public_key_pem = public_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
private_key_pem = private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption(),
)
# Save to disk
device_dir.mkdir(parents=True, exist_ok=True)
id_file.write_text(device_id)
pubkey_file.write_bytes(public_key_pem)
privkey_file.write_bytes(private_key_pem)
logger.info(f"Created new device identity: {device_id}")
return cls(
device_id=device_id,
public_key_pem=public_key_pem,
private_key_pem=private_key_pem,
)
def sign(self, payload: str) -> tuple[int, int]:
"""Sign a payload (for ECDSA keys)."""
private_key = serialization.load_pem_private_key(
self.private_key_pem, password=None, backend=default_backend()
)
signature = private_key.sign(payload.encode(), ec.ECDSA(hashes.SHA256()))
r, s = decode_dss_signature(signature)
return r, s
def sign_base64url(self, payload: str) -> str:
"""Sign payload and return base64url encoded signature (matches TypeScript crypto.sign)."""
import base64
private_key = serialization.load_pem_private_key(
self.private_key_pem, password=None, backend=default_backend()
)
# Ed25519 signing (used by OpenClaw)
sig = private_key.sign(payload.encode())
return base64.urlsafe_b64encode(sig).rstrip(b"=").decode()
@dataclass
class GatewayHello:
"""Gateway hello response after connection."""
protocol: int
server_version: str
conn_id: str
methods: list[str]
events: list[str]
device_token: str | None = None
role: str | None = None
scopes: list[str] | None = None
@dataclass
class MessageEvent:
"""Incoming message event from agent."""
event: str
payload: dict[str, Any]
seq: int | None = None
@dataclass
class SendResult:
"""Result of sending a message."""
message_id: str
session_key: str
ok: bool
class OpenClawWebSocketClient:
"""WebSocket client for OpenClaw Gateway.
Supports:
- Device authentication
- Send messages to agents via sessions.send
- Receive real-time responses via event subscription
- Session management
Example usage:
async with OpenClawWebSocketClient() as client:
await client.connect()
result = await client.send_message(session_key, "Hello agent!")
async for event in client.subscribe(session_key):
print(event)
"""
def __init__(
self,
url: str = DEFAULT_GATEWAY_URL,
gateway_token: str | None = None,
device_identity: DeviceIdentity | None = None,
client_name: str = "cli", # Must be a valid GatewayClientId (cli, gateway-client, etc)
client_version: str = "1.0.0",
timeout_ms: int = 30000,
):
self.url = url
self.gateway_token = gateway_token or self._load_gateway_token()
self.device_identity = device_identity
self.client_name = client_name
self.client_version = client_version
self.timeout_ms = timeout_ms
self._ws: websockets.WebSocketClientProtocol | None = None
self._hello: GatewayHello | None = None
self._pending: dict[str, asyncio.Future] = {}
self._event_handlers: list[Callable[[MessageEvent], None]] = []
self._recv_task: asyncio.Task | None = None
self._nonce: str | None = None
self._connected = False
@staticmethod
def _load_gateway_token() -> str | None:
"""Load gateway token from ~/.openclaw/openclaw.json."""
try:
from pathlib import Path
token_file = Path.home() / ".openclaw" / "openclaw.json"
if token_file.exists():
data = json.loads(token_file.read_text())
return data.get("gateway", {}).get("auth", {}).get("token")
except Exception:
pass
return None
@property
def is_connected(self) -> bool:
return self._connected and self._ws is not None
async def __aenter__(self) -> "OpenClawWebSocketClient":
return self
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
await self.disconnect()
async def connect(self) -> GatewayHello:
"""Connect to the Gateway and complete authentication handshake."""
if self._connected:
return self._hello
# Load or create device identity
if self.device_identity is None:
self.device_identity = DeviceIdentity.load_or_create()
logger.info(f"Connecting to OpenClaw Gateway at {self.url}")
self._ws = await websockets.connect(
self.url,
max_size=25 * 1024 * 1024, # 25MB max payload
)
# Start receive loop
self._recv_task = asyncio.create_task(self._recv_loop())
# Wait for connect.challenge
challenge = await self._wait_for_event("connect.challenge")
self._nonce = challenge.payload.get("nonce")
# Build connect params
connect_params = self._build_connect_params()
# Send connect request and wait for hello-ok
hello_event = await self._send_request("connect", connect_params, _allow_handshake=True)
self._hello = GatewayHello(
protocol=hello_event["protocol"],
server_version=hello_event["server"]["version"],
conn_id=hello_event["server"]["connId"],
methods=hello_event["features"]["methods"],
events=hello_event["features"]["events"],
device_token=hello_event.get("auth", {}).get("deviceToken"),
role=hello_event.get("auth", {}).get("role"),
scopes=hello_event.get("auth", {}).get("scopes"),
)
self._connected = True
logger.info(f"Connected to OpenClaw Gateway v{self._hello.server_version}")
logger.info(f"Supported methods: {self._hello.methods}")
return self._hello
def _build_connect_params(self) -> dict[str, Any]:
"""Build connect parameters with device authentication.
Implements V3 device auth payload format:
v3|deviceId|clientId|clientMode|role|scopes|signedAtMs|token|nonce|platform|deviceFamily
"""
import base64
# Load public key - use Raw format for Ed25519 (32 bytes)
from cryptography.hazmat.primitives.asymmetric import ed25519
private_key = serialization.load_pem_private_key(
self.device_identity.private_key_pem, password=None, backend=default_backend()
)
if isinstance(private_key, ed25519.Ed25519PrivateKey):
public_key = private_key.public_key()
public_key_raw = public_key.public_bytes(
encoding=serialization.Encoding.Raw,
format=serialization.PublicFormat.Raw,
)
else:
# ECDSA: use SPKI format
public_key = serialization.load_pem_public_key(
self.device_identity.public_key_pem, backend=default_backend()
)
public_key_raw = public_key.public_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
public_key_b64 = base64.urlsafe_b64encode(public_key_raw).rstrip(b"=").decode()
# Build auth payload for signing (V3 format)
signed_at_ms = int(time.time() * 1000)
scopes = "operator.admin,operator.approvals,operator.pairing,operator.read,operator.write"
token = self.gateway_token or ""
platform = "darwin"
device_family = ""
# V3 payload: v3|deviceId|clientId|clientMode|role|scopes|signedAtMs|token|nonce|platform|deviceFamily
auth_payload = "|".join([
"v3",
self.device_identity.device_id,
self.client_name, # clientId
"backend", # clientMode
"operator", # role
scopes,
str(signed_at_ms),
token,
self._nonce or "",
platform,
device_family,
])
signature_b64 = self.device_identity.sign_base64url(auth_payload)
params = {
"minProtocol": PROTOCOL_VERSION,
"maxProtocol": PROTOCOL_VERSION,
"client": {
"id": self.client_name,
"version": self.client_version,
"platform": platform,
"mode": "backend",
},
"device": {
"id": self.device_identity.device_id,
"publicKey": public_key_b64,
"signature": signature_b64,
"signedAt": signed_at_ms,
"nonce": self._nonce,
},
"auth": {
"token": token or None,
},
"role": "operator",
"scopes": scopes.split(","),
}
return params
async def _recv_loop(self) -> None:
"""Receive and dispatch incoming messages."""
try:
async for raw in self._ws:
if raw is None:
break
await self._handle_frame(json.loads(raw))
except websockets.exceptions.ConnectionClosed:
pass
except Exception as e:
logger.error(f"Receive loop error: {e}")
finally:
# Clean up pending futures
for future in self._pending.values():
if not future.done():
future.set_exception(Exception("Connection closed"))
self._pending.clear()
self._connected = False
async def _handle_frame(self, frame: dict[str, Any]) -> None:
"""Handle incoming frame."""
frame_type = frame.get("type")
if frame_type == "event":
event_name = frame.get("event", "")
payload = frame.get("payload", {})
seq = frame.get("seq")
event = MessageEvent(event=event_name, payload=payload, seq=seq)
# Handle connect challenge
if event_name == "connect.challenge":
nonce = payload.get("nonce")
if nonce:
self._nonce = nonce
challenge_event = MessageEvent(event=event_name, payload={"nonce": nonce}, seq=seq)
for handler in self._event_handlers:
try:
handler(challenge_event)
except Exception as e:
logger.error(f"Event handler error: {e}")
# Notify event handlers
for handler in self._event_handlers:
try:
handler(event)
except Exception as e:
logger.error(f"Event handler error: {e}")
elif frame_type == "res":
req_id = frame.get("id")
if req_id in self._pending:
future = self._pending.pop(req_id)
if frame.get("ok"):
future.set_result(frame.get("payload", {}))
else:
error = frame.get("error", {})
future.set_exception(Exception(f"{error.get('code', 'ERROR')}: {error.get('message', 'Unknown error')}"))
async def _wait_for_event(self, event_name: str, timeout_ms: int | None = None) -> MessageEvent:
"""Wait for a specific event."""
future: asyncio.Future = asyncio.Future()
timeout = timeout_ms or self.timeout_ms
def handler(event: MessageEvent) -> None:
if event.event == event_name:
if not future.done():
future.set_result(event)
self._event_handlers.append(handler)
try:
return await asyncio.wait_for(future, timeout / 1000)
finally:
self._event_handlers.remove(handler)
async def _send_request(self, method: str, params: dict[str, Any] | None = None, _allow_handshake: bool = False) -> dict[str, Any]:
"""Send a request and wait for response.
Args:
method: The RPC method name
params: Optional parameters for the method
_allow_handshake: If True, allow sending even during handshake (for connect method)
"""
if not self._ws:
raise Exception("Not connected to Gateway")
if not self._connected and not _allow_handshake:
raise Exception("Not connected to Gateway")
req_id = str(uuid.uuid4())
frame = {"type": "req", "id": req_id, "method": method}
if params:
frame["params"] = params
future: asyncio.Future = asyncio.Future()
self._pending[req_id] = future
await self._ws.send(json.dumps(frame))
try:
return await asyncio.wait_for(future, self.timeout_ms / 1000)
except asyncio.TimeoutError:
self._pending.pop(req_id, None)
raise TimeoutError(f"Request {method} timed out after {self.timeout_ms}ms")
async def call_method(self, method: str, params: dict[str, Any] | None = None) -> dict[str, Any]:
"""Call any RPC method on the Gateway.
Args:
method: The RPC method name (e.g., "sessions.list", "agents.list")
params: Optional parameters for the method
Returns:
The response payload from the Gateway
"""
return await self._send_request(method, params)
async def disconnect(self) -> None:
"""Disconnect from the Gateway."""
self._connected = False
if self._recv_task:
self._recv_task.cancel()
try:
await self._recv_task
except asyncio.CancelledError:
pass
if self._ws:
await self._ws.close()
self._ws = None
self._hello = None
# -------------------------------------------------------------------------
# Session operations
# -------------------------------------------------------------------------
async def list_sessions(
self,
limit: int = 50,
agent_id: str | None = None,
include_last_message: bool = True,
) -> list[dict[str, Any]]:
"""List active sessions."""
params: dict[str, Any] = {"limit": limit, "includeLastMessage": include_last_message}
if agent_id:
params["agentId"] = agent_id
result = await self._send_request("sessions.list", params)
return result.get("sessions", [])
async def resolve_session(
self,
agent_id: str | None = None,
label: str | None = None,
channel: str | None = None,
include_global: bool = True,
) -> str | None:
"""Resolve a session key by agent and optional channel."""
params: dict[str, Any] = {"includeGlobal": include_global}
if agent_id:
params["agentId"] = agent_id
if label:
params["label"] = label
if channel:
params["channel"] = channel
result = await self._send_request("sessions.resolve", params)
key = result.get("key")
if isinstance(key, str) and key.strip():
return key.strip()
return None
async def send_message(
self,
session_key: str,
message: str,
thinking: str | None = None,
timeout_ms: int | None = None,
) -> dict[str, Any]:
"""Send a message to an agent session.
Args:
session_key: The session key (format: agentId:channelId:accountId:conversationId)
message: The message text to send
thinking: Optional thinking/reasoning to include
timeout_ms: Timeout for the request
Returns:
The response payload containing message ID and result
"""
params: dict[str, Any] = {
"key": session_key,
"message": message,
}
if thinking:
params["thinking"] = thinking
previous_timeout_ms = self.timeout_ms
if timeout_ms is not None:
self.timeout_ms = timeout_ms
try:
return await self._send_request("sessions.send", params)
finally:
self.timeout_ms = previous_timeout_ms
async def unsubscribe(self, session_key: str) -> dict[str, Any]:
"""Unsubscribe from messages for a session."""
return await self._send_request("sessions.messages.unsubscribe", {"key": session_key})
async def get_session_history(
self,
session_key: str,
limit: int = 20,
) -> dict[str, Any]:
"""Best-effort session history read.
OpenClaw's public Gateway surface is subscription-first for live message flow.
History is not consistently exposed over the same WS methods across builds, so
callers should still keep a CLI or REST fallback available.
"""
return await self._send_request("sessions.preview", {"keys": [session_key], "limit": limit})
async def subscribe(self, session_key: str) -> AsyncMessageIterator:
"""Subscribe to messages from a session.
Usage:
async for event in client.subscribe(session_key):
print(f"Event: {event.event}", event.payload)
Args:
session_key: The session key to subscribe to
Returns:
AsyncIterator of MessageEvents
"""
# First subscribe to the session
await self._send_request("sessions.messages.subscribe", {"key": session_key})
return AsyncMessageIterator(self, session_key)
# -------------------------------------------------------------------------
# Agent operations
# -------------------------------------------------------------------------
async def list_agents(self) -> list[dict[str, Any]]:
"""List configured agents."""
result = await self._send_request("agents.list", {})
return result.get("agents", [])
async def get_agent(self, agent_id: str) -> dict[str, Any] | None:
"""Get agent details."""
agents = await self.list_agents()
for agent in agents:
if agent.get("id") == agent_id:
return agent
return None
# -------------------------------------------------------------------------
# Channel operations
# -------------------------------------------------------------------------
async def channels_status(self, probe: bool = False) -> dict[str, Any]:
"""Get channel status."""
params = {"probe": probe} if probe else {}
return await self._send_request("channels.status", params)
async def channels_list(self) -> list[dict[str, Any]]:
"""List configured channels."""
result = await self._send_request("channels.list", {})
return result.get("channels", [])
# -------------------------------------------------------------------------
# Convenience methods
# -------------------------------------------------------------------------
async def send_to_agent(
self,
agent_id: str,
message: str,
channel: str | None = None,
label: str | None = None,
) -> dict[str, Any]:
"""Convenience method to send a message to an agent.
Resolves the session automatically.
Args:
agent_id: The agent ID
message: Message to send
channel: Optional channel to route through
label: Optional session label
Returns:
The agent's response
"""
session_key = await self.resolve_session(agent_id=agent_id, label=label, channel=channel)
if not session_key:
raise ValueError(f"No session found for agent {agent_id}")
return await self.send_message(session_key, message)
def add_event_handler(self, handler: Callable[[MessageEvent], None]) -> None:
"""Add an event handler for incoming events."""
self._event_handlers.append(handler)
def remove_event_handler(self, handler: Callable[[MessageEvent], None]) -> None:
"""Remove an event handler."""
self._event_handlers.remove(handler)
class AsyncMessageIterator:
"""Async iterator for session messages."""
def __init__(self, client: OpenClawWebSocketClient, session_key: str):
self._client = client
self._session_key = session_key
self._queue: asyncio.Queue[MessageEvent] = asyncio.Queue()
self._handler_added = False
def _on_event(self, event: MessageEvent) -> None:
"""Handle incoming event and check if it's for our session."""
# Filter to session-specific events
payload = event.payload or {}
event_session_key = payload.get("sessionKey") or payload.get("key")
if event_session_key == self._session_key or event.event.startswith("sessions."):
self._queue.put_nowait(event)
async def __aiter__(self) -> "AsyncMessageIterator":
if not self._handler_added:
self._client.add_event_handler(self._on_event)
self._handler_added = True
return self
async def __anext__(self) -> MessageEvent:
return await self._queue.get()
async def aclose(self) -> None:
if self._handler_added:
self._client.remove_event_handler(self._on_event)
self._handler_added = False
# -----------------------------------------------------------------------------
# Synchronous convenience functions
# -----------------------------------------------------------------------------
async def send_to_agent(
message: str,
agent_id: str,
gateway_url: str = DEFAULT_GATEWAY_URL,
gateway_token: str | None = None,
channel: str | None = None,
label: str | None = None,
timeout_ms: int = 60000,
) -> dict[str, Any]:
"""Send a message to an agent and wait for response.
This is a convenience function for one-shot message sending.
Args:
message: The message to send
agent_id: The agent ID to target
gateway_url: Gateway WebSocket URL
gateway_token: Optional gateway auth token
channel: Optional channel to route through
label: Optional session label
timeout_ms: Timeout in milliseconds
Returns:
The agent's response
Example:
response = await send_to_agent("Hello!", agent_id="my-agent")
"""
async with OpenClawWebSocketClient(
url=gateway_url,
gateway_token=gateway_token,
) as client:
await client.connect()
return await client.send_to_agent(
agent_id=agent_id,
message=message,
channel=channel,
label=label,
)
async def list_active_sessions(
gateway_url: str = DEFAULT_GATEWAY_URL,
gateway_token: str | None = None,
agent_id: str | None = None,
) -> list[dict[str, Any]]:
"""List active sessions.
Args:
gateway_url: Gateway WebSocket URL
gateway_token: Optional gateway auth token
agent_id: Optional agent ID to filter by
Returns:
List of active sessions
"""
async with OpenClawWebSocketClient(
url=gateway_url,
gateway_token=gateway_token,
) as client:
await client.connect()
return await client.list_sessions(agent_id=agent_id)