381 lines
12 KiB
Python
381 lines
12 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""
|
|
Main Entry Point
|
|
Supports: backtest, live, mock modes
|
|
"""
|
|
import argparse
|
|
import asyncio
|
|
import logging
|
|
import os
|
|
from contextlib import AsyncExitStack
|
|
from pathlib import Path
|
|
import loguru
|
|
|
|
from dotenv import load_dotenv
|
|
|
|
from backend.agents import AnalystAgent, PMAgent, RiskAgent
|
|
from backend.agents.skills_manager import SkillsManager
|
|
from backend.agents.toolkit_factory import create_agent_toolkit, load_agent_profiles
|
|
from backend.agents.prompt_loader import PromptLoader
|
|
from backend.agents.workspace_manager import WorkspaceManager
|
|
from backend.config.bootstrap_config import resolve_runtime_config
|
|
from backend.config.constants import ANALYST_TYPES
|
|
from backend.core.pipeline import TradingPipeline
|
|
from backend.core.scheduler import BacktestScheduler, Scheduler
|
|
from backend.utils.settlement import SettlementCoordinator
|
|
from backend.llm.models import get_agent_formatter, get_agent_model
|
|
from backend.services.gateway import Gateway
|
|
from backend.services.market import MarketService
|
|
from backend.services.storage import StorageService
|
|
|
|
load_dotenv()
|
|
logger = logging.getLogger(__name__)
|
|
loguru.logger.disable("flowllm")
|
|
loguru.logger.disable("reme_ai")
|
|
_prompt_loader = PromptLoader()
|
|
|
|
|
|
def _get_run_dir(config_name: str) -> Path:
|
|
"""Return the canonical run-scoped directory for a config."""
|
|
project_root = Path(__file__).resolve().parents[1]
|
|
return WorkspaceManager(project_root=project_root).get_run_dir(config_name)
|
|
|
|
|
|
def _resolve_runtime_config(args) -> dict:
|
|
"""Merge env defaults with run-scoped bootstrap config."""
|
|
project_root = Path(__file__).resolve().parents[1]
|
|
return resolve_runtime_config(
|
|
project_root=project_root,
|
|
config_name=args.config_name,
|
|
enable_memory=args.enable_memory,
|
|
schedule_mode=args.schedule_mode,
|
|
interval_minutes=args.interval_minutes,
|
|
trigger_time=args.trigger_time,
|
|
)
|
|
|
|
|
|
def create_long_term_memory(agent_name: str, config_name: str):
|
|
"""
|
|
Create ReMeTaskLongTermMemory for an agent
|
|
|
|
Requires DASHSCOPE_API_KEY env var
|
|
"""
|
|
from agentscope.memory import ReMeTaskLongTermMemory
|
|
from agentscope.model import DashScopeChatModel
|
|
from agentscope.embedding import DashScopeTextEmbedding
|
|
|
|
api_key = os.getenv("MEMORY_API_KEY")
|
|
if not api_key:
|
|
logger.warning("MEMORY_API_KEY not set, long-term memory disabled")
|
|
return None
|
|
|
|
memory_dir = str(_get_run_dir(config_name) / "memory")
|
|
|
|
return ReMeTaskLongTermMemory(
|
|
agent_name=agent_name,
|
|
user_name=agent_name,
|
|
model=DashScopeChatModel(
|
|
model_name=os.getenv("MEMORY_MODEL_NAME", "qwen3-max"),
|
|
api_key=api_key,
|
|
stream=False,
|
|
),
|
|
embedding_model=DashScopeTextEmbedding(
|
|
model_name=os.getenv(
|
|
"MEMORY_EMBEDDING_MODEL",
|
|
"text-embedding-v4",
|
|
),
|
|
api_key=api_key,
|
|
dimensions=1024,
|
|
),
|
|
**{
|
|
"vector_store.default.backend": "local",
|
|
"vector_store.default.params.store_dir": memory_dir,
|
|
},
|
|
)
|
|
|
|
|
|
def create_agents(
|
|
config_name: str,
|
|
initial_cash: float,
|
|
margin_requirement: float,
|
|
enable_long_term_memory: bool = False,
|
|
):
|
|
"""Create all agents for the system
|
|
|
|
Returns:
|
|
tuple: (analysts, risk_manager, portfolio_manager, long_term_memories)
|
|
long_term_memories is a list of memory
|
|
"""
|
|
analysts = []
|
|
long_term_memories = []
|
|
workspace_manager = WorkspaceManager()
|
|
workspace_manager.initialize_default_assets(
|
|
config_name=config_name,
|
|
agent_ids=list(ANALYST_TYPES.keys())
|
|
+ ["risk_manager", "portfolio_manager"],
|
|
analyst_personas=_prompt_loader.load_yaml_config("analyst", "personas"),
|
|
)
|
|
profiles = load_agent_profiles()
|
|
skills_manager = SkillsManager()
|
|
active_skill_map = skills_manager.prepare_active_skills(
|
|
config_name=config_name,
|
|
agent_defaults={
|
|
agent_id: profile.get("skills", [])
|
|
for agent_id, profile in profiles.items()
|
|
},
|
|
)
|
|
|
|
for analyst_type in ANALYST_TYPES:
|
|
model = get_agent_model(analyst_type)
|
|
formatter = get_agent_formatter(analyst_type)
|
|
toolkit = create_agent_toolkit(
|
|
analyst_type,
|
|
config_name,
|
|
active_skill_dirs=active_skill_map.get(analyst_type, []),
|
|
)
|
|
|
|
long_term_memory = None
|
|
if enable_long_term_memory:
|
|
long_term_memory = create_long_term_memory(
|
|
analyst_type,
|
|
config_name,
|
|
)
|
|
if long_term_memory:
|
|
long_term_memories.append(long_term_memory)
|
|
|
|
analyst = AnalystAgent(
|
|
analyst_type=analyst_type,
|
|
toolkit=toolkit,
|
|
model=model,
|
|
formatter=formatter,
|
|
agent_id=analyst_type,
|
|
config={"config_name": config_name},
|
|
long_term_memory=long_term_memory,
|
|
)
|
|
analysts.append(analyst)
|
|
|
|
risk_long_term_memory = None
|
|
if enable_long_term_memory:
|
|
risk_long_term_memory = create_long_term_memory(
|
|
"risk_manager",
|
|
config_name,
|
|
)
|
|
if risk_long_term_memory:
|
|
long_term_memories.append(risk_long_term_memory)
|
|
|
|
risk_manager = RiskAgent(
|
|
model=get_agent_model("risk_manager"),
|
|
formatter=get_agent_formatter("risk_manager"),
|
|
name="risk_manager",
|
|
config={"config_name": config_name},
|
|
long_term_memory=risk_long_term_memory,
|
|
toolkit=create_agent_toolkit(
|
|
"risk_manager",
|
|
config_name,
|
|
active_skill_dirs=active_skill_map.get("risk_manager", []),
|
|
),
|
|
)
|
|
|
|
pm_long_term_memory = None
|
|
if enable_long_term_memory:
|
|
pm_long_term_memory = create_long_term_memory(
|
|
"portfolio_manager",
|
|
config_name,
|
|
)
|
|
if pm_long_term_memory:
|
|
long_term_memories.append(pm_long_term_memory)
|
|
|
|
portfolio_manager = PMAgent(
|
|
name="portfolio_manager",
|
|
model=get_agent_model("portfolio_manager"),
|
|
formatter=get_agent_formatter("portfolio_manager"),
|
|
initial_cash=initial_cash,
|
|
margin_requirement=margin_requirement,
|
|
config={"config_name": config_name},
|
|
long_term_memory=pm_long_term_memory,
|
|
toolkit_factory=create_agent_toolkit,
|
|
toolkit_factory_kwargs={
|
|
"active_skill_dirs": active_skill_map.get(
|
|
"portfolio_manager",
|
|
[],
|
|
),
|
|
},
|
|
)
|
|
|
|
return analysts, risk_manager, portfolio_manager, long_term_memories
|
|
async def run_with_gateway(args):
|
|
"""Run with WebSocket gateway"""
|
|
is_backtest = args.mode == "backtest"
|
|
runtime_config = _resolve_runtime_config(args)
|
|
|
|
config_name = args.config_name
|
|
tickers = runtime_config["tickers"]
|
|
initial_cash = runtime_config["initial_cash"]
|
|
margin_requirement = runtime_config["margin_requirement"]
|
|
|
|
# Create market service
|
|
market_service = MarketService(
|
|
tickers=tickers,
|
|
poll_interval=args.poll_interval,
|
|
mock_mode=args.mock and not is_backtest,
|
|
backtest_mode=is_backtest,
|
|
api_key=os.getenv("FINNHUB_API_KEY")
|
|
if not args.mock and not is_backtest
|
|
else None,
|
|
backtest_start_date=args.start_date if is_backtest else None,
|
|
backtest_end_date=args.end_date if is_backtest else None,
|
|
)
|
|
|
|
# Create storage service
|
|
storage_service = StorageService(
|
|
dashboard_dir=_get_run_dir(config_name) / "team_dashboard",
|
|
initial_cash=initial_cash,
|
|
config_name=config_name,
|
|
)
|
|
|
|
if not storage_service.files["summary"].exists():
|
|
storage_service.initialize_empty_dashboard()
|
|
else:
|
|
storage_service.update_leaderboard_model_info()
|
|
|
|
# Create agents and pipeline
|
|
analysts, risk_manager, pm, long_term_memories = create_agents(
|
|
config_name=config_name,
|
|
initial_cash=initial_cash,
|
|
margin_requirement=margin_requirement,
|
|
enable_long_term_memory=runtime_config["enable_memory"],
|
|
)
|
|
portfolio_state = storage_service.load_portfolio_state()
|
|
pm.load_portfolio_state(portfolio_state)
|
|
|
|
settlement_coordinator = SettlementCoordinator(
|
|
storage=storage_service,
|
|
initial_capital=initial_cash,
|
|
)
|
|
|
|
pipeline = TradingPipeline(
|
|
analysts=analysts,
|
|
risk_manager=risk_manager,
|
|
portfolio_manager=pm,
|
|
settlement_coordinator=settlement_coordinator,
|
|
max_comm_cycles=runtime_config["max_comm_cycles"],
|
|
)
|
|
|
|
# Create scheduler callback
|
|
scheduler_callback = None
|
|
trading_dates = []
|
|
live_scheduler = None
|
|
|
|
if is_backtest:
|
|
backtest_scheduler = BacktestScheduler(
|
|
start_date=args.start_date,
|
|
end_date=args.end_date,
|
|
trading_calendar="NYSE",
|
|
delay_between_days=0.5,
|
|
)
|
|
trading_dates = backtest_scheduler.get_trading_dates()
|
|
|
|
async def scheduler_callback_fn(callback):
|
|
await backtest_scheduler.start(callback)
|
|
|
|
scheduler_callback = scheduler_callback_fn
|
|
else:
|
|
# Live mode: use daily or intraday scheduler with NYSE timezone
|
|
live_scheduler = Scheduler(
|
|
mode=runtime_config["schedule_mode"],
|
|
trigger_time=runtime_config["trigger_time"],
|
|
interval_minutes=runtime_config["interval_minutes"],
|
|
config={"config_name": config_name},
|
|
)
|
|
|
|
async def scheduler_callback_fn(callback):
|
|
await live_scheduler.start(callback)
|
|
|
|
scheduler_callback = scheduler_callback_fn
|
|
|
|
# Create gateway
|
|
gateway = Gateway(
|
|
market_service=market_service,
|
|
storage_service=storage_service,
|
|
pipeline=pipeline,
|
|
scheduler_callback=scheduler_callback,
|
|
config={
|
|
"mode": args.mode,
|
|
"mock_mode": args.mock,
|
|
"backtest_mode": is_backtest,
|
|
"tickers": tickers,
|
|
"config_name": config_name,
|
|
"schedule_mode": runtime_config["schedule_mode"],
|
|
"interval_minutes": runtime_config["interval_minutes"],
|
|
"trigger_time": runtime_config["trigger_time"],
|
|
"initial_cash": initial_cash,
|
|
"margin_requirement": margin_requirement,
|
|
"max_comm_cycles": runtime_config["max_comm_cycles"],
|
|
"enable_memory": runtime_config["enable_memory"],
|
|
},
|
|
scheduler=live_scheduler if not is_backtest else None,
|
|
)
|
|
|
|
if is_backtest:
|
|
gateway.set_backtest_dates(trading_dates)
|
|
|
|
# Start long-term memory contexts and run gateway
|
|
async with AsyncExitStack() as stack:
|
|
for memory in long_term_memories:
|
|
await stack.enter_async_context(memory)
|
|
await gateway.start(host=args.host, port=args.port)
|
|
|
|
|
|
def main():
|
|
"""Main entry point"""
|
|
parser = argparse.ArgumentParser(description="Trading System")
|
|
parser.add_argument("--mode", choices=["live", "backtest"], default="live")
|
|
parser.add_argument("--mock", action="store_true")
|
|
parser.add_argument("--config-name", default="mock")
|
|
parser.add_argument("--host", default="0.0.0.0")
|
|
parser.add_argument("--port", type=int, default=8765)
|
|
parser.add_argument(
|
|
"--schedule-mode",
|
|
choices=["daily", "intraday"],
|
|
default="daily",
|
|
)
|
|
parser.add_argument("--trigger-time", default="09:30") # NYSE market open
|
|
parser.add_argument("--interval-minutes", type=int, default=60)
|
|
parser.add_argument("--poll-interval", type=int, default=10)
|
|
parser.add_argument("--start-date")
|
|
parser.add_argument("--end-date")
|
|
parser.add_argument(
|
|
"--enable-memory",
|
|
action="store_true",
|
|
help="Enable ReMeTaskLongTermMemory for agents",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Load config from env for logging
|
|
runtime_config = _resolve_runtime_config(args)
|
|
tickers = runtime_config["tickers"]
|
|
initial_cash = runtime_config["initial_cash"]
|
|
|
|
logger.info("=" * 60)
|
|
logger.info(f"Mode: {args.mode}, Config: {args.config_name}")
|
|
logger.info(f"Tickers: {tickers}")
|
|
logger.info(f"Initial Cash: ${initial_cash:,.2f}")
|
|
logger.info(
|
|
"Long-term Memory: %s",
|
|
"enabled" if runtime_config["enable_memory"] else "disabled",
|
|
)
|
|
if args.mode == "backtest":
|
|
if not args.start_date or not args.end_date:
|
|
parser.error(
|
|
"--start-date and --end-date required for backtest mode",
|
|
)
|
|
logger.info(f"Backtest: {args.start_date} to {args.end_date}")
|
|
logger.info("=" * 60)
|
|
|
|
asyncio.run(run_with_gateway(args))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|