refactor: reformat test fixtures in data_juicer_agent_test.py
This commit is contained in:
@@ -6,7 +6,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
root_path = Path(__file__).parent.parent
|
root_path = Path(__file__).parent.parent
|
||||||
sys.path.insert(0, str(root_path))
|
sys.path.insert(0, str(root_path))
|
||||||
sys.path.insert(0, str(Path(root_path)/"data_juicer_agent"))
|
sys.path.insert(0, str(Path(root_path) / "data_juicer_agent"))
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import AsyncMock, Mock, patch
|
from unittest.mock import AsyncMock, Mock, patch
|
||||||
@@ -39,16 +39,14 @@ from data_juicer_agent.tools import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestDataJuicerAgent:
|
@pytest.fixture
|
||||||
"""Test suite for the data_juicer_agent functionality"""
|
def mock_toolkit():
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_toolkit(self):
|
|
||||||
"""Create a mocked Toolkit instance"""
|
"""Create a mocked Toolkit instance"""
|
||||||
return Mock(spec=Toolkit)
|
return Mock(spec=Toolkit)
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_model(self):
|
@pytest.fixture
|
||||||
|
def mock_model():
|
||||||
"""Create a mocked DashScopeChatModel"""
|
"""Create a mocked DashScopeChatModel"""
|
||||||
model = Mock(spec=DashScopeChatModel)
|
model = Mock(spec=DashScopeChatModel)
|
||||||
model.call = AsyncMock(
|
model.call = AsyncMock(
|
||||||
@@ -56,18 +54,21 @@ class TestDataJuicerAgent:
|
|||||||
)
|
)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_formatter(self):
|
@pytest.fixture
|
||||||
|
def mock_formatter():
|
||||||
"""Create a mocked DashScopeChatFormatter"""
|
"""Create a mocked DashScopeChatFormatter"""
|
||||||
return Mock(spec=DashScopeChatFormatter)
|
return Mock(spec=DashScopeChatFormatter)
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_memory(self):
|
@pytest.fixture
|
||||||
|
def mock_memory():
|
||||||
"""Create a mocked InMemoryMemory"""
|
"""Create a mocked InMemoryMemory"""
|
||||||
return Mock(spec=InMemoryMemory)
|
return Mock(spec=InMemoryMemory)
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_mcp_client(self):
|
@pytest.fixture
|
||||||
|
def mock_mcp_client():
|
||||||
"""Create a mocked MCP client"""
|
"""Create a mocked MCP client"""
|
||||||
mock_client = Mock()
|
mock_client = Mock()
|
||||||
mock_client.name = "DJ_recipe_flow"
|
mock_client.name = "DJ_recipe_flow"
|
||||||
@@ -77,6 +78,29 @@ class TestDataJuicerAgent:
|
|||||||
mock_client.list_tools = AsyncMock()
|
mock_client.list_tools = AsyncMock()
|
||||||
return mock_client
|
return mock_client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_agent(
|
||||||
|
mock_model,
|
||||||
|
mock_formatter,
|
||||||
|
mock_toolkit,
|
||||||
|
mock_memory,
|
||||||
|
):
|
||||||
|
"""Create a mocked ReActAgent instance"""
|
||||||
|
agent = Mock(spec=ReActAgent)
|
||||||
|
agent.model = mock_model
|
||||||
|
agent.formatter = mock_formatter
|
||||||
|
agent.toolkit = mock_toolkit
|
||||||
|
agent.memory = mock_memory
|
||||||
|
agent.__call__ = AsyncMock(
|
||||||
|
return_value=Msg("assistant", "test response", role="assistant"),
|
||||||
|
)
|
||||||
|
return agent
|
||||||
|
|
||||||
|
|
||||||
|
class TestDataJuicerAgent:
|
||||||
|
"""Test suite for the data_juicer_agent functionality"""
|
||||||
|
|
||||||
def create_named_mock_agent(self, name, mock_agent, *args, **kwargs):
|
def create_named_mock_agent(self, name, mock_agent, *args, **kwargs):
|
||||||
"""Create a named mock agent for testing"""
|
"""Create a named mock agent for testing"""
|
||||||
agent_instance = Mock(spec=ReActAgent)
|
agent_instance = Mock(spec=ReActAgent)
|
||||||
@@ -88,24 +112,8 @@ class TestDataJuicerAgent:
|
|||||||
agent_instance.name = name
|
agent_instance.name = name
|
||||||
return agent_instance
|
return agent_instance
|
||||||
|
|
||||||
@pytest.fixture
|
async def mock_user_func(self, msg=None):
|
||||||
def mock_agent(
|
return Msg("user", "exit", role="user")
|
||||||
self,
|
|
||||||
mock_model,
|
|
||||||
mock_formatter,
|
|
||||||
mock_toolkit,
|
|
||||||
mock_memory,
|
|
||||||
):
|
|
||||||
"""Create a mocked ReActAgent instance"""
|
|
||||||
agent = Mock(spec=ReActAgent)
|
|
||||||
agent.model = mock_model
|
|
||||||
agent.formatter = mock_formatter
|
|
||||||
agent.toolkit = mock_toolkit
|
|
||||||
agent.memory = mock_memory
|
|
||||||
agent.__call__ = AsyncMock(
|
|
||||||
return_value=Msg("assistant", "test response", role="assistant"),
|
|
||||||
)
|
|
||||||
return agent
|
|
||||||
|
|
||||||
def test_dj_toolkit_initialization(self):
|
def test_dj_toolkit_initialization(self):
|
||||||
"""Test DJ toolkit initialization and tool registration"""
|
"""Test DJ toolkit initialization and tool registration"""
|
||||||
@@ -174,6 +182,7 @@ class TestDataJuicerAgent:
|
|||||||
name="DataJuicer",
|
name="DataJuicer",
|
||||||
sys_prompt="You are {name}, a agent.",
|
sys_prompt="You are {name}, a agent.",
|
||||||
toolkit=mock_toolkit,
|
toolkit=mock_toolkit,
|
||||||
|
description="test description",
|
||||||
model=mock_model,
|
model=mock_model,
|
||||||
formatter=mock_formatter,
|
formatter=mock_formatter,
|
||||||
memory=mock_memory,
|
memory=mock_memory,
|
||||||
@@ -181,13 +190,12 @@ class TestDataJuicerAgent:
|
|||||||
|
|
||||||
assert agent.name == "DataJuicer"
|
assert agent.name == "DataJuicer"
|
||||||
assert "DataJuicer" in agent.sys_prompt
|
assert "DataJuicer" in agent.sys_prompt
|
||||||
|
assert "test" in agent.__doc__
|
||||||
assert agent.model == mock_model
|
assert agent.model == mock_model
|
||||||
assert agent.formatter == mock_formatter
|
assert agent.formatter == mock_formatter
|
||||||
assert agent.toolkit == mock_toolkit
|
assert agent.toolkit == mock_toolkit
|
||||||
assert agent.memory == mock_memory
|
assert agent.memory == mock_memory
|
||||||
|
assert isinstance(agent, ReActAgent)
|
||||||
async def mock_user_func(self, msg=None):
|
|
||||||
return Msg("user", "exit", role="user")
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_main_with_multiple_agents_loading(self, mock_agent, mock_mcp_client):
|
async def test_main_with_multiple_agents_loading(self, mock_agent, mock_mcp_client):
|
||||||
@@ -218,21 +226,6 @@ class TestDataJuicerAgent:
|
|||||||
# Validate multiple agents are correctly created (dj, dj_dev, dj_mcp, and router)
|
# Validate multiple agents are correctly created (dj, dj_dev, dj_mcp, and router)
|
||||||
assert mock_create_agent.call_count == 4
|
assert mock_create_agent.call_count == 4
|
||||||
|
|
||||||
# Validate router agent is created
|
|
||||||
create_calls = mock_create_agent.call_args_list
|
|
||||||
router_agent_created = any(
|
|
||||||
call[0][0] == "Router"
|
|
||||||
for call in create_calls # First parameter is name
|
|
||||||
)
|
|
||||||
assert router_agent_created, "Router agent should be created"
|
|
||||||
|
|
||||||
# Validate dj_mcp agent is created
|
|
||||||
mcp_agent_created = any(
|
|
||||||
call[0][0] == "mcp_datajuicer_agent"
|
|
||||||
for call in create_calls # First parameter is name
|
|
||||||
)
|
|
||||||
assert mcp_agent_created, "MCP agent should be created"
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
pytest.main(["-v", __file__])
|
pytest.main(["-v", __file__])
|
||||||
|
|||||||
Reference in New Issue
Block a user