Initial commit of integrated agent system

This commit is contained in:
cillin
2026-03-30 17:46:44 +08:00
commit 0fa413380c
337 changed files with 75268 additions and 0 deletions

549
backend/llm/models.py Normal file
View File

@@ -0,0 +1,549 @@
# -*- coding: utf-8 -*-
"""
AgentScope Native Model Factory
Uses native AgentScope model classes for LLM calls
"""
import asyncio
import inspect
import os
import time
import logging
from enum import Enum
from typing import Any, Callable, Optional, Tuple, TypeVar, Union
from agentscope.formatter import (
AnthropicChatFormatter,
DashScopeChatFormatter,
GeminiChatFormatter,
OllamaChatFormatter,
OpenAIChatFormatter,
)
from agentscope.model import (
AnthropicChatModel,
DashScopeChatModel,
GeminiChatModel,
OllamaChatModel,
OpenAIChatModel,
)
from backend.config.env_config import (
canonicalize_model_provider,
get_agent_model_config,
get_env_str,
)
logger = logging.getLogger(__name__)
# Retry wrapper types
T = TypeVar("T")
def _usage_value(usage: Any, key: str, default: Any = 0) -> Any:
"""Read usage fields from both object-style and dict-style usage payloads."""
if usage is None:
return default
if isinstance(usage, dict):
return usage.get(key, default)
try:
return getattr(usage, key)
except (AttributeError, KeyError):
return default
def _usage_total_tokens(usage: Any) -> int:
total = _usage_value(usage, "total_tokens", None)
if total is not None:
return int(total or 0)
input_tokens = _usage_value(usage, "input_tokens", 0)
output_tokens = _usage_value(usage, "output_tokens", 0)
return int((input_tokens or 0) + (output_tokens or 0))
class RetryChatModel:
"""Wraps an AgentScope model with automatic retry for transient errors.
Based on CoPaw's RetryChatModel design. Handles rate limits, timeouts,
and other transient failures with exponential backoff.
"""
DEFAULT_MAX_RETRIES = 3
DEFAULT_INITIAL_DELAY = 1.0
DEFAULT_MAX_DELAY = 60.0
DEFAULT_BACKOFF_MULTIPLIER = 2.0
# Transient error codes/messages that should trigger retry
TRANSIENT_ERROR_KEYWORDS = frozenset([
"rate_limit",
"429",
"timeout",
"503",
"502",
"504",
"connection",
"disconnected",
"temporary",
"overloaded",
"too_many_requests",
])
def __init__(
self,
model: Any,
max_retries: int = DEFAULT_MAX_RETRIES,
initial_delay: float = DEFAULT_INITIAL_DELAY,
max_delay: float = DEFAULT_MAX_DELAY,
backoff_multiplier: float = DEFAULT_BACKOFF_MULTIPLIER,
on_retry: Optional[Callable[[int, Exception, float], None]] = None,
):
"""Initialize retry wrapper.
Args:
model: The underlying AgentScope model to wrap
max_retries: Maximum number of retry attempts
initial_delay: Initial delay in seconds before first retry
max_delay: Maximum delay between retries
backoff_multiplier: Multiplier for exponential backoff
on_retry: Optional callback(retry_count, exception, delay) for logging
"""
self._model = model
self._max_retries = max_retries
self._initial_delay = initial_delay
self._max_delay = max_delay
self._backoff_multiplier = backoff_multiplier
self._on_retry = on_retry
self._total_tokens_used = 0
self._total_cost = 0.0
@property
def model_name(self) -> str:
return getattr(self._model, "model_name", str(self._model))
@property
def total_tokens_used(self) -> int:
return self._total_tokens_used
@property
def total_cost(self) -> float:
return self._total_cost
def _is_transient_error(self, error: Exception) -> bool:
"""Check if an error is transient and should be retried.
Args:
error: The exception to check
Returns:
True if the error is transient
"""
error_str = str(error).lower()
for keyword in self.TRANSIENT_ERROR_KEYWORDS:
if keyword in error_str:
return True
return False
def _calculate_delay(self, retry_count: int) -> float:
"""Calculate delay for given retry attempt with exponential backoff.
Args:
retry_count: Current retry attempt number (1-based)
Returns:
Delay in seconds
"""
delay = self._initial_delay * (self._backoff_multiplier ** (retry_count - 1))
return min(delay, self._max_delay)
def _call_with_retry(self, func: Callable[..., T], *args, **kwargs) -> T:
"""Call a function with retry logic for transient errors.
Args:
func: Function to call
*args: Positional arguments
**kwargs: Keyword arguments
Returns:
Result from func
Raises:
Last exception if all retries exhausted
"""
last_error: Optional[Exception] = None
for attempt in range(1, self._max_retries + 1):
try:
result = func(*args, **kwargs)
# Track usage if available
if hasattr(result, "usage") and result.usage:
usage = result.usage
self._total_tokens_used += _usage_total_tokens(usage)
self._total_cost += float(_usage_value(usage, "cost", 0.0) or 0.0)
return result
except Exception as e:
last_error = e
if attempt >= self._max_retries:
logger.error(
"RetryChatModel: Max retries (%d) exhausted for %s",
self._max_retries,
self.model_name,
)
break
if not self._is_transient_error(e):
logger.warning(
"RetryChatModel: Non-transient error, not retrying: %s",
str(e),
)
break
delay = self._calculate_delay(attempt)
logger.warning(
"RetryChatModel: Transient error on attempt %d/%d, "
"retrying in %.1fs: %s",
attempt,
self._max_retries,
delay,
str(e)[:200],
)
if self._on_retry:
self._on_retry(attempt, e, delay)
time.sleep(delay)
if last_error is not None:
raise last_error
raise RuntimeError("RetryChatModel: Unexpected state, no error but no result")
async def _call_with_retry_async(self, func: Callable[..., T], *args, **kwargs) -> T:
"""Call an async function with retry logic for transient errors."""
last_error: Optional[Exception] = None
for attempt in range(1, self._max_retries + 1):
try:
result = await func(*args, **kwargs)
if hasattr(result, "usage") and result.usage:
usage = result.usage
self._total_tokens_used += _usage_total_tokens(usage)
self._total_cost += float(_usage_value(usage, "cost", 0.0) or 0.0)
return result
except Exception as e:
last_error = e
if attempt >= self._max_retries:
logger.error(
"RetryChatModel: Max retries (%d) exhausted for %s",
self._max_retries,
self.model_name,
)
break
if not self._is_transient_error(e):
logger.warning(
"RetryChatModel: Non-transient error, not retrying: %s",
str(e),
)
break
delay = self._calculate_delay(attempt)
logger.warning(
"RetryChatModel: Transient async error on attempt %d/%d, "
"retrying in %.1fs: %s",
attempt,
self._max_retries,
delay,
str(e)[:200],
)
if self._on_retry:
self._on_retry(attempt, e, delay)
await asyncio.sleep(delay)
if last_error is not None:
raise last_error
raise RuntimeError("RetryChatModel: Unexpected async state, no error but no result")
def __call__(self, *args, **kwargs) -> Any:
"""Forward calls to the wrapped model with retry logic."""
model_call = getattr(self._model, "__call__", None)
if inspect.iscoroutinefunction(self._model) or inspect.iscoroutinefunction(model_call):
return self._call_with_retry_async(self._model, *args, **kwargs)
result = self._model(*args, **kwargs)
return result
def __getattr__(self, name: str) -> Any:
"""Proxy attribute access to the wrapped model."""
return getattr(self._model, name)
class TokenRecordingModelWrapper:
"""Wraps a model to track token usage per provider.
Based on CoPaw's TokenRecordingModelWrapper design.
"""
def __init__(self, model: Any):
"""Initialize token recorder.
Args:
model: The underlying AgentScope model to wrap
"""
self._model = model
self._total_tokens = 0
self._prompt_tokens = 0
self._completion_tokens = 0
self._total_cost = 0.0
@property
def model_name(self) -> str:
return getattr(self._model, "model_name", str(self._model))
@property
def total_tokens(self) -> int:
return self._total_tokens
@property
def prompt_tokens(self) -> int:
return self._prompt_tokens
@property
def completion_tokens(self) -> int:
return self._completion_tokens
@property
def total_cost(self) -> float:
return self._total_cost
def record_usage(self, usage: Any) -> None:
"""Record token usage from a model response.
Args:
usage: Usage object from model response
"""
if usage is None:
return
prompt_tokens = _usage_value(usage, "prompt_tokens", None)
completion_tokens = _usage_value(usage, "completion_tokens", None)
if prompt_tokens is None:
prompt_tokens = _usage_value(usage, "input_tokens", 0)
if completion_tokens is None:
completion_tokens = _usage_value(usage, "output_tokens", 0)
self._prompt_tokens += int(prompt_tokens or 0)
self._completion_tokens += int(completion_tokens or 0)
self._total_tokens += _usage_total_tokens(usage)
self._total_cost += float(_usage_value(usage, "cost", 0.0) or 0.0)
def __call__(self, *args, **kwargs) -> Any:
"""Forward calls and record usage."""
result = self._model(*args, **kwargs)
if hasattr(result, "usage") and result.usage:
self.record_usage(result.usage)
return result
def __getattr__(self, name: str) -> Any:
"""Proxy attribute access to the wrapped model."""
return getattr(self._model, name)
class ModelProvider(Enum):
"""Supported model providers"""
OPENAI = "OPENAI"
ANTHROPIC = "ANTHROPIC"
DASHSCOPE = "DASHSCOPE"
ALIBABA = "ALIBABA"
GEMINI = "GEMINI"
GOOGLE = "GOOGLE"
OLLAMA = "OLLAMA"
DEEPSEEK = "DEEPSEEK"
GROQ = "GROQ"
OPENROUTER = "OPENROUTER"
# Provider to AgentScope model class mapping
PROVIDER_MODEL_MAP = {
"OPENAI": OpenAIChatModel,
"ANTHROPIC": AnthropicChatModel,
"DASHSCOPE": DashScopeChatModel,
"ALIBABA": DashScopeChatModel,
"GEMINI": GeminiChatModel,
"GOOGLE": GeminiChatModel,
"OLLAMA": OllamaChatModel,
# OpenAI-compatible providers use OpenAIChatModel with custom base_url
"DEEPSEEK": OpenAIChatModel,
"GROQ": OpenAIChatModel,
"OPENROUTER": OpenAIChatModel,
}
# Provider to formatter mapping
PROVIDER_FORMATTER_MAP = {
"OPENAI": OpenAIChatFormatter,
"ANTHROPIC": AnthropicChatFormatter,
"DASHSCOPE": DashScopeChatFormatter,
"ALIBABA": DashScopeChatFormatter,
"GEMINI": GeminiChatFormatter,
"GOOGLE": GeminiChatFormatter,
"OLLAMA": OllamaChatFormatter,
# OpenAI-compatible providers use OpenAIChatFormatter
"DEEPSEEK": OpenAIChatFormatter,
"GROQ": OpenAIChatFormatter,
"OPENROUTER": OpenAIChatFormatter,
}
# Provider-specific base URLs
PROVIDER_BASE_URLS = {
"DEEPSEEK": "https://api.deepseek.com/v1",
"GROQ": "https://api.groq.com/openai/v1",
"OPENROUTER": "https://openrouter.ai/api/v1",
}
# Provider-specific API key environment variable names
PROVIDER_API_KEY_ENV = {
"OPENAI": "OPENAI_API_KEY",
"ANTHROPIC": "ANTHROPIC_API_KEY",
"DASHSCOPE": "DASHSCOPE_API_KEY",
"ALIBABA": "DASHSCOPE_API_KEY",
"GEMINI": "GOOGLE_API_KEY",
"GOOGLE": "GOOGLE_API_KEY",
"DEEPSEEK": "DEEPSEEK_API_KEY",
"GROQ": "GROQ_API_KEY",
"OPENROUTER": "OPENROUTER_API_KEY",
}
def create_model(
model_name: str,
provider: str,
api_key: Optional[str] = None,
stream: bool = False,
**kwargs,
):
"""
Create an AgentScope model instance
Args:
model_name: Model name (e.g., "gpt-4o", "claude-3-opus")
provider: Provider name (e.g., "OPENAI", "ANTHROPIC")
api_key: API key (optional, will read from env if not provided)
stream: Whether to use streaming mode
**kwargs: Additional model-specific arguments
Returns:
AgentScope model instance
"""
provider = canonicalize_model_provider(provider)
model_class = PROVIDER_MODEL_MAP.get(provider)
if model_class is None:
raise ValueError(f"Unsupported provider: {provider}")
# Get API key from env if not provided
if api_key is None:
env_key = PROVIDER_API_KEY_ENV.get(provider)
if env_key:
api_key = os.getenv(env_key)
# Build model kwargs
model_kwargs = {
"model_name": model_name,
"stream": stream,
**kwargs,
}
# Add API key if needed (Ollama doesn't need it)
if provider != "OLLAMA" and api_key:
model_kwargs["api_key"] = api_key
# Handle OpenAI-compatible providers with custom base_url
if provider in PROVIDER_BASE_URLS:
base_url = PROVIDER_BASE_URLS[provider]
model_kwargs["client_args"] = {"base_url": base_url}
# Handle custom OpenAI base URL
if provider == "OPENAI":
base_url = get_env_str("OPENAI_BASE_URL") or get_env_str(
"OPENAI_API_BASE",
)
if base_url:
model_kwargs["client_args"] = {"base_url": base_url}
# Handle DashScope base URL (uses different parameter)
if provider in ("DASHSCOPE", "ALIBABA"):
base_url = get_env_str("DASHSCOPE_BASE_URL")
if base_url:
model_kwargs["base_http_api_url"] = base_url
# Handle Ollama host
if provider == "OLLAMA":
host = get_env_str("OLLAMA_HOST")
if host:
model_kwargs["host"] = host
model = model_class(**model_kwargs)
return RetryChatModel(model)
def get_agent_model(agent_id: str, stream: bool = False):
"""
Get model for a specific agent based on environment variables
Environment variable pattern:
AGENT_{AGENT_ID}_MODEL_NAME: Model name
AGENT_{AGENT_ID}_MODEL_PROVIDER: Provider name
fallback to global MODEL_NAME & MODEL_PROVIDER if agent-specific not given
Args:
agent_id: Agent ID (e.g., "sentiment_analyst", "portfolio_manager")
stream: Whether to use streaming mode
Returns:
AgentScope model instance
"""
resolved = get_agent_model_config(agent_id)
return create_model(
model_name=resolved.model_name,
provider=resolved.provider,
stream=stream,
)
def get_agent_formatter(agent_id: str):
"""
Get formatter for a specific agent based on environment variables
Args:
agent_id: Agent ID (e.g., "sentiment_analyst", "portfolio_manager")
Returns:
AgentScope formatter instance
"""
provider = get_agent_model_config(agent_id).provider
formatter_class = PROVIDER_FORMATTER_MAP.get(provider, OpenAIChatFormatter)
return formatter_class()
def get_agent_model_info(agent_id: str) -> Tuple[str, str]:
"""
Get model name and provider for a specific agent
Args:
agent_id: Agent ID (e.g., "sentiment_analyst", "portfolio_manager")
Returns:
Tuple of (model_name, provider_name)
"""
resolved = get_agent_model_config(agent_id)
return resolved.model_name, resolved.provider