# -*- coding: utf-8 -*- """ AgentScope Native Model Factory Uses native AgentScope model classes for LLM calls """ import asyncio import inspect import os import time import logging from enum import Enum from typing import Any, Callable, Optional, Tuple, TypeVar from agentscope.formatter import ( AnthropicChatFormatter, DashScopeChatFormatter, GeminiChatFormatter, OllamaChatFormatter, OpenAIChatFormatter, ) from agentscope.model import ( AnthropicChatModel, DashScopeChatModel, GeminiChatModel, OllamaChatModel, OpenAIChatModel, ) from backend.config.env_config import ( canonicalize_model_provider, get_agent_model_config, get_env_str, ) logger = logging.getLogger(__name__) # Retry wrapper types T = TypeVar("T") def _usage_value(usage: Any, key: str, default: Any = 0) -> Any: """Read usage fields from both object-style and dict-style usage payloads.""" if usage is None: return default if isinstance(usage, dict): return usage.get(key, default) try: return getattr(usage, key) except (AttributeError, KeyError): return default def _usage_total_tokens(usage: Any) -> int: total = _usage_value(usage, "total_tokens", None) if total is not None: return int(total or 0) input_tokens = _usage_value(usage, "input_tokens", 0) output_tokens = _usage_value(usage, "output_tokens", 0) return int((input_tokens or 0) + (output_tokens or 0)) class RetryChatModel: """Wraps an AgentScope model with automatic retry for transient errors. Based on CoPaw's RetryChatModel design. Handles rate limits, timeouts, and other transient failures with exponential backoff. """ DEFAULT_MAX_RETRIES = 3 DEFAULT_INITIAL_DELAY = 1.0 DEFAULT_MAX_DELAY = 60.0 DEFAULT_BACKOFF_MULTIPLIER = 2.0 # Transient error codes/messages that should trigger retry TRANSIENT_ERROR_KEYWORDS = frozenset([ "rate_limit", "429", "timeout", "503", "502", "504", "connection", "disconnected", "temporary", "overloaded", "too_many_requests", ]) def __init__( self, model: Any, max_retries: int = DEFAULT_MAX_RETRIES, initial_delay: float = DEFAULT_INITIAL_DELAY, max_delay: float = DEFAULT_MAX_DELAY, backoff_multiplier: float = DEFAULT_BACKOFF_MULTIPLIER, on_retry: Optional[Callable[[int, Exception, float], None]] = None, ): """Initialize retry wrapper. Args: model: The underlying AgentScope model to wrap max_retries: Maximum number of retry attempts initial_delay: Initial delay in seconds before first retry max_delay: Maximum delay between retries backoff_multiplier: Multiplier for exponential backoff on_retry: Optional callback(retry_count, exception, delay) for logging """ self._model = model self._max_retries = max_retries self._initial_delay = initial_delay self._max_delay = max_delay self._backoff_multiplier = backoff_multiplier self._on_retry = on_retry self._total_tokens_used = 0 self._total_cost = 0.0 @property def model_name(self) -> str: return getattr(self._model, "model_name", str(self._model)) @property def total_tokens_used(self) -> int: return self._total_tokens_used @property def total_cost(self) -> float: return self._total_cost def _is_transient_error(self, error: Exception) -> bool: """Check if an error is transient and should be retried. Args: error: The exception to check Returns: True if the error is transient """ error_str = str(error).lower() for keyword in self.TRANSIENT_ERROR_KEYWORDS: if keyword in error_str: return True return False def _calculate_delay(self, retry_count: int) -> float: """Calculate delay for given retry attempt with exponential backoff. Args: retry_count: Current retry attempt number (1-based) Returns: Delay in seconds """ delay = self._initial_delay * (self._backoff_multiplier ** (retry_count - 1)) return min(delay, self._max_delay) def _call_with_retry(self, func: Callable[..., T], *args, **kwargs) -> T: """Call a function with retry logic for transient errors. Args: func: Function to call *args: Positional arguments **kwargs: Keyword arguments Returns: Result from func Raises: Last exception if all retries exhausted """ last_error: Optional[Exception] = None for attempt in range(1, self._max_retries + 1): try: result = func(*args, **kwargs) # Track usage if available if hasattr(result, "usage") and result.usage: usage = result.usage self._total_tokens_used += _usage_total_tokens(usage) self._total_cost += float(_usage_value(usage, "cost", 0.0) or 0.0) return result except Exception as e: last_error = e if attempt >= self._max_retries: logger.error( "RetryChatModel: Max retries (%d) exhausted for %s", self._max_retries, self.model_name, ) break if not self._is_transient_error(e): logger.warning( "RetryChatModel: Non-transient error, not retrying: %s", str(e), ) break delay = self._calculate_delay(attempt) logger.warning( "RetryChatModel: Transient error on attempt %d/%d, " "retrying in %.1fs: %s", attempt, self._max_retries, delay, str(e)[:200], ) if self._on_retry: self._on_retry(attempt, e, delay) time.sleep(delay) if last_error is not None: raise last_error raise RuntimeError("RetryChatModel: Unexpected state, no error but no result") async def _call_with_retry_async(self, func: Callable[..., T], *args, **kwargs) -> T: """Call an async function with retry logic for transient errors.""" last_error: Optional[Exception] = None for attempt in range(1, self._max_retries + 1): try: result = await func(*args, **kwargs) if hasattr(result, "usage") and result.usage: usage = result.usage self._total_tokens_used += _usage_total_tokens(usage) self._total_cost += float(_usage_value(usage, "cost", 0.0) or 0.0) return result except Exception as e: last_error = e if attempt >= self._max_retries: logger.error( "RetryChatModel: Max retries (%d) exhausted for %s", self._max_retries, self.model_name, ) break if not self._is_transient_error(e): logger.warning( "RetryChatModel: Non-transient error, not retrying: %s", str(e), ) break delay = self._calculate_delay(attempt) logger.warning( "RetryChatModel: Transient async error on attempt %d/%d, " "retrying in %.1fs: %s", attempt, self._max_retries, delay, str(e)[:200], ) if self._on_retry: self._on_retry(attempt, e, delay) await asyncio.sleep(delay) if last_error is not None: raise last_error raise RuntimeError("RetryChatModel: Unexpected async state, no error but no result") def __call__(self, *args, **kwargs) -> Any: """Forward calls to the wrapped model with retry logic.""" model_call = getattr(self._model, "__call__", None) if inspect.iscoroutinefunction(self._model) or inspect.iscoroutinefunction(model_call): return self._call_with_retry_async(self._model, *args, **kwargs) result = self._model(*args, **kwargs) return result def __getattr__(self, name: str) -> Any: """Proxy attribute access to the wrapped model.""" return getattr(self._model, name) class TokenRecordingModelWrapper: """Wraps a model to track token usage per provider. Based on CoPaw's TokenRecordingModelWrapper design. """ def __init__(self, model: Any): """Initialize token recorder. Args: model: The underlying AgentScope model to wrap """ self._model = model self._total_tokens = 0 self._prompt_tokens = 0 self._completion_tokens = 0 self._total_cost = 0.0 @property def model_name(self) -> str: return getattr(self._model, "model_name", str(self._model)) @property def total_tokens(self) -> int: return self._total_tokens @property def prompt_tokens(self) -> int: return self._prompt_tokens @property def completion_tokens(self) -> int: return self._completion_tokens @property def total_cost(self) -> float: return self._total_cost def record_usage(self, usage: Any) -> None: """Record token usage from a model response. Args: usage: Usage object from model response """ if usage is None: return prompt_tokens = _usage_value(usage, "prompt_tokens", None) completion_tokens = _usage_value(usage, "completion_tokens", None) if prompt_tokens is None: prompt_tokens = _usage_value(usage, "input_tokens", 0) if completion_tokens is None: completion_tokens = _usage_value(usage, "output_tokens", 0) self._prompt_tokens += int(prompt_tokens or 0) self._completion_tokens += int(completion_tokens or 0) self._total_tokens += _usage_total_tokens(usage) self._total_cost += float(_usage_value(usage, "cost", 0.0) or 0.0) def __call__(self, *args, **kwargs) -> Any: """Forward calls and record usage.""" result = self._model(*args, **kwargs) if hasattr(result, "usage") and result.usage: self.record_usage(result.usage) return result def __getattr__(self, name: str) -> Any: """Proxy attribute access to the wrapped model.""" return getattr(self._model, name) class ModelProvider(Enum): """Supported model providers""" OPENAI = "OPENAI" ANTHROPIC = "ANTHROPIC" DASHSCOPE = "DASHSCOPE" ALIBABA = "ALIBABA" GEMINI = "GEMINI" GOOGLE = "GOOGLE" OLLAMA = "OLLAMA" DEEPSEEK = "DEEPSEEK" GROQ = "GROQ" OPENROUTER = "OPENROUTER" # Provider to AgentScope model class mapping PROVIDER_MODEL_MAP = { "OPENAI": OpenAIChatModel, "ANTHROPIC": AnthropicChatModel, "DASHSCOPE": DashScopeChatModel, "ALIBABA": DashScopeChatModel, "GEMINI": GeminiChatModel, "GOOGLE": GeminiChatModel, "OLLAMA": OllamaChatModel, # OpenAI-compatible providers use OpenAIChatModel with custom base_url "DEEPSEEK": OpenAIChatModel, "GROQ": OpenAIChatModel, "OPENROUTER": OpenAIChatModel, } # Provider to formatter mapping PROVIDER_FORMATTER_MAP = { "OPENAI": OpenAIChatFormatter, "ANTHROPIC": AnthropicChatFormatter, "DASHSCOPE": DashScopeChatFormatter, "ALIBABA": DashScopeChatFormatter, "GEMINI": GeminiChatFormatter, "GOOGLE": GeminiChatFormatter, "OLLAMA": OllamaChatFormatter, # OpenAI-compatible providers use OpenAIChatFormatter "DEEPSEEK": OpenAIChatFormatter, "GROQ": OpenAIChatFormatter, "OPENROUTER": OpenAIChatFormatter, } # Provider-specific base URLs PROVIDER_BASE_URLS = { "DEEPSEEK": "https://api.deepseek.com/v1", "GROQ": "https://api.groq.com/openai/v1", "OPENROUTER": "https://openrouter.ai/api/v1", } # Provider-specific API key environment variable names PROVIDER_API_KEY_ENV = { "OPENAI": "OPENAI_API_KEY", "ANTHROPIC": "ANTHROPIC_API_KEY", "DASHSCOPE": "DASHSCOPE_API_KEY", "ALIBABA": "DASHSCOPE_API_KEY", "GEMINI": "GOOGLE_API_KEY", "GOOGLE": "GOOGLE_API_KEY", "DEEPSEEK": "DEEPSEEK_API_KEY", "GROQ": "GROQ_API_KEY", "OPENROUTER": "OPENROUTER_API_KEY", } def create_model( model_name: str, provider: str, api_key: Optional[str] = None, stream: bool = False, **kwargs, ): """ Create an AgentScope model instance Args: model_name: Model name (e.g., "gpt-4o", "claude-3-opus") provider: Provider name (e.g., "OPENAI", "ANTHROPIC") api_key: API key (optional, will read from env if not provided) stream: Whether to use streaming mode **kwargs: Additional model-specific arguments Returns: AgentScope model instance """ provider = canonicalize_model_provider(provider) # If provider is default OPENAI but model name looks like deepseek, # check if we should switch to DASHSCOPE. if provider == "OPENAI" and "deepseek" in model_name.lower() and os.getenv("DASHSCOPE_API_KEY"): provider = "DASHSCOPE" # Intelligent routing: if it's a DeepSeek model and we have DashScope credentials, # prefer using DashScopeChatModel over OpenAIChatModel. if provider == "DEEPSEEK" and os.getenv("DASHSCOPE_API_KEY"): provider = "DASHSCOPE" model_class = PROVIDER_MODEL_MAP.get(provider) if model_class is None: raise ValueError(f"Unsupported provider: {provider}") # Get API key from env if not provided if api_key is None: env_key = PROVIDER_API_KEY_ENV.get(provider) if env_key: api_key = os.getenv(env_key) # Build model kwargs model_kwargs = { "model_name": model_name, "stream": stream, **kwargs, } # Add API key if needed (Ollama doesn't need it) if provider != "OLLAMA" and api_key: model_kwargs["api_key"] = api_key # Handle OpenAI-compatible providers with custom base_url if provider in PROVIDER_BASE_URLS: base_url = PROVIDER_BASE_URLS[provider] model_kwargs["client_args"] = {"base_url": base_url} # Handle custom OpenAI base URL if provider == "OPENAI": base_url = get_env_str("OPENAI_BASE_URL") or get_env_str( "OPENAI_API_BASE", ) if base_url: model_kwargs["client_args"] = {"base_url": base_url} # Handle DashScope base URL (uses different parameter) if provider in ("DASHSCOPE", "ALIBABA"): base_url = get_env_str("DASHSCOPE_BASE_URL") if base_url: model_kwargs["base_http_api_url"] = base_url # Handle Ollama host if provider == "OLLAMA": host = get_env_str("OLLAMA_HOST") if host: model_kwargs["host"] = host model = model_class(**model_kwargs) return RetryChatModel(model) def get_agent_model(agent_id: str, stream: bool = False): """ Get model for a specific agent based on environment variables Environment variable pattern: AGENT_{AGENT_ID}_MODEL_NAME: Model name AGENT_{AGENT_ID}_MODEL_PROVIDER: Provider name fallback to global MODEL_NAME & MODEL_PROVIDER if agent-specific not given Args: agent_id: Agent ID (e.g., "sentiment_analyst", "portfolio_manager") stream: Whether to use streaming mode Returns: AgentScope model instance """ resolved = get_agent_model_config(agent_id) return create_model( model_name=resolved.model_name, provider=resolved.provider, stream=stream, ) def get_agent_formatter(agent_id: str): """ Get formatter for a specific agent based on environment variables Args: agent_id: Agent ID (e.g., "sentiment_analyst", "portfolio_manager") Returns: AgentScope formatter instance """ provider = get_agent_model_config(agent_id).provider formatter_class = PROVIDER_FORMATTER_MAP.get(provider, OpenAIChatFormatter) return formatter_class() def get_agent_model_info(agent_id: str) -> Tuple[str, str]: """ Get model name and provider for a specific agent Args: agent_id: Agent ID (e.g., "sentiment_analyst", "portfolio_manager") Returns: Tuple of (model_name, provider_name) """ resolved = get_agent_model_config(agent_id) return resolved.model_name, resolved.provider