556 lines
18 KiB
Python
556 lines
18 KiB
Python
"""Unit tests for SentimentAnalyst.
|
|
|
|
This module tests the SentimentAnalyst class including sentiment analysis,
|
|
news collection, report generation, and cost deduction.
|
|
"""
|
|
|
|
import asyncio
|
|
from unittest.mock import patch
|
|
|
|
import pytest
|
|
|
|
from openclaw.agents.base import ActivityType
|
|
from openclaw.agents.sentiment_analyst import (
|
|
SentimentAnalyst,
|
|
SentimentReport,
|
|
SentimentSource,
|
|
)
|
|
from openclaw.core.economy import SurvivalStatus
|
|
|
|
|
|
class TestSentimentAnalystInitialization:
|
|
"""Test SentimentAnalyst initialization."""
|
|
|
|
def test_default_initialization(self):
|
|
"""Test agent with default parameters."""
|
|
agent = SentimentAnalyst(agent_id="sentiment-1", initial_capital=10000.0)
|
|
|
|
assert agent.agent_id == "sentiment-1"
|
|
assert agent.balance == 10000.0
|
|
assert agent.skill_level == 0.5
|
|
assert agent.max_sources == 10
|
|
assert agent._analysis_history == []
|
|
assert agent.decision_cost == 0.08
|
|
|
|
def test_custom_initialization(self):
|
|
"""Test agent with custom parameters."""
|
|
agent = SentimentAnalyst(
|
|
agent_id="sentiment-2",
|
|
initial_capital=5000.0,
|
|
skill_level=0.8,
|
|
max_sources=15,
|
|
)
|
|
|
|
assert agent.agent_id == "sentiment-2"
|
|
assert agent.balance == 5000.0
|
|
assert agent.skill_level == 0.8
|
|
assert agent.max_sources == 15
|
|
|
|
def test_inherits_from_base_agent(self):
|
|
"""Test that SentimentAnalyst inherits from BaseAgent."""
|
|
from openclaw.agents.base import BaseAgent
|
|
|
|
agent = SentimentAnalyst(agent_id="test", initial_capital=10000.0)
|
|
|
|
assert isinstance(agent, BaseAgent)
|
|
|
|
|
|
class TestDecideActivity:
|
|
"""Test decide_activity method."""
|
|
|
|
@pytest.fixture
|
|
def agent(self):
|
|
"""Create a test agent."""
|
|
return SentimentAnalyst(agent_id="test", initial_capital=10000.0)
|
|
|
|
def test_bankrupt_agent_only_rests(self, agent):
|
|
"""Test that bankrupt agent can only rest."""
|
|
agent.economic_tracker.balance = 0 # Bankrupt
|
|
|
|
result = asyncio.run(agent.decide_activity())
|
|
|
|
assert result == ActivityType.REST
|
|
|
|
def test_critical_status_prefers_learning(self, agent):
|
|
"""Test critical status leads to learning."""
|
|
agent.economic_tracker.balance = 3500.0 # Critical
|
|
agent.state.skill_level = 0.5
|
|
|
|
result = asyncio.run(agent.decide_activity())
|
|
|
|
assert result in [ActivityType.LEARN, ActivityType.PAPER_TRADE]
|
|
|
|
def test_thriving_status_prefers_analyzing(self, agent):
|
|
"""Test thriving status leads to analyzing."""
|
|
agent.economic_tracker.balance = 20000.0 # Thriving
|
|
|
|
# Run multiple times to account for randomness
|
|
results = [asyncio.run(agent.decide_activity()) for _ in range(20)]
|
|
|
|
# Most should be ANALYZE
|
|
analyze_count = results.count(ActivityType.ANALYZE)
|
|
assert analyze_count >= 10 # At least half
|
|
|
|
def test_struggling_status_less_analyzing(self, agent):
|
|
"""Test struggling status prefers paper trading."""
|
|
agent.economic_tracker.balance = 8500.0 # Struggling
|
|
|
|
# Run multiple times
|
|
results = [asyncio.run(agent.decide_activity()) for _ in range(20)]
|
|
|
|
# Some should be paper trade
|
|
paper_trades = [r for r in results if r == ActivityType.PAPER_TRADE]
|
|
assert len(paper_trades) >= 3
|
|
|
|
|
|
class TestAnalyzeSentiment:
|
|
"""Test analyze_sentiment method."""
|
|
|
|
@pytest.fixture
|
|
def agent(self):
|
|
"""Create a test agent."""
|
|
return SentimentAnalyst(agent_id="test", initial_capital=10000.0)
|
|
|
|
def test_returns_sentiment_report(self, agent):
|
|
"""Test that analyze_sentiment returns SentimentReport."""
|
|
result = asyncio.run(agent.analyze_sentiment("AAPL"))
|
|
|
|
assert isinstance(result, SentimentReport)
|
|
assert result.symbol == "AAPL"
|
|
assert result.overall_sentiment in ["bullish", "bearish", "neutral"]
|
|
assert -1.0 <= result.sentiment_score <= 1.0
|
|
assert len(result.sources) > 0
|
|
assert result.summary != ""
|
|
|
|
def test_deducts_decision_cost(self, agent):
|
|
"""Test that analyze_sentiment deducts $0.08 decision cost."""
|
|
initial_balance = agent.balance
|
|
|
|
asyncio.run(agent.analyze_sentiment("AAPL"))
|
|
|
|
# Should deduct $0.08 + additional token/data costs
|
|
assert agent.balance < initial_balance
|
|
# At minimum, the $0.08 should be deducted
|
|
assert agent.balance <= initial_balance - 0.08
|
|
|
|
def test_exact_decision_cost_deducted(self, agent):
|
|
"""Test that exactly $0.08 decision cost is deducted."""
|
|
initial_balance = agent.balance
|
|
|
|
# Calculate expected cost
|
|
expected_cost = agent.decision_cost # $0.08
|
|
|
|
asyncio.run(agent.analyze_sentiment("AAPL"))
|
|
|
|
# The balance change should be at least the decision cost
|
|
balance_change = initial_balance - agent.balance
|
|
# Use approximate comparison due to floating point precision
|
|
assert balance_change >= expected_cost - 0.001 # Allow small floating point tolerance
|
|
|
|
def test_sentiment_score_in_range(self, agent):
|
|
"""Test that sentiment score is between -1.0 and 1.0."""
|
|
result = asyncio.run(agent.analyze_sentiment("TSLA"))
|
|
|
|
assert -1.0 <= result.sentiment_score <= 1.0
|
|
|
|
def test_confidence_in_range(self, agent):
|
|
"""Test that confidence is between 0.0 and 1.0."""
|
|
result = asyncio.run(agent.analyze_sentiment("NVDA"))
|
|
|
|
assert 0.0 < result.confidence <= 1.0
|
|
|
|
def test_sources_populated(self, agent):
|
|
"""Test that sources are populated in the report."""
|
|
result = asyncio.run(agent.analyze_sentiment("MSFT"))
|
|
|
|
assert len(result.sources) > 0
|
|
assert all(isinstance(s, SentimentSource) for s in result.sources)
|
|
assert all(s.title for s in result.sources)
|
|
assert all(s.source for s in result.sources)
|
|
|
|
def test_sample_headlines_populated(self, agent):
|
|
"""Test that sample headlines are populated."""
|
|
result = asyncio.run(agent.analyze_sentiment("GOOGL"))
|
|
|
|
assert len(result.sample_headlines) > 0
|
|
assert len(result.sample_headlines) <= 3
|
|
|
|
def test_timestamp_populated(self, agent):
|
|
"""Test that timestamp is populated."""
|
|
result = asyncio.run(agent.analyze_sentiment("AMZN"))
|
|
|
|
assert result.timestamp != ""
|
|
assert "T" in result.timestamp # ISO format has T
|
|
|
|
def test_history_recorded(self, agent):
|
|
"""Test that analysis is recorded in history."""
|
|
initial_history_len = len(agent._analysis_history)
|
|
|
|
asyncio.run(agent.analyze_sentiment("META"))
|
|
|
|
assert len(agent._analysis_history) == initial_history_len + 1
|
|
assert agent._analysis_history[-1].symbol == "META"
|
|
|
|
|
|
class TestAnalyze:
|
|
"""Test analyze method (async, from BaseAgent)."""
|
|
|
|
@pytest.fixture
|
|
def agent(self):
|
|
"""Create a test agent."""
|
|
return SentimentAnalyst(agent_id="test", initial_capital=10000.0)
|
|
|
|
def test_analyze_returns_dict(self, agent):
|
|
"""Test that analyze returns a dictionary."""
|
|
result = asyncio.run(agent.analyze("AAPL"))
|
|
|
|
assert isinstance(result, dict)
|
|
assert result["symbol"] == "AAPL"
|
|
assert "sentiment" in result
|
|
assert "score" in result
|
|
assert "confidence" in result
|
|
assert "summary" in result
|
|
assert "sources_analyzed" in result
|
|
assert "cost" in result
|
|
|
|
def test_analyze_deducts_cost(self, agent):
|
|
"""Test that analyze deducts cost."""
|
|
initial_balance = agent.balance
|
|
|
|
asyncio.run(agent.analyze("AAPL"))
|
|
|
|
assert agent.balance < initial_balance
|
|
|
|
|
|
class TestCollectNewsData:
|
|
"""Test _collect_news_data method."""
|
|
|
|
@pytest.fixture
|
|
def agent(self):
|
|
"""Create a test agent."""
|
|
return SentimentAnalyst(agent_id="test", initial_capital=10000.0)
|
|
|
|
def test_returns_list_of_sources(self, agent):
|
|
"""Test that method returns list of SentimentSource."""
|
|
sources = agent._collect_news_data("AAPL")
|
|
|
|
assert isinstance(sources, list)
|
|
assert len(sources) > 0
|
|
assert all(isinstance(s, SentimentSource) for s in sources)
|
|
|
|
def test_sources_within_max_limit(self, agent):
|
|
"""Test that sources don't exceed max_sources."""
|
|
sources = agent._collect_news_data("TSLA")
|
|
|
|
assert len(sources) <= agent.max_sources
|
|
|
|
def test_sources_have_required_fields(self, agent):
|
|
"""Test that sources have required fields."""
|
|
sources = agent._collect_news_data("NVDA")
|
|
|
|
for source in sources:
|
|
assert source.title != ""
|
|
assert source.source in ["Reuters", "Bloomberg", "CNBC", "WSJ", "TechCrunch"]
|
|
assert 0.0 <= source.relevance_score <= 1.0
|
|
assert source.raw_sentiment in ["positive", "negative", "neutral"]
|
|
|
|
|
|
class TestCalculateSentiment:
|
|
"""Test _calculate_sentiment method."""
|
|
|
|
@pytest.fixture
|
|
def agent(self):
|
|
"""Create a test agent."""
|
|
return SentimentAnalyst(agent_id="test", initial_capital=10000.0)
|
|
|
|
def test_bullish_sources_positive_score(self, agent):
|
|
"""Test that bullish sources produce positive score."""
|
|
sources = [
|
|
SentimentSource(
|
|
title="Stock surges on strong earnings growth",
|
|
content="Company reports record profit and expansion",
|
|
source="Reuters",
|
|
timestamp="2024-01-01T00:00:00",
|
|
raw_sentiment="positive",
|
|
relevance_score=1.0,
|
|
)
|
|
]
|
|
|
|
score, confidence = agent._calculate_sentiment(sources)
|
|
|
|
assert score > 0
|
|
assert 0.0 < confidence <= 1.0
|
|
|
|
def test_bearish_sources_negative_score(self, agent):
|
|
"""Test that bearish sources produce negative score."""
|
|
sources = [
|
|
SentimentSource(
|
|
title="Stock crashes amid bankruptcy fears",
|
|
content="Company faces major losses and layoffs",
|
|
source="Reuters",
|
|
timestamp="2024-01-01T00:00:00",
|
|
raw_sentiment="negative",
|
|
relevance_score=1.0,
|
|
)
|
|
]
|
|
|
|
score, confidence = agent._calculate_sentiment(sources)
|
|
|
|
assert score < 0
|
|
assert 0.0 < confidence <= 1.0
|
|
|
|
def test_neutral_sources_near_zero_score(self, agent):
|
|
"""Test that neutral sources produce near-zero score."""
|
|
sources = [
|
|
SentimentSource(
|
|
title="Company announces regular meeting",
|
|
content="Standard board meeting scheduled",
|
|
source="Reuters",
|
|
timestamp="2024-01-01T00:00:00",
|
|
raw_sentiment="neutral",
|
|
relevance_score=0.5,
|
|
)
|
|
]
|
|
|
|
score, confidence = agent._calculate_sentiment(sources)
|
|
|
|
assert -0.5 <= score <= 0.5
|
|
|
|
def test_empty_sources_returns_zero(self, agent):
|
|
"""Test that empty sources return zero score."""
|
|
score, confidence = agent._calculate_sentiment([])
|
|
|
|
assert score == 0.0
|
|
assert confidence == 0.0
|
|
|
|
|
|
class TestGenerateSummary:
|
|
"""Test _generate_summary method."""
|
|
|
|
@pytest.fixture
|
|
def agent(self):
|
|
"""Create a test agent."""
|
|
return SentimentAnalyst(agent_id="test", initial_capital=10000.0)
|
|
|
|
def test_bullish_summary(self, agent):
|
|
"""Test summary generation for bullish sentiment."""
|
|
sources = [
|
|
SentimentSource(title="Good news", content="", source="A", timestamp="", raw_sentiment="positive"),
|
|
SentimentSource(title="Good news", content="", source="B", timestamp="", raw_sentiment="positive"),
|
|
]
|
|
|
|
summary = agent._generate_summary("AAPL", "bullish", 0.5, sources)
|
|
|
|
assert "bullish" in summary.lower()
|
|
assert "AAPL" in summary
|
|
assert "0.50" in summary or "0.5" in summary
|
|
|
|
def test_bearish_summary(self, agent):
|
|
"""Test summary generation for bearish sentiment."""
|
|
sources = [
|
|
SentimentSource(title="Bad news", content="", source="A", timestamp="", raw_sentiment="negative"),
|
|
]
|
|
|
|
summary = agent._generate_summary("TSLA", "bearish", -0.3, sources)
|
|
|
|
assert "bearish" in summary.lower()
|
|
assert "TSLA" in summary
|
|
|
|
def test_neutral_summary(self, agent):
|
|
"""Test summary generation for neutral sentiment."""
|
|
sources = [
|
|
SentimentSource(title="News", content="", source="A", timestamp="", raw_sentiment="neutral"),
|
|
]
|
|
|
|
summary = agent._generate_summary("NVDA", "neutral", 0.0, sources)
|
|
|
|
assert "neutral" in summary.lower()
|
|
assert "NVDA" in summary
|
|
|
|
|
|
class TestGetAnalysisHistory:
|
|
"""Test get_analysis_history method."""
|
|
|
|
@pytest.fixture
|
|
def agent(self):
|
|
"""Create a test agent."""
|
|
return SentimentAnalyst(agent_id="test", initial_capital=10000.0)
|
|
|
|
def test_returns_copy(self, agent):
|
|
"""Test that method returns a copy of history."""
|
|
asyncio.run(agent.analyze_sentiment("AAPL"))
|
|
|
|
history = agent.get_analysis_history()
|
|
history.append(None) # Modify the copy
|
|
|
|
# Original should be unchanged
|
|
assert len(agent._analysis_history) == 1
|
|
|
|
def test_returns_all_analyses(self, agent):
|
|
"""Test that method returns all analyses."""
|
|
asyncio.run(agent.analyze_sentiment("AAPL"))
|
|
asyncio.run(agent.analyze_sentiment("TSLA"))
|
|
|
|
history = agent.get_analysis_history()
|
|
|
|
assert len(history) == 2
|
|
|
|
|
|
class TestGetSentimentTrend:
|
|
"""Test get_sentiment_trend method."""
|
|
|
|
@pytest.fixture
|
|
def agent(self):
|
|
"""Create a test agent."""
|
|
return SentimentAnalyst(agent_id="test", initial_capital=10000.0)
|
|
|
|
def test_returns_none_for_no_history(self, agent):
|
|
"""Test that method returns None when no history."""
|
|
result = agent.get_sentiment_trend("AAPL")
|
|
|
|
assert result is None
|
|
|
|
def test_returns_trend_data(self, agent):
|
|
"""Test that method returns trend data."""
|
|
# Create multiple analyses for same symbol
|
|
with patch.object(agent, '_collect_news_data') as mock_collect:
|
|
mock_collect.return_value = [
|
|
SentimentSource(
|
|
title="Positive news",
|
|
content="",
|
|
source="Reuters",
|
|
timestamp="",
|
|
raw_sentiment="positive",
|
|
relevance_score=1.0,
|
|
)
|
|
]
|
|
asyncio.run(agent.analyze_sentiment("AAPL"))
|
|
asyncio.run(agent.analyze_sentiment("AAPL"))
|
|
|
|
result = agent.get_sentiment_trend("AAPL")
|
|
|
|
assert result is not None
|
|
assert result["symbol"] == "AAPL"
|
|
assert "average_score" in result
|
|
assert "trend" in result
|
|
assert "analyses_count" in result
|
|
assert "latest_sentiment" in result
|
|
|
|
def test_trend_improving(self, agent):
|
|
"""Test that trend shows improving when scores increase."""
|
|
# Manually add analyses with increasing scores
|
|
agent._analysis_history.append(
|
|
SentimentReport(
|
|
symbol="AAPL",
|
|
overall_sentiment="neutral",
|
|
sentiment_score=0.0,
|
|
sources=[],
|
|
summary="",
|
|
)
|
|
)
|
|
agent._analysis_history.append(
|
|
SentimentReport(
|
|
symbol="AAPL",
|
|
overall_sentiment="bullish",
|
|
sentiment_score=0.5,
|
|
sources=[],
|
|
summary="",
|
|
)
|
|
)
|
|
agent._analysis_history.append(
|
|
SentimentReport(
|
|
symbol="AAPL",
|
|
overall_sentiment="bullish",
|
|
sentiment_score=0.8,
|
|
sources=[],
|
|
summary="",
|
|
)
|
|
)
|
|
|
|
result = agent.get_sentiment_trend("AAPL")
|
|
|
|
assert result["trend"] == "improving"
|
|
|
|
|
|
class TestSentimentSource:
|
|
"""Test SentimentSource dataclass."""
|
|
|
|
def test_creation(self):
|
|
"""Test creating a SentimentSource."""
|
|
source = SentimentSource(
|
|
title="Test Title",
|
|
content="Test Content",
|
|
source="Reuters",
|
|
timestamp="2024-01-01T00:00:00",
|
|
raw_sentiment="positive",
|
|
relevance_score=0.8,
|
|
)
|
|
|
|
assert source.title == "Test Title"
|
|
assert source.content == "Test Content"
|
|
assert source.source == "Reuters"
|
|
assert source.raw_sentiment == "positive"
|
|
assert source.relevance_score == 0.8
|
|
|
|
def test_default_values(self):
|
|
"""Test default values for SentimentSource."""
|
|
source = SentimentSource(
|
|
title="Test",
|
|
content="",
|
|
source="Reuters",
|
|
timestamp="",
|
|
)
|
|
|
|
assert source.raw_sentiment == ""
|
|
assert source.relevance_score == 0.5
|
|
|
|
|
|
class TestSentimentReport:
|
|
"""Test SentimentReport dataclass."""
|
|
|
|
def test_creation(self):
|
|
"""Test creating a SentimentReport."""
|
|
report = SentimentReport(
|
|
symbol="AAPL",
|
|
overall_sentiment="bullish",
|
|
sentiment_score=0.75,
|
|
sources=[],
|
|
summary="Positive outlook",
|
|
confidence=0.8,
|
|
sample_headlines=["Good news"],
|
|
)
|
|
|
|
assert report.symbol == "AAPL"
|
|
assert report.overall_sentiment == "bullish"
|
|
assert report.sentiment_score == 0.75
|
|
assert report.summary == "Positive outlook"
|
|
assert report.confidence == 0.8
|
|
|
|
def test_default_values(self):
|
|
"""Test default values for SentimentReport."""
|
|
report = SentimentReport(
|
|
symbol="TSLA",
|
|
overall_sentiment="neutral",
|
|
sentiment_score=0.0,
|
|
sources=[],
|
|
summary="",
|
|
)
|
|
|
|
assert report.timestamp != ""
|
|
assert report.confidence == 0.5
|
|
assert report.sample_headlines == []
|
|
|
|
|
|
class TestDecisionCost:
|
|
"""Test decision_cost class attribute."""
|
|
|
|
def test_decision_cost_value(self):
|
|
"""Test that decision_cost is $0.08."""
|
|
agent = SentimentAnalyst(agent_id="test", initial_capital=10000.0)
|
|
|
|
assert agent.decision_cost == 0.08
|
|
|
|
def test_decision_cost_class_attribute(self):
|
|
"""Test that decision_cost is a class attribute."""
|
|
assert SentimentAnalyst.decision_cost == 0.08
|