210 lines
6.3 KiB
Python
210 lines
6.3 KiB
Python
"""Tests for LangGraph-based trading workflow.
|
|
|
|
This module tests the LangGraph workflow implementation including:
|
|
- State management
|
|
- Node functions
|
|
- Graph construction and execution
|
|
- End-to-end workflow runs
|
|
"""
|
|
|
|
import asyncio
|
|
|
|
import pytest
|
|
|
|
from openclaw.workflow.state import (
|
|
TradingWorkflowState,
|
|
create_initial_state,
|
|
get_state_summary,
|
|
)
|
|
from openclaw.workflow.trading_workflow import TradingWorkflow, run_trading_workflow
|
|
|
|
|
|
class TestWorkflowState:
|
|
"""Test suite for workflow state management."""
|
|
|
|
def test_create_initial_state(self):
|
|
"""Test initial state creation."""
|
|
state = create_initial_state("AAPL", 1000.0)
|
|
|
|
assert state["config"]["symbol"] == "AAPL"
|
|
assert state["current_step"] == "START"
|
|
assert state["completed_steps"] == []
|
|
assert state["errors"] == []
|
|
assert state["technical_report"] is None
|
|
assert state["sentiment_report"] is None
|
|
assert state["fundamental_report"] is None
|
|
|
|
def test_create_initial_state_different_symbol(self):
|
|
"""Test initial state with different symbol."""
|
|
state = create_initial_state("TSLA", 2000.0)
|
|
|
|
assert state["config"]["symbol"] == "TSLA"
|
|
assert state["config"]["market_analyst"]["agent_id"] == "market_analyst_TSLA"
|
|
|
|
def test_get_state_summary_initial(self):
|
|
"""Test state summary for initial state."""
|
|
state = create_initial_state("AAPL", 1000.0)
|
|
summary = get_state_summary(state)
|
|
|
|
assert summary["symbol"] == "AAPL"
|
|
assert summary["current_step"] == "START"
|
|
assert summary["has_technical"] is False
|
|
assert summary["has_sentiment"] is False
|
|
assert summary["has_fundamental"] is False
|
|
assert summary["error_count"] == 0
|
|
|
|
|
|
class TestTradingWorkflow:
|
|
"""Test suite for TradingWorkflow class."""
|
|
|
|
def test_workflow_initialization(self):
|
|
"""Test workflow initialization."""
|
|
workflow = TradingWorkflow("AAPL", 1000.0)
|
|
|
|
assert workflow.symbol == "AAPL"
|
|
assert workflow.initial_capital == 1000.0
|
|
assert workflow.enable_parallel is True
|
|
|
|
def test_graph_build(self):
|
|
"""Test graph compilation."""
|
|
workflow = TradingWorkflow("AAPL", 1000.0)
|
|
|
|
# Build graph
|
|
graph = workflow._build_graph()
|
|
|
|
assert graph is not None
|
|
|
|
def test_graph_property(self):
|
|
"""Test graph property access."""
|
|
workflow = TradingWorkflow("AAPL", 1000.0)
|
|
|
|
# Access graph property (should build on first access)
|
|
graph = workflow.graph
|
|
|
|
assert graph is not None
|
|
# Second access should return cached graph
|
|
assert workflow.graph is graph
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_workflow_run_basic(self):
|
|
"""Test basic workflow execution."""
|
|
workflow = TradingWorkflow("AAPL", 1000.0)
|
|
|
|
final_state = await workflow.run(debug=True)
|
|
|
|
assert final_state is not None
|
|
assert "completed_steps" in final_state
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_workflow_run_outputs_generated(self):
|
|
"""Test that workflow generates expected outputs."""
|
|
workflow = TradingWorkflow("AAPL", 1000.0)
|
|
|
|
final_state = await workflow.run()
|
|
|
|
# Check that at least some analyses completed
|
|
completed = final_state.get("completed_steps", [])
|
|
assert len(completed) > 0
|
|
|
|
def test_run_sync(self):
|
|
"""Test synchronous workflow execution."""
|
|
workflow = TradingWorkflow("AAPL", 1000.0)
|
|
|
|
final_state = workflow.run_sync()
|
|
|
|
assert final_state is not None
|
|
assert "completed_steps" in final_state
|
|
|
|
def test_get_final_decision(self):
|
|
"""Test getting final decision from state."""
|
|
workflow = TradingWorkflow("AAPL", 1000.0)
|
|
|
|
final_state = workflow.run_sync()
|
|
decision = workflow.get_final_decision(final_state)
|
|
|
|
# Decision may be None if workflow didn't complete, but shouldn't error
|
|
if decision:
|
|
assert "symbol" in decision
|
|
assert decision["symbol"] == "AAPL"
|
|
|
|
def test_visualize(self):
|
|
"""Test workflow visualization generation."""
|
|
workflow = TradingWorkflow("AAPL", 1000.0)
|
|
|
|
mermaid = workflow.visualize()
|
|
|
|
assert "flowchart" in mermaid
|
|
assert "MarketAnalysis" in mermaid
|
|
assert "SentimentAnalysis" in mermaid
|
|
assert "FundamentalAnalysis" in mermaid
|
|
assert "BullBearDebate" in mermaid
|
|
assert "DecisionFusion" in mermaid
|
|
assert "RiskAssessment" in mermaid
|
|
|
|
|
|
class TestWorkflowIntegration:
|
|
"""Integration tests for the complete workflow."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_full_workflow_streaming(self):
|
|
"""Test workflow with streaming."""
|
|
workflow = TradingWorkflow("MSFT", 1000.0)
|
|
|
|
updates = []
|
|
async for update in workflow.astream(debug=True):
|
|
updates.append(update)
|
|
|
|
assert len(updates) > 0
|
|
|
|
def test_convenience_function(self):
|
|
"""Test the convenience function run_trading_workflow."""
|
|
decision = run_trading_workflow("GOOGL", 1000.0, debug=False)
|
|
|
|
assert decision is not None
|
|
assert "symbol" in decision
|
|
assert decision["symbol"] == "GOOGL"
|
|
|
|
|
|
class TestWorkflowEdgeCases:
|
|
"""Test edge cases and error handling."""
|
|
|
|
def test_empty_symbol(self):
|
|
"""Test workflow with empty symbol."""
|
|
workflow = TradingWorkflow("", 1000.0)
|
|
|
|
final_state = workflow.run_sync()
|
|
|
|
# Should still complete without errors
|
|
assert final_state is not None
|
|
|
|
def test_zero_capital(self):
|
|
"""Test workflow with zero capital."""
|
|
workflow = TradingWorkflow("AAPL", 0.0)
|
|
|
|
final_state = workflow.run_sync()
|
|
|
|
# Should still complete
|
|
assert final_state is not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_multiple_workflows(self):
|
|
"""Test running multiple workflows concurrently."""
|
|
workflow1 = TradingWorkflow("AAPL", 1000.0)
|
|
workflow2 = TradingWorkflow("GOOGL", 1000.0)
|
|
workflow3 = TradingWorkflow("MSFT", 1000.0)
|
|
|
|
# Run all three concurrently
|
|
results = await asyncio.gather(
|
|
workflow1.run(),
|
|
workflow2.run(),
|
|
workflow3.run(),
|
|
)
|
|
|
|
assert len(results) == 3
|
|
for result in results:
|
|
assert result is not None
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v"])
|