"""Tests for LangGraph-based trading workflow. This module tests the LangGraph workflow implementation including: - State management - Node functions - Graph construction and execution - End-to-end workflow runs """ import asyncio import pytest from openclaw.workflow.state import ( TradingWorkflowState, create_initial_state, get_state_summary, ) from openclaw.workflow.trading_workflow import TradingWorkflow, run_trading_workflow class TestWorkflowState: """Test suite for workflow state management.""" def test_create_initial_state(self): """Test initial state creation.""" state = create_initial_state("AAPL", 1000.0) assert state["config"]["symbol"] == "AAPL" assert state["current_step"] == "START" assert state["completed_steps"] == [] assert state["errors"] == [] assert state["technical_report"] is None assert state["sentiment_report"] is None assert state["fundamental_report"] is None def test_create_initial_state_different_symbol(self): """Test initial state with different symbol.""" state = create_initial_state("TSLA", 2000.0) assert state["config"]["symbol"] == "TSLA" assert state["config"]["market_analyst"]["agent_id"] == "market_analyst_TSLA" def test_get_state_summary_initial(self): """Test state summary for initial state.""" state = create_initial_state("AAPL", 1000.0) summary = get_state_summary(state) assert summary["symbol"] == "AAPL" assert summary["current_step"] == "START" assert summary["has_technical"] is False assert summary["has_sentiment"] is False assert summary["has_fundamental"] is False assert summary["error_count"] == 0 class TestTradingWorkflow: """Test suite for TradingWorkflow class.""" def test_workflow_initialization(self): """Test workflow initialization.""" workflow = TradingWorkflow("AAPL", 1000.0) assert workflow.symbol == "AAPL" assert workflow.initial_capital == 1000.0 assert workflow.enable_parallel is True def test_graph_build(self): """Test graph compilation.""" workflow = TradingWorkflow("AAPL", 1000.0) # Build graph graph = workflow._build_graph() assert graph is not None def test_graph_property(self): """Test graph property access.""" workflow = TradingWorkflow("AAPL", 1000.0) # Access graph property (should build on first access) graph = workflow.graph assert graph is not None # Second access should return cached graph assert workflow.graph is graph @pytest.mark.asyncio async def test_workflow_run_basic(self): """Test basic workflow execution.""" workflow = TradingWorkflow("AAPL", 1000.0) final_state = await workflow.run(debug=True) assert final_state is not None assert "completed_steps" in final_state @pytest.mark.asyncio async def test_workflow_run_outputs_generated(self): """Test that workflow generates expected outputs.""" workflow = TradingWorkflow("AAPL", 1000.0) final_state = await workflow.run() # Check that at least some analyses completed completed = final_state.get("completed_steps", []) assert len(completed) > 0 def test_run_sync(self): """Test synchronous workflow execution.""" workflow = TradingWorkflow("AAPL", 1000.0) final_state = workflow.run_sync() assert final_state is not None assert "completed_steps" in final_state def test_get_final_decision(self): """Test getting final decision from state.""" workflow = TradingWorkflow("AAPL", 1000.0) final_state = workflow.run_sync() decision = workflow.get_final_decision(final_state) # Decision may be None if workflow didn't complete, but shouldn't error if decision: assert "symbol" in decision assert decision["symbol"] == "AAPL" def test_visualize(self): """Test workflow visualization generation.""" workflow = TradingWorkflow("AAPL", 1000.0) mermaid = workflow.visualize() assert "flowchart" in mermaid assert "MarketAnalysis" in mermaid assert "SentimentAnalysis" in mermaid assert "FundamentalAnalysis" in mermaid assert "BullBearDebate" in mermaid assert "DecisionFusion" in mermaid assert "RiskAssessment" in mermaid class TestWorkflowIntegration: """Integration tests for the complete workflow.""" @pytest.mark.asyncio async def test_full_workflow_streaming(self): """Test workflow with streaming.""" workflow = TradingWorkflow("MSFT", 1000.0) updates = [] async for update in workflow.astream(debug=True): updates.append(update) assert len(updates) > 0 def test_convenience_function(self): """Test the convenience function run_trading_workflow.""" decision = run_trading_workflow("GOOGL", 1000.0, debug=False) assert decision is not None assert "symbol" in decision assert decision["symbol"] == "GOOGL" class TestWorkflowEdgeCases: """Test edge cases and error handling.""" def test_empty_symbol(self): """Test workflow with empty symbol.""" workflow = TradingWorkflow("", 1000.0) final_state = workflow.run_sync() # Should still complete without errors assert final_state is not None def test_zero_capital(self): """Test workflow with zero capital.""" workflow = TradingWorkflow("AAPL", 0.0) final_state = workflow.run_sync() # Should still complete assert final_state is not None @pytest.mark.asyncio async def test_multiple_workflows(self): """Test running multiple workflows concurrently.""" workflow1 = TradingWorkflow("AAPL", 1000.0) workflow2 = TradingWorkflow("GOOGL", 1000.0) workflow3 = TradingWorkflow("MSFT", 1000.0) # Run all three concurrently results = await asyncio.gather( workflow1.run(), workflow2.run(), workflow3.run(), ) assert len(results) == 3 for result in results: assert result is not None if __name__ == "__main__": pytest.main([__file__, "-v"])