389 lines
13 KiB
Python
389 lines
13 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""
|
|
Portfolio Manager Agent - Based on AgentScope ReActAgent
|
|
Responsible for decision-making (NOT trade execution)
|
|
"""
|
|
|
|
from pathlib import Path
|
|
from typing import Any, Dict, Optional, Callable
|
|
|
|
from agentscope.agent import ReActAgent
|
|
from agentscope.memory import InMemoryMemory, LongTermMemoryBase
|
|
from agentscope.message import Msg, TextBlock
|
|
from agentscope.tool import Toolkit, ToolResponse
|
|
|
|
from ..utils.progress import progress
|
|
from .prompt_factory import build_agent_system_prompt, clear_prompt_factory_cache
|
|
from .team_pipeline_config import update_active_analysts
|
|
from ..config.constants import ANALYST_TYPES
|
|
|
|
|
|
class PMAgent(ReActAgent):
|
|
"""
|
|
Portfolio Manager Agent - Makes investment decisions
|
|
|
|
Key features:
|
|
1. PM outputs decisions only (action + quantity per ticker)
|
|
2. Trade execution happens externally (in pipeline/executor)
|
|
3. Supports both backtest and live modes
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
name: str = "portfolio_manager",
|
|
model: Any = None,
|
|
formatter: Any = None,
|
|
initial_cash: float = 100000.0,
|
|
margin_requirement: float = 0.25,
|
|
config: Optional[Dict[str, Any]] = None,
|
|
long_term_memory: Optional[LongTermMemoryBase] = None,
|
|
toolkit_factory: Any = None,
|
|
toolkit_factory_kwargs: Optional[Dict[str, Any]] = None,
|
|
toolkit: Optional[Toolkit] = None,
|
|
):
|
|
object.__setattr__(self, "config", config or {})
|
|
|
|
# Portfolio state
|
|
object.__setattr__(
|
|
self,
|
|
"portfolio",
|
|
{
|
|
"cash": initial_cash,
|
|
"positions": {},
|
|
"margin_used": 0.0,
|
|
"margin_requirement": margin_requirement,
|
|
},
|
|
)
|
|
|
|
# Decisions made in current cycle
|
|
object.__setattr__(self, "_decisions", {})
|
|
toolkit_factory_kwargs = toolkit_factory_kwargs or {}
|
|
object.__setattr__(self, "_toolkit_factory", toolkit_factory)
|
|
object.__setattr__(
|
|
self,
|
|
"_toolkit_factory_kwargs",
|
|
toolkit_factory_kwargs,
|
|
)
|
|
object.__setattr__(self, "_create_team_agent_cb", None)
|
|
object.__setattr__(self, "_remove_team_agent_cb", None)
|
|
|
|
# Create toolkit after local state is ready so bound tool methods can be registered.
|
|
if toolkit is None:
|
|
if toolkit_factory is not None:
|
|
toolkit = toolkit_factory(
|
|
name,
|
|
self.config.get("config_name", "default"),
|
|
owner=self,
|
|
**toolkit_factory_kwargs,
|
|
)
|
|
else:
|
|
toolkit = self._create_toolkit()
|
|
object.__setattr__(self, "toolkit", toolkit)
|
|
|
|
sys_prompt = build_agent_system_prompt(
|
|
agent_id=name,
|
|
config_name=self.config.get("config_name", "default"),
|
|
toolkit=self.toolkit,
|
|
)
|
|
|
|
kwargs = {
|
|
"name": name,
|
|
"sys_prompt": sys_prompt,
|
|
"model": model,
|
|
"formatter": formatter,
|
|
"toolkit": toolkit,
|
|
"memory": InMemoryMemory(),
|
|
"max_iters": 10,
|
|
}
|
|
if long_term_memory:
|
|
kwargs["long_term_memory"] = long_term_memory
|
|
kwargs["long_term_memory_mode"] = "both"
|
|
|
|
super().__init__(**kwargs)
|
|
|
|
def _create_toolkit(self) -> Toolkit:
|
|
"""Create toolkit with decision recording tool"""
|
|
toolkit = Toolkit()
|
|
toolkit.register_tool_function(self._make_decision)
|
|
return toolkit
|
|
|
|
def _make_decision(
|
|
self,
|
|
ticker: str,
|
|
action: str,
|
|
quantity: int,
|
|
confidence: int = 50,
|
|
reasoning: str = "",
|
|
) -> ToolResponse:
|
|
"""
|
|
Record a trading decision for a ticker.
|
|
|
|
Args:
|
|
ticker: Stock ticker symbol (e.g., "AAPL")
|
|
action: Decision - "long", "short" or "hold"
|
|
quantity: Number of shares to trade (0 for hold)
|
|
confidence: Confidence level 0-100
|
|
reasoning: Explanation for this decision
|
|
|
|
Returns:
|
|
ToolResponse confirming decision recorded
|
|
"""
|
|
if action not in ["long", "short", "hold"]:
|
|
return ToolResponse(
|
|
content=[
|
|
TextBlock(
|
|
type="text",
|
|
text=f"Invalid action: {action}. "
|
|
"Must be 'long', 'short', or 'hold'.",
|
|
),
|
|
],
|
|
)
|
|
|
|
self._decisions[ticker] = {
|
|
"action": action,
|
|
"quantity": quantity if action != "hold" else 0,
|
|
"confidence": confidence,
|
|
"reasoning": reasoning,
|
|
}
|
|
|
|
return ToolResponse(
|
|
content=[
|
|
TextBlock(
|
|
type="text",
|
|
text=f"Decision recorded: {action} "
|
|
f"{quantity} shares of {ticker}"
|
|
f" (confidence: {confidence}%)",
|
|
),
|
|
],
|
|
)
|
|
|
|
def _add_team_analyst(self, agent_id: str) -> ToolResponse:
|
|
"""Add one analyst to active discussion team."""
|
|
config_name = self.config.get("config_name", "default")
|
|
project_root = Path(__file__).resolve().parents[2]
|
|
active = update_active_analysts(
|
|
project_root=project_root,
|
|
config_name=config_name,
|
|
available_analysts=list(ANALYST_TYPES.keys()),
|
|
add=[agent_id],
|
|
)
|
|
return ToolResponse(
|
|
content=[
|
|
TextBlock(
|
|
type="text",
|
|
text=(
|
|
f"Active analyst team updated. Added: {agent_id}. "
|
|
f"Current active analysts: {', '.join(active)}"
|
|
),
|
|
),
|
|
],
|
|
)
|
|
|
|
def _remove_team_analyst(self, agent_id: str) -> ToolResponse:
|
|
"""Remove one analyst from active discussion team."""
|
|
callback_msg = ""
|
|
callback = self._remove_team_agent_cb
|
|
if callback is not None:
|
|
callback_msg = callback(agent_id=agent_id)
|
|
|
|
config_name = self.config.get("config_name", "default")
|
|
project_root = Path(__file__).resolve().parents[2]
|
|
active = update_active_analysts(
|
|
project_root=project_root,
|
|
config_name=config_name,
|
|
available_analysts=list(ANALYST_TYPES.keys()),
|
|
remove=[agent_id],
|
|
)
|
|
return ToolResponse(
|
|
content=[
|
|
TextBlock(
|
|
type="text",
|
|
text=(
|
|
f"Active analyst team updated. Removed: {agent_id}. "
|
|
f"Current active analysts: {', '.join(active)}"
|
|
+ (f" | {callback_msg}" if callback_msg else "")
|
|
),
|
|
),
|
|
],
|
|
)
|
|
|
|
def _set_active_analysts(self, agent_ids: str) -> ToolResponse:
|
|
"""Set active analysts from comma-separated agent ids."""
|
|
requested = [
|
|
item.strip() for item in str(agent_ids or "").split(",") if item.strip()
|
|
]
|
|
config_name = self.config.get("config_name", "default")
|
|
project_root = Path(__file__).resolve().parents[2]
|
|
active = update_active_analysts(
|
|
project_root=project_root,
|
|
config_name=config_name,
|
|
available_analysts=list(ANALYST_TYPES.keys()),
|
|
set_to=requested,
|
|
)
|
|
return ToolResponse(
|
|
content=[
|
|
TextBlock(
|
|
type="text",
|
|
text=f"Active analyst team set to: {', '.join(active)}",
|
|
),
|
|
],
|
|
)
|
|
|
|
def _create_team_analyst(self, agent_id: str, analyst_type: str) -> ToolResponse:
|
|
"""Create a runtime analyst instance and activate it."""
|
|
callback = self._create_team_agent_cb
|
|
if callback is None:
|
|
return ToolResponse(
|
|
content=[
|
|
TextBlock(
|
|
type="text",
|
|
text="Runtime agent creation is not available in current pipeline.",
|
|
),
|
|
],
|
|
)
|
|
result = callback(agent_id=agent_id, analyst_type=analyst_type)
|
|
return ToolResponse(
|
|
content=[
|
|
TextBlock(type="text", text=result),
|
|
],
|
|
)
|
|
|
|
def set_team_controller(
|
|
self,
|
|
*,
|
|
create_agent_callback: Optional[Callable[..., str]] = None,
|
|
remove_agent_callback: Optional[Callable[..., str]] = None,
|
|
) -> None:
|
|
"""Inject runtime team lifecycle callbacks from pipeline."""
|
|
object.__setattr__(self, "_create_team_agent_cb", create_agent_callback)
|
|
object.__setattr__(self, "_remove_team_agent_cb", remove_agent_callback)
|
|
|
|
async def reply(self, x: Msg = None) -> Msg:
|
|
"""
|
|
Make investment decisions
|
|
|
|
Returns:
|
|
Msg with decisions in metadata
|
|
"""
|
|
if x is None:
|
|
return Msg(
|
|
name=self.name,
|
|
content="No input provided",
|
|
role="assistant",
|
|
)
|
|
|
|
# Clear previous decisions
|
|
self._decisions = {}
|
|
|
|
progress.update_status(
|
|
self.name,
|
|
None,
|
|
"Analyzing and making decisions",
|
|
)
|
|
|
|
result = await super().reply(x)
|
|
|
|
progress.update_status(self.name, None, "Completed")
|
|
|
|
# Attach decisions to metadata
|
|
if result.metadata is None:
|
|
result.metadata = {}
|
|
result.metadata["decisions"] = self._decisions.copy()
|
|
result.metadata["portfolio"] = self.portfolio.copy()
|
|
|
|
return result
|
|
|
|
def get_decisions(self) -> Dict[str, Dict]:
|
|
"""Get decisions from current cycle"""
|
|
return self._decisions.copy()
|
|
|
|
def get_portfolio_state(self) -> Dict[str, Any]:
|
|
"""Get current portfolio state"""
|
|
return self.portfolio.copy()
|
|
|
|
def load_portfolio_state(self, portfolio: Dict[str, Any]):
|
|
"""Load portfolio state"""
|
|
if not portfolio:
|
|
return
|
|
self.portfolio = {
|
|
"cash": portfolio.get("cash", self.portfolio["cash"]),
|
|
"positions": portfolio.get("positions", {}).copy(),
|
|
"margin_used": portfolio.get("margin_used", 0.0),
|
|
"margin_requirement": portfolio.get(
|
|
"margin_requirement",
|
|
self.portfolio["margin_requirement"],
|
|
),
|
|
}
|
|
|
|
def update_portfolio(self, portfolio: Dict[str, Any]):
|
|
"""Update portfolio after external execution"""
|
|
self.portfolio.update(portfolio)
|
|
|
|
def _has_open_positions(self) -> bool:
|
|
"""Return whether the current portfolio still has non-zero positions."""
|
|
for position in self.portfolio.get("positions", {}).values():
|
|
if position.get("long", 0) or position.get("short", 0):
|
|
return True
|
|
return False
|
|
|
|
def can_apply_initial_cash(self) -> bool:
|
|
"""Only allow cash rebasing before any positions or margin exist."""
|
|
return (
|
|
not self._has_open_positions()
|
|
and float(self.portfolio.get("margin_used", 0.0) or 0.0) == 0.0
|
|
)
|
|
|
|
def apply_runtime_portfolio_config(
|
|
self,
|
|
*,
|
|
margin_requirement: Optional[float] = None,
|
|
initial_cash: Optional[float] = None,
|
|
) -> Dict[str, bool]:
|
|
"""Apply safe run-time portfolio config updates."""
|
|
result = {
|
|
"margin_requirement": False,
|
|
"initial_cash": False,
|
|
}
|
|
|
|
if margin_requirement is not None:
|
|
self.portfolio["margin_requirement"] = float(margin_requirement)
|
|
result["margin_requirement"] = True
|
|
|
|
if initial_cash is not None and self.can_apply_initial_cash():
|
|
self.portfolio["cash"] = float(initial_cash)
|
|
result["initial_cash"] = True
|
|
|
|
return result
|
|
|
|
def reload_runtime_assets(self, active_skill_dirs: Optional[list] = None) -> None:
|
|
"""Reload toolkit and system prompt from current run assets."""
|
|
from .toolkit_factory import create_agent_toolkit
|
|
|
|
clear_prompt_factory_cache()
|
|
toolkit_factory = self._toolkit_factory or create_agent_toolkit
|
|
toolkit_kwargs = dict(self._toolkit_factory_kwargs)
|
|
if active_skill_dirs is not None:
|
|
toolkit_kwargs["active_skill_dirs"] = active_skill_dirs
|
|
|
|
self.toolkit = toolkit_factory(
|
|
self.name,
|
|
self.config.get("config_name", "default"),
|
|
owner=self,
|
|
**toolkit_kwargs,
|
|
)
|
|
self._apply_runtime_sys_prompt(
|
|
build_agent_system_prompt(
|
|
agent_id=self.name,
|
|
config_name=self.config.get("config_name", "default"),
|
|
toolkit=self.toolkit,
|
|
),
|
|
)
|
|
|
|
def _apply_runtime_sys_prompt(self, sys_prompt: str) -> None:
|
|
"""Update the prompt used by future turns and the cached system msg."""
|
|
self._sys_prompt = sys_prompt
|
|
for msg, _marks in self.memory.content:
|
|
if getattr(msg, "role", None) == "system":
|
|
msg.content = sys_prompt
|
|
break
|