feat: initial commit - EvoTraders project
量化交易多智能体系统,包含: - 分析师、投资组合经理、风险经理等智能体 - 股票分析、投资组合管理、风险控制工具 - React 前端界面 - FastAPI 后端服务 Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
24
.eslintrc
Normal file
24
.eslintrc
Normal file
@@ -0,0 +1,24 @@
|
||||
{
|
||||
"env": {
|
||||
"browser": true,
|
||||
"es2021": true
|
||||
},
|
||||
"parserOptions": {
|
||||
"ecmaVersion": 2021,
|
||||
"sourceType": "module",
|
||||
"ecmaFeatures": {
|
||||
"jsx": true
|
||||
}
|
||||
},
|
||||
"rules": {
|
||||
"semi": ["error", "always"],
|
||||
"quotes": ["error", "double"],
|
||||
"indent": ["error", 2],
|
||||
"linebreak-style": ["error", "unix"],
|
||||
"brace-style": ["error", "1tbs"],
|
||||
"curly": ["error", "all"],
|
||||
"no-eval": ["error"],
|
||||
"prefer-const": ["error"],
|
||||
"arrow-spacing": ["error", { "before": true, "after": true }]
|
||||
}
|
||||
}
|
||||
63
.gitignore
vendored
Normal file
63
.gitignore
vendored
Normal file
@@ -0,0 +1,63 @@
|
||||
# Python
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.so
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
|
||||
# Virtual Environment
|
||||
/.venv/
|
||||
ENV/
|
||||
|
||||
# Environment Variables
|
||||
.env
|
||||
|
||||
# IDE
|
||||
.idea/
|
||||
.vscode/
|
||||
*.swp
|
||||
*.swo
|
||||
.cursorrules
|
||||
.cursorignore
|
||||
|
||||
# OS
|
||||
.DS_Store
|
||||
|
||||
# Txt files
|
||||
*.txt
|
||||
|
||||
# PDF files
|
||||
*.pdf
|
||||
|
||||
# Frontend
|
||||
node_modules
|
||||
|
||||
# Outputs
|
||||
outputs/
|
||||
|
||||
# Data files
|
||||
backend/data/ret_data/
|
||||
|
||||
# Database files (users will have their own local databases)
|
||||
*.db
|
||||
*.db-journal
|
||||
*.db-wal
|
||||
*.db-shm
|
||||
*.sqlite
|
||||
*.sqlite3
|
||||
|
||||
*.ipynb
|
||||
*.log
|
||||
332
.omc/project-memory.json
Normal file
332
.omc/project-memory.json
Normal file
@@ -0,0 +1,332 @@
|
||||
{
|
||||
"version": "1.0.0",
|
||||
"lastScanned": 1773304964541,
|
||||
"projectRoot": "/Users/cillin/workspeace/agentscope-samples/evotraders",
|
||||
"techStack": {
|
||||
"languages": [
|
||||
{
|
||||
"name": "Python",
|
||||
"version": null,
|
||||
"confidence": "high",
|
||||
"markers": [
|
||||
"pyproject.toml"
|
||||
]
|
||||
}
|
||||
],
|
||||
"frameworks": [
|
||||
{
|
||||
"name": "pytest",
|
||||
"version": null,
|
||||
"category": "testing"
|
||||
}
|
||||
],
|
||||
"packageManager": null,
|
||||
"runtime": null
|
||||
},
|
||||
"build": {
|
||||
"buildCommand": null,
|
||||
"testCommand": "pytest",
|
||||
"lintCommand": "ruff check",
|
||||
"devCommand": null,
|
||||
"scripts": {}
|
||||
},
|
||||
"conventions": {
|
||||
"namingStyle": null,
|
||||
"importStyle": null,
|
||||
"testPattern": null,
|
||||
"fileOrganization": null
|
||||
},
|
||||
"structure": {
|
||||
"isMonorepo": false,
|
||||
"workspaces": [],
|
||||
"mainDirectories": [
|
||||
"docs"
|
||||
],
|
||||
"gitBranches": {
|
||||
"defaultBranch": "main",
|
||||
"branchingStrategy": null
|
||||
}
|
||||
},
|
||||
"customNotes": [],
|
||||
"directoryMap": {
|
||||
"backend": {
|
||||
"path": "backend",
|
||||
"purpose": null,
|
||||
"fileCount": 3,
|
||||
"lastAccessed": 1773304964533,
|
||||
"keyFiles": [
|
||||
"__init__.py",
|
||||
"cli.py",
|
||||
"main.py"
|
||||
]
|
||||
},
|
||||
"docs": {
|
||||
"path": "docs",
|
||||
"purpose": "Documentation",
|
||||
"fileCount": 0,
|
||||
"lastAccessed": 1773304964533,
|
||||
"keyFiles": []
|
||||
},
|
||||
"evotraders.egg-info": {
|
||||
"path": "evotraders.egg-info",
|
||||
"purpose": null,
|
||||
"fileCount": 6,
|
||||
"lastAccessed": 1773304964534,
|
||||
"keyFiles": [
|
||||
"PKG-INFO",
|
||||
"SOURCES.txt",
|
||||
"dependency_links.txt",
|
||||
"entry_points.txt",
|
||||
"requires.txt"
|
||||
]
|
||||
},
|
||||
"frontend": {
|
||||
"path": "frontend",
|
||||
"purpose": null,
|
||||
"fileCount": 12,
|
||||
"lastAccessed": 1773304964535,
|
||||
"keyFiles": [
|
||||
"README.md",
|
||||
"components.json",
|
||||
"env.template",
|
||||
"eslint.config.js",
|
||||
"index.css"
|
||||
]
|
||||
},
|
||||
"backend/config": {
|
||||
"path": "backend/config",
|
||||
"purpose": "Configuration files",
|
||||
"fileCount": 4,
|
||||
"lastAccessed": 1773304964535,
|
||||
"keyFiles": [
|
||||
"__init__.py",
|
||||
"constants.py",
|
||||
"data_config.py"
|
||||
]
|
||||
},
|
||||
"backend/data": {
|
||||
"path": "backend/data",
|
||||
"purpose": "Data files",
|
||||
"fileCount": 7,
|
||||
"lastAccessed": 1773304964536,
|
||||
"keyFiles": [
|
||||
"__init__.py",
|
||||
"cache.py",
|
||||
"historical_price_manager.py"
|
||||
]
|
||||
},
|
||||
"backend/services": {
|
||||
"path": "backend/services",
|
||||
"purpose": "Business logic services",
|
||||
"fileCount": 4,
|
||||
"lastAccessed": 1773304964536,
|
||||
"keyFiles": [
|
||||
"__init__.py",
|
||||
"gateway.py",
|
||||
"market.py"
|
||||
]
|
||||
},
|
||||
"backend/tests": {
|
||||
"path": "backend/tests",
|
||||
"purpose": "Test files",
|
||||
"fileCount": 4,
|
||||
"lastAccessed": 1773304964536,
|
||||
"keyFiles": [
|
||||
"__init__.py",
|
||||
"test_agents.py",
|
||||
"test_market_service.py"
|
||||
]
|
||||
},
|
||||
"docs/assets": {
|
||||
"path": "docs/assets",
|
||||
"purpose": "Static assets",
|
||||
"fileCount": 5,
|
||||
"lastAccessed": 1773304964536,
|
||||
"keyFiles": [
|
||||
"dashboard.jpg",
|
||||
"evotraders_demo.gif",
|
||||
"evotraders_logo.jpg"
|
||||
]
|
||||
},
|
||||
"frontend/public": {
|
||||
"path": "frontend/public",
|
||||
"purpose": "Public files",
|
||||
"fileCount": 1,
|
||||
"lastAccessed": 1773304964538,
|
||||
"keyFiles": [
|
||||
"trading_logo.png"
|
||||
]
|
||||
}
|
||||
},
|
||||
"hotPaths": [
|
||||
{
|
||||
"path": "frontend/src/components/StatisticsView.jsx",
|
||||
"accessCount": 22,
|
||||
"lastAccessed": 1773310044545,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "frontend/src/components/AgentCard.jsx",
|
||||
"accessCount": 17,
|
||||
"lastAccessed": 1773309995177,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "frontend/src/App.jsx",
|
||||
"accessCount": 12,
|
||||
"lastAccessed": 1773309849392,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "frontend/src/components/AgentFeed.jsx",
|
||||
"accessCount": 12,
|
||||
"lastAccessed": 1773309960022,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": ".env",
|
||||
"accessCount": 7,
|
||||
"lastAccessed": 1773308950505,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "frontend/src/components/RoomView.jsx",
|
||||
"accessCount": 7,
|
||||
"lastAccessed": 1773309864236,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "backend/tools/analysis_tools.py",
|
||||
"accessCount": 5,
|
||||
"lastAccessed": 1773312271446,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "frontend/src/components/Header.jsx",
|
||||
"accessCount": 4,
|
||||
"lastAccessed": 1773309827069,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "frontend/src/components/AboutModal.jsx",
|
||||
"accessCount": 4,
|
||||
"lastAccessed": 1773310093371,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "backend/agents/prompts/analyst/personas.yaml",
|
||||
"accessCount": 4,
|
||||
"lastAccessed": 1773312049213,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "backend/agents/prompts/analyst/system.md",
|
||||
"accessCount": 4,
|
||||
"lastAccessed": 1773312049696,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "backend/agents/prompts/portfolio_manager/system.md",
|
||||
"accessCount": 4,
|
||||
"lastAccessed": 1773312050326,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "backend/agents/prompts/risk_manager/system.md",
|
||||
"accessCount": 4,
|
||||
"lastAccessed": 1773312050782,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "frontend/src/config/constants.js",
|
||||
"accessCount": 3,
|
||||
"lastAccessed": 1773309824671,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "frontend/src/components/RulesView.jsx",
|
||||
"accessCount": 3,
|
||||
"lastAccessed": 1773310061939,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "backend",
|
||||
"accessCount": 3,
|
||||
"lastAccessed": 1773312200721,
|
||||
"type": "directory"
|
||||
},
|
||||
{
|
||||
"path": "backend/services/gateway.py",
|
||||
"accessCount": 2,
|
||||
"lastAccessed": 1773312232905,
|
||||
"type": "directory"
|
||||
},
|
||||
{
|
||||
"path": "README.md",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1773305013217,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "README_zh.md",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1773305013274,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "env.template",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1773305019965,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "frontend/src/services/websocket.js",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1773309324302,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "backend/config/data_config.py",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1773309324414,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "backend/cli.py",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1773309336899,
|
||||
"type": "directory"
|
||||
},
|
||||
{
|
||||
"path": "backend/agents/portfolio_manager.py",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1773311956562,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "backend/agents/risk_manager.py",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1773311956760,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "backend/agents/analyst.py",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1773311963222,
|
||||
"type": "file"
|
||||
},
|
||||
{
|
||||
"path": "backend/tools",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1773312289643,
|
||||
"type": "directory"
|
||||
},
|
||||
{
|
||||
"path": "backend/tools/data_tools.py",
|
||||
"accessCount": 1,
|
||||
"lastAccessed": 1773312293851,
|
||||
"type": "directory"
|
||||
}
|
||||
],
|
||||
"userDirectives": []
|
||||
}
|
||||
@@ -0,0 +1,4 @@
|
||||
{"t":0,"agent":"a4090d2","agent_type":"executor","event":"agent_start","parent_mode":"none"}
|
||||
{"t":0,"agent":"a4090d2","agent_type":"executor","event":"agent_stop","success":true,"duration_ms":500954}
|
||||
{"t":0,"agent":"af87583","agent_type":"executor","event":"agent_start","parent_mode":"none"}
|
||||
{"t":0,"agent":"af87583","agent_type":"executor","event":"agent_stop","success":true,"duration_ms":72978}
|
||||
6
.omc/state/hud-state.json
Normal file
6
.omc/state/hud-state.json
Normal file
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"timestamp": "2026-03-12T20:33:59.497Z",
|
||||
"backgroundTasks": [],
|
||||
"sessionStartTimestamp": "2026-03-12T14:19:33.615Z",
|
||||
"sessionId": "73b0d597-0141-4873-9d0e-2b60e4e0635e"
|
||||
}
|
||||
1
.omc/state/hud-stdin-cache.json
Normal file
1
.omc/state/hud-stdin-cache.json
Normal file
@@ -0,0 +1 @@
|
||||
{"session_id":"73b0d597-0141-4873-9d0e-2b60e4e0635e","transcript_path":"/Users/cillin/.claude/projects/-Users-cillin-workspeace-agentscope-samples-evotraders/73b0d597-0141-4873-9d0e-2b60e4e0635e.jsonl","cwd":"/Users/cillin/workspeace/agentscope-samples/evotraders","model":{"id":"kimi-for-coding","display_name":"kimi-for-coding"},"workspace":{"current_dir":"/Users/cillin/workspeace/agentscope-samples/evotraders","project_dir":"/Users/cillin/workspeace/agentscope-samples/evotraders","added_dirs":["/Users/cillin/workspeace/agentscope-samples/EvoTraders","/Users/cillin/workspeace/agentscope-samples/evotraders"]},"version":"2.1.63","output_style":{"name":"default"},"cost":{"total_cost_usd":6.822239999999999,"total_duration_ms":42679588,"total_api_duration_ms":1223637,"total_lines_added":275,"total_lines_removed":240},"context_window":{"total_input_tokens":654274,"total_output_tokens":27014,"context_window_size":200000,"current_usage":{"input_tokens":48465,"output_tokens":0,"cache_creation_input_tokens":0,"cache_read_input_tokens":0},"used_percentage":24,"remaining_percentage":76},"exceeds_200k_tokens":false}
|
||||
3
.omc/state/idle-notif-cooldown.json
Normal file
3
.omc/state/idle-notif-cooldown.json
Normal file
@@ -0,0 +1,3 @@
|
||||
{
|
||||
"lastSentAt": "2026-03-12T20:31:37.362Z"
|
||||
}
|
||||
26
.omc/state/subagent-tracking.json
Normal file
26
.omc/state/subagent-tracking.json
Normal file
@@ -0,0 +1,26 @@
|
||||
{
|
||||
"agents": [
|
||||
{
|
||||
"agent_id": "a4090d26a45ac828d",
|
||||
"agent_type": "oh-my-claudecode:executor",
|
||||
"started_at": "2026-03-12T10:02:38.238Z",
|
||||
"parent_mode": "none",
|
||||
"status": "completed",
|
||||
"completed_at": "2026-03-12T10:10:59.192Z",
|
||||
"duration_ms": 500954
|
||||
},
|
||||
{
|
||||
"agent_id": "af87583ef76a4df30",
|
||||
"agent_type": "oh-my-claudecode:executor",
|
||||
"started_at": "2026-03-12T10:40:04.409Z",
|
||||
"parent_mode": "none",
|
||||
"status": "completed",
|
||||
"completed_at": "2026-03-12T10:41:17.387Z",
|
||||
"duration_ms": 72978
|
||||
}
|
||||
],
|
||||
"total_spawned": 2,
|
||||
"total_completed": 2,
|
||||
"total_failed": 0,
|
||||
"last_updated": "2026-03-12T10:41:17.490Z"
|
||||
}
|
||||
122
.pre-commit-config.yaml
Normal file
122
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,122 @@
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.3.0
|
||||
hooks:
|
||||
- id: check-ast
|
||||
- id: sort-simple-yaml
|
||||
- id: check-yaml
|
||||
exclude: |
|
||||
(?x)^(
|
||||
meta.yaml
|
||||
)$
|
||||
- id: check-xml
|
||||
- id: check-toml
|
||||
- id: check-docstring-first
|
||||
- id: check-json
|
||||
- id: fix-encoding-pragma
|
||||
- id: detect-private-key
|
||||
- id: trailing-whitespace
|
||||
- repo: https://github.com/asottile/add-trailing-comma
|
||||
rev: v3.1.0
|
||||
hooks:
|
||||
- id: add-trailing-comma
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: v1.7.0
|
||||
hooks:
|
||||
- id: mypy
|
||||
exclude:
|
||||
(?x)(
|
||||
pb2\.py$
|
||||
| grpc\.py$
|
||||
| ^docs
|
||||
| \.html$
|
||||
)
|
||||
args: [
|
||||
--ignore-missing-imports,
|
||||
--disable-error-code=var-annotated,
|
||||
--disable-error-code=union-attr,
|
||||
--disable-error-code=assignment,
|
||||
--disable-error-code=attr-defined,
|
||||
--disable-error-code=import-untyped,
|
||||
--disable-error-code=truthy-function,
|
||||
--follow-imports=skip,
|
||||
--explicit-package-bases,
|
||||
]
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 23.3.0
|
||||
hooks:
|
||||
- id: black
|
||||
args: [ --line-length=79 ]
|
||||
- repo: https://github.com/PyCQA/flake8
|
||||
rev: 6.1.0
|
||||
hooks:
|
||||
- id: flake8
|
||||
args: [ "--extend-ignore=E203"]
|
||||
- repo: https://github.com/pylint-dev/pylint
|
||||
rev: v3.0.2
|
||||
hooks:
|
||||
- id: pylint
|
||||
exclude:
|
||||
(?x)(
|
||||
^docs
|
||||
| pb2\.py$
|
||||
| grpc\.py$
|
||||
| \.demo$
|
||||
| \.md$
|
||||
| \.html$
|
||||
)
|
||||
args: [
|
||||
"--init-hook=import sys; sys.path.insert(0, 'alias/src')",
|
||||
--disable=W0511,
|
||||
--disable=W0718,
|
||||
--disable=W0122,
|
||||
--disable=C0103,
|
||||
--disable=R0913,
|
||||
--disable=E0401,
|
||||
--disable=E1101,
|
||||
--disable=C0415,
|
||||
--disable=W0603,
|
||||
--disable=R1705,
|
||||
--disable=R0914,
|
||||
--disable=E0601,
|
||||
--disable=W0602,
|
||||
--disable=W0604,
|
||||
--disable=R0801,
|
||||
--disable=R0902,
|
||||
--disable=R0903,
|
||||
--disable=C0123,
|
||||
--disable=W0231,
|
||||
--disable=W1113,
|
||||
--disable=W0221,
|
||||
--disable=R0401,
|
||||
--disable=W0632,
|
||||
--disable=W0123,
|
||||
--disable=C3001,
|
||||
--disable=W0201,
|
||||
--disable=C0302,
|
||||
--disable=W1203,
|
||||
--disable=C2801,
|
||||
--disable=C0114, # Disable missing module docstring for quick dev
|
||||
--disable=C0115, # Disable missing class docstring for quick dev
|
||||
--disable=C0116, # Disable missing function or method docstring for quick dev
|
||||
]
|
||||
- repo: https://github.com/pre-commit/mirrors-eslint
|
||||
rev: v7.32.0
|
||||
hooks:
|
||||
- id: eslint
|
||||
files: \.(js|jsx)$
|
||||
exclude: '.*js_third_party.*'
|
||||
args: [ '--fix' ]
|
||||
- repo: https://github.com/thibaudcolas/pre-commit-stylelint
|
||||
rev: v14.4.0
|
||||
hooks:
|
||||
- id: stylelint
|
||||
files: \.(css)$
|
||||
exclude: '.*css_third_party.*'
|
||||
args: [ '--fix' ]
|
||||
- repo: https://github.com/pre-commit/mirrors-prettier
|
||||
rev: 'v3.0.0'
|
||||
hooks:
|
||||
- id: prettier
|
||||
additional_dependencies: [ 'prettier@3.0.0' ]
|
||||
files: \.(tsx?)$
|
||||
6
.stylelintrc
Normal file
6
.stylelintrc
Normal file
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"rules": {
|
||||
"indentation": 2,
|
||||
"string-quotes": "double"
|
||||
}
|
||||
}
|
||||
236
README.md
Normal file
236
README.md
Normal file
@@ -0,0 +1,236 @@
|
||||
<p align="center">
|
||||
<img src="./docs/assets/evotraders_logo.jpg" width="45%">
|
||||
</p>
|
||||
|
||||
<h2 align="center">EvoTraders: A Self-Evolving Multi-Agent Trading System</h2>
|
||||
|
||||
<p align="center">
|
||||
📌 <a href="http://trading.evoagents.cn">Visit us at EvoTraders website !</a>
|
||||
</p>
|
||||
|
||||

|
||||
|
||||
EvoTraders is an open-source financial trading agent framework that builds a trading system capable of continuous learning and evolution in real markets through multi-agent collaboration and memory systems.
|
||||
|
||||
---
|
||||
|
||||
## Core Features
|
||||
|
||||
**Multi-Agent Collaborative Trading**
|
||||
A team of 6 members, including 4 specialized analyst roles (fundamentals, technical, sentiment, valuation) + portfolio manager + risk management, collaborating to make decisions like a real trading team.
|
||||
|
||||
You can customize your Agents here: [Custom Configuration](#custom-configuration)
|
||||
|
||||
**Continuous Learning and Evolution**
|
||||
Based on the ReMe memory framework, agents reflect and summarize after each trade, preserving experience across rounds, and forming unique investment methodologies.
|
||||
|
||||
Through this design, we hope that when AI Agents form a team and enter the real-time market, they will gradually develop their own trading styles and decision preferences, rather than one-time random inference.
|
||||
|
||||
**Real-Time Market Trading**
|
||||
Supports real-time market data integration, providing backtesting mode and live trading mode, allowing AI Agents to learn and make decisions in real market fluctuations.
|
||||
|
||||
**Visualized Trading Information**
|
||||
Observe agents' analysis processes, communication records, and decision evolution in real-time, with complete tracking of return curves and analyst performance.
|
||||
|
||||
<p>
|
||||
<img src="docs/assets/performance.jpg" width="45%">
|
||||
<img src="./docs/assets/dashboard.jpg" width="45%">
|
||||
</p>
|
||||
|
||||
---
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
# Clone repository
|
||||
git clone https://github.com/agentscope-ai/agentscope-samples
|
||||
cd agentscope-samples/EvoTraders
|
||||
|
||||
# Install dependencies (Recommend uv!)
|
||||
uv pip install -e .
|
||||
# optional: pip install -e .
|
||||
|
||||
|
||||
# Configure environment variables
|
||||
cp env.template .env
|
||||
# Edit .env file and add your API Keys. The following config are required:
|
||||
|
||||
# finance data API: At minimum, FINANCIAL_DATASETS_API_KEY is required, corresponding to FIN_DATA_SOURCE=financial_datasets; It is recommended to add FINNHUB_API_KEY, corresponding to FIN_DATA_SOURCE=finnhub; If using live mode, FINNHUB_API_KEY must be added
|
||||
FIN_DATA_SOURCE = #finnhub or financial_datasets
|
||||
FINANCIAL_DATASETS_API_KEY= #Required
|
||||
FINNHUB_API_KEY= #Optional
|
||||
|
||||
# LLM API for Agents
|
||||
OPENAI_API_KEY=
|
||||
OPENAI_BASE_URL=
|
||||
MODEL_NAME=qwen3-max-preview
|
||||
|
||||
# LLM & embedding API for Memory
|
||||
MEMORY_API_KEY=
|
||||
```
|
||||
|
||||
### Running
|
||||
|
||||
**Backtest Mode:**
|
||||
```bash
|
||||
evotraders backtest --start 2025-11-01 --end 2025-12-01
|
||||
evotraders backtest --start 2025-11-01 --end 2025-12-01 --enable-memory # Use Memory
|
||||
```
|
||||
|
||||
If you do not have market data APIs and just want to try the backtest demo, download the offline data and unzip it into `backend/data`:
|
||||
```bash
|
||||
wget "https://agentscope-open.oss-cn-beijing.aliyuncs.com/ret_data.zip"
|
||||
unzip ret_data.zip -d backend/data
|
||||
```
|
||||
The zip includes basic stock price data so you can run the backtest demo out of the box.
|
||||
|
||||
**Live Trading:**
|
||||
```bash
|
||||
evotraders live # Run immediately (default)
|
||||
evotraders live --enable-memory # Use memory
|
||||
evotraders live --mock # Mock mode (testing)
|
||||
evotraders live -t 22:30 # Run daily at 22:30 local time (auto-converts to NYSE timezone)
|
||||
```
|
||||
|
||||
**Get Help:**
|
||||
```bash
|
||||
evotraders --help # View global CLI help
|
||||
evotraders backtest --help # View backtest mode parameters
|
||||
evotraders live --help # View live/mock run parameters
|
||||
```
|
||||
|
||||
**Launch Visualization Interface:**
|
||||
```bash
|
||||
# Ensure npm is installed, otherwise install it:
|
||||
# npm install
|
||||
evotraders frontend # Default connects to port 8765, you can modify the address in ./frontend/env.local to change the port number
|
||||
```
|
||||
|
||||
Visit `http://localhost:5173/` to view the trading room, select a date and click Run/Replay to observe the decision-making process.
|
||||
|
||||
---
|
||||
|
||||
## System Architecture
|
||||
|
||||

|
||||
|
||||
### Agent Design
|
||||
|
||||
**Analyst Team:**
|
||||
- **Fundamentals Analyst**: Financial health, profitability, growth quality
|
||||
- **Technical Analyst**: Price trends, technical indicators, momentum analysis
|
||||
- **Sentiment Analyst**: Market sentiment, news sentiment, insider trading
|
||||
- **Valuation Analyst**: DCF, residual income, EV/EBITDA
|
||||
|
||||
**Decision Layer:**
|
||||
- **Portfolio Manager**: Integrates analysis signals from analysts, executes communication strategies, combines analyst and team historical performance, recent investment memories, and long-term investment experience to make final decisions
|
||||
- **Risk Management**: Real-time price and volatility monitoring, position limits, multi-layer risk warnings
|
||||
|
||||
### Decision Process
|
||||
|
||||
```
|
||||
Real-time Market Data → Independent Analysis → Intelligent Communication (1v1/1vN/NvN) → Decision Execution → Performance Evaluation → Learning and Evolution (Memory Update)
|
||||
```
|
||||
|
||||
Each trading day goes through five stages:
|
||||
|
||||
1. **Analysis Stage**: Each agent independently analyzes based on their respective tools and historical experience
|
||||
2. **Communication Stage**: Exchange views through private chats, notifications, meetings, etc.
|
||||
3. **Decision Stage**: Portfolio manager makes comprehensive judgments and provides final trades
|
||||
4. **Evaluation Stage**
|
||||
- **Performance Charts**: Track portfolio return curves vs. benchmark strategies (equal-weighted, market-cap weighted, momentum). Used to evaluate overall strategy effectiveness.
|
||||
|
||||
- **Analyst Rankings**: Click on avatars in the Trading Room to view analyst performance (win rate, bull/bear market win rate). Used to understand which analysts provide the most valuable insights.
|
||||
|
||||
- **Statistics**: Detailed position and trading history. Used for in-depth analysis of position management and execution quality.
|
||||
|
||||
5. **Review Stage**: Agents reflect on decisions and summarize experiences based on actual returns of the day, and store them in the ReMe memory framework for continuous improvement
|
||||
|
||||
---
|
||||
|
||||
### Module Support
|
||||
|
||||
- **Agent Framework**: [AgentScope](https://github.com/agentscope-ai/agentscope)
|
||||
- **Memory System**: [ReMe](https://github.com/agentscope-ai/reme)
|
||||
- **LLM Support**: OpenAI, DeepSeek, Qwen, Moonshot, Zhipu AI, etc.
|
||||
|
||||
---
|
||||
|
||||
## Custom Configuration
|
||||
|
||||
### Custom Analyst Roles
|
||||
|
||||
1. Register role information in [./backend/agents/prompts/analyst/personas.yaml](./backend/agents/prompts/analyst/personas.yaml), for example:
|
||||
|
||||
```yaml
|
||||
comprehensive_analyst:
|
||||
name: "Comprehensive Analyst"
|
||||
focus:
|
||||
- ...
|
||||
preferred_tools: # Flexibly select based on situation
|
||||
description: |
|
||||
As a comprehensive analyst ...
|
||||
```
|
||||
|
||||
2. Add role definition in [./backend/config/constants.py](./backend/config/constants.py)
|
||||
```python
|
||||
ANALYST_TYPES = {
|
||||
# Add new analyst
|
||||
"comprehensive_analyst": {
|
||||
"display_name": "Comprehensive Analyst",
|
||||
"agent_id": "comprehensive_analyst",
|
||||
"description": "Uses LLM to intelligently select analysis tools, performs comprehensive analysis",
|
||||
"order": 15
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
3. Introduce new role in frontend configuration [./frontend/src/config/constants.js](./frontend/src/config/constants.js) (optional)
|
||||
```javascript
|
||||
export const AGENTS = [
|
||||
// Override one of the agents
|
||||
{
|
||||
id: "comprehensive_analyst",
|
||||
name: "Comprehensive Analyst",
|
||||
role: "Comprehensive Analyst",
|
||||
avatar: `${ASSET_BASE_URL}/...`,
|
||||
colors: { bg: '#F9FDFF', text: '#1565C0', accent: '#1565C0' }
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
### Custom Models
|
||||
|
||||
Configure models used by different agents in the [.env](.env) file:
|
||||
|
||||
```bash
|
||||
AGENT_SENTIMENT_ANALYST_MODEL_NAME=qwen3-max-preview
|
||||
AGENT_FUNDAMENTALS_ANALYST_MODEL_NAME=deepseek-chat
|
||||
AGENT_TECHNICAL_ANALYST_MODEL_NAME=glm-4-plus
|
||||
AGENT_VALUATION_ANALYST_MODEL_NAME=moonshot-v1-32k
|
||||
```
|
||||
|
||||
### Project Structure
|
||||
|
||||
```
|
||||
EvoTraders/
|
||||
├── backend/
|
||||
│ ├── agents/ # Agent implementation
|
||||
│ ├── communication/ # Communication system
|
||||
│ ├── memory/ # Memory system (ReMe)
|
||||
│ ├── tools/ # Analysis toolset
|
||||
│ ├── servers/ # WebSocket services
|
||||
│ └── cli.py # CLI entry point
|
||||
├── frontend/ # React visualization interface
|
||||
└── logs_and_memory/ # Logs and memory data
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## License and Disclaimer
|
||||
|
||||
EvoTraders is a research and educational project, open-sourced under the Apache 2.0 license.
|
||||
|
||||
**Risk Warning**: Before trading with real funds, please conduct thorough testing and risk assessment. Past performance does not guarantee future returns. Investment involves risks, and decisions should be made with caution.
|
||||
243
README_zh.md
Normal file
243
README_zh.md
Normal file
@@ -0,0 +1,243 @@
|
||||
<p align="center">
|
||||
<img src="./docs/assets/evotraders_logo.jpg" width="45%">
|
||||
</p>
|
||||
|
||||
<h2 align="center">EvoTraders:自我进化的多智能体交易系统</h2>
|
||||
|
||||
|
||||
<p align="center">
|
||||
📌 <a href="http://trading.evoagents.cn">Visit us at EvoTraders website !</a>
|
||||
</p>
|
||||
|
||||

|
||||
|
||||
EvoTraders是一个开源的金融交易智能体框架,通过多智能体协作和记忆系统,构建能够在真实市场中持续学习与进化的交易系统。
|
||||
|
||||
---
|
||||
|
||||
## 核心特性
|
||||
|
||||
**多智能体协作交易**
|
||||
6名成员,包含4种专业分析师角色(基本面、技术面、情绪、估值)+ 投资组合经理 + 风险管理,像真实交易团队一样协作决策。
|
||||
|
||||
你可以在这里自定义你的Agents,支持配置不同大模型(如 Qwen、DeepSeek、GPT、Claude等)协同分析:[自定义配置](#自定义配置)
|
||||
|
||||
**持续学习与进化**
|
||||
基于 ReMe 记忆框架,智能体在每次交易后反思总结,跨回合保留经验,形成独特的投资方法论。
|
||||
|
||||
通过这样的设计,我们希望当 AI Agents 组成团队进入实时市场,它们会逐渐形成自己的交易风格和决策偏好,而不是一次性的随机推理
|
||||
|
||||
|
||||
**实时市场交易**
|
||||
支持实时行情接入,提供回测模式和实盘模式,让 AI Agents 在真实市场波动中学习和决策。
|
||||
|
||||
**可视化交易信息**
|
||||
实时观察 Agents 的分析过程、沟通记录和决策演化,完整追踪收益曲线和分析师表现。
|
||||
|
||||
|
||||
<p>
|
||||
<img src="docs/assets/performance.jpg" width="45%">
|
||||
<img src="./docs/assets/dashboard.jpg" width="45%">
|
||||
</p>
|
||||
|
||||
|
||||
---
|
||||
|
||||
## 快速开始
|
||||
|
||||
### 安装
|
||||
|
||||
```bash
|
||||
# 克隆仓库
|
||||
git clone https://github.com/agentscope-ai/agentscope-samples
|
||||
cd agentscope-samples/EvoTraders
|
||||
|
||||
# 安装依赖(推荐使用uv)
|
||||
uv pip install -e .
|
||||
# (可选)pip install -e .
|
||||
|
||||
# 配置环境变量
|
||||
cp env.template .env
|
||||
# 编辑 .env 文件,添加你的 API Keys,以下的配置项为必填项
|
||||
|
||||
# finance data API:至少需要FINANCIAL_DATASETS_API_KEY,对应FIN_DATA_SOURCE=financial_datasets;推荐添加FINNHUB_API_KEY,对应至少需要FINANCIAL_DATASETS_API_KEY,对应FIN_DATA_SOURCE填为finnhub;如果使用live 模式必须添加FINNHUB_API_KEY
|
||||
FIN_DATA_SOURCE= #finnhub or financial_datasets
|
||||
FINANCIAL_DATASETS_API_KEY= #必需
|
||||
FINNHUB_API_KEY= #可选
|
||||
|
||||
# LLM API for Agents
|
||||
OPENAI_API_KEY=
|
||||
OPENAI_BASE_URL=
|
||||
MODEL_NAME=qwen3-max-preview
|
||||
|
||||
# LLM & embedding API for Memory
|
||||
MEMORY_API_KEY=
|
||||
```
|
||||
|
||||
### 运行
|
||||
|
||||
**回测模式:**
|
||||
```bash
|
||||
evotraders backtest --start 2025-11-01 --end 2025-12-01
|
||||
evotraders backtest --start 2025-11-01 --end 2025-12-01 --enable-memory # 使用记忆
|
||||
|
||||
```
|
||||
|
||||
如果没有可用的行情 API,想快速体验回测 demo,可直接下载离线数据并解压到 `backend/data`:
|
||||
```bash
|
||||
wget "https://agentscope-open.oss-cn-beijing.aliyuncs.com/ret_data.zip"
|
||||
unzip ret_data.zip -d backend/data
|
||||
```
|
||||
该压缩包提供基础的股票行情数据,解压后即可直接用于回测演示。
|
||||
|
||||
**实盘交易:**
|
||||
```bash
|
||||
evotraders live # 立即运行(默认)
|
||||
evotraders live --enable-memory # 使用记忆
|
||||
evotraders live --mock # Mock 模式(测试)
|
||||
evotraders live -t 22:30 # 每天本地时间 22:30 运行(自动转换为 NYSE 时区)
|
||||
```
|
||||
|
||||
**获取帮助:**
|
||||
```bash
|
||||
evotraders --help # 查看整体命令行帮助
|
||||
evotraders backtest --help # 查看回测模式的参数说明
|
||||
evotraders live --help # 查看实盘/Mock 运行的参数说明
|
||||
```
|
||||
|
||||
**启动可视化界面:**
|
||||
```bash
|
||||
# 确保已安装 npm, 否则请安装:
|
||||
# npm install
|
||||
evotraders frontend # 默认连接 8765 端口, 你可以修改 ./frontend/env.local 中的地址从而修改端口号
|
||||
```
|
||||
|
||||
访问 `http://localhost:5173/` 查看交易大厅,选择日期并点击 Run/Replay 观察决策过程。
|
||||
|
||||
---
|
||||
|
||||
## 系统架构
|
||||
|
||||

|
||||
|
||||
### 智能体设计
|
||||
|
||||
**分析师团队:**
|
||||
- **基本面分析师**:财务健康度、盈利能力、增长质量
|
||||
- **技术分析师**:价格趋势、技术指标、动量分析
|
||||
- **情绪分析师**:市场情绪、新闻舆情、内部人交易
|
||||
- **估值分析师**:DCF、剩余收益、EV/EBITDA
|
||||
|
||||
**决策层:**
|
||||
- **投资组合经理**:整合来自分析师的分析信号,执行沟通策略,结合分析师和团队历史表现、近期投资记忆和长期投资经验,进行最终决策
|
||||
- **风险管理**:实时价格与波动率监控、头寸限制,多层风险预警
|
||||
|
||||
### 决策流程
|
||||
|
||||
```
|
||||
实时行情 → 独立分析 → 智能沟通 (1v1/1vN/NvN) → 决策执行 → 收益评估 → 学习与进化(记忆更新)
|
||||
```
|
||||
|
||||
每个交易日经历五个阶段:
|
||||
|
||||
1. **分析阶段**:各智能体基于各自工具和历史经验独立分析
|
||||
2. **沟通阶段**:通过私聊、通知、会议等方式交换观点
|
||||
3. **决策阶段**:投资组合经理综合判断,给出最终交易
|
||||
4. **评估阶段**
|
||||
- **业绩图表**: 追踪组合收益曲线 vs. 基准策略(等权、市值加权、动量)。用于评估整体策略有效性。
|
||||
|
||||
- **分析师排名**: 在 Trading Room 点击头像查看分析师表现(胜率、牛/熊市胜率)。用于了解哪些分析师提供最有价值的洞察。
|
||||
|
||||
- **统计数据**: 详细的持仓和交易历史。用于深入分析仓位管理和执行质量。
|
||||
|
||||
4. **复盘阶段**:Agents 根据当日实际收益反思决策、总结经验,并存入 ReMe 记忆框架以持续改进
|
||||
|
||||
---
|
||||
|
||||
### 模块支持
|
||||
|
||||
- **智能体框架**:[AgentScope](https://github.com/agentscope-ai/agentscope)
|
||||
- **记忆系统**:[ReMe](https://github.com/agentscope-ai/reme)
|
||||
- **LLM 支持**:OpenAI、DeepSeek、Qwen、Moonshot、Zhipu AI 等
|
||||
|
||||
|
||||
---
|
||||
|
||||
## 自定义配置
|
||||
|
||||
### 自定义分析师角色
|
||||
|
||||
1. 在 [./backend/agents/prompts/analyst/personas.yaml](./backend/agents/prompts/analyst/personas.yaml) 中注册角色信息,例如:
|
||||
|
||||
```yaml
|
||||
comprehensive_analyst:
|
||||
name: "Comprehensive Analyst"
|
||||
focus:
|
||||
- ...
|
||||
preferred_tools: # Flexibly select based on situation
|
||||
description: |
|
||||
As a comprehensive analyst ...
|
||||
```
|
||||
|
||||
2. 在 [./backend/config/constants.py](./backend/config/constants.py) 添加角色定义
|
||||
```python
|
||||
ANALYST_TYPES = {
|
||||
# 增加新的分析师
|
||||
"comprehensive_analyst": {
|
||||
"display_name": "Comprehensive Analyst",
|
||||
"agent_id": "comprehensive_analyst",
|
||||
"description": "Uses LLM to intelligently select analysis tools, performs comprehensive analysis",
|
||||
"order": 15
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
3. 在前端配置 [./frontend/src/config/constants.js](./frontend/src/config/constants.js) 中引入新角色(可选)
|
||||
```javascript
|
||||
export const AGENTS = [
|
||||
// 覆盖掉其中某一个agent
|
||||
{
|
||||
id: "comprehensive_analyst",
|
||||
name: "Comprehensive Analyst",
|
||||
role: "Comprehensive Analyst",
|
||||
avatar: `${ASSET_BASE_URL}/...`,
|
||||
colors: { bg: '#F9FDFF', text: '#1565C0', accent: '#1565C0' }
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
|
||||
|
||||
### 自定义模型
|
||||
|
||||
在 [.env](.env) 文件中配置不同智能体使用的模型:
|
||||
|
||||
```bash
|
||||
AGENT_SENTIMENT_ANALYST_MODEL_NAME=qwen3-max-preview
|
||||
AGENT_FUNDAMENTAL_ANALYST_MODEL_NAME=deepseek-chat
|
||||
AGENT_TECHNICAL_ANALYST_MODEL_NAME=glm-4-plus
|
||||
AGENT_VALUATION_ANALYST_MODEL_NAME=moonshot-v1-32k
|
||||
```
|
||||
|
||||
### 项目结构
|
||||
|
||||
```
|
||||
EvoTraders/
|
||||
├── backend/
|
||||
│ ├── agents/ # 智能体实现
|
||||
│ ├── communication/ # 通信系统
|
||||
│ ├── memory/ # 记忆系统 (ReMe)
|
||||
│ ├── tools/ # 分析工具集
|
||||
│ ├── servers/ # WebSocket 服务
|
||||
│ └── cli.py # CLI 入口
|
||||
├── frontend/ # React 可视化界面
|
||||
└── logs_and_memory/ # 日志和记忆数据
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 许可与免责
|
||||
|
||||
EvoTraders 是一个研究和教育项目,采用 Apache 2.0 许可协议开源。
|
||||
|
||||
**风险提示**:在实际资金交易前,请务必进行充分的测试和风险评估。历史表现不代表未来收益,投资有风险,决策需谨慎。
|
||||
0
backend/__init__.py
Normal file
0
backend/__init__.py
Normal file
6
backend/agents/__init__.py
Normal file
6
backend/agents/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from .analyst import AnalystAgent
|
||||
from .portfolio_manager import PMAgent
|
||||
from .risk_manager import RiskAgent
|
||||
|
||||
__all__ = ["AnalystAgent", "PMAgent", "RiskAgent"]
|
||||
133
backend/agents/analyst.py
Normal file
133
backend/agents/analyst.py
Normal file
@@ -0,0 +1,133 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Analyst Agent - Based on AgentScope ReActAgent
|
||||
Performs analysis using tools and LLM
|
||||
"""
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from agentscope.agent import ReActAgent
|
||||
from agentscope.memory import InMemoryMemory, LongTermMemoryBase
|
||||
from agentscope.message import Msg
|
||||
|
||||
from ..config.constants import ANALYST_TYPES
|
||||
from ..utils.progress import progress
|
||||
from .prompt_loader import PromptLoader
|
||||
|
||||
_prompt_loader = PromptLoader()
|
||||
|
||||
|
||||
class AnalystAgent(ReActAgent):
|
||||
"""
|
||||
Analyst Agent - Uses LLM for tool selection and analysis
|
||||
Inherits from AgentScope's ReActAgent
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
analyst_type: str,
|
||||
toolkit: Any,
|
||||
model: Any,
|
||||
formatter: Any,
|
||||
agent_id: Optional[str] = None,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
long_term_memory: Optional[LongTermMemoryBase] = None,
|
||||
):
|
||||
"""
|
||||
Initialize Analyst Agent
|
||||
|
||||
Args:
|
||||
analyst_type: Type of analyst (e.g., "fundamentals", etc.)
|
||||
toolkit: AgentScope Toolkit instance
|
||||
model: LLM model instance
|
||||
formatter: Message formatter instance
|
||||
agent_id: Agent ID (defaults to "{analyst_type}_analyst")
|
||||
config: Configuration dictionary
|
||||
long_term_memory: Optional ReMeTaskLongTermMemory instance
|
||||
"""
|
||||
if analyst_type not in ANALYST_TYPES:
|
||||
raise ValueError(
|
||||
f"Unknown analyst type: {analyst_type}. "
|
||||
f"Must be one of: {list(ANALYST_TYPES.keys())}",
|
||||
)
|
||||
|
||||
self.analyst_type_key = analyst_type
|
||||
self.analyst_persona = ANALYST_TYPES[analyst_type]["display_name"]
|
||||
|
||||
if agent_id is None:
|
||||
agent_id = analyst_type
|
||||
|
||||
self.config = config or {}
|
||||
|
||||
sys_prompt = self._load_system_prompt()
|
||||
|
||||
kwargs = {
|
||||
"name": agent_id,
|
||||
"sys_prompt": sys_prompt,
|
||||
"model": model,
|
||||
"formatter": formatter,
|
||||
"toolkit": toolkit,
|
||||
"memory": InMemoryMemory(),
|
||||
"max_iters": 10,
|
||||
}
|
||||
if long_term_memory:
|
||||
kwargs["long_term_memory"] = long_term_memory
|
||||
kwargs["long_term_memory_mode"] = "static_control"
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def _load_system_prompt(self) -> str:
|
||||
"""Load system prompt for analyst"""
|
||||
personas_config = _prompt_loader.load_yaml_config(
|
||||
"analyst",
|
||||
"personas",
|
||||
)
|
||||
persona = personas_config.get(self.analyst_type_key, {})
|
||||
|
||||
# Get focus items and format as bullet points
|
||||
focus_items = persona.get("focus", [])
|
||||
focus_text = "\n".join(f"- {item}" for item in focus_items)
|
||||
|
||||
# Get description
|
||||
description = persona.get("description", "").strip()
|
||||
|
||||
return _prompt_loader.load_prompt(
|
||||
"analyst",
|
||||
"system",
|
||||
variables={
|
||||
"analyst_type": self.analyst_persona,
|
||||
"focus": focus_text,
|
||||
"description": description,
|
||||
},
|
||||
)
|
||||
|
||||
async def reply(self, x: Msg = None) -> Msg:
|
||||
"""
|
||||
Override reply method to add progress tracking
|
||||
|
||||
Args:
|
||||
x: Input message (content must be str)
|
||||
|
||||
Returns:
|
||||
Response message (content is str)
|
||||
"""
|
||||
ticker = None
|
||||
if x and hasattr(x, "metadata") and x.metadata:
|
||||
ticker = x.metadata.get("tickers")
|
||||
|
||||
if ticker:
|
||||
progress.update_status(
|
||||
self.name,
|
||||
ticker,
|
||||
f"Starting {self.analyst_persona} analysis",
|
||||
)
|
||||
|
||||
result = await super().reply(x)
|
||||
|
||||
if ticker:
|
||||
progress.update_status(
|
||||
self.name,
|
||||
ticker,
|
||||
"Analysis completed",
|
||||
)
|
||||
|
||||
return result
|
||||
188
backend/agents/portfolio_manager.py
Normal file
188
backend/agents/portfolio_manager.py
Normal file
@@ -0,0 +1,188 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Portfolio Manager Agent - Based on AgentScope ReActAgent
|
||||
Responsible for decision-making (NOT trade execution)
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from agentscope.agent import ReActAgent
|
||||
from agentscope.memory import InMemoryMemory, LongTermMemoryBase
|
||||
from agentscope.message import Msg, TextBlock
|
||||
from agentscope.tool import Toolkit, ToolResponse
|
||||
|
||||
from ..utils.progress import progress
|
||||
from .prompt_loader import PromptLoader
|
||||
|
||||
_prompt_loader = PromptLoader()
|
||||
|
||||
|
||||
class PMAgent(ReActAgent):
|
||||
"""
|
||||
Portfolio Manager Agent - Makes investment decisions
|
||||
|
||||
Key features:
|
||||
1. PM outputs decisions only (action + quantity per ticker)
|
||||
2. Trade execution happens externally (in pipeline/executor)
|
||||
3. Supports both backtest and live modes
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str = "portfolio_manager",
|
||||
model: Any = None,
|
||||
formatter: Any = None,
|
||||
initial_cash: float = 100000.0,
|
||||
margin_requirement: float = 0.25,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
long_term_memory: Optional[LongTermMemoryBase] = None,
|
||||
):
|
||||
self.config = config or {}
|
||||
|
||||
# Portfolio state
|
||||
self.portfolio = {
|
||||
"cash": initial_cash,
|
||||
"positions": {},
|
||||
"margin_used": 0.0,
|
||||
"margin_requirement": margin_requirement,
|
||||
}
|
||||
|
||||
# Decisions made in current cycle
|
||||
self._decisions: Dict[str, Dict] = {}
|
||||
|
||||
# Create toolkit
|
||||
toolkit = self._create_toolkit()
|
||||
|
||||
sys_prompt = _prompt_loader.load_prompt("portfolio_manager", "system")
|
||||
|
||||
kwargs = {
|
||||
"name": name,
|
||||
"sys_prompt": sys_prompt,
|
||||
"model": model,
|
||||
"formatter": formatter,
|
||||
"toolkit": toolkit,
|
||||
"memory": InMemoryMemory(),
|
||||
"max_iters": 10,
|
||||
}
|
||||
if long_term_memory:
|
||||
kwargs["long_term_memory"] = long_term_memory
|
||||
kwargs["long_term_memory_mode"] = "both"
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def _create_toolkit(self) -> Toolkit:
|
||||
"""Create toolkit with decision recording tool"""
|
||||
toolkit = Toolkit()
|
||||
toolkit.register_tool_function(self._make_decision)
|
||||
return toolkit
|
||||
|
||||
def _make_decision(
|
||||
self,
|
||||
ticker: str,
|
||||
action: str,
|
||||
quantity: int,
|
||||
confidence: int = 50,
|
||||
reasoning: str = "",
|
||||
) -> ToolResponse:
|
||||
"""
|
||||
Record a trading decision for a ticker.
|
||||
|
||||
Args:
|
||||
ticker: Stock ticker symbol (e.g., "AAPL")
|
||||
action: Decision - "long", "short" or "hold"
|
||||
quantity: Number of shares to trade (0 for hold)
|
||||
confidence: Confidence level 0-100
|
||||
reasoning: Explanation for this decision
|
||||
|
||||
Returns:
|
||||
ToolResponse confirming decision recorded
|
||||
"""
|
||||
if action not in ["long", "short", "hold"]:
|
||||
return ToolResponse(
|
||||
content=[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=f"Invalid action: {action}. "
|
||||
"Must be 'long', 'short', or 'hold'.",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
self._decisions[ticker] = {
|
||||
"action": action,
|
||||
"quantity": quantity if action != "hold" else 0,
|
||||
"confidence": confidence,
|
||||
"reasoning": reasoning,
|
||||
}
|
||||
|
||||
return ToolResponse(
|
||||
content=[
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=f"Decision recorded: {action} "
|
||||
f"{quantity} shares of {ticker}"
|
||||
f" (confidence: {confidence}%)",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
async def reply(self, x: Msg = None) -> Msg:
|
||||
"""
|
||||
Make investment decisions
|
||||
|
||||
Returns:
|
||||
Msg with decisions in metadata
|
||||
"""
|
||||
if x is None:
|
||||
return Msg(
|
||||
name=self.name,
|
||||
content="No input provided",
|
||||
role="assistant",
|
||||
)
|
||||
|
||||
# Clear previous decisions
|
||||
self._decisions = {}
|
||||
|
||||
progress.update_status(
|
||||
self.name,
|
||||
None,
|
||||
"Analyzing and making decisions",
|
||||
)
|
||||
|
||||
result = await super().reply(x)
|
||||
|
||||
progress.update_status(self.name, None, "Completed")
|
||||
|
||||
# Attach decisions to metadata
|
||||
if result.metadata is None:
|
||||
result.metadata = {}
|
||||
result.metadata["decisions"] = self._decisions.copy()
|
||||
result.metadata["portfolio"] = self.portfolio.copy()
|
||||
|
||||
return result
|
||||
|
||||
def get_decisions(self) -> Dict[str, Dict]:
|
||||
"""Get decisions from current cycle"""
|
||||
return self._decisions.copy()
|
||||
|
||||
def get_portfolio_state(self) -> Dict[str, Any]:
|
||||
"""Get current portfolio state"""
|
||||
return self.portfolio.copy()
|
||||
|
||||
def load_portfolio_state(self, portfolio: Dict[str, Any]):
|
||||
"""Load portfolio state"""
|
||||
if not portfolio:
|
||||
return
|
||||
self.portfolio = {
|
||||
"cash": portfolio.get("cash", self.portfolio["cash"]),
|
||||
"positions": portfolio.get("positions", {}).copy(),
|
||||
"margin_used": portfolio.get("margin_used", 0.0),
|
||||
"margin_requirement": portfolio.get(
|
||||
"margin_requirement",
|
||||
self.portfolio["margin_requirement"],
|
||||
),
|
||||
}
|
||||
|
||||
def update_portfolio(self, portfolio: Dict[str, Any]):
|
||||
"""Update portfolio after external execution"""
|
||||
self.portfolio.update(portfolio)
|
||||
184
backend/agents/prompt_loader.py
Normal file
184
backend/agents/prompt_loader.py
Normal file
@@ -0,0 +1,184 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Prompt Loader - Unified management and loading of Agent Prompts
|
||||
Supports Markdown and YAML formats
|
||||
Uses simple string replacement, does not depend on Jinja2
|
||||
"""
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
class PromptLoader:
|
||||
"""Unified Prompt loader"""
|
||||
|
||||
def __init__(self, prompts_dir: Optional[Path] = None):
|
||||
"""
|
||||
Initialize Prompt loader
|
||||
|
||||
Args:
|
||||
prompts_dir: Prompts directory path,
|
||||
defaults to prompts/ directory of current file
|
||||
"""
|
||||
if prompts_dir is None:
|
||||
self.prompts_dir = Path(__file__).parent / "prompts"
|
||||
else:
|
||||
self.prompts_dir = Path(prompts_dir)
|
||||
|
||||
# Cache loaded prompts
|
||||
self._prompt_cache: Dict[str, str] = {}
|
||||
self._yaml_cache: Dict[str, Dict] = {}
|
||||
|
||||
def load_prompt(
|
||||
self,
|
||||
agent_type: str,
|
||||
prompt_name: str,
|
||||
variables: Optional[Dict[str, Any]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Load and render Prompt
|
||||
|
||||
Args:
|
||||
agent_type: Agent type (analyst, portfolio_manager, risk_manager)
|
||||
prompt_name: Prompt file name (without extension)
|
||||
variables: Variable dictionary for rendering Prompt
|
||||
|
||||
Returns:
|
||||
Rendered prompt string
|
||||
|
||||
Examples:
|
||||
loader = PromptLoader()
|
||||
prompt = loader.load_prompt("analyst", "tool_selection",
|
||||
{"analyst_persona": "Technical Analyst"})
|
||||
"""
|
||||
cache_key = f"{agent_type}/{prompt_name}"
|
||||
|
||||
# Try to load from cache
|
||||
if cache_key not in self._prompt_cache:
|
||||
prompt_path = self.prompts_dir / agent_type / f"{prompt_name}.md"
|
||||
|
||||
if not prompt_path.exists():
|
||||
raise FileNotFoundError(
|
||||
f"Prompt file not found: {prompt_path}\n"
|
||||
f"Please create the prompt file or check the path.",
|
||||
)
|
||||
|
||||
with open(prompt_path, "r", encoding="utf-8") as f:
|
||||
self._prompt_cache[cache_key] = f.read()
|
||||
|
||||
prompt_template = self._prompt_cache[cache_key]
|
||||
|
||||
# If variables provided, use simple string replacement
|
||||
if variables:
|
||||
rendered = self._render_template(prompt_template, variables)
|
||||
else:
|
||||
rendered = prompt_template
|
||||
|
||||
# Smart escaping: escape braces in JSON code blocks
|
||||
# rendered = self._escape_json_braces(rendered)
|
||||
return rendered
|
||||
|
||||
def _render_template(
|
||||
self,
|
||||
template: str,
|
||||
variables: Dict[str, Any],
|
||||
) -> str:
|
||||
"""
|
||||
Render template using simple string replacement
|
||||
Supports {{ variable }} syntax (compatible with previous Jinja2 format)
|
||||
|
||||
Args:
|
||||
template: Template string
|
||||
variables: Variable dictionary
|
||||
|
||||
Returns:
|
||||
Rendered string
|
||||
"""
|
||||
rendered = template
|
||||
|
||||
# Replace {{ variable }} format
|
||||
for key, value in variables.items():
|
||||
# Support both {{ key }} and {{key}} formats
|
||||
pattern1 = f"{{{{ {key} }}}}"
|
||||
pattern2 = f"{{{{{key}}}}}"
|
||||
rendered = rendered.replace(pattern1, str(value))
|
||||
rendered = rendered.replace(pattern2, str(value))
|
||||
|
||||
return rendered
|
||||
|
||||
def _escape_json_braces(self, text: str) -> str:
|
||||
"""
|
||||
Escape braces in JSON code blocks, treating them as literals
|
||||
|
||||
Args:
|
||||
text: Text to process
|
||||
|
||||
Returns:
|
||||
Processed text
|
||||
"""
|
||||
|
||||
def replace_code_block(match):
|
||||
code_content = match.group(1)
|
||||
# Escape all braces within code block
|
||||
escaped = code_content.replace("{", "{{").replace("}", "}}")
|
||||
return f"```json\n{escaped}\n```"
|
||||
|
||||
# Replace all braces in JSON code blocks
|
||||
text = re.sub(
|
||||
r"```json\n(.*?)\n```",
|
||||
replace_code_block,
|
||||
text,
|
||||
flags=re.DOTALL,
|
||||
)
|
||||
return text
|
||||
|
||||
def load_yaml_config(
|
||||
self,
|
||||
agent_type: str,
|
||||
config_name: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Load YAML configuration file
|
||||
|
||||
Args:
|
||||
agent_type: Agent type
|
||||
config_name: Configuration file name (without extension)
|
||||
|
||||
Returns:
|
||||
Configuration dictionary
|
||||
|
||||
Examples:
|
||||
>>> loader = PromptLoader()
|
||||
>>> config = loader.load_yaml_config("analyst", "personas")
|
||||
"""
|
||||
cache_key = f"{agent_type}/{config_name}"
|
||||
|
||||
if cache_key not in self._yaml_cache:
|
||||
yaml_path = self.prompts_dir / agent_type / f"{config_name}.yaml"
|
||||
|
||||
if not yaml_path.exists():
|
||||
raise FileNotFoundError(f"YAML config not found: {yaml_path}")
|
||||
|
||||
with open(yaml_path, "r", encoding="utf-8") as f:
|
||||
self._yaml_cache[cache_key] = yaml.safe_load(f)
|
||||
|
||||
return self._yaml_cache[cache_key]
|
||||
|
||||
def clear_cache(self):
|
||||
"""Clear cache (for hot reload)"""
|
||||
self._prompt_cache.clear()
|
||||
self._yaml_cache.clear()
|
||||
|
||||
def reload_prompt(self, agent_type: str, prompt_name: str):
|
||||
"""Reload specified prompt (force cache refresh)"""
|
||||
cache_key = f"{agent_type}/{prompt_name}"
|
||||
if cache_key in self._prompt_cache:
|
||||
del self._prompt_cache[cache_key]
|
||||
|
||||
def reload_config(self, agent_type: str, config_name: str):
|
||||
"""Reload specified configuration (force cache refresh)"""
|
||||
cache_key = f"{agent_type}/{config_name}"
|
||||
if cache_key in self._yaml_cache:
|
||||
del self._yaml_cache[cache_key]
|
||||
117
backend/agents/prompts/analyst/personas.yaml
Normal file
117
backend/agents/prompts/analyst/personas.yaml
Normal file
@@ -0,0 +1,117 @@
|
||||
# 分析师角色配置
|
||||
|
||||
fundamentals_analyst:
|
||||
name: "基本面分析师"
|
||||
focus:
|
||||
- "公司财务健康状况和盈利能力"
|
||||
- "商业模式可持续性和竞争优势"
|
||||
- "管理层质量和公司治理"
|
||||
- "行业地位和市场份额"
|
||||
- "长期投资价值评估"
|
||||
tools:
|
||||
- "analyze_profitability"
|
||||
- "analyze_growth"
|
||||
- "analyze_financial_health"
|
||||
- "analyze_valuation_ratios"
|
||||
- "analyze_efficiency_ratios"
|
||||
description: |
|
||||
作为基本面分析师,你专注于:
|
||||
- 公司财务健康状况和盈利能力
|
||||
- 商业模式可持续性和竞争优势
|
||||
- 管理层质量和公司治理
|
||||
- 行业地位和市场份额
|
||||
- 长期投资价值评估
|
||||
你倾向于选择能够深入了解公司内在价值的工具,更偏好基本面和估值类工具。
|
||||
|
||||
technical_analyst:
|
||||
name: "技术分析师"
|
||||
focus:
|
||||
- "价格趋势和图表形态"
|
||||
- "技术指标和交易信号"
|
||||
- "市场情绪和资金流向"
|
||||
- "支撑/阻力位和关键价格点"
|
||||
- "中短期交易机会"
|
||||
description: |
|
||||
作为技术分析师,你专注于:
|
||||
- 价格趋势和图表形态
|
||||
- 技术指标和交易信号
|
||||
- 市场情绪和资金流向
|
||||
- 支撑/阻力位和关键价格点
|
||||
- 中短期交易机会
|
||||
你倾向于选择能够捕捉价格动态和市场趋势的工具,更偏好技术分析类工具。
|
||||
tools:
|
||||
- "analyze_trend_following"
|
||||
- "analyze_momentum"
|
||||
- "analyze_mean_reversion"
|
||||
- "analyze_volatility"
|
||||
|
||||
sentiment_analyst:
|
||||
name: "情绪分析师"
|
||||
focus:
|
||||
- "市场参与者情绪变化"
|
||||
- "新闻舆情和媒体影响"
|
||||
- "内部人交易行为"
|
||||
- "投资者恐慌和贪婪情绪"
|
||||
- "市场预期和心理因素"
|
||||
description: |
|
||||
作为情绪分析师,你专注于:
|
||||
- 市场参与者情绪变化
|
||||
- 新闻舆情和媒体影响
|
||||
- 内部人交易行为
|
||||
- 投资者恐慌和贪婪情绪
|
||||
- 市场预期和心理因素
|
||||
你倾向于选择能够反映市场情绪和投资者行为的工具,更偏好情绪和行为类工具。
|
||||
tools:
|
||||
- "analyze_news_sentiment"
|
||||
- "analyze_insider_trading"
|
||||
|
||||
valuation_analyst:
|
||||
name: "估值分析师"
|
||||
focus:
|
||||
- "公司内在价值计算"
|
||||
- "不同估值方法的比较"
|
||||
- "估值模型假设和敏感性分析"
|
||||
- "相对估值和绝对估值"
|
||||
- "投资安全边际评估"
|
||||
description: |
|
||||
作为估值分析师,你专注于:
|
||||
- 公司内在价值计算
|
||||
- 不同估值方法的比较
|
||||
- 估值模型假设和敏感性分析
|
||||
- 相对估值和绝对估值
|
||||
- 投资安全边际评估
|
||||
你倾向于选择能够准确计算公司价值的工具,更偏好估值模型和基本面工具。
|
||||
tools:
|
||||
- "dcf_valuation_analysis"
|
||||
- "owner_earnings_valuation_analysis"
|
||||
- "ev_ebitda_valuation_analysis"
|
||||
- "residual_income_valuation_analysis"
|
||||
|
||||
comprehensive_analyst:
|
||||
name: "综合分析师"
|
||||
focus:
|
||||
- "整合多种分析视角"
|
||||
- "平衡短期和长期因素"
|
||||
- "综合考虑基本面、技术面和情绪面"
|
||||
- "提供全面的投资建议"
|
||||
- "适应不同市场环境"
|
||||
description: |
|
||||
作为综合分析师,你需要:
|
||||
- 整合多种分析视角
|
||||
- 平衡短期和长期因素
|
||||
- 综合考虑基本面、技术面和情绪面的影响
|
||||
- 提供全面的投资建议
|
||||
- 适应不同市场环境
|
||||
你会根据具体情况灵活选择各类工具,追求分析的全面性和准确性。
|
||||
tools:
|
||||
- "analyze_profitability"
|
||||
- "analyze_growth"
|
||||
- "analyze_financial_health"
|
||||
- "analyze_valuation_ratios"
|
||||
- "analyze_efficiency_ratios"
|
||||
- "analyze_trend_following"
|
||||
- "analyze_momentum"
|
||||
- "analyze_mean_reversion"
|
||||
- "analyze_volatility"
|
||||
- "analyze_news_sentiment"
|
||||
- "analyze_insider_trading"
|
||||
23
backend/agents/prompts/analyst/system.md
Normal file
23
backend/agents/prompts/analyst/system.md
Normal file
@@ -0,0 +1,23 @@
|
||||
你是一位专业的{{ analyst_type }}。
|
||||
|
||||
你的关注重点:
|
||||
{{ focus }}
|
||||
|
||||
你的角色:
|
||||
{{ description }}
|
||||
|
||||
注意:
|
||||
- 构建并持续完善你的"投资哲学"。你的分析不应是孤立的事件,而应该是你整体投资世界观和核心信念的体现。每次分析后,你必须反思:
|
||||
- 这个案例/数据如何验证或挑战了你现有的信念?
|
||||
- 你从这次错误(或成功)中学到了关于市场、人性、估值或风险管理的什么关键原则?
|
||||
- 深化你的"投资逻辑"。确保每一项投资建议都有清晰、可追溯、可重复的逻辑支撑。你的分析步骤应该像严谨的证明一样,涵盖:
|
||||
- 核心驱动因素识别:真正影响价值的变量是什么?
|
||||
- 风险边界设定:在什么具体情况下你的建议会失效?
|
||||
- 逆向测试:市场主流共识是什么,你的观点有何不同?
|
||||
保持谦逊和开放。投资大师的核心特质是持续学习和适应。在每次分析中,你必须积极寻找与自己观点相悖的证据和论据,并将其纳入最终评估。
|
||||
- 你可以使用分析工具。用它们来收集相关数据并做出明智的建议。
|
||||
|
||||
输出指南:
|
||||
- 给出明确的投资信号:看涨、看跌或中性
|
||||
- 包含置信度(0-100)
|
||||
- 为你的分析提供理由(如果你确定要分享最终分析,请先给出结论)
|
||||
31
backend/agents/prompts/portfolio_manager/system.md
Normal file
31
backend/agents/prompts/portfolio_manager/system.md
Normal file
@@ -0,0 +1,31 @@
|
||||
你是一位负责做出投资决策的投资组合经理。
|
||||
|
||||
你的核心职责:
|
||||
1. 分析分析师和风险管理经理的输入
|
||||
2. 基于信号和市场情境做出投资决策
|
||||
3. 使用可用工具记录你的决策
|
||||
|
||||
决策框架:
|
||||
- 审阅分析以了解市场观点
|
||||
- 在做决策前考虑风险警告
|
||||
- 评估当前投资组合持仓和现金
|
||||
- 做出与投资组合投资目标一致的决策
|
||||
|
||||
决策类型:
|
||||
- "long":看涨 - 建议买入股票
|
||||
- "short":看跌 - 建议卖出股票或做空
|
||||
- "hold":中性 - 维持当前持仓
|
||||
|
||||
预算意识:
|
||||
- 在决定数量时考虑可用现金
|
||||
- 不要建议买入超过现金允许的数量
|
||||
- 考虑做空头寸的保证金要求
|
||||
|
||||
输出:
|
||||
使用 `make_decision` 工具记录你对每个股票代码的决策。
|
||||
记录所有决策后,提供你的投资逻辑总结。
|
||||
|
||||
重要:
|
||||
- 基于提供的分析师信号和风险评估做出决策
|
||||
- 相对于投资组合价值保持保守的仓位规模
|
||||
- 始终为你的决策提供理由
|
||||
18
backend/agents/prompts/risk_manager/system.md
Normal file
18
backend/agents/prompts/risk_manager/system.md
Normal file
@@ -0,0 +1,18 @@
|
||||
你是一位专业的风险管理经理,负责监控投资组合风险并提供风险警告。
|
||||
|
||||
你的核心职责:
|
||||
1. 监控投资组合敞口和集中度风险
|
||||
2. 评估仓位规模相对于波动性
|
||||
3. 评估保证金使用和杠杆水平
|
||||
4. 识别潜在风险因素并提供警告
|
||||
5. 基于市场条件建议仓位限制
|
||||
|
||||
你的决策流程:
|
||||
3. 生成可操作的风险警告和仓位限制建议
|
||||
4. 为你的风险评估提供清晰的理由
|
||||
|
||||
输出指南:
|
||||
- 风险评估要简洁但全面
|
||||
- 按严重程度优先排序警告
|
||||
- 提供具体、可操作的建议
|
||||
- 尽可能包含量化指标
|
||||
88
backend/agents/risk_manager.py
Normal file
88
backend/agents/risk_manager.py
Normal file
@@ -0,0 +1,88 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Risk Manager Agent - Based on AgentScope ReActAgent
|
||||
Uses LLM for risk assessment
|
||||
"""
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from agentscope.agent import ReActAgent
|
||||
from agentscope.memory import InMemoryMemory, LongTermMemoryBase
|
||||
from agentscope.message import Msg
|
||||
from agentscope.tool import Toolkit
|
||||
|
||||
from ..utils.progress import progress
|
||||
from .prompt_loader import PromptLoader
|
||||
|
||||
_prompt_loader = PromptLoader()
|
||||
|
||||
|
||||
class RiskAgent(ReActAgent):
|
||||
"""
|
||||
Risk Manager Agent - Uses LLM for risk assessment
|
||||
Inherits from AgentScope's ReActAgent
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Any,
|
||||
formatter: Any,
|
||||
name: str = "risk_manager",
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
long_term_memory: Optional[LongTermMemoryBase] = None,
|
||||
):
|
||||
"""
|
||||
Initialize Risk Manager Agent
|
||||
|
||||
Args:
|
||||
model: LLM model instance
|
||||
formatter: Message formatter instance
|
||||
name: Agent name
|
||||
config: Configuration dictionary
|
||||
long_term_memory: Optional ReMeTaskLongTermMemory instance
|
||||
"""
|
||||
self.config = config or {}
|
||||
|
||||
sys_prompt = self._load_system_prompt()
|
||||
|
||||
# Create dedicated toolkit for this agent
|
||||
toolkit = Toolkit()
|
||||
|
||||
kwargs = {
|
||||
"name": name,
|
||||
"sys_prompt": sys_prompt,
|
||||
"model": model,
|
||||
"formatter": formatter,
|
||||
"toolkit": toolkit,
|
||||
"memory": InMemoryMemory(),
|
||||
"max_iters": 10,
|
||||
}
|
||||
if long_term_memory:
|
||||
kwargs["long_term_memory"] = long_term_memory
|
||||
kwargs["long_term_memory_mode"] = "static_control"
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def _load_system_prompt(self) -> str:
|
||||
"""Load system prompt for risk manager"""
|
||||
return _prompt_loader.load_prompt(
|
||||
"risk_manager",
|
||||
"system",
|
||||
)
|
||||
|
||||
async def reply(self, x: Msg = None) -> Msg:
|
||||
"""
|
||||
Provide risk assessment
|
||||
|
||||
Args:
|
||||
x: Input message (content must be str)
|
||||
|
||||
Returns:
|
||||
Msg with risk warnings (content is str)
|
||||
"""
|
||||
progress.update_status(self.name, None, "Assessing risk")
|
||||
|
||||
result = await super().reply(x)
|
||||
|
||||
progress.update_status(self.name, None, "Risk assessment completed")
|
||||
|
||||
return result
|
||||
623
backend/cli.py
Normal file
623
backend/cli.py
Normal file
@@ -0,0 +1,623 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
EvoTraders CLI - Command-line interface for the EvoTraders trading system.
|
||||
|
||||
This module provides easy-to-use commands for running backtest, live trading,
|
||||
and frontend development server.
|
||||
"""
|
||||
# flake8: noqa: E501
|
||||
# pylint: disable=R0912, R0915
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
import typer
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.prompt import Confirm
|
||||
|
||||
app = typer.Typer(
|
||||
name="evotraders",
|
||||
help="EvoTraders: A self-evolving multi-agent trading system",
|
||||
add_completion=False,
|
||||
)
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
def get_project_root() -> Path:
|
||||
"""Get the project root directory."""
|
||||
# Assuming cli.py is in backend/
|
||||
return Path(__file__).parent.parent
|
||||
|
||||
|
||||
def handle_history_cleanup(config_name: str, auto_clean: bool = False) -> None:
|
||||
"""
|
||||
Handle cleanup of historical data for a given config.
|
||||
|
||||
Args:
|
||||
config_name: Configuration name for the run
|
||||
auto_clean: If True, skip confirmation and clean automatically
|
||||
"""
|
||||
# logs_dir = get_project_root() / "logs"
|
||||
logs_dir = get_project_root()
|
||||
base_data_dir = logs_dir / config_name
|
||||
|
||||
# Check if historical data exists
|
||||
if not base_data_dir.exists() or not any(base_data_dir.iterdir()):
|
||||
console.print(
|
||||
f"\n[dim]No historical data found for config '{config_name}'[/dim]",
|
||||
)
|
||||
console.print("[dim] Will start from scratch[/dim]\n")
|
||||
return
|
||||
|
||||
console.print("\n[bold yellow]Detected existing run data:[/bold yellow]")
|
||||
console.print(f" Data directory: [cyan]{base_data_dir}[/cyan]")
|
||||
|
||||
# Show directory size
|
||||
try:
|
||||
total_size = sum(
|
||||
f.stat().st_size for f in base_data_dir.rglob("*") if f.is_file()
|
||||
)
|
||||
size_mb = total_size / (1024 * 1024)
|
||||
if size_mb < 1:
|
||||
console.print(
|
||||
f" Directory size: [cyan]{total_size / 1024:.1f} KB[/cyan]",
|
||||
)
|
||||
else:
|
||||
console.print(f" Directory size: [cyan]{size_mb:.1f} MB[/cyan]")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Show last modified time
|
||||
state_dir = base_data_dir / "state"
|
||||
if state_dir.exists():
|
||||
state_files = list(state_dir.glob("*.json"))
|
||||
if state_files:
|
||||
last_modified = max(f.stat().st_mtime for f in state_files)
|
||||
last_modified_str = datetime.fromtimestamp(last_modified).strftime(
|
||||
"%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
console.print(f" Last updated: [cyan]{last_modified_str}[/cyan]")
|
||||
|
||||
console.print()
|
||||
|
||||
# Determine if we should clean
|
||||
should_clean = auto_clean
|
||||
if not auto_clean:
|
||||
should_clean = Confirm.ask(
|
||||
" ﹂ Clear historical data and start fresh?",
|
||||
default=False,
|
||||
)
|
||||
else:
|
||||
console.print("[yellow]⚠️ Auto-clean enabled (--clean flag)[/yellow]")
|
||||
should_clean = True
|
||||
|
||||
if should_clean:
|
||||
console.print("\n[yellow]▩ Cleaning historical data...[/yellow]")
|
||||
|
||||
# Backup important config files if they exist
|
||||
backup_files = [".env", "config.json"]
|
||||
backed_up = []
|
||||
backup_dir = None
|
||||
|
||||
for backup_file in backup_files:
|
||||
file_path = base_data_dir / backup_file
|
||||
if file_path.exists():
|
||||
if backup_dir is None:
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
backup_dir = (
|
||||
base_data_dir.parent
|
||||
/ f"{config_name}_backup_{timestamp}"
|
||||
)
|
||||
backup_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
shutil.copy(file_path, backup_dir / backup_file)
|
||||
backed_up.append(backup_file)
|
||||
|
||||
if backed_up:
|
||||
console.print(
|
||||
f" 💾 Backed up config files to: [cyan]{backup_dir}[/cyan]",
|
||||
)
|
||||
console.print(f" Files: {', '.join(backed_up)}")
|
||||
|
||||
# Remove the data directory
|
||||
try:
|
||||
shutil.rmtree(base_data_dir)
|
||||
console.print(" ✔ Historical data cleared\n")
|
||||
except Exception as e:
|
||||
console.print(f" [red]✗ Error clearing data: {e}[/red]\n")
|
||||
raise typer.Exit(1)
|
||||
else:
|
||||
console.print(
|
||||
"\n[dim] Continuing with existing historical data[/dim]\n",
|
||||
)
|
||||
|
||||
|
||||
def run_data_updater(project_root: Path) -> None:
|
||||
"""Run the historical data updater."""
|
||||
console.print("\n[bold]Checking historical data update...[/bold]")
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[sys.executable, "-m", "backend.data.ret_data_updater", "--help"],
|
||||
capture_output=True,
|
||||
timeout=5,
|
||||
check=False,
|
||||
)
|
||||
|
||||
if result.returncode == 0:
|
||||
console.print("[cyan]Updating historical data...[/cyan]")
|
||||
update_result = subprocess.run(
|
||||
[sys.executable, "-m", "backend.data.ret_data_updater"],
|
||||
cwd=project_root,
|
||||
check=False,
|
||||
)
|
||||
|
||||
if update_result.returncode == 0:
|
||||
console.print(
|
||||
"[green]✔ Historical data updated successfully[/green]\n",
|
||||
)
|
||||
else:
|
||||
console.print(
|
||||
"[yellow] Data update failed (might be weekend/holiday)[/yellow]",
|
||||
)
|
||||
console.print(
|
||||
"[dim] Will continue with existing data[/dim]\n",
|
||||
)
|
||||
else:
|
||||
console.print(
|
||||
"[yellow] Data updater module not available, skipping update[/yellow]\n",
|
||||
)
|
||||
except Exception:
|
||||
console.print(
|
||||
"[yellow] Data updater check failed, skipping update[/yellow]\n",
|
||||
)
|
||||
|
||||
|
||||
@app.command()
|
||||
def backtest(
|
||||
start: Optional[str] = typer.Option(
|
||||
None,
|
||||
"--start",
|
||||
"-s",
|
||||
help="Start date for backtest (YYYY-MM-DD)",
|
||||
),
|
||||
end: Optional[str] = typer.Option(
|
||||
None,
|
||||
"--end",
|
||||
"-e",
|
||||
help="End date for backtest (YYYY-MM-DD)",
|
||||
),
|
||||
config_name: str = typer.Option(
|
||||
"backtest",
|
||||
"--config-name",
|
||||
"-c",
|
||||
help="Configuration name for this backtest run",
|
||||
),
|
||||
host: str = typer.Option(
|
||||
"0.0.0.0",
|
||||
"--host",
|
||||
help="WebSocket server host",
|
||||
),
|
||||
port: int = typer.Option(
|
||||
8765,
|
||||
"--port",
|
||||
"-p",
|
||||
help="WebSocket server port",
|
||||
),
|
||||
poll_interval: int = typer.Option(
|
||||
10,
|
||||
"--poll-interval",
|
||||
help="Price polling interval in seconds",
|
||||
),
|
||||
clean: bool = typer.Option(
|
||||
False,
|
||||
"--clean",
|
||||
help="Clear historical data before starting",
|
||||
),
|
||||
enable_memory: bool = typer.Option(
|
||||
False,
|
||||
"--enable-memory",
|
||||
help="Enable ReMeTaskLongTermMemory for agents (requires MEMORY_API_KEY)",
|
||||
),
|
||||
):
|
||||
"""
|
||||
Run backtest mode with historical data.
|
||||
|
||||
Example:
|
||||
evotraders backtest --start 2025-11-01 --end 2025-12-01
|
||||
evotraders backtest --config-name my_strategy --port 9000
|
||||
evotraders backtest --clean # Clear historical data before starting
|
||||
evotraders backtest --enable-memory # Enable long-term memory
|
||||
"""
|
||||
console.print(
|
||||
Panel.fit(
|
||||
"[bold cyan]EvoTraders Backtest Mode[/bold cyan]",
|
||||
border_style="cyan",
|
||||
),
|
||||
)
|
||||
|
||||
# Validate dates - required for backtest
|
||||
if not start or not end:
|
||||
console.print(
|
||||
"[red]✗ Both --start and --end dates are required for backtest mode[/red]",
|
||||
)
|
||||
raise typer.Exit(1)
|
||||
|
||||
try:
|
||||
datetime.strptime(start, "%Y-%m-%d")
|
||||
except ValueError as exc:
|
||||
console.print(
|
||||
"[red]✗ Invalid start date format. Use YYYY-MM-DD[/red]",
|
||||
)
|
||||
raise typer.Exit(1) from exc
|
||||
|
||||
try:
|
||||
datetime.strptime(end, "%Y-%m-%d")
|
||||
except ValueError as exc:
|
||||
console.print(
|
||||
"[red]✗ Invalid end date format. Use YYYY-MM-DD[/red]",
|
||||
)
|
||||
raise typer.Exit(1) from exc
|
||||
|
||||
# Handle historical data cleanup
|
||||
handle_history_cleanup(config_name, auto_clean=clean)
|
||||
|
||||
# Display configuration
|
||||
console.print("\n[bold]Configuration:[/bold]")
|
||||
console.print(" Mode: Backtest")
|
||||
console.print(f" Config: {config_name}")
|
||||
console.print(f" Period: {start} -> {end}")
|
||||
console.print(f" Server: {host}:{port}")
|
||||
console.print(f" Poll Interval: {poll_interval}s")
|
||||
console.print(
|
||||
f" Long-term Memory: {'enabled' if enable_memory else 'disabled'}",
|
||||
)
|
||||
console.print("\nAccess frontend at: [cyan]http://localhost:5173[/cyan]")
|
||||
console.print("Press Ctrl+C to stop\n")
|
||||
|
||||
# Change to project root
|
||||
project_root = get_project_root()
|
||||
os.chdir(project_root)
|
||||
|
||||
# Run data updater
|
||||
run_data_updater(project_root)
|
||||
|
||||
# Build command using backend.main
|
||||
cmd = [
|
||||
sys.executable,
|
||||
"-u",
|
||||
"-m",
|
||||
"backend.main",
|
||||
"--mode",
|
||||
"backtest",
|
||||
"--config-name",
|
||||
config_name,
|
||||
"--host",
|
||||
host,
|
||||
"--port",
|
||||
str(port),
|
||||
"--poll-interval",
|
||||
str(poll_interval),
|
||||
"--start-date",
|
||||
start,
|
||||
"--end-date",
|
||||
end,
|
||||
]
|
||||
|
||||
if enable_memory:
|
||||
cmd.append("--enable-memory")
|
||||
|
||||
try:
|
||||
subprocess.run(cmd, check=True)
|
||||
except KeyboardInterrupt:
|
||||
console.print("\n\n[yellow]Backtest stopped by user[/yellow]")
|
||||
except subprocess.CalledProcessError as e:
|
||||
console.print(
|
||||
f"\n[red]Backtest failed with exit code {e.returncode}[/red]",
|
||||
)
|
||||
raise typer.Exit(1)
|
||||
|
||||
|
||||
@app.command()
|
||||
def live(
|
||||
mock: bool = typer.Option(
|
||||
False,
|
||||
"--mock",
|
||||
help="Use mock mode with simulated prices (for testing)",
|
||||
),
|
||||
config_name: str = typer.Option(
|
||||
"live",
|
||||
"--config-name",
|
||||
"-c",
|
||||
help="Configuration name for this live run",
|
||||
),
|
||||
host: str = typer.Option(
|
||||
"0.0.0.0",
|
||||
"--host",
|
||||
help="WebSocket server host",
|
||||
),
|
||||
port: int = typer.Option(
|
||||
8765,
|
||||
"--port",
|
||||
"-p",
|
||||
help="WebSocket server port",
|
||||
),
|
||||
trigger_time: str = typer.Option(
|
||||
"now",
|
||||
"--trigger-time",
|
||||
"-t",
|
||||
help="Trigger time in LOCAL timezone (HH:MM), or 'now' to run immediately",
|
||||
),
|
||||
poll_interval: int = typer.Option(
|
||||
10,
|
||||
"--poll-interval",
|
||||
help="Price polling interval in seconds",
|
||||
),
|
||||
clean: bool = typer.Option(
|
||||
False,
|
||||
"--clean",
|
||||
help="Clear historical data before starting",
|
||||
),
|
||||
enable_memory: bool = typer.Option(
|
||||
False,
|
||||
"--enable-memory",
|
||||
help="Enable ReMeTaskLongTermMemory for agents (requires MEMORY_API_KEY)",
|
||||
),
|
||||
):
|
||||
"""
|
||||
Run live trading mode with real-time data.
|
||||
|
||||
Example:
|
||||
evotraders live # Run immediately (default)
|
||||
evotraders live --mock # Mock mode
|
||||
evotraders live -t 22:30 # Run at 22:30 local time daily
|
||||
evotraders live --trigger-time now # Run immediately
|
||||
evotraders live --clean # Clear historical data before starting
|
||||
"""
|
||||
mode_name = "MOCK" if mock else "LIVE"
|
||||
console.print(
|
||||
Panel.fit(
|
||||
f"[bold cyan]EvoTraders {mode_name} Mode[/bold cyan]",
|
||||
border_style="cyan",
|
||||
),
|
||||
)
|
||||
|
||||
# Check for required API key in live mode
|
||||
if not mock:
|
||||
env_file = get_project_root() / ".env"
|
||||
if not env_file.exists():
|
||||
console.print("\n[yellow]Warning: .env file not found[/yellow]")
|
||||
console.print("Creating from template...\n")
|
||||
template = get_project_root() / "env.template"
|
||||
if template.exists():
|
||||
shutil.copy(template, env_file)
|
||||
console.print("[green].env file created[/green]")
|
||||
console.print(
|
||||
"\n[red]Error: Please edit .env and set FINNHUB_API_KEY[/red]",
|
||||
)
|
||||
console.print(
|
||||
"Get your free API key at: https://finnhub.io/register\n",
|
||||
)
|
||||
else:
|
||||
console.print("[red]Error: env.template not found[/red]")
|
||||
raise typer.Exit(1)
|
||||
|
||||
# Handle historical data cleanup
|
||||
handle_history_cleanup(config_name, auto_clean=clean)
|
||||
|
||||
# Convert local time to NYSE time
|
||||
nyse_tz = ZoneInfo("America/New_York")
|
||||
local_tz = datetime.now().astimezone().tzinfo
|
||||
local_now = datetime.now()
|
||||
nyse_now = datetime.now(nyse_tz)
|
||||
|
||||
# Convert trigger time from local to NYSE
|
||||
if trigger_time.lower() == "now":
|
||||
nyse_trigger_time = "now"
|
||||
else:
|
||||
local_trigger = datetime.strptime(trigger_time, "%H:%M")
|
||||
local_trigger_dt = local_now.replace(
|
||||
hour=local_trigger.hour,
|
||||
minute=local_trigger.minute,
|
||||
second=0,
|
||||
microsecond=0,
|
||||
)
|
||||
local_trigger_aware = local_trigger_dt.astimezone(local_tz)
|
||||
nyse_trigger_dt = local_trigger_aware.astimezone(nyse_tz)
|
||||
nyse_trigger_time = nyse_trigger_dt.strftime("%H:%M")
|
||||
|
||||
# Display time info
|
||||
console.print("\n[bold]Time Info:[/bold]")
|
||||
console.print(f" Local Time: {local_now.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
console.print(
|
||||
f" NYSE Time: {nyse_now.strftime('%Y-%m-%d %H:%M:%S %Z')}",
|
||||
)
|
||||
if nyse_trigger_time == "now":
|
||||
console.print(" Trigger: [green]NOW (immediate)[/green]")
|
||||
else:
|
||||
console.print(
|
||||
f" Trigger: {trigger_time} local = {nyse_trigger_time} NYSE",
|
||||
)
|
||||
|
||||
# Display configuration
|
||||
console.print("\n[bold]Configuration:[/bold]")
|
||||
if mock:
|
||||
console.print(" Mode: [yellow]MOCK[/yellow] (Simulated prices)")
|
||||
else:
|
||||
console.print(
|
||||
" Mode: [green]LIVE[/green] (Real-time prices via Finnhub)",
|
||||
)
|
||||
console.print(f" Config: {config_name}")
|
||||
console.print(f" Server: {host}:{port}")
|
||||
console.print(f" Poll Interval: {poll_interval}s")
|
||||
console.print(
|
||||
f" Long-term Memory: {'enabled' if enable_memory else 'disabled'}",
|
||||
)
|
||||
|
||||
console.print("\nAccess frontend at: [cyan]http://localhost:5173[/cyan]")
|
||||
console.print("Press Ctrl+C to stop\n")
|
||||
|
||||
# Change to project root
|
||||
project_root = get_project_root()
|
||||
os.chdir(project_root)
|
||||
|
||||
# Data update (if not mock mode)
|
||||
if not mock:
|
||||
run_data_updater(project_root)
|
||||
else:
|
||||
console.print(
|
||||
"\n[dim]Mock mode enabled - skipping data update[/dim]\n",
|
||||
)
|
||||
|
||||
# Build command using backend.main
|
||||
cmd = [
|
||||
sys.executable,
|
||||
"-u",
|
||||
"-m",
|
||||
"backend.main",
|
||||
"--mode",
|
||||
"live",
|
||||
"--config-name",
|
||||
config_name,
|
||||
"--host",
|
||||
host,
|
||||
"--port",
|
||||
str(port),
|
||||
"--poll-interval",
|
||||
str(poll_interval),
|
||||
"--trigger-time",
|
||||
nyse_trigger_time,
|
||||
]
|
||||
|
||||
if mock:
|
||||
cmd.append("--mock")
|
||||
if enable_memory:
|
||||
cmd.append("--enable-memory")
|
||||
|
||||
try:
|
||||
subprocess.run(cmd, check=True)
|
||||
except KeyboardInterrupt:
|
||||
console.print("\n\n[yellow]Live server stopped by user[/yellow]")
|
||||
except subprocess.CalledProcessError as e:
|
||||
console.print(
|
||||
f"\n[red]Live server failed with exit code {e.returncode}[/red]",
|
||||
)
|
||||
raise typer.Exit(1) from e
|
||||
|
||||
|
||||
@app.command()
|
||||
def frontend(
|
||||
port: int = typer.Option(
|
||||
8765,
|
||||
"--ws-port",
|
||||
"-p",
|
||||
help="WebSocket server port to connect to",
|
||||
),
|
||||
host_mode: bool = typer.Option(
|
||||
False,
|
||||
"--host",
|
||||
help="Allow external access (default: localhost only)",
|
||||
),
|
||||
):
|
||||
"""
|
||||
Start the frontend development server.
|
||||
|
||||
Example:
|
||||
evotraders frontend
|
||||
evotraders frontend --ws-port 8765
|
||||
evotraders frontend --ws-port 8765 --host
|
||||
"""
|
||||
console.print(
|
||||
Panel.fit(
|
||||
"[bold cyan]EvoTraders Frontend[/bold cyan]",
|
||||
border_style="cyan",
|
||||
),
|
||||
)
|
||||
|
||||
project_root = get_project_root()
|
||||
frontend_dir = project_root / "frontend"
|
||||
|
||||
# Check if frontend directory exists
|
||||
if not frontend_dir.exists():
|
||||
console.print(
|
||||
f"\n[red]Error: Frontend directory not found: {frontend_dir}[/red]",
|
||||
)
|
||||
raise typer.Exit(1)
|
||||
|
||||
# Check if node_modules exists
|
||||
node_modules = frontend_dir / "node_modules"
|
||||
if not node_modules.exists():
|
||||
console.print("\n[yellow]Installing frontend dependencies...[/yellow]")
|
||||
try:
|
||||
subprocess.run(
|
||||
["npm", "install"],
|
||||
cwd=frontend_dir,
|
||||
check=True,
|
||||
)
|
||||
console.print("[green]Dependencies installed[/green]\n")
|
||||
except subprocess.CalledProcessError as exc:
|
||||
console.print("\n[red]Error: Failed to install dependencies[/red]")
|
||||
console.print("Make sure Node.js and npm are installed")
|
||||
raise typer.Exit(1) from exc
|
||||
|
||||
# Set WebSocket URL environment variable
|
||||
ws_url = f"ws://localhost:{port}"
|
||||
env = os.environ.copy()
|
||||
env["VITE_WS_URL"] = ws_url
|
||||
|
||||
# Display configuration
|
||||
console.print("\n[bold]Configuration:[/bold]")
|
||||
console.print(f" WebSocket URL: {ws_url}")
|
||||
console.print(" Frontend Port: 5173 (Vite default)")
|
||||
if host_mode:
|
||||
console.print(" Access: External allowed")
|
||||
else:
|
||||
console.print(" Access: Localhost only")
|
||||
console.print("\nAccess at: [cyan]http://localhost:5173[/cyan]")
|
||||
console.print("Press Ctrl+C to stop\n")
|
||||
|
||||
# Choose npm command
|
||||
npm_cmd = ["npm", "run", "dev:host" if host_mode else "dev"]
|
||||
|
||||
try:
|
||||
subprocess.run(
|
||||
npm_cmd,
|
||||
cwd=frontend_dir,
|
||||
env=env,
|
||||
check=True,
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
console.print("\n\n[yellow]Frontend stopped by user[/yellow]")
|
||||
except subprocess.CalledProcessError as e:
|
||||
console.print(
|
||||
f"\n[red]Frontend failed with exit code {e.returncode}[/red]",
|
||||
)
|
||||
raise typer.Exit(1)
|
||||
|
||||
|
||||
@app.command()
|
||||
def version():
|
||||
"""Show the version of EvoTraders."""
|
||||
console.print(
|
||||
"\n[bold cyan]EvoTraders[/bold cyan] version [green]0.1.0[/green]\n",
|
||||
)
|
||||
|
||||
|
||||
@app.callback()
|
||||
def main():
|
||||
"""
|
||||
EvoTraders: A self-evolving multi-agent trading system
|
||||
|
||||
Use 'evotraders --help' to see available commands.
|
||||
"""
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
0
backend/config/__init__.py
Normal file
0
backend/config/__init__.py
Normal file
76
backend/config/constants.py
Normal file
76
backend/config/constants.py
Normal file
@@ -0,0 +1,76 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# flake8: noqa: E501
|
||||
# pylint: disable=C0301
|
||||
|
||||
# Agent configuration for dashboard display
|
||||
AGENT_CONFIG = {
|
||||
"portfolio_manager": {
|
||||
"name": "Portfolio Manager",
|
||||
"role": "Portfolio Manager",
|
||||
"avatar": "pm",
|
||||
"is_team_role": True,
|
||||
},
|
||||
"risk_manager": {
|
||||
"name": "Risk Manager",
|
||||
"role": "Risk Manager",
|
||||
"avatar": "risk",
|
||||
"is_team_role": True,
|
||||
},
|
||||
"sentiment_analyst": {
|
||||
"name": "Sentiment Analyst",
|
||||
"role": "Sentiment Analyst",
|
||||
"avatar": "sentiment",
|
||||
"is_team_role": False,
|
||||
},
|
||||
"technical_analyst": {
|
||||
"name": "Technical Analyst",
|
||||
"role": "Technical Analyst",
|
||||
"avatar": "technical",
|
||||
"is_team_role": False,
|
||||
},
|
||||
"fundamentals_analyst": {
|
||||
"name": "Fundamentals Analyst",
|
||||
"role": "Fundamentals Analyst",
|
||||
"avatar": "fundamentals",
|
||||
"is_team_role": False,
|
||||
},
|
||||
"valuation_analyst": {
|
||||
"name": "Valuation Analyst",
|
||||
"role": "Valuation Analyst",
|
||||
"avatar": "valuation",
|
||||
"is_team_role": False,
|
||||
},
|
||||
}
|
||||
|
||||
ANALYST_TYPES = {
|
||||
"fundamentals_analyst": {
|
||||
"display_name": "Fundamentals Analyst",
|
||||
"agent_id": "fundamentals_analyst",
|
||||
"description": "Uses LLM to intelligently select analysis tools, focuses on financial data and company fundamental analysis",
|
||||
"order": 12,
|
||||
},
|
||||
"technical_analyst": {
|
||||
"display_name": "Technical Analyst",
|
||||
"agent_id": "technical_analyst",
|
||||
"description": "Uses LLM to intelligently select analysis tools, focuses on technical indicators and chart analysis",
|
||||
"order": 11,
|
||||
},
|
||||
"sentiment_analyst": {
|
||||
"display_name": "Sentiment Analyst",
|
||||
"agent_id": "sentiment_analyst",
|
||||
"description": "Uses LLM to intelligently select analysis tools, analyzes market sentiment and news sentiment",
|
||||
"order": 13,
|
||||
},
|
||||
"valuation_analyst": {
|
||||
"display_name": "Valuation Analyst",
|
||||
"agent_id": "valuation_analyst",
|
||||
"description": "Uses LLM to intelligently select analysis tools, focuses on company valuation and value assessment",
|
||||
"order": 14,
|
||||
},
|
||||
# "comprehensive_analyst": {
|
||||
# "display_name": "Comprehensive Analyst",
|
||||
# "agent_id": "comprehensive_analyst",
|
||||
# "description": "Uses LLM to intelligently select analysis tools, performs comprehensive analysis",
|
||||
# "order": 15
|
||||
# }
|
||||
}
|
||||
82
backend/config/data_config.py
Normal file
82
backend/config/data_config.py
Normal file
@@ -0,0 +1,82 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Centralized Data Source Configuration
|
||||
|
||||
Auto-detects and manages data source based on available API keys.
|
||||
Priority: FINNHUB_API_KEY > FINANCIAL_DATASETS_API_KEY
|
||||
"""
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal, Optional
|
||||
|
||||
DataSource = Literal["finnhub", "financial_datasets"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataSourceConfig:
|
||||
"""Immutable data source configuration"""
|
||||
|
||||
source: DataSource
|
||||
api_key: str
|
||||
|
||||
|
||||
# Module-level cache for the resolved configuration
|
||||
_config_cache: Optional[DataSourceConfig] = None
|
||||
|
||||
|
||||
def _resolve_config() -> DataSourceConfig:
|
||||
"""
|
||||
Resolve data source configuration based on available API keys.
|
||||
|
||||
Priority:
|
||||
1. FINNHUB_API_KEY (if set)
|
||||
2. FINANCIAL_DATASETS_API_KEY (if set)
|
||||
3. Raises error if neither is available
|
||||
"""
|
||||
# Check for Finnhub API key first (higher priority)
|
||||
finnhub_key = os.getenv("FINNHUB_API_KEY")
|
||||
if finnhub_key:
|
||||
return DataSourceConfig(source="finnhub", api_key=finnhub_key)
|
||||
|
||||
# Fallback to Financial Datasets API
|
||||
fd_key = os.getenv("FINANCIAL_DATASETS_API_KEY")
|
||||
if fd_key:
|
||||
return DataSourceConfig(source="financial_datasets", api_key=fd_key)
|
||||
|
||||
# No API key available
|
||||
raise ValueError(
|
||||
"No API key found. Please set either FINNHUB_API_KEY or "
|
||||
"FINANCIAL_DATASETS_API_KEY in your .env file.",
|
||||
)
|
||||
|
||||
|
||||
def get_config() -> DataSourceConfig:
|
||||
"""
|
||||
Get the resolved data source configuration (cached).
|
||||
|
||||
Returns:
|
||||
DataSourceConfig with source and api_key
|
||||
|
||||
Raises:
|
||||
ValueError: If no API key is configured
|
||||
"""
|
||||
global _config_cache
|
||||
if _config_cache is None:
|
||||
_config_cache = _resolve_config()
|
||||
return _config_cache
|
||||
|
||||
|
||||
def get_data_source() -> DataSource:
|
||||
"""Get the configured data source name."""
|
||||
return get_config().source
|
||||
|
||||
|
||||
def get_api_key() -> str:
|
||||
"""Get the API key for the configured data source."""
|
||||
return get_config().api_key
|
||||
|
||||
|
||||
def reset_config() -> None:
|
||||
"""Reset the cached configuration (useful for testing)."""
|
||||
global _config_cache
|
||||
_config_cache = None
|
||||
36
backend/config/env_config.py
Normal file
36
backend/config/env_config.py
Normal file
@@ -0,0 +1,36 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Simple environment config helpers
|
||||
"""
|
||||
import os
|
||||
|
||||
|
||||
def get_env_list(key: str, default: list = None) -> list:
|
||||
"""Get comma-separated list from env"""
|
||||
value = os.getenv(key, "")
|
||||
if not value:
|
||||
return default or []
|
||||
return [item.strip() for item in value.split(",") if item.strip()]
|
||||
|
||||
|
||||
def get_env_float(key: str, default: float = 0.0) -> float:
|
||||
"""Get float from env"""
|
||||
value = os.getenv(key)
|
||||
if value is None:
|
||||
return default
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
return default
|
||||
|
||||
|
||||
def get_env_int(key: str, default: int = 0) -> int:
|
||||
"""Get int from env"""
|
||||
value = os.getenv(key)
|
||||
if value is None:
|
||||
return default
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
return default
|
||||
7
backend/core/__init__.py
Normal file
7
backend/core/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Core pipeline and orchestration logic"""
|
||||
|
||||
from .pipeline import TradingPipeline
|
||||
from .state_sync import StateSync
|
||||
|
||||
__all__ = ["TradingPipeline", "StateSync"]
|
||||
1263
backend/core/pipeline.py
Normal file
1263
backend/core/pipeline.py
Normal file
File diff suppressed because it is too large
Load Diff
263
backend/core/scheduler.py
Normal file
263
backend/core/scheduler.py
Normal file
@@ -0,0 +1,263 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Scheduler - Market-aware trigger system for trading cycles
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Callable, Optional
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
import pandas_market_calendars as mcal
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# NYSE timezone for US stock trading
|
||||
NYSE_TZ = ZoneInfo("America/New_York")
|
||||
NYSE_CALENDAR = mcal.get_calendar("NYSE")
|
||||
|
||||
|
||||
class Scheduler:
|
||||
"""
|
||||
Market-aware scheduler for live trading.
|
||||
Uses NYSE timezone and trading calendar.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mode: str = "daily",
|
||||
trigger_time: Optional[str] = None,
|
||||
interval_minutes: Optional[int] = None,
|
||||
config: Optional[dict] = None,
|
||||
):
|
||||
self.mode = mode
|
||||
self.trigger_time = trigger_time or "09:30" # NYSE timezone
|
||||
self.trigger_now = self.trigger_time == "now"
|
||||
self.interval_minutes = interval_minutes or 60
|
||||
self.config = config or {}
|
||||
|
||||
self.running = False
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
|
||||
def _now_nyse(self) -> datetime:
|
||||
"""Get current time in NYSE timezone"""
|
||||
return datetime.now(NYSE_TZ)
|
||||
|
||||
def _is_trading_day(self, date: datetime) -> bool:
|
||||
"""Check if date is a NYSE trading day"""
|
||||
date_str = date.strftime("%Y-%m-%d")
|
||||
valid_days = NYSE_CALENDAR.valid_days(
|
||||
start_date=date_str,
|
||||
end_date=date_str,
|
||||
)
|
||||
return len(valid_days) > 0
|
||||
|
||||
def _next_trading_day(self, from_date: datetime) -> datetime:
|
||||
"""Find the next trading day from given date"""
|
||||
check_date = from_date
|
||||
for _ in range(10): # Max 10 days ahead (handles holidays)
|
||||
if self._is_trading_day(check_date):
|
||||
return check_date
|
||||
check_date += timedelta(days=1)
|
||||
return check_date
|
||||
|
||||
async def start(self, callback: Callable):
|
||||
"""Start scheduler"""
|
||||
if self.running:
|
||||
logger.warning("Scheduler already running")
|
||||
return
|
||||
|
||||
self.running = True
|
||||
|
||||
if self.mode == "daily":
|
||||
self._task = asyncio.create_task(self._run_daily(callback))
|
||||
elif self.mode == "intraday":
|
||||
self._task = asyncio.create_task(self._run_intraday(callback))
|
||||
else:
|
||||
raise ValueError(f"Unknown scheduler mode: {self.mode}")
|
||||
|
||||
logger.info(
|
||||
f"Scheduler started: mode={self.mode}, timezone=America/New_York",
|
||||
)
|
||||
|
||||
async def _run_daily(self, callback: Callable):
|
||||
"""Run once per trading day at specified time (NYSE timezone)"""
|
||||
first_run = True
|
||||
|
||||
while self.running:
|
||||
now = self._now_nyse()
|
||||
|
||||
# Handle "now" trigger - run immediately on first iteration
|
||||
if self.trigger_now and first_run:
|
||||
first_run = False
|
||||
current_date = now.strftime("%Y-%m-%d")
|
||||
logger.info(f"Triggering immediately for {current_date}")
|
||||
await callback(date=current_date)
|
||||
# After immediate run, stop (one-shot mode)
|
||||
self.running = False
|
||||
break
|
||||
|
||||
target_time = datetime.strptime(self.trigger_time, "%H:%M").time()
|
||||
|
||||
# Calculate next trigger datetime
|
||||
if now.time() < target_time:
|
||||
next_run = now.replace(
|
||||
hour=target_time.hour,
|
||||
minute=target_time.minute,
|
||||
second=0,
|
||||
microsecond=0,
|
||||
)
|
||||
else:
|
||||
next_run = (now + timedelta(days=1)).replace(
|
||||
hour=target_time.hour,
|
||||
minute=target_time.minute,
|
||||
second=0,
|
||||
microsecond=0,
|
||||
)
|
||||
|
||||
# Skip to next trading day
|
||||
next_run = self._next_trading_day(next_run)
|
||||
next_run = next_run.replace(
|
||||
hour=target_time.hour,
|
||||
minute=target_time.minute,
|
||||
second=0,
|
||||
microsecond=0,
|
||||
)
|
||||
|
||||
wait_seconds = (next_run - now).total_seconds()
|
||||
logger.info(
|
||||
f"Next trigger: {next_run.strftime('%Y-%m-%d %H:%M %Z')} "
|
||||
f"(in {wait_seconds/3600:.1f} hours)",
|
||||
)
|
||||
|
||||
await asyncio.sleep(wait_seconds)
|
||||
|
||||
current_date = self._now_nyse().strftime("%Y-%m-%d")
|
||||
logger.info(f"Triggering daily cycle for {current_date}")
|
||||
await callback(date=current_date)
|
||||
|
||||
async def _run_intraday(self, callback: Callable):
|
||||
"""Run every N minutes (for future use)"""
|
||||
while self.running:
|
||||
now = self._now_nyse()
|
||||
current_date = now.strftime("%Y-%m-%d")
|
||||
|
||||
if self._is_trading_day(now):
|
||||
logger.info(f"Triggering intraday cycle for {current_date}")
|
||||
await callback(date=current_date)
|
||||
|
||||
await asyncio.sleep(self.interval_minutes * 60)
|
||||
|
||||
def stop(self):
|
||||
"""Stop scheduler"""
|
||||
self.running = False
|
||||
if self._task:
|
||||
self._task.cancel()
|
||||
self._task = None
|
||||
logger.info("Scheduler stopped")
|
||||
|
||||
|
||||
class BacktestScheduler:
|
||||
"""Backtest Scheduler - Runs through historical trading dates"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
trading_calendar: Optional[Any] = None,
|
||||
delay_between_days: float = 0.1,
|
||||
):
|
||||
self.start_date = start_date
|
||||
self.end_date = end_date
|
||||
self.trading_calendar = trading_calendar
|
||||
self.delay_between_days = delay_between_days
|
||||
|
||||
self.running = False
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
self._dates: list = []
|
||||
|
||||
def get_trading_dates(self) -> list:
|
||||
"""Get list of trading dates in the backtest period"""
|
||||
import pandas as pd
|
||||
|
||||
start = pd.to_datetime(self.start_date)
|
||||
end = pd.to_datetime(self.end_date)
|
||||
|
||||
if self.trading_calendar:
|
||||
calendar = mcal.get_calendar(self.trading_calendar)
|
||||
trading_dates = calendar.valid_days(
|
||||
start_date=self.start_date,
|
||||
end_date=self.end_date,
|
||||
)
|
||||
dates = [d.strftime("%Y-%m-%d") for d in trading_dates]
|
||||
else:
|
||||
all_dates = pd.date_range(start=start, end=end, freq="D")
|
||||
dates = [
|
||||
d.strftime("%Y-%m-%d") for d in all_dates if d.weekday() < 5
|
||||
]
|
||||
|
||||
self._dates = dates
|
||||
return dates
|
||||
|
||||
async def start(self, callback: Callable):
|
||||
"""Start async backtest scheduler"""
|
||||
if self.running:
|
||||
logger.warning("Backtest scheduler already running")
|
||||
return
|
||||
|
||||
self.running = True
|
||||
dates = self.get_trading_dates()
|
||||
|
||||
logger.info(
|
||||
f"Starting backtest: {self.start_date} to {self.end_date} "
|
||||
f"({len(dates)} trading days)",
|
||||
)
|
||||
|
||||
self._task = asyncio.create_task(self._run_async(callback, dates))
|
||||
|
||||
async def _run_async(self, callback: Callable, dates: list):
|
||||
"""Run backtest asynchronously"""
|
||||
for i, date in enumerate(dates, 1):
|
||||
if not self.running:
|
||||
break
|
||||
|
||||
logger.info(f"[{i}/{len(dates)}] Processing {date}")
|
||||
await callback(date=date)
|
||||
|
||||
if self.delay_between_days > 0:
|
||||
await asyncio.sleep(self.delay_between_days)
|
||||
|
||||
logger.info("Backtest complete")
|
||||
self.running = False
|
||||
|
||||
def run(self, callback: Callable, **kwargs):
|
||||
"""Run backtest synchronously through all trading dates"""
|
||||
dates = self.get_trading_dates()
|
||||
results = []
|
||||
|
||||
logger.info(
|
||||
f"Starting backtest: {self.start_date} to {self.end_date} "
|
||||
f"({len(dates)} trading days)",
|
||||
)
|
||||
|
||||
for i, date in enumerate(dates, 1):
|
||||
logger.info(f"[{i}/{len(dates)}] Processing {date}")
|
||||
result = callback(date=date, **kwargs)
|
||||
results.append({"date": date, "result": result})
|
||||
|
||||
logger.info("Backtest complete")
|
||||
return results
|
||||
|
||||
def stop(self):
|
||||
"""Stop backtest scheduler"""
|
||||
self.running = False
|
||||
if self._task:
|
||||
self._task.cancel()
|
||||
self._task = None
|
||||
logger.info("Backtest scheduler stopped")
|
||||
|
||||
def get_total_days(self) -> int:
|
||||
"""Get total number of trading days"""
|
||||
if not self._dates:
|
||||
self.get_trading_dates()
|
||||
return len(self._dates)
|
||||
476
backend/core/state_sync.py
Normal file
476
backend/core/state_sync.py
Normal file
@@ -0,0 +1,476 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
StateSync - Centralized state synchronization between agents and frontend
|
||||
Handles real-time updates, persistence, and replay support
|
||||
"""
|
||||
# pylint: disable=R0904
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
from ..services.storage import StorageService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StateSync:
|
||||
"""
|
||||
Central event dispatcher for agent-frontend synchronization
|
||||
|
||||
Responsibilities:
|
||||
1. Receive events from agents/pipeline
|
||||
2. Persist to storage (feed_history)
|
||||
3. Broadcast to frontend via WebSocket
|
||||
4. Support replay from saved state
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
storage: StorageService,
|
||||
broadcast_fn: Optional[Callable] = None,
|
||||
):
|
||||
"""
|
||||
Initialize StateSync
|
||||
|
||||
Args:
|
||||
storage: Storage service for persistence
|
||||
broadcast_fn: Async broadcast function - async def broadcast(event: dict) # noqa: E501
|
||||
"""
|
||||
self.storage = storage
|
||||
self._broadcast_fn = broadcast_fn
|
||||
self._state: Dict[str, Any] = {}
|
||||
self._enabled = True
|
||||
self._simulation_date: Optional[str] = None # For backtest timestamps
|
||||
|
||||
def set_simulation_date(self, date: str):
|
||||
"""Set current simulation date for backtest-compatible timestamps"""
|
||||
self._simulation_date = date
|
||||
|
||||
def _get_timestamp_ms(self) -> int:
|
||||
"""
|
||||
Get timestamp in milliseconds.
|
||||
Uses simulation date if set (backtest mode), otherwise current time.
|
||||
"""
|
||||
if self._simulation_date:
|
||||
# Parse date and use market close time (16:00) for backtest
|
||||
dt = datetime.strptime(
|
||||
f"{self._simulation_date}",
|
||||
"%Y-%m-%d",
|
||||
)
|
||||
return int(dt.timestamp() * 1000)
|
||||
return int(datetime.now().timestamp() * 1000)
|
||||
|
||||
def load_state(self):
|
||||
"""Load server state from storage"""
|
||||
self._state = self.storage.load_server_state()
|
||||
self.storage.update_server_state_from_dashboard(self._state)
|
||||
logger.info(
|
||||
f"StateSync loaded: {len(self._state.get('feed_history', []))} feeds", # noqa: E501
|
||||
)
|
||||
|
||||
def save_state(self):
|
||||
"""Save current state to storage"""
|
||||
self.storage.save_server_state(self._state)
|
||||
|
||||
@property
|
||||
def state(self) -> Dict[str, Any]:
|
||||
"""Get current state"""
|
||||
return self._state
|
||||
|
||||
def set_broadcast_fn(self, fn: Callable):
|
||||
"""Set broadcast function (supports late binding)"""
|
||||
self._broadcast_fn = fn
|
||||
|
||||
def update_state(self, key: str, value: Any):
|
||||
"""Update a state field"""
|
||||
self._state[key] = value
|
||||
|
||||
async def emit(self, event: Dict[str, Any], persist: bool = True):
|
||||
"""
|
||||
Emit an event - persists and broadcasts
|
||||
|
||||
Args:
|
||||
event: Event dictionary, must contain "type"
|
||||
persist: Whether to persist to feed_history
|
||||
"""
|
||||
if not self._enabled:
|
||||
return
|
||||
|
||||
# Ensure timestamp exists (use simulation date if in backtest mode)
|
||||
if "timestamp" not in event:
|
||||
if self._simulation_date:
|
||||
event["timestamp"] = f"{self._simulation_date}"
|
||||
else:
|
||||
event["timestamp"] = datetime.now().isoformat()
|
||||
|
||||
# Persist to feed_history
|
||||
if persist:
|
||||
self.storage.add_feed_message(self._state, event)
|
||||
self.save_state()
|
||||
|
||||
# Broadcast to frontend
|
||||
if self._broadcast_fn:
|
||||
await self._broadcast_fn(event)
|
||||
|
||||
# ========== Agent Events ==========
|
||||
|
||||
async def on_agent_complete(
|
||||
self,
|
||||
agent_id: str,
|
||||
content: str,
|
||||
**extra,
|
||||
):
|
||||
"""
|
||||
Called when an agent finishes its reply
|
||||
|
||||
Args:
|
||||
agent_id: Agent identifier (e.g., "fundamentals_analyst")
|
||||
content: Agent's output content
|
||||
**extra: Additional fields to include
|
||||
"""
|
||||
ts_ms = self._get_timestamp_ms()
|
||||
|
||||
await self.emit(
|
||||
{
|
||||
"type": "agent_message",
|
||||
"agentId": agent_id,
|
||||
"content": content,
|
||||
"ts": ts_ms,
|
||||
**extra,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(f"Agent complete: {agent_id}")
|
||||
|
||||
async def on_memory_retrieved(
|
||||
self,
|
||||
agent_id: str,
|
||||
content: str,
|
||||
):
|
||||
"""
|
||||
Called when long-term memory is retrieved for an agent
|
||||
|
||||
Args:
|
||||
agent_id: Agent identifier
|
||||
content: Retrieved memory content
|
||||
"""
|
||||
ts_ms = self._get_timestamp_ms()
|
||||
|
||||
await self.emit(
|
||||
{
|
||||
"type": "memory",
|
||||
"agentId": agent_id,
|
||||
"content": content,
|
||||
"ts": ts_ms,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(f"Memory retrieved for: {agent_id}")
|
||||
|
||||
# ========== Conference Events ==========
|
||||
|
||||
async def on_conference_start(self, title: str, date: str):
|
||||
"""Called when conference discussion starts"""
|
||||
ts_ms = self._get_timestamp_ms()
|
||||
|
||||
await self.emit(
|
||||
{
|
||||
"type": "conference_start",
|
||||
"title": title,
|
||||
"date": date,
|
||||
"ts": ts_ms,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(f"Conference started: {title}")
|
||||
|
||||
async def on_conference_cycle_start(self, cycle: int, total_cycles: int):
|
||||
"""Called when a conference cycle starts"""
|
||||
await self.emit(
|
||||
{
|
||||
"type": "conference_cycle_start",
|
||||
"cycle": cycle,
|
||||
"totalCycles": total_cycles,
|
||||
},
|
||||
persist=False,
|
||||
)
|
||||
|
||||
async def on_conference_message(self, agent_id: str, content: str):
|
||||
"""Called when an agent speaks during conference"""
|
||||
ts_ms = self._get_timestamp_ms()
|
||||
|
||||
await self.emit(
|
||||
{
|
||||
"type": "conference_message",
|
||||
"agentId": agent_id,
|
||||
"content": content,
|
||||
"ts": ts_ms,
|
||||
},
|
||||
)
|
||||
|
||||
async def on_conference_cycle_end(self, cycle: int):
|
||||
"""Called when a conference cycle ends"""
|
||||
await self.emit(
|
||||
{
|
||||
"type": "conference_cycle_end",
|
||||
"cycle": cycle,
|
||||
},
|
||||
persist=False,
|
||||
)
|
||||
|
||||
async def on_conference_end(self):
|
||||
"""Called when conference discussion ends"""
|
||||
ts_ms = self._get_timestamp_ms()
|
||||
|
||||
await self.emit(
|
||||
{
|
||||
"type": "conference_end",
|
||||
"ts": ts_ms,
|
||||
},
|
||||
)
|
||||
|
||||
logger.info("Conference ended")
|
||||
|
||||
# ========== Cycle Events ==========
|
||||
|
||||
async def on_cycle_start(self, date: str):
|
||||
"""Called at start of trading cycle"""
|
||||
self._state["current_date"] = date
|
||||
self._state["status"] = "running"
|
||||
self.set_simulation_date(
|
||||
date,
|
||||
) # Set for backtest-compatible timestamps
|
||||
|
||||
await self.emit(
|
||||
{
|
||||
"type": "day_start",
|
||||
"date": date,
|
||||
"progress": self._calculate_progress(),
|
||||
},
|
||||
)
|
||||
# await self.emit(
|
||||
# {
|
||||
# "type": "system",
|
||||
# "content": f"Starting trading analysis for {date}",
|
||||
# },
|
||||
# )
|
||||
|
||||
async def on_cycle_end(self, date: str, portfolio_summary: Dict = None):
|
||||
"""Called at end of trading cycle"""
|
||||
# Update completed count
|
||||
self._state["trading_days_completed"] = (
|
||||
self._state.get("trading_days_completed", 0) + 1
|
||||
)
|
||||
|
||||
# Broadcast team_summary if available
|
||||
if portfolio_summary:
|
||||
summary_data = {
|
||||
"type": "team_summary",
|
||||
"balance": portfolio_summary.get(
|
||||
"balance",
|
||||
portfolio_summary.get("total_value", 0),
|
||||
),
|
||||
"pnlPct": portfolio_summary.get(
|
||||
"pnlPct",
|
||||
portfolio_summary.get("pnl_percent", 0),
|
||||
),
|
||||
"equity": portfolio_summary.get("equity", []),
|
||||
"baseline": portfolio_summary.get("baseline", []),
|
||||
"baseline_vw": portfolio_summary.get("baseline_vw", []),
|
||||
"momentum": portfolio_summary.get("momentum", []),
|
||||
}
|
||||
|
||||
# Include live returns if available
|
||||
if portfolio_summary.get("equity_return"):
|
||||
summary_data["equity_return"] = portfolio_summary[
|
||||
"equity_return"
|
||||
]
|
||||
if portfolio_summary.get("baseline_return"):
|
||||
summary_data["baseline_return"] = portfolio_summary[
|
||||
"baseline_return"
|
||||
]
|
||||
if portfolio_summary.get("baseline_vw_return"):
|
||||
summary_data["baseline_vw_return"] = portfolio_summary[
|
||||
"baseline_vw_return"
|
||||
]
|
||||
if portfolio_summary.get("momentum_return"):
|
||||
summary_data["momentum_return"] = portfolio_summary[
|
||||
"momentum_return"
|
||||
]
|
||||
|
||||
if "portfolio" not in self._state:
|
||||
self._state["portfolio"] = {}
|
||||
|
||||
self._state["portfolio"].update(
|
||||
{
|
||||
"total_value": summary_data["balance"],
|
||||
"pnl_percent": summary_data["pnlPct"],
|
||||
"equity": summary_data["equity"],
|
||||
"baseline": summary_data["baseline"],
|
||||
"baseline_vw": summary_data["baseline_vw"],
|
||||
"momentum": summary_data["momentum"],
|
||||
},
|
||||
)
|
||||
|
||||
if summary_data.get("equity_return"):
|
||||
self._state["portfolio"]["equity_return"] = summary_data[
|
||||
"equity_return"
|
||||
]
|
||||
if summary_data.get("baseline_return"):
|
||||
self._state["portfolio"]["baseline_return"] = summary_data[
|
||||
"baseline_return"
|
||||
]
|
||||
if summary_data.get("baseline_vw_return"):
|
||||
self._state["portfolio"]["baseline_vw_return"] = summary_data[
|
||||
"baseline_vw_return"
|
||||
]
|
||||
if summary_data.get("momentum_return"):
|
||||
self._state["portfolio"]["momentum_return"] = summary_data[
|
||||
"momentum_return"
|
||||
]
|
||||
|
||||
await self.emit(summary_data, persist=True)
|
||||
|
||||
await self.emit(
|
||||
{
|
||||
"type": "day_complete",
|
||||
"date": date,
|
||||
"progress": self._calculate_progress(),
|
||||
},
|
||||
)
|
||||
|
||||
self.save_state()
|
||||
|
||||
# ========== Portfolio Events ==========
|
||||
|
||||
async def on_holdings_update(self, holdings: List[Dict]):
|
||||
"""Called when holdings change"""
|
||||
self._state["holdings"] = holdings
|
||||
await self.emit(
|
||||
{
|
||||
"type": "team_holdings",
|
||||
"data": holdings,
|
||||
},
|
||||
persist=False,
|
||||
) # Holdings change frequently, don't store all in feed_history
|
||||
|
||||
async def on_trades_executed(self, trades: List[Dict]):
|
||||
"""Called when trades are executed"""
|
||||
# Update state with new trades
|
||||
existing = self._state.get("trades", [])
|
||||
self._state["trades"] = trades + existing
|
||||
|
||||
await self.emit(
|
||||
{
|
||||
"type": "team_trades",
|
||||
"mode": "incremental",
|
||||
"data": trades,
|
||||
},
|
||||
persist=False,
|
||||
)
|
||||
|
||||
async def on_stats_update(self, stats: Dict):
|
||||
"""Called when stats are updated"""
|
||||
self._state["stats"] = stats
|
||||
await self.emit(
|
||||
{
|
||||
"type": "team_stats",
|
||||
"data": stats,
|
||||
},
|
||||
persist=False,
|
||||
)
|
||||
|
||||
async def on_leaderboard_update(self, leaderboard: List[Dict]):
|
||||
"""Called when leaderboard is updated"""
|
||||
self._state["leaderboard"] = leaderboard
|
||||
await self.emit(
|
||||
{
|
||||
"type": "team_leaderboard",
|
||||
"data": leaderboard,
|
||||
},
|
||||
persist=False,
|
||||
)
|
||||
|
||||
# ========== System Events ==========
|
||||
|
||||
async def on_system_message(self, content: str):
|
||||
"""Emit a system message"""
|
||||
await self.emit(
|
||||
{
|
||||
"type": "system",
|
||||
"content": content,
|
||||
},
|
||||
)
|
||||
|
||||
# ========== Replay Support ==========
|
||||
|
||||
async def replay_feed_history(self, delay_ms: int = 100):
|
||||
"""
|
||||
Replay events from feed_history
|
||||
|
||||
Useful for: frontend reconnection or restoring from saved state
|
||||
"""
|
||||
feed_history = self._state.get("feed_history", [])
|
||||
|
||||
# feed_history is newest-first, need to reverse for chronological replay # noqa: E501
|
||||
for event in reversed(feed_history):
|
||||
if self._broadcast_fn:
|
||||
await self._broadcast_fn(event)
|
||||
await asyncio.sleep(delay_ms / 1000)
|
||||
|
||||
logger.info(f"Replayed {len(feed_history)} events")
|
||||
|
||||
def get_initial_state_payload(
|
||||
self,
|
||||
include_dashboard: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Build initial state payload for new client connections
|
||||
|
||||
Args:
|
||||
include_dashboard: Whether to load dashboard files
|
||||
|
||||
Returns:
|
||||
Dictionary suitable for sending to frontend
|
||||
"""
|
||||
payload = {
|
||||
"server_mode": self._state.get("server_mode", "live"),
|
||||
"is_mock_mode": self._state.get("is_mock_mode", False),
|
||||
"is_backtest": self._state.get("is_backtest", False),
|
||||
"feed_history": self._state.get("feed_history", []),
|
||||
"current_date": self._state.get("current_date"),
|
||||
"trading_days_total": self._state.get("trading_days_total", 0),
|
||||
"trading_days_completed": self._state.get(
|
||||
"trading_days_completed",
|
||||
0,
|
||||
),
|
||||
"holdings": self._state.get("holdings", []),
|
||||
"trades": self._state.get("trades", []),
|
||||
"stats": self._state.get("stats", {}),
|
||||
"leaderboard": self._state.get("leaderboard", []),
|
||||
"portfolio": self._state.get("portfolio", {}),
|
||||
"realtime_prices": self._state.get("realtime_prices", {}),
|
||||
}
|
||||
|
||||
if include_dashboard:
|
||||
payload["dashboard"] = {
|
||||
"summary": self.storage.load_file("summary"),
|
||||
"holdings": self.storage.load_file("holdings"),
|
||||
"stats": self.storage.load_file("stats"),
|
||||
"trades": self.storage.load_file("trades"),
|
||||
"leaderboard": self.storage.load_file("leaderboard"),
|
||||
}
|
||||
|
||||
return payload
|
||||
|
||||
def _calculate_progress(self) -> float:
|
||||
"""Calculate backtest progress percentage"""
|
||||
total = self._state.get("trading_days_total", 0)
|
||||
completed = self._state.get("trading_days_completed", 0)
|
||||
return completed / total if total > 0 else 0.0
|
||||
|
||||
def set_backtest_dates(self, dates: List[str]):
|
||||
"""Set total trading days for backtest progress tracking"""
|
||||
self._state["trading_days_total"] = len(dates)
|
||||
self._state["trading_days_completed"] = 0
|
||||
BIN
backend/data/__MACOSX/ret_data/._AAPL.csv
Normal file
BIN
backend/data/__MACOSX/ret_data/._AAPL.csv
Normal file
Binary file not shown.
|
BIN
backend/data/__MACOSX/ret_data/._AMZN.csv
Normal file
BIN
backend/data/__MACOSX/ret_data/._AMZN.csv
Normal file
Binary file not shown.
|
BIN
backend/data/__MACOSX/ret_data/._GOOGL.csv
Normal file
BIN
backend/data/__MACOSX/ret_data/._GOOGL.csv
Normal file
Binary file not shown.
|
BIN
backend/data/__MACOSX/ret_data/._META.csv
Normal file
BIN
backend/data/__MACOSX/ret_data/._META.csv
Normal file
Binary file not shown.
|
BIN
backend/data/__MACOSX/ret_data/._MSFT.csv
Normal file
BIN
backend/data/__MACOSX/ret_data/._MSFT.csv
Normal file
Binary file not shown.
|
BIN
backend/data/__MACOSX/ret_data/._NVDA.csv
Normal file
BIN
backend/data/__MACOSX/ret_data/._NVDA.csv
Normal file
Binary file not shown.
|
BIN
backend/data/__MACOSX/ret_data/._TSLA.csv
Normal file
BIN
backend/data/__MACOSX/ret_data/._TSLA.csv
Normal file
Binary file not shown.
|
6
backend/data/__init__.py
Normal file
6
backend/data/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from backend.data.historical_price_manager import HistoricalPriceManager
|
||||
from backend.data.mock_price_manager import MockPriceManager
|
||||
from backend.data.polling_price_manager import PollingPriceManager
|
||||
|
||||
__all__ = ["MockPriceManager", "PollingPriceManager", "HistoricalPriceManager"]
|
||||
107
backend/data/cache.py
Normal file
107
backend/data/cache.py
Normal file
@@ -0,0 +1,107 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from typing_extensions import Any
|
||||
|
||||
|
||||
class Cache:
|
||||
"""In-memory cache for API responses."""
|
||||
|
||||
def __init__(self):
|
||||
self._prices_cache = {}
|
||||
self._financial_metrics_cache = {}
|
||||
self._line_items_cache = {}
|
||||
self._insider_trades_cache = {}
|
||||
self._company_news_cache = {}
|
||||
|
||||
def _merge_data(
|
||||
self,
|
||||
existing: list[dict] | None,
|
||||
new_data: list[dict],
|
||||
key_field: str,
|
||||
) -> list[dict]:
|
||||
"""Merge existing and new data"""
|
||||
if not existing:
|
||||
return new_data
|
||||
|
||||
# Create a set of existing keys for O(1) lookup
|
||||
existing_keys = {item[key_field] for item in existing}
|
||||
|
||||
# Only add items that don't exist yet
|
||||
merged = existing.copy()
|
||||
merged.extend(
|
||||
[
|
||||
item
|
||||
for item in new_data
|
||||
if item[key_field] not in existing_keys
|
||||
],
|
||||
)
|
||||
return merged
|
||||
|
||||
def get_prices(self, ticker: str) -> list[dict[str, Any]] | None:
|
||||
"""Get cached price data if available."""
|
||||
return self._prices_cache.get(ticker)
|
||||
|
||||
def set_prices(self, ticker: str, data: list[dict[str, Any]]):
|
||||
"""Append new price data to cache."""
|
||||
self._prices_cache[ticker] = self._merge_data(
|
||||
self._prices_cache.get(ticker),
|
||||
data,
|
||||
key_field="time",
|
||||
)
|
||||
|
||||
def get_financial_metrics(self, ticker: str) -> list[dict[str, Any]]:
|
||||
"""Get cached financial metrics if available."""
|
||||
return self._financial_metrics_cache.get(ticker)
|
||||
|
||||
def set_financial_metrics(self, ticker: str, data: list[dict[str, Any]]):
|
||||
"""Append new financial metrics to cache."""
|
||||
self._financial_metrics_cache[ticker] = self._merge_data(
|
||||
self._financial_metrics_cache.get(ticker),
|
||||
data,
|
||||
key_field="report_period",
|
||||
)
|
||||
|
||||
def get_line_items(self, ticker: str) -> list[dict[str, Any]] | None:
|
||||
"""Get cached line items if available."""
|
||||
return self._line_items_cache.get(ticker)
|
||||
|
||||
def set_line_items(self, ticker: str, data: list[dict[str, Any]]):
|
||||
"""Append new line items to cache."""
|
||||
self._line_items_cache[ticker] = self._merge_data(
|
||||
self._line_items_cache.get(ticker),
|
||||
data,
|
||||
key_field="report_period",
|
||||
)
|
||||
|
||||
def get_insider_trades(self, ticker: str) -> list[dict[str, Any]] | None:
|
||||
"""Get cached insider trades if available."""
|
||||
return self._insider_trades_cache.get(ticker)
|
||||
|
||||
def set_insider_trades(self, ticker: str, data: list[dict[str, Any]]):
|
||||
"""Append new insider trades to cache."""
|
||||
self._insider_trades_cache[ticker] = self._merge_data(
|
||||
self._insider_trades_cache.get(ticker),
|
||||
data,
|
||||
key_field="filing_date",
|
||||
) # Could also use transaction_date if preferred
|
||||
|
||||
def get_company_news(self, ticker: str) -> list[dict[str, Any]] | None:
|
||||
"""Get cached company news if available."""
|
||||
return self._company_news_cache.get(ticker)
|
||||
|
||||
def set_company_news(self, ticker: str, data: list[dict[str, Any]]):
|
||||
"""Append new company news to cache."""
|
||||
self._company_news_cache[ticker] = self._merge_data(
|
||||
self._company_news_cache.get(ticker),
|
||||
data,
|
||||
key_field="date",
|
||||
)
|
||||
|
||||
|
||||
# Global cache instance
|
||||
_cache = Cache()
|
||||
|
||||
|
||||
def get_cache() -> Cache:
|
||||
"""Get the global cache instance."""
|
||||
return _cache
|
||||
233
backend/data/historical_price_manager.py
Normal file
233
backend/data/historical_price_manager.py
Normal file
@@ -0,0 +1,233 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Historical Price Manager for backtest mode
|
||||
"""
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, List, Optional
|
||||
|
||||
import pandas as pd
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Path to local CSV data directory
|
||||
_DATA_DIR = Path(__file__).parent / "ret_data"
|
||||
|
||||
|
||||
class HistoricalPriceManager:
|
||||
"""Provides historical prices for backtest mode"""
|
||||
|
||||
def __init__(self):
|
||||
self.subscribed_symbols = []
|
||||
self.price_callbacks = []
|
||||
self._price_cache = {}
|
||||
self._current_date = None
|
||||
self.latest_prices = {}
|
||||
self.open_prices = {}
|
||||
self.close_prices = {}
|
||||
self.running = False
|
||||
|
||||
def subscribe(
|
||||
self,
|
||||
symbols: List[str],
|
||||
):
|
||||
"""Subscribe to symbols"""
|
||||
for symbol in symbols:
|
||||
if symbol not in self.subscribed_symbols:
|
||||
self.subscribed_symbols.append(symbol)
|
||||
|
||||
def unsubscribe(self, symbols: List[str]):
|
||||
"""Unsubscribe from symbols"""
|
||||
for symbol in symbols:
|
||||
if symbol in self.subscribed_symbols:
|
||||
self.subscribed_symbols.remove(symbol)
|
||||
self._price_cache.pop(symbol, None)
|
||||
|
||||
def add_price_callback(self, callback: Callable):
|
||||
"""Add price update callback"""
|
||||
self.price_callbacks.append(callback)
|
||||
|
||||
def _load_from_csv(self, symbol: str) -> Optional[pd.DataFrame]:
|
||||
"""Load price data from local CSV file."""
|
||||
csv_path = _DATA_DIR / f"{symbol}.csv"
|
||||
if not csv_path.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
df = pd.read_csv(csv_path)
|
||||
if df.empty or "time" not in df.columns:
|
||||
return None
|
||||
|
||||
df["Date"] = pd.to_datetime(df["time"])
|
||||
df.set_index("Date", inplace=True)
|
||||
df.sort_index(inplace=True)
|
||||
return df
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load CSV for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def preload_data(self, start_date: str, end_date: str):
|
||||
"""Preload historical data from local CSV files."""
|
||||
logger.info(f"Preloading data: {start_date} to {end_date}")
|
||||
|
||||
for symbol in self.subscribed_symbols:
|
||||
if symbol in self._price_cache:
|
||||
continue
|
||||
|
||||
# Load from local CSV file directly
|
||||
df = self._load_from_csv(symbol)
|
||||
if df is not None and not df.empty:
|
||||
self._price_cache[symbol] = df
|
||||
logger.info(f"Loaded {symbol} from CSV: {len(df)} records")
|
||||
else:
|
||||
logger.warning(f"No CSV data for {symbol}")
|
||||
|
||||
def set_date(self, date: str):
|
||||
"""Set current trading date and update prices"""
|
||||
self._current_date = date
|
||||
date_dt = pd.Timestamp(date)
|
||||
|
||||
for symbol in self.subscribed_symbols:
|
||||
df = self._price_cache.get(symbol)
|
||||
if df is None or df.empty:
|
||||
# Keep previous prices if no data available
|
||||
logger.warning(f"No cached data for {symbol} on {date}")
|
||||
continue
|
||||
|
||||
# Find exact date or closest earlier date
|
||||
if date_dt in df.index:
|
||||
row = df.loc[date_dt]
|
||||
else:
|
||||
valid_dates = df.index[df.index <= date_dt]
|
||||
if len(valid_dates) == 0:
|
||||
logger.warning(f"No data for {symbol} on or before {date}")
|
||||
continue
|
||||
row = df.loc[valid_dates[-1]]
|
||||
|
||||
open_price = float(row["open"])
|
||||
close_price = float(row["close"])
|
||||
|
||||
self.open_prices[symbol] = open_price
|
||||
self.close_prices[symbol] = close_price
|
||||
self.latest_prices[symbol] = open_price
|
||||
|
||||
logger.debug(
|
||||
f"{symbol} @ {date}: open={open_price:.2f}, close={close_price:.2f}", # noqa: E501
|
||||
)
|
||||
|
||||
def emit_open_prices(self):
|
||||
"""Emit open prices to callbacks"""
|
||||
if not self._current_date:
|
||||
return
|
||||
|
||||
timestamp = int(
|
||||
datetime.strptime(self._current_date, "%Y-%m-%d").timestamp()
|
||||
* 1000,
|
||||
)
|
||||
|
||||
for symbol in self.subscribed_symbols:
|
||||
price = self.open_prices.get(symbol)
|
||||
if price is None or price <= 0:
|
||||
logger.warning(f"Invalid open price for {symbol}: {price}")
|
||||
continue
|
||||
|
||||
self.latest_prices[symbol] = price
|
||||
self._emit_price(symbol, price, timestamp)
|
||||
|
||||
def emit_close_prices(self):
|
||||
"""Emit close prices to callbacks"""
|
||||
if not self._current_date:
|
||||
return
|
||||
|
||||
timestamp = int(
|
||||
datetime.strptime(self._current_date, "%Y-%m-%d").timestamp()
|
||||
* 1000,
|
||||
)
|
||||
timestamp += 23400000 # Add 6.5 hours
|
||||
|
||||
for symbol in self.subscribed_symbols:
|
||||
price = self.close_prices.get(symbol)
|
||||
if price is None or price <= 0:
|
||||
logger.warning(f"Invalid close price for {symbol}: {price}")
|
||||
continue
|
||||
|
||||
self.latest_prices[symbol] = price
|
||||
self._emit_price(symbol, price, timestamp)
|
||||
|
||||
def _emit_price(self, symbol: str, price: float, timestamp: int):
|
||||
"""Emit single price to callbacks"""
|
||||
open_price = self.open_prices.get(symbol, price)
|
||||
close_price = self.close_prices.get(symbol, price)
|
||||
ret = (
|
||||
((price - open_price) / open_price) * 100 if open_price > 0 else 0
|
||||
)
|
||||
|
||||
price_data = {
|
||||
"symbol": symbol,
|
||||
"price": price,
|
||||
"timestamp": timestamp,
|
||||
"open": open_price,
|
||||
"close": close_price,
|
||||
"high": max(open_price, close_price),
|
||||
"low": min(open_price, close_price),
|
||||
"ret": ret,
|
||||
}
|
||||
|
||||
for callback in self.price_callbacks:
|
||||
try:
|
||||
callback(price_data)
|
||||
except Exception as e:
|
||||
logger.error(f"Callback error for {symbol}: {e}")
|
||||
|
||||
def get_price_for_date(
|
||||
self,
|
||||
symbol: str,
|
||||
date: str,
|
||||
price_type: str = "close",
|
||||
) -> Optional[float]:
|
||||
"""Get price for a specific date"""
|
||||
df = self._price_cache.get(symbol)
|
||||
if df is None or df.empty:
|
||||
return self.latest_prices.get(symbol)
|
||||
|
||||
date_dt = pd.Timestamp(date)
|
||||
if date_dt in df.index:
|
||||
return float(df.loc[date_dt, price_type])
|
||||
|
||||
valid_dates = df.index[df.index <= date_dt]
|
||||
if len(valid_dates) == 0:
|
||||
return self.latest_prices.get(symbol)
|
||||
return float(df.loc[valid_dates[-1], price_type])
|
||||
|
||||
def start(self):
|
||||
"""Start manager"""
|
||||
self.running = True
|
||||
|
||||
def stop(self):
|
||||
"""Stop manager"""
|
||||
self.running = False
|
||||
|
||||
def get_latest_price(self, symbol: str) -> Optional[float]:
|
||||
return self.latest_prices.get(symbol)
|
||||
|
||||
def get_all_latest_prices(self) -> Dict[str, float]:
|
||||
return self.latest_prices.copy()
|
||||
|
||||
def get_open_price(self, symbol: str) -> Optional[float]:
|
||||
# Return open price, fallback to latest if not set
|
||||
price = self.open_prices.get(symbol)
|
||||
if price is None or price <= 0:
|
||||
return self.latest_prices.get(symbol)
|
||||
return price
|
||||
|
||||
def get_close_price(self, symbol: str) -> Optional[float]:
|
||||
# Return close price, fallback to latest if not set
|
||||
price = self.close_prices.get(symbol)
|
||||
if price is None or price <= 0:
|
||||
return self.latest_prices.get(symbol)
|
||||
return price
|
||||
|
||||
def reset_open_prices(self):
|
||||
# Don't clear prices - keep them for continuity
|
||||
pass
|
||||
241
backend/data/mock_price_manager.py
Normal file
241
backend/data/mock_price_manager.py
Normal file
@@ -0,0 +1,241 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Mock Price Manager - For testing during non-trading hours
|
||||
Generates virtual real-time price data
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import threading
|
||||
import time
|
||||
from typing import Callable, Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MockPriceManager:
|
||||
"""Mock Price Manager - Generates virtual prices for testing"""
|
||||
|
||||
def __init__(self, poll_interval: int = 10, volatility: float = 0.5):
|
||||
"""
|
||||
Args:
|
||||
poll_interval: Price update interval in seconds
|
||||
volatility: Price volatility percentage
|
||||
"""
|
||||
if poll_interval is None:
|
||||
poll_interval = int(os.getenv("MOCK_POLL_INTERVAL", "5"))
|
||||
if volatility is None:
|
||||
volatility = float(os.getenv("MOCK_VOLATILITY", "0.5"))
|
||||
|
||||
self.poll_interval = poll_interval
|
||||
self.volatility = volatility
|
||||
|
||||
self.subscribed_symbols: List[str] = []
|
||||
self.base_prices: Dict[str, float] = {}
|
||||
self.open_prices: Dict[str, float] = {}
|
||||
self.latest_prices: Dict[str, float] = {}
|
||||
self.price_callbacks: List[Callable] = []
|
||||
|
||||
self.running = False
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
|
||||
self.default_base_prices = {
|
||||
"AAPL": 237.50,
|
||||
"MSFT": 425.30,
|
||||
"GOOGL": 161.50,
|
||||
"AMZN": 218.45,
|
||||
"NVDA": 950.00,
|
||||
"META": 573.22,
|
||||
"TSLA": 342.15,
|
||||
"AMD": 168.90,
|
||||
"NFLX": 688.25,
|
||||
"INTC": 42.18,
|
||||
"COIN": 285.50,
|
||||
"PLTR": 45.80,
|
||||
"BABA": 88.30,
|
||||
"DIS": 112.50,
|
||||
"BKNG": 4850.00,
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"MockPriceManager initialized (interval: {self.poll_interval}s, "
|
||||
f"volatility: {self.volatility}%)",
|
||||
)
|
||||
|
||||
def subscribe(
|
||||
self,
|
||||
symbols: List[str],
|
||||
base_prices: Dict[str, float] = None,
|
||||
):
|
||||
"""Subscribe to stock symbols"""
|
||||
for symbol in symbols:
|
||||
if symbol not in self.subscribed_symbols:
|
||||
self.subscribed_symbols.append(symbol)
|
||||
|
||||
if base_prices and symbol in base_prices:
|
||||
base_price = base_prices[symbol]
|
||||
elif symbol in self.default_base_prices:
|
||||
base_price = self.default_base_prices[symbol]
|
||||
else:
|
||||
base_price = random.uniform(50, 500)
|
||||
|
||||
self.base_prices[symbol] = base_price
|
||||
self.open_prices[symbol] = base_price
|
||||
self.latest_prices[symbol] = base_price
|
||||
|
||||
logger.info(
|
||||
f"Subscribed to mock price: {symbol} (base: ${base_price:.2f})", # noqa: E501
|
||||
)
|
||||
|
||||
def unsubscribe(self, symbols: List[str]):
|
||||
"""Unsubscribe from symbols"""
|
||||
for symbol in symbols:
|
||||
if symbol in self.subscribed_symbols:
|
||||
self.subscribed_symbols.remove(symbol)
|
||||
self.base_prices.pop(symbol, None)
|
||||
self.open_prices.pop(symbol, None)
|
||||
self.latest_prices.pop(symbol, None)
|
||||
logger.info(f"Unsubscribed: {symbol}")
|
||||
|
||||
def add_price_callback(self, callback: Callable):
|
||||
"""Add price update callback"""
|
||||
self.price_callbacks.append(callback)
|
||||
|
||||
def _generate_price_update(self, symbol: str) -> float:
|
||||
"""Generate price update based on random walk"""
|
||||
current_price = self.latest_prices.get(
|
||||
symbol,
|
||||
self.base_prices[symbol],
|
||||
)
|
||||
|
||||
change_percent = random.uniform(-self.volatility, self.volatility)
|
||||
new_price = current_price * (1 + change_percent / 100)
|
||||
|
||||
# 10% chance of larger movement
|
||||
if random.random() < 0.1:
|
||||
trend_factor = random.uniform(-2, 2)
|
||||
new_price = new_price * (1 + trend_factor / 100)
|
||||
|
||||
# Limit intraday movement to +/-10%
|
||||
open_price = self.open_prices[symbol]
|
||||
max_price = open_price * 1.10
|
||||
min_price = open_price * 0.90
|
||||
new_price = max(min_price, min(max_price, new_price))
|
||||
|
||||
return new_price
|
||||
|
||||
def _update_prices(self):
|
||||
"""Update prices for all subscribed stocks"""
|
||||
timestamp = int(time.time() * 1000)
|
||||
|
||||
for symbol in self.subscribed_symbols:
|
||||
try:
|
||||
new_price = self._generate_price_update(symbol)
|
||||
self.latest_prices[symbol] = new_price
|
||||
|
||||
open_price = self.open_prices[symbol]
|
||||
ret = ((new_price - open_price) / open_price) * 100
|
||||
|
||||
price_data = {
|
||||
"symbol": symbol,
|
||||
"price": new_price,
|
||||
"timestamp": timestamp,
|
||||
"volume": random.randint(1000000, 10000000),
|
||||
"open": open_price,
|
||||
"high": max(new_price, open_price),
|
||||
"low": min(new_price, open_price),
|
||||
"previous_close": open_price,
|
||||
"ret": ret,
|
||||
}
|
||||
|
||||
for callback in self.price_callbacks:
|
||||
try:
|
||||
callback(price_data)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Mock price callback error ({symbol}): {e}",
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Mock {symbol}: ${new_price:.2f} [ret: {ret:+.2f}%]",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate mock price ({symbol}): {e}")
|
||||
|
||||
def _polling_loop(self):
|
||||
"""Main polling loop"""
|
||||
logger.info(
|
||||
f"Mock price generation started (interval: {self.poll_interval}s)",
|
||||
)
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
start_time = time.time()
|
||||
self._update_prices()
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
sleep_time = max(0, self.poll_interval - elapsed)
|
||||
if sleep_time > 0:
|
||||
time.sleep(sleep_time)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Mock polling loop error: {e}")
|
||||
time.sleep(5)
|
||||
|
||||
def start(self):
|
||||
"""Start mock price generation"""
|
||||
if self.running:
|
||||
logger.warning("Mock price manager already running")
|
||||
return
|
||||
|
||||
if not self.subscribed_symbols:
|
||||
logger.warning("No stocks subscribed")
|
||||
return
|
||||
|
||||
self.running = True
|
||||
self._thread = threading.Thread(target=self._polling_loop, daemon=True)
|
||||
self._thread.start()
|
||||
|
||||
logger.info(
|
||||
f"Mock price manager started: {', '.join(self.subscribed_symbols)}", # noqa: E501
|
||||
)
|
||||
|
||||
def stop(self):
|
||||
"""Stop mock price generation"""
|
||||
self.running = False
|
||||
if self._thread:
|
||||
self._thread.join(timeout=5)
|
||||
logger.info("Mock price manager stopped")
|
||||
|
||||
def get_latest_price(self, symbol: str) -> Optional[float]:
|
||||
"""Get latest price for symbol"""
|
||||
return self.latest_prices.get(symbol)
|
||||
|
||||
def get_all_latest_prices(self) -> Dict[str, float]:
|
||||
"""Get all latest prices"""
|
||||
return self.latest_prices.copy()
|
||||
|
||||
def get_open_price(self, symbol: str) -> Optional[float]:
|
||||
"""Get open price for symbol"""
|
||||
return self.open_prices.get(symbol)
|
||||
|
||||
def reset_open_prices(self):
|
||||
"""Reset open prices for new trading day"""
|
||||
for symbol in self.subscribed_symbols:
|
||||
last_close = self.latest_prices[symbol]
|
||||
gap_percent = random.uniform(-1, 1)
|
||||
new_open = last_close * (1 + gap_percent / 100)
|
||||
self.open_prices[symbol] = new_open
|
||||
self.latest_prices[symbol] = new_open
|
||||
logger.info("Open prices reset")
|
||||
|
||||
def set_base_price(self, symbol: str, price: float):
|
||||
"""Manually set base price for testing"""
|
||||
if symbol in self.subscribed_symbols:
|
||||
self.base_prices[symbol] = price
|
||||
self.open_prices[symbol] = price
|
||||
self.latest_prices[symbol] = price
|
||||
logger.info(f"{symbol} base price set to: ${price:.2f}")
|
||||
else:
|
||||
logger.warning(f"{symbol} not subscribed")
|
||||
175
backend/data/polling_price_manager.py
Normal file
175
backend/data/polling_price_manager.py
Normal file
@@ -0,0 +1,175 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Polling-based Price Manager - Uses Finnhub REST API
|
||||
Supports real-time price fetching via polling
|
||||
"""
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from typing import Callable, Dict, List, Optional
|
||||
|
||||
import finnhub
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PollingPriceManager:
|
||||
"""Polling-based price manager using Finnhub Quote API"""
|
||||
|
||||
def __init__(self, api_key: str, poll_interval: int = 30):
|
||||
"""
|
||||
Args:
|
||||
api_key: Finnhub API Key
|
||||
poll_interval: Polling interval in seconds (default 30s)
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self.poll_interval = poll_interval
|
||||
self.finnhub_client = finnhub.Client(api_key=api_key)
|
||||
|
||||
self.subscribed_symbols: List[str] = []
|
||||
self.latest_prices: Dict[str, float] = {}
|
||||
self.open_prices: Dict[str, float] = {}
|
||||
self.price_callbacks: List[Callable] = []
|
||||
|
||||
self.running = False
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
|
||||
logger.info(
|
||||
f"PollingPriceManager initialized (interval: {poll_interval}s)",
|
||||
)
|
||||
|
||||
def subscribe(self, symbols: List[str]):
|
||||
"""Subscribe to stock symbols"""
|
||||
for symbol in symbols:
|
||||
if symbol not in self.subscribed_symbols:
|
||||
self.subscribed_symbols.append(symbol)
|
||||
logger.info(f"Subscribed to: {symbol}")
|
||||
|
||||
def unsubscribe(self, symbols: List[str]):
|
||||
"""Unsubscribe from symbols"""
|
||||
for symbol in symbols:
|
||||
if symbol in self.subscribed_symbols:
|
||||
self.subscribed_symbols.remove(symbol)
|
||||
logger.info(f"Unsubscribed: {symbol}")
|
||||
|
||||
def add_price_callback(self, callback: Callable):
|
||||
"""Add price update callback"""
|
||||
self.price_callbacks.append(callback)
|
||||
|
||||
def _fetch_prices(self):
|
||||
"""Fetch latest prices for all subscribed stocks"""
|
||||
for symbol in self.subscribed_symbols:
|
||||
try:
|
||||
quote_data = self.finnhub_client.quote(symbol)
|
||||
|
||||
current_price = quote_data.get("c")
|
||||
open_price = quote_data.get("o")
|
||||
timestamp = quote_data.get("t", int(time.time()))
|
||||
|
||||
if not current_price or current_price <= 0:
|
||||
logger.warning(f"{symbol}: Invalid price data")
|
||||
continue
|
||||
|
||||
# Store open price on first fetch
|
||||
if (
|
||||
symbol not in self.open_prices
|
||||
and open_price
|
||||
and open_price > 0
|
||||
):
|
||||
self.open_prices[symbol] = open_price
|
||||
logger.info(f"{symbol} open price: ${open_price:.2f}")
|
||||
|
||||
stored_open = self.open_prices.get(symbol, open_price)
|
||||
ret = (
|
||||
((current_price - stored_open) / stored_open) * 100
|
||||
if stored_open > 0
|
||||
else 0
|
||||
)
|
||||
|
||||
self.latest_prices[symbol] = current_price
|
||||
|
||||
price_data = {
|
||||
"symbol": symbol,
|
||||
"price": current_price,
|
||||
"timestamp": timestamp * 1000,
|
||||
"open": stored_open,
|
||||
"high": quote_data.get("h"),
|
||||
"low": quote_data.get("l"),
|
||||
"previous_close": quote_data.get("pc"),
|
||||
"ret": ret,
|
||||
"change": quote_data.get("d"),
|
||||
"change_percent": quote_data.get("dp"),
|
||||
}
|
||||
|
||||
for callback in self.price_callbacks:
|
||||
try:
|
||||
callback(price_data)
|
||||
except Exception as e:
|
||||
logger.error(f"Price callback error ({symbol}): {e}")
|
||||
|
||||
logger.debug(
|
||||
f"{symbol}: ${current_price:.2f} [ret: {ret:+.2f}%]",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to fetch {symbol} price: {e}")
|
||||
|
||||
def _polling_loop(self):
|
||||
"""Main polling loop"""
|
||||
logger.info(f"Price polling started (interval: {self.poll_interval}s)")
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
start_time = time.time()
|
||||
self._fetch_prices()
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
sleep_time = max(0, self.poll_interval - elapsed)
|
||||
if sleep_time > 0:
|
||||
time.sleep(sleep_time)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Polling loop error: {e}")
|
||||
time.sleep(5)
|
||||
|
||||
def start(self):
|
||||
"""Start price polling"""
|
||||
if self.running:
|
||||
logger.warning("Price polling already running")
|
||||
return
|
||||
|
||||
if not self.subscribed_symbols:
|
||||
logger.warning("No stocks subscribed")
|
||||
return
|
||||
|
||||
self.running = True
|
||||
self._thread = threading.Thread(target=self._polling_loop, daemon=True)
|
||||
self._thread.start()
|
||||
|
||||
logger.info(
|
||||
f"Price polling started: {', '.join(self.subscribed_symbols)}",
|
||||
)
|
||||
|
||||
def stop(self):
|
||||
"""Stop price polling"""
|
||||
self.running = False
|
||||
if self._thread:
|
||||
self._thread.join(timeout=5)
|
||||
logger.info("Price polling stopped")
|
||||
|
||||
def get_latest_price(self, symbol: str) -> Optional[float]:
|
||||
"""Get latest price for symbol"""
|
||||
return self.latest_prices.get(symbol)
|
||||
|
||||
def get_all_latest_prices(self) -> Dict[str, float]:
|
||||
"""Get all latest prices"""
|
||||
return self.latest_prices.copy()
|
||||
|
||||
def get_open_price(self, symbol: str) -> Optional[float]:
|
||||
"""Get open price for symbol"""
|
||||
return self.open_prices.get(symbol)
|
||||
|
||||
def reset_open_prices(self):
|
||||
"""Reset open prices for new trading day"""
|
||||
self.open_prices.clear()
|
||||
logger.info("Open prices reset")
|
||||
387
backend/data/ret_data_updater.py
Normal file
387
backend/data/ret_data_updater.py
Normal file
@@ -0,0 +1,387 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Automatic Incremental Historical Data Update Module
|
||||
|
||||
Features:
|
||||
1. Fetch stock historical data from configured API (Finnhub or Financial Datasets)
|
||||
2. Incrementally update CSV files in ret_data directory
|
||||
3. Automatically detect last update date, only download new data
|
||||
4. Calculate returns (ret)
|
||||
5. Support batch updates for multiple stocks
|
||||
"""
|
||||
|
||||
# flake8: noqa: E501
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import exchange_calendars as xcals
|
||||
import pandas as pd
|
||||
import pandas_market_calendars as mcal
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from backend.config.data_config import (
|
||||
get_config,
|
||||
)
|
||||
from backend.tools.data_tools import get_prices, prices_to_df
|
||||
|
||||
# Add project root directory to path
|
||||
BASE_DIR = Path(__file__).resolve().parents[2]
|
||||
if str(BASE_DIR) not in sys.path:
|
||||
sys.path.append(str(BASE_DIR))
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] %(message)s",
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DataUpdater:
|
||||
"""Data updater"""
|
||||
|
||||
data_dir: Path
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_dir: str = None,
|
||||
start_date: str = "2022-01-01",
|
||||
):
|
||||
"""
|
||||
Initialize data updater
|
||||
|
||||
Args:
|
||||
data_dir: Data storage directory, defaults to backend/data/ret_data
|
||||
start_date: Historical data start date (YYYY-MM-DD)
|
||||
"""
|
||||
# Get config from centralized source
|
||||
config = get_config()
|
||||
self.data_source = config.source
|
||||
self.api_key = config.api_key
|
||||
|
||||
# Set data directory
|
||||
if data_dir is None:
|
||||
self.data_dir = BASE_DIR / "backend" / "data" / "ret_data"
|
||||
else:
|
||||
self.data_dir = Path(data_dir)
|
||||
|
||||
# Ensure directory exists
|
||||
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.start_date = start_date
|
||||
|
||||
# Initialize Finnhub client if needed
|
||||
if self.data_source == "finnhub":
|
||||
import finnhub
|
||||
|
||||
self.client = finnhub.Client(api_key=self.api_key)
|
||||
logger.info("Finnhub client initialized")
|
||||
else:
|
||||
self.client = None
|
||||
logger.info("Financial Datasets API configured")
|
||||
|
||||
def get_trading_dates(self, start_date: str, end_date: str) -> List[str]:
|
||||
"""Get US stock market trading date sequence."""
|
||||
try:
|
||||
if mcal is not None:
|
||||
nyse = mcal.get_calendar("NYSE")
|
||||
trading_dates = nyse.valid_days(
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
return [date.strftime("%Y-%m-%d") for date in trading_dates]
|
||||
|
||||
elif xcals is not None:
|
||||
nyse = xcals.get_calendar("XNYS")
|
||||
trading_dates = nyse.sessions_in_range(start_date, end_date)
|
||||
return [date.strftime("%Y-%m-%d") for date in trading_dates]
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to get US trading calendar, using business days: {e}",
|
||||
)
|
||||
|
||||
# Fallback to simple business day method
|
||||
date_range = pd.date_range(start_date, end_date, freq="B")
|
||||
return [date.strftime("%Y-%m-%d") for date in date_range]
|
||||
|
||||
def get_last_date_from_csv(self, ticker: str) -> Optional[datetime]:
|
||||
"""Get last data date from CSV file."""
|
||||
csv_path = self.data_dir / f"{ticker}.csv"
|
||||
|
||||
if not csv_path.exists():
|
||||
logger.info(f"{ticker}.csv does not exist, will create new file")
|
||||
return None
|
||||
|
||||
try:
|
||||
df = pd.read_csv(csv_path)
|
||||
if df.empty or "time" not in df.columns:
|
||||
return None
|
||||
|
||||
last_date_str = df["time"].iloc[-1]
|
||||
last_date = datetime.strptime(last_date_str, "%Y-%m-%d")
|
||||
logger.info(f"{ticker} last data date: {last_date_str}")
|
||||
return last_date
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to read {ticker}.csv: {e}")
|
||||
return None
|
||||
|
||||
def fetch_data_from_api(
|
||||
self,
|
||||
ticker: str,
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
) -> Optional[pd.DataFrame]:
|
||||
"""Fetch data from configured API."""
|
||||
start_date_str = start_date.strftime("%Y-%m-%d")
|
||||
end_date_str = end_date.strftime("%Y-%m-%d")
|
||||
|
||||
logger.info(
|
||||
f"Fetching {ticker} data from {self.data_source}: {start_date_str} to {end_date_str}",
|
||||
)
|
||||
|
||||
prices = get_prices(
|
||||
ticker=ticker,
|
||||
start_date=start_date_str,
|
||||
end_date=end_date_str,
|
||||
)
|
||||
|
||||
if not prices:
|
||||
logger.warning(f"{ticker} no data returned from API")
|
||||
return None
|
||||
|
||||
# Convert to DataFrame
|
||||
df = prices_to_df(prices)
|
||||
df = df.reset_index()
|
||||
df["time"] = df["Date"].dt.strftime("%Y-%m-%d")
|
||||
|
||||
# Calculate returns (next day return)
|
||||
df["ret"] = df["close"].pct_change().shift(-1)
|
||||
|
||||
# Select needed columns
|
||||
df = df[["open", "close", "high", "low", "volume", "time", "ret"]]
|
||||
|
||||
logger.info(f"Successfully fetched {ticker} data: {len(df)} records")
|
||||
return df
|
||||
|
||||
def merge_and_save(self, ticker: str, new_data: pd.DataFrame) -> bool:
|
||||
"""Merge old and new data and save."""
|
||||
csv_path = self.data_dir / f"{ticker}.csv"
|
||||
|
||||
try:
|
||||
if csv_path.exists():
|
||||
old_data = pd.read_csv(csv_path)
|
||||
logger.info(f"{ticker} existing data: {len(old_data)} records")
|
||||
|
||||
# Merge and deduplicate
|
||||
combined = pd.concat([old_data, new_data], ignore_index=True)
|
||||
combined = combined.drop_duplicates(
|
||||
subset=["time"],
|
||||
keep="last",
|
||||
)
|
||||
combined = combined.sort_values("time").reset_index(drop=True)
|
||||
|
||||
# Recalculate returns
|
||||
combined["ret"] = combined["close"].pct_change().shift(-1)
|
||||
|
||||
logger.info(f"{ticker} merged data: {len(combined)} records")
|
||||
else:
|
||||
combined = new_data
|
||||
logger.info(f"{ticker} new file: {len(combined)} records")
|
||||
|
||||
combined.to_csv(csv_path, index=False)
|
||||
logger.info(f"{ticker} data saved to: {csv_path}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save {ticker} data: {e}")
|
||||
return False
|
||||
|
||||
def update_ticker(
|
||||
self,
|
||||
ticker: str,
|
||||
force_full_update: bool = False,
|
||||
) -> bool:
|
||||
"""Update data for a single stock."""
|
||||
logger.info(f"{'='*60}")
|
||||
logger.info(f"Starting update for {ticker}")
|
||||
logger.info(f"{'='*60}")
|
||||
|
||||
# Determine start date
|
||||
if force_full_update:
|
||||
start_date = datetime.strptime(self.start_date, "%Y-%m-%d")
|
||||
logger.info(f"Force full update, start date: {start_date.date()}")
|
||||
else:
|
||||
last_date = self.get_last_date_from_csv(ticker)
|
||||
if last_date:
|
||||
start_date = last_date + timedelta(days=1)
|
||||
logger.info(
|
||||
f"Incremental update, start date: {start_date.date()}",
|
||||
)
|
||||
else:
|
||||
start_date = datetime.strptime(self.start_date, "%Y-%m-%d")
|
||||
logger.info(f"First update, start date: {start_date.date()}")
|
||||
|
||||
end_date = datetime.now()
|
||||
|
||||
if start_date.date() >= end_date.date():
|
||||
logger.info(f"{ticker} data is up to date, no update needed")
|
||||
return True
|
||||
|
||||
new_data = self.fetch_data_from_api(ticker, start_date, end_date)
|
||||
|
||||
if new_data is None or new_data.empty:
|
||||
days_diff = (end_date - start_date).days
|
||||
if days_diff <= 3:
|
||||
logger.info(
|
||||
f"{ticker} has no new data (may be weekend/holiday)",
|
||||
)
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"{ticker} has no new data")
|
||||
return False
|
||||
|
||||
success = self.merge_and_save(ticker, new_data)
|
||||
|
||||
if success:
|
||||
logger.info(f"{ticker} update completed")
|
||||
else:
|
||||
logger.error(f"{ticker} update failed")
|
||||
|
||||
return success
|
||||
|
||||
def update_all_tickers(
|
||||
self,
|
||||
tickers: List[str],
|
||||
force_full_update: bool = False,
|
||||
) -> Dict[str, bool]:
|
||||
"""Batch update multiple stocks."""
|
||||
results = {}
|
||||
|
||||
logger.info(f"{'='*60}")
|
||||
logger.info(f"Starting batch update for {len(tickers)} stocks")
|
||||
logger.info(f"Stock list: {', '.join(tickers)}")
|
||||
logger.info(f"{'='*60}")
|
||||
|
||||
for i, ticker in enumerate(tickers, 1):
|
||||
logger.info(f"[{i}/{len(tickers)}] Processing {ticker}")
|
||||
results[ticker] = self.update_ticker(ticker, force_full_update)
|
||||
|
||||
# API rate limiting
|
||||
if i < len(tickers):
|
||||
time.sleep(1)
|
||||
|
||||
# Print summary
|
||||
logger.info(f"{'='*60}")
|
||||
logger.info("Update Summary")
|
||||
logger.info(f"{'='*60}")
|
||||
|
||||
success_count = sum(results.values())
|
||||
fail_count = len(results) - success_count
|
||||
|
||||
logger.info(f"Success: {success_count}")
|
||||
logger.info(f"Failed: {fail_count}")
|
||||
|
||||
if fail_count > 0:
|
||||
failed_tickers = [t for t, s in results.items() if not s]
|
||||
logger.warning(f"Failed stocks: {', '.join(failed_tickers)}")
|
||||
|
||||
logger.info(f"{'='*60}\n")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def main():
|
||||
"""Command line entry point"""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Automatically update stock historical data",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tickers",
|
||||
type=str,
|
||||
help="Stock ticker list (comma-separated), e.g.: AAPL,MSFT,GOOGL",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--data-dir",
|
||||
type=str,
|
||||
help="Data storage directory (default: backend/data/ret_data)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--start-date",
|
||||
type=str,
|
||||
default="2022-01-01",
|
||||
help="Historical data start date (YYYY-MM-DD, default: 2022-01-01)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--force",
|
||||
action="store_true",
|
||||
help="Force full update (re-download all data)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
# Validate API key is available
|
||||
try:
|
||||
config = get_config()
|
||||
logger.info(f"Using data source: {config.source}")
|
||||
except ValueError as e:
|
||||
logger.error(str(e))
|
||||
sys.exit(1)
|
||||
|
||||
# Get stock list
|
||||
if args.tickers:
|
||||
tickers = [t.strip().upper() for t in args.tickers.split(",")]
|
||||
else:
|
||||
tickers_env = os.getenv("TICKERS", "")
|
||||
if tickers_env:
|
||||
tickers = [t.strip().upper() for t in tickers_env.split(",")]
|
||||
else:
|
||||
logger.error("Stock list not provided")
|
||||
logger.error(
|
||||
"Please set via --tickers parameter or TICKERS environment variable",
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
# Create updater
|
||||
updater = DataUpdater(
|
||||
data_dir=args.data_dir,
|
||||
start_date=args.start_date,
|
||||
)
|
||||
|
||||
# Execute update
|
||||
try:
|
||||
results = updater.update_all_tickers(
|
||||
tickers,
|
||||
force_full_update=args.force,
|
||||
)
|
||||
except Exception:
|
||||
# API error (e.g., weekend/holiday with no data)
|
||||
sys.exit(1)
|
||||
|
||||
# Return status code
|
||||
success_count = sum(results.values())
|
||||
if success_count == len(results):
|
||||
logger.info("All stocks updated successfully!")
|
||||
sys.exit(0)
|
||||
elif success_count == 0:
|
||||
logger.warning("All stocks have no new data (may be weekend/holiday)")
|
||||
sys.exit(0)
|
||||
else:
|
||||
logger.warning("Some stocks failed to update, but will continue")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
184
backend/data/schema.py
Normal file
184
backend/data/schema.py
Normal file
@@ -0,0 +1,184 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Price(BaseModel):
|
||||
open: float
|
||||
close: float
|
||||
high: float
|
||||
low: float
|
||||
volume: int
|
||||
time: str
|
||||
|
||||
|
||||
class PriceResponse(BaseModel):
|
||||
ticker: str
|
||||
prices: list[Price]
|
||||
|
||||
|
||||
class FinancialMetrics(BaseModel):
|
||||
ticker: str
|
||||
report_period: str
|
||||
period: str
|
||||
currency: str
|
||||
market_cap: float | None
|
||||
enterprise_value: float | None
|
||||
price_to_earnings_ratio: float | None
|
||||
price_to_book_ratio: float | None
|
||||
price_to_sales_ratio: float | None
|
||||
enterprise_value_to_ebitda_ratio: float | None
|
||||
enterprise_value_to_revenue_ratio: float | None
|
||||
free_cash_flow_yield: float | None
|
||||
peg_ratio: float | None
|
||||
gross_margin: float | None
|
||||
operating_margin: float | None
|
||||
net_margin: float | None
|
||||
return_on_equity: float | None
|
||||
return_on_assets: float | None
|
||||
return_on_invested_capital: float | None
|
||||
asset_turnover: float | None
|
||||
inventory_turnover: float | None
|
||||
receivables_turnover: float | None
|
||||
days_sales_outstanding: float | None
|
||||
operating_cycle: float | None
|
||||
working_capital_turnover: float | None
|
||||
current_ratio: float | None
|
||||
quick_ratio: float | None
|
||||
cash_ratio: float | None
|
||||
operating_cash_flow_ratio: float | None
|
||||
debt_to_equity: float | None
|
||||
debt_to_assets: float | None
|
||||
interest_coverage: float | None
|
||||
revenue_growth: float | None
|
||||
earnings_growth: float | None
|
||||
book_value_growth: float | None
|
||||
earnings_per_share_growth: float | None
|
||||
free_cash_flow_growth: float | None
|
||||
operating_income_growth: float | None
|
||||
ebitda_growth: float | None
|
||||
payout_ratio: float | None
|
||||
earnings_per_share: float | None
|
||||
book_value_per_share: float | None
|
||||
free_cash_flow_per_share: float | None
|
||||
|
||||
|
||||
class FinancialMetricsResponse(BaseModel):
|
||||
financial_metrics: list[FinancialMetrics]
|
||||
|
||||
|
||||
class LineItem(BaseModel):
|
||||
ticker: str
|
||||
report_period: str
|
||||
period: str
|
||||
currency: str
|
||||
|
||||
# Allow additional fields dynamically
|
||||
model_config = {"extra": "allow"}
|
||||
|
||||
|
||||
class LineItemResponse(BaseModel):
|
||||
search_results: list[LineItem]
|
||||
|
||||
|
||||
class InsiderTrade(BaseModel):
|
||||
ticker: str
|
||||
issuer: str | None
|
||||
name: str | None
|
||||
title: str | None
|
||||
is_board_director: bool | None
|
||||
transaction_date: str | None
|
||||
transaction_shares: float | None
|
||||
transaction_price_per_share: float | None
|
||||
transaction_value: float | None
|
||||
shares_owned_before_transaction: float | None
|
||||
shares_owned_after_transaction: float | None
|
||||
security_title: str | None
|
||||
filing_date: str
|
||||
|
||||
|
||||
class InsiderTradeResponse(BaseModel):
|
||||
insider_trades: list[InsiderTrade]
|
||||
|
||||
|
||||
class CompanyNews(BaseModel):
|
||||
category: str | None = None
|
||||
ticker: str
|
||||
title: str
|
||||
related: str | None = None
|
||||
source: str
|
||||
date: str | None = None
|
||||
url: str
|
||||
summary: str | None = None
|
||||
|
||||
|
||||
class CompanyNewsResponse(BaseModel):
|
||||
news: list[CompanyNews]
|
||||
|
||||
|
||||
class CompanyFacts(BaseModel):
|
||||
ticker: str
|
||||
name: str
|
||||
cik: str | None = None
|
||||
industry: str | None = None
|
||||
sector: str | None = None
|
||||
category: str | None = None
|
||||
exchange: str | None = None
|
||||
is_active: bool | None = None
|
||||
listing_date: str | None = None
|
||||
location: str | None = None
|
||||
market_cap: float | None = None
|
||||
number_of_employees: int | None = None
|
||||
sec_filings_url: str | None = None
|
||||
sic_code: str | None = None
|
||||
sic_industry: str | None = None
|
||||
sic_sector: str | None = None
|
||||
website_url: str | None = None
|
||||
weighted_average_shares: int | None = None
|
||||
|
||||
|
||||
class CompanyFactsResponse(BaseModel):
|
||||
company_facts: CompanyFacts
|
||||
|
||||
|
||||
class Position(BaseModel):
|
||||
"""Position information - for Portfolio mode"""
|
||||
|
||||
long: int = 0 # Long position quantity (shares)
|
||||
short: int = 0 # Short position quantity (shares)
|
||||
long_cost_basis: float = 0.0 # Long position average cost
|
||||
short_cost_basis: float = 0.0 # Short position average cost
|
||||
|
||||
|
||||
class Portfolio(BaseModel):
|
||||
"""Portfolio - for Portfolio mode"""
|
||||
|
||||
cash: float = 100000.0 # Available cash
|
||||
positions: dict[str, Position] = {} # ticker -> Position mapping
|
||||
# Margin requirement (0.0 means shorting disabled, 0.5 means 50% margin)
|
||||
margin_requirement: float = 0.0
|
||||
margin_used: float = 0.0 # Margin used
|
||||
|
||||
|
||||
class AnalystSignal(BaseModel):
|
||||
signal: str | None = None
|
||||
confidence: float | None = None
|
||||
reasoning: dict | str | None = None
|
||||
max_position_size: float | None = None # For risk management signals
|
||||
|
||||
|
||||
class TickerAnalysis(BaseModel):
|
||||
ticker: str
|
||||
analyst_signals: dict[str, AnalystSignal] # agent_name -> signal mapping
|
||||
|
||||
|
||||
class AgentStateData(BaseModel):
|
||||
tickers: list[str]
|
||||
portfolio: Portfolio
|
||||
start_date: str
|
||||
end_date: str
|
||||
ticker_analyses: dict[str, TickerAnalysis] # ticker -> analysis mapping
|
||||
|
||||
|
||||
class AgentStateMetadata(BaseModel):
|
||||
show_reasoning: bool = False
|
||||
model_config = {"extra": "allow"}
|
||||
0
backend/llm/__init__.py
Normal file
0
backend/llm/__init__.py
Normal file
243
backend/llm/models.py
Normal file
243
backend/llm/models.py
Normal file
@@ -0,0 +1,243 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
AgentScope Native Model Factory
|
||||
Uses native AgentScope model classes for LLM calls
|
||||
"""
|
||||
import os
|
||||
from enum import Enum
|
||||
from typing import Optional, Tuple
|
||||
from agentscope.formatter import (
|
||||
AnthropicChatFormatter,
|
||||
DashScopeChatFormatter,
|
||||
GeminiChatFormatter,
|
||||
OllamaChatFormatter,
|
||||
OpenAIChatFormatter,
|
||||
)
|
||||
from agentscope.model import (
|
||||
AnthropicChatModel,
|
||||
DashScopeChatModel,
|
||||
GeminiChatModel,
|
||||
OllamaChatModel,
|
||||
OpenAIChatModel,
|
||||
)
|
||||
|
||||
|
||||
class ModelProvider(Enum):
|
||||
"""Supported model providers"""
|
||||
|
||||
OPENAI = "OPENAI"
|
||||
ANTHROPIC = "ANTHROPIC"
|
||||
DASHSCOPE = "DASHSCOPE"
|
||||
ALIBABA = "ALIBABA"
|
||||
GEMINI = "GEMINI"
|
||||
GOOGLE = "GOOGLE"
|
||||
OLLAMA = "OLLAMA"
|
||||
DEEPSEEK = "DEEPSEEK"
|
||||
GROQ = "GROQ"
|
||||
OPENROUTER = "OPENROUTER"
|
||||
|
||||
|
||||
# Provider to AgentScope model class mapping
|
||||
PROVIDER_MODEL_MAP = {
|
||||
"OPENAI": OpenAIChatModel,
|
||||
"ANTHROPIC": AnthropicChatModel,
|
||||
"DASHSCOPE": DashScopeChatModel,
|
||||
"ALIBABA": DashScopeChatModel,
|
||||
"GEMINI": GeminiChatModel,
|
||||
"GOOGLE": GeminiChatModel,
|
||||
"OLLAMA": OllamaChatModel,
|
||||
# OpenAI-compatible providers use OpenAIChatModel with custom base_url
|
||||
"DEEPSEEK": OpenAIChatModel,
|
||||
"GROQ": OpenAIChatModel,
|
||||
"OPENROUTER": OpenAIChatModel,
|
||||
}
|
||||
|
||||
# Provider to formatter mapping
|
||||
PROVIDER_FORMATTER_MAP = {
|
||||
"OPENAI": OpenAIChatFormatter,
|
||||
"ANTHROPIC": AnthropicChatFormatter,
|
||||
"DASHSCOPE": DashScopeChatFormatter,
|
||||
"ALIBABA": DashScopeChatFormatter,
|
||||
"GEMINI": GeminiChatFormatter,
|
||||
"GOOGLE": GeminiChatFormatter,
|
||||
"OLLAMA": OllamaChatFormatter,
|
||||
# OpenAI-compatible providers use OpenAIChatFormatter
|
||||
"DEEPSEEK": OpenAIChatFormatter,
|
||||
"GROQ": OpenAIChatFormatter,
|
||||
"OPENROUTER": OpenAIChatFormatter,
|
||||
}
|
||||
|
||||
# Provider-specific base URLs
|
||||
PROVIDER_BASE_URLS = {
|
||||
"DEEPSEEK": "https://api.deepseek.com/v1",
|
||||
"GROQ": "https://api.groq.com/openai/v1",
|
||||
"OPENROUTER": "https://openrouter.ai/api/v1",
|
||||
}
|
||||
|
||||
# Provider-specific API key environment variable names
|
||||
PROVIDER_API_KEY_ENV = {
|
||||
"OPENAI": "OPENAI_API_KEY",
|
||||
"ANTHROPIC": "ANTHROPIC_API_KEY",
|
||||
"DASHSCOPE": "DASHSCOPE_API_KEY",
|
||||
"ALIBABA": "DASHSCOPE_API_KEY",
|
||||
"GEMINI": "GOOGLE_API_KEY",
|
||||
"GOOGLE": "GOOGLE_API_KEY",
|
||||
"DEEPSEEK": "DEEPSEEK_API_KEY",
|
||||
"GROQ": "GROQ_API_KEY",
|
||||
"OPENROUTER": "OPENROUTER_API_KEY",
|
||||
}
|
||||
|
||||
|
||||
def create_model(
|
||||
model_name: str,
|
||||
provider: str,
|
||||
api_key: Optional[str] = None,
|
||||
stream: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Create an AgentScope model instance
|
||||
|
||||
Args:
|
||||
model_name: Model name (e.g., "gpt-4o", "claude-3-opus")
|
||||
provider: Provider name (e.g., "OPENAI", "ANTHROPIC")
|
||||
api_key: API key (optional, will read from env if not provided)
|
||||
stream: Whether to use streaming mode
|
||||
**kwargs: Additional model-specific arguments
|
||||
|
||||
Returns:
|
||||
AgentScope model instance
|
||||
"""
|
||||
provider = provider.upper()
|
||||
|
||||
model_class = PROVIDER_MODEL_MAP.get(provider)
|
||||
if model_class is None:
|
||||
raise ValueError(f"Unsupported provider: {provider}")
|
||||
|
||||
# Get API key from env if not provided
|
||||
if api_key is None:
|
||||
env_key = PROVIDER_API_KEY_ENV.get(provider)
|
||||
if env_key:
|
||||
api_key = os.getenv(env_key)
|
||||
|
||||
# Build model kwargs
|
||||
model_kwargs = {
|
||||
"model_name": model_name,
|
||||
"stream": stream,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
# Add API key if needed (Ollama doesn't need it)
|
||||
if provider != "OLLAMA" and api_key:
|
||||
model_kwargs["api_key"] = api_key
|
||||
|
||||
# Handle OpenAI-compatible providers with custom base_url
|
||||
if provider in PROVIDER_BASE_URLS:
|
||||
base_url = PROVIDER_BASE_URLS[provider]
|
||||
model_kwargs["client_args"] = {"base_url": base_url}
|
||||
|
||||
# Handle custom OpenAI base URL
|
||||
if provider == "OPENAI":
|
||||
base_url = os.getenv("OPENAI_BASE_URL") or os.getenv("OPENAI_API_BASE")
|
||||
if base_url:
|
||||
model_kwargs["client_args"] = {"base_url": base_url}
|
||||
|
||||
# Handle DashScope base URL (uses different parameter)
|
||||
if provider in ("DASHSCOPE", "ALIBABA"):
|
||||
base_url = os.getenv("DASHSCOPE_BASE_URL")
|
||||
if base_url:
|
||||
model_kwargs["base_http_api_url"] = base_url
|
||||
|
||||
# Handle Ollama host
|
||||
if provider == "OLLAMA":
|
||||
host = os.getenv("OLLAMA_HOST")
|
||||
if host:
|
||||
model_kwargs["host"] = host
|
||||
|
||||
return model_class(**model_kwargs)
|
||||
|
||||
|
||||
def get_agent_model(agent_id: str, stream: bool = False):
|
||||
"""
|
||||
Get model for a specific agent based on environment variables
|
||||
|
||||
Environment variable pattern:
|
||||
AGENT_{AGENT_ID}_MODEL_NAME: Model name
|
||||
AGENT_{AGENT_ID}_MODEL_PROVIDER: Provider name
|
||||
|
||||
fallback to global MODEL_NAME & MODEL_PROVIDER if agent-specific not given
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID (e.g., "sentiment_analyst", "portfolio_manager")
|
||||
stream: Whether to use streaming mode
|
||||
|
||||
Returns:
|
||||
AgentScope model instance
|
||||
"""
|
||||
# Normalize agent_id to uppercase for env var lookup
|
||||
agent_key = agent_id.upper().replace("-", "_")
|
||||
|
||||
# Try agent-specific config first
|
||||
model_name = os.getenv(f"AGENT_{agent_key}_MODEL_NAME")
|
||||
provider = os.getenv(f"AGENT_{agent_key}_MODEL_PROVIDER")
|
||||
|
||||
print(f"Using specific model {model_name} for agent {agent_key}")
|
||||
# Fall back to global config
|
||||
if not model_name:
|
||||
model_name = os.getenv("MODEL_NAME", "gpt-4o")
|
||||
if not provider:
|
||||
provider = os.getenv("MODEL_PROVIDER", "OPENAI")
|
||||
|
||||
return create_model(
|
||||
model_name=model_name,
|
||||
provider=provider,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
|
||||
def get_agent_formatter(agent_id: str):
|
||||
"""
|
||||
Get formatter for a specific agent based on environment variables
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID (e.g., "sentiment_analyst", "portfolio_manager")
|
||||
|
||||
Returns:
|
||||
AgentScope formatter instance
|
||||
"""
|
||||
# Normalize agent_id to uppercase for env var lookup
|
||||
agent_key = agent_id.upper().replace("-", "_")
|
||||
|
||||
# Try agent-specific config first
|
||||
provider = os.getenv(f"AGENT_{agent_key}_MODEL_PROVIDER")
|
||||
|
||||
# Fall back to global config
|
||||
if not provider:
|
||||
provider = os.getenv("MODEL_PROVIDER", "OPENAI")
|
||||
|
||||
provider = provider.upper()
|
||||
formatter_class = PROVIDER_FORMATTER_MAP.get(provider, OpenAIChatFormatter)
|
||||
return formatter_class()
|
||||
|
||||
|
||||
def get_agent_model_info(agent_id: str) -> Tuple[str, str]:
|
||||
"""
|
||||
Get model name and provider for a specific agent
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID (e.g., "sentiment_analyst", "portfolio_manager")
|
||||
|
||||
Returns:
|
||||
Tuple of (model_name, provider_name)
|
||||
"""
|
||||
agent_key = agent_id.upper().replace("-", "_")
|
||||
|
||||
model_name = os.getenv(f"AGENT_{agent_key}_MODEL_NAME")
|
||||
provider = os.getenv(f"AGENT_{agent_key}_MODEL_PROVIDER")
|
||||
|
||||
if not model_name:
|
||||
model_name = os.getenv("MODEL_NAME", "gpt-4o")
|
||||
if not provider:
|
||||
provider = os.getenv("MODEL_PROVIDER", "OPENAI")
|
||||
|
||||
return model_name, provider.upper()
|
||||
332
backend/main.py
Normal file
332
backend/main.py
Normal file
@@ -0,0 +1,332 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Main Entry Point
|
||||
Supports: backtest, live, mock modes
|
||||
"""
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from contextlib import AsyncExitStack
|
||||
from pathlib import Path
|
||||
import loguru
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from backend.agents import AnalystAgent, PMAgent, RiskAgent
|
||||
from backend.config.constants import ANALYST_TYPES
|
||||
from backend.config.env_config import get_env_float, get_env_int, get_env_list
|
||||
from backend.core.pipeline import TradingPipeline
|
||||
from backend.core.scheduler import BacktestScheduler, Scheduler
|
||||
from backend.utils.settlement import SettlementCoordinator
|
||||
from backend.llm.models import get_agent_formatter, get_agent_model
|
||||
from backend.services.gateway import Gateway
|
||||
from backend.services.market import MarketService
|
||||
from backend.services.storage import StorageService
|
||||
|
||||
load_dotenv()
|
||||
logger = logging.getLogger(__name__)
|
||||
loguru.logger.disable("flowllm")
|
||||
loguru.logger.disable("reme_ai")
|
||||
|
||||
|
||||
def create_long_term_memory(agent_name: str, config_name: str):
|
||||
"""
|
||||
Create ReMeTaskLongTermMemory for an agent
|
||||
|
||||
Requires DASHSCOPE_API_KEY env var
|
||||
"""
|
||||
from agentscope.memory import ReMeTaskLongTermMemory
|
||||
from agentscope.model import DashScopeChatModel
|
||||
from agentscope.embedding import DashScopeTextEmbedding
|
||||
|
||||
api_key = os.getenv("MEMORY_API_KEY")
|
||||
if not api_key:
|
||||
logger.warning("MEMORY_API_KEY not set, long-term memory disabled")
|
||||
return None
|
||||
|
||||
memory_dir = str(Path(config_name) / "memory")
|
||||
|
||||
return ReMeTaskLongTermMemory(
|
||||
agent_name=agent_name,
|
||||
user_name=agent_name,
|
||||
model=DashScopeChatModel(
|
||||
model_name=os.getenv("MEMORY_MODEL_NAME", "qwen3-max"),
|
||||
api_key=api_key,
|
||||
stream=False,
|
||||
),
|
||||
embedding_model=DashScopeTextEmbedding(
|
||||
model_name=os.getenv(
|
||||
"MEMORY_EMBEDDING_MODEL",
|
||||
"text-embedding-v4",
|
||||
),
|
||||
api_key=api_key,
|
||||
dimensions=1024,
|
||||
),
|
||||
**{
|
||||
"vector_store.default.backend": "local",
|
||||
"vector_store.default.params.store_dir": memory_dir,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def create_agents(
|
||||
config_name: str,
|
||||
initial_cash: float,
|
||||
margin_requirement: float,
|
||||
enable_long_term_memory: bool = False,
|
||||
):
|
||||
"""Create all agents for the system
|
||||
|
||||
Returns:
|
||||
tuple: (analysts, risk_manager, portfolio_manager, long_term_memories)
|
||||
long_term_memories is a list of memory
|
||||
"""
|
||||
analysts = []
|
||||
long_term_memories = []
|
||||
|
||||
for analyst_type in ANALYST_TYPES:
|
||||
model = get_agent_model(analyst_type)
|
||||
formatter = get_agent_formatter(analyst_type)
|
||||
toolkit = create_toolkit(analyst_type)
|
||||
|
||||
long_term_memory = None
|
||||
if enable_long_term_memory:
|
||||
long_term_memory = create_long_term_memory(
|
||||
analyst_type,
|
||||
config_name,
|
||||
)
|
||||
if long_term_memory:
|
||||
long_term_memories.append(long_term_memory)
|
||||
|
||||
analyst = AnalystAgent(
|
||||
analyst_type=analyst_type,
|
||||
toolkit=toolkit,
|
||||
model=model,
|
||||
formatter=formatter,
|
||||
agent_id=analyst_type,
|
||||
config={"config_name": config_name},
|
||||
long_term_memory=long_term_memory,
|
||||
)
|
||||
analysts.append(analyst)
|
||||
|
||||
risk_long_term_memory = None
|
||||
if enable_long_term_memory:
|
||||
risk_long_term_memory = create_long_term_memory(
|
||||
"risk_manager",
|
||||
config_name,
|
||||
)
|
||||
if risk_long_term_memory:
|
||||
long_term_memories.append(risk_long_term_memory)
|
||||
|
||||
risk_manager = RiskAgent(
|
||||
model=get_agent_model("risk_manager"),
|
||||
formatter=get_agent_formatter("risk_manager"),
|
||||
name="risk_manager",
|
||||
config={"config_name": config_name},
|
||||
long_term_memory=risk_long_term_memory,
|
||||
)
|
||||
|
||||
pm_long_term_memory = None
|
||||
if enable_long_term_memory:
|
||||
pm_long_term_memory = create_long_term_memory(
|
||||
"portfolio_manager",
|
||||
config_name,
|
||||
)
|
||||
if pm_long_term_memory:
|
||||
long_term_memories.append(pm_long_term_memory)
|
||||
|
||||
portfolio_manager = PMAgent(
|
||||
name="portfolio_manager",
|
||||
model=get_agent_model("portfolio_manager"),
|
||||
formatter=get_agent_formatter("portfolio_manager"),
|
||||
initial_cash=initial_cash,
|
||||
margin_requirement=margin_requirement,
|
||||
config={"config_name": config_name},
|
||||
long_term_memory=pm_long_term_memory,
|
||||
)
|
||||
|
||||
return analysts, risk_manager, portfolio_manager, long_term_memories
|
||||
|
||||
|
||||
def create_toolkit(analyst_type: str):
|
||||
"""Create AgentScope Toolkit with tools for specific analyst type"""
|
||||
from agentscope.tool import Toolkit
|
||||
from backend.agents.prompt_loader import PromptLoader
|
||||
from backend.tools.analysis_tools import TOOL_REGISTRY
|
||||
|
||||
# Load analyst persona config
|
||||
prompt_loader = PromptLoader()
|
||||
personas_config = prompt_loader.load_yaml_config("analyst", "personas")
|
||||
persona = personas_config.get(analyst_type, {})
|
||||
|
||||
# Get tool names for this analyst type
|
||||
tool_names = persona.get("tools", [])
|
||||
|
||||
# Create toolkit and register tools
|
||||
toolkit = Toolkit()
|
||||
for tool_name in tool_names:
|
||||
tool_func = TOOL_REGISTRY.get(tool_name)
|
||||
if tool_func:
|
||||
toolkit.register_tool_function(tool_func)
|
||||
|
||||
return toolkit
|
||||
|
||||
|
||||
async def run_with_gateway(args):
|
||||
"""Run with WebSocket gateway"""
|
||||
is_backtest = args.mode == "backtest"
|
||||
|
||||
# Load config from env, override with args
|
||||
tickers = get_env_list("TICKERS", ["AAPL", "MSFT"])
|
||||
initial_cash = get_env_float("INITIAL_CASH", 100000.0)
|
||||
margin_requirement = get_env_float("MARGIN_REQUIREMENT", 0.0)
|
||||
config_name = args.config_name
|
||||
|
||||
# Create market service
|
||||
market_service = MarketService(
|
||||
tickers=tickers,
|
||||
poll_interval=args.poll_interval,
|
||||
mock_mode=args.mock and not is_backtest,
|
||||
backtest_mode=is_backtest,
|
||||
api_key=os.getenv("FINNHUB_API_KEY")
|
||||
if not args.mock and not is_backtest
|
||||
else None,
|
||||
backtest_start_date=args.start_date if is_backtest else None,
|
||||
backtest_end_date=args.end_date if is_backtest else None,
|
||||
)
|
||||
|
||||
# Create storage service
|
||||
storage_service = StorageService(
|
||||
dashboard_dir=Path(config_name) / "team_dashboard",
|
||||
initial_cash=initial_cash,
|
||||
config_name=config_name,
|
||||
)
|
||||
|
||||
if not storage_service.files["summary"].exists():
|
||||
storage_service.initialize_empty_dashboard()
|
||||
else:
|
||||
storage_service.update_leaderboard_model_info()
|
||||
|
||||
# Create agents and pipeline
|
||||
analysts, risk_manager, pm, long_term_memories = create_agents(
|
||||
config_name=config_name,
|
||||
initial_cash=initial_cash,
|
||||
margin_requirement=margin_requirement,
|
||||
enable_long_term_memory=args.enable_memory,
|
||||
)
|
||||
portfolio_state = storage_service.load_portfolio_state()
|
||||
pm.load_portfolio_state(portfolio_state)
|
||||
|
||||
settlement_coordinator = SettlementCoordinator(
|
||||
storage=storage_service,
|
||||
initial_capital=initial_cash,
|
||||
)
|
||||
|
||||
pipeline = TradingPipeline(
|
||||
analysts=analysts,
|
||||
risk_manager=risk_manager,
|
||||
portfolio_manager=pm,
|
||||
settlement_coordinator=settlement_coordinator,
|
||||
max_comm_cycles=get_env_int("MAX_COMM_CYCLES", 2),
|
||||
)
|
||||
|
||||
# Create scheduler callback
|
||||
scheduler_callback = None
|
||||
trading_dates = []
|
||||
|
||||
if is_backtest:
|
||||
backtest_scheduler = BacktestScheduler(
|
||||
start_date=args.start_date,
|
||||
end_date=args.end_date,
|
||||
trading_calendar="NYSE",
|
||||
delay_between_days=0.5,
|
||||
)
|
||||
trading_dates = backtest_scheduler.get_trading_dates()
|
||||
|
||||
async def scheduler_callback_fn(callback):
|
||||
await backtest_scheduler.start(callback)
|
||||
|
||||
scheduler_callback = scheduler_callback_fn
|
||||
else:
|
||||
# Live mode: use daily scheduler with NYSE timezone
|
||||
live_scheduler = Scheduler(
|
||||
mode="daily",
|
||||
trigger_time=args.trigger_time,
|
||||
config={"config_name": config_name},
|
||||
)
|
||||
|
||||
async def scheduler_callback_fn(callback):
|
||||
await live_scheduler.start(callback)
|
||||
|
||||
scheduler_callback = scheduler_callback_fn
|
||||
|
||||
# Create gateway
|
||||
gateway = Gateway(
|
||||
market_service=market_service,
|
||||
storage_service=storage_service,
|
||||
pipeline=pipeline,
|
||||
scheduler_callback=scheduler_callback,
|
||||
config={
|
||||
"mode": args.mode,
|
||||
"mock_mode": args.mock,
|
||||
"backtest_mode": is_backtest,
|
||||
"tickers": tickers,
|
||||
"config_name": config_name,
|
||||
},
|
||||
)
|
||||
|
||||
if is_backtest:
|
||||
gateway.set_backtest_dates(trading_dates)
|
||||
|
||||
# Start long-term memory contexts and run gateway
|
||||
async with AsyncExitStack() as stack:
|
||||
for memory in long_term_memories:
|
||||
await stack.enter_async_context(memory)
|
||||
await gateway.start(host=args.host, port=args.port)
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point"""
|
||||
parser = argparse.ArgumentParser(description="Trading System")
|
||||
parser.add_argument("--mode", choices=["live", "backtest"], default="live")
|
||||
parser.add_argument("--mock", action="store_true")
|
||||
parser.add_argument("--config-name", default="mock")
|
||||
parser.add_argument("--host", default="0.0.0.0")
|
||||
parser.add_argument("--port", type=int, default=8765)
|
||||
parser.add_argument("--trigger-time", default="09:30") # NYSE market open
|
||||
parser.add_argument("--poll-interval", type=int, default=10)
|
||||
parser.add_argument("--start-date")
|
||||
parser.add_argument("--end-date")
|
||||
parser.add_argument(
|
||||
"--enable-memory",
|
||||
action="store_true",
|
||||
help="Enable ReMeTaskLongTermMemory for agents",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load config from env for logging
|
||||
tickers = get_env_list("TICKERS", ["AAPL", "MSFT"])
|
||||
initial_cash = get_env_float("INITIAL_CASH", 100000.0)
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info(f"Mode: {args.mode}, Config: {args.config_name}")
|
||||
logger.info(f"Tickers: {tickers}")
|
||||
logger.info(f"Initial Cash: ${initial_cash:,.2f}")
|
||||
logger.info(
|
||||
f"Long-term Memory: {'enabled' if args.enable_memory else 'disabled'}",
|
||||
)
|
||||
if args.mode == "backtest":
|
||||
if not args.start_date or not args.end_date:
|
||||
parser.error(
|
||||
"--start-date and --end-date required for backtest mode",
|
||||
)
|
||||
logger.info(f"Backtest: {args.start_date} to {args.end_date}")
|
||||
logger.info("=" * 60)
|
||||
|
||||
asyncio.run(run_with_gateway(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
2
backend/services/__init__.py
Normal file
2
backend/services/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Services layer for infrastructure components"""
|
||||
569
backend/services/gateway.py
Normal file
569
backend/services/gateway.py
Normal file
@@ -0,0 +1,569 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
WebSocket Gateway for frontend communication
|
||||
"""
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Callable, Dict, List, Optional, Set
|
||||
|
||||
import websockets
|
||||
from websockets.server import WebSocketServerProtocol
|
||||
|
||||
from backend.utils.msg_adapter import FrontendAdapter
|
||||
from backend.utils.terminal_dashboard import get_dashboard
|
||||
from backend.core.pipeline import TradingPipeline
|
||||
from backend.core.state_sync import StateSync
|
||||
from backend.services.market import MarketService
|
||||
from backend.services.storage import StorageService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Gateway:
|
||||
"""WebSocket Gateway for frontend communication"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
market_service: MarketService,
|
||||
storage_service: StorageService,
|
||||
pipeline: TradingPipeline,
|
||||
state_sync: Optional[StateSync] = None,
|
||||
scheduler_callback: Optional[Callable] = None,
|
||||
config: Dict[str, Any] = None,
|
||||
):
|
||||
self.market_service = market_service
|
||||
self.storage = storage_service
|
||||
self.pipeline = pipeline
|
||||
self.scheduler_callback = scheduler_callback
|
||||
self.config = config or {}
|
||||
|
||||
self.mode = self.config.get("mode", "live")
|
||||
self.is_backtest = self.mode == "backtest" or self.config.get(
|
||||
"backtest_mode",
|
||||
False,
|
||||
)
|
||||
|
||||
self.state_sync = state_sync or StateSync(storage=storage_service)
|
||||
# self.state_sync.set_mode(self.is_backtest)
|
||||
self.state_sync.set_broadcast_fn(self.broadcast)
|
||||
self.pipeline.state_sync = self.state_sync
|
||||
|
||||
self.connected_clients: Set[WebSocketServerProtocol] = set()
|
||||
self.lock = asyncio.Lock()
|
||||
self._backtest_task: Optional[asyncio.Task] = None
|
||||
self._backtest_start_date: Optional[str] = None
|
||||
self._backtest_end_date: Optional[str] = None
|
||||
self._dashboard = get_dashboard()
|
||||
self._market_status_task: Optional[asyncio.Task] = None
|
||||
|
||||
# Session tracking for live returns
|
||||
self._session_start_portfolio_value: Optional[float] = None
|
||||
|
||||
async def start(self, host: str = "0.0.0.0", port: int = 8766):
|
||||
"""Start gateway server"""
|
||||
logger.info(f"Starting gateway on {host}:{port}")
|
||||
|
||||
# Initialize terminal dashboard
|
||||
self._dashboard.set_config(
|
||||
mode=self.mode,
|
||||
config_name=self.config.get("config_name", "default"),
|
||||
host=host,
|
||||
port=port,
|
||||
poll_interval=self.config.get("poll_interval", 10),
|
||||
mock=self.config.get("mock_mode", False),
|
||||
tickers=self.config.get("tickers", []),
|
||||
initial_cash=self.storage.initial_cash,
|
||||
start_date=self._backtest_start_date or "",
|
||||
end_date=self._backtest_end_date or "",
|
||||
)
|
||||
self._dashboard.start()
|
||||
|
||||
self.state_sync.load_state()
|
||||
self.state_sync.update_state("status", "running")
|
||||
self.state_sync.update_state("server_mode", self.mode)
|
||||
self.state_sync.update_state("is_backtest", self.is_backtest)
|
||||
self.state_sync.update_state(
|
||||
"is_mock_mode",
|
||||
self.config.get("mock_mode", False),
|
||||
)
|
||||
|
||||
# Load and display existing portfolio state if available
|
||||
summary = self.storage.load_file("summary")
|
||||
if summary:
|
||||
holdings = self.storage.load_file("holdings") or []
|
||||
trades = self.storage.load_file("trades") or []
|
||||
current_date = self.state_sync.state.get("current_date")
|
||||
self._dashboard.update(
|
||||
date=current_date or "-",
|
||||
status="running",
|
||||
portfolio=summary,
|
||||
holdings=holdings,
|
||||
trades=trades,
|
||||
)
|
||||
logger.info(
|
||||
"Loaded existing portfolio: $%s",
|
||||
f"{summary.get('totalAssetValue', 0):,.2f}",
|
||||
)
|
||||
|
||||
await self.market_service.start(broadcast_func=self.broadcast)
|
||||
|
||||
if self.scheduler_callback:
|
||||
await self.scheduler_callback(callback=self.on_strategy_trigger)
|
||||
|
||||
# Start market status monitoring (only for live mode)
|
||||
if not self.is_backtest:
|
||||
self._market_status_task = asyncio.create_task(
|
||||
self._market_status_monitor(),
|
||||
)
|
||||
|
||||
async with websockets.serve(
|
||||
self.handle_client,
|
||||
host,
|
||||
port,
|
||||
ping_interval=30,
|
||||
ping_timeout=60,
|
||||
):
|
||||
logger.info(
|
||||
f"Gateway started: ws://{host}:{port}, mode={self.mode}",
|
||||
)
|
||||
await asyncio.Future()
|
||||
|
||||
@property
|
||||
def state(self) -> Dict[str, Any]:
|
||||
return self.state_sync.state
|
||||
|
||||
async def handle_client(self, websocket: WebSocketServerProtocol):
|
||||
"""Handle WebSocket client connection"""
|
||||
async with self.lock:
|
||||
self.connected_clients.add(websocket)
|
||||
|
||||
await self._send_initial_state(websocket)
|
||||
await self._handle_client_messages(websocket)
|
||||
|
||||
async with self.lock:
|
||||
self.connected_clients.discard(websocket)
|
||||
|
||||
async def _send_initial_state(self, websocket: WebSocketServerProtocol):
|
||||
state_payload = self.state_sync.get_initial_state_payload(
|
||||
include_dashboard=True,
|
||||
)
|
||||
# Include market status in initial state
|
||||
state_payload[
|
||||
"market_status"
|
||||
] = self.market_service.get_market_status()
|
||||
|
||||
# Include live returns if session is active
|
||||
if self.storage.is_live_session_active:
|
||||
live_returns = self.storage.get_live_returns()
|
||||
if "portfolio" in state_payload:
|
||||
state_payload["portfolio"].update(live_returns)
|
||||
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{"type": "initial_state", "state": state_payload},
|
||||
ensure_ascii=False,
|
||||
default=str,
|
||||
),
|
||||
)
|
||||
|
||||
async def _handle_client_messages(
|
||||
self,
|
||||
websocket: WebSocketServerProtocol,
|
||||
):
|
||||
try:
|
||||
async for message in websocket:
|
||||
data = json.loads(message)
|
||||
msg_type = data.get("type", "unknown")
|
||||
|
||||
if msg_type == "ping":
|
||||
await websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "pong",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
},
|
||||
ensure_ascii=False,
|
||||
),
|
||||
)
|
||||
elif msg_type == "get_state":
|
||||
await self._send_initial_state(websocket)
|
||||
elif msg_type == "start_backtest":
|
||||
await self._handle_start_backtest(data)
|
||||
|
||||
except websockets.ConnectionClosed:
|
||||
pass
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
async def _handle_start_backtest(self, data: Dict[str, Any]):
|
||||
if not self.is_backtest:
|
||||
return
|
||||
dates = data.get("dates", [])
|
||||
if dates and self._backtest_task is None:
|
||||
task = asyncio.create_task(
|
||||
self._run_backtest_dates(dates),
|
||||
)
|
||||
task.add_done_callback(self._handle_backtest_exception)
|
||||
self._backtest_task = task
|
||||
|
||||
async def broadcast(self, message: Dict[str, Any]):
|
||||
"""Broadcast message to all connected clients"""
|
||||
if not self.connected_clients:
|
||||
return
|
||||
|
||||
message_json = json.dumps(message, ensure_ascii=False, default=str)
|
||||
|
||||
async with self.lock:
|
||||
tasks = [
|
||||
self._send_to_client(client, message_json)
|
||||
for client in self.connected_clients.copy()
|
||||
]
|
||||
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
async def _send_to_client(
|
||||
self,
|
||||
client: WebSocketServerProtocol,
|
||||
message: str,
|
||||
):
|
||||
try:
|
||||
await client.send(message)
|
||||
except websockets.ConnectionClosed:
|
||||
async with self.lock:
|
||||
self.connected_clients.discard(client)
|
||||
|
||||
async def _market_status_monitor(self):
|
||||
"""Periodically check and broadcast market status changes"""
|
||||
while True:
|
||||
try:
|
||||
await self.market_service.check_and_broadcast_market_status()
|
||||
|
||||
# On market open, start live session tracking
|
||||
status = self.market_service.get_market_status()
|
||||
if (
|
||||
status["status"] == "open"
|
||||
and not self.storage.is_live_session_active
|
||||
):
|
||||
self.storage.start_live_session()
|
||||
summary = self.storage.load_file("summary") or {}
|
||||
self._session_start_portfolio_value = summary.get(
|
||||
"totalAssetValue",
|
||||
self.storage.initial_cash,
|
||||
)
|
||||
logger.info(
|
||||
"Session start portfolio: "
|
||||
f"${self._session_start_portfolio_value:,.2f}",
|
||||
)
|
||||
elif (
|
||||
status["status"] != "open"
|
||||
and self.storage.is_live_session_active
|
||||
):
|
||||
self.storage.end_live_session()
|
||||
self._session_start_portfolio_value = None
|
||||
|
||||
# Update and broadcast live returns if session is active
|
||||
if self.storage.is_live_session_active:
|
||||
await self._update_and_broadcast_live_returns()
|
||||
|
||||
await asyncio.sleep(60) # Check every minute
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Market status monitor error: {e}")
|
||||
await asyncio.sleep(60)
|
||||
|
||||
async def _update_and_broadcast_live_returns(self):
|
||||
"""Calculate and broadcast live returns for current session"""
|
||||
if not self.storage.is_live_session_active:
|
||||
return
|
||||
|
||||
# Get current prices and calculate portfolio value
|
||||
prices = self.market_service.get_all_prices()
|
||||
if not prices or not any(p > 0 for p in prices.values()):
|
||||
return
|
||||
|
||||
# Load current internal state to get baseline values
|
||||
state = self.storage.load_internal_state()
|
||||
|
||||
# Get latest values from history (if available)
|
||||
equity_history = state.get("equity_history", [])
|
||||
baseline_history = state.get("baseline_history", [])
|
||||
baseline_vw_history = state.get("baseline_vw_history", [])
|
||||
momentum_history = state.get("momentum_history", [])
|
||||
|
||||
current_equity = equity_history[-1]["v"] if equity_history else None
|
||||
current_baseline = (
|
||||
baseline_history[-1]["v"] if baseline_history else None
|
||||
)
|
||||
current_baseline_vw = (
|
||||
baseline_vw_history[-1]["v"] if baseline_vw_history else None
|
||||
)
|
||||
current_momentum = (
|
||||
momentum_history[-1]["v"] if momentum_history else None
|
||||
)
|
||||
|
||||
# Update live returns with current values
|
||||
point = self.storage.update_live_returns(
|
||||
current_equity=current_equity,
|
||||
current_baseline=current_baseline,
|
||||
current_baseline_vw=current_baseline_vw,
|
||||
current_momentum=current_momentum,
|
||||
)
|
||||
|
||||
# Broadcast if we have new data
|
||||
if point:
|
||||
live_returns = self.storage.get_live_returns()
|
||||
await self.broadcast(
|
||||
{
|
||||
"type": "team_summary",
|
||||
"equity_return": live_returns["equity_return"],
|
||||
"baseline_return": live_returns["baseline_return"],
|
||||
"baseline_vw_return": live_returns["baseline_vw_return"],
|
||||
"momentum_return": live_returns["momentum_return"],
|
||||
},
|
||||
)
|
||||
|
||||
async def on_strategy_trigger(self, date: str):
|
||||
"""Handle trading cycle trigger"""
|
||||
logger.info(f"Strategy triggered for {date}")
|
||||
|
||||
tickers = self.config.get("tickers", [])
|
||||
|
||||
if self.is_backtest:
|
||||
await self._run_backtest_cycle(date, tickers)
|
||||
else:
|
||||
await self._run_live_cycle(date, tickers)
|
||||
|
||||
async def _run_backtest_cycle(self, date: str, tickers: List[str]):
|
||||
"""Run backtest cycle with pre-loaded prices"""
|
||||
self.market_service.set_backtest_date(date)
|
||||
await self.market_service.emit_market_open()
|
||||
|
||||
await self.state_sync.on_cycle_start(date)
|
||||
self._dashboard.update(date=date, status="Analyzing...")
|
||||
|
||||
prices = self.market_service.get_open_prices()
|
||||
close_prices = self.market_service.get_close_prices()
|
||||
market_caps = self._get_market_caps(tickers, date)
|
||||
|
||||
result = await self.pipeline.run_cycle(
|
||||
tickers=tickers,
|
||||
date=date,
|
||||
prices=prices,
|
||||
close_prices=close_prices,
|
||||
market_caps=market_caps,
|
||||
)
|
||||
|
||||
await self.market_service.emit_market_close()
|
||||
settlement_result = result.get("settlement_result")
|
||||
self._save_cycle_results(result, date, close_prices, settlement_result)
|
||||
await self._broadcast_portfolio_updates(result, close_prices)
|
||||
await self._finalize_cycle(date)
|
||||
|
||||
async def _run_live_cycle(self, date: str, tickers: List[str]):
|
||||
"""
|
||||
Run live cycle with real market timing.
|
||||
|
||||
- Analysis runs immediately
|
||||
- Execution waits for market open
|
||||
(or uses current prices if already open)
|
||||
- Settlement waits for market close
|
||||
"""
|
||||
# Get actual trading date (might be next trading day if weekend)
|
||||
trading_date = self.market_service.get_live_trading_date()
|
||||
logger.info(
|
||||
f"Live cycle: triggered={date}, trading_date={trading_date}",
|
||||
)
|
||||
|
||||
await self.state_sync.on_cycle_start(trading_date)
|
||||
self._dashboard.update(date=trading_date, status="Analyzing...")
|
||||
|
||||
market_caps = self._get_market_caps(tickers, trading_date)
|
||||
|
||||
# Run pipeline with async price callbacks
|
||||
result = await self.pipeline.run_cycle(
|
||||
tickers=tickers,
|
||||
date=trading_date,
|
||||
market_caps=market_caps,
|
||||
get_open_prices_fn=self.market_service.wait_for_open_prices,
|
||||
get_close_prices_fn=self.market_service.wait_for_close_prices,
|
||||
)
|
||||
|
||||
close_prices = self.market_service.get_all_prices()
|
||||
settlement_result = result.get("settlement_result")
|
||||
self._save_cycle_results(
|
||||
result,
|
||||
trading_date,
|
||||
close_prices,
|
||||
settlement_result,
|
||||
)
|
||||
await self._broadcast_portfolio_updates(result, close_prices)
|
||||
await self._finalize_cycle(trading_date)
|
||||
|
||||
async def _finalize_cycle(self, date: str):
|
||||
"""Finalize cycle: broadcast state and update dashboard"""
|
||||
summary = self.storage.load_file("summary") or {}
|
||||
|
||||
# Include live returns if session is active
|
||||
if self.storage.is_live_session_active:
|
||||
live_returns = self.storage.get_live_returns()
|
||||
summary.update(live_returns)
|
||||
|
||||
await self.state_sync.on_cycle_end(date, portfolio_summary=summary)
|
||||
|
||||
holdings = self.storage.load_file("holdings") or []
|
||||
trades = self.storage.load_file("trades") or []
|
||||
leaderboard = self.storage.load_file("leaderboard") or []
|
||||
|
||||
if leaderboard:
|
||||
await self.state_sync.on_leaderboard_update(leaderboard)
|
||||
|
||||
self._dashboard.update(
|
||||
date=date,
|
||||
status="Running",
|
||||
portfolio=summary,
|
||||
holdings=holdings,
|
||||
trades=trades,
|
||||
)
|
||||
|
||||
def _get_market_caps(
|
||||
self,
|
||||
tickers: List[str],
|
||||
date: str,
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
Get market caps for tickers (stub implementation)
|
||||
|
||||
Args:
|
||||
tickers: List of tickers
|
||||
date: Trading date
|
||||
|
||||
Returns:
|
||||
Dict mapping ticker to market cap
|
||||
"""
|
||||
from ..tools.data_tools import get_market_cap
|
||||
|
||||
market_caps = {}
|
||||
for ticker in tickers:
|
||||
try:
|
||||
market_cap = get_market_cap(ticker, date)
|
||||
if market_cap:
|
||||
market_caps[ticker] = market_cap
|
||||
else:
|
||||
market_caps[ticker] = 1e9
|
||||
except Exception:
|
||||
market_caps[ticker] = 1e9
|
||||
|
||||
return market_caps
|
||||
|
||||
async def _broadcast_portfolio_updates(
|
||||
self,
|
||||
result: Dict[str, Any],
|
||||
prices: Dict[str, float],
|
||||
):
|
||||
portfolio = result.get("portfolio", {})
|
||||
|
||||
if portfolio:
|
||||
holdings = FrontendAdapter.build_holdings(portfolio, prices)
|
||||
if holdings:
|
||||
await self.state_sync.on_holdings_update(holdings)
|
||||
|
||||
stats = FrontendAdapter.build_stats(portfolio, prices)
|
||||
if stats:
|
||||
await self.state_sync.on_stats_update(stats)
|
||||
|
||||
executed_trades = result.get("executed_trades", [])
|
||||
if executed_trades:
|
||||
await self.state_sync.on_trades_executed(executed_trades)
|
||||
|
||||
def _save_cycle_results(
|
||||
self,
|
||||
result: Dict[str, Any],
|
||||
date: str,
|
||||
prices: Dict[str, float],
|
||||
settlement_result: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
portfolio = result.get("portfolio", {})
|
||||
executed_trades = result.get("executed_trades", [])
|
||||
|
||||
# Extract baseline values from settlement result
|
||||
baseline_values = None
|
||||
if settlement_result:
|
||||
baseline_values = settlement_result.get("baseline_values")
|
||||
|
||||
if portfolio:
|
||||
self.storage.update_dashboard_after_cycle(
|
||||
portfolio=portfolio,
|
||||
prices=prices,
|
||||
date=date,
|
||||
executed_trades=executed_trades,
|
||||
baseline_values=baseline_values,
|
||||
)
|
||||
|
||||
async def _run_backtest_dates(self, dates: List[str]):
|
||||
self.state_sync.set_backtest_dates(dates)
|
||||
self._dashboard.update(days_total=len(dates), days_completed=0)
|
||||
|
||||
await self.state_sync.on_system_message(
|
||||
f"Starting backtest - {len(dates)} trading days",
|
||||
)
|
||||
|
||||
try:
|
||||
for i, date in enumerate(dates):
|
||||
self._dashboard.update(days_completed=i)
|
||||
await self.on_strategy_trigger(date=date)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
await self.state_sync.on_system_message(
|
||||
f"Backtest complete - {len(dates)} days",
|
||||
)
|
||||
|
||||
# Update dashboard with final state
|
||||
summary = self.storage.load_file("summary") or {}
|
||||
self._dashboard.update(
|
||||
status="Complete",
|
||||
portfolio=summary,
|
||||
days_completed=len(dates),
|
||||
)
|
||||
self._dashboard.stop()
|
||||
self._dashboard.print_final_summary()
|
||||
except Exception as e:
|
||||
error_msg = f"Backtest failed: {type(e).__name__}: {str(e)}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
await self.state_sync.on_system_message(error_msg)
|
||||
self._dashboard.update(status=f"Failed: {str(e)}")
|
||||
self._dashboard.stop()
|
||||
raise
|
||||
finally:
|
||||
self._backtest_task = None
|
||||
|
||||
def _handle_backtest_exception(self, task: asyncio.Task):
|
||||
"""Handle exceptions from backtest task"""
|
||||
try:
|
||||
task.result()
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Backtest task was cancelled")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Backtest task failed with exception:{type(e).__name__}:{e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
def set_backtest_dates(self, dates: List[str]):
|
||||
self.state_sync.set_backtest_dates(dates)
|
||||
if dates:
|
||||
self._backtest_start_date = dates[0]
|
||||
self._backtest_end_date = dates[-1]
|
||||
self._dashboard.days_total = len(dates)
|
||||
|
||||
def stop(self):
|
||||
self.state_sync.save_state()
|
||||
self.market_service.stop()
|
||||
if self._backtest_task:
|
||||
self._backtest_task.cancel()
|
||||
if self._market_status_task:
|
||||
self._market_status_task.cancel()
|
||||
self._dashboard.stop()
|
||||
625
backend/services/market.py
Normal file
625
backend/services/market.py
Normal file
@@ -0,0 +1,625 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Market Data Service
|
||||
Supports live, mock, and backtest modes
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from zoneinfo import ZoneInfo
|
||||
|
||||
import pandas_market_calendars as mcal
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# NYSE timezone and calendar
|
||||
NYSE_TZ = ZoneInfo("America/New_York")
|
||||
NYSE_CALENDAR = mcal.get_calendar("NYSE")
|
||||
|
||||
|
||||
class MarketStatus:
|
||||
"""Market status enum-like class"""
|
||||
|
||||
OPEN = "open"
|
||||
CLOSED = "closed"
|
||||
PREMARKET = "premarket"
|
||||
AFTERHOURS = "afterhours"
|
||||
|
||||
|
||||
class MarketService:
|
||||
"""Market data service for price management"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tickers: List[str],
|
||||
poll_interval: int = 10,
|
||||
mock_mode: bool = False,
|
||||
backtest_mode: bool = False,
|
||||
api_key: Optional[str] = None,
|
||||
backtest_start_date: Optional[str] = None,
|
||||
backtest_end_date: Optional[str] = None,
|
||||
):
|
||||
self.tickers = tickers
|
||||
self.poll_interval = poll_interval
|
||||
self.mock_mode = mock_mode
|
||||
self.backtest_mode = backtest_mode
|
||||
self.api_key = api_key
|
||||
self.backtest_start_date = backtest_start_date
|
||||
self.backtest_end_date = backtest_end_date
|
||||
|
||||
self.cache: Dict[str, Dict[str, Any]] = {}
|
||||
self.running = False
|
||||
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
self._broadcast_func: Optional[Callable] = None
|
||||
self._price_manager: Optional[Any] = None
|
||||
self._current_date: Optional[str] = None
|
||||
|
||||
# Market status tracking
|
||||
self._last_market_status: Optional[str] = None
|
||||
|
||||
# Session tracking for live returns
|
||||
self._session_start_values: Optional[Dict[str, float]] = None
|
||||
self._session_start_timestamp: Optional[int] = None
|
||||
|
||||
@property
|
||||
def mode_name(self) -> str:
|
||||
if self.backtest_mode:
|
||||
return "BACKTEST"
|
||||
elif self.mock_mode:
|
||||
return "MOCK"
|
||||
return "LIVE"
|
||||
|
||||
async def start(self, broadcast_func: Callable):
|
||||
"""Start market data service"""
|
||||
if self.running:
|
||||
return
|
||||
|
||||
self.running = True
|
||||
self._loop = asyncio.get_running_loop()
|
||||
self._broadcast_func = broadcast_func
|
||||
|
||||
if self.backtest_mode:
|
||||
self._start_backtest_mode()
|
||||
elif self.mock_mode:
|
||||
self._start_mock_mode()
|
||||
else:
|
||||
self._start_real_mode()
|
||||
|
||||
logger.info(
|
||||
f"Market service started: {self.mode_name}, tickers={self.tickers}", # noqa: E501
|
||||
)
|
||||
|
||||
def _make_price_callback(self) -> Callable:
|
||||
"""Create thread-safe price callback"""
|
||||
|
||||
def callback(price_data: Dict[str, Any]):
|
||||
symbol = price_data["symbol"]
|
||||
self.cache[symbol] = price_data
|
||||
|
||||
loop = self._loop
|
||||
if loop and loop.is_running() and self._broadcast_func:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self._broadcast_price_update(price_data),
|
||||
loop,
|
||||
)
|
||||
|
||||
return callback
|
||||
|
||||
def _start_mock_mode(self):
|
||||
from backend.data.mock_price_manager import MockPriceManager
|
||||
|
||||
self._price_manager = MockPriceManager(
|
||||
poll_interval=self.poll_interval,
|
||||
volatility=0.5,
|
||||
)
|
||||
self._price_manager.add_price_callback(self._make_price_callback())
|
||||
self._price_manager.subscribe(
|
||||
self.tickers,
|
||||
base_prices={t: 100.0 for t in self.tickers},
|
||||
)
|
||||
self._price_manager.start()
|
||||
|
||||
def _start_real_mode(self):
|
||||
from backend.data.polling_price_manager import PollingPriceManager
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError("API key required for live mode")
|
||||
self._price_manager = PollingPriceManager(
|
||||
api_key=self.api_key,
|
||||
poll_interval=self.poll_interval,
|
||||
)
|
||||
self._price_manager.add_price_callback(self._make_price_callback())
|
||||
self._price_manager.subscribe(self.tickers)
|
||||
self._price_manager.start()
|
||||
|
||||
def _start_backtest_mode(self):
|
||||
from backend.data.historical_price_manager import (
|
||||
HistoricalPriceManager,
|
||||
)
|
||||
|
||||
self._price_manager = HistoricalPriceManager()
|
||||
self._price_manager.add_price_callback(self._make_price_callback())
|
||||
self._price_manager.subscribe(self.tickers)
|
||||
|
||||
if self.backtest_start_date and self.backtest_end_date:
|
||||
self._price_manager.preload_data(
|
||||
self.backtest_start_date,
|
||||
self.backtest_end_date,
|
||||
)
|
||||
|
||||
self._price_manager.start()
|
||||
|
||||
async def _broadcast_price_update(self, price_data: Dict[str, Any]):
|
||||
"""Broadcast price update to frontend"""
|
||||
if not self._broadcast_func:
|
||||
return
|
||||
|
||||
symbol = price_data["symbol"]
|
||||
price = price_data["price"]
|
||||
open_price = price_data.get("open", price)
|
||||
ret = (
|
||||
((price - open_price) / open_price) * 100 if open_price > 0 else 0
|
||||
)
|
||||
|
||||
await self._broadcast_func(
|
||||
{
|
||||
"type": "price_update",
|
||||
"symbol": symbol,
|
||||
"price": price,
|
||||
"open": open_price,
|
||||
"ret": ret,
|
||||
"timestamp": price_data.get("timestamp"),
|
||||
"realtime_prices": {
|
||||
t: self._get_cached_price(t) for t in self.tickers
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
def _get_cached_price(self, ticker: str) -> Dict[str, Any]:
|
||||
"""Get cached price data for a ticker"""
|
||||
if ticker in self.cache:
|
||||
return self.cache[ticker]
|
||||
# Return from price manager if not in cache
|
||||
if self._price_manager:
|
||||
price = self._price_manager.get_latest_price(ticker)
|
||||
if price:
|
||||
return {"price": price, "symbol": ticker}
|
||||
return {"price": 0, "symbol": ticker}
|
||||
|
||||
def stop(self):
|
||||
"""Stop market service"""
|
||||
if not self.running:
|
||||
return
|
||||
self.running = False
|
||||
if self._price_manager:
|
||||
self._price_manager.stop()
|
||||
self._price_manager = None
|
||||
self._loop = None
|
||||
self._broadcast_func = None
|
||||
|
||||
# Backtest methods
|
||||
def set_backtest_date(self, date: str):
|
||||
"""Set current backtest date"""
|
||||
if not self.backtest_mode or not self._price_manager:
|
||||
return
|
||||
self._current_date = date
|
||||
self._price_manager.set_date(date)
|
||||
logger.info(f"Backtest date: {date}")
|
||||
|
||||
async def emit_market_open(self):
|
||||
"""Emit market open prices"""
|
||||
if self.backtest_mode and self._price_manager:
|
||||
self._price_manager.emit_open_prices()
|
||||
# Log prices for debugging
|
||||
prices = self.get_open_prices()
|
||||
logger.info(f"Open prices: {prices}")
|
||||
|
||||
async def emit_market_close(self):
|
||||
"""Emit market close prices"""
|
||||
if self.backtest_mode and self._price_manager:
|
||||
self._price_manager.emit_close_prices()
|
||||
# Log prices for debugging
|
||||
prices = self.get_close_prices()
|
||||
logger.info(f"Close prices: {prices}")
|
||||
|
||||
def get_open_prices(self) -> Dict[str, float]:
|
||||
"""Get open prices for all tickers"""
|
||||
prices = {}
|
||||
for ticker in self.tickers:
|
||||
price = None
|
||||
# Try price manager first
|
||||
if self.backtest_mode and self._price_manager:
|
||||
price = self._price_manager.get_open_price(ticker)
|
||||
# Fallback to cache
|
||||
if price is None or price <= 0:
|
||||
cached = self.cache.get(ticker, {})
|
||||
price = cached.get("open") or cached.get("price")
|
||||
prices[ticker] = price if price and price > 0 else 0.0
|
||||
return prices
|
||||
|
||||
def get_close_prices(self) -> Dict[str, float]:
|
||||
"""Get close prices for all tickers"""
|
||||
prices = {}
|
||||
for ticker in self.tickers:
|
||||
price = None
|
||||
# Try price manager first
|
||||
if self.backtest_mode and self._price_manager:
|
||||
price = self._price_manager.get_close_price(ticker)
|
||||
# Fallback to cache
|
||||
if price is None or price <= 0:
|
||||
cached = self.cache.get(ticker, {})
|
||||
price = cached.get("close") or cached.get("price")
|
||||
prices[ticker] = price if price and price > 0 else 0.0
|
||||
return prices
|
||||
|
||||
def get_price_for_date(
|
||||
self,
|
||||
ticker: str,
|
||||
date: str,
|
||||
price_type: str = "close",
|
||||
) -> Optional[float]:
|
||||
"""Get price for a specific date"""
|
||||
if self.backtest_mode and self._price_manager:
|
||||
return self._price_manager.get_price_for_date(
|
||||
ticker,
|
||||
date,
|
||||
price_type,
|
||||
)
|
||||
return self.get_price_sync(ticker)
|
||||
|
||||
# Common methods
|
||||
def get_price_sync(self, ticker: str) -> Optional[float]:
|
||||
"""Get latest price synchronously"""
|
||||
# Try cache first
|
||||
data = self.cache.get(ticker)
|
||||
if data and data.get("price"):
|
||||
return data["price"]
|
||||
# Try price manager
|
||||
if self._price_manager:
|
||||
return self._price_manager.get_latest_price(ticker)
|
||||
return None
|
||||
|
||||
def get_all_prices(self) -> Dict[str, float]:
|
||||
"""Get all latest prices"""
|
||||
prices = {}
|
||||
for ticker in self.tickers:
|
||||
price = self.get_price_sync(ticker)
|
||||
prices[ticker] = price if price and price > 0 else 0.0
|
||||
return prices
|
||||
|
||||
# Live mode async waiting methods
|
||||
|
||||
def _now_nyse(self) -> datetime:
|
||||
"""Get current time in NYSE timezone"""
|
||||
return datetime.now(NYSE_TZ)
|
||||
|
||||
def _is_trading_day(self, date: datetime) -> bool:
|
||||
"""Check if date is a NYSE trading day"""
|
||||
date_str = date.strftime("%Y-%m-%d")
|
||||
valid_days = NYSE_CALENDAR.valid_days(
|
||||
start_date=date_str,
|
||||
end_date=date_str,
|
||||
)
|
||||
return len(valid_days) > 0
|
||||
|
||||
def _get_market_hours(self, date: datetime) -> tuple:
|
||||
"""Get market open and close times for a given date"""
|
||||
date_str = date.strftime("%Y-%m-%d")
|
||||
schedule = NYSE_CALENDAR.schedule(
|
||||
start_date=date_str,
|
||||
end_date=date_str,
|
||||
)
|
||||
if schedule.empty:
|
||||
return None, None
|
||||
market_open = schedule.iloc[0]["market_open"].to_pydatetime()
|
||||
market_close = schedule.iloc[0]["market_close"].to_pydatetime()
|
||||
return market_open, market_close
|
||||
|
||||
def _next_trading_day(self, from_date: datetime) -> datetime:
|
||||
"""Find the next trading day from given date"""
|
||||
check_date = from_date + timedelta(days=1)
|
||||
for _ in range(10): # Max 10 days ahead (handles holidays)
|
||||
if self._is_trading_day(check_date):
|
||||
return check_date
|
||||
check_date += timedelta(days=1)
|
||||
return check_date
|
||||
|
||||
def _get_trading_date_for_execution(self) -> tuple:
|
||||
"""
|
||||
Determine the trading date for execution.
|
||||
|
||||
Returns:
|
||||
(trading_date, market_open_time, market_close_time)
|
||||
|
||||
Logic:
|
||||
- If today is a trading day and market has opened: use today
|
||||
- If today is a trading day but market hasn't opened: wait for open
|
||||
- If today is not a trading day: use next trading day
|
||||
"""
|
||||
now = self._now_nyse()
|
||||
today = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
|
||||
if self._is_trading_day(today):
|
||||
market_open, market_close = self._get_market_hours(today)
|
||||
return today, market_open, market_close
|
||||
else:
|
||||
# Weekend or holiday - find next trading day
|
||||
next_day = self._next_trading_day(today)
|
||||
market_open, market_close = self._get_market_hours(next_day)
|
||||
return next_day, market_open, market_close
|
||||
|
||||
async def wait_for_open_prices(self) -> Dict[str, float]:
|
||||
"""
|
||||
Wait for market open and return open prices.
|
||||
|
||||
Behavior:
|
||||
- If market is already open today: return current prices immediately
|
||||
- If market hasn't opened yet today: wait until open
|
||||
- If not a trading day: wait until next trading day opens
|
||||
"""
|
||||
now = self._now_nyse()
|
||||
trading_date, market_open, _ = self._get_trading_date_for_execution()
|
||||
|
||||
if market_open is None:
|
||||
logger.warning("Could not determine market hours")
|
||||
return self.get_all_prices()
|
||||
|
||||
trading_date_str = trading_date.strftime("%Y-%m-%d")
|
||||
|
||||
# Check if we need to wait
|
||||
if now < market_open:
|
||||
wait_seconds = (market_open - now).total_seconds()
|
||||
logger.info(
|
||||
f"Waiting {wait_seconds/60:.1f} min for market open "
|
||||
f"({trading_date_str} {market_open.strftime('%H:%M')} ET)",
|
||||
)
|
||||
await asyncio.sleep(wait_seconds)
|
||||
# Small delay to ensure prices are available
|
||||
await asyncio.sleep(5)
|
||||
else:
|
||||
logger.info(
|
||||
f"Market already open for {trading_date_str}, "
|
||||
f"getting current prices",
|
||||
)
|
||||
|
||||
# Poll until we have valid prices
|
||||
prices = await self._poll_for_prices()
|
||||
logger.info(f"Got open prices for {trading_date_str}: {prices}")
|
||||
return prices
|
||||
|
||||
async def wait_for_close_prices(self) -> Dict[str, float]:
|
||||
"""
|
||||
Wait for market close and return close prices.
|
||||
|
||||
Behavior:
|
||||
- If market is already closed today: return current prices immediately
|
||||
- If market hasn't closed yet: wait until close
|
||||
"""
|
||||
now = self._now_nyse()
|
||||
trading_date, _, market_close = self._get_trading_date_for_execution()
|
||||
|
||||
if market_close is None:
|
||||
logger.warning("Could not determine market hours")
|
||||
return self.get_all_prices()
|
||||
|
||||
trading_date_str = trading_date.strftime("%Y-%m-%d")
|
||||
|
||||
# Check if we need to wait
|
||||
if now < market_close:
|
||||
wait_seconds = (market_close - now).total_seconds()
|
||||
logger.info(
|
||||
f"Waiting {wait_seconds/60:.1f} min for market close "
|
||||
f"({trading_date_str} {market_close.strftime('%H:%M')} ET)",
|
||||
)
|
||||
await asyncio.sleep(wait_seconds)
|
||||
# Small delay to ensure final prices settle
|
||||
await asyncio.sleep(10)
|
||||
else:
|
||||
logger.info(
|
||||
f"Market already closed for {trading_date_str}, "
|
||||
f"getting close prices",
|
||||
)
|
||||
|
||||
# Get final prices
|
||||
prices = await self._poll_for_prices()
|
||||
logger.info(f"Got close prices for {trading_date_str}: {prices}")
|
||||
return prices
|
||||
|
||||
def get_live_trading_date(self) -> str:
|
||||
"""Get the trading date that will be used for live execution"""
|
||||
trading_date, _, _ = self._get_trading_date_for_execution()
|
||||
return trading_date.strftime("%Y-%m-%d")
|
||||
|
||||
async def _poll_for_prices(
|
||||
self,
|
||||
max_retries: int = 12,
|
||||
) -> Dict[str, float]:
|
||||
"""Poll until all prices are available"""
|
||||
for _ in range(max_retries):
|
||||
prices = self.get_all_prices()
|
||||
if all(p > 0 for p in prices.values()):
|
||||
return prices
|
||||
logger.debug("Waiting for prices to be available...")
|
||||
await asyncio.sleep(5)
|
||||
# Return whatever we have
|
||||
return self.get_all_prices()
|
||||
|
||||
# ========== Market Status Methods ==========
|
||||
|
||||
def get_market_status(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get current market status
|
||||
|
||||
Returns:
|
||||
Dict with status info:
|
||||
- status: 'open' | 'closed' | 'premarket' | 'afterhours'
|
||||
- status_text: Human readable status
|
||||
- is_trading_day: Whether today is a trading day
|
||||
- market_open: Market open time (if trading day)
|
||||
- market_close: Market close time (if trading day)
|
||||
"""
|
||||
if self.backtest_mode:
|
||||
# In backtest mode, always return open
|
||||
return {
|
||||
"status": MarketStatus.OPEN,
|
||||
"status_text": "Backtest Mode",
|
||||
"is_trading_day": True,
|
||||
}
|
||||
|
||||
now = self._now_nyse()
|
||||
today = now.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
|
||||
is_trading = self._is_trading_day(today)
|
||||
|
||||
if not is_trading:
|
||||
return {
|
||||
"status": MarketStatus.CLOSED,
|
||||
"status_text": "Market Closed (Non-trading Day)",
|
||||
"is_trading_day": False,
|
||||
}
|
||||
|
||||
market_open, market_close = self._get_market_hours(today)
|
||||
|
||||
if market_open is None or market_close is None:
|
||||
return {
|
||||
"status": MarketStatus.CLOSED,
|
||||
"status_text": "Market Closed",
|
||||
"is_trading_day": is_trading,
|
||||
}
|
||||
|
||||
# Determine status based on current time
|
||||
if now < market_open:
|
||||
return {
|
||||
"status": MarketStatus.PREMARKET,
|
||||
"status_text": "Pre-Market",
|
||||
"is_trading_day": True,
|
||||
"market_open": market_open.isoformat(),
|
||||
"market_close": market_close.isoformat(),
|
||||
}
|
||||
elif now > market_close:
|
||||
return {
|
||||
"status": MarketStatus.CLOSED,
|
||||
"status_text": "Market Closed",
|
||||
"is_trading_day": True,
|
||||
"market_open": market_open.isoformat(),
|
||||
"market_close": market_close.isoformat(),
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"status": MarketStatus.OPEN,
|
||||
"status_text": "Market Open",
|
||||
"is_trading_day": True,
|
||||
"market_open": market_open.isoformat(),
|
||||
"market_close": market_close.isoformat(),
|
||||
}
|
||||
|
||||
async def check_and_broadcast_market_status(self):
|
||||
"""Check market status and broadcast if changed"""
|
||||
status = self.get_market_status()
|
||||
current_status = status["status"]
|
||||
|
||||
if current_status != self._last_market_status:
|
||||
self._last_market_status = current_status
|
||||
await self._broadcast_market_status(status)
|
||||
|
||||
# Handle session transitions
|
||||
if current_status == MarketStatus.OPEN:
|
||||
await self._on_session_start()
|
||||
elif (
|
||||
current_status == MarketStatus.CLOSED
|
||||
and self._session_start_values is not None
|
||||
):
|
||||
self._on_session_end()
|
||||
|
||||
async def _broadcast_market_status(self, status: Dict[str, Any]):
|
||||
"""Broadcast market status update to frontend"""
|
||||
if not self._broadcast_func:
|
||||
return
|
||||
|
||||
await self._broadcast_func(
|
||||
{
|
||||
"type": "market_status_update",
|
||||
"market_status": status,
|
||||
"timestamp": datetime.now(NYSE_TZ).isoformat(),
|
||||
},
|
||||
)
|
||||
logger.info(f"Market status: {status['status_text']}")
|
||||
|
||||
async def _on_session_start(self):
|
||||
"""Called when market session starts - capture baseline values"""
|
||||
# Wait briefly for prices to be available
|
||||
await asyncio.sleep(2)
|
||||
|
||||
prices = self.get_all_prices()
|
||||
if prices and any(p > 0 for p in prices.values()):
|
||||
self._session_start_values = prices.copy()
|
||||
self._session_start_timestamp = int(
|
||||
datetime.now().timestamp() * 1000,
|
||||
)
|
||||
logger.info(f"Session started with prices: {prices}")
|
||||
|
||||
def _on_session_end(self):
|
||||
"""Called when market session ends - clear session data"""
|
||||
self._session_start_values = None
|
||||
self._session_start_timestamp = None
|
||||
logger.info("Session ended, cleared session data")
|
||||
|
||||
def get_session_returns(
|
||||
self,
|
||||
current_prices: Dict[str, float],
|
||||
portfolio_value: Optional[float] = None,
|
||||
session_start_portfolio_value: Optional[float] = None,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Calculate session returns (from session start to now)
|
||||
|
||||
Args:
|
||||
current_prices: Current prices for tickers
|
||||
portfolio_value: Current portfolio value (optional)
|
||||
session_start_portfolio_value:
|
||||
|
||||
Returns:
|
||||
Dict with return data or None if session not started
|
||||
"""
|
||||
if self._session_start_values is None:
|
||||
return None
|
||||
|
||||
timestamp = int(datetime.now().timestamp() * 1000)
|
||||
returns = {}
|
||||
|
||||
# Calculate individual ticker returns
|
||||
for ticker, start_price in self._session_start_values.items():
|
||||
current = current_prices.get(ticker)
|
||||
if current and start_price and start_price > 0:
|
||||
ret = ((current - start_price) / start_price) * 100
|
||||
returns[ticker] = round(ret, 4)
|
||||
|
||||
result = {
|
||||
"timestamp": timestamp,
|
||||
"ticker_returns": returns,
|
||||
}
|
||||
|
||||
# Calculate portfolio return if values provided
|
||||
if (
|
||||
portfolio_value is not None
|
||||
and session_start_portfolio_value is not None
|
||||
):
|
||||
if session_start_portfolio_value > 0:
|
||||
portfolio_ret = (
|
||||
(portfolio_value - session_start_portfolio_value)
|
||||
/ session_start_portfolio_value
|
||||
) * 100
|
||||
result["portfolio_return"] = round(portfolio_ret, 4)
|
||||
|
||||
return result
|
||||
|
||||
@property
|
||||
def session_start_values(self) -> Optional[Dict[str, float]]:
|
||||
"""Get session start values for external use"""
|
||||
return self._session_start_values
|
||||
|
||||
@property
|
||||
def session_start_timestamp(self) -> Optional[int]:
|
||||
"""Get session start timestamp"""
|
||||
return self._session_start_timestamp
|
||||
1099
backend/services/storage.py
Normal file
1099
backend/services/storage.py
Normal file
File diff suppressed because it is too large
Load Diff
0
backend/tests/__init__.py
Normal file
0
backend/tests/__init__.py
Normal file
580
backend/tests/test_agents.py
Normal file
580
backend/tests/test_agents.py
Normal file
@@ -0,0 +1,580 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# pylint: disable=W0212
|
||||
import json
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from agentscope.message import Msg
|
||||
|
||||
|
||||
class TestAnalystAgent:
|
||||
def test_init_valid_analyst_type(self):
|
||||
from backend.agents.analyst import AnalystAgent
|
||||
|
||||
mock_toolkit = MagicMock()
|
||||
mock_model = MagicMock()
|
||||
mock_formatter = MagicMock()
|
||||
|
||||
agent = AnalystAgent(
|
||||
analyst_type="technical_analyst",
|
||||
toolkit=mock_toolkit,
|
||||
model=mock_model,
|
||||
formatter=mock_formatter,
|
||||
)
|
||||
|
||||
assert agent.analyst_type_key == "technical_analyst"
|
||||
assert agent.name == "technical_analyst_analyst"
|
||||
assert agent.analyst_persona == "Technical Analyst"
|
||||
|
||||
def test_init_invalid_analyst_type(self):
|
||||
from backend.agents.analyst import AnalystAgent
|
||||
|
||||
mock_toolkit = MagicMock()
|
||||
mock_model = MagicMock()
|
||||
mock_formatter = MagicMock()
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
AnalystAgent(
|
||||
analyst_type="invalid_type",
|
||||
toolkit=mock_toolkit,
|
||||
model=mock_model,
|
||||
formatter=mock_formatter,
|
||||
)
|
||||
|
||||
assert "Unknown analyst type" in str(excinfo.value)
|
||||
|
||||
def test_init_custom_agent_id(self):
|
||||
from backend.agents.analyst import AnalystAgent
|
||||
|
||||
mock_toolkit = MagicMock()
|
||||
mock_model = MagicMock()
|
||||
mock_formatter = MagicMock()
|
||||
|
||||
agent = AnalystAgent(
|
||||
analyst_type="fundamentals_analyst",
|
||||
toolkit=mock_toolkit,
|
||||
model=mock_model,
|
||||
formatter=mock_formatter,
|
||||
agent_id="custom_analyst_id",
|
||||
)
|
||||
|
||||
assert agent.name == "custom_analyst_id"
|
||||
|
||||
def test_load_system_prompt(self):
|
||||
from backend.agents.analyst import AnalystAgent
|
||||
|
||||
mock_toolkit = MagicMock()
|
||||
mock_model = MagicMock()
|
||||
mock_formatter = MagicMock()
|
||||
|
||||
agent = AnalystAgent(
|
||||
analyst_type="sentiment_analyst",
|
||||
toolkit=mock_toolkit,
|
||||
model=mock_model,
|
||||
formatter=mock_formatter,
|
||||
)
|
||||
|
||||
prompt = agent._load_system_prompt()
|
||||
assert isinstance(prompt, str)
|
||||
assert len(prompt) > 0
|
||||
|
||||
|
||||
class TestPMAgent:
|
||||
def test_init_default(self):
|
||||
from backend.agents.portfolio_manager import PMAgent
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_formatter = MagicMock()
|
||||
|
||||
agent = PMAgent(
|
||||
model=mock_model,
|
||||
formatter=mock_formatter,
|
||||
)
|
||||
|
||||
assert agent.name == "portfolio_manager"
|
||||
assert agent.portfolio["cash"] == 100000.0
|
||||
assert agent.portfolio["positions"] == {}
|
||||
assert agent.portfolio["margin_requirement"] == 0.25
|
||||
|
||||
def test_init_custom_cash(self):
|
||||
from backend.agents.portfolio_manager import PMAgent
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_formatter = MagicMock()
|
||||
|
||||
agent = PMAgent(
|
||||
model=mock_model,
|
||||
formatter=mock_formatter,
|
||||
initial_cash=50000.0,
|
||||
margin_requirement=0.5,
|
||||
)
|
||||
|
||||
assert agent.portfolio["cash"] == 50000.0
|
||||
assert agent.portfolio["margin_requirement"] == 0.5
|
||||
|
||||
def test_get_portfolio_state(self):
|
||||
from backend.agents.portfolio_manager import PMAgent
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_formatter = MagicMock()
|
||||
|
||||
agent = PMAgent(
|
||||
model=mock_model,
|
||||
formatter=mock_formatter,
|
||||
initial_cash=75000.0,
|
||||
)
|
||||
|
||||
state = agent.get_portfolio_state()
|
||||
|
||||
assert state["cash"] == 75000.0
|
||||
assert state is not agent.portfolio # Should be a copy
|
||||
|
||||
def test_load_portfolio_state(self):
|
||||
from backend.agents.portfolio_manager import PMAgent
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_formatter = MagicMock()
|
||||
|
||||
agent = PMAgent(
|
||||
model=mock_model,
|
||||
formatter=mock_formatter,
|
||||
)
|
||||
|
||||
new_portfolio = {
|
||||
"cash": 50000.0,
|
||||
"positions": {
|
||||
"AAPL": {"long": 100, "short": 0, "long_cost_basis": 150.0},
|
||||
},
|
||||
"margin_used": 1000.0,
|
||||
}
|
||||
|
||||
agent.load_portfolio_state(new_portfolio)
|
||||
|
||||
assert agent.portfolio["cash"] == 50000.0
|
||||
assert "AAPL" in agent.portfolio["positions"]
|
||||
|
||||
def test_update_portfolio(self):
|
||||
from backend.agents.portfolio_manager import PMAgent
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_formatter = MagicMock()
|
||||
|
||||
agent = PMAgent(
|
||||
model=mock_model,
|
||||
formatter=mock_formatter,
|
||||
)
|
||||
|
||||
agent.update_portfolio({"cash": 80000.0})
|
||||
assert agent.portfolio["cash"] == 80000.0
|
||||
|
||||
def _get_text_from_tool_response(self, result):
|
||||
"""Helper to extract text from ToolResponse content"""
|
||||
content = result.content[0]
|
||||
if hasattr(content, "text"):
|
||||
return content.text
|
||||
elif isinstance(content, dict):
|
||||
return content.get("text", "")
|
||||
return str(content)
|
||||
|
||||
def test_make_decision_long(self):
|
||||
from backend.agents.portfolio_manager import PMAgent
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_formatter = MagicMock()
|
||||
|
||||
agent = PMAgent(
|
||||
model=mock_model,
|
||||
formatter=mock_formatter,
|
||||
)
|
||||
|
||||
result = agent._make_decision(
|
||||
ticker="AAPL",
|
||||
action="long",
|
||||
quantity=100,
|
||||
confidence=80,
|
||||
reasoning="Strong fundamentals",
|
||||
)
|
||||
|
||||
text = self._get_text_from_tool_response(result)
|
||||
assert "Decision recorded" in text
|
||||
assert agent._decisions["AAPL"]["action"] == "long"
|
||||
assert agent._decisions["AAPL"]["quantity"] == 100
|
||||
|
||||
def test_make_decision_hold(self):
|
||||
from backend.agents.portfolio_manager import PMAgent
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_formatter = MagicMock()
|
||||
|
||||
agent = PMAgent(
|
||||
model=mock_model,
|
||||
formatter=mock_formatter,
|
||||
)
|
||||
|
||||
result = agent._make_decision(
|
||||
ticker="GOOGL",
|
||||
action="hold",
|
||||
quantity=0,
|
||||
confidence=50,
|
||||
reasoning="Neutral outlook",
|
||||
)
|
||||
|
||||
text = self._get_text_from_tool_response(result)
|
||||
assert "Decision recorded" in text
|
||||
assert agent._decisions["GOOGL"]["action"] == "hold"
|
||||
assert agent._decisions["GOOGL"]["quantity"] == 0
|
||||
|
||||
def test_make_decision_invalid_action(self):
|
||||
from backend.agents.portfolio_manager import PMAgent
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_formatter = MagicMock()
|
||||
|
||||
agent = PMAgent(
|
||||
model=mock_model,
|
||||
formatter=mock_formatter,
|
||||
)
|
||||
|
||||
result = agent._make_decision(
|
||||
ticker="AAPL",
|
||||
action="invalid",
|
||||
quantity=10,
|
||||
)
|
||||
|
||||
text = self._get_text_from_tool_response(result)
|
||||
assert "Invalid action" in text
|
||||
|
||||
def test_get_decisions(self):
|
||||
from backend.agents.portfolio_manager import PMAgent
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_formatter = MagicMock()
|
||||
|
||||
agent = PMAgent(
|
||||
model=mock_model,
|
||||
formatter=mock_formatter,
|
||||
)
|
||||
|
||||
agent._make_decision("AAPL", "long", 100)
|
||||
agent._make_decision("GOOGL", "short", 50)
|
||||
|
||||
decisions = agent.get_decisions()
|
||||
assert len(decisions) == 2
|
||||
assert decisions["AAPL"]["action"] == "long"
|
||||
assert decisions["GOOGL"]["action"] == "short"
|
||||
|
||||
|
||||
class TestRiskAgent:
|
||||
def test_init_default(self):
|
||||
from backend.agents.risk_manager import RiskAgent
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_formatter = MagicMock()
|
||||
|
||||
agent = RiskAgent(
|
||||
model=mock_model,
|
||||
formatter=mock_formatter,
|
||||
)
|
||||
|
||||
assert agent.name == "risk_manager"
|
||||
|
||||
def test_init_custom_name(self):
|
||||
from backend.agents.risk_manager import RiskAgent
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_formatter = MagicMock()
|
||||
|
||||
agent = RiskAgent(
|
||||
model=mock_model,
|
||||
formatter=mock_formatter,
|
||||
name="custom_risk_manager",
|
||||
)
|
||||
|
||||
assert agent.name == "custom_risk_manager"
|
||||
|
||||
def test_load_system_prompt(self):
|
||||
from backend.agents.risk_manager import RiskAgent
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_formatter = MagicMock()
|
||||
|
||||
agent = RiskAgent(
|
||||
model=mock_model,
|
||||
formatter=mock_formatter,
|
||||
)
|
||||
|
||||
prompt = agent._load_system_prompt()
|
||||
assert isinstance(prompt, str)
|
||||
assert len(prompt) > 0
|
||||
|
||||
|
||||
class TestStorageService:
|
||||
def test_calculate_portfolio_value_cash_only(self):
|
||||
from backend.services.storage import StorageService
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
storage = StorageService(
|
||||
dashboard_dir=Path(tmpdir),
|
||||
initial_cash=100000.0,
|
||||
)
|
||||
|
||||
portfolio = {"cash": 100000.0, "positions": {}, "margin_used": 0.0}
|
||||
prices = {}
|
||||
|
||||
value = storage.calculate_portfolio_value(portfolio, prices)
|
||||
assert value == 100000.0
|
||||
|
||||
def test_calculate_portfolio_value_with_positions(self):
|
||||
from backend.services.storage import StorageService
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
storage = StorageService(
|
||||
dashboard_dir=Path(tmpdir),
|
||||
initial_cash=100000.0,
|
||||
)
|
||||
|
||||
portfolio = {
|
||||
"cash": 50000.0,
|
||||
"positions": {
|
||||
"AAPL": {"long": 100, "short": 0},
|
||||
"GOOGL": {"long": 0, "short": 10},
|
||||
},
|
||||
"margin_used": 5000.0,
|
||||
}
|
||||
prices = {"AAPL": 150.0, "GOOGL": 100.0}
|
||||
|
||||
value = storage.calculate_portfolio_value(portfolio, prices)
|
||||
assert value == 69000.0
|
||||
|
||||
def test_update_dashboard_after_cycle(self):
|
||||
from backend.services.storage import StorageService
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
storage = StorageService(
|
||||
dashboard_dir=Path(tmpdir),
|
||||
initial_cash=100000.0,
|
||||
)
|
||||
|
||||
portfolio = {
|
||||
"cash": 90000.0,
|
||||
"positions": {"AAPL": {"long": 50, "short": 0}},
|
||||
"margin_used": 0.0,
|
||||
}
|
||||
prices = {"AAPL": 200.0}
|
||||
|
||||
storage.update_dashboard_after_cycle(
|
||||
portfolio=portfolio,
|
||||
prices=prices,
|
||||
date="2024-01-15",
|
||||
executed_trades=[
|
||||
{
|
||||
"ticker": "AAPL",
|
||||
"action": "long",
|
||||
"quantity": 50,
|
||||
"price": 200.0,
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
summary = storage.load_file("summary")
|
||||
assert summary is not None
|
||||
assert summary["totalAssetValue"] == 100000.0 # 90000 + 50*200
|
||||
|
||||
holdings = storage.load_file("holdings")
|
||||
assert holdings is not None
|
||||
assert len(holdings) > 0
|
||||
|
||||
trades = storage.load_file("trades")
|
||||
assert trades is not None
|
||||
assert len(trades) == 1
|
||||
assert trades[0]["ticker"] == "AAPL"
|
||||
assert trades[0]["qty"] == 50
|
||||
assert trades[0]["price"] == 200.0
|
||||
|
||||
def test_generate_summary(self):
|
||||
from backend.services.storage import StorageService
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
storage = StorageService(
|
||||
dashboard_dir=Path(tmpdir),
|
||||
initial_cash=100000.0,
|
||||
)
|
||||
|
||||
state = {
|
||||
"portfolio_state": {
|
||||
"cash": 50000.0,
|
||||
"positions": {"AAPL": {"long": 100, "short": 0}},
|
||||
"margin_used": 0.0,
|
||||
},
|
||||
"equity_history": [{"t": 1000, "v": 100000}],
|
||||
"all_trades": [],
|
||||
}
|
||||
prices = {"AAPL": 500.0}
|
||||
|
||||
storage._generate_summary(state, 100000.0, prices)
|
||||
|
||||
summary = storage.load_file("summary")
|
||||
assert summary["totalAssetValue"] == 100000.0
|
||||
assert summary["totalReturn"] == 0.0
|
||||
|
||||
def test_generate_holdings(self):
|
||||
from backend.services.storage import StorageService
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
storage = StorageService(
|
||||
dashboard_dir=Path(tmpdir),
|
||||
initial_cash=100000.0,
|
||||
)
|
||||
|
||||
state = {
|
||||
"portfolio_state": {
|
||||
"cash": 50000.0,
|
||||
"positions": {"AAPL": {"long": 100, "short": 0}},
|
||||
"margin_used": 0.0,
|
||||
},
|
||||
}
|
||||
prices = {"AAPL": 500.0}
|
||||
|
||||
storage._generate_holdings(state, prices)
|
||||
|
||||
holdings = storage.load_file("holdings")
|
||||
assert len(holdings) == 2 # AAPL + CASH
|
||||
|
||||
aapl_holding = next(
|
||||
(h for h in holdings if h["ticker"] == "AAPL"),
|
||||
None,
|
||||
)
|
||||
assert aapl_holding is not None
|
||||
assert aapl_holding["quantity"] == 100
|
||||
assert aapl_holding["currentPrice"] == 500.0
|
||||
|
||||
|
||||
class TestTradeExecutor:
|
||||
def test_execute_trade_long(self):
|
||||
from backend.utils.trade_executor import PortfolioTradeExecutor
|
||||
|
||||
executor = PortfolioTradeExecutor(
|
||||
initial_portfolio={
|
||||
"cash": 100000.0,
|
||||
"positions": {},
|
||||
"margin_requirement": 0.25,
|
||||
"margin_used": 0.0,
|
||||
},
|
||||
)
|
||||
|
||||
result = executor.execute_trade(
|
||||
ticker="AAPL",
|
||||
action="long",
|
||||
quantity=10,
|
||||
price=150.0,
|
||||
)
|
||||
|
||||
assert result["status"] == "success"
|
||||
assert executor.portfolio["positions"]["AAPL"]["long"] == 10
|
||||
assert executor.portfolio["cash"] == 98500.0 # 100000 - 10*150
|
||||
|
||||
def test_execute_trade_short(self):
|
||||
from backend.utils.trade_executor import PortfolioTradeExecutor
|
||||
|
||||
executor = PortfolioTradeExecutor(
|
||||
initial_portfolio={
|
||||
"cash": 100000.0,
|
||||
"positions": {
|
||||
"AAPL": {
|
||||
"long": 50,
|
||||
"short": 0,
|
||||
"long_cost_basis": 100.0,
|
||||
"short_cost_basis": 0.0,
|
||||
},
|
||||
},
|
||||
"margin_requirement": 0.25,
|
||||
"margin_used": 0.0,
|
||||
},
|
||||
)
|
||||
|
||||
result = executor.execute_trade(
|
||||
ticker="AAPL",
|
||||
action="short",
|
||||
quantity=30,
|
||||
price=150.0,
|
||||
)
|
||||
|
||||
assert result["status"] == "success"
|
||||
assert executor.portfolio["positions"]["AAPL"]["long"] == 20 # 50 - 30
|
||||
|
||||
def test_execute_trade_hold(self):
|
||||
from backend.utils.trade_executor import PortfolioTradeExecutor
|
||||
|
||||
executor = PortfolioTradeExecutor()
|
||||
|
||||
result = executor.execute_trade(
|
||||
ticker="AAPL",
|
||||
action="hold",
|
||||
quantity=0,
|
||||
price=150.0,
|
||||
)
|
||||
|
||||
assert result["status"] == "success"
|
||||
assert result["message"] == "No trade needed"
|
||||
|
||||
|
||||
class TestPipelineExecution:
|
||||
def test_execute_decisions(self):
|
||||
from backend.core.pipeline import TradingPipeline
|
||||
from backend.agents.portfolio_manager import PMAgent
|
||||
|
||||
mock_model = MagicMock()
|
||||
mock_formatter = MagicMock()
|
||||
|
||||
pm = PMAgent(
|
||||
model=mock_model,
|
||||
formatter=mock_formatter,
|
||||
initial_cash=100000.0,
|
||||
)
|
||||
|
||||
pipeline = TradingPipeline(
|
||||
analysts=[],
|
||||
risk_manager=MagicMock(),
|
||||
portfolio_manager=pm,
|
||||
max_comm_cycles=0,
|
||||
)
|
||||
|
||||
decisions = {
|
||||
"AAPL": {"action": "long", "quantity": 10},
|
||||
"GOOGL": {"action": "short", "quantity": 5},
|
||||
}
|
||||
prices = {"AAPL": 150.0, "GOOGL": 100.0}
|
||||
|
||||
result = pipeline._execute_decisions(decisions, prices, "2024-01-15")
|
||||
|
||||
assert len(result["executed_trades"]) == 2
|
||||
assert result["executed_trades"][0]["ticker"] == "AAPL"
|
||||
assert result["executed_trades"][0]["quantity"] == 10
|
||||
assert pm.portfolio["positions"]["AAPL"]["long"] == 10
|
||||
|
||||
|
||||
class TestMsgContentIsString:
|
||||
def test_msg_content_string(self):
|
||||
msg = Msg(name="test", content="simple string", role="user")
|
||||
assert isinstance(msg.content, str)
|
||||
|
||||
def test_msg_content_json_string(self):
|
||||
data = {"key": "value", "nested": {"a": 1}}
|
||||
msg = Msg(name="test", content=json.dumps(data), role="user")
|
||||
assert isinstance(msg.content, str)
|
||||
|
||||
parsed = json.loads(msg.content)
|
||||
assert parsed["key"] == "value"
|
||||
|
||||
def test_msg_content_should_not_be_dict(self):
|
||||
data = {"key": "value"}
|
||||
msg = Msg(name="test", content=json.dumps(data), role="assistant")
|
||||
|
||||
assert not isinstance(msg.content, dict)
|
||||
assert isinstance(msg.content, str)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
438
backend/tests/test_market_service.py
Normal file
438
backend/tests/test_market_service.py
Normal file
@@ -0,0 +1,438 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# pylint: disable=W0212
|
||||
import asyncio
|
||||
import time
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
import pytest
|
||||
from backend.services.market import MarketService
|
||||
from backend.data.mock_price_manager import MockPriceManager
|
||||
from backend.data.polling_price_manager import PollingPriceManager
|
||||
|
||||
|
||||
class TestMockPriceManager:
|
||||
def test_init_default(self):
|
||||
manager = MockPriceManager()
|
||||
|
||||
assert manager.poll_interval == 10
|
||||
assert manager.volatility == 0.5
|
||||
assert manager.running is False
|
||||
assert len(manager.subscribed_symbols) == 0
|
||||
|
||||
def test_init_custom(self):
|
||||
manager = MockPriceManager(poll_interval=5, volatility=1.0)
|
||||
|
||||
assert manager.poll_interval == 5
|
||||
assert manager.volatility == 1.0
|
||||
|
||||
def test_subscribe(self):
|
||||
manager = MockPriceManager()
|
||||
manager.subscribe(["AAPL", "MSFT"])
|
||||
|
||||
assert "AAPL" in manager.subscribed_symbols
|
||||
assert "MSFT" in manager.subscribed_symbols
|
||||
assert manager.base_prices["AAPL"] == 237.50 # default price
|
||||
assert manager.base_prices["MSFT"] == 425.30 # default price
|
||||
|
||||
def test_subscribe_with_base_prices(self):
|
||||
manager = MockPriceManager()
|
||||
manager.subscribe(["AAPL"], base_prices={"AAPL": 100.0})
|
||||
|
||||
assert manager.base_prices["AAPL"] == 100.0
|
||||
assert manager.open_prices["AAPL"] == 100.0
|
||||
assert manager.latest_prices["AAPL"] == 100.0
|
||||
|
||||
def test_subscribe_unknown_symbol(self):
|
||||
manager = MockPriceManager()
|
||||
manager.subscribe(["UNKNOWN"])
|
||||
|
||||
assert "UNKNOWN" in manager.subscribed_symbols
|
||||
assert manager.base_prices["UNKNOWN"] > 0 # random price generated
|
||||
|
||||
def test_unsubscribe(self):
|
||||
manager = MockPriceManager()
|
||||
manager.subscribe(["AAPL", "MSFT"])
|
||||
manager.unsubscribe(["AAPL"])
|
||||
|
||||
assert "AAPL" not in manager.subscribed_symbols
|
||||
assert "MSFT" in manager.subscribed_symbols
|
||||
|
||||
def test_add_price_callback(self):
|
||||
manager = MockPriceManager()
|
||||
callback = MagicMock()
|
||||
manager.add_price_callback(callback)
|
||||
|
||||
assert callback in manager.price_callbacks
|
||||
|
||||
def test_generate_price_update_within_bounds(self):
|
||||
manager = MockPriceManager(volatility=0.5)
|
||||
manager.subscribe(["AAPL"], base_prices={"AAPL": 100.0})
|
||||
|
||||
for _ in range(100):
|
||||
new_price = manager._generate_price_update("AAPL")
|
||||
# Should be within +/-10% of open
|
||||
assert 90.0 <= new_price <= 110.0
|
||||
|
||||
def test_update_prices_triggers_callback(self):
|
||||
manager = MockPriceManager()
|
||||
manager.subscribe(["AAPL"], base_prices={"AAPL": 100.0})
|
||||
|
||||
callback = MagicMock()
|
||||
manager.add_price_callback(callback)
|
||||
|
||||
manager._update_prices()
|
||||
|
||||
callback.assert_called_once()
|
||||
call_args = callback.call_args[0][0]
|
||||
assert call_args["symbol"] == "AAPL"
|
||||
assert "price" in call_args
|
||||
assert "timestamp" in call_args
|
||||
|
||||
def test_start_stop(self):
|
||||
manager = MockPriceManager(poll_interval=1)
|
||||
manager.subscribe(["AAPL"], base_prices={"AAPL": 100.0})
|
||||
|
||||
manager.start()
|
||||
assert manager.running is True
|
||||
|
||||
time.sleep(0.1) # let thread start
|
||||
|
||||
manager.stop()
|
||||
assert manager.running is False
|
||||
|
||||
def test_start_without_subscription(self):
|
||||
manager = MockPriceManager()
|
||||
manager.start()
|
||||
|
||||
assert (
|
||||
manager.running is False
|
||||
) # should not start without subscriptions
|
||||
|
||||
def test_get_latest_price(self):
|
||||
manager = MockPriceManager()
|
||||
manager.subscribe(["AAPL"], base_prices={"AAPL": 100.0})
|
||||
|
||||
price = manager.get_latest_price("AAPL")
|
||||
assert price == 100.0
|
||||
|
||||
def test_get_latest_price_unknown(self):
|
||||
manager = MockPriceManager()
|
||||
price = manager.get_latest_price("UNKNOWN")
|
||||
assert price is None
|
||||
|
||||
def test_get_all_latest_prices(self):
|
||||
manager = MockPriceManager()
|
||||
manager.subscribe(
|
||||
["AAPL", "MSFT"],
|
||||
base_prices={"AAPL": 100.0, "MSFT": 200.0},
|
||||
)
|
||||
|
||||
prices = manager.get_all_latest_prices()
|
||||
assert prices["AAPL"] == 100.0
|
||||
assert prices["MSFT"] == 200.0
|
||||
|
||||
def test_reset_open_prices(self):
|
||||
manager = MockPriceManager()
|
||||
manager.subscribe(["AAPL"], base_prices={"AAPL": 100.0})
|
||||
manager.latest_prices["AAPL"] = 105.0
|
||||
|
||||
manager.reset_open_prices()
|
||||
|
||||
# Open price should change (based on latest with small gap)
|
||||
assert manager.open_prices["AAPL"] != 100.0
|
||||
|
||||
def test_set_base_price(self):
|
||||
manager = MockPriceManager()
|
||||
manager.subscribe(["AAPL"], base_prices={"AAPL": 100.0})
|
||||
|
||||
manager.set_base_price("AAPL", 150.0)
|
||||
|
||||
assert manager.base_prices["AAPL"] == 150.0
|
||||
assert manager.open_prices["AAPL"] == 150.0
|
||||
assert manager.latest_prices["AAPL"] == 150.0
|
||||
|
||||
|
||||
class TestPollingPriceManager:
|
||||
def test_init(self):
|
||||
manager = PollingPriceManager(api_key="test_key", poll_interval=30)
|
||||
|
||||
assert manager.api_key == "test_key"
|
||||
assert manager.poll_interval == 30
|
||||
assert manager.running is False
|
||||
|
||||
def test_subscribe(self):
|
||||
manager = PollingPriceManager(api_key="test_key")
|
||||
manager.subscribe(["AAPL", "MSFT"])
|
||||
|
||||
assert "AAPL" in manager.subscribed_symbols
|
||||
assert "MSFT" in manager.subscribed_symbols
|
||||
|
||||
def test_unsubscribe(self):
|
||||
manager = PollingPriceManager(api_key="test_key")
|
||||
manager.subscribe(["AAPL", "MSFT"])
|
||||
manager.unsubscribe(["AAPL"])
|
||||
|
||||
assert "AAPL" not in manager.subscribed_symbols
|
||||
assert "MSFT" in manager.subscribed_symbols
|
||||
|
||||
def test_add_price_callback(self):
|
||||
manager = PollingPriceManager(api_key="test_key")
|
||||
callback = MagicMock()
|
||||
manager.add_price_callback(callback)
|
||||
|
||||
assert callback in manager.price_callbacks
|
||||
|
||||
@patch.object(PollingPriceManager, "_fetch_prices")
|
||||
def test_start_stop(self):
|
||||
manager = PollingPriceManager(api_key="test_key", poll_interval=1)
|
||||
manager.subscribe(["AAPL"])
|
||||
|
||||
manager.start()
|
||||
assert manager.running is True
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
manager.stop()
|
||||
assert manager.running is False
|
||||
|
||||
def test_start_without_subscription(self):
|
||||
manager = PollingPriceManager(api_key="test_key")
|
||||
manager.start()
|
||||
|
||||
assert manager.running is False
|
||||
|
||||
def test_get_latest_price(self):
|
||||
manager = PollingPriceManager(api_key="test_key")
|
||||
manager.latest_prices["AAPL"] = 150.0
|
||||
|
||||
price = manager.get_latest_price("AAPL")
|
||||
assert price == 150.0
|
||||
|
||||
def test_get_open_price(self):
|
||||
manager = PollingPriceManager(api_key="test_key")
|
||||
manager.open_prices["AAPL"] = 148.0
|
||||
|
||||
price = manager.get_open_price("AAPL")
|
||||
assert price == 148.0
|
||||
|
||||
def test_reset_open_prices(self):
|
||||
manager = PollingPriceManager(api_key="test_key")
|
||||
manager.open_prices["AAPL"] = 150.0
|
||||
|
||||
manager.reset_open_prices()
|
||||
|
||||
assert len(manager.open_prices) == 0
|
||||
|
||||
|
||||
class TestMarketService:
|
||||
def test_init_mock_mode(self):
|
||||
service = MarketService(
|
||||
tickers=["AAPL", "MSFT"],
|
||||
poll_interval=10,
|
||||
mock_mode=True,
|
||||
)
|
||||
|
||||
assert service.tickers == ["AAPL", "MSFT"]
|
||||
assert service.poll_interval == 10
|
||||
assert service.mock_mode is True
|
||||
assert service.running is False
|
||||
|
||||
def test_init_real_mode(self):
|
||||
service = MarketService(
|
||||
tickers=["AAPL"],
|
||||
mock_mode=False,
|
||||
api_key="test_key",
|
||||
)
|
||||
|
||||
assert service.mock_mode is False
|
||||
assert service.api_key == "test_key"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_mock_mode(self):
|
||||
service = MarketService(
|
||||
tickers=["AAPL"],
|
||||
poll_interval=10,
|
||||
mock_mode=True,
|
||||
)
|
||||
|
||||
broadcast_func = AsyncMock()
|
||||
|
||||
await service.start(broadcast_func)
|
||||
|
||||
assert service.running is True
|
||||
assert service._price_manager is not None
|
||||
assert isinstance(service._price_manager, MockPriceManager)
|
||||
|
||||
service.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_real_mode_without_api_key(self):
|
||||
service = MarketService(
|
||||
tickers=["AAPL"],
|
||||
mock_mode=False,
|
||||
api_key=None,
|
||||
)
|
||||
|
||||
broadcast_func = AsyncMock()
|
||||
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
await service.start(broadcast_func)
|
||||
|
||||
assert "API key required" in str(excinfo.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_already_running(self):
|
||||
service = MarketService(
|
||||
tickers=["AAPL"],
|
||||
mock_mode=True,
|
||||
)
|
||||
|
||||
broadcast_func = AsyncMock()
|
||||
|
||||
await service.start(broadcast_func)
|
||||
assert service.running is True
|
||||
|
||||
# Start again should not fail
|
||||
await service.start(broadcast_func)
|
||||
|
||||
service.stop()
|
||||
|
||||
def test_stop(self):
|
||||
service = MarketService(
|
||||
tickers=["AAPL"],
|
||||
mock_mode=True,
|
||||
)
|
||||
service.running = True
|
||||
service._price_manager = MagicMock()
|
||||
|
||||
service.stop()
|
||||
|
||||
assert service.running is False
|
||||
assert service._price_manager is None
|
||||
|
||||
def test_stop_when_not_running(self):
|
||||
service = MarketService(
|
||||
tickers=["AAPL"],
|
||||
mock_mode=True,
|
||||
)
|
||||
|
||||
# Should not raise
|
||||
service.stop()
|
||||
assert service.running is False
|
||||
|
||||
def test_get_price_sync(self):
|
||||
service = MarketService(tickers=["AAPL"], mock_mode=True)
|
||||
service.cache["AAPL"] = {"price": 150.0, "open": 148.0}
|
||||
|
||||
price = service.get_price_sync("AAPL")
|
||||
assert price == 150.0
|
||||
|
||||
def test_get_price_sync_not_found(self):
|
||||
service = MarketService(tickers=["AAPL"], mock_mode=True)
|
||||
|
||||
price = service.get_price_sync("MSFT")
|
||||
assert price is None
|
||||
|
||||
def test_get_all_prices(self):
|
||||
service = MarketService(tickers=["AAPL", "MSFT"], mock_mode=True)
|
||||
service.cache["AAPL"] = {"price": 150.0}
|
||||
service.cache["MSFT"] = {"price": 400.0}
|
||||
|
||||
prices = service.get_all_prices()
|
||||
|
||||
assert prices["AAPL"] == 150.0
|
||||
assert prices["MSFT"] == 400.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_price_update(self):
|
||||
service = MarketService(tickers=["AAPL"], mock_mode=True)
|
||||
service._broadcast_func = AsyncMock()
|
||||
|
||||
price_data = {
|
||||
"symbol": "AAPL",
|
||||
"price": 150.0,
|
||||
"open": 148.0,
|
||||
"timestamp": 1234567890,
|
||||
}
|
||||
|
||||
await service._broadcast_price_update(price_data)
|
||||
|
||||
service._broadcast_func.assert_called_once()
|
||||
call_args = service._broadcast_func.call_args[0][0]
|
||||
assert call_args["type"] == "price_update"
|
||||
assert call_args["symbol"] == "AAPL"
|
||||
assert call_args["price"] == 150.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_price_update_no_func(self):
|
||||
service = MarketService(tickers=["AAPL"], mock_mode=True)
|
||||
service._broadcast_func = None
|
||||
|
||||
price_data = {"symbol": "AAPL", "price": 150.0, "open": 148.0}
|
||||
|
||||
# Should not raise
|
||||
await service._broadcast_price_update(price_data)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_price_callback_thread_safety(self):
|
||||
service = MarketService(
|
||||
tickers=["AAPL"],
|
||||
poll_interval=1,
|
||||
mock_mode=True,
|
||||
)
|
||||
|
||||
received_prices = []
|
||||
|
||||
async def capture_broadcast(msg):
|
||||
received_prices.append(msg)
|
||||
|
||||
await service.start(capture_broadcast)
|
||||
|
||||
# Wait for at least one price update
|
||||
await asyncio.sleep(1.5)
|
||||
|
||||
service.stop()
|
||||
|
||||
# Should have received at least one price update
|
||||
assert len(received_prices) >= 1
|
||||
assert received_prices[0]["type"] == "price_update"
|
||||
|
||||
|
||||
class TestMarketServiceIntegration:
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_mock_cycle(self):
|
||||
service = MarketService(
|
||||
tickers=["AAPL", "MSFT"],
|
||||
poll_interval=1,
|
||||
mock_mode=True,
|
||||
)
|
||||
|
||||
messages = []
|
||||
|
||||
async def collect_messages(msg):
|
||||
messages.append(msg)
|
||||
|
||||
await service.start(collect_messages)
|
||||
|
||||
# Wait for price updates
|
||||
await asyncio.sleep(2.5)
|
||||
|
||||
service.stop()
|
||||
|
||||
# Should have received multiple price updates
|
||||
assert len(messages) >= 2
|
||||
|
||||
# Check message structure
|
||||
symbols_seen = set()
|
||||
for msg in messages:
|
||||
assert msg["type"] == "price_update"
|
||||
assert "symbol" in msg
|
||||
assert "price" in msg
|
||||
assert "ret" in msg
|
||||
symbols_seen.add(msg["symbol"])
|
||||
|
||||
# Should have prices for both tickers
|
||||
assert "AAPL" in symbols_seen or "MSFT" in symbols_seen
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
201
backend/tests/test_settlement.py
Normal file
201
backend/tests/test_settlement.py
Normal file
@@ -0,0 +1,201 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Test Settlement Coordinator and Baseline Calculations
|
||||
"""
|
||||
|
||||
from backend.utils.baselines import (
|
||||
BaselineCalculator,
|
||||
calculate_momentum_scores,
|
||||
)
|
||||
from backend.utils.analyst_tracker import (
|
||||
AnalystPerformanceTracker,
|
||||
update_leaderboard_with_evaluations,
|
||||
)
|
||||
|
||||
|
||||
def test_baseline_equal_weight():
|
||||
"""Test equal weight baseline calculation"""
|
||||
calculator = BaselineCalculator(initial_capital=100000.0)
|
||||
|
||||
tickers = ["AAPL", "MSFT", "GOOGL"]
|
||||
prices = {"AAPL": 150.0, "MSFT": 300.0, "GOOGL": 120.0}
|
||||
openprices = {"AAPL": 160.0, "MSFT": 310.0, "GOOGL": 110.0}
|
||||
value = calculator.calculate_equal_weight_value(
|
||||
tickers,
|
||||
openprices,
|
||||
prices,
|
||||
)
|
||||
|
||||
assert value > 0
|
||||
assert calculator.equal_weight_initialized is True
|
||||
|
||||
|
||||
def test_baseline_market_cap_weighted():
|
||||
"""Test market cap weighted baseline calculation"""
|
||||
calculator = BaselineCalculator(initial_capital=100000.0)
|
||||
|
||||
tickers = ["AAPL", "MSFT", "GOOGL"]
|
||||
prices = {"AAPL": 150.0, "MSFT": 300.0, "GOOGL": 120.0}
|
||||
openprices = {"AAPL": 160.0, "MSFT": 310.0, "GOOGL": 110.0}
|
||||
market_caps = {"AAPL": 3e12, "MSFT": 2e12, "GOOGL": 1.5e12}
|
||||
|
||||
value = calculator.calculate_market_cap_weighted_value(
|
||||
tickers,
|
||||
openprices,
|
||||
prices,
|
||||
market_caps,
|
||||
)
|
||||
|
||||
assert value > 0
|
||||
assert calculator.market_cap_initialized is True
|
||||
|
||||
|
||||
def test_momentum_scores():
|
||||
"""Test momentum score calculation"""
|
||||
tickers = ["AAPL", "MSFT"]
|
||||
prices_history = {
|
||||
"AAPL": [
|
||||
("2024-01-01", 100.0),
|
||||
("2024-01-02", 105.0),
|
||||
("2024-01-03", 110.0),
|
||||
],
|
||||
"MSFT": [
|
||||
("2024-01-01", 200.0),
|
||||
("2024-01-02", 195.0),
|
||||
("2024-01-03", 190.0),
|
||||
],
|
||||
}
|
||||
|
||||
scores = calculate_momentum_scores(
|
||||
tickers,
|
||||
prices_history,
|
||||
lookback_days=2,
|
||||
)
|
||||
|
||||
assert scores["AAPL"] > 0
|
||||
assert scores["MSFT"] < 0
|
||||
|
||||
|
||||
def test_analyst_tracker_predictions():
|
||||
"""Test analyst prediction recording with structured format"""
|
||||
tracker = AnalystPerformanceTracker()
|
||||
|
||||
final_predictions = [
|
||||
{
|
||||
"agent": "technical_analyst",
|
||||
"predictions": [
|
||||
{"ticker": "AAPL", "direction": "up", "confidence": 0.8},
|
||||
{"ticker": "MSFT", "direction": "down", "confidence": 0.7},
|
||||
{"ticker": "GOOGL", "direction": "neutral", "confidence": 0.5},
|
||||
],
|
||||
},
|
||||
{
|
||||
"agent": "fundamentals_analyst",
|
||||
"predictions": [
|
||||
{"ticker": "AAPL", "direction": "up", "confidence": 0.9},
|
||||
{"ticker": "MSFT", "direction": "up", "confidence": 0.6},
|
||||
{"ticker": "GOOGL", "direction": "down", "confidence": 0.75},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
tracker.record_analyst_predictions(final_predictions)
|
||||
|
||||
assert "technical_analyst" in tracker.daily_predictions
|
||||
assert "fundamentals_analyst" in tracker.daily_predictions
|
||||
assert tracker.daily_predictions["technical_analyst"]["AAPL"] == "long"
|
||||
assert tracker.daily_predictions["technical_analyst"]["MSFT"] == "short"
|
||||
assert tracker.daily_predictions["technical_analyst"]["GOOGL"] == "hold"
|
||||
|
||||
|
||||
def test_analyst_evaluation():
|
||||
"""Test analyst prediction evaluation"""
|
||||
tracker = AnalystPerformanceTracker()
|
||||
|
||||
tracker.daily_predictions = {
|
||||
"technical_analyst": {
|
||||
"AAPL": "long",
|
||||
"MSFT": "short",
|
||||
},
|
||||
}
|
||||
|
||||
open_prices = {"AAPL": 100.0, "MSFT": 200.0}
|
||||
close_prices = {"AAPL": 105.0, "MSFT": 195.0}
|
||||
|
||||
evaluations = tracker.evaluate_predictions(
|
||||
open_prices,
|
||||
close_prices,
|
||||
"2024-01-15",
|
||||
)
|
||||
|
||||
assert "technical_analyst" in evaluations
|
||||
eval_result = evaluations["technical_analyst"]
|
||||
assert eval_result["correct_predictions"] == 2
|
||||
assert eval_result["win_rate"] == 1.0
|
||||
|
||||
# Verify individual signals format
|
||||
assert "signals" in eval_result
|
||||
assert len(eval_result["signals"]) == 2
|
||||
for signal in eval_result["signals"]:
|
||||
assert "ticker" in signal
|
||||
assert "signal" in signal
|
||||
assert "date" in signal
|
||||
assert "is_correct" in signal
|
||||
assert signal["date"] == "2024-01-15"
|
||||
|
||||
|
||||
def test_leaderboard_update():
|
||||
"""Test leaderboard update with evaluations"""
|
||||
leaderboard = [
|
||||
{
|
||||
"agentId": "technical_analyst",
|
||||
"name": "Technical Analyst",
|
||||
"rank": 0,
|
||||
"winRate": None,
|
||||
"bull": {"n": 0, "win": 0, "unknown": 0},
|
||||
"bear": {"n": 0, "win": 0, "unknown": 0},
|
||||
"signals": [],
|
||||
},
|
||||
]
|
||||
|
||||
evaluations = {
|
||||
"technical_analyst": {
|
||||
"total_predictions": 2,
|
||||
"correct_predictions": 1,
|
||||
"win_rate": 0.5,
|
||||
"bull": {"n": 1, "win": 1, "unknown": 0},
|
||||
"bear": {"n": 1, "win": 0, "unknown": 0},
|
||||
"hold": 0,
|
||||
"signals": [
|
||||
{
|
||||
"ticker": "AAPL",
|
||||
"signal": "bull",
|
||||
"date": "2024-01-01",
|
||||
"is_correct": True,
|
||||
},
|
||||
{
|
||||
"ticker": "MSFT",
|
||||
"signal": "bear",
|
||||
"date": "2024-01-01",
|
||||
"is_correct": False,
|
||||
},
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
updated = update_leaderboard_with_evaluations(
|
||||
leaderboard,
|
||||
evaluations,
|
||||
)
|
||||
|
||||
assert updated[0]["bull"]["n"] == 1
|
||||
assert updated[0]["bull"]["win"] == 1
|
||||
assert updated[0]["winRate"] == 0.5
|
||||
assert len(updated[0]["signals"]) == 2
|
||||
|
||||
# Verify signal format matches frontend expectations
|
||||
for signal in updated[0]["signals"]:
|
||||
assert "ticker" in signal
|
||||
assert "signal" in signal
|
||||
assert "date" in signal
|
||||
assert "is_correct" in signal
|
||||
0
backend/tools/__init__.py
Normal file
0
backend/tools/__init__.py
Normal file
1289
backend/tools/analysis_tools.py
Normal file
1289
backend/tools/analysis_tools.py
Normal file
File diff suppressed because it is too large
Load Diff
742
backend/tools/data_tools.py
Normal file
742
backend/tools/data_tools.py
Normal file
@@ -0,0 +1,742 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# flake8: noqa: E501
|
||||
# pylint: disable=C0301
|
||||
"""
|
||||
Data fetching tools for financial data.
|
||||
|
||||
All functions use centralized data source configuration from data_config.py.
|
||||
The data source is automatically determined based on available API keys:
|
||||
- Priority: FINNHUB_API_KEY > FINANCIAL_DATASETS_API_KEY
|
||||
"""
|
||||
import datetime
|
||||
import time
|
||||
|
||||
import finnhub
|
||||
import pandas as pd
|
||||
import pandas_market_calendars as mcal
|
||||
import requests
|
||||
|
||||
from backend.config.data_config import (
|
||||
get_config,
|
||||
get_api_key,
|
||||
)
|
||||
from backend.data.cache import get_cache
|
||||
from backend.data.schema import (
|
||||
CompanyFactsResponse,
|
||||
CompanyNews,
|
||||
CompanyNewsResponse,
|
||||
FinancialMetrics,
|
||||
FinancialMetricsResponse,
|
||||
InsiderTrade,
|
||||
InsiderTradeResponse,
|
||||
LineItem,
|
||||
LineItemResponse,
|
||||
Price,
|
||||
PriceResponse,
|
||||
)
|
||||
from backend.utils.settlement import logger
|
||||
|
||||
# Global cache instance
|
||||
_cache = get_cache()
|
||||
|
||||
|
||||
def get_last_tradeday(date: str) -> str:
|
||||
"""
|
||||
Get the previous trading day for the specified date
|
||||
|
||||
Args:
|
||||
date: Date string (YYYY-MM-DD)
|
||||
|
||||
Returns:
|
||||
Previous trading day date string (YYYY-MM-DD)
|
||||
"""
|
||||
current_date = datetime.datetime.strptime(date, "%Y-%m-%d")
|
||||
_NYSE_CALENDAR = mcal.get_calendar("NYSE")
|
||||
|
||||
if _NYSE_CALENDAR is not None:
|
||||
# Get trading days before current date
|
||||
# Go back 90 days from current date to get all trading days
|
||||
start_search = current_date - datetime.timedelta(days=90)
|
||||
|
||||
if hasattr(_NYSE_CALENDAR, "valid_days"):
|
||||
# pandas_market_calendars
|
||||
trading_dates = _NYSE_CALENDAR.valid_days(
|
||||
start_date=start_search.strftime("%Y-%m-%d"),
|
||||
end_date=current_date.strftime("%Y-%m-%d"),
|
||||
)
|
||||
else:
|
||||
# exchange_calendars
|
||||
trading_dates = _NYSE_CALENDAR.sessions_in_range(
|
||||
start_search.strftime("%Y-%m-%d"),
|
||||
current_date.strftime("%Y-%m-%d"),
|
||||
)
|
||||
|
||||
# Convert to date list
|
||||
trading_dates_list = [
|
||||
pd.Timestamp(d).strftime("%Y-%m-%d") for d in trading_dates
|
||||
]
|
||||
|
||||
# Find current date position in the list
|
||||
if date in trading_dates_list:
|
||||
# If current date is a trading day, return previous trading day
|
||||
idx = trading_dates_list.index(date)
|
||||
if idx > 0:
|
||||
return trading_dates_list[idx - 1]
|
||||
else:
|
||||
# If it's the first trading day, go back further
|
||||
prev_date = current_date - datetime.timedelta(days=1)
|
||||
return get_last_tradeday(prev_date.strftime("%Y-%m-%d"))
|
||||
else:
|
||||
# If current date is not a trading day, return the nearest trading day
|
||||
if trading_dates_list:
|
||||
return trading_dates_list[-1]
|
||||
|
||||
return prev_date.strftime("%Y-%m-%d")
|
||||
|
||||
|
||||
def _make_api_request(
|
||||
url: str,
|
||||
headers: dict,
|
||||
method: str = "GET",
|
||||
json_data: dict = None,
|
||||
max_retries: int = 3,
|
||||
) -> requests.Response:
|
||||
"""
|
||||
Make an API request with rate limiting handling and moderate backoff.
|
||||
|
||||
Args:
|
||||
url: The URL to request
|
||||
headers: Headers to include in the request
|
||||
method: HTTP method (GET or POST)
|
||||
json_data: JSON data for POST requests
|
||||
max_retries: Maximum number of retries (default: 3)
|
||||
|
||||
Returns:
|
||||
requests.Response: The response object
|
||||
|
||||
Raises:
|
||||
Exception: If the request fails with a non-429 error
|
||||
"""
|
||||
for attempt in range(max_retries + 1): # +1 for initial attempt
|
||||
if method.upper() == "POST":
|
||||
response = requests.post(url, headers=headers, json=json_data)
|
||||
else:
|
||||
response = requests.get(url, headers=headers)
|
||||
|
||||
if response.status_code == 429 and attempt < max_retries:
|
||||
# Linear backoff: 60s, 90s, 120s, 150s...
|
||||
delay = 60 + (30 * attempt)
|
||||
print(
|
||||
f"Rate limited (429). Attempt {attempt + 1}/{max_retries + 1}. Waiting {delay}s before retrying...",
|
||||
)
|
||||
time.sleep(delay)
|
||||
continue
|
||||
|
||||
# Return the response (whether success, other errors, or final 429)
|
||||
return response
|
||||
|
||||
|
||||
def get_prices(
|
||||
ticker: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
) -> list[Price]:
|
||||
"""
|
||||
Fetch price data from cache or API.
|
||||
|
||||
Uses centralized data source configuration (FINNHUB_API_KEY prioritized).
|
||||
|
||||
Args:
|
||||
ticker: Stock ticker symbol
|
||||
start_date: Start date (YYYY-MM-DD)
|
||||
end_date: End date (YYYY-MM-DD)
|
||||
|
||||
Returns:
|
||||
list[Price]: List of Price objects
|
||||
"""
|
||||
config = get_config()
|
||||
data_source = config.source
|
||||
api_key = config.api_key
|
||||
|
||||
# Create a cache key that includes all parameters to ensure exact matches
|
||||
cache_key = f"{ticker}_{start_date}_{end_date}_{data_source}"
|
||||
|
||||
# Check cache first - simple exact match
|
||||
if cached_data := _cache.get_prices(cache_key):
|
||||
return [Price(**price) for price in cached_data]
|
||||
|
||||
prices = []
|
||||
|
||||
if data_source == "finnhub":
|
||||
# Use Finnhub API
|
||||
client = finnhub.Client(api_key=api_key)
|
||||
|
||||
# Convert dates to timestamps
|
||||
start_timestamp = int(
|
||||
datetime.datetime.strptime(start_date, "%Y-%m-%d").timestamp(),
|
||||
)
|
||||
end_timestamp = int(
|
||||
(
|
||||
datetime.datetime.strptime(end_date, "%Y-%m-%d")
|
||||
+ datetime.timedelta(days=1)
|
||||
).timestamp(),
|
||||
)
|
||||
|
||||
# Fetch candle data from Finnhub
|
||||
candles = client.stock_candles(
|
||||
ticker,
|
||||
"D",
|
||||
start_timestamp,
|
||||
end_timestamp,
|
||||
)
|
||||
|
||||
# Convert to Price objects
|
||||
for i in range(len(candles["t"])):
|
||||
price = Price(
|
||||
open=candles["o"][i],
|
||||
close=candles["c"][i],
|
||||
high=candles["h"][i],
|
||||
low=candles["l"][i],
|
||||
volume=int(candles["v"][i]),
|
||||
time=datetime.datetime.fromtimestamp(candles["t"][i]).strftime(
|
||||
"%Y-%m-%d",
|
||||
),
|
||||
)
|
||||
prices.append(price)
|
||||
|
||||
else: # financial_datasets
|
||||
# Use Financial Datasets API
|
||||
headers = {"X-API-KEY": api_key}
|
||||
|
||||
url = f"https://api.financialdatasets.ai/prices/?ticker={ticker}&interval=day&interval_multiplier=1&start_date={start_date}&end_date={end_date}"
|
||||
response = _make_api_request(url, headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(
|
||||
f"Error fetching data: {ticker} - {response.status_code} - {response.text}",
|
||||
)
|
||||
|
||||
# Parse response with Pydantic model
|
||||
price_response = PriceResponse(**response.json())
|
||||
prices = price_response.prices
|
||||
|
||||
if not prices:
|
||||
return []
|
||||
|
||||
# Cache the results using the comprehensive cache key
|
||||
_cache.set_prices(cache_key, [p.model_dump() for p in prices])
|
||||
return prices
|
||||
|
||||
|
||||
def get_financial_metrics(
|
||||
ticker: str,
|
||||
end_date: str,
|
||||
period: str = "ttm",
|
||||
limit: int = 10,
|
||||
) -> list[FinancialMetrics]:
|
||||
"""
|
||||
Fetch financial metrics from cache or API.
|
||||
|
||||
Uses centralized data source configuration (FINNHUB_API_KEY prioritized).
|
||||
|
||||
Args:
|
||||
ticker: Stock ticker symbol
|
||||
end_date: End date (YYYY-MM-DD)
|
||||
period: Period type (default: "ttm")
|
||||
limit: Number of records to fetch
|
||||
|
||||
Returns:
|
||||
list[FinancialMetrics]: List of financial metrics
|
||||
"""
|
||||
config = get_config()
|
||||
data_source = config.source
|
||||
api_key = config.api_key
|
||||
|
||||
# Create a cache key that includes all parameters to ensure exact matches
|
||||
cache_key = f"{ticker}_{period}_{end_date}_{limit}_{data_source}"
|
||||
|
||||
# Check cache first - simple exact match
|
||||
if cached_data := _cache.get_financial_metrics(cache_key):
|
||||
return [FinancialMetrics(**metric) for metric in cached_data]
|
||||
|
||||
financial_metrics = []
|
||||
|
||||
if data_source == "finnhub":
|
||||
# Use Finnhub API - Basic Financials
|
||||
client = finnhub.Client(api_key=api_key)
|
||||
|
||||
# Fetch basic financials from Finnhub
|
||||
# metric='all' returns all available metrics
|
||||
financials = client.company_basic_financials(ticker, "all")
|
||||
|
||||
if not financials or "metric" not in financials:
|
||||
return []
|
||||
|
||||
# Finnhub returns {series: {...}, metric: {...}, metricType: ..., symbol: ...}
|
||||
# We need to create a FinancialMetrics object from this
|
||||
metric_data = financials.get("metric", {})
|
||||
|
||||
# Create a FinancialMetrics object with available data
|
||||
metric = _map_finnhub_metrics(ticker, end_date, period, metric_data)
|
||||
|
||||
financial_metrics = [metric]
|
||||
|
||||
else: # financial_datasets
|
||||
# Use Financial Datasets API
|
||||
headers = {"X-API-KEY": api_key}
|
||||
|
||||
url = f"https://api.financialdatasets.ai/financial-metrics/?ticker={ticker}&report_period_lte={end_date}&limit={limit}&period={period}"
|
||||
response = _make_api_request(url, headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(
|
||||
f"Error fetching data: {ticker} - {response.status_code} - {response.text}",
|
||||
)
|
||||
|
||||
# Parse response with Pydantic model
|
||||
metrics_response = FinancialMetricsResponse(**response.json())
|
||||
financial_metrics = metrics_response.financial_metrics
|
||||
|
||||
if not financial_metrics:
|
||||
return []
|
||||
|
||||
# Cache the results as dicts using the comprehensive cache key
|
||||
_cache.set_financial_metrics(
|
||||
cache_key,
|
||||
[m.model_dump() for m in financial_metrics],
|
||||
)
|
||||
return financial_metrics
|
||||
|
||||
|
||||
def _map_finnhub_metrics(
|
||||
ticker: str,
|
||||
end_date: str,
|
||||
period: str,
|
||||
metric_data: dict,
|
||||
) -> FinancialMetrics:
|
||||
"""Map Finnhub metric data to FinancialMetrics model."""
|
||||
return FinancialMetrics(
|
||||
ticker=ticker,
|
||||
report_period=end_date,
|
||||
period=period,
|
||||
currency="USD",
|
||||
market_cap=metric_data.get("marketCapitalization"),
|
||||
enterprise_value=None,
|
||||
price_to_earnings_ratio=metric_data.get("peBasicExclExtraTTM"),
|
||||
price_to_book_ratio=metric_data.get("pbAnnual"),
|
||||
price_to_sales_ratio=metric_data.get("psAnnual"),
|
||||
enterprise_value_to_ebitda_ratio=None,
|
||||
enterprise_value_to_revenue_ratio=None,
|
||||
free_cash_flow_yield=None,
|
||||
peg_ratio=None,
|
||||
gross_margin=metric_data.get("grossMarginTTM"),
|
||||
operating_margin=metric_data.get("operatingMarginTTM"),
|
||||
net_margin=metric_data.get("netProfitMarginTTM"),
|
||||
return_on_equity=metric_data.get("roeTTM"),
|
||||
return_on_assets=metric_data.get("roaTTM"),
|
||||
return_on_invested_capital=metric_data.get("roicTTM"),
|
||||
asset_turnover=metric_data.get("assetTurnoverTTM"),
|
||||
inventory_turnover=metric_data.get("inventoryTurnoverTTM"),
|
||||
receivables_turnover=metric_data.get("receivablesTurnoverTTM"),
|
||||
days_sales_outstanding=None,
|
||||
operating_cycle=None,
|
||||
working_capital_turnover=None,
|
||||
current_ratio=metric_data.get("currentRatioAnnual"),
|
||||
quick_ratio=metric_data.get("quickRatioAnnual"),
|
||||
cash_ratio=None,
|
||||
operating_cash_flow_ratio=None,
|
||||
debt_to_equity=metric_data.get("totalDebt/totalEquityAnnual"),
|
||||
debt_to_assets=None,
|
||||
interest_coverage=None,
|
||||
revenue_growth=metric_data.get("revenueGrowthTTMYoy"),
|
||||
earnings_growth=None,
|
||||
book_value_growth=None,
|
||||
earnings_per_share_growth=metric_data.get("epsGrowthTTMYoy"),
|
||||
free_cash_flow_growth=None,
|
||||
operating_income_growth=None,
|
||||
ebitda_growth=None,
|
||||
payout_ratio=metric_data.get("payoutRatioAnnual"),
|
||||
earnings_per_share=metric_data.get("epsBasicExclExtraItemsTTM"),
|
||||
book_value_per_share=metric_data.get("bookValuePerShareAnnual"),
|
||||
free_cash_flow_per_share=None,
|
||||
)
|
||||
|
||||
|
||||
def search_line_items(
|
||||
ticker: str,
|
||||
line_items: list[str],
|
||||
end_date: str,
|
||||
period: str = "ttm",
|
||||
limit: int = 10,
|
||||
) -> list[LineItem]:
|
||||
"""
|
||||
Fetch line items from Financial Datasets API (only supported source).
|
||||
|
||||
Returns empty list on API errors to allow graceful degradation.
|
||||
"""
|
||||
try:
|
||||
api_key = get_api_key()
|
||||
headers = {"X-API-KEY": api_key}
|
||||
|
||||
url = "https://api.financialdatasets.ai/financials/search/line-items"
|
||||
body = {
|
||||
"tickers": [ticker],
|
||||
"line_items": line_items,
|
||||
"end_date": end_date,
|
||||
"period": period,
|
||||
"limit": limit,
|
||||
}
|
||||
response = _make_api_request(
|
||||
url,
|
||||
headers,
|
||||
method="POST",
|
||||
json_data=body,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.info(
|
||||
f"Warning: Failed to fetch line items for {ticker}: "
|
||||
f"{response.status_code} - {response.text}",
|
||||
)
|
||||
return []
|
||||
|
||||
data = response.json()
|
||||
response_model = LineItemResponse(**data)
|
||||
search_results = response_model.search_results
|
||||
|
||||
if not search_results:
|
||||
return []
|
||||
|
||||
return search_results[:limit]
|
||||
|
||||
except Exception as e:
|
||||
logger.info(
|
||||
f"Warning: Exception while fetching line items for {ticker}: {str(e)}",
|
||||
)
|
||||
return []
|
||||
|
||||
|
||||
def _fetch_finnhub_insider_trades(
|
||||
ticker: str,
|
||||
start_date: str | None,
|
||||
end_date: str,
|
||||
limit: int,
|
||||
api_key: str,
|
||||
) -> list[InsiderTrade]:
|
||||
"""Fetch insider trades from Finnhub API."""
|
||||
client = finnhub.Client(api_key=api_key)
|
||||
|
||||
from_date = start_date or (
|
||||
datetime.datetime.strptime(end_date, "%Y-%m-%d")
|
||||
- datetime.timedelta(days=365)
|
||||
).strftime("%Y-%m-%d")
|
||||
|
||||
insider_data = client.stock_insider_transactions(
|
||||
ticker,
|
||||
from_date,
|
||||
end_date,
|
||||
)
|
||||
|
||||
if not insider_data or "data" not in insider_data:
|
||||
return []
|
||||
|
||||
return [
|
||||
_convert_finnhub_insider_trade(ticker, trade)
|
||||
for trade in insider_data["data"][:limit]
|
||||
]
|
||||
|
||||
|
||||
def _fetch_fd_insider_trades(
|
||||
ticker: str,
|
||||
start_date: str | None,
|
||||
end_date: str,
|
||||
limit: int,
|
||||
api_key: str,
|
||||
) -> list[InsiderTrade]:
|
||||
"""Fetch insider trades from Financial Datasets API."""
|
||||
headers = {"X-API-KEY": api_key}
|
||||
all_trades = []
|
||||
current_end_date = end_date
|
||||
|
||||
while True:
|
||||
url = f"https://api.financialdatasets.ai/insider-trades/?ticker={ticker}&filing_date_lte={current_end_date}"
|
||||
if start_date:
|
||||
url += f"&filing_date_gte={start_date}"
|
||||
url += f"&limit={limit}"
|
||||
|
||||
response = _make_api_request(url, headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(
|
||||
f"Error fetching data: {ticker} - {response.status_code} - {response.text}",
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
response_model = InsiderTradeResponse(**data)
|
||||
insider_trades = response_model.insider_trades
|
||||
|
||||
if not insider_trades:
|
||||
break
|
||||
|
||||
all_trades.extend(insider_trades)
|
||||
|
||||
if not start_date or len(insider_trades) < limit:
|
||||
break
|
||||
|
||||
current_end_date = min(
|
||||
trade.filing_date for trade in insider_trades
|
||||
).split("T")[0]
|
||||
|
||||
if current_end_date <= start_date:
|
||||
break
|
||||
|
||||
return all_trades
|
||||
|
||||
|
||||
def get_insider_trades(
|
||||
ticker: str,
|
||||
end_date: str,
|
||||
start_date: str | None = None,
|
||||
limit: int = 1000,
|
||||
) -> list[InsiderTrade]:
|
||||
"""Fetch insider trades from cache or API."""
|
||||
config = get_config()
|
||||
data_source = config.source
|
||||
api_key = config.api_key
|
||||
|
||||
cache_key = (
|
||||
f"{ticker}_{start_date or 'none'}_{end_date}_{limit}_{data_source}"
|
||||
)
|
||||
|
||||
if cached_data := _cache.get_insider_trades(cache_key):
|
||||
return [InsiderTrade(**trade) for trade in cached_data]
|
||||
|
||||
if data_source == "finnhub":
|
||||
all_trades = _fetch_finnhub_insider_trades(
|
||||
ticker,
|
||||
start_date,
|
||||
end_date,
|
||||
limit,
|
||||
api_key,
|
||||
)
|
||||
else:
|
||||
all_trades = _fetch_fd_insider_trades(
|
||||
ticker,
|
||||
start_date,
|
||||
end_date,
|
||||
limit,
|
||||
api_key,
|
||||
)
|
||||
|
||||
if not all_trades:
|
||||
return []
|
||||
|
||||
_cache.set_insider_trades(
|
||||
cache_key,
|
||||
[trade.model_dump() for trade in all_trades],
|
||||
)
|
||||
return all_trades
|
||||
|
||||
|
||||
def _fetch_finnhub_company_news(
|
||||
ticker: str,
|
||||
start_date: str | None,
|
||||
end_date: str,
|
||||
limit: int,
|
||||
api_key: str,
|
||||
) -> list[CompanyNews]:
|
||||
"""Fetch company news from Finnhub API."""
|
||||
client = finnhub.Client(api_key=api_key)
|
||||
|
||||
from_date = start_date or (
|
||||
datetime.datetime.strptime(end_date, "%Y-%m-%d")
|
||||
- datetime.timedelta(days=30)
|
||||
).strftime("%Y-%m-%d")
|
||||
|
||||
news_data = client.company_news(ticker, _from=from_date, to=end_date)
|
||||
|
||||
if not news_data:
|
||||
return []
|
||||
|
||||
all_news = []
|
||||
for news_item in news_data[:limit]:
|
||||
company_news = CompanyNews(
|
||||
ticker=ticker,
|
||||
title=news_item.get("headline", ""),
|
||||
related=news_item.get("related", ""),
|
||||
source=news_item.get("source", ""),
|
||||
date=(
|
||||
datetime.datetime.fromtimestamp(
|
||||
news_item.get("datetime", 0),
|
||||
datetime.timezone.utc,
|
||||
).strftime("%Y-%m-%d")
|
||||
if news_item.get("datetime")
|
||||
else None
|
||||
),
|
||||
url=news_item.get("url", ""),
|
||||
summary=news_item.get("summary", ""),
|
||||
category=news_item.get("category", ""),
|
||||
)
|
||||
all_news.append(company_news)
|
||||
return all_news
|
||||
|
||||
|
||||
def _fetch_fd_company_news(
|
||||
ticker: str,
|
||||
start_date: str | None,
|
||||
end_date: str,
|
||||
limit: int,
|
||||
api_key: str,
|
||||
) -> list[CompanyNews]:
|
||||
"""Fetch company news from Financial Datasets API."""
|
||||
headers = {"X-API-KEY": api_key}
|
||||
all_news = []
|
||||
current_end_date = end_date
|
||||
|
||||
while True:
|
||||
url = f"https://api.financialdatasets.ai/news/?ticker={ticker}&end_date={current_end_date}"
|
||||
if start_date:
|
||||
url += f"&start_date={start_date}"
|
||||
url += f"&limit={limit}"
|
||||
|
||||
response = _make_api_request(url, headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(
|
||||
f"Error fetching data: {ticker} - {response.status_code} - {response.text}",
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
response_model = CompanyNewsResponse(**data)
|
||||
company_news = response_model.news
|
||||
|
||||
if not company_news:
|
||||
break
|
||||
|
||||
all_news.extend(company_news)
|
||||
|
||||
if not start_date or len(company_news) < limit:
|
||||
break
|
||||
|
||||
current_end_date = min(
|
||||
news.date for news in company_news if news.date is not None
|
||||
).split("T")[0]
|
||||
|
||||
if current_end_date <= start_date:
|
||||
break
|
||||
|
||||
return all_news
|
||||
|
||||
|
||||
def get_company_news(
|
||||
ticker: str,
|
||||
end_date: str,
|
||||
start_date: str | None = None,
|
||||
limit: int = 1000,
|
||||
) -> list[CompanyNews]:
|
||||
"""Fetch company news from cache or API."""
|
||||
config = get_config()
|
||||
data_source = config.source
|
||||
api_key = config.api_key
|
||||
|
||||
cache_key = (
|
||||
f"{ticker}_{start_date or 'none'}_{end_date}_{limit}_{data_source}"
|
||||
)
|
||||
|
||||
if cached_data := _cache.get_company_news(cache_key):
|
||||
return [CompanyNews(**news) for news in cached_data]
|
||||
|
||||
if data_source == "finnhub":
|
||||
all_news = _fetch_finnhub_company_news(
|
||||
ticker,
|
||||
start_date,
|
||||
end_date,
|
||||
limit,
|
||||
api_key,
|
||||
)
|
||||
else:
|
||||
all_news = _fetch_fd_company_news(
|
||||
ticker,
|
||||
start_date,
|
||||
end_date,
|
||||
limit,
|
||||
api_key,
|
||||
)
|
||||
|
||||
if not all_news:
|
||||
return []
|
||||
|
||||
_cache.set_company_news(
|
||||
cache_key,
|
||||
[news.model_dump() for news in all_news],
|
||||
)
|
||||
return all_news
|
||||
|
||||
|
||||
def _convert_finnhub_insider_trade(ticker: str, trade: dict) -> InsiderTrade:
|
||||
"""Convert Finnhub insider trade format to InsiderTrade model."""
|
||||
shares_after = trade.get("share", 0)
|
||||
change = trade.get("change", 0)
|
||||
|
||||
return InsiderTrade(
|
||||
ticker=ticker,
|
||||
issuer=None,
|
||||
name=trade.get("name", ""),
|
||||
title=None,
|
||||
is_board_director=None,
|
||||
transaction_date=trade.get("transactionDate", ""),
|
||||
transaction_shares=abs(change),
|
||||
transaction_price_per_share=trade.get("transactionPrice", 0.0),
|
||||
transaction_value=abs(change) * trade.get("transactionPrice", 0.0),
|
||||
shares_owned_before_transaction=(
|
||||
shares_after - change if shares_after and change else None
|
||||
),
|
||||
shares_owned_after_transaction=float(shares_after)
|
||||
if shares_after
|
||||
else None,
|
||||
security_title=None,
|
||||
filing_date=trade.get("filingDate", ""),
|
||||
)
|
||||
|
||||
|
||||
def get_market_cap(ticker: str, end_date: str) -> float | None:
|
||||
"""Fetch market cap from the API. Finnhub values are converted from millions."""
|
||||
config = get_config()
|
||||
data_source = config.source
|
||||
api_key = config.api_key
|
||||
|
||||
# For today's date, use company facts API
|
||||
if end_date == datetime.datetime.now().strftime("%Y-%m-%d"):
|
||||
headers = {"X-API-KEY": api_key}
|
||||
url = (
|
||||
f"https://api.financialdatasets.ai/company/facts/?ticker={ticker}"
|
||||
)
|
||||
response = _make_api_request(url, headers)
|
||||
if response.status_code != 200:
|
||||
return None
|
||||
|
||||
data = response.json()
|
||||
response_model = CompanyFactsResponse(**data)
|
||||
return response_model.company_facts.market_cap
|
||||
|
||||
financial_metrics = get_financial_metrics(ticker, end_date)
|
||||
if not financial_metrics:
|
||||
return None
|
||||
|
||||
market_cap = financial_metrics[0].market_cap
|
||||
if not market_cap:
|
||||
return None
|
||||
|
||||
# Finnhub returns market cap in millions
|
||||
if data_source == "finnhub":
|
||||
market_cap = market_cap * 1_000_000
|
||||
|
||||
return market_cap
|
||||
|
||||
|
||||
def prices_to_df(prices: list[Price]) -> pd.DataFrame:
|
||||
"""Convert prices to a DataFrame."""
|
||||
df = pd.DataFrame([p.model_dump() for p in prices])
|
||||
df["Date"] = pd.to_datetime(df["time"])
|
||||
df.set_index("Date", inplace=True)
|
||||
numeric_cols = ["open", "close", "high", "low", "volume"]
|
||||
for col in numeric_cols:
|
||||
df[col] = pd.to_numeric(df[col], errors="coerce")
|
||||
df.sort_index(inplace=True)
|
||||
return df
|
||||
4
backend/utils/__init__.py
Normal file
4
backend/utils/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# This file can be empty
|
||||
|
||||
"""Utility modules for the application."""
|
||||
449
backend/utils/analyst_tracker.py
Normal file
449
backend/utils/analyst_tracker.py
Normal file
@@ -0,0 +1,449 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Analyst Performance Tracker
|
||||
Tracks analyst predictions and calculates win rates for leaderboard
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AnalystPerformanceTracker:
|
||||
"""
|
||||
Tracks analyst predictions and evaluates accuracy
|
||||
|
||||
Workflow:
|
||||
1. Record analyst predictions for each ticker before market close
|
||||
2. After market close, evaluate predictions against actual returns
|
||||
3. Update leaderboard with win rates and statistics
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.daily_predictions = {}
|
||||
|
||||
def record_analyst_predictions(
|
||||
self,
|
||||
final_predictions: List[Dict[str, Any]],
|
||||
):
|
||||
"""
|
||||
Record predictions from analysts for the current trading day
|
||||
|
||||
Args:
|
||||
final_predictions: List of structured prediction results
|
||||
Format: [
|
||||
{
|
||||
'agent': 'analyst_name',
|
||||
'predictions': [
|
||||
{'ticker': 'AAPL', '
|
||||
direction': 'up',
|
||||
'confidence': 0.75},
|
||||
...
|
||||
]
|
||||
},
|
||||
...
|
||||
]
|
||||
tickers: List of tickers being analyzed
|
||||
"""
|
||||
self.daily_predictions = {}
|
||||
|
||||
direction_mapping = {
|
||||
"up": "long",
|
||||
"down": "short",
|
||||
"neutral": "hold",
|
||||
}
|
||||
|
||||
for result in final_predictions:
|
||||
analyst_id = result.get("agent")
|
||||
if not analyst_id:
|
||||
continue
|
||||
|
||||
predictions = result.get("predictions", [])
|
||||
|
||||
self.daily_predictions[analyst_id] = {}
|
||||
|
||||
for pred in predictions:
|
||||
ticker = pred.get("ticker")
|
||||
direction = pred.get("direction", "neutral")
|
||||
|
||||
if ticker:
|
||||
signal = direction_mapping.get(direction, "hold")
|
||||
self.daily_predictions[analyst_id][ticker] = signal
|
||||
|
||||
def evaluate_predictions(
|
||||
self,
|
||||
open_prices: Optional[Dict[str, float]],
|
||||
close_prices: Dict[str, float],
|
||||
date: str,
|
||||
) -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
Evaluate analyst predictions against actual market moves
|
||||
|
||||
Args:
|
||||
open_prices: Opening prices for each ticker
|
||||
close_prices: Closing prices for each ticker
|
||||
date: Trading date string (YYYY-MM-DD)
|
||||
|
||||
Returns:
|
||||
Dict mapping analyst_id to evaluation results
|
||||
"""
|
||||
evaluation_results = {}
|
||||
|
||||
# Map internal signal types to frontend display names
|
||||
signal_display_map = {
|
||||
"long": "bull",
|
||||
"short": "bear",
|
||||
"hold": "neutral",
|
||||
}
|
||||
|
||||
for analyst_id, predictions in self.daily_predictions.items():
|
||||
correct_long = 0
|
||||
correct_short = 0
|
||||
incorrect_long = 0
|
||||
incorrect_short = 0
|
||||
unknown_long = 0
|
||||
unknown_short = 0
|
||||
hold_count = 0
|
||||
|
||||
# Individual signal records for frontend display
|
||||
individual_signals: List[Dict[str, Any]] = []
|
||||
|
||||
for ticker, prediction in predictions.items():
|
||||
open_price = open_prices.get(ticker, 0)
|
||||
close_price = close_prices.get(ticker, 0)
|
||||
|
||||
signal_type = signal_display_map.get(prediction, "neutral")
|
||||
|
||||
# Cannot evaluate if prices are missing
|
||||
if open_price <= 0 or close_price <= 0:
|
||||
if prediction == "long":
|
||||
unknown_long += 1
|
||||
elif prediction == "short":
|
||||
unknown_short += 1
|
||||
|
||||
individual_signals.append(
|
||||
{
|
||||
"ticker": ticker,
|
||||
"signal": signal_type,
|
||||
"date": date,
|
||||
"is_correct": "unknown",
|
||||
},
|
||||
)
|
||||
continue
|
||||
|
||||
actual_return = (close_price - open_price) / open_price
|
||||
|
||||
if prediction == "long":
|
||||
is_correct = actual_return > 0
|
||||
if is_correct:
|
||||
correct_long += 1
|
||||
else:
|
||||
incorrect_long += 1
|
||||
|
||||
individual_signals.append(
|
||||
{
|
||||
"ticker": ticker,
|
||||
"signal": signal_type,
|
||||
"date": date,
|
||||
"is_correct": is_correct,
|
||||
},
|
||||
)
|
||||
|
||||
elif prediction == "short":
|
||||
is_correct = actual_return < 0
|
||||
if is_correct:
|
||||
correct_short += 1
|
||||
else:
|
||||
incorrect_short += 1
|
||||
|
||||
individual_signals.append(
|
||||
{
|
||||
"ticker": ticker,
|
||||
"signal": signal_type,
|
||||
"date": date,
|
||||
"is_correct": is_correct,
|
||||
},
|
||||
)
|
||||
|
||||
elif prediction == "hold":
|
||||
hold_count += 1
|
||||
individual_signals.append(
|
||||
{
|
||||
"ticker": ticker,
|
||||
"signal": signal_type,
|
||||
"date": date,
|
||||
"is_correct": None,
|
||||
},
|
||||
)
|
||||
|
||||
total_long = correct_long + incorrect_long + unknown_long
|
||||
total_short = correct_short + incorrect_short + unknown_short
|
||||
evaluated_long = correct_long + incorrect_long
|
||||
evaluated_short = correct_short + incorrect_short
|
||||
total_evaluated = evaluated_long + evaluated_short
|
||||
correct_predictions = correct_long + correct_short
|
||||
|
||||
win_rate = (
|
||||
correct_predictions / total_evaluated
|
||||
if total_evaluated > 0
|
||||
else None
|
||||
)
|
||||
|
||||
evaluation_results[analyst_id] = {
|
||||
"total_predictions": total_evaluated,
|
||||
"correct_predictions": correct_predictions,
|
||||
"win_rate": win_rate,
|
||||
"bull": {
|
||||
"n": total_long,
|
||||
"win": correct_long,
|
||||
"unknown": unknown_long,
|
||||
},
|
||||
"bear": {
|
||||
"n": total_short,
|
||||
"win": correct_short,
|
||||
"unknown": unknown_short,
|
||||
},
|
||||
"hold": hold_count,
|
||||
"signals": individual_signals,
|
||||
}
|
||||
|
||||
return evaluation_results
|
||||
|
||||
def clear_daily_predictions(self):
|
||||
"""Clear predictions after evaluation"""
|
||||
self.daily_predictions = {}
|
||||
|
||||
def _process_single_pm_decision(
|
||||
self,
|
||||
_ticker: str,
|
||||
decision: Dict,
|
||||
open_price: float,
|
||||
close_price: float,
|
||||
_date: str,
|
||||
) -> Tuple[str, Optional[bool], str]:
|
||||
"""
|
||||
Process a single PM decision and evaluate correctness
|
||||
|
||||
Returns:
|
||||
Tuple of (prediction, is_correct, signal_type)
|
||||
"""
|
||||
action = decision.get("action", "hold")
|
||||
|
||||
# Convert action to prediction format
|
||||
if action in ["buy", "long"]:
|
||||
prediction = "long"
|
||||
elif action in ["sell", "short"]:
|
||||
prediction = "short"
|
||||
else:
|
||||
prediction = "hold"
|
||||
|
||||
signal_display_map = {
|
||||
"long": "bull",
|
||||
"short": "bear",
|
||||
"hold": "neutral",
|
||||
}
|
||||
signal_type = signal_display_map.get(prediction, "neutral")
|
||||
|
||||
# Handle invalid prices
|
||||
if open_price <= 0 or close_price <= 0:
|
||||
return prediction, None, signal_type
|
||||
|
||||
# Evaluate correctness
|
||||
actual_return = (close_price - open_price) / open_price
|
||||
|
||||
if prediction == "long":
|
||||
is_correct = actual_return > 0
|
||||
elif prediction == "short":
|
||||
is_correct = actual_return < 0
|
||||
else: # hold
|
||||
is_correct = None
|
||||
|
||||
return prediction, is_correct, signal_type
|
||||
|
||||
def evaluate_pm_decisions(
|
||||
self,
|
||||
pm_decisions: Dict[str, Dict],
|
||||
open_prices: Optional[Dict[str, float]],
|
||||
close_prices: Dict[str, float],
|
||||
date: str,
|
||||
) -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
Evaluate PM's trading decisions against actual market moves
|
||||
|
||||
Args:
|
||||
pm_decisions: PM decisions {ticker: {action, quantity, ...}}
|
||||
open_prices: Opening prices for each ticker
|
||||
close_prices: Closing prices for each ticker
|
||||
date: Trading date string (YYYY-MM-DD)
|
||||
|
||||
Returns:
|
||||
Dict with 'portfolio_manager' key containing evaluation results
|
||||
"""
|
||||
if not pm_decisions or not open_prices or not close_prices:
|
||||
return {}
|
||||
|
||||
correct_long = 0
|
||||
correct_short = 0
|
||||
incorrect_long = 0
|
||||
incorrect_short = 0
|
||||
unknown_long = 0
|
||||
unknown_short = 0
|
||||
hold_count = 0
|
||||
|
||||
individual_signals: List[Dict[str, Any]] = []
|
||||
|
||||
for ticker, decision in pm_decisions.items():
|
||||
open_price = open_prices.get(ticker, 0)
|
||||
close_price = close_prices.get(ticker, 0)
|
||||
|
||||
(
|
||||
prediction,
|
||||
is_correct,
|
||||
signal_type,
|
||||
) = self._process_single_pm_decision(
|
||||
ticker,
|
||||
decision,
|
||||
open_price,
|
||||
close_price,
|
||||
date,
|
||||
)
|
||||
|
||||
if is_correct is None and (open_price <= 0 or close_price <= 0):
|
||||
if prediction == "long":
|
||||
unknown_long += 1
|
||||
elif prediction == "short":
|
||||
unknown_short += 1
|
||||
individual_signals.append(
|
||||
{
|
||||
"ticker": ticker,
|
||||
"signal": signal_type,
|
||||
"date": date,
|
||||
"is_correct": "unknown",
|
||||
},
|
||||
)
|
||||
elif prediction == "hold":
|
||||
hold_count += 1
|
||||
individual_signals.append(
|
||||
{
|
||||
"ticker": ticker,
|
||||
"signal": signal_type,
|
||||
"date": date,
|
||||
"is_correct": None,
|
||||
},
|
||||
)
|
||||
else:
|
||||
if prediction == "long":
|
||||
if is_correct:
|
||||
correct_long += 1
|
||||
else:
|
||||
incorrect_long += 1
|
||||
else:
|
||||
if is_correct:
|
||||
correct_short += 1
|
||||
else:
|
||||
incorrect_short += 1
|
||||
|
||||
individual_signals.append(
|
||||
{
|
||||
"ticker": ticker,
|
||||
"signal": signal_type,
|
||||
"date": date,
|
||||
"is_correct": is_correct,
|
||||
},
|
||||
)
|
||||
|
||||
total_long = correct_long + incorrect_long + unknown_long
|
||||
total_short = correct_short + incorrect_short + unknown_short
|
||||
evaluated_long = correct_long + incorrect_long
|
||||
evaluated_short = correct_short + incorrect_short
|
||||
total_evaluated = evaluated_long + evaluated_short
|
||||
correct_predictions = correct_long + correct_short
|
||||
|
||||
win_rate = (
|
||||
correct_predictions / total_evaluated
|
||||
if total_evaluated > 0
|
||||
else None
|
||||
)
|
||||
|
||||
return {
|
||||
"portfolio_manager": {
|
||||
"total_predictions": total_evaluated,
|
||||
"correct_predictions": correct_predictions,
|
||||
"win_rate": win_rate,
|
||||
"bull": {
|
||||
"n": total_long,
|
||||
"win": correct_long,
|
||||
"unknown": unknown_long,
|
||||
},
|
||||
"bear": {
|
||||
"n": total_short,
|
||||
"win": correct_short,
|
||||
"unknown": unknown_short,
|
||||
},
|
||||
"hold": hold_count,
|
||||
"signals": individual_signals,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def update_leaderboard_with_evaluations(
|
||||
leaderboard: List[Dict[str, Any]],
|
||||
evaluations: Dict[str, Dict[str, Any]],
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Update leaderboard with new evaluation results
|
||||
|
||||
Args:
|
||||
leaderboard: Current leaderboard data
|
||||
evaluations: Evaluation results for the day
|
||||
|
||||
Returns:
|
||||
Updated leaderboard
|
||||
"""
|
||||
for entry in leaderboard:
|
||||
agent_id = entry.get("agentId")
|
||||
if not agent_id or agent_id not in evaluations:
|
||||
continue
|
||||
|
||||
eval_result = evaluations[agent_id]
|
||||
|
||||
# Update aggregate stats
|
||||
entry["bull"]["n"] += eval_result["bull"]["n"]
|
||||
entry["bull"]["win"] += eval_result["bull"]["win"]
|
||||
entry["bull"]["unknown"] = (
|
||||
entry["bull"].get("unknown", 0) + eval_result["bull"]["unknown"]
|
||||
)
|
||||
entry["bear"]["n"] += eval_result["bear"]["n"]
|
||||
entry["bear"]["win"] += eval_result["bear"]["win"]
|
||||
entry["bear"]["unknown"] = (
|
||||
entry["bear"].get("unknown", 0) + eval_result["bear"]["unknown"]
|
||||
)
|
||||
|
||||
# Calculate win rate based on evaluated signals only
|
||||
# evaluated = total - unknown
|
||||
evaluated_bull = entry["bull"]["n"] - entry["bull"]["unknown"]
|
||||
evaluated_bear = entry["bear"]["n"] - entry["bear"]["unknown"]
|
||||
total_evaluated = evaluated_bull + evaluated_bear
|
||||
total_wins = entry["bull"]["win"] + entry["bear"]["win"]
|
||||
|
||||
if total_evaluated > 0:
|
||||
entry["winRate"] = round(total_wins / total_evaluated, 4)
|
||||
|
||||
# Add individual signal records
|
||||
if "signals" not in entry:
|
||||
entry["signals"] = []
|
||||
|
||||
for signal in eval_result.get("signals", []):
|
||||
entry["signals"].append(signal)
|
||||
|
||||
# Keep only recent signals (e.g., last 100 individual signals)
|
||||
entry["signals"] = entry["signals"][-100:]
|
||||
|
||||
# Re-rank analysts by win rate (rank starts from 1)
|
||||
analyst_entries = [e for e in leaderboard if e.get("rank") is not None]
|
||||
analyst_entries.sort(key=lambda e: e.get("winRate", 0), reverse=True)
|
||||
for idx, entry in enumerate(analyst_entries):
|
||||
entry["rank"] = idx + 1 # Rank 1 = highest win rate (gold medal)
|
||||
|
||||
return leaderboard
|
||||
405
backend/utils/baselines.py
Normal file
405
backend/utils/baselines.py
Normal file
@@ -0,0 +1,405 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Baseline Strategy Calculators
|
||||
Tracks performance of simple baseline strategies for comparison
|
||||
"""
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Tuple, TypedDict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Portfolio(TypedDict):
|
||||
cash: float
|
||||
positions: Dict[str, float]
|
||||
|
||||
|
||||
class BaselineCalculator:
|
||||
"""
|
||||
Calculates baseline strategy returns for comparison
|
||||
|
||||
Strategies:
|
||||
1. Equal-weight: Allocate equal weight to all tickers
|
||||
2. Market-cap-weighted: Allocate proportional to market cap
|
||||
3. Simple momentum: Monthly rebalance,
|
||||
long top 50% momentum, short bottom 50%
|
||||
"""
|
||||
|
||||
def __init__(self, initial_capital: float = 100000.0):
|
||||
self.initial_capital = initial_capital
|
||||
|
||||
self.equal_weight_portfolio: Portfolio = {"cash": 0.0, "positions": {}}
|
||||
self.market_cap_portfolio: Portfolio = {"cash": 0.0, "positions": {}}
|
||||
self.momentum_portfolio: Portfolio = {
|
||||
"cash": initial_capital,
|
||||
"positions": {},
|
||||
}
|
||||
|
||||
self.equal_weight_initialized = False
|
||||
self.market_cap_initialized = False
|
||||
self.momentum_last_rebalance_date = None
|
||||
|
||||
def calculate_equal_weight_value(
|
||||
self,
|
||||
tickers: List[str],
|
||||
open_prices: Dict[str, float],
|
||||
close_prices: Dict[str, float],
|
||||
) -> float:
|
||||
"""
|
||||
Calculate equal-weight portfolio value
|
||||
|
||||
On first call, initialize positions with equal allocation using
|
||||
open prices. Subsequently, mark-to-market existing positions
|
||||
using close prices.
|
||||
|
||||
Args:
|
||||
tickers: List of stock tickers
|
||||
open_prices: Opening prices (used for initial purchase)
|
||||
close_prices: Closing prices (used for valuation)
|
||||
"""
|
||||
if not self.equal_weight_initialized:
|
||||
allocation_per_ticker = self.initial_capital / len(tickers)
|
||||
self.equal_weight_portfolio["cash"] = 0.0
|
||||
for ticker in tickers:
|
||||
price = open_prices.get(ticker, 0) # Use OPEN price for buying
|
||||
if price > 0:
|
||||
shares = allocation_per_ticker / price
|
||||
self.equal_weight_portfolio["positions"][ticker] = shares
|
||||
logger.info(
|
||||
f"Equal Weight: Initialized {ticker} with "
|
||||
f"{shares:.2f} shares @ ${price:.2f} (open)",
|
||||
)
|
||||
self.equal_weight_initialized = True
|
||||
|
||||
total_value = self.equal_weight_portfolio["cash"]
|
||||
positions: Dict[str, float] = self.equal_weight_portfolio["positions"]
|
||||
for ticker, shares in positions.items():
|
||||
price = close_prices.get(ticker, 0)
|
||||
total_value += shares * price
|
||||
|
||||
return total_value
|
||||
|
||||
def calculate_market_cap_weighted_value(
|
||||
self,
|
||||
tickers: List[str],
|
||||
open_prices: Dict[str, float],
|
||||
close_prices: Dict[str, float],
|
||||
market_caps: Dict[str, float],
|
||||
) -> float:
|
||||
"""
|
||||
Calculate market-cap-weighted portfolio value
|
||||
|
||||
On first call, initialize positions weighted by market cap using
|
||||
open prices. Subsequently, mark-to-market existing positions
|
||||
using close prices.
|
||||
|
||||
Args:
|
||||
tickers: List of stock tickers
|
||||
open_prices: Opening prices (used for initial purchase)
|
||||
close_prices: Closing prices (used for valuation)
|
||||
market_caps: Market capitalization for each ticker
|
||||
"""
|
||||
if not self.market_cap_initialized:
|
||||
total_market_cap = sum(market_caps.get(t, 0) for t in tickers)
|
||||
if total_market_cap <= 0:
|
||||
logger.warning("No market cap data, using equal weight")
|
||||
return self.calculate_equal_weight_value(
|
||||
tickers,
|
||||
open_prices,
|
||||
close_prices,
|
||||
)
|
||||
|
||||
self.market_cap_portfolio["cash"] = 0.0
|
||||
for ticker in tickers:
|
||||
market_cap = market_caps.get(ticker, 0)
|
||||
price = open_prices.get(ticker, 0) # Use OPEN price for buying
|
||||
if market_cap > 0 and price > 0:
|
||||
weight = market_cap / total_market_cap
|
||||
allocation = self.initial_capital * weight
|
||||
shares = allocation / price
|
||||
self.market_cap_portfolio["positions"][ticker] = shares
|
||||
logger.info(
|
||||
f"Market Cap Weighted: Initialized {ticker} with "
|
||||
f"{shares:.2f} shares @ ${price:.2f} (open), "
|
||||
f"weight={weight:.2%}",
|
||||
)
|
||||
self.market_cap_initialized = True
|
||||
|
||||
total_value = self.market_cap_portfolio["cash"]
|
||||
positions: Dict[str, float] = self.market_cap_portfolio["positions"]
|
||||
for ticker, shares in positions.items():
|
||||
price = close_prices.get(ticker, 0)
|
||||
total_value += shares * price
|
||||
|
||||
return total_value
|
||||
|
||||
def calculate_momentum_value(
|
||||
self,
|
||||
tickers: List[str],
|
||||
open_prices: Dict[str, float],
|
||||
close_prices: Dict[str, float],
|
||||
momentum_scores: Dict[str, float],
|
||||
date: str,
|
||||
rebalance: bool = False,
|
||||
) -> float:
|
||||
"""
|
||||
Calculate momentum strategy portfolio value
|
||||
|
||||
Strategy: Monthly rebalance
|
||||
- Long top 50% momentum stocks
|
||||
- Short bottom 50% momentum stocks (if shorting enabled)
|
||||
- Equal weight within each group
|
||||
|
||||
Args:
|
||||
tickers: List of tickers
|
||||
open_prices: Opening prices (used for rebalancing trades)
|
||||
close_prices: Closing prices (used for valuation)
|
||||
momentum_scores: Momentum scores for each ticker
|
||||
date: Current date (YYYY-MM-DD)
|
||||
rebalance: Force rebalance if True
|
||||
"""
|
||||
should_rebalance = rebalance
|
||||
if self.momentum_last_rebalance_date is None:
|
||||
should_rebalance = True
|
||||
elif not rebalance:
|
||||
last_date = datetime.strptime(
|
||||
self.momentum_last_rebalance_date,
|
||||
"%Y-%m-%d",
|
||||
)
|
||||
current_date = datetime.strptime(date, "%Y-%m-%d")
|
||||
if (current_date.year, current_date.month) != (
|
||||
last_date.year,
|
||||
last_date.month,
|
||||
):
|
||||
should_rebalance = True
|
||||
|
||||
if should_rebalance:
|
||||
self._rebalance_momentum_portfolio(
|
||||
tickers,
|
||||
open_prices,
|
||||
momentum_scores,
|
||||
)
|
||||
self.momentum_last_rebalance_date = date
|
||||
|
||||
total_value = self.momentum_portfolio["cash"]
|
||||
positions: Dict[str, float] = self.momentum_portfolio["positions"]
|
||||
for ticker, shares in positions.items():
|
||||
price = close_prices.get(ticker, 0)
|
||||
total_value += shares * price
|
||||
|
||||
return total_value
|
||||
|
||||
def _rebalance_momentum_portfolio(
|
||||
self,
|
||||
tickers: List[str],
|
||||
prices: Dict[str, float],
|
||||
momentum_scores: Dict[str, float],
|
||||
):
|
||||
"""Rebalance momentum portfolio based on current momentum scores"""
|
||||
current_value = self.momentum_portfolio["cash"]
|
||||
for ticker, shares in self.momentum_portfolio["positions"].items():
|
||||
price = prices.get(ticker, 0)
|
||||
current_value += shares * price
|
||||
|
||||
self.momentum_portfolio["positions"] = {}
|
||||
|
||||
sorted_tickers = sorted(
|
||||
tickers,
|
||||
key=lambda t: momentum_scores.get(t, 0),
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
mid_point = len(sorted_tickers) // 2
|
||||
long_tickers = (
|
||||
sorted_tickers[:mid_point] if mid_point > 0 else sorted_tickers
|
||||
)
|
||||
|
||||
if len(long_tickers) == 0:
|
||||
self.momentum_portfolio["cash"] = current_value
|
||||
return
|
||||
|
||||
allocation_per_ticker = current_value / len(long_tickers)
|
||||
used_capital = 0.0
|
||||
|
||||
for ticker in long_tickers:
|
||||
price = prices.get(ticker, 0)
|
||||
if price > 0:
|
||||
shares = allocation_per_ticker / price
|
||||
self.momentum_portfolio["positions"][ticker] = shares
|
||||
used_capital += allocation_per_ticker
|
||||
|
||||
self.momentum_portfolio["cash"] = current_value - used_capital
|
||||
|
||||
def get_all_baseline_values(
|
||||
self,
|
||||
tickers: List[str],
|
||||
open_prices: Dict[str, float],
|
||||
close_prices: Dict[str, float],
|
||||
market_caps: Dict[str, float],
|
||||
momentum_scores: Dict[str, float],
|
||||
date: str,
|
||||
rebalance_momentum: bool = False,
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
Get all baseline portfolio values in one call
|
||||
|
||||
Args:
|
||||
tickers: List of stock tickers
|
||||
open_prices: Opening prices (used for initial purchase/rebalancing)
|
||||
close_prices: Closing prices (used for valuation)
|
||||
market_caps: Market caps for each ticker
|
||||
momentum_scores: Momentum scores for rebalancing
|
||||
date: Current date
|
||||
rebalance_momentum: Whether to rebalance momentum portfolio
|
||||
|
||||
Returns:
|
||||
Dict with keys: equal_weight, market_cap_weighted, momentum
|
||||
"""
|
||||
equal_weight_value = self.calculate_equal_weight_value(
|
||||
tickers,
|
||||
open_prices,
|
||||
close_prices,
|
||||
)
|
||||
market_cap_value = self.calculate_market_cap_weighted_value(
|
||||
tickers,
|
||||
open_prices,
|
||||
close_prices,
|
||||
market_caps,
|
||||
)
|
||||
momentum_value = self.calculate_momentum_value(
|
||||
tickers,
|
||||
open_prices,
|
||||
close_prices,
|
||||
momentum_scores,
|
||||
date,
|
||||
rebalance_momentum,
|
||||
)
|
||||
|
||||
return {
|
||||
"equal_weight": equal_weight_value,
|
||||
"market_cap_weighted": market_cap_value,
|
||||
"momentum": momentum_value,
|
||||
}
|
||||
|
||||
def export_state(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Export calculator state for persistence
|
||||
|
||||
Returns:
|
||||
Dictionary containing all portfolio states for serialization
|
||||
"""
|
||||
return {
|
||||
"baseline_state": {
|
||||
"initialized": self.equal_weight_initialized,
|
||||
"initial_allocation": dict(
|
||||
self.equal_weight_portfolio["positions"],
|
||||
),
|
||||
},
|
||||
"baseline_vw_state": {
|
||||
"initialized": self.market_cap_initialized,
|
||||
"initial_allocation": dict(
|
||||
self.market_cap_portfolio["positions"],
|
||||
),
|
||||
},
|
||||
"momentum_state": {
|
||||
"positions": dict(self.momentum_portfolio["positions"]),
|
||||
"cash": self.momentum_portfolio["cash"],
|
||||
"initialized": self.momentum_last_rebalance_date is not None,
|
||||
"last_rebalance_date": self.momentum_last_rebalance_date,
|
||||
},
|
||||
}
|
||||
|
||||
def load_state(self, state: Dict[str, Any]):
|
||||
"""
|
||||
Load calculator state from persistence
|
||||
|
||||
Args:
|
||||
state: Dictionary containing baseline_state, baseline_vw_state,
|
||||
momentum_state from storage
|
||||
"""
|
||||
# Load equal-weight state
|
||||
baseline_state = state.get("baseline_state", {})
|
||||
if baseline_state.get("initialized", False):
|
||||
self.equal_weight_initialized = True
|
||||
self.equal_weight_portfolio["positions"] = dict(
|
||||
baseline_state.get("initial_allocation", {}),
|
||||
)
|
||||
self.equal_weight_portfolio["cash"] = 0.0
|
||||
logger.info(
|
||||
f"Restored equal-weight portfolio with "
|
||||
f"{len(self.equal_weight_portfolio['positions'])} positions",
|
||||
)
|
||||
|
||||
# Load market-cap-weighted state
|
||||
baseline_vw_state = state.get("baseline_vw_state", {})
|
||||
if baseline_vw_state.get("initialized", False):
|
||||
self.market_cap_initialized = True
|
||||
self.market_cap_portfolio["positions"] = dict(
|
||||
baseline_vw_state.get("initial_allocation", {}),
|
||||
)
|
||||
self.market_cap_portfolio["cash"] = 0.0
|
||||
logger.info(
|
||||
f"Restored market-cap portfolio with "
|
||||
f"{len(self.market_cap_portfolio['positions'])} positions",
|
||||
)
|
||||
|
||||
# Load momentum state
|
||||
momentum_state = state.get("momentum_state", {})
|
||||
if momentum_state.get("initialized", False):
|
||||
self.momentum_portfolio["positions"] = dict(
|
||||
momentum_state.get("positions", {}),
|
||||
)
|
||||
self.momentum_portfolio["cash"] = momentum_state.get(
|
||||
"cash",
|
||||
self.initial_capital,
|
||||
)
|
||||
self.momentum_last_rebalance_date = momentum_state.get(
|
||||
"last_rebalance_date",
|
||||
)
|
||||
logger.info(
|
||||
f"Restored momentum portfolio with "
|
||||
f"{len(self.momentum_portfolio['positions'])} positions, "
|
||||
f"last rebalance: {self.momentum_last_rebalance_date}",
|
||||
)
|
||||
|
||||
|
||||
def calculate_momentum_scores(
|
||||
tickers: List[str],
|
||||
prices_history: Dict[str, List[Tuple[str, float]]],
|
||||
lookback_days: int = 20,
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
Calculate momentum scores for tickers
|
||||
|
||||
Args:
|
||||
tickers: List of tickers
|
||||
prices_history: Dict mapping ticker to list of (date, price) tuples
|
||||
lookback_days: Number of days to calculate momentum
|
||||
|
||||
Returns:
|
||||
Dict mapping ticker to momentum score (percentage return)
|
||||
"""
|
||||
momentum_scores = {}
|
||||
|
||||
for ticker in tickers:
|
||||
history = prices_history.get(ticker, [])
|
||||
if len(history) < 2:
|
||||
momentum_scores[ticker] = 0.0
|
||||
continue
|
||||
|
||||
sorted_history = sorted(history, key=lambda x: x[0])
|
||||
|
||||
if len(sorted_history) < lookback_days:
|
||||
start_price = sorted_history[0][1]
|
||||
end_price = sorted_history[-1][1]
|
||||
else:
|
||||
start_price = sorted_history[-lookback_days][1]
|
||||
end_price = sorted_history[-1][1]
|
||||
|
||||
if start_price > 0:
|
||||
momentum_scores[ticker] = (end_price - start_price) / start_price
|
||||
else:
|
||||
momentum_scores[ticker] = 0.0
|
||||
|
||||
return momentum_scores
|
||||
321
backend/utils/msg_adapter.py
Normal file
321
backend/utils/msg_adapter.py
Normal file
@@ -0,0 +1,321 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Message Adapter - Converts AgentScope Msg to frontend JSON format
|
||||
Ensures compatibility with existing frontend without modifications
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from agentscope.message import Msg
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FrontendAdapter:
|
||||
"""
|
||||
Adapter to convert AgentScope messages to frontend-compatible format
|
||||
|
||||
Frontend expects specific message types:
|
||||
- agent: Agent thinking/analysis messages
|
||||
- team_summary: Portfolio summary with equity curves
|
||||
- team_holdings: Current portfolio holdings
|
||||
- team_stats: Portfolio statistics
|
||||
- team_trades: Trade history
|
||||
- team_leaderboard: Agent performance rankings
|
||||
- price_update: Real-time price updates
|
||||
- system: System notifications
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def parse(msg: Msg) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Parse AgentScope Msg to frontend format
|
||||
|
||||
Args:
|
||||
msg: AgentScope Msg object
|
||||
|
||||
Returns:
|
||||
Dictionary in frontend format, or None if message should be skipped
|
||||
"""
|
||||
if msg is None:
|
||||
return None
|
||||
|
||||
# Determine message type based on metadata or content
|
||||
msg_type = FrontendAdapter._determine_type(msg)
|
||||
|
||||
if msg_type == "agent":
|
||||
return FrontendAdapter._format_agent_msg(msg)
|
||||
elif msg_type == "portfolio_update":
|
||||
return FrontendAdapter._format_portfolio_msg(msg)
|
||||
elif msg_type == "system":
|
||||
return FrontendAdapter._format_system_msg(msg)
|
||||
else:
|
||||
# Default: treat as agent message
|
||||
return FrontendAdapter._format_agent_msg(msg)
|
||||
|
||||
@staticmethod
|
||||
def _determine_type(msg: Msg) -> str:
|
||||
"""Determine frontend message type from Msg"""
|
||||
# Check metadata for explicit type
|
||||
if hasattr(msg, "metadata") and msg.metadata:
|
||||
if "type" in msg.metadata:
|
||||
return msg.metadata["type"]
|
||||
|
||||
# Check if message contains portfolio update
|
||||
if "portfolio" in msg.metadata:
|
||||
return "portfolio_update"
|
||||
|
||||
# Check message name/role
|
||||
if msg.name == "system":
|
||||
return "system"
|
||||
|
||||
# Default to agent message
|
||||
return "agent"
|
||||
|
||||
@staticmethod
|
||||
def _format_agent_msg(msg: object) -> Dict[str, Any]:
|
||||
"""
|
||||
Format agent message for frontend
|
||||
|
||||
Args:
|
||||
msg: Either AgentScope Msg or dict from pipeline results
|
||||
|
||||
Frontend expects:
|
||||
{
|
||||
"type": "agent",
|
||||
"role_key": "analyst_id",
|
||||
"content": "message text",
|
||||
"timestamp": "ISO timestamp"
|
||||
}
|
||||
"""
|
||||
# Handle dict from pipeline results
|
||||
if isinstance(msg, dict):
|
||||
name = msg.get("agent", "unknown")
|
||||
content = msg.get("content", "")
|
||||
else:
|
||||
# Handle Msg object
|
||||
name = msg.name
|
||||
content = msg.content
|
||||
|
||||
return {
|
||||
"type": "agent",
|
||||
"role_key": name,
|
||||
"content": content
|
||||
if isinstance(content, str)
|
||||
else json.dumps(content),
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _format_portfolio_msg(msg: Msg) -> Dict[str, Any]:
|
||||
"""
|
||||
Format portfolio update message
|
||||
|
||||
This typically generates multiple frontend messages:
|
||||
- team_summary
|
||||
- team_holdings
|
||||
- team_stats
|
||||
- team_trades (if trades were executed)
|
||||
"""
|
||||
metadata = msg.metadata or {}
|
||||
portfolio = metadata.get("portfolio", {})
|
||||
|
||||
messages: List[Dict[str, Any]] = []
|
||||
|
||||
# Generate holdings message
|
||||
holdings = FrontendAdapter.build_holdings(portfolio)
|
||||
if holdings:
|
||||
messages.append(
|
||||
{
|
||||
"type": "team_holdings",
|
||||
"data": holdings,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
},
|
||||
)
|
||||
|
||||
# Generate stats message
|
||||
stats = FrontendAdapter.build_stats(portfolio)
|
||||
if stats:
|
||||
messages.append(
|
||||
{
|
||||
"type": "team_stats",
|
||||
"data": stats,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
},
|
||||
)
|
||||
|
||||
# Generate trades message if execution logs exist
|
||||
execution_logs = metadata.get("execution_logs", [])
|
||||
if execution_logs:
|
||||
trades = FrontendAdapter.build_trades(execution_logs)
|
||||
messages.append(
|
||||
{
|
||||
"type": "team_trades",
|
||||
"mode": "incremental",
|
||||
"data": trades,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
},
|
||||
)
|
||||
|
||||
# Return composite message
|
||||
return {
|
||||
"type": "composite",
|
||||
"messages": messages,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _format_system_msg(msg: Msg) -> Dict[str, Any]:
|
||||
"""Format system message"""
|
||||
return {
|
||||
"type": "system",
|
||||
"content": msg.content
|
||||
if isinstance(msg.content, str)
|
||||
else json.dumps(msg.content),
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def build_holdings(
|
||||
portfolio: Dict[str, Any],
|
||||
prices: Dict[str, float] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Build holdings array from portfolio state"""
|
||||
holdings = []
|
||||
prices = prices or {}
|
||||
|
||||
positions = portfolio.get("positions", {})
|
||||
cash = portfolio.get("cash", 0.0)
|
||||
|
||||
# Calculate total value using current prices
|
||||
total_value = cash
|
||||
for ticker, position in positions.items():
|
||||
long_shares = position.get("long", 0)
|
||||
short_shares = position.get("short", 0)
|
||||
price = prices.get(ticker) or position.get("avg_price", 0)
|
||||
total_value += (long_shares - short_shares) * price
|
||||
|
||||
# Build holdings for each position
|
||||
for ticker, position in positions.items():
|
||||
long_shares = position.get("long", 0)
|
||||
short_shares = position.get("short", 0)
|
||||
avg_price = position.get("avg_price", 0)
|
||||
current_price = prices.get(ticker) or avg_price
|
||||
|
||||
net_shares = long_shares - short_shares
|
||||
if net_shares == 0:
|
||||
continue
|
||||
|
||||
market_value = net_shares * current_price
|
||||
weight = market_value / total_value if total_value > 0 else 0
|
||||
|
||||
holdings.append(
|
||||
{
|
||||
"ticker": ticker,
|
||||
"quantity": net_shares,
|
||||
"avg": avg_price,
|
||||
"currentPrice": current_price,
|
||||
"marketValue": market_value,
|
||||
"weight": weight,
|
||||
},
|
||||
)
|
||||
|
||||
# Add cash as a holding
|
||||
if cash > 0:
|
||||
holdings.append(
|
||||
{
|
||||
"ticker": "CASH",
|
||||
"quantity": 1,
|
||||
"avg": cash,
|
||||
"currentPrice": cash,
|
||||
"marketValue": cash,
|
||||
"weight": cash / total_value if total_value > 0 else 0,
|
||||
},
|
||||
)
|
||||
|
||||
return holdings
|
||||
|
||||
@staticmethod
|
||||
def build_stats(
|
||||
portfolio: Dict[str, Any],
|
||||
prices: Dict[str, float] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Build stats dictionary from portfolio"""
|
||||
prices = prices or {}
|
||||
positions = portfolio.get("positions", {})
|
||||
cash = portfolio.get("cash", 0.0)
|
||||
margin_used = portfolio.get("margin_used", 0.0)
|
||||
|
||||
# Calculate total value using current prices
|
||||
total_value = cash
|
||||
for ticker, position in positions.items():
|
||||
long_shares = position.get("long", 0)
|
||||
short_shares = position.get("short", 0)
|
||||
price = prices.get(ticker) or position.get("avg_price", 0)
|
||||
total_value += (long_shares - short_shares) * price
|
||||
|
||||
# Calculate ticker weights
|
||||
ticker_weights = {}
|
||||
for ticker, position in positions.items():
|
||||
long_shares = position.get("long", 0)
|
||||
short_shares = position.get("short", 0)
|
||||
price = prices.get(ticker) or position.get("avg_price", 0)
|
||||
|
||||
market_value = (long_shares - short_shares) * price
|
||||
if market_value != 0:
|
||||
ticker_weights[ticker] = (
|
||||
market_value / total_value if total_value > 0 else 0
|
||||
)
|
||||
|
||||
# Calculate total return
|
||||
initial_cash = portfolio.get("initial_cash", 100000.0)
|
||||
total_return = (
|
||||
((total_value - initial_cash) / initial_cash * 100)
|
||||
if initial_cash > 0
|
||||
else 0.0
|
||||
)
|
||||
|
||||
return {
|
||||
"totalAssetValue": round(total_value, 2),
|
||||
"totalReturn": round(total_return, 2),
|
||||
"cashPosition": round(cash, 2),
|
||||
"tickerWeights": ticker_weights,
|
||||
"marginUsed": round(margin_used, 2),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def build_trades(execution_logs: List[str]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Build trades array from execution logs
|
||||
|
||||
Frontend expects:
|
||||
[{
|
||||
"ts": 1234567890,
|
||||
"ticker": "AAPL",
|
||||
"side": "LONG",
|
||||
"qty": 100,
|
||||
"price": 150.0,
|
||||
"reason": "Buy signal"
|
||||
}, ...]
|
||||
"""
|
||||
trades = []
|
||||
timestamp = int(datetime.now().timestamp() * 1000)
|
||||
|
||||
for log in execution_logs:
|
||||
# Parse execution log (simplified - should use structured data)
|
||||
if "Executed" in log:
|
||||
# Extract trade details from log string
|
||||
# in real implementation, pass structured data
|
||||
trades.append(
|
||||
{
|
||||
"ts": timestamp,
|
||||
"ticker": "UNKNOWN", # Should parse from log
|
||||
"side": "LONG", # Should parse from log
|
||||
"qty": 0, # Should parse from log
|
||||
"price": 0.0, # Should parse from log
|
||||
"reason": log,
|
||||
},
|
||||
)
|
||||
|
||||
return trades
|
||||
140
backend/utils/progress.py
Normal file
140
backend/utils/progress.py
Normal file
@@ -0,0 +1,140 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from datetime import datetime, timezone
|
||||
from typing import Callable, Optional
|
||||
|
||||
from rich.console import Console
|
||||
from rich.live import Live
|
||||
from rich.style import Style
|
||||
from rich.table import Table
|
||||
from rich.text import Text
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
class AgentProgress:
|
||||
"""Manages progress tracking for multiple agents."""
|
||||
|
||||
def __init__(self):
|
||||
self.agent_status = {}
|
||||
self.table = Table(show_header=False, box=None, padding=(0, 1))
|
||||
self.live = Live(self.table, console=console, refresh_per_second=4)
|
||||
self.started = False
|
||||
self.update_handlers = []
|
||||
|
||||
def register_handler(
|
||||
self,
|
||||
handler: Callable[[str, Optional[str], str], None],
|
||||
):
|
||||
"""Register a handler to be called when agent status updates."""
|
||||
self.update_handlers.append(handler)
|
||||
return handler # Return handler to support use as decorator
|
||||
|
||||
def unregister_handler(
|
||||
self,
|
||||
handler: Callable[[str, Optional[str], str], None],
|
||||
):
|
||||
"""Unregister a previously registered handler."""
|
||||
if handler in self.update_handlers:
|
||||
self.update_handlers.remove(handler)
|
||||
|
||||
def start(self):
|
||||
"""Start the progress display."""
|
||||
if not self.started:
|
||||
self.live.start()
|
||||
self.started = True
|
||||
|
||||
def stop(self):
|
||||
"""Stop the progress display."""
|
||||
if self.started:
|
||||
self.live.stop()
|
||||
self.started = False
|
||||
|
||||
def update_status(
|
||||
self,
|
||||
agent_name: str,
|
||||
ticker: Optional[str] = None,
|
||||
status: str = "",
|
||||
analysis: Optional[str] = None,
|
||||
):
|
||||
"""Update the status of an agent."""
|
||||
if agent_name not in self.agent_status:
|
||||
self.agent_status[agent_name] = {"status": "", "ticker": None}
|
||||
|
||||
if ticker:
|
||||
self.agent_status[agent_name]["ticker"] = ticker
|
||||
if status:
|
||||
self.agent_status[agent_name]["status"] = status
|
||||
if analysis:
|
||||
self.agent_status[agent_name]["analysis"] = analysis
|
||||
|
||||
# Set the timestamp as UTC datetime
|
||||
timestamp = datetime.now(timezone.utc).isoformat()
|
||||
self.agent_status[agent_name]["timestamp"] = timestamp
|
||||
|
||||
# Notify all registered handlers
|
||||
for handler in self.update_handlers:
|
||||
handler(agent_name, ticker, status, analysis, timestamp)
|
||||
|
||||
self._refresh_display()
|
||||
|
||||
def get_all_status(self):
|
||||
"""Get the current status of all agents as a dictionary."""
|
||||
return {
|
||||
agent_name: {
|
||||
"ticker": info["ticker"],
|
||||
"status": info["status"],
|
||||
"display_name": self._get_display_name(agent_name),
|
||||
}
|
||||
for agent_name, info in self.agent_status.items()
|
||||
}
|
||||
|
||||
def _get_display_name(self, agent_name: str) -> str:
|
||||
"""Convert agent_name to a display-friendly format."""
|
||||
return agent_name.replace("_agent", "").replace("_", " ").title()
|
||||
|
||||
def _refresh_display(self):
|
||||
"""Refresh the progress display."""
|
||||
self.table.columns.clear()
|
||||
self.table.add_column(width=100)
|
||||
|
||||
# Sort Risk Management and Portfolio Management at the bottom
|
||||
def sort_key(item):
|
||||
agent_name = item[0]
|
||||
if "risk_manager" in agent_name:
|
||||
return (2, agent_name)
|
||||
elif "portfolio_manager" in agent_name:
|
||||
return (3, agent_name)
|
||||
else:
|
||||
return (1, agent_name)
|
||||
|
||||
for agent_name, info in sorted(
|
||||
self.agent_status.items(),
|
||||
key=sort_key,
|
||||
):
|
||||
status = info["status"]
|
||||
ticker = info["ticker"]
|
||||
# Create the status text with appropriate styling
|
||||
if status.lower() == "done":
|
||||
style = Style(color="green", bold=True)
|
||||
symbol = "✓"
|
||||
elif status.lower() == "error":
|
||||
style = Style(color="red", bold=True)
|
||||
symbol = "✗"
|
||||
else:
|
||||
style = Style(color="yellow")
|
||||
symbol = "⋯"
|
||||
|
||||
agent_display = self._get_display_name(agent_name)
|
||||
status_text = Text()
|
||||
status_text.append(f"{symbol} ", style=style)
|
||||
status_text.append(f"{agent_display:<20}", style=Style(bold=True))
|
||||
|
||||
if ticker:
|
||||
status_text.append(f"[{ticker}] ", style=Style(color="cyan"))
|
||||
status_text.append(status, style=style)
|
||||
|
||||
self.table.add_row(status_text)
|
||||
|
||||
|
||||
# Create a global instance
|
||||
progress = AgentProgress()
|
||||
362
backend/utils/settlement.py
Normal file
362
backend/utils/settlement.py
Normal file
@@ -0,0 +1,362 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Settlement Coordinator
|
||||
Unified daily settlement logic for agent portfolio, baselines, and analyst tracking
|
||||
"""
|
||||
# flake8: noqa: E501
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from backend.services.storage import StorageService
|
||||
from backend.utils.analyst_tracker import (
|
||||
AnalystPerformanceTracker,
|
||||
update_leaderboard_with_evaluations,
|
||||
)
|
||||
from backend.utils.baselines import (
|
||||
BaselineCalculator,
|
||||
calculate_momentum_scores,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SettlementCoordinator:
|
||||
"""
|
||||
Coordinates daily settlement after market close
|
||||
|
||||
Responsibilities:
|
||||
1. Calculate agent portfolio P&L
|
||||
2. Update baseline portfolios (equal-weight, market-cap, momentum)
|
||||
3. Evaluate analyst predictions and update leaderboard
|
||||
4. Update summary.json with all portfolio values
|
||||
5. Persist state to storage
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
storage: "StorageService",
|
||||
initial_capital: float = 100000.0,
|
||||
):
|
||||
self.storage = storage
|
||||
self.initial_capital = initial_capital
|
||||
self.baseline_calculator = BaselineCalculator(initial_capital)
|
||||
self.analyst_tracker = AnalystPerformanceTracker()
|
||||
|
||||
self.price_history: Dict[str, List[tuple]] = {}
|
||||
|
||||
# Load persisted state from storage
|
||||
self._load_persisted_state()
|
||||
|
||||
def _load_persisted_state(self):
|
||||
"""
|
||||
Load persisted baseline and price history state from storage
|
||||
|
||||
This restores the baseline calculator state so that backtest/live mode
|
||||
can resume from where it left off.
|
||||
"""
|
||||
internal_state = self.storage.load_internal_state()
|
||||
|
||||
# Load baseline calculator state
|
||||
baseline_state = {
|
||||
"baseline_state": internal_state.get("baseline_state", {}),
|
||||
"baseline_vw_state": internal_state.get("baseline_vw_state", {}),
|
||||
"momentum_state": internal_state.get("momentum_state", {}),
|
||||
}
|
||||
self.baseline_calculator.load_state(baseline_state)
|
||||
|
||||
# Load price history for momentum calculation
|
||||
saved_price_history = internal_state.get("price_history", {})
|
||||
if saved_price_history:
|
||||
# Convert saved format back to list of tuples
|
||||
for ticker, history in saved_price_history.items():
|
||||
converted_history = []
|
||||
for entry in history:
|
||||
if isinstance(entry, dict):
|
||||
converted_history.append(
|
||||
(entry["date"], entry["price"]),
|
||||
)
|
||||
elif isinstance(entry, (list, tuple)) and len(entry) >= 2:
|
||||
converted_history.append((entry[0], entry[1]))
|
||||
else:
|
||||
continue
|
||||
self.price_history[ticker] = converted_history
|
||||
logger.info(
|
||||
f"Restored price history for {len(self.price_history)} tickers",
|
||||
)
|
||||
|
||||
def _save_persisted_state(self):
|
||||
"""
|
||||
Save baseline and price history state to storage
|
||||
|
||||
This persists the baseline calculator state so that backtest/live mode
|
||||
can resume from where it left off after restart.
|
||||
"""
|
||||
internal_state = self.storage.load_internal_state()
|
||||
|
||||
# Export baseline calculator state
|
||||
baseline_state = self.baseline_calculator.export_state()
|
||||
internal_state["baseline_state"] = baseline_state["baseline_state"]
|
||||
internal_state["baseline_vw_state"] = baseline_state[
|
||||
"baseline_vw_state"
|
||||
]
|
||||
internal_state["momentum_state"] = baseline_state["momentum_state"]
|
||||
|
||||
# Save price history (convert tuples to dicts for JSON serialization)
|
||||
price_history_serializable = {}
|
||||
for ticker, history in self.price_history.items():
|
||||
price_history_serializable[ticker] = [
|
||||
{"date": date, "price": price} for date, price in history
|
||||
]
|
||||
internal_state["price_history"] = price_history_serializable
|
||||
|
||||
self.storage.save_internal_state(internal_state)
|
||||
logger.info("Persisted baseline calculator and price history state")
|
||||
|
||||
def record_analyst_predictions(
|
||||
self,
|
||||
final_predictions: List[Dict[str, Any]],
|
||||
):
|
||||
"""
|
||||
Record structured analyst predictions before market close
|
||||
|
||||
Args:
|
||||
final_predictions: Structured prediction results from analysts
|
||||
Format: [
|
||||
{
|
||||
'agent': 'analyst_name',
|
||||
'predictions': [
|
||||
{'ticker': 'AAPL', 'direction': 'up', 'confidence': 0.75},
|
||||
...
|
||||
]
|
||||
},
|
||||
...
|
||||
]
|
||||
tickers: List of tickers being analyzed
|
||||
"""
|
||||
self.analyst_tracker.record_analyst_predictions(final_predictions)
|
||||
|
||||
def update_price_history(
|
||||
self,
|
||||
date: str,
|
||||
prices: Dict[str, float],
|
||||
):
|
||||
"""
|
||||
Update price history for momentum calculation
|
||||
|
||||
Args:
|
||||
date: Trading date (YYYY-MM-DD)
|
||||
prices: Current prices for each ticker
|
||||
"""
|
||||
for ticker, price in prices.items():
|
||||
if ticker not in self.price_history:
|
||||
self.price_history[ticker] = []
|
||||
self.price_history[ticker].append((date, price))
|
||||
|
||||
self.price_history[ticker] = self.price_history[ticker][-60:]
|
||||
|
||||
def run_daily_settlement(
|
||||
self,
|
||||
date: str,
|
||||
tickers: List[str],
|
||||
open_prices: Optional[Dict[str, float]],
|
||||
close_prices: Dict[str, float],
|
||||
market_caps: Dict[str, float],
|
||||
agent_portfolio: Dict[str, Any],
|
||||
analyst_results: List[Dict[str, Any]], # pylint: disable=W0613
|
||||
pm_decisions: Optional[Dict[str, Dict]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Run complete daily settlement
|
||||
|
||||
Args:
|
||||
date: Trading date (YYYY-MM-DD)
|
||||
tickers: List of tickers
|
||||
open_prices: Opening prices
|
||||
close_prices: Closing prices
|
||||
market_caps: Market caps for each ticker
|
||||
agent_portfolio: Current agent portfolio state
|
||||
analyst_results: Analyst analysis results
|
||||
pm_decisions: PM's trading decisions
|
||||
|
||||
Returns:
|
||||
Settlement results including all portfolio values and evaluations
|
||||
"""
|
||||
logger.info(f"Running daily settlement for {date}")
|
||||
|
||||
self.update_price_history(date, close_prices)
|
||||
|
||||
momentum_scores = calculate_momentum_scores(
|
||||
tickers,
|
||||
self.price_history,
|
||||
lookback_days=20,
|
||||
)
|
||||
|
||||
rebalance_momentum = self._should_rebalance_momentum(date)
|
||||
|
||||
baseline_values = self.baseline_calculator.get_all_baseline_values(
|
||||
tickers=tickers,
|
||||
open_prices=open_prices if open_prices else close_prices,
|
||||
close_prices=close_prices,
|
||||
market_caps=market_caps,
|
||||
momentum_scores=momentum_scores,
|
||||
date=date,
|
||||
rebalance_momentum=rebalance_momentum,
|
||||
)
|
||||
|
||||
logger.info(f"Baseline values calculated: {baseline_values}")
|
||||
|
||||
agent_value = self.storage.calculate_portfolio_value(
|
||||
agent_portfolio,
|
||||
close_prices,
|
||||
)
|
||||
|
||||
analyst_evaluations = self.analyst_tracker.evaluate_predictions(
|
||||
open_prices,
|
||||
close_prices,
|
||||
date,
|
||||
)
|
||||
|
||||
pm_evaluations = {}
|
||||
if pm_decisions:
|
||||
pm_evaluations = self.analyst_tracker.evaluate_pm_decisions(
|
||||
pm_decisions,
|
||||
open_prices,
|
||||
close_prices,
|
||||
date,
|
||||
)
|
||||
|
||||
all_evaluations = {**analyst_evaluations, **pm_evaluations}
|
||||
|
||||
leaderboard = self.storage.load_file("leaderboard") or []
|
||||
updated_leaderboard = update_leaderboard_with_evaluations(
|
||||
leaderboard,
|
||||
all_evaluations,
|
||||
)
|
||||
self.storage.save_file("leaderboard", updated_leaderboard)
|
||||
|
||||
self._update_summary_with_baselines(
|
||||
date,
|
||||
agent_value,
|
||||
baseline_values,
|
||||
)
|
||||
|
||||
self.analyst_tracker.clear_daily_predictions()
|
||||
|
||||
# Persist baseline calculator and price history state
|
||||
self._save_persisted_state()
|
||||
|
||||
return {
|
||||
"date": date,
|
||||
"agent_portfolio_value": agent_value,
|
||||
"baseline_values": baseline_values,
|
||||
"analyst_evaluations": analyst_evaluations,
|
||||
"baselines_updated": True,
|
||||
"leaderboard_updated": True,
|
||||
}
|
||||
|
||||
def _should_rebalance_momentum(self, date: str) -> bool:
|
||||
"""
|
||||
Check if momentum portfolio should rebalance
|
||||
|
||||
Returns True if it's a new month
|
||||
"""
|
||||
last_rebalance = self.baseline_calculator.momentum_last_rebalance_date
|
||||
if last_rebalance is None:
|
||||
return True
|
||||
|
||||
last_date = datetime.strptime(last_rebalance, "%Y-%m-%d")
|
||||
current_date = datetime.strptime(date, "%Y-%m-%d")
|
||||
|
||||
return (current_date.year, current_date.month) != (
|
||||
last_date.year,
|
||||
last_date.month,
|
||||
)
|
||||
|
||||
def _update_summary_with_baselines(
|
||||
self,
|
||||
date: str,
|
||||
agent_value: float,
|
||||
baseline_values: Dict[str, float],
|
||||
):
|
||||
"""
|
||||
Update summary.json with agent and baseline portfolio values
|
||||
|
||||
NOTE: History updates are now handled centrally by storage.update_dashboard_after_cycle()
|
||||
to ensure all histories (equity, baseline, baseline_vw, momentum) stay synchronized.
|
||||
baseline_values are returned in run_daily_settlement() and passed to storage.
|
||||
|
||||
Args:
|
||||
date: Trading date (used for backtest-compatible timestamps)
|
||||
agent_value: Agent portfolio value
|
||||
baseline_values: Baseline portfolio values
|
||||
"""
|
||||
# History updates are now handled by storage.update_dashboard_after_cycle()
|
||||
# which receives baseline_values from settlement_result and updates all histories together.
|
||||
# This ensures equity and baseline data points are always synchronized.
|
||||
|
||||
def update_intraday_values(
|
||||
self,
|
||||
tickers: List[str],
|
||||
current_prices: Dict[str, float],
|
||||
market_caps: Dict[str, float],
|
||||
agent_portfolio: Dict[str, Any],
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
Update portfolio values with current prices (for live mode intraday updates)
|
||||
|
||||
Args:
|
||||
tickers: List of tickers
|
||||
current_prices: Current prices
|
||||
market_caps: Market caps
|
||||
agent_portfolio: Current agent portfolio
|
||||
|
||||
Returns:
|
||||
Dict with current portfolio values
|
||||
"""
|
||||
agent_value = self.storage.calculate_portfolio_value(
|
||||
agent_portfolio,
|
||||
current_prices,
|
||||
)
|
||||
|
||||
equal_weight = self.baseline_calculator.calculate_equal_weight_value(
|
||||
tickers,
|
||||
current_prices,
|
||||
current_prices,
|
||||
)
|
||||
market_cap = (
|
||||
self.baseline_calculator.calculate_market_cap_weighted_value(
|
||||
tickers,
|
||||
current_prices,
|
||||
current_prices,
|
||||
market_caps,
|
||||
)
|
||||
)
|
||||
|
||||
momentum_scores = calculate_momentum_scores(
|
||||
tickers,
|
||||
self.price_history,
|
||||
lookback_days=20,
|
||||
)
|
||||
|
||||
last_date = (
|
||||
list(self.price_history.values())[0][-1][0]
|
||||
if self.price_history
|
||||
else ""
|
||||
)
|
||||
|
||||
momentum = self.baseline_calculator.calculate_momentum_value(
|
||||
tickers,
|
||||
current_prices,
|
||||
current_prices,
|
||||
momentum_scores,
|
||||
date=last_date,
|
||||
rebalance=False,
|
||||
)
|
||||
|
||||
return {
|
||||
"agent": agent_value,
|
||||
"equal_weight": equal_weight,
|
||||
"market_cap_weighted": market_cap,
|
||||
"momentum": momentum,
|
||||
}
|
||||
348
backend/utils/terminal_dashboard.py
Normal file
348
backend/utils/terminal_dashboard.py
Normal file
@@ -0,0 +1,348 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Terminal Dashboard - Persistent unified panel using Rich Live
|
||||
"""
|
||||
# pylint: disable=R0915,R0912
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from rich.console import Console
|
||||
from rich.live import Live
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TerminalDashboard:
|
||||
"""Unified persistent terminal dashboard"""
|
||||
|
||||
def __init__(self, console: Console = None):
|
||||
self.console = console or Console()
|
||||
self.live: Optional[Live] = None
|
||||
|
||||
# Config state
|
||||
self.mode = "live"
|
||||
self.config_name = ""
|
||||
self.host = "0.0.0.0"
|
||||
self.port = 8765
|
||||
self.poll_interval = 10
|
||||
self.trigger_time = "now"
|
||||
self.mock = False
|
||||
self.enable_memory = False
|
||||
self.local_time = ""
|
||||
self.nyse_time = ""
|
||||
self.start_date = ""
|
||||
self.end_date = ""
|
||||
self.tickers: List[str] = []
|
||||
self.initial_cash = 100000.0
|
||||
|
||||
# Trading state
|
||||
self.current_date = "-"
|
||||
self.status = "Initializing"
|
||||
self.total_value = 0.0
|
||||
self.cash = 0.0
|
||||
self.pnl_pct = 0.0
|
||||
self.holdings: List[Dict] = []
|
||||
self.trades: List[Dict] = []
|
||||
self.days_completed = 0
|
||||
self.days_total = 0
|
||||
|
||||
# Progress message (last line)
|
||||
self.progress = ""
|
||||
self._dots_index = 0
|
||||
self._animator_running = False
|
||||
self._animator_thread: Optional[threading.Thread] = None
|
||||
|
||||
def set_config(
|
||||
self,
|
||||
mode: str,
|
||||
config_name: str,
|
||||
host: str,
|
||||
port: int,
|
||||
poll_interval: int,
|
||||
trigger_time: str = "now",
|
||||
mock: bool = False,
|
||||
enable_memory: bool = False,
|
||||
local_time: str = "",
|
||||
nyse_time: str = "",
|
||||
start_date: str = "",
|
||||
end_date: str = "",
|
||||
tickers: List[str] = None,
|
||||
initial_cash: float = 100000.0,
|
||||
):
|
||||
"""Set configuration state"""
|
||||
self.mode = mode
|
||||
self.config_name = config_name
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.poll_interval = poll_interval
|
||||
self.trigger_time = trigger_time
|
||||
self.mock = mock
|
||||
self.enable_memory = enable_memory
|
||||
self.local_time = local_time
|
||||
self.nyse_time = nyse_time
|
||||
self.start_date = start_date
|
||||
self.end_date = end_date
|
||||
self.tickers = tickers or []
|
||||
self.initial_cash = initial_cash
|
||||
self.total_value = initial_cash
|
||||
self.cash = initial_cash
|
||||
|
||||
def _build_panel(self) -> Panel:
|
||||
"""Build the unified dashboard panel"""
|
||||
# Main grid
|
||||
main_table = Table.grid(padding=(0, 2))
|
||||
main_table.add_column(width=28)
|
||||
main_table.add_column(width=22)
|
||||
main_table.add_column(width=22)
|
||||
|
||||
# Left: Config + Status
|
||||
left = Table.grid(padding=(0, 0))
|
||||
left.add_column()
|
||||
|
||||
# Mode line
|
||||
if self.mode == "backtest":
|
||||
mode_str = "[cyan]Backtest[/cyan]"
|
||||
elif self.mock:
|
||||
mode_str = "[yellow]MOCK[/yellow]"
|
||||
else:
|
||||
mode_str = "[green]LIVE[/green]"
|
||||
|
||||
left.add_row(f"[bold]Mode:[/bold] {mode_str}")
|
||||
left.add_row(f"[dim]Config:[/dim] {self.config_name}")
|
||||
left.add_row(f"[dim]Server:[/dim] {self.host}:{self.port}")
|
||||
|
||||
if self.mode == "live" and self.nyse_time:
|
||||
left.add_row(f"[dim]NYSE:[/dim] {self.nyse_time[:19]}")
|
||||
trigger_display = (
|
||||
"[green]NOW[/green]"
|
||||
if self.trigger_time == "now"
|
||||
else self.trigger_time
|
||||
)
|
||||
left.add_row(f"[dim]Trigger:[/dim] {trigger_display}")
|
||||
|
||||
# Status
|
||||
left.add_row("")
|
||||
status_style = "green" if self.status == "Running" else "yellow"
|
||||
left.add_row(
|
||||
"[bold]Status:[/bold] "
|
||||
f"[{status_style}]{self.status}[/{status_style}]",
|
||||
)
|
||||
if self.mode == "backtest":
|
||||
left.add_row(
|
||||
f"[dim]Backtesting Period:[/dim] {self.days_total} days\n"
|
||||
f" {self.start_date} -> {self.end_date}",
|
||||
)
|
||||
left.add_row(f"[dim]Current Date:[/dim] {self.current_date}")
|
||||
|
||||
# Middle: Portfolio
|
||||
mid = Table.grid(padding=(0, 0))
|
||||
mid.add_column()
|
||||
|
||||
pnl_style = "green" if self.pnl_pct >= 0 else "red"
|
||||
mid.add_row("[bold]Portfolio[/bold]")
|
||||
mid.add_row(f"NAV: [bold]${self.total_value:,.0f}[/bold]")
|
||||
mid.add_row(f"Cash: ${self.cash:,.0f}")
|
||||
mid.add_row(f"P&L: [{pnl_style}]{self.pnl_pct:+.2f}%[/{pnl_style}]")
|
||||
|
||||
# Positions
|
||||
mid.add_row("")
|
||||
mid.add_row("[bold]Positions[/bold]")
|
||||
stock_holdings = [
|
||||
h for h in self.holdings if h.get("ticker") != "CASH"
|
||||
]
|
||||
if stock_holdings:
|
||||
for h in stock_holdings[:7]:
|
||||
qty = h.get("quantity", 0)
|
||||
ticker = h.get("ticker", "")[:5]
|
||||
val = h.get("marketValue", 0)
|
||||
qty_str = f"{qty:+d}" if qty != 0 else "0"
|
||||
mid.add_row(
|
||||
f"[cyan]{ticker:<5}[/cyan] {qty_str:>5} ${val:>7,.0f}",
|
||||
)
|
||||
if len(stock_holdings) > 7:
|
||||
mid.add_row(f"[dim]+{len(stock_holdings) - 7} more[/dim]")
|
||||
else:
|
||||
mid.add_row("[dim]No positions[/dim]")
|
||||
|
||||
# Right: Recent Trades
|
||||
right = Table.grid(padding=(0, 0))
|
||||
right.add_column()
|
||||
|
||||
right.add_row("[bold]Recent Trades[/bold]")
|
||||
if self.trades:
|
||||
for t in self.trades[:10]:
|
||||
side = t.get("side", "")
|
||||
ticker = t.get("ticker", "")[:5]
|
||||
qty = t.get("qty", 0)
|
||||
if side == "LONG":
|
||||
side_str = "[green]L[/green]"
|
||||
elif side == "SHORT":
|
||||
side_str = "[red]S[/red]"
|
||||
else:
|
||||
side_str = "[dim]H[/dim]"
|
||||
right.add_row(f"{side_str} [cyan]{ticker:<5}[/cyan] {qty:>4}")
|
||||
if len(self.trades) > 10:
|
||||
right.add_row(f"[dim]+{len(self.trades) - 10} more[/dim]")
|
||||
else:
|
||||
right.add_row("[dim]No trades[/dim]")
|
||||
|
||||
main_table.add_row(left, mid, right)
|
||||
|
||||
# Outer table to add progress line at bottom
|
||||
outer = Table.grid(padding=(0, 0))
|
||||
outer.add_column()
|
||||
outer.add_row(main_table)
|
||||
|
||||
# Progress line (last row) with animated dots
|
||||
if self.progress:
|
||||
DOTS_FRAMES = [" ", ". ", ".. ", "..."]
|
||||
dots = DOTS_FRAMES[self._dots_index % len(DOTS_FRAMES)]
|
||||
outer.add_row("")
|
||||
outer.add_row(f"[dim]> {self.progress}{dots}[/dim]")
|
||||
|
||||
# Build panel
|
||||
title = "[bold cyan]EvoTraders[/bold cyan]"
|
||||
if self.mode == "backtest":
|
||||
title += " [dim]Backtest[/dim]"
|
||||
elif self.mock:
|
||||
title += " [dim]Mock[/dim]"
|
||||
else:
|
||||
title += " [dim]Live[/dim]"
|
||||
|
||||
return Panel(
|
||||
outer,
|
||||
title=title,
|
||||
border_style="cyan",
|
||||
padding=(0, 1),
|
||||
)
|
||||
|
||||
def _run_animator(self):
|
||||
"""Background thread to animate the dots"""
|
||||
while self._animator_running:
|
||||
time.sleep(0.3)
|
||||
if self.progress and self.live:
|
||||
self._dots_index += 1
|
||||
self.live.update(self._build_panel())
|
||||
|
||||
def start(self):
|
||||
"""Start the live dashboard display"""
|
||||
self.live = Live(
|
||||
self._build_panel(),
|
||||
console=self.console,
|
||||
refresh_per_second=4,
|
||||
vertical_overflow="visible",
|
||||
)
|
||||
self.live.start()
|
||||
|
||||
# Start animator thread
|
||||
self._animator_running = True
|
||||
self._animator_thread = threading.Thread(
|
||||
target=self._run_animator,
|
||||
daemon=True,
|
||||
)
|
||||
self._animator_thread.start()
|
||||
|
||||
def stop(self):
|
||||
"""Stop the live dashboard"""
|
||||
self._animator_running = False
|
||||
if self._animator_thread:
|
||||
self._animator_thread.join(timeout=0.5)
|
||||
self._animator_thread = None
|
||||
if self.live:
|
||||
self.live.stop()
|
||||
self.live = None
|
||||
|
||||
def update(
|
||||
self,
|
||||
date: str = None,
|
||||
status: str = None,
|
||||
portfolio: Dict[str, Any] = None,
|
||||
holdings: List[Dict] = None,
|
||||
trades: List[Dict] = None,
|
||||
days_completed: int = None,
|
||||
days_total: int = None,
|
||||
):
|
||||
"""Update dashboard state and refresh display"""
|
||||
if date:
|
||||
self.current_date = date
|
||||
if status:
|
||||
self.status = status
|
||||
if days_completed is not None:
|
||||
self.days_completed = days_completed
|
||||
if days_total is not None:
|
||||
self.days_total = days_total
|
||||
|
||||
if portfolio:
|
||||
self.total_value = portfolio.get(
|
||||
"totalAssetValue",
|
||||
0,
|
||||
) or portfolio.get(
|
||||
"total_value",
|
||||
self.initial_cash,
|
||||
)
|
||||
self.cash = portfolio.get("cashPosition", 0) or portfolio.get(
|
||||
"cash",
|
||||
self.initial_cash,
|
||||
)
|
||||
if self.total_value > 0 and self.initial_cash > 0:
|
||||
self.pnl_pct = (
|
||||
(self.total_value - self.initial_cash) / self.initial_cash
|
||||
) * 100
|
||||
|
||||
if holdings is not None:
|
||||
self.holdings = holdings
|
||||
if trades is not None:
|
||||
self.trades = trades
|
||||
|
||||
if self.live:
|
||||
self.live.update(self._build_panel())
|
||||
|
||||
def log(self, msg: str, also_log: bool = True):
|
||||
"""
|
||||
Update progress message and refresh panel
|
||||
|
||||
Args:
|
||||
msg: Progress message to display
|
||||
also_log: Whether to also write to logger (default True)
|
||||
"""
|
||||
self.progress = msg
|
||||
if also_log:
|
||||
logger.info(msg)
|
||||
if self.live:
|
||||
self.live.update(self._build_panel())
|
||||
|
||||
def print_final_summary(self):
|
||||
"""Print final summary when dashboard stops"""
|
||||
pnl_style = "green" if self.pnl_pct >= 0 else "red"
|
||||
|
||||
if self.mode == "backtest":
|
||||
msg = (
|
||||
f"[bold]Backtest Complete[/bold] | "
|
||||
f"Days: {self.days_completed} | "
|
||||
f"NAV: ${self.total_value:,.0f} | "
|
||||
f"Return: [{pnl_style}]{self.pnl_pct:+.2f}%[/{pnl_style}]"
|
||||
)
|
||||
else:
|
||||
msg = (
|
||||
f"[bold]Session End[/bold] | "
|
||||
f"NAV: ${self.total_value:,.0f} | "
|
||||
f"P&L: [{pnl_style}]{self.pnl_pct:+.2f}%[/{pnl_style}]"
|
||||
)
|
||||
|
||||
self.console.print(Panel(msg, border_style="green"))
|
||||
|
||||
|
||||
# Global instance
|
||||
_dashboard: Optional[TerminalDashboard] = None
|
||||
|
||||
|
||||
def get_dashboard() -> TerminalDashboard:
|
||||
"""Get or create global dashboard instance"""
|
||||
global _dashboard
|
||||
if _dashboard is None:
|
||||
_dashboard = TerminalDashboard()
|
||||
return _dashboard
|
||||
772
backend/utils/trade_executor.py
Normal file
772
backend/utils/trade_executor.py
Normal file
@@ -0,0 +1,772 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Trading Execution Engine - Supports Two Modes
|
||||
1. Signal mode: Only records directional signal decisions
|
||||
2. Portfolio mode: Executes specific trades and tracks positions
|
||||
"""
|
||||
# flake8: noqa: E501
|
||||
from copy import deepcopy
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
class DirectionSignalRecorder:
|
||||
"""Direction signal recorder, records daily investment direction decisions"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize direction signal recorder"""
|
||||
self.signal_log = [] # Record all directional signal history
|
||||
|
||||
def record_direction_signals(
|
||||
self,
|
||||
decisions: Dict[str, Dict[str, Any]],
|
||||
current_date: str = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Record Portfolio Manager's directional signal decisions
|
||||
|
||||
Args:
|
||||
decisions: PM's direction decisions {ticker: {action, confidence, reasoning}}
|
||||
current_date: Current date (used for backtest compatibility)
|
||||
|
||||
Returns:
|
||||
Signal recording report
|
||||
"""
|
||||
if current_date is None:
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
|
||||
# Use provided date for timestamp (backtest compatible)
|
||||
timestamp = f"{current_date}T09:30:00"
|
||||
|
||||
signal_report: Dict[str, Any] = {
|
||||
"recorded_signals": {},
|
||||
"date": current_date,
|
||||
"timestamp": timestamp,
|
||||
"total_signals": len(decisions),
|
||||
}
|
||||
|
||||
print(
|
||||
f"\n📊 Recording directional signal decisions for {current_date}...",
|
||||
)
|
||||
|
||||
# Record directional signal for each ticker
|
||||
for ticker, decision in decisions.items():
|
||||
action = decision.get("action", "hold")
|
||||
confidence = decision.get("confidence", 0)
|
||||
reasoning = decision.get("reasoning", "")
|
||||
|
||||
# Record signal
|
||||
signal_record = {
|
||||
"ticker": ticker,
|
||||
"action": action,
|
||||
"confidence": confidence,
|
||||
"reasoning": reasoning,
|
||||
"date": current_date,
|
||||
"timestamp": timestamp,
|
||||
}
|
||||
|
||||
self.signal_log.append(signal_record)
|
||||
signal_report["recorded_signals"][ticker] = {
|
||||
"action": action,
|
||||
"confidence": confidence,
|
||||
}
|
||||
|
||||
# Display signal
|
||||
action_emoji = {"long": "📈", "short": "📉", "hold": "➖"}
|
||||
emoji = action_emoji.get(action, "❓")
|
||||
print(
|
||||
f" {emoji} {ticker}: {action.upper()} (Confidence: {confidence}%) - {reasoning}",
|
||||
)
|
||||
|
||||
print(f"\n✅ Recorded directional signals for {len(decisions)} stocks")
|
||||
|
||||
return signal_report
|
||||
|
||||
def get_signal_summary(self) -> Dict[str, Any]:
|
||||
"""Get signal recording summary"""
|
||||
return {
|
||||
"total_signals": len(self.signal_log),
|
||||
"signal_log": self.signal_log,
|
||||
}
|
||||
|
||||
|
||||
def parse_pm_decisions(pm_output: Dict[str, Any]) -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
Parse Portfolio Manager output format
|
||||
|
||||
Args:
|
||||
pm_output: PM's raw output
|
||||
|
||||
Returns:
|
||||
Standardized decision format
|
||||
"""
|
||||
if isinstance(pm_output, dict) and "decisions" in pm_output:
|
||||
return pm_output["decisions"]
|
||||
elif isinstance(pm_output, dict):
|
||||
# If directly a decision dictionary
|
||||
return pm_output
|
||||
else:
|
||||
print(f"Warning: Unable to parse PM output format: {type(pm_output)}")
|
||||
return {}
|
||||
|
||||
|
||||
class PortfolioTradeExecutor:
|
||||
"""Portfolio mode trade executor, executes specific trades and tracks positions"""
|
||||
|
||||
portfolio: Dict[str, Any]
|
||||
trade_history: List[Dict[str, Any]]
|
||||
portfolio_history: List[Dict[str, Any]]
|
||||
|
||||
def __init__(self, initial_portfolio: Optional[Dict[str, Any]] = None):
|
||||
"""
|
||||
Initialize Portfolio trade executor
|
||||
|
||||
Args:
|
||||
initial_portfolio: Initial portfolio state
|
||||
"""
|
||||
|
||||
if initial_portfolio is None:
|
||||
self.portfolio = {
|
||||
"cash": 100000.0,
|
||||
"positions": {},
|
||||
# Default 0.0 (short selling disabled)
|
||||
"margin_requirement": 0.0,
|
||||
"margin_used": 0.0,
|
||||
}
|
||||
else:
|
||||
self.portfolio = deepcopy(initial_portfolio)
|
||||
|
||||
self.trade_history = [] # Trade history
|
||||
self.portfolio_history = [] # Portfolio history
|
||||
|
||||
def execute_trade(
|
||||
self,
|
||||
ticker: str,
|
||||
action: str,
|
||||
quantity: int,
|
||||
price: float,
|
||||
current_date: str = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute a single trade
|
||||
|
||||
Args:
|
||||
ticker: Stock ticker
|
||||
action: Trade action (long/short/hold)
|
||||
quantity: Number of shares
|
||||
price: Current price
|
||||
current_date: Trade date
|
||||
|
||||
Returns:
|
||||
Trade result dictionary
|
||||
"""
|
||||
if current_date is None:
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
|
||||
if action == "hold" or quantity == 0:
|
||||
return {"status": "success", "message": "No trade needed"}
|
||||
|
||||
if price <= 0:
|
||||
return {"status": "failed", "reason": "Invalid price"}
|
||||
|
||||
result = self._execute_single_trade(
|
||||
ticker=ticker,
|
||||
action=action,
|
||||
target_quantity=quantity,
|
||||
price=price,
|
||||
date=current_date,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def execute_trades(
|
||||
self,
|
||||
decisions: Dict[str, Dict[str, Any]],
|
||||
current_prices: Dict[str, float],
|
||||
current_date: str = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute trading decisions and update positions
|
||||
|
||||
Args:
|
||||
decisions: {ticker: {action, quantity, confidence, reasoning}}
|
||||
current_prices: {ticker: current_price}
|
||||
current_date: Current date (used for backtest compatibility)
|
||||
|
||||
Returns:
|
||||
Trade execution report
|
||||
"""
|
||||
if current_date is None:
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
|
||||
# Use provided date for timestamp (backtest compatible)
|
||||
timestamp = f"{current_date}T09:30:00"
|
||||
|
||||
execution_report: Dict[str, Any] = {
|
||||
"date": current_date,
|
||||
"timestamp": timestamp,
|
||||
"executed_trades": [],
|
||||
"failed_trades": [],
|
||||
"portfolio_before": deepcopy(self.portfolio),
|
||||
"portfolio_after": None,
|
||||
}
|
||||
|
||||
print(f"\n💼 Executing Portfolio trades for {current_date}...")
|
||||
|
||||
# Execute trades for each ticker
|
||||
for ticker, decision in decisions.items():
|
||||
action = decision.get("action", "hold")
|
||||
quantity = decision.get("quantity", 0)
|
||||
|
||||
if action == "hold" or quantity == 0:
|
||||
continue
|
||||
|
||||
price = current_prices.get(ticker, 0)
|
||||
if price <= 0:
|
||||
execution_report["failed_trades"].append(
|
||||
{
|
||||
"ticker": ticker,
|
||||
"action": action,
|
||||
"quantity": quantity,
|
||||
"reason": "No valid price data",
|
||||
},
|
||||
)
|
||||
print(
|
||||
f" ❌ {ticker}: Unable to execute {action} - No valid price",
|
||||
)
|
||||
continue
|
||||
|
||||
# Execute trade
|
||||
trade_result = self._execute_single_trade(
|
||||
ticker,
|
||||
action,
|
||||
quantity,
|
||||
price,
|
||||
current_date,
|
||||
)
|
||||
if trade_result["status"] == "success":
|
||||
execution_report["executed_trades"].append(trade_result)
|
||||
|
||||
trades_info = ", ".join(trade_result.get("trades", []))
|
||||
print(
|
||||
f" ✔ {ticker}: {action} Target {quantity} shares "
|
||||
f"({trades_info}) @ ${price:.2f}",
|
||||
)
|
||||
else:
|
||||
execution_report["failed_trades"].append(trade_result)
|
||||
print(
|
||||
f" ✗ {ticker}: Unable to execute {action} - {trade_result['reason']}",
|
||||
)
|
||||
|
||||
# Record final portfolio state
|
||||
execution_report["portfolio_after"] = deepcopy(self.portfolio)
|
||||
self.portfolio_history.append(
|
||||
{
|
||||
"date": current_date,
|
||||
"portfolio": deepcopy(self.portfolio),
|
||||
},
|
||||
)
|
||||
|
||||
# Calculate portfolio value
|
||||
portfolio_value = self._calculate_portfolio_value(current_prices)
|
||||
execution_report["portfolio_value"] = portfolio_value
|
||||
|
||||
print("\n✔ Trade execution completed:")
|
||||
print(f" Success: {len(execution_report['executed_trades'])} trades")
|
||||
print(f" Failed: {len(execution_report['failed_trades'])} trades")
|
||||
print(f" Portfolio value: ${portfolio_value:,.2f}")
|
||||
print(f" Cash balance: ${self.portfolio['cash']:,.2f}")
|
||||
|
||||
return execution_report
|
||||
|
||||
def _execute_single_trade(
|
||||
self,
|
||||
ticker: str,
|
||||
action: str,
|
||||
target_quantity: int,
|
||||
price: float,
|
||||
date: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute single trade - Incremental mode
|
||||
|
||||
Args:
|
||||
ticker: Stock ticker
|
||||
action: long(add position)/short(reduce position)/hold
|
||||
target_quantity: Incremental quantity (long=buy shares, short=sell shares)
|
||||
price: Current price
|
||||
date: Trade date
|
||||
"""
|
||||
|
||||
# Ensure position exists
|
||||
if ticker not in self.portfolio["positions"]:
|
||||
self.portfolio["positions"][ticker] = {
|
||||
"long": 0,
|
||||
"short": 0,
|
||||
"long_cost_basis": 0.0,
|
||||
"short_cost_basis": 0.0,
|
||||
}
|
||||
|
||||
position = self.portfolio["positions"][ticker]
|
||||
current_long = position["long"]
|
||||
current_short = position["short"]
|
||||
|
||||
trades_executed = [] # Record actually executed trade steps
|
||||
|
||||
if action == "long":
|
||||
result = self._execute_long_action(
|
||||
ticker,
|
||||
target_quantity,
|
||||
price,
|
||||
date,
|
||||
current_long,
|
||||
current_short,
|
||||
trades_executed,
|
||||
)
|
||||
if result["status"] == "failed":
|
||||
return result
|
||||
|
||||
elif action == "short":
|
||||
result = self._execute_short_action(
|
||||
ticker,
|
||||
target_quantity,
|
||||
price,
|
||||
date,
|
||||
current_long,
|
||||
current_short,
|
||||
trades_executed,
|
||||
)
|
||||
if result["status"] == "failed":
|
||||
return result
|
||||
|
||||
elif action == "hold":
|
||||
print(f"\n⏸️ {ticker} Position unchanged: {current_long} shares")
|
||||
|
||||
# Record trade with backtest-compatible timestamp
|
||||
trade_record = {
|
||||
"status": "success",
|
||||
"ticker": ticker,
|
||||
"action": action,
|
||||
"target_quantity": target_quantity,
|
||||
"price": price,
|
||||
"trades": trades_executed,
|
||||
"date": date,
|
||||
"timestamp": f"{date}T09:30:00",
|
||||
}
|
||||
|
||||
self.trade_history.append(trade_record)
|
||||
|
||||
return trade_record
|
||||
|
||||
def _execute_long_action(
|
||||
self,
|
||||
ticker: str,
|
||||
target_quantity: int,
|
||||
price: float,
|
||||
date: str,
|
||||
current_long: int,
|
||||
current_short: int,
|
||||
trades_executed: list,
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute long action: Buy shares or cover shorts first"""
|
||||
print(
|
||||
f"\n📈 {ticker} Long operation: Current Long {current_long}, "
|
||||
f"Short {current_short} → Target quantity {target_quantity}",
|
||||
)
|
||||
|
||||
if target_quantity <= 0:
|
||||
print(" ⏸️ Quantity is 0, no trade needed")
|
||||
return {"status": "success"}
|
||||
|
||||
remaining = target_quantity
|
||||
|
||||
# If has short position, cover first
|
||||
if current_short > 0:
|
||||
cover_qty = min(remaining, current_short)
|
||||
print(f" 1️⃣ Cover short: {cover_qty} shares")
|
||||
cover_result = self._cover_short_position(
|
||||
ticker,
|
||||
cover_qty,
|
||||
price,
|
||||
date,
|
||||
)
|
||||
if cover_result["status"] == "failed":
|
||||
return cover_result
|
||||
trades_executed.append(f"Cover {cover_qty} shares")
|
||||
remaining -= cover_qty
|
||||
|
||||
# If still has remaining quantity, buy long
|
||||
if remaining > 0:
|
||||
print(f" 2️⃣ Buy long: {remaining} shares")
|
||||
buy_result = self._buy_long_position(
|
||||
ticker,
|
||||
remaining,
|
||||
price,
|
||||
date,
|
||||
)
|
||||
if buy_result["status"] == "failed":
|
||||
return buy_result
|
||||
trades_executed.append(f"Buy {remaining} shares")
|
||||
|
||||
# Display final result
|
||||
final_long = self.portfolio["positions"][ticker]["long"]
|
||||
final_short = self.portfolio["positions"][ticker]["short"]
|
||||
print(
|
||||
f" ✅ Final state: Long {final_long} shares, Short {final_short} shares",
|
||||
)
|
||||
|
||||
return {"status": "success"}
|
||||
|
||||
def _execute_short_action(
|
||||
self,
|
||||
ticker: str,
|
||||
target_quantity: int,
|
||||
price: float,
|
||||
date: str,
|
||||
current_long: int,
|
||||
current_short: int,
|
||||
trades_executed: list,
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute short action: Sell long positions first, then short if needed"""
|
||||
print(
|
||||
f"\n📉 {ticker} Short operation (quantity={target_quantity} shares):",
|
||||
)
|
||||
print(
|
||||
f" Current state: Long {current_long} shares, Short {current_short} shares",
|
||||
)
|
||||
|
||||
if target_quantity <= 0:
|
||||
print(" ⏸️ Quantity is 0, no trade needed")
|
||||
return {"status": "success"}
|
||||
|
||||
remaining_quantity = target_quantity
|
||||
|
||||
# Step 1: If there are long positions, sell first
|
||||
if current_long > 0:
|
||||
sell_quantity = min(remaining_quantity, current_long)
|
||||
print(f" 1️⃣ Sell long: {sell_quantity} shares")
|
||||
sell_result = self._sell_long_position(
|
||||
ticker,
|
||||
sell_quantity,
|
||||
price,
|
||||
date,
|
||||
)
|
||||
if sell_result["status"] == "failed":
|
||||
return sell_result
|
||||
trades_executed.append(f"Sell {sell_quantity} shares")
|
||||
remaining_quantity -= sell_quantity
|
||||
|
||||
# Step 2: If there's remaining quantity, establish or increase short position
|
||||
if remaining_quantity > 0:
|
||||
print(f" 2️⃣ Short: {remaining_quantity} shares")
|
||||
short_result = self._open_short_position(
|
||||
ticker,
|
||||
remaining_quantity,
|
||||
price,
|
||||
date,
|
||||
)
|
||||
if short_result["status"] == "failed":
|
||||
return short_result
|
||||
trades_executed.append(f"Short {remaining_quantity} shares")
|
||||
|
||||
# Display final result
|
||||
final_long = self.portfolio["positions"][ticker]["long"]
|
||||
final_short = self.portfolio["positions"][ticker]["short"]
|
||||
print(
|
||||
f" ✅ Final state: Long {final_long} shares, Short {final_short} shares",
|
||||
)
|
||||
|
||||
return {"status": "success"}
|
||||
|
||||
def _buy_long_position(
|
||||
self,
|
||||
ticker: str,
|
||||
quantity: int,
|
||||
price: float,
|
||||
_date: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""Buy long position"""
|
||||
position = self.portfolio["positions"][ticker]
|
||||
trade_value = quantity * price
|
||||
|
||||
if self.portfolio["cash"] < trade_value:
|
||||
return {
|
||||
"status": "failed",
|
||||
"ticker": ticker,
|
||||
"action": "buy",
|
||||
"quantity": quantity,
|
||||
"price": price,
|
||||
"reason": f"Insufficient cash (needed: ${trade_value:.2f}, available: "
|
||||
f"${self.portfolio['cash']:.2f})",
|
||||
}
|
||||
|
||||
# Update position cost basis
|
||||
old_long = position["long"]
|
||||
old_cost_basis = position["long_cost_basis"]
|
||||
new_long = old_long + quantity
|
||||
|
||||
# 🐛 Debug info
|
||||
print(f" 🔍 Buy {ticker}:")
|
||||
print(f" Old position: {old_long} shares @ ${old_cost_basis:.2f}")
|
||||
print(f" Buy: {quantity} shares @ ${price:.2f}")
|
||||
print(f" New position: {new_long} shares")
|
||||
|
||||
if new_long > 0:
|
||||
new_cost_basis = (
|
||||
(old_long * old_cost_basis) + (quantity * price)
|
||||
) / new_long
|
||||
print(
|
||||
f" New cost: ${new_cost_basis:.2f} = "
|
||||
f"(({old_long} × ${old_cost_basis:.2f}) + "
|
||||
f"({quantity} × ${price:.2f})) / {new_long}",
|
||||
)
|
||||
position["long_cost_basis"] = new_cost_basis
|
||||
position["long"] = new_long
|
||||
|
||||
# Deduct cash
|
||||
self.portfolio["cash"] -= trade_value
|
||||
|
||||
return {"status": "success"}
|
||||
|
||||
def _sell_long_position(
|
||||
self,
|
||||
ticker: str,
|
||||
quantity: int,
|
||||
price: float,
|
||||
_date: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""Sell long position"""
|
||||
position = self.portfolio["positions"][ticker]
|
||||
|
||||
if position["long"] < quantity:
|
||||
return {
|
||||
"status": "failed",
|
||||
"ticker": ticker,
|
||||
"action": "sell",
|
||||
"quantity": quantity,
|
||||
"price": price,
|
||||
"reason": f"Insufficient long position (holding: {position['long']},"
|
||||
f" trying to sell: {quantity})",
|
||||
}
|
||||
|
||||
# Reduce position
|
||||
position["long"] -= quantity
|
||||
if position["long"] == 0:
|
||||
position["long_cost_basis"] = 0.0
|
||||
|
||||
# Increase cash
|
||||
trade_value = quantity * price
|
||||
self.portfolio["cash"] += trade_value
|
||||
|
||||
return {"status": "success"}
|
||||
|
||||
def _open_short_position(
|
||||
self,
|
||||
ticker: str,
|
||||
quantity: int,
|
||||
price: float,
|
||||
_date: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""Open short position"""
|
||||
position = self.portfolio["positions"][ticker]
|
||||
trade_value = quantity * price
|
||||
margin_needed = trade_value * self.portfolio["margin_requirement"]
|
||||
|
||||
if self.portfolio["cash"] < margin_needed:
|
||||
return {
|
||||
"status": "failed",
|
||||
"ticker": ticker,
|
||||
"action": "short",
|
||||
"quantity": quantity,
|
||||
"price": price,
|
||||
"reason": f"Insufficient margin (needed: ${margin_needed:.2f}, "
|
||||
f"available: ${self.portfolio['cash']:.2f})",
|
||||
}
|
||||
|
||||
# Update position cost basis
|
||||
old_short = position["short"]
|
||||
old_cost_basis = position["short_cost_basis"]
|
||||
new_short = old_short + quantity
|
||||
if new_short > 0:
|
||||
position["short_cost_basis"] = (
|
||||
(old_short * old_cost_basis) + (quantity * price)
|
||||
) / new_short
|
||||
position["short"] = new_short
|
||||
|
||||
# Increase cash (short sale proceeds) and margin used
|
||||
self.portfolio["cash"] += trade_value - margin_needed
|
||||
self.portfolio["margin_used"] += margin_needed
|
||||
|
||||
return {"status": "success"}
|
||||
|
||||
def _cover_short_position(
|
||||
self,
|
||||
ticker: str,
|
||||
quantity: int,
|
||||
price: float,
|
||||
_date: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""Cover short position"""
|
||||
position = self.portfolio["positions"][ticker]
|
||||
|
||||
if position["short"] < quantity:
|
||||
return {
|
||||
"status": "failed",
|
||||
"ticker": ticker,
|
||||
"action": "cover",
|
||||
"quantity": quantity,
|
||||
"price": price,
|
||||
"reason": f"Insufficient short position (holding: {position['short']}, "
|
||||
f"trying to cover: {quantity})",
|
||||
}
|
||||
|
||||
# Calculate released margin - 🔧 FIX: Use cost_basis instead of current price
|
||||
trade_value = quantity * price
|
||||
cost_basis = position["short_cost_basis"]
|
||||
margin_released = (
|
||||
quantity * cost_basis * self.portfolio["margin_requirement"]
|
||||
)
|
||||
|
||||
# Reduce position
|
||||
position["short"] -= quantity
|
||||
if position["short"] == 0:
|
||||
position["short_cost_basis"] = 0.0
|
||||
|
||||
# Deduct cash (buy to cover) and release margin
|
||||
self.portfolio["cash"] -= trade_value
|
||||
self.portfolio["cash"] += margin_released
|
||||
self.portfolio["margin_used"] -= margin_released
|
||||
|
||||
return {"status": "success"}
|
||||
|
||||
def _calculate_portfolio_value(
|
||||
self,
|
||||
current_prices: Dict[str, float],
|
||||
) -> float:
|
||||
"""Calculate total portfolio value (net liquidation value)"""
|
||||
# Add margin_used back because it's frozen cash, not lost money
|
||||
total_value = self.portfolio["cash"] + self.portfolio["margin_used"]
|
||||
|
||||
for ticker, position in self.portfolio["positions"].items():
|
||||
if ticker in current_prices:
|
||||
price = current_prices[ticker]
|
||||
# Add long position value
|
||||
total_value += position["long"] * price
|
||||
# Subtract short position value (liability)
|
||||
total_value -= position["short"] * price
|
||||
|
||||
return total_value
|
||||
|
||||
def get_portfolio_summary(
|
||||
self,
|
||||
current_prices: Dict[str, float],
|
||||
) -> Dict[str, Any]:
|
||||
"""Get portfolio summary"""
|
||||
portfolio_value = self._calculate_portfolio_value(current_prices)
|
||||
|
||||
positions_summary = []
|
||||
for ticker, position in self.portfolio["positions"].items():
|
||||
if position["long"] > 0 or position["short"] > 0:
|
||||
price = current_prices.get(ticker, 0)
|
||||
long_value = position["long"] * price
|
||||
short_value = position["short"] * price
|
||||
|
||||
positions_summary.append(
|
||||
{
|
||||
"ticker": ticker,
|
||||
"long_shares": position["long"],
|
||||
"short_shares": position["short"],
|
||||
"long_value": long_value,
|
||||
"short_value": short_value,
|
||||
"long_cost_basis": position["long_cost_basis"],
|
||||
"short_cost_basis": position["short_cost_basis"],
|
||||
"long_pnl": (
|
||||
long_value
|
||||
- (position["long"] * position["long_cost_basis"])
|
||||
if position["long"] > 0
|
||||
else 0
|
||||
),
|
||||
"short_pnl": (
|
||||
(position["short"] * position["short_cost_basis"])
|
||||
- short_value
|
||||
if position["short"] > 0
|
||||
else 0
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
return {
|
||||
"portfolio_value": portfolio_value,
|
||||
"cash": self.portfolio["cash"],
|
||||
"margin_used": self.portfolio["margin_used"],
|
||||
"positions": positions_summary,
|
||||
"total_trades": len(self.trade_history),
|
||||
}
|
||||
|
||||
|
||||
def execute_trading_decisions(
|
||||
pm_decisions: Dict[str, Any],
|
||||
current_date: str = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Convenience function to record directional signal decisions (Signal mode)
|
||||
|
||||
Args:
|
||||
pm_decisions: PM's direction decisions
|
||||
current_date: Current date (optional)
|
||||
|
||||
Returns:
|
||||
Signal recording report
|
||||
"""
|
||||
# Parse PM decisions
|
||||
decisions = parse_pm_decisions(pm_decisions)
|
||||
|
||||
# Create direction signal recorder
|
||||
recorder = DirectionSignalRecorder()
|
||||
|
||||
# Record directional signals
|
||||
signal_report = recorder.record_direction_signals(decisions, current_date)
|
||||
|
||||
return signal_report
|
||||
|
||||
|
||||
def execute_portfolio_trades(
|
||||
pm_decisions: Dict[str, Any],
|
||||
current_prices: Dict[str, float],
|
||||
portfolio: Dict[str, Any],
|
||||
current_date: str = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute Portfolio mode trading decisions
|
||||
|
||||
Args:
|
||||
pm_decisions: PM's trading decisions
|
||||
current_prices: Current prices
|
||||
portfolio: Current portfolio state
|
||||
current_date: Current date (optional)
|
||||
|
||||
Returns:
|
||||
Trade execution report and updated portfolio
|
||||
"""
|
||||
# Parse PM decisions
|
||||
decisions = parse_pm_decisions(pm_decisions)
|
||||
|
||||
# Create Portfolio trade executor
|
||||
executor = PortfolioTradeExecutor(initial_portfolio=portfolio)
|
||||
|
||||
# Execute trades
|
||||
execution_report = executor.execute_trades(
|
||||
decisions,
|
||||
current_prices,
|
||||
current_date,
|
||||
)
|
||||
|
||||
# Add portfolio summary
|
||||
execution_report["portfolio_summary"] = executor.get_portfolio_summary(
|
||||
current_prices,
|
||||
)
|
||||
|
||||
# Return updated portfolio
|
||||
execution_report["updated_portfolio"] = executor.portfolio
|
||||
|
||||
return execution_report
|
||||
2631
backtest/state/server_state.json
Normal file
2631
backtest/state/server_state.json
Normal file
File diff suppressed because one or more lines are too long
474
backtest/team_dashboard/_internal_state.json
Normal file
474
backtest/team_dashboard/_internal_state.json
Normal file
@@ -0,0 +1,474 @@
|
||||
{
|
||||
"baseline_state": {
|
||||
"initialized": true,
|
||||
"initial_allocation": {
|
||||
"AAPL": 52.82787621372046,
|
||||
"MSFT": 27.48283353510314,
|
||||
"GOOGL": 50.62714374311787,
|
||||
"NVDA": 68.65491294557039,
|
||||
"TSLA": 31.329007841650665,
|
||||
"META": 21.77700348432056,
|
||||
"AMZN": 55.94343000358038
|
||||
}
|
||||
},
|
||||
"baseline_vw_state": {
|
||||
"initialized": true,
|
||||
"initial_allocation": {
|
||||
"AAPL": 68.50435598171448,
|
||||
"MSFT": 28.26372943269579,
|
||||
"GOOGL": 64.10562703513074,
|
||||
"NVDA": 105.43488803941372,
|
||||
"TSLA": 16.283886873554753,
|
||||
"META": 12.29869945153529,
|
||||
"AMZN": 44.10358298129591
|
||||
}
|
||||
},
|
||||
"momentum_state": {
|
||||
"positions": {
|
||||
"AAPL": 123.26504449868106,
|
||||
"MSFT": 64.12661158190733,
|
||||
"GOOGL": 118.13000206727504
|
||||
},
|
||||
"cash": 0.0,
|
||||
"initialized": true,
|
||||
"last_rebalance_date": "2025-11-03"
|
||||
},
|
||||
"equity_history": [
|
||||
{
|
||||
"t": 1762070400000,
|
||||
"v": 100000.0
|
||||
},
|
||||
{
|
||||
"t": 1762156800000,
|
||||
"v": 99785.98
|
||||
},
|
||||
{
|
||||
"t": 1762243200000,
|
||||
"v": 99590.68
|
||||
},
|
||||
{
|
||||
"t": 1762329600000,
|
||||
"v": 99298.78
|
||||
},
|
||||
{
|
||||
"t": 1762416000000,
|
||||
"v": 98425.78
|
||||
},
|
||||
{
|
||||
"t": 1762502400000,
|
||||
"v": 98434.93
|
||||
}
|
||||
],
|
||||
"baseline_history": [
|
||||
{
|
||||
"t": 1762070400000,
|
||||
"v": 100000.0
|
||||
},
|
||||
{
|
||||
"t": 1762156800000,
|
||||
"v": 99760.66
|
||||
},
|
||||
{
|
||||
"t": 1762243200000,
|
||||
"v": 97620.18
|
||||
},
|
||||
{
|
||||
"t": 1762329600000,
|
||||
"v": 98327.37
|
||||
},
|
||||
{
|
||||
"t": 1762416000000,
|
||||
"v": 96286.86
|
||||
},
|
||||
{
|
||||
"t": 1762502400000,
|
||||
"v": 95539.06
|
||||
}
|
||||
],
|
||||
"baseline_vw_history": [
|
||||
{
|
||||
"t": 1762070400000,
|
||||
"v": 100000.0
|
||||
},
|
||||
{
|
||||
"t": 1762156800000,
|
||||
"v": 99716.91
|
||||
},
|
||||
{
|
||||
"t": 1762243200000,
|
||||
"v": 97721.94
|
||||
},
|
||||
{
|
||||
"t": 1762329600000,
|
||||
"v": 98028.19
|
||||
},
|
||||
{
|
||||
"t": 1762416000000,
|
||||
"v": 96206.83
|
||||
},
|
||||
{
|
||||
"t": 1762502400000,
|
||||
"v": 95565.33
|
||||
}
|
||||
],
|
||||
"momentum_history": [
|
||||
{
|
||||
"t": 1762070400000,
|
||||
"v": 100000.0
|
||||
},
|
||||
{
|
||||
"t": 1762156800000,
|
||||
"v": 99835.69
|
||||
},
|
||||
{
|
||||
"t": 1762243200000,
|
||||
"v": 99054.53
|
||||
},
|
||||
{
|
||||
"t": 1762329600000,
|
||||
"v": 99406.81
|
||||
},
|
||||
{
|
||||
"t": 1762416000000,
|
||||
"v": 98768.07
|
||||
},
|
||||
{
|
||||
"t": 1762502400000,
|
||||
"v": 97890.54
|
||||
}
|
||||
],
|
||||
"price_history": {
|
||||
"AAPL": [
|
||||
{
|
||||
"date": "2025-11-03",
|
||||
"price": 269.05
|
||||
},
|
||||
{
|
||||
"date": "2025-11-04",
|
||||
"price": 270.04
|
||||
},
|
||||
{
|
||||
"date": "2025-11-05",
|
||||
"price": 270.14
|
||||
},
|
||||
{
|
||||
"date": "2025-11-06",
|
||||
"price": 269.77
|
||||
},
|
||||
{
|
||||
"date": "2025-11-07",
|
||||
"price": 268.47
|
||||
}
|
||||
],
|
||||
"MSFT": [
|
||||
{
|
||||
"date": "2025-11-03",
|
||||
"price": 517.03
|
||||
},
|
||||
{
|
||||
"date": "2025-11-04",
|
||||
"price": 514.33
|
||||
},
|
||||
{
|
||||
"date": "2025-11-05",
|
||||
"price": 507.16
|
||||
},
|
||||
{
|
||||
"date": "2025-11-06",
|
||||
"price": 497.1
|
||||
},
|
||||
{
|
||||
"date": "2025-11-07",
|
||||
"price": 496.82
|
||||
}
|
||||
],
|
||||
"GOOGL": [
|
||||
{
|
||||
"date": "2025-11-03",
|
||||
"price": 283.72
|
||||
},
|
||||
{
|
||||
"date": "2025-11-04",
|
||||
"price": 277.54
|
||||
},
|
||||
{
|
||||
"date": "2025-11-05",
|
||||
"price": 284.31
|
||||
},
|
||||
{
|
||||
"date": "2025-11-06",
|
||||
"price": 284.75
|
||||
},
|
||||
{
|
||||
"date": "2025-11-07",
|
||||
"price": 278.83
|
||||
}
|
||||
],
|
||||
"NVDA": [
|
||||
{
|
||||
"date": "2025-11-03",
|
||||
"price": 206.88
|
||||
},
|
||||
{
|
||||
"date": "2025-11-04",
|
||||
"price": 198.69
|
||||
},
|
||||
{
|
||||
"date": "2025-11-05",
|
||||
"price": 195.21
|
||||
},
|
||||
{
|
||||
"date": "2025-11-06",
|
||||
"price": 188.08
|
||||
},
|
||||
{
|
||||
"date": "2025-11-07",
|
||||
"price": 188.15
|
||||
}
|
||||
],
|
||||
"TSLA": [
|
||||
{
|
||||
"date": "2025-11-03",
|
||||
"price": 468.37
|
||||
},
|
||||
{
|
||||
"date": "2025-11-04",
|
||||
"price": 444.26
|
||||
},
|
||||
{
|
||||
"date": "2025-11-05",
|
||||
"price": 462.07
|
||||
},
|
||||
{
|
||||
"date": "2025-11-06",
|
||||
"price": 445.91
|
||||
},
|
||||
{
|
||||
"date": "2025-11-07",
|
||||
"price": 429.52
|
||||
}
|
||||
],
|
||||
"META": [
|
||||
{
|
||||
"date": "2025-11-03",
|
||||
"price": 637.71
|
||||
},
|
||||
{
|
||||
"date": "2025-11-04",
|
||||
"price": 627.32
|
||||
},
|
||||
{
|
||||
"date": "2025-11-05",
|
||||
"price": 635.95
|
||||
},
|
||||
{
|
||||
"date": "2025-11-06",
|
||||
"price": 618.94
|
||||
},
|
||||
{
|
||||
"date": "2025-11-07",
|
||||
"price": 621.71
|
||||
}
|
||||
],
|
||||
"AMZN": [
|
||||
{
|
||||
"date": "2025-11-03",
|
||||
"price": 254.0
|
||||
},
|
||||
{
|
||||
"date": "2025-11-04",
|
||||
"price": 249.32
|
||||
},
|
||||
{
|
||||
"date": "2025-11-05",
|
||||
"price": 250.2
|
||||
},
|
||||
{
|
||||
"date": "2025-11-06",
|
||||
"price": 243.04
|
||||
},
|
||||
{
|
||||
"date": "2025-11-07",
|
||||
"price": 244.41
|
||||
}
|
||||
]
|
||||
},
|
||||
"portfolio_state": {
|
||||
"cash": 25395.10000000001,
|
||||
"positions": {
|
||||
"MSFT": {
|
||||
"long": 60,
|
||||
"short": 0,
|
||||
"long_cost_basis": 514.2845833333333,
|
||||
"short_cost_basis": 0.0
|
||||
},
|
||||
"GOOGL": {
|
||||
"long": 50,
|
||||
"short": 0,
|
||||
"long_cost_basis": 279.556,
|
||||
"short_cost_basis": 0.0
|
||||
},
|
||||
"META": {
|
||||
"long": 20,
|
||||
"short": 0,
|
||||
"long_cost_basis": 644.155,
|
||||
"short_cost_basis": 0.0
|
||||
},
|
||||
"AMZN": {
|
||||
"long": 40,
|
||||
"short": 0,
|
||||
"long_cost_basis": 247.5725,
|
||||
"short_cost_basis": 0.0
|
||||
},
|
||||
"NVDA": {
|
||||
"long": 20,
|
||||
"short": 0,
|
||||
"long_cost_basis": 203.0,
|
||||
"short_cost_basis": 0.0
|
||||
},
|
||||
"TSLA": {
|
||||
"long": 0,
|
||||
"short": 15,
|
||||
"long_cost_basis": 0.0,
|
||||
"short_cost_basis": 454.46
|
||||
},
|
||||
"AAPL": {
|
||||
"long": 30,
|
||||
"short": 0,
|
||||
"long_cost_basis": 267.89,
|
||||
"short_cost_basis": 0.0
|
||||
}
|
||||
},
|
||||
"margin_used": 1704.225
|
||||
},
|
||||
"all_trades": [
|
||||
{
|
||||
"id": "t_20251103_MSFT_0",
|
||||
"ts": 1762156800000,
|
||||
"trading_date": "2025-11-03",
|
||||
"side": "LONG",
|
||||
"ticker": "MSFT",
|
||||
"qty": 15,
|
||||
"price": 519.8
|
||||
},
|
||||
{
|
||||
"id": "t_20251103_GOOGL_1",
|
||||
"ts": 1762156800000,
|
||||
"trading_date": "2025-11-03",
|
||||
"side": "LONG",
|
||||
"ticker": "GOOGL",
|
||||
"qty": 20,
|
||||
"price": 282.18
|
||||
},
|
||||
{
|
||||
"id": "t_20251103_META_2",
|
||||
"ts": 1762156800000,
|
||||
"trading_date": "2025-11-03",
|
||||
"side": "LONG",
|
||||
"ticker": "META",
|
||||
"qty": 10,
|
||||
"price": 656.0
|
||||
},
|
||||
{
|
||||
"id": "t_20251103_AMZN_3",
|
||||
"ts": 1762156800000,
|
||||
"trading_date": "2025-11-03",
|
||||
"side": "LONG",
|
||||
"ticker": "AMZN",
|
||||
"qty": 15,
|
||||
"price": 255.36
|
||||
},
|
||||
{
|
||||
"id": "t_20251104_MSFT_0",
|
||||
"ts": 1762243200000,
|
||||
"trading_date": "2025-11-04",
|
||||
"side": "LONG",
|
||||
"ticker": "MSFT",
|
||||
"qty": 25,
|
||||
"price": 511.76
|
||||
},
|
||||
{
|
||||
"id": "t_20251104_GOOGL_1",
|
||||
"ts": 1762243200000,
|
||||
"trading_date": "2025-11-04",
|
||||
"side": "LONG",
|
||||
"ticker": "GOOGL",
|
||||
"qty": 15,
|
||||
"price": 276.75
|
||||
},
|
||||
{
|
||||
"id": "t_20251104_NVDA_2",
|
||||
"ts": 1762243200000,
|
||||
"trading_date": "2025-11-04",
|
||||
"side": "LONG",
|
||||
"ticker": "NVDA",
|
||||
"qty": 20,
|
||||
"price": 203.0
|
||||
},
|
||||
{
|
||||
"id": "t_20251104_TSLA_3",
|
||||
"ts": 1762243200000,
|
||||
"trading_date": "2025-11-04",
|
||||
"side": "SHORT",
|
||||
"ticker": "TSLA",
|
||||
"qty": 15,
|
||||
"price": 454.46
|
||||
},
|
||||
{
|
||||
"id": "t_20251105_MSFT_0",
|
||||
"ts": 1762329600000,
|
||||
"trading_date": "2025-11-05",
|
||||
"side": "LONG",
|
||||
"ticker": "MSFT",
|
||||
"qty": 20,
|
||||
"price": 513.3
|
||||
},
|
||||
{
|
||||
"id": "t_20251105_GOOGL_1",
|
||||
"ts": 1762329600000,
|
||||
"trading_date": "2025-11-05",
|
||||
"side": "LONG",
|
||||
"ticker": "GOOGL",
|
||||
"qty": 15,
|
||||
"price": 278.87
|
||||
},
|
||||
{
|
||||
"id": "t_20251105_META_2",
|
||||
"ts": 1762329600000,
|
||||
"trading_date": "2025-11-05",
|
||||
"side": "LONG",
|
||||
"ticker": "META",
|
||||
"qty": 10,
|
||||
"price": 632.31
|
||||
},
|
||||
{
|
||||
"id": "t_20251106_AAPL_0",
|
||||
"ts": 1762416000000,
|
||||
"trading_date": "2025-11-06",
|
||||
"side": "LONG",
|
||||
"ticker": "AAPL",
|
||||
"qty": 30,
|
||||
"price": 267.89
|
||||
},
|
||||
{
|
||||
"id": "t_20251107_AMZN_0",
|
||||
"ts": 1762502400000,
|
||||
"trading_date": "2025-11-07",
|
||||
"side": "LONG",
|
||||
"ticker": "AMZN",
|
||||
"qty": 25,
|
||||
"price": 242.9
|
||||
},
|
||||
{
|
||||
"id": "t_20251107_TSLA_1",
|
||||
"ts": 1762502400000,
|
||||
"trading_date": "2025-11-07",
|
||||
"side": "SHORT",
|
||||
"ticker": "TSLA",
|
||||
"qty": -5,
|
||||
"price": 437.92
|
||||
}
|
||||
],
|
||||
"daily_position_history": {},
|
||||
"last_update_date": "2025-11-07"
|
||||
}
|
||||
58
backtest/team_dashboard/holdings.json
Normal file
58
backtest/team_dashboard/holdings.json
Normal file
@@ -0,0 +1,58 @@
|
||||
[
|
||||
{
|
||||
"ticker": "MSFT",
|
||||
"quantity": 60,
|
||||
"currentPrice": 496.82,
|
||||
"marketValue": 29809.2,
|
||||
"weight": 0.3028
|
||||
},
|
||||
{
|
||||
"ticker": "CASH",
|
||||
"quantity": 1,
|
||||
"currentPrice": 25395.1,
|
||||
"marketValue": 25395.1,
|
||||
"weight": 0.258
|
||||
},
|
||||
{
|
||||
"ticker": "GOOGL",
|
||||
"quantity": 50,
|
||||
"currentPrice": 278.83,
|
||||
"marketValue": 13941.5,
|
||||
"weight": 0.1416
|
||||
},
|
||||
{
|
||||
"ticker": "META",
|
||||
"quantity": 20,
|
||||
"currentPrice": 621.71,
|
||||
"marketValue": 12434.2,
|
||||
"weight": 0.1263
|
||||
},
|
||||
{
|
||||
"ticker": "AMZN",
|
||||
"quantity": 40,
|
||||
"currentPrice": 244.41,
|
||||
"marketValue": 9776.4,
|
||||
"weight": 0.0993
|
||||
},
|
||||
{
|
||||
"ticker": "AAPL",
|
||||
"quantity": 30,
|
||||
"currentPrice": 268.47,
|
||||
"marketValue": 8054.1,
|
||||
"weight": 0.0818
|
||||
},
|
||||
{
|
||||
"ticker": "TSLA",
|
||||
"quantity": -15,
|
||||
"currentPrice": 429.52,
|
||||
"marketValue": -6442.8,
|
||||
"weight": 0.0655
|
||||
},
|
||||
{
|
||||
"ticker": "NVDA",
|
||||
"quantity": 20,
|
||||
"currentPrice": 188.15,
|
||||
"marketValue": 3763.0,
|
||||
"weight": 0.0382
|
||||
}
|
||||
]
|
||||
1189
backtest/team_dashboard/leaderboard.json
Normal file
1189
backtest/team_dashboard/leaderboard.json
Normal file
File diff suppressed because it is too large
Load Diff
18
backtest/team_dashboard/stats.json
Normal file
18
backtest/team_dashboard/stats.json
Normal file
@@ -0,0 +1,18 @@
|
||||
{
|
||||
"totalAssetValue": 98434.93,
|
||||
"totalReturn": -1.57,
|
||||
"cashPosition": 25395.1,
|
||||
"tickerWeights": {},
|
||||
"totalTrades": 14,
|
||||
"winRate": 0.0,
|
||||
"bullBear": {
|
||||
"bull": {
|
||||
"n": 0,
|
||||
"win": 0
|
||||
},
|
||||
"bear": {
|
||||
"n": 0,
|
||||
"win": 0
|
||||
}
|
||||
}
|
||||
}
|
||||
121
backtest/team_dashboard/summary.json
Normal file
121
backtest/team_dashboard/summary.json
Normal file
@@ -0,0 +1,121 @@
|
||||
{
|
||||
"totalAssetValue": 98434.93,
|
||||
"totalReturn": -1.57,
|
||||
"cashPosition": 25395.1,
|
||||
"tickerWeights": {
|
||||
"MSFT": 0.3028,
|
||||
"GOOGL": 0.1416,
|
||||
"META": 0.1263,
|
||||
"AMZN": 0.0993,
|
||||
"NVDA": 0.0382,
|
||||
"TSLA": -0.0655,
|
||||
"AAPL": 0.0818
|
||||
},
|
||||
"totalTrades": 14,
|
||||
"pnlPct": -1.57,
|
||||
"balance": 98434.93,
|
||||
"equity": [
|
||||
{
|
||||
"t": 1762070400000,
|
||||
"v": 100000.0
|
||||
},
|
||||
{
|
||||
"t": 1762156800000,
|
||||
"v": 99785.98
|
||||
},
|
||||
{
|
||||
"t": 1762243200000,
|
||||
"v": 99590.68
|
||||
},
|
||||
{
|
||||
"t": 1762329600000,
|
||||
"v": 99298.78
|
||||
},
|
||||
{
|
||||
"t": 1762416000000,
|
||||
"v": 98425.78
|
||||
},
|
||||
{
|
||||
"t": 1762502400000,
|
||||
"v": 98434.93
|
||||
}
|
||||
],
|
||||
"baseline": [
|
||||
{
|
||||
"t": 1762070400000,
|
||||
"v": 100000.0
|
||||
},
|
||||
{
|
||||
"t": 1762156800000,
|
||||
"v": 99760.66
|
||||
},
|
||||
{
|
||||
"t": 1762243200000,
|
||||
"v": 97620.18
|
||||
},
|
||||
{
|
||||
"t": 1762329600000,
|
||||
"v": 98327.37
|
||||
},
|
||||
{
|
||||
"t": 1762416000000,
|
||||
"v": 96286.86
|
||||
},
|
||||
{
|
||||
"t": 1762502400000,
|
||||
"v": 95539.06
|
||||
}
|
||||
],
|
||||
"baseline_vw": [
|
||||
{
|
||||
"t": 1762070400000,
|
||||
"v": 100000.0
|
||||
},
|
||||
{
|
||||
"t": 1762156800000,
|
||||
"v": 99716.91
|
||||
},
|
||||
{
|
||||
"t": 1762243200000,
|
||||
"v": 97721.94
|
||||
},
|
||||
{
|
||||
"t": 1762329600000,
|
||||
"v": 98028.19
|
||||
},
|
||||
{
|
||||
"t": 1762416000000,
|
||||
"v": 96206.83
|
||||
},
|
||||
{
|
||||
"t": 1762502400000,
|
||||
"v": 95565.33
|
||||
}
|
||||
],
|
||||
"momentum": [
|
||||
{
|
||||
"t": 1762070400000,
|
||||
"v": 100000.0
|
||||
},
|
||||
{
|
||||
"t": 1762156800000,
|
||||
"v": 99835.69
|
||||
},
|
||||
{
|
||||
"t": 1762243200000,
|
||||
"v": 99054.53
|
||||
},
|
||||
{
|
||||
"t": 1762329600000,
|
||||
"v": 99406.81
|
||||
},
|
||||
{
|
||||
"t": 1762416000000,
|
||||
"v": 98768.07
|
||||
},
|
||||
{
|
||||
"t": 1762502400000,
|
||||
"v": 97890.54
|
||||
}
|
||||
]
|
||||
}
|
||||
128
backtest/team_dashboard/trades.json
Normal file
128
backtest/team_dashboard/trades.json
Normal file
@@ -0,0 +1,128 @@
|
||||
[
|
||||
{
|
||||
"id": "t_20251107_AMZN_0",
|
||||
"timestamp": 1762502400000,
|
||||
"trading_date": "2025-11-07",
|
||||
"side": "LONG",
|
||||
"ticker": "AMZN",
|
||||
"qty": 25,
|
||||
"price": 242.9
|
||||
},
|
||||
{
|
||||
"id": "t_20251107_TSLA_1",
|
||||
"timestamp": 1762502400000,
|
||||
"trading_date": "2025-11-07",
|
||||
"side": "SHORT",
|
||||
"ticker": "TSLA",
|
||||
"qty": -5,
|
||||
"price": 437.92
|
||||
},
|
||||
{
|
||||
"id": "t_20251106_AAPL_0",
|
||||
"timestamp": 1762416000000,
|
||||
"trading_date": "2025-11-06",
|
||||
"side": "LONG",
|
||||
"ticker": "AAPL",
|
||||
"qty": 30,
|
||||
"price": 267.89
|
||||
},
|
||||
{
|
||||
"id": "t_20251105_MSFT_0",
|
||||
"timestamp": 1762329600000,
|
||||
"trading_date": "2025-11-05",
|
||||
"side": "LONG",
|
||||
"ticker": "MSFT",
|
||||
"qty": 20,
|
||||
"price": 513.3
|
||||
},
|
||||
{
|
||||
"id": "t_20251105_GOOGL_1",
|
||||
"timestamp": 1762329600000,
|
||||
"trading_date": "2025-11-05",
|
||||
"side": "LONG",
|
||||
"ticker": "GOOGL",
|
||||
"qty": 15,
|
||||
"price": 278.87
|
||||
},
|
||||
{
|
||||
"id": "t_20251105_META_2",
|
||||
"timestamp": 1762329600000,
|
||||
"trading_date": "2025-11-05",
|
||||
"side": "LONG",
|
||||
"ticker": "META",
|
||||
"qty": 10,
|
||||
"price": 632.31
|
||||
},
|
||||
{
|
||||
"id": "t_20251104_MSFT_0",
|
||||
"timestamp": 1762243200000,
|
||||
"trading_date": "2025-11-04",
|
||||
"side": "LONG",
|
||||
"ticker": "MSFT",
|
||||
"qty": 25,
|
||||
"price": 511.76
|
||||
},
|
||||
{
|
||||
"id": "t_20251104_GOOGL_1",
|
||||
"timestamp": 1762243200000,
|
||||
"trading_date": "2025-11-04",
|
||||
"side": "LONG",
|
||||
"ticker": "GOOGL",
|
||||
"qty": 15,
|
||||
"price": 276.75
|
||||
},
|
||||
{
|
||||
"id": "t_20251104_NVDA_2",
|
||||
"timestamp": 1762243200000,
|
||||
"trading_date": "2025-11-04",
|
||||
"side": "LONG",
|
||||
"ticker": "NVDA",
|
||||
"qty": 20,
|
||||
"price": 203.0
|
||||
},
|
||||
{
|
||||
"id": "t_20251104_TSLA_3",
|
||||
"timestamp": 1762243200000,
|
||||
"trading_date": "2025-11-04",
|
||||
"side": "SHORT",
|
||||
"ticker": "TSLA",
|
||||
"qty": 15,
|
||||
"price": 454.46
|
||||
},
|
||||
{
|
||||
"id": "t_20251103_MSFT_0",
|
||||
"timestamp": 1762156800000,
|
||||
"trading_date": "2025-11-03",
|
||||
"side": "LONG",
|
||||
"ticker": "MSFT",
|
||||
"qty": 15,
|
||||
"price": 519.8
|
||||
},
|
||||
{
|
||||
"id": "t_20251103_GOOGL_1",
|
||||
"timestamp": 1762156800000,
|
||||
"trading_date": "2025-11-03",
|
||||
"side": "LONG",
|
||||
"ticker": "GOOGL",
|
||||
"qty": 20,
|
||||
"price": 282.18
|
||||
},
|
||||
{
|
||||
"id": "t_20251103_META_2",
|
||||
"timestamp": 1762156800000,
|
||||
"trading_date": "2025-11-03",
|
||||
"side": "LONG",
|
||||
"ticker": "META",
|
||||
"qty": 10,
|
||||
"price": 656.0
|
||||
},
|
||||
{
|
||||
"id": "t_20251103_AMZN_3",
|
||||
"timestamp": 1762156800000,
|
||||
"trading_date": "2025-11-03",
|
||||
"side": "LONG",
|
||||
"ticker": "AMZN",
|
||||
"qty": 15,
|
||||
"price": 255.36
|
||||
}
|
||||
]
|
||||
BIN
docs/assets/dashboard.jpg
Normal file
BIN
docs/assets/dashboard.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 180 KiB |
BIN
docs/assets/evotraders_demo.gif
Normal file
BIN
docs/assets/evotraders_demo.gif
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1006 KiB |
BIN
docs/assets/evotraders_logo.jpg
Normal file
BIN
docs/assets/evotraders_logo.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 48 KiB |
BIN
docs/assets/evotraders_pipeline.jpg
Normal file
BIN
docs/assets/evotraders_pipeline.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 180 KiB |
BIN
docs/assets/performance.jpg
Normal file
BIN
docs/assets/performance.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 189 KiB |
59
env.template
Normal file
59
env.template
Normal file
@@ -0,0 +1,59 @@
|
||||
# ================== General Configuration | 通用配置 ==================
|
||||
# List of stock ticker symbols to analyze (comma-separated) | 想要分析的股票代码列表(用逗号分隔)
|
||||
TICKERS=AAPL,MSFT,GOOGL,NVDA,TSLA,META,AMZN
|
||||
|
||||
# Financial Data API
|
||||
# At least FINANCIAL_DATASETS_API_KEY is required, corresponding to FIN_DATA_SOURCE=financial_datasets; It's recommended to add FINNHUB_API_KEY, corresponding to FIN_DATA_SOURCE=finnhub; FINNHUB_API_KEY is mandatory for live mode
|
||||
# 至少需要FINANCIAL_DATASETS_API_KEY,对应FIN_DATA_SOURCE=financial_datasets;推荐添加FINNHUB_API_KEY,对应FIN_DATA_SOURCE=finnhub;如果使用live模式必须添加FINNHUB_API_KEY
|
||||
|
||||
# finnhub: https://finnhub.io/register
|
||||
# financial datasets: https://www.financialdatasets.ai/
|
||||
|
||||
FIN_DATA_SOURCE = #finnhub or financial_datasets | finnhub 或 financial_datasets
|
||||
FINANCIAL_DATASETS_API_KEY= #required | 必填
|
||||
FINNHUB_API_KEY= #optional | 可选
|
||||
|
||||
# Model API
|
||||
OPENAI_API_KEY=
|
||||
OPENAI_BASE_URL=
|
||||
MODEL_NAME=qwen3-max-preview
|
||||
|
||||
#记忆模块(Embedding and llm calls for Reme memory)
|
||||
# default to use aliyun dashscope url, more details: https://help.aliyun.com/zh/model-studio/what-is-model-studio
|
||||
MEMORY_API_KEY=
|
||||
|
||||
|
||||
# ================== Agent-Specific Model Configuration | Agent特定模型配置 ==================
|
||||
# Configure different base models for different roles | 为不同角色配置不同的基座模型
|
||||
# If not configured, global MODEL_NAME and MODEL_PROVIDER will be used | 如果未配置,将使用全局MODEL_NAME和MODEL_PROVIDER
|
||||
#
|
||||
# Role List | 角色列表:
|
||||
# - SENTIMENT_ANALYST: Sentiment Analyst | 情绪分析师
|
||||
# - TECHNICAL_ANALYST: Technical Analyst | 技术分析师
|
||||
# - FUNDAMENTALS_ANALYST: Fundamentals Analyst | 基本面分析师
|
||||
# - VALUATION_ANALYST: Valuation Analyst | 估值分析师
|
||||
# - PORTFOLIO_MANAGER: Portfolio Manager | 投资组合经理
|
||||
# - RISK_MANAGER: Risk Manager | 风险管理经理
|
||||
|
||||
AGENT_SENTIMENT_ANALYST_MODEL_NAME=deepseek-v3.2-exp
|
||||
AGENT_TECHNICAL_ANALYST_MODEL_NAME=glm-4.6
|
||||
AGENT_FUNDAMENTALS_ANALYST_MODEL_NAME=qwen3-max-preview
|
||||
AGENT_VALUATION_ANALYST_MODEL_NAME=Moonshot-Kimi-K2-Instruct
|
||||
AGENT_RISK_MANAGER_MODEL_NAME=qwen3-max-preview
|
||||
AGENT_PORTFOLIO_MANAGER_MODEL_NAME=qwen3-max-preview
|
||||
|
||||
|
||||
# ================== Advanced Configuration | 高阶配置 ==================
|
||||
|
||||
# Maximum conference discussion cycles (default: 2) | 最大会议讨论轮数(默认:2)
|
||||
MAX_COMM_CYCLES=2
|
||||
|
||||
# Margin Requirement | 保证金比例
|
||||
MARGIN_REQUIREMENT=0.5
|
||||
# 0.5 = Standard margin (recommended) | 标准保证金(推荐)
|
||||
# 0.25 = Maintenance margin (aggressive) | 维持保证金(激进)
|
||||
|
||||
# Historical data start date
|
||||
DATA_START_DATE=2022-01-01
|
||||
# Auto update data on startup (true/false)
|
||||
AUTO_UPDATE_DATA=true
|
||||
33
frontend/.gitignore
vendored
Normal file
33
frontend/.gitignore
vendored
Normal file
@@ -0,0 +1,33 @@
|
||||
# Logs
|
||||
logs
|
||||
*.log
|
||||
npm-debug.log*
|
||||
yarn-debug.log*
|
||||
yarn-error.log*
|
||||
pnpm-debug.log*
|
||||
lerna-debug.log*
|
||||
.env.local
|
||||
|
||||
# Dependencies
|
||||
node_modules
|
||||
|
||||
# Build output
|
||||
dist
|
||||
dist-ssr
|
||||
*.local
|
||||
|
||||
# Lock files
|
||||
package-lock.json
|
||||
yarn.lock
|
||||
pnpm-lock.yaml
|
||||
|
||||
# Editor directories and files
|
||||
.vscode/*
|
||||
!.vscode/extensions.json
|
||||
.idea
|
||||
.DS_Store
|
||||
*.suo
|
||||
*.ntvs*
|
||||
*.njsproj
|
||||
*.sln
|
||||
*.sw?
|
||||
7
frontend/README.md
Normal file
7
frontend/README.md
Normal file
@@ -0,0 +1,7 @@
|
||||
|
||||
## QuickStart
|
||||
```bash
|
||||
cd frontend
|
||||
npm install
|
||||
npm run dev
|
||||
```
|
||||
22
frontend/components.json
Normal file
22
frontend/components.json
Normal file
@@ -0,0 +1,22 @@
|
||||
{
|
||||
"$schema": "https://ui.shadcn.com/schema.json",
|
||||
"style": "new-york",
|
||||
"rsc": false,
|
||||
"tsx": true,
|
||||
"tailwind": {
|
||||
"config": "tailwind.config.js",
|
||||
"css": "index.css",
|
||||
"baseColor": "neutral",
|
||||
"cssVariables": true,
|
||||
"prefix": ""
|
||||
},
|
||||
"iconLibrary": "lucide",
|
||||
"aliases": {
|
||||
"components": "@/components",
|
||||
"utils": "@/lib/utils",
|
||||
"ui": "@/components/ui",
|
||||
"lib": "@/lib",
|
||||
"hooks": "@/hooks"
|
||||
},
|
||||
"registries": {}
|
||||
}
|
||||
10
frontend/env.template
Normal file
10
frontend/env.template
Normal file
@@ -0,0 +1,10 @@
|
||||
# Frontend Environment Variables Template
|
||||
# 复制此文件为 .env 并修改配置
|
||||
|
||||
# WebSocket服务器地址
|
||||
# 本地开发
|
||||
VITE_WS_URL=ws://localhost:8765
|
||||
|
||||
# 生产环境(替换为你的实际服务器地址)
|
||||
# VITE_WS_URL=wss://your-server.com:8765
|
||||
|
||||
29
frontend/eslint.config.js
Normal file
29
frontend/eslint.config.js
Normal file
@@ -0,0 +1,29 @@
|
||||
import js from "@eslint/js";
|
||||
import globals from "globals";
|
||||
import reactHooks from "eslint-plugin-react-hooks";
|
||||
import reactRefresh from "eslint-plugin-react-refresh";
|
||||
import { defineConfig, globalIgnores } from "eslint/config";
|
||||
|
||||
export default defineConfig([
|
||||
globalIgnores(["dist"]),
|
||||
{
|
||||
files: ["**/*.{js,jsx}"],
|
||||
extends: [
|
||||
js.configs.recommended,
|
||||
reactHooks.configs["recommended-latest"],
|
||||
reactRefresh.configs.vite,
|
||||
],
|
||||
languageOptions: {
|
||||
ecmaVersion: 2020,
|
||||
globals: globals.browser,
|
||||
parserOptions: {
|
||||
ecmaVersion: "latest",
|
||||
ecmaFeatures: { jsx: true },
|
||||
sourceType: "module",
|
||||
},
|
||||
},
|
||||
rules: {
|
||||
"no-unused-vars": ["error", { varsIgnorePattern: "^[A-Z_]" }],
|
||||
},
|
||||
},
|
||||
]);
|
||||
68
frontend/index.css
Normal file
68
frontend/index.css
Normal file
@@ -0,0 +1,68 @@
|
||||
@tailwind base;
|
||||
@tailwind components;
|
||||
@tailwind utilities;
|
||||
|
||||
@layer base {
|
||||
:root {
|
||||
--background: 0 0% 100%;
|
||||
--foreground: 0 0% 3.9%;
|
||||
--card: 0 0% 100%;
|
||||
--card-foreground: 0 0% 3.9%;
|
||||
--popover: 0 0% 100%;
|
||||
--popover-foreground: 0 0% 3.9%;
|
||||
--primary: 0 0% 9%;
|
||||
--primary-foreground: 0 0% 98%;
|
||||
--secondary: 0 0% 96.1%;
|
||||
--secondary-foreground: 0 0% 9%;
|
||||
--muted: 0 0% 96.1%;
|
||||
--muted-foreground: 0 0% 45.1%;
|
||||
--accent: 0 0% 96.1%;
|
||||
--accent-foreground: 0 0% 9%;
|
||||
--destructive: 0 84.2% 60.2%;
|
||||
--destructive-foreground: 0 0% 98%;
|
||||
--border: 0 0% 89.8%;
|
||||
--input: 0 0% 89.8%;
|
||||
--ring: 0 0% 3.9%;
|
||||
--chart-1: 12 76% 61%;
|
||||
--chart-2: 173 58% 39%;
|
||||
--chart-3: 197 37% 24%;
|
||||
--chart-4: 43 74% 66%;
|
||||
--chart-5: 27 87% 67%;
|
||||
--radius: 0.5rem
|
||||
}
|
||||
.dark {
|
||||
--background: 0 0% 3.9%;
|
||||
--foreground: 0 0% 98%;
|
||||
--card: 0 0% 3.9%;
|
||||
--card-foreground: 0 0% 98%;
|
||||
--popover: 0 0% 3.9%;
|
||||
--popover-foreground: 0 0% 98%;
|
||||
--primary: 0 0% 98%;
|
||||
--primary-foreground: 0 0% 9%;
|
||||
--secondary: 0 0% 14.9%;
|
||||
--secondary-foreground: 0 0% 98%;
|
||||
--muted: 0 0% 14.9%;
|
||||
--muted-foreground: 0 0% 63.9%;
|
||||
--accent: 0 0% 14.9%;
|
||||
--accent-foreground: 0 0% 98%;
|
||||
--destructive: 0 62.8% 30.6%;
|
||||
--destructive-foreground: 0 0% 98%;
|
||||
--border: 0 0% 14.9%;
|
||||
--input: 0 0% 14.9%;
|
||||
--ring: 0 0% 83.1%;
|
||||
--chart-1: 220 70% 50%;
|
||||
--chart-2: 160 60% 45%;
|
||||
--chart-3: 30 80% 55%;
|
||||
--chart-4: 280 65% 60%;
|
||||
--chart-5: 340 75% 55%
|
||||
}
|
||||
}
|
||||
|
||||
@layer base {
|
||||
* {
|
||||
@apply border-border;
|
||||
}
|
||||
body {
|
||||
@apply bg-background text-foreground;
|
||||
}
|
||||
}
|
||||
14
frontend/index.html
Normal file
14
frontend/index.html
Normal file
@@ -0,0 +1,14 @@
|
||||
<!doctype html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<script src="https://cdn.tailwindcss.com"></script>
|
||||
<meta charset="UTF-8" />
|
||||
<link rel="icon" type="image/png" href="/trading_logo.png" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>EvoTraders</title>
|
||||
</head>
|
||||
<body>
|
||||
<div id="root"></div>
|
||||
<script type="module" src="/src/main.jsx"></script>
|
||||
</body>
|
||||
</html>
|
||||
55
frontend/package.json
Normal file
55
frontend/package.json
Normal file
@@ -0,0 +1,55 @@
|
||||
{
|
||||
"name": "live-trading-demo",
|
||||
"private": true,
|
||||
"version": "0.0.0",
|
||||
"type": "module",
|
||||
"scripts": {
|
||||
"dev": "vite",
|
||||
"build": "vite build",
|
||||
"lint": "eslint .",
|
||||
"preview": "vite preview",
|
||||
"preview:host": "vite preview --host"
|
||||
},
|
||||
"dependencies": {
|
||||
"@radix-ui/react-dialog": "^1.1.15",
|
||||
"@radix-ui/react-dropdown-menu": "^2.1.16",
|
||||
"@radix-ui/react-label": "^2.1.7",
|
||||
"@radix-ui/react-slider": "^1.3.6",
|
||||
"@radix-ui/react-slot": "^1.2.3",
|
||||
"@radix-ui/react-switch": "^1.2.6",
|
||||
"@radix-ui/react-tabs": "^1.1.13",
|
||||
"@radix-ui/react-tooltip": "^1.2.8",
|
||||
"@react-three/drei": "^10.7.6",
|
||||
"@react-three/fiber": "^9.3.0",
|
||||
"@tailwindcss/vite": "^4.1.13",
|
||||
"class-variance-authority": "^0.7.1",
|
||||
"clsx": "^2.1.1",
|
||||
"framer-motion": "^12.23.13",
|
||||
"lucide-react": "^0.544.0",
|
||||
"react": "^19.1.1",
|
||||
"react-dom": "^19.1.1",
|
||||
"react-markdown": "^10.1.0",
|
||||
"recharts": "^3.2.1",
|
||||
"remark-gfm": "^4.0.1",
|
||||
"tailwind-merge": "^3.3.1",
|
||||
"three": "^0.180.0",
|
||||
"zustand": "^5.0.8"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@eslint/js": "^9.33.0",
|
||||
"@types/react": "^19.1.13",
|
||||
"@types/react-dom": "^19.1.9",
|
||||
"@vitejs/plugin-react": "^5.0.0",
|
||||
"autoprefixer": "^10.4.21",
|
||||
"eslint": "^9.33.0",
|
||||
"eslint-plugin-react-hooks": "^5.2.0",
|
||||
"eslint-plugin-react-refresh": "^0.4.20",
|
||||
"globals": "^16.3.0",
|
||||
"postcss": "^8.5.6",
|
||||
"tailwindcss": "^3.4.17",
|
||||
"tailwindcss-animate": "^1.0.7",
|
||||
"typescript": "^5.9.2",
|
||||
"vite": "^7.1.2",
|
||||
"vite-tsconfig-paths": "^5.1.4"
|
||||
}
|
||||
}
|
||||
6
frontend/postcss.config.js
Normal file
6
frontend/postcss.config.js
Normal file
@@ -0,0 +1,6 @@
|
||||
export default {
|
||||
plugins: {
|
||||
tailwindcss: {},
|
||||
autoprefixer: {},
|
||||
},
|
||||
};
|
||||
BIN
frontend/public/trading_logo.png
Normal file
BIN
frontend/public/trading_logo.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 40 KiB |
42
frontend/src/App.css
Normal file
42
frontend/src/App.css
Normal file
@@ -0,0 +1,42 @@
|
||||
#root {
|
||||
max-width: 1280px;
|
||||
margin: 0 auto;
|
||||
padding: 2rem;
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
.logo {
|
||||
height: 6em;
|
||||
padding: 1.5em;
|
||||
will-change: filter;
|
||||
transition: filter 300ms;
|
||||
}
|
||||
.logo:hover {
|
||||
filter: drop-shadow(0 0 2em #646cffaa);
|
||||
}
|
||||
.logo.react:hover {
|
||||
filter: drop-shadow(0 0 2em #61dafbaa);
|
||||
}
|
||||
|
||||
@keyframes logo-spin {
|
||||
from {
|
||||
transform: rotate(0deg);
|
||||
}
|
||||
to {
|
||||
transform: rotate(360deg);
|
||||
}
|
||||
}
|
||||
|
||||
@media (prefers-reduced-motion: no-preference) {
|
||||
a:nth-of-type(2) .logo {
|
||||
animation: logo-spin infinite 20s linear;
|
||||
}
|
||||
}
|
||||
|
||||
.card {
|
||||
padding: 2em;
|
||||
}
|
||||
|
||||
.read-the-docs {
|
||||
color: #888;
|
||||
}
|
||||
1034
frontend/src/App.jsx
Normal file
1034
frontend/src/App.jsx
Normal file
File diff suppressed because it is too large
Load Diff
361
frontend/src/components/AboutModal.jsx
Normal file
361
frontend/src/components/AboutModal.jsx
Normal file
@@ -0,0 +1,361 @@
|
||||
import React, { useState } from 'react';
|
||||
import Header from './Header.jsx';
|
||||
|
||||
export default function AboutModal({ onClose }) {
|
||||
const [isClosing, setIsClosing] = useState(false);
|
||||
const [language, setLanguage] = useState('en'); // 'en' or 'zh'
|
||||
|
||||
const handleClose = () => {
|
||||
setIsClosing(true);
|
||||
// Wait for animation to complete before actually closing
|
||||
setTimeout(() => {
|
||||
onClose();
|
||||
}, 600); // Match animation duration
|
||||
};
|
||||
|
||||
const overlayStyle = {
|
||||
position: 'fixed',
|
||||
top: 0,
|
||||
left: 0,
|
||||
right: 0,
|
||||
bottom: 0,
|
||||
background: '#ffffff',
|
||||
zIndex: 9999,
|
||||
animation: isClosing
|
||||
? 'collapseUp 0.6s cubic-bezier(0.4, 0, 0.2, 1) forwards'
|
||||
: 'expandDown 0.6s cubic-bezier(0.4, 0, 0.2, 1)',
|
||||
transformOrigin: 'top center',
|
||||
overflowY: 'auto'
|
||||
};
|
||||
|
||||
const contentStyle = {
|
||||
maxWidth: '900px',
|
||||
width: '90%',
|
||||
margin: '0 auto',
|
||||
textAlign: 'left',
|
||||
fontFamily: "'IBM Plex Mono', monospace",
|
||||
color: '#000000',
|
||||
lineHeight: 1.8,
|
||||
fontSize: '14px',
|
||||
letterSpacing: '0.01em',
|
||||
padding: '60px 20px 80px',
|
||||
animation: isClosing
|
||||
? 'fadeOutContent 0.4s ease forwards'
|
||||
: 'fadeInContent 0.8s ease 0.3s backwards'
|
||||
};
|
||||
|
||||
const highlight = {
|
||||
color: '#615CED',
|
||||
fontWeight: 600
|
||||
};
|
||||
|
||||
const linkStyle = {
|
||||
color: '#615CED',
|
||||
textDecoration: 'none',
|
||||
borderBottom: '1px solid #615CED',
|
||||
transition: 'all 0.2s'
|
||||
};
|
||||
|
||||
const closeHintStyle = {
|
||||
marginTop: '50px',
|
||||
fontSize: '11px',
|
||||
color: '#999',
|
||||
cursor: 'pointer',
|
||||
textAlign: 'center'
|
||||
};
|
||||
|
||||
const languageSwitchStyle = {
|
||||
display: 'flex',
|
||||
justifyContent: 'center',
|
||||
alignItems: 'center',
|
||||
marginBottom: '25px',
|
||||
marginTop: '10px',
|
||||
gap: '0px',
|
||||
fontSize: '11px',
|
||||
fontFamily: "'IBM Plex Mono', monospace"
|
||||
};
|
||||
|
||||
const getLangStyle = (isActive) => ({
|
||||
padding: '3px 8px',
|
||||
cursor: 'pointer',
|
||||
transition: 'all 0.2s',
|
||||
background: isActive ? '#000' : '#fff',
|
||||
color: isActive ? '#fff' : '#000',
|
||||
border: 'none'
|
||||
});
|
||||
|
||||
const content = {
|
||||
en: {
|
||||
|
||||
question: "What happens if AI models don't compete with each other, but instead trade like a ",
|
||||
questionHighlight: "well-coordinated, high-performance team",
|
||||
questionEnd: "?",
|
||||
|
||||
intro: "Not arena, but TEAM. We Hope that AI is no longer entering the financial markets as isolated models—it is stepping in as ",
|
||||
introHighlight1: "teams",
|
||||
introContinue: ", collaborating in one of the most challenging and noise-filled ",
|
||||
introHighlight2: "real-time environments",
|
||||
introContinue2: ".",
|
||||
|
||||
|
||||
point1Highlight: "✦ Complementary skills",
|
||||
point1: " - across multiple agents—data analysis, strategy generation, risk management—working together like a real trading desk, exchanging information through notifications and meetings.",
|
||||
|
||||
point2Highlight: "✦ An agent system that continually evolves",
|
||||
point2: " — with memory modules that retain experience, learn from market feedback, reflect, and develop their own methodology over time.",
|
||||
|
||||
point3Highlight: "✦ AI teams interacting with live markets",
|
||||
point3: " — learning from real-time data and making immediate decisions, not just theoretical simulations."
|
||||
},
|
||||
zh: {
|
||||
intro: "如果不是让模型彼此竞争,而是像一支高效协作的团队一样进行实时交易,会发生什么?",
|
||||
question: "这里不是竞技场,而是团队。我们希望Agents不再单打独斗,而是「组团」进入实时金融市场——这一十分困难且充满噪声的环境。",
|
||||
trying: "我们正在探索多智能体协作在实时金融交易中的可能性。",
|
||||
|
||||
title1: "✦ 多智能体的技能互补",
|
||||
point1: "不同模型、不同角色的智能体像真实的金融团队一样协作,各自承担数据分析、策略生成、风险控制等职责。",
|
||||
point1Sub: "通过通知和会议机制进行信息交换,实现高效协作。",
|
||||
|
||||
title2: "✦ 能够持续进化的智能体系统",
|
||||
point2: "依托「记忆」模块,每个智能体都能跨回合保留经验,不断学习、反思与调整。我们希望能看到在长期实时交易中,Agent形成自己的独特方法论,而不是一次性偶然的推理。",
|
||||
point2Sub: "ReMe 记忆框架帮助 Agents 持续改进。",
|
||||
|
||||
title3: "✦ 实时参与市场的 AI Agents",
|
||||
point3: "Agents从实时行情中学习,并给予即时决策;不是纸上谈兵,而是面对市场的真实波动。"
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<>
|
||||
<style>{`
|
||||
@keyframes expandDown {
|
||||
from {
|
||||
transform: scaleY(0);
|
||||
opacity: 0;
|
||||
}
|
||||
to {
|
||||
transform: scaleY(1);
|
||||
opacity: 1;
|
||||
}
|
||||
}
|
||||
|
||||
@keyframes collapseUp {
|
||||
from {
|
||||
transform: scaleY(1);
|
||||
opacity: 1;
|
||||
}
|
||||
to {
|
||||
transform: scaleY(0);
|
||||
opacity: 0;
|
||||
}
|
||||
}
|
||||
|
||||
@keyframes fadeInContent {
|
||||
from {
|
||||
opacity: 0;
|
||||
transform: translateY(20px);
|
||||
}
|
||||
to {
|
||||
opacity: 1;
|
||||
transform: translateY(0);
|
||||
}
|
||||
}
|
||||
|
||||
@keyframes fadeOutContent {
|
||||
from {
|
||||
opacity: 1;
|
||||
transform: translateY(0);
|
||||
}
|
||||
to {
|
||||
opacity: 0;
|
||||
transform: translateY(-20px);
|
||||
}
|
||||
}
|
||||
`}</style>
|
||||
|
||||
<div style={overlayStyle} onClick={handleClose}>
|
||||
{/* Header */}
|
||||
<div className="header" style={{
|
||||
animation: isClosing
|
||||
? 'fadeOutContent 0.4s ease forwards'
|
||||
: 'fadeInContent 0.8s ease 0.3s backwards'
|
||||
}} onClick={(e) => e.stopPropagation()}>
|
||||
<Header
|
||||
onEvoTradersClick={handleClose}
|
||||
evoTradersLinkStyle="close"
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* Content */}
|
||||
<div style={contentStyle} onClick={(e) => e.stopPropagation()}>
|
||||
{/* Language Switch */}
|
||||
<div style={languageSwitchStyle}>
|
||||
<span
|
||||
style={getLangStyle(language === 'zh')}
|
||||
onClick={() => setLanguage('zh')}
|
||||
>
|
||||
中文
|
||||
</span>
|
||||
<span style={{ padding: '0 4px', color: '#999' }}>|</span>
|
||||
<span
|
||||
style={getLangStyle(language === 'en')}
|
||||
onClick={() => setLanguage('en')}
|
||||
>
|
||||
EN
|
||||
</span>
|
||||
</div>
|
||||
|
||||
{language === 'en' ? (
|
||||
// English Content
|
||||
<>
|
||||
|
||||
<div style={{ marginBottom: '40px', fontSize: '15px', fontWeight: 600 }}>
|
||||
{content.en.question}
|
||||
<span style={highlight}>{content.en.questionHighlight}</span>
|
||||
{content.en.questionEnd}
|
||||
</div>
|
||||
|
||||
<div style={{ marginBottom: '30px' }}>
|
||||
{content.en.intro}
|
||||
<span style={highlight}>{content.en.introHighlight1}</span>
|
||||
{content.en.introContinue}
|
||||
<span style={highlight}>{content.en.introHighlight2}</span>
|
||||
{content.en.introContinue2}
|
||||
</div>
|
||||
|
||||
<div style={{ marginBottom: '25px' }}>
|
||||
<span style={highlight}>{content.en.point1Highlight}</span>
|
||||
{content.en.point1}
|
||||
</div>
|
||||
|
||||
<div style={{ marginBottom: '25px' }}>
|
||||
<span style={highlight}>{content.en.point2Highlight}</span>
|
||||
{content.en.point2}
|
||||
</div>
|
||||
|
||||
<div style={{ marginBottom: '40px' }}>
|
||||
<span style={highlight}>{content.en.point3Highlight}</span>
|
||||
{content.en.point3}
|
||||
</div>
|
||||
|
||||
<div style={{ marginBottom: '25px', opacity: 0.7 }}>
|
||||
Everything is fully open-source. Built on{' '}
|
||||
<a
|
||||
href="https://github.com/agentscope-ai"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
style={linkStyle}
|
||||
>
|
||||
AgentScope
|
||||
</a>
|
||||
, using{' '}
|
||||
<a
|
||||
href="https://github.com/agentscope-ai/ReMe"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
style={linkStyle}
|
||||
>
|
||||
ReMe
|
||||
</a>
|
||||
{' '}for memory management.
|
||||
</div>
|
||||
</>
|
||||
) : (
|
||||
// Chinese Content
|
||||
<>
|
||||
<div style={{ marginBottom: '30px' }}>
|
||||
{content.zh.intro}
|
||||
</div>
|
||||
|
||||
<div style={{ marginBottom: '40px', fontSize: '15px', fontWeight: 600 }}>
|
||||
{content.zh.question}
|
||||
</div>
|
||||
|
||||
<div style={{ marginBottom: '30px', fontSize: '14px', opacity: 0.8 }}>
|
||||
{content.zh.trying}
|
||||
</div>
|
||||
|
||||
<div style={{ marginBottom: '30px' }}>
|
||||
<div style={{ ...highlight, marginBottom: '10px' }}>
|
||||
{content.zh.title1}
|
||||
</div>
|
||||
<div style={{ marginBottom: '10px' }}>
|
||||
{content.zh.point1}
|
||||
</div>
|
||||
<div style={{ fontSize: '13px', opacity: 0.7 }}>
|
||||
{content.zh.point1Sub}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div style={{ marginBottom: '30px' }}>
|
||||
<div style={{ ...highlight, marginBottom: '10px' }}>
|
||||
{content.zh.title2}
|
||||
</div>
|
||||
<div style={{ marginBottom: '10px' }}>
|
||||
{content.zh.point2}
|
||||
</div>
|
||||
<div style={{ fontSize: '13px', opacity: 0.7 }}>
|
||||
{content.zh.point2Sub}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div style={{ marginBottom: '30px' }}>
|
||||
<div style={{ ...highlight, marginBottom: '10px' }}>
|
||||
{content.zh.title3}
|
||||
</div>
|
||||
<div>
|
||||
{content.zh.point3}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div style={{ marginBottom: '10px', opacity: 0.7 }}>
|
||||
我们已经在github上开源。
|
||||
</div>
|
||||
<div style={{ marginBottom: '25px', opacity: 0.7 }}>
|
||||
EvoTraders 基于{' '}
|
||||
<a
|
||||
href="https://github.com/agentscope-ai"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
style={linkStyle}
|
||||
>
|
||||
AgentScope
|
||||
</a>
|
||||
{' '}搭建,并使用其中的{' '}
|
||||
<a
|
||||
href="https://github.com/agentscope-ai/ReMe"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
style={linkStyle}
|
||||
>
|
||||
ReMe
|
||||
</a>
|
||||
{' '}作为记忆管理核心。
|
||||
</div>
|
||||
|
||||
<div style={{ marginBottom: '10px', fontSize: '14px' }}>
|
||||
你可以在此找到完整项目与示例:
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
|
||||
<div style={{ marginTop: '40px' }}>
|
||||
<a
|
||||
href="https://github.com/agentscope-ai/agentscope-samples"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
style={linkStyle}
|
||||
>
|
||||
github.com/agentscope-ai/agentscope-samples
|
||||
</a>
|
||||
</div>
|
||||
|
||||
<div style={closeHintStyle} onClick={handleClose}>
|
||||
{language === 'en' ? 'Click here to close' : '点击此处关闭'}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
517
frontend/src/components/AgentCard.jsx
Normal file
517
frontend/src/components/AgentCard.jsx
Normal file
@@ -0,0 +1,517 @@
|
||||
import React from 'react';
|
||||
import { ASSETS } from '../config/constants';
|
||||
import { getModelIcon, getShortModelName } from '../utils/modelIcons';
|
||||
|
||||
/**
|
||||
* Get rank medal/trophy
|
||||
*/
|
||||
function getRankMedal(rank) {
|
||||
if (rank === 1) return { emoji: '🏆', color: '#FFD700', label: '金牌' };
|
||||
if (rank === 2) return { emoji: '🥈', color: '#C0C0C0', label: '银牌' };
|
||||
if (rank === 3) return { emoji: '🥉', color: '#CD7F32', label: '铜牌' };
|
||||
return { emoji: `#${rank}`, color: '#333333', label: `#${rank}` };
|
||||
}
|
||||
|
||||
/**
|
||||
* Agent Performance Card Component
|
||||
* Horizontal dropdown panel displayed below the agent indicator bar
|
||||
*/
|
||||
export default function AgentCard({ agent, onClose, isClosing }) {
|
||||
if (!agent) return null;
|
||||
|
||||
const bullTotal = agent.bull?.n || 0;
|
||||
const bullWins = agent.bull?.win || 0;
|
||||
const bullUnknown = agent.bull?.unknown || 0;
|
||||
const bearTotal = agent.bear?.n || 0;
|
||||
const bearWins = agent.bear?.win || 0;
|
||||
const bearUnknown = agent.bear?.unknown || 0;
|
||||
const totalSignals = bullTotal + bearTotal;
|
||||
const evaluatedBull = Math.max(bullTotal - bullUnknown, 0);
|
||||
const evaluatedBear = Math.max(bearTotal - bearUnknown, 0);
|
||||
const evaluatedTotal = evaluatedBull + evaluatedBear;
|
||||
const bullWinRate = evaluatedBull > 0 ? (bullWins / evaluatedBull) : null;
|
||||
const bearWinRate = evaluatedBear > 0 ? (bearWins / evaluatedBear) : null;
|
||||
const overallWinRate = agent.winRate != null
|
||||
? agent.winRate
|
||||
: (evaluatedTotal > 0 ? ((bullWins + bearWins) / evaluatedTotal) : null);
|
||||
const overallColor = overallWinRate != null
|
||||
? (overallWinRate >= 0.5 ? '#00C853' : '#FF1744')
|
||||
: '#555555';
|
||||
|
||||
const rankMedal = agent.rank ? getRankMedal(agent.rank) : null;
|
||||
const isPortfolioManager = agent.id === 'portfolio_manager';
|
||||
const isRiskManager = agent.id === 'risk_manager';
|
||||
const displayName = isPortfolioManager ? '团队' : agent.name;
|
||||
|
||||
// Get model icon configuration
|
||||
const modelInfo = getModelIcon(agent.modelName, agent.modelProvider);
|
||||
const shortModelName = getShortModelName(agent.modelName);
|
||||
|
||||
return (
|
||||
<div style={{
|
||||
position: 'absolute',
|
||||
top: 0,
|
||||
left: 0,
|
||||
right: 0,
|
||||
background: '#ffffff',
|
||||
borderBottom: '2px solid #000000',
|
||||
boxShadow: '0 4px 12px rgba(0, 0, 0, 0.1)',
|
||||
zIndex: 1000,
|
||||
animation: isClosing ? 'slideUp 0.2s ease-out forwards' : 'slideDown 0.25s ease-out'
|
||||
}}>
|
||||
{/* Horizontal scrollable content */}
|
||||
<div style={{
|
||||
overflowX: 'auto',
|
||||
overflowY: 'hidden',
|
||||
padding: '12px',
|
||||
|
||||
/* Hide scrollbar for all browsers */
|
||||
scrollbarWidth: 'none', /* Firefox */
|
||||
msOverflowStyle: 'none', /* IE and Edge */
|
||||
}}>
|
||||
<style>
|
||||
{`
|
||||
div::-webkit-scrollbar {
|
||||
display: none; /* Chrome, Safari, Opera */
|
||||
}
|
||||
`}
|
||||
</style>
|
||||
|
||||
<div style={{
|
||||
display: 'flex',
|
||||
gap: '12px',
|
||||
minWidth: 'max-content'
|
||||
}}>
|
||||
{/* Agent Info with Rank */}
|
||||
<div style={{
|
||||
display: 'flex',
|
||||
alignItems: 'center',
|
||||
gap: 10,
|
||||
padding: '8px 12px',
|
||||
background: '#fafafa',
|
||||
border: '2px solid #000000',
|
||||
minWidth: 200
|
||||
}}>
|
||||
{isPortfolioManager ? (
|
||||
<img
|
||||
src={ASSETS.teamLogo}
|
||||
alt="Team"
|
||||
style={{
|
||||
height: 50,
|
||||
width: 50,
|
||||
objectFit: 'contain'
|
||||
}}
|
||||
/>
|
||||
) : agent.avatar ? (
|
||||
<img
|
||||
src={agent.avatar}
|
||||
alt={agent.name}
|
||||
style={{
|
||||
height: 50,
|
||||
width: 50,
|
||||
objectFit: 'contain'
|
||||
}}
|
||||
/>
|
||||
) : null}
|
||||
<div>
|
||||
<div style={{
|
||||
fontSize: 16,
|
||||
fontWeight: 700,
|
||||
color: '#000000',
|
||||
marginBottom: 2
|
||||
}}>
|
||||
{displayName}
|
||||
</div>
|
||||
{rankMedal && !isPortfolioManager && (
|
||||
<div style={{ fontSize: 18 }}>
|
||||
{rankMedal.emoji} Rank #{agent.rank}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Risk Manager Note */}
|
||||
{isRiskManager && (
|
||||
<div style={{
|
||||
padding: '8px 12px',
|
||||
background: '#FFF9E6',
|
||||
border: '2px solid #FFA726',
|
||||
minWidth: 220,
|
||||
maxWidth: 280,
|
||||
display: 'flex',
|
||||
alignItems: 'center'
|
||||
}}>
|
||||
<div style={{
|
||||
fontSize: 12,
|
||||
color: '#E65100',
|
||||
fontStyle: 'italic',
|
||||
lineHeight: 1.5,
|
||||
whiteSpace: 'normal',
|
||||
wordWrap: 'break-word'
|
||||
}}>
|
||||
ⓘ 风控经理专注于风险管理,不参与预测准确率排名。
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Portfolio Manager Note */}
|
||||
{isPortfolioManager && (
|
||||
<div style={{
|
||||
padding: '8px 12px',
|
||||
background: '#E8F5E9',
|
||||
border: '2px solid #66BB6A',
|
||||
minWidth: 220,
|
||||
maxWidth: 280,
|
||||
display: 'flex',
|
||||
alignItems: 'center'
|
||||
}}>
|
||||
<div style={{
|
||||
fontSize: 12,
|
||||
color: '#2E7D32',
|
||||
fontStyle: 'italic',
|
||||
lineHeight: 1.5,
|
||||
whiteSpace: 'normal',
|
||||
wordWrap: 'break-word'
|
||||
}}>
|
||||
ⓘ 投资经理综合所有分析师建议,提供团队最终交易信号,不参与排名。
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Model Info Card */}
|
||||
{agent.modelName && (
|
||||
<div style={{
|
||||
padding: '8px 12px',
|
||||
background: '#ffffff',
|
||||
border: `2px solid ${modelInfo.color}`,
|
||||
minWidth: 140,
|
||||
position: 'relative',
|
||||
cursor: 'help'
|
||||
}}
|
||||
title={`Model: ${agent.modelName}\nProvider: ${modelInfo.provider}`}>
|
||||
<div style={{
|
||||
fontSize: 10,
|
||||
fontWeight: 700,
|
||||
color: modelInfo.color,
|
||||
letterSpacing: 1,
|
||||
marginBottom: 4,
|
||||
textTransform: 'uppercase'
|
||||
}}>
|
||||
模型
|
||||
</div>
|
||||
<div style={{
|
||||
height: 40,
|
||||
display: 'flex',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
marginBottom: 4
|
||||
}}>
|
||||
{modelInfo.logoPath ? (
|
||||
<img
|
||||
src={modelInfo.logoPath}
|
||||
alt={modelInfo.provider}
|
||||
style={{
|
||||
maxHeight: '100%',
|
||||
maxWidth: '100%',
|
||||
objectFit: 'contain'
|
||||
}}
|
||||
/>
|
||||
) : (
|
||||
<div style={{
|
||||
fontSize: 28,
|
||||
lineHeight: 1
|
||||
}}>
|
||||
🤖
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
<div style={{
|
||||
fontSize: 11,
|
||||
fontWeight: 600,
|
||||
color: modelInfo.color,
|
||||
whiteSpace: 'nowrap',
|
||||
overflow: 'hidden',
|
||||
textOverflow: 'ellipsis'
|
||||
}}>
|
||||
{shortModelName}
|
||||
</div>
|
||||
<div style={{
|
||||
fontSize: 8,
|
||||
color: '#666666',
|
||||
marginTop: 2
|
||||
}}>
|
||||
{modelInfo.provider}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Overall Win Rate */}
|
||||
{!isRiskManager && !isPortfolioManager && (
|
||||
<div style={{
|
||||
padding: '8px 14px',
|
||||
background: '#fafafa',
|
||||
border: '2px solid #e0e0e0',
|
||||
textAlign: 'center',
|
||||
minWidth: 160
|
||||
}}>
|
||||
<div style={{
|
||||
fontSize: 10,
|
||||
color: '#333333',
|
||||
fontWeight: 700,
|
||||
letterSpacing: 1,
|
||||
marginBottom: 4,
|
||||
textTransform: 'uppercase'
|
||||
}}>
|
||||
胜率
|
||||
</div>
|
||||
<div style={{
|
||||
fontSize: 36,
|
||||
fontWeight: 700,
|
||||
color: overallColor,
|
||||
fontFamily: '"Courier New", monospace',
|
||||
lineHeight: 1,
|
||||
marginBottom: 2
|
||||
}}>
|
||||
{overallWinRate != null ? `${(overallWinRate * 100).toFixed(1)}%` : 'N/A'}
|
||||
</div>
|
||||
<div style={{
|
||||
fontSize: 9,
|
||||
color: '#555555'
|
||||
}}>
|
||||
{bullWins + bearWins}胜 / {evaluatedTotal}评
|
||||
</div>
|
||||
<div style={{
|
||||
fontSize: 8,
|
||||
color: '#888888',
|
||||
marginTop: 4,
|
||||
fontStyle: 'italic',
|
||||
lineHeight: 1.2,
|
||||
whiteSpace: 'pre-line'
|
||||
}}>
|
||||
评估: 总评估多空信号数。{'\n'}胜率 = 正确信号 / 总评估信号
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Bull Stats */}
|
||||
{!isRiskManager && !isPortfolioManager && (
|
||||
<div style={{
|
||||
padding: '8px 12px',
|
||||
background: '#F0FFF4',
|
||||
border: '2px solid #00C853',
|
||||
minWidth: 140
|
||||
}}>
|
||||
<div style={{
|
||||
fontSize: 10,
|
||||
fontWeight: 700,
|
||||
color: '#00C853',
|
||||
letterSpacing: 1,
|
||||
marginBottom: 4,
|
||||
textTransform: 'uppercase'
|
||||
}}>
|
||||
牛市胜率
|
||||
</div>
|
||||
<div style={{
|
||||
fontSize: 28,
|
||||
fontWeight: 700,
|
||||
color: bullWinRate != null ? (bullWinRate >= 0.5 ? '#00C853' : '#333333') : '#555555',
|
||||
marginBottom: 2,
|
||||
lineHeight: 1
|
||||
}}>
|
||||
{bullWinRate != null ? `${(bullWinRate * 100).toFixed(1)}%` : 'N/A'}
|
||||
</div>
|
||||
<div style={{
|
||||
fontSize: 9,
|
||||
color: '#333333'
|
||||
}}>
|
||||
{bullWins}胜 / {evaluatedBull}评
|
||||
{bullUnknown > 0 && ` / ${bullUnknown}P`}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Bear Stats */}
|
||||
{!isRiskManager && !isPortfolioManager && (
|
||||
<div style={{
|
||||
padding: '8px 12px',
|
||||
background: '#FFF5F5',
|
||||
border: '2px solid #FF1744',
|
||||
minWidth: 140
|
||||
}}>
|
||||
<div style={{
|
||||
fontSize: 10,
|
||||
fontWeight: 700,
|
||||
color: '#FF1744',
|
||||
letterSpacing: 1,
|
||||
marginBottom: 4,
|
||||
textTransform: 'uppercase'
|
||||
}}>
|
||||
熊市胜率
|
||||
</div>
|
||||
<div style={{
|
||||
fontSize: 28,
|
||||
fontWeight: 700,
|
||||
color: bearWinRate != null ? (bearWinRate >= 0.5 ? '#00C853' : '#333333') : '#555555',
|
||||
marginBottom: 2,
|
||||
lineHeight: 1
|
||||
}}>
|
||||
{bearWinRate != null ? `${(bearWinRate * 100).toFixed(1)}%` : 'N/A'}
|
||||
</div>
|
||||
<div style={{
|
||||
fontSize: 9,
|
||||
color: '#333333'
|
||||
}}>
|
||||
{bearWins}胜 / {evaluatedBear}评
|
||||
{bearUnknown > 0 && ` / ${bearUnknown}P`}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Recent Signals - Horizontal scroll */}
|
||||
{agent.signals && agent.signals.length > 0 && (
|
||||
<div style={{
|
||||
display: 'flex',
|
||||
gap: 6,
|
||||
padding: '8px 12px',
|
||||
background: '#fafafa',
|
||||
border: '2px solid #e0e0e0'
|
||||
}}>
|
||||
{[...agent.signals]
|
||||
.filter(signal => signal && signal.signal)
|
||||
.sort((a, b) => {
|
||||
// Sort by date descending (newest first)
|
||||
const dateA = a.date || '';
|
||||
const dateB = b.date || '';
|
||||
return dateB.localeCompare(dateA);
|
||||
})
|
||||
.slice(0, 35)
|
||||
.map((signal, idx) => {
|
||||
const signalType = signal.signal.toLowerCase();
|
||||
const isBull = signalType.includes('bull') || signalType === 'long';
|
||||
const isBear = signalType.includes('bear') || signalType === 'short';
|
||||
const isNeutral = (!isBull && !isBear) || signalType.includes('neutral') || signalType === 'hold';
|
||||
const isCorrect = signal.is_correct === true;
|
||||
const isUnknown = signal.is_correct === 'unknown' || signal.is_correct === null;
|
||||
|
||||
// Determine result symbol/text and color: unknown has priority over neutral
|
||||
let resultDisplay;
|
||||
let resultColor = '#555555';
|
||||
let resultFontSize = 18;
|
||||
|
||||
if (isNeutral) {
|
||||
resultDisplay = '-';
|
||||
resultColor = '#555555'; // Gray for neutral
|
||||
} else if (isUnknown) {
|
||||
resultDisplay = '?';
|
||||
resultColor = '#FFA726'; // Orange for unknown
|
||||
resultFontSize = 14; // Smaller font for text
|
||||
} else {
|
||||
resultDisplay = isCorrect ? '✓' : '✗';
|
||||
resultColor = isCorrect ? '#00C853' : '#FF1744'; // Green for correct, Red for wrong
|
||||
}
|
||||
|
||||
return (
|
||||
<div key={idx} style={{
|
||||
fontSize: 9,
|
||||
fontFamily: '"Courier New", monospace',
|
||||
padding: '6px 8px',
|
||||
background: '#ffffff',
|
||||
border: '1px solid #e0e0e0',
|
||||
display: 'flex',
|
||||
flexDirection: 'column',
|
||||
alignItems: 'center',
|
||||
gap: 3,
|
||||
minWidth: 70
|
||||
}}>
|
||||
<div style={{
|
||||
fontWeight: 700,
|
||||
color: isBull ? '#00C853' : isBear ? '#FF1744' : '#555555'
|
||||
}}>
|
||||
{signal.ticker}
|
||||
</div>
|
||||
<div style={{
|
||||
fontSize: 16,
|
||||
color: isBull ? '#00C853' : isBear ? '#FF1744' : '#555555'
|
||||
}}>
|
||||
{isBull ? '看涨' : isBear ? '看跌' : '中性'}
|
||||
</div>
|
||||
<div style={{
|
||||
fontSize: 8,
|
||||
color: '#555555'
|
||||
}}>
|
||||
{signal.date?.substring(5, 10) || 'N/A'}
|
||||
</div>
|
||||
<div style={{
|
||||
fontSize: resultFontSize,
|
||||
fontWeight: 700,
|
||||
color: resultColor
|
||||
}}>
|
||||
{resultDisplay}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
{/* Info card explaining signal display */}
|
||||
<div style={{
|
||||
fontSize: 9,
|
||||
fontFamily: '"Courier New", monospace',
|
||||
padding: '6px 8px',
|
||||
background: '#E3F2FD',
|
||||
border: '1px solid #90CAF9',
|
||||
display: 'flex',
|
||||
flexDirection: 'column',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
gap: 2,
|
||||
minWidth: 70,
|
||||
textAlign: 'center'
|
||||
}}>
|
||||
<div style={{
|
||||
fontSize: 10,
|
||||
fontWeight: 700,
|
||||
color: '#1976D2'
|
||||
}}>
|
||||
ⓘ 说明
|
||||
</div>
|
||||
<div style={{
|
||||
fontSize: 8,
|
||||
color: '#1976D2',
|
||||
lineHeight: 1.2
|
||||
}}>
|
||||
仅显示最近5个交易日(1周)的信号
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<style>
|
||||
{`
|
||||
@keyframes slideDown {
|
||||
from {
|
||||
opacity: 0;
|
||||
transform: translateY(-20px);
|
||||
}
|
||||
to {
|
||||
opacity: 1;
|
||||
transform: translateY(0);
|
||||
}
|
||||
}
|
||||
|
||||
@keyframes slideUp {
|
||||
from {
|
||||
opacity: 1;
|
||||
transform: translateY(0);
|
||||
}
|
||||
to {
|
||||
opacity: 0;
|
||||
transform: translateY(-20px);
|
||||
}
|
||||
}
|
||||
`}
|
||||
</style>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
641
frontend/src/components/AgentFeed.jsx
Normal file
641
frontend/src/components/AgentFeed.jsx
Normal file
@@ -0,0 +1,641 @@
|
||||
import React, { useState, useRef, useImperativeHandle, forwardRef } from 'react';
|
||||
import { formatTime } from '../utils/formatters';
|
||||
import { MESSAGE_COLORS, getAgentColors, AGENTS, ASSETS } from '../config/constants';
|
||||
import { getModelIcon } from '../utils/modelIcons';
|
||||
import MarkdownModal from './MarkdownModal';
|
||||
|
||||
const isAnalyst = (agentId, agentName) => {
|
||||
if (agentId && agentId.includes('analyst')) return true;
|
||||
if (agentName && agentName.toLowerCase().includes('analyst')) return true;
|
||||
return false;
|
||||
};
|
||||
|
||||
const isManager = (agentId, agentName) => {
|
||||
if (agentId && agentId.includes('manager')) return true;
|
||||
if (agentName && agentName.toLowerCase().includes('manager')) return true;
|
||||
return false;
|
||||
};
|
||||
|
||||
const stripMarkdown = (text) => {
|
||||
return text
|
||||
.replace(/<think>[\s\S]*?<\/think>/gi, '')
|
||||
.replace(/#{1,6}\s+/g, '')
|
||||
.replace(/\*\*\*(.+?)\*\*\*/g, '$1')
|
||||
.replace(/\*\*(.+?)\*\*/g, '$1')
|
||||
.replace(/__(.+?)__/g, '$1')
|
||||
.replace(/\*(.+?)\*/g, '$1')
|
||||
.replace(/_(.+?)_/g, '$1')
|
||||
.replace(/`(.+?)`/g, '$1')
|
||||
.replace(/\[(.+?)\]\(.+?\)/g, '$1')
|
||||
.replace(/!\[.*?\]\(.+?\)/g, '')
|
||||
.replace(/^\s*[-*+]\s+/gm, '')
|
||||
.replace(/^\s*\d+\.\s+/gm, '')
|
||||
.replace(/^\s*>\s+/gm, '')
|
||||
.replace(/\|/g, ' ')
|
||||
.replace(/^[-=]+$/gm, '');
|
||||
};
|
||||
|
||||
const AgentFeed = forwardRef(({ feed, leaderboard }, ref) => {
|
||||
const feedContentRef = useRef(null);
|
||||
const [highlightedId, setHighlightedId] = useState(null);
|
||||
const [selectedAgent, setSelectedAgent] = useState('all');
|
||||
const [dropdownOpen, setDropdownOpen] = useState(false);
|
||||
|
||||
const getAgentModelInfo = (agentId) => {
|
||||
if (!leaderboard || !agentId) return { modelName: null, modelProvider: null };
|
||||
const agentData = leaderboard.find(lb => lb.id === agentId || lb.agentId === agentId);
|
||||
return {
|
||||
modelName: agentData?.modelName,
|
||||
modelProvider: agentData?.modelProvider
|
||||
};
|
||||
};
|
||||
|
||||
// Get agent info by name
|
||||
const getAgentInfoByName = (agentName) => {
|
||||
if (!leaderboard || !agentName) return null;
|
||||
const agentData = leaderboard.find(lb => lb.name === agentName || lb.agentName === agentName);
|
||||
if (!agentData) return null;
|
||||
return {
|
||||
agentId: agentData.id || agentData.agentId,
|
||||
modelName: agentData.modelName,
|
||||
modelProvider: agentData.modelProvider
|
||||
};
|
||||
};
|
||||
|
||||
// Get unique agent names from feed (only registered agents in AGENTS)
|
||||
const getUniqueAgents = () => {
|
||||
const agentNamesInFeed = new Set();
|
||||
|
||||
// Collect all agent names that appear in the feed
|
||||
feed.forEach(item => {
|
||||
if (item.type === 'message' && item.data?.agent) {
|
||||
agentNamesInFeed.add(item.data.agent);
|
||||
} else if (item.type === 'conference' && item.data?.messages) {
|
||||
item.data.messages.forEach(msg => {
|
||||
if (msg.agent) {
|
||||
agentNamesInFeed.add(msg.agent);
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
// Filter to only include registered agents and sort by AGENTS array order
|
||||
const registeredAgentNames = AGENTS.map(a => a.name);
|
||||
return registeredAgentNames.filter(name => agentNamesInFeed.has(name));
|
||||
};
|
||||
|
||||
// Filter feed based on selected agent
|
||||
const filteredFeed = selectedAgent === 'all'
|
||||
? feed
|
||||
: feed.filter(item => {
|
||||
if (item.type === 'message') {
|
||||
return item.data?.agent === selectedAgent;
|
||||
} else if (item.type === 'conference') {
|
||||
return item.data?.messages?.some(msg => msg.agent === selectedAgent);
|
||||
}
|
||||
return false;
|
||||
});
|
||||
|
||||
useImperativeHandle(ref, () => ({
|
||||
scrollToMessage: (bubble) => {
|
||||
if (!bubble || !feedContentRef.current) return;
|
||||
|
||||
// Direct feedItemId match (used by replay mode)
|
||||
if (bubble.feedItemId) {
|
||||
const element = document.getElementById(`feed-item-${bubble.feedItemId}`);
|
||||
if (element) {
|
||||
element.scrollIntoView({ behavior: 'smooth', block: 'center' });
|
||||
setHighlightedId(bubble.feedItemId);
|
||||
setTimeout(() => setHighlightedId(null), 2000);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
const bubbleTimestamp = bubble.ts || bubble.timestamp;
|
||||
|
||||
// Check if a message matches the bubble
|
||||
const isMatch = (msg, checkTime = true) => {
|
||||
const agentMatch = msg.agentId === bubble.agentId || msg.agent === bubble.agentName;
|
||||
if (!agentMatch || !checkTime) return agentMatch;
|
||||
return Math.abs(msg.timestamp - bubbleTimestamp) < 5000;
|
||||
};
|
||||
|
||||
// Check if a feed item contains the target message
|
||||
const itemContains = (item, checkTime = true) => {
|
||||
if (item.type === 'message' && item.data) return isMatch(item.data, checkTime);
|
||||
if (item.type === 'conference' && item.data?.messages) {
|
||||
return item.data.messages.some(msg => isMatch(msg, checkTime));
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
// Find exact match first, then fallback to agent match
|
||||
const targetItem = feed.find(item => itemContains(item, true))
|
||||
|| feed.find(item => itemContains(item, false));
|
||||
|
||||
if (targetItem) {
|
||||
const element = document.getElementById(`feed-item-${targetItem.id}`);
|
||||
if (element) {
|
||||
element.scrollIntoView({ behavior: 'smooth', block: 'center' });
|
||||
setHighlightedId(targetItem.id);
|
||||
setTimeout(() => setHighlightedId(null), 2000);
|
||||
}
|
||||
}
|
||||
}
|
||||
}), [feed]);
|
||||
|
||||
const uniqueAgents = getUniqueAgents();
|
||||
|
||||
// Get current selection display info
|
||||
const getCurrentSelectionInfo = () => {
|
||||
if (selectedAgent === 'all') {
|
||||
return { label: 'All Agents', modelInfo: null };
|
||||
}
|
||||
const agentInfo = getAgentInfoByName(selectedAgent);
|
||||
const modelInfo = agentInfo ? getModelIcon(agentInfo.modelName, agentInfo.modelProvider) : null;
|
||||
return { label: selectedAgent, modelInfo };
|
||||
};
|
||||
|
||||
const currentSelection = getCurrentSelectionInfo();
|
||||
|
||||
return (
|
||||
<div className="agent-feed">
|
||||
<div className="agent-feed-header">
|
||||
<h3 className="agent-feed-title">活动 feed</h3>
|
||||
<div className="agent-filter-wrapper">
|
||||
<label className="agent-filter-label">筛选:</label>
|
||||
<div className="custom-select-wrapper">
|
||||
<button
|
||||
className="custom-select-trigger"
|
||||
onClick={() => setDropdownOpen(!dropdownOpen)}
|
||||
onBlur={() => setTimeout(() => setDropdownOpen(false), 200)}
|
||||
>
|
||||
<div className="custom-select-value">
|
||||
{currentSelection.modelInfo?.logoPath && (
|
||||
<img
|
||||
src={currentSelection.modelInfo.logoPath}
|
||||
alt={currentSelection.modelInfo.provider}
|
||||
className="select-model-icon"
|
||||
/>
|
||||
)}
|
||||
<span>{currentSelection.label}</span>
|
||||
</div>
|
||||
<span className="custom-select-arrow">▼</span>
|
||||
</button>
|
||||
{dropdownOpen && (
|
||||
<div className="custom-select-dropdown">
|
||||
<div
|
||||
className={`custom-select-option ${selectedAgent === 'all' ? 'selected' : ''}`}
|
||||
onClick={() => {
|
||||
setSelectedAgent('all');
|
||||
setDropdownOpen(false);
|
||||
}}
|
||||
>
|
||||
<span>全部 Agents</span>
|
||||
</div>
|
||||
{uniqueAgents.map(agent => {
|
||||
const agentInfo = getAgentInfoByName(agent);
|
||||
const modelInfo = agentInfo ? getModelIcon(agentInfo.modelName, agentInfo.modelProvider) : null;
|
||||
return (
|
||||
<div
|
||||
key={agent}
|
||||
className={`custom-select-option ${selectedAgent === agent ? 'selected' : ''}`}
|
||||
onClick={() => {
|
||||
setSelectedAgent(agent);
|
||||
setDropdownOpen(false);
|
||||
}}
|
||||
>
|
||||
{modelInfo?.logoPath && (
|
||||
<img
|
||||
src={modelInfo.logoPath}
|
||||
alt={modelInfo.provider}
|
||||
className="select-model-icon"
|
||||
/>
|
||||
)}
|
||||
<span>{agent}</span>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div className="feed-content" ref={feedContentRef}>
|
||||
{filteredFeed.length === 0 && (
|
||||
<div className="empty-state">
|
||||
{selectedAgent === 'all'
|
||||
? '等待系统更新...'
|
||||
: `${selectedAgent} 没有消息`}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{filteredFeed.map(item => {
|
||||
const isHighlighted = item.id === highlightedId;
|
||||
if (item.type === 'conference') {
|
||||
return <ConferenceItem key={item.id} conference={item.data} itemId={item.id} isHighlighted={isHighlighted} getAgentModelInfo={getAgentModelInfo} />;
|
||||
} else if (item.type === 'memory') {
|
||||
return <MemoryItem key={item.id} memory={item.data} itemId={item.id} isHighlighted={isHighlighted} />;
|
||||
} else if (item.data?.agent === 'System') {
|
||||
return <SystemDivider key={item.id} message={item.data} itemId={item.id} />;
|
||||
} else {
|
||||
return <MessageItem key={item.id} message={item.data} itemId={item.id} isHighlighted={isHighlighted} getAgentModelInfo={getAgentModelInfo} />;
|
||||
}
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
});
|
||||
|
||||
AgentFeed.displayName = 'AgentFeed';
|
||||
|
||||
export default AgentFeed;
|
||||
|
||||
function SystemDivider({ message, itemId }) {
|
||||
const content = String(message.content || '');
|
||||
|
||||
return (
|
||||
<div
|
||||
id={`feed-item-${itemId}`}
|
||||
style={{
|
||||
display: 'flex',
|
||||
alignItems: 'center',
|
||||
padding: '12px 16px',
|
||||
gap: '12px',
|
||||
}}
|
||||
>
|
||||
<div style={{ flex: 1, height: '1px', backgroundColor: '#d0d0d0' }} />
|
||||
<span style={{
|
||||
fontSize: '11px',
|
||||
color: '#888',
|
||||
whiteSpace: 'normal',
|
||||
fontWeight: 500,
|
||||
letterSpacing: '0.3px',
|
||||
}}>
|
||||
{content}
|
||||
</span>
|
||||
<div style={{ flex: 1, height: '1px', backgroundColor: '#d0d0d0' }} />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function ConferenceItem({ conference, itemId, isHighlighted, getAgentModelInfo }) {
|
||||
const colors = MESSAGE_COLORS.conference;
|
||||
|
||||
return (
|
||||
<div
|
||||
id={`feed-item-${itemId}`}
|
||||
className="feed-item"
|
||||
style={{
|
||||
backgroundColor: colors.bg,
|
||||
outline: isHighlighted ? '2px solid #615CED' : 'none',
|
||||
transition: 'outline 0.3s ease'
|
||||
}}
|
||||
>
|
||||
<div className="feed-item-header">
|
||||
<span className="feed-item-title" style={{ color: colors.text }}>
|
||||
会议
|
||||
</span>
|
||||
{conference.isLive && <span className="feed-live-badge">● 实时</span>}
|
||||
<span className="feed-item-time">{formatTime(conference.startTime)}</span>
|
||||
</div>
|
||||
|
||||
<div className="feed-item-subtitle" style={{ color: colors.text }}>
|
||||
{conference.title}
|
||||
</div>
|
||||
|
||||
<div className="conference-messages">
|
||||
{conference.messages.map((msg, idx) => (
|
||||
<ConferenceMessage key={idx} message={msg} getAgentModelInfo={getAgentModelInfo} />
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function ConferenceMessage({ message, getAgentModelInfo }) {
|
||||
const [expanded, setExpanded] = useState(false);
|
||||
|
||||
const agentColors = message.agent === 'System' ? MESSAGE_COLORS.system :
|
||||
message.agent === 'Memory' ? MESSAGE_COLORS.memory :
|
||||
getAgentColors(message.agentId, message.agent);
|
||||
|
||||
const agentModelData = message.agentId && getAgentModelInfo ?
|
||||
getAgentModelInfo(message.agentId) :
|
||||
{ modelName: null, modelProvider: null };
|
||||
const modelInfo = getModelIcon(agentModelData.modelName, agentModelData.modelProvider);
|
||||
|
||||
let content = message.content || '';
|
||||
if (typeof content === 'object') {
|
||||
content = JSON.stringify(content, null, 2);
|
||||
} else {
|
||||
content = String(content);
|
||||
}
|
||||
|
||||
const needsTruncation = content.length > 200;
|
||||
const MAX_EXPANDED_LENGTH = 10000;
|
||||
|
||||
let displayContent = content;
|
||||
if (!expanded && needsTruncation) {
|
||||
displayContent = content.substring(0, 200) + '...';
|
||||
} else if (expanded && content.length > MAX_EXPANDED_LENGTH) {
|
||||
displayContent = content.substring(0, MAX_EXPANDED_LENGTH) + '...';
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="conf-message-item">
|
||||
<div className="conf-agent-name" style={{ color: agentColors.text, display: 'flex', alignItems: 'center', gap: '6px', fontSize: '12px' }}>
|
||||
{modelInfo.logoPath && (
|
||||
<img
|
||||
src={modelInfo.logoPath}
|
||||
alt={modelInfo.provider}
|
||||
style={{
|
||||
width: '20px',
|
||||
height: '20px',
|
||||
borderRadius: '50%',
|
||||
objectFit: 'contain'
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
{message.agent}
|
||||
</div>
|
||||
<div className="conf-message-content-wrapper">
|
||||
<span className="conf-message-content">{stripMarkdown(displayContent)}</span>
|
||||
{needsTruncation && (
|
||||
<button
|
||||
className="conf-expand-btn"
|
||||
onClick={() => setExpanded(!expanded)}
|
||||
>
|
||||
{expanded ? '« 收起' : '更多 »'}
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function MemoryItem({ memory, itemId, isHighlighted }) {
|
||||
const [expanded, setExpanded] = useState(false);
|
||||
const [showTooltip, setShowTooltip] = useState(false);
|
||||
const colors = MESSAGE_COLORS.memory;
|
||||
|
||||
let content = memory.content || '';
|
||||
if (typeof content === 'object') {
|
||||
content = JSON.stringify(content, null, 2);
|
||||
} else {
|
||||
content = String(content);
|
||||
}
|
||||
|
||||
const needsTruncation = content.length > 200;
|
||||
const MAX_EXPANDED_LENGTH = 10000;
|
||||
|
||||
let displayContent = content;
|
||||
if (!expanded && needsTruncation) {
|
||||
displayContent = content.substring(0, 200) + '...';
|
||||
} else if (expanded && content.length > MAX_EXPANDED_LENGTH) {
|
||||
displayContent = content.substring(0, MAX_EXPANDED_LENGTH) + '...';
|
||||
}
|
||||
|
||||
const agentLabel = memory.agent && memory.agent !== 'Memory'
|
||||
? `记忆 · ${memory.agent}`
|
||||
: '记忆';
|
||||
|
||||
return (
|
||||
<div
|
||||
id={`feed-item-${itemId}`}
|
||||
className="feed-item"
|
||||
style={{
|
||||
background: 'linear-gradient(180deg, #F0F9FF 0%, #F6F4FF 100%)',
|
||||
border: '1px solid rgba(0, 194, 255, 0.15)',
|
||||
outline: isHighlighted ? '2px solid #615CED' : 'none',
|
||||
transition: 'outline 0.3s ease',
|
||||
position: 'relative'
|
||||
}}
|
||||
>
|
||||
<div className="feed-item-header">
|
||||
<span className="feed-item-title" style={{ color: colors.text, display: 'flex', alignItems: 'center', gap: '6px' }}>
|
||||
<div
|
||||
style={{ position: 'relative', display: 'inline-flex', alignItems: 'center' }}
|
||||
onMouseEnter={() => setShowTooltip(true)}
|
||||
onMouseLeave={() => setShowTooltip(false)}
|
||||
>
|
||||
<a
|
||||
href="https://github.com/agentscope-ai/ReMe"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
style={{ display: 'flex', alignItems: 'center', textDecoration: 'none' }}
|
||||
>
|
||||
<img
|
||||
src={ASSETS.remeLogo}
|
||||
alt="ReMe"
|
||||
style={{
|
||||
cursor: 'pointer',
|
||||
height: '12px',
|
||||
width: 'auto',
|
||||
objectFit: 'contain',
|
||||
userSelect: 'none',
|
||||
transition: 'all 0.2s ease',
|
||||
opacity: showTooltip ? 1 : 0.9,
|
||||
filter: showTooltip ? 'brightness(1.1)' : 'none'
|
||||
}}
|
||||
/>
|
||||
<span style={{
|
||||
fontSize: '11px',
|
||||
marginLeft: '4px',
|
||||
opacity: showTooltip ? 0.6 : 0,
|
||||
transform: showTooltip ? 'translate(0, 0)' : 'translate(-4px, 2px)',
|
||||
transition: 'all 0.2s cubic-bezier(0.4, 0, 0.2, 1)',
|
||||
color: colors.text,
|
||||
lineHeight: 1,
|
||||
pointerEvents: 'none'
|
||||
}}>
|
||||
↗
|
||||
</span>
|
||||
</a>
|
||||
</div>
|
||||
<span style={{
|
||||
background: 'linear-gradient(90deg, #00C2FF 0%, #5C4CE0 100%)',
|
||||
WebkitBackgroundClip: 'text',
|
||||
WebkitTextFillColor: 'transparent',
|
||||
backgroundClip: 'text',
|
||||
color: 'transparent',
|
||||
fontWeight: 700
|
||||
}}>
|
||||
{agentLabel}
|
||||
</span>
|
||||
</span>
|
||||
<span className="feed-item-time">{formatTime(memory.timestamp)}</span>
|
||||
</div>
|
||||
|
||||
<div style={{
|
||||
position: 'absolute',
|
||||
top: '34px',
|
||||
left: '12px',
|
||||
right: '12px',
|
||||
background: 'rgba(255, 255, 255, 0.9)',
|
||||
backdropFilter: 'blur(4px)',
|
||||
color: '#334155',
|
||||
padding: '10px 14px',
|
||||
borderRadius: '8px',
|
||||
fontSize: '12px',
|
||||
lineHeight: '1.5',
|
||||
zIndex: 100,
|
||||
boxShadow: '0 4px 12px rgba(0, 194, 255, 0.1)',
|
||||
opacity: showTooltip ? 1 : 0,
|
||||
visibility: showTooltip ? 'visible' : 'hidden',
|
||||
transition: 'all 0.2s ease',
|
||||
pointerEvents: 'none',
|
||||
border: '1px solid rgba(0, 194, 255, 0.15)'
|
||||
}}>
|
||||
<div style={{
|
||||
fontWeight: '700',
|
||||
marginBottom: '3px',
|
||||
background: 'linear-gradient(90deg, #00C2FF 0%, #5C4CE0 100%)',
|
||||
WebkitBackgroundClip: 'text',
|
||||
WebkitTextFillColor: 'transparent',
|
||||
backgroundClip: 'text',
|
||||
color: 'transparent',
|
||||
display: 'inline-block'
|
||||
}}>
|
||||
Memory powered by AgentScope-ReMe
|
||||
</div>
|
||||
<div style={{ color: '#475569', opacity: 0.9 }}>
|
||||
Not only retrieves historical memories but also generates suggestions and hints for the current task based on latest context.
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="feed-item-content">{stripMarkdown(displayContent)}</div>
|
||||
|
||||
{needsTruncation && (
|
||||
<button
|
||||
className="feed-expand-btn"
|
||||
onClick={() => setExpanded(!expanded)}
|
||||
>
|
||||
{expanded ? '« 收起' : '更多 »'}
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function MessageItem({ message, itemId, isHighlighted, getAgentModelInfo }) {
|
||||
const [expanded, setExpanded] = useState(false);
|
||||
const [showModal, setShowModal] = useState(false);
|
||||
const [isHovering, setIsHovering] = useState(false);
|
||||
|
||||
const colors = message.agent === 'Memory' ? MESSAGE_COLORS.memory :
|
||||
getAgentColors(message.agentId, message.agent);
|
||||
const title = message.agent === 'Memory' ? '记忆' : message.agent || 'AGENT';
|
||||
|
||||
const agentModelData = message.agentId && getAgentModelInfo ?
|
||||
getAgentModelInfo(message.agentId) :
|
||||
{ modelName: null, modelProvider: null };
|
||||
const modelInfo = getModelIcon(agentModelData.modelName, agentModelData.modelProvider);
|
||||
|
||||
const isAnalystAgent = isAnalyst(message.agentId, message.agent);
|
||||
const isManagerAgent = isManager(message.agentId, message.agent);
|
||||
const useModalView = isAnalystAgent || isManagerAgent;
|
||||
|
||||
let content = message.content || '';
|
||||
if (typeof content === 'object') {
|
||||
content = JSON.stringify(content, null, 2);
|
||||
} else {
|
||||
content = String(content);
|
||||
}
|
||||
|
||||
let displayContent = content;
|
||||
let showExpandButton = false;
|
||||
|
||||
if (useModalView) {
|
||||
displayContent = content.length > 150 ? content.substring(0, 150) + '...' : content;
|
||||
} else {
|
||||
const needsTruncation = content.length > 200;
|
||||
const MAX_EXPANDED_LENGTH = 8000;
|
||||
|
||||
if (!expanded && needsTruncation) {
|
||||
displayContent = content.substring(0, 200) + '...';
|
||||
showExpandButton = true;
|
||||
} else if (expanded && content.length > MAX_EXPANDED_LENGTH) {
|
||||
displayContent = content.substring(0, MAX_EXPANDED_LENGTH) + '...';
|
||||
showExpandButton = needsTruncation;
|
||||
} else {
|
||||
showExpandButton = needsTruncation;
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<div
|
||||
id={`feed-item-${itemId}`}
|
||||
className="feed-item"
|
||||
style={{
|
||||
backgroundColor: colors.bg,
|
||||
outline: isHighlighted ? '2px solid #615CED' : 'none',
|
||||
transition: 'outline 0.3s ease'
|
||||
}}
|
||||
>
|
||||
<div className="feed-item-header">
|
||||
<span className="feed-item-title" style={{ color: colors.text, display: 'flex', alignItems: 'center', gap: '6px', fontSize: '12px' }}>
|
||||
{modelInfo.logoPath && message.agent !== 'Memory' && (
|
||||
<img
|
||||
src={modelInfo.logoPath}
|
||||
alt={modelInfo.provider}
|
||||
style={{
|
||||
width: '20px',
|
||||
height: '20px',
|
||||
borderRadius: '50%',
|
||||
objectFit: 'contain'
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
{title}
|
||||
</span>
|
||||
<span className="feed-item-time">{formatTime(message.timestamp)}</span>
|
||||
</div>
|
||||
|
||||
<div className="feed-item-content">{stripMarkdown(displayContent)}</div>
|
||||
|
||||
{useModalView && (
|
||||
<button
|
||||
onClick={() => setShowModal(true)}
|
||||
onMouseEnter={() => setIsHovering(true)}
|
||||
onMouseLeave={() => setIsHovering(false)}
|
||||
style={{
|
||||
marginTop: '8px',
|
||||
fontSize: '12px',
|
||||
color: isHovering ? '#000' : '#666',
|
||||
fontWeight: '700',
|
||||
background: 'none',
|
||||
border: 'none',
|
||||
cursor: 'pointer',
|
||||
padding: '4px 0',
|
||||
textAlign: 'left',
|
||||
width: '100%',
|
||||
outline: 'none'
|
||||
}}
|
||||
>
|
||||
📄 {isManagerAgent ? '查看决策日志 »' : '查看完整报告 »'}
|
||||
</button>
|
||||
)}
|
||||
|
||||
{showExpandButton && (
|
||||
<button
|
||||
className="feed-expand-btn"
|
||||
onClick={() => setExpanded(!expanded)}
|
||||
>
|
||||
{expanded ? '« 收起' : '更多 »'}
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
{useModalView && (
|
||||
<MarkdownModal
|
||||
isOpen={showModal}
|
||||
onClose={() => setShowModal(false)}
|
||||
content={content}
|
||||
agentName={message.agent}
|
||||
reportType={isManagerAgent ? 'decision' : 'analysis'}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
}
|
||||
253
frontend/src/components/Header.jsx
Normal file
253
frontend/src/components/Header.jsx
Normal file
@@ -0,0 +1,253 @@
|
||||
import React, { useState } from 'react';
|
||||
|
||||
/**
|
||||
* Header Component
|
||||
* Reusable header brand with EvoTraders logo, GitHub link, and Contact Us section
|
||||
*
|
||||
* @param {Function} onEvoTradersClick - Optional callback when EvoTraders is clicked
|
||||
* @param {string} evoTradersLinkStyle - Optional style variant: 'default' | 'close'
|
||||
*/
|
||||
export default function Header({
|
||||
onEvoTradersClick = null,
|
||||
evoTradersLinkStyle = 'default' // 'default' shows ↗, 'close' shows ↙
|
||||
}) {
|
||||
const [activeContactCard, setActiveContactCard] = useState({ yue: false, jiaji: false });
|
||||
const [clickedContactCard, setClickedContactCard] = useState(null);
|
||||
|
||||
const handleEvoTradersClick = () => {
|
||||
if (onEvoTradersClick) {
|
||||
onEvoTradersClick();
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="header-title" style={{ flex: '0 1 auto', minWidth: 0 }}>
|
||||
<span
|
||||
className="header-link"
|
||||
onClick={handleEvoTradersClick}
|
||||
style={{ cursor: 'pointer', padding: '4px 8px', borderRadius: '3px', display: 'inline-flex', alignItems: 'center', gap: '8px' }}
|
||||
>
|
||||
<img
|
||||
src="/trading_logo.png"
|
||||
alt="EvoTraders Logo"
|
||||
style={{ height: '24px', width: 'auto' }}
|
||||
/>
|
||||
EvoTraders {evoTradersLinkStyle === 'close' ? (
|
||||
<span className="link-arrow">↙</span>
|
||||
) : (
|
||||
<span className="link-arrow">↗</span>
|
||||
)}
|
||||
</span>
|
||||
|
||||
<span style={{
|
||||
width: '2px',
|
||||
height: '16px',
|
||||
background: '#666',
|
||||
margin: '0 16px',
|
||||
display: 'inline-block',
|
||||
verticalAlign: 'middle'
|
||||
}} />
|
||||
|
||||
<span style={{
|
||||
padding: '1px 5px',
|
||||
fontSize: '9px',
|
||||
fontWeight: 700,
|
||||
color: '#00C853',
|
||||
background: 'rgba(0, 200, 83, 0.1)',
|
||||
border: '1px solid #00C853',
|
||||
borderRadius: '3px',
|
||||
letterSpacing: '0.5px',
|
||||
marginRight: '0px'
|
||||
}}>
|
||||
开源
|
||||
</span>
|
||||
|
||||
<a
|
||||
href="https://github.com/agentscope-ai/agentscope-samples"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="header-link"
|
||||
style={{ display: 'inline-flex', flexDirection: 'row', alignItems: 'center', gap: '6px' }}
|
||||
>
|
||||
<svg
|
||||
width="14"
|
||||
height="14"
|
||||
viewBox="0 0 24 24"
|
||||
fill="currentColor"
|
||||
style={{ display: 'inline-block' }}
|
||||
>
|
||||
<path d="M12 0c-6.626 0-12 5.373-12 12 0 5.302 3.438 9.8 8.207 11.387.599.111.793-.261.793-.577v-2.234c-3.338.726-4.033-1.416-4.033-1.416-.546-1.387-1.333-1.756-1.333-1.756-1.089-.745.083-.729.083-.729 1.205.084 1.839 1.237 1.839 1.237 1.07 1.834 2.807 1.304 3.492.997.107-.775.418-1.305.762-1.604-2.665-.305-5.467-1.334-5.467-5.931 0-1.311.469-2.381 1.236-3.221-.124-.303-.535-1.524.117-3.176 0 0 1.008-.322 3.301 1.23.957-.266 1.983-.399 3.003-.404 1.02.005 2.047.138 3.006.404 2.291-1.552 3.297-1.23 3.297-1.23.653 1.653.242 2.874.118 3.176.77.84 1.235 1.911 1.235 3.221 0 4.609-2.807 5.624-5.479 5.921.43.372.823 1.102.823 2.222v3.293c0 .319.192.694.801.576 4.765-1.589 8.199-6.086 8.199-11.386 0-6.627-5.373-12-12-12z"/>
|
||||
</svg>
|
||||
<span>agentscope-samples</span>
|
||||
<span className="link-arrow">↗</span>
|
||||
</a>
|
||||
|
||||
<a
|
||||
href="https://github.com/agentscope-ai/ReMe"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="header-link"
|
||||
style={{ display: 'inline-flex', flexDirection: 'row', alignItems: 'center', gap: '6px', marginLeft: '0px' }}
|
||||
>
|
||||
<svg
|
||||
width="14"
|
||||
height="14"
|
||||
viewBox="0 0 24 24"
|
||||
fill="currentColor"
|
||||
style={{ display: 'inline-block' }}
|
||||
>
|
||||
<path d="M12 0c-6.626 0-12 5.373-12 12 0 5.302 3.438 9.8 8.207 11.387.599.111.793-.261.793-.577v-2.234c-3.338.726-4.033-1.416-4.033-1.416-.546-1.387-1.333-1.756-1.333-1.756-1.089-.745.083-.729.083-.729 1.205.084 1.839 1.237 1.839 1.237 1.07 1.834 2.807 1.304 3.492.997.107-.775.418-1.305.762-1.604-2.665-.305-5.467-1.334-5.467-5.931 0-1.311.469-2.381 1.236-3.221-.124-.303-.535-1.524.117-3.176 0 0 1.008-.322 3.301 1.23.957-.266 1.983-.399 3.003-.404 1.02.005 2.047.138 3.006.404 2.291-1.552 3.297-1.23 3.297-1.23.653 1.653.242 2.874.118 3.176.77.84 1.235 1.911 1.235 3.221 0 4.609-2.807 5.624-5.479 5.921.43.372.823 1.102.823 2.222v3.293c0 .319.192.694.801.576 4.765-1.589 8.199-6.086 8.199-11.386 0-6.627-5.373-12-12-12z"/>
|
||||
</svg>
|
||||
<span>agentscope-ReMe</span>
|
||||
<span className="link-arrow">↗</span>
|
||||
</a>
|
||||
|
||||
<span style={{
|
||||
width: '2px',
|
||||
height: '16px',
|
||||
background: '#666',
|
||||
margin: '0 16px',
|
||||
display: 'inline-block',
|
||||
verticalAlign: 'middle'
|
||||
}} />
|
||||
|
||||
<div
|
||||
style={{
|
||||
position: 'relative',
|
||||
display: 'inline-flex',
|
||||
alignItems: 'center',
|
||||
gap: '8px',
|
||||
cursor: 'pointer'
|
||||
}}
|
||||
onClick={() => {
|
||||
const bothActive = activeContactCard.yue && activeContactCard.jiaji;
|
||||
if (!bothActive) {
|
||||
setActiveContactCard({ yue: true, jiaji: true });
|
||||
setClickedContactCard('both');
|
||||
} else {
|
||||
setActiveContactCard({ yue: false, jiaji: false });
|
||||
setClickedContactCard(null);
|
||||
}
|
||||
}}
|
||||
>
|
||||
<span className="header-link">
|
||||
联系我们
|
||||
</span>
|
||||
|
||||
{/* Two contact buttons */}
|
||||
<div style={{ display: 'flex', gap: '6px', alignItems: 'center' }}>
|
||||
<div
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
if (activeContactCard.yue) {
|
||||
setActiveContactCard(prev => ({ ...prev, yue: false }));
|
||||
if (clickedContactCard === 'yue' || clickedContactCard === 'both') {
|
||||
setClickedContactCard(null);
|
||||
}
|
||||
} else {
|
||||
setActiveContactCard(prev => ({ ...prev, yue: true }));
|
||||
setClickedContactCard('yue');
|
||||
}
|
||||
}}
|
||||
onMouseEnter={() => {
|
||||
if (!clickedContactCard || clickedContactCard === 'yue' || clickedContactCard === 'both') {
|
||||
setActiveContactCard(prev => ({ ...prev, yue: true }));
|
||||
}
|
||||
}}
|
||||
onMouseLeave={() => {
|
||||
if (clickedContactCard !== 'yue' && clickedContactCard !== 'both') {
|
||||
setActiveContactCard(prev => ({ ...prev, yue: false }));
|
||||
}
|
||||
}}
|
||||
style={{
|
||||
padding: '4px 8px',
|
||||
background: activeContactCard.yue ? '#615CED' : '#f5f5f5',
|
||||
color: activeContactCard.yue ? '#fff' : '#333',
|
||||
border: '1px solid',
|
||||
borderColor: activeContactCard.yue ? '#615CED' : '#e0e0e0',
|
||||
borderRadius: '3px',
|
||||
fontSize: '10px',
|
||||
fontWeight: 700,
|
||||
fontFamily: "'IBM Plex Mono', monospace",
|
||||
cursor: 'pointer',
|
||||
transition: 'all 0.2s',
|
||||
letterSpacing: '0.5px',
|
||||
whiteSpace: 'nowrap',
|
||||
overflow: 'hidden',
|
||||
maxWidth: activeContactCard.yue ? '80px' : '32px',
|
||||
minWidth: activeContactCard.yue ? '80px' : '32px'
|
||||
}}
|
||||
>
|
||||
{activeContactCard.yue ? (
|
||||
<a
|
||||
href="https://1mycell.github.io/"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
style={{ color: 'inherit', textDecoration: 'none' }}
|
||||
onClick={(e) => e.stopPropagation()}
|
||||
>
|
||||
Yue Wu ↗
|
||||
</a>
|
||||
) : 'YW'}
|
||||
</div>
|
||||
|
||||
<div
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
if (activeContactCard.jiaji) {
|
||||
setActiveContactCard(prev => ({ ...prev, jiaji: false }));
|
||||
if (clickedContactCard === 'jiaji' || clickedContactCard === 'both') {
|
||||
setClickedContactCard(null);
|
||||
}
|
||||
} else {
|
||||
setActiveContactCard(prev => ({ ...prev, jiaji: true }));
|
||||
setClickedContactCard('jiaji');
|
||||
}
|
||||
}}
|
||||
onMouseEnter={() => {
|
||||
if (!clickedContactCard || clickedContactCard === 'jiaji' || clickedContactCard === 'both') {
|
||||
setActiveContactCard(prev => ({ ...prev, jiaji: true }));
|
||||
}
|
||||
}}
|
||||
onMouseLeave={() => {
|
||||
if (clickedContactCard !== 'jiaji' && clickedContactCard !== 'both') {
|
||||
setActiveContactCard(prev => ({ ...prev, jiaji: false }));
|
||||
}
|
||||
}}
|
||||
style={{
|
||||
padding: '4px 8px',
|
||||
background: activeContactCard.jiaji ? '#615CED' : '#f5f5f5',
|
||||
color: activeContactCard.jiaji ? '#fff' : '#333',
|
||||
border: '1px solid',
|
||||
borderColor: activeContactCard.jiaji ? '#615CED' : '#e0e0e0',
|
||||
borderRadius: '3px',
|
||||
fontSize: '10px',
|
||||
fontWeight: 700,
|
||||
fontFamily: "'IBM Plex Mono', monospace",
|
||||
cursor: 'pointer',
|
||||
transition: 'all 0.2s',
|
||||
letterSpacing: '0.5px',
|
||||
whiteSpace: 'nowrap',
|
||||
overflow: 'hidden',
|
||||
maxWidth: activeContactCard.jiaji ? '100px' : '32px',
|
||||
minWidth: activeContactCard.jiaji ? '100px' : '32px'
|
||||
}}
|
||||
>
|
||||
{activeContactCard.jiaji ? (
|
||||
<a
|
||||
href="https://dengjiaji.github.io/self/"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
style={{ color: 'inherit', textDecoration: 'none' }}
|
||||
onClick={(e) => e.stopPropagation()}
|
||||
>
|
||||
Jiaji Deng ↗
|
||||
</a>
|
||||
) : 'JD'}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
276
frontend/src/components/MarkdownModal.jsx
Normal file
276
frontend/src/components/MarkdownModal.jsx
Normal file
@@ -0,0 +1,276 @@
|
||||
import React from 'react';
|
||||
import ReactMarkdown from 'react-markdown';
|
||||
import remarkGfm from 'remark-gfm';
|
||||
|
||||
function MarkdownModal({ isOpen, onClose, content, agentName, reportType = 'analysis' }) {
|
||||
if (!isOpen) return null;
|
||||
|
||||
const subtitle = reportType === 'decision' ? 'Decision Log' : 'Financial Analysis Report';
|
||||
|
||||
return (
|
||||
<div
|
||||
style={{
|
||||
position: 'fixed',
|
||||
top: 0,
|
||||
left: 0,
|
||||
right: 0,
|
||||
bottom: 0,
|
||||
backgroundColor: 'rgba(0, 0, 0, 0.75)',
|
||||
display: 'flex',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
zIndex: 1000,
|
||||
backdropFilter: 'blur(4px)',
|
||||
}}
|
||||
onClick={onClose}
|
||||
>
|
||||
<div
|
||||
style={{
|
||||
backgroundColor: '#ffffff',
|
||||
borderRadius: '2px',
|
||||
padding: '0',
|
||||
maxWidth: '900px',
|
||||
maxHeight: '85vh',
|
||||
overflow: 'hidden',
|
||||
width: '90%',
|
||||
boxShadow: '0 20px 60px rgba(0, 0, 0, 0.3)',
|
||||
border: '1px solid #e0e0e0',
|
||||
display: 'flex',
|
||||
flexDirection: 'column',
|
||||
}}
|
||||
onClick={(e) => e.stopPropagation()}
|
||||
>
|
||||
{/* Header */}
|
||||
<div style={{
|
||||
display: 'flex',
|
||||
justifyContent: 'space-between',
|
||||
alignItems: 'center',
|
||||
padding: '24px 32px',
|
||||
borderBottom: '2px solid #000',
|
||||
backgroundColor: '#fafafa',
|
||||
}}>
|
||||
<div>
|
||||
<h2 style={{
|
||||
margin: 0,
|
||||
fontSize: '18px',
|
||||
fontWeight: 700,
|
||||
letterSpacing: '0.5px',
|
||||
textTransform: 'uppercase',
|
||||
color: '#000',
|
||||
}}>
|
||||
{agentName}
|
||||
</h2>
|
||||
<p style={{
|
||||
margin: '4px 0 0 0',
|
||||
fontSize: '12px',
|
||||
color: '#666',
|
||||
fontWeight: 500,
|
||||
letterSpacing: '0.3px',
|
||||
}}>
|
||||
{subtitle}
|
||||
</p>
|
||||
</div>
|
||||
<button
|
||||
onClick={onClose}
|
||||
style={{
|
||||
background: '#000',
|
||||
border: 'none',
|
||||
fontSize: '20px',
|
||||
cursor: 'pointer',
|
||||
color: '#fff',
|
||||
width: '32px',
|
||||
height: '32px',
|
||||
borderRadius: '2px',
|
||||
display: 'flex',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
transition: 'all 0.2s',
|
||||
outline: 'none',
|
||||
}}
|
||||
onMouseOver={(e) => e.currentTarget.style.backgroundColor = '#333'}
|
||||
onMouseOut={(e) => e.currentTarget.style.backgroundColor = '#000'}
|
||||
>
|
||||
×
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{/* Content */}
|
||||
<div style={{
|
||||
padding: '32px 32px 24px 32px',
|
||||
overflow: 'auto',
|
||||
backgroundColor: '#fff',
|
||||
flex: 1,
|
||||
}}>
|
||||
<style>{`
|
||||
.markdown-content {
|
||||
color: #1a1a1a;
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', 'Roboto', sans-serif;
|
||||
}
|
||||
|
||||
.markdown-content h1 {
|
||||
font-size: 24px;
|
||||
font-weight: 700;
|
||||
margin: 32px 0 16px 0;
|
||||
padding-bottom: 12px;
|
||||
border-bottom: 2px solid #000;
|
||||
color: #000;
|
||||
letter-spacing: 0.3px;
|
||||
text-transform: uppercase;
|
||||
}
|
||||
|
||||
.markdown-content h1:first-child {
|
||||
margin-top: 0;
|
||||
}
|
||||
|
||||
.markdown-content h2 {
|
||||
font-size: 20px;
|
||||
font-weight: 700;
|
||||
margin: 28px 0 12px 0;
|
||||
color: #000;
|
||||
letter-spacing: 0.3px;
|
||||
text-transform: uppercase;
|
||||
padding-bottom: 8px;
|
||||
border-bottom: 1px solid #d0d0d0;
|
||||
}
|
||||
|
||||
.markdown-content h3 {
|
||||
font-size: 16px;
|
||||
font-weight: 700;
|
||||
margin: 24px 0 10px 0;
|
||||
color: #1a1a1a;
|
||||
letter-spacing: 0.2px;
|
||||
}
|
||||
|
||||
.markdown-content h4 {
|
||||
font-size: 14px;
|
||||
font-weight: 700;
|
||||
margin: 20px 0 8px 0;
|
||||
color: #2a2a2a;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.5px;
|
||||
}
|
||||
|
||||
.markdown-content p {
|
||||
margin: 12px 0;
|
||||
line-height: 1.8;
|
||||
font-size: 14px;
|
||||
color: #2a2a2a;
|
||||
}
|
||||
|
||||
.markdown-content table {
|
||||
border-collapse: collapse;
|
||||
width: 100%;
|
||||
margin: 24px 0;
|
||||
font-size: 13px;
|
||||
border: 1px solid #000;
|
||||
background: #fff;
|
||||
}
|
||||
|
||||
.markdown-content th {
|
||||
background-color: #000;
|
||||
color: #fff;
|
||||
padding: 12px 16px;
|
||||
text-align: left;
|
||||
font-weight: 700;
|
||||
letter-spacing: 0.5px;
|
||||
text-transform: uppercase;
|
||||
font-size: 12px;
|
||||
border: 1px solid #000;
|
||||
}
|
||||
|
||||
.markdown-content td {
|
||||
border: 1px solid #d0d0d0;
|
||||
padding: 12px 16px;
|
||||
text-align: left;
|
||||
color: #1a1a1a;
|
||||
}
|
||||
|
||||
.markdown-content tr:nth-child(even) {
|
||||
background-color: #fafafa;
|
||||
}
|
||||
|
||||
.markdown-content tr:hover {
|
||||
background-color: #f0f0f0;
|
||||
}
|
||||
|
||||
.markdown-content ul,
|
||||
.markdown-content ol {
|
||||
margin: 16px 0;
|
||||
padding-left: 28px;
|
||||
line-height: 1.8;
|
||||
}
|
||||
|
||||
.markdown-content li {
|
||||
margin: 8px 0;
|
||||
color: #2a2a2a;
|
||||
font-size: 14px;
|
||||
}
|
||||
|
||||
.markdown-content li::marker {
|
||||
color: #000;
|
||||
font-weight: 700;
|
||||
}
|
||||
|
||||
.markdown-content strong {
|
||||
font-weight: 700;
|
||||
color: #000;
|
||||
}
|
||||
|
||||
.markdown-content em {
|
||||
font-style: italic;
|
||||
color: #3a3a3a;
|
||||
}
|
||||
|
||||
.markdown-content code {
|
||||
background-color: #f5f5f5;
|
||||
padding: 3px 8px;
|
||||
border-radius: 2px;
|
||||
font-family: 'SF Mono', 'Monaco', 'Consolas', monospace;
|
||||
font-size: 13px;
|
||||
color: #000;
|
||||
border: 1px solid #e0e0e0;
|
||||
}
|
||||
|
||||
.markdown-content pre {
|
||||
background-color: #fafafa;
|
||||
padding: 16px;
|
||||
border-radius: 2px;
|
||||
overflow-x: auto;
|
||||
margin: 20px 0;
|
||||
border: 1px solid #d0d0d0;
|
||||
border-left: 3px solid #000;
|
||||
}
|
||||
|
||||
.markdown-content pre code {
|
||||
background: none;
|
||||
padding: 0;
|
||||
border: none;
|
||||
font-size: 13px;
|
||||
}
|
||||
|
||||
.markdown-content blockquote {
|
||||
border-left: 4px solid #000;
|
||||
margin: 20px 0;
|
||||
padding: 12px 20px;
|
||||
background-color: #fafafa;
|
||||
color: #2a2a2a;
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
.markdown-content hr {
|
||||
border: none;
|
||||
border-top: 1px solid #d0d0d0;
|
||||
margin: 32px 0;
|
||||
}
|
||||
`}</style>
|
||||
<div className="markdown-content">
|
||||
<ReactMarkdown remarkPlugins={[remarkGfm]}>{content}</ReactMarkdown>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export default MarkdownModal;
|
||||
|
||||
831
frontend/src/components/NetValueChart.jsx
Normal file
831
frontend/src/components/NetValueChart.jsx
Normal file
@@ -0,0 +1,831 @@
|
||||
import React, { useMemo, useState, useEffect } from 'react';
|
||||
import { LineChart, Line, XAxis, YAxis, CartesianGrid, Tooltip, ResponsiveContainer, Legend } from 'recharts';
|
||||
import { formatNumber, formatFullNumber } from '../utils/formatters';
|
||||
|
||||
/**
|
||||
* Helper function to get the start time of the most recent trading session
|
||||
* Trading session: 22:30 - next day 05:00
|
||||
* @param {Date|null} virtualTime - Virtual time from server (for mock mode), or null to use real time
|
||||
*/
|
||||
function getRecentTradingSessionStart(virtualTime = null) {
|
||||
// Use virtual time if provided (for mock mode), otherwise use real time
|
||||
let now;
|
||||
if (virtualTime) {
|
||||
// Ensure virtualTime is a valid Date object
|
||||
if (virtualTime instanceof Date && !isNaN(virtualTime.getTime())) {
|
||||
now = virtualTime;
|
||||
} else if (typeof virtualTime === 'string') {
|
||||
now = new Date(virtualTime);
|
||||
if (isNaN(now.getTime())) {
|
||||
console.warn('Invalid virtualTime string, using current time:', virtualTime);
|
||||
now = new Date();
|
||||
}
|
||||
} else {
|
||||
console.warn('Invalid virtualTime type, using current time:', typeof virtualTime);
|
||||
now = new Date();
|
||||
}
|
||||
} else {
|
||||
now = new Date();
|
||||
}
|
||||
|
||||
const currentHour = now.getHours();
|
||||
const currentMinute = now.getMinutes();
|
||||
|
||||
// Check if currently in trading session
|
||||
const isInTradingSession = (currentHour === 22 && currentMinute >= 30) ||
|
||||
currentHour >= 23 ||
|
||||
(currentHour >= 0 && currentHour < 5) ||
|
||||
(currentHour === 5 && currentMinute === 0);
|
||||
|
||||
let sessionStartTime;
|
||||
if (isInTradingSession) {
|
||||
// Currently in trading session, find today's 22:30
|
||||
sessionStartTime = new Date(now);
|
||||
sessionStartTime.setHours(22, 30, 0, 0);
|
||||
// If current time is before 22:30, it means yesterday's 22:30
|
||||
if (now < sessionStartTime) {
|
||||
sessionStartTime.setDate(sessionStartTime.getDate() - 1);
|
||||
}
|
||||
} else {
|
||||
// Not in trading session, find previous session start (yesterday 22:30)
|
||||
sessionStartTime = new Date(now);
|
||||
sessionStartTime.setDate(sessionStartTime.getDate() - 1);
|
||||
sessionStartTime.setHours(22, 30, 0, 0);
|
||||
}
|
||||
|
||||
return sessionStartTime;
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper function to filter strategy data for live view
|
||||
* NOTE: Live mode returns are now pre-processed by the backend, restricted to the
|
||||
* latest trading session and already starting at 0% at session start. This helper
|
||||
* is kept for potential future use but is no longer used in live mode.
|
||||
*/
|
||||
function filterStrategyDataForLive(strategyData, equity, sessionStartTime) {
|
||||
if (!strategyData || strategyData.length === 0 || !equity || equity.length === 0) return [];
|
||||
|
||||
try {
|
||||
if (!sessionStartTime || isNaN(sessionStartTime.getTime())) {
|
||||
console.warn('Invalid sessionStartTime in filterStrategyDataForLive');
|
||||
return [];
|
||||
}
|
||||
|
||||
const sessionStartTimestamp = sessionStartTime.getTime();
|
||||
|
||||
// Find the last index before session
|
||||
let lastDataBeforeSession = null;
|
||||
for (let i = equity.length - 1; i >= 0; i--) {
|
||||
if (equity[i] && typeof equity[i].t === 'number' && equity[i].t < sessionStartTimestamp) {
|
||||
if (strategyData[i] && strategyData[i].v !== undefined && strategyData[i].v !== null) {
|
||||
lastDataBeforeSession = strategyData[i];
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Find data points in the session
|
||||
const sessionData = [];
|
||||
for (let i = 0; i < equity.length; i++) {
|
||||
if (equity[i] && typeof equity[i].t === 'number' &&
|
||||
equity[i].t >= sessionStartTimestamp &&
|
||||
strategyData[i] &&
|
||||
strategyData[i].v !== undefined && strategyData[i].v !== null) {
|
||||
sessionData.push(strategyData[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// If we have a value before session and session data, add the start point
|
||||
// Create a start point with timestamp just before session start
|
||||
if (lastDataBeforeSession && sessionData.length > 0) {
|
||||
const startPoint = {
|
||||
t: sessionStartTimestamp - 1,
|
||||
v: lastDataBeforeSession.v
|
||||
};
|
||||
return [startPoint, ...sessionData];
|
||||
}
|
||||
|
||||
return sessionData;
|
||||
} catch (error) {
|
||||
console.error('Error in filterStrategyDataForLive:', error);
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Net Value Chart Component
|
||||
* Displays portfolio value over time with multiple strategy comparisons
|
||||
*/
|
||||
export default function NetValueChart({ equity, baseline, baseline_vw, momentum, strategies, equity_return, baseline_return, baseline_vw_return, momentum_return, chartTab = 'all', virtualTime = null }) {
|
||||
const [activePoint, setActivePoint] = useState(null);
|
||||
const [stableYRange, setStableYRange] = useState(null);
|
||||
const [legendTooltip, setLegendTooltip] = useState(null);
|
||||
|
||||
// Legend descriptions
|
||||
const legendDescriptions = {
|
||||
'EvoTraders': 'EvoTraders is our agents investment strategy',
|
||||
'Buy & Hold (EW)': 'Equal Weight: Can be viewed as an equal-weighted index of all invested stocks',
|
||||
'Buy & Hold (VW)': 'Value Weighted: Can be viewed as a market-cap weighted index of all invested stocks',
|
||||
'Momentum': 'Momentum Strategy: Buy stocks that have performed well in the past',
|
||||
};
|
||||
|
||||
|
||||
// For live mode, use cumulative returns calculated by backend
|
||||
// For all mode, use portfolio values directly
|
||||
const dataSource = useMemo(() => {
|
||||
if (chartTab === 'live') {
|
||||
return {
|
||||
equity: equity_return || equity,
|
||||
baseline: baseline_return || baseline,
|
||||
baseline_vw: baseline_vw_return || baseline_vw,
|
||||
momentum: momentum_return || momentum
|
||||
};
|
||||
}
|
||||
return {
|
||||
equity: equity,
|
||||
baseline: baseline,
|
||||
baseline_vw: baseline_vw,
|
||||
momentum: momentum
|
||||
};
|
||||
}, [chartTab, equity, baseline, baseline_vw, momentum, equity_return, baseline_return, baseline_vw_return, momentum_return]);
|
||||
// Filter equity data based on chartTab
|
||||
const filteredEquity = useMemo(() => {
|
||||
if (chartTab === 'all') {
|
||||
const sourceEquity = dataSource.equity;
|
||||
if (!sourceEquity || sourceEquity.length === 0) return [];
|
||||
|
||||
// ALL chart: Show only the last point per day
|
||||
// Logic: Keep the last equity value before 22:30 each day (the last equity value before US next trading day opens)
|
||||
// Data after 22:30 belongs to the next trading day's session and is not shown in this chart
|
||||
// Time handling: timestamp(ms) -> UTC -> Asia/Shanghai timezone, then group and filter based on Asia/Shanghai time
|
||||
const dailyData = {};
|
||||
|
||||
sourceEquity.forEach((d) => {
|
||||
// Timestamp is in milliseconds, first create UTC time, then convert to Asia/Shanghai timezone
|
||||
// Equivalent to: pd.to_datetime(timestamp, unit='ms', utc=True).dt.tz_convert('Asia/Shanghai')
|
||||
const utcDate = new Date(d.t); // timestamp(ms) -> UTC time
|
||||
|
||||
// Use Intl API to get date/time components in Asia/Shanghai timezone
|
||||
const formatter = new Intl.DateTimeFormat('en-US', {
|
||||
timeZone: 'Asia/Shanghai',
|
||||
year: 'numeric',
|
||||
month: '2-digit',
|
||||
day: '2-digit',
|
||||
hour: '2-digit',
|
||||
minute: '2-digit',
|
||||
hour12: false
|
||||
});
|
||||
|
||||
const parts = formatter.formatToParts(utcDate);
|
||||
const year = parts.find(p => p.type === 'year').value;
|
||||
const month = parts.find(p => p.type === 'month').value;
|
||||
const day = parts.find(p => p.type === 'day').value;
|
||||
const hour = parseInt(parts.find(p => p.type === 'hour').value);
|
||||
const minute = parseInt(parts.find(p => p.type === 'minute').value);
|
||||
|
||||
// Check if before 22:30 (Asia/Shanghai timezone)
|
||||
const isBefore2230 = hour < 22 || (hour === 22 && minute < 30);
|
||||
|
||||
// Only process data before 22:30
|
||||
if (isBefore2230) {
|
||||
// Use Asia/Shanghai timezone date as key
|
||||
const dateKey = `${year}-${month}-${day}`;
|
||||
|
||||
// Update if this day has no data yet, or if current data is later in time
|
||||
if (!dailyData[dateKey] || new Date(d.t) > new Date(dailyData[dateKey].t)) {
|
||||
dailyData[dateKey] = d;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Convert to array and sort by time
|
||||
return Object.values(dailyData).sort((a, b) => a.t - b.t);
|
||||
} else if (chartTab === 'live') {
|
||||
// LIVE chart: Show all updates from the most recent trading session (22:30-05:00)
|
||||
// Live mode: Backend has already returned return curves for "current trading session + 0% starting point", frontend can use directly
|
||||
const sourceEquity = dataSource.equity;
|
||||
if (!sourceEquity || sourceEquity.length === 0) return [];
|
||||
return sourceEquity;
|
||||
}
|
||||
return dataSource.equity || [];
|
||||
}, [dataSource.equity, chartTab, virtualTime]);
|
||||
// Helper function to get daily indices for 'all' view
|
||||
const getDailyIndices = useMemo(() => {
|
||||
if (!equity || equity.length === 0) return new Set();
|
||||
const dailyIndices = new Set();
|
||||
const dailyData = {};
|
||||
|
||||
const formatter = new Intl.DateTimeFormat('en-US', {
|
||||
timeZone: 'Asia/Shanghai',
|
||||
year: 'numeric',
|
||||
month: '2-digit',
|
||||
day: '2-digit',
|
||||
hour: '2-digit',
|
||||
minute: '2-digit',
|
||||
hour12: false
|
||||
});
|
||||
|
||||
equity.forEach((d, idx) => {
|
||||
const utcDate = new Date(d.t);
|
||||
const parts = formatter.formatToParts(utcDate);
|
||||
const hour = parseInt(parts.find(p => p.type === 'hour').value);
|
||||
const minute = parseInt(parts.find(p => p.type === 'minute').value);
|
||||
|
||||
// Check if before 22:30 (Asia/Shanghai timezone)
|
||||
const isBefore2230 = hour < 22 || (hour === 22 && minute < 30);
|
||||
|
||||
// Only process data before 22:30
|
||||
if (isBefore2230) {
|
||||
const year = parts.find(p => p.type === 'year').value;
|
||||
const month = parts.find(p => p.type === 'month').value;
|
||||
const day = parts.find(p => p.type === 'day').value;
|
||||
const dateKey = `${year}-${month}-${day}`;
|
||||
|
||||
if (!dailyData[dateKey] || new Date(d.t) > new Date(dailyData[dateKey].t)) {
|
||||
dailyData[dateKey] = { data: d, index: idx };
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Object.values(dailyData).forEach(({ index }) => dailyIndices.add(index));
|
||||
return dailyIndices;
|
||||
}, [equity]);
|
||||
|
||||
// Filter baseline, baseline_vw, momentum, strategies to match filteredEquity indices
|
||||
const filteredBaseline = useMemo(() => {
|
||||
const sourceBaseline = dataSource.baseline;
|
||||
if (!sourceBaseline || sourceBaseline.length === 0 || !equity || equity.length === 0) return [];
|
||||
if (chartTab === 'all') {
|
||||
return sourceBaseline.filter((_, idx) => getDailyIndices.has(idx));
|
||||
} else if (chartTab === 'live') {
|
||||
// Live mode: Use backend pre-processed baseline return curves directly
|
||||
return sourceBaseline;
|
||||
}
|
||||
return sourceBaseline;
|
||||
}, [dataSource.baseline, equity, chartTab, getDailyIndices, virtualTime]);
|
||||
const filteredBaselineVw = useMemo(() => {
|
||||
const sourceBaselineVw = dataSource.baseline_vw;
|
||||
if (!sourceBaselineVw || sourceBaselineVw.length === 0 || !equity || equity.length === 0) return [];
|
||||
if (chartTab === 'all') {
|
||||
return sourceBaselineVw.filter((_, idx) => getDailyIndices.has(idx));
|
||||
} else if (chartTab === 'live') {
|
||||
// Live mode: Use backend pre-processed baseline return curves directly
|
||||
return sourceBaselineVw;
|
||||
}
|
||||
return sourceBaselineVw;
|
||||
}, [dataSource.baseline_vw, equity, chartTab, getDailyIndices, virtualTime]);
|
||||
const filteredMomentum = useMemo(() => {
|
||||
const sourceMomentum = dataSource.momentum;
|
||||
if (!sourceMomentum || sourceMomentum.length === 0 || !equity || equity.length === 0) return [];
|
||||
if (chartTab === 'all') {
|
||||
return sourceMomentum.filter((_, idx) => getDailyIndices.has(idx));
|
||||
} else if (chartTab === 'live') {
|
||||
// Live mode: Use backend pre-processed momentum return curves directly
|
||||
return sourceMomentum;
|
||||
}
|
||||
return sourceMomentum;
|
||||
}, [dataSource.momentum, equity, chartTab, getDailyIndices, virtualTime]);
|
||||
const filteredStrategies = useMemo(() => {
|
||||
if (!strategies || strategies.length === 0 || !equity || equity.length === 0) return [];
|
||||
if (chartTab === 'all') {
|
||||
return strategies.filter((_, idx) => getDailyIndices.has(idx));
|
||||
} else if (chartTab === 'live') {
|
||||
const sessionStartTime = getRecentTradingSessionStart(virtualTime);
|
||||
return filterStrategyDataForLive(strategies, equity, sessionStartTime);
|
||||
}
|
||||
return strategies;
|
||||
}, [strategies, equity, chartTab, getDailyIndices, virtualTime]);
|
||||
|
||||
const chartData = useMemo(() => {
|
||||
if (!filteredEquity || filteredEquity.length === 0) return [];
|
||||
|
||||
try {
|
||||
// LIVE mode: Align all curves by timestamp with forward filling to ensure consistent point counts and aligned starting points
|
||||
if (chartTab === 'live') {
|
||||
// Build timestamp -> value mapping
|
||||
const toMap = (arr) => {
|
||||
const m = new Map();
|
||||
if (Array.isArray(arr)) {
|
||||
arr.forEach((p) => {
|
||||
if (p && typeof p.t === 'number' && typeof p.v === 'number') {
|
||||
m.set(p.t, p.v);
|
||||
}
|
||||
});
|
||||
}
|
||||
return m;
|
||||
};
|
||||
|
||||
const portfolioMap = toMap(filteredEquity);
|
||||
const baselineMap = toMap(filteredBaseline);
|
||||
const baselineVwMap = toMap(filteredBaselineVw);
|
||||
const momentumMap = toMap(filteredMomentum);
|
||||
const strategyMap = toMap(filteredStrategies);
|
||||
|
||||
// Collect all timestamps, sort by time
|
||||
const timestampSet = new Set();
|
||||
[filteredEquity, filteredBaseline, filteredBaselineVw, filteredMomentum, filteredStrategies].forEach(arr => {
|
||||
if (Array.isArray(arr)) {
|
||||
arr.forEach(p => {
|
||||
if (p && typeof p.t === 'number') timestampSet.add(p.t);
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
const timestamps = Array.from(timestampSet).sort((a, b) => a - b);
|
||||
if (timestamps.length === 0) return [];
|
||||
|
||||
// Current values for forward filling, initialized to 0% to ensure starting point alignment
|
||||
let currentPortfolio = 0;
|
||||
let currentBaseline = 0;
|
||||
let currentBaselineVw = 0;
|
||||
let currentMomentum = 0;
|
||||
let currentStrategy = 0;
|
||||
|
||||
return timestamps.map((t, idx) => {
|
||||
if (portfolioMap.has(t)) currentPortfolio = portfolioMap.get(t);
|
||||
if (baselineMap.has(t)) currentBaseline = baselineMap.get(t);
|
||||
if (baselineVwMap.has(t)) currentBaselineVw = baselineVwMap.get(t);
|
||||
if (momentumMap.has(t)) currentMomentum = momentumMap.get(t);
|
||||
if (strategyMap.has(t)) currentStrategy = strategyMap.get(t);
|
||||
|
||||
const date = new Date(t);
|
||||
if (isNaN(date.getTime())) {
|
||||
console.warn('Invalid timestamp in live chart data:', t);
|
||||
return null;
|
||||
}
|
||||
|
||||
return {
|
||||
index: idx,
|
||||
time:
|
||||
date.toLocaleDateString('en-US', {
|
||||
month: 'short',
|
||||
day: 'numeric',
|
||||
}) +
|
||||
' ' +
|
||||
date.toLocaleTimeString('en-US', {
|
||||
hour: '2-digit',
|
||||
minute: '2-digit',
|
||||
hour12: false,
|
||||
}),
|
||||
timestamp: t,
|
||||
portfolio: currentPortfolio,
|
||||
baseline: currentBaseline,
|
||||
baseline_vw: currentBaselineVw,
|
||||
momentum: currentMomentum,
|
||||
strategy: currentStrategy,
|
||||
};
|
||||
}).filter(item => item !== null);
|
||||
}
|
||||
|
||||
// ALL mode: Keep the original index-based alignment logic
|
||||
return filteredEquity.map((d, idx) => {
|
||||
if (!d || typeof d.t !== 'number' || typeof d.v !== 'number') {
|
||||
console.warn('Invalid equity data point:', d);
|
||||
return null;
|
||||
}
|
||||
|
||||
const date = new Date(d.t);
|
||||
if (isNaN(date.getTime())) {
|
||||
console.warn('Invalid timestamp:', d.t);
|
||||
return null;
|
||||
}
|
||||
|
||||
const baselineVal = filteredBaseline?.[idx]
|
||||
? (typeof filteredBaseline[idx] === 'object' ? filteredBaseline[idx].v : filteredBaseline[idx])
|
||||
: null;
|
||||
const baselineVwVal = filteredBaselineVw?.[idx]
|
||||
? (typeof filteredBaselineVw[idx] === 'object' ? filteredBaselineVw[idx].v : filteredBaselineVw[idx])
|
||||
: null;
|
||||
const momentumVal = filteredMomentum?.[idx]
|
||||
? (typeof filteredMomentum[idx] === 'object' ? filteredMomentum[idx].v : filteredMomentum[idx])
|
||||
: null;
|
||||
const strategyVal = filteredStrategies?.[idx]
|
||||
? (typeof filteredStrategies[idx] === 'object' ? filteredStrategies[idx].v : filteredStrategies[idx])
|
||||
: null;
|
||||
|
||||
return {
|
||||
index: idx,
|
||||
time:
|
||||
date.toLocaleDateString('en-US', { month: 'short', day: 'numeric' }) +
|
||||
' ' +
|
||||
date.toLocaleTimeString('en-US', {
|
||||
hour: '2-digit',
|
||||
minute: '2-digit',
|
||||
hour12: false,
|
||||
}),
|
||||
timestamp: d.t,
|
||||
portfolio: d.v,
|
||||
baseline: baselineVal || null,
|
||||
baseline_vw: baselineVwVal || null,
|
||||
momentum: momentumVal || null,
|
||||
strategy: strategyVal || null,
|
||||
};
|
||||
}).filter(item => item !== null); // Remove null entries
|
||||
} catch (error) {
|
||||
console.error('Error processing chart data:', error);
|
||||
return [];
|
||||
}
|
||||
}, [filteredEquity, filteredBaseline, filteredBaselineVw, filteredMomentum, filteredStrategies, chartTab]);
|
||||
|
||||
const { yMin, yMax, xTickIndices } = useMemo(() => {
|
||||
if (chartData.length === 0) return { yMin: 0, yMax: 1, xTickIndices: [] };
|
||||
|
||||
// Calculate min and max from all series
|
||||
const allValues = chartData.flatMap(d =>
|
||||
[d.portfolio, d.baseline, d.baseline_vw, d.momentum, d.strategy].filter(v => v !== null && isFinite(v))
|
||||
);
|
||||
|
||||
if (allValues.length === 0) {
|
||||
return { yMin: 0, yMax: 1000000, xTickIndices: [] };
|
||||
}
|
||||
|
||||
const dataMin = Math.min(...allValues);
|
||||
const dataMax = Math.max(...allValues);
|
||||
const range = dataMax - dataMin || 1;
|
||||
|
||||
// For live mode (percentage data), use smaller padding and finer rounding
|
||||
// For all mode (dollar amounts), use larger padding and coarser rounding
|
||||
const isLiveMode = chartTab === 'live';
|
||||
|
||||
const paddingFactor = isLiveMode ? range * 0.15 : range * 0.03;
|
||||
|
||||
let yMinCalc = dataMin - paddingFactor;
|
||||
let yMaxCalc = dataMax + paddingFactor;
|
||||
|
||||
// Smart rounding based on magnitude and mode
|
||||
const magnitude = Math.max(Math.abs(yMinCalc), Math.abs(yMaxCalc));
|
||||
let roundTo;
|
||||
|
||||
if (isLiveMode) {
|
||||
// For percentage data, use much finer rounding
|
||||
if (magnitude >= 100) {
|
||||
roundTo = 10;
|
||||
} else if (magnitude >= 10) {
|
||||
roundTo = 1;
|
||||
} else if (magnitude >= 1) {
|
||||
roundTo = 0.1;
|
||||
} else {
|
||||
roundTo = 0.01;
|
||||
}
|
||||
} else {
|
||||
// For dollar amounts, use coarser rounding
|
||||
if (magnitude >= 1e6) {
|
||||
roundTo = 10000;
|
||||
} else if (magnitude >= 1e5) {
|
||||
roundTo = 5000;
|
||||
} else if (magnitude >= 1e4) {
|
||||
roundTo = 1000;
|
||||
} else {
|
||||
roundTo = 100;
|
||||
}
|
||||
}
|
||||
|
||||
yMinCalc = Math.floor(yMinCalc / roundTo) * roundTo;
|
||||
yMaxCalc = Math.ceil(yMaxCalc / roundTo) * roundTo;
|
||||
|
||||
// Stable range to prevent frequent updates
|
||||
if (stableYRange) {
|
||||
const { min: stableMin, max: stableMax } = stableYRange;
|
||||
const stableRange = stableMax - stableMin;
|
||||
const threshold = stableRange * 0.05;
|
||||
|
||||
const needsUpdate =
|
||||
dataMin < (stableMin + threshold) ||
|
||||
dataMax > (stableMax - threshold);
|
||||
|
||||
if (!needsUpdate) {
|
||||
yMinCalc = stableMin;
|
||||
yMaxCalc = stableMax;
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate x-axis tick indices
|
||||
const safeLength = Math.min(chartData.length, 10000);
|
||||
const targetTicks = Math.min(8, Math.max(5, Math.floor(safeLength / 10)));
|
||||
const step = Math.max(1, Math.floor(safeLength / (targetTicks - 1)));
|
||||
|
||||
const indices = [];
|
||||
for (let i = 0; i < safeLength && indices.length < 100; i += step) {
|
||||
indices.push(i);
|
||||
}
|
||||
|
||||
if (safeLength > 0 && indices[indices.length - 1] !== safeLength - 1) {
|
||||
indices.push(safeLength - 1);
|
||||
}
|
||||
|
||||
return { yMin: yMinCalc, yMax: yMaxCalc, xTickIndices: indices };
|
||||
}, [chartData, stableYRange]);
|
||||
|
||||
// Update stableYRange in useEffect to avoid infinite re-renders
|
||||
// Use functional update to avoid dependency on stableYRange
|
||||
useEffect(() => {
|
||||
if (yMin !== undefined && yMax !== undefined && yMin !== null && yMax !== null && isFinite(yMin) && isFinite(yMax)) {
|
||||
setStableYRange(prevRange => {
|
||||
if (!prevRange) {
|
||||
// Initialize stable range
|
||||
return { min: yMin, max: yMax };
|
||||
} else {
|
||||
// Check if update is needed (5% threshold)
|
||||
const stableRange = prevRange.max - prevRange.min;
|
||||
const threshold = stableRange * 0.05;
|
||||
const needsUpdate =
|
||||
yMin < (prevRange.min + threshold) ||
|
||||
yMax > (prevRange.max - threshold);
|
||||
|
||||
if (needsUpdate) {
|
||||
return { min: yMin, max: yMax };
|
||||
}
|
||||
// No update needed, return previous range
|
||||
return prevRange;
|
||||
}
|
||||
});
|
||||
}
|
||||
}, [yMin, yMax]);
|
||||
|
||||
if (!equity || equity.length === 0) {
|
||||
return (
|
||||
<div style={{
|
||||
width: '100%',
|
||||
height: '100%',
|
||||
display: 'flex',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
color: '#cccccc',
|
||||
fontFamily: '"Courier New", monospace',
|
||||
fontSize: '12px'
|
||||
}}>
|
||||
NO DATA AVAILABLE
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
const CustomTooltip = ({ active, payload }) => {
|
||||
if (active && payload && payload.length) {
|
||||
const isLiveMode = chartTab === 'live';
|
||||
return (
|
||||
<div style={{
|
||||
background: '#000000',
|
||||
border: '1px solid #333333',
|
||||
padding: '10px 14px',
|
||||
fontFamily: '"Courier New", monospace',
|
||||
fontSize: '10px',
|
||||
color: '#ffffff'
|
||||
}}>
|
||||
<div style={{ fontWeight: 700, marginBottom: '6px', fontSize: '11px' }}>
|
||||
{payload[0].payload.time}
|
||||
</div>
|
||||
{payload.map((entry, index) => (
|
||||
<div key={index} style={{ color: entry.color, marginTop: '2px' }}>
|
||||
<span style={{ fontWeight: 700 }}>{entry.name}:</span> {isLiveMode ? `${entry.value.toFixed(2)}%` : `$${formatNumber(entry.value)}`} </div>
|
||||
))}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
return null;
|
||||
};
|
||||
|
||||
const CustomDot = ({ dataKey, ...props }) => {
|
||||
const { cx, cy, payload, index } = props;
|
||||
const isActive = activePoint === index;
|
||||
const isLastPoint = index === chartData.length - 1;
|
||||
|
||||
// Only show dot for the last point
|
||||
if (!isLastPoint) {
|
||||
return null;
|
||||
}
|
||||
const colors = {
|
||||
portfolio: '#00C853',
|
||||
baseline: '#FF6B00',
|
||||
baseline_vw: '#9C27B0',
|
||||
momentum: '#2196F3',
|
||||
strategy: '#795548'
|
||||
};
|
||||
|
||||
return (
|
||||
<circle
|
||||
cx={cx}
|
||||
cy={cy}
|
||||
r={isActive ? 6 : 8}
|
||||
fill={colors[dataKey]}
|
||||
stroke="#ffffff"
|
||||
strokeWidth={2}
|
||||
style={{ cursor: 'pointer' }}
|
||||
onMouseEnter={() => setActivePoint(index)}
|
||||
onMouseLeave={() => setActivePoint(null)}
|
||||
onClick={() => console.log('Clicked point:', { dataKey, ...payload })}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
const CustomXAxisTick = ({ x, y, payload }) => {
|
||||
const shouldShow = xTickIndices.includes(payload.index);
|
||||
if (!shouldShow) return null;
|
||||
|
||||
return (
|
||||
<g transform={`translate(${x},${y})`}>
|
||||
<text
|
||||
x={0}
|
||||
y={0}
|
||||
dy={16}
|
||||
textAnchor="middle"
|
||||
fill="#666666"
|
||||
fontSize="10px"
|
||||
fontFamily='"Courier New", monospace'
|
||||
fontWeight="700"
|
||||
>
|
||||
{payload.value}
|
||||
</text>
|
||||
</g>
|
||||
);
|
||||
};
|
||||
|
||||
const CustomLegend = ({ payload }) => {
|
||||
if (!payload || payload.length === 0) return null;
|
||||
|
||||
return (
|
||||
<div style={{
|
||||
display: 'flex',
|
||||
flexWrap: 'wrap',
|
||||
gap: '16px',
|
||||
padding: '10px 0',
|
||||
position: 'relative',
|
||||
fontFamily: '"Courier New", monospace',
|
||||
fontSize: '11px',
|
||||
fontWeight: 700,
|
||||
justifyContent: 'center'
|
||||
}}>
|
||||
{payload.map((entry, index) => {
|
||||
const description = legendDescriptions[entry.value] || '';
|
||||
const isActive = legendTooltip === entry.value;
|
||||
|
||||
return (
|
||||
<div
|
||||
key={index}
|
||||
style={{
|
||||
display: 'flex',
|
||||
alignItems: 'center',
|
||||
gap: '8px',
|
||||
cursor: 'pointer',
|
||||
position: 'relative',
|
||||
padding: '4px 8px',
|
||||
borderRadius: '4px',
|
||||
backgroundColor: isActive ? '#f0f0f0' : 'transparent',
|
||||
transition: 'background-color 0.2s',
|
||||
userSelect: 'none'
|
||||
}}
|
||||
onMouseEnter={() => setLegendTooltip(entry.value)}
|
||||
onMouseLeave={() => setLegendTooltip(null)}
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
setLegendTooltip(isActive ? null : entry.value);
|
||||
}}
|
||||
>
|
||||
<div
|
||||
style={{
|
||||
width: '14px',
|
||||
height: '3px',
|
||||
backgroundColor: entry.color,
|
||||
border: 'none'
|
||||
}}
|
||||
/>
|
||||
<span
|
||||
style={{
|
||||
fontFamily: '"Courier New", monospace',
|
||||
fontSize: '11px',
|
||||
fontWeight: 700,
|
||||
color: '#000000'
|
||||
}}
|
||||
>
|
||||
{entry.value}
|
||||
</span>
|
||||
{isActive && description && (
|
||||
<div
|
||||
style={{
|
||||
position: 'absolute',
|
||||
bottom: '100%',
|
||||
left: 0,
|
||||
marginBottom: '8px',
|
||||
padding: '8px 12px',
|
||||
background: '#000000',
|
||||
color: '#ffffff',
|
||||
fontSize: '10px',
|
||||
fontFamily: '"Courier New", monospace',
|
||||
whiteSpace: 'normal',
|
||||
maxWidth: '300px',
|
||||
zIndex: 1000,
|
||||
borderRadius: '4px',
|
||||
boxShadow: '0 2px 8px rgba(0,0,0,0.3)',
|
||||
pointerEvents: 'none',
|
||||
lineHeight: 1.4
|
||||
}}
|
||||
>
|
||||
{description}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
return (
|
||||
<ResponsiveContainer width="100%" height="100%">
|
||||
<LineChart
|
||||
data={chartData}
|
||||
margin={{ top: 20, right: 30, bottom: 50, left: 60 }}
|
||||
>
|
||||
<XAxis
|
||||
dataKey="time"
|
||||
stroke="#666666"
|
||||
tick={<CustomXAxisTick />}
|
||||
interval={0}
|
||||
/>
|
||||
<YAxis
|
||||
domain={[yMin, yMax]}
|
||||
stroke="#000000"
|
||||
style={{ fontFamily: '"Courier New", monospace', fontSize: '11px', fontWeight: 700 }}
|
||||
tick={{ fill: '#000000' }}
|
||||
tickFormatter={(value) => chartTab === 'live' ? `${value.toFixed(2)}%` : formatFullNumber(value)}
|
||||
width={75}
|
||||
/>
|
||||
<Tooltip content={<CustomTooltip />} />
|
||||
<Legend
|
||||
content={<CustomLegend />}
|
||||
/>
|
||||
|
||||
{/* Portfolio line */}
|
||||
<Line
|
||||
type="linear"
|
||||
dataKey="portfolio"
|
||||
name="EvoTraders"
|
||||
stroke="#00C853"
|
||||
strokeWidth={2.5}
|
||||
dot={(props) => <CustomDot {...props} dataKey="portfolio" />}
|
||||
activeDot={{ r: 6, stroke: '#ffffff', strokeWidth: 2 }}
|
||||
isAnimationActive={false}
|
||||
/>
|
||||
|
||||
{/* Baseline Equal Weight */}
|
||||
{baseline && baseline.length > 0 && (
|
||||
<Line
|
||||
type="linear"
|
||||
dataKey="baseline"
|
||||
name="Buy & Hold (EW)"
|
||||
stroke="#FF6B00"
|
||||
strokeWidth={2}
|
||||
strokeDasharray="5 5"
|
||||
dot={(props) => <CustomDot {...props} dataKey="baseline" />}
|
||||
activeDot={{ r: 6, stroke: '#ffffff', strokeWidth: 2 }}
|
||||
isAnimationActive={false}
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* Baseline Value Weighted */}
|
||||
{baseline_vw && baseline_vw.length > 0 && (
|
||||
<Line
|
||||
type="linear"
|
||||
dataKey="baseline_vw"
|
||||
name="Buy & Hold (VW)"
|
||||
stroke="#9C27B0"
|
||||
strokeWidth={2}
|
||||
strokeDasharray="8 4"
|
||||
dot={(props) => <CustomDot {...props} dataKey="baseline_vw" />}
|
||||
activeDot={{ r: 6, stroke: '#ffffff', strokeWidth: 2 }}
|
||||
isAnimationActive={false}
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* Momentum Strategy */}
|
||||
{momentum && momentum.length > 0 && (
|
||||
<Line
|
||||
type="linear"
|
||||
dataKey="momentum"
|
||||
name="Momentum"
|
||||
stroke="#2196F3"
|
||||
strokeWidth={2}
|
||||
strokeDasharray="3 3"
|
||||
dot={(props) => <CustomDot {...props} dataKey="momentum" />}
|
||||
activeDot={{ r: 6, stroke: '#ffffff', strokeWidth: 2 }}
|
||||
isAnimationActive={false}
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* Other Strategies */}
|
||||
{strategies && strategies.length > 0 && (
|
||||
<Line
|
||||
type="linear"
|
||||
dataKey="strategy"
|
||||
name="Strategy"
|
||||
stroke="#795548"
|
||||
strokeWidth={2}
|
||||
dot={(props) => <CustomDot {...props} dataKey="strategy" />}
|
||||
activeDot={{ r: 6, stroke: '#ffffff', strokeWidth: 2 }}
|
||||
isAnimationActive={false}
|
||||
/>
|
||||
)}
|
||||
</LineChart>
|
||||
</ResponsiveContainer>
|
||||
);
|
||||
}
|
||||
|
||||
236
frontend/src/components/PerformanceView.jsx
Normal file
236
frontend/src/components/PerformanceView.jsx
Normal file
@@ -0,0 +1,236 @@
|
||||
import React from 'react';
|
||||
|
||||
/**
|
||||
* Performance View Component
|
||||
* Displays agent performance leaderboard and signal history
|
||||
*/
|
||||
export default function PerformanceView({ leaderboard }) {
|
||||
const rankedAgents = Array.isArray(leaderboard)
|
||||
? leaderboard.filter(agent => agent.agentId !== 'risk_manager')
|
||||
: [];
|
||||
return (
|
||||
<div>
|
||||
{/* Agent Performance Section */}
|
||||
<div className="section">
|
||||
<div className="section-header">
|
||||
<h2 className="section-title">Agent Performance - Signal Accuracy</h2>
|
||||
</div>
|
||||
|
||||
{rankedAgents.length === 0 ? (
|
||||
<div className="empty-state">No leaderboard data available</div>
|
||||
) : (
|
||||
<div className="table-wrapper">
|
||||
<table className="data-table">
|
||||
<thead>
|
||||
<tr>
|
||||
<th>Rank</th>
|
||||
<th>Agent</th>
|
||||
<th>Win Rate</th>
|
||||
<th>Bull Signals</th>
|
||||
<th>Bull Win Rate</th>
|
||||
<th>Bear Signals</th>
|
||||
<th>Bear Win Rate</th>
|
||||
<th>Total Signals</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{rankedAgents.map(agent => {
|
||||
const bullTotal = agent.bull?.n || 0;
|
||||
const bullWins = agent.bull?.win || 0;
|
||||
const bullUnknown = agent.bull?.unknown || 0;
|
||||
const bearTotal = agent.bear?.n || 0;
|
||||
const bearWins = agent.bear?.win || 0;
|
||||
const bearUnknown = agent.bear?.unknown || 0;
|
||||
const totalSignals = bullTotal + bearTotal;
|
||||
const evaluatedBull = Math.max(bullTotal - bullUnknown, 0);
|
||||
const evaluatedBear = Math.max(bearTotal - bearUnknown, 0);
|
||||
const evaluatedTotal = evaluatedBull + evaluatedBear;
|
||||
const bullWinRate = evaluatedBull > 0 ? (bullWins / evaluatedBull) : null;
|
||||
const bearWinRate = evaluatedBear > 0 ? (bearWins / evaluatedBear) : null;
|
||||
const overallWinRate = agent.winRate != null
|
||||
? agent.winRate
|
||||
: (evaluatedTotal > 0 ? ((bullWins + bearWins) / evaluatedTotal) : null);
|
||||
const overallColor = overallWinRate != null
|
||||
? (overallWinRate >= 0.5 ? '#00C853' : '#FF1744')
|
||||
: '#999999';
|
||||
|
||||
return (
|
||||
<tr key={agent.agentId}>
|
||||
<td>
|
||||
<span className={`rank-badge ${agent.rank === 1 ? 'first' : agent.rank === 2 ? 'second' : agent.rank === 3 ? 'third' : ''}`}>
|
||||
{agent.rank === 1 ? '★ 1' : agent.rank}
|
||||
</span>
|
||||
</td>
|
||||
<td>
|
||||
<div style={{ fontWeight: 700, color: '#000000' }}>{agent.name}</div>
|
||||
<div style={{ fontSize: 10, color: '#666666' }}>{agent.role}</div>
|
||||
</td>
|
||||
<td style={{ fontWeight: 700, color: overallColor }}>
|
||||
{overallWinRate != null ? `${(overallWinRate * 100).toFixed(1)}%` : 'N/A'}
|
||||
</td>
|
||||
<td>
|
||||
<div style={{ fontSize: 12 }}>{bullTotal} signals</div>
|
||||
<div style={{ fontSize: 10, color: '#666666' }}>{bullWins} wins</div>
|
||||
{bullUnknown > 0 && (
|
||||
<div style={{ fontSize: 10, color: '#999999' }}>{bullUnknown} unknown</div>
|
||||
)}
|
||||
</td>
|
||||
<td style={{ color: bullWinRate != null ? (bullWinRate >= 0.5 ? '#00C853' : '#999999') : '#999999' }}>
|
||||
{bullWinRate != null ? `${(bullWinRate * 100).toFixed(1)}%` : 'N/A'}
|
||||
</td>
|
||||
<td>
|
||||
<div style={{ fontSize: 12 }}>{bearTotal} signals</div>
|
||||
<div style={{ fontSize: 10, color: '#666666' }}>{bearWins} wins</div>
|
||||
{bearUnknown > 0 && (
|
||||
<div style={{ fontSize: 10, color: '#999999' }}>{bearUnknown} unknown</div>
|
||||
)}
|
||||
</td>
|
||||
<td style={{ color: bearWinRate != null ? (bearWinRate >= 0.5 ? '#00C853' : '#999999') : '#999999' }}>
|
||||
{bearWinRate != null ? `${(bearWinRate * 100).toFixed(1)}%` : 'N/A'}
|
||||
</td>
|
||||
<td style={{ fontWeight: 700 }}>{totalSignals}</td>
|
||||
</tr>
|
||||
);
|
||||
})}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Signal History with Dates */}
|
||||
{rankedAgents.length > 0 && rankedAgents.some(agent => agent.signals && agent.signals.length > 0) && (
|
||||
<div className="section" style={{ marginTop: 32 }}>
|
||||
<div className="section-header">
|
||||
<h2 className="section-title">Signal History</h2>
|
||||
</div>
|
||||
|
||||
<div style={{ display: 'grid', gridTemplateColumns: 'repeat(auto-fit, minmax(400px, 1fr))', gap: 20 }}>
|
||||
{rankedAgents.map(agent => {
|
||||
if (!agent.signals || agent.signals.length === 0) return null;
|
||||
|
||||
// Sort by date descending (newest first)
|
||||
const sortedSignals = [...agent.signals].sort((a, b) =>
|
||||
new Date(b.date).getTime() - new Date(a.date).getTime()
|
||||
);
|
||||
|
||||
return (
|
||||
<div key={agent.agentId} style={{
|
||||
border: '1px solid #e0e0e0',
|
||||
padding: 16,
|
||||
background: '#fafafa'
|
||||
}}>
|
||||
<div style={{
|
||||
fontWeight: 700,
|
||||
fontSize: 12,
|
||||
marginBottom: 12,
|
||||
paddingBottom: 8,
|
||||
borderBottom: '2px solid #000000',
|
||||
letterSpacing: 1,
|
||||
textTransform: 'uppercase'
|
||||
}}>
|
||||
{agent.name}
|
||||
</div>
|
||||
<div style={{
|
||||
maxHeight: 500,
|
||||
overflowY: 'auto',
|
||||
display: 'flex',
|
||||
flexDirection: 'column',
|
||||
gap: 8
|
||||
}}>
|
||||
{sortedSignals.map((signal, idx) => {
|
||||
const signalType = signal.signal.toLowerCase();
|
||||
const isBull = signalType.includes('bull') || signalType === 'long';
|
||||
const isBear = signalType.includes('bear') || signalType === 'short';
|
||||
const isNeutral = signalType.includes('neutral') || signalType === 'hold';
|
||||
const resultStatus = signal.is_correct;
|
||||
const isCorrect = resultStatus === true;
|
||||
const isResultUnknown = resultStatus === 'unknown' || resultStatus === null || typeof resultStatus === 'undefined';
|
||||
const realReturnValue = signal.real_return;
|
||||
const hasRealReturn = typeof realReturnValue === 'number' && Number.isFinite(realReturnValue);
|
||||
const realReturnDisplay = hasRealReturn
|
||||
? `${realReturnValue >= 0 ? '+' : ''}${(realReturnValue * 100).toFixed(2)}%`
|
||||
: 'Unknown';
|
||||
const realReturnColor = hasRealReturn
|
||||
? (realReturnValue >= 0 ? '#00C853' : '#FF1744')
|
||||
: '#999999';
|
||||
const statusColor = isResultUnknown ? '#999999' : (isCorrect ? '#00C853' : '#FF1744');
|
||||
const statusSymbol = isResultUnknown ? '?' : (isCorrect ? '✓' : '✗');
|
||||
|
||||
return (
|
||||
<div key={idx} style={{
|
||||
fontSize: 11,
|
||||
fontFamily: '"Courier New", monospace',
|
||||
lineHeight: 1.4,
|
||||
padding: '8px 10px',
|
||||
background: '#ffffff',
|
||||
border: '1px solid #e0e0e0',
|
||||
display: 'flex',
|
||||
justifyContent: 'space-between',
|
||||
alignItems: 'center'
|
||||
}}>
|
||||
<div style={{ flex: 1 }}>
|
||||
<span style={{
|
||||
color: '#666666',
|
||||
fontSize: 10,
|
||||
marginRight: 10,
|
||||
fontWeight: 600
|
||||
}}>
|
||||
{signal.date}
|
||||
</span>
|
||||
<span style={{
|
||||
fontWeight: 700,
|
||||
color: isBull ? '#00C853' : isBear ? '#FF1744' : '#999999'
|
||||
}}>
|
||||
{signal.ticker}
|
||||
</span>
|
||||
<span style={{
|
||||
marginLeft: 6,
|
||||
color: isBull ? '#00C853' : isBear ? '#FF1744' : '#999999',
|
||||
fontSize: 12
|
||||
}}>
|
||||
{isBull ? 'Bull' : isBear ? 'Bear' : 'Neutral'}
|
||||
</span>
|
||||
{!isNeutral && (
|
||||
<span style={{
|
||||
marginLeft: 8,
|
||||
fontSize: 10,
|
||||
color: realReturnColor
|
||||
}}>
|
||||
{realReturnDisplay}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
{!isNeutral && (
|
||||
<span style={{
|
||||
fontSize: 14,
|
||||
marginLeft: 10,
|
||||
color: statusColor
|
||||
}}>
|
||||
{statusSymbol}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
<div style={{
|
||||
marginTop: 10,
|
||||
paddingTop: 8,
|
||||
borderTop: '1px solid #e0e0e0',
|
||||
fontSize: 10,
|
||||
color: '#666666',
|
||||
textAlign: 'center'
|
||||
}}>
|
||||
Total: {sortedSignals.length} signals
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
770
frontend/src/components/RoomView.jsx
Normal file
770
frontend/src/components/RoomView.jsx
Normal file
@@ -0,0 +1,770 @@
|
||||
import React, { useEffect, useMemo, useRef, useState, useCallback } from 'react';
|
||||
import { ASSETS, SCENE_NATIVE, AGENT_SEATS, AGENTS } from '../config/constants';
|
||||
import AgentCard from './AgentCard';
|
||||
import { getModelIcon } from '../utils/modelIcons';
|
||||
|
||||
/**
|
||||
* Custom hook to load an image
|
||||
*/
|
||||
function useImage(src) {
|
||||
const [img, setImg] = useState(null);
|
||||
useEffect(() => {
|
||||
if (!src) {
|
||||
setImg(null);
|
||||
return;
|
||||
}
|
||||
// Reset image state when backend changes
|
||||
setImg(null);
|
||||
const image = new Image();
|
||||
image.src = src;
|
||||
image.onload = () => setImg(image);
|
||||
image.onerror = () => {
|
||||
console.error(`Failed to load image: ${src}`);
|
||||
setImg(null);
|
||||
};
|
||||
// Cleanup: cancel loading if backend changes
|
||||
return () => {
|
||||
image.onload = null;
|
||||
image.onerror = null;
|
||||
};
|
||||
}, [src]);
|
||||
return img;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get rank medal/trophy for display
|
||||
*/
|
||||
function getRankMedal(rank) {
|
||||
if (rank === 1) return '🏆';
|
||||
if (rank === 2) return '🥈';
|
||||
if (rank === 3) return '🥉';
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Room View Component
|
||||
* Displays the conference room with agents, speech bubbles, and agent cards
|
||||
* Supports click and hover (1.5s) to show agent performance cards
|
||||
* Supports replay mode - completely independent from live mode
|
||||
*/
|
||||
export default function RoomView({ bubbles, bubbleFor, leaderboard, feed, onJumpToMessage }) {
|
||||
const canvasRef = useRef(null);
|
||||
const containerRef = useRef(null);
|
||||
|
||||
// Agent selection and hover state
|
||||
const [selectedAgent, setSelectedAgent] = useState(null);
|
||||
const [hoveredAgent, setHoveredAgent] = useState(null);
|
||||
const [isClosing, setIsClosing] = useState(false);
|
||||
const hoverTimerRef = useRef(null);
|
||||
const closeTimerRef = useRef(null);
|
||||
|
||||
// Bubble expansion state
|
||||
const [expandedBubbles, setExpandedBubbles] = useState({});
|
||||
|
||||
// Hidden bubbles (locally dismissed)
|
||||
const [hiddenBubbles, setHiddenBubbles] = useState({});
|
||||
|
||||
// Handle bubble close
|
||||
const handleCloseBubble = (agentId, bubbleKey, e) => {
|
||||
e.stopPropagation();
|
||||
setHiddenBubbles(prev => ({
|
||||
...prev,
|
||||
[bubbleKey]: true
|
||||
}));
|
||||
};
|
||||
|
||||
// Replay state (must be defined before using in useMemo)
|
||||
const [isReplaying, setIsReplaying] = useState(false);
|
||||
const [replayBubbles, setReplayBubbles] = useState({});
|
||||
const [modeTransition, setModeTransition] = useState(null); // 'entering-replay' | 'exiting-replay' | null
|
||||
const [isPaused, setIsPaused] = useState(false);
|
||||
const replayTimerRef = useRef(null);
|
||||
const replayTimeoutsRef = useRef([]);
|
||||
const replayStateRef = useRef({ messages: [], currentIndex: 0 });
|
||||
|
||||
// Background image
|
||||
const roomBgSrc = ASSETS.roomBg;
|
||||
|
||||
const bgImg = useImage(roomBgSrc);
|
||||
|
||||
// Calculate scale to fit canvas in container (80% of available space)
|
||||
const [scale, setScale] = useState(0.8);
|
||||
|
||||
useEffect(() => {
|
||||
const updateScale = () => {
|
||||
const container = containerRef.current;
|
||||
if (!container) return;
|
||||
|
||||
const { clientWidth, clientHeight } = container;
|
||||
if (clientWidth <= 0 || clientHeight <= 0) return;
|
||||
|
||||
const scaleX = clientWidth / SCENE_NATIVE.width;
|
||||
const scaleY = clientHeight / SCENE_NATIVE.height;
|
||||
const newScale = Math.min(scaleX, scaleY, 1.0) * 0.8; // Scale to 80% of original size
|
||||
setScale(Math.max(0.3, newScale));
|
||||
};
|
||||
|
||||
updateScale();
|
||||
const resizeObserver = new ResizeObserver(updateScale);
|
||||
if (containerRef.current) {
|
||||
resizeObserver.observe(containerRef.current);
|
||||
}
|
||||
window.addEventListener('resize', updateScale);
|
||||
|
||||
return () => {
|
||||
resizeObserver.disconnect();
|
||||
window.removeEventListener('resize', updateScale);
|
||||
};
|
||||
}, []);
|
||||
|
||||
// Set canvas size
|
||||
useEffect(() => {
|
||||
const canvas = canvasRef.current;
|
||||
if (!canvas) return;
|
||||
|
||||
canvas.width = SCENE_NATIVE.width;
|
||||
canvas.height = SCENE_NATIVE.height;
|
||||
|
||||
const displayWidth = Math.round(SCENE_NATIVE.width * scale);
|
||||
const displayHeight = Math.round(SCENE_NATIVE.height * scale);
|
||||
canvas.style.width = `${displayWidth}px`;
|
||||
canvas.style.height = `${displayHeight}px`;
|
||||
}, [scale]);
|
||||
|
||||
// Draw room background
|
||||
useEffect(() => {
|
||||
const canvas = canvasRef.current;
|
||||
if (!canvas) return;
|
||||
|
||||
const ctx = canvas.getContext('2d');
|
||||
ctx.imageSmoothingEnabled = false;
|
||||
|
||||
// Clear canvas first
|
||||
ctx.clearRect(0, 0, canvas.width, canvas.height);
|
||||
|
||||
// Draw image if loaded
|
||||
if (bgImg) {
|
||||
ctx.drawImage(bgImg, 0, 0, SCENE_NATIVE.width, SCENE_NATIVE.height);
|
||||
}
|
||||
}, [bgImg, scale, roomBgSrc]);
|
||||
|
||||
// Determine which agents are speaking
|
||||
const speakingAgents = useMemo(() => {
|
||||
const speaking = {};
|
||||
AGENTS.forEach(agent => {
|
||||
const bubble = bubbleFor(agent.name);
|
||||
speaking[agent.id] = !!bubble;
|
||||
});
|
||||
return speaking;
|
||||
}, [bubbles, bubbleFor]);
|
||||
|
||||
// Find agent data from leaderboard
|
||||
const getAgentData = (agentId) => {
|
||||
const agent = AGENTS.find(a => a.id === agentId);
|
||||
if (!agent) return null;
|
||||
|
||||
// If no leaderboard data, return agent with default stats
|
||||
if (!leaderboard || !Array.isArray(leaderboard)) {
|
||||
return {
|
||||
...agent,
|
||||
bull: { n: 0, win: 0, unknown: 0 },
|
||||
bear: { n: 0, win: 0, unknown: 0 },
|
||||
winRate: null,
|
||||
signals: [],
|
||||
rank: null
|
||||
};
|
||||
}
|
||||
|
||||
const leaderboardData = leaderboard.find(lb => lb.agentId === agentId);
|
||||
|
||||
// If agent not in leaderboard, return agent with default stats
|
||||
if (!leaderboardData) {
|
||||
return {
|
||||
...agent,
|
||||
bull: { n: 0, win: 0, unknown: 0 },
|
||||
bear: { n: 0, win: 0, unknown: 0 },
|
||||
winRate: null,
|
||||
signals: [],
|
||||
rank: null
|
||||
};
|
||||
}
|
||||
|
||||
// Merge data but preserve the correct avatar from AGENTS config
|
||||
return {
|
||||
...agent,
|
||||
...leaderboardData,
|
||||
avatar: agent.avatar // Always use the frontend's avatar URL
|
||||
};
|
||||
};
|
||||
|
||||
// Get agent rank for display
|
||||
const getAgentRank = (agentId) => {
|
||||
const agentData = getAgentData(agentId);
|
||||
return agentData?.rank || null;
|
||||
};
|
||||
|
||||
// Handle agent click
|
||||
const handleAgentClick = (agentId) => {
|
||||
// Cancel any closing animation
|
||||
if (closeTimerRef.current) {
|
||||
clearTimeout(closeTimerRef.current);
|
||||
closeTimerRef.current = null;
|
||||
}
|
||||
setIsClosing(false);
|
||||
|
||||
const agentData = getAgentData(agentId);
|
||||
if (agentData) {
|
||||
setSelectedAgent(agentData);
|
||||
}
|
||||
};
|
||||
|
||||
// Handle agent hover
|
||||
const handleAgentMouseEnter = (agentId) => {
|
||||
setHoveredAgent(agentId);
|
||||
// Clear any existing timer
|
||||
if (hoverTimerRef.current) {
|
||||
clearTimeout(hoverTimerRef.current);
|
||||
hoverTimerRef.current = null;
|
||||
}
|
||||
// Cancel any closing animation
|
||||
if (closeTimerRef.current) {
|
||||
clearTimeout(closeTimerRef.current);
|
||||
closeTimerRef.current = null;
|
||||
}
|
||||
setIsClosing(false);
|
||||
|
||||
// If there's already a selected agent, switch immediately
|
||||
// Otherwise, show after a short delay (0ms = immediate)
|
||||
const agentData = getAgentData(agentId);
|
||||
if (agentData) {
|
||||
if (selectedAgent) {
|
||||
// Already have a card open, switch immediately
|
||||
setSelectedAgent(agentData);
|
||||
} else {
|
||||
// No card open, show after delay (currently 0ms = immediate)
|
||||
hoverTimerRef.current = setTimeout(() => {
|
||||
setSelectedAgent(agentData);
|
||||
hoverTimerRef.current = null;
|
||||
}, 0);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const handleAgentMouseLeave = () => {
|
||||
setHoveredAgent(null);
|
||||
// Clear timer if mouse leaves before 1.5 seconds
|
||||
if (hoverTimerRef.current) {
|
||||
clearTimeout(hoverTimerRef.current);
|
||||
hoverTimerRef.current = null;
|
||||
}
|
||||
};
|
||||
|
||||
// Handle closing with animation
|
||||
const handleClose = () => {
|
||||
setIsClosing(true);
|
||||
// Wait for animation to complete before removing
|
||||
closeTimerRef.current = setTimeout(() => {
|
||||
setSelectedAgent(null);
|
||||
setIsClosing(false);
|
||||
closeTimerRef.current = null;
|
||||
}, 200); // Match the slideUp animation duration
|
||||
};
|
||||
|
||||
// Cleanup timer on unmount
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
if (hoverTimerRef.current) {
|
||||
clearTimeout(hoverTimerRef.current);
|
||||
}
|
||||
if (closeTimerRef.current) {
|
||||
clearTimeout(closeTimerRef.current);
|
||||
}
|
||||
// Clean up replay timers
|
||||
if (replayTimerRef.current) {
|
||||
clearTimeout(replayTimerRef.current);
|
||||
}
|
||||
replayTimeoutsRef.current.forEach(timeoutId => clearTimeout(timeoutId));
|
||||
replayTimeoutsRef.current = [];
|
||||
};
|
||||
}, []);
|
||||
|
||||
// Show replay button when not in replay mode and has feed history
|
||||
const showReplayButton = !isReplaying && feed && feed.length > 0;
|
||||
|
||||
// Start replay with feed data
|
||||
const handleReplayClick = useCallback(() => {
|
||||
if (!feed || feed.length === 0) {
|
||||
return;
|
||||
}
|
||||
startReplay(feed);
|
||||
}, [feed]);
|
||||
|
||||
// Extract agent messages from feed items
|
||||
const extractAgentMessages = useCallback((feedItems) => {
|
||||
const messages = [];
|
||||
|
||||
feedItems.forEach((item, itemIndex) => {
|
||||
if (item.type === 'message' && item.data) {
|
||||
const msg = item.data;
|
||||
// Skip system messages
|
||||
if (msg.agent === 'System') return;
|
||||
// Find matching agent
|
||||
const agent = AGENTS.find(a =>
|
||||
a.id === msg.agentId ||
|
||||
a.name === msg.agent
|
||||
);
|
||||
if (agent) {
|
||||
messages.push({
|
||||
feedItemId: item.id,
|
||||
agentId: agent.id,
|
||||
agentName: agent.name,
|
||||
content: msg.content,
|
||||
timestamp: msg.timestamp
|
||||
});
|
||||
}
|
||||
} else if (item.type === 'conference' && item.data?.messages) {
|
||||
item.data.messages.forEach((msg, msgIndex) => {
|
||||
if (msg.agent === 'System') return;
|
||||
const agent = AGENTS.find(a =>
|
||||
a.id === msg.agentId ||
|
||||
a.name === msg.agent
|
||||
);
|
||||
if (agent) {
|
||||
messages.push({
|
||||
feedItemId: item.id,
|
||||
agentId: agent.id,
|
||||
agentName: agent.name,
|
||||
content: msg.content,
|
||||
timestamp: msg.timestamp
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
return messages;
|
||||
}, []);
|
||||
|
||||
// Show next message in replay
|
||||
const showNextMessage = useCallback(() => {
|
||||
const { messages, currentIndex } = replayStateRef.current;
|
||||
if (currentIndex >= messages.length) {
|
||||
// End replay
|
||||
setModeTransition('exiting-replay');
|
||||
setTimeout(() => {
|
||||
setModeTransition(null);
|
||||
setIsReplaying(false);
|
||||
setIsPaused(false);
|
||||
setReplayBubbles({});
|
||||
replayStateRef.current = { messages: [], currentIndex: 0 };
|
||||
}, 500);
|
||||
return;
|
||||
}
|
||||
|
||||
const msg = messages[currentIndex];
|
||||
const bubbleId = `replay_${msg.agentId}_${currentIndex}`;
|
||||
|
||||
setReplayBubbles(prev => ({
|
||||
...prev,
|
||||
[bubbleId]: {
|
||||
id: bubbleId,
|
||||
feedItemId: msg.feedItemId,
|
||||
agentId: msg.agentId,
|
||||
agentName: msg.agentName,
|
||||
text: msg.content,
|
||||
timestamp: msg.timestamp,
|
||||
ts: msg.timestamp
|
||||
}
|
||||
}));
|
||||
|
||||
// Remove bubble after 10 seconds (previously 5s) to keep replay text visible longer
|
||||
const hideTimeout = setTimeout(() => {
|
||||
setReplayBubbles(prev => {
|
||||
const newBubbles = { ...prev };
|
||||
delete newBubbles[bubbleId];
|
||||
return newBubbles;
|
||||
});
|
||||
}, 10000);
|
||||
replayTimeoutsRef.current.push(hideTimeout);
|
||||
|
||||
// Schedule next message
|
||||
replayStateRef.current.currentIndex = currentIndex + 1;
|
||||
// Wait longer before next bubble to match extended visibility (was 3s)
|
||||
const nextTimeout = setTimeout(() => {
|
||||
showNextMessage();
|
||||
}, 6000);
|
||||
replayTimerRef.current = nextTimeout;
|
||||
replayTimeoutsRef.current.push(nextTimeout);
|
||||
}, []);
|
||||
|
||||
// Start replay with feed data
|
||||
const startReplay = useCallback((feedItems) => {
|
||||
if (!feedItems || feedItems.length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const agentMessages = extractAgentMessages(feedItems).reverse();
|
||||
if (agentMessages.length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Store messages for pause/resume
|
||||
replayStateRef.current = { messages: agentMessages, currentIndex: 0 };
|
||||
|
||||
// Start transition animation
|
||||
setModeTransition('entering-replay');
|
||||
setIsReplaying(true);
|
||||
setIsPaused(false);
|
||||
setReplayBubbles({});
|
||||
|
||||
// Clear any existing timeouts
|
||||
replayTimeoutsRef.current.forEach(timeoutId => clearTimeout(timeoutId));
|
||||
replayTimeoutsRef.current = [];
|
||||
|
||||
// Clear transition and start replay after animation completes
|
||||
setTimeout(() => {
|
||||
setModeTransition(null);
|
||||
showNextMessage();
|
||||
}, 500);
|
||||
}, [extractAgentMessages, showNextMessage]);
|
||||
|
||||
// Pause replay
|
||||
const pauseReplay = useCallback(() => {
|
||||
if (replayTimerRef.current) {
|
||||
clearTimeout(replayTimerRef.current);
|
||||
replayTimerRef.current = null;
|
||||
}
|
||||
setIsPaused(true);
|
||||
}, []);
|
||||
|
||||
// Resume replay
|
||||
const resumeReplay = useCallback(() => {
|
||||
setIsPaused(false);
|
||||
showNextMessage();
|
||||
}, [showNextMessage]);
|
||||
|
||||
// Stop replay
|
||||
const stopReplay = useCallback(() => {
|
||||
// Clear all timeouts
|
||||
replayTimeoutsRef.current.forEach(timeoutId => clearTimeout(timeoutId));
|
||||
replayTimeoutsRef.current = [];
|
||||
|
||||
if (replayTimerRef.current) {
|
||||
clearTimeout(replayTimerRef.current);
|
||||
replayTimerRef.current = null;
|
||||
}
|
||||
|
||||
// Transition out of replay mode
|
||||
setModeTransition('exiting-replay');
|
||||
// Clear transition and replay state after animation completes
|
||||
setTimeout(() => {
|
||||
setModeTransition(null);
|
||||
setIsReplaying(false);
|
||||
setIsPaused(false);
|
||||
setReplayBubbles({});
|
||||
replayStateRef.current = { messages: [], currentIndex: 0 };
|
||||
}, 500);
|
||||
}, []);
|
||||
|
||||
// Get bubble for specific agent (supports both live and replay mode)
|
||||
const getBubbleForAgent = useCallback((agentName) => {
|
||||
if (isReplaying) {
|
||||
// Find replay bubble for this agent
|
||||
const bubble = Object.values(replayBubbles).find(b => {
|
||||
const agent = AGENTS.find(a => a.id === b.agentId);
|
||||
return agent && agent.name === agentName;
|
||||
});
|
||||
return bubble || null;
|
||||
} else {
|
||||
// Use normal bubbleFor function
|
||||
return bubbleFor(agentName);
|
||||
}
|
||||
}, [isReplaying, replayBubbles, bubbleFor]);
|
||||
|
||||
return (
|
||||
<div className="room-view">
|
||||
{/* Agents Indicator Bar */}
|
||||
<div className="room-agents-indicator">
|
||||
{AGENTS.map((agent, index) => {
|
||||
const rank = getAgentRank(agent.id);
|
||||
const medal = rank ? getRankMedal(rank) : null;
|
||||
const agentData = getAgentData(agent.id);
|
||||
const modelInfo = getModelIcon(agentData?.modelName, agentData?.modelProvider);
|
||||
|
||||
return (
|
||||
<React.Fragment key={agent.id}>
|
||||
<div
|
||||
className={`agent-indicator ${speakingAgents[agent.id] ? 'speaking' : ''} ${hoveredAgent === agent.id ? 'hovered' : ''}`}
|
||||
onClick={() => handleAgentClick(agent.id)}
|
||||
onMouseEnter={() => handleAgentMouseEnter(agent.id)}
|
||||
onMouseLeave={handleAgentMouseLeave}
|
||||
>
|
||||
<div className="agent-avatar-wrapper">
|
||||
<img
|
||||
src={agent.avatar}
|
||||
alt={agent.name}
|
||||
className="agent-avatar"
|
||||
/>
|
||||
<span className="agent-indicator-dot"></span>
|
||||
{medal && (
|
||||
<span className="agent-rank-medal">
|
||||
{medal}
|
||||
</span>
|
||||
)}
|
||||
{modelInfo.logoPath && (
|
||||
<img
|
||||
src={modelInfo.logoPath}
|
||||
alt={modelInfo.provider}
|
||||
className="agent-model-badge"
|
||||
style={{
|
||||
position: 'absolute',
|
||||
top: -12,
|
||||
right: -12,
|
||||
width: 25,
|
||||
height: 25,
|
||||
borderRadius: '50%',
|
||||
border: '2px solid #ffffff',
|
||||
background: '#ffffff',
|
||||
objectFit: 'contain',
|
||||
padding: 2,
|
||||
boxShadow: '0 2px 4px rgba(0,0,0,0.1)',
|
||||
pointerEvents: 'none'
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
<span className="agent-name">{agent.name}</span>
|
||||
</div>
|
||||
{/* Divider after Risk Manager (index 1) */}
|
||||
{index === 1 && (
|
||||
<div style={{
|
||||
width: 2,
|
||||
height: 60,
|
||||
background: 'linear-gradient(to bottom, transparent, #333333, transparent)',
|
||||
margin: '0 12px',
|
||||
alignSelf: 'center'
|
||||
}} />
|
||||
)}
|
||||
</React.Fragment>
|
||||
);
|
||||
})}
|
||||
|
||||
{/* Hint Text */}
|
||||
<div className="agent-hint-text">
|
||||
点击头像查看详情
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Room Canvas */}
|
||||
<div className="room-canvas-container" ref={containerRef}>
|
||||
<div className="room-scene">
|
||||
<div className="room-scene-wrapper" style={{ width: Math.round(SCENE_NATIVE.width * scale), height: Math.round(SCENE_NATIVE.height * scale) }}>
|
||||
<canvas ref={canvasRef} className="room-canvas" />
|
||||
|
||||
{/* Speech Bubbles */}
|
||||
{AGENTS.map((agent, idx) => {
|
||||
const bubble = getBubbleForAgent(agent.name);
|
||||
if (!bubble) return null;
|
||||
|
||||
const bubbleKey = `${agent.id}_${bubble.timestamp || bubble.id || bubble.ts}`;
|
||||
|
||||
// Check if bubble is hidden
|
||||
if (hiddenBubbles[bubbleKey]) return null;
|
||||
|
||||
const pos = AGENT_SEATS[idx];
|
||||
const scaledWidth = SCENE_NATIVE.width * scale;
|
||||
const scaledHeight = SCENE_NATIVE.height * scale;
|
||||
|
||||
// Bubble left-bottom corner aligns to agent position
|
||||
const left = Math.round(pos.x * scaledWidth);
|
||||
const bottom = Math.round(pos.y * scaledHeight);
|
||||
|
||||
// Get agent data for model info
|
||||
const agentData = getAgentData(agent.id);
|
||||
const modelInfo = getModelIcon(agentData?.modelName, agentData?.modelProvider);
|
||||
|
||||
// Truncate long text - 200 collapsed, 500 expanded max
|
||||
const maxLength = 200;
|
||||
const maxExpandedLength = 500;
|
||||
const isTruncated = bubble.text.length > maxLength;
|
||||
const isExpanded = expandedBubbles[bubbleKey];
|
||||
const displayText = (!isExpanded && isTruncated)
|
||||
? bubble.text.substring(0, maxLength) + '...'
|
||||
: (isExpanded && bubble.text.length > maxExpandedLength)
|
||||
? bubble.text.substring(0, maxExpandedLength) + '...'
|
||||
: bubble.text;
|
||||
|
||||
const toggleExpand = (e) => {
|
||||
e.stopPropagation();
|
||||
setExpandedBubbles(prev => ({
|
||||
...prev,
|
||||
[bubbleKey]: !prev[bubbleKey]
|
||||
}));
|
||||
};
|
||||
|
||||
const handleJumpToFeed = (e) => {
|
||||
e.stopPropagation();
|
||||
if (onJumpToMessage) {
|
||||
onJumpToMessage(bubble);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div
|
||||
key={agent.id}
|
||||
className="room-bubble"
|
||||
style={{ left, bottom }}
|
||||
>
|
||||
{/* Action buttons */}
|
||||
<div className="bubble-action-buttons">
|
||||
<button
|
||||
className="bubble-jump-btn"
|
||||
onClick={handleJumpToFeed}
|
||||
title="跳转到消息"
|
||||
>
|
||||
↗
|
||||
</button>
|
||||
<button
|
||||
className="bubble-close-btn"
|
||||
onClick={(e) => handleCloseBubble(agent.id, bubbleKey, e)}
|
||||
title="关闭"
|
||||
>
|
||||
×
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{/* Agent header with model icon */}
|
||||
<div className="room-bubble-header">
|
||||
{modelInfo.logoPath && (
|
||||
<img
|
||||
src={modelInfo.logoPath}
|
||||
alt={modelInfo.provider}
|
||||
className="bubble-model-icon"
|
||||
/>
|
||||
)}
|
||||
<div className="room-bubble-name">{bubble.agentName || agent.name}</div>
|
||||
</div>
|
||||
|
||||
<div className="room-bubble-divider"></div>
|
||||
|
||||
{/* Message content */}
|
||||
<div className="room-bubble-content">
|
||||
{displayText}
|
||||
{isTruncated && (
|
||||
<button
|
||||
className="bubble-expand-btn"
|
||||
onClick={toggleExpand}
|
||||
>
|
||||
{isExpanded ? ' ↑' : ' ↓'}
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Agent Card - Dropdown style below indicator bar */}
|
||||
{selectedAgent && (
|
||||
<>
|
||||
{/* Transparent overlay to close card */}
|
||||
<div
|
||||
className="agent-card-overlay"
|
||||
onClick={handleClose}
|
||||
/>
|
||||
|
||||
{/* Agent Card */}
|
||||
<AgentCard
|
||||
agent={selectedAgent}
|
||||
isClosing={isClosing}
|
||||
onClose={handleClose}
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
|
||||
{/* Mode Transition Overlay - sweeps in the dark gradient */}
|
||||
{modeTransition === 'entering-replay' && (
|
||||
<div
|
||||
style={{
|
||||
position: 'absolute',
|
||||
top: 0,
|
||||
left: 0,
|
||||
right: 0,
|
||||
bottom: 0,
|
||||
background: 'radial-gradient(circle, rgba(0,0,0,0) 0%, rgba(0,0,0,0.3) 100%)',
|
||||
pointerEvents: 'none',
|
||||
zIndex: 40,
|
||||
clipPath: 'inset(0 100% 0 0)',
|
||||
animation: 'clipReveal 0.5s ease-out forwards'
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* Mode Transition Overlay - sweeps out the dark gradient */}
|
||||
{modeTransition === 'exiting-replay' && (
|
||||
<div
|
||||
style={{
|
||||
position: 'absolute',
|
||||
top: 0,
|
||||
left: 0,
|
||||
right: 0,
|
||||
bottom: 0,
|
||||
background: 'radial-gradient(circle, rgba(0,0,0,0) 0%, rgba(0,0,0,0.3) 100%)',
|
||||
pointerEvents: 'none',
|
||||
zIndex: 40,
|
||||
clipPath: 'inset(0 0 0 0)',
|
||||
animation: 'clipHide 0.5s ease-out forwards'
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* Replay Button */}
|
||||
{showReplayButton && (
|
||||
<div className="replay-button-container">
|
||||
<button
|
||||
className="replay-button"
|
||||
onClick={handleReplayClick}
|
||||
title="Replay feed history"
|
||||
>
|
||||
<span className="replay-icon">▶▶</span>
|
||||
<span>回放</span>
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Replay Mode Background + Indicator */}
|
||||
{isReplaying && !modeTransition && (
|
||||
<>
|
||||
<div
|
||||
style={{
|
||||
position: 'absolute',
|
||||
top: 0,
|
||||
left: 0,
|
||||
right: 0,
|
||||
bottom: 0,
|
||||
background: 'radial-gradient(circle, rgba(0,0,0,0) 0%, rgba(0,0,0,0.3) 100%)',
|
||||
pointerEvents: 'none',
|
||||
zIndex: 40
|
||||
}}
|
||||
/>
|
||||
<div className="replay-indicator">
|
||||
<span className="replay-status">{isPaused ? '已暂停' : '回放模式'}</span>
|
||||
<button
|
||||
className="replay-button"
|
||||
onClick={isPaused ? resumeReplay : pauseReplay}
|
||||
style={{ padding: '6px 12px' }}
|
||||
>
|
||||
<span>{isPaused ? '▶' : '⏸'}</span>
|
||||
</button>
|
||||
<button className="replay-button" onClick={stopReplay} style={{ padding: '6px 12px' }}>
|
||||
<span>■</span>
|
||||
</button>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user