Files
evotraders/tests/data_juicer_agent_test.py
2025-11-10 18:17:27 +08:00

249 lines
8.1 KiB
Python

# -*- coding: utf-8 -*-
import os
from unittest.mock import AsyncMock, Mock, patch
import pytest
from agentscope.agent import ReActAgent
from agentscope.model import DashScopeChatModel
from agentscope.tool import Toolkit
from agentscope.message import Msg
from agentscope.formatter import DashScopeChatFormatter
from agentscope.memory import InMemoryMemory
from agentscope.tool import (
view_text_file,
write_text_file,
)
# Import the main function and related components
from data_juicer_agent.main import main
from data_juicer_agent.agent_factory import create_agent
from data_juicer_agent.tools import (
dj_toolkit,
dj_dev_toolkit,
dj_tools,
dj_dev_tools,
mcp_tools,
get_mcp_toolkit,
execute_safe_command,
query_dj_operators,
get_basic_files,
get_operator_example,
configure_data_juicer_path,
)
@pytest.fixture
def mock_toolkit():
"""Create a mocked Toolkit instance"""
return Mock(spec=Toolkit)
@pytest.fixture
def mock_model():
"""Create a mocked DashScopeChatModel"""
model = Mock(spec=DashScopeChatModel)
model.call = AsyncMock(
return_value=Msg("assistant", "test response", role="assistant"),
)
return model
@pytest.fixture
def mock_formatter():
"""Create a mocked DashScopeChatFormatter"""
return Mock(spec=DashScopeChatFormatter)
@pytest.fixture
def mock_memory():
"""Create a mocked InMemoryMemory"""
return Mock(spec=InMemoryMemory)
@pytest.fixture
def mock_mcp_client():
"""Create a mocked MCP client"""
mock_client = Mock()
mock_client.name = "DJ_recipe_flow"
mock_client.connect = AsyncMock()
mock_client.close = AsyncMock()
mock_client.get_callable_function = AsyncMock()
mock_client.list_tools = AsyncMock()
return mock_client
@pytest.fixture
def mock_agent(
mock_model, # pylint: disable=redefined-outer-name
mock_formatter, # pylint: disable=redefined-outer-name
mock_toolkit, # pylint: disable=redefined-outer-name
mock_memory, # pylint: disable=redefined-outer-name
):
"""Create a mocked ReActAgent instance"""
agent = Mock(spec=ReActAgent)
agent.model = mock_model
agent.formatter = mock_formatter
agent.toolkit = mock_toolkit
agent.memory = mock_memory
agent.__call__ = AsyncMock(
return_value=Msg("assistant", "test response", role="assistant"),
)
return agent
class TestDataJuicerAgent:
"""Test suite for the data_juicer_agent functionality"""
def named_mock_agent(
self,
name,
mock_agent, # pylint: disable=redefined-outer-name
):
"""Create a named mock agent for testing"""
agent_instance = Mock(spec=ReActAgent)
agent_instance.model = mock_agent.model
agent_instance.formatter = mock_agent.formatter
agent_instance.toolkit = mock_agent.toolkit
agent_instance.memory = mock_agent.memory
agent_instance.__call__ = mock_agent.__call__
agent_instance.name = name
return agent_instance
def _named_mock_agent_side_effect(
self,
mock_agent, # pylint: disable=redefined-outer-name
):
"""Side effect function for creating named mock agents"""
return lambda name, *args, **kwargs: self.named_mock_agent(
name,
mock_agent,
*args,
**kwargs,
)
async def mock_user_func(self):
return Msg("user", "exit", role="user")
def test_dj_toolkit_initialization(self):
"""Test DJ toolkit initialization and tool registration"""
assert dj_toolkit.tools.get("execute_safe_command") is not None
assert dj_toolkit.tools.get("view_text_file") is not None
assert dj_toolkit.tools.get("write_text_file") is not None
assert dj_toolkit.tools.get("query_dj_operators") is not None
# Verify tool list contains expected tools
expected_tools = [
execute_safe_command,
view_text_file,
write_text_file,
query_dj_operators,
]
assert len(dj_tools) == len(expected_tools)
for tool in expected_tools:
assert tool in dj_tools
def test_dj_dev_toolkit_initialization(self):
"""Test DJ development toolkit initialization and tool registration"""
assert dj_dev_toolkit.tools.get("view_text_file") is not None
assert dj_dev_toolkit.tools.get("write_text_file") is not None
assert dj_dev_toolkit.tools.get("get_basic_files") is not None
assert dj_dev_toolkit.tools.get("get_operator_example") is not None
assert (
dj_dev_toolkit.tools.get("configure_data_juicer_path") is not None
)
# Verify tool list contains expected tools
expected_tools = [
view_text_file,
write_text_file,
get_basic_files,
get_operator_example,
configure_data_juicer_path,
]
assert len(dj_dev_tools) == len(expected_tools)
for tool in expected_tools:
assert tool in dj_dev_tools
@pytest.mark.asyncio
async def test_mcp_tools_list(
self,
mock_mcp_client, # pylint: disable=redefined-outer-name
):
"""Test MCP tools list contains expected tools"""
with patch(
"agentscope.mcp.HttpStatefulClient",
return_value=mock_mcp_client,
) as mock_client_cls:
await get_mcp_toolkit()
assert mock_client_cls.assert_called_once
expected_tools = [view_text_file, write_text_file]
assert len(mcp_tools) == len(expected_tools)
for tool in expected_tools:
assert tool in mcp_tools
@pytest.mark.asyncio
async def test_agent_initialization(
self,
mock_model, # pylint: disable=redefined-outer-name
mock_formatter, # pylint: disable=redefined-outer-name
mock_toolkit, # pylint: disable=redefined-outer-name
mock_memory, # pylint: disable=redefined-outer-name
):
"""Test ReActAgent initialization"""
with patch.dict(os.environ, {"DASHSCOPE_API_KEY": "test_key"}):
agent = create_agent(
name="DataJuicer",
sys_prompt="You are {name}, a agent.",
toolkit=mock_toolkit,
description="test description",
model=mock_model,
formatter=mock_formatter,
memory=mock_memory,
)
assert agent.name == "DataJuicer"
assert "DataJuicer" in agent.sys_prompt
assert "test" in agent.__doc__
assert agent.model == mock_model
assert agent.formatter == mock_formatter
assert agent.toolkit == mock_toolkit
assert agent.memory == mock_memory
assert isinstance(agent, ReActAgent)
@pytest.mark.asyncio
async def test_main_with_multiple_agents_loading(
self,
mock_agent, # pylint: disable=redefined-outer-name
mock_mcp_client, # pylint: disable=redefined-outer-name
):
"""Test main function loads multiple agents successfully"""
with patch.dict(os.environ, {"DASHSCOPE_API_KEY": "test_key"}):
mock_mcp_clients = [mock_mcp_client]
with patch(
"data_juicer_agent.tools.mcp_helpers._create_clients",
return_value=mock_mcp_clients,
):
with patch(
"data_juicer_agent.main.create_agent",
side_effect=self._named_mock_agent_side_effect(mock_agent),
) as mock_create_agent:
with patch(
"data_juicer_agent.main.user",
side_effect=self.mock_user_func,
):
await main(
use_studio=False,
available_agents=["dj", "dj_dev", "dj_mcp"],
retrieval_mode="auto",
)
# Validate multiple agents are correctly created
# (dj, dj_dev, dj_mcp, and router)
assert mock_create_agent.call_count == 4
if __name__ == "__main__":
pytest.main(["-v", __file__])