refactor: reformat test fixtures in data_juicer_agent_test.py

This commit is contained in:
cmgzn
2025-10-31 15:37:26 +08:00
parent 3816343f2e
commit 4dcf72a683

View File

@@ -6,7 +6,7 @@ from pathlib import Path
root_path = Path(__file__).parent.parent root_path = Path(__file__).parent.parent
sys.path.insert(0, str(root_path)) sys.path.insert(0, str(root_path))
sys.path.insert(0, str(Path(root_path)/"data_juicer_agent")) sys.path.insert(0, str(Path(root_path) / "data_juicer_agent"))
import pytest import pytest
from unittest.mock import AsyncMock, Mock, patch from unittest.mock import AsyncMock, Mock, patch
@@ -39,44 +39,68 @@ from data_juicer_agent.tools import (
) )
@pytest.fixture
def mock_toolkit():
"""Create a mocked Toolkit instance"""
return Mock(spec=Toolkit)
@pytest.fixture
def mock_model():
"""Create a mocked DashScopeChatModel"""
model = Mock(spec=DashScopeChatModel)
model.call = AsyncMock(
return_value=Msg("assistant", "test response", role="assistant"),
)
return model
@pytest.fixture
def mock_formatter():
"""Create a mocked DashScopeChatFormatter"""
return Mock(spec=DashScopeChatFormatter)
@pytest.fixture
def mock_memory():
"""Create a mocked InMemoryMemory"""
return Mock(spec=InMemoryMemory)
@pytest.fixture
def mock_mcp_client():
"""Create a mocked MCP client"""
mock_client = Mock()
mock_client.name = "DJ_recipe_flow"
mock_client.connect = AsyncMock()
mock_client.close = AsyncMock()
mock_client.get_callable_function = AsyncMock()
mock_client.list_tools = AsyncMock()
return mock_client
@pytest.fixture
def mock_agent(
mock_model,
mock_formatter,
mock_toolkit,
mock_memory,
):
"""Create a mocked ReActAgent instance"""
agent = Mock(spec=ReActAgent)
agent.model = mock_model
agent.formatter = mock_formatter
agent.toolkit = mock_toolkit
agent.memory = mock_memory
agent.__call__ = AsyncMock(
return_value=Msg("assistant", "test response", role="assistant"),
)
return agent
class TestDataJuicerAgent: class TestDataJuicerAgent:
"""Test suite for the data_juicer_agent functionality""" """Test suite for the data_juicer_agent functionality"""
@pytest.fixture
def mock_toolkit(self):
"""Create a mocked Toolkit instance"""
return Mock(spec=Toolkit)
@pytest.fixture
def mock_model(self):
"""Create a mocked DashScopeChatModel"""
model = Mock(spec=DashScopeChatModel)
model.call = AsyncMock(
return_value=Msg("assistant", "test response", role="assistant"),
)
return model
@pytest.fixture
def mock_formatter(self):
"""Create a mocked DashScopeChatFormatter"""
return Mock(spec=DashScopeChatFormatter)
@pytest.fixture
def mock_memory(self):
"""Create a mocked InMemoryMemory"""
return Mock(spec=InMemoryMemory)
@pytest.fixture
def mock_mcp_client(self):
"""Create a mocked MCP client"""
mock_client = Mock()
mock_client.name = "DJ_recipe_flow"
mock_client.connect = AsyncMock()
mock_client.close = AsyncMock()
mock_client.get_callable_function = AsyncMock()
mock_client.list_tools = AsyncMock()
return mock_client
def create_named_mock_agent(self, name, mock_agent, *args, **kwargs): def create_named_mock_agent(self, name, mock_agent, *args, **kwargs):
"""Create a named mock agent for testing""" """Create a named mock agent for testing"""
agent_instance = Mock(spec=ReActAgent) agent_instance = Mock(spec=ReActAgent)
@@ -88,24 +112,8 @@ class TestDataJuicerAgent:
agent_instance.name = name agent_instance.name = name
return agent_instance return agent_instance
@pytest.fixture async def mock_user_func(self, msg=None):
def mock_agent( return Msg("user", "exit", role="user")
self,
mock_model,
mock_formatter,
mock_toolkit,
mock_memory,
):
"""Create a mocked ReActAgent instance"""
agent = Mock(spec=ReActAgent)
agent.model = mock_model
agent.formatter = mock_formatter
agent.toolkit = mock_toolkit
agent.memory = mock_memory
agent.__call__ = AsyncMock(
return_value=Msg("assistant", "test response", role="assistant"),
)
return agent
def test_dj_toolkit_initialization(self): def test_dj_toolkit_initialization(self):
"""Test DJ toolkit initialization and tool registration""" """Test DJ toolkit initialization and tool registration"""
@@ -149,9 +157,9 @@ class TestDataJuicerAgent:
async def test_mcp_tools_list(self, mock_mcp_client): async def test_mcp_tools_list(self, mock_mcp_client):
"""Test MCP tools list contains expected tools and MCP client binding""" """Test MCP tools list contains expected tools and MCP client binding"""
with patch( with patch(
"agentscope.mcp.HttpStatefulClient", "agentscope.mcp.HttpStatefulClient",
return_value=mock_mcp_client, return_value=mock_mcp_client,
) as mock_client_cls: ) as mock_client_cls:
await get_mcp_toolkit() await get_mcp_toolkit()
assert mock_client_cls.assert_called_once assert mock_client_cls.assert_called_once
@@ -174,6 +182,7 @@ class TestDataJuicerAgent:
name="DataJuicer", name="DataJuicer",
sys_prompt="You are {name}, a agent.", sys_prompt="You are {name}, a agent.",
toolkit=mock_toolkit, toolkit=mock_toolkit,
description="test description",
model=mock_model, model=mock_model,
formatter=mock_formatter, formatter=mock_formatter,
memory=mock_memory, memory=mock_memory,
@@ -181,13 +190,12 @@ class TestDataJuicerAgent:
assert agent.name == "DataJuicer" assert agent.name == "DataJuicer"
assert "DataJuicer" in agent.sys_prompt assert "DataJuicer" in agent.sys_prompt
assert "test" in agent.__doc__
assert agent.model == mock_model assert agent.model == mock_model
assert agent.formatter == mock_formatter assert agent.formatter == mock_formatter
assert agent.toolkit == mock_toolkit assert agent.toolkit == mock_toolkit
assert agent.memory == mock_memory assert agent.memory == mock_memory
assert isinstance(agent, ReActAgent)
async def mock_user_func(self, msg=None):
return Msg("user", "exit", role="user")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_main_with_multiple_agents_loading(self, mock_agent, mock_mcp_client): async def test_main_with_multiple_agents_loading(self, mock_agent, mock_mcp_client):
@@ -218,21 +226,6 @@ class TestDataJuicerAgent:
# Validate multiple agents are correctly created (dj, dj_dev, dj_mcp, and router) # Validate multiple agents are correctly created (dj, dj_dev, dj_mcp, and router)
assert mock_create_agent.call_count == 4 assert mock_create_agent.call_count == 4
# Validate router agent is created
create_calls = mock_create_agent.call_args_list
router_agent_created = any(
call[0][0] == "Router"
for call in create_calls # First parameter is name
)
assert router_agent_created, "Router agent should be created"
# Validate dj_mcp agent is created
mcp_agent_created = any(
call[0][0] == "mcp_datajuicer_agent"
for call in create_calls # First parameter is name
)
assert mcp_agent_created, "MCP agent should be created"
if __name__ == "__main__": if __name__ == "__main__":
pytest.main(["-v", __file__]) pytest.main(["-v", __file__])