229 lines
7.2 KiB
Python
229 lines
7.2 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""
|
|
Portfolio Manager Agent - Based on AgentScope ReActAgent
|
|
Responsible for decision-making (NOT trade execution)
|
|
"""
|
|
|
|
from typing import Any, Dict, Optional
|
|
|
|
from agentscope.agent import ReActAgent
|
|
from agentscope.memory import InMemoryMemory, LongTermMemoryBase
|
|
from agentscope.message import Msg, TextBlock
|
|
from agentscope.tool import Toolkit, ToolResponse
|
|
|
|
from ..utils.progress import progress
|
|
from .prompt_factory import build_agent_system_prompt, clear_prompt_factory_cache
|
|
|
|
|
|
class PMAgent(ReActAgent):
|
|
"""
|
|
Portfolio Manager Agent - Makes investment decisions
|
|
|
|
Key features:
|
|
1. PM outputs decisions only (action + quantity per ticker)
|
|
2. Trade execution happens externally (in pipeline/executor)
|
|
3. Supports both backtest and live modes
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
name: str = "portfolio_manager",
|
|
model: Any = None,
|
|
formatter: Any = None,
|
|
initial_cash: float = 100000.0,
|
|
margin_requirement: float = 0.25,
|
|
config: Optional[Dict[str, Any]] = None,
|
|
long_term_memory: Optional[LongTermMemoryBase] = None,
|
|
toolkit_factory: Any = None,
|
|
toolkit_factory_kwargs: Optional[Dict[str, Any]] = None,
|
|
toolkit: Optional[Toolkit] = None,
|
|
):
|
|
self.config = config or {}
|
|
|
|
# Portfolio state
|
|
self.portfolio = {
|
|
"cash": initial_cash,
|
|
"positions": {},
|
|
"margin_used": 0.0,
|
|
"margin_requirement": margin_requirement,
|
|
}
|
|
|
|
# Decisions made in current cycle
|
|
self._decisions: Dict[str, Dict] = {}
|
|
toolkit_factory_kwargs = toolkit_factory_kwargs or {}
|
|
self._toolkit_factory = toolkit_factory
|
|
self._toolkit_factory_kwargs = toolkit_factory_kwargs
|
|
|
|
# Create toolkit after local state is ready so bound tool methods can be registered.
|
|
if toolkit is None:
|
|
if toolkit_factory is not None:
|
|
toolkit = toolkit_factory(
|
|
name,
|
|
self.config.get("config_name", "default"),
|
|
owner=self,
|
|
**toolkit_factory_kwargs,
|
|
)
|
|
else:
|
|
toolkit = self._create_toolkit()
|
|
self.toolkit = toolkit
|
|
|
|
sys_prompt = build_agent_system_prompt(
|
|
agent_id=name,
|
|
config_name=self.config.get("config_name", "default"),
|
|
toolkit=self.toolkit,
|
|
)
|
|
|
|
kwargs = {
|
|
"name": name,
|
|
"sys_prompt": sys_prompt,
|
|
"model": model,
|
|
"formatter": formatter,
|
|
"toolkit": toolkit,
|
|
"memory": InMemoryMemory(),
|
|
"max_iters": 10,
|
|
}
|
|
if long_term_memory:
|
|
kwargs["long_term_memory"] = long_term_memory
|
|
kwargs["long_term_memory_mode"] = "both"
|
|
|
|
super().__init__(**kwargs)
|
|
|
|
def _create_toolkit(self) -> Toolkit:
|
|
"""Create toolkit with decision recording tool"""
|
|
toolkit = Toolkit()
|
|
toolkit.register_tool_function(self._make_decision)
|
|
return toolkit
|
|
|
|
def _make_decision(
|
|
self,
|
|
ticker: str,
|
|
action: str,
|
|
quantity: int,
|
|
confidence: int = 50,
|
|
reasoning: str = "",
|
|
) -> ToolResponse:
|
|
"""
|
|
Record a trading decision for a ticker.
|
|
|
|
Args:
|
|
ticker: Stock ticker symbol (e.g., "AAPL")
|
|
action: Decision - "long", "short" or "hold"
|
|
quantity: Number of shares to trade (0 for hold)
|
|
confidence: Confidence level 0-100
|
|
reasoning: Explanation for this decision
|
|
|
|
Returns:
|
|
ToolResponse confirming decision recorded
|
|
"""
|
|
if action not in ["long", "short", "hold"]:
|
|
return ToolResponse(
|
|
content=[
|
|
TextBlock(
|
|
type="text",
|
|
text=f"Invalid action: {action}. "
|
|
"Must be 'long', 'short', or 'hold'.",
|
|
),
|
|
],
|
|
)
|
|
|
|
self._decisions[ticker] = {
|
|
"action": action,
|
|
"quantity": quantity if action != "hold" else 0,
|
|
"confidence": confidence,
|
|
"reasoning": reasoning,
|
|
}
|
|
|
|
return ToolResponse(
|
|
content=[
|
|
TextBlock(
|
|
type="text",
|
|
text=f"Decision recorded: {action} "
|
|
f"{quantity} shares of {ticker}"
|
|
f" (confidence: {confidence}%)",
|
|
),
|
|
],
|
|
)
|
|
|
|
async def reply(self, x: Msg = None) -> Msg:
|
|
"""
|
|
Make investment decisions
|
|
|
|
Returns:
|
|
Msg with decisions in metadata
|
|
"""
|
|
if x is None:
|
|
return Msg(
|
|
name=self.name,
|
|
content="No input provided",
|
|
role="assistant",
|
|
)
|
|
|
|
# Clear previous decisions
|
|
self._decisions = {}
|
|
|
|
progress.update_status(
|
|
self.name,
|
|
None,
|
|
"Analyzing and making decisions",
|
|
)
|
|
|
|
result = await super().reply(x)
|
|
|
|
progress.update_status(self.name, None, "Completed")
|
|
|
|
# Attach decisions to metadata
|
|
if result.metadata is None:
|
|
result.metadata = {}
|
|
result.metadata["decisions"] = self._decisions.copy()
|
|
result.metadata["portfolio"] = self.portfolio.copy()
|
|
|
|
return result
|
|
|
|
def get_decisions(self) -> Dict[str, Dict]:
|
|
"""Get decisions from current cycle"""
|
|
return self._decisions.copy()
|
|
|
|
def get_portfolio_state(self) -> Dict[str, Any]:
|
|
"""Get current portfolio state"""
|
|
return self.portfolio.copy()
|
|
|
|
def load_portfolio_state(self, portfolio: Dict[str, Any]):
|
|
"""Load portfolio state"""
|
|
if not portfolio:
|
|
return
|
|
self.portfolio = {
|
|
"cash": portfolio.get("cash", self.portfolio["cash"]),
|
|
"positions": portfolio.get("positions", {}).copy(),
|
|
"margin_used": portfolio.get("margin_used", 0.0),
|
|
"margin_requirement": portfolio.get(
|
|
"margin_requirement",
|
|
self.portfolio["margin_requirement"],
|
|
),
|
|
}
|
|
|
|
def update_portfolio(self, portfolio: Dict[str, Any]):
|
|
"""Update portfolio after external execution"""
|
|
self.portfolio.update(portfolio)
|
|
|
|
def reload_runtime_assets(self, active_skill_dirs: Optional[list] = None) -> None:
|
|
"""Reload toolkit and system prompt from current run assets."""
|
|
from .toolkit_factory import create_agent_toolkit
|
|
|
|
clear_prompt_factory_cache()
|
|
toolkit_factory = self._toolkit_factory or create_agent_toolkit
|
|
toolkit_kwargs = dict(self._toolkit_factory_kwargs)
|
|
if active_skill_dirs is not None:
|
|
toolkit_kwargs["active_skill_dirs"] = active_skill_dirs
|
|
|
|
self.toolkit = toolkit_factory(
|
|
self.name,
|
|
self.config.get("config_name", "default"),
|
|
owner=self,
|
|
**toolkit_kwargs,
|
|
)
|
|
self.sys_prompt = build_agent_system_prompt(
|
|
agent_id=self.name,
|
|
config_name=self.config.get("config_name", "default"),
|
|
toolkit=self.toolkit,
|
|
)
|