stock/tests/unit/test_learning_memory.py
2026-02-27 03:17:12 +08:00

843 lines
26 KiB
Python

"""Unit tests for Agent learning memory system."""
import tempfile
from datetime import datetime, timedelta
from pathlib import Path
import pytest
from openclaw.memory import (
BM25Index,
DecisionMemory,
ErrorMemory,
LearningMemory,
MarketMemory,
MemoryDocument,
MemoryType,
TradeMemory,
)
class TestMemoryDocument:
"""Test MemoryDocument data class."""
def test_document_creation(self):
"""Test creating a MemoryDocument."""
doc = MemoryDocument(
doc_id="test_001",
content="test content",
memory_type="trade_memory",
importance=0.8,
)
assert doc.doc_id == "test_001"
assert doc.content == "test content"
assert doc.memory_type == "trade_memory"
assert doc.importance == 0.8
assert doc.access_count == 0
assert doc.last_accessed is None
assert isinstance(doc.timestamp, datetime)
def test_document_to_dict(self):
"""Test converting document to dictionary."""
doc = MemoryDocument(
doc_id="test_002",
content="test content",
memory_type="market_memory",
importance=0.5,
metadata={"key": "value"},
)
data = doc.to_dict()
assert data["doc_id"] == "test_002"
assert data["content"] == "test content"
assert data["memory_type"] == "market_memory"
assert data["importance"] == 0.5
assert data["metadata"] == {"key": "value"}
def test_document_from_dict(self):
"""Test creating document from dictionary."""
data = {
"doc_id": "test_003",
"content": "test content",
"memory_type": "error_memory",
"timestamp": datetime.now().isoformat(),
"metadata": {"error": "test"},
"importance": 0.9,
"access_count": 5,
"last_accessed": None,
}
doc = MemoryDocument.from_dict(data)
assert doc.doc_id == "test_003"
assert doc.content == "test content"
assert doc.memory_type == "error_memory"
assert doc.importance == 0.9
assert doc.access_count == 5
def test_document_serialization_with_last_accessed(self):
"""Test serialization with last_accessed timestamp."""
now = datetime.now()
doc = MemoryDocument(
doc_id="test_004",
content="test",
memory_type="trade_memory",
last_accessed=now,
)
data = doc.to_dict()
assert data["last_accessed"] == now.isoformat()
restored = MemoryDocument.from_dict(data)
assert restored.last_accessed == now
class TestBM25Index:
"""Test BM25 index implementation."""
def test_index_initialization(self):
"""Test BM25 index initialization."""
index = BM25Index(k1=1.2, b=0.75)
assert index.k1 == 1.2
assert index.b == 0.75
assert index.num_docs == 0
assert index.avg_doc_length == 0.0
def test_add_document(self):
"""Test adding documents to index."""
index = BM25Index()
doc = MemoryDocument(
doc_id="doc_001",
content="buy apple stock momentum strategy",
memory_type="trade_memory",
)
index.add_document(doc)
assert index.num_docs == 1
assert "doc_001" in index.documents
assert index.avg_doc_length > 0
def test_add_multiple_documents(self):
"""Test adding multiple documents."""
index = BM25Index()
docs = [
MemoryDocument(f"doc_{i}", f"content {i} test", "trade_memory")
for i in range(5)
]
for doc in docs:
index.add_document(doc)
assert index.num_docs == 5
assert index.avg_doc_length > 0
def test_remove_document(self):
"""Test removing documents from index."""
index = BM25Index()
doc = MemoryDocument("doc_001", "test content", "trade_memory")
index.add_document(doc)
result = index.remove_document("doc_001")
assert result is True
assert index.num_docs == 0
assert "doc_001" not in index.documents
def test_remove_nonexistent_document(self):
"""Test removing non-existent document."""
index = BM25Index()
result = index.remove_document("nonexistent")
assert result is False
def test_search_basic(self):
"""Test basic search functionality."""
index = BM25Index()
# Add documents
index.add_document(
MemoryDocument("doc_1", "buy apple stock with momentum", "trade_memory")
)
index.add_document(
MemoryDocument("doc_2", "sell microsoft stock breakout", "trade_memory")
)
index.add_document(
MemoryDocument("doc_3", "market analysis volatile regime", "market_memory")
)
# Search
results = index.search("buy apple stock", top_k=2)
assert len(results) > 0
assert results[0][0].doc_id == "doc_1" # Most relevant
def test_search_with_memory_type_filter(self):
"""Test search with memory type filter."""
index = BM25Index()
index.add_document(
MemoryDocument("doc_1", "buy apple stock momentum", "trade_memory")
)
index.add_document(
MemoryDocument("doc_2", "buy microsoft stock breakout", "trade_memory")
)
index.add_document(
MemoryDocument("doc_3", "buy signal market analysis", "market_memory")
)
# Search only trade_memory
results = index.search("buy", memory_type="trade_memory", top_k=5)
assert len(results) == 2
for doc, _ in results:
assert doc.memory_type == "trade_memory"
def test_search_empty_query(self):
"""Test search with empty query."""
index = BM25Index()
index.add_document(MemoryDocument("doc_1", "test content", "trade_memory"))
results = index.search("")
assert results == []
def test_get_document(self):
"""Test retrieving document by ID."""
index = BM25Index()
doc = MemoryDocument("doc_001", "test content", "trade_memory")
index.add_document(doc)
retrieved = index.get_document("doc_001")
assert retrieved is not None
assert retrieved.doc_id == "doc_001"
not_found = index.get_document("nonexistent")
assert not_found is None
def test_update_document(self):
"""Test updating document fields."""
index = BM25Index()
doc = MemoryDocument("doc_001", "test content", "trade_memory", importance=0.5)
index.add_document(doc)
result = index.update_document("doc_001", importance=0.9)
assert result is True
assert index.get_document("doc_001").importance == 0.9
def test_update_nonexistent_document(self):
"""Test updating non-existent document."""
index = BM25Index()
result = index.update_document("nonexistent", importance=0.9)
assert result is False
def test_update_document_content(self):
"""Test updating document content (triggers re-index)."""
index = BM25Index()
doc = MemoryDocument("doc_001", "original content", "trade_memory")
index.add_document(doc)
result = index.update_document("doc_001", content="updated content")
assert result is True
assert index.get_document("doc_001").content == "updated content"
# Document should still be searchable
results = index.search("updated")
assert len(results) > 0
def test_get_stats(self):
"""Test getting index statistics."""
index = BM25Index()
for i in range(3):
index.add_document(
MemoryDocument(f"doc_{i}", f"content {i}", "trade_memory")
)
index.add_document(MemoryDocument("doc_market", "market data", "market_memory"))
stats = index.get_stats()
assert stats["num_documents"] == 4
assert stats["memory_types"]["trade_memory"] == 3
assert stats["memory_types"]["market_memory"] == 1
assert stats["avg_doc_length"] > 0
def test_save_and_load(self):
"""Test saving and loading index."""
with tempfile.TemporaryDirectory() as tmpdir:
index_path = Path(tmpdir) / "index.pkl"
# Create and save index
index = BM25Index()
index.add_document(
MemoryDocument("doc_1", "test content", "trade_memory", importance=0.8)
)
index.save(index_path)
# Load into new index
new_index = BM25Index()
result = new_index.load(index_path)
assert result is True
assert new_index.num_docs == 1
assert "doc_1" in new_index.documents
assert new_index.get_document("doc_1").importance == 0.8
def test_load_nonexistent_file(self):
"""Test loading from non-existent file."""
index = BM25Index()
result = index.load(Path("/nonexistent/path/index.pkl"))
assert result is False
class TestTradeMemory:
"""Test TradeMemory data class."""
def test_trade_memory_creation(self):
"""Test creating TradeMemory."""
memory = TradeMemory(
symbol="AAPL",
action="buy",
quantity=100,
price=150.0,
pnl=500.0,
)
assert memory.symbol == "AAPL"
assert memory.action == "buy"
assert memory.quantity == 100
assert memory.pnl == 500.0
def test_trade_memory_to_text(self):
"""Test converting TradeMemory to text."""
memory = TradeMemory(
symbol="AAPL",
action="buy",
quantity=100,
price=150.0,
pnl=500.0,
strategy="momentum",
outcome="profitable breakout",
)
text = memory.to_text()
assert "AAPL" in text
assert "buy" in text
assert "momentum" in text
assert "profitable breakout" in text
def test_trade_memory_to_dict(self):
"""Test converting TradeMemory to dictionary."""
memory = TradeMemory(
symbol="MSFT",
action="sell",
quantity=50,
price=300.0,
pnl=-200.0,
)
data = memory.to_dict()
assert data["symbol"] == "MSFT"
assert data["action"] == "sell"
assert data["pnl"] == -200.0
class TestMarketMemory:
"""Test MarketMemory data class."""
def test_market_memory_creation(self):
"""Test creating MarketMemory."""
memory = MarketMemory(
symbol="AAPL",
market_regime="trending",
sentiment="bullish",
)
assert memory.symbol == "AAPL"
assert memory.market_regime == "trending"
assert memory.sentiment == "bullish"
def test_market_memory_to_text(self):
"""Test converting MarketMemory to text."""
memory = MarketMemory(
symbol="AAPL",
market_regime="volatile",
sentiment="extreme_fear",
indicators={"rsi": 70.5, "macd": 1.2},
events=["earnings", "fed_meeting"],
)
text = memory.to_text()
assert "AAPL" in text
assert "volatile" in text
assert "extreme_fear" in text
assert "earnings" in text
class TestDecisionMemory:
"""Test DecisionMemory data class."""
def test_decision_memory_creation(self):
"""Test creating DecisionMemory."""
memory = DecisionMemory(
decision_type="entry",
context="breakout detected",
confidence=0.8,
)
assert memory.decision_type == "entry"
assert memory.context == "breakout detected"
assert memory.confidence == 0.8
def test_decision_memory_to_text(self):
"""Test converting DecisionMemory to text."""
memory = DecisionMemory(
decision_type="exit",
context="profit target reached",
reasoning="technical resistance",
expected_outcome="profit",
actual_outcome="profit",
confidence=0.9,
factors=["rsi_overbought", "resistance_level"],
)
text = memory.to_text()
assert "exit" in text
assert "profit target reached" in text
assert "technical resistance" in text
assert "rsi_overbought" in text
class TestErrorMemory:
"""Test ErrorMemory data class."""
def test_error_memory_creation(self):
"""Test creating ErrorMemory."""
memory = ErrorMemory(
error_type="connection_error",
error_message="Failed to connect to API",
severity="high",
)
assert memory.error_type == "connection_error"
assert memory.error_message == "Failed to connect to API"
assert memory.severity == "high"
def test_error_memory_to_text(self):
"""Test converting ErrorMemory to text."""
memory = ErrorMemory(
error_type="api_error",
error_message="Rate limit exceeded",
context="order placement",
recovery_action="wait and retry",
severity="critical",
preventability="yes",
)
text = memory.to_text()
assert "api_error" in text
assert "Rate limit exceeded" in text
assert "critical" in text
assert "wait and retry" in text
class TestLearningMemory:
"""Test LearningMemory class."""
def test_learning_memory_initialization(self):
"""Test LearningMemory initialization."""
memory = LearningMemory(agent_id="test_agent")
assert memory.agent_id == "test_agent"
assert memory.max_memories == 10000
assert memory.decay_enabled is True
def test_add_trade_memory(self):
"""Test adding trade memory."""
memory = LearningMemory(agent_id="test_agent")
doc_id = memory.add_trade_memory(
symbol="AAPL",
action="buy",
quantity=100,
price=150.0,
pnl=500.0,
strategy="momentum",
outcome="profitable",
)
assert doc_id is not None
assert memory.index.num_docs == 1
def test_add_market_memory(self):
"""Test adding market memory."""
memory = LearningMemory(agent_id="test_agent")
doc_id = memory.add_market_memory(
symbol="AAPL",
market_regime="trending",
sentiment="bullish",
indicators={"rsi": 60.0},
)
assert doc_id is not None
assert memory.index.num_docs == 1
def test_add_decision_memory(self):
"""Test adding decision memory."""
memory = LearningMemory(agent_id="test_agent")
doc_id = memory.add_decision_memory(
decision_type="entry",
context="breakout pattern",
reasoning="volume surge",
confidence=0.8,
)
assert doc_id is not None
assert memory.index.num_docs == 1
def test_add_error_memory(self):
"""Test adding error memory."""
memory = LearningMemory(agent_id="test_agent")
doc_id = memory.add_error_memory(
error_type="timeout",
error_message="Connection timed out",
severity="high",
)
assert doc_id is not None
assert memory.index.num_docs == 1
def test_search_similar_trades(self):
"""Test searching similar trades."""
memory = LearningMemory(agent_id="test_agent")
memory.add_trade_memory(
symbol="AAPL",
action="buy",
quantity=100,
price=150.0,
pnl=500.0,
strategy="momentum",
outcome="profitable",
)
memory.add_trade_memory(
symbol="MSFT",
action="buy",
quantity=50,
price=300.0,
pnl=200.0,
strategy="breakout",
outcome="profitable",
)
results = memory.search_similar_trades(symbol="AAPL", top_k=2)
assert len(results) > 0
assert results[0]["data"]["symbol"] == "AAPL"
def test_search_similar_trades_with_min_pnl(self):
"""Test searching similar trades with P&L filter."""
memory = LearningMemory(agent_id="test_agent")
memory.add_trade_memory(
symbol="AAPL", action="buy", quantity=100, price=150.0, pnl=500.0
)
memory.add_trade_memory(
symbol="AAPL", action="buy", quantity=100, price=150.0, pnl=100.0
)
results = memory.search_similar_trades(symbol="AAPL", min_pnl=200.0)
assert len(results) == 1
assert results[0]["data"]["pnl"] == 500.0
def test_search_similar_market_states(self):
"""Test searching similar market states."""
memory = LearningMemory(agent_id="test_agent")
memory.add_market_memory(
symbol="AAPL",
market_regime="volatile",
sentiment="extreme_fear",
)
memory.add_market_memory(
symbol="MSFT",
market_regime="trending",
sentiment="neutral",
)
results = memory.search_similar_market_states(regime="volatile")
assert len(results) > 0
def test_get_decision_suggestions(self):
"""Test getting decision suggestions."""
memory = LearningMemory(agent_id="test_agent")
memory.add_decision_memory(
decision_type="entry",
context="breakout pattern",
reasoning="volume surge",
expected_outcome="profit",
actual_outcome="profit",
confidence=0.8,
)
suggestions = memory.get_decision_suggestions(
context="breakout detected",
decision_type="entry",
)
assert len(suggestions) > 0
assert suggestions[0]["decision_type"] == "entry"
def test_get_error_lessons(self):
"""Test getting error lessons."""
memory = LearningMemory(agent_id="test_agent")
memory.add_error_memory(
error_type="api_error",
error_message="Rate limit exceeded",
recovery_action="implement backoff",
)
lessons = memory.get_error_lessons(error_type="api_error")
assert len(lessons) > 0
def test_update_memory_importance(self):
"""Test updating memory importance."""
memory = LearningMemory(agent_id="test_agent")
doc_id = memory.add_trade_memory(
symbol="AAPL", action="buy", quantity=100, price=150.0, pnl=500.0
)
result = memory.update_memory_importance(doc_id, 0.9)
assert result is True
doc = memory.index.get_document(doc_id)
assert doc.importance == 0.9
def test_mark_important(self):
"""Test marking memory as important."""
memory = LearningMemory(agent_id="test_agent")
doc_id = memory.add_trade_memory(
symbol="AAPL", action="buy", quantity=100, price=150.0, pnl=500.0
)
result = memory.mark_important(doc_id)
assert result is True
assert memory.index.get_document(doc_id).importance == 1.0
def test_delete_memory(self):
"""Test deleting a memory."""
memory = LearningMemory(agent_id="test_agent")
doc_id = memory.add_trade_memory(
symbol="AAPL", action="buy", quantity=100, price=150.0, pnl=500.0
)
result = memory.delete_memory(doc_id)
assert result is True
assert memory.index.num_docs == 0
def test_clear_all_memories(self):
"""Test clearing all memories."""
memory = LearningMemory(agent_id="test_agent")
for i in range(5):
memory.add_trade_memory(
symbol=f"SYM{i}",
action="buy",
quantity=100,
price=100.0,
pnl=100.0,
)
memory.clear_all_memories()
assert memory.index.num_docs == 0
def test_get_memory_stats(self):
"""Test getting memory statistics."""
memory = LearningMemory(agent_id="test_agent")
memory.add_trade_memory(
symbol="AAPL", action="buy", quantity=100, price=150.0, pnl=500.0
)
memory.add_market_memory(symbol="AAPL", market_regime="trending")
stats = memory.get_memory_stats()
assert stats["agent_id"] == "test_agent"
assert stats["total_memories"] == 2
assert stats["memory_types"]["trade_memory"] == 1
assert stats["memory_types"]["market_memory"] == 1
def test_save_and_load(self):
"""Test saving and loading learning memory."""
with tempfile.TemporaryDirectory() as tmpdir:
storage_dir = Path(tmpdir) / "memory"
# Create and populate memory
memory = LearningMemory(agent_id="test_agent", storage_dir=storage_dir)
memory.add_trade_memory(
symbol="AAPL",
action="buy",
quantity=100,
price=150.0,
pnl=500.0,
strategy="momentum",
)
memory.save()
# Load into new memory instance
new_memory = LearningMemory(agent_id="test_agent", storage_dir=storage_dir)
assert new_memory.index.num_docs == 1
doc = list(new_memory.index.documents.values())[0]
assert doc.metadata["symbol"] == "AAPL"
def test_memory_importance_calculation_trade(self):
"""Test importance calculation for trades."""
memory = LearningMemory(agent_id="test_agent")
# High P&L trade should have higher importance
doc_id_high = memory.add_trade_memory(
symbol="AAPL", action="buy", quantity=100, price=150.0, pnl=1500.0
)
doc_id_low = memory.add_trade_memory(
symbol="MSFT", action="buy", quantity=100, price=150.0, pnl=50.0
)
high_doc = memory.index.get_document(doc_id_high)
low_doc = memory.index.get_document(doc_id_low)
assert high_doc.importance > low_doc.importance
def test_memory_importance_calculation_market(self):
"""Test importance calculation for market memories."""
memory = LearningMemory(agent_id="test_agent")
# Volatile regime should have higher importance
doc_id_volatile = memory.add_market_memory(
symbol="AAPL", market_regime="volatile", sentiment="extreme_fear"
)
doc_id_normal = memory.add_market_memory(
symbol="MSFT", market_regime="trending", sentiment="neutral"
)
volatile_doc = memory.index.get_document(doc_id_volatile)
normal_doc = memory.index.get_document(doc_id_normal)
assert volatile_doc.importance > normal_doc.importance
def test_memory_importance_calculation_error(self):
"""Test importance calculation for errors."""
memory = LearningMemory(agent_id="test_agent")
# Critical error should have higher importance
doc_id_critical = memory.add_error_memory(
error_type="api_error",
error_message="test",
severity="critical",
)
doc_id_low = memory.add_error_memory(
error_type="minor_error",
error_message="test",
severity="low",
)
critical_doc = memory.index.get_document(doc_id_critical)
low_doc = memory.index.get_document(doc_id_low)
assert critical_doc.importance > low_doc.importance
def test_memory_decay(self):
"""Test memory decay mechanism."""
memory = LearningMemory(agent_id="test_agent", decay_enabled=True, decay_days=30)
doc_id = memory.add_trade_memory(
symbol="AAPL", action="buy", quantity=100, price=150.0, pnl=500.0
)
# Manually set timestamp to old date (older than decay_days * 3)
memory.index.documents[doc_id].timestamp = datetime.now() - timedelta(days=100)
memory.index.documents[doc_id].importance = 0.05 # Low importance (below threshold)
# Add another memory to trigger decay
memory.add_trade_memory(
symbol="MSFT", action="buy", quantity=100, price=150.0, pnl=500.0
)
# Old memory should be removed (100 days > 30*3=90 days AND importance < 0.1)
assert doc_id not in memory.index.documents
def test_memory_limit_enforcement(self):
"""Test memory limit enforcement."""
memory = LearningMemory(agent_id="test_agent", max_memories=5)
for i in range(10):
memory.add_trade_memory(
symbol=f"SYM{i}",
action="buy",
quantity=100,
price=100.0,
pnl=100.0,
)
assert memory.index.num_docs <= 5
def test_trade_memory_with_market_conditions(self):
"""Test trade memory with market conditions."""
memory = LearningMemory(agent_id="test_agent")
doc_id = memory.add_trade_memory(
symbol="AAPL",
action="buy",
quantity=100,
price=150.0,
pnl=500.0,
market_conditions={"trend": "up", "volatility": "high"},
)
doc = memory.index.get_document(doc_id)
assert doc.metadata["market_conditions"]["trend"] == "up"
def test_access_count_tracking(self):
"""Test that access count is tracked during searches."""
memory = LearningMemory(agent_id="test_agent")
doc_id = memory.add_trade_memory(
symbol="AAPL", action="buy", quantity=100, price=150.0, pnl=500.0
)
# Initial access count
assert memory.index.get_document(doc_id).access_count == 0
# Search should increment access count
memory.search_similar_trades(symbol="AAPL")
assert memory.index.get_document(doc_id).access_count == 1
class TestMemoryType:
"""Test MemoryType enum."""
def test_memory_type_values(self):
"""Test memory type values."""
assert MemoryType.TRADE.value == "trade_memory"
assert MemoryType.MARKET.value == "market_memory"
assert MemoryType.DECISION.value == "decision_memory"
assert MemoryType.ERROR.value == "error_memory"