Add run-scoped skill and prompt asset management
This commit is contained in:
117
backend/main.py
117
backend/main.py
@@ -14,6 +14,11 @@ import loguru
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from backend.agents import AnalystAgent, PMAgent, RiskAgent
|
||||
from backend.agents.skills_manager import SkillsManager
|
||||
from backend.agents.toolkit_factory import create_agent_toolkit, load_agent_profiles
|
||||
from backend.agents.prompt_loader import PromptLoader
|
||||
from backend.agents.workspace_manager import WorkspaceManager
|
||||
from backend.config.bootstrap_config import get_bootstrap_config_for_run
|
||||
from backend.config.constants import ANALYST_TYPES
|
||||
from backend.config.env_config import get_env_float, get_env_int, get_env_list
|
||||
from backend.core.pipeline import TradingPipeline
|
||||
@@ -28,6 +33,38 @@ load_dotenv()
|
||||
logger = logging.getLogger(__name__)
|
||||
loguru.logger.disable("flowllm")
|
||||
loguru.logger.disable("reme_ai")
|
||||
_prompt_loader = PromptLoader()
|
||||
|
||||
|
||||
def _resolve_runtime_config(args) -> dict:
|
||||
"""Merge env defaults with run-scoped bootstrap config."""
|
||||
project_root = Path(__file__).resolve().parents[1]
|
||||
bootstrap = get_bootstrap_config_for_run(project_root, args.config_name)
|
||||
|
||||
return {
|
||||
"tickers": bootstrap.get("tickers")
|
||||
or get_env_list("TICKERS", ["AAPL", "MSFT"]),
|
||||
"initial_cash": float(
|
||||
bootstrap.get(
|
||||
"initial_cash",
|
||||
get_env_float("INITIAL_CASH", 100000.0),
|
||||
),
|
||||
),
|
||||
"margin_requirement": float(
|
||||
bootstrap.get(
|
||||
"margin_requirement",
|
||||
get_env_float("MARGIN_REQUIREMENT", 0.0),
|
||||
),
|
||||
),
|
||||
"max_comm_cycles": int(
|
||||
bootstrap.get(
|
||||
"max_comm_cycles",
|
||||
get_env_int("MAX_COMM_CYCLES", 2),
|
||||
),
|
||||
),
|
||||
"enable_memory": args.enable_memory
|
||||
or bool(bootstrap.get("enable_memory", False)),
|
||||
}
|
||||
|
||||
|
||||
def create_long_term_memory(agent_name: str, config_name: str):
|
||||
@@ -84,11 +121,31 @@ def create_agents(
|
||||
"""
|
||||
analysts = []
|
||||
long_term_memories = []
|
||||
workspace_manager = WorkspaceManager()
|
||||
workspace_manager.initialize_default_assets(
|
||||
config_name=config_name,
|
||||
agent_ids=list(ANALYST_TYPES.keys())
|
||||
+ ["risk_manager", "portfolio_manager"],
|
||||
analyst_personas=_prompt_loader.load_yaml_config("analyst", "personas"),
|
||||
)
|
||||
profiles = load_agent_profiles()
|
||||
skills_manager = SkillsManager()
|
||||
active_skill_map = skills_manager.prepare_active_skills(
|
||||
config_name=config_name,
|
||||
agent_defaults={
|
||||
agent_id: profile.get("skills", [])
|
||||
for agent_id, profile in profiles.items()
|
||||
},
|
||||
)
|
||||
|
||||
for analyst_type in ANALYST_TYPES:
|
||||
model = get_agent_model(analyst_type)
|
||||
formatter = get_agent_formatter(analyst_type)
|
||||
toolkit = create_toolkit(analyst_type)
|
||||
toolkit = create_agent_toolkit(
|
||||
analyst_type,
|
||||
config_name,
|
||||
active_skill_dirs=active_skill_map.get(analyst_type, []),
|
||||
)
|
||||
|
||||
long_term_memory = None
|
||||
if enable_long_term_memory:
|
||||
@@ -125,6 +182,11 @@ def create_agents(
|
||||
name="risk_manager",
|
||||
config={"config_name": config_name},
|
||||
long_term_memory=risk_long_term_memory,
|
||||
toolkit=create_agent_toolkit(
|
||||
"risk_manager",
|
||||
config_name,
|
||||
active_skill_dirs=active_skill_map.get("risk_manager", []),
|
||||
),
|
||||
)
|
||||
|
||||
pm_long_term_memory = None
|
||||
@@ -144,44 +206,25 @@ def create_agents(
|
||||
margin_requirement=margin_requirement,
|
||||
config={"config_name": config_name},
|
||||
long_term_memory=pm_long_term_memory,
|
||||
toolkit_factory=create_agent_toolkit,
|
||||
toolkit_factory_kwargs={
|
||||
"active_skill_dirs": active_skill_map.get(
|
||||
"portfolio_manager",
|
||||
[],
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
return analysts, risk_manager, portfolio_manager, long_term_memories
|
||||
|
||||
|
||||
def create_toolkit(analyst_type: str):
|
||||
"""Create AgentScope Toolkit with tools for specific analyst type"""
|
||||
from agentscope.tool import Toolkit
|
||||
from backend.agents.prompt_loader import PromptLoader
|
||||
from backend.tools.analysis_tools import TOOL_REGISTRY
|
||||
|
||||
# Load analyst persona config
|
||||
prompt_loader = PromptLoader()
|
||||
personas_config = prompt_loader.load_yaml_config("analyst", "personas")
|
||||
persona = personas_config.get(analyst_type, {})
|
||||
|
||||
# Get tool names for this analyst type
|
||||
tool_names = persona.get("tools", [])
|
||||
|
||||
# Create toolkit and register tools
|
||||
toolkit = Toolkit()
|
||||
for tool_name in tool_names:
|
||||
tool_func = TOOL_REGISTRY.get(tool_name)
|
||||
if tool_func:
|
||||
toolkit.register_tool_function(tool_func)
|
||||
|
||||
return toolkit
|
||||
|
||||
|
||||
async def run_with_gateway(args):
|
||||
"""Run with WebSocket gateway"""
|
||||
is_backtest = args.mode == "backtest"
|
||||
runtime_config = _resolve_runtime_config(args)
|
||||
|
||||
# Load config from env, override with args
|
||||
tickers = get_env_list("TICKERS", ["AAPL", "MSFT"])
|
||||
initial_cash = get_env_float("INITIAL_CASH", 100000.0)
|
||||
margin_requirement = get_env_float("MARGIN_REQUIREMENT", 0.0)
|
||||
config_name = args.config_name
|
||||
tickers = runtime_config["tickers"]
|
||||
initial_cash = runtime_config["initial_cash"]
|
||||
margin_requirement = runtime_config["margin_requirement"]
|
||||
|
||||
# Create market service
|
||||
market_service = MarketService(
|
||||
@@ -213,7 +256,7 @@ async def run_with_gateway(args):
|
||||
config_name=config_name,
|
||||
initial_cash=initial_cash,
|
||||
margin_requirement=margin_requirement,
|
||||
enable_long_term_memory=args.enable_memory,
|
||||
enable_long_term_memory=runtime_config["enable_memory"],
|
||||
)
|
||||
portfolio_state = storage_service.load_portfolio_state()
|
||||
pm.load_portfolio_state(portfolio_state)
|
||||
@@ -228,7 +271,7 @@ async def run_with_gateway(args):
|
||||
risk_manager=risk_manager,
|
||||
portfolio_manager=pm,
|
||||
settlement_coordinator=settlement_coordinator,
|
||||
max_comm_cycles=get_env_int("MAX_COMM_CYCLES", 2),
|
||||
max_comm_cycles=runtime_config["max_comm_cycles"],
|
||||
)
|
||||
|
||||
# Create scheduler callback
|
||||
@@ -307,15 +350,17 @@ def main():
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load config from env for logging
|
||||
tickers = get_env_list("TICKERS", ["AAPL", "MSFT"])
|
||||
initial_cash = get_env_float("INITIAL_CASH", 100000.0)
|
||||
runtime_config = _resolve_runtime_config(args)
|
||||
tickers = runtime_config["tickers"]
|
||||
initial_cash = runtime_config["initial_cash"]
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info(f"Mode: {args.mode}, Config: {args.config_name}")
|
||||
logger.info(f"Tickers: {tickers}")
|
||||
logger.info(f"Initial Cash: ${initial_cash:,.2f}")
|
||||
logger.info(
|
||||
f"Long-term Memory: {'enabled' if args.enable_memory else 'disabled'}",
|
||||
"Long-term Memory: %s",
|
||||
"enabled" if runtime_config["enable_memory"] else "disabled",
|
||||
)
|
||||
if args.mode == "backtest":
|
||||
if not args.start_date or not args.end_date:
|
||||
|
||||
Reference in New Issue
Block a user