"""Unit tests for Agent learning memory system.""" import tempfile from datetime import datetime, timedelta from pathlib import Path import pytest from openclaw.memory import ( BM25Index, DecisionMemory, ErrorMemory, LearningMemory, MarketMemory, MemoryDocument, MemoryType, TradeMemory, ) class TestMemoryDocument: """Test MemoryDocument data class.""" def test_document_creation(self): """Test creating a MemoryDocument.""" doc = MemoryDocument( doc_id="test_001", content="test content", memory_type="trade_memory", importance=0.8, ) assert doc.doc_id == "test_001" assert doc.content == "test content" assert doc.memory_type == "trade_memory" assert doc.importance == 0.8 assert doc.access_count == 0 assert doc.last_accessed is None assert isinstance(doc.timestamp, datetime) def test_document_to_dict(self): """Test converting document to dictionary.""" doc = MemoryDocument( doc_id="test_002", content="test content", memory_type="market_memory", importance=0.5, metadata={"key": "value"}, ) data = doc.to_dict() assert data["doc_id"] == "test_002" assert data["content"] == "test content" assert data["memory_type"] == "market_memory" assert data["importance"] == 0.5 assert data["metadata"] == {"key": "value"} def test_document_from_dict(self): """Test creating document from dictionary.""" data = { "doc_id": "test_003", "content": "test content", "memory_type": "error_memory", "timestamp": datetime.now().isoformat(), "metadata": {"error": "test"}, "importance": 0.9, "access_count": 5, "last_accessed": None, } doc = MemoryDocument.from_dict(data) assert doc.doc_id == "test_003" assert doc.content == "test content" assert doc.memory_type == "error_memory" assert doc.importance == 0.9 assert doc.access_count == 5 def test_document_serialization_with_last_accessed(self): """Test serialization with last_accessed timestamp.""" now = datetime.now() doc = MemoryDocument( doc_id="test_004", content="test", memory_type="trade_memory", last_accessed=now, ) data = doc.to_dict() assert data["last_accessed"] == now.isoformat() restored = MemoryDocument.from_dict(data) assert restored.last_accessed == now class TestBM25Index: """Test BM25 index implementation.""" def test_index_initialization(self): """Test BM25 index initialization.""" index = BM25Index(k1=1.2, b=0.75) assert index.k1 == 1.2 assert index.b == 0.75 assert index.num_docs == 0 assert index.avg_doc_length == 0.0 def test_add_document(self): """Test adding documents to index.""" index = BM25Index() doc = MemoryDocument( doc_id="doc_001", content="buy apple stock momentum strategy", memory_type="trade_memory", ) index.add_document(doc) assert index.num_docs == 1 assert "doc_001" in index.documents assert index.avg_doc_length > 0 def test_add_multiple_documents(self): """Test adding multiple documents.""" index = BM25Index() docs = [ MemoryDocument(f"doc_{i}", f"content {i} test", "trade_memory") for i in range(5) ] for doc in docs: index.add_document(doc) assert index.num_docs == 5 assert index.avg_doc_length > 0 def test_remove_document(self): """Test removing documents from index.""" index = BM25Index() doc = MemoryDocument("doc_001", "test content", "trade_memory") index.add_document(doc) result = index.remove_document("doc_001") assert result is True assert index.num_docs == 0 assert "doc_001" not in index.documents def test_remove_nonexistent_document(self): """Test removing non-existent document.""" index = BM25Index() result = index.remove_document("nonexistent") assert result is False def test_search_basic(self): """Test basic search functionality.""" index = BM25Index() # Add documents index.add_document( MemoryDocument("doc_1", "buy apple stock with momentum", "trade_memory") ) index.add_document( MemoryDocument("doc_2", "sell microsoft stock breakout", "trade_memory") ) index.add_document( MemoryDocument("doc_3", "market analysis volatile regime", "market_memory") ) # Search results = index.search("buy apple stock", top_k=2) assert len(results) > 0 assert results[0][0].doc_id == "doc_1" # Most relevant def test_search_with_memory_type_filter(self): """Test search with memory type filter.""" index = BM25Index() index.add_document( MemoryDocument("doc_1", "buy apple stock momentum", "trade_memory") ) index.add_document( MemoryDocument("doc_2", "buy microsoft stock breakout", "trade_memory") ) index.add_document( MemoryDocument("doc_3", "buy signal market analysis", "market_memory") ) # Search only trade_memory results = index.search("buy", memory_type="trade_memory", top_k=5) assert len(results) == 2 for doc, _ in results: assert doc.memory_type == "trade_memory" def test_search_empty_query(self): """Test search with empty query.""" index = BM25Index() index.add_document(MemoryDocument("doc_1", "test content", "trade_memory")) results = index.search("") assert results == [] def test_get_document(self): """Test retrieving document by ID.""" index = BM25Index() doc = MemoryDocument("doc_001", "test content", "trade_memory") index.add_document(doc) retrieved = index.get_document("doc_001") assert retrieved is not None assert retrieved.doc_id == "doc_001" not_found = index.get_document("nonexistent") assert not_found is None def test_update_document(self): """Test updating document fields.""" index = BM25Index() doc = MemoryDocument("doc_001", "test content", "trade_memory", importance=0.5) index.add_document(doc) result = index.update_document("doc_001", importance=0.9) assert result is True assert index.get_document("doc_001").importance == 0.9 def test_update_nonexistent_document(self): """Test updating non-existent document.""" index = BM25Index() result = index.update_document("nonexistent", importance=0.9) assert result is False def test_update_document_content(self): """Test updating document content (triggers re-index).""" index = BM25Index() doc = MemoryDocument("doc_001", "original content", "trade_memory") index.add_document(doc) result = index.update_document("doc_001", content="updated content") assert result is True assert index.get_document("doc_001").content == "updated content" # Document should still be searchable results = index.search("updated") assert len(results) > 0 def test_get_stats(self): """Test getting index statistics.""" index = BM25Index() for i in range(3): index.add_document( MemoryDocument(f"doc_{i}", f"content {i}", "trade_memory") ) index.add_document(MemoryDocument("doc_market", "market data", "market_memory")) stats = index.get_stats() assert stats["num_documents"] == 4 assert stats["memory_types"]["trade_memory"] == 3 assert stats["memory_types"]["market_memory"] == 1 assert stats["avg_doc_length"] > 0 def test_save_and_load(self): """Test saving and loading index.""" with tempfile.TemporaryDirectory() as tmpdir: index_path = Path(tmpdir) / "index.pkl" # Create and save index index = BM25Index() index.add_document( MemoryDocument("doc_1", "test content", "trade_memory", importance=0.8) ) index.save(index_path) # Load into new index new_index = BM25Index() result = new_index.load(index_path) assert result is True assert new_index.num_docs == 1 assert "doc_1" in new_index.documents assert new_index.get_document("doc_1").importance == 0.8 def test_load_nonexistent_file(self): """Test loading from non-existent file.""" index = BM25Index() result = index.load(Path("/nonexistent/path/index.pkl")) assert result is False class TestTradeMemory: """Test TradeMemory data class.""" def test_trade_memory_creation(self): """Test creating TradeMemory.""" memory = TradeMemory( symbol="AAPL", action="buy", quantity=100, price=150.0, pnl=500.0, ) assert memory.symbol == "AAPL" assert memory.action == "buy" assert memory.quantity == 100 assert memory.pnl == 500.0 def test_trade_memory_to_text(self): """Test converting TradeMemory to text.""" memory = TradeMemory( symbol="AAPL", action="buy", quantity=100, price=150.0, pnl=500.0, strategy="momentum", outcome="profitable breakout", ) text = memory.to_text() assert "AAPL" in text assert "buy" in text assert "momentum" in text assert "profitable breakout" in text def test_trade_memory_to_dict(self): """Test converting TradeMemory to dictionary.""" memory = TradeMemory( symbol="MSFT", action="sell", quantity=50, price=300.0, pnl=-200.0, ) data = memory.to_dict() assert data["symbol"] == "MSFT" assert data["action"] == "sell" assert data["pnl"] == -200.0 class TestMarketMemory: """Test MarketMemory data class.""" def test_market_memory_creation(self): """Test creating MarketMemory.""" memory = MarketMemory( symbol="AAPL", market_regime="trending", sentiment="bullish", ) assert memory.symbol == "AAPL" assert memory.market_regime == "trending" assert memory.sentiment == "bullish" def test_market_memory_to_text(self): """Test converting MarketMemory to text.""" memory = MarketMemory( symbol="AAPL", market_regime="volatile", sentiment="extreme_fear", indicators={"rsi": 70.5, "macd": 1.2}, events=["earnings", "fed_meeting"], ) text = memory.to_text() assert "AAPL" in text assert "volatile" in text assert "extreme_fear" in text assert "earnings" in text class TestDecisionMemory: """Test DecisionMemory data class.""" def test_decision_memory_creation(self): """Test creating DecisionMemory.""" memory = DecisionMemory( decision_type="entry", context="breakout detected", confidence=0.8, ) assert memory.decision_type == "entry" assert memory.context == "breakout detected" assert memory.confidence == 0.8 def test_decision_memory_to_text(self): """Test converting DecisionMemory to text.""" memory = DecisionMemory( decision_type="exit", context="profit target reached", reasoning="technical resistance", expected_outcome="profit", actual_outcome="profit", confidence=0.9, factors=["rsi_overbought", "resistance_level"], ) text = memory.to_text() assert "exit" in text assert "profit target reached" in text assert "technical resistance" in text assert "rsi_overbought" in text class TestErrorMemory: """Test ErrorMemory data class.""" def test_error_memory_creation(self): """Test creating ErrorMemory.""" memory = ErrorMemory( error_type="connection_error", error_message="Failed to connect to API", severity="high", ) assert memory.error_type == "connection_error" assert memory.error_message == "Failed to connect to API" assert memory.severity == "high" def test_error_memory_to_text(self): """Test converting ErrorMemory to text.""" memory = ErrorMemory( error_type="api_error", error_message="Rate limit exceeded", context="order placement", recovery_action="wait and retry", severity="critical", preventability="yes", ) text = memory.to_text() assert "api_error" in text assert "Rate limit exceeded" in text assert "critical" in text assert "wait and retry" in text class TestLearningMemory: """Test LearningMemory class.""" def test_learning_memory_initialization(self): """Test LearningMemory initialization.""" memory = LearningMemory(agent_id="test_agent") assert memory.agent_id == "test_agent" assert memory.max_memories == 10000 assert memory.decay_enabled is True def test_add_trade_memory(self): """Test adding trade memory.""" memory = LearningMemory(agent_id="test_agent") doc_id = memory.add_trade_memory( symbol="AAPL", action="buy", quantity=100, price=150.0, pnl=500.0, strategy="momentum", outcome="profitable", ) assert doc_id is not None assert memory.index.num_docs == 1 def test_add_market_memory(self): """Test adding market memory.""" memory = LearningMemory(agent_id="test_agent") doc_id = memory.add_market_memory( symbol="AAPL", market_regime="trending", sentiment="bullish", indicators={"rsi": 60.0}, ) assert doc_id is not None assert memory.index.num_docs == 1 def test_add_decision_memory(self): """Test adding decision memory.""" memory = LearningMemory(agent_id="test_agent") doc_id = memory.add_decision_memory( decision_type="entry", context="breakout pattern", reasoning="volume surge", confidence=0.8, ) assert doc_id is not None assert memory.index.num_docs == 1 def test_add_error_memory(self): """Test adding error memory.""" memory = LearningMemory(agent_id="test_agent") doc_id = memory.add_error_memory( error_type="timeout", error_message="Connection timed out", severity="high", ) assert doc_id is not None assert memory.index.num_docs == 1 def test_search_similar_trades(self): """Test searching similar trades.""" memory = LearningMemory(agent_id="test_agent") memory.add_trade_memory( symbol="AAPL", action="buy", quantity=100, price=150.0, pnl=500.0, strategy="momentum", outcome="profitable", ) memory.add_trade_memory( symbol="MSFT", action="buy", quantity=50, price=300.0, pnl=200.0, strategy="breakout", outcome="profitable", ) results = memory.search_similar_trades(symbol="AAPL", top_k=2) assert len(results) > 0 assert results[0]["data"]["symbol"] == "AAPL" def test_search_similar_trades_with_min_pnl(self): """Test searching similar trades with P&L filter.""" memory = LearningMemory(agent_id="test_agent") memory.add_trade_memory( symbol="AAPL", action="buy", quantity=100, price=150.0, pnl=500.0 ) memory.add_trade_memory( symbol="AAPL", action="buy", quantity=100, price=150.0, pnl=100.0 ) results = memory.search_similar_trades(symbol="AAPL", min_pnl=200.0) assert len(results) == 1 assert results[0]["data"]["pnl"] == 500.0 def test_search_similar_market_states(self): """Test searching similar market states.""" memory = LearningMemory(agent_id="test_agent") memory.add_market_memory( symbol="AAPL", market_regime="volatile", sentiment="extreme_fear", ) memory.add_market_memory( symbol="MSFT", market_regime="trending", sentiment="neutral", ) results = memory.search_similar_market_states(regime="volatile") assert len(results) > 0 def test_get_decision_suggestions(self): """Test getting decision suggestions.""" memory = LearningMemory(agent_id="test_agent") memory.add_decision_memory( decision_type="entry", context="breakout pattern", reasoning="volume surge", expected_outcome="profit", actual_outcome="profit", confidence=0.8, ) suggestions = memory.get_decision_suggestions( context="breakout detected", decision_type="entry", ) assert len(suggestions) > 0 assert suggestions[0]["decision_type"] == "entry" def test_get_error_lessons(self): """Test getting error lessons.""" memory = LearningMemory(agent_id="test_agent") memory.add_error_memory( error_type="api_error", error_message="Rate limit exceeded", recovery_action="implement backoff", ) lessons = memory.get_error_lessons(error_type="api_error") assert len(lessons) > 0 def test_update_memory_importance(self): """Test updating memory importance.""" memory = LearningMemory(agent_id="test_agent") doc_id = memory.add_trade_memory( symbol="AAPL", action="buy", quantity=100, price=150.0, pnl=500.0 ) result = memory.update_memory_importance(doc_id, 0.9) assert result is True doc = memory.index.get_document(doc_id) assert doc.importance == 0.9 def test_mark_important(self): """Test marking memory as important.""" memory = LearningMemory(agent_id="test_agent") doc_id = memory.add_trade_memory( symbol="AAPL", action="buy", quantity=100, price=150.0, pnl=500.0 ) result = memory.mark_important(doc_id) assert result is True assert memory.index.get_document(doc_id).importance == 1.0 def test_delete_memory(self): """Test deleting a memory.""" memory = LearningMemory(agent_id="test_agent") doc_id = memory.add_trade_memory( symbol="AAPL", action="buy", quantity=100, price=150.0, pnl=500.0 ) result = memory.delete_memory(doc_id) assert result is True assert memory.index.num_docs == 0 def test_clear_all_memories(self): """Test clearing all memories.""" memory = LearningMemory(agent_id="test_agent") for i in range(5): memory.add_trade_memory( symbol=f"SYM{i}", action="buy", quantity=100, price=100.0, pnl=100.0, ) memory.clear_all_memories() assert memory.index.num_docs == 0 def test_get_memory_stats(self): """Test getting memory statistics.""" memory = LearningMemory(agent_id="test_agent") memory.add_trade_memory( symbol="AAPL", action="buy", quantity=100, price=150.0, pnl=500.0 ) memory.add_market_memory(symbol="AAPL", market_regime="trending") stats = memory.get_memory_stats() assert stats["agent_id"] == "test_agent" assert stats["total_memories"] == 2 assert stats["memory_types"]["trade_memory"] == 1 assert stats["memory_types"]["market_memory"] == 1 def test_save_and_load(self): """Test saving and loading learning memory.""" with tempfile.TemporaryDirectory() as tmpdir: storage_dir = Path(tmpdir) / "memory" # Create and populate memory memory = LearningMemory(agent_id="test_agent", storage_dir=storage_dir) memory.add_trade_memory( symbol="AAPL", action="buy", quantity=100, price=150.0, pnl=500.0, strategy="momentum", ) memory.save() # Load into new memory instance new_memory = LearningMemory(agent_id="test_agent", storage_dir=storage_dir) assert new_memory.index.num_docs == 1 doc = list(new_memory.index.documents.values())[0] assert doc.metadata["symbol"] == "AAPL" def test_memory_importance_calculation_trade(self): """Test importance calculation for trades.""" memory = LearningMemory(agent_id="test_agent") # High P&L trade should have higher importance doc_id_high = memory.add_trade_memory( symbol="AAPL", action="buy", quantity=100, price=150.0, pnl=1500.0 ) doc_id_low = memory.add_trade_memory( symbol="MSFT", action="buy", quantity=100, price=150.0, pnl=50.0 ) high_doc = memory.index.get_document(doc_id_high) low_doc = memory.index.get_document(doc_id_low) assert high_doc.importance > low_doc.importance def test_memory_importance_calculation_market(self): """Test importance calculation for market memories.""" memory = LearningMemory(agent_id="test_agent") # Volatile regime should have higher importance doc_id_volatile = memory.add_market_memory( symbol="AAPL", market_regime="volatile", sentiment="extreme_fear" ) doc_id_normal = memory.add_market_memory( symbol="MSFT", market_regime="trending", sentiment="neutral" ) volatile_doc = memory.index.get_document(doc_id_volatile) normal_doc = memory.index.get_document(doc_id_normal) assert volatile_doc.importance > normal_doc.importance def test_memory_importance_calculation_error(self): """Test importance calculation for errors.""" memory = LearningMemory(agent_id="test_agent") # Critical error should have higher importance doc_id_critical = memory.add_error_memory( error_type="api_error", error_message="test", severity="critical", ) doc_id_low = memory.add_error_memory( error_type="minor_error", error_message="test", severity="low", ) critical_doc = memory.index.get_document(doc_id_critical) low_doc = memory.index.get_document(doc_id_low) assert critical_doc.importance > low_doc.importance def test_memory_decay(self): """Test memory decay mechanism.""" memory = LearningMemory(agent_id="test_agent", decay_enabled=True, decay_days=30) doc_id = memory.add_trade_memory( symbol="AAPL", action="buy", quantity=100, price=150.0, pnl=500.0 ) # Manually set timestamp to old date (older than decay_days * 3) memory.index.documents[doc_id].timestamp = datetime.now() - timedelta(days=100) memory.index.documents[doc_id].importance = 0.05 # Low importance (below threshold) # Add another memory to trigger decay memory.add_trade_memory( symbol="MSFT", action="buy", quantity=100, price=150.0, pnl=500.0 ) # Old memory should be removed (100 days > 30*3=90 days AND importance < 0.1) assert doc_id not in memory.index.documents def test_memory_limit_enforcement(self): """Test memory limit enforcement.""" memory = LearningMemory(agent_id="test_agent", max_memories=5) for i in range(10): memory.add_trade_memory( symbol=f"SYM{i}", action="buy", quantity=100, price=100.0, pnl=100.0, ) assert memory.index.num_docs <= 5 def test_trade_memory_with_market_conditions(self): """Test trade memory with market conditions.""" memory = LearningMemory(agent_id="test_agent") doc_id = memory.add_trade_memory( symbol="AAPL", action="buy", quantity=100, price=150.0, pnl=500.0, market_conditions={"trend": "up", "volatility": "high"}, ) doc = memory.index.get_document(doc_id) assert doc.metadata["market_conditions"]["trend"] == "up" def test_access_count_tracking(self): """Test that access count is tracked during searches.""" memory = LearningMemory(agent_id="test_agent") doc_id = memory.add_trade_memory( symbol="AAPL", action="buy", quantity=100, price=150.0, pnl=500.0 ) # Initial access count assert memory.index.get_document(doc_id).access_count == 0 # Search should increment access count memory.search_similar_trades(symbol="AAPL") assert memory.index.get_document(doc_id).access_count == 1 class TestMemoryType: """Test MemoryType enum.""" def test_memory_type_values(self): """Test memory type values.""" assert MemoryType.TRADE.value == "trade_memory" assert MemoryType.MARKET.value == "market_memory" assert MemoryType.DECISION.value == "decision_memory" assert MemoryType.ERROR.value == "error_memory"