Initial commit: Pixel AI comic/video creation platform

- FastAPI backend with SQLModel, Alembic migrations, AgentScope agents
- Next.js 15 frontend with React 19, Tailwind, Zustand, React Flow
- Multi-provider AI system (DashScope, Kling, MiniMax, Volcengine, OpenAI, etc.)
- All HTTP clients migrated from sync requests to async httpx
- Admin-managed API keys via environment variables
- SSRF vulnerability fixed in ensure_url()
This commit is contained in:
张鹏
2026-04-29 01:20:12 +08:00
commit f9f4560459
808 changed files with 151724 additions and 0 deletions

View File

@@ -0,0 +1,348 @@
"""
测试 BaseRepository 功能
验证通用仓储模式的 CRUD 操作、过滤、排序和分页功能
"""
import pytest
from datetime import datetime
from sqlmodel import Session, create_engine, SQLModel
from sqlalchemy.pool import StaticPool
from src.models.entities import ProjectDB, TaskDB
from src.repositories.base_repository import BaseRepository
from src.repositories.task_repository import TaskRepository
@pytest.fixture
def engine():
"""创建内存数据库引擎用于测试"""
engine = create_engine(
"sqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
SQLModel.metadata.create_all(engine)
return engine
@pytest.fixture
def session(engine):
"""创建数据库会话"""
with Session(engine) as session:
yield session
@pytest.fixture
def task_repository(session):
"""创建 TaskRepository 实例"""
return TaskRepository(session)
@pytest.fixture
def sample_tasks(session):
"""创建示例任务数据"""
tasks = [
TaskDB(
id=f"task_{i}",
type="image" if i % 2 == 0 else "video",
status="pending" if i < 3 else "processing",
model="flux-dev",
params={"prompt": f"test prompt {i}"},
user_id="user_1" if i < 5 else "user_2",
project_id="project_1",
created_at=datetime.now().timestamp() + i,
updated_at=datetime.now().timestamp() + i
)
for i in range(10)
]
for task in tasks:
session.add(task)
session.commit()
return tasks
class TestBaseRepositoryCRUD:
"""测试基础 CRUD 操作"""
def test_create(self, task_repository, session):
"""测试创建记录"""
task = TaskDB(
id="test_task",
type="image",
status="pending",
model="flux-dev",
params={"prompt": "test"},
created_at=datetime.now().timestamp(),
updated_at=datetime.now().timestamp()
)
created = task_repository.create(task)
assert created.id == "test_task"
assert created.type == "image"
assert created.status == "pending"
def test_get(self, task_repository, sample_tasks):
"""测试获取单条记录"""
task = task_repository.get("task_0")
assert task is not None
assert task.id == "task_0"
assert task.type == "image"
def test_get_not_found(self, task_repository):
"""测试获取不存在的记录"""
task = task_repository.get("nonexistent")
assert task is None
def test_get_by_field(self, task_repository, sample_tasks):
"""测试按字段获取记录"""
task = task_repository.get_by_field("user_id", "user_1")
assert task is not None
assert task.user_id == "user_1"
def test_update(self, task_repository, sample_tasks):
"""测试更新记录"""
task = task_repository.get("task_0")
task.status = "completed"
updated = task_repository.update(task)
assert updated.status == "completed"
# 验证更新已持久化
retrieved = task_repository.get("task_0")
assert retrieved.status == "completed"
def test_update_by_id(self, task_repository, sample_tasks):
"""测试按 ID 更新记录"""
updated = task_repository.update_by_id("task_0", {"status": "failed"})
assert updated is not None
assert updated.status == "failed"
def test_delete(self, task_repository, sample_tasks):
"""测试删除记录"""
result = task_repository.delete("task_0")
assert result is True
# 验证已删除
task = task_repository.get("task_0")
assert task is None
def test_delete_not_found(self, task_repository):
"""测试删除不存在的记录"""
result = task_repository.delete("nonexistent")
assert result is False
def test_soft_delete(self, task_repository, sample_tasks):
"""测试软删除"""
result = task_repository.soft_delete("task_0")
assert result is True
# 验证 deleted_at 已设置
task = task_repository.get("task_0")
assert task.deleted_at is not None
def test_exists(self, task_repository, sample_tasks):
"""测试检查记录是否存在"""
assert task_repository.exists("task_0") is True
assert task_repository.exists("nonexistent") is False
def test_exists_by_field(self, task_repository, sample_tasks):
"""测试按字段检查记录是否存在"""
assert task_repository.exists_by_field("user_id", "user_1") is True
assert task_repository.exists_by_field("user_id", "nonexistent") is False
class TestBaseRepositoryFiltering:
"""测试过滤功能"""
def test_list_with_simple_filter(self, task_repository, sample_tasks):
"""测试简单过滤"""
tasks = task_repository.list(filters={"type": "image"})
assert len(tasks) == 5
assert all(task.type == "image" for task in tasks)
def test_list_with_multiple_filters(self, task_repository, sample_tasks):
"""测试多个过滤条件"""
tasks = task_repository.list(filters={
"type": "image",
"status": "pending"
})
assert len(tasks) == 2 # task_0 和 task_2
assert all(task.type == "image" and task.status == "pending" for task in tasks)
def test_list_with_gt_filter(self, task_repository, sample_tasks):
"""测试大于过滤"""
# 获取前5个任务user_1
tasks = task_repository.list(filters={"user_id": "user_1"})
assert len(tasks) == 5
def test_list_with_in_filter(self, task_repository, sample_tasks):
"""测试 IN 过滤"""
tasks = task_repository.list(filters={
"status__in": ["pending", "processing"]
})
assert len(tasks) == 10
def test_list_with_like_filter(self, task_repository, sample_tasks):
"""测试 LIKE 过滤"""
# 注意SQLite 的 LIKE 是大小写不敏感的
tasks = task_repository.list(filters={"model__like": "%flux%"})
assert len(tasks) == 10
class TestBaseRepositorySorting:
"""测试排序功能"""
def test_list_with_asc_sort(self, task_repository, sample_tasks):
"""测试升序排序"""
tasks = task_repository.list(sort_by="created_at", sort_order="asc")
# 验证按创建时间升序
for i in range(len(tasks) - 1):
assert tasks[i].created_at <= tasks[i + 1].created_at
def test_list_with_desc_sort(self, task_repository, sample_tasks):
"""测试降序排序"""
tasks = task_repository.list(sort_by="created_at", sort_order="desc")
# 验证按创建时间降序
for i in range(len(tasks) - 1):
assert tasks[i].created_at >= tasks[i + 1].created_at
class TestBaseRepositoryPagination:
"""测试分页功能"""
def test_list_with_pagination(self, task_repository, sample_tasks):
"""测试分页"""
# 第一页
page1 = task_repository.list(skip=0, limit=3)
assert len(page1) == 3
# 第二页
page2 = task_repository.list(skip=3, limit=3)
assert len(page2) == 3
# 验证不重复
page1_ids = {task.id for task in page1}
page2_ids = {task.id for task in page2}
assert len(page1_ids & page2_ids) == 0
def test_list_paginated(self, task_repository, sample_tasks):
"""测试分页方法"""
records, total = task_repository.list_paginated(page=1, page_size=3)
assert len(records) == 3
assert total == 10
# 第二页
records, total = task_repository.list_paginated(page=2, page_size=3)
assert len(records) == 3
assert total == 10
def test_count(self, task_repository, sample_tasks):
"""测试计数"""
total = task_repository.count()
assert total == 10
# 带过滤的计数
count = task_repository.count(filters={"type": "image"})
assert count == 5
class TestBaseRepositoryBatchOperations:
"""测试批量操作"""
def test_create_many(self, task_repository, session):
"""测试批量创建"""
tasks = [
TaskDB(
id=f"batch_task_{i}",
type="image",
status="pending",
model="flux-dev",
params={},
created_at=datetime.now().timestamp(),
updated_at=datetime.now().timestamp()
)
for i in range(5)
]
created = task_repository.create_many(tasks)
assert len(created) == 5
# 验证已创建
for task in created:
retrieved = task_repository.get(task.id)
assert retrieved is not None
class TestTaskRepositorySpecific:
"""测试 TaskRepository 特定方法"""
def test_list_by_status(self, task_repository, sample_tasks):
"""测试按状态列出任务"""
tasks = task_repository.list_by_status("pending")
assert len(tasks) == 3
assert all(task.status == "pending" for task in tasks)
def test_list_by_user(self, task_repository, sample_tasks):
"""测试按用户列出任务"""
tasks = task_repository.list_by_user("user_1")
assert len(tasks) == 5
assert all(task.user_id == "user_1" for task in tasks)
def test_list_by_project(self, task_repository, sample_tasks):
"""测试按项目列出任务"""
tasks = task_repository.list_by_project("project_1")
assert len(tasks) == 10
assert all(task.project_id == "project_1" for task in tasks)
def test_count_by_status(self, task_repository, sample_tasks):
"""测试按状态计数"""
count = task_repository.count_by_status("pending")
assert count == 3
def test_count_by_user(self, task_repository, sample_tasks):
"""测试按用户计数"""
count = task_repository.count_by_user("user_1")
assert count == 5
class TestQueryPerformanceTracking:
"""测试查询性能跟踪"""
def test_query_stats_tracking(self, task_repository, sample_tasks):
"""测试查询统计跟踪"""
# 执行一些查询
task_repository.list(limit=5)
task_repository.get("task_0")
task_repository.count()
# 获取统计信息
stats = task_repository.get_query_stats()
assert stats["query_count"] >= 3
assert stats["total_query_time"] > 0
assert stats["avg_query_time"] > 0