Add run-scoped skill and prompt asset management

This commit is contained in:
2026-03-16 00:04:04 +08:00
parent 964d3b6e13
commit 78f133617f
23 changed files with 1309 additions and 109 deletions

View File

@@ -14,6 +14,11 @@ import loguru
from dotenv import load_dotenv
from backend.agents import AnalystAgent, PMAgent, RiskAgent
from backend.agents.skills_manager import SkillsManager
from backend.agents.toolkit_factory import create_agent_toolkit, load_agent_profiles
from backend.agents.prompt_loader import PromptLoader
from backend.agents.workspace_manager import WorkspaceManager
from backend.config.bootstrap_config import get_bootstrap_config_for_run
from backend.config.constants import ANALYST_TYPES
from backend.config.env_config import get_env_float, get_env_int, get_env_list
from backend.core.pipeline import TradingPipeline
@@ -28,6 +33,38 @@ load_dotenv()
logger = logging.getLogger(__name__)
loguru.logger.disable("flowllm")
loguru.logger.disable("reme_ai")
_prompt_loader = PromptLoader()
def _resolve_runtime_config(args) -> dict:
"""Merge env defaults with run-scoped bootstrap config."""
project_root = Path(__file__).resolve().parents[1]
bootstrap = get_bootstrap_config_for_run(project_root, args.config_name)
return {
"tickers": bootstrap.get("tickers")
or get_env_list("TICKERS", ["AAPL", "MSFT"]),
"initial_cash": float(
bootstrap.get(
"initial_cash",
get_env_float("INITIAL_CASH", 100000.0),
),
),
"margin_requirement": float(
bootstrap.get(
"margin_requirement",
get_env_float("MARGIN_REQUIREMENT", 0.0),
),
),
"max_comm_cycles": int(
bootstrap.get(
"max_comm_cycles",
get_env_int("MAX_COMM_CYCLES", 2),
),
),
"enable_memory": args.enable_memory
or bool(bootstrap.get("enable_memory", False)),
}
def create_long_term_memory(agent_name: str, config_name: str):
@@ -84,11 +121,31 @@ def create_agents(
"""
analysts = []
long_term_memories = []
workspace_manager = WorkspaceManager()
workspace_manager.initialize_default_assets(
config_name=config_name,
agent_ids=list(ANALYST_TYPES.keys())
+ ["risk_manager", "portfolio_manager"],
analyst_personas=_prompt_loader.load_yaml_config("analyst", "personas"),
)
profiles = load_agent_profiles()
skills_manager = SkillsManager()
active_skill_map = skills_manager.prepare_active_skills(
config_name=config_name,
agent_defaults={
agent_id: profile.get("skills", [])
for agent_id, profile in profiles.items()
},
)
for analyst_type in ANALYST_TYPES:
model = get_agent_model(analyst_type)
formatter = get_agent_formatter(analyst_type)
toolkit = create_toolkit(analyst_type)
toolkit = create_agent_toolkit(
analyst_type,
config_name,
active_skill_dirs=active_skill_map.get(analyst_type, []),
)
long_term_memory = None
if enable_long_term_memory:
@@ -125,6 +182,11 @@ def create_agents(
name="risk_manager",
config={"config_name": config_name},
long_term_memory=risk_long_term_memory,
toolkit=create_agent_toolkit(
"risk_manager",
config_name,
active_skill_dirs=active_skill_map.get("risk_manager", []),
),
)
pm_long_term_memory = None
@@ -144,44 +206,25 @@ def create_agents(
margin_requirement=margin_requirement,
config={"config_name": config_name},
long_term_memory=pm_long_term_memory,
toolkit_factory=create_agent_toolkit,
toolkit_factory_kwargs={
"active_skill_dirs": active_skill_map.get(
"portfolio_manager",
[],
),
},
)
return analysts, risk_manager, portfolio_manager, long_term_memories
def create_toolkit(analyst_type: str):
"""Create AgentScope Toolkit with tools for specific analyst type"""
from agentscope.tool import Toolkit
from backend.agents.prompt_loader import PromptLoader
from backend.tools.analysis_tools import TOOL_REGISTRY
# Load analyst persona config
prompt_loader = PromptLoader()
personas_config = prompt_loader.load_yaml_config("analyst", "personas")
persona = personas_config.get(analyst_type, {})
# Get tool names for this analyst type
tool_names = persona.get("tools", [])
# Create toolkit and register tools
toolkit = Toolkit()
for tool_name in tool_names:
tool_func = TOOL_REGISTRY.get(tool_name)
if tool_func:
toolkit.register_tool_function(tool_func)
return toolkit
async def run_with_gateway(args):
"""Run with WebSocket gateway"""
is_backtest = args.mode == "backtest"
runtime_config = _resolve_runtime_config(args)
# Load config from env, override with args
tickers = get_env_list("TICKERS", ["AAPL", "MSFT"])
initial_cash = get_env_float("INITIAL_CASH", 100000.0)
margin_requirement = get_env_float("MARGIN_REQUIREMENT", 0.0)
config_name = args.config_name
tickers = runtime_config["tickers"]
initial_cash = runtime_config["initial_cash"]
margin_requirement = runtime_config["margin_requirement"]
# Create market service
market_service = MarketService(
@@ -213,7 +256,7 @@ async def run_with_gateway(args):
config_name=config_name,
initial_cash=initial_cash,
margin_requirement=margin_requirement,
enable_long_term_memory=args.enable_memory,
enable_long_term_memory=runtime_config["enable_memory"],
)
portfolio_state = storage_service.load_portfolio_state()
pm.load_portfolio_state(portfolio_state)
@@ -228,7 +271,7 @@ async def run_with_gateway(args):
risk_manager=risk_manager,
portfolio_manager=pm,
settlement_coordinator=settlement_coordinator,
max_comm_cycles=get_env_int("MAX_COMM_CYCLES", 2),
max_comm_cycles=runtime_config["max_comm_cycles"],
)
# Create scheduler callback
@@ -307,15 +350,17 @@ def main():
args = parser.parse_args()
# Load config from env for logging
tickers = get_env_list("TICKERS", ["AAPL", "MSFT"])
initial_cash = get_env_float("INITIAL_CASH", 100000.0)
runtime_config = _resolve_runtime_config(args)
tickers = runtime_config["tickers"]
initial_cash = runtime_config["initial_cash"]
logger.info("=" * 60)
logger.info(f"Mode: {args.mode}, Config: {args.config_name}")
logger.info(f"Tickers: {tickers}")
logger.info(f"Initial Cash: ${initial_cash:,.2f}")
logger.info(
f"Long-term Memory: {'enabled' if args.enable_memory else 'disabled'}",
"Long-term Memory: %s",
"enabled" if runtime_config["enable_memory"] else "disabled",
)
if args.mode == "backtest":
if not args.start_date or not args.end_date: