249 lines
8.1 KiB
Python
249 lines
8.1 KiB
Python
# -*- coding: utf-8 -*-
|
|
import os
|
|
from unittest.mock import AsyncMock, Mock, patch
|
|
|
|
import pytest
|
|
from agentscope.agent import ReActAgent
|
|
from agentscope.model import DashScopeChatModel
|
|
from agentscope.tool import Toolkit
|
|
from agentscope.message import Msg
|
|
from agentscope.formatter import DashScopeChatFormatter
|
|
from agentscope.memory import InMemoryMemory
|
|
from agentscope.tool import (
|
|
view_text_file,
|
|
write_text_file,
|
|
)
|
|
|
|
# Import the main function and related components
|
|
from data_juicer_agent.main import main
|
|
from data_juicer_agent.agent_factory import create_agent
|
|
from data_juicer_agent.tools import (
|
|
dj_toolkit,
|
|
dj_dev_toolkit,
|
|
dj_tools,
|
|
dj_dev_tools,
|
|
mcp_tools,
|
|
get_mcp_toolkit,
|
|
execute_safe_command,
|
|
query_dj_operators,
|
|
get_basic_files,
|
|
get_operator_example,
|
|
configure_data_juicer_path,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_toolkit():
|
|
"""Create a mocked Toolkit instance"""
|
|
return Mock(spec=Toolkit)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_model():
|
|
"""Create a mocked DashScopeChatModel"""
|
|
model = Mock(spec=DashScopeChatModel)
|
|
model.call = AsyncMock(
|
|
return_value=Msg("assistant", "test response", role="assistant"),
|
|
)
|
|
return model
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_formatter():
|
|
"""Create a mocked DashScopeChatFormatter"""
|
|
return Mock(spec=DashScopeChatFormatter)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_memory():
|
|
"""Create a mocked InMemoryMemory"""
|
|
return Mock(spec=InMemoryMemory)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_mcp_client():
|
|
"""Create a mocked MCP client"""
|
|
mock_client = Mock()
|
|
mock_client.name = "DJ_recipe_flow"
|
|
mock_client.connect = AsyncMock()
|
|
mock_client.close = AsyncMock()
|
|
mock_client.get_callable_function = AsyncMock()
|
|
mock_client.list_tools = AsyncMock()
|
|
return mock_client
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_agent(
|
|
mock_model, # pylint: disable=redefined-outer-name
|
|
mock_formatter, # pylint: disable=redefined-outer-name
|
|
mock_toolkit, # pylint: disable=redefined-outer-name
|
|
mock_memory, # pylint: disable=redefined-outer-name
|
|
):
|
|
"""Create a mocked ReActAgent instance"""
|
|
agent = Mock(spec=ReActAgent)
|
|
agent.model = mock_model
|
|
agent.formatter = mock_formatter
|
|
agent.toolkit = mock_toolkit
|
|
agent.memory = mock_memory
|
|
agent.__call__ = AsyncMock(
|
|
return_value=Msg("assistant", "test response", role="assistant"),
|
|
)
|
|
return agent
|
|
|
|
|
|
class TestDataJuicerAgent:
|
|
"""Test suite for the data_juicer_agent functionality"""
|
|
|
|
def named_mock_agent(
|
|
self,
|
|
name,
|
|
mock_agent, # pylint: disable=redefined-outer-name
|
|
):
|
|
"""Create a named mock agent for testing"""
|
|
agent_instance = Mock(spec=ReActAgent)
|
|
agent_instance.model = mock_agent.model
|
|
agent_instance.formatter = mock_agent.formatter
|
|
agent_instance.toolkit = mock_agent.toolkit
|
|
agent_instance.memory = mock_agent.memory
|
|
agent_instance.__call__ = mock_agent.__call__
|
|
agent_instance.name = name
|
|
return agent_instance
|
|
|
|
def _named_mock_agent_side_effect(
|
|
self,
|
|
mock_agent, # pylint: disable=redefined-outer-name
|
|
):
|
|
"""Side effect function for creating named mock agents"""
|
|
return lambda name, *args, **kwargs: self.named_mock_agent(
|
|
name,
|
|
mock_agent,
|
|
*args,
|
|
**kwargs,
|
|
)
|
|
|
|
async def mock_user_func(self):
|
|
return Msg("user", "exit", role="user")
|
|
|
|
def test_dj_toolkit_initialization(self):
|
|
"""Test DJ toolkit initialization and tool registration"""
|
|
assert dj_toolkit.tools.get("execute_safe_command") is not None
|
|
assert dj_toolkit.tools.get("view_text_file") is not None
|
|
assert dj_toolkit.tools.get("write_text_file") is not None
|
|
assert dj_toolkit.tools.get("query_dj_operators") is not None
|
|
|
|
# Verify tool list contains expected tools
|
|
expected_tools = [
|
|
execute_safe_command,
|
|
view_text_file,
|
|
write_text_file,
|
|
query_dj_operators,
|
|
]
|
|
assert len(dj_tools) == len(expected_tools)
|
|
for tool in expected_tools:
|
|
assert tool in dj_tools
|
|
|
|
def test_dj_dev_toolkit_initialization(self):
|
|
"""Test DJ development toolkit initialization and tool registration"""
|
|
assert dj_dev_toolkit.tools.get("view_text_file") is not None
|
|
assert dj_dev_toolkit.tools.get("write_text_file") is not None
|
|
assert dj_dev_toolkit.tools.get("get_basic_files") is not None
|
|
assert dj_dev_toolkit.tools.get("get_operator_example") is not None
|
|
assert (
|
|
dj_dev_toolkit.tools.get("configure_data_juicer_path") is not None
|
|
)
|
|
|
|
# Verify tool list contains expected tools
|
|
expected_tools = [
|
|
view_text_file,
|
|
write_text_file,
|
|
get_basic_files,
|
|
get_operator_example,
|
|
configure_data_juicer_path,
|
|
]
|
|
assert len(dj_dev_tools) == len(expected_tools)
|
|
for tool in expected_tools:
|
|
assert tool in dj_dev_tools
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_mcp_tools_list(
|
|
self,
|
|
mock_mcp_client, # pylint: disable=redefined-outer-name
|
|
):
|
|
"""Test MCP tools list contains expected tools"""
|
|
with patch(
|
|
"agentscope.mcp.HttpStatefulClient",
|
|
return_value=mock_mcp_client,
|
|
) as mock_client_cls:
|
|
await get_mcp_toolkit()
|
|
assert mock_client_cls.assert_called_once
|
|
|
|
expected_tools = [view_text_file, write_text_file]
|
|
assert len(mcp_tools) == len(expected_tools)
|
|
for tool in expected_tools:
|
|
assert tool in mcp_tools
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_agent_initialization(
|
|
self,
|
|
mock_model, # pylint: disable=redefined-outer-name
|
|
mock_formatter, # pylint: disable=redefined-outer-name
|
|
mock_toolkit, # pylint: disable=redefined-outer-name
|
|
mock_memory, # pylint: disable=redefined-outer-name
|
|
):
|
|
"""Test ReActAgent initialization"""
|
|
with patch.dict(os.environ, {"DASHSCOPE_API_KEY": "test_key"}):
|
|
agent = create_agent(
|
|
name="DataJuicer",
|
|
sys_prompt="You are {name}, a agent.",
|
|
toolkit=mock_toolkit,
|
|
description="test description",
|
|
model=mock_model,
|
|
formatter=mock_formatter,
|
|
memory=mock_memory,
|
|
)
|
|
|
|
assert agent.name == "DataJuicer"
|
|
assert "DataJuicer" in agent.sys_prompt
|
|
assert "test" in agent.__doc__
|
|
assert agent.model == mock_model
|
|
assert agent.formatter == mock_formatter
|
|
assert agent.toolkit == mock_toolkit
|
|
assert agent.memory == mock_memory
|
|
assert isinstance(agent, ReActAgent)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_main_with_multiple_agents_loading(
|
|
self,
|
|
mock_agent, # pylint: disable=redefined-outer-name
|
|
mock_mcp_client, # pylint: disable=redefined-outer-name
|
|
):
|
|
"""Test main function loads multiple agents successfully"""
|
|
with patch.dict(os.environ, {"DASHSCOPE_API_KEY": "test_key"}):
|
|
mock_mcp_clients = [mock_mcp_client]
|
|
|
|
with patch(
|
|
"data_juicer_agent.tools.mcp_helpers._create_clients",
|
|
return_value=mock_mcp_clients,
|
|
):
|
|
with patch(
|
|
"data_juicer_agent.main.create_agent",
|
|
side_effect=self._named_mock_agent_side_effect(mock_agent),
|
|
) as mock_create_agent:
|
|
with patch(
|
|
"data_juicer_agent.main.user",
|
|
side_effect=self.mock_user_func,
|
|
):
|
|
await main(
|
|
use_studio=False,
|
|
available_agents=["dj", "dj_dev", "dj_mcp"],
|
|
retrieval_mode="auto",
|
|
)
|
|
|
|
# Validate multiple agents are correctly created
|
|
# (dj, dj_dev, dj_mcp, and router)
|
|
assert mock_create_agent.call_count == 4
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main(["-v", __file__])
|