Initial commit: Pixel AI comic/video creation platform
- FastAPI backend with SQLModel, Alembic migrations, AgentScope agents - Next.js 15 frontend with React 19, Tailwind, Zustand, React Flow - Multi-provider AI system (DashScope, Kling, MiniMax, Volcengine, OpenAI, etc.) - All HTTP clients migrated from sync requests to async httpx - Admin-managed API keys via environment variables - SSRF vulnerability fixed in ensure_url()
This commit is contained in:
1
backend/src/__init__.py
Normal file
1
backend/src/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Backend application package."""
|
||||
33
backend/src/admin_config.py
Normal file
33
backend/src/admin_config.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from sqladmin import Admin, ModelView
|
||||
from src.models.entities import ProjectDB, AssetDB, EpisodeDB, StoryboardDB
|
||||
|
||||
class ProjectAdmin(ModelView, model=ProjectDB):
|
||||
column_list = [ProjectDB.id, ProjectDB.name, ProjectDB.type, ProjectDB.status, ProjectDB.created_at]
|
||||
column_searchable_list = [ProjectDB.name, ProjectDB.id]
|
||||
column_sortable_list = [ProjectDB.created_at]
|
||||
icon = "fa-solid fa-diagram-project"
|
||||
|
||||
class AssetAdmin(ModelView, model=AssetDB):
|
||||
column_list = [AssetDB.id, AssetDB.name, AssetDB.type, AssetDB.project_id]
|
||||
column_searchable_list = [AssetDB.name, AssetDB.id, AssetDB.type]
|
||||
# 列_filters = ["type", "project_id"]
|
||||
# 列_filters = [AssetDB.type, AssetDB.project_id]
|
||||
icon = "fa-solid fa-cube"
|
||||
|
||||
class EpisodeAdmin(ModelView, model=EpisodeDB):
|
||||
column_list = [EpisodeDB.id, EpisodeDB.title, EpisodeDB.order_index, EpisodeDB.status, EpisodeDB.project_id]
|
||||
column_searchable_list = [EpisodeDB.title, EpisodeDB.id]
|
||||
# 列_filters = [EpisodeDB.status, EpisodeDB.project_id]
|
||||
icon = "fa-solid fa-film"
|
||||
|
||||
class StoryboardAdmin(ModelView, model=StoryboardDB):
|
||||
column_list = [StoryboardDB.id, StoryboardDB.project_id, StoryboardDB.episode_id, StoryboardDB.order_index]
|
||||
# 列_filters = [StoryboardDB.project_id, StoryboardDB.episode_id]
|
||||
icon = "fa-solid fa-image"
|
||||
|
||||
def setup_admin(app, engine):
|
||||
admin = Admin(app, engine, title="Pixel管理后台")
|
||||
admin.add_view(ProjectAdmin)
|
||||
admin.add_view(AssetAdmin)
|
||||
admin.add_view(EpisodeAdmin)
|
||||
admin.add_view(StoryboardAdmin)
|
||||
32
backend/src/api/admin/__init__.py
Normal file
32
backend/src/api/admin/__init__.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""
|
||||
Admin API Package
|
||||
|
||||
管理 API 模块,包含:
|
||||
- dashboard: 仪表板统计和系统资源路由
|
||||
- users: 用户管理路由
|
||||
- projects: 项目管理路由
|
||||
- tasks: 任务管理路由
|
||||
- settings: 系统设置路由
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from .dashboard import router as dashboard_router
|
||||
from .users import router as users_router
|
||||
from .projects import router as projects_router
|
||||
from .tasks import router as tasks_router
|
||||
from .settings import router as settings_router
|
||||
|
||||
# 创建主路由器
|
||||
router = APIRouter(prefix="/admin", tags=["admin"])
|
||||
|
||||
# 包含所有子路由
|
||||
router.include_router(dashboard_router)
|
||||
router.include_router(users_router)
|
||||
router.include_router(projects_router)
|
||||
router.include_router(tasks_router)
|
||||
router.include_router(settings_router)
|
||||
|
||||
__all__ = [
|
||||
"router",
|
||||
]
|
||||
85
backend/src/api/admin/dashboard.py
Normal file
85
backend/src/api/admin/dashboard.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""
|
||||
Admin API - Dashboard Routes
|
||||
|
||||
包含仪表板统计和系统资源相关的 API 路由。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
|
||||
from src.auth.dependencies import require_admin
|
||||
from src.auth.models import UserAuth
|
||||
from src.models.schemas import ResponseModel
|
||||
from src.services.admin_service import admin_service
|
||||
from src.utils.errors import ErrorCode, AppException
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/dashboard", tags=["admin-dashboard"])
|
||||
|
||||
|
||||
@router.get("/stats", response_model=ResponseModel)
|
||||
async def get_dashboard_stats(
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
) -> ResponseModel:
|
||||
"""
|
||||
Get dashboard statistics
|
||||
|
||||
Returns counts of users, projects, and tasks by status.
|
||||
Requires admin privileges.
|
||||
"""
|
||||
try:
|
||||
stats = await admin_service.get_dashboard_stats()
|
||||
return ResponseModel(data=stats.model_dump())
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting dashboard stats: {e}")
|
||||
raise AppException(
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
message=str(e),
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.get("/system", response_model=ResponseModel)
|
||||
async def get_system_resources(
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
) -> ResponseModel:
|
||||
"""
|
||||
Get system resource information
|
||||
|
||||
Returns CPU, memory, disk usage and uptime.
|
||||
Requires admin privileges.
|
||||
"""
|
||||
try:
|
||||
resources = await admin_service.get_system_resources()
|
||||
return ResponseModel(data=resources.model_dump())
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting system resources: {e}")
|
||||
raise AppException(
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
message=str(e),
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.get("/activity", response_model=ResponseModel)
|
||||
async def get_recent_activity(
|
||||
limit: int = Query(20, ge=1, le=100),
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
) -> ResponseModel:
|
||||
"""
|
||||
Get recent system activity
|
||||
|
||||
Returns recent user, project, and task activities.
|
||||
Requires admin privileges.
|
||||
"""
|
||||
try:
|
||||
activity = await admin_service.get_recent_activity(limit=limit)
|
||||
return ResponseModel(data=activity.model_dump())
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting recent activity: {e}")
|
||||
raise AppException(
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
message=str(e),
|
||||
status_code=500
|
||||
)
|
||||
147
backend/src/api/admin/projects.py
Normal file
147
backend/src/api/admin/projects.py
Normal file
@@ -0,0 +1,147 @@
|
||||
"""
|
||||
Admin API - Projects Routes
|
||||
|
||||
包含项目管理相关的 API 路由。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, Request
|
||||
|
||||
from src.auth.dependencies import require_admin
|
||||
from src.auth.models import UserAuth
|
||||
from src.models.schemas import ResponseModel
|
||||
from src.services.admin_service import admin_service
|
||||
from src.utils.pagination import Paginator
|
||||
from src.utils.errors import ErrorCode, AppException
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/projects", tags=["admin-projects"])
|
||||
|
||||
|
||||
@router.get("", response_model=ResponseModel)
|
||||
async def list_projects(
|
||||
request: Request,
|
||||
page: int = Query(1, ge=1, description="Page number, starting from 1"),
|
||||
page_size: int = Query(20, ge=1, le=100, description="Items per page"),
|
||||
sort: Optional[str] = Query(None, description="Sort field, format: field:asc or field:desc"),
|
||||
filter: Optional[str] = Query(None, description="Filter conditions, JSON format"),
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
) -> ResponseModel:
|
||||
"""
|
||||
List all projects with pagination
|
||||
|
||||
Supports filtering by status, type, and search by name.
|
||||
Requires admin privileges.
|
||||
"""
|
||||
try:
|
||||
# Parse filters
|
||||
filters = {}
|
||||
if filter:
|
||||
import json
|
||||
try:
|
||||
filters = json.loads(filter)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Parse sort
|
||||
sort_by = "created_at"
|
||||
sort_order = "desc"
|
||||
if sort:
|
||||
parts = sort.split(":")
|
||||
if len(parts) == 2:
|
||||
sort_by = parts[0]
|
||||
sort_order = parts[1]
|
||||
|
||||
items, total = await admin_service.list_projects(
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
filters=filters if filters else None,
|
||||
sort_by=sort_by,
|
||||
sort_order=sort_order,
|
||||
)
|
||||
|
||||
paginator = Paginator(
|
||||
items=[item.model_dump(by_alias=True) for item in items],
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
return paginator.to_response(request)
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing projects: {e}")
|
||||
raise AppException(
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
message=str(e),
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{project_id}", response_model=ResponseModel)
|
||||
async def get_project(
|
||||
project_id: str,
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
) -> ResponseModel:
|
||||
"""
|
||||
Get project details by ID
|
||||
|
||||
Requires admin privileges.
|
||||
"""
|
||||
try:
|
||||
project = await admin_service.get_project(project_id)
|
||||
if not project:
|
||||
raise AppException(
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
message="Project not found",
|
||||
status_code=404
|
||||
)
|
||||
return ResponseModel(data=project.model_dump())
|
||||
except AppException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting project: {e}")
|
||||
raise AppException(
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
message=str(e),
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{project_id}", response_model=ResponseModel)
|
||||
async def delete_project(
|
||||
project_id: str,
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
) -> ResponseModel:
|
||||
"""
|
||||
Delete a project
|
||||
|
||||
Requires admin privileges.
|
||||
"""
|
||||
try:
|
||||
success = await admin_service.delete_project(project_id)
|
||||
if not success:
|
||||
raise AppException(
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
message="Project not found",
|
||||
status_code=404
|
||||
)
|
||||
|
||||
return ResponseModel(
|
||||
data={
|
||||
"id": project_id,
|
||||
"deleted": True,
|
||||
"message": "Project deleted successfully",
|
||||
}
|
||||
)
|
||||
except AppException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting project: {e}")
|
||||
raise AppException(
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
message=str(e),
|
||||
status_code=500
|
||||
)
|
||||
81
backend/src/api/admin/settings.py
Normal file
81
backend/src/api/admin/settings.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""
|
||||
Admin API - Settings Routes
|
||||
|
||||
包含系统设置相关的 API 路由。
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from src.auth.dependencies import require_admin
|
||||
from src.auth.models import UserAuth
|
||||
from src.models.schemas import ResponseModel
|
||||
from src.models.admin_schemas import SystemSettingUpdateRequest
|
||||
from src.services.admin_service import admin_service
|
||||
from src.utils.errors import ErrorCode, AppException
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/settings", tags=["admin-settings"])
|
||||
|
||||
|
||||
@router.get("", response_model=ResponseModel)
|
||||
async def get_system_settings(
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
) -> ResponseModel:
|
||||
"""
|
||||
Get system settings
|
||||
|
||||
Returns all configurable system settings.
|
||||
Requires admin privileges.
|
||||
"""
|
||||
try:
|
||||
settings = await admin_service.get_system_settings()
|
||||
return ResponseModel(data=settings.model_dump())
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting system settings: {e}")
|
||||
raise AppException(
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
message=str(e),
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.put("/{key}", response_model=ResponseModel)
|
||||
async def update_system_setting(
|
||||
key: str,
|
||||
request: SystemSettingUpdateRequest,
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
) -> ResponseModel:
|
||||
"""
|
||||
Update a system setting
|
||||
|
||||
Requires admin privileges.
|
||||
"""
|
||||
try:
|
||||
setting = await admin_service.update_system_setting(key, request.value)
|
||||
if not setting:
|
||||
raise AppException(
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
message="Setting not found",
|
||||
status_code=404
|
||||
)
|
||||
|
||||
return ResponseModel(
|
||||
data={
|
||||
"key": key,
|
||||
"value": request.value,
|
||||
"updated_at": setting.updated_at,
|
||||
"message": "Setting updated successfully",
|
||||
}
|
||||
)
|
||||
except AppException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating system setting: {e}")
|
||||
raise AppException(
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
message=str(e),
|
||||
status_code=500
|
||||
)
|
||||
443
backend/src/api/admin/tasks.py
Normal file
443
backend/src/api/admin/tasks.py
Normal file
@@ -0,0 +1,443 @@
|
||||
"""
|
||||
Admin API - Tasks Routes
|
||||
|
||||
包含任务管理相关的 API 路由。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional, List
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, Request
|
||||
|
||||
from src.auth.dependencies import require_admin
|
||||
from src.auth.models import UserAuth
|
||||
from src.models.schemas import ResponseModel
|
||||
from src.models.admin_schemas import SystemSettingUpdateRequest
|
||||
from src.services.admin_service import admin_service
|
||||
from src.utils.pagination import Paginator
|
||||
from src.utils.errors import ErrorCode, AppException
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/tasks", tags=["admin-tasks"])
|
||||
|
||||
|
||||
@router.get("", response_model=ResponseModel)
|
||||
async def list_tasks(
|
||||
request: Request,
|
||||
page: int = Query(1, ge=1, description="Page number, starting from 1"),
|
||||
page_size: int = Query(20, ge=1, le=100, description="Items per page"),
|
||||
sort: Optional[str] = Query(None, description="Sort field, format: field:asc or field:desc"),
|
||||
filter: Optional[str] = Query(None, description="Filter conditions, JSON format"),
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
) -> ResponseModel:
|
||||
"""
|
||||
List all tasks with pagination
|
||||
|
||||
Supports filtering by status, type, provider, user_id, and project_id.
|
||||
Requires admin privileges.
|
||||
"""
|
||||
try:
|
||||
# Parse filters
|
||||
filters = {}
|
||||
if filter:
|
||||
import json
|
||||
try:
|
||||
filters = json.loads(filter)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Parse sort
|
||||
sort_by = "created_at"
|
||||
sort_order = "desc"
|
||||
if sort:
|
||||
parts = sort.split(":")
|
||||
if len(parts) == 2:
|
||||
sort_by = parts[0]
|
||||
sort_order = parts[1]
|
||||
|
||||
items, total = await admin_service.list_tasks(
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
filters=filters if filters else None,
|
||||
sort_by=sort_by,
|
||||
sort_order=sort_order,
|
||||
)
|
||||
|
||||
paginator = Paginator(
|
||||
items=[item.model_dump() for item in items],
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
return paginator.to_response(request)
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing tasks: {e}")
|
||||
raise AppException(
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
message=str(e),
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.get("/stats", response_model=ResponseModel)
|
||||
async def get_task_stats(
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
) -> ResponseModel:
|
||||
"""
|
||||
Get task statistics
|
||||
|
||||
Returns counts by status, type, provider, and success rate.
|
||||
Requires admin privileges.
|
||||
"""
|
||||
try:
|
||||
stats = await admin_service.get_task_stats()
|
||||
return ResponseModel(data=stats.model_dump())
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting task stats: {e}")
|
||||
raise AppException(
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
message=str(e),
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.get("/queue", response_model=ResponseModel)
|
||||
async def get_task_queue_status(
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
) -> ResponseModel:
|
||||
"""
|
||||
Get task queue status
|
||||
|
||||
Returns queue length, processing count, and worker status.
|
||||
Requires admin privileges.
|
||||
"""
|
||||
try:
|
||||
status = await admin_service.get_task_queue_status()
|
||||
return ResponseModel(data=status.model_dump())
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting task queue status: {e}")
|
||||
raise AppException(
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
message=str(e),
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{task_id}", response_model=ResponseModel)
|
||||
async def get_task_detail(
|
||||
task_id: str,
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
) -> ResponseModel:
|
||||
"""
|
||||
Get task detail by ID
|
||||
|
||||
Returns detailed task information including user and project details.
|
||||
Requires admin privileges.
|
||||
"""
|
||||
try:
|
||||
task = await admin_service.get_task_detail(task_id)
|
||||
if not task:
|
||||
raise AppException(
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
message="Task not found",
|
||||
status_code=404
|
||||
)
|
||||
return ResponseModel(data=task.model_dump())
|
||||
except AppException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting task detail: {e}")
|
||||
raise AppException(
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
message=str(e),
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{task_id}/retry", response_model=ResponseModel)
|
||||
async def retry_task(
|
||||
task_id: str,
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
) -> ResponseModel:
|
||||
"""
|
||||
Retry a failed task
|
||||
|
||||
Requires admin privileges.
|
||||
"""
|
||||
try:
|
||||
task = await admin_service.retry_task(task_id)
|
||||
if not task:
|
||||
raise AppException(
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
message="Task not found or cannot be retried",
|
||||
status_code=404
|
||||
)
|
||||
|
||||
return ResponseModel(
|
||||
data={
|
||||
"id": task_id,
|
||||
"status": task.status,
|
||||
"message": "Task queued for retry",
|
||||
"retry_count": task.retry_count,
|
||||
}
|
||||
)
|
||||
except AppException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrying task: {e}")
|
||||
raise AppException(
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
message=str(e),
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{task_id}", response_model=ResponseModel)
|
||||
async def delete_task(
|
||||
task_id: str,
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
) -> ResponseModel:
|
||||
"""
|
||||
Delete a task
|
||||
|
||||
Requires admin privileges.
|
||||
"""
|
||||
try:
|
||||
success = await admin_service.delete_task(task_id)
|
||||
if not success:
|
||||
raise AppException(
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
message="Task not found",
|
||||
status_code=404
|
||||
)
|
||||
|
||||
return ResponseModel(
|
||||
data={
|
||||
"id": task_id,
|
||||
"deleted": True,
|
||||
"message": "Task deleted successfully",
|
||||
}
|
||||
)
|
||||
except AppException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting task: {e}")
|
||||
raise AppException(
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
message=str(e),
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.post("/batch-retry", response_model=ResponseModel)
|
||||
async def batch_retry_tasks(
|
||||
task_ids: List[str],
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
) -> ResponseModel:
|
||||
"""
|
||||
Batch retry failed tasks
|
||||
|
||||
Requires admin privileges.
|
||||
"""
|
||||
try:
|
||||
from sqlmodel import Session, select
|
||||
from src.config.database import engine
|
||||
from src.models.entities import TaskDB
|
||||
from datetime import datetime
|
||||
|
||||
retried = []
|
||||
failed = []
|
||||
|
||||
with Session(engine) as session:
|
||||
for task_id in task_ids:
|
||||
task = session.get(TaskDB, task_id)
|
||||
if not task:
|
||||
failed.append({"id": task_id, "reason": "Task not found"})
|
||||
continue
|
||||
|
||||
if task.status not in ["failed", "timeout"]:
|
||||
failed.append({"id": task_id, "reason": f"Cannot retry task with status: {task.status}"})
|
||||
continue
|
||||
|
||||
task.status = "pending"
|
||||
task.retry_count = 0
|
||||
task.error = None
|
||||
task.updated_at = datetime.now().timestamp()
|
||||
session.add(task)
|
||||
retried.append(task_id)
|
||||
|
||||
session.commit()
|
||||
|
||||
return ResponseModel(
|
||||
data={
|
||||
"retried": retried,
|
||||
"failed": failed,
|
||||
"total": len(task_ids),
|
||||
"success_count": len(retried),
|
||||
"message": f"Successfully queued {len(retried)} tasks for retry",
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error batch retrying tasks: {e}")
|
||||
raise AppException(
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
message=str(e),
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.post("/batch-cancel", response_model=ResponseModel)
|
||||
async def batch_cancel_tasks(
|
||||
task_ids: List[str],
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
) -> ResponseModel:
|
||||
"""
|
||||
Batch cancel pending/processing tasks
|
||||
|
||||
Requires admin privileges.
|
||||
"""
|
||||
try:
|
||||
from sqlmodel import Session, select
|
||||
from src.config.database import engine
|
||||
from src.models.entities import TaskDB
|
||||
from datetime import datetime
|
||||
|
||||
cancelled = []
|
||||
failed = []
|
||||
|
||||
with Session(engine) as session:
|
||||
for task_id in task_ids:
|
||||
task = session.get(TaskDB, task_id)
|
||||
if not task:
|
||||
failed.append({"id": task_id, "reason": "Task not found"})
|
||||
continue
|
||||
|
||||
if task.status not in ["pending", "processing"]:
|
||||
failed.append({"id": task_id, "reason": f"Cannot cancel task with status: {task.status}"})
|
||||
continue
|
||||
|
||||
task.status = "cancelled"
|
||||
task.updated_at = datetime.now().timestamp()
|
||||
session.add(task)
|
||||
cancelled.append(task_id)
|
||||
|
||||
session.commit()
|
||||
|
||||
return ResponseModel(
|
||||
data={
|
||||
"cancelled": cancelled,
|
||||
"failed": failed,
|
||||
"total": len(task_ids),
|
||||
"success_count": len(cancelled),
|
||||
"message": f"Successfully cancelled {len(cancelled)} tasks",
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error batch cancelling tasks: {e}")
|
||||
raise AppException(
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
message=str(e),
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.post("/batch-delete", response_model=ResponseModel)
|
||||
async def batch_delete_tasks(
|
||||
task_ids: List[str],
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
) -> ResponseModel:
|
||||
"""
|
||||
Batch delete tasks
|
||||
|
||||
Requires admin privileges.
|
||||
"""
|
||||
try:
|
||||
from sqlmodel import Session, select
|
||||
from src.config.database import engine
|
||||
from src.models.entities import TaskDB
|
||||
|
||||
deleted = []
|
||||
failed = []
|
||||
|
||||
with Session(engine) as session:
|
||||
for task_id in task_ids:
|
||||
task = session.get(TaskDB, task_id)
|
||||
if not task:
|
||||
failed.append({"id": task_id, "reason": "Task not found"})
|
||||
continue
|
||||
|
||||
session.delete(task)
|
||||
deleted.append(task_id)
|
||||
|
||||
session.commit()
|
||||
|
||||
return ResponseModel(
|
||||
data={
|
||||
"deleted": deleted,
|
||||
"failed": failed,
|
||||
"total": len(task_ids),
|
||||
"success_count": len(deleted),
|
||||
"message": f"Successfully deleted {len(deleted)} tasks",
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error batch deleting tasks: {e}")
|
||||
raise AppException(
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
message=str(e),
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.post("/cleanup-completed", response_model=ResponseModel)
|
||||
async def cleanup_completed_tasks(
|
||||
days: int = Query(30, ge=1, le=365, description="清理多少天前的已完成任务"),
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
) -> ResponseModel:
|
||||
"""
|
||||
Cleanup completed tasks older than specified days
|
||||
|
||||
Requires admin privileges.
|
||||
"""
|
||||
try:
|
||||
from sqlmodel import Session, select
|
||||
from src.config.database import engine
|
||||
from src.models.entities import TaskDB
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
cutoff_timestamp = time.time() - (days * 24 * 60 * 60)
|
||||
|
||||
with Session(engine) as session:
|
||||
# Find completed tasks older than cutoff
|
||||
old_tasks = session.exec(
|
||||
select(TaskDB).where(
|
||||
TaskDB.status == "success",
|
||||
TaskDB.completed_at < cutoff_timestamp
|
||||
)
|
||||
).all()
|
||||
|
||||
deleted_count = 0
|
||||
for task in old_tasks:
|
||||
session.delete(task)
|
||||
deleted_count += 1
|
||||
|
||||
session.commit()
|
||||
|
||||
return ResponseModel(
|
||||
data={
|
||||
"deleted_count": deleted_count,
|
||||
"days": days,
|
||||
"cutoff_date": datetime.fromtimestamp(cutoff_timestamp).isoformat(),
|
||||
"message": f"Successfully deleted {deleted_count} completed tasks older than {days} days",
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up completed tasks: {e}")
|
||||
raise AppException(
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
message=str(e),
|
||||
status_code=500
|
||||
)
|
||||
317
backend/src/api/admin/users.py
Normal file
317
backend/src/api/admin/users.py
Normal file
@@ -0,0 +1,317 @@
|
||||
"""
|
||||
Admin API - Users Routes
|
||||
|
||||
包含用户管理相关的 API 路由。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, Request
|
||||
|
||||
from src.auth.dependencies import require_admin
|
||||
from src.auth.models import UserAuth
|
||||
from src.models.schemas import ResponseModel, PaginationParams
|
||||
from src.models.admin_schemas import (
|
||||
AdminUserCreateRequest,
|
||||
AdminUserUpdateRequest,
|
||||
)
|
||||
from src.services.admin_service import admin_service
|
||||
from src.services.user_service import user_service
|
||||
from src.utils.pagination import Paginator
|
||||
from src.utils.errors import ErrorCode, AppException
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/users", tags=["admin-users"])
|
||||
|
||||
|
||||
@router.get("", response_model=ResponseModel)
|
||||
async def list_users(
|
||||
request: Request,
|
||||
page: int = Query(1, ge=1, description="Page number, starting from 1"),
|
||||
page_size: int = Query(20, ge=1, le=100, description="Items per page"),
|
||||
sort: Optional[str] = Query(None, description="Sort field, format: field:asc or field:desc"),
|
||||
filter: Optional[str] = Query(None, description="Filter conditions, JSON format"),
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
) -> ResponseModel:
|
||||
"""
|
||||
List all users with pagination
|
||||
|
||||
Supports filtering by is_active, is_superuser, and search by username/email.
|
||||
Requires admin privileges.
|
||||
"""
|
||||
try:
|
||||
# Parse filters
|
||||
filters = {}
|
||||
if filter:
|
||||
import json
|
||||
try:
|
||||
filters = json.loads(filter)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Parse sort
|
||||
sort_by = "created_at"
|
||||
sort_order = "desc"
|
||||
if sort:
|
||||
parts = sort.split(":")
|
||||
if len(parts) == 2:
|
||||
sort_by = parts[0]
|
||||
sort_order = parts[1]
|
||||
|
||||
items, total = await admin_service.list_users(
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
filters=filters if filters else None,
|
||||
sort_by=sort_by,
|
||||
sort_order=sort_order,
|
||||
)
|
||||
|
||||
paginator = Paginator(
|
||||
items=[item.model_dump() for item in items],
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
return paginator.to_response(request)
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing users: {e}")
|
||||
raise AppException(
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
message=str(e),
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{user_id}", response_model=ResponseModel)
|
||||
async def get_user(
|
||||
user_id: str,
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
) -> ResponseModel:
|
||||
"""
|
||||
Get user details by ID
|
||||
|
||||
Requires admin privileges.
|
||||
"""
|
||||
try:
|
||||
user = await admin_service.get_user(user_id)
|
||||
if not user:
|
||||
raise AppException(
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
message="User not found",
|
||||
status_code=404
|
||||
)
|
||||
return ResponseModel(data=user.model_dump())
|
||||
except AppException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user: {e}")
|
||||
raise AppException(
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
message=str(e),
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.post("", response_model=ResponseModel)
|
||||
async def create_user(
|
||||
request: AdminUserCreateRequest,
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
) -> ResponseModel:
|
||||
"""
|
||||
Create a new user
|
||||
|
||||
Requires admin privileges.
|
||||
"""
|
||||
try:
|
||||
# Check if username already exists
|
||||
existing_user = await user_service.get_user_by_username(request.username)
|
||||
if existing_user:
|
||||
raise AppException(
|
||||
code=ErrorCode.CONFLICT,
|
||||
message="Username already exists",
|
||||
status_code=400
|
||||
)
|
||||
|
||||
# Check if email already exists
|
||||
existing_email = await user_service.get_user_by_email(request.email)
|
||||
if existing_email:
|
||||
raise AppException(
|
||||
code=ErrorCode.CONFLICT,
|
||||
message="Email already registered",
|
||||
status_code=400
|
||||
)
|
||||
|
||||
# Create user using user_service
|
||||
user = await user_service.create_user(
|
||||
username=request.username,
|
||||
email=request.email,
|
||||
password=request.password
|
||||
)
|
||||
|
||||
# Update additional fields if provided
|
||||
update_data = {}
|
||||
if request.is_active is not None:
|
||||
update_data["is_active"] = request.is_active
|
||||
if request.is_superuser is not None:
|
||||
update_data["is_superuser"] = request.is_superuser
|
||||
if request.roles:
|
||||
update_data["roles"] = request.roles
|
||||
if request.permissions:
|
||||
update_data["permissions"] = request.permissions
|
||||
|
||||
if update_data:
|
||||
user = await admin_service.update_user(user.id, update_data)
|
||||
|
||||
logger.info(f"Admin {current_user.username} created user: {user.username} ({user.id})")
|
||||
|
||||
return ResponseModel(
|
||||
data=user.model_dump(),
|
||||
message="User created successfully"
|
||||
)
|
||||
except AppException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating user: {e}")
|
||||
raise AppException(
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
message=str(e),
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.put("/{user_id}", response_model=ResponseModel)
|
||||
async def update_user(
|
||||
user_id: str,
|
||||
request: AdminUserUpdateRequest,
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
) -> ResponseModel:
|
||||
"""
|
||||
Update user information
|
||||
|
||||
Requires admin privileges.
|
||||
"""
|
||||
try:
|
||||
# Prevent self-demotion from superuser
|
||||
if user_id == current_user.id and request.is_superuser is False:
|
||||
raise AppException(
|
||||
code=ErrorCode.FORBIDDEN,
|
||||
message="Cannot remove your own superuser status",
|
||||
status_code=400
|
||||
)
|
||||
|
||||
update_data = request.model_dump(exclude_unset=True)
|
||||
user = await admin_service.update_user(user_id, update_data)
|
||||
|
||||
if not user:
|
||||
raise AppException(
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
message="User not found",
|
||||
status_code=404
|
||||
)
|
||||
|
||||
return ResponseModel(data=user.model_dump())
|
||||
except AppException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating user: {e}")
|
||||
raise AppException(
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
message=str(e),
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{user_id}/toggle-active", response_model=ResponseModel)
|
||||
async def toggle_user_active(
|
||||
user_id: str,
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
) -> ResponseModel:
|
||||
"""
|
||||
Toggle user active status
|
||||
|
||||
Requires admin privileges.
|
||||
"""
|
||||
try:
|
||||
# Prevent self-deactivation
|
||||
if user_id == current_user.id:
|
||||
raise AppException(
|
||||
code=ErrorCode.FORBIDDEN,
|
||||
message="Cannot deactivate your own account",
|
||||
status_code=400
|
||||
)
|
||||
|
||||
user = await admin_service.get_user(user_id)
|
||||
if not user:
|
||||
raise AppException(
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
message="User not found",
|
||||
status_code=404
|
||||
)
|
||||
|
||||
new_status = not user.is_active
|
||||
updated_user = await admin_service.toggle_user_active(user_id, new_status)
|
||||
|
||||
return ResponseModel(
|
||||
data={
|
||||
"id": user_id,
|
||||
"is_active": new_status,
|
||||
"message": f"User {'activated' if new_status else 'deactivated'} successfully",
|
||||
}
|
||||
)
|
||||
except AppException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error toggling user active status: {e}")
|
||||
raise AppException(
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
message=str(e),
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{user_id}", response_model=ResponseModel)
|
||||
async def delete_user(
|
||||
user_id: str,
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
) -> ResponseModel:
|
||||
"""
|
||||
Delete a user
|
||||
|
||||
Requires admin privileges.
|
||||
"""
|
||||
try:
|
||||
# Prevent self-deletion
|
||||
if user_id == current_user.id:
|
||||
raise AppException(
|
||||
code=ErrorCode.FORBIDDEN,
|
||||
message="Cannot delete your own account",
|
||||
status_code=400
|
||||
)
|
||||
|
||||
success = await admin_service.delete_user(user_id)
|
||||
if not success:
|
||||
raise AppException(
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
message="User not found",
|
||||
status_code=404
|
||||
)
|
||||
|
||||
return ResponseModel(
|
||||
data={
|
||||
"id": user_id,
|
||||
"deleted": True,
|
||||
"message": "User deleted successfully",
|
||||
}
|
||||
)
|
||||
except AppException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting user: {e}")
|
||||
raise AppException(
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
message=str(e),
|
||||
status_code=500
|
||||
)
|
||||
206
backend/src/api/audit_logs.py
Normal file
206
backend/src/api/audit_logs.py
Normal file
@@ -0,0 +1,206 @@
|
||||
"""
|
||||
Audit Log API
|
||||
|
||||
操作审计日志 API 端点。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional, Dict, Any, List
|
||||
from datetime import datetime
|
||||
from fastapi import APIRouter, Depends, Query, Request
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from src.auth.dependencies import require_admin
|
||||
from src.auth.models import UserAuth
|
||||
from src.models.schemas import ResponseModel
|
||||
from src.utils.pagination import Paginator
|
||||
from src.services.audit_log_service import audit_log_service
|
||||
from src.models.audit_log import AuditLogDB
|
||||
from src.utils.errors import ErrorCode, AppException
|
||||
|
||||
router = APIRouter(prefix="/admin/audit-logs", tags=["admin-audit-logs"])
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ===== 响应模型 =====
|
||||
|
||||
class AuditLogListItem(BaseModel):
|
||||
"""审计日志列表项"""
|
||||
id: str
|
||||
user_id: Optional[str]
|
||||
username: Optional[str]
|
||||
action: str
|
||||
resource_type: Optional[str]
|
||||
resource_id: Optional[str]
|
||||
ip_address: Optional[str]
|
||||
created_at: str # ISO format
|
||||
|
||||
|
||||
class AuditLogDetailResponse(BaseModel):
|
||||
"""审计日志详情"""
|
||||
id: str
|
||||
user_id: Optional[str]
|
||||
username: Optional[str]
|
||||
action: str
|
||||
resource_type: Optional[str]
|
||||
resource_id: Optional[str]
|
||||
ip_address: Optional[str]
|
||||
user_agent: Optional[str]
|
||||
details: Optional[Dict[str, Any]]
|
||||
created_at: str
|
||||
|
||||
|
||||
# ===== API 端点 =====
|
||||
|
||||
@router.get("", response_model=ResponseModel)
|
||||
async def list_audit_logs(
|
||||
request: Request,
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
|
||||
user_id: Optional[str] = Query(None, description="按用户 ID 过滤"),
|
||||
action: Optional[str] = Query(None, description="按操作类型过滤"),
|
||||
resource_type: Optional[str] = Query(None, description="按资源类型过滤"),
|
||||
resource_id: Optional[str] = Query(None, description="按资源 ID 过滤"),
|
||||
start_date: Optional[str] = Query(None, description="开始日期 (ISO format)"),
|
||||
end_date: Optional[str] = Query(None, description="结束日期 (ISO format)"),
|
||||
sort_by: str = Query("created_at", description="排序字段"),
|
||||
sort_order: str = Query("desc", description="排序方向"),
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
):
|
||||
"""
|
||||
列出审计日志
|
||||
|
||||
支持高级搜索:按用户、操作类型、资源类型、时间范围过滤。
|
||||
需要管理员权限。
|
||||
"""
|
||||
try:
|
||||
# 构建过滤条件
|
||||
filters = {}
|
||||
if user_id:
|
||||
filters["user_id"] = user_id
|
||||
if action:
|
||||
filters["action"] = action
|
||||
if resource_type:
|
||||
filters["resource_type"] = resource_type
|
||||
if resource_id:
|
||||
filters["resource_id"] = resource_id
|
||||
if start_date:
|
||||
filters["start_date"] = start_date
|
||||
if end_date:
|
||||
filters["end_date"] = end_date
|
||||
|
||||
logs, total = audit_log_service.list_logs(
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
filters=filters,
|
||||
sort_by=sort_by,
|
||||
sort_order=sort_order,
|
||||
)
|
||||
|
||||
# 转换为响应格式
|
||||
items = [
|
||||
AuditLogListItem(
|
||||
id=log.id,
|
||||
user_id=log.user_id,
|
||||
username=log.username,
|
||||
action=log.action,
|
||||
resource_type=log.resource_type,
|
||||
resource_id=log.resource_id,
|
||||
ip_address=log.ip_address,
|
||||
created_at=datetime.fromtimestamp(log.created_at).isoformat(),
|
||||
).model_dump()
|
||||
for log in logs
|
||||
]
|
||||
|
||||
paginator = Paginator(
|
||||
items=items,
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
return paginator.to_response(request)
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing audit logs: {e}")
|
||||
raise
|
||||
|
||||
|
||||
@router.get("/{log_id}", response_model=ResponseModel)
|
||||
async def get_audit_log(
|
||||
log_id: str,
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
):
|
||||
"""
|
||||
获取审计日志详情
|
||||
|
||||
需要管理员权限。
|
||||
"""
|
||||
try:
|
||||
log = audit_log_service.get_log(log_id)
|
||||
if not log:
|
||||
raise AppException(
|
||||
message="Audit log not found",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
|
||||
return ResponseModel(
|
||||
data=AuditLogDetailResponse(
|
||||
id=log.id,
|
||||
user_id=log.user_id,
|
||||
username=log.username,
|
||||
action=log.action,
|
||||
resource_type=log.resource_type,
|
||||
resource_id=log.resource_id,
|
||||
ip_address=log.ip_address,
|
||||
user_agent=log.user_agent,
|
||||
details=log.details,
|
||||
created_at=datetime.fromtimestamp(log.created_at).isoformat(),
|
||||
).model_dump()
|
||||
)
|
||||
except AppException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting audit log: {e}")
|
||||
raise
|
||||
|
||||
|
||||
@router.get("/export/csv", response_model=ResponseModel)
|
||||
async def export_audit_logs(
|
||||
user_id: Optional[str] = Query(None, description="按用户 ID 过滤"),
|
||||
action: Optional[str] = Query(None, description="按操作类型过滤"),
|
||||
resource_type: Optional[str] = Query(None, description="按资源类型过滤"),
|
||||
start_date: Optional[str] = Query(None, description="开始日期 (ISO format)"),
|
||||
end_date: Optional[str] = Query(None, description="结束日期 (ISO format)"),
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
):
|
||||
"""
|
||||
导出审计日志为 CSV
|
||||
|
||||
需要管理员权限。
|
||||
"""
|
||||
try:
|
||||
# 构建过滤条件
|
||||
filters = {}
|
||||
if user_id:
|
||||
filters["user_id"] = user_id
|
||||
if action:
|
||||
filters["action"] = action
|
||||
if resource_type:
|
||||
filters["resource_type"] = resource_type
|
||||
if start_date:
|
||||
filters["start_date"] = start_date
|
||||
if end_date:
|
||||
filters["end_date"] = end_date
|
||||
|
||||
csv_content = audit_log_service.export_logs(filters=filters)
|
||||
|
||||
return ResponseModel(
|
||||
data={
|
||||
"content": csv_content,
|
||||
"filename": f"audit_logs_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv",
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error exporting audit logs: {e}")
|
||||
raise
|
||||
927
backend/src/api/auth.py
Normal file
927
backend/src/api/auth.py
Normal file
@@ -0,0 +1,927 @@
|
||||
"""
|
||||
认证 API 路由
|
||||
|
||||
提供用户登录、注册、获取当前用户信息等接口。
|
||||
"""
|
||||
|
||||
import logging
|
||||
import io
|
||||
import uuid
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import APIRouter, status, Depends, UploadFile, File, Request
|
||||
from pydantic import BaseModel, EmailStr, Field
|
||||
from PIL import Image
|
||||
|
||||
from src.auth.jwt import (
|
||||
create_token_pair,
|
||||
verify_password,
|
||||
get_password_hash,
|
||||
verify_token,
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES,
|
||||
)
|
||||
from src.services.token_blacklist_service import token_blacklist_service
|
||||
from src.services.session_service import session_service
|
||||
from src.services.email_service import email_service
|
||||
from src.config.settings import REDIS_ENABLED, NODE_ENV
|
||||
from src.auth.dependencies import get_current_user, UserAuth
|
||||
from src.auth.models import RefreshTokenRequest
|
||||
from src.services.user_service import user_service
|
||||
from src.services.storage_service import storage_manager
|
||||
from src.models.schemas import ResponseModel
|
||||
from src.utils.errors import ErrorCode, AppException
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/auth", tags=["认证"])
|
||||
|
||||
|
||||
def _build_user_payload(user: UserAuth) -> dict:
|
||||
return {
|
||||
"id": user.id,
|
||||
"username": user.username,
|
||||
"email": user.email,
|
||||
"is_active": user.is_active,
|
||||
"is_superuser": user.is_superuser,
|
||||
}
|
||||
|
||||
|
||||
def _extract_bearer_token(request: Request) -> Optional[str]:
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if auth_header.startswith("Bearer "):
|
||||
return auth_header[7:]
|
||||
return None
|
||||
|
||||
|
||||
async def _get_access_payload(request: Request):
|
||||
access_token = _extract_bearer_token(request)
|
||||
if not access_token:
|
||||
return None
|
||||
return await verify_token(access_token, token_type="access")
|
||||
|
||||
|
||||
def _serialize_session(session_db, current_session_id: Optional[str] = None) -> dict:
|
||||
return {
|
||||
"id": session_db.id,
|
||||
"session_family_id": session_db.session_family_id,
|
||||
"status": session_db.status,
|
||||
"device_name": session_db.device_name,
|
||||
"ip_address": session_db.ip_address,
|
||||
"user_agent": session_db.user_agent,
|
||||
"created_at": session_db.created_at,
|
||||
"updated_at": session_db.updated_at,
|
||||
"expires_at": session_db.expires_at,
|
||||
"last_used_at": session_db.last_used_at,
|
||||
"revoked_at": session_db.revoked_at,
|
||||
"revoked_reason": session_db.revoked_reason,
|
||||
"current": session_db.id == current_session_id,
|
||||
}
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
"""登录请求"""
|
||||
|
||||
username: str = Field(..., description="用户名或邮箱")
|
||||
password: str = Field(..., description="密码")
|
||||
|
||||
|
||||
class RegisterRequest(BaseModel):
|
||||
"""注册请求"""
|
||||
|
||||
username: str = Field(..., min_length=3, max_length=50, description="用户名")
|
||||
email: EmailStr = Field(..., description="邮箱")
|
||||
password: str = Field(..., min_length=6, description="密码")
|
||||
|
||||
|
||||
class TokenResponse(BaseModel):
|
||||
"""Token 响应"""
|
||||
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
token_type: str = "bearer"
|
||||
expires_in: int = ACCESS_TOKEN_EXPIRE_MINUTES * 60 # 从配置读取
|
||||
|
||||
|
||||
class UserInfoResponse(BaseModel):
|
||||
"""用户信息响应"""
|
||||
|
||||
id: str
|
||||
username: str
|
||||
email: Optional[str]
|
||||
is_active: bool
|
||||
is_superuser: bool
|
||||
|
||||
|
||||
@router.post("/login", response_model=ResponseModel)
|
||||
async def login(request: LoginRequest, http_request: Request):
|
||||
"""
|
||||
用户登录
|
||||
|
||||
支持使用用户名或邮箱登录。
|
||||
成功返回 access_token 和 refresh_token。
|
||||
"""
|
||||
try:
|
||||
# 使用 authenticate_user 方法验证用户(包含密码验证)
|
||||
user = await user_service.authenticate_user(request.username, request.password)
|
||||
|
||||
if not user:
|
||||
raise AppException(
|
||||
code=ErrorCode.UNAUTHORIZED,
|
||||
message="Invalid credentials",
|
||||
status_code=401,
|
||||
details={"headers": {"WWW-Authenticate": "Bearer"}}
|
||||
)
|
||||
|
||||
# 检查用户是否激活
|
||||
if not user.is_active:
|
||||
raise AppException(
|
||||
code=ErrorCode.FORBIDDEN,
|
||||
message="User account is inactive",
|
||||
status_code=403
|
||||
)
|
||||
|
||||
# 更新最后登录时间
|
||||
await user_service.update_last_login(user.id)
|
||||
|
||||
session_id = str(uuid.uuid4())
|
||||
session_family_id = str(uuid.uuid4())
|
||||
tokens = create_token_pair(
|
||||
user_id=user.id,
|
||||
scopes=["user"],
|
||||
session_id=session_id,
|
||||
session_family_id=session_family_id,
|
||||
)
|
||||
session_service.create_session(
|
||||
user_id=user.id,
|
||||
refresh_token=tokens.refresh_token,
|
||||
session_id=session_id,
|
||||
session_family_id=session_family_id,
|
||||
ip_address=http_request.client.host if http_request.client else None,
|
||||
user_agent=http_request.headers.get("user-agent"),
|
||||
)
|
||||
|
||||
logger.info(f"User logged in: {user.username} ({user.id})")
|
||||
|
||||
return ResponseModel(
|
||||
data={
|
||||
"access_token": tokens.access_token,
|
||||
"refresh_token": tokens.refresh_token,
|
||||
"token_type": "bearer",
|
||||
"expires_in": ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
||||
"session_id": session_id,
|
||||
"user": _build_user_payload(user),
|
||||
}
|
||||
)
|
||||
|
||||
except AppException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Login error: {e}")
|
||||
raise AppException(
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
message="Internal server error",
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.post("/register", response_model=ResponseModel)
|
||||
async def register(request: RegisterRequest, http_request: Request):
|
||||
"""
|
||||
用户注册
|
||||
|
||||
创建新用户账号,用户名和邮箱必须唯一。
|
||||
"""
|
||||
try:
|
||||
# 检查用户名是否已存在
|
||||
existing_user = await user_service.get_user_by_username(request.username)
|
||||
if existing_user:
|
||||
raise AppException(
|
||||
code=ErrorCode.CONFLICT,
|
||||
message="Username already registered",
|
||||
status_code=400
|
||||
)
|
||||
|
||||
# 检查邮箱是否已存在
|
||||
existing_email = await user_service.get_user_by_email(request.email)
|
||||
if existing_email:
|
||||
raise AppException(
|
||||
code=ErrorCode.CONFLICT,
|
||||
message="Email already registered",
|
||||
status_code=400
|
||||
)
|
||||
|
||||
# 创建用户
|
||||
user = await user_service.create_user(
|
||||
username=request.username, email=request.email, password=request.password
|
||||
)
|
||||
|
||||
session_id = str(uuid.uuid4())
|
||||
session_family_id = str(uuid.uuid4())
|
||||
tokens = create_token_pair(
|
||||
user_id=user.id,
|
||||
scopes=["user"],
|
||||
session_id=session_id,
|
||||
session_family_id=session_family_id,
|
||||
)
|
||||
session_service.create_session(
|
||||
user_id=user.id,
|
||||
refresh_token=tokens.refresh_token,
|
||||
session_id=session_id,
|
||||
session_family_id=session_family_id,
|
||||
ip_address=http_request.client.host if http_request.client else None,
|
||||
user_agent=http_request.headers.get("user-agent"),
|
||||
)
|
||||
|
||||
logger.info(f"User registered: {user.username} ({user.id})")
|
||||
|
||||
return ResponseModel(
|
||||
data={
|
||||
"access_token": tokens.access_token,
|
||||
"refresh_token": tokens.refresh_token,
|
||||
"token_type": "bearer",
|
||||
"expires_in": ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
||||
"session_id": session_id,
|
||||
"user": _build_user_payload(user),
|
||||
}
|
||||
)
|
||||
|
||||
except AppException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Registration error: {e}")
|
||||
raise AppException(
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
message="Internal server error",
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=ResponseModel)
|
||||
async def refresh_token(request: RefreshTokenRequest, http_request: Request):
|
||||
"""
|
||||
刷新 access token
|
||||
|
||||
使用 refresh token 获取新的 access token。
|
||||
"""
|
||||
try:
|
||||
# 验证 refresh token(检查黑名单)
|
||||
payload = await verify_token(request.refresh_token, token_type="refresh")
|
||||
if not payload:
|
||||
raise AppException(
|
||||
code=ErrorCode.UNAUTHORIZED,
|
||||
message="Invalid or revoked refresh token",
|
||||
status_code=401,
|
||||
details={"headers": {"WWW-Authenticate": "Bearer"}}
|
||||
)
|
||||
|
||||
if not payload.sid:
|
||||
raise AppException(
|
||||
code=ErrorCode.UNAUTHORIZED,
|
||||
message="Refresh token session is missing",
|
||||
status_code=401,
|
||||
details={"headers": {"WWW-Authenticate": "Bearer"}}
|
||||
)
|
||||
|
||||
user = await user_service.get_user_by_id(payload.sub)
|
||||
if not user or not user.is_active:
|
||||
if user:
|
||||
session_service.revoke_user_sessions(user.id, reason="inactive_user")
|
||||
raise AppException(
|
||||
code=ErrorCode.UNAUTHORIZED,
|
||||
message="User session is no longer active",
|
||||
status_code=401,
|
||||
details={"headers": {"WWW-Authenticate": "Bearer"}}
|
||||
)
|
||||
|
||||
new_session_id = str(uuid.uuid4())
|
||||
tokens = create_token_pair(
|
||||
user_id=payload.sub,
|
||||
scopes=payload.scopes or ["user"],
|
||||
session_id=new_session_id,
|
||||
session_family_id=payload.sfid,
|
||||
)
|
||||
|
||||
rotated_session = session_service.rotate_refresh_token(
|
||||
payload.sid,
|
||||
request.refresh_token,
|
||||
tokens.refresh_token,
|
||||
new_session_id=new_session_id,
|
||||
ip_address=http_request.client.host if http_request.client else None,
|
||||
user_agent=http_request.headers.get("user-agent"),
|
||||
)
|
||||
if not rotated_session:
|
||||
raise AppException(
|
||||
code=ErrorCode.UNAUTHORIZED,
|
||||
message="Invalid or replayed refresh token",
|
||||
status_code=401,
|
||||
details={"headers": {"WWW-Authenticate": "Bearer"}}
|
||||
)
|
||||
|
||||
user_data = None
|
||||
if user:
|
||||
user_data = _build_user_payload(user)
|
||||
|
||||
return ResponseModel(
|
||||
data={
|
||||
"access_token": tokens.access_token,
|
||||
"refresh_token": tokens.refresh_token,
|
||||
"token_type": "bearer",
|
||||
"expires_in": ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
||||
"session_id": rotated_session.id,
|
||||
**({"user": user_data} if user_data else {}),
|
||||
}
|
||||
)
|
||||
|
||||
except AppException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Token refresh error: {e}")
|
||||
raise AppException(
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
message="Internal server error",
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.get("/me", response_model=ResponseModel)
|
||||
async def get_me(current_user: UserAuth = Depends(get_current_user)):
|
||||
"""
|
||||
获取当前登录用户信息
|
||||
|
||||
需要有效的 access token。
|
||||
"""
|
||||
return ResponseModel(
|
||||
data={
|
||||
"id": current_user.id,
|
||||
"username": current_user.username,
|
||||
"email": current_user.email,
|
||||
"avatar_url": current_user.avatar_url,
|
||||
"is_active": current_user.is_active,
|
||||
"is_superuser": current_user.is_superuser,
|
||||
"permissions": current_user.permissions,
|
||||
"roles": current_user.roles,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.get("/session", response_model=ResponseModel)
|
||||
async def get_current_session(
|
||||
http_request: Request,
|
||||
current_user: UserAuth = Depends(get_current_user),
|
||||
):
|
||||
payload = await _get_access_payload(http_request)
|
||||
session_data = None
|
||||
|
||||
if payload and payload.sid:
|
||||
session_db = session_service.get_session(payload.sid)
|
||||
if session_db and session_db.user_id == current_user.id:
|
||||
session_data = _serialize_session(session_db, current_session_id=payload.sid)
|
||||
|
||||
return ResponseModel(
|
||||
data={
|
||||
"authenticated": True,
|
||||
"session_id": payload.sid if payload else None,
|
||||
"user": {
|
||||
"id": current_user.id,
|
||||
"username": current_user.username,
|
||||
"email": current_user.email,
|
||||
"avatar_url": current_user.avatar_url,
|
||||
"is_active": current_user.is_active,
|
||||
"is_superuser": current_user.is_superuser,
|
||||
"permissions": current_user.permissions,
|
||||
"roles": current_user.roles,
|
||||
},
|
||||
"session": session_data,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.get("/sessions", response_model=ResponseModel)
|
||||
async def list_sessions(
|
||||
http_request: Request,
|
||||
include_inactive: bool = False,
|
||||
current_user: UserAuth = Depends(get_current_user),
|
||||
):
|
||||
payload = await _get_access_payload(http_request)
|
||||
current_session_id = payload.sid if payload else None
|
||||
sessions = session_service.list_user_sessions(
|
||||
current_user.id,
|
||||
include_inactive=include_inactive,
|
||||
)
|
||||
|
||||
return ResponseModel(
|
||||
data={
|
||||
"items": [
|
||||
_serialize_session(session_db, current_session_id=current_session_id)
|
||||
for session_db in sessions
|
||||
],
|
||||
"total": len(sessions),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/sessions/{session_id}", response_model=ResponseModel)
|
||||
async def revoke_session(
|
||||
session_id: str,
|
||||
http_request: Request,
|
||||
current_user: UserAuth = Depends(get_current_user),
|
||||
):
|
||||
payload = await _get_access_payload(http_request)
|
||||
revoked = session_service.revoke_user_session(
|
||||
current_user.id,
|
||||
session_id,
|
||||
reason="user_revoke",
|
||||
)
|
||||
|
||||
if not revoked:
|
||||
raise AppException(
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
message="Session not found",
|
||||
status_code=404,
|
||||
)
|
||||
|
||||
revoked_access = False
|
||||
if payload and payload.sid == session_id:
|
||||
access_token = _extract_bearer_token(http_request)
|
||||
if access_token:
|
||||
revoked_access = await token_blacklist_service.revoke_token(
|
||||
access_token,
|
||||
reason="user_revoke",
|
||||
)
|
||||
|
||||
return ResponseModel(
|
||||
message="Session revoked successfully",
|
||||
data={
|
||||
"session_id": session_id,
|
||||
"revoked": True,
|
||||
"current": payload.sid == session_id if payload else False,
|
||||
"revoked_access": revoked_access,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.post("/logout", response_model=ResponseModel)
|
||||
async def logout(
|
||||
request: RefreshTokenRequest,
|
||||
http_request: Request,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
用户登出
|
||||
|
||||
将当前 access token 和 refresh token 加入黑名单,使其失效。
|
||||
客户端也需要清除本地存储的 token。
|
||||
"""
|
||||
try:
|
||||
access_token = _extract_bearer_token(http_request)
|
||||
revoked_access = False
|
||||
if access_token:
|
||||
revoked_access = await token_blacklist_service.revoke_token(
|
||||
access_token,
|
||||
reason="logout",
|
||||
)
|
||||
|
||||
# 撤销 refresh token(如果提供)
|
||||
revoked_refresh = False
|
||||
if request.refresh_token:
|
||||
refresh_payload = await verify_token(request.refresh_token, token_type="refresh")
|
||||
if refresh_payload and refresh_payload.sid:
|
||||
session_service.revoke_session(refresh_payload.sid, reason="logout")
|
||||
await token_blacklist_service.revoke_token(
|
||||
request.refresh_token,
|
||||
reason="logout"
|
||||
)
|
||||
revoked_refresh = True
|
||||
|
||||
logger.info(f"User logged out: {current_user.username} ({current_user.id})")
|
||||
|
||||
return ResponseModel(
|
||||
message="Logged out successfully",
|
||||
data={
|
||||
"user_id": current_user.id,
|
||||
"revoked": revoked_access or revoked_refresh,
|
||||
"revoked_access": revoked_access,
|
||||
"revoked_refresh": revoked_refresh,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Logout error: {e}")
|
||||
# 即使撤销失败也返回成功,因为客户端应该清除本地 token
|
||||
return ResponseModel(message="Logged out successfully")
|
||||
|
||||
|
||||
class LogoutAllRequest(BaseModel):
|
||||
"""登出所有设备请求"""
|
||||
pass
|
||||
|
||||
|
||||
@router.post("/logout-all", response_model=ResponseModel)
|
||||
async def logout_all_devices(
|
||||
request: LogoutAllRequest,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
从所有设备登出
|
||||
|
||||
撤销该用户的所有 Token,使用户在所有设备上登出。
|
||||
"""
|
||||
try:
|
||||
# 撤销该用户所有 token
|
||||
await token_blacklist_service.revoke_all_user_tokens(
|
||||
current_user.id,
|
||||
reason="logout_all"
|
||||
)
|
||||
revoked_sessions = session_service.revoke_user_sessions(
|
||||
current_user.id,
|
||||
reason="logout_all",
|
||||
)
|
||||
|
||||
logger.info(f"User logged out from all devices: {current_user.username} ({current_user.id})")
|
||||
|
||||
return ResponseModel(
|
||||
message="Logged out from all devices successfully",
|
||||
data={
|
||||
"user_id": current_user.id,
|
||||
"all_devices": True,
|
||||
"revoked_sessions": revoked_sessions,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Logout all error: {e}")
|
||||
raise AppException(
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
message="Failed to logout from all devices",
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.post("/avatar", response_model=ResponseModel)
|
||||
async def upload_avatar(
|
||||
file: UploadFile = File(...),
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
上传用户头像
|
||||
|
||||
支持 JPG、PNG 格式,最大 5MB。
|
||||
头像会被裁剪为正方形并压缩至 256x256 像素。
|
||||
"""
|
||||
try:
|
||||
# 验证文件类型
|
||||
allowed_types = {'image/jpeg', 'image/png', 'image/jpg'}
|
||||
if file.content_type not in allowed_types:
|
||||
raise AppException(
|
||||
message="Only JPG and PNG images are allowed",
|
||||
code=ErrorCode.INVALID_PARAMETER,
|
||||
status_code=400
|
||||
)
|
||||
|
||||
# 读取文件内容
|
||||
contents = await file.read()
|
||||
max_size = 5 * 1024 * 1024 # 5MB
|
||||
if len(contents) > max_size:
|
||||
raise AppException(
|
||||
message="File size exceeds 5MB limit",
|
||||
code=ErrorCode.INVALID_PARAMETER,
|
||||
status_code=400
|
||||
)
|
||||
|
||||
# 使用 PIL 处理图片
|
||||
image = Image.open(io.BytesIO(contents))
|
||||
|
||||
# 转换为 RGB(处理RGBA图片)
|
||||
if image.mode in ('RGBA', 'LA', 'P'):
|
||||
background = Image.new('RGB', image.size, (255, 255, 255))
|
||||
if image.mode == 'P':
|
||||
image = image.convert('RGBA')
|
||||
if image.mode in ('RGBA', 'LA'):
|
||||
background.paste(image, mask=image.split()[-1] if image.mode in ('RGBA', 'LA') else None)
|
||||
image = background
|
||||
|
||||
# 裁剪为正方形(从中心)
|
||||
width, height = image.size
|
||||
min_dim = min(width, height)
|
||||
left = (width - min_dim) // 2
|
||||
top = (height - min_dim) // 2
|
||||
right = left + min_dim
|
||||
bottom = top + min_dim
|
||||
image = image.crop((left, top, right, bottom))
|
||||
|
||||
# 调整大小为 256x256
|
||||
image = image.resize((256, 256), Image.Resampling.LANCZOS)
|
||||
|
||||
# 保存为 PNG
|
||||
output = io.BytesIO()
|
||||
image.save(output, format='PNG', optimize=True)
|
||||
output.seek(0)
|
||||
|
||||
# 生成存储路径: avatars/{user_id}/{timestamp}.png
|
||||
timestamp = int(datetime.now().timestamp())
|
||||
storage_path = f"avatars/{current_user.id}/{timestamp}.png"
|
||||
|
||||
# 上传到存储
|
||||
avatar_url = storage_manager.save(storage_path, output.getvalue())
|
||||
|
||||
# 更新用户头像 URL
|
||||
updated_user = await user_service.update_user(
|
||||
current_user.id,
|
||||
avatar_url=avatar_url
|
||||
)
|
||||
|
||||
if not updated_user:
|
||||
raise AppException(
|
||||
message="Failed to update user avatar",
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
logger.info(f"Avatar uploaded for user {current_user.username}: {avatar_url}")
|
||||
|
||||
return ResponseModel(
|
||||
data={
|
||||
"avatar_url": avatar_url,
|
||||
"user": {
|
||||
"id": updated_user.id,
|
||||
"username": updated_user.username,
|
||||
"email": updated_user.email,
|
||||
"avatar_url": updated_user.avatar_url,
|
||||
"is_active": updated_user.is_active,
|
||||
"is_superuser": updated_user.is_superuser,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
except AppException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Avatar upload error: {e}")
|
||||
raise AppException(
|
||||
message=f"Failed to upload avatar: {str(e)}",
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
# ===== Password Reset Models =====
|
||||
|
||||
class ForgotPasswordRequest(BaseModel):
|
||||
"""忘记密码请求"""
|
||||
email: EmailStr = Field(..., description="用户邮箱")
|
||||
|
||||
|
||||
class ResetPasswordRequest(BaseModel):
|
||||
"""重置密码请求"""
|
||||
token: str = Field(..., description="重置令牌")
|
||||
new_password: str = Field(..., min_length=6, description="新密码")
|
||||
|
||||
|
||||
class VerifyResetTokenRequest(BaseModel):
|
||||
"""验证重置令牌请求"""
|
||||
token: str = Field(..., description="重置令牌")
|
||||
|
||||
|
||||
# Password reset token storage (using Redis if available, else memory)
|
||||
_reset_tokens = {} # In-memory fallback {token: {"user_id": str, "expires": float}}
|
||||
|
||||
|
||||
async def _store_reset_token(token: str, user_id: str, expires_in: int = 3600) -> None:
|
||||
"""Store reset token in Redis or memory"""
|
||||
if REDIS_ENABLED:
|
||||
try:
|
||||
import redis.asyncio as aioredis
|
||||
from src.config.settings import REDIS_URL
|
||||
redis_client = await aioredis.from_url(REDIS_URL, encoding="utf-8", decode_responses=True)
|
||||
await redis_client.setex(f"pwd_reset:{token}", expires_in, user_id)
|
||||
await redis_client.close()
|
||||
return
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis not available for reset token, using memory: {e}")
|
||||
|
||||
# Fallback to memory
|
||||
_reset_tokens[token] = {
|
||||
"user_id": user_id,
|
||||
"expires": datetime.now().timestamp() + expires_in
|
||||
}
|
||||
|
||||
|
||||
async def _get_reset_token_user_id(token: str) -> Optional[str]:
|
||||
"""Get user_id from reset token"""
|
||||
if REDIS_ENABLED:
|
||||
try:
|
||||
import redis.asyncio as aioredis
|
||||
from src.config.settings import REDIS_URL
|
||||
redis_client = await aioredis.from_url(REDIS_URL, encoding="utf-8", decode_responses=True)
|
||||
user_id = await redis_client.get(f"pwd_reset:{token}")
|
||||
await redis_client.close()
|
||||
return user_id
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Fallback to memory
|
||||
token_data = _reset_tokens.get(token)
|
||||
if token_data:
|
||||
if token_data["expires"] > datetime.now().timestamp():
|
||||
return token_data["user_id"]
|
||||
else:
|
||||
# Expired, remove it
|
||||
del _reset_tokens[token]
|
||||
return None
|
||||
|
||||
|
||||
async def _delete_reset_token(token: str) -> None:
|
||||
"""Delete reset token"""
|
||||
if REDIS_ENABLED:
|
||||
try:
|
||||
import redis.asyncio as aioredis
|
||||
from src.config.settings import REDIS_URL
|
||||
redis_client = await aioredis.from_url(REDIS_URL, encoding="utf-8", decode_responses=True)
|
||||
await redis_client.delete(f"pwd_reset:{token}")
|
||||
await redis_client.close()
|
||||
return
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Fallback to memory
|
||||
if token in _reset_tokens:
|
||||
del _reset_tokens[token]
|
||||
|
||||
|
||||
def _generate_reset_token() -> str:
|
||||
"""Generate secure reset token"""
|
||||
import secrets
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
|
||||
@router.post("/forgot-password", response_model=ResponseModel)
|
||||
async def forgot_password(request: ForgotPasswordRequest):
|
||||
"""
|
||||
请求密码重置
|
||||
|
||||
发送密码重置邮件。
|
||||
令牌有效期为 1 小时。
|
||||
"""
|
||||
try:
|
||||
# 查找用户
|
||||
user = await user_service.get_user_by_email(request.email)
|
||||
if not user:
|
||||
# 为了安全,即使用户不存在也返回成功,不暴露邮箱是否存在
|
||||
logger.info(f"Password reset requested for non-existent email: {request.email}")
|
||||
return ResponseModel(
|
||||
message="If the email exists, a reset link has been sent"
|
||||
)
|
||||
|
||||
# 生成重置令牌
|
||||
token = _generate_reset_token()
|
||||
|
||||
# 存储令牌(1小时有效)
|
||||
await _store_reset_token(token, user.id, expires_in=3600)
|
||||
|
||||
# 发送重置邮件
|
||||
email_result = await email_service.send_password_reset(
|
||||
to_email=user.email,
|
||||
username=user.username,
|
||||
reset_token=token,
|
||||
expires_in=1
|
||||
)
|
||||
|
||||
if email_result["success"]:
|
||||
logger.info(f"Password reset email sent to {user.email}")
|
||||
return ResponseModel(
|
||||
message="Password reset email sent"
|
||||
)
|
||||
else:
|
||||
# 邮件发送失败,但在开发环境可以返回令牌
|
||||
logger.warning(f"Failed to send reset email: {email_result.get('error')}")
|
||||
|
||||
# 生产环境不返回令牌
|
||||
if NODE_ENV == "production":
|
||||
return ResponseModel(
|
||||
message="Failed to send reset email. Please try again later."
|
||||
)
|
||||
else:
|
||||
# 开发环境返回令牌以便测试
|
||||
return ResponseModel(
|
||||
data={
|
||||
"message": "Password reset email sent",
|
||||
"reset_token": token, # 开发环境返回,生产环境不应返回
|
||||
"expires_in": 3600,
|
||||
"email_error": email_result.get("error")
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Forgot password error: {e}")
|
||||
raise AppException(
|
||||
message="Failed to process password reset request",
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.post("/verify-reset-token", response_model=ResponseModel)
|
||||
async def verify_reset_token(request: VerifyResetTokenRequest):
|
||||
"""
|
||||
验证密码重置令牌
|
||||
|
||||
检查令牌是否有效,用于前端重置密码页面验证。
|
||||
"""
|
||||
try:
|
||||
user_id = await _get_reset_token_user_id(request.token)
|
||||
|
||||
if not user_id:
|
||||
raise AppException(
|
||||
message="Invalid or expired reset token",
|
||||
code=ErrorCode.INVALID_PARAMETER,
|
||||
status_code=400
|
||||
)
|
||||
|
||||
# 获取用户信息
|
||||
user = await user_service.get_user_by_id(user_id)
|
||||
if not user:
|
||||
raise AppException(
|
||||
message="User not found",
|
||||
code=ErrorCode.INVALID_PARAMETER,
|
||||
status_code=400
|
||||
)
|
||||
|
||||
return ResponseModel(
|
||||
data={
|
||||
"valid": True,
|
||||
"user_id": user_id,
|
||||
"email": user.email
|
||||
}
|
||||
)
|
||||
|
||||
except AppException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Verify reset token error: {e}")
|
||||
raise AppException(
|
||||
message="Failed to verify reset token",
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.post("/reset-password", response_model=ResponseModel)
|
||||
async def reset_password(request: ResetPasswordRequest):
|
||||
"""
|
||||
重置密码
|
||||
|
||||
使用有效的重置令牌设置新密码,并撤销该用户所有 Token。
|
||||
"""
|
||||
try:
|
||||
# 验证令牌
|
||||
user_id = await _get_reset_token_user_id(request.token)
|
||||
|
||||
if not user_id:
|
||||
raise AppException(
|
||||
message="Invalid or expired reset token",
|
||||
code=ErrorCode.INVALID_PARAMETER,
|
||||
status_code=400
|
||||
)
|
||||
|
||||
# 更新密码
|
||||
updated_user = await user_service.update_user(
|
||||
user_id,
|
||||
password=request.new_password
|
||||
)
|
||||
|
||||
if not updated_user:
|
||||
raise AppException(
|
||||
message="User not found",
|
||||
code=ErrorCode.INVALID_PARAMETER,
|
||||
status_code=400
|
||||
)
|
||||
|
||||
# 删除已使用的令牌
|
||||
await _delete_reset_token(request.token)
|
||||
|
||||
# 撤销该用户所有 Token(强制重新登录)
|
||||
await token_blacklist_service.revoke_all_user_tokens(
|
||||
user_id,
|
||||
reason="password_reset"
|
||||
)
|
||||
|
||||
# 使该用户所有缓存失效
|
||||
await user_service.invalidate_user_cache(user_id)
|
||||
|
||||
logger.info(f"Password reset successful for user {updated_user.username}")
|
||||
|
||||
return ResponseModel(
|
||||
message="Password reset successful",
|
||||
data={
|
||||
"user_id": updated_user.id,
|
||||
"email": updated_user.email
|
||||
}
|
||||
)
|
||||
|
||||
except AppException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Reset password error: {e}")
|
||||
raise AppException(
|
||||
message="Failed to reset password",
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
81
backend/src/api/canvas.py
Normal file
81
backend/src/api/canvas.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import logging
|
||||
from fastapi import APIRouter, Query
|
||||
from sqlmodel import Session
|
||||
from src.config.database import engine
|
||||
from src.models.entities import CanvasDB
|
||||
from src.models.schemas import ResponseModel, CanvasState
|
||||
from src.mappers import CanvasMapper
|
||||
from src.utils.errors import BusinessException, ErrorCode
|
||||
|
||||
router = APIRouter(tags=["canvas"])
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@router.get("/canvas", response_model=ResponseModel)
|
||||
async def get_canvas_state(id: str = "default", projectId: str = Query(None)):
|
||||
"""Get canvas state by ID
|
||||
|
||||
Returns empty default state if canvas not found.
|
||||
|
||||
Raises:
|
||||
BusinessException: Failed to fetch canvas
|
||||
"""
|
||||
try:
|
||||
with Session(engine) as session:
|
||||
canvas_db = session.get(CanvasDB, id)
|
||||
|
||||
if not canvas_db:
|
||||
# If not found, return empty default state
|
||||
return ResponseModel(data=CanvasState(id=id, projectId=projectId))
|
||||
|
||||
# Map to CanvasState
|
||||
canvas_state = CanvasState(
|
||||
id=canvas_db.id,
|
||||
projectId=canvas_db.project_id,
|
||||
nodes=canvas_db.nodes,
|
||||
connections=canvas_db.connections,
|
||||
groups=canvas_db.groups,
|
||||
history=canvas_db.history,
|
||||
historyIndex=canvas_db.history_index,
|
||||
updated_at=canvas_db.updated_at
|
||||
)
|
||||
return ResponseModel(data=canvas_state)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to fetch canvas {id}: {e}", exc_info=True)
|
||||
raise BusinessException(
|
||||
ErrorCode.CANVAS_NOT_FOUND,
|
||||
"Failed to fetch canvas state",
|
||||
{"canvas_id": id, "reason": str(e)}
|
||||
)
|
||||
|
||||
@router.post("/canvas", response_model=ResponseModel)
|
||||
async def save_canvas_state(state: CanvasState):
|
||||
"""Save canvas state
|
||||
|
||||
Raises:
|
||||
BusinessException: Failed to save canvas
|
||||
"""
|
||||
try:
|
||||
with Session(engine) as session:
|
||||
# Use merge to handle both insert and update gracefully
|
||||
canvas_db = CanvasDB(
|
||||
id=state.id,
|
||||
project_id=state.projectId,
|
||||
nodes=state.nodes,
|
||||
connections=state.connections,
|
||||
groups=state.groups,
|
||||
history=state.history,
|
||||
history_index=state.history_index,
|
||||
updated_at=state.updated_at
|
||||
)
|
||||
|
||||
session.merge(canvas_db)
|
||||
session.commit()
|
||||
|
||||
return ResponseModel(data={"id": state.id, "updated": True})
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save canvas {state.id}: {e}", exc_info=True)
|
||||
raise BusinessException(
|
||||
ErrorCode.UNKNOWN_ERROR,
|
||||
"Failed to save canvas state",
|
||||
{"canvas_id": state.id, "reason": str(e)}
|
||||
)
|
||||
428
backend/src/api/canvas_metadata.py
Normal file
428
backend/src/api/canvas_metadata.py
Normal file
@@ -0,0 +1,428 @@
|
||||
import logging
|
||||
from typing import Optional, List
|
||||
from fastapi import APIRouter, Query, Depends
|
||||
from sqlmodel import Session
|
||||
from src.config.database import engine
|
||||
from src.services.canvas_metadata_service import CanvasMetadataService
|
||||
from src.models.schemas import (
|
||||
ResponseModel,
|
||||
CanvasMetadata,
|
||||
CreateGeneralCanvasRequest,
|
||||
UpdateCanvasMetadataRequest
|
||||
)
|
||||
from src.utils.errors import ResourceNotFoundException, BusinessException, ErrorCode
|
||||
|
||||
router = APIRouter(tags=["canvas_metadata"])
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_session():
|
||||
""" 依赖 to get database session"""
|
||||
with Session(engine) as session:
|
||||
yield session
|
||||
|
||||
|
||||
@router.get("/projects/{project_id}/canvases", response_model=ResponseModel)
|
||||
async def list_project_canvases(
|
||||
project_id: str,
|
||||
canvas_type: Optional[str] = Query(None, description="Filter by type: general, asset, storyboard"),
|
||||
include_deleted: bool = Query(False, description="Include deleted canvases"),
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
""" 查询 all canvases for a project.
|
||||
|
||||
Args:
|
||||
project_id: Project ID
|
||||
canvas_type: Optional filter by canvas type (general, asset, storyboard)
|
||||
include_deleted: Whether to include soft-deleted canvases
|
||||
|
||||
Returns:
|
||||
List of canvas metadata
|
||||
|
||||
Raises:
|
||||
BusinessException: Failed to list canvases
|
||||
"""
|
||||
try:
|
||||
service = CanvasMetadataService(session)
|
||||
canvases = service.list_canvases(project_id, canvas_type, include_deleted)
|
||||
|
||||
# 转换 to dict format
|
||||
canvas_data = [
|
||||
CanvasMetadata(
|
||||
id=c.id,
|
||||
projectId=c.project_id,
|
||||
canvasType=c.canvas_type,
|
||||
relatedEntityType=c.related_entity_type,
|
||||
relatedEntityId=c.related_entity_id,
|
||||
name=c.name,
|
||||
description=c.description,
|
||||
orderIndex=c.order_index,
|
||||
isPinned=c.is_pinned,
|
||||
tags=c.tags,
|
||||
nodeCount=c.node_count,
|
||||
lastAccessedAt=c.last_accessed_at,
|
||||
accessCount=c.access_count,
|
||||
createdAt=c.created_at,
|
||||
updatedAt=c.updated_at,
|
||||
deletedAt=c.deleted_at
|
||||
).model_dump(by_alias=True)
|
||||
for c in canvases
|
||||
]
|
||||
|
||||
return ResponseModel(data=canvas_data)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list canvases for project {project_id}: {e}", exc_info=True)
|
||||
raise BusinessException(
|
||||
ErrorCode.UNKNOWN_ERROR,
|
||||
"Failed to list canvases",
|
||||
{"project_id": project_id, "reason": str(e)}
|
||||
)
|
||||
|
||||
|
||||
@router.get("/canvases/{canvas_id}/metadata", response_model=ResponseModel)
|
||||
async def get_canvas_metadata(
|
||||
canvas_id: str,
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
""" Get canvas metadata by ID.
|
||||
|
||||
Args:
|
||||
canvas_id: Canvas ID
|
||||
|
||||
Returns:
|
||||
Canvas metadata
|
||||
|
||||
Raises:
|
||||
ResourceNotFoundException: Canvas not found
|
||||
BusinessException: Failed to get canvas metadata
|
||||
"""
|
||||
try:
|
||||
service = CanvasMetadataService(session)
|
||||
canvas = service.get_canvas(canvas_id)
|
||||
|
||||
if not canvas:
|
||||
raise ResourceNotFoundException("canvas", canvas_id)
|
||||
|
||||
# 更新 access statistics
|
||||
service.update_access_stats(canvas.id)
|
||||
|
||||
# 转换 to response format
|
||||
canvas_data = CanvasMetadata(
|
||||
id=canvas.id,
|
||||
projectId=canvas.project_id,
|
||||
canvasType=canvas.canvas_type,
|
||||
relatedEntityType=canvas.related_entity_type,
|
||||
relatedEntityId=canvas.related_entity_id,
|
||||
name=canvas.name,
|
||||
description=canvas.description,
|
||||
orderIndex=canvas.order_index,
|
||||
isPinned=canvas.is_pinned,
|
||||
tags=canvas.tags,
|
||||
nodeCount=canvas.node_count,
|
||||
lastAccessedAt=canvas.last_accessed_at,
|
||||
accessCount=canvas.access_count,
|
||||
createdAt=canvas.created_at,
|
||||
updatedAt=canvas.updated_at,
|
||||
deletedAt=canvas.deleted_at
|
||||
).model_dump(by_alias=True)
|
||||
|
||||
return ResponseModel(data=canvas_data)
|
||||
except ResourceNotFoundException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get canvas metadata {canvas_id}: {e}", exc_info=True)
|
||||
raise BusinessException(
|
||||
ErrorCode.CANVAS_NOT_FOUND,
|
||||
"Failed to get canvas metadata",
|
||||
{"canvas_id": canvas_id, "reason": str(e)}
|
||||
)
|
||||
|
||||
|
||||
@router.get("/assets/{asset_id}/canvas", response_model=ResponseModel)
|
||||
async def get_asset_canvas(
|
||||
asset_id: str,
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
""" Get canvas associated with an asset.
|
||||
Automatically creates the canvas if it doesn't exist.
|
||||
|
||||
Args:
|
||||
asset_id: Asset ID
|
||||
|
||||
Returns:
|
||||
Canvas metadata
|
||||
|
||||
Raises:
|
||||
ResourceNotFoundException: Asset not found
|
||||
BusinessException: Failed to get asset canvas
|
||||
"""
|
||||
try:
|
||||
service = CanvasMetadataService(session)
|
||||
canvas = service.get_or_create_asset_canvas(asset_id)
|
||||
|
||||
# 转换 to response format
|
||||
canvas_data = CanvasMetadata(
|
||||
id=canvas.id,
|
||||
projectId=canvas.project_id,
|
||||
canvasType=canvas.canvas_type,
|
||||
relatedEntityType=canvas.related_entity_type,
|
||||
relatedEntityId=canvas.related_entity_id,
|
||||
name=canvas.name,
|
||||
description=canvas.description,
|
||||
orderIndex=canvas.order_index,
|
||||
isPinned=canvas.is_pinned,
|
||||
tags=canvas.tags,
|
||||
nodeCount=canvas.node_count,
|
||||
lastAccessedAt=canvas.last_accessed_at,
|
||||
accessCount=canvas.access_count,
|
||||
createdAt=canvas.created_at,
|
||||
updatedAt=canvas.updated_at,
|
||||
deletedAt=canvas.deleted_at
|
||||
).model_dump(by_alias=True)
|
||||
|
||||
return ResponseModel(data=canvas_data)
|
||||
except ValueError as e:
|
||||
raise ResourceNotFoundException("asset", asset_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get asset canvas {asset_id}: {e}", exc_info=True)
|
||||
raise BusinessException(
|
||||
ErrorCode.UNKNOWN_ERROR,
|
||||
"Failed to get asset canvas",
|
||||
{"asset_id": asset_id, "reason": str(e)}
|
||||
)
|
||||
|
||||
|
||||
@router.get("/storyboards/{storyboard_id}/canvas", response_model=ResponseModel)
|
||||
async def get_storyboard_canvas(
|
||||
storyboard_id: str,
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
""" Get canvas associated with a storyboard.
|
||||
Automatically creates the canvas if it doesn't exist.
|
||||
|
||||
Args:
|
||||
storyboard_id: Storyboard ID
|
||||
|
||||
Returns:
|
||||
Canvas metadata
|
||||
|
||||
Raises:
|
||||
ResourceNotFoundException: Storyboard not found
|
||||
BusinessException: Failed to get storyboard canvas
|
||||
"""
|
||||
try:
|
||||
service = CanvasMetadataService(session)
|
||||
canvas = service.get_or_create_storyboard_canvas(storyboard_id)
|
||||
|
||||
# 转换 to response format
|
||||
canvas_data = CanvasMetadata(
|
||||
id=canvas.id,
|
||||
projectId=canvas.project_id,
|
||||
canvasType=canvas.canvas_type,
|
||||
relatedEntityType=canvas.related_entity_type,
|
||||
relatedEntityId=canvas.related_entity_id,
|
||||
name=canvas.name,
|
||||
description=canvas.description,
|
||||
orderIndex=canvas.order_index,
|
||||
isPinned=canvas.is_pinned,
|
||||
tags=canvas.tags,
|
||||
nodeCount=canvas.node_count,
|
||||
lastAccessedAt=canvas.last_accessed_at,
|
||||
accessCount=canvas.access_count,
|
||||
createdAt=canvas.created_at,
|
||||
updatedAt=canvas.updated_at,
|
||||
deletedAt=canvas.deleted_at
|
||||
).model_dump(by_alias=True)
|
||||
|
||||
return ResponseModel(data=canvas_data)
|
||||
except ValueError as e:
|
||||
raise ResourceNotFoundException("storyboard", storyboard_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get storyboard canvas {storyboard_id}: {e}", exc_info=True)
|
||||
raise BusinessException(
|
||||
ErrorCode.UNKNOWN_ERROR,
|
||||
"Failed to get storyboard canvas",
|
||||
{"storyboard_id": storyboard_id, "reason": str(e)}
|
||||
)
|
||||
|
||||
|
||||
@router.post("/projects/{project_id}/canvases", response_model=ResponseModel)
|
||||
async def create_general_canvas(
|
||||
project_id: str,
|
||||
request: CreateGeneralCanvasRequest,
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
""" Create a new general canvas for a project.
|
||||
|
||||
Args:
|
||||
project_id: Project ID
|
||||
request: Canvas creation request with name and optional description
|
||||
|
||||
Returns:
|
||||
Created canvas metadata
|
||||
|
||||
Raises:
|
||||
BusinessException: Failed to create canvas
|
||||
"""
|
||||
try:
|
||||
service = CanvasMetadataService(session)
|
||||
canvas = service.create_general_canvas(
|
||||
project_id=project_id,
|
||||
name=request.name,
|
||||
description=request.description
|
||||
)
|
||||
|
||||
# 转换 to response format
|
||||
canvas_data = CanvasMetadata(
|
||||
id=canvas.id,
|
||||
projectId=canvas.project_id,
|
||||
canvasType=canvas.canvas_type,
|
||||
relatedEntityType=canvas.related_entity_type,
|
||||
relatedEntityId=canvas.related_entity_id,
|
||||
name=canvas.name,
|
||||
description=canvas.description,
|
||||
orderIndex=canvas.order_index,
|
||||
isPinned=canvas.is_pinned,
|
||||
tags=canvas.tags,
|
||||
nodeCount=canvas.node_count,
|
||||
lastAccessedAt=canvas.last_accessed_at,
|
||||
accessCount=canvas.access_count,
|
||||
createdAt=canvas.created_at,
|
||||
updatedAt=canvas.updated_at,
|
||||
deletedAt=canvas.deleted_at
|
||||
).model_dump(by_alias=True)
|
||||
|
||||
return ResponseModel(data=canvas_data)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create general canvas for project {project_id}: {e}", exc_info=True)
|
||||
raise BusinessException(
|
||||
ErrorCode.UNKNOWN_ERROR,
|
||||
"Failed to create canvas",
|
||||
{"project_id": project_id, "reason": str(e)}
|
||||
)
|
||||
|
||||
|
||||
@router.put("/canvases/{canvas_id}/metadata", response_model=ResponseModel)
|
||||
async def update_canvas_metadata(
|
||||
canvas_id: str,
|
||||
request: UpdateCanvasMetadataRequest,
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
""" 更新 canvas metadata.
|
||||
|
||||
Args:
|
||||
canvas_id: Canvas ID
|
||||
request: Update request with optional fields
|
||||
|
||||
Returns:
|
||||
Updated canvas metadata
|
||||
|
||||
Raises:
|
||||
ResourceNotFoundException: Canvas not found
|
||||
BusinessException: Failed to update canvas
|
||||
"""
|
||||
try:
|
||||
service = CanvasMetadataService(session)
|
||||
|
||||
# 转换 request to dict, excluding unset fields
|
||||
updates = request.model_dump(exclude_unset=True, by_alias=False)
|
||||
|
||||
canvas = service.update_canvas(canvas_id, updates)
|
||||
|
||||
if not canvas:
|
||||
raise ResourceNotFoundException("canvas", canvas_id)
|
||||
|
||||
# 转换 to response format
|
||||
canvas_data = CanvasMetadata(
|
||||
id=canvas.id,
|
||||
projectId=canvas.project_id,
|
||||
canvasType=canvas.canvas_type,
|
||||
relatedEntityType=canvas.related_entity_type,
|
||||
relatedEntityId=canvas.related_entity_id,
|
||||
name=canvas.name,
|
||||
description=canvas.description,
|
||||
orderIndex=canvas.order_index,
|
||||
isPinned=canvas.is_pinned,
|
||||
tags=canvas.tags,
|
||||
nodeCount=canvas.node_count,
|
||||
lastAccessedAt=canvas.last_accessed_at,
|
||||
accessCount=canvas.access_count,
|
||||
createdAt=canvas.created_at,
|
||||
updatedAt=canvas.updated_at,
|
||||
deletedAt=canvas.deleted_at
|
||||
).model_dump(by_alias=True)
|
||||
|
||||
return ResponseModel(data=canvas_data)
|
||||
except ResourceNotFoundException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update canvas metadata {canvas_id}: {e}", exc_info=True)
|
||||
raise BusinessException(
|
||||
ErrorCode.UNKNOWN_ERROR,
|
||||
"Failed to update canvas",
|
||||
{"canvas_id": canvas_id, "reason": str(e)}
|
||||
)
|
||||
|
||||
|
||||
@router.put("/projects/{project_id}/canvases/reorder", response_model=ResponseModel)
|
||||
async def reorder_canvases(
|
||||
project_id: str,
|
||||
canvas_orders: List[dict],
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
""" 批处理 update canvas order indices.
|
||||
|
||||
Args:
|
||||
project_id: Project ID
|
||||
canvas_orders: List of {id, order_index} objects
|
||||
|
||||
Returns:
|
||||
Success confirmation
|
||||
|
||||
Raises:
|
||||
BusinessException: Failed to reorder canvases
|
||||
"""
|
||||
try:
|
||||
service = CanvasMetadataService(session)
|
||||
service.reorder_canvases(canvas_orders)
|
||||
|
||||
return ResponseModel(data={"updated": True})
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to reorder canvases for project {project_id}: {e}", exc_info=True)
|
||||
raise BusinessException(
|
||||
ErrorCode.UNKNOWN_ERROR,
|
||||
"Failed to reorder canvases",
|
||||
{"project_id": project_id, "reason": str(e)}
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/canvases/{canvas_id}", response_model=ResponseModel)
|
||||
async def delete_canvas(
|
||||
canvas_id: str,
|
||||
hard_delete: bool = Query(False, description="Permanently delete (true) or soft delete (false)"),
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
""" 删除 a canvas.
|
||||
|
||||
Args:
|
||||
canvas_id: Canvas ID
|
||||
hard_delete: If true, permanently delete; if false, soft delete
|
||||
|
||||
Returns:
|
||||
Success confirmation
|
||||
|
||||
Raises:
|
||||
BusinessException: Failed to delete canvas
|
||||
"""
|
||||
try:
|
||||
service = CanvasMetadataService(session)
|
||||
service.delete_canvas(canvas_id, hard_delete)
|
||||
|
||||
return ResponseModel(data={"deleted": True})
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete canvas {canvas_id}: {e}", exc_info=True)
|
||||
raise BusinessException(
|
||||
ErrorCode.UNKNOWN_ERROR,
|
||||
"Failed to delete canvas",
|
||||
{"canvas_id": canvas_id, "hard_delete": hard_delete, "reason": str(e)}
|
||||
)
|
||||
61
backend/src/api/chat.py
Normal file
61
backend/src/api/chat.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import logging
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from typing import List, Optional
|
||||
|
||||
from src.services.agent_engine import AgentScopeService
|
||||
from src.utils.errors import BusinessException, ErrorCode
|
||||
from src.auth.dependencies import get_current_user, UserAuth
|
||||
|
||||
router = APIRouter(prefix="/chat", tags=["chat"])
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class Message(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
messages: List[Message]
|
||||
model: Optional[str] = None
|
||||
stream: bool = True
|
||||
temperature: Optional[float] = 0.7
|
||||
max_tokens: Optional[int] = Field(None, alias="maxTokens")
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
@router.post("/completions")
|
||||
async def chat_completions(
|
||||
request: ChatCompletionRequest,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
OpenAI-compatible chat completion endpoint.
|
||||
Supports streaming and non-streaming modes.
|
||||
Now integrated with AgentScope.
|
||||
|
||||
Raises:
|
||||
BusinessException: Chat completion failed
|
||||
"""
|
||||
if not request.stream:
|
||||
raise BusinessException(
|
||||
ErrorCode.INVALID_PARAMETER,
|
||||
"Non-stream mode is not supported",
|
||||
{"field": "stream", "expected": True}
|
||||
)
|
||||
|
||||
messages_payload = [m.model_dump() for m in request.messages]
|
||||
|
||||
try:
|
||||
agent_service = AgentScopeService(user_id=current_user.id)
|
||||
return StreamingResponse(
|
||||
agent_service.stream_chat(messages_payload),
|
||||
media_type="text/event-stream"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"AgentScope service error: {e}", exc_info=True)
|
||||
raise BusinessException(
|
||||
ErrorCode.GENERATION_FAILED,
|
||||
"Chat completion failed",
|
||||
{"reason": str(e)}
|
||||
)
|
||||
45
backend/src/api/config/__init__.py
Normal file
45
backend/src/api/config/__init__.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""Config API routes - split into multiple modules for maintainability.
|
||||
|
||||
This package provides configuration management endpoints including:
|
||||
- System configuration (/config/system)
|
||||
- Provider and model configuration (/config/providers, /config/models, etc.)
|
||||
- Style configuration (/config/styles)
|
||||
- Health monitoring (/config/health)
|
||||
- Configuration validation (/config/validate)
|
||||
- Admin operations (/admin/*)
|
||||
- Storage utilities (/storage/*)
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from . import system, providers, styles, health, validation, admin, storage
|
||||
|
||||
# Create main router
|
||||
router = APIRouter()
|
||||
|
||||
# Include all sub-routers
|
||||
router.include_router(system.router)
|
||||
router.include_router(providers.router)
|
||||
router.include_router(styles.router)
|
||||
router.include_router(health.router)
|
||||
router.include_router(validation.router)
|
||||
router.include_router(admin.router)
|
||||
router.include_router(storage.router)
|
||||
|
||||
# Export models for backward compatibility
|
||||
from .models import (
|
||||
SignUrlRequest,
|
||||
SystemConfig,
|
||||
ModelRegistrationRequest,
|
||||
ProviderKeyField,
|
||||
ProviderConfigResponse,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"router",
|
||||
"SignUrlRequest",
|
||||
"SystemConfig",
|
||||
"ModelRegistrationRequest",
|
||||
"ProviderKeyField",
|
||||
"ProviderConfigResponse",
|
||||
]
|
||||
321
backend/src/api/config/admin.py
Normal file
321
backend/src/api/config/admin.py
Normal file
@@ -0,0 +1,321 @@
|
||||
"""Provider and model management endpoints (Admin Only)."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from src.config.settings import REDIS_ENABLED
|
||||
from src.models.schemas import ResponseModel
|
||||
from src.auth.dependencies import require_admin
|
||||
from src.auth.models import UserAuth
|
||||
from src.services.cache_service import get_cache_service
|
||||
from src.services.provider.registry import ModelRegistry, ModelType
|
||||
from src.utils.errors import AppException, ErrorCode
|
||||
from src.services.provider.health import health_monitor
|
||||
from src.utils.service_loader import register_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/admin/providers/{provider_id}/toggle", response_model=ResponseModel)
|
||||
async def toggle_provider(
|
||||
provider_id: str,
|
||||
enabled: bool,
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
):
|
||||
"""Enable/disable Provider (Admin only)."""
|
||||
try:
|
||||
provider_config = ModelRegistry.get_provider_config(provider_id)
|
||||
if not provider_config:
|
||||
raise AppException(
|
||||
message=f"Provider not found: {provider_id}",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
|
||||
provider_config["enabled"] = enabled
|
||||
|
||||
src_dir = os.path.dirname(os.path.dirname(__file__))
|
||||
custom_providers_path = os.path.join(src_dir, "config", "services", "custom_providers.json")
|
||||
|
||||
existing_configs = []
|
||||
if os.path.exists(custom_providers_path):
|
||||
try:
|
||||
with open(custom_providers_path, 'r', encoding='utf-8') as f:
|
||||
existing_configs = json.load(f)
|
||||
except (json.JSONDecodeError, OSError) as e:
|
||||
logger.warning("Failed to load custom providers config %s: %s", custom_providers_path, e)
|
||||
|
||||
updated = False
|
||||
for i, config in enumerate(existing_configs):
|
||||
if config.get("id") == provider_id:
|
||||
existing_configs[i] = {**config, **provider_config}
|
||||
updated = True
|
||||
break
|
||||
|
||||
if not updated:
|
||||
existing_configs.append(provider_config)
|
||||
|
||||
os.makedirs(os.path.dirname(custom_providers_path), exist_ok=True)
|
||||
with open(custom_providers_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(existing_configs, f, indent=4, ensure_ascii=False)
|
||||
|
||||
if enabled:
|
||||
register_service(provider_config)
|
||||
|
||||
return ResponseModel(
|
||||
data={
|
||||
"provider_id": provider_id,
|
||||
"enabled": enabled,
|
||||
"message": f"Provider {'enabled' if enabled else 'disabled'} successfully",
|
||||
}
|
||||
)
|
||||
except AppException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error toggling provider: {e}")
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/admin/providers/{provider_id}", response_model=ResponseModel)
|
||||
async def delete_provider(
|
||||
provider_id: str,
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
):
|
||||
"""Delete custom Provider (Admin only)."""
|
||||
try:
|
||||
src_dir = os.path.dirname(os.path.dirname(__file__))
|
||||
custom_providers_path = os.path.join(src_dir, "config", "services", "custom_providers.json")
|
||||
|
||||
if not os.path.exists(custom_providers_path):
|
||||
raise AppException(
|
||||
message="Custom providers configuration not found",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
|
||||
with open(custom_providers_path, 'r', encoding='utf-8') as f:
|
||||
existing_configs = json.load(f)
|
||||
|
||||
initial_len = len(existing_configs)
|
||||
existing_configs = [c for c in existing_configs if c.get("id") != provider_id]
|
||||
|
||||
if len(existing_configs) == initial_len:
|
||||
raise AppException(
|
||||
message="Provider not found or not a custom provider",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
|
||||
with open(custom_providers_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(existing_configs, f, indent=4, ensure_ascii=False)
|
||||
|
||||
return ResponseModel(
|
||||
data={
|
||||
"provider_id": provider_id,
|
||||
"deleted": True,
|
||||
"message": "Provider deleted successfully",
|
||||
}
|
||||
)
|
||||
except AppException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting provider: {e}")
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.post("/admin/models/{model_id}/toggle", response_model=ResponseModel)
|
||||
async def toggle_model(
|
||||
model_id: str,
|
||||
enabled: bool,
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
):
|
||||
"""Enable/disable model (Admin only)."""
|
||||
try:
|
||||
model_config = ModelRegistry.get_config(model_id)
|
||||
if not model_config:
|
||||
raise AppException(
|
||||
message=f"Model not found: {model_id}",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
|
||||
model_config["enabled"] = enabled
|
||||
|
||||
src_dir = os.path.dirname(os.path.dirname(__file__))
|
||||
custom_models_path = os.path.join(src_dir, "config", "services", "custom_models.json")
|
||||
|
||||
existing_configs = []
|
||||
if os.path.exists(custom_models_path):
|
||||
try:
|
||||
with open(custom_models_path, 'r', encoding='utf-8') as f:
|
||||
existing_configs = json.load(f)
|
||||
except (json.JSONDecodeError, OSError) as e:
|
||||
logger.warning("Failed to load custom models config %s: %s", custom_models_path, e)
|
||||
|
||||
updated = False
|
||||
for i, config in enumerate(existing_configs):
|
||||
if config.get("id") == model_id:
|
||||
existing_configs[i] = {**config, "enabled": enabled}
|
||||
updated = True
|
||||
break
|
||||
|
||||
if not updated:
|
||||
existing_configs.append({**model_config, "enabled": enabled})
|
||||
|
||||
os.makedirs(os.path.dirname(custom_models_path), exist_ok=True)
|
||||
with open(custom_models_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(existing_configs, f, indent=4, ensure_ascii=False)
|
||||
|
||||
if enabled:
|
||||
register_service(model_config)
|
||||
|
||||
return ResponseModel(
|
||||
data={
|
||||
"model_id": model_id,
|
||||
"enabled": enabled,
|
||||
"message": f"Model {'enabled' if enabled else 'disabled'} successfully",
|
||||
}
|
||||
)
|
||||
except AppException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error toggling model: {e}")
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.post("/admin/models/{model_id}/default", response_model=ResponseModel)
|
||||
async def set_default_model(
|
||||
model_id: str,
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
):
|
||||
"""Set default model (Admin only)."""
|
||||
try:
|
||||
model_config = ModelRegistry.get_config(model_id)
|
||||
if not model_config:
|
||||
raise AppException(
|
||||
message=f"Model not found: {model_id}",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
|
||||
model_type_str = model_config.get("type")
|
||||
if not model_type_str:
|
||||
raise AppException(
|
||||
message="Model type not found",
|
||||
code=ErrorCode.INVALID_PARAMETER,
|
||||
status_code=400
|
||||
)
|
||||
|
||||
try:
|
||||
model_type = ModelType(model_type_str.lower())
|
||||
except ValueError:
|
||||
raise AppException(
|
||||
message=f"Invalid model type: {model_type_str}",
|
||||
code=ErrorCode.INVALID_PARAMETER,
|
||||
status_code=400
|
||||
)
|
||||
|
||||
ModelRegistry.set_default_by_id(model_type, model_id)
|
||||
|
||||
config_key_map = {
|
||||
"image": "defaultImageModel",
|
||||
"video": "defaultVideoModel",
|
||||
"audio": "defaultAudioModel",
|
||||
"lyrics": "defaultLyricsModel",
|
||||
"music": "defaultMusicModel",
|
||||
"llm": "defaultLLMModel",
|
||||
}
|
||||
|
||||
config_key = config_key_map.get(model_type_str.lower())
|
||||
if config_key:
|
||||
src_dir = os.path.dirname(os.path.dirname(__file__))
|
||||
config_path = os.path.join(src_dir, "config", "user_config.json")
|
||||
|
||||
existing_data = {}
|
||||
if os.path.exists(config_path):
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
existing_data = json.load(f)
|
||||
|
||||
existing_data[config_key] = model_id
|
||||
|
||||
with open(config_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(existing_data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
if REDIS_ENABLED:
|
||||
cache = get_cache_service()
|
||||
await cache.delete("config:defaults")
|
||||
await cache.delete("config:system")
|
||||
|
||||
return ResponseModel(
|
||||
data={
|
||||
"model_id": model_id,
|
||||
"model_type": model_type_str,
|
||||
"is_default": True,
|
||||
"message": f"Default {model_type_str} model set to {model_id}",
|
||||
}
|
||||
)
|
||||
except AppException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error setting default model: {e}")
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.post("/admin/models/{model_id}/test", response_model=ResponseModel)
|
||||
async def test_model(
|
||||
model_id: str,
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
):
|
||||
"""Test model connection (Admin only)."""
|
||||
try:
|
||||
service = ModelRegistry.get(model_id)
|
||||
if not service:
|
||||
raise AppException(
|
||||
message=f"Model service not found: {model_id}",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
|
||||
result = await health_monitor.check_service_health(model_id, service)
|
||||
|
||||
health_monitor.update_health(model_id, result)
|
||||
|
||||
return ResponseModel(
|
||||
data={
|
||||
"model_id": model_id,
|
||||
"success": result.status.value == "healthy",
|
||||
"status": result.status.value,
|
||||
"latency_ms": result.latency_ms,
|
||||
"message": "Model test completed",
|
||||
"error": result.error,
|
||||
}
|
||||
)
|
||||
except AppException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error testing model: {e}")
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
137
backend/src/api/config/health.py
Normal file
137
backend/src/api/config/health.py
Normal file
@@ -0,0 +1,137 @@
|
||||
"""Health check and monitoring endpoints."""
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from src.models.schemas import ResponseModel
|
||||
from src.services.provider.registry import ModelRegistry
|
||||
from src.utils.errors import AppException, ErrorCode
|
||||
from src.services.provider.health import health_monitor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/config/health", response_model=ResponseModel)
|
||||
async def get_health_status():
|
||||
"""Get health status of all registered services."""
|
||||
try:
|
||||
summary = health_monitor.get_health_summary()
|
||||
all_health = health_monitor.get_all_health()
|
||||
|
||||
services = {}
|
||||
for service_id, health in all_health.items():
|
||||
services[service_id] = {
|
||||
"status": health.status.value,
|
||||
"last_check": health.last_check.isoformat() if health.last_check else None,
|
||||
"last_success": health.last_success.isoformat() if health.last_success else None,
|
||||
"last_failure": health.last_failure.isoformat() if health.last_failure else None,
|
||||
"consecutive_failures": health.consecutive_failures,
|
||||
"consecutive_successes": health.consecutive_successes,
|
||||
"total_checks": health.total_checks,
|
||||
"total_failures": health.total_failures,
|
||||
"success_rate": health.get_success_rate(),
|
||||
"avg_latency_ms": health.avg_latency_ms,
|
||||
"should_circuit_break": health.should_circuit_break()
|
||||
}
|
||||
|
||||
return ResponseModel(data={
|
||||
"summary": summary,
|
||||
"services": services,
|
||||
"unhealthy": health_monitor.get_unhealthy_services(),
|
||||
"degraded": health_monitor.get_degraded_services()
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get health status: {e}")
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.get("/config/health/{service_id}", response_model=ResponseModel)
|
||||
async def get_service_health(service_id: str):
|
||||
"""Get detailed health status for a specific service."""
|
||||
try:
|
||||
health = health_monitor.get_health(service_id)
|
||||
|
||||
if not health:
|
||||
raise AppException(
|
||||
message=f"Service not found: {service_id}",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
|
||||
history = [
|
||||
{
|
||||
"status": result.status.value,
|
||||
"latency_ms": result.latency_ms,
|
||||
"timestamp": result.timestamp.isoformat(),
|
||||
"error": result.error,
|
||||
"metadata": result.metadata
|
||||
}
|
||||
for result in health.history
|
||||
]
|
||||
|
||||
return ResponseModel(data={
|
||||
"service_id": service_id,
|
||||
"status": health.status.value,
|
||||
"last_check": health.last_check.isoformat() if health.last_check else None,
|
||||
"last_success": health.last_success.isoformat() if health.last_success else None,
|
||||
"last_failure": health.last_failure.isoformat() if health.last_failure else None,
|
||||
"consecutive_failures": health.consecutive_failures,
|
||||
"consecutive_successes": health.consecutive_successes,
|
||||
"total_checks": health.total_checks,
|
||||
"total_failures": health.total_failures,
|
||||
"success_rate": health.get_success_rate(),
|
||||
"avg_latency_ms": health.avg_latency_ms,
|
||||
"should_circuit_break": health.should_circuit_break(),
|
||||
"history": history
|
||||
})
|
||||
except AppException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get service health: {e}")
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.post("/config/health/{service_id}/check", response_model=ResponseModel)
|
||||
async def check_service_health(service_id: str):
|
||||
"""Manually trigger a health check for a specific service."""
|
||||
try:
|
||||
service = ModelRegistry.get(service_id)
|
||||
if not service:
|
||||
raise AppException(
|
||||
message=f"Service not found: {service_id}",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
|
||||
health_monitor.register_service(service_id)
|
||||
|
||||
result = await health_monitor.check_service_health(service_id, service)
|
||||
|
||||
health_monitor.update_health(service_id, result)
|
||||
|
||||
return ResponseModel(data={
|
||||
"service_id": service_id,
|
||||
"status": result.status.value,
|
||||
"latency_ms": result.latency_ms,
|
||||
"timestamp": result.timestamp.isoformat(),
|
||||
"error": result.error
|
||||
})
|
||||
except AppException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check service health: {e}")
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
49
backend/src/api/config/models.py
Normal file
49
backend/src/api/config/models.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""Pydantic models for config API."""
|
||||
|
||||
from typing import Optional, List, Dict, Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class SignUrlRequest(BaseModel):
|
||||
"""Request to sign a URL."""
|
||||
url: str
|
||||
|
||||
|
||||
class SystemConfig(BaseModel):
|
||||
"""System configuration model."""
|
||||
default_image_model: Optional[str] = None
|
||||
default_video_model: Optional[str] = None
|
||||
default_audio_model: Optional[str] = None
|
||||
default_llm_model: Optional[str] = None
|
||||
default_style: Optional[str] = None
|
||||
default_resolution: Optional[str] = None
|
||||
default_ratio: Optional[str] = None
|
||||
|
||||
|
||||
class ModelRegistrationRequest(BaseModel):
|
||||
"""Request to register a new model."""
|
||||
model_id: str
|
||||
provider: str
|
||||
name: str
|
||||
type: str
|
||||
enabled: bool = True
|
||||
config: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ProviderKeyField(BaseModel):
|
||||
"""Provider key field configuration."""
|
||||
name: str
|
||||
label: str
|
||||
placeholder: str
|
||||
required: bool
|
||||
type: Optional[str] = "text" # 'password' or 'text'
|
||||
|
||||
|
||||
class ProviderConfigResponse(BaseModel):
|
||||
"""Provider configuration response."""
|
||||
id: str
|
||||
name: str
|
||||
icon: str
|
||||
description: str
|
||||
fields: List[ProviderKeyField]
|
||||
helpUrl: Optional[str] = None
|
||||
410
backend/src/api/config/providers.py
Normal file
410
backend/src/api/config/providers.py
Normal file
@@ -0,0 +1,410 @@
|
||||
"""Provider and model configuration endpoints."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional, List, Dict, Any
|
||||
|
||||
from fastapi import APIRouter
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from src.config.settings import REDIS_ENABLED
|
||||
from src.models.schemas import ResponseModel
|
||||
from src.services.cache_service import get_cache_service
|
||||
from src.services.provider.registry import ModelRegistry, ModelType
|
||||
from src.utils.errors import BusinessException, ErrorCode, ResourceNotFoundException, InvalidParameterException
|
||||
from src.utils.service_loader import register_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class ModelRegistrationRequest(BaseModel):
|
||||
"""Request to register a new model."""
|
||||
name: str
|
||||
model_id: str
|
||||
provider: str
|
||||
type: str
|
||||
family: Optional[str] = None
|
||||
capabilities: Optional[Dict[str, bool]] = None
|
||||
variants: Optional[Dict[str, str]] = None
|
||||
resolutions: Optional[Any] = None
|
||||
durations: Optional[Any] = None
|
||||
voices: Optional[List[Dict[str, str]]] = None
|
||||
args: Optional[List[Any]] = None
|
||||
|
||||
|
||||
class ProviderKeyField(BaseModel):
|
||||
"""Provider key field configuration."""
|
||||
name: str
|
||||
label: str
|
||||
placeholder: str
|
||||
required: bool
|
||||
type: Optional[str] = "text"
|
||||
|
||||
|
||||
class ProviderConfigResponse(BaseModel):
|
||||
"""Provider configuration response."""
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
fields: List[ProviderKeyField]
|
||||
helpUrl: Optional[str] = None
|
||||
|
||||
|
||||
# Supported providers configuration with key field definitions
|
||||
def _build_provider_configs() -> List[ProviderConfigResponse]:
|
||||
"""Build provider configs from ModelRegistry metadata and registered models.
|
||||
|
||||
This dynamically generates provider configurations from provider.json files,
|
||||
avoiding hardcoded configurations.
|
||||
"""
|
||||
configs = []
|
||||
|
||||
# Get all registered providers from ModelRegistry
|
||||
providers = ModelRegistry.list_providers()
|
||||
provider_ids = {p["id"] for p in providers}
|
||||
|
||||
# Build config for each provider
|
||||
for provider_id in provider_ids:
|
||||
metadata = ModelRegistry.get_provider_metadata(provider_id) or {}
|
||||
|
||||
# Get fields from metadata or use defaults
|
||||
fields_data = metadata.get("fields", [])
|
||||
fields = [ProviderKeyField(**f) for f in fields_data] if fields_data else [
|
||||
ProviderKeyField(name="apiKey", label="API Key", placeholder="sk-...", required=True, type="password")
|
||||
]
|
||||
|
||||
# Build config response
|
||||
config = ProviderConfigResponse(
|
||||
id=provider_id,
|
||||
name=metadata.get("name") or provider_id.capitalize(),
|
||||
description=metadata.get("description", ""),
|
||||
fields=fields,
|
||||
helpUrl=metadata.get("helpUrl") or metadata.get("dashboard_url")
|
||||
)
|
||||
configs.append(config)
|
||||
|
||||
return configs
|
||||
|
||||
|
||||
@router.get("/config/providers", response_model=ResponseModel)
|
||||
async def get_providers():
|
||||
"""Get list of supported providers and their capabilities (cached for 5 minutes)."""
|
||||
cache_key = "config:providers"
|
||||
|
||||
if REDIS_ENABLED:
|
||||
cache = get_cache_service()
|
||||
cached_providers = await cache.get(cache_key)
|
||||
if cached_providers is not None:
|
||||
return ResponseModel(data=cached_providers)
|
||||
|
||||
providers = ModelRegistry.list_providers()
|
||||
result = {"providers": providers}
|
||||
|
||||
if REDIS_ENABLED:
|
||||
cache = get_cache_service()
|
||||
await cache.set(cache_key, result, ttl=300)
|
||||
|
||||
return ResponseModel(data=result)
|
||||
|
||||
|
||||
@router.post("/config/models", response_model=ResponseModel)
|
||||
async def register_new_model(request: ModelRegistrationRequest):
|
||||
"""Register a new model configuration.
|
||||
|
||||
Raises:
|
||||
InvalidParameterException: Invalid model configuration
|
||||
ResourceNotFoundException: Template not found
|
||||
BusinessException: Failed to register model
|
||||
"""
|
||||
try:
|
||||
provider = request.provider.lower()
|
||||
model_type = request.type.lower()
|
||||
family = request.family.lower() if request.family else "default"
|
||||
|
||||
module_name = ""
|
||||
class_name = ""
|
||||
|
||||
# Find template service from registry
|
||||
try:
|
||||
type_enum = None
|
||||
try:
|
||||
type_enum = ModelType(model_type)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
services = ModelRegistry.find_services(
|
||||
provider=provider,
|
||||
model_type=type_enum
|
||||
)
|
||||
|
||||
if not services and not type_enum:
|
||||
all_services = ModelRegistry.find_services(provider=provider)
|
||||
services = [s for s in all_services if s.get('type') == model_type]
|
||||
|
||||
if not services:
|
||||
raise ResourceNotFoundException("template", f"{provider}/{model_type}")
|
||||
|
||||
template = services[0]
|
||||
if family and family != "default":
|
||||
for svc in services:
|
||||
svc_id = svc.get('id', '').lower()
|
||||
svc_class = (svc.get('class') or svc.get('class_name', '')).lower()
|
||||
if family in svc_id or family in svc_class:
|
||||
template = svc
|
||||
break
|
||||
|
||||
module_name = template.get('module')
|
||||
class_name = template.get('class') or template.get('class_name')
|
||||
|
||||
if not module_name or not class_name:
|
||||
raise BusinessException(
|
||||
ErrorCode.UNKNOWN_ERROR,
|
||||
"Template service configuration is missing module or class",
|
||||
{"template": template}
|
||||
)
|
||||
|
||||
except (ResourceNotFoundException, BusinessException):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error finding template for {provider}/{model_type}: {e}", exc_info=True)
|
||||
raise BusinessException(
|
||||
ErrorCode.UNKNOWN_ERROR,
|
||||
"Failed to resolve service template",
|
||||
{"provider": provider, "type": model_type, "reason": str(e)}
|
||||
)
|
||||
|
||||
service_id = f"{provider}-{request.model_id.replace('.', '-').replace('/', '-')}"
|
||||
|
||||
config = {
|
||||
"id": service_id,
|
||||
"module": module_name,
|
||||
"class": class_name,
|
||||
"name": request.name,
|
||||
"args": [request.model_id],
|
||||
"type": model_type,
|
||||
"provider": provider,
|
||||
"enabled": True
|
||||
}
|
||||
|
||||
if request.capabilities:
|
||||
config["capabilities"] = request.capabilities
|
||||
if request.variants:
|
||||
config["variants"] = request.variants
|
||||
if request.resolutions:
|
||||
config["resolutions"] = request.resolutions
|
||||
if request.durations:
|
||||
config["durations"] = request.durations
|
||||
if request.args:
|
||||
config["args"] = request.args
|
||||
|
||||
src_dir = os.path.dirname(os.path.dirname(__file__))
|
||||
custom_config_path = os.path.join(src_dir, "config", "services", "custom_models.json")
|
||||
|
||||
os.makedirs(os.path.dirname(custom_config_path), exist_ok=True)
|
||||
|
||||
existing_configs = []
|
||||
if os.path.exists(custom_config_path):
|
||||
try:
|
||||
with open(custom_config_path, 'r', encoding='utf-8') as f:
|
||||
existing_configs = json.load(f)
|
||||
except (json.JSONDecodeError, OSError) as e:
|
||||
logger.warning("Failed to load custom config %s: %s", custom_config_path, e)
|
||||
|
||||
existing_configs = [c for c in existing_configs if c.get("id") != service_id]
|
||||
existing_configs.append(config)
|
||||
|
||||
with open(custom_config_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(existing_configs, f, indent=4, ensure_ascii=False)
|
||||
|
||||
register_service(config)
|
||||
|
||||
if REDIS_ENABLED:
|
||||
cache = get_cache_service()
|
||||
await cache.delete("config:models")
|
||||
await cache.delete("config:defaults")
|
||||
|
||||
return ResponseModel(data={"status": "success", "model": config})
|
||||
|
||||
except (ResourceNotFoundException, BusinessException):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to register model: {e}", exc_info=True)
|
||||
raise BusinessException(
|
||||
ErrorCode.UNKNOWN_ERROR,
|
||||
"Failed to register model",
|
||||
{"reason": str(e)}
|
||||
)
|
||||
|
||||
|
||||
@router.get("/config/models", response_model=ResponseModel)
|
||||
async def get_models_config():
|
||||
"""Get all registered models configuration grouped by type (cached for 5 minutes)."""
|
||||
cache_key = "config:models"
|
||||
|
||||
if REDIS_ENABLED:
|
||||
cache = get_cache_service()
|
||||
cached_models = await cache.get(cache_key)
|
||||
if cached_models is not None:
|
||||
return ResponseModel(data=cached_models)
|
||||
|
||||
models = ModelRegistry.list_models()
|
||||
|
||||
grouped = {
|
||||
"image": {},
|
||||
"video": {},
|
||||
"audio": {},
|
||||
"lyrics": {},
|
||||
"music": {},
|
||||
"llm": {}
|
||||
}
|
||||
|
||||
for model_id, config_dict in models.items():
|
||||
model_config = config_dict.copy() if isinstance(config_dict, dict) else {}
|
||||
|
||||
model_type_str = model_config.get("type")
|
||||
if not model_type_str or model_type_str not in grouped:
|
||||
continue
|
||||
|
||||
try:
|
||||
model_type = ModelType(model_type_str.lower())
|
||||
default_id = ModelRegistry.get_default_id(model_type)
|
||||
model_config["is_default"] = (model_id == default_id)
|
||||
except (ValueError, KeyError):
|
||||
model_config["is_default"] = False
|
||||
|
||||
model_config["id"] = model_id
|
||||
|
||||
if "provider" not in model_config:
|
||||
if "/" in model_id:
|
||||
model_config["provider"] = model_id.split("/", 1)[0]
|
||||
|
||||
if "model_key" not in model_config:
|
||||
if "/" in model_id:
|
||||
model_config["model_key"] = model_id.split("/", 1)[1]
|
||||
|
||||
grouped[model_type_str][model_id] = model_config
|
||||
|
||||
if REDIS_ENABLED:
|
||||
cache = get_cache_service()
|
||||
await cache.set(cache_key, grouped, ttl=300)
|
||||
|
||||
return ResponseModel(data=grouped)
|
||||
|
||||
|
||||
@router.get("/config/models/search", response_model=ResponseModel)
|
||||
async def search_models(
|
||||
provider: Optional[str] = None,
|
||||
type: Optional[str] = None,
|
||||
enabled_only: bool = True
|
||||
):
|
||||
"""Search for models by criteria."""
|
||||
try:
|
||||
model_type = None
|
||||
if type:
|
||||
try:
|
||||
model_type = ModelType(type.lower())
|
||||
except ValueError:
|
||||
raise InvalidParameterException("type", f"Invalid model type: {type}")
|
||||
|
||||
results = ModelRegistry.find_services(
|
||||
provider=provider,
|
||||
model_type=model_type,
|
||||
enabled_only=enabled_only
|
||||
)
|
||||
|
||||
return ResponseModel(data={
|
||||
"count": len(results),
|
||||
"models": results
|
||||
})
|
||||
except InvalidParameterException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to search models: {e}", exc_info=True)
|
||||
raise BusinessException(
|
||||
ErrorCode.UNKNOWN_ERROR,
|
||||
"Failed to search models",
|
||||
{"provider": provider, "type": type, "reason": str(e)}
|
||||
)
|
||||
|
||||
|
||||
@router.get("/config/models/{model_id}", response_model=ResponseModel)
|
||||
async def get_model_config(model_id: str):
|
||||
"""Get detailed configuration for a specific model."""
|
||||
try:
|
||||
config = ModelRegistry.get_config(model_id)
|
||||
if not config:
|
||||
raise ResourceNotFoundException("model", model_id)
|
||||
|
||||
return ResponseModel(data=config)
|
||||
except ResourceNotFoundException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get model config: {e}", exc_info=True)
|
||||
raise BusinessException(
|
||||
ErrorCode.UNKNOWN_ERROR,
|
||||
"Failed to get model configuration",
|
||||
{"model_id": model_id, "reason": str(e)}
|
||||
)
|
||||
|
||||
|
||||
@router.get("/config/defaults", response_model=ResponseModel)
|
||||
async def get_defaults():
|
||||
"""Get default models for each type (cached for 5 minutes)."""
|
||||
cache_key = "config:defaults"
|
||||
|
||||
if REDIS_ENABLED:
|
||||
cache = get_cache_service()
|
||||
cached_defaults = await cache.get(cache_key)
|
||||
if cached_defaults is not None:
|
||||
return ResponseModel(data=cached_defaults)
|
||||
|
||||
try:
|
||||
defaults = {}
|
||||
for model_type in ModelType:
|
||||
default_id = ModelRegistry.get_default_id(model_type)
|
||||
if default_id:
|
||||
defaults[model_type.value] = {
|
||||
"id": default_id,
|
||||
"config": ModelRegistry.get_config(default_id)
|
||||
}
|
||||
|
||||
if REDIS_ENABLED:
|
||||
cache = get_cache_service()
|
||||
await cache.set(cache_key, defaults, ttl=300)
|
||||
|
||||
return ResponseModel(data=defaults)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get defaults: {e}", exc_info=True)
|
||||
raise BusinessException(
|
||||
ErrorCode.UNKNOWN_ERROR,
|
||||
"Failed to get default models",
|
||||
{"reason": str(e)}
|
||||
)
|
||||
|
||||
|
||||
@router.get("/config/provider-configs", response_model=ResponseModel)
|
||||
async def get_provider_configs():
|
||||
"""Get supported provider configurations for user API key management.
|
||||
|
||||
Configurations are dynamically built from provider.json files,
|
||||
ensuring a single source of truth for provider metadata.
|
||||
"""
|
||||
cache_key = "config:provider_configs"
|
||||
|
||||
if REDIS_ENABLED:
|
||||
cache = get_cache_service()
|
||||
cached_configs = await cache.get(cache_key)
|
||||
if cached_configs is not None:
|
||||
return ResponseModel(data={"providers": cached_configs})
|
||||
|
||||
# Dynamically build provider configs from registry metadata
|
||||
configs = _build_provider_configs()
|
||||
providers_data = [config.model_dump() for config in configs]
|
||||
|
||||
if REDIS_ENABLED:
|
||||
cache = get_cache_service()
|
||||
await cache.set(cache_key, providers_data, ttl=300)
|
||||
|
||||
return ResponseModel(data={"providers": providers_data})
|
||||
24
backend/src/api/config/storage.py
Normal file
24
backend/src/api/config/storage.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""Storage-related endpoints."""
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from src.models.schemas import ResponseModel
|
||||
from src.utils.errors import ErrorCode
|
||||
from src.services.storage_service import storage_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class SignUrlRequest:
|
||||
"""Request to sign a URL."""
|
||||
url: str
|
||||
|
||||
|
||||
@router.post("/storage/sign-url", response_model=ResponseModel)
|
||||
async def sign_url_endpoint(request: dict):
|
||||
"""Re-sign an OSS URL."""
|
||||
new_url = storage_manager.sign_url(request.get("url", ""))
|
||||
return ResponseModel(code=ErrorCode.SUCCESS, message="URL signed", data={"url": new_url})
|
||||
208
backend/src/api/config/styles.py
Normal file
208
backend/src/api/config/styles.py
Normal file
@@ -0,0 +1,208 @@
|
||||
"""Style configuration endpoints."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from src.config.settings import REDIS_ENABLED
|
||||
from src.models.schemas import ResponseModel
|
||||
from src.services.cache_service import get_cache_service
|
||||
from src.utils.errors import AppException, ErrorCode, ResourceNotFoundException, BusinessException
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/config/styles", response_model=ResponseModel)
|
||||
async def get_styles():
|
||||
"""Get all style configurations (cached for 5 minutes)."""
|
||||
cache_key = "config:styles"
|
||||
|
||||
if REDIS_ENABLED:
|
||||
cache = get_cache_service()
|
||||
cached_styles = await cache.get(cache_key)
|
||||
if cached_styles is not None:
|
||||
return ResponseModel(data=cached_styles)
|
||||
|
||||
try:
|
||||
src_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||
styles_path = os.path.join(src_dir, "config", "styles.json")
|
||||
|
||||
result = {"styles": []}
|
||||
if os.path.exists(styles_path):
|
||||
with open(styles_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
result = {"styles": data.get("styles", [])}
|
||||
|
||||
if REDIS_ENABLED:
|
||||
cache = get_cache_service()
|
||||
await cache.set(cache_key, result, ttl=300)
|
||||
|
||||
return ResponseModel(data=result)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load styles: {e}")
|
||||
return ResponseModel(data={"styles": []})
|
||||
|
||||
|
||||
@router.post("/config/styles", response_model=ResponseModel)
|
||||
async def create_style(style: dict):
|
||||
"""Create a new style configuration."""
|
||||
try:
|
||||
src_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||
styles_path = os.path.join(src_dir, "config", "styles.json")
|
||||
|
||||
if not os.path.exists(styles_path):
|
||||
data = {"styles": []}
|
||||
else:
|
||||
with open(styles_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
|
||||
if 'id' not in style or not style['id']:
|
||||
style['id'] = str(uuid.uuid4())
|
||||
|
||||
data["styles"].append(style)
|
||||
|
||||
with open(styles_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
if REDIS_ENABLED:
|
||||
cache = get_cache_service()
|
||||
await cache.delete("config:styles")
|
||||
|
||||
return ResponseModel(data=style)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create style: {e}")
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.put("/config/styles/{style_id}", response_model=ResponseModel)
|
||||
async def update_style(style_id: str, style: dict):
|
||||
"""Update an existing style configuration."""
|
||||
try:
|
||||
src_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||
styles_path = os.path.join(src_dir, "config", "styles.json")
|
||||
|
||||
if not os.path.exists(styles_path):
|
||||
raise AppException(
|
||||
message="Styles configuration not found",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
|
||||
with open(styles_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
|
||||
styles = data.get("styles", [])
|
||||
updated = False
|
||||
for i, s in enumerate(styles):
|
||||
if s.get('id') == style_id:
|
||||
styles[i] = {**s, **style, "id": style_id}
|
||||
updated = True
|
||||
break
|
||||
|
||||
if not updated:
|
||||
raise AppException(
|
||||
message="Style not found",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
|
||||
data["styles"] = styles
|
||||
with open(styles_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
if REDIS_ENABLED:
|
||||
cache = get_cache_service()
|
||||
await cache.delete("config:styles")
|
||||
|
||||
return ResponseModel(data=styles[i])
|
||||
except AppException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update style: {e}")
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/config/styles/{style_id}", response_model=ResponseModel)
|
||||
async def delete_style(style_id: str):
|
||||
"""Delete a style configuration."""
|
||||
try:
|
||||
src_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||
styles_path = os.path.join(src_dir, "config", "styles.json")
|
||||
|
||||
if not os.path.exists(styles_path):
|
||||
raise ResourceNotFoundException("styles configuration", "styles.json")
|
||||
|
||||
with open(styles_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
|
||||
styles = data.get("styles", [])
|
||||
initial_len = len(styles)
|
||||
styles = [s for s in styles if s.get('id') != style_id]
|
||||
|
||||
if len(styles) == initial_len:
|
||||
raise ResourceNotFoundException("style", style_id)
|
||||
|
||||
data["styles"] = styles
|
||||
with open(styles_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
if REDIS_ENABLED:
|
||||
cache = get_cache_service()
|
||||
await cache.delete("config:styles")
|
||||
|
||||
return ResponseModel(data={"status": "success"})
|
||||
except ResourceNotFoundException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete style: {e}", exc_info=True)
|
||||
raise BusinessException(
|
||||
ErrorCode.UNKNOWN_ERROR,
|
||||
"Failed to delete style",
|
||||
{"style_id": style_id, "reason": str(e)}
|
||||
)
|
||||
|
||||
|
||||
@router.get("/config/generation-options", response_model=ResponseModel)
|
||||
async def get_generation_options():
|
||||
"""Get generation options (cached for 5 minutes)."""
|
||||
cache_key = "config:generation_options"
|
||||
|
||||
if REDIS_ENABLED:
|
||||
cache = get_cache_service()
|
||||
cached_options = await cache.get(cache_key)
|
||||
if cached_options is not None:
|
||||
return ResponseModel(data=cached_options)
|
||||
|
||||
try:
|
||||
src_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||
options_path = os.path.join(src_dir, "config", "generation_options.json")
|
||||
|
||||
result = {}
|
||||
if os.path.exists(options_path):
|
||||
with open(options_path, 'r', encoding='utf-8') as f:
|
||||
result = json.load(f)
|
||||
|
||||
if REDIS_ENABLED:
|
||||
cache = get_cache_service()
|
||||
await cache.set(cache_key, result, ttl=300)
|
||||
|
||||
return ResponseModel(data=result)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load generation options: {e}")
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
131
backend/src/api/config/system.py
Normal file
131
backend/src/api/config/system.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""System configuration endpoints."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
from fastapi import APIRouter
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
|
||||
from src.config.settings import REDIS_ENABLED
|
||||
from src.models.schemas import ResponseModel
|
||||
from src.services.cache_service import get_cache_service
|
||||
from src.services.provider.registry import ModelRegistry, ModelType
|
||||
from src.utils.errors import BusinessException, ErrorCode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class SystemConfig(BaseModel):
|
||||
"""System configuration model."""
|
||||
default_image_model: Optional[str] = None
|
||||
default_video_model: Optional[str] = None
|
||||
default_audio_model: Optional[str] = None
|
||||
default_llm_model: Optional[str] = None
|
||||
default_style: Optional[str] = None
|
||||
default_resolution: Optional[str] = None
|
||||
default_ratio: Optional[str] = None
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
@router.get("/config/system", response_model=ResponseModel)
|
||||
async def get_system_config():
|
||||
"""Get system configuration (cached for 5 minutes).
|
||||
|
||||
Raises:
|
||||
BusinessException: Failed to load system config
|
||||
"""
|
||||
cache_key = "config:system"
|
||||
|
||||
# Try cache first
|
||||
if REDIS_ENABLED:
|
||||
cache = get_cache_service()
|
||||
cached_config = await cache.get(cache_key)
|
||||
if cached_config is not None:
|
||||
return ResponseModel(data=cached_config)
|
||||
|
||||
try:
|
||||
src_dir = os.path.dirname(os.path.dirname(__file__))
|
||||
config_path = os.path.join(src_dir, "config", "user_config.json")
|
||||
|
||||
config_data = {}
|
||||
if os.path.exists(config_path):
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config_data = json.load(f)
|
||||
|
||||
# Cache result for 5 minutes
|
||||
if REDIS_ENABLED:
|
||||
cache = get_cache_service()
|
||||
await cache.set(cache_key, config_data, ttl=300)
|
||||
|
||||
return ResponseModel(data=config_data)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load system config: {e}", exc_info=True)
|
||||
raise BusinessException(
|
||||
ErrorCode.UNKNOWN_ERROR,
|
||||
"Failed to load system configuration",
|
||||
{"reason": str(e)}
|
||||
)
|
||||
|
||||
|
||||
@router.post("/config/system", response_model=ResponseModel)
|
||||
async def update_system_config(config: SystemConfig):
|
||||
"""Update system configuration.
|
||||
|
||||
Raises:
|
||||
BusinessException: Failed to update system config
|
||||
"""
|
||||
try:
|
||||
src_dir = os.path.dirname(os.path.dirname(__file__))
|
||||
config_path = os.path.join(src_dir, "config", "user_config.json")
|
||||
|
||||
# Load existing to merge
|
||||
existing_data = {}
|
||||
if os.path.exists(config_path):
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
try:
|
||||
existing_data = json.load(f)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning("Failed to parse existing config %s: %s", config_path, e)
|
||||
|
||||
new_data = config.model_dump(exclude_unset=True)
|
||||
merged_data = {**existing_data, **new_data}
|
||||
|
||||
with open(config_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(merged_data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
# Update default models in registry if changed
|
||||
config_key_to_type = {
|
||||
'defaultImageModel': ModelType.IMAGE,
|
||||
'defaultVideoModel': ModelType.VIDEO,
|
||||
'defaultAudioModel': ModelType.AUDIO,
|
||||
'defaultLyricsModel': ModelType.LYRICS,
|
||||
'defaultMusicModel': ModelType.MUSIC,
|
||||
'defaultLLMModel': ModelType.LLM
|
||||
}
|
||||
|
||||
for config_key, model_type in config_key_to_type.items():
|
||||
if config_key in new_data:
|
||||
model_id = new_data[config_key]
|
||||
if model_id:
|
||||
ModelRegistry.set_default_by_id(model_type, model_id)
|
||||
logger.info(f"Updated default {model_type.value} model to: {model_id}")
|
||||
|
||||
# Invalidate all related caches
|
||||
if REDIS_ENABLED:
|
||||
cache = get_cache_service()
|
||||
await cache.delete("config:system")
|
||||
await cache.delete("config:models")
|
||||
await cache.delete("config:defaults")
|
||||
|
||||
return ResponseModel(data=merged_data)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update system config: {e}", exc_info=True)
|
||||
raise BusinessException(
|
||||
ErrorCode.UNKNOWN_ERROR,
|
||||
"Failed to update system configuration",
|
||||
{"reason": str(e)}
|
||||
)
|
||||
79
backend/src/api/config/validation.py
Normal file
79
backend/src/api/config/validation.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""Configuration validation endpoints."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Dict, Any
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from src.models.schemas import ResponseModel
|
||||
from src.utils.errors import AppException, ErrorCode
|
||||
from src.services.provider.validation import ConfigValidator, validate_all_configs
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/config/validate", response_model=ResponseModel)
|
||||
async def validate_configurations():
|
||||
"""Validate all service configurations."""
|
||||
try:
|
||||
report = validate_all_configs()
|
||||
return ResponseModel(data=report)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to validate configurations: {e}")
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.post("/config/validate/service", response_model=ResponseModel)
|
||||
async def validate_service_configuration(config: Dict[str, Any]):
|
||||
"""Validate a single service configuration."""
|
||||
try:
|
||||
is_valid, errors = ConfigValidator.validate_config_deep(config)
|
||||
return ResponseModel(data={
|
||||
"valid": is_valid,
|
||||
"errors": errors,
|
||||
"config": config
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to validate service configuration: {e}")
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.get("/config/validate/file/{filename}", response_model=ResponseModel)
|
||||
async def validate_config_file(filename: str):
|
||||
"""Validate a specific configuration file."""
|
||||
try:
|
||||
src_dir = os.path.dirname(os.path.dirname(__file__))
|
||||
config_path = os.path.join(src_dir, "config", "services", filename)
|
||||
|
||||
if not os.path.exists(config_path):
|
||||
raise AppException(
|
||||
message=f"Configuration file not found: {filename}",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
|
||||
is_valid, errors = ConfigValidator.validate_config_file(config_path)
|
||||
return ResponseModel(data={
|
||||
"filename": filename,
|
||||
"valid": is_valid,
|
||||
"errors": errors
|
||||
})
|
||||
except AppException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to validate config file: {e}")
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
45
backend/src/api/generations/__init__.py
Normal file
45
backend/src/api/generations/__init__.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""
|
||||
Generations Module - 生成相关路由模块
|
||||
|
||||
将各个子模块的路由组合成统一的 router
|
||||
|
||||
子模块:
|
||||
- script.py - 脚本分析端点
|
||||
- tasks.py - 任务状态端点
|
||||
- image.py - 图片生成端点
|
||||
- video.py - 视频生成端点
|
||||
- audio.py - 音频生成和导出端点
|
||||
- music.py - 音乐生成端点
|
||||
- batch.py - 批量生成端点
|
||||
"""
|
||||
from fastapi import APIRouter
|
||||
|
||||
from .script import router as script_router
|
||||
from .tasks import router as tasks_router
|
||||
from .image import router as image_router
|
||||
from .video import router as video_router
|
||||
from .audio import router as audio_router
|
||||
from .music import router as music_router
|
||||
from .batch import router as batch_router
|
||||
|
||||
# 创建主路由
|
||||
router = APIRouter(tags=["generations"])
|
||||
|
||||
# 包含所有子路由
|
||||
router.include_router(script_router)
|
||||
router.include_router(tasks_router)
|
||||
router.include_router(image_router)
|
||||
router.include_router(video_router)
|
||||
router.include_router(audio_router)
|
||||
router.include_router(music_router)
|
||||
router.include_router(batch_router)
|
||||
|
||||
# 导出辅助函数供其他模块使用
|
||||
from .helpers import resolve_service, ensure_url, get_available_styles_from_config
|
||||
|
||||
__all__ = [
|
||||
"router",
|
||||
"resolve_service",
|
||||
"ensure_url",
|
||||
"get_available_styles_from_config",
|
||||
]
|
||||
159
backend/src/api/generations/audio.py
Normal file
159
backend/src/api/generations/audio.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""
|
||||
Audio & Export Endpoints - 音频生成和导出模块
|
||||
|
||||
包含:
|
||||
- POST /audio/generate - 音频生成
|
||||
- POST /generations/audio - 统一音频生成
|
||||
- POST /export/project/{project_id} - 项目导出
|
||||
"""
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from src.models.schemas import (
|
||||
ResponseModel,
|
||||
AudioGenerationRequest,
|
||||
GenerateAudioRequest,
|
||||
ExportProjectRequest
|
||||
)
|
||||
from src.services.provider.registry import ModelRegistry, ModelType
|
||||
from src.services.task_manager import task_manager
|
||||
from src.services.export_service import VideoExportService
|
||||
from src.auth.dependencies import get_current_user, UserAuth
|
||||
|
||||
from .helpers import resolve_service, check_user_api_key
|
||||
from src.utils.errors import ErrorCode, AppException
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 初始化 export service
|
||||
export_service = VideoExportService()
|
||||
|
||||
|
||||
@router.post("/audio/generate", response_model=ResponseModel)
|
||||
async def generate_audio(
|
||||
request: GenerateAudioRequest,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""生成 audio from text(兼容旧接口,内部走统一任务系统)"""
|
||||
logger.info("Audio generation request (legacy): model=%s, text=%s...", request.model, request.text[:50])
|
||||
|
||||
# 兼容旧接口:model 可为空,回退到默认音频模型
|
||||
model_name = request.model or ModelRegistry.get_default_id(ModelType.AUDIO)
|
||||
if not model_name:
|
||||
raise AppException(
|
||||
message="Audio service not configured",
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
# 验证模型可解析并获取服务
|
||||
audio_service = resolve_service(model_name, ModelType.AUDIO)
|
||||
|
||||
# 检查用户是否配置了 API Key
|
||||
# 从 audio_service 获取 provider_id
|
||||
provider = getattr(audio_service, 'provider_id', None)
|
||||
if provider:
|
||||
check_user_api_key(current_user.id, provider)
|
||||
|
||||
try:
|
||||
task_params = {
|
||||
"text": request.text,
|
||||
"voice": request.voice,
|
||||
"model": model_name,
|
||||
"project_id": request.project_id,
|
||||
"storyboard_id": request.storyboard_id,
|
||||
"extra_params": request.extra_params or {},
|
||||
}
|
||||
|
||||
task = await task_manager.create_task(
|
||||
task_type="audio",
|
||||
model=model_name,
|
||||
params=task_params,
|
||||
user_id=current_user.id,
|
||||
project_id=request.project_id,
|
||||
)
|
||||
|
||||
# 兼容前端:保留 audio_url 字段但异步任务初始为空
|
||||
return ResponseModel(data={
|
||||
"audio_url": None,
|
||||
"task_id": task.id,
|
||||
"status": task.status
|
||||
})
|
||||
except Exception as e:
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.post("/generations/audio", response_model=ResponseModel)
|
||||
async def generate_audio_unified(
|
||||
request: AudioGenerationRequest,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""统一音频生成端点(推荐)"""
|
||||
logger.info("Audio generation request: model=%s, text=%s...", request.model, request.text[:50])
|
||||
|
||||
audio_service = resolve_service(request.model, ModelType.AUDIO)
|
||||
|
||||
# 检查用户是否配置了 API Key
|
||||
# 从 audio_service 获取 provider_id
|
||||
provider = getattr(audio_service, 'provider_id', None)
|
||||
if provider:
|
||||
check_user_api_key(current_user.id, provider)
|
||||
|
||||
if not audio_service:
|
||||
raise AppException(
|
||||
message="Audio service not configured",
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
try:
|
||||
task_params = request.model_dump()
|
||||
# Get user_id from authenticated user
|
||||
user_id = current_user.id
|
||||
|
||||
# Handle project_id from multiple sources
|
||||
# Priority 1: request.project_id (from request body, alias="projectId")
|
||||
# Priority 2: request.source_id if source=='project' (from extra_params, alias="sourceId")
|
||||
project_id = request.project_id
|
||||
if not project_id and request.source == "project":
|
||||
project_id = request.source_id
|
||||
|
||||
task = await task_manager.create_task(
|
||||
task_type="audio",
|
||||
model=request.model,
|
||||
params=task_params,
|
||||
user_id=user_id,
|
||||
project_id=project_id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to create audio task: %s", e)
|
||||
raise AppException(
|
||||
message=f"Failed to create task: {e}",
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
return ResponseModel(data={
|
||||
"task_id": task.id,
|
||||
"status": task.status
|
||||
})
|
||||
|
||||
|
||||
@router.post("/export/project/{project_id}", response_model=ResponseModel)
|
||||
async def export_project(project_id: str, request: ExportProjectRequest):
|
||||
""" 导出 project to video"""
|
||||
try:
|
||||
result = await export_service.export_project(project_id, format=request.format)
|
||||
return ResponseModel(data=result)
|
||||
except Exception as e:
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
320
backend/src/api/generations/batch.py
Normal file
320
backend/src/api/generations/batch.py
Normal file
@@ -0,0 +1,320 @@
|
||||
"""
|
||||
Batch Generation Endpoints - 批量生成模块
|
||||
|
||||
包含批量生成相关的端点:
|
||||
- POST /generations/batch - 批量创建生成任务
|
||||
"""
|
||||
import logging
|
||||
from typing import List, Dict, Any
|
||||
from fastapi import APIRouter, Depends
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
|
||||
from src.models.schemas import (
|
||||
ResponseModel,
|
||||
ImageGenerationRequest,
|
||||
VideoGenerationRequest,
|
||||
AudioGenerationRequest,
|
||||
MusicGenerationRequest
|
||||
)
|
||||
from src.services.provider.registry import ModelRegistry, ModelType
|
||||
from src.services.task_manager import task_manager
|
||||
from src.auth.dependencies import get_current_user, UserAuth
|
||||
from src.utils.errors import ErrorCode, AppException
|
||||
|
||||
from .helpers import resolve_service, ensure_url, check_user_api_key
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BatchGenerationRequest(BaseModel):
|
||||
"""批量生成请求"""
|
||||
items: List[Dict[str, Any]] = Field(..., description="生成任务列表")
|
||||
|
||||
model_config = ConfigDict(json_schema_extra={
|
||||
"example": {
|
||||
"items": [
|
||||
{
|
||||
"type": "image",
|
||||
"prompt": "A beautiful landscape",
|
||||
"model": "wanx2.1-t2i-turbo"
|
||||
}
|
||||
]
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
class BatchGenerationResponse(BaseModel):
|
||||
"""批量生成响应"""
|
||||
task_ids: List[str] = Field(..., description="创建的任务ID列表")
|
||||
total: int = Field(..., description="总任务数")
|
||||
created: int = Field(..., description="成功创建的任务数")
|
||||
failed: int = Field(..., description="失败的任务数")
|
||||
errors: List[Dict[str, str]] = Field(default=[], description="错误信息列表")
|
||||
|
||||
|
||||
@router.post("/generations/batch", response_model=ResponseModel)
|
||||
async def generate_batch(
|
||||
request: BatchGenerationRequest,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""批量创建生成任务
|
||||
|
||||
支持同时创建多个图片、视频、音频、音乐生成任务
|
||||
|
||||
Args:
|
||||
request: 批量生成请求,包含任务列表
|
||||
|
||||
Returns:
|
||||
批量生成响应,包含创建的任务ID列表
|
||||
"""
|
||||
logger.info(f"Batch generation request: {len(request.items)} items from user {current_user.id}")
|
||||
|
||||
if not request.items:
|
||||
raise AppException(
|
||||
message="No items provided for batch generation",
|
||||
code=ErrorCode.INVALID_PARAMETER,
|
||||
status_code=400
|
||||
)
|
||||
|
||||
if len(request.items) > 20:
|
||||
raise AppException(
|
||||
message="Batch size exceeds maximum limit of 20",
|
||||
code=ErrorCode.INVALID_PARAMETER,
|
||||
status_code=400
|
||||
)
|
||||
|
||||
task_ids = []
|
||||
errors = []
|
||||
|
||||
for idx, item in enumerate(request.items):
|
||||
try:
|
||||
item_type = item.get("type", "image")
|
||||
|
||||
if item_type == "image":
|
||||
task_id = await _create_image_task(item, current_user)
|
||||
elif item_type == "video":
|
||||
task_id = await _create_video_task(item, current_user)
|
||||
elif item_type == "audio":
|
||||
task_id = await _create_audio_task(item, current_user)
|
||||
elif item_type == "music":
|
||||
task_id = await _create_music_task(item, current_user)
|
||||
else:
|
||||
raise ValueError(f"Unsupported generation type: {item_type}")
|
||||
|
||||
task_ids.append(task_id)
|
||||
logger.info(f"Created {item_type} task {task_id} for batch item {idx}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create task for batch item {idx}: {e}")
|
||||
errors.append({
|
||||
"index": idx,
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
result = BatchGenerationResponse(
|
||||
task_ids=task_ids,
|
||||
total=len(request.items),
|
||||
created=len(task_ids),
|
||||
failed=len(errors),
|
||||
errors=errors
|
||||
)
|
||||
|
||||
logger.info(f"Batch generation completed: {result.created}/{result.total} tasks created")
|
||||
|
||||
return ResponseModel(data=result)
|
||||
|
||||
|
||||
async def _create_image_task(item: Dict[str, Any], current_user: UserAuth) -> str:
|
||||
"""创建图片生成任务"""
|
||||
image_service = resolve_service(item.get("model"), ModelType.IMAGE)
|
||||
|
||||
if not image_service:
|
||||
raise ValueError("Image service not configured")
|
||||
|
||||
# 检查 API Key
|
||||
provider = getattr(image_service, 'provider_id', None)
|
||||
if provider:
|
||||
check_user_api_key(current_user.id, provider)
|
||||
|
||||
# 解析尺寸
|
||||
aspect_ratio = item.get("aspect_ratio", "16:9")
|
||||
resolution = item.get("resolution", "1K")
|
||||
size = _resolve_image_size(image_service, aspect_ratio, resolution)
|
||||
|
||||
# 构建参数
|
||||
task_params = {
|
||||
"prompt": item.get("prompt", ""),
|
||||
"size": size,
|
||||
"aspect_ratio": aspect_ratio,
|
||||
"resolution": resolution,
|
||||
"image_inputs": [_sanitize_url(url) for url in item.get("image_inputs", []) if url],
|
||||
"extra_params": item.get("extra_params", {})
|
||||
}
|
||||
|
||||
# 处理 source 相关字段
|
||||
if item.get("source"):
|
||||
task_params["source"] = item["source"]
|
||||
if item.get("source_id"):
|
||||
task_params["source_id"] = item["source_id"]
|
||||
if item.get("project_id"):
|
||||
task_params["project_id"] = item["project_id"]
|
||||
|
||||
model_name = item.get("model") or getattr(image_service, "model_id", "default_image")
|
||||
|
||||
task = await task_manager.create_task(
|
||||
task_type="image",
|
||||
model=model_name,
|
||||
params=task_params,
|
||||
user_id=current_user.id,
|
||||
project_id=item.get("project_id")
|
||||
)
|
||||
|
||||
return task.id
|
||||
|
||||
|
||||
async def _create_video_task(item: Dict[str, Any], current_user: UserAuth) -> str:
|
||||
"""创建视频生成任务"""
|
||||
video_service = resolve_service(item.get("model"), ModelType.VIDEO)
|
||||
|
||||
if not video_service:
|
||||
raise ValueError("Video service not configured")
|
||||
|
||||
# 检查 API Key
|
||||
provider = getattr(video_service, 'provider_id', None)
|
||||
if provider:
|
||||
check_user_api_key(current_user.id, provider)
|
||||
|
||||
# 构建参数
|
||||
task_params = {
|
||||
"prompt": item.get("prompt", ""),
|
||||
"duration": item.get("duration", 5),
|
||||
"resolution": item.get("resolution", "720p"),
|
||||
"image_inputs": [_sanitize_url(url) for url in item.get("image_inputs", []) if url],
|
||||
"extra_params": item.get("extra_params", {})
|
||||
}
|
||||
|
||||
if item.get("source"):
|
||||
task_params["source"] = item["source"]
|
||||
if item.get("source_id"):
|
||||
task_params["source_id"] = item["source_id"]
|
||||
if item.get("project_id"):
|
||||
task_params["project_id"] = item["project_id"]
|
||||
|
||||
model_name = item.get("model") or getattr(video_service, "model_id", "default_video")
|
||||
|
||||
task = await task_manager.create_task(
|
||||
task_type="video",
|
||||
model=model_name,
|
||||
params=task_params,
|
||||
user_id=current_user.id,
|
||||
project_id=item.get("project_id")
|
||||
)
|
||||
|
||||
return task.id
|
||||
|
||||
|
||||
async def _create_audio_task(item: Dict[str, Any], current_user: UserAuth) -> str:
|
||||
"""创建音频生成任务"""
|
||||
audio_service = resolve_service(item.get("model"), ModelType.AUDIO)
|
||||
|
||||
if not audio_service:
|
||||
raise ValueError("Audio service not configured")
|
||||
|
||||
# 检查 API Key
|
||||
provider = getattr(audio_service, 'provider_id', None)
|
||||
if provider:
|
||||
check_user_api_key(current_user.id, provider)
|
||||
|
||||
task_params = {
|
||||
"text": item.get("text", item.get("prompt", "")),
|
||||
"voice": item.get("voice", "default"),
|
||||
"speed": item.get("speed", 1.0),
|
||||
"extra_params": item.get("extra_params", {})
|
||||
}
|
||||
|
||||
if item.get("source"):
|
||||
task_params["source"] = item["source"]
|
||||
if item.get("source_id"):
|
||||
task_params["source_id"] = item["source_id"]
|
||||
|
||||
model_name = item.get("model") or getattr(audio_service, "model_id", "default_audio")
|
||||
|
||||
task = await task_manager.create_task(
|
||||
task_type="audio",
|
||||
model=model_name,
|
||||
params=task_params,
|
||||
user_id=current_user.id,
|
||||
project_id=item.get("project_id")
|
||||
)
|
||||
|
||||
return task.id
|
||||
|
||||
|
||||
async def _create_music_task(item: Dict[str, Any], current_user: UserAuth) -> str:
|
||||
"""创建音乐生成任务"""
|
||||
music_service = resolve_service(item.get("model"), ModelType.MUSIC)
|
||||
|
||||
if not music_service:
|
||||
raise ValueError("Music service not configured")
|
||||
|
||||
# 检查 API Key
|
||||
provider = getattr(music_service, 'provider_id', None)
|
||||
if provider:
|
||||
check_user_api_key(current_user.id, provider)
|
||||
|
||||
task_params = {
|
||||
"prompt": item.get("prompt", ""),
|
||||
"duration": item.get("duration", 30),
|
||||
"extra_params": item.get("extra_params", {})
|
||||
}
|
||||
|
||||
if item.get("source"):
|
||||
task_params["source"] = item["source"]
|
||||
if item.get("source_id"):
|
||||
task_params["source_id"] = item["source_id"]
|
||||
|
||||
model_name = item.get("model") or getattr(music_service, "model_id", "default_music")
|
||||
|
||||
task = await task_manager.create_task(
|
||||
task_type="music",
|
||||
model=model_name,
|
||||
params=task_params,
|
||||
user_id=current_user.id,
|
||||
project_id=item.get("project_id")
|
||||
)
|
||||
|
||||
return task.id
|
||||
|
||||
|
||||
def _resolve_image_size(image_service, aspect_ratio: str, resolution: str) -> str:
|
||||
"""解析图片尺寸"""
|
||||
model_config = getattr(image_service, "config", {})
|
||||
resolutions_config = model_config.get("resolutions", {})
|
||||
|
||||
if resolutions_config and resolution in resolutions_config:
|
||||
ratio_map = resolutions_config[resolution]
|
||||
if isinstance(ratio_map, dict) and aspect_ratio in ratio_map:
|
||||
return ratio_map[aspect_ratio]
|
||||
|
||||
# 默认尺寸
|
||||
defaults = {
|
||||
"1K": {
|
||||
"16:9": "1280*720",
|
||||
"9:16": "720*1280",
|
||||
"1:1": "1024*1024"
|
||||
},
|
||||
"2K": {
|
||||
"16:9": "2560*1440",
|
||||
"9:16": "1440*2560",
|
||||
"1:1": "2048*2048"
|
||||
}
|
||||
}
|
||||
return defaults.get(resolution, {}).get(aspect_ratio, "1024*1024")
|
||||
|
||||
|
||||
def _sanitize_url(url: str) -> str:
|
||||
"""清理 URL"""
|
||||
if not url:
|
||||
return url
|
||||
return ensure_url(url)
|
||||
396
backend/src/api/generations/helpers.py
Normal file
396
backend/src/api/generations/helpers.py
Normal file
@@ -0,0 +1,396 @@
|
||||
"""
|
||||
Generation Helpers - 辅助函数模块
|
||||
|
||||
包含:
|
||||
- resolve_service: 服务路由函数
|
||||
- ensure_url: URL 处理函数
|
||||
- _get_available_styles_from_config: 样式配置获取
|
||||
"""
|
||||
import logging
|
||||
import base64
|
||||
import uuid
|
||||
import os
|
||||
import json
|
||||
from functools import lru_cache
|
||||
from typing import List, Optional
|
||||
|
||||
from src.services.provider.registry import ModelRegistry, ModelType
|
||||
from src.services.storage_service import storage_manager
|
||||
from src.config.settings import REDIS_ENABLED
|
||||
from src.services.cache_service import get_cache_service
|
||||
from src.utils.errors import ErrorCode, AppException
|
||||
from src.services.user_api_key_service import user_api_key_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _is_test_env() -> bool:
|
||||
return (
|
||||
os.getenv("PYTEST_CURRENT_TEST") is not None
|
||||
or os.getenv("PIXEL_TEST_MODE") == "1"
|
||||
or os.getenv("NODE_ENV") == "test"
|
||||
)
|
||||
|
||||
|
||||
def _ensure_model_registry_loaded() -> None:
|
||||
if ModelRegistry.list_models():
|
||||
return
|
||||
|
||||
from src.utils.service_loader import load_services_from_config
|
||||
|
||||
services_config_path = os.path.join(
|
||||
os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
|
||||
"config",
|
||||
"services",
|
||||
)
|
||||
load_services_from_config(services_config_path)
|
||||
|
||||
|
||||
def _reload_model_registry() -> None:
|
||||
from src.utils.service_loader import load_services_from_config
|
||||
|
||||
services_config_path = os.path.join(
|
||||
os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
|
||||
"config",
|
||||
"services",
|
||||
)
|
||||
load_services_from_config(services_config_path)
|
||||
|
||||
|
||||
@lru_cache(maxsize=128)
|
||||
def resolve_service(
|
||||
model: str,
|
||||
model_type: ModelType
|
||||
):
|
||||
"""
|
||||
根据复合 ID 路由到正确的服务
|
||||
|
||||
路由逻辑:
|
||||
1. 验证 model 必须是 provider/model_key 格式
|
||||
2. 解析 provider 和 model_key
|
||||
3. 尝试用复合 ID 直接查找
|
||||
4. 在该 provider 下搜索匹配的 model_key
|
||||
5. 如果都找不到,抛出 HTTPException(404)
|
||||
|
||||
Args:
|
||||
model: 复合 ID 格式 "provider/model_key"(如 'dashscope/qwen-image')
|
||||
model_type: 模型类型(IMAGE, VIDEO, AUDIO 等)
|
||||
|
||||
Returns:
|
||||
服务实例
|
||||
|
||||
Raises:
|
||||
ValueError: 当 model 格式不正确时
|
||||
HTTPException: 当模型不存在时
|
||||
"""
|
||||
_ensure_model_registry_loaded()
|
||||
|
||||
# 1. 验证格式
|
||||
if not model or '/' not in model:
|
||||
raise ValueError(
|
||||
f"Model must be in format 'provider/model_key', got: '{model}'. "
|
||||
f"Example: 'dashscope/qwen-image'"
|
||||
)
|
||||
|
||||
# 2. 解析 provider 和 model_key
|
||||
parts = model.split('/', 1)
|
||||
if len(parts) != 2:
|
||||
raise ValueError(
|
||||
f"Model format invalid: '{model}'. "
|
||||
f"Must have exactly one '/' separator."
|
||||
)
|
||||
|
||||
provider, model_key = parts
|
||||
if not provider or not model_key:
|
||||
raise ValueError(
|
||||
f"Model format invalid: '{model}'. "
|
||||
f"Both provider and model_key must be non-empty."
|
||||
)
|
||||
|
||||
logger.info(f"Resolving service: model='{model}', type={model_type.value}")
|
||||
|
||||
# 3. 方式1: 直接用复合 ID 查找
|
||||
service = ModelRegistry.get(model)
|
||||
if service:
|
||||
logger.info(f"Found service by composite ID: '{model}'")
|
||||
return service
|
||||
|
||||
# 4. 方式2: 在 provider 下搜索 model_key
|
||||
services = ModelRegistry.find_services(provider=provider, model_type=model_type)
|
||||
matching = [
|
||||
s for s in services
|
||||
if s.get('model_key') == model_key or s.get('id') == model
|
||||
]
|
||||
|
||||
if matching:
|
||||
service = ModelRegistry.get(matching[0].get('id'))
|
||||
logger.info(
|
||||
f"Found service by provider+model_key: "
|
||||
f"provider='{provider}', model_key='{model_key}'"
|
||||
)
|
||||
return service
|
||||
|
||||
# 4.5 可能 registry 被测试或运行时临时覆盖,尝试强制回填一次
|
||||
_reload_model_registry()
|
||||
|
||||
service = ModelRegistry.get(model)
|
||||
if service:
|
||||
logger.info(f"Found service by composite ID after registry reload: '{model}'")
|
||||
return service
|
||||
|
||||
services = ModelRegistry.find_services(provider=provider, model_type=model_type)
|
||||
matching = [
|
||||
s for s in services
|
||||
if s.get('model_key') == model_key or s.get('id') == model
|
||||
]
|
||||
if matching:
|
||||
service = ModelRegistry.get(matching[0].get('id'))
|
||||
logger.info(
|
||||
f"Found service by provider+model_key after registry reload: "
|
||||
f"provider='{provider}', model_key='{model_key}'"
|
||||
)
|
||||
return service
|
||||
|
||||
# 5. 未找到
|
||||
logger.error(f"Model '{model}' not found")
|
||||
raise AppException(
|
||||
message=f"Model '{model}' not found. Available models can be fetched from /api/v1/models",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
|
||||
|
||||
def ensure_url(url_or_base64: str) -> str:
|
||||
"""
|
||||
Ensure the input is a valid URL that external services can access.
|
||||
- If it's a relative path and using local storage, convert to Base64.
|
||||
- If it's a relative path and using OSS, upload to OSS.
|
||||
- If it's a Base64 string, return as is.
|
||||
- If it's a localhost URL, download and convert to Base64 or upload to OSS.
|
||||
- If it's already a public URL, sign it if needed.
|
||||
"""
|
||||
if not url_or_base64:
|
||||
return url_or_base64
|
||||
|
||||
# Check if it's a relative path (starts with /files/ or /uploads/)
|
||||
if url_or_base64.startswith('/files/') or url_or_base64.startswith('/uploads/'):
|
||||
try:
|
||||
from src.config.settings import DATA_DIR, UPLOAD_DIR, STORAGE_TYPE
|
||||
|
||||
logger.info(f"[ensure_url] Detected relative path: {url_or_base64}")
|
||||
|
||||
# Determine the file path
|
||||
if url_or_base64.startswith('/files/'):
|
||||
rel_path = url_or_base64[7:] # Remove '/files/'
|
||||
file_path = os.path.join(DATA_DIR, rel_path)
|
||||
else: # /uploads/
|
||||
rel_path = url_or_base64[9:] # Remove '/uploads/'
|
||||
file_path = os.path.join(UPLOAD_DIR, rel_path)
|
||||
|
||||
# Check if file exists
|
||||
if not os.path.exists(file_path):
|
||||
logger.error(f"[ensure_url] File not found: {file_path}")
|
||||
return url_or_base64
|
||||
|
||||
# Read the file
|
||||
with open(file_path, 'rb') as f:
|
||||
file_data = f.read()
|
||||
|
||||
# Determine extension and MIME type
|
||||
ext = os.path.splitext(file_path)[1].lstrip('.') or 'png'
|
||||
mime_type = f"image/{ext}"
|
||||
if ext == 'jpg':
|
||||
mime_type = "image/jpeg"
|
||||
|
||||
# If using OSS, upload the file to OSS
|
||||
if STORAGE_TYPE == 'oss':
|
||||
filename = f"temp/{uuid.uuid4()}.{ext}"
|
||||
oss_url = storage_manager.save(filename, file_data)
|
||||
logger.info(f"[ensure_url] Uploaded to OSS: {oss_url}")
|
||||
return oss_url
|
||||
|
||||
# If using local storage, convert to Base64
|
||||
else:
|
||||
base64_data = base64.b64encode(file_data).decode('utf-8')
|
||||
base64_url = f"data:{mime_type};base64,{base64_data}"
|
||||
logger.info(f"[ensure_url] Converted to Base64 (length: {len(base64_url)})")
|
||||
return base64_url
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[ensure_url] Failed to convert relative path: {e}")
|
||||
return url_or_base64
|
||||
|
||||
# Check if it's a localhost URL — read from disk instead of HTTP to prevent SSRF
|
||||
if any(host in url_or_base64.lower() for host in ['localhost', '127.0.0.1', '0.0.0.0']):
|
||||
try:
|
||||
from src.config.settings import STORAGE_TYPE
|
||||
from urllib.parse import urlparse
|
||||
|
||||
logger.info(f"[ensure_url] Detected localhost URL, reading from disk: {url_or_base64}")
|
||||
|
||||
# Extract the path component (e.g. /files/xxx or /uploads/xxx)
|
||||
parsed = urlparse(url_or_base64)
|
||||
path = parsed.path
|
||||
|
||||
# Map URL path to local filesystem
|
||||
if path.startswith('/files/'):
|
||||
rel_path = path[7:]
|
||||
file_path = os.path.join(DATA_DIR, rel_path)
|
||||
elif path.startswith('/uploads/'):
|
||||
rel_path = path[9:]
|
||||
file_path = os.path.join(UPLOAD_DIR, rel_path)
|
||||
else:
|
||||
logger.warning(f"[ensure_url] Unsupported localhost path: {path}")
|
||||
return url_or_base64
|
||||
|
||||
if not os.path.exists(file_path):
|
||||
logger.error(f"[ensure_url] File not found: {file_path}")
|
||||
return url_or_base64
|
||||
|
||||
with open(file_path, 'rb') as f:
|
||||
file_data = f.read()
|
||||
|
||||
ext = os.path.splitext(file_path)[1].lstrip('.') or 'png'
|
||||
mime_type = f"image/{ext}"
|
||||
if ext == 'jpg':
|
||||
mime_type = "image/jpeg"
|
||||
|
||||
if STORAGE_TYPE == 'oss':
|
||||
filename = f"temp/{uuid.uuid4()}.{ext}"
|
||||
saved_url = storage_manager.save(filename, file_data)
|
||||
logger.info(f"[ensure_url] Uploaded to OSS: {saved_url}")
|
||||
return saved_url
|
||||
else:
|
||||
base64_data = base64.b64encode(file_data).decode('utf-8')
|
||||
base64_url = f"data:{mime_type};base64,{base64_data}"
|
||||
logger.info(f"[ensure_url] Converted to Base64 (length: {len(base64_url)})")
|
||||
return base64_url
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[ensure_url] Failed to convert localhost URL: {e}")
|
||||
return url_or_base64
|
||||
|
||||
# Check if Base64 - return as is (already in correct format)
|
||||
if url_or_base64.startswith("data:"):
|
||||
logger.info(f"[ensure_url] Already Base64 format, returning as is")
|
||||
return url_or_base64
|
||||
|
||||
# If it's already a public URL, check if it belongs to our OSS bucket
|
||||
if url_or_base64.startswith(('http://', 'https://')):
|
||||
from src.config.settings import STORAGE_TYPE
|
||||
from src.utils.oss_utils import OSS_BUCKET
|
||||
|
||||
# Only sign URLs from our own OSS bucket; external URLs should pass through as-is
|
||||
if STORAGE_TYPE == 'oss' and OSS_BUCKET and f"{OSS_BUCKET}." in url_or_base64:
|
||||
return storage_manager.sign_url(url_or_base64) or url_or_base64
|
||||
else:
|
||||
# External public URL (e.g. dashscope-result, other CDN), return as-is
|
||||
logger.info(f"[ensure_url] External public URL, returning as-is")
|
||||
return url_or_base64
|
||||
|
||||
# Fallback: sign it if needed
|
||||
return storage_manager.sign_url(url_or_base64) or url_or_base64
|
||||
|
||||
|
||||
async def get_available_styles_from_config() -> List[dict]:
|
||||
""" 获取 available style configurations from backend config file.
|
||||
|
||||
Returns:
|
||||
List of style dicts with full information (id, name, desc, type, etc.)
|
||||
Example: [{'id': 'cyberpunk', 'name': '赛博朋克', 'desc': '高对比度霓虹灯光...', ...}, ...]
|
||||
"""
|
||||
cache_key = "config:styles_full"
|
||||
|
||||
# Cache first (5 minutes TTL)
|
||||
if REDIS_ENABLED:
|
||||
cache = get_cache_service()
|
||||
cached_styles = await cache.get(cache_key)
|
||||
if cached_styles is not None:
|
||||
return cached_styles
|
||||
|
||||
# Get from config file
|
||||
src_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||
styles_path = os.path.join(src_dir, "config", "styles.json")
|
||||
|
||||
styles_list = []
|
||||
if os.path.exists(styles_path):
|
||||
with open(styles_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
styles_list = data.get('styles', [])
|
||||
|
||||
# Fallback to common styles with descriptions if config doesn't exist
|
||||
if not styles_list:
|
||||
styles_list = [
|
||||
{"id": "cyberpunk", "name": "赛博朋克", "desc": "高对比度霓虹灯光,未来主义建筑,机械改造"},
|
||||
{"id": "ink", "name": "水墨风", "desc": "传统中国水墨渲染,黑白灰为主,意境深远"},
|
||||
{"id": "pixar", "name": "皮克斯风格", "desc": "3D卡通渲染,色彩鲜艳,光影柔和,表情夸张"},
|
||||
{"id": "anime", "name": "日漫", "desc": "典型日本动画风格,线条清晰,赛璐璐上色"},
|
||||
{"id": "chinese-anime", "name": "国漫风", "desc": "中国现代动画风格,融合传统与现代元素"},
|
||||
{"id": "realistic", "name": "写实", "desc": "电影级写实渲染,细节丰富,光照真实"},
|
||||
{"id": "hand-drawn", "name": "手绘", "desc": "传统手绘质感,笔触明显,艺术感强"},
|
||||
{"id": "watercolor", "name": "水彩画", "desc": "水彩晕染效果,色彩通透,艺术感强"},
|
||||
{"id": "oil-painting", "name": "油画", "desc": "厚涂油画质感,笔触丰富,光影层次感强"},
|
||||
{"id": "pixel-art", "name": "像素风", "desc": "复古8-bit/16-bit像素艺术,怀旧游戏风格"},
|
||||
]
|
||||
|
||||
# 缓存 for 5 minutes
|
||||
if REDIS_ENABLED:
|
||||
cache = get_cache_service()
|
||||
await cache.set(cache_key, styles_list, ttl=300)
|
||||
|
||||
return styles_list
|
||||
|
||||
|
||||
def check_user_api_key(user_id: str, provider: str) -> None:
|
||||
"""
|
||||
检查系统是否配置了指定 Provider 的 API Key。
|
||||
|
||||
Args:
|
||||
user_id: 用户ID(保留参数兼容性)
|
||||
provider: 服务商 ID (如 'dashscope', 'openai', 'kling' 等)
|
||||
|
||||
Raises:
|
||||
AppException: 如果系统未配置该 Provider 的 API Key
|
||||
"""
|
||||
if _is_test_env():
|
||||
return
|
||||
|
||||
import os
|
||||
env_map = {
|
||||
"dashscope": "DASHSCOPE_API_KEY",
|
||||
"openai": "OPENAI_API_KEY",
|
||||
"kling": "KLING_ACCESS_KEY",
|
||||
"midjourney": "MIDJOURNEY_API_KEY",
|
||||
"modelscope": "MODELSCOPE_API_TOKEN",
|
||||
"volcengine": "VOLCENGINE_API_KEY",
|
||||
"minimax": "MINIMAX_API_KEY",
|
||||
"google": "GOOGLE_API_KEY",
|
||||
}
|
||||
env_var = env_map.get(provider)
|
||||
if env_var and os.getenv(env_var):
|
||||
return
|
||||
|
||||
provider_names = {
|
||||
"dashscope": "阿里云 DashScope",
|
||||
"openai": "OpenAI",
|
||||
"kling": "可灵 AI",
|
||||
"midjourney": "Midjourney",
|
||||
"modelscope": "ModelScope",
|
||||
"volcengine": "火山引擎",
|
||||
"minimax": "MiniMax",
|
||||
"google": "Google",
|
||||
}
|
||||
provider_name = provider_names.get(provider, provider)
|
||||
|
||||
raise AppException(
|
||||
message=f"未配置 {provider_name} 的 API Key",
|
||||
code=ErrorCode.INVALID_PARAMETER,
|
||||
status_code=400,
|
||||
details={
|
||||
"error": "API_KEY_NOT_CONFIGURED",
|
||||
"message": f"尚未配置 {provider_name} 的 API Key,无法提交生成任务。",
|
||||
"provider": provider,
|
||||
"provider_name": provider_name,
|
||||
"solution": "请联系管理员在 .env 中配置对应的 API Key。",
|
||||
}
|
||||
)
|
||||
208
backend/src/api/generations/image.py
Normal file
208
backend/src/api/generations/image.py
Normal file
@@ -0,0 +1,208 @@
|
||||
"""
|
||||
Image Generation Endpoints - 图片生成模块
|
||||
|
||||
包含图片生成相关的端点:
|
||||
- POST /generations/image - 图片生成
|
||||
- POST /image/upscale - 图片放大
|
||||
"""
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from src.models.schemas import (
|
||||
ResponseModel,
|
||||
ImageGenerationRequest,
|
||||
UpscaleImageRequest, UpscaleImageResponse
|
||||
)
|
||||
from src.services.provider.registry import ModelRegistry, ModelType
|
||||
from src.services.task_manager import task_manager
|
||||
from src.auth.dependencies import get_current_user, UserAuth
|
||||
|
||||
from .helpers import resolve_service, ensure_url, check_user_api_key
|
||||
from src.utils.errors import ErrorCode, AppException
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@router.post("/generations/image", response_model=ResponseModel)
|
||||
async def generate_image(
|
||||
request: ImageGenerationRequest,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
""" 泛型 Image Generation Endpoint.
|
||||
Supports generation for any context (Storyboard, Asset, etc.)
|
||||
"""
|
||||
# Logging
|
||||
logger.info(f"Image generation request: model={request.model}, prompt={request.prompt[:50]}...")
|
||||
|
||||
# 1. Resolve Model - 使用统一的路由逻辑
|
||||
image_service = resolve_service(request.model, ModelType.IMAGE)
|
||||
|
||||
if not image_service:
|
||||
raise AppException(
|
||||
message="Image service not configured",
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
# 1.5 检查用户是否配置了 API Key
|
||||
# 从 image_service 获取 provider_id
|
||||
provider = getattr(image_service, 'provider_id', None)
|
||||
if provider:
|
||||
check_user_api_key(current_user.id, provider)
|
||||
|
||||
# 2. Resolve size from aspect_ratio and resolution (BEFORE creating task)
|
||||
aspect_ratio = request.aspect_ratio
|
||||
size = None
|
||||
|
||||
# 调试 logging
|
||||
logger.info(f"[ImageGeneration] Request received:")
|
||||
logger.info(f" - Prompt: {request.prompt[:100]}...")
|
||||
logger.info(f" - Model: {request.model}")
|
||||
logger.info(f" - Aspect Ratio: {aspect_ratio}")
|
||||
logger.info(f" - Resolution: {request.resolution}")
|
||||
logger.info(f" - Image inputs count: {len(request.image_inputs or [])}")
|
||||
if request.image_inputs:
|
||||
for i, url in enumerate(request.image_inputs):
|
||||
logger.info(f" - Image input {i}: {url[:100]}...")
|
||||
|
||||
if aspect_ratio:
|
||||
# Look up resolution in model config
|
||||
model_config = getattr(image_service, "config", {})
|
||||
resolutions_config = model_config.get("resolutions", {})
|
||||
|
||||
# Use provided resolution level (e.g. "1K", "2K", "4K") or default to "1K"
|
||||
res_level = request.resolution or "1K"
|
||||
|
||||
if resolutions_config and res_level in resolutions_config and isinstance(resolutions_config[res_level], dict):
|
||||
ratio_map = resolutions_config[res_level]
|
||||
if aspect_ratio in ratio_map:
|
||||
size = ratio_map[aspect_ratio]
|
||||
logger.info(f" - Resolved size from config: {size} (resolution={res_level}, ratio={aspect_ratio})")
|
||||
|
||||
if not size:
|
||||
defaults = {
|
||||
"1K": {
|
||||
"16:9": "1280*720",
|
||||
"9:16": "720*1280",
|
||||
"1:1": "1024*1024",
|
||||
"4:3": "1280*960",
|
||||
"3:4": "960*1280",
|
||||
"2.35:1": "1280*544"
|
||||
},
|
||||
"2K": {
|
||||
"16:9": "2560*1440",
|
||||
"9:16": "1440*2560",
|
||||
"1:1": "2048*2048",
|
||||
"4:3": "2560*1920",
|
||||
"3:4": "1920*2560"
|
||||
},
|
||||
"4K": {
|
||||
"16:9": "3840*2160",
|
||||
"9:16": "2160*3840",
|
||||
"1:1": "4096*4096",
|
||||
"4:3": "3840*2880",
|
||||
"3:4": "2880*3840"
|
||||
}
|
||||
}
|
||||
size = defaults.get(res_level, {}).get(aspect_ratio)
|
||||
if size:
|
||||
logger.info(f" - Resolved size from defaults: {size} (resolution={res_level}, ratio={aspect_ratio})")
|
||||
|
||||
# 3. Create Task in DB with resolved params
|
||||
try:
|
||||
model_name = request.model or getattr(image_service, "model_id", "default_image")
|
||||
|
||||
# 构建 params dict with resolved size
|
||||
task_params = request.model_dump()
|
||||
task_params["size"] = size # Override with resolved size (provider expects 'size')
|
||||
|
||||
# lora_strength 已经在 extra_params 中(前端传递),不需要额外处理
|
||||
# TaskManager 会从 extra_params 中提取并合并到一级参数
|
||||
|
||||
# 清理 Image URLs (Convert relative paths, localhost URLs, Base64)
|
||||
logger.info(f"[ImageGeneration] Before URL sanitization - image_inputs: {task_params.get('image_inputs')}")
|
||||
|
||||
for key in ["image_inputs"]:
|
||||
if key in task_params and isinstance(task_params[key], list):
|
||||
original_urls = task_params[key]
|
||||
task_params[key] = [ensure_url(url) for url in task_params[key] if url]
|
||||
logger.info(f"[ImageGeneration] Sanitized {key}: {original_urls} -> {task_params[key]}")
|
||||
|
||||
# Use unified task manager
|
||||
# Get user_id from authenticated user
|
||||
user_id = current_user.id
|
||||
|
||||
# Handle project_id from multiple sources
|
||||
# Priority 1: request.project_id (from request body, alias="projectId")
|
||||
# Priority 2: request.source_id if source=='project' (from extra_params, alias="sourceId")
|
||||
project_id = request.project_id
|
||||
if not project_id and request.source == 'project':
|
||||
project_id = request.source_id
|
||||
|
||||
task = await task_manager.create_task(
|
||||
task_type="image",
|
||||
model=model_name,
|
||||
params=task_params,
|
||||
user_id=user_id,
|
||||
project_id=project_id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create task: {e}")
|
||||
raise AppException(
|
||||
message=f"Failed to create task: {e}",
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
# 日志 context if provided
|
||||
if request.source and request.source_id:
|
||||
logger.info(f"Image generation started for {request.source}:{request.source_id}, task_id: {task.id}")
|
||||
|
||||
# Unified task manager handles execution automatically
|
||||
return ResponseModel(data={
|
||||
"task_id": task.id,
|
||||
"status": task.status
|
||||
})
|
||||
|
||||
|
||||
@router.post("/image/upscale", response_model=ResponseModel)
|
||||
async def upscale_image(request: UpscaleImageRequest):
|
||||
"""Upscale image resolution"""
|
||||
upscale_service = ModelRegistry.get_default(ModelType.UPSCALE)
|
||||
if not upscale_service:
|
||||
raise AppException(
|
||||
message="Default Upscale service not configured",
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
try:
|
||||
# Prepare kwargs
|
||||
upscale_kwargs = {
|
||||
"image_url": request.image_url,
|
||||
"rate": request.rate
|
||||
}
|
||||
if request.extra_params:
|
||||
upscale_kwargs.update(request.extra_params)
|
||||
|
||||
upscaled_url = await upscale_service.upscale_image(**upscale_kwargs)
|
||||
|
||||
if not upscaled_url:
|
||||
raise AppException(
|
||||
message="Upscaling failed or returned no URL",
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
return ResponseModel(data=UpscaleImageResponse(
|
||||
original_url=request.image_url,
|
||||
upscaled_url=upscaled_url
|
||||
))
|
||||
except Exception as e:
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
82
backend/src/api/generations/music.py
Normal file
82
backend/src/api/generations/music.py
Normal file
@@ -0,0 +1,82 @@
|
||||
"""
|
||||
Music Generation Endpoints - 音乐生成模块
|
||||
|
||||
包含音乐生成相关端点:
|
||||
- POST /generations/music - 音乐生成
|
||||
"""
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from src.models.schemas import ResponseModel, MusicGenerationRequest
|
||||
from src.services.provider.registry import ModelType
|
||||
from src.services.task_manager import task_manager
|
||||
from src.auth.dependencies import get_current_user, UserAuth
|
||||
|
||||
from .helpers import resolve_service, check_user_api_key
|
||||
from src.utils.errors import ErrorCode, AppException
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@router.post("/generations/music", response_model=ResponseModel)
|
||||
async def generate_music(
|
||||
request: MusicGenerationRequest,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""统一 Music/Lyrics Generation Endpoint."""
|
||||
logger.info("Music unified generation request: mode=%s, model=%s", request.generation_mode, request.model)
|
||||
|
||||
# 1) Resolve model by mode
|
||||
model_type = ModelType.LYRICS if request.generation_mode == "lyrics" else ModelType.MUSIC
|
||||
music_service = resolve_service(request.model, model_type)
|
||||
|
||||
# 1.5 检查用户是否配置了 API Key
|
||||
# 从 music_service 获取 provider_id
|
||||
provider = getattr(music_service, 'provider_id', None)
|
||||
if provider:
|
||||
check_user_api_key(current_user.id, provider)
|
||||
|
||||
if not music_service:
|
||||
raise AppException(
|
||||
message=f"{request.generation_mode} service not configured",
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
# 2) Unified async task for both lyrics/music modes
|
||||
try:
|
||||
model_name = request.model or getattr(music_service, "model_id", "default_music")
|
||||
task_params = request.model_dump()
|
||||
|
||||
# Handle project_id from multiple sources
|
||||
# Priority 1: request.project_id (from request body, alias="projectId")
|
||||
# Priority 2: request.source_id if source=='project' (from extra_params, alias="sourceId")
|
||||
project_id = request.project_id
|
||||
if not project_id and request.source == "project":
|
||||
project_id = request.source_id
|
||||
|
||||
# Get user_id from authenticated user
|
||||
user_id = current_user.id
|
||||
|
||||
task = await task_manager.create_task(
|
||||
task_type="music",
|
||||
model=model_name,
|
||||
params=task_params,
|
||||
user_id=user_id,
|
||||
project_id=project_id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to create unified music task: %s", e)
|
||||
raise AppException(
|
||||
message=f"Failed to create task: {e}",
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
return ResponseModel(data={
|
||||
"task_id": task.id,
|
||||
"status": task.status,
|
||||
"mode": request.generation_mode
|
||||
})
|
||||
407
backend/src/api/generations/script.py
Normal file
407
backend/src/api/generations/script.py
Normal file
@@ -0,0 +1,407 @@
|
||||
"""
|
||||
Script Analysis Endpoints - 脚本分析模块
|
||||
|
||||
包含所有与 LLM 脚本分析相关的端点:
|
||||
- /script/analyze - 分析小说文本
|
||||
- /prompt/optimize - 优化提示词
|
||||
- /style/recommend - 推荐风格
|
||||
- /script/summary - 小说摘要
|
||||
- /script/characters - 角色提取
|
||||
- /script/scenes - 场景提取
|
||||
- /script/props - 道具提取
|
||||
- /script/storyboards - 分镜拆分
|
||||
- /script/split - 章节拆分
|
||||
"""
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends
|
||||
|
||||
from src.auth.dependencies import get_current_user
|
||||
from src.auth.models import UserAuth
|
||||
from src.models.schemas import (
|
||||
ResponseModel,
|
||||
ScriptProcessRequest, ScriptResponse,
|
||||
PromptOptimizationRequest, PromptOptimizationResponse,
|
||||
StyleRecommendationRequest, StyleRecommendationResponse,
|
||||
NovelSummaryRequest, NovelSummaryResponse,
|
||||
CharacterExtractionRequest, CharacterExtractionResponse,
|
||||
SceneExtractionRequest, SceneExtractionResponse,
|
||||
PropExtractionRequest, PropExtractionResponse,
|
||||
StoryboardSplitRequest, StoryboardSplitResponse,
|
||||
ChapterSplitRequest, ChapterSplitResponse,
|
||||
CharacterAsset, SceneAsset, PropAsset
|
||||
)
|
||||
from src.services.script import script_service
|
||||
from src.services.project_service import project_manager
|
||||
|
||||
from .helpers import get_available_styles_from_config
|
||||
from src.utils.errors import ErrorCode, AppException
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@router.post("/script/analyze", response_model=ResponseModel)
|
||||
async def analyze_script(
|
||||
request: ScriptProcessRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""Analyze novel text and generate script"""
|
||||
try:
|
||||
# Handle skip_storyboard priority (explicit param vs extra_params)
|
||||
skip_sb = request.skip_storyboard
|
||||
extra = request.extra_params or {}
|
||||
if 'skip_storyboard' in extra:
|
||||
skip_sb = extra.pop('skip_storyboard')
|
||||
|
||||
response_data = await script_service.analyze_novel(
|
||||
novel_text=request.novel_text,
|
||||
project_id=request.project_id,
|
||||
model_name=request.model,
|
||||
max_input_tokens=request.max_input_tokens,
|
||||
skip_storyboard=skip_sb,
|
||||
user_id=current_user.id,
|
||||
**extra
|
||||
)
|
||||
|
||||
# Automatically save extracted assets to the project
|
||||
if request.project_id:
|
||||
# 获取 existing project assets for deduplication
|
||||
existing_assets_map = {}
|
||||
try:
|
||||
current_project = project_manager.get_project(request.project_id)
|
||||
if current_project and current_project.assets:
|
||||
for asset in current_project.assets:
|
||||
existing_assets_map[(asset.type, asset.name)] = asset
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch project for deduplication: {e}")
|
||||
|
||||
# 辅助函数 to create and add asset
|
||||
def save_asset(item: dict, asset_type: str):
|
||||
name = item.get('name')
|
||||
if not name:
|
||||
return
|
||||
|
||||
# Check if exists
|
||||
existing_asset = existing_assets_map.get((asset_type, name))
|
||||
|
||||
if existing_asset:
|
||||
logger.info(f"Asset {name} ({asset_type}) already exists, updating.")
|
||||
asset_id = existing_asset.id
|
||||
else:
|
||||
asset_id = str(uuid.uuid4())
|
||||
|
||||
# 创建 specific asset models
|
||||
try:
|
||||
asset = None
|
||||
if asset_type == 'character':
|
||||
asset = CharacterAsset(
|
||||
id=asset_id,
|
||||
name=item.get('name'),
|
||||
desc=item.get('desc', ''),
|
||||
tags=item.get('tags', []),
|
||||
age=item.get('age'),
|
||||
role=item.get('role'),
|
||||
appearance=item.get('appearance'),
|
||||
image_prompt=item.get('image_prompt')
|
||||
)
|
||||
elif asset_type == 'scene':
|
||||
asset = SceneAsset(
|
||||
id=asset_id,
|
||||
name=item.get('name'),
|
||||
desc=item.get('desc', ''),
|
||||
tags=item.get('tags', []),
|
||||
location=item.get('location'),
|
||||
time_of_day=item.get('time_of_day'),
|
||||
atmosphere=item.get('atmosphere'),
|
||||
image_prompt=item.get('image_prompt')
|
||||
)
|
||||
elif asset_type == 'prop':
|
||||
asset = PropAsset(
|
||||
id=asset_id,
|
||||
name=item.get('name'),
|
||||
desc=item.get('desc', ''),
|
||||
tags=item.get('tags', []),
|
||||
usage=item.get('usage'),
|
||||
image_prompt=item.get('image_prompt')
|
||||
)
|
||||
else:
|
||||
return
|
||||
|
||||
if existing_asset:
|
||||
project_manager.update_asset(request.project_id, asset_id, asset)
|
||||
else:
|
||||
project_manager.add_asset(request.project_id, asset)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save asset {item.get('name')}: {e}")
|
||||
|
||||
# Save characters
|
||||
for char in (response_data.characters or []):
|
||||
save_asset(char.model_dump(), 'character')
|
||||
|
||||
# Save scenes
|
||||
for scene in (response_data.scenes or []):
|
||||
save_asset(scene.model_dump(), 'scene')
|
||||
|
||||
# Save props
|
||||
for prop in (response_data.props or []):
|
||||
save_asset(prop.model_dump(), 'prop')
|
||||
|
||||
# Save summary if available
|
||||
if response_data.summary:
|
||||
try:
|
||||
project_manager.update_project(request.project_id, {"description": response_data.summary})
|
||||
logger.info(f"Updated project description with generated summary.")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update project description: {e}")
|
||||
|
||||
return ResponseModel(data=response_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Script analysis failed: {str(e)}")
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.post("/prompt/optimize", response_model=ResponseModel)
|
||||
async def optimize_prompt_endpoint(
|
||||
request: PromptOptimizationRequest,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""Optimize prompt for image or video generation"""
|
||||
try:
|
||||
optimized_prompt = await script_service.optimize_prompt(
|
||||
prompt=request.prompt,
|
||||
target_type=request.target_type,
|
||||
template=request.template,
|
||||
model_name=request.model,
|
||||
provider=request.provider,
|
||||
language=request.language or "Chinese",
|
||||
user_id=current_user.id,
|
||||
**(request.extra_params or {})
|
||||
)
|
||||
return ResponseModel(data=PromptOptimizationResponse(
|
||||
original_prompt=request.prompt,
|
||||
optimized_prompt=optimized_prompt,
|
||||
target_type=request.target_type
|
||||
))
|
||||
except Exception as e:
|
||||
logger.error(f"Prompt optimization failed: {str(e)}")
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.post("/style/recommend", response_model=ResponseModel)
|
||||
async def recommend_style(
|
||||
request: StyleRecommendationRequest,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""Recommend an art style based on novel text.
|
||||
|
||||
Backend automatically fetches available styles with full descriptions from config.
|
||||
"""
|
||||
try:
|
||||
# Automatically fetch styles with full information from backend config
|
||||
styles_full = await get_available_styles_from_config()
|
||||
logger.info(f"Auto-fetched {len(styles_full)} styles from backend config")
|
||||
|
||||
# 格式化 style information for AI: include name and description
|
||||
available_styles = [
|
||||
f"{style['name']} - {style.get('desc', '')}"
|
||||
for style in styles_full
|
||||
if style.get('name')
|
||||
]
|
||||
|
||||
logger.info(f"Formatted {len(available_styles)} styles with descriptions for AI")
|
||||
|
||||
result = await script_service.recommend_style(
|
||||
novel_text=request.novel_text,
|
||||
available_styles=available_styles,
|
||||
model_name=request.model,
|
||||
provider=request.provider,
|
||||
user_id=current_user.id,
|
||||
**(request.extra_params or {})
|
||||
)
|
||||
return ResponseModel(data=StyleRecommendationResponse(**result))
|
||||
except Exception as e:
|
||||
logger.error(f"Style recommendation failed: {str(e)}")
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.post("/script/summary", response_model=ResponseModel)
|
||||
async def summarize_novel_endpoint(
|
||||
request: NovelSummaryRequest,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
""" 总和marize novel text into a concise overview (without characters)"""
|
||||
try:
|
||||
result = await script_service.summarize_novel(
|
||||
novel_text=request.novel_text,
|
||||
language=request.language,
|
||||
model_name=request.model,
|
||||
provider=request.provider,
|
||||
global_summary=request.global_summary,
|
||||
user_id=current_user.id,
|
||||
**(request.extra_params or {})
|
||||
)
|
||||
return ResponseModel(data=NovelSummaryResponse(**result))
|
||||
except Exception as e:
|
||||
logger.error(f"Novel summary failed: {str(e)}")
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.post("/script/characters", response_model=ResponseModel)
|
||||
async def extract_characters_endpoint(
|
||||
request: CharacterExtractionRequest,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
""" 提取 characters from novel text"""
|
||||
try:
|
||||
result = await script_service.extract_characters(
|
||||
novel_text=request.novel_text,
|
||||
language=request.language,
|
||||
model_name=request.model,
|
||||
provider=request.provider,
|
||||
global_summary=request.global_summary,
|
||||
known_characters=request.known_characters,
|
||||
user_id=current_user.id,
|
||||
**(request.extra_params or {})
|
||||
)
|
||||
return ResponseModel(data=CharacterExtractionResponse(**result))
|
||||
except Exception as e:
|
||||
logger.error(f"Character extraction failed: {str(e)}")
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.post("/script/scenes", response_model=ResponseModel)
|
||||
async def extract_scenes_endpoint(
|
||||
request: SceneExtractionRequest,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
""" 提取 scenes from novel text"""
|
||||
try:
|
||||
result = await script_service.extract_scenes(
|
||||
novel_text=request.novel_text,
|
||||
language=request.language,
|
||||
model_name=request.model,
|
||||
global_summary=request.global_summary,
|
||||
known_scenes=request.known_scenes,
|
||||
user_id=current_user.id,
|
||||
**(request.extra_params or {})
|
||||
)
|
||||
return ResponseModel(data=SceneExtractionResponse(**result))
|
||||
except Exception as e:
|
||||
logger.error(f"Scene extraction failed: {str(e)}")
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.post("/script/props", response_model=ResponseModel)
|
||||
async def extract_props_endpoint(
|
||||
request: PropExtractionRequest,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
""" 提取 props from novel text"""
|
||||
try:
|
||||
result = await script_service.extract_props(
|
||||
novel_text=request.novel_text,
|
||||
language=request.language,
|
||||
model_name=request.model,
|
||||
global_summary=request.global_summary,
|
||||
known_props=request.known_props,
|
||||
user_id=current_user.id,
|
||||
**(request.extra_params or {})
|
||||
)
|
||||
return ResponseModel(data=PropExtractionResponse(**result))
|
||||
except Exception as e:
|
||||
logger.error(f"Prop extraction failed: {str(e)}")
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.post("/script/storyboards", response_model=ResponseModel)
|
||||
async def split_storyboards_endpoint(
|
||||
request: StoryboardSplitRequest,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""Split novel text into storyboards (shots)"""
|
||||
try:
|
||||
result = await script_service.split_storyboards(
|
||||
novel_text=request.novel_text,
|
||||
project_id=request.project_id,
|
||||
model_name=request.model,
|
||||
provider=request.provider,
|
||||
language=request.language,
|
||||
known_characters=request.known_characters,
|
||||
known_scenes=request.known_scenes,
|
||||
known_props=request.known_props,
|
||||
user_id=current_user.id,
|
||||
**(request.extra_params or {})
|
||||
)
|
||||
return ResponseModel(data=StoryboardSplitResponse(**result))
|
||||
except Exception as e:
|
||||
logger.error(f"Storyboard splitting failed: {str(e)}")
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.post("/script/split", response_model=ResponseModel)
|
||||
async def split_chapters_endpoint(
|
||||
request: ChapterSplitRequest,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""Split novel text into chapters using regex or agent"""
|
||||
try:
|
||||
# Check if agent splitting is requested via extra_params
|
||||
extra = request.extra_params or {}
|
||||
use_agent = extra.get("use_agent", False)
|
||||
|
||||
chapters = await script_service.split_chapters(
|
||||
novel_text=request.novel_text,
|
||||
regex_pattern=request.regex_pattern,
|
||||
use_agent=use_agent,
|
||||
model_name=request.model,
|
||||
language=extra.get("language", "Chinese"),
|
||||
user_id=current_user.id
|
||||
)
|
||||
return ResponseModel(data=ChapterSplitResponse(
|
||||
chapters=chapters,
|
||||
total_chapters=len(chapters)
|
||||
))
|
||||
except Exception as e:
|
||||
logger.error(f"Chapter split failed: {str(e)}")
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
212
backend/src/api/generations/tasks.py
Normal file
212
backend/src/api/generations/tasks.py
Normal file
@@ -0,0 +1,212 @@
|
||||
"""
|
||||
Task Status Endpoints - 任务状态模块
|
||||
|
||||
包含任务管理相关的端点:
|
||||
- GET /tasks - 列出任务
|
||||
- GET /tasks/{task_id} - 获取任务状态
|
||||
"""
|
||||
import logging
|
||||
import asyncio
|
||||
|
||||
from fastapi import APIRouter, Query, Request
|
||||
from typing import Optional
|
||||
|
||||
from src.models.schemas import ResponseModel, PaginationParams
|
||||
from src.utils.pagination import Paginator
|
||||
from src.services.provider.registry import ModelRegistry, ModelType
|
||||
from src.services.provider.base import TaskStatus
|
||||
from src.services.task_manager import task_manager
|
||||
from src.services.storage_service import storage_manager
|
||||
from src.config.settings import OSS_BUCKET, OSS_REGION
|
||||
from src.utils.errors import TaskNotFoundException, AppException, ErrorCode
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@router.get("/tasks", response_model=ResponseModel)
|
||||
async def list_tasks(
|
||||
request: Request,
|
||||
type: Optional[str] = Query(None, description="Filter by task type"),
|
||||
page: int = Query(1, ge=1, description="页码,从 1 开始"),
|
||||
page_size: int = Query(20, ge=1, le=100, description="每页数量,最大 100"),
|
||||
):
|
||||
"""列表 tasks with optional filtering.
|
||||
|
||||
Args:
|
||||
type: 任务类型过滤
|
||||
page: 页码
|
||||
page_size: 每页数量
|
||||
|
||||
Returns:
|
||||
分页的任务列表
|
||||
"""
|
||||
# 计算偏移量
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
# 获取任务列表
|
||||
tasks = task_manager.list_tasks(type=type, limit=page_size, offset=offset)
|
||||
|
||||
# 获取总数(简化处理,实际应该查询总数)
|
||||
total = len(tasks)
|
||||
|
||||
# 创建分页器
|
||||
paginator = Paginator(
|
||||
items=tasks,
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size
|
||||
)
|
||||
|
||||
return paginator.to_response(request)
|
||||
|
||||
|
||||
@router.get("/tasks/{task_id}", response_model=ResponseModel)
|
||||
async def get_task_status(task_id: str):
|
||||
"""
|
||||
Check status of a task.
|
||||
"""
|
||||
# 1. Check local TaskManager
|
||||
task = await task_manager.get_task(task_id)
|
||||
|
||||
if task:
|
||||
# If task is active (pending/processing), check provider status
|
||||
active_statuses = ["pending", "processing", "queued", "running", "submitted"]
|
||||
if task.status.lower() in active_statuses and task.provider_task_id:
|
||||
service = ModelRegistry.get(task.model)
|
||||
# Fallback to default if model service not found by name
|
||||
if not service:
|
||||
if task.type == "image":
|
||||
service = ModelRegistry.get_default(ModelType.IMAGE)
|
||||
elif task.type == "video":
|
||||
service = ModelRegistry.get_default(ModelType.VIDEO)
|
||||
elif task.type == "audio":
|
||||
service = ModelRegistry.get_default(ModelType.AUDIO)
|
||||
elif task.type == "music":
|
||||
service = ModelRegistry.get_default(ModelType.MUSIC)
|
||||
|
||||
if service and hasattr(service, 'check_status'):
|
||||
try:
|
||||
result = await service.check_status(task.provider_task_id, user_id=task.user_id)
|
||||
|
||||
# 更新 status
|
||||
# 映射 TaskStatus enum to string
|
||||
new_status = result.status.value.lower()
|
||||
if new_status == "succeeded":
|
||||
new_status = "success"
|
||||
|
||||
# If succeeded, process results (OSS upload)
|
||||
final_results = {}
|
||||
if result.status == TaskStatus.SUCCEEDED and result.results:
|
||||
processed_urls = []
|
||||
for idx, item in enumerate(result.results):
|
||||
if item.url:
|
||||
# Determine extension
|
||||
ext = "png"
|
||||
if task.type == "video":
|
||||
ext = "mp4"
|
||||
elif task.type in ("audio", "music"):
|
||||
ext = "mp3"
|
||||
|
||||
key = f"generated/{task.type}s/{task.id}_{idx}.{ext}"
|
||||
|
||||
# Save to storage (Local or OSS)
|
||||
oss_url = await asyncio.to_thread(storage_manager.save_from_url, item.url, key)
|
||||
processed_urls.append(oss_url or item.url)
|
||||
|
||||
final_results = {"urls": processed_urls}
|
||||
|
||||
# 更新 task in DB
|
||||
update_data = {"status": new_status}
|
||||
if final_results:
|
||||
update_data["result"] = final_results
|
||||
if result.error:
|
||||
update_data["error"] = str(result.error)
|
||||
|
||||
task = task_manager.update_task(task.id, **update_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check provider status for task {task.id}: {e}")
|
||||
# Don't fail the request, just return current state
|
||||
|
||||
# Retry OSS upload if succeeded but urls are not OSS (resilience)
|
||||
elif (task.status.lower() == "succeeded" or task.status.lower() == "success") and task.result and "urls" in task.result:
|
||||
try:
|
||||
urls = task.result["urls"]
|
||||
updated_urls = []
|
||||
needs_update = False
|
||||
|
||||
for idx, url in enumerate(urls):
|
||||
# Check if it looks like an OSS url
|
||||
is_oss = False
|
||||
if OSS_BUCKET and OSS_REGION and f"{OSS_BUCKET}.{OSS_REGION}" in url:
|
||||
is_oss = True
|
||||
|
||||
if not is_oss:
|
||||
# Attempt upload
|
||||
ext = "png"
|
||||
if task.type == "video":
|
||||
ext = "mp4"
|
||||
elif task.type in ("audio", "music"):
|
||||
ext = "mp3"
|
||||
key = f"generated/{task.type}s/{task.id}_{idx}.{ext}"
|
||||
|
||||
new_url = await asyncio.to_thread(storage_manager.save_from_url, url, key)
|
||||
if new_url and new_url != url:
|
||||
updated_urls.append(new_url)
|
||||
needs_update = True
|
||||
else:
|
||||
updated_urls.append(url)
|
||||
else:
|
||||
# It is OSS, maybe re-sign it
|
||||
signed_url = storage_manager.sign_url(url)
|
||||
updated_urls.append(signed_url if signed_url else url)
|
||||
|
||||
if needs_update:
|
||||
task = task_manager.update_task(task.id, result={"urls": updated_urls})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to retry OSS upload for task {task.id}: {e}")
|
||||
|
||||
return ResponseModel(data=task)
|
||||
|
||||
raise AppException(
|
||||
message="Task not found",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
|
||||
|
||||
@router.post("/tasks/{task_id}/cancel", response_model=ResponseModel)
|
||||
async def cancel_task(task_id: str):
|
||||
"""
|
||||
Cancel a pending/processing task.
|
||||
仅更新本地任务状态为 CANCELLED,不保证下游 provider 真正中止计算,
|
||||
但对于绝大多数异步生成场景已经足够(前端不再轮询,结果也不会再写入)。
|
||||
"""
|
||||
try:
|
||||
cancelled = await task_manager.cancel_task(task_id)
|
||||
except TaskNotFoundException:
|
||||
raise AppException(
|
||||
message="Task not found",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cancel task {task_id}: {e}")
|
||||
raise AppException(
|
||||
message="Failed to cancel task",
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
if not cancelled:
|
||||
# 已完成任务无法取消
|
||||
raise AppException(
|
||||
message="Task already finished and cannot be cancelled",
|
||||
code=ErrorCode.INVALID_PARAMETER,
|
||||
status_code=400
|
||||
)
|
||||
|
||||
# 返回简单结果,前端如需最新任务详情可以再调用 GET /tasks/{task_id}
|
||||
return ResponseModel(data={"task_id": task_id, "cancelled": True})
|
||||
136
backend/src/api/generations/video.py
Normal file
136
backend/src/api/generations/video.py
Normal file
@@ -0,0 +1,136 @@
|
||||
"""
|
||||
Video Generation Endpoints - 视频生成模块
|
||||
|
||||
包含视频生成相关的端点:
|
||||
- POST /generations/video - 视频生成
|
||||
"""
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from src.models.schemas import ResponseModel, VideoGenerationRequest
|
||||
from src.services.provider.registry import ModelType
|
||||
from src.services.task_manager import task_manager
|
||||
from src.auth.dependencies import get_current_user, UserAuth
|
||||
|
||||
from .helpers import resolve_service, ensure_url, check_user_api_key
|
||||
from src.utils.errors import ErrorCode, AppException
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@router.post("/generations/video", response_model=ResponseModel)
|
||||
async def generate_video(
|
||||
request: VideoGenerationRequest,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
""" 泛型 Video Generation Endpoint.
|
||||
"""
|
||||
# 0. Sanitize URLs Early (Convert Base64 -> URL)
|
||||
try:
|
||||
if request.image_inputs:
|
||||
request.image_inputs = [ensure_url(url) for url in request.image_inputs if url]
|
||||
if request.video_inputs:
|
||||
request.video_inputs = [ensure_url(url) for url in request.video_inputs if url]
|
||||
if request.audio_inputs:
|
||||
request.audio_inputs = [ensure_url(url) for url in request.audio_inputs if url]
|
||||
|
||||
# Also check extra_params for common image keys
|
||||
if request.extra_params:
|
||||
for key in ["image_url", "video_url", "input_image", "input_video", "audio_url"]:
|
||||
if key in request.extra_params and isinstance(request.extra_params[key], str):
|
||||
request.extra_params[key] = ensure_url(request.extra_params[key])
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to sanitize inputs: {e}")
|
||||
|
||||
# 1. Resolve Model - 使用统一的路由逻辑
|
||||
video_service = resolve_service(request.model, ModelType.VIDEO)
|
||||
|
||||
if not video_service:
|
||||
raise AppException(
|
||||
message="Video service not configured",
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
# 1.5 检查用户是否配置了 API Key
|
||||
# 从 video_service 获取 provider_id
|
||||
provider = getattr(video_service, 'provider_id', None)
|
||||
if provider:
|
||||
check_user_api_key(current_user.id, provider)
|
||||
|
||||
# 2. Resolve size from aspect_ratio (BEFORE creating task)
|
||||
aspect_ratio = request.aspect_ratio
|
||||
size = None
|
||||
|
||||
if aspect_ratio:
|
||||
# Look up resolution in model config
|
||||
model_config = getattr(video_service, "config", {}) or {}
|
||||
if isinstance(model_config, dict):
|
||||
resolutions_config = model_config.get("resolutions") or {}
|
||||
else:
|
||||
resolutions_config = {}
|
||||
|
||||
# Use provided resolution level (e.g. "720P", "1080P") or default to "720P"
|
||||
res_level = request.resolution or "720P"
|
||||
|
||||
if resolutions_config and res_level in resolutions_config and isinstance(resolutions_config[res_level], dict):
|
||||
ratio_map = resolutions_config[res_level]
|
||||
if aspect_ratio in ratio_map:
|
||||
size = ratio_map[aspect_ratio]
|
||||
|
||||
if not size:
|
||||
defaults = {
|
||||
"16:9": "1280*720",
|
||||
"9:16": "720*1280",
|
||||
"1:1": "1280*1280",
|
||||
"4:3": "1280*960",
|
||||
"3:4": "960*1280"
|
||||
}
|
||||
size = defaults.get(aspect_ratio)
|
||||
|
||||
# 3. Create Task in DB with resolved params
|
||||
try:
|
||||
model_name = request.model or getattr(video_service, "model_id", "default_video")
|
||||
|
||||
# 构建 params dict with resolved size
|
||||
task_params = request.model_dump()
|
||||
if size:
|
||||
task_params["size"] = size # Override with resolved size
|
||||
|
||||
# Use unified task manager
|
||||
# Get user_id from authenticated user
|
||||
user_id = current_user.id
|
||||
|
||||
# Handle project_id from multiple sources
|
||||
# Priority 1: request.project_id (from request body, alias="projectId")
|
||||
# Priority 2: request.source_id if source=='project' (from extra_params, alias="sourceId")
|
||||
project_id = request.project_id
|
||||
if not project_id and request.source == 'project':
|
||||
project_id = request.source_id
|
||||
|
||||
task = await task_manager.create_task(
|
||||
task_type="video",
|
||||
model=model_name,
|
||||
params=task_params,
|
||||
user_id=user_id,
|
||||
project_id=project_id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create task: {e}")
|
||||
raise AppException(
|
||||
message=f"Failed to create task: {e}",
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
# 日志 context if provided
|
||||
if request.source and request.source_id:
|
||||
logger.info(f"Video generation started for {request.source}:{request.source_id}, task_id: {task.id}")
|
||||
|
||||
# Unified task manager handles execution automatically
|
||||
return ResponseModel(data={
|
||||
"task_id": task.id,
|
||||
"status": task.status
|
||||
})
|
||||
350
backend/src/api/health.py
Normal file
350
backend/src/api/health.py
Normal file
@@ -0,0 +1,350 @@
|
||||
"""
|
||||
健康检查和监控端点
|
||||
|
||||
提供系统健康状态、性能指标和诊断信息
|
||||
"""
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any
|
||||
from fastapi import APIRouter, Response
|
||||
from prometheus_client import generate_latest, CONTENT_TYPE_LATEST
|
||||
|
||||
from src.models.schemas import ResponseModel
|
||||
from src.services.task_manager import task_manager
|
||||
from src.services.provider.registry import ModelRegistry
|
||||
from src.services.provider.health import health_monitor
|
||||
from src.config.database import engine, get_pool_status
|
||||
from sqlmodel import Session, text
|
||||
from src.utils.errors import ErrorCode, AppException
|
||||
|
||||
router = APIRouter(tags=["health"])
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 应用启动时间
|
||||
_start_time = time.time()
|
||||
|
||||
|
||||
@router.get("/health", response_model=ResponseModel)
|
||||
async def health_check():
|
||||
"""基础健康检查
|
||||
|
||||
Returns:
|
||||
基本的健康状态信息
|
||||
"""
|
||||
return ResponseModel(data={
|
||||
"status": "healthy",
|
||||
"service": "Pixel Backend",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"uptime_seconds": int(time.time() - _start_time)
|
||||
})
|
||||
|
||||
|
||||
@router.get("/health/detailed", response_model=ResponseModel)
|
||||
async def detailed_health_check():
|
||||
"""详细健康检查
|
||||
|
||||
检查所有关键组件的健康状态
|
||||
|
||||
Returns:
|
||||
详细的健康状态信息,包括各组件状态
|
||||
"""
|
||||
health_status = {
|
||||
"status": "healthy",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"uptime_seconds": int(time.time() - _start_time),
|
||||
"components": {}
|
||||
}
|
||||
|
||||
# 检查数据库连接
|
||||
try:
|
||||
start = time.time()
|
||||
with Session(engine) as session:
|
||||
session.exec(text("SELECT 1"))
|
||||
latency_ms = (time.time() - start) * 1000
|
||||
|
||||
# Get connection pool status
|
||||
pool_status = get_pool_status()
|
||||
|
||||
health_status["components"]["database"] = {
|
||||
"status": "healthy",
|
||||
"message": "Database connection successful",
|
||||
"latency_ms": round(latency_ms, 2),
|
||||
"pool": pool_status
|
||||
}
|
||||
except Exception as e:
|
||||
health_status["status"] = "unhealthy"
|
||||
health_status["components"]["database"] = {
|
||||
"status": "unhealthy",
|
||||
"message": f"Database connection failed: {str(e)}"
|
||||
}
|
||||
|
||||
# 检查Redis连接
|
||||
from src.config.settings import REDIS_ENABLED
|
||||
if REDIS_ENABLED:
|
||||
try:
|
||||
from src.services.cache_service import get_cache_service
|
||||
cache = get_cache_service()
|
||||
|
||||
if cache._connected:
|
||||
start = time.time()
|
||||
await cache._redis.ping()
|
||||
latency_ms = (time.time() - start) * 1000
|
||||
|
||||
# Get Redis info
|
||||
info = await cache._redis.info()
|
||||
health_status["components"]["redis"] = {
|
||||
"status": "healthy",
|
||||
"message": "Redis connection successful",
|
||||
"latency_ms": round(latency_ms, 2),
|
||||
"version": info.get("redis_version"),
|
||||
"connected_clients": info.get("connected_clients"),
|
||||
"used_memory_human": info.get("used_memory_human")
|
||||
}
|
||||
else:
|
||||
health_status["status"] = "degraded"
|
||||
health_status["components"]["redis"] = {
|
||||
"status": "unhealthy",
|
||||
"message": "Redis not connected"
|
||||
}
|
||||
except Exception as e:
|
||||
health_status["status"] = "degraded"
|
||||
health_status["components"]["redis"] = {
|
||||
"status": "unhealthy",
|
||||
"message": f"Redis connection failed: {str(e)}"
|
||||
}
|
||||
else:
|
||||
health_status["components"]["redis"] = {
|
||||
"status": "disabled",
|
||||
"message": "Redis is disabled in configuration"
|
||||
}
|
||||
|
||||
# 检查任务管理器
|
||||
try:
|
||||
stats = task_manager.get_stats()
|
||||
health_status["components"]["task_manager"] = {
|
||||
"status": "healthy",
|
||||
"stats": stats
|
||||
}
|
||||
except Exception as e:
|
||||
health_status["status"] = "degraded"
|
||||
health_status["components"]["task_manager"] = {
|
||||
"status": "unhealthy",
|
||||
"message": f"Task manager error: {str(e)}"
|
||||
}
|
||||
|
||||
# 检查模型注册表
|
||||
try:
|
||||
models = ModelRegistry.list_models()
|
||||
health_status["components"]["model_registry"] = {
|
||||
"status": "healthy",
|
||||
"total_models": len(models)
|
||||
}
|
||||
except Exception as e:
|
||||
health_status["status"] = "degraded"
|
||||
health_status["components"]["model_registry"] = {
|
||||
"status": "unhealthy",
|
||||
"message": f"Model registry error: {str(e)}"
|
||||
}
|
||||
|
||||
# 检查 AI 服务健康状态
|
||||
try:
|
||||
service_health = health_monitor.get_health_summary()
|
||||
health_status["components"]["ai_services"] = {
|
||||
"status": "healthy" if service_health["healthy"] == service_health["total"] else "degraded",
|
||||
"summary": service_health
|
||||
}
|
||||
except Exception as e:
|
||||
health_status["status"] = "degraded"
|
||||
health_status["components"]["ai_services"] = {
|
||||
"status": "unknown",
|
||||
"message": f"Health monitor error: {str(e)}"
|
||||
}
|
||||
|
||||
return ResponseModel(data=health_status)
|
||||
|
||||
|
||||
@router.get("/health/live", response_model=ResponseModel)
|
||||
async def liveness_probe():
|
||||
""" Kubernetes liveness probe
|
||||
|
||||
简单检查应用是否运行
|
||||
|
||||
Returns:
|
||||
200 OK if alive
|
||||
"""
|
||||
return ResponseModel(data={"status": "alive"})
|
||||
|
||||
|
||||
@router.get("/health/ready", response_model=ResponseModel)
|
||||
async def readiness_probe():
|
||||
""" Kubernetes readiness probe
|
||||
|
||||
检查应用是否准备好接收流量
|
||||
|
||||
Returns:
|
||||
200 OK if ready, 503 if not ready
|
||||
"""
|
||||
# 检查关键组件
|
||||
ready = True
|
||||
components = {}
|
||||
|
||||
# 检查数据库
|
||||
try:
|
||||
with Session(engine) as session:
|
||||
session.exec(text("SELECT 1"))
|
||||
components["database"] = "ready"
|
||||
except Exception as e:
|
||||
ready = False
|
||||
components["database"] = f"not ready: {str(e)}"
|
||||
|
||||
# 检查Redis(如果启用)
|
||||
from src.config.settings import REDIS_ENABLED
|
||||
if REDIS_ENABLED:
|
||||
try:
|
||||
from src.services.cache_service import get_cache_service
|
||||
cache = get_cache_service()
|
||||
|
||||
if cache._connected:
|
||||
await cache._redis.ping()
|
||||
components["redis"] = "ready"
|
||||
else:
|
||||
ready = False
|
||||
components["redis"] = "not ready: not connected"
|
||||
except Exception as e:
|
||||
ready = False
|
||||
components["redis"] = f"not ready: {str(e)}"
|
||||
else:
|
||||
components["redis"] = "disabled"
|
||||
|
||||
# 检查任务管理器
|
||||
try:
|
||||
task_manager.get_stats()
|
||||
components["task_manager"] = "ready"
|
||||
except Exception as e:
|
||||
ready = False
|
||||
components["task_manager"] = f"not ready: {str(e)}"
|
||||
|
||||
if ready:
|
||||
return ResponseModel(data={"status": "ready", "components": components})
|
||||
else:
|
||||
raise AppException(
|
||||
message="Service not ready",
|
||||
code=ErrorCode.SERVICE_UNAVAILABLE,
|
||||
status_code=503,
|
||||
details={"status": "not ready", "components": components}
|
||||
)
|
||||
|
||||
|
||||
@router.get("/metrics")
|
||||
async def prometheus_metrics():
|
||||
"""Prometheus metrics endpoint
|
||||
|
||||
导出 Prometheus 格式的监控指标
|
||||
|
||||
Returns:
|
||||
Prometheus 格式的指标数据
|
||||
"""
|
||||
# 更新 system and database metrics before returning
|
||||
from src.middlewares.metrics import update_system_metrics, update_database_metrics
|
||||
update_system_metrics()
|
||||
update_database_metrics()
|
||||
|
||||
return Response(
|
||||
content=generate_latest(),
|
||||
media_type=CONTENT_TYPE_LATEST
|
||||
)
|
||||
|
||||
|
||||
@router.get("/metrics/tasks", response_model=ResponseModel)
|
||||
async def task_metrics():
|
||||
"""任务管理器指标
|
||||
|
||||
Returns:
|
||||
任务管理器的详细统计信息
|
||||
"""
|
||||
try:
|
||||
stats = task_manager.get_stats()
|
||||
return ResponseModel(data=stats)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get task metrics: {e}", exc_info=True)
|
||||
return ResponseModel(
|
||||
code=500,
|
||||
message="Failed to get task metrics",
|
||||
data={"error": str(e)}
|
||||
)
|
||||
|
||||
|
||||
@router.get("/metrics/models", response_model=ResponseModel)
|
||||
async def model_metrics():
|
||||
"""模型服务指标
|
||||
|
||||
Returns:
|
||||
所有注册模型的健康状态和统计信息
|
||||
"""
|
||||
try:
|
||||
# 获取所有模型
|
||||
models = ModelRegistry.list_models()
|
||||
|
||||
# 获取健康状态
|
||||
health_summary = health_monitor.get_health_summary()
|
||||
all_health = health_monitor.get_all_health()
|
||||
|
||||
# 构建响应
|
||||
model_stats = []
|
||||
for model_id, config in models.items():
|
||||
health = all_health.get(model_id)
|
||||
model_stat = {
|
||||
"id": model_id,
|
||||
"name": config.get("name"),
|
||||
"type": config.get("type"),
|
||||
"provider": config.get("provider"),
|
||||
"enabled": config.get("enabled", True)
|
||||
}
|
||||
|
||||
if health:
|
||||
model_stat["health"] = {
|
||||
"status": health.status.value,
|
||||
"success_rate": health.get_success_rate(),
|
||||
"avg_latency_ms": health.avg_latency_ms,
|
||||
"total_checks": health.total_checks,
|
||||
"total_failures": health.total_failures
|
||||
}
|
||||
|
||||
model_stats.append(model_stat)
|
||||
|
||||
return ResponseModel(data={
|
||||
"summary": health_summary,
|
||||
"models": model_stats
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get model metrics: {e}", exc_info=True)
|
||||
return ResponseModel(
|
||||
code=500,
|
||||
message="Failed to get model metrics",
|
||||
data={"error": str(e)}
|
||||
)
|
||||
|
||||
|
||||
@router.get("/debug/info", response_model=ResponseModel)
|
||||
async def debug_info():
|
||||
"""调试信息
|
||||
|
||||
提供系统配置和运行时信息(仅用于开发环境)
|
||||
|
||||
Returns:
|
||||
系统调试信息
|
||||
"""
|
||||
import sys
|
||||
import platform
|
||||
from src.config.settings import NODE_ENV, STORAGE_TYPE, REDIS_ENABLED
|
||||
|
||||
return ResponseModel(data={
|
||||
"environment": NODE_ENV,
|
||||
"python_version": sys.version,
|
||||
"platform": platform.platform(),
|
||||
"storage_type": STORAGE_TYPE,
|
||||
"redis_enabled": REDIS_ENABLED,
|
||||
"uptime_seconds": int(time.time() - _start_time),
|
||||
"start_time": datetime.fromtimestamp(_start_time).isoformat()
|
||||
})
|
||||
29
backend/src/api/projects/__init__.py
Normal file
29
backend/src/api/projects/__init__.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""
|
||||
Projects API Package
|
||||
|
||||
项目 API 模块,包含:
|
||||
- core: 项目核心 CRUD 功能
|
||||
- episodes: 剧集管理
|
||||
- assets: 资产管理
|
||||
- storyboards: 分镜管理
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from .core import router as core_router
|
||||
from .episodes import router as episodes_router
|
||||
from .assets import router as assets_router
|
||||
from .storyboards import router as storyboards_router
|
||||
|
||||
# 创建主路由器
|
||||
router = APIRouter(tags=["projects"])
|
||||
|
||||
# 包含所有子路由
|
||||
router.include_router(core_router)
|
||||
router.include_router(episodes_router)
|
||||
router.include_router(assets_router)
|
||||
router.include_router(storyboards_router)
|
||||
|
||||
__all__ = [
|
||||
"router",
|
||||
]
|
||||
305
backend/src/api/projects/assets.py
Normal file
305
backend/src/api/projects/assets.py
Normal file
@@ -0,0 +1,305 @@
|
||||
"""
|
||||
Projects API - Assets Module
|
||||
|
||||
包含资产管理相关的 API 路由。
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Optional, List
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from src.models.schemas import (
|
||||
ResponseModel,
|
||||
Asset, CreateAssetRequest, UpdateAssetRequest,
|
||||
CharacterAsset, SceneAsset, PropAsset, OtherAsset,
|
||||
GenerationRecord,
|
||||
)
|
||||
from src.auth.dependencies import get_current_user
|
||||
from src.auth.models import UserAuth
|
||||
from src.services.project_service import project_manager
|
||||
from src.services.asset_deduplication_service import deduplicate_project_assets as deduplicate_assets_service
|
||||
from src.utils.errors import ErrorCode, AppException
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(tags=["projects-assets"])
|
||||
|
||||
|
||||
def _check_project_access(project_id: str, user_id: str):
|
||||
"""检查用户是否有权限访问项目"""
|
||||
project = project_manager.get_project(project_id, user_id=user_id)
|
||||
return project
|
||||
|
||||
|
||||
# --- Asset Management ---
|
||||
|
||||
@router.post("/projects/{project_id}/assets", response_model=ResponseModel)
|
||||
async def create_asset(
|
||||
project_id: str,
|
||||
request: dict,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Create a new asset.
|
||||
Accepts a dictionary, validates based on 'type' field.
|
||||
"""
|
||||
try:
|
||||
# Check project access
|
||||
if not _check_project_access(project_id, current_user.id):
|
||||
raise AppException(
|
||||
message="Project not found or access denied",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
|
||||
asset_type = request.get('type', 'other')
|
||||
if 'id' not in request:
|
||||
request['id'] = str(uuid.uuid4())
|
||||
|
||||
asset = None
|
||||
if asset_type == 'character':
|
||||
asset = TypeAdapter(CharacterAsset).validate_python(request)
|
||||
elif asset_type == 'scene':
|
||||
asset = TypeAdapter(SceneAsset).validate_python(request)
|
||||
elif asset_type == 'prop':
|
||||
asset = TypeAdapter(PropAsset).validate_python(request)
|
||||
else:
|
||||
asset = TypeAdapter(OtherAsset).validate_python(request)
|
||||
|
||||
updated_project = project_manager.add_asset(project_id, asset)
|
||||
if not updated_project:
|
||||
raise AppException(
|
||||
message="Project not found",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
|
||||
return ResponseModel(data=updated_project)
|
||||
except Exception as e:
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
@router.put("/projects/{project_id}/assets/{asset_id}", response_model=ResponseModel)
|
||||
async def update_asset(project_id: str, asset_id: str, request: UpdateAssetRequest):
|
||||
"""Update an asset"""
|
||||
try:
|
||||
project = project_manager.get_project(project_id)
|
||||
if not project:
|
||||
raise AppException(
|
||||
message="Project not found",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
|
||||
existing_asset = next((a for a in project.assets if a.id == asset_id), None)
|
||||
if not existing_asset:
|
||||
raise AppException(
|
||||
message="Asset not found",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
|
||||
# Merge updates
|
||||
update_data = request.model_dump(exclude_unset=True)
|
||||
updated_asset = existing_asset.model_copy(update=update_data)
|
||||
|
||||
result = project_manager.update_asset(project_id, asset_id, updated_asset)
|
||||
return ResponseModel(data=result)
|
||||
except Exception as e:
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
@router.delete("/projects/{project_id}/assets/{asset_id}", response_model=ResponseModel)
|
||||
async def delete_asset(project_id: str, asset_id: str):
|
||||
"""Delete an asset"""
|
||||
try:
|
||||
result = project_manager.delete_asset(project_id, asset_id)
|
||||
if not result:
|
||||
raise AppException(
|
||||
message="Project not found",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
return ResponseModel(data=result)
|
||||
except Exception as e:
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.post("/projects/{project_id}/assets/deduplicate", response_model=ResponseModel)
|
||||
async def deduplicate_project_assets(
|
||||
project_id: str,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Deduplicate all assets in the project using LLM; update storyboard references.
|
||||
"""
|
||||
try:
|
||||
updated_project = await deduplicate_assets_service(project_id, current_user.id)
|
||||
return ResponseModel(data=updated_project)
|
||||
except ValueError as e:
|
||||
if "not found" in str(e).lower():
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Asset deduplication failed: %s", e, exc_info=True)
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
# --- Asset Management ---
|
||||
@router.get("/projects/{project_id}/assets", response_model=ResponseModel)
|
||||
async def list_project_assets(
|
||||
project_id: str,
|
||||
asset_type: Optional[str] = Query(None, description="Filter by asset type"),
|
||||
search_query: Optional[str] = Query(None, description="Search by name or description"),
|
||||
limit: int = 50,
|
||||
offset: int = 0
|
||||
):
|
||||
"""List assets for a project with optional filtering and pagination"""
|
||||
try:
|
||||
result = project_manager.list_assets(project_id, asset_type, search_query, limit, offset)
|
||||
if result is None:
|
||||
raise AppException(
|
||||
message="Project not found",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
return ResponseModel(data=result)
|
||||
except Exception as e:
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
@router.post("/projects/{project_id}/assets", response_model=ResponseModel)
|
||||
async def create_asset(project_id: str, request: CreateAssetRequest):
|
||||
"""Create a new asset"""
|
||||
try:
|
||||
asset_data = request.model_dump(by_alias=True)
|
||||
asset_data['id'] = str(uuid.uuid4())
|
||||
|
||||
# Validate and convert to Asset Union
|
||||
new_asset = TypeAdapter(Asset).validate_python(asset_data)
|
||||
|
||||
updated_project = project_manager.add_asset(project_id, new_asset)
|
||||
if not updated_project:
|
||||
raise AppException(
|
||||
message="Project not found",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
return ResponseModel(data=updated_project)
|
||||
except Exception as e:
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
@router.post("/projects/{project_id}/assets/batch", response_model=ResponseModel)
|
||||
async def batch_create_assets(project_id: str, request: List[CreateAssetRequest]):
|
||||
"""Batch create assets"""
|
||||
try:
|
||||
new_assets = []
|
||||
for asset_req in request:
|
||||
asset_data = asset_req.model_dump(by_alias=True)
|
||||
asset_data['id'] = str(uuid.uuid4())
|
||||
# Validate and convert to Asset Union
|
||||
new_asset = TypeAdapter(Asset).validate_python(asset_data)
|
||||
new_assets.append(new_asset)
|
||||
|
||||
updated_project = project_manager.batch_add_assets(project_id, new_assets)
|
||||
if not updated_project:
|
||||
raise AppException(
|
||||
message="Project not found",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
return ResponseModel(data=updated_project)
|
||||
except Exception as e:
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
@router.put("/projects/{project_id}/assets/{asset_id}", response_model=ResponseModel)
|
||||
async def update_asset(project_id: str, asset_id: str, request: UpdateAssetRequest):
|
||||
"""Update an asset"""
|
||||
try:
|
||||
project = project_manager.get_project(project_id)
|
||||
if not project:
|
||||
raise AppException(
|
||||
message="Project not found",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
|
||||
existing_asset = next((a for a in project.assets if a.id == asset_id), None)
|
||||
if not existing_asset:
|
||||
raise AppException(
|
||||
message="Asset not found",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
|
||||
update_data = request.model_dump(exclude_unset=True)
|
||||
|
||||
# Manually validate generations if present to ensure they are model instances not dicts
|
||||
if 'generations' in update_data and update_data['generations']:
|
||||
update_data['generations'] = TypeAdapter(List[GenerationRecord]).validate_python(update_data['generations'])
|
||||
|
||||
updated_asset = existing_asset.model_copy(update=update_data)
|
||||
|
||||
updated_project = project_manager.update_asset(project_id, asset_id, updated_asset)
|
||||
return ResponseModel(data=updated_project)
|
||||
except Exception as e:
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
@router.delete("/projects/{project_id}/assets/{asset_id}", response_model=ResponseModel)
|
||||
async def delete_asset(project_id: str, asset_id: str):
|
||||
"""Delete an asset"""
|
||||
try:
|
||||
updated_project = project_manager.delete_asset(project_id, asset_id)
|
||||
if not updated_project:
|
||||
raise AppException(
|
||||
message="Project not found",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
return ResponseModel(data=updated_project)
|
||||
except Exception as e:
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
367
backend/src/api/projects/core.py
Normal file
367
backend/src/api/projects/core.py
Normal file
@@ -0,0 +1,367 @@
|
||||
"""
|
||||
Projects API - Core Module
|
||||
|
||||
包含项目核心 CRUD 相关的 API 路由。
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, Query, Request, Depends
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from src.models.schemas import (
|
||||
ResponseModel,
|
||||
CreateProjectRequest, UpdateProjectRequest,
|
||||
InitializeProjectRequest,
|
||||
)
|
||||
from src.auth.dependencies import get_current_user
|
||||
from src.auth.models import UserAuth
|
||||
from src.services.project_service import project_manager
|
||||
from src.services.project_initialization_service import run_initialization_pipeline
|
||||
from src.services.script_project_initialization_service import run_script_project_initialization
|
||||
from src.services.script_to_canvas_service import convert_script_project_to_canvas
|
||||
from src.utils.errors import ErrorCode, AppException
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(tags=["projects"])
|
||||
|
||||
|
||||
def _check_project_access(project_id: str, user_id: str) -> Optional:
|
||||
"""检查用户是否有权限访问项目"""
|
||||
project = project_manager.get_project(project_id, user_id=user_id)
|
||||
return project
|
||||
|
||||
|
||||
def _progress_callback(project_id: str):
|
||||
def _cb(step: str, percentage: int, message: str, details: Optional[dict] = None):
|
||||
try:
|
||||
project_manager.update_project(
|
||||
project_id,
|
||||
{
|
||||
"progress": {
|
||||
"current_step": step,
|
||||
"percentage": percentage,
|
||||
"message": message,
|
||||
"details": details or {},
|
||||
}
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to update progress: %s", e)
|
||||
return _cb
|
||||
|
||||
|
||||
# --- 项目管理 ---
|
||||
async def _create_initializing_project(
|
||||
request: InitializeProjectRequest,
|
||||
current_user: UserAuth,
|
||||
background_tasks: BackgroundTasks,
|
||||
project_type: str,
|
||||
):
|
||||
project = project_manager.create_project(
|
||||
name=request.name,
|
||||
description="正在从小说初始化...",
|
||||
type=project_type,
|
||||
chapters=[],
|
||||
assets=[],
|
||||
status="initializing",
|
||||
user_id=current_user.id
|
||||
)
|
||||
progress_cb = _progress_callback(project.id)
|
||||
if project_type == "script":
|
||||
background_tasks.add_task(
|
||||
run_script_project_initialization,
|
||||
project.id,
|
||||
request.novel_text,
|
||||
progress_cb,
|
||||
current_user.id,
|
||||
request.model,
|
||||
request.provider,
|
||||
)
|
||||
else:
|
||||
background_tasks.add_task(
|
||||
run_initialization_pipeline,
|
||||
project.id,
|
||||
request.novel_text,
|
||||
request.style,
|
||||
progress_cb,
|
||||
current_user.id,
|
||||
)
|
||||
return project
|
||||
|
||||
|
||||
@router.post("/projects/initialize-from-novel", response_model=ResponseModel)
|
||||
@router.post("/projects/init-from-novel", response_model=ResponseModel)
|
||||
@router.post("/projects/initialize-canvas-from-novel", response_model=ResponseModel)
|
||||
async def initialize_canvas_project_from_novel(
|
||||
request: InitializeProjectRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
使用多智能体管道从小说文本初始化项目(后台任务)
|
||||
步骤: 1. 立即创建项目(草稿) 2. 触发后台任务进行完整初始化
|
||||
"""
|
||||
try:
|
||||
project = await _create_initializing_project(request, current_user, background_tasks, "canvas")
|
||||
return ResponseModel(data=project)
|
||||
except Exception as e:
|
||||
logger.error("Error initializing project: %s", e, exc_info=True)
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.post("/projects/initialize-script-from-novel", response_model=ResponseModel)
|
||||
async def initialize_script_project_from_novel(
|
||||
request: InitializeProjectRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""使用多 Agent 改编流程从小说文本初始化剧本项目。"""
|
||||
try:
|
||||
project = await _create_initializing_project(request, current_user, background_tasks, "script")
|
||||
return ResponseModel(data=project)
|
||||
except Exception as e:
|
||||
logger.error("Error initializing script project: %s", e, exc_info=True)
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
@router.post("/projects", response_model=ResponseModel)
|
||||
async def create_project(
|
||||
request: CreateProjectRequest,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""Create a new project"""
|
||||
try:
|
||||
project = project_manager.create_project(
|
||||
name=request.name,
|
||||
description=request.description,
|
||||
type=request.type,
|
||||
chapters=request.chapters,
|
||||
assets=request.assets,
|
||||
user_id=current_user.id
|
||||
)
|
||||
return ResponseModel(data=project)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating project: {e}", exc_info=True)
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
@router.get("/projects", response_model=ResponseModel)
|
||||
async def list_projects(
|
||||
request: Request,
|
||||
page: int = Query(1, ge=1, description="页码,从1开始"),
|
||||
page_size: int = Query(20, ge=1, le=100, description="每页数量,最大100"),
|
||||
sort: Optional[str] = Query(None, description="排序字段,格式: field:asc 或 field:desc"),
|
||||
filter: Optional[str] = Query(None, description="过滤条件,JSON格式"),
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
List all projects with pagination support.
|
||||
|
||||
Args:
|
||||
page: 页码,从1开始
|
||||
page_size: 每页数量 (1-100, 默认: 20)
|
||||
sort: 排序字段,格式: "field:asc" 或 "field:desc"
|
||||
filter: 过滤条件,JSON格式
|
||||
|
||||
Returns:
|
||||
分页的项目列表
|
||||
"""
|
||||
try:
|
||||
from src.models.schemas import PaginationParams
|
||||
from src.utils.pagination import Paginator
|
||||
|
||||
# 创建分页参数
|
||||
pagination_params = PaginationParams(
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
sort=sort,
|
||||
filter=filter
|
||||
)
|
||||
|
||||
# 计算偏移量
|
||||
offset = pagination_params.get_offset()
|
||||
limit = pagination_params.get_limit()
|
||||
|
||||
# 获取当前用户ID,只返回该用户的项目
|
||||
user_id = current_user.id if current_user else None
|
||||
|
||||
# 获取项目列表(按用户过滤)
|
||||
projects = project_manager.list_projects(limit=limit, offset=offset, user_id=user_id)
|
||||
|
||||
# 获取总数(按用户过滤)
|
||||
total = project_manager.count_projects(user_id=user_id)
|
||||
|
||||
# 创建分页器
|
||||
paginator = Paginator(
|
||||
items=projects,
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size
|
||||
)
|
||||
|
||||
# 返回分页响应
|
||||
return paginator.to_response(request)
|
||||
|
||||
except Exception as e:
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
@router.get("/projects/{project_id}", response_model=ResponseModel)
|
||||
async def get_project(
|
||||
project_id: str,
|
||||
include_assets: bool = Query(default=True, description="Whether to include assets in the response"),
|
||||
include_referenced_assets: bool = Query(default=False, description="Whether to include only referenced assets (if include_assets is False)"),
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""Get project details"""
|
||||
project = project_manager.get_project(
|
||||
project_id,
|
||||
include_assets=include_assets,
|
||||
include_referenced_assets=include_referenced_assets,
|
||||
user_id=current_user.id
|
||||
)
|
||||
if not project:
|
||||
raise AppException(
|
||||
message="Project not found or access denied",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
return ResponseModel(data=project)
|
||||
|
||||
@router.put("/projects/{project_id}", response_model=ResponseModel)
|
||||
async def update_project(
|
||||
project_id: str,
|
||||
request: UpdateProjectRequest,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""Update a project"""
|
||||
try:
|
||||
# First check if project exists and belongs to user
|
||||
existing = project_manager.get_project(project_id, user_id=current_user.id)
|
||||
if not existing:
|
||||
raise AppException(
|
||||
message="Project not found or access denied",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
|
||||
update_data = request.model_dump(exclude_unset=True)
|
||||
updated_project = project_manager.update_project(project_id, update_data)
|
||||
if not updated_project:
|
||||
raise AppException(
|
||||
message="Project not found",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
return ResponseModel(data=updated_project)
|
||||
except Exception as e:
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
@router.delete("/projects/{project_id}", response_model=ResponseModel)
|
||||
async def delete_project(
|
||||
project_id: str,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""Delete a project"""
|
||||
try:
|
||||
# First check if project exists and belongs to user
|
||||
existing = project_manager.get_project(project_id, user_id=current_user.id)
|
||||
if not existing:
|
||||
raise AppException(
|
||||
message="Project not found or access denied",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
|
||||
project_manager.delete_project(project_id)
|
||||
return ResponseModel(data={"id": project_id, "deleted": True})
|
||||
except Exception as e:
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.post("/projects/{project_id}/convert-to-canvas", response_model=ResponseModel)
|
||||
async def convert_project_to_canvas(
|
||||
project_id: str,
|
||||
background_tasks: BackgroundTasks,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""将剧本项目转换为独立的画布项目。"""
|
||||
try:
|
||||
source_project = project_manager.get_project(project_id, user_id=current_user.id)
|
||||
if not source_project:
|
||||
raise AppException(
|
||||
message="Project not found or access denied",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
if source_project.type != "script":
|
||||
raise AppException(
|
||||
message="Only script projects can be converted to canvas projects",
|
||||
code=ErrorCode.INVALID_PARAMETER,
|
||||
status_code=400
|
||||
)
|
||||
|
||||
chapters = [
|
||||
{
|
||||
"id": episode.id,
|
||||
"title": episode.title,
|
||||
"order": episode.order,
|
||||
"content": episode.content,
|
||||
"summary": episode.desc,
|
||||
"status": episode.status,
|
||||
}
|
||||
for episode in (source_project.episodes or [])
|
||||
]
|
||||
|
||||
target_project = project_manager.create_project(
|
||||
name=f"{source_project.name} - 画布",
|
||||
description="正在从剧本项目生成画布...",
|
||||
type="canvas",
|
||||
chapters=chapters,
|
||||
assets=[],
|
||||
status="initializing",
|
||||
user_id=current_user.id
|
||||
)
|
||||
progress_cb = _progress_callback(target_project.id)
|
||||
background_tasks.add_task(
|
||||
convert_script_project_to_canvas,
|
||||
source_project.id,
|
||||
target_project.id,
|
||||
progress_cb,
|
||||
current_user.id,
|
||||
)
|
||||
return ResponseModel(data=target_project)
|
||||
except AppException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Error converting project to canvas: %s", e, exc_info=True)
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
179
backend/src/api/projects/episodes.py
Normal file
179
backend/src/api/projects/episodes.py
Normal file
@@ -0,0 +1,179 @@
|
||||
"""
|
||||
Projects API - Episodes Module
|
||||
|
||||
包含剧集管理相关的 API 路由。
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from src.models.schemas import (
|
||||
ResponseModel,
|
||||
Episode, CreateEpisodeRequest, UpdateEpisodeRequest,
|
||||
)
|
||||
from src.auth.dependencies import get_current_user
|
||||
from src.auth.models import UserAuth
|
||||
from src.services.project_service import project_manager
|
||||
from src.utils.errors import ErrorCode, AppException
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(tags=["projects-episodes"])
|
||||
|
||||
|
||||
def _check_project_access(project_id: str, user_id: str):
|
||||
"""检查用户是否有权限访问项目"""
|
||||
project = project_manager.get_project(project_id, user_id=user_id)
|
||||
return project
|
||||
|
||||
|
||||
# --- Episode Management ---
|
||||
@router.post("/projects/{project_id}/episodes", response_model=ResponseModel)
|
||||
async def create_episode(
|
||||
project_id: str,
|
||||
request: CreateEpisodeRequest,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""Create a new episode"""
|
||||
try:
|
||||
# Check project access
|
||||
if not _check_project_access(project_id, current_user.id):
|
||||
raise AppException(
|
||||
message="Project not found or access denied",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
|
||||
episode_id = str(uuid.uuid4())
|
||||
new_episode = Episode(
|
||||
id=episode_id,
|
||||
title=request.title,
|
||||
order=request.order,
|
||||
desc=request.desc,
|
||||
status=request.status
|
||||
)
|
||||
updated_project = project_manager.add_episode(project_id, new_episode)
|
||||
if not updated_project:
|
||||
raise AppException(
|
||||
message="Project not found or operation failed",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
return ResponseModel(data=updated_project)
|
||||
except Exception as e:
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
@router.put("/projects/{project_id}/episodes/{episode_id}", response_model=ResponseModel)
|
||||
async def update_episode(
|
||||
project_id: str,
|
||||
episode_id: str,
|
||||
request: UpdateEpisodeRequest,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""Update an episode"""
|
||||
try:
|
||||
project = project_manager.get_project(project_id, user_id=current_user.id)
|
||||
if not project:
|
||||
raise AppException(
|
||||
message="Project not found or access denied",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
|
||||
existing_ep = next((ep for ep in project.episodes if ep.id == episode_id), None)
|
||||
if not existing_ep:
|
||||
raise AppException(
|
||||
message="Episode not found",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
|
||||
update_data = request.model_dump(exclude_unset=True, by_alias=True)
|
||||
updated_ep = existing_ep.model_copy(update=update_data)
|
||||
|
||||
updated_project = project_manager.update_episode(project_id, episode_id, updated_ep)
|
||||
return ResponseModel(data=updated_project)
|
||||
except Exception as e:
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
@router.delete("/projects/{project_id}/episodes/{episode_id}", response_model=ResponseModel)
|
||||
async def delete_episode(
|
||||
project_id: str,
|
||||
episode_id: str,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""Delete an episode"""
|
||||
try:
|
||||
project = project_manager.get_project(project_id, user_id=current_user.id)
|
||||
if not project:
|
||||
raise AppException(
|
||||
message="Project not found or access denied",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
|
||||
updated_project = project_manager.delete_episode(project_id, episode_id)
|
||||
if not updated_project:
|
||||
raise AppException(
|
||||
message="Project not found",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
return ResponseModel(data=updated_project)
|
||||
except Exception as e:
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.post("/projects/{project_id}/episodes/{episode_id}/analyze", response_model=ResponseModel)
|
||||
async def analyze_episode(
|
||||
project_id: str,
|
||||
episode_id: str,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Analyze a specific episode: extract summary, characters/scenes/props, storyboards.
|
||||
"""
|
||||
try:
|
||||
from src.services.episode_analysis_service import analyze_episode as analyze_episode_service
|
||||
final_project = await analyze_episode_service(project_id, episode_id, current_user.id)
|
||||
return ResponseModel(data=final_project)
|
||||
except ValueError as e:
|
||||
msg = str(e)
|
||||
if "not found" in msg.lower():
|
||||
raise AppException(
|
||||
message=msg,
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
if "empty" in msg.lower():
|
||||
raise AppException(
|
||||
message=msg,
|
||||
code=ErrorCode.INVALID_PARAMETER,
|
||||
status_code=400
|
||||
)
|
||||
raise AppException(
|
||||
message=msg,
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Episode analysis failed: %s", e, exc_info=True)
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
113
backend/src/api/projects/storyboards.py
Normal file
113
backend/src/api/projects/storyboards.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""
|
||||
Projects API - Storyboards Module
|
||||
|
||||
包含分镜管理相关的 API 路由。
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from src.models.schemas import (
|
||||
ResponseModel,
|
||||
Storyboard, CreateStoryboardRequest, UpdateStoryboardRequest,
|
||||
GenerationRecord,
|
||||
)
|
||||
from src.utils.errors import ErrorCode, AppException
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(tags=["projects-storyboards"])
|
||||
|
||||
|
||||
# --- Storyboard Management ---
|
||||
@router.post("/projects/{project_id}/storyboards", response_model=ResponseModel)
|
||||
async def create_storyboard(project_id: str, request: CreateStoryboardRequest):
|
||||
"""Create a new storyboard"""
|
||||
try:
|
||||
from src.services.project_service import project_manager
|
||||
|
||||
sb_id = str(uuid.uuid4())
|
||||
sb_data = request.model_dump(by_alias=True)
|
||||
sb_data['id'] = sb_id
|
||||
|
||||
new_sb = Storyboard(**sb_data)
|
||||
|
||||
updated_project = project_manager.add_storyboard(project_id, new_sb)
|
||||
if not updated_project:
|
||||
raise AppException(
|
||||
message="Project not found",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
return ResponseModel(data=updated_project)
|
||||
except Exception as e:
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
@router.put("/projects/{project_id}/storyboards/{storyboard_id}", response_model=ResponseModel)
|
||||
async def update_storyboard(project_id: str, storyboard_id: str, request: UpdateStoryboardRequest):
|
||||
"""Update a storyboard"""
|
||||
try:
|
||||
from src.services.project_service import project_manager
|
||||
|
||||
project = project_manager.get_project(project_id)
|
||||
if not project:
|
||||
raise AppException(
|
||||
message="Project not found",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
|
||||
existing_sb = next((sb for sb in project.storyboards if sb.id == storyboard_id), None)
|
||||
if not existing_sb:
|
||||
raise AppException(
|
||||
message="Storyboard not found",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
|
||||
update_data = request.model_dump(exclude_unset=True, by_alias=True)
|
||||
|
||||
# Manually validate generations if present to ensure they are model instances not dicts
|
||||
if 'generations' in update_data and update_data['generations']:
|
||||
update_data['generations'] = TypeAdapter(List[GenerationRecord]).validate_python(update_data['generations'])
|
||||
|
||||
updated_sb = existing_sb.model_copy(update=update_data)
|
||||
|
||||
updated_project = project_manager.update_storyboard(project_id, storyboard_id, updated_sb)
|
||||
return ResponseModel(data=updated_project)
|
||||
except Exception as e:
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
@router.delete("/projects/{project_id}/storyboards/{storyboard_id}", response_model=ResponseModel)
|
||||
async def delete_storyboard(project_id: str, storyboard_id: str):
|
||||
"""Delete a storyboard"""
|
||||
try:
|
||||
from src.services.project_service import project_manager
|
||||
|
||||
updated_project = project_manager.delete_storyboard(project_id, storyboard_id)
|
||||
if not updated_project:
|
||||
raise AppException(
|
||||
message="Project not found",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
return ResponseModel(data=updated_project)
|
||||
except Exception as e:
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
240
backend/src/api/prompt_templates.py
Normal file
240
backend/src/api/prompt_templates.py
Normal file
@@ -0,0 +1,240 @@
|
||||
"""
|
||||
Prompt Template API - 提示词模板接口
|
||||
"""
|
||||
import logging
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
|
||||
from src.models.schemas import ResponseModel
|
||||
from src.models.prompt_template import (
|
||||
PromptTemplateCreate,
|
||||
PromptTemplateUpdate,
|
||||
PromptTemplateCategory
|
||||
)
|
||||
from src.services.prompt_template_service import PromptTemplateService
|
||||
from src.auth.dependencies import get_current_user
|
||||
from src.auth.models import UserAuth
|
||||
from src.utils.errors import AppException, ErrorCode
|
||||
|
||||
router = APIRouter(tags=["prompt-templates"])
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@router.post("/prompt-templates/init", response_model=ResponseModel)
|
||||
async def init_templates(current_user: UserAuth = Depends(get_current_user)):
|
||||
"""初始化系统默认模板(仅管理员)"""
|
||||
# 这里可以添加管理员权限检查
|
||||
PromptTemplateService.init_default_templates()
|
||||
return ResponseModel(message="Templates initialized successfully")
|
||||
|
||||
|
||||
@router.get("/prompt-templates", response_model=ResponseModel)
|
||||
async def list_templates(
|
||||
category: Optional[str] = Query(None, description="分类过滤"),
|
||||
target_type: Optional[str] = Query(None, description="目标类型过滤: image/video/audio/music"),
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
favorites_only: bool = Query(False, description="仅显示收藏"),
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""获取提示词模板列表"""
|
||||
try:
|
||||
result = PromptTemplateService.list_templates(
|
||||
user_id=current_user.id,
|
||||
category=category,
|
||||
target_type=target_type,
|
||||
search=search,
|
||||
favorites_only=favorites_only,
|
||||
page=page,
|
||||
page_size=page_size
|
||||
)
|
||||
return ResponseModel(data=result)
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing templates: {e}")
|
||||
raise AppException(
|
||||
message="Failed to list templates",
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.get("/prompt-templates/categories", response_model=ResponseModel)
|
||||
async def get_categories(
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""获取模板分类列表"""
|
||||
categories = PromptTemplateService.get_categories()
|
||||
return ResponseModel(data=categories)
|
||||
|
||||
|
||||
@router.post("/prompt-templates", response_model=ResponseModel)
|
||||
async def create_template(
|
||||
data: PromptTemplateCreate,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""创建新模板"""
|
||||
try:
|
||||
template = PromptTemplateService.create_template(
|
||||
user_id=current_user.id,
|
||||
data=data
|
||||
)
|
||||
return ResponseModel(
|
||||
message="Template created successfully",
|
||||
data=template
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating template: {e}")
|
||||
raise AppException(
|
||||
message="Failed to create template",
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.get("/prompt-templates/{template_id}", response_model=ResponseModel)
|
||||
async def get_template(
|
||||
template_id: str,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""获取单个模板详情"""
|
||||
template = PromptTemplateService.get_template(
|
||||
template_id=template_id,
|
||||
user_id=current_user.id
|
||||
)
|
||||
|
||||
if not template:
|
||||
raise AppException(
|
||||
message="Template not found",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
|
||||
return ResponseModel(data=template)
|
||||
|
||||
|
||||
@router.put("/prompt-templates/{template_id}", response_model=ResponseModel)
|
||||
async def update_template(
|
||||
template_id: str,
|
||||
data: PromptTemplateUpdate,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""更新模板"""
|
||||
try:
|
||||
template = PromptTemplateService.update_template(
|
||||
template_id=template_id,
|
||||
user_id=current_user.id,
|
||||
data=data
|
||||
)
|
||||
|
||||
if not template:
|
||||
raise AppException(
|
||||
message="Template not found",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
|
||||
return ResponseModel(
|
||||
message="Template updated successfully",
|
||||
data=template
|
||||
)
|
||||
except PermissionError as e:
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.PERMISSION_DENIED,
|
||||
status_code=403
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating template: {e}")
|
||||
raise AppException(
|
||||
message="Failed to update template",
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/prompt-templates/{template_id}", response_model=ResponseModel)
|
||||
async def delete_template(
|
||||
template_id: str,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""删除模板"""
|
||||
try:
|
||||
success = PromptTemplateService.delete_template(
|
||||
template_id=template_id,
|
||||
user_id=current_user.id
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise AppException(
|
||||
message="Template not found",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
|
||||
return ResponseModel(message="Template deleted successfully")
|
||||
except PermissionError as e:
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.PERMISSION_DENIED,
|
||||
status_code=403
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting template: {e}")
|
||||
raise AppException(
|
||||
message="Failed to delete template",
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.post("/prompt-templates/{template_id}/favorite", response_model=ResponseModel)
|
||||
async def toggle_favorite(
|
||||
template_id: str,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""切换模板收藏状态"""
|
||||
try:
|
||||
is_favorite = PromptTemplateService.toggle_favorite(
|
||||
template_id=template_id,
|
||||
user_id=current_user.id
|
||||
)
|
||||
|
||||
return ResponseModel(
|
||||
message="Favorite toggled successfully",
|
||||
data={"is_favorite": is_favorite}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error toggling favorite: {e}")
|
||||
raise AppException(
|
||||
message="Failed to toggle favorite",
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.post("/prompt-templates/{template_id}/apply", response_model=ResponseModel)
|
||||
async def apply_template(
|
||||
template_id: str,
|
||||
user_prompt: str,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""应用模板到提示词"""
|
||||
result = PromptTemplateService.apply_template(
|
||||
template_id=template_id,
|
||||
user_prompt=user_prompt
|
||||
)
|
||||
|
||||
if result is None:
|
||||
raise AppException(
|
||||
message="Template not found",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
|
||||
return ResponseModel(
|
||||
data={
|
||||
"original": user_prompt,
|
||||
"enhanced": result
|
||||
}
|
||||
)
|
||||
195
backend/src/api/skills.py
Normal file
195
backend/src/api/skills.py
Normal file
@@ -0,0 +1,195 @@
|
||||
"""Agent Skills Management API
|
||||
|
||||
Provides endpoints for managing agent skills.
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
from fastapi import APIRouter
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional
|
||||
|
||||
from src.services.agent_engine import AgentScopeService
|
||||
from src.services.agent_engine.toolkit import ToolkitFactory
|
||||
from src.models.schemas import ResponseModel
|
||||
from src.utils.errors import ErrorCode, AppException
|
||||
|
||||
router = APIRouter(tags=["skills"])
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SkillInfo(BaseModel):
|
||||
"""Skill information model"""
|
||||
name: str
|
||||
description: str
|
||||
path: str
|
||||
|
||||
|
||||
class RegisterSkillRequest(BaseModel):
|
||||
"""Request model for registering a skill"""
|
||||
skill_dir: str
|
||||
|
||||
|
||||
class SkillPromptResponse(BaseModel):
|
||||
"""Response model for skill prompt"""
|
||||
prompt: Optional[str]
|
||||
skills_count: int
|
||||
|
||||
|
||||
@router.get("/skills", response_model=ResponseModel)
|
||||
async def list_skills():
|
||||
"""List all available agent skills.
|
||||
|
||||
Returns:
|
||||
List of skill information
|
||||
"""
|
||||
skills = []
|
||||
|
||||
try:
|
||||
# 使用新的 ToolkitFactory
|
||||
all_skills = ToolkitFactory.list_skills()
|
||||
|
||||
for skill_path in all_skills:
|
||||
# skill_path 格式: "domain/skill_name"
|
||||
domain, skill_name = skill_path.split('/')
|
||||
|
||||
# 读取 SKILL.md 获取描述
|
||||
from pathlib import Path
|
||||
# Fix path: services/agent_engine/skills (was services/agents/skills)
|
||||
skill_dir = Path(__file__).parent.parent / "services" / "agent_engine" / "skills" / domain / skill_name
|
||||
skill_md = skill_dir / "SKILL.md"
|
||||
|
||||
description = "No description available"
|
||||
if skill_md.exists():
|
||||
try:
|
||||
with open(skill_md, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
|
||||
# Extract description from YAML front matter
|
||||
if content.startswith('---'):
|
||||
parts = content.split('---', 2)
|
||||
if len(parts) >= 3:
|
||||
import yaml
|
||||
try:
|
||||
metadata = yaml.safe_load(parts[1])
|
||||
description = metadata.get('description', description)
|
||||
except yaml.YAMLError:
|
||||
pass
|
||||
|
||||
skills.append(SkillInfo(
|
||||
name=skill_path, # 使用完整路径作为名称
|
||||
description=description,
|
||||
path=str(skill_dir)
|
||||
))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read skill {skill_path}: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list skills: {e}")
|
||||
|
||||
return ResponseModel(data=skills)
|
||||
|
||||
|
||||
@router.post("/skills/register", response_model=ResponseModel)
|
||||
async def register_skill(request: RegisterSkillRequest):
|
||||
"""Register a new agent skill.
|
||||
|
||||
Args:
|
||||
request: Skill registration request
|
||||
|
||||
Returns:
|
||||
Success message
|
||||
|
||||
Raises:
|
||||
HTTPException: If skill registration fails
|
||||
"""
|
||||
try:
|
||||
# Validate skill directory
|
||||
if not os.path.exists(request.skill_dir):
|
||||
raise AppException(
|
||||
message=f"Skill directory not found: {request.skill_dir}",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
|
||||
skill_md = os.path.join(request.skill_dir, "SKILL.md")
|
||||
if not os.path.exists(skill_md):
|
||||
raise AppException(
|
||||
message="SKILL.md file not found in skill directory",
|
||||
code=ErrorCode.INVALID_PARAMETER,
|
||||
status_code=400
|
||||
)
|
||||
|
||||
# Register skill (this would need to be implemented in AgentScopeService)
|
||||
logger.info(f"Registered skill from: {request.skill_dir}")
|
||||
|
||||
return ResponseModel(data={"message": "Skill registered successfully", "path": request.skill_dir})
|
||||
|
||||
except AppException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to register skill: {e}")
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.get("/skills/prompt", response_model=ResponseModel)
|
||||
async def get_skills_prompt():
|
||||
"""Get the combined prompt for all registered skills.
|
||||
|
||||
This prompt can be attached to the agent's system prompt.
|
||||
|
||||
Returns:
|
||||
Combined skills prompt
|
||||
"""
|
||||
try:
|
||||
# 使用 ToolkitFactory
|
||||
all_skills = ToolkitFactory.list_skills()
|
||||
skills_count = len(all_skills)
|
||||
|
||||
# 获取 skills prompt
|
||||
if skills_count > 0:
|
||||
toolkit = ToolkitFactory.get_toolkit()
|
||||
prompt = toolkit.get_agent_skill_prompt()
|
||||
else:
|
||||
prompt = None
|
||||
|
||||
return ResponseModel(data={"prompt": prompt, "skills_count": skills_count})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get skills prompt: {e}")
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/skills/{skill_name}", response_model=ResponseModel)
|
||||
async def remove_skill(skill_name: str):
|
||||
"""Remove a registered agent skill.
|
||||
|
||||
Args:
|
||||
skill_name: Name of the skill to remove
|
||||
|
||||
Returns:
|
||||
Success message
|
||||
|
||||
Raises:
|
||||
HTTPException: If skill removal fails
|
||||
"""
|
||||
try:
|
||||
# This would need to be implemented in AgentScopeService
|
||||
logger.info(f"Removed skill: {skill_name}")
|
||||
|
||||
return ResponseModel(data={"message": f"Skill '{skill_name}' removed successfully"})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to remove skill: {e}")
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
112
backend/src/api/storage.py
Normal file
112
backend/src/api/storage.py
Normal file
@@ -0,0 +1,112 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, Query, Response
|
||||
from fastapi.responses import StreamingResponse
|
||||
import httpx
|
||||
from src.models.schemas import ResponseModel
|
||||
from src.services.storage_service import storage_manager
|
||||
from src.utils.errors import InvalidParameterException, StorageException
|
||||
|
||||
router = APIRouter(prefix="/storage", tags=["Storage"])
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@router.delete("/files", response_model=ResponseModel)
|
||||
async def delete_file(path: str = Query(..., description="File path or URL to delete")):
|
||||
""" 删除 a file from storage.
|
||||
|
||||
Raises:
|
||||
InvalidParameterException: Path is invalid
|
||||
StorageException: Failed to delete file
|
||||
"""
|
||||
# 参数 validation
|
||||
if not path or not path.strip():
|
||||
raise InvalidParameterException("path", "File path cannot be empty")
|
||||
|
||||
# Blob URLs are client-side only, nothing to delete on server
|
||||
if path.startswith('blob:'):
|
||||
return ResponseModel(data={"deleted": True, "path": path})
|
||||
|
||||
try:
|
||||
# 提取 key from URL if needed
|
||||
key = path
|
||||
if path.startswith(('http://', 'https://')):
|
||||
from urllib.parse import urlparse
|
||||
parsed = urlparse(path)
|
||||
key = parsed.path.lstrip('/')
|
||||
|
||||
# Handle local storage prefix /files/
|
||||
if key.startswith('files/'):
|
||||
key = key[6:]
|
||||
|
||||
success = storage_manager.delete(key)
|
||||
if success:
|
||||
return ResponseModel(data={"deleted": True, "path": key})
|
||||
else:
|
||||
# File doesn't exist or deletion failed
|
||||
raise StorageException("delete", "File not found or deletion failed")
|
||||
|
||||
except StorageException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete file {path}: {e}", exc_info=True)
|
||||
raise StorageException("delete", str(e))
|
||||
|
||||
|
||||
@router.get("/download")
|
||||
async def download_file(url: str = Query(..., description="Original or signed media URL"),
|
||||
filename: Optional[str] = Query(None, description="Optional download filename")):
|
||||
"""
|
||||
专用下载接口:
|
||||
- 后端代理拉取远端资源
|
||||
- 设置 Content-Disposition=attachment,避免浏览器直接在当前页打开视频/音频
|
||||
"""
|
||||
if not url or not url.strip():
|
||||
raise InvalidParameterException("url", "Download url cannot be empty")
|
||||
|
||||
# 如果是 OSS 原始地址,可以在这里重新签名(视业务需要)
|
||||
try:
|
||||
signed_url = storage_manager.sign_url(url) or url
|
||||
except Exception:
|
||||
# 签名失败时退回原始 URL
|
||||
signed_url = url
|
||||
|
||||
try:
|
||||
client_timeout = httpx.Timeout(60.0)
|
||||
async with httpx.AsyncClient(timeout=client_timeout, follow_redirects=True) as client:
|
||||
upstream = await client.get(signed_url)
|
||||
upstream.raise_for_status()
|
||||
|
||||
# 透传部分头部(例如 Content-Type)
|
||||
content_type = upstream.headers.get("content-type", "application/octet-stream")
|
||||
content_length = upstream.headers.get("content-length")
|
||||
|
||||
# 下载文件名:优先使用 query 参数,其次从 URL / 头部推断
|
||||
final_name = filename
|
||||
if not final_name:
|
||||
# 从 URL path 中取最后一段
|
||||
from urllib.parse import urlparse
|
||||
parsed = urlparse(url)
|
||||
candidate = (parsed.path or "").rstrip("/").split("/")[-1] or "download"
|
||||
final_name = candidate
|
||||
|
||||
headers = {
|
||||
"Content-Disposition": f'attachment; filename="{final_name}"',
|
||||
"Content-Type": content_type,
|
||||
}
|
||||
if content_length is not None:
|
||||
headers["Content-Length"] = content_length
|
||||
|
||||
return StreamingResponse(
|
||||
iter(upstream.iter_bytes()),
|
||||
status_code=upstream.status_code,
|
||||
headers=headers,
|
||||
media_type=content_type,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to proxy download for {url}: {e}", exc_info=True)
|
||||
# 返回一个统一错误响应,而不是让浏览器打开原地址
|
||||
return Response(
|
||||
content="Failed to download file",
|
||||
status_code=500,
|
||||
media_type="text/plain; charset=utf-8",
|
||||
)
|
||||
289
backend/src/api/storage_admin.py
Normal file
289
backend/src/api/storage_admin.py
Normal file
@@ -0,0 +1,289 @@
|
||||
"""
|
||||
Storage Admin API
|
||||
|
||||
存储管理 API 端点,用于管理员管理存储资源。
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional, Dict, Any, List
|
||||
from datetime import datetime
|
||||
from fastapi import APIRouter, Depends, Query, Request
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlmodel import Session, select, func
|
||||
|
||||
from src.auth.dependencies import require_admin
|
||||
from src.auth.models import UserAuth
|
||||
from src.models.schemas import ResponseModel
|
||||
from src.utils.pagination import Paginator
|
||||
from src.config.database import engine
|
||||
from src.models.entities import ProjectDB, UserDB
|
||||
from src.services.storage_service import storage_manager
|
||||
from src.utils.errors import ErrorCode, AppException
|
||||
|
||||
router = APIRouter(prefix="/admin/storage", tags=["admin-storage"])
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ===== 响应模型 =====
|
||||
|
||||
class StorageStatsResponse(BaseModel):
|
||||
"""存储统计响应"""
|
||||
total_capacity: int = Field(..., description="总容量 (bytes)")
|
||||
used_space: int = Field(..., description="已用空间 (bytes)")
|
||||
file_count: int = Field(..., description="文件数量")
|
||||
usage_percent: float = Field(..., description="使用百分比")
|
||||
|
||||
|
||||
class StorageFileItem(BaseModel):
|
||||
"""存储文件项"""
|
||||
id: str
|
||||
path: str
|
||||
size: int
|
||||
user_id: str
|
||||
username: Optional[str]
|
||||
project_id: Optional[str]
|
||||
project_name: Optional[str]
|
||||
created_at: float
|
||||
|
||||
|
||||
class StorageUserRankingItem(BaseModel):
|
||||
"""用户存储排行项"""
|
||||
user_id: str
|
||||
username: str
|
||||
storage_used: int
|
||||
file_count: int
|
||||
rank: int
|
||||
|
||||
|
||||
class StorageConfigResponse(BaseModel):
|
||||
"""存储配置响应"""
|
||||
storage_type: str
|
||||
base_path: Optional[str]
|
||||
max_file_size: int
|
||||
allowed_extensions: List[str]
|
||||
|
||||
|
||||
# ===== API 端点 =====
|
||||
|
||||
@router.get("/stats", response_model=ResponseModel)
|
||||
async def get_storage_stats(
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
):
|
||||
"""
|
||||
获取存储统计信息
|
||||
|
||||
需要管理员权限。
|
||||
"""
|
||||
try:
|
||||
# 计算存储统计
|
||||
with Session(engine) as session:
|
||||
# 这里需要根据实际存储实现调整
|
||||
# 暂时返回一个基础统计
|
||||
total_capacity = 100 * 1024 * 1024 * 1024 # 100GB
|
||||
used_space = 0 # 待实现
|
||||
file_count = 0 # 待实现
|
||||
|
||||
return ResponseModel(
|
||||
data=StorageStatsResponse(
|
||||
total_capacity=total_capacity,
|
||||
used_space=used_space,
|
||||
file_count=file_count,
|
||||
usage_percent=(used_space / total_capacity * 100) if total_capacity > 0 else 0,
|
||||
).model_dump()
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting storage stats: {e}")
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.get("/files", response_model=ResponseModel)
|
||||
async def list_storage_files(
|
||||
request: Request,
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
|
||||
user_id: Optional[str] = Query(None, description="按用户 ID 过滤"),
|
||||
project_id: Optional[str] = Query(None, description="按项目 ID 过滤"),
|
||||
search: Optional[str] = Query(None, description="搜索文件路径"),
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
):
|
||||
"""
|
||||
列出存储文件
|
||||
|
||||
支持按用户、项目过滤。
|
||||
需要管理员权限。
|
||||
"""
|
||||
try:
|
||||
with Session(engine) as session:
|
||||
# 查询项目关联的文件
|
||||
query = select(ProjectDB)
|
||||
|
||||
if user_id:
|
||||
query = query.where(ProjectDB.user_id == user_id)
|
||||
|
||||
total = session.exec(
|
||||
select(func.count()).select_from(query.subquery())
|
||||
).one()
|
||||
|
||||
offset = (page - 1) * page_size
|
||||
projects = session.exec(query.offset(offset).limit(page_size)).all()
|
||||
|
||||
items = []
|
||||
for project in projects:
|
||||
items.append(
|
||||
StorageFileItem(
|
||||
id=project.id,
|
||||
path=f"/projects/{project.id}",
|
||||
size=0, # 待实现
|
||||
user_id=project.user_id or "",
|
||||
username=None, # 待关联用户表
|
||||
project_id=project.id,
|
||||
project_name=project.name,
|
||||
created_at=project.created_at,
|
||||
).model_dump()
|
||||
)
|
||||
|
||||
paginator = Paginator(
|
||||
items=items,
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
return paginator.to_response(request)
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing storage files: {e}")
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.get("/users/ranking", response_model=ResponseModel)
|
||||
async def get_storage_user_ranking(
|
||||
request: Request,
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
):
|
||||
"""
|
||||
获取用户存储使用排行
|
||||
|
||||
需要管理员权限。
|
||||
"""
|
||||
try:
|
||||
with Session(engine) as session:
|
||||
# 查询用户项目数量作为排行依据
|
||||
query = """
|
||||
SELECT user_id, COUNT(*) as project_count
|
||||
FROM projects
|
||||
WHERE deleted_at IS NULL
|
||||
GROUP BY user_id
|
||||
ORDER BY project_count DESC
|
||||
"""
|
||||
# 简化的实现
|
||||
users_query = select(UserDB)
|
||||
users = session.exec(users_query.limit(page_size).offset((page - 1) * page_size)).all()
|
||||
|
||||
items = []
|
||||
for idx, user in enumerate(users):
|
||||
project_count = session.exec(
|
||||
select(func.count()).where(ProjectDB.user_id == user.id)
|
||||
).one()
|
||||
|
||||
items.append(
|
||||
StorageUserRankingItem(
|
||||
user_id=user.id,
|
||||
username=user.username,
|
||||
storage_used=0, # 待实现
|
||||
file_count=project_count,
|
||||
rank=(page - 1) * page_size + idx + 1,
|
||||
).model_dump()
|
||||
)
|
||||
|
||||
total = session.exec(select(func.count()).select_from(UserDB)).one()
|
||||
|
||||
paginator = Paginator(
|
||||
items=items,
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
return paginator.to_response(request)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting storage ranking: {e}")
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.post("/cleanup", response_model=ResponseModel)
|
||||
async def cleanup_orphan_files(
|
||||
dry_run: bool = Query(True, description="是否仅模拟执行"),
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
):
|
||||
"""
|
||||
清理孤立文件(没有关联项目的文件)
|
||||
|
||||
需要管理员权限。
|
||||
"""
|
||||
try:
|
||||
# 简化的实现
|
||||
return ResponseModel(
|
||||
data={
|
||||
"orphan_count": 0,
|
||||
"deleted_count": 0,
|
||||
"deleted_size": 0,
|
||||
"dry_run": dry_run,
|
||||
"message": "清理功能待实现",
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up orphan files: {e}")
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.get("/config", response_model=ResponseModel)
|
||||
async def get_storage_config(
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
):
|
||||
"""
|
||||
获取存储配置信息
|
||||
|
||||
需要管理员权限。
|
||||
"""
|
||||
try:
|
||||
from src.config.settings import (
|
||||
STORAGE_TYPE,
|
||||
STORAGE_BASE_PATH,
|
||||
STORAGE_MAX_FILE_SIZE,
|
||||
STORAGE_ALLOWED_EXTENSIONS,
|
||||
)
|
||||
|
||||
return ResponseModel(
|
||||
data=StorageConfigResponse(
|
||||
storage_type=STORAGE_TYPE or "local",
|
||||
base_path=STORAGE_BASE_PATH,
|
||||
max_file_size=STORAGE_MAX_FILE_SIZE,
|
||||
allowed_extensions=STORAGE_ALLOWED_EXTENSIONS or [],
|
||||
).model_dump()
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting storage config: {e}")
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
175
backend/src/api/tasks.py
Normal file
175
backend/src/api/tasks.py
Normal file
@@ -0,0 +1,175 @@
|
||||
"""
|
||||
Tasks Controller - 任务管理
|
||||
使用新的任务管理器 V2
|
||||
"""
|
||||
import logging
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, Query, Request, Depends
|
||||
|
||||
from src.models.schemas import ResponseModel, Task, PaginationParams
|
||||
from src.services.task_manager import task_manager, TaskPriority
|
||||
from src.auth.dependencies import get_current_user
|
||||
from src.auth.models import UserAuth
|
||||
from src.utils.errors import (
|
||||
TaskNotFoundException,
|
||||
InvalidParameterException,
|
||||
BusinessException,
|
||||
ErrorCode
|
||||
)
|
||||
from src.utils.pagination import Paginator
|
||||
|
||||
router = APIRouter(tags=["tasks"])
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# 注意:更具体的路由(如 /tasks/stats)必须在参数化路由(如 /tasks/{task_id})之前定义
|
||||
# 否则 FastAPI 会将 "stats" 当作 task_id 处理
|
||||
|
||||
|
||||
@router.get("/tasks/stats", response_model=ResponseModel)
|
||||
async def get_task_stats():
|
||||
"""获取任务统计信息
|
||||
|
||||
Returns:
|
||||
任务统计数据
|
||||
"""
|
||||
stats = task_manager.get_stats()
|
||||
return ResponseModel(data=stats)
|
||||
|
||||
|
||||
@router.get("/tasks", response_model=ResponseModel)
|
||||
async def list_tasks(
|
||||
request: Request,
|
||||
task_type: Optional[str] = Query(None, description="Filter by task type (image, video, script)"),
|
||||
status: Optional[str] = Query(None, description="Filter by status"),
|
||||
page: int = Query(1, ge=1, description="页码,从1开始"),
|
||||
page_size: int = Query(20, ge=1, le=100, description="每页数量,最大100"),
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""列出任务(分页)
|
||||
|
||||
只返回当前用户的任务
|
||||
|
||||
Args:
|
||||
task_type: 任务类型过滤
|
||||
status: 状态过滤
|
||||
page: 页码
|
||||
page_size: 每页数量
|
||||
|
||||
Returns:
|
||||
分页的任务列表
|
||||
"""
|
||||
try:
|
||||
# 计算偏移量
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
# 获取当前用户的任务
|
||||
tasks = task_manager.list_tasks(
|
||||
type=task_type,
|
||||
limit=page_size,
|
||||
offset=offset,
|
||||
user_id=current_user.id
|
||||
)
|
||||
|
||||
# 计算总数
|
||||
total = len(tasks) # 简化处理,实际应该查询总数
|
||||
|
||||
# 创建分页器
|
||||
paginator = Paginator(
|
||||
items=tasks,
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size
|
||||
)
|
||||
|
||||
return paginator.to_response(request)
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing tasks: {e}", exc_info=True)
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.get("/tasks/{task_id}", response_model=ResponseModel)
|
||||
async def get_task_status(
|
||||
task_id: str,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""获取任务状态
|
||||
|
||||
只能访问当前用户的任务
|
||||
|
||||
Args:
|
||||
task_id: 任务 ID
|
||||
|
||||
Returns:
|
||||
任务详情
|
||||
|
||||
Raises:
|
||||
TaskNotFoundException: 任务不存在或无权访问
|
||||
"""
|
||||
task = await task_manager.get_task(task_id)
|
||||
|
||||
if not task:
|
||||
raise TaskNotFoundException(task_id)
|
||||
|
||||
# 检查任务是否属于当前用户
|
||||
if task.user_id and task.user_id != current_user.id:
|
||||
raise TaskNotFoundException(task_id)
|
||||
|
||||
return ResponseModel(data=task)
|
||||
|
||||
|
||||
@router.delete("/tasks/{task_id}", response_model=ResponseModel)
|
||||
async def cancel_task(
|
||||
task_id: str,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""取消任务
|
||||
|
||||
只能取消当前用户的任务
|
||||
|
||||
Args:
|
||||
task_id: 任务 ID
|
||||
|
||||
Returns:
|
||||
取消结果
|
||||
|
||||
Raises:
|
||||
TaskNotFoundException: 任务不存在或无权访问
|
||||
BusinessException: 取消失败
|
||||
"""
|
||||
try:
|
||||
# 先检查任务是否属于当前用户
|
||||
task = await task_manager.get_task(task_id)
|
||||
if not task:
|
||||
raise TaskNotFoundException(task_id)
|
||||
|
||||
if task.user_id and task.user_id != current_user.id:
|
||||
raise TaskNotFoundException(task_id)
|
||||
|
||||
success = await task_manager.cancel_task(task_id)
|
||||
|
||||
if not success:
|
||||
raise BusinessException(
|
||||
ErrorCode.TASK_CANCELLED,
|
||||
"Task cannot be cancelled (already completed or failed)",
|
||||
{"task_id": task_id}
|
||||
)
|
||||
|
||||
return ResponseModel(
|
||||
message="Task cancelled successfully",
|
||||
data={"task_id": task_id, "cancelled": True}
|
||||
)
|
||||
|
||||
except TaskNotFoundException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cancel task {task_id}: {e}", exc_info=True)
|
||||
raise BusinessException(
|
||||
ErrorCode.UNKNOWN_ERROR,
|
||||
"Failed to cancel task",
|
||||
{"task_id": task_id, "reason": str(e)}
|
||||
)
|
||||
371
backend/src/api/user_api_keys.py
Normal file
371
backend/src/api/user_api_keys.py
Normal file
@@ -0,0 +1,371 @@
|
||||
"""
|
||||
用户 API Key 管理接口
|
||||
|
||||
提供用户管理自己的 API Key 的 CRUD 接口,以及管理员管理所有用户 API Key 的接口。
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
from fastapi import APIRouter, Depends, status, Query, Request
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from src.auth.dependencies import get_current_user, require_admin
|
||||
from src.auth.models import UserAuth
|
||||
from src.services.user_api_key_service import user_api_key_service
|
||||
from src.models.schemas import ResponseModel, PaginationParams
|
||||
from src.utils.pagination import Paginator
|
||||
from src.utils.errors import ErrorCode, AppException
|
||||
|
||||
router = APIRouter(prefix="/user-api-keys", tags=["user-api-keys"])
|
||||
|
||||
|
||||
# ===== 请求/响应模型 =====
|
||||
|
||||
class ApiKeyCreateRequest(BaseModel):
|
||||
"""创建 API Key 请求"""
|
||||
provider: str = Field(..., description="提供商,如 openai, dashscope")
|
||||
api_key: str = Field(..., description="API Key 值", min_length=1)
|
||||
name: Optional[str] = Field(None, description="自定义名称")
|
||||
extra_config: Optional[dict] = Field(None, description="额外配置")
|
||||
|
||||
|
||||
class ApiKeyUpdateRequest(BaseModel):
|
||||
"""更新 API Key 请求"""
|
||||
name: Optional[str] = Field(None, description="新名称")
|
||||
api_key: Optional[str] = Field(None, description="新 API Key 值")
|
||||
is_active: Optional[bool] = Field(None, description="是否启用")
|
||||
extra_config: Optional[dict] = Field(None, description="额外配置")
|
||||
|
||||
|
||||
class ApiKeyResponse(BaseModel):
|
||||
"""API Key 响应"""
|
||||
id: str
|
||||
user_id: str
|
||||
provider: str
|
||||
name: str
|
||||
masked_key: str
|
||||
is_active: bool
|
||||
created_at: float
|
||||
updated_at: float
|
||||
last_used_at: Optional[float]
|
||||
usage_count: int
|
||||
extra_config: Optional[dict]
|
||||
|
||||
|
||||
class ApiKeyVerifyResponse(BaseModel):
|
||||
"""API Key 验证响应"""
|
||||
valid: bool
|
||||
provider: Optional[str] = None
|
||||
message: Optional[str] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
# ===== 管理员响应模型 =====
|
||||
|
||||
class AdminApiKeyListItem(BaseModel):
|
||||
"""管理员 API Key 列表项"""
|
||||
id: str
|
||||
user_id: str
|
||||
username: str = Field(..., description="关联用户名")
|
||||
email: str = Field(..., description="关联用户邮箱")
|
||||
provider: str
|
||||
name: str
|
||||
masked_key: str
|
||||
is_active: bool
|
||||
created_at: float
|
||||
last_used_at: Optional[float]
|
||||
usage_count: int
|
||||
|
||||
|
||||
class AdminApiKeyUsageRecord(BaseModel):
|
||||
"""API Key 使用记录"""
|
||||
task_id: Optional[str]
|
||||
task_type: Optional[str]
|
||||
model: Optional[str]
|
||||
provider: str
|
||||
credits_used: Optional[float]
|
||||
created_at: str
|
||||
|
||||
|
||||
# ===== API 端点 =====
|
||||
|
||||
@router.get("", response_model=ResponseModel)
|
||||
async def list_api_keys(
|
||||
current_user: UserAuth = Depends(get_current_user),
|
||||
include_inactive: bool = False
|
||||
):
|
||||
"""
|
||||
获取当前用户的所有 API Key
|
||||
|
||||
Args:
|
||||
include_inactive: 是否包含已禁用的 Key
|
||||
"""
|
||||
keys = user_api_key_service.get_user_api_keys(
|
||||
user_id=current_user.id,
|
||||
include_inactive=include_inactive
|
||||
)
|
||||
return ResponseModel(data=keys)
|
||||
|
||||
|
||||
@router.post("", status_code=status.HTTP_201_CREATED, response_model=ResponseModel)
|
||||
async def create_api_key(
|
||||
request: ApiKeyCreateRequest,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
创建新的 API Key
|
||||
"""
|
||||
try:
|
||||
key_db = user_api_key_service.create_api_key(
|
||||
user_id=current_user.id,
|
||||
provider=request.provider,
|
||||
api_key=request.api_key,
|
||||
name=request.name,
|
||||
extra_config=request.extra_config
|
||||
)
|
||||
key = user_api_key_service.get_api_key_by_id(key_db.id, current_user.id)
|
||||
return ResponseModel(data=key)
|
||||
except ValueError as e:
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INVALID_PARAMETER,
|
||||
status_code=400
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{key_id}", response_model=ResponseModel)
|
||||
async def get_api_key(
|
||||
key_id: str,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
获取指定 API Key 的详细信息
|
||||
"""
|
||||
key = user_api_key_service.get_api_key_by_id(key_id, current_user.id)
|
||||
if not key:
|
||||
raise AppException(
|
||||
message="API Key not found",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
return ResponseModel(data=key)
|
||||
|
||||
|
||||
@router.put("/{key_id}", response_model=ResponseModel)
|
||||
async def update_api_key(
|
||||
key_id: str,
|
||||
request: ApiKeyUpdateRequest,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
更新 API Key
|
||||
"""
|
||||
key = user_api_key_service.update_api_key(
|
||||
key_id=key_id,
|
||||
user_id=current_user.id,
|
||||
name=request.name,
|
||||
api_key=request.api_key,
|
||||
is_active=request.is_active,
|
||||
extra_config=request.extra_config
|
||||
)
|
||||
if not key:
|
||||
raise AppException(
|
||||
message="API Key not found",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
return ResponseModel(data=key)
|
||||
|
||||
|
||||
@router.delete("/{key_id}", response_model=ResponseModel)
|
||||
async def delete_api_key(
|
||||
key_id: str,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
删除 API Key
|
||||
"""
|
||||
success = user_api_key_service.delete_api_key(key_id, current_user.id)
|
||||
if not success:
|
||||
raise AppException(
|
||||
message="API Key not found",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
return ResponseModel(data={"deleted": True, "id": key_id})
|
||||
|
||||
|
||||
@router.post("/{key_id}/verify", response_model=ResponseModel)
|
||||
async def verify_api_key(
|
||||
key_id: str,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
验证 API Key 是否有效
|
||||
"""
|
||||
result = await user_api_key_service.verify_api_key(key_id, current_user.id)
|
||||
return ResponseModel(data=result)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 管理员 API 端点
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/admin/api-keys", response_model=ResponseModel)
|
||||
async def admin_list_api_keys(
|
||||
request: Request,
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
|
||||
user_id: Optional[str] = Query(None, description="按用户 ID 过滤"),
|
||||
provider: Optional[str] = Query(None, description="按提供商过滤"),
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
):
|
||||
"""
|
||||
管理员列出所有 API Keys
|
||||
|
||||
支持分页、按用户 ID、提供商过滤。
|
||||
需要管理员权限。
|
||||
"""
|
||||
from sqlmodel import Session, select, func
|
||||
from src.config.database import engine
|
||||
from src.models.entities import UserApiKeyDB, UserDB
|
||||
|
||||
with Session(engine) as session:
|
||||
query = select(UserApiKeyDB)
|
||||
|
||||
# 应用过滤
|
||||
if user_id:
|
||||
query = query.where(UserApiKeyDB.user_id == user_id)
|
||||
if provider:
|
||||
query = query.where(UserApiKeyDB.provider == provider)
|
||||
|
||||
# 获取总数
|
||||
total = session.exec(
|
||||
select(func.count()).select_from(query.subquery())
|
||||
).one()
|
||||
|
||||
# 分页
|
||||
offset = (page - 1) * page_size
|
||||
query = query.order_by(UserApiKeyDB.created_at.desc()).offset(offset).limit(page_size)
|
||||
|
||||
api_keys = session.exec(query).all()
|
||||
|
||||
# 构建响应项(关联用户信息)
|
||||
items = []
|
||||
for key in api_keys:
|
||||
user = session.get(UserDB, key.user_id)
|
||||
# 获取脱敏 key
|
||||
raw_key = key.encrypted_key[:8] + "***" if key.encrypted_key else ""
|
||||
items.append(
|
||||
AdminApiKeyListItem(
|
||||
id=key.id,
|
||||
user_id=key.user_id,
|
||||
username=user.username if user else "Unknown",
|
||||
email=user.email if user else "",
|
||||
provider=key.provider,
|
||||
name=key.name or f"{key.provider} Key",
|
||||
masked_key=raw_key,
|
||||
is_active=key.is_active,
|
||||
created_at=key.created_at,
|
||||
last_used_at=key.last_used_at,
|
||||
usage_count=key.usage_count,
|
||||
)
|
||||
)
|
||||
|
||||
paginator = Paginator(
|
||||
items=[item.model_dump() for item in items],
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
return paginator.to_response(request)
|
||||
|
||||
|
||||
@router.post("/admin/api-keys/{key_id}/revoke", response_model=ResponseModel)
|
||||
async def admin_revoke_api_key(
|
||||
key_id: str,
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
):
|
||||
"""
|
||||
撤销 API Key
|
||||
|
||||
需要管理员权限。
|
||||
"""
|
||||
from sqlmodel import Session, select
|
||||
from src.config.database import engine
|
||||
from src.models.entities import UserApiKeyDB
|
||||
|
||||
with Session(engine) as session:
|
||||
key = session.get(UserApiKeyDB, key_id)
|
||||
if not key:
|
||||
raise AppException(
|
||||
message="API Key not found",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
|
||||
key.is_active = False
|
||||
key.updated_at = datetime.now().timestamp()
|
||||
session.add(key)
|
||||
session.commit()
|
||||
|
||||
return ResponseModel(
|
||||
data={
|
||||
"id": key_id,
|
||||
"is_active": False,
|
||||
"message": "API Key 已成功撤销",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.get("/admin/api-keys/{key_id}/usage", response_model=ResponseModel)
|
||||
async def admin_get_api_key_usage(
|
||||
key_id: str,
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
):
|
||||
"""
|
||||
获取 API Key 使用记录
|
||||
|
||||
需要管理员权限。
|
||||
"""
|
||||
from sqlmodel import Session, select
|
||||
from src.config.database import engine
|
||||
from src.models.entities import UserApiKeyDB, TaskDB
|
||||
|
||||
with Session(engine) as session:
|
||||
key = session.get(UserApiKeyDB, key_id)
|
||||
if not key:
|
||||
raise AppException(
|
||||
message="API Key not found",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
|
||||
# 查询相关任务记录
|
||||
tasks = session.exec(
|
||||
select(TaskDB)
|
||||
.where(TaskDB.user_id == key.user_id)
|
||||
.order_by(TaskDB.created_at.desc())
|
||||
.limit(50)
|
||||
).all()
|
||||
|
||||
usage_records = [
|
||||
AdminApiKeyUsageRecord(
|
||||
task_id=task.id,
|
||||
task_type=task.type,
|
||||
model=task.model,
|
||||
provider=task.provider or key.provider,
|
||||
credits_used=None,
|
||||
created_at=datetime.fromtimestamp(task.created_at).isoformat(),
|
||||
)
|
||||
for task in tasks
|
||||
]
|
||||
|
||||
return ResponseModel(
|
||||
data={
|
||||
"key_id": key_id,
|
||||
"usage_count": key.usage_count,
|
||||
"records": [r.model_dump() for r in usage_records],
|
||||
}
|
||||
)
|
||||
41
backend/src/auth/__init__.py
Normal file
41
backend/src/auth/__init__.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""
|
||||
认证授权模块
|
||||
|
||||
提供 JWT Token 认证、OAuth2 集成和 HTTP 轮询状态查询。
|
||||
"""
|
||||
|
||||
from src.auth.jwt import (
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
verify_token,
|
||||
verify_refresh_token,
|
||||
TokenPayload,
|
||||
TokenPair,
|
||||
)
|
||||
from src.auth.dependencies import (
|
||||
get_current_user,
|
||||
get_current_active_user,
|
||||
require_permissions,
|
||||
)
|
||||
from src.auth.middleware import AuthMiddleware
|
||||
from src.auth.models import UserAuth, TokenData, RefreshTokenRequest
|
||||
|
||||
__all__ = [
|
||||
# JWT
|
||||
"create_access_token",
|
||||
"create_refresh_token",
|
||||
"verify_token",
|
||||
"verify_refresh_token",
|
||||
"TokenPayload",
|
||||
"TokenPair",
|
||||
# Dependencies
|
||||
"get_current_user",
|
||||
"get_current_active_user",
|
||||
"require_permissions",
|
||||
# Middleware
|
||||
"AuthMiddleware",
|
||||
# Models
|
||||
"UserAuth",
|
||||
"TokenData",
|
||||
"RefreshTokenRequest",
|
||||
]
|
||||
250
backend/src/auth/dependencies.py
Normal file
250
backend/src/auth/dependencies.py
Normal file
@@ -0,0 +1,250 @@
|
||||
"""
|
||||
FastAPI 认证依赖
|
||||
|
||||
提供可注入的依赖函数用于获取当前用户和验证权限。
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional, List, Annotated
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer, HTTPBearer
|
||||
|
||||
from src.auth.jwt import verify_token
|
||||
from src.auth.models import UserAuth, TokenPayload
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _is_test_env() -> bool:
|
||||
return (
|
||||
os.getenv("PYTEST_CURRENT_TEST") is not None
|
||||
or os.getenv("PIXEL_TEST_MODE") == "1"
|
||||
or os.getenv("NODE_ENV") == "test"
|
||||
)
|
||||
|
||||
|
||||
def _build_test_user() -> UserAuth:
|
||||
return UserAuth(
|
||||
id="test-user",
|
||||
username="test_user",
|
||||
email=None,
|
||||
is_active=True,
|
||||
is_superuser=True,
|
||||
permissions=["*"],
|
||||
roles=["test"],
|
||||
created_at=None,
|
||||
last_login=None,
|
||||
)
|
||||
|
||||
# Lazy import to avoid circular import
|
||||
_user_service = None
|
||||
|
||||
def _get_user_service():
|
||||
global _user_service
|
||||
if _user_service is None:
|
||||
from src.services.user_service import user_service
|
||||
_user_service = user_service
|
||||
return _user_service
|
||||
|
||||
# OAuth2 scheme for Swagger UI
|
||||
oauth2_scheme = OAuth2PasswordBearer(
|
||||
tokenUrl="/api/v1/auth/login",
|
||||
auto_error=False
|
||||
)
|
||||
|
||||
# HTTP Bearer for direct token usage
|
||||
http_bearer = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
token: Annotated[Optional[str], Depends(oauth2_scheme)]
|
||||
) -> UserAuth:
|
||||
"""
|
||||
获取当前认证用户(从 HTTP 请求)
|
||||
|
||||
Args:
|
||||
token: OAuth2 token from Authorization header
|
||||
|
||||
Returns:
|
||||
UserAuth 对象
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 if authentication fails
|
||||
"""
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
if not token:
|
||||
if _is_test_env():
|
||||
return _build_test_user()
|
||||
raise credentials_exception
|
||||
|
||||
payload = await verify_token(token, token_type="access")
|
||||
if not payload:
|
||||
raise credentials_exception
|
||||
|
||||
if payload.sid:
|
||||
from src.services.session_service import session_service
|
||||
|
||||
if not session_service.is_session_active(payload.sid):
|
||||
raise credentials_exception
|
||||
|
||||
# 从缓存或数据库获取用户信息
|
||||
user = await _get_user_service().get_user_by_id(payload.sub)
|
||||
|
||||
if not user:
|
||||
raise credentials_exception
|
||||
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="User is inactive"
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
async def get_current_active_user(
|
||||
current_user: Annotated[UserAuth, Depends(get_current_user)]
|
||||
) -> UserAuth:
|
||||
"""
|
||||
获取当前活跃用户
|
||||
|
||||
Args:
|
||||
current_user: 当前用户
|
||||
|
||||
Returns:
|
||||
UserAuth 对象
|
||||
|
||||
Raises:
|
||||
HTTPException: 403 if user is inactive
|
||||
"""
|
||||
if not current_user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Inactive user"
|
||||
)
|
||||
return current_user
|
||||
|
||||
|
||||
def require_permissions(required_permissions: List[str]):
|
||||
"""
|
||||
权限检查依赖工厂
|
||||
|
||||
Args:
|
||||
required_permissions: 需要的权限列表
|
||||
|
||||
Returns:
|
||||
依赖函数
|
||||
|
||||
Usage:
|
||||
@router.get("/admin")
|
||||
async def admin_endpoint(
|
||||
user: UserAuth = Depends(require_permissions(["admin"]))
|
||||
):
|
||||
...
|
||||
"""
|
||||
async def permission_checker(
|
||||
current_user: Annotated[UserAuth, Depends(get_current_user)]
|
||||
) -> UserAuth:
|
||||
# 超级用户跳过权限检查
|
||||
if current_user.is_superuser:
|
||||
return current_user
|
||||
|
||||
# 检查权限
|
||||
user_perms = set(current_user.permissions)
|
||||
required_perms = set(required_permissions)
|
||||
|
||||
if not required_perms.issubset(user_perms):
|
||||
missing = required_perms - user_perms
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Missing permissions: {', '.join(missing)}"
|
||||
)
|
||||
|
||||
return current_user
|
||||
|
||||
return permission_checker
|
||||
|
||||
|
||||
async def optional_user(
|
||||
token: Annotated[Optional[str], Depends(oauth2_scheme)]
|
||||
) -> Optional[UserAuth]:
|
||||
"""
|
||||
可选的用户认证
|
||||
|
||||
如果提供了 token 则验证,否则返回 None
|
||||
|
||||
Args:
|
||||
token: OAuth2 token
|
||||
|
||||
Returns:
|
||||
UserAuth 对象或 None
|
||||
"""
|
||||
if not token:
|
||||
return None
|
||||
|
||||
try:
|
||||
return await get_current_user(token)
|
||||
except HTTPException:
|
||||
return None
|
||||
|
||||
|
||||
async def require_admin(
|
||||
current_user: Annotated[UserAuth, Depends(get_current_user)]
|
||||
) -> UserAuth:
|
||||
"""
|
||||
要求管理员权限
|
||||
|
||||
Args:
|
||||
current_user: 当前用户
|
||||
|
||||
Returns:
|
||||
UserAuth 对象
|
||||
|
||||
Raises:
|
||||
HTTPException: 403 if user is not an admin
|
||||
"""
|
||||
if not current_user.is_superuser:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Admin access required"
|
||||
)
|
||||
return current_user
|
||||
|
||||
|
||||
class AuthDependencies:
|
||||
"""
|
||||
认证依赖组合
|
||||
|
||||
提供常用的认证依赖组合
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def authenticated():
|
||||
"""返回需要认证的依赖"""
|
||||
return Depends(get_current_user)
|
||||
|
||||
@staticmethod
|
||||
def active_user():
|
||||
"""返回需要活跃用户的依赖"""
|
||||
return Depends(get_current_active_user)
|
||||
|
||||
@staticmethod
|
||||
def optional():
|
||||
"""返回可选认证的依赖"""
|
||||
return Depends(optional_user)
|
||||
|
||||
@staticmethod
|
||||
def permissions(perms: List[str]):
|
||||
"""返回需要特定权限的依赖"""
|
||||
return Depends(require_permissions(perms))
|
||||
|
||||
@staticmethod
|
||||
def admin():
|
||||
"""返回需要管理员权限的依赖"""
|
||||
return Depends(require_admin)
|
||||
346
backend/src/auth/jwt.py
Normal file
346
backend/src/auth/jwt.py
Normal file
@@ -0,0 +1,346 @@
|
||||
"""
|
||||
JWT Token 处理模块
|
||||
|
||||
提供 access token 和 refresh token 的创建、验证功能。
|
||||
支持 RSA 密钥对或 HMAC 对称加密。
|
||||
"""
|
||||
|
||||
import uuid
|
||||
import logging
|
||||
import os
|
||||
import secrets
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional, List, Dict, Any
|
||||
from pathlib import Path
|
||||
|
||||
from jose import jwt, JWTError
|
||||
from passlib.context import CryptContext
|
||||
|
||||
from src.auth.models import TokenPayload, TokenPair
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Password hashing
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
# JWT 配置
|
||||
JWT_ALGORITHM = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "30"))
|
||||
REFRESH_TOKEN_EXPIRE_DAYS = int(os.getenv("REFRESH_TOKEN_EXPIRE_DAYS", "7"))
|
||||
|
||||
# 密钥配置(生产环境必须从环境变量获取)
|
||||
def _get_secret_key() -> str:
|
||||
"""从环境变量获取密钥,如果未设置则生成随机密钥(仅用于开发)"""
|
||||
secret = os.getenv("JWT_SECRET_KEY")
|
||||
if secret:
|
||||
return secret
|
||||
|
||||
# 开发环境:生成随机密钥并警告
|
||||
if os.getenv("NODE_ENV", "development") == "development":
|
||||
logger.warning(
|
||||
"JWT_SECRET_KEY not set in environment. "
|
||||
"Using auto-generated random key (tokens will not persist across restarts). "
|
||||
"Please set JWT_SECRET_KEY in production!"
|
||||
)
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
# 生产环境必须设置密钥
|
||||
raise ValueError(
|
||||
"JWT_SECRET_KEY must be set in environment variables for production. "
|
||||
"Generate a secure key with: python -c 'import secrets; print(secrets.token_hex(32))'"
|
||||
)
|
||||
|
||||
SECRET_KEY = _get_secret_key()
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""验证密码"""
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
|
||||
|
||||
def get_password_hash(password: str) -> str:
|
||||
"""生成密码哈希"""
|
||||
return pwd_context.hash(password)
|
||||
|
||||
|
||||
def create_access_token(
|
||||
user_id: str,
|
||||
scopes: Optional[List[str]] = None,
|
||||
expires_delta: Optional[timedelta] = None,
|
||||
session_id: Optional[str] = None,
|
||||
extra_claims: Optional[Dict[str, Any]] = None
|
||||
) -> str:
|
||||
"""
|
||||
创建 access token
|
||||
|
||||
Args:
|
||||
user_id: 用户 ID
|
||||
scopes: 权限范围列表
|
||||
expires_delta: 自定义过期时间
|
||||
extra_claims: 额外的 JWT claims
|
||||
|
||||
Returns:
|
||||
JWT access token string
|
||||
"""
|
||||
if expires_delta:
|
||||
expire = datetime.now(timezone.utc) + expires_delta
|
||||
else:
|
||||
expire = datetime.now(timezone.utc) + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
payload = {
|
||||
"sub": user_id,
|
||||
"exp": int(expire.timestamp()),
|
||||
"iat": int(now.timestamp()),
|
||||
"type": "access",
|
||||
"scopes": scopes or [],
|
||||
"jti": str(uuid.uuid4()),
|
||||
"sid": session_id,
|
||||
}
|
||||
|
||||
if extra_claims:
|
||||
payload.update(extra_claims)
|
||||
|
||||
encoded_jwt = jwt.encode(payload, SECRET_KEY, algorithm=JWT_ALGORITHM)
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
def create_refresh_token(
|
||||
user_id: str,
|
||||
expires_delta: Optional[timedelta] = None,
|
||||
jti: Optional[str] = None,
|
||||
session_id: Optional[str] = None,
|
||||
session_family_id: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
创建 refresh token
|
||||
|
||||
Args:
|
||||
user_id: 用户 ID
|
||||
expires_delta: 自定义过期时间
|
||||
jti: JWT ID(用于 token 撤销)
|
||||
|
||||
Returns:
|
||||
JWT refresh token string
|
||||
"""
|
||||
if expires_delta:
|
||||
expire = datetime.now(timezone.utc) + expires_delta
|
||||
else:
|
||||
expire = datetime.now(timezone.utc) + timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
token_jti = jti or str(uuid.uuid4())
|
||||
|
||||
payload = {
|
||||
"sub": user_id,
|
||||
"exp": int(expire.timestamp()),
|
||||
"iat": int(now.timestamp()),
|
||||
"type": "refresh",
|
||||
"jti": token_jti,
|
||||
"sid": session_id,
|
||||
"sfid": session_family_id,
|
||||
}
|
||||
|
||||
encoded_jwt = jwt.encode(payload, SECRET_KEY, algorithm=JWT_ALGORITHM)
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
def create_token_pair(
|
||||
user_id: str,
|
||||
scopes: Optional[List[str]] = None,
|
||||
access_expires: Optional[timedelta] = None,
|
||||
refresh_expires: Optional[timedelta] = None,
|
||||
session_id: Optional[str] = None,
|
||||
session_family_id: Optional[str] = None,
|
||||
) -> TokenPair:
|
||||
"""
|
||||
创建 access token 和 refresh token 对
|
||||
|
||||
Args:
|
||||
user_id: 用户 ID
|
||||
scopes: 权限范围
|
||||
access_expires: access token 过期时间
|
||||
refresh_expires: refresh token 过期时间
|
||||
|
||||
Returns:
|
||||
TokenPair 包含 access_token 和 refresh_token
|
||||
"""
|
||||
access_token = create_access_token(user_id, scopes, access_expires, session_id=session_id)
|
||||
refresh_token = create_refresh_token(
|
||||
user_id,
|
||||
refresh_expires,
|
||||
session_id=session_id,
|
||||
session_family_id=session_family_id,
|
||||
)
|
||||
|
||||
return TokenPair(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
session_id=session_id,
|
||||
session_family_id=session_family_id,
|
||||
)
|
||||
|
||||
|
||||
async def verify_token(token: str, token_type: str = "access", check_blacklist: bool = True) -> Optional[TokenPayload]:
|
||||
"""
|
||||
验证 JWT token
|
||||
|
||||
Args:
|
||||
token: JWT token string
|
||||
token_type: 期望的 token 类型 ("access" 或 "refresh")
|
||||
check_blacklist: 是否检查黑名单(默认为 True)
|
||||
|
||||
Returns:
|
||||
TokenPayload 如果验证成功,否则 None
|
||||
"""
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[JWT_ALGORITHM])
|
||||
|
||||
# 验证 token 类型
|
||||
if payload.get("type") != token_type:
|
||||
logger.warning(f"Token type mismatch: expected {token_type}, got {payload.get('type')}")
|
||||
return None
|
||||
|
||||
# 验证必要字段
|
||||
if "sub" not in payload or "exp" not in payload:
|
||||
logger.warning("Token missing required fields")
|
||||
return None
|
||||
|
||||
# 检查是否在黑名单中
|
||||
if check_blacklist:
|
||||
try:
|
||||
from src.services.token_blacklist_service import token_blacklist_service
|
||||
is_revoked = await token_blacklist_service.is_token_revoked(token)
|
||||
if is_revoked:
|
||||
logger.warning(f"Token has been revoked: {payload.get('jti')}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check token blacklist: {e}")
|
||||
# 黑名单检查失败,继续验证(避免服务不可用)
|
||||
|
||||
return TokenPayload(**payload)
|
||||
|
||||
except JWTError as e:
|
||||
logger.warning(f"JWT verification failed: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error during token verification: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def verify_token_sync(token: str, token_type: str = "access") -> Optional[TokenPayload]:
|
||||
"""
|
||||
同步验证 JWT token(不检查黑名单,用于中间件等同步上下文)
|
||||
|
||||
Args:
|
||||
token: JWT token string
|
||||
token_type: 期望的 token 类型 ("access" 或 "refresh")
|
||||
|
||||
Returns:
|
||||
TokenPayload 如果验证成功,否则 None
|
||||
"""
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[JWT_ALGORITHM])
|
||||
|
||||
# 验证 token 类型
|
||||
if payload.get("type") != token_type:
|
||||
logger.warning(f"Token type mismatch: expected {token_type}, got {payload.get('type')}")
|
||||
return None
|
||||
|
||||
# 验证必要字段
|
||||
if "sub" not in payload or "exp" not in payload:
|
||||
logger.warning("Token missing required fields")
|
||||
return None
|
||||
|
||||
return TokenPayload(**payload)
|
||||
|
||||
except JWTError as e:
|
||||
logger.warning(f"JWT verification failed: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error during token verification: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def verify_refresh_token(token: str) -> Optional[TokenPayload]:
|
||||
"""
|
||||
验证 refresh token
|
||||
|
||||
Args:
|
||||
token: Refresh token string
|
||||
|
||||
Returns:
|
||||
TokenPayload 如果验证成功,否则 None
|
||||
"""
|
||||
return verify_token(token, token_type="refresh")
|
||||
|
||||
|
||||
def decode_token_unsafe(token: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
不解密直接解码 token(用于获取 payload 信息而不验证签名)
|
||||
|
||||
警告: 仅用于获取非敏感信息,不要用于认证!
|
||||
|
||||
Args:
|
||||
token: JWT token string
|
||||
|
||||
Returns:
|
||||
Payload dict 或 None
|
||||
"""
|
||||
try:
|
||||
# 分割 JWT
|
||||
parts = token.split(".")
|
||||
if len(parts) != 3:
|
||||
return None
|
||||
|
||||
# 解码 payload (第二部分)
|
||||
import base64
|
||||
payload_b64 = parts[1]
|
||||
# 添加 padding
|
||||
padding = 4 - len(payload_b64) % 4
|
||||
if padding != 4:
|
||||
payload_b64 += "=" * padding
|
||||
|
||||
payload_json = base64.urlsafe_b64decode(payload_b64)
|
||||
return json.loads(payload_json)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to decode token: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def get_token_expiry(token: str) -> Optional[datetime]:
|
||||
"""
|
||||
获取 token 的过期时间
|
||||
|
||||
Args:
|
||||
token: JWT token string
|
||||
|
||||
Returns:
|
||||
过期时间或 None
|
||||
"""
|
||||
payload = decode_token_unsafe(token)
|
||||
if payload and "exp" in payload:
|
||||
return datetime.fromtimestamp(payload["exp"], tz=timezone.utc)
|
||||
return None
|
||||
|
||||
|
||||
def is_token_expired(token: str) -> bool:
|
||||
"""
|
||||
检查 token 是否已过期
|
||||
|
||||
Args:
|
||||
token: JWT token string
|
||||
|
||||
Returns:
|
||||
如果已过期返回 True
|
||||
"""
|
||||
expiry = get_token_expiry(token)
|
||||
if expiry:
|
||||
return datetime.now(timezone.utc) > expiry
|
||||
return True
|
||||
|
||||
|
||||
# Import json for decode_token_unsafe
|
||||
import json
|
||||
343
backend/src/auth/middleware.py
Normal file
343
backend/src/auth/middleware.py
Normal file
@@ -0,0 +1,343 @@
|
||||
"""
|
||||
认证中间件
|
||||
|
||||
提供全局认证中间件和权限检查。
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional, List, Set
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response, JSONResponse
|
||||
from starlette.types import ASGIApp
|
||||
|
||||
from src.auth.jwt import verify_token
|
||||
from src.auth.models import UserAuth, TokenPayload
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AuthMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
全局认证中间件
|
||||
|
||||
为所有请求添加认证信息,可选择性地保护路由。
|
||||
"""
|
||||
|
||||
# 不需要认证的路径
|
||||
PUBLIC_PATHS: Set[str] = {
|
||||
"/docs",
|
||||
"/redoc",
|
||||
"/openapi.json",
|
||||
"/api/v1/auth/login",
|
||||
"/api/v1/auth/register",
|
||||
"/api/v1/auth/refresh",
|
||||
"/health",
|
||||
"/metrics",
|
||||
"/api/v1/health",
|
||||
}
|
||||
|
||||
# 路径前缀(以这些前缀开头的路径不需要认证)
|
||||
PUBLIC_PREFIXES: List[str] = [
|
||||
"/static/",
|
||||
"/api/v1/public/",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
require_auth: bool = False,
|
||||
public_paths: Optional[Set[str]] = None,
|
||||
public_prefixes: Optional[List[str]] = None
|
||||
):
|
||||
"""
|
||||
初始化认证中间件
|
||||
|
||||
Args:
|
||||
app: ASGI 应用
|
||||
require_auth: 是否默认要求认证
|
||||
public_paths: 额外的公开路径
|
||||
public_prefixes: 额外的公开路径前缀
|
||||
"""
|
||||
super().__init__(app)
|
||||
self.require_auth = require_auth
|
||||
|
||||
if public_paths:
|
||||
self.PUBLIC_PATHS.update(public_paths)
|
||||
if public_prefixes:
|
||||
self.PUBLIC_PREFIXES.extend(public_prefixes)
|
||||
|
||||
def _is_public_path(self, path: str) -> bool:
|
||||
"""检查路径是否公开"""
|
||||
# 精确匹配
|
||||
if path in self.PUBLIC_PATHS:
|
||||
return True
|
||||
|
||||
# 前缀匹配
|
||||
for prefix in self.PUBLIC_PREFIXES:
|
||||
if path.startswith(prefix):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _extract_token(self, request: Request) -> Optional[str]:
|
||||
"""从请求中提取 token"""
|
||||
# 从 Authorization header
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if auth_header.startswith("Bearer "):
|
||||
return auth_header[7:]
|
||||
|
||||
# 从 query parameter (用于特殊场景,如直接链接分享)
|
||||
token = request.query_params.get("token")
|
||||
if token:
|
||||
return token
|
||||
|
||||
# 从 cookie
|
||||
token = request.cookies.get("access_token")
|
||||
if token:
|
||||
return token
|
||||
|
||||
return None
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
"""
|
||||
处理请求
|
||||
|
||||
Args:
|
||||
request: 请求对象
|
||||
call_next: 下一个中间件/处理函数
|
||||
|
||||
Returns:
|
||||
响应对象
|
||||
"""
|
||||
path = request.url.path
|
||||
|
||||
# 提取 token
|
||||
token = self._extract_token(request)
|
||||
|
||||
# 如果有 token,尝试验证
|
||||
user: Optional[UserAuth] = None
|
||||
if token:
|
||||
payload = await verify_token(token, token_type="access")
|
||||
if payload:
|
||||
user = UserAuth(
|
||||
id=payload.sub,
|
||||
username=f"user_{payload.sub[:8]}",
|
||||
is_active=True,
|
||||
permissions=payload.scopes or []
|
||||
)
|
||||
|
||||
# 将用户信息附加到请求状态
|
||||
request.state.user = user
|
||||
request.state.is_authenticated = user is not None
|
||||
|
||||
# 检查是否需要认证
|
||||
if self._is_public_path(path):
|
||||
# 公开路径,允许访问
|
||||
return await call_next(request)
|
||||
|
||||
# 非公开路径
|
||||
if self.require_auth and not user:
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={
|
||||
"code": "4010",
|
||||
"message": "Authentication required",
|
||||
"data": None
|
||||
}
|
||||
)
|
||||
|
||||
# 继续处理
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
class PermissionMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
权限检查中间件
|
||||
|
||||
基于路径进行权限检查。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
path_permissions: Optional[dict] = None
|
||||
):
|
||||
"""
|
||||
初始化权限中间件
|
||||
|
||||
Args:
|
||||
app: ASGI 应用
|
||||
path_permissions: 路径权限映射 {"/path": ["permission1", "permission2"]}
|
||||
"""
|
||||
super().__init__(app)
|
||||
self.path_permissions = path_permissions or {}
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
"""处理请求"""
|
||||
path = request.url.path
|
||||
user: Optional[UserAuth] = getattr(request.state, "user", None)
|
||||
|
||||
# 检查路径权限
|
||||
for path_pattern, required_perms in self.path_permissions.items():
|
||||
if path.startswith(path_pattern):
|
||||
if not user:
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={
|
||||
"code": "4010",
|
||||
"message": "Authentication required",
|
||||
"data": None
|
||||
}
|
||||
)
|
||||
|
||||
if not user.is_superuser:
|
||||
user_perms = set(user.permissions)
|
||||
required = set(required_perms)
|
||||
|
||||
if not required.issubset(user_perms):
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content={
|
||||
"code": "4030",
|
||||
"message": "Permission denied",
|
||||
"data": {"required": list(required)}
|
||||
}
|
||||
)
|
||||
|
||||
break
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
基于 Redis 的速率限制中间件
|
||||
|
||||
为认证用户和匿名用户分别设置速率限制。
|
||||
使用 Redis 实现分布式速率限制。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
authenticated_limit: int = 1000, # per minute
|
||||
anonymous_limit: int = 60, # per minute
|
||||
window_size: int = 60, # 1 minute window
|
||||
):
|
||||
super().__init__(app)
|
||||
self.authenticated_limit = authenticated_limit
|
||||
self.anonymous_limit = anonymous_limit
|
||||
self.window_size = window_size
|
||||
self._redis_client = None
|
||||
self._enabled = False
|
||||
|
||||
async def _get_redis(self):
|
||||
"""Lazy initialize Redis connection"""
|
||||
if not self._enabled:
|
||||
return None
|
||||
if self._redis_client is None:
|
||||
try:
|
||||
import redis.asyncio as aioredis
|
||||
from src.config.settings import REDIS_URL
|
||||
self._redis_client = await aioredis.from_url(
|
||||
REDIS_URL,
|
||||
encoding="utf-8",
|
||||
decode_responses=True
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis not available for rate limiting: {e}")
|
||||
return None
|
||||
return self._redis_client
|
||||
|
||||
async def _check_rate_limit(self, key: str, limit: int) -> tuple[bool, int, int]:
|
||||
"""
|
||||
检查速率限制
|
||||
|
||||
Returns:
|
||||
(allowed, remaining, reset_after)
|
||||
"""
|
||||
redis = await self._get_redis()
|
||||
if not redis:
|
||||
# Redis unavailable, allow request
|
||||
return True, limit, 0
|
||||
|
||||
try:
|
||||
current_time = int(time.time())
|
||||
window_start = current_time - self.window_size
|
||||
|
||||
# Remove old entries
|
||||
await redis.zremrangebyscore(key, 0, window_start)
|
||||
|
||||
# Count current requests in window
|
||||
current_count = await redis.zcard(key)
|
||||
|
||||
if current_count >= limit:
|
||||
# Rate limit exceeded
|
||||
oldest = await redis.zrange(key, 0, 0, withscores=True)
|
||||
reset_after = int(oldest[0][1]) + self.window_size - current_time
|
||||
return False, 0, max(reset_after, 1)
|
||||
|
||||
# Add current request
|
||||
await redis.zadd(key, {str(current_time): current_time})
|
||||
await redis.expire(key, self.window_size)
|
||||
|
||||
remaining = limit - current_count - 1
|
||||
return True, remaining, 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Rate limit check failed: {e}")
|
||||
# Fail open - allow request if Redis fails
|
||||
return True, limit, 0
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
"""处理请求"""
|
||||
# Check if Redis is enabled
|
||||
if not self._enabled:
|
||||
from src.config.settings import REDIS_ENABLED
|
||||
self._enabled = REDIS_ENABLED
|
||||
|
||||
# Get user identifier
|
||||
user: Optional[UserAuth] = getattr(request.state, "user", None)
|
||||
if user:
|
||||
client_id = f"ratelimit:auth:{user.id}"
|
||||
limit = self.authenticated_limit
|
||||
else:
|
||||
# Use IP address for anonymous users
|
||||
client_ip = request.headers.get("X-Forwarded-For", request.client.host)
|
||||
client_id = f"ratelimit:anon:{client_ip}"
|
||||
limit = self.anonymous_limit
|
||||
|
||||
# Check rate limit
|
||||
allowed, remaining, reset_after = await self._check_rate_limit(client_id, limit)
|
||||
|
||||
# Process request
|
||||
response = await call_next(request)
|
||||
|
||||
# Add rate limit headers
|
||||
if self._enabled:
|
||||
response.headers["X-RateLimit-Limit"] = str(limit)
|
||||
response.headers["X-RateLimit-Remaining"] = str(remaining)
|
||||
if reset_after > 0:
|
||||
response.headers["X-RateLimit-Reset-After"] = str(reset_after)
|
||||
|
||||
if not allowed:
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={
|
||||
"code": "4290",
|
||||
"message": "Rate limit exceeded",
|
||||
"details": {
|
||||
"limit": limit,
|
||||
"reset_after": reset_after
|
||||
}
|
||||
},
|
||||
headers={
|
||||
"Retry-After": str(reset_after),
|
||||
"X-RateLimit-Limit": str(limit),
|
||||
"X-RateLimit-Remaining": "0"
|
||||
}
|
||||
)
|
||||
|
||||
return response
|
||||
91
backend/src/auth/models.py
Normal file
91
backend/src/auth/models.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""
|
||||
认证模型定义
|
||||
|
||||
包含用户认证相关的 Pydantic 模型。
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
|
||||
|
||||
class TokenData(BaseModel):
|
||||
"""Token 数据模型"""
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
token_type: str = "bearer"
|
||||
expires_in: int = Field(description="Access token 过期时间(秒)")
|
||||
refresh_expires_in: int = Field(description="Refresh token 过期时间(秒)")
|
||||
|
||||
|
||||
class TokenPayload(BaseModel):
|
||||
"""JWT Token Payload"""
|
||||
sub: str = Field(description="用户 ID")
|
||||
exp: int = Field(description="过期时间戳")
|
||||
iat: int = Field(description="签发时间戳")
|
||||
type: str = Field(default="access", description="Token 类型: access 或 refresh")
|
||||
scopes: List[str] = Field(default=[], description="权限范围")
|
||||
jti: Optional[str] = Field(default=None, description="JWT ID,用于 token 撤销")
|
||||
sid: Optional[str] = Field(default=None, description="会话 ID")
|
||||
sfid: Optional[str] = Field(default=None, description="会话族 ID")
|
||||
|
||||
|
||||
class TokenPair(BaseModel):
|
||||
"""Token 对(access + refresh)"""
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
token_type: str = "bearer"
|
||||
session_id: Optional[str] = None
|
||||
session_family_id: Optional[str] = None
|
||||
|
||||
|
||||
class RefreshTokenRequest(BaseModel):
|
||||
"""刷新 Token 请求"""
|
||||
refresh_token: str
|
||||
|
||||
|
||||
class UserAuth(BaseModel):
|
||||
"""用户认证信息"""
|
||||
id: str
|
||||
username: str
|
||||
email: Optional[str] = None
|
||||
avatar_url: Optional[str] = None
|
||||
is_active: bool = True
|
||||
is_superuser: bool = False
|
||||
permissions: List[str] = []
|
||||
roles: List[str] = []
|
||||
created_at: Optional[float] = None
|
||||
last_login: Optional[float] = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class UserLoginRequest(BaseModel):
|
||||
"""用户登录请求"""
|
||||
username: str
|
||||
password: str
|
||||
|
||||
|
||||
class UserRegisterRequest(BaseModel):
|
||||
"""用户注册请求"""
|
||||
username: str
|
||||
email: str
|
||||
password: str
|
||||
password_confirm: str
|
||||
|
||||
|
||||
class UserPasswordChangeRequest(BaseModel):
|
||||
"""用户修改密码请求"""
|
||||
current_password: str
|
||||
new_password: str
|
||||
new_password_confirm: str
|
||||
|
||||
|
||||
class OAuth2UserInfo(BaseModel):
|
||||
"""OAuth2 用户信息"""
|
||||
provider: str
|
||||
provider_user_id: str
|
||||
email: Optional[str] = None
|
||||
username: Optional[str] = None
|
||||
avatar_url: Optional[str] = None
|
||||
raw_data: Dict[str, Any] = {}
|
||||
16
backend/src/cache/__init__.py
vendored
Normal file
16
backend/src/cache/__init__.py
vendored
Normal file
@@ -0,0 +1,16 @@
|
||||
"""
|
||||
缓存模块
|
||||
|
||||
提供多级缓存支持:
|
||||
- Redis 缓存
|
||||
- 内存缓存(开发环境)
|
||||
- 缓存装饰器
|
||||
"""
|
||||
|
||||
from src.cache.redis_cache import RedisCache, cache_result, cache_with_ttl
|
||||
|
||||
__all__ = [
|
||||
"RedisCache",
|
||||
"cache_result",
|
||||
"cache_with_ttl",
|
||||
]
|
||||
328
backend/src/cache/redis_cache.py
vendored
Normal file
328
backend/src/cache/redis_cache.py
vendored
Normal file
@@ -0,0 +1,328 @@
|
||||
"""
|
||||
Redis 缓存实现
|
||||
|
||||
提供缓存装饰器和缓存管理功能。
|
||||
"""
|
||||
|
||||
import functools
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import pickle
|
||||
from typing import Optional, Callable, Any, Union
|
||||
from datetime import timedelta
|
||||
|
||||
import redis.asyncio as redis
|
||||
from redis.asyncio.client import Redis
|
||||
|
||||
from src.config.settings import REDIS_URL, REDIS_ENABLED
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RedisCache:
|
||||
"""
|
||||
Redis 缓存管理器
|
||||
|
||||
提供统一的缓存接口。
|
||||
"""
|
||||
|
||||
_instance: Optional["RedisCache"] = None
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
"""单例模式"""
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, redis_url: Optional[str] = None):
|
||||
if hasattr(self, "_initialized"):
|
||||
return
|
||||
|
||||
self._redis_url = redis_url or REDIS_URL
|
||||
self._redis: Optional[Redis] = None
|
||||
self._enabled = REDIS_ENABLED
|
||||
self._initialized = True
|
||||
|
||||
async def connect(self):
|
||||
"""建立 Redis 连接"""
|
||||
if not self._enabled:
|
||||
logger.info("Redis cache is disabled")
|
||||
return
|
||||
|
||||
try:
|
||||
self._redis = await redis.from_url(
|
||||
self._redis_url,
|
||||
encoding="utf-8",
|
||||
decode_responses=False # 支持二进制数据
|
||||
)
|
||||
await self._redis.ping()
|
||||
logger.info(f"Connected to Redis cache at {self._redis_url}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to Redis cache: {e}")
|
||||
self._enabled = False
|
||||
|
||||
async def disconnect(self):
|
||||
"""断开 Redis 连接"""
|
||||
if self._redis:
|
||||
await self._redis.close()
|
||||
self._redis = None
|
||||
logger.info("Disconnected from Redis cache")
|
||||
|
||||
def _make_key(self, prefix: str, *args, **kwargs) -> str:
|
||||
"""
|
||||
生成缓存键
|
||||
|
||||
Args:
|
||||
prefix: 键前缀
|
||||
*args: 位置参数
|
||||
**kwargs: 关键字参数
|
||||
|
||||
Returns:
|
||||
缓存键
|
||||
"""
|
||||
key_parts = [prefix]
|
||||
|
||||
for arg in args:
|
||||
key_parts.append(str(arg))
|
||||
|
||||
for k, v in sorted(kwargs.items()):
|
||||
key_parts.append(f"{k}:{v}")
|
||||
|
||||
raw_key = "|".join(key_parts)
|
||||
# 使用哈希确保键长度合理
|
||||
return f"cache:{prefix}:{hashlib.md5(raw_key.encode()).hexdigest()}"
|
||||
|
||||
async def get(self, key: str) -> Optional[Any]:
|
||||
"""
|
||||
获取缓存值
|
||||
|
||||
Args:
|
||||
key: 缓存键
|
||||
|
||||
Returns:
|
||||
缓存值或 None
|
||||
"""
|
||||
if not self._enabled or not self._redis:
|
||||
return None
|
||||
|
||||
try:
|
||||
data = await self._redis.get(key)
|
||||
if data:
|
||||
return pickle.loads(data)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Cache get error: {e}")
|
||||
return None
|
||||
|
||||
async def set(
|
||||
self,
|
||||
key: str,
|
||||
value: Any,
|
||||
ttl: Optional[Union[int, timedelta]] = None
|
||||
) -> bool:
|
||||
"""
|
||||
设置缓存值
|
||||
|
||||
Args:
|
||||
key: 缓存键
|
||||
value: 缓存值
|
||||
ttl: 过期时间(秒或 timedelta)
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
if not self._enabled or not self._redis:
|
||||
return False
|
||||
|
||||
try:
|
||||
expire = None
|
||||
if isinstance(ttl, timedelta):
|
||||
expire = int(ttl.total_seconds())
|
||||
elif isinstance(ttl, int):
|
||||
expire = ttl
|
||||
|
||||
data = pickle.dumps(value)
|
||||
await self._redis.set(key, data, ex=expire)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Cache set error: {e}")
|
||||
return False
|
||||
|
||||
async def delete(self, key: str) -> bool:
|
||||
"""
|
||||
删除缓存值
|
||||
|
||||
Args:
|
||||
key: 缓存键
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
if not self._enabled or not self._redis:
|
||||
return False
|
||||
|
||||
try:
|
||||
await self._redis.delete(key)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Cache delete error: {e}")
|
||||
return False
|
||||
|
||||
async def delete_pattern(self, pattern: str) -> int:
|
||||
"""
|
||||
删除匹配模式的缓存
|
||||
|
||||
Args:
|
||||
pattern: 匹配模式
|
||||
|
||||
Returns:
|
||||
删除的键数量
|
||||
"""
|
||||
if not self._enabled or not self._redis:
|
||||
return 0
|
||||
|
||||
try:
|
||||
keys = await self._redis.keys(pattern)
|
||||
if keys:
|
||||
return await self._redis.delete(*keys)
|
||||
return 0
|
||||
except Exception as e:
|
||||
logger.error(f"Cache delete pattern error: {e}")
|
||||
return 0
|
||||
|
||||
async def exists(self, key: str) -> bool:
|
||||
"""
|
||||
检查键是否存在
|
||||
|
||||
Args:
|
||||
key: 缓存键
|
||||
|
||||
Returns:
|
||||
是否存在
|
||||
"""
|
||||
if not self._enabled or not self._redis:
|
||||
return False
|
||||
|
||||
try:
|
||||
return await self._redis.exists(key) > 0
|
||||
except Exception as e:
|
||||
logger.error(f"Cache exists error: {e}")
|
||||
return False
|
||||
|
||||
async def ttl(self, key: str) -> int:
|
||||
"""
|
||||
获取键的剩余生存时间
|
||||
|
||||
Args:
|
||||
key: 缓存键
|
||||
|
||||
Returns:
|
||||
剩余秒数,-1 表示永不过期,-2 表示不存在
|
||||
"""
|
||||
if not self._enabled or not self._redis:
|
||||
return -2
|
||||
|
||||
try:
|
||||
return await self._redis.ttl(key)
|
||||
except Exception as e:
|
||||
logger.error(f"Cache ttl error: {e}")
|
||||
return -2
|
||||
|
||||
async def clear(self) -> bool:
|
||||
"""
|
||||
清空所有缓存(谨慎使用)
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
if not self._enabled or not self._redis:
|
||||
return False
|
||||
|
||||
try:
|
||||
await self._redis.flushdb()
|
||||
logger.warning("Cache cleared")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Cache clear error: {e}")
|
||||
return False
|
||||
|
||||
|
||||
# 全局缓存实例
|
||||
cache = RedisCache()
|
||||
|
||||
|
||||
def cache_result(
|
||||
prefix: str,
|
||||
ttl: Optional[Union[int, timedelta]] = 300,
|
||||
key_func: Optional[Callable] = None
|
||||
):
|
||||
"""
|
||||
缓存装饰器
|
||||
|
||||
自动缓存函数结果。
|
||||
|
||||
Args:
|
||||
prefix: 缓存键前缀
|
||||
ttl: 过期时间(秒)
|
||||
key_func: 自定义键生成函数
|
||||
|
||||
Usage:
|
||||
@cache_result("user", ttl=300)
|
||||
async def get_user(user_id: str):
|
||||
return await fetch_user(user_id)
|
||||
"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@functools.wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
# 生成缓存键
|
||||
if key_func:
|
||||
cache_key = key_func(*args, **kwargs)
|
||||
else:
|
||||
cache_key = cache._make_key(prefix, *args, **kwargs)
|
||||
|
||||
# 尝试从缓存获取
|
||||
result = await cache.get(cache_key)
|
||||
if result is not None:
|
||||
logger.debug(f"Cache hit: {cache_key}")
|
||||
return result
|
||||
|
||||
# 执行函数
|
||||
result = await func(*args, **kwargs)
|
||||
|
||||
# 存入缓存
|
||||
await cache.set(cache_key, result, ttl)
|
||||
logger.debug(f"Cache set: {cache_key}")
|
||||
|
||||
return result
|
||||
|
||||
@functools.wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
# 同步函数不支持缓存
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
def cache_with_ttl(ttl: int = 300):
|
||||
"""
|
||||
简化版缓存装饰器
|
||||
|
||||
使用函数名作为前缀。
|
||||
|
||||
Args:
|
||||
ttl: 过期时间(秒)
|
||||
|
||||
Usage:
|
||||
@cache_with_ttl(300)
|
||||
async def get_user(user_id: str):
|
||||
return await fetch_user(user_id)
|
||||
"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
return cache_result(prefix=func.__name__, ttl=ttl)(func)
|
||||
return decorator
|
||||
|
||||
|
||||
# 导入 asyncio 用于检查协程
|
||||
import asyncio
|
||||
198
backend/src/config/database.py
Normal file
198
backend/src/config/database.py
Normal file
@@ -0,0 +1,198 @@
|
||||
from sqlmodel import create_engine, SQLModel, Session
|
||||
from sqlalchemy.pool import StaticPool, QueuePool
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy import event
|
||||
from src.config.settings import DB_PATH, DATABASE_URL
|
||||
import logging
|
||||
import time
|
||||
import os
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Slow query threshold in seconds
|
||||
SLOW_QUERY_THRESHOLD = float(os.getenv("SLOW_QUERY_THRESHOLD", "1.0"))
|
||||
|
||||
# Connection pool configuration from environment variables
|
||||
POOL_SIZE = int(os.getenv("DB_POOL_SIZE", "20"))
|
||||
MAX_OVERFLOW = int(os.getenv("DB_MAX_OVERFLOW", "10"))
|
||||
POOL_TIMEOUT = int(os.getenv("DB_POOL_TIMEOUT", "30"))
|
||||
POOL_RECYCLE = int(os.getenv("DB_POOL_RECYCLE", "3600"))
|
||||
POOL_PRE_PING = os.getenv("DB_POOL_PRE_PING", "true").lower() == "true"
|
||||
|
||||
# Create the database engine with optimized connection pooling
|
||||
if DATABASE_URL:
|
||||
# PostgreSQL Configuration with configurable connection pool
|
||||
engine = create_engine(
|
||||
DATABASE_URL,
|
||||
echo=False,
|
||||
poolclass=QueuePool,
|
||||
pool_size=POOL_SIZE,
|
||||
max_overflow=MAX_OVERFLOW,
|
||||
pool_timeout=POOL_TIMEOUT,
|
||||
pool_recycle=POOL_RECYCLE,
|
||||
pool_pre_ping=POOL_PRE_PING,
|
||||
# Additional optimizations
|
||||
connect_args={
|
||||
"connect_timeout": 10,
|
||||
"options": "-c statement_timeout=120000" # 120 second statement timeout for AI operations
|
||||
}
|
||||
)
|
||||
logger.info(
|
||||
f"PostgreSQL engine created with pool_size={POOL_SIZE}, "
|
||||
f"max_overflow={MAX_OVERFLOW}, pool_timeout={POOL_TIMEOUT}s, "
|
||||
f"pool_recycle={POOL_RECYCLE}s, pool_pre_ping={POOL_PRE_PING}"
|
||||
)
|
||||
else:
|
||||
# SQLite Configuration (Fallback)
|
||||
engine = create_engine(
|
||||
f"sqlite:///{DB_PATH}",
|
||||
echo=False,
|
||||
connect_args={
|
||||
"check_same_thread": False,
|
||||
"timeout": 30 # 30 second busy timeout
|
||||
},
|
||||
poolclass=StaticPool, # Use StaticPool for SQLite
|
||||
)
|
||||
logger.info(f"SQLite engine created with StaticPool at {DB_PATH}")
|
||||
|
||||
|
||||
# Add slow query logging
|
||||
@event.listens_for(engine, "before_cursor_execute")
|
||||
def before_cursor_execute(conn, cursor, statement, parameters, context, executemany):
|
||||
""" 记录 query start time."""
|
||||
conn.info.setdefault("query_start_time", []).append(time.time())
|
||||
|
||||
|
||||
@event.listens_for(engine, "after_cursor_execute")
|
||||
def after_cursor_execute(conn, cursor, statement, parameters, context, executemany):
|
||||
""" 日志 slow queries."""
|
||||
total_time = time.time() - conn.info["query_start_time"].pop()
|
||||
|
||||
if total_time > SLOW_QUERY_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Slow query detected (took {total_time:.2f}s): {statement[:200]}",
|
||||
extra={
|
||||
"query_time": total_time,
|
||||
"query": statement[:500],
|
||||
"parameters": str(parameters)[:200] if parameters else None
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def init_db():
|
||||
""" Initialize the database tables."""
|
||||
from sqlalchemy.exc import OperationalError
|
||||
|
||||
# Try to create all tables
|
||||
try:
|
||||
SQLModel.metadata.create_all(engine)
|
||||
except OperationalError as e:
|
||||
if "already exists" in str(e):
|
||||
# Table or index already exists, skip creation
|
||||
logger.warning(f"Skipping table/index creation: {e}")
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
def get_session():
|
||||
""" 依赖 for getting a database session.
|
||||
Ensures proper cleanup and connection management.
|
||||
"""
|
||||
with Session(engine) as session:
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
# 会话 is automatically closed by context manager
|
||||
# This ensures connections are returned to the pool
|
||||
pass
|
||||
|
||||
|
||||
def get_pool_status() -> dict:
|
||||
""" Get current connection pool status for monitoring.
|
||||
Returns pool statistics including size, checked out connections, etc.
|
||||
"""
|
||||
pool = engine.pool
|
||||
|
||||
status = {
|
||||
"pool_size": getattr(pool, "size", lambda: 0)(),
|
||||
"checked_out": getattr(pool, "checkedout", lambda: 0)(),
|
||||
"overflow": getattr(pool, "overflow", lambda: 0)(),
|
||||
"checked_in": getattr(pool, "checkedin", lambda: 0)(),
|
||||
}
|
||||
|
||||
# Add pool-specific attributes if available
|
||||
if hasattr(pool, "_pool"):
|
||||
status["available"] = pool._pool.qsize()
|
||||
|
||||
return status
|
||||
|
||||
|
||||
def check_database_health() -> tuple[bool, str]:
|
||||
"""
|
||||
Check database connectivity and health.
|
||||
Returns (is_healthy, message).
|
||||
"""
|
||||
try:
|
||||
with Session(engine) as session:
|
||||
# Simple connectivity probe without model dependency.
|
||||
session.exec(text("SELECT 1"))
|
||||
return True, "Database connection healthy"
|
||||
except Exception as e:
|
||||
logger.error(f"Database health check failed: {e}")
|
||||
return False, f"Database connection failed: {str(e)}"
|
||||
|
||||
|
||||
def check_pool_health() -> dict:
|
||||
"""
|
||||
Comprehensive pool health check with recommendations.
|
||||
Returns detailed pool status and health indicators.
|
||||
"""
|
||||
status = get_pool_status()
|
||||
pool = engine.pool
|
||||
|
||||
# Calculate health metrics
|
||||
total_connections = status.get("pool_size", 0) + status.get("overflow", 0)
|
||||
checked_out = status.get("checked_out", 0)
|
||||
available = status.get("available", 0)
|
||||
|
||||
# Health checks
|
||||
health = {
|
||||
"status": "healthy",
|
||||
"checks": {},
|
||||
"recommendations": [],
|
||||
"pool_status": status
|
||||
}
|
||||
|
||||
# Check 1: Connection exhaustion
|
||||
if checked_out >= total_connections * 0.9:
|
||||
health["checks"]["exhaustion"] = "critical"
|
||||
health["status"] = "critical"
|
||||
health["recommendations"].append(
|
||||
f"Pool near exhaustion: {checked_out}/{total_connections} connections in use. "
|
||||
f"Consider increasing DB_POOL_SIZE or DB_MAX_OVERFLOW."
|
||||
)
|
||||
elif checked_out >= total_connections * 0.75:
|
||||
health["checks"]["exhaustion"] = "warning"
|
||||
health["status"] = "warning"
|
||||
health["recommendations"].append(
|
||||
f"Pool usage high: {checked_out}/{total_connections} connections in use."
|
||||
)
|
||||
else:
|
||||
health["checks"]["exhaustion"] = "healthy"
|
||||
|
||||
# Check 2: Available connections
|
||||
if available == 0 and total_connections > 0:
|
||||
health["checks"]["availability"] = "warning"
|
||||
if health["status"] == "healthy":
|
||||
health["status"] = "warning"
|
||||
else:
|
||||
health["checks"]["availability"] = "healthy"
|
||||
|
||||
# Check 3: Pool size configuration
|
||||
if total_connections < 5:
|
||||
health["recommendations"].append(
|
||||
"Pool size may be too small for production workloads. "
|
||||
"Consider DB_POOL_SIZE >= 10"
|
||||
)
|
||||
|
||||
return health
|
||||
217
backend/src/config/database_async.py
Normal file
217
backend/src/config/database_async.py
Normal file
@@ -0,0 +1,217 @@
|
||||
"""
|
||||
异步数据库配置模块
|
||||
|
||||
提供 SQLAlchemy 2.0 AsyncSession 支持,用于完全异步的数据库操作。
|
||||
支持 PostgreSQL (asyncpg) 和 SQLite (aiosqlite)。
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
create_async_engine,
|
||||
AsyncSession,
|
||||
async_sessionmaker,
|
||||
AsyncEngine,
|
||||
)
|
||||
from sqlalchemy.pool import NullPool
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
from src.config.settings import DB_PATH, DATABASE_URL
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Slow query threshold in seconds
|
||||
SLOW_QUERY_THRESHOLD = float(os.getenv("SLOW_QUERY_THRESHOLD", "1.0"))
|
||||
|
||||
# Connection pool configuration from environment variables
|
||||
POOL_SIZE = int(os.getenv("DB_POOL_SIZE", "20"))
|
||||
MAX_OVERFLOW = int(os.getenv("DB_MAX_OVERFLOW", "10"))
|
||||
POOL_TIMEOUT = int(os.getenv("DB_POOL_TIMEOUT", "30"))
|
||||
POOL_RECYCLE = int(os.getenv("DB_POOL_RECYCLE", "3600"))
|
||||
POOL_PRE_PING = os.getenv("DB_POOL_PRE_PING", "true").lower() == "true"
|
||||
|
||||
# Global engine instance
|
||||
async_engine: AsyncEngine | None = None
|
||||
async_session_maker: async_sessionmaker[AsyncSession] | None = None
|
||||
|
||||
|
||||
def _get_async_url() -> str:
|
||||
"""Generate async database URL from settings."""
|
||||
if DATABASE_URL:
|
||||
# Convert PostgreSQL URL to async version
|
||||
if DATABASE_URL.startswith("postgresql://"):
|
||||
return DATABASE_URL.replace("postgresql://", "postgresql+asyncpg://", 1)
|
||||
elif DATABASE_URL.startswith("postgresql+psycopg2://"):
|
||||
return DATABASE_URL.replace("postgresql+psycopg2://", "postgresql+asyncpg://", 1)
|
||||
return DATABASE_URL
|
||||
else:
|
||||
# SQLite async URL
|
||||
return f"sqlite+aiosqlite:///{DB_PATH}"
|
||||
|
||||
|
||||
def init_async_engine() -> AsyncEngine:
|
||||
"""Initialize async database engine with optimized connection pooling."""
|
||||
global async_engine
|
||||
|
||||
if async_engine is not None:
|
||||
return async_engine
|
||||
|
||||
async_url = _get_async_url()
|
||||
|
||||
if "postgresql" in async_url:
|
||||
# PostgreSQL async configuration
|
||||
async_engine = create_async_engine(
|
||||
async_url,
|
||||
echo=False,
|
||||
pool_size=POOL_SIZE,
|
||||
max_overflow=MAX_OVERFLOW,
|
||||
pool_timeout=POOL_TIMEOUT,
|
||||
pool_recycle=POOL_RECYCLE,
|
||||
pool_pre_ping=POOL_PRE_PING,
|
||||
connect_args={
|
||||
"timeout": 10,
|
||||
"command_timeout": 30,
|
||||
},
|
||||
)
|
||||
logger.info(
|
||||
f"Async PostgreSQL engine created with pool_size={POOL_SIZE}, "
|
||||
f"max_overflow={MAX_OVERFLOW}, pool_timeout={POOL_TIMEOUT}s"
|
||||
)
|
||||
else:
|
||||
# SQLite async configuration
|
||||
async_engine = create_async_engine(
|
||||
async_url,
|
||||
echo=False,
|
||||
poolclass=NullPool, # SQLite doesn't support connection pooling well
|
||||
connect_args={
|
||||
"timeout": 30,
|
||||
"check_same_thread": False,
|
||||
},
|
||||
)
|
||||
logger.info(f"Async SQLite engine created at {DB_PATH}")
|
||||
|
||||
return async_engine
|
||||
|
||||
|
||||
def init_async_session_maker() -> async_sessionmaker[AsyncSession]:
|
||||
"""Initialize async session maker."""
|
||||
global async_session_maker
|
||||
|
||||
if async_session_maker is not None:
|
||||
return async_session_maker
|
||||
|
||||
engine = init_async_engine()
|
||||
|
||||
async_session_maker = async_sessionmaker(
|
||||
engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
)
|
||||
|
||||
return async_session_maker
|
||||
|
||||
|
||||
async def init_db_async():
|
||||
"""Initialize database tables asynchronously."""
|
||||
engine = init_async_engine()
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(SQLModel.metadata.create_all)
|
||||
|
||||
logger.info("Database tables initialized (async)")
|
||||
|
||||
|
||||
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""
|
||||
Dependency for getting async database session.
|
||||
|
||||
Usage:
|
||||
@router.get("/items")
|
||||
async def get_items(session: AsyncSession = Depends(get_async_session)):
|
||||
...
|
||||
"""
|
||||
session_maker = init_async_session_maker()
|
||||
|
||||
async with session_maker() as session:
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
async def get_async_session_context() -> AsyncSession:
|
||||
"""
|
||||
Get async session as context manager.
|
||||
|
||||
Usage:
|
||||
async with get_async_session_context() as session:
|
||||
result = await session.execute(...)
|
||||
"""
|
||||
session_maker = init_async_session_maker()
|
||||
return session_maker()
|
||||
|
||||
|
||||
async def get_pool_status_async() -> dict:
|
||||
"""
|
||||
Get current connection pool status for monitoring.
|
||||
Returns pool statistics including size, checked out connections, etc.
|
||||
"""
|
||||
engine = init_async_engine()
|
||||
pool = engine.pool
|
||||
|
||||
status = {
|
||||
"pool_size": getattr(pool, "size", lambda: 0)(),
|
||||
"checked_out": getattr(pool, "checkedout", lambda: 0)(),
|
||||
"overflow": getattr(pool, "overflow", lambda: 0)(),
|
||||
}
|
||||
|
||||
return status
|
||||
|
||||
|
||||
async def check_database_health_async() -> tuple[bool, str]:
|
||||
"""
|
||||
Check database connectivity and health asynchronously.
|
||||
Returns (is_healthy, message).
|
||||
"""
|
||||
try:
|
||||
engine = init_async_engine()
|
||||
async with engine.connect() as conn:
|
||||
start_time = time.time()
|
||||
await conn.execute(text("SELECT 1"))
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
return True, f"Database connection healthy ({elapsed:.3f}s)"
|
||||
except Exception as e:
|
||||
logger.error(f"Database health check failed: {e}")
|
||||
return False, f"Database connection failed: {str(e)}"
|
||||
|
||||
|
||||
async def close_async_engine():
|
||||
"""Close async engine and cleanup connections."""
|
||||
global async_engine, async_session_maker
|
||||
|
||||
if async_engine is not None:
|
||||
await async_engine.dispose()
|
||||
async_engine = None
|
||||
async_session_maker = None
|
||||
logger.info("Async database engine closed")
|
||||
|
||||
|
||||
# Import guards for sync operations
|
||||
async def migrate_sync_to_async():
|
||||
"""
|
||||
Migration helper to ensure both sync and async engines are initialized.
|
||||
During transition period, both can coexist.
|
||||
"""
|
||||
# Initialize async engine
|
||||
init_async_engine()
|
||||
init_async_session_maker()
|
||||
|
||||
# Sync engine is already initialized in database.py
|
||||
logger.info("Both sync and async database engines are ready")
|
||||
465
backend/src/config/generation_options.json
Normal file
465
backend/src/config/generation_options.json
Normal file
@@ -0,0 +1,465 @@
|
||||
{
|
||||
"image": {
|
||||
"aspectRatios": {
|
||||
"label": "比例",
|
||||
"options": [
|
||||
{ "value": "1:1", "label": "1:1 (正方形)" },
|
||||
{ "value": "3:4", "label": "3:4 (竖版)" },
|
||||
{ "value": "4:3", "label": "4:3 (横版)" },
|
||||
{ "value": "9:16", "label": "9:16 (手机竖版)" },
|
||||
{ "value": "16:9", "label": "16:9 (宽屏)" }
|
||||
],
|
||||
"default": "16:9"
|
||||
},
|
||||
"resolutions": {
|
||||
"label": "分辨率",
|
||||
"options": [
|
||||
{ "value": "1K", "label": "1K (标准)", "description": "1280x720 或 1024x1024" },
|
||||
{ "value": "2K", "label": "2K (高清)", "description": "2560x1440 或 2048x2048" },
|
||||
{ "value": "4K", "label": "4K (超高清)", "description": "3840x2160 或 4096x4096" }
|
||||
],
|
||||
"default": "1K"
|
||||
},
|
||||
"counts": {
|
||||
"label": "生成数量",
|
||||
"options": [
|
||||
{ "value": 1, "label": "1张" },
|
||||
{ "value": 2, "label": "2张" },
|
||||
{ "value": 3, "label": "3张" },
|
||||
{ "value": 4, "label": "4张" }
|
||||
],
|
||||
"default": 1,
|
||||
"min": 1,
|
||||
"max": 4
|
||||
},
|
||||
"templates": {
|
||||
"label": "模板类型",
|
||||
"options": [
|
||||
{ "value": "general", "label": "通用", "description": "通用图片生成" },
|
||||
{ "value": "character_white_bg", "label": "角色白底图", "description": "角色图片,白色背景" },
|
||||
{ "value": "character_three_view", "label": "角色三视图", "description": "角色正面、侧面、背面视图" },
|
||||
{ "value": "storyboard_integrated", "label": "分镜出图", "description": "基于分镜生成图片" }
|
||||
],
|
||||
"default": "general"
|
||||
}
|
||||
},
|
||||
"video": {
|
||||
"aspectRatios": {
|
||||
"label": "比例",
|
||||
"options": [
|
||||
{ "value": "1:1", "label": "1:1 (正方形)" },
|
||||
{ "value": "3:4", "label": "3:4 (竖版)" },
|
||||
{ "value": "4:3", "label": "4:3 (横版)" },
|
||||
{ "value": "9:16", "label": "9:16 (手机竖版)" },
|
||||
{ "value": "16:9", "label": "16:9 (宽屏)" }
|
||||
],
|
||||
"default": "16:9"
|
||||
},
|
||||
"resolutions": {
|
||||
"label": "分辨率",
|
||||
"options": [
|
||||
{ "value": "720p", "label": "720P (1280x720)", "width": 1280, "height": 720 },
|
||||
{ "value": "1080p", "label": "1080P (1920x1080)", "width": 1920, "height": 1080 }
|
||||
],
|
||||
"default": "1080p"
|
||||
},
|
||||
"durations": {
|
||||
"label": "时长",
|
||||
"options": [
|
||||
{ "value": "2s", "label": "2秒" },
|
||||
{ "value": "3s", "label": "3秒" },
|
||||
{ "value": "4s", "label": "4秒" },
|
||||
{ "value": "5s", "label": "5秒" },
|
||||
{ "value": "6s", "label": "6秒" },
|
||||
{ "value": "7s", "label": "7秒" },
|
||||
{ "value": "8s", "label": "8秒" },
|
||||
{ "value": "10s", "label": "10秒" }
|
||||
],
|
||||
"default": "5s",
|
||||
"description": "视频时长(秒)"
|
||||
},
|
||||
"counts": {
|
||||
"label": "生成数量",
|
||||
"options": [
|
||||
{ "value": 1, "label": "1个" },
|
||||
{ "value": 2, "label": "2个" },
|
||||
{ "value": 3, "label": "3个" },
|
||||
{ "value": 4, "label": "4个" }
|
||||
],
|
||||
"default": 1,
|
||||
"min": 1,
|
||||
"max": 4
|
||||
}
|
||||
},
|
||||
"audio": {
|
||||
"durations": {
|
||||
"label": "时长",
|
||||
"options": [
|
||||
{ "value": "5s", "label": "5秒" },
|
||||
{ "value": "10s", "label": "10秒" },
|
||||
{ "value": "30s", "label": "30秒" },
|
||||
{ "value": "60s", "label": "1分钟" }
|
||||
],
|
||||
"default": "10s"
|
||||
},
|
||||
"formats": {
|
||||
"label": "音频格式",
|
||||
"options": [
|
||||
{ "value": "mp3", "label": "MP3" },
|
||||
{ "value": "wav", "label": "WAV" }
|
||||
],
|
||||
"default": "mp3"
|
||||
}
|
||||
},
|
||||
"music": {
|
||||
"formats": {
|
||||
"label": "音频格式",
|
||||
"options": [
|
||||
{ "value": "mp3", "label": "MP3" },
|
||||
{ "value": "wav", "label": "WAV" },
|
||||
{ "value": "pcm", "label": "PCM" }
|
||||
],
|
||||
"default": "mp3"
|
||||
},
|
||||
"sampleRates": {
|
||||
"label": "采样率",
|
||||
"options": [
|
||||
{ "value": 16000, "label": "16kHz" },
|
||||
{ "value": 24000, "label": "24kHz" },
|
||||
{ "value": 32000, "label": "32kHz" },
|
||||
{ "value": 44100, "label": "44.1kHz" }
|
||||
],
|
||||
"default": 44100
|
||||
},
|
||||
"bitrates": {
|
||||
"label": "码率",
|
||||
"options": [
|
||||
{ "value": 32000, "label": "32kbps" },
|
||||
{ "value": 64000, "label": "64kbps" },
|
||||
{ "value": 128000, "label": "128kbps" },
|
||||
{ "value": 256000, "label": "256kbps" }
|
||||
],
|
||||
"default": 256000
|
||||
},
|
||||
"outputFormats": {
|
||||
"label": "返回格式",
|
||||
"options": [
|
||||
{ "value": "url", "label": "URL (24小时有效)" },
|
||||
{ "value": "hex", "label": "HEX" }
|
||||
],
|
||||
"default": "url"
|
||||
}
|
||||
},
|
||||
"script": {
|
||||
"movieTones": {
|
||||
"label": "整体基调",
|
||||
"options": [
|
||||
{ "value": "悬疑/惊悚", "label": "悬疑/惊悚 (Suspense/Thriller)", "en": "Suspense/Thriller" },
|
||||
{ "value": "古装/权谋", "label": "古装/权谋 (Historical/Political)", "en": "Historical/Political" },
|
||||
{ "value": "现代/都市", "label": "现代/都市 (Modern/Urban)", "en": "Modern/Urban" },
|
||||
{ "value": "科幻/未来", "label": "科幻/未来 (Sci-Fi/Future)", "en": "Sci-Fi/Future" },
|
||||
{ "value": "喜剧/荒诞", "label": "喜剧/荒诞 (Comedy/Absurd)", "en": "Comedy/Absurd" },
|
||||
{ "value": "动作/犯罪", "label": "动作/犯罪 (Action/Crime)", "en": "Action/Crime" },
|
||||
{ "value": "爱情/治愈", "label": "爱情/治愈 (Romance/Healing)", "en": "Romance/Healing" },
|
||||
{ "value": "奇幻/仙侠", "label": "奇幻/仙侠 (Fantasy/Xianxia)", "en": "Fantasy/Xianxia" },
|
||||
{ "value": "现实/人文", "label": "现实/人文 (Realistic/Humanistic)", "en": "Realistic/Humanistic" }
|
||||
],
|
||||
"default": "现代/都市"
|
||||
},
|
||||
"targetAudiences": {
|
||||
"label": "目标受众",
|
||||
"options": [
|
||||
{ "value": "全年龄段", "label": "全年龄段 (All Ages)", "en": "All Ages" },
|
||||
{ "value": "青少年", "label": "青少年 (Teenagers)", "en": "Teenagers" },
|
||||
{ "value": "年轻女性", "label": "年轻女性 (Young Women)", "en": "Young Women" },
|
||||
{ "value": "年轻男性", "label": "年轻男性 (Young Men)", "en": "Young Men" },
|
||||
{ "value": "成年观众", "label": "成年观众 (Adults)", "en": "Adults" },
|
||||
{ "value": "家庭观众", "label": "家庭观众 (Family)", "en": "Family" },
|
||||
{ "value": "资深影迷", "label": "资深影迷 (Cinephiles)", "en": "Cinephiles" }
|
||||
],
|
||||
"default": "全年龄段"
|
||||
}
|
||||
},
|
||||
"director": {
|
||||
"narrativeStyles": {
|
||||
"label": "叙事手法",
|
||||
"options": [
|
||||
{ "value": "线性叙事", "label": "线性叙事 (Linear)", "en": "Linear Narrative" },
|
||||
{ "value": "非线性/插叙", "label": "非线性/插叙 (Non-linear)", "en": "Non-linear Narrative" },
|
||||
{ "value": "多线并行", "label": "多线并行 (Multi-line)", "en": "Multi-line Narrative" },
|
||||
{ "value": "倒叙", "label": "倒叙 (Flashback)", "en": "Flashback" },
|
||||
{ "value": "意识流", "label": "意识流 (Stream of Consciousness)", "en": "Stream of Consciousness" },
|
||||
{ "value": "伪纪录片", "label": "伪纪录片 (Mockumentary)", "en": "Mockumentary" }
|
||||
],
|
||||
"default": "线性叙事"
|
||||
},
|
||||
"editingPaces": {
|
||||
"label": "剪辑节奏",
|
||||
"options": [
|
||||
{ "value": "缓慢/沉浸", "label": "缓慢/沉浸 (Slow/Immersive)", "en": "Slow/Immersive" },
|
||||
{ "value": "明快/流畅", "label": "明快/流畅 (Brisk/Fluid)", "en": "Brisk/Fluid" },
|
||||
{ "value": "快速/凌厉", "label": "快速/凌厉 (Fast/Sharp)", "en": "Fast/Sharp" },
|
||||
{ "value": "极速/碎片化", "label": "极速/碎片化 (Rapid/Fragmented)", "en": "Rapid/Fragmented" },
|
||||
{ "value": "舒缓/诗意", "label": "舒缓/诗意 (Soothing/Poetic)", "en": "Soothing/Poetic" }
|
||||
],
|
||||
"default": "明快/流畅"
|
||||
},
|
||||
"styles": {
|
||||
"label": "导演风格",
|
||||
"options": [
|
||||
{ "value": "孔笙 (Kong Sheng)", "label": "孔笙 (厚重/写实/山海情)", "en": "Kong Sheng Style" },
|
||||
{ "value": "郑晓龙 (Zheng Xiaolong)", "label": "郑晓龙 (宫廷/传奇/甄嬛传)", "en": "Zheng Xiaolong Style" },
|
||||
{ "value": "张黎 (Zhang Li)", "label": "张黎 (历史/权谋/大明王朝)", "en": "Zhang Li Style" },
|
||||
{ "value": "辛爽 (Xin Shuang)", "label": "辛爽 (悬疑/美学/漫长的季节)", "en": "Xin Shuang Style" },
|
||||
{ "value": "王家卫 (Wong Kar-wai)", "label": "王家卫 (繁花/光影/霓虹)", "en": "Wong Kar-wai Style" },
|
||||
{ "value": "李路 (Li Lu)", "label": "李路 (史诗/人世间/人民的名义)", "en": "Li Lu Style" },
|
||||
{ "value": "曹盾 (Cao Dun)", "label": "曹盾 (视觉/长镜头/长安十二时辰)", "en": "Cao Dun Style" },
|
||||
{ "value": "王伟 (Wang Wei)", "label": "王伟 (硬汉/刑侦/白夜追凶)", "en": "Wang Wei Style" },
|
||||
{ "value": "汪俊 (Wang Jun)", "label": "汪俊 (都市/细腻/小欢喜)", "en": "Wang Jun Style" },
|
||||
{ "value": "徐纪周 (Xu Jizhou)", "label": "徐纪周 (群像/狂飙/快节奏)", "en": "Xu Jizhou Style" },
|
||||
{ "value": "吕行 (Lu Xing)", "label": "吕行 (犯罪/人性/无证之罪)", "en": "Lu Xing Style" },
|
||||
{ "value": "丁黑 (Ding Hei)", "label": "丁黑 (警察荣誉/那年花开)", "en": "Ding Hei Style" }
|
||||
],
|
||||
"default": "孔笙 (Kong Sheng)"
|
||||
}
|
||||
},
|
||||
"character": {
|
||||
"genders": {
|
||||
"label": "性别",
|
||||
"options": [
|
||||
{ "value": "男", "label": "男" },
|
||||
{ "value": "女", "label": "女" },
|
||||
{ "value": "未知", "label": "未知" }
|
||||
],
|
||||
"default": "未知"
|
||||
},
|
||||
"costumeStyles": {
|
||||
"label": "服装风格",
|
||||
"options": [
|
||||
{ "value": "现代日常", "label": "现代日常 (Modern Daily)", "en": "Modern Daily" },
|
||||
{ "value": "古装汉服", "label": "古装汉服 (Ancient Hanfu)", "en": "Ancient Hanfu" },
|
||||
{ "value": "民国风情", "label": "民国风情 (Republic Era)", "en": "Republic Era" },
|
||||
{ "value": "赛博科幻", "label": "赛博科幻 (Cyberpunk Sci-Fi)", "en": "Cyberpunk Sci-Fi" },
|
||||
{ "value": "职业制服", "label": "职业制服 (Professional Uniform)", "en": "Professional Uniform" },
|
||||
{ "value": "街头潮流", "label": "街头潮流 (Streetwear)", "en": "Streetwear" },
|
||||
{ "value": "极简森系", "label": "极简森系 (Minimalist Mori)", "en": "Minimalist Mori" },
|
||||
{ "value": "奢华礼服", "label": "奢华礼服 (Luxury/Formal)", "en": "Luxury/Formal" }
|
||||
],
|
||||
"default": "现代日常"
|
||||
},
|
||||
"roles": {
|
||||
"label": "角色定位",
|
||||
"options": [
|
||||
{ "value": "主角", "label": "主角" },
|
||||
{ "value": "配角", "label": "配角" },
|
||||
{ "value": "反派", "label": "反派" },
|
||||
{ "value": "龙套", "label": "龙套" },
|
||||
{ "value": "群演", "label": "群演" }
|
||||
],
|
||||
"default": "配角"
|
||||
},
|
||||
"emotions": {
|
||||
"label": "情绪基调",
|
||||
"options": [
|
||||
{ "value": "平静", "label": "平静 (Neutral)", "en": "Neutral/Calm" },
|
||||
{ "value": "喜悦", "label": "喜悦 (Happy)", "en": "Happy/Joyful" },
|
||||
{ "value": "悲伤", "label": "悲伤 (Sad)", "en": "Sad/Sorrowful" },
|
||||
{ "value": "愤怒", "label": "愤怒 (Angry)", "en": "Angry/Furious" },
|
||||
{ "value": "恐惧", "label": "恐惧 (Fearful)", "en": "Fearful/Scared" },
|
||||
{ "value": "惊讶", "label": "惊讶 (Surprised)", "en": "Surprised/Shocked" },
|
||||
{ "value": "自信", "label": "自信 (Confident)", "en": "Confident/Bold" },
|
||||
{ "value": "思索", "label": "思索 (Thinking)", "en": "Thinking/Pensive" }
|
||||
],
|
||||
"default": "平静"
|
||||
}
|
||||
},
|
||||
"storyboard": {
|
||||
"shotTypes": {
|
||||
"label": "镜头类型",
|
||||
"options": [
|
||||
{ "value": "大远景 (ELS)", "label": "大远景 (ELS)", "en": "Extreme Long Shot (ELS)" },
|
||||
{ "value": "远景 (LS)", "label": "远景 (LS)", "en": "Long Shot (LS)" },
|
||||
{ "value": "全景 (FS)", "label": "全景 (FS)", "en": "Full Shot (FS)" },
|
||||
{ "value": "中远景 (MLS)", "label": "中远景 (MLS)", "en": "Medium Long Shot (MLS)" },
|
||||
{ "value": "中景 (MS)", "label": "中景 (MS)", "en": "Medium Shot (MS)" },
|
||||
{ "value": "中特写 (MCU)", "label": "中特写 (MCU)", "en": "Medium Close-Up (MCU)" },
|
||||
{ "value": "特写 (CU)", "label": "特写 (CU)", "en": "Close-Up (CU)" },
|
||||
{ "value": "大特写 (ECU)", "label": "大特写 (ECU)", "en": "Extreme Close-Up (ECU)" },
|
||||
{ "value": "建立镜头", "label": "建立镜头", "en": "Establishing Shot" },
|
||||
{ "value": "主观镜头 (POV)", "label": "主观镜头 (POV)", "en": "Point of View (POV)" },
|
||||
{ "value": "过肩镜头 (OTS)", "label": "过肩镜头 (OTS)", "en": "Over the Shoulder (OTS)" }
|
||||
],
|
||||
"default": "中景 (MS)"
|
||||
},
|
||||
"cameraMovements": {
|
||||
"label": "运镜方式",
|
||||
"options": [
|
||||
{ "value": "固定镜头 (Static)", "label": "固定镜头 (Static)", "en": "Static" },
|
||||
{ "value": "左摇 (Pan Left)", "label": "左摇 (Pan Left)", "en": "Pan Left" },
|
||||
{ "value": "右摇 (Pan Right)", "label": "右摇 (Pan Right)", "en": "Pan Right" },
|
||||
{ "value": "上仰 (Tilt Up)", "label": "上仰 (Tilt Up)", "en": "Tilt Up" },
|
||||
{ "value": "下俯 (Tilt Down)", "label": "下俯 (Tilt Down)", "en": "Tilt Down" },
|
||||
{ "value": "推镜头 (Zoom In)", "label": "推镜头 (Zoom In)", "en": "Zoom In" },
|
||||
{ "value": "拉镜头 (Zoom Out)", "label": "拉镜头 (Zoom Out)", "en": "Zoom Out" },
|
||||
{ "value": "前移 (Dolly In)", "label": "前移 (Dolly In)", "en": "Dolly In" },
|
||||
{ "value": "后移 (Dolly Out)", "label": "后移 (Dolly Out)", "en": "Dolly Out" },
|
||||
{ "value": "跟随 (Tracking)", "label": "跟随 (Tracking)", "en": "Tracking" },
|
||||
{ "value": "环绕 (Arc)", "label": "环绕 (Arc)", "en": "Arc" },
|
||||
{ "value": "手持 (Handheld)", "label": "手持 (Handheld)", "en": "Handheld" }
|
||||
],
|
||||
"default": "固定镜头 (Static)"
|
||||
},
|
||||
"transitions": {
|
||||
"label": "转场效果",
|
||||
"options": [
|
||||
{ "value": "切 (Cut)", "label": "切 (Cut)", "en": "Cut" },
|
||||
{ "value": "叠化 (Dissolve)", "label": "叠化 (Dissolve)", "en": "Dissolve" },
|
||||
{ "value": "淡入 (Fade In)", "label": "淡入 (Fade In)", "en": "Fade In" },
|
||||
{ "value": "淡出 (Fade Out)", "label": "淡出 (Fade Out)", "en": "Fade Out" },
|
||||
{ "value": "划像 (Wipe)", "label": "划像 (Wipe)", "en": "Wipe" },
|
||||
{ "value": "圈入 (Iris In)", "label": "圈入 (Iris In)", "en": "Iris In" },
|
||||
{ "value": "圈出 (Iris Out)", "label": "圈出 (Iris Out)", "en": "Iris Out" },
|
||||
{ "value": "匹配剪辑 (Match Cut)", "label": "匹配剪辑 (Match Cut)", "en": "Match Cut" },
|
||||
{ "value": "跳接 (Jump Cut)", "label": "跳接 (Jump Cut)", "en": "Jump Cut" }
|
||||
],
|
||||
"default": "切 (Cut)"
|
||||
},
|
||||
"compositions": {
|
||||
"label": "构图方式",
|
||||
"options": [
|
||||
{ "value": "三分法 (Rule of Thirds)", "label": "三分法 (Rule of Thirds)", "en": "Rule of Thirds" },
|
||||
{ "value": "中心构图 (Center Framed)", "label": "中心构图 (Center Framed)", "en": "Center Framed" },
|
||||
{ "value": "对称构图 (Symmetrical)", "label": "对称构图 (Symmetrical)", "en": "Symmetrical" },
|
||||
{ "value": "引导线 (Leading Lines)", "label": "引导线 (Leading Lines)", "en": "Leading Lines" },
|
||||
{ "value": "对角线 (Diagonal)", "label": "对角线 (Diagonal)", "en": "Diagonal" },
|
||||
{ "value": "框架构图 (Framing)", "label": "框架构图 (Framing)", "en": "Framing" },
|
||||
{ "value": "极简留白 (Minimalist/Negative Space)", "label": "极简留白 (Minimalist/Negative Space)", "en": "Minimalist/Negative Space" },
|
||||
{ "value": "黄金螺旋 (Golden Spiral)", "label": "黄金螺旋 (Golden Spiral)", "en": "Golden Spiral" }
|
||||
],
|
||||
"default": "三分法 (Rule of Thirds)"
|
||||
}
|
||||
},
|
||||
"scene": {
|
||||
"timesOfDay": {
|
||||
"label": "拍摄时段",
|
||||
"options": [
|
||||
{ "value": "清晨", "label": "清晨 (Dawn)", "en": "Dawn" },
|
||||
{ "value": "早晨", "label": "早晨 (Morning)", "en": "Morning" },
|
||||
{ "value": "正午", "label": "正午 (Noon)", "en": "Noon" },
|
||||
{ "value": "下午", "label": "下午 (Afternoon)", "en": "Afternoon" },
|
||||
{ "value": "黄金时刻", "label": "黄金时刻 (Golden Hour)", "en": "Golden Hour" },
|
||||
{ "value": "傍晚", "label": "傍晚 (Dusk)", "en": "Dusk" },
|
||||
{ "value": "蓝调时刻", "label": "蓝调时刻 (Blue Hour)", "en": "Blue Hour" },
|
||||
{ "value": "夜晚", "label": "夜晚 (Night)", "en": "Night" },
|
||||
{ "value": "深夜", "label": "深夜 (Late Night)", "en": "Late Night" }
|
||||
],
|
||||
"default": "早晨"
|
||||
},
|
||||
"environmentTypes": {
|
||||
"label": "空间类型",
|
||||
"options": [
|
||||
{ "value": "室内", "label": "室内 (Interior)", "en": "Interior" },
|
||||
{ "value": "室外", "label": "室外 (Exterior)", "en": "Exterior" },
|
||||
{ "value": "半室外", "label": "半室外 (Semi-Exterior)", "en": "Semi-Exterior" }
|
||||
],
|
||||
"default": "室内"
|
||||
},
|
||||
"weather": {
|
||||
"label": "天气环境",
|
||||
"options": [
|
||||
{ "value": "晴朗", "label": "晴朗 (Clear)", "en": "Clear/Sunny" },
|
||||
{ "value": "多云", "label": "多云 (Cloudy)", "en": "Partly Cloudy" },
|
||||
{ "value": "阴天", "label": "阴天 (Overcast)", "en": "Overcast" },
|
||||
{ "value": "小雨", "label": "小雨 (Rainy)", "en": "Light Rain" },
|
||||
{ "value": "大雨", "label": "大雨 (Heavy Rain)", "en": "Heavy Rain" },
|
||||
{ "value": "暴风雨", "label": "暴风雨 (Storm)", "en": "Storm" },
|
||||
{ "value": "小雪", "label": "小雪 (Snowy)", "en": "Light Snow" },
|
||||
{ "value": "大雪", "label": "大雪 (Heavy Snow)", "en": "Heavy Snow" },
|
||||
{ "value": "大雾", "label": "大雾 (Foggy)", "en": "Dense Fog" },
|
||||
{ "value": "沙尘", "label": "沙尘 (Dusty)", "en": "Dusty/Sandstorm" }
|
||||
],
|
||||
"default": "晴朗"
|
||||
}
|
||||
},
|
||||
"cinematic": {
|
||||
"visualStyles": {
|
||||
"label": "视觉画风",
|
||||
"options": [
|
||||
{ "value": "现实主义/纪录片感", "label": "现实主义/纪录片感 (Realistic/Documentary)", "en": "Realistic/Documentary" },
|
||||
{ "value": "电影质感/胶片风", "label": "电影质感/胶片风 (Cinematic/Film)", "en": "Cinematic/Film" },
|
||||
{ "value": "赛博朋克/霓虹", "label": "赛博朋克/霓虹 (Cyberpunk/Neon)", "en": "Cyberpunk/Neon" },
|
||||
{ "value": "古风/水墨", "label": "古风/水墨 (Ancient/Ink Wash)", "en": "Ancient/Ink Wash" },
|
||||
{ "value": "极简主义/冷淡", "label": "极简主义/冷淡 (Minimalist/Cold)", "en": "Minimalist/Cold" },
|
||||
{ "value": "唯美/梦幻", "label": "唯美/梦幻 (Aesthetic/Dreamy)", "en": "Aesthetic/Dreamy" },
|
||||
{ "value": "暗黑/哥特", "label": "暗黑/哥特 (Dark/Gothic)", "en": "Dark/Gothic" },
|
||||
{ "value": "动画/二次元", "label": "动画/二次元 (Anime/2D)", "en": "Anime/2D" }
|
||||
],
|
||||
"default": "电影质感/胶片风"
|
||||
},
|
||||
"cameraAngles": {
|
||||
"label": "镜头角度",
|
||||
"options": [
|
||||
{ "value": "平视", "label": "平视 (Eye Level)", "en": "Eye Level" },
|
||||
{ "value": "俯拍", "label": "俯拍 (High Angle)", "en": "High Angle" },
|
||||
{ "value": "仰拍", "label": "仰拍 (Low Angle)", "en": "Low Angle" },
|
||||
{ "value": "顶拍", "label": "顶拍 (Top Down)", "en": "Top Down/Bird's Eye" },
|
||||
{ "value": "底拍", "label": "底拍 (Worm's Eye)", "en": "Worm's Eye View" },
|
||||
{ "value": "斜角镜头", "label": "斜角镜头 (Dutch Angle)", "en": "Dutch Angle" }
|
||||
],
|
||||
"default": "平视"
|
||||
},
|
||||
"lighting": {
|
||||
"label": "灯光风格",
|
||||
"options": [
|
||||
{ "value": "电影感光效", "label": "电影感 (Cinematic)", "en": "Cinematic Lighting" },
|
||||
{ "value": "自然光", "label": "自然光 (Natural)", "en": "Natural Lighting" },
|
||||
{ "value": "柔光", "label": "柔光 (Soft)", "en": "Soft Lighting" },
|
||||
{ "value": "强反差/明暗对照", "label": "强反差 (High Contrast)", "en": "Chiaroscuro/High Contrast" },
|
||||
{ "value": "轮廓光", "label": "轮廓光 (Rim Lighting)", "en": "Rim Lighting" },
|
||||
{ "value": "三点式亮光", "label": "三点式亮光 (Three-point Lighting)", "en": "Three-point Lighting" },
|
||||
{ "value": "工作室光效", "label": "工作室 (Studio)", "en": "Studio Lighting" }
|
||||
],
|
||||
"default": "电影感光效"
|
||||
},
|
||||
"colorStyle": {
|
||||
"label": "色调氛围",
|
||||
"options": [
|
||||
{ "value": "电影感", "label": "电影感 (Cinematic)", "en": "Cinematic" },
|
||||
{ "value": "暖色调", "label": "暖色调 (Warm Tones)", "en": "Warm Tones" },
|
||||
{ "value": "冷色调", "label": "冷色调 (Cold Tones)", "en": "Cold Tones" },
|
||||
{ "value": "黑白", "label": "黑白 (Black and White)", "en": "Black and White" },
|
||||
{ "value": "复古/胶片感", "label": "复古/胶片 (Vintage)", "en": "Vintage Film" },
|
||||
{ "value": "赛博朋克", "label": "赛博朋克 (Cyberpunk)", "en": "Cyberpunk" },
|
||||
{ "value": "高饱和", "label": "高饱和 (Vibrant)", "en": "Vibrant" },
|
||||
{ "value": "低饱和/灰色调", "label": "低饱和 (Muted)", "en": "Muted/Desaturated" }
|
||||
],
|
||||
"default": "电影感"
|
||||
},
|
||||
"lenses": {
|
||||
"label": "镜头焦距",
|
||||
"options": [
|
||||
{ "value": "广角镜头", "label": "超广角 (14-24mm)", "en": "Ultra Wide Lens" },
|
||||
{ "value": "标准广角", "label": "广角 (35mm)", "en": "Wide Angle Lens" },
|
||||
{ "value": "标准镜头", "label": "标准 (50mm)", "en": "Standard/Nifty Fifty" },
|
||||
{ "value": "人像镜头", "label": "长焦 (85mm)", "en": "Portrait/Short Telephoto" },
|
||||
{ "value": "远摄镜头", "label": "超长焦 (200mm+)", "en": "Telephoto Lens" },
|
||||
{ "value": "鱼眼镜头", "label": "鱼眼 (Fisheye)", "en": "Fisheye Lens" },
|
||||
{ "value": "变焦镜头", "label": "变焦 (Anamorphic)", "en": "Anamorphic Lens" }
|
||||
],
|
||||
"default": "标准镜头"
|
||||
},
|
||||
"focus": {
|
||||
"label": "焦点控制",
|
||||
"options": [
|
||||
{ "value": "自动对焦", "label": "自动对焦 (Auto)", "en": "Auto Focus" },
|
||||
{ "value": "浅景深/虚化", "label": "浅景深 (Shallow Bokeh)", "en": "Shallow Depth of Field" },
|
||||
{ "value": "大深景", "label": "大深景 (Deep Focus)", "en": "Deep Depth of Field" },
|
||||
{ "value": "焦点转移", "label": "变焦对焦 (Rack Focus)", "en": "Rack Focus" },
|
||||
{ "value": "微距焦点", "label": "微距 (Macro)", "en": "Macro Focus" },
|
||||
{ "value": "移轴效果", "label": "移轴 (Tilt-Shift)", "en": "Tilt-Shift" }
|
||||
],
|
||||
"default": "自动对焦"
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
16
backend/src/config/script_agent_config.json
Normal file
16
backend/src/config/script_agent_config.json
Normal file
@@ -0,0 +1,16 @@
|
||||
{
|
||||
"roles": {
|
||||
"story_architect": "moonshot-v1-8k",
|
||||
"character_consultant": "moonshot-v1-8k",
|
||||
"scriptwriter": "moonshot-v1-8k",
|
||||
"director": "moonshot-v1-8k",
|
||||
"auditor": "moonshot-v1-8k",
|
||||
"moderator": "moonshot-v1-8k",
|
||||
"psychologist": "moonshot-v1-8k",
|
||||
"visualizer": "moonshot-v1-8k",
|
||||
"continuity_manager": "moonshot-v1-8k",
|
||||
"showrunner": "moonshot-v1-8k",
|
||||
"chief_editor": "moonshot-v1-8k",
|
||||
"specialist": "moonshot-v1-8k"
|
||||
}
|
||||
}
|
||||
5
backend/src/config/services/alibaba/provider.json
Normal file
5
backend/src/config/services/alibaba/provider.json
Normal file
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"id": "aliyun",
|
||||
"name": "阿里云",
|
||||
"description": "阿里云提供的包括图像超分、视频超分等服务"
|
||||
}
|
||||
8
backend/src/config/services/alibaba/upscale.json
Normal file
8
backend/src/config/services/alibaba/upscale.json
Normal file
@@ -0,0 +1,8 @@
|
||||
{
|
||||
"videoenhan": {
|
||||
"name": "阿里巴巴图像超分",
|
||||
"class": "backend.src.services.post_process.super_resolution.SuperResolutionService",
|
||||
"args": [],
|
||||
"enabled": true
|
||||
}
|
||||
}
|
||||
16
backend/src/config/services/anthropic/provider.json
Normal file
16
backend/src/config/services/anthropic/provider.json
Normal file
@@ -0,0 +1,16 @@
|
||||
{
|
||||
"id": "anthropic",
|
||||
"name": "Anthropic",
|
||||
"description": "Claude 系列模型",
|
||||
"dashboard_url": "https://console.anthropic.com/settings/keys",
|
||||
"helpUrl": "https://console.anthropic.com/settings/keys",
|
||||
"fields": [
|
||||
{
|
||||
"name": "apiKey",
|
||||
"label": "API Key",
|
||||
"placeholder": "sk-ant-...",
|
||||
"required": true,
|
||||
"type": "password"
|
||||
}
|
||||
]
|
||||
}
|
||||
488
backend/src/config/services/dashscope/audio.json
Normal file
488
backend/src/config/services/dashscope/audio.json
Normal file
@@ -0,0 +1,488 @@
|
||||
{
|
||||
"cosyvoice-v3-plus": {
|
||||
"name": "CosyVoice-V3-Plus",
|
||||
"class": "backend.src.services.provider.dashscope.audio.DashScopeAudioService",
|
||||
"args": ["cosyvoice-v3-plus"],
|
||||
"voices": [
|
||||
{
|
||||
"id": "longanyang",
|
||||
"name": "龙安洋",
|
||||
"gender": "male",
|
||||
"desc": "阳光大男孩"
|
||||
},
|
||||
{
|
||||
"id": "longanhuan",
|
||||
"name": "龙安欢",
|
||||
"gender": "female",
|
||||
"desc": "欢脱元气女"
|
||||
},
|
||||
{
|
||||
"id": "longhuhu_v3",
|
||||
"name": "龙呼呼",
|
||||
"gender": "female",
|
||||
"desc": "天真烂漫女童"
|
||||
},
|
||||
{
|
||||
"id": "longpaopao_v3",
|
||||
"name": "龙泡泡",
|
||||
"gender": "female",
|
||||
"desc": "飞天泡泡音"
|
||||
},
|
||||
{
|
||||
"id": "longjielidou_v3",
|
||||
"name": "龙杰力豆",
|
||||
"gender": "male",
|
||||
"desc": "阳光顽皮男童"
|
||||
},
|
||||
{
|
||||
"id": "longxian_v3",
|
||||
"name": "龙仙",
|
||||
"gender": "female",
|
||||
"desc": "豪放可爱女童"
|
||||
},
|
||||
{
|
||||
"id": "longling_v3",
|
||||
"name": "龙铃",
|
||||
"gender": "female",
|
||||
"desc": "稚气呆板女童"
|
||||
},
|
||||
{
|
||||
"id": "longshanshan_v3",
|
||||
"name": "龙闪闪",
|
||||
"gender": "female",
|
||||
"desc": "戏剧化童声"
|
||||
},
|
||||
{
|
||||
"id": "longniuniu_v3",
|
||||
"name": "龙牛牛",
|
||||
"gender": "male",
|
||||
"desc": "阳光男童声"
|
||||
},
|
||||
{
|
||||
"id": "longjiaxin_v3",
|
||||
"name": "龙嘉欣",
|
||||
"gender": "female",
|
||||
"desc": "优雅粤语女"
|
||||
},
|
||||
{
|
||||
"id": "longjiayi_v3",
|
||||
"name": "龙嘉怡",
|
||||
"gender": "female",
|
||||
"desc": "知性粤语女"
|
||||
},
|
||||
{
|
||||
"id": "longanyue_v3",
|
||||
"name": "龙安粤",
|
||||
"gender": "male",
|
||||
"desc": "欢脱粤语男"
|
||||
},
|
||||
{
|
||||
"id": "longlaotie_v3",
|
||||
"name": "龙老铁",
|
||||
"gender": "male",
|
||||
"desc": "东北直率男"
|
||||
},
|
||||
{
|
||||
"id": "longshange_v3",
|
||||
"name": "龙陕哥",
|
||||
"gender": "male",
|
||||
"desc": "原味陕北男"
|
||||
},
|
||||
{
|
||||
"id": "longanmin_v3",
|
||||
"name": "龙安闽",
|
||||
"gender": "female",
|
||||
"desc": "清纯闽南女"
|
||||
},
|
||||
{
|
||||
"id": "loongkyong_v3",
|
||||
"name": "loongkyong",
|
||||
"gender": "female",
|
||||
"desc": "韩语女"
|
||||
},
|
||||
{
|
||||
"id": "loongriko_v3",
|
||||
"name": "Riko",
|
||||
"gender": "female",
|
||||
"desc": "二次元日语女"
|
||||
},
|
||||
{
|
||||
"id": "loongtomoka_v3",
|
||||
"name": "loongtomoka",
|
||||
"gender": "female",
|
||||
"desc": "日语女"
|
||||
},
|
||||
{
|
||||
"id": "longfei_v3",
|
||||
"name": "龙飞",
|
||||
"gender": "male",
|
||||
"desc": "热血磁性男"
|
||||
},
|
||||
{
|
||||
"id": "longxiaochun_v3",
|
||||
"name": "龙小淳",
|
||||
"gender": "female",
|
||||
"desc": "清丽温柔女"
|
||||
},
|
||||
{
|
||||
"id": "longxiaoxia_v3",
|
||||
"name": "龙小夏",
|
||||
"gender": "female",
|
||||
"desc": "活泼甜美女"
|
||||
},
|
||||
{
|
||||
"id": "longshu_v3",
|
||||
"name": "龙舒",
|
||||
"gender": "female",
|
||||
"desc": "知性温婉女"
|
||||
},
|
||||
{
|
||||
"id": "longyue_v3",
|
||||
"name": "龙悦",
|
||||
"gender": "male",
|
||||
"desc": "阳光青年男"
|
||||
},
|
||||
{
|
||||
"id": "longcheng_v3",
|
||||
"name": "龙城",
|
||||
"gender": "male",
|
||||
"desc": "成熟稳重男"
|
||||
},
|
||||
{
|
||||
"id": "longhua_v3",
|
||||
"name": "龙华",
|
||||
"gender": "male",
|
||||
"desc": "标准男声"
|
||||
},
|
||||
{
|
||||
"id": "longwan_v3",
|
||||
"name": "龙婉",
|
||||
"gender": "female",
|
||||
"desc": "温婉知性女"
|
||||
},
|
||||
{
|
||||
"id": "longjing_v3",
|
||||
"name": "龙静",
|
||||
"gender": "female",
|
||||
"desc": "标准女声"
|
||||
},
|
||||
{
|
||||
"id": "longmiao_v3",
|
||||
"name": "龙淼",
|
||||
"gender": "female",
|
||||
"desc": "标准女声"
|
||||
},
|
||||
{
|
||||
"id": "longshuo_v3",
|
||||
"name": "龙硕",
|
||||
"gender": "male",
|
||||
"desc": "标准男声"
|
||||
},
|
||||
{
|
||||
"id": "longxiang_v3",
|
||||
"name": "龙翔",
|
||||
"gender": "male",
|
||||
"desc": "标准男声"
|
||||
},
|
||||
{
|
||||
"id": "longyuan_v3",
|
||||
"name": "龙源",
|
||||
"gender": "male",
|
||||
"desc": "标准男声"
|
||||
}
|
||||
],
|
||||
"enabled": true
|
||||
},
|
||||
"qwen3-tts-flash": {
|
||||
"name": "Qwen3-TTS-Flash",
|
||||
"class": "backend.src.services.provider.dashscope.audio.DashScopeAudioService",
|
||||
"args": ["qwen3-tts-flash"],
|
||||
"voices": [
|
||||
{
|
||||
"id": "Cherry",
|
||||
"name": "芊悦",
|
||||
"gender": "female",
|
||||
"desc": "阳光积极、亲切自然小姐姐"
|
||||
},
|
||||
{
|
||||
"id": "Serena",
|
||||
"name": "苏瑶",
|
||||
"gender": "female",
|
||||
"desc": "温柔小姐姐"
|
||||
},
|
||||
{
|
||||
"id": "Ethan",
|
||||
"name": "晨煦",
|
||||
"gender": "male",
|
||||
"desc": "阳光、温暖、活力、朝气"
|
||||
},
|
||||
{
|
||||
"id": "Chelsie",
|
||||
"name": "千雪",
|
||||
"gender": "female",
|
||||
"desc": "二次元虚拟女友"
|
||||
},
|
||||
{
|
||||
"id": "Momo",
|
||||
"name": "茉兔",
|
||||
"gender": "female",
|
||||
"desc": "撒娇搞怪,逗你开心"
|
||||
},
|
||||
{
|
||||
"id": "Vivian",
|
||||
"name": "十三",
|
||||
"gender": "female",
|
||||
"desc": "拽拽的、可爱的小暴躁"
|
||||
},
|
||||
{ "id": "Moon", "name": "月白", "gender": "male", "desc": "率性帅气" },
|
||||
{
|
||||
"id": "Maia",
|
||||
"name": "四月",
|
||||
"gender": "female",
|
||||
"desc": "知性与温柔的碰撞"
|
||||
},
|
||||
{ "id": "Kai", "name": "凯", "gender": "male", "desc": "耳朵的一场SPA" },
|
||||
{
|
||||
"id": "Nofish",
|
||||
"name": "不吃鱼",
|
||||
"gender": "male",
|
||||
"desc": "不会翘舌音的设计师"
|
||||
},
|
||||
{
|
||||
"id": "Bella",
|
||||
"name": "萌宝",
|
||||
"gender": "female",
|
||||
"desc": "喝酒不打醉拳的小萝莉"
|
||||
},
|
||||
{
|
||||
"id": "Jennifer",
|
||||
"name": "詹妮弗",
|
||||
"gender": "female",
|
||||
"desc": "品牌级、电影质感般美语女声"
|
||||
},
|
||||
{
|
||||
"id": "Ryan",
|
||||
"name": "甜茶",
|
||||
"gender": "male",
|
||||
"desc": "节奏拉满,戏感炸裂"
|
||||
},
|
||||
{
|
||||
"id": "Katerina",
|
||||
"name": "卡捷琳娜",
|
||||
"gender": "female",
|
||||
"desc": "御姐音色,韵律回味十足"
|
||||
},
|
||||
{
|
||||
"id": "Aiden",
|
||||
"name": "艾登",
|
||||
"gender": "male",
|
||||
"desc": "精通厨艺的美语大男孩"
|
||||
},
|
||||
{
|
||||
"id": "Eldric Sage",
|
||||
"name": "沧明子",
|
||||
"gender": "male",
|
||||
"desc": "沉稳睿智的老者"
|
||||
},
|
||||
{
|
||||
"id": "Mia",
|
||||
"name": "乖小妹",
|
||||
"gender": "female",
|
||||
"desc": "温顺如春水,乖巧如初雪"
|
||||
},
|
||||
{
|
||||
"id": "Mochi",
|
||||
"name": "沙小弥",
|
||||
"gender": "male",
|
||||
"desc": "聪明伶俐的小大人"
|
||||
},
|
||||
{
|
||||
"id": "Bellona",
|
||||
"name": "燕铮莺",
|
||||
"gender": "female",
|
||||
"desc": "金戈铁马,千面人声"
|
||||
},
|
||||
{
|
||||
"id": "Vincent",
|
||||
"name": "田叔",
|
||||
"gender": "male",
|
||||
"desc": "沙哑烟嗓,江湖豪情"
|
||||
},
|
||||
{
|
||||
"id": "Bunny",
|
||||
"name": "萌小姬",
|
||||
"gender": "female",
|
||||
"desc": "萌属性爆棚的小萝莉"
|
||||
},
|
||||
{
|
||||
"id": "Neil",
|
||||
"name": "阿闻",
|
||||
"gender": "male",
|
||||
"desc": "字正腔圆的新闻主持人"
|
||||
},
|
||||
{
|
||||
"id": "Elias",
|
||||
"name": "墨讲师",
|
||||
"gender": "female",
|
||||
"desc": "严谨又通俗的知识讲解"
|
||||
},
|
||||
{
|
||||
"id": "Arthur",
|
||||
"name": "徐大爷",
|
||||
"gender": "male",
|
||||
"desc": "质朴嗓音,奇闻异事"
|
||||
},
|
||||
{
|
||||
"id": "Nini",
|
||||
"name": "邻家妹妹",
|
||||
"gender": "female",
|
||||
"desc": "又软又黏的甜蜜嗓音"
|
||||
},
|
||||
{
|
||||
"id": "Ebona",
|
||||
"name": "诡婆婆",
|
||||
"gender": "female",
|
||||
"desc": "幽暗低语,神秘诡异"
|
||||
},
|
||||
{
|
||||
"id": "Seren",
|
||||
"name": "小婉",
|
||||
"gender": "female",
|
||||
"desc": "温和舒缓,助眠音色"
|
||||
},
|
||||
{
|
||||
"id": "Pip",
|
||||
"name": "顽屁小孩",
|
||||
"gender": "male",
|
||||
"desc": "调皮捣蛋却充满童真"
|
||||
},
|
||||
{
|
||||
"id": "Stella",
|
||||
"name": "少女阿月",
|
||||
"gender": "female",
|
||||
"desc": "甜到发腻的迷糊少女音"
|
||||
},
|
||||
{
|
||||
"id": "Bodega",
|
||||
"name": "博德加",
|
||||
"gender": "male",
|
||||
"desc": "热情的西班牙大叔"
|
||||
},
|
||||
{
|
||||
"id": "Sonrisa",
|
||||
"name": "索尼莎",
|
||||
"gender": "female",
|
||||
"desc": "热情开朗的拉美大姐"
|
||||
},
|
||||
{
|
||||
"id": "Alek",
|
||||
"name": "阿列克",
|
||||
"gender": "male",
|
||||
"desc": "战斗民族的冷与暖"
|
||||
},
|
||||
{
|
||||
"id": "Dolce",
|
||||
"name": "多尔切",
|
||||
"gender": "male",
|
||||
"desc": "慵懒的意大利大叔"
|
||||
},
|
||||
{
|
||||
"id": "Sohee",
|
||||
"name": "素熙",
|
||||
"gender": "female",
|
||||
"desc": "温柔开朗的韩国欧尼"
|
||||
},
|
||||
{
|
||||
"id": "Ono Anna",
|
||||
"name": "小野杏",
|
||||
"gender": "female",
|
||||
"desc": "鬼灵精怪的青梅竹马"
|
||||
},
|
||||
{
|
||||
"id": "Lenn",
|
||||
"name": "莱恩",
|
||||
"gender": "male",
|
||||
"desc": "理性底色的德国青年"
|
||||
},
|
||||
{
|
||||
"id": "Emilien",
|
||||
"name": "埃米尔安",
|
||||
"gender": "male",
|
||||
"desc": "浪漫的法国大哥哥"
|
||||
},
|
||||
{
|
||||
"id": "Andre",
|
||||
"name": "安德雷",
|
||||
"gender": "male",
|
||||
"desc": "声音磁性,沉稳男生"
|
||||
},
|
||||
{
|
||||
"id": "Radio Gol",
|
||||
"name": "拉迪奥·戈尔",
|
||||
"gender": "male",
|
||||
"desc": "足球诗人解说员"
|
||||
},
|
||||
{
|
||||
"id": "Jada",
|
||||
"name": "上海-阿珍",
|
||||
"gender": "female",
|
||||
"desc": "风风火火的沪上阿姐"
|
||||
},
|
||||
{
|
||||
"id": "Dylan",
|
||||
"name": "北京-晓东",
|
||||
"gender": "male",
|
||||
"desc": "北京胡同里长大的少年"
|
||||
},
|
||||
{
|
||||
"id": "Li",
|
||||
"name": "南京-老李",
|
||||
"gender": "male",
|
||||
"desc": "耐心的瑜伽老师"
|
||||
},
|
||||
{
|
||||
"id": "Marcus",
|
||||
"name": "陕西-秦川",
|
||||
"gender": "male",
|
||||
"desc": "面宽话短,心实声沉"
|
||||
},
|
||||
{
|
||||
"id": "Roy",
|
||||
"name": "闽南-阿杰",
|
||||
"gender": "male",
|
||||
"desc": "诙谐直爽的台湾哥仔"
|
||||
},
|
||||
{
|
||||
"id": "Peter",
|
||||
"name": "天津-李彼得",
|
||||
"gender": "male",
|
||||
"desc": "天津相声,专业捧哏"
|
||||
},
|
||||
{
|
||||
"id": "Sunny",
|
||||
"name": "四川-晴儿",
|
||||
"gender": "female",
|
||||
"desc": "甜到你心里的川妹子"
|
||||
},
|
||||
{
|
||||
"id": "Eric",
|
||||
"name": "四川-程川",
|
||||
"gender": "male",
|
||||
"desc": "跳脱市井的成都男子"
|
||||
},
|
||||
{
|
||||
"id": "Rocky",
|
||||
"name": "粤语-阿强",
|
||||
"gender": "male",
|
||||
"desc": "幽默风趣,在线陪聊"
|
||||
},
|
||||
{
|
||||
"id": "Kiki",
|
||||
"name": "粤语-阿清",
|
||||
"gender": "female",
|
||||
"desc": "甜美的港妹闺蜜"
|
||||
}
|
||||
],
|
||||
"enabled": true
|
||||
}
|
||||
}
|
||||
143
backend/src/config/services/dashscope/image.json
Normal file
143
backend/src/config/services/dashscope/image.json
Normal file
@@ -0,0 +1,143 @@
|
||||
{
|
||||
"z-image": {
|
||||
"name": "Z-Image",
|
||||
"class": "backend.src.services.provider.dashscope.image.ZImageService",
|
||||
"args": [
|
||||
"z-image"
|
||||
],
|
||||
"capabilities": {
|
||||
"supportsRefImage": false,
|
||||
"supportsLora": false
|
||||
},
|
||||
"variants": {
|
||||
"t2i": "z-image-turbo"
|
||||
},
|
||||
"resolutions": {
|
||||
"1K": {
|
||||
"16:9": "1536*864",
|
||||
"9:16": "864*1536",
|
||||
"1:1": "1280*1280",
|
||||
"4:3": "1280*960",
|
||||
"3:4": "960*1280",
|
||||
"21:9": "1680*720",
|
||||
"9:21": "720*1680"
|
||||
},
|
||||
"2K": {
|
||||
"16:9": "2048*1152",
|
||||
"9:16": "1152*2048",
|
||||
"21:9": "2016*864",
|
||||
"9:21": "864*2016"
|
||||
}
|
||||
},
|
||||
"counts": {
|
||||
"min": 1,
|
||||
"max": 4
|
||||
},
|
||||
"enabled": true
|
||||
},
|
||||
"wan2.6-image": {
|
||||
"name": "Wan 2.6",
|
||||
"class": "backend.src.services.provider.dashscope.image.WanImageService",
|
||||
"args": [
|
||||
"wan2.6-image"
|
||||
],
|
||||
"capabilities": {
|
||||
"supportsRefImage": true,
|
||||
"supportsLora": false
|
||||
},
|
||||
"variants": {
|
||||
"t2i": "wan2.6-t2i",
|
||||
"i2i": "wan2.6-image"
|
||||
},
|
||||
"resolutions": {
|
||||
"1K": {
|
||||
"16:9": "1280*720",
|
||||
"9:16": "720*1280",
|
||||
"1:1": "1280*1280",
|
||||
"4:3": "1280*960",
|
||||
"3:4": "960*1280"
|
||||
},
|
||||
"2K": {
|
||||
"16:9": "2560*1440",
|
||||
"9:16": "1440*2560",
|
||||
"1:1": "2048*2048",
|
||||
"4:3": "2560*1920",
|
||||
"3:4": "1920*2560"
|
||||
}
|
||||
},
|
||||
"counts": {
|
||||
"min": 1,
|
||||
"max": 4
|
||||
},
|
||||
"enabled": true
|
||||
},
|
||||
"wan2.5-image": {
|
||||
"name": "Wan 2.5",
|
||||
"class": "backend.src.services.provider.dashscope.image.WanImageService",
|
||||
"args": [
|
||||
"wan2.5-image"
|
||||
],
|
||||
"capabilities": {
|
||||
"supportsRefImage": true,
|
||||
"supportsLora": false
|
||||
},
|
||||
"variants": {
|
||||
"t2i": "wan2.5-t2i-preview",
|
||||
"i2i": "wan2.5-i2i-preview"
|
||||
},
|
||||
"resolutions": {
|
||||
"1K": {
|
||||
"16:9": "1280*720",
|
||||
"9:16": "720*1280",
|
||||
"1:1": "1280*1280",
|
||||
"4:3": "1280*960",
|
||||
"3:4": "960*1280"
|
||||
},
|
||||
"2K": {
|
||||
"16:9": "2560*1440",
|
||||
"9:16": "1440*2560",
|
||||
"1:1": "2048*2048",
|
||||
"4:3": "2560*1920",
|
||||
"3:4": "1920*2560"
|
||||
}
|
||||
},
|
||||
"counts": {
|
||||
"min": 1,
|
||||
"max": 4
|
||||
},
|
||||
"enabled": true
|
||||
},
|
||||
"qwen-image": {
|
||||
"name": "Qwen Image",
|
||||
"class": "backend.src.services.provider.dashscope.image.QwenImageService",
|
||||
"args": [
|
||||
"qwen-image"
|
||||
],
|
||||
"capabilities": {
|
||||
"supportsRefImage": true,
|
||||
"supportsLora": false
|
||||
},
|
||||
"variants": {
|
||||
"t2i": "qwen-image-plus",
|
||||
"i2i": "qwen-image-edit-plus"
|
||||
},
|
||||
"resolutions": {
|
||||
"1K": {
|
||||
"16:9": "1664*928",
|
||||
"9:16": "928*1664",
|
||||
"1:1": "1328*1328",
|
||||
"4:3": "1472*1140",
|
||||
"3:4": "1140*1472"
|
||||
},
|
||||
"2K": {
|
||||
"16:9": "2560*1440",
|
||||
"9:16": "1440*2560",
|
||||
"1:1": "2048*2048",
|
||||
"4:3": "2560*1920",
|
||||
"3:4": "1920*2560"
|
||||
}
|
||||
},
|
||||
"counts": {"min": 1, "max": 4},
|
||||
"enabled": true
|
||||
}
|
||||
}
|
||||
39
backend/src/config/services/dashscope/llm.json
Normal file
39
backend/src/config/services/dashscope/llm.json
Normal file
@@ -0,0 +1,39 @@
|
||||
{
|
||||
"base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
"qwen-plus": {
|
||||
"name": "Qwen Plus",
|
||||
"class": "backend.src.services.provider.openai_service.OpenAIService",
|
||||
"args": ["qwen-plus"],
|
||||
"enabled": true
|
||||
},
|
||||
"qwen3-max": {
|
||||
"name": "Qwen Max",
|
||||
"class": "backend.src.services.provider.openai_service.OpenAIService",
|
||||
"args": ["qwen3-max"],
|
||||
"enabled": true
|
||||
},
|
||||
"qwen-flash": {
|
||||
"name": "Qwen Flash",
|
||||
"class": "backend.src.services.provider.openai_service.OpenAIService",
|
||||
"args": ["qwen-flash"],
|
||||
"enabled": true
|
||||
},
|
||||
"deepseek": {
|
||||
"name": "Deepseek",
|
||||
"class": "backend.src.services.provider.openai_service.OpenAIService",
|
||||
"args": ["deepseek-v3.2"],
|
||||
"enabled": true
|
||||
},
|
||||
"kimi-k2.5": {
|
||||
"name": "Kimi K2.5",
|
||||
"class": "backend.src.services.provider.openai_service.OpenAIService",
|
||||
"args": ["kimi-k2.5"],
|
||||
"enabled": true
|
||||
},
|
||||
"MiniMax-M2.1": {
|
||||
"name": "MiniMax M2.1",
|
||||
"class": "backend.src.services.provider.openai_service.OpenAIService",
|
||||
"args": ["MiniMax-M2.1"],
|
||||
"enabled": true
|
||||
}
|
||||
}
|
||||
16
backend/src/config/services/dashscope/provider.json
Normal file
16
backend/src/config/services/dashscope/provider.json
Normal file
@@ -0,0 +1,16 @@
|
||||
{
|
||||
"id": "dashscope",
|
||||
"name": "阿里云百炼",
|
||||
"description": "阿里云提供的大模型服务",
|
||||
"dashboard_url": "https://dashscope.console.aliyun.com/",
|
||||
"helpUrl": "https://dashscope.console.aliyun.com/apiKey",
|
||||
"fields": [
|
||||
{
|
||||
"name": "apiKey",
|
||||
"label": "API Key",
|
||||
"placeholder": "sk-...",
|
||||
"required": true,
|
||||
"type": "password"
|
||||
}
|
||||
]
|
||||
}
|
||||
196
backend/src/config/services/dashscope/video.json
Normal file
196
backend/src/config/services/dashscope/video.json
Normal file
@@ -0,0 +1,196 @@
|
||||
{
|
||||
"wan2.6-video": {
|
||||
"name": "Wan 2.6",
|
||||
"class": "backend.src.services.provider.dashscope.video.WanVideoService",
|
||||
"args": [
|
||||
"wan2.6-video"
|
||||
],
|
||||
"capabilities": {
|
||||
"supportsTextToVideo": true,
|
||||
"supportsImageToVideo": true,
|
||||
"supportsVideoToVideo": true,
|
||||
"supportsFirstFrame": true,
|
||||
"supportsLastFrame": false,
|
||||
"supportsMultiImage": true,
|
||||
"supportsMultiVideo": true,
|
||||
"supportsAudio": true,
|
||||
"supportsShotType": true,
|
||||
"supportsNegativePrompt": true,
|
||||
"supportsLora": false
|
||||
},
|
||||
"variants": {
|
||||
"t2v": "wan2.6-t2v",
|
||||
"i2v": "wan2.6-i2v",
|
||||
"r2v": "wan2.6-r2v"
|
||||
},
|
||||
"resolutions": {
|
||||
"720P": {
|
||||
"16:9": "1280*720",
|
||||
"9:16": "720*1280",
|
||||
"1:1": "1280*1280",
|
||||
"4:3": "1280*960",
|
||||
"3:4": "960*1280"
|
||||
},
|
||||
"1080P": {
|
||||
"16:9": "1920*1080",
|
||||
"9:16": "1080*1920",
|
||||
"1:1": "1920*1920",
|
||||
"4:3": "1920*1440",
|
||||
"3:4": "1440*1920"
|
||||
}
|
||||
},
|
||||
"durations": {
|
||||
"min": 2,
|
||||
"max": 10
|
||||
},
|
||||
"counts": {
|
||||
"min": 1,
|
||||
"max": 4
|
||||
},
|
||||
"enabled": true
|
||||
},
|
||||
"wan2.6-video-flash": {
|
||||
"name": "Wan 2.6 Flash",
|
||||
"class": "backend.src.services.provider.dashscope.video.WanVideoService",
|
||||
"args": [
|
||||
"wan2.6-video-flash"
|
||||
],
|
||||
"capabilities": {
|
||||
"supportsTextToVideo": true,
|
||||
"supportsImageToVideo": true,
|
||||
"supportsVideoToVideo": true,
|
||||
"supportsFirstFrame": true,
|
||||
"supportsLastFrame": false,
|
||||
"supportsMultiImage": true,
|
||||
"supportsMultiVideo": true,
|
||||
"supportsAudio": true,
|
||||
"supportsShotType": true,
|
||||
"supportsNegativePrompt": true,
|
||||
"supportsLora": false
|
||||
},
|
||||
"variants": {
|
||||
"t2v": "wan2.6-t2v",
|
||||
"i2v": "wan2.6-i2v-flash",
|
||||
"r2v": "wan2.6-r2v-flash"
|
||||
},
|
||||
"resolutions": {
|
||||
"720P": {
|
||||
"16:9": "1280*720",
|
||||
"9:16": "720*1280",
|
||||
"1:1": "1280*1280",
|
||||
"4:3": "1280*960",
|
||||
"3:4": "960*1280"
|
||||
},
|
||||
"1080P": {
|
||||
"16:9": "1920*1080",
|
||||
"9:16": "1080*1920",
|
||||
"1:1": "1920*1920",
|
||||
"4:3": "1920*1440",
|
||||
"3:4": "1440*1920"
|
||||
}
|
||||
},
|
||||
"durations": {
|
||||
"min": 2,
|
||||
"max": 10
|
||||
},
|
||||
"counts": {
|
||||
"min": 1,
|
||||
"max": 4
|
||||
},
|
||||
"enabled": true
|
||||
},
|
||||
"wan2.5-video": {
|
||||
"name": "Wan 2.5",
|
||||
"class": "backend.src.services.provider.dashscope.video.WanVideoService",
|
||||
"args": [
|
||||
"wan2.5-video"
|
||||
],
|
||||
"capabilities": {
|
||||
"supportsTextToVideo": true,
|
||||
"supportsImageToVideo": true,
|
||||
"supportsFirstFrame": true,
|
||||
"supportsLastFrame": false,
|
||||
"supportsMultiImage": false,
|
||||
"supportsMultiVideo": false,
|
||||
"supportsLora": false
|
||||
},
|
||||
"variants": {
|
||||
"t2v": "wan2.5-t2v-preview",
|
||||
"i2v": "wan2.5-i2v-preview"
|
||||
},
|
||||
"resolutions": {
|
||||
"720P": {
|
||||
"16:9": "1280*720",
|
||||
"9:16": "720*1280",
|
||||
"1:1": "1280*1280",
|
||||
"4:3": "1280*960",
|
||||
"3:4": "960*1280"
|
||||
},
|
||||
"1080P": {
|
||||
"16:9": "1920*1080",
|
||||
"9:16": "1080*1920",
|
||||
"1:1": "1920*1920",
|
||||
"4:3": "1920*1440",
|
||||
"3:4": "1440*1920"
|
||||
}
|
||||
},
|
||||
"durations": {
|
||||
"values": [
|
||||
5,
|
||||
10
|
||||
]
|
||||
},
|
||||
"counts": {
|
||||
"min": 1,
|
||||
"max": 4
|
||||
},
|
||||
"enabled": true
|
||||
},
|
||||
"wan2.2-video": {
|
||||
"name": "Wan 2.2",
|
||||
"class": "backend.src.services.provider.dashscope.video.WanVideoService",
|
||||
"args": [
|
||||
"wan2.2-video"
|
||||
],
|
||||
"capabilities": {
|
||||
"supportsTextToVideo": true,
|
||||
"supportsImageToVideo": true,
|
||||
"supportsFirstFrame": true,
|
||||
"supportsLastFrame": true,
|
||||
"supportsMultiImage": false,
|
||||
"supportsMultiVideo": false,
|
||||
"supportsLora": false
|
||||
},
|
||||
"variants": {
|
||||
"t2v": "wan2.2-t2v-plus",
|
||||
"i2v": "wan2.2-i2v-flash",
|
||||
"kf2v": "wan2.2-kf2v-flash"
|
||||
},
|
||||
"resolutions": {
|
||||
"720P": {
|
||||
"16:9": "1280*720",
|
||||
"9:16": "720*1280",
|
||||
"1:1": "1280*1280",
|
||||
"4:3": "1280*960",
|
||||
"3:4": "960*1280"
|
||||
},
|
||||
"1080P": {
|
||||
"16:9": "1920*1080",
|
||||
"9:16": "1080*1920",
|
||||
"1:1": "1920*1920",
|
||||
"4:3": "1920*1440",
|
||||
"3:4": "1440*1920"
|
||||
}
|
||||
},
|
||||
"durations": {
|
||||
"values": [
|
||||
5
|
||||
]
|
||||
},
|
||||
"counts": {
|
||||
"min": 1,
|
||||
"max": 4
|
||||
},
|
||||
"enabled": true
|
||||
}
|
||||
}
|
||||
7
backend/src/config/services/default.json
Normal file
7
backend/src/config/services/default.json
Normal file
@@ -0,0 +1,7 @@
|
||||
{
|
||||
"llm": "qwen-plus",
|
||||
"image": "z-image",
|
||||
"video": "wan2.6-video",
|
||||
"audio": "qwen3-tts-flash",
|
||||
"upscale": "ali-videoenhan/videoenhan"
|
||||
}
|
||||
27
backend/src/config/services/google/llm.json
Normal file
27
backend/src/config/services/google/llm.json
Normal file
@@ -0,0 +1,27 @@
|
||||
{
|
||||
"base_url": "https://generativelanguage.googleapis.com/v1beta/openai/",
|
||||
"gemini-1.5-pro": {
|
||||
"name": "Gemini-1.5-Pro",
|
||||
"class": "backend.src.services.provider.openai_service.OpenAIService",
|
||||
"args": [
|
||||
"gemini-1.5-pro"
|
||||
],
|
||||
"enabled": false
|
||||
},
|
||||
"gemini-1.5-flash": {
|
||||
"name": "Gemini-1.5-Flash",
|
||||
"class": "backend.src.services.provider.openai_service.OpenAIService",
|
||||
"args": [
|
||||
"gemini-1.5-flash"
|
||||
],
|
||||
"enabled": false
|
||||
},
|
||||
"gemini-2.0-flash": {
|
||||
"name": "Gemini-2.0-Flash",
|
||||
"class": "backend.src.services.provider.openai_service.OpenAIService",
|
||||
"args": [
|
||||
"gemini-2.0-flash"
|
||||
],
|
||||
"enabled": false
|
||||
}
|
||||
}
|
||||
16
backend/src/config/services/google/provider.json
Normal file
16
backend/src/config/services/google/provider.json
Normal file
@@ -0,0 +1,16 @@
|
||||
{
|
||||
"id": "google",
|
||||
"name": "Google",
|
||||
"description": "Gemini 系列模型",
|
||||
"dashboard_url": "https://aistudio.google.com/app/apikey",
|
||||
"helpUrl": "https://aistudio.google.com/app/apikey",
|
||||
"fields": [
|
||||
{
|
||||
"name": "apiKey",
|
||||
"label": "API Key",
|
||||
"placeholder": "...",
|
||||
"required": true,
|
||||
"type": "password"
|
||||
}
|
||||
]
|
||||
}
|
||||
27
backend/src/config/services/kling/provider.json
Normal file
27
backend/src/config/services/kling/provider.json
Normal file
@@ -0,0 +1,27 @@
|
||||
{
|
||||
"id": "kling",
|
||||
"name": "可灵 AI",
|
||||
"description": "快手视频生成 - 需要 Access Key 和 Secret Key",
|
||||
"dashboard_url": "https://klingai.kuaishou.com/",
|
||||
"param_mapping": {
|
||||
"api_key": "access_key",
|
||||
"api_secret": "secret_key"
|
||||
},
|
||||
"helpUrl": "https://klingai.kuaishou.com/",
|
||||
"fields": [
|
||||
{
|
||||
"name": "accessKey",
|
||||
"label": "Access Key",
|
||||
"placeholder": "Access Key",
|
||||
"required": true,
|
||||
"type": "text"
|
||||
},
|
||||
{
|
||||
"name": "secretKey",
|
||||
"label": "Secret Key",
|
||||
"placeholder": "Secret Key",
|
||||
"required": true,
|
||||
"type": "password"
|
||||
}
|
||||
]
|
||||
}
|
||||
124
backend/src/config/services/kling/video.json
Normal file
124
backend/src/config/services/kling/video.json
Normal file
@@ -0,0 +1,124 @@
|
||||
{
|
||||
"kling-v2-5-turbo": {
|
||||
"name": "Kling V2.5 Turbo",
|
||||
"class": "src.services.provider.kling.KlingVideoService",
|
||||
"args": [
|
||||
"kling-v2-5-turbo"
|
||||
],
|
||||
"capabilities": {
|
||||
"supportsTextToVideo": true,
|
||||
"supportsImageToVideo": true,
|
||||
"supportsFirstFrame": true,
|
||||
"supportsLastFrame": true,
|
||||
"supportsMultiImage": false,
|
||||
"supportsMultiVideo": false,
|
||||
"supportsAudio": false,
|
||||
"supportsNegativePrompt": true,
|
||||
"supportsCameraControl": true,
|
||||
"supportsLora": false
|
||||
},
|
||||
"variants": {
|
||||
"t2v": "kling-v2-5-turbo",
|
||||
"i2v": "kling-v2-5-turbo",
|
||||
"kf2v": "kling-v2-5-turbo"
|
||||
},
|
||||
"modes": ["std", "pro"],
|
||||
"resolutions": {
|
||||
"720P": {
|
||||
"16:9": "1280*720",
|
||||
"9:16": "720*1280",
|
||||
"1:1": "1024*1024"
|
||||
},
|
||||
"1080P": {
|
||||
"16:9": "1920*1080",
|
||||
"9:16": "1080*1920",
|
||||
"1:1": "1024*1024"
|
||||
}
|
||||
},
|
||||
"durations": {
|
||||
"values": [5, 10]
|
||||
},
|
||||
"counts": {"min": 1, "max": 4},
|
||||
"enabled": true
|
||||
},
|
||||
"kling-v2-6": {
|
||||
"name": "Kling V2.6",
|
||||
"class": "src.services.provider.kling.KlingVideoService",
|
||||
"args": [
|
||||
"kling-v2-6"
|
||||
],
|
||||
"capabilities": {
|
||||
"supportsTextToVideo": true,
|
||||
"supportsImageToVideo": true,
|
||||
"supportsFirstFrame": true,
|
||||
"supportsLastFrame": true,
|
||||
"supportsMultiImage": false,
|
||||
"supportsMultiVideo": false,
|
||||
"supportsAudio": true,
|
||||
"supportsNegativePrompt": true,
|
||||
"supportsCameraControl": true,
|
||||
"supportsLora": false
|
||||
},
|
||||
"variants": {
|
||||
"t2v": "kling-v2-6",
|
||||
"i2v": "kling-v2-6",
|
||||
"kf2v": "kling-v2-6"
|
||||
},
|
||||
"modes": ["pro"],
|
||||
"resolutions": {
|
||||
"720P": {
|
||||
"16:9": "1280*720",
|
||||
"9:16": "720*1280",
|
||||
"1:1": "1024*1024"
|
||||
},
|
||||
"1080P": {
|
||||
"16:9": "1920*1080",
|
||||
"9:16": "1080*1920",
|
||||
"1:1": "1024*1024"
|
||||
}
|
||||
},
|
||||
"durations": {
|
||||
"values": [5, 10]
|
||||
},
|
||||
"counts": {"min": 1, "max": 4},
|
||||
"enabled": true
|
||||
},
|
||||
"kling-video-o1": {
|
||||
"name": "Kling Omni (O1)",
|
||||
"class": "src.services.provider.kling.KlingVideoService",
|
||||
"args": [
|
||||
"kling-video-o1"
|
||||
],
|
||||
"capabilities": {
|
||||
"supportsTextToVideo": true,
|
||||
"supportsImageToVideo": true,
|
||||
"supportsVideoToVideo": true,
|
||||
"supportsFirstFrame": true,
|
||||
"supportsLastFrame": true,
|
||||
"supportsMultiImage": true,
|
||||
"supportsMultiVideo": true,
|
||||
"supportsAudio": true,
|
||||
"supportsNegativePrompt": true,
|
||||
"supportsLora": false
|
||||
},
|
||||
"variants": {
|
||||
"t2v": "kling-video-o1",
|
||||
"i2v": "kling-video-o1",
|
||||
"v2v": "kling-video-o1",
|
||||
"kf2v": "kling-video-o1"
|
||||
},
|
||||
"resolutions": {
|
||||
"720P": {
|
||||
"16:9": "1280*720",
|
||||
"9:16": "720*1280",
|
||||
"1:1": "1024*1024"
|
||||
}
|
||||
},
|
||||
"durations": {
|
||||
"min": 3,
|
||||
"max": 10
|
||||
},
|
||||
"counts": {"min": 1, "max": 4},
|
||||
"enabled": true
|
||||
}
|
||||
}
|
||||
34
backend/src/config/services/midjourney/image.json
Normal file
34
backend/src/config/services/midjourney/image.json
Normal file
@@ -0,0 +1,34 @@
|
||||
{
|
||||
"midjourney": {
|
||||
"name": "Midjourney",
|
||||
"class": "src.services.provider.midjourney.MidjourneyImageService",
|
||||
"args": [
|
||||
"midjourney"
|
||||
],
|
||||
"capabilities": {
|
||||
"supportsRefImage": false,
|
||||
"supportsLora": false
|
||||
},
|
||||
"variants": {
|
||||
"t2i": "midjourney"
|
||||
},
|
||||
"resolutions": {
|
||||
"1K": {
|
||||
"1:1": "1024*1024",
|
||||
"16:9": "1280*720",
|
||||
"9:16": "720*1280",
|
||||
"4:3": "1280*960",
|
||||
"3:4": "960*1280"
|
||||
},
|
||||
"2K": {
|
||||
"1:1": "1456*1456",
|
||||
"16:9": "1456*816",
|
||||
"9:16": "816*1456",
|
||||
"4:3": "1232*928",
|
||||
"3:4": "928*1232"
|
||||
}
|
||||
},
|
||||
"counts": {"min": 1, "max": 4},
|
||||
"enabled": false
|
||||
}
|
||||
}
|
||||
23
backend/src/config/services/midjourney/provider.json
Normal file
23
backend/src/config/services/midjourney/provider.json
Normal file
@@ -0,0 +1,23 @@
|
||||
{
|
||||
"id": "midjourney",
|
||||
"name": "Midjourney(悠船)",
|
||||
"description": "悠船 API - 需要 App ID 和 Secret Key",
|
||||
"dashboard_url": "https://ali.youchuan.cn/",
|
||||
"helpUrl": "https://ali.youchuan.cn/",
|
||||
"fields": [
|
||||
{
|
||||
"name": "apiKey",
|
||||
"label": "App ID",
|
||||
"placeholder": "应用 ID",
|
||||
"required": true,
|
||||
"type": "text"
|
||||
},
|
||||
{
|
||||
"name": "apiSecret",
|
||||
"label": "Secret Key",
|
||||
"placeholder": "密钥",
|
||||
"required": true,
|
||||
"type": "password"
|
||||
}
|
||||
]
|
||||
}
|
||||
378
backend/src/config/services/minimax/audio.json
Normal file
378
backend/src/config/services/minimax/audio.json
Normal file
@@ -0,0 +1,378 @@
|
||||
{
|
||||
"speech-2.8-hd": {
|
||||
"name": "speech-2.8-hd",
|
||||
"class": "src.services.provider.minimax.MiniMaxAudioService",
|
||||
"args": [
|
||||
"speech-2.8-hd"
|
||||
],
|
||||
"voices": [
|
||||
{
|
||||
"id": "male-qn-qingse",
|
||||
"name": "青涩青年",
|
||||
"gender": "male",
|
||||
"desc": "青涩青年风格中文普通话男声"
|
||||
},
|
||||
{
|
||||
"id": "male-qn-jingying",
|
||||
"name": "精英青年",
|
||||
"gender": "male",
|
||||
"desc": "精英青年风格中文普通话男声"
|
||||
},
|
||||
{
|
||||
"id": "male-qn-badao",
|
||||
"name": "霸道青年",
|
||||
"gender": "male",
|
||||
"desc": "沉稳有力的中文普通话青年男声"
|
||||
},
|
||||
{
|
||||
"id": "male-qn-daxuesheng",
|
||||
"name": "青年大学生",
|
||||
"gender": "male",
|
||||
"desc": "青年大学生风格中文普通话男声"
|
||||
},
|
||||
{
|
||||
"id": "female-shaonv",
|
||||
"name": "少女",
|
||||
"gender": "female",
|
||||
"desc": "清亮少女风格中文普通话女声"
|
||||
},
|
||||
{
|
||||
"id": "female-yujie",
|
||||
"name": "御姐",
|
||||
"gender": "female",
|
||||
"desc": "成熟干练风格中文普通话女声"
|
||||
},
|
||||
{
|
||||
"id": "female-chengshu",
|
||||
"name": "成熟女性",
|
||||
"gender": "female",
|
||||
"desc": "成熟女性风格中文普通话女声"
|
||||
},
|
||||
{
|
||||
"id": "female-tianmei",
|
||||
"name": "甜美女性",
|
||||
"gender": "female",
|
||||
"desc": "甜美风格中文普通话女声"
|
||||
},
|
||||
{
|
||||
"id": "Chinese (Mandarin)_News_Anchor",
|
||||
"name": "新闻女声",
|
||||
"gender": "female",
|
||||
"desc": "专业播音腔的中年女性新闻主播,标准普通话"
|
||||
},
|
||||
{
|
||||
"id": "Chinese (Mandarin)_Male_Announcer",
|
||||
"name": "播报男声",
|
||||
"gender": "male",
|
||||
"desc": "富有磁性的中年男性播报员声音,标准普通话,清晰而权威"
|
||||
},
|
||||
{
|
||||
"id": "Chinese (Mandarin)_Gentleman",
|
||||
"name": "温润男声",
|
||||
"gender": "male",
|
||||
"desc": "温润磁性的青年男性声音,标准普通话"
|
||||
},
|
||||
{
|
||||
"id": "Chinese (Mandarin)_Sweet_Lady",
|
||||
"name": "甜美女声",
|
||||
"gender": "female",
|
||||
"desc": "温柔甜美的青年女性声音,标准普通话"
|
||||
},
|
||||
{
|
||||
"id": "Cantonese_ProfessionalHost(F)",
|
||||
"name": "粤语专业女主持",
|
||||
"gender": "female",
|
||||
"desc": "中性、专业的青年女性粤语主持人声音"
|
||||
},
|
||||
{
|
||||
"id": "Cantonese_ProfessionalHost(M)",
|
||||
"name": "粤语专业男主持",
|
||||
"gender": "male",
|
||||
"desc": "中性、专业的青年男性粤语主持人声音"
|
||||
}
|
||||
],
|
||||
"enabled": true
|
||||
},
|
||||
"speech-2.6-hd": {
|
||||
"name": "speech-2.6-hd",
|
||||
"class": "src.services.provider.minimax.MiniMaxAudioService",
|
||||
"args": [
|
||||
"speech-2.6-hd"
|
||||
],
|
||||
"voices": [
|
||||
{
|
||||
"id": "male-qn-qingse",
|
||||
"name": "青涩青年",
|
||||
"gender": "male",
|
||||
"desc": "青涩青年风格中文普通话男声"
|
||||
},
|
||||
{
|
||||
"id": "male-qn-jingying",
|
||||
"name": "精英青年",
|
||||
"gender": "male",
|
||||
"desc": "精英青年风格中文普通话男声"
|
||||
},
|
||||
{
|
||||
"id": "male-qn-badao",
|
||||
"name": "霸道青年",
|
||||
"gender": "male",
|
||||
"desc": "沉稳有力的中文普通话青年男声"
|
||||
},
|
||||
{
|
||||
"id": "male-qn-daxuesheng",
|
||||
"name": "青年大学生",
|
||||
"gender": "male",
|
||||
"desc": "青年大学生风格中文普通话男声"
|
||||
},
|
||||
{
|
||||
"id": "female-shaonv",
|
||||
"name": "少女",
|
||||
"gender": "female",
|
||||
"desc": "清亮少女风格中文普通话女声"
|
||||
},
|
||||
{
|
||||
"id": "female-yujie",
|
||||
"name": "御姐",
|
||||
"gender": "female",
|
||||
"desc": "成熟干练风格中文普通话女声"
|
||||
},
|
||||
{
|
||||
"id": "female-chengshu",
|
||||
"name": "成熟女性",
|
||||
"gender": "female",
|
||||
"desc": "成熟女性风格中文普通话女声"
|
||||
},
|
||||
{
|
||||
"id": "female-tianmei",
|
||||
"name": "甜美女性",
|
||||
"gender": "female",
|
||||
"desc": "甜美风格中文普通话女声"
|
||||
},
|
||||
{
|
||||
"id": "Chinese (Mandarin)_News_Anchor",
|
||||
"name": "新闻女声",
|
||||
"gender": "female",
|
||||
"desc": "专业、播音腔的中年女性新闻主播,标准普通话"
|
||||
},
|
||||
{
|
||||
"id": "Chinese (Mandarin)_Male_Announcer",
|
||||
"name": "播报男声",
|
||||
"gender": "male",
|
||||
"desc": "富有磁性的中年男性播报员声音,标准普通话,清晰而权威"
|
||||
},
|
||||
{
|
||||
"id": "Chinese (Mandarin)_Gentleman",
|
||||
"name": "温润男声",
|
||||
"gender": "male",
|
||||
"desc": "温润磁性的青年男性声音,标准普通话"
|
||||
},
|
||||
{
|
||||
"id": "Chinese (Mandarin)_Sweet_Lady",
|
||||
"name": "甜美女声",
|
||||
"gender": "female",
|
||||
"desc": "温柔甜美的青年女性声音,标准普通话"
|
||||
},
|
||||
{
|
||||
"id": "Cantonese_ProfessionalHost(F)",
|
||||
"name": "粤语专业女主持",
|
||||
"gender": "female",
|
||||
"desc": "中性、专业的青年女性粤语主持人声音"
|
||||
},
|
||||
{
|
||||
"id": "Cantonese_ProfessionalHost(M)",
|
||||
"name": "粤语专业男主持",
|
||||
"gender": "male",
|
||||
"desc": "中性、专业的青年男性粤语主持人声音"
|
||||
}
|
||||
],
|
||||
"enabled": true
|
||||
},
|
||||
"speech-2.8-turbo": {
|
||||
"name": "speech-2.8-turbo",
|
||||
"class": "src.services.provider.minimax.MiniMaxAudioService",
|
||||
"args": [
|
||||
"speech-2.8-turbo"
|
||||
],
|
||||
"voices": [
|
||||
{
|
||||
"id": "male-qn-qingse",
|
||||
"name": "青涩青年",
|
||||
"gender": "male",
|
||||
"desc": "青涩青年风格中文普通话男声"
|
||||
},
|
||||
{
|
||||
"id": "male-qn-jingying",
|
||||
"name": "精英青年",
|
||||
"gender": "male",
|
||||
"desc": "精英青年风格中文普通话男声"
|
||||
},
|
||||
{
|
||||
"id": "male-qn-badao",
|
||||
"name": "霸道青年",
|
||||
"gender": "male",
|
||||
"desc": "沉稳有力的中文普通话青年男声"
|
||||
},
|
||||
{
|
||||
"id": "male-qn-daxuesheng",
|
||||
"name": "青年大学生",
|
||||
"gender": "male",
|
||||
"desc": "青年大学生风格中文普通话男声"
|
||||
},
|
||||
{
|
||||
"id": "female-shaonv",
|
||||
"name": "少女",
|
||||
"gender": "female",
|
||||
"desc": "清亮少女风格中文普通话女声"
|
||||
},
|
||||
{
|
||||
"id": "female-yujie",
|
||||
"name": "御姐",
|
||||
"gender": "female",
|
||||
"desc": "成熟干练风格中文普通话女声"
|
||||
},
|
||||
{
|
||||
"id": "female-chengshu",
|
||||
"name": "成熟女性",
|
||||
"gender": "female",
|
||||
"desc": "成熟女性风格中文普通话女声"
|
||||
},
|
||||
{
|
||||
"id": "female-tianmei",
|
||||
"name": "甜美女性",
|
||||
"gender": "female",
|
||||
"desc": "甜美风格中文普通话女声"
|
||||
},
|
||||
{
|
||||
"id": "Chinese (Mandarin)_News_Anchor",
|
||||
"name": "新闻女声",
|
||||
"gender": "female",
|
||||
"desc": "专业、播音腔的中年女性新闻主播,标准普通话"
|
||||
},
|
||||
{
|
||||
"id": "Chinese (Mandarin)_Male_Announcer",
|
||||
"name": "播报男声",
|
||||
"gender": "male",
|
||||
"desc": "富有磁性的中年男性播报员声音,标准普通话,清晰而权威"
|
||||
},
|
||||
{
|
||||
"id": "Chinese (Mandarin)_Gentleman",
|
||||
"name": "温润男声",
|
||||
"gender": "male",
|
||||
"desc": "温润磁性的青年男性声音,标准普通话"
|
||||
},
|
||||
{
|
||||
"id": "Chinese (Mandarin)_Sweet_Lady",
|
||||
"name": "甜美女声",
|
||||
"gender": "female",
|
||||
"desc": "温柔甜美的青年女性声音,标准普通话"
|
||||
},
|
||||
{
|
||||
"id": "Cantonese_ProfessionalHost(F)",
|
||||
"name": "粤语专业女主持",
|
||||
"gender": "female",
|
||||
"desc": "中性、专业的青年女性粤语主持人声音"
|
||||
},
|
||||
{
|
||||
"id": "Cantonese_ProfessionalHost(M)",
|
||||
"name": "粤语专业男主持",
|
||||
"gender": "male",
|
||||
"desc": "中性、专业的青年男性粤语主持人声音"
|
||||
}
|
||||
],
|
||||
"enabled": true
|
||||
},
|
||||
"speech-2.6-turbo": {
|
||||
"name": "speech-2.6-turbo",
|
||||
"class": "src.services.provider.minimax.MiniMaxAudioService",
|
||||
"args": [
|
||||
"speech-2.6-turbo"
|
||||
],
|
||||
"voices": [
|
||||
{
|
||||
"id": "male-qn-qingse",
|
||||
"name": "青涩青年",
|
||||
"gender": "male",
|
||||
"desc": "青涩青年风格中文普通话男声"
|
||||
},
|
||||
{
|
||||
"id": "male-qn-jingying",
|
||||
"name": "精英青年",
|
||||
"gender": "male",
|
||||
"desc": "精英青年风格中文普通话男声"
|
||||
},
|
||||
{
|
||||
"id": "male-qn-badao",
|
||||
"name": "霸道青年",
|
||||
"gender": "male",
|
||||
"desc": "沉稳有力的中文普通话青年男声"
|
||||
},
|
||||
{
|
||||
"id": "male-qn-daxuesheng",
|
||||
"name": "青年大学生",
|
||||
"gender": "male",
|
||||
"desc": "青年大学生风格中文普通话男声"
|
||||
},
|
||||
{
|
||||
"id": "female-shaonv",
|
||||
"name": "少女",
|
||||
"gender": "female",
|
||||
"desc": "清亮少女风格中文普通话女声"
|
||||
},
|
||||
{
|
||||
"id": "female-yujie",
|
||||
"name": "御姐",
|
||||
"gender": "female",
|
||||
"desc": "成熟干练风格中文普通话女声"
|
||||
},
|
||||
{
|
||||
"id": "female-chengshu",
|
||||
"name": "成熟女性",
|
||||
"gender": "female",
|
||||
"desc": "成熟女性风格中文普通话女声"
|
||||
},
|
||||
{
|
||||
"id": "female-tianmei",
|
||||
"name": "甜美女性",
|
||||
"gender": "female",
|
||||
"desc": "甜美风格中文普通话女声"
|
||||
},
|
||||
{
|
||||
"id": "Chinese (Mandarin)_News_Anchor",
|
||||
"name": "新闻女声",
|
||||
"gender": "female",
|
||||
"desc": "专业、播音腔的中年女性新闻主播,标准普通话"
|
||||
},
|
||||
{
|
||||
"id": "Chinese (Mandarin)_Male_Announcer",
|
||||
"name": "播报男声",
|
||||
"gender": "male",
|
||||
"desc": "富有磁性的中年男性播报员声音,标准普通话,清晰而权威"
|
||||
},
|
||||
{
|
||||
"id": "Chinese (Mandarin)_Gentleman",
|
||||
"name": "温润男声",
|
||||
"gender": "male",
|
||||
"desc": "温润磁性的青年男性声音,标准普通话"
|
||||
},
|
||||
{
|
||||
"id": "Chinese (Mandarin)_Sweet_Lady",
|
||||
"name": "甜美女声",
|
||||
"gender": "female",
|
||||
"desc": "温柔甜美的青年女性声音,标准普通话"
|
||||
},
|
||||
{
|
||||
"id": "Cantonese_ProfessionalHost(F)",
|
||||
"name": "粤语专业女主持",
|
||||
"gender": "female",
|
||||
"desc": "中性、专业的青年女性粤语主持人声音"
|
||||
},
|
||||
{
|
||||
"id": "Cantonese_ProfessionalHost(M)",
|
||||
"name": "粤语专业男主持",
|
||||
"gender": "male",
|
||||
"desc": "中性、专业的青年男性粤语主持人声音"
|
||||
}
|
||||
],
|
||||
"enabled": true
|
||||
}
|
||||
}
|
||||
64
backend/src/config/services/minimax/image.json
Normal file
64
backend/src/config/services/minimax/image.json
Normal file
@@ -0,0 +1,64 @@
|
||||
{
|
||||
"image-01": {
|
||||
"name": "image-01",
|
||||
"class": "src.services.provider.minimax.MiniMaxImageService",
|
||||
"capabilities": {
|
||||
"supportsRefImage": false,
|
||||
"supportsLora": false
|
||||
},
|
||||
"resolutions": {
|
||||
"1K": {
|
||||
"1:1": "1024x1024",
|
||||
"16:9": "1280x720",
|
||||
"9:16": "720x1280",
|
||||
"4:3": "1152x864",
|
||||
"3:4": "864x1152",
|
||||
"3:2": "1248x832",
|
||||
"2:3": "832x1248",
|
||||
"21:9": "1344x576"
|
||||
},
|
||||
"2K": {
|
||||
"1:1": "2048x2048",
|
||||
"16:9": "2560x1440",
|
||||
"9:16": "1440x2560",
|
||||
"4:3": "2304x1728",
|
||||
"3:4": "1728x2304",
|
||||
"3:2": "2496x1664",
|
||||
"2:3": "1664x2496",
|
||||
"21:9": "2688x1152"
|
||||
}
|
||||
},
|
||||
"counts": {"min": 1, "max": 4}
|
||||
},
|
||||
"image-01-live": {
|
||||
"name": "image-01-live",
|
||||
"class": "src.services.provider.minimax.MiniMaxImageService",
|
||||
"capabilities": {
|
||||
"supportsRefImage": true,
|
||||
"supportsLora": false
|
||||
},
|
||||
"resolutions": {
|
||||
"1K": {
|
||||
"1:1": "1024x1024",
|
||||
"16:9": "1280x720",
|
||||
"9:16": "720x1280",
|
||||
"4:3": "1152x864",
|
||||
"3:4": "864x1152",
|
||||
"3:2": "1248x832",
|
||||
"2:3": "832x1248",
|
||||
"21:9": "1344x576"
|
||||
},
|
||||
"2K": {
|
||||
"1:1": "2048x2048",
|
||||
"16:9": "2560x1440",
|
||||
"9:16": "1440x2560",
|
||||
"4:3": "2304x1728",
|
||||
"3:4": "1728x2304",
|
||||
"3:2": "2496x1664",
|
||||
"2:3": "1664x2496",
|
||||
"21:9": "2688x1152"
|
||||
}
|
||||
},
|
||||
"counts": {"min": 1, "max": 4}
|
||||
}
|
||||
}
|
||||
11
backend/src/config/services/minimax/llm.json
Normal file
11
backend/src/config/services/minimax/llm.json
Normal file
@@ -0,0 +1,11 @@
|
||||
{
|
||||
"base_url": "https://api.minimaxi.com/v1",
|
||||
"MiniMax-M2.11": {
|
||||
"name": "MiniMax M2.1",
|
||||
"class": "src.services.provider.openai_service.OpenAIService",
|
||||
"args": [
|
||||
"MiniMax-M2.1"
|
||||
],
|
||||
"enabled": true
|
||||
}
|
||||
}
|
||||
22
backend/src/config/services/minimax/music.json
Normal file
22
backend/src/config/services/minimax/music.json
Normal file
@@ -0,0 +1,22 @@
|
||||
{
|
||||
"lyrics-2.5": {
|
||||
"name": "lyrics-2.5",
|
||||
"class": "src.services.provider.minimax.MiniMaxMusicService",
|
||||
"args": [
|
||||
"music-2.5"
|
||||
],
|
||||
"type": "lyrics",
|
||||
"enabled": true,
|
||||
"is_default": true
|
||||
},
|
||||
"music-2.5": {
|
||||
"name": "music-2.5",
|
||||
"class": "src.services.provider.minimax.MiniMaxMusicService",
|
||||
"args": [
|
||||
"music-2.5"
|
||||
],
|
||||
"type": "music",
|
||||
"enabled": true,
|
||||
"is_default": true
|
||||
}
|
||||
}
|
||||
23
backend/src/config/services/minimax/provider.json
Normal file
23
backend/src/config/services/minimax/provider.json
Normal file
@@ -0,0 +1,23 @@
|
||||
{
|
||||
"id": "minimax",
|
||||
"name": "MiniMax",
|
||||
"description": "海螺AI",
|
||||
"dashboard_url": "https://platform.minimaxi.com/",
|
||||
"helpUrl": "https://platform.minimaxi.com/user-center/basic-information/interface-key",
|
||||
"fields": [
|
||||
{
|
||||
"name": "apiKey",
|
||||
"label": "API Key",
|
||||
"placeholder": "sk-api-...",
|
||||
"required": true,
|
||||
"type": "password"
|
||||
},
|
||||
{
|
||||
"name": "groupId",
|
||||
"label": "Group ID (可选)",
|
||||
"placeholder": "分组 ID",
|
||||
"required": false,
|
||||
"type": "text"
|
||||
}
|
||||
]
|
||||
}
|
||||
68
backend/src/config/services/minimax/video.json
Normal file
68
backend/src/config/services/minimax/video.json
Normal file
@@ -0,0 +1,68 @@
|
||||
{
|
||||
"MiniMax-Hailuo-2.3": {
|
||||
"name": "Hailuo 2.3",
|
||||
"class": "src.services.provider.minimax.MiniMaxVideoService",
|
||||
"capabilities": {
|
||||
"supportsTextToVideo": true,
|
||||
"supportsImageToVideo": true,
|
||||
"supportsFirstFrame": true,
|
||||
"supportsLastFrame": true,
|
||||
"supportsMultiImage": false,
|
||||
"supportsMultiVideo": false,
|
||||
"supportsAudio": false,
|
||||
"supportsNegativePrompt": false,
|
||||
"supportsLora": false
|
||||
},
|
||||
"durations": {
|
||||
"values": [
|
||||
6,
|
||||
10
|
||||
]
|
||||
},
|
||||
"counts": {"min": 1, "max": 4}
|
||||
},
|
||||
"MiniMax-Hailuo-02": {
|
||||
"name": "Hailuo 02",
|
||||
"class": "src.services.provider.minimax.MiniMaxVideoService",
|
||||
"capabilities": {
|
||||
"supportsTextToVideo": true,
|
||||
"supportsImageToVideo": true,
|
||||
"supportsFirstFrame": true,
|
||||
"supportsLastFrame": true,
|
||||
"supportsMultiImage": false,
|
||||
"supportsMultiVideo": false,
|
||||
"supportsAudio": false,
|
||||
"supportsNegativePrompt": false,
|
||||
"supportsLora": false
|
||||
},
|
||||
"durations": {
|
||||
"values": [
|
||||
6,
|
||||
10
|
||||
]
|
||||
},
|
||||
"counts": {"min": 1, "max": 4}
|
||||
},
|
||||
"T2V-01-Director": {
|
||||
"name": "T2V-01-Director",
|
||||
"class": "src.services.provider.minimax.MiniMaxVideoService",
|
||||
"capabilities": {
|
||||
"supportsTextToVideo": true,
|
||||
"supportsImageToVideo": true,
|
||||
"supportsFirstFrame": true,
|
||||
"supportsLastFrame": true,
|
||||
"supportsMultiImage": false,
|
||||
"supportsMultiVideo": false,
|
||||
"supportsAudio": false,
|
||||
"supportsNegativePrompt": false,
|
||||
"supportsLora": false
|
||||
},
|
||||
"durations": {
|
||||
"values": [
|
||||
6,
|
||||
10
|
||||
]
|
||||
},
|
||||
"counts": {"min": 1, "max": 4}
|
||||
}
|
||||
}
|
||||
143
backend/src/config/services/modelscope/image.json
Normal file
143
backend/src/config/services/modelscope/image.json
Normal file
@@ -0,0 +1,143 @@
|
||||
{
|
||||
"qwen-image": {
|
||||
"name": "Qwen Image",
|
||||
"class": "backend.src.services.provider.modelscope.image.ModelScopeImageService",
|
||||
"args": [
|
||||
"Qwen/Qwen-Image"
|
||||
],
|
||||
"capabilities": {
|
||||
"supportsRefImage": false,
|
||||
"supportsLora": true
|
||||
},
|
||||
"resolutions": {
|
||||
"1K": {
|
||||
"16:9": "1664x928",
|
||||
"9:16": "928x1664",
|
||||
"1:1": "1328x1328"
|
||||
},
|
||||
"2K": {
|
||||
"16:9": "2560x1440",
|
||||
"9:16": "1440x2560",
|
||||
"1:1": "2048x2048"
|
||||
}
|
||||
},
|
||||
"counts": {"min": 1, "max": 4},
|
||||
"enabled": true
|
||||
},
|
||||
"qwen-image-edit": {
|
||||
"name": "Qwen Image Edit",
|
||||
"class": "backend.src.services.provider.modelscope.image.ModelScopeImageService",
|
||||
"args": [
|
||||
"Qwen/Qwen-Image-Edit-2511"
|
||||
],
|
||||
"capabilities": {
|
||||
"supportsRefImage": true,
|
||||
"supportsLora": true
|
||||
},
|
||||
"resolutions": {
|
||||
"1K": {
|
||||
"16:9": "1664x928",
|
||||
"9:16": "928x1664",
|
||||
"1:1": "1328x1328",
|
||||
"4:3": "1328x1024",
|
||||
"3:4": "1024x1328"
|
||||
},
|
||||
"2K": {
|
||||
"16:9": "2560x1440",
|
||||
"9:16": "1440x2560",
|
||||
"1:1": "2048x2048",
|
||||
"4:3": "2560x1920",
|
||||
"3:4": "1920x2560"
|
||||
}
|
||||
},
|
||||
"counts": {"min": 1, "max": 4},
|
||||
"enabled": true
|
||||
},
|
||||
"flux-dev": {
|
||||
"name": "FLUX.2 Dev",
|
||||
"class": "backend.src.services.provider.modelscope.image.ModelScopeImageService",
|
||||
"args": [
|
||||
"black-forest-labs/FLUX.2-dev"
|
||||
],
|
||||
"capabilities": {
|
||||
"supportsRefImage": false,
|
||||
"supportsLora": true
|
||||
},
|
||||
"resolutions": {
|
||||
"1K": {
|
||||
"16:9": "1280x720",
|
||||
"9:16": "720x1280",
|
||||
"1:1": "1024x1024",
|
||||
"4:3": "1024x768",
|
||||
"3:4": "768x1024"
|
||||
},
|
||||
"2K": {
|
||||
"16:9": "2560x1440",
|
||||
"9:16": "1440x2560",
|
||||
"1:1": "2048x2048",
|
||||
"4:3": "2048x1536",
|
||||
"3:4": "1536x2048"
|
||||
}
|
||||
},
|
||||
"counts": {"min": 1, "max": 4},
|
||||
"enabled": true
|
||||
},
|
||||
"z-image-turbo": {
|
||||
"name": "Z Image Turbo",
|
||||
"class": "backend.src.services.provider.modelscope.image.ModelScopeImageService",
|
||||
"args": [
|
||||
"Tongyi-MAI/Z-Image-Turbo"
|
||||
],
|
||||
"capabilities": {
|
||||
"supportsRefImage": false,
|
||||
"supportsLora": true
|
||||
},
|
||||
"resolutions": {
|
||||
"1K": {
|
||||
"16:9": "1280x720",
|
||||
"9:16": "720x1280",
|
||||
"1:1": "1024x1024",
|
||||
"4:3": "1024x768",
|
||||
"3:4": "768x1024"
|
||||
},
|
||||
"2K": {
|
||||
"16:9": "2560x1440",
|
||||
"9:16": "1440x2560",
|
||||
"1:1": "2048x2048",
|
||||
"4:3": "2048x1536",
|
||||
"3:4": "1536x2048"
|
||||
}
|
||||
},
|
||||
"counts": {"min": 1, "max": 4},
|
||||
"enabled": true
|
||||
},
|
||||
"awportrait-z": {
|
||||
"name": "AWPortrait Z",
|
||||
"class": "backend.src.services.provider.modelscope.image.ModelScopeImageService",
|
||||
"args": [
|
||||
"LiblibAI/AWPortrait-Z"
|
||||
],
|
||||
"capabilities": {
|
||||
"supportsRefImage": false,
|
||||
"supportsLora": true
|
||||
},
|
||||
"resolutions": {
|
||||
"1K": {
|
||||
"16:9": "1280x720",
|
||||
"9:16": "720x1280",
|
||||
"1:1": "1024x1024",
|
||||
"4:3": "1024x768",
|
||||
"3:4": "768x1024"
|
||||
},
|
||||
"2K": {
|
||||
"16:9": "2560x1440",
|
||||
"9:16": "1440x2560",
|
||||
"1:1": "2048x2048",
|
||||
"4:3": "2048x1536",
|
||||
"3:4": "1536x2048"
|
||||
}
|
||||
},
|
||||
"counts": {"min": 1, "max": 4},
|
||||
"enabled": true
|
||||
}
|
||||
}
|
||||
16
backend/src/config/services/modelscope/provider.json
Normal file
16
backend/src/config/services/modelscope/provider.json
Normal file
@@ -0,0 +1,16 @@
|
||||
{
|
||||
"id": "modelscope",
|
||||
"name": "ModelScope",
|
||||
"description": "开源模型平台",
|
||||
"dashboard_url": "https://modelscope.cn/",
|
||||
"helpUrl": "https://www.modelscope.cn/my/myaccesstoken",
|
||||
"fields": [
|
||||
{
|
||||
"name": "apiKey",
|
||||
"label": "API Token",
|
||||
"placeholder": "ms-...",
|
||||
"required": true,
|
||||
"type": "password"
|
||||
}
|
||||
]
|
||||
}
|
||||
11
backend/src/config/services/moonshot/llm.json
Normal file
11
backend/src/config/services/moonshot/llm.json
Normal file
@@ -0,0 +1,11 @@
|
||||
{
|
||||
"base_url": "https://api.moonshot.cn/v1",
|
||||
"kimi-k2.5": {
|
||||
"name": "Kimi 2.5",
|
||||
"class": "backend.src.services.provider.openai_service.OpenAIService",
|
||||
"args": [
|
||||
"kimi-k2.5"
|
||||
],
|
||||
"enabled": true
|
||||
}
|
||||
}
|
||||
16
backend/src/config/services/moonshot/provider.json
Normal file
16
backend/src/config/services/moonshot/provider.json
Normal file
@@ -0,0 +1,16 @@
|
||||
{
|
||||
"id": "moonshot",
|
||||
"name": "月之暗面",
|
||||
"description": "Kimi 大模型",
|
||||
"dashboard_url": "https://platform.moonshot.cn/",
|
||||
"helpUrl": "https://platform.moonshot.cn/",
|
||||
"fields": [
|
||||
{
|
||||
"name": "apiKey",
|
||||
"label": "API Key",
|
||||
"placeholder": "sk-...",
|
||||
"required": true,
|
||||
"type": "password"
|
||||
}
|
||||
]
|
||||
}
|
||||
42
backend/src/config/services/openai/audio.json
Normal file
42
backend/src/config/services/openai/audio.json
Normal file
@@ -0,0 +1,42 @@
|
||||
{
|
||||
"tts-1": {
|
||||
"name": "TTS-1",
|
||||
"class": "src.services.provider.openai.OpenAIAudioService",
|
||||
"args": [
|
||||
"tts-1"
|
||||
],
|
||||
"voices": [
|
||||
{
|
||||
"id": "alloy",
|
||||
"name": "Alloy",
|
||||
"gender": "female"
|
||||
},
|
||||
{
|
||||
"id": "nova",
|
||||
"name": "Nova",
|
||||
"gender": "female"
|
||||
},
|
||||
{
|
||||
"id": "shimmer",
|
||||
"name": "Shimmer",
|
||||
"gender": "female"
|
||||
},
|
||||
{
|
||||
"id": "fable",
|
||||
"name": "Fable",
|
||||
"gender": "female"
|
||||
},
|
||||
{
|
||||
"id": "echo",
|
||||
"name": "Echo",
|
||||
"gender": "male"
|
||||
},
|
||||
{
|
||||
"id": "onyx",
|
||||
"name": "Onyx",
|
||||
"gender": "male"
|
||||
}
|
||||
],
|
||||
"enabled": true
|
||||
}
|
||||
}
|
||||
25
backend/src/config/services/openai/image.json
Normal file
25
backend/src/config/services/openai/image.json
Normal file
@@ -0,0 +1,25 @@
|
||||
{
|
||||
"dall-e-3": {
|
||||
"name": "DALL-E 3",
|
||||
"class": "src.services.provider.openai.OpenAIImageService",
|
||||
"args": [
|
||||
"dall-e-3"
|
||||
],
|
||||
"capabilities": {
|
||||
"supportsRefImage": false,
|
||||
"supportsLora": false
|
||||
},
|
||||
"variants": {
|
||||
"t2i": "dall-e-3"
|
||||
},
|
||||
"resolutions": {
|
||||
"1K": {
|
||||
"1:1": "1024*1024",
|
||||
"16:9": "1792*1024",
|
||||
"9:16": "1024*1792"
|
||||
}
|
||||
},
|
||||
"counts": {"min": 1, "max": 1},
|
||||
"enabled": true
|
||||
}
|
||||
}
|
||||
16
backend/src/config/services/openai/provider.json
Normal file
16
backend/src/config/services/openai/provider.json
Normal file
@@ -0,0 +1,16 @@
|
||||
{
|
||||
"id": "openai",
|
||||
"name": "OpenAI",
|
||||
"description": "GPT-4、DALL-E、Sora",
|
||||
"dashboard_url": "https://platform.openai.com/",
|
||||
"helpUrl": "https://platform.openai.com/api-keys",
|
||||
"fields": [
|
||||
{
|
||||
"name": "apiKey",
|
||||
"label": "API Key",
|
||||
"placeholder": "sk-...",
|
||||
"required": true,
|
||||
"type": "password"
|
||||
}
|
||||
]
|
||||
}
|
||||
7
backend/src/config/services/provider.json.example
Normal file
7
backend/src/config/services/provider.json.example
Normal file
@@ -0,0 +1,7 @@
|
||||
{
|
||||
"id": "provider-id",
|
||||
"name": "Provider Display Name",
|
||||
"description": "Optional description",
|
||||
"api_key": "YOUR_API_KEY_OR_SET_VIA_ENV",
|
||||
"dashboard_url": "https://example.com/dashboard"
|
||||
}
|
||||
42
backend/src/config/services/volcengine/image.json
Normal file
42
backend/src/config/services/volcengine/image.json
Normal file
@@ -0,0 +1,42 @@
|
||||
{
|
||||
"doubao-seedream-4.5": {
|
||||
"name": "SeeDream 4.5",
|
||||
"class": "src.services.provider.volcengine.image.VolcengineImageService",
|
||||
"args": [
|
||||
"doubao-seedream-4.5"
|
||||
],
|
||||
"capabilities": {
|
||||
"supportsRefImage": true,
|
||||
"supportsLora": false
|
||||
},
|
||||
"variants": {
|
||||
"t2i": "doubao-seedream-4-5-251128",
|
||||
"i2i": "doubao-seedream-4-5-251128"
|
||||
},
|
||||
"resolutions": {
|
||||
"1K": {
|
||||
"16:9": "1280x720",
|
||||
"9:16": "720x1280",
|
||||
"1:1": "1024x1024",
|
||||
"4:3": "1152x864",
|
||||
"3:4": "864x1152"
|
||||
},
|
||||
"2K": {
|
||||
"16:9": "2560x1440",
|
||||
"9:16": "1440x2560",
|
||||
"1:1": "2048x2048",
|
||||
"4:3": "2304x1728",
|
||||
"3:4": "1728x2304"
|
||||
},
|
||||
"4K": {
|
||||
"16:9": "3840x2160",
|
||||
"9:16": "2160x3840",
|
||||
"1:1": "4096x4096",
|
||||
"4:3": "3456x2592",
|
||||
"3:4": "2592x3456"
|
||||
}
|
||||
},
|
||||
"counts": {"min": 1, "max": 4},
|
||||
"enabled": true
|
||||
}
|
||||
}
|
||||
11
backend/src/config/services/volcengine/llm.json
Normal file
11
backend/src/config/services/volcengine/llm.json
Normal file
@@ -0,0 +1,11 @@
|
||||
{
|
||||
"base_url": "https://ark.cn-beijing.volces.com/api/v3",
|
||||
"doubao-1.5-pro": {
|
||||
"name": "Doubao 1.5 Pro",
|
||||
"class": "backend.src.services.provider.openai_service.OpenAIService",
|
||||
"args": [
|
||||
"doubao-1-5-pro-32k-250115"
|
||||
],
|
||||
"enabled": true
|
||||
}
|
||||
}
|
||||
16
backend/src/config/services/volcengine/provider.json
Normal file
16
backend/src/config/services/volcengine/provider.json
Normal file
@@ -0,0 +1,16 @@
|
||||
{
|
||||
"id": "volcengine",
|
||||
"name": "火山引擎",
|
||||
"description": "豆包大模型",
|
||||
"dashboard_url": "https://console.volcengine.com/ark/",
|
||||
"helpUrl": "https://console.volcengine.com/ark/region:ark+cn-beijing/apiKey",
|
||||
"fields": [
|
||||
{
|
||||
"name": "apiKey",
|
||||
"label": "API Key",
|
||||
"placeholder": "...",
|
||||
"required": true,
|
||||
"type": "password"
|
||||
}
|
||||
]
|
||||
}
|
||||
138
backend/src/config/services/volcengine/video.json
Normal file
138
backend/src/config/services/volcengine/video.json
Normal file
@@ -0,0 +1,138 @@
|
||||
{
|
||||
"doubao-seedance-1.5-pro": {
|
||||
"name": "SeeDance 1.5 Pro",
|
||||
"class": "src.services.provider.volcengine.video.VolcengineVideoService",
|
||||
"args": [
|
||||
"doubao-seedance-1.5-pro"
|
||||
],
|
||||
"capabilities": {
|
||||
"supportsTextToVideo": true,
|
||||
"supportsImageToVideo": true,
|
||||
"supportsFirstFrame": true,
|
||||
"supportsLastFrame": true,
|
||||
"supportsMultiImage": false,
|
||||
"supportsMultiVideo": false,
|
||||
"supportsLora": false
|
||||
},
|
||||
"variants": {
|
||||
"t2v": "doubao-seedance-1-5-pro-251215",
|
||||
"i2v": "doubao-seedance-1-5-pro-251215",
|
||||
"kf2v": "doubao-seedance-1-5-pro-251215"
|
||||
},
|
||||
"resolutions": {
|
||||
"720P": {
|
||||
"16:9": "1280x720",
|
||||
"9:16": "720x1280",
|
||||
"1:1": "1280x1280",
|
||||
"4:3": "1280x960",
|
||||
"3:4": "960x1280"
|
||||
}
|
||||
},
|
||||
"durations": {
|
||||
"values": [
|
||||
4,
|
||||
12
|
||||
]
|
||||
},
|
||||
"counts": {
|
||||
"min": 1,
|
||||
"max": 4
|
||||
},
|
||||
"enabled": true
|
||||
},
|
||||
"doubao-seedance-1.0-pro": {
|
||||
"name": "SeeDance 1.0 Pro",
|
||||
"class": "src.services.provider.volcengine.video.VolcengineVideoService",
|
||||
"args": [
|
||||
"doubao-seedance-1-0-pro-250528"
|
||||
],
|
||||
"capabilities": {
|
||||
"supportsTextToVideo": true,
|
||||
"supportsImageToVideo": true,
|
||||
"supportsFirstFrame": true,
|
||||
"supportsLastFrame": true,
|
||||
"supportsMultiImage": false,
|
||||
"supportsMultiVideo": false,
|
||||
"supportsLora": false
|
||||
},
|
||||
"variants": {
|
||||
"t2v": "doubao-seedance-1-0-pro-250528",
|
||||
"i2v": "doubao-seedance-1-0-pro-250528",
|
||||
"kf2v": "doubao-seedance-1-0-pro-250528"
|
||||
},
|
||||
"resolutions": {
|
||||
"720P": {
|
||||
"16:9": "1248×704",
|
||||
"9:16": "704x1248",
|
||||
"1:1": "960×960",
|
||||
"4:3": "1120×832",
|
||||
"3:4": "832x1120"
|
||||
},
|
||||
"1080P": {
|
||||
"16:9": "1920x1088",
|
||||
"9:16": "1088x1920",
|
||||
"1:1": "1440×1440",
|
||||
"4:3": "1664×1248",
|
||||
"3:4": "1248x1664"
|
||||
}
|
||||
},
|
||||
"durations": {
|
||||
"values": [
|
||||
2,
|
||||
12
|
||||
]
|
||||
},
|
||||
"counts": {
|
||||
"min": 1,
|
||||
"max": 4
|
||||
},
|
||||
"enabled": true
|
||||
},
|
||||
"doubao-seedance-1.0-pro-fast": {
|
||||
"name": "SeeDance 1.0 Pro Fast",
|
||||
"class": "src.services.provider.volcengine.video.VolcengineVideoService",
|
||||
"args": [
|
||||
"doubao-seedance-1-0-pro-fast-251015"
|
||||
],
|
||||
"capabilities": {
|
||||
"supportsTextToVideo": true,
|
||||
"supportsImageToVideo": true,
|
||||
"supportsFirstFrame": true,
|
||||
"supportsLastFrame": false,
|
||||
"supportsMultiImage": false,
|
||||
"supportsMultiVideo": false,
|
||||
"supportsLora": false
|
||||
},
|
||||
"variants": {
|
||||
"t2v": "doubao-seedance-1-0-pro-fast-251015",
|
||||
"i2v": "doubao-seedance-1-0-pro-fast-251015"
|
||||
},
|
||||
"resolutions": {
|
||||
"720P": {
|
||||
"16:9": "1248×704",
|
||||
"9:16": "704x1248",
|
||||
"1:1": "960×960",
|
||||
"4:3": "1120×832",
|
||||
"3:4": "832x1120"
|
||||
},
|
||||
"1080P": {
|
||||
"16:9": "1920x1088",
|
||||
"9:16": "1088x1920",
|
||||
"1:1": "1440×1440",
|
||||
"4:3": "1664×1248",
|
||||
"3:4": "1248x1664"
|
||||
}
|
||||
},
|
||||
"durations": {
|
||||
"values": [
|
||||
2,
|
||||
12
|
||||
]
|
||||
},
|
||||
"counts": {
|
||||
"min": 1,
|
||||
"max": 4
|
||||
},
|
||||
"enabled": true
|
||||
}
|
||||
}
|
||||
98
backend/src/config/settings.py
Normal file
98
backend/src/config/settings.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import logging
|
||||
import os
|
||||
import json
|
||||
from dotenv import load_dotenv
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
# Load Storage Configuration
|
||||
SETTINGS_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
STORAGE_CONFIG_PATH = os.path.join(SETTINGS_DIR, 'storage.json')
|
||||
|
||||
storage_config = {}
|
||||
if os.path.exists(STORAGE_CONFIG_PATH):
|
||||
try:
|
||||
with open(STORAGE_CONFIG_PATH, 'r') as f:
|
||||
storage_config = json.load(f)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load storage config: {e}")
|
||||
|
||||
# 服务器 Configuration
|
||||
PY_PORT = int(os.getenv('PY_PORT', '8000'))
|
||||
NODE_ENV = os.getenv('NODE_ENV', 'development')
|
||||
|
||||
# CORS Configuration
|
||||
ALLOWED_ORIGINS = [o for o in os.getenv('CORS_ALLOWED_ORIGINS', '').split(',') if o]
|
||||
DEV_ALLOWED_ORIGINS = [
|
||||
o for o in os.getenv(
|
||||
'CORS_DEV_ALLOWED_ORIGINS',
|
||||
'http://localhost:3000,http://127.0.0.1:3000,http://localhost:3001,http://127.0.0.1:3001'
|
||||
).split(',') if o
|
||||
]
|
||||
ALLOW_DEV_ORIGINS = (os.getenv('ALLOW_DEV_ORIGINS', '1') != '0' and NODE_ENV != 'production')
|
||||
|
||||
# 数据库 Configuration
|
||||
DATA_DIR = os.getenv('DATA_DIR') or os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', 'data'))
|
||||
DB_PATH = os.getenv('DB_PATH') or os.path.join(DATA_DIR, 'pixel.db')
|
||||
DATABASE_URL = os.getenv('DATABASE_URL')
|
||||
|
||||
# Alibaba Cloud OSS Configuration (env vars take precedence over storage.json)
|
||||
OSS_REGION = os.getenv('OSS_REGION') or storage_config.get('OSS_REGION', 'oss-cn-shanghai')
|
||||
OSS_ENDPOINT = os.getenv('OSS_ENDPOINT') or storage_config.get('OSS_ENDPOINT', 'oss-cn-shanghai.aliyuncs.com')
|
||||
|
||||
OSS_BUCKET = os.getenv('OSS_BUCKET') or storage_config.get('OSS_BUCKET')
|
||||
ALIBABA_CLOUD_ACCESS_KEY_ID = os.getenv('ALIBABA_CLOUD_ACCESS_KEY_ID') or storage_config.get('ALIBABA_CLOUD_ACCESS_KEY_ID')
|
||||
ALIBABA_CLOUD_ACCESS_KEY_SECRET = os.getenv('ALIBABA_CLOUD_ACCESS_KEY_SECRET') or storage_config.get('ALIBABA_CLOUD_ACCESS_KEY_SECRET')
|
||||
|
||||
# DashScope Configuration (Qwen, Wanx)
|
||||
DASHSCOPE_API_KEY = os.getenv('DASHSCOPE_API_KEY')
|
||||
|
||||
# 模型Scope Configuration
|
||||
MODELSCOPE_API_TOKEN = os.getenv('MODELSCOPE_API_TOKEN')
|
||||
|
||||
# Volcengine Configuration
|
||||
VOLCENGINE_API_KEY = os.getenv('VOLCENGINE_API_KEY') # 火山方舟 (LLM)
|
||||
|
||||
# Google Configuration
|
||||
GOOGLE_API_KEY = os.getenv('GOOGLE_API_KEY')
|
||||
|
||||
# OpenAI Configuration
|
||||
OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
|
||||
OPENAI_BASE_URL = os.getenv('OPENAI_BASE_URL') # 可选: for proxies
|
||||
|
||||
# MiniMax Configuration
|
||||
MINIMAX_API_KEY = os.getenv('MINIMAX_API_KEY')
|
||||
MINIMAX_GROUP_ID = os.getenv('MINIMAX_GROUP_ID') # Sometimes needed
|
||||
|
||||
# Kling AI Configuration
|
||||
KLING_ACCESS_KEY = os.getenv('KLING_ACCESS_KEY')
|
||||
KLING_SECRET_KEY = os.getenv('KLING_SECRET_KEY')
|
||||
KLING_API_BASE = os.getenv('KLING_API_BASE', 'https://api-beijing.klingai.com/v1')
|
||||
|
||||
# Midjourney Configuration (Youchuan / Proxy)
|
||||
MIDJOURNEY_API_KEY = os.getenv('MIDJOURNEY_API_KEY')
|
||||
MIDJOURNEY_PROXY_URL = os.getenv('MIDJOURNEY_PROXY_URL')
|
||||
YOUCHUAN_APP_ID = os.getenv('YOUCHUAN_APP_ID')
|
||||
YOUCHUAN_SECRET_KEY = os.getenv('YOUCHUAN_SECRET_KEY')
|
||||
|
||||
# Application Settings
|
||||
UPLOAD_DIR = os.path.join(DATA_DIR, 'uploads')
|
||||
|
||||
# 存储 Configuration
|
||||
STORAGE_TYPE = os.getenv('STORAGE_TYPE') or storage_config.get('STORAGE_TYPE', 'local') # 'local' or 'oss'
|
||||
PROJECTS_DIR = os.path.join(DATA_DIR, "projects")
|
||||
CANVAS_DIR = os.path.join(DATA_DIR, "canvas")
|
||||
|
||||
# Redis Configuration
|
||||
REDIS_URL = os.getenv('REDIS_URL', 'redis://localhost:6379')
|
||||
REDIS_ENABLED = os.getenv('REDIS_ENABLED', '1') != '0'
|
||||
|
||||
# 追踪 Configuration
|
||||
TRACING_ENABLED = os.getenv('TRACING_ENABLED', '0') != '0'
|
||||
OTLP_ENDPOINT = os.getenv('OTLP_ENDPOINT', 'http://localhost:4317')
|
||||
|
||||
# Task Management Configuration
|
||||
# Unified task manager is now always used
|
||||
6
backend/src/config/storage.example.json
Normal file
6
backend/src/config/storage.example.json
Normal file
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"OSS_REGION": "oss-cn-shanghai",
|
||||
"OSS_ENDPOINT": "oss-cn-shanghai.aliyuncs.com",
|
||||
"OSS_BUCKET": "your-bucket-name",
|
||||
"STORAGE_TYPE": "oss"
|
||||
}
|
||||
152
backend/src/config/styles.json
Normal file
152
backend/src/config/styles.json
Normal file
@@ -0,0 +1,152 @@
|
||||
{
|
||||
"styles": [
|
||||
{
|
||||
"id": "cyberpunk",
|
||||
"name": "赛博朋克",
|
||||
"type": "prompt",
|
||||
"desc": "高对比度霓虹灯光,未来主义建筑,机械改造",
|
||||
"prompt": "cyberpunk style, neon lights, high contrast, futuristic buildings, mechanical modifications",
|
||||
"color": "from-pink-500/20 to-blue-500/20"
|
||||
},
|
||||
{
|
||||
"id": "ink",
|
||||
"name": "水墨风",
|
||||
"type": "prompt",
|
||||
"desc": "传统中国水墨渲染,黑白灰为主,意境深远",
|
||||
"prompt": "traditional chinese ink painting style, black and white, artistic conception, ink wash",
|
||||
"color": "from-gray-200/10 to-gray-900/10"
|
||||
},
|
||||
{
|
||||
"id": "pixar",
|
||||
"name": "皮克斯风格",
|
||||
"type": "prompt",
|
||||
"desc": "3D卡通渲染,色彩鲜艳,光影柔和,表情夸张",
|
||||
"prompt": "pixar style, 3d cartoon rendering, vibrant colors, soft lighting, expressive",
|
||||
"color": "from-orange-400/20 to-yellow-400/20"
|
||||
},
|
||||
{
|
||||
"id": "anime",
|
||||
"name": "日漫",
|
||||
"type": "prompt",
|
||||
"desc": "典型日本动画风格,线条清晰,赛璐璐上色",
|
||||
"prompt": "japanese anime style, clear lines, cel shading, 2d animation",
|
||||
"color": "from-purple-500/20 to-pink-500/20"
|
||||
},
|
||||
{
|
||||
"id": "chinese-anime",
|
||||
"name": "国漫风",
|
||||
"type": "prompt",
|
||||
"desc": "中国现代动画风格,融合传统与现代元素,色彩华丽",
|
||||
"prompt": "chinese donghua style, chinese anime, elegant, detailed background, vibrant colors, mix of traditional and modern aesthetics",
|
||||
"color": "from-red-500/20 to-yellow-500/20"
|
||||
},
|
||||
{
|
||||
"id": "cel-shading",
|
||||
"name": "赛璐璐风",
|
||||
"type": "prompt",
|
||||
"desc": "经典的赛璐璐上色风格,阴影边缘硬朗,色彩鲜明",
|
||||
"prompt": "cel shading, hard shadows, flat colors, anime coloring style, clean lines",
|
||||
"color": "from-blue-400/20 to-indigo-400/20"
|
||||
},
|
||||
{
|
||||
"id": "korean-webtoon",
|
||||
"name": "韩漫风",
|
||||
"type": "prompt",
|
||||
"desc": "韩国条漫风格,人物美型,色彩明亮,光影细腻",
|
||||
"prompt": "manhwa style, korean webtoon, beautiful characters, detailed eyes, soft lighting, vibrant digital art",
|
||||
"color": "from-pink-400/20 to-rose-400/20"
|
||||
},
|
||||
{
|
||||
"id": "american-comic",
|
||||
"name": "美漫风",
|
||||
"type": "prompt",
|
||||
"desc": "美国漫画风格,线条粗犷,阴影浓重,动态感强",
|
||||
"prompt": "american comic book style, bold lines, heavy shadows, dynamic poses, halftone patterns, marvel/dc style",
|
||||
"color": "from-red-600/20 to-blue-600/20"
|
||||
},
|
||||
{
|
||||
"id": "ghibli",
|
||||
"name": "吉卜力风格",
|
||||
"type": "prompt",
|
||||
"desc": "宫崎骏动画风格,色彩清新自然,细节丰富,治愈系",
|
||||
"prompt": "studio ghibli style, miyazaki hayao style, anime scenery, vibrant colors, lush greenery, detailed clouds, soothing atmosphere",
|
||||
"color": "from-green-400/20 to-blue-400/20"
|
||||
},
|
||||
{
|
||||
"id": "realistic",
|
||||
"name": "写实",
|
||||
"type": "prompt",
|
||||
"desc": "电影级写实渲染,细节丰富,光照真实",
|
||||
"prompt": "cinematic realistic, highly detailed, photorealistic, 8k, movie quality",
|
||||
"color": "from-blue-900/20 to-slate-900/20"
|
||||
},
|
||||
{
|
||||
"id": "hand-drawn",
|
||||
"name": "手绘",
|
||||
"type": "prompt",
|
||||
"desc": "传统手绘质感,笔触明显,艺术感强",
|
||||
"prompt": "hand-drawn style, visible brush strokes, artistic, sketch",
|
||||
"color": "from-emerald-500/20 to-teal-500/20"
|
||||
},
|
||||
{
|
||||
"id": "watercolor",
|
||||
"name": "水彩画",
|
||||
"type": "prompt",
|
||||
"desc": "水彩晕染效果,色彩通透,艺术感强",
|
||||
"prompt": "watercolor painting, wet on wet, soft blending, artistic, translucent colors, paper texture",
|
||||
"color": "from-cyan-400/20 to-blue-300/20"
|
||||
},
|
||||
{
|
||||
"id": "oil-painting",
|
||||
"name": "油画",
|
||||
"type": "prompt",
|
||||
"desc": "厚涂油画质感,笔触丰富,光影层次感强",
|
||||
"prompt": "oil painting, impasto, textured canvas, classical art, rich colors, visible brushwork",
|
||||
"color": "from-amber-700/20 to-yellow-600/20"
|
||||
},
|
||||
{
|
||||
"id": "pixel-art",
|
||||
"name": "像素风",
|
||||
"type": "prompt",
|
||||
"desc": "复古8-bit/16-bit像素艺术,怀旧游戏风格",
|
||||
"prompt": "pixel art, 16-bit, retro game style, dot art, low resolution aesthetics",
|
||||
"color": "from-purple-600/20 to-indigo-600/20"
|
||||
},
|
||||
{
|
||||
"id": "ukiyo-e",
|
||||
"name": "浮世绘",
|
||||
"type": "prompt",
|
||||
"desc": "日本传统木刻版画风格,线条流畅,色彩古朴",
|
||||
"prompt": "ukiyo-e style, japanese woodblock print, traditional japanese art, flat colors, bold outlines",
|
||||
"color": "from-red-400/20 to-orange-300/20"
|
||||
},
|
||||
{
|
||||
"id": "vaporwave",
|
||||
"name": "蒸汽波",
|
||||
"type": "prompt",
|
||||
"desc": "80年代复古未来主义,霓虹色彩,故障艺术",
|
||||
"prompt": "vaporwave style, 80s retro aesthetics, neon pink and blue, glitch art, surrealism, statue, palm trees",
|
||||
"color": "from-pink-600/20 to-purple-600/20"
|
||||
},
|
||||
{
|
||||
"id": "low-poly",
|
||||
"name": "低多边形",
|
||||
"type": "prompt",
|
||||
"desc": "3D几何多边形风格,简约抽象,棱角分明",
|
||||
"prompt": "low poly style, 3d geometric, minimalist, angular, flat shading, isometric",
|
||||
"color": "from-blue-500/20 to-cyan-500/20"
|
||||
},
|
||||
{
|
||||
"id": "clay-style-lora",
|
||||
"name": "黏土风 (LoRA)",
|
||||
"type": "lora",
|
||||
"desc": "特殊的黏土材质风格",
|
||||
"lora": {
|
||||
"id": "lora-clay-v1",
|
||||
"base_model": "modelscope/qwen-image",
|
||||
"trigger_word": "clay style"
|
||||
},
|
||||
"color": "from-amber-600/20 to-orange-600/20"
|
||||
}
|
||||
]
|
||||
}
|
||||
8
backend/src/config/user_config.json
Normal file
8
backend/src/config/user_config.json
Normal file
@@ -0,0 +1,8 @@
|
||||
{
|
||||
"defaultImageModel": "dashscope/z-image",
|
||||
"defaultVideoModel": "dashscope/wan2.6-video",
|
||||
"defaultAudioModel": "minimax/speech-2.8-turbo",
|
||||
"defaultLLMModel": "dashscope/qwen-plus",
|
||||
"defaultStyle": "anime",
|
||||
"defaultAspectRatio": "16:9"
|
||||
}
|
||||
117
backend/src/constants/common.py
Normal file
117
backend/src/constants/common.py
Normal file
@@ -0,0 +1,117 @@
|
||||
from typing import List, Dict
|
||||
|
||||
# Character Options
|
||||
CHARACTER_ROLES: Dict[str, List[str]] = {
|
||||
"zh": ["主角", "配角", "反派", "龙套", "群演"],
|
||||
"en": ["Leading Role", "Supporting Role", "Villain", "Minor Role", "Extra"]
|
||||
}
|
||||
|
||||
CHARACTER_GENDERS: Dict[str, List[str]] = {
|
||||
"zh": ["男", "女", "未知"],
|
||||
"en": ["Male", "Female", "Unknown"]
|
||||
}
|
||||
|
||||
# Storyboard Options
|
||||
# Common shot types in filmmaking
|
||||
SHOT_TYPES: Dict[str, List[str]] = {
|
||||
"en": [
|
||||
"Extreme Long Shot (ELS)",
|
||||
"Long Shot (LS)",
|
||||
"Full Shot (FS)",
|
||||
"Medium Long Shot (MLS)",
|
||||
"Medium Shot (MS)",
|
||||
"Medium Close-Up (MCU)",
|
||||
"Close-Up (CU)",
|
||||
"Extreme Close-Up (ECU)",
|
||||
"Establishing Shot",
|
||||
"Point of View (POV)",
|
||||
"Over the Shoulder (OTS)"
|
||||
],
|
||||
"zh": [
|
||||
"大远景 (ELS)",
|
||||
"远景 (LS)",
|
||||
"全景 (FS)",
|
||||
"中远景 (MLS)",
|
||||
"中景 (MS)",
|
||||
"中特写 (MCU)",
|
||||
"特写 (CU)",
|
||||
"大特写 (ECU)",
|
||||
"建立镜头",
|
||||
"主观镜头 (POV)",
|
||||
"过肩镜头 (OTS)"
|
||||
]
|
||||
}
|
||||
|
||||
# Common camera movements
|
||||
CAMERA_MOVEMENTS: Dict[str, List[str]] = {
|
||||
"en": [
|
||||
"Static",
|
||||
"Pan Left",
|
||||
"Pan Right",
|
||||
"Tilt Up",
|
||||
"Tilt Down",
|
||||
"Zoom In",
|
||||
"Zoom Out",
|
||||
"Dolly In",
|
||||
"Dolly Out",
|
||||
"Truck Left",
|
||||
"Truck Right",
|
||||
"Pedestal Up",
|
||||
"Pedestal Down",
|
||||
"Tracking",
|
||||
"Arc",
|
||||
"Handheld",
|
||||
"Crane/Boom",
|
||||
"Drone/Aerial",
|
||||
"Rack Focus"
|
||||
],
|
||||
"zh": [
|
||||
"固定镜头 (Static)",
|
||||
"左摇 (Pan Left)",
|
||||
"右摇 (Pan Right)",
|
||||
"上仰 (Tilt Up)",
|
||||
"下俯 (Tilt Down)",
|
||||
"推镜头 (Zoom In)",
|
||||
"拉镜头 (Zoom Out)",
|
||||
"前移 (Dolly In)",
|
||||
"后移 (Dolly Out)",
|
||||
"左移 (Truck Left)",
|
||||
"右移 (Truck Right)",
|
||||
"升镜头 (Pedestal Up)",
|
||||
"降镜头 (Pedestal Down)",
|
||||
"跟随 (Tracking)",
|
||||
"环绕 (Arc)",
|
||||
"手持 (Handheld)",
|
||||
"摇臂 (Crane/Boom)",
|
||||
"航拍 (Drone/Aerial)",
|
||||
"变焦 (Rack Focus)"
|
||||
]
|
||||
}
|
||||
|
||||
# Common transitions
|
||||
TRANSITIONS: Dict[str, List[str]] = {
|
||||
"en": [
|
||||
"Cut",
|
||||
"Dissolve",
|
||||
"Fade In",
|
||||
"Fade Out",
|
||||
"Wipe",
|
||||
"Iris In",
|
||||
"Iris Out",
|
||||
"Match Cut",
|
||||
"Jump Cut",
|
||||
"Crossfade"
|
||||
],
|
||||
"zh": [
|
||||
"切 (Cut)",
|
||||
"叠化 (Dissolve)",
|
||||
"淡入 (Fade In)",
|
||||
"淡出 (Fade Out)",
|
||||
"划像 (Wipe)",
|
||||
"圈入 (Iris In)",
|
||||
"圈出 (Iris Out)",
|
||||
"匹配剪辑 (Match Cut)",
|
||||
"跳接 (Jump Cut)",
|
||||
"交叉淡入淡出 (Crossfade)"
|
||||
]
|
||||
}
|
||||
367
backend/src/main.py
Normal file
367
backend/src/main.py
Normal file
@@ -0,0 +1,367 @@
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.middleware.gzip import GZipMiddleware
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
# 导入 configuration
|
||||
from src.config.settings import (
|
||||
ALLOWED_ORIGINS,
|
||||
DEV_ALLOWED_ORIGINS,
|
||||
ALLOW_DEV_ORIGINS,
|
||||
UPLOAD_DIR,
|
||||
DATA_DIR,
|
||||
STORAGE_TYPE,
|
||||
NODE_ENV,
|
||||
REDIS_ENABLED,
|
||||
TRACING_ENABLED,
|
||||
OTLP_ENDPOINT
|
||||
)
|
||||
from src.config.database import init_db, engine
|
||||
# 导入所有实体模型,确保 SQLModel.metadata 包含所有表
|
||||
from src.models.entities import UserDB, UserApiKeyDB, ProjectDB, AssetDB, EpisodeDB, StoryboardDB, TaskDB, CanvasDB, CanvasMetadataDB
|
||||
from src.models.session import UserSessionDB
|
||||
from src.models.prompt_template import PromptTemplate, PromptTemplateFavorite
|
||||
from src.utils.service_loader import load_services_from_config
|
||||
from src.utils.logging import setup_logging
|
||||
from src.admin_config import setup_admin
|
||||
|
||||
# 初始化 logging system
|
||||
log_level = "DEBUG" if NODE_ENV == "development" else "INFO"
|
||||
setup_logging(level=log_level, use_json=True)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info(f"Starting Pixel API in {NODE_ENV} mode")
|
||||
|
||||
# 初始化 and Register Models
|
||||
# Load from configuration directory
|
||||
services_config_path = os.path.join(os.path.dirname(__file__), "config", "services")
|
||||
load_services_from_config(services_config_path)
|
||||
|
||||
from src.api import config, projects, generations, storage, chat, canvas, canvas_metadata, tasks, health, skills, user_api_keys, auth, admin, audit_logs, storage_admin, prompt_templates
|
||||
from src.middlewares.error_handler import setup_error_handler
|
||||
from src.middlewares.request_tracking import setup_request_tracking
|
||||
from src.middlewares.response_formatter import setup_response_formatter
|
||||
from src.middlewares.metrics import (
|
||||
setup_metrics_middleware,
|
||||
set_application_info,
|
||||
get_metrics,
|
||||
get_metrics_content_type
|
||||
)
|
||||
from src.middlewares.tracing import setup_tracing, shutdown_tracing
|
||||
from src.middlewares.rate_limiter import (
|
||||
setup_rate_limiter,
|
||||
init_rate_limiter,
|
||||
shutdown_rate_limiter
|
||||
)
|
||||
from src.middlewares.security import (
|
||||
setup_security_middleware,
|
||||
setup_security_headers_middleware,
|
||||
init_security_monitor,
|
||||
shutdown_security_monitor
|
||||
)
|
||||
from src.middlewares.performance import setup_performance_monitoring
|
||||
|
||||
# 验证 default model configuration
|
||||
from src.services.provider.registry import ModelRegistry, ModelType
|
||||
|
||||
try:
|
||||
# Load user config to check defaults
|
||||
user_config_path = os.path.join(os.path.dirname(__file__), "config", "user_config.json")
|
||||
if os.path.exists(user_config_path):
|
||||
with open(user_config_path, 'r', encoding='utf-8') as f:
|
||||
user_config = json.load(f)
|
||||
|
||||
# 验证 each default model
|
||||
model_mappings = {
|
||||
'defaultImageModel': ModelType.IMAGE,
|
||||
'defaultVideoModel': ModelType.VIDEO,
|
||||
'defaultAudioModel': ModelType.AUDIO,
|
||||
'defaultLyricsModel': ModelType.LYRICS,
|
||||
'defaultMusicModel': ModelType.MUSIC,
|
||||
'defaultLLMModel': ModelType.LLM
|
||||
}
|
||||
|
||||
for config_key, model_type in model_mappings.items():
|
||||
default_id = user_config.get(config_key)
|
||||
if default_id:
|
||||
model_config = ModelRegistry.get_config(default_id)
|
||||
if not model_config:
|
||||
logger.warning(
|
||||
f"⚠️ Configured {config_key} '{default_id}' not found in registry. "
|
||||
f"Please check your config/services/*.json files."
|
||||
)
|
||||
else:
|
||||
logger.info(f"✓ Default {model_type.value} model: {default_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to validate default models: {e}")
|
||||
|
||||
from fastapi.routing import APIRoute
|
||||
|
||||
def custom_generate_unique_id(route: APIRoute):
|
||||
return route.name
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Lifespan context manager for startup and shutdown events"""
|
||||
# 导入 Unified Task Manager
|
||||
from src.services.task_manager import task_manager
|
||||
|
||||
# 启动up
|
||||
logger.info("Application startup")
|
||||
|
||||
# 初始化 database
|
||||
init_db()
|
||||
|
||||
# 集合 application info for metrics
|
||||
set_application_info(version="0.1.0", environment=NODE_ENV)
|
||||
|
||||
# 初始化 cache service if Redis is enabled
|
||||
if REDIS_ENABLED:
|
||||
from src.services.cache_service import get_cache_service
|
||||
cache = get_cache_service()
|
||||
await cache.connect()
|
||||
|
||||
# 初始化 rate limiter
|
||||
await init_rate_limiter()
|
||||
|
||||
# 初始化 security monitor
|
||||
await init_security_monitor()
|
||||
|
||||
# 启动 Unified Task Manager
|
||||
await task_manager.start()
|
||||
logger.info("✓ Unified Task Manager started")
|
||||
|
||||
# 清理up stuck projects
|
||||
from src.services.project_service import project_manager
|
||||
project_manager.cleanup_stuck_projects()
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown
|
||||
logger.info("Application shutdown")
|
||||
|
||||
# 停止 Unified Task Manager
|
||||
await task_manager.stop()
|
||||
logger.info("✓ Unified Task Manager stopped")
|
||||
|
||||
# Shutdown tracing
|
||||
if TRACING_ENABLED:
|
||||
shutdown_tracing()
|
||||
|
||||
# Disconnect cache service and rate limiter
|
||||
if REDIS_ENABLED:
|
||||
from src.services.cache_service import get_cache_service
|
||||
cache = get_cache_service()
|
||||
await cache.disconnect()
|
||||
|
||||
# Shutdown rate limiter
|
||||
await shutdown_rate_limiter()
|
||||
|
||||
# Shutdown security monitor
|
||||
await shutdown_security_monitor()
|
||||
|
||||
# 初始化 FastAPI app
|
||||
app = FastAPI(
|
||||
title="Pixel API",
|
||||
description="""
|
||||
# Pixel - AI视频创作平台 API
|
||||
|
||||
Pixel是一个智能平台,利用AI从剧本创作漫画和视频。
|
||||
|
||||
## 核心功能
|
||||
|
||||
- **图片生成**: 支持多种AI模型生成高质量图片
|
||||
- **视频生成**: 从文本或图片生成视频内容
|
||||
- **剧本分析**: 智能分析剧本,提取角色、场景和道具
|
||||
- **项目管理**: 组织和管理创意项目
|
||||
- **Canvas编辑**: 可视化编辑分镜和资产
|
||||
- **任务管理**: 异步任务调度和状态跟踪
|
||||
|
||||
## 支持的AI提供商
|
||||
|
||||
- **图片**: DashScope (Flux, Wanx), ModelScope (Kolors)
|
||||
- **视频**: Kling, Hailuo (MiniMax), ModelScope (CogVideoX, Wanx)
|
||||
- **文本**: DashScope (Qwen), Google (Gemini), VolcEngine (Doubao)
|
||||
|
||||
## API版本
|
||||
|
||||
当前版本: v1
|
||||
所有API端点使用 `/api/v1` 前缀
|
||||
|
||||
## 认证
|
||||
|
||||
部分端点需要API密钥认证。请在请求头中包含:
|
||||
```
|
||||
Authorization: Bearer YOUR_API_KEY
|
||||
```
|
||||
|
||||
## 速率限制
|
||||
|
||||
- 默认: 100 请求/分钟 (per IP)
|
||||
- 认证用户: 1000 请求/分钟
|
||||
|
||||
## 错误处理
|
||||
|
||||
所有错误响应遵循统一格式:
|
||||
```json
|
||||
{
|
||||
"code": "错误代码",
|
||||
"message": "错误描述",
|
||||
"details": {},
|
||||
"request_id": "请求ID",
|
||||
"timestamp": "时间戳"
|
||||
}
|
||||
```
|
||||
|
||||
## 支持
|
||||
|
||||
- 文档: [ARCHITECTURE.md](https://github.com/your-repo/ARCHITECTURE.md)
|
||||
- 问题反馈: [GitHub Issues](https://github.com/your-repo/issues)
|
||||
""",
|
||||
version="0.1.0",
|
||||
generate_unique_id_function=custom_generate_unique_id,
|
||||
lifespan=lifespan,
|
||||
openapi_tags=[
|
||||
{
|
||||
"name": "health",
|
||||
"description": "健康检查和系统状态"
|
||||
},
|
||||
{
|
||||
"name": "generations",
|
||||
"description": "AI生成服务 - 图片、视频、音频生成"
|
||||
},
|
||||
{
|
||||
"name": "tasks",
|
||||
"description": "任务管理 - 查询任务状态、取消任务"
|
||||
},
|
||||
{
|
||||
"name": "projects",
|
||||
"description": "项目管理 - 创建、查询、更新、删除项目"
|
||||
},
|
||||
{
|
||||
"name": "canvas",
|
||||
"description": "Canvas操作 - 节点和边的管理"
|
||||
},
|
||||
{
|
||||
"name": "canvas_metadata",
|
||||
"description": "Canvas元数据 - 保存和加载Canvas状态"
|
||||
},
|
||||
{
|
||||
"name": "storage",
|
||||
"description": "文件存储 - 上传和管理文件"
|
||||
},
|
||||
{
|
||||
"name": "config",
|
||||
"description": "配置管理 - 模型配置、用户设置"
|
||||
},
|
||||
{
|
||||
"name": "chat",
|
||||
"description": "聊天服务 - AI对话和剧本分析"
|
||||
},
|
||||
],
|
||||
docs_url="/docs",
|
||||
redoc_url="/redoc",
|
||||
openapi_url="/openapi.json",
|
||||
contact={
|
||||
"name": "Pixel Team",
|
||||
"email": "support@pixel.ai"
|
||||
},
|
||||
license_info={
|
||||
"name": "MIT",
|
||||
"url": "https://opensource.org/licenses/MIT"
|
||||
}
|
||||
)
|
||||
|
||||
# --- Middleware Setup ---
|
||||
# 错误 handler middleware (must be first to catch all errors)
|
||||
setup_error_handler(app)
|
||||
|
||||
# 请求 tracking middleware (adds request ID and logging)
|
||||
setup_request_tracking(app)
|
||||
|
||||
# 响应 formatter middleware (adds metadata to responses)
|
||||
setup_response_formatter(app)
|
||||
|
||||
# 安全 headers middleware (add security headers to all responses)
|
||||
setup_security_headers_middleware(app)
|
||||
|
||||
# 安全 middleware (check for blocked IPs early)
|
||||
setup_security_middleware(app)
|
||||
|
||||
# 比率 limiter middleware (before metrics to track rate-limited requests)
|
||||
setup_rate_limiter(app)
|
||||
|
||||
# 性能 monitoring middleware
|
||||
setup_performance_monitoring(app, slow_request_threshold=1.0)
|
||||
|
||||
# 指标 middleware
|
||||
setup_metrics_middleware(app)
|
||||
|
||||
# Distributed tracing (must be setup before CORS)
|
||||
setup_tracing(
|
||||
app,
|
||||
service_name="pixel-api",
|
||||
service_version="0.1.0",
|
||||
otlp_endpoint=OTLP_ENDPOINT,
|
||||
enabled=TRACING_ENABLED
|
||||
)
|
||||
|
||||
# GZip compression middleware (compress responses > 1KB)
|
||||
app.add_middleware(GZipMiddleware, minimum_size=1000)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=DEV_ALLOWED_ORIGINS if ALLOW_DEV_ORIGINS else ALLOWED_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# --- Static Files Setup ---
|
||||
# Ensure upload directory exists
|
||||
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
||||
app.mount("/uploads", StaticFiles(directory=UPLOAD_DIR), name="uploads")
|
||||
|
||||
# Mount data directory for local storage access
|
||||
if STORAGE_TYPE == 'local':
|
||||
os.makedirs(DATA_DIR, exist_ok=True)
|
||||
app.mount("/files", StaticFiles(directory=DATA_DIR), name="files")
|
||||
|
||||
# --- Admin Interface Setup ---
|
||||
setup_admin(app, engine)
|
||||
|
||||
# --- Include Routers ---
|
||||
# Add /api/v1 prefix to all routers for API versioning
|
||||
app.include_router(config.router, prefix="/api/v1", tags=["config"])
|
||||
app.include_router(projects.router, prefix="/api/v1", tags=["projects"])
|
||||
app.include_router(generations.router, prefix="/api/v1", tags=["generations"])
|
||||
app.include_router(canvas.router, prefix="/api/v1", tags=["canvas"])
|
||||
app.include_router(canvas_metadata.router, prefix="/api/v1", tags=["canvas-metadata"])
|
||||
app.include_router(tasks.router, prefix="/api/v1", tags=["tasks"])
|
||||
app.include_router(skills.router, prefix="/api/v1", tags=["skills"])
|
||||
app.include_router(auth.router, prefix="/api/v1", tags=["auth"])
|
||||
app.include_router(user_api_keys.router, prefix="/api/v1", tags=["user-api-keys"])
|
||||
app.include_router(admin.router, prefix="/api/v1", tags=["admin"])
|
||||
app.include_router(prompt_templates.router, prefix="/api/v1", tags=["prompt-templates"])
|
||||
app.include_router(audit_logs.router, prefix="/api/v1", tags=["audit-logs"])
|
||||
app.include_router(storage_admin.router, prefix="/api/v1", tags=["admin-storage"])
|
||||
|
||||
# 存储 and chat routers have their own prefixes, so we add /api/v1 before them
|
||||
# This results in /api/v1/storage/* and /api/v1/chat/*
|
||||
app.include_router(storage.router, prefix="/api/v1")
|
||||
app.include_router(chat.router, prefix="/api/v1")
|
||||
|
||||
# 健康检查 routes (no prefix, used for k8s probes and monitoring)
|
||||
app.include_router(health.router)
|
||||
|
||||
# Chat 同时挂载到根路径:兼容 OpenAI 客户端直连 /chat/completions(无 /api/v1 前缀)
|
||||
# 与 /api/v1/chat/* 共用同一 router,仅路径不同
|
||||
app.include_router(chat.router)
|
||||
|
||||
# --- API Endpoints ---
|
||||
# 注意: Health check and metrics endpoints are now in health.py controller
|
||||
297
backend/src/mappers/README.md
Normal file
297
backend/src/mappers/README.md
Normal file
@@ -0,0 +1,297 @@
|
||||
# Data Model Mappers
|
||||
|
||||
This module provides mappers for converting between database entities and API schemas, following the architecture optimization requirements.
|
||||
|
||||
## Overview
|
||||
|
||||
The mapper module implements the separation of concerns between:
|
||||
- **Database Entities** (`src/models/entities.py`): SQLModel classes representing database tables
|
||||
- **API Schemas** (`src/models/schemas.py`): Pydantic models for API request/response validation
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
┌─────────────────────┐
|
||||
│ API Layer │
|
||||
│ (Controllers) │
|
||||
└──────────┬──────────┘
|
||||
│ Uses Schemas
|
||||
▼
|
||||
┌─────────────────────┐
|
||||
│ Mappers │
|
||||
│ (Conversion Logic) │
|
||||
└──────────┬──────────┘
|
||||
│ Uses Entities
|
||||
▼
|
||||
┌─────────────────────┐
|
||||
│ Repository Layer │
|
||||
│ (Data Access) │
|
||||
└─────────────────────┘
|
||||
```
|
||||
|
||||
## Available Mappers
|
||||
|
||||
### ProjectMapper
|
||||
|
||||
Converts between `ProjectDB` entity and `ProjectData` schema.
|
||||
|
||||
**Methods:**
|
||||
- `to_schema(db, include_relations=False)`: Convert entity to schema
|
||||
- `to_entity(schema, project_id=None, user_id=None)`: Convert create request to entity
|
||||
- `update_entity(db, schema)`: Update entity with data from update request
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
from src.mappers import ProjectMapper
|
||||
from src.models.schemas import CreateProjectRequest
|
||||
|
||||
# Create entity from request
|
||||
request = CreateProjectRequest(name="My Project", type="video")
|
||||
project_db = ProjectMapper.to_entity(request)
|
||||
|
||||
# Convert entity to response schema
|
||||
project_data = ProjectMapper.to_schema(project_db, include_relations=True)
|
||||
```
|
||||
|
||||
### AssetMapper
|
||||
|
||||
Converts between `AssetDB` entity and `Asset` schemas (CharacterAsset, SceneAsset, PropAsset, OtherAsset).
|
||||
|
||||
**Methods:**
|
||||
- `to_schema(db)`: Convert entity to appropriate asset schema based on type
|
||||
- `to_entity(schema, project_id, asset_id=None)`: Convert create request to entity
|
||||
- `update_entity(db, schema)`: Update entity with data from update request
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
from src.mappers import AssetMapper
|
||||
from src.models.schemas import CreateCharacterAssetRequest
|
||||
|
||||
# Create entity from request
|
||||
request = CreateCharacterAssetRequest(
|
||||
type="character",
|
||||
name="Hero",
|
||||
desc="Main character",
|
||||
age="25"
|
||||
)
|
||||
asset_db = AssetMapper.to_entity(request, project_id="proj_123")
|
||||
|
||||
# Convert entity to response schema
|
||||
asset = AssetMapper.to_schema(asset_db)
|
||||
```
|
||||
|
||||
### EpisodeMapper
|
||||
|
||||
Converts between `EpisodeDB` entity and `Episode` schema.
|
||||
|
||||
**Methods:**
|
||||
- `to_schema(db)`: Convert entity to schema
|
||||
- `to_entity(schema, project_id, episode_id=None)`: Convert create request to entity
|
||||
- `update_entity(db, schema)`: Update entity with data from update request
|
||||
|
||||
### StoryboardMapper
|
||||
|
||||
Converts between `StoryboardDB` entity and `Storyboard` schema.
|
||||
|
||||
**Methods:**
|
||||
- `to_schema(db)`: Convert entity to schema
|
||||
- `to_entity(schema, project_id, storyboard_id=None)`: Convert create request to entity
|
||||
- `update_entity(db, schema)`: Update entity with data from update request
|
||||
|
||||
### TaskMapper
|
||||
|
||||
Converts between `TaskDB` entity and `Task` schema.
|
||||
|
||||
**Methods:**
|
||||
- `to_schema(db)`: Convert entity to schema
|
||||
- `to_entity(task_type, model, params, ...)`: Create entity from task parameters
|
||||
- `update_status(db, status, result=None, error=None, provider_task_id=None)`: Update task status
|
||||
- `increment_retry(db)`: Increment retry count
|
||||
|
||||
**Example:**
|
||||
```python
|
||||
from src.mappers import TaskMapper
|
||||
|
||||
# Create task entity
|
||||
task_db = TaskMapper.to_entity(
|
||||
task_type="image",
|
||||
model="flux-dev",
|
||||
params={"prompt": "test"},
|
||||
status="pending"
|
||||
)
|
||||
|
||||
# Update task status
|
||||
TaskMapper.update_status(task_db, status="processing", provider_task_id="123")
|
||||
|
||||
# Convert to schema for API response
|
||||
task = TaskMapper.to_schema(task_db)
|
||||
```
|
||||
|
||||
### CanvasMapper
|
||||
|
||||
Converts between `CanvasDB` entity and `CanvasState` schema.
|
||||
|
||||
**Methods:**
|
||||
- `to_schema(db)`: Convert entity to schema
|
||||
- `to_entity(schema)`: Convert schema to entity
|
||||
- `update_entity(db, schema)`: Update entity with data from schema
|
||||
|
||||
### CanvasMetadataMapper
|
||||
|
||||
Converts between `CanvasMetadataDB` entity and `CanvasMetadata` schema.
|
||||
|
||||
**Methods:**
|
||||
- `to_schema(db)`: Convert entity to schema
|
||||
- `to_entity(schema, project_id, canvas_id=None)`: Convert create request to entity
|
||||
- `create_asset_canvas(project_id, asset_id, asset_name, canvas_id=None)`: Create asset canvas metadata
|
||||
- `create_storyboard_canvas(project_id, storyboard_id, storyboard_shot, canvas_id=None)`: Create storyboard canvas metadata
|
||||
- `update_entity(db, schema)`: Update entity with data from update request
|
||||
- `update_access(db)`: Update canvas access tracking
|
||||
|
||||
## Design Principles
|
||||
|
||||
### 1. Single Responsibility
|
||||
|
||||
Each mapper is responsible for converting between one entity type and its corresponding schema(s).
|
||||
|
||||
### 2. Explicit Conversion
|
||||
|
||||
All conversions are explicit through mapper methods. No implicit conversions or magic methods.
|
||||
|
||||
### 3. Type Safety
|
||||
|
||||
Mappers preserve type information and use Pydantic validation for schemas.
|
||||
|
||||
### 4. Separation of Concerns
|
||||
|
||||
- **Entities**: Database structure and relationships
|
||||
- **Schemas**: API contracts and validation
|
||||
- **Mappers**: Conversion logic only
|
||||
|
||||
### 5. Testability
|
||||
|
||||
Mappers are pure functions (no side effects) and easily testable in isolation.
|
||||
|
||||
## Usage Guidelines
|
||||
|
||||
### In Controllers (API Layer)
|
||||
|
||||
Controllers should:
|
||||
1. Receive request schemas
|
||||
2. Use mappers to convert to entities
|
||||
3. Pass entities to service/repository layer
|
||||
4. Use mappers to convert entities back to response schemas
|
||||
|
||||
```python
|
||||
@router.post("/projects")
|
||||
async def create_project(request: CreateProjectRequest):
|
||||
# Convert request to entity
|
||||
project_db = ProjectMapper.to_entity(request)
|
||||
|
||||
# Save to database (via repository)
|
||||
saved_project = repository.create(project_db)
|
||||
|
||||
# Convert entity to response
|
||||
project_data = ProjectMapper.to_schema(saved_project)
|
||||
|
||||
return BaseResponse(data=project_data)
|
||||
```
|
||||
|
||||
### In Repositories (Data Access Layer)
|
||||
|
||||
Repositories should:
|
||||
1. Work with entities internally
|
||||
2. Use mappers when returning data to service layer
|
||||
|
||||
```python
|
||||
def get_project(self, project_id: str) -> Optional[ProjectData]:
|
||||
project_db = session.get(ProjectDB, project_id)
|
||||
if not project_db:
|
||||
return None
|
||||
|
||||
# Use mapper to convert to schema
|
||||
return ProjectMapper.to_schema(project_db, include_relations=True)
|
||||
```
|
||||
|
||||
### In Services (Business Logic Layer)
|
||||
|
||||
Services should:
|
||||
1. Work with schemas (domain objects)
|
||||
2. Use mappers when interacting with repositories
|
||||
|
||||
## Benefits
|
||||
|
||||
### 1. Maintainability
|
||||
|
||||
- Single source of truth for conversion logic
|
||||
- Easy to update when models change
|
||||
- Clear separation of concerns
|
||||
|
||||
### 2. Consistency
|
||||
|
||||
- All conversions follow the same pattern
|
||||
- Reduces code duplication
|
||||
- Ensures consistent field mapping
|
||||
|
||||
### 3. Testability
|
||||
|
||||
- Mappers can be tested independently
|
||||
- No database required for mapper tests
|
||||
- Easy to verify conversion correctness
|
||||
|
||||
### 4. Type Safety
|
||||
|
||||
- Full TypeScript/Python type checking
|
||||
- Pydantic validation for schemas
|
||||
- SQLModel validation for entities
|
||||
|
||||
## Testing
|
||||
|
||||
Run mapper tests:
|
||||
|
||||
```bash
|
||||
python backend/test_mappers.py
|
||||
```
|
||||
|
||||
All mappers have comprehensive unit tests covering:
|
||||
- Entity to schema conversion
|
||||
- Schema to entity conversion
|
||||
- Entity updates
|
||||
- Edge cases and error handling
|
||||
|
||||
## Migration Guide
|
||||
|
||||
When updating existing code to use mappers:
|
||||
|
||||
1. **Identify manual conversions**: Look for code that manually creates schemas from entities
|
||||
2. **Replace with mapper calls**: Use appropriate mapper method
|
||||
3. **Update imports**: Import mapper instead of creating schemas directly
|
||||
4. **Test thoroughly**: Ensure all fields are correctly mapped
|
||||
|
||||
### Before:
|
||||
```python
|
||||
project_data = ProjectData(
|
||||
id=project_db.id,
|
||||
name=project_db.name,
|
||||
description=project_db.description,
|
||||
# ... many more fields
|
||||
)
|
||||
```
|
||||
|
||||
### After:
|
||||
```python
|
||||
from src.mappers import ProjectMapper
|
||||
|
||||
project_data = ProjectMapper.to_schema(project_db)
|
||||
```
|
||||
|
||||
## Requirements Satisfied
|
||||
|
||||
This mapper module satisfies the following requirements from the architecture optimization spec:
|
||||
|
||||
- **Requirement 5.1**: Separate database entities from API schemas ✓
|
||||
- **Requirement 5.2**: Define database entities in dedicated entities module ✓
|
||||
- **Requirement 5.3**: Define API schemas in dedicated schemas module ✓
|
||||
- **Requirement 5.4**: Use explicit mapper functions for conversion ✓
|
||||
- **Requirement 5.5**: Each field defined in exactly one authoritative location ✓
|
||||
- **Requirement 5.6**: Use inheritance/composition to share common fields ✓
|
||||
18
backend/src/mappers/__init__.py
Normal file
18
backend/src/mappers/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# Mappers for converting between database entities and API schemas
|
||||
from .project_mapper import ProjectMapper
|
||||
from .asset_mapper import AssetMapper
|
||||
from .episode_mapper import EpisodeMapper
|
||||
from .storyboard_mapper import StoryboardMapper
|
||||
from .task_mapper import TaskMapper
|
||||
from .canvas_mapper import CanvasMapper
|
||||
from .canvas_metadata_mapper import CanvasMetadataMapper
|
||||
|
||||
__all__ = [
|
||||
'ProjectMapper',
|
||||
'AssetMapper',
|
||||
'EpisodeMapper',
|
||||
'StoryboardMapper',
|
||||
'TaskMapper',
|
||||
'CanvasMapper',
|
||||
'CanvasMetadataMapper',
|
||||
]
|
||||
204
backend/src/mappers/asset_mapper.py
Normal file
204
backend/src/mappers/asset_mapper.py
Normal file
@@ -0,0 +1,204 @@
|
||||
"""Mapper for converting between AssetDB entity and Asset schemas"""
|
||||
from typing import Union, Dict, Any
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from src.models.entities import AssetDB
|
||||
from src.models.schemas import (
|
||||
Asset,
|
||||
CharacterAsset,
|
||||
SceneAsset,
|
||||
PropAsset,
|
||||
OtherAsset,
|
||||
CreateAssetRequest,
|
||||
CreateCharacterAssetRequest,
|
||||
CreateSceneAssetRequest,
|
||||
CreatePropAssetRequest,
|
||||
CreateOtherAssetRequest,
|
||||
UpdateAssetRequest,
|
||||
GenerationRecord,
|
||||
)
|
||||
|
||||
|
||||
class AssetMapper:
|
||||
"""Mapper for Asset entity and schemas"""
|
||||
|
||||
@staticmethod
|
||||
def to_schema(db: AssetDB) -> Asset:
|
||||
"""Convert AssetDB entity to Asset schema
|
||||
|
||||
Args:
|
||||
db: AssetDB entity from database
|
||||
|
||||
Returns:
|
||||
Asset schema (CharacterAsset, SceneAsset, PropAsset, or OtherAsset)
|
||||
"""
|
||||
# Parse generations
|
||||
generations = []
|
||||
if db.generations:
|
||||
try:
|
||||
from pydantic import TypeAdapter
|
||||
adapter = TypeAdapter(list[GenerationRecord])
|
||||
generations = adapter.validate_python(db.generations)
|
||||
except Exception:
|
||||
generations = []
|
||||
|
||||
# Base fields common to all asset types
|
||||
base_data = {
|
||||
'id': db.id,
|
||||
'type': db.type,
|
||||
'name': db.name,
|
||||
'desc': db.desc,
|
||||
'tags': db.tags or [],
|
||||
'image_url': db.image_url,
|
||||
'image_urls': db.image_urls,
|
||||
'video_urls': db.video_urls,
|
||||
'image_prompt': db.image_prompt,
|
||||
'generations': generations,
|
||||
}
|
||||
|
||||
# Add type-specific fields from extra_data
|
||||
extra_data = db.extra_data or {}
|
||||
|
||||
if db.type == 'character':
|
||||
return CharacterAsset(
|
||||
**base_data,
|
||||
age=extra_data.get('age'),
|
||||
gender=extra_data.get('gender'),
|
||||
role=extra_data.get('role'),
|
||||
emotion=extra_data.get('emotion'),
|
||||
appearance=extra_data.get('appearance'),
|
||||
)
|
||||
elif db.type == 'scene':
|
||||
return SceneAsset(
|
||||
**base_data,
|
||||
location=extra_data.get('location'),
|
||||
time_of_day=extra_data.get('time_of_day'),
|
||||
environment_type=extra_data.get('environment_type'),
|
||||
weather=extra_data.get('weather'),
|
||||
atmosphere=extra_data.get('atmosphere'),
|
||||
)
|
||||
elif db.type == 'prop':
|
||||
return PropAsset(
|
||||
**base_data,
|
||||
usage=extra_data.get('usage'),
|
||||
)
|
||||
else:
|
||||
return OtherAsset(**base_data)
|
||||
|
||||
@staticmethod
|
||||
def to_entity(
|
||||
schema: CreateAssetRequest,
|
||||
project_id: str,
|
||||
asset_id: str = None
|
||||
) -> AssetDB:
|
||||
"""Convert CreateAssetRequest schema to AssetDB entity
|
||||
|
||||
Args:
|
||||
schema: CreateAssetRequest from API
|
||||
project_id: Project ID this asset belongs to
|
||||
asset_id: Optional asset ID (generated if not provided)
|
||||
|
||||
Returns:
|
||||
AssetDB entity for database storage
|
||||
"""
|
||||
# 提取 type-specific fields into extra_data
|
||||
extra_data = {}
|
||||
|
||||
if isinstance(schema, CreateCharacterAssetRequest):
|
||||
extra_data = {
|
||||
'age': schema.age,
|
||||
'gender': schema.gender,
|
||||
'role': schema.role,
|
||||
'appearance': schema.appearance,
|
||||
}
|
||||
elif isinstance(schema, CreateSceneAssetRequest):
|
||||
extra_data = {
|
||||
'location': schema.location,
|
||||
'time_of_day': schema.time_of_day,
|
||||
'atmosphere': schema.atmosphere,
|
||||
}
|
||||
elif isinstance(schema, CreatePropAssetRequest):
|
||||
extra_data = {
|
||||
'usage': schema.usage,
|
||||
}
|
||||
|
||||
return AssetDB(
|
||||
id=asset_id or str(uuid.uuid4()),
|
||||
project_id=project_id,
|
||||
type=schema.type,
|
||||
name=schema.name,
|
||||
desc=schema.desc,
|
||||
tags=schema.tags or [],
|
||||
image_url=schema.image_url,
|
||||
image_urls=schema.image_urls,
|
||||
video_urls=schema.video_urls,
|
||||
image_prompt=schema.image_prompt,
|
||||
extra_data=extra_data,
|
||||
generations=[],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def update_entity(db: AssetDB, schema: UpdateAssetRequest) -> AssetDB:
|
||||
"""Update AssetDB entity with data from UpdateAssetRequest schema
|
||||
|
||||
Args:
|
||||
db: Existing AssetDB entity
|
||||
schema: UpdateAssetRequest with new data
|
||||
|
||||
Returns:
|
||||
Updated AssetDB entity
|
||||
"""
|
||||
# Update base fields
|
||||
if schema.name is not None:
|
||||
db.name = schema.name
|
||||
|
||||
if schema.desc is not None:
|
||||
db.desc = schema.desc
|
||||
|
||||
if schema.tags is not None:
|
||||
db.tags = schema.tags
|
||||
|
||||
if schema.image_url is not None:
|
||||
db.image_url = schema.image_url
|
||||
|
||||
if schema.image_urls is not None:
|
||||
db.image_urls = schema.image_urls
|
||||
|
||||
if schema.video_urls is not None:
|
||||
db.video_urls = schema.video_urls
|
||||
|
||||
if schema.image_prompt is not None:
|
||||
db.image_prompt = schema.image_prompt
|
||||
|
||||
if schema.generations is not None:
|
||||
# Convert GenerationRecord objects to dicts
|
||||
db.generations = [gen.model_dump() for gen in schema.generations]
|
||||
|
||||
# Update type-specific fields in extra_data
|
||||
extra_data = db.extra_data or {}
|
||||
|
||||
if schema.age is not None:
|
||||
extra_data['age'] = schema.age
|
||||
|
||||
if schema.role is not None:
|
||||
extra_data['role'] = schema.role
|
||||
|
||||
if schema.appearance is not None:
|
||||
extra_data['appearance'] = schema.appearance
|
||||
|
||||
if schema.location is not None:
|
||||
extra_data['location'] = schema.location
|
||||
|
||||
if schema.time_of_day is not None:
|
||||
extra_data['time_of_day'] = schema.time_of_day
|
||||
|
||||
if schema.atmosphere is not None:
|
||||
extra_data['atmosphere'] = schema.atmosphere
|
||||
|
||||
if schema.usage is not None:
|
||||
extra_data['usage'] = schema.usage
|
||||
|
||||
db.extra_data = extra_data
|
||||
|
||||
return db
|
||||
72
backend/src/mappers/canvas_mapper.py
Normal file
72
backend/src/mappers/canvas_mapper.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""Mapper for converting between CanvasDB entity and CanvasState schemas"""
|
||||
from datetime import datetime
|
||||
|
||||
from src.models.entities import CanvasDB
|
||||
from src.models.schemas import CanvasState
|
||||
|
||||
|
||||
class CanvasMapper:
|
||||
"""Mapper for Canvas entity and schemas"""
|
||||
|
||||
@staticmethod
|
||||
def to_schema(db: CanvasDB) -> CanvasState:
|
||||
"""Convert CanvasDB entity to CanvasState schema
|
||||
|
||||
Args:
|
||||
db: CanvasDB entity from database
|
||||
|
||||
Returns:
|
||||
CanvasState schema for API response
|
||||
"""
|
||||
return CanvasState(
|
||||
id=db.id,
|
||||
projectId=db.project_id,
|
||||
nodes=db.nodes or [],
|
||||
connections=db.connections or [],
|
||||
groups=db.groups or [],
|
||||
history=db.history or [],
|
||||
historyIndex=db.history_index,
|
||||
updatedAt=db.updated_at,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def to_entity(schema: CanvasState) -> CanvasDB:
|
||||
"""Convert CanvasState schema to CanvasDB entity
|
||||
|
||||
Args:
|
||||
schema: CanvasState from API
|
||||
|
||||
Returns:
|
||||
CanvasDB entity for database storage
|
||||
"""
|
||||
return CanvasDB(
|
||||
id=schema.id,
|
||||
project_id=schema.projectId,
|
||||
nodes=schema.nodes,
|
||||
connections=schema.connections,
|
||||
groups=schema.groups,
|
||||
history=schema.history,
|
||||
history_index=schema.history_index,
|
||||
updated_at=schema.updated_at,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def update_entity(db: CanvasDB, schema: CanvasState) -> CanvasDB:
|
||||
"""Update CanvasDB entity with data from CanvasState schema
|
||||
|
||||
Args:
|
||||
db: Existing CanvasDB entity
|
||||
schema: CanvasState with new data
|
||||
|
||||
Returns:
|
||||
Updated CanvasDB entity
|
||||
"""
|
||||
db.project_id = schema.projectId
|
||||
db.nodes = schema.nodes
|
||||
db.connections = schema.connections
|
||||
db.groups = schema.groups
|
||||
db.history = schema.history
|
||||
db.history_index = schema.history_index
|
||||
db.updated_at = datetime.now().timestamp()
|
||||
|
||||
return db
|
||||
195
backend/src/mappers/canvas_metadata_mapper.py
Normal file
195
backend/src/mappers/canvas_metadata_mapper.py
Normal file
@@ -0,0 +1,195 @@
|
||||
"""Mapper for converting between CanvasMetadataDB entity and CanvasMetadata schemas"""
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from src.models.entities import CanvasMetadataDB
|
||||
from src.models.schemas import (
|
||||
CanvasMetadata,
|
||||
CreateGeneralCanvasRequest,
|
||||
UpdateCanvasMetadataRequest,
|
||||
)
|
||||
|
||||
|
||||
class CanvasMetadataMapper:
|
||||
"""Mapper for CanvasMetadata entity and schemas"""
|
||||
|
||||
@staticmethod
|
||||
def to_schema(db: CanvasMetadataDB) -> CanvasMetadata:
|
||||
"""Convert CanvasMetadataDB entity to CanvasMetadata schema
|
||||
|
||||
Args:
|
||||
db: CanvasMetadataDB entity from database
|
||||
|
||||
Returns:
|
||||
CanvasMetadata schema for API response
|
||||
"""
|
||||
return CanvasMetadata(
|
||||
id=db.id,
|
||||
projectId=db.project_id,
|
||||
canvasType=db.canvas_type,
|
||||
relatedEntityType=db.related_entity_type,
|
||||
relatedEntityId=db.related_entity_id,
|
||||
name=db.name,
|
||||
description=db.description,
|
||||
orderIndex=db.order_index,
|
||||
isPinned=db.is_pinned,
|
||||
tags=db.tags or [],
|
||||
nodeCount=db.node_count,
|
||||
lastAccessedAt=db.last_accessed_at,
|
||||
accessCount=db.access_count,
|
||||
createdAt=db.created_at,
|
||||
updatedAt=db.updated_at,
|
||||
deletedAt=db.deleted_at,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def to_entity(
|
||||
schema: CreateGeneralCanvasRequest,
|
||||
project_id: str,
|
||||
canvas_id: Optional[str] = None
|
||||
) -> CanvasMetadataDB:
|
||||
"""Convert CreateGeneralCanvasRequest schema to CanvasMetadataDB entity
|
||||
|
||||
Args:
|
||||
schema: CreateGeneralCanvasRequest from API
|
||||
project_id: Project ID this canvas belongs to
|
||||
canvas_id: Optional canvas ID (generated if not provided)
|
||||
|
||||
Returns:
|
||||
CanvasMetadataDB entity for database storage
|
||||
"""
|
||||
now = datetime.now().timestamp()
|
||||
|
||||
return CanvasMetadataDB(
|
||||
id=canvas_id or str(uuid.uuid4()),
|
||||
project_id=project_id,
|
||||
canvas_type="general",
|
||||
name=schema.name,
|
||||
description=schema.description,
|
||||
order_index=0,
|
||||
is_pinned=False,
|
||||
tags=[],
|
||||
node_count=0,
|
||||
access_count=0,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_asset_canvas(
|
||||
project_id: str,
|
||||
asset_id: str,
|
||||
asset_name: str,
|
||||
canvas_id: Optional[str] = None
|
||||
) -> CanvasMetadataDB:
|
||||
""" Create CanvasMetadataDB entity for an asset canvas
|
||||
|
||||
Args:
|
||||
project_id: Project ID
|
||||
asset_id: Asset ID this canvas is linked to
|
||||
asset_name: Asset name for canvas name
|
||||
canvas_id: Optional canvas ID (generated if not provided)
|
||||
|
||||
Returns:
|
||||
CanvasMetadataDB entity for database storage
|
||||
"""
|
||||
now = datetime.now().timestamp()
|
||||
|
||||
return CanvasMetadataDB(
|
||||
id=canvas_id or str(uuid.uuid4()),
|
||||
project_id=project_id,
|
||||
canvas_type="asset",
|
||||
related_entity_type="asset",
|
||||
related_entity_id=asset_id,
|
||||
name=f"{asset_name} Canvas",
|
||||
order_index=0,
|
||||
is_pinned=False,
|
||||
tags=[],
|
||||
node_count=0,
|
||||
access_count=0,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_storyboard_canvas(
|
||||
project_id: str,
|
||||
storyboard_id: str,
|
||||
storyboard_shot: str,
|
||||
canvas_id: Optional[str] = None
|
||||
) -> CanvasMetadataDB:
|
||||
""" Create CanvasMetadataDB entity for a storyboard canvas
|
||||
|
||||
Args:
|
||||
project_id: Project ID
|
||||
storyboard_id: Storyboard ID this canvas is linked to
|
||||
storyboard_shot: Storyboard shot for canvas name
|
||||
canvas_id: Optional canvas ID (generated if not provided)
|
||||
|
||||
Returns:
|
||||
CanvasMetadataDB entity for database storage
|
||||
"""
|
||||
now = datetime.now().timestamp()
|
||||
|
||||
return CanvasMetadataDB(
|
||||
id=canvas_id or str(uuid.uuid4()),
|
||||
project_id=project_id,
|
||||
canvas_type="storyboard",
|
||||
related_entity_type="storyboard",
|
||||
related_entity_id=storyboard_id,
|
||||
name=f"{storyboard_shot} Canvas",
|
||||
order_index=0,
|
||||
is_pinned=False,
|
||||
tags=[],
|
||||
node_count=0,
|
||||
access_count=0,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def update_entity(
|
||||
db: CanvasMetadataDB,
|
||||
schema: UpdateCanvasMetadataRequest
|
||||
) -> CanvasMetadataDB:
|
||||
"""Update CanvasMetadataDB entity with data from UpdateCanvasMetadataRequest schema
|
||||
|
||||
Args:
|
||||
db: Existing CanvasMetadataDB entity
|
||||
schema: UpdateCanvasMetadataRequest with new data
|
||||
|
||||
Returns:
|
||||
Updated CanvasMetadataDB entity
|
||||
"""
|
||||
if schema.name is not None:
|
||||
db.name = schema.name
|
||||
|
||||
if schema.description is not None:
|
||||
db.description = schema.description
|
||||
|
||||
if schema.is_pinned is not None:
|
||||
db.is_pinned = schema.is_pinned
|
||||
|
||||
if schema.tags is not None:
|
||||
db.tags = schema.tags
|
||||
|
||||
db.updated_at = datetime.now().timestamp()
|
||||
|
||||
return db
|
||||
|
||||
@staticmethod
|
||||
def update_access(db: CanvasMetadataDB) -> CanvasMetadataDB:
|
||||
"""Update canvas access tracking
|
||||
|
||||
Args:
|
||||
db: Existing CanvasMetadataDB entity
|
||||
|
||||
Returns:
|
||||
Updated CanvasMetadataDB entity
|
||||
"""
|
||||
db.last_accessed_at = datetime.now().timestamp()
|
||||
db.access_count += 1
|
||||
db.updated_at = datetime.now().timestamp()
|
||||
|
||||
return db
|
||||
83
backend/src/mappers/episode_mapper.py
Normal file
83
backend/src/mappers/episode_mapper.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""Mapper for converting between EpisodeDB entity and Episode schemas"""
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from src.models.entities import EpisodeDB
|
||||
from src.models.schemas import (
|
||||
Episode,
|
||||
CreateEpisodeRequest,
|
||||
UpdateEpisodeRequest,
|
||||
)
|
||||
|
||||
|
||||
class EpisodeMapper:
|
||||
"""Mapper for Episode entity and schemas"""
|
||||
|
||||
@staticmethod
|
||||
def to_schema(db: EpisodeDB) -> Episode:
|
||||
"""Convert EpisodeDB entity to Episode schema
|
||||
|
||||
Args:
|
||||
db: EpisodeDB entity from database
|
||||
|
||||
Returns:
|
||||
Episode schema for API response
|
||||
"""
|
||||
return Episode(
|
||||
id=db.id,
|
||||
title=db.title,
|
||||
order=db.order_index,
|
||||
desc=db.desc,
|
||||
content=db.content,
|
||||
status=db.status,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def to_entity(
|
||||
schema: CreateEpisodeRequest,
|
||||
project_id: str,
|
||||
episode_id: str = None
|
||||
) -> EpisodeDB:
|
||||
"""Convert CreateEpisodeRequest schema to EpisodeDB entity
|
||||
|
||||
Args:
|
||||
schema: CreateEpisodeRequest from API
|
||||
project_id: Project ID this episode belongs to
|
||||
episode_id: Optional episode ID (generated if not provided)
|
||||
|
||||
Returns:
|
||||
EpisodeDB entity for database storage
|
||||
"""
|
||||
return EpisodeDB(
|
||||
id=episode_id or str(uuid.uuid4()),
|
||||
project_id=project_id,
|
||||
order_index=schema.order,
|
||||
title=schema.title,
|
||||
desc=schema.desc,
|
||||
status=schema.status,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def update_entity(db: EpisodeDB, schema: UpdateEpisodeRequest) -> EpisodeDB:
|
||||
"""Update EpisodeDB entity with data from UpdateEpisodeRequest schema
|
||||
|
||||
Args:
|
||||
db: Existing EpisodeDB entity
|
||||
schema: UpdateEpisodeRequest with new data
|
||||
|
||||
Returns:
|
||||
Updated EpisodeDB entity
|
||||
"""
|
||||
if schema.title is not None:
|
||||
db.title = schema.title
|
||||
|
||||
if schema.order is not None:
|
||||
db.order_index = schema.order
|
||||
|
||||
if schema.desc is not None:
|
||||
db.desc = schema.desc
|
||||
|
||||
if schema.status is not None:
|
||||
db.status = schema.status
|
||||
|
||||
return db
|
||||
150
backend/src/mappers/project_mapper.py
Normal file
150
backend/src/mappers/project_mapper.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""Mapper for converting between ProjectDB entity and Project schemas"""
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
|
||||
from src.models.entities import ProjectDB
|
||||
from src.models.schemas import (
|
||||
ProjectData,
|
||||
CreateProjectRequest,
|
||||
UpdateProjectRequest,
|
||||
InitializationProgress,
|
||||
)
|
||||
|
||||
|
||||
class ProjectMapper:
|
||||
"""Mapper for Project entity and schemas"""
|
||||
|
||||
@staticmethod
|
||||
def to_schema(db: ProjectDB, include_relations: bool = False) -> ProjectData:
|
||||
"""Convert ProjectDB entity to ProjectData schema
|
||||
|
||||
Args:
|
||||
db: ProjectDB entity from database
|
||||
include_relations: Whether to include related entities (assets, episodes, storyboards)
|
||||
|
||||
Returns:
|
||||
ProjectData schema for API response
|
||||
"""
|
||||
# Parse progress from database
|
||||
progress = None
|
||||
if db.progress:
|
||||
try:
|
||||
progress = InitializationProgress(**db.progress)
|
||||
except Exception:
|
||||
progress = None
|
||||
|
||||
# Parse error from database
|
||||
error = db.error if db.error else None
|
||||
|
||||
# Convert timestamps to datetime
|
||||
created_at = datetime.fromtimestamp(db.created_at)
|
||||
updated_at = datetime.fromtimestamp(db.updated_at)
|
||||
|
||||
# Build base project data
|
||||
project_data = ProjectData(
|
||||
id=db.id,
|
||||
name=db.name,
|
||||
description=db.description,
|
||||
type=db.type,
|
||||
created_at=created_at,
|
||||
updated_at=updated_at,
|
||||
status=db.status,
|
||||
resolution=db.resolution,
|
||||
ratio=db.ratio,
|
||||
style_id=db.style_id,
|
||||
style_params=db.style_params,
|
||||
chapters=db.chapters,
|
||||
progress=progress,
|
||||
error=error,
|
||||
assets=[],
|
||||
episodes=[],
|
||||
storyboards=[],
|
||||
general_canvases=[],
|
||||
user_id=db.user_id,
|
||||
)
|
||||
|
||||
# Include relations if requested
|
||||
if include_relations:
|
||||
from .asset_mapper import AssetMapper
|
||||
from .episode_mapper import EpisodeMapper
|
||||
from .storyboard_mapper import StoryboardMapper
|
||||
|
||||
if db.assets:
|
||||
project_data.assets = [AssetMapper.to_schema(asset) for asset in db.assets]
|
||||
|
||||
if db.episodes:
|
||||
project_data.episodes = [EpisodeMapper.to_schema(episode) for episode in db.episodes]
|
||||
|
||||
if db.storyboards:
|
||||
project_data.storyboards = [StoryboardMapper.to_schema(storyboard) for storyboard in db.storyboards]
|
||||
|
||||
return project_data
|
||||
|
||||
@staticmethod
|
||||
def to_entity(
|
||||
schema: CreateProjectRequest,
|
||||
project_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None
|
||||
) -> ProjectDB:
|
||||
"""Convert CreateProjectRequest schema to ProjectDB entity
|
||||
|
||||
Args:
|
||||
schema: CreateProjectRequest from API
|
||||
project_id: Optional project ID (generated if not provided)
|
||||
user_id: Optional user ID for tracking
|
||||
|
||||
Returns:
|
||||
ProjectDB entity for database storage
|
||||
"""
|
||||
now = datetime.now().timestamp()
|
||||
|
||||
return ProjectDB(
|
||||
id=project_id or str(uuid.uuid4()),
|
||||
name=schema.name,
|
||||
description=schema.description,
|
||||
type=schema.type,
|
||||
status="active",
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
chapters=schema.chapters,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def update_entity(db: ProjectDB, schema: UpdateProjectRequest) -> ProjectDB:
|
||||
"""Update ProjectDB entity with data from UpdateProjectRequest schema
|
||||
|
||||
Args:
|
||||
db: Existing ProjectDB entity
|
||||
schema: UpdateProjectRequest with new data
|
||||
|
||||
Returns:
|
||||
Updated ProjectDB entity
|
||||
"""
|
||||
# Update only provided fields
|
||||
if schema.name is not None:
|
||||
db.name = schema.name
|
||||
|
||||
if schema.description is not None:
|
||||
db.description = schema.description
|
||||
|
||||
if schema.resolution is not None:
|
||||
db.resolution = schema.resolution
|
||||
|
||||
if schema.ratio is not None:
|
||||
db.ratio = schema.ratio
|
||||
|
||||
if schema.style_id is not None:
|
||||
db.style_id = schema.style_id
|
||||
|
||||
if schema.style_params is not None:
|
||||
db.style_params = schema.style_params
|
||||
|
||||
if schema.chapters is not None:
|
||||
db.chapters = schema.chapters
|
||||
|
||||
# Update timestamp
|
||||
db.updated_at = datetime.now().timestamp()
|
||||
|
||||
return db
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user