Initial commit of integrated agent system
This commit is contained in:
0
backend/llm/__init__.py
Normal file
0
backend/llm/__init__.py
Normal file
549
backend/llm/models.py
Normal file
549
backend/llm/models.py
Normal file
@@ -0,0 +1,549 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
AgentScope Native Model Factory
|
||||
Uses native AgentScope model classes for LLM calls
|
||||
"""
|
||||
import asyncio
|
||||
import inspect
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Optional, Tuple, TypeVar, Union
|
||||
from agentscope.formatter import (
|
||||
AnthropicChatFormatter,
|
||||
DashScopeChatFormatter,
|
||||
GeminiChatFormatter,
|
||||
OllamaChatFormatter,
|
||||
OpenAIChatFormatter,
|
||||
)
|
||||
from agentscope.model import (
|
||||
AnthropicChatModel,
|
||||
DashScopeChatModel,
|
||||
GeminiChatModel,
|
||||
OllamaChatModel,
|
||||
OpenAIChatModel,
|
||||
)
|
||||
from backend.config.env_config import (
|
||||
canonicalize_model_provider,
|
||||
get_agent_model_config,
|
||||
get_env_str,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Retry wrapper types
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def _usage_value(usage: Any, key: str, default: Any = 0) -> Any:
|
||||
"""Read usage fields from both object-style and dict-style usage payloads."""
|
||||
if usage is None:
|
||||
return default
|
||||
if isinstance(usage, dict):
|
||||
return usage.get(key, default)
|
||||
try:
|
||||
return getattr(usage, key)
|
||||
except (AttributeError, KeyError):
|
||||
return default
|
||||
|
||||
|
||||
def _usage_total_tokens(usage: Any) -> int:
|
||||
total = _usage_value(usage, "total_tokens", None)
|
||||
if total is not None:
|
||||
return int(total or 0)
|
||||
input_tokens = _usage_value(usage, "input_tokens", 0)
|
||||
output_tokens = _usage_value(usage, "output_tokens", 0)
|
||||
return int((input_tokens or 0) + (output_tokens or 0))
|
||||
|
||||
|
||||
class RetryChatModel:
|
||||
"""Wraps an AgentScope model with automatic retry for transient errors.
|
||||
|
||||
Based on CoPaw's RetryChatModel design. Handles rate limits, timeouts,
|
||||
and other transient failures with exponential backoff.
|
||||
"""
|
||||
|
||||
DEFAULT_MAX_RETRIES = 3
|
||||
DEFAULT_INITIAL_DELAY = 1.0
|
||||
DEFAULT_MAX_DELAY = 60.0
|
||||
DEFAULT_BACKOFF_MULTIPLIER = 2.0
|
||||
|
||||
# Transient error codes/messages that should trigger retry
|
||||
TRANSIENT_ERROR_KEYWORDS = frozenset([
|
||||
"rate_limit",
|
||||
"429",
|
||||
"timeout",
|
||||
"503",
|
||||
"502",
|
||||
"504",
|
||||
"connection",
|
||||
"disconnected",
|
||||
"temporary",
|
||||
"overloaded",
|
||||
"too_many_requests",
|
||||
])
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Any,
|
||||
max_retries: int = DEFAULT_MAX_RETRIES,
|
||||
initial_delay: float = DEFAULT_INITIAL_DELAY,
|
||||
max_delay: float = DEFAULT_MAX_DELAY,
|
||||
backoff_multiplier: float = DEFAULT_BACKOFF_MULTIPLIER,
|
||||
on_retry: Optional[Callable[[int, Exception, float], None]] = None,
|
||||
):
|
||||
"""Initialize retry wrapper.
|
||||
|
||||
Args:
|
||||
model: The underlying AgentScope model to wrap
|
||||
max_retries: Maximum number of retry attempts
|
||||
initial_delay: Initial delay in seconds before first retry
|
||||
max_delay: Maximum delay between retries
|
||||
backoff_multiplier: Multiplier for exponential backoff
|
||||
on_retry: Optional callback(retry_count, exception, delay) for logging
|
||||
"""
|
||||
self._model = model
|
||||
self._max_retries = max_retries
|
||||
self._initial_delay = initial_delay
|
||||
self._max_delay = max_delay
|
||||
self._backoff_multiplier = backoff_multiplier
|
||||
self._on_retry = on_retry
|
||||
self._total_tokens_used = 0
|
||||
self._total_cost = 0.0
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
return getattr(self._model, "model_name", str(self._model))
|
||||
|
||||
@property
|
||||
def total_tokens_used(self) -> int:
|
||||
return self._total_tokens_used
|
||||
|
||||
@property
|
||||
def total_cost(self) -> float:
|
||||
return self._total_cost
|
||||
|
||||
def _is_transient_error(self, error: Exception) -> bool:
|
||||
"""Check if an error is transient and should be retried.
|
||||
|
||||
Args:
|
||||
error: The exception to check
|
||||
|
||||
Returns:
|
||||
True if the error is transient
|
||||
"""
|
||||
error_str = str(error).lower()
|
||||
for keyword in self.TRANSIENT_ERROR_KEYWORDS:
|
||||
if keyword in error_str:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _calculate_delay(self, retry_count: int) -> float:
|
||||
"""Calculate delay for given retry attempt with exponential backoff.
|
||||
|
||||
Args:
|
||||
retry_count: Current retry attempt number (1-based)
|
||||
|
||||
Returns:
|
||||
Delay in seconds
|
||||
"""
|
||||
delay = self._initial_delay * (self._backoff_multiplier ** (retry_count - 1))
|
||||
return min(delay, self._max_delay)
|
||||
|
||||
def _call_with_retry(self, func: Callable[..., T], *args, **kwargs) -> T:
|
||||
"""Call a function with retry logic for transient errors.
|
||||
|
||||
Args:
|
||||
func: Function to call
|
||||
*args: Positional arguments
|
||||
**kwargs: Keyword arguments
|
||||
|
||||
Returns:
|
||||
Result from func
|
||||
|
||||
Raises:
|
||||
Last exception if all retries exhausted
|
||||
"""
|
||||
last_error: Optional[Exception] = None
|
||||
|
||||
for attempt in range(1, self._max_retries + 1):
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
|
||||
# Track usage if available
|
||||
if hasattr(result, "usage") and result.usage:
|
||||
usage = result.usage
|
||||
self._total_tokens_used += _usage_total_tokens(usage)
|
||||
self._total_cost += float(_usage_value(usage, "cost", 0.0) or 0.0)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
|
||||
if attempt >= self._max_retries:
|
||||
logger.error(
|
||||
"RetryChatModel: Max retries (%d) exhausted for %s",
|
||||
self._max_retries,
|
||||
self.model_name,
|
||||
)
|
||||
break
|
||||
|
||||
if not self._is_transient_error(e):
|
||||
logger.warning(
|
||||
"RetryChatModel: Non-transient error, not retrying: %s",
|
||||
str(e),
|
||||
)
|
||||
break
|
||||
|
||||
delay = self._calculate_delay(attempt)
|
||||
logger.warning(
|
||||
"RetryChatModel: Transient error on attempt %d/%d, "
|
||||
"retrying in %.1fs: %s",
|
||||
attempt,
|
||||
self._max_retries,
|
||||
delay,
|
||||
str(e)[:200],
|
||||
)
|
||||
|
||||
if self._on_retry:
|
||||
self._on_retry(attempt, e, delay)
|
||||
|
||||
time.sleep(delay)
|
||||
|
||||
if last_error is not None:
|
||||
raise last_error
|
||||
raise RuntimeError("RetryChatModel: Unexpected state, no error but no result")
|
||||
|
||||
async def _call_with_retry_async(self, func: Callable[..., T], *args, **kwargs) -> T:
|
||||
"""Call an async function with retry logic for transient errors."""
|
||||
last_error: Optional[Exception] = None
|
||||
|
||||
for attempt in range(1, self._max_retries + 1):
|
||||
try:
|
||||
result = await func(*args, **kwargs)
|
||||
|
||||
if hasattr(result, "usage") and result.usage:
|
||||
usage = result.usage
|
||||
self._total_tokens_used += _usage_total_tokens(usage)
|
||||
self._total_cost += float(_usage_value(usage, "cost", 0.0) or 0.0)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
|
||||
if attempt >= self._max_retries:
|
||||
logger.error(
|
||||
"RetryChatModel: Max retries (%d) exhausted for %s",
|
||||
self._max_retries,
|
||||
self.model_name,
|
||||
)
|
||||
break
|
||||
|
||||
if not self._is_transient_error(e):
|
||||
logger.warning(
|
||||
"RetryChatModel: Non-transient error, not retrying: %s",
|
||||
str(e),
|
||||
)
|
||||
break
|
||||
|
||||
delay = self._calculate_delay(attempt)
|
||||
logger.warning(
|
||||
"RetryChatModel: Transient async error on attempt %d/%d, "
|
||||
"retrying in %.1fs: %s",
|
||||
attempt,
|
||||
self._max_retries,
|
||||
delay,
|
||||
str(e)[:200],
|
||||
)
|
||||
|
||||
if self._on_retry:
|
||||
self._on_retry(attempt, e, delay)
|
||||
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
if last_error is not None:
|
||||
raise last_error
|
||||
raise RuntimeError("RetryChatModel: Unexpected async state, no error but no result")
|
||||
|
||||
def __call__(self, *args, **kwargs) -> Any:
|
||||
"""Forward calls to the wrapped model with retry logic."""
|
||||
model_call = getattr(self._model, "__call__", None)
|
||||
if inspect.iscoroutinefunction(self._model) or inspect.iscoroutinefunction(model_call):
|
||||
return self._call_with_retry_async(self._model, *args, **kwargs)
|
||||
|
||||
result = self._model(*args, **kwargs)
|
||||
return result
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
"""Proxy attribute access to the wrapped model."""
|
||||
return getattr(self._model, name)
|
||||
|
||||
|
||||
class TokenRecordingModelWrapper:
|
||||
"""Wraps a model to track token usage per provider.
|
||||
|
||||
Based on CoPaw's TokenRecordingModelWrapper design.
|
||||
"""
|
||||
|
||||
def __init__(self, model: Any):
|
||||
"""Initialize token recorder.
|
||||
|
||||
Args:
|
||||
model: The underlying AgentScope model to wrap
|
||||
"""
|
||||
self._model = model
|
||||
self._total_tokens = 0
|
||||
self._prompt_tokens = 0
|
||||
self._completion_tokens = 0
|
||||
self._total_cost = 0.0
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
return getattr(self._model, "model_name", str(self._model))
|
||||
|
||||
@property
|
||||
def total_tokens(self) -> int:
|
||||
return self._total_tokens
|
||||
|
||||
@property
|
||||
def prompt_tokens(self) -> int:
|
||||
return self._prompt_tokens
|
||||
|
||||
@property
|
||||
def completion_tokens(self) -> int:
|
||||
return self._completion_tokens
|
||||
|
||||
@property
|
||||
def total_cost(self) -> float:
|
||||
return self._total_cost
|
||||
|
||||
def record_usage(self, usage: Any) -> None:
|
||||
"""Record token usage from a model response.
|
||||
|
||||
Args:
|
||||
usage: Usage object from model response
|
||||
"""
|
||||
if usage is None:
|
||||
return
|
||||
|
||||
prompt_tokens = _usage_value(usage, "prompt_tokens", None)
|
||||
completion_tokens = _usage_value(usage, "completion_tokens", None)
|
||||
|
||||
if prompt_tokens is None:
|
||||
prompt_tokens = _usage_value(usage, "input_tokens", 0)
|
||||
if completion_tokens is None:
|
||||
completion_tokens = _usage_value(usage, "output_tokens", 0)
|
||||
|
||||
self._prompt_tokens += int(prompt_tokens or 0)
|
||||
self._completion_tokens += int(completion_tokens or 0)
|
||||
self._total_tokens += _usage_total_tokens(usage)
|
||||
self._total_cost += float(_usage_value(usage, "cost", 0.0) or 0.0)
|
||||
|
||||
def __call__(self, *args, **kwargs) -> Any:
|
||||
"""Forward calls and record usage."""
|
||||
result = self._model(*args, **kwargs)
|
||||
|
||||
if hasattr(result, "usage") and result.usage:
|
||||
self.record_usage(result.usage)
|
||||
|
||||
return result
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
"""Proxy attribute access to the wrapped model."""
|
||||
return getattr(self._model, name)
|
||||
|
||||
|
||||
class ModelProvider(Enum):
|
||||
"""Supported model providers"""
|
||||
|
||||
OPENAI = "OPENAI"
|
||||
ANTHROPIC = "ANTHROPIC"
|
||||
DASHSCOPE = "DASHSCOPE"
|
||||
ALIBABA = "ALIBABA"
|
||||
GEMINI = "GEMINI"
|
||||
GOOGLE = "GOOGLE"
|
||||
OLLAMA = "OLLAMA"
|
||||
DEEPSEEK = "DEEPSEEK"
|
||||
GROQ = "GROQ"
|
||||
OPENROUTER = "OPENROUTER"
|
||||
|
||||
|
||||
# Provider to AgentScope model class mapping
|
||||
PROVIDER_MODEL_MAP = {
|
||||
"OPENAI": OpenAIChatModel,
|
||||
"ANTHROPIC": AnthropicChatModel,
|
||||
"DASHSCOPE": DashScopeChatModel,
|
||||
"ALIBABA": DashScopeChatModel,
|
||||
"GEMINI": GeminiChatModel,
|
||||
"GOOGLE": GeminiChatModel,
|
||||
"OLLAMA": OllamaChatModel,
|
||||
# OpenAI-compatible providers use OpenAIChatModel with custom base_url
|
||||
"DEEPSEEK": OpenAIChatModel,
|
||||
"GROQ": OpenAIChatModel,
|
||||
"OPENROUTER": OpenAIChatModel,
|
||||
}
|
||||
|
||||
# Provider to formatter mapping
|
||||
PROVIDER_FORMATTER_MAP = {
|
||||
"OPENAI": OpenAIChatFormatter,
|
||||
"ANTHROPIC": AnthropicChatFormatter,
|
||||
"DASHSCOPE": DashScopeChatFormatter,
|
||||
"ALIBABA": DashScopeChatFormatter,
|
||||
"GEMINI": GeminiChatFormatter,
|
||||
"GOOGLE": GeminiChatFormatter,
|
||||
"OLLAMA": OllamaChatFormatter,
|
||||
# OpenAI-compatible providers use OpenAIChatFormatter
|
||||
"DEEPSEEK": OpenAIChatFormatter,
|
||||
"GROQ": OpenAIChatFormatter,
|
||||
"OPENROUTER": OpenAIChatFormatter,
|
||||
}
|
||||
|
||||
# Provider-specific base URLs
|
||||
PROVIDER_BASE_URLS = {
|
||||
"DEEPSEEK": "https://api.deepseek.com/v1",
|
||||
"GROQ": "https://api.groq.com/openai/v1",
|
||||
"OPENROUTER": "https://openrouter.ai/api/v1",
|
||||
}
|
||||
|
||||
# Provider-specific API key environment variable names
|
||||
PROVIDER_API_KEY_ENV = {
|
||||
"OPENAI": "OPENAI_API_KEY",
|
||||
"ANTHROPIC": "ANTHROPIC_API_KEY",
|
||||
"DASHSCOPE": "DASHSCOPE_API_KEY",
|
||||
"ALIBABA": "DASHSCOPE_API_KEY",
|
||||
"GEMINI": "GOOGLE_API_KEY",
|
||||
"GOOGLE": "GOOGLE_API_KEY",
|
||||
"DEEPSEEK": "DEEPSEEK_API_KEY",
|
||||
"GROQ": "GROQ_API_KEY",
|
||||
"OPENROUTER": "OPENROUTER_API_KEY",
|
||||
}
|
||||
|
||||
|
||||
def create_model(
|
||||
model_name: str,
|
||||
provider: str,
|
||||
api_key: Optional[str] = None,
|
||||
stream: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Create an AgentScope model instance
|
||||
|
||||
Args:
|
||||
model_name: Model name (e.g., "gpt-4o", "claude-3-opus")
|
||||
provider: Provider name (e.g., "OPENAI", "ANTHROPIC")
|
||||
api_key: API key (optional, will read from env if not provided)
|
||||
stream: Whether to use streaming mode
|
||||
**kwargs: Additional model-specific arguments
|
||||
|
||||
Returns:
|
||||
AgentScope model instance
|
||||
"""
|
||||
provider = canonicalize_model_provider(provider)
|
||||
|
||||
model_class = PROVIDER_MODEL_MAP.get(provider)
|
||||
if model_class is None:
|
||||
raise ValueError(f"Unsupported provider: {provider}")
|
||||
|
||||
# Get API key from env if not provided
|
||||
if api_key is None:
|
||||
env_key = PROVIDER_API_KEY_ENV.get(provider)
|
||||
if env_key:
|
||||
api_key = os.getenv(env_key)
|
||||
|
||||
# Build model kwargs
|
||||
model_kwargs = {
|
||||
"model_name": model_name,
|
||||
"stream": stream,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
# Add API key if needed (Ollama doesn't need it)
|
||||
if provider != "OLLAMA" and api_key:
|
||||
model_kwargs["api_key"] = api_key
|
||||
|
||||
# Handle OpenAI-compatible providers with custom base_url
|
||||
if provider in PROVIDER_BASE_URLS:
|
||||
base_url = PROVIDER_BASE_URLS[provider]
|
||||
model_kwargs["client_args"] = {"base_url": base_url}
|
||||
|
||||
# Handle custom OpenAI base URL
|
||||
if provider == "OPENAI":
|
||||
base_url = get_env_str("OPENAI_BASE_URL") or get_env_str(
|
||||
"OPENAI_API_BASE",
|
||||
)
|
||||
if base_url:
|
||||
model_kwargs["client_args"] = {"base_url": base_url}
|
||||
|
||||
# Handle DashScope base URL (uses different parameter)
|
||||
if provider in ("DASHSCOPE", "ALIBABA"):
|
||||
base_url = get_env_str("DASHSCOPE_BASE_URL")
|
||||
if base_url:
|
||||
model_kwargs["base_http_api_url"] = base_url
|
||||
|
||||
# Handle Ollama host
|
||||
if provider == "OLLAMA":
|
||||
host = get_env_str("OLLAMA_HOST")
|
||||
if host:
|
||||
model_kwargs["host"] = host
|
||||
|
||||
model = model_class(**model_kwargs)
|
||||
return RetryChatModel(model)
|
||||
|
||||
|
||||
def get_agent_model(agent_id: str, stream: bool = False):
|
||||
"""
|
||||
Get model for a specific agent based on environment variables
|
||||
|
||||
Environment variable pattern:
|
||||
AGENT_{AGENT_ID}_MODEL_NAME: Model name
|
||||
AGENT_{AGENT_ID}_MODEL_PROVIDER: Provider name
|
||||
|
||||
fallback to global MODEL_NAME & MODEL_PROVIDER if agent-specific not given
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID (e.g., "sentiment_analyst", "portfolio_manager")
|
||||
stream: Whether to use streaming mode
|
||||
|
||||
Returns:
|
||||
AgentScope model instance
|
||||
"""
|
||||
resolved = get_agent_model_config(agent_id)
|
||||
|
||||
return create_model(
|
||||
model_name=resolved.model_name,
|
||||
provider=resolved.provider,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
|
||||
def get_agent_formatter(agent_id: str):
|
||||
"""
|
||||
Get formatter for a specific agent based on environment variables
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID (e.g., "sentiment_analyst", "portfolio_manager")
|
||||
|
||||
Returns:
|
||||
AgentScope formatter instance
|
||||
"""
|
||||
provider = get_agent_model_config(agent_id).provider
|
||||
formatter_class = PROVIDER_FORMATTER_MAP.get(provider, OpenAIChatFormatter)
|
||||
return formatter_class()
|
||||
|
||||
|
||||
def get_agent_model_info(agent_id: str) -> Tuple[str, str]:
|
||||
"""
|
||||
Get model name and provider for a specific agent
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID (e.g., "sentiment_analyst", "portfolio_manager")
|
||||
|
||||
Returns:
|
||||
Tuple of (model_name, provider_name)
|
||||
"""
|
||||
resolved = get_agent_model_config(agent_id)
|
||||
return resolved.model_name, resolved.provider
|
||||
Reference in New Issue
Block a user