843 lines
26 KiB
Python
843 lines
26 KiB
Python
"""Unit tests for Agent learning memory system."""
|
|
|
|
import tempfile
|
|
from datetime import datetime, timedelta
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
|
|
from openclaw.memory import (
|
|
BM25Index,
|
|
DecisionMemory,
|
|
ErrorMemory,
|
|
LearningMemory,
|
|
MarketMemory,
|
|
MemoryDocument,
|
|
MemoryType,
|
|
TradeMemory,
|
|
)
|
|
|
|
|
|
class TestMemoryDocument:
|
|
"""Test MemoryDocument data class."""
|
|
|
|
def test_document_creation(self):
|
|
"""Test creating a MemoryDocument."""
|
|
doc = MemoryDocument(
|
|
doc_id="test_001",
|
|
content="test content",
|
|
memory_type="trade_memory",
|
|
importance=0.8,
|
|
)
|
|
|
|
assert doc.doc_id == "test_001"
|
|
assert doc.content == "test content"
|
|
assert doc.memory_type == "trade_memory"
|
|
assert doc.importance == 0.8
|
|
assert doc.access_count == 0
|
|
assert doc.last_accessed is None
|
|
assert isinstance(doc.timestamp, datetime)
|
|
|
|
def test_document_to_dict(self):
|
|
"""Test converting document to dictionary."""
|
|
doc = MemoryDocument(
|
|
doc_id="test_002",
|
|
content="test content",
|
|
memory_type="market_memory",
|
|
importance=0.5,
|
|
metadata={"key": "value"},
|
|
)
|
|
|
|
data = doc.to_dict()
|
|
assert data["doc_id"] == "test_002"
|
|
assert data["content"] == "test content"
|
|
assert data["memory_type"] == "market_memory"
|
|
assert data["importance"] == 0.5
|
|
assert data["metadata"] == {"key": "value"}
|
|
|
|
def test_document_from_dict(self):
|
|
"""Test creating document from dictionary."""
|
|
data = {
|
|
"doc_id": "test_003",
|
|
"content": "test content",
|
|
"memory_type": "error_memory",
|
|
"timestamp": datetime.now().isoformat(),
|
|
"metadata": {"error": "test"},
|
|
"importance": 0.9,
|
|
"access_count": 5,
|
|
"last_accessed": None,
|
|
}
|
|
|
|
doc = MemoryDocument.from_dict(data)
|
|
assert doc.doc_id == "test_003"
|
|
assert doc.content == "test content"
|
|
assert doc.memory_type == "error_memory"
|
|
assert doc.importance == 0.9
|
|
assert doc.access_count == 5
|
|
|
|
def test_document_serialization_with_last_accessed(self):
|
|
"""Test serialization with last_accessed timestamp."""
|
|
now = datetime.now()
|
|
doc = MemoryDocument(
|
|
doc_id="test_004",
|
|
content="test",
|
|
memory_type="trade_memory",
|
|
last_accessed=now,
|
|
)
|
|
|
|
data = doc.to_dict()
|
|
assert data["last_accessed"] == now.isoformat()
|
|
|
|
restored = MemoryDocument.from_dict(data)
|
|
assert restored.last_accessed == now
|
|
|
|
|
|
class TestBM25Index:
|
|
"""Test BM25 index implementation."""
|
|
|
|
def test_index_initialization(self):
|
|
"""Test BM25 index initialization."""
|
|
index = BM25Index(k1=1.2, b=0.75)
|
|
|
|
assert index.k1 == 1.2
|
|
assert index.b == 0.75
|
|
assert index.num_docs == 0
|
|
assert index.avg_doc_length == 0.0
|
|
|
|
def test_add_document(self):
|
|
"""Test adding documents to index."""
|
|
index = BM25Index()
|
|
doc = MemoryDocument(
|
|
doc_id="doc_001",
|
|
content="buy apple stock momentum strategy",
|
|
memory_type="trade_memory",
|
|
)
|
|
|
|
index.add_document(doc)
|
|
|
|
assert index.num_docs == 1
|
|
assert "doc_001" in index.documents
|
|
assert index.avg_doc_length > 0
|
|
|
|
def test_add_multiple_documents(self):
|
|
"""Test adding multiple documents."""
|
|
index = BM25Index()
|
|
|
|
docs = [
|
|
MemoryDocument(f"doc_{i}", f"content {i} test", "trade_memory")
|
|
for i in range(5)
|
|
]
|
|
|
|
for doc in docs:
|
|
index.add_document(doc)
|
|
|
|
assert index.num_docs == 5
|
|
assert index.avg_doc_length > 0
|
|
|
|
def test_remove_document(self):
|
|
"""Test removing documents from index."""
|
|
index = BM25Index()
|
|
doc = MemoryDocument("doc_001", "test content", "trade_memory")
|
|
index.add_document(doc)
|
|
|
|
result = index.remove_document("doc_001")
|
|
|
|
assert result is True
|
|
assert index.num_docs == 0
|
|
assert "doc_001" not in index.documents
|
|
|
|
def test_remove_nonexistent_document(self):
|
|
"""Test removing non-existent document."""
|
|
index = BM25Index()
|
|
result = index.remove_document("nonexistent")
|
|
|
|
assert result is False
|
|
|
|
def test_search_basic(self):
|
|
"""Test basic search functionality."""
|
|
index = BM25Index()
|
|
|
|
# Add documents
|
|
index.add_document(
|
|
MemoryDocument("doc_1", "buy apple stock with momentum", "trade_memory")
|
|
)
|
|
index.add_document(
|
|
MemoryDocument("doc_2", "sell microsoft stock breakout", "trade_memory")
|
|
)
|
|
index.add_document(
|
|
MemoryDocument("doc_3", "market analysis volatile regime", "market_memory")
|
|
)
|
|
|
|
# Search
|
|
results = index.search("buy apple stock", top_k=2)
|
|
|
|
assert len(results) > 0
|
|
assert results[0][0].doc_id == "doc_1" # Most relevant
|
|
|
|
def test_search_with_memory_type_filter(self):
|
|
"""Test search with memory type filter."""
|
|
index = BM25Index()
|
|
|
|
index.add_document(
|
|
MemoryDocument("doc_1", "buy apple stock momentum", "trade_memory")
|
|
)
|
|
index.add_document(
|
|
MemoryDocument("doc_2", "buy microsoft stock breakout", "trade_memory")
|
|
)
|
|
index.add_document(
|
|
MemoryDocument("doc_3", "buy signal market analysis", "market_memory")
|
|
)
|
|
|
|
# Search only trade_memory
|
|
results = index.search("buy", memory_type="trade_memory", top_k=5)
|
|
|
|
assert len(results) == 2
|
|
for doc, _ in results:
|
|
assert doc.memory_type == "trade_memory"
|
|
|
|
def test_search_empty_query(self):
|
|
"""Test search with empty query."""
|
|
index = BM25Index()
|
|
index.add_document(MemoryDocument("doc_1", "test content", "trade_memory"))
|
|
|
|
results = index.search("")
|
|
assert results == []
|
|
|
|
def test_get_document(self):
|
|
"""Test retrieving document by ID."""
|
|
index = BM25Index()
|
|
doc = MemoryDocument("doc_001", "test content", "trade_memory")
|
|
index.add_document(doc)
|
|
|
|
retrieved = index.get_document("doc_001")
|
|
assert retrieved is not None
|
|
assert retrieved.doc_id == "doc_001"
|
|
|
|
not_found = index.get_document("nonexistent")
|
|
assert not_found is None
|
|
|
|
def test_update_document(self):
|
|
"""Test updating document fields."""
|
|
index = BM25Index()
|
|
doc = MemoryDocument("doc_001", "test content", "trade_memory", importance=0.5)
|
|
index.add_document(doc)
|
|
|
|
result = index.update_document("doc_001", importance=0.9)
|
|
|
|
assert result is True
|
|
assert index.get_document("doc_001").importance == 0.9
|
|
|
|
def test_update_nonexistent_document(self):
|
|
"""Test updating non-existent document."""
|
|
index = BM25Index()
|
|
result = index.update_document("nonexistent", importance=0.9)
|
|
assert result is False
|
|
|
|
def test_update_document_content(self):
|
|
"""Test updating document content (triggers re-index)."""
|
|
index = BM25Index()
|
|
doc = MemoryDocument("doc_001", "original content", "trade_memory")
|
|
index.add_document(doc)
|
|
|
|
result = index.update_document("doc_001", content="updated content")
|
|
|
|
assert result is True
|
|
assert index.get_document("doc_001").content == "updated content"
|
|
# Document should still be searchable
|
|
results = index.search("updated")
|
|
assert len(results) > 0
|
|
|
|
def test_get_stats(self):
|
|
"""Test getting index statistics."""
|
|
index = BM25Index()
|
|
|
|
for i in range(3):
|
|
index.add_document(
|
|
MemoryDocument(f"doc_{i}", f"content {i}", "trade_memory")
|
|
)
|
|
|
|
index.add_document(MemoryDocument("doc_market", "market data", "market_memory"))
|
|
|
|
stats = index.get_stats()
|
|
|
|
assert stats["num_documents"] == 4
|
|
assert stats["memory_types"]["trade_memory"] == 3
|
|
assert stats["memory_types"]["market_memory"] == 1
|
|
assert stats["avg_doc_length"] > 0
|
|
|
|
def test_save_and_load(self):
|
|
"""Test saving and loading index."""
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
index_path = Path(tmpdir) / "index.pkl"
|
|
|
|
# Create and save index
|
|
index = BM25Index()
|
|
index.add_document(
|
|
MemoryDocument("doc_1", "test content", "trade_memory", importance=0.8)
|
|
)
|
|
index.save(index_path)
|
|
|
|
# Load into new index
|
|
new_index = BM25Index()
|
|
result = new_index.load(index_path)
|
|
|
|
assert result is True
|
|
assert new_index.num_docs == 1
|
|
assert "doc_1" in new_index.documents
|
|
assert new_index.get_document("doc_1").importance == 0.8
|
|
|
|
def test_load_nonexistent_file(self):
|
|
"""Test loading from non-existent file."""
|
|
index = BM25Index()
|
|
result = index.load(Path("/nonexistent/path/index.pkl"))
|
|
assert result is False
|
|
|
|
|
|
class TestTradeMemory:
|
|
"""Test TradeMemory data class."""
|
|
|
|
def test_trade_memory_creation(self):
|
|
"""Test creating TradeMemory."""
|
|
memory = TradeMemory(
|
|
symbol="AAPL",
|
|
action="buy",
|
|
quantity=100,
|
|
price=150.0,
|
|
pnl=500.0,
|
|
)
|
|
|
|
assert memory.symbol == "AAPL"
|
|
assert memory.action == "buy"
|
|
assert memory.quantity == 100
|
|
assert memory.pnl == 500.0
|
|
|
|
def test_trade_memory_to_text(self):
|
|
"""Test converting TradeMemory to text."""
|
|
memory = TradeMemory(
|
|
symbol="AAPL",
|
|
action="buy",
|
|
quantity=100,
|
|
price=150.0,
|
|
pnl=500.0,
|
|
strategy="momentum",
|
|
outcome="profitable breakout",
|
|
)
|
|
|
|
text = memory.to_text()
|
|
assert "AAPL" in text
|
|
assert "buy" in text
|
|
assert "momentum" in text
|
|
assert "profitable breakout" in text
|
|
|
|
def test_trade_memory_to_dict(self):
|
|
"""Test converting TradeMemory to dictionary."""
|
|
memory = TradeMemory(
|
|
symbol="MSFT",
|
|
action="sell",
|
|
quantity=50,
|
|
price=300.0,
|
|
pnl=-200.0,
|
|
)
|
|
|
|
data = memory.to_dict()
|
|
assert data["symbol"] == "MSFT"
|
|
assert data["action"] == "sell"
|
|
assert data["pnl"] == -200.0
|
|
|
|
|
|
class TestMarketMemory:
|
|
"""Test MarketMemory data class."""
|
|
|
|
def test_market_memory_creation(self):
|
|
"""Test creating MarketMemory."""
|
|
memory = MarketMemory(
|
|
symbol="AAPL",
|
|
market_regime="trending",
|
|
sentiment="bullish",
|
|
)
|
|
|
|
assert memory.symbol == "AAPL"
|
|
assert memory.market_regime == "trending"
|
|
assert memory.sentiment == "bullish"
|
|
|
|
def test_market_memory_to_text(self):
|
|
"""Test converting MarketMemory to text."""
|
|
memory = MarketMemory(
|
|
symbol="AAPL",
|
|
market_regime="volatile",
|
|
sentiment="extreme_fear",
|
|
indicators={"rsi": 70.5, "macd": 1.2},
|
|
events=["earnings", "fed_meeting"],
|
|
)
|
|
|
|
text = memory.to_text()
|
|
assert "AAPL" in text
|
|
assert "volatile" in text
|
|
assert "extreme_fear" in text
|
|
assert "earnings" in text
|
|
|
|
|
|
class TestDecisionMemory:
|
|
"""Test DecisionMemory data class."""
|
|
|
|
def test_decision_memory_creation(self):
|
|
"""Test creating DecisionMemory."""
|
|
memory = DecisionMemory(
|
|
decision_type="entry",
|
|
context="breakout detected",
|
|
confidence=0.8,
|
|
)
|
|
|
|
assert memory.decision_type == "entry"
|
|
assert memory.context == "breakout detected"
|
|
assert memory.confidence == 0.8
|
|
|
|
def test_decision_memory_to_text(self):
|
|
"""Test converting DecisionMemory to text."""
|
|
memory = DecisionMemory(
|
|
decision_type="exit",
|
|
context="profit target reached",
|
|
reasoning="technical resistance",
|
|
expected_outcome="profit",
|
|
actual_outcome="profit",
|
|
confidence=0.9,
|
|
factors=["rsi_overbought", "resistance_level"],
|
|
)
|
|
|
|
text = memory.to_text()
|
|
assert "exit" in text
|
|
assert "profit target reached" in text
|
|
assert "technical resistance" in text
|
|
assert "rsi_overbought" in text
|
|
|
|
|
|
class TestErrorMemory:
|
|
"""Test ErrorMemory data class."""
|
|
|
|
def test_error_memory_creation(self):
|
|
"""Test creating ErrorMemory."""
|
|
memory = ErrorMemory(
|
|
error_type="connection_error",
|
|
error_message="Failed to connect to API",
|
|
severity="high",
|
|
)
|
|
|
|
assert memory.error_type == "connection_error"
|
|
assert memory.error_message == "Failed to connect to API"
|
|
assert memory.severity == "high"
|
|
|
|
def test_error_memory_to_text(self):
|
|
"""Test converting ErrorMemory to text."""
|
|
memory = ErrorMemory(
|
|
error_type="api_error",
|
|
error_message="Rate limit exceeded",
|
|
context="order placement",
|
|
recovery_action="wait and retry",
|
|
severity="critical",
|
|
preventability="yes",
|
|
)
|
|
|
|
text = memory.to_text()
|
|
assert "api_error" in text
|
|
assert "Rate limit exceeded" in text
|
|
assert "critical" in text
|
|
assert "wait and retry" in text
|
|
|
|
|
|
class TestLearningMemory:
|
|
"""Test LearningMemory class."""
|
|
|
|
def test_learning_memory_initialization(self):
|
|
"""Test LearningMemory initialization."""
|
|
memory = LearningMemory(agent_id="test_agent")
|
|
|
|
assert memory.agent_id == "test_agent"
|
|
assert memory.max_memories == 10000
|
|
assert memory.decay_enabled is True
|
|
|
|
def test_add_trade_memory(self):
|
|
"""Test adding trade memory."""
|
|
memory = LearningMemory(agent_id="test_agent")
|
|
|
|
doc_id = memory.add_trade_memory(
|
|
symbol="AAPL",
|
|
action="buy",
|
|
quantity=100,
|
|
price=150.0,
|
|
pnl=500.0,
|
|
strategy="momentum",
|
|
outcome="profitable",
|
|
)
|
|
|
|
assert doc_id is not None
|
|
assert memory.index.num_docs == 1
|
|
|
|
def test_add_market_memory(self):
|
|
"""Test adding market memory."""
|
|
memory = LearningMemory(agent_id="test_agent")
|
|
|
|
doc_id = memory.add_market_memory(
|
|
symbol="AAPL",
|
|
market_regime="trending",
|
|
sentiment="bullish",
|
|
indicators={"rsi": 60.0},
|
|
)
|
|
|
|
assert doc_id is not None
|
|
assert memory.index.num_docs == 1
|
|
|
|
def test_add_decision_memory(self):
|
|
"""Test adding decision memory."""
|
|
memory = LearningMemory(agent_id="test_agent")
|
|
|
|
doc_id = memory.add_decision_memory(
|
|
decision_type="entry",
|
|
context="breakout pattern",
|
|
reasoning="volume surge",
|
|
confidence=0.8,
|
|
)
|
|
|
|
assert doc_id is not None
|
|
assert memory.index.num_docs == 1
|
|
|
|
def test_add_error_memory(self):
|
|
"""Test adding error memory."""
|
|
memory = LearningMemory(agent_id="test_agent")
|
|
|
|
doc_id = memory.add_error_memory(
|
|
error_type="timeout",
|
|
error_message="Connection timed out",
|
|
severity="high",
|
|
)
|
|
|
|
assert doc_id is not None
|
|
assert memory.index.num_docs == 1
|
|
|
|
def test_search_similar_trades(self):
|
|
"""Test searching similar trades."""
|
|
memory = LearningMemory(agent_id="test_agent")
|
|
|
|
memory.add_trade_memory(
|
|
symbol="AAPL",
|
|
action="buy",
|
|
quantity=100,
|
|
price=150.0,
|
|
pnl=500.0,
|
|
strategy="momentum",
|
|
outcome="profitable",
|
|
)
|
|
memory.add_trade_memory(
|
|
symbol="MSFT",
|
|
action="buy",
|
|
quantity=50,
|
|
price=300.0,
|
|
pnl=200.0,
|
|
strategy="breakout",
|
|
outcome="profitable",
|
|
)
|
|
|
|
results = memory.search_similar_trades(symbol="AAPL", top_k=2)
|
|
|
|
assert len(results) > 0
|
|
assert results[0]["data"]["symbol"] == "AAPL"
|
|
|
|
def test_search_similar_trades_with_min_pnl(self):
|
|
"""Test searching similar trades with P&L filter."""
|
|
memory = LearningMemory(agent_id="test_agent")
|
|
|
|
memory.add_trade_memory(
|
|
symbol="AAPL", action="buy", quantity=100, price=150.0, pnl=500.0
|
|
)
|
|
memory.add_trade_memory(
|
|
symbol="AAPL", action="buy", quantity=100, price=150.0, pnl=100.0
|
|
)
|
|
|
|
results = memory.search_similar_trades(symbol="AAPL", min_pnl=200.0)
|
|
|
|
assert len(results) == 1
|
|
assert results[0]["data"]["pnl"] == 500.0
|
|
|
|
def test_search_similar_market_states(self):
|
|
"""Test searching similar market states."""
|
|
memory = LearningMemory(agent_id="test_agent")
|
|
|
|
memory.add_market_memory(
|
|
symbol="AAPL",
|
|
market_regime="volatile",
|
|
sentiment="extreme_fear",
|
|
)
|
|
memory.add_market_memory(
|
|
symbol="MSFT",
|
|
market_regime="trending",
|
|
sentiment="neutral",
|
|
)
|
|
|
|
results = memory.search_similar_market_states(regime="volatile")
|
|
|
|
assert len(results) > 0
|
|
|
|
def test_get_decision_suggestions(self):
|
|
"""Test getting decision suggestions."""
|
|
memory = LearningMemory(agent_id="test_agent")
|
|
|
|
memory.add_decision_memory(
|
|
decision_type="entry",
|
|
context="breakout pattern",
|
|
reasoning="volume surge",
|
|
expected_outcome="profit",
|
|
actual_outcome="profit",
|
|
confidence=0.8,
|
|
)
|
|
|
|
suggestions = memory.get_decision_suggestions(
|
|
context="breakout detected",
|
|
decision_type="entry",
|
|
)
|
|
|
|
assert len(suggestions) > 0
|
|
assert suggestions[0]["decision_type"] == "entry"
|
|
|
|
def test_get_error_lessons(self):
|
|
"""Test getting error lessons."""
|
|
memory = LearningMemory(agent_id="test_agent")
|
|
|
|
memory.add_error_memory(
|
|
error_type="api_error",
|
|
error_message="Rate limit exceeded",
|
|
recovery_action="implement backoff",
|
|
)
|
|
|
|
lessons = memory.get_error_lessons(error_type="api_error")
|
|
|
|
assert len(lessons) > 0
|
|
|
|
def test_update_memory_importance(self):
|
|
"""Test updating memory importance."""
|
|
memory = LearningMemory(agent_id="test_agent")
|
|
|
|
doc_id = memory.add_trade_memory(
|
|
symbol="AAPL", action="buy", quantity=100, price=150.0, pnl=500.0
|
|
)
|
|
|
|
result = memory.update_memory_importance(doc_id, 0.9)
|
|
|
|
assert result is True
|
|
doc = memory.index.get_document(doc_id)
|
|
assert doc.importance == 0.9
|
|
|
|
def test_mark_important(self):
|
|
"""Test marking memory as important."""
|
|
memory = LearningMemory(agent_id="test_agent")
|
|
|
|
doc_id = memory.add_trade_memory(
|
|
symbol="AAPL", action="buy", quantity=100, price=150.0, pnl=500.0
|
|
)
|
|
|
|
result = memory.mark_important(doc_id)
|
|
|
|
assert result is True
|
|
assert memory.index.get_document(doc_id).importance == 1.0
|
|
|
|
def test_delete_memory(self):
|
|
"""Test deleting a memory."""
|
|
memory = LearningMemory(agent_id="test_agent")
|
|
|
|
doc_id = memory.add_trade_memory(
|
|
symbol="AAPL", action="buy", quantity=100, price=150.0, pnl=500.0
|
|
)
|
|
|
|
result = memory.delete_memory(doc_id)
|
|
|
|
assert result is True
|
|
assert memory.index.num_docs == 0
|
|
|
|
def test_clear_all_memories(self):
|
|
"""Test clearing all memories."""
|
|
memory = LearningMemory(agent_id="test_agent")
|
|
|
|
for i in range(5):
|
|
memory.add_trade_memory(
|
|
symbol=f"SYM{i}",
|
|
action="buy",
|
|
quantity=100,
|
|
price=100.0,
|
|
pnl=100.0,
|
|
)
|
|
|
|
memory.clear_all_memories()
|
|
|
|
assert memory.index.num_docs == 0
|
|
|
|
def test_get_memory_stats(self):
|
|
"""Test getting memory statistics."""
|
|
memory = LearningMemory(agent_id="test_agent")
|
|
|
|
memory.add_trade_memory(
|
|
symbol="AAPL", action="buy", quantity=100, price=150.0, pnl=500.0
|
|
)
|
|
memory.add_market_memory(symbol="AAPL", market_regime="trending")
|
|
|
|
stats = memory.get_memory_stats()
|
|
|
|
assert stats["agent_id"] == "test_agent"
|
|
assert stats["total_memories"] == 2
|
|
assert stats["memory_types"]["trade_memory"] == 1
|
|
assert stats["memory_types"]["market_memory"] == 1
|
|
|
|
def test_save_and_load(self):
|
|
"""Test saving and loading learning memory."""
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
storage_dir = Path(tmpdir) / "memory"
|
|
|
|
# Create and populate memory
|
|
memory = LearningMemory(agent_id="test_agent", storage_dir=storage_dir)
|
|
memory.add_trade_memory(
|
|
symbol="AAPL",
|
|
action="buy",
|
|
quantity=100,
|
|
price=150.0,
|
|
pnl=500.0,
|
|
strategy="momentum",
|
|
)
|
|
memory.save()
|
|
|
|
# Load into new memory instance
|
|
new_memory = LearningMemory(agent_id="test_agent", storage_dir=storage_dir)
|
|
|
|
assert new_memory.index.num_docs == 1
|
|
doc = list(new_memory.index.documents.values())[0]
|
|
assert doc.metadata["symbol"] == "AAPL"
|
|
|
|
def test_memory_importance_calculation_trade(self):
|
|
"""Test importance calculation for trades."""
|
|
memory = LearningMemory(agent_id="test_agent")
|
|
|
|
# High P&L trade should have higher importance
|
|
doc_id_high = memory.add_trade_memory(
|
|
symbol="AAPL", action="buy", quantity=100, price=150.0, pnl=1500.0
|
|
)
|
|
doc_id_low = memory.add_trade_memory(
|
|
symbol="MSFT", action="buy", quantity=100, price=150.0, pnl=50.0
|
|
)
|
|
|
|
high_doc = memory.index.get_document(doc_id_high)
|
|
low_doc = memory.index.get_document(doc_id_low)
|
|
|
|
assert high_doc.importance > low_doc.importance
|
|
|
|
def test_memory_importance_calculation_market(self):
|
|
"""Test importance calculation for market memories."""
|
|
memory = LearningMemory(agent_id="test_agent")
|
|
|
|
# Volatile regime should have higher importance
|
|
doc_id_volatile = memory.add_market_memory(
|
|
symbol="AAPL", market_regime="volatile", sentiment="extreme_fear"
|
|
)
|
|
doc_id_normal = memory.add_market_memory(
|
|
symbol="MSFT", market_regime="trending", sentiment="neutral"
|
|
)
|
|
|
|
volatile_doc = memory.index.get_document(doc_id_volatile)
|
|
normal_doc = memory.index.get_document(doc_id_normal)
|
|
|
|
assert volatile_doc.importance > normal_doc.importance
|
|
|
|
def test_memory_importance_calculation_error(self):
|
|
"""Test importance calculation for errors."""
|
|
memory = LearningMemory(agent_id="test_agent")
|
|
|
|
# Critical error should have higher importance
|
|
doc_id_critical = memory.add_error_memory(
|
|
error_type="api_error",
|
|
error_message="test",
|
|
severity="critical",
|
|
)
|
|
doc_id_low = memory.add_error_memory(
|
|
error_type="minor_error",
|
|
error_message="test",
|
|
severity="low",
|
|
)
|
|
|
|
critical_doc = memory.index.get_document(doc_id_critical)
|
|
low_doc = memory.index.get_document(doc_id_low)
|
|
|
|
assert critical_doc.importance > low_doc.importance
|
|
|
|
def test_memory_decay(self):
|
|
"""Test memory decay mechanism."""
|
|
memory = LearningMemory(agent_id="test_agent", decay_enabled=True, decay_days=30)
|
|
|
|
doc_id = memory.add_trade_memory(
|
|
symbol="AAPL", action="buy", quantity=100, price=150.0, pnl=500.0
|
|
)
|
|
|
|
# Manually set timestamp to old date (older than decay_days * 3)
|
|
memory.index.documents[doc_id].timestamp = datetime.now() - timedelta(days=100)
|
|
memory.index.documents[doc_id].importance = 0.05 # Low importance (below threshold)
|
|
|
|
# Add another memory to trigger decay
|
|
memory.add_trade_memory(
|
|
symbol="MSFT", action="buy", quantity=100, price=150.0, pnl=500.0
|
|
)
|
|
|
|
# Old memory should be removed (100 days > 30*3=90 days AND importance < 0.1)
|
|
assert doc_id not in memory.index.documents
|
|
|
|
def test_memory_limit_enforcement(self):
|
|
"""Test memory limit enforcement."""
|
|
memory = LearningMemory(agent_id="test_agent", max_memories=5)
|
|
|
|
for i in range(10):
|
|
memory.add_trade_memory(
|
|
symbol=f"SYM{i}",
|
|
action="buy",
|
|
quantity=100,
|
|
price=100.0,
|
|
pnl=100.0,
|
|
)
|
|
|
|
assert memory.index.num_docs <= 5
|
|
|
|
def test_trade_memory_with_market_conditions(self):
|
|
"""Test trade memory with market conditions."""
|
|
memory = LearningMemory(agent_id="test_agent")
|
|
|
|
doc_id = memory.add_trade_memory(
|
|
symbol="AAPL",
|
|
action="buy",
|
|
quantity=100,
|
|
price=150.0,
|
|
pnl=500.0,
|
|
market_conditions={"trend": "up", "volatility": "high"},
|
|
)
|
|
|
|
doc = memory.index.get_document(doc_id)
|
|
assert doc.metadata["market_conditions"]["trend"] == "up"
|
|
|
|
def test_access_count_tracking(self):
|
|
"""Test that access count is tracked during searches."""
|
|
memory = LearningMemory(agent_id="test_agent")
|
|
|
|
doc_id = memory.add_trade_memory(
|
|
symbol="AAPL", action="buy", quantity=100, price=150.0, pnl=500.0
|
|
)
|
|
|
|
# Initial access count
|
|
assert memory.index.get_document(doc_id).access_count == 0
|
|
|
|
# Search should increment access count
|
|
memory.search_similar_trades(symbol="AAPL")
|
|
|
|
assert memory.index.get_document(doc_id).access_count == 1
|
|
|
|
|
|
class TestMemoryType:
|
|
"""Test MemoryType enum."""
|
|
|
|
def test_memory_type_values(self):
|
|
"""Test memory type values."""
|
|
assert MemoryType.TRADE.value == "trade_memory"
|
|
assert MemoryType.MARKET.value == "market_memory"
|
|
assert MemoryType.DECISION.value == "decision_memory"
|
|
assert MemoryType.ERROR.value == "error_memory"
|