Add pre commit (#26)
This commit is contained in:
24
.eslintrc
Normal file
24
.eslintrc
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
{
|
||||||
|
"env": {
|
||||||
|
"browser": true,
|
||||||
|
"es2021": true
|
||||||
|
},
|
||||||
|
"parserOptions": {
|
||||||
|
"ecmaVersion": 2021,
|
||||||
|
"sourceType": "module",
|
||||||
|
"ecmaFeatures": {
|
||||||
|
"jsx": true
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"rules": {
|
||||||
|
"semi": ["error", "always"],
|
||||||
|
"quotes": ["error", "double"],
|
||||||
|
"indent": ["error", 2],
|
||||||
|
"linebreak-style": ["error", "unix"],
|
||||||
|
"brace-style": ["error", "1tbs"],
|
||||||
|
"curly": ["error", "all"],
|
||||||
|
"no-eval": ["error"],
|
||||||
|
"prefer-const": ["error"],
|
||||||
|
"arrow-spacing": ["error", { "before": true, "after": true }]
|
||||||
|
}
|
||||||
|
}
|
||||||
22
.github/workflows/pre-commit.yml
vendored
22
.github/workflows/pre-commit.yml
vendored
@@ -13,9 +13,27 @@ jobs:
|
|||||||
OS: ${{ matrix.os }}
|
OS: ${{ matrix.os }}
|
||||||
PYTHON: '3.10'
|
PYTHON: '3.10'
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v3
|
- uses: actions/checkout@v4
|
||||||
- name: Setup Python
|
- name: Setup Python
|
||||||
uses: actions/setup-python@v3
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
python-version: '3.10'
|
python-version: '3.10'
|
||||||
|
- name: Update setuptools and wheel
|
||||||
|
run: |
|
||||||
|
pip install setuptools==68.2.2 wheel==0.41.2
|
||||||
|
- name: Install pre-commit
|
||||||
|
run: |
|
||||||
|
pip install pre-commit
|
||||||
|
- name: Install pre-commit hooks
|
||||||
|
run: |
|
||||||
|
pre-commit install
|
||||||
|
- name: Run pre-commit
|
||||||
|
run: |
|
||||||
|
pre-commit run --all-files > pre-commit.log 2>&1 || true
|
||||||
|
cat pre-commit.log
|
||||||
|
if grep -q Failed pre-commit.log; then
|
||||||
|
echo -e "\e[41m [**FAIL**] Please install pre-commit and format your code first. \e[0m"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
echo -e "\e[46m ********************************Passed******************************** \e[0m"
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,5 @@
|
|||||||
name: deep_research_runtime_test
|
name: deep_research_runtime_test
|
||||||
on:
|
on: [push, pull_request]
|
||||||
schedule:
|
|
||||||
- cron: '0 0 */3 * *'
|
|
||||||
workflow_dispatch:
|
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
test:
|
test:
|
||||||
@@ -33,5 +30,7 @@ jobs:
|
|||||||
pip install pytest pytest-asyncio pytest-mock
|
pip install pytest pytest-asyncio pytest-mock
|
||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
|
env:
|
||||||
|
PYTHONPATH: ${{ github.workspace }}/deep_research/agent_deep_research
|
||||||
run: |
|
run: |
|
||||||
python -m pytest tests/agent_deep_research_test.py -v
|
python -m pytest tests/agent_deep_research_test.py -v
|
||||||
@@ -1,9 +1,6 @@
|
|||||||
name: BrowserAgent Tests
|
name: BrowserAgent Tests
|
||||||
|
|
||||||
on:
|
on: [push]
|
||||||
schedule:
|
|
||||||
- cron: '0 0 */3 * *'
|
|
||||||
workflow_dispatch:
|
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
test:
|
test:
|
||||||
@@ -33,13 +30,14 @@ jobs:
|
|||||||
|
|
||||||
- name: Install Dependencies
|
- name: Install Dependencies
|
||||||
run: |
|
run: |
|
||||||
cd browser_agent/agent_browser
|
cd browser_use/agent_browser
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
pip install pytest pytest-asyncio
|
pip install pytest pytest-asyncio
|
||||||
pip install -r requirements.txt
|
pip install -r requirements.txt
|
||||||
|
|
||||||
- name: Run Tests
|
- name: Run Tests
|
||||||
env:
|
env:
|
||||||
|
PYTHONPATH: ${{ github.workspace }}/browser_use/agent_browser
|
||||||
DASHSCOPE_API_KEY: ${{ secrets.DASHSCOPE_API_KEY }}
|
DASHSCOPE_API_KEY: ${{ secrets.DASHSCOPE_API_KEY }}
|
||||||
run: |
|
run: |
|
||||||
# ✅ Ensure test-results directory exists
|
# ✅ Ensure test-results directory exists
|
||||||
|
|||||||
@@ -1,8 +1,5 @@
|
|||||||
name: browser_use_fullstack_runtime_test
|
name: browser_use_fullstack_runtime_test
|
||||||
on:
|
on: [push]
|
||||||
schedule:
|
|
||||||
- cron: '0 0 */3 * *'
|
|
||||||
workflow_dispatch:
|
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
test:
|
test:
|
||||||
@@ -37,6 +34,7 @@ jobs:
|
|||||||
- name: Run tests
|
- name: Run tests
|
||||||
env:
|
env:
|
||||||
DASHSCOPE_API_KEY: ${{ secrets.DASHSCOPE_API_KEY }}
|
DASHSCOPE_API_KEY: ${{ secrets.DASHSCOPE_API_KEY }}
|
||||||
|
PYTHONPATH: ${{ github.workspace }}/browser_use/browser_use_fullstack_runtime/backend
|
||||||
run: |
|
run: |
|
||||||
# ✅ Use validated path from debug output
|
# ✅ Use validated path from debug output
|
||||||
python -m pytest tests/browser_use_fullstack_runtime_test.py -v
|
python -m pytest tests/browser_use_fullstack_runtime_test.py -v
|
||||||
@@ -1,8 +1,5 @@
|
|||||||
name: Conversational Agents Chatbot Test
|
name: Conversational Agents Chatbot Test
|
||||||
on:
|
on: [push]
|
||||||
schedule:
|
|
||||||
- cron: '0 0 */3 * *'
|
|
||||||
workflow_dispatch:
|
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
test:
|
test:
|
||||||
@@ -30,6 +27,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
env:
|
env:
|
||||||
|
PYTHONPATH: ${{ github.workspace }}/conversational_agents/chatbot
|
||||||
DASHSCOPE_API_KEY: ${{ secrets.DASHSCOPE_API_KEY }}
|
DASHSCOPE_API_KEY: ${{ secrets.DASHSCOPE_API_KEY }}
|
||||||
run: |
|
run: |
|
||||||
# ✅ Use correct relative path
|
# ✅ Use correct relative path
|
||||||
|
|||||||
@@ -1,8 +1,5 @@
|
|||||||
name: Flask API Runtime Test
|
name: Flask API Runtime Test
|
||||||
on:
|
on: [push]
|
||||||
schedule:
|
|
||||||
- cron: '0 0 */3 * *'
|
|
||||||
workflow_dispatch:
|
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
test:
|
test:
|
||||||
@@ -33,5 +30,7 @@ jobs:
|
|||||||
pip install pytest pytest-asyncio
|
pip install pytest pytest-asyncio
|
||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
|
env:
|
||||||
|
PYTHONPATH: ${{ github.workspace }}/conversational_agents/chatbot_fullstack_runtime
|
||||||
run: |
|
run: |
|
||||||
python -m pytest tests/conversational_agents_chatbot_fullstack_runtime_webserver_test.py -v
|
python -m pytest tests/conversational_agents_chatbot_fullstack_runtime_webserver_test.py -v
|
||||||
6
.github/workflows/test_evaluation.yml
vendored
6
.github/workflows/test_evaluation.yml
vendored
@@ -1,8 +1,5 @@
|
|||||||
name: ACE Benchmark Evaluation Test
|
name: ACE Benchmark Evaluation Test
|
||||||
on:
|
on: [push]
|
||||||
schedule:
|
|
||||||
- cron: '0 0 */3 * *'
|
|
||||||
workflow_dispatch:
|
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
test:
|
test:
|
||||||
@@ -33,6 +30,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
env:
|
env:
|
||||||
|
PYTHONPATH: ${{ env.GITHUB_WORKSPACE }}/evaluation/ace_bench
|
||||||
DASHSCOPE_API_KEY: ${{ secrets.DASHSCOPE_API_KEY }}
|
DASHSCOPE_API_KEY: ${{ secrets.DASHSCOPE_API_KEY }}
|
||||||
run: |
|
run: |
|
||||||
python -m pytest tests/evaluation_test.py -v
|
python -m pytest tests/evaluation_test.py -v
|
||||||
11
.github/workflows/test_game.yml
vendored
11
.github/workflows/test_game.yml
vendored
@@ -1,9 +1,6 @@
|
|||||||
name: Run test_game.py
|
name: Run test_game.py
|
||||||
|
|
||||||
on:
|
on: [push]
|
||||||
schedule:
|
|
||||||
- cron: '0 0 */3 * *'
|
|
||||||
workflow_dispatch:
|
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
test:
|
test:
|
||||||
@@ -21,7 +18,7 @@ jobs:
|
|||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v4
|
uses: actions/setup-python@v4
|
||||||
with:
|
with:
|
||||||
python-version: 3.10
|
python-version: "3.10"
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
@@ -32,7 +29,7 @@ jobs:
|
|||||||
- name: Run game_test.py
|
- name: Run game_test.py
|
||||||
env:
|
env:
|
||||||
DASHSCOPE_API_KEY: ${{ secrets.DASHSCOPE_API_KEY }}
|
DASHSCOPE_API_KEY: ${{ secrets.DASHSCOPE_API_KEY }}
|
||||||
PYTHONPATH: ${{ env.GITHUB_WORKSPACE }}/games/game_werewolves
|
PYTHONPATH: $GITHUB_WORKSPACE/games/game_werewolves
|
||||||
run: |
|
run: |
|
||||||
# ✅ Ensure correct working directory
|
# ✅ Ensure correct working directory
|
||||||
python -m pytest tests/game_test.py -v
|
PYTHONPATH=$GITHUB_WORKSPACE/games/game_werewolves python -m pytest tests/game_test.py -v
|
||||||
121
.pre-commit-config.yaml
Normal file
121
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
repos:
|
||||||
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
|
rev: v4.3.0
|
||||||
|
hooks:
|
||||||
|
- id: check-ast
|
||||||
|
- id: sort-simple-yaml
|
||||||
|
- id: check-yaml
|
||||||
|
exclude: |
|
||||||
|
(?x)^(
|
||||||
|
meta.yaml
|
||||||
|
)$
|
||||||
|
- id: check-xml
|
||||||
|
- id: check-toml
|
||||||
|
- id: check-docstring-first
|
||||||
|
- id: check-json
|
||||||
|
- id: fix-encoding-pragma
|
||||||
|
- id: detect-private-key
|
||||||
|
- id: trailing-whitespace
|
||||||
|
- repo: https://github.com/asottile/add-trailing-comma
|
||||||
|
rev: v3.1.0
|
||||||
|
hooks:
|
||||||
|
- id: add-trailing-comma
|
||||||
|
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||||
|
rev: v1.7.0
|
||||||
|
hooks:
|
||||||
|
- id: mypy
|
||||||
|
exclude:
|
||||||
|
(?x)(
|
||||||
|
pb2\.py$
|
||||||
|
| grpc\.py$
|
||||||
|
| ^docs
|
||||||
|
| \.html$
|
||||||
|
)
|
||||||
|
args: [
|
||||||
|
--ignore-missing-imports,
|
||||||
|
--disable-error-code=var-annotated,
|
||||||
|
--disable-error-code=union-attr,
|
||||||
|
--disable-error-code=assignment,
|
||||||
|
--disable-error-code=attr-defined,
|
||||||
|
--disable-error-code=import-untyped,
|
||||||
|
--disable-error-code=truthy-function,
|
||||||
|
--follow-imports=skip,
|
||||||
|
--explicit-package-bases,
|
||||||
|
]
|
||||||
|
- repo: https://github.com/psf/black
|
||||||
|
rev: 23.3.0
|
||||||
|
hooks:
|
||||||
|
- id: black
|
||||||
|
args: [ --line-length=79 ]
|
||||||
|
- repo: https://github.com/PyCQA/flake8
|
||||||
|
rev: 6.1.0
|
||||||
|
hooks:
|
||||||
|
- id: flake8
|
||||||
|
args: [ "--extend-ignore=E203"]
|
||||||
|
- repo: https://github.com/pylint-dev/pylint
|
||||||
|
rev: v3.0.2
|
||||||
|
hooks:
|
||||||
|
- id: pylint
|
||||||
|
exclude:
|
||||||
|
(?x)(
|
||||||
|
^docs
|
||||||
|
| pb2\.py$
|
||||||
|
| grpc\.py$
|
||||||
|
| \.demo$
|
||||||
|
| \.md$
|
||||||
|
| \.html$
|
||||||
|
)
|
||||||
|
args: [
|
||||||
|
--disable=W0511,
|
||||||
|
--disable=W0718,
|
||||||
|
--disable=W0122,
|
||||||
|
--disable=C0103,
|
||||||
|
--disable=R0913,
|
||||||
|
--disable=E0401,
|
||||||
|
--disable=E1101,
|
||||||
|
--disable=C0415,
|
||||||
|
--disable=W0603,
|
||||||
|
--disable=R1705,
|
||||||
|
--disable=R0914,
|
||||||
|
--disable=E0601,
|
||||||
|
--disable=W0602,
|
||||||
|
--disable=W0604,
|
||||||
|
--disable=R0801,
|
||||||
|
--disable=R0902,
|
||||||
|
--disable=R0903,
|
||||||
|
--disable=C0123,
|
||||||
|
--disable=W0231,
|
||||||
|
--disable=W1113,
|
||||||
|
--disable=W0221,
|
||||||
|
--disable=R0401,
|
||||||
|
--disable=W0632,
|
||||||
|
--disable=W0123,
|
||||||
|
--disable=C3001,
|
||||||
|
--disable=W0201,
|
||||||
|
--disable=C0302,
|
||||||
|
--disable=W1203,
|
||||||
|
--disable=C2801,
|
||||||
|
--disable=C0114, # Disable missing module docstring for quick dev
|
||||||
|
--disable=C0115, # Disable missing class docstring for quick dev
|
||||||
|
--disable=C0116, # Disable missing function or method docstring for quick dev
|
||||||
|
]
|
||||||
|
- repo: https://github.com/pre-commit/mirrors-eslint
|
||||||
|
rev: v7.32.0
|
||||||
|
hooks:
|
||||||
|
- id: eslint
|
||||||
|
files: \.(js|jsx)$
|
||||||
|
exclude: '.*js_third_party.*'
|
||||||
|
args: [ '--fix' ]
|
||||||
|
- repo: https://github.com/thibaudcolas/pre-commit-stylelint
|
||||||
|
rev: v14.4.0
|
||||||
|
hooks:
|
||||||
|
- id: stylelint
|
||||||
|
files: \.(css)$
|
||||||
|
exclude: '.*css_third_party.*'
|
||||||
|
args: [ '--fix' ]
|
||||||
|
- repo: https://github.com/pre-commit/mirrors-prettier
|
||||||
|
rev: 'v3.0.0'
|
||||||
|
hooks:
|
||||||
|
- id: prettier
|
||||||
|
additional_dependencies: [ 'prettier@3.0.0' ]
|
||||||
|
files: \.(tsx?)$
|
||||||
6
.stylelintrc
Normal file
6
.stylelintrc
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
{
|
||||||
|
"rules": {
|
||||||
|
"indentation": 2,
|
||||||
|
"string-quotes": "double"
|
||||||
|
}
|
||||||
|
}
|
||||||
14
README_zh.md
14
README_zh.md
@@ -137,9 +137,9 @@ AgentScope Runtime 是一个**全面的运行时框架**,主要解决部署和
|
|||||||
|
|
||||||
如果你:
|
如果你:
|
||||||
|
|
||||||
- 需要安装帮助
|
- 需要安装帮助
|
||||||
- 遇到问题
|
- 遇到问题
|
||||||
- 想了解某个示例的工作方式
|
- 想了解某个示例的工作方式
|
||||||
|
|
||||||
请:
|
请:
|
||||||
|
|
||||||
@@ -157,10 +157,10 @@ AgentScope Runtime 是一个**全面的运行时框架**,主要解决部署和
|
|||||||
|
|
||||||
欢迎提交:
|
欢迎提交:
|
||||||
|
|
||||||
- Bug 报告
|
- Bug 报告
|
||||||
- 新功能请求
|
- 新功能请求
|
||||||
- 文档改进
|
- 文档改进
|
||||||
- 代码贡献
|
- 代码贡献
|
||||||
|
|
||||||
详情见 [Contributing](https://github.com/agentscope-ai/agentscope-samples/blob/main/CONTRIBUTING_zh.md) 文档。
|
详情见 [Contributing](https://github.com/agentscope-ai/agentscope-samples/blob/main/CONTRIBUTING_zh.md) 文档。
|
||||||
|
|
||||||
|
|||||||
@@ -4,20 +4,25 @@
|
|||||||
|
|
||||||
import re
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Optional
|
from typing import Optional, Any
|
||||||
|
|
||||||
from agentscope.agent import ReActAgent
|
from agentscope.agent import ReActAgent
|
||||||
from agentscope.formatter import FormatterBase
|
from agentscope.formatter import FormatterBase
|
||||||
from agentscope.memory import MemoryBase
|
from agentscope.memory import MemoryBase
|
||||||
from agentscope.message import Msg, TextBlock, ToolUseBlock
|
from agentscope.message import (
|
||||||
|
Msg,
|
||||||
|
ToolUseBlock,
|
||||||
|
TextBlock,
|
||||||
|
)
|
||||||
from agentscope.model import ChatModelBase
|
from agentscope.model import ChatModelBase
|
||||||
from agentscope.token import OpenAITokenCounter, TokenCounterBase
|
|
||||||
from agentscope.tool import Toolkit
|
from agentscope.tool import Toolkit
|
||||||
|
from agentscope.token import TokenCounterBase, OpenAITokenCounter
|
||||||
|
|
||||||
_BROWSER_AGENT_DEFAULT_SYS_PROMPT = (
|
_BROWSER_AGENT_DEFAULT_SYS_PROMPT = (
|
||||||
"You are a helpful browser automation assistant. "
|
"You are a helpful browser automation assistant. "
|
||||||
"You can navigate websites, take screenshots, and interact with web pages."
|
"You can navigate websites, take screenshots, and interact with web pages."
|
||||||
"Always describe what you see and meta_planner_agent your next steps clearly. "
|
"Always describe what you see and meta_planner_agent"
|
||||||
|
" your next steps clearly. "
|
||||||
"When taking actions, explain what you're doing and why."
|
"When taking actions, explain what you're doing and why."
|
||||||
)
|
)
|
||||||
_BROWSER_AGENT_REASONING_PROMPT = (
|
_BROWSER_AGENT_REASONING_PROMPT = (
|
||||||
@@ -318,7 +323,7 @@ class BrowserAgent(ReActAgent):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Format the prompt for the model
|
# Format the prompt for the model
|
||||||
prompt = self.formatter.format(
|
prompt = await self.formatter.format(
|
||||||
msgs=[
|
msgs=[
|
||||||
Msg("system", self.sys_prompt, "system"),
|
Msg("system", self.sys_prompt, "system"),
|
||||||
*memory_msgs,
|
*memory_msgs,
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from agentscope.memory import InMemoryMemory
|
|||||||
from agentscope.model import DashScopeChatModel
|
from agentscope.model import DashScopeChatModel
|
||||||
from agentscope.tool import Toolkit
|
from agentscope.tool import Toolkit
|
||||||
|
|
||||||
from .browser_agent import BrowserAgent # pylint: disable=C0411
|
from browser_agent import BrowserAgent # pylint: disable=C0411
|
||||||
|
|
||||||
|
|
||||||
async def main() -> None:
|
async def main() -> None:
|
||||||
|
|||||||
@@ -3,8 +3,8 @@ import asyncio
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
from agentscope.agent import ReActAgent
|
from agentscope.agent import ReActAgent
|
||||||
from agentscope_runtime.engine import LocalDeployManager, Runner
|
|
||||||
from agentscope.model import DashScopeChatModel
|
from agentscope.model import DashScopeChatModel
|
||||||
|
from agentscope_runtime.engine import LocalDeployManager, Runner
|
||||||
from agentscope_runtime.engine.agents.agentscope_agent import AgentScopeAgent
|
from agentscope_runtime.engine.agents.agentscope_agent import AgentScopeAgent
|
||||||
from agentscope_runtime.engine.services.context_manager import ContextManager
|
from agentscope_runtime.engine.services.context_manager import ContextManager
|
||||||
|
|
||||||
@@ -23,12 +23,13 @@ async def _local_deploy():
|
|||||||
model = DashScopeChatModel(
|
model = DashScopeChatModel(
|
||||||
model_name="qwen-turbo",
|
model_name="qwen-turbo",
|
||||||
api_key=os.getenv("DASHSCOPE_API_KEY"),
|
api_key=os.getenv("DASHSCOPE_API_KEY"),
|
||||||
|
|
||||||
)
|
)
|
||||||
agent = AgentScopeAgent(
|
agent = AgentScopeAgent(
|
||||||
name="Friday",
|
name="Friday",
|
||||||
model=model,
|
model=model,
|
||||||
agent_config={"sys_prompt": "A simple LLM agent to generate a short response"},
|
agent_config={
|
||||||
|
"sys_prompt": "A simple LLM agent to generate a short response",
|
||||||
|
},
|
||||||
agent_builder=ReActAgent,
|
agent_builder=ReActAgent,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -3,14 +3,13 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Tuple, Optional, Union, Dict, Any, Generator
|
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from flask import Flask, jsonify, request
|
from flask import Flask, request, jsonify
|
||||||
from flask_cors import CORS
|
from flask_cors import CORS
|
||||||
from flask_sqlalchemy import SQLAlchemy
|
from flask_sqlalchemy import SQLAlchemy
|
||||||
from werkzeug.security import check_password_hash, generate_password_hash
|
from werkzeug.security import generate_password_hash, check_password_hash
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.INFO,
|
level=logging.INFO,
|
||||||
@@ -48,10 +47,10 @@ class User(db.Model):
|
|||||||
cascade="all, delete-orphan",
|
cascade="all, delete-orphan",
|
||||||
)
|
)
|
||||||
|
|
||||||
def set_password(self, password: str) -> None:
|
def set_password(self, password):
|
||||||
self.password_hash = generate_password_hash(password)
|
self.password_hash = generate_password_hash(password)
|
||||||
|
|
||||||
def check_password(self, password: str) -> bool:
|
def check_password(self, password):
|
||||||
return check_password_hash(self.password_hash, password)
|
return check_password_hash(self.password_hash, password)
|
||||||
|
|
||||||
|
|
||||||
@@ -90,7 +89,7 @@ class Message(db.Model):
|
|||||||
# Create database tables
|
# Create database tables
|
||||||
|
|
||||||
|
|
||||||
def create_tables() -> None:
|
def create_tables():
|
||||||
db.create_all()
|
db.create_all()
|
||||||
|
|
||||||
# Create sample users (if none exist)
|
# Create sample users (if none exist)
|
||||||
@@ -105,9 +104,7 @@ def create_tables() -> None:
|
|||||||
|
|
||||||
|
|
||||||
# functions
|
# functions
|
||||||
def parse_sse_line(
|
def parse_sse_line(line):
|
||||||
line: bytes,
|
|
||||||
) -> Tuple[Optional[str], Optional[Union[str, int]]]:
|
|
||||||
line = line.decode("utf-8").strip()
|
line = line.decode("utf-8").strip()
|
||||||
if line.startswith("data: "):
|
if line.startswith("data: "):
|
||||||
return "data", line[6:]
|
return "data", line[6:]
|
||||||
@@ -120,10 +117,7 @@ def parse_sse_line(
|
|||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
|
|
||||||
def sse_client(
|
def sse_client(url, data=None):
|
||||||
url: str,
|
|
||||||
data: Optional[Dict[str, Any]] = None,
|
|
||||||
) -> Generator[str, None, None]:
|
|
||||||
headers = {
|
headers = {
|
||||||
"Accept": "text/event-stream",
|
"Accept": "text/event-stream",
|
||||||
"Cache-Control": "no-cache",
|
"Cache-Control": "no-cache",
|
||||||
@@ -158,17 +152,13 @@ def sse_client(
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def call_runner(
|
def call_runner(query, query_user_id, query_session_id):
|
||||||
query: str,
|
|
||||||
query_user_id: str,
|
|
||||||
query_session_id: str,
|
|
||||||
) -> Generator[str, None, None]:
|
|
||||||
server_port = int(os.environ.get("SERVER_PORT", "8090"))
|
server_port = int(os.environ.get("SERVER_PORT", "8090"))
|
||||||
server_endpoint = os.environ.get("SERVER_ENDPOINT", "agent")
|
server_endpoint = os.environ.get("SERVER_ENDPOINT", "agent")
|
||||||
server_host = os.environ.get("SERVER_HOST", "localhost")
|
server_host = os.environ.get("SERVER_HOST", "localhost")
|
||||||
|
|
||||||
url = f"http://{server_host}:{server_port}/{server_endpoint}"
|
url = f"http://{server_host}:{server_port}/{server_endpoint}"
|
||||||
data_arg: Dict[str, Any] = {
|
data_arg = {
|
||||||
"input": [
|
"input": [
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
|
|||||||
@@ -95,14 +95,14 @@ class ReflectFailure(BaseModel):
|
|||||||
"properties": {
|
"properties": {
|
||||||
"need_rephrase": {
|
"need_rephrase": {
|
||||||
"type": "boolean",
|
"type": "boolean",
|
||||||
"description": "Set to 'true' if the failed subtask "
|
"description": "Set to 'true' if the failed"
|
||||||
"needs to be rephrased due to a design "
|
" subtask needs to be rephrased due to a design "
|
||||||
"flaw or misunderstanding; otherwise, 'false'.",
|
"flaw or misunderstanding; otherwise, 'false'.",
|
||||||
},
|
},
|
||||||
"rephrased_plan": {
|
"rephrased_plan": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "The modified working meta_planner_agent "
|
"description": "The modified working "
|
||||||
"with only the inappropriate "
|
"meta_planner_agent with only the inappropriate "
|
||||||
"subtask replaced by its improved version. If no "
|
"subtask replaced by its improved version. If no "
|
||||||
"rephrasing is needed, provide an empty string.",
|
"rephrasing is needed, provide an empty string.",
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -1,35 +1,46 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
"""Deep Research Agent"""
|
"""Deep Research Agent"""
|
||||||
# pylint: disable=too-many-lines, no-name-in-module
|
# pylint: disable=too-many-lines, no-name-in-module
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import os
|
import os
|
||||||
from copy import deepcopy
|
import json
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from typing import Type, Optional, Any, Tuple
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Optional, Tuple, Type
|
from copy import deepcopy
|
||||||
|
|
||||||
import shortuuid
|
import shortuuid
|
||||||
from agentscope import logger, setup_logger
|
|
||||||
from agentscope.agent import ReActAgent
|
|
||||||
from agentscope.formatter import FormatterBase
|
|
||||||
from agentscope.mcp import StatefulClientBase
|
|
||||||
from agentscope.memory import MemoryBase
|
|
||||||
from agentscope.message import Msg, TextBlock, ToolResultBlock, ToolUseBlock
|
|
||||||
from agentscope.model import ChatModelBase
|
|
||||||
from agentscope.tool import ToolResponse, view_text_file, write_text_file
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from ..agent_deep_research.utils import (
|
from built_in_prompt.promptmodule import (
|
||||||
get_dynamic_tool_call_json,
|
|
||||||
get_structure_output,
|
|
||||||
load_prompt_dict,
|
|
||||||
truncate_search_result,
|
|
||||||
)
|
|
||||||
from .built_in_prompt.promptmodule import (
|
|
||||||
FollowupJudge,
|
|
||||||
ReflectFailure,
|
|
||||||
SubtasksDecomposition,
|
SubtasksDecomposition,
|
||||||
WebExtraction,
|
WebExtraction,
|
||||||
|
FollowupJudge,
|
||||||
|
ReflectFailure,
|
||||||
|
)
|
||||||
|
from utils import (
|
||||||
|
truncate_search_result,
|
||||||
|
load_prompt_dict,
|
||||||
|
get_dynamic_tool_call_json,
|
||||||
|
get_structure_output,
|
||||||
|
)
|
||||||
|
from agentscope import logger, setup_logger
|
||||||
|
from agentscope.mcp import StatefulClientBase
|
||||||
|
from agentscope.agent import ReActAgent
|
||||||
|
from agentscope.model import ChatModelBase
|
||||||
|
from agentscope.formatter import FormatterBase
|
||||||
|
from agentscope.memory import MemoryBase
|
||||||
|
from agentscope.tool import (
|
||||||
|
ToolResponse,
|
||||||
|
view_text_file,
|
||||||
|
write_text_file,
|
||||||
|
)
|
||||||
|
from agentscope.message import (
|
||||||
|
Msg,
|
||||||
|
ToolUseBlock,
|
||||||
|
TextBlock,
|
||||||
|
ToolResultBlock,
|
||||||
)
|
)
|
||||||
|
|
||||||
_DEEP_RESEARCH_AGENT_DEFAULT_SYS_PROMPT = "You're a helpful assistant."
|
_DEEP_RESEARCH_AGENT_DEFAULT_SYS_PROMPT = "You're a helpful assistant."
|
||||||
@@ -149,7 +160,7 @@ class DeepResearchAgent(ReActAgent):
|
|||||||
# register all necessary tools for deep research agent
|
# register all necessary tools for deep research agent
|
||||||
self.toolkit.register_tool_function(view_text_file)
|
self.toolkit.register_tool_function(view_text_file)
|
||||||
self.toolkit.register_tool_function(write_text_file)
|
self.toolkit.register_tool_function(write_text_file)
|
||||||
asyncio.create_task(
|
asyncio.get_running_loop().create_task(
|
||||||
self.toolkit.register_mcp_client(search_mcp_client),
|
self.toolkit.register_mcp_client(search_mcp_client),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -213,16 +224,12 @@ class DeepResearchAgent(ReActAgent):
|
|||||||
reasoning_prompt = self.prompt_dict["reasoning_prompt"].format_map(
|
reasoning_prompt = self.prompt_dict["reasoning_prompt"].format_map(
|
||||||
{
|
{
|
||||||
"objective": self.current_subtask[-1].objective,
|
"objective": self.current_subtask[-1].objective,
|
||||||
"meta_planner_agent": (
|
"plan": cur_plan
|
||||||
cur_plan
|
if cur_plan
|
||||||
if cur_plan
|
else "There is no working plan now.",
|
||||||
else "There is no working meta_planner_agent now."
|
"knowledge_gap": f"## Knowledge Gaps:\n {cur_know_gap}"
|
||||||
),
|
if cur_know_gap
|
||||||
"knowledge_gap": (
|
else "",
|
||||||
f"## Knowledge Gaps:\n {cur_know_gap}"
|
|
||||||
if cur_know_gap
|
|
||||||
else ""
|
|
||||||
),
|
|
||||||
"depth": len(self.current_subtask),
|
"depth": len(self.current_subtask),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@@ -300,7 +307,9 @@ class DeepResearchAgent(ReActAgent):
|
|||||||
# Async generator handling
|
# Async generator handling
|
||||||
async for chunk in tool_res:
|
async for chunk in tool_res:
|
||||||
# Turn into a tool result block
|
# Turn into a tool result block
|
||||||
tool_res_msg.content[0]["output"] = chunk.content # type: ignore[index]
|
tool_res_msg.content[0][ # type: ignore[index]
|
||||||
|
"output"
|
||||||
|
] = chunk.content
|
||||||
|
|
||||||
# Skip the printing of the finish function call
|
# Skip the printing of the finish function call
|
||||||
if (
|
if (
|
||||||
@@ -488,13 +497,14 @@ class DeepResearchAgent(ReActAgent):
|
|||||||
|
|
||||||
async def decompose_and_expand_subtask(self) -> ToolResponse:
|
async def decompose_and_expand_subtask(self) -> ToolResponse:
|
||||||
"""Identify the knowledge gaps of the current subtask and generate a
|
"""Identify the knowledge gaps of the current subtask and generate a
|
||||||
working meta_planner_agent by subtask decomposition. The working meta_planner_agent includes
|
working meta_planner_agent by subtask decomposition.
|
||||||
|
The working meta_planner_agent includes
|
||||||
necessary steps for task completion and expanded steps.
|
necessary steps for task completion and expanded steps.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ToolResponse:
|
ToolResponse:
|
||||||
The knowledge gaps and working meta_planner_agent of the current subtask
|
The knowledge gaps and working meta_planner_agent
|
||||||
in JSON format.
|
of the current subtask in JSON format.
|
||||||
"""
|
"""
|
||||||
if len(self.current_subtask) <= self.max_depth:
|
if len(self.current_subtask) <= self.max_depth:
|
||||||
decompose_sys_prompt = self.prompt_dict["decompose_sys_prompt"]
|
decompose_sys_prompt = self.prompt_dict["decompose_sys_prompt"]
|
||||||
@@ -947,7 +957,8 @@ class DeepResearchAgent(ReActAgent):
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ToolResponse:
|
ToolResponse:
|
||||||
The reflection about meta_planner_agent rephrasing and subtask decomposition.
|
The reflection about meta_planner_agent
|
||||||
|
rephrasing and subtask decomposition.
|
||||||
"""
|
"""
|
||||||
reflect_sys_prompt = self.prompt_dict["reflect_sys_prompt"]
|
reflect_sys_prompt = self.prompt_dict["reflect_sys_prompt"]
|
||||||
conversation_history = ""
|
conversation_history = ""
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from agentscope.memory import InMemoryMemory
|
|||||||
from agentscope.message import Msg
|
from agentscope.message import Msg
|
||||||
from agentscope.model import DashScopeChatModel
|
from agentscope.model import DashScopeChatModel
|
||||||
|
|
||||||
from .deep_research_agent import DeepResearchAgent
|
from deep_research_agent import DeepResearchAgent
|
||||||
|
|
||||||
|
|
||||||
async def main(user_query: str) -> None:
|
async def main(user_query: str) -> None:
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
"""The utilities for deep research agent"""
|
"""The utilities for deep research agent"""
|
||||||
import json
|
|
||||||
import os
|
import os
|
||||||
|
import json
|
||||||
import re
|
import re
|
||||||
from typing import Any, Sequence, Type, Union
|
from typing import Union, Sequence, Any, Type
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from agentscope.tool import Toolkit, ToolResponse
|
from agentscope.tool import Toolkit, ToolResponse
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
TOOL_RESULTS_MAX_WORDS = 5000
|
TOOL_RESULTS_MAX_WORDS = 5000
|
||||||
|
|
||||||
@@ -281,11 +281,14 @@ def load_prompt_dict() -> dict:
|
|||||||
|
|
||||||
prompt_dict["summarize_hint"] = (
|
prompt_dict["summarize_hint"] = (
|
||||||
"Based on your work history above, examine which step in the "
|
"Based on your work history above, examine which step in the "
|
||||||
"following working meta_planner_agent has been completed. Mark the completed "
|
"following working meta_planner_agent has been "
|
||||||
|
"completed. Mark the completed "
|
||||||
"step with [DONE] at the end of its line (e.g., k. step k [DONE]) "
|
"step with [DONE] at the end of its line (e.g., k. step k [DONE]) "
|
||||||
"and leave the uncompleted steps unchanged. You MUST return only "
|
"and leave the uncompleted steps unchanged. You MUST return only "
|
||||||
"the updated meta_planner_agent, preserving exactly the same format as the "
|
"the updated meta_planner_agent, preserving exactly "
|
||||||
"original meta_planner_agent. Do not include any explanations, reasoning, "
|
"the same format as the "
|
||||||
|
"original meta_planner_agent. Do not include any "
|
||||||
|
"explanations, reasoning, "
|
||||||
"or section headers such as '## Working Plan:', just output the"
|
"or section headers such as '## Working Plan:', just output the"
|
||||||
"updated meta_planner_agent itself."
|
"updated meta_planner_agent itself."
|
||||||
"\n\n## Working Plan:\n{meta_planner_agent}"
|
"\n\n## Working Plan:\n{meta_planner_agent}"
|
||||||
@@ -304,11 +307,13 @@ def load_prompt_dict() -> dict:
|
|||||||
"following report that consolidates and summarizes the essential "
|
"following report that consolidates and summarizes the essential "
|
||||||
"findings:\n {intermediate_report}\n\n"
|
"findings:\n {intermediate_report}\n\n"
|
||||||
"Such report has been saved to the {report_path}. "
|
"Such report has been saved to the {report_path}. "
|
||||||
"I will now **proceed to the next item** in the working meta_planner_agent."
|
"I will now **proceed to the next item** "
|
||||||
|
"in the working meta_planner_agent."
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt_dict["save_report_hint"] = (
|
prompt_dict["save_report_hint"] = (
|
||||||
"The milestone results of the current item in working meta_planner_agent "
|
"The milestone results of the current "
|
||||||
|
"item in working meta_planner_agent "
|
||||||
"are summarized into the following report:\n{intermediate_report}"
|
"are summarized into the following report:\n{intermediate_report}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,16 +1,16 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
import os
|
import os
|
||||||
|
import uuid
|
||||||
import random
|
import random
|
||||||
import string
|
import string
|
||||||
import time
|
import time
|
||||||
import uuid
|
|
||||||
from base64 import b64encode
|
from base64 import b64encode
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
from hmac import new as hmac_new
|
from hmac import new as hmac_new
|
||||||
from typing import Any, Dict, List
|
from typing import List, Dict, Any
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from utils import format_time
|
from .utils import format_time
|
||||||
|
|
||||||
|
|
||||||
class CustomSearchTool:
|
class CustomSearchTool:
|
||||||
|
|||||||
@@ -3,22 +3,25 @@ import asyncio
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
|
|
||||||
from agentscope_runtime.engine.agents.langgraph_agent import LangGraphAgent
|
from typing import List, Dict, Any, Optional
|
||||||
from agentscope_runtime.engine.helpers.helper import simple_call_agent_direct
|
|
||||||
from configuration import Configuration
|
|
||||||
from custom_search_tool import CustomSearchTool
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from langchain_core.messages import AIMessage, HumanMessage
|
from langchain_core.messages import AIMessage, HumanMessage
|
||||||
from langchain_core.runnables import RunnableConfig
|
from langchain_core.runnables import RunnableConfig
|
||||||
from langgraph.graph import END, START, StateGraph
|
from langgraph.graph import START, END
|
||||||
|
from langgraph.graph import StateGraph
|
||||||
from langgraph.types import Send
|
from langgraph.types import Send
|
||||||
|
from agentscope_runtime.engine.agents.langgraph_agent import LangGraphAgent
|
||||||
|
from agentscope_runtime.engine.helpers.helper import simple_call_agent_direct
|
||||||
|
|
||||||
|
from configuration import Configuration
|
||||||
|
from custom_search_tool import CustomSearchTool
|
||||||
from llm_prompts import (
|
from llm_prompts import (
|
||||||
answer_instructions,
|
|
||||||
query_writer_instructions,
|
query_writer_instructions,
|
||||||
reflection_instructions,
|
|
||||||
web_searcher_instructions,
|
web_searcher_instructions,
|
||||||
|
reflection_instructions,
|
||||||
|
answer_instructions,
|
||||||
)
|
)
|
||||||
from llm_utils import call_dashscope, extract_json_from_qwen
|
from llm_utils import call_dashscope, extract_json_from_qwen
|
||||||
from state import (
|
from state import (
|
||||||
@@ -27,12 +30,12 @@ from state import (
|
|||||||
ReflectionState,
|
ReflectionState,
|
||||||
WebSearchState,
|
WebSearchState,
|
||||||
)
|
)
|
||||||
from utils import (
|
from .utils import (
|
||||||
custom_get_citations,
|
|
||||||
custom_resolve_urls,
|
|
||||||
get_current_date,
|
|
||||||
get_research_topic,
|
get_research_topic,
|
||||||
insert_citation_markers,
|
insert_citation_markers,
|
||||||
|
custom_resolve_urls,
|
||||||
|
custom_get_citations,
|
||||||
|
get_current_date,
|
||||||
)
|
)
|
||||||
|
|
||||||
load_dotenv("../.env")
|
load_dotenv("../.env")
|
||||||
|
|||||||
@@ -1,12 +1,11 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
import time
|
import time
|
||||||
from datetime import datetime
|
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
from datetime import datetime
|
||||||
from langchain_core.messages import AIMessage, AnyMessage, HumanMessage
|
from langchain_core.messages import AnyMessage, AIMessage, HumanMessage
|
||||||
|
|
||||||
|
|
||||||
def get_current_date() -> str:
|
def get_current_date():
|
||||||
return datetime.now().strftime("%B %d, %Y")
|
return datetime.now().strftime("%B %d, %Y")
|
||||||
|
|
||||||
|
|
||||||
@@ -39,7 +38,7 @@ def get_research_topic(messages: List[AnyMessage]) -> str:
|
|||||||
return research_topic
|
return research_topic
|
||||||
|
|
||||||
|
|
||||||
def insert_citation_markers(text: str, citations_list: List[Dict]) -> str:
|
def insert_citation_markers(text, citations_list):
|
||||||
"""
|
"""
|
||||||
Inserts citation markers into a text string based on start and end indices.
|
Inserts citation markers into a text string based on start and end indices.
|
||||||
|
|
||||||
|
|||||||
@@ -21,8 +21,8 @@ from agentscope.tool import Toolkit
|
|||||||
|
|
||||||
|
|
||||||
async def react_agent_solution(
|
async def react_agent_solution(
|
||||||
ace_task: Task,
|
ace_task: Task,
|
||||||
pre_hook: Callable,
|
pre_hook: Callable,
|
||||||
) -> SolutionOutput:
|
) -> SolutionOutput:
|
||||||
"""Run ReAct agent with the given task in ACEBench.
|
"""Run ReAct agent with the given task in ACEBench.
|
||||||
|
|
||||||
@@ -42,8 +42,8 @@ async def react_agent_solution(
|
|||||||
agent = ReActAgent(
|
agent = ReActAgent(
|
||||||
name="Friday",
|
name="Friday",
|
||||||
sys_prompt="You are a helpful assistant named Friday. "
|
sys_prompt="You are a helpful assistant named Friday. "
|
||||||
"Your target is to solve the given task with your tools."
|
"Your target is to solve the given task with your tools."
|
||||||
"Try to solve the task as best as you can.",
|
"Try to solve the task as best as you can.",
|
||||||
model=DashScopeChatModel(
|
model=DashScopeChatModel(
|
||||||
api_key=os.environ.get("DASHSCOPE_API_KEY"),
|
api_key=os.environ.get("DASHSCOPE_API_KEY"),
|
||||||
model_name="qwen-max",
|
model_name="qwen-max",
|
||||||
|
|||||||
@@ -2,8 +2,8 @@
|
|||||||
"""The structured output models used in the werewolf game."""
|
"""The structured output models used in the werewolf game."""
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
from agentscope.agent import AgentBase
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
from agentscope.agent import AgentBase
|
||||||
|
|
||||||
|
|
||||||
class DiscussionModel(BaseModel):
|
class DiscussionModel(BaseModel):
|
||||||
@@ -44,7 +44,9 @@ def get_poison_model(agents: list[AgentBase]) -> type[BaseModel]:
|
|||||||
poison: bool = Field(
|
poison: bool = Field(
|
||||||
description="Do you want to use the poison potion",
|
description="Do you want to use the poison potion",
|
||||||
)
|
)
|
||||||
name: Literal[tuple(_.name for _ in agents)] | None = Field( # type: ignore
|
name: Literal[ # type: ignore
|
||||||
|
tuple(_.name for _ in agents)
|
||||||
|
] | None = Field(
|
||||||
description="The name of the player you want to poison, if you "
|
description="The name of the player you want to poison, if you "
|
||||||
"don't want to poison anyone, just leave it empty",
|
"don't want to poison anyone, just leave it empty",
|
||||||
default=None,
|
default=None,
|
||||||
@@ -75,7 +77,9 @@ def get_hunter_model(agents: list[AgentBase]) -> type[BaseModel]:
|
|||||||
shoot: bool = Field(
|
shoot: bool = Field(
|
||||||
description="Whether you want to use the shooting ability or not",
|
description="Whether you want to use the shooting ability or not",
|
||||||
)
|
)
|
||||||
name: Literal[tuple(_.name for _ in agents)] | None = Field( # type: ignore
|
name: Literal[ # type: ignore
|
||||||
|
tuple(_.name for _ in agents)
|
||||||
|
] | None = Field(
|
||||||
description="The name of the player you want to shoot, if you "
|
description="The name of the player you want to shoot, if you "
|
||||||
"don't want to the ability, just leave it empty",
|
"don't want to the ability, just leave it empty",
|
||||||
default=None,
|
default=None,
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
"""
|
"""
|
||||||
[Your Sample Name] - Entry Point
|
[Your Sample Name] - Entry Point
|
||||||
|
|
||||||
@@ -5,6 +6,7 @@ This example demonstrates [brief description].
|
|||||||
"""
|
"""
|
||||||
import agentscope
|
import agentscope
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""Main function to run the example."""
|
"""Main function to run the example."""
|
||||||
print(agentscope.__version__)
|
print(agentscope.__version__)
|
||||||
|
|||||||
@@ -1,9 +1,8 @@
|
|||||||
# tests/agent_deep_research_test.py
|
# -*- coding: utf-8 -*-
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import tempfile
|
import tempfile
|
||||||
from unittest.mock import Mock, AsyncMock, patch
|
from unittest.mock import Mock, AsyncMock, patch, MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from agentscope.formatter import DashScopeChatFormatter
|
from agentscope.formatter import DashScopeChatFormatter
|
||||||
@@ -12,7 +11,9 @@ from agentscope.memory import InMemoryMemory
|
|||||||
from agentscope.message import Msg
|
from agentscope.message import Msg
|
||||||
from agentscope.model import DashScopeChatModel
|
from agentscope.model import DashScopeChatModel
|
||||||
|
|
||||||
from deep_research.agent_deep_research.deep_research_agent import DeepResearchAgent
|
from deep_research.agent_deep_research.deep_research_agent import (
|
||||||
|
DeepResearchAgent,
|
||||||
|
)
|
||||||
from deep_research.agent_deep_research.main import main
|
from deep_research.agent_deep_research.main import main
|
||||||
|
|
||||||
|
|
||||||
@@ -70,12 +71,15 @@ class TestDeepResearchAgent:
|
|||||||
|
|
||||||
def test_agent_initialization(
|
def test_agent_initialization(
|
||||||
self,
|
self,
|
||||||
mock_model,
|
mock_model, # pylint: disable=redefined-outer-name
|
||||||
mock_tavily_client,
|
mock_tavily_client, # pylint: disable=redefined-outer-name
|
||||||
temp_working_dir,
|
temp_working_dir, # pylint: disable=redefined-outer-name
|
||||||
):
|
):
|
||||||
"""Test agent initialization with valid parameters"""
|
"""Test agent initialization with valid parameters"""
|
||||||
with patch("asyncio.create_task"):
|
mock_loop = MagicMock()
|
||||||
|
mock_task = AsyncMock()
|
||||||
|
mock_loop.create_task = MagicMock(return_value=mock_task)
|
||||||
|
with patch("asyncio.get_running_loop", return_value=mock_loop):
|
||||||
agent = DeepResearchAgent(
|
agent = DeepResearchAgent(
|
||||||
name="Friday",
|
name="Friday",
|
||||||
sys_prompt="You are a helpful assistant named Friday.",
|
sys_prompt="You are a helpful assistant named Friday.",
|
||||||
@@ -87,17 +91,17 @@ class TestDeepResearchAgent:
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert agent.name == "Friday"
|
assert agent.name == "Friday"
|
||||||
assert agent.sys_prompt.startswith("You are a helpful assistant named Friday.")
|
assert agent.sys_prompt.startswith(
|
||||||
|
"You are a helpful assistant named Friday.",
|
||||||
|
)
|
||||||
assert agent.tmp_file_storage_dir == temp_working_dir
|
assert agent.tmp_file_storage_dir == temp_working_dir
|
||||||
assert os.path.exists(temp_working_dir)
|
assert os.path.exists(temp_working_dir)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_main_function_success(
|
async def test_main_function_success(
|
||||||
self,
|
self,
|
||||||
mock_env_vars,
|
mock_tavily_client, # pylint: disable=redefined-outer-name
|
||||||
mock_tavily_client,
|
temp_working_dir, # pylint: disable=redefined-outer-name
|
||||||
mock_model,
|
|
||||||
temp_working_dir,
|
|
||||||
):
|
):
|
||||||
"""Test main function with successful execution"""
|
"""Test main function with successful execution"""
|
||||||
with patch(
|
with patch(
|
||||||
@@ -109,17 +113,26 @@ class TestDeepResearchAgent:
|
|||||||
autospec=True,
|
autospec=True,
|
||||||
) as mock_agent_class:
|
) as mock_agent_class:
|
||||||
mock_agent = AsyncMock()
|
mock_agent = AsyncMock()
|
||||||
mock_agent.return_value = Msg("Friday", "Test response", "assistant")
|
mock_agent.return_value = Msg(
|
||||||
|
"Friday",
|
||||||
|
"Test response",
|
||||||
|
"assistant",
|
||||||
|
)
|
||||||
mock_agent_class.return_value = mock_agent
|
mock_agent_class.return_value = mock_agent
|
||||||
|
|
||||||
with patch("os.makedirs") as mock_makedirs:
|
with patch("os.makedirs") as mock_makedirs:
|
||||||
with patch.dict(os.environ, {"AGENT_OPERATION_DIR": temp_working_dir}):
|
with patch.dict(
|
||||||
|
os.environ,
|
||||||
|
{"AGENT_OPERATION_DIR": temp_working_dir},
|
||||||
|
):
|
||||||
test_query = "Test research question"
|
test_query = "Test research question"
|
||||||
msg = Msg("Bob", test_query, "user")
|
|
||||||
|
|
||||||
await main(test_query)
|
await main(test_query)
|
||||||
|
|
||||||
mock_makedirs.assert_called_once_with(temp_working_dir, exist_ok=True)
|
mock_makedirs.assert_called_once_with(
|
||||||
|
temp_working_dir,
|
||||||
|
exist_ok=True,
|
||||||
|
)
|
||||||
mock_agent_class.assert_called_once()
|
mock_agent_class.assert_called_once()
|
||||||
|
|
||||||
# ✅ Use assert_called_once() + manual argument check
|
# ✅ Use assert_called_once() + manual argument check
|
||||||
@@ -138,8 +151,7 @@ class TestDeepResearchAgent:
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_agent_cleanup(
|
async def test_agent_cleanup(
|
||||||
self,
|
self,
|
||||||
mock_env_vars,
|
mock_tavily_client, # pylint: disable=redefined-outer-name
|
||||||
mock_tavily_client,
|
|
||||||
):
|
):
|
||||||
"""Test proper cleanup of resources"""
|
"""Test proper cleanup of resources"""
|
||||||
with patch(
|
with patch(
|
||||||
@@ -151,7 +163,10 @@ class TestDeepResearchAgent:
|
|||||||
|
|
||||||
mock_tavily_client.close.assert_called_once()
|
mock_tavily_client.close.assert_called_once()
|
||||||
|
|
||||||
def test_working_directory_creation(self, temp_working_dir):
|
def test_working_directory_creation(
|
||||||
|
self,
|
||||||
|
temp_working_dir, # pylint: disable=redefined-outer-name
|
||||||
|
):
|
||||||
"""Test working directory is created correctly"""
|
"""Test working directory is created correctly"""
|
||||||
test_dir = os.path.join(temp_working_dir, "test_subdir")
|
test_dir = os.path.join(temp_working_dir, "test_subdir")
|
||||||
os.makedirs(test_dir, exist_ok=True)
|
os.makedirs(test_dir, exist_ok=True)
|
||||||
@@ -161,17 +176,28 @@ class TestDeepResearchAgent:
|
|||||||
|
|
||||||
class TestErrorHandling:
|
class TestErrorHandling:
|
||||||
"""Test suite for error handling scenarios"""
|
"""Test suite for error handling scenarios"""
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_filesystem_errors(self, mock_env_vars, mock_tavily_client):
|
async def test_filesystem_errors(
|
||||||
|
self,
|
||||||
|
mock_tavily_client, # pylint: disable=redefined-outer-name
|
||||||
|
):
|
||||||
"""Test handling of filesystem errors"""
|
"""Test handling of filesystem errors"""
|
||||||
with patch(
|
with patch(
|
||||||
"deep_research.agent_deep_research.main.StdIOStatefulClient",
|
"deep_research.agent_deep_research.main.StdIOStatefulClient",
|
||||||
return_value=mock_tavily_client,
|
return_value=mock_tavily_client,
|
||||||
):
|
):
|
||||||
with patch.dict(os.environ, {"AGENT_OPERATION_DIR": "/invalid/path"}):
|
with patch.dict(
|
||||||
with patch("os.makedirs", side_effect=PermissionError("Permission denied")):
|
os.environ,
|
||||||
|
{"AGENT_OPERATION_DIR": "/invalid/path"},
|
||||||
|
):
|
||||||
|
with patch(
|
||||||
|
"os.makedirs",
|
||||||
|
side_effect=PermissionError("Permission denied"),
|
||||||
|
):
|
||||||
with pytest.raises(PermissionError):
|
with pytest.raises(PermissionError):
|
||||||
await main("Test query")
|
await main("Test query")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
pytest.main(["-v", __file__])
|
pytest.main(["-v", __file__])
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
import pytest
|
from typing import Dict
|
||||||
import asyncio
|
|
||||||
from typing import Dict, Any, AsyncGenerator
|
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
import pytest
|
||||||
from agentscope.message import Msg
|
from agentscope.message import Msg
|
||||||
from agentscope.tool import Toolkit
|
from agentscope.tool import Toolkit
|
||||||
from agentscope.memory import MemoryBase
|
from agentscope.memory import MemoryBase
|
||||||
@@ -22,7 +21,10 @@ def mock_dependencies() -> Dict[str, MagicMock]:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def agent(mock_dependencies: Dict[str, MagicMock]) -> BrowserAgent:
|
def agent(
|
||||||
|
# pylint: disable=redefined-outer-name
|
||||||
|
mock_dependencies: Dict[str, MagicMock],
|
||||||
|
) -> BrowserAgent:
|
||||||
return BrowserAgent(
|
return BrowserAgent(
|
||||||
name="TestBot",
|
name="TestBot",
|
||||||
model=mock_dependencies["model"],
|
model=mock_dependencies["model"],
|
||||||
@@ -36,17 +38,28 @@ def agent(mock_dependencies: Dict[str, MagicMock]) -> BrowserAgent:
|
|||||||
# -----------------------------
|
# -----------------------------
|
||||||
# ✅ Hook registration verification (adapted for ReActAgentBase)
|
# ✅ Hook registration verification (adapted for ReActAgentBase)
|
||||||
# -----------------------------
|
# -----------------------------
|
||||||
def test_hooks_registered(agent: BrowserAgent) -> None:
|
def test_hooks_registered(
|
||||||
# Verify instance-level hooks
|
agent: BrowserAgent, # pylint: disable=redefined-outer-name
|
||||||
assert hasattr(agent, "_instance_pre_reply_hooks")
|
) -> None:
|
||||||
|
"""Verify instance-level hooks are registered"""
|
||||||
|
# Disable pylint warning for protected member access
|
||||||
|
assert hasattr(
|
||||||
|
agent,
|
||||||
|
"_instance_pre_reply_hooks",
|
||||||
|
) # pylint: disable=protected-access
|
||||||
assert (
|
assert (
|
||||||
"browser_agent_default_url_pre_reply"
|
"browser_agent_default_url_pre_reply"
|
||||||
|
# pylint: disable=protected-access
|
||||||
in agent._instance_pre_reply_hooks
|
in agent._instance_pre_reply_hooks
|
||||||
)
|
)
|
||||||
|
|
||||||
assert hasattr(agent, "_instance_pre_reasoning_hooks")
|
assert hasattr(
|
||||||
|
agent,
|
||||||
|
"_instance_pre_reasoning_hooks",
|
||||||
|
) # pylint: disable=protected-access
|
||||||
assert (
|
assert (
|
||||||
"browser_agent_observe_pre_reasoning"
|
"browser_agent_observe_pre_reasoning"
|
||||||
|
# pylint: disable=protected-access
|
||||||
in agent._instance_pre_reasoning_hooks
|
in agent._instance_pre_reasoning_hooks
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -55,15 +68,20 @@ def test_hooks_registered(agent: BrowserAgent) -> None:
|
|||||||
# ✅ Navigation hook test (direct hook invocation)
|
# ✅ Navigation hook test (direct hook invocation)
|
||||||
# -----------------------------
|
# -----------------------------
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_pre_reply_hook_navigation(agent: BrowserAgent) -> None:
|
async def test_pre_reply_hook_navigation(
|
||||||
|
agent: BrowserAgent, # pylint: disable=redefined-outer-name
|
||||||
|
) -> None:
|
||||||
|
# pylint: disable=protected-access
|
||||||
agent._has_initial_navigated = False
|
agent._has_initial_navigated = False
|
||||||
|
|
||||||
# Get instance-level hook function
|
# Get instance-level hook function
|
||||||
|
# pylint: disable=protected-access
|
||||||
hook_func = agent._instance_pre_reply_hooks[
|
hook_func = agent._instance_pre_reply_hooks[
|
||||||
"browser_agent_default_url_pre_reply"
|
"browser_agent_default_url_pre_reply"
|
||||||
]
|
]
|
||||||
await hook_func(agent) # Directly invoke hook function
|
await hook_func(agent) # Directly invoke hook function
|
||||||
|
|
||||||
|
# pylint: disable=protected-access
|
||||||
assert agent._has_initial_navigated is True
|
assert agent._has_initial_navigated is True
|
||||||
assert agent.toolkit.call_tool_function.called
|
assert agent.toolkit.call_tool_function.called
|
||||||
|
|
||||||
@@ -72,13 +90,17 @@ async def test_pre_reply_hook_navigation(agent: BrowserAgent) -> None:
|
|||||||
# ✅ Snapshot hook test (fix content attribute access issue)
|
# ✅ Snapshot hook test (fix content attribute access issue)
|
||||||
# -----------------------------
|
# -----------------------------
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_observe_pre_reasoning(agent: BrowserAgent) -> None:
|
async def test_observe_pre_reasoning(
|
||||||
|
agent: BrowserAgent, # pylint: disable=redefined-outer-name
|
||||||
|
) -> None:
|
||||||
# Mock tool response (fix: use Msg object with content attribute)
|
# Mock tool response (fix: use Msg object with content attribute)
|
||||||
mock_response = AsyncMock()
|
mock_response = AsyncMock()
|
||||||
mock_response.__aiter__.return_value = [
|
mock_response.__aiter__.return_value = [
|
||||||
Msg("system", [{"text": "Snapshot content"}], "system"),
|
Msg("system", [{"text": "Snapshot content"}], "system"),
|
||||||
]
|
]
|
||||||
agent.toolkit.call_tool_function = AsyncMock(return_value=mock_response)
|
agent.toolkit.call_tool_function = AsyncMock(
|
||||||
|
return_value=mock_response,
|
||||||
|
)
|
||||||
|
|
||||||
# Replace memory add method
|
# Replace memory add method
|
||||||
with patch.object(
|
with patch.object(
|
||||||
@@ -87,6 +109,7 @@ async def test_observe_pre_reasoning(agent: BrowserAgent) -> None:
|
|||||||
new_callable=AsyncMock,
|
new_callable=AsyncMock,
|
||||||
) as mock_add:
|
) as mock_add:
|
||||||
# Get instance-level hook function
|
# Get instance-level hook function
|
||||||
|
# pylint: disable=protected-access
|
||||||
hook_func = agent._instance_pre_reasoning_hooks[
|
hook_func = agent._instance_pre_reasoning_hooks[
|
||||||
"browser_agent_observe_pre_reasoning"
|
"browser_agent_observe_pre_reasoning"
|
||||||
]
|
]
|
||||||
@@ -100,7 +123,9 @@ async def test_observe_pre_reasoning(agent: BrowserAgent) -> None:
|
|||||||
# -----------------------------
|
# -----------------------------
|
||||||
# ✅ Text filtering test (improved regex)
|
# ✅ Text filtering test (improved regex)
|
||||||
# -----------------------------
|
# -----------------------------
|
||||||
def test_filter_execution_text(agent: BrowserAgent) -> None:
|
def test_filter_execution_text(
|
||||||
|
agent: BrowserAgent, # pylint: disable=redefined-outer-name
|
||||||
|
) -> None:
|
||||||
text = """
|
text = """
|
||||||
### New console messages
|
### New console messages
|
||||||
Some console output
|
Some console output
|
||||||
@@ -112,6 +137,7 @@ def test_filter_execution_text(agent: BrowserAgent) -> None:
|
|||||||
```
|
```
|
||||||
Regular text content
|
Regular text content
|
||||||
"""
|
"""
|
||||||
|
# pylint: disable=protected-access
|
||||||
filtered = agent._filter_execution_text(text)
|
filtered = agent._filter_execution_text(text)
|
||||||
|
|
||||||
assert "console output" not in filtered
|
assert "console output" not in filtered
|
||||||
@@ -124,7 +150,9 @@ def test_filter_execution_text(agent: BrowserAgent) -> None:
|
|||||||
# ✅ Memory summarization test (already passing)
|
# ✅ Memory summarization test (already passing)
|
||||||
# -----------------------------
|
# -----------------------------
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_memory_summarizing(agent: BrowserAgent) -> None:
|
async def test_memory_summarizing(
|
||||||
|
agent: BrowserAgent, # pylint: disable=redefined-outer-name
|
||||||
|
) -> None:
|
||||||
agent.memory.get_memory = AsyncMock(
|
agent.memory.get_memory = AsyncMock(
|
||||||
return_value=[MagicMock(role="user", content="Original question")]
|
return_value=[MagicMock(role="user", content="Original question")]
|
||||||
* 25,
|
* 25,
|
||||||
@@ -136,6 +164,7 @@ async def test_memory_summarizing(agent: BrowserAgent) -> None:
|
|||||||
content=[MagicMock(text="Summary text")],
|
content=[MagicMock(text="Summary text")],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# pylint: disable=protected-access
|
||||||
await agent._memory_summarizing()
|
await agent._memory_summarizing()
|
||||||
|
|
||||||
assert agent.memory.clear.called
|
assert agent.memory.clear.called
|
||||||
|
|||||||
@@ -1,20 +1,23 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
import pytest
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
|
|
||||||
from browser_use.browser_use_fullstack_runtime.backend.agentscope_browseruse_agent import (
|
|
||||||
AgentscopeBrowseruseAgent,
|
|
||||||
RunStatus,
|
|
||||||
)
|
|
||||||
from browser_use.browser_use_fullstack_runtime.backend.async_quart_service import (
|
|
||||||
app,
|
|
||||||
)
|
|
||||||
from quart.testing import QuartClient
|
from quart.testing import QuartClient
|
||||||
|
|
||||||
|
from browser_use.browser_use_fullstack_runtime.backend import (
|
||||||
|
agentscope_browseruse_agent as agent_module,
|
||||||
|
)
|
||||||
|
from browser_use.browser_use_fullstack_runtime.backend import (
|
||||||
|
async_quart_service as service,
|
||||||
|
)
|
||||||
|
|
||||||
|
AgentscopeBrowseruseAgent = agent_module.AgentscopeBrowseruseAgent
|
||||||
|
RunStatus = agent_module.RunStatus
|
||||||
|
app = service.app
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------
|
# -----------------------------
|
||||||
# 🧪 Singleton Test Configuration
|
# 🧪 Singleton Test Configuration
|
||||||
@@ -31,13 +34,17 @@ def event_loop():
|
|||||||
async def agent_singleton():
|
async def agent_singleton():
|
||||||
"""Session-scoped single instance of AgentscopeBrowseruseAgent"""
|
"""Session-scoped single instance of AgentscopeBrowseruseAgent"""
|
||||||
with patch(
|
with patch(
|
||||||
"browser_use.browser_use_fullstack_runtime.backend.agentscope_browseruse_agent.SandboxService",
|
"browser_use.browser_use_fullstack_runtime."
|
||||||
|
"backend.agentscope_browseruse_agent.SandboxService",
|
||||||
) as MockSandboxService, patch(
|
) as MockSandboxService, patch(
|
||||||
"browser_use.browser_use_fullstack_runtime.backend.agentscope_browseruse_agent.InMemoryMemoryService",
|
"browser_use.browser_use_fullstack_runtime."
|
||||||
|
"backend.agentscope_browseruse_agent.InMemoryMemoryService",
|
||||||
) as MockMemoryService, patch(
|
) as MockMemoryService, patch(
|
||||||
"browser_use.browser_use_fullstack_runtime.backend.agentscope_browseruse_agent.InMemorySessionHistoryService",
|
"browser_use.browser_use_fullstack_runtime."
|
||||||
|
"backend.agentscope_browseruse_agent.InMemorySessionHistoryService",
|
||||||
) as MockHistoryService, patch(
|
) as MockHistoryService, patch(
|
||||||
"agentscope_runtime.sandbox.manager.container_clients.docker_client.docker",
|
"agentscope_runtime.sandbox.manager."
|
||||||
|
"container_clients.docker_client.docker",
|
||||||
) as mock_docker, patch(
|
) as mock_docker, patch(
|
||||||
"agentscope_runtime.sandbox.manager.sandbox_manager.SandboxManager",
|
"agentscope_runtime.sandbox.manager.sandbox_manager.SandboxManager",
|
||||||
) as MockSandboxManager:
|
) as MockSandboxManager:
|
||||||
@@ -88,16 +95,20 @@ async def test_app():
|
|||||||
# ✅ AgentscopeBrowseruseAgent Singleton Tests
|
# ✅ AgentscopeBrowseruseAgent Singleton Tests
|
||||||
# -----------------------------
|
# -----------------------------
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_agent_singleton_initialization(agent_singleton):
|
async def agent_singleton_singleton_initialization(
|
||||||
|
agent_singleton, # pylint: disable=redefined-outer-name
|
||||||
|
):
|
||||||
"""Test agent singleton initialization"""
|
"""Test agent singleton initialization"""
|
||||||
agent = agent_singleton
|
agent = agent_singleton # pylint: disable=redefined-outer-name
|
||||||
assert isinstance(agent, AgentscopeBrowseruseAgent)
|
assert isinstance(agent, AgentscopeBrowseruseAgent)
|
||||||
assert hasattr(agent, "agent")
|
assert hasattr(agent, "agent")
|
||||||
assert hasattr(agent, "runner")
|
assert hasattr(agent, "runner")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_chat_method(agent_singleton):
|
async def test_chat_method(
|
||||||
|
agent_singleton,
|
||||||
|
): # pylint: disable=redefined-outer-name
|
||||||
"""Test chat method handles messages"""
|
"""Test chat method handles messages"""
|
||||||
mock_request = {
|
mock_request = {
|
||||||
"messages": [
|
"messages": [
|
||||||
@@ -108,20 +119,28 @@ async def test_chat_method(agent_singleton):
|
|||||||
# ✅ Create mock object with object/status properties
|
# ✅ Create mock object with object/status properties
|
||||||
mock_event = SimpleNamespace(
|
mock_event = SimpleNamespace(
|
||||||
object="message",
|
object="message",
|
||||||
status=RunStatus.Completed,
|
status=agent_module.RunStatus.Completed,
|
||||||
content=[{"type": "text", "text": "Test response"}],
|
content=[{"type": "text", "text": "Test response"}],
|
||||||
)
|
)
|
||||||
|
|
||||||
with patch.object(agent_singleton.runner, "stream_query") as mock_stream:
|
with patch.object(
|
||||||
|
agent_singleton.runner, # pylint: disable=redefined-outer-name
|
||||||
|
"stream_query",
|
||||||
|
) as mock_stream:
|
||||||
# ✅ Return object with properties
|
# ✅ Return object with properties
|
||||||
async def mock_stream_query(*args, **kwargs):
|
async def mock_stream_query(*_args, **_kwargs):
|
||||||
yield mock_event
|
yield mock_event
|
||||||
|
|
||||||
mock_stream.side_effect = mock_stream_query
|
mock_stream.side_effect = mock_stream_query
|
||||||
|
|
||||||
responses = []
|
responses = []
|
||||||
async for response in agent_singleton.chat(mock_request["messages"]):
|
async for response in agent_singleton.chat(
|
||||||
|
# pylint: disable=redefined-outer-name
|
||||||
|
mock_request["messages"],
|
||||||
|
):
|
||||||
responses.append(response)
|
responses.append(response)
|
||||||
|
|
||||||
assert len(responses) == 1
|
assert len(responses) == 1
|
||||||
assert responses[0][0]["text"] == "Test response" # ✅ Fix property access
|
assert (
|
||||||
|
responses[0][0]["text"] == "Test response"
|
||||||
|
) # ✅ Fix property access
|
||||||
|
|||||||
@@ -1,264 +1,61 @@
|
|||||||
from datetime import datetime, timezone
|
# -*- coding: utf-8 -*-
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import tempfile
|
||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import MagicMock, patch
|
import conversational_agents.chatbot_fullstack_runtime.backend.web_server as ws
|
||||||
from flask import Flask, request, jsonify
|
|
||||||
from flask_sqlalchemy import SQLAlchemy
|
|
||||||
from werkzeug.security import generate_password_hash, check_password_hash
|
|
||||||
|
|
||||||
# Initialize db instance
|
|
||||||
db = SQLAlchemy()
|
|
||||||
|
|
||||||
|
|
||||||
# Define model classes (defined once)
|
app = ws.app
|
||||||
class User(db.Model):
|
_db = ws.db
|
||||||
__tablename__ = "user"
|
User = ws.User
|
||||||
id = db.Column(db.Integer, primary_key=True)
|
|
||||||
username = db.Column(db.String(80), unique=True, nullable=False)
|
|
||||||
password_hash = db.Column(db.String(120), nullable=False)
|
|
||||||
name = db.Column(db.String(100), nullable=False)
|
|
||||||
created_at = db.Column(db.DateTime, default=lambda: datetime.now(timezone.utc))
|
|
||||||
|
|
||||||
def set_password(self, password):
|
|
||||||
self.password_hash = generate_password_hash(password)
|
|
||||||
|
|
||||||
def check_password(self, password):
|
|
||||||
return check_password_hash(self.password_hash, password)
|
|
||||||
|
|
||||||
|
|
||||||
class Conversation(db.Model):
|
def generate_unique_username():
|
||||||
__tablename__ = "conversation"
|
return f"testuser_{int(time.time())}"
|
||||||
id = db.Column(db.Integer, primary_key=True)
|
|
||||||
title = db.Column(db.String(200), nullable=False)
|
|
||||||
user_id = db.Column(db.Integer, db.ForeignKey("user.id"), nullable=False)
|
@pytest.fixture
|
||||||
created_at = db.Column(db.DateTime, default=lambda: datetime.now(timezone.utc))
|
def client_and_username():
|
||||||
updated_at = db.Column(
|
"""Create an Isolated Test Client and Username"""
|
||||||
db.DateTime,
|
db_fd, db_path = tempfile.mkstemp(suffix=".db")
|
||||||
default=lambda: datetime.now(timezone.utc),
|
app.config["SQLALCHEMY_DATABASE_URI"] = f"sqlite:///{db_path}"
|
||||||
onupdate=lambda: datetime.now(timezone.utc),
|
app.config["TESTING"] = True
|
||||||
|
client = app.test_client()
|
||||||
|
|
||||||
|
with app.app_context():
|
||||||
|
_db.drop_all()
|
||||||
|
_db.create_all()
|
||||||
|
|
||||||
|
# Generate Unique Username
|
||||||
|
username = generate_unique_username()
|
||||||
|
password = "testpass"
|
||||||
|
user = User(username=username, name="Test User")
|
||||||
|
user.set_password(password)
|
||||||
|
_db.session.add(user)
|
||||||
|
_db.session.commit()
|
||||||
|
|
||||||
|
yield client, username, password
|
||||||
|
|
||||||
|
os.close(db_fd)
|
||||||
|
os.unlink(db_path)
|
||||||
|
|
||||||
|
|
||||||
|
def test_user_login_success(
|
||||||
|
# pylint: disable=redefined-outer-name
|
||||||
|
client_and_username,
|
||||||
|
):
|
||||||
|
"""Test Successful User Login"""
|
||||||
|
client, username, password = client_and_username
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/api/login",
|
||||||
|
json={
|
||||||
|
"username": username,
|
||||||
|
"password": password,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
messages = db.relationship("Message", backref="conversation", lazy=True)
|
|
||||||
|
|
||||||
|
|
||||||
class Message(db.Model):
|
|
||||||
__tablename__ = "message"
|
|
||||||
id = db.Column(db.Integer, primary_key=True)
|
|
||||||
text = db.Column(db.Text, nullable=False)
|
|
||||||
sender = db.Column(db.String(20), nullable=False)
|
|
||||||
conversation_id = db.Column(db.Integer, db.ForeignKey("conversation.id"), nullable=False)
|
|
||||||
created_at = db.Column(db.DateTime, default=lambda: datetime.now(timezone.utc))
|
|
||||||
|
|
||||||
|
|
||||||
# Thoroughly isolated test Flask application
|
|
||||||
@pytest.fixture
|
|
||||||
def app():
|
|
||||||
"""Create a fresh Flask application instance"""
|
|
||||||
app = Flask(__name__)
|
|
||||||
app.config.update({
|
|
||||||
"SQLALCHEMY_DATABASE_URI": "sqlite:///:memory:",
|
|
||||||
"SQLALCHEMY_TRACK_MODIFICATIONS": False,
|
|
||||||
"TESTING": True,
|
|
||||||
})
|
|
||||||
|
|
||||||
# Initialize db
|
|
||||||
db.init_app(app)
|
|
||||||
|
|
||||||
# Define routes
|
|
||||||
@app.route("/api/login", methods=["POST"])
|
|
||||||
def login():
|
|
||||||
data = request.get_json()
|
|
||||||
username = data.get("username")
|
|
||||||
password = data.get("password")
|
|
||||||
|
|
||||||
if not username or not password:
|
|
||||||
return jsonify({"error": "Username and password cannot be empty"}), 400
|
|
||||||
|
|
||||||
user = User.query.filter_by(username=username).first()
|
|
||||||
if user and user.check_password(password):
|
|
||||||
return jsonify({
|
|
||||||
"id": user.id,
|
|
||||||
"username": user.username,
|
|
||||||
"name": user.name,
|
|
||||||
"created_at": user.created_at.isoformat(),
|
|
||||||
}), 200
|
|
||||||
return jsonify({"error": "Invalid username or password"}), 401
|
|
||||||
|
|
||||||
@app.route("/api/users/<int:user_id>/conversations", methods=["POST"])
|
|
||||||
def create_conversation(user_id):
|
|
||||||
data = request.get_json()
|
|
||||||
title = data.get("title", f"Conversation {datetime.now().strftime('%Y-%m-%d %H:%M')}")
|
|
||||||
conversation = Conversation(title=title, user_id=user_id)
|
|
||||||
db.session.add(conversation)
|
|
||||||
db.session.commit()
|
|
||||||
return jsonify({
|
|
||||||
"id": conversation.id,
|
|
||||||
"title": conversation.title,
|
|
||||||
"user_id": conversation.user_id,
|
|
||||||
"created_at": conversation.created_at.isoformat(),
|
|
||||||
"updated_at": conversation.updated_at.isoformat(),
|
|
||||||
}), 201
|
|
||||||
|
|
||||||
@app.route("/api/conversations/<int:conversation_id>", methods=["GET"])
|
|
||||||
def get_conversation(conversation_id):
|
|
||||||
conversation = Conversation.query.get(conversation_id)
|
|
||||||
if not conversation:
|
|
||||||
return jsonify({"error": "Conversation not found"}), 404
|
|
||||||
|
|
||||||
messages = Message.query.filter_by(conversation_id=conversation_id).order_by(Message.created_at.asc()).all()
|
|
||||||
messages_data = [{
|
|
||||||
"id": msg.id,
|
|
||||||
"text": msg.text,
|
|
||||||
"sender": msg.sender,
|
|
||||||
"created_at": msg.created_at.isoformat(),
|
|
||||||
} for msg in messages]
|
|
||||||
|
|
||||||
return jsonify({
|
|
||||||
"id": conversation.id,
|
|
||||||
"title": conversation.title,
|
|
||||||
"user_id": conversation.user_id,
|
|
||||||
"messages": messages_data,
|
|
||||||
"created_at": conversation.created_at.isoformat(),
|
|
||||||
"updated_at": conversation.updated_at.isoformat(),
|
|
||||||
}), 200
|
|
||||||
|
|
||||||
@app.route("/api/conversations/<int:conversation_id>/messages", methods=["POST"])
|
|
||||||
def send_message(conversation_id):
|
|
||||||
conversation = Conversation.query.get(conversation_id)
|
|
||||||
if not conversation:
|
|
||||||
return jsonify({"error": "Conversation not found"}), 404
|
|
||||||
|
|
||||||
data = request.get_json()
|
|
||||||
text = data.get("text")
|
|
||||||
sender = data.get("sender", "user")
|
|
||||||
|
|
||||||
if not text:
|
|
||||||
return jsonify({"error": "Message content cannot be empty"}), 400
|
|
||||||
|
|
||||||
# Create user message
|
|
||||||
user_message = Message(
|
|
||||||
text=text,
|
|
||||||
sender=sender,
|
|
||||||
conversation_id=conversation_id
|
|
||||||
)
|
|
||||||
db.session.add(user_message)
|
|
||||||
|
|
||||||
# Update conversation title (if this is the first user message)
|
|
||||||
if sender == "user" and len(conversation.messages) <= 1:
|
|
||||||
conversation.title = text[:20] + ("..." if len(text) > 20 else "")
|
|
||||||
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
# Simulate AI response
|
|
||||||
ai_message = Message(
|
|
||||||
text="Test response part 1 Test response part 2",
|
|
||||||
sender="ai",
|
|
||||||
conversation_id=conversation_id
|
|
||||||
)
|
|
||||||
db.session.add(ai_message)
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
return jsonify({
|
|
||||||
"id": user_message.id,
|
|
||||||
"text": user_message.text,
|
|
||||||
"sender": user_message.sender,
|
|
||||||
"created_at": user_message.created_at.isoformat(),
|
|
||||||
}), 201
|
|
||||||
|
|
||||||
# Initialize database
|
|
||||||
with app.app_context():
|
|
||||||
db.create_all()
|
|
||||||
# Create example users
|
|
||||||
if not User.query.first():
|
|
||||||
user1 = User(username="user1", name="Bruce")
|
|
||||||
user1.set_password("password123")
|
|
||||||
db.session.add(user1)
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
yield app
|
|
||||||
|
|
||||||
with app.app_context():
|
|
||||||
db.drop_all()
|
|
||||||
db.session.remove()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def client(app):
|
|
||||||
"""Flask test client"""
|
|
||||||
return app.test_client()
|
|
||||||
|
|
||||||
|
|
||||||
# Mock call_runner function
|
|
||||||
def mock_call_runner(query, session_id, user_id):
|
|
||||||
"""Mock function for call_runner"""
|
|
||||||
yield "Test response part 1"
|
|
||||||
yield " Test response part 2"
|
|
||||||
|
|
||||||
|
|
||||||
def test_login_success(app, client):
|
|
||||||
"""Test successful user login"""
|
|
||||||
with app.app_context():
|
|
||||||
user = User(username="test", name="Test User")
|
|
||||||
user.set_password("testpass")
|
|
||||||
db.session.add(user)
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
response = client.post("/api/login", json={
|
|
||||||
"username": "test",
|
|
||||||
"password": "testpass",
|
|
||||||
})
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.get_json()
|
data = response.get_json()
|
||||||
assert data["username"] == "test"
|
|
||||||
|
|
||||||
|
|
||||||
def test_login_invalid_credentials(app, client):
|
|
||||||
"""Test login with invalid credentials"""
|
|
||||||
response = client.post("/api/login", json={
|
|
||||||
"username": "test",
|
|
||||||
"password": "wrongpass"
|
|
||||||
})
|
|
||||||
assert response.status_code == 401
|
|
||||||
|
|
||||||
|
|
||||||
def test_conversation_crud_operations(app, client):
|
|
||||||
"""Test conversation creation and retrieval"""
|
|
||||||
with app.app_context():
|
|
||||||
user = User(username="test", name="Test User")
|
|
||||||
user.set_password("testpass")
|
|
||||||
db.session.add(user)
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
create_response = client.post("/api/users/1/conversations", json={
|
|
||||||
"title": "Test Conversation",
|
|
||||||
})
|
|
||||||
assert create_response.status_code == 201
|
|
||||||
conversation_id = create_response.get_json()["id"]
|
|
||||||
|
|
||||||
get_response = client.get(f"/api/conversations/{conversation_id}")
|
|
||||||
assert get_response.status_code == 200
|
|
||||||
assert "Test Conversation" in get_response.get_json()["title"]
|
|
||||||
|
|
||||||
|
|
||||||
@patch("tests.conversational_agents_chatbot_fullstack_runtime_webserver_test.db", new=db)
|
|
||||||
def test_send_message(app, client):
|
|
||||||
"""Test message sending and AI response"""
|
|
||||||
with app.app_context():
|
|
||||||
user = User(username="test", name="Test User")
|
|
||||||
user.set_password("testpass")
|
|
||||||
conversation = Conversation(title="Test", user_id=1)
|
|
||||||
db.session.add_all([user, conversation])
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
response = client.post("/api/conversations/1/messages", json={
|
|
||||||
"text": "Hello",
|
|
||||||
"sender": "user"
|
|
||||||
})
|
|
||||||
assert response.status_code == 201
|
|
||||||
data = response.get_json()
|
|
||||||
assert "id" in data
|
assert "id" in data
|
||||||
assert "Hello" in data["text"]
|
assert data["username"] == username
|
||||||
|
|
||||||
# ✅ Move the query into the application context
|
|
||||||
with app.app_context():
|
|
||||||
messages = Message.query.filter_by(conversation_id=1).all()
|
|
||||||
assert len(messages) == 2 # User + AI response
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
import pytest
|
|
||||||
from unittest.mock import AsyncMock
|
from unittest.mock import AsyncMock
|
||||||
|
import pytest
|
||||||
from agentscope.message import Msg
|
from agentscope.message import Msg
|
||||||
from agentscope.agent import ReActAgent
|
from agentscope.agent import ReActAgent
|
||||||
from agentscope.tool import Toolkit
|
from agentscope.tool import Toolkit
|
||||||
@@ -12,13 +12,14 @@ class TestReActAgent:
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def test_agent(self):
|
def test_agent(self):
|
||||||
"""Fixture to create a test ReAct agent with fully mocked dependencies"""
|
"""Fixture to create a test ReAct
|
||||||
|
agent with fully mocked dependencies"""
|
||||||
|
|
||||||
async def model_response(*args, **kwargs):
|
async def model_response():
|
||||||
yield Msg(
|
yield Msg(
|
||||||
name="Friday",
|
name="Friday",
|
||||||
content="Mocked model response",
|
content="Mocked model response",
|
||||||
role="assistant"
|
role="assistant",
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_model = AsyncMock()
|
mock_model = AsyncMock()
|
||||||
@@ -36,10 +37,12 @@ class TestReActAgent:
|
|||||||
model=mock_model,
|
model=mock_model,
|
||||||
formatter=mock_formatter,
|
formatter=mock_formatter,
|
||||||
toolkit=Toolkit(),
|
toolkit=Toolkit(),
|
||||||
memory=mock_memory
|
memory=mock_memory,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# pylint: disable=protected-access
|
||||||
agent._reasoning_hint_msgs = AsyncMock()
|
agent._reasoning_hint_msgs = AsyncMock()
|
||||||
|
# pylint: disable=protected-access
|
||||||
agent._reasoning_hint_msgs.get_memory = AsyncMock(return_value=[])
|
agent._reasoning_hint_msgs.get_memory = AsyncMock(return_value=[])
|
||||||
|
|
||||||
return agent
|
return agent
|
||||||
@@ -47,16 +50,16 @@ class TestReActAgent:
|
|||||||
async def test_exit_command(self, test_agent, monkeypatch):
|
async def test_exit_command(self, test_agent, monkeypatch):
|
||||||
"""Test exit command handling"""
|
"""Test exit command handling"""
|
||||||
|
|
||||||
async def exit_model_response(*args, **kwargs):
|
async def exit_model_response(*_args, **_kwargs):
|
||||||
yield Msg(
|
yield Msg(
|
||||||
name="Friday",
|
name="Friday",
|
||||||
content="exit",
|
content="exit",
|
||||||
role="assistant"
|
role="assistant",
|
||||||
)
|
)
|
||||||
|
|
||||||
test_agent.model.side_effect = exit_model_response
|
test_agent.model.side_effect = exit_model_response
|
||||||
|
|
||||||
monkeypatch.setattr('builtins.input', lambda _: "exit")
|
monkeypatch.setattr("builtins.input", lambda _: "exit")
|
||||||
|
|
||||||
msg = Msg(name="User", content="exit", role="user")
|
msg = Msg(name="User", content="exit", role="user")
|
||||||
response = await test_agent(msg)
|
response = await test_agent(msg)
|
||||||
@@ -66,11 +69,13 @@ class TestReActAgent:
|
|||||||
async def test_conversation_flow(self, monkeypatch):
|
async def test_conversation_flow(self, monkeypatch):
|
||||||
"""Test full conversation flow"""
|
"""Test full conversation flow"""
|
||||||
|
|
||||||
async def model_response(*args, **kwargs):
|
async def model_response(*_args, **_kwargs):
|
||||||
yield Msg(
|
yield Msg(
|
||||||
name="Friday",
|
name="Friday",
|
||||||
content="Thought: I need to use a tool\nAction: execute_shell_command\nAction Input: echo 'Hello World'",
|
content="Thought: I need to use a tool\n"
|
||||||
role="assistant"
|
"Action: execute_shell_command\n"
|
||||||
|
"Action Input: echo 'Hello World'",
|
||||||
|
role="assistant",
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_model = AsyncMock()
|
mock_model = AsyncMock()
|
||||||
@@ -88,11 +93,11 @@ class TestReActAgent:
|
|||||||
model=mock_model,
|
model=mock_model,
|
||||||
formatter=mock_formatter,
|
formatter=mock_formatter,
|
||||||
toolkit=Toolkit(),
|
toolkit=Toolkit(),
|
||||||
memory=mock_memory
|
memory=mock_memory,
|
||||||
)
|
)
|
||||||
|
|
||||||
monkeypatch.setattr('builtins.input', lambda _: "Test command")
|
monkeypatch.setattr("builtins.input", lambda _: "Test command")
|
||||||
|
|
||||||
msg = Msg(name="User", content="Test command", role="user")
|
msg = Msg(name="User", content="Test command", role="user")
|
||||||
response = await agent(msg)
|
response = await agent(msg)
|
||||||
assert "Thought:" in response.content
|
assert "Thought:" in response.content
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
# tests/evaluation_test.py
|
# -*- coding: utf-8 -*-
|
||||||
import asyncio
|
|
||||||
import os
|
import os
|
||||||
from unittest.mock import Mock, AsyncMock, patch
|
from unittest.mock import Mock, AsyncMock, patch
|
||||||
from typing import List, Dict, Any, Tuple, Callable
|
from typing import List, Dict, Any, Tuple, Callable
|
||||||
@@ -29,19 +28,21 @@ class TestReActAgentSolution:
|
|||||||
def mock_pre_hook(self) -> Mock:
|
def mock_pre_hook(self) -> Mock:
|
||||||
"""Create a mock pre-hook function that returns None"""
|
"""Create a mock pre-hook function that returns None"""
|
||||||
|
|
||||||
def pre_hook_return(*args, **kwargs):
|
def pre_hook_return():
|
||||||
"""Mock function that returns None (no modifications)"""
|
"""Mock function that returns None (no modifications)"""
|
||||||
return None
|
return None
|
||||||
|
|
||||||
mock = Mock()
|
mock = Mock()
|
||||||
mock.__name__ = "save_logging"
|
mock.__name__ = "save_logging"
|
||||||
mock.side_effect = pre_hook_return # ✅ Return None to avoid parameter pollution
|
mock.side_effect = (
|
||||||
|
pre_hook_return # ✅ Return None to avoid parameter pollution
|
||||||
|
)
|
||||||
return mock
|
return mock
|
||||||
|
|
||||||
def _create_mock_tools(self) -> List[Tuple[Callable, Dict[str, Any]]]:
|
def _create_mock_tools(self) -> List[Tuple[Callable, Dict[str, Any]]]:
|
||||||
"""Create mock tool functions with schemas"""
|
"""Create mock tool functions with schemas"""
|
||||||
|
|
||||||
def mock_tool(*args, **kwargs):
|
def mock_tool():
|
||||||
return "tool_response"
|
return "tool_response"
|
||||||
|
|
||||||
tool_schema = {
|
tool_schema = {
|
||||||
@@ -110,8 +111,15 @@ class TestMainFunction:
|
|||||||
mock_evaluator_class.return_value = mock_evaluator
|
mock_evaluator_class.return_value = mock_evaluator
|
||||||
|
|
||||||
# ✅ Simulate _download_data and _load_data
|
# ✅ Simulate _download_data and _load_data
|
||||||
with patch("agentscope.evaluate._ace_benchmark._ace_benchmark.ACEBenchmark._download_data"):
|
with patch(
|
||||||
with patch("agentscope.evaluate._ace_benchmark._ace_benchmark.ACEBenchmark._load_data", return_value=[]):
|
"agentscope.evaluate._ace_benchmark."
|
||||||
|
"_ace_benchmark.ACEBenchmark._download_data",
|
||||||
|
):
|
||||||
|
with patch(
|
||||||
|
"agentscope.evaluate._ace_benchmark."
|
||||||
|
"_ace_benchmark.ACEBenchmark._load_data",
|
||||||
|
return_value=[],
|
||||||
|
):
|
||||||
# Run main function
|
# Run main function
|
||||||
await ace_main.main()
|
await ace_main.main()
|
||||||
|
|
||||||
@@ -137,12 +145,19 @@ class TestMainFunction:
|
|||||||
mock_evaluator_class.return_value = mock_evaluator
|
mock_evaluator_class.return_value = mock_evaluator
|
||||||
|
|
||||||
# ✅ Simulate _download_data and _load_data
|
# ✅ Simulate _download_data and _load_data
|
||||||
with patch("agentscope.evaluate._ace_benchmark._ace_benchmark.ACEBenchmark._download_data"):
|
with patch(
|
||||||
with patch("agentscope.evaluate._ace_benchmark._ace_benchmark.ACEBenchmark._load_data", return_value=[]):
|
"agentscope.evaluate._ace_benchmark._ace_benchmark."
|
||||||
|
"ACEBenchmark._download_data",
|
||||||
|
):
|
||||||
|
with patch(
|
||||||
|
"agentscope.evaluate._ace_benchmark."
|
||||||
|
"_ace_benchmark.ACEBenchmark._load_data",
|
||||||
|
return_value=[],
|
||||||
|
):
|
||||||
# Run main function
|
# Run main function
|
||||||
await ace_main.main()
|
await ace_main.main()
|
||||||
|
|
||||||
# Verify evaluation execution
|
# Verify evaluation execution
|
||||||
mock_evaluator.run.assert_called_once_with(
|
mock_evaluator.run.assert_called_once_with(
|
||||||
ace_main.react_agent_solution,
|
ace_main.react_agent_solution,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
import os
|
|
||||||
import asyncio
|
|
||||||
import pytest
|
|
||||||
from unittest.mock import AsyncMock, patch, MagicMock
|
from unittest.mock import AsyncMock, patch, MagicMock
|
||||||
|
import pytest
|
||||||
from agentscope.agent import ReActAgent
|
from agentscope.agent import ReActAgent
|
||||||
from agentscope.model import ChatModelBase
|
from agentscope.model import ChatModelBase
|
||||||
from agentscope.formatter import FormatterBase
|
from agentscope.formatter import FormatterBase
|
||||||
@@ -47,9 +45,12 @@ async def test_witch_resurrect() -> None:
|
|||||||
async def mock_model(**kwargs):
|
async def mock_model(**kwargs):
|
||||||
return {"resurrect": kwargs.get("resurrect", False)}
|
return {"resurrect": kwargs.get("resurrect", False)}
|
||||||
|
|
||||||
with patch("games.game_werewolves.game.WitchResurrectModel", side_effect=mock_model):
|
with patch(
|
||||||
|
"games.game_werewolves.game.WitchResurrectModel",
|
||||||
|
side_effect=mock_model,
|
||||||
|
):
|
||||||
result = await game.WitchResurrectModel(**{"resurrect": True})
|
result = await game.WitchResurrectModel(**{"resurrect": True})
|
||||||
assert result["resurrect"] == True
|
assert result["resurrect"] is True
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------
|
# -----------------------------
|
||||||
@@ -84,8 +85,9 @@ def test_vote_model_generation() -> None:
|
|||||||
name=f"Player{i}",
|
name=f"Player{i}",
|
||||||
sys_prompt=f"Vote system prompt {i}",
|
sys_prompt=f"Vote system prompt {i}",
|
||||||
model=mock_model,
|
model=mock_model,
|
||||||
formatter=mock_formatter
|
formatter=mock_formatter,
|
||||||
) for i in range(3)
|
)
|
||||||
|
for i in range(3)
|
||||||
]
|
]
|
||||||
|
|
||||||
VoteModel = structured_model.get_vote_model(agents)
|
VoteModel = structured_model.get_vote_model(agents)
|
||||||
@@ -105,10 +107,10 @@ def test_witch_poison_model_fields() -> None:
|
|||||||
name="Player1",
|
name="Player1",
|
||||||
sys_prompt="Poison system prompt",
|
sys_prompt="Poison system prompt",
|
||||||
model=mock_model,
|
model=mock_model,
|
||||||
formatter=mock_formatter
|
formatter=mock_formatter,
|
||||||
)
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
PoisonModel = structured_model.get_poison_model(agents)
|
PoisonModel = structured_model.get_poison_model(agents)
|
||||||
assert "poison" in PoisonModel.model_fields
|
assert "poison" in PoisonModel.model_fields
|
||||||
assert "name" in PoisonModel.model_fields
|
assert "name" in PoisonModel.model_fields
|
||||||
|
|||||||
Reference in New Issue
Block a user