diff --git a/tests/data_juicer_agent_test.py b/tests/data_juicer_agent_test.py index 2263028..f0e7911 100644 --- a/tests/data_juicer_agent_test.py +++ b/tests/data_juicer_agent_test.py @@ -6,7 +6,7 @@ from pathlib import Path root_path = Path(__file__).parent.parent sys.path.insert(0, str(root_path)) -sys.path.insert(0, str(Path(root_path)/"data_juicer_agent")) +sys.path.insert(0, str(Path(root_path) / "data_juicer_agent")) import pytest from unittest.mock import AsyncMock, Mock, patch @@ -39,44 +39,68 @@ from data_juicer_agent.tools import ( ) +@pytest.fixture +def mock_toolkit(): + """Create a mocked Toolkit instance""" + return Mock(spec=Toolkit) + + +@pytest.fixture +def mock_model(): + """Create a mocked DashScopeChatModel""" + model = Mock(spec=DashScopeChatModel) + model.call = AsyncMock( + return_value=Msg("assistant", "test response", role="assistant"), + ) + return model + + +@pytest.fixture +def mock_formatter(): + """Create a mocked DashScopeChatFormatter""" + return Mock(spec=DashScopeChatFormatter) + + +@pytest.fixture +def mock_memory(): + """Create a mocked InMemoryMemory""" + return Mock(spec=InMemoryMemory) + + +@pytest.fixture +def mock_mcp_client(): + """Create a mocked MCP client""" + mock_client = Mock() + mock_client.name = "DJ_recipe_flow" + mock_client.connect = AsyncMock() + mock_client.close = AsyncMock() + mock_client.get_callable_function = AsyncMock() + mock_client.list_tools = AsyncMock() + return mock_client + + +@pytest.fixture +def mock_agent( + mock_model, + mock_formatter, + mock_toolkit, + mock_memory, +): + """Create a mocked ReActAgent instance""" + agent = Mock(spec=ReActAgent) + agent.model = mock_model + agent.formatter = mock_formatter + agent.toolkit = mock_toolkit + agent.memory = mock_memory + agent.__call__ = AsyncMock( + return_value=Msg("assistant", "test response", role="assistant"), + ) + return agent + + class TestDataJuicerAgent: """Test suite for the data_juicer_agent functionality""" - @pytest.fixture - def mock_toolkit(self): - """Create a mocked Toolkit instance""" - return Mock(spec=Toolkit) - - @pytest.fixture - def mock_model(self): - """Create a mocked DashScopeChatModel""" - model = Mock(spec=DashScopeChatModel) - model.call = AsyncMock( - return_value=Msg("assistant", "test response", role="assistant"), - ) - return model - - @pytest.fixture - def mock_formatter(self): - """Create a mocked DashScopeChatFormatter""" - return Mock(spec=DashScopeChatFormatter) - - @pytest.fixture - def mock_memory(self): - """Create a mocked InMemoryMemory""" - return Mock(spec=InMemoryMemory) - - @pytest.fixture - def mock_mcp_client(self): - """Create a mocked MCP client""" - mock_client = Mock() - mock_client.name = "DJ_recipe_flow" - mock_client.connect = AsyncMock() - mock_client.close = AsyncMock() - mock_client.get_callable_function = AsyncMock() - mock_client.list_tools = AsyncMock() - return mock_client - def create_named_mock_agent(self, name, mock_agent, *args, **kwargs): """Create a named mock agent for testing""" agent_instance = Mock(spec=ReActAgent) @@ -88,24 +112,8 @@ class TestDataJuicerAgent: agent_instance.name = name return agent_instance - @pytest.fixture - def mock_agent( - self, - mock_model, - mock_formatter, - mock_toolkit, - mock_memory, - ): - """Create a mocked ReActAgent instance""" - agent = Mock(spec=ReActAgent) - agent.model = mock_model - agent.formatter = mock_formatter - agent.toolkit = mock_toolkit - agent.memory = mock_memory - agent.__call__ = AsyncMock( - return_value=Msg("assistant", "test response", role="assistant"), - ) - return agent + async def mock_user_func(self, msg=None): + return Msg("user", "exit", role="user") def test_dj_toolkit_initialization(self): """Test DJ toolkit initialization and tool registration""" @@ -149,12 +157,12 @@ class TestDataJuicerAgent: async def test_mcp_tools_list(self, mock_mcp_client): """Test MCP tools list contains expected tools and MCP client binding""" with patch( - "agentscope.mcp.HttpStatefulClient", - return_value=mock_mcp_client, - ) as mock_client_cls: + "agentscope.mcp.HttpStatefulClient", + return_value=mock_mcp_client, + ) as mock_client_cls: await get_mcp_toolkit() assert mock_client_cls.assert_called_once - + expected_tools = [view_text_file, write_text_file] assert len(mcp_tools) == len(expected_tools) for tool in expected_tools: @@ -174,6 +182,7 @@ class TestDataJuicerAgent: name="DataJuicer", sys_prompt="You are {name}, a agent.", toolkit=mock_toolkit, + description="test description", model=mock_model, formatter=mock_formatter, memory=mock_memory, @@ -181,13 +190,12 @@ class TestDataJuicerAgent: assert agent.name == "DataJuicer" assert "DataJuicer" in agent.sys_prompt + assert "test" in agent.__doc__ assert agent.model == mock_model assert agent.formatter == mock_formatter assert agent.toolkit == mock_toolkit assert agent.memory == mock_memory - - async def mock_user_func(self, msg=None): - return Msg("user", "exit", role="user") + assert isinstance(agent, ReActAgent) @pytest.mark.asyncio async def test_main_with_multiple_agents_loading(self, mock_agent, mock_mcp_client): @@ -218,21 +226,6 @@ class TestDataJuicerAgent: # Validate multiple agents are correctly created (dj, dj_dev, dj_mcp, and router) assert mock_create_agent.call_count == 4 - # Validate router agent is created - create_calls = mock_create_agent.call_args_list - router_agent_created = any( - call[0][0] == "Router" - for call in create_calls # First parameter is name - ) - assert router_agent_created, "Router agent should be created" - - # Validate dj_mcp agent is created - mcp_agent_created = any( - call[0][0] == "mcp_datajuicer_agent" - for call in create_calls # First parameter is name - ) - assert mcp_agent_created, "MCP agent should be created" - if __name__ == "__main__": pytest.main(["-v", __file__])