Initial commit of integrated agent system
This commit is contained in:
140
backend/utils/progress.py
Normal file
140
backend/utils/progress.py
Normal file
@@ -0,0 +1,140 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from datetime import datetime, timezone
|
||||
from typing import Callable, Optional
|
||||
|
||||
from rich.console import Console
|
||||
from rich.live import Live
|
||||
from rich.style import Style
|
||||
from rich.table import Table
|
||||
from rich.text import Text
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
class AgentProgress:
|
||||
"""Manages progress tracking for multiple agents."""
|
||||
|
||||
def __init__(self):
|
||||
self.agent_status = {}
|
||||
self.table = Table(show_header=False, box=None, padding=(0, 1))
|
||||
self.live = Live(self.table, console=console, refresh_per_second=4)
|
||||
self.started = False
|
||||
self.update_handlers = []
|
||||
|
||||
def register_handler(
|
||||
self,
|
||||
handler: Callable[[str, Optional[str], str], None],
|
||||
):
|
||||
"""Register a handler to be called when agent status updates."""
|
||||
self.update_handlers.append(handler)
|
||||
return handler # Return handler to support use as decorator
|
||||
|
||||
def unregister_handler(
|
||||
self,
|
||||
handler: Callable[[str, Optional[str], str], None],
|
||||
):
|
||||
"""Unregister a previously registered handler."""
|
||||
if handler in self.update_handlers:
|
||||
self.update_handlers.remove(handler)
|
||||
|
||||
def start(self):
|
||||
"""Start the progress display."""
|
||||
if not self.started:
|
||||
self.live.start()
|
||||
self.started = True
|
||||
|
||||
def stop(self):
|
||||
"""Stop the progress display."""
|
||||
if self.started:
|
||||
self.live.stop()
|
||||
self.started = False
|
||||
|
||||
def update_status(
|
||||
self,
|
||||
agent_name: str,
|
||||
ticker: Optional[str] = None,
|
||||
status: str = "",
|
||||
analysis: Optional[str] = None,
|
||||
):
|
||||
"""Update the status of an agent."""
|
||||
if agent_name not in self.agent_status:
|
||||
self.agent_status[agent_name] = {"status": "", "ticker": None}
|
||||
|
||||
if ticker:
|
||||
self.agent_status[agent_name]["ticker"] = ticker
|
||||
if status:
|
||||
self.agent_status[agent_name]["status"] = status
|
||||
if analysis:
|
||||
self.agent_status[agent_name]["analysis"] = analysis
|
||||
|
||||
# Set the timestamp as UTC datetime
|
||||
timestamp = datetime.now(timezone.utc).isoformat()
|
||||
self.agent_status[agent_name]["timestamp"] = timestamp
|
||||
|
||||
# Notify all registered handlers
|
||||
for handler in self.update_handlers:
|
||||
handler(agent_name, ticker, status, analysis, timestamp)
|
||||
|
||||
self._refresh_display()
|
||||
|
||||
def get_all_status(self):
|
||||
"""Get the current status of all agents as a dictionary."""
|
||||
return {
|
||||
agent_name: {
|
||||
"ticker": info["ticker"],
|
||||
"status": info["status"],
|
||||
"display_name": self._get_display_name(agent_name),
|
||||
}
|
||||
for agent_name, info in self.agent_status.items()
|
||||
}
|
||||
|
||||
def _get_display_name(self, agent_name: str) -> str:
|
||||
"""Convert agent_name to a display-friendly format."""
|
||||
return agent_name.replace("_agent", "").replace("_", " ").title()
|
||||
|
||||
def _refresh_display(self):
|
||||
"""Refresh the progress display."""
|
||||
self.table.columns.clear()
|
||||
self.table.add_column(width=100)
|
||||
|
||||
# Sort Risk Management and Portfolio Management at the bottom
|
||||
def sort_key(item):
|
||||
agent_name = item[0]
|
||||
if "risk_manager" in agent_name:
|
||||
return (2, agent_name)
|
||||
elif "portfolio_manager" in agent_name:
|
||||
return (3, agent_name)
|
||||
else:
|
||||
return (1, agent_name)
|
||||
|
||||
for agent_name, info in sorted(
|
||||
self.agent_status.items(),
|
||||
key=sort_key,
|
||||
):
|
||||
status = info["status"]
|
||||
ticker = info["ticker"]
|
||||
# Create the status text with appropriate styling
|
||||
if status.lower() == "done":
|
||||
style = Style(color="green", bold=True)
|
||||
symbol = "✓"
|
||||
elif status.lower() == "error":
|
||||
style = Style(color="red", bold=True)
|
||||
symbol = "✗"
|
||||
else:
|
||||
style = Style(color="yellow")
|
||||
symbol = "⋯"
|
||||
|
||||
agent_display = self._get_display_name(agent_name)
|
||||
status_text = Text()
|
||||
status_text.append(f"{symbol} ", style=style)
|
||||
status_text.append(f"{agent_display:<20}", style=Style(bold=True))
|
||||
|
||||
if ticker:
|
||||
status_text.append(f"[{ticker}] ", style=Style(color="cyan"))
|
||||
status_text.append(status, style=style)
|
||||
|
||||
self.table.add_row(status_text)
|
||||
|
||||
|
||||
# Create a global instance
|
||||
progress = AgentProgress()
|
||||
Reference in New Issue
Block a user