Initial commit: Pixel AI comic/video creation platform
- FastAPI backend with SQLModel, Alembic migrations, AgentScope agents - Next.js 15 frontend with React 19, Tailwind, Zustand, React Flow - Multi-provider AI system (DashScope, Kling, MiniMax, Volcengine, OpenAI, etc.) - All HTTP clients migrated from sync requests to async httpx - Admin-managed API keys via environment variables - SSRF vulnerability fixed in ensure_url()
This commit is contained in:
108
backend/.env.example
Normal file
108
backend/.env.example
Normal file
@@ -0,0 +1,108 @@
|
||||
# ===========================================
|
||||
# Pixel Backend Environment Configuration
|
||||
# ===========================================
|
||||
# Copy this file to .env and fill in your values.
|
||||
|
||||
# ---- Server ----
|
||||
NODE_ENV=development
|
||||
PY_PORT=8000
|
||||
|
||||
# ---- Database ----
|
||||
# Default: SQLite (backend/data/pixel.db)
|
||||
# DATABASE_URL=postgresql://user:pass@localhost:5432/pixel
|
||||
# DB_PATH= # override SQLite file path
|
||||
# DATA_DIR= # override data directory
|
||||
|
||||
# Database connection pool
|
||||
# DB_POOL_SIZE=20
|
||||
# DB_MAX_OVERFLOW=10
|
||||
# DB_POOL_TIMEOUT=30
|
||||
# DB_POOL_RECYCLE=3600
|
||||
# DB_POOL_PRE_PING=true
|
||||
# SLOW_QUERY_THRESHOLD=1.0
|
||||
|
||||
# ---- Redis ----
|
||||
REDIS_URL=redis://localhost:6379
|
||||
REDIS_ENABLED=1
|
||||
|
||||
# ---- JWT Auth ----
|
||||
# Auto-generated in dev; MUST set in production
|
||||
# JWT_SECRET_KEY=your-secret-key-here
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES=30
|
||||
REFRESH_TOKEN_EXPIRE_DAYS=7
|
||||
|
||||
# ---- Encryption Key (for user API key storage) ----
|
||||
# Auto-generated in dev; MUST set in production
|
||||
# MASTER_ENCRYPTION_KEY=your-fernet-key-here
|
||||
|
||||
# ---- CORS ----
|
||||
# CORS_ALLOWED_ORIGINS=https://your-app.example.com
|
||||
# CORS_DEV_ALLOWED_ORIGINS=http://localhost:3000,http://127.0.0.1:3000
|
||||
# ALLOW_DEV_ORIGINS=1
|
||||
|
||||
# ---- Storage (OSS) ----
|
||||
STORAGE_TYPE=local
|
||||
# OSS_REGION=oss-cn-shanghai
|
||||
# OSS_ENDPOINT=oss-cn-shanghai.aliyuncs.com
|
||||
# OSS_BUCKET=your-bucket-name
|
||||
# ALIBABA_CLOUD_ACCESS_KEY_ID=your_key
|
||||
# ALIBABA_CLOUD_ACCESS_KEY_SECRET=your_secret
|
||||
|
||||
# ---- Email (SMTP) ----
|
||||
# SMTP_HOST=
|
||||
# SMTP_PORT=587
|
||||
# SMTP_USER=
|
||||
# SMTP_PASSWORD=
|
||||
# SMTP_FROM=
|
||||
# SMTP_TLS=true
|
||||
# FRONTEND_URL=http://localhost:3000
|
||||
|
||||
# ===========================================
|
||||
# AI Provider API Keys
|
||||
# All users share these system-level keys.
|
||||
# ===========================================
|
||||
|
||||
# DashScope (Qwen LLM, Wanx Image, Z-Image)
|
||||
# DASHSCOPE_API_KEY=sk-xxx
|
||||
# DASHSCOPE_BASE_URL= # optional
|
||||
|
||||
# VolcEngine / 火山引擎 (Doubao LLM, video)
|
||||
# VOLCENGINE_API_KEY=xxx
|
||||
|
||||
# Google (Gemini LLM)
|
||||
# GOOGLE_API_KEY=xxx
|
||||
|
||||
# OpenAI
|
||||
# OPENAI_API_KEY=sk-xxx
|
||||
# OPENAI_BASE_URL= # optional, for proxies
|
||||
|
||||
# MiniMax / 海螺 (video, audio, music)
|
||||
# MINIMAX_API_KEY=xxx
|
||||
# MINIMAX_GROUP_ID=xxx
|
||||
|
||||
# Kling / 可灵 (video) — requires both access_key and secret_key
|
||||
# KLING_ACCESS_KEY=xxx
|
||||
# KLING_SECRET_KEY=xxx
|
||||
# KLING_API_BASE=https://api-beijing.klingai.com/v1
|
||||
|
||||
# Midjourney / 有川 (image)
|
||||
# MIDJOURNEY_API_KEY=xxx
|
||||
# MIDJOURNEY_PROXY_URL=xxx
|
||||
# YOUCHUAN_APP_ID=xxx
|
||||
# YOUCHUAN_SECRET_KEY=xxx
|
||||
|
||||
# ModelScope (image, video)
|
||||
# MODELSCOPE_API_TOKEN=xxx
|
||||
|
||||
# ---- Script Agent (AgentScope) ----
|
||||
# Override keys specifically for script analysis agents
|
||||
# SCRIPT_DASHSCOPE_API_KEY=xxx
|
||||
# SCRIPT_DASHSCOPE_BASE_URL=xxx
|
||||
# SCRIPT_OPENAI_API_KEY=xxx
|
||||
# SCRIPT_OPENAI_BASE_URL=xxx
|
||||
|
||||
# ---- Monitoring ----
|
||||
# ENABLE_METRICS=true
|
||||
# LOG_LEVEL=INFO
|
||||
# TRACING_ENABLED=0
|
||||
# OTLP_ENDPOINT=http://localhost:4317
|
||||
40
backend/Dockerfile
Normal file
40
backend/Dockerfile
Normal file
@@ -0,0 +1,40 @@
|
||||
FROM python:3.12-slim
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
curl \
|
||||
git \
|
||||
libgl1 \
|
||||
libglib2.0-0 \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install uv
|
||||
RUN pip install uv
|
||||
|
||||
# Set working directory
|
||||
WORKDIR /app
|
||||
|
||||
# Copy dependency files
|
||||
COPY pyproject.toml uv.lock ./
|
||||
|
||||
# Create virtual environment and install dependencies
|
||||
# Using --frozen to ensure strict adherence to uv.lock
|
||||
RUN uv sync --frozen --no-install-project
|
||||
|
||||
# Add virtual environment to PATH
|
||||
ENV PATH="/app/.venv/bin:$PATH"
|
||||
|
||||
# Copy source code
|
||||
COPY src ./src
|
||||
|
||||
# Install the project itself (if needed)
|
||||
RUN uv sync --frozen
|
||||
|
||||
# Create directories for data/uploads if they don't exist
|
||||
RUN mkdir -p data/uploads
|
||||
|
||||
# Expose port
|
||||
EXPOSE 8000
|
||||
|
||||
# Start command
|
||||
CMD ["uv", "run", "uvicorn", "src.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
147
backend/alembic.ini
Normal file
147
backend/alembic.ini
Normal file
@@ -0,0 +1,147 @@
|
||||
# A generic, single database configuration.
|
||||
|
||||
[alembic]
|
||||
# path to migration scripts.
|
||||
# this is typically a path given in POSIX (e.g. forward slashes)
|
||||
# format, relative to the token %(here)s which refers to the location of this
|
||||
# ini file
|
||||
script_location = %(here)s/alembic
|
||||
|
||||
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
|
||||
# Uncomment the line below if you want the files to be prepended with date and time
|
||||
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
|
||||
# for all available tokens
|
||||
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
|
||||
|
||||
# sys.path path, will be prepended to sys.path if present.
|
||||
# defaults to the current working directory. for multiple paths, the path separator
|
||||
# is defined by "path_separator" below.
|
||||
prepend_sys_path = .
|
||||
|
||||
|
||||
# timezone to use when rendering the date within the migration file
|
||||
# as well as the filename.
|
||||
# If specified, requires the tzdata library which can be installed by adding
|
||||
# `alembic[tz]` to the pip requirements.
|
||||
# string value is passed to ZoneInfo()
|
||||
# leave blank for localtime
|
||||
# timezone =
|
||||
|
||||
# max length of characters to apply to the "slug" field
|
||||
# truncate_slug_length = 40
|
||||
|
||||
# set to 'true' to run the environment during
|
||||
# the 'revision' command, regardless of autogenerate
|
||||
# revision_environment = false
|
||||
|
||||
# set to 'true' to allow .pyc and .pyo files without
|
||||
# a source .py file to be detected as revisions in the
|
||||
# versions/ directory
|
||||
# sourceless = false
|
||||
|
||||
# version location specification; This defaults
|
||||
# to <script_location>/versions. When using multiple version
|
||||
# directories, initial revisions must be specified with --version-path.
|
||||
# The path separator used here should be the separator specified by "path_separator"
|
||||
# below.
|
||||
# version_locations = %(here)s/bar:%(here)s/bat:%(here)s/alembic/versions
|
||||
|
||||
# path_separator; This indicates what character is used to split lists of file
|
||||
# paths, including version_locations and prepend_sys_path within configparser
|
||||
# files such as alembic.ini.
|
||||
# The default rendered in new alembic.ini files is "os", which uses os.pathsep
|
||||
# to provide os-dependent path splitting.
|
||||
#
|
||||
# Note that in order to support legacy alembic.ini files, this default does NOT
|
||||
# take place if path_separator is not present in alembic.ini. If this
|
||||
# option is omitted entirely, fallback logic is as follows:
|
||||
#
|
||||
# 1. Parsing of the version_locations option falls back to using the legacy
|
||||
# "version_path_separator" key, which if absent then falls back to the legacy
|
||||
# behavior of splitting on spaces and/or commas.
|
||||
# 2. Parsing of the prepend_sys_path option falls back to the legacy
|
||||
# behavior of splitting on spaces, commas, or colons.
|
||||
#
|
||||
# Valid values for path_separator are:
|
||||
#
|
||||
# path_separator = :
|
||||
# path_separator = ;
|
||||
# path_separator = space
|
||||
# path_separator = newline
|
||||
#
|
||||
# Use os.pathsep. Default configuration used for new projects.
|
||||
path_separator = os
|
||||
|
||||
# set to 'true' to search source files recursively
|
||||
# in each "version_locations" directory
|
||||
# new in Alembic version 1.10
|
||||
# recursive_version_locations = false
|
||||
|
||||
# the output encoding used when revision files
|
||||
# are written from script.py.mako
|
||||
# output_encoding = utf-8
|
||||
|
||||
# database URL. This is consumed by the user-maintained env.py script only.
|
||||
# other means of configuring database URLs may be customized within the env.py
|
||||
# file.
|
||||
sqlalchemy.url = driver://user:pass@localhost/dbname
|
||||
|
||||
|
||||
[post_write_hooks]
|
||||
# post_write_hooks defines scripts or Python functions that are run
|
||||
# on newly generated revision scripts. See the documentation for further
|
||||
# detail and examples
|
||||
|
||||
# format using "black" - use the console_scripts runner, against the "black" entrypoint
|
||||
# hooks = black
|
||||
# black.type = console_scripts
|
||||
# black.entrypoint = black
|
||||
# black.options = -l 79 REVISION_SCRIPT_FILENAME
|
||||
|
||||
# lint with attempts to fix using "ruff" - use the module runner, against the "ruff" module
|
||||
# hooks = ruff
|
||||
# ruff.type = module
|
||||
# ruff.module = ruff
|
||||
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
|
||||
|
||||
# Alternatively, use the exec runner to execute a binary found on your PATH
|
||||
# hooks = ruff
|
||||
# ruff.type = exec
|
||||
# ruff.executable = ruff
|
||||
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
|
||||
|
||||
# Logging configuration. This is also consumed by the user-maintained
|
||||
# env.py script only.
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARNING
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARNING
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
||||
1
backend/alembic/README
Normal file
1
backend/alembic/README
Normal file
@@ -0,0 +1 @@
|
||||
Generic single-database configuration.
|
||||
94
backend/alembic/env.py
Normal file
94
backend/alembic/env.py
Normal file
@@ -0,0 +1,94 @@
|
||||
import sys
|
||||
import os
|
||||
from logging.config import fileConfig
|
||||
|
||||
from sqlalchemy import engine_from_config
|
||||
from sqlalchemy import pool
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
from alembic import context
|
||||
|
||||
# Add backend directory to sys.path so we can import src
|
||||
# Add project root to sys.path to allow importing backend
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))
|
||||
|
||||
# Import models
|
||||
from backend.src.models.entities import *
|
||||
from backend.src.models.session import *
|
||||
from backend.src.config.settings import DB_PATH, DATABASE_URL
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
config = context.config
|
||||
|
||||
# Interpret the config file for Python logging.
|
||||
# This line sets up loggers basically.
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
# Override sqlalchemy.url with the one from config
|
||||
if DATABASE_URL:
|
||||
config.set_main_option("sqlalchemy.url", DATABASE_URL)
|
||||
else:
|
||||
config.set_main_option("sqlalchemy.url", f"sqlite:///{DB_PATH}")
|
||||
|
||||
# add your model's MetaData object here
|
||||
# for 'autogenerate' support
|
||||
target_metadata = SQLModel.metadata
|
||||
|
||||
# other values from the config, defined by the needs of env.py,
|
||||
# can be acquired:
|
||||
# my_important_option = config.get_main_option("my_important_option")
|
||||
# ... etc.
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Run migrations in 'offline' mode.
|
||||
|
||||
This configures the context with just a URL
|
||||
and not an Engine, though an Engine is acceptable
|
||||
here as well. By skipping the Engine creation
|
||||
we don't even need a DBAPI to be available.
|
||||
|
||||
Calls to context.execute() here emit the given string to the
|
||||
script output.
|
||||
|
||||
"""
|
||||
url = config.get_main_option("sqlalchemy.url")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
"""Run migrations in 'online' mode.
|
||||
|
||||
In this scenario we need to create an Engine
|
||||
and associate a connection with the context.
|
||||
|
||||
"""
|
||||
connectable = engine_from_config(
|
||||
config.get_section(config.config_ini_section, {}),
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
with connectable.connect() as connection:
|
||||
context.configure(
|
||||
connection=connection, target_metadata=target_metadata
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
28
backend/alembic/script.py.mako
Normal file
28
backend/alembic/script.py.mako
Normal file
@@ -0,0 +1,28 @@
|
||||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = ${repr(up_revision)}
|
||||
down_revision: Union[str, Sequence[str], None] = ${repr(down_revision)}
|
||||
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
${downgrades if downgrades else "pass"}
|
||||
339
backend/alembic/verify_migration.py
Normal file
339
backend/alembic/verify_migration.py
Normal file
@@ -0,0 +1,339 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Migration Verification Script for Canvas Metadata Table
|
||||
|
||||
This script verifies that the canvas_metadata migration was successful by:
|
||||
1. Checking that the canvas_metadata table exists
|
||||
2. Verifying all required columns exist with correct types
|
||||
3. Checking that all indexes were created
|
||||
4. Validating data migration from general_canvases
|
||||
5. Validating data migration from asset canvases
|
||||
6. Validating data migration from storyboard canvases
|
||||
7. Checking data integrity and consistency
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
from sqlalchemy import create_engine, inspect, text
|
||||
from sqlalchemy.engine import Engine
|
||||
import json
|
||||
|
||||
# Add backend directory to sys.path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))
|
||||
|
||||
from backend.src.config.settings import DB_PATH
|
||||
|
||||
|
||||
class MigrationVerifier:
|
||||
def __init__(self, db_path: str):
|
||||
self.db_path = db_path
|
||||
self.engine = create_engine(f"sqlite:///{db_path}")
|
||||
self.inspector = inspect(self.engine)
|
||||
self.errors = []
|
||||
self.warnings = []
|
||||
self.success_count = 0
|
||||
self.total_checks = 0
|
||||
|
||||
def check(self, condition: bool, success_msg: str, error_msg: str):
|
||||
"""Helper method to track check results"""
|
||||
self.total_checks += 1
|
||||
if condition:
|
||||
self.success_count += 1
|
||||
print(f"✅ {success_msg}")
|
||||
else:
|
||||
self.errors.append(error_msg)
|
||||
print(f"❌ {error_msg}")
|
||||
|
||||
def warn(self, message: str):
|
||||
"""Helper method to track warnings"""
|
||||
self.warnings.append(message)
|
||||
print(f"⚠️ {message}")
|
||||
|
||||
def verify_table_exists(self) -> bool:
|
||||
"""Verify that canvas_metadata table exists"""
|
||||
print("\n=== Checking Table Existence ===")
|
||||
tables = self.inspector.get_table_names()
|
||||
self.check(
|
||||
'canvas_metadata' in tables,
|
||||
"canvas_metadata table exists",
|
||||
"canvas_metadata table does not exist"
|
||||
)
|
||||
return 'canvas_metadata' in tables
|
||||
|
||||
def verify_columns(self) -> bool:
|
||||
"""Verify all required columns exist with correct types"""
|
||||
print("\n=== Checking Columns ===")
|
||||
|
||||
required_columns = {
|
||||
'id': 'VARCHAR',
|
||||
'project_id': 'VARCHAR',
|
||||
'canvas_type': 'VARCHAR',
|
||||
'related_entity_type': 'VARCHAR',
|
||||
'related_entity_id': 'VARCHAR',
|
||||
'name': 'VARCHAR',
|
||||
'description': 'VARCHAR',
|
||||
'order_index': 'INTEGER',
|
||||
'is_pinned': 'BOOLEAN',
|
||||
'tags': 'JSON',
|
||||
'node_count': 'INTEGER',
|
||||
'last_accessed_at': 'FLOAT',
|
||||
'access_count': 'INTEGER',
|
||||
'created_at': 'FLOAT',
|
||||
'updated_at': 'FLOAT',
|
||||
'deleted_at': 'FLOAT',
|
||||
'legacy_id': 'VARCHAR'
|
||||
}
|
||||
|
||||
columns = self.inspector.get_columns('canvas_metadata')
|
||||
column_dict = {col['name']: col for col in columns}
|
||||
|
||||
all_columns_exist = True
|
||||
for col_name, expected_type in required_columns.items():
|
||||
if col_name in column_dict:
|
||||
col_type = str(column_dict[col_name]['type']).upper()
|
||||
# SQLite stores JSON as TEXT, so we need to check for that
|
||||
if expected_type == 'JSON' and 'TEXT' in col_type:
|
||||
print(f"✅ Column '{col_name}' exists with type {col_type} (JSON stored as TEXT)")
|
||||
elif expected_type in col_type or col_type in expected_type:
|
||||
print(f"✅ Column '{col_name}' exists with type {col_type}")
|
||||
else:
|
||||
self.warn(f"Column '{col_name}' exists but type is {col_type}, expected {expected_type}")
|
||||
else:
|
||||
all_columns_exist = False
|
||||
self.errors.append(f"Column '{col_name}' is missing")
|
||||
print(f"❌ Column '{col_name}' is missing")
|
||||
|
||||
return all_columns_exist
|
||||
|
||||
def verify_indexes(self) -> bool:
|
||||
"""Verify all required indexes exist"""
|
||||
print("\n=== Checking Indexes ===")
|
||||
|
||||
required_indexes = [
|
||||
'ix_canvas_metadata_project_id',
|
||||
'ix_canvas_metadata_canvas_type',
|
||||
'ix_canvas_metadata_related_entity_id',
|
||||
'ix_canvas_metadata_legacy_id',
|
||||
'ix_canvas_metadata_project_type',
|
||||
'ix_canvas_metadata_type_entity'
|
||||
]
|
||||
|
||||
indexes = self.inspector.get_indexes('canvas_metadata')
|
||||
index_names = [idx['name'] for idx in indexes]
|
||||
|
||||
all_indexes_exist = True
|
||||
for idx_name in required_indexes:
|
||||
if idx_name in index_names:
|
||||
print(f"✅ Index '{idx_name}' exists")
|
||||
else:
|
||||
all_indexes_exist = False
|
||||
self.errors.append(f"Index '{idx_name}' is missing")
|
||||
print(f"❌ Index '{idx_name}' is missing")
|
||||
|
||||
return all_indexes_exist
|
||||
|
||||
def verify_foreign_keys(self) -> bool:
|
||||
"""Verify foreign key constraints"""
|
||||
print("\n=== Checking Foreign Keys ===")
|
||||
|
||||
fks = self.inspector.get_foreign_keys('canvas_metadata')
|
||||
|
||||
has_project_fk = any(
|
||||
fk['referred_table'] == 'projects' and 'project_id' in fk['constrained_columns']
|
||||
for fk in fks
|
||||
)
|
||||
|
||||
self.check(
|
||||
has_project_fk,
|
||||
"Foreign key to projects table exists",
|
||||
"Foreign key to projects table is missing"
|
||||
)
|
||||
|
||||
return has_project_fk
|
||||
|
||||
def verify_data_migration(self) -> bool:
|
||||
"""Verify data was migrated correctly"""
|
||||
print("\n=== Checking Data Migration ===")
|
||||
|
||||
with self.engine.connect() as conn:
|
||||
# Check if any canvas_metadata records exist
|
||||
result = conn.execute(text("SELECT COUNT(*) FROM canvas_metadata"))
|
||||
count = result.scalar()
|
||||
|
||||
if count > 0:
|
||||
print(f"✅ Found {count} canvas metadata records")
|
||||
|
||||
# Check general canvases
|
||||
result = conn.execute(text(
|
||||
"SELECT COUNT(*) FROM canvas_metadata WHERE canvas_type = 'general'"
|
||||
))
|
||||
general_count = result.scalar()
|
||||
print(f" - General canvases: {general_count}")
|
||||
|
||||
# Check asset canvases
|
||||
result = conn.execute(text(
|
||||
"SELECT COUNT(*) FROM canvas_metadata WHERE canvas_type = 'asset'"
|
||||
))
|
||||
asset_count = result.scalar()
|
||||
print(f" - Asset canvases: {asset_count}")
|
||||
|
||||
# Check storyboard canvases
|
||||
result = conn.execute(text(
|
||||
"SELECT COUNT(*) FROM canvas_metadata WHERE canvas_type = 'storyboard'"
|
||||
))
|
||||
storyboard_count = result.scalar()
|
||||
print(f" - Storyboard canvases: {storyboard_count}")
|
||||
|
||||
# Verify legacy_id mapping for migrated canvases
|
||||
result = conn.execute(text(
|
||||
"SELECT COUNT(*) FROM canvas_metadata WHERE legacy_id IS NOT NULL"
|
||||
))
|
||||
legacy_count = result.scalar()
|
||||
if legacy_count > 0:
|
||||
print(f"✅ Found {legacy_count} canvases with legacy_id mapping")
|
||||
|
||||
return True
|
||||
else:
|
||||
self.warn("No canvas metadata records found (this is OK if database is empty)")
|
||||
return True
|
||||
|
||||
def verify_data_integrity(self) -> bool:
|
||||
"""Verify data integrity constraints"""
|
||||
print("\n=== Checking Data Integrity ===")
|
||||
|
||||
with self.engine.connect() as conn:
|
||||
# Check if canvases table exists
|
||||
tables = self.inspector.get_table_names()
|
||||
if 'canvases' not in tables:
|
||||
self.warn("canvases table does not exist yet - skipping canvas content check")
|
||||
orphaned_metadata = 0
|
||||
else:
|
||||
# Check that all canvas_metadata records have corresponding canvas content
|
||||
result = conn.execute(text("""
|
||||
SELECT COUNT(*)
|
||||
FROM canvas_metadata cm
|
||||
LEFT JOIN canvases c ON cm.id = c.id
|
||||
WHERE c.id IS NULL
|
||||
"""))
|
||||
orphaned_metadata = result.scalar()
|
||||
|
||||
self.check(
|
||||
orphaned_metadata == 0,
|
||||
f"All canvas metadata records have corresponding canvas content",
|
||||
f"Found {orphaned_metadata} canvas metadata records without canvas content"
|
||||
)
|
||||
|
||||
# Check that related_entity_id is set for asset and storyboard canvases
|
||||
result = conn.execute(text("""
|
||||
SELECT COUNT(*)
|
||||
FROM canvas_metadata
|
||||
WHERE canvas_type IN ('asset', 'storyboard')
|
||||
AND related_entity_id IS NULL
|
||||
"""))
|
||||
missing_entity_id = result.scalar()
|
||||
|
||||
self.check(
|
||||
missing_entity_id == 0,
|
||||
"All asset/storyboard canvases have related_entity_id",
|
||||
f"Found {missing_entity_id} asset/storyboard canvases without related_entity_id"
|
||||
)
|
||||
|
||||
# Check that general canvases don't have related_entity_id
|
||||
result = conn.execute(text("""
|
||||
SELECT COUNT(*)
|
||||
FROM canvas_metadata
|
||||
WHERE canvas_type = 'general'
|
||||
AND related_entity_id IS NOT NULL
|
||||
"""))
|
||||
invalid_general = result.scalar()
|
||||
|
||||
self.check(
|
||||
invalid_general == 0,
|
||||
"General canvases don't have related_entity_id",
|
||||
f"Found {invalid_general} general canvases with related_entity_id"
|
||||
)
|
||||
|
||||
return orphaned_metadata == 0 and missing_entity_id == 0 and invalid_general == 0
|
||||
|
||||
def verify_project_relationship(self) -> bool:
|
||||
"""Verify that ProjectDB relationship is working"""
|
||||
print("\n=== Checking Project Relationship ===")
|
||||
|
||||
with self.engine.connect() as conn:
|
||||
# Check that all canvas_metadata records reference valid projects
|
||||
result = conn.execute(text("""
|
||||
SELECT COUNT(*)
|
||||
FROM canvas_metadata cm
|
||||
LEFT JOIN projects p ON cm.project_id = p.id
|
||||
WHERE p.id IS NULL
|
||||
"""))
|
||||
orphaned_canvases = result.scalar()
|
||||
|
||||
self.check(
|
||||
orphaned_canvases == 0,
|
||||
"All canvas metadata records reference valid projects",
|
||||
f"Found {orphaned_canvases} canvas metadata records with invalid project_id"
|
||||
)
|
||||
|
||||
return orphaned_canvases == 0
|
||||
|
||||
def run_all_checks(self) -> bool:
|
||||
"""Run all verification checks"""
|
||||
print("=" * 60)
|
||||
print("Canvas Metadata Migration Verification")
|
||||
print("=" * 60)
|
||||
print(f"Database: {self.db_path}")
|
||||
|
||||
if not os.path.exists(self.db_path):
|
||||
print(f"\n❌ Database file does not exist: {self.db_path}")
|
||||
return False
|
||||
|
||||
# Run all checks
|
||||
table_exists = self.verify_table_exists()
|
||||
if not table_exists:
|
||||
print("\n❌ Cannot continue verification - table does not exist")
|
||||
return False
|
||||
|
||||
self.verify_columns()
|
||||
self.verify_indexes()
|
||||
self.verify_foreign_keys()
|
||||
self.verify_data_migration()
|
||||
self.verify_data_integrity()
|
||||
self.verify_project_relationship()
|
||||
|
||||
# Print summary
|
||||
print("\n" + "=" * 60)
|
||||
print("Verification Summary")
|
||||
print("=" * 60)
|
||||
print(f"Total checks: {self.total_checks}")
|
||||
print(f"Passed: {self.success_count}")
|
||||
print(f"Failed: {len(self.errors)}")
|
||||
print(f"Warnings: {len(self.warnings)}")
|
||||
|
||||
if self.errors:
|
||||
print("\n❌ Errors:")
|
||||
for error in self.errors:
|
||||
print(f" - {error}")
|
||||
|
||||
if self.warnings:
|
||||
print("\n⚠️ Warnings:")
|
||||
for warning in self.warnings:
|
||||
print(f" - {warning}")
|
||||
|
||||
if len(self.errors) == 0:
|
||||
print("\n✅ All checks passed! Migration is successful.")
|
||||
return True
|
||||
else:
|
||||
print("\n❌ Migration verification failed. Please review the errors above.")
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point"""
|
||||
verifier = MigrationVerifier(DB_PATH)
|
||||
success = verifier.run_all_checks()
|
||||
sys.exit(0 if success else 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
127
backend/alembic/versions/72f609dd9e66_initial_schema.py
Normal file
127
backend/alembic/versions/72f609dd9e66_initial_schema.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""Initial schema
|
||||
|
||||
Revision ID: 72f609dd9e66
|
||||
Revises:
|
||||
Create Date: 2026-01-08 09:52:59.473436
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '72f609dd9e66'
|
||||
down_revision: Union[str, Sequence[str], None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('projects',
|
||||
sa.Column('id', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column('description', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column('type', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column('status', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column('created_at', sa.Float(), nullable=False),
|
||||
sa.Column('updated_at', sa.Float(), nullable=False),
|
||||
sa.Column('resolution', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column('ratio', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column('style_preset', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column('style_params', sa.JSON(), nullable=True),
|
||||
sa.Column('chapters', sa.JSON(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_table('tasks',
|
||||
sa.Column('id', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column('type', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column('status', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column('created_at', sa.Float(), nullable=False),
|
||||
sa.Column('updated_at', sa.Float(), nullable=False),
|
||||
sa.Column('model', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column('params', sa.JSON(), nullable=True),
|
||||
sa.Column('provider_task_id', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column('result', sa.JSON(), nullable=True),
|
||||
sa.Column('error', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_table('assets',
|
||||
sa.Column('id', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column('project_id', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column('type', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column('desc', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column('tags', sa.JSON(), nullable=True),
|
||||
sa.Column('image_url', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column('image_urls', sa.JSON(), nullable=True),
|
||||
sa.Column('video_urls', sa.JSON(), nullable=True),
|
||||
sa.Column('extra_data', sa.JSON(), nullable=True),
|
||||
sa.Column('generations', sa.JSON(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['project_id'], ['projects.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_assets_project_id'), 'assets', ['project_id'], unique=False)
|
||||
op.create_table('episodes',
|
||||
sa.Column('id', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column('project_id', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column('order_index', sa.Integer(), nullable=False),
|
||||
sa.Column('title', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column('desc', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column('content', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column('status', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.ForeignKeyConstraint(['project_id'], ['projects.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_episodes_project_id'), 'episodes', ['project_id'], unique=False)
|
||||
op.create_table('storyboards',
|
||||
sa.Column('id', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column('project_id', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column('episode_id', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column('order_index', sa.Integer(), nullable=False),
|
||||
sa.Column('shot', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column('desc', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column('duration', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column('type', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column('scene_id', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column('character_ids', sa.JSON(), nullable=True),
|
||||
sa.Column('prop_ids', sa.JSON(), nullable=True),
|
||||
sa.Column('voiceover', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column('audio_desc', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column('audio_url', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column('camera_movement', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column('transition', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column('visual_anchor', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column('visual_dynamics', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column('director_note', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column('image_prompt', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column('video_script', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column('image_urls', sa.JSON(), nullable=True),
|
||||
sa.Column('video_urls', sa.JSON(), nullable=True),
|
||||
sa.Column('generations', sa.JSON(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['episode_id'], ['episodes.id'], ),
|
||||
sa.ForeignKeyConstraint(['project_id'], ['projects.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_storyboards_episode_id'), 'storyboards', ['episode_id'], unique=False)
|
||||
op.create_index(op.f('ix_storyboards_project_id'), 'storyboards', ['project_id'], unique=False)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index(op.f('ix_storyboards_project_id'), table_name='storyboards')
|
||||
op.drop_index(op.f('ix_storyboards_episode_id'), table_name='storyboards')
|
||||
op.drop_table('storyboards')
|
||||
op.drop_index(op.f('ix_episodes_project_id'), table_name='episodes')
|
||||
op.drop_table('episodes')
|
||||
op.drop_index(op.f('ix_assets_project_id'), table_name='assets')
|
||||
op.drop_table('assets')
|
||||
op.drop_table('tasks')
|
||||
op.drop_table('projects')
|
||||
# ### end Alembic commands ###
|
||||
245
backend/alembic/versions/add_canvas_metadata_table.py
Normal file
245
backend/alembic/versions/add_canvas_metadata_table.py
Normal file
@@ -0,0 +1,245 @@
|
||||
"""add canvas metadata table
|
||||
|
||||
Revision ID: add_canvas_metadata
|
||||
Revises: bfac9b8e32f5
|
||||
Create Date: 2026-01-17 10:00:00.000000
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'add_canvas_metadata'
|
||||
down_revision: Union[str, Sequence[str], None] = ('add_progress_tracking', 'add_prompt_fields')
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# 1. 创建 canvas_metadata 表
|
||||
op.create_table(
|
||||
'canvas_metadata',
|
||||
sa.Column('id', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column('project_id', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column('canvas_type', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column('related_entity_type', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column('related_entity_id', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column('description', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.Column('order_index', sa.Integer(), nullable=False, server_default='0'),
|
||||
sa.Column('is_pinned', sa.Boolean(), nullable=False, server_default='0'),
|
||||
sa.Column('tags', sa.JSON(), nullable=True),
|
||||
sa.Column('node_count', sa.Integer(), nullable=False, server_default='0'),
|
||||
sa.Column('last_accessed_at', sa.Float(), nullable=True),
|
||||
sa.Column('access_count', sa.Integer(), nullable=False, server_default='0'),
|
||||
sa.Column('created_at', sa.Float(), nullable=False),
|
||||
sa.Column('updated_at', sa.Float(), nullable=False),
|
||||
sa.Column('deleted_at', sa.Float(), nullable=True),
|
||||
sa.Column('legacy_id', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.ForeignKeyConstraint(['project_id'], ['projects.id'])
|
||||
)
|
||||
|
||||
# 2. 创建索引
|
||||
op.create_index('ix_canvas_metadata_project_id', 'canvas_metadata', ['project_id'])
|
||||
op.create_index('ix_canvas_metadata_canvas_type', 'canvas_metadata', ['canvas_type'])
|
||||
op.create_index('ix_canvas_metadata_related_entity_id', 'canvas_metadata', ['related_entity_id'])
|
||||
op.create_index('ix_canvas_metadata_legacy_id', 'canvas_metadata', ['legacy_id'])
|
||||
op.create_index('ix_canvas_metadata_project_type', 'canvas_metadata', ['project_id', 'canvas_type'])
|
||||
op.create_index('ix_canvas_metadata_type_entity', 'canvas_metadata', ['canvas_type', 'related_entity_id'])
|
||||
|
||||
# 3. 迁移数据
|
||||
migrate_general_canvases()
|
||||
migrate_asset_canvases()
|
||||
migrate_storyboard_canvases()
|
||||
|
||||
|
||||
def migrate_general_canvases():
|
||||
"""迁移通用画布数据"""
|
||||
conn = op.get_bind()
|
||||
|
||||
# 获取所有项目的 general_canvases
|
||||
try:
|
||||
projects = conn.execute(sa.text("SELECT id, general_canvases FROM projects")).fetchall()
|
||||
except:
|
||||
# 如果 general_canvases 列不存在,跳过
|
||||
return
|
||||
|
||||
for project in projects:
|
||||
project_id = project[0]
|
||||
general_canvases_json = project[1]
|
||||
|
||||
if not general_canvases_json:
|
||||
continue
|
||||
|
||||
try:
|
||||
canvases = json.loads(general_canvases_json) if isinstance(general_canvases_json, str) else general_canvases_json
|
||||
except:
|
||||
continue
|
||||
|
||||
if not isinstance(canvases, list):
|
||||
continue
|
||||
|
||||
for idx, canvas in enumerate(canvases):
|
||||
canvas_id = canvas.get('id')
|
||||
if not canvas_id:
|
||||
continue
|
||||
|
||||
# 插入到 canvas_metadata
|
||||
conn.execute(sa.text("""
|
||||
INSERT INTO canvas_metadata (
|
||||
id, project_id, canvas_type, name, order_index,
|
||||
created_at, updated_at
|
||||
) VALUES (
|
||||
:id, :project_id, 'general', :name, :order_index,
|
||||
:created_at, :updated_at
|
||||
)
|
||||
"""), {
|
||||
'id': canvas_id,
|
||||
'project_id': project_id,
|
||||
'name': canvas.get('name', f'Canvas {idx + 1}'),
|
||||
'order_index': idx,
|
||||
'created_at': canvas.get('createdAt', datetime.now().timestamp()),
|
||||
'updated_at': canvas.get('updatedAt', datetime.now().timestamp())
|
||||
})
|
||||
|
||||
conn.commit()
|
||||
|
||||
|
||||
def migrate_asset_canvases():
|
||||
"""迁移素材画布数据"""
|
||||
conn = op.get_bind()
|
||||
|
||||
# 查找所有以 canvas-asset- 开头的画布
|
||||
try:
|
||||
canvases = conn.execute(sa.text("""
|
||||
SELECT id, project_id, updated_at
|
||||
FROM canvases
|
||||
WHERE id LIKE 'canvas-asset-%'
|
||||
""")).fetchall()
|
||||
except:
|
||||
return
|
||||
|
||||
for canvas in canvases:
|
||||
old_id = canvas[0]
|
||||
project_id = canvas[1]
|
||||
updated_at = canvas[2]
|
||||
|
||||
# 提取 asset_id
|
||||
asset_id = old_id.replace('canvas-asset-', '')
|
||||
|
||||
# 查找对应的 asset
|
||||
try:
|
||||
asset = conn.execute(sa.text("""
|
||||
SELECT name FROM assets WHERE id = :asset_id
|
||||
"""), {'asset_id': asset_id}).fetchone()
|
||||
except:
|
||||
continue
|
||||
|
||||
if not asset:
|
||||
continue
|
||||
|
||||
# 生成新 UUID
|
||||
new_id = str(uuid.uuid4())
|
||||
|
||||
# 插入元数据
|
||||
conn.execute(sa.text("""
|
||||
INSERT INTO canvas_metadata (
|
||||
id, project_id, canvas_type, related_entity_type,
|
||||
related_entity_id, name, created_at, updated_at, legacy_id
|
||||
) VALUES (
|
||||
:id, :project_id, 'asset', 'asset',
|
||||
:asset_id, :name, :created_at, :updated_at, :legacy_id
|
||||
)
|
||||
"""), {
|
||||
'id': new_id,
|
||||
'project_id': project_id,
|
||||
'asset_id': asset_id,
|
||||
'name': asset[0],
|
||||
'created_at': updated_at,
|
||||
'updated_at': updated_at,
|
||||
'legacy_id': old_id
|
||||
})
|
||||
|
||||
# 更新 canvases 表的 ID
|
||||
conn.execute(sa.text("""
|
||||
UPDATE canvases SET id = :new_id WHERE id = :old_id
|
||||
"""), {'new_id': new_id, 'old_id': old_id})
|
||||
|
||||
conn.commit()
|
||||
|
||||
|
||||
def migrate_storyboard_canvases():
|
||||
"""迁移分镜画布数据"""
|
||||
conn = op.get_bind()
|
||||
|
||||
# 查找所有以 canvas-storyboard- 开头的画布
|
||||
try:
|
||||
canvases = conn.execute(sa.text("""
|
||||
SELECT id, project_id, updated_at
|
||||
FROM canvases
|
||||
WHERE id LIKE 'canvas-storyboard-%'
|
||||
""")).fetchall()
|
||||
except:
|
||||
return
|
||||
|
||||
for canvas in canvases:
|
||||
old_id = canvas[0]
|
||||
project_id = canvas[1]
|
||||
updated_at = canvas[2]
|
||||
|
||||
storyboard_id = old_id.replace('canvas-storyboard-', '')
|
||||
|
||||
try:
|
||||
storyboard = conn.execute(sa.text("""
|
||||
SELECT shot FROM storyboards WHERE id = :storyboard_id
|
||||
"""), {'storyboard_id': storyboard_id}).fetchone()
|
||||
except:
|
||||
continue
|
||||
|
||||
if not storyboard:
|
||||
continue
|
||||
|
||||
new_id = str(uuid.uuid4())
|
||||
|
||||
conn.execute(sa.text("""
|
||||
INSERT INTO canvas_metadata (
|
||||
id, project_id, canvas_type, related_entity_type,
|
||||
related_entity_id, name, created_at, updated_at, legacy_id
|
||||
) VALUES (
|
||||
:id, :project_id, 'storyboard', 'storyboard',
|
||||
:storyboard_id, :name, :created_at, :updated_at, :legacy_id
|
||||
)
|
||||
"""), {
|
||||
'id': new_id,
|
||||
'project_id': project_id,
|
||||
'storyboard_id': storyboard_id,
|
||||
'name': storyboard[0],
|
||||
'created_at': updated_at,
|
||||
'updated_at': updated_at,
|
||||
'legacy_id': old_id
|
||||
})
|
||||
|
||||
conn.execute(sa.text("""
|
||||
UPDATE canvases SET id = :new_id WHERE id = :old_id
|
||||
"""), {'new_id': new_id, 'old_id': old_id})
|
||||
|
||||
conn.commit()
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# 回滚操作
|
||||
op.drop_index('ix_canvas_metadata_type_entity', 'canvas_metadata')
|
||||
op.drop_index('ix_canvas_metadata_project_type', 'canvas_metadata')
|
||||
op.drop_index('ix_canvas_metadata_legacy_id', 'canvas_metadata')
|
||||
op.drop_index('ix_canvas_metadata_related_entity_id', 'canvas_metadata')
|
||||
op.drop_index('ix_canvas_metadata_canvas_type', 'canvas_metadata')
|
||||
op.drop_index('ix_canvas_metadata_project_id', 'canvas_metadata')
|
||||
op.drop_table('canvas_metadata')
|
||||
41
backend/alembic/versions/add_cinematic_fields.py
Normal file
41
backend/alembic/versions/add_cinematic_fields.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""add cinematic and professional fields to assets and storyboards
|
||||
|
||||
Revision ID: add_cinematic_fields
|
||||
Revises: add_prompt_fields
|
||||
Create Date: 2026-01-20
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'add_cinematic_fields'
|
||||
down_revision = 'add_canvas_metadata'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Assets表已经使用extra_data存储这些字段,但为了查询效率,我们可以选择不添加直接列
|
||||
# 因为Asset的emotion, environment_type, weather等字段已经通过extra_data JSON存储
|
||||
# 如果未来需要索引查询,可以添加:
|
||||
# op.add_column('assets', sa.Column('emotion', sa.String(), nullable=True))
|
||||
# op.add_column('assets', sa.Column('environment_type', sa.String(), nullable=True))
|
||||
# op.add_column('assets', sa.Column('weather', sa.String(), nullable=True))
|
||||
|
||||
# Add cinematic control fields to storyboards table
|
||||
op.add_column('storyboards', sa.Column('camera_angle', sa.String(), nullable=True))
|
||||
op.add_column('storyboards', sa.Column('lens', sa.String(), nullable=True))
|
||||
op.add_column('storyboards', sa.Column('focus', sa.String(), nullable=True))
|
||||
op.add_column('storyboards', sa.Column('lighting', sa.String(), nullable=True))
|
||||
op.add_column('storyboards', sa.Column('color_style', sa.String(), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove cinematic fields from storyboards
|
||||
op.drop_column('storyboards', 'color_style')
|
||||
op.drop_column('storyboards', 'lighting')
|
||||
op.drop_column('storyboards', 'focus')
|
||||
op.drop_column('storyboards', 'lens')
|
||||
op.drop_column('storyboards', 'camera_angle')
|
||||
100
backend/alembic/versions/add_indexes_and_optimizations.py
Normal file
100
backend/alembic/versions/add_indexes_and_optimizations.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""Add indexes and database optimizations
|
||||
|
||||
Revision ID: add_indexes_opt
|
||||
Revises: bfac9b8e32f5
|
||||
Create Date: 2026-01-14 10:00:00.000000
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'add_indexes_opt'
|
||||
down_revision: Union[str, Sequence[str], None] = 'bfac9b8e32f5'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Add indexes, soft delete columns, and full-text search support."""
|
||||
|
||||
# Add soft delete columns
|
||||
op.add_column('projects', sa.Column('deleted_at', sa.Float(), nullable=True))
|
||||
op.add_column('assets', sa.Column('deleted_at', sa.Float(), nullable=True))
|
||||
op.add_column('episodes', sa.Column('deleted_at', sa.Float(), nullable=True))
|
||||
op.add_column('storyboards', sa.Column('deleted_at', sa.Float(), nullable=True))
|
||||
op.add_column('tasks', sa.Column('deleted_at', sa.Float(), nullable=True))
|
||||
|
||||
# Add indexes for frequently queried fields on projects
|
||||
op.create_index('idx_projects_created_at', 'projects', ['created_at'])
|
||||
op.create_index('idx_projects_updated_at', 'projects', ['updated_at'])
|
||||
op.create_index('idx_projects_status', 'projects', ['status'])
|
||||
op.create_index('idx_projects_deleted_at', 'projects', ['deleted_at'])
|
||||
|
||||
# Add indexes for tasks
|
||||
op.create_index('idx_tasks_status', 'tasks', ['status'])
|
||||
op.create_index('idx_tasks_type', 'tasks', ['type'])
|
||||
op.create_index('idx_tasks_created_at', 'tasks', ['created_at'])
|
||||
op.create_index('idx_tasks_type_status', 'tasks', ['type', 'status'])
|
||||
op.create_index('idx_tasks_deleted_at', 'tasks', ['deleted_at'])
|
||||
|
||||
# Add indexes for assets
|
||||
op.create_index('idx_assets_type', 'assets', ['type'])
|
||||
op.create_index('idx_assets_deleted_at', 'assets', ['deleted_at'])
|
||||
|
||||
# Add indexes for episodes
|
||||
op.create_index('idx_episodes_status', 'episodes', ['status'])
|
||||
op.create_index('idx_episodes_order_index', 'episodes', ['order_index'])
|
||||
op.create_index('idx_episodes_deleted_at', 'episodes', ['deleted_at'])
|
||||
|
||||
# Add indexes for storyboards
|
||||
op.create_index('idx_storyboards_type', 'storyboards', ['type'])
|
||||
op.create_index('idx_storyboards_order_index', 'storyboards', ['order_index'])
|
||||
op.create_index('idx_storyboards_deleted_at', 'storyboards', ['deleted_at'])
|
||||
|
||||
# Note: SQLite doesn't support full-text search indexes like PostgreSQL
|
||||
# For SQLite, we'll use the FTS5 virtual table approach in the application layer
|
||||
# or use LIKE queries with indexes on the name columns
|
||||
# Adding index on name columns for better LIKE query performance
|
||||
op.create_index('idx_projects_name', 'projects', ['name'])
|
||||
op.create_index('idx_assets_name', 'assets', ['name'])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Remove indexes, soft delete columns, and full-text search support."""
|
||||
|
||||
# Drop indexes
|
||||
op.drop_index('idx_assets_name', table_name='assets')
|
||||
op.drop_index('idx_projects_name', table_name='projects')
|
||||
|
||||
op.drop_index('idx_storyboards_deleted_at', table_name='storyboards')
|
||||
op.drop_index('idx_storyboards_order_index', table_name='storyboards')
|
||||
op.drop_index('idx_storyboards_type', table_name='storyboards')
|
||||
|
||||
op.drop_index('idx_episodes_deleted_at', table_name='episodes')
|
||||
op.drop_index('idx_episodes_order_index', table_name='episodes')
|
||||
op.drop_index('idx_episodes_status', table_name='episodes')
|
||||
|
||||
op.drop_index('idx_assets_deleted_at', table_name='assets')
|
||||
op.drop_index('idx_assets_type', table_name='assets')
|
||||
|
||||
op.drop_index('idx_tasks_deleted_at', table_name='tasks')
|
||||
op.drop_index('idx_tasks_type_status', table_name='tasks')
|
||||
op.drop_index('idx_tasks_created_at', table_name='tasks')
|
||||
op.drop_index('idx_tasks_type', table_name='tasks')
|
||||
op.drop_index('idx_tasks_status', table_name='tasks')
|
||||
|
||||
op.drop_index('idx_projects_deleted_at', table_name='projects')
|
||||
op.drop_index('idx_projects_status', table_name='projects')
|
||||
op.drop_index('idx_projects_updated_at', table_name='projects')
|
||||
op.drop_index('idx_projects_created_at', table_name='projects')
|
||||
|
||||
# Drop soft delete columns
|
||||
op.drop_column('tasks', 'deleted_at')
|
||||
op.drop_column('storyboards', 'deleted_at')
|
||||
op.drop_column('episodes', 'deleted_at')
|
||||
op.drop_column('assets', 'deleted_at')
|
||||
op.drop_column('projects', 'deleted_at')
|
||||
30
backend/alembic/versions/add_progress_tracking.py
Normal file
30
backend/alembic/versions/add_progress_tracking.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""add progress tracking fields
|
||||
|
||||
Revision ID: add_progress_tracking
|
||||
Revises: add_task_mgmt_fields
|
||||
Create Date: 2026-01-16 15:00:00.000000
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import sqlite
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'add_progress_tracking'
|
||||
down_revision = 'add_task_mgmt_fields'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add progress and error columns to projects table
|
||||
with op.batch_alter_table('projects', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('progress', sa.JSON(), nullable=True))
|
||||
batch_op.add_column(sa.Column('error', sa.JSON(), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove progress and error columns from projects table
|
||||
with op.batch_alter_table('projects', schema=None) as batch_op:
|
||||
batch_op.drop_column('error')
|
||||
batch_op.drop_column('progress')
|
||||
36
backend/alembic/versions/add_prompt_fields.py
Normal file
36
backend/alembic/versions/add_prompt_fields.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""add prompt fields to assets and storyboards
|
||||
|
||||
Revision ID: add_prompt_fields
|
||||
Revises: add_task_mgmt_fields
|
||||
Create Date: 2026-01-16
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'add_prompt_fields'
|
||||
down_revision = 'add_task_mgmt_fields'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add image_prompt to assets table
|
||||
op.add_column('assets', sa.Column('image_prompt', sa.String(), nullable=True))
|
||||
|
||||
# Add prompt fields to storyboards table
|
||||
op.add_column('storyboards', sa.Column('original_text', sa.String(), nullable=True))
|
||||
op.add_column('storyboards', sa.Column('merge_image_prompt', sa.String(), nullable=True))
|
||||
op.add_column('storyboards', sa.Column('video_prompt', sa.String(), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove fields from storyboards
|
||||
op.drop_column('storyboards', 'video_prompt')
|
||||
op.drop_column('storyboards', 'merge_image_prompt')
|
||||
op.drop_column('storyboards', 'original_text')
|
||||
|
||||
# Remove field from assets
|
||||
op.drop_column('assets', 'image_prompt')
|
||||
42
backend/alembic/versions/add_provider_to_tasks.py
Normal file
42
backend/alembic/versions/add_provider_to_tasks.py
Normal file
@@ -0,0 +1,42 @@
|
||||
"""add provider to tasks
|
||||
|
||||
Revision ID: add_provider_to_tasks
|
||||
Revises: bfac9b8e32f5
|
||||
Create Date: 2024-02-11
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'add_provider_to_tasks'
|
||||
down_revision = 'bfac9b8e32f5'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Add provider column to tasks table"""
|
||||
# Add provider column (nullable, indexed)
|
||||
op.add_column('tasks', sa.Column('provider', sa.String(), nullable=True))
|
||||
|
||||
# Create index on provider column for faster queries
|
||||
op.create_index(op.f('ix_tasks_provider'), 'tasks', ['provider'], unique=False)
|
||||
|
||||
# Optional: Migrate existing data by extracting provider from params
|
||||
# This is a data migration that can be run separately if needed
|
||||
op.execute("""
|
||||
UPDATE tasks
|
||||
SET provider = params->>'provider'
|
||||
WHERE params->>'provider' IS NOT NULL
|
||||
""")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Remove provider column from tasks table"""
|
||||
# Drop index first
|
||||
op.drop_index(op.f('ix_tasks_provider'), table_name='tasks')
|
||||
|
||||
# Drop column
|
||||
op.drop_column('tasks', 'provider')
|
||||
50
backend/alembic/versions/add_task_management_fields.py
Normal file
50
backend/alembic/versions/add_task_management_fields.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""add task management fields
|
||||
|
||||
Revision ID: add_task_mgmt_fields
|
||||
Revises: add_indexes_opt
|
||||
Create Date: 2026-01-14
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'add_task_mgmt_fields'
|
||||
down_revision = 'add_indexes_opt'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# Add retry configuration fields
|
||||
op.add_column('tasks', sa.Column('retry_count', sa.Integer(), nullable=False, server_default='0'))
|
||||
op.add_column('tasks', sa.Column('max_retries', sa.Integer(), nullable=False, server_default='3'))
|
||||
|
||||
# Add timestamp fields for task lifecycle
|
||||
op.add_column('tasks', sa.Column('started_at', sa.Float(), nullable=True))
|
||||
op.add_column('tasks', sa.Column('completed_at', sa.Float(), nullable=True))
|
||||
|
||||
# Add user context fields
|
||||
op.add_column('tasks', sa.Column('user_id', sa.String(), nullable=True))
|
||||
op.add_column('tasks', sa.Column('project_id', sa.String(), nullable=True))
|
||||
|
||||
# Add indexes for new fields
|
||||
op.create_index('idx_tasks_user_id', 'tasks', ['user_id'])
|
||||
op.create_index('idx_tasks_project_id', 'tasks', ['project_id'])
|
||||
|
||||
# Note: deleted_at column already exists from previous migration
|
||||
|
||||
|
||||
def downgrade():
|
||||
# Remove indexes
|
||||
op.drop_index('idx_tasks_project_id', table_name='tasks')
|
||||
op.drop_index('idx_tasks_user_id', table_name='tasks')
|
||||
|
||||
# Remove columns
|
||||
op.drop_column('tasks', 'project_id')
|
||||
op.drop_column('tasks', 'user_id')
|
||||
op.drop_column('tasks', 'completed_at')
|
||||
op.drop_column('tasks', 'started_at')
|
||||
op.drop_column('tasks', 'max_retries')
|
||||
op.drop_column('tasks', 'retry_count')
|
||||
54
backend/alembic/versions/add_user_sessions.py
Normal file
54
backend/alembic/versions/add_user_sessions.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""add user_sessions table
|
||||
|
||||
Revision ID: add_user_sessions
|
||||
Revises: b546dbb9df98
|
||||
Create Date: 2026-03-09
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
revision: str = 'add_user_sessions'
|
||||
down_revision: Union[str, Sequence[str], None] = 'b546dbb9df98'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
'user_sessions',
|
||||
sa.Column('id', sa.String(), nullable=False),
|
||||
sa.Column('user_id', sa.String(), nullable=False),
|
||||
sa.Column('session_family_id', sa.String(), nullable=False),
|
||||
sa.Column('refresh_token_hash', sa.String(), nullable=False),
|
||||
sa.Column('status', sa.String(), nullable=False, server_default='active'),
|
||||
sa.Column('created_at', sa.Float(), nullable=False),
|
||||
sa.Column('updated_at', sa.Float(), nullable=False),
|
||||
sa.Column('expires_at', sa.Float(), nullable=False),
|
||||
sa.Column('last_used_at', sa.Float(), nullable=True),
|
||||
sa.Column('revoked_at', sa.Float(), nullable=True),
|
||||
sa.Column('revoked_reason', sa.String(), nullable=True),
|
||||
sa.Column('replaced_by_session_id', sa.String(), nullable=True),
|
||||
sa.Column('ip_address', sa.String(), nullable=True),
|
||||
sa.Column('user_agent', sa.Text(), nullable=True),
|
||||
sa.Column('device_name', sa.String(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
)
|
||||
op.create_index(op.f('ix_user_sessions_user_id'), 'user_sessions', ['user_id'], unique=False)
|
||||
op.create_index(op.f('ix_user_sessions_session_family_id'), 'user_sessions', ['session_family_id'], unique=False)
|
||||
op.create_index(op.f('ix_user_sessions_refresh_token_hash'), 'user_sessions', ['refresh_token_hash'], unique=False)
|
||||
op.create_index(op.f('ix_user_sessions_status'), 'user_sessions', ['status'], unique=False)
|
||||
op.create_index(op.f('ix_user_sessions_revoked_at'), 'user_sessions', ['revoked_at'], unique=False)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index(op.f('ix_user_sessions_revoked_at'), table_name='user_sessions')
|
||||
op.drop_index(op.f('ix_user_sessions_status'), table_name='user_sessions')
|
||||
op.drop_index(op.f('ix_user_sessions_refresh_token_hash'), table_name='user_sessions')
|
||||
op.drop_index(op.f('ix_user_sessions_session_family_id'), table_name='user_sessions')
|
||||
op.drop_index(op.f('ix_user_sessions_user_id'), table_name='user_sessions')
|
||||
op.drop_table('user_sessions')
|
||||
@@ -0,0 +1,72 @@
|
||||
"""add_users_and_api_keys_tables
|
||||
|
||||
Revision ID: b546dbb9df98
|
||||
Revises: rename_style_preset
|
||||
Create Date: 2026-02-14 13:01:36.394119
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'b546dbb9df98'
|
||||
down_revision: Union[str, Sequence[str], None] = 'rename_style_preset'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema - Add users and user_api_keys tables."""
|
||||
# Create users table
|
||||
op.create_table(
|
||||
'users',
|
||||
sa.Column('id', sa.String(), nullable=False),
|
||||
sa.Column('username', sa.String(), nullable=False),
|
||||
sa.Column('email', sa.String(), nullable=True),
|
||||
sa.Column('password_hash', sa.String(), nullable=False),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False, server_default='1'),
|
||||
sa.Column('is_superuser', sa.Boolean(), nullable=False, server_default='0'),
|
||||
sa.Column('permissions', sa.JSON(), nullable=False, server_default='[]'),
|
||||
sa.Column('roles', sa.JSON(), nullable=False, server_default='[]'),
|
||||
sa.Column('created_at', sa.Float(), nullable=False),
|
||||
sa.Column('updated_at', sa.Float(), nullable=False),
|
||||
sa.Column('last_login', sa.Float(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('email'),
|
||||
sa.UniqueConstraint('username')
|
||||
)
|
||||
op.create_index(op.f('ix_users_email'), 'users', ['email'], unique=False)
|
||||
op.create_index(op.f('ix_users_username'), 'users', ['username'], unique=False)
|
||||
|
||||
# Create user_api_keys table
|
||||
op.create_table(
|
||||
'user_api_keys',
|
||||
sa.Column('id', sa.String(), nullable=False),
|
||||
sa.Column('user_id', sa.String(), nullable=False),
|
||||
sa.Column('provider', sa.String(), nullable=False),
|
||||
sa.Column('encrypted_key', sa.String(), nullable=False),
|
||||
sa.Column('name', sa.String(), nullable=True),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False, server_default='1'),
|
||||
sa.Column('created_at', sa.Float(), nullable=False),
|
||||
sa.Column('updated_at', sa.Float(), nullable=False),
|
||||
sa.Column('last_used_at', sa.Float(), nullable=True),
|
||||
sa.Column('usage_count', sa.Integer(), nullable=False, server_default='0'),
|
||||
sa.Column('extra_config', sa.JSON(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_user_api_keys_provider'), 'user_api_keys', ['provider'], unique=False)
|
||||
op.create_index(op.f('ix_user_api_keys_user_id'), 'user_api_keys', ['user_id'], unique=False)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema - Remove users and user_api_keys tables."""
|
||||
op.drop_index(op.f('ix_user_api_keys_user_id'), table_name='user_api_keys')
|
||||
op.drop_index(op.f('ix_user_api_keys_provider'), table_name='user_api_keys')
|
||||
op.drop_table('user_api_keys')
|
||||
op.drop_index(op.f('ix_users_username'), table_name='users')
|
||||
op.drop_index(op.f('ix_users_email'), table_name='users')
|
||||
op.drop_table('users')
|
||||
@@ -0,0 +1,30 @@
|
||||
"""add location and time to storyboards
|
||||
|
||||
Revision ID: bfac9b8e32f5
|
||||
Revises: 72f609dd9e66
|
||||
Create Date: 2026-01-11 00:49:48.323949
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'bfac9b8e32f5'
|
||||
down_revision: Union[str, Sequence[str], None] = '72f609dd9e66'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
op.add_column('storyboards', sa.Column('location', sa.String(), nullable=True))
|
||||
op.add_column('storyboards', sa.Column('time', sa.String(), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
op.drop_column('storyboards', 'time')
|
||||
op.drop_column('storyboards', 'location')
|
||||
34
backend/alembic/versions/rename_style_preset_to_style_id.py
Normal file
34
backend/alembic/versions/rename_style_preset_to_style_id.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""rename style_preset to style_id
|
||||
|
||||
Revision ID: rename_style_preset
|
||||
Revises: add_cinematic_fields, add_provider_to_tasks
|
||||
Create Date: 2024-02-11
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from typing import Union, Sequence
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'rename_style_preset'
|
||||
down_revision: Union[str, Sequence[str], None] = ('add_cinematic_fields', 'add_provider_to_tasks')
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
"""Rename style_preset column to style_id in projects table"""
|
||||
# SQLite doesn't support ALTER COLUMN RENAME directly
|
||||
# We need to use a workaround with table recreation
|
||||
|
||||
with op.batch_alter_table('projects', schema=None) as batch_op:
|
||||
# Rename the column
|
||||
batch_op.alter_column('style_preset', new_column_name='style_id')
|
||||
|
||||
|
||||
def downgrade():
|
||||
"""Revert style_id column back to style_preset"""
|
||||
with op.batch_alter_table('projects', schema=None) as batch_op:
|
||||
# Rename back
|
||||
batch_op.alter_column('style_id', new_column_name='style_preset')
|
||||
49
backend/pyproject.toml
Normal file
49
backend/pyproject.toml
Normal file
@@ -0,0 +1,49 @@
|
||||
[project]
|
||||
name = "backend"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"alibabacloud-tea-openapi>=0.4.2",
|
||||
"alibabacloud-tea-util>=0.3.14",
|
||||
"alibabacloud-videoenhan20200320>=4.0.0",
|
||||
"fastapi>=0.127.0",
|
||||
"oss2>=2.19.1",
|
||||
"pydantic>=2.12.5",
|
||||
"python-dotenv>=1.2.1",
|
||||
"python-multipart>=0.0.21",
|
||||
"requests>=2.32.5",
|
||||
"httpx>=0.27.0",
|
||||
"uvicorn>=0.40.0",
|
||||
"modelscope>=1.29.2",
|
||||
"volcengine>=1.0.100",
|
||||
"google-generativeai>=0.8.4",
|
||||
"sqlmodel>=0.0.31",
|
||||
"alembic>=1.17.2",
|
||||
"agentscope>=1.0.11",
|
||||
"redis>=5.0.0",
|
||||
"prometheus-client>=0.20.0",
|
||||
"opentelemetry-api>=1.25.0",
|
||||
"opentelemetry-sdk>=1.25.0",
|
||||
"opentelemetry-instrumentation-fastapi>=0.46b0",
|
||||
"opentelemetry-exporter-otlp-proto-grpc>=1.25.0",
|
||||
"fastmcp>=2.0.0",
|
||||
"tenacity>=8.2.3",
|
||||
"psycopg2-binary>=2.9.9",
|
||||
"asyncpg>=0.29.0",
|
||||
"sqladmin[full]>=0.16.0",
|
||||
"itsdangerous>=2.1.2",
|
||||
"psutil>=5.9.0",
|
||||
"Pillow>=11.0.0",
|
||||
]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
asyncio_default_fixture_loop_scope = "function"
|
||||
pythonpath = ["."]
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"hypothesis>=6.151.5",
|
||||
]
|
||||
84
backend/scripts/generate_error_codes_ts.py
Normal file
84
backend/scripts/generate_error_codes_ts.py
Normal file
@@ -0,0 +1,84 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Generate TypeScript ErrorCode enum from Python ErrorCode enum
|
||||
|
||||
Usage:
|
||||
python scripts/generate_error_codes_ts.py > ../frontend/src/lib/errors.ts
|
||||
|
||||
This script synchronizes the frontend ErrorCode enum with the backend definition
|
||||
to ensure consistency across the project.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add src to path
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent / "src"))
|
||||
|
||||
from src.utils.errors import ErrorCode
|
||||
|
||||
|
||||
def generate_typescript():
|
||||
lines = [
|
||||
"import { logger } from './utils/logger';",
|
||||
"",
|
||||
"/**",
|
||||
" * Frontend Error Handling Module",
|
||||
" * Provides standardized error codes, user-friendly messages, and retry logic",
|
||||
" *",
|
||||
" * Error codes are synchronized with backend:",
|
||||
" * Format: 4-digit string",
|
||||
" * - 0000: Success",
|
||||
" * - 1xxx: General errors",
|
||||
" * - 2xxx: Business errors",
|
||||
" * - 3xxx: Task errors",
|
||||
" * - 4xxx: AI service errors",
|
||||
" * - 5xxx: Storage errors",
|
||||
" */",
|
||||
"",
|
||||
"/**",
|
||||
" * Error codes matching backend error responses",
|
||||
" * Auto-synchronized with backend/src/utils/errors.py",
|
||||
" */",
|
||||
"export enum ErrorCode {",
|
||||
]
|
||||
|
||||
# Group error codes by category for better organization
|
||||
categories = {
|
||||
"0000": "Success",
|
||||
"1": "General errors",
|
||||
"2": "Business errors",
|
||||
"3": "Task errors",
|
||||
"4": "AI service errors",
|
||||
"5": "Storage errors",
|
||||
}
|
||||
|
||||
current_category = None
|
||||
|
||||
# Add backend error codes
|
||||
for code in ErrorCode:
|
||||
category = code.value[0] if code.value != "0000" else "0000"
|
||||
|
||||
if category != current_category:
|
||||
if current_category is not None:
|
||||
lines.append("")
|
||||
if category in categories:
|
||||
lines.append(f" // {categories[category]} ({category}xxx)")
|
||||
current_category = category
|
||||
|
||||
lines.append(f" {code.name} = '{code.value}',")
|
||||
|
||||
# Add frontend-specific error codes
|
||||
lines.extend([
|
||||
"",
|
||||
" // Frontend-specific errors (not from backend)",
|
||||
" NETWORK_ERROR = 'NET01',",
|
||||
" TIMEOUT_ERROR = 'TIM01',",
|
||||
"}",
|
||||
])
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(generate_typescript())
|
||||
1
backend/src/__init__.py
Normal file
1
backend/src/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Backend application package."""
|
||||
33
backend/src/admin_config.py
Normal file
33
backend/src/admin_config.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from sqladmin import Admin, ModelView
|
||||
from src.models.entities import ProjectDB, AssetDB, EpisodeDB, StoryboardDB
|
||||
|
||||
class ProjectAdmin(ModelView, model=ProjectDB):
|
||||
column_list = [ProjectDB.id, ProjectDB.name, ProjectDB.type, ProjectDB.status, ProjectDB.created_at]
|
||||
column_searchable_list = [ProjectDB.name, ProjectDB.id]
|
||||
column_sortable_list = [ProjectDB.created_at]
|
||||
icon = "fa-solid fa-diagram-project"
|
||||
|
||||
class AssetAdmin(ModelView, model=AssetDB):
|
||||
column_list = [AssetDB.id, AssetDB.name, AssetDB.type, AssetDB.project_id]
|
||||
column_searchable_list = [AssetDB.name, AssetDB.id, AssetDB.type]
|
||||
# 列_filters = ["type", "project_id"]
|
||||
# 列_filters = [AssetDB.type, AssetDB.project_id]
|
||||
icon = "fa-solid fa-cube"
|
||||
|
||||
class EpisodeAdmin(ModelView, model=EpisodeDB):
|
||||
column_list = [EpisodeDB.id, EpisodeDB.title, EpisodeDB.order_index, EpisodeDB.status, EpisodeDB.project_id]
|
||||
column_searchable_list = [EpisodeDB.title, EpisodeDB.id]
|
||||
# 列_filters = [EpisodeDB.status, EpisodeDB.project_id]
|
||||
icon = "fa-solid fa-film"
|
||||
|
||||
class StoryboardAdmin(ModelView, model=StoryboardDB):
|
||||
column_list = [StoryboardDB.id, StoryboardDB.project_id, StoryboardDB.episode_id, StoryboardDB.order_index]
|
||||
# 列_filters = [StoryboardDB.project_id, StoryboardDB.episode_id]
|
||||
icon = "fa-solid fa-image"
|
||||
|
||||
def setup_admin(app, engine):
|
||||
admin = Admin(app, engine, title="Pixel管理后台")
|
||||
admin.add_view(ProjectAdmin)
|
||||
admin.add_view(AssetAdmin)
|
||||
admin.add_view(EpisodeAdmin)
|
||||
admin.add_view(StoryboardAdmin)
|
||||
32
backend/src/api/admin/__init__.py
Normal file
32
backend/src/api/admin/__init__.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""
|
||||
Admin API Package
|
||||
|
||||
管理 API 模块,包含:
|
||||
- dashboard: 仪表板统计和系统资源路由
|
||||
- users: 用户管理路由
|
||||
- projects: 项目管理路由
|
||||
- tasks: 任务管理路由
|
||||
- settings: 系统设置路由
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from .dashboard import router as dashboard_router
|
||||
from .users import router as users_router
|
||||
from .projects import router as projects_router
|
||||
from .tasks import router as tasks_router
|
||||
from .settings import router as settings_router
|
||||
|
||||
# 创建主路由器
|
||||
router = APIRouter(prefix="/admin", tags=["admin"])
|
||||
|
||||
# 包含所有子路由
|
||||
router.include_router(dashboard_router)
|
||||
router.include_router(users_router)
|
||||
router.include_router(projects_router)
|
||||
router.include_router(tasks_router)
|
||||
router.include_router(settings_router)
|
||||
|
||||
__all__ = [
|
||||
"router",
|
||||
]
|
||||
85
backend/src/api/admin/dashboard.py
Normal file
85
backend/src/api/admin/dashboard.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""
|
||||
Admin API - Dashboard Routes
|
||||
|
||||
包含仪表板统计和系统资源相关的 API 路由。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
|
||||
from src.auth.dependencies import require_admin
|
||||
from src.auth.models import UserAuth
|
||||
from src.models.schemas import ResponseModel
|
||||
from src.services.admin_service import admin_service
|
||||
from src.utils.errors import ErrorCode, AppException
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/dashboard", tags=["admin-dashboard"])
|
||||
|
||||
|
||||
@router.get("/stats", response_model=ResponseModel)
|
||||
async def get_dashboard_stats(
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
) -> ResponseModel:
|
||||
"""
|
||||
Get dashboard statistics
|
||||
|
||||
Returns counts of users, projects, and tasks by status.
|
||||
Requires admin privileges.
|
||||
"""
|
||||
try:
|
||||
stats = await admin_service.get_dashboard_stats()
|
||||
return ResponseModel(data=stats.model_dump())
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting dashboard stats: {e}")
|
||||
raise AppException(
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
message=str(e),
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.get("/system", response_model=ResponseModel)
|
||||
async def get_system_resources(
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
) -> ResponseModel:
|
||||
"""
|
||||
Get system resource information
|
||||
|
||||
Returns CPU, memory, disk usage and uptime.
|
||||
Requires admin privileges.
|
||||
"""
|
||||
try:
|
||||
resources = await admin_service.get_system_resources()
|
||||
return ResponseModel(data=resources.model_dump())
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting system resources: {e}")
|
||||
raise AppException(
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
message=str(e),
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.get("/activity", response_model=ResponseModel)
|
||||
async def get_recent_activity(
|
||||
limit: int = Query(20, ge=1, le=100),
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
) -> ResponseModel:
|
||||
"""
|
||||
Get recent system activity
|
||||
|
||||
Returns recent user, project, and task activities.
|
||||
Requires admin privileges.
|
||||
"""
|
||||
try:
|
||||
activity = await admin_service.get_recent_activity(limit=limit)
|
||||
return ResponseModel(data=activity.model_dump())
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting recent activity: {e}")
|
||||
raise AppException(
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
message=str(e),
|
||||
status_code=500
|
||||
)
|
||||
147
backend/src/api/admin/projects.py
Normal file
147
backend/src/api/admin/projects.py
Normal file
@@ -0,0 +1,147 @@
|
||||
"""
|
||||
Admin API - Projects Routes
|
||||
|
||||
包含项目管理相关的 API 路由。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, Request
|
||||
|
||||
from src.auth.dependencies import require_admin
|
||||
from src.auth.models import UserAuth
|
||||
from src.models.schemas import ResponseModel
|
||||
from src.services.admin_service import admin_service
|
||||
from src.utils.pagination import Paginator
|
||||
from src.utils.errors import ErrorCode, AppException
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/projects", tags=["admin-projects"])
|
||||
|
||||
|
||||
@router.get("", response_model=ResponseModel)
|
||||
async def list_projects(
|
||||
request: Request,
|
||||
page: int = Query(1, ge=1, description="Page number, starting from 1"),
|
||||
page_size: int = Query(20, ge=1, le=100, description="Items per page"),
|
||||
sort: Optional[str] = Query(None, description="Sort field, format: field:asc or field:desc"),
|
||||
filter: Optional[str] = Query(None, description="Filter conditions, JSON format"),
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
) -> ResponseModel:
|
||||
"""
|
||||
List all projects with pagination
|
||||
|
||||
Supports filtering by status, type, and search by name.
|
||||
Requires admin privileges.
|
||||
"""
|
||||
try:
|
||||
# Parse filters
|
||||
filters = {}
|
||||
if filter:
|
||||
import json
|
||||
try:
|
||||
filters = json.loads(filter)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Parse sort
|
||||
sort_by = "created_at"
|
||||
sort_order = "desc"
|
||||
if sort:
|
||||
parts = sort.split(":")
|
||||
if len(parts) == 2:
|
||||
sort_by = parts[0]
|
||||
sort_order = parts[1]
|
||||
|
||||
items, total = await admin_service.list_projects(
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
filters=filters if filters else None,
|
||||
sort_by=sort_by,
|
||||
sort_order=sort_order,
|
||||
)
|
||||
|
||||
paginator = Paginator(
|
||||
items=[item.model_dump(by_alias=True) for item in items],
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
return paginator.to_response(request)
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing projects: {e}")
|
||||
raise AppException(
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
message=str(e),
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{project_id}", response_model=ResponseModel)
|
||||
async def get_project(
|
||||
project_id: str,
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
) -> ResponseModel:
|
||||
"""
|
||||
Get project details by ID
|
||||
|
||||
Requires admin privileges.
|
||||
"""
|
||||
try:
|
||||
project = await admin_service.get_project(project_id)
|
||||
if not project:
|
||||
raise AppException(
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
message="Project not found",
|
||||
status_code=404
|
||||
)
|
||||
return ResponseModel(data=project.model_dump())
|
||||
except AppException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting project: {e}")
|
||||
raise AppException(
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
message=str(e),
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{project_id}", response_model=ResponseModel)
|
||||
async def delete_project(
|
||||
project_id: str,
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
) -> ResponseModel:
|
||||
"""
|
||||
Delete a project
|
||||
|
||||
Requires admin privileges.
|
||||
"""
|
||||
try:
|
||||
success = await admin_service.delete_project(project_id)
|
||||
if not success:
|
||||
raise AppException(
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
message="Project not found",
|
||||
status_code=404
|
||||
)
|
||||
|
||||
return ResponseModel(
|
||||
data={
|
||||
"id": project_id,
|
||||
"deleted": True,
|
||||
"message": "Project deleted successfully",
|
||||
}
|
||||
)
|
||||
except AppException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting project: {e}")
|
||||
raise AppException(
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
message=str(e),
|
||||
status_code=500
|
||||
)
|
||||
81
backend/src/api/admin/settings.py
Normal file
81
backend/src/api/admin/settings.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""
|
||||
Admin API - Settings Routes
|
||||
|
||||
包含系统设置相关的 API 路由。
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from src.auth.dependencies import require_admin
|
||||
from src.auth.models import UserAuth
|
||||
from src.models.schemas import ResponseModel
|
||||
from src.models.admin_schemas import SystemSettingUpdateRequest
|
||||
from src.services.admin_service import admin_service
|
||||
from src.utils.errors import ErrorCode, AppException
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/settings", tags=["admin-settings"])
|
||||
|
||||
|
||||
@router.get("", response_model=ResponseModel)
|
||||
async def get_system_settings(
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
) -> ResponseModel:
|
||||
"""
|
||||
Get system settings
|
||||
|
||||
Returns all configurable system settings.
|
||||
Requires admin privileges.
|
||||
"""
|
||||
try:
|
||||
settings = await admin_service.get_system_settings()
|
||||
return ResponseModel(data=settings.model_dump())
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting system settings: {e}")
|
||||
raise AppException(
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
message=str(e),
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.put("/{key}", response_model=ResponseModel)
|
||||
async def update_system_setting(
|
||||
key: str,
|
||||
request: SystemSettingUpdateRequest,
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
) -> ResponseModel:
|
||||
"""
|
||||
Update a system setting
|
||||
|
||||
Requires admin privileges.
|
||||
"""
|
||||
try:
|
||||
setting = await admin_service.update_system_setting(key, request.value)
|
||||
if not setting:
|
||||
raise AppException(
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
message="Setting not found",
|
||||
status_code=404
|
||||
)
|
||||
|
||||
return ResponseModel(
|
||||
data={
|
||||
"key": key,
|
||||
"value": request.value,
|
||||
"updated_at": setting.updated_at,
|
||||
"message": "Setting updated successfully",
|
||||
}
|
||||
)
|
||||
except AppException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating system setting: {e}")
|
||||
raise AppException(
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
message=str(e),
|
||||
status_code=500
|
||||
)
|
||||
443
backend/src/api/admin/tasks.py
Normal file
443
backend/src/api/admin/tasks.py
Normal file
@@ -0,0 +1,443 @@
|
||||
"""
|
||||
Admin API - Tasks Routes
|
||||
|
||||
包含任务管理相关的 API 路由。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional, List
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, Request
|
||||
|
||||
from src.auth.dependencies import require_admin
|
||||
from src.auth.models import UserAuth
|
||||
from src.models.schemas import ResponseModel
|
||||
from src.models.admin_schemas import SystemSettingUpdateRequest
|
||||
from src.services.admin_service import admin_service
|
||||
from src.utils.pagination import Paginator
|
||||
from src.utils.errors import ErrorCode, AppException
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/tasks", tags=["admin-tasks"])
|
||||
|
||||
|
||||
@router.get("", response_model=ResponseModel)
|
||||
async def list_tasks(
|
||||
request: Request,
|
||||
page: int = Query(1, ge=1, description="Page number, starting from 1"),
|
||||
page_size: int = Query(20, ge=1, le=100, description="Items per page"),
|
||||
sort: Optional[str] = Query(None, description="Sort field, format: field:asc or field:desc"),
|
||||
filter: Optional[str] = Query(None, description="Filter conditions, JSON format"),
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
) -> ResponseModel:
|
||||
"""
|
||||
List all tasks with pagination
|
||||
|
||||
Supports filtering by status, type, provider, user_id, and project_id.
|
||||
Requires admin privileges.
|
||||
"""
|
||||
try:
|
||||
# Parse filters
|
||||
filters = {}
|
||||
if filter:
|
||||
import json
|
||||
try:
|
||||
filters = json.loads(filter)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Parse sort
|
||||
sort_by = "created_at"
|
||||
sort_order = "desc"
|
||||
if sort:
|
||||
parts = sort.split(":")
|
||||
if len(parts) == 2:
|
||||
sort_by = parts[0]
|
||||
sort_order = parts[1]
|
||||
|
||||
items, total = await admin_service.list_tasks(
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
filters=filters if filters else None,
|
||||
sort_by=sort_by,
|
||||
sort_order=sort_order,
|
||||
)
|
||||
|
||||
paginator = Paginator(
|
||||
items=[item.model_dump() for item in items],
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
return paginator.to_response(request)
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing tasks: {e}")
|
||||
raise AppException(
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
message=str(e),
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.get("/stats", response_model=ResponseModel)
|
||||
async def get_task_stats(
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
) -> ResponseModel:
|
||||
"""
|
||||
Get task statistics
|
||||
|
||||
Returns counts by status, type, provider, and success rate.
|
||||
Requires admin privileges.
|
||||
"""
|
||||
try:
|
||||
stats = await admin_service.get_task_stats()
|
||||
return ResponseModel(data=stats.model_dump())
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting task stats: {e}")
|
||||
raise AppException(
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
message=str(e),
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.get("/queue", response_model=ResponseModel)
|
||||
async def get_task_queue_status(
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
) -> ResponseModel:
|
||||
"""
|
||||
Get task queue status
|
||||
|
||||
Returns queue length, processing count, and worker status.
|
||||
Requires admin privileges.
|
||||
"""
|
||||
try:
|
||||
status = await admin_service.get_task_queue_status()
|
||||
return ResponseModel(data=status.model_dump())
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting task queue status: {e}")
|
||||
raise AppException(
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
message=str(e),
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{task_id}", response_model=ResponseModel)
|
||||
async def get_task_detail(
|
||||
task_id: str,
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
) -> ResponseModel:
|
||||
"""
|
||||
Get task detail by ID
|
||||
|
||||
Returns detailed task information including user and project details.
|
||||
Requires admin privileges.
|
||||
"""
|
||||
try:
|
||||
task = await admin_service.get_task_detail(task_id)
|
||||
if not task:
|
||||
raise AppException(
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
message="Task not found",
|
||||
status_code=404
|
||||
)
|
||||
return ResponseModel(data=task.model_dump())
|
||||
except AppException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting task detail: {e}")
|
||||
raise AppException(
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
message=str(e),
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{task_id}/retry", response_model=ResponseModel)
|
||||
async def retry_task(
|
||||
task_id: str,
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
) -> ResponseModel:
|
||||
"""
|
||||
Retry a failed task
|
||||
|
||||
Requires admin privileges.
|
||||
"""
|
||||
try:
|
||||
task = await admin_service.retry_task(task_id)
|
||||
if not task:
|
||||
raise AppException(
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
message="Task not found or cannot be retried",
|
||||
status_code=404
|
||||
)
|
||||
|
||||
return ResponseModel(
|
||||
data={
|
||||
"id": task_id,
|
||||
"status": task.status,
|
||||
"message": "Task queued for retry",
|
||||
"retry_count": task.retry_count,
|
||||
}
|
||||
)
|
||||
except AppException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrying task: {e}")
|
||||
raise AppException(
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
message=str(e),
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{task_id}", response_model=ResponseModel)
|
||||
async def delete_task(
|
||||
task_id: str,
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
) -> ResponseModel:
|
||||
"""
|
||||
Delete a task
|
||||
|
||||
Requires admin privileges.
|
||||
"""
|
||||
try:
|
||||
success = await admin_service.delete_task(task_id)
|
||||
if not success:
|
||||
raise AppException(
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
message="Task not found",
|
||||
status_code=404
|
||||
)
|
||||
|
||||
return ResponseModel(
|
||||
data={
|
||||
"id": task_id,
|
||||
"deleted": True,
|
||||
"message": "Task deleted successfully",
|
||||
}
|
||||
)
|
||||
except AppException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting task: {e}")
|
||||
raise AppException(
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
message=str(e),
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.post("/batch-retry", response_model=ResponseModel)
|
||||
async def batch_retry_tasks(
|
||||
task_ids: List[str],
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
) -> ResponseModel:
|
||||
"""
|
||||
Batch retry failed tasks
|
||||
|
||||
Requires admin privileges.
|
||||
"""
|
||||
try:
|
||||
from sqlmodel import Session, select
|
||||
from src.config.database import engine
|
||||
from src.models.entities import TaskDB
|
||||
from datetime import datetime
|
||||
|
||||
retried = []
|
||||
failed = []
|
||||
|
||||
with Session(engine) as session:
|
||||
for task_id in task_ids:
|
||||
task = session.get(TaskDB, task_id)
|
||||
if not task:
|
||||
failed.append({"id": task_id, "reason": "Task not found"})
|
||||
continue
|
||||
|
||||
if task.status not in ["failed", "timeout"]:
|
||||
failed.append({"id": task_id, "reason": f"Cannot retry task with status: {task.status}"})
|
||||
continue
|
||||
|
||||
task.status = "pending"
|
||||
task.retry_count = 0
|
||||
task.error = None
|
||||
task.updated_at = datetime.now().timestamp()
|
||||
session.add(task)
|
||||
retried.append(task_id)
|
||||
|
||||
session.commit()
|
||||
|
||||
return ResponseModel(
|
||||
data={
|
||||
"retried": retried,
|
||||
"failed": failed,
|
||||
"total": len(task_ids),
|
||||
"success_count": len(retried),
|
||||
"message": f"Successfully queued {len(retried)} tasks for retry",
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error batch retrying tasks: {e}")
|
||||
raise AppException(
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
message=str(e),
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.post("/batch-cancel", response_model=ResponseModel)
|
||||
async def batch_cancel_tasks(
|
||||
task_ids: List[str],
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
) -> ResponseModel:
|
||||
"""
|
||||
Batch cancel pending/processing tasks
|
||||
|
||||
Requires admin privileges.
|
||||
"""
|
||||
try:
|
||||
from sqlmodel import Session, select
|
||||
from src.config.database import engine
|
||||
from src.models.entities import TaskDB
|
||||
from datetime import datetime
|
||||
|
||||
cancelled = []
|
||||
failed = []
|
||||
|
||||
with Session(engine) as session:
|
||||
for task_id in task_ids:
|
||||
task = session.get(TaskDB, task_id)
|
||||
if not task:
|
||||
failed.append({"id": task_id, "reason": "Task not found"})
|
||||
continue
|
||||
|
||||
if task.status not in ["pending", "processing"]:
|
||||
failed.append({"id": task_id, "reason": f"Cannot cancel task with status: {task.status}"})
|
||||
continue
|
||||
|
||||
task.status = "cancelled"
|
||||
task.updated_at = datetime.now().timestamp()
|
||||
session.add(task)
|
||||
cancelled.append(task_id)
|
||||
|
||||
session.commit()
|
||||
|
||||
return ResponseModel(
|
||||
data={
|
||||
"cancelled": cancelled,
|
||||
"failed": failed,
|
||||
"total": len(task_ids),
|
||||
"success_count": len(cancelled),
|
||||
"message": f"Successfully cancelled {len(cancelled)} tasks",
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error batch cancelling tasks: {e}")
|
||||
raise AppException(
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
message=str(e),
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.post("/batch-delete", response_model=ResponseModel)
|
||||
async def batch_delete_tasks(
|
||||
task_ids: List[str],
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
) -> ResponseModel:
|
||||
"""
|
||||
Batch delete tasks
|
||||
|
||||
Requires admin privileges.
|
||||
"""
|
||||
try:
|
||||
from sqlmodel import Session, select
|
||||
from src.config.database import engine
|
||||
from src.models.entities import TaskDB
|
||||
|
||||
deleted = []
|
||||
failed = []
|
||||
|
||||
with Session(engine) as session:
|
||||
for task_id in task_ids:
|
||||
task = session.get(TaskDB, task_id)
|
||||
if not task:
|
||||
failed.append({"id": task_id, "reason": "Task not found"})
|
||||
continue
|
||||
|
||||
session.delete(task)
|
||||
deleted.append(task_id)
|
||||
|
||||
session.commit()
|
||||
|
||||
return ResponseModel(
|
||||
data={
|
||||
"deleted": deleted,
|
||||
"failed": failed,
|
||||
"total": len(task_ids),
|
||||
"success_count": len(deleted),
|
||||
"message": f"Successfully deleted {len(deleted)} tasks",
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error batch deleting tasks: {e}")
|
||||
raise AppException(
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
message=str(e),
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.post("/cleanup-completed", response_model=ResponseModel)
|
||||
async def cleanup_completed_tasks(
|
||||
days: int = Query(30, ge=1, le=365, description="清理多少天前的已完成任务"),
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
) -> ResponseModel:
|
||||
"""
|
||||
Cleanup completed tasks older than specified days
|
||||
|
||||
Requires admin privileges.
|
||||
"""
|
||||
try:
|
||||
from sqlmodel import Session, select
|
||||
from src.config.database import engine
|
||||
from src.models.entities import TaskDB
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
cutoff_timestamp = time.time() - (days * 24 * 60 * 60)
|
||||
|
||||
with Session(engine) as session:
|
||||
# Find completed tasks older than cutoff
|
||||
old_tasks = session.exec(
|
||||
select(TaskDB).where(
|
||||
TaskDB.status == "success",
|
||||
TaskDB.completed_at < cutoff_timestamp
|
||||
)
|
||||
).all()
|
||||
|
||||
deleted_count = 0
|
||||
for task in old_tasks:
|
||||
session.delete(task)
|
||||
deleted_count += 1
|
||||
|
||||
session.commit()
|
||||
|
||||
return ResponseModel(
|
||||
data={
|
||||
"deleted_count": deleted_count,
|
||||
"days": days,
|
||||
"cutoff_date": datetime.fromtimestamp(cutoff_timestamp).isoformat(),
|
||||
"message": f"Successfully deleted {deleted_count} completed tasks older than {days} days",
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up completed tasks: {e}")
|
||||
raise AppException(
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
message=str(e),
|
||||
status_code=500
|
||||
)
|
||||
317
backend/src/api/admin/users.py
Normal file
317
backend/src/api/admin/users.py
Normal file
@@ -0,0 +1,317 @@
|
||||
"""
|
||||
Admin API - Users Routes
|
||||
|
||||
包含用户管理相关的 API 路由。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, Request
|
||||
|
||||
from src.auth.dependencies import require_admin
|
||||
from src.auth.models import UserAuth
|
||||
from src.models.schemas import ResponseModel, PaginationParams
|
||||
from src.models.admin_schemas import (
|
||||
AdminUserCreateRequest,
|
||||
AdminUserUpdateRequest,
|
||||
)
|
||||
from src.services.admin_service import admin_service
|
||||
from src.services.user_service import user_service
|
||||
from src.utils.pagination import Paginator
|
||||
from src.utils.errors import ErrorCode, AppException
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/users", tags=["admin-users"])
|
||||
|
||||
|
||||
@router.get("", response_model=ResponseModel)
|
||||
async def list_users(
|
||||
request: Request,
|
||||
page: int = Query(1, ge=1, description="Page number, starting from 1"),
|
||||
page_size: int = Query(20, ge=1, le=100, description="Items per page"),
|
||||
sort: Optional[str] = Query(None, description="Sort field, format: field:asc or field:desc"),
|
||||
filter: Optional[str] = Query(None, description="Filter conditions, JSON format"),
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
) -> ResponseModel:
|
||||
"""
|
||||
List all users with pagination
|
||||
|
||||
Supports filtering by is_active, is_superuser, and search by username/email.
|
||||
Requires admin privileges.
|
||||
"""
|
||||
try:
|
||||
# Parse filters
|
||||
filters = {}
|
||||
if filter:
|
||||
import json
|
||||
try:
|
||||
filters = json.loads(filter)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Parse sort
|
||||
sort_by = "created_at"
|
||||
sort_order = "desc"
|
||||
if sort:
|
||||
parts = sort.split(":")
|
||||
if len(parts) == 2:
|
||||
sort_by = parts[0]
|
||||
sort_order = parts[1]
|
||||
|
||||
items, total = await admin_service.list_users(
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
filters=filters if filters else None,
|
||||
sort_by=sort_by,
|
||||
sort_order=sort_order,
|
||||
)
|
||||
|
||||
paginator = Paginator(
|
||||
items=[item.model_dump() for item in items],
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
return paginator.to_response(request)
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing users: {e}")
|
||||
raise AppException(
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
message=str(e),
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{user_id}", response_model=ResponseModel)
|
||||
async def get_user(
|
||||
user_id: str,
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
) -> ResponseModel:
|
||||
"""
|
||||
Get user details by ID
|
||||
|
||||
Requires admin privileges.
|
||||
"""
|
||||
try:
|
||||
user = await admin_service.get_user(user_id)
|
||||
if not user:
|
||||
raise AppException(
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
message="User not found",
|
||||
status_code=404
|
||||
)
|
||||
return ResponseModel(data=user.model_dump())
|
||||
except AppException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user: {e}")
|
||||
raise AppException(
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
message=str(e),
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.post("", response_model=ResponseModel)
|
||||
async def create_user(
|
||||
request: AdminUserCreateRequest,
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
) -> ResponseModel:
|
||||
"""
|
||||
Create a new user
|
||||
|
||||
Requires admin privileges.
|
||||
"""
|
||||
try:
|
||||
# Check if username already exists
|
||||
existing_user = await user_service.get_user_by_username(request.username)
|
||||
if existing_user:
|
||||
raise AppException(
|
||||
code=ErrorCode.CONFLICT,
|
||||
message="Username already exists",
|
||||
status_code=400
|
||||
)
|
||||
|
||||
# Check if email already exists
|
||||
existing_email = await user_service.get_user_by_email(request.email)
|
||||
if existing_email:
|
||||
raise AppException(
|
||||
code=ErrorCode.CONFLICT,
|
||||
message="Email already registered",
|
||||
status_code=400
|
||||
)
|
||||
|
||||
# Create user using user_service
|
||||
user = await user_service.create_user(
|
||||
username=request.username,
|
||||
email=request.email,
|
||||
password=request.password
|
||||
)
|
||||
|
||||
# Update additional fields if provided
|
||||
update_data = {}
|
||||
if request.is_active is not None:
|
||||
update_data["is_active"] = request.is_active
|
||||
if request.is_superuser is not None:
|
||||
update_data["is_superuser"] = request.is_superuser
|
||||
if request.roles:
|
||||
update_data["roles"] = request.roles
|
||||
if request.permissions:
|
||||
update_data["permissions"] = request.permissions
|
||||
|
||||
if update_data:
|
||||
user = await admin_service.update_user(user.id, update_data)
|
||||
|
||||
logger.info(f"Admin {current_user.username} created user: {user.username} ({user.id})")
|
||||
|
||||
return ResponseModel(
|
||||
data=user.model_dump(),
|
||||
message="User created successfully"
|
||||
)
|
||||
except AppException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating user: {e}")
|
||||
raise AppException(
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
message=str(e),
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.put("/{user_id}", response_model=ResponseModel)
|
||||
async def update_user(
|
||||
user_id: str,
|
||||
request: AdminUserUpdateRequest,
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
) -> ResponseModel:
|
||||
"""
|
||||
Update user information
|
||||
|
||||
Requires admin privileges.
|
||||
"""
|
||||
try:
|
||||
# Prevent self-demotion from superuser
|
||||
if user_id == current_user.id and request.is_superuser is False:
|
||||
raise AppException(
|
||||
code=ErrorCode.FORBIDDEN,
|
||||
message="Cannot remove your own superuser status",
|
||||
status_code=400
|
||||
)
|
||||
|
||||
update_data = request.model_dump(exclude_unset=True)
|
||||
user = await admin_service.update_user(user_id, update_data)
|
||||
|
||||
if not user:
|
||||
raise AppException(
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
message="User not found",
|
||||
status_code=404
|
||||
)
|
||||
|
||||
return ResponseModel(data=user.model_dump())
|
||||
except AppException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating user: {e}")
|
||||
raise AppException(
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
message=str(e),
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{user_id}/toggle-active", response_model=ResponseModel)
|
||||
async def toggle_user_active(
|
||||
user_id: str,
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
) -> ResponseModel:
|
||||
"""
|
||||
Toggle user active status
|
||||
|
||||
Requires admin privileges.
|
||||
"""
|
||||
try:
|
||||
# Prevent self-deactivation
|
||||
if user_id == current_user.id:
|
||||
raise AppException(
|
||||
code=ErrorCode.FORBIDDEN,
|
||||
message="Cannot deactivate your own account",
|
||||
status_code=400
|
||||
)
|
||||
|
||||
user = await admin_service.get_user(user_id)
|
||||
if not user:
|
||||
raise AppException(
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
message="User not found",
|
||||
status_code=404
|
||||
)
|
||||
|
||||
new_status = not user.is_active
|
||||
updated_user = await admin_service.toggle_user_active(user_id, new_status)
|
||||
|
||||
return ResponseModel(
|
||||
data={
|
||||
"id": user_id,
|
||||
"is_active": new_status,
|
||||
"message": f"User {'activated' if new_status else 'deactivated'} successfully",
|
||||
}
|
||||
)
|
||||
except AppException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error toggling user active status: {e}")
|
||||
raise AppException(
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
message=str(e),
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{user_id}", response_model=ResponseModel)
|
||||
async def delete_user(
|
||||
user_id: str,
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
) -> ResponseModel:
|
||||
"""
|
||||
Delete a user
|
||||
|
||||
Requires admin privileges.
|
||||
"""
|
||||
try:
|
||||
# Prevent self-deletion
|
||||
if user_id == current_user.id:
|
||||
raise AppException(
|
||||
code=ErrorCode.FORBIDDEN,
|
||||
message="Cannot delete your own account",
|
||||
status_code=400
|
||||
)
|
||||
|
||||
success = await admin_service.delete_user(user_id)
|
||||
if not success:
|
||||
raise AppException(
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
message="User not found",
|
||||
status_code=404
|
||||
)
|
||||
|
||||
return ResponseModel(
|
||||
data={
|
||||
"id": user_id,
|
||||
"deleted": True,
|
||||
"message": "User deleted successfully",
|
||||
}
|
||||
)
|
||||
except AppException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting user: {e}")
|
||||
raise AppException(
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
message=str(e),
|
||||
status_code=500
|
||||
)
|
||||
206
backend/src/api/audit_logs.py
Normal file
206
backend/src/api/audit_logs.py
Normal file
@@ -0,0 +1,206 @@
|
||||
"""
|
||||
Audit Log API
|
||||
|
||||
操作审计日志 API 端点。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional, Dict, Any, List
|
||||
from datetime import datetime
|
||||
from fastapi import APIRouter, Depends, Query, Request
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from src.auth.dependencies import require_admin
|
||||
from src.auth.models import UserAuth
|
||||
from src.models.schemas import ResponseModel
|
||||
from src.utils.pagination import Paginator
|
||||
from src.services.audit_log_service import audit_log_service
|
||||
from src.models.audit_log import AuditLogDB
|
||||
from src.utils.errors import ErrorCode, AppException
|
||||
|
||||
router = APIRouter(prefix="/admin/audit-logs", tags=["admin-audit-logs"])
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ===== 响应模型 =====
|
||||
|
||||
class AuditLogListItem(BaseModel):
|
||||
"""审计日志列表项"""
|
||||
id: str
|
||||
user_id: Optional[str]
|
||||
username: Optional[str]
|
||||
action: str
|
||||
resource_type: Optional[str]
|
||||
resource_id: Optional[str]
|
||||
ip_address: Optional[str]
|
||||
created_at: str # ISO format
|
||||
|
||||
|
||||
class AuditLogDetailResponse(BaseModel):
|
||||
"""审计日志详情"""
|
||||
id: str
|
||||
user_id: Optional[str]
|
||||
username: Optional[str]
|
||||
action: str
|
||||
resource_type: Optional[str]
|
||||
resource_id: Optional[str]
|
||||
ip_address: Optional[str]
|
||||
user_agent: Optional[str]
|
||||
details: Optional[Dict[str, Any]]
|
||||
created_at: str
|
||||
|
||||
|
||||
# ===== API 端点 =====
|
||||
|
||||
@router.get("", response_model=ResponseModel)
|
||||
async def list_audit_logs(
|
||||
request: Request,
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
|
||||
user_id: Optional[str] = Query(None, description="按用户 ID 过滤"),
|
||||
action: Optional[str] = Query(None, description="按操作类型过滤"),
|
||||
resource_type: Optional[str] = Query(None, description="按资源类型过滤"),
|
||||
resource_id: Optional[str] = Query(None, description="按资源 ID 过滤"),
|
||||
start_date: Optional[str] = Query(None, description="开始日期 (ISO format)"),
|
||||
end_date: Optional[str] = Query(None, description="结束日期 (ISO format)"),
|
||||
sort_by: str = Query("created_at", description="排序字段"),
|
||||
sort_order: str = Query("desc", description="排序方向"),
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
):
|
||||
"""
|
||||
列出审计日志
|
||||
|
||||
支持高级搜索:按用户、操作类型、资源类型、时间范围过滤。
|
||||
需要管理员权限。
|
||||
"""
|
||||
try:
|
||||
# 构建过滤条件
|
||||
filters = {}
|
||||
if user_id:
|
||||
filters["user_id"] = user_id
|
||||
if action:
|
||||
filters["action"] = action
|
||||
if resource_type:
|
||||
filters["resource_type"] = resource_type
|
||||
if resource_id:
|
||||
filters["resource_id"] = resource_id
|
||||
if start_date:
|
||||
filters["start_date"] = start_date
|
||||
if end_date:
|
||||
filters["end_date"] = end_date
|
||||
|
||||
logs, total = audit_log_service.list_logs(
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
filters=filters,
|
||||
sort_by=sort_by,
|
||||
sort_order=sort_order,
|
||||
)
|
||||
|
||||
# 转换为响应格式
|
||||
items = [
|
||||
AuditLogListItem(
|
||||
id=log.id,
|
||||
user_id=log.user_id,
|
||||
username=log.username,
|
||||
action=log.action,
|
||||
resource_type=log.resource_type,
|
||||
resource_id=log.resource_id,
|
||||
ip_address=log.ip_address,
|
||||
created_at=datetime.fromtimestamp(log.created_at).isoformat(),
|
||||
).model_dump()
|
||||
for log in logs
|
||||
]
|
||||
|
||||
paginator = Paginator(
|
||||
items=items,
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
return paginator.to_response(request)
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing audit logs: {e}")
|
||||
raise
|
||||
|
||||
|
||||
@router.get("/{log_id}", response_model=ResponseModel)
|
||||
async def get_audit_log(
|
||||
log_id: str,
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
):
|
||||
"""
|
||||
获取审计日志详情
|
||||
|
||||
需要管理员权限。
|
||||
"""
|
||||
try:
|
||||
log = audit_log_service.get_log(log_id)
|
||||
if not log:
|
||||
raise AppException(
|
||||
message="Audit log not found",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
|
||||
return ResponseModel(
|
||||
data=AuditLogDetailResponse(
|
||||
id=log.id,
|
||||
user_id=log.user_id,
|
||||
username=log.username,
|
||||
action=log.action,
|
||||
resource_type=log.resource_type,
|
||||
resource_id=log.resource_id,
|
||||
ip_address=log.ip_address,
|
||||
user_agent=log.user_agent,
|
||||
details=log.details,
|
||||
created_at=datetime.fromtimestamp(log.created_at).isoformat(),
|
||||
).model_dump()
|
||||
)
|
||||
except AppException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting audit log: {e}")
|
||||
raise
|
||||
|
||||
|
||||
@router.get("/export/csv", response_model=ResponseModel)
|
||||
async def export_audit_logs(
|
||||
user_id: Optional[str] = Query(None, description="按用户 ID 过滤"),
|
||||
action: Optional[str] = Query(None, description="按操作类型过滤"),
|
||||
resource_type: Optional[str] = Query(None, description="按资源类型过滤"),
|
||||
start_date: Optional[str] = Query(None, description="开始日期 (ISO format)"),
|
||||
end_date: Optional[str] = Query(None, description="结束日期 (ISO format)"),
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
):
|
||||
"""
|
||||
导出审计日志为 CSV
|
||||
|
||||
需要管理员权限。
|
||||
"""
|
||||
try:
|
||||
# 构建过滤条件
|
||||
filters = {}
|
||||
if user_id:
|
||||
filters["user_id"] = user_id
|
||||
if action:
|
||||
filters["action"] = action
|
||||
if resource_type:
|
||||
filters["resource_type"] = resource_type
|
||||
if start_date:
|
||||
filters["start_date"] = start_date
|
||||
if end_date:
|
||||
filters["end_date"] = end_date
|
||||
|
||||
csv_content = audit_log_service.export_logs(filters=filters)
|
||||
|
||||
return ResponseModel(
|
||||
data={
|
||||
"content": csv_content,
|
||||
"filename": f"audit_logs_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv",
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error exporting audit logs: {e}")
|
||||
raise
|
||||
927
backend/src/api/auth.py
Normal file
927
backend/src/api/auth.py
Normal file
@@ -0,0 +1,927 @@
|
||||
"""
|
||||
认证 API 路由
|
||||
|
||||
提供用户登录、注册、获取当前用户信息等接口。
|
||||
"""
|
||||
|
||||
import logging
|
||||
import io
|
||||
import uuid
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import APIRouter, status, Depends, UploadFile, File, Request
|
||||
from pydantic import BaseModel, EmailStr, Field
|
||||
from PIL import Image
|
||||
|
||||
from src.auth.jwt import (
|
||||
create_token_pair,
|
||||
verify_password,
|
||||
get_password_hash,
|
||||
verify_token,
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES,
|
||||
)
|
||||
from src.services.token_blacklist_service import token_blacklist_service
|
||||
from src.services.session_service import session_service
|
||||
from src.services.email_service import email_service
|
||||
from src.config.settings import REDIS_ENABLED, NODE_ENV
|
||||
from src.auth.dependencies import get_current_user, UserAuth
|
||||
from src.auth.models import RefreshTokenRequest
|
||||
from src.services.user_service import user_service
|
||||
from src.services.storage_service import storage_manager
|
||||
from src.models.schemas import ResponseModel
|
||||
from src.utils.errors import ErrorCode, AppException
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/auth", tags=["认证"])
|
||||
|
||||
|
||||
def _build_user_payload(user: UserAuth) -> dict:
|
||||
return {
|
||||
"id": user.id,
|
||||
"username": user.username,
|
||||
"email": user.email,
|
||||
"is_active": user.is_active,
|
||||
"is_superuser": user.is_superuser,
|
||||
}
|
||||
|
||||
|
||||
def _extract_bearer_token(request: Request) -> Optional[str]:
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if auth_header.startswith("Bearer "):
|
||||
return auth_header[7:]
|
||||
return None
|
||||
|
||||
|
||||
async def _get_access_payload(request: Request):
|
||||
access_token = _extract_bearer_token(request)
|
||||
if not access_token:
|
||||
return None
|
||||
return await verify_token(access_token, token_type="access")
|
||||
|
||||
|
||||
def _serialize_session(session_db, current_session_id: Optional[str] = None) -> dict:
|
||||
return {
|
||||
"id": session_db.id,
|
||||
"session_family_id": session_db.session_family_id,
|
||||
"status": session_db.status,
|
||||
"device_name": session_db.device_name,
|
||||
"ip_address": session_db.ip_address,
|
||||
"user_agent": session_db.user_agent,
|
||||
"created_at": session_db.created_at,
|
||||
"updated_at": session_db.updated_at,
|
||||
"expires_at": session_db.expires_at,
|
||||
"last_used_at": session_db.last_used_at,
|
||||
"revoked_at": session_db.revoked_at,
|
||||
"revoked_reason": session_db.revoked_reason,
|
||||
"current": session_db.id == current_session_id,
|
||||
}
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
"""登录请求"""
|
||||
|
||||
username: str = Field(..., description="用户名或邮箱")
|
||||
password: str = Field(..., description="密码")
|
||||
|
||||
|
||||
class RegisterRequest(BaseModel):
|
||||
"""注册请求"""
|
||||
|
||||
username: str = Field(..., min_length=3, max_length=50, description="用户名")
|
||||
email: EmailStr = Field(..., description="邮箱")
|
||||
password: str = Field(..., min_length=6, description="密码")
|
||||
|
||||
|
||||
class TokenResponse(BaseModel):
|
||||
"""Token 响应"""
|
||||
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
token_type: str = "bearer"
|
||||
expires_in: int = ACCESS_TOKEN_EXPIRE_MINUTES * 60 # 从配置读取
|
||||
|
||||
|
||||
class UserInfoResponse(BaseModel):
|
||||
"""用户信息响应"""
|
||||
|
||||
id: str
|
||||
username: str
|
||||
email: Optional[str]
|
||||
is_active: bool
|
||||
is_superuser: bool
|
||||
|
||||
|
||||
@router.post("/login", response_model=ResponseModel)
|
||||
async def login(request: LoginRequest, http_request: Request):
|
||||
"""
|
||||
用户登录
|
||||
|
||||
支持使用用户名或邮箱登录。
|
||||
成功返回 access_token 和 refresh_token。
|
||||
"""
|
||||
try:
|
||||
# 使用 authenticate_user 方法验证用户(包含密码验证)
|
||||
user = await user_service.authenticate_user(request.username, request.password)
|
||||
|
||||
if not user:
|
||||
raise AppException(
|
||||
code=ErrorCode.UNAUTHORIZED,
|
||||
message="Invalid credentials",
|
||||
status_code=401,
|
||||
details={"headers": {"WWW-Authenticate": "Bearer"}}
|
||||
)
|
||||
|
||||
# 检查用户是否激活
|
||||
if not user.is_active:
|
||||
raise AppException(
|
||||
code=ErrorCode.FORBIDDEN,
|
||||
message="User account is inactive",
|
||||
status_code=403
|
||||
)
|
||||
|
||||
# 更新最后登录时间
|
||||
await user_service.update_last_login(user.id)
|
||||
|
||||
session_id = str(uuid.uuid4())
|
||||
session_family_id = str(uuid.uuid4())
|
||||
tokens = create_token_pair(
|
||||
user_id=user.id,
|
||||
scopes=["user"],
|
||||
session_id=session_id,
|
||||
session_family_id=session_family_id,
|
||||
)
|
||||
session_service.create_session(
|
||||
user_id=user.id,
|
||||
refresh_token=tokens.refresh_token,
|
||||
session_id=session_id,
|
||||
session_family_id=session_family_id,
|
||||
ip_address=http_request.client.host if http_request.client else None,
|
||||
user_agent=http_request.headers.get("user-agent"),
|
||||
)
|
||||
|
||||
logger.info(f"User logged in: {user.username} ({user.id})")
|
||||
|
||||
return ResponseModel(
|
||||
data={
|
||||
"access_token": tokens.access_token,
|
||||
"refresh_token": tokens.refresh_token,
|
||||
"token_type": "bearer",
|
||||
"expires_in": ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
||||
"session_id": session_id,
|
||||
"user": _build_user_payload(user),
|
||||
}
|
||||
)
|
||||
|
||||
except AppException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Login error: {e}")
|
||||
raise AppException(
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
message="Internal server error",
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.post("/register", response_model=ResponseModel)
|
||||
async def register(request: RegisterRequest, http_request: Request):
|
||||
"""
|
||||
用户注册
|
||||
|
||||
创建新用户账号,用户名和邮箱必须唯一。
|
||||
"""
|
||||
try:
|
||||
# 检查用户名是否已存在
|
||||
existing_user = await user_service.get_user_by_username(request.username)
|
||||
if existing_user:
|
||||
raise AppException(
|
||||
code=ErrorCode.CONFLICT,
|
||||
message="Username already registered",
|
||||
status_code=400
|
||||
)
|
||||
|
||||
# 检查邮箱是否已存在
|
||||
existing_email = await user_service.get_user_by_email(request.email)
|
||||
if existing_email:
|
||||
raise AppException(
|
||||
code=ErrorCode.CONFLICT,
|
||||
message="Email already registered",
|
||||
status_code=400
|
||||
)
|
||||
|
||||
# 创建用户
|
||||
user = await user_service.create_user(
|
||||
username=request.username, email=request.email, password=request.password
|
||||
)
|
||||
|
||||
session_id = str(uuid.uuid4())
|
||||
session_family_id = str(uuid.uuid4())
|
||||
tokens = create_token_pair(
|
||||
user_id=user.id,
|
||||
scopes=["user"],
|
||||
session_id=session_id,
|
||||
session_family_id=session_family_id,
|
||||
)
|
||||
session_service.create_session(
|
||||
user_id=user.id,
|
||||
refresh_token=tokens.refresh_token,
|
||||
session_id=session_id,
|
||||
session_family_id=session_family_id,
|
||||
ip_address=http_request.client.host if http_request.client else None,
|
||||
user_agent=http_request.headers.get("user-agent"),
|
||||
)
|
||||
|
||||
logger.info(f"User registered: {user.username} ({user.id})")
|
||||
|
||||
return ResponseModel(
|
||||
data={
|
||||
"access_token": tokens.access_token,
|
||||
"refresh_token": tokens.refresh_token,
|
||||
"token_type": "bearer",
|
||||
"expires_in": ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
||||
"session_id": session_id,
|
||||
"user": _build_user_payload(user),
|
||||
}
|
||||
)
|
||||
|
||||
except AppException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Registration error: {e}")
|
||||
raise AppException(
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
message="Internal server error",
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=ResponseModel)
|
||||
async def refresh_token(request: RefreshTokenRequest, http_request: Request):
|
||||
"""
|
||||
刷新 access token
|
||||
|
||||
使用 refresh token 获取新的 access token。
|
||||
"""
|
||||
try:
|
||||
# 验证 refresh token(检查黑名单)
|
||||
payload = await verify_token(request.refresh_token, token_type="refresh")
|
||||
if not payload:
|
||||
raise AppException(
|
||||
code=ErrorCode.UNAUTHORIZED,
|
||||
message="Invalid or revoked refresh token",
|
||||
status_code=401,
|
||||
details={"headers": {"WWW-Authenticate": "Bearer"}}
|
||||
)
|
||||
|
||||
if not payload.sid:
|
||||
raise AppException(
|
||||
code=ErrorCode.UNAUTHORIZED,
|
||||
message="Refresh token session is missing",
|
||||
status_code=401,
|
||||
details={"headers": {"WWW-Authenticate": "Bearer"}}
|
||||
)
|
||||
|
||||
user = await user_service.get_user_by_id(payload.sub)
|
||||
if not user or not user.is_active:
|
||||
if user:
|
||||
session_service.revoke_user_sessions(user.id, reason="inactive_user")
|
||||
raise AppException(
|
||||
code=ErrorCode.UNAUTHORIZED,
|
||||
message="User session is no longer active",
|
||||
status_code=401,
|
||||
details={"headers": {"WWW-Authenticate": "Bearer"}}
|
||||
)
|
||||
|
||||
new_session_id = str(uuid.uuid4())
|
||||
tokens = create_token_pair(
|
||||
user_id=payload.sub,
|
||||
scopes=payload.scopes or ["user"],
|
||||
session_id=new_session_id,
|
||||
session_family_id=payload.sfid,
|
||||
)
|
||||
|
||||
rotated_session = session_service.rotate_refresh_token(
|
||||
payload.sid,
|
||||
request.refresh_token,
|
||||
tokens.refresh_token,
|
||||
new_session_id=new_session_id,
|
||||
ip_address=http_request.client.host if http_request.client else None,
|
||||
user_agent=http_request.headers.get("user-agent"),
|
||||
)
|
||||
if not rotated_session:
|
||||
raise AppException(
|
||||
code=ErrorCode.UNAUTHORIZED,
|
||||
message="Invalid or replayed refresh token",
|
||||
status_code=401,
|
||||
details={"headers": {"WWW-Authenticate": "Bearer"}}
|
||||
)
|
||||
|
||||
user_data = None
|
||||
if user:
|
||||
user_data = _build_user_payload(user)
|
||||
|
||||
return ResponseModel(
|
||||
data={
|
||||
"access_token": tokens.access_token,
|
||||
"refresh_token": tokens.refresh_token,
|
||||
"token_type": "bearer",
|
||||
"expires_in": ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
||||
"session_id": rotated_session.id,
|
||||
**({"user": user_data} if user_data else {}),
|
||||
}
|
||||
)
|
||||
|
||||
except AppException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Token refresh error: {e}")
|
||||
raise AppException(
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
message="Internal server error",
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.get("/me", response_model=ResponseModel)
|
||||
async def get_me(current_user: UserAuth = Depends(get_current_user)):
|
||||
"""
|
||||
获取当前登录用户信息
|
||||
|
||||
需要有效的 access token。
|
||||
"""
|
||||
return ResponseModel(
|
||||
data={
|
||||
"id": current_user.id,
|
||||
"username": current_user.username,
|
||||
"email": current_user.email,
|
||||
"avatar_url": current_user.avatar_url,
|
||||
"is_active": current_user.is_active,
|
||||
"is_superuser": current_user.is_superuser,
|
||||
"permissions": current_user.permissions,
|
||||
"roles": current_user.roles,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.get("/session", response_model=ResponseModel)
|
||||
async def get_current_session(
|
||||
http_request: Request,
|
||||
current_user: UserAuth = Depends(get_current_user),
|
||||
):
|
||||
payload = await _get_access_payload(http_request)
|
||||
session_data = None
|
||||
|
||||
if payload and payload.sid:
|
||||
session_db = session_service.get_session(payload.sid)
|
||||
if session_db and session_db.user_id == current_user.id:
|
||||
session_data = _serialize_session(session_db, current_session_id=payload.sid)
|
||||
|
||||
return ResponseModel(
|
||||
data={
|
||||
"authenticated": True,
|
||||
"session_id": payload.sid if payload else None,
|
||||
"user": {
|
||||
"id": current_user.id,
|
||||
"username": current_user.username,
|
||||
"email": current_user.email,
|
||||
"avatar_url": current_user.avatar_url,
|
||||
"is_active": current_user.is_active,
|
||||
"is_superuser": current_user.is_superuser,
|
||||
"permissions": current_user.permissions,
|
||||
"roles": current_user.roles,
|
||||
},
|
||||
"session": session_data,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.get("/sessions", response_model=ResponseModel)
|
||||
async def list_sessions(
|
||||
http_request: Request,
|
||||
include_inactive: bool = False,
|
||||
current_user: UserAuth = Depends(get_current_user),
|
||||
):
|
||||
payload = await _get_access_payload(http_request)
|
||||
current_session_id = payload.sid if payload else None
|
||||
sessions = session_service.list_user_sessions(
|
||||
current_user.id,
|
||||
include_inactive=include_inactive,
|
||||
)
|
||||
|
||||
return ResponseModel(
|
||||
data={
|
||||
"items": [
|
||||
_serialize_session(session_db, current_session_id=current_session_id)
|
||||
for session_db in sessions
|
||||
],
|
||||
"total": len(sessions),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/sessions/{session_id}", response_model=ResponseModel)
|
||||
async def revoke_session(
|
||||
session_id: str,
|
||||
http_request: Request,
|
||||
current_user: UserAuth = Depends(get_current_user),
|
||||
):
|
||||
payload = await _get_access_payload(http_request)
|
||||
revoked = session_service.revoke_user_session(
|
||||
current_user.id,
|
||||
session_id,
|
||||
reason="user_revoke",
|
||||
)
|
||||
|
||||
if not revoked:
|
||||
raise AppException(
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
message="Session not found",
|
||||
status_code=404,
|
||||
)
|
||||
|
||||
revoked_access = False
|
||||
if payload and payload.sid == session_id:
|
||||
access_token = _extract_bearer_token(http_request)
|
||||
if access_token:
|
||||
revoked_access = await token_blacklist_service.revoke_token(
|
||||
access_token,
|
||||
reason="user_revoke",
|
||||
)
|
||||
|
||||
return ResponseModel(
|
||||
message="Session revoked successfully",
|
||||
data={
|
||||
"session_id": session_id,
|
||||
"revoked": True,
|
||||
"current": payload.sid == session_id if payload else False,
|
||||
"revoked_access": revoked_access,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.post("/logout", response_model=ResponseModel)
|
||||
async def logout(
|
||||
request: RefreshTokenRequest,
|
||||
http_request: Request,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
用户登出
|
||||
|
||||
将当前 access token 和 refresh token 加入黑名单,使其失效。
|
||||
客户端也需要清除本地存储的 token。
|
||||
"""
|
||||
try:
|
||||
access_token = _extract_bearer_token(http_request)
|
||||
revoked_access = False
|
||||
if access_token:
|
||||
revoked_access = await token_blacklist_service.revoke_token(
|
||||
access_token,
|
||||
reason="logout",
|
||||
)
|
||||
|
||||
# 撤销 refresh token(如果提供)
|
||||
revoked_refresh = False
|
||||
if request.refresh_token:
|
||||
refresh_payload = await verify_token(request.refresh_token, token_type="refresh")
|
||||
if refresh_payload and refresh_payload.sid:
|
||||
session_service.revoke_session(refresh_payload.sid, reason="logout")
|
||||
await token_blacklist_service.revoke_token(
|
||||
request.refresh_token,
|
||||
reason="logout"
|
||||
)
|
||||
revoked_refresh = True
|
||||
|
||||
logger.info(f"User logged out: {current_user.username} ({current_user.id})")
|
||||
|
||||
return ResponseModel(
|
||||
message="Logged out successfully",
|
||||
data={
|
||||
"user_id": current_user.id,
|
||||
"revoked": revoked_access or revoked_refresh,
|
||||
"revoked_access": revoked_access,
|
||||
"revoked_refresh": revoked_refresh,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Logout error: {e}")
|
||||
# 即使撤销失败也返回成功,因为客户端应该清除本地 token
|
||||
return ResponseModel(message="Logged out successfully")
|
||||
|
||||
|
||||
class LogoutAllRequest(BaseModel):
|
||||
"""登出所有设备请求"""
|
||||
pass
|
||||
|
||||
|
||||
@router.post("/logout-all", response_model=ResponseModel)
|
||||
async def logout_all_devices(
|
||||
request: LogoutAllRequest,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
从所有设备登出
|
||||
|
||||
撤销该用户的所有 Token,使用户在所有设备上登出。
|
||||
"""
|
||||
try:
|
||||
# 撤销该用户所有 token
|
||||
await token_blacklist_service.revoke_all_user_tokens(
|
||||
current_user.id,
|
||||
reason="logout_all"
|
||||
)
|
||||
revoked_sessions = session_service.revoke_user_sessions(
|
||||
current_user.id,
|
||||
reason="logout_all",
|
||||
)
|
||||
|
||||
logger.info(f"User logged out from all devices: {current_user.username} ({current_user.id})")
|
||||
|
||||
return ResponseModel(
|
||||
message="Logged out from all devices successfully",
|
||||
data={
|
||||
"user_id": current_user.id,
|
||||
"all_devices": True,
|
||||
"revoked_sessions": revoked_sessions,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Logout all error: {e}")
|
||||
raise AppException(
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
message="Failed to logout from all devices",
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.post("/avatar", response_model=ResponseModel)
|
||||
async def upload_avatar(
|
||||
file: UploadFile = File(...),
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
上传用户头像
|
||||
|
||||
支持 JPG、PNG 格式,最大 5MB。
|
||||
头像会被裁剪为正方形并压缩至 256x256 像素。
|
||||
"""
|
||||
try:
|
||||
# 验证文件类型
|
||||
allowed_types = {'image/jpeg', 'image/png', 'image/jpg'}
|
||||
if file.content_type not in allowed_types:
|
||||
raise AppException(
|
||||
message="Only JPG and PNG images are allowed",
|
||||
code=ErrorCode.INVALID_PARAMETER,
|
||||
status_code=400
|
||||
)
|
||||
|
||||
# 读取文件内容
|
||||
contents = await file.read()
|
||||
max_size = 5 * 1024 * 1024 # 5MB
|
||||
if len(contents) > max_size:
|
||||
raise AppException(
|
||||
message="File size exceeds 5MB limit",
|
||||
code=ErrorCode.INVALID_PARAMETER,
|
||||
status_code=400
|
||||
)
|
||||
|
||||
# 使用 PIL 处理图片
|
||||
image = Image.open(io.BytesIO(contents))
|
||||
|
||||
# 转换为 RGB(处理RGBA图片)
|
||||
if image.mode in ('RGBA', 'LA', 'P'):
|
||||
background = Image.new('RGB', image.size, (255, 255, 255))
|
||||
if image.mode == 'P':
|
||||
image = image.convert('RGBA')
|
||||
if image.mode in ('RGBA', 'LA'):
|
||||
background.paste(image, mask=image.split()[-1] if image.mode in ('RGBA', 'LA') else None)
|
||||
image = background
|
||||
|
||||
# 裁剪为正方形(从中心)
|
||||
width, height = image.size
|
||||
min_dim = min(width, height)
|
||||
left = (width - min_dim) // 2
|
||||
top = (height - min_dim) // 2
|
||||
right = left + min_dim
|
||||
bottom = top + min_dim
|
||||
image = image.crop((left, top, right, bottom))
|
||||
|
||||
# 调整大小为 256x256
|
||||
image = image.resize((256, 256), Image.Resampling.LANCZOS)
|
||||
|
||||
# 保存为 PNG
|
||||
output = io.BytesIO()
|
||||
image.save(output, format='PNG', optimize=True)
|
||||
output.seek(0)
|
||||
|
||||
# 生成存储路径: avatars/{user_id}/{timestamp}.png
|
||||
timestamp = int(datetime.now().timestamp())
|
||||
storage_path = f"avatars/{current_user.id}/{timestamp}.png"
|
||||
|
||||
# 上传到存储
|
||||
avatar_url = storage_manager.save(storage_path, output.getvalue())
|
||||
|
||||
# 更新用户头像 URL
|
||||
updated_user = await user_service.update_user(
|
||||
current_user.id,
|
||||
avatar_url=avatar_url
|
||||
)
|
||||
|
||||
if not updated_user:
|
||||
raise AppException(
|
||||
message="Failed to update user avatar",
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
logger.info(f"Avatar uploaded for user {current_user.username}: {avatar_url}")
|
||||
|
||||
return ResponseModel(
|
||||
data={
|
||||
"avatar_url": avatar_url,
|
||||
"user": {
|
||||
"id": updated_user.id,
|
||||
"username": updated_user.username,
|
||||
"email": updated_user.email,
|
||||
"avatar_url": updated_user.avatar_url,
|
||||
"is_active": updated_user.is_active,
|
||||
"is_superuser": updated_user.is_superuser,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
except AppException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Avatar upload error: {e}")
|
||||
raise AppException(
|
||||
message=f"Failed to upload avatar: {str(e)}",
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
# ===== Password Reset Models =====
|
||||
|
||||
class ForgotPasswordRequest(BaseModel):
|
||||
"""忘记密码请求"""
|
||||
email: EmailStr = Field(..., description="用户邮箱")
|
||||
|
||||
|
||||
class ResetPasswordRequest(BaseModel):
|
||||
"""重置密码请求"""
|
||||
token: str = Field(..., description="重置令牌")
|
||||
new_password: str = Field(..., min_length=6, description="新密码")
|
||||
|
||||
|
||||
class VerifyResetTokenRequest(BaseModel):
|
||||
"""验证重置令牌请求"""
|
||||
token: str = Field(..., description="重置令牌")
|
||||
|
||||
|
||||
# Password reset token storage (using Redis if available, else memory)
|
||||
_reset_tokens = {} # In-memory fallback {token: {"user_id": str, "expires": float}}
|
||||
|
||||
|
||||
async def _store_reset_token(token: str, user_id: str, expires_in: int = 3600) -> None:
|
||||
"""Store reset token in Redis or memory"""
|
||||
if REDIS_ENABLED:
|
||||
try:
|
||||
import redis.asyncio as aioredis
|
||||
from src.config.settings import REDIS_URL
|
||||
redis_client = await aioredis.from_url(REDIS_URL, encoding="utf-8", decode_responses=True)
|
||||
await redis_client.setex(f"pwd_reset:{token}", expires_in, user_id)
|
||||
await redis_client.close()
|
||||
return
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis not available for reset token, using memory: {e}")
|
||||
|
||||
# Fallback to memory
|
||||
_reset_tokens[token] = {
|
||||
"user_id": user_id,
|
||||
"expires": datetime.now().timestamp() + expires_in
|
||||
}
|
||||
|
||||
|
||||
async def _get_reset_token_user_id(token: str) -> Optional[str]:
|
||||
"""Get user_id from reset token"""
|
||||
if REDIS_ENABLED:
|
||||
try:
|
||||
import redis.asyncio as aioredis
|
||||
from src.config.settings import REDIS_URL
|
||||
redis_client = await aioredis.from_url(REDIS_URL, encoding="utf-8", decode_responses=True)
|
||||
user_id = await redis_client.get(f"pwd_reset:{token}")
|
||||
await redis_client.close()
|
||||
return user_id
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Fallback to memory
|
||||
token_data = _reset_tokens.get(token)
|
||||
if token_data:
|
||||
if token_data["expires"] > datetime.now().timestamp():
|
||||
return token_data["user_id"]
|
||||
else:
|
||||
# Expired, remove it
|
||||
del _reset_tokens[token]
|
||||
return None
|
||||
|
||||
|
||||
async def _delete_reset_token(token: str) -> None:
|
||||
"""Delete reset token"""
|
||||
if REDIS_ENABLED:
|
||||
try:
|
||||
import redis.asyncio as aioredis
|
||||
from src.config.settings import REDIS_URL
|
||||
redis_client = await aioredis.from_url(REDIS_URL, encoding="utf-8", decode_responses=True)
|
||||
await redis_client.delete(f"pwd_reset:{token}")
|
||||
await redis_client.close()
|
||||
return
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Fallback to memory
|
||||
if token in _reset_tokens:
|
||||
del _reset_tokens[token]
|
||||
|
||||
|
||||
def _generate_reset_token() -> str:
|
||||
"""Generate secure reset token"""
|
||||
import secrets
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
|
||||
@router.post("/forgot-password", response_model=ResponseModel)
|
||||
async def forgot_password(request: ForgotPasswordRequest):
|
||||
"""
|
||||
请求密码重置
|
||||
|
||||
发送密码重置邮件。
|
||||
令牌有效期为 1 小时。
|
||||
"""
|
||||
try:
|
||||
# 查找用户
|
||||
user = await user_service.get_user_by_email(request.email)
|
||||
if not user:
|
||||
# 为了安全,即使用户不存在也返回成功,不暴露邮箱是否存在
|
||||
logger.info(f"Password reset requested for non-existent email: {request.email}")
|
||||
return ResponseModel(
|
||||
message="If the email exists, a reset link has been sent"
|
||||
)
|
||||
|
||||
# 生成重置令牌
|
||||
token = _generate_reset_token()
|
||||
|
||||
# 存储令牌(1小时有效)
|
||||
await _store_reset_token(token, user.id, expires_in=3600)
|
||||
|
||||
# 发送重置邮件
|
||||
email_result = await email_service.send_password_reset(
|
||||
to_email=user.email,
|
||||
username=user.username,
|
||||
reset_token=token,
|
||||
expires_in=1
|
||||
)
|
||||
|
||||
if email_result["success"]:
|
||||
logger.info(f"Password reset email sent to {user.email}")
|
||||
return ResponseModel(
|
||||
message="Password reset email sent"
|
||||
)
|
||||
else:
|
||||
# 邮件发送失败,但在开发环境可以返回令牌
|
||||
logger.warning(f"Failed to send reset email: {email_result.get('error')}")
|
||||
|
||||
# 生产环境不返回令牌
|
||||
if NODE_ENV == "production":
|
||||
return ResponseModel(
|
||||
message="Failed to send reset email. Please try again later."
|
||||
)
|
||||
else:
|
||||
# 开发环境返回令牌以便测试
|
||||
return ResponseModel(
|
||||
data={
|
||||
"message": "Password reset email sent",
|
||||
"reset_token": token, # 开发环境返回,生产环境不应返回
|
||||
"expires_in": 3600,
|
||||
"email_error": email_result.get("error")
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Forgot password error: {e}")
|
||||
raise AppException(
|
||||
message="Failed to process password reset request",
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.post("/verify-reset-token", response_model=ResponseModel)
|
||||
async def verify_reset_token(request: VerifyResetTokenRequest):
|
||||
"""
|
||||
验证密码重置令牌
|
||||
|
||||
检查令牌是否有效,用于前端重置密码页面验证。
|
||||
"""
|
||||
try:
|
||||
user_id = await _get_reset_token_user_id(request.token)
|
||||
|
||||
if not user_id:
|
||||
raise AppException(
|
||||
message="Invalid or expired reset token",
|
||||
code=ErrorCode.INVALID_PARAMETER,
|
||||
status_code=400
|
||||
)
|
||||
|
||||
# 获取用户信息
|
||||
user = await user_service.get_user_by_id(user_id)
|
||||
if not user:
|
||||
raise AppException(
|
||||
message="User not found",
|
||||
code=ErrorCode.INVALID_PARAMETER,
|
||||
status_code=400
|
||||
)
|
||||
|
||||
return ResponseModel(
|
||||
data={
|
||||
"valid": True,
|
||||
"user_id": user_id,
|
||||
"email": user.email
|
||||
}
|
||||
)
|
||||
|
||||
except AppException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Verify reset token error: {e}")
|
||||
raise AppException(
|
||||
message="Failed to verify reset token",
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.post("/reset-password", response_model=ResponseModel)
|
||||
async def reset_password(request: ResetPasswordRequest):
|
||||
"""
|
||||
重置密码
|
||||
|
||||
使用有效的重置令牌设置新密码,并撤销该用户所有 Token。
|
||||
"""
|
||||
try:
|
||||
# 验证令牌
|
||||
user_id = await _get_reset_token_user_id(request.token)
|
||||
|
||||
if not user_id:
|
||||
raise AppException(
|
||||
message="Invalid or expired reset token",
|
||||
code=ErrorCode.INVALID_PARAMETER,
|
||||
status_code=400
|
||||
)
|
||||
|
||||
# 更新密码
|
||||
updated_user = await user_service.update_user(
|
||||
user_id,
|
||||
password=request.new_password
|
||||
)
|
||||
|
||||
if not updated_user:
|
||||
raise AppException(
|
||||
message="User not found",
|
||||
code=ErrorCode.INVALID_PARAMETER,
|
||||
status_code=400
|
||||
)
|
||||
|
||||
# 删除已使用的令牌
|
||||
await _delete_reset_token(request.token)
|
||||
|
||||
# 撤销该用户所有 Token(强制重新登录)
|
||||
await token_blacklist_service.revoke_all_user_tokens(
|
||||
user_id,
|
||||
reason="password_reset"
|
||||
)
|
||||
|
||||
# 使该用户所有缓存失效
|
||||
await user_service.invalidate_user_cache(user_id)
|
||||
|
||||
logger.info(f"Password reset successful for user {updated_user.username}")
|
||||
|
||||
return ResponseModel(
|
||||
message="Password reset successful",
|
||||
data={
|
||||
"user_id": updated_user.id,
|
||||
"email": updated_user.email
|
||||
}
|
||||
)
|
||||
|
||||
except AppException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Reset password error: {e}")
|
||||
raise AppException(
|
||||
message="Failed to reset password",
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
81
backend/src/api/canvas.py
Normal file
81
backend/src/api/canvas.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import logging
|
||||
from fastapi import APIRouter, Query
|
||||
from sqlmodel import Session
|
||||
from src.config.database import engine
|
||||
from src.models.entities import CanvasDB
|
||||
from src.models.schemas import ResponseModel, CanvasState
|
||||
from src.mappers import CanvasMapper
|
||||
from src.utils.errors import BusinessException, ErrorCode
|
||||
|
||||
router = APIRouter(tags=["canvas"])
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@router.get("/canvas", response_model=ResponseModel)
|
||||
async def get_canvas_state(id: str = "default", projectId: str = Query(None)):
|
||||
"""Get canvas state by ID
|
||||
|
||||
Returns empty default state if canvas not found.
|
||||
|
||||
Raises:
|
||||
BusinessException: Failed to fetch canvas
|
||||
"""
|
||||
try:
|
||||
with Session(engine) as session:
|
||||
canvas_db = session.get(CanvasDB, id)
|
||||
|
||||
if not canvas_db:
|
||||
# If not found, return empty default state
|
||||
return ResponseModel(data=CanvasState(id=id, projectId=projectId))
|
||||
|
||||
# Map to CanvasState
|
||||
canvas_state = CanvasState(
|
||||
id=canvas_db.id,
|
||||
projectId=canvas_db.project_id,
|
||||
nodes=canvas_db.nodes,
|
||||
connections=canvas_db.connections,
|
||||
groups=canvas_db.groups,
|
||||
history=canvas_db.history,
|
||||
historyIndex=canvas_db.history_index,
|
||||
updated_at=canvas_db.updated_at
|
||||
)
|
||||
return ResponseModel(data=canvas_state)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to fetch canvas {id}: {e}", exc_info=True)
|
||||
raise BusinessException(
|
||||
ErrorCode.CANVAS_NOT_FOUND,
|
||||
"Failed to fetch canvas state",
|
||||
{"canvas_id": id, "reason": str(e)}
|
||||
)
|
||||
|
||||
@router.post("/canvas", response_model=ResponseModel)
|
||||
async def save_canvas_state(state: CanvasState):
|
||||
"""Save canvas state
|
||||
|
||||
Raises:
|
||||
BusinessException: Failed to save canvas
|
||||
"""
|
||||
try:
|
||||
with Session(engine) as session:
|
||||
# Use merge to handle both insert and update gracefully
|
||||
canvas_db = CanvasDB(
|
||||
id=state.id,
|
||||
project_id=state.projectId,
|
||||
nodes=state.nodes,
|
||||
connections=state.connections,
|
||||
groups=state.groups,
|
||||
history=state.history,
|
||||
history_index=state.history_index,
|
||||
updated_at=state.updated_at
|
||||
)
|
||||
|
||||
session.merge(canvas_db)
|
||||
session.commit()
|
||||
|
||||
return ResponseModel(data={"id": state.id, "updated": True})
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save canvas {state.id}: {e}", exc_info=True)
|
||||
raise BusinessException(
|
||||
ErrorCode.UNKNOWN_ERROR,
|
||||
"Failed to save canvas state",
|
||||
{"canvas_id": state.id, "reason": str(e)}
|
||||
)
|
||||
428
backend/src/api/canvas_metadata.py
Normal file
428
backend/src/api/canvas_metadata.py
Normal file
@@ -0,0 +1,428 @@
|
||||
import logging
|
||||
from typing import Optional, List
|
||||
from fastapi import APIRouter, Query, Depends
|
||||
from sqlmodel import Session
|
||||
from src.config.database import engine
|
||||
from src.services.canvas_metadata_service import CanvasMetadataService
|
||||
from src.models.schemas import (
|
||||
ResponseModel,
|
||||
CanvasMetadata,
|
||||
CreateGeneralCanvasRequest,
|
||||
UpdateCanvasMetadataRequest
|
||||
)
|
||||
from src.utils.errors import ResourceNotFoundException, BusinessException, ErrorCode
|
||||
|
||||
router = APIRouter(tags=["canvas_metadata"])
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_session():
|
||||
""" 依赖 to get database session"""
|
||||
with Session(engine) as session:
|
||||
yield session
|
||||
|
||||
|
||||
@router.get("/projects/{project_id}/canvases", response_model=ResponseModel)
|
||||
async def list_project_canvases(
|
||||
project_id: str,
|
||||
canvas_type: Optional[str] = Query(None, description="Filter by type: general, asset, storyboard"),
|
||||
include_deleted: bool = Query(False, description="Include deleted canvases"),
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
""" 查询 all canvases for a project.
|
||||
|
||||
Args:
|
||||
project_id: Project ID
|
||||
canvas_type: Optional filter by canvas type (general, asset, storyboard)
|
||||
include_deleted: Whether to include soft-deleted canvases
|
||||
|
||||
Returns:
|
||||
List of canvas metadata
|
||||
|
||||
Raises:
|
||||
BusinessException: Failed to list canvases
|
||||
"""
|
||||
try:
|
||||
service = CanvasMetadataService(session)
|
||||
canvases = service.list_canvases(project_id, canvas_type, include_deleted)
|
||||
|
||||
# 转换 to dict format
|
||||
canvas_data = [
|
||||
CanvasMetadata(
|
||||
id=c.id,
|
||||
projectId=c.project_id,
|
||||
canvasType=c.canvas_type,
|
||||
relatedEntityType=c.related_entity_type,
|
||||
relatedEntityId=c.related_entity_id,
|
||||
name=c.name,
|
||||
description=c.description,
|
||||
orderIndex=c.order_index,
|
||||
isPinned=c.is_pinned,
|
||||
tags=c.tags,
|
||||
nodeCount=c.node_count,
|
||||
lastAccessedAt=c.last_accessed_at,
|
||||
accessCount=c.access_count,
|
||||
createdAt=c.created_at,
|
||||
updatedAt=c.updated_at,
|
||||
deletedAt=c.deleted_at
|
||||
).model_dump(by_alias=True)
|
||||
for c in canvases
|
||||
]
|
||||
|
||||
return ResponseModel(data=canvas_data)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list canvases for project {project_id}: {e}", exc_info=True)
|
||||
raise BusinessException(
|
||||
ErrorCode.UNKNOWN_ERROR,
|
||||
"Failed to list canvases",
|
||||
{"project_id": project_id, "reason": str(e)}
|
||||
)
|
||||
|
||||
|
||||
@router.get("/canvases/{canvas_id}/metadata", response_model=ResponseModel)
|
||||
async def get_canvas_metadata(
|
||||
canvas_id: str,
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
""" Get canvas metadata by ID.
|
||||
|
||||
Args:
|
||||
canvas_id: Canvas ID
|
||||
|
||||
Returns:
|
||||
Canvas metadata
|
||||
|
||||
Raises:
|
||||
ResourceNotFoundException: Canvas not found
|
||||
BusinessException: Failed to get canvas metadata
|
||||
"""
|
||||
try:
|
||||
service = CanvasMetadataService(session)
|
||||
canvas = service.get_canvas(canvas_id)
|
||||
|
||||
if not canvas:
|
||||
raise ResourceNotFoundException("canvas", canvas_id)
|
||||
|
||||
# 更新 access statistics
|
||||
service.update_access_stats(canvas.id)
|
||||
|
||||
# 转换 to response format
|
||||
canvas_data = CanvasMetadata(
|
||||
id=canvas.id,
|
||||
projectId=canvas.project_id,
|
||||
canvasType=canvas.canvas_type,
|
||||
relatedEntityType=canvas.related_entity_type,
|
||||
relatedEntityId=canvas.related_entity_id,
|
||||
name=canvas.name,
|
||||
description=canvas.description,
|
||||
orderIndex=canvas.order_index,
|
||||
isPinned=canvas.is_pinned,
|
||||
tags=canvas.tags,
|
||||
nodeCount=canvas.node_count,
|
||||
lastAccessedAt=canvas.last_accessed_at,
|
||||
accessCount=canvas.access_count,
|
||||
createdAt=canvas.created_at,
|
||||
updatedAt=canvas.updated_at,
|
||||
deletedAt=canvas.deleted_at
|
||||
).model_dump(by_alias=True)
|
||||
|
||||
return ResponseModel(data=canvas_data)
|
||||
except ResourceNotFoundException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get canvas metadata {canvas_id}: {e}", exc_info=True)
|
||||
raise BusinessException(
|
||||
ErrorCode.CANVAS_NOT_FOUND,
|
||||
"Failed to get canvas metadata",
|
||||
{"canvas_id": canvas_id, "reason": str(e)}
|
||||
)
|
||||
|
||||
|
||||
@router.get("/assets/{asset_id}/canvas", response_model=ResponseModel)
|
||||
async def get_asset_canvas(
|
||||
asset_id: str,
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
""" Get canvas associated with an asset.
|
||||
Automatically creates the canvas if it doesn't exist.
|
||||
|
||||
Args:
|
||||
asset_id: Asset ID
|
||||
|
||||
Returns:
|
||||
Canvas metadata
|
||||
|
||||
Raises:
|
||||
ResourceNotFoundException: Asset not found
|
||||
BusinessException: Failed to get asset canvas
|
||||
"""
|
||||
try:
|
||||
service = CanvasMetadataService(session)
|
||||
canvas = service.get_or_create_asset_canvas(asset_id)
|
||||
|
||||
# 转换 to response format
|
||||
canvas_data = CanvasMetadata(
|
||||
id=canvas.id,
|
||||
projectId=canvas.project_id,
|
||||
canvasType=canvas.canvas_type,
|
||||
relatedEntityType=canvas.related_entity_type,
|
||||
relatedEntityId=canvas.related_entity_id,
|
||||
name=canvas.name,
|
||||
description=canvas.description,
|
||||
orderIndex=canvas.order_index,
|
||||
isPinned=canvas.is_pinned,
|
||||
tags=canvas.tags,
|
||||
nodeCount=canvas.node_count,
|
||||
lastAccessedAt=canvas.last_accessed_at,
|
||||
accessCount=canvas.access_count,
|
||||
createdAt=canvas.created_at,
|
||||
updatedAt=canvas.updated_at,
|
||||
deletedAt=canvas.deleted_at
|
||||
).model_dump(by_alias=True)
|
||||
|
||||
return ResponseModel(data=canvas_data)
|
||||
except ValueError as e:
|
||||
raise ResourceNotFoundException("asset", asset_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get asset canvas {asset_id}: {e}", exc_info=True)
|
||||
raise BusinessException(
|
||||
ErrorCode.UNKNOWN_ERROR,
|
||||
"Failed to get asset canvas",
|
||||
{"asset_id": asset_id, "reason": str(e)}
|
||||
)
|
||||
|
||||
|
||||
@router.get("/storyboards/{storyboard_id}/canvas", response_model=ResponseModel)
|
||||
async def get_storyboard_canvas(
|
||||
storyboard_id: str,
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
""" Get canvas associated with a storyboard.
|
||||
Automatically creates the canvas if it doesn't exist.
|
||||
|
||||
Args:
|
||||
storyboard_id: Storyboard ID
|
||||
|
||||
Returns:
|
||||
Canvas metadata
|
||||
|
||||
Raises:
|
||||
ResourceNotFoundException: Storyboard not found
|
||||
BusinessException: Failed to get storyboard canvas
|
||||
"""
|
||||
try:
|
||||
service = CanvasMetadataService(session)
|
||||
canvas = service.get_or_create_storyboard_canvas(storyboard_id)
|
||||
|
||||
# 转换 to response format
|
||||
canvas_data = CanvasMetadata(
|
||||
id=canvas.id,
|
||||
projectId=canvas.project_id,
|
||||
canvasType=canvas.canvas_type,
|
||||
relatedEntityType=canvas.related_entity_type,
|
||||
relatedEntityId=canvas.related_entity_id,
|
||||
name=canvas.name,
|
||||
description=canvas.description,
|
||||
orderIndex=canvas.order_index,
|
||||
isPinned=canvas.is_pinned,
|
||||
tags=canvas.tags,
|
||||
nodeCount=canvas.node_count,
|
||||
lastAccessedAt=canvas.last_accessed_at,
|
||||
accessCount=canvas.access_count,
|
||||
createdAt=canvas.created_at,
|
||||
updatedAt=canvas.updated_at,
|
||||
deletedAt=canvas.deleted_at
|
||||
).model_dump(by_alias=True)
|
||||
|
||||
return ResponseModel(data=canvas_data)
|
||||
except ValueError as e:
|
||||
raise ResourceNotFoundException("storyboard", storyboard_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get storyboard canvas {storyboard_id}: {e}", exc_info=True)
|
||||
raise BusinessException(
|
||||
ErrorCode.UNKNOWN_ERROR,
|
||||
"Failed to get storyboard canvas",
|
||||
{"storyboard_id": storyboard_id, "reason": str(e)}
|
||||
)
|
||||
|
||||
|
||||
@router.post("/projects/{project_id}/canvases", response_model=ResponseModel)
|
||||
async def create_general_canvas(
|
||||
project_id: str,
|
||||
request: CreateGeneralCanvasRequest,
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
""" Create a new general canvas for a project.
|
||||
|
||||
Args:
|
||||
project_id: Project ID
|
||||
request: Canvas creation request with name and optional description
|
||||
|
||||
Returns:
|
||||
Created canvas metadata
|
||||
|
||||
Raises:
|
||||
BusinessException: Failed to create canvas
|
||||
"""
|
||||
try:
|
||||
service = CanvasMetadataService(session)
|
||||
canvas = service.create_general_canvas(
|
||||
project_id=project_id,
|
||||
name=request.name,
|
||||
description=request.description
|
||||
)
|
||||
|
||||
# 转换 to response format
|
||||
canvas_data = CanvasMetadata(
|
||||
id=canvas.id,
|
||||
projectId=canvas.project_id,
|
||||
canvasType=canvas.canvas_type,
|
||||
relatedEntityType=canvas.related_entity_type,
|
||||
relatedEntityId=canvas.related_entity_id,
|
||||
name=canvas.name,
|
||||
description=canvas.description,
|
||||
orderIndex=canvas.order_index,
|
||||
isPinned=canvas.is_pinned,
|
||||
tags=canvas.tags,
|
||||
nodeCount=canvas.node_count,
|
||||
lastAccessedAt=canvas.last_accessed_at,
|
||||
accessCount=canvas.access_count,
|
||||
createdAt=canvas.created_at,
|
||||
updatedAt=canvas.updated_at,
|
||||
deletedAt=canvas.deleted_at
|
||||
).model_dump(by_alias=True)
|
||||
|
||||
return ResponseModel(data=canvas_data)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create general canvas for project {project_id}: {e}", exc_info=True)
|
||||
raise BusinessException(
|
||||
ErrorCode.UNKNOWN_ERROR,
|
||||
"Failed to create canvas",
|
||||
{"project_id": project_id, "reason": str(e)}
|
||||
)
|
||||
|
||||
|
||||
@router.put("/canvases/{canvas_id}/metadata", response_model=ResponseModel)
|
||||
async def update_canvas_metadata(
|
||||
canvas_id: str,
|
||||
request: UpdateCanvasMetadataRequest,
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
""" 更新 canvas metadata.
|
||||
|
||||
Args:
|
||||
canvas_id: Canvas ID
|
||||
request: Update request with optional fields
|
||||
|
||||
Returns:
|
||||
Updated canvas metadata
|
||||
|
||||
Raises:
|
||||
ResourceNotFoundException: Canvas not found
|
||||
BusinessException: Failed to update canvas
|
||||
"""
|
||||
try:
|
||||
service = CanvasMetadataService(session)
|
||||
|
||||
# 转换 request to dict, excluding unset fields
|
||||
updates = request.model_dump(exclude_unset=True, by_alias=False)
|
||||
|
||||
canvas = service.update_canvas(canvas_id, updates)
|
||||
|
||||
if not canvas:
|
||||
raise ResourceNotFoundException("canvas", canvas_id)
|
||||
|
||||
# 转换 to response format
|
||||
canvas_data = CanvasMetadata(
|
||||
id=canvas.id,
|
||||
projectId=canvas.project_id,
|
||||
canvasType=canvas.canvas_type,
|
||||
relatedEntityType=canvas.related_entity_type,
|
||||
relatedEntityId=canvas.related_entity_id,
|
||||
name=canvas.name,
|
||||
description=canvas.description,
|
||||
orderIndex=canvas.order_index,
|
||||
isPinned=canvas.is_pinned,
|
||||
tags=canvas.tags,
|
||||
nodeCount=canvas.node_count,
|
||||
lastAccessedAt=canvas.last_accessed_at,
|
||||
accessCount=canvas.access_count,
|
||||
createdAt=canvas.created_at,
|
||||
updatedAt=canvas.updated_at,
|
||||
deletedAt=canvas.deleted_at
|
||||
).model_dump(by_alias=True)
|
||||
|
||||
return ResponseModel(data=canvas_data)
|
||||
except ResourceNotFoundException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update canvas metadata {canvas_id}: {e}", exc_info=True)
|
||||
raise BusinessException(
|
||||
ErrorCode.UNKNOWN_ERROR,
|
||||
"Failed to update canvas",
|
||||
{"canvas_id": canvas_id, "reason": str(e)}
|
||||
)
|
||||
|
||||
|
||||
@router.put("/projects/{project_id}/canvases/reorder", response_model=ResponseModel)
|
||||
async def reorder_canvases(
|
||||
project_id: str,
|
||||
canvas_orders: List[dict],
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
""" 批处理 update canvas order indices.
|
||||
|
||||
Args:
|
||||
project_id: Project ID
|
||||
canvas_orders: List of {id, order_index} objects
|
||||
|
||||
Returns:
|
||||
Success confirmation
|
||||
|
||||
Raises:
|
||||
BusinessException: Failed to reorder canvases
|
||||
"""
|
||||
try:
|
||||
service = CanvasMetadataService(session)
|
||||
service.reorder_canvases(canvas_orders)
|
||||
|
||||
return ResponseModel(data={"updated": True})
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to reorder canvases for project {project_id}: {e}", exc_info=True)
|
||||
raise BusinessException(
|
||||
ErrorCode.UNKNOWN_ERROR,
|
||||
"Failed to reorder canvases",
|
||||
{"project_id": project_id, "reason": str(e)}
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/canvases/{canvas_id}", response_model=ResponseModel)
|
||||
async def delete_canvas(
|
||||
canvas_id: str,
|
||||
hard_delete: bool = Query(False, description="Permanently delete (true) or soft delete (false)"),
|
||||
session: Session = Depends(get_session)
|
||||
):
|
||||
""" 删除 a canvas.
|
||||
|
||||
Args:
|
||||
canvas_id: Canvas ID
|
||||
hard_delete: If true, permanently delete; if false, soft delete
|
||||
|
||||
Returns:
|
||||
Success confirmation
|
||||
|
||||
Raises:
|
||||
BusinessException: Failed to delete canvas
|
||||
"""
|
||||
try:
|
||||
service = CanvasMetadataService(session)
|
||||
service.delete_canvas(canvas_id, hard_delete)
|
||||
|
||||
return ResponseModel(data={"deleted": True})
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete canvas {canvas_id}: {e}", exc_info=True)
|
||||
raise BusinessException(
|
||||
ErrorCode.UNKNOWN_ERROR,
|
||||
"Failed to delete canvas",
|
||||
{"canvas_id": canvas_id, "hard_delete": hard_delete, "reason": str(e)}
|
||||
)
|
||||
61
backend/src/api/chat.py
Normal file
61
backend/src/api/chat.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import logging
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from typing import List, Optional
|
||||
|
||||
from src.services.agent_engine import AgentScopeService
|
||||
from src.utils.errors import BusinessException, ErrorCode
|
||||
from src.auth.dependencies import get_current_user, UserAuth
|
||||
|
||||
router = APIRouter(prefix="/chat", tags=["chat"])
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class Message(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
messages: List[Message]
|
||||
model: Optional[str] = None
|
||||
stream: bool = True
|
||||
temperature: Optional[float] = 0.7
|
||||
max_tokens: Optional[int] = Field(None, alias="maxTokens")
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
@router.post("/completions")
|
||||
async def chat_completions(
|
||||
request: ChatCompletionRequest,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
OpenAI-compatible chat completion endpoint.
|
||||
Supports streaming and non-streaming modes.
|
||||
Now integrated with AgentScope.
|
||||
|
||||
Raises:
|
||||
BusinessException: Chat completion failed
|
||||
"""
|
||||
if not request.stream:
|
||||
raise BusinessException(
|
||||
ErrorCode.INVALID_PARAMETER,
|
||||
"Non-stream mode is not supported",
|
||||
{"field": "stream", "expected": True}
|
||||
)
|
||||
|
||||
messages_payload = [m.model_dump() for m in request.messages]
|
||||
|
||||
try:
|
||||
agent_service = AgentScopeService(user_id=current_user.id)
|
||||
return StreamingResponse(
|
||||
agent_service.stream_chat(messages_payload),
|
||||
media_type="text/event-stream"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"AgentScope service error: {e}", exc_info=True)
|
||||
raise BusinessException(
|
||||
ErrorCode.GENERATION_FAILED,
|
||||
"Chat completion failed",
|
||||
{"reason": str(e)}
|
||||
)
|
||||
45
backend/src/api/config/__init__.py
Normal file
45
backend/src/api/config/__init__.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""Config API routes - split into multiple modules for maintainability.
|
||||
|
||||
This package provides configuration management endpoints including:
|
||||
- System configuration (/config/system)
|
||||
- Provider and model configuration (/config/providers, /config/models, etc.)
|
||||
- Style configuration (/config/styles)
|
||||
- Health monitoring (/config/health)
|
||||
- Configuration validation (/config/validate)
|
||||
- Admin operations (/admin/*)
|
||||
- Storage utilities (/storage/*)
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from . import system, providers, styles, health, validation, admin, storage
|
||||
|
||||
# Create main router
|
||||
router = APIRouter()
|
||||
|
||||
# Include all sub-routers
|
||||
router.include_router(system.router)
|
||||
router.include_router(providers.router)
|
||||
router.include_router(styles.router)
|
||||
router.include_router(health.router)
|
||||
router.include_router(validation.router)
|
||||
router.include_router(admin.router)
|
||||
router.include_router(storage.router)
|
||||
|
||||
# Export models for backward compatibility
|
||||
from .models import (
|
||||
SignUrlRequest,
|
||||
SystemConfig,
|
||||
ModelRegistrationRequest,
|
||||
ProviderKeyField,
|
||||
ProviderConfigResponse,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"router",
|
||||
"SignUrlRequest",
|
||||
"SystemConfig",
|
||||
"ModelRegistrationRequest",
|
||||
"ProviderKeyField",
|
||||
"ProviderConfigResponse",
|
||||
]
|
||||
321
backend/src/api/config/admin.py
Normal file
321
backend/src/api/config/admin.py
Normal file
@@ -0,0 +1,321 @@
|
||||
"""Provider and model management endpoints (Admin Only)."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from src.config.settings import REDIS_ENABLED
|
||||
from src.models.schemas import ResponseModel
|
||||
from src.auth.dependencies import require_admin
|
||||
from src.auth.models import UserAuth
|
||||
from src.services.cache_service import get_cache_service
|
||||
from src.services.provider.registry import ModelRegistry, ModelType
|
||||
from src.utils.errors import AppException, ErrorCode
|
||||
from src.services.provider.health import health_monitor
|
||||
from src.utils.service_loader import register_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/admin/providers/{provider_id}/toggle", response_model=ResponseModel)
|
||||
async def toggle_provider(
|
||||
provider_id: str,
|
||||
enabled: bool,
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
):
|
||||
"""Enable/disable Provider (Admin only)."""
|
||||
try:
|
||||
provider_config = ModelRegistry.get_provider_config(provider_id)
|
||||
if not provider_config:
|
||||
raise AppException(
|
||||
message=f"Provider not found: {provider_id}",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
|
||||
provider_config["enabled"] = enabled
|
||||
|
||||
src_dir = os.path.dirname(os.path.dirname(__file__))
|
||||
custom_providers_path = os.path.join(src_dir, "config", "services", "custom_providers.json")
|
||||
|
||||
existing_configs = []
|
||||
if os.path.exists(custom_providers_path):
|
||||
try:
|
||||
with open(custom_providers_path, 'r', encoding='utf-8') as f:
|
||||
existing_configs = json.load(f)
|
||||
except (json.JSONDecodeError, OSError) as e:
|
||||
logger.warning("Failed to load custom providers config %s: %s", custom_providers_path, e)
|
||||
|
||||
updated = False
|
||||
for i, config in enumerate(existing_configs):
|
||||
if config.get("id") == provider_id:
|
||||
existing_configs[i] = {**config, **provider_config}
|
||||
updated = True
|
||||
break
|
||||
|
||||
if not updated:
|
||||
existing_configs.append(provider_config)
|
||||
|
||||
os.makedirs(os.path.dirname(custom_providers_path), exist_ok=True)
|
||||
with open(custom_providers_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(existing_configs, f, indent=4, ensure_ascii=False)
|
||||
|
||||
if enabled:
|
||||
register_service(provider_config)
|
||||
|
||||
return ResponseModel(
|
||||
data={
|
||||
"provider_id": provider_id,
|
||||
"enabled": enabled,
|
||||
"message": f"Provider {'enabled' if enabled else 'disabled'} successfully",
|
||||
}
|
||||
)
|
||||
except AppException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error toggling provider: {e}")
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/admin/providers/{provider_id}", response_model=ResponseModel)
|
||||
async def delete_provider(
|
||||
provider_id: str,
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
):
|
||||
"""Delete custom Provider (Admin only)."""
|
||||
try:
|
||||
src_dir = os.path.dirname(os.path.dirname(__file__))
|
||||
custom_providers_path = os.path.join(src_dir, "config", "services", "custom_providers.json")
|
||||
|
||||
if not os.path.exists(custom_providers_path):
|
||||
raise AppException(
|
||||
message="Custom providers configuration not found",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
|
||||
with open(custom_providers_path, 'r', encoding='utf-8') as f:
|
||||
existing_configs = json.load(f)
|
||||
|
||||
initial_len = len(existing_configs)
|
||||
existing_configs = [c for c in existing_configs if c.get("id") != provider_id]
|
||||
|
||||
if len(existing_configs) == initial_len:
|
||||
raise AppException(
|
||||
message="Provider not found or not a custom provider",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
|
||||
with open(custom_providers_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(existing_configs, f, indent=4, ensure_ascii=False)
|
||||
|
||||
return ResponseModel(
|
||||
data={
|
||||
"provider_id": provider_id,
|
||||
"deleted": True,
|
||||
"message": "Provider deleted successfully",
|
||||
}
|
||||
)
|
||||
except AppException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting provider: {e}")
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.post("/admin/models/{model_id}/toggle", response_model=ResponseModel)
|
||||
async def toggle_model(
|
||||
model_id: str,
|
||||
enabled: bool,
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
):
|
||||
"""Enable/disable model (Admin only)."""
|
||||
try:
|
||||
model_config = ModelRegistry.get_config(model_id)
|
||||
if not model_config:
|
||||
raise AppException(
|
||||
message=f"Model not found: {model_id}",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
|
||||
model_config["enabled"] = enabled
|
||||
|
||||
src_dir = os.path.dirname(os.path.dirname(__file__))
|
||||
custom_models_path = os.path.join(src_dir, "config", "services", "custom_models.json")
|
||||
|
||||
existing_configs = []
|
||||
if os.path.exists(custom_models_path):
|
||||
try:
|
||||
with open(custom_models_path, 'r', encoding='utf-8') as f:
|
||||
existing_configs = json.load(f)
|
||||
except (json.JSONDecodeError, OSError) as e:
|
||||
logger.warning("Failed to load custom models config %s: %s", custom_models_path, e)
|
||||
|
||||
updated = False
|
||||
for i, config in enumerate(existing_configs):
|
||||
if config.get("id") == model_id:
|
||||
existing_configs[i] = {**config, "enabled": enabled}
|
||||
updated = True
|
||||
break
|
||||
|
||||
if not updated:
|
||||
existing_configs.append({**model_config, "enabled": enabled})
|
||||
|
||||
os.makedirs(os.path.dirname(custom_models_path), exist_ok=True)
|
||||
with open(custom_models_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(existing_configs, f, indent=4, ensure_ascii=False)
|
||||
|
||||
if enabled:
|
||||
register_service(model_config)
|
||||
|
||||
return ResponseModel(
|
||||
data={
|
||||
"model_id": model_id,
|
||||
"enabled": enabled,
|
||||
"message": f"Model {'enabled' if enabled else 'disabled'} successfully",
|
||||
}
|
||||
)
|
||||
except AppException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error toggling model: {e}")
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.post("/admin/models/{model_id}/default", response_model=ResponseModel)
|
||||
async def set_default_model(
|
||||
model_id: str,
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
):
|
||||
"""Set default model (Admin only)."""
|
||||
try:
|
||||
model_config = ModelRegistry.get_config(model_id)
|
||||
if not model_config:
|
||||
raise AppException(
|
||||
message=f"Model not found: {model_id}",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
|
||||
model_type_str = model_config.get("type")
|
||||
if not model_type_str:
|
||||
raise AppException(
|
||||
message="Model type not found",
|
||||
code=ErrorCode.INVALID_PARAMETER,
|
||||
status_code=400
|
||||
)
|
||||
|
||||
try:
|
||||
model_type = ModelType(model_type_str.lower())
|
||||
except ValueError:
|
||||
raise AppException(
|
||||
message=f"Invalid model type: {model_type_str}",
|
||||
code=ErrorCode.INVALID_PARAMETER,
|
||||
status_code=400
|
||||
)
|
||||
|
||||
ModelRegistry.set_default_by_id(model_type, model_id)
|
||||
|
||||
config_key_map = {
|
||||
"image": "defaultImageModel",
|
||||
"video": "defaultVideoModel",
|
||||
"audio": "defaultAudioModel",
|
||||
"lyrics": "defaultLyricsModel",
|
||||
"music": "defaultMusicModel",
|
||||
"llm": "defaultLLMModel",
|
||||
}
|
||||
|
||||
config_key = config_key_map.get(model_type_str.lower())
|
||||
if config_key:
|
||||
src_dir = os.path.dirname(os.path.dirname(__file__))
|
||||
config_path = os.path.join(src_dir, "config", "user_config.json")
|
||||
|
||||
existing_data = {}
|
||||
if os.path.exists(config_path):
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
existing_data = json.load(f)
|
||||
|
||||
existing_data[config_key] = model_id
|
||||
|
||||
with open(config_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(existing_data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
if REDIS_ENABLED:
|
||||
cache = get_cache_service()
|
||||
await cache.delete("config:defaults")
|
||||
await cache.delete("config:system")
|
||||
|
||||
return ResponseModel(
|
||||
data={
|
||||
"model_id": model_id,
|
||||
"model_type": model_type_str,
|
||||
"is_default": True,
|
||||
"message": f"Default {model_type_str} model set to {model_id}",
|
||||
}
|
||||
)
|
||||
except AppException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error setting default model: {e}")
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.post("/admin/models/{model_id}/test", response_model=ResponseModel)
|
||||
async def test_model(
|
||||
model_id: str,
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
):
|
||||
"""Test model connection (Admin only)."""
|
||||
try:
|
||||
service = ModelRegistry.get(model_id)
|
||||
if not service:
|
||||
raise AppException(
|
||||
message=f"Model service not found: {model_id}",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
|
||||
result = await health_monitor.check_service_health(model_id, service)
|
||||
|
||||
health_monitor.update_health(model_id, result)
|
||||
|
||||
return ResponseModel(
|
||||
data={
|
||||
"model_id": model_id,
|
||||
"success": result.status.value == "healthy",
|
||||
"status": result.status.value,
|
||||
"latency_ms": result.latency_ms,
|
||||
"message": "Model test completed",
|
||||
"error": result.error,
|
||||
}
|
||||
)
|
||||
except AppException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error testing model: {e}")
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
137
backend/src/api/config/health.py
Normal file
137
backend/src/api/config/health.py
Normal file
@@ -0,0 +1,137 @@
|
||||
"""Health check and monitoring endpoints."""
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from src.models.schemas import ResponseModel
|
||||
from src.services.provider.registry import ModelRegistry
|
||||
from src.utils.errors import AppException, ErrorCode
|
||||
from src.services.provider.health import health_monitor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/config/health", response_model=ResponseModel)
|
||||
async def get_health_status():
|
||||
"""Get health status of all registered services."""
|
||||
try:
|
||||
summary = health_monitor.get_health_summary()
|
||||
all_health = health_monitor.get_all_health()
|
||||
|
||||
services = {}
|
||||
for service_id, health in all_health.items():
|
||||
services[service_id] = {
|
||||
"status": health.status.value,
|
||||
"last_check": health.last_check.isoformat() if health.last_check else None,
|
||||
"last_success": health.last_success.isoformat() if health.last_success else None,
|
||||
"last_failure": health.last_failure.isoformat() if health.last_failure else None,
|
||||
"consecutive_failures": health.consecutive_failures,
|
||||
"consecutive_successes": health.consecutive_successes,
|
||||
"total_checks": health.total_checks,
|
||||
"total_failures": health.total_failures,
|
||||
"success_rate": health.get_success_rate(),
|
||||
"avg_latency_ms": health.avg_latency_ms,
|
||||
"should_circuit_break": health.should_circuit_break()
|
||||
}
|
||||
|
||||
return ResponseModel(data={
|
||||
"summary": summary,
|
||||
"services": services,
|
||||
"unhealthy": health_monitor.get_unhealthy_services(),
|
||||
"degraded": health_monitor.get_degraded_services()
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get health status: {e}")
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.get("/config/health/{service_id}", response_model=ResponseModel)
|
||||
async def get_service_health(service_id: str):
|
||||
"""Get detailed health status for a specific service."""
|
||||
try:
|
||||
health = health_monitor.get_health(service_id)
|
||||
|
||||
if not health:
|
||||
raise AppException(
|
||||
message=f"Service not found: {service_id}",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
|
||||
history = [
|
||||
{
|
||||
"status": result.status.value,
|
||||
"latency_ms": result.latency_ms,
|
||||
"timestamp": result.timestamp.isoformat(),
|
||||
"error": result.error,
|
||||
"metadata": result.metadata
|
||||
}
|
||||
for result in health.history
|
||||
]
|
||||
|
||||
return ResponseModel(data={
|
||||
"service_id": service_id,
|
||||
"status": health.status.value,
|
||||
"last_check": health.last_check.isoformat() if health.last_check else None,
|
||||
"last_success": health.last_success.isoformat() if health.last_success else None,
|
||||
"last_failure": health.last_failure.isoformat() if health.last_failure else None,
|
||||
"consecutive_failures": health.consecutive_failures,
|
||||
"consecutive_successes": health.consecutive_successes,
|
||||
"total_checks": health.total_checks,
|
||||
"total_failures": health.total_failures,
|
||||
"success_rate": health.get_success_rate(),
|
||||
"avg_latency_ms": health.avg_latency_ms,
|
||||
"should_circuit_break": health.should_circuit_break(),
|
||||
"history": history
|
||||
})
|
||||
except AppException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get service health: {e}")
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.post("/config/health/{service_id}/check", response_model=ResponseModel)
|
||||
async def check_service_health(service_id: str):
|
||||
"""Manually trigger a health check for a specific service."""
|
||||
try:
|
||||
service = ModelRegistry.get(service_id)
|
||||
if not service:
|
||||
raise AppException(
|
||||
message=f"Service not found: {service_id}",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
|
||||
health_monitor.register_service(service_id)
|
||||
|
||||
result = await health_monitor.check_service_health(service_id, service)
|
||||
|
||||
health_monitor.update_health(service_id, result)
|
||||
|
||||
return ResponseModel(data={
|
||||
"service_id": service_id,
|
||||
"status": result.status.value,
|
||||
"latency_ms": result.latency_ms,
|
||||
"timestamp": result.timestamp.isoformat(),
|
||||
"error": result.error
|
||||
})
|
||||
except AppException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check service health: {e}")
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
49
backend/src/api/config/models.py
Normal file
49
backend/src/api/config/models.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""Pydantic models for config API."""
|
||||
|
||||
from typing import Optional, List, Dict, Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class SignUrlRequest(BaseModel):
|
||||
"""Request to sign a URL."""
|
||||
url: str
|
||||
|
||||
|
||||
class SystemConfig(BaseModel):
|
||||
"""System configuration model."""
|
||||
default_image_model: Optional[str] = None
|
||||
default_video_model: Optional[str] = None
|
||||
default_audio_model: Optional[str] = None
|
||||
default_llm_model: Optional[str] = None
|
||||
default_style: Optional[str] = None
|
||||
default_resolution: Optional[str] = None
|
||||
default_ratio: Optional[str] = None
|
||||
|
||||
|
||||
class ModelRegistrationRequest(BaseModel):
|
||||
"""Request to register a new model."""
|
||||
model_id: str
|
||||
provider: str
|
||||
name: str
|
||||
type: str
|
||||
enabled: bool = True
|
||||
config: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ProviderKeyField(BaseModel):
|
||||
"""Provider key field configuration."""
|
||||
name: str
|
||||
label: str
|
||||
placeholder: str
|
||||
required: bool
|
||||
type: Optional[str] = "text" # 'password' or 'text'
|
||||
|
||||
|
||||
class ProviderConfigResponse(BaseModel):
|
||||
"""Provider configuration response."""
|
||||
id: str
|
||||
name: str
|
||||
icon: str
|
||||
description: str
|
||||
fields: List[ProviderKeyField]
|
||||
helpUrl: Optional[str] = None
|
||||
410
backend/src/api/config/providers.py
Normal file
410
backend/src/api/config/providers.py
Normal file
@@ -0,0 +1,410 @@
|
||||
"""Provider and model configuration endpoints."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional, List, Dict, Any
|
||||
|
||||
from fastapi import APIRouter
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from src.config.settings import REDIS_ENABLED
|
||||
from src.models.schemas import ResponseModel
|
||||
from src.services.cache_service import get_cache_service
|
||||
from src.services.provider.registry import ModelRegistry, ModelType
|
||||
from src.utils.errors import BusinessException, ErrorCode, ResourceNotFoundException, InvalidParameterException
|
||||
from src.utils.service_loader import register_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class ModelRegistrationRequest(BaseModel):
|
||||
"""Request to register a new model."""
|
||||
name: str
|
||||
model_id: str
|
||||
provider: str
|
||||
type: str
|
||||
family: Optional[str] = None
|
||||
capabilities: Optional[Dict[str, bool]] = None
|
||||
variants: Optional[Dict[str, str]] = None
|
||||
resolutions: Optional[Any] = None
|
||||
durations: Optional[Any] = None
|
||||
voices: Optional[List[Dict[str, str]]] = None
|
||||
args: Optional[List[Any]] = None
|
||||
|
||||
|
||||
class ProviderKeyField(BaseModel):
|
||||
"""Provider key field configuration."""
|
||||
name: str
|
||||
label: str
|
||||
placeholder: str
|
||||
required: bool
|
||||
type: Optional[str] = "text"
|
||||
|
||||
|
||||
class ProviderConfigResponse(BaseModel):
|
||||
"""Provider configuration response."""
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
fields: List[ProviderKeyField]
|
||||
helpUrl: Optional[str] = None
|
||||
|
||||
|
||||
# Supported providers configuration with key field definitions
|
||||
def _build_provider_configs() -> List[ProviderConfigResponse]:
|
||||
"""Build provider configs from ModelRegistry metadata and registered models.
|
||||
|
||||
This dynamically generates provider configurations from provider.json files,
|
||||
avoiding hardcoded configurations.
|
||||
"""
|
||||
configs = []
|
||||
|
||||
# Get all registered providers from ModelRegistry
|
||||
providers = ModelRegistry.list_providers()
|
||||
provider_ids = {p["id"] for p in providers}
|
||||
|
||||
# Build config for each provider
|
||||
for provider_id in provider_ids:
|
||||
metadata = ModelRegistry.get_provider_metadata(provider_id) or {}
|
||||
|
||||
# Get fields from metadata or use defaults
|
||||
fields_data = metadata.get("fields", [])
|
||||
fields = [ProviderKeyField(**f) for f in fields_data] if fields_data else [
|
||||
ProviderKeyField(name="apiKey", label="API Key", placeholder="sk-...", required=True, type="password")
|
||||
]
|
||||
|
||||
# Build config response
|
||||
config = ProviderConfigResponse(
|
||||
id=provider_id,
|
||||
name=metadata.get("name") or provider_id.capitalize(),
|
||||
description=metadata.get("description", ""),
|
||||
fields=fields,
|
||||
helpUrl=metadata.get("helpUrl") or metadata.get("dashboard_url")
|
||||
)
|
||||
configs.append(config)
|
||||
|
||||
return configs
|
||||
|
||||
|
||||
@router.get("/config/providers", response_model=ResponseModel)
|
||||
async def get_providers():
|
||||
"""Get list of supported providers and their capabilities (cached for 5 minutes)."""
|
||||
cache_key = "config:providers"
|
||||
|
||||
if REDIS_ENABLED:
|
||||
cache = get_cache_service()
|
||||
cached_providers = await cache.get(cache_key)
|
||||
if cached_providers is not None:
|
||||
return ResponseModel(data=cached_providers)
|
||||
|
||||
providers = ModelRegistry.list_providers()
|
||||
result = {"providers": providers}
|
||||
|
||||
if REDIS_ENABLED:
|
||||
cache = get_cache_service()
|
||||
await cache.set(cache_key, result, ttl=300)
|
||||
|
||||
return ResponseModel(data=result)
|
||||
|
||||
|
||||
@router.post("/config/models", response_model=ResponseModel)
|
||||
async def register_new_model(request: ModelRegistrationRequest):
|
||||
"""Register a new model configuration.
|
||||
|
||||
Raises:
|
||||
InvalidParameterException: Invalid model configuration
|
||||
ResourceNotFoundException: Template not found
|
||||
BusinessException: Failed to register model
|
||||
"""
|
||||
try:
|
||||
provider = request.provider.lower()
|
||||
model_type = request.type.lower()
|
||||
family = request.family.lower() if request.family else "default"
|
||||
|
||||
module_name = ""
|
||||
class_name = ""
|
||||
|
||||
# Find template service from registry
|
||||
try:
|
||||
type_enum = None
|
||||
try:
|
||||
type_enum = ModelType(model_type)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
services = ModelRegistry.find_services(
|
||||
provider=provider,
|
||||
model_type=type_enum
|
||||
)
|
||||
|
||||
if not services and not type_enum:
|
||||
all_services = ModelRegistry.find_services(provider=provider)
|
||||
services = [s for s in all_services if s.get('type') == model_type]
|
||||
|
||||
if not services:
|
||||
raise ResourceNotFoundException("template", f"{provider}/{model_type}")
|
||||
|
||||
template = services[0]
|
||||
if family and family != "default":
|
||||
for svc in services:
|
||||
svc_id = svc.get('id', '').lower()
|
||||
svc_class = (svc.get('class') or svc.get('class_name', '')).lower()
|
||||
if family in svc_id or family in svc_class:
|
||||
template = svc
|
||||
break
|
||||
|
||||
module_name = template.get('module')
|
||||
class_name = template.get('class') or template.get('class_name')
|
||||
|
||||
if not module_name or not class_name:
|
||||
raise BusinessException(
|
||||
ErrorCode.UNKNOWN_ERROR,
|
||||
"Template service configuration is missing module or class",
|
||||
{"template": template}
|
||||
)
|
||||
|
||||
except (ResourceNotFoundException, BusinessException):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error finding template for {provider}/{model_type}: {e}", exc_info=True)
|
||||
raise BusinessException(
|
||||
ErrorCode.UNKNOWN_ERROR,
|
||||
"Failed to resolve service template",
|
||||
{"provider": provider, "type": model_type, "reason": str(e)}
|
||||
)
|
||||
|
||||
service_id = f"{provider}-{request.model_id.replace('.', '-').replace('/', '-')}"
|
||||
|
||||
config = {
|
||||
"id": service_id,
|
||||
"module": module_name,
|
||||
"class": class_name,
|
||||
"name": request.name,
|
||||
"args": [request.model_id],
|
||||
"type": model_type,
|
||||
"provider": provider,
|
||||
"enabled": True
|
||||
}
|
||||
|
||||
if request.capabilities:
|
||||
config["capabilities"] = request.capabilities
|
||||
if request.variants:
|
||||
config["variants"] = request.variants
|
||||
if request.resolutions:
|
||||
config["resolutions"] = request.resolutions
|
||||
if request.durations:
|
||||
config["durations"] = request.durations
|
||||
if request.args:
|
||||
config["args"] = request.args
|
||||
|
||||
src_dir = os.path.dirname(os.path.dirname(__file__))
|
||||
custom_config_path = os.path.join(src_dir, "config", "services", "custom_models.json")
|
||||
|
||||
os.makedirs(os.path.dirname(custom_config_path), exist_ok=True)
|
||||
|
||||
existing_configs = []
|
||||
if os.path.exists(custom_config_path):
|
||||
try:
|
||||
with open(custom_config_path, 'r', encoding='utf-8') as f:
|
||||
existing_configs = json.load(f)
|
||||
except (json.JSONDecodeError, OSError) as e:
|
||||
logger.warning("Failed to load custom config %s: %s", custom_config_path, e)
|
||||
|
||||
existing_configs = [c for c in existing_configs if c.get("id") != service_id]
|
||||
existing_configs.append(config)
|
||||
|
||||
with open(custom_config_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(existing_configs, f, indent=4, ensure_ascii=False)
|
||||
|
||||
register_service(config)
|
||||
|
||||
if REDIS_ENABLED:
|
||||
cache = get_cache_service()
|
||||
await cache.delete("config:models")
|
||||
await cache.delete("config:defaults")
|
||||
|
||||
return ResponseModel(data={"status": "success", "model": config})
|
||||
|
||||
except (ResourceNotFoundException, BusinessException):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to register model: {e}", exc_info=True)
|
||||
raise BusinessException(
|
||||
ErrorCode.UNKNOWN_ERROR,
|
||||
"Failed to register model",
|
||||
{"reason": str(e)}
|
||||
)
|
||||
|
||||
|
||||
@router.get("/config/models", response_model=ResponseModel)
|
||||
async def get_models_config():
|
||||
"""Get all registered models configuration grouped by type (cached for 5 minutes)."""
|
||||
cache_key = "config:models"
|
||||
|
||||
if REDIS_ENABLED:
|
||||
cache = get_cache_service()
|
||||
cached_models = await cache.get(cache_key)
|
||||
if cached_models is not None:
|
||||
return ResponseModel(data=cached_models)
|
||||
|
||||
models = ModelRegistry.list_models()
|
||||
|
||||
grouped = {
|
||||
"image": {},
|
||||
"video": {},
|
||||
"audio": {},
|
||||
"lyrics": {},
|
||||
"music": {},
|
||||
"llm": {}
|
||||
}
|
||||
|
||||
for model_id, config_dict in models.items():
|
||||
model_config = config_dict.copy() if isinstance(config_dict, dict) else {}
|
||||
|
||||
model_type_str = model_config.get("type")
|
||||
if not model_type_str or model_type_str not in grouped:
|
||||
continue
|
||||
|
||||
try:
|
||||
model_type = ModelType(model_type_str.lower())
|
||||
default_id = ModelRegistry.get_default_id(model_type)
|
||||
model_config["is_default"] = (model_id == default_id)
|
||||
except (ValueError, KeyError):
|
||||
model_config["is_default"] = False
|
||||
|
||||
model_config["id"] = model_id
|
||||
|
||||
if "provider" not in model_config:
|
||||
if "/" in model_id:
|
||||
model_config["provider"] = model_id.split("/", 1)[0]
|
||||
|
||||
if "model_key" not in model_config:
|
||||
if "/" in model_id:
|
||||
model_config["model_key"] = model_id.split("/", 1)[1]
|
||||
|
||||
grouped[model_type_str][model_id] = model_config
|
||||
|
||||
if REDIS_ENABLED:
|
||||
cache = get_cache_service()
|
||||
await cache.set(cache_key, grouped, ttl=300)
|
||||
|
||||
return ResponseModel(data=grouped)
|
||||
|
||||
|
||||
@router.get("/config/models/search", response_model=ResponseModel)
|
||||
async def search_models(
|
||||
provider: Optional[str] = None,
|
||||
type: Optional[str] = None,
|
||||
enabled_only: bool = True
|
||||
):
|
||||
"""Search for models by criteria."""
|
||||
try:
|
||||
model_type = None
|
||||
if type:
|
||||
try:
|
||||
model_type = ModelType(type.lower())
|
||||
except ValueError:
|
||||
raise InvalidParameterException("type", f"Invalid model type: {type}")
|
||||
|
||||
results = ModelRegistry.find_services(
|
||||
provider=provider,
|
||||
model_type=model_type,
|
||||
enabled_only=enabled_only
|
||||
)
|
||||
|
||||
return ResponseModel(data={
|
||||
"count": len(results),
|
||||
"models": results
|
||||
})
|
||||
except InvalidParameterException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to search models: {e}", exc_info=True)
|
||||
raise BusinessException(
|
||||
ErrorCode.UNKNOWN_ERROR,
|
||||
"Failed to search models",
|
||||
{"provider": provider, "type": type, "reason": str(e)}
|
||||
)
|
||||
|
||||
|
||||
@router.get("/config/models/{model_id}", response_model=ResponseModel)
|
||||
async def get_model_config(model_id: str):
|
||||
"""Get detailed configuration for a specific model."""
|
||||
try:
|
||||
config = ModelRegistry.get_config(model_id)
|
||||
if not config:
|
||||
raise ResourceNotFoundException("model", model_id)
|
||||
|
||||
return ResponseModel(data=config)
|
||||
except ResourceNotFoundException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get model config: {e}", exc_info=True)
|
||||
raise BusinessException(
|
||||
ErrorCode.UNKNOWN_ERROR,
|
||||
"Failed to get model configuration",
|
||||
{"model_id": model_id, "reason": str(e)}
|
||||
)
|
||||
|
||||
|
||||
@router.get("/config/defaults", response_model=ResponseModel)
|
||||
async def get_defaults():
|
||||
"""Get default models for each type (cached for 5 minutes)."""
|
||||
cache_key = "config:defaults"
|
||||
|
||||
if REDIS_ENABLED:
|
||||
cache = get_cache_service()
|
||||
cached_defaults = await cache.get(cache_key)
|
||||
if cached_defaults is not None:
|
||||
return ResponseModel(data=cached_defaults)
|
||||
|
||||
try:
|
||||
defaults = {}
|
||||
for model_type in ModelType:
|
||||
default_id = ModelRegistry.get_default_id(model_type)
|
||||
if default_id:
|
||||
defaults[model_type.value] = {
|
||||
"id": default_id,
|
||||
"config": ModelRegistry.get_config(default_id)
|
||||
}
|
||||
|
||||
if REDIS_ENABLED:
|
||||
cache = get_cache_service()
|
||||
await cache.set(cache_key, defaults, ttl=300)
|
||||
|
||||
return ResponseModel(data=defaults)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get defaults: {e}", exc_info=True)
|
||||
raise BusinessException(
|
||||
ErrorCode.UNKNOWN_ERROR,
|
||||
"Failed to get default models",
|
||||
{"reason": str(e)}
|
||||
)
|
||||
|
||||
|
||||
@router.get("/config/provider-configs", response_model=ResponseModel)
|
||||
async def get_provider_configs():
|
||||
"""Get supported provider configurations for user API key management.
|
||||
|
||||
Configurations are dynamically built from provider.json files,
|
||||
ensuring a single source of truth for provider metadata.
|
||||
"""
|
||||
cache_key = "config:provider_configs"
|
||||
|
||||
if REDIS_ENABLED:
|
||||
cache = get_cache_service()
|
||||
cached_configs = await cache.get(cache_key)
|
||||
if cached_configs is not None:
|
||||
return ResponseModel(data={"providers": cached_configs})
|
||||
|
||||
# Dynamically build provider configs from registry metadata
|
||||
configs = _build_provider_configs()
|
||||
providers_data = [config.model_dump() for config in configs]
|
||||
|
||||
if REDIS_ENABLED:
|
||||
cache = get_cache_service()
|
||||
await cache.set(cache_key, providers_data, ttl=300)
|
||||
|
||||
return ResponseModel(data={"providers": providers_data})
|
||||
24
backend/src/api/config/storage.py
Normal file
24
backend/src/api/config/storage.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""Storage-related endpoints."""
|
||||
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from src.models.schemas import ResponseModel
|
||||
from src.utils.errors import ErrorCode
|
||||
from src.services.storage_service import storage_manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class SignUrlRequest:
|
||||
"""Request to sign a URL."""
|
||||
url: str
|
||||
|
||||
|
||||
@router.post("/storage/sign-url", response_model=ResponseModel)
|
||||
async def sign_url_endpoint(request: dict):
|
||||
"""Re-sign an OSS URL."""
|
||||
new_url = storage_manager.sign_url(request.get("url", ""))
|
||||
return ResponseModel(code=ErrorCode.SUCCESS, message="URL signed", data={"url": new_url})
|
||||
208
backend/src/api/config/styles.py
Normal file
208
backend/src/api/config/styles.py
Normal file
@@ -0,0 +1,208 @@
|
||||
"""Style configuration endpoints."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from src.config.settings import REDIS_ENABLED
|
||||
from src.models.schemas import ResponseModel
|
||||
from src.services.cache_service import get_cache_service
|
||||
from src.utils.errors import AppException, ErrorCode, ResourceNotFoundException, BusinessException
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/config/styles", response_model=ResponseModel)
|
||||
async def get_styles():
|
||||
"""Get all style configurations (cached for 5 minutes)."""
|
||||
cache_key = "config:styles"
|
||||
|
||||
if REDIS_ENABLED:
|
||||
cache = get_cache_service()
|
||||
cached_styles = await cache.get(cache_key)
|
||||
if cached_styles is not None:
|
||||
return ResponseModel(data=cached_styles)
|
||||
|
||||
try:
|
||||
src_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||
styles_path = os.path.join(src_dir, "config", "styles.json")
|
||||
|
||||
result = {"styles": []}
|
||||
if os.path.exists(styles_path):
|
||||
with open(styles_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
result = {"styles": data.get("styles", [])}
|
||||
|
||||
if REDIS_ENABLED:
|
||||
cache = get_cache_service()
|
||||
await cache.set(cache_key, result, ttl=300)
|
||||
|
||||
return ResponseModel(data=result)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load styles: {e}")
|
||||
return ResponseModel(data={"styles": []})
|
||||
|
||||
|
||||
@router.post("/config/styles", response_model=ResponseModel)
|
||||
async def create_style(style: dict):
|
||||
"""Create a new style configuration."""
|
||||
try:
|
||||
src_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||
styles_path = os.path.join(src_dir, "config", "styles.json")
|
||||
|
||||
if not os.path.exists(styles_path):
|
||||
data = {"styles": []}
|
||||
else:
|
||||
with open(styles_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
|
||||
if 'id' not in style or not style['id']:
|
||||
style['id'] = str(uuid.uuid4())
|
||||
|
||||
data["styles"].append(style)
|
||||
|
||||
with open(styles_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
if REDIS_ENABLED:
|
||||
cache = get_cache_service()
|
||||
await cache.delete("config:styles")
|
||||
|
||||
return ResponseModel(data=style)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create style: {e}")
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.put("/config/styles/{style_id}", response_model=ResponseModel)
|
||||
async def update_style(style_id: str, style: dict):
|
||||
"""Update an existing style configuration."""
|
||||
try:
|
||||
src_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||
styles_path = os.path.join(src_dir, "config", "styles.json")
|
||||
|
||||
if not os.path.exists(styles_path):
|
||||
raise AppException(
|
||||
message="Styles configuration not found",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
|
||||
with open(styles_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
|
||||
styles = data.get("styles", [])
|
||||
updated = False
|
||||
for i, s in enumerate(styles):
|
||||
if s.get('id') == style_id:
|
||||
styles[i] = {**s, **style, "id": style_id}
|
||||
updated = True
|
||||
break
|
||||
|
||||
if not updated:
|
||||
raise AppException(
|
||||
message="Style not found",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
|
||||
data["styles"] = styles
|
||||
with open(styles_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
if REDIS_ENABLED:
|
||||
cache = get_cache_service()
|
||||
await cache.delete("config:styles")
|
||||
|
||||
return ResponseModel(data=styles[i])
|
||||
except AppException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update style: {e}")
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/config/styles/{style_id}", response_model=ResponseModel)
|
||||
async def delete_style(style_id: str):
|
||||
"""Delete a style configuration."""
|
||||
try:
|
||||
src_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||
styles_path = os.path.join(src_dir, "config", "styles.json")
|
||||
|
||||
if not os.path.exists(styles_path):
|
||||
raise ResourceNotFoundException("styles configuration", "styles.json")
|
||||
|
||||
with open(styles_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
|
||||
styles = data.get("styles", [])
|
||||
initial_len = len(styles)
|
||||
styles = [s for s in styles if s.get('id') != style_id]
|
||||
|
||||
if len(styles) == initial_len:
|
||||
raise ResourceNotFoundException("style", style_id)
|
||||
|
||||
data["styles"] = styles
|
||||
with open(styles_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
if REDIS_ENABLED:
|
||||
cache = get_cache_service()
|
||||
await cache.delete("config:styles")
|
||||
|
||||
return ResponseModel(data={"status": "success"})
|
||||
except ResourceNotFoundException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete style: {e}", exc_info=True)
|
||||
raise BusinessException(
|
||||
ErrorCode.UNKNOWN_ERROR,
|
||||
"Failed to delete style",
|
||||
{"style_id": style_id, "reason": str(e)}
|
||||
)
|
||||
|
||||
|
||||
@router.get("/config/generation-options", response_model=ResponseModel)
|
||||
async def get_generation_options():
|
||||
"""Get generation options (cached for 5 minutes)."""
|
||||
cache_key = "config:generation_options"
|
||||
|
||||
if REDIS_ENABLED:
|
||||
cache = get_cache_service()
|
||||
cached_options = await cache.get(cache_key)
|
||||
if cached_options is not None:
|
||||
return ResponseModel(data=cached_options)
|
||||
|
||||
try:
|
||||
src_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||
options_path = os.path.join(src_dir, "config", "generation_options.json")
|
||||
|
||||
result = {}
|
||||
if os.path.exists(options_path):
|
||||
with open(options_path, 'r', encoding='utf-8') as f:
|
||||
result = json.load(f)
|
||||
|
||||
if REDIS_ENABLED:
|
||||
cache = get_cache_service()
|
||||
await cache.set(cache_key, result, ttl=300)
|
||||
|
||||
return ResponseModel(data=result)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load generation options: {e}")
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
131
backend/src/api/config/system.py
Normal file
131
backend/src/api/config/system.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""System configuration endpoints."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
from fastapi import APIRouter
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
|
||||
from src.config.settings import REDIS_ENABLED
|
||||
from src.models.schemas import ResponseModel
|
||||
from src.services.cache_service import get_cache_service
|
||||
from src.services.provider.registry import ModelRegistry, ModelType
|
||||
from src.utils.errors import BusinessException, ErrorCode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class SystemConfig(BaseModel):
|
||||
"""System configuration model."""
|
||||
default_image_model: Optional[str] = None
|
||||
default_video_model: Optional[str] = None
|
||||
default_audio_model: Optional[str] = None
|
||||
default_llm_model: Optional[str] = None
|
||||
default_style: Optional[str] = None
|
||||
default_resolution: Optional[str] = None
|
||||
default_ratio: Optional[str] = None
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
@router.get("/config/system", response_model=ResponseModel)
|
||||
async def get_system_config():
|
||||
"""Get system configuration (cached for 5 minutes).
|
||||
|
||||
Raises:
|
||||
BusinessException: Failed to load system config
|
||||
"""
|
||||
cache_key = "config:system"
|
||||
|
||||
# Try cache first
|
||||
if REDIS_ENABLED:
|
||||
cache = get_cache_service()
|
||||
cached_config = await cache.get(cache_key)
|
||||
if cached_config is not None:
|
||||
return ResponseModel(data=cached_config)
|
||||
|
||||
try:
|
||||
src_dir = os.path.dirname(os.path.dirname(__file__))
|
||||
config_path = os.path.join(src_dir, "config", "user_config.json")
|
||||
|
||||
config_data = {}
|
||||
if os.path.exists(config_path):
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config_data = json.load(f)
|
||||
|
||||
# Cache result for 5 minutes
|
||||
if REDIS_ENABLED:
|
||||
cache = get_cache_service()
|
||||
await cache.set(cache_key, config_data, ttl=300)
|
||||
|
||||
return ResponseModel(data=config_data)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load system config: {e}", exc_info=True)
|
||||
raise BusinessException(
|
||||
ErrorCode.UNKNOWN_ERROR,
|
||||
"Failed to load system configuration",
|
||||
{"reason": str(e)}
|
||||
)
|
||||
|
||||
|
||||
@router.post("/config/system", response_model=ResponseModel)
|
||||
async def update_system_config(config: SystemConfig):
|
||||
"""Update system configuration.
|
||||
|
||||
Raises:
|
||||
BusinessException: Failed to update system config
|
||||
"""
|
||||
try:
|
||||
src_dir = os.path.dirname(os.path.dirname(__file__))
|
||||
config_path = os.path.join(src_dir, "config", "user_config.json")
|
||||
|
||||
# Load existing to merge
|
||||
existing_data = {}
|
||||
if os.path.exists(config_path):
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
try:
|
||||
existing_data = json.load(f)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning("Failed to parse existing config %s: %s", config_path, e)
|
||||
|
||||
new_data = config.model_dump(exclude_unset=True)
|
||||
merged_data = {**existing_data, **new_data}
|
||||
|
||||
with open(config_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(merged_data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
# Update default models in registry if changed
|
||||
config_key_to_type = {
|
||||
'defaultImageModel': ModelType.IMAGE,
|
||||
'defaultVideoModel': ModelType.VIDEO,
|
||||
'defaultAudioModel': ModelType.AUDIO,
|
||||
'defaultLyricsModel': ModelType.LYRICS,
|
||||
'defaultMusicModel': ModelType.MUSIC,
|
||||
'defaultLLMModel': ModelType.LLM
|
||||
}
|
||||
|
||||
for config_key, model_type in config_key_to_type.items():
|
||||
if config_key in new_data:
|
||||
model_id = new_data[config_key]
|
||||
if model_id:
|
||||
ModelRegistry.set_default_by_id(model_type, model_id)
|
||||
logger.info(f"Updated default {model_type.value} model to: {model_id}")
|
||||
|
||||
# Invalidate all related caches
|
||||
if REDIS_ENABLED:
|
||||
cache = get_cache_service()
|
||||
await cache.delete("config:system")
|
||||
await cache.delete("config:models")
|
||||
await cache.delete("config:defaults")
|
||||
|
||||
return ResponseModel(data=merged_data)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update system config: {e}", exc_info=True)
|
||||
raise BusinessException(
|
||||
ErrorCode.UNKNOWN_ERROR,
|
||||
"Failed to update system configuration",
|
||||
{"reason": str(e)}
|
||||
)
|
||||
79
backend/src/api/config/validation.py
Normal file
79
backend/src/api/config/validation.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""Configuration validation endpoints."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Dict, Any
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from src.models.schemas import ResponseModel
|
||||
from src.utils.errors import AppException, ErrorCode
|
||||
from src.services.provider.validation import ConfigValidator, validate_all_configs
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/config/validate", response_model=ResponseModel)
|
||||
async def validate_configurations():
|
||||
"""Validate all service configurations."""
|
||||
try:
|
||||
report = validate_all_configs()
|
||||
return ResponseModel(data=report)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to validate configurations: {e}")
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.post("/config/validate/service", response_model=ResponseModel)
|
||||
async def validate_service_configuration(config: Dict[str, Any]):
|
||||
"""Validate a single service configuration."""
|
||||
try:
|
||||
is_valid, errors = ConfigValidator.validate_config_deep(config)
|
||||
return ResponseModel(data={
|
||||
"valid": is_valid,
|
||||
"errors": errors,
|
||||
"config": config
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to validate service configuration: {e}")
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.get("/config/validate/file/{filename}", response_model=ResponseModel)
|
||||
async def validate_config_file(filename: str):
|
||||
"""Validate a specific configuration file."""
|
||||
try:
|
||||
src_dir = os.path.dirname(os.path.dirname(__file__))
|
||||
config_path = os.path.join(src_dir, "config", "services", filename)
|
||||
|
||||
if not os.path.exists(config_path):
|
||||
raise AppException(
|
||||
message=f"Configuration file not found: {filename}",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
|
||||
is_valid, errors = ConfigValidator.validate_config_file(config_path)
|
||||
return ResponseModel(data={
|
||||
"filename": filename,
|
||||
"valid": is_valid,
|
||||
"errors": errors
|
||||
})
|
||||
except AppException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to validate config file: {e}")
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
45
backend/src/api/generations/__init__.py
Normal file
45
backend/src/api/generations/__init__.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""
|
||||
Generations Module - 生成相关路由模块
|
||||
|
||||
将各个子模块的路由组合成统一的 router
|
||||
|
||||
子模块:
|
||||
- script.py - 脚本分析端点
|
||||
- tasks.py - 任务状态端点
|
||||
- image.py - 图片生成端点
|
||||
- video.py - 视频生成端点
|
||||
- audio.py - 音频生成和导出端点
|
||||
- music.py - 音乐生成端点
|
||||
- batch.py - 批量生成端点
|
||||
"""
|
||||
from fastapi import APIRouter
|
||||
|
||||
from .script import router as script_router
|
||||
from .tasks import router as tasks_router
|
||||
from .image import router as image_router
|
||||
from .video import router as video_router
|
||||
from .audio import router as audio_router
|
||||
from .music import router as music_router
|
||||
from .batch import router as batch_router
|
||||
|
||||
# 创建主路由
|
||||
router = APIRouter(tags=["generations"])
|
||||
|
||||
# 包含所有子路由
|
||||
router.include_router(script_router)
|
||||
router.include_router(tasks_router)
|
||||
router.include_router(image_router)
|
||||
router.include_router(video_router)
|
||||
router.include_router(audio_router)
|
||||
router.include_router(music_router)
|
||||
router.include_router(batch_router)
|
||||
|
||||
# 导出辅助函数供其他模块使用
|
||||
from .helpers import resolve_service, ensure_url, get_available_styles_from_config
|
||||
|
||||
__all__ = [
|
||||
"router",
|
||||
"resolve_service",
|
||||
"ensure_url",
|
||||
"get_available_styles_from_config",
|
||||
]
|
||||
159
backend/src/api/generations/audio.py
Normal file
159
backend/src/api/generations/audio.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""
|
||||
Audio & Export Endpoints - 音频生成和导出模块
|
||||
|
||||
包含:
|
||||
- POST /audio/generate - 音频生成
|
||||
- POST /generations/audio - 统一音频生成
|
||||
- POST /export/project/{project_id} - 项目导出
|
||||
"""
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from src.models.schemas import (
|
||||
ResponseModel,
|
||||
AudioGenerationRequest,
|
||||
GenerateAudioRequest,
|
||||
ExportProjectRequest
|
||||
)
|
||||
from src.services.provider.registry import ModelRegistry, ModelType
|
||||
from src.services.task_manager import task_manager
|
||||
from src.services.export_service import VideoExportService
|
||||
from src.auth.dependencies import get_current_user, UserAuth
|
||||
|
||||
from .helpers import resolve_service, check_user_api_key
|
||||
from src.utils.errors import ErrorCode, AppException
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 初始化 export service
|
||||
export_service = VideoExportService()
|
||||
|
||||
|
||||
@router.post("/audio/generate", response_model=ResponseModel)
|
||||
async def generate_audio(
|
||||
request: GenerateAudioRequest,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""生成 audio from text(兼容旧接口,内部走统一任务系统)"""
|
||||
logger.info("Audio generation request (legacy): model=%s, text=%s...", request.model, request.text[:50])
|
||||
|
||||
# 兼容旧接口:model 可为空,回退到默认音频模型
|
||||
model_name = request.model or ModelRegistry.get_default_id(ModelType.AUDIO)
|
||||
if not model_name:
|
||||
raise AppException(
|
||||
message="Audio service not configured",
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
# 验证模型可解析并获取服务
|
||||
audio_service = resolve_service(model_name, ModelType.AUDIO)
|
||||
|
||||
# 检查用户是否配置了 API Key
|
||||
# 从 audio_service 获取 provider_id
|
||||
provider = getattr(audio_service, 'provider_id', None)
|
||||
if provider:
|
||||
check_user_api_key(current_user.id, provider)
|
||||
|
||||
try:
|
||||
task_params = {
|
||||
"text": request.text,
|
||||
"voice": request.voice,
|
||||
"model": model_name,
|
||||
"project_id": request.project_id,
|
||||
"storyboard_id": request.storyboard_id,
|
||||
"extra_params": request.extra_params or {},
|
||||
}
|
||||
|
||||
task = await task_manager.create_task(
|
||||
task_type="audio",
|
||||
model=model_name,
|
||||
params=task_params,
|
||||
user_id=current_user.id,
|
||||
project_id=request.project_id,
|
||||
)
|
||||
|
||||
# 兼容前端:保留 audio_url 字段但异步任务初始为空
|
||||
return ResponseModel(data={
|
||||
"audio_url": None,
|
||||
"task_id": task.id,
|
||||
"status": task.status
|
||||
})
|
||||
except Exception as e:
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.post("/generations/audio", response_model=ResponseModel)
|
||||
async def generate_audio_unified(
|
||||
request: AudioGenerationRequest,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""统一音频生成端点(推荐)"""
|
||||
logger.info("Audio generation request: model=%s, text=%s...", request.model, request.text[:50])
|
||||
|
||||
audio_service = resolve_service(request.model, ModelType.AUDIO)
|
||||
|
||||
# 检查用户是否配置了 API Key
|
||||
# 从 audio_service 获取 provider_id
|
||||
provider = getattr(audio_service, 'provider_id', None)
|
||||
if provider:
|
||||
check_user_api_key(current_user.id, provider)
|
||||
|
||||
if not audio_service:
|
||||
raise AppException(
|
||||
message="Audio service not configured",
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
try:
|
||||
task_params = request.model_dump()
|
||||
# Get user_id from authenticated user
|
||||
user_id = current_user.id
|
||||
|
||||
# Handle project_id from multiple sources
|
||||
# Priority 1: request.project_id (from request body, alias="projectId")
|
||||
# Priority 2: request.source_id if source=='project' (from extra_params, alias="sourceId")
|
||||
project_id = request.project_id
|
||||
if not project_id and request.source == "project":
|
||||
project_id = request.source_id
|
||||
|
||||
task = await task_manager.create_task(
|
||||
task_type="audio",
|
||||
model=request.model,
|
||||
params=task_params,
|
||||
user_id=user_id,
|
||||
project_id=project_id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to create audio task: %s", e)
|
||||
raise AppException(
|
||||
message=f"Failed to create task: {e}",
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
return ResponseModel(data={
|
||||
"task_id": task.id,
|
||||
"status": task.status
|
||||
})
|
||||
|
||||
|
||||
@router.post("/export/project/{project_id}", response_model=ResponseModel)
|
||||
async def export_project(project_id: str, request: ExportProjectRequest):
|
||||
""" 导出 project to video"""
|
||||
try:
|
||||
result = await export_service.export_project(project_id, format=request.format)
|
||||
return ResponseModel(data=result)
|
||||
except Exception as e:
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
320
backend/src/api/generations/batch.py
Normal file
320
backend/src/api/generations/batch.py
Normal file
@@ -0,0 +1,320 @@
|
||||
"""
|
||||
Batch Generation Endpoints - 批量生成模块
|
||||
|
||||
包含批量生成相关的端点:
|
||||
- POST /generations/batch - 批量创建生成任务
|
||||
"""
|
||||
import logging
|
||||
from typing import List, Dict, Any
|
||||
from fastapi import APIRouter, Depends
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
|
||||
from src.models.schemas import (
|
||||
ResponseModel,
|
||||
ImageGenerationRequest,
|
||||
VideoGenerationRequest,
|
||||
AudioGenerationRequest,
|
||||
MusicGenerationRequest
|
||||
)
|
||||
from src.services.provider.registry import ModelRegistry, ModelType
|
||||
from src.services.task_manager import task_manager
|
||||
from src.auth.dependencies import get_current_user, UserAuth
|
||||
from src.utils.errors import ErrorCode, AppException
|
||||
|
||||
from .helpers import resolve_service, ensure_url, check_user_api_key
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BatchGenerationRequest(BaseModel):
|
||||
"""批量生成请求"""
|
||||
items: List[Dict[str, Any]] = Field(..., description="生成任务列表")
|
||||
|
||||
model_config = ConfigDict(json_schema_extra={
|
||||
"example": {
|
||||
"items": [
|
||||
{
|
||||
"type": "image",
|
||||
"prompt": "A beautiful landscape",
|
||||
"model": "wanx2.1-t2i-turbo"
|
||||
}
|
||||
]
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
class BatchGenerationResponse(BaseModel):
|
||||
"""批量生成响应"""
|
||||
task_ids: List[str] = Field(..., description="创建的任务ID列表")
|
||||
total: int = Field(..., description="总任务数")
|
||||
created: int = Field(..., description="成功创建的任务数")
|
||||
failed: int = Field(..., description="失败的任务数")
|
||||
errors: List[Dict[str, str]] = Field(default=[], description="错误信息列表")
|
||||
|
||||
|
||||
@router.post("/generations/batch", response_model=ResponseModel)
|
||||
async def generate_batch(
|
||||
request: BatchGenerationRequest,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""批量创建生成任务
|
||||
|
||||
支持同时创建多个图片、视频、音频、音乐生成任务
|
||||
|
||||
Args:
|
||||
request: 批量生成请求,包含任务列表
|
||||
|
||||
Returns:
|
||||
批量生成响应,包含创建的任务ID列表
|
||||
"""
|
||||
logger.info(f"Batch generation request: {len(request.items)} items from user {current_user.id}")
|
||||
|
||||
if not request.items:
|
||||
raise AppException(
|
||||
message="No items provided for batch generation",
|
||||
code=ErrorCode.INVALID_PARAMETER,
|
||||
status_code=400
|
||||
)
|
||||
|
||||
if len(request.items) > 20:
|
||||
raise AppException(
|
||||
message="Batch size exceeds maximum limit of 20",
|
||||
code=ErrorCode.INVALID_PARAMETER,
|
||||
status_code=400
|
||||
)
|
||||
|
||||
task_ids = []
|
||||
errors = []
|
||||
|
||||
for idx, item in enumerate(request.items):
|
||||
try:
|
||||
item_type = item.get("type", "image")
|
||||
|
||||
if item_type == "image":
|
||||
task_id = await _create_image_task(item, current_user)
|
||||
elif item_type == "video":
|
||||
task_id = await _create_video_task(item, current_user)
|
||||
elif item_type == "audio":
|
||||
task_id = await _create_audio_task(item, current_user)
|
||||
elif item_type == "music":
|
||||
task_id = await _create_music_task(item, current_user)
|
||||
else:
|
||||
raise ValueError(f"Unsupported generation type: {item_type}")
|
||||
|
||||
task_ids.append(task_id)
|
||||
logger.info(f"Created {item_type} task {task_id} for batch item {idx}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create task for batch item {idx}: {e}")
|
||||
errors.append({
|
||||
"index": idx,
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
result = BatchGenerationResponse(
|
||||
task_ids=task_ids,
|
||||
total=len(request.items),
|
||||
created=len(task_ids),
|
||||
failed=len(errors),
|
||||
errors=errors
|
||||
)
|
||||
|
||||
logger.info(f"Batch generation completed: {result.created}/{result.total} tasks created")
|
||||
|
||||
return ResponseModel(data=result)
|
||||
|
||||
|
||||
async def _create_image_task(item: Dict[str, Any], current_user: UserAuth) -> str:
|
||||
"""创建图片生成任务"""
|
||||
image_service = resolve_service(item.get("model"), ModelType.IMAGE)
|
||||
|
||||
if not image_service:
|
||||
raise ValueError("Image service not configured")
|
||||
|
||||
# 检查 API Key
|
||||
provider = getattr(image_service, 'provider_id', None)
|
||||
if provider:
|
||||
check_user_api_key(current_user.id, provider)
|
||||
|
||||
# 解析尺寸
|
||||
aspect_ratio = item.get("aspect_ratio", "16:9")
|
||||
resolution = item.get("resolution", "1K")
|
||||
size = _resolve_image_size(image_service, aspect_ratio, resolution)
|
||||
|
||||
# 构建参数
|
||||
task_params = {
|
||||
"prompt": item.get("prompt", ""),
|
||||
"size": size,
|
||||
"aspect_ratio": aspect_ratio,
|
||||
"resolution": resolution,
|
||||
"image_inputs": [_sanitize_url(url) for url in item.get("image_inputs", []) if url],
|
||||
"extra_params": item.get("extra_params", {})
|
||||
}
|
||||
|
||||
# 处理 source 相关字段
|
||||
if item.get("source"):
|
||||
task_params["source"] = item["source"]
|
||||
if item.get("source_id"):
|
||||
task_params["source_id"] = item["source_id"]
|
||||
if item.get("project_id"):
|
||||
task_params["project_id"] = item["project_id"]
|
||||
|
||||
model_name = item.get("model") or getattr(image_service, "model_id", "default_image")
|
||||
|
||||
task = await task_manager.create_task(
|
||||
task_type="image",
|
||||
model=model_name,
|
||||
params=task_params,
|
||||
user_id=current_user.id,
|
||||
project_id=item.get("project_id")
|
||||
)
|
||||
|
||||
return task.id
|
||||
|
||||
|
||||
async def _create_video_task(item: Dict[str, Any], current_user: UserAuth) -> str:
|
||||
"""创建视频生成任务"""
|
||||
video_service = resolve_service(item.get("model"), ModelType.VIDEO)
|
||||
|
||||
if not video_service:
|
||||
raise ValueError("Video service not configured")
|
||||
|
||||
# 检查 API Key
|
||||
provider = getattr(video_service, 'provider_id', None)
|
||||
if provider:
|
||||
check_user_api_key(current_user.id, provider)
|
||||
|
||||
# 构建参数
|
||||
task_params = {
|
||||
"prompt": item.get("prompt", ""),
|
||||
"duration": item.get("duration", 5),
|
||||
"resolution": item.get("resolution", "720p"),
|
||||
"image_inputs": [_sanitize_url(url) for url in item.get("image_inputs", []) if url],
|
||||
"extra_params": item.get("extra_params", {})
|
||||
}
|
||||
|
||||
if item.get("source"):
|
||||
task_params["source"] = item["source"]
|
||||
if item.get("source_id"):
|
||||
task_params["source_id"] = item["source_id"]
|
||||
if item.get("project_id"):
|
||||
task_params["project_id"] = item["project_id"]
|
||||
|
||||
model_name = item.get("model") or getattr(video_service, "model_id", "default_video")
|
||||
|
||||
task = await task_manager.create_task(
|
||||
task_type="video",
|
||||
model=model_name,
|
||||
params=task_params,
|
||||
user_id=current_user.id,
|
||||
project_id=item.get("project_id")
|
||||
)
|
||||
|
||||
return task.id
|
||||
|
||||
|
||||
async def _create_audio_task(item: Dict[str, Any], current_user: UserAuth) -> str:
|
||||
"""创建音频生成任务"""
|
||||
audio_service = resolve_service(item.get("model"), ModelType.AUDIO)
|
||||
|
||||
if not audio_service:
|
||||
raise ValueError("Audio service not configured")
|
||||
|
||||
# 检查 API Key
|
||||
provider = getattr(audio_service, 'provider_id', None)
|
||||
if provider:
|
||||
check_user_api_key(current_user.id, provider)
|
||||
|
||||
task_params = {
|
||||
"text": item.get("text", item.get("prompt", "")),
|
||||
"voice": item.get("voice", "default"),
|
||||
"speed": item.get("speed", 1.0),
|
||||
"extra_params": item.get("extra_params", {})
|
||||
}
|
||||
|
||||
if item.get("source"):
|
||||
task_params["source"] = item["source"]
|
||||
if item.get("source_id"):
|
||||
task_params["source_id"] = item["source_id"]
|
||||
|
||||
model_name = item.get("model") or getattr(audio_service, "model_id", "default_audio")
|
||||
|
||||
task = await task_manager.create_task(
|
||||
task_type="audio",
|
||||
model=model_name,
|
||||
params=task_params,
|
||||
user_id=current_user.id,
|
||||
project_id=item.get("project_id")
|
||||
)
|
||||
|
||||
return task.id
|
||||
|
||||
|
||||
async def _create_music_task(item: Dict[str, Any], current_user: UserAuth) -> str:
|
||||
"""创建音乐生成任务"""
|
||||
music_service = resolve_service(item.get("model"), ModelType.MUSIC)
|
||||
|
||||
if not music_service:
|
||||
raise ValueError("Music service not configured")
|
||||
|
||||
# 检查 API Key
|
||||
provider = getattr(music_service, 'provider_id', None)
|
||||
if provider:
|
||||
check_user_api_key(current_user.id, provider)
|
||||
|
||||
task_params = {
|
||||
"prompt": item.get("prompt", ""),
|
||||
"duration": item.get("duration", 30),
|
||||
"extra_params": item.get("extra_params", {})
|
||||
}
|
||||
|
||||
if item.get("source"):
|
||||
task_params["source"] = item["source"]
|
||||
if item.get("source_id"):
|
||||
task_params["source_id"] = item["source_id"]
|
||||
|
||||
model_name = item.get("model") or getattr(music_service, "model_id", "default_music")
|
||||
|
||||
task = await task_manager.create_task(
|
||||
task_type="music",
|
||||
model=model_name,
|
||||
params=task_params,
|
||||
user_id=current_user.id,
|
||||
project_id=item.get("project_id")
|
||||
)
|
||||
|
||||
return task.id
|
||||
|
||||
|
||||
def _resolve_image_size(image_service, aspect_ratio: str, resolution: str) -> str:
|
||||
"""解析图片尺寸"""
|
||||
model_config = getattr(image_service, "config", {})
|
||||
resolutions_config = model_config.get("resolutions", {})
|
||||
|
||||
if resolutions_config and resolution in resolutions_config:
|
||||
ratio_map = resolutions_config[resolution]
|
||||
if isinstance(ratio_map, dict) and aspect_ratio in ratio_map:
|
||||
return ratio_map[aspect_ratio]
|
||||
|
||||
# 默认尺寸
|
||||
defaults = {
|
||||
"1K": {
|
||||
"16:9": "1280*720",
|
||||
"9:16": "720*1280",
|
||||
"1:1": "1024*1024"
|
||||
},
|
||||
"2K": {
|
||||
"16:9": "2560*1440",
|
||||
"9:16": "1440*2560",
|
||||
"1:1": "2048*2048"
|
||||
}
|
||||
}
|
||||
return defaults.get(resolution, {}).get(aspect_ratio, "1024*1024")
|
||||
|
||||
|
||||
def _sanitize_url(url: str) -> str:
|
||||
"""清理 URL"""
|
||||
if not url:
|
||||
return url
|
||||
return ensure_url(url)
|
||||
396
backend/src/api/generations/helpers.py
Normal file
396
backend/src/api/generations/helpers.py
Normal file
@@ -0,0 +1,396 @@
|
||||
"""
|
||||
Generation Helpers - 辅助函数模块
|
||||
|
||||
包含:
|
||||
- resolve_service: 服务路由函数
|
||||
- ensure_url: URL 处理函数
|
||||
- _get_available_styles_from_config: 样式配置获取
|
||||
"""
|
||||
import logging
|
||||
import base64
|
||||
import uuid
|
||||
import os
|
||||
import json
|
||||
from functools import lru_cache
|
||||
from typing import List, Optional
|
||||
|
||||
from src.services.provider.registry import ModelRegistry, ModelType
|
||||
from src.services.storage_service import storage_manager
|
||||
from src.config.settings import REDIS_ENABLED
|
||||
from src.services.cache_service import get_cache_service
|
||||
from src.utils.errors import ErrorCode, AppException
|
||||
from src.services.user_api_key_service import user_api_key_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _is_test_env() -> bool:
|
||||
return (
|
||||
os.getenv("PYTEST_CURRENT_TEST") is not None
|
||||
or os.getenv("PIXEL_TEST_MODE") == "1"
|
||||
or os.getenv("NODE_ENV") == "test"
|
||||
)
|
||||
|
||||
|
||||
def _ensure_model_registry_loaded() -> None:
|
||||
if ModelRegistry.list_models():
|
||||
return
|
||||
|
||||
from src.utils.service_loader import load_services_from_config
|
||||
|
||||
services_config_path = os.path.join(
|
||||
os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
|
||||
"config",
|
||||
"services",
|
||||
)
|
||||
load_services_from_config(services_config_path)
|
||||
|
||||
|
||||
def _reload_model_registry() -> None:
|
||||
from src.utils.service_loader import load_services_from_config
|
||||
|
||||
services_config_path = os.path.join(
|
||||
os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
|
||||
"config",
|
||||
"services",
|
||||
)
|
||||
load_services_from_config(services_config_path)
|
||||
|
||||
|
||||
@lru_cache(maxsize=128)
|
||||
def resolve_service(
|
||||
model: str,
|
||||
model_type: ModelType
|
||||
):
|
||||
"""
|
||||
根据复合 ID 路由到正确的服务
|
||||
|
||||
路由逻辑:
|
||||
1. 验证 model 必须是 provider/model_key 格式
|
||||
2. 解析 provider 和 model_key
|
||||
3. 尝试用复合 ID 直接查找
|
||||
4. 在该 provider 下搜索匹配的 model_key
|
||||
5. 如果都找不到,抛出 HTTPException(404)
|
||||
|
||||
Args:
|
||||
model: 复合 ID 格式 "provider/model_key"(如 'dashscope/qwen-image')
|
||||
model_type: 模型类型(IMAGE, VIDEO, AUDIO 等)
|
||||
|
||||
Returns:
|
||||
服务实例
|
||||
|
||||
Raises:
|
||||
ValueError: 当 model 格式不正确时
|
||||
HTTPException: 当模型不存在时
|
||||
"""
|
||||
_ensure_model_registry_loaded()
|
||||
|
||||
# 1. 验证格式
|
||||
if not model or '/' not in model:
|
||||
raise ValueError(
|
||||
f"Model must be in format 'provider/model_key', got: '{model}'. "
|
||||
f"Example: 'dashscope/qwen-image'"
|
||||
)
|
||||
|
||||
# 2. 解析 provider 和 model_key
|
||||
parts = model.split('/', 1)
|
||||
if len(parts) != 2:
|
||||
raise ValueError(
|
||||
f"Model format invalid: '{model}'. "
|
||||
f"Must have exactly one '/' separator."
|
||||
)
|
||||
|
||||
provider, model_key = parts
|
||||
if not provider or not model_key:
|
||||
raise ValueError(
|
||||
f"Model format invalid: '{model}'. "
|
||||
f"Both provider and model_key must be non-empty."
|
||||
)
|
||||
|
||||
logger.info(f"Resolving service: model='{model}', type={model_type.value}")
|
||||
|
||||
# 3. 方式1: 直接用复合 ID 查找
|
||||
service = ModelRegistry.get(model)
|
||||
if service:
|
||||
logger.info(f"Found service by composite ID: '{model}'")
|
||||
return service
|
||||
|
||||
# 4. 方式2: 在 provider 下搜索 model_key
|
||||
services = ModelRegistry.find_services(provider=provider, model_type=model_type)
|
||||
matching = [
|
||||
s for s in services
|
||||
if s.get('model_key') == model_key or s.get('id') == model
|
||||
]
|
||||
|
||||
if matching:
|
||||
service = ModelRegistry.get(matching[0].get('id'))
|
||||
logger.info(
|
||||
f"Found service by provider+model_key: "
|
||||
f"provider='{provider}', model_key='{model_key}'"
|
||||
)
|
||||
return service
|
||||
|
||||
# 4.5 可能 registry 被测试或运行时临时覆盖,尝试强制回填一次
|
||||
_reload_model_registry()
|
||||
|
||||
service = ModelRegistry.get(model)
|
||||
if service:
|
||||
logger.info(f"Found service by composite ID after registry reload: '{model}'")
|
||||
return service
|
||||
|
||||
services = ModelRegistry.find_services(provider=provider, model_type=model_type)
|
||||
matching = [
|
||||
s for s in services
|
||||
if s.get('model_key') == model_key or s.get('id') == model
|
||||
]
|
||||
if matching:
|
||||
service = ModelRegistry.get(matching[0].get('id'))
|
||||
logger.info(
|
||||
f"Found service by provider+model_key after registry reload: "
|
||||
f"provider='{provider}', model_key='{model_key}'"
|
||||
)
|
||||
return service
|
||||
|
||||
# 5. 未找到
|
||||
logger.error(f"Model '{model}' not found")
|
||||
raise AppException(
|
||||
message=f"Model '{model}' not found. Available models can be fetched from /api/v1/models",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
|
||||
|
||||
def ensure_url(url_or_base64: str) -> str:
|
||||
"""
|
||||
Ensure the input is a valid URL that external services can access.
|
||||
- If it's a relative path and using local storage, convert to Base64.
|
||||
- If it's a relative path and using OSS, upload to OSS.
|
||||
- If it's a Base64 string, return as is.
|
||||
- If it's a localhost URL, download and convert to Base64 or upload to OSS.
|
||||
- If it's already a public URL, sign it if needed.
|
||||
"""
|
||||
if not url_or_base64:
|
||||
return url_or_base64
|
||||
|
||||
# Check if it's a relative path (starts with /files/ or /uploads/)
|
||||
if url_or_base64.startswith('/files/') or url_or_base64.startswith('/uploads/'):
|
||||
try:
|
||||
from src.config.settings import DATA_DIR, UPLOAD_DIR, STORAGE_TYPE
|
||||
|
||||
logger.info(f"[ensure_url] Detected relative path: {url_or_base64}")
|
||||
|
||||
# Determine the file path
|
||||
if url_or_base64.startswith('/files/'):
|
||||
rel_path = url_or_base64[7:] # Remove '/files/'
|
||||
file_path = os.path.join(DATA_DIR, rel_path)
|
||||
else: # /uploads/
|
||||
rel_path = url_or_base64[9:] # Remove '/uploads/'
|
||||
file_path = os.path.join(UPLOAD_DIR, rel_path)
|
||||
|
||||
# Check if file exists
|
||||
if not os.path.exists(file_path):
|
||||
logger.error(f"[ensure_url] File not found: {file_path}")
|
||||
return url_or_base64
|
||||
|
||||
# Read the file
|
||||
with open(file_path, 'rb') as f:
|
||||
file_data = f.read()
|
||||
|
||||
# Determine extension and MIME type
|
||||
ext = os.path.splitext(file_path)[1].lstrip('.') or 'png'
|
||||
mime_type = f"image/{ext}"
|
||||
if ext == 'jpg':
|
||||
mime_type = "image/jpeg"
|
||||
|
||||
# If using OSS, upload the file to OSS
|
||||
if STORAGE_TYPE == 'oss':
|
||||
filename = f"temp/{uuid.uuid4()}.{ext}"
|
||||
oss_url = storage_manager.save(filename, file_data)
|
||||
logger.info(f"[ensure_url] Uploaded to OSS: {oss_url}")
|
||||
return oss_url
|
||||
|
||||
# If using local storage, convert to Base64
|
||||
else:
|
||||
base64_data = base64.b64encode(file_data).decode('utf-8')
|
||||
base64_url = f"data:{mime_type};base64,{base64_data}"
|
||||
logger.info(f"[ensure_url] Converted to Base64 (length: {len(base64_url)})")
|
||||
return base64_url
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[ensure_url] Failed to convert relative path: {e}")
|
||||
return url_or_base64
|
||||
|
||||
# Check if it's a localhost URL — read from disk instead of HTTP to prevent SSRF
|
||||
if any(host in url_or_base64.lower() for host in ['localhost', '127.0.0.1', '0.0.0.0']):
|
||||
try:
|
||||
from src.config.settings import STORAGE_TYPE
|
||||
from urllib.parse import urlparse
|
||||
|
||||
logger.info(f"[ensure_url] Detected localhost URL, reading from disk: {url_or_base64}")
|
||||
|
||||
# Extract the path component (e.g. /files/xxx or /uploads/xxx)
|
||||
parsed = urlparse(url_or_base64)
|
||||
path = parsed.path
|
||||
|
||||
# Map URL path to local filesystem
|
||||
if path.startswith('/files/'):
|
||||
rel_path = path[7:]
|
||||
file_path = os.path.join(DATA_DIR, rel_path)
|
||||
elif path.startswith('/uploads/'):
|
||||
rel_path = path[9:]
|
||||
file_path = os.path.join(UPLOAD_DIR, rel_path)
|
||||
else:
|
||||
logger.warning(f"[ensure_url] Unsupported localhost path: {path}")
|
||||
return url_or_base64
|
||||
|
||||
if not os.path.exists(file_path):
|
||||
logger.error(f"[ensure_url] File not found: {file_path}")
|
||||
return url_or_base64
|
||||
|
||||
with open(file_path, 'rb') as f:
|
||||
file_data = f.read()
|
||||
|
||||
ext = os.path.splitext(file_path)[1].lstrip('.') or 'png'
|
||||
mime_type = f"image/{ext}"
|
||||
if ext == 'jpg':
|
||||
mime_type = "image/jpeg"
|
||||
|
||||
if STORAGE_TYPE == 'oss':
|
||||
filename = f"temp/{uuid.uuid4()}.{ext}"
|
||||
saved_url = storage_manager.save(filename, file_data)
|
||||
logger.info(f"[ensure_url] Uploaded to OSS: {saved_url}")
|
||||
return saved_url
|
||||
else:
|
||||
base64_data = base64.b64encode(file_data).decode('utf-8')
|
||||
base64_url = f"data:{mime_type};base64,{base64_data}"
|
||||
logger.info(f"[ensure_url] Converted to Base64 (length: {len(base64_url)})")
|
||||
return base64_url
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[ensure_url] Failed to convert localhost URL: {e}")
|
||||
return url_or_base64
|
||||
|
||||
# Check if Base64 - return as is (already in correct format)
|
||||
if url_or_base64.startswith("data:"):
|
||||
logger.info(f"[ensure_url] Already Base64 format, returning as is")
|
||||
return url_or_base64
|
||||
|
||||
# If it's already a public URL, check if it belongs to our OSS bucket
|
||||
if url_or_base64.startswith(('http://', 'https://')):
|
||||
from src.config.settings import STORAGE_TYPE
|
||||
from src.utils.oss_utils import OSS_BUCKET
|
||||
|
||||
# Only sign URLs from our own OSS bucket; external URLs should pass through as-is
|
||||
if STORAGE_TYPE == 'oss' and OSS_BUCKET and f"{OSS_BUCKET}." in url_or_base64:
|
||||
return storage_manager.sign_url(url_or_base64) or url_or_base64
|
||||
else:
|
||||
# External public URL (e.g. dashscope-result, other CDN), return as-is
|
||||
logger.info(f"[ensure_url] External public URL, returning as-is")
|
||||
return url_or_base64
|
||||
|
||||
# Fallback: sign it if needed
|
||||
return storage_manager.sign_url(url_or_base64) or url_or_base64
|
||||
|
||||
|
||||
async def get_available_styles_from_config() -> List[dict]:
|
||||
""" 获取 available style configurations from backend config file.
|
||||
|
||||
Returns:
|
||||
List of style dicts with full information (id, name, desc, type, etc.)
|
||||
Example: [{'id': 'cyberpunk', 'name': '赛博朋克', 'desc': '高对比度霓虹灯光...', ...}, ...]
|
||||
"""
|
||||
cache_key = "config:styles_full"
|
||||
|
||||
# Cache first (5 minutes TTL)
|
||||
if REDIS_ENABLED:
|
||||
cache = get_cache_service()
|
||||
cached_styles = await cache.get(cache_key)
|
||||
if cached_styles is not None:
|
||||
return cached_styles
|
||||
|
||||
# Get from config file
|
||||
src_dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
|
||||
styles_path = os.path.join(src_dir, "config", "styles.json")
|
||||
|
||||
styles_list = []
|
||||
if os.path.exists(styles_path):
|
||||
with open(styles_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
styles_list = data.get('styles', [])
|
||||
|
||||
# Fallback to common styles with descriptions if config doesn't exist
|
||||
if not styles_list:
|
||||
styles_list = [
|
||||
{"id": "cyberpunk", "name": "赛博朋克", "desc": "高对比度霓虹灯光,未来主义建筑,机械改造"},
|
||||
{"id": "ink", "name": "水墨风", "desc": "传统中国水墨渲染,黑白灰为主,意境深远"},
|
||||
{"id": "pixar", "name": "皮克斯风格", "desc": "3D卡通渲染,色彩鲜艳,光影柔和,表情夸张"},
|
||||
{"id": "anime", "name": "日漫", "desc": "典型日本动画风格,线条清晰,赛璐璐上色"},
|
||||
{"id": "chinese-anime", "name": "国漫风", "desc": "中国现代动画风格,融合传统与现代元素"},
|
||||
{"id": "realistic", "name": "写实", "desc": "电影级写实渲染,细节丰富,光照真实"},
|
||||
{"id": "hand-drawn", "name": "手绘", "desc": "传统手绘质感,笔触明显,艺术感强"},
|
||||
{"id": "watercolor", "name": "水彩画", "desc": "水彩晕染效果,色彩通透,艺术感强"},
|
||||
{"id": "oil-painting", "name": "油画", "desc": "厚涂油画质感,笔触丰富,光影层次感强"},
|
||||
{"id": "pixel-art", "name": "像素风", "desc": "复古8-bit/16-bit像素艺术,怀旧游戏风格"},
|
||||
]
|
||||
|
||||
# 缓存 for 5 minutes
|
||||
if REDIS_ENABLED:
|
||||
cache = get_cache_service()
|
||||
await cache.set(cache_key, styles_list, ttl=300)
|
||||
|
||||
return styles_list
|
||||
|
||||
|
||||
def check_user_api_key(user_id: str, provider: str) -> None:
|
||||
"""
|
||||
检查系统是否配置了指定 Provider 的 API Key。
|
||||
|
||||
Args:
|
||||
user_id: 用户ID(保留参数兼容性)
|
||||
provider: 服务商 ID (如 'dashscope', 'openai', 'kling' 等)
|
||||
|
||||
Raises:
|
||||
AppException: 如果系统未配置该 Provider 的 API Key
|
||||
"""
|
||||
if _is_test_env():
|
||||
return
|
||||
|
||||
import os
|
||||
env_map = {
|
||||
"dashscope": "DASHSCOPE_API_KEY",
|
||||
"openai": "OPENAI_API_KEY",
|
||||
"kling": "KLING_ACCESS_KEY",
|
||||
"midjourney": "MIDJOURNEY_API_KEY",
|
||||
"modelscope": "MODELSCOPE_API_TOKEN",
|
||||
"volcengine": "VOLCENGINE_API_KEY",
|
||||
"minimax": "MINIMAX_API_KEY",
|
||||
"google": "GOOGLE_API_KEY",
|
||||
}
|
||||
env_var = env_map.get(provider)
|
||||
if env_var and os.getenv(env_var):
|
||||
return
|
||||
|
||||
provider_names = {
|
||||
"dashscope": "阿里云 DashScope",
|
||||
"openai": "OpenAI",
|
||||
"kling": "可灵 AI",
|
||||
"midjourney": "Midjourney",
|
||||
"modelscope": "ModelScope",
|
||||
"volcengine": "火山引擎",
|
||||
"minimax": "MiniMax",
|
||||
"google": "Google",
|
||||
}
|
||||
provider_name = provider_names.get(provider, provider)
|
||||
|
||||
raise AppException(
|
||||
message=f"未配置 {provider_name} 的 API Key",
|
||||
code=ErrorCode.INVALID_PARAMETER,
|
||||
status_code=400,
|
||||
details={
|
||||
"error": "API_KEY_NOT_CONFIGURED",
|
||||
"message": f"尚未配置 {provider_name} 的 API Key,无法提交生成任务。",
|
||||
"provider": provider,
|
||||
"provider_name": provider_name,
|
||||
"solution": "请联系管理员在 .env 中配置对应的 API Key。",
|
||||
}
|
||||
)
|
||||
208
backend/src/api/generations/image.py
Normal file
208
backend/src/api/generations/image.py
Normal file
@@ -0,0 +1,208 @@
|
||||
"""
|
||||
Image Generation Endpoints - 图片生成模块
|
||||
|
||||
包含图片生成相关的端点:
|
||||
- POST /generations/image - 图片生成
|
||||
- POST /image/upscale - 图片放大
|
||||
"""
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from src.models.schemas import (
|
||||
ResponseModel,
|
||||
ImageGenerationRequest,
|
||||
UpscaleImageRequest, UpscaleImageResponse
|
||||
)
|
||||
from src.services.provider.registry import ModelRegistry, ModelType
|
||||
from src.services.task_manager import task_manager
|
||||
from src.auth.dependencies import get_current_user, UserAuth
|
||||
|
||||
from .helpers import resolve_service, ensure_url, check_user_api_key
|
||||
from src.utils.errors import ErrorCode, AppException
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@router.post("/generations/image", response_model=ResponseModel)
|
||||
async def generate_image(
|
||||
request: ImageGenerationRequest,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
""" 泛型 Image Generation Endpoint.
|
||||
Supports generation for any context (Storyboard, Asset, etc.)
|
||||
"""
|
||||
# Logging
|
||||
logger.info(f"Image generation request: model={request.model}, prompt={request.prompt[:50]}...")
|
||||
|
||||
# 1. Resolve Model - 使用统一的路由逻辑
|
||||
image_service = resolve_service(request.model, ModelType.IMAGE)
|
||||
|
||||
if not image_service:
|
||||
raise AppException(
|
||||
message="Image service not configured",
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
# 1.5 检查用户是否配置了 API Key
|
||||
# 从 image_service 获取 provider_id
|
||||
provider = getattr(image_service, 'provider_id', None)
|
||||
if provider:
|
||||
check_user_api_key(current_user.id, provider)
|
||||
|
||||
# 2. Resolve size from aspect_ratio and resolution (BEFORE creating task)
|
||||
aspect_ratio = request.aspect_ratio
|
||||
size = None
|
||||
|
||||
# 调试 logging
|
||||
logger.info(f"[ImageGeneration] Request received:")
|
||||
logger.info(f" - Prompt: {request.prompt[:100]}...")
|
||||
logger.info(f" - Model: {request.model}")
|
||||
logger.info(f" - Aspect Ratio: {aspect_ratio}")
|
||||
logger.info(f" - Resolution: {request.resolution}")
|
||||
logger.info(f" - Image inputs count: {len(request.image_inputs or [])}")
|
||||
if request.image_inputs:
|
||||
for i, url in enumerate(request.image_inputs):
|
||||
logger.info(f" - Image input {i}: {url[:100]}...")
|
||||
|
||||
if aspect_ratio:
|
||||
# Look up resolution in model config
|
||||
model_config = getattr(image_service, "config", {})
|
||||
resolutions_config = model_config.get("resolutions", {})
|
||||
|
||||
# Use provided resolution level (e.g. "1K", "2K", "4K") or default to "1K"
|
||||
res_level = request.resolution or "1K"
|
||||
|
||||
if resolutions_config and res_level in resolutions_config and isinstance(resolutions_config[res_level], dict):
|
||||
ratio_map = resolutions_config[res_level]
|
||||
if aspect_ratio in ratio_map:
|
||||
size = ratio_map[aspect_ratio]
|
||||
logger.info(f" - Resolved size from config: {size} (resolution={res_level}, ratio={aspect_ratio})")
|
||||
|
||||
if not size:
|
||||
defaults = {
|
||||
"1K": {
|
||||
"16:9": "1280*720",
|
||||
"9:16": "720*1280",
|
||||
"1:1": "1024*1024",
|
||||
"4:3": "1280*960",
|
||||
"3:4": "960*1280",
|
||||
"2.35:1": "1280*544"
|
||||
},
|
||||
"2K": {
|
||||
"16:9": "2560*1440",
|
||||
"9:16": "1440*2560",
|
||||
"1:1": "2048*2048",
|
||||
"4:3": "2560*1920",
|
||||
"3:4": "1920*2560"
|
||||
},
|
||||
"4K": {
|
||||
"16:9": "3840*2160",
|
||||
"9:16": "2160*3840",
|
||||
"1:1": "4096*4096",
|
||||
"4:3": "3840*2880",
|
||||
"3:4": "2880*3840"
|
||||
}
|
||||
}
|
||||
size = defaults.get(res_level, {}).get(aspect_ratio)
|
||||
if size:
|
||||
logger.info(f" - Resolved size from defaults: {size} (resolution={res_level}, ratio={aspect_ratio})")
|
||||
|
||||
# 3. Create Task in DB with resolved params
|
||||
try:
|
||||
model_name = request.model or getattr(image_service, "model_id", "default_image")
|
||||
|
||||
# 构建 params dict with resolved size
|
||||
task_params = request.model_dump()
|
||||
task_params["size"] = size # Override with resolved size (provider expects 'size')
|
||||
|
||||
# lora_strength 已经在 extra_params 中(前端传递),不需要额外处理
|
||||
# TaskManager 会从 extra_params 中提取并合并到一级参数
|
||||
|
||||
# 清理 Image URLs (Convert relative paths, localhost URLs, Base64)
|
||||
logger.info(f"[ImageGeneration] Before URL sanitization - image_inputs: {task_params.get('image_inputs')}")
|
||||
|
||||
for key in ["image_inputs"]:
|
||||
if key in task_params and isinstance(task_params[key], list):
|
||||
original_urls = task_params[key]
|
||||
task_params[key] = [ensure_url(url) for url in task_params[key] if url]
|
||||
logger.info(f"[ImageGeneration] Sanitized {key}: {original_urls} -> {task_params[key]}")
|
||||
|
||||
# Use unified task manager
|
||||
# Get user_id from authenticated user
|
||||
user_id = current_user.id
|
||||
|
||||
# Handle project_id from multiple sources
|
||||
# Priority 1: request.project_id (from request body, alias="projectId")
|
||||
# Priority 2: request.source_id if source=='project' (from extra_params, alias="sourceId")
|
||||
project_id = request.project_id
|
||||
if not project_id and request.source == 'project':
|
||||
project_id = request.source_id
|
||||
|
||||
task = await task_manager.create_task(
|
||||
task_type="image",
|
||||
model=model_name,
|
||||
params=task_params,
|
||||
user_id=user_id,
|
||||
project_id=project_id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create task: {e}")
|
||||
raise AppException(
|
||||
message=f"Failed to create task: {e}",
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
# 日志 context if provided
|
||||
if request.source and request.source_id:
|
||||
logger.info(f"Image generation started for {request.source}:{request.source_id}, task_id: {task.id}")
|
||||
|
||||
# Unified task manager handles execution automatically
|
||||
return ResponseModel(data={
|
||||
"task_id": task.id,
|
||||
"status": task.status
|
||||
})
|
||||
|
||||
|
||||
@router.post("/image/upscale", response_model=ResponseModel)
|
||||
async def upscale_image(request: UpscaleImageRequest):
|
||||
"""Upscale image resolution"""
|
||||
upscale_service = ModelRegistry.get_default(ModelType.UPSCALE)
|
||||
if not upscale_service:
|
||||
raise AppException(
|
||||
message="Default Upscale service not configured",
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
try:
|
||||
# Prepare kwargs
|
||||
upscale_kwargs = {
|
||||
"image_url": request.image_url,
|
||||
"rate": request.rate
|
||||
}
|
||||
if request.extra_params:
|
||||
upscale_kwargs.update(request.extra_params)
|
||||
|
||||
upscaled_url = await upscale_service.upscale_image(**upscale_kwargs)
|
||||
|
||||
if not upscaled_url:
|
||||
raise AppException(
|
||||
message="Upscaling failed or returned no URL",
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
return ResponseModel(data=UpscaleImageResponse(
|
||||
original_url=request.image_url,
|
||||
upscaled_url=upscaled_url
|
||||
))
|
||||
except Exception as e:
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
82
backend/src/api/generations/music.py
Normal file
82
backend/src/api/generations/music.py
Normal file
@@ -0,0 +1,82 @@
|
||||
"""
|
||||
Music Generation Endpoints - 音乐生成模块
|
||||
|
||||
包含音乐生成相关端点:
|
||||
- POST /generations/music - 音乐生成
|
||||
"""
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from src.models.schemas import ResponseModel, MusicGenerationRequest
|
||||
from src.services.provider.registry import ModelType
|
||||
from src.services.task_manager import task_manager
|
||||
from src.auth.dependencies import get_current_user, UserAuth
|
||||
|
||||
from .helpers import resolve_service, check_user_api_key
|
||||
from src.utils.errors import ErrorCode, AppException
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@router.post("/generations/music", response_model=ResponseModel)
|
||||
async def generate_music(
|
||||
request: MusicGenerationRequest,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""统一 Music/Lyrics Generation Endpoint."""
|
||||
logger.info("Music unified generation request: mode=%s, model=%s", request.generation_mode, request.model)
|
||||
|
||||
# 1) Resolve model by mode
|
||||
model_type = ModelType.LYRICS if request.generation_mode == "lyrics" else ModelType.MUSIC
|
||||
music_service = resolve_service(request.model, model_type)
|
||||
|
||||
# 1.5 检查用户是否配置了 API Key
|
||||
# 从 music_service 获取 provider_id
|
||||
provider = getattr(music_service, 'provider_id', None)
|
||||
if provider:
|
||||
check_user_api_key(current_user.id, provider)
|
||||
|
||||
if not music_service:
|
||||
raise AppException(
|
||||
message=f"{request.generation_mode} service not configured",
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
# 2) Unified async task for both lyrics/music modes
|
||||
try:
|
||||
model_name = request.model or getattr(music_service, "model_id", "default_music")
|
||||
task_params = request.model_dump()
|
||||
|
||||
# Handle project_id from multiple sources
|
||||
# Priority 1: request.project_id (from request body, alias="projectId")
|
||||
# Priority 2: request.source_id if source=='project' (from extra_params, alias="sourceId")
|
||||
project_id = request.project_id
|
||||
if not project_id and request.source == "project":
|
||||
project_id = request.source_id
|
||||
|
||||
# Get user_id from authenticated user
|
||||
user_id = current_user.id
|
||||
|
||||
task = await task_manager.create_task(
|
||||
task_type="music",
|
||||
model=model_name,
|
||||
params=task_params,
|
||||
user_id=user_id,
|
||||
project_id=project_id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to create unified music task: %s", e)
|
||||
raise AppException(
|
||||
message=f"Failed to create task: {e}",
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
return ResponseModel(data={
|
||||
"task_id": task.id,
|
||||
"status": task.status,
|
||||
"mode": request.generation_mode
|
||||
})
|
||||
407
backend/src/api/generations/script.py
Normal file
407
backend/src/api/generations/script.py
Normal file
@@ -0,0 +1,407 @@
|
||||
"""
|
||||
Script Analysis Endpoints - 脚本分析模块
|
||||
|
||||
包含所有与 LLM 脚本分析相关的端点:
|
||||
- /script/analyze - 分析小说文本
|
||||
- /prompt/optimize - 优化提示词
|
||||
- /style/recommend - 推荐风格
|
||||
- /script/summary - 小说摘要
|
||||
- /script/characters - 角色提取
|
||||
- /script/scenes - 场景提取
|
||||
- /script/props - 道具提取
|
||||
- /script/storyboards - 分镜拆分
|
||||
- /script/split - 章节拆分
|
||||
"""
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends
|
||||
|
||||
from src.auth.dependencies import get_current_user
|
||||
from src.auth.models import UserAuth
|
||||
from src.models.schemas import (
|
||||
ResponseModel,
|
||||
ScriptProcessRequest, ScriptResponse,
|
||||
PromptOptimizationRequest, PromptOptimizationResponse,
|
||||
StyleRecommendationRequest, StyleRecommendationResponse,
|
||||
NovelSummaryRequest, NovelSummaryResponse,
|
||||
CharacterExtractionRequest, CharacterExtractionResponse,
|
||||
SceneExtractionRequest, SceneExtractionResponse,
|
||||
PropExtractionRequest, PropExtractionResponse,
|
||||
StoryboardSplitRequest, StoryboardSplitResponse,
|
||||
ChapterSplitRequest, ChapterSplitResponse,
|
||||
CharacterAsset, SceneAsset, PropAsset
|
||||
)
|
||||
from src.services.script import script_service
|
||||
from src.services.project_service import project_manager
|
||||
|
||||
from .helpers import get_available_styles_from_config
|
||||
from src.utils.errors import ErrorCode, AppException
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@router.post("/script/analyze", response_model=ResponseModel)
|
||||
async def analyze_script(
|
||||
request: ScriptProcessRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""Analyze novel text and generate script"""
|
||||
try:
|
||||
# Handle skip_storyboard priority (explicit param vs extra_params)
|
||||
skip_sb = request.skip_storyboard
|
||||
extra = request.extra_params or {}
|
||||
if 'skip_storyboard' in extra:
|
||||
skip_sb = extra.pop('skip_storyboard')
|
||||
|
||||
response_data = await script_service.analyze_novel(
|
||||
novel_text=request.novel_text,
|
||||
project_id=request.project_id,
|
||||
model_name=request.model,
|
||||
max_input_tokens=request.max_input_tokens,
|
||||
skip_storyboard=skip_sb,
|
||||
user_id=current_user.id,
|
||||
**extra
|
||||
)
|
||||
|
||||
# Automatically save extracted assets to the project
|
||||
if request.project_id:
|
||||
# 获取 existing project assets for deduplication
|
||||
existing_assets_map = {}
|
||||
try:
|
||||
current_project = project_manager.get_project(request.project_id)
|
||||
if current_project and current_project.assets:
|
||||
for asset in current_project.assets:
|
||||
existing_assets_map[(asset.type, asset.name)] = asset
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch project for deduplication: {e}")
|
||||
|
||||
# 辅助函数 to create and add asset
|
||||
def save_asset(item: dict, asset_type: str):
|
||||
name = item.get('name')
|
||||
if not name:
|
||||
return
|
||||
|
||||
# Check if exists
|
||||
existing_asset = existing_assets_map.get((asset_type, name))
|
||||
|
||||
if existing_asset:
|
||||
logger.info(f"Asset {name} ({asset_type}) already exists, updating.")
|
||||
asset_id = existing_asset.id
|
||||
else:
|
||||
asset_id = str(uuid.uuid4())
|
||||
|
||||
# 创建 specific asset models
|
||||
try:
|
||||
asset = None
|
||||
if asset_type == 'character':
|
||||
asset = CharacterAsset(
|
||||
id=asset_id,
|
||||
name=item.get('name'),
|
||||
desc=item.get('desc', ''),
|
||||
tags=item.get('tags', []),
|
||||
age=item.get('age'),
|
||||
role=item.get('role'),
|
||||
appearance=item.get('appearance'),
|
||||
image_prompt=item.get('image_prompt')
|
||||
)
|
||||
elif asset_type == 'scene':
|
||||
asset = SceneAsset(
|
||||
id=asset_id,
|
||||
name=item.get('name'),
|
||||
desc=item.get('desc', ''),
|
||||
tags=item.get('tags', []),
|
||||
location=item.get('location'),
|
||||
time_of_day=item.get('time_of_day'),
|
||||
atmosphere=item.get('atmosphere'),
|
||||
image_prompt=item.get('image_prompt')
|
||||
)
|
||||
elif asset_type == 'prop':
|
||||
asset = PropAsset(
|
||||
id=asset_id,
|
||||
name=item.get('name'),
|
||||
desc=item.get('desc', ''),
|
||||
tags=item.get('tags', []),
|
||||
usage=item.get('usage'),
|
||||
image_prompt=item.get('image_prompt')
|
||||
)
|
||||
else:
|
||||
return
|
||||
|
||||
if existing_asset:
|
||||
project_manager.update_asset(request.project_id, asset_id, asset)
|
||||
else:
|
||||
project_manager.add_asset(request.project_id, asset)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save asset {item.get('name')}: {e}")
|
||||
|
||||
# Save characters
|
||||
for char in (response_data.characters or []):
|
||||
save_asset(char.model_dump(), 'character')
|
||||
|
||||
# Save scenes
|
||||
for scene in (response_data.scenes or []):
|
||||
save_asset(scene.model_dump(), 'scene')
|
||||
|
||||
# Save props
|
||||
for prop in (response_data.props or []):
|
||||
save_asset(prop.model_dump(), 'prop')
|
||||
|
||||
# Save summary if available
|
||||
if response_data.summary:
|
||||
try:
|
||||
project_manager.update_project(request.project_id, {"description": response_data.summary})
|
||||
logger.info(f"Updated project description with generated summary.")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update project description: {e}")
|
||||
|
||||
return ResponseModel(data=response_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Script analysis failed: {str(e)}")
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.post("/prompt/optimize", response_model=ResponseModel)
|
||||
async def optimize_prompt_endpoint(
|
||||
request: PromptOptimizationRequest,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""Optimize prompt for image or video generation"""
|
||||
try:
|
||||
optimized_prompt = await script_service.optimize_prompt(
|
||||
prompt=request.prompt,
|
||||
target_type=request.target_type,
|
||||
template=request.template,
|
||||
model_name=request.model,
|
||||
provider=request.provider,
|
||||
language=request.language or "Chinese",
|
||||
user_id=current_user.id,
|
||||
**(request.extra_params or {})
|
||||
)
|
||||
return ResponseModel(data=PromptOptimizationResponse(
|
||||
original_prompt=request.prompt,
|
||||
optimized_prompt=optimized_prompt,
|
||||
target_type=request.target_type
|
||||
))
|
||||
except Exception as e:
|
||||
logger.error(f"Prompt optimization failed: {str(e)}")
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.post("/style/recommend", response_model=ResponseModel)
|
||||
async def recommend_style(
|
||||
request: StyleRecommendationRequest,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""Recommend an art style based on novel text.
|
||||
|
||||
Backend automatically fetches available styles with full descriptions from config.
|
||||
"""
|
||||
try:
|
||||
# Automatically fetch styles with full information from backend config
|
||||
styles_full = await get_available_styles_from_config()
|
||||
logger.info(f"Auto-fetched {len(styles_full)} styles from backend config")
|
||||
|
||||
# 格式化 style information for AI: include name and description
|
||||
available_styles = [
|
||||
f"{style['name']} - {style.get('desc', '')}"
|
||||
for style in styles_full
|
||||
if style.get('name')
|
||||
]
|
||||
|
||||
logger.info(f"Formatted {len(available_styles)} styles with descriptions for AI")
|
||||
|
||||
result = await script_service.recommend_style(
|
||||
novel_text=request.novel_text,
|
||||
available_styles=available_styles,
|
||||
model_name=request.model,
|
||||
provider=request.provider,
|
||||
user_id=current_user.id,
|
||||
**(request.extra_params or {})
|
||||
)
|
||||
return ResponseModel(data=StyleRecommendationResponse(**result))
|
||||
except Exception as e:
|
||||
logger.error(f"Style recommendation failed: {str(e)}")
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.post("/script/summary", response_model=ResponseModel)
|
||||
async def summarize_novel_endpoint(
|
||||
request: NovelSummaryRequest,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
""" 总和marize novel text into a concise overview (without characters)"""
|
||||
try:
|
||||
result = await script_service.summarize_novel(
|
||||
novel_text=request.novel_text,
|
||||
language=request.language,
|
||||
model_name=request.model,
|
||||
provider=request.provider,
|
||||
global_summary=request.global_summary,
|
||||
user_id=current_user.id,
|
||||
**(request.extra_params or {})
|
||||
)
|
||||
return ResponseModel(data=NovelSummaryResponse(**result))
|
||||
except Exception as e:
|
||||
logger.error(f"Novel summary failed: {str(e)}")
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.post("/script/characters", response_model=ResponseModel)
|
||||
async def extract_characters_endpoint(
|
||||
request: CharacterExtractionRequest,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
""" 提取 characters from novel text"""
|
||||
try:
|
||||
result = await script_service.extract_characters(
|
||||
novel_text=request.novel_text,
|
||||
language=request.language,
|
||||
model_name=request.model,
|
||||
provider=request.provider,
|
||||
global_summary=request.global_summary,
|
||||
known_characters=request.known_characters,
|
||||
user_id=current_user.id,
|
||||
**(request.extra_params or {})
|
||||
)
|
||||
return ResponseModel(data=CharacterExtractionResponse(**result))
|
||||
except Exception as e:
|
||||
logger.error(f"Character extraction failed: {str(e)}")
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.post("/script/scenes", response_model=ResponseModel)
|
||||
async def extract_scenes_endpoint(
|
||||
request: SceneExtractionRequest,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
""" 提取 scenes from novel text"""
|
||||
try:
|
||||
result = await script_service.extract_scenes(
|
||||
novel_text=request.novel_text,
|
||||
language=request.language,
|
||||
model_name=request.model,
|
||||
global_summary=request.global_summary,
|
||||
known_scenes=request.known_scenes,
|
||||
user_id=current_user.id,
|
||||
**(request.extra_params or {})
|
||||
)
|
||||
return ResponseModel(data=SceneExtractionResponse(**result))
|
||||
except Exception as e:
|
||||
logger.error(f"Scene extraction failed: {str(e)}")
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.post("/script/props", response_model=ResponseModel)
|
||||
async def extract_props_endpoint(
|
||||
request: PropExtractionRequest,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
""" 提取 props from novel text"""
|
||||
try:
|
||||
result = await script_service.extract_props(
|
||||
novel_text=request.novel_text,
|
||||
language=request.language,
|
||||
model_name=request.model,
|
||||
global_summary=request.global_summary,
|
||||
known_props=request.known_props,
|
||||
user_id=current_user.id,
|
||||
**(request.extra_params or {})
|
||||
)
|
||||
return ResponseModel(data=PropExtractionResponse(**result))
|
||||
except Exception as e:
|
||||
logger.error(f"Prop extraction failed: {str(e)}")
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.post("/script/storyboards", response_model=ResponseModel)
|
||||
async def split_storyboards_endpoint(
|
||||
request: StoryboardSplitRequest,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""Split novel text into storyboards (shots)"""
|
||||
try:
|
||||
result = await script_service.split_storyboards(
|
||||
novel_text=request.novel_text,
|
||||
project_id=request.project_id,
|
||||
model_name=request.model,
|
||||
provider=request.provider,
|
||||
language=request.language,
|
||||
known_characters=request.known_characters,
|
||||
known_scenes=request.known_scenes,
|
||||
known_props=request.known_props,
|
||||
user_id=current_user.id,
|
||||
**(request.extra_params or {})
|
||||
)
|
||||
return ResponseModel(data=StoryboardSplitResponse(**result))
|
||||
except Exception as e:
|
||||
logger.error(f"Storyboard splitting failed: {str(e)}")
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.post("/script/split", response_model=ResponseModel)
|
||||
async def split_chapters_endpoint(
|
||||
request: ChapterSplitRequest,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""Split novel text into chapters using regex or agent"""
|
||||
try:
|
||||
# Check if agent splitting is requested via extra_params
|
||||
extra = request.extra_params or {}
|
||||
use_agent = extra.get("use_agent", False)
|
||||
|
||||
chapters = await script_service.split_chapters(
|
||||
novel_text=request.novel_text,
|
||||
regex_pattern=request.regex_pattern,
|
||||
use_agent=use_agent,
|
||||
model_name=request.model,
|
||||
language=extra.get("language", "Chinese"),
|
||||
user_id=current_user.id
|
||||
)
|
||||
return ResponseModel(data=ChapterSplitResponse(
|
||||
chapters=chapters,
|
||||
total_chapters=len(chapters)
|
||||
))
|
||||
except Exception as e:
|
||||
logger.error(f"Chapter split failed: {str(e)}")
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
212
backend/src/api/generations/tasks.py
Normal file
212
backend/src/api/generations/tasks.py
Normal file
@@ -0,0 +1,212 @@
|
||||
"""
|
||||
Task Status Endpoints - 任务状态模块
|
||||
|
||||
包含任务管理相关的端点:
|
||||
- GET /tasks - 列出任务
|
||||
- GET /tasks/{task_id} - 获取任务状态
|
||||
"""
|
||||
import logging
|
||||
import asyncio
|
||||
|
||||
from fastapi import APIRouter, Query, Request
|
||||
from typing import Optional
|
||||
|
||||
from src.models.schemas import ResponseModel, PaginationParams
|
||||
from src.utils.pagination import Paginator
|
||||
from src.services.provider.registry import ModelRegistry, ModelType
|
||||
from src.services.provider.base import TaskStatus
|
||||
from src.services.task_manager import task_manager
|
||||
from src.services.storage_service import storage_manager
|
||||
from src.config.settings import OSS_BUCKET, OSS_REGION
|
||||
from src.utils.errors import TaskNotFoundException, AppException, ErrorCode
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@router.get("/tasks", response_model=ResponseModel)
|
||||
async def list_tasks(
|
||||
request: Request,
|
||||
type: Optional[str] = Query(None, description="Filter by task type"),
|
||||
page: int = Query(1, ge=1, description="页码,从 1 开始"),
|
||||
page_size: int = Query(20, ge=1, le=100, description="每页数量,最大 100"),
|
||||
):
|
||||
"""列表 tasks with optional filtering.
|
||||
|
||||
Args:
|
||||
type: 任务类型过滤
|
||||
page: 页码
|
||||
page_size: 每页数量
|
||||
|
||||
Returns:
|
||||
分页的任务列表
|
||||
"""
|
||||
# 计算偏移量
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
# 获取任务列表
|
||||
tasks = task_manager.list_tasks(type=type, limit=page_size, offset=offset)
|
||||
|
||||
# 获取总数(简化处理,实际应该查询总数)
|
||||
total = len(tasks)
|
||||
|
||||
# 创建分页器
|
||||
paginator = Paginator(
|
||||
items=tasks,
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size
|
||||
)
|
||||
|
||||
return paginator.to_response(request)
|
||||
|
||||
|
||||
@router.get("/tasks/{task_id}", response_model=ResponseModel)
|
||||
async def get_task_status(task_id: str):
|
||||
"""
|
||||
Check status of a task.
|
||||
"""
|
||||
# 1. Check local TaskManager
|
||||
task = await task_manager.get_task(task_id)
|
||||
|
||||
if task:
|
||||
# If task is active (pending/processing), check provider status
|
||||
active_statuses = ["pending", "processing", "queued", "running", "submitted"]
|
||||
if task.status.lower() in active_statuses and task.provider_task_id:
|
||||
service = ModelRegistry.get(task.model)
|
||||
# Fallback to default if model service not found by name
|
||||
if not service:
|
||||
if task.type == "image":
|
||||
service = ModelRegistry.get_default(ModelType.IMAGE)
|
||||
elif task.type == "video":
|
||||
service = ModelRegistry.get_default(ModelType.VIDEO)
|
||||
elif task.type == "audio":
|
||||
service = ModelRegistry.get_default(ModelType.AUDIO)
|
||||
elif task.type == "music":
|
||||
service = ModelRegistry.get_default(ModelType.MUSIC)
|
||||
|
||||
if service and hasattr(service, 'check_status'):
|
||||
try:
|
||||
result = await service.check_status(task.provider_task_id, user_id=task.user_id)
|
||||
|
||||
# 更新 status
|
||||
# 映射 TaskStatus enum to string
|
||||
new_status = result.status.value.lower()
|
||||
if new_status == "succeeded":
|
||||
new_status = "success"
|
||||
|
||||
# If succeeded, process results (OSS upload)
|
||||
final_results = {}
|
||||
if result.status == TaskStatus.SUCCEEDED and result.results:
|
||||
processed_urls = []
|
||||
for idx, item in enumerate(result.results):
|
||||
if item.url:
|
||||
# Determine extension
|
||||
ext = "png"
|
||||
if task.type == "video":
|
||||
ext = "mp4"
|
||||
elif task.type in ("audio", "music"):
|
||||
ext = "mp3"
|
||||
|
||||
key = f"generated/{task.type}s/{task.id}_{idx}.{ext}"
|
||||
|
||||
# Save to storage (Local or OSS)
|
||||
oss_url = await asyncio.to_thread(storage_manager.save_from_url, item.url, key)
|
||||
processed_urls.append(oss_url or item.url)
|
||||
|
||||
final_results = {"urls": processed_urls}
|
||||
|
||||
# 更新 task in DB
|
||||
update_data = {"status": new_status}
|
||||
if final_results:
|
||||
update_data["result"] = final_results
|
||||
if result.error:
|
||||
update_data["error"] = str(result.error)
|
||||
|
||||
task = task_manager.update_task(task.id, **update_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check provider status for task {task.id}: {e}")
|
||||
# Don't fail the request, just return current state
|
||||
|
||||
# Retry OSS upload if succeeded but urls are not OSS (resilience)
|
||||
elif (task.status.lower() == "succeeded" or task.status.lower() == "success") and task.result and "urls" in task.result:
|
||||
try:
|
||||
urls = task.result["urls"]
|
||||
updated_urls = []
|
||||
needs_update = False
|
||||
|
||||
for idx, url in enumerate(urls):
|
||||
# Check if it looks like an OSS url
|
||||
is_oss = False
|
||||
if OSS_BUCKET and OSS_REGION and f"{OSS_BUCKET}.{OSS_REGION}" in url:
|
||||
is_oss = True
|
||||
|
||||
if not is_oss:
|
||||
# Attempt upload
|
||||
ext = "png"
|
||||
if task.type == "video":
|
||||
ext = "mp4"
|
||||
elif task.type in ("audio", "music"):
|
||||
ext = "mp3"
|
||||
key = f"generated/{task.type}s/{task.id}_{idx}.{ext}"
|
||||
|
||||
new_url = await asyncio.to_thread(storage_manager.save_from_url, url, key)
|
||||
if new_url and new_url != url:
|
||||
updated_urls.append(new_url)
|
||||
needs_update = True
|
||||
else:
|
||||
updated_urls.append(url)
|
||||
else:
|
||||
# It is OSS, maybe re-sign it
|
||||
signed_url = storage_manager.sign_url(url)
|
||||
updated_urls.append(signed_url if signed_url else url)
|
||||
|
||||
if needs_update:
|
||||
task = task_manager.update_task(task.id, result={"urls": updated_urls})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to retry OSS upload for task {task.id}: {e}")
|
||||
|
||||
return ResponseModel(data=task)
|
||||
|
||||
raise AppException(
|
||||
message="Task not found",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
|
||||
|
||||
@router.post("/tasks/{task_id}/cancel", response_model=ResponseModel)
|
||||
async def cancel_task(task_id: str):
|
||||
"""
|
||||
Cancel a pending/processing task.
|
||||
仅更新本地任务状态为 CANCELLED,不保证下游 provider 真正中止计算,
|
||||
但对于绝大多数异步生成场景已经足够(前端不再轮询,结果也不会再写入)。
|
||||
"""
|
||||
try:
|
||||
cancelled = await task_manager.cancel_task(task_id)
|
||||
except TaskNotFoundException:
|
||||
raise AppException(
|
||||
message="Task not found",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cancel task {task_id}: {e}")
|
||||
raise AppException(
|
||||
message="Failed to cancel task",
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
if not cancelled:
|
||||
# 已完成任务无法取消
|
||||
raise AppException(
|
||||
message="Task already finished and cannot be cancelled",
|
||||
code=ErrorCode.INVALID_PARAMETER,
|
||||
status_code=400
|
||||
)
|
||||
|
||||
# 返回简单结果,前端如需最新任务详情可以再调用 GET /tasks/{task_id}
|
||||
return ResponseModel(data={"task_id": task_id, "cancelled": True})
|
||||
136
backend/src/api/generations/video.py
Normal file
136
backend/src/api/generations/video.py
Normal file
@@ -0,0 +1,136 @@
|
||||
"""
|
||||
Video Generation Endpoints - 视频生成模块
|
||||
|
||||
包含视频生成相关的端点:
|
||||
- POST /generations/video - 视频生成
|
||||
"""
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from src.models.schemas import ResponseModel, VideoGenerationRequest
|
||||
from src.services.provider.registry import ModelType
|
||||
from src.services.task_manager import task_manager
|
||||
from src.auth.dependencies import get_current_user, UserAuth
|
||||
|
||||
from .helpers import resolve_service, ensure_url, check_user_api_key
|
||||
from src.utils.errors import ErrorCode, AppException
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@router.post("/generations/video", response_model=ResponseModel)
|
||||
async def generate_video(
|
||||
request: VideoGenerationRequest,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
""" 泛型 Video Generation Endpoint.
|
||||
"""
|
||||
# 0. Sanitize URLs Early (Convert Base64 -> URL)
|
||||
try:
|
||||
if request.image_inputs:
|
||||
request.image_inputs = [ensure_url(url) for url in request.image_inputs if url]
|
||||
if request.video_inputs:
|
||||
request.video_inputs = [ensure_url(url) for url in request.video_inputs if url]
|
||||
if request.audio_inputs:
|
||||
request.audio_inputs = [ensure_url(url) for url in request.audio_inputs if url]
|
||||
|
||||
# Also check extra_params for common image keys
|
||||
if request.extra_params:
|
||||
for key in ["image_url", "video_url", "input_image", "input_video", "audio_url"]:
|
||||
if key in request.extra_params and isinstance(request.extra_params[key], str):
|
||||
request.extra_params[key] = ensure_url(request.extra_params[key])
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to sanitize inputs: {e}")
|
||||
|
||||
# 1. Resolve Model - 使用统一的路由逻辑
|
||||
video_service = resolve_service(request.model, ModelType.VIDEO)
|
||||
|
||||
if not video_service:
|
||||
raise AppException(
|
||||
message="Video service not configured",
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
# 1.5 检查用户是否配置了 API Key
|
||||
# 从 video_service 获取 provider_id
|
||||
provider = getattr(video_service, 'provider_id', None)
|
||||
if provider:
|
||||
check_user_api_key(current_user.id, provider)
|
||||
|
||||
# 2. Resolve size from aspect_ratio (BEFORE creating task)
|
||||
aspect_ratio = request.aspect_ratio
|
||||
size = None
|
||||
|
||||
if aspect_ratio:
|
||||
# Look up resolution in model config
|
||||
model_config = getattr(video_service, "config", {}) or {}
|
||||
if isinstance(model_config, dict):
|
||||
resolutions_config = model_config.get("resolutions") or {}
|
||||
else:
|
||||
resolutions_config = {}
|
||||
|
||||
# Use provided resolution level (e.g. "720P", "1080P") or default to "720P"
|
||||
res_level = request.resolution or "720P"
|
||||
|
||||
if resolutions_config and res_level in resolutions_config and isinstance(resolutions_config[res_level], dict):
|
||||
ratio_map = resolutions_config[res_level]
|
||||
if aspect_ratio in ratio_map:
|
||||
size = ratio_map[aspect_ratio]
|
||||
|
||||
if not size:
|
||||
defaults = {
|
||||
"16:9": "1280*720",
|
||||
"9:16": "720*1280",
|
||||
"1:1": "1280*1280",
|
||||
"4:3": "1280*960",
|
||||
"3:4": "960*1280"
|
||||
}
|
||||
size = defaults.get(aspect_ratio)
|
||||
|
||||
# 3. Create Task in DB with resolved params
|
||||
try:
|
||||
model_name = request.model or getattr(video_service, "model_id", "default_video")
|
||||
|
||||
# 构建 params dict with resolved size
|
||||
task_params = request.model_dump()
|
||||
if size:
|
||||
task_params["size"] = size # Override with resolved size
|
||||
|
||||
# Use unified task manager
|
||||
# Get user_id from authenticated user
|
||||
user_id = current_user.id
|
||||
|
||||
# Handle project_id from multiple sources
|
||||
# Priority 1: request.project_id (from request body, alias="projectId")
|
||||
# Priority 2: request.source_id if source=='project' (from extra_params, alias="sourceId")
|
||||
project_id = request.project_id
|
||||
if not project_id and request.source == 'project':
|
||||
project_id = request.source_id
|
||||
|
||||
task = await task_manager.create_task(
|
||||
task_type="video",
|
||||
model=model_name,
|
||||
params=task_params,
|
||||
user_id=user_id,
|
||||
project_id=project_id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create task: {e}")
|
||||
raise AppException(
|
||||
message=f"Failed to create task: {e}",
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
# 日志 context if provided
|
||||
if request.source and request.source_id:
|
||||
logger.info(f"Video generation started for {request.source}:{request.source_id}, task_id: {task.id}")
|
||||
|
||||
# Unified task manager handles execution automatically
|
||||
return ResponseModel(data={
|
||||
"task_id": task.id,
|
||||
"status": task.status
|
||||
})
|
||||
350
backend/src/api/health.py
Normal file
350
backend/src/api/health.py
Normal file
@@ -0,0 +1,350 @@
|
||||
"""
|
||||
健康检查和监控端点
|
||||
|
||||
提供系统健康状态、性能指标和诊断信息
|
||||
"""
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any
|
||||
from fastapi import APIRouter, Response
|
||||
from prometheus_client import generate_latest, CONTENT_TYPE_LATEST
|
||||
|
||||
from src.models.schemas import ResponseModel
|
||||
from src.services.task_manager import task_manager
|
||||
from src.services.provider.registry import ModelRegistry
|
||||
from src.services.provider.health import health_monitor
|
||||
from src.config.database import engine, get_pool_status
|
||||
from sqlmodel import Session, text
|
||||
from src.utils.errors import ErrorCode, AppException
|
||||
|
||||
router = APIRouter(tags=["health"])
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 应用启动时间
|
||||
_start_time = time.time()
|
||||
|
||||
|
||||
@router.get("/health", response_model=ResponseModel)
|
||||
async def health_check():
|
||||
"""基础健康检查
|
||||
|
||||
Returns:
|
||||
基本的健康状态信息
|
||||
"""
|
||||
return ResponseModel(data={
|
||||
"status": "healthy",
|
||||
"service": "Pixel Backend",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"uptime_seconds": int(time.time() - _start_time)
|
||||
})
|
||||
|
||||
|
||||
@router.get("/health/detailed", response_model=ResponseModel)
|
||||
async def detailed_health_check():
|
||||
"""详细健康检查
|
||||
|
||||
检查所有关键组件的健康状态
|
||||
|
||||
Returns:
|
||||
详细的健康状态信息,包括各组件状态
|
||||
"""
|
||||
health_status = {
|
||||
"status": "healthy",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"uptime_seconds": int(time.time() - _start_time),
|
||||
"components": {}
|
||||
}
|
||||
|
||||
# 检查数据库连接
|
||||
try:
|
||||
start = time.time()
|
||||
with Session(engine) as session:
|
||||
session.exec(text("SELECT 1"))
|
||||
latency_ms = (time.time() - start) * 1000
|
||||
|
||||
# Get connection pool status
|
||||
pool_status = get_pool_status()
|
||||
|
||||
health_status["components"]["database"] = {
|
||||
"status": "healthy",
|
||||
"message": "Database connection successful",
|
||||
"latency_ms": round(latency_ms, 2),
|
||||
"pool": pool_status
|
||||
}
|
||||
except Exception as e:
|
||||
health_status["status"] = "unhealthy"
|
||||
health_status["components"]["database"] = {
|
||||
"status": "unhealthy",
|
||||
"message": f"Database connection failed: {str(e)}"
|
||||
}
|
||||
|
||||
# 检查Redis连接
|
||||
from src.config.settings import REDIS_ENABLED
|
||||
if REDIS_ENABLED:
|
||||
try:
|
||||
from src.services.cache_service import get_cache_service
|
||||
cache = get_cache_service()
|
||||
|
||||
if cache._connected:
|
||||
start = time.time()
|
||||
await cache._redis.ping()
|
||||
latency_ms = (time.time() - start) * 1000
|
||||
|
||||
# Get Redis info
|
||||
info = await cache._redis.info()
|
||||
health_status["components"]["redis"] = {
|
||||
"status": "healthy",
|
||||
"message": "Redis connection successful",
|
||||
"latency_ms": round(latency_ms, 2),
|
||||
"version": info.get("redis_version"),
|
||||
"connected_clients": info.get("connected_clients"),
|
||||
"used_memory_human": info.get("used_memory_human")
|
||||
}
|
||||
else:
|
||||
health_status["status"] = "degraded"
|
||||
health_status["components"]["redis"] = {
|
||||
"status": "unhealthy",
|
||||
"message": "Redis not connected"
|
||||
}
|
||||
except Exception as e:
|
||||
health_status["status"] = "degraded"
|
||||
health_status["components"]["redis"] = {
|
||||
"status": "unhealthy",
|
||||
"message": f"Redis connection failed: {str(e)}"
|
||||
}
|
||||
else:
|
||||
health_status["components"]["redis"] = {
|
||||
"status": "disabled",
|
||||
"message": "Redis is disabled in configuration"
|
||||
}
|
||||
|
||||
# 检查任务管理器
|
||||
try:
|
||||
stats = task_manager.get_stats()
|
||||
health_status["components"]["task_manager"] = {
|
||||
"status": "healthy",
|
||||
"stats": stats
|
||||
}
|
||||
except Exception as e:
|
||||
health_status["status"] = "degraded"
|
||||
health_status["components"]["task_manager"] = {
|
||||
"status": "unhealthy",
|
||||
"message": f"Task manager error: {str(e)}"
|
||||
}
|
||||
|
||||
# 检查模型注册表
|
||||
try:
|
||||
models = ModelRegistry.list_models()
|
||||
health_status["components"]["model_registry"] = {
|
||||
"status": "healthy",
|
||||
"total_models": len(models)
|
||||
}
|
||||
except Exception as e:
|
||||
health_status["status"] = "degraded"
|
||||
health_status["components"]["model_registry"] = {
|
||||
"status": "unhealthy",
|
||||
"message": f"Model registry error: {str(e)}"
|
||||
}
|
||||
|
||||
# 检查 AI 服务健康状态
|
||||
try:
|
||||
service_health = health_monitor.get_health_summary()
|
||||
health_status["components"]["ai_services"] = {
|
||||
"status": "healthy" if service_health["healthy"] == service_health["total"] else "degraded",
|
||||
"summary": service_health
|
||||
}
|
||||
except Exception as e:
|
||||
health_status["status"] = "degraded"
|
||||
health_status["components"]["ai_services"] = {
|
||||
"status": "unknown",
|
||||
"message": f"Health monitor error: {str(e)}"
|
||||
}
|
||||
|
||||
return ResponseModel(data=health_status)
|
||||
|
||||
|
||||
@router.get("/health/live", response_model=ResponseModel)
|
||||
async def liveness_probe():
|
||||
""" Kubernetes liveness probe
|
||||
|
||||
简单检查应用是否运行
|
||||
|
||||
Returns:
|
||||
200 OK if alive
|
||||
"""
|
||||
return ResponseModel(data={"status": "alive"})
|
||||
|
||||
|
||||
@router.get("/health/ready", response_model=ResponseModel)
|
||||
async def readiness_probe():
|
||||
""" Kubernetes readiness probe
|
||||
|
||||
检查应用是否准备好接收流量
|
||||
|
||||
Returns:
|
||||
200 OK if ready, 503 if not ready
|
||||
"""
|
||||
# 检查关键组件
|
||||
ready = True
|
||||
components = {}
|
||||
|
||||
# 检查数据库
|
||||
try:
|
||||
with Session(engine) as session:
|
||||
session.exec(text("SELECT 1"))
|
||||
components["database"] = "ready"
|
||||
except Exception as e:
|
||||
ready = False
|
||||
components["database"] = f"not ready: {str(e)}"
|
||||
|
||||
# 检查Redis(如果启用)
|
||||
from src.config.settings import REDIS_ENABLED
|
||||
if REDIS_ENABLED:
|
||||
try:
|
||||
from src.services.cache_service import get_cache_service
|
||||
cache = get_cache_service()
|
||||
|
||||
if cache._connected:
|
||||
await cache._redis.ping()
|
||||
components["redis"] = "ready"
|
||||
else:
|
||||
ready = False
|
||||
components["redis"] = "not ready: not connected"
|
||||
except Exception as e:
|
||||
ready = False
|
||||
components["redis"] = f"not ready: {str(e)}"
|
||||
else:
|
||||
components["redis"] = "disabled"
|
||||
|
||||
# 检查任务管理器
|
||||
try:
|
||||
task_manager.get_stats()
|
||||
components["task_manager"] = "ready"
|
||||
except Exception as e:
|
||||
ready = False
|
||||
components["task_manager"] = f"not ready: {str(e)}"
|
||||
|
||||
if ready:
|
||||
return ResponseModel(data={"status": "ready", "components": components})
|
||||
else:
|
||||
raise AppException(
|
||||
message="Service not ready",
|
||||
code=ErrorCode.SERVICE_UNAVAILABLE,
|
||||
status_code=503,
|
||||
details={"status": "not ready", "components": components}
|
||||
)
|
||||
|
||||
|
||||
@router.get("/metrics")
|
||||
async def prometheus_metrics():
|
||||
"""Prometheus metrics endpoint
|
||||
|
||||
导出 Prometheus 格式的监控指标
|
||||
|
||||
Returns:
|
||||
Prometheus 格式的指标数据
|
||||
"""
|
||||
# 更新 system and database metrics before returning
|
||||
from src.middlewares.metrics import update_system_metrics, update_database_metrics
|
||||
update_system_metrics()
|
||||
update_database_metrics()
|
||||
|
||||
return Response(
|
||||
content=generate_latest(),
|
||||
media_type=CONTENT_TYPE_LATEST
|
||||
)
|
||||
|
||||
|
||||
@router.get("/metrics/tasks", response_model=ResponseModel)
|
||||
async def task_metrics():
|
||||
"""任务管理器指标
|
||||
|
||||
Returns:
|
||||
任务管理器的详细统计信息
|
||||
"""
|
||||
try:
|
||||
stats = task_manager.get_stats()
|
||||
return ResponseModel(data=stats)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get task metrics: {e}", exc_info=True)
|
||||
return ResponseModel(
|
||||
code=500,
|
||||
message="Failed to get task metrics",
|
||||
data={"error": str(e)}
|
||||
)
|
||||
|
||||
|
||||
@router.get("/metrics/models", response_model=ResponseModel)
|
||||
async def model_metrics():
|
||||
"""模型服务指标
|
||||
|
||||
Returns:
|
||||
所有注册模型的健康状态和统计信息
|
||||
"""
|
||||
try:
|
||||
# 获取所有模型
|
||||
models = ModelRegistry.list_models()
|
||||
|
||||
# 获取健康状态
|
||||
health_summary = health_monitor.get_health_summary()
|
||||
all_health = health_monitor.get_all_health()
|
||||
|
||||
# 构建响应
|
||||
model_stats = []
|
||||
for model_id, config in models.items():
|
||||
health = all_health.get(model_id)
|
||||
model_stat = {
|
||||
"id": model_id,
|
||||
"name": config.get("name"),
|
||||
"type": config.get("type"),
|
||||
"provider": config.get("provider"),
|
||||
"enabled": config.get("enabled", True)
|
||||
}
|
||||
|
||||
if health:
|
||||
model_stat["health"] = {
|
||||
"status": health.status.value,
|
||||
"success_rate": health.get_success_rate(),
|
||||
"avg_latency_ms": health.avg_latency_ms,
|
||||
"total_checks": health.total_checks,
|
||||
"total_failures": health.total_failures
|
||||
}
|
||||
|
||||
model_stats.append(model_stat)
|
||||
|
||||
return ResponseModel(data={
|
||||
"summary": health_summary,
|
||||
"models": model_stats
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get model metrics: {e}", exc_info=True)
|
||||
return ResponseModel(
|
||||
code=500,
|
||||
message="Failed to get model metrics",
|
||||
data={"error": str(e)}
|
||||
)
|
||||
|
||||
|
||||
@router.get("/debug/info", response_model=ResponseModel)
|
||||
async def debug_info():
|
||||
"""调试信息
|
||||
|
||||
提供系统配置和运行时信息(仅用于开发环境)
|
||||
|
||||
Returns:
|
||||
系统调试信息
|
||||
"""
|
||||
import sys
|
||||
import platform
|
||||
from src.config.settings import NODE_ENV, STORAGE_TYPE, REDIS_ENABLED
|
||||
|
||||
return ResponseModel(data={
|
||||
"environment": NODE_ENV,
|
||||
"python_version": sys.version,
|
||||
"platform": platform.platform(),
|
||||
"storage_type": STORAGE_TYPE,
|
||||
"redis_enabled": REDIS_ENABLED,
|
||||
"uptime_seconds": int(time.time() - _start_time),
|
||||
"start_time": datetime.fromtimestamp(_start_time).isoformat()
|
||||
})
|
||||
29
backend/src/api/projects/__init__.py
Normal file
29
backend/src/api/projects/__init__.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""
|
||||
Projects API Package
|
||||
|
||||
项目 API 模块,包含:
|
||||
- core: 项目核心 CRUD 功能
|
||||
- episodes: 剧集管理
|
||||
- assets: 资产管理
|
||||
- storyboards: 分镜管理
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from .core import router as core_router
|
||||
from .episodes import router as episodes_router
|
||||
from .assets import router as assets_router
|
||||
from .storyboards import router as storyboards_router
|
||||
|
||||
# 创建主路由器
|
||||
router = APIRouter(tags=["projects"])
|
||||
|
||||
# 包含所有子路由
|
||||
router.include_router(core_router)
|
||||
router.include_router(episodes_router)
|
||||
router.include_router(assets_router)
|
||||
router.include_router(storyboards_router)
|
||||
|
||||
__all__ = [
|
||||
"router",
|
||||
]
|
||||
305
backend/src/api/projects/assets.py
Normal file
305
backend/src/api/projects/assets.py
Normal file
@@ -0,0 +1,305 @@
|
||||
"""
|
||||
Projects API - Assets Module
|
||||
|
||||
包含资产管理相关的 API 路由。
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Optional, List
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from src.models.schemas import (
|
||||
ResponseModel,
|
||||
Asset, CreateAssetRequest, UpdateAssetRequest,
|
||||
CharacterAsset, SceneAsset, PropAsset, OtherAsset,
|
||||
GenerationRecord,
|
||||
)
|
||||
from src.auth.dependencies import get_current_user
|
||||
from src.auth.models import UserAuth
|
||||
from src.services.project_service import project_manager
|
||||
from src.services.asset_deduplication_service import deduplicate_project_assets as deduplicate_assets_service
|
||||
from src.utils.errors import ErrorCode, AppException
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(tags=["projects-assets"])
|
||||
|
||||
|
||||
def _check_project_access(project_id: str, user_id: str):
|
||||
"""检查用户是否有权限访问项目"""
|
||||
project = project_manager.get_project(project_id, user_id=user_id)
|
||||
return project
|
||||
|
||||
|
||||
# --- Asset Management ---
|
||||
|
||||
@router.post("/projects/{project_id}/assets", response_model=ResponseModel)
|
||||
async def create_asset(
|
||||
project_id: str,
|
||||
request: dict,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Create a new asset.
|
||||
Accepts a dictionary, validates based on 'type' field.
|
||||
"""
|
||||
try:
|
||||
# Check project access
|
||||
if not _check_project_access(project_id, current_user.id):
|
||||
raise AppException(
|
||||
message="Project not found or access denied",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
|
||||
asset_type = request.get('type', 'other')
|
||||
if 'id' not in request:
|
||||
request['id'] = str(uuid.uuid4())
|
||||
|
||||
asset = None
|
||||
if asset_type == 'character':
|
||||
asset = TypeAdapter(CharacterAsset).validate_python(request)
|
||||
elif asset_type == 'scene':
|
||||
asset = TypeAdapter(SceneAsset).validate_python(request)
|
||||
elif asset_type == 'prop':
|
||||
asset = TypeAdapter(PropAsset).validate_python(request)
|
||||
else:
|
||||
asset = TypeAdapter(OtherAsset).validate_python(request)
|
||||
|
||||
updated_project = project_manager.add_asset(project_id, asset)
|
||||
if not updated_project:
|
||||
raise AppException(
|
||||
message="Project not found",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
|
||||
return ResponseModel(data=updated_project)
|
||||
except Exception as e:
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
@router.put("/projects/{project_id}/assets/{asset_id}", response_model=ResponseModel)
|
||||
async def update_asset(project_id: str, asset_id: str, request: UpdateAssetRequest):
|
||||
"""Update an asset"""
|
||||
try:
|
||||
project = project_manager.get_project(project_id)
|
||||
if not project:
|
||||
raise AppException(
|
||||
message="Project not found",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
|
||||
existing_asset = next((a for a in project.assets if a.id == asset_id), None)
|
||||
if not existing_asset:
|
||||
raise AppException(
|
||||
message="Asset not found",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
|
||||
# Merge updates
|
||||
update_data = request.model_dump(exclude_unset=True)
|
||||
updated_asset = existing_asset.model_copy(update=update_data)
|
||||
|
||||
result = project_manager.update_asset(project_id, asset_id, updated_asset)
|
||||
return ResponseModel(data=result)
|
||||
except Exception as e:
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
@router.delete("/projects/{project_id}/assets/{asset_id}", response_model=ResponseModel)
|
||||
async def delete_asset(project_id: str, asset_id: str):
|
||||
"""Delete an asset"""
|
||||
try:
|
||||
result = project_manager.delete_asset(project_id, asset_id)
|
||||
if not result:
|
||||
raise AppException(
|
||||
message="Project not found",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
return ResponseModel(data=result)
|
||||
except Exception as e:
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.post("/projects/{project_id}/assets/deduplicate", response_model=ResponseModel)
|
||||
async def deduplicate_project_assets(
|
||||
project_id: str,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Deduplicate all assets in the project using LLM; update storyboard references.
|
||||
"""
|
||||
try:
|
||||
updated_project = await deduplicate_assets_service(project_id, current_user.id)
|
||||
return ResponseModel(data=updated_project)
|
||||
except ValueError as e:
|
||||
if "not found" in str(e).lower():
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Asset deduplication failed: %s", e, exc_info=True)
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
# --- Asset Management ---
|
||||
@router.get("/projects/{project_id}/assets", response_model=ResponseModel)
|
||||
async def list_project_assets(
|
||||
project_id: str,
|
||||
asset_type: Optional[str] = Query(None, description="Filter by asset type"),
|
||||
search_query: Optional[str] = Query(None, description="Search by name or description"),
|
||||
limit: int = 50,
|
||||
offset: int = 0
|
||||
):
|
||||
"""List assets for a project with optional filtering and pagination"""
|
||||
try:
|
||||
result = project_manager.list_assets(project_id, asset_type, search_query, limit, offset)
|
||||
if result is None:
|
||||
raise AppException(
|
||||
message="Project not found",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
return ResponseModel(data=result)
|
||||
except Exception as e:
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
@router.post("/projects/{project_id}/assets", response_model=ResponseModel)
|
||||
async def create_asset(project_id: str, request: CreateAssetRequest):
|
||||
"""Create a new asset"""
|
||||
try:
|
||||
asset_data = request.model_dump(by_alias=True)
|
||||
asset_data['id'] = str(uuid.uuid4())
|
||||
|
||||
# Validate and convert to Asset Union
|
||||
new_asset = TypeAdapter(Asset).validate_python(asset_data)
|
||||
|
||||
updated_project = project_manager.add_asset(project_id, new_asset)
|
||||
if not updated_project:
|
||||
raise AppException(
|
||||
message="Project not found",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
return ResponseModel(data=updated_project)
|
||||
except Exception as e:
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
@router.post("/projects/{project_id}/assets/batch", response_model=ResponseModel)
|
||||
async def batch_create_assets(project_id: str, request: List[CreateAssetRequest]):
|
||||
"""Batch create assets"""
|
||||
try:
|
||||
new_assets = []
|
||||
for asset_req in request:
|
||||
asset_data = asset_req.model_dump(by_alias=True)
|
||||
asset_data['id'] = str(uuid.uuid4())
|
||||
# Validate and convert to Asset Union
|
||||
new_asset = TypeAdapter(Asset).validate_python(asset_data)
|
||||
new_assets.append(new_asset)
|
||||
|
||||
updated_project = project_manager.batch_add_assets(project_id, new_assets)
|
||||
if not updated_project:
|
||||
raise AppException(
|
||||
message="Project not found",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
return ResponseModel(data=updated_project)
|
||||
except Exception as e:
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
@router.put("/projects/{project_id}/assets/{asset_id}", response_model=ResponseModel)
|
||||
async def update_asset(project_id: str, asset_id: str, request: UpdateAssetRequest):
|
||||
"""Update an asset"""
|
||||
try:
|
||||
project = project_manager.get_project(project_id)
|
||||
if not project:
|
||||
raise AppException(
|
||||
message="Project not found",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
|
||||
existing_asset = next((a for a in project.assets if a.id == asset_id), None)
|
||||
if not existing_asset:
|
||||
raise AppException(
|
||||
message="Asset not found",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
|
||||
update_data = request.model_dump(exclude_unset=True)
|
||||
|
||||
# Manually validate generations if present to ensure they are model instances not dicts
|
||||
if 'generations' in update_data and update_data['generations']:
|
||||
update_data['generations'] = TypeAdapter(List[GenerationRecord]).validate_python(update_data['generations'])
|
||||
|
||||
updated_asset = existing_asset.model_copy(update=update_data)
|
||||
|
||||
updated_project = project_manager.update_asset(project_id, asset_id, updated_asset)
|
||||
return ResponseModel(data=updated_project)
|
||||
except Exception as e:
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
@router.delete("/projects/{project_id}/assets/{asset_id}", response_model=ResponseModel)
|
||||
async def delete_asset(project_id: str, asset_id: str):
|
||||
"""Delete an asset"""
|
||||
try:
|
||||
updated_project = project_manager.delete_asset(project_id, asset_id)
|
||||
if not updated_project:
|
||||
raise AppException(
|
||||
message="Project not found",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
return ResponseModel(data=updated_project)
|
||||
except Exception as e:
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
367
backend/src/api/projects/core.py
Normal file
367
backend/src/api/projects/core.py
Normal file
@@ -0,0 +1,367 @@
|
||||
"""
|
||||
Projects API - Core Module
|
||||
|
||||
包含项目核心 CRUD 相关的 API 路由。
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, Query, Request, Depends
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from src.models.schemas import (
|
||||
ResponseModel,
|
||||
CreateProjectRequest, UpdateProjectRequest,
|
||||
InitializeProjectRequest,
|
||||
)
|
||||
from src.auth.dependencies import get_current_user
|
||||
from src.auth.models import UserAuth
|
||||
from src.services.project_service import project_manager
|
||||
from src.services.project_initialization_service import run_initialization_pipeline
|
||||
from src.services.script_project_initialization_service import run_script_project_initialization
|
||||
from src.services.script_to_canvas_service import convert_script_project_to_canvas
|
||||
from src.utils.errors import ErrorCode, AppException
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(tags=["projects"])
|
||||
|
||||
|
||||
def _check_project_access(project_id: str, user_id: str) -> Optional:
|
||||
"""检查用户是否有权限访问项目"""
|
||||
project = project_manager.get_project(project_id, user_id=user_id)
|
||||
return project
|
||||
|
||||
|
||||
def _progress_callback(project_id: str):
|
||||
def _cb(step: str, percentage: int, message: str, details: Optional[dict] = None):
|
||||
try:
|
||||
project_manager.update_project(
|
||||
project_id,
|
||||
{
|
||||
"progress": {
|
||||
"current_step": step,
|
||||
"percentage": percentage,
|
||||
"message": message,
|
||||
"details": details or {},
|
||||
}
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to update progress: %s", e)
|
||||
return _cb
|
||||
|
||||
|
||||
# --- 项目管理 ---
|
||||
async def _create_initializing_project(
|
||||
request: InitializeProjectRequest,
|
||||
current_user: UserAuth,
|
||||
background_tasks: BackgroundTasks,
|
||||
project_type: str,
|
||||
):
|
||||
project = project_manager.create_project(
|
||||
name=request.name,
|
||||
description="正在从小说初始化...",
|
||||
type=project_type,
|
||||
chapters=[],
|
||||
assets=[],
|
||||
status="initializing",
|
||||
user_id=current_user.id
|
||||
)
|
||||
progress_cb = _progress_callback(project.id)
|
||||
if project_type == "script":
|
||||
background_tasks.add_task(
|
||||
run_script_project_initialization,
|
||||
project.id,
|
||||
request.novel_text,
|
||||
progress_cb,
|
||||
current_user.id,
|
||||
request.model,
|
||||
request.provider,
|
||||
)
|
||||
else:
|
||||
background_tasks.add_task(
|
||||
run_initialization_pipeline,
|
||||
project.id,
|
||||
request.novel_text,
|
||||
request.style,
|
||||
progress_cb,
|
||||
current_user.id,
|
||||
)
|
||||
return project
|
||||
|
||||
|
||||
@router.post("/projects/initialize-from-novel", response_model=ResponseModel)
|
||||
@router.post("/projects/init-from-novel", response_model=ResponseModel)
|
||||
@router.post("/projects/initialize-canvas-from-novel", response_model=ResponseModel)
|
||||
async def initialize_canvas_project_from_novel(
|
||||
request: InitializeProjectRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
使用多智能体管道从小说文本初始化项目(后台任务)
|
||||
步骤: 1. 立即创建项目(草稿) 2. 触发后台任务进行完整初始化
|
||||
"""
|
||||
try:
|
||||
project = await _create_initializing_project(request, current_user, background_tasks, "canvas")
|
||||
return ResponseModel(data=project)
|
||||
except Exception as e:
|
||||
logger.error("Error initializing project: %s", e, exc_info=True)
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.post("/projects/initialize-script-from-novel", response_model=ResponseModel)
|
||||
async def initialize_script_project_from_novel(
|
||||
request: InitializeProjectRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""使用多 Agent 改编流程从小说文本初始化剧本项目。"""
|
||||
try:
|
||||
project = await _create_initializing_project(request, current_user, background_tasks, "script")
|
||||
return ResponseModel(data=project)
|
||||
except Exception as e:
|
||||
logger.error("Error initializing script project: %s", e, exc_info=True)
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
@router.post("/projects", response_model=ResponseModel)
|
||||
async def create_project(
|
||||
request: CreateProjectRequest,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""Create a new project"""
|
||||
try:
|
||||
project = project_manager.create_project(
|
||||
name=request.name,
|
||||
description=request.description,
|
||||
type=request.type,
|
||||
chapters=request.chapters,
|
||||
assets=request.assets,
|
||||
user_id=current_user.id
|
||||
)
|
||||
return ResponseModel(data=project)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating project: {e}", exc_info=True)
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
@router.get("/projects", response_model=ResponseModel)
|
||||
async def list_projects(
|
||||
request: Request,
|
||||
page: int = Query(1, ge=1, description="页码,从1开始"),
|
||||
page_size: int = Query(20, ge=1, le=100, description="每页数量,最大100"),
|
||||
sort: Optional[str] = Query(None, description="排序字段,格式: field:asc 或 field:desc"),
|
||||
filter: Optional[str] = Query(None, description="过滤条件,JSON格式"),
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
List all projects with pagination support.
|
||||
|
||||
Args:
|
||||
page: 页码,从1开始
|
||||
page_size: 每页数量 (1-100, 默认: 20)
|
||||
sort: 排序字段,格式: "field:asc" 或 "field:desc"
|
||||
filter: 过滤条件,JSON格式
|
||||
|
||||
Returns:
|
||||
分页的项目列表
|
||||
"""
|
||||
try:
|
||||
from src.models.schemas import PaginationParams
|
||||
from src.utils.pagination import Paginator
|
||||
|
||||
# 创建分页参数
|
||||
pagination_params = PaginationParams(
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
sort=sort,
|
||||
filter=filter
|
||||
)
|
||||
|
||||
# 计算偏移量
|
||||
offset = pagination_params.get_offset()
|
||||
limit = pagination_params.get_limit()
|
||||
|
||||
# 获取当前用户ID,只返回该用户的项目
|
||||
user_id = current_user.id if current_user else None
|
||||
|
||||
# 获取项目列表(按用户过滤)
|
||||
projects = project_manager.list_projects(limit=limit, offset=offset, user_id=user_id)
|
||||
|
||||
# 获取总数(按用户过滤)
|
||||
total = project_manager.count_projects(user_id=user_id)
|
||||
|
||||
# 创建分页器
|
||||
paginator = Paginator(
|
||||
items=projects,
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size
|
||||
)
|
||||
|
||||
# 返回分页响应
|
||||
return paginator.to_response(request)
|
||||
|
||||
except Exception as e:
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
@router.get("/projects/{project_id}", response_model=ResponseModel)
|
||||
async def get_project(
|
||||
project_id: str,
|
||||
include_assets: bool = Query(default=True, description="Whether to include assets in the response"),
|
||||
include_referenced_assets: bool = Query(default=False, description="Whether to include only referenced assets (if include_assets is False)"),
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""Get project details"""
|
||||
project = project_manager.get_project(
|
||||
project_id,
|
||||
include_assets=include_assets,
|
||||
include_referenced_assets=include_referenced_assets,
|
||||
user_id=current_user.id
|
||||
)
|
||||
if not project:
|
||||
raise AppException(
|
||||
message="Project not found or access denied",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
return ResponseModel(data=project)
|
||||
|
||||
@router.put("/projects/{project_id}", response_model=ResponseModel)
|
||||
async def update_project(
|
||||
project_id: str,
|
||||
request: UpdateProjectRequest,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""Update a project"""
|
||||
try:
|
||||
# First check if project exists and belongs to user
|
||||
existing = project_manager.get_project(project_id, user_id=current_user.id)
|
||||
if not existing:
|
||||
raise AppException(
|
||||
message="Project not found or access denied",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
|
||||
update_data = request.model_dump(exclude_unset=True)
|
||||
updated_project = project_manager.update_project(project_id, update_data)
|
||||
if not updated_project:
|
||||
raise AppException(
|
||||
message="Project not found",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
return ResponseModel(data=updated_project)
|
||||
except Exception as e:
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
@router.delete("/projects/{project_id}", response_model=ResponseModel)
|
||||
async def delete_project(
|
||||
project_id: str,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""Delete a project"""
|
||||
try:
|
||||
# First check if project exists and belongs to user
|
||||
existing = project_manager.get_project(project_id, user_id=current_user.id)
|
||||
if not existing:
|
||||
raise AppException(
|
||||
message="Project not found or access denied",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
|
||||
project_manager.delete_project(project_id)
|
||||
return ResponseModel(data={"id": project_id, "deleted": True})
|
||||
except Exception as e:
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.post("/projects/{project_id}/convert-to-canvas", response_model=ResponseModel)
|
||||
async def convert_project_to_canvas(
|
||||
project_id: str,
|
||||
background_tasks: BackgroundTasks,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""将剧本项目转换为独立的画布项目。"""
|
||||
try:
|
||||
source_project = project_manager.get_project(project_id, user_id=current_user.id)
|
||||
if not source_project:
|
||||
raise AppException(
|
||||
message="Project not found or access denied",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
if source_project.type != "script":
|
||||
raise AppException(
|
||||
message="Only script projects can be converted to canvas projects",
|
||||
code=ErrorCode.INVALID_PARAMETER,
|
||||
status_code=400
|
||||
)
|
||||
|
||||
chapters = [
|
||||
{
|
||||
"id": episode.id,
|
||||
"title": episode.title,
|
||||
"order": episode.order,
|
||||
"content": episode.content,
|
||||
"summary": episode.desc,
|
||||
"status": episode.status,
|
||||
}
|
||||
for episode in (source_project.episodes or [])
|
||||
]
|
||||
|
||||
target_project = project_manager.create_project(
|
||||
name=f"{source_project.name} - 画布",
|
||||
description="正在从剧本项目生成画布...",
|
||||
type="canvas",
|
||||
chapters=chapters,
|
||||
assets=[],
|
||||
status="initializing",
|
||||
user_id=current_user.id
|
||||
)
|
||||
progress_cb = _progress_callback(target_project.id)
|
||||
background_tasks.add_task(
|
||||
convert_script_project_to_canvas,
|
||||
source_project.id,
|
||||
target_project.id,
|
||||
progress_cb,
|
||||
current_user.id,
|
||||
)
|
||||
return ResponseModel(data=target_project)
|
||||
except AppException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Error converting project to canvas: %s", e, exc_info=True)
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
179
backend/src/api/projects/episodes.py
Normal file
179
backend/src/api/projects/episodes.py
Normal file
@@ -0,0 +1,179 @@
|
||||
"""
|
||||
Projects API - Episodes Module
|
||||
|
||||
包含剧集管理相关的 API 路由。
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from src.models.schemas import (
|
||||
ResponseModel,
|
||||
Episode, CreateEpisodeRequest, UpdateEpisodeRequest,
|
||||
)
|
||||
from src.auth.dependencies import get_current_user
|
||||
from src.auth.models import UserAuth
|
||||
from src.services.project_service import project_manager
|
||||
from src.utils.errors import ErrorCode, AppException
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(tags=["projects-episodes"])
|
||||
|
||||
|
||||
def _check_project_access(project_id: str, user_id: str):
|
||||
"""检查用户是否有权限访问项目"""
|
||||
project = project_manager.get_project(project_id, user_id=user_id)
|
||||
return project
|
||||
|
||||
|
||||
# --- Episode Management ---
|
||||
@router.post("/projects/{project_id}/episodes", response_model=ResponseModel)
|
||||
async def create_episode(
|
||||
project_id: str,
|
||||
request: CreateEpisodeRequest,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""Create a new episode"""
|
||||
try:
|
||||
# Check project access
|
||||
if not _check_project_access(project_id, current_user.id):
|
||||
raise AppException(
|
||||
message="Project not found or access denied",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
|
||||
episode_id = str(uuid.uuid4())
|
||||
new_episode = Episode(
|
||||
id=episode_id,
|
||||
title=request.title,
|
||||
order=request.order,
|
||||
desc=request.desc,
|
||||
status=request.status
|
||||
)
|
||||
updated_project = project_manager.add_episode(project_id, new_episode)
|
||||
if not updated_project:
|
||||
raise AppException(
|
||||
message="Project not found or operation failed",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
return ResponseModel(data=updated_project)
|
||||
except Exception as e:
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
@router.put("/projects/{project_id}/episodes/{episode_id}", response_model=ResponseModel)
|
||||
async def update_episode(
|
||||
project_id: str,
|
||||
episode_id: str,
|
||||
request: UpdateEpisodeRequest,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""Update an episode"""
|
||||
try:
|
||||
project = project_manager.get_project(project_id, user_id=current_user.id)
|
||||
if not project:
|
||||
raise AppException(
|
||||
message="Project not found or access denied",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
|
||||
existing_ep = next((ep for ep in project.episodes if ep.id == episode_id), None)
|
||||
if not existing_ep:
|
||||
raise AppException(
|
||||
message="Episode not found",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
|
||||
update_data = request.model_dump(exclude_unset=True, by_alias=True)
|
||||
updated_ep = existing_ep.model_copy(update=update_data)
|
||||
|
||||
updated_project = project_manager.update_episode(project_id, episode_id, updated_ep)
|
||||
return ResponseModel(data=updated_project)
|
||||
except Exception as e:
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
@router.delete("/projects/{project_id}/episodes/{episode_id}", response_model=ResponseModel)
|
||||
async def delete_episode(
|
||||
project_id: str,
|
||||
episode_id: str,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""Delete an episode"""
|
||||
try:
|
||||
project = project_manager.get_project(project_id, user_id=current_user.id)
|
||||
if not project:
|
||||
raise AppException(
|
||||
message="Project not found or access denied",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
|
||||
updated_project = project_manager.delete_episode(project_id, episode_id)
|
||||
if not updated_project:
|
||||
raise AppException(
|
||||
message="Project not found",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
return ResponseModel(data=updated_project)
|
||||
except Exception as e:
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.post("/projects/{project_id}/episodes/{episode_id}/analyze", response_model=ResponseModel)
|
||||
async def analyze_episode(
|
||||
project_id: str,
|
||||
episode_id: str,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Analyze a specific episode: extract summary, characters/scenes/props, storyboards.
|
||||
"""
|
||||
try:
|
||||
from src.services.episode_analysis_service import analyze_episode as analyze_episode_service
|
||||
final_project = await analyze_episode_service(project_id, episode_id, current_user.id)
|
||||
return ResponseModel(data=final_project)
|
||||
except ValueError as e:
|
||||
msg = str(e)
|
||||
if "not found" in msg.lower():
|
||||
raise AppException(
|
||||
message=msg,
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
if "empty" in msg.lower():
|
||||
raise AppException(
|
||||
message=msg,
|
||||
code=ErrorCode.INVALID_PARAMETER,
|
||||
status_code=400
|
||||
)
|
||||
raise AppException(
|
||||
message=msg,
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Episode analysis failed: %s", e, exc_info=True)
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
113
backend/src/api/projects/storyboards.py
Normal file
113
backend/src/api/projects/storyboards.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""
|
||||
Projects API - Storyboards Module
|
||||
|
||||
包含分镜管理相关的 API 路由。
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from src.models.schemas import (
|
||||
ResponseModel,
|
||||
Storyboard, CreateStoryboardRequest, UpdateStoryboardRequest,
|
||||
GenerationRecord,
|
||||
)
|
||||
from src.utils.errors import ErrorCode, AppException
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(tags=["projects-storyboards"])
|
||||
|
||||
|
||||
# --- Storyboard Management ---
|
||||
@router.post("/projects/{project_id}/storyboards", response_model=ResponseModel)
|
||||
async def create_storyboard(project_id: str, request: CreateStoryboardRequest):
|
||||
"""Create a new storyboard"""
|
||||
try:
|
||||
from src.services.project_service import project_manager
|
||||
|
||||
sb_id = str(uuid.uuid4())
|
||||
sb_data = request.model_dump(by_alias=True)
|
||||
sb_data['id'] = sb_id
|
||||
|
||||
new_sb = Storyboard(**sb_data)
|
||||
|
||||
updated_project = project_manager.add_storyboard(project_id, new_sb)
|
||||
if not updated_project:
|
||||
raise AppException(
|
||||
message="Project not found",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
return ResponseModel(data=updated_project)
|
||||
except Exception as e:
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
@router.put("/projects/{project_id}/storyboards/{storyboard_id}", response_model=ResponseModel)
|
||||
async def update_storyboard(project_id: str, storyboard_id: str, request: UpdateStoryboardRequest):
|
||||
"""Update a storyboard"""
|
||||
try:
|
||||
from src.services.project_service import project_manager
|
||||
|
||||
project = project_manager.get_project(project_id)
|
||||
if not project:
|
||||
raise AppException(
|
||||
message="Project not found",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
|
||||
existing_sb = next((sb for sb in project.storyboards if sb.id == storyboard_id), None)
|
||||
if not existing_sb:
|
||||
raise AppException(
|
||||
message="Storyboard not found",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
|
||||
update_data = request.model_dump(exclude_unset=True, by_alias=True)
|
||||
|
||||
# Manually validate generations if present to ensure they are model instances not dicts
|
||||
if 'generations' in update_data and update_data['generations']:
|
||||
update_data['generations'] = TypeAdapter(List[GenerationRecord]).validate_python(update_data['generations'])
|
||||
|
||||
updated_sb = existing_sb.model_copy(update=update_data)
|
||||
|
||||
updated_project = project_manager.update_storyboard(project_id, storyboard_id, updated_sb)
|
||||
return ResponseModel(data=updated_project)
|
||||
except Exception as e:
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
@router.delete("/projects/{project_id}/storyboards/{storyboard_id}", response_model=ResponseModel)
|
||||
async def delete_storyboard(project_id: str, storyboard_id: str):
|
||||
"""Delete a storyboard"""
|
||||
try:
|
||||
from src.services.project_service import project_manager
|
||||
|
||||
updated_project = project_manager.delete_storyboard(project_id, storyboard_id)
|
||||
if not updated_project:
|
||||
raise AppException(
|
||||
message="Project not found",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
return ResponseModel(data=updated_project)
|
||||
except Exception as e:
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
240
backend/src/api/prompt_templates.py
Normal file
240
backend/src/api/prompt_templates.py
Normal file
@@ -0,0 +1,240 @@
|
||||
"""
|
||||
Prompt Template API - 提示词模板接口
|
||||
"""
|
||||
import logging
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
|
||||
from src.models.schemas import ResponseModel
|
||||
from src.models.prompt_template import (
|
||||
PromptTemplateCreate,
|
||||
PromptTemplateUpdate,
|
||||
PromptTemplateCategory
|
||||
)
|
||||
from src.services.prompt_template_service import PromptTemplateService
|
||||
from src.auth.dependencies import get_current_user
|
||||
from src.auth.models import UserAuth
|
||||
from src.utils.errors import AppException, ErrorCode
|
||||
|
||||
router = APIRouter(tags=["prompt-templates"])
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@router.post("/prompt-templates/init", response_model=ResponseModel)
|
||||
async def init_templates(current_user: UserAuth = Depends(get_current_user)):
|
||||
"""初始化系统默认模板(仅管理员)"""
|
||||
# 这里可以添加管理员权限检查
|
||||
PromptTemplateService.init_default_templates()
|
||||
return ResponseModel(message="Templates initialized successfully")
|
||||
|
||||
|
||||
@router.get("/prompt-templates", response_model=ResponseModel)
|
||||
async def list_templates(
|
||||
category: Optional[str] = Query(None, description="分类过滤"),
|
||||
target_type: Optional[str] = Query(None, description="目标类型过滤: image/video/audio/music"),
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
favorites_only: bool = Query(False, description="仅显示收藏"),
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""获取提示词模板列表"""
|
||||
try:
|
||||
result = PromptTemplateService.list_templates(
|
||||
user_id=current_user.id,
|
||||
category=category,
|
||||
target_type=target_type,
|
||||
search=search,
|
||||
favorites_only=favorites_only,
|
||||
page=page,
|
||||
page_size=page_size
|
||||
)
|
||||
return ResponseModel(data=result)
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing templates: {e}")
|
||||
raise AppException(
|
||||
message="Failed to list templates",
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.get("/prompt-templates/categories", response_model=ResponseModel)
|
||||
async def get_categories(
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""获取模板分类列表"""
|
||||
categories = PromptTemplateService.get_categories()
|
||||
return ResponseModel(data=categories)
|
||||
|
||||
|
||||
@router.post("/prompt-templates", response_model=ResponseModel)
|
||||
async def create_template(
|
||||
data: PromptTemplateCreate,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""创建新模板"""
|
||||
try:
|
||||
template = PromptTemplateService.create_template(
|
||||
user_id=current_user.id,
|
||||
data=data
|
||||
)
|
||||
return ResponseModel(
|
||||
message="Template created successfully",
|
||||
data=template
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating template: {e}")
|
||||
raise AppException(
|
||||
message="Failed to create template",
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.get("/prompt-templates/{template_id}", response_model=ResponseModel)
|
||||
async def get_template(
|
||||
template_id: str,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""获取单个模板详情"""
|
||||
template = PromptTemplateService.get_template(
|
||||
template_id=template_id,
|
||||
user_id=current_user.id
|
||||
)
|
||||
|
||||
if not template:
|
||||
raise AppException(
|
||||
message="Template not found",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
|
||||
return ResponseModel(data=template)
|
||||
|
||||
|
||||
@router.put("/prompt-templates/{template_id}", response_model=ResponseModel)
|
||||
async def update_template(
|
||||
template_id: str,
|
||||
data: PromptTemplateUpdate,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""更新模板"""
|
||||
try:
|
||||
template = PromptTemplateService.update_template(
|
||||
template_id=template_id,
|
||||
user_id=current_user.id,
|
||||
data=data
|
||||
)
|
||||
|
||||
if not template:
|
||||
raise AppException(
|
||||
message="Template not found",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
|
||||
return ResponseModel(
|
||||
message="Template updated successfully",
|
||||
data=template
|
||||
)
|
||||
except PermissionError as e:
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.PERMISSION_DENIED,
|
||||
status_code=403
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating template: {e}")
|
||||
raise AppException(
|
||||
message="Failed to update template",
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/prompt-templates/{template_id}", response_model=ResponseModel)
|
||||
async def delete_template(
|
||||
template_id: str,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""删除模板"""
|
||||
try:
|
||||
success = PromptTemplateService.delete_template(
|
||||
template_id=template_id,
|
||||
user_id=current_user.id
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise AppException(
|
||||
message="Template not found",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
|
||||
return ResponseModel(message="Template deleted successfully")
|
||||
except PermissionError as e:
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.PERMISSION_DENIED,
|
||||
status_code=403
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting template: {e}")
|
||||
raise AppException(
|
||||
message="Failed to delete template",
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.post("/prompt-templates/{template_id}/favorite", response_model=ResponseModel)
|
||||
async def toggle_favorite(
|
||||
template_id: str,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""切换模板收藏状态"""
|
||||
try:
|
||||
is_favorite = PromptTemplateService.toggle_favorite(
|
||||
template_id=template_id,
|
||||
user_id=current_user.id
|
||||
)
|
||||
|
||||
return ResponseModel(
|
||||
message="Favorite toggled successfully",
|
||||
data={"is_favorite": is_favorite}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error toggling favorite: {e}")
|
||||
raise AppException(
|
||||
message="Failed to toggle favorite",
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.post("/prompt-templates/{template_id}/apply", response_model=ResponseModel)
|
||||
async def apply_template(
|
||||
template_id: str,
|
||||
user_prompt: str,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""应用模板到提示词"""
|
||||
result = PromptTemplateService.apply_template(
|
||||
template_id=template_id,
|
||||
user_prompt=user_prompt
|
||||
)
|
||||
|
||||
if result is None:
|
||||
raise AppException(
|
||||
message="Template not found",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
|
||||
return ResponseModel(
|
||||
data={
|
||||
"original": user_prompt,
|
||||
"enhanced": result
|
||||
}
|
||||
)
|
||||
195
backend/src/api/skills.py
Normal file
195
backend/src/api/skills.py
Normal file
@@ -0,0 +1,195 @@
|
||||
"""Agent Skills Management API
|
||||
|
||||
Provides endpoints for managing agent skills.
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
from fastapi import APIRouter
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional
|
||||
|
||||
from src.services.agent_engine import AgentScopeService
|
||||
from src.services.agent_engine.toolkit import ToolkitFactory
|
||||
from src.models.schemas import ResponseModel
|
||||
from src.utils.errors import ErrorCode, AppException
|
||||
|
||||
router = APIRouter(tags=["skills"])
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SkillInfo(BaseModel):
|
||||
"""Skill information model"""
|
||||
name: str
|
||||
description: str
|
||||
path: str
|
||||
|
||||
|
||||
class RegisterSkillRequest(BaseModel):
|
||||
"""Request model for registering a skill"""
|
||||
skill_dir: str
|
||||
|
||||
|
||||
class SkillPromptResponse(BaseModel):
|
||||
"""Response model for skill prompt"""
|
||||
prompt: Optional[str]
|
||||
skills_count: int
|
||||
|
||||
|
||||
@router.get("/skills", response_model=ResponseModel)
|
||||
async def list_skills():
|
||||
"""List all available agent skills.
|
||||
|
||||
Returns:
|
||||
List of skill information
|
||||
"""
|
||||
skills = []
|
||||
|
||||
try:
|
||||
# 使用新的 ToolkitFactory
|
||||
all_skills = ToolkitFactory.list_skills()
|
||||
|
||||
for skill_path in all_skills:
|
||||
# skill_path 格式: "domain/skill_name"
|
||||
domain, skill_name = skill_path.split('/')
|
||||
|
||||
# 读取 SKILL.md 获取描述
|
||||
from pathlib import Path
|
||||
# Fix path: services/agent_engine/skills (was services/agents/skills)
|
||||
skill_dir = Path(__file__).parent.parent / "services" / "agent_engine" / "skills" / domain / skill_name
|
||||
skill_md = skill_dir / "SKILL.md"
|
||||
|
||||
description = "No description available"
|
||||
if skill_md.exists():
|
||||
try:
|
||||
with open(skill_md, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
|
||||
# Extract description from YAML front matter
|
||||
if content.startswith('---'):
|
||||
parts = content.split('---', 2)
|
||||
if len(parts) >= 3:
|
||||
import yaml
|
||||
try:
|
||||
metadata = yaml.safe_load(parts[1])
|
||||
description = metadata.get('description', description)
|
||||
except yaml.YAMLError:
|
||||
pass
|
||||
|
||||
skills.append(SkillInfo(
|
||||
name=skill_path, # 使用完整路径作为名称
|
||||
description=description,
|
||||
path=str(skill_dir)
|
||||
))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read skill {skill_path}: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list skills: {e}")
|
||||
|
||||
return ResponseModel(data=skills)
|
||||
|
||||
|
||||
@router.post("/skills/register", response_model=ResponseModel)
|
||||
async def register_skill(request: RegisterSkillRequest):
|
||||
"""Register a new agent skill.
|
||||
|
||||
Args:
|
||||
request: Skill registration request
|
||||
|
||||
Returns:
|
||||
Success message
|
||||
|
||||
Raises:
|
||||
HTTPException: If skill registration fails
|
||||
"""
|
||||
try:
|
||||
# Validate skill directory
|
||||
if not os.path.exists(request.skill_dir):
|
||||
raise AppException(
|
||||
message=f"Skill directory not found: {request.skill_dir}",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
|
||||
skill_md = os.path.join(request.skill_dir, "SKILL.md")
|
||||
if not os.path.exists(skill_md):
|
||||
raise AppException(
|
||||
message="SKILL.md file not found in skill directory",
|
||||
code=ErrorCode.INVALID_PARAMETER,
|
||||
status_code=400
|
||||
)
|
||||
|
||||
# Register skill (this would need to be implemented in AgentScopeService)
|
||||
logger.info(f"Registered skill from: {request.skill_dir}")
|
||||
|
||||
return ResponseModel(data={"message": "Skill registered successfully", "path": request.skill_dir})
|
||||
|
||||
except AppException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to register skill: {e}")
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.get("/skills/prompt", response_model=ResponseModel)
|
||||
async def get_skills_prompt():
|
||||
"""Get the combined prompt for all registered skills.
|
||||
|
||||
This prompt can be attached to the agent's system prompt.
|
||||
|
||||
Returns:
|
||||
Combined skills prompt
|
||||
"""
|
||||
try:
|
||||
# 使用 ToolkitFactory
|
||||
all_skills = ToolkitFactory.list_skills()
|
||||
skills_count = len(all_skills)
|
||||
|
||||
# 获取 skills prompt
|
||||
if skills_count > 0:
|
||||
toolkit = ToolkitFactory.get_toolkit()
|
||||
prompt = toolkit.get_agent_skill_prompt()
|
||||
else:
|
||||
prompt = None
|
||||
|
||||
return ResponseModel(data={"prompt": prompt, "skills_count": skills_count})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get skills prompt: {e}")
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/skills/{skill_name}", response_model=ResponseModel)
|
||||
async def remove_skill(skill_name: str):
|
||||
"""Remove a registered agent skill.
|
||||
|
||||
Args:
|
||||
skill_name: Name of the skill to remove
|
||||
|
||||
Returns:
|
||||
Success message
|
||||
|
||||
Raises:
|
||||
HTTPException: If skill removal fails
|
||||
"""
|
||||
try:
|
||||
# This would need to be implemented in AgentScopeService
|
||||
logger.info(f"Removed skill: {skill_name}")
|
||||
|
||||
return ResponseModel(data={"message": f"Skill '{skill_name}' removed successfully"})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to remove skill: {e}")
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
112
backend/src/api/storage.py
Normal file
112
backend/src/api/storage.py
Normal file
@@ -0,0 +1,112 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, Query, Response
|
||||
from fastapi.responses import StreamingResponse
|
||||
import httpx
|
||||
from src.models.schemas import ResponseModel
|
||||
from src.services.storage_service import storage_manager
|
||||
from src.utils.errors import InvalidParameterException, StorageException
|
||||
|
||||
router = APIRouter(prefix="/storage", tags=["Storage"])
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@router.delete("/files", response_model=ResponseModel)
|
||||
async def delete_file(path: str = Query(..., description="File path or URL to delete")):
|
||||
""" 删除 a file from storage.
|
||||
|
||||
Raises:
|
||||
InvalidParameterException: Path is invalid
|
||||
StorageException: Failed to delete file
|
||||
"""
|
||||
# 参数 validation
|
||||
if not path or not path.strip():
|
||||
raise InvalidParameterException("path", "File path cannot be empty")
|
||||
|
||||
# Blob URLs are client-side only, nothing to delete on server
|
||||
if path.startswith('blob:'):
|
||||
return ResponseModel(data={"deleted": True, "path": path})
|
||||
|
||||
try:
|
||||
# 提取 key from URL if needed
|
||||
key = path
|
||||
if path.startswith(('http://', 'https://')):
|
||||
from urllib.parse import urlparse
|
||||
parsed = urlparse(path)
|
||||
key = parsed.path.lstrip('/')
|
||||
|
||||
# Handle local storage prefix /files/
|
||||
if key.startswith('files/'):
|
||||
key = key[6:]
|
||||
|
||||
success = storage_manager.delete(key)
|
||||
if success:
|
||||
return ResponseModel(data={"deleted": True, "path": key})
|
||||
else:
|
||||
# File doesn't exist or deletion failed
|
||||
raise StorageException("delete", "File not found or deletion failed")
|
||||
|
||||
except StorageException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete file {path}: {e}", exc_info=True)
|
||||
raise StorageException("delete", str(e))
|
||||
|
||||
|
||||
@router.get("/download")
|
||||
async def download_file(url: str = Query(..., description="Original or signed media URL"),
|
||||
filename: Optional[str] = Query(None, description="Optional download filename")):
|
||||
"""
|
||||
专用下载接口:
|
||||
- 后端代理拉取远端资源
|
||||
- 设置 Content-Disposition=attachment,避免浏览器直接在当前页打开视频/音频
|
||||
"""
|
||||
if not url or not url.strip():
|
||||
raise InvalidParameterException("url", "Download url cannot be empty")
|
||||
|
||||
# 如果是 OSS 原始地址,可以在这里重新签名(视业务需要)
|
||||
try:
|
||||
signed_url = storage_manager.sign_url(url) or url
|
||||
except Exception:
|
||||
# 签名失败时退回原始 URL
|
||||
signed_url = url
|
||||
|
||||
try:
|
||||
client_timeout = httpx.Timeout(60.0)
|
||||
async with httpx.AsyncClient(timeout=client_timeout, follow_redirects=True) as client:
|
||||
upstream = await client.get(signed_url)
|
||||
upstream.raise_for_status()
|
||||
|
||||
# 透传部分头部(例如 Content-Type)
|
||||
content_type = upstream.headers.get("content-type", "application/octet-stream")
|
||||
content_length = upstream.headers.get("content-length")
|
||||
|
||||
# 下载文件名:优先使用 query 参数,其次从 URL / 头部推断
|
||||
final_name = filename
|
||||
if not final_name:
|
||||
# 从 URL path 中取最后一段
|
||||
from urllib.parse import urlparse
|
||||
parsed = urlparse(url)
|
||||
candidate = (parsed.path or "").rstrip("/").split("/")[-1] or "download"
|
||||
final_name = candidate
|
||||
|
||||
headers = {
|
||||
"Content-Disposition": f'attachment; filename="{final_name}"',
|
||||
"Content-Type": content_type,
|
||||
}
|
||||
if content_length is not None:
|
||||
headers["Content-Length"] = content_length
|
||||
|
||||
return StreamingResponse(
|
||||
iter(upstream.iter_bytes()),
|
||||
status_code=upstream.status_code,
|
||||
headers=headers,
|
||||
media_type=content_type,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to proxy download for {url}: {e}", exc_info=True)
|
||||
# 返回一个统一错误响应,而不是让浏览器打开原地址
|
||||
return Response(
|
||||
content="Failed to download file",
|
||||
status_code=500,
|
||||
media_type="text/plain; charset=utf-8",
|
||||
)
|
||||
289
backend/src/api/storage_admin.py
Normal file
289
backend/src/api/storage_admin.py
Normal file
@@ -0,0 +1,289 @@
|
||||
"""
|
||||
Storage Admin API
|
||||
|
||||
存储管理 API 端点,用于管理员管理存储资源。
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional, Dict, Any, List
|
||||
from datetime import datetime
|
||||
from fastapi import APIRouter, Depends, Query, Request
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlmodel import Session, select, func
|
||||
|
||||
from src.auth.dependencies import require_admin
|
||||
from src.auth.models import UserAuth
|
||||
from src.models.schemas import ResponseModel
|
||||
from src.utils.pagination import Paginator
|
||||
from src.config.database import engine
|
||||
from src.models.entities import ProjectDB, UserDB
|
||||
from src.services.storage_service import storage_manager
|
||||
from src.utils.errors import ErrorCode, AppException
|
||||
|
||||
router = APIRouter(prefix="/admin/storage", tags=["admin-storage"])
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ===== 响应模型 =====
|
||||
|
||||
class StorageStatsResponse(BaseModel):
|
||||
"""存储统计响应"""
|
||||
total_capacity: int = Field(..., description="总容量 (bytes)")
|
||||
used_space: int = Field(..., description="已用空间 (bytes)")
|
||||
file_count: int = Field(..., description="文件数量")
|
||||
usage_percent: float = Field(..., description="使用百分比")
|
||||
|
||||
|
||||
class StorageFileItem(BaseModel):
|
||||
"""存储文件项"""
|
||||
id: str
|
||||
path: str
|
||||
size: int
|
||||
user_id: str
|
||||
username: Optional[str]
|
||||
project_id: Optional[str]
|
||||
project_name: Optional[str]
|
||||
created_at: float
|
||||
|
||||
|
||||
class StorageUserRankingItem(BaseModel):
|
||||
"""用户存储排行项"""
|
||||
user_id: str
|
||||
username: str
|
||||
storage_used: int
|
||||
file_count: int
|
||||
rank: int
|
||||
|
||||
|
||||
class StorageConfigResponse(BaseModel):
|
||||
"""存储配置响应"""
|
||||
storage_type: str
|
||||
base_path: Optional[str]
|
||||
max_file_size: int
|
||||
allowed_extensions: List[str]
|
||||
|
||||
|
||||
# ===== API 端点 =====
|
||||
|
||||
@router.get("/stats", response_model=ResponseModel)
|
||||
async def get_storage_stats(
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
):
|
||||
"""
|
||||
获取存储统计信息
|
||||
|
||||
需要管理员权限。
|
||||
"""
|
||||
try:
|
||||
# 计算存储统计
|
||||
with Session(engine) as session:
|
||||
# 这里需要根据实际存储实现调整
|
||||
# 暂时返回一个基础统计
|
||||
total_capacity = 100 * 1024 * 1024 * 1024 # 100GB
|
||||
used_space = 0 # 待实现
|
||||
file_count = 0 # 待实现
|
||||
|
||||
return ResponseModel(
|
||||
data=StorageStatsResponse(
|
||||
total_capacity=total_capacity,
|
||||
used_space=used_space,
|
||||
file_count=file_count,
|
||||
usage_percent=(used_space / total_capacity * 100) if total_capacity > 0 else 0,
|
||||
).model_dump()
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting storage stats: {e}")
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.get("/files", response_model=ResponseModel)
|
||||
async def list_storage_files(
|
||||
request: Request,
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
|
||||
user_id: Optional[str] = Query(None, description="按用户 ID 过滤"),
|
||||
project_id: Optional[str] = Query(None, description="按项目 ID 过滤"),
|
||||
search: Optional[str] = Query(None, description="搜索文件路径"),
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
):
|
||||
"""
|
||||
列出存储文件
|
||||
|
||||
支持按用户、项目过滤。
|
||||
需要管理员权限。
|
||||
"""
|
||||
try:
|
||||
with Session(engine) as session:
|
||||
# 查询项目关联的文件
|
||||
query = select(ProjectDB)
|
||||
|
||||
if user_id:
|
||||
query = query.where(ProjectDB.user_id == user_id)
|
||||
|
||||
total = session.exec(
|
||||
select(func.count()).select_from(query.subquery())
|
||||
).one()
|
||||
|
||||
offset = (page - 1) * page_size
|
||||
projects = session.exec(query.offset(offset).limit(page_size)).all()
|
||||
|
||||
items = []
|
||||
for project in projects:
|
||||
items.append(
|
||||
StorageFileItem(
|
||||
id=project.id,
|
||||
path=f"/projects/{project.id}",
|
||||
size=0, # 待实现
|
||||
user_id=project.user_id or "",
|
||||
username=None, # 待关联用户表
|
||||
project_id=project.id,
|
||||
project_name=project.name,
|
||||
created_at=project.created_at,
|
||||
).model_dump()
|
||||
)
|
||||
|
||||
paginator = Paginator(
|
||||
items=items,
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
return paginator.to_response(request)
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing storage files: {e}")
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.get("/users/ranking", response_model=ResponseModel)
|
||||
async def get_storage_user_ranking(
|
||||
request: Request,
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
):
|
||||
"""
|
||||
获取用户存储使用排行
|
||||
|
||||
需要管理员权限。
|
||||
"""
|
||||
try:
|
||||
with Session(engine) as session:
|
||||
# 查询用户项目数量作为排行依据
|
||||
query = """
|
||||
SELECT user_id, COUNT(*) as project_count
|
||||
FROM projects
|
||||
WHERE deleted_at IS NULL
|
||||
GROUP BY user_id
|
||||
ORDER BY project_count DESC
|
||||
"""
|
||||
# 简化的实现
|
||||
users_query = select(UserDB)
|
||||
users = session.exec(users_query.limit(page_size).offset((page - 1) * page_size)).all()
|
||||
|
||||
items = []
|
||||
for idx, user in enumerate(users):
|
||||
project_count = session.exec(
|
||||
select(func.count()).where(ProjectDB.user_id == user.id)
|
||||
).one()
|
||||
|
||||
items.append(
|
||||
StorageUserRankingItem(
|
||||
user_id=user.id,
|
||||
username=user.username,
|
||||
storage_used=0, # 待实现
|
||||
file_count=project_count,
|
||||
rank=(page - 1) * page_size + idx + 1,
|
||||
).model_dump()
|
||||
)
|
||||
|
||||
total = session.exec(select(func.count()).select_from(UserDB)).one()
|
||||
|
||||
paginator = Paginator(
|
||||
items=items,
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
return paginator.to_response(request)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting storage ranking: {e}")
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.post("/cleanup", response_model=ResponseModel)
|
||||
async def cleanup_orphan_files(
|
||||
dry_run: bool = Query(True, description="是否仅模拟执行"),
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
):
|
||||
"""
|
||||
清理孤立文件(没有关联项目的文件)
|
||||
|
||||
需要管理员权限。
|
||||
"""
|
||||
try:
|
||||
# 简化的实现
|
||||
return ResponseModel(
|
||||
data={
|
||||
"orphan_count": 0,
|
||||
"deleted_count": 0,
|
||||
"deleted_size": 0,
|
||||
"dry_run": dry_run,
|
||||
"message": "清理功能待实现",
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up orphan files: {e}")
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.get("/config", response_model=ResponseModel)
|
||||
async def get_storage_config(
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
):
|
||||
"""
|
||||
获取存储配置信息
|
||||
|
||||
需要管理员权限。
|
||||
"""
|
||||
try:
|
||||
from src.config.settings import (
|
||||
STORAGE_TYPE,
|
||||
STORAGE_BASE_PATH,
|
||||
STORAGE_MAX_FILE_SIZE,
|
||||
STORAGE_ALLOWED_EXTENSIONS,
|
||||
)
|
||||
|
||||
return ResponseModel(
|
||||
data=StorageConfigResponse(
|
||||
storage_type=STORAGE_TYPE or "local",
|
||||
base_path=STORAGE_BASE_PATH,
|
||||
max_file_size=STORAGE_MAX_FILE_SIZE,
|
||||
allowed_extensions=STORAGE_ALLOWED_EXTENSIONS or [],
|
||||
).model_dump()
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting storage config: {e}")
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
175
backend/src/api/tasks.py
Normal file
175
backend/src/api/tasks.py
Normal file
@@ -0,0 +1,175 @@
|
||||
"""
|
||||
Tasks Controller - 任务管理
|
||||
使用新的任务管理器 V2
|
||||
"""
|
||||
import logging
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, Query, Request, Depends
|
||||
|
||||
from src.models.schemas import ResponseModel, Task, PaginationParams
|
||||
from src.services.task_manager import task_manager, TaskPriority
|
||||
from src.auth.dependencies import get_current_user
|
||||
from src.auth.models import UserAuth
|
||||
from src.utils.errors import (
|
||||
TaskNotFoundException,
|
||||
InvalidParameterException,
|
||||
BusinessException,
|
||||
ErrorCode
|
||||
)
|
||||
from src.utils.pagination import Paginator
|
||||
|
||||
router = APIRouter(tags=["tasks"])
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# 注意:更具体的路由(如 /tasks/stats)必须在参数化路由(如 /tasks/{task_id})之前定义
|
||||
# 否则 FastAPI 会将 "stats" 当作 task_id 处理
|
||||
|
||||
|
||||
@router.get("/tasks/stats", response_model=ResponseModel)
|
||||
async def get_task_stats():
|
||||
"""获取任务统计信息
|
||||
|
||||
Returns:
|
||||
任务统计数据
|
||||
"""
|
||||
stats = task_manager.get_stats()
|
||||
return ResponseModel(data=stats)
|
||||
|
||||
|
||||
@router.get("/tasks", response_model=ResponseModel)
|
||||
async def list_tasks(
|
||||
request: Request,
|
||||
task_type: Optional[str] = Query(None, description="Filter by task type (image, video, script)"),
|
||||
status: Optional[str] = Query(None, description="Filter by status"),
|
||||
page: int = Query(1, ge=1, description="页码,从1开始"),
|
||||
page_size: int = Query(20, ge=1, le=100, description="每页数量,最大100"),
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""列出任务(分页)
|
||||
|
||||
只返回当前用户的任务
|
||||
|
||||
Args:
|
||||
task_type: 任务类型过滤
|
||||
status: 状态过滤
|
||||
page: 页码
|
||||
page_size: 每页数量
|
||||
|
||||
Returns:
|
||||
分页的任务列表
|
||||
"""
|
||||
try:
|
||||
# 计算偏移量
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
# 获取当前用户的任务
|
||||
tasks = task_manager.list_tasks(
|
||||
type=task_type,
|
||||
limit=page_size,
|
||||
offset=offset,
|
||||
user_id=current_user.id
|
||||
)
|
||||
|
||||
# 计算总数
|
||||
total = len(tasks) # 简化处理,实际应该查询总数
|
||||
|
||||
# 创建分页器
|
||||
paginator = Paginator(
|
||||
items=tasks,
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size
|
||||
)
|
||||
|
||||
return paginator.to_response(request)
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing tasks: {e}", exc_info=True)
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INTERNAL_ERROR,
|
||||
status_code=500
|
||||
)
|
||||
|
||||
|
||||
@router.get("/tasks/{task_id}", response_model=ResponseModel)
|
||||
async def get_task_status(
|
||||
task_id: str,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""获取任务状态
|
||||
|
||||
只能访问当前用户的任务
|
||||
|
||||
Args:
|
||||
task_id: 任务 ID
|
||||
|
||||
Returns:
|
||||
任务详情
|
||||
|
||||
Raises:
|
||||
TaskNotFoundException: 任务不存在或无权访问
|
||||
"""
|
||||
task = await task_manager.get_task(task_id)
|
||||
|
||||
if not task:
|
||||
raise TaskNotFoundException(task_id)
|
||||
|
||||
# 检查任务是否属于当前用户
|
||||
if task.user_id and task.user_id != current_user.id:
|
||||
raise TaskNotFoundException(task_id)
|
||||
|
||||
return ResponseModel(data=task)
|
||||
|
||||
|
||||
@router.delete("/tasks/{task_id}", response_model=ResponseModel)
|
||||
async def cancel_task(
|
||||
task_id: str,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""取消任务
|
||||
|
||||
只能取消当前用户的任务
|
||||
|
||||
Args:
|
||||
task_id: 任务 ID
|
||||
|
||||
Returns:
|
||||
取消结果
|
||||
|
||||
Raises:
|
||||
TaskNotFoundException: 任务不存在或无权访问
|
||||
BusinessException: 取消失败
|
||||
"""
|
||||
try:
|
||||
# 先检查任务是否属于当前用户
|
||||
task = await task_manager.get_task(task_id)
|
||||
if not task:
|
||||
raise TaskNotFoundException(task_id)
|
||||
|
||||
if task.user_id and task.user_id != current_user.id:
|
||||
raise TaskNotFoundException(task_id)
|
||||
|
||||
success = await task_manager.cancel_task(task_id)
|
||||
|
||||
if not success:
|
||||
raise BusinessException(
|
||||
ErrorCode.TASK_CANCELLED,
|
||||
"Task cannot be cancelled (already completed or failed)",
|
||||
{"task_id": task_id}
|
||||
)
|
||||
|
||||
return ResponseModel(
|
||||
message="Task cancelled successfully",
|
||||
data={"task_id": task_id, "cancelled": True}
|
||||
)
|
||||
|
||||
except TaskNotFoundException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cancel task {task_id}: {e}", exc_info=True)
|
||||
raise BusinessException(
|
||||
ErrorCode.UNKNOWN_ERROR,
|
||||
"Failed to cancel task",
|
||||
{"task_id": task_id, "reason": str(e)}
|
||||
)
|
||||
371
backend/src/api/user_api_keys.py
Normal file
371
backend/src/api/user_api_keys.py
Normal file
@@ -0,0 +1,371 @@
|
||||
"""
|
||||
用户 API Key 管理接口
|
||||
|
||||
提供用户管理自己的 API Key 的 CRUD 接口,以及管理员管理所有用户 API Key 的接口。
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
from fastapi import APIRouter, Depends, status, Query, Request
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from src.auth.dependencies import get_current_user, require_admin
|
||||
from src.auth.models import UserAuth
|
||||
from src.services.user_api_key_service import user_api_key_service
|
||||
from src.models.schemas import ResponseModel, PaginationParams
|
||||
from src.utils.pagination import Paginator
|
||||
from src.utils.errors import ErrorCode, AppException
|
||||
|
||||
router = APIRouter(prefix="/user-api-keys", tags=["user-api-keys"])
|
||||
|
||||
|
||||
# ===== 请求/响应模型 =====
|
||||
|
||||
class ApiKeyCreateRequest(BaseModel):
|
||||
"""创建 API Key 请求"""
|
||||
provider: str = Field(..., description="提供商,如 openai, dashscope")
|
||||
api_key: str = Field(..., description="API Key 值", min_length=1)
|
||||
name: Optional[str] = Field(None, description="自定义名称")
|
||||
extra_config: Optional[dict] = Field(None, description="额外配置")
|
||||
|
||||
|
||||
class ApiKeyUpdateRequest(BaseModel):
|
||||
"""更新 API Key 请求"""
|
||||
name: Optional[str] = Field(None, description="新名称")
|
||||
api_key: Optional[str] = Field(None, description="新 API Key 值")
|
||||
is_active: Optional[bool] = Field(None, description="是否启用")
|
||||
extra_config: Optional[dict] = Field(None, description="额外配置")
|
||||
|
||||
|
||||
class ApiKeyResponse(BaseModel):
|
||||
"""API Key 响应"""
|
||||
id: str
|
||||
user_id: str
|
||||
provider: str
|
||||
name: str
|
||||
masked_key: str
|
||||
is_active: bool
|
||||
created_at: float
|
||||
updated_at: float
|
||||
last_used_at: Optional[float]
|
||||
usage_count: int
|
||||
extra_config: Optional[dict]
|
||||
|
||||
|
||||
class ApiKeyVerifyResponse(BaseModel):
|
||||
"""API Key 验证响应"""
|
||||
valid: bool
|
||||
provider: Optional[str] = None
|
||||
message: Optional[str] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
# ===== 管理员响应模型 =====
|
||||
|
||||
class AdminApiKeyListItem(BaseModel):
|
||||
"""管理员 API Key 列表项"""
|
||||
id: str
|
||||
user_id: str
|
||||
username: str = Field(..., description="关联用户名")
|
||||
email: str = Field(..., description="关联用户邮箱")
|
||||
provider: str
|
||||
name: str
|
||||
masked_key: str
|
||||
is_active: bool
|
||||
created_at: float
|
||||
last_used_at: Optional[float]
|
||||
usage_count: int
|
||||
|
||||
|
||||
class AdminApiKeyUsageRecord(BaseModel):
|
||||
"""API Key 使用记录"""
|
||||
task_id: Optional[str]
|
||||
task_type: Optional[str]
|
||||
model: Optional[str]
|
||||
provider: str
|
||||
credits_used: Optional[float]
|
||||
created_at: str
|
||||
|
||||
|
||||
# ===== API 端点 =====
|
||||
|
||||
@router.get("", response_model=ResponseModel)
|
||||
async def list_api_keys(
|
||||
current_user: UserAuth = Depends(get_current_user),
|
||||
include_inactive: bool = False
|
||||
):
|
||||
"""
|
||||
获取当前用户的所有 API Key
|
||||
|
||||
Args:
|
||||
include_inactive: 是否包含已禁用的 Key
|
||||
"""
|
||||
keys = user_api_key_service.get_user_api_keys(
|
||||
user_id=current_user.id,
|
||||
include_inactive=include_inactive
|
||||
)
|
||||
return ResponseModel(data=keys)
|
||||
|
||||
|
||||
@router.post("", status_code=status.HTTP_201_CREATED, response_model=ResponseModel)
|
||||
async def create_api_key(
|
||||
request: ApiKeyCreateRequest,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
创建新的 API Key
|
||||
"""
|
||||
try:
|
||||
key_db = user_api_key_service.create_api_key(
|
||||
user_id=current_user.id,
|
||||
provider=request.provider,
|
||||
api_key=request.api_key,
|
||||
name=request.name,
|
||||
extra_config=request.extra_config
|
||||
)
|
||||
key = user_api_key_service.get_api_key_by_id(key_db.id, current_user.id)
|
||||
return ResponseModel(data=key)
|
||||
except ValueError as e:
|
||||
raise AppException(
|
||||
message=str(e),
|
||||
code=ErrorCode.INVALID_PARAMETER,
|
||||
status_code=400
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{key_id}", response_model=ResponseModel)
|
||||
async def get_api_key(
|
||||
key_id: str,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
获取指定 API Key 的详细信息
|
||||
"""
|
||||
key = user_api_key_service.get_api_key_by_id(key_id, current_user.id)
|
||||
if not key:
|
||||
raise AppException(
|
||||
message="API Key not found",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
return ResponseModel(data=key)
|
||||
|
||||
|
||||
@router.put("/{key_id}", response_model=ResponseModel)
|
||||
async def update_api_key(
|
||||
key_id: str,
|
||||
request: ApiKeyUpdateRequest,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
更新 API Key
|
||||
"""
|
||||
key = user_api_key_service.update_api_key(
|
||||
key_id=key_id,
|
||||
user_id=current_user.id,
|
||||
name=request.name,
|
||||
api_key=request.api_key,
|
||||
is_active=request.is_active,
|
||||
extra_config=request.extra_config
|
||||
)
|
||||
if not key:
|
||||
raise AppException(
|
||||
message="API Key not found",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
return ResponseModel(data=key)
|
||||
|
||||
|
||||
@router.delete("/{key_id}", response_model=ResponseModel)
|
||||
async def delete_api_key(
|
||||
key_id: str,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
删除 API Key
|
||||
"""
|
||||
success = user_api_key_service.delete_api_key(key_id, current_user.id)
|
||||
if not success:
|
||||
raise AppException(
|
||||
message="API Key not found",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
return ResponseModel(data={"deleted": True, "id": key_id})
|
||||
|
||||
|
||||
@router.post("/{key_id}/verify", response_model=ResponseModel)
|
||||
async def verify_api_key(
|
||||
key_id: str,
|
||||
current_user: UserAuth = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
验证 API Key 是否有效
|
||||
"""
|
||||
result = await user_api_key_service.verify_api_key(key_id, current_user.id)
|
||||
return ResponseModel(data=result)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 管理员 API 端点
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/admin/api-keys", response_model=ResponseModel)
|
||||
async def admin_list_api_keys(
|
||||
request: Request,
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
|
||||
user_id: Optional[str] = Query(None, description="按用户 ID 过滤"),
|
||||
provider: Optional[str] = Query(None, description="按提供商过滤"),
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
):
|
||||
"""
|
||||
管理员列出所有 API Keys
|
||||
|
||||
支持分页、按用户 ID、提供商过滤。
|
||||
需要管理员权限。
|
||||
"""
|
||||
from sqlmodel import Session, select, func
|
||||
from src.config.database import engine
|
||||
from src.models.entities import UserApiKeyDB, UserDB
|
||||
|
||||
with Session(engine) as session:
|
||||
query = select(UserApiKeyDB)
|
||||
|
||||
# 应用过滤
|
||||
if user_id:
|
||||
query = query.where(UserApiKeyDB.user_id == user_id)
|
||||
if provider:
|
||||
query = query.where(UserApiKeyDB.provider == provider)
|
||||
|
||||
# 获取总数
|
||||
total = session.exec(
|
||||
select(func.count()).select_from(query.subquery())
|
||||
).one()
|
||||
|
||||
# 分页
|
||||
offset = (page - 1) * page_size
|
||||
query = query.order_by(UserApiKeyDB.created_at.desc()).offset(offset).limit(page_size)
|
||||
|
||||
api_keys = session.exec(query).all()
|
||||
|
||||
# 构建响应项(关联用户信息)
|
||||
items = []
|
||||
for key in api_keys:
|
||||
user = session.get(UserDB, key.user_id)
|
||||
# 获取脱敏 key
|
||||
raw_key = key.encrypted_key[:8] + "***" if key.encrypted_key else ""
|
||||
items.append(
|
||||
AdminApiKeyListItem(
|
||||
id=key.id,
|
||||
user_id=key.user_id,
|
||||
username=user.username if user else "Unknown",
|
||||
email=user.email if user else "",
|
||||
provider=key.provider,
|
||||
name=key.name or f"{key.provider} Key",
|
||||
masked_key=raw_key,
|
||||
is_active=key.is_active,
|
||||
created_at=key.created_at,
|
||||
last_used_at=key.last_used_at,
|
||||
usage_count=key.usage_count,
|
||||
)
|
||||
)
|
||||
|
||||
paginator = Paginator(
|
||||
items=[item.model_dump() for item in items],
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
return paginator.to_response(request)
|
||||
|
||||
|
||||
@router.post("/admin/api-keys/{key_id}/revoke", response_model=ResponseModel)
|
||||
async def admin_revoke_api_key(
|
||||
key_id: str,
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
):
|
||||
"""
|
||||
撤销 API Key
|
||||
|
||||
需要管理员权限。
|
||||
"""
|
||||
from sqlmodel import Session, select
|
||||
from src.config.database import engine
|
||||
from src.models.entities import UserApiKeyDB
|
||||
|
||||
with Session(engine) as session:
|
||||
key = session.get(UserApiKeyDB, key_id)
|
||||
if not key:
|
||||
raise AppException(
|
||||
message="API Key not found",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
|
||||
key.is_active = False
|
||||
key.updated_at = datetime.now().timestamp()
|
||||
session.add(key)
|
||||
session.commit()
|
||||
|
||||
return ResponseModel(
|
||||
data={
|
||||
"id": key_id,
|
||||
"is_active": False,
|
||||
"message": "API Key 已成功撤销",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.get("/admin/api-keys/{key_id}/usage", response_model=ResponseModel)
|
||||
async def admin_get_api_key_usage(
|
||||
key_id: str,
|
||||
current_user: UserAuth = Depends(require_admin),
|
||||
):
|
||||
"""
|
||||
获取 API Key 使用记录
|
||||
|
||||
需要管理员权限。
|
||||
"""
|
||||
from sqlmodel import Session, select
|
||||
from src.config.database import engine
|
||||
from src.models.entities import UserApiKeyDB, TaskDB
|
||||
|
||||
with Session(engine) as session:
|
||||
key = session.get(UserApiKeyDB, key_id)
|
||||
if not key:
|
||||
raise AppException(
|
||||
message="API Key not found",
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
status_code=404
|
||||
)
|
||||
|
||||
# 查询相关任务记录
|
||||
tasks = session.exec(
|
||||
select(TaskDB)
|
||||
.where(TaskDB.user_id == key.user_id)
|
||||
.order_by(TaskDB.created_at.desc())
|
||||
.limit(50)
|
||||
).all()
|
||||
|
||||
usage_records = [
|
||||
AdminApiKeyUsageRecord(
|
||||
task_id=task.id,
|
||||
task_type=task.type,
|
||||
model=task.model,
|
||||
provider=task.provider or key.provider,
|
||||
credits_used=None,
|
||||
created_at=datetime.fromtimestamp(task.created_at).isoformat(),
|
||||
)
|
||||
for task in tasks
|
||||
]
|
||||
|
||||
return ResponseModel(
|
||||
data={
|
||||
"key_id": key_id,
|
||||
"usage_count": key.usage_count,
|
||||
"records": [r.model_dump() for r in usage_records],
|
||||
}
|
||||
)
|
||||
41
backend/src/auth/__init__.py
Normal file
41
backend/src/auth/__init__.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""
|
||||
认证授权模块
|
||||
|
||||
提供 JWT Token 认证、OAuth2 集成和 HTTP 轮询状态查询。
|
||||
"""
|
||||
|
||||
from src.auth.jwt import (
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
verify_token,
|
||||
verify_refresh_token,
|
||||
TokenPayload,
|
||||
TokenPair,
|
||||
)
|
||||
from src.auth.dependencies import (
|
||||
get_current_user,
|
||||
get_current_active_user,
|
||||
require_permissions,
|
||||
)
|
||||
from src.auth.middleware import AuthMiddleware
|
||||
from src.auth.models import UserAuth, TokenData, RefreshTokenRequest
|
||||
|
||||
__all__ = [
|
||||
# JWT
|
||||
"create_access_token",
|
||||
"create_refresh_token",
|
||||
"verify_token",
|
||||
"verify_refresh_token",
|
||||
"TokenPayload",
|
||||
"TokenPair",
|
||||
# Dependencies
|
||||
"get_current_user",
|
||||
"get_current_active_user",
|
||||
"require_permissions",
|
||||
# Middleware
|
||||
"AuthMiddleware",
|
||||
# Models
|
||||
"UserAuth",
|
||||
"TokenData",
|
||||
"RefreshTokenRequest",
|
||||
]
|
||||
250
backend/src/auth/dependencies.py
Normal file
250
backend/src/auth/dependencies.py
Normal file
@@ -0,0 +1,250 @@
|
||||
"""
|
||||
FastAPI 认证依赖
|
||||
|
||||
提供可注入的依赖函数用于获取当前用户和验证权限。
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional, List, Annotated
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer, HTTPBearer
|
||||
|
||||
from src.auth.jwt import verify_token
|
||||
from src.auth.models import UserAuth, TokenPayload
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _is_test_env() -> bool:
|
||||
return (
|
||||
os.getenv("PYTEST_CURRENT_TEST") is not None
|
||||
or os.getenv("PIXEL_TEST_MODE") == "1"
|
||||
or os.getenv("NODE_ENV") == "test"
|
||||
)
|
||||
|
||||
|
||||
def _build_test_user() -> UserAuth:
|
||||
return UserAuth(
|
||||
id="test-user",
|
||||
username="test_user",
|
||||
email=None,
|
||||
is_active=True,
|
||||
is_superuser=True,
|
||||
permissions=["*"],
|
||||
roles=["test"],
|
||||
created_at=None,
|
||||
last_login=None,
|
||||
)
|
||||
|
||||
# Lazy import to avoid circular import
|
||||
_user_service = None
|
||||
|
||||
def _get_user_service():
|
||||
global _user_service
|
||||
if _user_service is None:
|
||||
from src.services.user_service import user_service
|
||||
_user_service = user_service
|
||||
return _user_service
|
||||
|
||||
# OAuth2 scheme for Swagger UI
|
||||
oauth2_scheme = OAuth2PasswordBearer(
|
||||
tokenUrl="/api/v1/auth/login",
|
||||
auto_error=False
|
||||
)
|
||||
|
||||
# HTTP Bearer for direct token usage
|
||||
http_bearer = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
token: Annotated[Optional[str], Depends(oauth2_scheme)]
|
||||
) -> UserAuth:
|
||||
"""
|
||||
获取当前认证用户(从 HTTP 请求)
|
||||
|
||||
Args:
|
||||
token: OAuth2 token from Authorization header
|
||||
|
||||
Returns:
|
||||
UserAuth 对象
|
||||
|
||||
Raises:
|
||||
HTTPException: 401 if authentication fails
|
||||
"""
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
if not token:
|
||||
if _is_test_env():
|
||||
return _build_test_user()
|
||||
raise credentials_exception
|
||||
|
||||
payload = await verify_token(token, token_type="access")
|
||||
if not payload:
|
||||
raise credentials_exception
|
||||
|
||||
if payload.sid:
|
||||
from src.services.session_service import session_service
|
||||
|
||||
if not session_service.is_session_active(payload.sid):
|
||||
raise credentials_exception
|
||||
|
||||
# 从缓存或数据库获取用户信息
|
||||
user = await _get_user_service().get_user_by_id(payload.sub)
|
||||
|
||||
if not user:
|
||||
raise credentials_exception
|
||||
|
||||
if not user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="User is inactive"
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
async def get_current_active_user(
|
||||
current_user: Annotated[UserAuth, Depends(get_current_user)]
|
||||
) -> UserAuth:
|
||||
"""
|
||||
获取当前活跃用户
|
||||
|
||||
Args:
|
||||
current_user: 当前用户
|
||||
|
||||
Returns:
|
||||
UserAuth 对象
|
||||
|
||||
Raises:
|
||||
HTTPException: 403 if user is inactive
|
||||
"""
|
||||
if not current_user.is_active:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Inactive user"
|
||||
)
|
||||
return current_user
|
||||
|
||||
|
||||
def require_permissions(required_permissions: List[str]):
|
||||
"""
|
||||
权限检查依赖工厂
|
||||
|
||||
Args:
|
||||
required_permissions: 需要的权限列表
|
||||
|
||||
Returns:
|
||||
依赖函数
|
||||
|
||||
Usage:
|
||||
@router.get("/admin")
|
||||
async def admin_endpoint(
|
||||
user: UserAuth = Depends(require_permissions(["admin"]))
|
||||
):
|
||||
...
|
||||
"""
|
||||
async def permission_checker(
|
||||
current_user: Annotated[UserAuth, Depends(get_current_user)]
|
||||
) -> UserAuth:
|
||||
# 超级用户跳过权限检查
|
||||
if current_user.is_superuser:
|
||||
return current_user
|
||||
|
||||
# 检查权限
|
||||
user_perms = set(current_user.permissions)
|
||||
required_perms = set(required_permissions)
|
||||
|
||||
if not required_perms.issubset(user_perms):
|
||||
missing = required_perms - user_perms
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Missing permissions: {', '.join(missing)}"
|
||||
)
|
||||
|
||||
return current_user
|
||||
|
||||
return permission_checker
|
||||
|
||||
|
||||
async def optional_user(
|
||||
token: Annotated[Optional[str], Depends(oauth2_scheme)]
|
||||
) -> Optional[UserAuth]:
|
||||
"""
|
||||
可选的用户认证
|
||||
|
||||
如果提供了 token 则验证,否则返回 None
|
||||
|
||||
Args:
|
||||
token: OAuth2 token
|
||||
|
||||
Returns:
|
||||
UserAuth 对象或 None
|
||||
"""
|
||||
if not token:
|
||||
return None
|
||||
|
||||
try:
|
||||
return await get_current_user(token)
|
||||
except HTTPException:
|
||||
return None
|
||||
|
||||
|
||||
async def require_admin(
|
||||
current_user: Annotated[UserAuth, Depends(get_current_user)]
|
||||
) -> UserAuth:
|
||||
"""
|
||||
要求管理员权限
|
||||
|
||||
Args:
|
||||
current_user: 当前用户
|
||||
|
||||
Returns:
|
||||
UserAuth 对象
|
||||
|
||||
Raises:
|
||||
HTTPException: 403 if user is not an admin
|
||||
"""
|
||||
if not current_user.is_superuser:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Admin access required"
|
||||
)
|
||||
return current_user
|
||||
|
||||
|
||||
class AuthDependencies:
|
||||
"""
|
||||
认证依赖组合
|
||||
|
||||
提供常用的认证依赖组合
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def authenticated():
|
||||
"""返回需要认证的依赖"""
|
||||
return Depends(get_current_user)
|
||||
|
||||
@staticmethod
|
||||
def active_user():
|
||||
"""返回需要活跃用户的依赖"""
|
||||
return Depends(get_current_active_user)
|
||||
|
||||
@staticmethod
|
||||
def optional():
|
||||
"""返回可选认证的依赖"""
|
||||
return Depends(optional_user)
|
||||
|
||||
@staticmethod
|
||||
def permissions(perms: List[str]):
|
||||
"""返回需要特定权限的依赖"""
|
||||
return Depends(require_permissions(perms))
|
||||
|
||||
@staticmethod
|
||||
def admin():
|
||||
"""返回需要管理员权限的依赖"""
|
||||
return Depends(require_admin)
|
||||
346
backend/src/auth/jwt.py
Normal file
346
backend/src/auth/jwt.py
Normal file
@@ -0,0 +1,346 @@
|
||||
"""
|
||||
JWT Token 处理模块
|
||||
|
||||
提供 access token 和 refresh token 的创建、验证功能。
|
||||
支持 RSA 密钥对或 HMAC 对称加密。
|
||||
"""
|
||||
|
||||
import uuid
|
||||
import logging
|
||||
import os
|
||||
import secrets
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional, List, Dict, Any
|
||||
from pathlib import Path
|
||||
|
||||
from jose import jwt, JWTError
|
||||
from passlib.context import CryptContext
|
||||
|
||||
from src.auth.models import TokenPayload, TokenPair
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Password hashing
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
# JWT 配置
|
||||
JWT_ALGORITHM = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "30"))
|
||||
REFRESH_TOKEN_EXPIRE_DAYS = int(os.getenv("REFRESH_TOKEN_EXPIRE_DAYS", "7"))
|
||||
|
||||
# 密钥配置(生产环境必须从环境变量获取)
|
||||
def _get_secret_key() -> str:
|
||||
"""从环境变量获取密钥,如果未设置则生成随机密钥(仅用于开发)"""
|
||||
secret = os.getenv("JWT_SECRET_KEY")
|
||||
if secret:
|
||||
return secret
|
||||
|
||||
# 开发环境:生成随机密钥并警告
|
||||
if os.getenv("NODE_ENV", "development") == "development":
|
||||
logger.warning(
|
||||
"JWT_SECRET_KEY not set in environment. "
|
||||
"Using auto-generated random key (tokens will not persist across restarts). "
|
||||
"Please set JWT_SECRET_KEY in production!"
|
||||
)
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
# 生产环境必须设置密钥
|
||||
raise ValueError(
|
||||
"JWT_SECRET_KEY must be set in environment variables for production. "
|
||||
"Generate a secure key with: python -c 'import secrets; print(secrets.token_hex(32))'"
|
||||
)
|
||||
|
||||
SECRET_KEY = _get_secret_key()
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""验证密码"""
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
|
||||
|
||||
def get_password_hash(password: str) -> str:
|
||||
"""生成密码哈希"""
|
||||
return pwd_context.hash(password)
|
||||
|
||||
|
||||
def create_access_token(
|
||||
user_id: str,
|
||||
scopes: Optional[List[str]] = None,
|
||||
expires_delta: Optional[timedelta] = None,
|
||||
session_id: Optional[str] = None,
|
||||
extra_claims: Optional[Dict[str, Any]] = None
|
||||
) -> str:
|
||||
"""
|
||||
创建 access token
|
||||
|
||||
Args:
|
||||
user_id: 用户 ID
|
||||
scopes: 权限范围列表
|
||||
expires_delta: 自定义过期时间
|
||||
extra_claims: 额外的 JWT claims
|
||||
|
||||
Returns:
|
||||
JWT access token string
|
||||
"""
|
||||
if expires_delta:
|
||||
expire = datetime.now(timezone.utc) + expires_delta
|
||||
else:
|
||||
expire = datetime.now(timezone.utc) + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
payload = {
|
||||
"sub": user_id,
|
||||
"exp": int(expire.timestamp()),
|
||||
"iat": int(now.timestamp()),
|
||||
"type": "access",
|
||||
"scopes": scopes or [],
|
||||
"jti": str(uuid.uuid4()),
|
||||
"sid": session_id,
|
||||
}
|
||||
|
||||
if extra_claims:
|
||||
payload.update(extra_claims)
|
||||
|
||||
encoded_jwt = jwt.encode(payload, SECRET_KEY, algorithm=JWT_ALGORITHM)
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
def create_refresh_token(
|
||||
user_id: str,
|
||||
expires_delta: Optional[timedelta] = None,
|
||||
jti: Optional[str] = None,
|
||||
session_id: Optional[str] = None,
|
||||
session_family_id: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
创建 refresh token
|
||||
|
||||
Args:
|
||||
user_id: 用户 ID
|
||||
expires_delta: 自定义过期时间
|
||||
jti: JWT ID(用于 token 撤销)
|
||||
|
||||
Returns:
|
||||
JWT refresh token string
|
||||
"""
|
||||
if expires_delta:
|
||||
expire = datetime.now(timezone.utc) + expires_delta
|
||||
else:
|
||||
expire = datetime.now(timezone.utc) + timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
token_jti = jti or str(uuid.uuid4())
|
||||
|
||||
payload = {
|
||||
"sub": user_id,
|
||||
"exp": int(expire.timestamp()),
|
||||
"iat": int(now.timestamp()),
|
||||
"type": "refresh",
|
||||
"jti": token_jti,
|
||||
"sid": session_id,
|
||||
"sfid": session_family_id,
|
||||
}
|
||||
|
||||
encoded_jwt = jwt.encode(payload, SECRET_KEY, algorithm=JWT_ALGORITHM)
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
def create_token_pair(
|
||||
user_id: str,
|
||||
scopes: Optional[List[str]] = None,
|
||||
access_expires: Optional[timedelta] = None,
|
||||
refresh_expires: Optional[timedelta] = None,
|
||||
session_id: Optional[str] = None,
|
||||
session_family_id: Optional[str] = None,
|
||||
) -> TokenPair:
|
||||
"""
|
||||
创建 access token 和 refresh token 对
|
||||
|
||||
Args:
|
||||
user_id: 用户 ID
|
||||
scopes: 权限范围
|
||||
access_expires: access token 过期时间
|
||||
refresh_expires: refresh token 过期时间
|
||||
|
||||
Returns:
|
||||
TokenPair 包含 access_token 和 refresh_token
|
||||
"""
|
||||
access_token = create_access_token(user_id, scopes, access_expires, session_id=session_id)
|
||||
refresh_token = create_refresh_token(
|
||||
user_id,
|
||||
refresh_expires,
|
||||
session_id=session_id,
|
||||
session_family_id=session_family_id,
|
||||
)
|
||||
|
||||
return TokenPair(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
session_id=session_id,
|
||||
session_family_id=session_family_id,
|
||||
)
|
||||
|
||||
|
||||
async def verify_token(token: str, token_type: str = "access", check_blacklist: bool = True) -> Optional[TokenPayload]:
|
||||
"""
|
||||
验证 JWT token
|
||||
|
||||
Args:
|
||||
token: JWT token string
|
||||
token_type: 期望的 token 类型 ("access" 或 "refresh")
|
||||
check_blacklist: 是否检查黑名单(默认为 True)
|
||||
|
||||
Returns:
|
||||
TokenPayload 如果验证成功,否则 None
|
||||
"""
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[JWT_ALGORITHM])
|
||||
|
||||
# 验证 token 类型
|
||||
if payload.get("type") != token_type:
|
||||
logger.warning(f"Token type mismatch: expected {token_type}, got {payload.get('type')}")
|
||||
return None
|
||||
|
||||
# 验证必要字段
|
||||
if "sub" not in payload or "exp" not in payload:
|
||||
logger.warning("Token missing required fields")
|
||||
return None
|
||||
|
||||
# 检查是否在黑名单中
|
||||
if check_blacklist:
|
||||
try:
|
||||
from src.services.token_blacklist_service import token_blacklist_service
|
||||
is_revoked = await token_blacklist_service.is_token_revoked(token)
|
||||
if is_revoked:
|
||||
logger.warning(f"Token has been revoked: {payload.get('jti')}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check token blacklist: {e}")
|
||||
# 黑名单检查失败,继续验证(避免服务不可用)
|
||||
|
||||
return TokenPayload(**payload)
|
||||
|
||||
except JWTError as e:
|
||||
logger.warning(f"JWT verification failed: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error during token verification: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def verify_token_sync(token: str, token_type: str = "access") -> Optional[TokenPayload]:
|
||||
"""
|
||||
同步验证 JWT token(不检查黑名单,用于中间件等同步上下文)
|
||||
|
||||
Args:
|
||||
token: JWT token string
|
||||
token_type: 期望的 token 类型 ("access" 或 "refresh")
|
||||
|
||||
Returns:
|
||||
TokenPayload 如果验证成功,否则 None
|
||||
"""
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_KEY, algorithms=[JWT_ALGORITHM])
|
||||
|
||||
# 验证 token 类型
|
||||
if payload.get("type") != token_type:
|
||||
logger.warning(f"Token type mismatch: expected {token_type}, got {payload.get('type')}")
|
||||
return None
|
||||
|
||||
# 验证必要字段
|
||||
if "sub" not in payload or "exp" not in payload:
|
||||
logger.warning("Token missing required fields")
|
||||
return None
|
||||
|
||||
return TokenPayload(**payload)
|
||||
|
||||
except JWTError as e:
|
||||
logger.warning(f"JWT verification failed: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error during token verification: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def verify_refresh_token(token: str) -> Optional[TokenPayload]:
|
||||
"""
|
||||
验证 refresh token
|
||||
|
||||
Args:
|
||||
token: Refresh token string
|
||||
|
||||
Returns:
|
||||
TokenPayload 如果验证成功,否则 None
|
||||
"""
|
||||
return verify_token(token, token_type="refresh")
|
||||
|
||||
|
||||
def decode_token_unsafe(token: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
不解密直接解码 token(用于获取 payload 信息而不验证签名)
|
||||
|
||||
警告: 仅用于获取非敏感信息,不要用于认证!
|
||||
|
||||
Args:
|
||||
token: JWT token string
|
||||
|
||||
Returns:
|
||||
Payload dict 或 None
|
||||
"""
|
||||
try:
|
||||
# 分割 JWT
|
||||
parts = token.split(".")
|
||||
if len(parts) != 3:
|
||||
return None
|
||||
|
||||
# 解码 payload (第二部分)
|
||||
import base64
|
||||
payload_b64 = parts[1]
|
||||
# 添加 padding
|
||||
padding = 4 - len(payload_b64) % 4
|
||||
if padding != 4:
|
||||
payload_b64 += "=" * padding
|
||||
|
||||
payload_json = base64.urlsafe_b64decode(payload_b64)
|
||||
return json.loads(payload_json)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to decode token: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def get_token_expiry(token: str) -> Optional[datetime]:
|
||||
"""
|
||||
获取 token 的过期时间
|
||||
|
||||
Args:
|
||||
token: JWT token string
|
||||
|
||||
Returns:
|
||||
过期时间或 None
|
||||
"""
|
||||
payload = decode_token_unsafe(token)
|
||||
if payload and "exp" in payload:
|
||||
return datetime.fromtimestamp(payload["exp"], tz=timezone.utc)
|
||||
return None
|
||||
|
||||
|
||||
def is_token_expired(token: str) -> bool:
|
||||
"""
|
||||
检查 token 是否已过期
|
||||
|
||||
Args:
|
||||
token: JWT token string
|
||||
|
||||
Returns:
|
||||
如果已过期返回 True
|
||||
"""
|
||||
expiry = get_token_expiry(token)
|
||||
if expiry:
|
||||
return datetime.now(timezone.utc) > expiry
|
||||
return True
|
||||
|
||||
|
||||
# Import json for decode_token_unsafe
|
||||
import json
|
||||
343
backend/src/auth/middleware.py
Normal file
343
backend/src/auth/middleware.py
Normal file
@@ -0,0 +1,343 @@
|
||||
"""
|
||||
认证中间件
|
||||
|
||||
提供全局认证中间件和权限检查。
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional, List, Set
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response, JSONResponse
|
||||
from starlette.types import ASGIApp
|
||||
|
||||
from src.auth.jwt import verify_token
|
||||
from src.auth.models import UserAuth, TokenPayload
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AuthMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
全局认证中间件
|
||||
|
||||
为所有请求添加认证信息,可选择性地保护路由。
|
||||
"""
|
||||
|
||||
# 不需要认证的路径
|
||||
PUBLIC_PATHS: Set[str] = {
|
||||
"/docs",
|
||||
"/redoc",
|
||||
"/openapi.json",
|
||||
"/api/v1/auth/login",
|
||||
"/api/v1/auth/register",
|
||||
"/api/v1/auth/refresh",
|
||||
"/health",
|
||||
"/metrics",
|
||||
"/api/v1/health",
|
||||
}
|
||||
|
||||
# 路径前缀(以这些前缀开头的路径不需要认证)
|
||||
PUBLIC_PREFIXES: List[str] = [
|
||||
"/static/",
|
||||
"/api/v1/public/",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
require_auth: bool = False,
|
||||
public_paths: Optional[Set[str]] = None,
|
||||
public_prefixes: Optional[List[str]] = None
|
||||
):
|
||||
"""
|
||||
初始化认证中间件
|
||||
|
||||
Args:
|
||||
app: ASGI 应用
|
||||
require_auth: 是否默认要求认证
|
||||
public_paths: 额外的公开路径
|
||||
public_prefixes: 额外的公开路径前缀
|
||||
"""
|
||||
super().__init__(app)
|
||||
self.require_auth = require_auth
|
||||
|
||||
if public_paths:
|
||||
self.PUBLIC_PATHS.update(public_paths)
|
||||
if public_prefixes:
|
||||
self.PUBLIC_PREFIXES.extend(public_prefixes)
|
||||
|
||||
def _is_public_path(self, path: str) -> bool:
|
||||
"""检查路径是否公开"""
|
||||
# 精确匹配
|
||||
if path in self.PUBLIC_PATHS:
|
||||
return True
|
||||
|
||||
# 前缀匹配
|
||||
for prefix in self.PUBLIC_PREFIXES:
|
||||
if path.startswith(prefix):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _extract_token(self, request: Request) -> Optional[str]:
|
||||
"""从请求中提取 token"""
|
||||
# 从 Authorization header
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if auth_header.startswith("Bearer "):
|
||||
return auth_header[7:]
|
||||
|
||||
# 从 query parameter (用于特殊场景,如直接链接分享)
|
||||
token = request.query_params.get("token")
|
||||
if token:
|
||||
return token
|
||||
|
||||
# 从 cookie
|
||||
token = request.cookies.get("access_token")
|
||||
if token:
|
||||
return token
|
||||
|
||||
return None
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
"""
|
||||
处理请求
|
||||
|
||||
Args:
|
||||
request: 请求对象
|
||||
call_next: 下一个中间件/处理函数
|
||||
|
||||
Returns:
|
||||
响应对象
|
||||
"""
|
||||
path = request.url.path
|
||||
|
||||
# 提取 token
|
||||
token = self._extract_token(request)
|
||||
|
||||
# 如果有 token,尝试验证
|
||||
user: Optional[UserAuth] = None
|
||||
if token:
|
||||
payload = await verify_token(token, token_type="access")
|
||||
if payload:
|
||||
user = UserAuth(
|
||||
id=payload.sub,
|
||||
username=f"user_{payload.sub[:8]}",
|
||||
is_active=True,
|
||||
permissions=payload.scopes or []
|
||||
)
|
||||
|
||||
# 将用户信息附加到请求状态
|
||||
request.state.user = user
|
||||
request.state.is_authenticated = user is not None
|
||||
|
||||
# 检查是否需要认证
|
||||
if self._is_public_path(path):
|
||||
# 公开路径,允许访问
|
||||
return await call_next(request)
|
||||
|
||||
# 非公开路径
|
||||
if self.require_auth and not user:
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={
|
||||
"code": "4010",
|
||||
"message": "Authentication required",
|
||||
"data": None
|
||||
}
|
||||
)
|
||||
|
||||
# 继续处理
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
class PermissionMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
权限检查中间件
|
||||
|
||||
基于路径进行权限检查。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
path_permissions: Optional[dict] = None
|
||||
):
|
||||
"""
|
||||
初始化权限中间件
|
||||
|
||||
Args:
|
||||
app: ASGI 应用
|
||||
path_permissions: 路径权限映射 {"/path": ["permission1", "permission2"]}
|
||||
"""
|
||||
super().__init__(app)
|
||||
self.path_permissions = path_permissions or {}
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
"""处理请求"""
|
||||
path = request.url.path
|
||||
user: Optional[UserAuth] = getattr(request.state, "user", None)
|
||||
|
||||
# 检查路径权限
|
||||
for path_pattern, required_perms in self.path_permissions.items():
|
||||
if path.startswith(path_pattern):
|
||||
if not user:
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={
|
||||
"code": "4010",
|
||||
"message": "Authentication required",
|
||||
"data": None
|
||||
}
|
||||
)
|
||||
|
||||
if not user.is_superuser:
|
||||
user_perms = set(user.permissions)
|
||||
required = set(required_perms)
|
||||
|
||||
if not required.issubset(user_perms):
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content={
|
||||
"code": "4030",
|
||||
"message": "Permission denied",
|
||||
"data": {"required": list(required)}
|
||||
}
|
||||
)
|
||||
|
||||
break
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
基于 Redis 的速率限制中间件
|
||||
|
||||
为认证用户和匿名用户分别设置速率限制。
|
||||
使用 Redis 实现分布式速率限制。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
authenticated_limit: int = 1000, # per minute
|
||||
anonymous_limit: int = 60, # per minute
|
||||
window_size: int = 60, # 1 minute window
|
||||
):
|
||||
super().__init__(app)
|
||||
self.authenticated_limit = authenticated_limit
|
||||
self.anonymous_limit = anonymous_limit
|
||||
self.window_size = window_size
|
||||
self._redis_client = None
|
||||
self._enabled = False
|
||||
|
||||
async def _get_redis(self):
|
||||
"""Lazy initialize Redis connection"""
|
||||
if not self._enabled:
|
||||
return None
|
||||
if self._redis_client is None:
|
||||
try:
|
||||
import redis.asyncio as aioredis
|
||||
from src.config.settings import REDIS_URL
|
||||
self._redis_client = await aioredis.from_url(
|
||||
REDIS_URL,
|
||||
encoding="utf-8",
|
||||
decode_responses=True
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Redis not available for rate limiting: {e}")
|
||||
return None
|
||||
return self._redis_client
|
||||
|
||||
async def _check_rate_limit(self, key: str, limit: int) -> tuple[bool, int, int]:
|
||||
"""
|
||||
检查速率限制
|
||||
|
||||
Returns:
|
||||
(allowed, remaining, reset_after)
|
||||
"""
|
||||
redis = await self._get_redis()
|
||||
if not redis:
|
||||
# Redis unavailable, allow request
|
||||
return True, limit, 0
|
||||
|
||||
try:
|
||||
current_time = int(time.time())
|
||||
window_start = current_time - self.window_size
|
||||
|
||||
# Remove old entries
|
||||
await redis.zremrangebyscore(key, 0, window_start)
|
||||
|
||||
# Count current requests in window
|
||||
current_count = await redis.zcard(key)
|
||||
|
||||
if current_count >= limit:
|
||||
# Rate limit exceeded
|
||||
oldest = await redis.zrange(key, 0, 0, withscores=True)
|
||||
reset_after = int(oldest[0][1]) + self.window_size - current_time
|
||||
return False, 0, max(reset_after, 1)
|
||||
|
||||
# Add current request
|
||||
await redis.zadd(key, {str(current_time): current_time})
|
||||
await redis.expire(key, self.window_size)
|
||||
|
||||
remaining = limit - current_count - 1
|
||||
return True, remaining, 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Rate limit check failed: {e}")
|
||||
# Fail open - allow request if Redis fails
|
||||
return True, limit, 0
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response:
|
||||
"""处理请求"""
|
||||
# Check if Redis is enabled
|
||||
if not self._enabled:
|
||||
from src.config.settings import REDIS_ENABLED
|
||||
self._enabled = REDIS_ENABLED
|
||||
|
||||
# Get user identifier
|
||||
user: Optional[UserAuth] = getattr(request.state, "user", None)
|
||||
if user:
|
||||
client_id = f"ratelimit:auth:{user.id}"
|
||||
limit = self.authenticated_limit
|
||||
else:
|
||||
# Use IP address for anonymous users
|
||||
client_ip = request.headers.get("X-Forwarded-For", request.client.host)
|
||||
client_id = f"ratelimit:anon:{client_ip}"
|
||||
limit = self.anonymous_limit
|
||||
|
||||
# Check rate limit
|
||||
allowed, remaining, reset_after = await self._check_rate_limit(client_id, limit)
|
||||
|
||||
# Process request
|
||||
response = await call_next(request)
|
||||
|
||||
# Add rate limit headers
|
||||
if self._enabled:
|
||||
response.headers["X-RateLimit-Limit"] = str(limit)
|
||||
response.headers["X-RateLimit-Remaining"] = str(remaining)
|
||||
if reset_after > 0:
|
||||
response.headers["X-RateLimit-Reset-After"] = str(reset_after)
|
||||
|
||||
if not allowed:
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={
|
||||
"code": "4290",
|
||||
"message": "Rate limit exceeded",
|
||||
"details": {
|
||||
"limit": limit,
|
||||
"reset_after": reset_after
|
||||
}
|
||||
},
|
||||
headers={
|
||||
"Retry-After": str(reset_after),
|
||||
"X-RateLimit-Limit": str(limit),
|
||||
"X-RateLimit-Remaining": "0"
|
||||
}
|
||||
)
|
||||
|
||||
return response
|
||||
91
backend/src/auth/models.py
Normal file
91
backend/src/auth/models.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""
|
||||
认证模型定义
|
||||
|
||||
包含用户认证相关的 Pydantic 模型。
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
|
||||
|
||||
class TokenData(BaseModel):
|
||||
"""Token 数据模型"""
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
token_type: str = "bearer"
|
||||
expires_in: int = Field(description="Access token 过期时间(秒)")
|
||||
refresh_expires_in: int = Field(description="Refresh token 过期时间(秒)")
|
||||
|
||||
|
||||
class TokenPayload(BaseModel):
|
||||
"""JWT Token Payload"""
|
||||
sub: str = Field(description="用户 ID")
|
||||
exp: int = Field(description="过期时间戳")
|
||||
iat: int = Field(description="签发时间戳")
|
||||
type: str = Field(default="access", description="Token 类型: access 或 refresh")
|
||||
scopes: List[str] = Field(default=[], description="权限范围")
|
||||
jti: Optional[str] = Field(default=None, description="JWT ID,用于 token 撤销")
|
||||
sid: Optional[str] = Field(default=None, description="会话 ID")
|
||||
sfid: Optional[str] = Field(default=None, description="会话族 ID")
|
||||
|
||||
|
||||
class TokenPair(BaseModel):
|
||||
"""Token 对(access + refresh)"""
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
token_type: str = "bearer"
|
||||
session_id: Optional[str] = None
|
||||
session_family_id: Optional[str] = None
|
||||
|
||||
|
||||
class RefreshTokenRequest(BaseModel):
|
||||
"""刷新 Token 请求"""
|
||||
refresh_token: str
|
||||
|
||||
|
||||
class UserAuth(BaseModel):
|
||||
"""用户认证信息"""
|
||||
id: str
|
||||
username: str
|
||||
email: Optional[str] = None
|
||||
avatar_url: Optional[str] = None
|
||||
is_active: bool = True
|
||||
is_superuser: bool = False
|
||||
permissions: List[str] = []
|
||||
roles: List[str] = []
|
||||
created_at: Optional[float] = None
|
||||
last_login: Optional[float] = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class UserLoginRequest(BaseModel):
|
||||
"""用户登录请求"""
|
||||
username: str
|
||||
password: str
|
||||
|
||||
|
||||
class UserRegisterRequest(BaseModel):
|
||||
"""用户注册请求"""
|
||||
username: str
|
||||
email: str
|
||||
password: str
|
||||
password_confirm: str
|
||||
|
||||
|
||||
class UserPasswordChangeRequest(BaseModel):
|
||||
"""用户修改密码请求"""
|
||||
current_password: str
|
||||
new_password: str
|
||||
new_password_confirm: str
|
||||
|
||||
|
||||
class OAuth2UserInfo(BaseModel):
|
||||
"""OAuth2 用户信息"""
|
||||
provider: str
|
||||
provider_user_id: str
|
||||
email: Optional[str] = None
|
||||
username: Optional[str] = None
|
||||
avatar_url: Optional[str] = None
|
||||
raw_data: Dict[str, Any] = {}
|
||||
16
backend/src/cache/__init__.py
vendored
Normal file
16
backend/src/cache/__init__.py
vendored
Normal file
@@ -0,0 +1,16 @@
|
||||
"""
|
||||
缓存模块
|
||||
|
||||
提供多级缓存支持:
|
||||
- Redis 缓存
|
||||
- 内存缓存(开发环境)
|
||||
- 缓存装饰器
|
||||
"""
|
||||
|
||||
from src.cache.redis_cache import RedisCache, cache_result, cache_with_ttl
|
||||
|
||||
__all__ = [
|
||||
"RedisCache",
|
||||
"cache_result",
|
||||
"cache_with_ttl",
|
||||
]
|
||||
328
backend/src/cache/redis_cache.py
vendored
Normal file
328
backend/src/cache/redis_cache.py
vendored
Normal file
@@ -0,0 +1,328 @@
|
||||
"""
|
||||
Redis 缓存实现
|
||||
|
||||
提供缓存装饰器和缓存管理功能。
|
||||
"""
|
||||
|
||||
import functools
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import pickle
|
||||
from typing import Optional, Callable, Any, Union
|
||||
from datetime import timedelta
|
||||
|
||||
import redis.asyncio as redis
|
||||
from redis.asyncio.client import Redis
|
||||
|
||||
from src.config.settings import REDIS_URL, REDIS_ENABLED
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RedisCache:
|
||||
"""
|
||||
Redis 缓存管理器
|
||||
|
||||
提供统一的缓存接口。
|
||||
"""
|
||||
|
||||
_instance: Optional["RedisCache"] = None
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
"""单例模式"""
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, redis_url: Optional[str] = None):
|
||||
if hasattr(self, "_initialized"):
|
||||
return
|
||||
|
||||
self._redis_url = redis_url or REDIS_URL
|
||||
self._redis: Optional[Redis] = None
|
||||
self._enabled = REDIS_ENABLED
|
||||
self._initialized = True
|
||||
|
||||
async def connect(self):
|
||||
"""建立 Redis 连接"""
|
||||
if not self._enabled:
|
||||
logger.info("Redis cache is disabled")
|
||||
return
|
||||
|
||||
try:
|
||||
self._redis = await redis.from_url(
|
||||
self._redis_url,
|
||||
encoding="utf-8",
|
||||
decode_responses=False # 支持二进制数据
|
||||
)
|
||||
await self._redis.ping()
|
||||
logger.info(f"Connected to Redis cache at {self._redis_url}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to Redis cache: {e}")
|
||||
self._enabled = False
|
||||
|
||||
async def disconnect(self):
|
||||
"""断开 Redis 连接"""
|
||||
if self._redis:
|
||||
await self._redis.close()
|
||||
self._redis = None
|
||||
logger.info("Disconnected from Redis cache")
|
||||
|
||||
def _make_key(self, prefix: str, *args, **kwargs) -> str:
|
||||
"""
|
||||
生成缓存键
|
||||
|
||||
Args:
|
||||
prefix: 键前缀
|
||||
*args: 位置参数
|
||||
**kwargs: 关键字参数
|
||||
|
||||
Returns:
|
||||
缓存键
|
||||
"""
|
||||
key_parts = [prefix]
|
||||
|
||||
for arg in args:
|
||||
key_parts.append(str(arg))
|
||||
|
||||
for k, v in sorted(kwargs.items()):
|
||||
key_parts.append(f"{k}:{v}")
|
||||
|
||||
raw_key = "|".join(key_parts)
|
||||
# 使用哈希确保键长度合理
|
||||
return f"cache:{prefix}:{hashlib.md5(raw_key.encode()).hexdigest()}"
|
||||
|
||||
async def get(self, key: str) -> Optional[Any]:
|
||||
"""
|
||||
获取缓存值
|
||||
|
||||
Args:
|
||||
key: 缓存键
|
||||
|
||||
Returns:
|
||||
缓存值或 None
|
||||
"""
|
||||
if not self._enabled or not self._redis:
|
||||
return None
|
||||
|
||||
try:
|
||||
data = await self._redis.get(key)
|
||||
if data:
|
||||
return pickle.loads(data)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Cache get error: {e}")
|
||||
return None
|
||||
|
||||
async def set(
|
||||
self,
|
||||
key: str,
|
||||
value: Any,
|
||||
ttl: Optional[Union[int, timedelta]] = None
|
||||
) -> bool:
|
||||
"""
|
||||
设置缓存值
|
||||
|
||||
Args:
|
||||
key: 缓存键
|
||||
value: 缓存值
|
||||
ttl: 过期时间(秒或 timedelta)
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
if not self._enabled or not self._redis:
|
||||
return False
|
||||
|
||||
try:
|
||||
expire = None
|
||||
if isinstance(ttl, timedelta):
|
||||
expire = int(ttl.total_seconds())
|
||||
elif isinstance(ttl, int):
|
||||
expire = ttl
|
||||
|
||||
data = pickle.dumps(value)
|
||||
await self._redis.set(key, data, ex=expire)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Cache set error: {e}")
|
||||
return False
|
||||
|
||||
async def delete(self, key: str) -> bool:
|
||||
"""
|
||||
删除缓存值
|
||||
|
||||
Args:
|
||||
key: 缓存键
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
if not self._enabled or not self._redis:
|
||||
return False
|
||||
|
||||
try:
|
||||
await self._redis.delete(key)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Cache delete error: {e}")
|
||||
return False
|
||||
|
||||
async def delete_pattern(self, pattern: str) -> int:
|
||||
"""
|
||||
删除匹配模式的缓存
|
||||
|
||||
Args:
|
||||
pattern: 匹配模式
|
||||
|
||||
Returns:
|
||||
删除的键数量
|
||||
"""
|
||||
if not self._enabled or not self._redis:
|
||||
return 0
|
||||
|
||||
try:
|
||||
keys = await self._redis.keys(pattern)
|
||||
if keys:
|
||||
return await self._redis.delete(*keys)
|
||||
return 0
|
||||
except Exception as e:
|
||||
logger.error(f"Cache delete pattern error: {e}")
|
||||
return 0
|
||||
|
||||
async def exists(self, key: str) -> bool:
|
||||
"""
|
||||
检查键是否存在
|
||||
|
||||
Args:
|
||||
key: 缓存键
|
||||
|
||||
Returns:
|
||||
是否存在
|
||||
"""
|
||||
if not self._enabled or not self._redis:
|
||||
return False
|
||||
|
||||
try:
|
||||
return await self._redis.exists(key) > 0
|
||||
except Exception as e:
|
||||
logger.error(f"Cache exists error: {e}")
|
||||
return False
|
||||
|
||||
async def ttl(self, key: str) -> int:
|
||||
"""
|
||||
获取键的剩余生存时间
|
||||
|
||||
Args:
|
||||
key: 缓存键
|
||||
|
||||
Returns:
|
||||
剩余秒数,-1 表示永不过期,-2 表示不存在
|
||||
"""
|
||||
if not self._enabled or not self._redis:
|
||||
return -2
|
||||
|
||||
try:
|
||||
return await self._redis.ttl(key)
|
||||
except Exception as e:
|
||||
logger.error(f"Cache ttl error: {e}")
|
||||
return -2
|
||||
|
||||
async def clear(self) -> bool:
|
||||
"""
|
||||
清空所有缓存(谨慎使用)
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
if not self._enabled or not self._redis:
|
||||
return False
|
||||
|
||||
try:
|
||||
await self._redis.flushdb()
|
||||
logger.warning("Cache cleared")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Cache clear error: {e}")
|
||||
return False
|
||||
|
||||
|
||||
# 全局缓存实例
|
||||
cache = RedisCache()
|
||||
|
||||
|
||||
def cache_result(
|
||||
prefix: str,
|
||||
ttl: Optional[Union[int, timedelta]] = 300,
|
||||
key_func: Optional[Callable] = None
|
||||
):
|
||||
"""
|
||||
缓存装饰器
|
||||
|
||||
自动缓存函数结果。
|
||||
|
||||
Args:
|
||||
prefix: 缓存键前缀
|
||||
ttl: 过期时间(秒)
|
||||
key_func: 自定义键生成函数
|
||||
|
||||
Usage:
|
||||
@cache_result("user", ttl=300)
|
||||
async def get_user(user_id: str):
|
||||
return await fetch_user(user_id)
|
||||
"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@functools.wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
# 生成缓存键
|
||||
if key_func:
|
||||
cache_key = key_func(*args, **kwargs)
|
||||
else:
|
||||
cache_key = cache._make_key(prefix, *args, **kwargs)
|
||||
|
||||
# 尝试从缓存获取
|
||||
result = await cache.get(cache_key)
|
||||
if result is not None:
|
||||
logger.debug(f"Cache hit: {cache_key}")
|
||||
return result
|
||||
|
||||
# 执行函数
|
||||
result = await func(*args, **kwargs)
|
||||
|
||||
# 存入缓存
|
||||
await cache.set(cache_key, result, ttl)
|
||||
logger.debug(f"Cache set: {cache_key}")
|
||||
|
||||
return result
|
||||
|
||||
@functools.wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
# 同步函数不支持缓存
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
def cache_with_ttl(ttl: int = 300):
|
||||
"""
|
||||
简化版缓存装饰器
|
||||
|
||||
使用函数名作为前缀。
|
||||
|
||||
Args:
|
||||
ttl: 过期时间(秒)
|
||||
|
||||
Usage:
|
||||
@cache_with_ttl(300)
|
||||
async def get_user(user_id: str):
|
||||
return await fetch_user(user_id)
|
||||
"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
return cache_result(prefix=func.__name__, ttl=ttl)(func)
|
||||
return decorator
|
||||
|
||||
|
||||
# 导入 asyncio 用于检查协程
|
||||
import asyncio
|
||||
198
backend/src/config/database.py
Normal file
198
backend/src/config/database.py
Normal file
@@ -0,0 +1,198 @@
|
||||
from sqlmodel import create_engine, SQLModel, Session
|
||||
from sqlalchemy.pool import StaticPool, QueuePool
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy import event
|
||||
from src.config.settings import DB_PATH, DATABASE_URL
|
||||
import logging
|
||||
import time
|
||||
import os
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Slow query threshold in seconds
|
||||
SLOW_QUERY_THRESHOLD = float(os.getenv("SLOW_QUERY_THRESHOLD", "1.0"))
|
||||
|
||||
# Connection pool configuration from environment variables
|
||||
POOL_SIZE = int(os.getenv("DB_POOL_SIZE", "20"))
|
||||
MAX_OVERFLOW = int(os.getenv("DB_MAX_OVERFLOW", "10"))
|
||||
POOL_TIMEOUT = int(os.getenv("DB_POOL_TIMEOUT", "30"))
|
||||
POOL_RECYCLE = int(os.getenv("DB_POOL_RECYCLE", "3600"))
|
||||
POOL_PRE_PING = os.getenv("DB_POOL_PRE_PING", "true").lower() == "true"
|
||||
|
||||
# Create the database engine with optimized connection pooling
|
||||
if DATABASE_URL:
|
||||
# PostgreSQL Configuration with configurable connection pool
|
||||
engine = create_engine(
|
||||
DATABASE_URL,
|
||||
echo=False,
|
||||
poolclass=QueuePool,
|
||||
pool_size=POOL_SIZE,
|
||||
max_overflow=MAX_OVERFLOW,
|
||||
pool_timeout=POOL_TIMEOUT,
|
||||
pool_recycle=POOL_RECYCLE,
|
||||
pool_pre_ping=POOL_PRE_PING,
|
||||
# Additional optimizations
|
||||
connect_args={
|
||||
"connect_timeout": 10,
|
||||
"options": "-c statement_timeout=120000" # 120 second statement timeout for AI operations
|
||||
}
|
||||
)
|
||||
logger.info(
|
||||
f"PostgreSQL engine created with pool_size={POOL_SIZE}, "
|
||||
f"max_overflow={MAX_OVERFLOW}, pool_timeout={POOL_TIMEOUT}s, "
|
||||
f"pool_recycle={POOL_RECYCLE}s, pool_pre_ping={POOL_PRE_PING}"
|
||||
)
|
||||
else:
|
||||
# SQLite Configuration (Fallback)
|
||||
engine = create_engine(
|
||||
f"sqlite:///{DB_PATH}",
|
||||
echo=False,
|
||||
connect_args={
|
||||
"check_same_thread": False,
|
||||
"timeout": 30 # 30 second busy timeout
|
||||
},
|
||||
poolclass=StaticPool, # Use StaticPool for SQLite
|
||||
)
|
||||
logger.info(f"SQLite engine created with StaticPool at {DB_PATH}")
|
||||
|
||||
|
||||
# Add slow query logging
|
||||
@event.listens_for(engine, "before_cursor_execute")
|
||||
def before_cursor_execute(conn, cursor, statement, parameters, context, executemany):
|
||||
""" 记录 query start time."""
|
||||
conn.info.setdefault("query_start_time", []).append(time.time())
|
||||
|
||||
|
||||
@event.listens_for(engine, "after_cursor_execute")
|
||||
def after_cursor_execute(conn, cursor, statement, parameters, context, executemany):
|
||||
""" 日志 slow queries."""
|
||||
total_time = time.time() - conn.info["query_start_time"].pop()
|
||||
|
||||
if total_time > SLOW_QUERY_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Slow query detected (took {total_time:.2f}s): {statement[:200]}",
|
||||
extra={
|
||||
"query_time": total_time,
|
||||
"query": statement[:500],
|
||||
"parameters": str(parameters)[:200] if parameters else None
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def init_db():
|
||||
""" Initialize the database tables."""
|
||||
from sqlalchemy.exc import OperationalError
|
||||
|
||||
# Try to create all tables
|
||||
try:
|
||||
SQLModel.metadata.create_all(engine)
|
||||
except OperationalError as e:
|
||||
if "already exists" in str(e):
|
||||
# Table or index already exists, skip creation
|
||||
logger.warning(f"Skipping table/index creation: {e}")
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
def get_session():
|
||||
""" 依赖 for getting a database session.
|
||||
Ensures proper cleanup and connection management.
|
||||
"""
|
||||
with Session(engine) as session:
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
# 会话 is automatically closed by context manager
|
||||
# This ensures connections are returned to the pool
|
||||
pass
|
||||
|
||||
|
||||
def get_pool_status() -> dict:
|
||||
""" Get current connection pool status for monitoring.
|
||||
Returns pool statistics including size, checked out connections, etc.
|
||||
"""
|
||||
pool = engine.pool
|
||||
|
||||
status = {
|
||||
"pool_size": getattr(pool, "size", lambda: 0)(),
|
||||
"checked_out": getattr(pool, "checkedout", lambda: 0)(),
|
||||
"overflow": getattr(pool, "overflow", lambda: 0)(),
|
||||
"checked_in": getattr(pool, "checkedin", lambda: 0)(),
|
||||
}
|
||||
|
||||
# Add pool-specific attributes if available
|
||||
if hasattr(pool, "_pool"):
|
||||
status["available"] = pool._pool.qsize()
|
||||
|
||||
return status
|
||||
|
||||
|
||||
def check_database_health() -> tuple[bool, str]:
|
||||
"""
|
||||
Check database connectivity and health.
|
||||
Returns (is_healthy, message).
|
||||
"""
|
||||
try:
|
||||
with Session(engine) as session:
|
||||
# Simple connectivity probe without model dependency.
|
||||
session.exec(text("SELECT 1"))
|
||||
return True, "Database connection healthy"
|
||||
except Exception as e:
|
||||
logger.error(f"Database health check failed: {e}")
|
||||
return False, f"Database connection failed: {str(e)}"
|
||||
|
||||
|
||||
def check_pool_health() -> dict:
|
||||
"""
|
||||
Comprehensive pool health check with recommendations.
|
||||
Returns detailed pool status and health indicators.
|
||||
"""
|
||||
status = get_pool_status()
|
||||
pool = engine.pool
|
||||
|
||||
# Calculate health metrics
|
||||
total_connections = status.get("pool_size", 0) + status.get("overflow", 0)
|
||||
checked_out = status.get("checked_out", 0)
|
||||
available = status.get("available", 0)
|
||||
|
||||
# Health checks
|
||||
health = {
|
||||
"status": "healthy",
|
||||
"checks": {},
|
||||
"recommendations": [],
|
||||
"pool_status": status
|
||||
}
|
||||
|
||||
# Check 1: Connection exhaustion
|
||||
if checked_out >= total_connections * 0.9:
|
||||
health["checks"]["exhaustion"] = "critical"
|
||||
health["status"] = "critical"
|
||||
health["recommendations"].append(
|
||||
f"Pool near exhaustion: {checked_out}/{total_connections} connections in use. "
|
||||
f"Consider increasing DB_POOL_SIZE or DB_MAX_OVERFLOW."
|
||||
)
|
||||
elif checked_out >= total_connections * 0.75:
|
||||
health["checks"]["exhaustion"] = "warning"
|
||||
health["status"] = "warning"
|
||||
health["recommendations"].append(
|
||||
f"Pool usage high: {checked_out}/{total_connections} connections in use."
|
||||
)
|
||||
else:
|
||||
health["checks"]["exhaustion"] = "healthy"
|
||||
|
||||
# Check 2: Available connections
|
||||
if available == 0 and total_connections > 0:
|
||||
health["checks"]["availability"] = "warning"
|
||||
if health["status"] == "healthy":
|
||||
health["status"] = "warning"
|
||||
else:
|
||||
health["checks"]["availability"] = "healthy"
|
||||
|
||||
# Check 3: Pool size configuration
|
||||
if total_connections < 5:
|
||||
health["recommendations"].append(
|
||||
"Pool size may be too small for production workloads. "
|
||||
"Consider DB_POOL_SIZE >= 10"
|
||||
)
|
||||
|
||||
return health
|
||||
217
backend/src/config/database_async.py
Normal file
217
backend/src/config/database_async.py
Normal file
@@ -0,0 +1,217 @@
|
||||
"""
|
||||
异步数据库配置模块
|
||||
|
||||
提供 SQLAlchemy 2.0 AsyncSession 支持,用于完全异步的数据库操作。
|
||||
支持 PostgreSQL (asyncpg) 和 SQLite (aiosqlite)。
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
create_async_engine,
|
||||
AsyncSession,
|
||||
async_sessionmaker,
|
||||
AsyncEngine,
|
||||
)
|
||||
from sqlalchemy.pool import NullPool
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
from src.config.settings import DB_PATH, DATABASE_URL
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Slow query threshold in seconds
|
||||
SLOW_QUERY_THRESHOLD = float(os.getenv("SLOW_QUERY_THRESHOLD", "1.0"))
|
||||
|
||||
# Connection pool configuration from environment variables
|
||||
POOL_SIZE = int(os.getenv("DB_POOL_SIZE", "20"))
|
||||
MAX_OVERFLOW = int(os.getenv("DB_MAX_OVERFLOW", "10"))
|
||||
POOL_TIMEOUT = int(os.getenv("DB_POOL_TIMEOUT", "30"))
|
||||
POOL_RECYCLE = int(os.getenv("DB_POOL_RECYCLE", "3600"))
|
||||
POOL_PRE_PING = os.getenv("DB_POOL_PRE_PING", "true").lower() == "true"
|
||||
|
||||
# Global engine instance
|
||||
async_engine: AsyncEngine | None = None
|
||||
async_session_maker: async_sessionmaker[AsyncSession] | None = None
|
||||
|
||||
|
||||
def _get_async_url() -> str:
|
||||
"""Generate async database URL from settings."""
|
||||
if DATABASE_URL:
|
||||
# Convert PostgreSQL URL to async version
|
||||
if DATABASE_URL.startswith("postgresql://"):
|
||||
return DATABASE_URL.replace("postgresql://", "postgresql+asyncpg://", 1)
|
||||
elif DATABASE_URL.startswith("postgresql+psycopg2://"):
|
||||
return DATABASE_URL.replace("postgresql+psycopg2://", "postgresql+asyncpg://", 1)
|
||||
return DATABASE_URL
|
||||
else:
|
||||
# SQLite async URL
|
||||
return f"sqlite+aiosqlite:///{DB_PATH}"
|
||||
|
||||
|
||||
def init_async_engine() -> AsyncEngine:
|
||||
"""Initialize async database engine with optimized connection pooling."""
|
||||
global async_engine
|
||||
|
||||
if async_engine is not None:
|
||||
return async_engine
|
||||
|
||||
async_url = _get_async_url()
|
||||
|
||||
if "postgresql" in async_url:
|
||||
# PostgreSQL async configuration
|
||||
async_engine = create_async_engine(
|
||||
async_url,
|
||||
echo=False,
|
||||
pool_size=POOL_SIZE,
|
||||
max_overflow=MAX_OVERFLOW,
|
||||
pool_timeout=POOL_TIMEOUT,
|
||||
pool_recycle=POOL_RECYCLE,
|
||||
pool_pre_ping=POOL_PRE_PING,
|
||||
connect_args={
|
||||
"timeout": 10,
|
||||
"command_timeout": 30,
|
||||
},
|
||||
)
|
||||
logger.info(
|
||||
f"Async PostgreSQL engine created with pool_size={POOL_SIZE}, "
|
||||
f"max_overflow={MAX_OVERFLOW}, pool_timeout={POOL_TIMEOUT}s"
|
||||
)
|
||||
else:
|
||||
# SQLite async configuration
|
||||
async_engine = create_async_engine(
|
||||
async_url,
|
||||
echo=False,
|
||||
poolclass=NullPool, # SQLite doesn't support connection pooling well
|
||||
connect_args={
|
||||
"timeout": 30,
|
||||
"check_same_thread": False,
|
||||
},
|
||||
)
|
||||
logger.info(f"Async SQLite engine created at {DB_PATH}")
|
||||
|
||||
return async_engine
|
||||
|
||||
|
||||
def init_async_session_maker() -> async_sessionmaker[AsyncSession]:
|
||||
"""Initialize async session maker."""
|
||||
global async_session_maker
|
||||
|
||||
if async_session_maker is not None:
|
||||
return async_session_maker
|
||||
|
||||
engine = init_async_engine()
|
||||
|
||||
async_session_maker = async_sessionmaker(
|
||||
engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autocommit=False,
|
||||
autoflush=False,
|
||||
)
|
||||
|
||||
return async_session_maker
|
||||
|
||||
|
||||
async def init_db_async():
|
||||
"""Initialize database tables asynchronously."""
|
||||
engine = init_async_engine()
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(SQLModel.metadata.create_all)
|
||||
|
||||
logger.info("Database tables initialized (async)")
|
||||
|
||||
|
||||
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""
|
||||
Dependency for getting async database session.
|
||||
|
||||
Usage:
|
||||
@router.get("/items")
|
||||
async def get_items(session: AsyncSession = Depends(get_async_session)):
|
||||
...
|
||||
"""
|
||||
session_maker = init_async_session_maker()
|
||||
|
||||
async with session_maker() as session:
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
async def get_async_session_context() -> AsyncSession:
|
||||
"""
|
||||
Get async session as context manager.
|
||||
|
||||
Usage:
|
||||
async with get_async_session_context() as session:
|
||||
result = await session.execute(...)
|
||||
"""
|
||||
session_maker = init_async_session_maker()
|
||||
return session_maker()
|
||||
|
||||
|
||||
async def get_pool_status_async() -> dict:
|
||||
"""
|
||||
Get current connection pool status for monitoring.
|
||||
Returns pool statistics including size, checked out connections, etc.
|
||||
"""
|
||||
engine = init_async_engine()
|
||||
pool = engine.pool
|
||||
|
||||
status = {
|
||||
"pool_size": getattr(pool, "size", lambda: 0)(),
|
||||
"checked_out": getattr(pool, "checkedout", lambda: 0)(),
|
||||
"overflow": getattr(pool, "overflow", lambda: 0)(),
|
||||
}
|
||||
|
||||
return status
|
||||
|
||||
|
||||
async def check_database_health_async() -> tuple[bool, str]:
|
||||
"""
|
||||
Check database connectivity and health asynchronously.
|
||||
Returns (is_healthy, message).
|
||||
"""
|
||||
try:
|
||||
engine = init_async_engine()
|
||||
async with engine.connect() as conn:
|
||||
start_time = time.time()
|
||||
await conn.execute(text("SELECT 1"))
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
return True, f"Database connection healthy ({elapsed:.3f}s)"
|
||||
except Exception as e:
|
||||
logger.error(f"Database health check failed: {e}")
|
||||
return False, f"Database connection failed: {str(e)}"
|
||||
|
||||
|
||||
async def close_async_engine():
|
||||
"""Close async engine and cleanup connections."""
|
||||
global async_engine, async_session_maker
|
||||
|
||||
if async_engine is not None:
|
||||
await async_engine.dispose()
|
||||
async_engine = None
|
||||
async_session_maker = None
|
||||
logger.info("Async database engine closed")
|
||||
|
||||
|
||||
# Import guards for sync operations
|
||||
async def migrate_sync_to_async():
|
||||
"""
|
||||
Migration helper to ensure both sync and async engines are initialized.
|
||||
During transition period, both can coexist.
|
||||
"""
|
||||
# Initialize async engine
|
||||
init_async_engine()
|
||||
init_async_session_maker()
|
||||
|
||||
# Sync engine is already initialized in database.py
|
||||
logger.info("Both sync and async database engines are ready")
|
||||
465
backend/src/config/generation_options.json
Normal file
465
backend/src/config/generation_options.json
Normal file
@@ -0,0 +1,465 @@
|
||||
{
|
||||
"image": {
|
||||
"aspectRatios": {
|
||||
"label": "比例",
|
||||
"options": [
|
||||
{ "value": "1:1", "label": "1:1 (正方形)" },
|
||||
{ "value": "3:4", "label": "3:4 (竖版)" },
|
||||
{ "value": "4:3", "label": "4:3 (横版)" },
|
||||
{ "value": "9:16", "label": "9:16 (手机竖版)" },
|
||||
{ "value": "16:9", "label": "16:9 (宽屏)" }
|
||||
],
|
||||
"default": "16:9"
|
||||
},
|
||||
"resolutions": {
|
||||
"label": "分辨率",
|
||||
"options": [
|
||||
{ "value": "1K", "label": "1K (标准)", "description": "1280x720 或 1024x1024" },
|
||||
{ "value": "2K", "label": "2K (高清)", "description": "2560x1440 或 2048x2048" },
|
||||
{ "value": "4K", "label": "4K (超高清)", "description": "3840x2160 或 4096x4096" }
|
||||
],
|
||||
"default": "1K"
|
||||
},
|
||||
"counts": {
|
||||
"label": "生成数量",
|
||||
"options": [
|
||||
{ "value": 1, "label": "1张" },
|
||||
{ "value": 2, "label": "2张" },
|
||||
{ "value": 3, "label": "3张" },
|
||||
{ "value": 4, "label": "4张" }
|
||||
],
|
||||
"default": 1,
|
||||
"min": 1,
|
||||
"max": 4
|
||||
},
|
||||
"templates": {
|
||||
"label": "模板类型",
|
||||
"options": [
|
||||
{ "value": "general", "label": "通用", "description": "通用图片生成" },
|
||||
{ "value": "character_white_bg", "label": "角色白底图", "description": "角色图片,白色背景" },
|
||||
{ "value": "character_three_view", "label": "角色三视图", "description": "角色正面、侧面、背面视图" },
|
||||
{ "value": "storyboard_integrated", "label": "分镜出图", "description": "基于分镜生成图片" }
|
||||
],
|
||||
"default": "general"
|
||||
}
|
||||
},
|
||||
"video": {
|
||||
"aspectRatios": {
|
||||
"label": "比例",
|
||||
"options": [
|
||||
{ "value": "1:1", "label": "1:1 (正方形)" },
|
||||
{ "value": "3:4", "label": "3:4 (竖版)" },
|
||||
{ "value": "4:3", "label": "4:3 (横版)" },
|
||||
{ "value": "9:16", "label": "9:16 (手机竖版)" },
|
||||
{ "value": "16:9", "label": "16:9 (宽屏)" }
|
||||
],
|
||||
"default": "16:9"
|
||||
},
|
||||
"resolutions": {
|
||||
"label": "分辨率",
|
||||
"options": [
|
||||
{ "value": "720p", "label": "720P (1280x720)", "width": 1280, "height": 720 },
|
||||
{ "value": "1080p", "label": "1080P (1920x1080)", "width": 1920, "height": 1080 }
|
||||
],
|
||||
"default": "1080p"
|
||||
},
|
||||
"durations": {
|
||||
"label": "时长",
|
||||
"options": [
|
||||
{ "value": "2s", "label": "2秒" },
|
||||
{ "value": "3s", "label": "3秒" },
|
||||
{ "value": "4s", "label": "4秒" },
|
||||
{ "value": "5s", "label": "5秒" },
|
||||
{ "value": "6s", "label": "6秒" },
|
||||
{ "value": "7s", "label": "7秒" },
|
||||
{ "value": "8s", "label": "8秒" },
|
||||
{ "value": "10s", "label": "10秒" }
|
||||
],
|
||||
"default": "5s",
|
||||
"description": "视频时长(秒)"
|
||||
},
|
||||
"counts": {
|
||||
"label": "生成数量",
|
||||
"options": [
|
||||
{ "value": 1, "label": "1个" },
|
||||
{ "value": 2, "label": "2个" },
|
||||
{ "value": 3, "label": "3个" },
|
||||
{ "value": 4, "label": "4个" }
|
||||
],
|
||||
"default": 1,
|
||||
"min": 1,
|
||||
"max": 4
|
||||
}
|
||||
},
|
||||
"audio": {
|
||||
"durations": {
|
||||
"label": "时长",
|
||||
"options": [
|
||||
{ "value": "5s", "label": "5秒" },
|
||||
{ "value": "10s", "label": "10秒" },
|
||||
{ "value": "30s", "label": "30秒" },
|
||||
{ "value": "60s", "label": "1分钟" }
|
||||
],
|
||||
"default": "10s"
|
||||
},
|
||||
"formats": {
|
||||
"label": "音频格式",
|
||||
"options": [
|
||||
{ "value": "mp3", "label": "MP3" },
|
||||
{ "value": "wav", "label": "WAV" }
|
||||
],
|
||||
"default": "mp3"
|
||||
}
|
||||
},
|
||||
"music": {
|
||||
"formats": {
|
||||
"label": "音频格式",
|
||||
"options": [
|
||||
{ "value": "mp3", "label": "MP3" },
|
||||
{ "value": "wav", "label": "WAV" },
|
||||
{ "value": "pcm", "label": "PCM" }
|
||||
],
|
||||
"default": "mp3"
|
||||
},
|
||||
"sampleRates": {
|
||||
"label": "采样率",
|
||||
"options": [
|
||||
{ "value": 16000, "label": "16kHz" },
|
||||
{ "value": 24000, "label": "24kHz" },
|
||||
{ "value": 32000, "label": "32kHz" },
|
||||
{ "value": 44100, "label": "44.1kHz" }
|
||||
],
|
||||
"default": 44100
|
||||
},
|
||||
"bitrates": {
|
||||
"label": "码率",
|
||||
"options": [
|
||||
{ "value": 32000, "label": "32kbps" },
|
||||
{ "value": 64000, "label": "64kbps" },
|
||||
{ "value": 128000, "label": "128kbps" },
|
||||
{ "value": 256000, "label": "256kbps" }
|
||||
],
|
||||
"default": 256000
|
||||
},
|
||||
"outputFormats": {
|
||||
"label": "返回格式",
|
||||
"options": [
|
||||
{ "value": "url", "label": "URL (24小时有效)" },
|
||||
{ "value": "hex", "label": "HEX" }
|
||||
],
|
||||
"default": "url"
|
||||
}
|
||||
},
|
||||
"script": {
|
||||
"movieTones": {
|
||||
"label": "整体基调",
|
||||
"options": [
|
||||
{ "value": "悬疑/惊悚", "label": "悬疑/惊悚 (Suspense/Thriller)", "en": "Suspense/Thriller" },
|
||||
{ "value": "古装/权谋", "label": "古装/权谋 (Historical/Political)", "en": "Historical/Political" },
|
||||
{ "value": "现代/都市", "label": "现代/都市 (Modern/Urban)", "en": "Modern/Urban" },
|
||||
{ "value": "科幻/未来", "label": "科幻/未来 (Sci-Fi/Future)", "en": "Sci-Fi/Future" },
|
||||
{ "value": "喜剧/荒诞", "label": "喜剧/荒诞 (Comedy/Absurd)", "en": "Comedy/Absurd" },
|
||||
{ "value": "动作/犯罪", "label": "动作/犯罪 (Action/Crime)", "en": "Action/Crime" },
|
||||
{ "value": "爱情/治愈", "label": "爱情/治愈 (Romance/Healing)", "en": "Romance/Healing" },
|
||||
{ "value": "奇幻/仙侠", "label": "奇幻/仙侠 (Fantasy/Xianxia)", "en": "Fantasy/Xianxia" },
|
||||
{ "value": "现实/人文", "label": "现实/人文 (Realistic/Humanistic)", "en": "Realistic/Humanistic" }
|
||||
],
|
||||
"default": "现代/都市"
|
||||
},
|
||||
"targetAudiences": {
|
||||
"label": "目标受众",
|
||||
"options": [
|
||||
{ "value": "全年龄段", "label": "全年龄段 (All Ages)", "en": "All Ages" },
|
||||
{ "value": "青少年", "label": "青少年 (Teenagers)", "en": "Teenagers" },
|
||||
{ "value": "年轻女性", "label": "年轻女性 (Young Women)", "en": "Young Women" },
|
||||
{ "value": "年轻男性", "label": "年轻男性 (Young Men)", "en": "Young Men" },
|
||||
{ "value": "成年观众", "label": "成年观众 (Adults)", "en": "Adults" },
|
||||
{ "value": "家庭观众", "label": "家庭观众 (Family)", "en": "Family" },
|
||||
{ "value": "资深影迷", "label": "资深影迷 (Cinephiles)", "en": "Cinephiles" }
|
||||
],
|
||||
"default": "全年龄段"
|
||||
}
|
||||
},
|
||||
"director": {
|
||||
"narrativeStyles": {
|
||||
"label": "叙事手法",
|
||||
"options": [
|
||||
{ "value": "线性叙事", "label": "线性叙事 (Linear)", "en": "Linear Narrative" },
|
||||
{ "value": "非线性/插叙", "label": "非线性/插叙 (Non-linear)", "en": "Non-linear Narrative" },
|
||||
{ "value": "多线并行", "label": "多线并行 (Multi-line)", "en": "Multi-line Narrative" },
|
||||
{ "value": "倒叙", "label": "倒叙 (Flashback)", "en": "Flashback" },
|
||||
{ "value": "意识流", "label": "意识流 (Stream of Consciousness)", "en": "Stream of Consciousness" },
|
||||
{ "value": "伪纪录片", "label": "伪纪录片 (Mockumentary)", "en": "Mockumentary" }
|
||||
],
|
||||
"default": "线性叙事"
|
||||
},
|
||||
"editingPaces": {
|
||||
"label": "剪辑节奏",
|
||||
"options": [
|
||||
{ "value": "缓慢/沉浸", "label": "缓慢/沉浸 (Slow/Immersive)", "en": "Slow/Immersive" },
|
||||
{ "value": "明快/流畅", "label": "明快/流畅 (Brisk/Fluid)", "en": "Brisk/Fluid" },
|
||||
{ "value": "快速/凌厉", "label": "快速/凌厉 (Fast/Sharp)", "en": "Fast/Sharp" },
|
||||
{ "value": "极速/碎片化", "label": "极速/碎片化 (Rapid/Fragmented)", "en": "Rapid/Fragmented" },
|
||||
{ "value": "舒缓/诗意", "label": "舒缓/诗意 (Soothing/Poetic)", "en": "Soothing/Poetic" }
|
||||
],
|
||||
"default": "明快/流畅"
|
||||
},
|
||||
"styles": {
|
||||
"label": "导演风格",
|
||||
"options": [
|
||||
{ "value": "孔笙 (Kong Sheng)", "label": "孔笙 (厚重/写实/山海情)", "en": "Kong Sheng Style" },
|
||||
{ "value": "郑晓龙 (Zheng Xiaolong)", "label": "郑晓龙 (宫廷/传奇/甄嬛传)", "en": "Zheng Xiaolong Style" },
|
||||
{ "value": "张黎 (Zhang Li)", "label": "张黎 (历史/权谋/大明王朝)", "en": "Zhang Li Style" },
|
||||
{ "value": "辛爽 (Xin Shuang)", "label": "辛爽 (悬疑/美学/漫长的季节)", "en": "Xin Shuang Style" },
|
||||
{ "value": "王家卫 (Wong Kar-wai)", "label": "王家卫 (繁花/光影/霓虹)", "en": "Wong Kar-wai Style" },
|
||||
{ "value": "李路 (Li Lu)", "label": "李路 (史诗/人世间/人民的名义)", "en": "Li Lu Style" },
|
||||
{ "value": "曹盾 (Cao Dun)", "label": "曹盾 (视觉/长镜头/长安十二时辰)", "en": "Cao Dun Style" },
|
||||
{ "value": "王伟 (Wang Wei)", "label": "王伟 (硬汉/刑侦/白夜追凶)", "en": "Wang Wei Style" },
|
||||
{ "value": "汪俊 (Wang Jun)", "label": "汪俊 (都市/细腻/小欢喜)", "en": "Wang Jun Style" },
|
||||
{ "value": "徐纪周 (Xu Jizhou)", "label": "徐纪周 (群像/狂飙/快节奏)", "en": "Xu Jizhou Style" },
|
||||
{ "value": "吕行 (Lu Xing)", "label": "吕行 (犯罪/人性/无证之罪)", "en": "Lu Xing Style" },
|
||||
{ "value": "丁黑 (Ding Hei)", "label": "丁黑 (警察荣誉/那年花开)", "en": "Ding Hei Style" }
|
||||
],
|
||||
"default": "孔笙 (Kong Sheng)"
|
||||
}
|
||||
},
|
||||
"character": {
|
||||
"genders": {
|
||||
"label": "性别",
|
||||
"options": [
|
||||
{ "value": "男", "label": "男" },
|
||||
{ "value": "女", "label": "女" },
|
||||
{ "value": "未知", "label": "未知" }
|
||||
],
|
||||
"default": "未知"
|
||||
},
|
||||
"costumeStyles": {
|
||||
"label": "服装风格",
|
||||
"options": [
|
||||
{ "value": "现代日常", "label": "现代日常 (Modern Daily)", "en": "Modern Daily" },
|
||||
{ "value": "古装汉服", "label": "古装汉服 (Ancient Hanfu)", "en": "Ancient Hanfu" },
|
||||
{ "value": "民国风情", "label": "民国风情 (Republic Era)", "en": "Republic Era" },
|
||||
{ "value": "赛博科幻", "label": "赛博科幻 (Cyberpunk Sci-Fi)", "en": "Cyberpunk Sci-Fi" },
|
||||
{ "value": "职业制服", "label": "职业制服 (Professional Uniform)", "en": "Professional Uniform" },
|
||||
{ "value": "街头潮流", "label": "街头潮流 (Streetwear)", "en": "Streetwear" },
|
||||
{ "value": "极简森系", "label": "极简森系 (Minimalist Mori)", "en": "Minimalist Mori" },
|
||||
{ "value": "奢华礼服", "label": "奢华礼服 (Luxury/Formal)", "en": "Luxury/Formal" }
|
||||
],
|
||||
"default": "现代日常"
|
||||
},
|
||||
"roles": {
|
||||
"label": "角色定位",
|
||||
"options": [
|
||||
{ "value": "主角", "label": "主角" },
|
||||
{ "value": "配角", "label": "配角" },
|
||||
{ "value": "反派", "label": "反派" },
|
||||
{ "value": "龙套", "label": "龙套" },
|
||||
{ "value": "群演", "label": "群演" }
|
||||
],
|
||||
"default": "配角"
|
||||
},
|
||||
"emotions": {
|
||||
"label": "情绪基调",
|
||||
"options": [
|
||||
{ "value": "平静", "label": "平静 (Neutral)", "en": "Neutral/Calm" },
|
||||
{ "value": "喜悦", "label": "喜悦 (Happy)", "en": "Happy/Joyful" },
|
||||
{ "value": "悲伤", "label": "悲伤 (Sad)", "en": "Sad/Sorrowful" },
|
||||
{ "value": "愤怒", "label": "愤怒 (Angry)", "en": "Angry/Furious" },
|
||||
{ "value": "恐惧", "label": "恐惧 (Fearful)", "en": "Fearful/Scared" },
|
||||
{ "value": "惊讶", "label": "惊讶 (Surprised)", "en": "Surprised/Shocked" },
|
||||
{ "value": "自信", "label": "自信 (Confident)", "en": "Confident/Bold" },
|
||||
{ "value": "思索", "label": "思索 (Thinking)", "en": "Thinking/Pensive" }
|
||||
],
|
||||
"default": "平静"
|
||||
}
|
||||
},
|
||||
"storyboard": {
|
||||
"shotTypes": {
|
||||
"label": "镜头类型",
|
||||
"options": [
|
||||
{ "value": "大远景 (ELS)", "label": "大远景 (ELS)", "en": "Extreme Long Shot (ELS)" },
|
||||
{ "value": "远景 (LS)", "label": "远景 (LS)", "en": "Long Shot (LS)" },
|
||||
{ "value": "全景 (FS)", "label": "全景 (FS)", "en": "Full Shot (FS)" },
|
||||
{ "value": "中远景 (MLS)", "label": "中远景 (MLS)", "en": "Medium Long Shot (MLS)" },
|
||||
{ "value": "中景 (MS)", "label": "中景 (MS)", "en": "Medium Shot (MS)" },
|
||||
{ "value": "中特写 (MCU)", "label": "中特写 (MCU)", "en": "Medium Close-Up (MCU)" },
|
||||
{ "value": "特写 (CU)", "label": "特写 (CU)", "en": "Close-Up (CU)" },
|
||||
{ "value": "大特写 (ECU)", "label": "大特写 (ECU)", "en": "Extreme Close-Up (ECU)" },
|
||||
{ "value": "建立镜头", "label": "建立镜头", "en": "Establishing Shot" },
|
||||
{ "value": "主观镜头 (POV)", "label": "主观镜头 (POV)", "en": "Point of View (POV)" },
|
||||
{ "value": "过肩镜头 (OTS)", "label": "过肩镜头 (OTS)", "en": "Over the Shoulder (OTS)" }
|
||||
],
|
||||
"default": "中景 (MS)"
|
||||
},
|
||||
"cameraMovements": {
|
||||
"label": "运镜方式",
|
||||
"options": [
|
||||
{ "value": "固定镜头 (Static)", "label": "固定镜头 (Static)", "en": "Static" },
|
||||
{ "value": "左摇 (Pan Left)", "label": "左摇 (Pan Left)", "en": "Pan Left" },
|
||||
{ "value": "右摇 (Pan Right)", "label": "右摇 (Pan Right)", "en": "Pan Right" },
|
||||
{ "value": "上仰 (Tilt Up)", "label": "上仰 (Tilt Up)", "en": "Tilt Up" },
|
||||
{ "value": "下俯 (Tilt Down)", "label": "下俯 (Tilt Down)", "en": "Tilt Down" },
|
||||
{ "value": "推镜头 (Zoom In)", "label": "推镜头 (Zoom In)", "en": "Zoom In" },
|
||||
{ "value": "拉镜头 (Zoom Out)", "label": "拉镜头 (Zoom Out)", "en": "Zoom Out" },
|
||||
{ "value": "前移 (Dolly In)", "label": "前移 (Dolly In)", "en": "Dolly In" },
|
||||
{ "value": "后移 (Dolly Out)", "label": "后移 (Dolly Out)", "en": "Dolly Out" },
|
||||
{ "value": "跟随 (Tracking)", "label": "跟随 (Tracking)", "en": "Tracking" },
|
||||
{ "value": "环绕 (Arc)", "label": "环绕 (Arc)", "en": "Arc" },
|
||||
{ "value": "手持 (Handheld)", "label": "手持 (Handheld)", "en": "Handheld" }
|
||||
],
|
||||
"default": "固定镜头 (Static)"
|
||||
},
|
||||
"transitions": {
|
||||
"label": "转场效果",
|
||||
"options": [
|
||||
{ "value": "切 (Cut)", "label": "切 (Cut)", "en": "Cut" },
|
||||
{ "value": "叠化 (Dissolve)", "label": "叠化 (Dissolve)", "en": "Dissolve" },
|
||||
{ "value": "淡入 (Fade In)", "label": "淡入 (Fade In)", "en": "Fade In" },
|
||||
{ "value": "淡出 (Fade Out)", "label": "淡出 (Fade Out)", "en": "Fade Out" },
|
||||
{ "value": "划像 (Wipe)", "label": "划像 (Wipe)", "en": "Wipe" },
|
||||
{ "value": "圈入 (Iris In)", "label": "圈入 (Iris In)", "en": "Iris In" },
|
||||
{ "value": "圈出 (Iris Out)", "label": "圈出 (Iris Out)", "en": "Iris Out" },
|
||||
{ "value": "匹配剪辑 (Match Cut)", "label": "匹配剪辑 (Match Cut)", "en": "Match Cut" },
|
||||
{ "value": "跳接 (Jump Cut)", "label": "跳接 (Jump Cut)", "en": "Jump Cut" }
|
||||
],
|
||||
"default": "切 (Cut)"
|
||||
},
|
||||
"compositions": {
|
||||
"label": "构图方式",
|
||||
"options": [
|
||||
{ "value": "三分法 (Rule of Thirds)", "label": "三分法 (Rule of Thirds)", "en": "Rule of Thirds" },
|
||||
{ "value": "中心构图 (Center Framed)", "label": "中心构图 (Center Framed)", "en": "Center Framed" },
|
||||
{ "value": "对称构图 (Symmetrical)", "label": "对称构图 (Symmetrical)", "en": "Symmetrical" },
|
||||
{ "value": "引导线 (Leading Lines)", "label": "引导线 (Leading Lines)", "en": "Leading Lines" },
|
||||
{ "value": "对角线 (Diagonal)", "label": "对角线 (Diagonal)", "en": "Diagonal" },
|
||||
{ "value": "框架构图 (Framing)", "label": "框架构图 (Framing)", "en": "Framing" },
|
||||
{ "value": "极简留白 (Minimalist/Negative Space)", "label": "极简留白 (Minimalist/Negative Space)", "en": "Minimalist/Negative Space" },
|
||||
{ "value": "黄金螺旋 (Golden Spiral)", "label": "黄金螺旋 (Golden Spiral)", "en": "Golden Spiral" }
|
||||
],
|
||||
"default": "三分法 (Rule of Thirds)"
|
||||
}
|
||||
},
|
||||
"scene": {
|
||||
"timesOfDay": {
|
||||
"label": "拍摄时段",
|
||||
"options": [
|
||||
{ "value": "清晨", "label": "清晨 (Dawn)", "en": "Dawn" },
|
||||
{ "value": "早晨", "label": "早晨 (Morning)", "en": "Morning" },
|
||||
{ "value": "正午", "label": "正午 (Noon)", "en": "Noon" },
|
||||
{ "value": "下午", "label": "下午 (Afternoon)", "en": "Afternoon" },
|
||||
{ "value": "黄金时刻", "label": "黄金时刻 (Golden Hour)", "en": "Golden Hour" },
|
||||
{ "value": "傍晚", "label": "傍晚 (Dusk)", "en": "Dusk" },
|
||||
{ "value": "蓝调时刻", "label": "蓝调时刻 (Blue Hour)", "en": "Blue Hour" },
|
||||
{ "value": "夜晚", "label": "夜晚 (Night)", "en": "Night" },
|
||||
{ "value": "深夜", "label": "深夜 (Late Night)", "en": "Late Night" }
|
||||
],
|
||||
"default": "早晨"
|
||||
},
|
||||
"environmentTypes": {
|
||||
"label": "空间类型",
|
||||
"options": [
|
||||
{ "value": "室内", "label": "室内 (Interior)", "en": "Interior" },
|
||||
{ "value": "室外", "label": "室外 (Exterior)", "en": "Exterior" },
|
||||
{ "value": "半室外", "label": "半室外 (Semi-Exterior)", "en": "Semi-Exterior" }
|
||||
],
|
||||
"default": "室内"
|
||||
},
|
||||
"weather": {
|
||||
"label": "天气环境",
|
||||
"options": [
|
||||
{ "value": "晴朗", "label": "晴朗 (Clear)", "en": "Clear/Sunny" },
|
||||
{ "value": "多云", "label": "多云 (Cloudy)", "en": "Partly Cloudy" },
|
||||
{ "value": "阴天", "label": "阴天 (Overcast)", "en": "Overcast" },
|
||||
{ "value": "小雨", "label": "小雨 (Rainy)", "en": "Light Rain" },
|
||||
{ "value": "大雨", "label": "大雨 (Heavy Rain)", "en": "Heavy Rain" },
|
||||
{ "value": "暴风雨", "label": "暴风雨 (Storm)", "en": "Storm" },
|
||||
{ "value": "小雪", "label": "小雪 (Snowy)", "en": "Light Snow" },
|
||||
{ "value": "大雪", "label": "大雪 (Heavy Snow)", "en": "Heavy Snow" },
|
||||
{ "value": "大雾", "label": "大雾 (Foggy)", "en": "Dense Fog" },
|
||||
{ "value": "沙尘", "label": "沙尘 (Dusty)", "en": "Dusty/Sandstorm" }
|
||||
],
|
||||
"default": "晴朗"
|
||||
}
|
||||
},
|
||||
"cinematic": {
|
||||
"visualStyles": {
|
||||
"label": "视觉画风",
|
||||
"options": [
|
||||
{ "value": "现实主义/纪录片感", "label": "现实主义/纪录片感 (Realistic/Documentary)", "en": "Realistic/Documentary" },
|
||||
{ "value": "电影质感/胶片风", "label": "电影质感/胶片风 (Cinematic/Film)", "en": "Cinematic/Film" },
|
||||
{ "value": "赛博朋克/霓虹", "label": "赛博朋克/霓虹 (Cyberpunk/Neon)", "en": "Cyberpunk/Neon" },
|
||||
{ "value": "古风/水墨", "label": "古风/水墨 (Ancient/Ink Wash)", "en": "Ancient/Ink Wash" },
|
||||
{ "value": "极简主义/冷淡", "label": "极简主义/冷淡 (Minimalist/Cold)", "en": "Minimalist/Cold" },
|
||||
{ "value": "唯美/梦幻", "label": "唯美/梦幻 (Aesthetic/Dreamy)", "en": "Aesthetic/Dreamy" },
|
||||
{ "value": "暗黑/哥特", "label": "暗黑/哥特 (Dark/Gothic)", "en": "Dark/Gothic" },
|
||||
{ "value": "动画/二次元", "label": "动画/二次元 (Anime/2D)", "en": "Anime/2D" }
|
||||
],
|
||||
"default": "电影质感/胶片风"
|
||||
},
|
||||
"cameraAngles": {
|
||||
"label": "镜头角度",
|
||||
"options": [
|
||||
{ "value": "平视", "label": "平视 (Eye Level)", "en": "Eye Level" },
|
||||
{ "value": "俯拍", "label": "俯拍 (High Angle)", "en": "High Angle" },
|
||||
{ "value": "仰拍", "label": "仰拍 (Low Angle)", "en": "Low Angle" },
|
||||
{ "value": "顶拍", "label": "顶拍 (Top Down)", "en": "Top Down/Bird's Eye" },
|
||||
{ "value": "底拍", "label": "底拍 (Worm's Eye)", "en": "Worm's Eye View" },
|
||||
{ "value": "斜角镜头", "label": "斜角镜头 (Dutch Angle)", "en": "Dutch Angle" }
|
||||
],
|
||||
"default": "平视"
|
||||
},
|
||||
"lighting": {
|
||||
"label": "灯光风格",
|
||||
"options": [
|
||||
{ "value": "电影感光效", "label": "电影感 (Cinematic)", "en": "Cinematic Lighting" },
|
||||
{ "value": "自然光", "label": "自然光 (Natural)", "en": "Natural Lighting" },
|
||||
{ "value": "柔光", "label": "柔光 (Soft)", "en": "Soft Lighting" },
|
||||
{ "value": "强反差/明暗对照", "label": "强反差 (High Contrast)", "en": "Chiaroscuro/High Contrast" },
|
||||
{ "value": "轮廓光", "label": "轮廓光 (Rim Lighting)", "en": "Rim Lighting" },
|
||||
{ "value": "三点式亮光", "label": "三点式亮光 (Three-point Lighting)", "en": "Three-point Lighting" },
|
||||
{ "value": "工作室光效", "label": "工作室 (Studio)", "en": "Studio Lighting" }
|
||||
],
|
||||
"default": "电影感光效"
|
||||
},
|
||||
"colorStyle": {
|
||||
"label": "色调氛围",
|
||||
"options": [
|
||||
{ "value": "电影感", "label": "电影感 (Cinematic)", "en": "Cinematic" },
|
||||
{ "value": "暖色调", "label": "暖色调 (Warm Tones)", "en": "Warm Tones" },
|
||||
{ "value": "冷色调", "label": "冷色调 (Cold Tones)", "en": "Cold Tones" },
|
||||
{ "value": "黑白", "label": "黑白 (Black and White)", "en": "Black and White" },
|
||||
{ "value": "复古/胶片感", "label": "复古/胶片 (Vintage)", "en": "Vintage Film" },
|
||||
{ "value": "赛博朋克", "label": "赛博朋克 (Cyberpunk)", "en": "Cyberpunk" },
|
||||
{ "value": "高饱和", "label": "高饱和 (Vibrant)", "en": "Vibrant" },
|
||||
{ "value": "低饱和/灰色调", "label": "低饱和 (Muted)", "en": "Muted/Desaturated" }
|
||||
],
|
||||
"default": "电影感"
|
||||
},
|
||||
"lenses": {
|
||||
"label": "镜头焦距",
|
||||
"options": [
|
||||
{ "value": "广角镜头", "label": "超广角 (14-24mm)", "en": "Ultra Wide Lens" },
|
||||
{ "value": "标准广角", "label": "广角 (35mm)", "en": "Wide Angle Lens" },
|
||||
{ "value": "标准镜头", "label": "标准 (50mm)", "en": "Standard/Nifty Fifty" },
|
||||
{ "value": "人像镜头", "label": "长焦 (85mm)", "en": "Portrait/Short Telephoto" },
|
||||
{ "value": "远摄镜头", "label": "超长焦 (200mm+)", "en": "Telephoto Lens" },
|
||||
{ "value": "鱼眼镜头", "label": "鱼眼 (Fisheye)", "en": "Fisheye Lens" },
|
||||
{ "value": "变焦镜头", "label": "变焦 (Anamorphic)", "en": "Anamorphic Lens" }
|
||||
],
|
||||
"default": "标准镜头"
|
||||
},
|
||||
"focus": {
|
||||
"label": "焦点控制",
|
||||
"options": [
|
||||
{ "value": "自动对焦", "label": "自动对焦 (Auto)", "en": "Auto Focus" },
|
||||
{ "value": "浅景深/虚化", "label": "浅景深 (Shallow Bokeh)", "en": "Shallow Depth of Field" },
|
||||
{ "value": "大深景", "label": "大深景 (Deep Focus)", "en": "Deep Depth of Field" },
|
||||
{ "value": "焦点转移", "label": "变焦对焦 (Rack Focus)", "en": "Rack Focus" },
|
||||
{ "value": "微距焦点", "label": "微距 (Macro)", "en": "Macro Focus" },
|
||||
{ "value": "移轴效果", "label": "移轴 (Tilt-Shift)", "en": "Tilt-Shift" }
|
||||
],
|
||||
"default": "自动对焦"
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
16
backend/src/config/script_agent_config.json
Normal file
16
backend/src/config/script_agent_config.json
Normal file
@@ -0,0 +1,16 @@
|
||||
{
|
||||
"roles": {
|
||||
"story_architect": "moonshot-v1-8k",
|
||||
"character_consultant": "moonshot-v1-8k",
|
||||
"scriptwriter": "moonshot-v1-8k",
|
||||
"director": "moonshot-v1-8k",
|
||||
"auditor": "moonshot-v1-8k",
|
||||
"moderator": "moonshot-v1-8k",
|
||||
"psychologist": "moonshot-v1-8k",
|
||||
"visualizer": "moonshot-v1-8k",
|
||||
"continuity_manager": "moonshot-v1-8k",
|
||||
"showrunner": "moonshot-v1-8k",
|
||||
"chief_editor": "moonshot-v1-8k",
|
||||
"specialist": "moonshot-v1-8k"
|
||||
}
|
||||
}
|
||||
5
backend/src/config/services/alibaba/provider.json
Normal file
5
backend/src/config/services/alibaba/provider.json
Normal file
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"id": "aliyun",
|
||||
"name": "阿里云",
|
||||
"description": "阿里云提供的包括图像超分、视频超分等服务"
|
||||
}
|
||||
8
backend/src/config/services/alibaba/upscale.json
Normal file
8
backend/src/config/services/alibaba/upscale.json
Normal file
@@ -0,0 +1,8 @@
|
||||
{
|
||||
"videoenhan": {
|
||||
"name": "阿里巴巴图像超分",
|
||||
"class": "backend.src.services.post_process.super_resolution.SuperResolutionService",
|
||||
"args": [],
|
||||
"enabled": true
|
||||
}
|
||||
}
|
||||
16
backend/src/config/services/anthropic/provider.json
Normal file
16
backend/src/config/services/anthropic/provider.json
Normal file
@@ -0,0 +1,16 @@
|
||||
{
|
||||
"id": "anthropic",
|
||||
"name": "Anthropic",
|
||||
"description": "Claude 系列模型",
|
||||
"dashboard_url": "https://console.anthropic.com/settings/keys",
|
||||
"helpUrl": "https://console.anthropic.com/settings/keys",
|
||||
"fields": [
|
||||
{
|
||||
"name": "apiKey",
|
||||
"label": "API Key",
|
||||
"placeholder": "sk-ant-...",
|
||||
"required": true,
|
||||
"type": "password"
|
||||
}
|
||||
]
|
||||
}
|
||||
488
backend/src/config/services/dashscope/audio.json
Normal file
488
backend/src/config/services/dashscope/audio.json
Normal file
@@ -0,0 +1,488 @@
|
||||
{
|
||||
"cosyvoice-v3-plus": {
|
||||
"name": "CosyVoice-V3-Plus",
|
||||
"class": "backend.src.services.provider.dashscope.audio.DashScopeAudioService",
|
||||
"args": ["cosyvoice-v3-plus"],
|
||||
"voices": [
|
||||
{
|
||||
"id": "longanyang",
|
||||
"name": "龙安洋",
|
||||
"gender": "male",
|
||||
"desc": "阳光大男孩"
|
||||
},
|
||||
{
|
||||
"id": "longanhuan",
|
||||
"name": "龙安欢",
|
||||
"gender": "female",
|
||||
"desc": "欢脱元气女"
|
||||
},
|
||||
{
|
||||
"id": "longhuhu_v3",
|
||||
"name": "龙呼呼",
|
||||
"gender": "female",
|
||||
"desc": "天真烂漫女童"
|
||||
},
|
||||
{
|
||||
"id": "longpaopao_v3",
|
||||
"name": "龙泡泡",
|
||||
"gender": "female",
|
||||
"desc": "飞天泡泡音"
|
||||
},
|
||||
{
|
||||
"id": "longjielidou_v3",
|
||||
"name": "龙杰力豆",
|
||||
"gender": "male",
|
||||
"desc": "阳光顽皮男童"
|
||||
},
|
||||
{
|
||||
"id": "longxian_v3",
|
||||
"name": "龙仙",
|
||||
"gender": "female",
|
||||
"desc": "豪放可爱女童"
|
||||
},
|
||||
{
|
||||
"id": "longling_v3",
|
||||
"name": "龙铃",
|
||||
"gender": "female",
|
||||
"desc": "稚气呆板女童"
|
||||
},
|
||||
{
|
||||
"id": "longshanshan_v3",
|
||||
"name": "龙闪闪",
|
||||
"gender": "female",
|
||||
"desc": "戏剧化童声"
|
||||
},
|
||||
{
|
||||
"id": "longniuniu_v3",
|
||||
"name": "龙牛牛",
|
||||
"gender": "male",
|
||||
"desc": "阳光男童声"
|
||||
},
|
||||
{
|
||||
"id": "longjiaxin_v3",
|
||||
"name": "龙嘉欣",
|
||||
"gender": "female",
|
||||
"desc": "优雅粤语女"
|
||||
},
|
||||
{
|
||||
"id": "longjiayi_v3",
|
||||
"name": "龙嘉怡",
|
||||
"gender": "female",
|
||||
"desc": "知性粤语女"
|
||||
},
|
||||
{
|
||||
"id": "longanyue_v3",
|
||||
"name": "龙安粤",
|
||||
"gender": "male",
|
||||
"desc": "欢脱粤语男"
|
||||
},
|
||||
{
|
||||
"id": "longlaotie_v3",
|
||||
"name": "龙老铁",
|
||||
"gender": "male",
|
||||
"desc": "东北直率男"
|
||||
},
|
||||
{
|
||||
"id": "longshange_v3",
|
||||
"name": "龙陕哥",
|
||||
"gender": "male",
|
||||
"desc": "原味陕北男"
|
||||
},
|
||||
{
|
||||
"id": "longanmin_v3",
|
||||
"name": "龙安闽",
|
||||
"gender": "female",
|
||||
"desc": "清纯闽南女"
|
||||
},
|
||||
{
|
||||
"id": "loongkyong_v3",
|
||||
"name": "loongkyong",
|
||||
"gender": "female",
|
||||
"desc": "韩语女"
|
||||
},
|
||||
{
|
||||
"id": "loongriko_v3",
|
||||
"name": "Riko",
|
||||
"gender": "female",
|
||||
"desc": "二次元日语女"
|
||||
},
|
||||
{
|
||||
"id": "loongtomoka_v3",
|
||||
"name": "loongtomoka",
|
||||
"gender": "female",
|
||||
"desc": "日语女"
|
||||
},
|
||||
{
|
||||
"id": "longfei_v3",
|
||||
"name": "龙飞",
|
||||
"gender": "male",
|
||||
"desc": "热血磁性男"
|
||||
},
|
||||
{
|
||||
"id": "longxiaochun_v3",
|
||||
"name": "龙小淳",
|
||||
"gender": "female",
|
||||
"desc": "清丽温柔女"
|
||||
},
|
||||
{
|
||||
"id": "longxiaoxia_v3",
|
||||
"name": "龙小夏",
|
||||
"gender": "female",
|
||||
"desc": "活泼甜美女"
|
||||
},
|
||||
{
|
||||
"id": "longshu_v3",
|
||||
"name": "龙舒",
|
||||
"gender": "female",
|
||||
"desc": "知性温婉女"
|
||||
},
|
||||
{
|
||||
"id": "longyue_v3",
|
||||
"name": "龙悦",
|
||||
"gender": "male",
|
||||
"desc": "阳光青年男"
|
||||
},
|
||||
{
|
||||
"id": "longcheng_v3",
|
||||
"name": "龙城",
|
||||
"gender": "male",
|
||||
"desc": "成熟稳重男"
|
||||
},
|
||||
{
|
||||
"id": "longhua_v3",
|
||||
"name": "龙华",
|
||||
"gender": "male",
|
||||
"desc": "标准男声"
|
||||
},
|
||||
{
|
||||
"id": "longwan_v3",
|
||||
"name": "龙婉",
|
||||
"gender": "female",
|
||||
"desc": "温婉知性女"
|
||||
},
|
||||
{
|
||||
"id": "longjing_v3",
|
||||
"name": "龙静",
|
||||
"gender": "female",
|
||||
"desc": "标准女声"
|
||||
},
|
||||
{
|
||||
"id": "longmiao_v3",
|
||||
"name": "龙淼",
|
||||
"gender": "female",
|
||||
"desc": "标准女声"
|
||||
},
|
||||
{
|
||||
"id": "longshuo_v3",
|
||||
"name": "龙硕",
|
||||
"gender": "male",
|
||||
"desc": "标准男声"
|
||||
},
|
||||
{
|
||||
"id": "longxiang_v3",
|
||||
"name": "龙翔",
|
||||
"gender": "male",
|
||||
"desc": "标准男声"
|
||||
},
|
||||
{
|
||||
"id": "longyuan_v3",
|
||||
"name": "龙源",
|
||||
"gender": "male",
|
||||
"desc": "标准男声"
|
||||
}
|
||||
],
|
||||
"enabled": true
|
||||
},
|
||||
"qwen3-tts-flash": {
|
||||
"name": "Qwen3-TTS-Flash",
|
||||
"class": "backend.src.services.provider.dashscope.audio.DashScopeAudioService",
|
||||
"args": ["qwen3-tts-flash"],
|
||||
"voices": [
|
||||
{
|
||||
"id": "Cherry",
|
||||
"name": "芊悦",
|
||||
"gender": "female",
|
||||
"desc": "阳光积极、亲切自然小姐姐"
|
||||
},
|
||||
{
|
||||
"id": "Serena",
|
||||
"name": "苏瑶",
|
||||
"gender": "female",
|
||||
"desc": "温柔小姐姐"
|
||||
},
|
||||
{
|
||||
"id": "Ethan",
|
||||
"name": "晨煦",
|
||||
"gender": "male",
|
||||
"desc": "阳光、温暖、活力、朝气"
|
||||
},
|
||||
{
|
||||
"id": "Chelsie",
|
||||
"name": "千雪",
|
||||
"gender": "female",
|
||||
"desc": "二次元虚拟女友"
|
||||
},
|
||||
{
|
||||
"id": "Momo",
|
||||
"name": "茉兔",
|
||||
"gender": "female",
|
||||
"desc": "撒娇搞怪,逗你开心"
|
||||
},
|
||||
{
|
||||
"id": "Vivian",
|
||||
"name": "十三",
|
||||
"gender": "female",
|
||||
"desc": "拽拽的、可爱的小暴躁"
|
||||
},
|
||||
{ "id": "Moon", "name": "月白", "gender": "male", "desc": "率性帅气" },
|
||||
{
|
||||
"id": "Maia",
|
||||
"name": "四月",
|
||||
"gender": "female",
|
||||
"desc": "知性与温柔的碰撞"
|
||||
},
|
||||
{ "id": "Kai", "name": "凯", "gender": "male", "desc": "耳朵的一场SPA" },
|
||||
{
|
||||
"id": "Nofish",
|
||||
"name": "不吃鱼",
|
||||
"gender": "male",
|
||||
"desc": "不会翘舌音的设计师"
|
||||
},
|
||||
{
|
||||
"id": "Bella",
|
||||
"name": "萌宝",
|
||||
"gender": "female",
|
||||
"desc": "喝酒不打醉拳的小萝莉"
|
||||
},
|
||||
{
|
||||
"id": "Jennifer",
|
||||
"name": "詹妮弗",
|
||||
"gender": "female",
|
||||
"desc": "品牌级、电影质感般美语女声"
|
||||
},
|
||||
{
|
||||
"id": "Ryan",
|
||||
"name": "甜茶",
|
||||
"gender": "male",
|
||||
"desc": "节奏拉满,戏感炸裂"
|
||||
},
|
||||
{
|
||||
"id": "Katerina",
|
||||
"name": "卡捷琳娜",
|
||||
"gender": "female",
|
||||
"desc": "御姐音色,韵律回味十足"
|
||||
},
|
||||
{
|
||||
"id": "Aiden",
|
||||
"name": "艾登",
|
||||
"gender": "male",
|
||||
"desc": "精通厨艺的美语大男孩"
|
||||
},
|
||||
{
|
||||
"id": "Eldric Sage",
|
||||
"name": "沧明子",
|
||||
"gender": "male",
|
||||
"desc": "沉稳睿智的老者"
|
||||
},
|
||||
{
|
||||
"id": "Mia",
|
||||
"name": "乖小妹",
|
||||
"gender": "female",
|
||||
"desc": "温顺如春水,乖巧如初雪"
|
||||
},
|
||||
{
|
||||
"id": "Mochi",
|
||||
"name": "沙小弥",
|
||||
"gender": "male",
|
||||
"desc": "聪明伶俐的小大人"
|
||||
},
|
||||
{
|
||||
"id": "Bellona",
|
||||
"name": "燕铮莺",
|
||||
"gender": "female",
|
||||
"desc": "金戈铁马,千面人声"
|
||||
},
|
||||
{
|
||||
"id": "Vincent",
|
||||
"name": "田叔",
|
||||
"gender": "male",
|
||||
"desc": "沙哑烟嗓,江湖豪情"
|
||||
},
|
||||
{
|
||||
"id": "Bunny",
|
||||
"name": "萌小姬",
|
||||
"gender": "female",
|
||||
"desc": "萌属性爆棚的小萝莉"
|
||||
},
|
||||
{
|
||||
"id": "Neil",
|
||||
"name": "阿闻",
|
||||
"gender": "male",
|
||||
"desc": "字正腔圆的新闻主持人"
|
||||
},
|
||||
{
|
||||
"id": "Elias",
|
||||
"name": "墨讲师",
|
||||
"gender": "female",
|
||||
"desc": "严谨又通俗的知识讲解"
|
||||
},
|
||||
{
|
||||
"id": "Arthur",
|
||||
"name": "徐大爷",
|
||||
"gender": "male",
|
||||
"desc": "质朴嗓音,奇闻异事"
|
||||
},
|
||||
{
|
||||
"id": "Nini",
|
||||
"name": "邻家妹妹",
|
||||
"gender": "female",
|
||||
"desc": "又软又黏的甜蜜嗓音"
|
||||
},
|
||||
{
|
||||
"id": "Ebona",
|
||||
"name": "诡婆婆",
|
||||
"gender": "female",
|
||||
"desc": "幽暗低语,神秘诡异"
|
||||
},
|
||||
{
|
||||
"id": "Seren",
|
||||
"name": "小婉",
|
||||
"gender": "female",
|
||||
"desc": "温和舒缓,助眠音色"
|
||||
},
|
||||
{
|
||||
"id": "Pip",
|
||||
"name": "顽屁小孩",
|
||||
"gender": "male",
|
||||
"desc": "调皮捣蛋却充满童真"
|
||||
},
|
||||
{
|
||||
"id": "Stella",
|
||||
"name": "少女阿月",
|
||||
"gender": "female",
|
||||
"desc": "甜到发腻的迷糊少女音"
|
||||
},
|
||||
{
|
||||
"id": "Bodega",
|
||||
"name": "博德加",
|
||||
"gender": "male",
|
||||
"desc": "热情的西班牙大叔"
|
||||
},
|
||||
{
|
||||
"id": "Sonrisa",
|
||||
"name": "索尼莎",
|
||||
"gender": "female",
|
||||
"desc": "热情开朗的拉美大姐"
|
||||
},
|
||||
{
|
||||
"id": "Alek",
|
||||
"name": "阿列克",
|
||||
"gender": "male",
|
||||
"desc": "战斗民族的冷与暖"
|
||||
},
|
||||
{
|
||||
"id": "Dolce",
|
||||
"name": "多尔切",
|
||||
"gender": "male",
|
||||
"desc": "慵懒的意大利大叔"
|
||||
},
|
||||
{
|
||||
"id": "Sohee",
|
||||
"name": "素熙",
|
||||
"gender": "female",
|
||||
"desc": "温柔开朗的韩国欧尼"
|
||||
},
|
||||
{
|
||||
"id": "Ono Anna",
|
||||
"name": "小野杏",
|
||||
"gender": "female",
|
||||
"desc": "鬼灵精怪的青梅竹马"
|
||||
},
|
||||
{
|
||||
"id": "Lenn",
|
||||
"name": "莱恩",
|
||||
"gender": "male",
|
||||
"desc": "理性底色的德国青年"
|
||||
},
|
||||
{
|
||||
"id": "Emilien",
|
||||
"name": "埃米尔安",
|
||||
"gender": "male",
|
||||
"desc": "浪漫的法国大哥哥"
|
||||
},
|
||||
{
|
||||
"id": "Andre",
|
||||
"name": "安德雷",
|
||||
"gender": "male",
|
||||
"desc": "声音磁性,沉稳男生"
|
||||
},
|
||||
{
|
||||
"id": "Radio Gol",
|
||||
"name": "拉迪奥·戈尔",
|
||||
"gender": "male",
|
||||
"desc": "足球诗人解说员"
|
||||
},
|
||||
{
|
||||
"id": "Jada",
|
||||
"name": "上海-阿珍",
|
||||
"gender": "female",
|
||||
"desc": "风风火火的沪上阿姐"
|
||||
},
|
||||
{
|
||||
"id": "Dylan",
|
||||
"name": "北京-晓东",
|
||||
"gender": "male",
|
||||
"desc": "北京胡同里长大的少年"
|
||||
},
|
||||
{
|
||||
"id": "Li",
|
||||
"name": "南京-老李",
|
||||
"gender": "male",
|
||||
"desc": "耐心的瑜伽老师"
|
||||
},
|
||||
{
|
||||
"id": "Marcus",
|
||||
"name": "陕西-秦川",
|
||||
"gender": "male",
|
||||
"desc": "面宽话短,心实声沉"
|
||||
},
|
||||
{
|
||||
"id": "Roy",
|
||||
"name": "闽南-阿杰",
|
||||
"gender": "male",
|
||||
"desc": "诙谐直爽的台湾哥仔"
|
||||
},
|
||||
{
|
||||
"id": "Peter",
|
||||
"name": "天津-李彼得",
|
||||
"gender": "male",
|
||||
"desc": "天津相声,专业捧哏"
|
||||
},
|
||||
{
|
||||
"id": "Sunny",
|
||||
"name": "四川-晴儿",
|
||||
"gender": "female",
|
||||
"desc": "甜到你心里的川妹子"
|
||||
},
|
||||
{
|
||||
"id": "Eric",
|
||||
"name": "四川-程川",
|
||||
"gender": "male",
|
||||
"desc": "跳脱市井的成都男子"
|
||||
},
|
||||
{
|
||||
"id": "Rocky",
|
||||
"name": "粤语-阿强",
|
||||
"gender": "male",
|
||||
"desc": "幽默风趣,在线陪聊"
|
||||
},
|
||||
{
|
||||
"id": "Kiki",
|
||||
"name": "粤语-阿清",
|
||||
"gender": "female",
|
||||
"desc": "甜美的港妹闺蜜"
|
||||
}
|
||||
],
|
||||
"enabled": true
|
||||
}
|
||||
}
|
||||
143
backend/src/config/services/dashscope/image.json
Normal file
143
backend/src/config/services/dashscope/image.json
Normal file
@@ -0,0 +1,143 @@
|
||||
{
|
||||
"z-image": {
|
||||
"name": "Z-Image",
|
||||
"class": "backend.src.services.provider.dashscope.image.ZImageService",
|
||||
"args": [
|
||||
"z-image"
|
||||
],
|
||||
"capabilities": {
|
||||
"supportsRefImage": false,
|
||||
"supportsLora": false
|
||||
},
|
||||
"variants": {
|
||||
"t2i": "z-image-turbo"
|
||||
},
|
||||
"resolutions": {
|
||||
"1K": {
|
||||
"16:9": "1536*864",
|
||||
"9:16": "864*1536",
|
||||
"1:1": "1280*1280",
|
||||
"4:3": "1280*960",
|
||||
"3:4": "960*1280",
|
||||
"21:9": "1680*720",
|
||||
"9:21": "720*1680"
|
||||
},
|
||||
"2K": {
|
||||
"16:9": "2048*1152",
|
||||
"9:16": "1152*2048",
|
||||
"21:9": "2016*864",
|
||||
"9:21": "864*2016"
|
||||
}
|
||||
},
|
||||
"counts": {
|
||||
"min": 1,
|
||||
"max": 4
|
||||
},
|
||||
"enabled": true
|
||||
},
|
||||
"wan2.6-image": {
|
||||
"name": "Wan 2.6",
|
||||
"class": "backend.src.services.provider.dashscope.image.WanImageService",
|
||||
"args": [
|
||||
"wan2.6-image"
|
||||
],
|
||||
"capabilities": {
|
||||
"supportsRefImage": true,
|
||||
"supportsLora": false
|
||||
},
|
||||
"variants": {
|
||||
"t2i": "wan2.6-t2i",
|
||||
"i2i": "wan2.6-image"
|
||||
},
|
||||
"resolutions": {
|
||||
"1K": {
|
||||
"16:9": "1280*720",
|
||||
"9:16": "720*1280",
|
||||
"1:1": "1280*1280",
|
||||
"4:3": "1280*960",
|
||||
"3:4": "960*1280"
|
||||
},
|
||||
"2K": {
|
||||
"16:9": "2560*1440",
|
||||
"9:16": "1440*2560",
|
||||
"1:1": "2048*2048",
|
||||
"4:3": "2560*1920",
|
||||
"3:4": "1920*2560"
|
||||
}
|
||||
},
|
||||
"counts": {
|
||||
"min": 1,
|
||||
"max": 4
|
||||
},
|
||||
"enabled": true
|
||||
},
|
||||
"wan2.5-image": {
|
||||
"name": "Wan 2.5",
|
||||
"class": "backend.src.services.provider.dashscope.image.WanImageService",
|
||||
"args": [
|
||||
"wan2.5-image"
|
||||
],
|
||||
"capabilities": {
|
||||
"supportsRefImage": true,
|
||||
"supportsLora": false
|
||||
},
|
||||
"variants": {
|
||||
"t2i": "wan2.5-t2i-preview",
|
||||
"i2i": "wan2.5-i2i-preview"
|
||||
},
|
||||
"resolutions": {
|
||||
"1K": {
|
||||
"16:9": "1280*720",
|
||||
"9:16": "720*1280",
|
||||
"1:1": "1280*1280",
|
||||
"4:3": "1280*960",
|
||||
"3:4": "960*1280"
|
||||
},
|
||||
"2K": {
|
||||
"16:9": "2560*1440",
|
||||
"9:16": "1440*2560",
|
||||
"1:1": "2048*2048",
|
||||
"4:3": "2560*1920",
|
||||
"3:4": "1920*2560"
|
||||
}
|
||||
},
|
||||
"counts": {
|
||||
"min": 1,
|
||||
"max": 4
|
||||
},
|
||||
"enabled": true
|
||||
},
|
||||
"qwen-image": {
|
||||
"name": "Qwen Image",
|
||||
"class": "backend.src.services.provider.dashscope.image.QwenImageService",
|
||||
"args": [
|
||||
"qwen-image"
|
||||
],
|
||||
"capabilities": {
|
||||
"supportsRefImage": true,
|
||||
"supportsLora": false
|
||||
},
|
||||
"variants": {
|
||||
"t2i": "qwen-image-plus",
|
||||
"i2i": "qwen-image-edit-plus"
|
||||
},
|
||||
"resolutions": {
|
||||
"1K": {
|
||||
"16:9": "1664*928",
|
||||
"9:16": "928*1664",
|
||||
"1:1": "1328*1328",
|
||||
"4:3": "1472*1140",
|
||||
"3:4": "1140*1472"
|
||||
},
|
||||
"2K": {
|
||||
"16:9": "2560*1440",
|
||||
"9:16": "1440*2560",
|
||||
"1:1": "2048*2048",
|
||||
"4:3": "2560*1920",
|
||||
"3:4": "1920*2560"
|
||||
}
|
||||
},
|
||||
"counts": {"min": 1, "max": 4},
|
||||
"enabled": true
|
||||
}
|
||||
}
|
||||
39
backend/src/config/services/dashscope/llm.json
Normal file
39
backend/src/config/services/dashscope/llm.json
Normal file
@@ -0,0 +1,39 @@
|
||||
{
|
||||
"base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
"qwen-plus": {
|
||||
"name": "Qwen Plus",
|
||||
"class": "backend.src.services.provider.openai_service.OpenAIService",
|
||||
"args": ["qwen-plus"],
|
||||
"enabled": true
|
||||
},
|
||||
"qwen3-max": {
|
||||
"name": "Qwen Max",
|
||||
"class": "backend.src.services.provider.openai_service.OpenAIService",
|
||||
"args": ["qwen3-max"],
|
||||
"enabled": true
|
||||
},
|
||||
"qwen-flash": {
|
||||
"name": "Qwen Flash",
|
||||
"class": "backend.src.services.provider.openai_service.OpenAIService",
|
||||
"args": ["qwen-flash"],
|
||||
"enabled": true
|
||||
},
|
||||
"deepseek": {
|
||||
"name": "Deepseek",
|
||||
"class": "backend.src.services.provider.openai_service.OpenAIService",
|
||||
"args": ["deepseek-v3.2"],
|
||||
"enabled": true
|
||||
},
|
||||
"kimi-k2.5": {
|
||||
"name": "Kimi K2.5",
|
||||
"class": "backend.src.services.provider.openai_service.OpenAIService",
|
||||
"args": ["kimi-k2.5"],
|
||||
"enabled": true
|
||||
},
|
||||
"MiniMax-M2.1": {
|
||||
"name": "MiniMax M2.1",
|
||||
"class": "backend.src.services.provider.openai_service.OpenAIService",
|
||||
"args": ["MiniMax-M2.1"],
|
||||
"enabled": true
|
||||
}
|
||||
}
|
||||
16
backend/src/config/services/dashscope/provider.json
Normal file
16
backend/src/config/services/dashscope/provider.json
Normal file
@@ -0,0 +1,16 @@
|
||||
{
|
||||
"id": "dashscope",
|
||||
"name": "阿里云百炼",
|
||||
"description": "阿里云提供的大模型服务",
|
||||
"dashboard_url": "https://dashscope.console.aliyun.com/",
|
||||
"helpUrl": "https://dashscope.console.aliyun.com/apiKey",
|
||||
"fields": [
|
||||
{
|
||||
"name": "apiKey",
|
||||
"label": "API Key",
|
||||
"placeholder": "sk-...",
|
||||
"required": true,
|
||||
"type": "password"
|
||||
}
|
||||
]
|
||||
}
|
||||
196
backend/src/config/services/dashscope/video.json
Normal file
196
backend/src/config/services/dashscope/video.json
Normal file
@@ -0,0 +1,196 @@
|
||||
{
|
||||
"wan2.6-video": {
|
||||
"name": "Wan 2.6",
|
||||
"class": "backend.src.services.provider.dashscope.video.WanVideoService",
|
||||
"args": [
|
||||
"wan2.6-video"
|
||||
],
|
||||
"capabilities": {
|
||||
"supportsTextToVideo": true,
|
||||
"supportsImageToVideo": true,
|
||||
"supportsVideoToVideo": true,
|
||||
"supportsFirstFrame": true,
|
||||
"supportsLastFrame": false,
|
||||
"supportsMultiImage": true,
|
||||
"supportsMultiVideo": true,
|
||||
"supportsAudio": true,
|
||||
"supportsShotType": true,
|
||||
"supportsNegativePrompt": true,
|
||||
"supportsLora": false
|
||||
},
|
||||
"variants": {
|
||||
"t2v": "wan2.6-t2v",
|
||||
"i2v": "wan2.6-i2v",
|
||||
"r2v": "wan2.6-r2v"
|
||||
},
|
||||
"resolutions": {
|
||||
"720P": {
|
||||
"16:9": "1280*720",
|
||||
"9:16": "720*1280",
|
||||
"1:1": "1280*1280",
|
||||
"4:3": "1280*960",
|
||||
"3:4": "960*1280"
|
||||
},
|
||||
"1080P": {
|
||||
"16:9": "1920*1080",
|
||||
"9:16": "1080*1920",
|
||||
"1:1": "1920*1920",
|
||||
"4:3": "1920*1440",
|
||||
"3:4": "1440*1920"
|
||||
}
|
||||
},
|
||||
"durations": {
|
||||
"min": 2,
|
||||
"max": 10
|
||||
},
|
||||
"counts": {
|
||||
"min": 1,
|
||||
"max": 4
|
||||
},
|
||||
"enabled": true
|
||||
},
|
||||
"wan2.6-video-flash": {
|
||||
"name": "Wan 2.6 Flash",
|
||||
"class": "backend.src.services.provider.dashscope.video.WanVideoService",
|
||||
"args": [
|
||||
"wan2.6-video-flash"
|
||||
],
|
||||
"capabilities": {
|
||||
"supportsTextToVideo": true,
|
||||
"supportsImageToVideo": true,
|
||||
"supportsVideoToVideo": true,
|
||||
"supportsFirstFrame": true,
|
||||
"supportsLastFrame": false,
|
||||
"supportsMultiImage": true,
|
||||
"supportsMultiVideo": true,
|
||||
"supportsAudio": true,
|
||||
"supportsShotType": true,
|
||||
"supportsNegativePrompt": true,
|
||||
"supportsLora": false
|
||||
},
|
||||
"variants": {
|
||||
"t2v": "wan2.6-t2v",
|
||||
"i2v": "wan2.6-i2v-flash",
|
||||
"r2v": "wan2.6-r2v-flash"
|
||||
},
|
||||
"resolutions": {
|
||||
"720P": {
|
||||
"16:9": "1280*720",
|
||||
"9:16": "720*1280",
|
||||
"1:1": "1280*1280",
|
||||
"4:3": "1280*960",
|
||||
"3:4": "960*1280"
|
||||
},
|
||||
"1080P": {
|
||||
"16:9": "1920*1080",
|
||||
"9:16": "1080*1920",
|
||||
"1:1": "1920*1920",
|
||||
"4:3": "1920*1440",
|
||||
"3:4": "1440*1920"
|
||||
}
|
||||
},
|
||||
"durations": {
|
||||
"min": 2,
|
||||
"max": 10
|
||||
},
|
||||
"counts": {
|
||||
"min": 1,
|
||||
"max": 4
|
||||
},
|
||||
"enabled": true
|
||||
},
|
||||
"wan2.5-video": {
|
||||
"name": "Wan 2.5",
|
||||
"class": "backend.src.services.provider.dashscope.video.WanVideoService",
|
||||
"args": [
|
||||
"wan2.5-video"
|
||||
],
|
||||
"capabilities": {
|
||||
"supportsTextToVideo": true,
|
||||
"supportsImageToVideo": true,
|
||||
"supportsFirstFrame": true,
|
||||
"supportsLastFrame": false,
|
||||
"supportsMultiImage": false,
|
||||
"supportsMultiVideo": false,
|
||||
"supportsLora": false
|
||||
},
|
||||
"variants": {
|
||||
"t2v": "wan2.5-t2v-preview",
|
||||
"i2v": "wan2.5-i2v-preview"
|
||||
},
|
||||
"resolutions": {
|
||||
"720P": {
|
||||
"16:9": "1280*720",
|
||||
"9:16": "720*1280",
|
||||
"1:1": "1280*1280",
|
||||
"4:3": "1280*960",
|
||||
"3:4": "960*1280"
|
||||
},
|
||||
"1080P": {
|
||||
"16:9": "1920*1080",
|
||||
"9:16": "1080*1920",
|
||||
"1:1": "1920*1920",
|
||||
"4:3": "1920*1440",
|
||||
"3:4": "1440*1920"
|
||||
}
|
||||
},
|
||||
"durations": {
|
||||
"values": [
|
||||
5,
|
||||
10
|
||||
]
|
||||
},
|
||||
"counts": {
|
||||
"min": 1,
|
||||
"max": 4
|
||||
},
|
||||
"enabled": true
|
||||
},
|
||||
"wan2.2-video": {
|
||||
"name": "Wan 2.2",
|
||||
"class": "backend.src.services.provider.dashscope.video.WanVideoService",
|
||||
"args": [
|
||||
"wan2.2-video"
|
||||
],
|
||||
"capabilities": {
|
||||
"supportsTextToVideo": true,
|
||||
"supportsImageToVideo": true,
|
||||
"supportsFirstFrame": true,
|
||||
"supportsLastFrame": true,
|
||||
"supportsMultiImage": false,
|
||||
"supportsMultiVideo": false,
|
||||
"supportsLora": false
|
||||
},
|
||||
"variants": {
|
||||
"t2v": "wan2.2-t2v-plus",
|
||||
"i2v": "wan2.2-i2v-flash",
|
||||
"kf2v": "wan2.2-kf2v-flash"
|
||||
},
|
||||
"resolutions": {
|
||||
"720P": {
|
||||
"16:9": "1280*720",
|
||||
"9:16": "720*1280",
|
||||
"1:1": "1280*1280",
|
||||
"4:3": "1280*960",
|
||||
"3:4": "960*1280"
|
||||
},
|
||||
"1080P": {
|
||||
"16:9": "1920*1080",
|
||||
"9:16": "1080*1920",
|
||||
"1:1": "1920*1920",
|
||||
"4:3": "1920*1440",
|
||||
"3:4": "1440*1920"
|
||||
}
|
||||
},
|
||||
"durations": {
|
||||
"values": [
|
||||
5
|
||||
]
|
||||
},
|
||||
"counts": {
|
||||
"min": 1,
|
||||
"max": 4
|
||||
},
|
||||
"enabled": true
|
||||
}
|
||||
}
|
||||
7
backend/src/config/services/default.json
Normal file
7
backend/src/config/services/default.json
Normal file
@@ -0,0 +1,7 @@
|
||||
{
|
||||
"llm": "qwen-plus",
|
||||
"image": "z-image",
|
||||
"video": "wan2.6-video",
|
||||
"audio": "qwen3-tts-flash",
|
||||
"upscale": "ali-videoenhan/videoenhan"
|
||||
}
|
||||
27
backend/src/config/services/google/llm.json
Normal file
27
backend/src/config/services/google/llm.json
Normal file
@@ -0,0 +1,27 @@
|
||||
{
|
||||
"base_url": "https://generativelanguage.googleapis.com/v1beta/openai/",
|
||||
"gemini-1.5-pro": {
|
||||
"name": "Gemini-1.5-Pro",
|
||||
"class": "backend.src.services.provider.openai_service.OpenAIService",
|
||||
"args": [
|
||||
"gemini-1.5-pro"
|
||||
],
|
||||
"enabled": false
|
||||
},
|
||||
"gemini-1.5-flash": {
|
||||
"name": "Gemini-1.5-Flash",
|
||||
"class": "backend.src.services.provider.openai_service.OpenAIService",
|
||||
"args": [
|
||||
"gemini-1.5-flash"
|
||||
],
|
||||
"enabled": false
|
||||
},
|
||||
"gemini-2.0-flash": {
|
||||
"name": "Gemini-2.0-Flash",
|
||||
"class": "backend.src.services.provider.openai_service.OpenAIService",
|
||||
"args": [
|
||||
"gemini-2.0-flash"
|
||||
],
|
||||
"enabled": false
|
||||
}
|
||||
}
|
||||
16
backend/src/config/services/google/provider.json
Normal file
16
backend/src/config/services/google/provider.json
Normal file
@@ -0,0 +1,16 @@
|
||||
{
|
||||
"id": "google",
|
||||
"name": "Google",
|
||||
"description": "Gemini 系列模型",
|
||||
"dashboard_url": "https://aistudio.google.com/app/apikey",
|
||||
"helpUrl": "https://aistudio.google.com/app/apikey",
|
||||
"fields": [
|
||||
{
|
||||
"name": "apiKey",
|
||||
"label": "API Key",
|
||||
"placeholder": "...",
|
||||
"required": true,
|
||||
"type": "password"
|
||||
}
|
||||
]
|
||||
}
|
||||
27
backend/src/config/services/kling/provider.json
Normal file
27
backend/src/config/services/kling/provider.json
Normal file
@@ -0,0 +1,27 @@
|
||||
{
|
||||
"id": "kling",
|
||||
"name": "可灵 AI",
|
||||
"description": "快手视频生成 - 需要 Access Key 和 Secret Key",
|
||||
"dashboard_url": "https://klingai.kuaishou.com/",
|
||||
"param_mapping": {
|
||||
"api_key": "access_key",
|
||||
"api_secret": "secret_key"
|
||||
},
|
||||
"helpUrl": "https://klingai.kuaishou.com/",
|
||||
"fields": [
|
||||
{
|
||||
"name": "accessKey",
|
||||
"label": "Access Key",
|
||||
"placeholder": "Access Key",
|
||||
"required": true,
|
||||
"type": "text"
|
||||
},
|
||||
{
|
||||
"name": "secretKey",
|
||||
"label": "Secret Key",
|
||||
"placeholder": "Secret Key",
|
||||
"required": true,
|
||||
"type": "password"
|
||||
}
|
||||
]
|
||||
}
|
||||
124
backend/src/config/services/kling/video.json
Normal file
124
backend/src/config/services/kling/video.json
Normal file
@@ -0,0 +1,124 @@
|
||||
{
|
||||
"kling-v2-5-turbo": {
|
||||
"name": "Kling V2.5 Turbo",
|
||||
"class": "src.services.provider.kling.KlingVideoService",
|
||||
"args": [
|
||||
"kling-v2-5-turbo"
|
||||
],
|
||||
"capabilities": {
|
||||
"supportsTextToVideo": true,
|
||||
"supportsImageToVideo": true,
|
||||
"supportsFirstFrame": true,
|
||||
"supportsLastFrame": true,
|
||||
"supportsMultiImage": false,
|
||||
"supportsMultiVideo": false,
|
||||
"supportsAudio": false,
|
||||
"supportsNegativePrompt": true,
|
||||
"supportsCameraControl": true,
|
||||
"supportsLora": false
|
||||
},
|
||||
"variants": {
|
||||
"t2v": "kling-v2-5-turbo",
|
||||
"i2v": "kling-v2-5-turbo",
|
||||
"kf2v": "kling-v2-5-turbo"
|
||||
},
|
||||
"modes": ["std", "pro"],
|
||||
"resolutions": {
|
||||
"720P": {
|
||||
"16:9": "1280*720",
|
||||
"9:16": "720*1280",
|
||||
"1:1": "1024*1024"
|
||||
},
|
||||
"1080P": {
|
||||
"16:9": "1920*1080",
|
||||
"9:16": "1080*1920",
|
||||
"1:1": "1024*1024"
|
||||
}
|
||||
},
|
||||
"durations": {
|
||||
"values": [5, 10]
|
||||
},
|
||||
"counts": {"min": 1, "max": 4},
|
||||
"enabled": true
|
||||
},
|
||||
"kling-v2-6": {
|
||||
"name": "Kling V2.6",
|
||||
"class": "src.services.provider.kling.KlingVideoService",
|
||||
"args": [
|
||||
"kling-v2-6"
|
||||
],
|
||||
"capabilities": {
|
||||
"supportsTextToVideo": true,
|
||||
"supportsImageToVideo": true,
|
||||
"supportsFirstFrame": true,
|
||||
"supportsLastFrame": true,
|
||||
"supportsMultiImage": false,
|
||||
"supportsMultiVideo": false,
|
||||
"supportsAudio": true,
|
||||
"supportsNegativePrompt": true,
|
||||
"supportsCameraControl": true,
|
||||
"supportsLora": false
|
||||
},
|
||||
"variants": {
|
||||
"t2v": "kling-v2-6",
|
||||
"i2v": "kling-v2-6",
|
||||
"kf2v": "kling-v2-6"
|
||||
},
|
||||
"modes": ["pro"],
|
||||
"resolutions": {
|
||||
"720P": {
|
||||
"16:9": "1280*720",
|
||||
"9:16": "720*1280",
|
||||
"1:1": "1024*1024"
|
||||
},
|
||||
"1080P": {
|
||||
"16:9": "1920*1080",
|
||||
"9:16": "1080*1920",
|
||||
"1:1": "1024*1024"
|
||||
}
|
||||
},
|
||||
"durations": {
|
||||
"values": [5, 10]
|
||||
},
|
||||
"counts": {"min": 1, "max": 4},
|
||||
"enabled": true
|
||||
},
|
||||
"kling-video-o1": {
|
||||
"name": "Kling Omni (O1)",
|
||||
"class": "src.services.provider.kling.KlingVideoService",
|
||||
"args": [
|
||||
"kling-video-o1"
|
||||
],
|
||||
"capabilities": {
|
||||
"supportsTextToVideo": true,
|
||||
"supportsImageToVideo": true,
|
||||
"supportsVideoToVideo": true,
|
||||
"supportsFirstFrame": true,
|
||||
"supportsLastFrame": true,
|
||||
"supportsMultiImage": true,
|
||||
"supportsMultiVideo": true,
|
||||
"supportsAudio": true,
|
||||
"supportsNegativePrompt": true,
|
||||
"supportsLora": false
|
||||
},
|
||||
"variants": {
|
||||
"t2v": "kling-video-o1",
|
||||
"i2v": "kling-video-o1",
|
||||
"v2v": "kling-video-o1",
|
||||
"kf2v": "kling-video-o1"
|
||||
},
|
||||
"resolutions": {
|
||||
"720P": {
|
||||
"16:9": "1280*720",
|
||||
"9:16": "720*1280",
|
||||
"1:1": "1024*1024"
|
||||
}
|
||||
},
|
||||
"durations": {
|
||||
"min": 3,
|
||||
"max": 10
|
||||
},
|
||||
"counts": {"min": 1, "max": 4},
|
||||
"enabled": true
|
||||
}
|
||||
}
|
||||
34
backend/src/config/services/midjourney/image.json
Normal file
34
backend/src/config/services/midjourney/image.json
Normal file
@@ -0,0 +1,34 @@
|
||||
{
|
||||
"midjourney": {
|
||||
"name": "Midjourney",
|
||||
"class": "src.services.provider.midjourney.MidjourneyImageService",
|
||||
"args": [
|
||||
"midjourney"
|
||||
],
|
||||
"capabilities": {
|
||||
"supportsRefImage": false,
|
||||
"supportsLora": false
|
||||
},
|
||||
"variants": {
|
||||
"t2i": "midjourney"
|
||||
},
|
||||
"resolutions": {
|
||||
"1K": {
|
||||
"1:1": "1024*1024",
|
||||
"16:9": "1280*720",
|
||||
"9:16": "720*1280",
|
||||
"4:3": "1280*960",
|
||||
"3:4": "960*1280"
|
||||
},
|
||||
"2K": {
|
||||
"1:1": "1456*1456",
|
||||
"16:9": "1456*816",
|
||||
"9:16": "816*1456",
|
||||
"4:3": "1232*928",
|
||||
"3:4": "928*1232"
|
||||
}
|
||||
},
|
||||
"counts": {"min": 1, "max": 4},
|
||||
"enabled": false
|
||||
}
|
||||
}
|
||||
23
backend/src/config/services/midjourney/provider.json
Normal file
23
backend/src/config/services/midjourney/provider.json
Normal file
@@ -0,0 +1,23 @@
|
||||
{
|
||||
"id": "midjourney",
|
||||
"name": "Midjourney(悠船)",
|
||||
"description": "悠船 API - 需要 App ID 和 Secret Key",
|
||||
"dashboard_url": "https://ali.youchuan.cn/",
|
||||
"helpUrl": "https://ali.youchuan.cn/",
|
||||
"fields": [
|
||||
{
|
||||
"name": "apiKey",
|
||||
"label": "App ID",
|
||||
"placeholder": "应用 ID",
|
||||
"required": true,
|
||||
"type": "text"
|
||||
},
|
||||
{
|
||||
"name": "apiSecret",
|
||||
"label": "Secret Key",
|
||||
"placeholder": "密钥",
|
||||
"required": true,
|
||||
"type": "password"
|
||||
}
|
||||
]
|
||||
}
|
||||
378
backend/src/config/services/minimax/audio.json
Normal file
378
backend/src/config/services/minimax/audio.json
Normal file
@@ -0,0 +1,378 @@
|
||||
{
|
||||
"speech-2.8-hd": {
|
||||
"name": "speech-2.8-hd",
|
||||
"class": "src.services.provider.minimax.MiniMaxAudioService",
|
||||
"args": [
|
||||
"speech-2.8-hd"
|
||||
],
|
||||
"voices": [
|
||||
{
|
||||
"id": "male-qn-qingse",
|
||||
"name": "青涩青年",
|
||||
"gender": "male",
|
||||
"desc": "青涩青年风格中文普通话男声"
|
||||
},
|
||||
{
|
||||
"id": "male-qn-jingying",
|
||||
"name": "精英青年",
|
||||
"gender": "male",
|
||||
"desc": "精英青年风格中文普通话男声"
|
||||
},
|
||||
{
|
||||
"id": "male-qn-badao",
|
||||
"name": "霸道青年",
|
||||
"gender": "male",
|
||||
"desc": "沉稳有力的中文普通话青年男声"
|
||||
},
|
||||
{
|
||||
"id": "male-qn-daxuesheng",
|
||||
"name": "青年大学生",
|
||||
"gender": "male",
|
||||
"desc": "青年大学生风格中文普通话男声"
|
||||
},
|
||||
{
|
||||
"id": "female-shaonv",
|
||||
"name": "少女",
|
||||
"gender": "female",
|
||||
"desc": "清亮少女风格中文普通话女声"
|
||||
},
|
||||
{
|
||||
"id": "female-yujie",
|
||||
"name": "御姐",
|
||||
"gender": "female",
|
||||
"desc": "成熟干练风格中文普通话女声"
|
||||
},
|
||||
{
|
||||
"id": "female-chengshu",
|
||||
"name": "成熟女性",
|
||||
"gender": "female",
|
||||
"desc": "成熟女性风格中文普通话女声"
|
||||
},
|
||||
{
|
||||
"id": "female-tianmei",
|
||||
"name": "甜美女性",
|
||||
"gender": "female",
|
||||
"desc": "甜美风格中文普通话女声"
|
||||
},
|
||||
{
|
||||
"id": "Chinese (Mandarin)_News_Anchor",
|
||||
"name": "新闻女声",
|
||||
"gender": "female",
|
||||
"desc": "专业播音腔的中年女性新闻主播,标准普通话"
|
||||
},
|
||||
{
|
||||
"id": "Chinese (Mandarin)_Male_Announcer",
|
||||
"name": "播报男声",
|
||||
"gender": "male",
|
||||
"desc": "富有磁性的中年男性播报员声音,标准普通话,清晰而权威"
|
||||
},
|
||||
{
|
||||
"id": "Chinese (Mandarin)_Gentleman",
|
||||
"name": "温润男声",
|
||||
"gender": "male",
|
||||
"desc": "温润磁性的青年男性声音,标准普通话"
|
||||
},
|
||||
{
|
||||
"id": "Chinese (Mandarin)_Sweet_Lady",
|
||||
"name": "甜美女声",
|
||||
"gender": "female",
|
||||
"desc": "温柔甜美的青年女性声音,标准普通话"
|
||||
},
|
||||
{
|
||||
"id": "Cantonese_ProfessionalHost(F)",
|
||||
"name": "粤语专业女主持",
|
||||
"gender": "female",
|
||||
"desc": "中性、专业的青年女性粤语主持人声音"
|
||||
},
|
||||
{
|
||||
"id": "Cantonese_ProfessionalHost(M)",
|
||||
"name": "粤语专业男主持",
|
||||
"gender": "male",
|
||||
"desc": "中性、专业的青年男性粤语主持人声音"
|
||||
}
|
||||
],
|
||||
"enabled": true
|
||||
},
|
||||
"speech-2.6-hd": {
|
||||
"name": "speech-2.6-hd",
|
||||
"class": "src.services.provider.minimax.MiniMaxAudioService",
|
||||
"args": [
|
||||
"speech-2.6-hd"
|
||||
],
|
||||
"voices": [
|
||||
{
|
||||
"id": "male-qn-qingse",
|
||||
"name": "青涩青年",
|
||||
"gender": "male",
|
||||
"desc": "青涩青年风格中文普通话男声"
|
||||
},
|
||||
{
|
||||
"id": "male-qn-jingying",
|
||||
"name": "精英青年",
|
||||
"gender": "male",
|
||||
"desc": "精英青年风格中文普通话男声"
|
||||
},
|
||||
{
|
||||
"id": "male-qn-badao",
|
||||
"name": "霸道青年",
|
||||
"gender": "male",
|
||||
"desc": "沉稳有力的中文普通话青年男声"
|
||||
},
|
||||
{
|
||||
"id": "male-qn-daxuesheng",
|
||||
"name": "青年大学生",
|
||||
"gender": "male",
|
||||
"desc": "青年大学生风格中文普通话男声"
|
||||
},
|
||||
{
|
||||
"id": "female-shaonv",
|
||||
"name": "少女",
|
||||
"gender": "female",
|
||||
"desc": "清亮少女风格中文普通话女声"
|
||||
},
|
||||
{
|
||||
"id": "female-yujie",
|
||||
"name": "御姐",
|
||||
"gender": "female",
|
||||
"desc": "成熟干练风格中文普通话女声"
|
||||
},
|
||||
{
|
||||
"id": "female-chengshu",
|
||||
"name": "成熟女性",
|
||||
"gender": "female",
|
||||
"desc": "成熟女性风格中文普通话女声"
|
||||
},
|
||||
{
|
||||
"id": "female-tianmei",
|
||||
"name": "甜美女性",
|
||||
"gender": "female",
|
||||
"desc": "甜美风格中文普通话女声"
|
||||
},
|
||||
{
|
||||
"id": "Chinese (Mandarin)_News_Anchor",
|
||||
"name": "新闻女声",
|
||||
"gender": "female",
|
||||
"desc": "专业、播音腔的中年女性新闻主播,标准普通话"
|
||||
},
|
||||
{
|
||||
"id": "Chinese (Mandarin)_Male_Announcer",
|
||||
"name": "播报男声",
|
||||
"gender": "male",
|
||||
"desc": "富有磁性的中年男性播报员声音,标准普通话,清晰而权威"
|
||||
},
|
||||
{
|
||||
"id": "Chinese (Mandarin)_Gentleman",
|
||||
"name": "温润男声",
|
||||
"gender": "male",
|
||||
"desc": "温润磁性的青年男性声音,标准普通话"
|
||||
},
|
||||
{
|
||||
"id": "Chinese (Mandarin)_Sweet_Lady",
|
||||
"name": "甜美女声",
|
||||
"gender": "female",
|
||||
"desc": "温柔甜美的青年女性声音,标准普通话"
|
||||
},
|
||||
{
|
||||
"id": "Cantonese_ProfessionalHost(F)",
|
||||
"name": "粤语专业女主持",
|
||||
"gender": "female",
|
||||
"desc": "中性、专业的青年女性粤语主持人声音"
|
||||
},
|
||||
{
|
||||
"id": "Cantonese_ProfessionalHost(M)",
|
||||
"name": "粤语专业男主持",
|
||||
"gender": "male",
|
||||
"desc": "中性、专业的青年男性粤语主持人声音"
|
||||
}
|
||||
],
|
||||
"enabled": true
|
||||
},
|
||||
"speech-2.8-turbo": {
|
||||
"name": "speech-2.8-turbo",
|
||||
"class": "src.services.provider.minimax.MiniMaxAudioService",
|
||||
"args": [
|
||||
"speech-2.8-turbo"
|
||||
],
|
||||
"voices": [
|
||||
{
|
||||
"id": "male-qn-qingse",
|
||||
"name": "青涩青年",
|
||||
"gender": "male",
|
||||
"desc": "青涩青年风格中文普通话男声"
|
||||
},
|
||||
{
|
||||
"id": "male-qn-jingying",
|
||||
"name": "精英青年",
|
||||
"gender": "male",
|
||||
"desc": "精英青年风格中文普通话男声"
|
||||
},
|
||||
{
|
||||
"id": "male-qn-badao",
|
||||
"name": "霸道青年",
|
||||
"gender": "male",
|
||||
"desc": "沉稳有力的中文普通话青年男声"
|
||||
},
|
||||
{
|
||||
"id": "male-qn-daxuesheng",
|
||||
"name": "青年大学生",
|
||||
"gender": "male",
|
||||
"desc": "青年大学生风格中文普通话男声"
|
||||
},
|
||||
{
|
||||
"id": "female-shaonv",
|
||||
"name": "少女",
|
||||
"gender": "female",
|
||||
"desc": "清亮少女风格中文普通话女声"
|
||||
},
|
||||
{
|
||||
"id": "female-yujie",
|
||||
"name": "御姐",
|
||||
"gender": "female",
|
||||
"desc": "成熟干练风格中文普通话女声"
|
||||
},
|
||||
{
|
||||
"id": "female-chengshu",
|
||||
"name": "成熟女性",
|
||||
"gender": "female",
|
||||
"desc": "成熟女性风格中文普通话女声"
|
||||
},
|
||||
{
|
||||
"id": "female-tianmei",
|
||||
"name": "甜美女性",
|
||||
"gender": "female",
|
||||
"desc": "甜美风格中文普通话女声"
|
||||
},
|
||||
{
|
||||
"id": "Chinese (Mandarin)_News_Anchor",
|
||||
"name": "新闻女声",
|
||||
"gender": "female",
|
||||
"desc": "专业、播音腔的中年女性新闻主播,标准普通话"
|
||||
},
|
||||
{
|
||||
"id": "Chinese (Mandarin)_Male_Announcer",
|
||||
"name": "播报男声",
|
||||
"gender": "male",
|
||||
"desc": "富有磁性的中年男性播报员声音,标准普通话,清晰而权威"
|
||||
},
|
||||
{
|
||||
"id": "Chinese (Mandarin)_Gentleman",
|
||||
"name": "温润男声",
|
||||
"gender": "male",
|
||||
"desc": "温润磁性的青年男性声音,标准普通话"
|
||||
},
|
||||
{
|
||||
"id": "Chinese (Mandarin)_Sweet_Lady",
|
||||
"name": "甜美女声",
|
||||
"gender": "female",
|
||||
"desc": "温柔甜美的青年女性声音,标准普通话"
|
||||
},
|
||||
{
|
||||
"id": "Cantonese_ProfessionalHost(F)",
|
||||
"name": "粤语专业女主持",
|
||||
"gender": "female",
|
||||
"desc": "中性、专业的青年女性粤语主持人声音"
|
||||
},
|
||||
{
|
||||
"id": "Cantonese_ProfessionalHost(M)",
|
||||
"name": "粤语专业男主持",
|
||||
"gender": "male",
|
||||
"desc": "中性、专业的青年男性粤语主持人声音"
|
||||
}
|
||||
],
|
||||
"enabled": true
|
||||
},
|
||||
"speech-2.6-turbo": {
|
||||
"name": "speech-2.6-turbo",
|
||||
"class": "src.services.provider.minimax.MiniMaxAudioService",
|
||||
"args": [
|
||||
"speech-2.6-turbo"
|
||||
],
|
||||
"voices": [
|
||||
{
|
||||
"id": "male-qn-qingse",
|
||||
"name": "青涩青年",
|
||||
"gender": "male",
|
||||
"desc": "青涩青年风格中文普通话男声"
|
||||
},
|
||||
{
|
||||
"id": "male-qn-jingying",
|
||||
"name": "精英青年",
|
||||
"gender": "male",
|
||||
"desc": "精英青年风格中文普通话男声"
|
||||
},
|
||||
{
|
||||
"id": "male-qn-badao",
|
||||
"name": "霸道青年",
|
||||
"gender": "male",
|
||||
"desc": "沉稳有力的中文普通话青年男声"
|
||||
},
|
||||
{
|
||||
"id": "male-qn-daxuesheng",
|
||||
"name": "青年大学生",
|
||||
"gender": "male",
|
||||
"desc": "青年大学生风格中文普通话男声"
|
||||
},
|
||||
{
|
||||
"id": "female-shaonv",
|
||||
"name": "少女",
|
||||
"gender": "female",
|
||||
"desc": "清亮少女风格中文普通话女声"
|
||||
},
|
||||
{
|
||||
"id": "female-yujie",
|
||||
"name": "御姐",
|
||||
"gender": "female",
|
||||
"desc": "成熟干练风格中文普通话女声"
|
||||
},
|
||||
{
|
||||
"id": "female-chengshu",
|
||||
"name": "成熟女性",
|
||||
"gender": "female",
|
||||
"desc": "成熟女性风格中文普通话女声"
|
||||
},
|
||||
{
|
||||
"id": "female-tianmei",
|
||||
"name": "甜美女性",
|
||||
"gender": "female",
|
||||
"desc": "甜美风格中文普通话女声"
|
||||
},
|
||||
{
|
||||
"id": "Chinese (Mandarin)_News_Anchor",
|
||||
"name": "新闻女声",
|
||||
"gender": "female",
|
||||
"desc": "专业、播音腔的中年女性新闻主播,标准普通话"
|
||||
},
|
||||
{
|
||||
"id": "Chinese (Mandarin)_Male_Announcer",
|
||||
"name": "播报男声",
|
||||
"gender": "male",
|
||||
"desc": "富有磁性的中年男性播报员声音,标准普通话,清晰而权威"
|
||||
},
|
||||
{
|
||||
"id": "Chinese (Mandarin)_Gentleman",
|
||||
"name": "温润男声",
|
||||
"gender": "male",
|
||||
"desc": "温润磁性的青年男性声音,标准普通话"
|
||||
},
|
||||
{
|
||||
"id": "Chinese (Mandarin)_Sweet_Lady",
|
||||
"name": "甜美女声",
|
||||
"gender": "female",
|
||||
"desc": "温柔甜美的青年女性声音,标准普通话"
|
||||
},
|
||||
{
|
||||
"id": "Cantonese_ProfessionalHost(F)",
|
||||
"name": "粤语专业女主持",
|
||||
"gender": "female",
|
||||
"desc": "中性、专业的青年女性粤语主持人声音"
|
||||
},
|
||||
{
|
||||
"id": "Cantonese_ProfessionalHost(M)",
|
||||
"name": "粤语专业男主持",
|
||||
"gender": "male",
|
||||
"desc": "中性、专业的青年男性粤语主持人声音"
|
||||
}
|
||||
],
|
||||
"enabled": true
|
||||
}
|
||||
}
|
||||
64
backend/src/config/services/minimax/image.json
Normal file
64
backend/src/config/services/minimax/image.json
Normal file
@@ -0,0 +1,64 @@
|
||||
{
|
||||
"image-01": {
|
||||
"name": "image-01",
|
||||
"class": "src.services.provider.minimax.MiniMaxImageService",
|
||||
"capabilities": {
|
||||
"supportsRefImage": false,
|
||||
"supportsLora": false
|
||||
},
|
||||
"resolutions": {
|
||||
"1K": {
|
||||
"1:1": "1024x1024",
|
||||
"16:9": "1280x720",
|
||||
"9:16": "720x1280",
|
||||
"4:3": "1152x864",
|
||||
"3:4": "864x1152",
|
||||
"3:2": "1248x832",
|
||||
"2:3": "832x1248",
|
||||
"21:9": "1344x576"
|
||||
},
|
||||
"2K": {
|
||||
"1:1": "2048x2048",
|
||||
"16:9": "2560x1440",
|
||||
"9:16": "1440x2560",
|
||||
"4:3": "2304x1728",
|
||||
"3:4": "1728x2304",
|
||||
"3:2": "2496x1664",
|
||||
"2:3": "1664x2496",
|
||||
"21:9": "2688x1152"
|
||||
}
|
||||
},
|
||||
"counts": {"min": 1, "max": 4}
|
||||
},
|
||||
"image-01-live": {
|
||||
"name": "image-01-live",
|
||||
"class": "src.services.provider.minimax.MiniMaxImageService",
|
||||
"capabilities": {
|
||||
"supportsRefImage": true,
|
||||
"supportsLora": false
|
||||
},
|
||||
"resolutions": {
|
||||
"1K": {
|
||||
"1:1": "1024x1024",
|
||||
"16:9": "1280x720",
|
||||
"9:16": "720x1280",
|
||||
"4:3": "1152x864",
|
||||
"3:4": "864x1152",
|
||||
"3:2": "1248x832",
|
||||
"2:3": "832x1248",
|
||||
"21:9": "1344x576"
|
||||
},
|
||||
"2K": {
|
||||
"1:1": "2048x2048",
|
||||
"16:9": "2560x1440",
|
||||
"9:16": "1440x2560",
|
||||
"4:3": "2304x1728",
|
||||
"3:4": "1728x2304",
|
||||
"3:2": "2496x1664",
|
||||
"2:3": "1664x2496",
|
||||
"21:9": "2688x1152"
|
||||
}
|
||||
},
|
||||
"counts": {"min": 1, "max": 4}
|
||||
}
|
||||
}
|
||||
11
backend/src/config/services/minimax/llm.json
Normal file
11
backend/src/config/services/minimax/llm.json
Normal file
@@ -0,0 +1,11 @@
|
||||
{
|
||||
"base_url": "https://api.minimaxi.com/v1",
|
||||
"MiniMax-M2.11": {
|
||||
"name": "MiniMax M2.1",
|
||||
"class": "src.services.provider.openai_service.OpenAIService",
|
||||
"args": [
|
||||
"MiniMax-M2.1"
|
||||
],
|
||||
"enabled": true
|
||||
}
|
||||
}
|
||||
22
backend/src/config/services/minimax/music.json
Normal file
22
backend/src/config/services/minimax/music.json
Normal file
@@ -0,0 +1,22 @@
|
||||
{
|
||||
"lyrics-2.5": {
|
||||
"name": "lyrics-2.5",
|
||||
"class": "src.services.provider.minimax.MiniMaxMusicService",
|
||||
"args": [
|
||||
"music-2.5"
|
||||
],
|
||||
"type": "lyrics",
|
||||
"enabled": true,
|
||||
"is_default": true
|
||||
},
|
||||
"music-2.5": {
|
||||
"name": "music-2.5",
|
||||
"class": "src.services.provider.minimax.MiniMaxMusicService",
|
||||
"args": [
|
||||
"music-2.5"
|
||||
],
|
||||
"type": "music",
|
||||
"enabled": true,
|
||||
"is_default": true
|
||||
}
|
||||
}
|
||||
23
backend/src/config/services/minimax/provider.json
Normal file
23
backend/src/config/services/minimax/provider.json
Normal file
@@ -0,0 +1,23 @@
|
||||
{
|
||||
"id": "minimax",
|
||||
"name": "MiniMax",
|
||||
"description": "海螺AI",
|
||||
"dashboard_url": "https://platform.minimaxi.com/",
|
||||
"helpUrl": "https://platform.minimaxi.com/user-center/basic-information/interface-key",
|
||||
"fields": [
|
||||
{
|
||||
"name": "apiKey",
|
||||
"label": "API Key",
|
||||
"placeholder": "sk-api-...",
|
||||
"required": true,
|
||||
"type": "password"
|
||||
},
|
||||
{
|
||||
"name": "groupId",
|
||||
"label": "Group ID (可选)",
|
||||
"placeholder": "分组 ID",
|
||||
"required": false,
|
||||
"type": "text"
|
||||
}
|
||||
]
|
||||
}
|
||||
68
backend/src/config/services/minimax/video.json
Normal file
68
backend/src/config/services/minimax/video.json
Normal file
@@ -0,0 +1,68 @@
|
||||
{
|
||||
"MiniMax-Hailuo-2.3": {
|
||||
"name": "Hailuo 2.3",
|
||||
"class": "src.services.provider.minimax.MiniMaxVideoService",
|
||||
"capabilities": {
|
||||
"supportsTextToVideo": true,
|
||||
"supportsImageToVideo": true,
|
||||
"supportsFirstFrame": true,
|
||||
"supportsLastFrame": true,
|
||||
"supportsMultiImage": false,
|
||||
"supportsMultiVideo": false,
|
||||
"supportsAudio": false,
|
||||
"supportsNegativePrompt": false,
|
||||
"supportsLora": false
|
||||
},
|
||||
"durations": {
|
||||
"values": [
|
||||
6,
|
||||
10
|
||||
]
|
||||
},
|
||||
"counts": {"min": 1, "max": 4}
|
||||
},
|
||||
"MiniMax-Hailuo-02": {
|
||||
"name": "Hailuo 02",
|
||||
"class": "src.services.provider.minimax.MiniMaxVideoService",
|
||||
"capabilities": {
|
||||
"supportsTextToVideo": true,
|
||||
"supportsImageToVideo": true,
|
||||
"supportsFirstFrame": true,
|
||||
"supportsLastFrame": true,
|
||||
"supportsMultiImage": false,
|
||||
"supportsMultiVideo": false,
|
||||
"supportsAudio": false,
|
||||
"supportsNegativePrompt": false,
|
||||
"supportsLora": false
|
||||
},
|
||||
"durations": {
|
||||
"values": [
|
||||
6,
|
||||
10
|
||||
]
|
||||
},
|
||||
"counts": {"min": 1, "max": 4}
|
||||
},
|
||||
"T2V-01-Director": {
|
||||
"name": "T2V-01-Director",
|
||||
"class": "src.services.provider.minimax.MiniMaxVideoService",
|
||||
"capabilities": {
|
||||
"supportsTextToVideo": true,
|
||||
"supportsImageToVideo": true,
|
||||
"supportsFirstFrame": true,
|
||||
"supportsLastFrame": true,
|
||||
"supportsMultiImage": false,
|
||||
"supportsMultiVideo": false,
|
||||
"supportsAudio": false,
|
||||
"supportsNegativePrompt": false,
|
||||
"supportsLora": false
|
||||
},
|
||||
"durations": {
|
||||
"values": [
|
||||
6,
|
||||
10
|
||||
]
|
||||
},
|
||||
"counts": {"min": 1, "max": 4}
|
||||
}
|
||||
}
|
||||
143
backend/src/config/services/modelscope/image.json
Normal file
143
backend/src/config/services/modelscope/image.json
Normal file
@@ -0,0 +1,143 @@
|
||||
{
|
||||
"qwen-image": {
|
||||
"name": "Qwen Image",
|
||||
"class": "backend.src.services.provider.modelscope.image.ModelScopeImageService",
|
||||
"args": [
|
||||
"Qwen/Qwen-Image"
|
||||
],
|
||||
"capabilities": {
|
||||
"supportsRefImage": false,
|
||||
"supportsLora": true
|
||||
},
|
||||
"resolutions": {
|
||||
"1K": {
|
||||
"16:9": "1664x928",
|
||||
"9:16": "928x1664",
|
||||
"1:1": "1328x1328"
|
||||
},
|
||||
"2K": {
|
||||
"16:9": "2560x1440",
|
||||
"9:16": "1440x2560",
|
||||
"1:1": "2048x2048"
|
||||
}
|
||||
},
|
||||
"counts": {"min": 1, "max": 4},
|
||||
"enabled": true
|
||||
},
|
||||
"qwen-image-edit": {
|
||||
"name": "Qwen Image Edit",
|
||||
"class": "backend.src.services.provider.modelscope.image.ModelScopeImageService",
|
||||
"args": [
|
||||
"Qwen/Qwen-Image-Edit-2511"
|
||||
],
|
||||
"capabilities": {
|
||||
"supportsRefImage": true,
|
||||
"supportsLora": true
|
||||
},
|
||||
"resolutions": {
|
||||
"1K": {
|
||||
"16:9": "1664x928",
|
||||
"9:16": "928x1664",
|
||||
"1:1": "1328x1328",
|
||||
"4:3": "1328x1024",
|
||||
"3:4": "1024x1328"
|
||||
},
|
||||
"2K": {
|
||||
"16:9": "2560x1440",
|
||||
"9:16": "1440x2560",
|
||||
"1:1": "2048x2048",
|
||||
"4:3": "2560x1920",
|
||||
"3:4": "1920x2560"
|
||||
}
|
||||
},
|
||||
"counts": {"min": 1, "max": 4},
|
||||
"enabled": true
|
||||
},
|
||||
"flux-dev": {
|
||||
"name": "FLUX.2 Dev",
|
||||
"class": "backend.src.services.provider.modelscope.image.ModelScopeImageService",
|
||||
"args": [
|
||||
"black-forest-labs/FLUX.2-dev"
|
||||
],
|
||||
"capabilities": {
|
||||
"supportsRefImage": false,
|
||||
"supportsLora": true
|
||||
},
|
||||
"resolutions": {
|
||||
"1K": {
|
||||
"16:9": "1280x720",
|
||||
"9:16": "720x1280",
|
||||
"1:1": "1024x1024",
|
||||
"4:3": "1024x768",
|
||||
"3:4": "768x1024"
|
||||
},
|
||||
"2K": {
|
||||
"16:9": "2560x1440",
|
||||
"9:16": "1440x2560",
|
||||
"1:1": "2048x2048",
|
||||
"4:3": "2048x1536",
|
||||
"3:4": "1536x2048"
|
||||
}
|
||||
},
|
||||
"counts": {"min": 1, "max": 4},
|
||||
"enabled": true
|
||||
},
|
||||
"z-image-turbo": {
|
||||
"name": "Z Image Turbo",
|
||||
"class": "backend.src.services.provider.modelscope.image.ModelScopeImageService",
|
||||
"args": [
|
||||
"Tongyi-MAI/Z-Image-Turbo"
|
||||
],
|
||||
"capabilities": {
|
||||
"supportsRefImage": false,
|
||||
"supportsLora": true
|
||||
},
|
||||
"resolutions": {
|
||||
"1K": {
|
||||
"16:9": "1280x720",
|
||||
"9:16": "720x1280",
|
||||
"1:1": "1024x1024",
|
||||
"4:3": "1024x768",
|
||||
"3:4": "768x1024"
|
||||
},
|
||||
"2K": {
|
||||
"16:9": "2560x1440",
|
||||
"9:16": "1440x2560",
|
||||
"1:1": "2048x2048",
|
||||
"4:3": "2048x1536",
|
||||
"3:4": "1536x2048"
|
||||
}
|
||||
},
|
||||
"counts": {"min": 1, "max": 4},
|
||||
"enabled": true
|
||||
},
|
||||
"awportrait-z": {
|
||||
"name": "AWPortrait Z",
|
||||
"class": "backend.src.services.provider.modelscope.image.ModelScopeImageService",
|
||||
"args": [
|
||||
"LiblibAI/AWPortrait-Z"
|
||||
],
|
||||
"capabilities": {
|
||||
"supportsRefImage": false,
|
||||
"supportsLora": true
|
||||
},
|
||||
"resolutions": {
|
||||
"1K": {
|
||||
"16:9": "1280x720",
|
||||
"9:16": "720x1280",
|
||||
"1:1": "1024x1024",
|
||||
"4:3": "1024x768",
|
||||
"3:4": "768x1024"
|
||||
},
|
||||
"2K": {
|
||||
"16:9": "2560x1440",
|
||||
"9:16": "1440x2560",
|
||||
"1:1": "2048x2048",
|
||||
"4:3": "2048x1536",
|
||||
"3:4": "1536x2048"
|
||||
}
|
||||
},
|
||||
"counts": {"min": 1, "max": 4},
|
||||
"enabled": true
|
||||
}
|
||||
}
|
||||
16
backend/src/config/services/modelscope/provider.json
Normal file
16
backend/src/config/services/modelscope/provider.json
Normal file
@@ -0,0 +1,16 @@
|
||||
{
|
||||
"id": "modelscope",
|
||||
"name": "ModelScope",
|
||||
"description": "开源模型平台",
|
||||
"dashboard_url": "https://modelscope.cn/",
|
||||
"helpUrl": "https://www.modelscope.cn/my/myaccesstoken",
|
||||
"fields": [
|
||||
{
|
||||
"name": "apiKey",
|
||||
"label": "API Token",
|
||||
"placeholder": "ms-...",
|
||||
"required": true,
|
||||
"type": "password"
|
||||
}
|
||||
]
|
||||
}
|
||||
11
backend/src/config/services/moonshot/llm.json
Normal file
11
backend/src/config/services/moonshot/llm.json
Normal file
@@ -0,0 +1,11 @@
|
||||
{
|
||||
"base_url": "https://api.moonshot.cn/v1",
|
||||
"kimi-k2.5": {
|
||||
"name": "Kimi 2.5",
|
||||
"class": "backend.src.services.provider.openai_service.OpenAIService",
|
||||
"args": [
|
||||
"kimi-k2.5"
|
||||
],
|
||||
"enabled": true
|
||||
}
|
||||
}
|
||||
16
backend/src/config/services/moonshot/provider.json
Normal file
16
backend/src/config/services/moonshot/provider.json
Normal file
@@ -0,0 +1,16 @@
|
||||
{
|
||||
"id": "moonshot",
|
||||
"name": "月之暗面",
|
||||
"description": "Kimi 大模型",
|
||||
"dashboard_url": "https://platform.moonshot.cn/",
|
||||
"helpUrl": "https://platform.moonshot.cn/",
|
||||
"fields": [
|
||||
{
|
||||
"name": "apiKey",
|
||||
"label": "API Key",
|
||||
"placeholder": "sk-...",
|
||||
"required": true,
|
||||
"type": "password"
|
||||
}
|
||||
]
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user